diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index c217ba8b2c74f..0e0d4c4da0e35 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -453,6 +453,24 @@ TensorView* rand(const std::vector& shape, DataType dtype) { return out; } +// TENSOR FACTORIES +TensorView* uniform( + const std::vector& shape, + Val* low, + Val* high, + DataType dtype) { + auto n = shape.size(); + auto out = TensorViewBuilder() + .ndims(n) + .dtype(dtype) + .contiguity(std::vector(n, true)) + .shape(shape) + .build(); + IrBuilder::create( + RNGOpType::UniformRange, out, dtype, std::vector{low, high}); + return out; +} + TensorView* rand_like(TensorView* v) { TORCH_CHECK( isFloatingPointType(v->dtype()), diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index 8b6702ab1d372..66344c74880c0 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -121,12 +121,20 @@ TORCH_CUDA_CU_API WelfordResult Welford( // import IrBuilder just for this one interface. Int* init_N = nullptr); -// TENSOR FACTORIES +// RNG OPERATIONS TORCH_CUDA_CU_API TensorView* rand( const std::vector& shape, DataType dtype); TORCH_CUDA_CU_API Val* rand_like(Val*); TORCH_CUDA_CU_API TensorView* rand_like(TensorView*); + +TORCH_CUDA_CU_API TensorView* uniform( + const std::vector& shape, + Val* low, + Val* high, + DataType dtype); + +// TENSOR FACTORIES TORCH_CUDA_CU_API TensorView* full( const std::vector& shape, Val* fill_value, diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 8f21dac5bee8b..ac78ec2fb3bdc 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -794,7 +794,17 @@ class CudaKernelGenerator : private OptOutConstDispatch { if (needFloatSuffix(op_type) && rop->dtype() == DataType::Float) { code_ << "f"; } - code_ << "(rng_result, rng_component" << rop->name() << ");\n"; + code_ << "(rng_result, rng_component" << rop->name(); + switch (op_type) { + case RNGOpType::UniformRange: { + auto parameters = rop->getParameters(); + TORCH_INTERNAL_ASSERT(parameters.size() == 2); + code_ << ", " << gen(parameters[0]) << ", " << gen(parameters[1]); + break; + } + default:; + } + code_ << ");\n"; } std::string genBinaryOp( diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 7904f4e9ec869..aa8793366a326 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -233,6 +233,7 @@ class TORCH_CUDA_CU_API RNGOp : public Expr { RNGOpType type, Val* out, DataType dtype, + std::vector parameters = {}, int rng_offset = 0, Val* philox_index = nullptr); @@ -254,6 +255,14 @@ class TORCH_CUDA_CU_API RNGOp : public Expr { rng_offset_ = val; } + const std::vector& getParameters() const { + return parameters_; + } + + const std::vector& getShape() const { + return shape_; + } + Val* getPhiloxIndex() const { return philox_index_; } @@ -267,6 +276,8 @@ class TORCH_CUDA_CU_API RNGOp : public Expr { private: const RNGOpType rng_op_type_; const DataType dtype_; + std::vector parameters_; + std::vector shape_; int rng_offset_ = -1; // The index used to feed philox's subsequence and component Val* philox_index_ = nullptr; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 4258d8c6b1377..ec55ad013d016 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -481,14 +481,19 @@ void IrPrinter::handle(const RNGOp* rop) { os_ << rop->getRNGOpType() << "({"; bool first = true; - for (auto i : rop->inputs()) { + for (auto i : rop->getShape()) { if (!first) { os_ << ", "; } handle(i); first = false; } - os_ << "}, " << rop->dtype() << ")"; + os_ << "}"; + for (auto i : rop->getParameters()) { + os_ << ", "; + handle(i); + } + os_ << ", " << rop->dtype() << ")"; indent_size_--; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 0d8d04a89c888..3319bf28a18a9 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -441,18 +441,26 @@ RNGOp::RNGOp( RNGOpType type, Val* out, DataType dtype, + std::vector parameters, int rng_offset, Val* philox_index) : Expr(passkey, ExprType::RNGOp), rng_op_type_(type), dtype_(dtype), + parameters_(std::move(parameters)), rng_offset_(rng_offset), philox_index_(philox_index) { if (out->isA()) { for (auto id : out->as()->getRootDomain()) { - addInput(id->extent()); + shape_.emplace_back(id->extent()); } } + for (auto v : shape_) { + addInput(v); + } + for (auto v : parameters_) { + addInput(v); + } addOutput(out); } @@ -460,6 +468,7 @@ RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner) : Expr(src, ir_cloner), rng_op_type_(src->rng_op_type_), dtype_(src->dtype()), + parameters_(ir_cloner->clone(src->parameters_)), rng_offset_(src->rng_offset_), philox_index_(ir_cloner->clone(src->philox_index_)) {} @@ -477,6 +486,14 @@ bool RNGOp::sameAs(const Statement* other) const { if (dtype_ != other_op->dtype_) { return false; } + if (parameters_.size() != other_op->parameters_.size()) { + return false; + } + for (auto i : c10::irange(parameters_.size())) { + if (!parameters_[i]->sameAs(other_op->parameters_[i])) { + return false; + } + } if (getRNGOffset() != other_op->getRNGOffset()) { return false; } diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index c90acf17b8b33..dba5ee10adabb 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -266,6 +266,10 @@ struct SubstituteInExpr : public OptInDispatch { } void handle(RNGOp* rng_expr) final { + std::vector subsituted_params; + for (auto v : rng_expr->getParameters()) { + subsituted_params.emplace_back(reference_->sameAs(v) ? substitute_ : v); + } auto out = reference_->sameAs(rng_expr->output(0)) ? substitute_ : rng_expr->output(0); expr_ = IrBuilder::create( @@ -273,6 +277,7 @@ struct SubstituteInExpr : public OptInDispatch { rng_expr->getRNGOpType(), out, rng_expr->dtype(), + subsituted_params, rng_expr->getRNGOffset(), rng_expr->getPhiloxIndex()); } diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 4719a5fd7bfdf..cde580c1be9fc 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -109,6 +109,7 @@ void IndexLowering::handle(const RNGOp* rop) { rop->getRNGOpType(), out, rop->dtype(), + rop->getParameters(), rop->getRNGOffset(), philox_index); diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index e4f4d4a0e89ac..12a3de15f4a7f 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -214,8 +214,13 @@ void OptOutMutator::mutate(TernaryOp* top) { void OptOutMutator::mutate(RNGOp* rop) { Val* out = maybeMutated(rop->output(0)); + auto& parameters = rop->getParameters(); + std::vector mutated_parameters; + for (auto v : parameters) { + mutated_parameters.emplace_back(maybeMutated(v)); + } - if (out == rop->output(0)) { + if (out == rop->output(0) && mutated_parameters == parameters) { return; } @@ -227,6 +232,7 @@ void OptOutMutator::mutate(RNGOp* rop) { rop_type, out, rop->dtype(), + mutated_parameters, rop->getRNGOffset(), rop->getPhiloxIndex()); } diff --git a/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu b/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu index 96cec63f8d9ee..75d39e7c0c4b6 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu @@ -67,3 +67,23 @@ __device__ double rng_uniform(const uint4& rng_result, int rng_component) { __device__ float rng_uniformf(const uint4& rng_result, int rng_component) { return uniformf((&rng_result.x)[rng_component]); } + +__device__ double rng_uniform_range( + const uint4& rng_result, + int rng_component, + double from, + double to) { + auto range = to - from; + auto uniform01 = rng_uniform(rng_result, rng_component); + return from + range * uniform01; +} + +__device__ float rng_uniform_rangef( + const uint4& rng_result, + int rng_component, + float from, + float to) { + auto range = to - from; + auto uniform01 = rng_uniformf(rng_result, rng_component); + return from + range * uniform01; +} diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu b/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu index d7a2ce72896d8..5fc61fe6a368d 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu @@ -329,5 +329,41 @@ TEST_F(NVFuserTest, FusionBroadcastingRNGSmemNonSquareTile_CUDA) { TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item()); } +TEST_F(NVFuserTest, FusionUniform_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + Int* size_val = IrBuilder::create(); + Double* low = IrBuilder::create(); + Double* high = IrBuilder::create(); + fusion->addInput(size_val); + fusion->addInput(low); + fusion->addInput(high); + TensorView* tv0 = uniform({size_val}, low, high, DataType::Float); + TensorView* tv1 = uniform({size_val}, low, high, DataType::Double); + fusion->addOutput(tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) { + at::manual_seed(0); + auto cg_outputs = fec.runFusionWithInputs({size, -1.0, 1.0}); + + at::manual_seed(0); + auto ref0 = generate_uniform(size, kFloat) * 2 - 1; + auto ref1 = generate_uniform(size, kDouble) * 2 - 1; + + testValidate( + fec.fusion(), + cg_outputs, + {size, -1.0, 1.0}, + {ref0, ref1}, + __LINE__, + __FILE__); + } +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index a333eb4d87ee2..3b8f380683ed2 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -673,6 +673,8 @@ static const char* rng_op_type_inline_op2string(RNGOpType t) { switch (t) { case RNGOpType::Uniform: return "rng_uniform"; + case RNGOpType::UniformRange: + return "rng_uniform_range"; default: break; } @@ -711,6 +713,8 @@ static const char* rng_op_type2string(RNGOpType t) { switch (t) { case RNGOpType::Uniform: return "rng_uniform"; + case RNGOpType::UniformRange: + return "rng_uniform_range"; default: TORCH_INTERNAL_ASSERT(false, "Unexpected RNGOpType"); } diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index de4af398820fc..4aa894113e993 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -248,7 +248,8 @@ enum class BinaryOpType { }; enum class RNGOpType { - Uniform, + Uniform, // Uniform in [0, 1) + UniformRange, // Uniform in [low, high] }; // Return if output of operator should be a boolean