Skip to content

Commit

Permalink
Add support for uniform RNG (#1986)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Sep 26, 2022
1 parent eb1dad1 commit f262d9c
Show file tree
Hide file tree
Showing 13 changed files with 149 additions and 7 deletions.
18 changes: 18 additions & 0 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,24 @@ TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
return out;
}

// TENSOR FACTORIES
TensorView* uniform(
const std::vector<Val*>& shape,
Val* low,
Val* high,
DataType dtype) {
auto n = shape.size();
auto out = TensorViewBuilder()
.ndims(n)
.dtype(dtype)
.contiguity(std::vector<bool>(n, true))
.shape(shape)
.build();
IrBuilder::create<RNGOp>(
RNGOpType::UniformRange, out, dtype, std::vector<Val*>{low, high});
return out;
}

TensorView* rand_like(TensorView* v) {
TORCH_CHECK(
isFloatingPointType(v->dtype()),
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,20 @@ TORCH_CUDA_CU_API WelfordResult Welford(
// import IrBuilder just for this one interface.
Int* init_N = nullptr);

// TENSOR FACTORIES
// RNG OPERATIONS
TORCH_CUDA_CU_API TensorView* rand(
const std::vector<Val*>& shape,
DataType dtype);
TORCH_CUDA_CU_API Val* rand_like(Val*);
TORCH_CUDA_CU_API TensorView* rand_like(TensorView*);

TORCH_CUDA_CU_API TensorView* uniform(
const std::vector<Val*>& shape,
Val* low,
Val* high,
DataType dtype);

// TENSOR FACTORIES
TORCH_CUDA_CU_API TensorView* full(
const std::vector<Val*>& shape,
Val* fill_value,
Expand Down
12 changes: 11 additions & 1 deletion torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,17 @@ class CudaKernelGenerator : private OptOutConstDispatch {
if (needFloatSuffix(op_type) && rop->dtype() == DataType::Float) {
code_ << "f";
}
code_ << "(rng_result, rng_component" << rop->name() << ");\n";
code_ << "(rng_result, rng_component" << rop->name();
switch (op_type) {
case RNGOpType::UniformRange: {
auto parameters = rop->getParameters();
TORCH_INTERNAL_ASSERT(parameters.size() == 2);
code_ << ", " << gen(parameters[0]) << ", " << gen(parameters[1]);
break;
}
default:;
}
code_ << ");\n";
}

std::string genBinaryOp(
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {
RNGOpType type,
Val* out,
DataType dtype,
std::vector<Val*> parameters = {},
int rng_offset = 0,
Val* philox_index = nullptr);

Expand All @@ -254,6 +255,14 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {
rng_offset_ = val;
}

const std::vector<Val*>& getParameters() const {
return parameters_;
}

const std::vector<Val*>& getShape() const {
return shape_;
}

Val* getPhiloxIndex() const {
return philox_index_;
}
Expand All @@ -267,6 +276,8 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {
private:
const RNGOpType rng_op_type_;
const DataType dtype_;
std::vector<Val*> parameters_;
std::vector<Val*> shape_;
int rng_offset_ = -1;
// The index used to feed philox's subsequence and component
Val* philox_index_ = nullptr;
Expand Down
9 changes: 7 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_iostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,14 +481,19 @@ void IrPrinter::handle(const RNGOp* rop) {

os_ << rop->getRNGOpType() << "({";
bool first = true;
for (auto i : rop->inputs()) {
for (auto i : rop->getShape()) {
if (!first) {
os_ << ", ";
}
handle(i);
first = false;
}
os_ << "}, " << rop->dtype() << ")";
os_ << "}";
for (auto i : rop->getParameters()) {
os_ << ", ";
handle(i);
}
os_ << ", " << rop->dtype() << ")";

indent_size_--;

Expand Down
19 changes: 18 additions & 1 deletion torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,25 +441,34 @@ RNGOp::RNGOp(
RNGOpType type,
Val* out,
DataType dtype,
std::vector<Val*> parameters,
int rng_offset,
Val* philox_index)
: Expr(passkey, ExprType::RNGOp),
rng_op_type_(type),
dtype_(dtype),
parameters_(std::move(parameters)),
rng_offset_(rng_offset),
philox_index_(philox_index) {
if (out->isA<TensorView>()) {
for (auto id : out->as<TensorView>()->getRootDomain()) {
addInput(id->extent());
shape_.emplace_back(id->extent());
}
}
for (auto v : shape_) {
addInput(v);
}
for (auto v : parameters_) {
addInput(v);
}
addOutput(out);
}

RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
rng_op_type_(src->rng_op_type_),
dtype_(src->dtype()),
parameters_(ir_cloner->clone(src->parameters_)),
rng_offset_(src->rng_offset_),
philox_index_(ir_cloner->clone(src->philox_index_)) {}

Expand All @@ -477,6 +486,14 @@ bool RNGOp::sameAs(const Statement* other) const {
if (dtype_ != other_op->dtype_) {
return false;
}
if (parameters_.size() != other_op->parameters_.size()) {
return false;
}
for (auto i : c10::irange(parameters_.size())) {
if (!parameters_[i]->sameAs(other_op->parameters_[i])) {
return false;
}
}
if (getRNGOffset() != other_op->getRNGOffset()) {
return false;
}
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,18 @@ struct SubstituteInExpr : public OptInDispatch {
}

void handle(RNGOp* rng_expr) final {
std::vector<Val*> subsituted_params;
for (auto v : rng_expr->getParameters()) {
subsituted_params.emplace_back(reference_->sameAs(v) ? substitute_ : v);
}
auto out = reference_->sameAs(rng_expr->output(0)) ? substitute_
: rng_expr->output(0);
expr_ = IrBuilder::create<RNGOp>(
rng_expr->container(),
rng_expr->getRNGOpType(),
out,
rng_expr->dtype(),
subsituted_params,
rng_expr->getRNGOffset(),
rng_expr->getPhiloxIndex());
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/lower_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ void IndexLowering::handle(const RNGOp* rop) {
rop->getRNGOpType(),
out,
rop->dtype(),
rop->getParameters(),
rop->getRNGOffset(),
philox_index);

Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/codegen/cuda/mutator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,13 @@ void OptOutMutator::mutate(TernaryOp* top) {

void OptOutMutator::mutate(RNGOp* rop) {
Val* out = maybeMutated(rop->output(0));
auto& parameters = rop->getParameters();
std::vector<Val*> mutated_parameters;
for (auto v : parameters) {
mutated_parameters.emplace_back(maybeMutated(v));
}

if (out == rop->output(0)) {
if (out == rop->output(0) && mutated_parameters == parameters) {
return;
}

Expand All @@ -227,6 +232,7 @@ void OptOutMutator::mutate(RNGOp* rop) {
rop_type,
out,
rop->dtype(),
mutated_parameters,
rop->getRNGOffset(),
rop->getPhiloxIndex());
}
Expand Down
20 changes: 20 additions & 0 deletions torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,23 @@ __device__ double rng_uniform(const uint4& rng_result, int rng_component) {
__device__ float rng_uniformf(const uint4& rng_result, int rng_component) {
return uniformf((&rng_result.x)[rng_component]);
}

__device__ double rng_uniform_range(
const uint4& rng_result,
int rng_component,
double from,
double to) {
auto range = to - from;
auto uniform01 = rng_uniform(rng_result, rng_component);
return from + range * uniform01;
}

__device__ float rng_uniform_rangef(
const uint4& rng_result,
int rng_component,
float from,
float to) {
auto range = to - from;
auto uniform01 = rng_uniformf(rng_result, rng_component);
return from + range * uniform01;
}
36 changes: 36 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu
Original file line number Diff line number Diff line change
Expand Up @@ -329,5 +329,41 @@ TEST_F(NVFuserTest, FusionBroadcastingRNGSmemNonSquareTile_CUDA) {
TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item<bool>());
}
TEST_F(NVFuserTest, FusionUniform_CUDA) {
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);
Int* size_val = IrBuilder::create<Int>();
Double* low = IrBuilder::create<Double>();
Double* high = IrBuilder::create<Double>();
fusion->addInput(size_val);
fusion->addInput(low);
fusion->addInput(high);
TensorView* tv0 = uniform({size_val}, low, high, DataType::Float);
TensorView* tv1 = uniform({size_val}, low, high, DataType::Double);
fusion->addOutput(tv0);
fusion->addOutput(tv1);
FusionExecutorCache fec(std::move(fusion_ptr));
for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) {
at::manual_seed(0);
auto cg_outputs = fec.runFusionWithInputs({size, -1.0, 1.0});
at::manual_seed(0);
auto ref0 = generate_uniform(size, kFloat) * 2 - 1;
auto ref1 = generate_uniform(size, kDouble) * 2 - 1;
testValidate(
fec.fusion(),
cg_outputs,
{size, -1.0, 1.0},
{ref0, ref1},
__LINE__,
__FILE__);
}
}
} // namespace jit
} // namespace torch
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,8 @@ static const char* rng_op_type_inline_op2string(RNGOpType t) {
switch (t) {
case RNGOpType::Uniform:
return "rng_uniform";
case RNGOpType::UniformRange:
return "rng_uniform_range";
default:
break;
}
Expand Down Expand Up @@ -711,6 +713,8 @@ static const char* rng_op_type2string(RNGOpType t) {
switch (t) {
case RNGOpType::Uniform:
return "rng_uniform";
case RNGOpType::UniformRange:
return "rng_uniform_range";
default:
TORCH_INTERNAL_ASSERT(false, "Unexpected RNGOpType");
}
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ enum class BinaryOpType {
};

enum class RNGOpType {
Uniform,
Uniform, // Uniform in [0, 1)
UniformRange, // Uniform in [low, high]
};

// Return if output of operator should be a boolean
Expand Down

0 comments on commit f262d9c

Please sign in to comment.