/************************************************************************* * Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #include "device.h" #include "collectives.h" #include "primitives.h" namespace { template __device__ __forceinline__ void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) { ncclRing *ring = &ncclShmem.channel.ring; const int nranks = ncclShmem.comm.nRanks; const int rank = ncclShmem.comm.rank; const int prevRank = ring->userRanks[nranks-1]; const int root = work->root; size_t chunkCount; size_t channelCount; size_t gridOffset; ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), (size_t*)nullptr, &gridOffset, &channelCount, &chunkCount); size_t offset; int nelem; // Coverity reports that the callee treats &ring->next as an array. However, due to the use of // FanSymmetric<1>, only the first element is ever accessed, so it's fine. // coverity[callee_ptr_arith:FALSE] Primitives, 0, Proto, 0> prims(tid, nthreads, &ring->prev, &ring->next, work->sendbuff, work->recvbuff, work->redOpArg); if (prevRank == root) { for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { offset = gridOffset + elemOffset; nelem = min(chunkCount, channelCount - elemOffset); prims.send(offset, nelem); } } else if (rank == root) { for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { offset = gridOffset + elemOffset; nelem = min(chunkCount, channelCount - elemOffset); prims.recvReduceCopy(offset, offset, nelem, /*postOp=*/true); } } else { for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) { offset = gridOffset + elemOffset; nelem = min(chunkCount, channelCount - elemOffset); prims.recvReduceSend(offset, nelem); } } } } template struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { using Proto = ProtoSimple; runRing(tid, nthreads, work); } }; template struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { runRing(tid, nthreads, work); } }; template struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { runRing(tid, nthreads, work); } };