Skip to content

Commit

Permalink
Merge pull request #50 from csarofeen/cleanup_warnings
Browse files Browse the repository at this point in the history
Cleanup a few compiler warnings
  • Loading branch information
tlemo authored Jun 1, 2020
2 parents f9c4b52 + 51e830a commit 26d13cc
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
14 changes: 7 additions & 7 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ namespace fuser {

// Will return a new value of type val with the DataType dtype, if it's a
// tensorview it will propagate the shape information from val.
TORCH_CUDA_API Val* newValLike(const Val* const val, DataType dtype) {
TORCH_CUDA_API Val* newValLike(const Val* val, DataType dtype) {
switch (val->getValType().value()) {
case (ValType::TensorView):
return static_cast<const TensorView* const>(val)->newForOutput(dtype);
return val->as<TensorView>()->newForOutput(dtype);
case (ValType::NamedScalar):
case (ValType::Scalar):
switch (dtype) {
Expand All @@ -39,7 +39,7 @@ TORCH_CUDA_API Val* newValLike(const Val* const val, DataType dtype) {
val->getDataType().value());
}

TORCH_CUDA_API Val* newValLike(const Val* const val) {
TORCH_CUDA_API Val* newValLike(const Val* val) {
return newValLike(val, val->getDataType().value());
}

Expand Down Expand Up @@ -112,7 +112,7 @@ TORCH_CUDA_API Val* castOp(DataType dtype, Val* v1) {
}

TORCH_CUDA_API TensorView* castOp(DataType dtype, TensorView* v1) {
return castOp(dtype, static_cast<Val*>(v1))->as<TensorView>();
return castOp(dtype, v1->as<Val>())->as<TensorView>();
}

// UNARY OPERATIONS
Expand All @@ -124,7 +124,7 @@ TORCH_CUDA_API Val* unaryOp(UnaryOpType type, Val* v1) {
}

TORCH_CUDA_API TensorView* unaryOp(UnaryOpType type, TensorView* v1) {
return unaryOp(type, static_cast<Val*>(v1))->as<TensorView>();
return unaryOp(type, v1->as<Val>())->as<TensorView>();
}

TORCH_CUDA_API Val* neg(Val* v) {
Expand Down Expand Up @@ -551,7 +551,7 @@ TORCH_CUDA_API Val* threshold(Val* in, Val* thresh, Val* value) {
}

TORCH_CUDA_API TensorView* threshold(TensorView* in, Val* thresh, Val* value) {
return threshold(static_cast<Val*>(in), thresh, value)->as<TensorView>();
return threshold(in->as<Val>(), thresh, value)->as<TensorView>();
}

TORCH_CUDA_API Val* clamp(Val* in, Val* min_val, Val* max_val) {
Expand All @@ -572,7 +572,7 @@ TORCH_CUDA_API Val* clamp(Val* in, Val* min_val, Val* max_val) {
}

TORCH_CUDA_API TensorView* clamp(TensorView* in, Val* min_val, Val* max_val) {
return clamp(static_cast<Val*>(in), min_val, max_val)->as<TensorView>();
return clamp(in->as<Val>(), min_val, max_val)->as<TensorView>();
}

} // namespace fuser
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,18 @@ void Statement::dispatch(T handler, Statement* stmt) {
}

template <typename T>
void Val::constDispatch(T handler, const Val* const val) {
void Val::constDispatch(T handler, const Val* val) {
switch (*(val->getValType())) {
case ValType::Scalar:
switch (*(val->getDataType())) {
case DataType::Bool:
ptr(handler)->handle(static_cast<const Bool* const>(val));
ptr(handler)->handle(static_cast<const Bool*>(val));
return;
case DataType::Float:
ptr(handler)->handle(static_cast<const Float*>(val));
return;
case DataType::Half:
ptr(handler)->handle(static_cast<const Half* const>(val));
ptr(handler)->handle(static_cast<const Half*>(val));
return;
case DataType::Int:
ptr(handler)->handle(static_cast<const Int*>(val));
Expand Down Expand Up @@ -190,10 +190,10 @@ void Expr::constDispatch(T handler, const Expr* expr) {
ptr(handler)->handle(static_cast<const BinaryOp*>(expr));
return;
case ExprType::TernaryOp:
ptr(handler)->handle(static_cast<const TernaryOp* const>(expr));
ptr(handler)->handle(static_cast<const TernaryOp*>(expr));
return;
case ExprType::ReductionOp:
ptr(handler)->handle(static_cast<const ReductionOp* const>(expr));
ptr(handler)->handle(static_cast<const ReductionOp*>(expr));
return;
case ExprType::ForLoop:
ptr(handler)->handle(static_cast<const ForLoop*>(expr));
Expand Down

0 comments on commit 26d13cc

Please sign in to comment.