Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… into fix_dtype
  • Loading branch information
SkafteNicki committed Sep 2, 2021
2 parents 02a7ebe + 13d5fcd commit 57592af
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
import sys
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence
import yaml

import numpy as np
import pytest
import torch
import yaml
from torch import Tensor, tensor
from torch.multiprocessing import Pool, set_start_method

Expand Down
15 changes: 10 additions & 5 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,12 @@ def device(self) -> "torch.device":
return self._device

def to(self, *args: Any, **kwargs: Any) -> "Metric":
"""Moves the parameters and buffers. Normal dtype casting is not supported by this method
instead use the `set_dtype` method instead.
"""Moves the parameters and buffers.
Normal dtype casting is not supported by this method instead use the `set_dtype` method instead.
"""
out = torch._C._nn._parse_to(*args, **kwargs)
if len(out)==4: # pytorch 1.5 and higher
if len(out) == 4: # pytorch 1.5 and higher
device, dtype, non_blocking, convert_to_format = out
else: # pytorch 1.4 and lower
device, dtype, non_blocking = out
Expand All @@ -431,8 +432,12 @@ def to(self, *args: Any, **kwargs: Any) -> "Metric":

def convert(t: Tensor) -> Tensor:
if convert_to_format is not None and t.dim() in (4, 5):
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
non_blocking, memory_format=convert_to_format)
return t.to(
device,
dtype if t.is_floating_point() or t.is_complex() else None,
non_blocking,
memory_format=convert_to_format,
)
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)

self._device = device
Expand Down

0 comments on commit 57592af

Please sign in to comment.