#version 450 core
#define PRECISION ${PRECISION}
#define FORMAT ${FORMAT}

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */

/*
 * Output Image
 */
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;

/*
 * Input Buffer
 */
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;

/*
 * Params Buffer
 */
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
  ivec4 oextents;
  ivec2 iextents;
  vec2 scale;
}
uBlock;

/*
 * Local Work Group Size
 */
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
 * Upsamples uInput to the uOutput with scale according to uBlock params,
 * using the equation for bilinear upsampling/interpolation
 * along the height and width plane.
 * align_true ~ align_corners=True, it means that each of the 4 output
 * corner texels are treated in interpolation as if they were squarely
 * aligned with the 4 input corner texels, if the two textures were overlaid.
 */
void main() {
  const ivec3 pos = ivec3(gl_GlobalInvocationID);
  if (any(greaterThan(pos, uBlock.oextents.xyz))) {
    return;
  }
  vec2 pos_interp = vec2(pos.xy) * uBlock.iextents.xy /
      clamp(uBlock.oextents.xy - 1, vec2(1, 1), uBlock.oextents.xy - 1);

  // 4 input texels used for bilinear interpolation, naming by PyTorch
  // Tensor coordinate space where the "top" is x = 0 and "left" is y = 0,
  // Vulkan reversed
  ivec3 in_pos_topleft = ivec3(floor(pos_interp.x), floor(pos_interp.y), pos.z);
  ivec3 in_pos_bottomleft =
      ivec3(floor(pos_interp.x), ceil(pos_interp.y), pos.z);
  ivec3 in_pos_topright = ivec3(ceil(pos_interp.x), floor(pos_interp.y), pos.z);
  ivec3 in_pos_bottomright =
      ivec3(ceil(pos_interp.x), ceil(pos_interp.y), pos.z);

  vec2 alpha = pos_interp - in_pos_topleft.xy;

  const vec4 top_val_interp =
      (texelFetch(uInput, in_pos_topleft, 0) * (1 - alpha.x)) +
      (texelFetch(uInput, in_pos_topright, 0) * alpha.x);
  const vec4 bot_val_interp =
      (texelFetch(uInput, in_pos_bottomleft, 0) * (1 - alpha.x)) +
      (texelFetch(uInput, in_pos_bottomright, 0) * alpha.x);

  imageStore(
      uOutput,
      pos,
      (top_val_interp * (1 - alpha.y)) + (bot_val_interp * alpha.y));
}
