Skip to content

Commit

Permalink
Merge pull request #80 from trustimaging/slowdown
Browse files Browse the repository at this point in the history
Major improvements to inversion performance
  • Loading branch information
ccuetom authored Jun 3, 2024
2 parents ed0db01 + 6354f23 commit 223c2d7
Show file tree
Hide file tree
Showing 36 changed files with 2,595 additions and 12,615 deletions.
2 changes: 1 addition & 1 deletion mosaic/comms/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1883,7 +1883,7 @@ async def process_msg(self, sender_id, msg):
sender_id, method, msg.reply,
**msg.kwargs)

if comms_method is not False:
if comms_method is not False and future is not None:
try:
await future
except asyncio.CancelledError:
Expand Down
1 change: 0 additions & 1 deletion mosaic/comms/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5):
4. We return the compressed result
"""

if isinstance(payload, pickle.PickleBuffer):
payload = memoryview(payload)

Expand Down
88 changes: 1 addition & 87 deletions mosaic/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async def set_result(self, result):
if not isinstance(result, (tuple, dict)):
result = (result,)

min_size = 1024**2
min_size = 1024**1
if isinstance(result, tuple):
async def store(_value):
return await self.runtime.put(_value, reply=True)
Expand Down Expand Up @@ -272,91 +272,6 @@ def add_profile(self, profile, **kwargs):
kwargs['tessera_id'] = self.tessera_id
return super().add_profile(profile, **kwargs)

async def __prepare_args(self):
"""
Prepare the arguments of the task for execution.
Returns
-------
Future
"""

awaitable_args = []

for index in range(len(self.args)):
arg = self.args[index]

if type(arg) in types.awaitable_types:
self._args_state[index] = arg.state

if arg.state != 'done':
if not isinstance(arg, TaskDone):
self._args_value[index] = None
self._args_pending.add(arg)

def callback(_index, _arg):
def _callback(fut):
self.loop.run(self._set_arg_done, fut, _index, _arg)

return _callback

arg.add_done_callback(callback(index, arg))

else:
async def _await_arg(_index, _arg):
_result = await _arg.result()
_attr = self._args_value if not isinstance(_arg, TaskDone) else None
return _attr, _index, _result

awaitable_args.append(
_await_arg(index, arg)
)

else:
self._args_state[index] = 'ready'
self._args_value[index] = arg

for key, value in self.kwargs.items():
if type(value) in types.awaitable_types:
self._kwargs_state[key] = value.state

if value.state != 'done':
if not isinstance(value, TaskDone):
self._kwargs_value[key] = None
self._kwargs_pending.add(value)

def callback(_key, _arg):
def _callback(fut):
self.loop.run(self._set_kwarg_done, fut, _key, _arg)

return _callback

value.add_done_callback(callback(key, value))

else:
async def _await_kwarg(_key, _arg):
_result = await _arg.result()
_attr = self._kwargs_value if not isinstance(_arg, TaskDone) else None
return _attr, _key, _result

awaitable_args.append(
_await_kwarg(key, value)
)

else:
self._kwargs_state[key] = 'ready'
self._kwargs_value[key] = value

for task in asyncio.as_completed(awaitable_args):
attr, key, result = await task
if attr is not None:
attr[key] = result

await self._check_ready()

return self._ready_future

async def prepare_args(self):
"""
Prepare the arguments of the task for execution.
Expand Down Expand Up @@ -973,7 +888,6 @@ def _deserialisation_helper(cls, state):
if instance.state == 'done':
instance.set_done()

# TODO Unsure about the need for this
# Synchronise the task state, in case something has happened between
# the moment when it was pickled until it has been re-registered on
# this side
Expand Down
20 changes: 14 additions & 6 deletions mosaic/file_manipulation/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _write_dataset(name, obj, group):
dataset.attrs['is_str'] = isinstance(flat_obj[0], str)


def read(obj, lazy=True, filter=None):
def read(obj, lazy=True, filter=None, only=None):
if isinstance(obj, h5py.Group):
if filter is None:
filter = {}
Expand All @@ -137,6 +137,8 @@ def read(obj, lazy=True, filter=None):
if obj.attrs.get('is_array'):
data = []
for key in sorted(obj.keys()):
if only is not None and key not in only:
continue
try:
value = read(obj[key], lazy=lazy, filter=filter)
except FilterException:
Expand All @@ -145,6 +147,8 @@ def read(obj, lazy=True, filter=None):
else:
data = {}
for key in obj.keys():
if only is not None and key not in only:
continue
try:
value = read(obj[key], lazy=lazy, filter=filter)
except FilterException:
Expand Down Expand Up @@ -212,7 +216,7 @@ class on its own,
or as a context manager,
>>> with HDF5(...) as file:
>>> file.write(...)
>>> file.dump(...)
If a particular version is given, the filename will be generated without checks. If no version is given,
the ``path`` will be checked for the latest available version of the file.
Expand Down Expand Up @@ -258,11 +262,15 @@ def __init__(self, *args, **kwargs):

file_parameter = camel_case(parameter)
version = kwargs.pop('version', None)
version_start = kwargs.pop('version_start', 0)
extension = kwargs.pop('extension', '.h5')

if version is None or version < 0:
version = 0
filename = _abs_filename('%s-%s%s' % (project_name, file_parameter, extension), path)
version = version_start
if version > 0:
filename = _abs_filename('%s-%s-%05d%s' % (project_name, file_parameter, version, extension), path)
else:
filename = _abs_filename('%s-%s%s' % (project_name, file_parameter, extension), path)
while os.path.exists(filename):
version += 1
filename = _abs_filename('%s-%s-%05d%s' % (project_name, file_parameter, version, extension), path)
Expand Down Expand Up @@ -297,9 +305,9 @@ def file(self):
def close(self):
self._file.close()

def load(self, lazy=True, filter=None):
def load(self, lazy=True, filter=None, only=None):
group = self._file['/']
description = read(group, lazy=lazy, filter=filter)
description = read(group, lazy=lazy, filter=filter, only=only)
return Struct(description)

def dump(self, description):
Expand Down
6 changes: 3 additions & 3 deletions mosaic/runtime/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def start_node(*args, **extra_kwargs):

self._comms.start_heartbeat(node_proxy.uid)

self.logger.info('Listening at <NODE:0 | WORKER:0:0-0:%d>' % num_workers)
self.logger.info('Listening at <NODE:0 | WORKER:0-%d>' % num_workers)

async def init_cluster(self, **kwargs):
"""
Expand Down Expand Up @@ -233,8 +233,8 @@ async def wait_for(proxy):
self.logger.debug('Started heartbeat with node %s' % node_uid)

self.logger.info('Listening at <NODE:%d-%d | '
'WORKER:0:0-%d:%d address=%s>' % (0, num_nodes, num_nodes, num_workers,
', '.join(node_list)))
'WORKER:0-%d address=%s>' % (0, num_nodes, num_workers,
', '.join(node_list)))

def init_file(self, runtime_config):
runtime_id = self.uid
Expand Down
10 changes: 5 additions & 5 deletions mosaic/runtime/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@ async def init_workers(self, **kwargs):
for worker_index in range(self._num_workers)}
allowed_cpus = sum([len(c) for c in available_cpus.values()])

# Eliminate cores corresponding to hyperthreading
for node_index, node_cpus in available_cpus.items():
node_cpus = [each for each in node_cpus if each < num_cpus]
available_cpus[node_index] = node_cpus

total_cpus = sum([len(c) for c in available_cpus.values()])
worker_cpus = {}
worker_nodes = {}
if total_cpus <= allowed_cpus:
# Eliminate cores corresponding to hyperthreading
for node_index, node_cpus in available_cpus.items():
node_cpus = [each for each in node_cpus if each < num_cpus]
available_cpus[node_index] = node_cpus

node_ids = list(available_cpus.keys())
num_nodes = len(available_cpus)
num_cpus_per_node = min([len(cpus) for cpus in available_cpus.values()])
Expand Down
17 changes: 10 additions & 7 deletions mosaic/runtime/runtime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

import os
import gc
import zmq
import zmq.asyncio
import psutil
Expand Down Expand Up @@ -216,9 +215,6 @@ async def init(self, **kwargs):
else:
self._local_warehouse = self._remote_warehouse

# Start maintenance loop
self._loop.interval(self.maintenance, interval=0.5)

# Connect to monitor
monitor_address = kwargs.get('monitor_address', None)
monitor_port = kwargs.get('monitor_port', None)
Expand All @@ -235,6 +231,13 @@ async def init(self, **kwargs):
if profile:
self.set_profiler()

# Start maintenance loop
if self.uid == 'head' or 'worker' in self.uid:
maintenance_interval = max(0.5, min(len(self._workers)*0.5, 60))
else:
maintenance_interval = 0.5
self._loop.interval(self.maintenance, interval=maintenance_interval)

async def init_warehouse(self, **kwargs):
"""
Init warehouse process.
Expand Down Expand Up @@ -1025,7 +1028,8 @@ async def put(self, obj, publish=False, reply=False):
else:
warehouse = self._local_warehouse
warehouse_obj = WarehouseObject(obj)
self._warehouse_cache[warehouse_obj.uid] = obj
if self.uid != 'head':
self._warehouse_cache[warehouse_obj.uid] = obj

await warehouse.put_remote(obj=obj, uid=warehouse_obj.uid,
publish=publish, reply=reply or publish)
Expand All @@ -1045,7 +1049,7 @@ async def get(self, uid, cache=True):
-------
"""
if not cache:
if self.uid == 'head' or not cache:
return await self._local_warehouse.get_remote(uid=uid, reply=True)

obj_uid = uid.uid if hasattr(uid, 'uid') else uid
Expand Down Expand Up @@ -1198,7 +1202,6 @@ async def maintenance(self):
self._maintenance_msgs = {}

await asyncio.gather(*tasks)
gc.collect()

def maintenance_queue(self, fun):
"""
Expand Down
19 changes: 9 additions & 10 deletions mosaic/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def sizeof(obj, seen=None):
return 0
try:
if hasattr(obj, 'nbytes') and isinstance(obj.nbytes, int):
size = obj.nbytes
return obj.nbytes
else:
size = sys.getsizeof(obj)
except Exception:
Expand Down Expand Up @@ -84,18 +84,17 @@ async def remote_sizeof(obj, seen=None, pending=False):
if isinstance(obj, asyncio.Future):
return 0
if isinstance(obj, mosaic.types.awaitable_types):
size = await obj.size(pending=pending)
return await obj.size(pending=pending)
else:
try:
if hasattr(obj, 'nbytes') and isinstance(obj.nbytes, int):
return obj.nbytes if not pending else 0
else:
size = sys.getsizeof(obj)
except Exception:
size = sys.getsizeof(obj)
if pending:
size = 0
else:
try:
if hasattr(obj, 'nbytes') and isinstance(obj.nbytes, int):
size = obj.nbytes
else:
size = sys.getsizeof(obj)
except Exception:
size = sys.getsizeof(obj)
if seen is None:
seen = set()
obj_id = id(obj)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
'mscript=mosaic.cli.mscript:go',
'mprof=mosaic.cli.mprof:go',
'findomp=mosaic.cli.findomp:go',
'plot=stride.cli.plot:go',
]
},
zip_safe=False,
Expand Down
Loading

0 comments on commit 223c2d7

Please sign in to comment.