Skip to contents

Returns a new function that computes the gradient of f via reverse-mode automatic differentiation. f must return a single float scalar. The returned function has the same signature as f and returns the gradients in the same structure as the inputs (or the subset selected by wrt).

Usage

gradient(f, wrt = NULL)

Arguments

f

(function)
Function to differentiate. Arguments can be tensorish (AnvilTensor) or static (non-tensor) values. Must return a single scalar float tensor.

wrt

(character or NULL)
Names of the arguments to compute the gradient with respect to. Only tensorish (float tensor) arguments can be included; static arguments must not appear in wrt. If NULL (the default), the gradient is computed with respect to all arguments (which must all be tensorish in that case).

Value

function

See also

value_and_gradient() to get both the output and gradients, transform_gradient() for the low-level graph transformation.

Examples

f <- function(x, y) sum(x * y)
g <- jit(gradient(f))
g(nv_tensor(c(1, 2), dtype = "f32"), nv_tensor(c(3, 4), dtype = "f32"))
#> $x
#> AnvilTensor
#>  3
#>  4
#> [ CPUf32{2} ] 
#> 
#> $y
#> AnvilTensor
#>  1
#>  2
#> [ CPUf32{2} ] 
#> 

# Differentiate with respect to a single argument
g_x <- jit(gradient(f, wrt = "x"))
g_x(nv_tensor(c(1, 2), dtype = "f32"), nv_tensor(c(3, 4), dtype = "f32"))
#> $x
#> AnvilTensor
#>  3
#>  4
#> [ CPUf32{2} ] 
#> 

# Static (non-tensor) arguments are passed through but cannot be in wrt
f2 <- function(x, power) sum(x^power)
g2 <- jit(gradient(f2, wrt = "x"), static = "power")
g2(nv_tensor(c(1, 2, 3), dtype = "f32"), power = 2L)
#> $x
#> AnvilTensor
#>  2
#>  4
#>  6
#> [ CPUf32{3} ] 
#>