Skip to content

Commit

Permalink
Incorporate review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
tlemo committed May 28, 2020
1 parent 9684d56 commit 1b47c92
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 11 deletions.
9 changes: 5 additions & 4 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ void testGPU_FusionExprEvalConstants() {
auto* a = new Int(7);
auto* b = new Int(3);

checkIntValue(&eval_context, neg(a), -7);
checkIntValue(&eval_context, add(a, b), 10);
checkIntValue(&eval_context, mul(sub(a, b), div(a, b)), 8);
checkIntValue(&eval_context, neg(mul(sub(a, b), div(a, b))), -8);
checkIntValue(&eval_context, mod(a, b), 1);
checkIntValue(&eval_context, ceilDiv(a, b), 3);
}
Expand All @@ -88,7 +89,7 @@ void testGPU_FusionExprEvalBindings() {
auto* a = new Int();
auto* b = new Int();
auto* c = add(a, b);
auto* d = ceilDiv(add(a, b), b);
auto* d = neg(ceilDiv(add(a, b), b));

eval_context.bind(a, 7);
eval_context.bind(b, 3);
Expand All @@ -97,7 +98,7 @@ void testGPU_FusionExprEvalBindings() {
checkIntValue(&eval_context, sub(a, b), 4);
checkIntValue(&eval_context, mod(a, b), 1);
checkIntValue(&eval_context, ceilDiv(a, b), 3);
checkIntValue(&eval_context, d, 4);
checkIntValue(&eval_context, d, -4);

eval_context.bind(a, 2);
eval_context.bind(b, 5);
Expand All @@ -106,7 +107,7 @@ void testGPU_FusionExprEvalBindings() {
checkIntValue(&eval_context, sub(a, b), -3);
checkIntValue(&eval_context, mod(a, b), 2);
checkIntValue(&eval_context, ceilDiv(a, b), 1);
checkIntValue(&eval_context, d, 2);
checkIntValue(&eval_context, d, -2);
}

// Evaluate expressions in a simple IR
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ TORCH_CUDA_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) {
return out;
}

TORCH_CUDA_API Val* neg(Val* v) {
return unaryOp(UnaryOpType::Neg, v);
}

TORCH_CUDA_API Val* add(Val* v1, Val* v2) {
return binaryOp(BinaryOpType::Add, v1, v2);
}
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ TORCH_CUDA_API Val* reductionOp(
Val* init,
Val* v1);

// BINARY OPAERATIONS
// UNARY OPERATIONS
TORCH_CUDA_API Val* neg(Val* v);

// BINARY OPERATIONS
TORCH_CUDA_API Val* add(Val* v1, Val* v2);
TORCH_CUDA_API Val* sub(Val* v1, Val* v2);
TORCH_CUDA_API Val* mul(Val* v1, Val* v2);
Expand Down
5 changes: 1 addition & 4 deletions torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void ExpressionEvaluator::handle(const UnaryOp* uop) {
if (in.has_value()) {
switch (uop->getUnaryOpType()) {
case UnaryOpType::Neg:
result_ = Int::ScalarType(!*in);
result_ = -*in;
break;
case UnaryOpType::Cast:
result_ = *in;
Expand Down Expand Up @@ -102,9 +102,6 @@ void ExpressionEvaluator::handle(const BinaryOp* bop) {
TORCH_CHECK(*rhs != 0);
result_ = *lhs % *rhs;
break;
case BinaryOpType::LT:
result_ = Int::ScalarType(*lhs < *rhs);
break;
case BinaryOpType::CeilDiv:
TORCH_CHECK(*rhs != 0);
result_ = (*lhs + *rhs - 1) / *rhs;
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ struct TORCH_CUDA_API Statement {
auto downcast_ptr = static_cast<T*>(this);
#else
auto downcast_ptr = dynamic_cast<T*>(this);
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
#endif
assert(downcast_ptr != nullptr);
return downcast_ptr;
}

Expand All @@ -113,8 +113,8 @@ struct TORCH_CUDA_API Statement {
auto downcast_ptr = static_cast<const T*>(this);
#else
auto downcast_ptr = dynamic_cast<const T*>(this);
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
#endif
assert(downcast_ptr != nullptr);
return downcast_ptr;
}

Expand Down

0 comments on commit 1b47c92

Please sign in to comment.