From e9e357d066aa4889389d1b024d56297a9ce3e078 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ma=C3=ABl=20Kerbiriou?= Date: Sat, 5 Oct 2024 11:43:09 +0200 Subject: [PATCH] support more Pillow image modes for encode/decode --- Cargo.toml | 3 +- pillow_jxl/JpegXLImagePlugin.py | 2 +- src/decode.rs | 115 ++++++++++++++++++++----------- src/encode.rs | 118 +++++++++++++++++++++++++------- 4 files changed, 172 insertions(+), 66 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9b064d7..5c17e97 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,8 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version="0.22.0", features = ["extension-module"] } -jpegxl-rs = { version="0.11.0" } +jpegxl-rs = { version = "0.11.0", default-features = false } +bytemuck = { version = "1.18.0" } [features] # Enables parallel processing support by enabling the "rayon" feature of jpeg-decoder. diff --git a/pillow_jxl/JpegXLImagePlugin.py b/pillow_jxl/JpegXLImagePlugin.py index 158738c..84db642 100644 --- a/pillow_jxl/JpegXLImagePlugin.py +++ b/pillow_jxl/JpegXLImagePlugin.py @@ -7,7 +7,7 @@ from pillow_jxl import Decoder, Encoder -_VALID_JXL_MODES = {"RGB", "RGBA", "L", "LA"} +_VALID_JXL_MODES = {"RGB", "RGBA", "L", "LA", "F", "I;16"} def _accept(data): diff --git a/src/decode.rs b/src/decode.rs index 56ca34c..3662f18 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -1,12 +1,15 @@ use std::borrow::Cow; +use std::u8; -use pyo3::exceptions::PyRuntimeError; +use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; use jpegxl_rs::decode::{Data, Metadata, Pixels}; use jpegxl_rs::parallel::threads_runner::ThreadsRunner; use jpegxl_rs::{decoder_builder, DecodeError}; +use bytemuck::*; + // it works even if the item is not documented: #[pyclass(module = "pillow_jxl")] @@ -24,49 +27,86 @@ struct ImageInfo { } impl ImageInfo { - fn from(item: Metadata) -> ImageInfo { + fn from(item: &Metadata, mode: String) -> ImageInfo { ImageInfo { - mode: Self::mode(item.num_color_channels, item.has_alpha_channel), + mode: mode, width: item.width, height: item.height, num_channels: item.num_color_channels, has_alpha_channel: item.has_alpha_channel, } } +} - fn mode(num_channels: u32, has_alpha_channel: bool) -> String { - match (num_channels, has_alpha_channel) { - (1, false) => "L".to_string(), - (1, true) => "LA".to_string(), - (3, false) => "RGB".to_string(), - (3, true) => "RGBA".to_string(), - _ => panic!("Unsupported number of channels"), +fn mode_8_bits(info: &Metadata, pixel_format: &'static str) -> PyResult<&'static str> { + let mode = match (info.num_color_channels, info.has_alpha_channel) { + (3, false) => "RGB", + (3, true) => "RGBA", + (1, false) => "L", + (1, true) => "LA", + (channels, has_alpha) => { + return Err(PyValueError::new_err(format!( + "Unsupported number of channels for {pixel_format}: {channels}, has_alpha: {has_alpha}" + ))) } - } + }; + Ok(mode) } -pub fn convert_pixels(pixels: Pixels) -> Vec { +pub fn convert_pixels(pixels: Pixels, info: &Metadata) -> PyResult<(Vec, &'static str)> { let mut result = Vec::new(); - match pixels { - Pixels::Uint8(pixels) => { - for pixel in pixels { - result.push(pixel); - } + let mode = match (pixels, info.num_color_channels, info.has_alpha_channel) { + (Pixels::Uint8(pixels), _, _) => { + // 8 bits RGB(A) and L(A) + result.extend_from_slice(&pixels); + mode_8_bits(info, "Uint8") } - Pixels::Uint16(pixels) => { - for pixel in pixels { - result.push((pixel >> 8) as u8); - result.push(pixel as u8); - } + (Pixels::Uint16(pixels), 1, false) => { + // 16 bits: I;16 + result.extend_from_slice( + try_cast_slice(&pixels).map_err(|e| PyValueError::new_err(e.to_string()))?, + ); + Ok("I;16") } - Pixels::Float(pixels) => { + (Pixels::Uint16(pixels), _, _) => { + // RGB(A) and LA must be converted to 8 bits + result.reserve(pixels.len()); + result.extend(pixels.into_iter().map(|pixel| (pixel >> 8) as u8)); + mode_8_bits(info, "Uint16") + } + (Pixels::Float(pixels), 1, false) => { + // 32 bits: F + result.extend_from_slice( + try_cast_slice(&pixels).map_err(|e| PyValueError::new_err(e.to_string()))?, + ); + Ok("F") + } + (Pixels::Float(pixels), _, _) => { + // RGB(A) and LA must be converted to 8 bits + result.reserve(pixels.len()); + result.extend(pixels.into_iter().map(|pixel| (pixel * 255.0) as u8)); + mode_8_bits(info, "Float") + } + (Pixels::Float16(pixels), 1, false) => { + // Convert to f32 (F) + result.reserve(pixels.len() * 4); for pixel in pixels { - result.push((pixel * 255.0) as u8); + result.extend_from_slice(&f32::from(pixel).to_ne_bytes()); } + Ok("F") } - Pixels::Float16(_) => panic!("Float16 is not supported yet"), - } - result + (Pixels::Float16(pixels), _, _) => { + // RGB(A) and LA must be converted to 8 bits + result.reserve(pixels.len()); + result.extend( + pixels + .into_iter() + .map(|pixel| (f32::from(pixel) * 255.0) as u8), + ); + mode_8_bits(info, "Float16") + } + }?; + Ok((result, mode)) } #[pyclass(module = "pillow_jxl")] @@ -112,21 +152,14 @@ impl Decoder { .parallel_runner(¶llel_runner) .build() .map_err(to_pyjxlerror)?; - let (info, img) = decoder.reconstruct(&data).map_err(to_pyjxlerror)?; - let (jpeg, img) = match img { - Data::Jpeg(x) => (true, x), - Data::Pixels(x) => (false, convert_pixels(x)), - }; - let icc_profile: Vec = match &info.icc_profile { - Some(x) => x.to_vec(), - None => Vec::new(), + let (metadata, img) = decoder.reconstruct(&data).map_err(to_pyjxlerror)?; + let (jpeg, (img, mode)) = match img { + Data::Jpeg(x) => (true, (x, "cf_jpeg")), + Data::Pixels(x) => (false, convert_pixels(x, &metadata)?), }; - Ok(( - jpeg, - ImageInfo::from(info), - Cow::Owned(img), - Cow::Owned(icc_profile), - )) + let info = ImageInfo::from(&metadata, mode.to_string()); + let icc_profile = metadata.icc_profile.unwrap_or_else(|| Vec::new()); + Ok((jpeg, info, Cow::Owned(img), Cow::Owned(icc_profile))) } } diff --git a/src/encode.rs b/src/encode.rs index 70fa771..c1af69c 100644 --- a/src/encode.rs +++ b/src/encode.rs @@ -3,14 +3,15 @@ use std::borrow::Cow; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; -use jpegxl_rs::encode::{ColorEncoding, EncoderFrame, EncoderResult, EncoderSpeed, Metadata}; +use bytemuck::cast_slice; + +use jpegxl_rs::encode::{ColorEncoding, EncoderFrame, EncoderSpeed, JxlEncoder, Metadata}; use jpegxl_rs::parallel::threads_runner::ThreadsRunner; use jpegxl_rs::{encoder_builder, EncodeError}; #[pyclass(module = "pillow_jxl")] pub struct Encoder { - num_channels: u32, - has_alpha: bool, + pixel_type: PixelType, lossless: bool, quality: f32, decoding_speed: i64, @@ -34,11 +35,25 @@ impl Encoder { use_original_profile: bool, num_threads: isize, ) -> PyResult { - let (num_channels, has_alpha) = match mode { - "RGBA" => (4, true), - "RGB" => (3, false), - "LA" => (2, true), - "L" => (1, false), + let pixel_type = match mode { + "RGBA" => PixelType::Uint8 { + num_channels: 4, + has_alpha: true, + }, + "RGB" => PixelType::Uint8 { + num_channels: 3, + has_alpha: false, + }, + "LA" => PixelType::Uint8 { + num_channels: 2, + has_alpha: true, + }, + "L" => PixelType::Uint8 { + num_channels: 1, + has_alpha: false, + }, + "F" => PixelType::Float32, + "I;16" => PixelType::Uint16, _ => { return Err(PyValueError::new_err( "Only RGB, RGBA, L, LA are supported.", @@ -61,8 +76,7 @@ impl Encoder { }; Ok(Self { - num_channels, - has_alpha, + pixel_type, lossless, quality, decoding_speed, @@ -91,7 +105,7 @@ impl Encoder { fn __repr__(&self) -> PyResult { Ok(format!( "Encoder(has_alpha={}, lossless={}, quality={}, decoding_speed={}, effort={}, num_threads={})", - self.has_alpha, self.lossless, self.quality, self.decoding_speed, self.effort, self.num_threads + self.pixel_type.has_alpha(), self.lossless, self.quality, self.decoding_speed, self.effort, self.num_threads )) } } @@ -119,18 +133,17 @@ impl Encoder { let mut encoder = encoder_builder() .parallel_runner(¶llel_runner) .jpeg_quality(self.quality) - .has_alpha(self.has_alpha) + .has_alpha(self.pixel_type.has_alpha()) .lossless(self.lossless) .use_container(self.use_container) .decoding_speed(self.decoding_speed) .build() .map_err(to_pyjxlerror)?; encoder.uses_original_profile = self.use_original_profile; - encoder.color_encoding = match self.num_channels { - 1 | 2 => ColorEncoding::SrgbLuma, - 3 | 4 => ColorEncoding::Srgb, - _ => return Err(PyValueError::new_err("Invalid num channels")), - }; + encoder.color_encoding = self + .pixel_type + .color_encoding() + .ok_or_else(|| PyValueError::new_err("Invalid pixel type"))?; encoder.speed = match self.effort { 1 => EncoderSpeed::Lightning, 2 => EncoderSpeed::Thunder, @@ -143,10 +156,9 @@ impl Encoder { 9 => EncoderSpeed::Tortoise, _ => return Err(PyValueError::new_err("Invalid effort")), }; - let buffer: EncoderResult = match jpeg_encode { - true => encoder.encode_jpeg(&data).map_err(to_pyjxlerror)?, + let buffer: Vec = match jpeg_encode { + true => encoder.encode_jpeg(&data).map_err(to_pyjxlerror)?.data, false => { - let frame = EncoderFrame::new(data).num_channels(self.num_channels); if let Some(exif_data) = exif { encoder .add_metadata(&Metadata::Exif(exif_data), true) @@ -162,15 +174,75 @@ impl Encoder { .add_metadata(&Metadata::Jumb(jumb_data), true) .map_err(to_pyjxlerror)? } - encoder - .encode_frame(&frame, width, height) + self.pixel_type + .encode_frame(&mut encoder, &data, width, height) .map_err(to_pyjxlerror)? } }; - Ok(Cow::Owned(buffer.data)) + Ok(Cow::Owned(buffer)) } } fn to_pyjxlerror(e: EncodeError) -> PyErr { PyRuntimeError::new_err(e.to_string()) } + +/// Represents the pixels type that can be found in PIL images +enum PixelType { + Uint8 { num_channels: u32, has_alpha: bool }, + Uint16, + Float32, +} + +impl PixelType { + fn has_alpha(&self) -> bool { + match self { + PixelType::Uint8 { has_alpha, .. } => *has_alpha, + _ => false, + } + } + fn color_encoding(&self) -> Option { + match self { + PixelType::Uint8 { + num_channels: 1 | 2, + .. + } => Some(ColorEncoding::SrgbLuma), + PixelType::Uint8 { + num_channels: 3 | 4, + .. + } => Some(ColorEncoding::Srgb), + PixelType::Uint8 { .. } => None, + PixelType::Uint16 => Some(ColorEncoding::SrgbLuma), + //FIXME: float pixels are meant to be linear, but who knows what pillow experimental modes are doing? + PixelType::Float32 => Some(ColorEncoding::LinearSrgbLuma), + } + } + fn encode_frame( + &self, + encoder: &mut JxlEncoder, + data: &[u8], + width: u32, + height: u32, + ) -> Result, EncodeError> { + match self { + PixelType::Uint8 { num_channels, .. } => { + let frame = EncoderFrame::new(data).num_channels(*num_channels); + encoder + .encode_frame::(&frame, width, height) + .map(|buf| buf.data) + } + PixelType::Uint16 => { + let frame = EncoderFrame::new(cast_slice(data)).num_channels(1); + encoder + .encode_frame::(&frame, width, height) + .map(|buf| buf.data) + } + PixelType::Float32 => { + let frame = EncoderFrame::new(cast_slice(data)).num_channels(1); + encoder + .encode_frame::(&frame, width, height) + .map(|buf| buf.data) + } + } + } +}