Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,11 @@ Val* newScalar(ValType vtype, DataType dtype) {
switch (dtype) {
case DataType::Bool:
return IrBuilder::create<Bool>();
case DataType::Double:
case DataType::Float:
case DataType::Half:
case DataType::BFloat16:
return IrBuilder::create<Float>();
case DataType::Double:
return IrBuilder::create<Double>();
case DataType::Int32:
case DataType::Int:
Expand Down
47 changes: 45 additions & 2 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,17 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}

void initStringStreamFormat(std::stringstream& ss) {
const int digits = std::numeric_limits<Double::ScalarType>::max_digits10;
ss.imbue(std::locale("C"));
ss << std::scientific << std::setprecision(digits);
ss << std::scientific;
setPrecistion(ss);
}

// By default use double precision format
template <typename FloatingPointScalarType = Double>
void setPrecistion(std::stringstream& ss) {
const int digits = std::numeric_limits<
typename FloatingPointScalarType::ScalarType>::max_digits10;
ss << std::setprecision(digits);
}

// Generates the kernel function declaration
Expand Down Expand Up @@ -405,6 +413,32 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}

void handle(const Float* f) final {
const auto def = f->definition();
const bool has_alloc = alloc_map_.find(f) != alloc_map_.end();
if (def != nullptr && !has_alloc) {
code_ << "(" << gen(def) << ")";
} else if (f->isConst()) {
auto val = *f->value();
// note: default inf/nan doesn't work and should be replaced with macros
// `NAN`, `POS_INFINITY` and `NEG_INFINITY` instead.
if (std::isinf(val)) {
if (val > 0) {
code_ << "POS_INFINITY";
} else {
code_ << "NEG_INFINITY";
}
} else if (std::isnan(val)) {
code_ << "NAN";
} else {
setPrecistion<Float>(code_);
code_ << val << "f";
}
} else {
code_ << varName(f);
}
}

void handle(const Double* d) final {
const auto def = d->definition();
const bool has_alloc = alloc_map_.find(d) != alloc_map_.end();
Expand All @@ -423,6 +457,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
} else if (std::isnan(val)) {
code_ << "NAN";
} else {
setPrecistion<Double>(code_);
code_ << val;
}
} else {
Expand Down Expand Up @@ -902,6 +937,14 @@ class CudaKernelGenerator : private OptOutConstDispatch {
exponent = int_exp;
}
}
} else if (auto val_float = dynamic_cast<Float*>(rhs)) {
if (val_float->isConst()) {
auto fp_exp = val_float->value().value();
float int_exp = 0;
if (std::modf(fp_exp, &int_exp) == 0) {
exponent = int_exp;
}
}
}

if (!exponent.has_value()) {
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ void Val::dispatch(T handler, Val* val) {
case DataType::Bool:
ptr(handler)->handle(val->as<Bool>());
return;
case DataType::Float:
ptr(handler)->handle(val->as<Float>());
return;
case DataType::Double:
ptr(handler)->handle(val->as<Double>());
return;
Expand Down Expand Up @@ -272,6 +275,9 @@ void Val::constDispatch(T handler, const Val* val) {
case DataType::Bool:
ptr(handler)->handle(val->as<Bool>());
return;
case DataType::Float:
ptr(handler)->handle(val->as<Float>());
return;
case DataType::Double:
ptr(handler)->handle(val->as<Double>());
return;
Expand Down Expand Up @@ -507,6 +513,9 @@ void Val::mutatorDispatch(T mutator, Val* val) {
case DataType::Bool:
ptr(mutator)->mutate(val->as<Bool>());
return;
case DataType::Float:
ptr(mutator)->mutate(val->as<Float>());
return;
case DataType::Double:
ptr(mutator)->mutate(val->as<Double>());
return;
Expand Down Expand Up @@ -821,6 +830,9 @@ void OptInDispatch::unhandled(Statement* stmt) {
void OptOutConstDispatch::handle(const Bool* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const Float* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const Double* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -980,6 +992,9 @@ void OptOutDispatch::unhandled(Statement*) {}
void OptOutDispatch::handle(Bool* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(Float* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(Double* stmt) {
unhandled(stmt);
}
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ class TensorDomain;
class TensorView;

class Bool;
class Double;
template <DataType DT>
class FloatingPoint;
using Float = FloatingPoint<DataType::Float>;
using Double = FloatingPoint<DataType::Double>;
class Int;
class ComplexDouble;
class NamedScalar;
Expand Down Expand Up @@ -135,6 +138,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const TensorDomain* stmt);
virtual void handle(const TensorView* stmt);
virtual void handle(const Bool* stmt);
virtual void handle(const Float* stmt);
virtual void handle(const Double* stmt);
virtual void handle(const Int* stmt);
virtual void handle(const ComplexDouble* stmt);
Expand Down Expand Up @@ -200,6 +204,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {

// Vals
virtual void handle(Bool* stmt);
virtual void handle(Float* stmt);
virtual void handle(Double* stmt);
virtual void handle(Int* stmt);
virtual void handle(ComplexDouble* stmt);
Expand Down Expand Up @@ -309,6 +314,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {

// Vals
virtual void mutate(Bool*);
virtual void mutate(Float*);
virtual void mutate(Double*);
virtual void mutate(Int*);
virtual void mutate(ComplexDouble*);
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ void Fusion::addInput(Val* input) {
TORCH_CHECK(
!input->isConst(),
"Immediate scalar value cannot be added as an input. It is not necessary to pass it as an input.");
TORCH_CHECK(
!input->isA<Float>(),
"Using Float as an input is not supported as there is no scalar float type in PyTorch.");
}

inputs_.push_back(input);
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ class ConstCheck : private OptOutConstDispatch {
is_const_ = is_const_ && b->isConst();
}

void handle(const Float* f) final {
is_const_ = is_const_ && f->isConst();
}

void handle(const Double* d) final {
is_const_ = is_const_ && d->isConst();
}
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>

#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_container.h>

namespace torch {
namespace jit {
namespace fuser {
Expand Down Expand Up @@ -45,6 +48,7 @@ IR_BUILDER_INSTANTIATE(IterDomain)
IR_BUILDER_INSTANTIATE(TensorDomain)
IR_BUILDER_INSTANTIATE(TensorView)
IR_BUILDER_INSTANTIATE(Bool)
IR_BUILDER_INSTANTIATE(Float)
IR_BUILDER_INSTANTIATE(Double)
IR_BUILDER_INSTANTIATE(Int)
IR_BUILDER_INSTANTIATE(ComplexDouble)
Expand Down Expand Up @@ -80,6 +84,8 @@ Val* IrBuilder::newResult(DataType dtype) {
switch (dtype) {
case DataType::Bool:
return IrBuilder::create<Bool>(c10::nullopt);
case DataType::Float:
return IrBuilder::create<Float>(c10::nullopt);
case DataType::Double:
return IrBuilder::create<Double>(c10::nullopt);
case DataType::Int:
Expand Down
18 changes: 2 additions & 16 deletions torch/csrc/jit/codegen/cuda/ir_builder.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#pragma once

#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_container.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder_passkey.h>

namespace torch {
namespace jit {
Expand All @@ -14,20 +13,7 @@ class Kernel;
}

class IrCloner;

// Passkey for builder to register properties with statements, and to call
// functions in IrContainer
class TORCH_CUDA_CU_API IrBuilderPasskey {
friend class IrBuilder;

public:
// TODO: Collapse ir_container and Kernel once Kernel inherits from
// IrContainer
IrContainer* const ir_container_ = nullptr;

private:
explicit IrBuilderPasskey(IrContainer* ir_container);
};
class IrContainer;

//! IR builder interface
class TORCH_CUDA_CU_API IrBuilder {
Expand Down
30 changes: 30 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_builder_passkey.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <c10/macros/Export.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

class IrContainer;

// Passkey for builder to register properties with statements, and to call
// functions in IrContainer
class TORCH_CUDA_CU_API IrBuilderPasskey {
friend class IrBuilder;

public:
// TODO: Collapse ir_container and Kernel once Kernel inherits from
// IrContainer
IrContainer* const ir_container_ = nullptr;

private:
explicit IrBuilderPasskey(IrContainer* ir_container)
: ir_container_(ir_container) {}
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_cloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ void IrCloner::handle(const Bool* b) {
clone_ = IrBuilder::clone(b, this);
}

void IrCloner::handle(const Float* f) {
clone_ = IrBuilder::clone(f, this);
}

void IrCloner::handle(const Double* d) {
clone_ = IrBuilder::clone(d, this);
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_cloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch {
void handle(const IterDomain*) override;

void handle(const Bool*) override;
void handle(const Float*) override;
void handle(const Double*) override;
void handle(const Int*) override;
void handle(const ComplexDouble*) override;
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_graphviz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ class IrNodeLabel : private OptInConstDispatch {
}
}

void handle(const Float* f) override {
if (f->isSymbolic()) {
label_ << "f" << f->name();
} else {
if (detail_level_ >= DetailLevel::Explicit) {
label_ << "f" << f->name() << "=";
}
label_ << *f->value();
}
}

void handle(const Double* d) override {
if (d->isSymbolic()) {
label_ << "d" << d->name();
Expand Down Expand Up @@ -363,6 +374,10 @@ void IrGraphGenerator::handle(const Bool* b) {
printValue(b, IrNodeLabel::gen(b, detail_level_));
}

void IrGraphGenerator::handle(const Float* f) {
printValue(f, IrNodeLabel::gen(f, detail_level_));
}

void IrGraphGenerator::handle(const Double* d) {
printValue(d, IrNodeLabel::gen(d, detail_level_));
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_graphviz.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch {
void handle(const IterDomain*) override;

void handle(const Bool*) override;
void handle(const Float*) override;
void handle(const Double*) override;
void handle(const Int*) override;
void handle(const ComplexDouble*) override;
Expand Down
Loading