This vignette covers the static shape restriction within
jit(). First, we describe what this means and then we
discuss how to work around it.
Whenever the XLA compiler that underpins jit() compiles
a program, it must know the shape of every intermediate value at compile
time. This means that functions such as unique() cannot be
part of a jit-compiled function, because their output shape depends on
the runtime values (unique(c(1, 1)) outputs a length-1
vector, but unique(c(1, 2)) outputs a length-2 vector).
Other examples are which(x > 0) or
x[x > 0].
While this restriction is less ergonomic, it means the compiler knows
more about your program and your compiled executable will run faster.
The cost is that some operations require rethinking, and a few cannot
currently be expressed inside jit() at all.
For many operations, you can work around the restriction by keeping
shapes fixed and using a logical mask – the same idea presented in the
Padding section of the efficiency vignette. We cover some
common patterns below. For operations we cannot yet handle in-graph
(such as which() or unique()), you currently
need to convert the AnvlArray back to R, apply the
operation there, and then convert the result back to an
AnvlArray to resume the computation. We hope to lift the
static-shape restriction in the long term and add functions like
which() and unique() to {anvl}’s eager API to
make this more ergonomic.
The masking pattern
The usual workaround is to keep the output shape equal to the input shape – so that everything remains known at trace time – and carry a logical mask that tells us which positions count. Any downstream operation is then modified to ignore the masked-out positions.
Throughout, we will use a single example vector, x, and
compute things that in plain R we would write as
sum(x[x > 0]), max(x[x > 0]), and so
on:
Masked sum
For a sum over the positive entries, we don’t have to filter at all –
we can replace each non-matching entry with 0 and sum the
whole vector. Because adding 0 does nothing, the masked-out
positions don’t contribute to the total:
sum_positive <- jit(function(x) {
nv_reduce_sum(nv_ifelse(x > 0, x, nv_fill_like(x, 0)), dims = 1L)
})
sum_positive(x)
#> AnvlArray
#> 11
#> [ CPUf32{} ]This trick works because 0 is the neutral value
of addition. The same idea generalizes to any reduction whose operation
has such a neutral value:
| Operation | Neutral value |
|---|---|
sum (and additive reductions) |
0 |
prod (and multiplicative reductions) |
1 |
max |
-Inf (or the dtype’s minimum) |
min |
+Inf (or the dtype’s maximum) |
any |
FALSE |
all |
TRUE |
Masked mean
The mean is the case where this trick alone is not enough:
mean(x[x > 0]) divides by the number of
matching entries, which depends on the data. The matching count
is a scalar, though, so we can compute it inside jit() by
summing the boolean mask, and then divide separately:
mean_positive <- jit(function(x) {
mask <- x > 0
total <- nv_reduce_sum(nv_ifelse(mask, x, 0), dims = 1L)
n <- nv_reduce_sum(mask, dims = 1L)
total / n
})
mean_positive(x)
#> AnvlArray
#> 11
#> [ CPUf32{} ]Note that the divisor is the matching count, not
length(x). This is the workaround for
length(x[cond]) more generally: convert the mask to a
numeric type and sum it.
Subset assignment
The static shape restriction also prevents calls of the form
x[mask] <- update. In some cases, this can be replaced
by nv_ifelse(). For example, you might want to replace all
values < 1 with 1:
nv_ifelse(x < 1, 1, x)
#> AnvlArray
#> 1
#> 1
#> 3
#> 1
#> 2
#> 1
#> 5
#> [ CPUf32{7} ]What is currently not possible is to actually subset x
so that it only contains values >= 1.