Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat add plugin for privacy #49

Merged
merged 7 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions litellm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
from litellm import completion
from typing import List, Dict, Any, Optional

SAFETY_SETTINGS = [
{"category": cat, "threshold": "BLOCK_NONE"}
for cat in [
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT"
]
]

class LiteLLMWrapper:
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
self.api_key = api_key
Expand All @@ -14,7 +24,7 @@ class Chat:
class Completions:
@staticmethod
def create(model: str, messages: List[Dict[str, str]], **kwargs):
response = completion(model=model, messages=messages, **kwargs)
response = completion(model=model, messages=messages, **kwargs, safety_settings=SAFETY_SETTINGS)
# Convert LiteLLM response to match OpenAI response structure
return response

Expand All @@ -28,8 +38,8 @@ def list():
# This list can be expanded as needed.
return {
"data": [
{"id": "gpt-3.5-turbo"},
{"id": "gpt-4"},
{"id": "gpt-4o-mini"},
{"id": "gpt-4o"},
{"id": "command-nightly"},
# Add more models as needed
]
Expand Down
80 changes: 48 additions & 32 deletions optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from optillm.rto import round_trip_optimization
from optillm.self_consistency import advanced_self_consistency_approach
from optillm.pvg import inference_time_pv_game
from optillm.z3_solver import Z3SolverSystem
from optillm.z3_solver import Z3SymPySolverSystem
from optillm.rstar import RStar
from optillm.cot_reflection import cot_reflection
from optillm.plansearch import plansearch
Expand All @@ -44,31 +44,34 @@
# Initialize Flask app
app = Flask(__name__)

# OpenAI, Azure, or LiteLLM API configuration
if os.environ.get("OPENAI_API_KEY"):
API_KEY = os.environ.get("OPENAI_API_KEY")
default_client = OpenAI(api_key=API_KEY)
elif os.environ.get("AZURE_OPENAI_API_KEY"):
API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
API_VERSION = os.environ.get("AZURE_API_VERSION")
AZURE_ENDPOINT = os.environ.get("AZURE_API_BASE")
if API_KEY is not None:
default_client = AzureOpenAI(
api_key=API_KEY,
api_version=API_VERSION,
azure_endpoint=AZURE_ENDPOINT,
)
def get_config():
API_KEY = None
# OpenAI, Azure, or LiteLLM API configuration
if os.environ.get("OPENAI_API_KEY"):
API_KEY = os.environ.get("OPENAI_API_KEY")
default_client = OpenAI(api_key=API_KEY)
elif os.environ.get("AZURE_OPENAI_API_KEY"):
API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
API_VERSION = os.environ.get("AZURE_API_VERSION")
AZURE_ENDPOINT = os.environ.get("AZURE_API_BASE")
if API_KEY is not None:
default_client = AzureOpenAI(
api_key=API_KEY,
api_version=API_VERSION,
azure_endpoint=AZURE_ENDPOINT,
)
else:
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
azure_credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")
default_client = AzureOpenAI(
api_version=API_VERSION,
azure_endpoint=AZURE_ENDPOINT,
azure_ad_token_provider=token_provider
)
else:
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
azure_credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")
default_client = AzureOpenAI(
api_version=API_VERSION,
azure_endpoint=AZURE_ENDPOINT,
azure_ad_token_provider=token_provider
)
else:
default_client = LiteLLMWrapper()
default_client = LiteLLMWrapper()
return default_client, API_KEY

# Server configuration
server_config = {
Expand Down Expand Up @@ -156,7 +159,7 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode
elif approach == 'rto':
return round_trip_optimization(system_prompt, initial_query, client, model)
elif approach == 'z3':
z3_solver = Z3SolverSystem(system_prompt, client, model)
z3_solver = Z3SymPySolverSystem(system_prompt, client, model)
return z3_solver.process_query(initial_query)
elif approach == "self_consistency":
return advanced_self_consistency_approach(system_prompt, initial_query, client, model)
Expand Down Expand Up @@ -263,6 +266,14 @@ def check_api_key():
def proxy():
logger.info('Received request to /v1/chat/completions')
data = request.get_json()
auth_header = request.headers.get("Authorization")
bearer_token = ""

if auth_header and auth_header.startswith("Bearer "):
# Extract the bearer token
bearer_token = auth_header.split("Bearer ")[1].strip()
logger.debug(f"Intercepted Bearer Token: {bearer_token}")

logger.debug(f'Request data: {data}')

stream = data.get('stream', False)
Expand All @@ -281,15 +292,20 @@ def proxy():
model = f"{optillm_approach}-{model}"

base_url = server_config['base_url']

if base_url != "":
client = OpenAI(api_key=API_KEY, base_url=base_url)
else:
client = default_client
default_client, api_key = get_config()

operation, approaches, model = parse_combined_approach(model, known_approaches, plugin_approaches)
logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}')

if bearer_token != "" and bearer_token.startswith("sk-") and model.startswith("gpt"):
api_key = bearer_token
if base_url != "":
client = OpenAI(api_key=api_key, base_url=base_url)
else:
client = OpenAI(api_key=api_key)
else:
client = default_client

try:
if operation == 'SINGLE':
final_response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
Expand Down Expand Up @@ -342,7 +358,7 @@ def proxy():
@app.route('/v1/models', methods=['GET'])
def proxy_models():
logger.info('Received request to /v1/models')

default_client, API_KEY = get_config()
try:
if server_config['base_url']:
client = OpenAI(api_key=API_KEY, base_url=server_config['base_url'])
Expand Down
28 changes: 17 additions & 11 deletions optillm/plugins/memory_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,23 @@ def extract_query(text: str) -> Tuple[str, str]:
return query, context

def extract_key_information(text: str, client, model: str) -> List[str]:
# print(f"Prompt : {text}")
prompt = f"""Extract key information from the following text. Provide a list of important facts or concepts, each on a new line:

{text}

Key information:"""

response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=1000
)

key_info = response.choices[0].message.content.strip().split('\n')
try:
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=1000
)
key_info = response.choices[0].message.content.strip().split('\n')
except Exception as e:
print(f"Error parsing content: {str(e)}")
return [],0

return [info.strip('- ') for info in key_info if info.strip()], response.usage.completion_tokens

Expand All @@ -75,14 +79,16 @@ def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str
chunk_size = 10000
for i in range(0, len(context), chunk_size):
chunk = context[i:i+chunk_size]
# print(f"chunk: {chunk}")
key_info, tokens = extract_key_information(chunk, client, model)
#print(f"key info: {key_info}")
completion_tokens += tokens
for info in key_info:
memory.add(info)

# print(f"query : {query}")
# Retrieve relevant information from memory
relevant_info = memory.get_relevant(query)

# print(f"relevant_info : {relevant_info}")
# Generate response using relevant information
prompt = f"""System: {system_prompt}

Expand All @@ -96,8 +102,8 @@ def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str
messages=[{"role": "user", "content": prompt}],
max_tokens=1000
)

print(f"response : {response}")
final_response = response.choices[0].message.content.strip()
completion_tokens += response.usage.completion_tokens

print(f"final_response: {final_response}")
return final_response, completion_tokens
135 changes: 135 additions & 0 deletions optillm/plugins/privacy_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import spacy
from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine, DeanonymizeEngine, OperatorConfig
from presidio_anonymizer.operators import Operator, OperatorType

from typing import Dict, Tuple

SLUG = "privacy"

class InstanceCounterAnonymizer(Operator):
"""
Anonymizer which replaces the entity value
with an instance counter per entity.
"""

REPLACING_FORMAT = "<{entity_type}_{index}>"

def operate(self, text: str, params: Dict = None) -> str:
"""Anonymize the input text."""

entity_type: str = params["entity_type"]

# entity_mapping is a dict of dicts containing mappings per entity type
entity_mapping: Dict[Dict:str] = params["entity_mapping"]

entity_mapping_for_type = entity_mapping.get(entity_type)
if not entity_mapping_for_type:
new_text = self.REPLACING_FORMAT.format(
entity_type=entity_type, index=0
)
entity_mapping[entity_type] = {}

else:
if text in entity_mapping_for_type:
return entity_mapping_for_type[text]

previous_index = self._get_last_index(entity_mapping_for_type)
new_text = self.REPLACING_FORMAT.format(
entity_type=entity_type, index=previous_index + 1
)

entity_mapping[entity_type][text] = new_text
return new_text

@staticmethod
def _get_last_index(entity_mapping_for_type: Dict) -> int:
"""Get the last index for a given entity type."""

def get_index(value: str) -> int:
return int(value.split("_")[-1][:-1])

indices = [get_index(v) for v in entity_mapping_for_type.values()]
return max(indices)

def validate(self, params: Dict = None) -> None:
"""Validate operator parameters."""

if "entity_mapping" not in params:
raise ValueError("An input Dict called `entity_mapping` is required.")
if "entity_type" not in params:
raise ValueError("An entity_type param is required.")

def operator_name(self) -> str:
return "entity_counter"

def operator_type(self) -> OperatorType:
return OperatorType.Anonymize

def download_model(model_name):
if not spacy.util.is_package(model_name):
print(f"Downloading {model_name} model...")
spacy.cli.download(model_name)
else:
print(f"{model_name} model already downloaded.")

def replace_entities(entity_map, text):
# Create a reverse mapping of placeholders to entity names
reverse_map = {}
for entity_type, entities in entity_map.items():
for entity_name, placeholder in entities.items():
reverse_map[placeholder] = entity_name

# Function to replace placeholders with entity names
def replace_placeholder(match):
placeholder = match.group(0)
return reverse_map.get(placeholder, placeholder)

# Use regex to find and replace all placeholders
import re
pattern = r'<[A-Z_]+_\d+>'
replaced_text = re.sub(pattern, replace_placeholder, text)

return replaced_text

def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]:
# Use the function
model_name = "en_core_web_lg"
download_model(model_name)

analyzer = AnalyzerEngine()
analyzer_results = analyzer.analyze(text=initial_query, language="en")

# Create Anonymizer engine and add the custom anonymizer
anonymizer_engine = AnonymizerEngine()
anonymizer_engine.add_anonymizer(InstanceCounterAnonymizer)

# Create a mapping between entity types and counters
entity_mapping = dict()

# Anonymize the text
anonymized_result = anonymizer_engine.anonymize(
initial_query,
analyzer_results,
{
"DEFAULT": OperatorConfig(
"entity_counter", {"entity_mapping": entity_mapping}
)
},
)
# print(f"Anonymized request: {anonymized_result.text}")

response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": anonymized_result.text}],
)

# print(entity_mapping)
final_response = response.choices[0].message.content.strip()
# print(f"response: {final_response}")

final_response = replace_entities(entity_mapping, final_response)

return final_response, response.usage.completion_tokens
2 changes: 1 addition & 1 deletion optillm/rto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode
c2 = extract_code_from_prompt(c2)

if c1.strip() == c2.strip():
return c1
return c1, rto_completion_tokens

messages = [{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Initial query: {initial_query}\n\nFirst generated code (C1):\n{c1}\n\nSecond generated code (C2):\n{c2}\n\nBased on the initial query and these two different code implementations, generate a final, optimized version of the code. Only respond with the final code, do not return anything else."}]
Expand Down
Loading