Skip to content

Commit

Permalink
Shorten the conversation tests for speed + fixing position overflows
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Oct 20, 2023
1 parent bc4bbd9 commit e26da79
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions tests/pipelines/test_pipelines_conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ def get_test_pipeline(self, model, tokenizer, processor):

def run_pipeline_test(self, conversation_agent, _):
# Simple
outputs = conversation_agent(Conversation("Hi there!"), max_new_tokens=20)
outputs = conversation_agent(Conversation("Hi there!"), max_new_tokens=2)
self.assertEqual(
outputs,
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
)

# Single list
outputs = conversation_agent([Conversation("Hi there!")], max_new_tokens=20)
outputs = conversation_agent([Conversation("Hi there!")], max_new_tokens=2)
self.assertEqual(
outputs,
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
Expand All @@ -96,7 +96,7 @@ def run_pipeline_test(self, conversation_agent, _):
self.assertEqual(len(conversation_1), 1)
self.assertEqual(len(conversation_2), 1)

outputs = conversation_agent([conversation_1, conversation_2], max_new_tokens=20)
outputs = conversation_agent([conversation_1, conversation_2], max_new_tokens=2)
self.assertEqual(outputs, [conversation_1, conversation_2])
self.assertEqual(
outputs,
Expand All @@ -118,7 +118,7 @@ def run_pipeline_test(self, conversation_agent, _):

# One conversation with history
conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
outputs = conversation_agent(conversation_2, max_new_tokens=20)
outputs = conversation_agent(conversation_2, max_new_tokens=2)
self.assertEqual(outputs, conversation_2)
self.assertEqual(
outputs,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pipeline_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def data(n):

out = []
if task == "conversational":
for item in pipeline(data(10), batch_size=4, max_new_tokens=20):
for item in pipeline(data(10), batch_size=4, max_new_tokens=2):
out.append(item)
else:
for item in pipeline(data(10), batch_size=4):
Expand Down

0 comments on commit e26da79

Please sign in to comment.