Skip to content

Commit

Permalink
Made it possible to register buffer overflow callback and configure i…
Browse files Browse the repository at this point in the history
…nternal processing buffers.
  • Loading branch information
pjarosik committed Apr 10, 2022
1 parent faf049f commit bf3dd3e
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 7 deletions.
37 changes: 36 additions & 1 deletion api/python/arrus/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,22 @@ def __init__(self, callback_fn):
def run(self, element):
try:
self._callback_fn(element)
except:
except Exception as e:
print(e)
traceback.print_exc()


class OnBufferOverflowCallback(arrus.core.OnBufferOverflowCallbackWrapper):

def __init__(self, callback_fn):
super().__init__()
self._callback_fn = callback_fn

def run(self):
try:
self._callback_fn()
except Exception as e:
print(e)
traceback.print_exc()


Expand Down Expand Up @@ -66,6 +81,8 @@ def __init__(self, buffer_handle):
self._buffer_handle = buffer_handle
self._callbacks = []
self._register_internal_callback()
self._on_buffer_overflow_callbacks = []
self._register_internal_buffer_overflow_callback()
self.elements = self._wrap_elements()
self.n_elements = len(self.elements)

Expand All @@ -81,6 +98,14 @@ def append_on_new_data_callback(self, callback):
"""
self._callbacks.append(callback)

def append_on_buffer_overflow_callback(self, callback):
"""
Register callback that will be called when buffer overflow occurs.
:param callback: callback function to register
"""
self._on_buffer_overflow_callbacks.append(callback)

def _register_internal_callback(self):
self._callback_wrapper = OnNewDataCallback(self._callback)
arrus.core.registerOnNewDataCallbackFifoLockFreeBuffer(
Expand All @@ -92,6 +117,16 @@ def _callback(self, element):
for cbk in self._callbacks:
cbk(py_element)

def _on_buffer_overflow_callback(self):
for cbk in self._on_buffer_overflow_callbacks:
cbk()

def _register_internal_buffer_overflow_callback(self):
self._overflow_callback_wrapper = OnBufferOverflowCallback(
self._on_buffer_overflow_callback)
arrus.core.registerOnBufferOverflowCallback(
self._buffer_handle, self._overflow_callback_wrapper)

def _wrap_elements(self):
return [DataBufferElement(self._buffer_handle.getElement(i))
for i in range(self._buffer_handle.getNumberOfElements())]
Expand Down
37 changes: 31 additions & 6 deletions api/python/arrus/utils/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def get_bmode_imaging(sequence, grid, placement="/GPU:0",
EnvelopeDetection(),
Transpose(axes=(0, 2, 1)),
ScanConversion(x_grid, z_grid),
LogCompression()
Mean(axis=0),
LogCompression(),
),
placement=placement)
elif isinstance(sequence, arrus.ops.imaging.PwiSequence) \
Expand Down Expand Up @@ -192,16 +193,24 @@ def __init__(self, input_buffer, const_metadata, processing):
# Initialize pipeline.
self.cp = cp
self.input_buffer = self.__register_buffer(input_buffer)
self.gpu_buffer = Buffer(n_elements=2, shape=const_metadata.input_shape,
default_buffer = ProcessingBuffer(size=2, type="locked")

in_buffer_spec = processing.input_buffer
out_buffer_spec = processing.output_buffer
in_buffer_spec = in_buffer_spec if in_buffer_spec is not None else default_buffer
out_buffer_spec = out_buffer_spec if out_buffer_spec is not None else default_buffer

self.gpu_buffer = Buffer(n_elements=in_buffer_spec.size,
shape=const_metadata.input_shape,
dtype=const_metadata.dtype, math_pkg=cp,
type="locked")
type=in_buffer_spec.type)
self.pipeline = processing.pipeline
self.data_stream = cp.cuda.Stream(non_blocking=True)
self.processing_stream = cp.cuda.Stream(non_blocking=True)
self.out_metadata = processing.pipeline.prepare(const_metadata)
self.out_buffers = [Buffer(n_elements=2, shape=m.input_shape,
self.out_buffers = [Buffer(n_elements=out_buffer_spec.size, shape=m.input_shape,
dtype=m.dtype, math_pkg=np,
type="locked")
type=out_buffer_spec.type)
for m in self.out_metadata]
# Wait for all the initialization done in by the Pipeline.
cp.cuda.Stream.null.synchronize()
Expand All @@ -224,6 +233,9 @@ def __init__(self, input_buffer, const_metadata, processing):
self.metadata_extractor = ExtractMetadata()
self.metadata_extractor.prepare(const_metadata)
self.input_buffer.append_on_new_data_callback(self.process)
if processing.on_buffer_overflow_callback is not None:
self.input_buffer.append_on_buffer_overflow_callback(
processing.on_buffer_overflow_callback)

@property
def outputs(self):
Expand Down Expand Up @@ -508,15 +520,28 @@ def set_placement(self, device):
self.filter_pkg = pkgs['filter_pkg']


@dataclasses.dataclass(frozen=True)
class ProcessingBuffer:
size: int
type: str
# TODO: placement


class Processing:
"""
A description of complete data processing run in the arrus.utils.imaging.
"""

def __init__(self, pipeline, callback=None, extract_metadata=False):
def __init__(self, pipeline, callback=None, extract_metadata=False,
input_buffer: ProcessingBuffer=None,
output_buffer: ProcessingBuffer=None,
on_buffer_overflow_callback=None):
self.pipeline = pipeline
self.callback = callback
self.extract_metadata = extract_metadata
self.input_buffer = input_buffer
self.output_buffer = output_buffer
self.on_buffer_overflow_callback = on_buffer_overflow_callback


class Lambda(Operation):
Expand Down
23 changes: 23 additions & 0 deletions api/python/wrappers/core.i
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,29 @@ void registerOnNewDataCallbackFifoLockFreeBuffer(const std::shared_ptr<arrus::fr
};
fifolockfreeBuffer->registerOnNewDataCallback(actualCallback);
}

class OnBufferOverflowCallbackWrapper {
public:
OnBufferOverflowCallbackWrapper() {}
virtual void run() const {}
virtual ~OnBufferOverflowCallbackWrapper() {}
};

void registerOnBufferOverflowCallback(const std::shared_ptr<arrus::framework::Buffer> &buffer, OnBufferOverflowCallbackWrapper& callback) {
auto fifolockfreeBuffer = std::static_pointer_cast<DataBuffer>(buffer);
::arrus::framework::OnOverflowCallback actualCallback = [&]() {
PyGILState_STATE gstate = PyGILState_Ensure();
try {
callback.run();
} catch(const std::exception &e) {
std::cerr << "Exception: " << e.what() << std::endl;
} catch(...) {
std::cerr << "Unhandled exception" << std::endl;
}
PyGILState_Release(gstate);
};
fifolockfreeBuffer->registerOnOverflowCallback(actualCallback);
}
%};

// ------------------------------------------ SESSION
Expand Down

0 comments on commit bf3dd3e

Please sign in to comment.