def while_test(a: Tensor,
    i: Tensor) -> Tensor:
  a0 = a
  i0 = i
  _0 = bool(torch.lt(i, 3))
  while _0:
    a1 = torch.mul_(a0, a0)
    i1 = torch.add_(i0, 1)
    _0, a0, i0 = bool(torch.lt(i1, 3)), a1, i1
  return a0
