forked from Eventual-Inc/Daft
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT] Map Getter (Eventual-Inc#2255)
Closes Eventual-Inc#2240 --------- Co-authored-by: Sammy Sidhu <samster25@users.noreply.github.com>
- Loading branch information
Showing
16 changed files
with
322 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
use common_error::{DaftError, DaftResult}; | ||
|
||
use crate::{ | ||
array::ops::DaftCompare, | ||
datatypes::{logical::MapArray, DaftArrayType}, | ||
DataType, Series, | ||
}; | ||
|
||
fn single_map_get(structs: &Series, key_to_get: &Series) -> DaftResult<Series> { | ||
let (keys, values) = { | ||
let struct_array = structs.struct_()?; | ||
(struct_array.get("key")?, struct_array.get("value")?) | ||
}; | ||
let mask = keys.equal(key_to_get)?; | ||
let filtered = values.filter(&mask)?; | ||
if filtered.is_empty() { | ||
Ok(Series::full_null("value", values.data_type(), 1)) | ||
} else if filtered.len() == 1 { | ||
Ok(filtered) | ||
} else { | ||
filtered.head(1) | ||
} | ||
} | ||
|
||
impl MapArray { | ||
pub fn map_get(&self, key_to_get: &Series) -> DaftResult<Series> { | ||
let value_type = if let DataType::Map(inner_dtype) = self.data_type() { | ||
match *inner_dtype.clone() { | ||
DataType::Struct(fields) if fields.len() == 2 => { | ||
fields[1].dtype.clone() | ||
} | ||
_ => { | ||
return Err(DaftError::TypeError(format!( | ||
"Expected inner type to be a struct type with two fields: key and value, got {:?}", | ||
inner_dtype | ||
))) | ||
} | ||
} | ||
} else { | ||
return Err(DaftError::TypeError(format!( | ||
"Expected input to be a map type, got {:?}", | ||
self.data_type() | ||
))); | ||
}; | ||
|
||
match key_to_get.len() { | ||
1 => { | ||
let mut result = Vec::with_capacity(self.len()); | ||
for series in self.physical.into_iter() { | ||
match series { | ||
Some(s) if !s.is_empty() => result.push(single_map_get(&s, key_to_get)?), | ||
_ => result.push(Series::full_null("value", &value_type, 1)), | ||
} | ||
} | ||
Series::concat(&result.iter().collect::<Vec<_>>()) | ||
} | ||
len if len == self.len() => { | ||
let mut result = Vec::with_capacity(len); | ||
for (i, series) in self.physical.into_iter().enumerate() { | ||
match (series, key_to_get.slice(i, i + 1)?) { | ||
(Some(s), k) if !s.is_empty() => result.push(single_map_get(&s, &k)?), | ||
_ => result.push(Series::full_null("value", &value_type, 1)), | ||
} | ||
} | ||
Series::concat(&result.iter().collect::<Vec<_>>()) | ||
} | ||
_ => Err(DaftError::ValueError(format!( | ||
"Expected key to have length 1 or length equal to the map length, got {}", | ||
key_to_get.len() | ||
))), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ mod len; | |
mod list; | ||
mod list_agg; | ||
mod log; | ||
mod map; | ||
mod mean; | ||
mod merge_sketch; | ||
mod null; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
use crate::datatypes::DataType; | ||
use crate::series::Series; | ||
use common_error::DaftError; | ||
use common_error::DaftResult; | ||
|
||
impl Series { | ||
pub fn map_get(&self, key: &Series) -> DaftResult<Series> { | ||
match self.data_type() { | ||
DataType::Map(_) => self.map()?.map_get(key), | ||
dt => Err(DaftError::TypeError(format!( | ||
"map.get not implemented for {}", | ||
dt | ||
))), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
use crate::ExprRef; | ||
use daft_core::{ | ||
datatypes::{DataType, Field}, | ||
schema::Schema, | ||
series::Series, | ||
}; | ||
|
||
use crate::functions::FunctionExpr; | ||
use common_error::{DaftError, DaftResult}; | ||
|
||
use super::super::FunctionEvaluator; | ||
|
||
pub(super) struct GetEvaluator {} | ||
|
||
impl FunctionEvaluator for GetEvaluator { | ||
fn fn_name(&self) -> &'static str { | ||
"map_get" | ||
} | ||
|
||
fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> { | ||
match inputs { | ||
[input, key] => match (input.to_field(schema), key.to_field(schema)) { | ||
(Ok(input_field), Ok(_)) => match input_field.dtype { | ||
DataType::Map(inner) => match inner.as_ref() { | ||
DataType::Struct(fields) if fields.len() == 2 => { | ||
let value_dtype = &fields[1].dtype; | ||
Ok(Field::new("value", value_dtype.clone())) | ||
} | ||
_ => Err(DaftError::TypeError(format!( | ||
"Expected input map to have struct values with 2 fields, got {}", | ||
inner | ||
))), | ||
}, | ||
_ => Err(DaftError::TypeError(format!( | ||
"Expected input to be a map, got {}", | ||
input_field.dtype | ||
))), | ||
}, | ||
(Err(e), _) | (_, Err(e)) => Err(e), | ||
}, | ||
_ => Err(DaftError::SchemaMismatch(format!( | ||
"Expected 2 input args, got {}", | ||
inputs.len() | ||
))), | ||
} | ||
} | ||
|
||
fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> { | ||
match inputs { | ||
[input, key] => input.map_get(key), | ||
_ => Err(DaftError::ValueError(format!( | ||
"Expected 2 input args, got {}", | ||
inputs.len() | ||
))), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
mod get; | ||
|
||
use get::GetEvaluator; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
use crate::{Expr, ExprRef}; | ||
|
||
use super::FunctionEvaluator; | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] | ||
pub enum MapExpr { | ||
Get, | ||
} | ||
|
||
impl MapExpr { | ||
#[inline] | ||
pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { | ||
use MapExpr::*; | ||
match self { | ||
Get => &GetEvaluator {}, | ||
} | ||
} | ||
} | ||
|
||
pub fn get(input: ExprRef, key: ExprRef) -> ExprRef { | ||
Expr::Function { | ||
func: super::FunctionExpr::Map(MapExpr::Get), | ||
inputs: vec![input, key], | ||
} | ||
.into() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.