#pragma once #include #include #include #include #include // This file contains autogenerated LazyTensor Non Native IR nodes namespace torch { namespace lazy { class Scalar : public TsNode { public: static torch::lazy::OpKind ClassOpKind() { return torch::lazy::OpKind(at::prim::Constant); } Scalar(const at::Scalar& value, const at::ScalarType& type) : TsNode( Scalar::ClassOpKind(), OpList{}, compute_shape_scalar(value, type), /* num_outputs */ 1, torch::lazy::MHash(value, type)), value(value), type(type) { } std::string ToString() const override { std::stringstream ss; ss << TsNode::ToString(); ss << ", value=" << value; ss << ", type=" << type; return ss.str(); } bool CanBeReused(const at::Scalar& value, const at::ScalarType& type) const; torch::lazy::TSOpVector Lower( std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const override; at::Scalar value; at::ScalarType type; }; class Expand : public TsNode { public: static torch::lazy::OpKind ClassOpKind() { return torch::lazy::OpKind(at::aten::expand); } Expand(const torch::lazy::Value& input, const ::std::vector& size, const bool& is_scalar_expand) : TsNode( Expand::ClassOpKind(), OpList{input}, [&](){ return compute_shape_expand(operand(0), size, is_scalar_expand)[0]; }, /* num_outputs */ 1, torch::lazy::MHash(size, is_scalar_expand)), size(size), is_scalar_expand(is_scalar_expand) { } std::string ToString() const override { std::stringstream ss; ss << TsNode::ToString(); ss << ", size=" << size; ss << ", is_scalar_expand=" << is_scalar_expand; return ss.str(); } bool CanBeReused(const torch::lazy::Value& input, const ::std::vector& size, const bool& is_scalar_expand) const { size_t i = 0; return (operand(i++) == input && this->size == size && this->is_scalar_expand == is_scalar_expand); } torch::lazy::TSOpVector Lower( std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const override; ::std::vector size; bool is_scalar_expand; }; class Cast : public TsNode { public: static torch::lazy::OpKind ClassOpKind() { return torch::lazy::OpKind(ltc_cast); } Cast(const torch::lazy::Value& input, const at::ScalarType& dtype, const ::std::optional& stype) : TsNode( Cast::ClassOpKind(), OpList{input}, compute_shape_cast(input, dtype, stype), /* num_outputs */ 1, torch::lazy::MHash(dtype, stype)), dtype(dtype), stype(stype) { } std::string ToString() const override { std::stringstream ss; ss << TsNode::ToString(); ss << ", dtype=" << dtype; if (stype.has_value()) { ss << ", stype=" << stype.value(); } else { ss << ", stype=null"; } return ss.str(); } bool CanBeReused(const torch::lazy::Value& input, const at::ScalarType& dtype, const ::std::optional& stype) const { size_t i = 0; return (operand(i++) == input && this->dtype == dtype && ((!this->stype&&!stype) || (this->stype&&stype && *(this->stype) == *stype))); } torch::lazy::TSOpVector Lower( std::shared_ptr function, torch::lazy::TSLoweringContext* loctx) const override; at::ScalarType dtype; ::std::optional stype; }; } // namespace lazy } // namespace torch