/************************************************************************* * Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #include "device.h" #include "collectives.h" #include "primitives.h" namespace { template __device__ __forceinline__ void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) { ncclRing *ring = &ncclShmem.channel.ring; const int rank = ring->userRanks[0]; const int nextRank = ring->userRanks[1]; const int root = work->root; ssize_t chunkCount; ssize_t channelCount; ssize_t gridOffset; ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), (ssize_t*)nullptr, &gridOffset, &channelCount, &chunkCount); size_t offset; int nelem; int workNthreads; bool isNetOffload = work->isOneRPN && work->netRegUsed; T *inputBuf = (T*)work->sendbuff; T *outputBuf = (T*)work->recvbuff; workNthreads = isNetOffload ? WARP_SIZE : nthreads; if (tid < workNthreads) { // Coverity reports that the callee treats &ring->next as an array. However, due to the use of // FanSymmetric<1>, only the first element is ever accessed, so it's fine. // coverity[callee_ptr_arith:FALSE] Primitives, 1, Proto, 0> prims(tid, workNthreads, &ring->prev, &ring->next, inputBuf, outputBuf, work->redOpArg, 0, 0, 0, work); for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { offset = gridOffset + elemOffset; nelem = min(chunkCount, channelCount - elemOffset); if (rank == root) { if (inputBuf == outputBuf || isNetOffload) { prims.directSend(offset, offset, nelem); } else { prims.directCopySend(offset, offset, nelem); } } else if (nextRank == root) { prims.directRecv(offset, nelem); } else { prims.directRecvCopyDirectSend(offset, offset, nelem); } } } else if (inputBuf != outputBuf && rank == root) { inputBuf = inputBuf + gridOffset; outputBuf = outputBuf + gridOffset; reduceCopy (tid - workNthreads, nthreads - workNthreads, work->redOpArg, &work->redOpArg, false, 1, (void**)&inputBuf, 1, (void**)&outputBuf, channelCount); } if (isNetOffload) barrier_sync(14, nthreads); } } template struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { using Proto = ProtoSimple; runRing(tid, nthreads, work); } }; template struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { runRing(tid, nthreads, work); } }; template struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { runRing(tid, nthreads, work); } };