// SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include "dispatch_utils.h" // ===================================================================================================================== // Keyword interpretation // ---------------------------------------------------------------- // 1c/2c: The number of channels. 2c means two inputs and two outputs. // Cached, Uncached: Whether cosine and sine are calculated in kernel. Cached means kernel can read these value from // memory rather than calculate these value according to the given theta in memory. // ReuseFreqsFrontPart: Normally, freqs/cos/sin tensors should be repeated before conduct the RoPE operators. With // this value set as true, the repeat is no longer required. Kernel can automatically relocate the // desired element. // sbhd: Shape of tensor: [sequence length, batch size, head count, hidden dimension]. // thd: Shape of tensor. // 2d: 2D image. // NopeFirst: [0, size_r(rotate dim)) is rotated and the rest is just copied if this value is false. // [size_d (size of d dim) - size_r, size_d) is rotated and the front part is just copied if true. // #define ROTATE_STYLE_NEOX 0 #define ROTATE_STYLE_GPTJ 1 // ===================================================================================================================== // Kernel Helper Functions // template __device__ __forceinline__ void load_cos_sin_uncached( float* p_cos_0, float* p_sin_0, float* p_cos_1, float* p_sin_1, const scalar_f_t* __restrict__ p_freqs, const int32_t did, const int32_t size_half_r) { if constexpr (RotateStyle == ROTATE_STYLE_NEOX) { if constexpr (IsForward) { sincosf(float(p_freqs[did]), p_sin_0, p_cos_0); if constexpr (ReuseFreqsFrontPart) { *p_cos_1 = *p_cos_0; *p_sin_1 = *p_sin_0; } else { sincosf(float(p_freqs[did + size_half_r]), p_sin_1, p_cos_1); } } else { const float f_did_0 = float(p_freqs[did]); if constexpr (ReuseFreqsFrontPart) { *p_cos_0 = cosf(f_did_0); *p_sin_0 = sinf(f_did_0); *p_cos_1 = *p_cos_0; *p_sin_1 = *p_sin_0; } else { const float f_did_1 = p_freqs[did + size_half_r]; *p_cos_0 = cosf(f_did_0); *p_sin_0 = sinf(f_did_1); *p_cos_1 = cosf(f_did_1); *p_sin_1 = sinf(f_did_0); } } } else if constexpr (RotateStyle == ROTATE_STYLE_GPTJ) { if constexpr (IsForward) { if constexpr (ReuseFreqsFrontPart) { sincosf(float(p_freqs[did]), p_sin_0, p_cos_0); *p_cos_1 = *p_cos_0; *p_sin_1 = *p_sin_0; } else { sincosf(float(p_freqs[did * 2]), p_sin_0, p_cos_0); sincosf(float(p_freqs[did * 2 + 1]), p_sin_1, p_cos_1); } } else { if constexpr (ReuseFreqsFrontPart) { sincosf(float(p_freqs[did]), p_sin_0, p_cos_0); *p_cos_1 = *p_cos_0; *p_sin_1 = *p_sin_0; } else { const float f_did_0 = float(p_freqs[did * 2]); const float f_did_1 = float(p_freqs[did * 2 + 1]); *p_cos_0 = cosf(f_did_0); *p_sin_0 = sinf(f_did_1); *p_cos_1 = cosf(f_did_1); *p_sin_1 = sinf(f_did_0); } } } } template __device__ __forceinline__ void load_cos_sin_cached( float* p_cos_0, float* p_sin_0, float* p_cos_1, float* p_sin_1, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int32_t did, const int32_t size_half_r) { if constexpr (RotateStyle == ROTATE_STYLE_NEOX) { if constexpr (IsForward) { *p_cos_0 = float(p_cos[did]); *p_sin_0 = float(p_sin[did]); if constexpr (ReuseFreqsFrontPart) { *p_cos_1 = *p_cos_0; *p_sin_1 = *p_sin_0; } else { *p_cos_1 = float(p_cos[did + size_half_r]); *p_sin_1 = float(p_sin[did + size_half_r]); } } else { if constexpr (ReuseFreqsFrontPart) { *p_cos_0 = float(p_cos[did]); *p_sin_0 = float(p_sin[did]); *p_cos_1 = *p_cos_0; *p_sin_1 = *p_sin_0; } else { *p_cos_0 = float(p_cos[did]); *p_sin_0 = float(p_sin[did + size_half_r]); *p_cos_1 = float(p_cos[did + size_half_r]); *p_sin_1 = float(p_sin[did]); } } } else if constexpr (RotateStyle == ROTATE_STYLE_GPTJ) { if constexpr (IsForward) { if constexpr (ReuseFreqsFrontPart) { *p_cos_0 = float(p_cos[did]); *p_sin_0 = float(p_sin[did]); *p_cos_1 = *p_cos_0; *p_sin_1 = *p_sin_0; } else { *p_cos_0 = float(p_cos[did * 2]); *p_sin_0 = float(p_sin[did * 2]); *p_cos_1 = float(p_cos[did * 2 + 1]); *p_sin_1 = float(p_sin[did * 2 + 1]); } } else { if constexpr (ReuseFreqsFrontPart) { *p_cos_0 = float(p_cos[did]); *p_sin_0 = float(p_sin[did]); *p_cos_1 = *p_cos_0; *p_sin_1 = *p_sin_0; } else { *p_cos_0 = float(p_cos[did * 2]); *p_sin_0 = float(p_sin[did * 2 + 1]); *p_cos_1 = float(p_cos[did * 2 + 1]); *p_sin_1 = float(p_sin[did * 2]); } } } } template __device__ __forceinline__ void get_offset( int32_t* p_offset_0, int32_t* p_offset_1, const int32_t did, const int32_t hid, const int32_t stride_d, const int32_t stride_h, const int32_t size_half_r) { const int32_t offset_h = hid * stride_h; if constexpr (RotateStyle == ROTATE_STYLE_NEOX) { *p_offset_0 = offset_h + did * stride_d; *p_offset_1 = *p_offset_0 + size_half_r * stride_d; } else if constexpr (RotateStyle == ROTATE_STYLE_GPTJ) { *p_offset_0 = offset_h + 2 * did * stride_d; if constexpr (StrideDEq1) { // Asking compiler to merge memory ops when accessing adjacent elements. *p_offset_1 = *p_offset_0 + 1; } else { *p_offset_1 = *p_offset_0 + stride_d; } } } template __device__ __forceinline__ void load_payload( o_scalar_t* p_data_0, o_scalar_t* p_data_1, const i_scalar_t* p_buffer, const int32_t did, const int32_t hid, const int32_t stride_d, const int32_t stride_h, const int32_t size_half_r) { int32_t offset_0, offset_1; get_offset(&offset_0, &offset_1, did, hid, stride_d, stride_h, size_half_r); *p_data_0 = o_scalar_t(p_buffer[offset_0]); *p_data_1 = o_scalar_t(p_buffer[offset_1]); } template __device__ __forceinline__ void store_payload( o_scalar_t* p_buffer, const i_scalar_t data_0, const i_scalar_t data_1, const int32_t did, const int32_t hid, const int32_t stride_d, const int32_t stride_h, const int32_t size_half_r) { int32_t offset_0, offset_1; get_offset(&offset_0, &offset_1, did, hid, stride_d, stride_h, size_half_r); p_buffer[offset_0] = o_scalar_t(data_0); p_buffer[offset_1] = o_scalar_t(data_1); } template __device__ __forceinline__ void elementwise_copy( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const int32_t hid_end, const int32_t did_start, const int32_t did_end, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_h, const int32_t stride_o_d) { if (did_end > did_start) { #pragma unroll for (int32_t hid = threadIdx.y; hid < hid_end; hid += blockDim.y) { const int32_t offset_i = hid * stride_i_h; const int32_t offset_o = hid * stride_o_h; #pragma unroll for (int32_t did = threadIdx.x + did_start; did < did_end; did += blockDim.x) { p_output[offset_o + did * stride_o_d] = p_input[offset_i + did * stride_i_d]; } } } } template __device__ __forceinline__ void elementwise_copy_2c( scalar_t* __restrict__ p_output_x, scalar_t* __restrict__ p_output_y, const scalar_t* __restrict__ p_input_x, const scalar_t* __restrict__ p_input_y, const int32_t hid_end_x, const int32_t hid_end_y, const int32_t did_start, const int32_t did_end, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_h, const int32_t stride_iy_d, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_h, const int32_t stride_oy_d) { if (did_end > did_start) { const int32_t hid_min_end = hid_end_x < hid_end_y ? hid_end_x : hid_end_y; #pragma unroll for (int32_t hid = threadIdx.y; hid < hid_min_end; hid += blockDim.y) { const int32_t offset_ix = hid * stride_ix_h; const int32_t offset_iy = hid * stride_iy_h; const int32_t offset_ox = hid * stride_ox_h; const int32_t offset_oy = hid * stride_oy_h; #pragma unroll for (int32_t did = threadIdx.x + did_start; did < did_end; did += blockDim.x) { p_output_x[offset_ox + did * stride_ox_d] = p_input_x[offset_ix + did * stride_ix_d]; p_output_y[offset_oy + did * stride_oy_d] = p_input_y[offset_iy + did * stride_iy_d]; } } #pragma unroll for (int32_t hid = threadIdx.y + hid_min_end; hid < hid_end_x; hid += blockDim.y) { const int32_t offset_ix = hid * stride_ix_h; const int32_t offset_ox = hid * stride_ox_h; #pragma unroll for (int32_t did = threadIdx.x + did_start; did < did_end; did += blockDim.x) { p_output_x[offset_ox + did * stride_ox_d] = p_input_x[offset_ix + did * stride_ix_d]; } } #pragma unroll for (int32_t hid = threadIdx.y + hid_min_end; hid < hid_end_y; hid += blockDim.y) { const int32_t offset_iy = hid * stride_iy_h; const int32_t offset_oy = hid * stride_oy_h; #pragma unroll for (int32_t did = threadIdx.x + did_start; did < did_end; did += blockDim.x) { p_output_y[offset_oy + did * stride_oy_d] = p_input_y[offset_iy + did * stride_iy_d]; } } } } // ===================================================================================================================== // Kernel Functionalities // struct OpUncachedFwd { template __device__ __forceinline__ static void apply_1c( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const scalar_f_t* __restrict__ p_freqs, const int32_t size_h, const int32_t size_d, const int32_t size_f, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_h, const int32_t stride_o_d) { // rotate count const int32_t size_r = ReuseFreqsFrontPart ? (size_f << 1) : size_f; const int32_t size_half_r = size_r >> 1; const int32_t did_start = NopeFirst ? (size_d - size_r) : 0; #pragma unroll for (int32_t did = threadIdx.x + did_start; did < size_half_r; did += blockDim.x) { float cos_0, sin_0, cos_1, sin_1; load_cos_sin_uncached( &cos_0, &sin_0, &cos_1, &sin_1, p_freqs, did - did_start, size_half_r); #pragma unroll for (int32_t hid = threadIdx.y; hid < size_h; hid += blockDim.y) { float input_0, input_1; load_payload( &input_0, &input_1, p_input, did, hid, stride_i_d, stride_i_h, size_half_r); const float output_0 = input_0 * cos_0 - input_1 * sin_0; const float output_1 = input_1 * cos_1 + input_0 * sin_1; store_payload( p_output, output_0, output_1, did, hid, stride_o_d, stride_o_h, size_half_r); } } // the rest are just forwarded if constexpr (!Inplace) { const int32_t did_start = NopeFirst ? 0 : size_r; const int32_t did_end = NopeFirst ? (size_d - size_r) : size_d; elementwise_copy( p_output, p_input, size_h, did_start, did_end, stride_i_h, stride_i_d, stride_o_h, stride_o_d); } } template __device__ __forceinline__ static void apply_2c( scalar_t* __restrict__ p_output_x, scalar_t* __restrict__ p_output_y, const scalar_t* __restrict__ p_input_x, const scalar_t* __restrict__ p_input_y, const scalar_f_t* __restrict__ p_freqs, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_h, const int32_t stride_iy_d, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_h, const int32_t stride_oy_d) { // rotate count const int32_t size_r = ReuseFreqsFrontPart ? (size_f << 1) : size_f; const int32_t size_half_r = size_r >> 1; const int32_t did_start = NopeFirst ? (size_d - size_r) : 0; const int32_t size_min_h = min(size_h_x, size_h_y); #pragma unroll for (int32_t did = threadIdx.x + did_start; did < size_half_r; did += blockDim.x) { float cos_0, sin_0, cos_1, sin_1; load_cos_sin_uncached( &cos_0, &sin_0, &cos_1, &sin_1, p_freqs, did - did_start, size_half_r); #pragma unroll for (int32_t hid = threadIdx.y; hid < size_min_h; hid += blockDim.y) { float input_x_0, input_x_1, input_y_0, input_y_1; load_payload( &input_x_0, &input_x_1, p_input_x, did, hid, stride_ix_d, stride_ix_h, size_half_r); load_payload( &input_y_0, &input_y_1, p_input_y, did, hid, stride_iy_d, stride_iy_h, size_half_r); const float output_x_0 = input_x_0 * cos_0 - input_x_1 * sin_0; const float output_x_1 = input_x_1 * cos_1 + input_x_0 * sin_1; const float output_y_0 = input_y_0 * cos_0 - input_y_1 * sin_0; const float output_y_1 = input_y_1 * cos_1 + input_y_0 * sin_1; store_payload( p_output_x, output_x_0, output_x_1, did, hid, stride_ox_d, stride_ox_h, size_half_r); store_payload( p_output_y, output_y_0, output_y_1, did, hid, stride_oy_d, stride_oy_h, size_half_r); } #pragma unroll for (int32_t hid = threadIdx.y + size_min_h; hid < size_h_x; hid += blockDim.y) { float input_x_0, input_x_1; load_payload( &input_x_0, &input_x_1, p_input_x, did, hid, stride_ix_d, stride_ix_h, size_half_r); const float output_x_0 = input_x_0 * cos_0 - input_x_1 * sin_0; const float output_x_1 = input_x_1 * cos_1 + input_x_0 * sin_1; store_payload( p_output_x, output_x_0, output_x_1, did, hid, stride_ox_d, stride_ox_h, size_half_r); } #pragma unroll for (int32_t hid = threadIdx.y + size_min_h; hid < size_h_y; hid += blockDim.y) { float input_y_0, input_y_1; load_payload( &input_y_0, &input_y_1, p_input_y, did, hid, stride_iy_d, stride_iy_h, size_half_r); const float output_y_0 = input_y_0 * cos_0 - input_y_1 * sin_0; const float output_y_1 = input_y_1 * cos_1 + input_y_0 * sin_1; store_payload( p_output_y, output_y_0, output_y_1, did, hid, stride_oy_d, stride_oy_h, size_half_r); } } // the rest are just forwarded if constexpr (!Inplace) { const int32_t did_start = NopeFirst ? 0 : size_r; const int32_t did_end = NopeFirst ? (size_d - size_r) : size_d; elementwise_copy_2c( p_output_x, p_output_y, p_input_x, p_input_y, size_h_x, size_h_y, did_start, did_end, stride_ix_h, stride_ix_d, stride_iy_h, stride_iy_d, stride_ox_h, stride_ox_d, stride_oy_h, stride_oy_d); } } }; struct OpUncachedBwd { template __device__ __forceinline__ static void apply_1c( scalar_t* __restrict__ p_input_grads, const scalar_t* __restrict__ p_output_grads, const scalar_f_t* __restrict__ p_freqs, const int32_t size_h, const int32_t size_d, const int32_t size_f, const int32_t stride_o_h, const int32_t stride_o_d, const int32_t stride_i_h, const int32_t stride_i_d) { // rotate count const int32_t size_r = ReuseFreqsFrontPart ? (size_f << 1) : size_f; const int32_t size_half_r = size_r >> 1; const int32_t did_start = NopeFirst ? (size_d - size_r) : 0; #pragma unroll for (int32_t did = threadIdx.x + did_start; did < size_half_r; did += blockDim.x) { float cos_0, sin_0, cos_1, sin_1; load_cos_sin_uncached( &cos_0, &sin_0, &cos_1, &sin_1, p_freqs, did - did_start, size_half_r); #pragma unroll for (int32_t hid = threadIdx.y; hid < size_h; hid += blockDim.y) { float output_grad_0, output_grad_1; load_payload( &output_grad_0, &output_grad_1, p_output_grads, did, hid, stride_o_d, stride_o_h, size_half_r); const float input_grad_0 = output_grad_0 * cos_0 + output_grad_1 * sin_0; const float input_grad_1 = output_grad_1 * cos_1 - output_grad_0 * sin_1; store_payload( p_input_grads, input_grad_0, input_grad_1, did, hid, stride_i_d, stride_i_h, size_half_r); } } // the rest are just forwarded if constexpr (!Inplace) { const int32_t did_start = NopeFirst ? 0 : size_r; const int32_t did_end = NopeFirst ? (size_d - size_r) : size_d; elementwise_copy( p_input_grads, p_output_grads, size_h, did_start, did_end, stride_o_h, stride_o_d, stride_i_h, stride_i_d); } } template __device__ __forceinline__ static void apply_2c( scalar_t* __restrict__ p_input_grads_x, scalar_t* __restrict__ p_input_grads_y, const scalar_t* __restrict__ p_output_grads_x, const scalar_t* __restrict__ p_output_grads_y, const scalar_f_t* __restrict__ p_freqs, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_h, const int32_t stride_oy_d, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_h, const int32_t stride_iy_d) { // rotate count const int32_t size_r = ReuseFreqsFrontPart ? (size_f << 1) : size_f; const int32_t size_half_r = size_r >> 1; const int32_t did_start = NopeFirst ? (size_d - size_r) : 0; const int32_t size_min_h = min(size_h_x, size_h_y); #pragma unroll for (int32_t did = threadIdx.x + did_start; did < size_half_r; did += blockDim.x) { float cos_0, sin_0, cos_1, sin_1; load_cos_sin_uncached( &cos_0, &sin_0, &cos_1, &sin_1, p_freqs, did - did_start, size_half_r); #pragma unroll for (int32_t hid = threadIdx.y; hid < size_min_h; hid += blockDim.y) { float output_grad_x_0, output_grad_x_1, output_grad_y_0, output_grad_y_1; load_payload( &output_grad_x_0, &output_grad_x_1, p_output_grads_x, did, hid, stride_ox_d, stride_ox_h, size_half_r); load_payload( &output_grad_y_0, &output_grad_y_1, p_output_grads_y, did, hid, stride_oy_d, stride_oy_h, size_half_r); const float input_grad_x_0 = output_grad_x_0 * cos_0 + output_grad_x_1 * sin_0; const float input_grad_x_1 = output_grad_x_1 * cos_1 - output_grad_x_0 * sin_1; const float input_grad_y_0 = output_grad_y_0 * cos_0 + output_grad_y_1 * sin_0; const float input_grad_y_1 = output_grad_y_1 * cos_1 - output_grad_y_0 * sin_1; store_payload( p_input_grads_x, input_grad_x_0, input_grad_x_1, did, hid, stride_ix_d, stride_ix_h, size_half_r); store_payload( p_input_grads_y, input_grad_y_0, input_grad_y_1, did, hid, stride_iy_d, stride_iy_h, size_half_r); } #pragma unroll for (int32_t hid = threadIdx.y + size_min_h; hid < size_h_x; hid += blockDim.y) { float output_grad_x_0, output_grad_x_1; load_payload( &output_grad_x_0, &output_grad_x_1, p_output_grads_x, did, hid, stride_ox_d, stride_ox_h, size_half_r); const float input_grad_x_0 = output_grad_x_0 * cos_0 + output_grad_x_1 * sin_0; const float input_grad_x_1 = output_grad_x_1 * cos_1 - output_grad_x_0 * sin_1; store_payload( p_input_grads_x, input_grad_x_0, input_grad_x_1, did, hid, stride_ix_d, stride_ix_h, size_half_r); } #pragma unroll for (int32_t hid = threadIdx.y + size_min_h; hid < size_h_y; hid += blockDim.y) { float output_grad_y_0, output_grad_y_1; load_payload( &output_grad_y_0, &output_grad_y_1, p_output_grads_y, did, hid, stride_oy_d, stride_oy_h, size_half_r); const float input_grad_y_0 = output_grad_y_0 * cos_0 + output_grad_y_1 * sin_0; const float input_grad_y_1 = output_grad_y_1 * cos_1 - output_grad_y_0 * sin_1; store_payload( p_input_grads_y, input_grad_y_0, input_grad_y_1, did, hid, stride_iy_d, stride_iy_h, size_half_r); } } // the rest are just forwarded if constexpr (!Inplace) { const int32_t did_start = NopeFirst ? 0 : size_r; const int32_t did_end = NopeFirst ? (size_d - size_r) : size_d; elementwise_copy_2c( p_input_grads_x, p_input_grads_y, p_output_grads_x, p_output_grads_y, size_h_x, size_h_y, did_start, did_end, stride_ox_h, stride_ox_d, stride_oy_h, stride_oy_d, stride_ix_h, stride_ix_d, stride_iy_h, stride_iy_d); } } }; struct OpCachedFwd { template __device__ __forceinline__ static void apply_1c( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int32_t size_h, const int32_t size_d, const int32_t size_f, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_h, const int32_t stride_o_d) { // rotate count const int32_t size_r = ReuseFreqsFrontPart ? (size_f << 1) : size_f; const int32_t size_half_r = size_r >> 1; const int32_t did_start = NopeFirst ? (size_d - size_r) : 0; #pragma unroll for (int32_t did = threadIdx.x + did_start; did < size_half_r; did += blockDim.x) { float cos_0, sin_0, cos_1, sin_1; load_cos_sin_cached( &cos_0, &sin_0, &cos_1, &sin_1, p_cos, p_sin, did - did_start, size_half_r); #pragma unroll for (int32_t hid = threadIdx.y; hid < size_h; hid += blockDim.y) { float input_0, input_1; load_payload( &input_0, &input_1, p_input, did, hid, stride_i_d, stride_i_h, size_half_r); const float output_0 = input_0 * cos_0 - input_1 * sin_0; const float output_1 = input_1 * cos_1 + input_0 * sin_1; store_payload( p_output, output_0, output_1, did, hid, stride_o_d, stride_o_h, size_half_r); } } // the rest are just forwarded if constexpr (!Inplace) { const int32_t did_start = NopeFirst ? 0 : size_r; const int32_t did_end = NopeFirst ? (size_d - size_r) : size_d; elementwise_copy( p_output, p_input, size_h, did_start, did_end, stride_i_h, stride_i_d, stride_o_h, stride_o_d); } } template __device__ __forceinline__ static void apply_2c( scalar_t* __restrict__ p_output_x, scalar_t* __restrict__ p_output_y, const scalar_t* __restrict__ p_input_x, const scalar_t* __restrict__ p_input_y, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_h, const int32_t stride_iy_d, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_h, const int32_t stride_oy_d) { // rotate count const int32_t size_r = ReuseFreqsFrontPart ? (size_f << 1) : size_f; const int32_t size_half_r = size_r >> 1; const int32_t did_start = NopeFirst ? (size_d - size_r) : 0; const int32_t size_min_h = min(size_h_x, size_h_y); #pragma unroll for (int32_t did = threadIdx.x + did_start; did < size_half_r; did += blockDim.x) { float cos_0, sin_0, cos_1, sin_1; load_cos_sin_cached( &cos_0, &sin_0, &cos_1, &sin_1, p_cos, p_sin, did - did_start, size_half_r); #pragma unroll for (int32_t hid = threadIdx.y; hid < size_min_h; hid += blockDim.y) { float input_x_0, input_x_1, input_y_0, input_y_1; load_payload( &input_x_0, &input_x_1, p_input_x, did, hid, stride_ix_d, stride_ix_h, size_half_r); load_payload( &input_y_0, &input_y_1, p_input_y, did, hid, stride_iy_d, stride_iy_h, size_half_r); const float output_x_0 = input_x_0 * cos_0 - input_x_1 * sin_0; const float output_x_1 = input_x_1 * cos_1 + input_x_0 * sin_1; const float output_y_0 = input_y_0 * cos_0 - input_y_1 * sin_0; const float output_y_1 = input_y_1 * cos_1 + input_y_0 * sin_1; store_payload( p_output_x, output_x_0, output_x_1, did, hid, stride_ox_d, stride_ox_h, size_half_r); store_payload( p_output_y, output_y_0, output_y_1, did, hid, stride_oy_d, stride_oy_h, size_half_r); } #pragma unroll for (int32_t hid = threadIdx.y + size_min_h; hid < size_h_x; hid += blockDim.y) { float input_x_0, input_x_1; load_payload( &input_x_0, &input_x_1, p_input_x, did, hid, stride_ix_d, stride_ix_h, size_half_r); const float output_x_0 = input_x_0 * cos_0 - input_x_1 * sin_0; const float output_x_1 = input_x_1 * cos_1 + input_x_0 * sin_1; store_payload( p_output_x, output_x_0, output_x_1, did, hid, stride_ox_d, stride_ox_h, size_half_r); } #pragma unroll for (int32_t hid = threadIdx.y + size_min_h; hid < size_h_y; hid += blockDim.y) { float input_y_0, input_y_1; load_payload( &input_y_0, &input_y_1, p_input_y, did, hid, stride_iy_d, stride_iy_h, size_half_r); const float output_y_0 = input_y_0 * cos_0 - input_y_1 * sin_0; const float output_y_1 = input_y_1 * cos_1 + input_y_0 * sin_1; store_payload( p_output_y, output_y_0, output_y_1, did, hid, stride_oy_d, stride_oy_h, size_half_r); } } // the rest are just forwarded if constexpr (!Inplace) { const int32_t did_start = NopeFirst ? 0 : size_r; const int32_t did_end = NopeFirst ? (size_d - size_r) : size_d; elementwise_copy_2c( p_output_x, p_output_y, p_input_x, p_input_y, size_h_x, size_h_y, did_start, did_end, stride_ix_h, stride_ix_d, stride_iy_h, stride_iy_d, stride_ox_h, stride_ox_d, stride_oy_h, stride_oy_d); } } }; struct OpCachedBwd { template __device__ __forceinline__ static void apply_1c( scalar_t* __restrict__ p_input_grads, const scalar_t* __restrict__ p_output_grads, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int32_t size_h, const int32_t size_d, const int32_t size_f, const int32_t stride_o_h, const int32_t stride_o_d, const int32_t stride_i_h, const int32_t stride_i_d) { // rotate count const int32_t size_r = ReuseFreqsFrontPart ? (size_f << 1) : size_f; const int32_t size_half_r = size_r >> 1; const int32_t did_start = NopeFirst ? (size_d - size_r) : 0; #pragma unroll for (int32_t did = threadIdx.x + did_start; did < size_half_r; did += blockDim.x) { float cos_0, sin_0, cos_1, sin_1; load_cos_sin_cached( &cos_0, &sin_0, &cos_1, &sin_1, p_cos, p_sin, did - did_start, size_half_r); #pragma unroll for (int32_t hid = threadIdx.y; hid < size_h; hid += blockDim.y) { float output_grad_0, output_grad_1; load_payload( &output_grad_0, &output_grad_1, p_output_grads, did, hid, stride_o_d, stride_o_h, size_half_r); const float input_grad_0 = output_grad_0 * cos_0 + output_grad_1 * sin_0; const float input_grad_1 = output_grad_1 * cos_1 - output_grad_0 * sin_1; store_payload( p_input_grads, input_grad_0, input_grad_1, did, hid, stride_i_d, stride_i_h, size_half_r); } } // the rest are just forwarded if constexpr (!Inplace) { const int32_t did_start = NopeFirst ? 0 : size_r; const int32_t did_end = NopeFirst ? (size_d - size_r) : size_d; elementwise_copy( p_input_grads, p_output_grads, size_h, did_start, did_end, stride_o_h, stride_o_d, stride_i_h, stride_i_d); } } template __device__ __forceinline__ static void apply_2c( scalar_t* __restrict__ p_input_grads_x, scalar_t* __restrict__ p_input_grads_y, const scalar_t* __restrict__ p_output_grads_x, const scalar_t* __restrict__ p_output_grads_y, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_h, const int32_t stride_oy_d, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_h, const int32_t stride_iy_d) { // rotate count const int32_t size_r = ReuseFreqsFrontPart ? (size_f << 1) : size_f; const int32_t size_half_r = size_r >> 1; const int32_t did_start = NopeFirst ? (size_d - size_r) : 0; const int32_t size_min_h = min(size_h_x, size_h_y); #pragma unroll for (int32_t did = threadIdx.x + did_start; did < size_half_r; did += blockDim.x) { float cos_0, sin_0, cos_1, sin_1; load_cos_sin_cached( &cos_0, &sin_0, &cos_1, &sin_1, p_cos, p_sin, did - did_start, size_half_r); #pragma unroll for (int32_t hid = threadIdx.y; hid < size_min_h; hid += blockDim.y) { float output_grad_x_0, output_grad_x_1, output_grad_y_0, output_grad_y_1; load_payload( &output_grad_x_0, &output_grad_x_1, p_output_grads_x, did, hid, stride_ox_d, stride_ox_h, size_half_r); load_payload( &output_grad_y_0, &output_grad_y_1, p_output_grads_y, did, hid, stride_oy_d, stride_oy_h, size_half_r); const float input_grad_x_0 = output_grad_x_0 * cos_0 + output_grad_x_1 * sin_0; const float input_grad_x_1 = output_grad_x_1 * cos_1 - output_grad_x_0 * sin_1; const float input_grad_y_0 = output_grad_y_0 * cos_0 + output_grad_y_1 * sin_0; const float input_grad_y_1 = output_grad_y_1 * cos_1 - output_grad_y_0 * sin_1; store_payload( p_input_grads_x, input_grad_x_0, input_grad_x_1, did, hid, stride_ix_d, stride_ix_h, size_half_r); store_payload( p_input_grads_y, input_grad_y_0, input_grad_y_1, did, hid, stride_iy_d, stride_iy_h, size_half_r); } #pragma unroll for (int32_t hid = threadIdx.y + size_min_h; hid < size_h_x; hid += blockDim.y) { float output_grad_x_0, output_grad_x_1; load_payload( &output_grad_x_0, &output_grad_x_1, p_output_grads_x, did, hid, stride_ox_d, stride_ox_h, size_half_r); const float input_grad_x_0 = output_grad_x_0 * cos_0 + output_grad_x_1 * sin_0; const float input_grad_x_1 = output_grad_x_1 * cos_1 - output_grad_x_0 * sin_1; store_payload( p_input_grads_x, input_grad_x_0, input_grad_x_1, did, hid, stride_ix_d, stride_ix_h, size_half_r); } #pragma unroll for (int32_t hid = threadIdx.y + size_min_h; hid < size_h_y; hid += blockDim.y) { float output_grad_y_0, output_grad_y_1; load_payload( &output_grad_y_0, &output_grad_y_1, p_output_grads_y, did, hid, stride_oy_d, stride_oy_h, size_half_r); const float input_grad_y_0 = output_grad_y_0 * cos_0 + output_grad_y_1 * sin_0; const float input_grad_y_1 = output_grad_y_1 * cos_1 - output_grad_y_0 * sin_1; store_payload( p_input_grads_y, input_grad_y_0, input_grad_y_1, did, hid, stride_iy_d, stride_iy_h, size_half_r); } } // the rest are just forwarded if constexpr (!Inplace) { const int32_t did_start = NopeFirst ? 0 : size_r; const int32_t did_end = NopeFirst ? (size_d - size_r) : size_d; elementwise_copy_2c( p_input_grads_x, p_input_grads_y, p_output_grads_x, p_output_grads_y, size_h_x, size_h_y, did_start, did_end, stride_ox_h, stride_ox_d, stride_oy_h, stride_oy_d, stride_ix_h, stride_ix_d, stride_iy_h, stride_iy_d); } } }; // ===================================================================================================================== // Kernel Entries // template __global__ void kn_entry_1c_sbhd_uncached( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const scalar_f_t* __restrict__ p_freqs, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_i_s, const int32_t stride_i_b, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_s, const int32_t stride_o_b, const int32_t stride_o_h, const int32_t stride_o_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t offset_i = sid * stride_i_s + bid * stride_i_b; const uint64_t offset_o = sid * stride_o_s + bid * stride_o_b; const uint64_t offset_f = sid * size_f; Op::template apply_1c( p_output + offset_o, p_input + offset_i, p_freqs + offset_f, size_h, size_d, size_f, stride_i_h, stride_i_d, stride_o_h, stride_o_d); } template __global__ void kn_entry_1c_sbhd_uncached_inplace( scalar_t* __restrict__ p_inout, const scalar_f_t* __restrict__ p_freqs, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_s, const int32_t stride_b, const int32_t stride_h, const int32_t stride_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t offset = sid * stride_s + bid * stride_b; const uint64_t offset_f = sid * size_f; Op::template apply_1c( p_inout + offset, p_inout + offset, p_freqs + offset_f, size_h, size_d, size_f, stride_h, stride_d, stride_h, stride_d); } template __global__ void kn_entry_2c_sbhd_uncached( scalar_t* __restrict__ p_output_x, scalar_t* __restrict__ p_output_y, const scalar_t* __restrict__ p_input_x, const scalar_t* __restrict__ p_input_y, const scalar_f_t* __restrict__ p_freqs, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_ix_s, const int32_t stride_ix_b, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_s, const int32_t stride_iy_b, const int32_t stride_iy_h, const int32_t stride_iy_d, const int32_t stride_ox_s, const int32_t stride_ox_b, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_s, const int32_t stride_oy_b, const int32_t stride_oy_h, const int32_t stride_oy_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t offset_ix = sid * stride_ix_s + bid * stride_ix_b; const uint64_t offset_iy = sid * stride_iy_s + bid * stride_iy_b; const uint64_t offset_ox = sid * stride_ox_s + bid * stride_ox_b; const uint64_t offset_oy = sid * stride_oy_s + bid * stride_oy_b; const uint64_t offset_f = sid * size_f; Op::template apply_2c( p_output_x + offset_ox, p_output_y + offset_oy, p_input_x + offset_ix, p_input_y + offset_iy, p_freqs + offset_f, size_h_x, size_h_y, size_d, size_f, stride_ix_h, stride_ix_d, stride_iy_h, stride_iy_d, stride_ox_h, stride_ox_d, stride_oy_h, stride_oy_d); } template __global__ void kn_entry_2c_sbhd_uncached_inplace( scalar_t* __restrict__ p_inout_x, scalar_t* __restrict__ p_inout_y, const scalar_f_t* __restrict__ p_freqs, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_x_s, const int32_t stride_x_b, const int32_t stride_x_h, const int32_t stride_x_d, const int32_t stride_y_s, const int32_t stride_y_b, const int32_t stride_y_h, const int32_t stride_y_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t offset_x = sid * stride_x_s + bid * stride_x_b; const uint64_t offset_y = sid * stride_y_s + bid * stride_y_b; const uint64_t offset_f = sid * size_f; Op::template apply_2c( p_inout_x + offset_x, p_inout_y + offset_y, p_inout_x + offset_x, p_inout_y + offset_y, p_freqs + offset_f, size_h_x, size_h_y, size_d, size_f, stride_x_h, stride_x_d, stride_y_h, stride_y_d, stride_x_h, stride_x_d, stride_y_h, stride_y_d); } template __global__ void kn_entry_1c_sbhd_cached( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_i_s, const int32_t stride_i_b, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_s, const int32_t stride_o_b, const int32_t stride_o_h, const int32_t stride_o_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t offset_i = sid * stride_i_s + bid * stride_i_b; const uint64_t offset_o = sid * stride_o_s + bid * stride_o_b; const uint64_t offset_f = sid * size_f; Op::template apply_1c( p_output + offset_o, p_input + offset_i, p_cos + offset_f, p_sin + offset_f, size_h, size_d, size_f, stride_i_h, stride_i_d, stride_o_h, stride_o_d); } template __global__ void kn_entry_1c_sbhd_cached_inplace( scalar_t* __restrict__ p_inout, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_s, const int32_t stride_b, const int32_t stride_h, const int32_t stride_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t offset = sid * stride_s + bid * stride_b; const uint64_t offset_f = sid * size_f; Op::template apply_1c( p_inout + offset, p_inout + offset, p_cos + offset_f, p_sin + offset_f, size_h, size_d, size_f, stride_h, stride_d, stride_h, stride_d); } template __global__ void kn_entry_2c_sbhd_cached( scalar_t* __restrict__ p_output_x, scalar_t* __restrict__ p_output_y, const scalar_t* __restrict__ p_input_x, const scalar_t* __restrict__ p_input_y, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_ix_s, const int32_t stride_ix_b, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_s, const int32_t stride_iy_b, const int32_t stride_iy_h, const int32_t stride_iy_d, const int32_t stride_ox_s, const int32_t stride_ox_b, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_s, const int32_t stride_oy_b, const int32_t stride_oy_h, const int32_t stride_oy_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t offset_ix = sid * stride_ix_s + bid * stride_ix_b; const uint64_t offset_iy = sid * stride_iy_s + bid * stride_iy_b; const uint64_t offset_ox = sid * stride_ox_s + bid * stride_ox_b; const uint64_t offset_oy = sid * stride_oy_s + bid * stride_oy_b; const uint64_t offset_f = sid * size_f; Op::template apply_2c( p_output_x + offset_ox, p_output_y + offset_oy, p_input_x + offset_ix, p_input_y + offset_iy, p_cos + offset_f, p_sin + offset_f, size_h_x, size_h_y, size_d, size_f, stride_ix_h, stride_ix_d, stride_iy_h, stride_iy_d, stride_ox_h, stride_ox_d, stride_oy_h, stride_oy_d); } template __global__ void kn_entry_2c_sbhd_cached_inplace( scalar_t* __restrict__ p_inout_x, scalar_t* __restrict__ p_inout_y, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_x_s, const int32_t stride_x_b, const int32_t stride_x_h, const int32_t stride_x_d, const int32_t stride_y_s, const int32_t stride_y_b, const int32_t stride_y_h, const int32_t stride_y_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t offset_x = sid * stride_x_s + bid * stride_x_b; const uint64_t offset_y = sid * stride_y_s + bid * stride_y_b; const uint64_t offset_f = sid * size_f; Op::template apply_2c( p_inout_x + offset_x, p_inout_y + offset_y, p_inout_x + offset_x, p_inout_y + offset_y, p_cos + offset_f, p_sin + offset_f, size_h_x, size_h_y, size_d, size_f, stride_x_h, stride_x_d, stride_y_h, stride_y_d, stride_x_h, stride_x_d, stride_y_h, stride_y_d); } template __global__ void kn_entry_1c_sbhd_cached_indirect( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int64_t* __restrict__ p_indirect_buffer, const int32_t max_position, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_i_s, const int32_t stride_i_b, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_s, const int32_t stride_o_b, const int32_t stride_o_h, const int32_t stride_o_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t ib_idx = sid * gridDim.y + bid; const int64_t pos = p_indirect_buffer[ib_idx]; if ((pos >= 0) && (pos < max_position)) { const uint64_t offset_i = sid * stride_i_s + bid * stride_i_b; const uint64_t offset_o = sid * stride_o_s + bid * stride_o_b; const int64_t offset_f = pos * size_f; Op::template apply_1c( p_output + offset_o, p_input + offset_i, p_cos + offset_f, p_sin + offset_f, size_h, size_d, size_f, stride_i_h, stride_i_d, stride_o_h, stride_o_d); } } template __global__ void kn_entry_2c_sbhd_cached_indirect( scalar_t* __restrict__ p_output_x, scalar_t* __restrict__ p_output_y, const scalar_t* __restrict__ p_input_x, const scalar_t* __restrict__ p_input_y, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int64_t* __restrict__ p_indirect_buffer, const int32_t max_position, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_ix_s, const int32_t stride_ix_b, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_s, const int32_t stride_iy_b, const int32_t stride_iy_h, const int32_t stride_iy_d, const int32_t stride_ox_s, const int32_t stride_ox_b, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_s, const int32_t stride_oy_b, const int32_t stride_oy_h, const int32_t stride_oy_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t ib_idx = sid * gridDim.y + bid; const int64_t pos = p_indirect_buffer[ib_idx]; if ((pos >= 0) && (pos < max_position)) { const uint64_t offset_ix = sid * stride_ix_s + bid * stride_ix_b; const uint64_t offset_iy = sid * stride_iy_s + bid * stride_iy_b; const uint64_t offset_ox = sid * stride_ox_s + bid * stride_ox_b; const uint64_t offset_oy = sid * stride_oy_s + bid * stride_oy_b; const int64_t offset_f = pos * size_f; Op::template apply_2c( p_output_x + offset_ox, p_output_y + offset_oy, p_input_x + offset_ix, p_input_y + offset_iy, p_cos + offset_f, p_sin + offset_f, size_h_x, size_h_y, size_d, size_f, stride_ix_h, stride_ix_d, stride_iy_h, stride_iy_d, stride_ox_h, stride_ox_d, stride_oy_h, stride_oy_d); } } template __global__ void kn_entry_1c_sbhd_cached_indirect_inplace( scalar_t* __restrict__ p_inout, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int64_t* __restrict__ p_indirect_buffer, const int32_t max_position, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_s, const int32_t stride_b, const int32_t stride_h, const int32_t stride_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t ib_idx = sid * gridDim.y + bid; const int64_t pos = p_indirect_buffer[ib_idx]; if ((pos >= 0) && (pos < max_position)) { const uint64_t offset = sid * stride_s + bid * stride_b; const int64_t offset_f = pos * size_f; Op::template apply_1c( p_inout + offset, p_inout + offset, p_cos + offset_f, p_sin + offset_f, size_h, size_d, size_f, stride_h, stride_d, stride_h, stride_d); } } template __global__ void kn_entry_2c_sbhd_cached_indirect_inplace( scalar_t* __restrict__ p_inout_x, scalar_t* __restrict__ p_inout_y, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int64_t* __restrict__ p_indirect_buffer, const int32_t max_position, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_x_s, const int32_t stride_x_b, const int32_t stride_x_h, const int32_t stride_x_d, const int32_t stride_y_s, const int32_t stride_y_b, const int32_t stride_y_h, const int32_t stride_y_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t ib_idx = sid * gridDim.y + bid; const int64_t pos = p_indirect_buffer[ib_idx]; if ((pos >= 0) && (pos < max_position)) { const uint64_t offset_x = sid * stride_x_s + bid * stride_x_b; const uint64_t offset_y = sid * stride_y_s + bid * stride_y_b; const int64_t offset_f = pos * size_f; Op::template apply_2c( p_inout_x + offset_x, p_inout_y + offset_y, p_inout_x + offset_x, p_inout_y + offset_y, p_cos + offset_f, p_sin + offset_f, size_h_x, size_h_y, size_d, size_f, stride_x_h, stride_x_d, stride_y_h, stride_y_d, stride_x_h, stride_x_d, stride_y_h, stride_y_d); } } template __global__ void kn_entry_1c_sbhd_cached_indirect2( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int64_t* __restrict__ p_indirect_buffer_0, const int64_t* __restrict__ p_indirect_buffer_1, const int32_t max_position, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_i_s, const int32_t stride_i_b, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_s, const int32_t stride_o_b, const int32_t stride_o_h, const int32_t stride_o_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t ib_idx = sid * gridDim.y + bid; const int64_t pos = p_indirect_buffer_0[ib_idx] + p_indirect_buffer_1[ib_idx]; if ((pos >= 0) && (pos < max_position)) { const uint64_t offset_i = sid * stride_i_s + bid * stride_i_b; const uint64_t offset_o = sid * stride_o_s + bid * stride_o_b; const int64_t offset_f = pos * size_f; Op::template apply_1c( p_output + offset_o, p_input + offset_i, p_cos + offset_f, p_sin + offset_f, size_h, size_d, size_f, stride_i_h, stride_i_d, stride_o_h, stride_o_d); } } template __global__ void kn_entry_2c_sbhd_cached_indirect2( scalar_t* __restrict__ p_output_x, scalar_t* __restrict__ p_output_y, const scalar_t* __restrict__ p_input_x, const scalar_t* __restrict__ p_input_y, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int64_t* __restrict__ p_indirect_buffer_0, const int64_t* __restrict__ p_indirect_buffer_1, const int32_t max_position, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_ix_s, const int32_t stride_ix_b, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_s, const int32_t stride_iy_b, const int32_t stride_iy_h, const int32_t stride_iy_d, const int32_t stride_ox_s, const int32_t stride_ox_b, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_s, const int32_t stride_oy_b, const int32_t stride_oy_h, const int32_t stride_oy_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t ib_idx = sid * gridDim.y + bid; const int64_t pos = p_indirect_buffer_0[ib_idx] + p_indirect_buffer_1[ib_idx]; if ((pos >= 0) && (pos < max_position)) { const uint64_t offset_ix = sid * stride_ix_s + bid * stride_ix_b; const uint64_t offset_iy = sid * stride_iy_s + bid * stride_iy_b; const uint64_t offset_ox = sid * stride_ox_s + bid * stride_ox_b; const uint64_t offset_oy = sid * stride_oy_s + bid * stride_oy_b; const int64_t offset_f = pos * size_f; Op::template apply_2c( p_output_x + offset_ox, p_output_y + offset_oy, p_input_x + offset_ix, p_input_y + offset_iy, p_cos + offset_f, p_sin + offset_f, size_h_x, size_h_y, size_d, size_f, stride_ix_h, stride_ix_d, stride_iy_h, stride_iy_d, stride_ox_h, stride_ox_d, stride_oy_h, stride_oy_d); } } template __global__ void kn_entry_1c_sbhd_cached_indirect2_inplace( scalar_t* __restrict__ p_inout, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int64_t* __restrict__ p_indirect_buffer_0, const int64_t* __restrict__ p_indirect_buffer_1, const int32_t max_position, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_s, const int32_t stride_b, const int32_t stride_h, const int32_t stride_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t ib_idx = sid * gridDim.y + bid; const int64_t pos = p_indirect_buffer_0[ib_idx] + p_indirect_buffer_1[ib_idx]; if ((pos >= 0) && (pos < max_position)) { const uint64_t offset = sid * stride_s + bid * stride_b; const int64_t offset_f = pos * size_f; Op::template apply_1c( p_inout + offset, p_inout + offset, p_cos + offset_f, p_sin + offset_f, size_h, size_d, size_f, stride_h, stride_d, stride_h, stride_d); } } template __global__ void kn_entry_2c_sbhd_cached_indirect2_inplace( scalar_t* __restrict__ p_inout_x, scalar_t* __restrict__ p_inout_y, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int64_t* __restrict__ p_indirect_buffer_0, const int64_t* __restrict__ p_indirect_buffer_1, const int32_t max_position, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_x_s, const int32_t stride_x_b, const int32_t stride_x_h, const int32_t stride_x_d, const int32_t stride_y_s, const int32_t stride_y_b, const int32_t stride_y_h, const int32_t stride_y_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t ib_idx = sid * gridDim.y + bid; const int64_t pos = p_indirect_buffer_0[ib_idx] + p_indirect_buffer_1[ib_idx]; if ((pos >= 0) && (pos < max_position)) { const uint64_t offset_x = sid * stride_x_s + bid * stride_x_b; const uint64_t offset_y = sid * stride_y_s + bid * stride_y_b; const int64_t offset_f = pos * size_f; Op::template apply_2c( p_inout_x + offset_x, p_inout_y + offset_y, p_inout_x + offset_x, p_inout_y + offset_y, p_cos + offset_f, p_sin + offset_f, size_h_x, size_h_y, size_d, size_f, stride_x_h, stride_x_d, stride_y_h, stride_y_d, stride_x_h, stride_x_d, stride_y_h, stride_y_d); } } template __global__ void kn_entry_1c_thd_uncached( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const int32_t* __restrict__ p_cu_seqlens, const scalar_f_t* __restrict__ p_freqs, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_i_t, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_t, const int32_t stride_o_h, const int32_t stride_o_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t tid = sid + p_cu_seqlens[bid]; if (tid < p_cu_seqlens[bid + 1]) { const int32_t offset_i = tid * stride_i_t; const int32_t offset_o = tid * stride_o_t; const int32_t offset_f = sid * size_f; Op::template apply_1c( p_output + offset_o, p_input + offset_i, p_freqs + offset_f, size_h, size_d, size_f, stride_i_h, stride_i_d, stride_o_h, stride_o_d); } } template __global__ void kn_entry_1c_thd_uncached_inplace( scalar_t* __restrict__ p_inout, const int32_t* __restrict__ p_cu_seqlens, const scalar_f_t* __restrict__ p_freqs, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_t, const int32_t stride_h, const int32_t stride_d) { const uint64_t sid = blockIdx.x; const uint64_t bid = blockIdx.y; const uint64_t tid = sid + p_cu_seqlens[bid]; if (tid < p_cu_seqlens[bid + 1]) { const int32_t offset = tid * stride_t; const int32_t offset_f = sid * size_f; Op::template apply_1c( p_inout + offset, p_inout + offset, p_freqs + offset_f, size_h, size_d, size_f, stride_h, stride_d, stride_h, stride_d); } } template __global__ void kn_entry_1c_2d_cached( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const scalar_f_t* __restrict__ p_cos_h, const scalar_f_t* __restrict__ p_sin_h, const scalar_f_t* __restrict__ p_cos_w, const scalar_f_t* __restrict__ p_sin_w, const int32_t img_width, const int32_t size_h, const int32_t size_d, const int32_t stride_i_b, const int32_t stride_i_s, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_b, const int32_t stride_o_s, const int32_t stride_o_h, const int32_t stride_o_d) { const uint64_t Hid = blockIdx.x; const uint64_t Wid = blockIdx.y; const uint64_t sid = Hid * img_width + Wid; const uint64_t bid = blockIdx.z; const uint64_t size_half_d = size_d >> 1; const int offset_h_i = bid * stride_i_b + sid * stride_i_s; const int offset_h_o = bid * stride_o_b + sid * stride_o_s; const int offset_h_f = Hid * size_half_d; Op::template apply_1c( p_output + offset_h_o, p_input + offset_h_i, p_cos_h + offset_h_f, p_sin_h + offset_h_f, size_h, size_half_d, size_half_d, stride_i_h, stride_i_d, stride_o_h, stride_o_d); const int offset_w_i = offset_h_i + size_half_d * stride_i_d; const int offset_w_o = offset_h_o + size_half_d * stride_o_d; const int offset_w_f = Wid * size_half_d; Op::template apply_1c( p_output + offset_w_o, p_input + offset_w_i, p_cos_w + offset_w_f, p_sin_w + offset_w_f, size_h, size_half_d, size_half_d, stride_i_h, stride_i_d, stride_o_h, stride_o_d); } template __global__ void kn_entry_1c_2d_cached_inplace( scalar_t* __restrict__ p_inout, const scalar_f_t* __restrict__ p_cos_h, const scalar_f_t* __restrict__ p_sin_h, const scalar_f_t* __restrict__ p_cos_w, const scalar_f_t* __restrict__ p_sin_w, const int32_t img_width, const int32_t size_h, const int32_t size_d, const int32_t stride_b, const int32_t stride_s, const int32_t stride_h, const int32_t stride_d) { const uint64_t Hid = blockIdx.x; const uint64_t Wid = blockIdx.y; const uint64_t sid = Hid * img_width + Wid; const uint64_t bid = blockIdx.z; const uint64_t size_half_d = size_d >> 1; const int offset_h = bid * stride_b + sid * stride_s; const int offset_h_f = Hid * size_half_d; Op::template apply_1c( p_inout + offset_h, p_inout + offset_h, p_cos_h + offset_h_f, p_sin_h + offset_h_f, size_h, size_half_d, size_half_d, stride_h, stride_d, stride_h, stride_d); const int offset_w = offset_h + size_half_d * stride_d; const int offset_w_f = Wid * size_half_d; Op::template apply_1c( p_inout + offset_w, p_inout + offset_w, p_cos_w + offset_w_f, p_sin_w + offset_w_f, size_h, size_half_d, size_half_d, stride_h, stride_d, stride_h, stride_d); } // ===================================================================================================================== // Dispatches // #define LAUNCH_KERNEL_STRIDE_EQUAL_1_1_STRIDES(ROTATE_STYLE, STRIDE_0, ...) \ if constexpr ((ROTATE_STYLE) != ROTATE_STYLE_GPTJ) \ { \ constexpr bool Stride0Eq1 = false; \ __VA_ARGS__; \ } \ else if ((STRIDE_0) == 1) \ { \ constexpr bool Stride0Eq1 = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool Stride0Eq1 = false; \ __VA_ARGS__; \ } #define LAUNCH_KERNEL_STRIDE_EQUAL_1_2_STRIDES(ROTATE_STYLE, STRIDE_0, STRIDE_1, ...) \ if constexpr ((ROTATE_STYLE) != ROTATE_STYLE_GPTJ) \ { \ constexpr bool Stride0Eq1 = false; \ constexpr bool Stride1Eq1 = false; \ __VA_ARGS__; \ } \ else if ((STRIDE_0) == 1) \ { \ constexpr bool Stride0Eq1 = true; \ if ((STRIDE_1) == 1) \ { \ constexpr bool Stride1Eq1 = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool Stride1Eq1 = false; \ __VA_ARGS__; \ } \ } \ else \ { \ constexpr bool Stride0Eq1 = false; \ if ((STRIDE_1) == 1) \ { \ constexpr bool Stride1Eq1 = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool Stride1Eq1 = false; \ __VA_ARGS__; \ } \ } #define LAUNCH_KERNEL_STRIDE_EQUAL_1_4_STRIDES(ROTATE_STYLE, STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, ...) \ if constexpr ((ROTATE_STYLE) != ROTATE_STYLE_GPTJ) \ { \ constexpr bool Stride0Eq1 = false; \ constexpr bool Stride1Eq1 = false; \ constexpr bool Stride2Eq1 = false; \ constexpr bool Stride3Eq1 = false; \ __VA_ARGS__; \ } \ else if ((STRIDE_0) == 1) \ { \ constexpr bool Stride0Eq1 = true; \ if ((STRIDE_1) == 1) \ { \ constexpr bool Stride1Eq1 = true; \ if ((STRIDE_2) == 1) \ { \ constexpr bool Stride2Eq1 = true; \ if ((STRIDE_3) == 1) \ { \ constexpr bool Stride3Eq1 = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool Stride3Eq1 = false; \ __VA_ARGS__; \ } \ } \ else \ { \ constexpr bool Stride2Eq1 = false; \ if ((STRIDE_3) == 1) \ { \ constexpr bool Stride3Eq1 = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool Stride3Eq1 = false; \ __VA_ARGS__; \ } \ } \ } \ else \ { \ constexpr bool Stride1Eq1 = false; \ if ((STRIDE_2) == 1) \ { \ constexpr bool Stride2Eq1 = true; \ if ((STRIDE_3) == 1) \ { \ constexpr bool Stride3Eq1 = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool Stride3Eq1 = false; \ __VA_ARGS__; \ } \ } \ else \ { \ constexpr bool Stride2Eq1 = false; \ if ((STRIDE_3) == 1) \ { \ constexpr bool Stride3Eq1 = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool Stride3Eq1 = false; \ __VA_ARGS__; \ } \ } \ } \ } \ else \ { \ constexpr bool Stride0Eq1 = false; \ if ((STRIDE_1) == 1) \ { \ constexpr bool Stride1Eq1 = true; \ if ((STRIDE_2) == 1) \ { \ constexpr bool Stride2Eq1 = true; \ if ((STRIDE_3) == 1) \ { \ constexpr bool Stride3Eq1 = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool Stride3Eq1 = false; \ __VA_ARGS__; \ } \ } \ else \ { \ constexpr bool Stride2Eq1 = false; \ if ((STRIDE_3) == 1) \ { \ constexpr bool Stride3Eq1 = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool Stride3Eq1 = false; \ __VA_ARGS__; \ } \ } \ } \ else \ { \ constexpr bool Stride1Eq1 = false; \ if ((STRIDE_2) == 1) \ { \ constexpr bool Stride2Eq1 = true; \ if ((STRIDE_3) == 1) \ { \ constexpr bool Stride3Eq1 = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool Stride3Eq1 = false; \ __VA_ARGS__; \ } \ } \ else \ { \ constexpr bool Stride2Eq1 = false; \ if ((STRIDE_3) == 1) \ { \ constexpr bool Stride3Eq1 = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool Stride3Eq1 = false; \ __VA_ARGS__; \ } \ } \ } \ } template void dispatch_1c_sbhd_uncached( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const scalar_f_t* __restrict__ p_freqs, const int32_t size_s, const int32_t size_b, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_i_s, const int32_t stride_i_b, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_s, const int32_t stride_o_b, const int32_t stride_o_h, const int32_t stride_o_d) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); if (p_output == p_input) { assert(stride_i_s == stride_o_s); assert(stride_i_b == stride_o_b); assert(stride_i_h == stride_o_h); assert(stride_i_d == stride_o_d); LAUNCH_KERNEL_STRIDE_EQUAL_1_1_STRIDES( RotateStyle, stride_i_d, kn_entry_1c_sbhd_uncached_inplace<<>>( p_output, p_freqs, size_h, size_d, size_f, stride_i_s, stride_i_b, stride_i_h, stride_i_d);); } else { LAUNCH_KERNEL_STRIDE_EQUAL_1_2_STRIDES( RotateStyle, stride_o_d, stride_i_d, kn_entry_1c_sbhd_uncached<<>>( p_output, p_input, p_freqs, size_h, size_d, size_f, stride_i_s, stride_i_b, stride_i_h, stride_i_d, stride_o_s, stride_o_b, stride_o_h, stride_o_d); ); } } template void dispatch_2c_sbhd_uncached( scalar_t* __restrict__ p_output_x, scalar_t* __restrict__ p_output_y, const scalar_t* __restrict__ p_input_x, const scalar_t* __restrict__ p_input_y, const scalar_f_t* __restrict__ p_freqs, const int32_t size_s, const int32_t size_b, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_ix_s, const int32_t stride_ix_b, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_s, const int32_t stride_iy_b, const int32_t stride_iy_h, const int32_t stride_iy_d, const int32_t stride_ox_s, const int32_t stride_ox_b, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_s, const int32_t stride_oy_b, const int32_t stride_oy_h, const int32_t stride_oy_d) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h_x < 16 ? 4 : 8); if ((p_output_x == p_input_x) && (p_output_y == p_input_y)) { assert(stride_ix_s == stride_ox_s); assert(stride_ix_b == stride_ox_b); assert(stride_ix_h == stride_ox_h); assert(stride_ix_d == stride_ox_d); assert(stride_iy_s == stride_oy_s); assert(stride_iy_b == stride_oy_b); assert(stride_iy_h == stride_oy_h); assert(stride_iy_d == stride_oy_d); LAUNCH_KERNEL_STRIDE_EQUAL_1_2_STRIDES( RotateStyle, stride_ix_d, stride_iy_d, kn_entry_2c_sbhd_uncached_inplace<<>>( p_output_x, p_output_y, p_freqs, size_h_x, size_h_y, size_d, size_f, stride_ix_s, stride_ix_b, stride_ix_h, stride_ix_d, stride_iy_s, stride_iy_b, stride_iy_h, stride_iy_d); ); } else { LAUNCH_KERNEL_STRIDE_EQUAL_1_4_STRIDES( RotateStyle, stride_ox_d, stride_oy_d, stride_ix_d, stride_iy_d, kn_entry_2c_sbhd_uncached<<>>( p_output_x, p_output_y, p_input_x, p_input_y, p_freqs, size_h_x, size_h_y, size_d, size_f, stride_ix_s, stride_ix_b, stride_ix_h, stride_ix_d, stride_iy_s, stride_iy_b, stride_iy_h, stride_iy_d, stride_ox_s, stride_ox_b, stride_ox_h, stride_ox_d, stride_oy_s, stride_oy_b, stride_oy_h, stride_oy_d); ); } } template void dispatch_1c_sbhd_cached( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int32_t size_s, const int32_t size_b, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_i_s, const int32_t stride_i_b, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_s, const int32_t stride_o_b, const int32_t stride_o_h, const int32_t stride_o_d) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); if (p_output == p_input) { assert(stride_i_s == stride_o_s); assert(stride_i_b == stride_o_b); assert(stride_i_h == stride_o_h); assert(stride_i_d == stride_o_d); LAUNCH_KERNEL_STRIDE_EQUAL_1_1_STRIDES( RotateStyle, stride_i_d, kn_entry_1c_sbhd_cached_inplace<<>>( p_output, p_cos, p_sin, size_h, size_d, size_f, stride_i_s, stride_i_b, stride_i_h, stride_i_d); ); } else { LAUNCH_KERNEL_STRIDE_EQUAL_1_2_STRIDES( RotateStyle, stride_o_d, stride_i_d, kn_entry_1c_sbhd_cached<<>>( p_output, p_input, p_cos, p_sin, size_h, size_d, size_f, stride_i_s, stride_i_b, stride_i_h, stride_i_d, stride_o_s, stride_o_b, stride_o_h, stride_o_d); ); } } template void dispatch_2c_sbhd_cached( scalar_t* __restrict__ p_output_x, scalar_t* __restrict__ p_output_y, const scalar_t* __restrict__ p_input_x, const scalar_t* __restrict__ p_input_y, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int32_t size_s, const int32_t size_b, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_ix_s, const int32_t stride_ix_b, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_s, const int32_t stride_iy_b, const int32_t stride_iy_h, const int32_t stride_iy_d, const int32_t stride_ox_s, const int32_t stride_ox_b, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_s, const int32_t stride_oy_b, const int32_t stride_oy_h, const int32_t stride_oy_d) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h_x < 16 ? 4 : 8); if ((p_output_x == p_input_x) && (p_output_y == p_input_y)) { assert(stride_ix_s == stride_ox_s); assert(stride_ix_b == stride_ox_b); assert(stride_ix_h == stride_ox_h); assert(stride_ix_d == stride_ox_d); assert(stride_iy_s == stride_oy_s); assert(stride_iy_b == stride_oy_b); assert(stride_iy_h == stride_oy_h); assert(stride_iy_d == stride_oy_d); LAUNCH_KERNEL_STRIDE_EQUAL_1_2_STRIDES( RotateStyle, stride_ix_d, stride_iy_d, kn_entry_2c_sbhd_cached_inplace<<>>( p_output_x, p_output_y, p_cos, p_sin, size_h_x, size_h_y, size_d, size_f, stride_ix_s, stride_ix_b, stride_ix_h, stride_ix_d, stride_iy_s, stride_iy_b, stride_iy_h, stride_iy_d); ); } else { LAUNCH_KERNEL_STRIDE_EQUAL_1_4_STRIDES( RotateStyle, stride_ox_d, stride_oy_d, stride_ix_d, stride_iy_d, kn_entry_2c_sbhd_cached<<>>( p_output_x, p_output_y, p_input_x, p_input_y, p_cos, p_sin, size_h_x, size_h_y, size_d, size_f, stride_ix_s, stride_ix_b, stride_ix_h, stride_ix_d, stride_iy_s, stride_iy_b, stride_iy_h, stride_iy_d, stride_ox_s, stride_ox_b, stride_ox_h, stride_ox_d, stride_oy_s, stride_oy_b, stride_oy_h, stride_oy_d); ); } } template void dispatch_1c_sbhd_cached_indirect( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int64_t* __restrict__ p_indirect_buffer, const int32_t max_position, const int32_t size_s, const int32_t size_b, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_i_s, const int32_t stride_i_b, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_s, const int32_t stride_o_b, const int32_t stride_o_h, const int32_t stride_o_d) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); if (p_output == p_input) { assert(stride_i_s == stride_o_s); assert(stride_i_b == stride_o_b); assert(stride_i_h == stride_o_h); assert(stride_i_d == stride_o_d); LAUNCH_KERNEL_STRIDE_EQUAL_1_1_STRIDES( RotateStyle, stride_i_d, kn_entry_1c_sbhd_cached_indirect_inplace<<>>( p_output, p_cos, p_sin, p_indirect_buffer, max_position, size_h, size_d, size_f, stride_i_s, stride_i_b, stride_i_h, stride_i_d); ); } else { LAUNCH_KERNEL_STRIDE_EQUAL_1_2_STRIDES( RotateStyle, stride_o_d, stride_i_d, kn_entry_1c_sbhd_cached_indirect<<>>( p_output, p_input, p_cos, p_sin, p_indirect_buffer, max_position, size_h, size_d, size_f, stride_i_s, stride_i_b, stride_i_h, stride_i_d, stride_o_s, stride_o_b, stride_o_h, stride_o_d); ); } } template void dispatch_2c_sbhd_cached_indirect( scalar_t* __restrict__ p_output_x, scalar_t* __restrict__ p_output_y, const scalar_t* __restrict__ p_input_x, const scalar_t* __restrict__ p_input_y, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int64_t* __restrict__ p_indirect_buffer, const int32_t max_position, const int32_t size_s, const int32_t size_b, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_ix_s, const int32_t stride_ix_b, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_s, const int32_t stride_iy_b, const int32_t stride_iy_h, const int32_t stride_iy_d, const int32_t stride_ox_s, const int32_t stride_ox_b, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_s, const int32_t stride_oy_b, const int32_t stride_oy_h, const int32_t stride_oy_d) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h_x < 16 ? 4 : 8); if ((p_output_x == p_input_x) && (p_output_y == p_input_y)) { assert(stride_ix_s == stride_ox_s); assert(stride_ix_b == stride_ox_b); assert(stride_ix_h == stride_ox_h); assert(stride_ix_d == stride_ox_d); assert(stride_iy_s == stride_oy_s); assert(stride_iy_b == stride_oy_b); assert(stride_iy_h == stride_oy_h); assert(stride_iy_d == stride_oy_d); LAUNCH_KERNEL_STRIDE_EQUAL_1_2_STRIDES( RotateStyle, stride_ix_d, stride_iy_d, kn_entry_2c_sbhd_cached_indirect_inplace<<>>( p_output_x, p_output_y, p_cos, p_sin, p_indirect_buffer, max_position, size_h_x, size_h_y, size_d, size_f, stride_ix_s, stride_ix_b, stride_ix_h, stride_ix_d, stride_iy_s, stride_iy_b, stride_iy_h, stride_iy_d); ); } else { LAUNCH_KERNEL_STRIDE_EQUAL_1_4_STRIDES( RotateStyle, stride_ox_d, stride_oy_d, stride_ix_d, stride_iy_d, kn_entry_2c_sbhd_cached_indirect<<>>( p_output_x, p_output_y, p_input_x, p_input_y, p_cos, p_sin, p_indirect_buffer, max_position, size_h_x, size_h_y, size_d, size_f, stride_ix_s, stride_ix_b, stride_ix_h, stride_ix_d, stride_iy_s, stride_iy_b, stride_iy_h, stride_iy_d, stride_ox_s, stride_ox_b, stride_ox_h, stride_ox_d, stride_oy_s, stride_oy_b, stride_oy_h, stride_oy_d); ); } } template void dispatch_1c_sbhd_cached_indirect2( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int64_t* __restrict__ p_indirect_buffer_0, const int64_t* __restrict__ p_indirect_buffer_1, const int32_t max_position, const int32_t size_s, const int32_t size_b, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_i_s, const int32_t stride_i_b, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_s, const int32_t stride_o_b, const int32_t stride_o_h, const int32_t stride_o_d) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); if (p_output == p_input) { assert(stride_i_s == stride_o_s); assert(stride_i_b == stride_o_b); assert(stride_i_h == stride_o_h); assert(stride_i_d == stride_o_d); LAUNCH_KERNEL_STRIDE_EQUAL_1_1_STRIDES( RotateStyle, stride_i_d, kn_entry_1c_sbhd_cached_indirect2_inplace<<>>( p_output, p_cos, p_sin, p_indirect_buffer_0, p_indirect_buffer_1, max_position, size_h, size_d, size_f, stride_i_s, stride_i_b, stride_i_h, stride_i_d); ); } else { LAUNCH_KERNEL_STRIDE_EQUAL_1_2_STRIDES( RotateStyle, stride_o_d, stride_i_d, kn_entry_1c_sbhd_cached_indirect2<<>>( p_output, p_input, p_cos, p_sin, p_indirect_buffer_0, p_indirect_buffer_1, max_position, size_h, size_d, size_f, stride_i_s, stride_i_b, stride_i_h, stride_i_d, stride_o_s, stride_o_b, stride_o_h, stride_o_d); ); } } template void dispatch_2c_sbhd_cached_indirect2( scalar_t* __restrict__ p_output_x, scalar_t* __restrict__ p_output_y, const scalar_t* __restrict__ p_input_x, const scalar_t* __restrict__ p_input_y, const scalar_f_t* __restrict__ p_cos, const scalar_f_t* __restrict__ p_sin, const int64_t* __restrict__ p_indirect_buffer_0, const int64_t* __restrict__ p_indirect_buffer_1, const int32_t max_position, const int32_t size_s, const int32_t size_b, const int32_t size_h_x, const int32_t size_h_y, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_ix_s, const int32_t stride_ix_b, const int32_t stride_ix_h, const int32_t stride_ix_d, const int32_t stride_iy_s, const int32_t stride_iy_b, const int32_t stride_iy_h, const int32_t stride_iy_d, const int32_t stride_ox_s, const int32_t stride_ox_b, const int32_t stride_ox_h, const int32_t stride_ox_d, const int32_t stride_oy_s, const int32_t stride_oy_b, const int32_t stride_oy_h, const int32_t stride_oy_d) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h_x < 16 ? 4 : 8); if ((p_output_x == p_input_x) && (p_output_y == p_input_y)) { assert(stride_ix_s == stride_ox_s); assert(stride_ix_b == stride_ox_b); assert(stride_ix_h == stride_ox_h); assert(stride_ix_d == stride_ox_d); assert(stride_iy_s == stride_oy_s); assert(stride_iy_b == stride_oy_b); assert(stride_iy_h == stride_oy_h); assert(stride_iy_d == stride_oy_d); LAUNCH_KERNEL_STRIDE_EQUAL_1_2_STRIDES( RotateStyle, stride_ix_d, stride_iy_d, kn_entry_2c_sbhd_cached_indirect2_inplace<<>>( p_output_x, p_output_y, p_cos, p_sin, p_indirect_buffer_0, p_indirect_buffer_1, max_position, size_h_x, size_h_y, size_d, size_f, stride_ix_s, stride_ix_b, stride_ix_h, stride_ix_d, stride_iy_s, stride_iy_b, stride_iy_h, stride_iy_d); ); } else { LAUNCH_KERNEL_STRIDE_EQUAL_1_4_STRIDES( RotateStyle, stride_ox_d, stride_oy_d, stride_ix_d, stride_iy_d, kn_entry_2c_sbhd_cached_indirect2<<>>( p_output_x, p_output_y, p_input_x, p_input_y, p_cos, p_sin, p_indirect_buffer_0, p_indirect_buffer_1, max_position, size_h_x, size_h_y, size_d, size_f, stride_ix_s, stride_ix_b, stride_ix_h, stride_ix_d, stride_iy_s, stride_iy_b, stride_iy_h, stride_iy_d, stride_ox_s, stride_ox_b, stride_ox_h, stride_ox_d, stride_oy_s, stride_oy_b, stride_oy_h, stride_oy_d); ); } } template void dispatch_1c_thd_uncached( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const int32_t* __restrict__ p_cu_seqlens, const scalar_f_t* __restrict__ p_freqs, const int32_t size_max_s, const int32_t size_b, const int32_t size_h, const int32_t size_d, const int32_t size_f, // size of last dimension of freqs. const int32_t stride_i_t, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_t, const int32_t stride_o_h, const int32_t stride_o_d) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const dim3 grid(size_max_s, size_b); const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); if (p_output == p_input) { assert(stride_i_t == stride_o_t); assert(stride_i_h == stride_o_h); assert(stride_i_d == stride_o_d); LAUNCH_KERNEL_STRIDE_EQUAL_1_1_STRIDES( RotateStyle, stride_i_d, kn_entry_1c_thd_uncached_inplace<<>>( p_output, p_cu_seqlens, p_freqs, size_h, size_d, size_f, stride_i_t, stride_i_h, stride_i_d); ); } else { LAUNCH_KERNEL_STRIDE_EQUAL_1_2_STRIDES( RotateStyle, stride_o_d, stride_i_d, kn_entry_1c_thd_uncached<<>>( p_output, p_input, p_cu_seqlens, p_freqs, size_h, size_d, size_f, stride_i_t, stride_i_h, stride_i_d, stride_o_t, stride_o_h, stride_o_d); ); } } template void dispatch_1c_2d_cached( scalar_t* __restrict__ p_output, const scalar_t* __restrict__ p_input, const scalar_f_t* __restrict__ p_cos_h, const scalar_f_t* __restrict__ p_sin_h, const scalar_f_t* __restrict__ p_cos_w, const scalar_f_t* __restrict__ p_sin_w, const int img_height, const int img_width, const int32_t size_b, const int32_t size_h, const int32_t size_d, const int32_t stride_i_b, const int32_t stride_i_s, const int32_t stride_i_h, const int32_t stride_i_d, const int32_t stride_o_b, const int32_t stride_o_s, const int32_t stride_o_h, const int32_t stride_o_d) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const dim3 grid(img_height, img_width, size_b); const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); if (p_output == p_input) { assert(stride_i_s == stride_o_s); assert(stride_i_b == stride_o_b); assert(stride_i_h == stride_o_h); assert(stride_i_d == stride_o_d); LAUNCH_KERNEL_STRIDE_EQUAL_1_1_STRIDES( RotateStyle, stride_i_d, kn_entry_1c_2d_cached_inplace<<>>( p_output, p_cos_h, p_sin_h, p_cos_w, p_sin_w, img_width, size_h, size_d, stride_i_b, stride_i_s, stride_i_h, stride_i_d); ); } else { LAUNCH_KERNEL_STRIDE_EQUAL_1_2_STRIDES( RotateStyle, stride_o_d, stride_i_d, kn_entry_1c_2d_cached<<>>( p_output, p_input, p_cos_h, p_sin_h, p_cos_w, p_sin_w, img_width, size_h, size_d, stride_i_b, stride_i_s, stride_i_h, stride_i_d, stride_o_b, stride_o_s, stride_o_h, stride_o_d); ); } } #define DISPATCH_ROPE_TYPES_PARAMS(TYPE0, TYPE1, ROTATE_STYLE, REUSE_FREQS_FRONT_PART, NOPE_FIRST, NAME, ...) \ switch ((TYPE0)) { \ case at::ScalarType::Float: { \ using scalar_t_0 = float; \ switch ((TYPE1)) \ { \ case at::ScalarType::Float: { \ using scalar_t_1 = float; \ if ((REUSE_FREQS_FRONT_PART)) \ { \ constexpr bool ReuseFreqsFrontPart = true; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ else \ { \ constexpr bool ReuseFreqsFrontPart = false; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ break; \ } \ case at::ScalarType::Half: { \ using scalar_t_1 = at::Half; \ if ((REUSE_FREQS_FRONT_PART)) \ { \ constexpr bool ReuseFreqsFrontPart = true; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ else \ { \ constexpr bool ReuseFreqsFrontPart = false; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ break; \ } \ case at::ScalarType::BFloat16: { \ using scalar_t_1 = at::BFloat16; \ if ((REUSE_FREQS_FRONT_PART)) \ { \ constexpr bool ReuseFreqsFrontPart = true; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ else \ { \ constexpr bool ReuseFreqsFrontPart = false; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ break; \ } \ default: \ TORCH_CHECK(false, NAME " does't support ", \ toString((TYPE0)), " with ", toString((TYPE1)), "."); \ } \ break; \ } \ case at::ScalarType::Half: { \ using scalar_t_0 = at::Half; \ switch ((TYPE1)) \ { \ case at::ScalarType::Float: { \ using scalar_t_1 = float; \ if ((REUSE_FREQS_FRONT_PART)) \ { \ constexpr bool ReuseFreqsFrontPart = true; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ else \ { \ constexpr bool ReuseFreqsFrontPart = false; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ break; \ } \ case at::ScalarType::Half: { \ using scalar_t_1 = at::Half; \ if ((REUSE_FREQS_FRONT_PART)) \ { \ constexpr bool ReuseFreqsFrontPart = true; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ else \ { \ constexpr bool ReuseFreqsFrontPart = false; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ break; \ } \ case at::ScalarType::BFloat16: { \ using scalar_t_1 = at::BFloat16; \ if ((REUSE_FREQS_FRONT_PART)) \ { \ constexpr bool ReuseFreqsFrontPart = true; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ else \ { \ constexpr bool ReuseFreqsFrontPart = false; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ break; \ } \ default: \ TORCH_CHECK(false, NAME " does't support ", \ toString((TYPE0)), " with ", toString((TYPE1)), "."); \ } \ break; \ } \ case at::ScalarType::BFloat16: { \ using scalar_t_0 = at::BFloat16; \ switch ((TYPE1)) \ { \ case at::ScalarType::Float: { \ using scalar_t_1 = float; \ if ((REUSE_FREQS_FRONT_PART)) \ { \ constexpr bool ReuseFreqsFrontPart = true; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ else \ { \ constexpr bool ReuseFreqsFrontPart = false; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ break; \ } \ case at::ScalarType::Half: { \ using scalar_t_1 = at::Half; \ if ((REUSE_FREQS_FRONT_PART)) \ { \ constexpr bool ReuseFreqsFrontPart = true; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ else \ { \ constexpr bool ReuseFreqsFrontPart = false; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ break; \ } \ case at::ScalarType::BFloat16: { \ using scalar_t_1 = at::BFloat16; \ if ((REUSE_FREQS_FRONT_PART)) \ { \ constexpr bool ReuseFreqsFrontPart = true; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ else \ { \ constexpr bool ReuseFreqsFrontPart = false; \ if ((ROTATE_STYLE) == ROTATE_STYLE_NEOX) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_NEOX; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else if ((ROTATE_STYLE) == ROTATE_STYLE_GPTJ) \ { \ constexpr int32_t RotateStyle = ROTATE_STYLE_GPTJ; \ if ((NOPE_FIRST)) \ { \ constexpr bool NopeFirst = true; \ __VA_ARGS__; \ } \ else \ { \ constexpr bool NopeFirst = false; \ __VA_ARGS__; \ } \ } \ else \ { \ TORCH_CHECK(false, NAME " does't support rotate type ", \ std::to_string((ROTATE_STYLE)), "."); \ } \ } \ break; \ } \ default: \ TORCH_CHECK(false, NAME " does't support ", \ toString((TYPE0)), " with ", toString((TYPE1)), "."); \ } \ break; \ } \ default: \ TORCH_CHECK(false, NAME " does't support ", toString((TYPE0)), "."); \ }