-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #64 from intelligentnode/63-create-keras-agent
63 create keras agent
- Loading branch information
Showing
6 changed files
with
193 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,5 @@ REPLICATE_API_KEY= | |
MISTRAL_API_KEY= | ||
INTELLI_ONE_KEY= | ||
GEMINI_API_KEY= | ||
KAGGLE_USERNAME= | ||
KAGGLE_API_KEY= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from intelli.flow.agents.agent import BasicAgent | ||
from intelli.flow.input.agent_input import AgentInput, TextAgentInput, ImageAgentInput | ||
import os | ||
|
||
class KerasAgent(BasicAgent): | ||
def __init__(self, agent_type, provider="", mission="", model_params={}, options=None, log=False, external=False): | ||
super().__init__() | ||
|
||
# set the parameters | ||
self.type = agent_type | ||
self.provider = provider | ||
self.mission = mission | ||
self.model_params = model_params | ||
self.options = options if options is not None else {} | ||
self.log = log | ||
self.external = external | ||
try: | ||
import keras_nlp | ||
import keras | ||
self.nlp_manager = keras_nlp | ||
self.keras_manager = keras | ||
if not external: | ||
self.model = self.load_model() | ||
except ImportError as e: | ||
raise ImportError("keras_nlp is not installed or model is not supported.") from e | ||
|
||
def set_keras_model(self, model, model_params): | ||
if not self.external: | ||
raise Exception("Initiate the agent with external flag to set the model.") | ||
|
||
if not hasattr(model, "generate"): | ||
raise ValueError("The provided model does not have a 'generate' method, which is required for this agent.") | ||
|
||
self.model = model | ||
self.model_params = model_params | ||
|
||
def update_model_params(self, model_params): | ||
|
||
self.model_params = model_params | ||
|
||
def load_model(self): | ||
""" | ||
Dynamically load a model based on `model_params`. | ||
This example demonstrates loading Gemma models, but you should add similar logic for other models. | ||
""" | ||
|
||
model_param = self.model_params["model"] | ||
|
||
# set the username and password | ||
if "KAGGLE_USERNAME" in self.model_params: | ||
os.environ["KAGGLE_USERNAME"] = self.model_params["KAGGLE_USERNAME"] | ||
os.environ["KAGGLE_KEY"] = self.model_params["KAGGLE_KEY"] | ||
|
||
if "gemma" in model_param: | ||
print("start gemma model") | ||
return self.nlp_manager.models.GemmaCausalLM.from_preset(model_param) | ||
elif "mistral" in model_param: | ||
print("start mistral model") | ||
return self.nlp_manager.models.MistralCausalLM.from_preset(model_param) | ||
# ------------------------------------------------------------------ # | ||
# Add similar conditions for models like Mistral, RoBERTa, or BERT # | ||
# ------------------------------------------------------------------ # | ||
else: | ||
raise Exception("The received model not supported in this version.") | ||
|
||
def execute(self, agent_input: AgentInput): | ||
""" | ||
Execute the agent task based on input. | ||
""" | ||
|
||
if not isinstance(agent_input, TextAgentInput): | ||
raise ValueError("This agent requires a TextAgentInput.") | ||
|
||
max_length = self.model_params.get("max_length", 64) | ||
model_input = agent_input.desc if not self.mission else self.mission + ": " + agent_input.desc | ||
|
||
if hasattr(self.model, "generate"): | ||
if self.log: | ||
print("Call the model generate with input: ", model_input) | ||
|
||
generated_output = self.model.generate(model_input, max_length=max_length) | ||
|
||
if isinstance(generated_output, str) and generated_output.startswith(model_input): | ||
generated_output = generated_output.replace(model_input, "", 1).strip() | ||
|
||
return generated_output | ||
else: | ||
raise NotImplementedError("Model does not support text generation.") | ||
|
||
def fine_tune_model_with_lora(self, fine_tuning_config): | ||
""" | ||
Finetunes the model as per the provided config. | ||
""" | ||
print("Fine tuning model...") | ||
|
||
# rank=4 replaces the weights matrix of relevant layers with the product AxB of two matrices of rank 4, | ||
# which reduces the number of trainable parameters. | ||
lora_rank = fine_tuning_config.get("lora_rank", 4) | ||
|
||
# Enable lora for the model learning | ||
self.model.backbone.enable_lora(rank=lora_rank) | ||
|
||
# Set the preprocessor sequence_length | ||
self.model.preprocessor.sequence_length = fine_tuning_config.get("sequence_length", 512) | ||
|
||
# Use AdamW optimizer. | ||
learning_rate = fine_tuning_config.get("learning_rate", 0.001) | ||
weight_decay = fine_tuning_config.get("weight_decay", 0.004) | ||
beta_1 = fine_tuning_config.get("beta_1", 0.9) | ||
beta_2 = fine_tuning_config.get("beta_2", 0.999) | ||
|
||
optimizer = self.keras_manager.optimizers.AdamW( | ||
learning_rate=learning_rate, | ||
weight_decay=weight_decay, | ||
beta_1=beta_1, | ||
beta_2=beta_2 | ||
) | ||
|
||
# Exclude layernorm and bias terms from decay. | ||
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"]) | ||
|
||
# Compile the model | ||
self.model.compile( | ||
loss=self.keras_manager.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
optimizer=optimizer, | ||
weighted_metrics=[self.keras_manager.metrics.SparseCategoricalAccuracy()], | ||
) | ||
|
||
# Fit using input dataset, epochs and batch size | ||
dataset = fine_tuning_config.get('dataset') | ||
epochs = fine_tuning_config.get('epochs', 3) | ||
batch_size = fine_tuning_config.get('batch_size', 1) | ||
self.model.fit(dataset, epochs=epochs, batch_size=batch_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
import unittest | ||
from intelli.flow.types import * | ||
from intelli.flow.agents.agent import Agent | ||
from intelli.flow.agents.kagent import KerasAgent | ||
from intelli.flow.input.task_input import TextTaskInput | ||
from intelli.flow.sequence_flow import SequenceFlow | ||
from intelli.flow.tasks.task import Task | ||
from dotenv import load_dotenv | ||
# set keras back | ||
os.environ["KERAS_BACKEND"] = "jax" | ||
# load env | ||
load_dotenv() | ||
|
||
class TestKerasFlows(unittest.TestCase): | ||
def setUp(self): | ||
self.kaggle_username = os.getenv("KAGGLE_USERNAME") | ||
self.kaggle_pass = os.getenv("KAGGLE_API_KEY") | ||
|
||
def test_blog_post_flow(self): | ||
print("---- start simple blog post flow ----") | ||
|
||
# Define agents | ||
gemma_model_params = { | ||
"model": "gemma_instruct_2b_en", | ||
"max_length": 64, | ||
"KAGGLE_USERNAME": self.kaggle_username, | ||
"KAGGLE_KEY": self.kaggle_pass, | ||
} | ||
gemma_agent = KerasAgent(agent_type="text", | ||
mission="write blog posts", | ||
model_params=gemma_model_params) | ||
|
||
# Define tasks | ||
task1 = Task( | ||
TextTaskInput("blog post about electric cars"), gemma_agent, log=True | ||
) | ||
|
||
# Start SequenceFlow | ||
flow = SequenceFlow([task1], log=True) | ||
final_result = flow.start() | ||
|
||
print("Final result:", final_result) | ||
self.assertIsNotNone(final_result) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters