{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cutlass\n",
    "import cutlass.cute as cute\n",
    "from cutlass.cute.runtime import from_dlpack\n",
    "\n",
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to the TensorSSA in CuTe DSL\n",
    "\n",
    "This tutorial introduces what is the `TensorSSA` and why we need it. We also give some examples to show how to use `TensorSSA`.\n",
    "\n",
    "## What is TensorSSA\n",
    "\n",
    "`TensorSSA` is a Python class that represents a tensor value in Static Single Assignment (SSA) form within the CuTe DSL. You can think of it as a tensor residing in a (simulated) register.\n",
    "\n",
    "## Why TensorSSA\n",
    "\n",
    "`TensorSSA` encapsulates the underlying MLIR tensor value into an object that's easier to manipulate in Python. By overloading numerous Python operators (like `+`, `-`, `*`, `/`, `[]`, etc.), it allows users to express tensor computations (primarily element-wise operations and reductions) in a more Pythonic way. These element-wise operations are then translated into optimized vectorization instructions.\n",
    "\n",
    "It's part of the CuTe DSL, serving as a bridge between the user-described computational logic and the lower-level MLIR IR, particularly for representing and manipulating register-level data.\n",
    "\n",
    "## When to use TensorSSA\n",
    "\n",
    "`TensorSSA` is primarily used in the following scenarios:\n",
    "\n",
    "### Load from memory and store to memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a_vec: tensor_value<vector<12xf32> o (3, 4)>\n",
      "b_vec: tensor_value<vector<12xf32> o (3, 4)>\n",
      "tensor(raw_ptr(0x0000000006cff170: f32, generic, align<4>) o (3,4):(4,1), data=\n",
      "       [[ 2.000000,  2.000000,  2.000000,  2.000000, ],\n",
      "        [ 2.000000,  2.000000,  2.000000,  2.000000, ],\n",
      "        [ 2.000000,  2.000000,  2.000000,  2.000000, ]])\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def load_and_store(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
    "    \"\"\"\n",
    "    Load data from memory and store the result to memory.\n",
    "\n",
    "    :param res: The destination tensor to store the result.\n",
    "    :param a: The source tensor to be loaded.\n",
    "    :param b: The source tensor to be loaded.\n",
    "    \"\"\"\n",
    "    a_vec = a.load()\n",
    "    print(f\"a_vec: {a_vec}\")      # prints `a_vec: vector<12xf32> o (3, 4)`\n",
    "    b_vec = b.load()\n",
    "    print(f\"b_vec: {b_vec}\")      # prints `b_vec: vector<12xf32> o (3, 4)`\n",
    "    res.store(a_vec + b_vec)\n",
    "    cute.print_tensor(res)\n",
    "\n",
    "a = np.ones(12).reshape((3, 4)).astype(np.float32)\n",
    "b = np.ones(12).reshape((3, 4)).astype(np.float32)\n",
    "c = np.zeros(12).reshape((3, 4)).astype(np.float32)\n",
    "load_and_store(from_dlpack(c), from_dlpack(a), from_dlpack(b))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Register-Level Tensor Operations\n",
    "\n",
    "When writing kernel logic, various computations, transformations, slicing, etc., are performed on data loaded into registers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor_value<vector<24xf32> o (4, 2, 3)> -> tensor_value<vector<12xf32> o (4, 3)>\n",
      "tensor(raw_ptr(0x00000000071acaf0: f32, generic, align<4>) o (4,3):(3,1), data=\n",
      "       [[ 3.000000,  4.000000,  5.000000, ],\n",
      "        [ 9.000000,  10.000000,  11.000000, ],\n",
      "        [ 15.000000,  16.000000,  17.000000, ],\n",
      "        [ 21.000000,  22.000000,  23.000000, ]])\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def apply_slice(src: cute.Tensor, dst: cute.Tensor, indices: cutlass.Constexpr):\n",
    "    \"\"\"\n",
    "    Apply slice operation on the src tensor and store the result to the dst tensor.\n",
    "\n",
    "    :param src: The source tensor to be sliced.\n",
    "    :param dst: The destination tensor to store the result.\n",
    "    :param indices: The indices to slice the source tensor.\n",
    "    \"\"\"\n",
    "    src_vec = src.load()\n",
    "    dst_vec = src_vec[indices]\n",
    "    print(f\"{src_vec} -> {dst_vec}\")\n",
    "    if isinstance(dst_vec, cute.TensorSSA):\n",
    "        dst.store(dst_vec)\n",
    "        cute.print_tensor(dst)\n",
    "    else:\n",
    "        dst[0] = dst_vec\n",
    "        cute.print_tensor(dst)\n",
    "\n",
    "def slice_1():\n",
    "    src_shape = (4, 2, 3)\n",
    "    dst_shape = (4, 3)\n",
    "    indices = (None, 1, None)\n",
    "\n",
    "    \"\"\"\n",
    "    a:\n",
    "    [[[ 0.  1.  2.]\n",
    "      [ 3.  4.  5.]]\n",
    "\n",
    "     [[ 6.  7.  8.]\n",
    "      [ 9. 10. 11.]]\n",
    "\n",
    "     [[12. 13. 14.]\n",
    "      [15. 16. 17.]]\n",
    "\n",
    "     [[18. 19. 20.]\n",
    "      [21. 22. 23.]]]\n",
    "    \"\"\"\n",
    "    a = np.arange(np.prod(src_shape)).reshape(*src_shape).astype(np.float32)\n",
    "    dst = np.random.randn(*dst_shape).astype(np.float32)\n",
    "    apply_slice(from_dlpack(a), from_dlpack(dst), indices)\n",
    "\n",
    "slice_1()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor_value<vector<24xf32> o (4, 2, 3)> -> ?\n",
      "tensor(raw_ptr(0x00000000013cbbe0: f32, generic, align<4>) o (1):(1), data=\n",
      "       [ 10.000000, ])\n"
     ]
    }
   ],
   "source": [
    "def slice_2():\n",
    "    src_shape = (4, 2, 3)\n",
    "    dst_shape = (1,)\n",
    "    indices = 10\n",
    "    a = np.arange(np.prod(src_shape)).reshape(*src_shape).astype(np.float32)\n",
    "    dst = np.random.randn(*dst_shape).astype(np.float32)\n",
    "    apply_slice(from_dlpack(a), from_dlpack(dst), indices)\n",
    "\n",
    "slice_2()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Arithmetic Operations\n",
    "\n",
    "As we mentioned earlier, there're many tensor operations whose operands are `TensorSSA`. And they are all element-wise operations. We give some examples below.\n",
    "\n",
    "### Binary Operations\n",
    "\n",
    "For binary operations, the LHS operand is `TensorSSA` and the RHS operand can be either `TensorSSA` or `Numeric`. When the RHS is `Numeric`, it will be broadcast to a `TensorSSA`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
      "       [ 3.000000, ],\n",
      "       [ 3.000000, ],\n",
      "       [ 3.000000, ])\n",
      "tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
      "       [-1.000000, ],\n",
      "       [-1.000000, ],\n",
      "       [-1.000000, ])\n",
      "tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
      "       [ 2.000000, ],\n",
      "       [ 2.000000, ],\n",
      "       [ 2.000000, ])\n",
      "tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
      "       [ 0.500000, ],\n",
      "       [ 0.500000, ],\n",
      "       [ 0.500000, ])\n",
      "tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
      "       [ 0.000000, ],\n",
      "       [ 0.000000, ],\n",
      "       [ 0.000000, ])\n",
      "tensor(raw_ptr(0x00000000074f0e70: f32, generic, align<4>) o (3):(1), data=\n",
      "       [ 1.000000, ],\n",
      "       [ 1.000000, ],\n",
      "       [ 1.000000, ])\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def binary_op_1(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
    "    a_vec = a.load()\n",
    "    b_vec = b.load()\n",
    "\n",
    "    add_res = a_vec + b_vec\n",
    "    res.store(add_res)\n",
    "    cute.print_tensor(res)        # prints [3.000000, 3.000000, 3.000000]\n",
    "\n",
    "    sub_res = a_vec - b_vec\n",
    "    res.store(sub_res)\n",
    "    cute.print_tensor(res)        # prints [-1.000000, -1.000000, -1.000000]\n",
    "\n",
    "    mul_res = a_vec * b_vec\n",
    "    res.store(mul_res)\n",
    "    cute.print_tensor(res)        # prints [2.000000, 2.000000, 2.000000]\n",
    "\n",
    "    div_res = a_vec / b_vec\n",
    "    res.store(div_res)\n",
    "    cute.print_tensor(res)        # prints [0.500000, 0.500000, 0.500000]\n",
    "\n",
    "    floor_div_res = a_vec // b_vec\n",
    "    res.store(floor_div_res)\n",
    "    cute.print_tensor(res)        # prints [0.000000, 0.000000, 0.000000]\n",
    "\n",
    "    mod_res = a_vec % b_vec\n",
    "    res.store(mod_res)\n",
    "    cute.print_tensor(res)        # prints [1.000000, 1.000000, 1.000000]\n",
    "\n",
    "\n",
    "a = np.empty((3,), dtype=np.float32)\n",
    "a.fill(1.0)\n",
    "b = np.empty((3,), dtype=np.float32)\n",
    "b.fill(2.0)\n",
    "res = np.empty((3,), dtype=np.float32)\n",
    "binary_op_1(from_dlpack(res), from_dlpack(a), from_dlpack(b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
      "       [ 3.000000, ],\n",
      "       [ 3.000000, ],\n",
      "       [ 3.000000, ])\n",
      "tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
      "       [-1.000000, ],\n",
      "       [-1.000000, ],\n",
      "       [-1.000000, ])\n",
      "tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
      "       [ 2.000000, ],\n",
      "       [ 2.000000, ],\n",
      "       [ 2.000000, ])\n",
      "tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
      "       [ 0.500000, ],\n",
      "       [ 0.500000, ],\n",
      "       [ 0.500000, ])\n",
      "tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
      "       [ 0.000000, ],\n",
      "       [ 0.000000, ],\n",
      "       [ 0.000000, ])\n",
      "tensor(raw_ptr(0x0000000007828ed0: f32, generic, align<4>) o (3):(1), data=\n",
      "       [ 1.000000, ],\n",
      "       [ 1.000000, ],\n",
      "       [ 1.000000, ])\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def binary_op_2(res: cute.Tensor, a: cute.Tensor, c: cutlass.Constexpr):\n",
    "    a_vec = a.load()\n",
    "\n",
    "    add_res = a_vec + c\n",
    "    res.store(add_res)\n",
    "    cute.print_tensor(res)        # prints [3.000000, 3.000000, 3.000000]\n",
    "\n",
    "    sub_res = a_vec - c\n",
    "    res.store(sub_res)\n",
    "    cute.print_tensor(res)        # prints [-1.000000, -1.000000, -1.000000]\n",
    "\n",
    "    mul_res = a_vec * c\n",
    "    res.store(mul_res)\n",
    "    cute.print_tensor(res)        # prints [2.000000, 2.000000, 2.000000]\n",
    "\n",
    "    div_res = a_vec / c\n",
    "    res.store(div_res)\n",
    "    cute.print_tensor(res)        # prints [0.500000, 0.500000, 0.500000]\n",
    "\n",
    "    floor_div_res = a_vec // c\n",
    "    res.store(floor_div_res)\n",
    "    cute.print_tensor(res)        # prints [0.000000, 0.000000, 0.000000]\n",
    "\n",
    "    mod_res = a_vec % c\n",
    "    res.store(mod_res)\n",
    "    cute.print_tensor(res)        # prints [1.000000, 1.000000, 1.000000]\n",
    "\n",
    "a = np.empty((3,), dtype=np.float32)\n",
    "a.fill(1.0)\n",
    "c = 2.0\n",
    "res = np.empty((3,), dtype=np.float32)\n",
    "binary_op_2(from_dlpack(res), from_dlpack(a), c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[False  True False]\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def binary_op_3(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
    "    a_vec = a.load()\n",
    "    b_vec = b.load()\n",
    "\n",
    "    gt_res = a_vec > b_vec\n",
    "    res.store(gt_res)\n",
    "\n",
    "    \"\"\"\n",
    "    ge_res = a_ >= b_   # [False, True, False]\n",
    "    lt_res = a_ < b_    # [True, False, True]\n",
    "    le_res = a_ <= b_   # [True, False, True]\n",
    "    eq_res = a_ == b_   # [False, False, False]\n",
    "    \"\"\"\n",
    "\n",
    "a = np.array([1, 2, 3], dtype=np.float32)\n",
    "b = np.array([2, 1, 4], dtype=np.float32)\n",
    "res = np.empty((3,), dtype=np.bool_)\n",
    "binary_op_3(from_dlpack(res), from_dlpack(a), from_dlpack(b))\n",
    "print(res)     # prints [False, True, False]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3 0 7]\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def binary_op_4(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):\n",
    "    a_vec = a.load()\n",
    "    b_vec = b.load()\n",
    "\n",
    "    xor_res = a_vec ^ b_vec\n",
    "    res.store(xor_res)\n",
    "\n",
    "    # or_res = a_vec | b_vec\n",
    "    # res.store(or_res)     # prints [3, 2, 7]\n",
    "\n",
    "    # and_res = a_vec & b_vec\n",
    "    # res.store(and_res)      # prints [0, 2, 0]\n",
    "\n",
    "a = np.array([1, 2, 3], dtype=np.int32)\n",
    "b = np.array([2, 2, 4], dtype=np.int32)\n",
    "res = np.empty((3,), dtype=np.int32)\n",
    "binary_op_4(from_dlpack(res), from_dlpack(a), from_dlpack(b))\n",
    "print(res)     # prints [3, 0, 7]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Unary Operations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
      "       [ 2.000000, ],\n",
      "       [ 2.000000, ],\n",
      "       [ 2.000000, ])\n",
      "tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
      "       [-0.756802, ],\n",
      "       [-0.756802, ],\n",
      "       [-0.756802, ])\n",
      "tensor(raw_ptr(0x0000000007fbd180: f32, generic, align<4>) o (3):(1), data=\n",
      "       [ 16.000000, ],\n",
      "       [ 16.000000, ],\n",
      "       [ 16.000000, ])\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def unary_op_1(res: cute.Tensor, a: cute.Tensor):\n",
    "    a_vec = a.load()\n",
    "\n",
    "    sqrt_res = cute.math.sqrt(a_vec)\n",
    "    res.store(sqrt_res)\n",
    "    cute.print_tensor(res)        # prints [2.000000, 2.000000, 2.000000]\n",
    "\n",
    "    sin_res = cute.math.sin(a_vec)\n",
    "    res.store(sin_res)\n",
    "    cute.print_tensor(res)        # prints [-0.756802, -0.756802, -0.756802]\n",
    "\n",
    "    exp2_res = cute.math.exp2(a_vec)\n",
    "    res.store(exp2_res)\n",
    "    cute.print_tensor(res)        # prints [16.000000, 16.000000, 16.000000]\n",
    "\n",
    "a = np.array([4.0, 4.0, 4.0], dtype=np.float32)\n",
    "res = np.empty((3,), dtype=np.float32)\n",
    "unary_op_1(from_dlpack(res), from_dlpack(a))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Reduction Operation\n",
    "\n",
    "The `TensorSSA`'s `reduce` method applies a specified reduction operation (`ReductionOp.ADD`, `ReductionOp.MUL`, `ReductionOp.MAX`, `ReductionOp.MIN`) starting with an initial value, and performs this reduction along the dimensions specified by the `reduction_profile.`. The result is typically a new `TensorSSA` with reduced dimensions or a scalar value if reduces across all axes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "21.000000\n",
      "tensor(raw_ptr(0x00007ffd1ea2bca0: f32, rmem, align<32>) o (2):(1), data=\n",
      "       [ 6.000000, ],\n",
      "       [ 15.000000, ])\n",
      "tensor(raw_ptr(0x00007ffd1ea2bcc0: f32, rmem, align<32>) o (3):(1), data=\n",
      "       [ 6.000000, ],\n",
      "       [ 8.000000, ],\n",
      "       [ 10.000000, ])\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def reduction_op(a: cute.Tensor):\n",
    "    \"\"\"\n",
    "    Apply reduction operation on the src tensor.\n",
    "\n",
    "    :param src: The source tensor to be reduced.\n",
    "    \"\"\"\n",
    "    a_vec = a.load()\n",
    "    red_res = a_vec.reduce(\n",
    "        cute.ReductionOp.ADD,\n",
    "        0.0,\n",
    "        reduction_profile=0\n",
    "    )\n",
    "    cute.printf(red_res)        # prints 21.000000\n",
    "\n",
    "    red_res = a_vec.reduce(\n",
    "        cute.ReductionOp.ADD,\n",
    "        0.0,\n",
    "        reduction_profile=(None, 1)\n",
    "    )\n",
    "    # We can't print the TensorSSA directly at this point, so we store it to a new Tensor and print it.\n",
    "    res = cute.make_fragment(red_res.shape, cutlass.Float32)\n",
    "    res.store(red_res)\n",
    "    cute.print_tensor(res)        # prints [6.000000, 15.000000]\n",
    "\n",
    "    red_res = a_vec.reduce(\n",
    "        cute.ReductionOp.ADD,\n",
    "        1.0,\n",
    "        reduction_profile=(1, None)\n",
    "    )\n",
    "    res = cute.make_fragment(red_res.shape, cutlass.Float32)\n",
    "    res.store(red_res)\n",
    "    cute.print_tensor(res)        # prints [6.000000, 8.000000, 10.000000]\n",
    "\n",
    "\n",
    "a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)\n",
    "reduction_op(from_dlpack(a))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
