Skip to content

Commit

Permalink
Merge pull request #64 from intelligentnode/63-create-keras-agent
Browse files Browse the repository at this point in the history
63 create keras agent
  • Loading branch information
intelligentnode authored Mar 2, 2024
2 parents 348eddd + a1caeb1 commit 1051bb5
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 4 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ wrapper = RemoteImageModel(your_api_key, provider)
results = wrapper.generate_images(image_input)
```

# The Repository Setup
## Keras Agent
Load gemma or mistral models offline using keras agent, [check the docs](https://docs.intellinode.ai/docs/python/flows/kagent).

# Repository Setup
1. Install the requirements.
```shell
pip install -r requirements.txt
Expand Down
4 changes: 3 additions & 1 deletion instructions/run_integration_text.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,6 @@ python3 -m unittest intelli.test.integration.test_chatbot_with_data
# basic flow
python3 -m unittest intelli.test.integration.test_flow_sequence
# map flow
python3 -m unittest intelli.test.integration.test_flow_map
python3 -m unittest intelli.test.integration.test_flow_map
# keras nlp
python3 -m unittest intelli.test.integration.test_keras_agent
2 changes: 2 additions & 0 deletions intelli/.example.env
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ REPLICATE_API_KEY=
MISTRAL_API_KEY=
INTELLI_ONE_KEY=
GEMINI_API_KEY=
KAGGLE_USERNAME=
KAGGLE_API_KEY=
133 changes: 133 additions & 0 deletions intelli/flow/agents/kagent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from intelli.flow.agents.agent import BasicAgent
from intelli.flow.input.agent_input import AgentInput, TextAgentInput, ImageAgentInput
import os

class KerasAgent(BasicAgent):
def __init__(self, agent_type, provider="", mission="", model_params={}, options=None, log=False, external=False):
super().__init__()

# set the parameters
self.type = agent_type
self.provider = provider
self.mission = mission
self.model_params = model_params
self.options = options if options is not None else {}
self.log = log
self.external = external
try:
import keras_nlp
import keras
self.nlp_manager = keras_nlp
self.keras_manager = keras
if not external:
self.model = self.load_model()
except ImportError as e:
raise ImportError("keras_nlp is not installed or model is not supported.") from e

def set_keras_model(self, model, model_params):
if not self.external:
raise Exception("Initiate the agent with external flag to set the model.")

if not hasattr(model, "generate"):
raise ValueError("The provided model does not have a 'generate' method, which is required for this agent.")

self.model = model
self.model_params = model_params

def update_model_params(self, model_params):

self.model_params = model_params

def load_model(self):
"""
Dynamically load a model based on `model_params`.
This example demonstrates loading Gemma models, but you should add similar logic for other models.
"""

model_param = self.model_params["model"]

# set the username and password
if "KAGGLE_USERNAME" in self.model_params:
os.environ["KAGGLE_USERNAME"] = self.model_params["KAGGLE_USERNAME"]
os.environ["KAGGLE_KEY"] = self.model_params["KAGGLE_KEY"]

if "gemma" in model_param:
print("start gemma model")
return self.nlp_manager.models.GemmaCausalLM.from_preset(model_param)
elif "mistral" in model_param:
print("start mistral model")
return self.nlp_manager.models.MistralCausalLM.from_preset(model_param)
# ------------------------------------------------------------------ #
# Add similar conditions for models like Mistral, RoBERTa, or BERT #
# ------------------------------------------------------------------ #
else:
raise Exception("The received model not supported in this version.")

def execute(self, agent_input: AgentInput):
"""
Execute the agent task based on input.
"""

if not isinstance(agent_input, TextAgentInput):
raise ValueError("This agent requires a TextAgentInput.")

max_length = self.model_params.get("max_length", 64)
model_input = agent_input.desc if not self.mission else self.mission + ": " + agent_input.desc

if hasattr(self.model, "generate"):
if self.log:
print("Call the model generate with input: ", model_input)

generated_output = self.model.generate(model_input, max_length=max_length)

if isinstance(generated_output, str) and generated_output.startswith(model_input):
generated_output = generated_output.replace(model_input, "", 1).strip()

return generated_output
else:
raise NotImplementedError("Model does not support text generation.")

def fine_tune_model_with_lora(self, fine_tuning_config):
"""
Finetunes the model as per the provided config.
"""
print("Fine tuning model...")

# rank=4 replaces the weights matrix of relevant layers with the product AxB of two matrices of rank 4,
# which reduces the number of trainable parameters.
lora_rank = fine_tuning_config.get("lora_rank", 4)

# Enable lora for the model learning
self.model.backbone.enable_lora(rank=lora_rank)

# Set the preprocessor sequence_length
self.model.preprocessor.sequence_length = fine_tuning_config.get("sequence_length", 512)

# Use AdamW optimizer.
learning_rate = fine_tuning_config.get("learning_rate", 0.001)
weight_decay = fine_tuning_config.get("weight_decay", 0.004)
beta_1 = fine_tuning_config.get("beta_1", 0.9)
beta_2 = fine_tuning_config.get("beta_2", 0.999)

optimizer = self.keras_manager.optimizers.AdamW(
learning_rate=learning_rate,
weight_decay=weight_decay,
beta_1=beta_1,
beta_2=beta_2
)

# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

# Compile the model
self.model.compile(
loss=self.keras_manager.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[self.keras_manager.metrics.SparseCategoricalAccuracy()],
)

# Fit using input dataset, epochs and batch size
dataset = fine_tuning_config.get('dataset')
epochs = fine_tuning_config.get('epochs', 3)
batch_size = fine_tuning_config.get('batch_size', 1)
self.model.fit(dataset, epochs=epochs, batch_size=batch_size)
48 changes: 48 additions & 0 deletions intelli/test/integration/test_keras_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import unittest
from intelli.flow.types import *
from intelli.flow.agents.agent import Agent
from intelli.flow.agents.kagent import KerasAgent
from intelli.flow.input.task_input import TextTaskInput
from intelli.flow.sequence_flow import SequenceFlow
from intelli.flow.tasks.task import Task
from dotenv import load_dotenv
# set keras back
os.environ["KERAS_BACKEND"] = "jax"
# load env
load_dotenv()

class TestKerasFlows(unittest.TestCase):
def setUp(self):
self.kaggle_username = os.getenv("KAGGLE_USERNAME")
self.kaggle_pass = os.getenv("KAGGLE_API_KEY")

def test_blog_post_flow(self):
print("---- start simple blog post flow ----")

# Define agents
gemma_model_params = {
"model": "gemma_instruct_2b_en",
"max_length": 64,
"KAGGLE_USERNAME": self.kaggle_username,
"KAGGLE_KEY": self.kaggle_pass,
}
gemma_agent = KerasAgent(agent_type="text",
mission="write blog posts",
model_params=gemma_model_params)

# Define tasks
task1 = Task(
TextTaskInput("blog post about electric cars"), gemma_agent, log=True
)

# Start SequenceFlow
flow = SequenceFlow([task1], log=True)
final_result = flow.start()

print("Final result:", final_result)
self.assertIsNotNone(final_result)


if __name__ == "__main__":
unittest.main()
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="intelli",
version="0.1.6",
version="0.2.0",
author="Intellinode",
author_email="admin@intellinode.ai",
description="Create chatbots and AI agent workflows. Intelli provide unified layer to connect your data with multiple AI models like OpenAI, Gemini, and Mistral.",
Expand All @@ -21,6 +21,7 @@
"python-dotenv==1.0.1", "networkx==3.2.1",
],
extras_require={
'visual': ["matplotlib==3.6.0"],
"visual": ["matplotlib==3.6.0"],
"offline": ["keras-nlp", "keras>=3"],
}
)

0 comments on commit 1051bb5

Please sign in to comment.