Low-level graph transformation that appends the backward pass to a
traced AnvilGraph. The function f represented by graph must return a single
float scalar. The resulting graph computes the gradients of that scalar with respect
to the inputs specified by wrt.
The backward rules are stored in $rules[["backward"]] of the primitives.
This is the building block used by gradient() and value_and_gradient(); prefer
those higher-level wrappers unless you need to operate on graphs directly.
Arguments
- graph
(
AnvilGraph)
The graph to transform. Must produce a single scalar float output.- wrt
(
character)
Names of the graph inputs to differentiate with respect to.
Value
An AnvilGraph whose outputs are the requested gradients.
Examples
graph <- trace_fn(nvl_mul, list(nv_aten("f32", c()), nv_aten("f32", c())))
graph
#> <AnvilGraph>
#> Inputs:
#> %x1: f32[]
#> %x2: f32[]
#> Body:
#> %1: f32[] = mul(%x1, %x2)
#> Outputs:
#> %1: f32[]
transform_gradient(graph, "lhs")
#> <AnvilGraph>
#> Inputs:
#> %x1: f32[]
#> %x2: f32[]
#> Constants:
#> %c1: f32[]
#> Body:
#> %1: f32[] = mul(%x1, %x2)
#> %2: f32[] = mul(%c1, %x2)
#> Outputs:
#> %2: f32[]