Skip to content

Commit

Permalink
Revert in-progress changes to switch to a new Kernel IR hierarchy
Browse files Browse the repository at this point in the history
This reverts commit fc09c1b5a7240701da093406753908eba6f41e1d.
  • Loading branch information
tlemo committed Jul 24, 2020
1 parent 58d9f1a commit 72aec1d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 77 deletions.
28 changes: 26 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <torch/csrc/WindowsTorchApiMacro.h>

#include <torch/csrc/jit/codegen/cuda/type.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>

#include <cstdint>
#include <deque>
Expand Down Expand Up @@ -59,7 +58,7 @@ class IrCloner;
* Basically beinng able to succienctly traverse down the inhereitance stack of
* a Statment at runtime. This is currently implemented in dispatch.h
*/
class TORCH_CUDA_API Statement : public NonCopyable, public PolymorphicBase {
class TORCH_CUDA_API Statement {
friend void swap(Fusion&, Fusion&) noexcept;

public:
Expand All @@ -68,6 +67,8 @@ class TORCH_CUDA_API Statement : public NonCopyable, public PolymorphicBase {
// Cloning constructor
Statement(const Statement* src, IrCloner* ir_cloner);

virtual ~Statement() = default;

// Dispatch functions, definitions in dispatch.cpp
template <typename T>
static void dispatch(T handler, Statement*);
Expand Down Expand Up @@ -103,6 +104,29 @@ class TORCH_CUDA_API Statement : public NonCopyable, public PolymorphicBase {
// Make sure this is an Expr and return it as an Expr*
Expr* asExpr();

// Replacement for static_cast<T*>(ptr): ptr->as<T>()
template <class T>
T* as() {
#ifdef NDEBUG
auto downcast_ptr = static_cast<T*>(this);
#else
auto downcast_ptr = dynamic_cast<T*>(this);
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
#endif
return downcast_ptr;
}

template <class T>
const T* as() const {
#ifdef NDEBUG
auto downcast_ptr = static_cast<const T*>(this);
#else
auto downcast_ptr = dynamic_cast<const T*>(this);
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
#endif
return downcast_ptr;
}

// Return the fusion this statement belongs to
Fusion* fusion() const {
return fusion_;
Expand Down
32 changes: 0 additions & 32 deletions torch/csrc/jit/codegen/cuda/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,6 @@ namespace jit {
namespace fuser {
namespace kir {

#if 0 // TODO: switch to the new class hierarchy

// Base class for Kernel IR nodes
class TORCH_CUDA_API Node : public NonCopyable, public PolymorphicBase {};

// A generic value (scalar or tensor)
class TORCH_CUDA_API Val : public Node {
public:
explicit Val(ValType vtype, DataType dtype = DataType::Null)
: vtype_(vtype), dtype_(dtype) {}

private:
const ValType vtype_;
const DataType dtype_;
};

// A computation, with inputs and outputs
//
// TODO: rename to Statement/Operation?
//
class TORCH_CUDA_API Expr : public Node {
public:
explicit Expr(ExprType type) : type_(type) {}

private:
ExprType type_ = ExprType::Invalid;
std::vector<Val*> inputs_;
std::vector<Val*> outputs_;
};

#endif

#if 0

class TORCH_CUDA_API NamedScalar : public Val {
Expand Down
43 changes: 0 additions & 43 deletions torch/csrc/jit/codegen/cuda/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,49 +98,6 @@ TORCH_CUDA_API bool haveSameShape(
const std::shared_ptr<c10::TensorType>& lhs,
const std::shared_ptr<c10::TensorType>& rhs);


// Simple mixin for suppressing copy & move operations
class NonCopyable {
public:
NonCopyable() = default;

// No copy/move semantics
NonCopyable(const NonCopyable&) = delete;
NonCopyable& operator=(const NonCopyable&) = delete;
};

// A generic root for a hierarchy of polymorphic classes:
// - It ensures virtual destructors
// - Provides the base->as<Derived>() notation
class PolymorphicBase {
public:
virtual ~PolymorphicBase() = default;

// Replacement for static_cast<T*>(ptr): ptr->as<T>()
// (checked in DEBUG builds)
template <class T>
T* as() {
#ifdef NDEBUG
auto downcast_ptr = static_cast<T*>(this);
#else
auto downcast_ptr = dynamic_cast<T*>(this);
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
#endif
return downcast_ptr;
}

template <class T>
const T* as() const {
#ifdef NDEBUG
auto downcast_ptr = static_cast<const T*>(this);
#else
auto downcast_ptr = dynamic_cast<const T*>(this);
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
#endif
return downcast_ptr;
}
};

} // namespace fuser
} // namespace jit
} // namespace torch

0 comments on commit 72aec1d

Please sign in to comment.