Compiles a function to an XLA executable via tracing.
Returns a callable R function that executes the compiled binary.
Unlike jit(), compilation happens eagerly at
definition time rather than on first call, so the input shapes and dtypes must be
specified upfront via abstract tensors (see nv_aten()).
Usage
xla(f, args, donate = character(), device = NULL)Arguments
- f
(
function)
Function to compile. Must accept and returnAnvilTensors.- args
(
list)
List of abstract tensor specifications (e.g. fromnv_aten()) describing the expected shapes and dtypes off's arguments.- donate
(
character())
Names of the arguments whose buffers should be donated.- device
(
character(1))
Target device such as"cpu"(default) or"cuda".
Value
(function)
A function that accepts AnvilTensor arguments (matching the flat inputs)
and returns the result as AnvilTensors.
Details
Traces f with the given abstract args (via trace_fn()), lowers the resulting graph
via stablehlo() and then compiles it to an XLA executable via pjrt::pjrt_compile().
and compiles it to an XLA executable immediately.
See also
jit() for lazy compilation, compile_to_xla() for the lower-level API.