diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index ae894e57..388e0abc 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -7,6 +7,7 @@ import twisted.internet.base from autobahn.twisted.util import sleep +from autobahn.websocket.protocol import ConnectionRequest from boto.dynamodb2.exceptions import ( ProvisionedThroughputExceededException, ItemNotFound @@ -127,9 +128,9 @@ def setUp(self): self.proto.sendMessage = self.send_mock = Mock() self.orig_close = self.proto.sendClose - request_mock = Mock() + request_mock = Mock(spec=ConnectionRequest) request_mock.headers = {} - self.proto.ps = PushState(db=db, request=request_mock) + self.proto.ps = PushState.from_request(request=request_mock, db=db) self.proto.sendClose = self.close_mock = Mock() self.proto.transport = self.transport_mock = Mock() self.proto.closeHandshakeTimeout = 0 @@ -140,7 +141,10 @@ def tearDown(self): self.proto.force_retry = self.proto._force_retry def _connect(self): - self.proto.onConnect(None) + req = Mock(spec=ConnectionRequest) + req.headers = {} + req.host = None + self.proto.onConnect(req) def _send_message(self, msg): self.proto.onMessage(json.dumps(msg).encode('utf8'), False) @@ -241,7 +245,6 @@ def test_producer_interface(self): eq_(self.proto.ps._should_stop, True) def test_headers_locate(self): - from autobahn.websocket.protocol import ConnectionRequest req = ConnectionRequest("localhost", {"user-agent": "Me"}, "localhost", "/", {}, 1, "localhost", [], []) @@ -255,7 +258,7 @@ def test_base_tags(self): "rv:1.9.2.3) Gecko/20100401 Firefox/3.6.3 (.NET " "CLR 3.5.30729)"} req.host = "example.com:8080" - ps = PushState(db=self.proto.db, request=req) + ps = PushState.from_request(request=req, db=self.proto.db) eq_(sorted(ps._base_tags), sorted(['ua_os_family:Windows', 'ua_browser_family:Firefox', diff --git a/autopush/websocket.py b/autopush/websocket.py index a2e31eab..737f0704 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -38,6 +38,7 @@ import attr from attr import ( + Factory, attrs, attrib ) @@ -46,6 +47,7 @@ WebSocketServerFactory, WebSocketServerProtocol ) +from autobahn.websocket.protocol import ConnectionRequest # noqa from boto.dynamodb2.exceptions import ( ProvisionedThroughputExceededException, ItemNotFound @@ -95,6 +97,7 @@ from autopush.metrics import IMetrics # noqa from autopush.settings import AutopushSettings # noqa from autopush.ssl import AutopushSSLContextFactory +from autopush.types import JSONDict # noqa from autopush.utils import ( parse_user_agent, validate_uaid, @@ -167,109 +170,86 @@ def logging_data(self): @implementer(IProducer) +@attrs(slots=True) class PushState(object): + """Compact storage of a PushProtocolConnection's state""" + + db = attrib() # type: DatabaseManager + _callbacks = attrib(default=Factory(list)) # type: List[Deferred] + + stats = attrib( + default=Factory(SessionStatistics)) # type: SessionStatistics + + _user_agent = attrib(default=None) # type: Optional[str] + _base_tags = attrib(default=Factory(list)) # type: List[str] + raw_agent = attrib(default=Factory(dict)) # type: Optional[Dict[str, str]] + + _should_stop = attrib(default=False) # type: bool + _paused = attrib(default=False) # type: bool + + _uaid_obj = attrib(default=None) # type: Optional[uuid.UUID] + _uaid_hash = attrib(default=None) # type: Optional[str] + + last_ping = attrib(default=0.0) # type: float + check_storage = attrib(default=False) # type: bool + use_webpush = attrib(default=False) # type: bool + router_type = attrib(default=None) # type: Optional[str] + wake_data = attrib(default=None) # type: Optional[JSONDict] + connected_at = attrib(default=Factory(ms_time)) # type: float + ping_time_out = attrib(default=False) # type: bool + + # Message table rotation + message_month = attrib(init=False) # type: str + rotate_message_table = attrib(default=False) # type: bool + + _check_notifications = attrib(default=False) # type: bool + _more_notifications = attrib(default=False) # type: bool + + # Timestamped message handling defaults + scan_timestamps = attrib(default=False) # type: bool + current_timestamp = attrib(default=None) # type: Optional[int] + + # Hanger for common actions we defer + _notification_fetch = attrib(default=None) # type: Optional[Deferred] + _register = attrib(default=None) # type: Optional[Deferred] + + # Reflects Notification's sent that haven't been ack'd This is + # simplepush style by default + updates_sent = attrib(default=Factory(dict)) # type: Dict + + # Track Notification's we don't need to delete separately This is + # simplepush style by default + direct_updates = attrib(default=Factory(dict)) # type: Dict + + # Whether this record should be reset after delivering stored + # messages + _reset_uaid = attrib(default=False) # type: bool + + @classmethod + def from_request(cls, request, **kwargs): + # type: (ConnectionRequest, **Any) -> PushState + return cls( + user_agent=request.headers.get("user-agent"), + stats=SessionStatistics(host=request.host), + **kwargs + ) - __slots__ = [ - '_callbacks', - '_user_agent', - '_base_tags', - '_should_stop', - '_paused', - '_uaid_obj', - '_uaid_hash', - 'raw_agent', - 'last_ping', - 'check_storage', - 'use_webpush', - 'router_type', - 'wake_data', - 'connected_at', - 'db', - 'stats', - - # Table rotation - 'message_month', - 'message', - 'rotate_message_table', - - # Timestamped message handling - 'scan_timestamps', - 'current_timestamp', - - 'ping_time_out', - '_check_notifications', - '_more_notifications', - '_notification_fetch', - '_register', - 'updates_sent', - 'direct_updates', - - '_reset_uaid', - ] - - def __init__(self, db, request): - self._callbacks = [] - self.stats = SessionStatistics() - self.db = db - host = "" - - if request: - self._user_agent = request.headers.get("user-agent") - # Get the name of the server the request asked for. - host = request.host - else: - self._user_agent = None - - self.stats.host = host - self._base_tags = [] - self.raw_agent = {} + def __attrs_post_init__(self): + """Initialize PushState""" if self._user_agent: dd_tags, self.raw_agent = parse_user_agent(self._user_agent) for tag_name, tag_value in dd_tags.items(): setattr(self.stats, tag_name, tag_value) self._base_tags.append("%s:%s" % (tag_name, tag_value)) - if host: - self._base_tags.append("host:%s" % host) + if self.stats.host: + self._base_tags.append("host:%s" % self.stats.host) - db.metrics.increment("client.socket.connect", - tags=self._base_tags or None) - - self._should_stop = False - self._paused = False - self.uaid = None - self.last_ping = 0 - self.check_storage = False - self.use_webpush = False - self.router_type = None - self.wake_data = None - self.connected_at = ms_time() - self.ping_time_out = False + self.db.metrics.increment("client.socket.connect", + tags=self._base_tags or None) # Message table rotation initial settings - self.message_month = db.current_msg_month - self.rotate_message_table = False - - self._check_notifications = False - self._more_notifications = False - - # Timestamp message defaults - self.scan_timestamps = False - self.current_timestamp = None - - # Hanger for common actions we defer - self._notification_fetch = None - self._register = None - - # Reflects Notification's sent that haven't been ack'd - # This is simplepush style by default - self.updates_sent = {} - - # Track Notification's we don't need to delete separately - # This is simplepush style by default - self.direct_updates = {} + self.message_month = self.db.current_msg_month - # Whether this record should be reset after delivering stored - # messages self.reset_uaid = False @property @@ -491,7 +471,7 @@ def nukeConnection(self): def onConnect(self, request): """autobahn onConnect handler for when a connection has started""" track_object(self, msg="onConnect Start") - self.ps = PushState(db=self.db, request=request) + self.ps = PushState.from_request(request=request, db=self.db) # Setup ourself to handle producing the data self.transport.bufferSize = 2 * 1024