/******************************************************************************
 * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

/**
 * \file
 * AgentScanByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide prefix scan by key.
 */

#pragma once

#include <iterator>

#include "single_pass_scan_operators.cuh"
#include "../block/block_load.cuh"
#include "../block/block_store.cuh"
#include "../block/block_scan.cuh"
#include "../block/block_discontinuity.cuh"
#include "../config.cuh"
#include "../iterator/cache_modified_input_iterator.cuh"


CUB_NAMESPACE_BEGIN


/******************************************************************************
 * Tuning policy types
 ******************************************************************************/

/**
 * Parameterizable tuning policy type for AgentScanByKey
 */

template <int                      _BLOCK_THREADS,
          int                      _ITEMS_PER_THREAD = 1,
          BlockLoadAlgorithm       _LOAD_ALGORITHM   = BLOCK_LOAD_DIRECT,
          CacheLoadModifier        _LOAD_MODIFIER    = LOAD_DEFAULT,
          BlockScanAlgorithm       _SCAN_ALGORITHM   = BLOCK_SCAN_WARP_SCANS,
          BlockStoreAlgorithm      _STORE_ALGORITHM  = BLOCK_STORE_DIRECT>
struct AgentScanByKeyPolicy
{
    enum
    {
        BLOCK_THREADS    = _BLOCK_THREADS,
        ITEMS_PER_THREAD = _ITEMS_PER_THREAD,
    };

    static const BlockLoadAlgorithm  LOAD_ALGORITHM  = _LOAD_ALGORITHM;
    static const CacheLoadModifier   LOAD_MODIFIER   = _LOAD_MODIFIER;
    static const BlockScanAlgorithm  SCAN_ALGORITHM  = _SCAN_ALGORITHM;
    static const BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM;
};


/******************************************************************************
 * Thread block abstractions
 ******************************************************************************/

/**
 * \brief AgentScanByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide prefix scan by key.
 */
template <
    typename AgentScanByKeyPolicyT,       ///< Parameterized AgentScanPolicyT tuning policy type
    typename KeysInputIteratorT,          ///< Random-access input iterator type
    typename ValuesInputIteratorT,        ///< Random-access input iterator type
    typename ValuesOutputIteratorT,       ///< Random-access output iterator type
    typename EqualityOp,                  ///< Equality functor type
    typename ScanOpT,                     ///< Scan functor type
    typename InitValueT,                  ///< The init_value element for ScanOpT type (cub::NullType for inclusive scan)
    typename OffsetT>                     ///< Signed integer type for global offsets
struct AgentScanByKey
{
    //---------------------------------------------------------------------
    // Types and constants
    //---------------------------------------------------------------------

    using KeyT = typename std::iterator_traits<KeysInputIteratorT>::value_type;
    using InputT = typename std::iterator_traits<ValuesInputIteratorT>::value_type;

    // The output value type -- used as the intermediate accumulator
    // Per https://wg21.link/P0571, use InitValueT if provided, otherwise the
    // input iterator's value type.
    using OutputT =
      typename If<Equals<InitValueT, NullType>::VALUE, InputT, InitValueT>::Type;

    using SizeValuePairT = KeyValuePair<OffsetT, OutputT>;
    using KeyValuePairT = KeyValuePair<KeyT, OutputT>;
    using ReduceBySegmentOpT = ReduceBySegmentOp<ScanOpT>;

    using ScanTileStateT = ReduceByKeyScanTileState<OutputT, OffsetT>;

    // Constants
    enum
    {
        IS_INCLUSIVE        = Equals<InitValueT, NullType>::VALUE,            // Inclusive scan if no init_value type is provided
        BLOCK_THREADS       = AgentScanByKeyPolicyT::BLOCK_THREADS,
        ITEMS_PER_THREAD    = AgentScanByKeyPolicyT::ITEMS_PER_THREAD,
        ITEMS_PER_TILE      = BLOCK_THREADS * ITEMS_PER_THREAD,
    };

    using WrappedKeysInputIteratorT = typename If<IsPointer<KeysInputIteratorT>::VALUE,
        CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, KeyT, OffsetT>,   // Wrap the native input pointer with CacheModifiedInputIterator
        KeysInputIteratorT>::Type;
    using WrappedValuesInputIteratorT = typename If<IsPointer<ValuesInputIteratorT>::VALUE,
        CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, InputT, OffsetT>,   // Wrap the native input pointer with CacheModifiedInputIterator
        ValuesInputIteratorT>::Type;

    using BlockLoadKeysT = BlockLoad<KeyT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentScanByKeyPolicyT::LOAD_ALGORITHM>;
    using BlockLoadValuesT = BlockLoad<OutputT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentScanByKeyPolicyT::LOAD_ALGORITHM>;
    using BlockStoreValuesT = BlockStore<OutputT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentScanByKeyPolicyT::STORE_ALGORITHM>;
    using BlockDiscontinuityKeysT = BlockDiscontinuity<KeyT, BLOCK_THREADS, 1, 1>;

    using TilePrefixCallbackT = TilePrefixCallbackOp<SizeValuePairT, ReduceBySegmentOpT, ScanTileStateT>;
    using BlockScanT = BlockScan<SizeValuePairT, BLOCK_THREADS, AgentScanByKeyPolicyT::SCAN_ALGORITHM, 1, 1>;

    union TempStorage
    {
        struct ScanStorage
        {
            typename BlockScanT::TempStorage              scan;
            typename TilePrefixCallbackT::TempStorage     prefix;
            typename BlockDiscontinuityKeysT::TempStorage discontinuity;
        } scan_storage;

        typename BlockLoadKeysT::TempStorage    load_keys;
        typename BlockLoadValuesT::TempStorage  load_values;
        typename BlockStoreValuesT::TempStorage store_values;
    };

    //---------------------------------------------------------------------
    // Per-thread fields
    //---------------------------------------------------------------------

    TempStorage &                 storage;
    WrappedKeysInputIteratorT     d_keys_in;
    WrappedValuesInputIteratorT   d_values_in;
    ValuesOutputIteratorT         d_values_out;
    InequalityWrapper<EqualityOp> inequality_op;
    ScanOpT                       scan_op;
    ReduceBySegmentOpT            pair_scan_op;
    InitValueT                    init_value;

    //---------------------------------------------------------------------
    // Block scan utility methods (first tile)
    //---------------------------------------------------------------------

    // Exclusive scan specialization
    __device__ __forceinline__
    void ScanTile(
        SizeValuePairT (&scan_items)[ITEMS_PER_THREAD],
        SizeValuePairT &tile_aggregate,
        Int2Type<false> /* is_inclusive */)
    {
        BlockScanT(storage.scan_storage.scan)
            .ExclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate);
    }

    // Inclusive scan specialization
    __device__ __forceinline__
    void ScanTile(
        SizeValuePairT (&scan_items)[ITEMS_PER_THREAD],
        SizeValuePairT &tile_aggregate,
        Int2Type<true> /* is_inclusive */)
    {
        BlockScanT(storage.scan_storage.scan)
            .InclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate);
    }

    //---------------------------------------------------------------------
    // Block scan utility methods (subsequent tiles)
    //---------------------------------------------------------------------

    // Exclusive scan specialization (with prefix from predecessors)
    __device__ __forceinline__
    void ScanTile(
        SizeValuePairT (&scan_items)[ITEMS_PER_THREAD],
        SizeValuePairT & tile_aggregate,
        TilePrefixCallbackT &prefix_op,
        Int2Type<false> /* is_incclusive */)
    {
        BlockScanT(storage.scan_storage.scan)
            .ExclusiveScan(scan_items, scan_items, pair_scan_op, prefix_op);
        tile_aggregate = prefix_op.GetBlockAggregate();
    }

    // Inclusive scan specialization (with prefix from predecessors)
    __device__ __forceinline__
    void ScanTile(
        SizeValuePairT (&scan_items)[ITEMS_PER_THREAD],
        SizeValuePairT & tile_aggregate,
        TilePrefixCallbackT &prefix_op,
        Int2Type<true> /* is_inclusive */)
    {
        BlockScanT(storage.scan_storage.scan)
            .InclusiveScan(scan_items, scan_items, pair_scan_op, prefix_op);
        tile_aggregate = prefix_op.GetBlockAggregate();
    }

    //---------------------------------------------------------------------
    // Zip utility methods
    //---------------------------------------------------------------------

    template <bool IS_LAST_TILE>
    __device__ __forceinline__
    void ZipValuesAndFlags(
        OffsetT num_remaining,
        OutputT (&values)[ITEMS_PER_THREAD],
        OffsetT (&segment_flags)[ITEMS_PER_THREAD],
        SizeValuePairT (&scan_items)[ITEMS_PER_THREAD])
    {
        // Zip values and segment_flags
        #pragma unroll
        for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
        {
            // Set segment_flags for first out-of-bounds item, zero for others
            if (IS_LAST_TILE &&
                OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM == num_remaining)
                segment_flags[ITEM] = 1;

            scan_items[ITEM].value = values[ITEM];
            scan_items[ITEM].key   = segment_flags[ITEM];
        }
    }

    __device__ __forceinline__
    void UnzipValues(
        OutputT (&values)[ITEMS_PER_THREAD],
        SizeValuePairT (&scan_items)[ITEMS_PER_THREAD])
    {
        // Zip values and segment_flags
        #pragma unroll
        for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
        {
            values[ITEM] = scan_items[ITEM].value;
        }
    }

    template<bool IsNull=Equals<InitValueT, NullType>::VALUE, typename std::enable_if<!IsNull, int>::type=0>
    __device__ __forceinline__
    void AddInitToScan(
        OutputT (&items)[ITEMS_PER_THREAD],
        OffsetT (&flags)[ITEMS_PER_THREAD])
    {
        #pragma unroll
        for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
        {
            items[ITEM] = flags[ITEM] ? init_value : scan_op(init_value, items[ITEM]);
        }
    }

    template<bool IsNull=Equals<InitValueT, NullType>::VALUE, typename std::enable_if<IsNull, int>::type=0>
    __device__ __forceinline__
    void AddInitToScan(
        OutputT (&/*items*/)[ITEMS_PER_THREAD],
        OffsetT (&/*flags*/)[ITEMS_PER_THREAD])
    {}

    //---------------------------------------------------------------------
    // Cooperatively scan a device-wide sequence of tiles with other CTAs
    //---------------------------------------------------------------------

    // Process a tile of input (dynamic chained scan)
    //
    template <bool IS_LAST_TILE>
    __device__ __forceinline__
    void ConsumeTile(
        OffsetT          /*num_items*/,
        OffsetT          num_remaining,
        int              tile_idx,
        OffsetT          tile_base,
        ScanTileStateT&  tile_state)
    {
        // Load items
        KeyT           keys[ITEMS_PER_THREAD];
        OutputT        values[ITEMS_PER_THREAD];
        OffsetT        segment_flags[ITEMS_PER_THREAD];
        SizeValuePairT scan_items[ITEMS_PER_THREAD];

        if (IS_LAST_TILE)
        {
            // Fill last element with the first element
            // because collectives are not suffix guarded
            BlockLoadKeysT(storage.load_keys)
                .Load(d_keys_in + tile_base,
                      keys,
                      num_remaining,
                      *(d_keys_in + tile_base));
        }
        else
        {
            BlockLoadKeysT(storage.load_keys)
                .Load(d_keys_in + tile_base, keys);
        }

        CTA_SYNC();

        if (IS_LAST_TILE)
        {
            // Fill last element with the first element
            // because collectives are not suffix guarded
            BlockLoadValuesT(storage.load_values)
                .Load(d_values_in + tile_base,
                      values,
                      num_remaining,
                      *(d_values_in + tile_base));
        }
        else
        {
            BlockLoadValuesT(storage.load_values)
                .Load(d_values_in + tile_base, values);
        }

        CTA_SYNC();

        // first tile
        if (tile_idx == 0)
        {
            BlockDiscontinuityKeysT(storage.scan_storage.discontinuity)
                .FlagHeads(segment_flags, keys, inequality_op);

            // Zip values and segment_flags
            ZipValuesAndFlags<IS_LAST_TILE>(num_remaining,
                                            values,
                                            segment_flags,
                                            scan_items);

            // Exclusive scan of values and segment_flags
            SizeValuePairT tile_aggregate;
            ScanTile(scan_items, tile_aggregate, Int2Type<IS_INCLUSIVE>());

            if (threadIdx.x == 0)
            {
                if (!IS_LAST_TILE)
                    tile_state.SetInclusive(0, tile_aggregate);

                scan_items[0].key = 0;
            }
        }
        else
        {
            KeyT tile_pred_key = (threadIdx.x == 0) ? d_keys_in[tile_base - 1] : KeyT();
            BlockDiscontinuityKeysT(storage.scan_storage.discontinuity)
                .FlagHeads(segment_flags, keys, inequality_op, tile_pred_key);

            // Zip values and segment_flags
            ZipValuesAndFlags<IS_LAST_TILE>(num_remaining,
                                            values,
                                            segment_flags,
                                            scan_items);

            SizeValuePairT  tile_aggregate;
            TilePrefixCallbackT prefix_op(tile_state, storage.scan_storage.prefix, pair_scan_op, tile_idx);
            ScanTile(scan_items, tile_aggregate, prefix_op, Int2Type<IS_INCLUSIVE>());
        }

        CTA_SYNC();

        UnzipValues(values, scan_items);

        AddInitToScan(values, segment_flags);

        // Store items
        if (IS_LAST_TILE)
        {
            BlockStoreValuesT(storage.store_values)
                .Store(d_values_out + tile_base, values, num_remaining);
        }
        else
        {
            BlockStoreValuesT(storage.store_values)
                .Store(d_values_out + tile_base, values);
        }
    }

    //---------------------------------------------------------------------
    // Constructor
    //---------------------------------------------------------------------

    // Dequeue and scan tiles of items as part of a dynamic chained scan
    // with Init functor
    __device__ __forceinline__
    AgentScanByKey(
        TempStorage &         storage,
        KeysInputIteratorT    d_keys_in,
        ValuesInputIteratorT  d_values_in,
        ValuesOutputIteratorT d_values_out,
        EqualityOp            equality_op,
        ScanOpT               scan_op,
        InitValueT            init_value)
    : 
        storage(storage),
        d_keys_in(d_keys_in),
        d_values_in(d_values_in),
        d_values_out(d_values_out),
        inequality_op(equality_op),
        scan_op(scan_op),
        pair_scan_op(scan_op),
        init_value(init_value)
    {}
    
    /**
     * Scan tiles of items as part of a dynamic chained scan
     */
    __device__ __forceinline__ void ConsumeRange(
        OffsetT             num_items,          ///< Total number of input items
        ScanTileStateT&     tile_state,         ///< Global tile state descriptor
        int                 start_tile)         ///< The starting tile for the current grid
    {
        int  tile_idx         = blockIdx.x;
        OffsetT tile_base     = OffsetT(ITEMS_PER_TILE) * tile_idx;
        OffsetT num_remaining = num_items - tile_base;

        if (num_remaining > ITEMS_PER_TILE)
        {
            // Not the last tile (full)
            ConsumeTile<false>(num_items,
                               num_remaining,
                               tile_idx,
                               tile_base,
                               tile_state);
        }
        else if (num_remaining > 0)
        {
            // The last tile (possibly partially-full)
            ConsumeTile<true>(num_items,
                              num_remaining,
                              tile_idx,
                              tile_base,
                              tile_state);
        }
    }
};


CUB_NAMESPACE_END
