diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index da0f974a7490..19baf610d820 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -24,6 +24,7 @@ import "spark/connect/commands.proto"; import "spark/connect/expressions.proto"; import "spark/connect/relations.proto"; import "spark/connect/types.proto"; +import "spark/connect/ml.proto"; option java_multiple_files = true; option java_package = "org.apache.spark.connect.proto"; @@ -36,6 +37,7 @@ message Plan { oneof op_type { Relation root = 1; Command command = 2; + MlCommand ml_command = 3; } } @@ -261,6 +263,9 @@ message ExecutePlanResponse { // Special case for executing SQL commands. SqlCommandResult sql_command_result = 5; + // ML command response + MlCommandResponse ml_command_result = 100; + // Support arbitrary result objects. google.protobuf.Any extension = 999; } diff --git a/connector/connect/common/src/main/protobuf/spark/connect/ml.proto b/connector/connect/common/src/main/protobuf/spark/connect/ml.proto new file mode 100644 index 000000000000..1b6764a08583 --- /dev/null +++ b/connector/connect/common/src/main/protobuf/spark/connect/ml.proto @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = 'proto3'; + +package spark.connect; + +import "spark/connect/expressions.proto"; +import "spark/connect/relations.proto"; +import "spark/connect/ml_common.proto"; + +option java_multiple_files = true; +option java_package = "org.apache.spark.connect.proto"; + + +// MlEvaluator represents a ML Evaluator +message MlEvaluator { + // The name of the evaluator in the registry + string name = 1; + // param settings for the evaluator + MlParams params = 2; + // unique id of the evaluator + string uid = 3; +} + + +// a MlCommand is a type container that has exactly one ML command set +message MlCommand { + oneof ml_command_type { + // call `estimator.fit` and returns a model + Fit fit = 1; + // get model attribute + FetchModelAttr fetch_model_attr = 2; + // get model summary attribute + FetchModelSummaryAttr fetch_model_summary_attr = 3; + // load model + LoadModel load_model = 4; + // save model + SaveModel save_model = 5; + // call `evaluator.evaluate` + Evaluate evaluate = 6; + // save estimator or transformer + SaveStage save_stage = 7; + // load estimator or transformer + LoadStage load_stage = 8; + // save estimator + SaveEvaluator save_evaluator = 9; + // load estimator + LoadEvaluator load_evaluator = 10; + // copy model, returns new model reference id + CopyModel copy_model = 11; + // delete server side model object by model reference id + DeleteModel delete_model = 12; + } + + message Fit { + MlStage estimator = 1; + Relation dataset = 2; + } + + message Evaluate { + MlEvaluator evaluator = 1; + } + + message LoadModel { + string name = 1; + string path = 2; + } + + message SaveModel { + ModelRef model_ref = 1; + string path = 2; // saving path + bool overwrite = 3; + map options = 4; // saving options + } + + message LoadStage { + string name = 1; + string path = 2; + MlStage.StageType type = 3; + } + + message SaveStage { + MlStage stage = 1; + string path = 2; // saving path + bool overwrite = 3; + map options = 4; // saving options + } + + message LoadEvaluator { + string name = 1; + string path = 2; + } + + message SaveEvaluator { + MlEvaluator evaluator = 1; + string path = 2; // saving path + bool overwrite = 3; + map options = 4; // saving options + } + + message FetchModelAttr { + ModelRef model_ref = 1; + string name = 2; + } + + message FetchModelSummaryAttr { + ModelRef model_ref = 1; + string name = 2; + MlParams params = 3; + + // Evaluation dataset that it uses to computes + // the summary attribute + // If not set, get attributes from + // model.summary (i.e. the summary on training dataset) + optional Relation evaluation_dataset = 4; + } + + message CopyModel { + ModelRef model_ref = 1; + } + + message DeleteModel { + ModelRef model_ref = 1; + } +} + + +message MlCommandResponse { + oneof ml_command_response_type { + Expression.Literal literal = 1; + ModelInfo model_info = 2; + Vector vector = 3; + Matrix matrix = 4; + MlStage stage = 5; + ModelRef model_ref = 6; + } + message ModelInfo { + ModelRef model_ref = 1; + string model_uid = 2; + MlParams params = 3; + } +} diff --git a/connector/connect/common/src/main/protobuf/spark/connect/ml_common.proto b/connector/connect/common/src/main/protobuf/spark/connect/ml_common.proto new file mode 100644 index 000000000000..338167df27f5 --- /dev/null +++ b/connector/connect/common/src/main/protobuf/spark/connect/ml_common.proto @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = 'proto3'; + +package spark.connect; + +import "spark/connect/expressions.proto"; + +option java_multiple_files = true; +option java_package = "org.apache.spark.connect.proto"; + + +// MlParams stores param settings for +// ML Estimator / Transformer / Model / Evaluator +message MlParams { + // user-supplied params + map params = 1; + // default params + map default_params = 2; + + message ParamValue { + oneof param_value_type { + Expression.Literal literal = 1; + Vector vector = 2; + Matrix matrix = 3; + } + } +} + +// MlStage stores ML stage data (Estimator or Transformer) +message MlStage { + // The name of the stage in the registry + string name = 1; + // param settings for the stage + MlParams params = 2; + // unique id of the stage + string uid = 3; + StageType type = 4; + enum StageType { + STAGE_TYPE_UNSPECIFIED = 0; + STAGE_TYPE_ESTIMATOR = 1; + STAGE_TYPE_TRANSFORMER = 2; + } +} + +// ModelRef represents a reference to server side `Model` instance +message ModelRef { + // The ID is used to lookup the model instance in server side. + string id = 1; +} + +message Vector { + oneof one_of { + Dense dense = 1; + Sparse sparse = 2; + } + message Dense { + repeated double value = 1; + } + message Sparse { + int32 size = 1; + repeated double index = 2; + repeated double value = 3; + } +} + +message Matrix { + oneof one_of { + Dense dense = 1; + Sparse sparse = 2; + } + message Dense { + int32 num_rows = 1; + int32 num_cols = 2; + repeated double value = 3; + bool is_transposed = 4; + } + message Sparse { + int32 num_rows = 1; + int32 num_cols = 2; + repeated double colptr = 3; + repeated double row_index = 4; + repeated double value = 5; + bool is_transposed = 6; + } +} diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index aba965082ea2..5840496394a1 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -23,6 +23,7 @@ import "google/protobuf/any.proto"; import "spark/connect/expressions.proto"; import "spark/connect/types.proto"; import "spark/connect/catalog.proto"; +import "spark/connect/ml_common.proto"; option java_multiple_files = true; option java_package = "org.apache.spark.connect.proto"; @@ -83,6 +84,9 @@ message Relation { // Catalog API (experimental / unstable) Catalog catalog = 200; + // ML relation + MlRelation ml_relation = 300; + // This field is used to mark extensions to the protocol. When plugins generate arbitrary // relations they can add them here. During the planning the correct resolution is done. google.protobuf.Any extension = 998; @@ -90,6 +94,40 @@ message Relation { } } +message MlRelation { + oneof ml_relation_type { + ModelTransform model_transform = 1; + FeatureTransform feature_transform = 2; + ModelAttr model_attr = 3; + ModelSummaryAttr model_summary_attr = 4; + } + message ModelTransform { + Relation input = 1; + ModelRef model_ref = 2; + MlParams params = 3; + } + message FeatureTransform { + Relation input = 1; + MlStage transformer = 2; + } + message ModelAttr { + ModelRef model_ref = 1; + string name = 2; + } + message ModelSummaryAttr { + ModelRef model_ref = 1; + string name = 2; + MlParams params = 3; + + // Evaluation dataset that it uses to computes + // the summary attribute + // If not set, get attributes from + // model.summary (i.e. the summary on training dataset) + optional Relation evaluation_dataset = 4; + } +} + + // Used for testing purposes only. message Unknown {} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/AlgorithmRegisty.scala b/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/AlgorithmRegisty.scala new file mode 100644 index 000000000000..9585f9554802 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/AlgorithmRegisty.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.connect + +import org.apache.spark.connect.proto +import org.apache.spark.ml +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.classification.TrainingSummary +import org.apache.spark.ml.util.{HasTrainingSummary, MLWriter} +import org.apache.spark.sql.DataFrame + +object AlgorithmRegistry { + + def get(name: String): Algorithm = { + name match { + case "LogisticRegression" => new LogisticRegressionAlgorithm + case _ => + throw new IllegalArgumentException() + } + } + +} + +abstract class Algorithm { + + def initiateEstimator(uid: String): Estimator[_] + + def getModelAttr(model: Model[_], name: String): + Option[Either[proto.MlCommandResponse, DataFrame]] = { + + name match { + case "hasSummary" => + if (model.isInstanceOf[HasTrainingSummary]) { + Some(Left( + Serializer.serializeResponseValue(model.asInstanceOf[HasTrainingSummary].hasSummary) + )) + } else None + case "toString" => + Some(Left(Serializer.serializeResponseValue(model.toString))) + } + } + + def getModelSummaryAttr( + model: Model[_], + name: String, + datasetOpt: Option[DataFrame]): Either[proto.MlCommandResponse, DataFrame] + + def loadModel(path: String): Model[_] + + def loadEstimator(path: String): Estimator[_] + + protected def getEstimatorWriter(estimator: Estimator[_]): MLWriter + + protected def getModelWriter(model: Model[_]): MLWriter + + def _save( + writer: MLWriter, + path: String, + overwrite: Boolean, + options: Map[String, String]): Unit = { + if (overwrite) { + writer.overwrite() + } + options.map { case (k, v) => writer.option(k, v) } + writer.save(path) + } + + def saveModel( + model: Model[_], + path: String, + overwrite: Boolean, + options: Map[String, String]): Unit = { + _save(getModelWriter(model), path, overwrite, options) + } + + def saveEstimator( + estimator: Estimator[_], + path: String, + overwrite: Boolean, + options: Map[String, String]): Unit = { + _save(getEstimatorWriter(estimator), path, overwrite, options) + } +} + +class LogisticRegressionAlgorithm extends Algorithm { + + override def initiateEstimator(uid: String): Estimator[_] = { + new ml.classification.LogisticRegression(uid) + } + + override def loadModel(path: String): Model[_] = { + ml.classification.LogisticRegressionModel.load(path) + } + + override def loadEstimator(path: String): Estimator[_] = { + ml.classification.LogisticRegression.load(path) + } + + protected override def getModelWriter(model: Model[_]): MLWriter = { + model.asInstanceOf[ml.classification.LogisticRegressionModel].write + } + + protected override def getEstimatorWriter(estimator: Estimator[_]): MLWriter = { + estimator.asInstanceOf[ml.classification.LogisticRegression].write + } + + override def getModelAttr( + model: Model[_], + name: String): Option[Either[proto.MlCommandResponse, DataFrame]] = { + + super.getModelAttr(model, name).orElse { + val lorModel = model.asInstanceOf[ml.classification.LogisticRegressionModel] + + name match { + case "numClasses" => Some(Left(Serializer.serializeResponseValue(lorModel.numClasses))) + case "numFeatures" => Some(Left(Serializer.serializeResponseValue(lorModel.numFeatures))) + case "intercept" => Some(Left(Serializer.serializeResponseValue(lorModel.intercept))) + case "interceptVector" => + Some(Left(Serializer.serializeResponseValue(lorModel.interceptVector))) + case "coefficients" => Some(Left(Serializer.serializeResponseValue(lorModel.coefficients))) + case "coefficientMatrix" => + Some(Left(Serializer.serializeResponseValue(lorModel.coefficientMatrix))) + case _ => + None + } + } + } + + override def getModelSummaryAttr( + model: Model[_], + name: String, + datasetOpt: Option[DataFrame]): Either[proto.MlCommandResponse, DataFrame] = { + val lorModel = model.asInstanceOf[ml.classification.LogisticRegressionModel] + + val summary = datasetOpt match { + case Some(dataset) => lorModel.evaluate(dataset) + case None => lorModel.summary + } + val attrValueOpt = if (lorModel.numClasses <= 2) { + SummaryUtils.getBinaryClassificationSummaryAttr(summary.asBinary, name) + } else { + SummaryUtils.getClassificationSummaryAttr(summary, name) + } + attrValueOpt + .orElse(if (datasetOpt.isEmpty) { + SummaryUtils.getTrainingSummaryAttr(summary.asInstanceOf[TrainingSummary], name) + } else None) + .orElse { + val lorSummary = summary + name match { + case "probabilityCol" => + Some(Left(Serializer.serializeResponseValue(lorSummary.probabilityCol))) + case "featuresCol" => + Some(Left(Serializer.serializeResponseValue(lorSummary.featuresCol))) + case _ => + throw new IllegalArgumentException() + } + } + .get + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/MLCache.scala b/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/MLCache.scala new file mode 100644 index 000000000000..15c4582155b8 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/MLCache.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.connect + +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.ml.Model + +/** + * This class is for managing server side object that is used by spark connect client side code. + */ +class ObjectCache[T]( + val objectMap: ConcurrentHashMap[String, T] = new ConcurrentHashMap[String, T]() +) { + def register(obj: T): String = { + val objectId = UUID.randomUUID().toString.takeRight(12) + objectMap.put(objectId, obj) + objectId + } + + def get(id: String): T = objectMap.get(id) + + def remove(id: String): T = objectMap.remove(id) +} + +class ModelCache( + val cachedModel: ObjectCache[Model[_]] = new ObjectCache[Model[_]](), + val modelToHandlerMap: ConcurrentHashMap[String, Algorithm] = + new ConcurrentHashMap[String, Algorithm]()) { + def register(model: Model[_], algorithm: Algorithm): String = { + val refId = cachedModel.register(model) + modelToHandlerMap.put(refId, algorithm) + refId + } + + def get(refId: String): (Model[_], Algorithm) = { + (cachedModel.get(refId), modelToHandlerMap.get(refId)) + } + + def remove(refId: String): Unit = { + cachedModel.remove(refId) + modelToHandlerMap.remove(refId) + } +} + +case class MLCache(modelCache: ModelCache = new ModelCache()) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/MLHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/MLHandler.scala new file mode 100644 index 000000000000..bf543fdb1b49 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/MLHandler.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.connect + +import scala.collection.JavaConverters._ +import scala.language.existentials + +import org.apache.spark.connect.proto +import org.apache.spark.ml.Model +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.connect.ConnectSqlUtil +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter +import org.apache.spark.sql.connect.service.SessionHolder + +object MLHandler { + + def handleMlCommand( + sessionHolder: SessionHolder, + mlCommand: proto.MlCommand): proto.MlCommandResponse = { + mlCommand.getMlCommandTypeCase match { + case proto.MlCommand.MlCommandTypeCase.FIT => + val fitCommandProto = mlCommand.getFit + val estimatorProto = fitCommandProto.getEstimator + assert(estimatorProto.getType == proto.MlStage.StageType.STAGE_TYPE_ESTIMATOR) + + val algoName = fitCommandProto.getEstimator.getName + val algo = AlgorithmRegistry.get(algoName) + + val estimator = algo.initiateEstimator(estimatorProto.getUid) + MLUtils.setInstanceParams(estimator, estimatorProto.getParams) + val dataset = ConnectSqlUtil.parseRelationProto(fitCommandProto.getDataset, sessionHolder) + val model = estimator.fit(dataset).asInstanceOf[Model[_]] + val refId = sessionHolder.mlCache.modelCache.register(model, algo) + + proto.MlCommandResponse + .newBuilder() + .setModelRef( + proto.ModelRef.newBuilder().setId(refId) + ) + .build() + + case proto.MlCommand.MlCommandTypeCase.FETCH_MODEL_ATTR => + val getModelAttrProto = mlCommand.getFetchModelAttr + val (model, algo) = + sessionHolder.mlCache.modelCache.get(getModelAttrProto.getModelRef.getId) + algo.getModelAttr(model, getModelAttrProto.getName).get.left.get + + case proto.MlCommand.MlCommandTypeCase.FETCH_MODEL_SUMMARY_ATTR => + val getModelSummaryAttrProto = mlCommand.getFetchModelSummaryAttr + val (model, algo) = + sessionHolder.mlCache.modelCache.get(getModelSummaryAttrProto.getModelRef.getId) + // Create a copied model to avoid concurrently modify model params. + val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]] + MLUtils.setInstanceParams(copiedModel, getModelSummaryAttrProto.getParams) + + val datasetOpt = if (getModelSummaryAttrProto.hasEvaluationDataset) { + val evalDF = ConnectSqlUtil.parseRelationProto( + getModelSummaryAttrProto.getEvaluationDataset, + sessionHolder) + Some(evalDF) + } else None + + algo + .getModelSummaryAttr(copiedModel, getModelSummaryAttrProto.getName, datasetOpt) + .left + .get + + case proto.MlCommand.MlCommandTypeCase.LOAD_MODEL => + val loadModelProto = mlCommand.getLoadModel + val algo = AlgorithmRegistry.get(loadModelProto.getName) + val model = algo.loadModel(loadModelProto.getPath) + val refId = sessionHolder.mlCache.modelCache.register(model, algo) + + proto.MlCommandResponse + .newBuilder() + .setModelInfo( + proto.MlCommandResponse.ModelInfo.newBuilder() + .setModelRef(proto.ModelRef.newBuilder().setId(refId)) + .setModelUid(model.uid) + .setParams(MLUtils.convertInstanceParamsToProto(model)) + ) + .build() + + case proto.MlCommand.MlCommandTypeCase.SAVE_MODEL => + val saveModelProto = mlCommand.getSaveModel + val (model, algo) = + sessionHolder.mlCache.modelCache.get(saveModelProto.getModelRef.getId) + algo.saveModel( + model, + saveModelProto.getPath, + saveModelProto.getOverwrite, + saveModelProto.getOptionsMap.asScala.toMap) + proto.MlCommandResponse + .newBuilder() + .setLiteral(LiteralValueProtoConverter.toLiteralProto(null)) + .build() + + case proto.MlCommand.MlCommandTypeCase.LOAD_STAGE => + val loadStageProto = mlCommand.getLoadStage + val name = loadStageProto.getName + loadStageProto.getType match { + case proto.MlStage.StageType.STAGE_TYPE_ESTIMATOR => + val algo = AlgorithmRegistry.get(name) + val estimator = algo.loadEstimator(loadStageProto.getPath) + + proto.MlCommandResponse + .newBuilder() + .setStage( + proto.MlStage + .newBuilder() + .setName(name) + .setType(proto.MlStage.StageType.STAGE_TYPE_ESTIMATOR) + .setUid(estimator.uid) + .setParams(MLUtils.convertInstanceParamsToProto(estimator))) + .build() + case _ => + throw new UnsupportedOperationException() + } + + case proto.MlCommand.MlCommandTypeCase.SAVE_STAGE => + val saveStageProto = mlCommand.getSaveStage + val stageProto = saveStageProto.getStage + + stageProto.getType match { + case proto.MlStage.StageType.STAGE_TYPE_ESTIMATOR => + val name = stageProto.getName + val algo = AlgorithmRegistry.get(name) + val estimator = algo.initiateEstimator(stageProto.getUid) + MLUtils.setInstanceParams(estimator, stageProto.getParams) + algo.saveEstimator( + estimator, + saveStageProto.getPath, + saveStageProto.getOverwrite, + saveStageProto.getOptionsMap.asScala.toMap) + proto.MlCommandResponse + .newBuilder() + .setLiteral(LiteralValueProtoConverter.toLiteralProto(null)) + .build() + + case _ => + throw new UnsupportedOperationException() + } + + case proto.MlCommand.MlCommandTypeCase.COPY_MODEL => + val copyModelProto = mlCommand.getCopyModel + val (model, algo) = + sessionHolder.mlCache.modelCache.get(copyModelProto.getModelRef.getId) + val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]] + val refId = sessionHolder.mlCache.modelCache.register(copiedModel, algo) + proto.MlCommandResponse + .newBuilder() + .setModelRef(proto.ModelRef.newBuilder().setId(refId)) + .build() + + case proto.MlCommand.MlCommandTypeCase.DELETE_MODEL => + val modelRefId = mlCommand.getDeleteModel.getModelRef.getId + sessionHolder.mlCache.modelCache.remove(modelRefId) + proto.MlCommandResponse + .newBuilder() + .setLiteral(LiteralValueProtoConverter.toLiteralProto(null)) + .build() + + case _ => + throw new IllegalArgumentException() + } + } + + def transformMLRelation( + mlRelationProto: proto.MlRelation, + sessionHolder: SessionHolder): DataFrame = { + mlRelationProto.getMlRelationTypeCase match { + case proto.MlRelation.MlRelationTypeCase.MODEL_TRANSFORM => + val modelTransformRelationProto = mlRelationProto.getModelTransform + val (model, _) = + sessionHolder.mlCache.modelCache.get(modelTransformRelationProto.getModelRef.getId) + // Create a copied model to avoid concurrently modify model params. + val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]] + MLUtils.setInstanceParams(copiedModel, modelTransformRelationProto.getParams) + val inputDF = + ConnectSqlUtil.parseRelationProto(modelTransformRelationProto.getInput, sessionHolder) + copiedModel.transform(inputDF) + + case proto.MlRelation.MlRelationTypeCase.MODEL_ATTR => + val modelAttrProto = mlRelationProto.getModelAttr + val (model, algo) = + sessionHolder.mlCache.modelCache.get(modelAttrProto.getModelRef.getId) + algo.getModelAttr(model, modelAttrProto.getName).get.right.get + + case proto.MlRelation.MlRelationTypeCase.MODEL_SUMMARY_ATTR => + val modelSummaryAttr = mlRelationProto.getModelSummaryAttr + val (model, algo) = + sessionHolder.mlCache.modelCache.get(modelSummaryAttr.getModelRef.getId) + // Create a copied model to avoid concurrently modify model params. + val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]] + MLUtils.setInstanceParams(copiedModel, modelSummaryAttr.getParams) + + val datasetOpt = if (modelSummaryAttr.hasEvaluationDataset) { + val evalDF = + ConnectSqlUtil.parseRelationProto(modelSummaryAttr.getEvaluationDataset, sessionHolder) + Some(evalDF) + } else { + None + } + algo.getModelSummaryAttr(copiedModel, modelSummaryAttr.getName, datasetOpt).right.get + + case _ => + throw new IllegalArgumentException() + } + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/MLUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/MLUtils.scala new file mode 100644 index 000000000000..45375c31af92 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/MLUtils.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.connect + +import org.apache.spark.connect.proto +import org.apache.spark.ml.linalg.{Matrix, Vector} +import org.apache.spark.ml.param.Params +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter + +object MLUtils { + + def setInstanceParams(instance: Params, paramsProto: proto.MlParams): Unit = { + import scala.collection.JavaConverters._ + paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) => + val paramDef = instance.getParam(paramName) + val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto) + instance.set(paramDef, paramValue) + } + paramsProto.getDefaultParamsMap.asScala.foreach { case (paramName, paramValueProto) => + val paramDef = instance.getParam(paramName) + val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto) + instance._setDefault(paramDef -> paramValue) + } + } + + def parseParamValue(paramType: Class[_], paramValueProto: proto.MlParams.ParamValue): Any = { + paramValueProto.getParamValueTypeCase match { + case proto.MlParams.ParamValue.ParamValueTypeCase.VECTOR => + Serializer.deserializeVector(paramValueProto.getVector) + case proto.MlParams.ParamValue.ParamValueTypeCase.MATRIX => + Serializer.deserializeMatrix(paramValueProto.getMatrix) + case proto.MlParams.ParamValue.ParamValueTypeCase.LITERAL => + val value = LiteralValueProtoConverter.toCatalystValue(paramValueProto.getLiteral) + _convertParamValue(paramType, value) + case _ => + throw new IllegalArgumentException() + } + } + + def paramValueToProto(paramValue: Any): proto.MlParams.ParamValue = { + paramValue match { + case v: Vector => + proto.MlParams.ParamValue.newBuilder() + .setVector(Serializer.serializeVector(v)) + .build() + case m: Matrix => + proto.MlParams.ParamValue.newBuilder() + .setMatrix(Serializer.serializeMatrix(m)) + .build() + case _ => + val literalProto = LiteralValueProtoConverter.toLiteralProto(paramValue) + proto.MlParams.ParamValue.newBuilder() + .setLiteral(literalProto) + .build() + } + } + + def _convertParamValue(paramType: Class[_], value: Any): Any = { + // Some cases the param type might be mismatched with the value type. + // Because in python side we only have int / float type for numeric params. + // e.g.: + // param type is Int but client sends a Long type. + // param type is Long but client sends a Int type. + // param type is Float but client sends a Double type. + // param type is Array[Int] but client sends a Array[Long] type. + // param type is Array[Float] but client sends a Array[Double] type. + // param type is Array[Array[Int]] but client sends a Array[Array[Long]] type. + // param type is Array[Array[Float]] but client sends a Array[Array[Double]] type. + if (paramType == classOf[Byte]) { + value.asInstanceOf[java.lang.Number].byteValue() + } else if (paramType == classOf[Short]) { + value.asInstanceOf[java.lang.Number].shortValue() + } else if (paramType == classOf[Int]) { + value.asInstanceOf[java.lang.Number].intValue() + } else if (paramType == classOf[Long]) { + value.asInstanceOf[java.lang.Number].longValue() + } else if (paramType == classOf[Float]) { + value.asInstanceOf[java.lang.Number].floatValue() + } else if (paramType == classOf[Double]) { + value.asInstanceOf[java.lang.Number].doubleValue() + } else if (paramType.isArray) { + val compType = paramType.getComponentType + value.asInstanceOf[Array[_]].map { e => + _convertParamValue(compType, e) + } + } else { + value + } + } + + def convertInstanceParamsToProto(instance: Params): proto.MlParams = { + val builder = proto.MlParams.newBuilder() + instance.params.foreach { param => + val name = param.name + val valueOpt = instance.get(param) + val defaultValueOpt = instance.getDefault(param) + + if (valueOpt.isDefined) { + val valueProto = paramValueToProto(valueOpt.get) + builder.putParams(name, valueProto) + } + if (defaultValueOpt.isDefined) { + val defaultValueProto = paramValueToProto(defaultValueOpt.get) + builder.putDefaultParams(name, defaultValueProto) + } + } + builder.build() + } + +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/Serializer.scala b/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/Serializer.scala new file mode 100644 index 000000000000..0a254f9da667 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/Serializer.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.connect + +import org.apache.spark.connect.proto +import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors} +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter + +object Serializer { + + def serializeResponseValue(data: Any): proto.MlCommandResponse = { + data match { + case v: Vector => + val vectorProto = serializeVector(v) + proto.MlCommandResponse + .newBuilder() + .setVector(vectorProto) + .build() + case v: Matrix => + val matrixProto = serializeMatrix(v) + proto.MlCommandResponse + .newBuilder() + .setMatrix(matrixProto) + .build() + case _: Byte | _: Short | _: Int | _: Long | _: Float | _: Double | _: Boolean | _: String | + _: Array[_] => + proto.MlCommandResponse + .newBuilder() + .setLiteral(LiteralValueProtoConverter.toLiteralProto(data)) + .build() + } + } + + def serializeVector(data: Vector): proto.Vector = { + // TODO: Support sparse + val values = data.toArray + val denseBuilder = proto.Vector.Dense.newBuilder() + for (i <- 0 until values.length) { + denseBuilder.addValue(values(i)) + } + proto.Vector.newBuilder().setDense(denseBuilder).build() + } + + def serializeMatrix(data: Matrix): proto.Matrix = { + // TODO: Support sparse + // TODO: optimize transposed case + val denseBuilder = proto.Matrix.Dense.newBuilder() + val values = data.toArray + for (i <- 0 until values.length) { + denseBuilder.addValue(values(i)) + } + denseBuilder.setNumCols(data.numCols) + denseBuilder.setNumRows(data.numRows) + denseBuilder.setIsTransposed(false) + proto.Matrix.newBuilder().setDense(denseBuilder).build() + } + + def deserializeVector(protoValue: proto.Vector): Vector = { + // TODO: Support sparse + Vectors.dense( + protoValue.getDense.getValueList.stream().mapToDouble(_.doubleValue()).toArray + ) + } + + def deserializeMatrix(protoValue: proto.Matrix): Matrix = { + // TODO: Support sparse + val denseProto = protoValue.getDense + Matrices.dense( + denseProto.getNumRows, + denseProto.getNumCols, + denseProto.getValueList.stream().mapToDouble(_.doubleValue()).toArray + ) + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/SummaryUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/SummaryUtils.scala new file mode 100644 index 000000000000..59aee69f2df4 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/ml/connect/SummaryUtils.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.connect + +import org.apache.spark.connect.proto +import org.apache.spark.ml.classification.{BinaryClassificationSummary, ClassificationSummary, TrainingSummary} +import org.apache.spark.sql.DataFrame + +object SummaryUtils { + + def getClassificationSummaryAttr( + summary: ClassificationSummary, + name: String): Option[Either[proto.MlCommandResponse, DataFrame]] = { + name match { + case "predictions" => Some(Right(summary.predictions)) + case "predictionCol" => Some(Left(Serializer.serializeResponseValue(summary.predictionCol))) + case "labelCol" => Some(Left(Serializer.serializeResponseValue(summary.labelCol))) + case "weightCol" => Some(Left(Serializer.serializeResponseValue(summary.weightCol))) + case "labels" => Some(Left(Serializer.serializeResponseValue(summary.labels))) + case "truePositiveRateByLabel" => + Some(Left(Serializer.serializeResponseValue(summary.truePositiveRateByLabel))) + case "falsePositiveRateByLabel" => + Some(Left(Serializer.serializeResponseValue(summary.falsePositiveRateByLabel))) + case "precisionByLabel" => + Some(Left(Serializer.serializeResponseValue(summary.precisionByLabel))) + case "recallByLabel" => + Some(Left(Serializer.serializeResponseValue(summary.recallByLabel))) + // TODO: Support beta params. + case "fMeasureByLabel" => + Some(Left(Serializer.serializeResponseValue(summary.fMeasureByLabel))) + case "accuracy" => Some(Left(Serializer.serializeResponseValue(summary.accuracy))) + case "weightedTruePositiveRate" => + Some(Left(Serializer.serializeResponseValue(summary.weightedTruePositiveRate))) + case "weightedFalsePositiveRate" => + Some(Left(Serializer.serializeResponseValue(summary.weightedFalsePositiveRate))) + case "weightedRecall" => Some(Left(Serializer.serializeResponseValue(summary.weightedRecall))) + case "weightedPrecision" => + Some(Left(Serializer.serializeResponseValue(summary.weightedPrecision))) + // TODO: Support beta params. + case "weightedFMeasure" => + Some(Left(Serializer.serializeResponseValue(summary.weightedFMeasure))) + case _ => None + } + } + + def getBinaryClassificationSummaryAttr( + summary: BinaryClassificationSummary, + name: String): Option[Either[proto.MlCommandResponse, DataFrame]] = { + getClassificationSummaryAttr(summary, name).orElse(name match { + case "scoreCol" => Some(Left(Serializer.serializeResponseValue(summary.scoreCol))) + case "roc" => Some(Right(summary.roc)) + case "areaUnderROC" => Some(Left(Serializer.serializeResponseValue(summary.areaUnderROC))) + case "pr" => Some(Right(summary.pr)) + case "fMeasureByThreshold" => Some(Right(summary.fMeasureByThreshold)) + case "precisionByThreshold" => Some(Right(summary.precisionByThreshold)) + case "recallByThreshold" => Some(Right(summary.recallByThreshold)) + case _ => None + }) + } + + def getTrainingSummaryAttr( + summary: TrainingSummary, + name: String): Option[Either[proto.MlCommandResponse, DataFrame]] = { + name match { + case "objectiveHistory" => + Some(Left(Serializer.serializeResponseValue(summary.objectiveHistory))) + case "totalIterations" => + Some(Left(Serializer.serializeResponseValue(summary.totalIterations))) + case _ => None + } + } + +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ConnectSqlUtil.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ConnectSqlUtil.scala new file mode 100644 index 000000000000..acd408eead37 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ConnectSqlUtil.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.connect.proto +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.service.SessionHolder + +object ConnectSqlUtil { + def parseRelationProto( + relationProto: proto.Relation, + sessionHolder: SessionHolder + ): DataFrame = { + val relationalPlanner = new SparkConnectPlanner(sessionHolder) + val plan = relationalPlanner.transformRelation(relationProto) + Dataset.ofRows(sessionHolder.session, plan) + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 6de08862cd7e..17d6a231793a 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -31,7 +31,8 @@ import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand} import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult import org.apache.spark.connect.proto.Parse.ParseFormat import org.apache.spark.ml.{functions => MLFunctions} -import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession} +import org.apache.spark.ml.connect.MLHandler +import org.apache.spark.sql.{Column, Dataset, Encoders} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -44,7 +45,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter, UdfPacket} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry -import org.apache.spark.sql.connect.service.SparkConnectStreamHandler +import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectStreamHandler} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.arrow.ArrowConverters @@ -61,7 +62,9 @@ final case class InvalidCommandInput( private val cause: Throwable = null) extends Exception(message, cause) -class SparkConnectPlanner(val session: SparkSession) { +class SparkConnectPlanner(val sessionHolder: SessionHolder) { + val session = sessionHolder.session + private lazy val pythonExec = sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) @@ -127,6 +130,10 @@ class SparkConnectPlanner(val session: SparkSession) { // Catalog API (internal-only) case proto.Relation.RelTypeCase.CATALOG => transformCatalog(rel.getCatalog) + // ML relation + case proto.Relation.RelTypeCase.ML_RELATION => + MLHandler.transformMLRelation(rel.getMlRelation, sessionHolder).logicalPlan + // Handle plugins for Spark Connect Relation types. case proto.Relation.RelTypeCase.EXTENSION => transformRelationPlugin(rel.getExtension) @@ -1768,7 +1775,7 @@ class SparkConnectPlanner(val session: SparkSession) { */ private def handleWriteOperation(writeOperation: proto.WriteOperation): Unit = { // Transform the input plan into the logical plan. - val planner = new SparkConnectPlanner(session) + val planner = new SparkConnectPlanner(sessionHolder) val plan = planner.transformRelation(writeOperation.getInput) // And create a Dataset from the plan. val dataset = Dataset.ofRows(session, logicalPlan = plan) @@ -1839,7 +1846,7 @@ class SparkConnectPlanner(val session: SparkSession) { */ def handleWriteOperationV2(writeOperation: proto.WriteOperationV2): Unit = { // Transform the input plan into the logical plan. - val planner = new SparkConnectPlanner(session) + val planner = new SparkConnectPlanner(sessionHolder) val plan = planner.transformRelation(writeOperation.getInput) // And create a Dataset from the plan. val dataset = Dataset.ofRows(session, logicalPlan = plan) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index 4697a1fd7d42..e2fc83a54df7 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -23,7 +23,7 @@ import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput} import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExtendedMode, FormattedMode, SimpleMode} @@ -33,12 +33,12 @@ private[connect] class SparkConnectAnalyzeHandler( extends Logging { def handle(request: proto.AnalyzePlanRequest): Unit = { - val session = + val sessionHolder = SparkConnectService .getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId) - .session - session.withActive { - val response = process(request, session) + + sessionHolder.session.withActive { + val response = process(request, sessionHolder) responseObserver.onNext(response) responseObserver.onCompleted() } @@ -46,8 +46,9 @@ private[connect] class SparkConnectAnalyzeHandler( def process( request: proto.AnalyzePlanRequest, - session: SparkSession): proto.AnalyzePlanResponse = { - lazy val planner = new SparkConnectPlanner(session) + sessionHolder: SessionHolder): proto.AnalyzePlanResponse = { + lazy val planner = new SparkConnectPlanner(sessionHolder) + val session = sessionHolder.session val builder = proto.AnalyzePlanResponse.newBuilder() request.getAnalyzeCase match { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala old mode 100755 new mode 100644 index cd353b6ff609..249989598782 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -41,6 +41,7 @@ import org.apache.spark.api.python.PythonException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse} import org.apache.spark.internal.Logging +import org.apache.spark.ml.connect.MLCache import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT @@ -210,7 +211,11 @@ class SparkConnectService(debug: Boolean) * @param userId * @param session */ -case class SessionHolder(userId: String, sessionId: String, session: SparkSession) +case class SessionHolder( + userId: String, + sessionId: String, + session: SparkSession, + mlCache: MLCache = MLCache()) /** * Static instance of the SparkConnectService. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 335b871d499b..20148770ee23 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -26,7 +26,8 @@ import org.apache.spark.SparkEnv import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.ml.connect.MLHandler +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto @@ -43,23 +44,36 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp extends Logging { def handle(v: ExecutePlanRequest): Unit = { - val session = + val sessionHolder = SparkConnectService .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId) - .session + val session = sessionHolder.session session.withActive { v.getPlan.getOpTypeCase match { - case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v) - case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v) + case proto.Plan.OpTypeCase.COMMAND => handleCommand(sessionHolder, v) + case proto.Plan.OpTypeCase.ROOT => handlePlan(sessionHolder, v) + case proto.Plan.OpTypeCase.ML_COMMAND => + handleMlCommand(sessionHolder, v) case _ => throw new UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} not supported.") } } } - private def handlePlan(session: SparkSession, request: ExecutePlanRequest): Unit = { + private def handleMlCommand(sessionHolder: SessionHolder, request: ExecutePlanRequest): Unit = { + val mlResultProto = MLHandler.handleMlCommand(sessionHolder, request.getPlan.getMlCommand) + responseObserver.onNext( + ExecutePlanResponse + .newBuilder() + .setSessionId(sessionHolder.sessionId) + .setMlCommandResult(mlResultProto) + .build()) + } + + private def handlePlan(sessionHolder: SessionHolder, request: ExecutePlanRequest): Unit = { // Extract the plan from the request and convert it to a logical plan - val planner = new SparkConnectPlanner(session) + val session = sessionHolder.session + val planner = new SparkConnectPlanner(sessionHolder) val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot)) responseObserver.onNext( SparkConnectStreamHandler.sendSchemaToResponse(request.getSessionId, dataframe.schema)) @@ -73,9 +87,9 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp responseObserver.onCompleted() } - private def handleCommand(session: SparkSession, request: ExecutePlanRequest): Unit = { + private def handleCommand(sessionHolder: SessionHolder, request: ExecutePlanRequest): Unit = { val command = request.getPlan.getCommand - val planner = new SparkConnectPlanner(session) + val planner = new SparkConnectPlanner(sessionHolder) planner.process(command, request.getSessionId, responseObserver) responseObserver.onCompleted() } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala index e20a6159cc8a..979e1721851e 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, InMemoryCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.test.SharedSparkSession @@ -59,6 +60,9 @@ import org.apache.spark.util.Utils */ // scalastyle:on class ProtoToParsedPlanTestSuite extends SparkFunSuite with SharedSparkSession { + + def sessionHolder: SessionHolder = SessionHolder("user1", "session1", spark) + val url = "jdbc:h2:mem:testdb0" var conn: java.sql.Connection = null @@ -164,7 +168,7 @@ class ProtoToParsedPlanTestSuite extends SparkFunSuite with SharedSparkSession { val name = fileName.stripSuffix(".proto.bin") test(name) { val relation = readRelation(file) - val planner = new SparkConnectPlanner(spark) + val planner = new SparkConnectPlanner(sessionHolder) val catalystPlan = analyzer.executeAndCheck(planner.transformRelation(relation), new QueryPlanningTracker) val actual = normalizeExprIds(ReplaceExpressions(catalystPlan)).treeString diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index b6b214c839dc..7d3335c2d246 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProj import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto +import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -43,6 +44,8 @@ import org.apache.spark.unsafe.types.UTF8String */ trait SparkConnectPlanTest extends SharedSparkSession { + def sessionHolder: SessionHolder = SessionHolder("user1", "session1", spark) + class MockObserver extends StreamObserver[proto.ExecutePlanResponse] { override def onNext(value: ExecutePlanResponse): Unit = {} override def onError(t: Throwable): Unit = {} @@ -50,11 +53,11 @@ trait SparkConnectPlanTest extends SharedSparkSession { } def transform(rel: proto.Relation): logical.LogicalPlan = { - new SparkConnectPlanner(spark).transformRelation(rel) + new SparkConnectPlanner(sessionHolder).transformRelation(rel) } def transform(cmd: proto.Command): Unit = { - new SparkConnectPlanner(spark).process(cmd, "clientId", new MockObserver()) + new SparkConnectPlanner(sessionHolder).process(cmd, "clientId", new MockObserver()) } def readRel: proto.Relation = @@ -104,7 +107,7 @@ trait SparkConnectPlanTest extends SharedSparkSession { class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { test("Simple Limit") { - assertThrows[IndexOutOfBoundsException] { + assertThrows[NullPointerException] { new SparkConnectPlanner(None.orNull) .transformRelation( proto.Relation.newBuilder @@ -115,7 +118,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { test("InvalidInputs") { // No Relation Set - intercept[IndexOutOfBoundsException]( + intercept[NullPointerException]( new SparkConnectPlanner(None.orNull).transformRelation(proto.Relation.newBuilder().build())) intercept[InvalidPlanInput]( diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index c36ba76f9845..7c15f73a3ba0 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.connect.proto import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ -import org.apache.spark.sql.connect.service.{SparkConnectAnalyzeHandler, SparkConnectService} +import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectAnalyzeHandler, SparkConnectService} import org.apache.spark.sql.test.SharedSparkSession /** @@ -37,6 +37,8 @@ import org.apache.spark.sql.test.SharedSparkSession */ class SparkConnectServiceSuite extends SharedSparkSession { + def sessionHolder: SessionHolder = SessionHolder("user1", "session1", spark) + test("Test schema in analyze response") { withTable("test") { spark.sql(""" @@ -64,7 +66,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setSchema(proto.AnalyzePlanRequest.Schema.newBuilder().setPlan(plan).build()) .build() - val response1 = handler.process(request1, spark) + val response1 = handler.process(request1, sessionHolder) assert(response1.hasSchema) assert(response1.getSchema.getSchema.hasStruct) val schema = response1.getSchema.getSchema.getStruct @@ -85,7 +87,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .setExplainMode(proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE) .build()) .build() - val response2 = handler.process(request2, spark) + val response2 = handler.process(request2, sessionHolder) assert(response2.hasExplain) assert(response2.getExplain.getExplainString.size > 0) @@ -93,7 +95,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setIsLocal(proto.AnalyzePlanRequest.IsLocal.newBuilder().setPlan(plan).build()) .build() - val response3 = handler.process(request3, spark) + val response3 = handler.process(request3, sessionHolder) assert(response3.hasIsLocal) assert(!response3.getIsLocal.getIsLocal) @@ -101,7 +103,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setIsStreaming(proto.AnalyzePlanRequest.IsStreaming.newBuilder().setPlan(plan).build()) .build() - val response4 = handler.process(request4, spark) + val response4 = handler.process(request4, sessionHolder) assert(response4.hasIsStreaming) assert(!response4.getIsStreaming.getIsStreaming) @@ -109,7 +111,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setTreeString(proto.AnalyzePlanRequest.TreeString.newBuilder().setPlan(plan).build()) .build() - val response5 = handler.process(request5, spark) + val response5 = handler.process(request5, sessionHolder) assert(response5.hasTreeString) val treeString = response5.getTreeString.getTreeString assert(treeString.contains("root")) @@ -120,7 +122,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setInputFiles(proto.AnalyzePlanRequest.InputFiles.newBuilder().setPlan(plan).build()) .build() - val response6 = handler.process(request6, spark) + val response6 = handler.process(request6, sessionHolder) assert(response6.hasInputFiles) assert(response6.getInputFiles.getFilesCount === 0) } @@ -291,7 +293,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .build()) .build() - val response = handler.process(request, spark) + val response = handler.process(request, sessionHolder) assert(response.getExplain.getExplainString.contains("Parsed Logical Plan")) assert(response.getExplain.getExplainString.contains("Analyzed Logical Plan")) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index 39fc90fd0022..1fb02337e17a 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -195,7 +195,7 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne .build())) .build() - new SparkConnectPlanner(spark).process(plan, "clientId", new MockObserver()) + new SparkConnectPlanner(sessionHolder).process(plan, "clientId", new MockObserver()) assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin")) } } diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index c31a9362cd7f..239eee79ec6e 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -783,6 +783,8 @@ def __hash__(self): "pyspark.ml.connect.functions", # ml unittests "pyspark.ml.tests.connect.test_connect_function", + # ml classification algorithm tests + "pyspark.ml.connect.classification", ], excluded_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and diff --git a/mllib/common/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/common/src/main/scala/org/apache/spark/ml/param/params.scala index b818be30583c..908dacc96424 100644 --- a/mllib/common/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/common/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,6 +24,7 @@ import java.util.NoSuchElementException import scala.annotation.varargs import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.reflect.ClassTag import org.json4s._ import org.json4s.jackson.JsonMethods._ @@ -44,8 +45,14 @@ import org.apache.spark.ml.util.Identifiable * See [[ParamValidators]] for factory methods for common validation functions. * @tparam T param value type */ -class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean) - extends Serializable { +class Param[T: ClassTag]( + val parent: String, val name: String, val doc: String, val isValid: T => Boolean +) extends Serializable { + + // Generic type T is erased when compiling, + // but spark connect ML needs T type information, + // so use classTag to preserve the T type. + val paramValueClassTag = implicitly[ClassTag[T]] def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) = this(parent.uid, name, doc, isValid) @@ -793,6 +800,10 @@ trait Params extends Identifiable with Serializable { this } + private[spark] def _setDefault(paramPairs: ParamPair[_]*): this.type = { + setDefault(paramPairs: _*) + } + /** * Gets the default value of a parameter. */ diff --git a/mllib/core/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala b/mllib/core/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala index 9f3428db484c..6f430f1482ba 100644 --- a/mllib/core/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala +++ b/mllib/core/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.DoubleType /** * Abstraction for multiclass classification results for a given model. */ -private[classification] trait ClassificationSummary extends Serializable { +private[spark] trait ClassificationSummary extends Serializable { /** * Dataframe output by the model's `transform` method. @@ -147,7 +147,7 @@ private[classification] trait ClassificationSummary extends Serializable { /** * Abstraction for training results. */ -private[classification] trait TrainingSummary { +private[spark] trait TrainingSummary { /** * objective function (scaled loss + regularization) at each iteration. @@ -167,7 +167,7 @@ private[classification] trait TrainingSummary { /** * Abstraction for binary classification results for a given model. */ -private[classification] trait BinaryClassificationSummary extends ClassificationSummary { +private[spark] trait BinaryClassificationSummary extends ClassificationSummary { private val sparkSession = predictions.sparkSession import sparkSession.implicits._ diff --git a/mllib/core/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/core/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 1ad5f7a442da..1e6986b5098f 100644 --- a/mllib/core/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/core/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -108,9 +108,6 @@ private void init() { myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0)); myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param", ParamValidators.inRange(0.0, 1.0)); - List validStrings = Arrays.asList("a", "b"); - myStringParam_ = new Param<>(this, "myStringParam", "this is a string param", - ParamValidators.inArray(validStrings)); myDoubleArrayParam_ = new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param"); diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index 34c3aa9c62cf..fe3d4c547ffb 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -17,6 +17,7 @@ from abc import ABCMeta, abstractmethod +import os import copy import threading diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index c09a510d76b6..b1f38afd64bf 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -94,6 +94,7 @@ from pyspark.sql.functions import udf, when from pyspark.sql.types import ArrayType, DoubleType from pyspark.storagelevel import StorageLevel +from pyspark.ml.connect.utils import try_remote_ml_class if TYPE_CHECKING: @@ -1121,10 +1122,201 @@ def getUpperBoundsOnIntercepts(self) -> Vector: return self.getOrDefault(self.upperBoundsOnIntercepts) +class _LogisticRegressionCommon(ProbabilisticClassifier, _LogisticRegressionParams): + @overload + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + regParam: float = ..., + elasticNetParam: float = ..., + tol: float = ..., + fitIntercept: bool = ..., + threshold: float = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + standardization: bool = ..., + weightCol: Optional[str] = ..., + aggregationDepth: int = ..., + family: str = ..., + lowerBoundsOnCoefficients: Optional[Matrix] = ..., + upperBoundsOnCoefficients: Optional[Matrix] = ..., + lowerBoundsOnIntercepts: Optional[Vector] = ..., + upperBoundsOnIntercepts: Optional[Vector] = ..., + maxBlockSizeInMB: float = ..., + ) -> "LogisticRegression": + ... + + @overload + def setParams( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + regParam: float = ..., + elasticNetParam: float = ..., + tol: float = ..., + fitIntercept: bool = ..., + thresholds: Optional[List[float]] = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + standardization: bool = ..., + weightCol: Optional[str] = ..., + aggregationDepth: int = ..., + family: str = ..., + lowerBoundsOnCoefficients: Optional[Matrix] = ..., + upperBoundsOnCoefficients: Optional[Matrix] = ..., + lowerBoundsOnIntercepts: Optional[Vector] = ..., + upperBoundsOnIntercepts: Optional[Vector] = ..., + maxBlockSizeInMB: float = ..., + ) -> "LogisticRegression": + ... + + @keyword_only + @since("1.3.0") + def setParams( + self, + *, + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + maxIter: int = 100, + regParam: float = 0.0, + elasticNetParam: float = 0.0, + tol: float = 1e-6, + fitIntercept: bool = True, + threshold: float = 0.5, + thresholds: Optional[List[float]] = None, + probabilityCol: str = "probability", + rawPredictionCol: str = "rawPrediction", + standardization: bool = True, + weightCol: Optional[str] = None, + aggregationDepth: int = 2, + family: str = "auto", + lowerBoundsOnCoefficients: Optional[Matrix] = None, + upperBoundsOnCoefficients: Optional[Matrix] = None, + lowerBoundsOnIntercepts: Optional[Vector] = None, + upperBoundsOnIntercepts: Optional[Vector] = None, + maxBlockSizeInMB: float = 0.0, + ) -> "LogisticRegression": + """ + setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \ + aggregationDepth=2, family="auto", \ + lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \ + lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, \ + maxBlockSizeInMB=0.0): + Sets params for logistic regression. + If the threshold and thresholds Params are both set, they must be equivalent. + """ + kwargs = self._input_kwargs + self._set(**kwargs) + self._checkThresholdConsistency() + return self + + @since("2.1.0") + def setFamily(self, value: str) -> "LogisticRegression": + """ + Sets the value of :py:attr:`family`. + """ + return self._set(family=value) + + @since("2.3.0") + def setLowerBoundsOnCoefficients(self, value: Matrix) -> "LogisticRegression": + """ + Sets the value of :py:attr:`lowerBoundsOnCoefficients` + """ + return self._set(lowerBoundsOnCoefficients=value) + + @since("2.3.0") + def setUpperBoundsOnCoefficients(self, value: Matrix) -> "LogisticRegression": + """ + Sets the value of :py:attr:`upperBoundsOnCoefficients` + """ + return self._set(upperBoundsOnCoefficients=value) + + @since("2.3.0") + def setLowerBoundsOnIntercepts(self, value: Vector) -> "LogisticRegression": + """ + Sets the value of :py:attr:`lowerBoundsOnIntercepts` + """ + return self._set(lowerBoundsOnIntercepts=value) + + @since("2.3.0") + def setUpperBoundsOnIntercepts(self, value: Vector) -> "LogisticRegression": + """ + Sets the value of :py:attr:`upperBoundsOnIntercepts` + """ + return self._set(upperBoundsOnIntercepts=value) + + def setMaxIter(self, value: int) -> "LogisticRegression": + """ + Sets the value of :py:attr:`maxIter`. + """ + return self._set(maxIter=value) + + def setRegParam(self, value: float) -> "LogisticRegression": + """ + Sets the value of :py:attr:`regParam`. + """ + return self._set(regParam=value) + + def setTol(self, value: float) -> "LogisticRegression": + """ + Sets the value of :py:attr:`tol`. + """ + return self._set(tol=value) + + def setElasticNetParam(self, value: float) -> "LogisticRegression": + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + return self._set(elasticNetParam=value) + + def setFitIntercept(self, value: bool) -> "LogisticRegression": + """ + Sets the value of :py:attr:`fitIntercept`. + """ + return self._set(fitIntercept=value) + + def setStandardization(self, value: bool) -> "LogisticRegression": + """ + Sets the value of :py:attr:`standardization`. + """ + return self._set(standardization=value) + + def setWeightCol(self, value: str) -> "LogisticRegression": + """ + Sets the value of :py:attr:`weightCol`. + """ + return self._set(weightCol=value) + + def setAggregationDepth(self, value: int) -> "LogisticRegression": + """ + Sets the value of :py:attr:`aggregationDepth`. + """ + return self._set(aggregationDepth=value) + + @since("3.1.0") + def setMaxBlockSizeInMB(self, value: float) -> "LogisticRegression": + """ + Sets the value of :py:attr:`maxBlockSizeInMB`. + """ + return self._set(maxBlockSizeInMB=value) + + +@try_remote_ml_class @inherit_doc class LogisticRegression( _JavaProbabilisticClassifier["LogisticRegressionModel"], - _LogisticRegressionParams, + _LogisticRegressionCommon, JavaMLWritable, JavaMLReadable["LogisticRegression"], ): @@ -1138,11 +1330,13 @@ class LogisticRegression( -------- >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors - >>> bdf = sc.parallelize([ - ... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)), - ... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)), - ... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)), - ... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF() + >>> from pyspark.ml.functions import array_to_vector + >>> bdf = spark.createDataFrame([ + ... (1.0, 1.0, [0.0, 5.0]), + ... (0.0, 2.0, [1.0, 2.0]), + ... (1.0, 3.0, [2.0, 1.0]), + ... (0.0, 4.0, [3.0, 3.0]), + ... ], ["label", "weight", "features"]).withColumn("features", array_to_vector("features")) >>> blor = LogisticRegression(weightCol="weight") >>> blor.getRegParam() 0.0 @@ -1322,198 +1516,11 @@ def __init__( self.setParams(**kwargs) self._checkThresholdConsistency() - @overload - def setParams( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - maxIter: int = ..., - regParam: float = ..., - elasticNetParam: float = ..., - tol: float = ..., - fitIntercept: bool = ..., - threshold: float = ..., - probabilityCol: str = ..., - rawPredictionCol: str = ..., - standardization: bool = ..., - weightCol: Optional[str] = ..., - aggregationDepth: int = ..., - family: str = ..., - lowerBoundsOnCoefficients: Optional[Matrix] = ..., - upperBoundsOnCoefficients: Optional[Matrix] = ..., - lowerBoundsOnIntercepts: Optional[Vector] = ..., - upperBoundsOnIntercepts: Optional[Vector] = ..., - maxBlockSizeInMB: float = ..., - ) -> "LogisticRegression": - ... - - @overload - def setParams( - self, - *, - featuresCol: str = ..., - labelCol: str = ..., - predictionCol: str = ..., - maxIter: int = ..., - regParam: float = ..., - elasticNetParam: float = ..., - tol: float = ..., - fitIntercept: bool = ..., - thresholds: Optional[List[float]] = ..., - probabilityCol: str = ..., - rawPredictionCol: str = ..., - standardization: bool = ..., - weightCol: Optional[str] = ..., - aggregationDepth: int = ..., - family: str = ..., - lowerBoundsOnCoefficients: Optional[Matrix] = ..., - upperBoundsOnCoefficients: Optional[Matrix] = ..., - lowerBoundsOnIntercepts: Optional[Vector] = ..., - upperBoundsOnIntercepts: Optional[Vector] = ..., - maxBlockSizeInMB: float = ..., - ) -> "LogisticRegression": - ... - - @keyword_only - @since("1.3.0") - def setParams( - self, - *, - featuresCol: str = "features", - labelCol: str = "label", - predictionCol: str = "prediction", - maxIter: int = 100, - regParam: float = 0.0, - elasticNetParam: float = 0.0, - tol: float = 1e-6, - fitIntercept: bool = True, - threshold: float = 0.5, - thresholds: Optional[List[float]] = None, - probabilityCol: str = "probability", - rawPredictionCol: str = "rawPrediction", - standardization: bool = True, - weightCol: Optional[str] = None, - aggregationDepth: int = 2, - family: str = "auto", - lowerBoundsOnCoefficients: Optional[Matrix] = None, - upperBoundsOnCoefficients: Optional[Matrix] = None, - lowerBoundsOnIntercepts: Optional[Vector] = None, - upperBoundsOnIntercepts: Optional[Vector] = None, - maxBlockSizeInMB: float = 0.0, - ) -> "LogisticRegression": - """ - setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \ - aggregationDepth=2, family="auto", \ - lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \ - lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, \ - maxBlockSizeInMB=0.0): - Sets params for logistic regression. - If the threshold and thresholds Params are both set, they must be equivalent. - """ - kwargs = self._input_kwargs - self._set(**kwargs) - self._checkThresholdConsistency() - return self - def _create_model(self, java_model: "JavaObject") -> "LogisticRegressionModel": return LogisticRegressionModel(java_model) - @since("2.1.0") - def setFamily(self, value: str) -> "LogisticRegression": - """ - Sets the value of :py:attr:`family`. - """ - return self._set(family=value) - - @since("2.3.0") - def setLowerBoundsOnCoefficients(self, value: Matrix) -> "LogisticRegression": - """ - Sets the value of :py:attr:`lowerBoundsOnCoefficients` - """ - return self._set(lowerBoundsOnCoefficients=value) - - @since("2.3.0") - def setUpperBoundsOnCoefficients(self, value: Matrix) -> "LogisticRegression": - """ - Sets the value of :py:attr:`upperBoundsOnCoefficients` - """ - return self._set(upperBoundsOnCoefficients=value) - - @since("2.3.0") - def setLowerBoundsOnIntercepts(self, value: Vector) -> "LogisticRegression": - """ - Sets the value of :py:attr:`lowerBoundsOnIntercepts` - """ - return self._set(lowerBoundsOnIntercepts=value) - - @since("2.3.0") - def setUpperBoundsOnIntercepts(self, value: Vector) -> "LogisticRegression": - """ - Sets the value of :py:attr:`upperBoundsOnIntercepts` - """ - return self._set(upperBoundsOnIntercepts=value) - - def setMaxIter(self, value: int) -> "LogisticRegression": - """ - Sets the value of :py:attr:`maxIter`. - """ - return self._set(maxIter=value) - - def setRegParam(self, value: float) -> "LogisticRegression": - """ - Sets the value of :py:attr:`regParam`. - """ - return self._set(regParam=value) - - def setTol(self, value: float) -> "LogisticRegression": - """ - Sets the value of :py:attr:`tol`. - """ - return self._set(tol=value) - - def setElasticNetParam(self, value: float) -> "LogisticRegression": - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - return self._set(elasticNetParam=value) - - def setFitIntercept(self, value: bool) -> "LogisticRegression": - """ - Sets the value of :py:attr:`fitIntercept`. - """ - return self._set(fitIntercept=value) - - def setStandardization(self, value: bool) -> "LogisticRegression": - """ - Sets the value of :py:attr:`standardization`. - """ - return self._set(standardization=value) - - def setWeightCol(self, value: str) -> "LogisticRegression": - """ - Sets the value of :py:attr:`weightCol`. - """ - return self._set(weightCol=value) - - def setAggregationDepth(self, value: int) -> "LogisticRegression": - """ - Sets the value of :py:attr:`aggregationDepth`. - """ - return self._set(aggregationDepth=value) - - @since("3.1.0") - def setMaxBlockSizeInMB(self, value: float) -> "LogisticRegression": - """ - Sets the value of :py:attr:`maxBlockSizeInMB`. - """ - return self._set(maxBlockSizeInMB=value) - +@try_remote_ml_class class LogisticRegressionModel( _JavaProbabilisticClassificationModel[Vector], _LogisticRegressionParams, diff --git a/python/pyspark/ml/connect/__init__.py b/python/pyspark/ml/connect/__init__.py index 7612e0caa28e..54f51694c6f1 100644 --- a/python/pyspark/ml/connect/__init__.py +++ b/python/pyspark/ml/connect/__init__.py @@ -16,3 +16,4 @@ # """Spark Connect Python Client - ML module""" + diff --git a/python/pyspark/ml/connect/base.py b/python/pyspark/ml/connect/base.py new file mode 100644 index 000000000000..cf7f80a561fc --- /dev/null +++ b/python/pyspark/ml/connect/base.py @@ -0,0 +1,349 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta, abstractmethod + +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.ml import Estimator, Model, Predictor, PredictionModel +from pyspark.ml.wrapper import _PredictorParams +from pyspark.ml.util import MLWritable, MLWriter, MLReadable, MLReader +import pyspark.sql.connect.proto as pb2 +import pyspark.sql.connect.proto.ml_pb2 as ml_pb2 +import pyspark.sql.connect.proto.ml_common_pb2 as ml_common_pb2 +from pyspark.ml.connect.serializer import ( + deserialize_response_value, + serialize_ml_params, + set_instance_params_from_proto, +) +from pyspark.sql.connect import session as pyspark_session +from pyspark.sql.connect.plan import LogicalPlan + +from pyspark.ml.util import inherit_doc +from pyspark.ml.util import HasTrainingSummary as PySparkHasTrainingSummary + + +class ModelRef: + + def __init__(self, ref_id): + self.ref_id = ref_id + + def to_proto(self): + return ml_common_pb2.ModelRef(id=self.ref_id) + + @classmethod + def from_proto(cls, model_ref_pb: ml_common_pb2.ModelRef): + return ModelRef(ref_id=model_ref_pb.id) + + def __del__(self): + client = pyspark_session._active_spark_session.client + del_model_proto = ml_pb2.MlCommand.DeleteModel( + model_ref=self.to_proto(), + ) + req = client._execute_plan_request_with_metadata() + req.plan.ml_command.delete_model.CopyFrom(del_model_proto) + client._execute_ml(req) + + +@inherit_doc +class ClientEstimator(Estimator, metaclass=ABCMeta): + + @classmethod + @abstractmethod + def _algo_name(cls): + raise NotImplementedError() + + @classmethod + @abstractmethod + def _model_class(cls): + raise NotImplementedError() + + def _fit(self, dataset: DataFrame) -> Model: + client = dataset.sparkSession.client + dataset_relation = dataset._plan.plan(client) + estimator_proto = ml_common_pb2.MlStage( + name=self._algo_name(), + params=serialize_ml_params(self, client), + uid=self.uid, + type=ml_common_pb2.MlStage.STAGE_TYPE_ESTIMATOR, + ) + fit_command_proto = ml_pb2.MlCommand.Fit( + estimator=estimator_proto, + dataset=dataset_relation, + ) + req = client._execute_plan_request_with_metadata() + req.plan.ml_command.fit.CopyFrom(fit_command_proto) + + resp = client._execute_ml(req) + model_ref = deserialize_response_value(resp, client) + model = self._model_class()() + model._resetUid(self.uid) + model.model_ref = model_ref + return self._copyValues(model) + + +@inherit_doc +class ClientPredictor(Predictor, ClientEstimator, _PredictorParams, metaclass=ABCMeta): + pass + + +@inherit_doc +class ClientModel(Model, metaclass=ABCMeta): + + model_ref: ModelRef = None + + @classmethod + @abstractmethod + def _algo_name(cls): + raise NotImplementedError() + + def _get_model_attr(self, name): + client = pyspark_session._active_spark_session.client + model_attr_command_proto = ml_pb2.MlCommand.FetchModelAttr( + model_ref=self.model_ref.to_proto(), + name=name + ) + req = client._execute_plan_request_with_metadata() + req.plan.ml_command.fetch_model_attr.CopyFrom(model_attr_command_proto) + + resp = client._execute_ml(req) + return deserialize_response_value(resp, client) + + def _get_model_attr_dataframe(self, name) -> DataFrame: + session = pyspark_session._active_spark_session + plan = _ModelAttrRelationPlan( + self, name + ) + return DataFrame.withPlan(plan, session) + + def _transform(self, dataset: DataFrame) -> DataFrame: + session = dataset.sparkSession + plan = _ModelTransformRelationPlan(dataset._plan, self) + return DataFrame.withPlan(plan, session) + + def copy(self, extra=None): + copied_model = super(ClientModel, self).copy(extra) + + client = pyspark_session._active_spark_session.client + copy_model_proto = ml_pb2.MlCommand.CopyModel( + model_ref=self.model_ref.to_proto() + ) + req = client._execute_plan_request_with_metadata() + req.plan.ml_command.copy_model.CopyFrom(copy_model_proto) + + resp = client._execute_ml(req) + new_model_ref = deserialize_response_value(resp, client) + + copied_model.model_ref = new_model_ref + + return copied_model + + +@inherit_doc +class ClientPredictionModel(PredictionModel, ClientModel, _PredictorParams): + @property # type: ignore[misc] + def numFeatures(self) -> int: + return self._get_model_attr("numFeatures") + + def predict(self, value) -> float: + # TODO: support this. + raise NotImplementedError() + + +class _ModelTransformRelationPlan(LogicalPlan): + def __init__(self, child, model): + super().__init__(child) + self.model = model + + def plan(self, session: "SparkConnectClient") -> pb2.Relation: + assert self._child is not None + plan = self._create_proto_relation() + plan.ml_relation.model_transform.input.CopyFrom(self._child.plan(session)) + plan.ml_relation.model_transform.model_ref.CopyFrom(self.model.model_ref.to_proto()) + plan.ml_relation.model_transform.params.CopyFrom(serialize_ml_params(self.model, session)) + + return plan + + +class _ModelAttrRelationPlan(LogicalPlan): + def __init__(self, model, name): + super().__init__(None) + self.model = model + self.name = name + + def plan(self, session: "SparkConnectClient") -> pb2.Relation: + assert self._child is None + plan = self._create_proto_relation() + plan.ml_relation.model_attr.model_ref.CopyFrom(self.model.model_ref.to_proto) + plan.ml_relation.model_attr.name = self.name + plan.ml_relation.model_attr.params.CopyFrom(serialize_ml_params(self.model, session)) + return plan + + +class _ModelSummaryAttrRelationPlan(LogicalPlan): + def __init__(self, child, model, name): + super().__init__(child) + self.model = model + self.name = name + + def plan(self, session: "SparkConnectClient") -> pb2.Relation: + plan = self._create_proto_relation() + if self._child is not None: + plan.ml_relation.model_summary_attr.evaluation_dataset.CopyFrom(self._child.plan(session)) + plan.ml_relation.model_summary_attr.model_ref.CopyFrom(self.model.model_ref.to_proto()) + plan.ml_relation.model_summary_attr.name = self.name + plan.ml_relation.model_summary_attr.params.CopyFrom(serialize_ml_params(self.model, session)) + return plan + + +class ClientModelSummary(metaclass=ABCMeta): + def __init__(self, model, dataset): + self.model = model + self.dataset = dataset + + def _get_summary_attr_dataframe(self, name): + session = pyspark_session._active_spark_session + plan = _ModelSummaryAttrRelationPlan( + (self.dataset._plan if self.dataset is not None else None), + self.model, name + ) + return DataFrame.withPlan(plan, session) + + def _get_summary_attr(self, name): + client = pyspark_session._active_spark_session.client + + model_summary_attr_command_proto = ml_pb2.MlCommand.FetchModelSummaryAttr( + model_ref=self.model.model_ref.to_proto(), + name=name, + params=serialize_ml_params(self.model, client), + evaluation_dataset=(self.dataset._plan.plan(client) if self.dataset is not None else None) + ) + req = client._execute_plan_request_with_metadata() + req.plan.ml_command.fetch_model_summary_attr.CopyFrom(model_summary_attr_command_proto) + + resp = client._execute_ml(req) + return deserialize_response_value(resp, client) + + +@inherit_doc +class HasTrainingSummary(ClientModel, metaclass=ABCMeta): + + @property # type: ignore[misc] + def hasSummary(self) -> bool: + return self._get_model_attr("hasSummary") + + hasSummary.__doc__ = PySparkHasTrainingSummary.hasSummary.__doc__ + + @abstractmethod + def summary(self): + raise NotImplementedError() + + summary.__doc__ = PySparkHasTrainingSummary.summary.__doc__ + + +HasTrainingSummary.__doc__ = PySparkHasTrainingSummary.__doc__ + + +@inherit_doc +class ClientMLWriter(MLWriter): + + def __init__(self, instance: "ClientMLWritable"): + super(ClientMLWriter, self).__init__() + self.instance = instance + + def save(self, path: str) -> None: + client = pyspark_session._active_spark_session.client + req = client._execute_plan_request_with_metadata() + + if isinstance(self.instance, ClientModel): + save_cmd_proto = ml_pb2.MlCommand.SaveModel( + model_ref=self.instance.model_ref.to_proto(), + path=path, + overwrite=self.shouldOverwrite, + options=self.optionMap + ) + req.plan.ml_command.save_model.CopyFrom(save_cmd_proto) + elif isinstance(self.instance, Estimator): + stage_pb = ml_common_pb2.MlStage( + name=self.instance._algo_name(), + params=serialize_ml_params(self.instance, client), + uid=self.instance.uid, + type=ml_common_pb2.MlStage.STAGE_TYPE_ESTIMATOR, + ) + save_cmd_proto = ml_pb2.MlCommand.SaveStage( + stage=stage_pb, + path=path, + overwrite=self.shouldOverwrite, + options=self.optionMap + ) + req.plan.ml_command.save_stage.CopyFrom(save_cmd_proto) + else: + raise NotImplementedError() + + client._execute_ml(req) + + +@inherit_doc +class ClientMLWritable(MLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`. + """ + + def write(self) -> ClientMLWriter: + """Returns an MLWriter instance for this ML instance.""" + return ClientMLWriter(self) + + +@inherit_doc +class ClientMLReader(MLReader): + + def __init__(self, clazz): + self.clazz = clazz + + def load(self, path: str): + client = pyspark_session._active_spark_session.client + req = client._execute_plan_request_with_metadata() + + name = self.clazz._algo_name() + if issubclass(self.clazz, ClientModel): + load_model_proto = ml_pb2.MlCommand.LoadModel( + name=name, + path=path + ) + req.plan.ml_command.load_model.CopyFrom(load_model_proto) + resp = client._execute_ml(req) + return deserialize_response_value(resp, client, clazz=self.clazz) + + elif issubclass(self.clazz, ClientEstimator): + load_estimator_proto = ml_pb2.MlCommand.LoadStage( + name=name, + path=path, + type=ml_common_pb2.MlStage.STAGE_TYPE_ESTIMATOR + ) + req.plan.ml_command.load_stage.CopyFrom(load_estimator_proto) + resp = client._execute_ml(req) + return deserialize_response_value(resp, client, clazz=self.clazz) + else: + raise NotImplementedError() + + +@inherit_doc +class ClientMLReadable(MLReadable): + + @classmethod + def read(cls) -> ClientMLReader: + """Returns an MLReader instance for this class.""" + return ClientMLReader(cls) diff --git a/python/pyspark/ml/connect/classification.py b/python/pyspark/ml/connect/classification.py new file mode 100644 index 000000000000..81f3406d9b36 --- /dev/null +++ b/python/pyspark/ml/connect/classification.py @@ -0,0 +1,510 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + Dict, + Generic, + Iterable, + List, + Optional, + Type, + TypeVar, + Union, + cast, + overload, + TYPE_CHECKING, +) + +from pyspark.sql import DataFrame +from pyspark.ml.classification import ( + Classifier, + ProbabilisticClassifier, + ProbabilisticClassifier, + _LogisticRegressionParams, + _LogisticRegressionCommon, + ClassificationModel, + ProbabilisticClassificationModel +) +from pyspark.ml.linalg import (Matrix, Vector) +from pyspark import keyword_only, since, SparkContext, inheritable_thread_target +from pyspark.ml.connect.base import ( + ClientEstimator, + ClientModel, + HasTrainingSummary, + ClientModelSummary, + ClientPredictor, + ClientPredictionModel, + ClientMLWritable, + ClientMLReadable, +) +from abc import ABCMeta, abstractmethod +from pyspark.ml.util import inherit_doc +from pyspark.ml.classification import ( + LogisticRegression as PySparkLogisticRegression, + LogisticRegressionModel as PySparkLogisticRegressionModel, + _ClassificationSummary as _PySparkClassificationSummary, + _TrainingSummary as _PySparkTrainingSummary, + _BinaryClassificationSummary as _PySparkBinaryClassificationSummary, + LogisticRegressionSummary as PySparkLogisticRegressionSummary, + LogisticRegressionTrainingSummary as PySparkLogisticRegressionTrainingSummary, + BinaryLogisticRegressionSummary as PySparkBinaryLogisticRegressionSummary, + BinaryLogisticRegressionTrainingSummary as PySparkBinaryLogisticRegressionTrainingSummary, +) + + +@inherit_doc +class _ClientClassifier(Classifier, ClientPredictor, metaclass=ABCMeta): + pass + + +@inherit_doc +class _ClientProbabilisticClassifier( + ProbabilisticClassifier, _ClientClassifier, metaclass=ABCMeta +): + pass + + +@inherit_doc +class _ClientClassificationModel(ClassificationModel, ClientPredictionModel): + @property # type: ignore[misc] + def numClasses(self) -> int: + return self._get_model_attr("numClasses") + + def predictRaw(self, value: Vector) -> Vector: + # TODO: support this. + raise NotImplementedError() + + +@inherit_doc +class _ClientProbabilisticClassificationModel( + ProbabilisticClassificationModel, _ClientClassificationModel +): + def predictProbability(self, value: Vector) -> Vector: + # TODO: support this. + raise NotImplementedError() + + +@inherit_doc +class LogisticRegression( + _ClientProbabilisticClassifier, + _LogisticRegressionCommon, + ClientMLWritable, + ClientMLReadable, +): + _input_kwargs: Dict[str, Any] + + @overload + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + regParam: float = ..., + elasticNetParam: float = ..., + tol: float = ..., + fitIntercept: bool = ..., + threshold: float = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + standardization: bool = ..., + weightCol: Optional[str] = ..., + aggregationDepth: int = ..., + family: str = ..., + lowerBoundsOnCoefficients: Optional[Matrix] = ..., + upperBoundsOnCoefficients: Optional[Matrix] = ..., + lowerBoundsOnIntercepts: Optional[Vector] = ..., + upperBoundsOnIntercepts: Optional[Vector] = ..., + maxBlockSizeInMB: float = ..., + ): + ... + + @overload + def __init__( + self, + *, + featuresCol: str = ..., + labelCol: str = ..., + predictionCol: str = ..., + maxIter: int = ..., + regParam: float = ..., + elasticNetParam: float = ..., + tol: float = ..., + fitIntercept: bool = ..., + thresholds: Optional[List[float]] = ..., + probabilityCol: str = ..., + rawPredictionCol: str = ..., + standardization: bool = ..., + weightCol: Optional[str] = ..., + aggregationDepth: int = ..., + family: str = ..., + lowerBoundsOnCoefficients: Optional[Matrix] = ..., + upperBoundsOnCoefficients: Optional[Matrix] = ..., + lowerBoundsOnIntercepts: Optional[Vector] = ..., + upperBoundsOnIntercepts: Optional[Vector] = ..., + maxBlockSizeInMB: float = ..., + ): + ... + + @keyword_only + def __init__( + self, + *, + featuresCol: str = "features", + labelCol: str = "label", + predictionCol: str = "prediction", + maxIter: int = 100, + regParam: float = 0.0, + elasticNetParam: float = 0.0, + tol: float = 1e-6, + fitIntercept: bool = True, + threshold: float = 0.5, + thresholds: Optional[List[float]] = None, + probabilityCol: str = "probability", + rawPredictionCol: str = "rawPrediction", + standardization: bool = True, + weightCol: Optional[str] = None, + aggregationDepth: int = 2, + family: str = "auto", + lowerBoundsOnCoefficients: Optional[Matrix] = None, + upperBoundsOnCoefficients: Optional[Matrix] = None, + lowerBoundsOnIntercepts: Optional[Vector] = None, + upperBoundsOnIntercepts: Optional[Vector] = None, + maxBlockSizeInMB: float = 0.0, + ): + """ + __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \ + aggregationDepth=2, family="auto", \ + lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \ + lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, \ + maxBlockSizeInMB=0.0): + If the threshold and thresholds Params are both set, they must be equivalent. + """ + super(LogisticRegression, self).__init__() + kwargs = self._input_kwargs + self.setParams(**kwargs) + self._checkThresholdConsistency() + + @classmethod + def _algo_name(cls): + return "LogisticRegression" + + @classmethod + def _model_class(cls): + return LogisticRegressionModel + + +LogisticRegression.__doc__ = PySparkLogisticRegression.__doc__ + + +@inherit_doc +class LogisticRegressionModel( + _ClientProbabilisticClassificationModel, + _LogisticRegressionParams, + HasTrainingSummary, + ClientMLWritable, + ClientMLReadable, +): + @classmethod + def _algo_name(cls): + return "LogisticRegression" + + @property # type: ignore[misc] + def coefficients(self) -> Vector: + return self._get_model_attr("coefficients") + + @property # type: ignore[misc] + def intercept(self) -> float: + return self._get_model_attr("intercept") + + @property # type: ignore[misc] + def coefficientMatrix(self) -> Matrix: + return self._get_model_attr("coefficientMatrix") + + @property # type: ignore[misc] + def interceptVector(self) -> Vector: + return self._get_model_attr("interceptVector") + + def evaluate(self, dataset): + if self.numClasses <= 2: + return BinaryLogisticRegressionSummary(self, dataset) + else: + return LogisticRegressionSummary(self, dataset) + + # TODO: Move this method to common interface shared by connect code and legacy code + @property # type: ignore[misc] + def summary(self) -> "LogisticRegressionTrainingSummary": + if self.hasSummary: + if self.numClasses <= 2: + return BinaryLogisticRegressionTrainingSummary(self, None) + else: + return LogisticRegressionTrainingSummary(self, None) + else: + raise RuntimeError( + "No training summary available for this %s" % self.__class__.__name__ + ) + + +LogisticRegressionModel.__doc__ = PySparkLogisticRegressionModel.__doc__ + + +@inherit_doc +class _ClassificationSummary(ClientModelSummary): + + @property # type: ignore[misc] + def predictions(self) -> DataFrame: + return self._get_summary_attr_dataframe("predictions") + + predictions.__doc__ = _PySparkClassificationSummary.predictions.__doc__ + + @property # type: ignore[misc] + def predictionCol(self) -> str: + return self._get_summary_attr("predictionCol") + + predictionCol.__doc__ = _PySparkClassificationSummary.predictionCol.__doc__ + + @property # type: ignore[misc] + def labelCol(self) -> str: + return self._get_summary_attr("labelCol") + + labelCol.__doc__ = _PySparkClassificationSummary.labelCol.__doc__ + + @property # type: ignore[misc] + def weightCol(self) -> str: + return self._get_summary_attr("weightCol") + + weightCol.__doc__ = _PySparkClassificationSummary.weightCol.__doc__ + + @property + def labels(self) -> List[str]: + return self._get_summary_attr("labels") + + labels.__doc__ = _PySparkClassificationSummary.labels.__doc__ + + @property # type: ignore[misc] + def truePositiveRateByLabel(self) -> List[float]: + return self._get_summary_attr("truePositiveRateByLabel") + + truePositiveRateByLabel.__doc__ = _PySparkClassificationSummary.truePositiveRateByLabel.__doc__ + + @property # type: ignore[misc] + def falsePositiveRateByLabel(self) -> List[float]: + return self._get_summary_attr("falsePositiveRateByLabel") + + falsePositiveRateByLabel.__doc__ = _PySparkClassificationSummary.falsePositiveRateByLabel.__doc__ + + @property # type: ignore[misc] + def precisionByLabel(self) -> List[float]: + return self._get_summary_attr("precisionByLabel") + + precisionByLabel.__doc__ = _PySparkClassificationSummary.precisionByLabel.__doc__ + + @property # type: ignore[misc] + def recallByLabel(self) -> List[float]: + return self._get_summary_attr("recallByLabel") + + recallByLabel.__doc__ = _PySparkClassificationSummary.recallByLabel.__doc__ + + @property # type: ignore[misc] + def fMeasureByLabel(self, beta: float = 1.0) -> List[float]: + # TODO: support this. + raise NotImplementedError() + + fMeasureByLabel.__doc__ = _PySparkClassificationSummary.fMeasureByLabel.__doc__ + + @property # type: ignore[misc] + def accuracy(self) -> float: + return self._get_summary_attr("accuracy") + + accuracy.__doc__ = _PySparkClassificationSummary.accuracy.__doc__ + + @property # type: ignore[misc] + def weightedTruePositiveRate(self) -> float: + return self._get_summary_attr("weightedTruePositiveRate") + + weightedTruePositiveRate.__doc__ = _PySparkClassificationSummary.weightedTruePositiveRate.__doc__ + + @property # type: ignore[misc] + def weightedFalsePositiveRate(self) -> float: + return self._get_summary_attr("weightedFalsePositiveRate") + + weightedFalsePositiveRate.__doc__ = _PySparkClassificationSummary.weightedFalsePositiveRate.__doc__ + + @property # type: ignore[misc] + def weightedRecall(self) -> float: + return self._get_summary_attr("weightedRecall") + + weightedRecall.__doc__ = _PySparkClassificationSummary.weightedRecall.__doc__ + + @property # type: ignore[misc] + def weightedPrecision(self) -> float: + return self._get_summary_attr("weightedPrecision") + + weightedPrecision.__doc__ = _PySparkClassificationSummary.weightedPrecision.__doc__ + + def weightedFMeasure(self, beta: float = 1.0) -> float: + # TODO: support this. + raise NotImplementedError() + + weightedFMeasure.__doc__ = _PySparkClassificationSummary.weightedFMeasure.__doc__ + + +@inherit_doc +class _TrainingSummary(ClientModelSummary): + + @property # type: ignore[misc] + def objectiveHistory(self) -> List[float]: + return self._get_summary_attr("objectiveHistory") + + objectiveHistory.__doc__ = _PySparkTrainingSummary.objectiveHistory.__doc__ + + @property # type: ignore[misc] + def totalIterations(self) -> int: + return self._get_summary_attr("totalIterations") + + totalIterations.__doc__ = _PySparkTrainingSummary.totalIterations.__doc__ + + +@inherit_doc +class _BinaryClassificationSummary(_ClassificationSummary): + + @property # type: ignore[misc] + def scoreCol(self) -> str: + return self._get_summary_attr("scoreCol") + + scoreCol.__doc__ = _PySparkBinaryClassificationSummary.scoreCol.__doc__ + + @property + def roc(self) -> DataFrame: + return self._get_summary_attr_dataframe("roc") + + roc.__doc__ = _PySparkBinaryClassificationSummary.roc.__doc__ + + @property # type: ignore[misc] + def areaUnderROC(self) -> float: + return self._get_summary_attr("areaUnderROC") + + areaUnderROC.__doc__ = _PySparkBinaryClassificationSummary.areaUnderROC.__doc__ + + @property # type: ignore[misc] + def pr(self) -> DataFrame: + return self._get_summary_attr_dataframe("pr") + + pr.__doc__ = _PySparkBinaryClassificationSummary.pr.__doc__ + + @property # type: ignore[misc] + def fMeasureByThreshold(self) -> DataFrame: + return self._get_summary_attr_dataframe("fMeasureByThreshold") + + fMeasureByThreshold.__doc__ = _PySparkBinaryClassificationSummary.fMeasureByThreshold.__doc__ + + @property # type: ignore[misc] + def precisionByThreshold(self) -> DataFrame: + return self._get_summary_attr_dataframe("precisionByThreshold") + + precisionByThreshold.__doc__ = _PySparkBinaryClassificationSummary.precisionByThreshold.__doc__ + + @property # type: ignore[misc] + def recallByThreshold(self) -> DataFrame: + return self._get_summary_attr_dataframe("recallByThreshold") + + recallByThreshold.__doc__ = _PySparkBinaryClassificationSummary.recallByThreshold.__doc__ + + +@inherit_doc +class LogisticRegressionSummary(_ClassificationSummary): + + @property # type: ignore[misc] + def probabilityCol(self) -> str: + return self._get_summary_attr("probabilityCol") + + probabilityCol.__doc__ = PySparkLogisticRegressionSummary.probabilityCol.__doc__ + + @property # type: ignore[misc] + def featuresCol(self) -> str: + return self._get_summary_attr("featuresCol") + + featuresCol.__doc__ = PySparkLogisticRegressionSummary.featuresCol.__doc__ + + +LogisticRegressionSummary.__doc__ = PySparkLogisticRegressionSummary.__doc__ + + +@inherit_doc +class LogisticRegressionTrainingSummary(LogisticRegressionSummary, _TrainingSummary): + pass + + +LogisticRegressionTrainingSummary.__doc__ = PySparkLogisticRegressionTrainingSummary.__doc__ + + +@inherit_doc +class BinaryLogisticRegressionSummary(_BinaryClassificationSummary, LogisticRegressionSummary): + pass + + +BinaryLogisticRegressionSummary.__doc__ = PySparkBinaryLogisticRegressionSummary.__doc__ + + +@inherit_doc +class BinaryLogisticRegressionTrainingSummary( + BinaryLogisticRegressionSummary, LogisticRegressionTrainingSummary +): + pass + + +BinaryLogisticRegressionTrainingSummary.__doc__ = PySparkBinaryLogisticRegressionTrainingSummary.__doc__ + + +def _test() -> None: + import os + import sys + import doctest + from pyspark.sql import SparkSession as PySparkSession + import pyspark.ml.connect.classification + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.ml.connect.classification.__dict__.copy() + + globs["spark"] = ( + PySparkSession.builder.appName("ml.connect.classification tests") + .remote("local[4]") + .getOrCreate() + ) + + (failure_count, test_count) = doctest.testmod( + pyspark.ml.connect.classification, + globs=globs, + optionflags=doctest.ELLIPSIS + | doctest.NORMALIZE_WHITESPACE + | doctest.IGNORE_EXCEPTION_DETAIL, + ) + + globs["spark"].stop() + + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/ml/connect/serializer.py b/python/pyspark/ml/connect/serializer.py new file mode 100644 index 000000000000..38b1c00c2b2b --- /dev/null +++ b/python/pyspark/ml/connect/serializer.py @@ -0,0 +1,154 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pyspark.sql.connect.proto as pb2 +import pyspark.sql.connect.proto.ml_pb2 as ml_pb2 +import pyspark.sql.connect.proto.ml_common_pb2 as ml_common_pb2 + +from pyspark.sql.connect.expressions import LiteralExpression + +from pyspark.ml.linalg import Vector, Vectors, Matrices, Matrix + + +def deserialize_response_value(ml_command_result: ml_pb2.MlCommandResponse, client, **kwargs): + from pyspark.ml.connect.base import ModelRef + + if ml_command_result.HasField("literal"): + return LiteralExpression._to_value(ml_command_result.literal) + + if ml_command_result.HasField("model_info"): + model_info = ml_command_result.model_info + assert "clazz" in kwargs + clazz = kwargs["clazz"] + + model = clazz() + model._resetUid(model_info.model_uid) + set_instance_params_from_proto(model, model_info.params) + model.model_ref = ModelRef.from_proto(model_info.model_ref) + return model + + if ml_command_result.HasField("model_ref"): + return ModelRef.from_proto(ml_command_result.model_ref) + + if ml_command_result.HasField("vector"): + vector_pb = ml_command_result.vector + return deserialize_vector(vector_pb) + + if ml_command_result.HasField("matrix"): + matrix_pb = ml_command_result.matrix + return deserialize_matrix(matrix_pb) + + if ml_command_result.HasField("stage"): + assert "clazz" in kwargs + clazz = kwargs["clazz"] + stage_pb = ml_command_result.stage + stage = clazz() + stage._resetUid(stage_pb.uid) + set_instance_params_from_proto(stage, stage_pb.params) + return stage + + raise ValueError() + + +def deserialize_vector(vector_pb): + # TODO: support sparse + # TODO: support large vector + if vector_pb.HasField("dense"): + return Vectors.dense(vector_pb.dense.value) + raise ValueError() + + +def serialize_vector(vector): + return ml_common_pb2.Vector( + dense=ml_common_pb2.Vector.Dense( + value=vector.toArray() + ) + ) + + +def deserialize_matrix(matrix_pb): + # TODO: support sparse, is_transposed + # TODO: support large matrix + if matrix_pb.HasField("dense") and not matrix_pb.dense.is_transposed: + return Matrices.dense( + matrix_pb.dense.num_rows, + matrix_pb.dense.num_cols, + matrix_pb.dense.value, + ) + raise ValueError() + + +def serialize_matrix(matrix): + # TODO: support sparse, is_transposed + # TODO: support large matrix + return ml_common_pb2.Matrix( + dense=ml_common_pb2.Matrix.Dense( + num_rows=matrix.numRows, + num_cols=matrix.numCols, + value=matrix.toArray(), + is_transposed=False + ) + ) + + +def deserialize_param_value(value_pb: ml_common_pb2.MlParams.ParamValue): + if value_pb.HasField("literal"): + return LiteralExpression._to_value(value_pb.literal) + if value_pb.HasField("vector"): + return deserialize_vector(value_pb.vector) + if value_pb.HasField("matrix"): + return deserialize_vector(value_pb.matrix) + + +def serialize_param_value(value, client): + if isinstance(value, Vector): + return ml_common_pb2.MlParams.ParamValue( + vector= serialize_vector(value) + ) + if isinstance(value, Matrix): + return ml_common_pb2.MlParams.ParamValue( + matrix=serialize_matrix(value) + ) + return ml_common_pb2.MlParams.ParamValue( + literal=LiteralExpression._from_value(value).to_plan(client).literal + ) + + +def set_instance_params_from_proto(instance, params_proto): + instance._set(**{ + k: deserialize_param_value(v_pb) + for k, v_pb in params_proto.params.items() + }) + instance._setDefault(**{ + k: deserialize_param_value(v_pb) + for k, v_pb in params_proto.params.items() + }) + + +def serialize_ml_params(instance, client): + def gen_pb2_map(param_value_dict): + return { + k.name: serialize_param_value(v, client) + for k, v in param_value_dict.items() + } + + result = ml_common_pb2.MlParams( + params=gen_pb2_map(instance._paramMap), + default_params=gen_pb2_map(instance._defaultParamMap), + ) + return result + diff --git a/python/pyspark/ml/connect/utils.py b/python/pyspark/ml/connect/utils.py new file mode 100644 index 000000000000..41425e48a2d6 --- /dev/null +++ b/python/pyspark/ml/connect/utils.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.utils import is_remote +import os + + +def _get_remote_ml_class(cls): + remote_module = "pyspark.ml.connect." + cls.__module__[len("pyspark.ml."):] + cls_name = cls.__name__ + m = __import__(remote_module, fromlist=[cls_name]) + remote_cls = getattr(m, cls_name) + return remote_cls + + +def try_remote_ml_class(x): + + @classmethod + def patched__new__(cls, *args, **kwargs): + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + remote_cls = _get_remote_ml_class(cls) + return remote_cls(*args[1:], **kwargs) + + obj = object.__new__(cls) + obj.__init__(*args[1:], **kwargs) + return obj + + x.__new__ = patched__new__ + return x + + +def try_remote_ml_classmethod(fn): + + def patched_fn(cls, *args, **kwargs): + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + remote_cls = _get_remote_ml_class(cls) + original_fn_name = fn.__name__ + return getattr(remote_cls, original_fn_name)(*args, **kwargs) + return fn(cls, *args, **kwargs) + + return patched_fn diff --git a/python/pyspark/ml/tests/connect/test_connect_algorithm.py b/python/pyspark/ml/tests/connect/test_connect_algorithm.py new file mode 100644 index 000000000000..7def37c4b5eb --- /dev/null +++ b/python/pyspark/ml/tests/connect/test_connect_algorithm.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.ml.tests.test_algorithms import LogisticRegressionTest +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class LogisticRegressionParityTest(LogisticRegressionTest, ReusedConnectTestCase): + + def test_binomial_logistic_regression_with_bound(self): + super().test_binomial_logistic_regression_with_bound() + + def test_multinomial_logistic_regression_with_bound(self): + super().test_multinomial_logistic_regression_with_bound() + + def test_logistic_regression_with_threshold(self): + super().test_logistic_regression_with_threshold() + + +if __name__ == "__main__": + from pyspark.ml.tests.connect.test_connect_algorithm import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py index fb2507fe0852..b97ed93039be 100644 --- a/python/pyspark/ml/tests/test_algorithms.py +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -30,6 +30,7 @@ from pyspark.ml.clustering import DistributedLDAModel, KMeans, LocalLDAModel, LDA, LDAModel from pyspark.ml.fpm import FPGrowth from pyspark.ml.linalg import Matrices, Vectors, DenseVector +from pyspark.ml.functions import array_to_vector from pyspark.ml.recommendation import ALS from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression from pyspark.sql import Row @@ -41,13 +42,13 @@ def test_binomial_logistic_regression_with_bound(self): df = self.spark.createDataFrame( [ - (1.0, 1.0, Vectors.dense(0.0, 5.0)), - (0.0, 2.0, Vectors.dense(1.0, 2.0)), - (1.0, 3.0, Vectors.dense(2.0, 1.0)), - (0.0, 4.0, Vectors.dense(3.0, 3.0)), + (1.0, 1.0, [0.0, 5.0]), + (0.0, 2.0, [1.0, 2.0]), + (1.0, 3.0, [2.0, 1.0]), + (0.0, 4.0, [3.0, 3.0]), ], ["label", "weight", "features"], - ) + ).withColumn("features", array_to_vector("features")) lor = LogisticRegression( regParam=0.01, diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 5d1f89cbc13b..c7a0ca435979 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -42,6 +42,8 @@ from pyspark.sql.utils import is_remote from pyspark.util import VersionUtils +from pyspark.ml.connect.utils import try_remote_ml_classmethod + if TYPE_CHECKING: from py4j.java_gateway import JavaGateway, JavaObject from pyspark.ml._typing import PipelineStage @@ -376,6 +378,7 @@ class JavaMLReadable(MLReadable[RL]): """ @classmethod + @try_remote_ml_classmethod def read(cls) -> JavaMLReader[RL]: """Returns an MLReader instance for this class.""" return JavaMLReader(cls) diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 53fa97372a75..0011b149be3b 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -854,6 +854,31 @@ def _execute(self, req: pb2.ExecutePlanRequest) -> None: except grpc.RpcError as rpc_error: self._handle_error(rpc_error) + def _execute_ml(self, req: pb2.ExecutePlanRequest): + """ + Execute the passed ML command request `req` and return ML response result + Parameters + ---------- + req : pb2.ExecutePlanRequest + Proto representation of the plan. + """ + logger.info("Execute ML") + try: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy + ): + with attempt: + for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): + if b.session_id != self._session_id: + raise SparkConnectException( + "Received incorrect session identifier for request: " + f"{b.client_id} != {self._session_id}" + ) + assert b.HasField("ml_command_result") + return b.ml_command_result + except grpc.RpcError as rpc_error: + self._handle_error(rpc_error) + def _execute_and_fetch( self, req: pb2.ExecutePlanRequest ) -> Tuple[ diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 365573448930..f6c63c7f7156 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -34,10 +34,11 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2 from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2 from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2 +from pyspark.sql.connect.proto import ml_pb2 as spark_dot_connect_dot_ml__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\x89\x0e\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1a\x35\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04planB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\xb4\n\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06resultB\x08\n\x06result"\xd1\x01\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xfb\t\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1a`\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06valuesB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful2\xed\x02\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto\x1a\x16spark/connect/ml.proto"\xaf\x01\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommand\x12\x39\n\nml_command\x18\x03 \x01(\x0b\x32\x18.spark.connect.MlCommandH\x00R\tmlCommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\x89\x0e\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1a\x35\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04planB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\xb4\n\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06resultB\x08\n\x06result"\xd1\x01\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xcb\n\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12N\n\x11ml_command_result\x18\x64 \x01(\x0b\x32 .spark.connect.MlCommandResponseH\x00R\x0fmlCommandResult\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1a`\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06valuesB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful2\xed\x02\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -637,110 +638,110 @@ DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = b"8\001" - _PLAN._serialized_start = 191 - _PLAN._serialized_end = 307 - _USERCONTEXT._serialized_start = 309 - _USERCONTEXT._serialized_end = 431 - _ANALYZEPLANREQUEST._serialized_start = 434 - _ANALYZEPLANREQUEST._serialized_end = 2235 - _ANALYZEPLANREQUEST_SCHEMA._serialized_start = 1384 - _ANALYZEPLANREQUEST_SCHEMA._serialized_end = 1433 - _ANALYZEPLANREQUEST_EXPLAIN._serialized_start = 1436 - _ANALYZEPLANREQUEST_EXPLAIN._serialized_end = 1751 - _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE._serialized_start = 1579 - _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE._serialized_end = 1751 - _ANALYZEPLANREQUEST_TREESTRING._serialized_start = 1753 - _ANALYZEPLANREQUEST_TREESTRING._serialized_end = 1806 - _ANALYZEPLANREQUEST_ISLOCAL._serialized_start = 1808 - _ANALYZEPLANREQUEST_ISLOCAL._serialized_end = 1858 - _ANALYZEPLANREQUEST_ISSTREAMING._serialized_start = 1860 - _ANALYZEPLANREQUEST_ISSTREAMING._serialized_end = 1914 - _ANALYZEPLANREQUEST_INPUTFILES._serialized_start = 1916 - _ANALYZEPLANREQUEST_INPUTFILES._serialized_end = 1969 - _ANALYZEPLANREQUEST_SPARKVERSION._serialized_start = 1971 - _ANALYZEPLANREQUEST_SPARKVERSION._serialized_end = 1985 - _ANALYZEPLANREQUEST_DDLPARSE._serialized_start = 1987 - _ANALYZEPLANREQUEST_DDLPARSE._serialized_end = 2028 - _ANALYZEPLANREQUEST_SAMESEMANTICS._serialized_start = 2030 - _ANALYZEPLANREQUEST_SAMESEMANTICS._serialized_end = 2151 - _ANALYZEPLANREQUEST_SEMANTICHASH._serialized_start = 2153 - _ANALYZEPLANREQUEST_SEMANTICHASH._serialized_end = 2208 - _ANALYZEPLANRESPONSE._serialized_start = 2238 - _ANALYZEPLANRESPONSE._serialized_end = 3570 - _ANALYZEPLANRESPONSE_SCHEMA._serialized_start = 3098 - _ANALYZEPLANRESPONSE_SCHEMA._serialized_end = 3155 - _ANALYZEPLANRESPONSE_EXPLAIN._serialized_start = 3157 - _ANALYZEPLANRESPONSE_EXPLAIN._serialized_end = 3205 - _ANALYZEPLANRESPONSE_TREESTRING._serialized_start = 3207 - _ANALYZEPLANRESPONSE_TREESTRING._serialized_end = 3252 - _ANALYZEPLANRESPONSE_ISLOCAL._serialized_start = 3254 - _ANALYZEPLANRESPONSE_ISLOCAL._serialized_end = 3290 - _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_start = 3292 - _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_end = 3340 - _ANALYZEPLANRESPONSE_INPUTFILES._serialized_start = 3342 - _ANALYZEPLANRESPONSE_INPUTFILES._serialized_end = 3376 - _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_start = 3378 - _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_end = 3418 - _ANALYZEPLANRESPONSE_DDLPARSE._serialized_start = 3420 - _ANALYZEPLANRESPONSE_DDLPARSE._serialized_end = 3479 - _ANALYZEPLANRESPONSE_SAMESEMANTICS._serialized_start = 3481 - _ANALYZEPLANRESPONSE_SAMESEMANTICS._serialized_end = 3520 - _ANALYZEPLANRESPONSE_SEMANTICHASH._serialized_start = 3522 - _ANALYZEPLANRESPONSE_SEMANTICHASH._serialized_end = 3560 - _EXECUTEPLANREQUEST._serialized_start = 3573 - _EXECUTEPLANREQUEST._serialized_end = 3782 - _EXECUTEPLANRESPONSE._serialized_start = 3785 - _EXECUTEPLANRESPONSE._serialized_end = 5060 - _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 4291 - _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 4362 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 4364 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 4425 - _EXECUTEPLANRESPONSE_METRICS._serialized_start = 4428 - _EXECUTEPLANRESPONSE_METRICS._serialized_end = 4945 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 4523 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 4855 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 4732 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 4855 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 4857 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 4945 - _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 4947 - _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 5043 - _KEYVALUE._serialized_start = 5062 - _KEYVALUE._serialized_end = 5127 - _CONFIGREQUEST._serialized_start = 5130 - _CONFIGREQUEST._serialized_end = 6158 - _CONFIGREQUEST_OPERATION._serialized_start = 5350 - _CONFIGREQUEST_OPERATION._serialized_end = 5848 - _CONFIGREQUEST_SET._serialized_start = 5850 - _CONFIGREQUEST_SET._serialized_end = 5902 - _CONFIGREQUEST_GET._serialized_start = 5904 - _CONFIGREQUEST_GET._serialized_end = 5929 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 5931 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 5994 - _CONFIGREQUEST_GETOPTION._serialized_start = 5996 - _CONFIGREQUEST_GETOPTION._serialized_end = 6027 - _CONFIGREQUEST_GETALL._serialized_start = 6029 - _CONFIGREQUEST_GETALL._serialized_end = 6077 - _CONFIGREQUEST_UNSET._serialized_start = 6079 - _CONFIGREQUEST_UNSET._serialized_end = 6106 - _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 6108 - _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 6142 - _CONFIGRESPONSE._serialized_start = 6160 - _CONFIGRESPONSE._serialized_end = 6282 - _ADDARTIFACTSREQUEST._serialized_start = 6285 - _ADDARTIFACTSREQUEST._serialized_end = 7156 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 6672 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 6725 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 6727 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 6838 - _ADDARTIFACTSREQUEST_BATCH._serialized_start = 6840 - _ADDARTIFACTSREQUEST_BATCH._serialized_end = 6933 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 6936 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 7129 - _ADDARTIFACTSRESPONSE._serialized_start = 7159 - _ADDARTIFACTSRESPONSE._serialized_end = 7347 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 7266 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 7347 - _SPARKCONNECTSERVICE._serialized_start = 7350 - _SPARKCONNECTSERVICE._serialized_end = 7715 + _PLAN._serialized_start = 216 + _PLAN._serialized_end = 391 + _USERCONTEXT._serialized_start = 393 + _USERCONTEXT._serialized_end = 515 + _ANALYZEPLANREQUEST._serialized_start = 518 + _ANALYZEPLANREQUEST._serialized_end = 2319 + _ANALYZEPLANREQUEST_SCHEMA._serialized_start = 1468 + _ANALYZEPLANREQUEST_SCHEMA._serialized_end = 1517 + _ANALYZEPLANREQUEST_EXPLAIN._serialized_start = 1520 + _ANALYZEPLANREQUEST_EXPLAIN._serialized_end = 1835 + _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE._serialized_start = 1663 + _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE._serialized_end = 1835 + _ANALYZEPLANREQUEST_TREESTRING._serialized_start = 1837 + _ANALYZEPLANREQUEST_TREESTRING._serialized_end = 1890 + _ANALYZEPLANREQUEST_ISLOCAL._serialized_start = 1892 + _ANALYZEPLANREQUEST_ISLOCAL._serialized_end = 1942 + _ANALYZEPLANREQUEST_ISSTREAMING._serialized_start = 1944 + _ANALYZEPLANREQUEST_ISSTREAMING._serialized_end = 1998 + _ANALYZEPLANREQUEST_INPUTFILES._serialized_start = 2000 + _ANALYZEPLANREQUEST_INPUTFILES._serialized_end = 2053 + _ANALYZEPLANREQUEST_SPARKVERSION._serialized_start = 2055 + _ANALYZEPLANREQUEST_SPARKVERSION._serialized_end = 2069 + _ANALYZEPLANREQUEST_DDLPARSE._serialized_start = 2071 + _ANALYZEPLANREQUEST_DDLPARSE._serialized_end = 2112 + _ANALYZEPLANREQUEST_SAMESEMANTICS._serialized_start = 2114 + _ANALYZEPLANREQUEST_SAMESEMANTICS._serialized_end = 2235 + _ANALYZEPLANREQUEST_SEMANTICHASH._serialized_start = 2237 + _ANALYZEPLANREQUEST_SEMANTICHASH._serialized_end = 2292 + _ANALYZEPLANRESPONSE._serialized_start = 2322 + _ANALYZEPLANRESPONSE._serialized_end = 3654 + _ANALYZEPLANRESPONSE_SCHEMA._serialized_start = 3182 + _ANALYZEPLANRESPONSE_SCHEMA._serialized_end = 3239 + _ANALYZEPLANRESPONSE_EXPLAIN._serialized_start = 3241 + _ANALYZEPLANRESPONSE_EXPLAIN._serialized_end = 3289 + _ANALYZEPLANRESPONSE_TREESTRING._serialized_start = 3291 + _ANALYZEPLANRESPONSE_TREESTRING._serialized_end = 3336 + _ANALYZEPLANRESPONSE_ISLOCAL._serialized_start = 3338 + _ANALYZEPLANRESPONSE_ISLOCAL._serialized_end = 3374 + _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_start = 3376 + _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_end = 3424 + _ANALYZEPLANRESPONSE_INPUTFILES._serialized_start = 3426 + _ANALYZEPLANRESPONSE_INPUTFILES._serialized_end = 3460 + _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_start = 3462 + _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_end = 3502 + _ANALYZEPLANRESPONSE_DDLPARSE._serialized_start = 3504 + _ANALYZEPLANRESPONSE_DDLPARSE._serialized_end = 3563 + _ANALYZEPLANRESPONSE_SAMESEMANTICS._serialized_start = 3565 + _ANALYZEPLANRESPONSE_SAMESEMANTICS._serialized_end = 3604 + _ANALYZEPLANRESPONSE_SEMANTICHASH._serialized_start = 3606 + _ANALYZEPLANRESPONSE_SEMANTICHASH._serialized_end = 3644 + _EXECUTEPLANREQUEST._serialized_start = 3657 + _EXECUTEPLANREQUEST._serialized_end = 3866 + _EXECUTEPLANRESPONSE._serialized_start = 3869 + _EXECUTEPLANRESPONSE._serialized_end = 5224 + _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 4455 + _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 4526 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 4528 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 4589 + _EXECUTEPLANRESPONSE_METRICS._serialized_start = 4592 + _EXECUTEPLANRESPONSE_METRICS._serialized_end = 5109 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 4687 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 5019 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 4896 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 5019 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 5021 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 5109 + _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 5111 + _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 5207 + _KEYVALUE._serialized_start = 5226 + _KEYVALUE._serialized_end = 5291 + _CONFIGREQUEST._serialized_start = 5294 + _CONFIGREQUEST._serialized_end = 6322 + _CONFIGREQUEST_OPERATION._serialized_start = 5514 + _CONFIGREQUEST_OPERATION._serialized_end = 6012 + _CONFIGREQUEST_SET._serialized_start = 6014 + _CONFIGREQUEST_SET._serialized_end = 6066 + _CONFIGREQUEST_GET._serialized_start = 6068 + _CONFIGREQUEST_GET._serialized_end = 6093 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 6095 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 6158 + _CONFIGREQUEST_GETOPTION._serialized_start = 6160 + _CONFIGREQUEST_GETOPTION._serialized_end = 6191 + _CONFIGREQUEST_GETALL._serialized_start = 6193 + _CONFIGREQUEST_GETALL._serialized_end = 6241 + _CONFIGREQUEST_UNSET._serialized_start = 6243 + _CONFIGREQUEST_UNSET._serialized_end = 6270 + _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 6272 + _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 6306 + _CONFIGRESPONSE._serialized_start = 6324 + _CONFIGRESPONSE._serialized_end = 6446 + _ADDARTIFACTSREQUEST._serialized_start = 6449 + _ADDARTIFACTSREQUEST._serialized_end = 7320 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 6836 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 6889 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 6891 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 7002 + _ADDARTIFACTSREQUEST_BATCH._serialized_start = 7004 + _ADDARTIFACTSREQUEST_BATCH._serialized_end = 7097 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 7100 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 7293 + _ADDARTIFACTSRESPONSE._serialized_start = 7323 + _ADDARTIFACTSRESPONSE._serialized_end = 7511 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 7430 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 7511 + _SPARKCONNECTSERVICE._serialized_start = 7514 + _SPARKCONNECTSERVICE._serialized_end = 7879 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 4c020308d9a9..eb66939a0796 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -42,6 +42,7 @@ import google.protobuf.internal.enum_type_wrapper import google.protobuf.message import pyspark.sql.connect.proto.commands_pb2 import pyspark.sql.connect.proto.expressions_pb2 +import pyspark.sql.connect.proto.ml_pb2 import pyspark.sql.connect.proto.relations_pb2 import pyspark.sql.connect.proto.types_pb2 import sys @@ -65,31 +66,49 @@ class Plan(google.protobuf.message.Message): ROOT_FIELD_NUMBER: builtins.int COMMAND_FIELD_NUMBER: builtins.int + ML_COMMAND_FIELD_NUMBER: builtins.int @property def root(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: ... @property def command(self) -> pyspark.sql.connect.proto.commands_pb2.Command: ... + @property + def ml_command(self) -> pyspark.sql.connect.proto.ml_pb2.MlCommand: ... def __init__( self, *, root: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., command: pyspark.sql.connect.proto.commands_pb2.Command | None = ..., + ml_command: pyspark.sql.connect.proto.ml_pb2.MlCommand | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "command", b"command", "op_type", b"op_type", "root", b"root" + "command", + b"command", + "ml_command", + b"ml_command", + "op_type", + b"op_type", + "root", + b"root", ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "command", b"command", "op_type", b"op_type", "root", b"root" + "command", + b"command", + "ml_command", + b"ml_command", + "op_type", + b"op_type", + "root", + b"root", ], ) -> None: ... def WhichOneof( self, oneof_group: typing_extensions.Literal["op_type", b"op_type"] - ) -> typing_extensions.Literal["root", "command"] | None: ... + ) -> typing_extensions.Literal["root", "command", "ml_command"] | None: ... global___Plan = Plan @@ -1053,6 +1072,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): SESSION_ID_FIELD_NUMBER: builtins.int ARROW_BATCH_FIELD_NUMBER: builtins.int SQL_COMMAND_RESULT_FIELD_NUMBER: builtins.int + ML_COMMAND_RESULT_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int METRICS_FIELD_NUMBER: builtins.int OBSERVED_METRICS_FIELD_NUMBER: builtins.int @@ -1064,6 +1084,9 @@ class ExecutePlanResponse(google.protobuf.message.Message): def sql_command_result(self) -> global___ExecutePlanResponse.SqlCommandResult: """Special case for executing SQL commands.""" @property + def ml_command_result(self) -> pyspark.sql.connect.proto.ml_pb2.MlCommandResponse: + """ML command response""" + @property def extension(self) -> google.protobuf.any_pb2.Any: """Support arbitrary result objects.""" @property @@ -1087,6 +1110,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): session_id: builtins.str = ..., arrow_batch: global___ExecutePlanResponse.ArrowBatch | None = ..., sql_command_result: global___ExecutePlanResponse.SqlCommandResult | None = ..., + ml_command_result: pyspark.sql.connect.proto.ml_pb2.MlCommandResponse | None = ..., extension: google.protobuf.any_pb2.Any | None = ..., metrics: global___ExecutePlanResponse.Metrics | None = ..., observed_metrics: collections.abc.Iterable[global___ExecutePlanResponse.ObservedMetrics] @@ -1102,6 +1126,8 @@ class ExecutePlanResponse(google.protobuf.message.Message): b"extension", "metrics", b"metrics", + "ml_command_result", + b"ml_command_result", "response_type", b"response_type", "schema", @@ -1119,6 +1145,8 @@ class ExecutePlanResponse(google.protobuf.message.Message): b"extension", "metrics", b"metrics", + "ml_command_result", + b"ml_command_result", "observed_metrics", b"observed_metrics", "response_type", @@ -1133,7 +1161,9 @@ class ExecutePlanResponse(google.protobuf.message.Message): ) -> None: ... def WhichOneof( self, oneof_group: typing_extensions.Literal["response_type", b"response_type"] - ) -> typing_extensions.Literal["arrow_batch", "sql_command_result", "extension"] | None: ... + ) -> typing_extensions.Literal[ + "arrow_batch", "sql_command_result", "ml_command_result", "extension" + ] | None: ... global___ExecutePlanResponse = ExecutePlanResponse diff --git a/python/pyspark/sql/connect/proto/ml_common_pb2.py b/python/pyspark/sql/connect/proto/ml_common_pb2.py new file mode 100644 index 000000000000..3e386246354e --- /dev/null +++ b/python/pyspark/sql/connect/proto/ml_common_pb2.py @@ -0,0 +1,212 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: spark/connect/ml_common.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x1dspark/connect/ml_common.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xa3\x04\n\x08MlParams\x12;\n\x06params\x18\x01 \x03(\x0b\x32#.spark.connect.MlParams.ParamsEntryR\x06params\x12Q\n\x0e\x64\x65\x66\x61ult_params\x18\x02 \x03(\x0b\x32*.spark.connect.MlParams.DefaultParamsEntryR\rdefaultParams\x1a]\n\x0bParamsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.MlParams.ParamValueR\x05value:\x02\x38\x01\x1a\x64\n\x12\x44\x65\x66\x61ultParamsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.MlParams.ParamValueR\x05value:\x02\x38\x01\x1a\xc1\x01\n\nParamValue\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12/\n\x06vector\x18\x02 \x01(\x0b\x32\x15.spark.connect.VectorH\x00R\x06vector\x12/\n\x06matrix\x18\x03 \x01(\x0b\x32\x15.spark.connect.MatrixH\x00R\x06matrixB\x12\n\x10param_value_type"\xf5\x01\n\x07MlStage\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12/\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06params\x12\x10\n\x03uid\x18\x03 \x01(\tR\x03uid\x12\x34\n\x04type\x18\x04 \x01(\x0e\x32 .spark.connect.MlStage.StageTypeR\x04type"]\n\tStageType\x12\x1a\n\x16STAGE_TYPE_UNSPECIFIED\x10\x00\x12\x18\n\x14STAGE_TYPE_ESTIMATOR\x10\x01\x12\x1a\n\x16STAGE_TYPE_TRANSFORMER\x10\x02"\x1a\n\x08ModelRef\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id"\xe8\x01\n\x06Vector\x12\x33\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.spark.connect.Vector.DenseH\x00R\x05\x64\x65nse\x12\x36\n\x06sparse\x18\x02 \x01(\x0b\x32\x1c.spark.connect.Vector.SparseH\x00R\x06sparse\x1a\x1d\n\x05\x44\x65nse\x12\x14\n\x05value\x18\x01 \x03(\x01R\x05value\x1aH\n\x06Sparse\x12\x12\n\x04size\x18\x01 \x01(\x05R\x04size\x12\x14\n\x05index\x18\x02 \x03(\x01R\x05index\x12\x14\n\x05value\x18\x03 \x03(\x01R\x05valueB\x08\n\x06one_of"\xaa\x03\n\x06Matrix\x12\x33\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.spark.connect.Matrix.DenseH\x00R\x05\x64\x65nse\x12\x36\n\x06sparse\x18\x02 \x01(\x0b\x32\x1c.spark.connect.Matrix.SparseH\x00R\x06sparse\x1ax\n\x05\x44\x65nse\x12\x19\n\x08num_rows\x18\x01 \x01(\x05R\x07numRows\x12\x19\n\x08num_cols\x18\x02 \x01(\x05R\x07numCols\x12\x14\n\x05value\x18\x03 \x03(\x01R\x05value\x12#\n\ris_transposed\x18\x04 \x01(\x08R\x0cisTransposed\x1a\xae\x01\n\x06Sparse\x12\x19\n\x08num_rows\x18\x01 \x01(\x05R\x07numRows\x12\x19\n\x08num_cols\x18\x02 \x01(\x05R\x07numCols\x12\x16\n\x06\x63olptr\x18\x03 \x03(\x01R\x06\x63olptr\x12\x1b\n\trow_index\x18\x04 \x03(\x01R\x08rowIndex\x12\x14\n\x05value\x18\x05 \x03(\x01R\x05value\x12#\n\ris_transposed\x18\x06 \x01(\x08R\x0cisTransposedB\x08\n\x06one_ofB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' +) + + +_MLPARAMS = DESCRIPTOR.message_types_by_name["MlParams"] +_MLPARAMS_PARAMSENTRY = _MLPARAMS.nested_types_by_name["ParamsEntry"] +_MLPARAMS_DEFAULTPARAMSENTRY = _MLPARAMS.nested_types_by_name["DefaultParamsEntry"] +_MLPARAMS_PARAMVALUE = _MLPARAMS.nested_types_by_name["ParamValue"] +_MLSTAGE = DESCRIPTOR.message_types_by_name["MlStage"] +_MODELREF = DESCRIPTOR.message_types_by_name["ModelRef"] +_VECTOR = DESCRIPTOR.message_types_by_name["Vector"] +_VECTOR_DENSE = _VECTOR.nested_types_by_name["Dense"] +_VECTOR_SPARSE = _VECTOR.nested_types_by_name["Sparse"] +_MATRIX = DESCRIPTOR.message_types_by_name["Matrix"] +_MATRIX_DENSE = _MATRIX.nested_types_by_name["Dense"] +_MATRIX_SPARSE = _MATRIX.nested_types_by_name["Sparse"] +_MLSTAGE_STAGETYPE = _MLSTAGE.enum_types_by_name["StageType"] +MlParams = _reflection.GeneratedProtocolMessageType( + "MlParams", + (_message.Message,), + { + "ParamsEntry": _reflection.GeneratedProtocolMessageType( + "ParamsEntry", + (_message.Message,), + { + "DESCRIPTOR": _MLPARAMS_PARAMSENTRY, + "__module__": "spark.connect.ml_common_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlParams.ParamsEntry) + }, + ), + "DefaultParamsEntry": _reflection.GeneratedProtocolMessageType( + "DefaultParamsEntry", + (_message.Message,), + { + "DESCRIPTOR": _MLPARAMS_DEFAULTPARAMSENTRY, + "__module__": "spark.connect.ml_common_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlParams.DefaultParamsEntry) + }, + ), + "ParamValue": _reflection.GeneratedProtocolMessageType( + "ParamValue", + (_message.Message,), + { + "DESCRIPTOR": _MLPARAMS_PARAMVALUE, + "__module__": "spark.connect.ml_common_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlParams.ParamValue) + }, + ), + "DESCRIPTOR": _MLPARAMS, + "__module__": "spark.connect.ml_common_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlParams) + }, +) +_sym_db.RegisterMessage(MlParams) +_sym_db.RegisterMessage(MlParams.ParamsEntry) +_sym_db.RegisterMessage(MlParams.DefaultParamsEntry) +_sym_db.RegisterMessage(MlParams.ParamValue) + +MlStage = _reflection.GeneratedProtocolMessageType( + "MlStage", + (_message.Message,), + { + "DESCRIPTOR": _MLSTAGE, + "__module__": "spark.connect.ml_common_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlStage) + }, +) +_sym_db.RegisterMessage(MlStage) + +ModelRef = _reflection.GeneratedProtocolMessageType( + "ModelRef", + (_message.Message,), + { + "DESCRIPTOR": _MODELREF, + "__module__": "spark.connect.ml_common_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ModelRef) + }, +) +_sym_db.RegisterMessage(ModelRef) + +Vector = _reflection.GeneratedProtocolMessageType( + "Vector", + (_message.Message,), + { + "Dense": _reflection.GeneratedProtocolMessageType( + "Dense", + (_message.Message,), + { + "DESCRIPTOR": _VECTOR_DENSE, + "__module__": "spark.connect.ml_common_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Vector.Dense) + }, + ), + "Sparse": _reflection.GeneratedProtocolMessageType( + "Sparse", + (_message.Message,), + { + "DESCRIPTOR": _VECTOR_SPARSE, + "__module__": "spark.connect.ml_common_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Vector.Sparse) + }, + ), + "DESCRIPTOR": _VECTOR, + "__module__": "spark.connect.ml_common_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Vector) + }, +) +_sym_db.RegisterMessage(Vector) +_sym_db.RegisterMessage(Vector.Dense) +_sym_db.RegisterMessage(Vector.Sparse) + +Matrix = _reflection.GeneratedProtocolMessageType( + "Matrix", + (_message.Message,), + { + "Dense": _reflection.GeneratedProtocolMessageType( + "Dense", + (_message.Message,), + { + "DESCRIPTOR": _MATRIX_DENSE, + "__module__": "spark.connect.ml_common_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Matrix.Dense) + }, + ), + "Sparse": _reflection.GeneratedProtocolMessageType( + "Sparse", + (_message.Message,), + { + "DESCRIPTOR": _MATRIX_SPARSE, + "__module__": "spark.connect.ml_common_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Matrix.Sparse) + }, + ), + "DESCRIPTOR": _MATRIX, + "__module__": "spark.connect.ml_common_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Matrix) + }, +) +_sym_db.RegisterMessage(Matrix) +_sym_db.RegisterMessage(Matrix.Dense) +_sym_db.RegisterMessage(Matrix.Sparse) + +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" + _MLPARAMS_PARAMSENTRY._options = None + _MLPARAMS_PARAMSENTRY._serialized_options = b"8\001" + _MLPARAMS_DEFAULTPARAMSENTRY._options = None + _MLPARAMS_DEFAULTPARAMSENTRY._serialized_options = b"8\001" + _MLPARAMS._serialized_start = 82 + _MLPARAMS._serialized_end = 629 + _MLPARAMS_PARAMSENTRY._serialized_start = 238 + _MLPARAMS_PARAMSENTRY._serialized_end = 331 + _MLPARAMS_DEFAULTPARAMSENTRY._serialized_start = 333 + _MLPARAMS_DEFAULTPARAMSENTRY._serialized_end = 433 + _MLPARAMS_PARAMVALUE._serialized_start = 436 + _MLPARAMS_PARAMVALUE._serialized_end = 629 + _MLSTAGE._serialized_start = 632 + _MLSTAGE._serialized_end = 877 + _MLSTAGE_STAGETYPE._serialized_start = 784 + _MLSTAGE_STAGETYPE._serialized_end = 877 + _MODELREF._serialized_start = 879 + _MODELREF._serialized_end = 905 + _VECTOR._serialized_start = 908 + _VECTOR._serialized_end = 1140 + _VECTOR_DENSE._serialized_start = 1027 + _VECTOR_DENSE._serialized_end = 1056 + _VECTOR_SPARSE._serialized_start = 1058 + _VECTOR_SPARSE._serialized_end = 1130 + _MATRIX._serialized_start = 1143 + _MATRIX._serialized_end = 1569 + _MATRIX_DENSE._serialized_start = 1262 + _MATRIX_DENSE._serialized_end = 1382 + _MATRIX_SPARSE._serialized_start = 1385 + _MATRIX_SPARSE._serialized_end = 1559 +# @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/ml_common_pb2.pyi b/python/pyspark/sql/connect/proto/ml_common_pb2.pyi new file mode 100644 index 000000000000..3c5411788009 --- /dev/null +++ b/python/pyspark/sql/connect/proto/ml_common_pb2.pyi @@ -0,0 +1,447 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file + +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import pyspark.sql.connect.proto.expressions_pb2 +import sys +import typing + +if sys.version_info >= (3, 10): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class MlParams(google.protobuf.message.Message): + """MlParams stores param settings for + ML Estimator / Transformer / Model / Evaluator + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class ParamsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + @property + def value(self) -> global___MlParams.ParamValue: ... + def __init__( + self, + *, + key: builtins.str = ..., + value: global___MlParams.ParamValue | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["value", b"value"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] + ) -> None: ... + + class DefaultParamsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + @property + def value(self) -> global___MlParams.ParamValue: ... + def __init__( + self, + *, + key: builtins.str = ..., + value: global___MlParams.ParamValue | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["value", b"value"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] + ) -> None: ... + + class ParamValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + LITERAL_FIELD_NUMBER: builtins.int + VECTOR_FIELD_NUMBER: builtins.int + MATRIX_FIELD_NUMBER: builtins.int + @property + def literal(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression.Literal: ... + @property + def vector(self) -> global___Vector: ... + @property + def matrix(self) -> global___Matrix: ... + def __init__( + self, + *, + literal: pyspark.sql.connect.proto.expressions_pb2.Expression.Literal | None = ..., + vector: global___Vector | None = ..., + matrix: global___Matrix | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "literal", + b"literal", + "matrix", + b"matrix", + "param_value_type", + b"param_value_type", + "vector", + b"vector", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "literal", + b"literal", + "matrix", + b"matrix", + "param_value_type", + b"param_value_type", + "vector", + b"vector", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["param_value_type", b"param_value_type"] + ) -> typing_extensions.Literal["literal", "vector", "matrix"] | None: ... + + PARAMS_FIELD_NUMBER: builtins.int + DEFAULT_PARAMS_FIELD_NUMBER: builtins.int + @property + def params( + self, + ) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___MlParams.ParamValue]: + """user-supplied params""" + @property + def default_params( + self, + ) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___MlParams.ParamValue]: + """default params""" + def __init__( + self, + *, + params: collections.abc.Mapping[builtins.str, global___MlParams.ParamValue] | None = ..., + default_params: collections.abc.Mapping[builtins.str, global___MlParams.ParamValue] + | None = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "default_params", b"default_params", "params", b"params" + ], + ) -> None: ... + +global___MlParams = MlParams + +class MlStage(google.protobuf.message.Message): + """MlStage stores ML stage data (Estimator or Transformer)""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class _StageType: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + + class _StageTypeEnumTypeWrapper( + google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[MlStage._StageType.ValueType], + builtins.type, + ): # noqa: F821 + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + STAGE_TYPE_UNSPECIFIED: MlStage._StageType.ValueType # 0 + STAGE_TYPE_ESTIMATOR: MlStage._StageType.ValueType # 1 + STAGE_TYPE_TRANSFORMER: MlStage._StageType.ValueType # 2 + + class StageType(_StageType, metaclass=_StageTypeEnumTypeWrapper): ... + STAGE_TYPE_UNSPECIFIED: MlStage.StageType.ValueType # 0 + STAGE_TYPE_ESTIMATOR: MlStage.StageType.ValueType # 1 + STAGE_TYPE_TRANSFORMER: MlStage.StageType.ValueType # 2 + + NAME_FIELD_NUMBER: builtins.int + PARAMS_FIELD_NUMBER: builtins.int + UID_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + name: builtins.str + """The name of the stage in the registry""" + @property + def params(self) -> global___MlParams: + """param settings for the stage""" + uid: builtins.str + """unique id of the stage""" + type: global___MlStage.StageType.ValueType + def __init__( + self, + *, + name: builtins.str = ..., + params: global___MlParams | None = ..., + uid: builtins.str = ..., + type: global___MlStage.StageType.ValueType = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["params", b"params"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "name", b"name", "params", b"params", "type", b"type", "uid", b"uid" + ], + ) -> None: ... + +global___MlStage = MlStage + +class ModelRef(google.protobuf.message.Message): + """ModelRef represents a reference to server side `Model` instance""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ID_FIELD_NUMBER: builtins.int + id: builtins.str + """The ID is used to lookup the model instance in server side.""" + def __init__( + self, + *, + id: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["id", b"id"]) -> None: ... + +global___ModelRef = ModelRef + +class Vector(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class Dense(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + VALUE_FIELD_NUMBER: builtins.int + @property + def value( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + def __init__( + self, + *, + value: collections.abc.Iterable[builtins.float] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["value", b"value"]) -> None: ... + + class Sparse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SIZE_FIELD_NUMBER: builtins.int + INDEX_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + size: builtins.int + @property + def index( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + @property + def value( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + def __init__( + self, + *, + size: builtins.int = ..., + index: collections.abc.Iterable[builtins.float] | None = ..., + value: collections.abc.Iterable[builtins.float] | None = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "index", b"index", "size", b"size", "value", b"value" + ], + ) -> None: ... + + DENSE_FIELD_NUMBER: builtins.int + SPARSE_FIELD_NUMBER: builtins.int + @property + def dense(self) -> global___Vector.Dense: ... + @property + def sparse(self) -> global___Vector.Sparse: ... + def __init__( + self, + *, + dense: global___Vector.Dense | None = ..., + sparse: global___Vector.Sparse | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "dense", b"dense", "one_of", b"one_of", "sparse", b"sparse" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "dense", b"dense", "one_of", b"one_of", "sparse", b"sparse" + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["one_of", b"one_of"] + ) -> typing_extensions.Literal["dense", "sparse"] | None: ... + +global___Vector = Vector + +class Matrix(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class Dense(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NUM_ROWS_FIELD_NUMBER: builtins.int + NUM_COLS_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + IS_TRANSPOSED_FIELD_NUMBER: builtins.int + num_rows: builtins.int + num_cols: builtins.int + @property + def value( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + is_transposed: builtins.bool + def __init__( + self, + *, + num_rows: builtins.int = ..., + num_cols: builtins.int = ..., + value: collections.abc.Iterable[builtins.float] | None = ..., + is_transposed: builtins.bool = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "is_transposed", + b"is_transposed", + "num_cols", + b"num_cols", + "num_rows", + b"num_rows", + "value", + b"value", + ], + ) -> None: ... + + class Sparse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NUM_ROWS_FIELD_NUMBER: builtins.int + NUM_COLS_FIELD_NUMBER: builtins.int + COLPTR_FIELD_NUMBER: builtins.int + ROW_INDEX_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + IS_TRANSPOSED_FIELD_NUMBER: builtins.int + num_rows: builtins.int + num_cols: builtins.int + @property + def colptr( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + @property + def row_index( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + @property + def value( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: ... + is_transposed: builtins.bool + def __init__( + self, + *, + num_rows: builtins.int = ..., + num_cols: builtins.int = ..., + colptr: collections.abc.Iterable[builtins.float] | None = ..., + row_index: collections.abc.Iterable[builtins.float] | None = ..., + value: collections.abc.Iterable[builtins.float] | None = ..., + is_transposed: builtins.bool = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "colptr", + b"colptr", + "is_transposed", + b"is_transposed", + "num_cols", + b"num_cols", + "num_rows", + b"num_rows", + "row_index", + b"row_index", + "value", + b"value", + ], + ) -> None: ... + + DENSE_FIELD_NUMBER: builtins.int + SPARSE_FIELD_NUMBER: builtins.int + @property + def dense(self) -> global___Matrix.Dense: ... + @property + def sparse(self) -> global___Matrix.Sparse: ... + def __init__( + self, + *, + dense: global___Matrix.Dense | None = ..., + sparse: global___Matrix.Sparse | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "dense", b"dense", "one_of", b"one_of", "sparse", b"sparse" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "dense", b"dense", "one_of", b"one_of", "sparse", b"sparse" + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["one_of", b"one_of"] + ) -> typing_extensions.Literal["dense", "sparse"] | None: ... + +global___Matrix = Matrix diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py b/python/pyspark/sql/connect/proto/ml_pb2.py new file mode 100644 index 000000000000..86698430a578 --- /dev/null +++ b/python/pyspark/sql/connect/proto/ml_pb2.py @@ -0,0 +1,304 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: spark/connect/ml.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2 +from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2 +from pyspark.sql.connect.proto import ml_common_pb2 as spark_dot_connect_dot_ml__common__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x1dspark/connect/ml_common.proto"d\n\x0bMlEvaluator\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12/\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06params\x12\x10\n\x03uid\x18\x03 \x01(\tR\x03uid"\xfe\x13\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12S\n\x10\x66\x65tch_model_attr\x18\x02 \x01(\x0b\x32\'.spark.connect.MlCommand.FetchModelAttrH\x00R\x0e\x66\x65tchModelAttr\x12i\n\x18\x66\x65tch_model_summary_attr\x18\x03 \x01(\x0b\x32..spark.connect.MlCommand.FetchModelSummaryAttrH\x00R\x15\x66\x65tchModelSummaryAttr\x12\x43\n\nload_model\x18\x04 \x01(\x0b\x32".spark.connect.MlCommand.LoadModelH\x00R\tloadModel\x12\x43\n\nsave_model\x18\x05 \x01(\x0b\x32".spark.connect.MlCommand.SaveModelH\x00R\tsaveModel\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x12\x43\n\nsave_stage\x18\x07 \x01(\x0b\x32".spark.connect.MlCommand.SaveStageH\x00R\tsaveStage\x12\x43\n\nload_stage\x18\x08 \x01(\x0b\x32".spark.connect.MlCommand.LoadStageH\x00R\tloadStage\x12O\n\x0esave_evaluator\x18\t \x01(\x0b\x32&.spark.connect.MlCommand.SaveEvaluatorH\x00R\rsaveEvaluator\x12O\n\x0eload_evaluator\x18\n \x01(\x0b\x32&.spark.connect.MlCommand.LoadEvaluatorH\x00R\rloadEvaluator\x12\x43\n\ncopy_model\x18\x0b \x01(\x0b\x32".spark.connect.MlCommand.CopyModelH\x00R\tcopyModel\x12I\n\x0c\x64\x65lete_model\x18\x0c \x01(\x0b\x32$.spark.connect.MlCommand.DeleteModelH\x00R\x0b\x64\x65leteModel\x1an\n\x03\x46it\x12\x34\n\testimator\x18\x01 \x01(\x0b\x32\x16.spark.connect.MlStageR\testimator\x12\x31\n\x07\x64\x61taset\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61taset\x1a\x44\n\x08\x45valuate\x12\x38\n\tevaluator\x18\x01 \x01(\x0b\x32\x1a.spark.connect.MlEvaluatorR\tevaluator\x1a\x33\n\tLoadModel\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xfa\x01\n\tSaveModel\x12\x34\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x17.spark.connect.ModelRefR\x08modelRef\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x12\x1c\n\toverwrite\x18\x03 \x01(\x08R\toverwrite\x12I\n\x07options\x18\x04 \x03(\x0b\x32/.spark.connect.MlCommand.SaveModel.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1ai\n\tLoadStage\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x12\x34\n\x04type\x18\x03 \x01(\x0e\x32 .spark.connect.MlStage.StageTypeR\x04type\x1a\xf2\x01\n\tSaveStage\x12,\n\x05stage\x18\x01 \x01(\x0b\x32\x16.spark.connect.MlStageR\x05stage\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x12\x1c\n\toverwrite\x18\x03 \x01(\x08R\toverwrite\x12I\n\x07options\x18\x04 \x03(\x0b\x32/.spark.connect.MlCommand.SaveStage.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x37\n\rLoadEvaluator\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\x86\x02\n\rSaveEvaluator\x12\x38\n\tevaluator\x18\x01 \x01(\x0b\x32\x1a.spark.connect.MlEvaluatorR\tevaluator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x12\x1c\n\toverwrite\x18\x03 \x01(\x08R\toverwrite\x12M\n\x07options\x18\x04 \x03(\x0b\x32\x33.spark.connect.MlCommand.SaveEvaluator.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1aZ\n\x0e\x46\x65tchModelAttr\x12\x34\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x17.spark.connect.ModelRefR\x08modelRef\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x1a\xf6\x01\n\x15\x46\x65tchModelSummaryAttr\x12\x34\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x17.spark.connect.ModelRefR\x08modelRef\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12/\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06params\x12K\n\x12\x65valuation_dataset\x18\x04 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x11\x65valuationDataset\x88\x01\x01\x42\x15\n\x13_evaluation_dataset\x1a\x41\n\tCopyModel\x12\x34\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x17.spark.connect.ModelRefR\x08modelRef\x1a\x43\n\x0b\x44\x65leteModel\x12\x34\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x17.spark.connect.ModelRefR\x08modelRefB\x11\n\x0fml_command_type"\x97\x04\n\x11MlCommandResponse\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12K\n\nmodel_info\x18\x02 \x01(\x0b\x32*.spark.connect.MlCommandResponse.ModelInfoH\x00R\tmodelInfo\x12/\n\x06vector\x18\x03 \x01(\x0b\x32\x15.spark.connect.VectorH\x00R\x06vector\x12/\n\x06matrix\x18\x04 \x01(\x0b\x32\x15.spark.connect.MatrixH\x00R\x06matrix\x12.\n\x05stage\x18\x05 \x01(\x0b\x32\x16.spark.connect.MlStageH\x00R\x05stage\x12\x36\n\tmodel_ref\x18\x06 \x01(\x0b\x32\x17.spark.connect.ModelRefH\x00R\x08modelRef\x1a\x8f\x01\n\tModelInfo\x12\x34\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x17.spark.connect.ModelRefR\x08modelRef\x12\x1b\n\tmodel_uid\x18\x02 \x01(\tR\x08modelUid\x12/\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06paramsB\x1a\n\x18ml_command_response_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' +) + + +_MLEVALUATOR = DESCRIPTOR.message_types_by_name["MlEvaluator"] +_MLCOMMAND = DESCRIPTOR.message_types_by_name["MlCommand"] +_MLCOMMAND_FIT = _MLCOMMAND.nested_types_by_name["Fit"] +_MLCOMMAND_EVALUATE = _MLCOMMAND.nested_types_by_name["Evaluate"] +_MLCOMMAND_LOADMODEL = _MLCOMMAND.nested_types_by_name["LoadModel"] +_MLCOMMAND_SAVEMODEL = _MLCOMMAND.nested_types_by_name["SaveModel"] +_MLCOMMAND_SAVEMODEL_OPTIONSENTRY = _MLCOMMAND_SAVEMODEL.nested_types_by_name["OptionsEntry"] +_MLCOMMAND_LOADSTAGE = _MLCOMMAND.nested_types_by_name["LoadStage"] +_MLCOMMAND_SAVESTAGE = _MLCOMMAND.nested_types_by_name["SaveStage"] +_MLCOMMAND_SAVESTAGE_OPTIONSENTRY = _MLCOMMAND_SAVESTAGE.nested_types_by_name["OptionsEntry"] +_MLCOMMAND_LOADEVALUATOR = _MLCOMMAND.nested_types_by_name["LoadEvaluator"] +_MLCOMMAND_SAVEEVALUATOR = _MLCOMMAND.nested_types_by_name["SaveEvaluator"] +_MLCOMMAND_SAVEEVALUATOR_OPTIONSENTRY = _MLCOMMAND_SAVEEVALUATOR.nested_types_by_name[ + "OptionsEntry" +] +_MLCOMMAND_FETCHMODELATTR = _MLCOMMAND.nested_types_by_name["FetchModelAttr"] +_MLCOMMAND_FETCHMODELSUMMARYATTR = _MLCOMMAND.nested_types_by_name["FetchModelSummaryAttr"] +_MLCOMMAND_COPYMODEL = _MLCOMMAND.nested_types_by_name["CopyModel"] +_MLCOMMAND_DELETEMODEL = _MLCOMMAND.nested_types_by_name["DeleteModel"] +_MLCOMMANDRESPONSE = DESCRIPTOR.message_types_by_name["MlCommandResponse"] +_MLCOMMANDRESPONSE_MODELINFO = _MLCOMMANDRESPONSE.nested_types_by_name["ModelInfo"] +MlEvaluator = _reflection.GeneratedProtocolMessageType( + "MlEvaluator", + (_message.Message,), + { + "DESCRIPTOR": _MLEVALUATOR, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlEvaluator) + }, +) +_sym_db.RegisterMessage(MlEvaluator) + +MlCommand = _reflection.GeneratedProtocolMessageType( + "MlCommand", + (_message.Message,), + { + "Fit": _reflection.GeneratedProtocolMessageType( + "Fit", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMAND_FIT, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.Fit) + }, + ), + "Evaluate": _reflection.GeneratedProtocolMessageType( + "Evaluate", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMAND_EVALUATE, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.Evaluate) + }, + ), + "LoadModel": _reflection.GeneratedProtocolMessageType( + "LoadModel", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMAND_LOADMODEL, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.LoadModel) + }, + ), + "SaveModel": _reflection.GeneratedProtocolMessageType( + "SaveModel", + (_message.Message,), + { + "OptionsEntry": _reflection.GeneratedProtocolMessageType( + "OptionsEntry", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMAND_SAVEMODEL_OPTIONSENTRY, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.SaveModel.OptionsEntry) + }, + ), + "DESCRIPTOR": _MLCOMMAND_SAVEMODEL, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.SaveModel) + }, + ), + "LoadStage": _reflection.GeneratedProtocolMessageType( + "LoadStage", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMAND_LOADSTAGE, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.LoadStage) + }, + ), + "SaveStage": _reflection.GeneratedProtocolMessageType( + "SaveStage", + (_message.Message,), + { + "OptionsEntry": _reflection.GeneratedProtocolMessageType( + "OptionsEntry", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMAND_SAVESTAGE_OPTIONSENTRY, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.SaveStage.OptionsEntry) + }, + ), + "DESCRIPTOR": _MLCOMMAND_SAVESTAGE, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.SaveStage) + }, + ), + "LoadEvaluator": _reflection.GeneratedProtocolMessageType( + "LoadEvaluator", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMAND_LOADEVALUATOR, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.LoadEvaluator) + }, + ), + "SaveEvaluator": _reflection.GeneratedProtocolMessageType( + "SaveEvaluator", + (_message.Message,), + { + "OptionsEntry": _reflection.GeneratedProtocolMessageType( + "OptionsEntry", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMAND_SAVEEVALUATOR_OPTIONSENTRY, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.SaveEvaluator.OptionsEntry) + }, + ), + "DESCRIPTOR": _MLCOMMAND_SAVEEVALUATOR, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.SaveEvaluator) + }, + ), + "FetchModelAttr": _reflection.GeneratedProtocolMessageType( + "FetchModelAttr", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMAND_FETCHMODELATTR, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.FetchModelAttr) + }, + ), + "FetchModelSummaryAttr": _reflection.GeneratedProtocolMessageType( + "FetchModelSummaryAttr", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMAND_FETCHMODELSUMMARYATTR, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.FetchModelSummaryAttr) + }, + ), + "CopyModel": _reflection.GeneratedProtocolMessageType( + "CopyModel", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMAND_COPYMODEL, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.CopyModel) + }, + ), + "DeleteModel": _reflection.GeneratedProtocolMessageType( + "DeleteModel", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMAND_DELETEMODEL, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand.DeleteModel) + }, + ), + "DESCRIPTOR": _MLCOMMAND, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommand) + }, +) +_sym_db.RegisterMessage(MlCommand) +_sym_db.RegisterMessage(MlCommand.Fit) +_sym_db.RegisterMessage(MlCommand.Evaluate) +_sym_db.RegisterMessage(MlCommand.LoadModel) +_sym_db.RegisterMessage(MlCommand.SaveModel) +_sym_db.RegisterMessage(MlCommand.SaveModel.OptionsEntry) +_sym_db.RegisterMessage(MlCommand.LoadStage) +_sym_db.RegisterMessage(MlCommand.SaveStage) +_sym_db.RegisterMessage(MlCommand.SaveStage.OptionsEntry) +_sym_db.RegisterMessage(MlCommand.LoadEvaluator) +_sym_db.RegisterMessage(MlCommand.SaveEvaluator) +_sym_db.RegisterMessage(MlCommand.SaveEvaluator.OptionsEntry) +_sym_db.RegisterMessage(MlCommand.FetchModelAttr) +_sym_db.RegisterMessage(MlCommand.FetchModelSummaryAttr) +_sym_db.RegisterMessage(MlCommand.CopyModel) +_sym_db.RegisterMessage(MlCommand.DeleteModel) + +MlCommandResponse = _reflection.GeneratedProtocolMessageType( + "MlCommandResponse", + (_message.Message,), + { + "ModelInfo": _reflection.GeneratedProtocolMessageType( + "ModelInfo", + (_message.Message,), + { + "DESCRIPTOR": _MLCOMMANDRESPONSE_MODELINFO, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommandResponse.ModelInfo) + }, + ), + "DESCRIPTOR": _MLCOMMANDRESPONSE, + "__module__": "spark.connect.ml_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlCommandResponse) + }, +) +_sym_db.RegisterMessage(MlCommandResponse) +_sym_db.RegisterMessage(MlCommandResponse.ModelInfo) + +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" + _MLCOMMAND_SAVEMODEL_OPTIONSENTRY._options = None + _MLCOMMAND_SAVEMODEL_OPTIONSENTRY._serialized_options = b"8\001" + _MLCOMMAND_SAVESTAGE_OPTIONSENTRY._options = None + _MLCOMMAND_SAVESTAGE_OPTIONSENTRY._serialized_options = b"8\001" + _MLCOMMAND_SAVEEVALUATOR_OPTIONSENTRY._options = None + _MLCOMMAND_SAVEEVALUATOR_OPTIONSENTRY._serialized_options = b"8\001" + _MLEVALUATOR._serialized_start = 136 + _MLEVALUATOR._serialized_end = 236 + _MLCOMMAND._serialized_start = 239 + _MLCOMMAND._serialized_end = 2797 + _MLCOMMAND_FIT._serialized_start = 1141 + _MLCOMMAND_FIT._serialized_end = 1251 + _MLCOMMAND_EVALUATE._serialized_start = 1253 + _MLCOMMAND_EVALUATE._serialized_end = 1321 + _MLCOMMAND_LOADMODEL._serialized_start = 1323 + _MLCOMMAND_LOADMODEL._serialized_end = 1374 + _MLCOMMAND_SAVEMODEL._serialized_start = 1377 + _MLCOMMAND_SAVEMODEL._serialized_end = 1627 + _MLCOMMAND_SAVEMODEL_OPTIONSENTRY._serialized_start = 1569 + _MLCOMMAND_SAVEMODEL_OPTIONSENTRY._serialized_end = 1627 + _MLCOMMAND_LOADSTAGE._serialized_start = 1629 + _MLCOMMAND_LOADSTAGE._serialized_end = 1734 + _MLCOMMAND_SAVESTAGE._serialized_start = 1737 + _MLCOMMAND_SAVESTAGE._serialized_end = 1979 + _MLCOMMAND_SAVESTAGE_OPTIONSENTRY._serialized_start = 1569 + _MLCOMMAND_SAVESTAGE_OPTIONSENTRY._serialized_end = 1627 + _MLCOMMAND_LOADEVALUATOR._serialized_start = 1981 + _MLCOMMAND_LOADEVALUATOR._serialized_end = 2036 + _MLCOMMAND_SAVEEVALUATOR._serialized_start = 2039 + _MLCOMMAND_SAVEEVALUATOR._serialized_end = 2301 + _MLCOMMAND_SAVEEVALUATOR_OPTIONSENTRY._serialized_start = 1569 + _MLCOMMAND_SAVEEVALUATOR_OPTIONSENTRY._serialized_end = 1627 + _MLCOMMAND_FETCHMODELATTR._serialized_start = 2303 + _MLCOMMAND_FETCHMODELATTR._serialized_end = 2393 + _MLCOMMAND_FETCHMODELSUMMARYATTR._serialized_start = 2396 + _MLCOMMAND_FETCHMODELSUMMARYATTR._serialized_end = 2642 + _MLCOMMAND_COPYMODEL._serialized_start = 2644 + _MLCOMMAND_COPYMODEL._serialized_end = 2709 + _MLCOMMAND_DELETEMODEL._serialized_start = 2711 + _MLCOMMAND_DELETEMODEL._serialized_end = 2778 + _MLCOMMANDRESPONSE._serialized_start = 2800 + _MLCOMMANDRESPONSE._serialized_end = 3335 + _MLCOMMANDRESPONSE_MODELINFO._serialized_start = 3164 + _MLCOMMANDRESPONSE_MODELINFO._serialized_end = 3307 +# @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi b/python/pyspark/sql/connect/proto/ml_pb2.pyi new file mode 100644 index 000000000000..f5916f81ef7a --- /dev/null +++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi @@ -0,0 +1,728 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file + +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import pyspark.sql.connect.proto.expressions_pb2 +import pyspark.sql.connect.proto.ml_common_pb2 +import pyspark.sql.connect.proto.relations_pb2 +import sys + +if sys.version_info >= (3, 8): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class MlEvaluator(google.protobuf.message.Message): + """MlEvaluator represents a ML Evaluator""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NAME_FIELD_NUMBER: builtins.int + PARAMS_FIELD_NUMBER: builtins.int + UID_FIELD_NUMBER: builtins.int + name: builtins.str + """The name of the evaluator in the registry""" + @property + def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams: + """param settings for the evaluator""" + uid: builtins.str + """unique id of the evaluator""" + def __init__( + self, + *, + name: builtins.str = ..., + params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None = ..., + uid: builtins.str = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["params", b"params"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal["name", b"name", "params", b"params", "uid", b"uid"], + ) -> None: ... + +global___MlEvaluator = MlEvaluator + +class MlCommand(google.protobuf.message.Message): + """a MlCommand is a type container that has exactly one ML command set""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class Fit(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ESTIMATOR_FIELD_NUMBER: builtins.int + DATASET_FIELD_NUMBER: builtins.int + @property + def estimator(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlStage: ... + @property + def dataset(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: ... + def __init__( + self, + *, + estimator: pyspark.sql.connect.proto.ml_common_pb2.MlStage | None = ..., + dataset: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal["dataset", b"dataset", "estimator", b"estimator"], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal["dataset", b"dataset", "estimator", b"estimator"], + ) -> None: ... + + class Evaluate(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + EVALUATOR_FIELD_NUMBER: builtins.int + @property + def evaluator(self) -> global___MlEvaluator: ... + def __init__( + self, + *, + evaluator: global___MlEvaluator | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["evaluator", b"evaluator"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["evaluator", b"evaluator"] + ) -> None: ... + + class LoadModel(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NAME_FIELD_NUMBER: builtins.int + PATH_FIELD_NUMBER: builtins.int + name: builtins.str + path: builtins.str + def __init__( + self, + *, + name: builtins.str = ..., + path: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["name", b"name", "path", b"path"] + ) -> None: ... + + class SaveModel(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class OptionsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] + ) -> None: ... + + MODEL_REF_FIELD_NUMBER: builtins.int + PATH_FIELD_NUMBER: builtins.int + OVERWRITE_FIELD_NUMBER: builtins.int + OPTIONS_FIELD_NUMBER: builtins.int + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ModelRef: ... + path: builtins.str + """saving path""" + overwrite: builtins.bool + @property + def options( + self, + ) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: + """saving options""" + def __init__( + self, + *, + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ModelRef | None = ..., + path: builtins.str = ..., + overwrite: builtins.bool = ..., + options: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "model_ref", + b"model_ref", + "options", + b"options", + "overwrite", + b"overwrite", + "path", + b"path", + ], + ) -> None: ... + + class LoadStage(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NAME_FIELD_NUMBER: builtins.int + PATH_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + name: builtins.str + path: builtins.str + type: pyspark.sql.connect.proto.ml_common_pb2.MlStage.StageType.ValueType + def __init__( + self, + *, + name: builtins.str = ..., + path: builtins.str = ..., + type: pyspark.sql.connect.proto.ml_common_pb2.MlStage.StageType.ValueType = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "name", b"name", "path", b"path", "type", b"type" + ], + ) -> None: ... + + class SaveStage(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class OptionsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] + ) -> None: ... + + STAGE_FIELD_NUMBER: builtins.int + PATH_FIELD_NUMBER: builtins.int + OVERWRITE_FIELD_NUMBER: builtins.int + OPTIONS_FIELD_NUMBER: builtins.int + @property + def stage(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlStage: ... + path: builtins.str + """saving path""" + overwrite: builtins.bool + @property + def options( + self, + ) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: + """saving options""" + def __init__( + self, + *, + stage: pyspark.sql.connect.proto.ml_common_pb2.MlStage | None = ..., + path: builtins.str = ..., + overwrite: builtins.bool = ..., + options: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["stage", b"stage"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "options", b"options", "overwrite", b"overwrite", "path", b"path", "stage", b"stage" + ], + ) -> None: ... + + class LoadEvaluator(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NAME_FIELD_NUMBER: builtins.int + PATH_FIELD_NUMBER: builtins.int + name: builtins.str + path: builtins.str + def __init__( + self, + *, + name: builtins.str = ..., + path: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["name", b"name", "path", b"path"] + ) -> None: ... + + class SaveEvaluator(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class OptionsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] + ) -> None: ... + + EVALUATOR_FIELD_NUMBER: builtins.int + PATH_FIELD_NUMBER: builtins.int + OVERWRITE_FIELD_NUMBER: builtins.int + OPTIONS_FIELD_NUMBER: builtins.int + @property + def evaluator(self) -> global___MlEvaluator: ... + path: builtins.str + """saving path""" + overwrite: builtins.bool + @property + def options( + self, + ) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: + """saving options""" + def __init__( + self, + *, + evaluator: global___MlEvaluator | None = ..., + path: builtins.str = ..., + overwrite: builtins.bool = ..., + options: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["evaluator", b"evaluator"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "evaluator", + b"evaluator", + "options", + b"options", + "overwrite", + b"overwrite", + "path", + b"path", + ], + ) -> None: ... + + class FetchModelAttr(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MODEL_REF_FIELD_NUMBER: builtins.int + NAME_FIELD_NUMBER: builtins.int + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ModelRef: ... + name: builtins.str + def __init__( + self, + *, + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ModelRef | None = ..., + name: builtins.str = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref", "name", b"name"] + ) -> None: ... + + class FetchModelSummaryAttr(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MODEL_REF_FIELD_NUMBER: builtins.int + NAME_FIELD_NUMBER: builtins.int + PARAMS_FIELD_NUMBER: builtins.int + EVALUATION_DATASET_FIELD_NUMBER: builtins.int + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ModelRef: ... + name: builtins.str + @property + def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams: ... + @property + def evaluation_dataset(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: + """Evaluation dataset that it uses to computes + the summary attribute + If not set, get attributes from + model.summary (i.e. the summary on training dataset) + """ + def __init__( + self, + *, + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ModelRef | None = ..., + name: builtins.str = ..., + params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None = ..., + evaluation_dataset: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_evaluation_dataset", + b"_evaluation_dataset", + "evaluation_dataset", + b"evaluation_dataset", + "model_ref", + b"model_ref", + "params", + b"params", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_evaluation_dataset", + b"_evaluation_dataset", + "evaluation_dataset", + b"evaluation_dataset", + "model_ref", + b"model_ref", + "name", + b"name", + "params", + b"params", + ], + ) -> None: ... + def WhichOneof( + self, + oneof_group: typing_extensions.Literal["_evaluation_dataset", b"_evaluation_dataset"], + ) -> typing_extensions.Literal["evaluation_dataset"] | None: ... + + class CopyModel(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MODEL_REF_FIELD_NUMBER: builtins.int + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ModelRef: ... + def __init__( + self, + *, + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ModelRef | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref"] + ) -> None: ... + + class DeleteModel(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MODEL_REF_FIELD_NUMBER: builtins.int + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ModelRef: ... + def __init__( + self, + *, + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ModelRef | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref"] + ) -> None: ... + + FIT_FIELD_NUMBER: builtins.int + FETCH_MODEL_ATTR_FIELD_NUMBER: builtins.int + FETCH_MODEL_SUMMARY_ATTR_FIELD_NUMBER: builtins.int + LOAD_MODEL_FIELD_NUMBER: builtins.int + SAVE_MODEL_FIELD_NUMBER: builtins.int + EVALUATE_FIELD_NUMBER: builtins.int + SAVE_STAGE_FIELD_NUMBER: builtins.int + LOAD_STAGE_FIELD_NUMBER: builtins.int + SAVE_EVALUATOR_FIELD_NUMBER: builtins.int + LOAD_EVALUATOR_FIELD_NUMBER: builtins.int + COPY_MODEL_FIELD_NUMBER: builtins.int + DELETE_MODEL_FIELD_NUMBER: builtins.int + @property + def fit(self) -> global___MlCommand.Fit: + """call `estimator.fit` and returns a model""" + @property + def fetch_model_attr(self) -> global___MlCommand.FetchModelAttr: + """get model attribute""" + @property + def fetch_model_summary_attr(self) -> global___MlCommand.FetchModelSummaryAttr: + """get model summary attribute""" + @property + def load_model(self) -> global___MlCommand.LoadModel: + """load model""" + @property + def save_model(self) -> global___MlCommand.SaveModel: + """save model""" + @property + def evaluate(self) -> global___MlCommand.Evaluate: + """call `evaluator.evaluate`""" + @property + def save_stage(self) -> global___MlCommand.SaveStage: + """save estimator or transformer""" + @property + def load_stage(self) -> global___MlCommand.LoadStage: + """load estimator or transformer""" + @property + def save_evaluator(self) -> global___MlCommand.SaveEvaluator: + """save estimator""" + @property + def load_evaluator(self) -> global___MlCommand.LoadEvaluator: + """load estimator""" + @property + def copy_model(self) -> global___MlCommand.CopyModel: + """copy model, returns new model reference id""" + @property + def delete_model(self) -> global___MlCommand.DeleteModel: + """delete server side model object by model reference id""" + def __init__( + self, + *, + fit: global___MlCommand.Fit | None = ..., + fetch_model_attr: global___MlCommand.FetchModelAttr | None = ..., + fetch_model_summary_attr: global___MlCommand.FetchModelSummaryAttr | None = ..., + load_model: global___MlCommand.LoadModel | None = ..., + save_model: global___MlCommand.SaveModel | None = ..., + evaluate: global___MlCommand.Evaluate | None = ..., + save_stage: global___MlCommand.SaveStage | None = ..., + load_stage: global___MlCommand.LoadStage | None = ..., + save_evaluator: global___MlCommand.SaveEvaluator | None = ..., + load_evaluator: global___MlCommand.LoadEvaluator | None = ..., + copy_model: global___MlCommand.CopyModel | None = ..., + delete_model: global___MlCommand.DeleteModel | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "copy_model", + b"copy_model", + "delete_model", + b"delete_model", + "evaluate", + b"evaluate", + "fetch_model_attr", + b"fetch_model_attr", + "fetch_model_summary_attr", + b"fetch_model_summary_attr", + "fit", + b"fit", + "load_evaluator", + b"load_evaluator", + "load_model", + b"load_model", + "load_stage", + b"load_stage", + "ml_command_type", + b"ml_command_type", + "save_evaluator", + b"save_evaluator", + "save_model", + b"save_model", + "save_stage", + b"save_stage", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "copy_model", + b"copy_model", + "delete_model", + b"delete_model", + "evaluate", + b"evaluate", + "fetch_model_attr", + b"fetch_model_attr", + "fetch_model_summary_attr", + b"fetch_model_summary_attr", + "fit", + b"fit", + "load_evaluator", + b"load_evaluator", + "load_model", + b"load_model", + "load_stage", + b"load_stage", + "ml_command_type", + b"ml_command_type", + "save_evaluator", + b"save_evaluator", + "save_model", + b"save_model", + "save_stage", + b"save_stage", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["ml_command_type", b"ml_command_type"] + ) -> typing_extensions.Literal[ + "fit", + "fetch_model_attr", + "fetch_model_summary_attr", + "load_model", + "save_model", + "evaluate", + "save_stage", + "load_stage", + "save_evaluator", + "load_evaluator", + "copy_model", + "delete_model", + ] | None: ... + +global___MlCommand = MlCommand + +class MlCommandResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class ModelInfo(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MODEL_REF_FIELD_NUMBER: builtins.int + MODEL_UID_FIELD_NUMBER: builtins.int + PARAMS_FIELD_NUMBER: builtins.int + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ModelRef: ... + model_uid: builtins.str + @property + def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams: ... + def __init__( + self, + *, + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ModelRef | None = ..., + model_uid: builtins.str = ..., + params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal["model_ref", b"model_ref", "params", b"params"], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "model_ref", b"model_ref", "model_uid", b"model_uid", "params", b"params" + ], + ) -> None: ... + + LITERAL_FIELD_NUMBER: builtins.int + MODEL_INFO_FIELD_NUMBER: builtins.int + VECTOR_FIELD_NUMBER: builtins.int + MATRIX_FIELD_NUMBER: builtins.int + STAGE_FIELD_NUMBER: builtins.int + MODEL_REF_FIELD_NUMBER: builtins.int + @property + def literal(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression.Literal: ... + @property + def model_info(self) -> global___MlCommandResponse.ModelInfo: ... + @property + def vector(self) -> pyspark.sql.connect.proto.ml_common_pb2.Vector: ... + @property + def matrix(self) -> pyspark.sql.connect.proto.ml_common_pb2.Matrix: ... + @property + def stage(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlStage: ... + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ModelRef: ... + def __init__( + self, + *, + literal: pyspark.sql.connect.proto.expressions_pb2.Expression.Literal | None = ..., + model_info: global___MlCommandResponse.ModelInfo | None = ..., + vector: pyspark.sql.connect.proto.ml_common_pb2.Vector | None = ..., + matrix: pyspark.sql.connect.proto.ml_common_pb2.Matrix | None = ..., + stage: pyspark.sql.connect.proto.ml_common_pb2.MlStage | None = ..., + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ModelRef | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "literal", + b"literal", + "matrix", + b"matrix", + "ml_command_response_type", + b"ml_command_response_type", + "model_info", + b"model_info", + "model_ref", + b"model_ref", + "stage", + b"stage", + "vector", + b"vector", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "literal", + b"literal", + "matrix", + b"matrix", + "ml_command_response_type", + b"ml_command_response_type", + "model_info", + b"model_info", + "model_ref", + b"model_ref", + "stage", + b"stage", + "vector", + b"vector", + ], + ) -> None: ... + def WhichOneof( + self, + oneof_group: typing_extensions.Literal[ + "ml_command_response_type", b"ml_command_response_type" + ], + ) -> typing_extensions.Literal[ + "literal", "model_info", "vector", "matrix", "stage", "model_ref" + ] | None: ... + +global___MlCommandResponse = MlCommandResponse diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index aa6d39cd4f06..38369dc4931f 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -33,14 +33,20 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2 from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2 from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catalog__pb2 +from pyspark.sql.connect.proto import ml_common_pb2 as spark_dot_connect_dot_ml__common__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf0\x13\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12*\n\x05to_df\x18\x12 \x01(\x0b\x32\x13.spark.connect.ToDFH\x00R\x04toDf\x12U\n\x14with_columns_renamed\x18\x13 \x01(\x0b\x32!.spark.connect.WithColumnsRenamedH\x00R\x12withColumnsRenamed\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x36\n\tto_schema\x18\x1a \x01(\x0b\x32\x17.spark.connect.ToSchemaH\x00R\x08toSchema\x12\x64\n\x19repartition_by_expression\x18\x1b \x01(\x0b\x32&.spark.connect.RepartitionByExpressionH\x00R\x17repartitionByExpression\x12\x45\n\x0emap_partitions\x18\x1c \x01(\x0b\x32\x1c.spark.connect.MapPartitionsH\x00R\rmapPartitions\x12H\n\x0f\x63ollect_metrics\x18\x1d \x01(\x0b\x32\x1d.spark.connect.CollectMetricsH\x00R\x0e\x63ollectMetrics\x12,\n\x05parse\x18\x1e \x01(\x0b\x32\x14.spark.connect.ParseH\x00R\x05parse\x12\x36\n\tgroup_map\x18\x1f \x01(\x0b\x32\x17.spark.connect.GroupMapH\x00R\x08groupMap\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12*\n\x03\x63ov\x18g \x01(\x0b\x32\x16.spark.connect.StatCovH\x00R\x03\x63ov\x12-\n\x04\x63orr\x18h \x01(\x0b\x32\x17.spark.connect.StatCorrH\x00R\x04\x63orr\x12L\n\x0f\x61pprox_quantile\x18i \x01(\x0b\x32!.spark.connect.StatApproxQuantileH\x00R\x0e\x61pproxQuantile\x12=\n\nfreq_items\x18j \x01(\x0b\x32\x1c.spark.connect.StatFreqItemsH\x00R\tfreqItems\x12:\n\tsample_by\x18k \x01(\x0b\x32\x1b.spark.connect.StatSampleByH\x00R\x08sampleBy\x12\x33\n\x07\x63\x61talog\x18\xc8\x01 \x01(\x0b\x32\x16.spark.connect.CatalogH\x00R\x07\x63\x61talog\x12\x35\n\textension\x18\xe6\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"[\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id"\x86\x01\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query\x12\x30\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x1c.spark.connect.SQL.ArgsEntryR\x04\x61rgs\x1a\x37\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\xf0\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\x95\x02\n\nDataSource\x12\x1b\n\x06\x66ormat\x18\x01 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x12\x14\n\x05paths\x18\x04 \x03(\tR\x05paths\x12\x1e\n\npredicates\x18\x05 \x03(\tR\npredicates\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xd7\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07"\xdf\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01\x12\x37\n\x15\x61llow_missing_columns\x18\x06 \x01(\x08H\x02R\x13\x61llowMissingColumns\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_nameB\x18\n\x16_allow_missing_columns"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xc6\x04\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x41\n\ngroup_type\x18\x02 \x01(\x0e\x32".spark.connect.Aggregate.GroupTypeR\tgroupType\x12L\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12N\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x12\x34\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.spark.connect.Aggregate.PivotR\x05pivot\x1ao\n\x05Pivot\x12+\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x81\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04"\xa0\x01\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x05order\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\x05order\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x42\x0c\n\n_is_global"\x8d\x01\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x33\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07\x63olumns\x12!\n\x0c\x63olumn_names\x18\x03 \x03(\tR\x0b\x63olumnNames"\xab\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x42\x16\n\x14_all_columns_as_keys"Y\n\rLocalRelation\x12\x17\n\x04\x64\x61ta\x18\x01 \x01(\x0cH\x00R\x04\x64\x61ta\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x42\x07\n\x05_dataB\t\n\x07_schema"\x91\x02\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x12/\n\x13\x64\x65terministic_order\x18\x06 \x01(\x08R\x12\x64\x65terministicOrderB\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8e\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"`\n\x07StatCov\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x89\x01\n\x08StatCorr\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2\x12\x1b\n\x06method\x18\x04 \x01(\tH\x00R\x06method\x88\x01\x01\x42\t\n\x07_method"\xa4\x01\n\x12StatApproxQuantile\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12$\n\rprobabilities\x18\x03 \x03(\x01R\rprobabilities\x12%\n\x0erelative_error\x18\x04 \x01(\x01R\rrelativeError"}\n\rStatFreqItems\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x1d\n\x07support\x18\x03 \x01(\x01H\x00R\x07support\x88\x01\x01\x42\n\n\x08_support"\xb5\x02\n\x0cStatSampleBy\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03\x63ol\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x42\n\tfractions\x18\x03 \x03(\x0b\x32$.spark.connect.StatSampleBy.FractionR\tfractions\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x00R\x04seed\x88\x01\x01\x1a\x63\n\x08\x46raction\x12;\n\x07stratum\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x07stratum\x12\x1a\n\x08\x66raction\x18\x02 \x01(\x01R\x08\x66ractionB\x07\n\x05_seed"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"X\n\x04ToDF\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\xef\x01\n\x12WithColumnsRenamed\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x65\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x37.spark.connect.WithColumnsRenamed.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"w\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x07\x61liases\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x07\x61liases"\x84\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x39\n\nparameters\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\nparameters"\xc7\x02\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12:\n\x06values\x18\x03 \x01(\x0b\x32\x1d.spark.connect.Unpivot.ValuesH\x00R\x06values\x88\x01\x01\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnName\x1a;\n\x06Values\x12\x31\n\x06values\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06valuesB\t\n\x07_values"j\n\x08ToSchema\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema"\xcb\x01\n\x17RepartitionByExpression\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x0fpartition_exprs\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0epartitionExprs\x12*\n\x0enum_partitions\x18\x03 \x01(\x05H\x00R\rnumPartitions\x88\x01\x01\x42\x11\n\x0f_num_partitions"\x82\x01\n\rMapPartitions\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x04\x66unc\x18\x02 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc"\xcb\x01\n\x08GroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc"\x88\x01\n\x0e\x43ollectMetrics\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x33\n\x07metrics\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07metrics"\x84\x03\n\x05Parse\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x06\x66ormat\x18\x02 \x01(\x0e\x32 .spark.connect.Parse.ParseFormatR\x06\x66ormat\x12\x34\n\x06schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x06schema\x88\x01\x01\x12;\n\x07options\x18\x04 \x03(\x0b\x32!.spark.connect.Parse.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"X\n\x0bParseFormat\x12\x1c\n\x18PARSE_FORMAT_UNSPECIFIED\x10\x00\x12\x14\n\x10PARSE_FORMAT_CSV\x10\x01\x12\x15\n\x11PARSE_FORMAT_JSON\x10\x02\x42\t\n\x07_schemaB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1dspark/connect/ml_common.proto"\xaf\x14\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12*\n\x05to_df\x18\x12 \x01(\x0b\x32\x13.spark.connect.ToDFH\x00R\x04toDf\x12U\n\x14with_columns_renamed\x18\x13 \x01(\x0b\x32!.spark.connect.WithColumnsRenamedH\x00R\x12withColumnsRenamed\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x36\n\tto_schema\x18\x1a \x01(\x0b\x32\x17.spark.connect.ToSchemaH\x00R\x08toSchema\x12\x64\n\x19repartition_by_expression\x18\x1b \x01(\x0b\x32&.spark.connect.RepartitionByExpressionH\x00R\x17repartitionByExpression\x12\x45\n\x0emap_partitions\x18\x1c \x01(\x0b\x32\x1c.spark.connect.MapPartitionsH\x00R\rmapPartitions\x12H\n\x0f\x63ollect_metrics\x18\x1d \x01(\x0b\x32\x1d.spark.connect.CollectMetricsH\x00R\x0e\x63ollectMetrics\x12,\n\x05parse\x18\x1e \x01(\x0b\x32\x14.spark.connect.ParseH\x00R\x05parse\x12\x36\n\tgroup_map\x18\x1f \x01(\x0b\x32\x17.spark.connect.GroupMapH\x00R\x08groupMap\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12*\n\x03\x63ov\x18g \x01(\x0b\x32\x16.spark.connect.StatCovH\x00R\x03\x63ov\x12-\n\x04\x63orr\x18h \x01(\x0b\x32\x17.spark.connect.StatCorrH\x00R\x04\x63orr\x12L\n\x0f\x61pprox_quantile\x18i \x01(\x0b\x32!.spark.connect.StatApproxQuantileH\x00R\x0e\x61pproxQuantile\x12=\n\nfreq_items\x18j \x01(\x0b\x32\x1c.spark.connect.StatFreqItemsH\x00R\tfreqItems\x12:\n\tsample_by\x18k \x01(\x0b\x32\x1b.spark.connect.StatSampleByH\x00R\x08sampleBy\x12\x33\n\x07\x63\x61talog\x18\xc8\x01 \x01(\x0b\x32\x16.spark.connect.CatalogH\x00R\x07\x63\x61talog\x12=\n\x0bml_relation\x18\xac\x02 \x01(\x0b\x32\x19.spark.connect.MlRelationH\x00R\nmlRelation\x12\x35\n\textension\x18\xe6\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\xe3\x07\n\nMlRelation\x12S\n\x0fmodel_transform\x18\x01 \x01(\x0b\x32(.spark.connect.MlRelation.ModelTransformH\x00R\x0emodelTransform\x12Y\n\x11\x66\x65\x61ture_transform\x18\x02 \x01(\x0b\x32*.spark.connect.MlRelation.FeatureTransformH\x00R\x10\x66\x65\x61tureTransform\x12\x44\n\nmodel_attr\x18\x03 \x01(\x0b\x32#.spark.connect.MlRelation.ModelAttrH\x00R\tmodelAttr\x12Z\n\x12model_summary_attr\x18\x04 \x01(\x0b\x32*.spark.connect.MlRelation.ModelSummaryAttrH\x00R\x10modelSummaryAttr\x1a\xa6\x01\n\x0eModelTransform\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x34\n\tmodel_ref\x18\x02 \x01(\x0b\x32\x17.spark.connect.ModelRefR\x08modelRef\x12/\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06params\x1a{\n\x10\x46\x65\x61tureTransform\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x0btransformer\x18\x02 \x01(\x0b\x32\x16.spark.connect.MlStageR\x0btransformer\x1aU\n\tModelAttr\x12\x34\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x17.spark.connect.ModelRefR\x08modelRef\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x1a\xf1\x01\n\x10ModelSummaryAttr\x12\x34\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x17.spark.connect.ModelRefR\x08modelRef\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12/\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06params\x12K\n\x12\x65valuation_dataset\x18\x04 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x11\x65valuationDataset\x88\x01\x01\x42\x15\n\x13_evaluation_datasetB\x12\n\x10ml_relation_type"\t\n\x07Unknown"[\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id"\x86\x01\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query\x12\x30\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x1c.spark.connect.SQL.ArgsEntryR\x04\x61rgs\x1a\x37\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\xf0\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\x95\x02\n\nDataSource\x12\x1b\n\x06\x66ormat\x18\x01 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x12\x14\n\x05paths\x18\x04 \x03(\tR\x05paths\x12\x1e\n\npredicates\x18\x05 \x03(\tR\npredicates\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xd7\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07"\xdf\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01\x12\x37\n\x15\x61llow_missing_columns\x18\x06 \x01(\x08H\x02R\x13\x61llowMissingColumns\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_nameB\x18\n\x16_allow_missing_columns"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xc6\x04\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x41\n\ngroup_type\x18\x02 \x01(\x0e\x32".spark.connect.Aggregate.GroupTypeR\tgroupType\x12L\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12N\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x12\x34\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.spark.connect.Aggregate.PivotR\x05pivot\x1ao\n\x05Pivot\x12+\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x81\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04"\xa0\x01\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x05order\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\x05order\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x42\x0c\n\n_is_global"\x8d\x01\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x33\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07\x63olumns\x12!\n\x0c\x63olumn_names\x18\x03 \x03(\tR\x0b\x63olumnNames"\xab\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x42\x16\n\x14_all_columns_as_keys"Y\n\rLocalRelation\x12\x17\n\x04\x64\x61ta\x18\x01 \x01(\x0cH\x00R\x04\x64\x61ta\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x42\x07\n\x05_dataB\t\n\x07_schema"\x91\x02\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x12/\n\x13\x64\x65terministic_order\x18\x06 \x01(\x08R\x12\x64\x65terministicOrderB\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8e\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"`\n\x07StatCov\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x89\x01\n\x08StatCorr\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2\x12\x1b\n\x06method\x18\x04 \x01(\tH\x00R\x06method\x88\x01\x01\x42\t\n\x07_method"\xa4\x01\n\x12StatApproxQuantile\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12$\n\rprobabilities\x18\x03 \x03(\x01R\rprobabilities\x12%\n\x0erelative_error\x18\x04 \x01(\x01R\rrelativeError"}\n\rStatFreqItems\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x1d\n\x07support\x18\x03 \x01(\x01H\x00R\x07support\x88\x01\x01\x42\n\n\x08_support"\xb5\x02\n\x0cStatSampleBy\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03\x63ol\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x42\n\tfractions\x18\x03 \x03(\x0b\x32$.spark.connect.StatSampleBy.FractionR\tfractions\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x00R\x04seed\x88\x01\x01\x1a\x63\n\x08\x46raction\x12;\n\x07stratum\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x07stratum\x12\x1a\n\x08\x66raction\x18\x02 \x01(\x01R\x08\x66ractionB\x07\n\x05_seed"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"X\n\x04ToDF\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\xef\x01\n\x12WithColumnsRenamed\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x65\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x37.spark.connect.WithColumnsRenamed.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"w\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x07\x61liases\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x07\x61liases"\x84\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x39\n\nparameters\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\nparameters"\xc7\x02\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12:\n\x06values\x18\x03 \x01(\x0b\x32\x1d.spark.connect.Unpivot.ValuesH\x00R\x06values\x88\x01\x01\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnName\x1a;\n\x06Values\x12\x31\n\x06values\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06valuesB\t\n\x07_values"j\n\x08ToSchema\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema"\xcb\x01\n\x17RepartitionByExpression\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x0fpartition_exprs\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0epartitionExprs\x12*\n\x0enum_partitions\x18\x03 \x01(\x05H\x00R\rnumPartitions\x88\x01\x01\x42\x11\n\x0f_num_partitions"\x82\x01\n\rMapPartitions\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x04\x66unc\x18\x02 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc"\xcb\x01\n\x08GroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc"\x88\x01\n\x0e\x43ollectMetrics\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x33\n\x07metrics\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07metrics"\x84\x03\n\x05Parse\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x06\x66ormat\x18\x02 \x01(\x0e\x32 .spark.connect.Parse.ParseFormatR\x06\x66ormat\x12\x34\n\x06schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x06schema\x88\x01\x01\x12;\n\x07options\x18\x04 \x03(\x0b\x32!.spark.connect.Parse.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"X\n\x0bParseFormat\x12\x1c\n\x18PARSE_FORMAT_UNSPECIFIED\x10\x00\x12\x14\n\x10PARSE_FORMAT_CSV\x10\x01\x12\x15\n\x11PARSE_FORMAT_JSON\x10\x02\x42\t\n\x07_schemaB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _RELATION = DESCRIPTOR.message_types_by_name["Relation"] +_MLRELATION = DESCRIPTOR.message_types_by_name["MlRelation"] +_MLRELATION_MODELTRANSFORM = _MLRELATION.nested_types_by_name["ModelTransform"] +_MLRELATION_FEATURETRANSFORM = _MLRELATION.nested_types_by_name["FeatureTransform"] +_MLRELATION_MODELATTR = _MLRELATION.nested_types_by_name["ModelAttr"] +_MLRELATION_MODELSUMMARYATTR = _MLRELATION.nested_types_by_name["ModelSummaryAttr"] _UNKNOWN = DESCRIPTOR.message_types_by_name["Unknown"] _RELATIONCOMMON = DESCRIPTOR.message_types_by_name["RelationCommon"] _SQL = DESCRIPTOR.message_types_by_name["SQL"] @@ -111,6 +117,57 @@ ) _sym_db.RegisterMessage(Relation) +MlRelation = _reflection.GeneratedProtocolMessageType( + "MlRelation", + (_message.Message,), + { + "ModelTransform": _reflection.GeneratedProtocolMessageType( + "ModelTransform", + (_message.Message,), + { + "DESCRIPTOR": _MLRELATION_MODELTRANSFORM, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlRelation.ModelTransform) + }, + ), + "FeatureTransform": _reflection.GeneratedProtocolMessageType( + "FeatureTransform", + (_message.Message,), + { + "DESCRIPTOR": _MLRELATION_FEATURETRANSFORM, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlRelation.FeatureTransform) + }, + ), + "ModelAttr": _reflection.GeneratedProtocolMessageType( + "ModelAttr", + (_message.Message,), + { + "DESCRIPTOR": _MLRELATION_MODELATTR, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlRelation.ModelAttr) + }, + ), + "ModelSummaryAttr": _reflection.GeneratedProtocolMessageType( + "ModelSummaryAttr", + (_message.Message,), + { + "DESCRIPTOR": _MLRELATION_MODELSUMMARYATTR, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlRelation.ModelSummaryAttr) + }, + ), + "DESCRIPTOR": _MLRELATION, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.MlRelation) + }, +) +_sym_db.RegisterMessage(MlRelation) +_sym_db.RegisterMessage(MlRelation.ModelTransform) +_sym_db.RegisterMessage(MlRelation.FeatureTransform) +_sym_db.RegisterMessage(MlRelation.ModelAttr) +_sym_db.RegisterMessage(MlRelation.ModelSummaryAttr) + Unknown = _reflection.GeneratedProtocolMessageType( "Unknown", (_message.Message,), @@ -696,120 +753,130 @@ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001" _PARSE_OPTIONSENTRY._options = None _PARSE_OPTIONSENTRY._serialized_options = b"8\001" - _RELATION._serialized_start = 165 - _RELATION._serialized_end = 2709 - _UNKNOWN._serialized_start = 2711 - _UNKNOWN._serialized_end = 2720 - _RELATIONCOMMON._serialized_start = 2722 - _RELATIONCOMMON._serialized_end = 2813 - _SQL._serialized_start = 2816 - _SQL._serialized_end = 2950 - _SQL_ARGSENTRY._serialized_start = 2895 - _SQL_ARGSENTRY._serialized_end = 2950 - _READ._serialized_start = 2953 - _READ._serialized_end = 3449 - _READ_NAMEDTABLE._serialized_start = 3095 - _READ_NAMEDTABLE._serialized_end = 3156 - _READ_DATASOURCE._serialized_start = 3159 - _READ_DATASOURCE._serialized_end = 3436 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3356 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3414 - _PROJECT._serialized_start = 3451 - _PROJECT._serialized_end = 3568 - _FILTER._serialized_start = 3570 - _FILTER._serialized_end = 3682 - _JOIN._serialized_start = 3685 - _JOIN._serialized_end = 4156 - _JOIN_JOINTYPE._serialized_start = 3948 - _JOIN_JOINTYPE._serialized_end = 4156 - _SETOPERATION._serialized_start = 4159 - _SETOPERATION._serialized_end = 4638 - _SETOPERATION_SETOPTYPE._serialized_start = 4475 - _SETOPERATION_SETOPTYPE._serialized_end = 4589 - _LIMIT._serialized_start = 4640 - _LIMIT._serialized_end = 4716 - _OFFSET._serialized_start = 4718 - _OFFSET._serialized_end = 4797 - _TAIL._serialized_start = 4799 - _TAIL._serialized_end = 4874 - _AGGREGATE._serialized_start = 4877 - _AGGREGATE._serialized_end = 5459 - _AGGREGATE_PIVOT._serialized_start = 5216 - _AGGREGATE_PIVOT._serialized_end = 5327 - _AGGREGATE_GROUPTYPE._serialized_start = 5330 - _AGGREGATE_GROUPTYPE._serialized_end = 5459 - _SORT._serialized_start = 5462 - _SORT._serialized_end = 5622 - _DROP._serialized_start = 5625 - _DROP._serialized_end = 5766 - _DEDUPLICATE._serialized_start = 5769 - _DEDUPLICATE._serialized_end = 5940 - _LOCALRELATION._serialized_start = 5942 - _LOCALRELATION._serialized_end = 6031 - _SAMPLE._serialized_start = 6034 - _SAMPLE._serialized_end = 6307 - _RANGE._serialized_start = 6310 - _RANGE._serialized_end = 6455 - _SUBQUERYALIAS._serialized_start = 6457 - _SUBQUERYALIAS._serialized_end = 6571 - _REPARTITION._serialized_start = 6574 - _REPARTITION._serialized_end = 6716 - _SHOWSTRING._serialized_start = 6719 - _SHOWSTRING._serialized_end = 6861 - _STATSUMMARY._serialized_start = 6863 - _STATSUMMARY._serialized_end = 6955 - _STATDESCRIBE._serialized_start = 6957 - _STATDESCRIBE._serialized_end = 7038 - _STATCROSSTAB._serialized_start = 7040 - _STATCROSSTAB._serialized_end = 7141 - _STATCOV._serialized_start = 7143 - _STATCOV._serialized_end = 7239 - _STATCORR._serialized_start = 7242 - _STATCORR._serialized_end = 7379 - _STATAPPROXQUANTILE._serialized_start = 7382 - _STATAPPROXQUANTILE._serialized_end = 7546 - _STATFREQITEMS._serialized_start = 7548 - _STATFREQITEMS._serialized_end = 7673 - _STATSAMPLEBY._serialized_start = 7676 - _STATSAMPLEBY._serialized_end = 7985 - _STATSAMPLEBY_FRACTION._serialized_start = 7877 - _STATSAMPLEBY_FRACTION._serialized_end = 7976 - _NAFILL._serialized_start = 7988 - _NAFILL._serialized_end = 8122 - _NADROP._serialized_start = 8125 - _NADROP._serialized_end = 8259 - _NAREPLACE._serialized_start = 8262 - _NAREPLACE._serialized_end = 8558 - _NAREPLACE_REPLACEMENT._serialized_start = 8417 - _NAREPLACE_REPLACEMENT._serialized_end = 8558 - _TODF._serialized_start = 8560 - _TODF._serialized_end = 8648 - _WITHCOLUMNSRENAMED._serialized_start = 8651 - _WITHCOLUMNSRENAMED._serialized_end = 8890 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8823 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8890 - _WITHCOLUMNS._serialized_start = 8892 - _WITHCOLUMNS._serialized_end = 9011 - _HINT._serialized_start = 9014 - _HINT._serialized_end = 9146 - _UNPIVOT._serialized_start = 9149 - _UNPIVOT._serialized_end = 9476 - _UNPIVOT_VALUES._serialized_start = 9406 - _UNPIVOT_VALUES._serialized_end = 9465 - _TOSCHEMA._serialized_start = 9478 - _TOSCHEMA._serialized_end = 9584 - _REPARTITIONBYEXPRESSION._serialized_start = 9587 - _REPARTITIONBYEXPRESSION._serialized_end = 9790 - _MAPPARTITIONS._serialized_start = 9793 - _MAPPARTITIONS._serialized_end = 9923 - _GROUPMAP._serialized_start = 9926 - _GROUPMAP._serialized_end = 10129 - _COLLECTMETRICS._serialized_start = 10132 - _COLLECTMETRICS._serialized_end = 10268 - _PARSE._serialized_start = 10271 - _PARSE._serialized_end = 10659 - _PARSE_OPTIONSENTRY._serialized_start = 3356 - _PARSE_OPTIONSENTRY._serialized_end = 3414 - _PARSE_PARSEFORMAT._serialized_start = 10560 - _PARSE_PARSEFORMAT._serialized_end = 10648 + _RELATION._serialized_start = 196 + _RELATION._serialized_end = 2803 + _MLRELATION._serialized_start = 2806 + _MLRELATION._serialized_end = 3801 + _MLRELATION_MODELTRANSFORM._serialized_start = 3159 + _MLRELATION_MODELTRANSFORM._serialized_end = 3325 + _MLRELATION_FEATURETRANSFORM._serialized_start = 3327 + _MLRELATION_FEATURETRANSFORM._serialized_end = 3450 + _MLRELATION_MODELATTR._serialized_start = 3452 + _MLRELATION_MODELATTR._serialized_end = 3537 + _MLRELATION_MODELSUMMARYATTR._serialized_start = 3540 + _MLRELATION_MODELSUMMARYATTR._serialized_end = 3781 + _UNKNOWN._serialized_start = 3803 + _UNKNOWN._serialized_end = 3812 + _RELATIONCOMMON._serialized_start = 3814 + _RELATIONCOMMON._serialized_end = 3905 + _SQL._serialized_start = 3908 + _SQL._serialized_end = 4042 + _SQL_ARGSENTRY._serialized_start = 3987 + _SQL_ARGSENTRY._serialized_end = 4042 + _READ._serialized_start = 4045 + _READ._serialized_end = 4541 + _READ_NAMEDTABLE._serialized_start = 4187 + _READ_NAMEDTABLE._serialized_end = 4248 + _READ_DATASOURCE._serialized_start = 4251 + _READ_DATASOURCE._serialized_end = 4528 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 4448 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 4506 + _PROJECT._serialized_start = 4543 + _PROJECT._serialized_end = 4660 + _FILTER._serialized_start = 4662 + _FILTER._serialized_end = 4774 + _JOIN._serialized_start = 4777 + _JOIN._serialized_end = 5248 + _JOIN_JOINTYPE._serialized_start = 5040 + _JOIN_JOINTYPE._serialized_end = 5248 + _SETOPERATION._serialized_start = 5251 + _SETOPERATION._serialized_end = 5730 + _SETOPERATION_SETOPTYPE._serialized_start = 5567 + _SETOPERATION_SETOPTYPE._serialized_end = 5681 + _LIMIT._serialized_start = 5732 + _LIMIT._serialized_end = 5808 + _OFFSET._serialized_start = 5810 + _OFFSET._serialized_end = 5889 + _TAIL._serialized_start = 5891 + _TAIL._serialized_end = 5966 + _AGGREGATE._serialized_start = 5969 + _AGGREGATE._serialized_end = 6551 + _AGGREGATE_PIVOT._serialized_start = 6308 + _AGGREGATE_PIVOT._serialized_end = 6419 + _AGGREGATE_GROUPTYPE._serialized_start = 6422 + _AGGREGATE_GROUPTYPE._serialized_end = 6551 + _SORT._serialized_start = 6554 + _SORT._serialized_end = 6714 + _DROP._serialized_start = 6717 + _DROP._serialized_end = 6858 + _DEDUPLICATE._serialized_start = 6861 + _DEDUPLICATE._serialized_end = 7032 + _LOCALRELATION._serialized_start = 7034 + _LOCALRELATION._serialized_end = 7123 + _SAMPLE._serialized_start = 7126 + _SAMPLE._serialized_end = 7399 + _RANGE._serialized_start = 7402 + _RANGE._serialized_end = 7547 + _SUBQUERYALIAS._serialized_start = 7549 + _SUBQUERYALIAS._serialized_end = 7663 + _REPARTITION._serialized_start = 7666 + _REPARTITION._serialized_end = 7808 + _SHOWSTRING._serialized_start = 7811 + _SHOWSTRING._serialized_end = 7953 + _STATSUMMARY._serialized_start = 7955 + _STATSUMMARY._serialized_end = 8047 + _STATDESCRIBE._serialized_start = 8049 + _STATDESCRIBE._serialized_end = 8130 + _STATCROSSTAB._serialized_start = 8132 + _STATCROSSTAB._serialized_end = 8233 + _STATCOV._serialized_start = 8235 + _STATCOV._serialized_end = 8331 + _STATCORR._serialized_start = 8334 + _STATCORR._serialized_end = 8471 + _STATAPPROXQUANTILE._serialized_start = 8474 + _STATAPPROXQUANTILE._serialized_end = 8638 + _STATFREQITEMS._serialized_start = 8640 + _STATFREQITEMS._serialized_end = 8765 + _STATSAMPLEBY._serialized_start = 8768 + _STATSAMPLEBY._serialized_end = 9077 + _STATSAMPLEBY_FRACTION._serialized_start = 8969 + _STATSAMPLEBY_FRACTION._serialized_end = 9068 + _NAFILL._serialized_start = 9080 + _NAFILL._serialized_end = 9214 + _NADROP._serialized_start = 9217 + _NADROP._serialized_end = 9351 + _NAREPLACE._serialized_start = 9354 + _NAREPLACE._serialized_end = 9650 + _NAREPLACE_REPLACEMENT._serialized_start = 9509 + _NAREPLACE_REPLACEMENT._serialized_end = 9650 + _TODF._serialized_start = 9652 + _TODF._serialized_end = 9740 + _WITHCOLUMNSRENAMED._serialized_start = 9743 + _WITHCOLUMNSRENAMED._serialized_end = 9982 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 9915 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 9982 + _WITHCOLUMNS._serialized_start = 9984 + _WITHCOLUMNS._serialized_end = 10103 + _HINT._serialized_start = 10106 + _HINT._serialized_end = 10238 + _UNPIVOT._serialized_start = 10241 + _UNPIVOT._serialized_end = 10568 + _UNPIVOT_VALUES._serialized_start = 10498 + _UNPIVOT_VALUES._serialized_end = 10557 + _TOSCHEMA._serialized_start = 10570 + _TOSCHEMA._serialized_end = 10676 + _REPARTITIONBYEXPRESSION._serialized_start = 10679 + _REPARTITIONBYEXPRESSION._serialized_end = 10882 + _MAPPARTITIONS._serialized_start = 10885 + _MAPPARTITIONS._serialized_end = 11015 + _GROUPMAP._serialized_start = 11018 + _GROUPMAP._serialized_end = 11221 + _COLLECTMETRICS._serialized_start = 11224 + _COLLECTMETRICS._serialized_end = 11360 + _PARSE._serialized_start = 11363 + _PARSE._serialized_end = 11751 + _PARSE_OPTIONSENTRY._serialized_start = 4448 + _PARSE_OPTIONSENTRY._serialized_end = 4506 + _PARSE_PARSEFORMAT._serialized_start = 11652 + _PARSE_PARSEFORMAT._serialized_end = 11740 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 6ae4a323f6f7..ebcca0512229 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -42,6 +42,7 @@ import google.protobuf.internal.enum_type_wrapper import google.protobuf.message import pyspark.sql.connect.proto.catalog_pb2 import pyspark.sql.connect.proto.expressions_pb2 +import pyspark.sql.connect.proto.ml_common_pb2 import pyspark.sql.connect.proto.types_pb2 import sys import typing @@ -105,6 +106,7 @@ class Relation(google.protobuf.message.Message): FREQ_ITEMS_FIELD_NUMBER: builtins.int SAMPLE_BY_FIELD_NUMBER: builtins.int CATALOG_FIELD_NUMBER: builtins.int + ML_RELATION_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int UNKNOWN_FIELD_NUMBER: builtins.int @property @@ -197,6 +199,9 @@ class Relation(google.protobuf.message.Message): def catalog(self) -> pyspark.sql.connect.proto.catalog_pb2.Catalog: """Catalog API (experimental / unstable)""" @property + def ml_relation(self) -> global___MlRelation: + """ML relation""" + @property def extension(self) -> google.protobuf.any_pb2.Any: """This field is used to mark extensions to the protocol. When plugins generate arbitrary relations they can add them here. During the planning the correct resolution is done. @@ -249,6 +254,7 @@ class Relation(google.protobuf.message.Message): freq_items: global___StatFreqItems | None = ..., sample_by: global___StatSampleBy | None = ..., catalog: pyspark.sql.connect.proto.catalog_pb2.Catalog | None = ..., + ml_relation: global___MlRelation | None = ..., extension: google.protobuf.any_pb2.Any | None = ..., unknown: global___Unknown | None = ..., ) -> None: ... @@ -299,6 +305,8 @@ class Relation(google.protobuf.message.Message): b"local_relation", "map_partitions", b"map_partitions", + "ml_relation", + b"ml_relation", "offset", b"offset", "parse", @@ -396,6 +404,8 @@ class Relation(google.protobuf.message.Message): b"local_relation", "map_partitions", b"map_partitions", + "ml_relation", + b"ml_relation", "offset", b"offset", "parse", @@ -491,12 +501,211 @@ class Relation(google.protobuf.message.Message): "freq_items", "sample_by", "catalog", + "ml_relation", "extension", "unknown", ] | None: ... global___Relation = Relation +class MlRelation(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class ModelTransform(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + MODEL_REF_FIELD_NUMBER: builtins.int + PARAMS_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: ... + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ModelRef: ... + @property + def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams: ... + def __init__( + self, + *, + input: global___Relation | None = ..., + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ModelRef | None = ..., + params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "input", b"input", "model_ref", b"model_ref", "params", b"params" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "input", b"input", "model_ref", b"model_ref", "params", b"params" + ], + ) -> None: ... + + class FeatureTransform(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + TRANSFORMER_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: ... + @property + def transformer(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlStage: ... + def __init__( + self, + *, + input: global___Relation | None = ..., + transformer: pyspark.sql.connect.proto.ml_common_pb2.MlStage | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal["input", b"input", "transformer", b"transformer"], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal["input", b"input", "transformer", b"transformer"], + ) -> None: ... + + class ModelAttr(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MODEL_REF_FIELD_NUMBER: builtins.int + NAME_FIELD_NUMBER: builtins.int + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ModelRef: ... + name: builtins.str + def __init__( + self, + *, + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ModelRef | None = ..., + name: builtins.str = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref", "name", b"name"] + ) -> None: ... + + class ModelSummaryAttr(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MODEL_REF_FIELD_NUMBER: builtins.int + NAME_FIELD_NUMBER: builtins.int + PARAMS_FIELD_NUMBER: builtins.int + EVALUATION_DATASET_FIELD_NUMBER: builtins.int + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ModelRef: ... + name: builtins.str + @property + def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams: ... + @property + def evaluation_dataset(self) -> global___Relation: + """Evaluation dataset that it uses to computes + the summary attribute + If not set, get attributes from + model.summary (i.e. the summary on training dataset) + """ + def __init__( + self, + *, + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ModelRef | None = ..., + name: builtins.str = ..., + params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None = ..., + evaluation_dataset: global___Relation | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_evaluation_dataset", + b"_evaluation_dataset", + "evaluation_dataset", + b"evaluation_dataset", + "model_ref", + b"model_ref", + "params", + b"params", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_evaluation_dataset", + b"_evaluation_dataset", + "evaluation_dataset", + b"evaluation_dataset", + "model_ref", + b"model_ref", + "name", + b"name", + "params", + b"params", + ], + ) -> None: ... + def WhichOneof( + self, + oneof_group: typing_extensions.Literal["_evaluation_dataset", b"_evaluation_dataset"], + ) -> typing_extensions.Literal["evaluation_dataset"] | None: ... + + MODEL_TRANSFORM_FIELD_NUMBER: builtins.int + FEATURE_TRANSFORM_FIELD_NUMBER: builtins.int + MODEL_ATTR_FIELD_NUMBER: builtins.int + MODEL_SUMMARY_ATTR_FIELD_NUMBER: builtins.int + @property + def model_transform(self) -> global___MlRelation.ModelTransform: ... + @property + def feature_transform(self) -> global___MlRelation.FeatureTransform: ... + @property + def model_attr(self) -> global___MlRelation.ModelAttr: ... + @property + def model_summary_attr(self) -> global___MlRelation.ModelSummaryAttr: ... + def __init__( + self, + *, + model_transform: global___MlRelation.ModelTransform | None = ..., + feature_transform: global___MlRelation.FeatureTransform | None = ..., + model_attr: global___MlRelation.ModelAttr | None = ..., + model_summary_attr: global___MlRelation.ModelSummaryAttr | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "feature_transform", + b"feature_transform", + "ml_relation_type", + b"ml_relation_type", + "model_attr", + b"model_attr", + "model_summary_attr", + b"model_summary_attr", + "model_transform", + b"model_transform", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "feature_transform", + b"feature_transform", + "ml_relation_type", + b"ml_relation_type", + "model_attr", + b"model_attr", + "model_summary_attr", + b"model_summary_attr", + "model_transform", + b"model_transform", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["ml_relation_type", b"ml_relation_type"] + ) -> typing_extensions.Literal[ + "model_transform", "feature_transform", "model_attr", "model_summary_attr" + ] | None: ... + +global___MlRelation = MlRelation + class Unknown(google.protobuf.message.Message): """Used for testing purposes only.""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 57da3b5af606..23882994fa7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -187,7 +187,7 @@ private[sql] object Dataset { * @since 1.6.0 */ @Stable -class Dataset[T] private[sql]( +class Dataset[T] private[spark]( @DeveloperApi @Unstable @transient val queryExecution: QueryExecution, @DeveloperApi @Unstable @transient val encoder: Encoder[T]) extends Serializable {