Skip to content

Commit

Permalink
Merge pull request #106 from josephnowak/improve-xarray-from-func
Browse files Browse the repository at this point in the history
Improving the exec on parallel method
  • Loading branch information
josephnowak authored Jul 26, 2023
2 parents 1a76794 + 4ccfeef commit a656f9a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 35 deletions.
35 changes: 10 additions & 25 deletions tensordb/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def exec_on_parallel(
method: Callable,
paths_kwargs: Dict[str, Dict[str, Any]],
max_parallelization: int = None,
compute: bool = False,
client: Client = None,
call_pool: Literal['thread', 'process'] = 'thread',
compute_kwargs: Dict[str, Any] = None,
):
"""
Expand All @@ -127,42 +125,28 @@ def exec_on_parallel(
and avoid overflow the server or any DB.
None means call all in parallel
compute: bool, default True
Indicate if the computation must be delayed or not, read the use on zarr_storage doc for more info
client: Client, default None
Dask client, useful for in combination with the compute parameter
call_pool: Literal['thread', 'process'], default 'thread'
Internally this method use a Pool of process/thread to apply the method on all the tensors
this indicates the kind of pool
compute_kwargs: Dict[str, Any], default None
Parameters of the dask.compute or client.compute
"""
paths = list(paths_kwargs.keys())

call_pool = ThreadPoolExecutor if call_pool == 'thread' else ProcessPoolExecutor
max_parallelization = np.inf if max_parallelization is None else max_parallelization
max_parallelization = min(max_parallelization, len(paths))
compute_kwargs = compute_kwargs or {}
client = dask if client is None else client

with call_pool(max_parallelization) as pool:
for sub_paths in mit.chunked(paths, max_parallelization):
logger.info(f"Processing the following tensors: {sub_paths}")
futures = [
pool.submit(
BaseTensorClient._exec_on_dask,
method,
{"path": path, "compute": compute, **paths_kwargs[path]},
)
for path in sub_paths
]
futures = [future.result() for future in futures]

if not compute:
client.compute(futures, **compute_kwargs)
for sub_paths in mit.chunked(paths, max_parallelization):
logger.info(f"Processing the following tensors: {sub_paths}")
client.compute([
dask.delayed(BaseTensorClient._exec_on_dask)(
func=method,
params={"path": path, **paths_kwargs[path]}
)
for path in sub_paths
], **compute_kwargs)

def exec_on_dag_order(
self,
Expand Down Expand Up @@ -238,6 +222,7 @@ def exec_on_dag_order(
if only_on_groups:
# filter the invalid groups
level = [tensor for tensor in level if tensor.dag.group in only_on_groups]

if not level:
continue

Expand Down
13 changes: 3 additions & 10 deletions tensordb/tests/test_tensor_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,10 @@ def test_different_client(self, use_local):
assert {'a': 100} == tensor_client.read(path='different_client', name='different_client')

@pytest.mark.parametrize(
'max_parallelization, compute, call_pool',
[
(1, False, 'thread'),
(2, True, 'process'),
(4, False, 'thread'),
(None, True, 'thread'),
]
'max_parallelization',
[1, 2, 4]
)
def test_exec_on_dag_order(self, max_parallelization, compute, call_pool):
def test_exec_on_dag_order(self, max_parallelization):
definitions = [
TensorDefinition(
path='0',
Expand Down Expand Up @@ -341,7 +336,6 @@ def test_exec_on_dag_order(self, max_parallelization, compute, call_pool):
self.tensor_client.exec_on_dag_order(
method=self.tensor_client.store,
parallelization_kwargs={
'compute': compute,
'max_parallelization': max_parallelization
}
)
Expand All @@ -355,7 +349,6 @@ def test_exec_on_dag_order(self, max_parallelization, compute, call_pool):
tensors_path=['1'],
autofill_dependencies=True,
parallelization_kwargs={
'compute': compute,
'max_parallelization': max_parallelization
}
)
Expand Down

0 comments on commit a656f9a

Please sign in to comment.