Skip to content

Commit

Permalink
threading_partition_by_uid
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint committed Apr 18, 2024
1 parent fb85fb2 commit 56fcadc
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 16 deletions.
50 changes: 41 additions & 9 deletions nbs/nixtla_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
"import warnings\n",
"from typing import Dict, List, Optional, Union\n",
"\n",
"from functools import partial, wraps\n",
"from multiprocessing import cpu_count\n",
"from concurrent.futures import ThreadPoolExecutor\n",
"from threading import Lock\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tenacity import (\n",
Expand Down Expand Up @@ -805,25 +810,52 @@
" return func(self, **kwargs, num_partitions=1)\n",
" df = kwargs.pop('df')\n",
" X_df = kwargs.pop('X_df', None)\n",
" id_col = kwargs['id_col']\n",
" id_col = kwargs.pop('id_col')\n",
" uids = df['unique_id'].unique()\n",
" results_df = []\n",
" for uids_split in np.array_split(uids, num_partitions):\n",
" df_uids = df.query('unique_id in @uids_split')\n",
" df_lock = Lock()\n",
" X_df_lock = Lock()\n",
"\n",
" def partition_by_uid_single(uids_split, func, df, X_df, id_col, **kwargs):\n",
" with df_lock:\n",
" df_uids = df.query('unique_id in @uids_split')\n",
" if X_df is not None:\n",
" X_df_uids = X_df.query('unique_id in @uids_split')\n",
" with X_df_lock:\n",
" X_df_uids = X_df.query('unique_id in @uids_split')\n",
" else:\n",
" X_df_uids = None\n",
" df_uids = remove_unused_categories(df_uids, col=id_col)\n",
" X_df_uids = remove_unused_categories(X_df_uids, col=id_col)\n",
" kwargs_uids = {'df': df_uids, **kwargs}\n",
" if X_df_uids is not None:\n",
" kwargs_uids['X_df'] = X_df_uids\n",
" results_uids = func(self, **kwargs_uids, num_partitions=1)\n",
" results_df.append(results_uids)\n",
" results_df = pd.concat(results_df).reset_index(drop=True)\n",
"\n",
" kwargs_uids['id_col'] = id_col\n",
" results_uids = func(self, **kwargs_uids, num_partitions=1) \n",
"\n",
" return results_uids\n",
"\n",
" fpartition_by_uid_single = partial(partition_by_uid_single, \n",
" func=func, \n",
" df=df, \n",
" X_df=X_df, \n",
" id_col=id_col, \n",
" **kwargs)\n",
"\n",
" num_processes = min(num_partitions, cpu_count())\n",
" uids_splits = np.array_split(uids, num_processes)\n",
"\n",
" with ThreadPoolExecutor(max_workers=num_processes) as executor:\n",
" results_uids = [\n",
" executor.submit(fpartition_by_uid_single, uids_split)\n",
" for uids_split in uids_splits\n",
" ]\n",
" \n",
" results_df = pd.concat([result_df.result() for result_df in results_uids]).reset_index(drop=True)\n",
"\n",
" return results_df\n",
" return wrapper"
"\n",
" return wrapper\n",
" "
]
},
{
Expand Down
48 changes: 41 additions & 7 deletions nixtlats/nixtla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
import warnings
from typing import Dict, List, Optional, Union

from functools import partial, wraps
from multiprocessing import cpu_count
from concurrent.futures import ThreadPoolExecutor
from threading import Lock

import numpy as np
import pandas as pd
from tenacity import (
Expand Down Expand Up @@ -751,23 +756,52 @@ def wrapper(self, num_partitions, **kwargs):
return func(self, **kwargs, num_partitions=1)
df = kwargs.pop("df")
X_df = kwargs.pop("X_df", None)
id_col = kwargs["id_col"]
id_col = kwargs.pop("id_col")
uids = df["unique_id"].unique()
results_df = []
for uids_split in np.array_split(uids, num_partitions):
df_uids = df.query("unique_id in @uids_split")
df_lock = Lock()
X_df_lock = Lock()

def partition_by_uid_single(uids_split, func, df, X_df, id_col, **kwargs):
with df_lock:
df_uids = df.query("unique_id in @uids_split")
if X_df is not None:
X_df_uids = X_df.query("unique_id in @uids_split")
with X_df_lock:
X_df_uids = X_df.query("unique_id in @uids_split")
else:
X_df_uids = None
df_uids = remove_unused_categories(df_uids, col=id_col)
X_df_uids = remove_unused_categories(X_df_uids, col=id_col)
kwargs_uids = {"df": df_uids, **kwargs}
if X_df_uids is not None:
kwargs_uids["X_df"] = X_df_uids

kwargs_uids["id_col"] = id_col
results_uids = func(self, **kwargs_uids, num_partitions=1)
results_df.append(results_uids)
results_df = pd.concat(results_df).reset_index(drop=True)

return results_uids

fpartition_by_uid_single = partial(
partition_by_uid_single,
func=func,
df=df,
X_df=X_df,
id_col=id_col,
**kwargs
)

num_processes = min(num_partitions, cpu_count())
uids_splits = np.array_split(uids, num_processes)

with ThreadPoolExecutor(max_workers=num_processes) as executor:
results_uids = [
executor.submit(fpartition_by_uid_single, uids_split)
for uids_split in uids_splits
]

results_df = pd.concat(
[result_df.result() for result_df in results_uids]
).reset_index(drop=True)

return results_df

return wrapper
Expand Down

0 comments on commit 56fcadc

Please sign in to comment.