#pragma once #include namespace torch::lazy { // This IR was copied from code-generated output, but the entire _to_copy // operator cannot be trivially code generated since it is only desirable to // capture IR for certain permutations of _to_copy (e.g. dtype), and for the // others it is difficult to even invoke the aten/eager fallback necessitating // directly implementing the right to(device) behavior class ToCopy : public torch::lazy::TsNode { public: static OpKind ClassOpKind() { return OpKind(at::aten::_to_copy); } ToCopy( const torch::lazy::Value& self, const std::optional& dtype, const std::optional& layout, const std::optional& device, const std::optional& pin_memory, const bool& non_blocking, const std::optional& memory_format, std::vector&& shapes) : torch::lazy::TsNode( ClassOpKind(), {self}, std::move(shapes), /* num_outputs */ 1, torch::lazy::MHash( dtype, layout, device, pin_memory, non_blocking, memory_format)), dtype(dtype), layout(layout), device(device), pin_memory(pin_memory), non_blocking(non_blocking), memory_format(memory_format) {} bool CanBeReused( const torch::lazy::Value& self, const std::optional& dtype, const std::optional& layout, const std::optional& device, const std::optional& pin_memory, const bool& non_blocking, const std::optional& memory_format) const { size_t i = 0; return ( operand(i++) == self && this->dtype == dtype && this->layout == layout && this->device == device && this->pin_memory == pin_memory && this->non_blocking == non_blocking && this->memory_format == memory_format); } std::string ToString() const override { std::stringstream ss; ss << torch::lazy::TsNode::ToString(); if (dtype.has_value()) { ss << ", dtype=" << dtype.value(); } else { ss << ", dtype=null"; } if (layout.has_value()) { ss << ", layout=" << layout.value(); } else { ss << ", layout=null"; } if (device.has_value()) { ss << ", device=" << device.value(); } else { ss << ", device=null"; } if (pin_memory.has_value()) { ss << ", pin_memory=" << pin_memory.value(); } else { ss << ", pin_memory=null"; } ss << ", non_blocking=" << non_blocking; if (memory_format.has_value()) { ss << ", memory_format=" << memory_format.value(); } else { ss << ", memory_format=null"; } return ss.str(); } torch::lazy::TSOpVector Lower( std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const override { std::vector arguments; std::vector kwarguments; arguments.reserve(1); kwarguments.reserve(6); size_t i = 0; arguments.emplace_back(loctx->GetOutputOp(operand(i++))); kwarguments.emplace_back("dtype", dtype); kwarguments.emplace_back("layout", layout); kwarguments.emplace_back("device", device); kwarguments.emplace_back("pin_memory", pin_memory); kwarguments.emplace_back("non_blocking", non_blocking); kwarguments.emplace_back("memory_format", memory_format); torch::lazy::TSOpVector _to_copy_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); TORCH_CHECK_EQ(_to_copy_out.size(), 1); return _to_copy_out; } std::optional dtype; std::optional layout; std::optional device; std::optional pin_memory; bool non_blocking; std::optional memory_format; }; } // namespace torch::lazy