Skip to contents

Main Classes

The {pjrt} package provides an R interface to PJRT. PJRT is part of the OpenXLA project, which is an open source compiler ecosystem for machine learning or more generally Accelerated Linear Algebra (XLA). It defines a standardized interface which serves as a portability layer that allows a single framework to be used with different hardware backends. PJRT defines what a concrete “backend” needs to be able to do in order to be usable with a frontend (like JAX) using PJRT. Anyone can implement a specific PJRT plugin (PJRTPlugin) for a specific hardware backend. When using PJRT, such a plugin is loaded as a shared library and is wrapped in a PJRT client (PJRTClient). In this package, we currently support CPU plugins, NVIDIA GPU plugins, and Metal (Apple GPU; experimental) plugins.

When creating a plugin upon first use, the shared library is downloaded and cached.

library(pjrt)
plugin <- pjrt_plugin("cpu")
plugin
#> <PJRTPlugin:cpu>

Such a plugin can be used to create a client for a specific platform.

client <- plugin_client_create(plugin, "cpu")
client
#> <PJRTClient:cpu>

Currently, there will be exactly one client per platform and they are stored in a global cache.

the <- getFromNamespace("the", "pjrt")
the[["clients"]][["cpu"]]
#> <PJRTClient:cpu>
the[["plugins"]][["cpu"]]
#> <PJRTPlugin:cpu>

A client can have one or more devices. This is not relevant for CPU clients, but is in principle useful for GPUs, although this is not fully supported yet.

cpu_devices <- devices(client)
cpu_device <- cpu_devices[[1L]]

The most important operations that a client supports are data handling, compilation and execution.

Data Handling

Using a client, we can move data from the host (standard R objects) to a specific device to create a buffer (PJRTBuffer) (also known as a tensor or array). A buffer has a type, which consists of the types of the (homogenous) elements (PJRTElementType) and the shape. Scalars are represented as 0-dimensional buffers and dimensions can be 0.

host_data <- 1:4
device_buffer <- pjrt_buffer(host_data, device = cpu_device, shape = c(2L, 2L), dtype = "f32")
device_buffer
#> PJRTBuffer 
#>  1.0000 3.0000
#>  2.0000 4.0000
#> [ CPUf32{2x2} ]

After creation, we can inspect the shape and element type of the buffer.

shape(device_buffer)
#> [1] 2 2
elt_type(device_buffer)
#> <f32>
device(device_buffer)
#> <CpuDevice(id=0)>

We can also move the data back to the host from a buffer.

host_data_from_buffer <- as_array(device_buffer)
host_data_from_buffer
#>      [,1] [,2]
#> [1,]    1    3
#> [2,]    2    4

In principle, it is also supported to move data from one device to another, but this is not supported yet.

Compilation

PJRT offers an interface to compile programs into executables. These programs can be written in HLO or the newer StableHLO. The stablehlo package allows you to easily create StableHLO programs in R. Below, we define a simple program that adds two f32 tensors of shape (2, 2).

src <- r"(
func.func @main(
  %x: tensor<2x2xf32>,
  %y: tensor<2x2xf32>
) -> tensor<2x2xf32> {
  %0 = "stablehlo.add"(%x, %y) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
  "func.return"(%0): (tensor<2x2xf32>) -> ()
}
)"

We first create a program (PJRTProgram) object from the source code.

program <- pjrt_program(src, format = "mlir")
program
#> PJRTProgram(format=mlir, code_size=221)
#> 
#> func.func @main(
#>   %x: tensor<2x2xf32>,
#>   %y: tensor<2x2xf32>
#> ) -> tensor<2x2xf32> {
#> ...

We can’t really do anything with this, except for compiling it into an executable (PJRTLoadedExecutable). For compilation, we need to specify the client to use, because the executable depends on the platform as the GPU binary will differ from the CPU binary.

executable <- pjrt_compile(program, client = client)
executable
#> <PJRTLoadedExecutable>

Execution

With buffers in place and an executable, we can execute the program. We simply pass the same buffer twice as input. By default, single outputs are unpacked into a single buffer instead of a list of buffers, but we can disable this by setting simplify = FALSE.

result <- pjrt_execute(executable, device_buffer, device_buffer)

It would have also been possible to pass execution options (PJRTExecuteOptions) to the execution.

Serialization

We support reading and writing buffers using the safetensors format, which allows storing named lists of buffers.

tmp <- tempfile(fileext = ".safetensors")
safetensors::safe_save_file(list(x = result), tmp, framework = "pjrt")
reloaded <- safetensors::safe_load_file(tmp, framework = "pjrt")
reloaded$x
#> PJRTBuffer 
#>  2.0000 6.0000
#>  4.0000 8.0000
#> [ CPUf32{2x2} ]