Executes f with abstract tensor arguments and records every primitive operation into
an AnvilGraph.
The resulting graph can be lowered to StableHLO (via stablehlo()) or transformed
(e.g. via transform_gradient()).
Usage
trace_fn(
f,
args = NULL,
desc = NULL,
toplevel = FALSE,
lit_to_tensor = FALSE,
args_flat = NULL,
in_tree = NULL
)Arguments
- f
(
function)
The function to trace. Must not be aJitFunction(i.e. already jitted).- args
(
listof (AnvilTensor|AbstractTensor))
The (unflattened) arguments to the function. Mutually exclusive with theargs_flat/in_treepair.- desc
(
NULL|GraphDescriptor)
Optional descriptor. WhenNULL(default), a new descriptor is created.- toplevel
(
logical(1))
IfTRUE, concreteAnvilTensorinputs are treated as unknown (traced) values. IfFALSE(default), they are treated as known constants.- lit_to_tensor
(
logical(1))
Whether to convert literal inputs to tensors. Used internally by higher-order primitives such asnv_ifandnv_while.- args_flat
(
list)
Flattened arguments. Must be accompanied byin_tree.- in_tree
(
Node)
Tree structure describing howargs_flatmaps back tof's arguments.
Value
An AnvilGraph containing the traced operations.
See also
stablehlo() to lower the graph, jit() / xla() for end-to-end
compilation.