Skip to content

Commit

Permalink
Refactor alert save logic
Browse files Browse the repository at this point in the history
- Add typing support
- Break into functions
  • Loading branch information
thenav56 committed Apr 26, 2024
1 parent 2b2722e commit 3c6e6af
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 139 deletions.
331 changes: 199 additions & 132 deletions apps/cap_feed/formats/cap_xml.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import datetime
import logging
from xml.etree.ElementTree import Element as XmlElement

from django.contrib.gis.geos import GEOSGeometry
from django.db import IntegrityError
from django.utils import timezone

Expand All @@ -17,8 +22,12 @@
AlertInfoAreaGeocode,
AlertInfoAreaPolygon,
AlertInfoParameter,
Feed,
ProcessedAlert,
)
from main.managers import BulkCreateManager

logger = logging.getLogger(__name__)


def find_element(element, ns, tag):
Expand All @@ -28,137 +37,195 @@ def find_element(element, ns, tag):
return None


def create_alert(
feed: Feed,
url: str,
alert_root: XmlElement,
ns: dict,
) -> Alert | None:
alert_status = find_element(alert_root, ns, 'cap:status')
if alert_status != 'Actual':
return

# TODO: Properly handle reportOptionalMemberAccess
return Alert.objects.create(
feed=feed,
country=feed.country,
url=url,
identifier=alert_root.find('cap:identifier', ns).text, # type: ignore[reportOptionalMemberAccess]
sender=alert_root.find('cap:sender', ns).text, # type: ignore[reportOptionalMemberAccess]
sent=convert_datetime(alert_root.find('cap:sent', ns).text), # type: ignore[reportOptionalMemberAccess]
msg_type=alert_root.find('cap:msgType', ns).text, # type: ignore[reportOptionalMemberAccess]
source=find_element(alert_root, ns, 'cap:source'),
scope=alert_root.find('cap:scope', ns).text, # type: ignore[reportOptionalMemberAccess]
restriction=find_element(alert_root, ns, 'cap:restriction'),
addresses=find_element(alert_root, ns, 'cap:addresses'),
references=find_element(alert_root, ns, 'cap:references'),
code=find_element(alert_root, ns, 'cap:code'),
note=find_element(alert_root, ns, 'cap:note'),
incidents=find_element(alert_root, ns, 'cap:incidents'),
status=alert_status,
)


def create_alert_info(
alert: Alert,
alert_info_entry: XmlElement,
expire_time: datetime.datetime | None,
ns: dict,
) -> AlertInfo:
# TODO: Properly handle reportOptionalMemberAccess

return AlertInfo.objects.create(
alert=alert,
language=('en-US' if (x := alert_info_entry.find('cap:language', ns)) is None else x.text),
category=alert_info_entry.find('cap:category', ns).text, # type: ignore[reportOptionalMemberAccess]
event=alert_info_entry.find('cap:event', ns).text, # type: ignore[reportOptionalMemberAccess]
response_type=find_element(alert_info_entry, ns, 'cap:responseType'),
urgency=alert_info_entry.find('cap:urgency', ns).text, # type: ignore[reportOptionalMemberAccess]
severity=alert_info_entry.find('cap:severity', ns).text, # type: ignore[reportOptionalMemberAccess]
certainty=alert_info_entry.find('cap:certainty', ns).text, # type: ignore[reportOptionalMemberAccess]
audience=find_element(alert_info_entry, ns, 'cap:audience'),
effective=(alert.sent if (x := alert_info_entry.find('cap:effective', ns)) is None else x.text),
onset=convert_datetime(find_element(alert_info_entry, ns, 'cap:onset')),
sender_name=find_element(alert_info_entry, ns, 'cap:senderName'),
headline=find_element(alert_info_entry, ns, 'cap:headline'),
description=find_element(alert_info_entry, ns, 'cap:description'),
instruction=find_element(alert_info_entry, ns, 'cap:instruction'),
web=find_element(alert_info_entry, ns, 'cap:web'),
contact=find_element(alert_info_entry, ns, 'cap:contact'),
expires=expire_time,
)


def process_alert_info(
alert: Alert,
alert_info_entry: XmlElement,
mgr: BulkCreateManager,
ns: dict,
) -> tuple[AlertInfo | None, list[GEOSGeometry]]:
expire_time = convert_datetime(find_element(alert_info_entry, ns, 'cap:expires'))
if expire_time is not None and expire_time < timezone.now():
return None, []

polygons = []
alert_info = create_alert_info(alert, alert_info_entry, expire_time, ns)

# navigate alert info parameter
for alert_info_parameter_entry in alert_info_entry.findall('cap:parameter', ns):
mgr.add(
AlertInfoParameter(
alert_info=alert_info,
value_name=alert_info_parameter_entry.find('cap:valueName', ns).text, # type: ignore[reportOptionalMemberAccess] # noqa: E501
value=alert_info_parameter_entry.find('cap:value', ns).text, # type: ignore[reportOptionalMemberAccess]
)
)

# navigate alert info area
for alert_info_area_entry in alert_info_entry.findall('cap:area', ns):
alert_info_area = AlertInfoArea.objects.create(
alert_info=alert_info,
area_desc=alert_info_area_entry.find('cap:areaDesc', ns).text, # type: ignore[reportOptionalMemberAccess]
altitude=find_element(alert_info_entry, ns, 'cap:altitude'),
ceiling=find_element(alert_info_entry, ns, 'cap:ceiling'),
)

# navigate alert info area circle
for alert_info_area_circle_entry in alert_info_area_entry.findall('cap:circle', ns):
mgr.add(
AlertInfoAreaCircle(
alert_info_area=alert_info_area,
value=alert_info_area_circle_entry.text,
)
)

# navigate info area geocode
for alert_info_area_geocode_entry in alert_info_area_entry.findall('cap:geocode', ns):
mgr.add(
AlertInfoAreaGeocode(
alert_info_area=alert_info_area,
value_name=alert_info_area_geocode_entry.find('cap:valueName', ns).text, # type: ignore[reportOptionalMemberAccess] # noqa: E501
value=alert_info_area_geocode_entry.find('cap:value', ns).text, # type: ignore[reportOptionalMemberAccess] # noqa: E501
)
)

# navigate alert info area polygon
for alert_info_area_polygon_entry in alert_info_area_entry.findall('cap:polygon', ns):
if alert_info_area_polygon_entry is not None and alert_info_area_polygon_entry.text:
alert_info_area_polygon = AlertInfoAreaPolygon(
alert_info_area=alert_info_area,
value=alert_info_area_polygon_entry.text.strip(),
)
mgr.add(alert_info_area_polygon)
if parsed_polygon := alert_info_area_polygon.value_geojson:
polygons.append(parsed_polygon)
return alert_info, polygons


def process_alert(
url: str,
alert_root: XmlElement,
feed: Feed,
ns: dict,
) -> Alert | None:
alert = create_alert(feed, url, alert_root, ns)
if alert is None:
return

mgr = BulkCreateManager()
alert_has_valid_info = False
tagged_admin1s_id = set()

# navigate alert info
for alert_info_entry in alert_root.findall('cap:info', ns):
alert_info, alert_info_polygons = process_alert_info(alert, alert_info_entry, mgr, ns)
if not alert_info:
continue

alert_has_valid_info = True

# XXX: Do we need to check circles as well?
# check polygon intersection with admin1s
for polygon in alert_info_polygons:
possible_admin1s = Admin1.objects.filter(
country=alert.country,
# TODO: Check for performance issues
geometry__intersects=polygon,
).exclude(id__in=tagged_admin1s_id)
for admin1_id in possible_admin1s.values_list('id', flat=True):
tagged_admin1s_id.add(admin1_id)
mgr.add(
AlertAdmin1(
alert=alert,
admin1_id=admin1_id,
)
)

if alert_has_valid_info:
if not tagged_admin1s_id:
if unknown_admin1 := Admin1.objects.filter(country=alert.country, name='Unknown').first():
mgr.add(
AlertAdmin1(
alert=alert,
admin1=unknown_admin1,
)
)

alert.info_has_been_added()
alert.save()

mgr.done()
if mrg_summary := mgr.summary(ignore_empty=True):
logger.info(f"DB ops summary for alert: {alert.pk}: {str(mrg_summary)}")
return alert


def get_alert(url, alert_root, feed, ns) -> bool:
alert = None
try:
# register alert
alert = Alert()
alert.feed = feed
alert.country = feed.country
alert.url = url
alert.identifier = alert_root.find('cap:identifier', ns).text
alert.sender = alert_root.find('cap:sender', ns).text
alert.sent = convert_datetime(alert_root.find('cap:sent', ns).text)
alert.status = alert_root.find('cap:status', ns).text
if alert.status != 'Actual':
return False
alert.msg_type = alert_root.find('cap:msgType', ns).text
alert.source = find_element(alert_root, ns, 'cap:source')
alert.scope = alert_root.find('cap:scope', ns).text
alert.restriction = find_element(alert_root, ns, 'cap:restriction')
alert.addresses = find_element(alert_root, ns, 'cap:addresses')
alert.references = find_element(alert_root, ns, 'cap:references')
alert.code = find_element(alert_root, ns, 'cap:code')
alert.note = find_element(alert_root, ns, 'cap:note')
alert.references = find_element(alert_root, ns, 'cap:references')
alert.incidents = find_element(alert_root, ns, 'cap:incidents')

alert_has_valid_info = False
alert_matched_admin1 = False
# navigate alert info
for alert_info_entry in alert_root.findall('cap:info', ns):
alert_info = AlertInfo()
alert_info.alert = alert
alert_info.language = 'en-US' if (x := alert_info_entry.find('cap:language', ns)) is None else x.text
alert_info.category = alert_info_entry.find('cap:category', ns).text
alert_info.event = alert_info_entry.find('cap:event', ns).text
alert_info.response_type = find_element(alert_info_entry, ns, 'cap:responseType')
alert_info.urgency = alert_info_entry.find('cap:urgency', ns).text
alert_info.severity = alert_info_entry.find('cap:severity', ns).text
alert_info.certainty = alert_info_entry.find('cap:certainty', ns).text
alert_info.audience = find_element(alert_info_entry, ns, 'cap:audience')
alert_info.effective = alert.sent if (x := alert_info_entry.find('cap:effective', ns)) is None else x.text
alert_info.onset = convert_datetime(find_element(alert_info_entry, ns, 'cap:onset'))
expire_time = convert_datetime(find_element(alert_info_entry, ns, 'cap:expires'))
if expire_time is not None:
alert_info.expires = expire_time
if alert_info.expires < timezone.now():
continue
alert_info.sender_name = find_element(alert_info_entry, ns, 'cap:senderName')
alert_info.headline = find_element(alert_info_entry, ns, 'cap:headline')
alert_info.description = find_element(alert_info_entry, ns, 'cap:description')
alert_info.instruction = find_element(alert_info_entry, ns, 'cap:instruction')
alert_info.web = find_element(alert_info_entry, ns, 'cap:web')
alert_info.contact = find_element(alert_info_entry, ns, 'cap:contact')

alert.save()
alert_info.save()
alert_has_valid_info = True

# navigate alert info parameter
for alert_info_parameter_entry in alert_info_entry.findall('cap:parameter', ns):
alert_info_parameter = AlertInfoParameter()
alert_info_parameter.alert_info = alert_info
alert_info_parameter.value_name = alert_info_parameter_entry.find('cap:valueName', ns).text
alert_info_parameter.value = alert_info_parameter_entry.find('cap:value', ns).text
alert_info_parameter.save()

# navigate alert info area
for alert_info_area_entry in alert_info_entry.findall('cap:area', ns):
alert_info_area = AlertInfoArea()
alert_info_area.alert_info = alert_info
alert_info_area.area_desc = alert_info_area_entry.find('cap:areaDesc', ns).text
alert_info_area.altitude = find_element(alert_info_entry, ns, 'cap:altitude')
alert_info_area.ceiling = find_element(alert_info_entry, ns, 'cap:ceiling')
alert_info_area.save()

# navigate alert info area polygon
polygons = []
for alert_info_area_polygon_entry in alert_info_area_entry.findall('cap:polygon', ns):
if alert_info_area_polygon_entry is not None and alert_info_area_polygon_entry.text:
alert_info_area_polygon = AlertInfoAreaPolygon()
alert_info_area_polygon.alert_info_area = alert_info_area
alert_info_area_polygon.value = alert_info_area_polygon_entry.text.strip()
alert_info_area_polygon.save()
if parsed_polygon := alert_info_area_polygon.value_geojson:
polygons.append(parsed_polygon)

# check polygon intersection with admin1s
for polygon in polygons:
possible_admin1s = Admin1.objects.filter(
country=alert.country,
# TODO: Check for performance issues
geometry__intersects=polygon,
)
for admin1 in possible_admin1s:
if AlertAdmin1.objects.filter(alert=alert, admin1=admin1).exists():
continue
# TODO: Use bulk manager
alert_admin1 = AlertAdmin1.objects.create(
alert=alert,
admin1=admin1,
)
alert_matched_admin1 = True

# navigate alert info area circle
for alert_info_area_circle_entry in alert_info_area_entry.findall('cap:circle', ns):
alert_info_area_circle = AlertInfoAreaCircle()
alert_info_area_circle.alert_info_area = alert_info_area
alert_info_area_circle.value = alert_info_area_circle_entry.text
alert_info_area_circle.save()

# navigate info area geocode
for alert_info_area_geocode_entry in alert_info_area_entry.findall('cap:geocode', ns):
alert_info_area_geocode = AlertInfoAreaGeocode()
alert_info_area_geocode.alert_info_area = alert_info_area
alert_info_area_geocode.value_name = alert_info_area_geocode_entry.find('cap:valueName', ns).text
alert_info_area_geocode.value = alert_info_area_geocode_entry.find('cap:value', ns).text
alert_info_area_geocode.save()

if alert_has_valid_info:
if not alert_matched_admin1:
unknown_admin1 = Admin1.objects.filter(country=alert.country, name='Unknown').first()
if unknown_admin1:
alert_admin1 = AlertAdmin1()
alert_admin1.alert = alert
alert_admin1.admin1 = unknown_admin1
alert_admin1.save()
alert_matched_admin1 = True

alert.info_has_been_added()
alert.save()
alert = process_alert(url, alert_root, feed, ns)
if alert:
return True

except AttributeError as e:
log_attributeerror(feed, e, url)
except IntegrityError as e:
Expand All @@ -168,9 +235,9 @@ def get_alert(url, alert_root, feed, ns) -> bool:
log_valueerror(feed, e, url)
finally:
if alert is not None:
processed_alert = ProcessedAlert()
processed_alert.url = alert.url
processed_alert.feed = alert.feed
processed_alert.save()

# TODO: Do we need this?
ProcessedAlert.objects.create(
url=alert.url,
feed=alert.feed,
)
return False
12 changes: 8 additions & 4 deletions apps/cap_feed/formats/format_handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging

from apps.cap_feed.models import Alert
from main.sentry import SentryTag

from .atom import get_alerts_atom
from .nws_us import get_alerts_nws_us
from .rss import get_alerts_rss

logger = logging.getLogger(__name__)


def get_alerts(feed, all_alert_urls=set()):
# track new alerts
Expand All @@ -14,7 +18,7 @@ def get_alerts(feed, all_alert_urls=set()):
# track if poll was valid
valid_poll = False

print(f'Processing feed: {feed}')
logger.info(f'Processing feed: {feed}')

SentryTag.set_tags({SentryTag.Tag.FEED: feed.pk})
try:
Expand All @@ -27,10 +31,10 @@ def get_alerts(feed, all_alert_urls=set()):
case "nws_us":
alert_urls, polled_alerts_count, valid_poll = get_alerts_nws_us(feed, ns)
case _:
print(f'Format not supported: {feed}')
logger.error(f'Format not supported: {feed}')
alert_urls, polled_alerts_count, valid_poll = set(), 0, True
except Exception as e:
print(f'Error getting alerts from {feed.url}: {e}')
except Exception:
logger.error(f'Error getting alerts from {feed.url}', exc_info=True)
else:
if valid_poll:
# alerts that are in the database and have not expired but are no longer available -
Expand Down
Loading

0 comments on commit 3c6e6af

Please sign in to comment.