-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2669 improve location accuracy for street newsflashes #2708
base: dev
Are you sure you want to change the base?
Changes from all commits
c28639d
b1abe45
b65c013
4f82f89
135ce08
eb374ca
9f6999a
afa99af
6c02aca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from openai import OpenAI | ||
import json | ||
import tiktoken | ||
|
||
from langchain.output_parsers import EnumOutputParser | ||
from langchain_core.prompts import PromptTemplate | ||
from langchain_openai import ChatOpenAI | ||
import langchain | ||
from enum import Enum | ||
from anyway import secrets | ||
|
||
api_key = secrets.get("OPENAI_API_KEY") | ||
client = OpenAI(api_key=api_key) | ||
|
||
langchain.debug = True | ||
model = ChatOpenAI(api_key=api_key, temperature=0) | ||
|
||
|
||
def match_streets_with_langchain(street_names, location): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, I prefer to have type hints. It makes the reading much easier. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this code used? |
||
street_names.append("-") | ||
Streets = Enum('Streets', {name: name for name in street_names}) | ||
|
||
parser = EnumOutputParser(enum=Streets) | ||
print(parser.get_format_instructions()) | ||
prompt = PromptTemplate( | ||
template="Return the street that is mentioned in the location string. if non matches return '-'.\nstreets: {streets}\n" + | ||
"location_string:{location}\n{format_instructions}\n", | ||
input_variables=["streets", "location"], | ||
partial_variables={"format_instructions": parser.get_format_instructions()}, | ||
) | ||
|
||
chain = prompt | model | parser | ||
|
||
res = chain.invoke({"streets": street_names, "location": location}) | ||
return res | ||
|
||
|
||
def count_tokens_for_prompt(messages, model): | ||
tokenizer = tiktoken.encoding_for_model(model) | ||
total_tokens = 0 | ||
for message in messages: | ||
# Each message has a role and content | ||
message_tokens = tokenizer.encode(f"{message['role']}: {message['content']}") | ||
total_tokens += len(message_tokens) | ||
# Additional tokens for formatting | ||
total_tokens += 4 # approx overhead for each message (role + delimiters) | ||
|
||
return total_tokens | ||
|
||
|
||
def count_tokens(text, model): | ||
tokenizer = tiktoken.encoding_for_model(model) | ||
tokens = tokenizer.encode(text) | ||
return len(tokens) | ||
|
||
|
||
def ask_gpt(system_message, user_message, model="gpt-4o"): | ||
messages = [ | ||
{"role": "system", "content": system_message}, | ||
{"role": "user", "content": user_message} | ||
] | ||
completion = client.chat.completions.create( | ||
response_format={"type": "json_object"}, | ||
model=model, | ||
messages=messages | ||
) | ||
print(f"tokens for prompt: {count_tokens_for_prompt(messages, model)}") | ||
return completion.choices[0].message | ||
|
||
|
||
def ask_ai_about_street_matching(streets, location_string, model="gpt-4-turbo"): | ||
system_message = """ | ||
Given a list of streets, return the name of the street that is mentioned in the provided location string. | ||
Return the name exactly as appears in list. | ||
If no match is found, return "-". | ||
Return json with field "street" and your answer. | ||
Select one of the following options: | ||
""" + json.dumps(streets + ["-"]) | ||
input = json.dumps({"streets": streets, "location": location_string}) | ||
reply = ask_gpt(system_message, input, model) | ||
# print(f"tokens: {count_tokens(reply.content, model)}") | ||
result = json.loads(reply.content)["street"] | ||
return result, result in streets |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from anyway.parsers.resolution_fields import ResolutionFields as RF | ||
from anyway import secrets | ||
from anyway.models import AccidentMarkerView, RoadSegments | ||
from anyway.llm import ask_ai_about_street_matching | ||
from sqlalchemy import not_ | ||
import pandas as pd | ||
from sqlalchemy.orm import load_only | ||
|
@@ -176,19 +177,7 @@ def get_bounding_box(latitude, longitude, distance_in_km): | |
return final_loc | ||
|
||
|
||
def get_db_matching_location(db, latitude, longitude, resolution, road_no=None): | ||
""" | ||
extracts location from db by closest geo point to location found, using road number if provided and limits to | ||
requested resolution | ||
:param db: the DB | ||
:param latitude: location latitude | ||
:param longitude: location longitude | ||
:param resolution: wanted resolution | ||
:param road_no: road number if there is | ||
:return: a dict containing all the geo fields stated in | ||
resolution dict, with values filled according to resolution | ||
""" | ||
# READ MARKERS FROM DB | ||
def read_markers_and_distance_from_location(db, latitude, longitude, resolution, road_no=None): | ||
geod = Geodesic.WGS84 | ||
relevant_fields = RF.get_possible_fields(resolution) | ||
markers = db.get_markers_for_location_extraction() | ||
|
@@ -222,6 +211,24 @@ def get_db_matching_location(db, latitude, longitude, resolution, road_no=None): | |
markers["dist_point"] = markers.apply( | ||
lambda x: geod.Inverse(latitude, longitude, x["latitude"], x["longitude"])["s12"], axis=1 | ||
).replace({np.nan: None}) | ||
return markers | ||
|
||
|
||
def get_db_matching_location(db, latitude, longitude, resolution, road_no=None): | ||
""" | ||
extracts location from db by closest geo point to location found, using road number if provided and limits to | ||
requested resolution | ||
:param db: the DB | ||
:param latitude: location latitude | ||
:param longitude: location longitude | ||
:param resolution: wanted resolution | ||
:param road_no: road number if there is | ||
:return: a dict containing all the geo fields stated in | ||
resolution dict, with values filled according to resolution | ||
""" | ||
# READ MARKERS FROM DB | ||
relevant_fields = RF.get_possible_fields(resolution) | ||
markers = read_markers_and_distance_from_location(db, latitude, longitude, resolution, road_no) | ||
|
||
most_fit_loc = ( | ||
markers.loc[markers["dist_point"] == markers["dist_point"].min()].iloc[0].to_dict() | ||
|
@@ -240,6 +247,24 @@ def get_db_matching_location(db, latitude, longitude, resolution, road_no=None): | |
return final_loc | ||
|
||
|
||
def read_n_closest_streets(db, n, latitude, longitude, road_no=None): | ||
markers = read_markers_and_distance_from_location( | ||
db, latitude, longitude, BE_CONST.ResolutionCategories.STREET, road_no | ||
) | ||
# Sort by distance | ||
sorted_markers = markers.sort_values(by="dist_point") | ||
|
||
# Drop duplicates to ensure unique street1_hebrew values | ||
unique_street_markers = sorted_markers.drop_duplicates(subset="street1_hebrew") | ||
|
||
# Select the top n entries | ||
top_n_unique_streets = unique_street_markers.head(n) | ||
|
||
# Convert to dictionary if needed | ||
result_dicts = top_n_unique_streets.to_dict(orient="records") | ||
return [result["street1_hebrew"] for result in result_dicts] | ||
|
||
|
||
def set_accident_resolution(accident_row): | ||
""" | ||
set the resolution of the accident | ||
|
@@ -282,11 +307,12 @@ def reverse_geocode_extract(latitude, longitude): | |
try: | ||
gmaps = googlemaps.Client(key=secrets.get("GOOGLE_MAPS_KEY")) | ||
geocode_result = gmaps.reverse_geocode((latitude, longitude)) | ||
|
||
print(geocode_result) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you remove all prints in this code? |
||
# if we got no results, move to next iteration of location string | ||
if not geocode_result: | ||
return None | ||
except Exception as _: | ||
logging.info(_) | ||
logging.info("exception in gmaps") | ||
return None | ||
# logging.info(geocode_result) | ||
|
@@ -539,6 +565,42 @@ def extract_geo_features(db, newsflash: NewsFlash, use_existing_coordinates_only | |
if location_from_db is not None: | ||
update_location_fields(newsflash, location_from_db) | ||
try_find_segment_id(newsflash) | ||
logging.debug(newsflash.resolution) | ||
if newsflash.resolution == BE_CONST.ResolutionCategories.STREET: | ||
try_improve_street_identification(newsflash) | ||
|
||
|
||
def try_improve_street_identification(newsflash): | ||
from anyway.parsers import news_flash_db_adapter | ||
|
||
db = news_flash_db_adapter.init_db() | ||
all_closest_streets = read_n_closest_streets(db, 20, newsflash.lat, newsflash.lon) | ||
|
||
num_of_streets_for_first_try = 5 | ||
streets_for_first_try = all_closest_streets[:num_of_streets_for_first_try] | ||
streets_for_second_try = all_closest_streets[num_of_streets_for_first_try:] | ||
|
||
result, result_in_input = ask_ai_about_street_matching( | ||
streets_for_first_try, newsflash.location | ||
) | ||
logging.debug(f"result of 1st try {result}") | ||
if not result_in_input: | ||
logging.debug(f"street matching failed first try for newsflash {newsflash.id}") | ||
result, result_in_input = ask_ai_about_street_matching( | ||
streets_for_second_try, newsflash.location | ||
) | ||
logging.debug(f"result of 2nd try {result}") | ||
if result_in_input: | ||
if result == newsflash.street1_hebrew: | ||
logging.debug("street matching succeeded, street not changed") | ||
else: | ||
logging.debug( | ||
f"street matching succeeded, street updated for {newsflash.id} " | ||
f"from {newsflash.street1_hebrew} to {result}" | ||
) | ||
newsflash.street1_hebrew = result | ||
else: | ||
logging.debug(f"street matching failed second try for newsflash {newsflash.id}") | ||
|
||
|
||
def update_location_fields(newsflash, location_from_db): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ Flask-Login==0.5.0 | |
Flask-SQLAlchemy==2.4.1 | ||
flask-restx==0.5.1 | ||
Jinja2==3.1.4 | ||
SQLAlchemy==1.3.17 | ||
SQLAlchemy==1.4 | ||
Werkzeug==2.0.3 | ||
alembic==1.4.2 | ||
attrs==23.1.0 | ||
|
@@ -53,3 +53,7 @@ swifter==1.3.4 | |
telebot==0.0.5 | ||
selenium==4.11.2 | ||
apache-airflow-client==2.6.2 | ||
openai==1.45.0 | ||
langchain==0.2.16 | ||
langchain_openai==0.1.25 | ||
python-dotenv | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest inserting the secret into a function, to avoid tests failures in tests from forks