Skip to content

Commit

Permalink
Python SDK: Implement logging of images with dtype = U16 and F32
Browse files Browse the repository at this point in the history
  • Loading branch information
emilk committed Aug 5, 2022
1 parent 76946c5 commit 0b19dd2
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 21 deletions.
13 changes: 13 additions & 0 deletions crates/re_log_types/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,19 @@ pub enum TensorDataType {
F32,
}

pub trait TensorDataTypeTrait: Copy + Clone + Send + Sync {
const DTYPE: TensorDataType;
}
impl TensorDataTypeTrait for u8 {
const DTYPE: TensorDataType = TensorDataType::U8;
}
impl TensorDataTypeTrait for u16 {
const DTYPE: TensorDataType = TensorDataType::U16;
}
impl TensorDataTypeTrait for f32 {
const DTYPE: TensorDataType = TensorDataType::F32;
}

/// The data types supported by a [`Tensor`].
#[derive(Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
Expand Down
27 changes: 27 additions & 0 deletions crates/re_sdk_python/python/rerun_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,30 @@ def log_points(name, positions, colors):
positions.astype('float32')

log_points_rs(name, positions, colors)


def log_image(name, image):
# Catch some errors early:
if len(image.shape) < 2 or 3 < len(image.shape):
raise TypeError(f"Expected image, got array of shape {image.shape}")

if len(image.shape) == 3:
depth = image.shape[2]
if depth not in (1, 3, 4):
raise TypeError(
f"Expected image depth of of 1 (gray), 3 (RGB) or 4 (RGBA), got array of shape {image.shape}")

log_tensor(name, image)


def log_tensor(name, image):
if image.dtype == 'uint8':
log_tensor_u8(name, image)
elif image.dtype == 'uint16':
log_tensor_u16(name, image)
elif image.dtype == 'float32':
log_tensor_f32(name, image)
elif image.dtype == 'float64':
log_tensor_f32(name, image.astype('float32'))
else:
raise TypeError(f"Unsupported dtype: {image.dtype}")
47 changes: 30 additions & 17 deletions crates/re_sdk_python/src/python_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ fn rerun_sdk(_py: Python<'_>, m: &PyModule) -> PyResult<()> {

m.add_function(wrap_pyfunction!(log_point2d, m)?)?;
m.add_function(wrap_pyfunction!(log_points_rs, m)?)?;
m.add_function(wrap_pyfunction!(log_image, m)?)?;

m.add_function(wrap_pyfunction!(log_tensor_u8, m)?)?;
m.add_function(wrap_pyfunction!(log_tensor_u16, m)?)?;
m.add_function(wrap_pyfunction!(log_tensor_f32, m)?)?;
Ok(())
}

Expand Down Expand Up @@ -216,19 +219,26 @@ fn log_points_rs(

#[allow(clippy::needless_pass_by_value)]
#[pyfunction]
fn log_image(name: &str, img: numpy::PyReadonlyArrayDyn<'_, u8>) -> PyResult<()> {
match img.shape() {
// NOTE: opencv/numpy uses "height x width" convention
[_, _] | [_, _, 1 | 3 | 4] => {}
_ => {
return Err(PyTypeError::new_err(format!(
"Expected image of dimension of 2 or 3 with a depth of 1 (gray), 3 (RGB) or 4 (RGBA). Got image of shape {:?}", img.shape()
)));
}
};
fn log_tensor_u8(name: &str, img: numpy::PyReadonlyArrayDyn<'_, u8>) {
log_tensor(name, img);
}

#[allow(clippy::needless_pass_by_value)]
#[pyfunction]
fn log_tensor_u16(name: &str, img: numpy::PyReadonlyArrayDyn<'_, u16>) {
log_tensor(name, img);
}

// ----------------
#[allow(clippy::needless_pass_by_value)]
#[pyfunction]
fn log_tensor_f32(name: &str, img: numpy::PyReadonlyArrayDyn<'_, f32>) {
log_tensor(name, img);
}

fn log_tensor<T: TensorDataTypeTrait + numpy::Element + bytemuck::Pod>(
name: &str,
img: numpy::PyReadonlyArrayDyn<'_, T>,
) {
let mut sdk = Sdk::global();

let obj_path = ObjPath::from(name); // TODO(emilk): pass in proper obj path somehow
Expand All @@ -243,8 +253,6 @@ fn log_image(name: &str, img: numpy::PyReadonlyArrayDyn<'_, u8>) -> PyResult<()>
};
let log_msg = LogMsg::DataMsg(data_msg);
sdk.send(log_msg);

Ok(())
}

fn time_point() -> TimePoint {
Expand All @@ -256,10 +264,15 @@ fn time_point() -> TimePoint {
time_point
}

fn to_rerun_tensor(img: &numpy::PyReadonlyArrayDyn<'_, u8>) -> re_log_types::Tensor {
fn to_rerun_tensor<T: TensorDataTypeTrait + numpy::Element + bytemuck::Pod>(
img: &numpy::PyReadonlyArrayDyn<'_, T>,
) -> re_log_types::Tensor {
re_log_types::Tensor {
shape: img.shape().iter().map(|&d| d as u64).collect(),
dtype: TensorDataType::U8,
data: TensorData::Dense(img.to_owned_array().into_raw_vec()),
dtype: T::DTYPE,
// TODO(emilk): avoid double-allocating here
data: TensorData::Dense(bytemuck::allocation::pod_collect_to_vec(
&img.to_owned_array().into_raw_vec(),
)),
}
}
12 changes: 8 additions & 4 deletions crates/re_sdk_python/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ def log(args):

print(rerun.info())

if True:
image = cv2.imread('crates/re_viewer/data/logo_dark_mode.png',
cv2.IMREAD_UNCHANGED)
rerun.log_image("logo", image)

if False:
img = cv2.imread('crates/re_viewer/data/logo_dark_mode.png',
cv2.IMREAD_UNCHANGED)
rerun.log_image("logo", img)
depth_img = cv2.imread('depth_image.pgm', cv2.IMREAD_UNCHANGED)
rerun.log_image("depth", depth_img)

if False:
for i in range(64):
Expand All @@ -28,7 +32,7 @@ def log(args):
y = r * math.sin(angle) + 16.0
rerun.log_point2d(f"point2d_{i}", x, y)

if True:
if False:
pos3 = []
for i in range(1000):
angle = 6.28 * i / 64
Expand Down

0 comments on commit 0b19dd2

Please sign in to comment.