Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nullary RNGOp #1892

Merged
merged 31 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c87b048
Nullary RNGOp
zasdfgbnm Aug 5, 2022
ed6d5e7
Merge branch 'devel' of github.com:csarofeen/pytorch into RNGOp
zasdfgbnm Aug 9, 2022
bde2509
fix
zasdfgbnm Aug 9, 2022
e8bd53f
Merge branch 'devel' of github.com:csarofeen/pytorch into RNGOp
zasdfgbnm Aug 9, 2022
fc71f51
fixes
zasdfgbnm Aug 9, 2022
75a7b31
fix
zasdfgbnm Aug 9, 2022
dbc64fd
fix
zasdfgbnm Aug 9, 2022
9ed7831
fix
zasdfgbnm Aug 10, 2022
450d6d6
Merge branch 'devel' of github.com:csarofeen/pytorch into RNGOp
zasdfgbnm Aug 10, 2022
e36ad64
fixes
zasdfgbnm Aug 10, 2022
dd0b2cf
fix
zasdfgbnm Aug 10, 2022
ba21499
fix mutator
zasdfgbnm Aug 10, 2022
183fa6b
fix mutator
zasdfgbnm Aug 10, 2022
614d375
side cleanup
zasdfgbnm Aug 10, 2022
220831f
Merge branch 'devel' of github.com:csarofeen/pytorch into RNGOp
zasdfgbnm Aug 10, 2022
4ec07a6
lint
zasdfgbnm Aug 10, 2022
7664d36
getMaybeExpandedExtent
zasdfgbnm Aug 10, 2022
5022643
revert
zasdfgbnm Aug 10, 2022
e10ce0f
fix
zasdfgbnm Aug 10, 2022
f600632
fix
zasdfgbnm Aug 10, 2022
1036e33
revert
zasdfgbnm Aug 10, 2022
0a36802
name
zasdfgbnm Aug 10, 2022
e353c8a
fix
zasdfgbnm Aug 10, 2022
9958e73
fix sameAs
zasdfgbnm Aug 11, 2022
20affd0
Merge branch 'expand-mutator' of github.com:csarofeen/pytorch into RNGOp
zasdfgbnm Aug 11, 2022
1cf4694
Merge branch 'devel' of github.com:csarofeen/pytorch into RNGOp
zasdfgbnm Aug 11, 2022
a492088
fix
zasdfgbnm Aug 11, 2022
7fbf70e
Merge branch 'devel' of github.com:csarofeen/pytorch into RNGOp
zasdfgbnm Aug 23, 2022
185e8f9
save
zasdfgbnm Aug 23, 2022
0a5e88b
cleanup
zasdfgbnm Aug 23, 2022
3ae696a
more cleanup
zasdfgbnm Aug 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 22 additions & 27 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,19 @@ Val* getMaximumValue(DataType v) {

} // namespace

// TENSOR FACTORIES
TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
auto n = shape.size();
auto out = TensorViewBuilder()
.ndims(n)
.dtype(dtype)
.contiguity(std::vector<bool>(n, true))
.shape(shape)
.build();
IrBuilder::create<RNGOp>(RNGOpType::Uniform, out);
return out;
}

Val* castOp(DataType dtype, Val* v1) {
if (v1->getDataType().value() == dtype) {
return set(v1);
Expand Down Expand Up @@ -404,17 +417,6 @@ Val* unaryOp(UnaryOpType type, Val* v1) {
TORCH_INTERNAL_ASSERT(
type != UnaryOpType::Address,
"The reference operator & is not accessible in the Fusion IR");

// TODO: We should add the following, but we need to go through schedulers
// and make sure all calls to "fusion->inputs" includes the output of RandLike
//
// If rand like, there isn't a real dependency on the input value, so map it
// to a dummy scalar. if
//
// (type == UnaryOpType::RandLike) {
// v1 = new NamedScalar("__rnd", v1->getDataType().value());
// }

Val* out = newValLike(v1, v1->getDataType().value());
IrBuilder::create<UnaryOp>(type, out, v1);
return out;
Expand Down Expand Up @@ -469,28 +471,21 @@ NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
NVFUSER_DEFINE_UNARY_OP(print, Print)
#undef NVFUSER_DEFINE_UNARY_OP

Val* randlike(Val* v) {
TensorView* randlike(TensorView* v) {
TORCH_CHECK(
isFloatingPointType(v->dtype()),
"input must have floating point type, but got ",
v->dtype());
auto rand_vals = unaryOp(UnaryOpType::RandLike, v);
return where(
eq(rand_vals, IrBuilder::create<Double>(1.0)),
IrBuilder::create<Double>(0.0),
rand_vals);
std::vector<Val*> shape;
shape.reserve(v->getMaybeRFactorDomain().size());
for (auto id : v->getMaybeRFactorDomain()) {
shape.emplace_back(id->getMaybeExpandedExtent());
}
return rand(shape, v->dtype());
}

TensorView* randlike(TensorView* v) {
TORCH_CHECK(
isFloatingPointType(v->dtype()),
"input must have floating point type, but got ",
v->dtype());
auto rand_vals = unaryOp(UnaryOpType::RandLike, v);
return where(
eq(rand_vals, IrBuilder::create<Double>(1.0)),
IrBuilder::create<Double>(0.0),
rand_vals);
Val* randlike(Val* v) {
return randlike(v->as<TensorView>());
}

Val* bitwise_not(Val* v) {
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ TORCH_CUDA_CU_API WelfordResult Welford(
// import IrBuilder just for this one interface.
Int* init_N = nullptr);

// TENSOR FACTORIES
TORCH_CUDA_CU_API TensorView* rand(
const std::vector<Int*>& shape,
DataType dtype);

// UNARY OPERATIONS
// abs
TORCH_CUDA_CU_API Val* abs(Val*);
Expand Down
59 changes: 30 additions & 29 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,34 +706,12 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}

if (!print_inline_) {
if (op_type == UnaryOpType::RandLike) {
auto out_tv = uop->out()->as<kir::TensorIndex>()->view();
auto index = genTensorIndex(uop->in()->as<kir::TensorIndex>());
int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4;
indent() << "nvfuser_index_t rng_subseq" << uop->name() << " = ("
<< index << ") / " << multiple << ";\n";
indent() << "nvfuser_index_t rng_component" << uop->name() << " = ("
<< index << ") % " << multiple << ";\n";
indent() << "nvfuser_index_t rng_offset" << uop->name() << " = "
<< uop->getRNGOffset() << ";\n";
indent() << "if (rng_subseq != rng_subseq" << uop->name()
<< " || rng_offset != rng_offset" << uop->name() << ") {\n";
indent() << " rng_result = philox(philox_args.seed_, rng_subseq"
<< uop->name() << ", philox_offset / 4 + rng_offset"
<< uop->name() << ");\n";
indent() << " rng_subseq = rng_subseq" << uop->name() << ";\n";
indent() << " rng_offset = rng_offset" << uop->name() << ";\n";
indent() << "}\n";
}

indent() << gen(uop->out());
if (!uop->out()->isScalar() && !uop->in()->isScalar()) {
code_ << "\n";
indent() << kTab;
}
code_ << " = ";
} else {
TORCH_INTERNAL_ASSERT(op_type != UnaryOpType::RandLike);
}

if (auto op = inline_op_str(op_type)) {
Expand Down Expand Up @@ -762,20 +740,43 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}

code_ << "(";
if (op_type == UnaryOpType::RandLike) {
code_ << "rng_result, rng_component" << uop->name();
} else {
code_ << gen(uop->in());
}
code_ << ")";
code_ << "(" << gen(uop->in()) << ")";
}

if (!print_inline_) {
code_ << ";\n";
}
}

void handle(const RNGOp* rop) final {
// TODO: TORCH_INTERNAL_ASSERT that the scheduler correctly creates an
// innermost ID of size 4 (float) or size 2 (double)?
auto out_tv = rop->output(0)->as<kir::TensorIndex>()->view();
auto index = genTensorIndex(rop->getPhiloxIndex()->as<kir::TensorIndex>());
int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4;
indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = (" << index
<< ") / " << multiple << ";\n";
indent() << "nvfuser_index_t rng_component" << rop->name() << " = ("
<< index << ") % " << multiple << ";\n";
indent() << "nvfuser_index_t rng_offset" << rop->name() << " = "
<< rop->getRNGOffset() << ";\n";
indent() << "if (rng_subseq != rng_subseq" << rop->name()
<< " || rng_offset != rng_offset" << rop->name() << ") {\n";
indent() << " rng_result = philox(philox_args.seed_, rng_subseq"
<< rop->name() << ", philox_offset / 4 + rng_offset" << rop->name()
<< ");\n";
indent() << " rng_subseq = rng_subseq" << rop->name() << ";\n";
indent() << " rng_offset = rng_offset" << rop->name() << ";\n";
indent() << "}\n";
auto op_type = rop->getRNGOpType();
indent() << gen(rop->output(0)) << " = " << op_type;
if (needFloatSuffix(op_type) &&
rop->output(0)->dtype() == DataType::Float) {
code_ << "f";
}
code_ << "(rng_result, rng_component" << rop->name() << ");\n";
}

std::string genBinaryOp(
BinaryOpType op_type,
DataType data_type,
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::TernaryOp:
ptr(handler)->handle(expr->as<TernaryOp>());
return;
case ExprType::RNGOp:
ptr(handler)->handle(expr->as<RNGOp>());
return;
case ExprType::ReductionOp:
ptr(handler)->handle(expr->as<ReductionOp>());
return;
Expand Down Expand Up @@ -278,6 +281,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::TernaryOp:
ptr(handler)->handle(expr->as<TernaryOp>());
return;
case ExprType::RNGOp:
ptr(handler)->handle(expr->as<RNGOp>());
return;
case ExprType::ReductionOp:
ptr(handler)->handle(expr->as<ReductionOp>());
return;
Expand Down Expand Up @@ -460,6 +466,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::TernaryOp:
ptr(mutator)->mutate(expr->as<TernaryOp>());
return;
case ExprType::RNGOp:
ptr(mutator)->mutate(expr->as<RNGOp>());
return;
case ExprType::ReductionOp:
ptr(mutator)->mutate(expr->as<ReductionOp>());
return;
Expand Down Expand Up @@ -707,6 +716,9 @@ void OptOutConstDispatch::handle(const BinaryOp* stmt) {
void OptOutConstDispatch::handle(const TernaryOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const RNGOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const ReductionOp* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -851,6 +863,9 @@ void OptOutDispatch::handle(BinaryOp* stmt) {
void OptOutDispatch::handle(TernaryOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(RNGOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(ReductionOp* stmt) {
unhandled(stmt);
}
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class NamedScalar;
class UnaryOp;
class BinaryOp;
class TernaryOp;
class RNGOp;
class ReductionOp;
class GroupedReductionOp;
class WelfordOp;
Expand Down Expand Up @@ -143,6 +144,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const UnaryOp* stmt);
virtual void handle(const BinaryOp* stmt);
virtual void handle(const TernaryOp* stmt);
virtual void handle(const RNGOp* stmt);
virtual void handle(const ReductionOp* stmt);
virtual void handle(const GroupedReductionOp* stmt);
virtual void handle(const WelfordOp* stmt);
Expand Down Expand Up @@ -206,6 +208,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(UnaryOp* stmt);
virtual void handle(BinaryOp* stmt);
virtual void handle(TernaryOp* stmt);
virtual void handle(RNGOp* stmt);
virtual void handle(ReductionOp* stmt);
virtual void handle(GroupedReductionOp* stmt);
virtual void handle(WelfordOp* stmt);
Expand Down Expand Up @@ -310,6 +313,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
virtual void mutate(UnaryOp*);
virtual void mutate(BinaryOp*);
virtual void mutate(TernaryOp*);
virtual void mutate(RNGOp*);
virtual void mutate(ReductionOp*);
virtual void mutate(GroupedReductionOp*);
virtual void mutate(WelfordOp*);
Expand Down
23 changes: 18 additions & 5 deletions torch/csrc/jit/codegen/cuda/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,18 @@ void Fusion::printMath(bool from_outputs_only) {
std::cout << "}\n\n";
}

std::vector<Val*> Fusion::inputsAndCreated() {
auto result = inputs_;
for (auto expr : exprs()) {
if (expr->inputs().empty()) {
for (auto v : expr->outputs()) {
result.emplace_back(v);
}
}
}
return result;
}

void Fusion::printTransforms() {
FUSER_PERF_SCOPE("Fusion::printTransforms");

Expand Down Expand Up @@ -531,14 +543,15 @@ Expr* Fusion::definition(const Val* val) const {

// Indicate to kernel to set itself up to generate random numbers
bool Fusion::isStochastic() {
for (auto expr : exprs())
if (expr->getExprType() == ExprType::UnaryOp)
if (expr->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::RandLike)
return true;
for (auto expr : exprs()) {
if (expr->getExprType() == ExprType::RNGOp) {
return true;
}
}
return false;
}

std::vector<Val*> Fusion::getTerminatingOutputs() {
std::vector<Val*> Fusion::getTerminatingOutputs() const {
FUSER_PERF_SCOPE("getTerminatingOutputs");

auto is_reachable_to_output = [](Val* val) {
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/codegen/cuda/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,13 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer {
return inputs_;
}

std::vector<Val*> inputsAndCreated();

const auto& outputs() const {
return outputs_;
}

std::vector<Val*> getTerminatingOutputs();
std::vector<Val*> getTerminatingOutputs() const;

// Aliasing output to input value, this is a WAR to allow inplace update on
// input tensor.
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Expr;
class Val;
class UnaryOp;
class BinaryOp;
class RNGOp;
class IterDomain;
class IrCloner;
class IrContainer;
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ IR_BUILDER_INSTANTIATE(ViewOp)
IR_BUILDER_INSTANTIATE(UnaryOp)
IR_BUILDER_INSTANTIATE(BinaryOp)
IR_BUILDER_INSTANTIATE(TernaryOp)
IR_BUILDER_INSTANTIATE(RNGOp)
IR_BUILDER_INSTANTIATE(ReductionOp)
IR_BUILDER_INSTANTIATE(GroupedReductionOp)
IR_BUILDER_INSTANTIATE(WelfordOp)
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_cloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ void IrCloner::handle(const TernaryOp* op) {
clone_ = IrBuilder::clone(op, this);
}

void IrCloner::handle(const RNGOp* op) {
clone_ = IrBuilder::clone(op, this);
}

void IrCloner::handle(const BroadcastOp* op) {
clone_ = IrBuilder::clone(op, this);
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_cloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch {
void handle(const UnaryOp*) override;
void handle(const BinaryOp*) override;
void handle(const TernaryOp*) override;
void handle(const RNGOp*) override;
void handle(const BroadcastOp*) override;
void handle(const ReductionOp*) override;
void handle(const GroupedReductionOp*) override;
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_graphviz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,16 @@ void IrGraphGenerator::handle(const TernaryOp* op) {
addArc(op, op->out());
}

void IrGraphGenerator::handle(const RNGOp* op) {
// node
std::stringstream label;
label << op->getRNGOpType();
printExpr(op, label.str());

// inputs & outputs
addArc(op, op->output(0));
}

void IrGraphGenerator::handle(const BroadcastOp* op) {
printExpr(op, "Broadcast");
addArc(op->in(), op);
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_graphviz.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch {
void handle(const UnaryOp*) override;
void handle(const BinaryOp*) override;
void handle(const TernaryOp*) override;
void handle(const RNGOp*) override;
void handle(const BroadcastOp*) override;
void handle(const ReductionOp*) override;

Expand Down
Loading