from dataclasses import dataclass
from enum import Enum
from typing import *
from nki.language import NKIObject
[docs]@dataclass
class ReplicaGroup(NKIObject):
r"""Defines a group of ranks that participate in a collective operation.
Sub-groups represented by lists of ranks should not have any overlap."""
...
[docs]def all_reduce(srcs: List, dsts: List, replica_group: ReplicaGroup, op, priority: Optional[int]=None) -> None:
r"""Perform an all-reduce on the given replica group and input/output tensors.
The ``srcs`` and ``dsts`` parameters accept lists of tensors to support coalesced
collective communication, which allows multiple tensors to be reduced in a single
collective operation for improved efficiency.
Tensors can reside on either HBM or SBUF. However, mixing memory spaces is not
supported: all tensors must be on HBM or all must be on SBUF. Coalesced collective
communication (multiple tensors) is only supported when tensors are on HBM.
:param srcs: List of input tensors to reduce
:param dsts: List of output tensors to store results
:param replica_group: ReplicaGroup defining rank groups for the collective
:param op: The reduction operation to perform (``nl.add``, ``nl.minimum``, or ``nl.maximum``)
:param priority: DMA quality-of-service priority level 0-3 where lower is higher
priority (NeuronCore-v4+ only)"""
...
[docs]def all_gather(srcs: List, dsts: List, replica_group: ReplicaGroup, collective_dim: int, priority: Optional[int]=None) -> None:
r"""Perform an all-gather on the given replica group and input/output tensors.
The ``srcs`` and ``dsts`` parameters accept lists of tensors to support coalesced
collective communication, which allows multiple tensors to be gathered in a single
collective operation for improved efficiency.
Tensors can reside on either HBM or SBUF. However, mixing memory spaces is not
supported: all tensors must be on HBM or all must be on SBUF. Coalesced collective
communication (multiple tensors) is only supported when tensors are on HBM.
:param srcs: List of input tensors to gather
:param dsts: List of output tensors to store results
:param replica_group: ReplicaGroup defining rank groups for the collective
:param collective_dim: Dimension along which output tensors are concatenated.
Currently only 0 is supported for HBM tensors. For SBUF tensors, 0 or 1 is
supported as SBUF collectives currently only operate on 2D tensors with a
single free dimension.
:param priority: DMA quality-of-service priority level 0-3 where lower is higher
priority (NeuronCore-v4+ only)"""
...
[docs]def reduce_scatter(srcs: List, dsts: List, replica_group: ReplicaGroup, collective_dim: int, op, priority: Optional[int]=None) -> None:
r"""Perform a reduce-scatter on the given replica group and input/output tensors.
The ``srcs`` and ``dsts`` parameters accept lists of tensors to support coalesced
collective communication, which allows multiple tensors to be reduced and scattered
in a single collective operation for improved efficiency.
Tensors can reside on either HBM or SBUF. However, mixing memory spaces is not
supported: all tensors must be on HBM or all must be on SBUF. Coalesced collective
communication (multiple tensors) is only supported when tensors are on HBM.
:param srcs: List of input tensors to reduce and scatter
:param dsts: List of output tensors to store results
:param replica_group: ReplicaGroup defining rank groups for the collective
:param collective_dim: Dimension along which input tensors are split.
Currently only 0 is supported for both HBM and SBUF tensors.
:param op: The reduction operation to perform (``nl.add``, ``nl.minimum``, or ``nl.maximum``)
:param priority: DMA quality-of-service priority level 0-3 where lower is higher
priority (NeuronCore-v4+ only)"""
...
[docs]def all_to_all(srcs: List, dsts: List, replica_group: ReplicaGroup, collective_dim: int, priority: Optional[int]=None) -> None:
r"""Perform an all-to-all on the given replica group and input/output tensors.
The ``srcs`` and ``dsts`` parameters accept lists of tensors to support coalesced
collective communication, which allows multiple tensors to be redistributed in a
single collective operation for improved efficiency.
Tensors must reside on HBM. SBUF is not currently supported for all-to-all.
:param srcs: List of input tensors to redistribute
:param dsts: List of output tensors to store results
:param replica_group: ReplicaGroup defining rank groups for the collective
:param collective_dim: Dimension along which input tensors are split and output tensors are concatenated.
Currently only 0 is supported.
:param priority: DMA quality-of-service priority level 0-3 where lower is higher
priority (NeuronCore-v4+ only)"""
...
[docs]def all_to_all_v(srcs: List, dsts: List, replica_group: ReplicaGroup, metadata_tensor, recv_counts_known: bool=False, has_rdispls: bool=False, priority: Optional[int]=None) -> None:
r"""Executes an all-to-all collective where each rank can send
a different number of elements, known only at execution time (rather
than at compile time).
Unlike ``all_to_all`` which splits/concatenates along a collective
dimension, ``all_to_all_v`` treats tensors as flat element buffers.
Per-rank send/recv counts and displacements are supplied via a uint32
metadata tensor, making per-rank payload sizes dynamic.
**Current restrictions:**
On instances with a NeuronSwitch fabric (see `Trn3 architecture
<https://awsdocs-neuron.readthedocs-hosted.com/en/latest/about-neuron/arch/neuron-hardware/trn3-arch.html>`_),
``all_to_all_v`` requires LNC=2 and more than one participating
device. Multiple ranks per device are supported, but for every
replica-group rank-list, every device participating in that
rank-list must have all of its ranks (4 under LNC=2) included in
the same rank-list — each rank-list is a set of sequential ranks
in the world (e.g. ``[[1, 2, 3, 4], [5, 6, 7, 8]]``). To exclude a
rank, keep it in the replica group and set its ``send_count`` to 0.
On other instances, ``all_to_all_v`` currently supports only
inter-node replica groups: each rank-list contains same-indexed
ranks from different nodes (a node refers to a different Trn EC2
instance).
:param srcs: Input tensor list. Currently supports exactly one tensor.
Must be HBM-backed.
:param dsts: Output tensor list. Currently supports exactly one tensor.
Must be HBM-backed. ``src`` and ``dst`` element counts can be
different; sizes are validated against the metadata at execution
time.
:param replica_group: ReplicaGroup defining which ranks participate.
:param metadata_tensor: ``uint32`` tensor laid out contiguously in
memory. Shape depends on backing buffer, where ``rows`` is 3 when
``has_rdispls=False`` and 4 when ``has_rdispls=True``:
- HBM: ``(rows, replica_group_size)``.
- SBUF: ``(1, rows, replica_group_size)`` — the whole buffer must
live on a single partition, so a trivial partition dim is
prepended.
For each other rank ``r`` in the replica group, the rows are:
- Row 0 ``send_counts[r]``: number of elements sent to rank ``r``.
Always an input.
- Row 1 ``send_displs[r]``: offset in elements within ``src`` where
the chunk destined for rank ``r`` begins. Always an input.
- Row 2 ``recv_counts[r]``: number of elements received from rank
``r``. Controlled by ``recv_counts_known`` — see that flag.
- Row 3 ``recv_displs[r]``: offset in elements within ``dst`` where
the chunk from rank ``r`` is written. Only present when
``has_rdispls=True``.
:param recv_counts_known:
Controls whether row 2 is populated by the collective during
execution. Row 2 is never read as input.
- ``True``: row 2 is left untouched, avoiding a small per-rank
writeback.
- ``False`` (default): row 2 is an **output** — per-rank received
counts are written during execution, and can be read after the
op to learn received sizes.
:param has_rdispls:
- ``True``: row 3 is an **input**; recv_displs must be populated.
The chunk from sender rank ``r`` is written at
``dst[recv_displs[r] : recv_displs[r] + recv_counts[r]]``.
- ``False``: row 3 may be omitted from ``metadata_tensor`` (pass a
3-row tensor). Incoming chunks are laid out equally-spaced at
``recv_displs[r] = (dst.total_elements / replica_group_size) * r``,
regardless of the actual recv_count per rank.
:param priority: DMA QoS priority level 0-3 where lower is higher
priority (NeuronCore-v4+ only)."""
...
[docs]def collective_permute(srcs: List, dsts: List, source_target_pairs: List[Tuple[int, int]], priority: Optional[int]=None) -> None:
r"""Send and receive data between ranks based on explicitly defined source-target pairs.
Each pair ``(source, target)`` specifies that data from the source rank
should be sent to the target rank. This gives you full control over the
communication pattern (e.g., pairwise swaps, arbitrary shuffles).
Prefer :func:`collective_permute_implicit` when the communication
follows a ring topology, as the hardware can optimize that pattern.
Tensors must reside on HBM. SBUF is not currently supported for collective_permute.
Coalesced collective communication (multiple tensors) is not currently supported;
each list parameter must contain exactly one tensor.
:param srcs: List of source tensors to send
:param dsts: List of destination tensors to receive into
:param source_target_pairs: List of (source, target) rank ID pairs
:param priority: DMA quality-of-service priority level 0-3 where lower is higher
priority (NeuronCore-v4+ only)"""
...
[docs]def collective_permute_implicit(srcs_by_channel: List[List], dsts_by_channel: List[List], replica_group: ReplicaGroup, channel_ids: List[int]=[0], priority: Optional[int]=None) -> None:
r"""Send and receive data between ranks in a ring, where sources and destinations are
implicitly determined by the ring structure during runtime.
Each rank sends data to its successor and receives from its predecessor in the ring.
This differs from :func:`collective_permute` where users explicitly specify source-target pairs.
Since the sources and destinations are implicitly determined, use
:func:`collective_permute_implicit_current_processing_rank_id` to get the rank ID
whose data is currently being processed.
The outer dimension of ``srcs_by_channel`` and ``dsts_by_channel`` corresponds to channels.
For each channel, the inner list contains exactly one tensor (coalesced collective
communication is not currently supported).
**Channels**: Multiple channels enable overlapping communication, allowing concurrent data
transfers. The number of available channels depends on the replica group and system
connectivity (see
`Neuron Collectives <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-runtime/about/collectives.html#system-connectivity>`_).
The maximum number of channels is 4 for replica groups containing all devices inside a node
and 2 for other supported replica groups.
:param srcs_by_channel: List of source tensor lists, one per channel. Each inner list must contain exactly one tensor.
:param dsts_by_channel: List of destination tensor lists, one per channel. Each inner list must contain exactly one tensor.
:param replica_group: ReplicaGroup defining rank groups for the collective
:param channel_ids: List of channel IDs to use for communication (default [0] for single channel).
Currently must be consecutive integers starting from 0.
:param priority: DMA quality-of-service priority level 0-3 where lower is higher
priority (NeuronCore-v4+ only)"""
...
[docs]def collective_permute_implicit_reduce(srcs0_by_channel: List[List], srcs1_by_channel: List[List], dsts_by_channel: List[List], replica_group: ReplicaGroup, op, channel_ids: List[int]=[0], priority: Optional[int]=None) -> None:
r"""Perform an implicit collective permute with reduction in a ring, where sources and
destinations are implicitly determined by the ring structure during runtime.
Combines :func:`collective_permute_implicit` with a reduction operation.
Each rank reduces its local sources using ``op(srcs0_by_channel[i], srcs1_by_channel[i])``,
sends the result to its successor, and receives its predecessor's reduced result into
``dsts_by_channel[i]``.
Since the sources and destinations are implicitly determined, use
:func:`collective_permute_implicit_current_processing_rank_id` to get the rank ID
whose data is currently being processed.
The outer dimension of ``srcs0_by_channel``, ``srcs1_by_channel``, and ``dsts_by_channel``
corresponds to channels. For each channel, the inner list contains exactly one tensor
(coalesced collective communication is not currently supported).
**Channels**: Multiple channels enable overlapping communication, allowing concurrent data
transfers. The number of available channels depends on the replica group and system
connectivity (see
`Neuron Collectives <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-runtime/about/collectives.html#system-connectivity>`_).
The maximum number of channels is 4 for replica groups containing all devices inside a node
and 2 for other supported replica groups.
:param srcs0_by_channel: List of source tensor lists (left operand of reduction), one per channel. Each inner list must contain exactly one tensor.
:param srcs1_by_channel: List of source tensor lists (right operand of reduction), one per channel. Each inner list must contain exactly one tensor.
:param dsts_by_channel: List of destination tensor lists to receive predecessor's reduced result, one per channel. Each inner list must contain exactly one tensor.
:param replica_group: ReplicaGroup defining rank groups for the collective
:param op: The reduction operation to perform (``nl.add``, ``nl.minimum``, or ``nl.maximum``)
:param channel_ids: List of channel IDs to use for communication (default [0] for single channel).
Currently must be consecutive integers starting from 0.
:param priority: DMA quality-of-service priority level 0-3 where lower is higher
priority (NeuronCore-v4+ only)"""
...
[docs]def collective_permute_implicit_current_processing_rank_id(iteration_id: int, replica_group: ReplicaGroup, channel_id: int=0):
r"""Returns the rank ID of the data to be processed in the current ring iteration.
This function is intended to be used in conjunction with
:func:`collective_permute_implicit` or :func:`collective_permute_implicit_reduce`.
Since the sources and destinations are implicitly determined in ring algorithms,
the rank ID of received data can only be determined at runtime.
At iteration 0, this returns the current rank's own ID (processing local data).
In subsequent iterations, it returns the rank ID of data received from predecessors,
progressing around the ring.
The returned rank ID is a scalar register. To determine the offset of the received
data chunk within a tensor, use register ALU operations (e.g., multiply the rank ID
by chunk size), then use dynamic access pattern (``tensor.ap()``) in ISA compute
operations (e.g., ``nisa.nc_matmul()``).
**Typical usage pattern**: In each iteration of a ring algorithm, the compute kernel
uses this function to identify which rank's data is being processed, computes on that
data while concurrently triggering the next communication step to send already-computed
chunks to the successor.
**Channels**: Multiple channels enable overlapping communication, allowing concurrent data
transfers. The number of available channels depends on the replica group and system
connectivity (see
`Neuron Collectives <https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-runtime/about/collectives.html#system-connectivity>`_).
The maximum number of channels is 4 for replica groups containing all devices inside a node
and 2 for other supported replica groups.
:param iteration_id: Current ring step (typically the loop counter).
:param replica_group: ReplicaGroup defining the ring topology
:param channel_id: Channel ID for the communication (0 to num_channels-1)
:return: Scalar register containing the rank ID of the data to be processed"""
...
[docs]def rank_id():
r"""Get the rank ID of the current rank.
:return: The rank ID of the current rank within the collective group"""
...