Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
dexterju27 committed Mar 8, 2021
1 parent 3c211a1 commit 0c926d2
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/test_torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,3 +1051,17 @@ def get_opt(init_mf, mf):
init_model_file, is_finetune = agent._get_init_model(popt, None)
self.assertEqual(init_model_file, '{}.checkpoint'.format(mf))
self.assertFalse(is_finetune)

def test_truncate_metrics(self):
agent = get_agent(model='test_agents/unigram', truncate=5)
obs = {
'text': "I'll be back. I'll be back. I'll be back.",
'labels': ["I'll be back. I'll be back. I'll be back."],
'episode_done': True,
}
obs = agent.observe(obs)
agent.act()
self.assertEqual(agent._local_metrics['context_truncate'][0].value(), 1.0)
self.assertEqual(agent._local_metrics['label_truncate'][0].value(), 1.0)
self.assertEqual(agent._local_metrics['context_length'][0].value(), 9)
self.assertEqual(agent._local_metrics['label_length'][0].value(), 11)

0 comments on commit 0c926d2

Please sign in to comment.