Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
broken test (#4754)
Browse files Browse the repository at this point in the history
  • Loading branch information
klshuster authored Aug 19, 2022
1 parent e3eff5b commit 5b2fc1d
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions tests/test_tga.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,25 +129,31 @@ def test_full_context_block(self):
obs = {'text': '1 2 3 4 ' * 256, 'episode_done': False}
agent.observe(obs)
batch = agent.batchify([agent.observation])
self.assertEqual(agent._get_context(batch, 0).tolist(), [5, 4, 6, 7] * 256)
self.assertEqual(
agent._get_batch_context(batch)[0].tolist(), [5, 4, 6, 7] * 256
)

# observe 1 more obs, context is the same (truncation)
agent.observe(obs)
batch = agent.batchify([agent.observation])
self.assertEqual(agent._get_context(batch, 0).tolist(), [5, 4, 6, 7] * 256)
self.assertEqual(
agent._get_batch_context(batch)[0].tolist(), [5, 4, 6, 7] * 256
)

# Now, set agent's beam_block_full_context
args += ['--beam-block-full-context', 'true']
agent2 = create_agent(pp.parse_args(args), True)
agent2.observe(obs)
batch = agent2.batchify([agent2.observation])
self.assertEqual(agent2._get_context(batch, 0).tolist(), [5, 4, 6, 7] * 256)
self.assertEqual(
agent2._get_batch_context(batch)[0].tolist(), [5, 4, 6, 7] * 256
)

# observe 1 more obs, context is larger now
agent2.observe(obs)
batch = agent2.batchify([agent2.observation])
self.assertEqual(
agent2._get_context(batch, 0).tolist(),
agent2._get_batch_context(batch)[0].tolist(),
[5, 4, 6, 7] * 256 + [3] + [5, 4, 6, 7] * 256,
) # 3 is end token.

Expand Down

0 comments on commit 5b2fc1d

Please sign in to comment.