diff --git a/async_upnp_client/advertisement.py b/async_upnp_client/advertisement.py index 66ef7836..e9569f68 100644 --- a/async_upnp_client/advertisement.py +++ b/async_upnp_client/advertisement.py @@ -66,24 +66,26 @@ def __init__( def _on_data(self, request_line: str, headers: CaseInsensitiveDict) -> None: """Handle data.""" - if headers.get("MAN") == SSDP_DISCOVER: + if headers.get_lower("man") == SSDP_DISCOVER: # Ignore discover packets. return - if "NTS" not in headers: + + notification_sub_type = headers.get_lower("nts") + if notification_sub_type is None: _LOGGER.debug("Got non-advertisement packet: %s, %s", request_line, headers) return - _LOGGER.debug( - "Received advertisement, _remote_addr: %s, NT: %s, NTS: %s, USN: %s, location: %s", - headers.get("_remote_addr", ""), - headers.get("NT", ""), - headers.get("NTS", ""), - headers.get("USN", ""), - headers.get("location", ""), - ) + if _LOGGER.isEnabledFor(logging.DEBUG): + _LOGGER.debug( + "Received advertisement, _remote_addr: %s, NT: %s, NTS: %s, USN: %s, location: %s", + headers.get_lower("_remote_addr", ""), + headers.get_lower("nt", ""), + headers.get_lower("nts", ""), + headers.get_lower("usn", ""), + headers.get_lower("location", ""), + ) headers["_source"] = SsdpSource.ADVERTISEMENT - notification_sub_type = headers["NTS"] if notification_sub_type == NotificationSubType.SSDP_ALIVE: if self.async_on_alive: coro = self.async_on_alive(headers) diff --git a/async_upnp_client/server.py b/async_upnp_client/server.py index 07657067..f8c5fef9 100644 --- a/async_upnp_client/server.py +++ b/async_upnp_client/server.py @@ -482,16 +482,20 @@ def _on_data( # pylint: disable=too-many-branches assert self._transport - if request_line != "M-SEARCH * HTTP/1.1" or headers.get("MAN") != SSDP_DISCOVER: + if ( + request_line != "M-SEARCH * HTTP/1.1" + or headers.get_lower("man") != SSDP_DISCOVER + ): return - remote_addr = headers["_remote_addr"] + remote_addr = headers.get_lower("_remote_addr") _LOGGER.debug("Received M-SEARCH from: %s, headers: %s", remote_addr, headers) loop = asyncio.get_running_loop() - if "MX" in headers: + mx_header = headers.get_lower("mx") + if mx_header is not None: try: - delay = int(headers["MX"]) + delay = int(mx_header) _LOGGER.debug("Deferring response for %d seconds", delay) except ValueError: delay = 0 @@ -501,8 +505,9 @@ def _on_data( def _deferred_on_data(self, headers: CaseInsensitiveDict) -> None: # Determine how we should respond, page 1.3.2 of UPnP-arch-DeviceArchitecture-v2.0. - remote_addr = headers["_remote_addr"] - search_target = headers["st"].lower() + remote_addr = headers.get_lower("_remote_addr") + st_header: str = headers.get_lower("st", "") + search_target = st_header.lower() if search_target == SSDP_ST_ALL: # 3 + 2d + k (d: embedded device, k: service) # global: ST: upnp:rootdevice diff --git a/async_upnp_client/ssdp_listener.py b/async_upnp_client/ssdp_listener.py index 173ff498..1cee0699 100644 --- a/async_upnp_client/ssdp_listener.py +++ b/async_upnp_client/ssdp_listener.py @@ -41,9 +41,9 @@ def valid_search_headers(headers: CaseInsensitiveDict) -> bool: """Validate if this search is usable.""" # pylint: disable=invalid-name - udn = headers.get("_udn") # type: Optional[str] - st = headers.get("st") # type: Optional[str] - location = headers.get("location", "") # type: str + udn = headers.get_lower("_udn") # type: Optional[str] + st = headers.get_lower("st") # type: Optional[str] + location = headers.get_lower("location", "") # type: str return bool( udn and st @@ -60,10 +60,10 @@ def valid_search_headers(headers: CaseInsensitiveDict) -> bool: def valid_advertisement_headers(headers: CaseInsensitiveDict) -> bool: """Validate if this advertisement is usable for connecting to a device.""" # pylint: disable=invalid-name - udn = headers.get("_udn") # type: Optional[str] - nt = headers.get("nt") # type: Optional[str] - nts = headers.get("nts") # type: Optional[str] - location = headers.get("location", "") # type: str + udn = headers.get_lower("_udn") # type: Optional[str] + nt = headers.get_lower("nt") # type: Optional[str] + nts = headers.get_lower("nts") # type: Optional[str] + location = headers.get_lower("location", "") # type: str return bool( udn and nt @@ -81,22 +81,22 @@ def valid_advertisement_headers(headers: CaseInsensitiveDict) -> bool: def valid_byebye_headers(headers: CaseInsensitiveDict) -> bool: """Validate if this advertisement has required headers for byebye.""" # pylint: disable=invalid-name - udn = headers.get("_udn") # type: Optional[str] - nt = headers.get("nt") # type: Optional[str] - nts = headers.get("nts") # type: Optional[str] + udn = headers.get_lower("_udn") # type: Optional[str] + nt = headers.get_lower("nt") # type: Optional[str] + nts = headers.get_lower("nts") # type: Optional[str] return bool(udn and nt and nts) def extract_valid_to(headers: CaseInsensitiveDict) -> datetime: """Extract/create valid to.""" - cache_control = headers.get("cache-control", "") + cache_control = headers.get_lower("cache-control", "") match = CACHE_CONTROL_RE.search(cache_control) if match: max_age = int(match[1]) uncache_after = timedelta(seconds=max_age) else: uncache_after = DEFAULT_MAX_AGE - timestamp: datetime = headers["_timestamp"] + timestamp: datetime = headers.get_lower("_timestamp") return timestamp + uncache_after @@ -247,7 +247,7 @@ def ip_version_from_location(location: str) -> Optional[int]: def location_changed(ssdp_device: SsdpDevice, headers: CaseInsensitiveDict) -> bool: """Test if location changed for device.""" - new_location = headers.get("location", "") + new_location = headers.get_lower("location", "") if not new_location: return False @@ -295,14 +295,14 @@ def see_search( _LOGGER.debug("Received invalid search headers: %s", headers) return False, None, None, None - udn = headers["_udn"] + udn = headers.get_lower("_udn") is_new_device = udn not in self.devices ssdp_device, new_location = self._see_device(headers) if not ssdp_device: return False, None, None, None - search_target: SearchTarget = headers["ST"] + search_target: SearchTarget = headers.get_lower("st") is_new_service = ( search_target not in ssdp_device.advertisement_headers and search_target not in ssdp_device.search_headers @@ -339,14 +339,14 @@ def see_advertisement( _LOGGER.debug("Received invalid advertisement headers: %s", headers) return False, None, None - udn = headers["_udn"] + udn = headers.get_lower("_udn") is_new_device = udn not in self.devices ssdp_device, new_location = self._see_device(headers) if not ssdp_device: return False, None, None - notification_type: NotificationType = headers["NT"] + notification_type: NotificationType = headers.get_lower("nt") is_new_service = ( notification_type not in ssdp_device.advertisement_headers and notification_type not in ssdp_device.search_headers @@ -356,7 +356,7 @@ def see_advertisement( "See new service: %s, type: %s", ssdp_device, notification_type ) - notification_sub_type: NotificationSubType = headers["NTS"] + notification_sub_type: NotificationSubType = headers.get_lower("nts") propagate = ( notification_sub_type == NotificationSubType.SSDP_UPDATE or is_new_device @@ -407,8 +407,8 @@ def _see_device( new_location = location_changed(ssdp_device, headers) # Update device. - ssdp_device.add_location(headers["location"], valid_to) - ssdp_device.last_seen = headers["_timestamp"] + ssdp_device.add_location(headers.get_lower("location"), valid_to) + ssdp_device.last_seen = headers.get_lower("_timestamp") if not self.next_valid_to or self.next_valid_to > ssdp_device.valid_to: self.next_valid_to = ssdp_device.valid_to @@ -433,7 +433,7 @@ def unsee_advertisement( del self.devices[udn] # Update device before propagating it - notification_type: NotificationType = headers["NT"] + notification_type: NotificationType = headers.get_lower("nt") if notification_type in ssdp_device.advertisement_headers: ssdp_device.advertisement_headers[notification_type].replace(headers) else: diff --git a/async_upnp_client/utils.py b/async_upnp_client/utils.py index 19e3658e..b78da815 100644 --- a/async_upnp_client/utils.py +++ b/async_upnp_client/utils.py @@ -41,12 +41,12 @@ def as_lower_dict(self) -> Dict[str, Any]: """Return the underlying dict in lowercase.""" return {k.lower(): v for k, v in self._data.items()} - def get_lower(self, lower_key: str) -> Any: + def get_lower(self, lower_key: str, default: Any = None) -> Any: """Get a lower case key.""" data_key = self._case_map.get(lower_key, _SENTINEL) if data_key is not _SENTINEL: - return self._data[data_key] - return None + return self._data.get(data_key, default) + return default def replace(self, new_data: abcMapping) -> None: """Replace the underlying dict."""