Skip to contents

Compiles a function to an XLA executable via tracing.

Returns a callable R function that executes the compiled binary. Unlike jit(), compilation happens eagerly at definition time rather than on first call, so the input shapes and dtypes must be specified upfront via abstract tensors (see nv_aten()).

Usage

xla(f, args, donate = character(), device = NULL)

Arguments

f

(function)
Function to compile. Must accept and return AnvilTensors.

args

(list)
List of abstract tensor specifications (e.g. from nv_aten()) describing the expected shapes and dtypes of f's arguments.

donate

(character())
Names of the arguments whose buffers should be donated.

device

(character(1))
Target device such as "cpu" (default) or "cuda".

Value

(function)
A function that accepts AnvilTensor arguments (matching the flat inputs) and returns the result as AnvilTensors.

Details

Traces f with the given abstract args (via trace_fn()), lowers the resulting graph via stablehlo() and then compiles it to an XLA executable via pjrt::pjrt_compile(). and compiles it to an XLA executable immediately.

See also

jit() for lazy compilation, compile_to_xla() for the lower-level API.

Examples

f_compiled <- xla(function(x, y) x + y,
  args = list(x = nv_aten("f32", c(2, 2)), y = nv_aten("f32", c(2, 2)))
)
a <- nv_tensor(array(1:4, c(2, 2)), dtype = "f32")
b <- nv_tensor(array(5:8, c(2, 2)), dtype = "f32")
f_compiled(a, b)
#> AnvilTensor
#>   6 10
#>   8 12
#> [ CPUf32{2,2} ]