Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions async_upnp_client/advertisement.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,26 @@ def __init__(

def _on_data(self, request_line: str, headers: CaseInsensitiveDict) -> None:
"""Handle data."""
if headers.get("MAN") == SSDP_DISCOVER:
if headers.get_lower("man") == SSDP_DISCOVER:
# Ignore discover packets.
return
if "NTS" not in headers:

notification_sub_type = headers.get_lower("nts")
if notification_sub_type is None:
_LOGGER.debug("Got non-advertisement packet: %s, %s", request_line, headers)
return

_LOGGER.debug(
"Received advertisement, _remote_addr: %s, NT: %s, NTS: %s, USN: %s, location: %s",
headers.get("_remote_addr", ""),
headers.get("NT", "<no NT>"),
headers.get("NTS", "<no NTS>"),
headers.get("USN", "<no USN>"),
headers.get("location", ""),
)
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Received advertisement, _remote_addr: %s, NT: %s, NTS: %s, USN: %s, location: %s",
headers.get_lower("_remote_addr", ""),
headers.get_lower("nt", "<no NT>"),
headers.get_lower("nts", "<no NTS>"),
headers.get_lower("usn", "<no USN>"),
headers.get_lower("location", ""),
)

headers["_source"] = SsdpSource.ADVERTISEMENT
notification_sub_type = headers["NTS"]
if notification_sub_type == NotificationSubType.SSDP_ALIVE:
if self.async_on_alive:
coro = self.async_on_alive(headers)
Expand Down
17 changes: 11 additions & 6 deletions async_upnp_client/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,16 +482,20 @@ def _on_data(
# pylint: disable=too-many-branches
assert self._transport

if request_line != "M-SEARCH * HTTP/1.1" or headers.get("MAN") != SSDP_DISCOVER:
if (
request_line != "M-SEARCH * HTTP/1.1"
or headers.get_lower("man") != SSDP_DISCOVER
):
return

remote_addr = headers["_remote_addr"]
remote_addr = headers.get_lower("_remote_addr")
_LOGGER.debug("Received M-SEARCH from: %s, headers: %s", remote_addr, headers)

loop = asyncio.get_running_loop()
if "MX" in headers:
mx_header = headers.get_lower("mx")
if mx_header is not None:
try:
delay = int(headers["MX"])
delay = int(mx_header)
_LOGGER.debug("Deferring response for %d seconds", delay)
except ValueError:
delay = 0
Expand All @@ -501,8 +505,9 @@ def _on_data(

def _deferred_on_data(self, headers: CaseInsensitiveDict) -> None:
# Determine how we should respond, page 1.3.2 of UPnP-arch-DeviceArchitecture-v2.0.
remote_addr = headers["_remote_addr"]
search_target = headers["st"].lower()
remote_addr = headers.get_lower("_remote_addr")
st_header: str = headers.get_lower("st", "")
search_target = st_header.lower()
if search_target == SSDP_ST_ALL:
# 3 + 2d + k (d: embedded device, k: service)
# global: ST: upnp:rootdevice
Expand Down
42 changes: 21 additions & 21 deletions async_upnp_client/ssdp_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
def valid_search_headers(headers: CaseInsensitiveDict) -> bool:
"""Validate if this search is usable."""
# pylint: disable=invalid-name
udn = headers.get("_udn") # type: Optional[str]
st = headers.get("st") # type: Optional[str]
location = headers.get("location", "") # type: str
udn = headers.get_lower("_udn") # type: Optional[str]
st = headers.get_lower("st") # type: Optional[str]
location = headers.get_lower("location", "") # type: str
return bool(
udn
and st
Expand All @@ -60,10 +60,10 @@ def valid_search_headers(headers: CaseInsensitiveDict) -> bool:
def valid_advertisement_headers(headers: CaseInsensitiveDict) -> bool:
"""Validate if this advertisement is usable for connecting to a device."""
# pylint: disable=invalid-name
udn = headers.get("_udn") # type: Optional[str]
nt = headers.get("nt") # type: Optional[str]
nts = headers.get("nts") # type: Optional[str]
location = headers.get("location", "") # type: str
udn = headers.get_lower("_udn") # type: Optional[str]
nt = headers.get_lower("nt") # type: Optional[str]
nts = headers.get_lower("nts") # type: Optional[str]
location = headers.get_lower("location", "") # type: str
return bool(
udn
and nt
Expand All @@ -81,22 +81,22 @@ def valid_advertisement_headers(headers: CaseInsensitiveDict) -> bool:
def valid_byebye_headers(headers: CaseInsensitiveDict) -> bool:
"""Validate if this advertisement has required headers for byebye."""
# pylint: disable=invalid-name
udn = headers.get("_udn") # type: Optional[str]
nt = headers.get("nt") # type: Optional[str]
nts = headers.get("nts") # type: Optional[str]
udn = headers.get_lower("_udn") # type: Optional[str]
nt = headers.get_lower("nt") # type: Optional[str]
nts = headers.get_lower("nts") # type: Optional[str]
return bool(udn and nt and nts)


def extract_valid_to(headers: CaseInsensitiveDict) -> datetime:
"""Extract/create valid to."""
cache_control = headers.get("cache-control", "")
cache_control = headers.get_lower("cache-control", "")
match = CACHE_CONTROL_RE.search(cache_control)
if match:
max_age = int(match[1])
uncache_after = timedelta(seconds=max_age)
else:
uncache_after = DEFAULT_MAX_AGE
timestamp: datetime = headers["_timestamp"]
timestamp: datetime = headers.get_lower("_timestamp")
return timestamp + uncache_after


Expand Down Expand Up @@ -247,7 +247,7 @@ def ip_version_from_location(location: str) -> Optional[int]:

def location_changed(ssdp_device: SsdpDevice, headers: CaseInsensitiveDict) -> bool:
"""Test if location changed for device."""
new_location = headers.get("location", "")
new_location = headers.get_lower("location", "")
if not new_location:
return False

Expand Down Expand Up @@ -295,14 +295,14 @@ def see_search(
_LOGGER.debug("Received invalid search headers: %s", headers)
return False, None, None, None

udn = headers["_udn"]
udn = headers.get_lower("_udn")
is_new_device = udn not in self.devices

ssdp_device, new_location = self._see_device(headers)
if not ssdp_device:
return False, None, None, None

search_target: SearchTarget = headers["ST"]
search_target: SearchTarget = headers.get_lower("st")
is_new_service = (
search_target not in ssdp_device.advertisement_headers
and search_target not in ssdp_device.search_headers
Expand Down Expand Up @@ -339,14 +339,14 @@ def see_advertisement(
_LOGGER.debug("Received invalid advertisement headers: %s", headers)
return False, None, None

udn = headers["_udn"]
udn = headers.get_lower("_udn")
is_new_device = udn not in self.devices

ssdp_device, new_location = self._see_device(headers)
if not ssdp_device:
return False, None, None

notification_type: NotificationType = headers["NT"]
notification_type: NotificationType = headers.get_lower("nt")
is_new_service = (
notification_type not in ssdp_device.advertisement_headers
and notification_type not in ssdp_device.search_headers
Expand All @@ -356,7 +356,7 @@ def see_advertisement(
"See new service: %s, type: %s", ssdp_device, notification_type
)

notification_sub_type: NotificationSubType = headers["NTS"]
notification_sub_type: NotificationSubType = headers.get_lower("nts")
propagate = (
notification_sub_type == NotificationSubType.SSDP_UPDATE
or is_new_device
Expand Down Expand Up @@ -407,8 +407,8 @@ def _see_device(
new_location = location_changed(ssdp_device, headers)

# Update device.
ssdp_device.add_location(headers["location"], valid_to)
ssdp_device.last_seen = headers["_timestamp"]
ssdp_device.add_location(headers.get_lower("location"), valid_to)
ssdp_device.last_seen = headers.get_lower("_timestamp")
if not self.next_valid_to or self.next_valid_to > ssdp_device.valid_to:
self.next_valid_to = ssdp_device.valid_to

Expand All @@ -433,7 +433,7 @@ def unsee_advertisement(
del self.devices[udn]

# Update device before propagating it
notification_type: NotificationType = headers["NT"]
notification_type: NotificationType = headers.get_lower("nt")
if notification_type in ssdp_device.advertisement_headers:
ssdp_device.advertisement_headers[notification_type].replace(headers)
else:
Expand Down
6 changes: 3 additions & 3 deletions async_upnp_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def as_lower_dict(self) -> Dict[str, Any]:
"""Return the underlying dict in lowercase."""
return {k.lower(): v for k, v in self._data.items()}

def get_lower(self, lower_key: str) -> Any:
def get_lower(self, lower_key: str, default: Any = None) -> Any:
"""Get a lower case key."""
data_key = self._case_map.get(lower_key, _SENTINEL)
if data_key is not _SENTINEL:
return self._data[data_key]
return None
return self._data.get(data_key, default)
return default

def replace(self, new_data: abcMapping) -> None:
"""Replace the underlying dict."""
Expand Down