Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Serde for Custom ScalarUDFImpl traits #8706

Closed
alamb opened this issue Jan 1, 2024 · 13 comments · Fixed by #9395
Closed

Handle Serde for Custom ScalarUDFImpl traits #8706

alamb opened this issue Jan 1, 2024 · 13 comments · Fixed by #9395
Labels
enhancement New feature or request

Comments

@alamb
Copy link
Contributor

alamb commented Jan 1, 2024

Is your feature request related to a problem or challenge?

#8578 added a ScalarUDFImpl trait for implementing ScalarUDF.

@thinkharderdev said: #8578 (comment)

Nice! It would be very useful to be able to handle serde as well for custom implementations (perhaps in a different PR?). I think this could fit relatively easily into LogicalExtensionCodec

Describe the solution you'd like

No response

Describe alternatives you've considered

No response

Additional context

No response

@alamb alamb added the enhancement New feature or request label Jan 1, 2024
@yyy1000
Copy link
Contributor

yyy1000 commented Jan 24, 2024

I'd also like to work on this. 😃

@yyy1000
Copy link
Contributor

yyy1000 commented Jan 25, 2024

I need some guide on this. 🤔
It seems ScalarUDF can handle Serde now, what would this PR want to implement?
https://github.com/apache/arrow-datafusion/blob/7a0af5be2323443faa75cc5876651a72c3253af8/datafusion/proto/src/logical_plan/from_proto.rs#L1818-L1826

@alamb
Copy link
Contributor Author

alamb commented Jan 25, 2024

Maybe @thinkharderdev can comment -- perhaps nothing is needed?

@yyy1000
Copy link
Contributor

yyy1000 commented Jan 25, 2024

Aha, maybe I can help other issues first. 😄

@thinkharderdev
Copy link
Contributor

Hey @yyy1000 I think there is some work to do here. Currently the serialization for udf looks like

                    ScalarFunctionDefinition::UDF(fun) => Self {
                        expr_type: Some(ExprType::ScalarUdfExpr(
                            protobuf::ScalarUdfExprNode {
                                fun_name: fun.name().to_string(),
                                args,
                            },
                        )),
                    },

eg, we just use the name and assume wherever it is being deserialized will just have a registry where it can look up the scalar function definition by name.

But ideally we would be able to serialize a custom scalar function that has some sort of associated state. For example, a regex scalar function that actually contains the compiled regex in it's struct definition like:

struct MyRegexUdf {
   compiled_regex: Vec<u8> // just assume we have some serialization of the regex state machine here
}

impl ScalarUDFImpl for MyRegexUdf {
  fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
     // do something with compiled regex here
  }
}

Currently the mechanism that is used for this kind of thing is to define a custom LogicalExtensionCodec. So here I think we would add methods to that trait like

impl LogicalExtensionCodec for DefaultLogicalExtensionCodec {

  fn try_encode_scalar_udf(
        &self,
        _node: Arc<dyn ScalarUdfImpl>,
        _buf: &mut Vec<u8>,
    ) -> Result<()> {
        not_impl_err!("LogicalExtensionCodec is not provided")
    };


   fn try_decode_scalar_udf(
        &self,
        _buf: &[u8],
        _ctx: &SessionContext,
    ) -> Result<Arc<dyn ScalarUDFImpl>> {
        not_impl_err!("LogicalExtensionCodec is not provided")
    }
}

So then I would be able to define my own UDFs that contain internal state and then define an extension codec like

struct MyLogicalExtensionCodec;

impl LogicalExentionCodec for MyLogicalExtensionCodec {
    fn try_encode_scalar_udf(
        &self,
        node: Arc<dyn ScalarUdfImpl>,
        buf: &mut Vec<u8>,
    ) -> Result<()> {
        if let Some(regex_udf) = node.as_any().downcast_ref::<MyRegexUdf> {
           let proto = MyRegexUdfProto {
             compiled_regex: regex.compiled_regex.clone()
           }

           proto.encode(buf)?;

           Ok(())
        } else {
           not_impl_err!("LogicalExtensionCodec is not provided")
        }
    }; 

   fn try_decode_scalar_udf(
        &self,
        buf: &[u8],
        _ctx: &SessionContext,
    ) -> Result<Arc<dyn ScalarUDFImpl>> {
        if let Ok(proto) = MyRegexUdfProto::decode(buf) {
           Ok(Arc::new(MyRegexUdf { compiled_regex: proto.compiled_regex)))
        } else {
           not_impl_err!("LogicalExtensionCodec is not provided")
        }
    }

}

However, this doesn't play very nicely with how the serde is currently defined because we have no way to get a LogicalExtensionCodec in our impl TryFrom<&Expr> for protobuf::LogicalExprNode which we would need

@thinkharderdev
Copy link
Contributor

Perhaps this could be incorporated into FunctionRegistry and we have a serialize_expr(expr: &Expr, registry: &dyn FunctionRegistry) similar to how we do the deserialization?

@yyy1000
Copy link
Contributor

yyy1000 commented Jan 29, 2024

@thinkharderdev Much appreciated! I understand the issue now.
Another question is there some example code for deserialization so that I can refer to them to think how to write serialize_expr function?

@thinkharderdev
Copy link
Contributor

thinkharderdev commented Feb 26, 2024

@thinkharderdev Much appreciated! I understand the issue now. Another question is there some example code for deserialization so that I can refer to them to think how to write serialize_expr function?

Edit: On second thought we can probably just extend LogicalExtensionCodec and PhysicalExtensionCodec

Hey @yyy1000, looking at this a bit and I think what we want here is:

  1. Have ScalarUDFExprNode take an optional opaque payload to potentially contain the serialized function:
message ScalarUDFExprNode {
  string fun_name = 1;
  repeated LogicalExprNode args = 2;
  optional bytes fun_definition = 3;
}

Would do a similar thing for AggregateUDFExprNode and WindowExprNode

  1. Extend LogicalExtensionCodec and PhysicalExtensionCodec to support custom serde on UDF/UDAF/UDWF
pub trait LogicalExtensionCodec: Debug + Send + Sync {
    ... existing methods unchanged

    fn try_decode_udf(
        &self,
        name: &str,
        buf: &[u8],
    ) -> Result<Arc<ScalarUDF>>;

    fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<()>;

    fn try_decode_udaf(
        &self,
        name: &str,
        buf: &[u8],
    ) -> Result<Arc<AggregateUDF>>;

    fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()>;

    fn try_decode_udwf(
        &self,
        name: &str,
        buf: &[u8],
    ) -> Result<Arc<WindowUDF>>;

    fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()>;
}
  1. Replace impl TryFrom<&Expr> for protobuf::LogicalExprNode with a free function
pub fn serialize_expr(expr: &Expr, codec: &dyn LogicalExtensionCodec) -> Result<protobuf::LogicalExprNode> {
   ...
}

This would be mostly unchanged from the existing TryFrom implementation except for handling of Expr::ScalarFunction/AggregateFunction/WindowFunction. We would handle Expr::ScalarFunction something like:

            Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
                let args = args
                    .iter()
                    .map(|expr| expr.try_into())
                    .collect::<Result<Vec<_>, Error>>()?;
                match func_def {
                    ScalarFunctionDefinition::BuiltIn(fun) => {
                        let fun: protobuf::ScalarFunction = fun.try_into()?;
                        Self {
                            expr_type: Some(ExprType::ScalarFunction(
                                protobuf::ScalarFunctionNode {
                                    fun: fun.into(),
                                    args,
                                },
                            )),
                        }
                    }
                    ScalarFunctionDefinition::UDF(fun) => {
                        let mut buf = Vec::new();
                        codec.try_encode_udf(fun.as_ref(), &mut buf)?;

                        let fun_definition = if buf.is_empty() {
                            None
                        } else {
                            Some(buf)
                        };

                        Self {
                            expr_type: Some(ExprType::ScalarUdfExpr(
                                protobuf::ScalarUdfExprNode {
                                    fun_name: fun.name().to_string(),
                                    fun_definition,
                                    args,
                                },
                            )),
                        }
                    },
                    ScalarFunctionDefinition::Name(_) => {
                        return Err(Error::NotImplemented(
                    "Proto serialization error: Trying to serialize a unresolved function"
                        .to_string(),
                ));
                    }
                }
            }

  1. Similarly, in existing parse_expr we try and use the extension codec to deserialize the function if fun_definition is present:
pub fn parse_expr(
    proto: &protobuf::LogicalExprNode,
    registry: &dyn FunctionRegistry,
    codec: &dyn LogicalExtensionCodec,
) -> Result<Expr, Error> {
  ... handling of other expr types unchanged
  
  ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, fun_definition, args }) => {
  
     let scalar_fn = match fun_definition {
        Some(buf) => codec.try_decode_udf(&fun_name, &buf)?,
        None => registry.udf(fun_name.as_str())?,
     };
      Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
            scalar_fn,
            args.iter()
                    .map(|expr| parse_expr(expr, registry))
                    .collect::<Result<Vec<_>, Error>>()?,
            )))
     }

}

If you don't have bandwidth now to work on this let me know, My team can take this up as we are hoping to be able to use this functionality soon.

@yyy1000
Copy link
Contributor

yyy1000 commented Feb 27, 2024

Thanks so much for your detailed instructions! @thinkharderdev
I can try to implement it now and will seek your help if needed :)

@yyy1000
Copy link
Contributor

yyy1000 commented Feb 27, 2024

A question I have is, replace Replace impl TryFrom<&Expr> for protobuf::LogicalExprNode with a free function may lead a lot of places like expr.try_into() to serialize_expr(expr, codec), would this be OK for the code base? @thinkharderdev

@thinkharderdev
Copy link
Contributor

A question I have is, replace Replace impl TryFrom<&Expr> for protobuf::LogicalExprNode with a free function may lead a lot of places like expr.try_into() to serialize_expr(expr, codec), would this be OK for the code base? @thinkharderdev

Yeah I think so. @alamb do you see any issues with that?

@yyy1000
Copy link
Contributor

yyy1000 commented Feb 27, 2024

Also I wonder how to deal with some other places which don't need a LogicalExtensionCodec like to_bytes
https://github.com/apache/arrow-datafusion/blob/c439bc73b6a9ba9efa4c8a9b5d2fb6111e660e74/datafusion/proto/src/bytes/mod.rs#L87-L92. I think add a param of type LogicalExtensionCodec to the function will work, and what to do when calling this method, should I initialize a impl LogicalExtensionCodec? 🤔

@alamb
Copy link
Contributor Author

alamb commented Feb 28, 2024

A question I have is, replace Replace impl TryFrom<&Expr> for protobuf::LogicalExprNode with a free function may lead a lot of places like expr.try_into() to serialize_expr(expr, codec), would this be OK for the code base? @thinkharderdev

Yeah I think so. @alamb do you see any issues with that?

I don't see any specific issue and I don't think it would affect users of the crate much -- I don't think they typically use protobuf encoding directly, but rather go through the higher level apis like Expr::to_bytes():

https://github.com/apache/arrow-datafusion/blob/e62240969135e2236d100c8c0c01546a87950a80/datafusion/proto/src/lib.rs#L52-L68

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants