The {pjrt} package can compile and run stablehlo programs. However,
some operations are not directly covered by stablehlo and composing them
of other stablehlo ops would be inefficient. For this reason, PJRT
supports custom calls. In a stablehlo program such a custom call can be
executed via: stablehlo.custom_call @<target>(...).
In this vignette we will explain how to use and create custom calls on
CPU and CUDA and how to make use of LAPACK on CPU and cuSOLVE on
CUDA.
Using Custom Calls
{pjrt} ships a few built-in custom calls registered automatically
when the package is loaded. This includes print_tensor and
the linear-algebra primitives geqrf / orgqr
(QR factorisation), lu, svd, and
eigh. The example below uses print_tensor,
which takes a single tensor operand, prints it, and returns nothing.
library(pjrt)
src <- r"(
func.func @main(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
stablehlo.custom_call @print_tensor(%x) {
call_target_name = "print_tensor",
backend_config = {
print_header = "my matrix",
print_footer = "----"
},
has_side_effect = true,
api_version = 4 : i32,
operand_layouts = [dense<[1, 0]> : tensor<2xindex>],
result_layouts = []
} : (tensor<2x3xf32>) -> ()
"func.return"(%x) : (tensor<2x3xf32>) -> ()
}
)"
exec <- pjrt_compile(pjrt_program(src))
buf <- pjrt_buffer(matrix(1:6, nrow = 2, ncol = 3), dtype = "f32")
out <- pjrt_execute(exec, buf)## my matrix
## 1 3 5
## 2 4 6
## ----
as_array(out)## [,1] [,2] [,3]
## [1,] 1 3 5
## [2,] 2 4 6
A stablehlo.custom_call op carries:
-
call_target_name– a string that names a handler registered with the PJRT runtime (here, the built-in"print_tensor"). Note that this name also appears as@print_tensorafter the op name in the textual form; the@<name>is purely a pretty-printer convention that mirrorscall_target_name, only the attribute string is operative at runtime. -
api_version = 4 : i32– the calling convention used to invoke the handler. Just set this to 4. -
has_side_effect = true– required for ops that exist only for their side effect (such as printing). If we wouldn’t set this toTRUEhere, the compiler might drop the operation in one of its optimization passes. -
backend_config = {...}(optional) – a dictionary of attributes the handler can read.print_tensoracceptsprint_header(default"PJRTBuffer") andprint_footer(default"[ <dtype>{<shape>} ]"); pass""for either to suppress that line. -
operand_layouts/result_layoutsindicate which layout the custom call expects and outputs (e.g., row-major vs. col-major). XLA’s documented default for newly-created shapes is row-major (the XLA Shapes, but it’s best to always set this for clarity.
Next, we will discuss how to register a custom call.
Registering a Custom Call
In order to make a custom call availble in the {stablehlo} program, we need to:
- Define what it does.
- Register it with (one more more)
PJRTPlugins (CUDA, CPU) so it can be used.
Generally, we need different implementations for different plugins (CPU vs. CUDA). We will start with discussing CPU, as it is simpler.
CPU
Handler signature
A CPU handler is just a C++ function that takes one parameter per
operand / result of the stablehlo.custom_call op, plus
optional runtime-supplied parameters (attributes, contexts, etc.), and
returns an xla::ffi::Error. The eigh handler from
src/eigh.cpp, for example, is declared as:
#include "xla/ffi/api/ffi.h" // ships in pjrt's inst/include/
using namespace xla::ffi; // brings Error, AnyBuffer, Result<>, ... into scope
static Error do_eigh(AnyBuffer input,
Result<AnyBuffer> v_out,
Result<AnyBuffer> w_out);The first parameter is the input matrix; the two
Result<AnyBuffer> parameters are the output buffers
(eigenvectors and eigenvalues) that XLA has already allocated and that
the handler fills in. The return value is Error::Success()
on success, or one of the Error::Invalid… /
Error::Internal factories to surface a runtime failure to
XLA.
Registering the handler with
XLA_FFI_DEFINE_HANDLER
A C++ function on its own is not yet usable from MLIR – the runtime
has no way to know that do_eigh wants three buffers, in
that order, with no attributes or contexts.
XLA_FFI_DEFINE_HANDLER is the macro that attaches that
decoding metadata to the function and emits a symbol PJRT can call:
XLA_FFI_DEFINE_HANDLER(eigh_handler, do_eigh,
Ffi::Bind()
.Arg<AnyBuffer>() // symmetric matrix
.Ret<AnyBuffer>() // eigenvectors (n, n)
.Ret<AnyBuffer>()); // eigenvalues (n,)The pieces:
-
eigh_handler– the symbol name the macro emits. This (wrapped in an R external pointer) is theXLA_FFI_Handler*we hand to thepjrt::pjrt_register_custom_call()R function. -
do_eigh– the C++ function the handler dispatches to. -
Ffi::Bind().Arg<...>().Ret<...>()...– the binding specification that defines which arguments are passed todo_eigh()during execution.
The most important builder methods, all of which we already saw above:
-
.Arg<AnyBuffer>()/.Ret<AnyBuffer>()– one operand / result of any dtype / rank. -
.Attrs<Dictionary>()– the wholebackend_configas a typed dictionary (used byprint_tensor).
Other binding methods exist – statically typed
Buffer<dtype, rank>, named
.Attr<T>("name") for individual
backend_config entries, Token for
pure-ordering ops, RemainingArgs /
RemainingRets for variadic arity, and additional
Ctx<T> types including a host ThreadPool
(XLA’s intra-op pool, useful for parallelising hand-rolled CPU work).
For the full list see xla/ffi/api/ffi.h.
Calling into BLAS / LAPACK
{pjrt}’s build links against LAPACK and BLAS, so a CPU handler can call any LAPACK / BLAS routine directly.
Most LAPACK factorisation routines need a working buffer whose
optimal size depends on the input dimensions and on
implementation-specific blocking parameters. The standard idiom is to
call the routine twice: first with the workspace-size
argument set to -1, which makes LAPACK skip the actual work
and write the optimal size into the first element of the workspace
buffer; then a second time with a buffer of that size to do the real
factorisation.
LAPACK signals success or failure through an integer
info out-parameter (info = 0 is success,
anything else is an error). This must be checked after every
call – both the workspace query and the real call. {pjrt}
provides a small helper,
lapack_check_info(info, routine_name) in
src/ffi_lapack.h, that returns
Error::Success() on info == 0 and
Error::Internal("<routine_name> failed with info = <info>")
otherwise. Wrap each LAPACK call with
PJRT_RETURN_IF_ERROR(lapack_check_info(info, "...")) so XLA
stops execution rather than reading garbage results downstream.
Windows pitfall: no f32 LAPACK. R on Windows uses its own bundled LAPACK, which only ships the double-precision routines – single -precision is not available. A handler that calls into single-precision LAPACK will fail to load on Windows. macOS and Linux are unaffected (they use the system LAPACK, which has both precisions). The workaround is to promote f32 to f64 around the LAPACK call on Windows; See the existing linear algebra custom calls how to do this via C++ templates.
CUDA
Before getting into CUDA-specific handler details, a note on why the CUDA code in pjrt is structured differently from a “natural” CUDA backend.
The library that we based the custom calls on is jaxlib, which ships
separate wheels for CPU and CUDA: jaxlib
is the base CPU wheel, jax-cuda12-pjrt /
jax-cuda12-plugin are CUDA-specific wheels built with
--config=cuda and the CUDA toolkit on the build path. The
CUDA wheels can include CUDA SDK (Software Development Kit) headers,
link against the CUDA libraries, and use the real types directly.
R packages don’t have an equivalent mechanism.
R CMD INSTALL builds one source tree against one
Makevars; there is no “platform variant” the way Python
wheels have CUDA extras. The same pjrt.so binary has to
load and run on machines without the CUDA toolkit installed (and even
without an NVIDIA GPU) – otherwise we’d break every CPU-only user. So
pjrt’s CUDA path:
- never
#includes any CUDA SDK header, - types every CUDA value as opaque (
void *for streams and handles,uintptr_t-aliasedCUdeviceptrfor device addresses) so we don’t depend on the SDK’s typedefs at compile time, - maintains a hand-rolled function-pointer table
(
CudaLibsinffi_cuda.h) for the cuSOLVER + CUDA driver entry points we use, - loads those entry points lazily at runtime via
dlopen/dlsyminsideget_cuda_libs()(ffi_cuda.cpp), so the package loads cleanly on machines wherelibcusolver.soandlibcuda.so.1aren’t present (the loader just leaves the table flagged as unloaded, and the CUDA handlers returnError::Internalif invoked).
In effect, CudaLibs + get_cuda_libs() is
doing the job of the linker, dynamically and by hand –
the symbol lookup that a normal -lcusolver dependency would
have the system loader perform at process startup, we instead perform
from our own code at runtime. The trade-off is that every cuSOLVER /
driver function we want to use has to be transcribed by hand twice: once
as a typed function-pointer field in CudaLibs (whose
signature must match NVIDIA’s exactly – the compiler can’t check it for
us, since we don’t include the SDK headers that would tell it what to
expect), and once as a string symbol name passed to dlsym
in the loader.
In the future, we might move to a more robust mechanism that avoids these redefinitions, e.g. by including some CUDA headers (which might come with some licensing strings attached) or by e.g. generating the declarations from a DSL.
CUDA execution model
A few CUDA concepts come up in every CUDA handler, so we will briefly review them here.
Host vs. device. The CPU is the host and the GPU is the device. They have separate memory: a host pointer can’t be dereferenced on the device, and a device pointer can’t be dereferenced on the host. Moving data between them requires an explicit memcpy and data can also be directly copied from device to device.
Streams. A CUDA stream is a FIFO (first in first out) queue of operations that the device executes. Operations enqueued on the same stream run in the order they were enqueued. Operations on different streams can run concurrently when the dependency graph allows it (e.g. a memcpy on one stream overlapping a kernel on another).
Asynchronous execution. Most CUDA driver and
cuSOLVER calls – including the *Async memcpys, kernel
launches, and the cuSOLVER factorisations – are non-blocking on
the host. They enqueue work on the stream and return immediately, before
the GPU has actually done anything. The host call’s return only signals
“this is now queued”; the data isn’t real until the device gets to that
point in the queue. This is the source of most CUDA-handler pitfalls: if
you read a destination buffer right after issuing an async memcpy into
it, you’ll see whatever was there before, because the copy is still
queued.
Stream synchronisation.
cuStreamSynchronize(stream) blocks the host until every
operation previously enqueued on stream has finished.
That’s the point at which “queued” turns into “actually done” from the
host’s perspective. Inside an FFI handler, you should
not call cuStreamSynchronize unless you genuinely
need a host-visible result – syncing forces the host to wait for the GPU
and breaks the async pipelining XLA depends on. The only place pjrt
syncs today is print_tensor’s CUDA path, because it has to
copy data back to the host before formatting it. The linalg handlers
never sync: they enqueue their work on the stream and return, and XLA’s
own scheduling makes sure any downstream op that reads their outputs
runs after the kernel has actually completed.
Status codes. Every CUDA driver and cuSOLVER call
returns an int status. 0 is success; anything
else is an error and must be propagated up as
Error::Internal. A dropped status doesn’t crash – it leaves
the output buffer with garbage that surfaces much later as wrong
numerical answers, with no error message anywhere. The
PJRT_RETURN_IF_GPU_ERROR(expr, what) macro in
ffi_cuda.h wraps each call site and propagates non-zero
statuses uniformly.
Handler signature
A CUDA handler has the same shape as the CPU one, but two extra
context parameters get bound at the front. From
src/eigh_cuda.cpp:
static Error eigh_cuda_impl(void *stream,
ScratchAllocator &scratch,
AnyBuffer input,
Result<AnyBuffer> v_out,
Result<AnyBuffer> w_out);
XLA_FFI_DEFINE_HANDLER(eigh_handler_cuda, do_eigh_cuda,
Ffi::Bind()
.Ctx<PlatformStream<void *>>()
.Ctx<ScratchAllocator>()
.Arg<AnyBuffer>()
.Ret<AnyBuffer>()
.Ret<AnyBuffer>());The differences from CPU:
-
void *stream– bound viaCtx<PlatformStream<void *>>(), this is the CUDA stream XLA dispatched the op on. The FFI exposes exactly one stream per invocation – there is no API to ask for a transfer stream or any other – and the handler must enqueue all of its device-side work on it: kernel launches, cuSOLVER calls,memcpy_dtod/memcpy_dtoh/memset_d8, any synchronisation. -
ScratchAllocator &scratch– use this for any device memory the handler needs that isn’t an FFI input or output: workspace for the library call, an input copy,devInfo, etc. Allocations go through XLA’s runtime allocator (the same one that backs FFI input/output buffers, so the bytes are visible to XLA’s memory accounting) and are freed automatically when the handler returns – the FFI binding’s destructor calls back into the runtime to free every allocation it made. -
Windows. The handler symbol still has to be defined
on Windows (so the
[[Rcpp::export]]getter resolves), but the dispatcher returnsError(ErrorCode::kUnimplemented, ...)because pjrt has no CUDA support on Windows. See the#ifdef _WIN32guards insrc/eigh_cuda.cpp.
Using cuSOLVER
cuSOLVER is loaded lazily via dlopen (see CUDA intro
above), so a CUDA handler that wants to call cuSOLVER goes through the
CudaLibs function-pointer table and the Solver
prologue defined in ffi_cuda.h. The kernel declares
Solver solver(get_cuda_libs()); solver.begin(scratch, stream);
and then invokes routines. begin() borrows a stream-bound
handle from a pool (see stateful-handle paragraph below) and allocates a
devInfo int on device memory.
Because cuSOLVER/CUDA handles are expensive to create, they are
cached in a process-wide singleton pool – a map of stream pointer to
free-list of handles, protected by a mutex against concurrent host-side
borrow / return. A kernel borrows a handle from its stream’s free-list
(creating one the first time), uses it, and returns it on scope exit;
handles live for the process and are reused across
pjrt_execute calls. The design is adapted from jaxlib.
Most cuSOLVER factorisation routines follow the same query / allocate
/ run idiom as LAPACK (covered in the CPU section), but the workspace
query is a separate function with a _bufferSize suffix
(cusolverDnSgeqrf_bufferSize, etc.) rather than a
lwork = -1 like in LAPACK. Note that computing the buffer
size is a cheap host-side computation.
cuSOLVER actually has two status mechanisms, for two different failure modes:
- The return value of every routine is a host-side
intreturned synchronously when the call comes back, indicating whether the call was accepted and queued onto the stream. We always check this withPJRT_RETURN_IF_GPU_ERROR. - The
devInfoint is a device-sideintcuSOLVER’s kernel writes during execution, using the same convention as LAPACK’sinfo(0 = success, <0 = illegal argument at position |info|, >0 = routine-specific). The split exists because cuSOLVER is async on a stream: the host return covers pre-launch failures, but actual runtime / numerical status can only be reported after the kernel has run, and the kernel needs somewhere device-side to write it.
We pass devInfo because we have to, but don’t read it,
as it forces synchronization. Again, we follow jaxlib in this
decision.
Additional things to pay attention to
There are a few other things to be aware of when implementing custom calls.
Buffer aliasing. Stablehlo inherited XLA’s
input-output aliasing. stablehlo.custom_call’s
output_operand_aliases attribute lets XLA hand the handler
the same pointer for an input and an output. Routines that
overwrite their input in place will then read corrupted data. The fix:
copy input → destination buffer first, then factor in place there, with
a pointer-equality check so the copy is skipped when the buffers already
alias (if (in != target) memcpy(...)). jaxlib uses the same
pattern in every linalg kernel; pjrt’s
maybe_promote_inplace helper bakes the guard in for the CPU
side.
Dimension type mismatch and byte-size overflow.
Buffer dimensions arrive as int64_t from
input.dimensions(), but LAPACK and cuSOLVER take
int. A naive cast wraps silently for axes larger than
INT_MAX – use the dim_to_int(dim, name, out)
helper in ffi_common.h, which casts and returns
Error::InvalidArgument on overflow. That covers each
individual dimension, but not their product: even after
dim_to_int widens m and n to
int, an expression like m * n * sizeof(T)
still evaluates in int arithmetic and overflows for large
matrices (e.g. 50000 × 50000) before the result gets widened to
size_t. Cast at least one operand to size_t
before the multiplication so the whole expression evaluates in
size_t:
static_cast<size_t>(m) * n * sizeof(T).
scratch.Allocate(0).
ScratchAllocator::Allocate(0) returns
std::nullopt, which
allocate_workspace<T> propagates as
Error::Internal. Downstream use of the custom calls should
prevent this.