From 16688b0e84e79435f118f8dd4885ab4c991ba624 Mon Sep 17 00:00:00 2001 From: franperezlopez <1222398+franperezlopez@users.noreply.github.com> Date: Tue, 2 Jan 2024 16:32:48 +0100 Subject: [PATCH] langchain support (#5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ Refactor main classes and functions * add langchain and langsmith support --- Dockerfile | 7 +- bot.py | 137 ------------------- environment.yml | 3 + requirements_dev.txt | 6 + src/bot.py | 238 ++++++++++++++++++++++++++++++++ src/llm/agent.py | 240 ++++++++++++++++++++++++++------- src/llm/places.py | 315 +++++++++++++++++++++++++++---------------- src/llm/prompt.py | 5 +- src/settings.py | 37 +++-- 9 files changed, 677 insertions(+), 311 deletions(-) delete mode 100644 bot.py create mode 100644 requirements_dev.txt create mode 100644 src/bot.py diff --git a/Dockerfile b/Dockerfile index 3a9f041..8f3a086 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,8 +17,8 @@ ARG AZURE_OPENAI_API_KEY ENV AZURE_OPENAI_API_KEY=$AZURE_OPENAI_API_KEY ARG AZURE_OPENAI_API_BASE ENV AZURE_OPENAI_API_BASE=$AZURE_OPENAI_API_BASE -ARG LANGCHAIN_API_KEY -ENV LANGCHAIN_API_KEY=$LANGCHAIN_API_KEY +ARG LANGSMITH_API_KEY +ENV LANGSMITH_API_KEY=$LANGSMITH_API_KEY # Copy environment file @@ -35,8 +35,7 @@ WORKDIR /app # Copy source code COPY src src -COPY bot.py . EXPOSE 80 -CMD ["python", "-m", "bot"] +CMD ["python", "-m", "src.bot"] diff --git a/bot.py b/bot.py deleted file mode 100644 index 8c8711a..0000000 --- a/bot.py +++ /dev/null @@ -1,137 +0,0 @@ -import html -import io -import json -import tempfile -import traceback -import re - -from loguru import logger -from telegram import InputFile, Update -from telegram.constants import ParseMode, ChatAction -from telegram.ext import (ApplicationBuilder, CallbackContext, CommandHandler, - ContextTypes, MessageHandler, Updater, filters) - -from src.llm.agent import build_agent -from src.settings import get_settings - - -settings = get_settings() - -async def call_agent(image_url): - logger.info("Calling agent ...") - agent = build_agent() - event = await agent.create_card(image_url) - logger.info(event) - return event - -async def handle_image_compressed(update: Update, context: ContextTypes.DEFAULT_TYPE): - logger.info("Handling image compressed ...") - - photo = await context.bot.get_file(update.message.photo[-1]) - await _handle_image(update, context, photo) - - -async def handle_image(update: Update, context: ContextTypes.DEFAULT_TYPE): - # Get the image file from the message - logger.info("Handling image ...") - - # user = update.message.from_user - await update.message.reply_chat_action(action=ChatAction.UPLOAD_PHOTO) - photo = await context.bot.get_file(update.message.document) - await _handle_image(update, context, photo) - -async def _handle_image(update, context, photo): - def _normalize_fn(text: str): - term = "FN:" - idx = text.find(term) - idx_end = text.find("\n", idx) - return text[idx+len(term):idx_end] - - def _normalize_tel(text: str): - for term in ["TEL:", "TEL;"]: - idx = text.find(term) - if idx > -1: - break - if idx == -1: - return "111 222 333" - idx_end = text.find("\n", idx) - sub_text = text[idx+len(term):idx_end] - if sub_text.find(":") > -1: - return sub_text.split(":")[-1] - else: - return ''.join(re.findall('\d', sub_text)) - - # Download the image file and save it to a temporary file - with tempfile.NamedTemporaryFile(delete=True) as f: - image_path = f.name - await photo.download_to_drive(image_path) - - # Process the image and generate the ICS file - await update.message.reply_chat_action(action=ChatAction.TYPING) - vcf_data = await call_agent(image_path) - - # Send the card (file) to the user - if vcf_data: - phone_number = _normalize_tel(vcf_data) - first_name = _normalize_fn(vcf_data) - await update.message.reply_contact(phone_number=phone_number, first_name=first_name, vcard=vcf_data) - else: - await update.message.reply_text("No contact found in image") - -async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Send a message when the command /help is issued.""" - logger.info("Help command ...") - await update.message.reply_text("Help for you ...") - - -async def echo_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Echo the user message.""" - logger.info("Echo command ...") - logger.info(update.effective_message.location) - await update.message.reply_text(update.message.text) - - -async def error_handler(update: object, context: ContextTypes.DEFAULT_TYPE) -> None: - logger.exception(context.error) - dev_chat_id = settings.TELEGRAM_DEV_CHAT_ID - if dev_chat_id: - update_str = update.to_dict() if isinstance(update, Update) else str(update) - tb_list = traceback.format_exception(None, context.error, context.error.__traceback__) - tb_string = "".join(tb_list) - message = ( - f"An exception was raised while handling an update\n" - f"
update = {html.escape(json.dumps(update_str, indent=2, ensure_ascii=False))}" "\n\n" - f"
context.chat_data = {html.escape(str(context.chat_data))}\n\n" - f"
context.user_data = {html.escape(str(context.user_data))}\n\n" - f"
{html.escape(tb_string)}" - ) - await context.bot.send_message(chat_id=dev_chat_id, text=message, parse_mode=ParseMode.HTML) - - -def main(): - # Set up the Telegram bot - app = ApplicationBuilder().token(settings.TELEGRAM_TOKEN).build() - - # Set up the message handler for images - app.add_handler(MessageHandler(filters.PHOTO, handle_image_compressed)) - app.add_handler(MessageHandler(filters.Document.IMAGE, handle_image)) - app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo_command)) - app.add_handler(CommandHandler("help", help_command)) - - app.add_error_handler(error_handler) - - # Start the bot - logger.info("Starting bot ...") - logger.info(settings.model_dump()) - - if settings.TELEGRAM_WEBHOOK_URL: - # preferred method for production - app.run_webhook(listen="0.0.0.0", webhook_url=settings.TELEGRAM_WEBHOOK_URL, port=80, - allowed_updates=Update.ALL_TYPES, secret_token=settings.TELEGRAM_SECRET, drop_pending_updates=True) - else: - # preferred method for development - app.run_polling(allowed_updates=Update.ALL_TYPES) - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/environment.yml b/environment.yml index 11cbf53..15cf1f8 100644 --- a/environment.yml +++ b/environment.yml @@ -11,7 +11,10 @@ dependencies: - pillow=10.1.0 - pip - pip: + - langchain[openai]==0.0.353 - loguru==0.7.2 - piexif==1.1.3 - python-telegram-bot[webhooks]==20.7 - google-search-results==2.4.2 + - randomname==0.2.1 + diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..6c593ea --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,6 @@ +ipykernel +ruff +pytest +pytest-cov +pytest-moq +mypy \ No newline at end of file diff --git a/src/bot.py b/src/bot.py new file mode 100644 index 0000000..fea56c8 --- /dev/null +++ b/src/bot.py @@ -0,0 +1,238 @@ +import html +import json +import re +import tempfile +import traceback +from enum import IntEnum +from typing import Optional + +from loguru import logger +from PIL import Image +from telegram import Location, ReplyKeyboardRemove, Update +from telegram.constants import ChatAction, ParseMode +from telegram.ext import (ApplicationBuilder, CommandHandler, ContextTypes, + ConversationHandler, MessageHandler, filters) + +from src.llm.agent import build_agent +from src.llm.places import EXIFHelper +from src.settings import get_settings + +settings = get_settings() + +class States(IntEnum): + START = 1 + PHOTO = 2 + OK_GPS = 3 + NO_GPS = 4 + CARD = 5 + END = 6 + +PHOTO = 1 +NO_GPS = 2 + +async def call_agent(image_url, detail: str, location: Optional[Location] = None): + logger.info("Calling agent ...") + agent = build_agent() + kwargs = {} + if location: + kwargs["lat"] = location.latitude + kwargs["lon"] = location.longitude + event = await agent.create_card(image_url, detail, **kwargs) + logger.info(event) + return event + + +async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + logger.info("Start command ...") + await update.message.reply_text("Bienvenido! Para convertir una imagen en una tarjeta de contacto, envíame una imagen.", + reply_markup=ReplyKeyboardRemove()) + + return PHOTO + +async def photo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + logger.info("Handling photo ...") + + if len(update.message.photo) > 0: + # compressed image -> ask geolocation + photo = await context.bot.get_file(update.message.photo[-1]) + if context.user_data.get("location"): + await _handle_image(update, context, photo, detail="high", location=context.user_data["location"]) + return ConversationHandler.END + context.chat_data["photo"] = photo + await update.message.reply_text("¿Puedes enviarme tu ubicación?", + reply_markup=ReplyKeyboardRemove()) + return NO_GPS + else: + # uncompressed image -> do card + # TODO: refactor creating TelegramImage class + photo = await context.bot.get_file(update.message.document) + img = Image.open(photo) + lat, lon = EXIFHelper.extract_coordinates(img) + if lat and lon: + await _handle_image(update, context, photo, detail="low") + return ConversationHandler.END + elif context.user_data.get("location"): + await _handle_image(update, context, photo, detail="low", location=context.user_data["location"]) + return ConversationHandler.END + else: + context.chat_data["photo"] = photo + await update.message.reply_text("¿Puedes enviarme tu ubicación?", + reply_markup=ReplyKeyboardRemove()) + return NO_GPS + + +async def location(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int: + logger.info("Handling location ...") + + if update.message.location: + # location -> do card + photo = context.chat_data["photo"] + context.user_data["location"] = update.message.location + await _handle_image(update, context, photo, detail="high", location=update.message.location) + return ConversationHandler.END + else: + # no location -> ask for location + await update.message.reply_text("¿Puedes enviarme tu ubicación?", + reply_markup=ReplyKeyboardRemove()) + return NO_GPS + +# async def handle_image_compressed(update: Update, context: ContextTypes.DEFAULT_TYPE): +# logger.info("Handling image compressed ...") + +# photo = await context.bot.get_file(update.message.photo[-1]) +# await _handle_image(update, context, photo) + + +# async def handle_image(update: Update, context: ContextTypes.DEFAULT_TYPE): +# # Get the image file from the message +# logger.info("Handling image ...") + +# # user = update.message.from_user +# await update.message.reply_chat_action(action=ChatAction.UPLOAD_PHOTO) +# photo = await context.bot.get_file(update.message.document) +# await _handle_image(update, context, photo) + + +async def _handle_image(update, context, photo, detail: str, location: Optional[Location] = None): + def _normalize_fn(text: str): + term = "FN:" + idx = text.find(term) + idx_end = text.find("\n", idx) + return text[idx + len(term) : idx_end] + + def _normalize_tel(text: str): + for term in ["TEL:", "TEL;"]: + idx = text.find(term) + if idx > -1: + break + if idx == -1: + return "111 222 333" + idx_end = text.find("\n", idx) + sub_text = text[idx + len(term) : idx_end] + if sub_text.find(":") > -1: + return sub_text.split(":")[-1] + else: + return "".join(re.findall("\d", sub_text)) + + # Download the image file and save it to a temporary file + with tempfile.NamedTemporaryFile(delete=True) as f: + image_path = f.name + await photo.download_to_drive(image_path) + + # Process the image and generate the ICS file + await update.message.reply_chat_action(action=ChatAction.TYPING) + vcf_data = await call_agent(image_path, detail, location) + + # Send the card (file) to the user + if vcf_data: + phone_number = _normalize_tel(vcf_data) + first_name = _normalize_fn(vcf_data) + await update.message.reply_contact(phone_number=phone_number, first_name=first_name, vcard=vcf_data) + else: + await update.message.reply_text("No se pudo generar la tarjeta.") + + +async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Send a message when the command /help is issued.""" + logger.info("Help command ...") + await update.message.reply_text("Help for you ...") + + +async def echo_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Echo the user message.""" + logger.info(f"Echo command ... {update.message.text}") + await update.message.reply_text(update.message.text) + + +async def error_handler(update: object, context: ContextTypes.DEFAULT_TYPE) -> None: + logger.exception(context.error) + dev_chat_id = settings.TELEGRAM_DEV_CHAT_ID + if dev_chat_id: + update_str = update.to_dict() if isinstance(update, Update) else str(update) + tb_list = traceback.format_exception(None, context.error, context.error.__traceback__) + tb_string = "".join(tb_list) + message = ( + f"An exception was raised while handling an update\n" + f"
update = {html.escape(json.dumps(update_str, indent=2, ensure_ascii=False))}" + "\n\n" + f"
context.chat_data = {html.escape(str(context.chat_data))}\n\n" + f"
context.user_data = {html.escape(str(context.user_data))}\n\n" + f"
{html.escape(tb_string)}" + ) + await context.bot.send_message(chat_id=dev_chat_id, text=message, parse_mode=ParseMode.HTML) + + +async def cancel(update, context): + """Cancel the current operation and end the conversation""" + update.message.reply_text("Operación cancelada.") + return ConversationHandler.END + + +def main(): + # Set up the Telegram bot + app = ApplicationBuilder().token(settings.TELEGRAM_TOKEN).build() + + # Set up the message handler for images + # app.add_handler(MessageHandler(filters.PHOTO, handle_image_compressed)) + # app.add_handler(MessageHandler(filters.Document.IMAGE, handle_image)) + app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo_command)) + # app.add_handler(CommandHandler("help", help_command)) + + # Add conversation handler with the states GENDER, PHOTO, LOCATION and BIO + conv_handler = ConversationHandler( + entry_points=[CommandHandler("start", start)], + states={ + # GENDER: [MessageHandler(filters.Regex("^(Boy|Girl|Other)$"), gender)], + PHOTO: [MessageHandler(filters.PHOTO | filters.Document.IMAGE, photo)], + # PHOTO: [MessageHandler(filters.ALL, photo)], + NO_GPS: [MessageHandler(filters.LOCATION, location)], + # BIO: [MessageHandler(filters.TEXT & ~filters.COMMAND, bio)], + }, + fallbacks=[CommandHandler("cancel", cancel)], + ) + + app.add_handler(conv_handler) + + app.add_error_handler(error_handler) + + # Start the bot + logger.info("Starting bot ...") + logger.info(settings.model_dump()) + + if settings.TELEGRAM_WEBHOOK_URL: + # preferred method for production + app.run_webhook( + listen="0.0.0.0", + webhook_url=settings.TELEGRAM_WEBHOOK_URL, + port=80, + allowed_updates=Update.ALL_TYPES, + secret_token=settings.TELEGRAM_SECRET, + drop_pending_updates=True, + ) + else: + # preferred method for development + app.run_polling(allowed_updates=Update.ALL_TYPES) + + +if __name__ == "__main__": + main() diff --git a/src/llm/agent.py b/src/llm/agent.py index bb3bcf2..6cbb5d3 100644 --- a/src/llm/agent.py +++ b/src/llm/agent.py @@ -1,78 +1,226 @@ -from functools import cache -from openai import AsyncAzureOpenAI import base64 -from loguru import logger import json +from functools import cache +from typing import Literal, Optional + +import randomname +from langchain.chat_models import AzureChatOpenAI +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from loguru import logger +from PIL import Image -from src.settings import Settings import src.llm.prompt as prompt -from src.llm.places import search +from src.llm.places import PlacesTool +from src.settings import Settings + + +class OpenAITool: + """ + A class that provides tools for generating vision and card responses using the Azure OpenAI API. + + Args: + settings (Settings): An instance of the Settings class containing API configuration. + + Attributes: + _client (AsyncAzureOpenAI): An instance of the AsyncAzureOpenAI client for making API requests. + + """ -class CardAgent: def __init__(self, settings: Settings) -> None: - self.client = AsyncAzureOpenAI(api_key=settings.AZURE_OPENAI_API_KEY, - api_version=settings.AZURE_OPENAI_API_VERSION, - azure_endpoint=settings.AZURE_OPENAI_API_BASE) + """ + Initializes the tool. + + Args: + settings (Settings): An instance of the Settings class containing the required Azure OpenAI API key, version and endpoint. + + Returns: + None + """ + common_client_args = { + "openai_api_key": settings.AZURE_OPENAI_API_KEY, + "openai_api_version": settings.AZURE_OPENAI_API_VERSION, + "azure_endpoint": settings.AZURE_OPENAI_API_BASE, + "streaming": False, + } + + if settings.LANGSMITH_TRACER: + common_client_args["callbacks"] = [settings.LANGSMITH_TRACER] + + self._client_vision = AzureChatOpenAI( + **common_client_args, + azure_deployment=settings.AZURE_OPENAI_DEPLOYMENT_VISION, + max_tokens=500, + ) + + self._client_agent = AzureChatOpenAI( + **common_client_args, + azure_deployment=settings.AZURE_OPENAI_DEPLOYMENT_AGENT, + ) + self._settings = settings - async def create_card(self, image_path: str) -> str: - image = self._encode_image(image_path) - vision = await self._generate_vision(image) - logger.info(f"vision: {vision}") - if 'venue' in vision: - try: - query = ' '.join([value for key,value in self._normalize_json(vision).items() if 'venue' in key]) - vision = search(image_path, query) - vision = json.dumps(vision) - except: - logger.exception(f"Error parsing vision") - card = await self._generate_card(vision) - logger.debug(f"card: {card}") - card = self._postprocess(card) - return card + self._run_name = "img2card" + + def _image_base64_format(self, image_path): + format = Image.open(image_path).format.lower() + + if "jpg" in format or "jpeg" in format: + return "image/jpeg" + if "png" in format: + return "image/png" + raise ValueError(f"Unsupported image format: {format}") + + def _image_encode(self, image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + @property + def run_name(self) -> Optional[str]: + """ + Get the name of the run. + + Returns: + Optional[str]: The name of the run. + + """ + return self._run_name + + @run_name.setter + def run_name(self, name: Optional[str]) -> None: + """ + Set the name of the run. + + Args: + name (Optional[str]): The name of the run. + + """ + self._run_name = name + + async def generate_vision(self, image_path: str, detail: Literal["low", "high"] = "low") -> str: + """ + Generates a vision response based on the provided card image. + + Args: + card (str): The base64-encoded image of the card. + detail (Literal["low", "high"], optional): The level of detail for the vision response. Defaults to "low". + + Returns: + str: The generated vision response. + + """ + image = self._image_encode(image_path) + format = self._image_base64_format(image_path) - async def _generate_vision(self, card: str) -> str: messages = [ - {'role': 'user', 'content': [{'type': 'text', - 'text': prompt.VISION}, - {'type': 'image_url', - 'image_url': {'url': f"data:image/jpeg;base64,{card}", 'detail': 'low'}}]}, + HumanMessage( + content=[ + {"type": "text", "text": prompt.VISION_TOOL}, + {"type": "image_url", "image_url": {"url": f"data:{format};base64,{image}", "detail": detail}}, + ], + ) ] + result = await self._client_vision.ainvoke(messages, config={"run_name": self.run_name}) + vision_response = result.content - result = await self.client.chat.completions.create(messages=messages, model=self._settings.AZURE_OPENAI_DEPLOYMENT_VISION, - max_tokens=500, stream=False) - vision_response = result.choices[0].message.content return vision_response - async def _generate_card(self, vision: str) -> str: + async def generate_card(self, vision_transcription: str) -> str: + """ + Generates a card response based on the provided vision transcription. + + Args: + vision_transcription (str): The transcription of the vision response. + + Returns: + str: The generated card response. + + """ messages = [ - {'role': 'system', 'content': 'you are an expert in vCard format'}, - {'role': 'assistant', 'content': vision}, - {'role': 'user', 'content': prompt.AGENT}, + SystemMessage(content=prompt.AGENT_SYSTEM), + AIMessage(content=vision_transcription), + HumanMessage(content=prompt.AGENT_TOOL), ] + result = await self._client_agent.ainvoke(messages, config={"run_name": self.run_name}) + card_response = result.content - result = await self.client.chat.completions.create(messages=messages, model="agent", stream=False) - card_response = result.choices[0].message.content return card_response + +class CardAgent: + """ + Represents an agent that creates contact cards based on transcriptions generated from the images and place's searching. + + Args: + settings (Settings): The settings object containing configuration options. + + Attributes: + _settings (Settings): The settings object containing configuration options. + _places (PlacesTool): The tool for searching places. + _llm (OpenAITool): The tool for generating vision transcriptions from images. + + """ + + def __init__(self, settings: Settings) -> None: + self._settings = settings + self._places = PlacesTool(settings) + self._llm = OpenAITool(settings) + + async def create_card( + self, image_path: str, detail: str = "low", lat: Optional[float] = None, lon: Optional[float] = None + ) -> str: + """ + Creates a card based on the provided image path and additional details. + + Args: + image_path (str): The path to the image file. + detail (str, optional): The level of detail for generating the vision. Defaults to "low". + lat (float, optional): The latitude coordinate. Defaults to None. + lon (float, optional): The longitude coordinate. Defaults to None. + + Returns: + str: The generated card. + """ + self._llm.run_name = randomname.get_name() + vision_transcription = await self._llm.generate_vision(image_path, detail) + logger.info(f"vision: {vision_transcription}") + if "venue" in vision_transcription: + try: + query = " ".join( + [value for key, value in self._normalize_json(vision_transcription).items() if "venue" in key] + ) + vision_transcription = self._places.search(image_path, query, lat, lon) + if vision_transcription: + vision_transcription = json.dumps(vision_transcription) + else: + # TODO: if detail == "low", try again with detail == "high + return None + except Exception: + logger.exception("Error parsing vision") + card = await self._llm.generate_card(vision_transcription) + logger.debug(f"card: {card}") + card = self._postprocess(card) + return card + def _postprocess(self, card: str) -> str: BEGIN_CARD = "BEGIN:VCARD" END_CARD = "END:VCARD" idx_start = card.find(BEGIN_CARD) idx_end = card.find(END_CARD) - return card[idx_start:idx_end+len(END_CARD)] - - def _normalize_json(self, text:str) -> dict: - return json.loads(text.strip('`').lstrip('json')) + return card[idx_start : idx_end + len(END_CARD)] + def _normalize_json(self, text: str) -> dict: + return json.loads(text.strip("`").lstrip("json")) - def _encode_image(self, image_path): - with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode('utf-8') @cache def build_agent(): + """ + Builds and returns a CardAgent object based on the settings obtained from get_settings(). + + Returns: + CardAgent: The built CardAgent object. + """ from src.settings import get_settings settings = get_settings() - return CardAgent(settings) \ No newline at end of file + return CardAgent(settings) diff --git a/src/llm/places.py b/src/llm/places.py index b28dd97..055498c 100644 --- a/src/llm/places.py +++ b/src/llm/places.py @@ -1,29 +1,19 @@ import base64 import time -from typing import Optional -from serpapi import GoogleSearch +from fractions import Fraction +from typing import Dict, List, Optional, Tuple, Union + +import piexif from loguru import logger from PIL import Image -import piexif -from fractions import Fraction +from serpapi import GoogleSearch + +from src.settings import Settings -from src.settings import get_settings - -ENGINE = "google_local" -DOMAIN = "google.es" -RADIUS = 300 -settings = get_settings() - -def extract_coordinates(img: Image.Image) -> tuple[Optional[float], Optional[float]]: - exif_img = img.info.get('exif') # img.getexif() - if not exif_img: - return None, None - exif_data = piexif.load(exif_img) - # logger.debug(f"exif_data: {exif_data}") - gps_info = exif_data.get('GPS') - if not gps_info: - return None, None - def convert_to_degrees(value): + +class EXIFHelper: + @classmethod + def _convert_to_degrees(cls, value: Union[int, float, Tuple[int, int], Tuple[float, float]]) -> float: if isinstance(value[0], (int, float)): d = Fraction(value[0]) else: @@ -41,96 +31,195 @@ def convert_to_degrees(value): return float(d + (m / 60) + (s / 3600)) - latitude = convert_to_degrees(gps_info[piexif.GPSIFD.GPSLatitude]) - longitude = convert_to_degrees(gps_info[piexif.GPSIFD.GPSLongitude]) - if gps_info[piexif.GPSIFD.GPSLatitudeRef] == 'S': - latitude = -latitude - if gps_info[piexif.GPSIFD.GPSLongitudeRef] == 'W': - longitude = -longitude - return (latitude, longitude) - -def generate_uule_v2(latitude, longitude, radius): - latitude_e7 = int(latitude * 1e7) - longitude_e7 = int(longitude * 1e7) - radius = int(radius * 620) - - timestamp = int(time.time() * 1000000) - - uule_v2_string = f"role:1\nproducer:12\nprovenance:6\ntimestamp:{timestamp}\nlatlng{{\nlatitude_e7:{latitude_e7}\nlongitude_e7:{longitude_e7}\n}}\nradius:{radius}\n" - - uule_v2_string_encoded = base64.b64encode(uule_v2_string.encode()).decode() - - return "a+" + uule_v2_string_encoded - -def search_by_uule(query, uule) -> list[dict]: - def _normalize_distance(local): - distance = local.get('type') - if not distance: - distance = float(local['address'].split('·')[0].strip().split(' ')[0]) + @classmethod + def extract_coordinates(cls, img: Image.Image) -> Tuple[Optional[float], Optional[float]]: + """ + Extracts the latitude and longitude coordinates from the EXIF data of an image. + + Args: + img (PIL.Image.Image): The image from which to extract the coordinates. + + Returns: + Tuple[Optional[float], Optional[float]]: A tuple containing the latitude and longitude coordinates. + Returns (None, None) if the image does not have EXIF data or if the coordinates are not found. + """ + + img_exif = img.info.get("exif") # img.getexif() + if not img_exif: + return None, None + exif = piexif.load(img_exif) + exif_gps = exif.get("GPS") + if not exif_gps: + return None, None + + latitude = cls._convert_to_degrees(exif_gps[piexif.GPSIFD.GPSLatitude]) + longitude = cls._convert_to_degrees(exif_gps[piexif.GPSIFD.GPSLongitude]) + if exif_gps[piexif.GPSIFD.GPSLatitudeRef] == "S": + latitude = -latitude + if exif_gps[piexif.GPSIFD.GPSLongitudeRef] == "W": + longitude = -longitude + return (latitude, longitude) + + +class SerperHelper: + ENGINE = "google_local" + DOMAIN = "google.es" + + @staticmethod + def generate_uule_v2(latitude, longitude, radius) -> str: + """ + Generate a UULE v2 string based on the given latitude, longitude, and radius. + + Args: + latitude (float): The latitude of the location. + longitude (float): The longitude of the location. + radius (float): The radius of the location in kilometers. + + Returns: + str: The UULE v2 string. + + """ + latitude_e7 = int(latitude * 1e7) + longitude_e7 = int(longitude * 1e7) + radius = int(radius * 620) + + timestamp = int(time.time() * 1000000) + + uule_v2_string = f"role:1\nproducer:12\nprovenance:6\ntimestamp:{timestamp}\nlatlng{{\nlatitude_e7:{latitude_e7}\nlongitude_e7:{longitude_e7}\n}}\nradius:{radius}\n" + + uule_v2_string_encoded = base64.b64encode(uule_v2_string.encode()).decode() + + return "a+" + uule_v2_string_encoded + + + + @classmethod + def _normalize_distance(cls, local: Dict) -> Optional[float]: + distance = local.get("type") + try: + if not distance: + distance = float(local["address"].split("·")[0].strip().split(" ")[0]) + else: + distance = float(distance.split(" ")[0]) + return distance + except Exception: + logger.error(f"Error parsing distance {distance} from {local['address']}") + return None + + + @classmethod + def _common_parameters(cls, settings: Settings) -> Dict: + return { + "api_key": settings.SERPAPI_API_KEY, + "engine": cls.ENGINE, + "google_domain": cls.DOMAIN, + "gl": "es", + "hl": "en", + "device": "tablet" + } + + @classmethod + def _search(cls, settings: Settings, additional_args: Dict) -> List[Dict]: + params = cls._common_parameters(settings) | additional_args + search = GoogleSearch(params) + results = search.get_dict() + return results["local_results"] + + @classmethod + def search_by_uule(cls, settings: Settings, query: str, uule: str) -> List[Dict]: + """ + Search for local results on Google based on the given query and uule. + + Args: + query (str): The search query. + uule (str): The uule parameter for location-based search. + + Returns: + List[Dict]: A list of dictionaries representing the local search results. + Each dictionary contains the following keys: + - "title": The title of the local result. + - "place_id": The place ID of the local result. + - "address": The address of the local result. + - "description": The description of the local result (optional). + - "distance": The distance of the local result from the specified location. + + """ + + local_results = cls._search(settings, {"q": query, "uule": uule}) + + locals = [] + for local in local_results: + distance = cls._normalize_distance(local) + if distance: + locals.append( + { + "title": local.get("title"), + "place_id": local.get("place_id"), + "address": local["address"][local["address"].find("·") + 2 :], + "description": local.get("description"), + # 'extra': [local.get('service_options'), local.get('hours')], + "distance": distance, + } + ) + return sorted(locals, key=lambda x: x["distance"]) + + @classmethod + def search_by_place_id(cls, settings: Settings, query: str, place_id: str) -> Dict: + """ + Search for a place by its place ID. + + Args: + query (str): The search query. + place_id (str): The place ID of the location. + + Returns: + dict: A dictionary containing information about the place, including phone number, type, title, and GPS coordinates. + """ + if place_id is None: + logger.warning("No place ID provided") + return {} + + local_results = cls._search(settings, {"q": query, "ludocid": place_id}) + + result = { + "phone": local_results[0].get("phone"), + "type": local_results[0].get("type"), + "title": local_results[0].get("title"), + "gps_coordinates": local_results[0].get("gps_coordinates"), + } + + return result + + +class PlacesTool(): + RADIUS = 300 + + def __init__(self, settings: Settings): + self.settings: Settings = settings + + def search(self, image_path: str, query: str, lat: Optional[float], lon: Optional[float]) -> Optional[Dict]: + """ + Searches for a place based on the given image path and query. + + Args: + image_path (str): The path to the image. + query (str): The search query. + + Returns: + Optional[Dict]: A dictionary representing the found place, or None if no place is found. + """ + if lat and lon: + latitude, longitude = lat, lon else: - distance = float(distance.split(' ')[0]) - return distance - - params = { - "api_key": settings.SERPAPI_API_KEY, - "engine": ENGINE, - "google_domain": DOMAIN, - "gl": "es", - "hl": "en", - "q": query, - "device": "tablet", - "uule": uule - } - - search = GoogleSearch(params) - results = search.get_dict() - - locals = [] - for local in results['local_results']: - distance = _normalize_distance(local) - locals.append({ - 'title': local['title'], - 'place_id': local['place_id'], - 'address': local['address'][local['address'].find('·')+2:], - 'description': local.get('description'), - # 'extra': [local.get('service_options'), local.get('hours')], - 'distance': distance, - }) - return sorted(locals, key=lambda x: x['distance']) - -def search_by_place_id(query, place_id) -> dict: - params = { - "api_key": settings.SERPAPI_API_KEY, - "engine": ENGINE, - "google_domain": DOMAIN, - "q": query, - "gl": "es", - "hl": "en", - "ludocid": place_id, - } - - search = GoogleSearch(params) - results = search.get_dict() - - logger.debug(f"local_results: {results['local_results']}") - - result = { - 'phone': results['local_results'][0]['phone'], - 'type': results['local_results'][0]['type'], - 'title': results['local_results'][0]['title'], - 'gps_coordinates': results['local_results'][0]['gps_coordinates'], - } - - return result - -def search(image_path, query): - img = Image.open(image_path) - latitude, longitude = extract_coordinates(img) - uule = generate_uule_v2(latitude, longitude, RADIUS) - logger.info(f"uule: {uule}") - locals = search_by_uule(query, uule) - logger.info(f"locals:\n{locals}") - if len(locals) > 0: - place = locals[0] | search_by_place_id(query, locals[0]['place_id']) - logger.info(f"place:{place}") - return place + img = Image.open(image_path) + latitude, longitude = EXIFHelper.extract_coordinates(img) + if not latitude or not longitude: + return None + uule = SerperHelper.generate_uule_v2(latitude, longitude, self.RADIUS) + logger.info(f"uule: {uule}") + locals = SerperHelper.search_by_uule(self.settings, query, uule) + logger.debug(f"locals:\n{locals}") + if len(locals) > 0: + place = locals[0] | SerperHelper.search_by_place_id(self.settings, query, locals[0].get("place_id")) + logger.debug(f"place: {place}") + return place diff --git a/src/llm/prompt.py b/src/llm/prompt.py index 32208ff..bbcab8a 100644 --- a/src/llm/prompt.py +++ b/src/llm/prompt.py @@ -1,4 +1,5 @@ -VISION = "if image type is photo, return venue name (usually, bigger text) and venue type. if image type is card, captionize elements in the image. always return the answer with json format." +VISION_TOOL = "if image type is a photo, return venue name (usually, bigger text) and venue type (make your best bet). if image type is card, return captionized elements in the image. always return the response in json format. do not make up data." -AGENT = "convert previous json data to document type RFC 6350 using text representation. ignore logo and other decorations. be precise and concise. do not explain your output" +AGENT_SYSTEM = "you are an expert in vCard format" +AGENT_TOOL = "convert previous json data to document type RFC 6350 using text representation. ignore logo and other decorations. be precise and concise. do not explain your output" diff --git a/src/settings.py b/src/settings.py index f73a1fd..5c1f4b7 100644 --- a/src/settings.py +++ b/src/settings.py @@ -1,8 +1,11 @@ +import uuid +from functools import cache, cached_property from typing import Optional, Union -from pydantic import Field + +from langchain.callbacks import LangChainTracer +from pydantic import Field, computed_field from pydantic_settings import BaseSettings -from functools import cache -import uuid + class Settings(BaseSettings): class Config: @@ -12,7 +15,7 @@ class Config: case_sensitive = True extra = "ignore" # allow_mutation = False - + AZURE_OPENAI_API_KEY: str = Field(env="AZURE_OPENAI_API_KEY") AZURE_OPENAI_API_BASE: str = Field(env="AZURE_OPENAI_API_BASE") AZURE_OPENAI_API_VERSION: str = Field(default="2023-07-01-preview", env="AZURE_OPENAI_API_VERSION") @@ -24,17 +27,33 @@ class Config: SERPAPI_API_KEY: Optional[str] = Field(default=None, env="SERPAPI_API_KEY") - LANGCHAIN_TRACING_V2: str = Field(default="true", env="LANGCHAIN_TRACING_V2") - LANGCHAIN_ENDPOINT: str = Field(default="https://api.smith.langchain.com", env="LANGCHAIN_ENDPOINT") - LANGCHAIN_API_KEY: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY") - LANGCHAIN_PROJECT: str = Field(default="img2card", env="LANGCHAIN_PROJECT") + LANGSMITH_ENDPOINT: str = Field(default="https://api.smith.langchain.com", env="LANGSMITH_ENDPOINT") + LANGSMITH_API_KEY: Optional[str] = Field(default=None, env="LANGSMITH_API_KEY") + LANGSMITH_PROJECT: str = Field(default="img2card", env="LANGSMITH_PROJECT") TELEGRAM_TOKEN: str = Field(env="TELEGRAM_TOKEN") - TELEGRAM_SECRET: Optional[str] = Field(default_factory=lambda : str(uuid.uuid4()).replace('-', ''), env="TELEGRAM_SECRET") + TELEGRAM_SECRET: Optional[str] = Field( + default_factory=lambda: str(uuid.uuid4()).replace("-", ""), env="TELEGRAM_SECRET" + ) TELEGRAM_WEBHOOK_URL: Optional[str] = Field(default=None, env="TELEGRAM_WEBHOOK_URL") TELEGRAM_DEV_CHAT_ID: Optional[Union[int, str]] = Field(default=None, env="TELEGRAM_DEV_CHAT_ID") + @computed_field + @cached_property + def LANGSMITH_TRACER(self) -> Optional[LangChainTracer]: + return get_langsmith_tracer(self) + @cache def get_settings() -> Settings: return Settings() + + +def get_langsmith_tracer(settings) -> Optional[LangChainTracer]: + if settings.LANGSMITH_API_KEY: + from langsmith import Client + return LangChainTracer( + project_name=settings.LANGSMITH_PROJECT, + client=Client(api_url=settings.LANGSMITH_ENDPOINT, api_key=settings.LANGSMITH_API_KEY), + ) + return None