diff --git a/src/composition/learning_networks/signatures.jl b/src/composition/learning_networks/signatures.jl index b224cd47..d49aace9 100644 --- a/src/composition/learning_networks/signatures.jl +++ b/src/composition/learning_networks/signatures.jl @@ -356,7 +356,7 @@ function fitted_params(signature::Signature; supplement=true) end """ - output_and_report(signature, operation, Xnew) + output_and_report(signature, operation, Xnew...) **Private method.** @@ -375,3 +375,6 @@ function output_and_report(signature, operation, Xnew) report = MLJBase.report(signature_clone; supplement=false) return output, report end +# special case for static transformers with multiple inputs: +output_and_report(signature, operation, Xnew...) = + output_and_report(signature, operation, Xnew) diff --git a/src/operations.jl b/src/operations.jl index efa275ac..7bd4da4d 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -222,9 +222,9 @@ for operation in [:predict, :transform, :inverse_transform] quote - function $operation(model::NetworkComposite, fitresult, Xnew) + function $operation(model::NetworkComposite, fitresult, Xnew...) if $(QuoteNode(operation)) in MLJBase.operations(fitresult) - return output_and_report(fitresult, $(QuoteNode(operation)), Xnew) + return output_and_report(fitresult, $(QuoteNode(operation)), Xnew...) end throw(err_unsupported_operation($operation)) end diff --git a/test/composition/models/network_composite.jl b/test/composition/models/network_composite.jl index 87e064df..fd493a0d 100644 --- a/test/composition/models/network_composite.jl +++ b/test/composition/models/network_composite.jl @@ -645,6 +645,39 @@ end end +# # STATIC MODEL WITH MULTIPLE INPUTS + +mutable struct Balancer <: Static end +MLJBase.transform(::Balancer, _, X, y) = (selectrows(X, 1:2), selectrows(y, 1:2)) + +struct ThinWrapper <: StaticNetworkComposite + balancer +end + +function MLJBase.prefit(wrapper::ThinWrapper, verbosity) + + data = source() # empty source because there is no training data + Xs = first(data) + ys = last(data) + + mach=machine(:balancer) + + output = transform(mach, Xs, ys) + + (; transform = output) + +end + +balancer = Balancer() +wrapper = ThinWrapper(balancer) + +X, y = make_blobs() +mach = machine(wrapper) +Xunder, yunder = transform(mach, X, y) +@test Xunder == selectrows(X, 1:2) +@test yunder == selectrows(y, 1:2) + + # # MACHINE INTEGRATION TESTS