/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <iterator>
#include <type_traits>

#include <cub/agent/single_pass_scan_operators.cuh>
#include <cub/block/block_discontinuity.cuh>
#include <cub/block/block_exchange.cuh>
#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include <cub/config.cuh>
#include <cub/iterator/cache_modified_input_iterator.cuh>


CUB_NAMESPACE_BEGIN


/******************************************************************************
 * Tuning policy types
 ******************************************************************************/

template <int _BLOCK_THREADS,
          int _ITEMS_PER_THREAD,
          BlockLoadAlgorithm _LOAD_ALGORITHM,
          CacheLoadModifier _LOAD_MODIFIER,
          BlockScanAlgorithm _SCAN_ALGORITHM>
struct AgentThreeWayPartitionPolicy
{
  constexpr static int BLOCK_THREADS                 = _BLOCK_THREADS;
  constexpr static int ITEMS_PER_THREAD              = _ITEMS_PER_THREAD;
  constexpr static BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM;
  constexpr static CacheLoadModifier LOAD_MODIFIER   = _LOAD_MODIFIER;
  constexpr static BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM;
};


/**
 * \brief Implements a device-wide three-way partitioning
 *
 * Splits input data into three parts based on the selection functors. If the
 * first functor selects an item, the algorithm places it in the first part.
 * Otherwise, if the second functor selects an item, the algorithm places it in
 * the second part. If both functors don't select an item, the algorithm places
 * it into the unselected part.
 */
template <typename PolicyT,
          typename InputIteratorT,
          typename FirstOutputIteratorT,
          typename SecondOutputIteratorT,
          typename UnselectedOutputIteratorT,
          typename SelectFirstPartOp,
          typename SelectSecondPartOp,
          typename OffsetT>
struct AgentThreeWayPartition
{
  //---------------------------------------------------------------------
  // Types and constants
  //---------------------------------------------------------------------

  // The input value type
  using InputT = typename std::iterator_traits<InputIteratorT>::value_type;

  // Tile status descriptor interface type
  using ScanTileStateT = cub::ScanTileState<OffsetT>;

  // Constants
  constexpr static int BLOCK_THREADS = PolicyT::BLOCK_THREADS;
  constexpr static int ITEMS_PER_THREAD = PolicyT::ITEMS_PER_THREAD;
  constexpr static int TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD;

  using WrappedInputIteratorT = typename std::conditional<
    std::is_pointer<InputIteratorT>::value,
    cub::CacheModifiedInputIterator<PolicyT::LOAD_MODIFIER, InputT, OffsetT>,
    InputIteratorT>::type;

  // Parameterized BlockLoad type for input data
  using BlockLoadT = cub::BlockLoad<InputT,
                                    BLOCK_THREADS,
                                    ITEMS_PER_THREAD,
                                    PolicyT::LOAD_ALGORITHM>;

  // Parameterized BlockScan type
  using BlockScanT =
    cub::BlockScan<OffsetT, BLOCK_THREADS, PolicyT::SCAN_ALGORITHM>;

  // Callback type for obtaining tile prefix during block scan
  using TilePrefixCallbackOpT =
    cub::TilePrefixCallbackOp<OffsetT, cub::Sum, ScanTileStateT>;

  // Item exchange type
  using ItemExchangeT = InputT[TILE_ITEMS];

  // Shared memory type for this thread block
  union _TempStorage
  {
    struct ScanStorage
    {
      // Smem needed for tile scanning
      typename BlockScanT::TempStorage scan;

      // Smem needed for cooperative prefix callback
      typename TilePrefixCallbackOpT::TempStorage prefix;
    } scan_storage;

    // Smem needed for loading items
    typename BlockLoadT::TempStorage load_items;

    // Smem needed for compacting items (allows non POD items in this union)
    cub::Uninitialized<ItemExchangeT> raw_exchange;
  };

  // Alias wrapper allowing storage to be unioned
  struct TempStorage : cub::Uninitialized<_TempStorage> {};


  //---------------------------------------------------------------------
  // Per-thread fields
  //---------------------------------------------------------------------

  _TempStorage&                        temp_storage;       ///< Reference to temp_storage
  WrappedInputIteratorT                d_in;               ///< Input items
  FirstOutputIteratorT                 d_first_part_out;
  SecondOutputIteratorT                d_second_part_out;
  UnselectedOutputIteratorT            d_unselected_out;
  SelectFirstPartOp                    select_first_part_op;
  SelectSecondPartOp                   select_second_part_op;
  OffsetT                              num_items;          ///< Total number of input items


  //---------------------------------------------------------------------
  // Constructor
  //---------------------------------------------------------------------

  // Constructor
  __device__ __forceinline__
  AgentThreeWayPartition(TempStorage &temp_storage,
                         InputIteratorT d_in,
                         FirstOutputIteratorT d_first_part_out,
                         SecondOutputIteratorT d_second_part_out,
                         UnselectedOutputIteratorT d_unselected_out,
                         SelectFirstPartOp select_first_part_op,
                         SelectSecondPartOp select_second_part_op,
                         OffsetT num_items)
      : temp_storage(temp_storage.Alias())
      , d_in(d_in)
      , d_first_part_out(d_first_part_out)
      , d_second_part_out(d_second_part_out)
      , d_unselected_out(d_unselected_out)
      , select_first_part_op(select_first_part_op)
      , select_second_part_op(select_second_part_op)
      , num_items(num_items)
  {}

  //---------------------------------------------------------------------
  // Utility methods for initializing the selections
  //---------------------------------------------------------------------

  template <bool IS_LAST_TILE>
  __device__ __forceinline__ void
  Initialize(OffsetT num_tile_items,
             InputT (&items)[ITEMS_PER_THREAD],
             OffsetT (&first_items_selection_flags)[ITEMS_PER_THREAD],
             OffsetT (&second_items_selection_flags)[ITEMS_PER_THREAD])
  {
    for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
    {
      // Out-of-bounds items are selection_flags
      first_items_selection_flags[ITEM]  = 1;
      second_items_selection_flags[ITEM] = 1;

      if (!IS_LAST_TILE ||
          (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM < num_tile_items))
      {
        first_items_selection_flags[ITEM] = select_first_part_op(items[ITEM]);
        second_items_selection_flags[ITEM] =
          first_items_selection_flags[ITEM]
            ? 0
            : select_second_part_op(items[ITEM]);
      }
    }
  }

  template <bool IS_LAST_TILE>
  __device__ __forceinline__ void Scatter(
    InputT          (&items)[ITEMS_PER_THREAD],
    OffsetT         (&first_items_selection_flags)[ITEMS_PER_THREAD],
    OffsetT         (&first_items_selection_indices)[ITEMS_PER_THREAD],
    OffsetT         (&second_items_selection_flags)[ITEMS_PER_THREAD],
    OffsetT         (&second_items_selection_indices)[ITEMS_PER_THREAD],
    int             num_tile_items,
    int             num_first_tile_selections,
    int             num_second_tile_selections,
    OffsetT         num_first_selections_prefix,
    OffsetT         num_second_selections_prefix,
    OffsetT         num_rejected_prefix)
  {
    CTA_SYNC();

    int first_item_end = num_first_tile_selections;
    int second_item_end = first_item_end + num_second_tile_selections;

    // Scatter items to shared memory (rejections first)
    for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
    {
      int item_idx = (threadIdx.x * ITEMS_PER_THREAD) + ITEM;

      if (!IS_LAST_TILE || (item_idx < num_tile_items))
      {
        int local_scatter_offset = 0;

        if (first_items_selection_flags[ITEM])
        {
          local_scatter_offset = first_items_selection_indices[ITEM]
                               - num_first_selections_prefix;
        }
        else if (second_items_selection_flags[ITEM])
        {
          local_scatter_offset = first_item_end +
                                 second_items_selection_indices[ITEM] -
                                 num_second_selections_prefix;
        }
        else
        {
          // Medium item
          int local_selection_idx =
           (first_items_selection_indices[ITEM] - num_first_selections_prefix)
         + (second_items_selection_indices[ITEM] - num_second_selections_prefix);
          local_scatter_offset = second_item_end + item_idx - local_selection_idx;
        }

        temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM];
      }
    }

    CTA_SYNC();

    // Gather items from shared memory and scatter to global
    for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
    {
      int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x;

      if (!IS_LAST_TILE || (item_idx < num_tile_items))
      {
        InputT item = temp_storage.raw_exchange.Alias()[item_idx];

        if (item_idx < first_item_end)
        {
          d_first_part_out[num_first_selections_prefix + item_idx] = item;
        }
        else if (item_idx < second_item_end)
        {
          d_second_part_out[num_second_selections_prefix + item_idx - first_item_end] = item;
        }
        else
        {
          int rejection_idx = item_idx - second_item_end;
          d_unselected_out[num_rejected_prefix + rejection_idx] = item;
        }
      }
    }
  }


  //---------------------------------------------------------------------
  // Cooperatively scan a device-wide sequence of tiles with other CTAs
  //---------------------------------------------------------------------


  /**
   * Process first tile of input (dynamic chained scan).
   * Returns the running count of selections (including this tile)
   *
   * @param num_tile_items Number of input items comprising this tile
   * @param tile_offset Tile offset
   * @param first_tile_state Global tile state descriptor
   * @param second_tile_state Global tile state descriptor
   */
  template <bool IS_LAST_TILE>
  __device__ __forceinline__ void
  ConsumeFirstTile(int num_tile_items,
                   OffsetT tile_offset,
                   ScanTileStateT &first_tile_state,
                   ScanTileStateT &second_tile_state,
                   OffsetT &first_items,
                   OffsetT &second_items)
  {
    InputT items[ITEMS_PER_THREAD];

    OffsetT first_items_selection_flags[ITEMS_PER_THREAD];
    OffsetT first_items_selection_indices[ITEMS_PER_THREAD];

    OffsetT second_items_selection_flags[ITEMS_PER_THREAD];
    OffsetT second_items_selection_indices[ITEMS_PER_THREAD];

    // Load items
    if (IS_LAST_TILE)
    {
      BlockLoadT(temp_storage.load_items)
        .Load(d_in + tile_offset, items, num_tile_items);
    }
    else
    {
      BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items);
    }

    // Initialize selection_flags
    Initialize<IS_LAST_TILE>(
      num_tile_items,
      items,
      first_items_selection_flags,
      second_items_selection_flags);

    CTA_SYNC();

    // Exclusive scan of selection_flags
    BlockScanT(temp_storage.scan_storage.scan)
      .ExclusiveSum(first_items_selection_flags,
                    first_items_selection_indices,
                    first_items);

    if (threadIdx.x == 0)
    {
      // Update tile status if this is not the last tile
      if (!IS_LAST_TILE)
      {
        first_tile_state.SetInclusive(0, first_items);
      }
    }

    CTA_SYNC();

    // Exclusive scan of selection_flags
    BlockScanT(temp_storage.scan_storage.scan)
      .ExclusiveSum(second_items_selection_flags,
                    second_items_selection_indices,
                    second_items);

    if (threadIdx.x == 0)
    {
      // Update tile status if this is not the last tile
      if (!IS_LAST_TILE)
      {
        second_tile_state.SetInclusive(0, second_items);
      }
    }

    // Discount any out-of-bounds selections
    if (IS_LAST_TILE)
    {
      first_items -= (TILE_ITEMS - num_tile_items);
      second_items -= (TILE_ITEMS - num_tile_items);
    }

    // Scatter flagged items
    Scatter<IS_LAST_TILE>(
      items,
      first_items_selection_flags,
      first_items_selection_indices,
      second_items_selection_flags,
      second_items_selection_indices,
      num_tile_items,
      first_items,
      second_items,
      // all the prefixes equal to 0 because it's the first tile
      0, 0, 0);
  }


  /**
   * Process subsequent tile of input (dynamic chained scan).
   * Returns the running count of selections (including this tile)
   *
   * @param num_tile_items Number of input items comprising this tile
   * @param tile_idx Tile index
   * @param tile_offset Tile offset
   * @param first_tile_state Global tile state descriptor
   * @param second_tile_state Global tile state descriptor
   */
  template <bool IS_LAST_TILE>
  __device__ __forceinline__ void
  ConsumeSubsequentTile(int num_tile_items,
                        int tile_idx,
                        OffsetT tile_offset,
                        ScanTileStateT &first_tile_state,
                        ScanTileStateT &second_tile_state,
                        OffsetT &num_first_items_selections,
                        OffsetT &num_second_items_selections)
  {
    InputT items[ITEMS_PER_THREAD];

    OffsetT first_items_selection_flags[ITEMS_PER_THREAD];
    OffsetT first_items_selection_indices[ITEMS_PER_THREAD];

    OffsetT second_items_selection_flags[ITEMS_PER_THREAD];
    OffsetT second_items_selection_indices[ITEMS_PER_THREAD];

    // Load items
    if (IS_LAST_TILE)
    {
      BlockLoadT(temp_storage.load_items).Load(
        d_in + tile_offset, items, num_tile_items);
    }
    else
    {
      BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items);
    }

    // Initialize selection_flags
    Initialize<IS_LAST_TILE>(
      num_tile_items,
      items,
      first_items_selection_flags,
      second_items_selection_flags);

    CTA_SYNC();

    // Exclusive scan of values and selection_flags
    TilePrefixCallbackOpT first_prefix_op(first_tile_state,
                                          temp_storage.scan_storage.prefix,
                                          cub::Sum(),
                                          tile_idx);

    BlockScanT(temp_storage.scan_storage.scan)
      .ExclusiveSum(first_items_selection_flags,
                    first_items_selection_indices,
                    first_prefix_op);

    num_first_items_selections                  = first_prefix_op.GetInclusivePrefix();
    OffsetT num_first_items_in_tile_selections  = first_prefix_op.GetBlockAggregate();
    OffsetT num_first_items_selections_prefix   = first_prefix_op.GetExclusivePrefix();

    CTA_SYNC();

    TilePrefixCallbackOpT second_prefix_op(second_tile_state,
                                           temp_storage.scan_storage.prefix,
                                           cub::Sum(),
                                           tile_idx);
    BlockScanT(temp_storage.scan_storage.scan)
      .ExclusiveSum(second_items_selection_flags,
                    second_items_selection_indices,
                    second_prefix_op);

    num_second_items_selections                  = second_prefix_op.GetInclusivePrefix();
    OffsetT num_second_items_in_tile_selections  = second_prefix_op.GetBlockAggregate();
    OffsetT num_second_items_selections_prefix   = second_prefix_op.GetExclusivePrefix();

    OffsetT num_rejected_prefix = (tile_idx * TILE_ITEMS)
                                  - num_first_items_selections_prefix
                                  - num_second_items_selections_prefix;

    // Discount any out-of-bounds selections. There are exactly
    // TILE_ITEMS - num_tile_items elements like that because we
    // marked them as selected in Initialize method.
    if (IS_LAST_TILE)
    {
      const int num_discount = TILE_ITEMS - num_tile_items;

      num_first_items_selections          -= num_discount;
      num_first_items_in_tile_selections  -= num_discount;
      num_second_items_selections         -= num_discount;
      num_second_items_in_tile_selections -= num_discount;
    }

    // Scatter flagged items
    Scatter<IS_LAST_TILE>(
      items,
      first_items_selection_flags,
      first_items_selection_indices,
      second_items_selection_flags,
      second_items_selection_indices,
      num_tile_items,
      num_first_items_in_tile_selections,
      num_second_items_in_tile_selections,
      num_first_items_selections_prefix,
      num_second_items_selections_prefix,
      num_rejected_prefix);
  }


  /**
   * Process a tile of input
   */
  template <bool IS_LAST_TILE>
  __device__ __forceinline__ void ConsumeTile(
    int                 num_tile_items,
    int                 tile_idx,
    OffsetT             tile_offset,
    ScanTileStateT&     first_tile_state,
    ScanTileStateT&     second_tile_state,
    OffsetT&            first_items,
    OffsetT&            second_items)
  {
    if (tile_idx == 0)
    {
      ConsumeFirstTile<IS_LAST_TILE>(num_tile_items,
                                     tile_offset,
                                     first_tile_state,
                                     second_tile_state,
                                     first_items,
                                     second_items);
    }
    else
    {
      ConsumeSubsequentTile<IS_LAST_TILE>(num_tile_items,
                                          tile_idx,
                                          tile_offset,
                                          first_tile_state,
                                          second_tile_state,
                                          first_items,
                                          second_items);
    }
  }


  /**
   * Scan tiles of items as part of a dynamic chained scan
   *
   * @tparam NumSelectedIteratorT
   *   Output iterator type for recording number of items selection_flags
   *
   * @param num_tiles
   *   Total number of input tiles
   *
   * @param first_tile_state
   *   Global tile state descriptor
   *
   * @param second_tile_state
   *   Global tile state descriptor
   *
   * @param d_num_selected_out
   *   Output total number selection_flags
   */
  template <typename NumSelectedIteratorT>
  __device__ __forceinline__ void
  ConsumeRange(int num_tiles,
               ScanTileStateT &first_tile_state,
               ScanTileStateT &second_tile_state,
               NumSelectedIteratorT d_num_selected_out)
  {
    // Blocks are launched in increasing order, so just assign one tile per block
    // Current tile index
    int tile_idx = static_cast<int>((blockIdx.x * gridDim.y) + blockIdx.y);

    // Global offset for the current tile
    OffsetT tile_offset = tile_idx * TILE_ITEMS;

    OffsetT num_first_selections;
    OffsetT num_second_selections;

    if (tile_idx < num_tiles - 1)
    {
      // Not the last tile (full)
      ConsumeTile<false>(TILE_ITEMS,
                         tile_idx,
                         tile_offset,
                         first_tile_state,
                         second_tile_state,
                         num_first_selections,
                         num_second_selections);
    }
    else
    {
      // The last tile (possibly partially-full)
      OffsetT num_remaining = num_items - tile_offset;

      ConsumeTile<true>(num_remaining,
                        tile_idx,
                        tile_offset,
                        first_tile_state,
                        second_tile_state,
                        num_first_selections,
                        num_second_selections);

      if (threadIdx.x == 0)
      {
        // Output the total number of items selection_flags
        d_num_selected_out[0] = num_first_selections;
        d_num_selected_out[1] = num_second_selections;
      }
    }
  }

};


CUB_NAMESPACE_END
