From 4c31b577e61cebdbd51b65e443dc3edcc9c85975 Mon Sep 17 00:00:00 2001 From: Tengz Date: Tue, 9 Apr 2024 09:52:12 +0300 Subject: [PATCH 1/5] wip: async_support --- pyproject.toml | 3 +- threatx_api_client/__init__.py | 94 ++++++++++++++++++++++++---------- 2 files changed, 68 insertions(+), 29 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b70673a..9c76f43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ packages = [{include = "threatx_api_client"}] [tool.poetry.dependencies] python = "^3.8" requests = "^2.31.0" +aiohttp = "^3.9.3" [tool.poetry.group.dev.dependencies] ruff = "^0.0.288" @@ -40,7 +41,7 @@ select = [ "C90", "I", "N", - "D", +# "D", "PERF", "PL", "FURB", diff --git a/threatx_api_client/__init__.py b/threatx_api_client/__init__.py index 39f752e..e7dd52b 100644 --- a/threatx_api_client/__init__.py +++ b/threatx_api_client/__init__.py @@ -6,10 +6,13 @@ TXAPIIncorrectTokenError, TXAPIResponseError, ) +import aiohttp +import asyncio class Client: """Main API Client class.""" + def __init__(self, api_env, api_key): """Main Client class initializer.""" self.host_parts = { @@ -24,42 +27,74 @@ def __init__(self, api_env, api_key): self.api_env = api_env self.api_key = api_key + self.http_connector = aiohttp.TCPConnector(limit_per_host=100, limit=0, ttl_dns_cache=300) + self.parallel_requests = 10 + self.session_token = self.__login() - def __get_api_env_host(self, api_env): - if api_env not in self.host_parts: - raise TXAPIIncorrectEnvironmentError(f"TX API Env '{api_env}' not found!") + def __get_api_env_host(self): + if self.api_env not in self.host_parts: + raise TXAPIIncorrectEnvironmentError(f"TX API Env '{self.api_env}' not found!") - part = (f"-{self.host_parts.get(api_env)}" - if self.host_parts.get(api_env) else "") + part = (f"-{self.host_parts.get(self.api_env)}" + if self.host_parts.get(self.api_env) else "") return f"https://provision{part}.threatx.io" def __generate_api_link(self, api_ver: int): - return f"{self.__get_api_env_host(self.api_env)}/{self.api_path}/v{api_ver}" - - def __process_response(self, url: str, available_commands: list, payload: dict): - payload_command = payload.get("command") - - if payload_command not in available_commands: - raise TXAPIIncorrectCommandError(payload_command) - - auth = {"token": self.session_token} - response: dict = requests.post(url, json={**auth, **payload}).json() - - response_data = response.get("Ok") + return f"/{self.api_path}/v{api_ver}" - if response_data: - return response_data - - if response.get("Error") == "Token Expired. Please re-authenticate.": - self.session_token = self.__login() - return self.__process_response(url, available_commands, payload) - else: - raise TXAPIResponseError(response.get("Error")) + # async def post(self, url: str, available_commands, post_payload: dict): + # if post_payload.get("command") not in available_commands: + # raise TXAPIIncorrectCommandError(post_payload.get("command")) + # + # async with asyncio.Semaphore(self.parallel_requests): + # async with self.http_session.post(url, json={"token": self.session_token, **post_payload}) as response: + # response_data = await response.json() + # + # if response_data: + # result_responses.append(response_data) + # + # if response_data.get("Error") == "Token Expired. Please re-authenticate.": + # self.session_token = self.__login() + # return self.__process_response(url, available_commands, post_payload) + # else: + # raise TXAPIResponseError(response_data.get("Error")) + + async def __process_response(self, url: str, available_commands: list, payloads): + http_session = aiohttp.ClientSession( + base_url=self.__get_api_env_host(), + connector=self.http_connector + ) + semaphore = asyncio.Semaphore(self.parallel_requests) + responses = [] + + async def post(post_payload: dict): + if post_payload.get("command") not in available_commands: + raise TXAPIIncorrectCommandError(post_payload.get("command")) + + async with semaphore: + async with http_session.post(url, json={"token": self.session_token, **post_payload}) as raw_response: + response = await raw_response.json() + response_ok_data = response.get("Ok") + response_error_data = response.get("Error") + + if response_ok_data: + responses.append(response_ok_data) + return None + + if response_error_data == "Token Expired. Please re-authenticate.": + self.session_token = self.__login() + return post(post_payload) + else: + raise TXAPIResponseError(response_error_data) + + await asyncio.gather(*(post(payload) for payload in payloads)) + await http_session.close() + return responses def __login(self): - url = f"{self.__generate_api_link(1)}/login" + url = f"{self.__get_api_env_host()}{self.__generate_api_link(1)}/login" if not self.api_key: raise TXAPIIncorrectTokenError("Please provide TX API Key.") @@ -308,7 +343,7 @@ def list_whitelist(self, payload): return self.__process_response(url, available_commands, payload) - def list_blacklist(self, payload): + def list_blacklist(self, payloads): """Get blacklist IPs. Method allows to get customer blacklisted IPs. @@ -319,7 +354,10 @@ def list_blacklist(self, payload): available_commands = ["list"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete(self.__process_response(url, available_commands, payloads)) + self.http_connector.close() + return results def list_blocklist(self, payload): """Get blocklisted IPs. From b661a1dbf9c90c7bd571043adb9c783869f272b9 Mon Sep 17 00:00:00 2001 From: Tengz Date: Tue, 9 Apr 2024 10:41:44 +0300 Subject: [PATCH 2/5] refactor: optimizations, get rid of requests --- pyproject.toml | 1 - threatx_api_client/__init__.py | 115 +++++++++++++++------------------ 2 files changed, 53 insertions(+), 63 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9c76f43..a43dc90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,6 @@ packages = [{include = "threatx_api_client"}] [tool.poetry.dependencies] python = "^3.8" -requests = "^2.31.0" aiohttp = "^3.9.3" [tool.poetry.group.dev.dependencies] diff --git a/threatx_api_client/__init__.py b/threatx_api_client/__init__.py index e7dd52b..61d44ab 100644 --- a/threatx_api_client/__init__.py +++ b/threatx_api_client/__init__.py @@ -1,4 +1,6 @@ -import requests +import asyncio + +import aiohttp from threatx_api_client.exceptions import ( TXAPIIncorrectCommandError, @@ -6,8 +8,6 @@ TXAPIIncorrectTokenError, TXAPIResponseError, ) -import aiohttp -import asyncio class Client: @@ -27,10 +27,9 @@ def __init__(self, api_env, api_key): self.api_env = api_env self.api_key = api_key - self.http_connector = aiohttp.TCPConnector(limit_per_host=100, limit=0, ttl_dns_cache=300) self.parallel_requests = 10 - self.session_token = self.__login() + self.session_token = self.__get_session_token() def __get_api_env_host(self): if self.api_env not in self.host_parts: @@ -44,69 +43,62 @@ def __get_api_env_host(self): def __generate_api_link(self, api_ver: int): return f"/{self.api_path}/v{api_ver}" - # async def post(self, url: str, available_commands, post_payload: dict): - # if post_payload.get("command") not in available_commands: - # raise TXAPIIncorrectCommandError(post_payload.get("command")) - # - # async with asyncio.Semaphore(self.parallel_requests): - # async with self.http_session.post(url, json={"token": self.session_token, **post_payload}) as response: - # response_data = await response.json() - # - # if response_data: - # result_responses.append(response_data) - # - # if response_data.get("Error") == "Token Expired. Please re-authenticate.": - # self.session_token = self.__login() - # return self.__process_response(url, available_commands, post_payload) - # else: - # raise TXAPIResponseError(response_data.get("Error")) - - async def __process_response(self, url: str, available_commands: list, payloads): - http_session = aiohttp.ClientSession( - base_url=self.__get_api_env_host(), - connector=self.http_connector + def __init_http_session(self): + self.http_session = aiohttp.ClientSession( + base_url=self.__get_api_env_host() ) - semaphore = asyncio.Semaphore(self.parallel_requests) - responses = [] - - async def post(post_payload: dict): - if post_payload.get("command") not in available_commands: - raise TXAPIIncorrectCommandError(post_payload.get("command")) - - async with semaphore: - async with http_session.post(url, json={"token": self.session_token, **post_payload}) as raw_response: - response = await raw_response.json() - response_ok_data = response.get("Ok") - response_error_data = response.get("Error") - - if response_ok_data: - responses.append(response_ok_data) - return None - - if response_error_data == "Token Expired. Please re-authenticate.": - self.session_token = self.__login() - return post(post_payload) - else: - raise TXAPIResponseError(response_error_data) - - await asyncio.gather(*(post(payload) for payload in payloads)) - await http_session.close() + + async def __post(self, session, path: str, post_payload: dict): + async with asyncio.Semaphore(self.parallel_requests): + async with session.post(path, json=post_payload) as raw_response: + response = await raw_response.json() + response_ok_data = response.get("Ok") + response_error_data = response.get("Error") + + if response_ok_data: + return response_ok_data + + if response_error_data == "Token Expired. Please re-authenticate.": + self.session_token = self.__get_session_token() + return self.__post(session, path, post_payload) + else: + raise TXAPIResponseError(response_error_data) + + async def __process_response(self, path: str, available_commands: list, payloads): + for payload in payloads: + if payload.get("command") not in available_commands: + raise TXAPIIncorrectCommandError(payload.get("command")) + + async with aiohttp.ClientSession(base_url=self.__get_api_env_host()) as session: + responses = await asyncio.gather(*( + self.__post( + session, + path, + {"token": self.session_token, **payload}) for payload in payloads + )) + return responses - def __login(self): - url = f"{self.__get_api_env_host()}{self.__generate_api_link(1)}/login" + async def __login(self): + path = f"{self.__generate_api_link(1)}/login" if not self.api_key: raise TXAPIIncorrectTokenError("Please provide TX API Key.") - data = {"command": "login", "api_token": self.api_key} - - response = requests.post(url, json=data).json()["Ok"]["token"] - - if response: - return response - else: - raise TXAPIIncorrectTokenError("TX API Token is not correct!") + async with aiohttp.ClientSession(base_url=self.__get_api_env_host()) as session: + response = await asyncio.gather( + self.__post( + session, + path, + {"command": "login", "api_token": self.api_key} + ) + ) + return response[0]["token"] + + def __get_session_token(self): + loop = asyncio.get_event_loop() + results = loop.run_until_complete(self.__login()) + return results # TODO: Remove this? # def auth(self, payload): @@ -356,7 +348,6 @@ def list_blacklist(self, payloads): loop = asyncio.get_event_loop() results = loop.run_until_complete(self.__process_response(url, available_commands, payloads)) - self.http_connector.close() return results def list_blocklist(self, payload): From 40c9c3b97e12fc5c4024fb8a15f2690c2da064ff Mon Sep 17 00:00:00 2001 From: Tengz Date: Tue, 9 Apr 2024 10:44:35 +0300 Subject: [PATCH 3/5] refactor: clean up --- threatx_api_client/__init__.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/threatx_api_client/__init__.py b/threatx_api_client/__init__.py index 61d44ab..fc65986 100644 --- a/threatx_api_client/__init__.py +++ b/threatx_api_client/__init__.py @@ -43,11 +43,6 @@ def __get_api_env_host(self): def __generate_api_link(self, api_ver: int): return f"/{self.api_path}/v{api_ver}" - def __init_http_session(self): - self.http_session = aiohttp.ClientSession( - base_url=self.__get_api_env_host() - ) - async def __post(self, session, path: str, post_payload: dict): async with asyncio.Semaphore(self.parallel_requests): async with session.post(path, json=post_payload) as raw_response: From 856baa7b57d8a45e2480e04293c59178d3e8ca2e Mon Sep 17 00:00:00 2001 From: Tengz Date: Tue, 9 Apr 2024 11:25:47 +0300 Subject: [PATCH 4/5] refactor: backwards compatibility --- threatx_api_client/__init__.py | 242 +++++++++++++++++++++++++-------- 1 file changed, 183 insertions(+), 59 deletions(-) diff --git a/threatx_api_client/__init__.py b/threatx_api_client/__init__.py index fc65986..7c7c3c0 100644 --- a/threatx_api_client/__init__.py +++ b/threatx_api_client/__init__.py @@ -29,6 +29,8 @@ def __init__(self, api_env, api_key): self.parallel_requests = 10 + self.base_url = self.__get_api_env_host() + self.session_token = self.__get_session_token() def __get_api_env_host(self): @@ -46,7 +48,7 @@ def __generate_api_link(self, api_ver: int): async def __post(self, session, path: str, post_payload: dict): async with asyncio.Semaphore(self.parallel_requests): async with session.post(path, json=post_payload) as raw_response: - response = await raw_response.json() + response = await raw_response.json(content_type=None) response_ok_data = response.get("Ok") response_error_data = response.get("Error") @@ -60,11 +62,14 @@ async def __post(self, session, path: str, post_payload: dict): raise TXAPIResponseError(response_error_data) async def __process_response(self, path: str, available_commands: list, payloads): + if isinstance(payloads, dict): + payloads = [payloads] + for payload in payloads: if payload.get("command") not in available_commands: raise TXAPIIncorrectCommandError(payload.get("command")) - async with aiohttp.ClientSession(base_url=self.__get_api_env_host()) as session: + async with aiohttp.ClientSession(base_url=self.base_url) as session: responses = await asyncio.gather(*( self.__post( session, @@ -72,6 +77,9 @@ async def __process_response(self, path: str, available_commands: list, payloads {"token": self.session_token, **payload}) for payload in payloads )) + if len(responses) == 1: + return responses[0] + return responses async def __login(self): @@ -80,7 +88,7 @@ async def __login(self): if not self.api_key: raise TXAPIIncorrectTokenError("Please provide TX API Key.") - async with aiohttp.ClientSession(base_url=self.__get_api_env_host()) as session: + async with aiohttp.ClientSession(base_url=self.base_url) as session: response = await asyncio.gather( self.__post( session, @@ -88,7 +96,13 @@ async def __login(self): {"command": "login", "api_token": self.api_key} ) ) - return response[0]["token"] + + token_value = response[0]["token"] + + if not token_value: + raise TXAPIIncorrectTokenError("TX API Token is not correct!") + + return token_value def __get_session_token(self): loop = asyncio.get_event_loop() @@ -110,7 +124,7 @@ def __get_session_token(self): # # return self.__process_response(url, available_commands, payload) - def api_keys(self, payload: dict): + def api_keys(self, payloads): """API Keys management. Method allows to manage API keys, allowing authorized users to @@ -122,9 +136,13 @@ def api_keys(self, payload: dict): available_commands = ["list", "new", "update", "revoke"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def api_schemas(self, payload): + def api_schemas(self, payloads): """API schemas management. Method allows to manage API schemas. @@ -135,9 +153,13 @@ def api_schemas(self, payload): available_commands = ["save", "list", "delete"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def customers(self, payload): + def customers(self, payloads): """Customers management. Method allows to create, manage and remove customers. @@ -159,9 +181,13 @@ def customers(self, payload): "set_customer_config", # TODO: confirm ] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def users(self, payload): + def users(self, payloads): """Users management. Method allows to create, manage and remove users. @@ -179,9 +205,13 @@ def users(self, payload): "get_api_key", # TODO: confirm ] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def sites(self, payload): + def sites(self, payloads): """Sites management. Method allows to create, manage and remove sites. @@ -192,9 +222,13 @@ def sites(self, payload): available_commands = ["list", "new", "get", "delete", "update", "unset"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def site_groups(self, payload): + def site_groups(self, payloads): """Site groups management. Method allows to create, manage and remove site groups. @@ -206,9 +240,13 @@ def site_groups(self, payload): available_commands = ["list", "save", "delete"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def templates(self, payload): + def templates(self, payloads): """Templates management. Method allows to create, manage and remove customer templates. @@ -219,9 +257,13 @@ def templates(self, payload): available_commands = ["set", "get", "delete"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def sensors(self, payload): + def sensors(self, payloads): """Sensors information. Method provides information of on-premises deployed sensors and sensor metadata. @@ -232,9 +274,13 @@ def sensors(self, payload): available_commands = ["list", "tags"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def services(self, payload): + def services(self, payloads): """Services information. Method provides information on ThreatX system services and their public IP addresses. @@ -245,9 +291,13 @@ def services(self, payload): available_commands = ["list"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def entities(self, payload): + def entities(self, payloads): """Entities management. Method allows to list and manage entities. @@ -272,9 +322,13 @@ def entities(self, payload): "count" ] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def metrics(self, payload): + def metrics(self, payloads): """Statistical metrics. Method provides statistical metrics on ThreatX system operations. @@ -300,9 +354,13 @@ def metrics(self, payload): "request_stats_hourly_by_endpoint" ] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def subscriptions(self, payload): + def subscriptions(self, payloads): """Subscriptions management. Method allows to configure customer notification subscriptions. @@ -315,9 +373,13 @@ def subscriptions(self, payload): available_commands = ["save", "delete", "list", "enable", "disable"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def list_whitelist(self, payload): + def list_whitelist(self, payloads): """Get whitelist IPs. Method allows to get customer whitelisted IPs. @@ -328,7 +390,11 @@ def list_whitelist(self, payload): available_commands = ["list"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results def list_blacklist(self, payloads): """Get blacklist IPs. @@ -342,10 +408,12 @@ def list_blacklist(self, payloads): available_commands = ["list"] loop = asyncio.get_event_loop() - results = loop.run_until_complete(self.__process_response(url, available_commands, payloads)) + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) return results - def list_blocklist(self, payload): + def list_blocklist(self, payloads): """Get blocklisted IPs. Method allows to get customer blocked IPs. @@ -356,9 +424,13 @@ def list_blocklist(self, payload): available_commands = ["list"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def list_mutelist(self, payload): + def list_mutelist(self, payloads): """Get mutelisted IPs. Method allows to get customer mutelisted IPs. @@ -369,9 +441,13 @@ def list_mutelist(self, payload): available_commands = ["list"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def list_ignorelist(self, payload): + def list_ignorelist(self, payloads): """Get ignorelisted IPs. Method allows to get customer ignorelisted IPs. @@ -382,9 +458,13 @@ def list_ignorelist(self, payload): available_commands = ["list"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def global_tags(self, payload): + def global_tags(self, payloads): """Global tags management. Method allows to create new and provides information of global tags available for use. @@ -395,9 +475,13 @@ def global_tags(self, payload): available_commands = ["new", "list"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def actor_tags(self, payload): + def actor_tags(self, payloads): """Actor tags management. Method allows to create, manage and remove actor tags. @@ -408,16 +492,24 @@ def actor_tags(self, payload): available_commands = ["new", "list", "delete"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def features(self, payload): + def features(self, payloads): url = f"{self.__generate_api_link(1)}/features" available_commands = ["list", "query", "save", "delete"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def metrics_tech(self, payload): + def metrics_tech(self, payloads): """API Profiler information. Method provides information of customer API Profiler. @@ -428,9 +520,13 @@ def metrics_tech(self, payload): available_commands = ["list_endpoint_profiles", "list_site_profiles"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def channels(self, payload): + def channels(self, payloads): """Channels management. Method allows to create, manage and remove customer channels. @@ -441,9 +537,13 @@ def channels(self, payload): available_commands = ["new", "list", "update"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def global_settings(self, payload): + def global_settings(self, payloads): """Customer-wide settings. Method allows to get default customer-wide settings applied. @@ -454,16 +554,24 @@ def global_settings(self, payload): available_commands = ["get"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def dns_info(self, payload): + def dns_info(self, payloads): url = f"{self.__generate_api_link(1)}/dnsinfo" available_commands = ["list"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def logs(self, payload): + def logs(self, payloads): """Customer logs. Method allows to get customer logs including audit logs, match events, etc. @@ -489,9 +597,13 @@ def logs(self, payload): # for log in response: # log["customer"] = payload["customer_name"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def logs_v2(self, payload): + def logs_v2(self, payloads): """Customer logs. Method allows to get customer logs including block, match and audit events. @@ -512,9 +624,13 @@ def logs_v2(self, payload): # for log in response: # log["customer"] = payload["customer_name"] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def lists(self, payload): + def lists(self, payloads): """Lists management. Method allows to manage IP addresses within black, block and whitelists. @@ -551,9 +667,13 @@ def lists(self, payload): "ip_to_link", ] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results - def rules(self, payload): + def rules(self, payloads): """Rules management. Method allows to create, manage and remove customer rules. @@ -585,4 +705,8 @@ def rules(self, payload): "validate_rule" ] - return self.__process_response(url, available_commands, payload) + loop = asyncio.get_event_loop() + results = loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return results From ae4133364de52e254ef5c853ed12066f591d8c33 Mon Sep 17 00:00:00 2001 From: tengzl33t Date: Wed, 10 Apr 2024 08:28:09 +0300 Subject: [PATCH 5/5] refactor: formatting --- pyproject.toml | 2 +- threatx_api_client/__init__.py | 348 ++++++++++++++------------------- 2 files changed, 152 insertions(+), 198 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a43dc90..19e40c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ select = [ "C90", "I", "N", -# "D", + "D", "PERF", "PL", "FURB", diff --git a/threatx_api_client/__init__.py b/threatx_api_client/__init__.py index 7c7c3c0..631f350 100644 --- a/threatx_api_client/__init__.py +++ b/threatx_api_client/__init__.py @@ -35,7 +35,9 @@ def __init__(self, api_env, api_key): def __get_api_env_host(self): if self.api_env not in self.host_parts: - raise TXAPIIncorrectEnvironmentError(f"TX API Env '{self.api_env}' not found!") + raise TXAPIIncorrectEnvironmentError( + f"TX API Env '{self.api_env}' not found!" + ) part = (f"-{self.host_parts.get(self.api_env)}" if self.host_parts.get(self.api_env) else "") @@ -124,47 +126,52 @@ def __get_session_token(self): # # return self.__process_response(url, available_commands, payload) + def __run_async_processing(self, url, available_commands, payloads): + async_loop = asyncio.get_event_loop() + responses = async_loop.run_until_complete( + self.__process_response(url, available_commands, payloads) + ) + return responses + def api_keys(self, payloads): """API Keys management. Method allows to manage API keys, allowing authorized users to create (and revoke) keys granting automated access to the ThreatX API. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(2)}/apikeys" available_commands = ["list", "new", "update", "revoke"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def api_schemas(self, payloads): """API schemas management. Method allows to manage API schemas. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/apischemas" available_commands = ["save", "list", "delete"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def customers(self, payloads): """Customers management. Method allows to create, manage and remove customers. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/customers" @@ -181,18 +188,16 @@ def customers(self, payloads): "set_customer_config", # TODO: confirm ] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def users(self, payloads): """Users management. Method allows to create, manage and remove users. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/users" @@ -205,104 +210,94 @@ def users(self, payloads): "get_api_key", # TODO: confirm ] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def sites(self, payloads): """Sites management. Method allows to create, manage and remove sites. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(2)}/sites" available_commands = ["list", "new", "get", "delete", "update", "unset"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def site_groups(self, payloads): """Site groups management. Method allows to create, manage and remove site groups. - Site groups provide access control features similar to UNIX user groups, restricting access to ThreatX sites. - :param dict payload: API payload containing main command and additional parameters. - :return: + Site groups provide access control features similar to UNIX user groups, + restricting access to ThreatX sites. + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/sitegroups" available_commands = ["list", "save", "delete"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def templates(self, payloads): """Templates management. Method allows to create, manage and remove customer templates. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/templates" available_commands = ["set", "get", "delete"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def sensors(self, payloads): """Sensors information. Method provides information of on-premises deployed sensors and sensor metadata. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/sensors" available_commands = ["list", "tags"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def services(self, payloads): """Services information. - Method provides information on ThreatX system services and their public IP addresses. - :param dict payload: API payload containing main command and additional parameters. - :return: + Method provides information on ThreatX system services + and their public IP addresses. + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/services" available_commands = ["list"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def entities(self, payloads): """Entities management. Method allows to list and manage entities. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/entities" @@ -322,18 +317,16 @@ def entities(self, payloads): "count" ] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def metrics(self, payloads): """Statistical metrics. Method provides statistical metrics on ThreatX system operations. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/metrics" @@ -354,229 +347,200 @@ def metrics(self, payloads): "request_stats_hourly_by_endpoint" ] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def subscriptions(self, payloads): """Subscriptions management. Method allows to configure customer notification subscriptions. - Subscriptions are used to receive notifications related to ThreatX events, delivered either via email, + Subscriptions are used to receive notifications related + to ThreatX events, delivered either via email, webhook, or through a log emitter communicating directly to an analyzer. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/subscriptions" available_commands = ["save", "delete", "list", "enable", "disable"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def list_whitelist(self, payloads): """Get whitelist IPs. Method allows to get customer whitelisted IPs. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/whitelist" available_commands = ["list"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def list_blacklist(self, payloads): """Get blacklist IPs. Method allows to get customer blacklisted IPs. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/blacklist" available_commands = ["list"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def list_blocklist(self, payloads): """Get blocklisted IPs. Method allows to get customer blocked IPs. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/blocklist" available_commands = ["list"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def list_mutelist(self, payloads): """Get mutelisted IPs. Method allows to get customer mutelisted IPs. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/mutelist" available_commands = ["list"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def list_ignorelist(self, payloads): """Get ignorelisted IPs. Method allows to get customer ignorelisted IPs. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/ignorelist" available_commands = ["list"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def global_tags(self, payloads): """Global tags management. - Method allows to create new and provides information of global tags available for use. - :param dict payload: API payload containing main command and additional parameters. - :return: + Method allows to create new and provides information of + global tags available for use. + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/globaltags" available_commands = ["new", "list"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def actor_tags(self, payloads): """Actor tags management. Method allows to create, manage and remove actor tags. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/actortags" available_commands = ["new", "list", "delete"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def features(self, payloads): url = f"{self.__generate_api_link(1)}/features" available_commands = ["list", "query", "save", "delete"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def metrics_tech(self, payloads): """API Profiler information. Method provides information of customer API Profiler. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/metrics/tech" available_commands = ["list_endpoint_profiles", "list_site_profiles"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def channels(self, payloads): """Channels management. Method allows to create, manage and remove customer channels. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/channels" available_commands = ["new", "list", "update"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def global_settings(self, payloads): """Customer-wide settings. Method allows to get default customer-wide settings applied. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/globalsettings" available_commands = ["get"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def dns_info(self, payloads): + # TODO: add docs url = f"{self.__generate_api_link(1)}/dnsinfo" available_commands = ["list"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def logs(self, payloads): """Customer logs. Method allows to get customer logs including audit logs, match events, etc. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/logs" @@ -597,18 +561,16 @@ def logs(self, payloads): # for log in response: # log["customer"] = payload["customer_name"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def logs_v2(self, payloads): """Customer logs. Method allows to get customer logs including block, match and audit events. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(2)}/logs" @@ -624,18 +586,16 @@ def logs_v2(self, payloads): # for log in response: # log["customer"] = payload["customer_name"] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def lists(self, payloads): """Lists management. Method allows to manage IP addresses within black, block and whitelists. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/lists" @@ -667,18 +627,16 @@ def lists(self, payloads): "ip_to_link", ] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads) def rules(self, payloads): """Rules management. Method allows to create, manage and remove customer rules. - :param dict payload: API payload containing main command and additional parameters. - :return: + :param list[dict]|dict payloads: API payloads or a single payload containing + main command and additional parameters. + :return: responses: API responses + :rtype: list[dict]|dict """ url = f"{self.__generate_api_link(1)}/rules" @@ -705,8 +663,4 @@ def rules(self, payloads): "validate_rule" ] - loop = asyncio.get_event_loop() - results = loop.run_until_complete( - self.__process_response(url, available_commands, payloads) - ) - return results + return self.__run_async_processing(url, available_commands, payloads)