Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite ExpressionEvaluator to use IterVisitor #58

Merged
merged 2 commits into from
Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ static TensorView* makeDummyTensor(

static void checkIntValue(
const EvaluationContext* eval_context,
const Val* val,
Val* val,
Int::ScalarType expected_value) {
TORCH_CHECK(val->isAnInt());
const auto actual_value = ExpressionEvaluator::evaluate(val, eval_context);
Expand Down Expand Up @@ -148,11 +148,22 @@ void testGPU_FusionExprEvalBindings() {
auto* a = new Int();
auto* b = new Int();
auto* c = add(a, b);
auto* d = neg(ceilDiv(add(a, b), b));
auto* d = neg(ceilDiv(c, b));
auto* e = new Int(0);

// trying to evaluate before binding should give empty results
TORCH_CHECK(!ExpressionEvaluator::evaluate(a, &eval_context).has_value());
TORCH_CHECK(!ExpressionEvaluator::evaluate(d, &eval_context).has_value());

eval_context.bind(a, 7);
eval_context.bind(b, 3);

// can't bind to the results of expressions
ASSERT_ANY_THROW(eval_context.bind(c, 100));

// can't bind to concrete values
ASSERT_ANY_THROW(eval_context.bind(e, 100));

checkIntValue(&eval_context, c, 10);
checkIntValue(&eval_context, sub(a, b), 4);
checkIntValue(&eval_context, mod(a, b), 1);
Expand Down
57 changes: 31 additions & 26 deletions torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,76 +38,81 @@ void EvaluationContext::print() const {
}

c10::optional<Int::ScalarType> ExpressionEvaluator::evaluate(
const Statement* expr,
Val* val,
const EvaluationContext* context) {
TORCH_CHECK(context != nullptr);
ExpressionEvaluator evaluator(context);
evaluator.OptInConstDispatch::handle(expr);
return evaluator.result_;
evaluator.traverseFrom(context->fusion(), {val}, false);
return evaluator.value(val);
}

void ExpressionEvaluator::handle(const Int* i) {
c10::optional<Int::ScalarType> ExpressionEvaluator::value(
const Statement* stmt) const {
const auto it = values_.find(stmt);
return (it != values_.end()) ? c10::optional<Int::ScalarType>(it->second)
: c10::nullopt;
}

void ExpressionEvaluator::handle(Int* i) {
if (i->value().has_value()) {
result_ = i->value();
values_[i] = *i->value();
} else if (const auto* def = context_->fusion()->origin(i)) {
result_ = evaluate(def, context_);
const auto& def_result = value(def);
if (def_result.has_value()) {
values_[i] = *def_result;
}
} else {
const auto& bound_value = context_->concreteValue(i);
if (bound_value.has_value()) {
result_ = bound_value;
values_[i] = *bound_value;
}
}
}

void ExpressionEvaluator::handle(const NamedScalar* i) {
// nothing to do, leave the result "unknown"
}

void ExpressionEvaluator::handle(const UnaryOp* uop) {
const auto in = evaluate(uop->in(), context_);
void ExpressionEvaluator::handle(UnaryOp* uop) {
const auto in = value(uop->in());
if (in.has_value()) {
switch (uop->getUnaryOpType()) {
case UnaryOpType::Neg:
result_ = -*in;
values_[uop] = -*in;
break;
case UnaryOpType::Cast:
result_ = *in;
values_[uop] = *in;
break;
default:
TORCH_CHECK(!"Unexpected operator type");
}
}
}

void ExpressionEvaluator::handle(const BinaryOp* bop) {
TORCH_CHECK(bop->out()->isAnInt()); // not really needed
const auto lhs = evaluate(bop->lhs(), context_);
const auto rhs = evaluate(bop->rhs(), context_);
void ExpressionEvaluator::handle(BinaryOp* bop) {
const auto lhs = value(bop->lhs());
const auto rhs = value(bop->rhs());
if (lhs.has_value() && rhs.has_value()) {
switch (bop->getBinaryOpType()) {
case BinaryOpType::Add:
result_ = *lhs + *rhs;
values_[bop] = *lhs + *rhs;
break;
case BinaryOpType::Sub:
result_ = *lhs - *rhs;
values_[bop] = *lhs - *rhs;
break;
case BinaryOpType::Mul:
result_ = *lhs * *rhs;
values_[bop] = *lhs * *rhs;
break;
case BinaryOpType::Div:
TORCH_CHECK(*rhs != 0);
result_ = *lhs / *rhs;
values_[bop] = *lhs / *rhs;
break;
case BinaryOpType::Mod:
TORCH_CHECK(*rhs != 0);
result_ = *lhs % *rhs;
values_[bop] = *lhs % *rhs;
break;
case BinaryOpType::CeilDiv:
TORCH_CHECK(*rhs != 0);
result_ = (*lhs + *rhs - 1) / *rhs;
values_[bop] = (*lhs + *rhs - 1) / *rhs;
break;
case BinaryOpType::And:
result_ = Int::ScalarType(*lhs && *rhs);
values_[bop] = Int::ScalarType(*lhs && *rhs);
break;
default:
TORCH_CHECK(!"Unexpected operator type");
Expand Down
22 changes: 11 additions & 11 deletions torch/csrc/jit/codegen/cuda/expr_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#pragma once

#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>

#include <c10/util/Optional.h>

Expand All @@ -20,15 +20,15 @@ namespace fuser {
//
class TORCH_CUDA_API EvaluationContext {
public:
explicit EvaluationContext(const Fusion* fusion) : fusion_(fusion) {}
explicit EvaluationContext(Fusion* fusion) : fusion_(fusion) {}

// Set the concrete value for a Int*
void bind(const Val* value, Int::ScalarType concrete_value);

// Retrieves the concrete value, or nullopt if not set
c10::optional<Int::ScalarType> concreteValue(const Val* value) const;

const Fusion* fusion() const {
Fusion* fusion() const {
return fusion_;
}

Expand All @@ -37,18 +37,18 @@ class TORCH_CUDA_API EvaluationContext {

private:
std::unordered_map<const Val*, Int::ScalarType> bindings_;
const Fusion* fusion_ = nullptr;
Fusion* fusion_ = nullptr;
};

// Evaluates expressions in a Fusion IR, using the passed in
// context (EvaluationContext) to query for concrete_values. The
// evaluation context may override concrete values in the IR as well.
class TORCH_CUDA_API ExpressionEvaluator : private OptInConstDispatch {
class TORCH_CUDA_API ExpressionEvaluator : private IterVisitor {
public:
// Returns the result of the specified expression, or nullopt if
// the result cannot be evaluated
static c10::optional<Int::ScalarType> evaluate(
const Statement* expr,
Val* val,
const EvaluationContext* context);

private:
Expand All @@ -57,15 +57,15 @@ class TORCH_CUDA_API ExpressionEvaluator : private OptInConstDispatch {

~ExpressionEvaluator() override = default;

void handle(const Int*) override;
void handle(const NamedScalar*) override;
c10::optional<Int::ScalarType> value(const Statement* stmt) const;

void handle(const UnaryOp*) override;
void handle(const BinaryOp*) override;
void handle(Int*) override;
void handle(UnaryOp*) override;
void handle(BinaryOp*) override;

private:
const EvaluationContext* context_ = nullptr;
c10::optional<Int::ScalarType> result_;
std::unordered_map<const Statement*, Int::ScalarType> values_;
};

} // namespace fuser
Expand Down