Skip to content

Commit

Permalink
Tensor factories must set the output shape as its input (#1939)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Aug 28, 2022
1 parent b2fd01e commit 89330aa
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 52 deletions.
35 changes: 15 additions & 20 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,19 +362,6 @@ Val* getMaximumValue(DataType v) {

} // namespace

// TENSOR FACTORIES
TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
auto n = shape.size();
auto out = TensorViewBuilder()
.ndims(n)
.dtype(dtype)
.contiguity(std::vector<bool>(n, true))
.shape(shape)
.build();
IrBuilder::create<RNGOp>(RNGOpType::Uniform, out);
return out;
}

Val* castOp(DataType dtype, Val* v1) {
if (v1->getDataType().value() == dtype) {
return set(v1);
Expand Down Expand Up @@ -454,19 +441,27 @@ TensorView* unaryOp(
}

// TENSOR FACTORIES
TORCH_CUDA_CU_API TensorView* arange(Val* end, DataType dtype) {
TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
auto n = shape.size();
auto out = TensorViewBuilder()
.ndims(n)
.dtype(dtype)
.contiguity(std::vector<bool>(n, true))
.shape(shape)
.build();
IrBuilder::create<RNGOp>(RNGOpType::Uniform, out);
return out;
}

TensorView* arange(Val* end, DataType dtype) {
return arange(FusionGuard::getCurFusion()->zeroVal(), end, dtype);
}

TORCH_CUDA_CU_API TensorView* arange(Val* start, Val* end, DataType dtype) {
TensorView* arange(Val* start, Val* end, DataType dtype) {
return arange(start, end, FusionGuard::getCurFusion()->oneVal(), dtype);
}

TORCH_CUDA_CU_API TensorView* arange(
Val* start,
Val* end,
Val* step,
DataType dtype) {
TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) {
if (isIntegralType(dtype)) {
start = castOp(DataType::Int, start);
end = castOp(DataType::Int, end);
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ void Fusion::printMath(bool from_outputs_only) {
std::vector<Val*> Fusion::inputsAndCreated() {
auto result = inputs_;
for (auto expr : exprs()) {
if (expr->inputs().empty()) {
auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
if (tv_inputs.empty()) {
for (auto v : expr->outputs()) {
result.emplace_back(v);
}
Expand Down
27 changes: 14 additions & 13 deletions torch/csrc/jit/codegen/cuda/ir_iostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,27 +419,28 @@ void IrPrinter::handle(const TernaryOp* top) {
}

void IrPrinter::handle(const RNGOp* rop) {
bool istvop = ir_utils::isTvOp(rop);
if (!print_inline_) {
indent();
os_ << rop->output(0);

// tensor operations tend to be long, break them up into multiple lines
if (istvop) {
os_ << "\n";
indent_size_++;
indent();
}

os_ << rop->output(0) << "\n";
indent_size_++;
indent();
os_ << " = ";
} else {
checkInlineable(rop);
}

os_ << rop->getRNGOpType() << "()";
os_ << rop->getRNGOpType() << "(";
bool first = true;
for (auto i : rop->inputs()) {
if (!first) {
os_ << ", ";
}
handle(i);
first = false;
}
os_ << ")";

if (istvop)
indent_size_--;
indent_size_--;

if (!print_inline_)
os_ << ";\n";
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,11 @@ RNGOp::RNGOp(
rng_op_type_(type),
rng_offset_(rng_offset),
philox_index_(philox_index) {
if (out->isA<TensorView>()) {
for (auto id : out->as<TensorView>()->getRootDomain()) {
addInput(id->extent());
}
}
addOutput(out);
}

Expand Down
37 changes: 19 additions & 18 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,32 +106,33 @@ at::Tensor generate_uniform(int64_t size, at::ScalarType dtype) {
} // namespace
TEST_F(NVFuserTest, FusionRNGValidateWithCURand_CUDA) {
for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) {
for (auto dtype : {kFloat, kDouble}) {
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);
Int* size_val = IrBuilder::create<Int>();
fusion->addInput(size_val);
TensorView* tv0 = rand({size_val}, aten_to_data_type(dtype));
fusion->addOutput(tv0);
Int* size_val = IrBuilder::create<Int>();
fusion->addInput(size_val);
TensorView* tv0 = rand({size_val}, DataType::Float);
TensorView* tv1 = rand({size_val}, DataType::Double);
fusion->addOutput(tv0);
fusion->addOutput(tv1);
FusionExecutorCache fec(std::move(fusion_ptr));
FusionExecutorCache fec(std::move(fusion_ptr));
at::manual_seed(0);
auto cg_outputs = fec.runFusionWithInputs({size});
auto out = cg_outputs[0];
for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) {
at::manual_seed(0);
auto cg_outputs = fec.runFusionWithInputs({size});
at::manual_seed(0);
auto ref = generate_uniform(size, dtype);
at::manual_seed(0);
auto ref0 = generate_uniform(size, kFloat);
auto ref1 = generate_uniform(size, kDouble);
testValidate(fec.fusion(), {out}, {size}, {ref}, __LINE__, __FILE__);
}
testValidate(
fec.fusion(), cg_outputs, {size}, {ref0, ref1}, __LINE__, __FILE__);
}
}
TEST_F(NVFuserTest, FusionRNGSimpleValidateWithCURand_CUDA) {
TEST_F(NVFuserTest, FusionRNGManualScheduleValidateWithCURand_CUDA) {
int64_t size = 128;
auto dtype = kFloat;
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
Expand Down

0 comments on commit 89330aa

Please sign in to comment.