diff --git a/tidbcloudy/baseURL.py b/tidbcloudy/baseURL.py deleted file mode 100644 index 6c14a7a..0000000 --- a/tidbcloudy/baseURL.py +++ /dev/null @@ -1,9 +0,0 @@ -from enum import Enum - - -class V1BETA(Enum): - HOST = "https://api.tidbcloud.com/api/v1beta/" - - -class V1BETA1(Enum): - BILLING = "https://billing.tidbapi.com/v1beta1/" diff --git a/tidbcloudy/cluster.py b/tidbcloudy/cluster.py index fe113c3..dc3d4a1 100644 --- a/tidbcloudy/cluster.py +++ b/tidbcloudy/cluster.py @@ -8,7 +8,6 @@ from .backup import Backup from .util.log import log from .util.timestamp import timestamp_to_string -from .util.page import Page # noinspection PyShadowingBuiltins @@ -28,7 +27,7 @@ class Cluster(TiDBCloudyBase, TiDBCloudyContextualBase): def _update_info_from_server(self): path = "projects/{}/clusters/{}".format(self.project_id, self.id) - resp = self.context.call_get(path=path) + resp = self.context.call_get(server="v1beta", path=path) self.assign_object(resp) def wait_for_available(self, *, timeout_sec: int = None, interval_sec: int = 10) -> bool: @@ -76,7 +75,7 @@ def update(self, config: Union[UpdateClusterConfig, dict], update_from_server: b path = "projects/{}/clusters/{}".format(self.project_id, self.id) if isinstance(config, UpdateClusterConfig): config = config.to_object() - self.context.call_patch(path=path, json=config) + self.context.call_patch(server="v1beta", path=path, json=config) log("Cluster id={} has been updated".format(self.id)) if update_from_server: self._update_info_from_server() @@ -84,20 +83,20 @@ def update(self, config: Union[UpdateClusterConfig, dict], update_from_server: b def pause(self): path = "projects/{}/clusters/{}".format(self.project_id, self.id) config = {"config": {"paused": True}} - self.context.call_patch(path=path, json=config) + self.context.call_patch(server="v1beta", path=path, json=config) self._update_info_from_server() log("Cluster id={} status={}".format(self.id, self.status.cluster_status.value)) def resume(self): path = "projects/{}/clusters/{}".format(self.project_id, self.id) config = {"config": {"paused": False}} - self.context.call_patch(path=path, json=config) + self.context.call_patch(server="v1beta", path=path, json=config) self._update_info_from_server() log("Cluster id={} status={}".format(self.id, self.status.cluster_status.value)) def delete(self): path = "projects/{}/clusters/{}".format(self.project_id, self.id) - self.context.call_delete(path=path) + self.context.call_delete(server="v1beta", path=path) log("Cluster id={} has been deleted".format(self.id)) def create_backup(self, *, name: str, description: str = None) -> Backup: @@ -115,7 +114,7 @@ def create_backup(self, *, name: str, description: str = None) -> Backup: config = {"name": name} if description is not None: config["description"] = description - resp = self.context.call_post(path=path, json=config) + resp = self.context.call_post(server="v1beta", path=path, json=config) return self.get_backup(resp["id"]) def delete_backup(self, backup_id: str): @@ -167,10 +166,10 @@ def list_backups(self, *, page: int = None, page_size: int = None) -> Page[Backu query["page"] = page if page_size is not None: query["page_size"] = page_size - resp = self.context.call_get(path=path, params=query) + resp = self.context.call_get(server="v1beta", path=path, params=query) return Page( [Backup.from_object(self.context, {"cluster_id": self.id, "project_id": self.project_id, **backup}) for - backup in resp["items"]], page, page_size, resp["total"]) + backup in resp["items"]], page, page_size, resp["total"]) def get_backup(self, backup_id: str) -> Backup: """ @@ -183,7 +182,7 @@ def get_backup(self, backup_id: str) -> Backup: """ path = "projects/{}/clusters/{}/backups/{}".format(self.project_id, self.id, backup_id) - resp = self.context.call_get(path=path) + resp = self.context.call_get(server="v1beta", path=path) return Backup.from_object(self.context, {"cluster_id": self.id, "project_id": self.project_id, **resp}) def connect(self, type: str, database: str, password: str): diff --git a/tidbcloudy/context.py b/tidbcloudy/context.py index a4f17bc..8d476c2 100644 --- a/tidbcloudy/context.py +++ b/tidbcloudy/context.py @@ -1,20 +1,23 @@ import httpx -from tidbcloudy.baseURL import V1BETA from tidbcloudy.exception import TiDBCloudResponseException class Context: - def __init__(self, public_key: str, private_key: str): + + def __init__(self, public_key: str, private_key: str, server_config: dict): """ Args: public_key: your public key to access to TiDB Cloud private_key: your private key to access to TiDB Cloud + server_config: the server configuration dict to access to TiDB Cloud """ self._client = httpx.Client() self._client.auth = httpx.DigestAuth(public_key, private_key) + self._server_config = server_config - def _call_api(self, method: str, path: str, base_url: str, **kwargs) -> dict: + def _call_api(self, method: str, path: str, server: str, **kwargs) -> dict: + base_url = self._server_config.get(server) if base_url[-1] != "/": base_url += "/" try: @@ -27,26 +30,26 @@ def _call_api(self, method: str, path: str, base_url: str, **kwargs) -> dict: except httpx.HTTPStatusError as exc: raise TiDBCloudResponseException(status=exc.response.status_code, message=exc.response.text) - def call_get(self, path: str, base_url: str = V1BETA.HOST.value, + def call_get(self, server: str, path: str, *, params: dict = None) -> dict: - resp = self._call_api(method="GET", path=path, base_url=base_url, params=params) + resp = self._call_api(method="GET", path=path, server=server, params=params) return resp - def call_post(self, path: str, base_url: str = V1BETA.HOST.value, + def call_post(self, server: str, path: str, *, data: dict = None, json: dict = None) -> dict: - resp = self._call_api(method="POST", path=path, base_url=base_url, data=data, json=json) + resp = self._call_api(method="POST", path=path, server=server, data=data, json=json) return resp - def call_patch(self, path: str, base_url: str = V1BETA.HOST.value, + def call_patch(self, server: str, path: str, *, data: dict = None, json: dict = None) -> dict: - resp = self._call_api(method="PATCH", path=path, base_url=base_url, data=data, json=json) + resp = self._call_api(method="PATCH", path=path, server=server, data=data, json=json) return resp - def call_delete(self, path: str, base_url: str = V1BETA.HOST.value) -> dict: - resp = self._call_api(method="DELETE", base_url=base_url, path=path) + def call_delete(self, server: str, path: str) -> dict: + resp = self._call_api(method="DELETE", server=server, path=path) return resp diff --git a/tidbcloudy/project.py b/tidbcloudy/project.py index af6aba9..a986100 100644 --- a/tidbcloudy/project.py +++ b/tidbcloudy/project.py @@ -51,7 +51,7 @@ def create_cluster(self, config: Union[CreateClusterConfig, dict]) -> Cluster: if isinstance(config, CreateClusterConfig): config = config.to_object() path = "projects/{}/clusters".format(self.id) - resp = self.context.call_post(path=path, json=config) + resp = self.context.call_post(server="v1beta", path=path, json=config) return Cluster(context=self.context, id=resp["id"], project_id=self.id) def update_cluster(self, cluster_id: str, config: Union[UpdateClusterConfig, dict]): @@ -112,7 +112,7 @@ def get_cluster(self, cluster_id: str) -> Cluster: """ path = "projects/{}/clusters/{}".format(self.id, cluster_id) - resp = self.context.call_get(path=path) + resp = self.context.call_get(server="v1beta", path=path) return Cluster.from_object(self.context, resp) def iter_clusters(self, page_size: int = 10) -> Iterator[Cluster]: @@ -168,7 +168,7 @@ def list_clusters(self, page: int = None, page_size: int = None) -> Page[Cluster query["page"] = page if page_size is not None: query["page_size"] = page_size - resp = self.context.call_get(path=path, params=query) + resp = self.context.call_get(server="v1beta", path=path, params=query) return Page( [Cluster.from_object(self.context, item) for item in resp["items"]], page, page_size, resp["total"]) @@ -192,7 +192,7 @@ def create_restore(self, *, name: str, backup_id: str, cluster_config: Union[Cre "backup_id": backup_id, "config": cluster_config["config"] } - resp = self.context.call_post(path=path, json=create_config) + resp = self.context.call_post(server="v1beta", path=path, json=create_config) return Restore(context=self.context, id=resp["id"], cluster_id=resp["cluster_id"]) def get_restore(self, restore_id: str) -> Restore: @@ -205,7 +205,7 @@ def get_restore(self, restore_id: str) -> Restore: """ path = "projects/{}/restores/{}".format(self.id, restore_id) - resp = self.context.call_get(path=path) + resp = self.context.call_get(server="v1beta", path=path) return Restore.from_object(self.context, resp) def list_restores(self, *, page: int = None, page_size: int = None) -> Page[Restore]: @@ -224,7 +224,7 @@ def list_restores(self, *, page: int = None, page_size: int = None) -> Page[Rest query["page"] = page if page_size is not None: query["page_size"] = page_size - resp = self.context.call_get(path=path, params=query) + resp = self.context.call_get(server="v1beta", path=path, params=query) return Page( [Restore.from_object(self.context, item) for item in resp["items"]], page, page_size, resp["total"] @@ -273,7 +273,7 @@ def create_aws_cmek(self, config: List[Tuple[str, str]]) -> None: "kms_arn": kms_arn }) path = f"projects/{self.id}/aws-cmek" - self.context.call_post(path=path, json=payload) + self.context.call_post(server="v1beta", path=path, json=payload) def list_aws_cmek(self) -> Page[ProjectAWSCMEK]: """ @@ -292,7 +292,7 @@ def list_aws_cmek(self) -> Page[ProjectAWSCMEK]: print(cmek) """ path = f"projects/{self.id}/aws-cmek" - resp = self.context.call_get(path=path) + resp = self.context.call_get(server="v1beta", path=path) total = len(resp["items"]) return Page( [ProjectAWSCMEK.from_object(self.context, item) for item in resp["items"]], diff --git a/tidbcloudy/tidbcloud.py b/tidbcloudy/tidbcloud.py index cf9efdb..e812328 100644 --- a/tidbcloudy/tidbcloud.py +++ b/tidbcloudy/tidbcloud.py @@ -1,16 +1,22 @@ from typing import Iterator, List -from tidbcloudy.baseURL import V1BETA1 from tidbcloudy.context import Context from tidbcloudy.project import Project from tidbcloudy.specification import BillingMonthSummary, CloudSpecification from tidbcloudy.util.page import Page from tidbcloudy.util.timestamp import get_current_year_month +SERVER_CONFIG_DEFAULT = { + "v1beta": "https://api.tidbcloud.com/api/v1beta/", + "billing": "https://billing.tidbapi.com/v1beta1/" +} + class TiDBCloud: - def __init__(self, public_key: str, private_key: str): - self._context = Context(public_key, private_key) + def __init__(self, public_key: str, private_key: str, server_config: dict = None): + if server_config is None: + server_config = SERVER_CONFIG_DEFAULT + self._context = Context(public_key, private_key, server_config) def create_project(self, name: str, aws_cmek_enabled: bool = False, update_from_server: bool = False) -> Project: """ @@ -35,7 +41,7 @@ def create_project(self, name: str, aws_cmek_enabled: bool = False, update_from_ "name": name, "aws_cmek_enabled": aws_cmek_enabled } - resp = self._context.call_post(path="projects", json=config) + resp = self._context.call_post(server="v1beta", path="projects", json=config) project_id = resp["id"] if update_from_server: return self.get_project(project_id=project_id, update_from_server=True) @@ -91,7 +97,7 @@ def list_projects(self, page: int = None, page_size: int = None) -> Page[Project query["page"] = page if page_size is not None: query["page_size"] = page_size - resp = self._context.call_get(path="projects", params=query) + resp = self._context.call_get(server="v1beta", path="projects", params=query) return Page( [Project.from_object(self._context, item) for item in resp["items"]], page, page_size, resp["total"]) @@ -137,7 +143,7 @@ def list_provider_regions(self) -> List[CloudSpecification]: print(spec) # This is a CloudSpecification object """ - resp = self._context.call_get(path="clusters/provider/regions") + resp = self._context.call_get(server="v1beta", path="clusters/provider/regions") return [CloudSpecification.from_object(obj=item) for item in resp["items"]] def get_monthly_bill(self, month: str) -> BillingMonthSummary: @@ -159,7 +165,7 @@ def get_monthly_bill(self, month: str) -> BillingMonthSummary: if "-" not in month and len(month) == 6: month = f"{month[:4]}-{month[4:]}" path = f"bills/{month}" - resp = self._context.call_get(path=path, base_url=V1BETA1.BILLING.value) + resp = self._context.call_get(server="billing", path=path) return BillingMonthSummary.from_object(self._context, resp) def get_current_month_bill(self) -> BillingMonthSummary: