Skip to contents

Compiles a function to an XLA executable via tracing.

Returns a callable R function that executes the compiled binary. Unlike jit(), compilation happens eagerly at definition time rather than on first call, so the input shapes and dtypes must be specified upfront via abstract arrays (see nv_aval()).

Usage

xla(f, args, donate = character(), device = NULL)

Arguments

f

(function)
Function to compile. Must accept and return AnvlArrays.

args

(list)
List of abstract array specifications (e.g. from nv_aval()) describing the expected shapes and dtypes of f's arguments.

donate

(character())
Names of the arguments whose buffers should be donated.

device

(character(1) | PJRTDevice)
Target device such as "cpu" (default) or "cuda".

Value

(function)
A function that accepts AnvlArray arguments (matching the flat inputs) and returns the result as AnvlArrays.

Details

Traces f with the given abstract args (via trace_fn()), lowers the resulting graph via stablehlo() and then compiles it to an XLA executable via pjrt::pjrt_compile().

See also

jit() for lazy compilation, compile_xla() for the lower-level API.

Examples

f_compiled <- xla(function(x, y) x + y,
  args = list(x = nv_aval("f32", c(2, 2)), y = nv_aval("f32", c(2, 2)))
)
a <- nv_array(array(1:4, c(2, 2)), dtype = "f32")
b <- nv_array(array(5:8, c(2, 2)), dtype = "f32")
f_compiled(a, b)
#> AnvlArray
#>   6 10
#>   8 12
#> [ CPUf32{2,2} ]