diff --git a/Cargo.toml b/Cargo.toml index e3f6176..413f955 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,11 +22,12 @@ intel-mkl = ["ndarray-linalg/intel-mkl"] transform = ["ndarray-linalg"] [dependencies] -ndarray = { version = "0.15", default-features = false } +ndarray = { version = "0.15", default-features = false, features = ["serde"] } ndarray-stats = { version = "0.5", default-features = false } ndarray-linalg = { version = "0.14", default-features = false, optional = true } noisy_float = { version = "0.2", default-features = false } num-traits = { version = "0.2", default-features = false } +serde = { version = "1", default-features = false, features = ["alloc"] } [dev-dependencies] ndarray = { version = "0.15", features = ["approx"] } @@ -40,3 +41,4 @@ ndarray-linalg = { version = "0.14", features = ["intel-mkl"] } # it is a dependency of intel-mkl-tool, we pin it to temporary solve # https://github.com/rust-math/intel-mkl-src/issues/68 anyhow = "<1.0.49" +serde_json = "1" diff --git a/src/core/colour_models.rs b/src/core/colour_models.rs index 41d34c5..91c5614 100644 --- a/src/core/colour_models.rs +++ b/src/core/colour_models.rs @@ -60,6 +60,8 @@ pub struct Generic5; /// ColourModel trait, this trait reports base parameters for different colour /// models pub trait ColourModel { + const NAME: &'static str; + /// Number of colour channels for a type. fn channels() -> usize { 3 @@ -816,48 +818,89 @@ where } } -impl ColourModel for RGB {} -impl ColourModel for HSV {} -impl ColourModel for HSI {} -impl ColourModel for HSL {} -impl ColourModel for YCrCb {} -impl ColourModel for CIEXYZ {} -impl ColourModel for CIELAB {} -impl ColourModel for CIELUV {} +impl ColourModel for RGB { + const NAME: &'static str = "RGB"; +} + +impl ColourModel for HSV { + const NAME: &'static str = "HSV"; +} + +impl ColourModel for HSI { + const NAME: &'static str = "HSI"; +} + +impl ColourModel for HSL { + const NAME: &'static str = "HSL"; +} + +impl ColourModel for YCrCb { + const NAME: &'static str = "YCrCb"; +} + +impl ColourModel for CIEXYZ { + const NAME: &'static str = "CIEXYZ"; +} + +impl ColourModel for CIELAB { + const NAME: &'static str = "CIELAB"; +} + +impl ColourModel for CIELUV { + const NAME: &'static str = "CIELUV"; +} impl ColourModel for Gray { + const NAME: &'static str = "Gray"; + fn channels() -> usize { 1 } } impl ColourModel for Generic1 { + const NAME: &'static str = "Generic1"; + fn channels() -> usize { 1 } } + impl ColourModel for Generic2 { + const NAME: &'static str = "Generic2"; + fn channels() -> usize { 2 } } + impl ColourModel for Generic3 { + const NAME: &'static str = "Generic3"; + fn channels() -> usize { 3 } } + impl ColourModel for Generic4 { + const NAME: &'static str = "Generic4"; + fn channels() -> usize { 4 } } + impl ColourModel for Generic5 { + const NAME: &'static str = "Generic5"; + fn channels() -> usize { 5 } } impl ColourModel for RGBA { + const NAME: &'static str = "RGBA"; + fn channels() -> usize { 4 } diff --git a/src/core_serde.rs b/src/core_serde.rs new file mode 100644 index 0000000..1309bd5 --- /dev/null +++ b/src/core_serde.rs @@ -0,0 +1,188 @@ +use crate::core::{ColourModel, ImageBase}; +use ndarray::{Data, DataOwned}; +use serde::de; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt; +use std::marker::PhantomData; + +impl Serialize for ImageBase +where + A: Serialize, + T: Data, + C: ColourModel, +{ + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("ImageBase", 2)?; + state.serialize_field("data", &self.data)?; + state.serialize_field("model", C::NAME)?; + state.end() + } +} + +struct ImageBaseVisitor { + _marker_a: PhantomData, + _marker_b: PhantomData, +} + +enum ImageBaseField { + Data, + Model, +} + +impl ImageBaseVisitor { + pub fn new() -> Self { + ImageBaseVisitor { + _marker_a: PhantomData, + _marker_b: PhantomData, + } + } +} + +static IMAGE_BASE_FIELDS: &[&str] = &["data", "model"]; + +impl<'de, A, T, C> Deserialize<'de> for ImageBase +where + A: Deserialize<'de>, + T: DataOwned, + C: ColourModel, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_struct("ImageBase", IMAGE_BASE_FIELDS, ImageBaseVisitor::new()) + } +} + +impl<'de> Deserialize<'de> for ImageBaseField { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ImageBaseFieldVisitor; + + impl<'de> de::Visitor<'de> for ImageBaseFieldVisitor { + type Value = ImageBaseField; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str(r#""data" or "model""#) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "data" => Ok(ImageBaseField::Data), + "model" => Ok(ImageBaseField::Model), + other => Err(de::Error::unknown_field(other, IMAGE_BASE_FIELDS)), + } + } + + fn visit_bytes(self, value: &[u8]) -> Result + where + E: de::Error, + { + match value { + b"data" => Ok(ImageBaseField::Data), + b"model" => Ok(ImageBaseField::Model), + other => Err(de::Error::unknown_field( + &format!("{:?}", other), + IMAGE_BASE_FIELDS, + )), + } + } + } + + deserializer.deserialize_identifier(ImageBaseFieldVisitor) + } +} + +impl<'de, A, T, C> de::Visitor<'de> for ImageBaseVisitor +where + A: Deserialize<'de>, + T: DataOwned, + C: ColourModel, +{ + type Value = ImageBase; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("ndarray-vision Image representation") + } + + fn visit_seq(self, mut visitor: V) -> Result + where + V: de::SeqAccess<'de>, + { + let data = visitor + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let model = visitor + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + Ok(ImageBase { data, model }) + } + + fn visit_map(self, mut visitor: V) -> Result + where + V: de::MapAccess<'de>, + { + let mut data = None; + let mut model: Option<&str> = None; + while let Some(key) = visitor.next_key()? { + match key { + ImageBaseField::Data => { + if data.is_some() { + return Err(de::Error::duplicate_field("data")); + } + data = Some(visitor.next_value()?); + } + ImageBaseField::Model => { + if model.is_some() { + return Err(de::Error::duplicate_field("model")); + } + model = Some(visitor.next_value()?); + } + } + } + let data = data.ok_or_else(|| de::Error::missing_field("data"))?; + let model = model.ok_or_else(|| de::Error::missing_field("model"))?; + if model.to_lowercase() == C::NAME.to_lowercase() { + Ok(ImageBase { + data, + model: PhantomData, + }) + } else { + Err(de::Error::invalid_value( + de::Unexpected::Str(model), + &C::NAME, + )) + } + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use crate::core::{Image, RGB}; + + #[test] + fn serialize_image_base() { + const EXPECTED: &str = r#"{"data":{"v":1,"dim":[2,3,3],"data":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]},"model":"RGB"}"#; + let i = Image::::new(2, 3); + let actual = serde_json::to_string(&i).expect("Serialized image"); + assert_eq!(actual, EXPECTED); + } + + #[test] + fn deserialize_image_base() { + const EXPECTED: &str = r#"{"data":{"v":1,"dim":[2,3,3],"data":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]},"model":"RGB"}"#; + let actual: Image = serde_json::from_str(EXPECTED).expect("Deserialized image"); + assert_eq!(actual.model, PhantomData); + } +} diff --git a/src/lib.rs b/src/lib.rs index 538eaa1..fd20b3c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,7 @@ /// The core of `ndarray-vision` contains the `Image` type and colour models pub mod core; +mod core_serde; /// Image enhancement intrinsics and algorithms #[cfg(feature = "enhancement")] pub mod enhancement;