Skip to content

Commit

Permalink
ARRUS-29: Make batch size available in C++ API. (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
pjarosik authored Oct 27, 2021
1 parent 01bafb2 commit fc1f5f1
Show file tree
Hide file tree
Showing 28 changed files with 1,011 additions and 662 deletions.
14 changes: 12 additions & 2 deletions api/python/arrus/devices/us4r.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,11 @@ def sampling_frequency(self):
"""
Device sampling frequency [Hz].
"""
# TODO use sampling frequency from the us4r device
return 65e6
return self._handle.getSamplingFrequency()

@property
def n_us4oems(self):
return self._handle.getNumberOfUs4OEMs()

def set_kernel_context(self, kernel_context):
self._current_sequence_context = kernel_context
Expand All @@ -107,6 +110,13 @@ def get_probe_model(self):
return arrus.utils.core.convert_to_py_probe_model(
core_model=self._handle.getProbe(0).getModel())

def set_test_pattern(self, pattern):
"""
Sets given test ADC test patter to be run by Us4OEM components.
"""
test_pattern_core = arrus.utils.core.convert_to_test_pattern(pattern)
self._handle.setTestPattern(test_pattern_core)

def _get_dto(self):
probe_model = arrus.utils.core.convert_to_py_probe_model(
core_model=self._handle.getProbe(0).getModel())
Expand Down
1 change: 1 addition & 0 deletions api/python/arrus/ops/us4r.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class TxRxSequence:
ops: typing.List[TxRx]
tgc_curve: np.ndarray
sri: float = None
n_repeats: int = 1

def __post_init__(self):
object.__setattr__(self, "tgc_curve", np.asarray(self.tgc_curve))
Expand Down
9 changes: 6 additions & 3 deletions api/python/arrus/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def upload(self, scheme: arrus.ops.us4r.Scheme):
kernel_context = self._create_kernel_context(seq, us_device_dto, medium)
raw_seq = arrus.kernels.get_kernel(type(seq))(kernel_context)

batch_size = raw_seq.n_repeats

actual_scheme = dataclasses.replace(scheme, tx_rx_sequence=raw_seq)
core_scheme = arrus.utils.core.convert_to_core_scheme(actual_scheme)
upload_result = self._session_handle.upload(core_scheme)
Expand Down Expand Up @@ -119,7 +121,7 @@ def upload(self, scheme: arrus.ops.us4r.Scheme):
"Currently only a sequence with constant number of samples "
"can be accepted.")
n_samples = next(iter(n_samples))
input_shape = self._get_physical_frame_shape(fcm, n_samples, rx_batch_size=1)
input_shape = self._get_physical_frame_shape(fcm, n_samples, rx_batch_size=batch_size)

buffer = arrus.framework.DataBuffer(buffer_handle)

Expand Down Expand Up @@ -147,6 +149,8 @@ def upload(self, scheme: arrus.ops.us4r.Scheme):
dtype=m.dtype, math_pkg=np,
type="locked")
for m in out_metadata]
# Wait for all the initialization done in by the Pipeline.
cp.cuda.Stream.null.synchronize()
user_out_buffer = queue.Queue(maxsize=1)

def buffer_callback(elements):
Expand All @@ -159,7 +163,6 @@ def buffer_callback(elements):
user_out_buffer.put_nowait(user_elements)
except queue.Full:
pass

except Exception as e:
print(f"Exception: {type(e)}")
except:
Expand Down Expand Up @@ -299,7 +302,7 @@ def _get_physical_frame_shape(self, fcm, n_samples, n_channels=32,
# TODO: We assume here, that each frame has the same number of samples!
# This might not be case in further improvements.
n_frames = np.max(fcm.frames) + 1
return n_frames * n_samples * rx_batch_size, n_channels
return n_frames*n_samples*rx_batch_size, n_channels



20 changes: 17 additions & 3 deletions api/python/arrus/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import arrus.exceptions
import arrus.devices.probe

_UINT16_MIN = 0
_UINT16_MAX = 2**16-1


def convert_to_core_sequence(seq):
"""
Expand Down Expand Up @@ -37,8 +40,6 @@ def convert_to_core_sequence(seq):
rx.downsampling_factor,
arrus.core.PairChannelIdx(int(rx.padding[0]), int(rx.padding[1]))
)


core_txrx = arrus.core.TxRx(core_tx, core_rx, op.pri)
arrus.core.TxRxVectorPushBack(core_seq, core_txrx)

Expand All @@ -51,7 +52,13 @@ def convert_to_core_sequence(seq):
"samples are supported only.")

sri = -1 if seq.sri is None else seq.sri
core_seq = arrus.core.TxRxSequence(core_seq, seq.tgc_curve.tolist(), sri)
if seq.n_repeats < _UINT16_MIN or seq.n_repeats > _UINT16_MAX:
raise arrus.exceptions.IllegalArgumentError(
f"Parameter n_repeats should be in range "
f"[{_UINT16_MIN}, {_UINT16_MAX}]"
)
core_seq = arrus.core.TxRxSequence(core_seq, seq.tgc_curve.tolist(), sri,
seq.n_repeats)
return core_seq


Expand Down Expand Up @@ -117,6 +124,13 @@ def convert_to_core_scheme(scheme):
workMode=core_work_mode)


def convert_to_test_pattern(test_pattern_str):
return {
"OFF": arrus.core.Us4OEM.RxTestPattern_OFF,
"RAMP": arrus.core.Us4OEM.RxTestPattern_RAMP
}[test_pattern_str]


def convert_from_tuple(core_tuple):
"""
Converts arrus core tuple to python tuple.
Expand Down
1 change: 1 addition & 0 deletions api/python/wrappers/core.i
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ using namespace arrus::devices;
%include "arrus/core/api/devices/DeviceId.h"
%include "arrus/core/api/devices/Device.h"
%include "arrus/core/api/devices/DeviceWithComponents.h"
%include "arrus/core/api/devices/us4r/Us4OEM.h"
%include "arrus/core/api/devices/us4r/Us4R.h"
%include "arrus/core/api/devices/probe/ProbeModelId.h"
%include "arrus/core/api/devices/probe/ProbeModel.h"
Expand Down
5 changes: 5 additions & 0 deletions arrus/common/logging/impl/Logging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

BOOST_LOG_ATTRIBUTE_KEYWORD(severity, "Severity", arrus::LogSeverity)
BOOST_LOG_ATTRIBUTE_KEYWORD(deviceIdLogAttr, "DeviceId", std::string)
BOOST_LOG_ATTRIBUTE_KEYWORD(componentIdLogAttr, "ComponentId", std::string)

namespace arrus {

Expand Down Expand Up @@ -37,6 +38,10 @@ addTextSinkBoostPtr(const boost::shared_ptr<std::ostream> &ostream,
[
expr::stream << "[" << deviceIdLogAttr << "]"
]
<< expr::if_(expr::has_attr(componentIdLogAttr))
[
expr::stream << "[" << componentIdLogAttr << "]"
]
<< " " << severity << ": "
<< expr::smessage;
sink->set_formatter(formatter);
Expand Down
17 changes: 8 additions & 9 deletions arrus/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ set(SRC_FILES
api/arrus.h
api/session.h
api/common.h api/io.h
devices/us4r/validators/RxSettingsValidator.h)
devices/us4r/validators/RxSettingsValidator.h
devices/us4r/probeadapter/Us4OEMDataTransferRegistrar.h)

set_source_files_properties(${SRC_FILES} PROPERTIES COMPILE_FLAGS
"${ARRUS_CPP_STRICT_COMPILE_OPTIONS}")
Expand Down Expand Up @@ -239,8 +240,7 @@ target_compile_definitions(${TARGET_NAME} PRIVATE
if (ARRUS_RUN_TESTS)
find_package(GTest REQUIRED)

set(ARRUS_CORE_DEVICES_TESTS_SRCS
devices/DeviceId.cpp common/logging.cpp)
set(ARRUS_CORE_DEVICES_TESTS_SRCS devices/DeviceId.cpp common/logging.cpp)
# core::devices test
create_core_test(devices/DeviceIdTest.cpp devices/DeviceId.cpp)
create_core_test(devices/utilsTest.cpp)
Expand All @@ -250,26 +250,25 @@ if (ARRUS_RUN_TESTS)
set(US4OEM_FACTORY_IMPL_TEST_DEPS devices/DeviceId.cpp common/logging.cpp devices/us4r/us4oem/Us4OEMImpl.cpp
devices/TxRxParameters.cpp devices/us4r/FrameChannelMappingImpl.cpp)
create_core_test(devices/us4r/us4oem/Us4OEMFactoryImplTest.cpp "${US4OEM_FACTORY_IMPL_TEST_DEPS}")
# # create_core_test(devices/us4r/Us4RFactoryImplTest.cpp devic)
create_core_test(devices/us4r/Us4RSettingsConverterImplTest.cpp devices/DeviceId.cpp)
create_core_test(devices/us4r/external/ius4oem/IUs4OEMInitializerImplTest.cpp)
create_core_test(devices/us4r/commonTest.cpp "devices/us4r/common.cpp;devices/TxRxParameters.cpp")
#

set(US4OEM_IMPL_TEST_DEPS common/logging.cpp devices/us4r/us4oem/Us4OEMImpl.cpp
devices/us4r/common.cpp
devices/TxRxParameters.cpp devices/DeviceId.cpp devices/us4r/FrameChannelMappingImpl.cpp)
devices/us4r/common.cpp devices/TxRxParameters.cpp devices/DeviceId.cpp
devices/us4r/FrameChannelMappingImpl.cpp)
create_core_test(devices/us4r/us4oem/Us4OEMImplTest.cpp "${US4OEM_IMPL_TEST_DEPS}")

set(ADAPTER_IMPL_TEST_DEPS common/logging.cpp devices/us4r/probeadapter/ProbeAdapterImpl.cpp
devices/us4r/common.cpp
devices/TxRxParameters.cpp devices/DeviceId.cpp devices/us4r/FrameChannelMappingImpl.cpp)
create_core_test(devices/us4r/probeadapter/ProbeAdapterImplTest.cpp "${ADAPTER_IMPL_TEST_DEPS}")
create_core_test(devices/us4r/probeadapter/Us4OEMDataTransferRegistrarTest.cpp common/logging.cpp)
create_core_test(devices/probe/ProbeImplTest.cpp
"devices/probe/ProbeImpl.cpp;devices/us4r/FrameChannelMappingImpl.cpp;common/logging.cpp;devices/DeviceId.cpp")
# core::io tests
set(ARRUS_CORE_IO_TEST_DATA ${CMAKE_CURRENT_SOURCE_DIR}/io/test-data)
create_core_test(
io/settingsTest.cpp
create_core_test(io/settingsTest.cpp
"" # no additional source files
"protobuf::libprotobuf;arrus-core"
"-DARRUS_TEST_DATA_PATH=\"${ARRUS_CORE_IO_TEST_DATA}\"")
Expand Down
11 changes: 10 additions & 1 deletion arrus/core/api/devices/us4r/Us4OEM.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@ class Us4OEM : public Device, public TriggerGenerator {
using Handle = std::unique_ptr<Us4OEM>;
using RawHandle = PtrHandle<Us4OEM>;

/**
* Us4OEM ADC test pattern state.
*/
enum class RxTestPattern {
OFF,
/** Ramp (sawtooth data pattern). */
RAMP,
};

~Us4OEM() override = default;

virtual double getSamplingFrequency() = 0;
virtual float getSamplingFrequency() = 0;

virtual float getFPGATemperature() = 0;

Expand Down
15 changes: 15 additions & 0 deletions arrus/core/api/devices/us4r/Us4R.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,26 @@ class Us4R : public DeviceWithComponents {
*/
virtual void setRxSettings(const RxSettings &settings) = 0;

/**
* If active is true, turns off probe's RX data acquisition and turns on test patterns generation.
* Otherwise turns off test patterns generation and turns on probe's RX data acquisition.
*/
virtual void setTestPattern(Us4OEM::RxTestPattern pattern) = 0;

virtual void start() = 0;
virtual void stop() = 0;

/**
* Returns the number of us4OEM modules that are used in this us4R system.
*/
virtual uint8_t getNumberOfUs4OEMs() = 0;

/**
* Returns us4R device sampling frequency.
*/
virtual float getSamplingFrequency() const = 0;


Us4R(Us4R const&) = delete;
Us4R(Us4R const&&) = delete;
void operator=(Us4R const&) = delete;
Expand Down
10 changes: 4 additions & 6 deletions arrus/core/api/ops/us4r/Scheme.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ namespace arrus::ops::us4r {
*/
class Scheme {
public:

/**
* How the scheme should be executed on the us4r-lite device.
*
Expand All @@ -24,7 +23,7 @@ class Scheme {
enum class WorkMode {
/** Trigger generated by us4r, error on overflow. */
ASYNC,
/** Trigger generated by host, no error on overflow. */
/** Trigger generated by host, no error on overflow. DEPRECATED: will be replaced in the future by MANUAL mode */
HOST,
/** New data acquisition and processing is manually triggered by user. */
MANUAL
Expand All @@ -40,11 +39,10 @@ class Scheme {
* @param workMode scheme work mode
*/

Scheme(TxRxSequence txRxSequence, uint16 rxBufferSize,
const framework::DataBufferSpec &outputBuffer,
Scheme(TxRxSequence txRxSequence, uint16 rxBufferSize, const framework::DataBufferSpec &outputBuffer,
WorkMode workMode)
: txRxSequence(std::move(txRxSequence)), rxBufferSize(rxBufferSize),
outputBuffer(outputBuffer), workMode(workMode) {}
: txRxSequence(std::move(txRxSequence)), rxBufferSize(rxBufferSize), outputBuffer(outputBuffer),
workMode(workMode) {}

const TxRxSequence &getTxRxSequence() const {
return txRxSequence;
Expand Down
20 changes: 11 additions & 9 deletions arrus/core/api/ops/us4r/TxRxSequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,9 @@ namespace arrus::ops::us4r {
*/
class TxRx {
public:
// TODO(pjarosik) remove default constructor!!! Currently required by py swig wrapper
TxRx()
:tx(std::vector<bool>{}, std::vector<float>{},
Pulse(0, 0, false)),
rx(std::vector<bool>{},
std::make_pair<unsigned int, unsigned int>((unsigned int)0, (unsigned int)0)),
pri(0.0f)
TxRx():tx(std::vector<bool>{}, std::vector<float>{}, Pulse(0, 0, false)),
rx(std::vector<bool>{}, std::make_pair<unsigned int, unsigned int>((unsigned int)0, (unsigned int)0)),
pri(0.0f)
{}

/**
Expand Down Expand Up @@ -62,9 +58,10 @@ class TxRxSequence {
* @param sequence a list of tx/rxs that compose a given sequence
* @param tgcCurve tgc curve to apply
* @param sri sequence repetition interval - the total time that a given sequence should take.
* @param nRepeats - the number of repetitions of a given sequence. Determines the size of the batch
*/
TxRxSequence(std::vector<TxRx> sequence, TGCCurve tgcCurve, float sri = NO_SRI)
: txrxs(std::move(sequence)), tgcCurve(std::move(tgcCurve)), sri(sri) {}
TxRxSequence(std::vector<TxRx> sequence, TGCCurve tgcCurve, float sri = NO_SRI, int16 nRepeats = 1)
: txrxs(std::move(sequence)), tgcCurve(std::move(tgcCurve)), sri(sri), nRepeats(nRepeats) {}

/**
* Returns vector of operations to perform.
Expand Down Expand Up @@ -93,10 +90,15 @@ class TxRxSequence {
}
}

int16 getNRepeats() const {
return nRepeats;
}

private:
std::vector<TxRx> txrxs;
TGCCurve tgcCurve;
std::optional<float> sri;
int16 nRepeats;
};

}
Expand Down
16 changes: 16 additions & 0 deletions arrus/core/common/collections.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
#include <type_traits>
#include <bitset>
#include <stdexcept>
#include <iterator>

#include <gsl/span>
#include <boost/range/combine.hpp>

#include "arrus/core/api/arrus.h"

namespace arrus {

template <typename T> inline T&& identity(T&& t) { return std::forward<T>(t); }

/**
* Returns an array of range [start, end).
*/
Expand Down Expand Up @@ -188,6 +193,17 @@ inline T reduce(InputIt first, InputIt last, T init, BinaryOp binaryOp) {
return result;
}

template <typename T, typename V>
inline V getUnique(const std::vector<T> &input, std::function<V(const T&)> accessor = ::arrus::identity) {
std::unordered_set<V> values;
std::transform(std::begin(input), std::end(input), std::inserter(values, std::end(values)), accessor);
if (values.size() > 1) {
throw IllegalArgumentException("Non unique input values.");
}
// This is the size of a single element produced by this us4oem.
return *std::begin(values);
}

}

#endif //ARRUS_CORE_COMMON_COLLECTIONS_H
6 changes: 5 additions & 1 deletion arrus/core/common/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ std::shared_ptr<LoggerFactory> getLoggerFactory();

Logger::SharedHandle getDefaultLogger();

// Deprecated, prefer using ARRUS_INIT_COMPONENT_LOGGER
#define INIT_ARRUS_DEVICE_LOGGER(logger, devId) \
logger->setAttribute("DeviceId", devId) \
logger->setAttribute("DeviceId", devId)

#define ARRUS_INIT_COMPONENT_LOGGER(logger, componentId) \
logger->setAttribute("ComponentId", componentId)

#define ARRUS_LOG(logger, severity, msg) \
(logger)->log(severity, msg)
Expand Down
Loading

0 comments on commit fc1f5f1

Please sign in to comment.