Transforming Code
While a real anvil is made for reshaping metal, this package is a tool for reshaping code. We refer to such a rewriting of code as a transformation, of which there are three types:
-
R\(\rightarrow\)Graph: GenericRfunctions are too complicated to handle, so the first step in {anvil} is always to convert them into a computationalanvil::Graphobject via tracing. Such aGraphis similar toJAXExprobjects inJAX. It operates only onAnvilTensorobjects and appliesanvil::Primitiveoperations to them. -
Graph\(\rightarrow\)Graph: It is possible to transformGraphs into otherGraphs. Their purpose is to change the functionality of the code. At the time of writing, there is essentially only one such transformation, namely backward-mode automatic differentiation viagradient(). -
Graph\(\rightarrow\)Executable: In order to perform the actual computation, theGraphneeds to be converted into an executable. Currently, we only support the XLA backend (viastablehloandpjrt), but we are working on an experimental quickr backend.
Tracing R Functions into Graphs
All functionality in the {anvil} package is centered around the
anvil::Graph class. While it is in principle possible to
create Graphs by hand, these are usually created by tracing
R functions. In general, when we want to convert some code into another
form (in our case, R Code into a Graph), there are two
approaches:
- Static analysis, which would require operating on the abstract syntax tree (AST) of the code.
- Dynamic analysis (aka “tracing”), which executes the code and records selected operations.
The former approach is followed by the {quickr} package, while we go
with tracing. We start with a simple, yet illustrative example that
either adds or multiplies two inputs x and y
depending on the value of op.
library(anvil)
f <- function(x, y, op) {
if (op == "add") {
nv_add(x, y)
} else if (op == "mul") {
nv_mul(x, y)
} else {
stop("Unsupported operation")
}
}To do this, we use anvil::trace_fn(), which takes in an
R function and a list of AbstractTensor inputs
that specify the types of the inputs.
## AbstractTensor(dtype=f32, shape=)
## <Graph>
## Inputs:
## %x1: f32[]
## %x2: f32[]
## Body:
## %1: f32[] = mul(%x1, %x2)
## Outputs:
## %1: f32[]
The output of trace_fn() is now a Graph
object that represents the computation. The fields of the
Graph are:
-
inputs, which areGraphNodes that represent the inputs to the function. -
outputs, which areGraphNodes that represent the outputs of the function. -
calls, which arePrimitiveCalls that take inGraphNodes (and parameters) and produce outputGraphNodes. -
in_tree,out_tree, which we will cover later (do we??)
During trace_fn, the inputs What happens during
trace_fn() is that a new GraphDescriptor is
created and the inputs x and y are converted
into anvil::GraphBox objects. Then, the function
f is simply evaluated with the GraphBox
objects as inputs. During this evaluation, we need to distinguish
between two cases:
- A “standard”
Rfunction is called: Here, nothing special happens and the function is simply evaluated. - An
anvilfunction is called: Here, the operation that underlies the function is recorded in theGraphDescriptor.
The evaluation of the if statement is an example for the
first category. Because we set op = "mul", only the first
branch is executed. Then, we are calling nv_mul, which
attaches a PrimitiveCall that represents the multiplication
of the two tensors to the @calls of the
GraphDescriptor. Note that the nv_mul is
itself not primitive, but performs some type promotion and broadcasting
if needed, before calling into the primitive nvl_mul().
A PrimitiveCall object consists of the following
fields:
-
primitive: The primitive function that was called. -
inputs: The inputs to the primitive function. -
params: The parameters (non-tensors) to the primitive function. -
outputs: The outputs of the primitive function.
When the evaluation of f is complete, the
@outputs field of the GraphDescriptor is set
and the Graph is subsequently created from the
GraphDescriptor. The only difference between the
Graph and the GraphDescriptor is that the
latter has some utility fields that are useful during graph creation,
but for the purposes of this tutorial, you can think of them as being
the same.
Transforming Graphs into other Graphs
Once the R function is staged out into a simpler format,
it is ready to be transformed. The {anvil} package does not in any way
dictate how such a Graph to Graph
transformation can be implemented. For most interesting transformations,
however, we need to store some information for each {anvil} primitive
function. In the case of the gradient, we need to store the derivative
rules. For this, anvil::Primitive objects have a
@rules field that can be populated. The derivative rules
are stored as functions under the "backward" name. We can
access a primitive by it’s name via the prim()
function:
prim("mul")@rules[["backward"]]## function (inputs, outputs, grads, .required)
## {
## lhs <- inputs[[1L]]
## rhs <- inputs[[2L]]
## grad <- grads[[1L]]
## list(if (.required[[1L]]) nvl_mul(grad, rhs), if (.required[[2L]]) nvl_mul(grad,
## lhs))
## }
## <bytecode: 0x5562969e84c0>
## <environment: namespace:anvil>
The anvil::transform_gradient function uses these rules
to compute the gradient of a function. For this specific transformation,
we are walking the graph backwards and apply the derivative rules, which
will append the “backward pass” to the graph. Besides the forward graph,
the transformation takes in the wrt argument, which
specifies with respect to which arguments to compute the gradient.
bwd_graph <- transform_gradient(graph, wrt = c("x", "y"))
bwd_graph## <Graph>
## Inputs:
## %x1: f32[]
## %x2: f32[]
## Constants:
## %c1: f32[]
## Body:
## %1: f32[] = mul(%x1, %x2)
## %2: f32[] = mul(%c1, %x2)
## %3: f32[] = mul(%c1, %x1)
## Outputs:
## %2: f32[]
## %3: f32[]
Lowering a Graph
In order to execute a Graph, we need to convert it into
a – wait for it – executable. Here, we should how to compile using the
XLA backend. First, we will translate the Graph into the
StableHLO representation via the {stablehlo} package. Then, we will
compile this program using the XLA compiler that is accessible via the
{pjrt} package.
Like for the gradient transformation, the rules of how to do this
transformation are stored in the @rules fields of the
primitives.
prim("mul")@rules[["stablehlo"]]## function (lhs, rhs)
## {
## list(stablehlo::hlo_multiply(lhs, rhs))
## }
## <bytecode: 0x5562969e7618>
## <environment: namespace:anvil>
The anvil::stablehlo function will create a
stablehlo::Func object and will sequentially translate the
PrimitiveCalls into StableHLO operations.
func <- stablehlo(graph)[[1L]]
func## func.func @main (%0: tensor<f32>, %1: tensor<f32>) -> tensor<f32> {
## %2 = "stablehlo.multiply" (%0, %1): (tensor<f32>, tensor<f32>) -> (tensor<f32>)
## "func.return"(%2): (tensor<f32>) -> ()
## }
Now, we can compile the function via pjrt_compile().
hlo_str <- stablehlo::repr(func)
program <- pjrt::pjrt_program(src = hlo_str, format = "mlir")
exec <- pjrt::pjrt_compile(program)To run the function, we simply pass the tensors to the executable,
which will output a PJRTBuffer that we can easily convert
to an AnvilTensor.
x <- nv_scalar(3, "f32")
y <- nv_scalar(4, "f32")
out <- pjrt::pjrt_execute(exec, x, y)
out## PJRTBuffer
## 12.0000
## [ CPUf32{} ]
nv_tensor(out)## AnvilTensor
## 12.0000
## [ CPUf32{} ]
The User Interface
In the previous section, we have shown how the transformations are
implemented under the hood. The actual user interface is a little more
convenient and follows the JAX interface.
jit()
The jit() function allows to convert a regular
R function into a Just-In-Time compiled function that can
be executed on AnvilTensors. We apply it to our simple
example function, where we mark the non-tensor parameter op
as “static”. This means that the value of this parameter needs to be
known at compile time.
f_jit <- jit(f, static = "op")
f_jit(x, y, "add")## AnvilTensor
## 7.0000
## [ CPUf32{} ]
One might think that jit() first calls
trace_fn(), then runs stablehlo(), followed by
pjrt_compile(). This is, however, not what is happening, as
this requires the input types to be known. Instead, f_jit
is a “lazy” function that will only perform these steps once the inputs
are provided. However, if those steps were applied every time the
f_jit function is called, this would be very inefficient,
because tracing and compiling takes some time. Therefore, the function
f_jit also contains a cache (implemented as an
xlamisc::LRUCache), which will check whether there is
already a compiled executable for the given inputs. For this, the types
of all AnvilTensors need to match exactly (data type and
shape) and all static arguments need to be identical. For example, if we
run the function with AnvilTensors of the same type, but
different values, the function won’t be recompiled, which we can see by
checking the size of the cache, which is already 1, because we have
called it on x and y above.
cache_size <- function(f) environment(f)$cache$size
cache_size(f_jit)## [1] 1
After calling it with tensors of the same types and identical static argument values, the size of the cache remains 1:
## AnvilTensor
## -97.0000
## [ CPUf32{} ]
cache_size(f_jit)## [1] 1
When we execute the function with tensors of different
dtype or shape, the function will be
recompiled:
## AnvilTensor
## 3
## [ CPUi32{} ]
cache_size(f_jit)## [1] 2
Also, if we provide different values for static arguments, the function will be recompiled:
## AnvilTensor
## 2.0000
## [ CPUf32{} ]
cache_size(f_jit)## [1] 3
gradient()
Just like jit(), gradient() also returns a
function that will lazily create the graph and transform it, once the
inputs are provided.
Calling g() on AnvilTensors will not
actually compute the gradient, but instead just output the output types,
c.f. the debugging vignette for more.
g(x, y, "add")## $x
## 1.0000
## [ CPUf32{} ]
##
## $y
## 1.0000
## [ CPUf32{} ]
If we want to actually compute the gradient, we need to wrap it in
jit().
g_jit <- jit(g, static = "op")
g_jit(x, y, "add")## $x
## AnvilTensor
## 1.0000
## [ CPUf32{} ]
##
## $y
## AnvilTensor
## 1.0000
## [ CPUf32{} ]
Moreover, we can also use g in another function:
## $x
## AnvilTensor
## 3.0000
## [ CPUf32{} ]
##
## $y
## AnvilTensor
## 7.0000
## [ CPUf32{} ]
So, what is happening here? Once the inputs x and
y are provided to h_jit, a new
GraphDescriptor is created and the inputs x
and y are converted into GraphBox objects.
Then, the addition of x and y is recorded in
the GraphDescriptor. The call into g() is a
bit more involved. First, a new GraphDescriptor is created
and the forward computation of g is recorded. Subsequently,
the backward pass will be added to the descriptor, after which it will
be converted into a Graph. This Graph will
then be inlined into the parent GraphDescriptor
(representing the whole function h), which is then
converted into the main Graph. We can look at this graph
below, where trace_fn internally converts the
AnvilTensors x and y into their
abstract representation.
## <Graph>
## Inputs:
## %x1: f32[]
## %x2: f32[]
## Constants:
## %c1: f32[]
## Body:
## %1: f32[] = add(%x1, %x2)
## %2: f32[] = mul(%1, %x1)
## %3: f32[] = mul(%c1, %x1)
## %4: f32[] = mul(%c1, %1)
## Outputs:
## %3: f32[]
## %4: f32[]
Afterwards, this graph is lowered to stableHLO and subsequently compiled.
More Internals
Debug Mode
For how to use debug mode, see the debugging vignette.
Debug-mode is different from jit-mode, because we don’t have a
context that can initialize a main GraphDescriptor. For
this reason, every primitive initializes its own
GraphDescriptor that is thrown away after the primitive
returns DebugBox objects. These DebugBox
objects are only for user-interaction and have a nice printer. Whenever
a primitive is evaluated, this DebugBox is converted to a
GraphBox object that is used for the actual evaluation via
maybe_box_variable. This ensures that we don’t have to
duplicate any evaluation logic as we the graph-building functions only
have to work with GraphBox objects.
What gets lost in debug mode is identity of values, because the
GraphDescriptor is thrown away. This means that we cannot
say anything about identity of values, only about their types.
Unfortunately, our current mode for detecting debug mode is whether a
GraphDescriptor is active. For this reason, we don’t allow
calling local_descriptor() in the global environment. Maybe
we can improve this in the future, but for now it seems to work.