/*******************************************************************************
* Copyright 2017-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "utils/parallel.hpp"

#include "conv/ref_conv.hpp"

namespace conv {

void compute_ref_direct_fwd(const prb_t *prb, const args_t &args) {
    const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
    const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
    const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
    const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST);
    const dnn_mem_t &src_scales
            = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
    const dnn_mem_t &wei_scales
            = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
    const dnn_mem_t &dst_scales
            = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
    const dnn_mem_t &src_zps
            = args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
    const dnn_mem_t &wei_zps
            = args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
    const dnn_mem_t &dst_zps
            = args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);

    const bool has_src_scale = !prb->attr.scales.get(DNNL_ARG_SRC).is_def();
    const bool has_wei_scale = !prb->attr.scales.get(DNNL_ARG_WEIGHTS).is_def();
    const bool has_dst_scale = !prb->attr.scales.get(DNNL_ARG_DST).is_def();
    const int wei_scale_mask = prb->attr.scales.get_mask(
            DNNL_ARG_WEIGHTS, dnnl_convolution, wei_m.ndims(), prb->has_groups);
    const int src_scale_mask = prb->attr.scales.get_mask(
            DNNL_ARG_SRC, dnnl_convolution, src_m.ndims(), prb->has_groups);
    const int dst_scale_mask = prb->attr.scales.get_mask(
            DNNL_ARG_DST, dnnl_convolution, dst_m.ndims(), prb->has_groups);

    assert(IMPLICATION(
            has_src_scale, src_scales.nelems() == 1 || src_scale_mask == 3));
    assert(IMPLICATION(
            has_dst_scale, dst_scales.nelems() == 1 || dst_scale_mask == 2));

    const bool has_src_zp = !prb->attr.zero_points.get(DNNL_ARG_SRC).is_def();
    const bool has_wei_zp
            = !prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).is_def();
    const bool has_dst_zp = !prb->attr.zero_points.get(DNNL_ARG_DST).is_def();
    const int src_zp_mask = attr_t::get_default_mask(
            prb->attr.zero_points.get(DNNL_ARG_SRC).policy);
    const int wei_zp_mask = attr_t::get_default_mask(
            prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).policy);
    const int dst_zp_mask = attr_t::get_default_mask(
            prb->attr.zero_points.get(DNNL_ARG_DST).policy);

    /* help compiler optimize the code */
    const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic;
    const int64_t OCG = OC / G, ICG = IC / G;
    const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
    const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
    const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
    const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
    const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
    const int64_t DD = prb->dd + 1;
    const int64_t DH = prb->dh + 1;
    const int64_t DW = prb->dw + 1;

    int wei_zp = (has_wei_zp && (wei_zp_mask == 0)) ? wei_zps.get_elem(0) : 0;

    auto ker = [&](float &d, int64_t g, int64_t mb, int64_t oc, int64_t od,
                       int64_t oh, int64_t ow) {
        const float *__restrict src_loc
                = (const float *)src_m + (mb * IC + g * ICG) * ID * IH * IW;
        const float *__restrict wei_loc
                = (const float *)wei_m + (g * OCG + oc) * ICG * KD * KH * KW;

        for (int64_t kd = 0; kd < KD; ++kd) {
            const int64_t id = od * SD - PD + kd * DD;
            if (id < 0 || id >= ID) continue;
            for (int64_t kh = 0; kh < KH; ++kh) {
                const int64_t ih = oh * SH - PH + kh * DH;
                if (ih < 0 || ih >= IH) continue;
                for (int64_t kw = 0; kw < KW; ++kw) {
                    const int64_t iw = ow * SW - PW + kw * DW;
                    if (iw < 0 || iw >= IW) continue;

                    for (int64_t ic = 0; ic < ICG; ++ic) {
                        int64_t src_off = ((ic * ID + id) * IH + ih) * IW + iw;
                        int64_t wei_off = ((ic * KD + kd) * KH + kh) * KW + kw;
                        float src_scale = 1.f;
                        if (has_src_scale)
                            src_scale = src_scales.get_elem(
                                    src_scale_mask > 0 ? g * ICG + ic : 0);
                        int src_zp = has_src_zp ? src_zps.get_elem(
                                             src_zp_mask > 0 ? g * ICG + ic : 0)
                                                : 0;
                        const float s = src_loc[src_off];
                        const float w = wei_loc[wei_off];
                        const float d_tmp
                                = ((s - src_zp) * src_scale) * (w - wei_zp);
                        d += d_tmp;
                    }
                }
            }
        }
    };

    auto v_po_masks = prb->attr.post_ops.get_po_masks();
    benchdnn_parallel_nd(G, MB, OCG, OD, OH, OW,
            [&](int64_t g, int64_t mb, int64_t oc, int64_t od, int64_t oh,
                    int64_t ow) {
                const size_t dst_off = dst_off_f(prb, mb, g, oc, od, oh, ow);
                float &dst = ((float *)dst_m)[dst_off];

                float conv_res = 0;
                ker(conv_res, g, mb, oc, od, oh, ow);

                // apply scale as:
                //    dst = src_scale * wei_scale * conv(src - zp_src, wei)
                float wei_scale = 1.f, dst_scale = 1.f;
                if (has_wei_scale)
                    wei_scale = wei_scales.get_elem(
                            wei_scale_mask > 0 ? g * OCG + oc : 0);
                if (has_dst_scale)
                    dst_scale = 1.f
                            / dst_scales.get_elem(
                                    dst_scale_mask > 0 ? g * OCG + oc : 0);

                conv_res *= wei_scale;

                if (prb->dir & FLAG_BIA) {
                    const size_t bia_off = bia_off_f(prb, g, oc);
                    conv_res += ((float *)bia_m)[bia_off];
                }

                const auto v_po_vals
                        = prepare_po_vals(dst_m, args, v_po_masks, dst_off);

                maybe_post_ops(prb->attr, conv_res, dst, v_po_vals);

                int dst_zp = has_dst_zp
                        ? dst_zps.get_elem(dst_zp_mask > 0 ? g * OCG + oc : 0)
                        : 0;
                dst = conv_res * dst_scale + dst_zp;
                maybe_round(
                        prb->attr, DNNL_ARG_DST, dst, dst_off, prb->dst_dt());
            });
}

void compute_ref_direct_bwd_d(const prb_t *prb, const args_t &args) {
    const dnn_mem_t &diff_src_m = args.find(DNNL_ARG_DIFF_SRC);
    const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
    const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
    const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
    const dnn_mem_t &src_scales
            = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
    const dnn_mem_t &wei_scales
            = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
    const dnn_mem_t &dst_scales
            = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
    const dnn_mem_t &src_zps
            = args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
    const dnn_mem_t &wei_zps
            = args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
    const dnn_mem_t &dst_zps
            = args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);

    const bool has_src_scale = !prb->attr.scales.get(DNNL_ARG_SRC).is_def();
    const bool has_wei_scale = !prb->attr.scales.get(DNNL_ARG_WEIGHTS).is_def();
    const bool has_dst_scale = !prb->attr.scales.get(DNNL_ARG_DST).is_def();
    assert(IMPLICATION(has_src_scale, src_scales.nelems() == 1));
    assert(IMPLICATION(has_dst_scale, dst_scales.nelems() == 1));
    float src_scale = has_src_scale ? src_scales.get_elem(0) : 1.f;
    float dst_scale = has_dst_scale ? 1.f / dst_scales.get_elem(0) : 1.f;
    const int wei_scale_mask = prb->attr.scales.get_mask(
            DNNL_ARG_WEIGHTS, dnnl_convolution, wei_m.ndims(), prb->has_groups);

    const bool has_src_zp = !prb->attr.zero_points.get(DNNL_ARG_SRC).is_def();
    const bool has_wei_zp
            = !prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).is_def();
    const bool has_dst_zp = !prb->attr.zero_points.get(DNNL_ARG_DST).is_def();
    const int src_zp_mask = attr_t::get_default_mask(
            prb->attr.zero_points.get(DNNL_ARG_SRC).policy);
    const int wei_zp_mask = attr_t::get_default_mask(
            prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).policy);
    const int dst_zp_mask = attr_t::get_default_mask(
            prb->attr.zero_points.get(DNNL_ARG_DST).policy);

    /* help compiler optimize the code */
    const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic;
    const int64_t OCG = OC / G, ICG = IC / G;
    const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
    const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
    const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
    const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
    const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
    const int64_t DD = prb->dd + 1;
    const int64_t DH = prb->dh + 1;
    const int64_t DW = prb->dw + 1;

    int wei_zp = (has_wei_zp && (wei_zp_mask == 0)) ? wei_zps.get_elem(0) : 0;

    enum { precompute_size = 16 };
    const bool fast = MAX3(KD, KH, KW) <= precompute_size;

    /* pre-computes arrays of oh(ow) and kh(kw) for traversing in kernel */
    auto precompute_ok
            = [](int64_t i, int64_t O, int64_t K, int64_t S, int64_t P,
                      int64_t D, int64_t &num, int64_t *_o, int64_t *_k) {
                  assert(K <= precompute_size);
                  num = 0;
                  for (int64_t k = 0; k < K; ++k) {
                      int64_t o = i - k * D + P;
                      if (o < 0 || o % S) continue;
                      o /= S;
                      if (o >= O) continue;
                      _k[num] = k;
                      _o[num] = o;
                      ++num;
                  }
              };

    auto ker_fast = [&](float &ds, int64_t g, int64_t mb, int64_t ic,
                            int64_t id, int64_t ih, int64_t iw) {
        int64_t kd[precompute_size], od[precompute_size], num_d;
        int64_t kh[precompute_size], oh[precompute_size], num_h;
        int64_t kw[precompute_size], ow[precompute_size], num_w;
        precompute_ok(id, OD, KD, SD, PD, DD, num_d, od, kd);
        precompute_ok(ih, OH, KH, SH, PH, DH, num_h, oh, kh);
        precompute_ok(iw, OW, KW, SW, PW, DW, num_w, ow, kw);

        const float *__restrict diff_dst_loc = (const float *)diff_dst_m
                + (mb * OC + g * OCG) * OD * OH * OW;
        const float *__restrict wei_loc
                = (const float *)wei_m + ((g * OCG) * ICG + ic) * KD * KH * KW;

        for_(int64_t d = 0; d < num_d; ++d)
        for_(int64_t h = 0; h < num_h; ++h)
        for_(int64_t w = 0; w < num_w; ++w)
        for (int64_t oc = 0; oc < OCG; ++oc) {
            const int64_t diff_dst_off
                    = ((oc * OD + od[d]) * OH + oh[h]) * OW + ow[w];
            const int64_t wei_off
                    = ((oc * ICG * KD + kd[d]) * KH + kh[h]) * KW + kw[w];
            int src_zp = has_src_zp
                    ? src_zps.get_elem(src_zp_mask > 0 ? g * OCG + oc : 0)
                    : 0;
            float diff_dst_val
                    = (diff_dst_loc[diff_dst_off] - src_zp) * src_scale;

            float wei_scale = 1.f;
            if (has_wei_scale)
                wei_scale = wei_scales.get_elem(
                        wei_scale_mask > 0 ? g * ICG + ic : 0);
            float wei_val = (wei_loc[wei_off] - wei_zp) * wei_scale;
            ds += diff_dst_val * wei_val;
        }
    };

    auto ker = [&](float &ds, int64_t g, int64_t mb, int64_t ic, int64_t id,
                       int64_t ih, int64_t iw) {
        const float *__restrict diff_dst_loc = (const float *)diff_dst_m
                + (mb * OC + g * OCG) * OD * OH * OW;
        const float *__restrict wei_loc
                = (const float *)wei_m + ((g * OCG) * ICG + ic) * KD * KH * KW;

        for (int64_t kd = 0; kd < KD; ++kd) {
            int64_t od = id - kd * DD + PD;
            if (od < 0 || od % SD || od >= OD * SD) continue;
            od /= SD;
            for (int64_t kh = 0; kh < KH; ++kh) {
                int64_t oh = ih - kh * DH + PH;
                if (oh < 0 || oh % SH || oh >= OH * SH) continue;
                oh /= SH;
                for (int64_t kw = 0; kw < KW; ++kw) {
                    int64_t ow = iw - kw * DW + PW;
                    if (ow < 0 || ow % SW || ow >= OW * SW) continue;
                    ow /= SW;
                    for (int64_t oc = 0; oc < OCG; ++oc) {
                        const int64_t diff_dst_off
                                = ((oc * OD + od) * OH + oh) * OW + ow;
                        const int64_t wei_off
                                = ((oc * ICG * KD + kd) * KH + kh) * KW + kw;
                        int src_zp = has_src_zp ? src_zps.get_elem(
                                             src_zp_mask > 0 ? g * OCG + oc : 0)
                                                : 0;
                        float diff_dst_val
                                = (diff_dst_loc[diff_dst_off] - src_zp)
                                * src_scale;

                        float wei_scale = 1.f;
                        if (has_wei_scale)
                            wei_scale = wei_scales.get_elem(
                                    wei_scale_mask > 0 ? g * ICG + ic : 0);
                        float wei_val = (wei_loc[wei_off] - wei_zp) * wei_scale;
                        ds += diff_dst_val * wei_val;
                    }
                }
            }
        }
    };

    auto v_po_masks = prb->attr.post_ops.get_po_masks();
    benchdnn_parallel_nd(G, MB, ICG, ID, IH, IW,
            [&](int64_t g, int64_t mb, int64_t ic, int64_t id, int64_t ih,
                    int64_t iw) {
                size_t src_off = src_off_f(prb, mb, g, ic, id, ih, iw);
                float &ds = ((float *)diff_src_m)[src_off];
                float conv_res = 0;
                if (fast)
                    ker_fast(conv_res, g, mb, ic, id, ih, iw);
                else
                    ker(conv_res, g, mb, ic, id, ih, iw);

                if (prb->dir & FLAG_BIA) {
                    const size_t bia_off = (size_t)g * ICG + ic;
                    conv_res += ((float *)bia_m)[bia_off];
                }

                const auto v_po_vals = prepare_po_vals(
                        diff_src_m, args, v_po_masks, src_off);

                maybe_post_ops(prb->attr, conv_res, ds, v_po_vals);

                int dst_zp = has_dst_zp
                        ? dst_zps.get_elem(dst_zp_mask > 0 ? g * ICG + ic : 0)
                        : 0;
                ds = conv_res * dst_scale + dst_zp;
            });
}

void compute_ref_bwd_weights(const prb_t *prb, const args_t &args) {
    const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
    const dnn_mem_t &diff_wei_m = args.find(DNNL_ARG_DIFF_WEIGHTS);
    const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
    /* help compiler optimize the code */
    const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic;
    const int64_t OCG = OC / G, ICG = IC / G;
    const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
    const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
    const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
    const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
    const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
    const int64_t DD = prb->dd + 1;
    const int64_t DH = prb->dh + 1;
    const int64_t DW = prb->dw + 1;

    auto compute_bounds
            = [](int64_t I, int64_t O, int64_t k, int64_t S, int64_t P,
                      int64_t D, int64_t &o_s, int64_t &o_e) {
                  const float tmp = P - k * D;
                  o_s = MAX2(0, ceilf(tmp / S));
                  o_e = MIN2(O, ceilf((I + tmp) / S));
              };

    auto ker = [&](float &dw, int64_t g, int64_t oc, int64_t ic, int64_t kd,
                       int64_t kh, int64_t kw) {
        int64_t od_s, od_e, oh_s, oh_e, ow_s, ow_e;
        compute_bounds(ID, OD, kd, SD, PD, DD, od_s, od_e);
        compute_bounds(IH, OH, kh, SH, PH, DH, oh_s, oh_e);
        compute_bounds(IW, OW, kw, SW, PW, DW, ow_s, ow_e);
        const int64_t id_s = kd * DD - PD;
        const int64_t ih_s = kh * DH - PH;
        const int64_t iw_s = kw * DW - PW;

        for (int64_t mb = 0; mb < MB; ++mb) {
            const float *__restrict diff_dst_loc = (const float *)diff_dst_m
                    + (mb * OC + g * OCG + oc) * OD * OH * OW;
            const float *__restrict src_loc = (const float *)src_m
                    + (mb * IC + g * ICG + ic) * ID * IH * IW;

            for_(int64_t od = od_s; od < od_e; ++od)
            for_(int64_t oh = oh_s; oh < oh_e; ++oh)
            for (int64_t ow = ow_s; ow < ow_e; ++ow) {
                const int64_t id = od * SD + id_s;
                const int64_t ih = oh * SH + ih_s;
                const int64_t iw = ow * SW + iw_s;

                size_t diff_dst_off = (od * OH + oh) * OW + ow;
                size_t src_off = (id * IH + ih) * IW + iw;
                dw += diff_dst_loc[diff_dst_off] * src_loc[src_off];
            }
        }
    };

    benchdnn_parallel_nd(G, OCG, ICG, KD, KH, KW,
            [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh,
                    int64_t kw) {
                size_t wei_off = wei_off_f(prb, g, oc, ic, kd, kh, kw);
                float &dw = ((float *)diff_wei_m)[wei_off];
                dw = 0;
                ker(dw, g, oc, ic, kd, kh, kw);
            });
}

void compute_ref_bwd_bias(const prb_t *prb, const args_t &args) {
    const dnn_mem_t &diff_bia_m = args.find(DNNL_ARG_DIFF_BIAS);
    const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
    /* help compiler optimize the code */
    const int64_t MB = prb->mb, G = prb->g, OC = prb->oc;
    const int64_t OCG = OC / G;
    const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;

    benchdnn_parallel_nd(G, OCG, [&](int64_t g, int64_t oc) {
        size_t bia_off = bia_off_f(prb, g, oc);
        double sum = 0;

        for_(int64_t mb = 0; mb < MB; ++mb)
        for_(int64_t od = 0; od < OD; ++od)
        for_(int64_t oh = 0; oh < OH; ++oh)
        for (int64_t ow = 0; ow < OW; ++ow) {
            size_t dst_off = dst_off_f(prb, mb, g, oc, od, oh, ow);
            sum += ((float *)diff_dst_m)[dst_off];
        }
        ((float *)diff_bia_m)[bia_off] = (float)sum;
    });
}

void compute_ref_direct_bwd_w(const prb_t *prb, const args_t &args) {
    compute_ref_bwd_weights(prb, args);
    if (!(prb->dir & FLAG_BIA)) return;
    compute_ref_bwd_bias(prb, args);
}

void compute_ref_fwd(
        const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
    if (prim_ref) {
        SAFE_V(execute_and_wait(prim_ref, args));
        return;
    }

    if (prb->alg == WINO && prb->get_dt(SRC) == dnnl_f32) {
        compute_wino_ref_fwd(prb, args);
    } else {
        compute_ref_direct_fwd(prb, args);
    }
}

void compute_ref_bwd_d(
        const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
    if (prim_ref) {
        SAFE_V(execute_and_wait(prim_ref, args));
        return;
    }

    if (prb->alg == WINO && prb->get_dt(SRC) == dnnl_f32) {
        compute_wino_ref_bwd_d(prb, args);
    } else {
        compute_ref_direct_bwd_d(prb, args);
    }
}

void compute_ref_bwd_w(
        const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
    if (prim_ref) {
        SAFE_V(execute_and_wait(prim_ref, args));
        return;
    }

    if (prb->alg == WINO && prb->get_dt(SRC) == dnnl_f32) {
        compute_wino_ref_bwd_w(prb, args);
    } else {
        compute_ref_direct_bwd_w(prb, args);
    }
}

void compute_ref(
        const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
    if (prb->dir & FLAG_FWD)
        compute_ref_fwd(prb, args, prim_ref);
    else if (prb->dir == BWD_D)
        compute_ref_bwd_d(prb, args, prim_ref);
    else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI)
        compute_ref_bwd_w(prb, args, prim_ref);
}

} // namespace conv
