Skip to contents

General dot product of two arrays, supporting contraction over arbitrary dimensions and batching.

Usage

prim_dot_general(lhs, rhs, contracting_dims, batching_dims)

Arguments

lhs, rhs

(arrayish)
Left and right operand. Operands are promoted to a common data type. Scalars are broadcast to the shape of the other operand.

contracting_dims

(list(integer(), integer()))
A list of two integer vectors specifying which dimensions of lhs and rhs to contract over. The contracted dimensions must have matching sizes.

batching_dims

(list(integer(), integer()))
A list of two integer vectors specifying which dimensions of lhs and rhs are batch dimensions. These must have matching sizes.

Value

arrayish
The output shape is the batch dimensions followed by the remaining (non-contracted, non-batched) dimensions of lhs, then rhs.

Implemented Rules

  • stablehlo

  • quickr

  • reverse

StableHLO

Lowers to stablehlo::hlo_dot_general().

See also

Examples

x <- nv_array(matrix(1:6, nrow = 2))
y <- nv_array(matrix(1:6, nrow = 3))
prim_dot_general(x, y,
  contracting_dims = list(2L, 1L),
  batching_dims = list(integer(0), integer(0))
)
#> AnvlArray
#>  22 49
#>  28 64
#> [ CPUi32{2,2} ]