Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions aten/src/ATen/native/mps/operations/Activation.mm
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,51 @@ Tensor relu_mps(const Tensor& self) {
return erfTensor;
}

MPSGraphTensor* tanh (MPSGraph* mpsGraph, MPSGraphTensor *inputTensor) {
// 0.5 * x * (1 + text{Tanh}(sqrt(2 / pi) * (x + 0.044715 * x^3)))
auto dataType = [inputTensor dataType];
const float SQRT2_PI = 0.797884523868560791015625f;
const float VAL = 0.044715f;
MPSGraphTensor *onef = [mpsGraph constantWithScalar: 1.0f
shape: @[@1]
dataType: dataType];
MPSGraphTensor *halff = [mpsGraph constantWithScalar: 0.5f
shape: @[@1]
dataType: dataType];
MPSGraphTensor *sqrt2_pi = [mpsGraph constantWithScalar: SQRT2_PI
shape: @[@1]
dataType: dataType];
MPSGraphTensor *valf = [mpsGraph constantWithScalar: VAL
shape: @[@1]
dataType: dataType];

MPSGraphTensor *erfTensor = [mpsGraph multiplicationWithPrimaryTensor: inputTensor
secondaryTensor: inputTensor
name : nil];
erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor
secondaryTensor: inputTensor
name : nil];
erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor
secondaryTensor: valf
name : nil];
erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor
secondaryTensor: inputTensor
name : nil];
erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor
secondaryTensor: sqrt2_pi
name : nil];
erfTensor = [mpsGraph tanhWithTensor: erfTensor
name : nil];
erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor
secondaryTensor: onef
name : nil];
erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor
secondaryTensor: halff
name : nil];

return erfTensor;
}

TORCH_IMPL_FUNC(gelu_out_mps) (
const Tensor& self, c10::string_view approximate, const Tensor& output
) {
Expand All @@ -767,7 +812,7 @@ Tensor relu_mps(const Tensor& self) {
MPSStream* stream = getCurrentMPSStream();

@autoreleasepool {
string key = "gelu_out_mps" + getTensorsStringKey({self});
string key = "gelu_out_mps" + getTensorsStringKey({self}) + ":" + c10::str(approximate);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
Expand All @@ -782,7 +827,12 @@ Tensor relu_mps(const Tensor& self) {
getMPSDataType(self.scalar_type()),
getMPSShape(self));

MPSGraphTensor* outputTensor = normcdf(mpsGraph, inputTensor);
MPSGraphTensor* outputTensor = nil;
if(approximate == "tanh") {
outputTensor = tanh(mpsGraph, inputTensor);
} else {
outputTensor = normcdf(mpsGraph, inputTensor);
}
outputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor
secondaryTensor:inputTensor
name:nil];
Expand Down
12 changes: 11 additions & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,17 @@ def helper(dtype):
e_string = str(e)
self.assertEqual(e_string, "MPS does not support cumsum op with int64 input")

def test_gelu_tanh(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
x = cpu_x.detach().clone().to('mps')

gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh')
gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh')
self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu)

helper((2, 8, 4, 5))


class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
Expand Down Expand Up @@ -9315,7 +9326,6 @@ class TestConsistency(TestCase):
'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
'trace': [torch.int64],
'normalnumber_mean': [torch.float16, torch.float32],
'nn.functional.gelu': [torch.float32],
'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
'multinomial': [torch.float32],
'floor_divide': [torch.int16, torch.int32, torch.int64],
Expand Down