Skip to content

Commit

Permalink
Include scalar params in caching key (#94)
Browse files Browse the repository at this point in the history
* Include scalar params in caching key

* Add key for softplus backward; add test for scalar params
  • Loading branch information
abhudev authored and kulinseth committed Oct 31, 2022
1 parent 590be21 commit 4169a72
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/mps/operations/Activation.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,7 @@ Tensor glu_backward_mps (const Tensor& grad_output,
MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float);;

@autoreleasepool {
string key = "softplus_out_mps:" + getTensorsStringKey({self});
string key = "softplus_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(beta.to<double>()) + ":" + std::to_string(threshold.to<double>());

CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
Expand Down Expand Up @@ -1576,7 +1576,7 @@ Tensor glu_backward_mps (const Tensor& grad_output,
MPSStream* stream = getCurrentMPSStream();

@autoreleasepool {
string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self});
string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self}) + ":" + std::to_string(beta.to<double>()) + ":" + std::to_string(threshold.to<double>());

CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
Expand Down
10 changes: 6 additions & 4 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3853,12 +3853,12 @@ def helper(shape, dim=0):

# Test softplus
def test_softplus(self):
def helper(shape):
def helper(shape, beta=1, threshold=20):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()

softplus_result = torch.nn.Softplus(beta=0.5, threshold=0.5)(x)
softplus_result_cpu = torch.nn.Softplus(beta=0.5, threshold=0.5)(cpu_x)
softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x)

cpu_grad = torch.randn(softplus_result.shape)
grad = cpu_grad.to('mps')
Expand All @@ -3871,7 +3871,9 @@ def helper(shape):

# Test empty shape too
for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]:
helper(shape)
for beta in [0.5, 1, 2, 3, 4]:
for threshold in [0.5, 20, 30, 40, 50]:
helper(shape, beta, threshold)

# Test silu

Expand Down

0 comments on commit 4169a72

Please sign in to comment.