Converts a traced AnvilGraph into the StableHLO intermediate representation (IR).
Each graph operation is translated to its corresponding StableHLO op. The result can
be serialized to MLIR text via stablehlo::repr() and subsequently compiled to an
XLA executable with pjrt::pjrt_compile().
The rules for translating to stablehlo are stored in $rules[["stablehlo"]] of the primitives.
This is a low-level function; most users should use jit() or xla() instead.
Usage
stablehlo(graph, constants_as_inputs = TRUE, env = NULL, donate = character())Arguments
- graph
(
AnvilGraph)
The graph to lower (e.g. produced bytrace_fn()).- constants_as_inputs
(
logical(1))
IfTRUE(default), constants are registered as inputs to the StableHLO function so they can be passed in at execution time. IfFALSE, they are not added as inputs. Set toFALSEfor closures. Note thatGraphLiterals are always inlined into the StableHLO function.- env
(
HloEnv|NULL)
Optional environment for reusing variable mappings across nested function lowerings (e.g. for higher-order primitives likenv_while).- donate
(
character())
Names of the arguments whose buffers should be donated. Donated buffers can be aliased with outputs of the same type, enabling in-place operations.
Examples
x <- nv_tensor(c(1, 2))
graph <- trace_fn(function(y) y + x, list(y = nv_aten("f32", shape = c())))
graph
#> <AnvilGraph>
#> Inputs:
#> %x1: f32[]
#> Constants:
#> %c1: f32[2]
#> Body:
#> %1: f32[2] = broadcast_in_dim [shape = 2, broadcast_dimensions = <any>] (%x1)
#> %2: f32[2] = add(%1, %c1)
#> Outputs:
#> %2: f32[2]
stablehlo(graph)
#> [[1]]
#> func.func @main (%0: tensor<2xf32>, %1: tensor<f32>) -> tensor<2xf32> {
#> %2 = "stablehlo.broadcast_in_dim" (%1) {
#> broadcast_dimensions = array<i64>
#> }: (tensor<f32>) -> (tensor<2xf32>)
#> %3 = "stablehlo.add" (%2, %0): (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xf32>)
#> "func.return"(%3): (tensor<2xf32>) -> ()
#> }
#>
#> [[2]]
#> [[2]][[1]]
#> GraphValue(ConcreteTensor(f32, (2)))
#>
#>