Executes f with abstract array arguments and records every primitive operation into
an AnvlGraph.
The resulting graph can be lowered to StableHLO (via stablehlo()) or transformed
(e.g. via transform_gradient()).
Arguments
- f
(
function)
The function to trace. Must not be aJitFunction(i.e. already jitted).- args
(
listof (AnvlArray|AbstractArray))
The (unflattened) arguments to the function. Mutually exclusive with theargs_flat/in_treepair.- desc
(
NULL|GraphDescriptor)
Optional descriptor. WhenNULL(default), a new descriptor is created.- mode
(
character(1))
How to handle the inputs. Options are:"toplevel": Used for jit(). Default."subgraph": Use for tracing subgraphs in higher-order primitives likeprim_while()."inline": Use for transformations like jit, where the graph is later inlined into the parent graph.
- args_flat
(
list)
Flattened arguments. Must be accompanied byin_tree.- in_tree
(
Node)
Tree structure describing howargs_flatmaps back tof's arguments.
Value
An AnvlGraph containing the traced operations.
See also
stablehlo() to lower the graph, jit() / xla() for end-to-end
compilation.