From de07b0ed964ddd831845bce4a757be3e0217909a Mon Sep 17 00:00:00 2001 From: Kurt Shuster Date: Fri, 14 May 2021 18:00:22 -0400 Subject: [PATCH] [Chunk Teacher] Remove exception for specifying non-streaming data (#3653) * remove exception * update chunk teacher tests --- parlai/core/teachers.py | 3 --- tests/test_teachers.py | 40 +++++++++++++++++++--------------------- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/parlai/core/teachers.py b/parlai/core/teachers.py index b81b1a74249..5040fb147a9 100644 --- a/parlai/core/teachers.py +++ b/parlai/core/teachers.py @@ -2291,9 +2291,6 @@ def __init__(self, opt, shared=None): super().__init__(opt, shared) self.buffersize = self.get_buffersize() - if 'stream' not in opt['datatype']: - raise ValueError('Chunk teacher should be used with streaming. ') - self.set_datasettings(opt) self.dws = int(self.opt.get('distributed_world_size', 1)) diff --git a/tests/test_teachers.py b/tests/test_teachers.py index 88c1c384340..5a0a5a0e5f6 100644 --- a/tests/test_teachers.py +++ b/tests/test_teachers.py @@ -301,27 +301,25 @@ def test_dynamic_batched(self): assert valid['exs'] == 100 assert test['exs'] == 100 - def test_stream_only(self): - with self.assertRaises(ValueError): - valid, test = testing_utils.eval_model( - dict( - task='integration_tests:chunky', - model='parlai.agents.test_agents.test_agents:MockTorchAgent', - batchsize=32, - ), - valid_datatype='valid', - ) - - with self.assertRaises(ValueError): - valid, test = testing_utils.eval_model( - dict( - task='integration_tests:chunky', - model='parlai.agents.test_agents.test_agents:MockTorchAgent', - batchsize=32, - ), - valid_datatype='valid:stream', - test_datatype='test', - ) + def test_non_stream_works(self): + testing_utils.eval_model( + dict( + task='integration_tests:chunky', + model='parlai.agents.test_agents.test_agents:MockTorchAgent', + batchsize=32, + ), + valid_datatype='valid', + ) + + testing_utils.eval_model( + dict( + task='integration_tests:chunky', + model='parlai.agents.test_agents.test_agents:MockTorchAgent', + batchsize=32, + ), + valid_datatype='valid:stream', + test_datatype='test', + ) class CustomEvaluationTeacher(DialogTeacher):