/***************************************************************************************************
 * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Unit tests for thread-level GEMM
*/

#include "../../common/cutlass_unit_test.h"

#include "cutlass/layout/layout.h"
#include "cutlass/epilogue/thread/activation.h"

#include "cutlass/util/host_tensor.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, int N, typename Func>
__global__ void test_Epilogue_thread_activation(T *out, T *in) {

  cutlass::Array<T, N> *vec_out = reinterpret_cast<cutlass::Array<T, N> *>(out);
  cutlass::Array<T, N> *vec_in = reinterpret_cast<cutlass::Array<T, N> *>(in);

  Func func;
  vec_out[threadIdx.x] = func(vec_in[threadIdx.x]);
}

/////////////////////////////////////////////////////////////////////////////////////////////////

//
// Reference
//

static double GELU_golden_input[] = {
    1.587425827980,  1.157652974129,  0.750432848930, -0.965980410576,
    -0.388184845448,  0.014422321692,  0.353164494038,  1.354383468628,
     0.167588576674,  0.272798538208, -0.377032428980,  1.923444747925,
     0.308164477348, -0.341318070889,  0.278338819742, -0.292668998241,
    -1.051743745804, -0.814175724983,  0.112737402320,  1.262938618660,
    -1.582363605499,  0.722016870975,  1.053453564644, -0.659764587879,
     0.734917521477,  0.091274201870,  0.604461073875, -0.219043627381,
    -0.136795744300,  0.960650205612, -1.805408835411,  0.091029644012,
    -1.023343324661,  0.147713735700, -0.499895423651,  1.351878166199,
    -1.631091356277, -0.336171895266, -1.612408638000,  0.090832948685,
    -0.658132910728, -0.326727777719, -1.986387014389,  0.787685871124,
    -1.015677452087, -0.225094825029,  0.876752018929,  0.744826257229,
     0.870290279388, -0.757595360279,  1.510331749916,  0.750012576580,
     0.906444966793, -0.915759027004,  1.260277032852, -0.158465340734,
    -0.109191477299, -0.817102134228,  0.391305118799, -0.524910449982,
     0.351349592209,  0.801979541779,  0.446691334248, -0.741077482700,
     1.205966711044, -0.910210072994,  0.945986449718,  0.784096539021,
     1.670521497726,  0.344931513071, -0.301411420107,  0.309870749712,
    -0.879704594612, -1.951189517975, -0.805817663670, -0.661812782288,
    -0.505914270878, -1.836273789406, -0.381845980883, -0.554707705975,
    -0.375447630882, -0.516645610332,  0.509586095810,  1.087131023407,
     2.664817094803, -1.558295488358, -0.076461032033, -0.504621028900,
     1.327111959457, -1.819981694221,  1.350415468216, -2.074112653732,
     1.501431345940, -1.339013576508,  0.162817999721, -1.473457217216,
     0.357770472765,  0.188413277268,  1.601302266121, -0.653882205486,
     0.856162548065,  0.763102591038, -0.526283502579,  0.581961452961,
     0.089969776571,  1.968745589256,  0.545802056789, -1.168786048889,
     1.206663012505, -0.109096683562, -1.223938226700,  0.744599223137,
    -1.779406785965,  0.766436159611, -0.579044401646, -1.002057313919,
    -0.715845823288, -0.562508940697,  0.886768460274,  2.327786445618,
    -0.148763969541, -0.918884515762, -0.367678701878, -1.105021238327,
    -0.461237311363,  0.158228352666, -0.254040330648,  1.427477598190,
     0.277530491352,  0.046293262392, -0.535557329655, -1.486695051193,
    -0.953706681728, -1.040495038033, -0.314667612314,  0.348172843456,
     0.522773325443,  0.025960063562, -0.482472360134,  1.993084549904,
    -0.253064930439, -0.012146313675, -2.166327714920,  0.398040622473,
    -0.022238900885, -0.443580865860, -0.898376941681, -0.571689844131,
     1.666979670525, -0.831176340580, -0.671057403088,  0.481970995665,
    -1.096243023872, -1.493894338608,  0.596651911736, -0.229505166411,
     1.165976166725,  0.905094027519,  0.049716457725, -1.362933635712,
    -0.366948783398,  1.461613893509, -0.718411505222,  0.895385026932,
    -0.763122260571,  1.329716682434,  1.366570711136, -0.086544901133,
     0.059739742428,  0.940766513348, -0.272854357958, -1.738811373711,
    -0.361239165068,  0.696977972984,  1.288442254066,  1.264815807343,
    -0.573566436768, -1.141678214073,  0.081865988672, -0.886228799820,
    -0.236933603883,  1.050115466118, -0.538952171803,  0.651773929596,
    -0.220034509897, -1.198960781097,  1.247478365898, -0.053529661149,
     0.639809548855,  1.672434806824,  0.511088073254, -1.179364681244,
    -0.730427742004,  0.157630980015,  0.389369845390, -0.925578773022,
    -0.093250080943, -0.391062080860,  0.852983593941,  1.868778109550,
    -1.198786258698,  0.604997038841, -1.482687234879, -2.469333171844,
     0.718807697296, -0.559609353542,  2.187228441238, -2.927527904510,
     0.148535788059, -0.097280368209,  0.674131810665, -1.137645959854,
     0.792729616165, -1.166317462921, -0.498791724443,  1.675866723061,
    -0.137909621000, -0.653263568878, -2.281216144562,  0.296096831560,
     2.002410173416,  1.083609819412,  0.933580815792, -1.504760265350,
     2.185185909271,  0.286121010780, -1.035485863686, -0.216372340918,
    -0.274334043264, -0.849510788918, -1.397169828415, -0.407644748688,
     0.159476816654, -0.170650705695,  0.335193097591, -0.156852483749,
     0.036168430001,  0.858105242252, -1.086121797562,  0.404813349247,
    -0.481496721506, -0.389882832766,  0.020690204576, -0.772020936012,
    -0.758921504021,  0.323482036591,  0.115715265274, -0.811228036880,
    -0.882436633110,  0.176811277866,  1.678015947342,  0.379081040621,
    -0.842976212502,  0.346952259541, -0.545828759670,  1.632800459862
};

static double GELU_golden_output[] = {
    1.498199582100,  1.014679551125,  0.580462038517, -0.161344811320,
    -0.135453075171,  0.007294139825,  0.225325092673,  1.235459089279,
     0.094946734607,  0.165724009275, -0.133120641112,  1.871103763580,
     0.191376730800, -0.125069886446,  0.169681981206, -0.112644664943,
    -0.154036879539, -0.169163048267,  0.061428427696,  1.132469892502,
    -0.089851818979,  0.552240371704,  0.899579226971, -0.168043658137,
     0.565008401871,  0.048956073821,  0.439583092928, -0.090532489121,
    -0.060955654830,  0.798911273479, -0.064101703465,  0.048816055059,
    -0.156645998359,  0.082529976964, -0.154254898429,  1.232632875443,
    -0.083896033466, -0.123835846782, -0.086161509156,  0.048703473061,
    -0.167972877622, -0.121522113681, -0.046670529991,  0.617986679077,
    -0.157319813967, -0.092503339052,  0.709896743298,  0.574865520000,
     0.703132867813, -0.169963955879,  1.411436080933,  0.580042064190,
     0.741154611111, -0.164741978049,  1.129479527473, -0.069256491959,
    -0.049848672003, -0.169087052345,  0.255214750767, -0.157380074263,
     0.223928079009,  0.632535398006,  0.300378054380, -0.169946283102,
     1.068588852882, -0.165071934462,  0.783203184605,  0.614346146584,
     1.591325283051,  0.219006344676, -0.115003645420,  0.192637458444,
    -0.166712537408, -0.049788996577, -0.169361919165, -0.168130636215,
    -0.155041679740, -0.060888241976, -0.134137839079, -0.160614117980,
    -0.132782235742, -0.156389534473,  0.354075312614,  0.936574816704,
     2.654553413391, -0.092845752835, -0.035900454968, -0.154874503613,
     1.204704761505, -0.062572605908,  1.230982899666, -0.039479542524,
     1.401402950287, -0.120890334249,  0.091938301921, -0.103604510427,
     0.228880971670,  0.108285568655,  1.513783097267, -0.167782157660,
     0.688394129276,  0.593158841133, -0.157540664077,  0.418839782476,
     0.048209801316,  1.920528769493,  0.386099845171, -0.141709372401,
     1.069367766380, -0.049809500575, -0.135230198503,  0.574639260769,
    -0.066881760955,  0.596510827541, -0.162873372436, -0.158483341336,
    -0.169686436653, -0.161375194788,  0.720409095287,  2.304597616196,
    -0.065585561097, -0.164551988244, -0.131098195910, -0.148708447814,
    -0.148663327098,  0.089060656726, -0.101548098028,  1.317959904671,
     0.169103100896,  0.024001283571, -0.158595800400, -0.101909510791,
    -0.162240833044, -0.155090972781, -0.118474565446,  0.221488356590,
     0.365645468235,  0.013248858973, -0.151851043105,  1.946992278099,
    -0.101253561676, -0.006014300976, -0.032804865390,  0.260597169399,
    -0.010922161862, -0.145792976022, -0.165743649006, -0.162226170301,
     1.587365984917, -0.168676435947, -0.168497130275,  0.330191940069,
    -0.149622067809, -0.100989677012,  0.432351946831, -0.093922272325,
     1.023946166039,  0.739726305008,  0.025843897834, -0.117827951908,
    -0.130937814713,  1.356489539146, -0.169726014137,  0.729478538036,
    -0.169943705201,  1.207641005516,  1.249209761620, -0.040288090706,
     0.031292784959,  0.777626037598, -0.107090584934, -0.071350336075,
    -0.129670530558,  0.527676224709,  1.161149263382,  1.134579420090,
    -0.162394225597, -0.144757837057,  0.043603736907, -0.166386902332,
    -0.096278958023,  0.895924389362, -0.158969298005,  0.484089732170,
    -0.090857118368, -0.138206124306,  1.115107178688, -0.025622237474,
     0.472724437714,  1.593463659286,  0.355387806892, -0.140493586659,
    -0.169871479273,  0.088687323034,  0.253673940897, -0.164135158062,
    -0.043161027133, -0.136040985584,  0.685087263584,  1.811169505119,
    -0.138226687908,  0.440080583096, -0.102422207594, -0.016713079065,
     0.549075841904, -0.161096408963,  2.155813455582, -0.005001218989,
     0.083037458360, -0.044870752841,  0.505522191525, -0.145202502608,
     0.623111069202, -0.141991063952, -0.154108211398,  1.597298502922,
    -0.061391282827, -0.167753636837, -0.025704355910,  0.182520583272,
     1.957115054131,  0.932696640491,  0.769961357117, -0.099604383111,
     2.153636932373,  0.175279796124, -0.155551761389, -0.089653611183,
    -0.107515335083, -0.168032020330, -0.113423995674, -0.139319628477,
     0.089841812849, -0.073763631284,  0.211594089866, -0.068651281297,
     0.018605981022,  0.690416753292, -0.150658726692,  0.266040354967,
    -0.151710823178, -0.135800719261,  0.010515870526, -0.169883996248,
    -0.169960290194,  0.202769815922,  0.063187584281, -0.169236257672,
    -0.166577890515,  0.100812792778,  1.599699616432,  0.245525524020,
    -0.168275654316,  0.220552831888, -0.159705042839,  1.549110531807
};

/////////////////////////////////////////////////////////////////////////////////////////////////

TEST(Epilogue_thread_gelu_taylor, device_f32) {

    int const kN = 256;
    int const kV = 4;

    using Element = float;
    using Func = cutlass::epilogue::thread::GELU_taylor<cutlass::Array<Element, kV>>;

    double tolerance = 0.005;
    
    //
    // Construct workspace
    //
    cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Destination({1, kN});
    cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Source({1, kN});

    for (int i = 0; i < kN; ++i) {
        tensor_Source.host_data(i) = Element(GELU_golden_input[i]);
    }

    tensor_Destination.sync_device();
    tensor_Source.sync_device();

    //
    // Launch the kernel
    //
    dim3 grid(1,1,1);
    dim3 block(kN / kV, 1, 1);

    test_Epilogue_thread_activation<Element, kV, Func><<< grid, block >>>(
        tensor_Destination.device_data(),
        tensor_Source.device_data());

    tensor_Destination.sync_host();

    //
    // Verify
    //

    for (int i = 0; i < kN; ++i) {
        Element input = Element(GELU_golden_input[i]);
        Element got = tensor_Destination.host_data(i);
        Element expected = Element(GELU_golden_output[i]);

        double rel_error = (double(got) - double(expected)) / double(expected);

        double tolerance_override = tolerance;

        switch (i) {
            case 142: tolerance_override = 0.008; break;
            case 203: tolerance_override = 0.03; break;
            case 207: tolerance_override = 0.09; break;
            case 218: tolerance_override = 0.013; break;
        }

        EXPECT_LT(std::abs(rel_error), tolerance_override) 
            << "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected;
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////

TEST(Epilogue_thread_gelu_taylor, device_f16) {

    int const kN = 256;
    int const kV = 8;

    using Element = cutlass::half_t;
    using Func = cutlass::epilogue::thread::GELU_taylor<cutlass::Array<Element, kV>>;

    double tolerance = 0.005;

    //
    // Construct workspace
    //
    cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Destination({1, kN});
    cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Source({1, kN});

    for (int i = 0; i < kN; ++i) {
        tensor_Source.host_data(i) = Element(GELU_golden_input[i]);
    }

    tensor_Destination.sync_device();
    tensor_Source.sync_device();

    //
    // Launch the kernel
    //
    dim3 grid(1,1,1);
    dim3 block(kN / kV, 1, 1);

    test_Epilogue_thread_activation<Element, kV, Func><<< grid, block >>>(
        tensor_Destination.device_data(),
        tensor_Source.device_data());

    tensor_Destination.sync_host();

    //
    // Verify
    //

    for (int i = 0; i < kN; ++i) {
        Element input = Element(GELU_golden_input[i]);
        Element got = tensor_Destination.host_data(i);
        Element expected = Element(GELU_golden_output[i]);

        double rel_error = (double(got) - double(expected)) / double(expected);
        
        double tolerance_override = tolerance;

        switch (i) {
            case 36: tolerance_override = 0.006; break;
            case 77: tolerance_override = 0.009; break;
            case 95: tolerance_override = 0.008; break;
            case 112: tolerance_override = 0.007; break;
            case 171: tolerance_override = 0.006; break;
            case 203: tolerance_override = 0.03; break;
            case 207: tolerance_override = 0.15; break;
        }

        EXPECT_LT(std::abs(rel_error), tolerance_override) 
            << "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected;
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////
