// Copyright 2022 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$CHANNEL_SUBTILE = 1
$assert CHANNEL_TILE % CHANNEL_SUBTILE == 0
$CHANNEL_ROUND = 1
$assert MIDDLE_PASS_TILE <= LAST_PASS_TILE
$assert FIRST_PASS_TILE >= 1
$assert MIDDLE_PASS_TILE >= 1
$assert LAST_PASS_TILE >= 1
$assert ACCUMULATORS >= 1
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>

#include "xnnpack/dwconv.h"
#include "xnnpack/math.h"


$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
$SUFFIX = {"LINEAR": "", "MINMAX": "_minmax"}[ACTIVATION]
$PARAMS = {"LINEAR": "struct xnn_f32_default_params", "MINMAX": "union xnn_f32_minmax_params"}[ACTIVATION]
void xnn_f32_dwconv${SUFFIX}_ukernel_${FIRST_PASS_TILE}f${MIDDLE_PASS_TILE}m${LAST_PASS_TILE}l${CHANNEL_TILE}c${CHANNEL_SUBTILE}s${CHANNEL_ROUND}r__${"wasm" if WASM else "scalar"}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}(
    size_t channels,
    size_t output_width,
    const float** input,
    const float* weights,
    float* output,
    intptr_t input_stride,
    size_t output_increment,
    size_t input_offset,
    const float* zero,
    size_t kernel_size,
    float* buffer,
    const ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)])
{
  assert(channels != 0);
  assert(output_width != 0);
  assert(kernel_size > ${FIRST_PASS_TILE});

  $if ACTIVATION == "MINMAX":
    const float vmin = params->scalar.min;
    const float vmax = params->scalar.max;
  do {
    const float* w = weights;

    // First pass to process ${FIRST_PASS_TILE} inputs.
    {
      float* b = buffer;
      $for K in range(FIRST_PASS_TILE):
        const float* i${K} = input[${K}];
        assert(i${K} != NULL);
        if XNN_UNPREDICTABLE(i${K} != zero) {
          i${K} = (const float*) ((uintptr_t) i${K} + input_offset);
        }
      input += ${FIRST_PASS_TILE};

      // Process c channels and write to buffer.
      $if CHANNEL_TILE == 1:
        for (size_t c = channels; c >= 1; c -= 1) {
          float vacc0p0 = w[0];

          $for K in range(FIRST_PASS_TILE):
            const float vi${K} = *i${K}++;
            const float vk${K} = w[${K+1}];
            $if 1 <= K < ACCUMULATORS:
              float vacc0p${K} = vi${K} * vk${K};
            $else:
              vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}, vk${K}, vacc0p${K % ACCUMULATORS});

          w += ${FIRST_PASS_TILE + 1};

          $if ACCUMULATORS > 1:
            // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
            $ACC_SLICE = 1
            $while ACC_SLICE < ACCUMULATORS:
              $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
                $if A + ACC_SLICE < ACCUMULATORS:
                  vacc0p${A} = vacc0p${A} + vacc0p${A + ACC_SLICE};
              $ACC_SLICE *= 2

          *b++ = vacc0p0;
        }
      $else:
        size_t c = round_up_po2(channels, ${CHANNEL_ROUND});
        for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
          $for C in range(CHANNEL_TILE):
            float vacc${C}p0 = w[${C}];

          $for K in range(FIRST_PASS_TILE):

            $for C in range(CHANNEL_TILE):
              const float vi${K}x${C} = i${K}[${C}];
            i${K} += ${CHANNEL_TILE};

            $for C in range(CHANNEL_TILE):
              const float vk${K}x${C} = w[${(K + 1) * CHANNEL_TILE + C}];
            $for C in range(CHANNEL_TILE):
              $if 1 <= K < ACCUMULATORS:
                float vacc${C}p${K} = vi${K}x${C} * vk${K}x${C};
              $else:
                vacc${C}p${K % ACCUMULATORS} = math_muladd_f32(vi${K}x${C}, vk${K}x${C}, vacc${C}p${K % ACCUMULATORS});

          w += ${(FIRST_PASS_TILE + 1) * CHANNEL_TILE};

          $if ACCUMULATORS > 1:
            // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
            $ACC_SLICE = 1
            $while ACC_SLICE < ACCUMULATORS:
              $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
                $if A + ACC_SLICE < ACCUMULATORS:
                  $for C in range(CHANNEL_TILE):
                    vacc${C}p${A} = vacc${C}p${A} + vacc${C}p${A + ACC_SLICE};
              $ACC_SLICE *= 2

          $for C in range(CHANNEL_TILE):
            b[${C}] = vacc${C}p0;
          b += ${CHANNEL_TILE};
        }


        $if CHANNEL_TILE == 2:
          if (c != 0) {
            float vacc0p0 = w[0];
            $for K in range(FIRST_PASS_TILE):

              const float vi${K}x0 = i${K}[0];
              i${K} += 1;

              const float vk${K}x0 = w[${(K + 1)}];
              $if 1 <= K < ACCUMULATORS:
                float vacc0p${K} = vi${K}x0 * vk${K}x0;
              $else:
                vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}x0, vk${K}x0, vacc0p${K % ACCUMULATORS});

            w += ${(FIRST_PASS_TILE + 1)};

            $if ACCUMULATORS > 1:
              // Add up all accumulators to vacc0p0
              $ACC_SLICE = 1
              $while ACC_SLICE < ACCUMULATORS:
                $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
                  $if A + ACC_SLICE < ACCUMULATORS:
                    vacc0p${A} = vacc0p${A} + vacc0p${A + ACC_SLICE};
                $ACC_SLICE *= 2

            b[0] = vacc0p0;
            b += 1;
          }
        $else:
          for (; c != 0; c --) {
            float vacc0p0 = w[0];
            $for K in range(FIRST_PASS_TILE):

              const float vi${K}x0 = i${K}[0];
              i${K} += 1;

              const float vk${K}x0 = w[${(K + 1)}];
              $if 1 <= K < ACCUMULATORS:
                float vacc0p${K} = vi${K}x0 * vk${K}x0;
              $else:
                vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}x0, vk${K}x0, vacc0p${K % ACCUMULATORS});

            w += ${(FIRST_PASS_TILE + 1)};

            $if ACCUMULATORS > 1:
              // Add up all accumulators to vacc0p0
              $ACC_SLICE = 1
              $while ACC_SLICE < ACCUMULATORS:
                $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
                  $if A + ACC_SLICE < ACCUMULATORS:
                    vacc0p${A} = vacc0p${A} + vacc0p${A + ACC_SLICE};
                $ACC_SLICE *= 2

            b[0] = vacc0p0;
            b += 1;
          }
    }

    // Middle pass to process ${MIDDLE_PASS_TILE} inputs in each iteration.
    for (size_t ks = kernel_size - ${FIRST_PASS_TILE}; ks > ${LAST_PASS_TILE}; ks -= ${MIDDLE_PASS_TILE}) {
      float* b = buffer;
      $for K in range(MIDDLE_PASS_TILE):
        const float* i${K} = input[${K}];
        assert(i${K} != NULL);
        if XNN_UNPREDICTABLE(i${K} != zero) {
          i${K} = (const float*) ((uintptr_t) i${K} + input_offset);
        }
      input += ${MIDDLE_PASS_TILE};

      $if CHANNEL_TILE == 1:
        for (size_t c = channels; c >= 1; c -= 1) {
          float vacc0p0 = *b;

          $for K in range(MIDDLE_PASS_TILE):
            const float vi${K} = *i${K}++;
            const float vk${K} = w[${K}];
            $if 1 <= K < ACCUMULATORS:
              float vacc0p${K} = vi${K} * vk${K};
            $else:
              vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}, vk${K}, vacc0p${K % ACCUMULATORS});

          $if ACCUMULATORS > 1:
            // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
            $ACC_SLICE = 1
            $while ACC_SLICE < ACCUMULATORS:
              $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
                $if A + ACC_SLICE < ACCUMULATORS:
                  vacc0p${A} = vacc0p${A} + vacc0p${A + ACC_SLICE};
              $ACC_SLICE *= 2

          w += ${MIDDLE_PASS_TILE * CHANNEL_TILE};
          *b++ = vacc0p0;
        }
      $else:
        size_t c = round_up_po2(channels, ${CHANNEL_ROUND});
        for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
          $for C in range(CHANNEL_TILE):
            float vacc${C}p0 = b[${C}];

          $for K in range(FIRST_PASS_TILE):

            $for C in range(CHANNEL_TILE):
              const float vi${K}x${C} = i${K}[${C}];
            i${K} += ${CHANNEL_TILE};

            $for C in range(CHANNEL_TILE):
              const float vk${K}x${C} = w[${K * CHANNEL_TILE + C}];
            $for C in range(CHANNEL_TILE):
              $if 1 <= K < ACCUMULATORS:
                float vacc${C}p${K} = vi${K}x${C} * vk${K}x${C};
              $else:
                vacc${C}p${K % ACCUMULATORS} = math_muladd_f32(vi${K}x${C}, vk${K}x${C}, vacc${C}p${K % ACCUMULATORS});

          w += ${MIDDLE_PASS_TILE * CHANNEL_TILE};

          $if ACCUMULATORS > 1:
            // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
            $ACC_SLICE = 1
            $while ACC_SLICE < ACCUMULATORS:
              $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
                $if A + ACC_SLICE < ACCUMULATORS:
                  $for C in range(CHANNEL_TILE):
                    vacc${C}p${A} = vacc${C}p${A} + vacc${C}p${A + ACC_SLICE};
              $ACC_SLICE *= 2

          $for C in range(CHANNEL_TILE):
            b[${C}] = vacc${C}p0;
          b += ${CHANNEL_TILE};
        }

        $if CHANNEL_TILE == 2:
          if (c != 0) {
            float vacc0p0 = b[0];

            $for K in range(FIRST_PASS_TILE):

              const float vi${K}x0 = i${K}[0];
              i${K} += 1;

              const float vk${K}x0 = w[${K}];
              $if 1 <= K < ACCUMULATORS:
                float vacc0p${K} = vi${K}x0 * vk${K}x0;
              $else:
                vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}x0, vk${K}x0, vacc0p${K % ACCUMULATORS});

            w += ${MIDDLE_PASS_TILE};

            $if ACCUMULATORS > 1:
              // Add up all accumulators to vacc0p0
              $ACC_SLICE = 1
              $while ACC_SLICE < ACCUMULATORS:
                $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
                  $if A + ACC_SLICE < ACCUMULATORS:
                    vacc0p${A} = vacc0p${A} + vacc0p${A + ACC_SLICE};
                $ACC_SLICE *= 2

            b[0] = vacc0p0;
            b += 1;
          }
        $else:
          for (; c != 0; c --) {
            float vacc0p0 = b[0];

            $for K in range(FIRST_PASS_TILE):

              const float vi${K}x0 = i${K}[0];
              i${K} += 1;

              const float vk${K}x0 = w[${K}];
              $if 1 <= K < ACCUMULATORS:
                float vacc0p${K} = vi${K}x0 * vk${K}x0;
              $else:
                vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}x0, vk${K}x0, vacc0p${K % ACCUMULATORS});

            w += ${MIDDLE_PASS_TILE};

            $if ACCUMULATORS > 1:
              // Add up all accumulators to vacc0p0
              $ACC_SLICE = 1
              $while ACC_SLICE < ACCUMULATORS:
                $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
                  $if A + ACC_SLICE < ACCUMULATORS:
                    vacc0p${A} = vacc0p${A} + vacc0p${A + ACC_SLICE};
                $ACC_SLICE *= 2

            b[0] = vacc0p0;
            b += 1;
          }
    }

    // Last pass to process up to ${LAST_PASS_TILE} inputs.
    {
      float* b = buffer;
      $for K in range(0, LAST_PASS_TILE):
        const float* i${K} = input[${K}];
        assert(i${K} != NULL);
        if XNN_UNPREDICTABLE(i${K} != zero) {
          i${K} = (const float*) ((uintptr_t) i${K} + input_offset);
        }

      $if CHANNEL_TILE > 1:
        size_t c = channels;
        for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
          $for C in range(CHANNEL_TILE):
            float vacc${C}p0 = b[${C}];
          b += ${CHANNEL_TILE};

          $for K in range(LAST_PASS_TILE):

            $for C in range(CHANNEL_TILE):
              const float vi${K}x${C} = i${K}[${C}];
            i${K} += ${CHANNEL_TILE};

            $for C in range(CHANNEL_TILE):
              const float vk${K}x${C} = w[${K * CHANNEL_TILE + C}];
            $for C in range(CHANNEL_TILE):
              $if 1 <= K < ACCUMULATORS:
                float vacc${C}p${K} = vi${K}x${C} * vk${K}x${C};
              $else:
                vacc${C}p${K % ACCUMULATORS} = math_muladd_f32(vi${K}x${C}, vk${K}x${C}, vacc${C}p${K % ACCUMULATORS});

          w += ${(LAST_PASS_TILE) * CHANNEL_TILE};

          $if ACCUMULATORS > 1:
            // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
            $ACC_SLICE = 1
            $while ACC_SLICE < ACCUMULATORS:
              $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
                $if A + ACC_SLICE < ACCUMULATORS:
                  $for C in range(CHANNEL_TILE):
                    vacc${C}p${A} = vacc${C}p${A} + vacc${C}p${A + ACC_SLICE};
              $ACC_SLICE *= 2

          $if ACTIVATION == "MINMAX":
            $for C in range(CHANNEL_TILE):
              float vacc${C} = ${MAX_F32}(vacc${C}p0, vmin);

            $for C in range(CHANNEL_TILE):
              vacc${C} = ${MIN_F32}(vacc${C}, vmax);

            $for C in range(CHANNEL_TILE):
              output[${C}] = vacc${C};
          $else:
            $for C in range(CHANNEL_TILE):
              output[${C}] = vacc${C}p0;
          output += ${CHANNEL_TILE};
        }
        $if CHANNEL_TILE == 2:
          if (c != 0) {
            float vacc0p0 = *b;

            $for K in range(LAST_PASS_TILE):
              const float vi${K} = *i${K};
              const float vk${K} = w[${K}];
              $if 1 <= K < ACCUMULATORS:
                float vacc0p${K} = vi${K} * vk${K};
              $else:
                vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}, vk${K}, vacc0p${K % ACCUMULATORS});

            $if ACCUMULATORS > 1:
              // Add up all accumulators to vacc0p0
              $ACC_SLICE = 1
              $while ACC_SLICE < ACCUMULATORS:
                $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
                  $if A + ACC_SLICE < ACCUMULATORS:
                    vacc0p${A} = vacc0p${A} + vacc0p${A + ACC_SLICE};
                $ACC_SLICE *= 2

            $if ACTIVATION == "MINMAX":
              float vacc0 = ${MAX_F32}(vacc0p0, vmin);
              vacc0 = ${MIN_F32}(vacc0, vmax);
              *output++ = vacc0;
            $else:
              *output++ = vacc0p0;
          }
        $else:
          for (; c != 0; c --) {
            float vacc0p0 = *b++;

            $for K in range(LAST_PASS_TILE):
              const float vi${K} = *i${K}++;
              const float vk${K} = w[${K}];
              $if 1 <= K < ACCUMULATORS:
                float vacc0p${K} = vi${K} * vk${K};
              $else:
                vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}, vk${K}, vacc0p${K % ACCUMULATORS});
            w += ${LAST_PASS_TILE};

            $if ACCUMULATORS > 1:
              // Add up all accumulators to vacc0p0
              $ACC_SLICE = 1
              $while ACC_SLICE < ACCUMULATORS:
                $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
                  $if A + ACC_SLICE < ACCUMULATORS:
                    vacc0p${A} = vacc0p${A} + vacc0p${A + ACC_SLICE};
                $ACC_SLICE *= 2

            $if ACTIVATION == "MINMAX":
              float vacc0 = ${MAX_F32}(vacc0p0, vmin);
              vacc0 = ${MIN_F32}(vacc0, vmax);
              *output++ = vacc0;
            $else:
              *output++ = vacc0p0;
          }
      $else:
        for (size_t c = channels; c >= 1; c -= 1) {
          float vacc0p0 = *b++;

          $for K in range(LAST_PASS_TILE):
            const float vi${K} = *i${K}++;
            const float vk${K} = w[${K}];
            $if 1 <= K < ACCUMULATORS:
              float vacc0p${K} = vi${K} * vk${K};
            $else:
              vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}, vk${K}, vacc0p${K % ACCUMULATORS});

          w += ${LAST_PASS_TILE * CHANNEL_TILE};

          $if ACCUMULATORS > 1:
            // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
            $ACC_SLICE = 1
            $while ACC_SLICE < ACCUMULATORS:
              $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
                $if A + ACC_SLICE < ACCUMULATORS:
                  vacc0p${A} = vacc0p${A} + vacc0p${A + ACC_SLICE};
              $ACC_SLICE *= 2

          $if ACTIVATION == "MINMAX":
            float vacc0 = ${MAX_F32}(vacc0p0, vmin);
            vacc0 = ${MIN_F32}(vacc0, vmax);
            *output++ = vacc0;
          $else:
            *output++ = vacc0p0;
        }

    }
    input = (const float**) ((uintptr_t) input + input_stride);
    output = (float*) ((uintptr_t) output + output_increment);
  } while (--output_width != 0);
}
