Skip to content

Commit

Permalink
cleanlab studio beta api
Browse files Browse the repository at this point in the history
  • Loading branch information
axl1313 committed Aug 1, 2024
1 parent 4ad5763 commit 8d17e0c
Show file tree
Hide file tree
Showing 11 changed files with 545 additions and 116 deletions.
8 changes: 8 additions & 0 deletions cleanlab_studio/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,11 @@ def __init__(self, filepath: Union[str, pathlib.Path] = "") -> None:
if isinstance(filepath, pathlib.Path):
filepath = str(filepath)
super().__init__(f"File could not be found at {filepath}. Please check the file path.")


class BetaJobError(HandledError):
pass


class DownloadResultsError(HandledError):
pass
118 changes: 44 additions & 74 deletions cleanlab_studio/internal/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,52 +40,22 @@
pyspark_exists = False

from cleanlab_studio.errors import NotInstalledError
from cleanlab_studio.internal.api.api_helper import check_uuid_well_formed
from cleanlab_studio.internal.api.api_helper import (
check_uuid_well_formed,
construct_headers,
handle_api_error,
)
from cleanlab_studio.internal.types import JSONDict, SchemaOverride
from cleanlab_studio.version import __version__

base_url = os.environ.get("CLEANLAB_API_BASE_URL", "https://api.cleanlab.ai/api")
cli_base_url = f"{base_url}/cli/v0"
upload_base_url = f"{base_url}/upload/v1"
dataset_base_url = f"{base_url}/datasets"
project_base_url = f"{base_url}/projects"
cleanset_base_url = f"{base_url}/cleansets"
model_base_url = f"{base_url}/v1/deployment"
tlm_base_url = f"{base_url}/v0/trustworthy_llm"


def _construct_headers(
api_key: Optional[str], content_type: Optional[str] = "application/json"
) -> JSONDict:
retval = dict()
if api_key:
retval["Authorization"] = f"bearer {api_key}"
if content_type:
retval["Content-Type"] = content_type
retval["Client-Type"] = "python-api"
return retval


def handle_api_error(res: requests.Response) -> None:
handle_api_error_from_json(res.json(), res.status_code)


def handle_api_error_from_json(res_json: JSONDict, status_code: Optional[int] = None) -> None:
if "code" in res_json and "description" in res_json: # AuthError or UserQuotaError format
if res_json["code"] == "user_soft_quota_exceeded":
pass # soft quota limit is going away soon, so ignore it
else:
raise APIError(res_json["description"])

if res_json.get("error", None) is not None:
error = res_json["error"]
if (
status_code == 422
and isinstance(error, dict)
and error.get("code", None) == "UNSUPPORTED_PROJECT_CONFIGURATION"
):
raise InvalidProjectConfiguration(error["description"])
raise APIError(res_json["error"])
API_BASE_URL = os.environ.get("CLEANLAB_API_BASE_URL", "https://api.cleanlab.ai/api")
cli_base_url = f"{API_BASE_URL}/cli/v0"
upload_base_url = f"{API_BASE_URL}/upload/v1"
dataset_base_url = f"{API_BASE_URL}/datasets"
project_base_url = f"{API_BASE_URL}/projects"
cleanset_base_url = f"{API_BASE_URL}/cleansets"
model_base_url = f"{API_BASE_URL}/v1/deployment"
tlm_base_url = f"{API_BASE_URL}/v0/trustworthy_llm"


def handle_rate_limit_error_from_resp(resp: aiohttp.ClientResponse) -> None:
Expand Down Expand Up @@ -134,7 +104,7 @@ def validate_api_key(api_key: str) -> bool:
res = requests.get(
cli_base_url + "/validate",
json=dict(api_key=api_key),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
valid: bool = res.json()["valid"]
Expand All @@ -154,7 +124,7 @@ def initialize_upload(
res = requests.post(
f"{upload_base_url}/file/initialize",
json=dict(size_in_bytes=str(file_size), filename=filename, file_type=file_type),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
upload_id: str = res.json()["upload_id"]
Expand All @@ -169,7 +139,7 @@ def complete_file_upload(api_key: str, upload_id: str, upload_parts: List[JSONDi
res = requests.post(
f"{upload_base_url}/file/complete",
json=request_json,
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)

Expand All @@ -184,7 +154,7 @@ def confirm_upload(
res = requests.post(
f"{upload_base_url}/confirm",
json=request_json,
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)

Expand All @@ -199,7 +169,7 @@ def update_schema(
res = requests.patch(
f"{upload_base_url}/schema",
json=request_json,
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)

Expand All @@ -209,7 +179,7 @@ def get_ingestion_status(api_key: str, upload_id: str) -> JSONDict:
res = requests.get(
f"{upload_base_url}/total_progress",
params=dict(upload_id=upload_id),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
res_json: JSONDict = res.json()
Expand All @@ -221,7 +191,7 @@ def get_dataset_id(api_key: str, upload_id: str) -> JSONDict:
res = requests.get(
f"{upload_base_url}/dataset_id",
params=dict(upload_id=upload_id),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
res_json: JSONDict = res.json()
Expand All @@ -232,7 +202,7 @@ def get_project_of_cleanset(api_key: str, cleanset_id: str) -> str:
check_uuid_well_formed(cleanset_id, "cleanset ID")
res = requests.get(
cli_base_url + f"/cleansets/{cleanset_id}/project",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
project_id: str = res.json()["project_id"]
Expand All @@ -243,7 +213,7 @@ def get_label_column_of_project(api_key: str, project_id: str) -> str:
check_uuid_well_formed(project_id, "project ID")
res = requests.get(
cli_base_url + f"/projects/{project_id}/label_column",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
label_column: str = res.json()["label_column"]
Expand Down Expand Up @@ -274,7 +244,7 @@ def download_cleanlab_columns(
include_cleanlab_columns=include_cleanlab_columns,
include_project_details=include_project_details,
),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
id_col = get_id_column(api_key, cleanset_id)
Expand Down Expand Up @@ -306,7 +276,7 @@ def download_array(
check_uuid_well_formed(cleanset_id, "cleanset ID")
res = requests.get(
cli_base_url + f"/cleansets/{cleanset_id}/{name}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
res_json: JSONDict = res.json()
Expand All @@ -323,7 +293,7 @@ def get_id_column(api_key: str, cleanset_id: str) -> str:
check_uuid_well_formed(cleanset_id, "cleanset ID")
res = requests.get(
cli_base_url + f"/cleansets/{cleanset_id}/id_column",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
id_column: str = res.json()["id_column"]
Expand All @@ -334,7 +304,7 @@ def get_dataset_of_project(api_key: str, project_id: str) -> str:
check_uuid_well_formed(project_id, "project ID")
res = requests.get(
cli_base_url + f"/projects/{project_id}/dataset",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
dataset_id: str = res.json()["dataset_id"]
Expand All @@ -345,7 +315,7 @@ def get_dataset_schema(api_key: str, dataset_id: str) -> JSONDict:
check_uuid_well_formed(dataset_id, "dataset ID")
res = requests.get(
cli_base_url + f"/datasets/{dataset_id}/schema",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
schema: JSONDict = res.json()["schema"]
Expand All @@ -357,7 +327,7 @@ def get_dataset_details(api_key: str, dataset_id: str, task_type: Optional[str])
res = requests.get(
project_base_url + f"/dataset_details/{dataset_id}",
params=dict(tasktype=task_type),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
dataset_details: JSONDict = res.json()
Expand All @@ -368,7 +338,7 @@ def check_column_diversity(api_key: str, dataset_id: str, column_name: str) -> J
check_uuid_well_formed(dataset_id, "dataset ID")
res = requests.get(
dataset_base_url + f"/diversity/{dataset_id}/{column_name}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
column_diversity: JSONDict = res.json()
Expand All @@ -379,7 +349,7 @@ def is_valid_multilabel_column(api_key: str, dataset_id: str, column_name: str)
check_uuid_well_formed(dataset_id, "dataset ID")
res = requests.get(
dataset_base_url + f"/check_valid_multilabel/{dataset_id}/{column_name}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
multilabel_column: JSONDict = res.json()
Expand Down Expand Up @@ -410,7 +380,7 @@ def clean_dataset(
)
res = requests.post(
project_base_url + f"/clean",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
json=request_json,
)
handle_api_error(res)
Expand All @@ -422,7 +392,7 @@ def get_latest_cleanset_id(api_key: str, project_id: str) -> str:
check_uuid_well_formed(project_id, "project ID")
res = requests.get(
cleanset_base_url + f"/project/{project_id}/latest_cleanset_id",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
cleanset_id = res.json()["cleanset_id"]
Expand All @@ -448,7 +418,7 @@ def get_dataset_id_for_name(
res = requests.get(
dataset_base_url + f"/dataset_id_for_name",
params=dict(dataset_name=dataset_name),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
return cast(Optional[str], res.json().get("dataset_id", None))
Expand All @@ -458,7 +428,7 @@ def get_cleanset_status(api_key: str, cleanset_id: str) -> JSONDict:
check_uuid_well_formed(cleanset_id, "cleanset ID")
res = requests.get(
cleanset_base_url + f"/{cleanset_id}/status",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
status: JSONDict = res.json()
Expand All @@ -467,13 +437,13 @@ def get_cleanset_status(api_key: str, cleanset_id: str) -> JSONDict:

def delete_dataset(api_key: str, dataset_id: str) -> None:
check_uuid_well_formed(dataset_id, "dataset ID")
res = requests.delete(dataset_base_url + f"/{dataset_id}", headers=_construct_headers(api_key))
res = requests.delete(dataset_base_url + f"/{dataset_id}", headers=construct_headers(api_key))
handle_api_error(res)


def delete_project(api_key: str, project_id: str) -> None:
check_uuid_well_formed(project_id, "project ID")
res = requests.delete(project_base_url + f"/{project_id}", headers=_construct_headers(api_key))
res = requests.delete(project_base_url + f"/{project_id}", headers=construct_headers(api_key))
handle_api_error(res)


Expand Down Expand Up @@ -528,7 +498,7 @@ def deploy_model(api_key: str, cleanset_id: str, model_name: str) -> str:
check_uuid_well_formed(cleanset_id, "cleanset ID")
res = requests.post(
model_base_url,
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
json=dict(cleanset_id=cleanset_id, deployment_name=model_name),
)

Expand All @@ -542,7 +512,7 @@ def get_deployment_status(api_key: str, model_id: str) -> str:
check_uuid_well_formed(model_id, "model ID")
res = requests.get(
f"{model_base_url}/{model_id}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)
deployment: JSONDict = res.json()
Expand All @@ -555,7 +525,7 @@ def upload_predict_batch(api_key: str, model_id: str, batch: io.StringIO) -> str
url = f"{model_base_url}/{model_id}/upload"
res = requests.post(
url,
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)

handle_api_error(res)
Expand All @@ -573,7 +543,7 @@ def start_prediction(api_key: str, model_id: str, query_id: str) -> None:
check_uuid_well_formed(query_id, "query ID")
res = requests.post(
f"{model_base_url}/{model_id}/predict/{query_id}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)

handle_api_error(res)
Expand All @@ -584,7 +554,7 @@ def get_prediction_status(api_key: str, query_id: str) -> Dict[str, str]:
check_uuid_well_formed(query_id, "query ID")
res = requests.get(
f"{model_base_url}/predict/{query_id}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)

Expand All @@ -596,7 +566,7 @@ def get_deployed_model_info(api_key: str, model_id: str) -> Dict[str, str]:
check_uuid_well_formed(model_id, "model ID")
res = requests.get(
f"{model_base_url}/{model_id}",
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)
handle_api_error(res)

Expand Down Expand Up @@ -672,7 +642,7 @@ async def tlm_prompt(
res = await client_session.post(
f"{tlm_base_url}/prompt",
json=dict(prompt=prompt, quality=quality_preset, options=options or {}),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)

res_json = await res.json()
Expand Down Expand Up @@ -733,7 +703,7 @@ async def tlm_get_confidence_score(
quality=quality_preset,
options=options or {},
),
headers=_construct_headers(api_key),
headers=construct_headers(api_key),
)

res_json = await res.json()
Expand Down
Loading

0 comments on commit 8d17e0c

Please sign in to comment.