Skip to content

Commit

Permalink
Include scalar params in caching key (#94)
Browse files Browse the repository at this point in the history
* Include scalar params in caching key

* Add key for softplus backward; add test for scalar params
  • Loading branch information
abhudev authored and kulinseth committed Feb 5, 2023
1 parent e134b91 commit e9ffa35
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/mps/operations/Activation.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1419,7 +1419,7 @@ Tensor glu_backward_mps (const Tensor& grad_output,
MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float);

@autoreleasepool {
string key = "softplus_out_mps:" + getTensorsStringKey({self});
string key = "softplus_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(beta.to<double>()) + ":" + std::to_string(threshold.to<double>());

CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
Expand Down Expand Up @@ -1524,7 +1524,7 @@ Tensor glu_backward_mps (const Tensor& grad_output,
MPSStream* stream = getCurrentMPSStream();

@autoreleasepool {
string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self});
string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self}) + ":" + std::to_string(beta.to<double>()) + ":" + std::to_string(threshold.to<double>());

CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
Expand Down
8 changes: 4 additions & 4 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4649,7 +4649,7 @@ def helper(shape, dim=0):

# Test softplus
def test_softplus(self):
def helper(shape, beta=0.5, threshold=0.5):
def helper(shape, beta=1, threshold=20):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()

Expand All @@ -4667,9 +4667,9 @@ def helper(shape, beta=0.5, threshold=0.5):

# Test empty shape too
for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]:
helper(shape)
helper(shape, beta=0.6, threshold=0.6) # relu path
helper(shape, beta=1, threshold=20) # softplus path
for beta in [0.5, 1, 2, 3, 4]:
for threshold in [0.5, 20, 30, 40, 50]:
helper(shape, beta, threshold)

# Test silu

Expand Down

0 comments on commit e9ffa35

Please sign in to comment.