Returns a new function that computes the gradient of f via reverse-mode automatic
differentiation. f must return a single float scalar. The returned function has the
same signature as f and returns the gradients in the same structure as the inputs
(or the subset selected by wrt).
Arguments
- f
(
function)
Function to differentiate. Arguments can be tensorish (AnvilTensor) or static (non-tensor) values. Must return a single scalar float tensor.- wrt
(
characterorNULL)
Names of the arguments to compute the gradient with respect to. Only tensorish (float tensor) arguments can be included; static arguments must not appear inwrt. IfNULL(the default), the gradient is computed with respect to all arguments (which must all be tensorish in that case).
See also
value_and_gradient() to get both the output and gradients,
transform_gradient() for the low-level graph transformation.
Examples
f <- function(x, y) sum(x * y)
g <- jit(gradient(f))
g(nv_tensor(c(1, 2), dtype = "f32"), nv_tensor(c(3, 4), dtype = "f32"))
#> $x
#> AnvilTensor
#> 3
#> 4
#> [ CPUf32{2} ]
#>
#> $y
#> AnvilTensor
#> 1
#> 2
#> [ CPUf32{2} ]
#>
# Differentiate with respect to a single argument
g_x <- jit(gradient(f, wrt = "x"))
g_x(nv_tensor(c(1, 2), dtype = "f32"), nv_tensor(c(3, 4), dtype = "f32"))
#> $x
#> AnvilTensor
#> 3
#> 4
#> [ CPUf32{2} ]
#>
# Static (non-tensor) arguments are passed through but cannot be in wrt
f2 <- function(x, power) sum(x^power)
g2 <- jit(gradient(f2, wrt = "x"), static = "power")
g2(nv_tensor(c(1, 2, 3), dtype = "f32"), power = 2L)
#> $x
#> AnvilTensor
#> 2
#> 4
#> 6
#> [ CPUf32{3} ]
#>