/************************************************************************* * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #ifndef NCCL_PRIMITIVES_H_ #define NCCL_PRIMITIVES_H_ #include #include "reduce_kernel.h" // for reduction funcs #include "common_kernel.h" #include "common.h" #define NCCL_SPINS_BEFORE_CHECK_ABORT 10000 /* Protocol classes: ProtoSimple, ProtoLL, ProtoLL128 * We use these as template args to the Primtiives class instead of integral * enums (e.g. NCCL_PROTO_LL) because for SIMPLE we need to carry a few extra * numbers. Also these types hold methods which let us compute numbers important * to how that protocol operates with a consistent interface so that our * algorithm code can operate protocol parametrically. */ template struct ProtoSimple { static constexpr int Id = NCCL_PROTO_SIMPLE; static constexpr int SlicePerChunk = SlicePerChunk_1; static constexpr int StepPerSlice = StepPerSlice_1; static constexpr int Unroll = Unroll_1; static constexpr int MultimemSrcs = MultimemSrcs_1; static constexpr int MultimemDsts = MultimemDsts_1; // Data bytes (no flags etc) in one step of the fifo queue. __device__ static int calcBytePerStep() { return ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS; } // Granularity of data bytes transferred per thread. __device__ static int calcBytePerGrain() { return sizeof(uint64_t); // Bogus value? Nobody queries this metric for simple. } // Group width is how many consecutive group values a subchannel occupies. static constexpr int MaxGroupWidth = 2; }; struct ProtoLL { static constexpr int Id = NCCL_PROTO_LL; // Data bytes (no flags etc) in one step of the fifo queue. __device__ static int calcBytePerStep() { return ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/2; // Half is data } // Granularity of data bytes transferred per thread. __device__ static int calcBytePerGrain() { return sizeof(uint64_t); // One 16-byte line has 8-bytes of data } // Group width is how many consecutive group values a subchannel occupies. static constexpr int MaxGroupWidth = 1; }; struct ProtoLL128 { static constexpr int Id = NCCL_PROTO_LL128; // Data bytes (no flags etc) in one step of the fifo queue. __device__ static int calcBytePerStep() { return (ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS)*NCCL_LL128_DATAELEMS/NCCL_LL128_LINEELEMS; } // Granularity of data bytes transferred per thread. __device__ static int calcBytePerGrain() { return NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_DATAELEMS*sizeof(uint64_t)/NCCL_LL128_LINEELEMS; } // Group width is how many consecutive group values a subchannel occupies. static constexpr int MaxGroupWidth = 1; }; /* Fan (as in fan-in & fan-out) classes hold recv and send counts. The template * arguments are static bounds on the maximum values. Asymmetric counts are * independent. Symmetric is a static guarantee that nrecv==nsend, so it only * stores one value at runtime. This optimization save 32-bit register, but more * importantly uses fewer predicate registers when unrolling loops. */ template struct FanAsymmetric { static constexpr int MaxRecv = MaxRecv_, MaxSend = MaxSend_; int nr, ns; FanAsymmetric() = default; __device__ FanAsymmetric(int nrecv, int nsend): nr(nrecv), ns(nsend) { // assert(nrecv <= MaxRecv && nsend <= MaxSend); } __device__ int nrecv() const { return MaxRecv ? nr : 0; } __device__ int nsend() const { return MaxSend ? ns : 0; } }; template struct FanSymmetric { static constexpr int MaxRecv = MaxArity, MaxSend = MaxArity; int n; FanSymmetric() = default; __device__ FanSymmetric(int nrecv, int nsend): n(nrecv) { // assert(nrecv == nsend && nrecv <= MaxArity); } __device__ int nrecv() const { return n; } __device__ int nsend() const { return n; } }; // The primitives class. Specialized per protocol in the other headers. template class Primitives; // Used by LL & LL128 to implement direct members in the naive way. template struct PrimitivesWithoutDirect { __device__ void directSend(intptr_t inpIx, intptr_t outIx, int eltN) { static_cast(this)->send(inpIx, eltN); } __device__ void directSendFromOutput(intptr_t outIx, int eltN) { static_cast(this)->sendFromOutput(outIx, eltN); } __device__ void directRecv(intptr_t outIx, int eltN) { static_cast(this)->recv(outIx, eltN, /*postOp=*/false); } __device__ void directCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { static_cast(this)->copySend(inpIx, outIx, eltN, postOp); } __device__ void directRecvCopyDirectSend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { static_cast(this)->recvCopySend(outIx, eltN, /*postOp=*/false); } __device__ void directRecvDirectSend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { return; } __device__ void recvReduceCopyDirectSend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { // Direct is only for the send part static_cast(this)->recvReduceCopySend(inpIx, outIx, eltN, postOp); } __device__ __forceinline__ void directRecvReduceDirectSend(intptr_t inpIx, intptr_t outIx, ssize_t eltN, bool postOp=false) { static_cast(this)->recvReduceSend(inpIx, eltN); } __device__ __forceinline__ void directRecvReduceCopyDirectSend(intptr_t inpIx, intptr_t outIx, ssize_t eltN, bool postOp=false) { static_cast(this)->recvReduceCopySend(inpIx, outIx, eltN, postOp); } }; __device__ inline int checkAbort(int &abortCache, const int abortValue, int &spins) { if (abortCache & abortValue) return 1; if (++spins < NCCL_SPINS_BEFORE_CHECK_ABORT) return 0; spins = 0; int abort = *ncclShmem.comm.abortFlag; if (abort) { ncclShmem.aborted = abort; abortCache |= abortValue; } return abort; } #include "prims_simple.h" #include "prims_ll.h" #include "prims_ll128.h" #endif