{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Standard flow control and data processing DataPipes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import IterDataPipe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example IterDataPipe\n",
    "class ExampleIterPipe(IterDataPipe):\n",
    "    def __init__(self, range = 20):\n",
    "        self.range = range\n",
    "    def __iter__(self):\n",
    "        yield from self.range"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch\n",
    "\n",
    "Function: `batch`\n",
    "\n",
    "Description: \n",
    "\n",
    "Alternatives:\n",
    "\n",
    "Arguments:\n",
    "  - `batch_size: int` desired batch size\n",
    "  - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n",
    "  - `drop_last: bool = False`\n",
    "\n",
    "Example:\n",
    "\n",
    "Classic batching produce partial batches by default\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1, 2]\n",
      "[3, 4, 5]\n",
      "[6, 7, 8]\n",
      "[9]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).batch(3)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To drop incomplete batches add `drop_last` argument"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1, 2]\n",
      "[3, 4, 5]\n",
      "[6, 7, 8]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).batch(3, drop_last = True)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sequential calling of `batch` produce nested batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0, 1, 2], [3, 4, 5]]\n",
      "[[6, 7, 8], [9, 10, 11]]\n",
      "[[12, 13, 14], [15, 16, 17]]\n",
      "[[18, 19, 20], [21, 22, 23]]\n",
      "[[24, 25, 26], [27, 28, 29]]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(30).batch(3).batch(2)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is possible to unbatch source data before applying the new batching rule using `unbatch_level` argument"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
      "[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]\n",
      "[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(30).batch(3).batch(2).batch(10, unbatch_level=-1)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unbatch\n",
    "\n",
    "Function: `unbatch`\n",
    "\n",
    "Description: \n",
    "\n",
    "Alternatives:\n",
    "\n",
    "Arguments:\n",
    "    `unbatch_level:int = 1`\n",
    " \n",
    "Example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9\n",
      "0\n",
      "1\n",
      "2\n",
      "6\n",
      "7\n",
      "8\n",
      "3\n",
      "4\n",
      "5\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).batch(3).shuffle().unbatch()\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default unbatching is applied only on the first layer, to unbatch deeper use `unbatch_level` argument"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1]\n",
      "[2, 3]\n",
      "[4, 5]\n",
      "[6, 7]\n",
      "[8, 9]\n",
      "[10, 11]\n",
      "[12, 13]\n",
      "[14, 15]\n",
      "[16, 17]\n",
      "[18, 19]\n",
      "[20, 21]\n",
      "[22, 23]\n",
      "[24, 25]\n",
      "[26, 27]\n",
      "[28, 29]\n",
      "[30, 31]\n",
      "[32, 33]\n",
      "[34, 35]\n",
      "[36, 37]\n",
      "[38, 39]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(40).batch(2).batch(4).batch(3).unbatch(unbatch_level = 2)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setting `unbatch_level` to `-1` will unbatch to the lowest level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "17\n",
      "18\n",
      "19\n",
      "20\n",
      "21\n",
      "22\n",
      "23\n",
      "24\n",
      "25\n",
      "26\n",
      "27\n",
      "28\n",
      "29\n",
      "30\n",
      "31\n",
      "32\n",
      "33\n",
      "34\n",
      "35\n",
      "36\n",
      "37\n",
      "38\n",
      "39\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(40).batch(2).batch(4).batch(3).unbatch(unbatch_level = -1)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Map\n",
    "\n",
    "Function: `map`\n",
    "\n",
    "Description: \n",
    "\n",
    "Alternatives:\n",
    "\n",
    "Arguments:\n",
    "  - `nesting_level: int = 0`\n",
    " \n",
    "Example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "2\n",
      "4\n",
      "6\n",
      "8\n",
      "10\n",
      "12\n",
      "14\n",
      "16\n",
      "18\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).map(lambda x: x * 2)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`map` by default applies function to every mini-batch as a whole\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1, 2, 0, 1, 2]\n",
      "[3, 4, 5, 3, 4, 5]\n",
      "[6, 7, 8, 6, 7, 8]\n",
      "[9, 9]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).batch(3).map(lambda x: x * 2)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To apply function on individual items of the mini-batch use `nesting_level` argument"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0, 2, 4], [6, 8, 10]]\n",
      "[[12, 14, 16], [18]]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).batch(3).batch(2).map(lambda x: x * 2, nesting_level = 2)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setting `nesting_level` to `-1` will apply `map` function to the lowest level possible"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[0, 2, 4], [6, 8, 10]], [[12, 14, 16], [18]]]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).batch(3).batch(2).batch(2).map(lambda x: x * 2, nesting_level = -1)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Filter\n",
    "\n",
    "Function: `filter`\n",
    "\n",
    "Description: \n",
    "\n",
    "Alternatives:\n",
    "\n",
    "Arguments:\n",
    "  - `nesting_level: int = 0`\n",
    "  - `drop_empty_batches = True` whether empty many batches dropped or not.\n",
    " \n",
    "Example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "2\n",
      "4\n",
      "6\n",
      "8\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).filter(lambda x: x % 2 == 0)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Classic `filter` by default applies filter function to every mini-batches as a whole \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1, 2]\n",
      "[3, 4, 5]\n",
      "[6, 7, 8]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10)\n",
    "dp = dp.batch(3).filter(lambda x: len(x) > 2)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can apply filter function on individual elements by setting `nesting_level` argument"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5]\n",
      "[6, 7, 8]\n",
      "[9]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10)\n",
    "dp = dp.batch(3).filter(lambda x: x > 4, nesting_level = 1)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If mini-batch ends with zero elements after filtering default behaviour would be to drop them from the response. You can override this behaviour using `drop_empty_batches` argument.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[]\n",
      "[5]\n",
      "[6, 7, 8]\n",
      "[9]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10)\n",
    "dp = dp.batch(3).filter(lambda x: x > 4, nesting_level = -1, drop_empty_batches = False)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[0, 1, 2], [3]], [[], [10, 11]]]\n",
      "[[[12, 13, 14], [15, 16, 17]], [[18, 19]]]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(20)\n",
    "dp = dp.batch(3).batch(2).batch(2).filter(lambda x: x < 4 or x > 9 , nesting_level = -1, drop_empty_batches = False)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Shuffle\n",
    "\n",
    "Function: `shuffle`\n",
    "\n",
    "Description: \n",
    "\n",
    "Alternatives:\n",
    "\n",
    "Arguments:\n",
    "  - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n",
    "  - `buffer_size: int = 10000`\n",
    " \n",
    "Example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "9\n",
      "4\n",
      "0\n",
      "3\n",
      "7\n",
      "8\n",
      "5\n",
      "6\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).shuffle()\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`shuffle` operates on input mini-batches similar as on individual items"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1, 2]\n",
      "[3, 4, 5]\n",
      "[9]\n",
      "[6, 7, 8]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).batch(3).shuffle()\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To shuffle elements across batches use `shuffle(unbatch_level)` followed by `batch` pattern "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 1, 0]\n",
      "[7, 9, 6]\n",
      "[3, 5, 4]\n",
      "[8]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).batch(3).shuffle(unbatch_level = -1).batch(3)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collate\n",
    "\n",
    "Function: `collate`\n",
    "\n",
    "Description: \n",
    "\n",
    "Alternatives:\n",
    "\n",
    "Arguments:\n",
    " \n",
    "Example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0, 1, 2])\n",
      "tensor([3, 4, 5])\n",
      "tensor([6, 7, 8])\n",
      "tensor([9])\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).batch(3).collate()\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GroupBy\n",
    "\n",
    "Function: `groupby`\n",
    "\n",
    "Usage: `dp.groupby(lambda x: x[0])`\n",
    "\n",
    "Description: Batching items by combining items with same key into same batch \n",
    "\n",
    "Arguments:\n",
    " - `group_key_fn`\n",
    " - `group_size` - yeild resulted group as soon as `group_size` elements accumulated\n",
    " - `guaranteed_group_size:int = None`\n",
    " - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n",
    "\n",
    "#### Attention\n",
    "As datasteam can be arbitrary large, grouping is done on best effort basis and there is no guarantee that same key will never present in the different groups. You can call it local groupby where locallity is the one DataPipe process/thread."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 3, 6, 9]\n",
      "[1, 4, 7]\n",
      "[5, 2, 8]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).shuffle().groupby(lambda x: x % 3)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default group key function is applied to entire input (mini-batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0, 1, 2], [3, 4, 5], [6, 7, 8]]\n",
      "[[9]]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).batch(3).groupby(lambda x: len(x))\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is possible to unnest items from the mini-batches using `unbatch_level` argument"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 3, 6, 9]\n",
      "[1, 4, 7]\n",
      "[2, 5, 8]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10).batch(3).groupby(lambda x: x % 3, unbatch_level = 1)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When internal buffer (defined by `buffer_size`) is overfilled, groupby will yield biggest group available"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[9, 3]\n",
      "[13, 4, 7]\n",
      "[2, 11, 14, 5]\n",
      "[0, 6, 12]\n",
      "[1, 10]\n",
      "[8]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, buffer_size = 5)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`groupby` will produce `group_size` sized batches on as fast as possible basis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[6, 3, 12]\n",
      "[1, 16, 7]\n",
      "[2, 5, 8]\n",
      "[14, 11, 17]\n",
      "[15, 9, 0]\n",
      "[10, 4, 13]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(18).shuffle().groupby(lambda x: x % 3, group_size = 3)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Remaining groups must be at least `guaranteed_group_size` big. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[11, 2, 5]\n",
      "[1, 4, 10]\n",
      "[0, 9, 6]\n",
      "[14, 8]\n",
      "[13, 7]\n",
      "[12, 3]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, group_size = 3, guaranteed_group_size = 2)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Without defined `group_size` function will try to accumulate at least `guaranteed_group_size` elements before yielding resulted group"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3, 6, 9, 12, 0]\n",
      "[14, 2, 8, 11, 5]\n",
      "[7, 4, 1, 13, 10]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This behaviour becomes noticable when data is bigger than buffer and some groups getting evicted before gathering all potential items"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 3]\n",
      "[1, 4, 7]\n",
      "[2, 5, 8]\n",
      "[6, 9, 12]\n",
      "[10, 13]\n",
      "[11, 14]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(15).groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With randomness involved you might end up with incomplete groups (so next example expected to fail in most cases)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[14, 5, 11]\n",
      "[1, 7, 4, 10]\n",
      "[0, 12, 6]\n",
      "[8, 2]\n",
      "[9, 3]\n"
     ]
    },
    {
     "ename": "Exception",
     "evalue": "('Failed to group items', '[13]')",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mException\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-31-673b9dd7fb43>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mdp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mExampleIterPipe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m15\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgroupby\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguaranteed_group_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuffer_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdp\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/dataset/pytorch/torch/utils/data/datapipes/iter/grouping.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    275\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    276\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mbiggest_size\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrop_remaining\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 277\u001b[0;31m                 \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Failed to group items'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbiggest_key\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    278\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    279\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mbiggest_size\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mException\u001b[0m: ('Failed to group items', '[13]')"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To avoid this error and drop incomplete groups, use `drop_remaining` argument"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5, 2, 14]\n",
      "[4, 7, 13, 1, 10]\n",
      "[12, 6, 3, 9]\n",
      "[8, 11]\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6, drop_remaining = True)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zip\n",
    "\n",
    "Function: `zip`\n",
    "\n",
    "Description: \n",
    "\n",
    "Alternatives:\n",
    "\n",
    "Arguments:\n",
    " \n",
    "Example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0, 3)\n",
      "(1, 0)\n",
      "(2, 4)\n",
      "(3, 2)\n",
      "(4, 1)\n"
     ]
    }
   ],
   "source": [
    "_dp = ExampleIterPipe(5).shuffle()\n",
    "dp = ExampleIterPipe(5).zip(_dp)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fork\n",
    "\n",
    "Function: `fork`\n",
    "\n",
    "Description: \n",
    "\n",
    "Alternatives:\n",
    "\n",
    "Arguments:\n",
    " \n",
    "Example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "0\n",
      "1\n",
      "0\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(2)\n",
    "dp1, dp2, dp3 = dp.fork(3)\n",
    "for i in dp1 + dp2 + dp3:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Demultiplexer\n",
    "\n",
    "Function: `demux`\n",
    "\n",
    "Description: \n",
    "\n",
    "Alternatives:\n",
    "\n",
    "Arguments:\n",
    " \n",
    "Example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "4\n",
      "7\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(10)\n",
    "dp1, dp2, dp3 = dp.demux(3, lambda x: x % 3)\n",
    "for i in dp2:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multiplexer\n",
    "\n",
    "Function: `mux`\n",
    "\n",
    "Description: \n",
    "\n",
    "Alternatives:\n",
    "\n",
    "Arguments:\n",
    " \n",
    "Example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "0\n",
      "0\n",
      "1\n",
      "10\n",
      "100\n",
      "2\n",
      "20\n",
      "200\n"
     ]
    }
   ],
   "source": [
    "dp1 = ExampleIterPipe(3)\n",
    "dp2 = ExampleIterPipe(3).map(lambda x: x * 10)\n",
    "dp3 = ExampleIterPipe(3).map(lambda x: x * 100)\n",
    "\n",
    "dp = dp1.mux(dp2, dp3)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Concat\n",
    "\n",
    "Function: `concat`\n",
    "\n",
    "Description: Returns DataPipes with elements from the first datapipe following by elements from second datapipes\n",
    "\n",
    "Alternatives:\n",
    "    \n",
    "    `dp = dp.concat(dp2, dp3)`\n",
    "    `dp = dp.concat(*datapipes_list)`\n",
    "\n",
    "Example:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "0\n",
      "1\n",
      "2\n"
     ]
    }
   ],
   "source": [
    "dp = ExampleIterPipe(4)\n",
    "dp2 = ExampleIterPipe(3)\n",
    "dp = dp.concat(dp2)\n",
    "for i in dp:\n",
    "    print(i)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6.10 64-bit ('dataloader': conda)",
   "name": "python3610jvsc74a57bd0eb5e09632d6ea1cbf3eb9da7e37b7cf581db5ed13074b21cc44e159dc62acdab"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  },
  "orig_nbformat": 2
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
