diff --git a/spond/spond.py b/spond/spond.py index b95f3a9..7874fd7 100644 --- a/spond/spond.py +++ b/spond/spond.py @@ -6,6 +6,10 @@ import aiohttp +class AuthenticationError(Exception): + pass + + class Spond: def __init__(self, username, password): self.username = username @@ -31,7 +35,10 @@ async def login(self): data = {"email": self.username, "password": self.password} async with self.clientsession.post(login_url, json=data) as r: login_result = await r.json() - self.token = login_result["loginToken"] + self.token = login_result.get("loginToken", None) + if self.token is None: + err_msg = f"Login failed. Response received: {login_result}" + raise AuthenticationError(err_msg) api_chat_url = f"{self.api_url}chat" headers = { @@ -43,6 +50,19 @@ async def login(self): self.chat_url = result["url"] self.auth = result["auth"] + def require_authentication(func: callable): + async def wrapper(self, *args, **kwargs): + if not self.token: + try: + await self.login() + except AuthenticationError as e: + await self.clientsession.close() + raise e + return await func(self, *args, **kwargs) + + return wrapper + + @require_authentication async def get_groups(self): """ Get all groups. @@ -53,13 +73,12 @@ async def get_groups(self): list of dict Groups; each group is a dict. """ - if not self.token: - await self.login() url = f"{self.api_url}groups/" async with self.clientsession.get(url, headers=self.auth_headers) as r: self.groups = await r.json() return self.groups + @require_authentication async def get_group(self, uid) -> dict: """ Get a group by unique ID. @@ -74,8 +93,7 @@ async def get_group(self, uid) -> dict: ------- Details of the group. """ - if not self.token: - await self.login() + if not self.groups: await self.get_groups() for group in self.groups: @@ -83,6 +101,7 @@ async def get_group(self, uid) -> dict: return group raise IndexError + @require_authentication async def get_person(self, user) -> dict: """ Get a member or guardian by matching various identifiers. @@ -98,8 +117,6 @@ async def get_person(self, user) -> dict: ------- Member or guardian's details. """ - if not self.token: - await self.login() if not self.groups: await self.get_groups() for group in self.groups: @@ -126,14 +143,14 @@ async def get_person(self, user) -> dict: return guardian raise IndexError + @require_authentication async def get_messages(self): - if not self.token: - await self.login() url = f"{self.chat_url}/chats/?max=10" headers = {"auth": self.auth} async with self.clientsession.get(url, headers=headers) as r: return await r.json() + @require_authentication async def _continue_chat(self, chat_id, text): """ Send a given text in an existing given chat. @@ -152,14 +169,13 @@ async def _continue_chat(self, chat_id, text): dict Result of the sending. """ - if not self.token: - await self.login() url = f"{self.chat_url}/messages" data = {"chatId": chat_id, "text": text, "type": "TEXT"} headers = {"auth": self.auth} r = await self.clientsession.post(url, json=data, headers=headers) return await r.json() + @require_authentication async def send_message(self, text, user=None, group_uid=None, chat_id=None): """ Start a new chat or continue an existing one. @@ -192,8 +208,6 @@ async def send_message(self, text, user=None, group_uid=None, chat_id=None): "error": "wrong usage, group_id and user_id needed or continue chat with chat_id" } - if not self.token: - await self.login() user_obj = await self.get_person(user) if user_obj: user_uid = user_obj["profile"]["id"] @@ -210,6 +224,7 @@ async def send_message(self, text, user=None, group_uid=None, chat_id=None): r = await self.clientsession.post(url, json=data, headers=headers) return await r.json() + @require_authentication async def get_events( self, group_id: Optional[str] = None, @@ -259,8 +274,6 @@ async def get_events( list of dict Events; each event is a dict. """ - if not self.token: - await self.login() url = ( f"{self.api_url}sponds/?" f"max={max_events}" @@ -281,6 +294,7 @@ async def get_events( self.events = await r.json() return self.events + @require_authentication async def get_event(self, uid) -> dict: """ Get an event by unique ID. @@ -295,8 +309,6 @@ async def get_event(self, uid) -> dict: ------- Details of the event. """ - if not self.token: - await self.login() if not self.events: await self.get_events() for event in self.events: @@ -304,6 +316,7 @@ async def get_event(self, uid) -> dict: return event raise IndexError + @require_authentication async def update_event(self, uid, updates: dict): """ Updates an existing event. @@ -320,8 +333,6 @@ async def update_event(self, uid, updates: dict): json results of post command """ - if not self.token: - await self.login() if not self.events: await self.get_events() for event in self.events: