Skip to content

Commit

Permalink
fix: pylint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
nitchandak committed May 2, 2024
1 parent ea73d0b commit 111b820
Show file tree
Hide file tree
Showing 15 changed files with 96 additions and 108 deletions.
3 changes: 1 addition & 2 deletions gemini/sample-apps/agent-assist/backend/src/apis/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def chatbot_entry(data: dict = {}) -> dict:
with open("data/static/oe_examples/logs.json", "w") as f:
json.dump(logs, f)

result = run_orchestrator(query, chat_history_string)
return result
run_orchestrator(query, chat_history_string)


def process_history(chat_history):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def get_customer_management_data():
jsonify(
{
"total_active_customers": total_active_customers,
"average_satisfaction_score": float(
"{:.3}".format(average_satisfaction_score)
),
"average_satisfaction_score": float(f"{average_satisfaction_score:.3}"),
"total_lapsed_customers": total_lapsed_customers,
"chart_data": chart_data,
}
Expand Down Expand Up @@ -76,11 +74,7 @@ def get_metrics_data(data: list, start_date: str, end_date: str):

if policy_start_date is None:
continue
if (
policy["current_policy"]
and policy_start_date >= start_date
and policy_start_date <= end_date
):
if policy["current_policy"] and start_date <= policy_start_date <= end_date:
total_active_customers += 1

if total_ratings != 0:
Expand All @@ -106,11 +100,7 @@ def get_lapsed_customers(data: list, start_date: str, end_date: str):
policy_end_date = policy["policy_end_date"]
if policy_end_date is None:
continue
if (
policy["current_policy"] is None
and policy_end_date >= start_date
and policy_end_date <= end_date
):
if policy["current_policy"] is None and start_date <= policy_end_date <= end_date:
total_lapsed_customers += 1

return total_lapsed_customers
Expand All @@ -131,7 +121,7 @@ def get_chart_data(data, start_date, end_date):
start_date = datetime.strptime(start_date, "%Y-%m-%d")
end_date = datetime.strptime(end_date, "%Y-%m-%d")
month_list = [
datetime.strptime("%2.2d-%2.2d" % (year, month), "%Y-%m").strftime("%b-%y")
datetime.strptime(f"{year:02}-{month:02}", "%Y-%m").strftime("%b-%y")
for year in range(start_date.year, end_date.year + 1)
for month in range(
start_date.month if year == start_date.year else 1,
Expand All @@ -156,11 +146,7 @@ def get_chart_data(data, start_date, end_date):
"satisfaction_score": policy["satisfaction_score"],
"count": 1,
}
if (
policy["current_policy"]
and policy_start_date >= start_date
and policy_start_date <= end_date
):
if policy["current_policy"] and start_date <= policy_start_date <= end_date:
if month in month_data:
month_data[month]["active_customers"] += 1
else:
Expand Down
3 changes: 1 addition & 2 deletions gemini/sample-apps/agent-assist/backend/src/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ def handle_chatbot(data):
"""Handles the chatbot."""
print(data)
emit("chat", ["Generating..."])
chatbot_response = chatbot.chatbot_entry(data)
print(chatbot_response)
chatbot.chatbot_entry(data)
emit("chat", ["Done"])


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from src.chatbot_dir.agents.search_agent.preprocessing.table.text_bison import TextBison


def processTable(table_df_string: str) -> str:
def process_table(table_df_string: str) -> str:
"""Processes a table in dataframe string format using TextBison.
Args:
Expand All @@ -18,6 +18,6 @@ def processTable(table_df_string: str) -> str:
str: The processed table in dataframe string format.
"""
tb = TextBison()
PROMPT = PROMPT_FOR_TABLE.format(table_df_string)
df_string = tb.generate_response(PROMPT)
prompt = PROMPT_FOR_TABLE.format(table_df_string)
df_string = tb.generate_response(prompt)
return df_string
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,41 @@
from img2table.document import PDF
from img2table.ocr import TesseractOCR
from src.chatbot_dir.agents.search_agent.preprocessing.table.process_function import (
processTable,
process_table,
)

# Function to process the PDF tables and save the extracted text to files.


def process_pdf_tables(DOCUMENT_PATH: str, POLICY_NAME: str) -> None:
def process_pdf_tables(document_path: str, policy_name: str) -> None:
"""Processes the PDF tables and saves the extracted text to files.
Args:
DOCUMENT_PATH (str): The path to the PDF document.
POLICY_NAME (str): The name of the policy to which the PDF document belongs.
document_path (str): The path to the PDF document.
policy_name (str): The name of the policy to which the PDF document belongs.
"""

OUTPUT_PATH = f"data/static/table_text/{POLICY_NAME}/"
output_path = f"data/static/table_text/{policy_name}/"

pdf = PDF(src=DOCUMENT_PATH)
pdf = PDF(src=document_path)

ocr = TesseractOCR(lang="eng")

pdf_tables = pdf.extract_tables(ocr=ocr)

for idx, pdf_table in pdf_tables.items():
try:
os.makedirs(OUTPUT_PATH + str(idx))
os.makedirs(output_path + str(idx))
except OSError:
pass
if not pdf_table:
continue
for jdx, table in enumerate(pdf_table):
table_df_string = table.df.to_string()
table_string = processTable(table_df_string)
table_string = process_table(table_df_string)
print(table_string)
with open(OUTPUT_PATH + f"{idx}/table_df_{jdx}.txt", "w") as f:
with open(output_path + f"{idx}/table_df_{jdx}.txt", "w") as f:
f.write(table_df_string)

with open(OUTPUT_PATH + f"{idx}/table_string_{jdx}.txt", "w") as f:
with open(output_path + f"{idx}/table_string_{jdx}.txt", "w") as f:
f.write(table_string)
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class TextBison:
A class to interact with the Text Bison model from Vertex AI.
Args:
PROJECT_ID (str): The project ID of the Vertex AI project.
LOCATION (str): The location of the Vertex AI project.
project_id (str): The project ID of the Vertex AI project.
location (str): The location of the Vertex AI project.
max_output_tokens (int): The maximum number of tokens to generate.
temperature (float): The temperature controls the randomness of the generated text.
top_p (float): Top-p nucleus sampling.
Expand All @@ -30,8 +30,8 @@ class TextBison:

def __init__(
self,
PROJECT_ID=config["PROJECT_ID"],
LOCATION=config["LOCATION"],
project_id=config["PROJECT_ID"],
location=config["LOCATION"],
max_output_tokens: int = 8192,
temperature: float = 0.1,
top_p: float = 0.8,
Expand All @@ -41,27 +41,27 @@ def __init__(
Initialize the TextBison class.
Args:
PROJECT_ID (str): The project ID of the Vertex AI project.
LOCATION (str): The location of the Vertex AI project.
project_id (str): The project ID of the Vertex AI project.
location (str): The location of the Vertex AI project.
max_output_tokens (int): The maximum number of tokens to generate.
temperature (float): The temperature controls the randomness of the generated text.
top_p (float): Top-p nucleus sampling.
top_k (int): Top-k nucleus sampling.
"""
self.PROJECT_ID = PROJECT_ID
self.LOCATION = LOCATION
self.project_id = project_id
self.location = location
self.parameters = {
"max_output_tokens": max_output_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
}

vertexai.init(project=self.PROJECT_ID, location=self.LOCATION)
vertexai.init(project=self.project_id, location=self.location)

self.model = TextGenerationModel.from_pretrained(config["text_bison_model"])

def generate_response(self, PROMPT: str) -> str:
def generate_response(self, prompt: str) -> str:
"""
Generate a response using the Text Bison model.
Expand All @@ -73,5 +73,5 @@ def generate_response(self, PROMPT: str) -> str:
"""
print("running tb.generate_response")
parameters = self.parameters
response = self.model.predict(PROMPT, **parameters)
response = self.model.predict(prompt, **parameters)
return response.text
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,18 @@


class TextBison:
"""
Initializes the TextBison class for text generation.
Args:
PROJECT_ID: GCP Project ID.
LOCATION: GCP Region. Defaults to "us-central1".
"""

def __init__(
self,
PROJECT_ID=config["PROJECT_ID"],
LOCATION=config["LOCATION"],
project_id=config["PROJECT_ID"],
location=config["LOCATION"],
max_output_tokens=2048,
temperature=0.05,
top_p=0.8,
Expand All @@ -27,27 +35,27 @@ def __init__(
"""Initializes the TextBison class.
Args:
PROJECT_ID (str): The Google Cloud project ID.
LOCATION (str): The Google Cloud region where the model is deployed.
project_id (str): The Google Cloud project ID.
location (str): The Google Cloud region where the model is deployed.
max_output_tokens (int): The maximum number of tokens to generate.
temperature (float): The temperature to use for sampling.
top_p (float): The top-p value to use for sampling.
top_k (int): The top-k value to use for sampling.
"""
self.PROJECT_ID = PROJECT_ID
self.LOCATION = LOCATION
self.project_id = project_id
self.location = location
self.parameters = {
"max_output_tokens": max_output_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
}

vertexai.init(project=self.PROJECT_ID, location=self.LOCATION)
vertexai.init(project=self.project_id, location=self.location)

self.model = TextGenerationModel.from_pretrained(config["text_bison_model"])

def generate_response(self, PROMPT):
def generate_response(self, prompt):
"""Generates a response to a given PROMPT.
Args:
Expand All @@ -57,5 +65,5 @@ def generate_response(self, PROMPT):
str: The generated response.
"""
parameters = self.parameters
response = self.model.predict(PROMPT, **parameters)
response = self.model.predict(prompt, **parameters)
return response.text
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def generate_answer(question: str) -> str:
with open("data/likes.json", "rb") as f:
df2 = pd.DataFrame(json.load(f))

PROMPT = SQL_PROMPT.format(question=question)
answer = tb.generate_response(PROMPT=PROMPT)
prompt = SQL_PROMPT.format(question=question)
answer = tb.generate_response(prompt=prompt)
answer = answer.replace("<SQL>", "")
sql_query = answer.replace("</SQL>", "")
sql_query = sql_query.strip()
Expand All @@ -65,8 +65,8 @@ def generate_answer(question: str) -> str:
answer_df = ps.sqldf(sql_query, locals())
print(answer_df)
temp_df = answer_df.astype(str)
PROMPT = FINAL_ANSWER_PROMPT.format(question=question, df=temp_df)
answer_natural_language = tb.generate_response(PROMPT)
prompt = FINAL_ANSWER_PROMPT.format(question=question, df=temp_df)
answer_natural_language = tb.generate_response(prompt)
print("answer_natural_language : ", answer_natural_language)
return answer_natural_language

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def create_sales_pitch(prompt: str, policy_name: str) -> str:
return response


def generate_email(PROMPT: str, chat_history: str) -> tuple[str, str]:
def generate_email(prompt: str, chat_history: str) -> tuple[str, str]:
"""Generate email function to handle queries related to generating emails.
Args:
Expand All @@ -82,7 +82,7 @@ def generate_email(PROMPT: str, chat_history: str) -> tuple[str, str]:
Returns:
tuple[str, str]: A tuple containing the email subject and body.
"""
return mail_component(query=PROMPT, chat_history=chat_history)
return mail_component(query=prompt, chat_history=chat_history)


def send_email(email_id: str, subject: str, body: str) -> None:
Expand Down
10 changes: 5 additions & 5 deletions gemini/sample-apps/agent-assist/backend/src/utils/cal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,26 @@ def __init__(self):
Initializes the Calendar class.
"""
self.self_email = config["company_email"]
self.SCOPES = [config["CALENDAR_SCOPE"]]
self.scopes = [config["CALENDAR_SCOPE"]]
self.creds = None
if os.path.exists("cal_token.json"):
self.creds = Credentials.from_authorized_user_file(
"cal_token.json", self.SCOPES
"cal_token.json", self.scopes
)
if not self.creds or not self.creds.valid:
if self.creds and self.creds.expired and self.creds.refresh_token:
self.creds.refresh(Request())
else:
flow = InstalledAppFlow.from_client_secrets_file(
"keys/credentials_desktop.json", self.SCOPES
"keys/credentials_desktop.json", self.scopes
)
self.creds = flow.run_local_server(port=0)
with open("cal_token.json", "w") as token:
token.write(self.creds.to_json())
try:
self.service = build("calendar", "v3", credentials=self.creds)
except HttpError as error:
print("An error occurred: %s" % error)
print(f"An error occurred: {error}")

def create_event(
self, email: list[str], start_date_time: str, end_date_time: str
Expand Down Expand Up @@ -97,7 +97,7 @@ def create_event(
.insert(calendarId="primary", body=event, sendUpdates="all")
.execute()
)
print("Event created: %s" % (event.get("htmlLink")))
print(f"Event created: {event.get('htmlLink')}")

# print(event)
return event
Expand Down
Loading

0 comments on commit 111b820

Please sign in to comment.