Constructs the XLA backend, which stores array data in PJRT buffers (via
pjrt::pjrt_buffer()) and compiles jitted functions to XLA executables
via stablehlo() and pjrt::pjrt_compile(). This is the default
backend.
Value
An AnvlBackend object with subclass "AnvlBackendXla".
Data representation
An AnvlArray with backend = "xla" wraps a pjrt::pjrt_buffer()
stored in the $data field. The buffer owns the memory holding the tensor
values and may live on any device supported by PJRT (CPU, CUDA, Metal,
...). Calling as_array() transfers the buffer contents back to an R
array; calling nv_array() on an R object uploads it to the requested
device.
Each AnvlArray therefore has an associated device, queryable via
device(). A device is a pjrt::as_pjrt_device() object (e.g. the
platform "cpu" or "cuda", optionally with an index such as "cuda:1").
When device is NULL in nv_array() or the jit() wrapper, the
device defaults to the PJRT_PLATFORM environment variable (falling back
to "cpu"), or is inferred from the existing inputs of a jitted call.
Operations require all inputs to live on the same device.
XLA JIT arguments
donate(character(), defaultcharacter()): names of arguments whose underlying buffers may be donated to (i.e., reused/consumed by) the compiled XLA executable. Donated buffers must not be used again by the caller after the call; this can reduce memory usage and copies for large inputs. Must not overlap withstatic.