This guide explains how to implement a new primitive. It will primarily focus on how to do this. See the internals vignette for more information on how primitives work.
In general, there are two main reasons to add a new primitive:
- The operation is not expressible using existing primitives.
- The operation (or, for example, its derivative) would benefit from a dedicated implementation.
In order to be able to add a new primitive, it needs to be expressible in the {stablehlo} package. There are various scenarios:
- The operation is expressible in the StableHLO language and:
-
All required operations are already implemented in the {stablehlo} R package.
\(\rightarrow\) This is the simplest case and the one we assume in this guide.
-
One or more required operations are not already implemented in the {stablehlo} R package.
\(\rightarrow\) You first need to implement the missing operations in the {stablehlo} R package. See this issue for which operations are missing and the StableHLO specification for which are available.
-
- The operation is not available/cannot be efficiently expressed in
StableHLO, because:
-
It requires shape dynamism (output shape depends on data, such as using boolean values for subsetting):
\(\rightarrow\) This is currently not possible, but we hope we can add support for this in the future, e.g. via a second, dynamic Fortran backend.
-
It cannot be expressed or can only expressed inefficiently:
\(\rightarrow\) You can implement a stableHLO custom call, see the custom print operation in pjrt. This is currently not well documented, so you need to dig into the source code.
-
Adding a Primitive: Practical Example
Let’s add a new primitive step by step. We’ll implement
prim_repeat_along – a primitive that repeats an array
multiple times along a specified dimension.
For example, repeating c(1, 2, 3) twice along dimension
1 gives c(1, 2, 3, 1, 2, 3).
This primitive has a dynamic input (an array) and two static parameters (how many times to repeat and which dimension).
Step 1: Define the Primitive
Primitives are created with new_primitive(): it builds
the AnvlPrimitive metadata object that holds the rules,
wraps the body with jit(), attaches the metadata, and
registers the result in the internal primitive registry. The returned
callable becomes the primitive and is bound to a
prim_<name> R symbol.
library(anvl)
prim_repeat_along <- new_primitive(
"repeat_along",
function(operand, times, dim) {
# type of operand is checked by graph_desc_add()
infer_fn <- function(operand, times, dim) {
if (!checkmate::test_integerish(dim, lower = 1, upper = ndims(operand), len = 1L)) {
cli::cli_abort("{.arg dim} must be between 1 and {ndims(operand)}, but is {.val dim}")
}
if (!checkmate::test_integerish(times, lower = 1, len = 1L)) {
cli_abort("times must be a positive integer, but is {times}")
}
new_shape <- shape(operand)
new_shape[dim] <- new_shape[dim] * times
list(AbstractArray(
dtype = dtype(operand),
shape = Shape(new_shape),
ambiguous = operand$ambiguous
))
}
graph_desc_add(
self, # lexically bound to the AnvlPrimitive
list(operand = operand), # Dynamic inputs (arrays)
params = list( # Static parameters
times = times,
dim = dim
),
infer_fn = infer_fn
)[[1L]] # Extract single output from list
},
static = c("times", "dim")
)The primitive is now callable directly as
prim_repeat_along(x, times, dim).
Key points:
- Pass the lexically-bound
self(the [AnvlPrimitive]) as the first argument tograph_desc_add().new_primitive()installsselfinto the enclosing environment of the body, so this is always available inside a primitive. - The inference function receives abstract arrays (types, not values) for dynamic inputs, and actual values for static parameters. It must verify that the arguments to the function are valid, and should throw clear error messages — {anvl} programs are otherwise hard to debug.
-
graph_desc_add()returns a list of outputs; use[[1L]]for single-output primitives. - Propagate the
ambiguousflag from inputs to outputs, see type promotion for what this means.
Every argument that is not a dynamic array must be listed in
static. new_primitive calls jit()
with backend = "auto" internally, which defers the choice
of backend to call time.
If the primitive is a wrapper around a stablehlo operation, it is
possible to use the corresponding inference function from the stablehlo
package (such as stablehlo::infer_types_concatenate). When
doing so, you need to:
- Convert the abstract arrays to stablehlo
ValueTypes usingat2vt(). - Call the stablehlo inference function and obtain a list of
ValueTypes. - Convert the
ValueTypes back to abstract arrays usingvt2at(). - Set the
ambiguousflag of the output depending on the inputs (ambiguityis strictly an {anvl} concept, not a stablehlo concept).
Special case: primitives with 0 dynamic inputs
Array constructors (e.g. prim_fill,
prim_iota) have no dynamic array inputs – every argument is
static. Two things change in this case:
-
All formals must be listed in
static. Otherwise {anvl} would try to interpret a scalar argument likevalueorshapeas anAnvlArrayinput. -
backend = "auto"cannot infer a device from inputs, because there are no array inputs. Instead, add adeviceformal to the function and passdevice = device_arg("device")tonew_primitive(). At call time, the user-supplied device is read from that argument and used to determine both the backend (via [backend()] dispatch) and the compilation device.
For example, the real definition of prim_fill looks like
this:
prim_fill <- new_primitive(
"fill",
function(value, shape, dtype, ambiguous = FALSE, device = NULL) {
# ... graph_desc_add(self, ...) ...
},
static = 1:5,
device = device_arg("device")
)See [device_arg()] for the detailed semantics.
Shortcut helpers for common shapes
prim_repeat_along has its own parameters and needs a
custom body. Many primitives are simpler – elementwise unary or binary
ops, reductions, and comparisons all share a standard shape.
R/primitives.R exposes factories that generate the body for
you:
-
make_unary_op(stablehlo_infer)– elementwise unary (e.g.prim_abs,prim_negate). -
make_binary_op(stablehlo_infer)– elementwise binary (e.g.prim_add,prim_mul). -
make_reduce_op(infer_fn)– reductions withdims/dropparameters (e.g.prim_reduce_sum). -
make_compare_op(direction)– comparison ops with a fixeddirectionstring (e.g.prim_eq,prim_lt).
Used together with the corresponding stablehlo inference function, the primitive definition collapses to a one-liner:
prim_add <- new_primitive("add", make_binary_op(stablehlo::infer_types_add))
prim_negate <- new_primitive("negate", make_unary_op(stablehlo::infer_types_negate))
prim_reduce_sum <- new_primitive("reduce_sum", make_reduce_op(), static = 2:3)Reach for the manual graph_desc_add() form when the
primitive has extra static parameters, a custom inference function, or
multiple outputs.
Step 2: Add the StableHLO Rule
The StableHLO rule defines how to lower this primitive to actual
operations, which is used in the stablehlo() lowering pass.
We implement repeat_along using concatenation:
prim_repeat_along[["stablehlo"]] <- function(operand, times, dim) {
operands <- rep(list(operand), times)
list(rlang::exec(stablehlo::hlo_concatenate, !!!operands, dimension = dim - 1L))
}The rule receives:
- Dynamic inputs as
stablehlo::FuncValues. - Static parameters as their R values.
It must return a list of stablehlo::FuncValues, even if
there is only one output.
Important: StableHLO uses 0-based indexing, while
{anvl} uses R’s 1-based indexing. Always convert dimension indices by
subtracting 1. Also note that in graph_desc_add() we are
converting the error messages from stablehlo to our 1-based indexing, so
you do not have to worry about that here.
Step 3: Add the Reverse Rule
If the operation should support automatic differentiation, add a
reverse rule. The idea here is the following, where we assume the input
operand has shape (s_1, ..., s_n), which means
that the output (and therefore it’s gradient) has shape
(s_1, ..., s_{dim-1}, s_dim * times, s_{dim+1}, ..., s_n).
- Reshape the gradient to
(s_1, ..., s_{dim-1}, s_dim, times, s_{dim+1}, ..., s_n). - Sum over the
timesdimension and drop thetimesdimension.
prim_repeat_along[["reverse"]] <- function(inputs, outputs, grads, dim, times, .required) {
if (!.required[[1L]]) {
return(list(NULL))
}
grad <- grads[[1L]]
operand <- inputs[[1L]]
old_shape <- shape(operand)
grad_shape <- shape(grad)
new_shape <- grad_shape
new_shape[dim] <- old_shape[dim]
new_shape <- append(new_shape, times, after = dim - 1L)
grad_reshaped <- prim_reshape(grad, new_shape)
grad_summed <- prim_reduce_sum(grad_reshaped, dims = dim, drop = TRUE)
list(grad_summed)
}The reverse rule receives:
-
inputs: InputGraphValues from the forward pass -
outputs: OutputGraphValues from the forward pass -
grads: Gradients flowing back from downstream (one per output) - Static parameters by name (here:
dim,times) -
.required: Logical vector indicating which input gradients are needed
It returns a list with one gradient per input (or NULL
if not required).
Optional: a quickr rule
prim_<name>[["stablehlo"]] and
prim_<name>[["reverse"]] are the two rules you almost
always want. A primitive can optionally also carry a quickr
rule, which lowers it to plain R code for the quickr backend (see
R/rules-quickr.R). The quickr rule is only required if you
want the primitive to run under local_backend("quickr"); if
you skip it, the primitive will still work on the xla backend.
Step 4: Verify the Registration
new_primitive() returns the callable directly and also
registers the primitive in the internal registry used by the graph
machinery, so no separate registration step is needed:
prim_repeat_along
#> function (operand, times, dim)
#> {
#> if (currently_tracing()) {
#> cl <- match.call()
#> cl[[1L]] <- f
#> return(eval.parent(cl))
#> }
#> args <- lapply(as.list(match.call())[-1L], eval, envir = parent.frame())
#> be <- if (!is.null(device_argname) && !is.null(args[[device_argname]])) {
#> dev_val <- args[[device_argname]]
#> if (is.character(dev_val))
#> default_backend()
#> else backend(dev_val)
#> }
#> else {
#> jit_auto_detect_backend(flatten(args[!names(args) %in%
#> static]))
#> }
#> if (is.null(jit_fns[[be]])) {
#> jit_fns[[be]] <<- do.call(jit_with_backend, c(list(f = f,
#> static = static, cache_size = cache_size, backend = be),
#> if (!is.null(device_argname)) {
#> list(device = device_arg(device_argname))
#> } else if (!is.null(device)) {
#> list(device = device)
#> }, dots))
#> }
#> do.call(jit_fns[[be]], args)
#> }
#> <environment: 0x5620b69b2548>
#> attr(,"class")
#> [1] "JitPrimitive" "JitFunction"
#> attr(,"backend")
#> [1] "auto"
#> attr(,"primitive")
#> <AnvlPrimitive:repeat_along>Step 5: Add an nv_ API Function
In {anvl}, we also offer convenience wrappers around the primitives.
An example is prim_add vs nv_add, where the
latter calls into the former after optionally broadcasting (scalar)
inputs:
nv_add(1L, nv_array(2:3))
#> AnvlArray
#> 3
#> 4
#> [ CPUi32{2} ]
prim_add(1L, nv_array(2:3))
#> Error in `prim_add()`:
#> ! `lhs` and `rhs` must have the same tensor type.
#> ✖ Got tensor<i32> and tensor<2xi32>.In our case, no such convenience is needed and the functionality is
not too low-level (for it to be generally useful), so we can just
reassign the prim_* function to an nv_*
function:
nv_repeat_along <- prim_repeat_alongNote that in the nv_* wrapper function, you can only
access certain properties of the input arrayish values via:
If you, for example, use shape() instead of
shape_abstract(), your function won’t work with R literals.
I.e., <extract>_abstract() first converts the input
to an AbstractArray (if possible) and then extracts the
property.
Using Your Primitive
You can now use the primitive, both eagerly and inside a larger JIT-compiled function:
x <- nv_array(c(1, 2, 3), shape = c(3, 1))
# Eager call -- works because prim_repeat_along is itself jit-compiled.
prim_repeat_along(x, times = 2L, dim = 2L)
#> AnvlArray
#> 1 1
#> 2 2
#> 3 3
#> [ CPUf32{3,2} ]
# Traced into the outer graph when composed with another jit-compiled function.
jit(function(x) prim_repeat_along(x, times = 2L, dim = 2L))(x)
#> AnvlArray
#> 1 1
#> 2 2
#> 3 3
#> [ CPUf32{3,2} ]And compute gradients through it.
Contributing to the Package
If you want to contribute your primitive to {anvl}, there are some additional things to be aware of.
File Organization
-
R/primitives.R: Define theprim_*primitive vianew_primitive() -
R/rules-stablehlo.R: Add the StableHLO lowering rule -
R/rules-reverse.R: Add the reverse rule (if differentiable) -
R/rules-quickr.R: Add the quickr lowering rule (optional; only if the primitive should run on the quickr backend) -
R/api.R: Add thenv_*wrapper function (or possibly in another api file).
Testing
Tests can go in two places:
-
inst/extra-tests/: For tests that compare against torch. These live ininst/to avoid listing torch as a dependency. -
tests/testthat/: For tests without a torch counterpart.
Important: the describe() /
test_that() label must contain the full primitive name,
e.g. describe("prim_repeat_along", { ... }). The meta tests
in tests/testthat/test-primitives-meta.R verify that every
primitive has corresponding stablehlo and reverse tests, and flag any
that are missing.
Since no torch counterpart exists for prim_repeat_along,
we would add manual tests in:
tests/testthat/test-primitives-stablehlo.Rtests/testthat/test-primitives-reverse.R
Also, ensure that no linter errors are present,
devtools::check() passes, and format the code using
make format.
Higher-Order Primitives
Higher-Order Primitives are primitives that parameterized by an R
function or expression. Examples include prim_if and
prim_while. These are generally much more complex to
handle, so we don’t cover them here in detail (for now). The general
idea, however, is that the primitive prim_* function needs
to trace the provided function using trace_fn() and then
forward this graph to the stablehlo lowering rule and reverse rule.