Skip to content

Commit

Permalink
feat: refactor the implementation of baseURL
Browse files Browse the repository at this point in the history
  • Loading branch information
Oreoxmt committed Oct 5, 2023
1 parent da54bb0 commit a2bc4ee
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 45 deletions.
9 changes: 0 additions & 9 deletions tidbcloudy/baseURL.py

This file was deleted.

19 changes: 9 additions & 10 deletions tidbcloudy/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .backup import Backup
from .util.log import log
from .util.timestamp import timestamp_to_string
from .util.page import Page


# noinspection PyShadowingBuiltins
Expand All @@ -28,7 +27,7 @@ class Cluster(TiDBCloudyBase, TiDBCloudyContextualBase):

def _update_info_from_server(self):
path = "projects/{}/clusters/{}".format(self.project_id, self.id)
resp = self.context.call_get(path=path)
resp = self.context.call_get(server="v1beta", path=path)
self.assign_object(resp)

def wait_for_available(self, *, timeout_sec: int = None, interval_sec: int = 10) -> bool:
Expand Down Expand Up @@ -76,28 +75,28 @@ def update(self, config: Union[UpdateClusterConfig, dict], update_from_server: b
path = "projects/{}/clusters/{}".format(self.project_id, self.id)
if isinstance(config, UpdateClusterConfig):
config = config.to_object()
self.context.call_patch(path=path, json=config)
self.context.call_patch(server="v1beta", path=path, json=config)
log("Cluster id={} has been updated".format(self.id))
if update_from_server:
self._update_info_from_server()

def pause(self):
path = "projects/{}/clusters/{}".format(self.project_id, self.id)
config = {"config": {"paused": True}}
self.context.call_patch(path=path, json=config)
self.context.call_patch(server="v1beta", path=path, json=config)
self._update_info_from_server()
log("Cluster id={} status={}".format(self.id, self.status.cluster_status.value))

def resume(self):
path = "projects/{}/clusters/{}".format(self.project_id, self.id)
config = {"config": {"paused": False}}
self.context.call_patch(path=path, json=config)
self.context.call_patch(server="v1beta", path=path, json=config)
self._update_info_from_server()
log("Cluster id={} status={}".format(self.id, self.status.cluster_status.value))

def delete(self):
path = "projects/{}/clusters/{}".format(self.project_id, self.id)
self.context.call_delete(path=path)
self.context.call_delete(server="v1beta", path=path)
log("Cluster id={} has been deleted".format(self.id))

def create_backup(self, *, name: str, description: str = None) -> Backup:
Expand All @@ -115,7 +114,7 @@ def create_backup(self, *, name: str, description: str = None) -> Backup:
config = {"name": name}
if description is not None:
config["description"] = description
resp = self.context.call_post(path=path, json=config)
resp = self.context.call_post(server="v1beta", path=path, json=config)
return self.get_backup(resp["id"])

def delete_backup(self, backup_id: str):
Expand Down Expand Up @@ -167,10 +166,10 @@ def list_backups(self, *, page: int = None, page_size: int = None) -> Page[Backu
query["page"] = page
if page_size is not None:
query["page_size"] = page_size
resp = self.context.call_get(path=path, params=query)
resp = self.context.call_get(server="v1beta", path=path, params=query)
return Page(
[Backup.from_object(self.context, {"cluster_id": self.id, "project_id": self.project_id, **backup}) for
backup in resp["items"]], page, page_size, resp["total"])
backup in resp["items"]], page, page_size, resp["total"])

def get_backup(self, backup_id: str) -> Backup:
"""
Expand All @@ -183,7 +182,7 @@ def get_backup(self, backup_id: str) -> Backup:
"""
path = "projects/{}/clusters/{}/backups/{}".format(self.project_id, self.id, backup_id)
resp = self.context.call_get(path=path)
resp = self.context.call_get(server="v1beta", path=path)
return Backup.from_object(self.context, {"cluster_id": self.id, "project_id": self.project_id, **resp})

def connect(self, type: str, database: str, password: str):
Expand Down
25 changes: 14 additions & 11 deletions tidbcloudy/context.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import httpx

from tidbcloudy.baseURL import V1BETA
from tidbcloudy.exception import TiDBCloudResponseException


class Context:
def __init__(self, public_key: str, private_key: str):

def __init__(self, public_key: str, private_key: str, server_config: dict):
"""
Args:
public_key: your public key to access to TiDB Cloud
private_key: your private key to access to TiDB Cloud
server_config: the server configuration dict to access to TiDB Cloud
"""
self._client = httpx.Client()
self._client.auth = httpx.DigestAuth(public_key, private_key)
self._server_config = server_config

def _call_api(self, method: str, path: str, base_url: str, **kwargs) -> dict:
def _call_api(self, method: str, path: str, server: str, **kwargs) -> dict:
base_url = self._server_config.get(server)
if base_url[-1] != "/":
base_url += "/"
try:
Expand All @@ -27,26 +30,26 @@ def _call_api(self, method: str, path: str, base_url: str, **kwargs) -> dict:
except httpx.HTTPStatusError as exc:
raise TiDBCloudResponseException(status=exc.response.status_code, message=exc.response.text)

def call_get(self, path: str, base_url: str = V1BETA.HOST.value,
def call_get(self, server: str, path: str,
*,
params: dict = None) -> dict:
resp = self._call_api(method="GET", path=path, base_url=base_url, params=params)
resp = self._call_api(method="GET", path=path, server=server, params=params)
return resp

def call_post(self, path: str, base_url: str = V1BETA.HOST.value,
def call_post(self, server: str, path: str,
*,
data: dict = None,
json: dict = None) -> dict:
resp = self._call_api(method="POST", path=path, base_url=base_url, data=data, json=json)
resp = self._call_api(method="POST", path=path, server=server, data=data, json=json)
return resp

def call_patch(self, path: str, base_url: str = V1BETA.HOST.value,
def call_patch(self, server: str, path: str,
*,
data: dict = None,
json: dict = None) -> dict:
resp = self._call_api(method="PATCH", path=path, base_url=base_url, data=data, json=json)
resp = self._call_api(method="PATCH", path=path, server=server, data=data, json=json)
return resp

def call_delete(self, path: str, base_url: str = V1BETA.HOST.value) -> dict:
resp = self._call_api(method="DELETE", base_url=base_url, path=path)
def call_delete(self, server: str, path: str) -> dict:
resp = self._call_api(method="DELETE", server=server, path=path)
return resp
16 changes: 8 additions & 8 deletions tidbcloudy/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def create_cluster(self, config: Union[CreateClusterConfig, dict]) -> Cluster:
if isinstance(config, CreateClusterConfig):
config = config.to_object()
path = "projects/{}/clusters".format(self.id)
resp = self.context.call_post(path=path, json=config)
resp = self.context.call_post(server="v1beta", path=path, json=config)
return Cluster(context=self.context, id=resp["id"], project_id=self.id)

def update_cluster(self, cluster_id: str, config: Union[UpdateClusterConfig, dict]):
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_cluster(self, cluster_id: str) -> Cluster:
"""
path = "projects/{}/clusters/{}".format(self.id, cluster_id)
resp = self.context.call_get(path=path)
resp = self.context.call_get(server="v1beta", path=path)
return Cluster.from_object(self.context, resp)

def iter_clusters(self, page_size: int = 10) -> Iterator[Cluster]:
Expand Down Expand Up @@ -168,7 +168,7 @@ def list_clusters(self, page: int = None, page_size: int = None) -> Page[Cluster
query["page"] = page
if page_size is not None:
query["page_size"] = page_size
resp = self.context.call_get(path=path, params=query)
resp = self.context.call_get(server="v1beta", path=path, params=query)
return Page(
[Cluster.from_object(self.context, item) for item in resp["items"]],
page, page_size, resp["total"])
Expand All @@ -192,7 +192,7 @@ def create_restore(self, *, name: str, backup_id: str, cluster_config: Union[Cre
"backup_id": backup_id,
"config": cluster_config["config"]
}
resp = self.context.call_post(path=path, json=create_config)
resp = self.context.call_post(server="v1beta", path=path, json=create_config)
return Restore(context=self.context, id=resp["id"], cluster_id=resp["cluster_id"])

def get_restore(self, restore_id: str) -> Restore:
Expand All @@ -205,7 +205,7 @@ def get_restore(self, restore_id: str) -> Restore:
"""
path = "projects/{}/restores/{}".format(self.id, restore_id)
resp = self.context.call_get(path=path)
resp = self.context.call_get(server="v1beta", path=path)
return Restore.from_object(self.context, resp)

def list_restores(self, *, page: int = None, page_size: int = None) -> Page[Restore]:
Expand All @@ -224,7 +224,7 @@ def list_restores(self, *, page: int = None, page_size: int = None) -> Page[Rest
query["page"] = page
if page_size is not None:
query["page_size"] = page_size
resp = self.context.call_get(path=path, params=query)
resp = self.context.call_get(server="v1beta", path=path, params=query)
return Page(
[Restore.from_object(self.context, item) for item in resp["items"]],
page, page_size, resp["total"]
Expand Down Expand Up @@ -273,7 +273,7 @@ def create_aws_cmek(self, config: List[Tuple[str, str]]) -> None:
"kms_arn": kms_arn
})
path = f"projects/{self.id}/aws-cmek"
self.context.call_post(path=path, json=payload)
self.context.call_post(server="v1beta", path=path, json=payload)

def list_aws_cmek(self) -> Page[ProjectAWSCMEK]:
"""
Expand All @@ -292,7 +292,7 @@ def list_aws_cmek(self) -> Page[ProjectAWSCMEK]:
print(cmek)
"""
path = f"projects/{self.id}/aws-cmek"
resp = self.context.call_get(path=path)
resp = self.context.call_get(server="v1beta", path=path)
total = len(resp["items"])
return Page(
[ProjectAWSCMEK.from_object(self.context, item) for item in resp["items"]],
Expand Down
20 changes: 13 additions & 7 deletions tidbcloudy/tidbcloud.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from typing import Iterator, List

from tidbcloudy.baseURL import V1BETA1
from tidbcloudy.context import Context
from tidbcloudy.project import Project
from tidbcloudy.specification import BillingMonthSummary, CloudSpecification
from tidbcloudy.util.page import Page
from tidbcloudy.util.timestamp import get_current_year_month

SERVER_CONFIG_DEFAULT = {
"v1beta": "https://api.tidbcloud.com/api/v1beta/",
"billing": "https://billing.tidbapi.com/v1beta1/"
}


class TiDBCloud:
def __init__(self, public_key: str, private_key: str):
self._context = Context(public_key, private_key)
def __init__(self, public_key: str, private_key: str, server_config: dict = None):
if server_config is None:
server_config = SERVER_CONFIG_DEFAULT
self._context = Context(public_key, private_key, server_config)

def create_project(self, name: str, aws_cmek_enabled: bool = False, update_from_server: bool = False) -> Project:
"""
Expand All @@ -35,7 +41,7 @@ def create_project(self, name: str, aws_cmek_enabled: bool = False, update_from_
"name": name,
"aws_cmek_enabled": aws_cmek_enabled
}
resp = self._context.call_post(path="projects", json=config)
resp = self._context.call_post(server="v1beta", path="projects", json=config)
project_id = resp["id"]
if update_from_server:
return self.get_project(project_id=project_id, update_from_server=True)
Expand Down Expand Up @@ -91,7 +97,7 @@ def list_projects(self, page: int = None, page_size: int = None) -> Page[Project
query["page"] = page
if page_size is not None:
query["page_size"] = page_size
resp = self._context.call_get(path="projects", params=query)
resp = self._context.call_get(server="v1beta", path="projects", params=query)
return Page(
[Project.from_object(self._context, item) for item in resp["items"]],
page, page_size, resp["total"])
Expand Down Expand Up @@ -137,7 +143,7 @@ def list_provider_regions(self) -> List[CloudSpecification]:
print(spec) # This is a CloudSpecification object
"""
resp = self._context.call_get(path="clusters/provider/regions")
resp = self._context.call_get(server="v1beta", path="clusters/provider/regions")
return [CloudSpecification.from_object(obj=item) for item in resp["items"]]

def get_monthly_bill(self, month: str) -> BillingMonthSummary:
Expand All @@ -159,7 +165,7 @@ def get_monthly_bill(self, month: str) -> BillingMonthSummary:
if "-" not in month and len(month) == 6:
month = f"{month[:4]}-{month[4:]}"
path = f"bills/{month}"
resp = self._context.call_get(path=path, base_url=V1BETA1.BILLING.value)
resp = self._context.call_get(server="billing", path=path)
return BillingMonthSummary.from_object(self._context, resp)

def get_current_month_bill(self) -> BillingMonthSummary:
Expand Down

0 comments on commit a2bc4ee

Please sign in to comment.