Skip to contents

The {pjrt} package provides an R interface to PJRT (Pretty much Just another RunTime), which allows you to run XLA and stableHLO programs on various hardware backends. These programs are framework and hardware agnostic, which means they can be generated by ML frameworks such as jax, and run by PJRT on a specified backend (CPU, GPU, etc.). For a low-level R interface to create stableHLO programs, see the stablehlo package.

Installation

From GitHub:

pak::pak("r-xla/pjrt")

You can also install from r-universe, by adding the code below to your .Rprofile.

options(repos = c(
  rxla = "https://r-xla.r-universe.dev",
  CRAN = "https://cloud.r-project.org/"
))

Quick Start

Below, we create and run a stableHLO program that adds two f32 tensors of shape (2, 2).

library(pjrt)
src <- r"(
func.func @main(
  %x: tensor<2x2xf32>,
  %y: tensor<2x2xf32>
) -> tensor<2x2xf32> {
  %0 = "stablehlo.add"(%x, %y) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
  "func.return"(%0): (tensor<2x2xf32>) -> ()
}
)"
program <- pjrt_program(src, format = "mlir")
program
#> PJRTProgram(format=mlir, code_size=221)
#> 
#> func.func @main(
#>   %x: tensor<2x2xf32>,
#>   %y: tensor<2x2xf32>
#> ) -> tensor<2x2xf32> {
#> ...
executable <- pjrt_compile(program, client = "cpu")

x <- pjrt_buffer(c(1, 2, 3, 4), shape = c(2, 2), dtype = "f32")
x
#> PJRTBuffer 
#>  1.0000 3.0000
#>  2.0000 4.0000
#> [ CPUf32{2x2} ]
y <- pjrt_buffer(c(5, 6, 7, 8), shape = c(2, 2), dtype = "f32")
y
#> PJRTBuffer 
#>  5.0000 7.0000
#>  6.0000 8.0000
#> [ CPUf32{2x2} ]

pjrt_execute(executable, x, y)
#> PJRTBuffer 
#>   6.0000 10.0000
#>   8.0000 12.0000
#> [ CPUf32{2x2} ]

Main Features

  • Compile stableHLO programs into hardware-specific executables.
  • Provide a runtime to execute compiled programs.
  • Convert buffers to and from R arrays and vectors.
  • Read and write buffers using the safetensors format.

Platform Support

  • Linux
    • CPU backend is fully supported.
    • CUDA (NVIDIA GPU) backend is fully supported.
  • Windows
    • CPU backend is fully supported.
    • ⚠️ GPU is only supported via Windows Subsystem for Linux (WSL2).
  • macOS
    • CPU backend is supported.
    • ⚠️ Metal (Apple GPU) backend is available but not fully functional.

Acknowledgements

  • The development of this package is supported by MaRDI.
  • Without OpenXLA, none of this would be possible.
  • The design of the {pjrt} package was inspired by the gopjrt implementation.
  • The project also uses various components from OpenXLA:
  • For Metal, we are using the plugin implementation from jax-metal.