Wraps a function so that it is traced, lowered to StableHLO, and compiled to an XLA
executable on first call. Subsequent calls with the same input shapes and dtypes hit an
LRU cache and skip recompilation. Unlike xla(), the compiled executable is not
created eagerly but lazily on the first invocation.
Arguments
- f
(
function)
Function to compile. Must accept and returnAnvilTensors (and/or static arguments).- static
(
character())
Names of parameters offthat are not tensors. Static values are embedded as constants in the compiled program; a new compilation is triggered whenever a static value changes. For example useful when you want R control flow in your function.- cache_size
(
integer(1))
Maximum number of compiled executables to keep in the LRU cache.- donate
(
character())
Names of the arguments whose buffers should be donated. Donated buffers can be aliased with outputs of the same type, allowing in-place operations and reducing memory usage. An argument cannot appear in bothdonateandstatic.- device
(
NULL|character(1)|PJRTDevice)
The device to use if it cannot be inferred from the inputs or constants. Defaults to"cpu".
See also
xla() for ahead-of-time compilation, jit_eval() for evaluating an expression once.
Examples
f <- jit(function(x, y) x + y)
f(nv_tensor(1), nv_tensor(2))
#> AnvilTensor
#> 3
#> [ CPUf32{1} ]
# Static arguments enable data-dependent control flow
g <- jit(function(x, flag) {
if (flag) x + 1 else x * 2
}, static = "flag")
g(nv_tensor(3), TRUE)
#> AnvilTensor
#> 4
#> [ CPUf32{1} ]
g(nv_tensor(3), FALSE)
#> AnvilTensor
#> 6
#> [ CPUf32{1} ]