Skip to content

Commit

Permalink
Split out connection code
Browse files Browse the repository at this point in the history
  • Loading branch information
epenet committed Feb 21, 2022
1 parent 08d7265 commit 382f21d
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 105 deletions.
152 changes: 152 additions & 0 deletions samsungtvws/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
SamsungTVWS - Samsung Smart TV WS API wrapper
Copyright (C) 2019 Xchwarze
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor,
Boston, MA 02110-1335 USA
"""
import json
import logging
import ssl
import time

import websocket

from . import exceptions, helper

_LOGGING = logging.getLogger(__name__)


class SamsungTVWSBaseConnection:
_URL_FORMAT = "ws://{host}:{port}/api/v2/channels/{app}?name={name}"
_SSL_URL_FORMAT = (
"wss://{host}:{port}/api/v2/channels/{app}?name={name}&token={token}"
)

def __init__(
self,
host,
*,
token=None,
token_file=None,
port=8001,
timeout=None,
key_press_delay=1,
name="SamsungTvRemote",
endpoint="samsung.remote.control",
):
self.host = host
self.token = token
self.token_file = token_file
self.port = port
self.timeout = None if timeout == 0 else timeout
self.key_press_delay = key_press_delay
self.name = name
self.endpoint = endpoint
self.connection = None

def _is_ssl_connection(self):
return self.port == 8002

def _format_websocket_url(self, app):
params = {
"host": self.host,
"port": self.port,
"app": app,
"name": helper.serialize_string(self.name),
"token": self._get_token(),
}

if self._is_ssl_connection():
return self._SSL_URL_FORMAT.format(**params)
else:
return self._URL_FORMAT.format(**params)

def _get_token(self):
if self.token_file is not None:
try:
with open(self.token_file, "r") as token_file:
return token_file.readline()
except:
return ""
else:
return self.token

def _set_token(self, token):
_LOGGING.info("New token %s", token)
if self.token_file is not None:
_LOGGING.debug("Save token to file: %s", token)
with open(self.token_file, "w") as token_file:
token_file.write(token)
else:
self.token = token

def _check_for_token(self, response):
if response.get("data") and response["data"].get("token"):
token = response["data"].get("token")
_LOGGING.debug("Got token %s", token)
self._set_token(token)


class SamsungTVWSConnection(SamsungTVWSBaseConnection):
def __enter__(self):
return self

def __exit__(self, type, value, traceback):
self.close()

def open(self):
url = self._format_websocket_url(self.endpoint)
sslopt = {"cert_reqs": ssl.CERT_NONE} if self._is_ssl_connection() else {}

_LOGGING.debug("WS url %s", url)
# Only for debug use!
# websocket.enableTrace(True)
connection = websocket.create_connection(
url,
self.timeout,
sslopt=sslopt,
# Use 'connection' for fix websocket-client 0.57 bug
# header={'Connection': 'Upgrade'}
connection="Connection: Upgrade",
)

response = helper.process_api_response(connection.recv())
self._check_for_token(response)

if response["event"] != "ms.channel.connect":
self.close()
raise exceptions.ConnectionFailure(response)

return connection

def close(self):
if self.connection:
self.connection.close()

self.connection = None
_LOGGING.debug("Connection closed.")

def send_command(self, command, key_press_delay=None):
if self.connection is None:
self.connection = self.open()

payload = json.dumps(command)
self.connection.send(payload)

delay = self.key_press_delay if key_press_delay is None else key_press_delay
time.sleep(delay)
123 changes: 21 additions & 102 deletions samsungtvws/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,18 @@
Boston, MA 02110-1335 USA
"""
import json
import logging
import time
import ssl
import websocket
from typing import overload

import requests
from . import art
from . import exceptions
from . import shortcuts
from . import helper

from . import art, connection, exceptions, helper, shortcuts

_LOGGING = logging.getLogger(__name__)


class SamsungTVWS:
_URL_FORMAT = "ws://{host}:{port}/api/v2/channels/{app}?name={name}"
_SSL_URL_FORMAT = (
"wss://{host}:{port}/api/v2/channels/{app}?name={name}&token={token}"
)
class SamsungTVWS(connection.SamsungTVWSConnection):
_REST_URL_FORMAT = "{protocol}://{host}:{port}/api/v2/{route}"

def __init__(
Expand All @@ -50,37 +43,15 @@ def __init__(
key_press_delay=1,
name="SamsungTvRemote",
):
self.host = host
self.token = token
self.token_file = token_file
self.port = port
self.timeout = None if timeout == 0 else timeout
self.key_press_delay = key_press_delay
self.name = name
self.connection = None

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
self.close()

def _is_ssl_connection(self):
return self.port == 8002

def _format_websocket_url(self, app):
params = {
"host": self.host,
"port": self.port,
"app": app,
"name": helper.serialize_string(self.name),
"token": self._get_token(),
}

if self._is_ssl_connection():
return self._SSL_URL_FORMAT.format(**params)
else:
return self._URL_FORMAT.format(**params)
super().__init__(
host,
token=token,
token_file=token_file,
port=port,
timeout=timeout,
key_press_delay=key_press_delay,
name=name,
)

def _format_rest_url(self, route=""):
params = {
Expand All @@ -92,34 +63,8 @@ def _format_rest_url(self, route=""):

return self._REST_URL_FORMAT.format(**params)

def _get_token(self):
if self.token_file is not None:
try:
with open(self.token_file, "r") as token_file:
return token_file.readline()
except:
return ""
else:
return self.token

def _set_token(self, token):
_LOGGING.info("New token %s", token)
if self.token_file is not None:
_LOGGING.debug("Save token to file: %s", token)
with open(self.token_file, "w") as token_file:
token_file.write(token)
else:
self.token = token

def _ws_send(self, command, key_press_delay=None):
if self.connection is None:
self.connection = self.open("samsung.remote.control")

payload = json.dumps(command)
self.connection.send(payload)

delay = self.key_press_delay if key_press_delay is None else key_press_delay
time.sleep(delay)
return super().send_command(command, key_press_delay)

def _rest_request(self, target, method="GET"):
url = self._format_rest_url(target)
Expand All @@ -137,40 +82,14 @@ def _rest_request(self, target, method="GET"):
"TV unreachable or feature not supported on this model."
)

# endpoint is kept here for compatibility - can be removed in v2
@overload
def open(self, endpoint):
url = self._format_websocket_url(endpoint)
sslopt = {"cert_reqs": ssl.CERT_NONE} if self._is_ssl_connection() else {}

_LOGGING.debug("WS url %s", url)
# Only for debug use!
# websocket.enableTrace(True)
connection = websocket.create_connection(
url,
self.timeout,
sslopt=sslopt,
# Use 'connection' for fix websocket-client 0.57 bug
# header={'Connection': 'Upgrade'}
connection="Connection: Upgrade",
)

response = helper.process_api_response(connection.recv())
if response.get("data") and response.get("data").get("token"):
token = response.get("data").get("token")
_LOGGING.debug("Got token %s", token)
self._set_token(token)

if response["event"] != "ms.channel.connect":
self.close()
raise exceptions.ConnectionFailure(response)

return connection

def close(self):
if self.connection:
self.connection.close()
return super().open()

self.connection = None
_LOGGING.debug("Connection closed.")
# required for overloading - can be removed in v2
def open(self):
return super().open()

def send_key(self, key, times=1, key_press_delay=None, cmd="Click"):
for _ in range(times):
Expand Down
10 changes: 7 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@

@pytest.fixture(autouse=True)
def override_time_sleep():
"""Open a websocket connection."""
with patch("samsungtvws.remote.time.sleep"):
"""Ignore time sleep in tests."""
with patch("samsungtvws.connection.time.sleep"), patch(
"samsungtvws.remote.time.sleep"
):
yield


@pytest.fixture(name="connection")
def get_connection():
"""Open a websocket connection."""
connection = Mock()
with patch("samsungtvws.remote.websocket.create_connection") as connection_class:
with patch(
"samsungtvws.connection.websocket.create_connection"
) as connection_class:
connection_class.return_value = connection
yield connection

0 comments on commit 382f21d

Please sign in to comment.