Compiles a function to an XLA executable via tracing.
Returns a callable R function that executes the compiled binary.
Unlike jit(), compilation happens eagerly at
definition time rather than on first call, so the input shapes and dtypes must be
specified upfront via abstract arrays (see nv_aval()).
Usage
xla(f, args, donate = character(), device = NULL)Arguments
- f
(
function)
Function to compile. Must accept and returnAnvlArrays.- args
(
list)
List of abstract array specifications (e.g. fromnv_aval()) describing the expected shapes and dtypes off's arguments.- donate
(
character())
Names of the arguments whose buffers should be donated.- device
(
character(1)|PJRTDevice)
Target device such as"cpu"(default) or"cuda".
Value
(function)
A function that accepts AnvlArray arguments (matching the flat inputs)
and returns the result as AnvlArrays.
Details
Traces f with the given abstract args (via trace_fn()), lowers the resulting graph
via stablehlo() and then compiles it to an XLA executable via pjrt::pjrt_compile().
See also
jit() for lazy compilation, compile_xla() for the lower-level API.