#ifndef MPSCNNShaders_h #define MPSCNNShaders_h static const char* PT_METAL_SHADERS = R"PT_METAL_SHADERS( #include using namespace metal; constant ushort ushort_arg_0[[function_constant(0)]]; constant ushort ushort_arg_1[[function_constant(1)]]; constant ushort ushort_arg_2[[function_constant(2)]]; constant ushort ushort_arg_3[[function_constant(3)]]; constant ushort ushort_arg_4[[function_constant(4)]]; constant ushort ushort_arg_5[[function_constant(5)]]; constant ushort ushort_arg_6[[function_constant(6)]]; constant ushort ushort_arg_7[[function_constant(7)]]; constant ushort ushort_arg_8[[function_constant(8)]]; constant ushort ushort_arg_9[[function_constant(9)]]; constant ushort ushort_arg_10[[function_constant(10)]]; constant ushort ushort_arg_11[[function_constant(11)]]; constant float float_arg_0 [[function_constant(12)]]; constant float float_arg_1 [[function_constant(13)]]; inline constexpr ushort divRoundUp(ushort x, ushort y) { return (x + (y - 1)) / y; } enum broadcastOp { Add, Sub, Mul, Div, }; void elementwise_broadcast_nonarray(texture2d in0, texture2d in1, texture2d out, ushort2 gid, broadcastOp op) { if (gid.x >= out.get_width() || gid.y >= out.get_height()) { return; } ushort2 in0_stride = ushort2(in0.get_width() > 1, in0.get_height() > 1); ushort2 in1_stride = ushort2(in1.get_width() > 1, in1.get_height() > 1); ushort2 gid0 = gid.xy * in0_stride; ushort2 gid1 = gid.xy * in1_stride; if(op == Add) { out.write(in0.read(gid0) + in1.read(gid1), gid); } else if(op == Sub) { out.write(in0.read(gid0) - in1.read(gid1), gid); } else if(op == Mul) { out.write(in0.read(gid0) * in1.read(gid1), gid); } else if(op == Div) { out.write(in0.read(gid0) / in1.read(gid1), gid); } } void elementwise_broadcast(texture2d_array in0, texture2d_array in1, texture2d_array out, ushort3 gid, broadcastOp op) { if (gid.x >= out.get_width() || gid.y >= out.get_height()) { return; } ushort2 in0_stride = ushort2(in0.get_width() > 1, in0.get_height() > 1); ushort2 in1_stride = ushort2(in1.get_width() > 1, in1.get_height() > 1); ushort2 gid0 = gid.xy * in0_stride; ushort2 gid1 = gid.xy * in1_stride; if(op == Add) { out.write(in0.read(gid0, gid.z) + in1.read(gid1, gid.z), gid.xy, gid.z); } else if(op == Sub) { out.write(in0.read(gid0, gid.z) - in1.read(gid1, gid.z), gid.xy, gid.z); } else if(op == Mul) { out.write(in0.read(gid0, gid.z) * in1.read(gid1, gid.z), gid.xy, gid.z); } else if(op == Div) { out.write(in0.read(gid0, gid.z) / in1.read(gid1, gid.z), gid.xy, gid.z); } } kernel void elementwise_add_nonarray(texture2d in0[[texture(0)]], texture2d in1[[texture(1)]], texture2d out[[texture(2)]], ushort2 gid[[thread_position_in_grid]]) { elementwise_broadcast_nonarray(in0, in1, out, gid, Add); } kernel void elementwise_add(texture2d_array in0[[texture(0)]], texture2d_array in1[[texture(1)]], texture2d_array out[[texture(2)]], ushort3 gid[[thread_position_in_grid]]) { elementwise_broadcast(in0, in1, out, gid, Add); } kernel void elementwise_sub_nonarray(texture2d in0[[texture(0)]], texture2d in1[[texture(1)]], texture2d out[[texture(2)]], ushort2 gid[[thread_position_in_grid]]) { elementwise_broadcast_nonarray(in0, in1, out, gid, Sub); } kernel void elementwise_sub(texture2d_array in0[[texture(0)]], texture2d_array in1[[texture(1)]], texture2d_array out[[texture(2)]], ushort3 gid[[thread_position_in_grid]]) { elementwise_broadcast(in0, in1, out, gid, Sub); } kernel void elementwise_mul_nonarray(texture2d in0[[texture(0)]], texture2d in1[[texture(1)]], texture2d out[[texture(2)]], ushort2 gid[[thread_position_in_grid]]) { elementwise_broadcast_nonarray(in0, in1, out, gid, Mul); } kernel void elementwise_mul(texture2d_array in0[[texture(0)]], texture2d_array in1[[texture(1)]], texture2d_array out[[texture(2)]], ushort3 gid[[thread_position_in_grid]]) { elementwise_broadcast(in0, in1, out, gid, Mul); } kernel void elementwise_div_nonarray(texture2d in0[[texture(0)]], texture2d in1[[texture(1)]], texture2d out[[texture(2)]], ushort2 gid[[thread_position_in_grid]]) { elementwise_broadcast_nonarray(in0, in1, out, gid, Div); } kernel void elementwise_div(texture2d_array in0[[texture(0)]], texture2d_array in1[[texture(1)]], texture2d_array out[[texture(2)]], ushort3 gid[[thread_position_in_grid]]) { elementwise_broadcast(in0, in1, out, gid, Div); } kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]], texture2d_array out[[texture(0)]], ushort3 gid[[thread_position_in_grid]]) { const ushort C = ushort_arg_0; const ushort H = ushort_arg_1; const ushort W = ushort_arg_2; if (gid.x >= W || gid.y >= H) { return; } const ushort n = gid.z / divRoundUp(C, 4); const ushort c = gid.z - n * divRoundUp(C, 4); #define CHW_TO_CHWP4(idx, n, c_, h, w) \ if ((c_) < C) { \ trns[idx] = in[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)]; \ } else { \ trns[idx] = 0.0h; \ } half4 trns; CHW_TO_CHWP4(0, n, c * 4 + 0, gid.y, gid.x); CHW_TO_CHWP4(1, n, c * 4 + 1, gid.y, gid.x); CHW_TO_CHWP4(2, n, c * 4 + 2, gid.y, gid.x); CHW_TO_CHWP4(3, n, c * 4 + 3, gid.y, gid.x); #undef CHW_TO_CHWP4 out.write(trns, gid.xy, gid.z); } kernel void copy_nchw_to_metal_nonarray(constant float* in[[buffer(0)]], texture2d out[[texture(0)]], ushort2 gid[[thread_position_in_grid]]) { const ushort C = ushort_arg_0; const ushort H = ushort_arg_1; const ushort W = ushort_arg_2; if (gid.x >= W || gid.y >= H) { return; } half4 trns; #define CHW_TO_CHWP4(idx, c, h, w) \ if ((c) < C) { \ trns[idx] = in[int(c) * H * W + int(h) * W + int(w)]; \ } else { \ trns[idx] = 0.0h; \ } CHW_TO_CHWP4(0, 0, gid.y, gid.x); CHW_TO_CHWP4(1, 1, gid.y, gid.x); CHW_TO_CHWP4(2, 2, gid.y, gid.x); CHW_TO_CHWP4(3, 3, gid.y, gid.x); #undef CHW_TO_CHWP4 out.write(trns, gid.xy); } kernel void copy_metal_to_nchw(texture2d_array in[[texture(0)]], device float* out[[buffer(0)]], ushort3 gid[[thread_position_in_grid]]) { const ushort C = ushort_arg_0; const ushort H = ushort_arg_1; const ushort W = ushort_arg_2; if (gid.x >= W || gid.y >= H) { return; } const ushort n = gid.z / divRoundUp(C, 4); const ushort c = gid.z - n * divRoundUp(C, 4); half4 cs = in.read(gid.xy, gid.z); #define CHWP4_TO_CHW(idx, n, c_, h, w) \ if ((c_) < C) { \ out[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)] = cs[idx]; \ } CHWP4_TO_CHW(0, n, c * 4 + 0, gid.y, gid.x); CHWP4_TO_CHW(1, n, c * 4 + 1, gid.y, gid.x); CHWP4_TO_CHW(2, n, c * 4 + 2, gid.y, gid.x); CHWP4_TO_CHW(3, n, c * 4 + 3, gid.y, gid.x); #undef CHWP4_TO_CHW } kernel void copy_metal_to_nchw_nonarray(texture2d in[[texture(0)]], device float* out[[buffer(0)]], ushort2 gid[[thread_position_in_grid]]) { const ushort C = ushort_arg_0; const ushort H = ushort_arg_1; const ushort W = ushort_arg_2; if (gid.x >= W || gid.y >= H) { return; } half4 cs = in.read(gid.xy); #define CHWP4_TO_CHW(idx, c, h, w) \ if ((c) < C) { \ out[int(c) * H * W + int(h) * W + int(w)] = cs[idx]; \ } CHWP4_TO_CHW(0, 0, gid.y, gid.x); CHWP4_TO_CHW(1, 1, gid.y, gid.x); CHWP4_TO_CHW(2, 2, gid.y, gid.x); CHWP4_TO_CHW(3, 3, gid.y, gid.x); #undef CHWP4_TO_CHW } kernel void copy(texture2d_array in[[texture(0)]], texture2d_array out[[texture(1)]], ushort3 gid[[thread_position_in_grid]]) { if (gid.x >= out.get_width() || gid.y >= out.get_height()) { return; } ushort2 gid_ = gid.xy; out.write(in.read(gid_, gid.z), gid_, gid.z); } kernel void copy_nonarray(texture2d in[[texture(0)]], texture2d out[[texture(1)]], ushort2 gid[[thread_position_in_grid]]) { if (gid.x >= out.get_width() || gid.y >= out.get_height()) { return; } out.write(in.read(gid), gid); } kernel void copy_offset(texture2d_array in[[texture(0)]], texture2d_array out[[texture(1)]], constant ushort* offset_buf[[buffer(0)]], ushort3 gid[[thread_position_in_grid]]) { if (gid.x >= out.get_width() || gid.y >= out.get_height()) { return; } ushort2 gid_ = gid.xy; out.write(in.read(gid_, gid.z), gid_, gid.z + offset_buf[0]); } kernel void copy_offset_nonarray(texture2d in[[texture(0)]], texture2d_array out[[texture(1)]], constant ushort* offset_buf[[buffer(0)]], ushort3 gid[[thread_position_in_grid]]) { if (gid.x >= out.get_width() || gid.y >= out.get_height()) { return; } ushort2 gid_ = gid.xy; out.write(in.read(gid_), gid_, gid.z + offset_buf[0]); } constant bool store_features_out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4); constant bool store_features_out_is_tex = !store_features_out_is_arr; kernel void store_features(texture2d_array in[[texture(0)]], texture2d out_tex[[texture(1), function_constant(store_features_out_is_tex)]], texture2d_array out_arr[[texture(1), function_constant(store_features_out_is_arr)]], constant ushort* offset_buf[[buffer(0)]], ushort3 gid[[thread_position_in_grid]]) { ushort2 gid_ = gid.xy; if (store_features_out_is_arr) out_arr.write(in.read(gid_, gid.z * offset_buf[1] + offset_buf[0]), gid_, gid.z); else out_tex.write(in.read(gid_, gid.z * offset_buf[1] + offset_buf[0]), gid_); } constant bool append_features_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4); constant bool append_features_in_is_tex = !append_features_in_is_arr; kernel void append_features(texture2d in_tex[[texture(0), function_constant(append_features_in_is_tex)]], texture2d_array in_arr[[texture(0), function_constant(append_features_in_is_arr)]], texture2d_array out[[texture(1)]], constant ushort* offset_buf[[buffer(0)]], ushort3 gid[[thread_position_in_grid]]) { ushort2 gid_ = gid.xy; ushort batch = gid.z / offset_buf[0]; ushort feature = gid.z % offset_buf[0]; ushort outz = batch * offset_buf[1] + offset_buf[2] + feature; ushort inz = batch * offset_buf[3] + feature; half4 intex; if (append_features_in_is_arr) { intex = in_arr.read(gid_, inz); } else { intex = in_tex.read(gid_); } out.write(intex, gid_, outz); } constant bool prev_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4); constant bool prev_is_tex = !prev_is_arr; constant bool append_features_off_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4); constant bool append_features_off_in_is_tex = !append_features_off_in_is_arr; kernel void append_features_off(texture2d in_tex[[texture(0), function_constant(append_features_off_in_is_tex)]], texture2d_array in_arr[[texture(0), function_constant(append_features_off_in_is_arr)]], texture2d prev_tex[[texture(1), function_constant(prev_is_tex)]], texture2d_array prev_arr[[texture(1), function_constant(prev_is_arr)]], texture2d_array out[[texture(2)]], constant ushort* offset_buf[[buffer(0)]], ushort3 gid[[thread_position_in_grid]]) { ushort2 gid_ = gid.xy; ushort batch = gid.z / offset_buf[0]; ushort feature = gid.z % offset_buf[0]; ushort outz = batch * offset_buf[1] + offset_buf[2] + feature; ushort inz = batch * offset_buf[3] + feature; half4 outtex; if (prev_is_arr) outtex = prev_arr.read(gid_, batch); else outtex = prev_tex.read(gid_); half4 intex1; if (append_features_in_is_arr) intex1 = in_arr.read(gid_, inz); else intex1 = in_tex.read(gid_); if (feature == 0) { if (offset_buf[5] == 1) outtex.yzw = intex1.xyz; else if (offset_buf[5] == 2) outtex.zw = intex1.xy; else outtex.w = intex1.x; out.write(outtex, gid_, outz); return; } half4 intex0; if (append_features_in_is_arr) intex0 = in_arr.read(gid_, inz-1); else intex0 = intex1; if (offset_buf[5] == 1) { outtex.x = intex0.w; outtex.yzw = intex1.xyz; } else if (offset_buf[5] == 2) { outtex.xy = intex0.zw; outtex.zw = intex1.xy; } else { outtex.xyz = intex0.yzw; outtex.w = intex1.x; } out.write(outtex, gid_, outz); } constant bool clamp_is_arr = (ushort_arg_1 > 1 || ushort_arg_0 > 4); constant bool clamp_is_tex = !clamp_is_arr; kernel void clamp(texture2d_array in_arr[[texture(0), function_constant(clamp_is_arr)]], texture2d in_tex[[texture(0), function_constant(clamp_is_tex)]], texture2d_array out_arr[[texture(1), function_constant(clamp_is_arr)]], texture2d out_tex[[texture(1), function_constant(clamp_is_tex)]], ushort3 gid[[thread_position_in_grid]]) { const ushort w = clamp_is_arr? out_arr.get_width() : out_tex.get_width(); const ushort h = clamp_is_arr? out_arr.get_height() : out_tex.get_height(); if (gid.x >= w || gid.y >= h) { return; } const float4 min_(float_arg_0, float_arg_0, float_arg_0, float_arg_0); const float4 max_(float_arg_1, float_arg_1, float_arg_1, float_arg_1); ushort2 gid_ = gid.xy; if(clamp_is_arr){ float4 value = (float4)in_arr.read(gid_, gid.z); half4 clamped = (half4)clamp(value, min_, max_); out_arr.write(clamped, gid_, gid.z); } else { float4 value = (float4)in_tex.read(gid_); half4 clamped = (half4)clamp(value, min_, max_); out_tex.write(clamped, gid_); } } constant bool hardswish_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4); constant bool hardswish_is_tex = !hardswish_is_arr; kernel void hardswish(texture2d_array in_arr[[texture(0), function_constant(hardswish_is_arr)]], texture2d in_tex[[texture(0), function_constant(hardswish_is_tex)]], texture2d_array out_arr[[texture(1), function_constant(hardswish_is_arr)]], texture2d out_tex[[texture(1), function_constant(hardswish_is_tex)]], ushort3 gid[[thread_position_in_grid]]) { const ushort oH = ushort_arg_2; const ushort oW = ushort_arg_3; if (gid.x >= oW || gid.y >= oH) { return; } ushort2 gid_ = gid.xy; if (hardswish_is_arr) { half4 value = in_arr.read(gid_, gid.z); half4 mask1 = half4(value < 3.0); half4 mask2 = half4(value > -3.0); half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value); out_arr.write(outval, gid_, gid.z); } else { half4 value = in_tex.read(gid_); half4 mask1 = half4(value < 3); half4 mask2 = half4(value > -3.0); half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value); out_tex.write(outval, gid_); } } constant bool hardshrink_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4); constant bool hardshrink_is_tex = !hardshrink_is_arr; kernel void hardshrink(texture2d_array in_arr[[texture(0), function_constant(hardshrink_is_arr)]], texture2d in_tex[[texture(0), function_constant(hardshrink_is_tex)]], texture2d_array out_arr[[texture(1), function_constant(hardshrink_is_arr)]], texture2d out_tex[[texture(1), function_constant(hardshrink_is_tex)]], ushort3 gid[[thread_position_in_grid]]) { const ushort oH = ushort_arg_2; const ushort oW = ushort_arg_3; const half lambda = (half)float_arg_0; if (gid.x >= oW || gid.y >= oH) { return; } ushort2 gid_ = gid.xy; if (hardshrink_is_arr) { half4 value = in_arr.read(gid_, gid.z); half4 mask1 = half4(value <= lambda); half4 mask2 = half4(value >= -lambda); half4 outval = (1 - mask1)*value + (1 - mask2)*value; out_arr.write(outval, gid_, gid.z); } else { half4 value = in_tex.read(gid_); half4 mask1 = half4(value <= lambda); half4 mask2 = half4(value >= -lambda); half4 outval = (1 - mask1)*value + (1 - mask2)*value; out_tex.write(outval, gid_); } } constant bool leaky_relu_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4); constant bool leaky_relu_is_tex = !leaky_relu_is_arr; kernel void leaky_relu(texture2d_array in_arr[[texture(0), function_constant(leaky_relu_is_arr)]], texture2d in_tex[[texture(0), function_constant(leaky_relu_is_tex)]], texture2d_array out_arr[[texture(1), function_constant(leaky_relu_is_arr)]], texture2d out_tex[[texture(1), function_constant(leaky_relu_is_tex)]], ushort3 gid[[thread_position_in_grid]]) { const ushort oH = ushort_arg_2; const ushort oW = ushort_arg_3; const half negative_slope = (half)float_arg_0; if (gid.x >= oW || gid.y >= oH) { return; } ushort2 gid_ = gid.xy; if (leaky_relu_is_arr) { half4 value = in_arr.read(gid_, gid.z); half4 is_negative = half4(value < 0.0); half4 outval = is_negative*value*negative_slope + (1-is_negative)*value; out_arr.write(outval, gid_, gid.z); } else { half4 value = in_tex.read(gid_); half4 is_negative = half4(value < 0.0); half4 outval = is_negative*value*negative_slope + (1-is_negative)*value; out_tex.write(outval, gid_); } } constant bool out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4); constant bool out_is_tex = !out_is_arr; constant bool in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4); constant bool in_is_tex = !in_is_arr; kernel void reflection_pad2d(texture2d_array in_arr[[texture(0), function_constant(in_is_arr)]], texture2d in_tex[[texture(0),function_constant(in_is_tex)]], texture2d_array out_arr[[texture(1), function_constant(out_is_arr)]], texture2d out_tex[[texture(1), function_constant(out_is_tex)]], ushort3 gid[[thread_position_in_grid]]) { const ushort H2 = ushort_arg_0; const ushort W2 = ushort_arg_1; if (gid.x >= W2 || gid.y >= H2) { return; } const ushort pad_left = ushort_arg_8; const ushort pad_right = ushort_arg_9; const ushort pad_top = ushort_arg_10; const ushort pad_bottom = ushort_arg_11; const ushort2 out_size = ushort2(W2, H2); const ushort xoff_pre = 2*max(pad_left - gid.x, 0); const ushort xoff_post = 2*max(gid.x - (out_size.x - 1 - pad_right), 0); const ushort yoff_pre = 2*max(pad_top - gid.y, 0); const ushort yoff_post = 2*max(gid.y - (out_size.y - 1 - pad_bottom), 0); ushort2 inpos = ushort2( gid.x + xoff_pre - xoff_post - pad_left, gid.y + yoff_pre - yoff_post - pad_top); half4 intex; if (in_is_arr) { intex = in_arr.read(inpos, gid.z); } else { intex = in_tex.read(inpos); } if (out_is_arr) { out_arr.write(intex, gid.xy, gid.z); } else { out_tex.write(intex, gid.xy); } } constant bool reshape_out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4); constant bool reshape_out_is_tex = !reshape_out_is_arr; constant bool reshape_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4); constant bool reshape_in_is_tex = !reshape_in_is_arr; kernel void reshape(texture2d_array in_arr[[texture(0), function_constant(reshape_in_is_arr)]], texture2d in_tex[[texture(0),function_constant(reshape_in_is_tex)]], texture2d_array out_arr[[texture(1), function_constant(reshape_out_is_arr)]], texture2d out_tex[[texture(1), function_constant(reshape_out_is_tex)]], ushort3 gid[[thread_position_in_grid]]) { const ushort H2 = ushort_arg_0; const ushort W2 = ushort_arg_1; const ushort C2 = ushort_arg_2; if (gid.x >= W2 || gid.y >= H2) { return; } const ushort H1 = ushort_arg_4; const ushort W1 = ushort_arg_5; const ushort C1 = ushort_arg_6; const ushort N1 = ushort_arg_7; const size_t numel1 = H1 * W1 * C1 * N1; const ushort slices2 = divRoundUp(C2, 4); const ushort slices1 = divRoundUp(C1, 4); const ushort n2 = gid.z / slices2; //image index const ushort s2 = gid.z - n2 * slices2; // slice offest half4 value; for (int idx = 0; idx < 4; ++idx){ // we compute the "linear index" of the output element, // and convert it to the equivalent "linear index" of the input element. ushort offset = 4 * s2 + idx; size_t linear_idx = n2 * C2 * H2 * W2 + offset * H2 * W2 + gid.y * W2 + gid.x; if(linear_idx >= numel1){ value[idx] = 0; continue; } auto x1 = linear_idx % W1; auto y1 = ((int)(linear_idx/W1)) % H1; auto s1 = ((int)(linear_idx/W1/H1) % C1); auto n1 = ((int)(linear_idx/W1/H1/C1) % N1); auto z1 = (int)s1 / 4 + n1 * slices1; auto pos = s1 % 4; if(reshape_in_is_arr) { value[idx] = in_arr.read(ushort2(x1, y1), z1)[pos]; } else { value[idx] = in_tex.read(ushort2(x1, y1))[pos]; } } if(reshape_out_is_arr) { out_arr.write(value, gid.xy, gid.z); } else { out_tex.write(value, gid.xy); } } constant bool transpose_in_is_arr = (ushort_arg_3 > 1 || ushort_arg_4 > 4); constant bool transpose_in_is_tex = !transpose_in_is_arr; constant bool transpose_out_is_arr = (ushort_arg_5 > 1 || ushort_arg_6 > 4); constant bool transpose_out_is_tex = !transpose_out_is_arr; kernel void transpose(texture2d_arrayin_arr[[texture(0),function_constant(transpose_in_is_arr)]], texture2d in_tex[[texture(0), function_constant(transpose_in_is_tex)]], texture2d_arrayout_arr[[texture(1),function_constant(transpose_out_is_arr)]], texture2d out_tex[[texture(1), function_constant(transpose_out_is_tex)]], constant ushort* inSizeBuffer [[buffer(0)]], constant ushort* outSizeBuffer [[buffer(1)]], ushort3 gid[[thread_position_in_grid]]) { const ushort dim0 = ushort_arg_0; const ushort dim1 = ushort_arg_1; const ushort dim = ushort_arg_2; const ushort N1 = ushort_arg_3; const ushort C1 = ushort_arg_4; const ushort N2 = ushort_arg_5; const ushort C2 = ushort_arg_6; ushort W1,W2,H1,H2; if(transpose_in_is_arr) { W1 = in_arr.get_width(); H1 = in_arr.get_height(); } else { W1 = in_tex.get_width(); H1 = in_tex.get_height(); } if(transpose_out_is_arr) { W2 = out_arr.get_width(); H2 = out_arr.get_height(); } else { W2 = out_tex.get_width(); H2 = out_tex.get_height(); } if (gid.x >= W2 || gid.y >= H2) { return; } const size_t numel = H2 * W2 * C2 * N2; const ushort slices2 = divRoundUp(C2, 4); const ushort slices1 = divRoundUp(C1, 4); const ushort n2 = gid.z / slices2; const ushort s2 = gid.z - n2 * slices2; half4 value; ushort4 threadIndexBufferLower{1, 1, 1, 1}; ushort4 threadIndexBufferUpper{1, 1, 1 ,1}; for (int idx = 0; idx < 4; ++idx){ ushort offset = 4 * s2 + idx; size_t linear_idx2 = n2 * C2 * H2 * W2 + offset * H2 * W2 + gid.y * W2 + gid.x; if(linear_idx2 >= numel) { value[idx] = 0; continue; } ushort d2 = 0; for(int j = dim-1; j>=0; --j){ d2 = outSizeBuffer[j]; if(j > 3) { threadIndexBufferUpper[j-3] = linear_idx2 % d2; } else { threadIndexBufferLower[j] = linear_idx2 % d2; } linear_idx2 /= d2; } // swap dims ushort tmp; if(dim0 > 3) { tmp = threadIndexBufferUpper[dim0-3]; } else { tmp = threadIndexBufferLower[dim0]; } if(dim0 > 3 && dim1 > 3) { threadIndexBufferUpper[dim0-3] = threadIndexBufferUpper[dim1-3]; } else if (dim0 > 3 && dim1 < 3) { threadIndexBufferUpper[dim0-3] = threadIndexBufferLower[dim1]; } else if (dim0 < 3 && dim1 > 3) { threadIndexBufferLower[dim0] = threadIndexBufferUpper[dim1-3]; } else { threadIndexBufferLower[dim0] = threadIndexBufferLower[dim1]; } if(dim1 > 3) { threadIndexBufferUpper[dim1-3] = tmp; } else { threadIndexBufferLower[dim1] = tmp; } size_t linear_idx1 = 0; ushort m = 1; ushort d1 = 0; for(int k = dim-1; k>=0; --k) { if(k > 3) { d1 = threadIndexBufferUpper[k-3]; } else { d1 = threadIndexBufferLower[k]; } linear_idx1 += d1 * m; m *= inSizeBuffer[k]; } auto x1 = linear_idx1 % W1; auto y1 = ((int)(linear_idx1/W1)) % H1; auto c1 = ((int)(linear_idx1/W1/H1) % C1); auto n1 = ((int)(linear_idx1/W1/H1/C1) % N1); auto z1 = (int)c1 / 4 + n1 * slices1; auto pos = c1 % 4; if(transpose_in_is_arr) { value[idx] = in_arr.read(ushort2(x1, y1), z1)[pos]; } else { value[idx] = in_tex.read(ushort2(x1, y1))[pos]; } } if(transpose_out_is_arr) { out_arr.write(value, gid.xy, gid.z); } else { out_tex.write(value, gid.xy); } } constant bool split_channels_in_is_arr = (ushort_arg_0 > 4); constant bool split_channels_in_is_tex = !split_channels_in_is_arr; constant bool split_channels_out1_is_arr = (ushort_arg_1 > 4); constant bool split_channels_out1_is_tex = !split_channels_out1_is_arr; constant bool split_channels_out2_is_arr = (ushort_arg_2 > 4); constant bool split_channels_out2_is_tex = !(split_channels_out2_is_arr); // A naive implementation to split the input texture into two on channel dimension kernel void split_channels(texture2d_array in_arr[[texture(0), function_constant(split_channels_in_is_arr)]], texture2d in_tex[[texture(0), function_constant(split_channels_in_is_tex)]], texture2d_array out1_arr[[texture(1),function_constant(split_channels_out1_is_arr)]], texture2d out1_tex[[texture(1),function_constant(split_channels_out1_is_tex)]], texture2d_array out2_arr[[texture(2), function_constant(split_channels_out2_is_arr)]], texture2d out2_tex[[texture(2),function_constant(split_channels_out2_is_tex)]], ushort3 gid[[thread_position_in_grid]]) { ushort W,H; if(split_channels_in_is_arr) { W = in_arr.get_width(); H = in_arr.get_height(); } else { W = in_tex.get_width(); H = in_tex.get_height(); } if(gid.x >= W || gid.y >= H){ return; } const ushort C1 = ushort_arg_1; const ushort s1 = divRoundUp(C1, 4); const ushort c_offset = C1 % 4; half4 tmp1(0.0, 0.0, 0.0, 0.0); half4 tmp2(0.0, 0.0, 0.0, 0.0); half4 in41 = split_channels_in_is_arr ? in_arr.read(gid.xy, gid.z) : in_tex.read(gid.xy); half4 in42 = split_channels_in_is_arr ? in_arr.read(gid.xy, gid.z+1) : half4(0,0,0,0); if(gid.z < s1 - 1) { if(split_channels_out1_is_arr) { out1_arr.write(in41, gid.xy, gid.z); } } else if(gid.z == s1 - 1) { if(c_offset == 0){ if(split_channels_out1_is_arr) { out1_arr.write(in41, gid.xy, gid.z); } else { out1_tex.write(in41, gid.xy); } return; } else if(c_offset == 1) { tmp1.x = in41.x; tmp2.xyz = in41.yzw; tmp2.w = in42.x; } else if (c_offset == 2) { tmp1.xy = in41.xy; tmp2.xy = in41.zw; tmp2.zw = in42.xy; } else { tmp1.xyz = in41.xyz; tmp2.x = in41.w; tmp2.yzw = in42.xyz; } if(split_channels_out1_is_arr) { out1_arr.write(tmp1, gid.xy, gid.z); } else { out1_tex.write(tmp1, gid.xy); } if(split_channels_out2_is_arr) { out2_arr.write(tmp2, gid.xy, 0); } else { out2_tex.write(tmp2, gid.xy); } } else { if (c_offset == 0) { if(split_channels_out2_is_arr) { out2_arr.write(in41, gid.xy, gid.z - s1); } else { out2_tex.write(in41, gid.xy); } return; } else if (c_offset == 1 ){ tmp2.xyz = in41.yzw; tmp2.w = in42.x; } else if (c_offset == 2){ tmp2.xy = in41.zw; tmp2.zw = in42.xy; } else { tmp2.x = in41.w; tmp2.yzw = in42.xyz; } if(split_channels_out2_is_arr) { out2_arr.write(tmp2, gid.xy, gid.z - s1 + 1); } else { out2_tex.write(tmp2, gid.xy); } } } constant bool ra_has_in_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4); constant bool ra_has_out_arr = (ushort_arg_4 > 1 || ushort_arg_2 > 4); constant bool ra_has_in_tex = (!ra_has_in_arr); constant bool ra_has_out_tex = (!ra_has_out_arr); kernel void roi_align(texture2d_array ina[[texture(0), function_constant(ra_has_in_arr)]], texture2d in[[texture(0), function_constant(ra_has_in_tex)]], texture2d_array outa[[texture(1), function_constant(ra_has_out_arr)]], texture2d out[[texture(1), function_constant(ra_has_out_tex)]], constant half4* rois[[buffer(0)]], ushort3 gid[[thread_position_in_grid]]) { ushort out_width, out_height; if (ra_has_out_arr) { out_width = outa.get_width(); out_height = outa.get_height(); } else { out_width = out.get_width(); out_height = out.get_height(); } if (gid.x >= out_width || gid.y >= out_height) { return; } const half spatial_scale = half(ushort_arg_0) / 10000; const ushort sampling_ratio = ushort_arg_1; const ushort C = ushort_arg_2; const ushort pw = gid.x; const ushort ph = gid.y; const ushort n = gid.z / divRoundUp(C, 4); const ushort c = gid.z % divRoundUp(C, 4); const half4 roi_scaled = rois[n] * spatial_scale; const half roi_start_w = roi_scaled[0]; const half roi_start_h = roi_scaled[1]; const half roi_end_w = roi_scaled[2]; const half roi_end_h = roi_scaled[3]; // Force malformed ROIs to be 1x1 const half roi_width = max(roi_end_w - roi_start_w, (half)1.); const half roi_height = max(roi_end_h - roi_start_h, (half)1.); const half bin_size_h = static_cast(roi_height) / static_cast(out_height); const half bin_size_w = static_cast(roi_width) / static_cast(out_width); const ushort roi_bin_grid_h = sampling_ratio > 0 ? sampling_ratio : ceil(roi_height / static_cast(out_height)); const ushort roi_bin_grid_w = sampling_ratio > 0 ? sampling_ratio : ceil(roi_width / static_cast(out_width)); const half count = roi_bin_grid_h * roi_bin_grid_w; half4 output_val = 0.0; constexpr sampler s2(coord::pixel, address::clamp_to_edge, filter::linear); for (int iy = 0; iy < roi_bin_grid_h; iy++) { for (int ix = 0; ix < roi_bin_grid_w; ix++) { // Shift the pixel by 0.5. This is critical to achieve high accuracy. const half y = roi_start_h + ph * bin_size_h + (iy+0.5) * bin_size_h / static_cast(roi_bin_grid_h); const half x = roi_start_w + pw * bin_size_w + (ix+0.5) * bin_size_w / static_cast(roi_bin_grid_w); if (ra_has_in_arr) { output_val += ina.sample(s2, float2(x, y), c); } else { output_val += in.sample(s2, float2(x, y)); } } } output_val /= count; if (ra_has_out_arr) { outa.write(static_cast(output_val), gid.xy, gid.z); } else { out.write(static_cast(output_val), gid.xy); } } )PT_METAL_SHADERS"; #endif /* MPSCNNShaders_h */