diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 8e5831cd09946..085c2ddca2f9c 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -358,6 +358,19 @@ Val* getMaximumValue(DataType v) { } // namespace +// TENSOR FACTORIES +TensorView* rand(const std::vector& shape, DataType dtype) { + auto n = shape.size(); + auto out = TensorViewBuilder() + .ndims(n) + .dtype(dtype) + .contiguity(std::vector(n, true)) + .shape(shape) + .build(); + IrBuilder::create(RNGOpType::Uniform, out); + return out; +} + Val* castOp(DataType dtype, Val* v1) { if (v1->getDataType().value() == dtype) { return set(v1); @@ -404,17 +417,6 @@ Val* unaryOp(UnaryOpType type, Val* v1) { TORCH_INTERNAL_ASSERT( type != UnaryOpType::Address, "The reference operator & is not accessible in the Fusion IR"); - - // TODO: We should add the following, but we need to go through schedulers - // and make sure all calls to "fusion->inputs" includes the output of RandLike - // - // If rand like, there isn't a real dependency on the input value, so map it - // to a dummy scalar. if - // - // (type == UnaryOpType::RandLike) { - // v1 = new NamedScalar("__rnd", v1->getDataType().value()); - // } - Val* out = newValLike(v1, v1->getDataType().value()); IrBuilder::create(type, out, v1); return out; @@ -469,28 +471,21 @@ NVFUSER_DEFINE_UNARY_OP(trunc, Trunc) NVFUSER_DEFINE_UNARY_OP(print, Print) #undef NVFUSER_DEFINE_UNARY_OP -Val* randlike(Val* v) { +TensorView* randlike(TensorView* v) { TORCH_CHECK( isFloatingPointType(v->dtype()), "input must have floating point type, but got ", v->dtype()); - auto rand_vals = unaryOp(UnaryOpType::RandLike, v); - return where( - eq(rand_vals, IrBuilder::create(1.0)), - IrBuilder::create(0.0), - rand_vals); + std::vector shape; + shape.reserve(v->getMaybeRFactorDomain().size()); + for (auto id : v->getMaybeRFactorDomain()) { + shape.emplace_back(id->getMaybeExpandedExtent()); + } + return rand(shape, v->dtype()); } -TensorView* randlike(TensorView* v) { - TORCH_CHECK( - isFloatingPointType(v->dtype()), - "input must have floating point type, but got ", - v->dtype()); - auto rand_vals = unaryOp(UnaryOpType::RandLike, v); - return where( - eq(rand_vals, IrBuilder::create(1.0)), - IrBuilder::create(0.0), - rand_vals); +Val* randlike(Val* v) { + return randlike(v->as()); } Val* bitwise_not(Val* v) { diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index 7a1efee80f5dc..03ef14b9bdb59 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -121,6 +121,11 @@ TORCH_CUDA_CU_API WelfordResult Welford( // import IrBuilder just for this one interface. Int* init_N = nullptr); +// TENSOR FACTORIES +TORCH_CUDA_CU_API TensorView* rand( + const std::vector& shape, + DataType dtype); + // UNARY OPERATIONS // abs TORCH_CUDA_CU_API Val* abs(Val*); diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 07e6564f27435..5768863efa696 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -706,34 +706,12 @@ class CudaKernelGenerator : private OptOutConstDispatch { } if (!print_inline_) { - if (op_type == UnaryOpType::RandLike) { - auto out_tv = uop->out()->as()->view(); - auto index = genTensorIndex(uop->in()->as()); - int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4; - indent() << "nvfuser_index_t rng_subseq" << uop->name() << " = (" - << index << ") / " << multiple << ";\n"; - indent() << "nvfuser_index_t rng_component" << uop->name() << " = (" - << index << ") % " << multiple << ";\n"; - indent() << "nvfuser_index_t rng_offset" << uop->name() << " = " - << uop->getRNGOffset() << ";\n"; - indent() << "if (rng_subseq != rng_subseq" << uop->name() - << " || rng_offset != rng_offset" << uop->name() << ") {\n"; - indent() << " rng_result = philox(philox_args.seed_, rng_subseq" - << uop->name() << ", philox_offset / 4 + rng_offset" - << uop->name() << ");\n"; - indent() << " rng_subseq = rng_subseq" << uop->name() << ";\n"; - indent() << " rng_offset = rng_offset" << uop->name() << ";\n"; - indent() << "}\n"; - } - indent() << gen(uop->out()); if (!uop->out()->isScalar() && !uop->in()->isScalar()) { code_ << "\n"; indent() << kTab; } code_ << " = "; - } else { - TORCH_INTERNAL_ASSERT(op_type != UnaryOpType::RandLike); } if (auto op = inline_op_str(op_type)) { @@ -762,13 +740,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { } } - code_ << "("; - if (op_type == UnaryOpType::RandLike) { - code_ << "rng_result, rng_component" << uop->name(); - } else { - code_ << gen(uop->in()); - } - code_ << ")"; + code_ << "(" << gen(uop->in()) << ")"; } if (!print_inline_) { @@ -776,6 +748,35 @@ class CudaKernelGenerator : private OptOutConstDispatch { } } + void handle(const RNGOp* rop) final { + // TODO: TORCH_INTERNAL_ASSERT that the scheduler correctly creates an + // innermost ID of size 4 (float) or size 2 (double)? + auto out_tv = rop->output(0)->as()->view(); + auto index = genTensorIndex(rop->getPhiloxIndex()->as()); + int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4; + indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = (" << index + << ") / " << multiple << ";\n"; + indent() << "nvfuser_index_t rng_component" << rop->name() << " = (" + << index << ") % " << multiple << ";\n"; + indent() << "nvfuser_index_t rng_offset" << rop->name() << " = " + << rop->getRNGOffset() << ";\n"; + indent() << "if (rng_subseq != rng_subseq" << rop->name() + << " || rng_offset != rng_offset" << rop->name() << ") {\n"; + indent() << " rng_result = philox(philox_args.seed_, rng_subseq" + << rop->name() << ", philox_offset / 4 + rng_offset" << rop->name() + << ");\n"; + indent() << " rng_subseq = rng_subseq" << rop->name() << ";\n"; + indent() << " rng_offset = rng_offset" << rop->name() << ";\n"; + indent() << "}\n"; + auto op_type = rop->getRNGOpType(); + indent() << gen(rop->output(0)) << " = " << op_type; + if (needFloatSuffix(op_type) && + rop->output(0)->dtype() == DataType::Float) { + code_ << "f"; + } + code_ << "(rng_result, rng_component" << rop->name() << ");\n"; + } + std::string genBinaryOp( BinaryOpType op_type, DataType data_type, diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 676cb80866ea5..38c8f49541fb6 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -104,6 +104,9 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::TernaryOp: ptr(handler)->handle(expr->as()); return; + case ExprType::RNGOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::ReductionOp: ptr(handler)->handle(expr->as()); return; @@ -284,6 +287,9 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::TernaryOp: ptr(handler)->handle(expr->as()); return; + case ExprType::RNGOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::ReductionOp: ptr(handler)->handle(expr->as()); return; @@ -472,6 +478,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) { case ExprType::TernaryOp: ptr(mutator)->mutate(expr->as()); return; + case ExprType::RNGOp: + ptr(mutator)->mutate(expr->as()); + return; case ExprType::ReductionOp: ptr(mutator)->mutate(expr->as()); return; @@ -725,6 +734,9 @@ void OptOutConstDispatch::handle(const BinaryOp* stmt) { void OptOutConstDispatch::handle(const TernaryOp* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const RNGOp* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const ReductionOp* stmt) { unhandled(stmt); } @@ -875,6 +887,9 @@ void OptOutDispatch::handle(BinaryOp* stmt) { void OptOutDispatch::handle(TernaryOp* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(RNGOp* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(ReductionOp* stmt) { unhandled(stmt); } diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 5f84ecca40696..add623303bc91 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -71,6 +71,7 @@ class NamedScalar; class UnaryOp; class BinaryOp; class TernaryOp; +class RNGOp; class ReductionOp; class GroupedReductionOp; class WelfordOp; @@ -145,6 +146,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const UnaryOp* stmt); virtual void handle(const BinaryOp* stmt); virtual void handle(const TernaryOp* stmt); + virtual void handle(const RNGOp* stmt); virtual void handle(const ReductionOp* stmt); virtual void handle(const GroupedReductionOp* stmt); virtual void handle(const WelfordOp* stmt); @@ -210,6 +212,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(UnaryOp* stmt); virtual void handle(BinaryOp* stmt); virtual void handle(TernaryOp* stmt); + virtual void handle(RNGOp* stmt); virtual void handle(ReductionOp* stmt); virtual void handle(GroupedReductionOp* stmt); virtual void handle(WelfordOp* stmt); @@ -316,6 +319,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(UnaryOp*); virtual void mutate(BinaryOp*); virtual void mutate(TernaryOp*); + virtual void mutate(RNGOp*); virtual void mutate(ReductionOp*); virtual void mutate(GroupedReductionOp*); virtual void mutate(WelfordOp*); diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index d3f4ae6715b87..595ff6168433c 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -373,6 +373,18 @@ void Fusion::printMath(bool from_outputs_only) { std::cout << "}\n\n"; } +std::vector Fusion::inputsAndCreated() { + auto result = inputs_; + for (auto expr : exprs()) { + if (expr->inputs().empty()) { + for (auto v : expr->outputs()) { + result.emplace_back(v); + } + } + } + return result; +} + void Fusion::printTransforms() { FUSER_PERF_SCOPE("Fusion::printTransforms"); @@ -531,14 +543,15 @@ Expr* Fusion::definition(const Val* val) const { // Indicate to kernel to set itself up to generate random numbers bool Fusion::isStochastic() { - for (auto expr : exprs()) - if (expr->getExprType() == ExprType::UnaryOp) - if (expr->as()->getUnaryOpType() == UnaryOpType::RandLike) - return true; + for (auto expr : exprs()) { + if (expr->getExprType() == ExprType::RNGOp) { + return true; + } + } return false; } -std::vector Fusion::getTerminatingOutputs() { +std::vector Fusion::getTerminatingOutputs() const { FUSER_PERF_SCOPE("getTerminatingOutputs"); auto is_reachable_to_output = [](Val* val) { diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index 1f25a9661bf87..cf7b035e971f5 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -175,11 +175,13 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer { return inputs_; } + std::vector inputsAndCreated(); + const auto& outputs() const { return outputs_; } - std::vector getTerminatingOutputs(); + std::vector getTerminatingOutputs() const; // Aliasing output to input value, this is a WAR to allow inplace update on // input tensor. diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index f59d7d7deaa0e..7d5ebad25282b 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -48,6 +48,7 @@ class Expr; class Val; class UnaryOp; class BinaryOp; +class RNGOp; class IterDomain; class IrCloner; class IrContainer; diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.cpp b/torch/csrc/jit/codegen/cuda/ir_builder.cpp index adb523ecfa27d..56f35819cd3ca 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_builder.cpp @@ -63,6 +63,7 @@ IR_BUILDER_INSTANTIATE(ViewOp) IR_BUILDER_INSTANTIATE(UnaryOp) IR_BUILDER_INSTANTIATE(BinaryOp) IR_BUILDER_INSTANTIATE(TernaryOp) +IR_BUILDER_INSTANTIATE(RNGOp) IR_BUILDER_INSTANTIATE(ReductionOp) IR_BUILDER_INSTANTIATE(GroupedReductionOp) IR_BUILDER_INSTANTIATE(WelfordOp) diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index 638e9d8c5a5f1..7c7535076fda7 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -100,6 +100,10 @@ void IrCloner::handle(const TernaryOp* op) { clone_ = IrBuilder::clone(op, this); } +void IrCloner::handle(const RNGOp* op) { + clone_ = IrBuilder::clone(op, this); +} + void IrCloner::handle(const BroadcastOp* op) { clone_ = IrBuilder::clone(op, this); } diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index a0f5d76f007d8..ce7ed884bbf59 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -71,6 +71,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { void handle(const UnaryOp*) override; void handle(const BinaryOp*) override; void handle(const TernaryOp*) override; + void handle(const RNGOp*) override; void handle(const BroadcastOp*) override; void handle(const ReductionOp*) override; void handle(const GroupedReductionOp*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp index 941bf22dea763..b637c7bb69695 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp @@ -443,6 +443,16 @@ void IrGraphGenerator::handle(const TernaryOp* op) { addArc(op, op->out()); } +void IrGraphGenerator::handle(const RNGOp* op) { + // node + std::stringstream label; + label << op->getRNGOpType(); + printExpr(op, label.str()); + + // inputs & outputs + addArc(op, op->output(0)); +} + void IrGraphGenerator::handle(const BroadcastOp* op) { printExpr(op, "Broadcast"); addArc(op->in(), op); diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.h b/torch/csrc/jit/codegen/cuda/ir_graphviz.h index e5bbcac9157dc..29df3e37a089f 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.h +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.h @@ -85,6 +85,7 @@ class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch { void handle(const UnaryOp*) override; void handle(const BinaryOp*) override; void handle(const TernaryOp*) override; + void handle(const RNGOp*) override; void handle(const BroadcastOp*) override; void handle(const ReductionOp*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 129a0a0e79e42..1b61f18219e5c 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -58,23 +58,12 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr { return unary_op_type_; } - int getRNGOffset() const { - return rng_offset_; - } - - void setRNGOffset(int val) { - rng_offset_ = val; - } - bool sameAs(const Statement* other) const override; private: const UnaryOpType unary_op_type_; Val* const out_ = nullptr; Val* const in_ = nullptr; - // TODO: pull RNG op out of Unary ops - // https://github.com/csarofeen/pytorch/pull/1892 - int rng_offset_ = -1; }; //! A specialization for Binary operations. Binary operations take in two inputs @@ -110,6 +99,48 @@ class TORCH_CUDA_CU_API BinaryOp : public Expr { Val* const rhs_ = nullptr; }; +//! A specialization for random number generator (RNG) operations. RNG +//! operations take in no tensor input and produce a single output. +class TORCH_CUDA_CU_API RNGOp : public Expr { + public: + RNGOp( + IrBuilderPasskey, + RNGOpType type, + Val* out, + int rng_offset = 0, + Val* philox_index = nullptr); + + RNGOp(const RNGOp* src, IrCloner* ir_cloner); + + RNGOpType getRNGOpType() const { + return rng_op_type_; + } + + int getRNGOffset() const { + return rng_offset_; + } + + void setRNGOffset(int val) { + rng_offset_ = val; + } + + Val* getPhiloxIndex() const { + return philox_index_; + } + + void setPhiloxIndex(Val* index) { + philox_index_ = index; + } + + bool sameAs(const Statement* other) const override; + + private: + const RNGOpType rng_op_type_; + int rng_offset_ = -1; + // The index used to feed philox's subsequence and component + Val* philox_index_ = nullptr; +}; + //! Broadcast in to match out. is_broadcast_dims are relative to out. Where //! is_broadcast_dims.size() == out->nDims(). class TORCH_CUDA_CU_API BroadcastOp : public Expr { @@ -1146,6 +1177,13 @@ class TORCH_CUDA_CU_API IterDomain : public Val { return expanded_extent_; } + Val* getMaybeExpandedExtent() const { + if (hasExpandedExtent()) { + return expandedExtent(); + } + return extent(); + } + //! Dimension padding interface: //! 2 modes are currently supported: //! diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 005eeea8ae21c..36804ecc9d3f2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -393,6 +393,33 @@ void IrPrinter::handle(const TernaryOp* top) { os_ << ";\n"; } +void IrPrinter::handle(const RNGOp* rop) { + bool istvop = ir_utils::isTvOp(rop); + if (!print_inline_) { + indent(); + os_ << rop->output(0); + + // tensor operations tend to be long, break them up into multiple lines + if (istvop) { + os_ << "\n"; + indent_size_++; + indent(); + } + + os_ << " = "; + } else { + checkInlineable(rop); + } + + os_ << rop->getRNGOpType() << "()"; + + if (istvop) + indent_size_--; + + if (!print_inline_) + os_ << ";\n"; +} + void IrPrinter::handle(const ReductionOp* rop) { indent() << rop->out() << "\n"; indent() << " = reduction( " << rop->in() diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index 2df1ec2ec230a..6e8379120d9cc 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -85,6 +85,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { void handle(const UnaryOp*) final; void handle(const BinaryOp*) final; void handle(const TernaryOp*) final; + void handle(const RNGOp*) final; void handle(const ReductionOp*) final; void handle(const GroupedReductionOp*) final; void handle(const WelfordOp*) final; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index e9a90e0225320..3cf2e7021ad23 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -191,8 +191,7 @@ UnaryOp::UnaryOp( : Expr(passkey, ExprType::UnaryOp), unary_op_type_{type}, out_{out}, - in_{in}, - rng_offset_(rng_offset) { + in_{in} { addOutput(out); addInput(in); } @@ -201,8 +200,7 @@ UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner) : Expr(src, ir_cloner), unary_op_type_(src->unary_op_type_), out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - rng_offset_(src->rng_offset_) {} + in_(ir_cloner->clone(src->in_)) {} bool UnaryOp::sameAs(const Statement* other) const { if (this == other) { @@ -293,6 +291,48 @@ bool TernaryOp::sameAs(const Statement* other) const { return Expr::sameAs(other); } +RNGOp::RNGOp( + IrBuilderPasskey passkey, + RNGOpType type, + Val* out, + int rng_offset, + Val* philox_index) + : Expr(passkey, ExprType::RNGOp), + rng_op_type_(type), + rng_offset_(rng_offset), + philox_index_(philox_index) { + addOutput(out); +} + +RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), + rng_op_type_(src->rng_op_type_), + rng_offset_(src->rng_offset_) {} + +bool RNGOp::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + if (!other->isA()) { + return false; + } + const auto other_op = other->as(); + if (getRNGOpType() != other_op->getRNGOpType()) { + return false; + } + if (getRNGOffset() != other_op->getRNGOffset()) { + return false; + } + if ((philox_index_ == nullptr) != (other_op->philox_index_ == nullptr)) { + return false; + } + if ((philox_index_ != nullptr) && + !philox_index_->sameAs(other_op->philox_index_)) { + return false; + } + return Expr::sameAs(other); +} + BroadcastOp::BroadcastOp( IrBuilderPasskey passkey, Val* out, diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index d8520bf047f41..4906046d56365 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -186,11 +186,7 @@ struct SubstituteInExpr : public OptInDispatch { auto out = reference_->sameAs(unary_expr->out()) ? substitute_ : unary_expr->out(); expr_ = IrBuilder::create( - unary_expr->container(), - unary_expr->getUnaryOpType(), - out, - in, - unary_expr->getRNGOffset()); + unary_expr->container(), unary_expr->getUnaryOpType(), out, in); } void handle(BinaryOp* binary_expr) final { @@ -227,6 +223,17 @@ struct SubstituteInExpr : public OptInDispatch { in3); } + void handle(RNGOp* rng_expr) final { + auto out = reference_->sameAs(rng_expr->output(0)) ? substitute_ + : rng_expr->output(0); + expr_ = IrBuilder::create( + rng_expr->container(), + rng_expr->getRNGOpType(), + out, + rng_expr->getRNGOffset(), + rng_expr->getPhiloxIndex()); + } + void handle(ReductionOp* reduction_expr) final { auto init = reference_->sameAs(reduction_expr->init()) ? substitute_ @@ -720,13 +727,29 @@ class ValReplacementMutator : private OptOutMutator { // would be a tensorview that doesn't get updated extents. Therefore, first // grab all leaves towards outputs and grab stmts from there. auto stmts = StmtSort::getStmts(fusion, allLeafOuts(fusion), true); - for (auto stmt : stmts) { + + // Some fusions, such as standalone randlike, can have disconnected DAG, so + // we need some mechanism to make sure our replacement set is as complete as + // possible + // TODO: I think we need a more general mechanism to support disconnected + // DAG + std::vector more; + for (auto v : fusion->inputs()) { + if (std::find(stmts.begin(), stmts.end(), v) == stmts.end()) { + more.emplace_back(v); + } + } + auto more_stmts = StmtSort::getStmts(fusion, more, true); + more_stmts.insert(more_stmts.end(), stmts.begin(), stmts.end()); + + for (auto stmt : more_stmts) { mutate(stmt); } } private: using OptOutMutator::mutate; + void mutate(Val* val) final { if (replacement_map_.find(val) == replacement_map_.end()) { return OptOutMutator::mutate(val); @@ -875,8 +898,7 @@ struct ReplaceValInIndexVal : public OptInDispatch { auto inp = last_visited_val_; TORCH_INTERNAL_ASSERT(uop->out()->isA()); auto out = IrBuilder::create(c10::nullopt); - IrBuilder::create( - uop->getUnaryOpType(), out, inp, uop->getRNGOffset()); + IrBuilder::create(uop->getUnaryOpType(), out, inp); last_visited_val_ = out; } diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 6ae4e7374df57..08ba663c9fa63 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -279,8 +279,9 @@ class Inputs : public IterVisitor { } void handle(Val* val) override { - // If there's no definition to val, or val is within the provided inputs - if (val->definition() == nullptr || + // If there's no definition to val, or val is created inside the fusion, or + // val is within the provided inputs + if (val->definition() == nullptr || val->definition()->inputs().empty() || std::find(all_inputs_.begin(), all_inputs_.end(), val) != all_inputs_.end()) { // if not already placed in the inputs @@ -400,7 +401,6 @@ void BackwardVisitor::traverseFrom( } auto vals = AllVals::get(fusion, from); - auto exprs = StmtSort::getExprs(fusion, from); { @@ -843,7 +843,7 @@ std::vector StmtSort::getStmts( } void InputsOf::handle(Val* v) { - if (v->definition() == nullptr) { + if (v->definition() == nullptr || v->definition()->inputs().empty()) { if (grabbed_inputs.emplace(v).second) { ordered_inputs.push_back(v); } diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h index 2447933d7373a..8adac390dac89 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.h +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h @@ -109,7 +109,7 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { void traverseAllPaths(Fusion* fusion); //! Get inputs to vals. Possible input vals can be optionally - //! given. If not, vals with no defining expression are returned. + //! given. If not, vals with no producers are returned. static std::vector getInputsTo( const std::vector& vals, const std::vector& inputs = {}); diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp index 404e61e7dc527..9e52116049728 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel.cpp @@ -80,11 +80,9 @@ class KernelIrScanner : private IrVisitor { } } - void handle(UnaryOp* unary_op) final { - if (unary_op->getUnaryOpType() == UnaryOpType::RandLike) { - summary_.max_rng_offsets = - std::max(summary_.max_rng_offsets, unary_op->getRNGOffset()); - } + void handle(RNGOp* rng_op) final { + summary_.max_rng_offsets = + std::max(summary_.max_rng_offsets, rng_op->getRNGOffset()); } void handle(TensorIndex* tensor_index) final { diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index cc41376435c6b..62b245772dd03 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -39,6 +39,7 @@ class TensorView; class UnaryOp; class BinaryOp; class TernaryOp; +class RNGOp; class ReductionOp; class WelfordOp; class BroadcastOp; diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 50eec0bc6e550..53b9d172f203f 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -191,11 +191,9 @@ void GpuLower::collectPaddedParallelDims() { void assignRNGOffset(Fusion* fusion) { int counter = 0; for (auto expr : fusion->exprs()) { - if (expr->isA()) { - auto uop = expr->as(); - if (uop->getUnaryOpType() == UnaryOpType::RandLike) { - uop->setRNGOffset(counter++); - } + if (expr->isA()) { + auto rop = expr->as(); + rop->setRNGOffset(counter++); } } } diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index ab5eef6b21cff..2c18c6ea15e59 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -92,36 +92,30 @@ void IndexLowering::handle(const kir::ForLoop* for_loop) { active_scope_ = prev_scope; } -// TODO: use a separate IR node to represent rand like -void IndexLowering::lowerRandLike(const UnaryOp* uop) { +void IndexLowering::handle(const RNGOp* rop) { // Write random tensor indices into the consumer // tensor index if the output is a tensor. - auto out_tv = dynamic_cast(uop->out()); + auto out_tv = dynamic_cast(rop->output(0)); TORCH_INTERNAL_ASSERT(out_tv != nullptr, "rand scalar not yet supported"); - // TODO: using in as a placeholder for the random tensor index - // would need to keep this space on the new rand op when separating - // randlike from the unary op. - auto in = SimplifyingIrBuilder::create( + // TensorIndex for philox subsequence and component. + auto philox_index = SimplifyingIrBuilder::create( out_tv, Index::getRandomTensorStridedIndices(out_tv, for_loops_)); // TensorIndex for writing randlike output. - const auto out = lowerDstIndex(uop->out()); + const auto out = lowerDstIndex(out_tv); - pushBack(IrBuilder::create( - UnaryOpType::RandLike, out, in, uop->getRNGOffset())); - GpuLower::current()->propagateExprInfo(uop, back()); + auto lowered = IrBuilder::create( + rop->getRNGOpType(), out, rop->getRNGOffset(), philox_index); + + pushBack(lowered); + GpuLower::current()->propagateExprInfo(rop, back()); } void IndexLowering::handle(const UnaryOp* uop) { - if (uop->getUnaryOpType() == UnaryOpType::RandLike) { - lowerRandLike(uop); - return; - } const auto in = lowerSrcIndex(uop->in(), uop->out()); const auto out = lowerDstIndex(uop->out()); - pushBack(IrBuilder::create( - uop->getUnaryOpType(), out, in, uop->getRNGOffset())); + pushBack(IrBuilder::create(uop->getUnaryOpType(), out, in)); GpuLower::current()->propagateExprInfo(uop, back()); } diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h index 539c40f0fb6ce..06c19e780f069 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_index.h @@ -40,11 +40,10 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { void handle(const ViewAsScalar*) final; void handle(const UnaryOp*) final; - // TODO: use a separate IR node to represent rand like - void lowerRandLike(const UnaryOp*); void handle(const BinaryOp*) final; void handle(const TernaryOp*) final; + void handle(const RNGOp*) final; void handle(const ReductionOp*) final; void handle(const GroupedReductionOp*) final; void handle(const WelfordOp*) final; diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp index cd9a7f7878660..324bab279b37e 100644 --- a/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp @@ -14,8 +14,9 @@ void ConcretizedBroadcastDomains::build(Fusion* fusion) { exact_map_ = std::make_unique(fusion); // Initialize the origin map with input broadcast domains + auto inputs = fusion->inputsAndCreated(); for (const auto fusion_input_tv : - ir_utils::filterByType(fusion->inputs())) { + ir_utils::filterByType(inputs)) { for (auto root_id : fusion_input_tv->getRootDomain()) { if (root_id->isBroadcast()) { broadcast_origin_map_.emplace( diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 28da8774daa28..17cd0c34dd123 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -91,6 +91,7 @@ bool isTvOp(const Expr* expr) { (expr->getExprType().value() == ExprType::UnaryOp || expr->getExprType().value() == ExprType::BinaryOp || expr->getExprType().value() == ExprType::TernaryOp || + expr->getExprType().value() == ExprType::RNGOp || expr->getExprType().value() == ExprType::ReductionOp || expr->getExprType().value() == ExprType::GroupedReductionOp || expr->getExprType().value() == ExprType::WelfordOp || @@ -569,8 +570,7 @@ class ReplaceExprInput : private kir::ExprMutator { auto replacement = IrBuilder::create( node->getUnaryOpType(), node->out(), - replaced_inputs.value().at(node->in()), - node->getRNGOffset()); + replaced_inputs.value().at(node->in())); registerReplaceWithPredicate(node, replacement); } } @@ -600,6 +600,11 @@ class ReplaceExprInput : private kir::ExprMutator { } } + void handle(RNGOp* node) final { + // RNGOp has no input + return; + } + void handle(ReductionOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index bfb9d6a2534ef..201d492fdab99 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -136,7 +136,7 @@ void OptOutMutator::mutate(UnaryOp* uop) { auto container = uop->container(); auto uop_type = uop->getUnaryOpType(); container->removeExpr(uop); - IrBuilder::create(container, uop_type, out, in, uop->getRNGOffset()); + IrBuilder::create(container, uop_type, out, in); } void OptOutMutator::mutate(BinaryOp* bop) { @@ -171,6 +171,20 @@ void OptOutMutator::mutate(TernaryOp* top) { IrBuilder::create(container, top_type, out, in1, in2, in3); } +void OptOutMutator::mutate(RNGOp* rop) { + Val* out = maybeMutated(rop->output(0)); + + if (out == rop->output(0)) { + return; + } + + auto container = rop->container(); + auto rop_type = rop->getRNGOpType(); + container->removeExpr(rop); + IrBuilder::create( + container, rop_type, out, rop->getRNGOffset(), rop->getPhiloxIndex()); +} + void OptOutMutator::mutate(ReductionOp* rop) { Val* out = maybeMutated(rop->out()); Val* in = maybeMutated(rop->in()); diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp index 304614361a017..4fa42c55d53f2 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp @@ -1005,6 +1005,10 @@ void ComputeAtRootDomainMapBuilder::mapAllPendingMappings( } } +void ComputeAtRootDomainMapBuilder::handle(RNGOp* rop) { + handle(rop->output(0)->as()); +} + void ComputeAtRootDomainMapBuilder::handle(TensorView* tv) { const TensorDomain* td = tv->domain(); const auto rfactor = TensorDomain::noReductions(td->getMaybeRFactorDomain()); diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.h b/torch/csrc/jit/codegen/cuda/root_domain_map.h index a4a3b5a440e2c..551441d904afe 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.h +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.h @@ -399,6 +399,8 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder mapPointwiseOrReductionOp(top); } + void handle(RNGOp* top) override; + void handle(ReductionOp* op) override { mapPointwiseOrReductionOp(op); } diff --git a/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu b/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu index 4736d8ac6b176..96cec63f8d9ee 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu @@ -46,22 +46,24 @@ __device__ uint4 philox( __device__ float uniformf(unsigned int x) { constexpr float kRanInvM32 = 2.3283064e-10f; // Inverse of 2^32. - return x * kRanInvM32; + float result = x * kRanInvM32; + return result == 1 ? 0.0f : result; } __device__ double uniform(unsigned int x, unsigned int y) { constexpr double kRan2Pow53Inv = 1.1102230246251565e-16; const unsigned long long z = (unsigned long long)x ^ ((unsigned long long)y << (53 - 32)); - return z * kRan2Pow53Inv + (kRan2Pow53Inv / 2.0); + double result = z * kRan2Pow53Inv + (kRan2Pow53Inv / 2.0); + return result == 1 ? 0.0 : result; } -__device__ double randLike(const uint4& rng_result, int rng_component) { +__device__ double rng_uniform(const uint4& rng_result, int rng_component) { return uniform( (&rng_result.x)[rng_component * 2], (&rng_result.x)[rng_component * 2 + 1]); } -__device__ float randLikef(const uint4& rng_result, int rng_component) { +__device__ float rng_uniformf(const uint4& rng_result, int rng_component) { return uniformf((&rng_result.x)[rng_component]); } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp index 7b3edfd74cdb5..ac7f66836a87c 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp @@ -372,16 +372,16 @@ bool isConnectedFusionGraph(Fusion* fusion) { // Iterate through all used exprs for (auto expr : fusion->exprs()) { TORCH_INTERNAL_ASSERT( - !expr->inputs().empty(), "unknown expr with zero input"); + !expr->outputs().empty(), "unknown expr with zero output"); // Each expr maps all its inputs and // outputs to the same component - auto input0 = expr->inputs()[0]; + auto output0 = expr->output(0); for (auto input : expr->inputs()) { - component_sets.mapEntries(input0, input); + component_sets.mapEntries(output0, input); } for (auto output : expr->outputs()) { - component_sets.mapEntries(input0, output); + component_sets.mapEntries(output0, output); } } diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index d9ef22a110d37..f063ba82b6816 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -4238,12 +4238,6 @@ TEST_F(NVFuserTest, FusionUnaryOps_CUDA) { /*Inputs Tuple*/ std::make_tuple(std::make_pair(ValType::TensorView, dtype))); }); - - // TODO: why the rand_like test is failing for complex? Is it because each - // complex needs to draw 2 random numbers from the RNG? We need to enable - // this - // TODO: - // Randlike testing is moved to test_gpu_rng.cu } dtypes = {DataType::Int, DataType::Int32, DataType::Bool}; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu b/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu index 21360b1510914..bb7f910b2a665 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu @@ -150,7 +150,7 @@ TEST_F(NVFuserTest, FusionRNGSimpleValidateWithCURand_CUDA) { tv2->split(0, 8); tv2->axis(0)->parallelize(ParallelType::TIDx); - tv0->computeAt(tv2, 1); + tv1->computeAt(tv2, 1); auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); at::Tensor t0 = at::zeros({size}, options); diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index ef8136c631d40..7430e7e71964c 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -296,6 +296,8 @@ static const char* expr_type2string(ExprType t) { return "BinaryOp"; case ExprType::TernaryOp: return "TernaryOp"; + case ExprType::RNGOp: + return "RNGOp"; case ExprType::ReductionOp: return "ReductionOp"; case ExprType::GroupedReductionOp: @@ -391,6 +393,10 @@ bool needFloatSuffix(UnaryOpType t) { } } +bool needFloatSuffix(RNGOpType t) { + return true; +} + static const char* unary_op_type2string(UnaryOpType t) { switch (t) { case UnaryOpType::Abs: @@ -443,8 +449,6 @@ static const char* unary_op_type2string(UnaryOpType t) { return "not"; case UnaryOpType::Print: return "print"; - case UnaryOpType::RandLike: - return "randLike"; case UnaryOpType::Reciprocal: return "reciprocal"; case UnaryOpType::Relu: @@ -646,6 +650,16 @@ static const char* binary_op_type_inline_op2string(BinaryOpType t) { return nullptr; } +static const char* rng_op_type_inline_op2string(RNGOpType t) { + switch (t) { + case RNGOpType::Uniform: + return "rng_uniform"; + default: + break; + } + return nullptr; +} + std::string stringifyBooleanOp(const BinaryOpType bopt) { switch (bopt) { case BinaryOpType::And: @@ -674,6 +688,15 @@ static const char* ternary_op_type2string(TernaryOpType t) { } } +static const char* rng_op_type2string(RNGOpType t) { + switch (t) { + case RNGOpType::Uniform: + return "rng_uniform"; + default: + TORCH_INTERNAL_ASSERT(false, "Unexpected RNGOpType"); + } +} + static const char* parallel_type2string(ParallelType t) { switch (t) { case ParallelType::BIDz: @@ -986,6 +1009,10 @@ std::ostream& operator<<(std::ostream& out, const TernaryOpType totype) { return out << ternary_op_type2string(totype); } +std::ostream& operator<<(std::ostream& out, const RNGOpType rngtype) { + return out << rng_op_type2string(rngtype); +} + std::ostream& operator<<(std::ostream& out, const ParallelType ptype) { return out << stringifyThread(ptype); } @@ -1069,6 +1096,12 @@ c10::optional inline_op_str(const BinaryOpType botype) { : c10::nullopt; } +c10::optional inline_op_str(const RNGOpType rngtype) { + const char* str = rng_op_type_inline_op2string(rngtype); + return str != nullptr ? c10::optional(std::string(str)) + : c10::nullopt; +} + c10::optional integer_op_str(const BinaryOpType botype) { const char* str = binary_op_integer_op2string(botype); return str != nullptr ? c10::optional(std::string(str)) diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 455a995568349..8f7842788f2be 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -107,6 +107,7 @@ enum class ExprType { UnaryOp, BinaryOp, TernaryOp, + RNGOp, ReductionOp, GroupedReductionOp, BroadcastOp, @@ -169,7 +170,6 @@ enum class UnaryOpType { Log2, BitCast, Neg, - RandLike, Real, Reciprocal, Relu, @@ -242,6 +242,10 @@ enum class BinaryOpType { Xor }; +enum class RNGOpType { + Uniform, +}; + // Return if output of operator should be a boolean bool isIntegerOp(const BinaryOpType bopt); @@ -343,6 +347,7 @@ enum class SwizzleMode { NoSwizzle = 0, Data, Loop }; // float value i.e. sin->sinf bool needFloatSuffix(UnaryOpType t); bool needFloatSuffix(BinaryOpType t); +bool needFloatSuffix(RNGOpType t); ValType promote_type(const ValType& t1, const ValType& t2); DataType promote_type(const DataType& t1, const DataType& t2); @@ -359,6 +364,7 @@ TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const ExprType); TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const UnaryOpType); TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const BinaryOpType); TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const TernaryOpType); +TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const RNGOpType); TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const ParallelType); TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const MemoryType); TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const IterType); @@ -391,6 +397,7 @@ TORCH_CUDA_CU_API bool isParallelTypeVectorize(ParallelType); TORCH_CUDA_CU_API c10::optional inline_op_str(const UnaryOpType); TORCH_CUDA_CU_API c10::optional inline_op_str(const BinaryOpType); +TORCH_CUDA_CU_API c10::optional inline_op_str(const RNGOpType); TORCH_CUDA_CU_API c10::optional integer_op_str(const BinaryOpType); TORCH_CUDA_CU_API c10::optional bool_op_str(const BinaryOpType);