diff --git a/anyway/models.py b/anyway/models.py index 80f248b9..d2a0773a 100755 --- a/anyway/models.py +++ b/anyway/models.py @@ -1231,7 +1231,7 @@ def get_street_by_street_name(yishuv_symbol: int, name: str) -> int: ) if res is None: raise ValueError(f"{name}: could not find street in yishuv:{yishuv_symbol}") - return res + return res.street @staticmethod def get_streets_by_yishuv(yishuv_symbol: int) -> List[dict]: diff --git a/anyway/widgets/all_locations_widgets/injured_count_by_accident_year_widget.py b/anyway/widgets/all_locations_widgets/injured_count_by_accident_year_widget.py index 16764a5d..f4cfc76f 100644 --- a/anyway/widgets/all_locations_widgets/injured_count_by_accident_year_widget.py +++ b/anyway/widgets/all_locations_widgets/injured_count_by_accident_year_widget.py @@ -32,7 +32,7 @@ def __init__(self, request_params: RequestParams): def generate_items(self) -> None: res1 = get_accidents_stats( table_obj=InvolvedMarkerView, - filters=get_injured_filters(self.request_params), + filters=get_injured_filters(self.request_params.location_info), group_by=("accident_year", "injury_severity"), count="injury_severity", start_time=self.request_params.start_time, diff --git a/anyway/widgets/all_locations_widgets/injured_count_by_severity_widget.py b/anyway/widgets/all_locations_widgets/injured_count_by_severity_widget.py index 96bfedf8..a3c3d1f7 100644 --- a/anyway/widgets/all_locations_widgets/injured_count_by_severity_widget.py +++ b/anyway/widgets/all_locations_widgets/injured_count_by_severity_widget.py @@ -9,7 +9,7 @@ get_accidents_stats, join_strings, get_location_text, - get_involved_marker_view_location_filters, + get_injured_filters, ) from anyway.backend_constants import BE_CONST from flask_babel import _ @@ -40,7 +40,7 @@ def get_injured_count_by_severity( start_time: datetime.date, end_time: datetime.date, ): - filters = get_involved_marker_view_location_filters(resolution, location_info) + filters = get_injured_filters(location_info) filters["injury_severity"] = [ InjurySeverity.KILLED.value, InjurySeverity.SEVERE_INJURED.value, diff --git a/anyway/widgets/all_locations_widgets/killed_and_injured_count_per_age_group_widget_utils.py b/anyway/widgets/all_locations_widgets/killed_and_injured_count_per_age_group_widget_utils.py index f0b9d880..35624732 100644 --- a/anyway/widgets/all_locations_widgets/killed_and_injured_count_per_age_group_widget_utils.py +++ b/anyway/widgets/all_locations_widgets/killed_and_injured_count_per_age_group_widget_utils.py @@ -15,6 +15,7 @@ from anyway.widgets.widget_utils import ( get_expression_for_fields, add_resolution_location_accuracy_filter, + remove_loc_text_fields_from_filter, ) # RequestParams is not hashable, so we can't use functools.lru_cache @@ -35,19 +36,20 @@ def filter_and_group_injured_count_per_age_group( ) -> Dict[str, Dict[int, int]]: start_time = request_params.start_time end_time = request_params.end_time - cache_key = None # prevent pylint warning + cache_key = tuple(request_params.location_info.values()) +\ + (start_time, end_time) - if request_params.resolution == BE_CONST.ResolutionCategories.STREET: - accident_yishuv_name = request_params.location_info["yishuv_name"] - street1_hebrew = request_params.location_info["street1_hebrew"] - cache_key = (accident_yishuv_name, street1_hebrew, start_time, end_time) + # if request_params.resolution == BE_CONST.ResolutionCategories.STREET: + # accident_yishuv_name = request_params.location_info["yishuv_name"] + # street1_hebrew = request_params.location_info["street1_hebrew"] + # cache_key = (accident_yishuv_name, street1_hebrew, start_time, end_time) - elif request_params.resolution == BE_CONST.ResolutionCategories.SUBURBAN_ROAD: - road_number = request_params.location_info["road1"] - road_segment_id = request_params.location_info["road_segment_id"] - cache_key = (road_number, road_segment_id, start_time, end_time) + # elif request_params.resolution == BE_CONST.ResolutionCategories.SUBURBAN_ROAD: + # road_number = request_params.location_info["road1"] + # road_segment_id = request_params.location_info["road_segment_id"] + # cache_key = (road_number, road_segment_id, start_time, end_time) - if cache_dict.get(cache_key): + if cache_key in cache_dict: return cache_dict.get(cache_key) query = KilledAndInjuredCountPerAgeGroupWidgetUtils.create_query_for_killed_and_injured_count_per_age_group( @@ -115,6 +117,7 @@ def create_query_for_killed_and_injured_count_per_age_group( loc_filter = adapt_location_fields_to_involve_table(location_info) loc_filter = add_resolution_location_accuracy_filter(loc_filter, resolution) + loc_filter = remove_loc_text_fields_from_filter(loc_filter) loc_ex = get_expression_for_fields(loc_filter, InvolvedMarkerView) query = ( diff --git a/anyway/widgets/road_segment_widgets/accident_count_by_driver_type_widget.py b/anyway/widgets/road_segment_widgets/accident_count_by_driver_type_widget.py index 250533b4..44c19f4a 100644 --- a/anyway/widgets/road_segment_widgets/accident_count_by_driver_type_widget.py +++ b/anyway/widgets/road_segment_widgets/accident_count_by_driver_type_widget.py @@ -29,7 +29,7 @@ def generate_items(self) -> None: @staticmethod def count_accidents_by_driver_type(request_params: RequestParams): - filters = get_injured_filters(request_params) + filters = get_injured_filters(request_params.location_info) filters["involved_type"] = [ consts.InvolvedType.DRIVER.value, consts.InvolvedType.INJURED_DRIVER.value, diff --git a/anyway/widgets/urban_widgets/injured_accidents_with_pedestrians_widget.py b/anyway/widgets/urban_widgets/injured_accidents_with_pedestrians_widget.py index 3a390c0a..3acfabaf 100644 --- a/anyway/widgets/urban_widgets/injured_accidents_with_pedestrians_widget.py +++ b/anyway/widgets/urban_widgets/injured_accidents_with_pedestrians_widget.py @@ -30,11 +30,11 @@ def __init__(self, request_params: RequestParams): self.rank = 18 self.information = "Injured and killed pedestrians by severity and year" - def validate_parameters(self, yishuv_name, street1_hebrew): + def validate_parameters(self, yishuv_symbol, street1): # TODO: validate each parameter and display message accordingly return ( - yishuv_name is not None - and street1_hebrew is not None + yishuv_symbol is not None + and street1 is not None and self.request_params.years_ago is not None ) @@ -53,10 +53,10 @@ def convert_to_dict(query_results): def generate_items(self) -> None: try: - yishuv_name = self.request_params.location_info.get("yishuv_name") - street1_hebrew = self.request_params.location_info.get("street1_hebrew") + yishuv_symbol = self.request_params.location_info.get("yishuv_symbol") + street1 = self.request_params.location_info.get("street1") - # if not self.validate_parameters(yishuv_name, street1_hebrew): + # if not self.validate_parameters(yishuv_symbol, street1_hebrew): # # TODO: this will fail since there is no news_flash_obj in request_params # logging.exception(f"Could not validate parameters yishuv_name + street1_hebrew in widget : {self.name}") # return None @@ -74,7 +74,7 @@ def generate_items(self) -> None: func.count().label("count"), ) .filter(loc_ex) - .filter(InvolvedMarkerView.accident_yishuv_name == yishuv_name) + .filter(InvolvedMarkerView.accident_yishuv_symbol == yishuv_symbol) .filter( InvolvedMarkerView.injury_severity.in_( [ @@ -87,8 +87,8 @@ def generate_items(self) -> None: .filter(InvolvedMarkerView.injured_type == InjuredType.PEDESTRIAN.value) .filter( or_( - InvolvedMarkerView.street1_hebrew == street1_hebrew, - InvolvedMarkerView.street2_hebrew == street1_hebrew, + InvolvedMarkerView.street1 == street1, + InvolvedMarkerView.street2 == street1, ) ) .filter( diff --git a/anyway/widgets/urban_widgets/severe_fatal_count_by_vehicle_by_year_widget.py b/anyway/widgets/urban_widgets/severe_fatal_count_by_vehicle_by_year_widget.py index 2766aeac..0f6f0735 100644 --- a/anyway/widgets/urban_widgets/severe_fatal_count_by_vehicle_by_year_widget.py +++ b/anyway/widgets/urban_widgets/severe_fatal_count_by_vehicle_by_year_widget.py @@ -22,14 +22,14 @@ def __init__(self, request_params: RequestParams): def generate_items(self) -> None: self.items = SevereFatalCountByVehicleByYearWidget.separate_data( - self.request_params.location_info["yishuv_name"], + self.request_params.location_info["yishuv_symbol"], self.request_params.start_time, self.request_params.end_time, self.request_params.resolution, ) @staticmethod - def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]: + def separate_data(yishuv_symbol, start_time, end_time, resolution) -> Dict[str, Any]: output = { "e_bikes": get_accidents_stats( table_obj=InvolvedMarkerView, @@ -39,7 +39,7 @@ def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]: InjurySeverity.SEVERE_INJURED.value, ], "involve_vehicle_type": VehicleType.ELECTRIC_BIKE.value, - "accident_yishuv_name": yishuv, + "accident_yishuv_symbol": yishuv_symbol, }, group_by="accident_year", count="accident_year", @@ -55,7 +55,7 @@ def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]: InjurySeverity.SEVERE_INJURED.value, ], "involve_vehicle_type": VehicleType.BIKE.value, - "accident_yishuv_name": yishuv, + "accident_yishuv_symbol": yishuv_symbol, }, group_by="accident_year", count="accident_year", @@ -71,7 +71,7 @@ def separate_data(yishuv, start_time, end_time, resolution) -> Dict[str, Any]: InjurySeverity.SEVERE_INJURED.value, ], "involve_vehicle_type": VehicleType.ELECTRIC_SCOOTER.value, - "accident_yishuv_name": yishuv, + "accident_yishuv_symbol": yishuv_symbol, }, group_by="accident_year", count="accident_year", diff --git a/anyway/widgets/urban_widgets/small_motor_severe_fatal_count_by_year_widget.py b/anyway/widgets/urban_widgets/small_motor_severe_fatal_count_by_year_widget.py index a4b951af..83b5a810 100644 --- a/anyway/widgets/urban_widgets/small_motor_severe_fatal_count_by_year_widget.py +++ b/anyway/widgets/urban_widgets/small_motor_severe_fatal_count_by_year_widget.py @@ -21,14 +21,14 @@ def __init__(self, request_params: RequestParams): def generate_items(self) -> None: self.items = SmallMotorSevereFatalCountByYearWidget.get_motor_stats( - self.request_params.location_info["yishuv_name"], + self.request_params.location_info["yishuv_symbol"], self.request_params.start_time, self.request_params.end_time, self.request_params.resolution, ) @staticmethod - def get_motor_stats(location_info, start_time, end_time, resolution): + def get_motor_stats(yishuv_symbol, start_time, end_time, resolution): count_by_year = get_accidents_stats( table_obj=InvolvedMarkerView, filters={ @@ -37,7 +37,7 @@ def get_motor_stats(location_info, start_time, end_time, resolution): InjurySeverity.SEVERE_INJURED.value, ], "involve_vehicle_type": VehicleCategory.BICYCLE_AND_SMALL_MOTOR.get_codes(), - "accident_yishuv_name": location_info, + "accident_yishuv_symbol": yishuv_symbol, }, group_by="accident_year", count="accident_year", diff --git a/anyway/widgets/urban_widgets/urban_crosswalk_widget.py b/anyway/widgets/urban_widgets/urban_crosswalk_widget.py index 69ddd532..1da26d9d 100644 --- a/anyway/widgets/urban_widgets/urban_crosswalk_widget.py +++ b/anyway/widgets/urban_widgets/urban_crosswalk_widget.py @@ -10,7 +10,6 @@ from anyway.widgets.widget_utils import get_accidents_stats -# TODO: pretty sure there are errors in this widget, for example, is_included returns self.items class UrbanCrosswalkWidget(UrbanWidget): name: str = "urban_accidents_by_cross_location" files = [__file__] @@ -21,15 +20,14 @@ def __init__(self, request_params: RequestParams): def generate_items(self) -> None: self.items = UrbanCrosswalkWidget.get_crosswalk( - self.request_params.location_info["yishuv_name"], - self.request_params.location_info["street1_hebrew"], + self.request_params.location_info, self.request_params.start_time, self.request_params.end_time, self.request_params.resolution, ) @staticmethod - def get_crosswalk(yishuv, street, start_time, end_time, resolution) -> Dict[str, Any]: + def get_crosswalk(location_info: dict, start_time, end_time, resolution) -> Dict[str, Any]: cross_output = { "with_crosswalk": get_accidents_stats( table_obj=InvolvedMarkerView, @@ -39,11 +37,11 @@ def get_crosswalk(yishuv, street, start_time, end_time, resolution) -> Dict[str, InjurySeverity.SEVERE_INJURED.value, ], "cross_location": CrossCategory.CROSSWALK.get_codes(), - "accident_yishuv_name": yishuv, - "street1_hebrew": street, + "accident_yishuv_symbol": location_info["yishuv_symbol"], + "street1": location_info["street"], }, - group_by="street1_hebrew", - count="street1_hebrew", + group_by="street1", + count="street1", start_time=start_time, end_time=end_time, resolution=resolution, @@ -56,20 +54,22 @@ def get_crosswalk(yishuv, street, start_time, end_time, resolution) -> Dict[str, InjurySeverity.SEVERE_INJURED.value, ], "cross_location": CrossCategory.NONE.get_codes(), - "accident_yishuv_name": yishuv, - "street1_hebrew": street, + "accident_yishuv_symbol": location_info["yishuv_symbol"], + "street1": location_info["street"], }, - group_by="street1_hebrew", - count="street1_hebrew", + group_by="street1", + count="street1", start_time=start_time, end_time=end_time, resolution=resolution, ), } if not cross_output["with_crosswalk"]: - cross_output["with_crosswalk"] = [{"street1_hebrew": street, "count": 0}] + cross_output["with_crosswalk"] = [{"street1_hebrew": location_info["street1_hebrew"], + "count": 0}] if not cross_output["without_crosswalk"]: - cross_output["without_crosswalk"] = [{"street1_hebrew": street, "count": 0}] + cross_output["without_crosswalk"] = [{"street1_hebrew": location_info["street1_hebrew"], + "count": 0}] return cross_output @staticmethod diff --git a/anyway/widgets/urban_widgets/urban_widget.py b/anyway/widgets/urban_widgets/urban_widget.py index 3a701270..957658a8 100644 --- a/anyway/widgets/urban_widgets/urban_widget.py +++ b/anyway/widgets/urban_widgets/urban_widget.py @@ -14,8 +14,8 @@ def __init__(self, request_params: RequestParams): def is_urban(request_params: RequestParams) -> bool: return ( request_params is not None - and "yishuv_name" in request_params.location_info - and "street1_hebrew" in request_params.location_info + and "yishuv_symbol" in request_params.location_info + and "street1" in request_params.location_info ) @staticmethod diff --git a/anyway/widgets/widget_utils.py b/anyway/widgets/widget_utils.py index 75a13368..80d7f21b 100644 --- a/anyway/widgets/widget_utils.py +++ b/anyway/widgets/widget_utils.py @@ -24,9 +24,7 @@ def get_query(table_obj, filters, start_time, end_time): - if "road_segment_name" in filters and "road_segment_id" in filters: - filters = copy.copy(filters) - filters.pop("road_segment_name") + filters = remove_loc_text_fields_from_filter(filters) query = db.session.query(table_obj) if start_time: query = query.filter(getattr(table_obj, "accident_timestamp") >= start_time) @@ -48,6 +46,20 @@ def get_query(table_obj, filters, start_time, end_time): return query +def remove_loc_text_fields_from_filter(filters: dict) -> dict: + def remove_first_if_both_exist(d: dict, first: str, second: str) -> dict: + if first in d and second in d: + d.pop(first) + + res = copy.copy(filters) + remove_first_if_both_exist(res, "road_segment_name", "road_segment_id") + remove_first_if_both_exist(res, "yishuv_name", "yishuv_symbol") + remove_first_if_both_exist(res, "street1_hebrew", "street1") + remove_first_if_both_exist(res, "street2_hebrew", "street2") + remove_first_if_both_exist(res, "accident_yishuv_name", "accident_yishuv_symbol") + return res + + def get_expression_for_fields(filters: dict, table_obj): op_other, op_segment = None, None if "road_segment_id" not in filters.keys(): @@ -196,29 +208,28 @@ def gen_entity_labels(entity: Type[LabeledCode]) -> dict: return res -def get_involved_marker_view_location_filters( - resolution: BE_CONST.ResolutionCategories, location_info: LocationInfo -): - filters = {} - if resolution == BE_CONST.ResolutionCategories.STREET: - filters["accident_yishuv_name"] = location_info.get("yishuv_name") - filters["street1_hebrew"] = location_info.get("street1_hebrew") - elif resolution == BE_CONST.ResolutionCategories.SUBURBAN_ROAD: - filters["road1"] = location_info.get("road1") - filters["road_segment_id"] = location_info.get("road_segment_id") - return filters - - -def get_injured_filters(request_params: RequestParams): - new_filters = get_involved_marker_view_location_filters( - request_params.resolution, request_params.location_info - ) - for curr_filter, curr_values in request_params.location_info.items(): - if curr_filter in ["region_hebrew", "district_hebrew", "yishuv_name"]: +# def get_involved_marker_view_location_filters( +# resolution: BE_CONST.ResolutionCategories, location_info: LocationInfo +# ): +# filters = {} +# if resolution == BE_CONST.ResolutionCategories.STREET: +# filters["accident_yishuv_name"] = location_info.get("yishuv_name") +# filters["street1_hebrew"] = location_info.get("street1_hebrew") +# elif resolution == BE_CONST.ResolutionCategories.SUBURBAN_ROAD: +# filters["road1"] = location_info.get("road1") +# filters["road_segment_id"] = location_info.get("road_segment_id") +# return filters + + +def get_injured_filters(location_info: dict): + new_filters = copy.copy(location_info) + for curr_filter, curr_value in location_info.items(): + if curr_filter in ["region_hebrew", "district_hebrew", "yishuv_name", "yishuv_symbol"]: new_filter_name = "accident_" + curr_filter - new_filters[new_filter_name] = curr_values + new_filters[new_filter_name] = curr_value + new_filters.pop(curr_filter) - new_filters["injury_severity"] = [1, 2, 3, 4, 5] + new_filters["injury_severity"] = [1, 2, 3] return new_filters @@ -294,10 +305,10 @@ def get_involved_counts( .order_by(table.accident_year) ) filters = add_resolution_location_accuracy_filter(location_info, table) - if "yishuv_symbol" in location_info: + filters = remove_loc_text_fields_from_filter(filters) + if "yishuv_symbol" in filters: filters["accident_yishuv_symbol"] = filters["yishuv_symbol"] filters.pop("yishuv_symbol") - filters.pop("yishuv_name", None) ex = get_expression_for_fields(filters, table) query = query.filter(ex).group_by(table.accident_year) diff --git a/tests/test_infographics_utils.py b/tests/test_infographics_utils.py index 6e20a5a0..adfdbc12 100644 --- a/tests/test_infographics_utils.py +++ b/tests/test_infographics_utils.py @@ -8,7 +8,7 @@ get_filter_expression_raw, get_filter_expression, get_expression_for_non_road_segment_fields, - + remove_loc_text_fields_from_filter, ) from anyway.backend_constants import AccidentSeverity from anyway.models import AccidentMarkerView, RoadJunctionKM, RoadSegments, InvolvedMarkerView @@ -213,6 +213,31 @@ def test_get_expression_for_non_road_segment_fields(self): self.assertIn(" AND ", str(actual), "4") + def test_remove_names_from_filters(self): + self.assertEqual({}, remove_loc_text_fields_from_filter({}), "1") + expected = {"a": 1, "yishuv_symbol": 2} + test = {"a": 1, "yishuv_symbol": 2, "yishuv_name": "yishuv"} + actual = remove_loc_text_fields_from_filter(test) + self.assertEqual(expected, actual, "2") + expected = {"a": 1, "yishuv_symbol": 2, + "street1": 3, "street2": 4, + "road_segment_id": 17} + test = {"a": 1, "yishuv_symbol": 2, "yishuv_name": "yishuv", + "street1": 3, "street1_hebrew": "Hebrew", + "street2": 4, "street2_hebrew": "Hebrew2", + "road_segment_name": "seg name", "road_segment_id": 17} + actual = remove_loc_text_fields_from_filter(test) + self.assertEqual(expected, actual, "3") + expected = {"a": 1, "accident_yishuv_symbol": 2, + "street1_hebrew": "Hebrew", "street2": 4, + "road_segment_id": 17} + test = {"a": 1, "accident_yishuv_symbol": 2, "accident_yishuv_name": "yishuv", + "street1_hebrew": "Hebrew", + "street2": 4, "street2_hebrew": "Hebrew2", + "road_segment_name": "seg name", "road_segment_id": 17} + actual = remove_loc_text_fields_from_filter(test) + self.assertEqual(expected, actual, "3") + if __name__ == '__main__': unittest.main()