#pragma once #include /** * A Functor class to create a sort for fixed sized arrays/containers with a * compile time generated Bose-Nelson sorting network. * \tparam NumElements The number of elements in the array or container to * sort. \tparam T The element type. \tparam Compare A * comparator functor class that returns true if lhs < rhs. */ template class StaticSort { template struct Swap { template CUTLASS_HOST_DEVICE void s(T& v0, T& v1) { // Explicitly code out the Min and Max to nudge the compiler // to generate branchless code. T t = v0 < v1 ? v0 : v1; // Min v1 = v0 < v1 ? v1 : v0; // Max v0 = t; } CUTLASS_HOST_DEVICE Swap(A& a, const int& i0, const int& i1) { s(a[i0], a[i1]); } }; template struct PB { CUTLASS_HOST_DEVICE PB(A& a) { enum { L = X >> 1, M = (X & 1 ? Y : Y + 1) >> 1, IAddL = I + L, XSubL = X - L }; PB p0(a); PB p1(a); PB p2(a); } }; template struct PB { CUTLASS_HOST_DEVICE PB(A& a) { Swap s(a, I - 1, J - 1); } }; template struct PB { CUTLASS_HOST_DEVICE PB(A& a) { Swap s0(a, I - 1, J); Swap s1(a, I - 1, J - 1); } }; template struct PB { CUTLASS_HOST_DEVICE PB(A& a) { Swap s0(a, I - 1, J - 1); Swap s1(a, I, J - 1); } }; template struct PS { CUTLASS_HOST_DEVICE PS(A& a) { enum { L = M >> 1, IAddL = I + L, MSubL = M - L }; PS ps0(a); PS ps1(a); PB pb(a); } }; template struct PS { CUTLASS_HOST_DEVICE PS(A& a) {} }; public: /** * Sorts the array/container arr. * \param arr The array/container to be sorted. */ template CUTLASS_HOST_DEVICE void operator()(Container& arr) const { PS ps(arr); }; /** * Sorts the array arr. * \param arr The array to be sorted. */ template CUTLASS_HOST_DEVICE void operator()(T* arr) const { PS ps(arr); }; };