This document is relevant for: Trn2, Trn3
nki.collectives.all_to_all_v#
- nki.collectives.all_to_all_v(srcs: List, dsts: List, replica_group: ReplicaGroup, metadata_tensor, recv_counts_known: bool = False, has_rdispls: bool = False, priority: Optional[int] = None) None[source]#
Executes an all-to-all collective where each rank can send a different number of elements, known only at execution time (rather than at compile time).
Unlike
all_to_allwhich splits/concatenates along a collective dimension,all_to_all_vtreats tensors as flat element buffers. Per-rank send/recv counts and displacements are supplied via a uint32 metadata tensor, making per-rank payload sizes dynamic.Current restrictions:
On instances with a NeuronSwitch fabric (see Trn3 architecture),
all_to_all_vrequires LNC=2 and more than one participating device. Multiple ranks per device are supported, but for every replica-group rank-list, every device participating in that rank-list must have all of its ranks (4 under LNC=2) included in the same rank-list — each rank-list is a set of sequential ranks in the world (e.g.[[1, 2, 3, 4], [5, 6, 7, 8]]). To exclude a rank, keep it in the replica group and set itssend_countto 0.On other instances,
all_to_all_vcurrently supports only inter-node replica groups: each rank-list contains same-indexed ranks from different nodes (a node refers to a different Trn EC2 instance).- Parameters:
srcs – Input tensor list. Currently supports exactly one tensor. Must be HBM-backed.
dsts – Output tensor list. Currently supports exactly one tensor. Must be HBM-backed.
srcanddstelement counts can be different; sizes are validated against the metadata at execution time.replica_group – ReplicaGroup defining which ranks participate.
metadata_tensor –
uint32tensor laid out contiguously in memory. Shape depends on backing buffer, whererowsis 3 whenhas_rdispls=Falseand 4 whenhas_rdispls=True:HBM:
(rows, replica_group_size).SBUF:
(1, rows, replica_group_size)— the whole buffer must live on a single partition, so a trivial partition dim is prepended.
For each other rank
rin the replica group, the rows are:Row 0
send_counts[r]: number of elements sent to rankr. Always an input.Row 1
send_displs[r]: offset in elements withinsrcwhere the chunk destined for rankrbegins. Always an input.Row 2
recv_counts[r]: number of elements received from rankr. Controlled byrecv_counts_known— see that flag.Row 3
recv_displs[r]: offset in elements withindstwhere the chunk from rankris written. Only present whenhas_rdispls=True.
recv_counts_known –
Controls whether row 2 is populated by the collective during execution. Row 2 is never read as input.
True: row 2 is left untouched, avoiding a small per-rank writeback.False(default): row 2 is an output — per-rank received counts are written during execution, and can be read after the op to learn received sizes.
has_rdispls –
True: row 3 is an input; recv_displs must be populated. The chunk from sender rankris written atdst[recv_displs[r] : recv_displs[r] + recv_counts[r]].False: row 3 may be omitted frommetadata_tensor(pass a 3-row tensor). Incoming chunks are laid out equally-spaced atrecv_displs[r] = (dst.total_elements / replica_group_size) * r, regardless of the actual recv_count per rank.
priority – DMA QoS priority level 0-3 where lower is higher priority (NeuronCore-v4+ only).
This document is relevant for: Trn2, Trn3