Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2669 improve location accuracy for street newsflashes #2708

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
83 changes: 83 additions & 0 deletions anyway/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from openai import OpenAI
import json
import tiktoken

from langchain.output_parsers import EnumOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
import langchain
from enum import Enum
from anyway import secrets

api_key = secrets.get("OPENAI_API_KEY")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest inserting the secret into a function, to avoid tests failures in tests from forks

client = OpenAI(api_key=api_key)

langchain.debug = True
model = ChatOpenAI(api_key=api_key, temperature=0)


def match_streets_with_langchain(street_names, location):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I prefer to have type hints. It makes the reading much easier.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this code used?

street_names.append("-")
Streets = Enum('Streets', {name: name for name in street_names})

parser = EnumOutputParser(enum=Streets)
print(parser.get_format_instructions())
prompt = PromptTemplate(
template="Return the street that is mentioned in the location string. if non matches return '-'.\nstreets: {streets}\n" +
"location_string:{location}\n{format_instructions}\n",
input_variables=["streets", "location"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)

chain = prompt | model | parser

res = chain.invoke({"streets": street_names, "location": location})
return res


def count_tokens_for_prompt(messages, model):
tokenizer = tiktoken.encoding_for_model(model)
total_tokens = 0
for message in messages:
# Each message has a role and content
message_tokens = tokenizer.encode(f"{message['role']}: {message['content']}")
total_tokens += len(message_tokens)
# Additional tokens for formatting
total_tokens += 4 # approx overhead for each message (role + delimiters)

return total_tokens


def count_tokens(text, model):
tokenizer = tiktoken.encoding_for_model(model)
tokens = tokenizer.encode(text)
return len(tokens)


def ask_gpt(system_message, user_message, model="gpt-4o"):
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
]
completion = client.chat.completions.create(
response_format={"type": "json_object"},
model=model,
messages=messages
)
print(f"tokens for prompt: {count_tokens_for_prompt(messages, model)}")
return completion.choices[0].message


def ask_ai_about_street_matching(streets, location_string, model="gpt-4-turbo"):
system_message = """
Given a list of streets, return the name of the street that is mentioned in the provided location string.
Return the name exactly as appears in list.
If no match is found, return "-".
Return json with field "street" and your answer.
Select one of the following options:
""" + json.dumps(streets + ["-"])
input = json.dumps({"streets": streets, "location": location_string})
reply = ask_gpt(system_message, input, model)
# print(f"tokens: {count_tokens(reply.content, model)}")
result = json.loads(reply.content)["street"]
return result, result in streets
90 changes: 76 additions & 14 deletions anyway/parsers/location_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from anyway.parsers.resolution_fields import ResolutionFields as RF
from anyway import secrets
from anyway.models import AccidentMarkerView, RoadSegments
from anyway.llm import ask_ai_about_street_matching
from sqlalchemy import not_
import pandas as pd
from sqlalchemy.orm import load_only
Expand Down Expand Up @@ -176,19 +177,7 @@ def get_bounding_box(latitude, longitude, distance_in_km):
return final_loc


def get_db_matching_location(db, latitude, longitude, resolution, road_no=None):
"""
extracts location from db by closest geo point to location found, using road number if provided and limits to
requested resolution
:param db: the DB
:param latitude: location latitude
:param longitude: location longitude
:param resolution: wanted resolution
:param road_no: road number if there is
:return: a dict containing all the geo fields stated in
resolution dict, with values filled according to resolution
"""
# READ MARKERS FROM DB
def read_markers_and_distance_from_location(db, latitude, longitude, resolution, road_no=None):
geod = Geodesic.WGS84
relevant_fields = RF.get_possible_fields(resolution)
markers = db.get_markers_for_location_extraction()
Expand Down Expand Up @@ -222,6 +211,24 @@ def get_db_matching_location(db, latitude, longitude, resolution, road_no=None):
markers["dist_point"] = markers.apply(
lambda x: geod.Inverse(latitude, longitude, x["latitude"], x["longitude"])["s12"], axis=1
).replace({np.nan: None})
return markers


def get_db_matching_location(db, latitude, longitude, resolution, road_no=None):
"""
extracts location from db by closest geo point to location found, using road number if provided and limits to
requested resolution
:param db: the DB
:param latitude: location latitude
:param longitude: location longitude
:param resolution: wanted resolution
:param road_no: road number if there is
:return: a dict containing all the geo fields stated in
resolution dict, with values filled according to resolution
"""
# READ MARKERS FROM DB
relevant_fields = RF.get_possible_fields(resolution)
markers = read_markers_and_distance_from_location(db, latitude, longitude, resolution, road_no)

most_fit_loc = (
markers.loc[markers["dist_point"] == markers["dist_point"].min()].iloc[0].to_dict()
Expand All @@ -240,6 +247,24 @@ def get_db_matching_location(db, latitude, longitude, resolution, road_no=None):
return final_loc


def read_n_closest_streets(db, n, latitude, longitude, road_no=None):
markers = read_markers_and_distance_from_location(
db, latitude, longitude, BE_CONST.ResolutionCategories.STREET, road_no
)
# Sort by distance
sorted_markers = markers.sort_values(by="dist_point")

# Drop duplicates to ensure unique street1_hebrew values
unique_street_markers = sorted_markers.drop_duplicates(subset="street1_hebrew")

# Select the top n entries
top_n_unique_streets = unique_street_markers.head(n)

# Convert to dictionary if needed
result_dicts = top_n_unique_streets.to_dict(orient="records")
return [result["street1_hebrew"] for result in result_dicts]


def set_accident_resolution(accident_row):
"""
set the resolution of the accident
Expand Down Expand Up @@ -282,11 +307,12 @@ def reverse_geocode_extract(latitude, longitude):
try:
gmaps = googlemaps.Client(key=secrets.get("GOOGLE_MAPS_KEY"))
geocode_result = gmaps.reverse_geocode((latitude, longitude))

print(geocode_result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove all prints in this code?

# if we got no results, move to next iteration of location string
if not geocode_result:
return None
except Exception as _:
logging.info(_)
logging.info("exception in gmaps")
return None
# logging.info(geocode_result)
Expand Down Expand Up @@ -539,6 +565,42 @@ def extract_geo_features(db, newsflash: NewsFlash, use_existing_coordinates_only
if location_from_db is not None:
update_location_fields(newsflash, location_from_db)
try_find_segment_id(newsflash)
logging.debug(newsflash.resolution)
if newsflash.resolution == BE_CONST.ResolutionCategories.STREET:
try_improve_street_identification(newsflash)


def try_improve_street_identification(newsflash):
from anyway.parsers import news_flash_db_adapter

db = news_flash_db_adapter.init_db()
all_closest_streets = read_n_closest_streets(db, 20, newsflash.lat, newsflash.lon)

num_of_streets_for_first_try = 5
streets_for_first_try = all_closest_streets[:num_of_streets_for_first_try]
streets_for_second_try = all_closest_streets[num_of_streets_for_first_try:]

result, result_in_input = ask_ai_about_street_matching(
streets_for_first_try, newsflash.location
)
logging.debug(f"result of 1st try {result}")
if not result_in_input:
logging.debug(f"street matching failed first try for newsflash {newsflash.id}")
result, result_in_input = ask_ai_about_street_matching(
streets_for_second_try, newsflash.location
)
logging.debug(f"result of 2nd try {result}")
if result_in_input:
if result == newsflash.street1_hebrew:
logging.debug("street matching succeeded, street not changed")
else:
logging.debug(
f"street matching succeeded, street updated for {newsflash.id} "
f"from {newsflash.street1_hebrew} to {result}"
)
newsflash.street1_hebrew = result
else:
logging.debug(f"street matching failed second try for newsflash {newsflash.id}")


def update_location_fields(newsflash, location_from_db):
Expand Down
11 changes: 11 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,17 @@ def infographics_pictures(id):
raise Exception("generation failed")


@process.command()
@click.option("--id", type=int)
def street_name(id):
from anyway.parsers import news_flash_db_adapter
from anyway.parsers.location_extraction import try_improve_street_identification

db = news_flash_db_adapter.init_db()
newsflash = db.get_newsflash_by_id(id).first()
try_improve_street_identification(newsflash)


@process.group()
def cache():
pass
Expand Down
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Flask-Login==0.5.0
Flask-SQLAlchemy==2.4.1
flask-restx==0.5.1
Jinja2==3.1.4
SQLAlchemy==1.3.17
SQLAlchemy==1.4
Werkzeug==2.0.3
alembic==1.4.2
attrs==23.1.0
Expand Down Expand Up @@ -53,3 +53,7 @@ swifter==1.3.4
telebot==0.0.5
selenium==4.11.2
apache-airflow-client==2.6.2
openai==1.45.0
langchain==0.2.16
langchain_openai==0.1.25
python-dotenv
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is python-dotenv necessary?
Also, SQLAlchemy==1.4 modification necessary?
(I think we should perform packages upgrade, but it will be in a different pr with suitable tests)

Loading