Skip to content

Commit

Permalink
fix(utils): fix decimal flag bound (risingwavelabs#3808)
Browse files Browse the repository at this point in the history
  • Loading branch information
Li0k authored Jul 12, 2022
1 parent 6141051 commit 7a1331b
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 28 deletions.
96 changes: 71 additions & 25 deletions src/common/src/util/ordered/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,31 +244,77 @@ mod tests {

#[test]
fn test_ordered_row_deserializer() {
let order_types = vec![OrderType::Descending, OrderType::Ascending];
let serializer = OrderedRowSerializer::new(order_types.clone());
let schema = vec![DataType::Varchar, DataType::Int16];
let row1 = Row(vec![Some(Utf8("abc".to_string())), Some(Int16(5))]);
let row2 = Row(vec![Some(Utf8("abd".to_string())), Some(Int16(5))]);
let row3 = Row(vec![Some(Utf8("abc".to_string())), Some(Int16(6))]);
let rows = vec![row1.clone(), row2.clone(), row3.clone()];
let deserializer = OrderedRowDeserializer::new(schema, order_types.clone());
let mut array = vec![];
for row in &rows {
let mut row_bytes = vec![];
serializer.serialize(row, &mut row_bytes);
array.push(row_bytes);
pub use crate::types::decimal::Decimal;
use crate::types::ScalarImpl::{self, *};
{
// basic
let order_types = vec![OrderType::Descending, OrderType::Ascending];
let serializer = OrderedRowSerializer::new(order_types.clone());
let schema = vec![DataType::Varchar, DataType::Int16];
let row1 = Row(vec![Some(Utf8("abc".to_string())), Some(Int16(5))]);
let row2 = Row(vec![Some(Utf8("abd".to_string())), Some(Int16(5))]);
let row3 = Row(vec![Some(Utf8("abc".to_string())), Some(Int16(6))]);
let rows = vec![row1.clone(), row2.clone(), row3.clone()];
let deserializer = OrderedRowDeserializer::new(schema, order_types.clone());
let mut array = vec![];
for row in &rows {
let mut row_bytes = vec![];
serializer.serialize(row, &mut row_bytes);
array.push(row_bytes);
}
assert_eq!(
deserializer.deserialize(&array[0]).unwrap(),
OrderedRow::new(row1, &order_types)
);
assert_eq!(
deserializer.deserialize(&array[1]).unwrap(),
OrderedRow::new(row2, &order_types)
);
assert_eq!(
deserializer.deserialize(&array[2]).unwrap(),
OrderedRow::new(row3, &order_types)
);
}

{
// decimal

let order_types = vec![OrderType::Descending, OrderType::Ascending];
let serializer = OrderedRowSerializer::new(order_types.clone());
let schema = vec![DataType::Varchar, DataType::Decimal];

let row1 = Row(vec![
Some(Utf8("abc".to_string())),
Some(ScalarImpl::Decimal(Decimal::NaN)),
]);
let row2 = Row(vec![
Some(Utf8("abd".to_string())),
Some(ScalarImpl::Decimal(Decimal::PositiveINF)),
]);
let row3 = Row(vec![
Some(Utf8("abc".to_string())),
Some(ScalarImpl::Decimal(Decimal::NegativeINF)),
]);
let rows = vec![row1.clone(), row2.clone(), row3.clone()];
let deserializer = OrderedRowDeserializer::new(schema, order_types.clone());
let mut array = vec![];
for row in &rows {
let mut row_bytes = vec![];
serializer.serialize(row, &mut row_bytes);
array.push(row_bytes);
}
assert_eq!(
deserializer.deserialize(&array[0]).unwrap(),
OrderedRow::new(row1, &order_types)
);
assert_eq!(
deserializer.deserialize(&array[1]).unwrap(),
OrderedRow::new(row2, &order_types)
);
assert_eq!(
deserializer.deserialize(&array[2]).unwrap(),
OrderedRow::new(row3, &order_types)
);
}
assert_eq!(
deserializer.deserialize(&array[0]).unwrap(),
OrderedRow::new(row1, &order_types)
);
assert_eq!(
deserializer.deserialize(&array[1]).unwrap(),
OrderedRow::new(row2, &order_types)
);
assert_eq!(
deserializer.deserialize(&array[2]).unwrap(),
OrderedRow::new(row3, &order_types)
);
}
}
9 changes: 6 additions & 3 deletions src/utils/memcomparable/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ use serde::de::{

use crate::error::{Error, Result};

const DECIMAL_FLAG_LOW_BOUND: u8 = 0x6;
const DECIMAL_FLAG_UP_BOUND: u8 = 0x23;

/// A structure that deserializes memcomparable bytes into Rust values.
pub struct Deserializer<B: Buf> {
input: MaybeFlip<B>,
Expand Down Expand Up @@ -125,7 +128,7 @@ impl<B: Buf> Deserializer<B> {

fn read_decimal(&mut self) -> Result<Vec<u8>> {
let flag = self.input.get_u8();
if !(0x8..=0x22).contains(&flag) {
if !(DECIMAL_FLAG_LOW_BOUND..=DECIMAL_FLAG_UP_BOUND).contains(&flag) {
return Err(Error::InvalidBytesEncoding(flag));
}
let mut byte_array = vec![flag];
Expand Down Expand Up @@ -517,7 +520,7 @@ impl<B: Buf> Deserializer<B> {
// whether the decimal is negative or not.
let mut neg: bool = false;
let exponent = match byte_array[0] {
0x06 => {
DECIMAL_FLAG_LOW_BOUND => {
// NaN
return Ok((0, 31));
}
Expand Down Expand Up @@ -547,7 +550,7 @@ impl<B: Buf> Deserializer<B> {
(byte_array[0] - 0x17) as i8
}
0x22 => byte_array[1] as i8,
0x23 => {
DECIMAL_FLAG_UP_BOUND => {
// Positive INF
return Ok((0, 30));
}
Expand Down

0 comments on commit 7a1331b

Please sign in to comment.