Skip to contents

Constructs the XLA backend, which stores array data in PJRT buffers (via pjrt::pjrt_buffer()) and compiles jitted functions to XLA executables via stablehlo() and pjrt::pjrt_compile(). This is the default backend.

Usage

AnvlBackendXla()

Value

An AnvlBackend object with subclass "AnvlBackendXla".

Data representation

An AnvlArray with backend = "xla" wraps a pjrt::pjrt_buffer() stored in the $data field. The buffer owns the memory holding the tensor values and may live on any device supported by PJRT (CPU, CUDA, Metal, ...). Calling as_array() transfers the buffer contents back to an R array; calling nv_array() on an R object uploads it to the requested device.

Each AnvlArray therefore has an associated device, queryable via device(). A device is a pjrt::as_pjrt_device() object (e.g. the platform "cpu" or "cuda", optionally with an index such as "cuda:1"). When device is NULL in nv_array() or the jit() wrapper, the device defaults to the PJRT_PLATFORM environment variable (falling back to "cpu"), or is inferred from the existing inputs of a jitted call. Operations require all inputs to live on the same device.

XLA JIT arguments

  • donate (character(), default character()): names of arguments whose underlying buffers may be donated to (i.e., reused/consumed by) the compiled XLA executable. Donated buffers must not be used again by the caller after the call; this can reduce memory usage and copies for large inputs. Must not overlap with static.