Returns the k largest values along the last dimension, sorted in
descending order, together with their indices into that dimension.
For other dimensions, transpose so the target dimension is last, call
prim_top_k(), then transpose back. nv_top_k() does this.
Arguments
- operand
(
arrayish)
Tensor of integer, unsigned integer, or floating-point dtype with rank >= 1.- k
(
integer(1))
Number of top elements. Must satisfy1 <= k <= shape(operand)[ndims(operand)].
Value
list of two arrayish values:
The top-k values (same dtype as operand) and their indices along
the last dimension (dtype i32, matching JAX). Both have the same
shape as operand with the last dimension replaced by k. Ties are
broken by lower index first.
StableHLO
Lowers to stablehlo::hlo_top_k().