Produces a result array identical to input except that slices at
positions specified by scatter_indices are updated with values from
the update array. When multiple indices point to the same location,
the update_computation function determines how to combine the values
(by default the new value replaces the old one).
This is the inverse of prim_gather(): gather reads slices from an array
at given indices, while scatter writes slices into an array at given
indices.
Usage
prim_scatter(
input,
scatter_indices,
update,
update_window_dims,
inserted_window_dims,
input_batching_dims,
scatter_indices_batching_dims,
scatter_dims_to_operand_dims,
index_vector_dim,
indices_are_sorted = FALSE,
unique_indices = FALSE,
update_computation = NULL
)Arguments
- input
(
arrayish)
Arrayish value of any data type. The base array to scatter into.- scatter_indices
(
arrayishof integer type)
Array of indices. Contains index vectors that map to positions ininputviascatter_dims_to_operand_dims. The dimension specified byindex_vector_dimholds the index vectors.- update
(
arrayish)
Update values array. Must have the same data type asinput.- update_window_dims
(
integer())
Dimensions ofupdatethat are window dimensions, i.e. they correspond to the slice being written intoinput.- inserted_window_dims
(
integer())
Dimensions ofinputwhose slices have size 1 and are inserted (not present) in theupdatewindow. Together withupdate_window_dimsandinput_batching_dims, these must account for all dimensions ofinput.- input_batching_dims
(
integer())
Dimensions ofinputthat are batch dimensions. Useinteger(0)when there are no batch dimensions.- scatter_indices_batching_dims
(
integer())
Dimensions ofscatter_indicesthat correspond to batch dimensions. Must have the same length asinput_batching_dims.- scatter_dims_to_operand_dims
(
integer())
Maps each component of the index vector to aninputdimension. For example,scatter_dims_to_operand_dims = c(1L)means each index vector indexes into the first dimension ofinput.- index_vector_dim
(
integer(1))
Dimension ofscatter_indicesthat contains the index vectors. If set tondims(scatter_indices) + 1, each scalar element ofscatter_indicesis treated as a length-1 index vector.- indices_are_sorted
(
logical(1))
Whether indices are guaranteed to be sorted. Setting toTRUEmay improve performance but produces undefined behavior if the indices are not actually sorted. DefaultFALSE.- unique_indices
(
logical(1))
Whether indices are guaranteed to be unique (no duplicates). Setting toTRUEmay improve performance but produces undefined behavior if the indices are not actually unique. DefaultFALSE.- update_computation
(
function)
Binary functionf(old, new)that combines the existing value ininputwith the value fromupdate. The default (NULL) usesfunction(old, new) new, which replaces the old value.
Value
arrayish
Has the same data type and shape as input.
It is ambiguous if input is ambiguous.
Out Of Bounds Behavior
If a computed result index falls outside the bounds of input, the
update for that index is silently ignored.
Update Order
When multiple indices in scatter_indices map to the same element
of input, the order in which update_computation is applied is
implementation-defined and may vary between plugins ("cpu", "cuda").
StableHLO
Lowers to stablehlo::hlo_scatter().
See also
prim_gather(), nv_subset(), nv_subset_assign(), [, [<-
Examples
# Scatter values 10 and 30 into positions 1 and 3 of a zero vector
input <- nv_array(c(0, 0, 0, 0, 0))
indices <- nv_array(matrix(c(1L, 3L), ncol = 1))
updates <- nv_array(c(10, 30))
prim_scatter(
input, indices, updates,
update_window_dims = integer(0),
inserted_window_dims = 1L,
input_batching_dims = integer(0),
scatter_indices_batching_dims = integer(0),
scatter_dims_to_operand_dims = 1L,
index_vector_dim = 2L
)
#> AnvlArray
#> 10
#> 0
#> 30
#> 0
#> 0
#> [ CPUf32{5} ]