Skip to content

Commit

Permalink
Merge pull request #620 from 3rd-Son/coheremodel
Browse files Browse the repository at this point in the history
swarm -implemented async, stream, astream, batch and abatch
  • Loading branch information
cobycloud authored Oct 9, 2024
2 parents 5a404d2 + 618e4d2 commit a969f4f
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 39 deletions.
124 changes: 96 additions & 28 deletions pkgs/swarmauri/swarmauri/llms/concrete/CohereModel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
from typing import List, Dict, Literal
import asyncio
from typing import List, Dict, Literal, AsyncIterator, Iterator
from pydantic import Field
import cohere
from swarmauri_core.typing import SubclassUnion

Expand All @@ -26,45 +27,112 @@ class CohereModel(LLMBase):
]
name: str = "command"
type: Literal["CohereModel"] = "CohereModel"
client: cohere.ClientV2 = Field(default=None, exclude=True)

def __init__(self, **data):
super().__init__(**data)
self.client = cohere.ClientV2(api_key=self.api_key)

def _format_messages(
self, messages: List[SubclassUnion[MessageBase]]
) -> List[Dict[str, str]]:
"""
Cohere utilizes the following roles: CHATBOT, SYSTEM, TOOL, USER
"""
message_properties = ["content", "role"]

messages = [
message.model_dump(include=message_properties) for message in messages
]
formatted_messages = []
for message in messages:
message["message"] = message.pop("content")
if message.get("role") == "assistant":
message["role"] = "chatbot"
message["role"] = message["role"].upper()
logging.info(messages)
return messages
role = message.role
if role == "assistant":
role = "assistant"
formatted_messages.append({"role": role, "content": message.content})
return formatted_messages

def predict(self, conversation, temperature=0.7, max_tokens=256):
# Get next message
next_message = conversation.history[-1].content
formatted_messages = self._format_messages(conversation.history)

response = self.client.chat(
model=self.name,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
)

# Format chat_history
messages = self._format_messages(conversation.history[:-1])
message_content = response.message.content[0].text
conversation.add_message(AgentMessage(content=message_content))
return conversation

client = cohere.Client(api_key=self.api_key)
response = client.chat(
async def apredict(self, conversation, temperature=0.7, max_tokens=256):
formatted_messages = self._format_messages(conversation.history)

response = await asyncio.to_thread(
self.client.chat,
model=self.name,
chat_history=messages,
message=next_message,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
prompt_truncation="OFF",
connectors=[],
)

result = json.loads(response.json())
message_content = result["text"]
message_content = response.message.content[0].text
conversation.add_message(AgentMessage(content=message_content))
return conversation

def stream(self, conversation, temperature=0.7, max_tokens=256) -> Iterator[str]:
formatted_messages = self._format_messages(conversation.history)

stream = self.client.chat_stream(
model=self.name,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
)

collected_content = []
for chunk in stream:
if chunk and chunk.type == "content-delta":
content = chunk.delta.message.content.text
collected_content.append(content)
yield content

full_content = "".join(collected_content)
conversation.add_message(AgentMessage(content=full_content))

async def astream(
self, conversation, temperature=0.7, max_tokens=256
) -> AsyncIterator[str]:
formatted_messages = self._format_messages(conversation.history)

stream = await asyncio.to_thread(
self.client.chat_stream,
model=self.name,
messages=formatted_messages,
temperature=temperature,
max_tokens=max_tokens,
)

collected_content = []
for chunk in stream:
if chunk and chunk.type == "content-delta":
content = chunk.delta.message.content.text
collected_content.append(content)
yield content
await asyncio.sleep(0) # Allow other tasks to run

full_content = "".join(collected_content)
conversation.add_message(AgentMessage(content=full_content))

def batch(self, conversations: List, temperature=0.7, max_tokens=256) -> List:
return [
self.predict(conv, temperature=temperature, max_tokens=max_tokens)
for conv in conversations
]

async def abatch(
self, conversations: List, temperature=0.7, max_tokens=256, max_concurrent=5
) -> List:
semaphore = asyncio.Semaphore(max_concurrent)

async def process_conversation(conv):
async with semaphore:
return await self.apredict(
conv, temperature=temperature, max_tokens=max_tokens
)

tasks = [process_conversation(conv) for conv in conversations]
return await asyncio.gather(*tasks)
114 changes: 103 additions & 11 deletions pkgs/swarmauri/tests/unit/llms/CohereModel_unit_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import pytest
import os
import asyncio
from swarmauri.llms.concrete.CohereModel import CohereModel as LLM
from swarmauri.conversations.concrete.Conversation import Conversation

from swarmauri.messages.concrete.HumanMessage import HumanMessage
from swarmauri.messages.concrete.SystemMessage import SystemMessage
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("COHERE_API_KEY")

Expand Down Expand Up @@ -52,7 +49,6 @@ def test_default_name(cohere_model):
def test_no_system_context(cohere_model, model_name):
model = cohere_model
model.name = model_name

conversation = Conversation()

input_data = "Hello"
Expand All @@ -69,18 +65,114 @@ def test_no_system_context(cohere_model, model_name):
def test_preamble_system_context(cohere_model, model_name):
model = cohere_model
model.name = model_name

conversation = Conversation()

system_context = "Jane knows Martin."
human_message = SystemMessage(content=system_context)
conversation.add_message(human_message)
system_context = 'You only respond with the following phrase, "Jeff"'
system_message = SystemMessage(content=system_context)
conversation.add_message(system_message)

input_data = "Who does Jane know?"
input_data = "Hi"
human_message = HumanMessage(content=input_data)
conversation.add_message(human_message)

model.predict(conversation=conversation)
prediction = conversation.get_last().content
assert type(prediction) == str
assert "martin" in prediction.lower()
assert "Jeff" in prediction


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_stream(cohere_model, model_name):
model = cohere_model
model.name = model_name
conversation = Conversation()

input_data = "Write a short story about a cat."
human_message = HumanMessage(content=input_data)
conversation.add_message(human_message)

collected_tokens = []
for token in model.stream(conversation=conversation):
assert isinstance(token, str)
collected_tokens.append(token)

full_response = "".join(collected_tokens)
assert len(full_response) > 0
assert conversation.get_last().content == full_response


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
async def test_apredict(cohere_model, model_name):
model = cohere_model
model.name = model_name
conversation = Conversation()

input_data = "Hello"
human_message = HumanMessage(content=input_data)
conversation.add_message(human_message)

result = await model.apredict(conversation=conversation)
prediction = result.get_last().content
assert isinstance(prediction, str)


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
async def test_astream(cohere_model, model_name):
model = cohere_model
model.name = model_name
conversation = Conversation()

input_data = "Write a short story about a dog."
human_message = HumanMessage(content=input_data)
conversation.add_message(human_message)

collected_tokens = []
async for token in model.astream(conversation=conversation):
assert isinstance(token, str)
collected_tokens.append(token)

full_response = "".join(collected_tokens)
assert len(full_response) > 0
assert conversation.get_last().content == full_response


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
def test_batch(cohere_model, model_name):
model = cohere_model
model.name = model_name

conversations = []
for prompt in ["Hello", "Hi there", "Good morning"]:
conv = Conversation()
conv.add_message(HumanMessage(content=prompt))
conversations.append(conv)

results = model.batch(conversations=conversations)
assert len(results) == len(conversations)
for result in results:
assert isinstance(result.get_last().content, str)


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.unit
async def test_abatch(cohere_model, model_name):
model = cohere_model
model.name = model_name

conversations = []
for prompt in ["Hello", "Hi there", "Good morning"]:
conv = Conversation()
conv.add_message(HumanMessage(content=prompt))
conversations.append(conv)

results = await model.abatch(conversations=conversations)
assert len(results) == len(conversations)
for result in results:
assert isinstance(result.get_last().content, str)

0 comments on commit a969f4f

Please sign in to comment.