Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Refactor llms code using OOP, and model registry #65

Merged
merged 1 commit into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
442 changes: 12 additions & 430 deletions src/llms/llms.py

Large diffs are not rendered by default.

Empty file added src/llms/llms_v2/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions src/llms/llms_v2/base_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from openai import OpenAI
import os
import json
import re
import time
from datetime import datetime
from src.context.simple_context import SimpleContextManager
import logging
from abc import ABC, abstractmethod


class BaseLLMKernel(ABC):
def __init__(self,
llm_name: str,
max_gpu_memory: dict = None,
eval_device: str = None,
max_new_tokens: int = 256,
log_mode: str = "console"
):
print("Initialize AIOS powered by LLM: {}".format(llm_name))
self.config = self.load_config(llm_name)
self.max_gpu_memory = max_gpu_memory
self.eval_device = eval_device

self.log_mode = log_mode

self.load_llm_and_tokenizer()
self.MAX_NEW_TOKENS = max_new_tokens
self.logger = self.setup_logger()
self.context_manager = SimpleContextManager()
self.open_sourced = self.config["open_sourced"]
self.model_type = self.config["model_type"]
self.model_name = self.config["model_name"]
print("AIOS LLM successfully loaded. ")
self.model = None
self.tokenizer = None

def convert_map(self, map: dict) -> dict:
new_map = {}
for k,v in map.items():
new_map[int(k)] = v
return new_map

def load_config(self, llm_name):
# print(os.getcwd())
config_file = os.path.join(os.getcwd(), "src", "llms", "../llm_config/{}.json".format(llm_name))
with open(config_file, "r") as f:
config = json.load(f)
return config

def setup_logger(self):
logger = logging.getLogger(f"FIFO Scheduler Logger")
# logger.setLevel(logging.INFO) # Set the minimum logging level
date_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Provide two log modes: console and file
# Ensure the logger doesn't propagate to the root logger
logger.propagate = False

# Remove all handlers associated with this logger
for handler in logger.handlers[:]:
logger.removeHandler(handler)

if self.log_mode == "console":
handler = logging.StreamHandler()
handler.setLevel(logging.INFO) # Set logging level for console output
else:
assert self.log_mode == "file"
log_dir = os.path.join(os.getcwd(), "logs", "scheduler")
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_file = os.path.join(log_dir, f"{date_time}.txt")
handler = logging.FileHandler(log_file)
handler.setLevel(logging.INFO) # Set logging

logger.addHandler(handler) # enabled when run in a simulated shell
return logger

@abstractmethod
def load_llm_and_tokenizer(self) -> None: # load model from config
raise NotImplementedError

def address_request(self,
agent_process,
temperature=0.0
):
self.process(agent_process)
return

@abstractmethod
def process(self,
agent_process,
temperature=0.0) -> None:
raise NotImplementedError

5 changes: 5 additions & 0 deletions src/llms/llms_v2/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from transformers import AutoModelForCausalLM

MODEL_CLASS = {
"causal_lm": AutoModelForCausalLM,
}
55 changes: 55 additions & 0 deletions src/llms/llms_v2/gemini_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import re

import torch
from .constant import MODEL_CLASS
from .base_llm import BaseLLMKernel
import time
from transformers import AutoTokenizer
from ...utils.utils import get_from_env


class GeminiLLM(BaseLLMKernel):
def __init__(self, llm_name: str,
max_gpu_memory: dict = None,
eval_device: str = None,
max_new_tokens: int = 256,
log_mode: str = "console"):
super().__init__(llm_name,
max_gpu_memory,
eval_device,
max_new_tokens,
log_mode)

def load_llm_and_tokenizer(self) -> None:
assert self.model_name == "gemini-pro"
try:
import google.generativeai as genai
gemini_api_key = get_from_env("GEMINI_API_KEY")
genai.configure(api_key=gemini_api_key)
self.model = genai.GenerativeModel(self.model_name)
self.tokenizer = None
except ImportError:
raise ImportError(
"Could not import google.generativeai python package. "
"Please install it with `pip install google-generativeai`."
)

def process(self,
agent_process,
temperature=0.0) -> None:
assert re.search(r'gemini', self.model_name, re.IGNORECASE)
agent_process.set_status("executing")
agent_process.set_start_time(time.time())
prompt = agent_process.prompt
# print(f"Prompt: {prompt}")
outputs = self.model.generate_content(
prompt
)
try:
result = outputs.candidates[0].content.parts[0].text
agent_process.set_response(result)
except IndexError:
raise IndexError(f"{self.model_name} can not generate a valid result, please try again")
agent_process.set_status("done")
agent_process.set_end_time(time.time())
return
31 changes: 31 additions & 0 deletions src/llms/llms_v2/gpt_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import re
from .base_llm import BaseLLMKernel
import time
from openai import OpenAI

class GPTLLM(BaseLLMKernel):

def load_llm_and_tokenizer(self) -> None:
self.model = OpenAI()
self.tokenizer = None

def gpt_process(self,
agent_process,
temperature=0.0
):
assert re.search(r'gpt', self.model_name, re.IGNORECASE)
agent_process.set_status("executing")
agent_process.set_start_time(time.time())
prompt = agent_process.prompt,
print(f"Prompt: {prompt}")
response = self.model.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user", "content": prompt}
]
)
time.sleep(2) # set according to your request per minite
agent_process.set_response(response.choices[0].message.content)
agent_process.set_status("done")
agent_process.set_end_time(time.time())
return
10 changes: 10 additions & 0 deletions src/llms/llms_v2/model_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .gpt_llm import GPTLLM
from .open_llm import OpenLLM
from .gemini_llm import GeminiLLM
from .redblock_llm import RedBlockLLM

#used for closed LLM model registry
MODEL_REGISTRY = {'bedrock/anthropic.claude-3-haiku-20240307-v1:0': RedBlockLLM,
'gemini-pro': GeminiLLM,
'gpt-3.5': GPTLLM,
'gpt-4': GPTLLM}
Loading