Skip to contents

The {pjrt} package can compile and run stablehlo programs. However, some operations are not directly covered by stablehlo and composing them of other stablehlo ops would be inefficient. For this reason, PJRT supports custom calls. In a stablehlo program such a custom call can be executed via: stablehlo.custom_call @<target>(...). In this vignette we will explain how to use and create custom calls on CPU and CUDA and how to make use of LAPACK on CPU and cuSOLVE on CUDA.

Using Custom Calls

{pjrt} ships a few built-in custom calls registered automatically when the package is loaded. This includes print_tensor and the linear-algebra primitives geqrf / orgqr (QR factorisation), lu, svd, and eigh. The example below uses print_tensor, which takes a single tensor operand, prints it, and returns nothing.

library(pjrt)

src <- r"(
func.func @main(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
  stablehlo.custom_call @print_tensor(%x) {
    call_target_name = "print_tensor",
    backend_config = {
      print_header = "my matrix",
      print_footer = "----"
    },
    has_side_effect = true,
    api_version = 4 : i32,
    operand_layouts = [dense<[1, 0]> : tensor<2xindex>],
    result_layouts = []
  } : (tensor<2x3xf32>) -> ()

  "func.return"(%x) : (tensor<2x3xf32>) -> ()
}
)"

exec <- pjrt_compile(pjrt_program(src))
buf  <- pjrt_buffer(matrix(1:6, nrow = 2, ncol = 3), dtype = "f32")
out  <- pjrt_execute(exec, buf)
## my matrix
##  1 3 5
##  2 4 6
## ----
##      [,1] [,2] [,3]
## [1,]    1    3    5
## [2,]    2    4    6

A stablehlo.custom_call op carries:

  • call_target_name – a string that names a handler registered with the PJRT runtime (here, the built-in "print_tensor"). Note that this name also appears as @print_tensor after the op name in the textual form; the @<name> is purely a pretty-printer convention that mirrors call_target_name, only the attribute string is operative at runtime.
  • api_version = 4 : i32 – the calling convention used to invoke the handler. Just set this to 4.
  • has_side_effect = true – required for ops that exist only for their side effect (such as printing). If we wouldn’t set this to TRUE here, the compiler might drop the operation in one of its optimization passes.
  • backend_config = {...} (optional) – a dictionary of attributes the handler can read. print_tensor accepts print_header (default "PJRTBuffer") and print_footer (default "[ <dtype>{<shape>} ]"); pass "" for either to suppress that line.
  • operand_layouts / result_layouts indicate which layout the custom call expects and outputs (e.g., row-major vs. col-major). XLA’s documented default for newly-created shapes is row-major (the XLA Shapes, but it’s best to always set this for clarity.

Next, we will discuss how to register a custom call.

Registering a Custom Call

In order to make a custom call availble in the {stablehlo} program, we need to:

  1. Define what it does.
  2. Register it with (one more more) PJRTPlugins (CUDA, CPU) so it can be used.

Generally, we need different implementations for different plugins (CPU vs. CUDA). We will start with discussing CPU, as it is simpler.

CPU

Handler signature

A CPU handler is just a C++ function that takes one parameter per operand / result of the stablehlo.custom_call op, plus optional runtime-supplied parameters (attributes, contexts, etc.), and returns an xla::ffi::Error. The eigh handler from src/eigh.cpp, for example, is declared as:

#include "xla/ffi/api/ffi.h"   // ships in pjrt's inst/include/
using namespace xla::ffi;       // brings Error, AnyBuffer, Result<>, ... into scope

static Error do_eigh(AnyBuffer input,
                     Result<AnyBuffer> v_out,
                     Result<AnyBuffer> w_out);

The first parameter is the input matrix; the two Result<AnyBuffer> parameters are the output buffers (eigenvectors and eigenvalues) that XLA has already allocated and that the handler fills in. The return value is Error::Success() on success, or one of the Error::Invalid… / Error::Internal factories to surface a runtime failure to XLA.

Registering the handler with XLA_FFI_DEFINE_HANDLER

A C++ function on its own is not yet usable from MLIR – the runtime has no way to know that do_eigh wants three buffers, in that order, with no attributes or contexts. XLA_FFI_DEFINE_HANDLER is the macro that attaches that decoding metadata to the function and emits a symbol PJRT can call:

XLA_FFI_DEFINE_HANDLER(eigh_handler, do_eigh,
                       Ffi::Bind()
                           .Arg<AnyBuffer>()    // symmetric matrix
                           .Ret<AnyBuffer>()    // eigenvectors (n, n)
                           .Ret<AnyBuffer>());  // eigenvalues (n,)

The pieces:

  • eigh_handler – the symbol name the macro emits. This (wrapped in an R external pointer) is the XLA_FFI_Handler* we hand to the pjrt::pjrt_register_custom_call() R function.
  • do_eigh – the C++ function the handler dispatches to.
  • Ffi::Bind().Arg<...>().Ret<...>()... – the binding specification that defines which arguments are passed to do_eigh() during execution.

The most important builder methods, all of which we already saw above:

  • .Arg<AnyBuffer>() / .Ret<AnyBuffer>() – one operand / result of any dtype / rank.
  • .Attrs<Dictionary>() – the whole backend_config as a typed dictionary (used by print_tensor).

Other binding methods exist – statically typed Buffer<dtype, rank>, named .Attr<T>("name") for individual backend_config entries, Token for pure-ordering ops, RemainingArgs / RemainingRets for variadic arity, and additional Ctx<T> types including a host ThreadPool (XLA’s intra-op pool, useful for parallelising hand-rolled CPU work). For the full list see xla/ffi/api/ffi.h.

Calling into BLAS / LAPACK

{pjrt}’s build links against LAPACK and BLAS, so a CPU handler can call any LAPACK / BLAS routine directly.

Most LAPACK factorisation routines need a working buffer whose optimal size depends on the input dimensions and on implementation-specific blocking parameters. The standard idiom is to call the routine twice: first with the workspace-size argument set to -1, which makes LAPACK skip the actual work and write the optimal size into the first element of the workspace buffer; then a second time with a buffer of that size to do the real factorisation.

LAPACK signals success or failure through an integer info out-parameter (info = 0 is success, anything else is an error). This must be checked after every call – both the workspace query and the real call. {pjrt} provides a small helper, lapack_check_info(info, routine_name) in src/ffi_lapack.h, that returns Error::Success() on info == 0 and Error::Internal("<routine_name> failed with info = <info>") otherwise. Wrap each LAPACK call with PJRT_RETURN_IF_ERROR(lapack_check_info(info, "...")) so XLA stops execution rather than reading garbage results downstream.

Windows pitfall: no f32 LAPACK. R on Windows uses its own bundled LAPACK, which only ships the double-precision routines – single -precision is not available. A handler that calls into single-precision LAPACK will fail to load on Windows. macOS and Linux are unaffected (they use the system LAPACK, which has both precisions). The workaround is to promote f32 to f64 around the LAPACK call on Windows; See the existing linear algebra custom calls how to do this via C++ templates.

CUDA

Before getting into CUDA-specific handler details, a note on why the CUDA code in pjrt is structured differently from a “natural” CUDA backend.

The library that we based the custom calls on is jaxlib, which ships separate wheels for CPU and CUDA: jaxlib is the base CPU wheel, jax-cuda12-pjrt / jax-cuda12-plugin are CUDA-specific wheels built with --config=cuda and the CUDA toolkit on the build path. The CUDA wheels can include CUDA SDK (Software Development Kit) headers, link against the CUDA libraries, and use the real types directly.

R packages don’t have an equivalent mechanism. R CMD INSTALL builds one source tree against one Makevars; there is no “platform variant” the way Python wheels have CUDA extras. The same pjrt.so binary has to load and run on machines without the CUDA toolkit installed (and even without an NVIDIA GPU) – otherwise we’d break every CPU-only user. So pjrt’s CUDA path:

  • never #includes any CUDA SDK header,
  • types every CUDA value as opaque (void * for streams and handles, uintptr_t-aliased CUdeviceptr for device addresses) so we don’t depend on the SDK’s typedefs at compile time,
  • maintains a hand-rolled function-pointer table (CudaLibs in ffi_cuda.h) for the cuSOLVER + CUDA driver entry points we use,
  • loads those entry points lazily at runtime via dlopen / dlsym inside get_cuda_libs() (ffi_cuda.cpp), so the package loads cleanly on machines where libcusolver.so and libcuda.so.1 aren’t present (the loader just leaves the table flagged as unloaded, and the CUDA handlers return Error::Internal if invoked).

In effect, CudaLibs + get_cuda_libs() is doing the job of the linker, dynamically and by hand – the symbol lookup that a normal -lcusolver dependency would have the system loader perform at process startup, we instead perform from our own code at runtime. The trade-off is that every cuSOLVER / driver function we want to use has to be transcribed by hand twice: once as a typed function-pointer field in CudaLibs (whose signature must match NVIDIA’s exactly – the compiler can’t check it for us, since we don’t include the SDK headers that would tell it what to expect), and once as a string symbol name passed to dlsym in the loader.

In the future, we might move to a more robust mechanism that avoids these redefinitions, e.g. by including some CUDA headers (which might come with some licensing strings attached) or by e.g. generating the declarations from a DSL.

CUDA execution model

A few CUDA concepts come up in every CUDA handler, so we will briefly review them here.

Host vs. device. The CPU is the host and the GPU is the device. They have separate memory: a host pointer can’t be dereferenced on the device, and a device pointer can’t be dereferenced on the host. Moving data between them requires an explicit memcpy and data can also be directly copied from device to device.

Streams. A CUDA stream is a FIFO (first in first out) queue of operations that the device executes. Operations enqueued on the same stream run in the order they were enqueued. Operations on different streams can run concurrently when the dependency graph allows it (e.g. a memcpy on one stream overlapping a kernel on another).

Asynchronous execution. Most CUDA driver and cuSOLVER calls – including the *Async memcpys, kernel launches, and the cuSOLVER factorisations – are non-blocking on the host. They enqueue work on the stream and return immediately, before the GPU has actually done anything. The host call’s return only signals “this is now queued”; the data isn’t real until the device gets to that point in the queue. This is the source of most CUDA-handler pitfalls: if you read a destination buffer right after issuing an async memcpy into it, you’ll see whatever was there before, because the copy is still queued.

Stream synchronisation. cuStreamSynchronize(stream) blocks the host until every operation previously enqueued on stream has finished. That’s the point at which “queued” turns into “actually done” from the host’s perspective. Inside an FFI handler, you should not call cuStreamSynchronize unless you genuinely need a host-visible result – syncing forces the host to wait for the GPU and breaks the async pipelining XLA depends on. The only place pjrt syncs today is print_tensor’s CUDA path, because it has to copy data back to the host before formatting it. The linalg handlers never sync: they enqueue their work on the stream and return, and XLA’s own scheduling makes sure any downstream op that reads their outputs runs after the kernel has actually completed.

Status codes. Every CUDA driver and cuSOLVER call returns an int status. 0 is success; anything else is an error and must be propagated up as Error::Internal. A dropped status doesn’t crash – it leaves the output buffer with garbage that surfaces much later as wrong numerical answers, with no error message anywhere. The PJRT_RETURN_IF_GPU_ERROR(expr, what) macro in ffi_cuda.h wraps each call site and propagates non-zero statuses uniformly.

Handler signature

A CUDA handler has the same shape as the CPU one, but two extra context parameters get bound at the front. From src/eigh_cuda.cpp:

static Error eigh_cuda_impl(void *stream,
                            ScratchAllocator &scratch,
                            AnyBuffer input,
                            Result<AnyBuffer> v_out,
                            Result<AnyBuffer> w_out);

XLA_FFI_DEFINE_HANDLER(eigh_handler_cuda, do_eigh_cuda,
                       Ffi::Bind()
                           .Ctx<PlatformStream<void *>>()
                           .Ctx<ScratchAllocator>()
                           .Arg<AnyBuffer>()
                           .Ret<AnyBuffer>()
                           .Ret<AnyBuffer>());

The differences from CPU:

  • void *stream – bound via Ctx<PlatformStream<void *>>(), this is the CUDA stream XLA dispatched the op on. The FFI exposes exactly one stream per invocation – there is no API to ask for a transfer stream or any other – and the handler must enqueue all of its device-side work on it: kernel launches, cuSOLVER calls, memcpy_dtod / memcpy_dtoh / memset_d8, any synchronisation.
  • ScratchAllocator &scratch – use this for any device memory the handler needs that isn’t an FFI input or output: workspace for the library call, an input copy, devInfo, etc. Allocations go through XLA’s runtime allocator (the same one that backs FFI input/output buffers, so the bytes are visible to XLA’s memory accounting) and are freed automatically when the handler returns – the FFI binding’s destructor calls back into the runtime to free every allocation it made.
  • Windows. The handler symbol still has to be defined on Windows (so the [[Rcpp::export]] getter resolves), but the dispatcher returns Error(ErrorCode::kUnimplemented, ...) because pjrt has no CUDA support on Windows. See the #ifdef _WIN32 guards in src/eigh_cuda.cpp.

Using cuSOLVER

cuSOLVER is loaded lazily via dlopen (see CUDA intro above), so a CUDA handler that wants to call cuSOLVER goes through the CudaLibs function-pointer table and the Solver prologue defined in ffi_cuda.h. The kernel declares Solver solver(get_cuda_libs()); solver.begin(scratch, stream); and then invokes routines. begin() borrows a stream-bound handle from a pool (see stateful-handle paragraph below) and allocates a devInfo int on device memory.

Because cuSOLVER/CUDA handles are expensive to create, they are cached in a process-wide singleton pool – a map of stream pointer to free-list of handles, protected by a mutex against concurrent host-side borrow / return. A kernel borrows a handle from its stream’s free-list (creating one the first time), uses it, and returns it on scope exit; handles live for the process and are reused across pjrt_execute calls. The design is adapted from jaxlib.

Most cuSOLVER factorisation routines follow the same query / allocate / run idiom as LAPACK (covered in the CPU section), but the workspace query is a separate function with a _bufferSize suffix (cusolverDnSgeqrf_bufferSize, etc.) rather than a lwork = -1 like in LAPACK. Note that computing the buffer size is a cheap host-side computation.

cuSOLVER actually has two status mechanisms, for two different failure modes:

  • The return value of every routine is a host-side int returned synchronously when the call comes back, indicating whether the call was accepted and queued onto the stream. We always check this with PJRT_RETURN_IF_GPU_ERROR.
  • The devInfo int is a device-side int cuSOLVER’s kernel writes during execution, using the same convention as LAPACK’s info (0 = success, <0 = illegal argument at position |info|, >0 = routine-specific). The split exists because cuSOLVER is async on a stream: the host return covers pre-launch failures, but actual runtime / numerical status can only be reported after the kernel has run, and the kernel needs somewhere device-side to write it.

We pass devInfo because we have to, but don’t read it, as it forces synchronization. Again, we follow jaxlib in this decision.

Additional things to pay attention to

There are a few other things to be aware of when implementing custom calls.

Buffer aliasing. Stablehlo inherited XLA’s input-output aliasing. stablehlo.custom_call’s output_operand_aliases attribute lets XLA hand the handler the same pointer for an input and an output. Routines that overwrite their input in place will then read corrupted data. The fix: copy input → destination buffer first, then factor in place there, with a pointer-equality check so the copy is skipped when the buffers already alias (if (in != target) memcpy(...)). jaxlib uses the same pattern in every linalg kernel; pjrt’s maybe_promote_inplace helper bakes the guard in for the CPU side.

Dimension type mismatch and byte-size overflow. Buffer dimensions arrive as int64_t from input.dimensions(), but LAPACK and cuSOLVER take int. A naive cast wraps silently for axes larger than INT_MAX – use the dim_to_int(dim, name, out) helper in ffi_common.h, which casts and returns Error::InvalidArgument on overflow. That covers each individual dimension, but not their product: even after dim_to_int widens m and n to int, an expression like m * n * sizeof(T) still evaluates in int arithmetic and overflows for large matrices (e.g. 50000 × 50000) before the result gets widened to size_t. Cast at least one operand to size_t before the multiplication so the whole expression evaluates in size_t: static_cast<size_t>(m) * n * sizeof(T).

scratch.Allocate(0). ScratchAllocator::Allocate(0) returns std::nullopt, which allocate_workspace<T> propagates as Error::Internal. Downstream use of the custom calls should prevent this.