Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support newer azure deployments #478

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 51 additions & 29 deletions nbs/src/nixtla_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -686,16 +686,8 @@
" max_retries=max_retries, retry_interval=retry_interval, max_wait_time=max_wait_time\n",
" )\n",
" self._model_params: Dict[Tuple[str, str], Tuple[int, int]] = {}\n",
" if 'ai.azure' in base_url:\n",
" from packaging.version import Version\n",
"\n",
" import nixtla\n",
"\n",
" if Version(nixtla.__version__) > Version(\"0.5.2\"):\n",
" raise NotImplementedError(\n",
" \"This version doesn't support Azure endpoints, please install \"\n",
" \"an earlier version with: `pip install 'nixtla<=0.5.2'`\"\n",
" )\n",
" self._is_azure = 'ai.azure' in base_url\n",
" if self._is_azure:\n",
" self.supported_models = ['azureai']\n",
" else:\n",
" self.supported_models = ['timegpt-1', 'timegpt-1-long-horizon']\n",
Expand All @@ -707,13 +699,13 @@
" if np.issubdtype(v.dtype, np.floating):\n",
" v_cont = np.ascontiguousarray(v, dtype=np.float32)\n",
" d[k] = np.nan_to_num(v_cont, \n",
" nan=np.nan, \n",
" posinf=np.finfo(np.float32).max, \n",
" neginf=np.finfo(np.float32).min,\n",
" copy=False)\n",
" nan=np.nan, \n",
" posinf=np.finfo(np.float32).max, \n",
" neginf=np.finfo(np.float32).min,\n",
" copy=False,\n",
" )\n",
" else:\n",
" d[k] = np.ascontiguousarray(v)\n",
"\n",
" elif isinstance(v, dict):\n",
" ensure_contiguous_arrays(v) \n",
"\n",
Expand Down Expand Up @@ -802,6 +794,11 @@
" ]).T\n",
" return resp\n",
"\n",
" def _maybe_override_model(self, model: str) -> str:\n",
" if self._is_azure:\n",
" model = 'azureai'\n",
" return model\n",
"\n",
" def _get_model_params(self, model: str, freq: str) -> Tuple[int, int]:\n",
" key = (model, freq)\n",
" if key not in self._model_params:\n",
Expand Down Expand Up @@ -832,12 +829,26 @@
" )\n",
"\n",
" def _maybe_assign_feature_contributions(\n",
" self,\n",
" feature_contributions: Optional[List[List[float]]],\n",
" x_cols: List[str],\n",
" out_df: DataFrame,\n",
" insample_feat_contributions: Optional[List[List[float]]],\n",
" self,\n",
" expected_contributions: bool,\n",
" resp: Dict[str, Any],\n",
" x_cols: List[str],\n",
" out_df: DataFrame,\n",
" insample_feat_contributions: Optional[List[List[float]]],\n",
" ) -> None:\n",
" if not expected_contributions:\n",
" return\n",
" if 'feature_contributions' not in resp:\n",
" if self._is_azure:\n",
" warnings.warn(\n",
" \"feature_contributions aren't implemented in Azure yet.\"\n",
" )\n",
" return\n",
" else:\n",
" raise RuntimeError(\n",
" 'feature_contributions expected in response but not found'\n",
" )\n",
" feature_contributions = resp['feature_contributions']\n",
" if feature_contributions is None:\n",
" return \n",
" shap_cols = x_cols + [\"base_value\"]\n",
Expand Down Expand Up @@ -1025,6 +1036,7 @@
" )\n",
" self.__dict__.pop('weights_x', None)\n",
" self.__dict__.pop('feature_contributions', None)\n",
" model = self._maybe_override_model(model)\n",
" logger.info('Validating inputs...')\n",
" df, X_df, drop_id = self._run_validations(\n",
" df=df,\n",
Expand Down Expand Up @@ -1110,9 +1122,11 @@
" in_sample_payload = _forecast_payload_to_in_sample(payload)\n",
" logger.info('Calling Historical Forecast Endpoint...')\n",
" in_sample_resp = self._make_request_with_retries(\n",
" client, 'v2/historic_forecast', in_sample_payload,\n",
" client, 'v2/historic_forecast', in_sample_payload\n",
" )\n",
" insample_feat_contributions = in_sample_resp.get(\n",
" 'feature_contributions', None\n",
" )\n",
" insample_feat_contributions = in_sample_resp['feature_contributions']\n",
" else:\n",
" payloads = _partition_series(payload, num_partitions, h)\n",
" resp = self._make_partitioned_requests(client, 'v2/forecast', payloads)\n",
Expand All @@ -1122,9 +1136,11 @@
" ]\n",
" logger.info('Calling Historical Forecast Endpoint...')\n",
" in_sample_resp = self._make_partitioned_requests(\n",
" client, 'v2/historic_forecast', in_sample_payloads,\n",
" client, 'v2/historic_forecast', in_sample_payloads\n",
" )\n",
" insample_feat_contributions = in_sample_resp.get(\n",
" 'feature_contributions', None\n",
" )\n",
" insample_feat_contributions = in_sample_resp['feature_contributions']\n",
"\n",
" # assemble result\n",
" out = ufp.make_future_dataframe(\n",
Expand All @@ -1149,17 +1165,19 @@
" )\n",
" in_sample_df = ufp.drop_columns(in_sample_df, target_col)\n",
" out = ufp.vertical_concat([in_sample_df, out])\n",
" self._maybe_assign_feature_contributions(feature_contributions=resp['feature_contributions'], \n",
" x_cols=x_cols, \n",
" out_df=out[[id_col, time_col, 'TimeGPT']],\n",
" insample_feat_contributions=insample_feat_contributions)\n",
" self._maybe_assign_feature_contributions(\n",
" expected_contributions=feature_contributions,\n",
" resp=resp,\n",
" x_cols=x_cols,\n",
" out_df=out[[id_col, time_col, 'TimeGPT']],\n",
" insample_feat_contributions=insample_feat_contributions,\n",
" )\n",
" if add_history:\n",
" sort_idxs = ufp.maybe_compute_sort_indices(out, id_col=id_col, time_col=time_col)\n",
" if sort_idxs is not None:\n",
" out = ufp.take_rows(out, sort_idxs)\n",
" if hasattr(self, 'feature_contributions'):\n",
" self.feature_contributions = ufp.take_rows(self.feature_contributions, sort_idxs)\n",
"\n",
" out = _maybe_drop_id(df=out, id_col=id_col, drop=drop_id)\n",
" self._maybe_assign_weights(weights=resp['weights_x'], df=df, x_cols=x_cols)\n",
" return out\n",
Expand Down Expand Up @@ -1251,6 +1269,8 @@
" num_partitions=num_partitions,\n",
" )\n",
" self.__dict__.pop('weights_x', None)\n",
" model = self._maybe_override_model(model)\n",
" logger.info('Validating inputs...')\n",
" df, _, drop_id = self._run_validations(\n",
" df=df,\n",
" X_df=None,\n",
Expand Down Expand Up @@ -1433,6 +1453,8 @@
" model=model,\n",
" num_partitions=num_partitions,\n",
" )\n",
" model = self._maybe_override_model(model)\n",
" logger.info('Validating inputs...')\n",
" df, _, drop_id = self._run_validations(\n",
" df=df,\n",
" X_df=None,\n",
Expand Down
2 changes: 2 additions & 0 deletions nixtla/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
'nixtla/nixtla_client.py'),
'nixtla.nixtla_client.NixtlaClient._maybe_assign_weights': ( 'src/nixtla_client.html#nixtlaclient._maybe_assign_weights',
'nixtla/nixtla_client.py'),
'nixtla.nixtla_client.NixtlaClient._maybe_override_model': ( 'src/nixtla_client.html#nixtlaclient._maybe_override_model',
'nixtla/nixtla_client.py'),
'nixtla.nixtla_client.NixtlaClient._run_validations': ( 'src/nixtla_client.html#nixtlaclient._run_validations',
'nixtla/nixtla_client.py'),
'nixtla.nixtla_client.NixtlaClient.cross_validation': ( 'src/nixtla_client.html#nixtlaclient.cross_validation',
Expand Down
61 changes: 35 additions & 26 deletions nixtla/nixtla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,16 +612,8 @@ def __init__(
max_wait_time=max_wait_time,
)
self._model_params: Dict[Tuple[str, str], Tuple[int, int]] = {}
if "ai.azure" in base_url:
from packaging.version import Version

import nixtla

if Version(nixtla.__version__) > Version("0.5.2"):
raise NotImplementedError(
"This version doesn't support Azure endpoints, please install "
"an earlier version with: `pip install 'nixtla<=0.5.2'`"
)
self._is_azure = "ai.azure" in base_url
if self._is_azure:
self.supported_models = ["azureai"]
else:
self.supported_models = ["timegpt-1", "timegpt-1-long-horizon"]
Expand All @@ -643,7 +635,6 @@ def ensure_contiguous_arrays(d: Dict[str, Any]) -> None:
)
else:
d[k] = np.ascontiguousarray(v)

elif isinstance(v, dict):
ensure_contiguous_arrays(v)

Expand Down Expand Up @@ -737,6 +728,11 @@ def _make_partitioned_requests(
).T
return resp

def _maybe_override_model(self, model: str) -> str:
if self._is_azure:
model = "azureai"
return model

Comment on lines +731 to +735
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

def _get_model_params(self, model: str, freq: str) -> Tuple[int, int]:
key = (model, freq)
if key not in self._model_params:
Expand Down Expand Up @@ -766,11 +762,23 @@ def _maybe_assign_weights(

def _maybe_assign_feature_contributions(
self,
feature_contributions: Optional[List[List[float]]],
expected_contributions: bool,
resp: Dict[str, Any],
x_cols: List[str],
out_df: DataFrame,
insample_feat_contributions: Optional[List[List[float]]],
) -> None:
if not expected_contributions:
return
if "feature_contributions" not in resp:
if self._is_azure:
warnings.warn("feature_contributions aren't implemented in Azure yet.")
return
else:
raise RuntimeError(
"feature_contributions expected in response but not found"
)
feature_contributions = resp["feature_contributions"]
if feature_contributions is None:
return
shap_cols = x_cols + ["base_value"]
Expand Down Expand Up @@ -959,6 +967,7 @@ def forecast(
)
self.__dict__.pop("weights_x", None)
self.__dict__.pop("feature_contributions", None)
model = self._maybe_override_model(model)
logger.info("Validating inputs...")
df, X_df, drop_id = self._run_validations(
df=df,
Expand Down Expand Up @@ -1046,13 +1055,11 @@ def forecast(
in_sample_payload = _forecast_payload_to_in_sample(payload)
logger.info("Calling Historical Forecast Endpoint...")
in_sample_resp = self._make_request_with_retries(
client,
"v2/historic_forecast",
in_sample_payload,
client, "v2/historic_forecast", in_sample_payload
)
insample_feat_contributions = in_sample_resp.get(
"feature_contributions", None
)
insample_feat_contributions = in_sample_resp[
"feature_contributions"
]
else:
payloads = _partition_series(payload, num_partitions, h)
resp = self._make_partitioned_requests(client, "v2/forecast", payloads)
Expand All @@ -1062,13 +1069,11 @@ def forecast(
]
logger.info("Calling Historical Forecast Endpoint...")
in_sample_resp = self._make_partitioned_requests(
client,
"v2/historic_forecast",
in_sample_payloads,
client, "v2/historic_forecast", in_sample_payloads
)
insample_feat_contributions = in_sample_resp.get(
"feature_contributions", None
)
insample_feat_contributions = in_sample_resp[
"feature_contributions"
]

# assemble result
out = ufp.make_future_dataframe(
Expand All @@ -1094,7 +1099,8 @@ def forecast(
in_sample_df = ufp.drop_columns(in_sample_df, target_col)
out = ufp.vertical_concat([in_sample_df, out])
self._maybe_assign_feature_contributions(
feature_contributions=resp["feature_contributions"],
expected_contributions=feature_contributions,
resp=resp,
x_cols=x_cols,
out_df=out[[id_col, time_col, "TimeGPT"]],
insample_feat_contributions=insample_feat_contributions,
Expand All @@ -1109,7 +1115,6 @@ def forecast(
self.feature_contributions = ufp.take_rows(
self.feature_contributions, sort_idxs
)

out = _maybe_drop_id(df=out, id_col=id_col, drop=drop_id)
self._maybe_assign_weights(weights=resp["weights_x"], df=df, x_cols=x_cols)
return out
Expand Down Expand Up @@ -1201,6 +1206,8 @@ def detect_anomalies(
num_partitions=num_partitions,
)
self.__dict__.pop("weights_x", None)
model = self._maybe_override_model(model)
logger.info("Validating inputs...")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool

df, _, drop_id = self._run_validations(
df=df,
X_df=None,
Expand Down Expand Up @@ -1385,6 +1392,8 @@ def cross_validation(
model=model,
num_partitions=num_partitions,
)
model = self._maybe_override_model(model)
logger.info("Validating inputs...")
df, _, drop_id = self._run_validations(
df=df,
X_df=None,
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
"fastcore",
"httpx",
"orjson",
"packaging",
"pandas",
"pydantic",
"tenacity",
Expand Down