Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for uniform RNG #1986

Merged
merged 2 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,24 @@ TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
return out;
}

// TENSOR FACTORIES
TensorView* uniform(
const std::vector<Val*>& shape,
Val* low,
Val* high,
DataType dtype) {
auto n = shape.size();
auto out = TensorViewBuilder()
.ndims(n)
.dtype(dtype)
.contiguity(std::vector<bool>(n, true))
.shape(shape)
.build();
IrBuilder::create<RNGOp>(
RNGOpType::UniformRange, out, dtype, std::vector<Val*>{low, high});
return out;
}

TensorView* rand_like(TensorView* v) {
TORCH_CHECK(
isFloatingPointType(v->dtype()),
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,20 @@ TORCH_CUDA_CU_API WelfordResult Welford(
// import IrBuilder just for this one interface.
Int* init_N = nullptr);

// TENSOR FACTORIES
// RNG OPERATIONS
TORCH_CUDA_CU_API TensorView* rand(
const std::vector<Val*>& shape,
DataType dtype);
TORCH_CUDA_CU_API Val* rand_like(Val*);
TORCH_CUDA_CU_API TensorView* rand_like(TensorView*);

TORCH_CUDA_CU_API TensorView* uniform(
const std::vector<Val*>& shape,
Val* low,
Val* high,
DataType dtype);

// TENSOR FACTORIES
TORCH_CUDA_CU_API TensorView* full(
const std::vector<Val*>& shape,
Val* fill_value,
Expand Down
12 changes: 11 additions & 1 deletion torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,17 @@ class CudaKernelGenerator : private OptOutConstDispatch {
if (needFloatSuffix(op_type) && rop->dtype() == DataType::Float) {
code_ << "f";
}
code_ << "(rng_result, rng_component" << rop->name() << ");\n";
code_ << "(rng_result, rng_component" << rop->name();
switch (op_type) {
case RNGOpType::UniformRange: {
auto parameters = rop->getParameters();
TORCH_INTERNAL_ASSERT(parameters.size() == 2);
code_ << ", " << gen(parameters[0]) << ", " << gen(parameters[1]);
break;
}
default:;
}
code_ << ");\n";
}

std::string genBinaryOp(
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {
RNGOpType type,
Val* out,
DataType dtype,
std::vector<Val*> parameters = {},
int rng_offset = 0,
Val* philox_index = nullptr);

Expand All @@ -254,6 +255,14 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {
rng_offset_ = val;
}

const std::vector<Val*>& getParameters() const {
return parameters_;
}

const std::vector<Val*>& getShape() const {
return shape_;
}

Val* getPhiloxIndex() const {
return philox_index_;
}
Expand All @@ -267,6 +276,8 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {
private:
const RNGOpType rng_op_type_;
const DataType dtype_;
std::vector<Val*> parameters_;
std::vector<Val*> shape_;
int rng_offset_ = -1;
// The index used to feed philox's subsequence and component
Val* philox_index_ = nullptr;
Expand Down
9 changes: 7 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_iostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,14 +481,19 @@ void IrPrinter::handle(const RNGOp* rop) {

os_ << rop->getRNGOpType() << "({";
bool first = true;
for (auto i : rop->inputs()) {
for (auto i : rop->getShape()) {
if (!first) {
os_ << ", ";
}
handle(i);
first = false;
}
os_ << "}, " << rop->dtype() << ")";
os_ << "}";
for (auto i : rop->getParameters()) {
os_ << ", ";
handle(i);
}
os_ << ", " << rop->dtype() << ")";

indent_size_--;

Expand Down
19 changes: 18 additions & 1 deletion torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,25 +441,34 @@ RNGOp::RNGOp(
RNGOpType type,
Val* out,
DataType dtype,
std::vector<Val*> parameters,
int rng_offset,
Val* philox_index)
: Expr(passkey, ExprType::RNGOp),
rng_op_type_(type),
dtype_(dtype),
parameters_(std::move(parameters)),
rng_offset_(rng_offset),
philox_index_(philox_index) {
if (out->isA<TensorView>()) {
for (auto id : out->as<TensorView>()->getRootDomain()) {
addInput(id->extent());
shape_.emplace_back(id->extent());
}
}
for (auto v : shape_) {
addInput(v);
}
for (auto v : parameters_) {
addInput(v);
}
addOutput(out);
}

RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
rng_op_type_(src->rng_op_type_),
dtype_(src->dtype()),
parameters_(ir_cloner->clone(src->parameters_)),
rng_offset_(src->rng_offset_),
philox_index_(ir_cloner->clone(src->philox_index_)) {}

Expand All @@ -477,6 +486,14 @@ bool RNGOp::sameAs(const Statement* other) const {
if (dtype_ != other_op->dtype_) {
return false;
}
if (parameters_.size() != other_op->parameters_.size()) {
return false;
}
for (auto i : c10::irange(parameters_.size())) {
if (!parameters_[i]->sameAs(other_op->parameters_[i])) {
return false;
}
}
if (getRNGOffset() != other_op->getRNGOffset()) {
return false;
}
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,18 @@ struct SubstituteInExpr : public OptInDispatch {
}

void handle(RNGOp* rng_expr) final {
std::vector<Val*> subsituted_params;
for (auto v : rng_expr->getParameters()) {
subsituted_params.emplace_back(reference_->sameAs(v) ? substitute_ : v);
}
auto out = reference_->sameAs(rng_expr->output(0)) ? substitute_
: rng_expr->output(0);
expr_ = IrBuilder::create<RNGOp>(
rng_expr->container(),
rng_expr->getRNGOpType(),
out,
rng_expr->dtype(),
subsituted_params,
rng_expr->getRNGOffset(),
rng_expr->getPhiloxIndex());
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/lower_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ void IndexLowering::handle(const RNGOp* rop) {
rop->getRNGOpType(),
out,
rop->dtype(),
rop->getParameters(),
rop->getRNGOffset(),
philox_index);

Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/codegen/cuda/mutator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,13 @@ void OptOutMutator::mutate(TernaryOp* top) {

void OptOutMutator::mutate(RNGOp* rop) {
Val* out = maybeMutated(rop->output(0));
auto& parameters = rop->getParameters();
std::vector<Val*> mutated_parameters;
for (auto v : parameters) {
mutated_parameters.emplace_back(maybeMutated(v));
}

if (out == rop->output(0)) {
if (out == rop->output(0) && mutated_parameters == parameters) {
return;
}

Expand All @@ -227,6 +232,7 @@ void OptOutMutator::mutate(RNGOp* rop) {
rop_type,
out,
rop->dtype(),
mutated_parameters,
rop->getRNGOffset(),
rop->getPhiloxIndex());
}
Expand Down
20 changes: 20 additions & 0 deletions torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,23 @@ __device__ double rng_uniform(const uint4& rng_result, int rng_component) {
__device__ float rng_uniformf(const uint4& rng_result, int rng_component) {
return uniformf((&rng_result.x)[rng_component]);
}

__device__ double rng_uniform_range(
const uint4& rng_result,
int rng_component,
double from,
double to) {
auto range = to - from;
auto uniform01 = rng_uniform(rng_result, rng_component);
return from + range * uniform01;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do this in codegen and not just in arith?

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the case of uniform distribution, yes, it is possible to make it a composite operator. But I still prefer to do it in C++, because I want to have a unified approach for doing all distributions, and for the case of other distributions, it makes more sense to do in C++, for example, for normal distribution curand header does the following:

QUALIFIERS double2
_curand_box_muller_double(unsigned int x0, unsigned int x1,
                          unsigned int y0, unsigned int y1)
{
    double2 result;
    unsigned long long zx = (unsigned long long)x0 ^
        ((unsigned long long)x1 << (53 - 32));
    double u = zx * CURAND_2POW53_INV_DOUBLE + (CURAND_2POW53_INV_DOUBLE/2.0);
    unsigned long long zy = (unsigned long long)y0 ^
        ((unsigned long long)y1 << (53 - 32));
    double v = zy * (CURAND_2POW53_INV_DOUBLE*2.0) + CURAND_2POW53_INV_DOUBLE;
    double s = sqrt(-2.0 * log(u));

#if __CUDA_ARCH__ > 0
    sincospi(v, &result.x, &result.y);
#else
    result.x = sin(v*CURAND_PI_DOUBLE);
    result.y = cos(v*CURAND_PI_DOUBLE);
#endif
    result.x *= s;
    result.y *= s;

    return result;
}

QUALIFIERS double curand_normal_double(curandStatePhilox4_32_10_t *state)
{
    if(state->boxmuller_flag_double != EXTRA_FLAG_NORMAL) {
        uint4 _x;
        _x = curand4(state);
        double2 v = _curand_box_muller_double(_x.x, _x.y, _x.z, _x.w);
        state->boxmuller_extra_double = v.y;
        state->boxmuller_flag_double = EXTRA_FLAG_NORMAL;
        return v.x;
    }
    state->boxmuller_flag_double = 0;
    return state->boxmuller_extra_double;
}

}

__device__ float rng_uniform_rangef(
const uint4& rng_result,
int rng_component,
float from,
float to) {
auto range = to - from;
auto uniform01 = rng_uniformf(rng_result, rng_component);
return from + range * uniform01;
}
36 changes: 36 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu
Original file line number Diff line number Diff line change
Expand Up @@ -329,5 +329,41 @@ TEST_F(NVFuserTest, FusionBroadcastingRNGSmemNonSquareTile_CUDA) {
TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item<bool>());
}

TEST_F(NVFuserTest, FusionUniform_CUDA) {
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);

Int* size_val = IrBuilder::create<Int>();
Double* low = IrBuilder::create<Double>();
Double* high = IrBuilder::create<Double>();
fusion->addInput(size_val);
fusion->addInput(low);
fusion->addInput(high);
TensorView* tv0 = uniform({size_val}, low, high, DataType::Float);
TensorView* tv1 = uniform({size_val}, low, high, DataType::Double);
fusion->addOutput(tv0);
fusion->addOutput(tv1);

FusionExecutorCache fec(std::move(fusion_ptr));

for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) {
at::manual_seed(0);
auto cg_outputs = fec.runFusionWithInputs({size, -1.0, 1.0});

at::manual_seed(0);
auto ref0 = generate_uniform(size, kFloat) * 2 - 1;
naoyam marked this conversation as resolved.
Show resolved Hide resolved
auto ref1 = generate_uniform(size, kDouble) * 2 - 1;

testValidate(
naoyam marked this conversation as resolved.
Show resolved Hide resolved
fec.fusion(),
cg_outputs,
{size, -1.0, 1.0},
{ref0, ref1},
__LINE__,
__FILE__);
}
}

} // namespace jit
} // namespace torch
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,8 @@ static const char* rng_op_type_inline_op2string(RNGOpType t) {
switch (t) {
case RNGOpType::Uniform:
return "rng_uniform";
case RNGOpType::UniformRange:
return "rng_uniform_range";
default:
break;
}
Expand Down Expand Up @@ -711,6 +713,8 @@ static const char* rng_op_type2string(RNGOpType t) {
switch (t) {
case RNGOpType::Uniform:
return "rng_uniform";
case RNGOpType::UniformRange:
return "rng_uniform_range";
default:
TORCH_INTERNAL_ASSERT(false, "Unexpected RNGOpType");
}
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ enum class BinaryOpType {
};

enum class RNGOpType {
Uniform,
Uniform, // Uniform in [0, 1)
UniformRange, // Uniform in [low, high]
};

// Return if output of operator should be a boolean
Expand Down