From 57c5fea26827eaef15e7a6d53e6e9d3803dcf528 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 17 Nov 2022 11:08:03 -0800 Subject: [PATCH 1/2] Add Float IR node class Represents the 32-bit floating-point scalar value. Not supported in PyTorch, so can't be used as inputs to fusions --- torch/csrc/jit/codegen/cuda/arith.cpp | 3 +- torch/csrc/jit/codegen/cuda/codegen.cpp | 47 +++++++++++- torch/csrc/jit/codegen/cuda/dispatch.cpp | 15 ++++ torch/csrc/jit/codegen/cuda/dispatch.h | 8 ++- torch/csrc/jit/codegen/cuda/fusion.cpp | 3 + torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 4 ++ torch/csrc/jit/codegen/cuda/ir_builder.cpp | 6 ++ torch/csrc/jit/codegen/cuda/ir_builder.h | 18 +---- .../jit/codegen/cuda/ir_builder_passkey.h | 30 ++++++++ torch/csrc/jit/codegen/cuda/ir_cloner.cpp | 4 ++ torch/csrc/jit/codegen/cuda/ir_cloner.h | 1 + torch/csrc/jit/codegen/cuda/ir_graphviz.cpp | 15 ++++ torch/csrc/jit/codegen/cuda/ir_graphviz.h | 1 + .../jit/codegen/cuda/ir_interface_nodes.h | 41 ++++++++--- torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 18 +++++ torch/csrc/jit/codegen/cuda/ir_iostream.h | 1 + torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 30 ++------ torch/csrc/jit/codegen/cuda/kernel.cpp | 3 - torch/csrc/jit/codegen/cuda/kernel_ir.h | 23 ------ torch/csrc/jit/codegen/cuda/mutator.cpp | 2 + .../jit/codegen/cuda/runtime/grid_sync.cu | 4 +- .../csrc/jit/codegen/cuda/test/test_gpu3.cpp | 71 +++++++++++++++++++ torch/csrc/jit/codegen/cuda/type.h | 15 ++++ .../csrc/jit/codegen/cuda/type_promotion.cpp | 2 +- 24 files changed, 279 insertions(+), 86 deletions(-) create mode 100644 torch/csrc/jit/codegen/cuda/ir_builder_passkey.h diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 05e3fdb098f9..3c79d639e1fb 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -88,10 +88,11 @@ Val* newScalar(ValType vtype, DataType dtype) { switch (dtype) { case DataType::Bool: return IrBuilder::create(); - case DataType::Double: case DataType::Float: case DataType::Half: case DataType::BFloat16: + return IrBuilder::create(); + case DataType::Double: return IrBuilder::create(); case DataType::Int32: case DataType::Int: diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 188167c80b13..6fa5827e573e 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -170,9 +170,17 @@ class CudaKernelGenerator : private OptOutConstDispatch { } void initStringStreamFormat(std::stringstream& ss) { - const int digits = std::numeric_limits::max_digits10; ss.imbue(std::locale("C")); - ss << std::scientific << std::setprecision(digits); + ss << std::scientific; + setPrecistion(ss); + } + + // By default use double precision format + template + void setPrecistion(std::stringstream& ss) { + const int digits = std::numeric_limits< + typename FloatingPointScalarType::ScalarType>::max_digits10; + ss << std::setprecision(digits); } // Generates the kernel function declaration @@ -405,6 +413,32 @@ class CudaKernelGenerator : private OptOutConstDispatch { } } + void handle(const Float* d) final { + const auto def = d->definition(); + const bool has_alloc = alloc_map_.find(d) != alloc_map_.end(); + if (def != nullptr && !has_alloc) { + code_ << "(" << gen(def) << ")"; + } else if (d->isConst()) { + auto val = *d->value(); + // note: default inf/nan doesn't work and should be replaced with macros + // `NAN`, `POS_INFINITY` and `NEG_INFINITY` instead. + if (std::isinf(val)) { + if (val > 0) { + code_ << "POS_INFINITY"; + } else { + code_ << "NEG_INFINITY"; + } + } else if (std::isnan(val)) { + code_ << "NAN"; + } else { + setPrecistion(code_); + code_ << val << "f"; + } + } else { + code_ << varName(d); + } + } + void handle(const Double* d) final { const auto def = d->definition(); const bool has_alloc = alloc_map_.find(d) != alloc_map_.end(); @@ -423,6 +457,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { } else if (std::isnan(val)) { code_ << "NAN"; } else { + setPrecistion(code_); code_ << val; } } else { @@ -902,6 +937,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { exponent = int_exp; } } + } else if (auto val_float = dynamic_cast(rhs)) { + if (val_float->isConst()) { + auto fp_exp = val_float->value().value(); + float int_exp = 0; + if (std::modf(fp_exp, &int_exp) == 0) { + exponent = int_exp; + } + } } if (!exponent.has_value()) { diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 8c2172f3f383..251d688c5c8f 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -48,6 +48,9 @@ void Val::dispatch(T handler, Val* val) { case DataType::Bool: ptr(handler)->handle(val->as()); return; + case DataType::Float: + ptr(handler)->handle(val->as()); + return; case DataType::Double: ptr(handler)->handle(val->as()); return; @@ -272,6 +275,9 @@ void Val::constDispatch(T handler, const Val* val) { case DataType::Bool: ptr(handler)->handle(val->as()); return; + case DataType::Float: + ptr(handler)->handle(val->as()); + return; case DataType::Double: ptr(handler)->handle(val->as()); return; @@ -507,6 +513,9 @@ void Val::mutatorDispatch(T mutator, Val* val) { case DataType::Bool: ptr(mutator)->mutate(val->as()); return; + case DataType::Float: + ptr(mutator)->mutate(val->as()); + return; case DataType::Double: ptr(mutator)->mutate(val->as()); return; @@ -821,6 +830,9 @@ void OptInDispatch::unhandled(Statement* stmt) { void OptOutConstDispatch::handle(const Bool* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const Float* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const Double* stmt) { unhandled(stmt); } @@ -980,6 +992,9 @@ void OptOutDispatch::unhandled(Statement*) {} void OptOutDispatch::handle(Bool* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(Float* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(Double* stmt) { unhandled(stmt); } diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 49295c16532c..5d9d363554d4 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -62,7 +62,10 @@ class TensorDomain; class TensorView; class Bool; -class Double; +template +class FloatingPoint; +using Float = FloatingPoint; +using Double = FloatingPoint; class Int; class ComplexDouble; class NamedScalar; @@ -135,6 +138,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const TensorDomain* stmt); virtual void handle(const TensorView* stmt); virtual void handle(const Bool* stmt); + virtual void handle(const Float* stmt); virtual void handle(const Double* stmt); virtual void handle(const Int* stmt); virtual void handle(const ComplexDouble* stmt); @@ -200,6 +204,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { // Vals virtual void handle(Bool* stmt); + virtual void handle(Float* stmt); virtual void handle(Double* stmt); virtual void handle(Int* stmt); virtual void handle(ComplexDouble* stmt); @@ -309,6 +314,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { // Vals virtual void mutate(Bool*); + virtual void mutate(Float*); virtual void mutate(Double*); virtual void mutate(Int*); virtual void mutate(ComplexDouble*); diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 61d8541a9a29..25e6a9e12d23 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -198,6 +198,9 @@ void Fusion::addInput(Val* input) { TORCH_CHECK( !input->isConst(), "Immediate scalar value cannot be added as an input. It is not necessary to pass it as an input."); + TORCH_CHECK( + !input->isA(), + "Using Float as an input is not supported as there is no scalar float type in PyTorch."); } inputs_.push_back(input); diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 48c6e0959b2f..e260258f2ba0 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -146,6 +146,10 @@ class ConstCheck : private OptOutConstDispatch { is_const_ = is_const_ && b->isConst(); } + void handle(const Float* d) final { + is_const_ = is_const_ && d->isConst(); + } + void handle(const Double* d) final { is_const_ = is_const_ && d->isConst(); } diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.cpp b/torch/csrc/jit/codegen/cuda/ir_builder.cpp index 7b58a7d444f7..fe60fa94a2d8 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_builder.cpp @@ -3,6 +3,9 @@ #include #include +#include +#include + namespace torch { namespace jit { namespace fuser { @@ -45,6 +48,7 @@ IR_BUILDER_INSTANTIATE(IterDomain) IR_BUILDER_INSTANTIATE(TensorDomain) IR_BUILDER_INSTANTIATE(TensorView) IR_BUILDER_INSTANTIATE(Bool) +IR_BUILDER_INSTANTIATE(Float) IR_BUILDER_INSTANTIATE(Double) IR_BUILDER_INSTANTIATE(Int) IR_BUILDER_INSTANTIATE(ComplexDouble) @@ -80,6 +84,8 @@ Val* IrBuilder::newResult(DataType dtype) { switch (dtype) { case DataType::Bool: return IrBuilder::create(c10::nullopt); + case DataType::Float: + return IrBuilder::create(c10::nullopt); case DataType::Double: return IrBuilder::create(c10::nullopt); case DataType::Int: diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.h b/torch/csrc/jit/codegen/cuda/ir_builder.h index f122232f8fb8..77693fb5a26b 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.h +++ b/torch/csrc/jit/codegen/cuda/ir_builder.h @@ -1,8 +1,7 @@ #pragma once -#include #include -#include +#include namespace torch { namespace jit { @@ -14,20 +13,7 @@ class Kernel; } class IrCloner; - -// Passkey for builder to register properties with statements, and to call -// functions in IrContainer -class TORCH_CUDA_CU_API IrBuilderPasskey { - friend class IrBuilder; - - public: - // TODO: Collapse ir_container and Kernel once Kernel inherits from - // IrContainer - IrContainer* const ir_container_ = nullptr; - - private: - explicit IrBuilderPasskey(IrContainer* ir_container); -}; +class IrContainer; //! IR builder interface class TORCH_CUDA_CU_API IrBuilder { diff --git a/torch/csrc/jit/codegen/cuda/ir_builder_passkey.h b/torch/csrc/jit/codegen/cuda/ir_builder_passkey.h new file mode 100644 index 000000000000..1a8654807d12 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/ir_builder_passkey.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +class IrContainer; + +// Passkey for builder to register properties with statements, and to call +// functions in IrContainer +class TORCH_CUDA_CU_API IrBuilderPasskey { + friend class IrBuilder; + + public: + // TODO: Collapse ir_container and Kernel once Kernel inherits from + // IrContainer + IrContainer* const ir_container_ = nullptr; + + private: + explicit IrBuilderPasskey(IrContainer* ir_container) + : ir_container_(ir_container) {} +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index 1a538f88997d..a528c21f4c97 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -68,6 +68,10 @@ void IrCloner::handle(const Bool* b) { clone_ = IrBuilder::clone(b, this); } +void IrCloner::handle(const Float* d) { + clone_ = IrBuilder::clone(d, this); +} + void IrCloner::handle(const Double* d) { clone_ = IrBuilder::clone(d, this); } diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index 9e54b074acc7..790e44f9a108 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -63,6 +63,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { void handle(const IterDomain*) override; void handle(const Bool*) override; + void handle(const Float*) override; void handle(const Double*) override; void handle(const Int*) override; void handle(const ComplexDouble*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp index f98f1f8cc788..ed4f2a4d81a3 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp @@ -44,6 +44,17 @@ class IrNodeLabel : private OptInConstDispatch { } } + void handle(const Float* f) override { + if (f->isSymbolic()) { + label_ << "f" << f->name(); + } else { + if (detail_level_ >= DetailLevel::Explicit) { + label_ << "f" << f->name() << "="; + } + label_ << *f->value(); + } + } + void handle(const Double* d) override { if (d->isSymbolic()) { label_ << "d" << d->name(); @@ -363,6 +374,10 @@ void IrGraphGenerator::handle(const Bool* b) { printValue(b, IrNodeLabel::gen(b, detail_level_)); } +void IrGraphGenerator::handle(const Float* f) { + printValue(f, IrNodeLabel::gen(f, detail_level_)); +} + void IrGraphGenerator::handle(const Double* d) { printValue(d, IrNodeLabel::gen(d, detail_level_)); } diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.h b/torch/csrc/jit/codegen/cuda/ir_graphviz.h index f5e624d06b90..c32ca81cef94 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.h +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.h @@ -77,6 +77,7 @@ class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch { void handle(const IterDomain*) override; void handle(const Bool*) override; + void handle(const Float*) override; void handle(const Double*) override; void handle(const Int*) override; void handle(const ComplexDouble*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index b6b32e4d7a9f..99614e818b0a 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -21,7 +22,6 @@ class WelfordResult; class ViewTransform; class IrCloner; -class IrBuilderPasskey; //! A Bool value //! @@ -54,20 +54,27 @@ class TORCH_CUDA_CU_API Bool : public Val { const c10::optional maybe_value_; }; -//! A Float64 value. This value can be a symbolic value (defined after the -//! kernel is compiled) or a constant value (inlined into the kernel +//! A floating-point value. This value can be a symbolic value (defined after +//! the kernel is compiled) or a constant value (inlined into the kernel //! definition). -class TORCH_CUDA_CU_API Double : public Val { +template +class TORCH_CUDA_CU_API FloatingPoint : public Val { public: - using ScalarType = double; + using ScalarType = typename DataTypeToNativeType
::type; - Double(IrBuilderPasskey passkey); + FloatingPoint(IrBuilderPasskey passkey) + : Val(passkey, ValType::Scalar, DT), maybe_value_{c10::nullopt} {} - explicit Double(IrBuilderPasskey passkey, ScalarType value); + explicit FloatingPoint(IrBuilderPasskey passkey, ScalarType value) + : Val(passkey, ValType::Scalar, DT), maybe_value_{value} {} - explicit Double(IrBuilderPasskey passkey, c10::optional value); + explicit FloatingPoint( + IrBuilderPasskey passkey, + c10::optional value) + : Val(passkey, ValType::Scalar, DT), maybe_value_{value} {} - Double(const Double* src, IrCloner* ir_cloner); + FloatingPoint(const FloatingPoint* src, IrCloner* ir_cloner) + : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} bool isSymbolic() const { return !(maybe_value_.has_value()); @@ -79,12 +86,26 @@ class TORCH_CUDA_CU_API Double : public Val { return maybe_value_; } - bool sameAs(const Statement* other) const override; + bool sameAs(const Statement* other) const override { + if (this == other) { + return true; + } + if (!other->isA()) { + return false; + } + const auto other_val = other->as(); + if (isConst() && other_val->isConst()) + return *value() == *(other_val->value()); + return false; + } private: const c10::optional maybe_value_; }; +using Float = FloatingPoint; +using Double = FloatingPoint; + //! An Int64 value. If used for indexing it's set as size_t. Otherwise it's an //! inlined literal in the kernel. class TORCH_CUDA_CU_API Int : public Val { diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index d17fbb86cefc..17339324c175 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -214,6 +214,24 @@ void IrPrinter::handle(const Bool* b) { } } +void IrPrinter::handle(const Float* f) { + if (print_inline_ && f->definition() != nullptr) { + os_ << "( "; + handle(f->definition()); + os_ << " )"; + return; + } + + if (f->isSymbolic()) { + os_ << "f" << varName(f); + } else { + os_ << "float(" + << std::setprecision( + std::numeric_limits::max_digits10) + << *(f->value()) << ")"; + } +} + void IrPrinter::handle(const Double* d) { if (print_inline_ && d->definition() != nullptr) { os_ << "( "; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index dc4d8abd55b6..8f3f6035c33e 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -77,6 +77,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { void handle(const TensorView*) final; void handle(const Bool*) final; + void handle(const Float*) final; void handle(const Double*) final; void handle(const Int*) final; void handle(const ComplexDouble*) final; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 12c634d3e0af..e2a809b297c9 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -44,6 +44,10 @@ class ScalarCheck : OptInConstDispatch { same_ = v1_->as()->sameAs(v2_->as()); } + void handle(const Float* d) final { + same_ = v1_->as()->sameAs(v2_->as()); + } + void handle(const Double* d) final { same_ = v1_->as()->sameAs(v2_->as()); } @@ -99,32 +103,6 @@ bool Bool::sameAs(const Statement* other) const { return false; } -Double::Double(IrBuilderPasskey passkey) - : Val(passkey, ValType::Scalar, DataType::Double), - maybe_value_{c10::nullopt} {} - -Double::Double(IrBuilderPasskey passkey, ScalarType value) - : Val(passkey, ValType::Scalar, DataType::Double), maybe_value_{value} {} - -Double::Double(IrBuilderPasskey passkey, c10::optional value) - : Val(passkey, ValType::Scalar, DataType::Double), maybe_value_{value} {} - -Double::Double(const Double* src, IrCloner* ir_cloner) - : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} - -bool Double::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_double = other->as(); - if (isConst() && other_double->isConst()) - return *value() == *(other_double->value()); - return false; -} - Int::Int(IrBuilderPasskey passkey) : Val(passkey, ValType::Scalar, DataType::Int), maybe_value_{c10::nullopt} {} diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp index 858e2d59ba58..b25e4f7824a9 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel.cpp @@ -15,9 +15,6 @@ namespace jit { namespace fuser { namespace cuda { -IrBuilderPasskey::IrBuilderPasskey(IrContainer* ir_container) - : ir_container_(ir_container) {} - namespace kir { namespace { diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 572900445ed3..24f2869fcf05 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -21,29 +21,6 @@ namespace cuda { class IrBuilderPasskey; -// Abstract nodes -class Val; -class Expr; - -// Values -class Bool; -class Double; -class Int; -class NamedScalar; - -class IterDomain; -class TensorDomain; -class TensorView; - -// Expressions -class UnaryOp; -class BinaryOp; -class TernaryOp; -class RNGOp; -class ReductionOp; -class WelfordOp; -class BroadcastOp; - namespace kir { class Kernel; diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 3735e74080ee..9fb69c7ae647 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -47,6 +47,8 @@ void OptOutMutator::registerMutation(Val* val, Val* mutation) { void OptOutMutator::mutate(Bool* b) {} +void OptOutMutator::mutate(Float* d) {} + void OptOutMutator::mutate(Double* d) {} void OptOutMutator::mutate(Int* i) {} diff --git a/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu b/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu index bec24b486b46..1a6d7437d925 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu @@ -108,9 +108,7 @@ __device__ void sync( index_utils::maskedIsLast(blockIdx, gridDim); if (last_block) { int64_t finished_val = - ((int64_t)( - index_utils::maskedSize(gridDim) - - 1)) * + ((int64_t)(index_utils::maskedSize(gridDim) - 1)) * ((int64_t)n_entrances); unsigned int ns = 8; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp index 5023affb773c..b799b5d02a5f 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp @@ -6870,6 +6870,77 @@ TEST_F(NVFuserTest, FusionIssue2163ReproInvalidAlias_CUDA) { fe.kernel(), {cg_output}, aten_inputs, {ref_y}, __LINE__, __FILE__, ""); } +// Testing scalar FP types +TEST_F(NVFuserTest, FusionFloatingPointType_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const float float_val = 0.1f; + const double double_val = 0.2; + + { + auto tv0 = makeConcreteTensor({2}, DataType::Float); + fusion.addInput(tv0); + + auto f2 = IrBuilder::create(float_val); + auto d3 = IrBuilder::create(double_val); + + // Adding two Floats produces a Float + auto f4 = add(f2, f2); + TORCH_CHECK(f4->isA(), "Invalid result: ", f4->toString()); + + // Adding a Double and a Float produces a Double + auto d5 = add(f2, d3); + TORCH_CHECK(d5->isA(), "Invalid result: ", d5->toString()); + + // Adding a Float and a Double produces a Double + auto d6 = add(d3, f2); + TORCH_CHECK(d6->isA(), "Invalid result: ", d6->toString()); + + // Adding two Doubles produce a Double + auto d7 = add(d5, d6); + TORCH_CHECK(d7->isA(), "Invalid result: ", d7->toString()); + + // Adding a Float to a Float tensor produces a Float tensor + auto tv1 = add(tv0, f4); + TORCH_CHECK( + tv1->getDataType() == DataType::Float, + tv1->toString(), + " has an invalid data type: ", + tv1->getDataType().value()); + + // Adding a Double to a Float tensor still produces a Float tensor + auto tv2 = add(tv1, d7); + TORCH_CHECK( + tv2->getDataType() == DataType::Float, + tv2->toString(), + " has an invalid data type: ", + tv2->getDataType().value()); + + fusion.addOutput(tv2); + } + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2}, options); + + std::vector inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + auto f2 = float_val; + auto d3 = double_val; + auto f4 = f2 + f2; + auto d5 = f2 + d3; + auto d6 = d3 + f2; + auto d7 = d5 + d6; + auto t1 = t0 + f4; + auto t2 = t1 + d7; + + testValidate(&fusion, cg_outputs, inputs, {t2}, __LINE__, __FILE__); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 7ec0b8ef9fd9..f07df24773a9 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -103,6 +103,21 @@ DataType getTypeFromComplexType(DataType dtype); // Return if the datatype is supported on the current device TORCH_CUDA_CU_API bool isSupportedTypeByDevice(DataType dtype); +template +struct DataTypeToNativeType; + +#define DEFINE_DATATYPE_TO_NATIVE_TYPE(data_type, native_type) \ + template <> \ + struct DataTypeToNativeType { \ + using type = native_type; \ + }; + +// TODO: Add more type specializations +DEFINE_DATATYPE_TO_NATIVE_TYPE(DataType::Float, float); +DEFINE_DATATYPE_TO_NATIVE_TYPE(DataType::Double, double); + +#undef DEFINE_DATATYPE_TO_NATIVE_TYPE + enum class UnaryOpType { Abs, Acos, diff --git a/torch/csrc/jit/codegen/cuda/type_promotion.cpp b/torch/csrc/jit/codegen/cuda/type_promotion.cpp index bfc3f7451a38..b6f4f5245f27 100644 --- a/torch/csrc/jit/codegen/cuda/type_promotion.cpp +++ b/torch/csrc/jit/codegen/cuda/type_promotion.cpp @@ -54,7 +54,7 @@ at::native::ResultTypeState updateResultTypeState( const at::native::ResultTypeState& in_state) { at::native::ResultTypeState new_state = in_state; c10::ScalarType current = scalar; - if (c10::isFloatingType(scalar)) { + if (scalar == c10::ScalarType::Half || scalar == c10::ScalarType::BFloat16) { current = c10::typeMetaToScalarType(at::get_default_dtype()); } new_state.wrappedResult = From dab6d54bc2283761baed81ae4f3f765859d3a00d Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 17 Nov 2022 17:37:00 -0800 Subject: [PATCH 2/2] rename --- torch/csrc/jit/codegen/cuda/codegen.cpp | 12 ++++++------ torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 4 ++-- torch/csrc/jit/codegen/cuda/ir_cloner.cpp | 4 ++-- torch/csrc/jit/codegen/cuda/ir_interface_nodes.h | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 6fa5827e573e..5ef5b67bd9eb 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -413,13 +413,13 @@ class CudaKernelGenerator : private OptOutConstDispatch { } } - void handle(const Float* d) final { - const auto def = d->definition(); - const bool has_alloc = alloc_map_.find(d) != alloc_map_.end(); + void handle(const Float* f) final { + const auto def = f->definition(); + const bool has_alloc = alloc_map_.find(f) != alloc_map_.end(); if (def != nullptr && !has_alloc) { code_ << "(" << gen(def) << ")"; - } else if (d->isConst()) { - auto val = *d->value(); + } else if (f->isConst()) { + auto val = *f->value(); // note: default inf/nan doesn't work and should be replaced with macros // `NAN`, `POS_INFINITY` and `NEG_INFINITY` instead. if (std::isinf(val)) { @@ -435,7 +435,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { code_ << val << "f"; } } else { - code_ << varName(d); + code_ << varName(f); } } diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index e260258f2ba0..42bc9309e5d1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -146,8 +146,8 @@ class ConstCheck : private OptOutConstDispatch { is_const_ = is_const_ && b->isConst(); } - void handle(const Float* d) final { - is_const_ = is_const_ && d->isConst(); + void handle(const Float* f) final { + is_const_ = is_const_ && f->isConst(); } void handle(const Double* d) final { diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index a528c21f4c97..93caef14b5cb 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -68,8 +68,8 @@ void IrCloner::handle(const Bool* b) { clone_ = IrBuilder::clone(b, this); } -void IrCloner::handle(const Float* d) { - clone_ = IrBuilder::clone(d, this); +void IrCloner::handle(const Float* f) { + clone_ = IrBuilder::clone(f, this); } void IrCloner::handle(const Double* d) { diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 99614e818b0a..12138ac6fb07 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -90,7 +90,7 @@ class TORCH_CUDA_CU_API FloatingPoint : public Val { if (this == other) { return true; } - if (!other->isA()) { + if (!other->isA()) { return false; } const auto other_val = other->as();