diff --git a/Cargo.lock b/Cargo.lock index d0ed1ef2da3b..a2a5fa951360 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -825,12 +825,6 @@ dependencies = [ "once_cell", "paste", "ron", - "rustpython-ast", - "rustpython-bytecode", - "rustpython-compiler", - "rustpython-compiler-core", - "rustpython-parser", - "rustpython-vm", "serde", "snafu", "statrs", @@ -4361,6 +4355,7 @@ dependencies = [ "common-function", "common-query", "common-recordbatch", + "common-telemetry", "console", "datafusion", "datafusion-common", @@ -4369,6 +4364,7 @@ dependencies = [ "datatypes", "futures", "futures-util", + "paste", "query", "ron", "rustpython-ast", diff --git a/component/script/python/example/calc_rv.py b/component/script/python/example/calc_rv.py index 934503324132..7a1e7b40814a 100644 --- a/component/script/python/example/calc_rv.py +++ b/component/script/python/example/calc_rv.py @@ -44,13 +44,15 @@ def as_table(kline: list): "rv_60d", "rv_90d", "rv_180d" -], -sql="select open_time, close from k_line") +]) def calc_rvs(open_time, close): - from greptime import vector, log, prev, sqrt, datetime, pow, sum + from greptime import vector, log, prev, sqrt, datetime, pow, sum, last + import greptime as g def calc_rv(close, open_time, time, interval): mask = (open_time < time) & (open_time > time - interval) close = close[mask] + open_time = open_time[mask] + close = g.interval(open_time, close, datetime("10m"), lambda x:last(x)) avg_time_interval = (open_time[-1] - open_time[0])/(len(open_time)-1) ref = log(close/prev(close)) @@ -60,10 +62,10 @@ def calc_rv(close, open_time, time, interval): # how to get env var, # maybe through accessing scope and serde then send to remote? timepoint = open_time[-1] - rv_7d = calc_rv(close, open_time, timepoint, datetime("7d")) - rv_15d = calc_rv(close, open_time, timepoint, datetime("15d")) - rv_30d = calc_rv(close, open_time, timepoint, datetime("30d")) - rv_60d = calc_rv(close, open_time, timepoint, datetime("60d")) - rv_90d = calc_rv(close, open_time, timepoint, datetime("90d")) - rv_180d = calc_rv(close, open_time, timepoint, datetime("180d")) + rv_7d = vector([calc_rv(close, open_time, timepoint, datetime("7d"))]) + rv_15d = vector([calc_rv(close, open_time, timepoint, datetime("15d"))]) + rv_30d = vector([calc_rv(close, open_time, timepoint, datetime("30d"))]) + rv_60d = vector([calc_rv(close, open_time, timepoint, datetime("60d"))]) + rv_90d = vector([calc_rv(close, open_time, timepoint, datetime("90d"))]) + rv_180d = vector([calc_rv(close, open_time, timepoint, datetime("180d"))]) return rv_7d, rv_15d, rv_30d, rv_60d, rv_90d, rv_180d diff --git a/component/script/python/example/kline.json b/component/script/python/example/kline.json index 9928fceca4ef..c83e8f271a88 100644 --- a/component/script/python/example/kline.json +++ b/component/script/python/example/kline.json @@ -7,7 +7,7 @@ { "symbol": "BTCUSD", "period": "1", - "open_time": 1581231300, + "open_time": 300, "open": "10107", "high": "10109.34", "low": "10106.71", @@ -16,7 +16,7 @@ { "symbol": "BTCUSD", "period": "1", - "open_time": 1581231360, + "open_time": 900, "open": "10106.79", "high": "10109.27", "low": "10105.92", @@ -25,7 +25,7 @@ { "symbol": "BTCUSD", "period": "1", - "open_time": 1581231420, + "open_time": 1200, "open": "10106.09", "high": "10108.75", "low": "10104.66", @@ -34,7 +34,7 @@ { "symbol": "BTCUSD", "period": "1", - "open_time": 1581231480, + "open_time": 1800, "open": "10108.73", "high": "10109.52", "low": "10106.07", @@ -43,7 +43,7 @@ { "symbol": "BTCUSD", "period": "1", - "open_time": 1581231540, + "open_time": 2400, "open": "10106.38", "high": "10109.48", "low": "10104.81", @@ -52,7 +52,7 @@ { "symbol": "BTCUSD", "period": "1", - "open_time": 1581231600, + "open_time": 3000, "open": "10106.95", "high": "10109.48", "low": "10106.6", @@ -61,7 +61,7 @@ { "symbol": "BTCUSD", "period": "1", - "open_time": 1581231660, + "open_time": 3600, "open": "10107.55", "high": "10109.28", "low": "10104.68", @@ -70,7 +70,7 @@ { "symbol": "BTCUSD", "period": "1", - "open_time": 1581231720, + "open_time": 4200, "open": "10104.68", "high": "10109.18", "low": "10104.14", @@ -79,7 +79,7 @@ { "symbol": "BTCUSD", "period": "1", - "open_time": 1581231780, + "open_time": 4800, "open": "10108.8", "high": "10117.36", "low": "10108.8", @@ -88,7 +88,7 @@ { "symbol": "BTCUSD", "period": "1", - "open_time": 1581231840, + "open_time": 5400, "open": "10115.96", "high": "10119.19", "low": "10115.96", @@ -97,7 +97,7 @@ { "symbol": "BTCUSD", "period": "1", - "open_time": 1581231900, + "open_time": 6000, "open": "10117.08", "high": "10120.73", "low": "10116.96", diff --git a/component/script/python/greptime/__init__.py b/component/script/python/greptime/__init__.py index 8db592523486..febef7f323e1 100644 --- a/component/script/python/greptime/__init__.py +++ b/component/script/python/greptime/__init__.py @@ -1,4 +1,4 @@ from .greptime import coprocessor, copr -from .greptime import vector, log, prev, sqrt, pow, datetime, sum +from .greptime import vector, log, prev, next, first, last, sqrt, pow, datetime, sum, interval from .mock import mock_tester from .cfg import set_conn_addr, get_conn_addr diff --git a/component/script/python/greptime/greptime.py b/component/script/python/greptime/greptime.py index 8ac0a41c3b49..229a7a4db677 100644 --- a/component/script/python/greptime/greptime.py +++ b/component/script/python/greptime/greptime.py @@ -89,6 +89,11 @@ def datatype(self): def filter(self, lst_bool): return self[lst_bool] +def last(lst): + return lst[-1] + +def first(lst): + return lst[0] def prev(lst): ret = np.zeros(len(lst)) @@ -96,35 +101,22 @@ def prev(lst): ret[0] = nan return ret +def next(lst): + ret = np.zeros(len(lst)) + ret[:-1] = lst[1:] + ret[-1] = nan + return ret -def query(sql: str): - pass - - -def interval(arr: list, duration: int, fill, step: None | int = None, explicitOffset=False): +def interval(ts: vector, arr: vector, duration: int, func): """ Note that this is a mock function with same functionailty to the actual Python Coprocessor `arr` is a vector of integral or temporal type. - - `duration` is the length of sliding window - - `step` being the length when sliding window take a step - - `fill` indicate how to fill missing value: - - "prev": use previous - - "post": next - - "linear": linear interpolation, if not possible to interpolate certain types, fallback to prev - - "null": use null - - "none": do not interpolate """ - if step is None: - step = duration - - tot_len = int(np.ceil(len(arr) // step)) - slices = np.zeros((tot_len, int(duration))) - for idx, start in enumerate(range(0, len(arr), step)): - slices[idx] = arr[start:(start + duration)] - return slices + start = np.min(ts) + end = np.max(ts) + masks = [(ts >= i) & (ts <= (i+duration)) for i in range(start, end, duration)] + lst_res = [func(arr[mask]) for mask in masks] + return lst_res def factor(unit: str) -> int: diff --git a/component/script/python/greptime/mock.py b/component/script/python/greptime/mock.py index fed5a21a47ac..b37306827fde 100644 --- a/component/script/python/greptime/mock.py +++ b/component/script/python/greptime/mock.py @@ -4,7 +4,7 @@ """ from typing import Any import numpy as np -from .greptime import i32,i64,f32,f64, vector, interval, query, prev, datetime, log, sum, sqrt, pow, nan, copr, coprocessor +from .greptime import i32,i64,f32,f64, vector, interval, prev, datetime, log, sum, sqrt, pow, nan, copr, coprocessor import inspect import functools diff --git a/component/script/python/test.py b/component/script/python/test.py index e1e32079536a..0aeb42eb1f4d 100644 --- a/component/script/python/test.py +++ b/component/script/python/test.py @@ -26,6 +26,16 @@ def get_db(req:str): return requests.get("http://{}{}".format(get_conn_addr(), req)) if __name__ == "__main__": + with open("component/script/python/example/kline.json", "r") as kline_file: + kline = json.load(kline_file) + table = as_table(kline["result"]) + close = table["close"] + open_time = table["open_time"] + env = {"close":close, "open_time": open_time} + + res = mock_tester(calc_rvs, env=env) + print("Mock result:", [i[0] for i in res]) + exit() if len(sys.argv)!=2: raise Exception("Expect only one address as cmd's args") set_conn_addr(sys.argv[1]) @@ -42,11 +52,6 @@ def get_db(req:str): open_time = table["open_time"] init_table(close, open_time) - # print(repr(close), repr(open_time)) - # print("calc_rv:", calc_rv(close, open_time, open_time[-1]+datetime("10m"), datetime("7d"))) - env = {"close":close, "open_time": open_time} - # print("env:", env) - print("Mock result:", mock_tester(calc_rvs, env=env)) real = calc_rvs() print(real) try: diff --git a/src/common/function/Cargo.toml b/src/common/function/Cargo.toml index e462b6be173d..d1dde280dc93 100644 --- a/src/common/function/Cargo.toml +++ b/src/common/function/Cargo.toml @@ -20,12 +20,6 @@ num = "0.4" num-traits = "0.2" once_cell = "1.10" paste = "1.0" -rustpython-ast = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"} -rustpython-bytecode = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"} -rustpython-compiler = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"} -rustpython-compiler-core = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"} -rustpython-parser = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"} -rustpython-vm = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"} snafu = { version = "0.7", features = ["backtraces"] } statrs = "0.15" diff --git a/src/script/Cargo.toml b/src/script/Cargo.toml index a88007dd2f4f..139046bf1aad 100644 --- a/src/script/Cargo.toml +++ b/src/script/Cargo.toml @@ -15,6 +15,7 @@ python = [ "dep:rustpython-compiler-core", "dep:rustpython-bytecode", "dep:rustpython-ast", + "dep:paste" ] [dependencies] @@ -23,6 +24,7 @@ common-error = {path = "../common/error"} common-function = { path = "../common/function" } common-query = {path = "../common/query"} common-recordbatch = {path = "../common/recordbatch" } +common-telemetry = { path = "../common/telemetry" } console = "0.15" datafusion = {git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", optional = true} datafusion-common = {git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2"} @@ -38,8 +40,11 @@ rustpython-compiler = {git = "https://github.com/RustPython/RustPython", optiona rustpython-compiler-core = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"} rustpython-parser = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"} rustpython-vm = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"} +paste = { version = "1.0", optional = true} snafu = {version = "0.7", features = ["backtraces"]} sql = { path = "../sql" } +tokio = { version = "1.0", features = ["full"] } + [dev-dependencies] catalog = { path = "../catalog" } diff --git a/src/script/src/lib.rs b/src/script/src/lib.rs index b11a118d8cd5..1bb347e91498 100644 --- a/src/script/src/lib.rs +++ b/src/script/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(iterator_try_reduce)] pub mod engine; #[cfg(feature = "python")] pub mod python; diff --git a/src/script/src/python/builtins/mod.rs b/src/script/src/python/builtins/mod.rs index f2f48a374f36..75e05c3bdd6c 100644 --- a/src/script/src/python/builtins/mod.rs +++ b/src/script/src/python/builtins/mod.rs @@ -1,4 +1,5 @@ //! Builtin module contains GreptimeDB builtin udf/udaf + #[cfg(test)] #[allow(clippy::print_stdout)] mod test; @@ -271,6 +272,10 @@ pub(crate) mod greptime_builtin { use common_function::scalars::math::PowFunction; use common_function::scalars::{function::FunctionContext, Function}; + use datafusion::arrow::compute::comparison::{gt_eq_scalar, lt_eq_scalar}; + use datafusion::arrow::datatypes::DataType; + use datafusion::arrow::error::ArrowError; + use datafusion::arrow::scalar::{PrimitiveScalar, Scalar}; use datafusion::physical_plan::expressions; use datafusion_expr::ColumnarValue as DFColValue; use datafusion_physical_expr::math_expressions; @@ -278,16 +283,17 @@ pub(crate) mod greptime_builtin { use datatypes::arrow::array::{ArrayRef, NullArray}; use datatypes::arrow::compute; use datatypes::vectors::{ConstantVector, Float64Vector, Helper, Int64Vector}; - use rustpython_vm::builtins::{PyFloat, PyInt, PyStr}; - use rustpython_vm::function::OptionalArg; - use rustpython_vm::{AsObject, PyObjectRef, PyResult, VirtualMachine}; + use paste::paste; + use rustpython_vm::builtins::{PyFloat, PyFunction, PyInt, PyStr}; + use rustpython_vm::function::{FuncArgs, KwArgs, OptionalArg}; + use rustpython_vm::{AsObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine}; use crate::python::builtins::{ all_to_f64, eval_aggr_fn, from_df_err, try_into_columnar_value, try_into_py_obj, type_cast_error, }; - use crate::python::utils::is_instance; use crate::python::utils::PyVectorRef; + use crate::python::utils::{is_instance, py_vec_obj_to_array}; use crate::python::PyVector; #[pyfunction] @@ -655,23 +661,48 @@ pub(crate) mod greptime_builtin { let args = vec![base.as_vector_ref(), arg_pow]; let res = PowFunction::default() .eval(FunctionContext::default(), &args) - .unwrap(); + .map_err(|err| { + vm.new_runtime_error(format!( + "Fail to eval pow() withi given args: {args:?}, Error: {err}" + )) + })?; Ok(res.into()) } - // TODO: prev, sum, pow, sqrt, datetime, slice, and filter(through boolean array) + fn gen_none_array(data_type: DataType, len: usize, vm: &VirtualMachine) -> PyResult { + macro_rules! match_none_array { + ($VAR:ident, $LEN: ident, [$($TY:ident),*]) => { + paste!{ + match $VAR{ + $(DataType::$TY => Arc::new(arrow::array::[<$TY Array>]::from(vec![None;$LEN])), )* + _ => return Err(vm.new_type_error(format!("gen_none_array() does not support {:?}", data_type))) + } + } + }; + } + let ret: ArrayRef = match_none_array!( + data_type, + len, + [Boolean, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64] // We don't support float16 right now, it's not common in usage. + ); + Ok(ret) + } - /// TODO: for now prev(arr)[0] == arr[0], need better fill method #[pyfunction] fn prev(cur: PyVectorRef, vm: &VirtualMachine) -> PyResult { let cur: ArrayRef = cur.to_arrow_array(); if cur.len() == 0 { - return Err( - vm.new_runtime_error("Can't give prev for a zero length array!".to_string()) - ); + let ret = cur.slice(0, 0); + let ret = Helper::try_into_vector(&*ret).map_err(|e| { + vm.new_type_error(format!( + "Can't cast result into vector, result: {:?}, err: {:?}", + ret, e + )) + })?; + return Ok(ret.into()); } let cur = cur.slice(0, cur.len() - 1); // except the last one that is - let fill = cur.slice(0, 1); + let fill = gen_none_array(cur.data_type().to_owned(), 1, vm)?; let ret = compute::concatenate::concatenate(&[&*fill, &*cur]).map_err(|err| { vm.new_runtime_error(format!("Can't concat array[0] with array[0:-1]!{err:#?}")) })?; @@ -684,6 +715,211 @@ pub(crate) mod greptime_builtin { Ok(ret.into()) } + #[pyfunction] + fn next(cur: PyVectorRef, vm: &VirtualMachine) -> PyResult { + let cur: ArrayRef = cur.to_arrow_array(); + if cur.len() == 0 { + let ret = cur.slice(0, 0); + let ret = Helper::try_into_vector(&*ret).map_err(|e| { + vm.new_type_error(format!( + "Can't cast result into vector, result: {:?}, err: {:?}", + ret, e + )) + })?; + return Ok(ret.into()); + } + let cur = cur.slice(1, cur.len() - 1); // except the last one that is + let fill = gen_none_array(cur.data_type().to_owned(), 1, vm)?; + let ret = compute::concatenate::concatenate(&[&*cur, &*fill]).map_err(|err| { + vm.new_runtime_error(format!("Can't concat array[0] with array[0:-1]!{err:#?}")) + })?; + let ret = Helper::try_into_vector(&*ret).map_err(|e| { + vm.new_type_error(format!( + "Can't cast result into vector, result: {:?}, err: {:?}", + ret, e + )) + })?; + Ok(ret.into()) + } + + fn try_scalar_to_value(scalar: &dyn Scalar, vm: &VirtualMachine) -> PyResult { + let ty_error = |s: String| vm.new_type_error(s); + scalar + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ty_error(format!( + "expect scalar to be i64, found{:?}", + scalar.data_type() + )) + })? + .value() + .ok_or_else(|| ty_error("All element is Null in a time series array".to_string())) + } + + /// generate interval time point + fn gen_inteveral( + oldest: &dyn Scalar, + newest: &dyn Scalar, + duration: i64, + vm: &VirtualMachine, + ) -> PyResult>> { + use arrow::datatypes::DataType; + match (oldest.data_type(), newest.data_type()) { + (DataType::Int64, DataType::Int64) => (), + _ => { + return Err(vm.new_type_error(format!( + "Expect int64, found {:?} and {:?}", + oldest.data_type(), + newest.data_type() + ))); + } + } + + let oldest = try_scalar_to_value(oldest, vm)?; + let newest = try_scalar_to_value(newest, vm)?; + if oldest > newest { + return Err(vm.new_value_error(format!("{oldest} is greater than {newest}"))); + } + let ret = if duration > 0 { + (oldest..=newest) + .step_by(duration as usize) + .map(|v| PrimitiveScalar::new(DataType::Int64, Some(v))) + .collect::>() + } else { + return Err(vm.new_value_error(format!("duration: {duration} is not positive number."))); + }; + + Ok(ret) + } + + /// `func`: exec on sliding window slice of given `arr`, expect it to always return a PyVector of one element + /// `ts`: a vector of time stamp, expect to be Monotonous increase + /// `arr`: actual data vector + /// `duration`: the size of sliding window, also is the default step of sliding window's per step + #[pyfunction] + fn interval( + ts: PyVectorRef, + arr: PyVectorRef, + duration: i64, + func: PyRef, + vm: &VirtualMachine, + ) -> PyResult { + // TODO(discord9): change to use PyDict to mimic a table? + // then: table: PyDict, , lambda t: + // ts: PyStr, duration: i64 + // TODO: try to return a PyVector if possible, using concat array in arrow's compute module + // 1. slice them according to duration + let arrow_error = |err: ArrowError| vm.new_runtime_error(format!("Arrow Error: {err:#?}")); + let datatype_error = + |err: datatypes::Error| vm.new_runtime_error(format!("DataType Errors!: {err:#?}")); + let ts: ArrayRef = ts.to_arrow_array(); + let arr: ArrayRef = arr.to_arrow_array(); + let slices = { + let oldest = compute::aggregate::min(&*ts).map_err(arrow_error)?; + let newest = compute::aggregate::max(&*ts).map_err(arrow_error)?; + gen_inteveral(&*oldest, &*newest, duration, vm)? + }; + + let windows = { + slices + .iter() + .zip({ + let mut it = slices.iter(); + it.next(); + it + }) + .map(|(first, second)| { + compute::boolean::and(>_eq_scalar(&*ts, first), <_eq_scalar(&*ts, second)) + .map_err(arrow_error) + }) + .map(|mask| match mask { + Ok(mask) => compute::filter::filter(&*arr, &mask).map_err(arrow_error), + Err(e) => Err(e), + }) + .collect::, _>>()? + }; + + let apply_interval_function = |v: PyResult| match v { + Ok(v) => { + let args = FuncArgs::new(vec![v.into_pyobject(vm)], KwArgs::default()); + let ret = func.invoke(args, vm); + match ret{ + Ok(obj) => match py_vec_obj_to_array(&obj, vm, 1){ + Ok(v) => if v.len()==1{ + Ok(v) + }else{ + Err(vm.new_runtime_error(format!("Expect return's length to be at most one, found to be length of {}.", v.len()))) + }, + Err(err) => Err(vm + .new_runtime_error( + format!("expect `interval()`'s `func` return a PyVector(`vector`) or int/float/bool, found return to be {:?}, error msg: {err}", obj) + ) + ) + } + Err(e) => Err(e), + } + } + Err(e) => Err(e), + }; + + // 2. apply function on each slice + let fn_results = windows + .into_iter() + .map(|window| { + Helper::try_into_vector(window) + .map(PyVector::from) + .map_err(datatype_error) + }) + .map(apply_interval_function) + .collect::, _>>()?; + + // 3. get returen vector and concat them + let ret = fn_results + .into_iter() + .try_reduce(|acc, x| { + compute::concatenate::concatenate(&[acc.as_ref(), x.as_ref()]).map(Arc::from) + }) + .map_err(arrow_error)? + .unwrap_or_else(|| Arc::from(arr.slice(0, 0))); + // 4. return result vector + Ok(Helper::try_into_vector(ret).map_err(datatype_error)?.into()) + } + + /// return first element in a `PyVector` in sliced new `PyVector`, if vector's length is zero, return a zero sized slice instead + #[pyfunction] + fn first(arr: PyVectorRef, vm: &VirtualMachine) -> PyResult { + let arr: ArrayRef = arr.to_arrow_array(); + let ret = match arr.len() { + 0 => arr.slice(0, 0), + _ => arr.slice(0, 1), + }; + let ret = Helper::try_into_vector(&*ret).map_err(|e| { + vm.new_type_error(format!( + "Can't cast result into vector, result: {:?}, err: {:?}", + ret, e + )) + })?; + Ok(ret.into()) + } + + /// return last element in a `PyVector` in sliced new `PyVector`, if vector's length is zero, return a zero sized slice instead + #[pyfunction] + fn last(arr: PyVectorRef, vm: &VirtualMachine) -> PyResult { + let arr: ArrayRef = arr.to_arrow_array(); + let ret = match arr.len() { + 0 => arr.slice(0, 0), + _ => arr.slice(arr.len() - 1, 1), + }; + let ret = Helper::try_into_vector(&*ret).map_err(|e| { + vm.new_type_error(format!( + "Can't cast result into vector, result: {:?}, err: {:?}", + ret, e + )) + })?; + Ok(ret.into()) + } + #[pyfunction] fn datetime(input: &PyStr, vm: &VirtualMachine) -> PyResult { let mut parsed = Vec::new(); diff --git a/src/script/src/python/builtins/test.rs b/src/script/src/python/builtins/test.rs index ff3adcaed5ea..1d310abdcc5c 100644 --- a/src/script/src/python/builtins/test.rs +++ b/src/script/src/python/builtins/test.rs @@ -1,10 +1,25 @@ -use std::sync::Arc; +use std::{collections::HashMap, fs::File, io::Read, path::Path, sync::Arc}; -use arrow::array::PrimitiveArray; +use arrow::{ + array::{Float64Array, Int64Array, PrimitiveArray}, + compute::cast::CastOptions, + datatypes::DataType, +}; +use datatypes::vectors::VectorRef; +use ron::from_str as from_ron_string; use rustpython_vm::class::PyClassImpl; +use rustpython_vm::{ + builtins::{PyFloat, PyInt, PyList}, + convert::ToPyObject, + scope::Scope, + AsObject, PyObjectRef, VirtualMachine, +}; +use serde::{Deserialize, Serialize}; +use super::greptime_builtin; use super::*; use crate::python::utils::format_py_error; +use crate::python::{utils::is_instance, PyVector}; #[test] fn convert_scalar_to_py_obj_and_back() { rustpython_vm::Interpreter::with_init(Default::default(), |vm| { @@ -75,3 +90,384 @@ fn convert_scalar_to_py_obj_and_back() { assert!(expect_err.is_err()); }) } + +#[derive(Debug, Serialize, Deserialize)] +struct TestCase { + input: HashMap, + script: String, + expect: Result, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Var { + value: PyValue, + ty: DataType, +} + +/// for floating number comparsion +const EPS: f64 = 2.0 * f64::EPSILON; + +/// Null element just not supported for now for simplicity with writing test cases +#[derive(Debug, Serialize, Deserialize)] +enum PyValue { + FloatVec(Vec), + FloatVecWithNull(Vec>), + IntVec(Vec), + IntVecWithNull(Vec>), + Int(i64), + Float(f64), + Bool(bool), + Str(String), + /// for test if the length of FloatVec is of the same as `LenFloatVec.0` + LenFloatVec(usize), + /// for test if the length of IntVec is of the same as `LenIntVec.0` + LenIntVec(usize), + /// for test if result is within the bound of err using formula: + /// `(res - value).abs() < (value.abs()* error_percent)` + FloatWithError { + value: f64, + error_percent: f64, + }, +} + +impl PyValue { + /// compare if results is just as expect, not using PartialEq because it is not transtive .e.g. [1,2,3] == len(3) == [4,5,6] + fn just_as_expect(&self, other: &Self) -> bool { + match (self, other) { + (PyValue::FloatVec(a), PyValue::FloatVec(b)) => a + .iter() + .zip(b) + .fold(true, |acc, (x, y)| acc && (x - y).abs() <= EPS), + + (Self::FloatVecWithNull(a), Self::FloatVecWithNull(b)) => a == b, + + (PyValue::IntVec(a), PyValue::IntVec(b)) => a == b, + + (PyValue::Float(a), PyValue::Float(b)) => (a - b).abs() <= EPS, + + (PyValue::Int(a), PyValue::Int(b)) => a == b, + + // for just compare the length of vector + (PyValue::LenFloatVec(len), PyValue::FloatVec(v)) => *len == v.len(), + + (PyValue::LenIntVec(len), PyValue::IntVec(v)) => *len == v.len(), + + (PyValue::FloatVec(v), PyValue::LenFloatVec(len)) => *len == v.len(), + + (PyValue::IntVec(v), PyValue::LenIntVec(len)) => *len == v.len(), + + ( + Self::Float(v), + Self::FloatWithError { + value, + error_percent, + }, + ) + | ( + Self::FloatWithError { + value, + error_percent, + }, + Self::Float(v), + ) => (v - value).abs() < (value.abs() * error_percent), + (_, _) => false, + } + } +} + +fn is_float(ty: &DataType) -> bool { + matches!( + ty, + DataType::Float16 | DataType::Float32 | DataType::Float64 + ) +} + +/// unsigned included +fn is_int(ty: &DataType) -> bool { + matches!( + ty, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + ) +} + +impl PyValue { + fn to_py_obj(&self, vm: &VirtualMachine) -> Result { + let v: VectorRef = match self { + PyValue::FloatVec(v) => { + Arc::new(datatypes::vectors::Float64Vector::from_vec(v.clone())) + } + PyValue::IntVec(v) => Arc::new(datatypes::vectors::Int64Vector::from_vec(v.clone())), + PyValue::Int(v) => return Ok(vm.ctx.new_int(*v).into()), + PyValue::Float(v) => return Ok(vm.ctx.new_float(*v).into()), + Self::Bool(v) => return Ok(vm.ctx.new_bool(*v).into()), + Self::Str(s) => return Ok(vm.ctx.new_str(s.as_str()).into()), + _ => return Err(format!("Unsupported type:{self:#?}")), + }; + let v = PyVector::from(v).to_pyobject(vm); + Ok(v) + } + + fn from_py_obj(obj: &PyObjectRef, vm: &VirtualMachine) -> Result { + if is_instance::(obj, vm) { + let res = obj.payload::().unwrap(); + let res = res.to_arrow_array(); + let ty = res.data_type(); + if is_float(ty) { + let vec_f64 = arrow::compute::cast::cast( + res.as_ref(), + &DataType::Float64, + CastOptions { + wrapped: true, + partial: true, + }, + ) + .map_err(|err| format!("{err:#?}"))?; + assert_eq!(vec_f64.data_type(), &DataType::Float64); + let vec_f64 = vec_f64 + .as_any() + .downcast_ref::() + .ok_or(format!("Can't cast {vec_f64:#?} to Float64Array!"))?; + let ret = vec_f64 + .into_iter() + .map(|v| v.map(|inner| inner.to_owned())) + /* .enumerate() + .map(|(idx, v)| { + v.ok_or(format!( + "No null element expected, found one in {idx} position" + )) + .map(|v| v.to_owned()) + })*/ + .collect::>(); + if ret.iter().all(|x| x.is_some()) { + Ok(Self::FloatVec( + ret.into_iter().map(|i| i.unwrap()).collect(), + )) + } else { + Ok(Self::FloatVecWithNull(ret)) + } + } else if is_int(ty) { + let vec_int = arrow::compute::cast::cast( + res.as_ref(), + &DataType::Int64, + CastOptions { + wrapped: true, + partial: true, + }, + ) + .map_err(|err| format!("{err:#?}"))?; + assert_eq!(vec_int.data_type(), &DataType::Int64); + let vec_i64 = vec_int + .as_any() + .downcast_ref::() + .ok_or(format!("Can't cast {vec_int:#?} to Int64Array!"))?; + let ret: Vec = vec_i64 + .into_iter() + .enumerate() + .map(|(idx, v)| { + v.ok_or(format!( + "No null element expected, found one in {idx} position" + )) + .map(|v| v.to_owned()) + }) + .collect::>()?; + Ok(Self::IntVec(ret)) + } else { + Err(format!("unspupported DataType:{ty:#?}")) + } + } else if is_instance::(obj, vm) { + let res = obj + .to_owned() + .try_into_value::(vm) + .map_err(|err| format_py_error(err, vm).to_string())?; + Ok(Self::Int(res)) + } else if is_instance::(obj, vm) { + let res = obj + .to_owned() + .try_into_value::(vm) + .map_err(|err| format_py_error(err, vm).to_string())?; + Ok(Self::Float(res)) + } else if is_instance::(obj, vm) { + let res = obj.payload::().unwrap(); + let res: Vec = res + .borrow_vec() + .iter() + .map(|obj| { + let res = Self::from_py_obj(obj, vm).unwrap(); + assert!(matches!(res, Self::Float(_) | Self::Int(_))); + match res { + Self::Float(v) => Ok(v), + Self::Int(v) => Ok(v as f64), + _ => Err(format!("Expect only int/float in list, found {res:#?}")), + } + }) + .collect::>()?; + Ok(Self::FloatVec(res)) + } else { + todo!() + } + } +} + +#[test] +fn run_builtin_fn_testcases() { + let loc = Path::new("src/python/builtins/testcases.ron"); + let loc = loc.to_str().expect("Fail to parse path"); + let mut file = File::open(loc).expect("Fail to open file"); + let mut buf = String::new(); + file.read_to_string(&mut buf) + .expect("Fail to read to string"); + let testcases: Vec = from_ron_string(&buf).expect("Fail to convert to testcases"); + let cached_vm = rustpython_vm::Interpreter::with_init(Default::default(), |vm| { + vm.add_native_module("greptime", Box::new(greptime_builtin::make_module)); + // this can be in `.enter()` closure, but for clearity, put it in the `with_init()` + PyVector::make_class(&vm.ctx); + }); + for (idx, case) in testcases.into_iter().enumerate() { + print!("Testcase {idx} ..."); + cached_vm + .enter(|vm| { + let scope = vm.new_scope_with_builtins(); + case.input + .iter() + .try_for_each(|(k, v)| -> Result<(), String> { + let v = PyValue::to_py_obj(&v.value, vm).unwrap(); + set_item_into_scope(&scope, vm, k, v) + }) + .unwrap(); + let code_obj = vm + .compile( + &case.script, + rustpython_vm::compile::Mode::BlockExpr, + "".to_owned(), + ) + .map_err(|err| vm.new_syntax_error(&err)) + .unwrap(); + let res = vm.run_code_obj(code_obj, scope); + match res { + Err(e) => { + let err_res = format_py_error(e, vm).to_string(); + match case.expect{ + Ok(v) => { + println!("\nError:\n{err_res}"); + panic!("Expect Ok: {v:?}, found Error"); + }, + Err(err) => { + if !err_res.contains(&err){ + panic!("Error message not containing, expect {err_res}, found {}", err) + } + } + } + } + Ok(obj) => { + let ser = PyValue::from_py_obj(&obj, vm); + match (ser, case.expect){ + (Ok(real), Ok(expect)) => { + if !(real.just_as_expect(&expect.value)){ + panic!("Not as Expected for code:\n{}\n Real Value is {real:#?}, but expect {expect:#?}", case.script) + } + }, + (Err(real), Err(expect)) => { + if !expect.contains(&real){ + panic!("Expect Err(\"{expect}\"), found {real}") + } + }, + (Ok(real), Err(expect)) => panic!("Expect Err({expect}), found Ok({real:?})"), + (Err(real), Ok(expect)) => panic!("Expect Ok({expect:?}), found Err({real})"), + }; + } + }; + }); + println!(" passed!"); + } +} + +fn set_item_into_scope( + scope: &Scope, + vm: &VirtualMachine, + name: &str, + value: impl ToPyObject, +) -> Result<(), String> { + scope + .locals + .as_object() + .set_item(&name.to_owned(), vm.new_pyobj(value), vm) + .map_err(|err| { + format!( + "Error in setting var {name} in scope: \n{}", + format_py_error(err, vm) + ) + }) +} + +fn set_lst_of_vecs_in_scope( + scope: &Scope, + vm: &VirtualMachine, + arg_names: &[&str], + args: Vec, +) -> Result<(), String> { + let res = arg_names.iter().zip(args).try_for_each(|(name, vector)| { + scope + .locals + .as_object() + .set_item(name.to_owned(), vm.new_pyobj(vector), vm) + .map_err(|err| { + format!( + "Error in setting var {name} in scope: \n{}", + format_py_error(err, vm) + ) + }) + }); + res +} + +#[allow(unused_must_use)] +#[test] +fn test_vm() { + rustpython_vm::Interpreter::with_init(Default::default(), |vm| { + vm.add_native_module("udf_builtins", Box::new(greptime_builtin::make_module)); + // this can be in `.enter()` closure, but for clearity, put it in the `with_init()` + PyVector::make_class(&vm.ctx); + }) + .enter(|vm| { + let values = vec![1.0, 2.0, 3.0]; + let pows = vec![0i8, -1i8, 3i8]; + + let args: Vec = vec![ + Arc::new(datatypes::vectors::Float32Vector::from_vec(values)), + Arc::new(datatypes::vectors::Int8Vector::from_vec(pows)), + ]; + let args: Vec = args.into_iter().map(PyVector::from).collect(); + + let scope = vm.new_scope_with_builtins(); + set_lst_of_vecs_in_scope(&scope, vm, &["values", "pows"], args).unwrap(); + let code_obj = vm + .compile( + r#" +from udf_builtins import * +sin(values)"#, + rustpython_vm::compile::Mode::BlockExpr, + "".to_owned(), + ) + .map_err(|err| vm.new_syntax_error(&err)) + .unwrap(); + let res = vm.run_code_obj(code_obj, scope); + println!("{:#?}", res); + match res { + Err(e) => { + let err_res = format_py_error(e, vm).to_string(); + println!("Error:\n{err_res}"); + } + Ok(obj) => { + let _ser = PyValue::from_py_obj(&obj, vm); + dbg!(_ser); + } + } + }); +} diff --git a/src/script/src/python/builtins/testcases.ron b/src/script/src/python/builtins/testcases.ron index 2bc6e9fb7dc4..1d073effd193 100644 --- a/src/script/src/python/builtins/testcases.ron +++ b/src/script/src/python/builtins/testcases.ron @@ -1,6 +1,6 @@ // This is the file for UDF&UDAF binding from datafusion, // including most test for those function(except ApproxMedian which datafusion didn't implement) -// check src/scalars/py_udf_module/test.rs for more information +// check src/script/builtins/test.rs::run_builtin_fn_testcases() for more information [ // math expressions TestCase( @@ -670,7 +670,10 @@ pow(values, pows)"#, script: r#" from greptime import * pow(values, 1)"#, - expect: Err("TypeError: Can't cast operand of type `int` into `vector`.") + expect: Ok(( + value: FloatVec([ 1.0, 2.0, 3.0]), + ty: Float64 + )) ), TestCase( input: { @@ -781,4 +784,145 @@ from greptime import * sin(num)"#, expect: Err("Can't cast object of type str into vector or scalar") ), + TestCase( + input: {}, + script: r#" +from greptime import * +datetime("7d")"#, + expect: Ok(( + ty: Int64, + value: Int(604800) + )) + ), + TestCase( + input: {}, + script: r#" +from greptime import * +datetime("7dd")"#, + expect: Err("Unknown time unit") + ), + TestCase( + input: {}, + script: r#" +from greptime import * +datetime("d7")"#, + expect: Err("Python Runtime error, error:") + ), + TestCase( + input: { + "values": Var( + ty: Float64, + value: FloatVec([1.0, 2.0, 3.0]) + ), + "ts": Var( + ty: Int64, + value: IntVec([0, 9, 20]) + ), + }, + script: r#" +from greptime import * +interval(ts, values, 10, lambda x:sum(x))"#, + expect: Ok(( + ty: Float64, + value: FloatVec([3.0, 3.0]) + )) + ), + TestCase( + input: { + "values": Var( + ty: Float64, + value: FloatVec([1.0, 2.0, 3.0, 4.0]) + ), + "ts": Var( + ty: Int64, + value: IntVec([0, 9, 19, 20]) + ), + }, + script: r#" +from greptime import * +interval(ts, values, 10, lambda x:last(x))"#, + expect: Ok(( + ty: Float64, + value: FloatVec([2.0, 4.0]) + )) + ), + TestCase( + input: { + "values": Var( + ty: Float64, + value: FloatVec([1.0, 2.0, 3.0, 4.0]) + ), + "ts": Var( + ty: Int64, + value: IntVec([0, 9, 19, 20]) + ), + }, + script: r#" +from greptime import * +interval(ts, values, 10, lambda x:first(x))"#, + expect: Ok(( + ty: Float64, + value: FloatVec([1.0, 3.0]) + )) + ), + TestCase( + input: { + "values": Var( + ty: Float64, + value: FloatVec([]) + ) + }, + script: r#" +from greptime import * +prev(values)"#, + expect: Ok(( + ty: Float64, + value: FloatVec([1.0]) + )) + ), + TestCase( + input: { + "values": Var( + ty: Float64, + value: FloatVec([1.0, 2.0, 3.0]) + ) + }, + script: r#" +from greptime import * +prev(values)"#, + expect: Ok(( + ty: Float64, + value: FloatVecWithNull([None, Some(1.0), Some(2.0)]) + )) + ), + TestCase( + input: { + "values": Var( + ty: Float64, + value: FloatVec([1.0, 2.0, 3.0]) + ) + }, + script: r#" +from greptime import * +next(values)"#, + expect: Ok(( + ty: Float64, + value: FloatVecWithNull([Some(2.0), Some(3.0), None]) + )) + ), + TestCase( + input: { + "values": Var( + ty: Float64, + value: FloatVec([1.0, 2.0, 3.0]) + ) + }, + script: r#" +from greptime import * +sum(prev(values))"#, + expect: Ok(( + ty: Float64, + value: Float(3.0) + )) + ) ] diff --git a/src/script/src/python/coprocessor.rs b/src/script/src/python/coprocessor.rs index 7107e512771c..f821ee83c368 100644 --- a/src/script/src/python/coprocessor.rs +++ b/src/script/src/python/coprocessor.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use common_recordbatch::RecordBatch; use datafusion_common::record_batch::RecordBatch as DfRecordBatch; use datatypes::arrow; -use datatypes::arrow::array::{Array, ArrayRef, BooleanArray, PrimitiveArray}; +use datatypes::arrow::array::{Array, ArrayRef}; use datatypes::arrow::compute::cast::CastOptions; use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; use datatypes::schema::Schema; @@ -22,8 +22,10 @@ use rustpython_parser::{ }; use rustpython_vm as vm; use rustpython_vm::{class::PyClassImpl, AsObject}; +#[cfg(test)] +use serde::Deserialize; use snafu::{OptionExt, ResultExt}; -use vm::builtins::{PyBaseExceptionRef, PyBool, PyFloat, PyInt, PyTuple}; +use vm::builtins::{PyBaseExceptionRef, PyTuple}; use vm::scope::Scope; use vm::{Interpreter, PyObjectRef, VirtualMachine}; @@ -31,19 +33,12 @@ use crate::fail_parse_error; use crate::python::builtins::greptime_builtin; use crate::python::coprocessor::parse::{ret_parse_error, DecoratorArgs}; use crate::python::error::{ - ensure, ArrowSnafu, CoprParseSnafu, OtherSnafu, PyCompileSnafu, PyParseSnafu, Result, - TypeCastSnafu, + ensure, ret_other_error_with, ArrowSnafu, CoprParseSnafu, OtherSnafu, PyCompileSnafu, + PyParseSnafu, Result, TypeCastSnafu, }; -use crate::python::utils::format_py_error; +use crate::python::utils::{format_py_error, py_vec_obj_to_array}; use crate::python::{utils::is_instance, PyVector}; -fn ret_other_error_with(reason: String) -> OtherSnafu { - OtherSnafu { reason } -} - -#[cfg(test)] -use serde::Deserialize; - #[cfg_attr(test, derive(Deserialize))] #[derive(Debug, Clone, PartialEq, Eq)] pub struct AnnotationInfo { @@ -337,43 +332,9 @@ fn try_into_py_vector(fetch_args: Vec) -> Result> { Ok(args) } -/// convert a single PyVector or a number(a constant) into a Array(or a constant array) -fn py_vec_to_array_ref(obj: &PyObjectRef, vm: &VirtualMachine, col_len: usize) -> Result { - if is_instance::(obj, vm) { - let pyv = obj.payload::().with_context(|| { - ret_other_error_with(format!("can't cast obj {:?} to PyVector", obj)) - })?; - Ok(pyv.to_arrow_array()) - } else if is_instance::(obj, vm) { - let val = obj - .to_owned() - .try_into_value::(vm) - .map_err(|e| format_py_error(e, vm))?; - - let ret = PrimitiveArray::from_vec(vec![val; col_len]); - Ok(Arc::new(ret) as _) - } else if is_instance::(obj, vm) { - let val = obj - .to_owned() - .try_into_value::(vm) - .map_err(|e| format_py_error(e, vm))?; - let ret = PrimitiveArray::from_vec(vec![val; col_len]); - Ok(Arc::new(ret) as _) - } else if is_instance::(obj, vm) { - let val = obj - .to_owned() - .try_into_value::(vm) - .map_err(|e| format_py_error(e, vm))?; - - let ret = BooleanArray::from_iter(std::iter::repeat(Some(val)).take(5)); - Ok(Arc::new(ret) as _) - } else { - ret_other_error_with(format!("Expect a vector or a constant, found {:?}", obj)).fail() - } -} - /// convert a tuple of `PyVector` or one `PyVector`(wrapped in a Python Object Ref[`PyObjectRef`]) /// to a `Vec` +/// by default, a constant(int/float/bool) gives the a constant array of same length with input args fn try_into_columns( obj: &PyObjectRef, vm: &VirtualMachine, @@ -385,11 +346,11 @@ fn try_into_columns( })?; let cols = tuple .iter() - .map(|obj| py_vec_to_array_ref(obj, vm, col_len)) + .map(|obj| py_vec_obj_to_array(obj, vm, col_len)) .collect::>>()?; Ok(cols) } else { - let col = py_vec_to_array_ref(obj, vm, col_len)?; + let col = py_vec_obj_to_array(obj, vm, col_len)?; Ok(vec![col]) } } diff --git a/src/script/src/python/error.rs b/src/script/src/python/error.rs index c06d20a9b475..e4b5b63d1505 100644 --- a/src/script/src/python/error.rs +++ b/src/script/src/python/error.rs @@ -9,6 +9,10 @@ pub use snafu::ensure; use snafu::{prelude::Snafu, Backtrace}; pub type Result = std::result::Result; +pub(crate) fn ret_other_error_with(reason: String) -> OtherSnafu { + OtherSnafu { reason } +} + #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] pub enum Error { diff --git a/src/script/src/python/test.rs b/src/script/src/python/test.rs index 0e3456a3b349..f34520c7b225 100644 --- a/src/script/src/python/test.rs +++ b/src/script/src/python/test.rs @@ -198,10 +198,13 @@ fn test_calc_rvs() { "rv_180d" ]) def calc_rvs(open_time, close): - from greptime import vector, log, prev, sqrt, datetime, pow, sum + from greptime import vector, log, prev, sqrt, datetime, pow, sum, last + import greptime as g def calc_rv(close, open_time, time, interval): mask = (open_time < time) & (open_time > time - interval) close = close[mask] + open_time = open_time[mask] + close = g.interval(open_time, close, datetime("10m"), lambda x:last(x)) avg_time_interval = (open_time[-1] - open_time[0])/(len(open_time)-1) ref = log(close/prev(close)) @@ -211,12 +214,12 @@ def calc_rvs(open_time, close): # how to get env var, # maybe through accessing scope and serde then send to remote? timepoint = open_time[-1] - rv_7d = calc_rv(close, open_time, timepoint, datetime("7d")) - rv_15d = calc_rv(close, open_time, timepoint, datetime("15d")) - rv_30d = calc_rv(close, open_time, timepoint, datetime("30d")) - rv_60d = calc_rv(close, open_time, timepoint, datetime("60d")) - rv_90d = calc_rv(close, open_time, timepoint, datetime("90d")) - rv_180d = calc_rv(close, open_time, timepoint, datetime("180d")) + rv_7d = vector([calc_rv(close, open_time, timepoint, datetime("7d"))]) + rv_15d = vector([calc_rv(close, open_time, timepoint, datetime("15d"))]) + rv_30d = vector([calc_rv(close, open_time, timepoint, datetime("30d"))]) + rv_60d = vector([calc_rv(close, open_time, timepoint, datetime("60d"))]) + rv_90d = vector([calc_rv(close, open_time, timepoint, datetime("90d"))]) + rv_180d = vector([calc_rv(close, open_time, timepoint, datetime("180d"))]) return rv_7d, rv_15d, rv_30d, rv_60d, rv_90d, rv_180d "#; let close_array = PrimitiveArray::from_slice([ @@ -233,17 +236,8 @@ def calc_rvs(open_time, close): 10120.43, ]); let open_time_array = PrimitiveArray::from_slice([ - 1581231300i64, - 1581231360, - 1581231420, - 1581231480, - 1581231540, - 1581231600, - 1581231660, - 1581231720, - 1581231780, - 1581231840, - 1581231900, + 300i64, 900i64, 1200i64, 1800i64, 2400i64, 3000i64, 3600i64, 4200i64, 4800i64, 5400i64, + 6000i64, ]); let schema = Arc::new(Schema::from(vec![ Field::new("close", DataType::Float32, false), diff --git a/src/script/src/python/testcases.ron b/src/script/src/python/testcases.ron index ad2191211336..a6974d3c8e71 100644 --- a/src/script/src/python/testcases.ron +++ b/src/script/src/python/testcases.ron @@ -1,6 +1,6 @@ // This is the file for python coprocessor's testcases, // including coprocessor parsing test and execute test -// check src/scalars/python/test.rs for more information +// check src/script/python/test.rs::run_ron_testcases() for more information [ ( name: "correct_parse", diff --git a/src/script/src/python/utils.rs b/src/script/src/python/utils.rs index ba571e8bbeb8..810ff842f9e3 100644 --- a/src/script/src/python/utils.rs +++ b/src/script/src/python/utils.rs @@ -1,8 +1,15 @@ +use std::sync::Arc; + +use datafusion::arrow::array::{ArrayRef, BooleanArray, PrimitiveArray}; +use rustpython_vm::builtins::{PyBool, PyFloat, PyInt}; use rustpython_vm::{builtins::PyBaseExceptionRef, PyObjectRef, PyPayload, PyRef, VirtualMachine}; +use snafu::OptionExt; use snafu::{Backtrace, GenerateImplicitData}; use crate::python::error; +use crate::python::error::ret_other_error_with; use crate::python::PyVector; + pub(crate) type PyVectorRef = PyRef; /// use `rustpython`'s `is_instance` method to check if a PyObject is a instance of class. @@ -25,3 +32,41 @@ pub fn format_py_error(excep: PyBaseExceptionRef, vm: &VirtualMachine) -> error: backtrace: Backtrace::generate(), } } + +/// convert a single PyVector or a number(a constant)(wrapping in PyObjectRef) into a Array(or a constant array) +pub fn py_vec_obj_to_array( + obj: &PyObjectRef, + vm: &VirtualMachine, + col_len: usize, +) -> Result { + if is_instance::(obj, vm) { + let pyv = obj.payload::().with_context(|| { + ret_other_error_with(format!("can't cast obj {:?} to PyVector", obj)) + })?; + Ok(pyv.to_arrow_array()) + } else if is_instance::(obj, vm) { + let val = obj + .to_owned() + .try_into_value::(vm) + .map_err(|e| format_py_error(e, vm))?; + let ret = PrimitiveArray::from_vec(vec![val; col_len]); + Ok(Arc::new(ret) as _) + } else if is_instance::(obj, vm) { + let val = obj + .to_owned() + .try_into_value::(vm) + .map_err(|e| format_py_error(e, vm))?; + let ret = PrimitiveArray::from_vec(vec![val; col_len]); + Ok(Arc::new(ret) as _) + } else if is_instance::(obj, vm) { + let val = obj + .to_owned() + .try_into_value::(vm) + .map_err(|e| format_py_error(e, vm))?; + + let ret = BooleanArray::from_iter(std::iter::repeat(Some(val)).take(col_len)); + Ok(Arc::new(ret) as _) + } else { + ret_other_error_with(format!("Expect a vector or a constant, found {:?}", obj)).fail() + } +}