Skip to contents

This guide explains how to implement a new primitive. It will primarily focus on how to do this. See the internals vignette for more information on how primitives work.

In general, there are two main reasons to add a new primitive:

  1. The operation is not expressible using existing primitives.
  2. The operation (or, for example, its derivative) would benefit from a dedicated implementation.

In order to be able to add a new primitive, it needs to be expressible in the {stablehlo} package. There are various scenarios:

  1. The operation is expressible in the StableHLO language and:
    1. All required operations are already implemented in the {stablehlo} R package.

      \(\rightarrow\) This is the simplest case and the one we assume in this guide.

    2. One or more required operations are not already implemented in the {stablehlo} R package.

      \(\rightarrow\) You first need to implement the missing operations in the {stablehlo} R package. See this issue for which operations are missing and the StableHLO specification for which are available.

  2. The operation is not available/cannot be efficiently expressed in StableHLO, because:
    1. It requires shape dynamism (output shape depends on data, such as using boolean values for subsetting):

      \(\rightarrow\) This is currently not possible, but we hope we can add support for this in the future, e.g. via a second, dynamic Fortran backend.

    2. It cannot be expressed or can only expressed inefficiently:

      \(\rightarrow\) You can implement a stableHLO custom call, see the custom print operation in pjrt. This is currently not well documented, so you need to dig into the source code.

Adding a Primitive: Practical Example

Let’s add a new primitive step by step. We’ll implement prim_repeat_along – a primitive that repeats an array multiple times along a specified dimension.

For example, repeating c(1, 2, 3) twice along dimension 1 gives c(1, 2, 3, 1, 2, 3).

This primitive has a dynamic input (an array) and two static parameters (how many times to repeat and which dimension).

Step 1: Define the Primitive

Primitives are created with new_primitive(): it builds the AnvlPrimitive metadata object that holds the rules, wraps the body with jit(), attaches the metadata, and registers the result in the internal primitive registry. The returned callable becomes the primitive and is bound to a prim_<name> R symbol.

library(anvl)
prim_repeat_along <- new_primitive(
  "repeat_along",
  function(operand, times, dim) {
    # type of operand is checked by graph_desc_add()
    infer_fn <- function(operand, times, dim) {
      if (!checkmate::test_integerish(dim, lower = 1, upper = ndims(operand), len = 1L)) {
        cli::cli_abort("{.arg dim} must be between 1 and {ndims(operand)}, but is {.val dim}")
      }
      if (!checkmate::test_integerish(times, lower = 1, len = 1L)) {
        cli_abort("times must be a positive integer, but is {times}")
      }
      new_shape <- shape(operand)
      new_shape[dim] <- new_shape[dim] * times
      list(AbstractArray(
        dtype = dtype(operand),
        shape = Shape(new_shape),
        ambiguous = operand$ambiguous
      ))
    }

    graph_desc_add(
      self,                       # lexically bound to the AnvlPrimitive
      list(operand = operand),    # Dynamic inputs (arrays)
      params = list(              # Static parameters
        times = times,
        dim = dim
      ),
      infer_fn = infer_fn
    )[[1L]]  # Extract single output from list
  },
  static = c("times", "dim")
)

The primitive is now callable directly as prim_repeat_along(x, times, dim).

Key points:

  • Pass the lexically-bound self (the [AnvlPrimitive]) as the first argument to graph_desc_add(). new_primitive() installs self into the enclosing environment of the body, so this is always available inside a primitive.
  • The inference function receives abstract arrays (types, not values) for dynamic inputs, and actual values for static parameters. It must verify that the arguments to the function are valid, and should throw clear error messages — {anvl} programs are otherwise hard to debug.
  • graph_desc_add() returns a list of outputs; use [[1L]] for single-output primitives.
  • Propagate the ambiguous flag from inputs to outputs, see type promotion for what this means.

Every argument that is not a dynamic array must be listed in static. new_primitive calls jit() with backend = "auto" internally, which defers the choice of backend to call time.

If the primitive is a wrapper around a stablehlo operation, it is possible to use the corresponding inference function from the stablehlo package (such as stablehlo::infer_types_concatenate). When doing so, you need to:

  1. Convert the abstract arrays to stablehlo ValueTypes using at2vt().
  2. Call the stablehlo inference function and obtain a list of ValueTypes.
  3. Convert the ValueTypes back to abstract arrays using vt2at().
  4. Set the ambiguous flag of the output depending on the inputs (ambiguity is strictly an {anvl} concept, not a stablehlo concept).

Special case: primitives with 0 dynamic inputs

Array constructors (e.g. prim_fill, prim_iota) have no dynamic array inputs – every argument is static. Two things change in this case:

  1. All formals must be listed in static. Otherwise {anvl} would try to interpret a scalar argument like value or shape as an AnvlArray input.
  2. backend = "auto" cannot infer a device from inputs, because there are no array inputs. Instead, add a device formal to the function and pass device = device_arg("device") to new_primitive(). At call time, the user-supplied device is read from that argument and used to determine both the backend (via [backend()] dispatch) and the compilation device.

For example, the real definition of prim_fill looks like this:

prim_fill <- new_primitive(
  "fill",
  function(value, shape, dtype, ambiguous = FALSE, device = NULL) {
    # ... graph_desc_add(self, ...) ...
  },
  static = 1:5,
  device = device_arg("device")
)

See [device_arg()] for the detailed semantics.

Shortcut helpers for common shapes

prim_repeat_along has its own parameters and needs a custom body. Many primitives are simpler – elementwise unary or binary ops, reductions, and comparisons all share a standard shape. R/primitives.R exposes factories that generate the body for you:

  • make_unary_op(stablehlo_infer) – elementwise unary (e.g. prim_abs, prim_negate).
  • make_binary_op(stablehlo_infer) – elementwise binary (e.g. prim_add, prim_mul).
  • make_reduce_op(infer_fn) – reductions with dims / drop parameters (e.g. prim_reduce_sum).
  • make_compare_op(direction) – comparison ops with a fixed direction string (e.g. prim_eq, prim_lt).

Used together with the corresponding stablehlo inference function, the primitive definition collapses to a one-liner:

prim_add <- new_primitive("add", make_binary_op(stablehlo::infer_types_add))
prim_negate <- new_primitive("negate", make_unary_op(stablehlo::infer_types_negate))
prim_reduce_sum <- new_primitive("reduce_sum", make_reduce_op(), static = 2:3)

Reach for the manual graph_desc_add() form when the primitive has extra static parameters, a custom inference function, or multiple outputs.

Step 2: Add the StableHLO Rule

The StableHLO rule defines how to lower this primitive to actual operations, which is used in the stablehlo() lowering pass. We implement repeat_along using concatenation:

prim_repeat_along[["stablehlo"]] <- function(operand, times, dim) {
  operands <- rep(list(operand), times)
  list(rlang::exec(stablehlo::hlo_concatenate, !!!operands, dimension = dim - 1L))
}

The rule receives:

It must return a list of stablehlo::FuncValues, even if there is only one output.

Important: StableHLO uses 0-based indexing, while {anvl} uses R’s 1-based indexing. Always convert dimension indices by subtracting 1. Also note that in graph_desc_add() we are converting the error messages from stablehlo to our 1-based indexing, so you do not have to worry about that here.

Step 3: Add the Reverse Rule

If the operation should support automatic differentiation, add a reverse rule. The idea here is the following, where we assume the input operand has shape (s_1, ..., s_n), which means that the output (and therefore it’s gradient) has shape (s_1, ..., s_{dim-1}, s_dim * times, s_{dim+1}, ..., s_n).

  1. Reshape the gradient to (s_1, ..., s_{dim-1}, s_dim, times, s_{dim+1}, ..., s_n).
  2. Sum over the times dimension and drop the times dimension.
prim_repeat_along[["reverse"]] <- function(inputs, outputs, grads, dim, times, .required) {
  if (!.required[[1L]]) {
    return(list(NULL))
  }

  grad <- grads[[1L]]
  operand <- inputs[[1L]]

  old_shape <- shape(operand)
  grad_shape <- shape(grad)

  new_shape <- grad_shape
  new_shape[dim] <- old_shape[dim]
  new_shape <- append(new_shape, times, after = dim - 1L)

  grad_reshaped <- prim_reshape(grad, new_shape)
  grad_summed <- prim_reduce_sum(grad_reshaped, dims = dim, drop = TRUE)
  list(grad_summed)
}

The reverse rule receives:

  • inputs: Input GraphValues from the forward pass
  • outputs: Output GraphValues from the forward pass
  • grads: Gradients flowing back from downstream (one per output)
  • Static parameters by name (here: dim, times)
  • .required: Logical vector indicating which input gradients are needed

It returns a list with one gradient per input (or NULL if not required).

Optional: a quickr rule

prim_<name>[["stablehlo"]] and prim_<name>[["reverse"]] are the two rules you almost always want. A primitive can optionally also carry a quickr rule, which lowers it to plain R code for the quickr backend (see R/rules-quickr.R). The quickr rule is only required if you want the primitive to run under local_backend("quickr"); if you skip it, the primitive will still work on the xla backend.

Step 4: Verify the Registration

new_primitive() returns the callable directly and also registers the primitive in the internal registry used by the graph machinery, so no separate registration step is needed:

prim_repeat_along
#> function (operand, times, dim) 
#> {
#>     if (currently_tracing()) {
#>         cl <- match.call()
#>         cl[[1L]] <- f
#>         return(eval.parent(cl))
#>     }
#>     args <- lapply(as.list(match.call())[-1L], eval, envir = parent.frame())
#>     be <- if (!is.null(device_argname) && !is.null(args[[device_argname]])) {
#>         dev_val <- args[[device_argname]]
#>         if (is.character(dev_val)) 
#>             default_backend()
#>         else backend(dev_val)
#>     }
#>     else {
#>         jit_auto_detect_backend(flatten(args[!names(args) %in% 
#>             static]))
#>     }
#>     if (is.null(jit_fns[[be]])) {
#>         jit_fns[[be]] <<- do.call(jit_with_backend, c(list(f = f, 
#>             static = static, cache_size = cache_size, backend = be), 
#>             if (!is.null(device_argname)) {
#>                 list(device = device_arg(device_argname))
#>             } else if (!is.null(device)) {
#>                 list(device = device)
#>             }, dots))
#>     }
#>     do.call(jit_fns[[be]], args)
#> }
#> <environment: 0x5620b69b2548>
#> attr(,"class")
#> [1] "JitPrimitive" "JitFunction" 
#> attr(,"backend")
#> [1] "auto"
#> attr(,"primitive")
#> <AnvlPrimitive:repeat_along>

Step 5: Add an nv_ API Function

In {anvl}, we also offer convenience wrappers around the primitives. An example is prim_add vs nv_add, where the latter calls into the former after optionally broadcasting (scalar) inputs:

nv_add(1L, nv_array(2:3))
#> AnvlArray
#>  3
#>  4
#> [ CPUi32{2} ]
prim_add(1L, nv_array(2:3))
#> Error in `prim_add()`:
#> ! `lhs` and `rhs` must have the same tensor type.
#>  Got tensor<i32> and tensor<2xi32>.

In our case, no such convenience is needed and the functionality is not too low-level (for it to be generally useful), so we can just reassign the prim_* function to an nv_* function:

nv_repeat_along <- prim_repeat_along

Note that in the nv_* wrapper function, you can only access certain properties of the input arrayish values via:

If you, for example, use shape() instead of shape_abstract(), your function won’t work with R literals. I.e., <extract>_abstract() first converts the input to an AbstractArray (if possible) and then extracts the property.

Using Your Primitive

You can now use the primitive, both eagerly and inside a larger JIT-compiled function:

x <- nv_array(c(1, 2, 3), shape = c(3, 1))
# Eager call -- works because prim_repeat_along is itself jit-compiled.
prim_repeat_along(x, times = 2L, dim = 2L)
#> AnvlArray
#>  1 1
#>  2 2
#>  3 3
#> [ CPUf32{3,2} ]
# Traced into the outer graph when composed with another jit-compiled function.
jit(function(x) prim_repeat_along(x, times = 2L, dim = 2L))(x)
#> AnvlArray
#>  1 1
#>  2 2
#>  3 3
#> [ CPUf32{3,2} ]

And compute gradients through it.

f <- function(x) {
  repeated <- prim_repeat_along(x, times = 2L, dim = 2L)
  sum(repeated)
}

grad_f <- jit(gradient(f))
grad_f(x)[[1L]]
#> AnvlArray
#>  2
#>  2
#>  2
#> [ CPUf32{3,1} ]

Contributing to the Package

If you want to contribute your primitive to {anvl}, there are some additional things to be aware of.

File Organization

  • R/primitives.R: Define the prim_* primitive via new_primitive()
  • R/rules-stablehlo.R: Add the StableHLO lowering rule
  • R/rules-reverse.R: Add the reverse rule (if differentiable)
  • R/rules-quickr.R: Add the quickr lowering rule (optional; only if the primitive should run on the quickr backend)
  • R/api.R: Add the nv_* wrapper function (or possibly in another api file).

Testing

Tests can go in two places:

  1. inst/extra-tests/: For tests that compare against torch. These live in inst/ to avoid listing torch as a dependency.
  2. tests/testthat/: For tests without a torch counterpart.

Important: the describe() / test_that() label must contain the full primitive name, e.g. describe("prim_repeat_along", { ... }). The meta tests in tests/testthat/test-primitives-meta.R verify that every primitive has corresponding stablehlo and reverse tests, and flag any that are missing.

Since no torch counterpart exists for prim_repeat_along, we would add manual tests in:

  • tests/testthat/test-primitives-stablehlo.R
  • tests/testthat/test-primitives-reverse.R

Also, ensure that no linter errors are present, devtools::check() passes, and format the code using make format.

Higher-Order Primitives

Higher-Order Primitives are primitives that parameterized by an R function or expression. Examples include prim_if and prim_while. These are generally much more complex to handle, so we don’t cover them here in detail (for now). The general idea, however, is that the primitive prim_* function needs to trace the provided function using trace_fn() and then forward this graph to the stablehlo lowering rule and reverse rule.