Skip to content

Commit

Permalink
Test namedtuples sent to device correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanbreitsch committed Apr 29, 2020
1 parent 59fa3cc commit 7ad55ab
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tests/models/test_cpu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import namedtuple
import platform

import pytest
Expand Down Expand Up @@ -221,6 +222,13 @@ def test_single_gpu_batch_parse():
assert batch[1][0]['b'].device.index == 0
assert batch[1][0]['b'].type() == 'torch.cuda.FloatTensor'

# namedtuple of tensor
BatchType = namedtuple('BatchType', ['a', 'b'])
batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)]
batch = trainer.transfer_batch_to_gpu(batch, 0)
assert batch[0].a.device.index == 0
assert batch[0].a.type() == 'torch.cuda.FloatTensor'


def test_simple_cpu(tmpdir):
"""Verify continue training session on CPU."""
Expand Down

0 comments on commit 7ad55ab

Please sign in to comment.