diff --git a/aryaxai/common/types.py b/aryaxai/common/types.py index 0db178c..e253e78 100644 --- a/aryaxai/common/types.py +++ b/aryaxai/common/types.py @@ -10,7 +10,7 @@ class ProjectConfig(TypedDict): feature_exclude: Optional[List[str]] drop_duplicate_uid: Optional[bool] handle_errors: Optional[bool] - + feature_encodings: Optional[dict] class DataConfig(TypedDict): tags: List[str] diff --git a/aryaxai/core/project.py b/aryaxai/core/project.py index c2c2e2f..1c262d3 100755 --- a/aryaxai/core/project.py +++ b/aryaxai/core/project.py @@ -534,7 +534,8 @@ def upload_data( "pred_label": "", "feature_exclude": [], "drop_duplicate_uid: "", - "handle_errors": False + "handle_errors": False, + "feature_encodings": Dict[str, str] # {"feature_name":"labelencode | countencode | onehotencode"} }, defaults to None :return: response @@ -625,6 +626,19 @@ def upload_file_and_return_path() -> str: feature for feature in column_names if feature not in feature_exclude ] + feature_encodings = config.get("feature_encodings", None) + if feature_encodings: + Validate.value_against_list( + "feature_encodings_feature", + list(feature_encodings.keys()), + column_names, + ) + Validate.value_against_list( + "feature_encodings_feature", + list(feature_encodings.values()), + ["labelencode", "countencode", "onehotencode"], + ) + payload = { "project_name": self.project_name, "project_type": config["project_type"], @@ -639,7 +653,7 @@ def upload_file_and_return_path() -> str: "handle_errors": config.get("handle_errors", False), "feature_exclude": feature_exclude, "feature_include": feature_include, - "feature_encodings": {}, + "feature_encodings": feature_encodings, "feature_actual_used": [], }, } @@ -2097,13 +2111,12 @@ def model_inference( event_id=run_model_res["event_id"], ) - download_tag_payload = { - "project_name": self.project_name, - "tag": f"{tag}_{model}_Inference", - } + auth_token = self.api_client.get_auth_token() + + uri = f"{DOWNLOAD_TAG_DATA_URI}?project_name={self.project_name}&tag={tag}_{model}_Inference&token={auth_token}" - tag_data = self.api_client.request( - "POST", DOWNLOAD_TAG_DATA_URI, download_tag_payload + tag_data = self.api_client.base_request( + "GET", uri ) tag_data_df = pd.read_csv(io.StringIO(tag_data.text))