From cef8adf7fdaeaad75b8147900a46fa8154551d7c Mon Sep 17 00:00:00 2001 From: ronian526 Date: Fri, 20 Jan 2023 15:08:52 -0800 Subject: [PATCH 1/2] Fix nn.functional.gelu - fix gelu_out_mps key - add calculation for gelu with tanh - remove gelu from blocklist --- .../ATen/native/mps/operations/Activation.mm | 54 ++++++++++++++++++- test/test_mps.py | 1 - 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index dbb591246ca40..b84436bd99f5a 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -744,6 +744,51 @@ Tensor relu_mps(const Tensor& self) { return erfTensor; } +MPSGraphTensor* tanh (MPSGraph* mpsGraph, MPSGraphTensor *inputTensor) { + // 0.5 * x * (1 + text{Tanh}(sqrt(2 / pi) * (x + 0.044715 * x^3))) + auto dataType = [inputTensor dataType]; + const float SQRT2_PI = 0.797884523868560791015625f; + const float VAL = 0.044715f; + MPSGraphTensor *onef = [mpsGraph constantWithScalar: 1.0f + shape: @[@1] + dataType: dataType]; + MPSGraphTensor *halff = [mpsGraph constantWithScalar: 0.5f + shape: @[@1] + dataType: dataType]; + MPSGraphTensor *sqrt2_pi = [mpsGraph constantWithScalar: SQRT2_PI + shape: @[@1] + dataType: dataType]; + MPSGraphTensor *valf = [mpsGraph constantWithScalar: VAL + shape: @[@1] + dataType: dataType]; + + MPSGraphTensor *erfTensor = [mpsGraph multiplicationWithPrimaryTensor: inputTensor + secondaryTensor: inputTensor + name : nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor + secondaryTensor: inputTensor + name : nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor + secondaryTensor: valf + name : nil]; + erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor + secondaryTensor: inputTensor + name : nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor + secondaryTensor: sqrt2_pi + name : nil]; + erfTensor = [mpsGraph tanhWithTensor: erfTensor + name : nil]; + erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor + secondaryTensor: onef + name : nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor + secondaryTensor: halff + name : nil]; + + return erfTensor; +} + TORCH_IMPL_FUNC(gelu_out_mps) ( const Tensor& self, c10::string_view approximate, const Tensor& output ) { @@ -767,7 +812,7 @@ Tensor relu_mps(const Tensor& self) { MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "gelu_out_mps" + getTensorsStringKey({self}); + string key = "gelu_out_mps" + getTensorsStringKey({self}) + ":" + c10::str(approximate); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { @@ -782,7 +827,12 @@ Tensor relu_mps(const Tensor& self) { getMPSDataType(self.scalar_type()), getMPSShape(self)); - MPSGraphTensor* outputTensor = normcdf(mpsGraph, inputTensor); + MPSGraphTensor* outputTensor = nil; + if(approximate == "tanh") { + outputTensor = tanh(mpsGraph, inputTensor); + } else { + outputTensor = normcdf(mpsGraph, inputTensor); + } outputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor secondaryTensor:inputTensor name:nil]; diff --git a/test/test_mps.py b/test/test_mps.py index 4d8bbd06ef0c9..b340354efdf72 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9315,7 +9315,6 @@ class TestConsistency(TestCase): 'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], 'trace': [torch.int64], 'normalnumber_mean': [torch.float16, torch.float32], - 'nn.functional.gelu': [torch.float32], 'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], 'multinomial': [torch.float32], 'floor_divide': [torch.int16, torch.int32, torch.int64], From 68839b0ad765ab15c65b90557dc005af88ba79a1 Mon Sep 17 00:00:00 2001 From: ronian526 Date: Mon, 23 Jan 2023 14:41:53 -0800 Subject: [PATCH 2/2] - add test_gelu_tanh test --- test/test_mps.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_mps.py b/test/test_mps.py index b340354efdf72..ad7075b5b9a15 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -2197,6 +2197,17 @@ def helper(dtype): e_string = str(e) self.assertEqual(e_string, "MPS does not support cumsum op with int64 input") + def test_gelu_tanh(self): + def helper(shape): + cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) + x = cpu_x.detach().clone().to('mps') + + gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh') + gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh') + self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu) + + helper((2, 8, 4, 5)) + class TestLogical(TestCase): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):