From 800ae2a290aeead90ce7282ffa192ecfdcda112c Mon Sep 17 00:00:00 2001 From: FedericoGarza Date: Tue, 7 Nov 2023 20:28:30 +0000 Subject: [PATCH] fix: increase max_wait_time --- nbs/timegpt.ipynb | 43 ++++++++++++++++++++++++++++++++----------- nixtlats/timegpt.py | 12 +++++++----- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/nbs/timegpt.ipynb b/nbs/timegpt.ipynb index 3ff0ae19..66d0dbcb 100644 --- a/nbs/timegpt.ipynb +++ b/nbs/timegpt.ipynb @@ -159,7 +159,7 @@ " date_features_to_one_hot: Union[bool, List[str]] = True,\n", " max_retries: int = 6,\n", " retry_interval: int = 10,\n", - " max_wait_time: int = 60,\n", + " max_wait_time: int = 6 * 60,\n", " ):\n", " self.client = client\n", " self.h = h\n", @@ -561,7 +561,7 @@ " environment: Optional[str] = None,\n", " max_retries: int = 6,\n", " retry_interval: int = 10,\n", - " max_wait_time: int = 60,\n", + " max_wait_time: int = 6 * 60,\n", " ):\n", " \"\"\"\n", " Constructs all the necessary attributes for the TimeGPT object.\n", @@ -580,12 +580,14 @@ " The interval in seconds between consecutive retry attempts. \n", " This is the waiting period before the client tries to call the API again after a failed attempt. \n", " Default value is 10 seconds, meaning the client waits for 10 seconds between retries.\n", - " max_wait_time : int, (default=60)\n", + " max_wait_time : int, (default=360)\n", " The maximum total time in seconds that the client will spend on all retry attempts before giving up. \n", " This sets an upper limit on the cumulative waiting time for all retry attempts. \n", " If this time is exceeded, the client will stop retrying and raise an exception. \n", - " Default value is 60 seconds, meaning the client will cease retrying if the total time \n", - " spent on retries exceeds 60 seconds.\n", + " Default value is 360 seconds, meaning the client will cease retrying if the total time \n", + " spent on retries exceeds 360 seconds. \n", + " The client throws a ReadTimeout error after 60 seconds of inactivity. If you want to \n", + " catch these errors, use max_wait_time >> 60. \n", " \"\"\"\n", " if environment is None:\n", " environment = \"https://dashboard.nixtla.io/api\"\n", @@ -1236,12 +1238,22 @@ "outputs": [], "source": [ "#| hide\n", - "from time import time\n", + "from itertools import product\n", + "from time import time, sleep\n", "from unittest.mock import patch\n", "\n", "from requests.exceptions import HTTPError" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from httpx import ReadTimeout" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1250,17 +1262,26 @@ "source": [ "#| hide\n", "# test resilience of api calls\n", - "def mock_api_call(*args, **kwargs):\n", + "sleep_seconds = 5\n", + "\n", + "def raise_read_timeout_error(*args, **kwargs):\n", + " print(f'raising ReadTimeout error after {sleep_seconds} seconds')\n", + " sleep(sleep_seconds)\n", + " raise ReadTimeout\n", + " \n", + "def raise_http_error(*args, **kwargs):\n", + " print('raising HTTP error')\n", " raise HTTPError(response=dict(status_code=503))\n", " \n", "combs = [\n", - " (4, 5, 60),\n", + " (4, 5, 30),\n", " (10, 1, 5),\n", "]\n", - "for max_retries, retry_interval, max_wait_time in combs:\n", + "side_effects = [raise_read_timeout_error, raise_http_error]\n", + "for (max_retries, retry_interval, max_wait_time), side_effect in product(combs, side_effects):\n", " mock_timegpt = TimeGPT(token=os.environ['TIMEGPT_TOKEN'], max_retries=max_retries, retry_interval=retry_interval, max_wait_time=max_wait_time)\n", " init_time = time()\n", - " with patch('nixtlats.client.Nixtla.timegpt_multi_series', side_effect=mock_api_call):\n", + " with patch('nixtlats.client.Nixtla.timegpt_multi_series', side_effect=side_effect):\n", " test_fail(\n", " lambda: mock_timegpt.forecast(df=df, h=12, time_col='timestamp', target_col='value'),\n", " )\n", @@ -1268,7 +1289,7 @@ " approx_epected_time = min((max_retries - 1) * retry_interval, max_wait_time)\n", " upper_expected_time = min(max_retries * retry_interval, max_wait_time)\n", " assert total_mock_time >= approx_epected_time\n", - " assert total_mock_time - upper_expected_time <= 5 # preprocessing time before the first api call shoulb be less than 5 seconds" + " assert total_mock_time - upper_expected_time - (max_retries - 1) * sleep_seconds <= sleep_seconds # preprocessing time before the first api call shoulb be less than 60 seconds" ] }, { diff --git a/nixtlats/timegpt.py b/nixtlats/timegpt.py index 02e7d409..543d41da 100644 --- a/nixtlats/timegpt.py +++ b/nixtlats/timegpt.py @@ -95,7 +95,7 @@ def __init__( date_features_to_one_hot: Union[bool, List[str]] = True, max_retries: int = 6, retry_interval: int = 10, - max_wait_time: int = 60, + max_wait_time: int = 6 * 60, ): self.client = client self.h = h @@ -529,7 +529,7 @@ def __init__( environment: Optional[str] = None, max_retries: int = 6, retry_interval: int = 10, - max_wait_time: int = 60, + max_wait_time: int = 6 * 60, ): """ Constructs all the necessary attributes for the TimeGPT object. @@ -548,12 +548,14 @@ def __init__( The interval in seconds between consecutive retry attempts. This is the waiting period before the client tries to call the API again after a failed attempt. Default value is 10 seconds, meaning the client waits for 10 seconds between retries. - max_wait_time : int, (default=60) + max_wait_time : int, (default=360) The maximum total time in seconds that the client will spend on all retry attempts before giving up. This sets an upper limit on the cumulative waiting time for all retry attempts. If this time is exceeded, the client will stop retrying and raise an exception. - Default value is 60 seconds, meaning the client will cease retrying if the total time - spent on retries exceeds 60 seconds. + Default value is 360 seconds, meaning the client will cease retrying if the total time + spent on retries exceeds 360 seconds. + The client throws a ReadTimeout error after 60 seconds of inactivity. If you want to + catch these errors, use max_wait_time >> 60. """ if environment is None: environment = "https://dashboard.nixtla.io/api"