Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter numeric only fields in Widget queries #2686

Merged
merged 3 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion anyway/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ def get_street_by_street_name(yishuv_symbol: int, name: str) -> int:
)
if res is None:
raise ValueError(f"{name}: could not find street in yishuv:{yishuv_symbol}")
return res
return res.street

@staticmethod
def get_streets_by_yishuv(yishuv_symbol: int) -> List[dict]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, request_params: RequestParams):
def generate_items(self) -> None:
res1 = get_accidents_stats(
table_obj=InvolvedMarkerView,
filters=get_injured_filters(self.request_params),
filters=get_injured_filters(self.request_params.location_info),
group_by=("accident_year", "injury_severity"),
count="injury_severity",
start_time=self.request_params.start_time,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
get_accidents_stats,
join_strings,
get_location_text,
get_involved_marker_view_location_filters,
get_injured_filters,
)
from anyway.backend_constants import BE_CONST
from flask_babel import _
Expand Down Expand Up @@ -40,7 +40,7 @@ def get_injured_count_by_severity(
start_time: datetime.date,
end_time: datetime.date,
):
filters = get_involved_marker_view_location_filters(resolution, location_info)
filters = get_injured_filters(location_info)
filters["injury_severity"] = [
InjurySeverity.KILLED.value,
InjurySeverity.SEVERE_INJURED.value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from anyway.widgets.widget_utils import (
get_expression_for_fields,
add_resolution_location_accuracy_filter,
remove_loc_text_fields_from_filter,
)

# RequestParams is not hashable, so we can't use functools.lru_cache
Expand All @@ -35,19 +36,20 @@ def filter_and_group_injured_count_per_age_group(
) -> Dict[str, Dict[int, int]]:
start_time = request_params.start_time
end_time = request_params.end_time
cache_key = None # prevent pylint warning
cache_key = tuple(request_params.location_info.values()) +\
(start_time, end_time)

if request_params.resolution == BE_CONST.ResolutionCategories.STREET:
accident_yishuv_name = request_params.location_info["yishuv_name"]
street1_hebrew = request_params.location_info["street1_hebrew"]
cache_key = (accident_yishuv_name, street1_hebrew, start_time, end_time)
# if request_params.resolution == BE_CONST.ResolutionCategories.STREET:
# accident_yishuv_name = request_params.location_info["yishuv_name"]
# street1_hebrew = request_params.location_info["street1_hebrew"]
# cache_key = (accident_yishuv_name, street1_hebrew, start_time, end_time)

elif request_params.resolution == BE_CONST.ResolutionCategories.SUBURBAN_ROAD:
road_number = request_params.location_info["road1"]
road_segment_id = request_params.location_info["road_segment_id"]
cache_key = (road_number, road_segment_id, start_time, end_time)
# elif request_params.resolution == BE_CONST.ResolutionCategories.SUBURBAN_ROAD:
# road_number = request_params.location_info["road1"]
# road_segment_id = request_params.location_info["road_segment_id"]
# cache_key = (road_number, road_segment_id, start_time, end_time)

if cache_dict.get(cache_key):
if cache_key in cache_dict:
return cache_dict.get(cache_key)

query = KilledAndInjuredCountPerAgeGroupWidgetUtils.create_query_for_killed_and_injured_count_per_age_group(
Expand Down Expand Up @@ -115,6 +117,7 @@ def create_query_for_killed_and_injured_count_per_age_group(
loc_filter = adapt_location_fields_to_involve_table(location_info)
loc_filter = add_resolution_location_accuracy_filter(loc_filter,
resolution)
loc_filter = remove_loc_text_fields_from_filter(loc_filter)
loc_ex = get_expression_for_fields(loc_filter, InvolvedMarkerView)

query = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def generate_items(self) -> None:

@staticmethod
def count_accidents_by_driver_type(request_params: RequestParams):
filters = get_injured_filters(request_params)
filters = get_injured_filters(request_params.location_info)
filters["involved_type"] = [
consts.InvolvedType.DRIVER.value,
consts.InvolvedType.INJURED_DRIVER.value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def __init__(self, request_params: RequestParams):
self.rank = 18
self.information = "Injured and killed pedestrians by severity and year"

def validate_parameters(self, yishuv_name, street1_hebrew):
def validate_parameters(self, yishuv_symbol, street1):
# TODO: validate each parameter and display message accordingly
return (
yishuv_name is not None
and street1_hebrew is not None
yishuv_symbol is not None
and street1 is not None
and self.request_params.years_ago is not None
)

Expand All @@ -53,10 +53,10 @@ def convert_to_dict(query_results):

def generate_items(self) -> None:
try:
yishuv_name = self.request_params.location_info.get("yishuv_name")
street1_hebrew = self.request_params.location_info.get("street1_hebrew")
yishuv_symbol = self.request_params.location_info.get("yishuv_symbol")
street1 = self.request_params.location_info.get("street1")

# if not self.validate_parameters(yishuv_name, street1_hebrew):
# if not self.validate_parameters(yishuv_symbol, street1_hebrew):
# # TODO: this will fail since there is no news_flash_obj in request_params
# logging.exception(f"Could not validate parameters yishuv_name + street1_hebrew in widget : {self.name}")
# return None
Expand All @@ -74,7 +74,7 @@ def generate_items(self) -> None:
func.count().label("count"),
)
.filter(loc_ex)
.filter(InvolvedMarkerView.accident_yishuv_name == yishuv_name)
.filter(InvolvedMarkerView.accident_yishuv_symbol == yishuv_symbol)
.filter(
InvolvedMarkerView.injury_severity.in_(
[
Expand All @@ -87,8 +87,8 @@ def generate_items(self) -> None:
.filter(InvolvedMarkerView.injured_type == InjuredType.PEDESTRIAN.value)
.filter(
or_(
InvolvedMarkerView.street1_hebrew == street1_hebrew,
InvolvedMarkerView.street2_hebrew == street1_hebrew,
InvolvedMarkerView.street1 == street1,
InvolvedMarkerView.street2 == street1,
)
)
.filter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def __init__(self, request_params: RequestParams):

def generate_items(self) -> None:
self.items = SevereFatalCountByVehicleByYearWidget.separate_data(
self.request_params.location_info["yishuv_name"],
self.request_params.location_info["yishuv_symbol"],
self.request_params.start_time,
self.request_params.end_time,
self.request_params.resolution,
)

@staticmethod
def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]:
def separate_data(yishuv_symbol, start_time, end_time, resolution) -> Dict[str, Any]:
output = {
"e_bikes": get_accidents_stats(
table_obj=InvolvedMarkerView,
Expand All @@ -39,7 +39,7 @@ def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]:
InjurySeverity.SEVERE_INJURED.value,
],
"involve_vehicle_type": VehicleType.ELECTRIC_BIKE.value,
"accident_yishuv_name": yishuv,
"accident_yishuv_symbol": yishuv_symbol,
},
group_by="accident_year",
count="accident_year",
Expand All @@ -55,7 +55,7 @@ def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]:
InjurySeverity.SEVERE_INJURED.value,
],
"involve_vehicle_type": VehicleType.BIKE.value,
"accident_yishuv_name": yishuv,
"accident_yishuv_symbol": yishuv_symbol,
},
group_by="accident_year",
count="accident_year",
Expand All @@ -71,7 +71,7 @@ def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]:
InjurySeverity.SEVERE_INJURED.value,
],
"involve_vehicle_type": VehicleType.ELECTRIC_SCOOTER.value,
"accident_yishuv_name": yishuv,
"accident_yishuv_symbol": yishuv_symbol,
},
group_by="accident_year",
count="accident_year",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def __init__(self, request_params: RequestParams):

def generate_items(self) -> None:
self.items = SmallMotorSevereFatalCountByYearWidget.get_motor_stats(
self.request_params.location_info["yishuv_name"],
self.request_params.location_info["yishuv_symbol"],
self.request_params.start_time,
self.request_params.end_time,
self.request_params.resolution,
)

@staticmethod
def get_motor_stats(location_info, start_time, end_time, resolution):
def get_motor_stats(yishuv_symbol, start_time, end_time, resolution):
count_by_year = get_accidents_stats(
table_obj=InvolvedMarkerView,
filters={
Expand All @@ -37,7 +37,7 @@ def get_motor_stats(location_info, start_time, end_time, resolution):
InjurySeverity.SEVERE_INJURED.value,
],
"involve_vehicle_type": VehicleCategory.BICYCLE_AND_SMALL_MOTOR.get_codes(),
"accident_yishuv_name": location_info,
"accident_yishuv_symbol": yishuv_symbol,
},
group_by="accident_year",
count="accident_year",
Expand Down
28 changes: 14 additions & 14 deletions anyway/widgets/urban_widgets/urban_crosswalk_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from anyway.widgets.widget_utils import get_accidents_stats


# TODO: pretty sure there are errors in this widget, for example, is_included returns self.items
class UrbanCrosswalkWidget(UrbanWidget):
name: str = "urban_accidents_by_cross_location"
files = [__file__]
Expand All @@ -21,15 +20,14 @@ def __init__(self, request_params: RequestParams):

def generate_items(self) -> None:
self.items = UrbanCrosswalkWidget.get_crosswalk(
self.request_params.location_info["yishuv_name"],
self.request_params.location_info["street1_hebrew"],
self.request_params.location_info,
self.request_params.start_time,
self.request_params.end_time,
self.request_params.resolution,
)

@staticmethod
def get_crosswalk(yishuv, street, start_time, end_time, resolution) -> Dict[str, Any]:
def get_crosswalk(location_info: dict, start_time, end_time, resolution) -> Dict[str, Any]:
cross_output = {
"with_crosswalk": get_accidents_stats(
table_obj=InvolvedMarkerView,
Expand All @@ -39,11 +37,11 @@ def get_crosswalk(yishuv, street, start_time, end_time, resolution) -> Dict[str,
InjurySeverity.SEVERE_INJURED.value,
],
"cross_location": CrossCategory.CROSSWALK.get_codes(),
"accident_yishuv_name": yishuv,
"street1_hebrew": street,
"accident_yishuv_symbol": location_info["yishuv_symbol"],
"street1": location_info["street"],
},
group_by="street1_hebrew",
count="street1_hebrew",
group_by="street1",
count="street1",
start_time=start_time,
end_time=end_time,
resolution=resolution,
Expand All @@ -56,20 +54,22 @@ def get_crosswalk(yishuv, street, start_time, end_time, resolution) -> Dict[str,
InjurySeverity.SEVERE_INJURED.value,
],
"cross_location": CrossCategory.NONE.get_codes(),
"accident_yishuv_name": yishuv,
"street1_hebrew": street,
"accident_yishuv_symbol": location_info["yishuv_symbol"],
"street1": location_info["street"],
},
group_by="street1_hebrew",
count="street1_hebrew",
group_by="street1",
count="street1",
start_time=start_time,
end_time=end_time,
resolution=resolution,
),
}
if not cross_output["with_crosswalk"]:
cross_output["with_crosswalk"] = [{"street1_hebrew": street, "count": 0}]
cross_output["with_crosswalk"] = [{"street1_hebrew": location_info["street1_hebrew"],
"count": 0}]
if not cross_output["without_crosswalk"]:
cross_output["without_crosswalk"] = [{"street1_hebrew": street, "count": 0}]
cross_output["without_crosswalk"] = [{"street1_hebrew": location_info["street1_hebrew"],
"count": 0}]
ziv17 marked this conversation as resolved.
Show resolved Hide resolved
return cross_output

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions anyway/widgets/urban_widgets/urban_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def __init__(self, request_params: RequestParams):
def is_urban(request_params: RequestParams) -> bool:
return (
request_params is not None
and "yishuv_name" in request_params.location_info
and "street1_hebrew" in request_params.location_info
and "yishuv_symbol" in request_params.location_info
and "street1" in request_params.location_info
)

@staticmethod
Expand Down
63 changes: 37 additions & 26 deletions anyway/widgets/widget_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@


def get_query(table_obj, filters, start_time, end_time):
if "road_segment_name" in filters and "road_segment_id" in filters:
filters = copy.copy(filters)
filters.pop("road_segment_name")
filters = remove_loc_text_fields_from_filter(filters)
query = db.session.query(table_obj)
if start_time:
query = query.filter(getattr(table_obj, "accident_timestamp") >= start_time)
Expand All @@ -48,6 +46,20 @@ def get_query(table_obj, filters, start_time, end_time):
return query


def remove_loc_text_fields_from_filter(filters: dict) -> dict:
def remove_first_if_both_exist(d: dict, first: str, second: str) -> dict:
if first in d and second in d:
d.pop(first)

res = copy.copy(filters)
remove_first_if_both_exist(res, "road_segment_name", "road_segment_id")
remove_first_if_both_exist(res, "yishuv_name", "yishuv_symbol")
remove_first_if_both_exist(res, "street1_hebrew", "street1")
remove_first_if_both_exist(res, "street2_hebrew", "street2")
remove_first_if_both_exist(res, "accident_yishuv_name", "accident_yishuv_symbol")
return res


def get_expression_for_fields(filters: dict, table_obj):
op_other, op_segment = None, None
if "road_segment_id" not in filters.keys():
Expand Down Expand Up @@ -196,29 +208,28 @@ def gen_entity_labels(entity: Type[LabeledCode]) -> dict:
return res


def get_involved_marker_view_location_filters(
resolution: BE_CONST.ResolutionCategories, location_info: LocationInfo
):
filters = {}
if resolution == BE_CONST.ResolutionCategories.STREET:
filters["accident_yishuv_name"] = location_info.get("yishuv_name")
filters["street1_hebrew"] = location_info.get("street1_hebrew")
elif resolution == BE_CONST.ResolutionCategories.SUBURBAN_ROAD:
filters["road1"] = location_info.get("road1")
filters["road_segment_id"] = location_info.get("road_segment_id")
return filters


def get_injured_filters(request_params: RequestParams):
new_filters = get_involved_marker_view_location_filters(
request_params.resolution, request_params.location_info
)
for curr_filter, curr_values in request_params.location_info.items():
if curr_filter in ["region_hebrew", "district_hebrew", "yishuv_name"]:
# def get_involved_marker_view_location_filters(
# resolution: BE_CONST.ResolutionCategories, location_info: LocationInfo
# ):
# filters = {}
# if resolution == BE_CONST.ResolutionCategories.STREET:
# filters["accident_yishuv_name"] = location_info.get("yishuv_name")
# filters["street1_hebrew"] = location_info.get("street1_hebrew")
# elif resolution == BE_CONST.ResolutionCategories.SUBURBAN_ROAD:
# filters["road1"] = location_info.get("road1")
# filters["road_segment_id"] = location_info.get("road_segment_id")
# return filters


def get_injured_filters(location_info: dict):
new_filters = copy.copy(location_info)
for curr_filter, curr_value in location_info.items():
if curr_filter in ["region_hebrew", "district_hebrew", "yishuv_name", "yishuv_symbol"]:
new_filter_name = "accident_" + curr_filter
new_filters[new_filter_name] = curr_values
new_filters[new_filter_name] = curr_value
new_filters.pop(curr_filter)

new_filters["injury_severity"] = [1, 2, 3, 4, 5]
new_filters["injury_severity"] = [1, 2, 3]
return new_filters


Expand Down Expand Up @@ -294,10 +305,10 @@ def get_involved_counts(
.order_by(table.accident_year)
)
filters = add_resolution_location_accuracy_filter(location_info, table)
if "yishuv_symbol" in location_info:
filters = remove_loc_text_fields_from_filter(filters)
if "yishuv_symbol" in filters:
filters["accident_yishuv_symbol"] = filters["yishuv_symbol"]
filters.pop("yishuv_symbol")
filters.pop("yishuv_name", None)
ex = get_expression_for_fields(filters, table)
query = query.filter(ex).group_by(table.accident_year)

Expand Down
Loading
Loading