Skip to content

Commit

Permalink
Re-definition found for builtin function (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Jun 29, 2021
1 parent 441e116 commit 4340f35
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 16 deletions.
6 changes: 3 additions & 3 deletions tests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def compute(self):
metric(i)
state_dict = metric.state_dict()

sum = i * (i + 1) / 2
assert state_dict["x"] == sum * worldsize
assert metric.x == sum
exp_sum = i * (i + 1) / 2
assert state_dict["x"] == exp_sum * worldsize
assert metric.x == exp_sum
assert metric.c == (i + 1)
assert state_dict["c"] == metric.c * worldsize

Expand Down
12 changes: 6 additions & 6 deletions tests/regression/test_cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def _multi_target_sk_metric(preds, target, reduction, sk_fn=sk_cosine):
sk_target = target.view(-1, num_targets).numpy()
result_array = sk_fn(sk_target, sk_preds)
col = np.diagonal(result_array)
sum = col.sum()
col_sum = col.sum()
if reduction == 'sum':
to_return = sum
to_return = col_sum
elif reduction == 'mean':
mean = sum / len(col)
mean = col_sum / len(col)
to_return = mean
else:
to_return = col
Expand All @@ -62,11 +62,11 @@ def _single_target_sk_metric(preds, target, reduction, sk_fn=sk_cosine):
sk_target = target.view(-1).numpy()
result_array = sk_fn(np.expand_dims(sk_preds, axis=0), np.expand_dims(sk_target, axis=0))
col = np.diagonal(result_array)
sum = col.sum()
col_sum = col.sum()
if reduction == 'sum':
to_return = sum
to_return = col_sum
elif reduction == 'mean':
mean = sum / len(col)
mean = col_sum / len(col)
to_return = mean
else:
to_return = col
Expand Down
8 changes: 4 additions & 4 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
raise ValueError(f'Expected input `{name}` to be a string, but got {type(arg)}')

def __repr__(self) -> Optional[str]:
repr = super().__repr__()[:-2]
repr_str = super().__repr__()[:-2]
if self.prefix:
repr += f",\n prefix={self.prefix}{',' if self.postfix else ''}"
repr_str += f",\n prefix={self.prefix}{',' if self.postfix else ''}"
if self.postfix:
repr += f"{',' if not self.prefix else ''}\n postfix={self.postfix}"
return repr + "\n)"
repr_str += f"{',' if not self.prefix else ''}\n postfix={self.postfix}"
return repr_str + "\n)"
6 changes: 3 additions & 3 deletions torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ class MatrixSquareRoot(Function):
"""

@staticmethod
def forward(ctx: Any, input: Tensor) -> Tensor:
def forward(ctx: Any, input_data: Tensor) -> Tensor:
import scipy

# TODO: update whenever pytorch gets an matrix square root function
# Issue: https://github.com/pytorch/pytorch/issues/9983
m = input.detach().cpu().numpy().astype(np.float_)
m = input_data.detach().cpu().numpy().astype(np.float_)
scipy_res, _ = scipy.linalg.sqrtm(m, disp=False)
sqrtm = torch.from_numpy(scipy_res.real).to(input)
sqrtm = torch.from_numpy(scipy_res.real).to(input_data)
ctx.save_for_backward(sqrtm)
return sqrtm

Expand Down

0 comments on commit 4340f35

Please sign in to comment.