This document is relevant for: Trn1, Trn2, Trn3
nki.collectives.all_to_all_v#
- nki.collectives.all_to_all_v(srcs, dsts, replica_group, metadata_tensor, recv_counts_known=False, has_rdispls=False)[source]#
Perform a variable-length all-to-all on the given replica group and input/output tensors.
Unlike all_to_all which splits and concatenates along a collective_dim, all_to_all_v treats tensors as flat buffers of elements. Counts and displacements in the metadata tensor are in elements (row-major order), not slices along a particular dimension.
- Parameters:
srcs – List of input tensors to redistribute (must be exactly one)
dsts – List of output tensors to store results (must be exactly one)
replica_group – ReplicaGroup defining rank groups for the collective
metadata_tensor – Metadata tensor of shape (2-4, world_size), dtype uint32. Row 0: send counts, Row 1: send displacements, Row 2 (optional): recv counts, Row 3 (optional): recv displacements.
recv_counts_known – If True, metadata includes receive counts (row 2)
has_rdispls – If True, metadata includes receive displacements (row 3)
This document is relevant for: Trn1, Trn2, Trn3