From 2e118ffe64fd170849922852b1c999e311055fcc Mon Sep 17 00:00:00 2001 From: Ramin Azarmehr Date: Mon, 31 Oct 2022 14:37:51 -0400 Subject: [PATCH] Fix the type cast issue with Binary Ops (#158) * Fix data type issues for logaddexp and logaddexp2 ops * Fix the type cast issue with Binary Ops * Move several ops out of Blocklist in TestConsistency * Move good ops from block list to allow list --- .../ATen/native/mps/operations/BinaryOps.mm | 184 +++--------------- test/test_mps.py | 49 ++--- 2 files changed, 45 insertions(+), 188 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 2c50a5efdc3c6..6edb7bdb2a3de 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -72,38 +72,21 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha MPSGraphTensor* secondaryCastTensor = newCachedGraph->secondaryTensor; // this type inference is only required at the time of graph creation - const ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), other.scalar_type()); - - // Condition - - // 1. Division operation - // 2. Inputs are not float - bool div_condition = op_name.rfind("div", 0) == 0 - && (!(common_dtype == ScalarType::Float || common_dtype == ScalarType::Half)); - - auto compute_type = ScalarType::Float; - - if(div_condition) { - - if(output_.scalar_type() == ScalarType::Float || output_.scalar_type() == ScalarType::Half) - compute_type = output_.scalar_type(); - - primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, compute_type); - secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, compute_type); + ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), other.scalar_type()); + // Integer input must be cast to float if output is float + if (isIntegralType(common_dtype, true) && isFloatingType(output.scalar_type())) { + common_dtype = output_.scalar_type(); + } + if (self.scalar_type() != common_dtype) { + primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype); } - else { - if (self.scalar_type() != common_dtype) { - primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype); - } - if (other.scalar_type() != common_dtype) { - secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype); - } + if (other.scalar_type() != common_dtype) { + secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype); } newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor); // Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor // Output tensor should have been promoted but it remains an int32 tensor - - if ((div_condition && compute_type != output_.scalar_type()) || - output_.scalar_type() != common_dtype) { + if (output_.scalar_type() != common_dtype) { newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, output_.scalar_type()); } } @@ -319,139 +302,26 @@ Tensor floor_divide_mps(const Tensor& self, const Tensor& other) { TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) { - using namespace mps; - MPSStream* stream = getCurrentMPSStream(); - - if (&output != &self) { - output.resize_(self.sizes());; - } - - // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *otherTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - @autoreleasepool { - string key = "log_base_e_out_mps:" + getTensorsStringKey({self, other}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* xTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* yTensor = mpsGraphRankedPlaceHolder(mpsGraph, other); - MPSGraphTensor* ePowXTensor = [mpsGraph exponentWithTensor:xTensor - name:nil]; - MPSGraphTensor* ePowYTensor = [mpsGraph exponentWithTensor:yTensor - name:nil]; - MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:ePowXTensor - secondaryTensor:ePowYTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph logarithmWithTensor:sumTensor - name:nil]; - - newCachedGraph->inputTensor_ = xTensor; - newCachedGraph->otherTensor_ = yTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } - - } + mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { + MPSGraph* mpsGraph = cachedGraph->graph(); + MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil] + secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil] + name:nil]; + return [mpsGraph logarithmWithTensor:sumTensor name:nil]; + }; + mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp_out_mps", logaddexp_op_block); +} TORCH_IMPL_FUNC(logaddexp2_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) { - using namespace mps; - MPSStream* stream = getCurrentMPSStream(); - - if (&output != &self) { - output.resize_(self.sizes());; - } - - // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *otherTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - @autoreleasepool { - string key = "log_base_two_out_mps:" + getTensorsStringKey({self, other}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* xTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* yTensor = mpsGraphRankedPlaceHolder(mpsGraph, other); - MPSGraphTensor* twoPowXTensor = [mpsGraph exponentBase2WithTensor:xTensor - name:nil]; - MPSGraphTensor* twoPowYTensor = [mpsGraph exponentBase2WithTensor:yTensor - name:nil]; - MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:twoPowXTensor - secondaryTensor:twoPowYTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph logarithmBase2WithTensor:sumTensor - name:nil]; - - newCachedGraph->inputTensor_ = xTensor; - newCachedGraph->otherTensor_ = yTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } + mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { + MPSGraph* mpsGraph = cachedGraph->graph(); + MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil] + secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil] + name:nil]; + return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil]; + }; + mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp2_out_mps", logaddexp2_op_block); } } // namespace native diff --git a/test/test_mps.py b/test/test_mps.py index 040a230d8503d..276c595727a58 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -7490,7 +7490,7 @@ class TestConsistency(TestCase): '__getitem__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__radd__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__rand__': ['b8', 'i16', 'i32', 'i64', 'u8'], - '__rdiv__': ['f16', 'f32', 'i16', 'i32', 'u8'], + '__rdiv__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__rmatmul__': ['f32'], '__rmul__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__ror__': ['b8', 'i16', 'i32', 'i64', 'u8'], @@ -7500,6 +7500,7 @@ class TestConsistency(TestCase): 'masked.argmin': ['i16', 'i64', 'u8'], 'masked.log_softmax': ['f32'], 'masked.logaddexp': ['f32'], + 'masked.logsumexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'masked.norm': ['f16', 'f32'], 'masked.normalize': ['f16', 'f32'], 'masked.softmax': ['f32'], @@ -7509,7 +7510,7 @@ class TestConsistency(TestCase): 'abs': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'], 'acos': ['b8', 'f32', 'i16', 'i32', 'u8'], 'acosh': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], + 'add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'addbmm': ['f32'], 'addcdiv': ['f32'], 'addcmul': ['f32', 'i16', 'i32', 'i64', 'u8'], @@ -7524,7 +7525,6 @@ class TestConsistency(TestCase): 'argmin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'amax': ['f32'], 'amix': ['f32'], - 'logsumexp': ['f32'], 'mean': ['f32'], 'sum': ['f32'], 'asin': ['b8', 'f32', 'i16', 'i32', 'u8'], @@ -7549,6 +7549,9 @@ class TestConsistency(TestCase): 'ceil': ['f32', 'int32', 'int64', 'f16'], 'char': ['b8', 'u8'], 'chunk': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'clamp': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'clamp_max': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'clamp_min': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'clone': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'column_stack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'combinations': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], @@ -7566,6 +7569,7 @@ class TestConsistency(TestCase): 'diagonal_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], 'diff': ['f16', 'f32', 'i16', 'i32', 'i64'], 'dist': ['f32'], + 'div': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'dot': ['f32', 'i16', 'i32', 'i64', 'u8'], 'equal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'erf': ['b8', 'f32', 'i16', 'i32', 'u8'], @@ -7601,15 +7605,20 @@ class TestConsistency(TestCase): 'log1p': ['b8', 'f32', 'i16', 'i32', 'u8'], 'log2': ['b8', 'f32', 'i16', 'i32', 'u8'], 'log_softmax': ['f32'], - 'logaddexp': ['f32'], - 'logaddexp2': ['f32'], + 'logaddexp': ['f16', 'f32'], + 'logaddexp2': ['f16', 'f32'], + 'logical_and': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logical_not': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logspace': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'logsumexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'masked_fill': ['f16', 'i16', 'i32', 'i64'], 'masked_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'matmul': ['f32'], 'mm': ['f32'], 'mv': ['f32'], + 'mul': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'neg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'nn.functional.adaptive_max_pool1d': ['f32'], 'nn.functional.adaptive_max_pool2d': ['f32'], @@ -7639,7 +7648,7 @@ class TestConsistency(TestCase): 'nn.functional.hinge_embedding_loss': ['f32'], 'nn.functional.huber_loss': ['f32'], 'nn.functional.instance_norm': ['f32'], - 'nn.functional.kl_div': ['f32'], + 'nn.functional.kl_div': ['f32', 'i16', 'i32', 'i64'], 'nn.functional.l1_loss': ['f16', 'f32'], 'nn.functional.leaky_relu': ['f32'], 'nn.functional.linear': ['f32'], @@ -7706,7 +7715,7 @@ class TestConsistency(TestCase): 'square': ['f16', 'f32'], 'squeeze': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'stack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'sub': ['f32', 'i16', 'i32', 'i64'], + 'sub': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'sum_to_size': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'svd': ['f32'], 't': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], @@ -7720,7 +7729,7 @@ class TestConsistency(TestCase): 'tril_indices': ['i32', 'i64'], 'triu': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'triu_indices': ['i32', 'i64'], - 'true_divide': ['b8', 'f16', 'f32', 'i16', 'u8'], + 'true_divide': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'trunc': ['f32'], 'unbind': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'unflatten': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], @@ -7730,12 +7739,6 @@ class TestConsistency(TestCase): 'vsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'vstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'zero_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'clamp': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'clamp_max': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'clamp_min': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'logical_and': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'nonzero': ['f32', 'i16', 'i32', 'i64']} @@ -7922,7 +7925,6 @@ class TestConsistency(TestCase): 'masked.sum': [torch.bool], # Functions that hard crash - 'nn.functional.kl_div': [torch.int16, torch.int32, torch.int64], 'nn.functional.nll_loss': [torch.float32], 'nn.functional.padreflect': [torch.float32], 'nn.functional.padreplicate': [torch.float32], 'std': [torch.float16], @@ -7968,10 +7970,6 @@ class TestConsistency(TestCase): # These were moved from ALLOWLIST to BLOCK as they are not working # locally - '__radd__': ['torch.bool', 'torch.uint8'], - '__rmul__': ['torch.uint8'], - 'add': ['torch.bool', 'torch.uint8'], - 'addr': ['torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], 'diag': ['torch.int64'], 'diagflat': ['torch.int64'], @@ -8015,18 +8013,15 @@ class TestConsistency(TestCase): 'log_softmaxdtype': None, 'trapezoid': None, 'eq': None, - 'mul': None, 'inner': None, 'take_along_dim': None, # New block list ops that need investigation - '__rdiv__': ['torch.bool', 'torch.int64'], '__rpow__': ['torch.float32', 'torch.int16', 'torch.int32', 'torch.uint8'], '_masked.amax': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], '_masked.amin': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], '_masked.argmax': ['torch.float16', 'torch.float32', 'torch.int32'], '_masked.argmin': ['torch.float16', 'torch.float32', 'torch.int32'], - '_masked.logsumexp': ['torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], '_masked.mean': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], '_masked.prod': ['torch.bool', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], '_masked.std': ['torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], @@ -8039,9 +8034,6 @@ class TestConsistency(TestCase): 'bernoulli': ['torch.float32'], 'byte': ['torch.float16', 'torch.float32'], 'char': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64'], - 'clamp': ['torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'clamp_max': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'clamp_min': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], 'constant_pad_nd': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], 'count_nonzero': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], 'diff': ['torch.bool', 'torch.uint8'], @@ -8063,10 +8055,6 @@ class TestConsistency(TestCase): 'int': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int64', 'torch.uint8'], 'linalg.eigvals': ['torch.float32'], 'linalg.multi_dot': ['torch.float32'], - 'logical_and': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'logical_or': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'logical_xor': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'logsumexp': ['torch.bool', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], 'matmul': ['torch.uint8'], 'mean': ['torch.float16', 'torch.float32'], 'native_layer_norm': ['torch.float32'], @@ -8103,7 +8091,6 @@ class TestConsistency(TestCase): 'scatter_add': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], 'scatter': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], 'short': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'sub': ['torch.float16', 'torch.uint8'], 'sum': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], 'tan': ['torch.float32'], 'tensor_split': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], @@ -8179,7 +8166,7 @@ def get_samples(): if op.name == "nn.functional.conv2d" and dtype == torch.float32: atol = 1e-4 rtol = 3e-5 - elif op.name == "add" and dtype == torch.float16: + elif (op.name == "add" or op.name == "sub") and dtype == torch.float16: atol = 1e-2 rtol = 1e-2 else: