Skip to content

Commit

Permalink
feat(core): introduce PySessionContext and the API for the availabl…
Browse files Browse the repository at this point in the history
…e function list (#858)
  • Loading branch information
goldmedal committed Oct 31, 2024
1 parent 39e9f0c commit 7c50712
Show file tree
Hide file tree
Showing 14 changed files with 427 additions and 143 deletions.
6 changes: 3 additions & 3 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def __init__(self, manifest_str: str, function_path: str):
self.function_path = function_path

def rewrite(self, sql: str) -> str:
from wren_core import read_remote_function_list, transform_sql
from wren_core import SessionContext

try:
functions = read_remote_function_list(self.function_path)
return transform_sql(self.manifest_str, functions, sql)
session_context = SessionContext(self.manifest_str, self.function_path)
return session_context.transform_sql(sql)
except Exception as e:
raise RewriteError(str(e))

Expand Down
13 changes: 13 additions & 0 deletions ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,16 @@ def validate(data_source: DataSource, rule_name: str, dto: ValidateDTO) -> Respo
)
validator.validate(rule_name, dto.parameters, dto.manifest_str)
return Response(status_code=204)


@router.get("/functions")
def functions() -> Response:
from wren_core import SessionContext

from app.config import get_config

config = get_config()
session_context = SessionContext(None, config.remote_function_list_path)
functions = [f.to_dict() for f in session_context.get_available_functions()]

return JSONResponse(functions)
22 changes: 22 additions & 0 deletions ibis-server/tests/routers/v3/connector/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,28 @@ def test_query_with_remote_function(manifest_str, postgres: PostgresContainer):

config.set_remote_function_list_path(None)

def test_function_list():
config = get_config()
config.set_remote_function_list_path(file_path("resource/functions.csv"))

response = client.get(
url="/v3/connector/functions",
)
assert response.status_code == 200
result = response.json()
assert len(result) == 261
add_two = next(filter(lambda x: x["name"] == "add_two", result))
assert add_two["name"] == "add_two"

config.set_remote_function_list_path(None)

response = client.get(
url="/v3/connector/functions",
)
assert response.status_code == 200
result = response.json()
assert len(result) == 258

def _to_connection_info(pg: PostgresContainer):
return {
"host": pg.get_container_host_ip(),
Expand Down
17 changes: 9 additions & 8 deletions wren-core-py/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions wren-core-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ csv = "1.3.0"
serde = { version = "1.0.210", features = ["derive"] }
env_logger = "0.11.5"
log = "0.4.22"
tokio = "1.40.0"

[build-dependencies]
pyo3-build-config = "0.21.2"
240 changes: 240 additions & 0 deletions wren-core-py/src/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::errors::CoreError;
use crate::remote_functions::PyRemoteFunction;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use log::debug;
use pyo3::{pyclass, pymethods, PyErr, PyResult};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;
use wren_core::logical_plan::utils::map_data_type;
use wren_core::mdl::context::create_ctx_with_mdl;
use wren_core::mdl::function::{
ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType,
RemoteFunction,
};
use wren_core::mdl::manifest::Manifest;
use wren_core::{mdl, AggregateUDF, AnalyzedWrenMDL, ScalarUDF, WindowUDF};

/// The Python wrapper for the Wren Core session context.
#[pyclass(name = "SessionContext")]
#[derive(Clone)]
pub struct PySessionContext {
ctx: wren_core::SessionContext,
mdl: Arc<AnalyzedWrenMDL>,
remote_functions: Vec<RemoteFunction>,
}

impl Hash for PySessionContext {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.mdl.hash(state);
self.remote_functions.hash(state);
}
}

impl Default for PySessionContext {
fn default() -> Self {
Self {
ctx: wren_core::SessionContext::new(),
mdl: Arc::new(AnalyzedWrenMDL::default()),
remote_functions: vec![],
}
}
}

#[pymethods]
impl PySessionContext {
/// Create a new session context.
///
/// if `mdl_base64` is provided, the session context will be created with the given MDL. Otherwise, an empty MDL will be created.
/// if `remote_functions_path` is provided, the session context will be created with the remote functions defined in the CSV file.
#[new]
pub fn new(
mdl_base64: Option<&str>,
remote_functions_path: Option<&str>,
) -> PyResult<Self> {
let remote_functions =
Self::read_remote_function_list(remote_functions_path).unwrap();
let remote_functions: Vec<RemoteFunction> = remote_functions
.into_iter()
.map(|f| f.into())
.collect::<Vec<_>>();

let ctx = wren_core::SessionContext::new();

let Some(mdl_base64) = mdl_base64 else {
return Ok(Self {
ctx,
mdl: Arc::new(AnalyzedWrenMDL::default()),
remote_functions,
});
};

let mdl_json_bytes = BASE64_STANDARD
.decode(mdl_base64)
.map_err(CoreError::from)?;
let mdl_json = String::from_utf8(mdl_json_bytes).map_err(CoreError::from)?;
let manifest =
serde_json::from_str::<Manifest>(&mdl_json).map_err(CoreError::from)?;

let Ok(analyzed_mdl) = AnalyzedWrenMDL::analyze(manifest) else {
return Err(CoreError::new("Failed to analyze manifest").into());
};

let analyzed_mdl = Arc::new(analyzed_mdl);

let runtime = tokio::runtime::Runtime::new().unwrap();
let ctx = runtime
.block_on(create_ctx_with_mdl(&ctx, Arc::clone(&analyzed_mdl), false))
.map_err(CoreError::from)?;

remote_functions.iter().for_each(|remote_function| {
debug!("Registering remote function: {:?}", remote_function);
Self::register_remote_function(&ctx, remote_function);
});

Ok(Self {
ctx,
mdl: analyzed_mdl,
remote_functions,
})
}

/// Transform the given Wren SQL to the equivalent Planned SQL.
pub fn transform_sql(&self, sql: &str) -> PyResult<String> {
mdl::transform_sql(Arc::clone(&self.mdl), &self.remote_functions, sql)
.map_err(|e| PyErr::from(CoreError::from(e)))
}

/// Get the available functions in the session context.
pub fn get_available_functions(&self) -> PyResult<Vec<PyRemoteFunction>> {
let mut builder = self
.remote_functions
.iter()
.map(|f| (f.name.clone(), f.clone().into()))
.collect::<HashMap<String, PyRemoteFunction>>();
self.ctx
.state()
.scalar_functions()
.iter()
.for_each(|(name, _func)| {
match builder.entry(name.clone()) {
Entry::Occupied(_) => {}
Entry::Vacant(entry) => {
entry.insert(PyRemoteFunction {
function_type: "scalar".to_string(),
name: name.clone(),
// TODO: get function return type from SessionState
return_type: None,
param_names: None,
param_types: None,
description: None,
});
}
}
});
self.ctx
.state()
.aggregate_functions()
.iter()
.for_each(|(name, _func)| {
match builder.entry(name.clone()) {
Entry::Occupied(_) => {}
Entry::Vacant(entry) => {
entry.insert(PyRemoteFunction {
function_type: "aggregate".to_string(),
name: name.clone(),
// TODO: get function return type from SessionState
return_type: None,
param_names: None,
param_types: None,
description: None,
});
}
}
});
self.ctx
.state()
.window_functions()
.iter()
.for_each(|(name, _func)| {
match builder.entry(name.clone()) {
Entry::Occupied(_) => {}
Entry::Vacant(entry) => {
entry.insert(PyRemoteFunction {
function_type: "window".to_string(),
name: name.clone(),
// TODO: get function return type from SessionState
return_type: None,
param_names: None,
param_types: None,
description: None,
});
}
}
});
Ok(builder.values().cloned().collect())
}
}

impl PySessionContext {
fn register_remote_function(
ctx: &wren_core::SessionContext,
remote_function: &RemoteFunction,
) {
match &remote_function.function_type {
FunctionType::Scalar => {
ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new(
&remote_function.name,
map_data_type(&remote_function.return_type),
)))
}
FunctionType::Aggregate => {
ctx.register_udaf(AggregateUDF::new_from_impl(ByPassAggregateUDF::new(
&remote_function.name,
map_data_type(&remote_function.return_type),
)))
}
FunctionType::Window => {
ctx.register_udwf(WindowUDF::new_from_impl(ByPassWindowFunction::new(
&remote_function.name,
map_data_type(&remote_function.return_type),
)))
}
}
}

fn read_remote_function_list(path: Option<&str>) -> PyResult<Vec<PyRemoteFunction>> {
debug!(
"Reading remote function list from {}",
path.unwrap_or("path is not provided")
);
if let Some(path) = path {
Ok(csv::Reader::from_path(path)
.unwrap()
.into_deserialize::<PyRemoteFunction>()
.filter_map(Result::ok)
.collect::<Vec<_>>())
} else {
Ok(vec![])
}
}
}
12 changes: 12 additions & 0 deletions wren-core-py/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ impl From<CoreError> for PyErr {
}
}

impl From<PyErr> for CoreError {
fn from(err: PyErr) -> Self {
CoreError::new(&err.to_string())
}
}

impl From<DecodeError> for CoreError {
fn from(err: DecodeError) -> Self {
CoreError::new(&err.to_string())
Expand All @@ -41,3 +47,9 @@ impl From<serde_json::Error> for CoreError {
CoreError::new(&err.to_string())
}
}

impl From<wren_core::DataFusionError> for CoreError {
fn from(err: wren_core::DataFusionError) -> Self {
CoreError::new(&err.to_string())
}
}
Loading

0 comments on commit 7c50712

Please sign in to comment.