Skip to content

Commit 6f51156

Browse files
anmolsjoshivfdev-5
andauthored
Updated codebase such that torch>=1.3 (#1150)
Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent f096742 commit 6f51156

File tree

4 files changed

+4
-43
lines changed

4 files changed

+4
-43
lines changed

.github/workflows/pytorch-version-tests.yml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,10 @@ jobs:
1111
fail-fast: false
1212
matrix:
1313
python-version: [3.5, 3.6, 3.7, 3.8]
14-
pytorch-version: [1.4.0, 1.3.1, 1.2.0, 1.1.0, 1.0.1]
14+
pytorch-version: [1.4.0, 1.3.1]
1515
exclude:
1616
- pytorch-version: 1.3.1
1717
python-version: 3.8
18-
- pytorch-version: 1.2.0
19-
python-version: 3.8
20-
- pytorch-version: 1.1.0
21-
python-version: 3.8
22-
- pytorch-version: 1.0.1
23-
python-version: 3.8
2418

2519
steps:
2620
- uses: actions/checkout@v2

ignite/distributed/comp_models/native.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,9 @@ def spawn(
290290
"PRODUCT": dist.ReduceOp.PRODUCT,
291291
"MIN": dist.ReduceOp.MIN,
292292
"MAX": dist.ReduceOp.MAX,
293+
"AND": dist.ReduceOp.BAND,
294+
"OR": dist.ReduceOp.BOR,
293295
}
294-
if LooseVersion(torch.__version__) > LooseVersion("1.2.0"):
295-
_reduce_op_map.update(
296-
{"AND": dist.ReduceOp.BAND, "OR": dist.ReduceOp.BOR,}
297-
)
298296

299297
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
300298
if op not in self._reduce_op_map:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def find_version(*file_paths):
2323
VERSION = find_version("ignite", "__init__.py")
2424

2525
requirements = [
26-
"torch>=1.0,<2",
26+
"torch>=1.3,<2",
2727
]
2828

2929
setup(

tests/ignite/engine/test_engine.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import time
3-
from distutils.version import LooseVersion
43
from unittest.mock import MagicMock, Mock, call
54

65
import numpy as np
@@ -514,36 +513,6 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl):
514513
_test_run_check_triggered_events()
515514

516515

517-
@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion("1.2.0"), reason="No IterableDataset in torch<1.2.0")
518-
def test_engine_with_iterable_dataloader():
519-
520-
from torch.utils.data import DataLoader
521-
522-
def _test(epoch_length=None):
523-
524-
le = 50
525-
num_workers = 4
526-
ds = get_iterable_dataset(0, le)
527-
data_loader = DataLoader(ds, num_workers=num_workers)
528-
529-
counter = [0]
530-
531-
def foo(e, b):
532-
print("{}-{}: {}".format(e.state.epoch, e.state.iteration, b))
533-
counter[0] += 1
534-
535-
engine = Engine(foo)
536-
engine.run(data_loader, epoch_length=epoch_length, max_epochs=5)
537-
538-
epoch_length = le * num_workers if epoch_length is None else epoch_length
539-
assert counter[0] == 5 * epoch_length
540-
541-
_test(epoch_length=20)
542-
543-
# tests issue : https://github.com/pytorch/ignite/issues/1076
544-
_test(epoch_length=None)
545-
546-
547516
def test_engine_random_state():
548517
def random_data_generator():
549518
while True:

0 commit comments

Comments
 (0)