From 0c926d2107aba07babdd6dbbac871d9811247db5 Mon Sep 17 00:00:00 2001 From: Dexter Ju Date: Mon, 8 Mar 2021 11:21:26 -0800 Subject: [PATCH] add test --- tests/test_torch_agent.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_torch_agent.py b/tests/test_torch_agent.py index 513e80b0d16..144ae2c70b2 100644 --- a/tests/test_torch_agent.py +++ b/tests/test_torch_agent.py @@ -1051,3 +1051,17 @@ def get_opt(init_mf, mf): init_model_file, is_finetune = agent._get_init_model(popt, None) self.assertEqual(init_model_file, '{}.checkpoint'.format(mf)) self.assertFalse(is_finetune) + + def test_truncate_metrics(self): + agent = get_agent(model='test_agents/unigram', truncate=5) + obs = { + 'text': "I'll be back. I'll be back. I'll be back.", + 'labels': ["I'll be back. I'll be back. I'll be back."], + 'episode_done': True, + } + obs = agent.observe(obs) + agent.act() + self.assertEqual(agent._local_metrics['context_truncate'][0].value(), 1.0) + self.assertEqual(agent._local_metrics['label_truncate'][0].value(), 1.0) + self.assertEqual(agent._local_metrics['context_length'][0].value(), 9) + self.assertEqual(agent._local_metrics['label_length'][0].value(), 11)