Skip to content

Commit 1a8670a

Browse files
abhudevkulinseth
authored andcommitted
Add error messages for int64 non-available ops (#80)
* Add error messages for int64 non-available ops * Move warning to common code
1 parent ce8d70d commit 1a8670a

File tree

3 files changed

+6
-0
lines changed

3 files changed

+6
-0
lines changed

aten/src/ATen/native/mps/operations/Activation.mm

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ Tensor relu_mps(const Tensor& self) {
416416
using namespace mps;
417417
using CachedGraph = MPSUnaryCachedGraph;
418418
TORCH_CHECK(output.is_mps());
419+
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support sigmoid op with int64 input")
419420

420421
if(output.numel() == 0) {
421422
return;

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha,
2828
const Tensor& output_, std::string op_name, BinaryOpBlock binaryBlock)
2929
{
30+
3031
// it's possible to receive empty tensors here
3132
if (self.numel() == 0 || other.numel() == 0) {
3233
return;
@@ -203,6 +204,9 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp
203204

204205
#define CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(func_out, func_stub, other_type) \
205206
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \
207+
TORCH_CHECK(!(self.scalar_type() == ScalarType::Long && \
208+
(#func_stub == "power" || #func_stub == "atan2")), \
209+
"MPS does not support ", #func_stub, " op with int64 input") \
206210
mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \
207211
^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \
208212
MPSGraph* mpsGraph = cachedGraph->graph(); \

aten/src/ATen/native/mps/operations/UnaryOps.mm

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
134134

135135
TORCH_IMPL_FUNC(log1p_out_mps) (const Tensor& self, const Tensor& output)
136136
{
137+
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support log1p op with int64 input")
137138
using namespace mps;
138139
if (!output.is_same_size(self)) {
139140
output.resize_(self.sizes());

0 commit comments

Comments
 (0)