Skip to contents

Builds an AnvlPrimitive metadata object, wraps fn with jit(), attaches the metadata via attr(., "primitive"), prepends class "JitPrimitive", and (by default) registers the result under name in the primitive registry.

The backend is always "auto" and cannot be configured.

Usage

new_primitive(
  name,
  fn,
  subgraphs = character(),
  static = character(),
  device = NULL,
  register = TRUE
)

Arguments

name

(character(1))
Primitive name.

fn

(function)
Body of the primitive. Its formals become the formals of the returned JIT-compiled callable. Inside fn, the primitive is accessible via the lexically-bound symbol self (an AnvlPrimitive); pass it as the first argument to graph_desc_add().

subgraphs

(character())
Names of parameters that are subgraphs (for higher-order primitives).

static

(character() | integer())
Passed to jit().

device

(NULL | character(1) | device_arg())
Passed to jit(). Useful for primitives with no array inputs (e.g. prim_fill) where the device must come from an explicit argument.

register

(logical(1))
If TRUE (default), register the result under name in the primitive registry.

Value

A callable of class c("JitPrimitive", "JitFunction").