2020
2121from pyspark import since , keyword_only
2222from pyspark .ml import Estimator , Model
23- from pyspark .ml .base import HasNumFeaturesModel
2423from pyspark .ml .param .shared import *
2524from pyspark .ml .regression import DecisionTreeModel , DecisionTreeRegressionModel , \
2625 RandomForestParams , TreeEnsembleModels , TreeEnsembleParams
@@ -66,6 +65,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
6665 DenseVector([5.5...])
6766 >>> model.intercept
6867 -2.68...
68+ >>> model.numFeatures
69+ 1
6970 >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
7071 >>> result = model.transform(test0).head()
7172 >>> result.prediction
@@ -93,6 +94,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
9394 True
9495 >>> model.intercept == model2.intercept
9596 True
97+ >>> model.numFeatures == model2.numFeatures
98+ True
9699
97100 .. versionadded:: 1.3.0
98101 """
@@ -215,7 +218,7 @@ def _checkThresholdConsistency(self):
215218 " threshold (%g) and thresholds (equivalent to %g)" % (t2 , t ))
216219
217220
218- class LogisticRegressionModel (HasNumFeaturesModel , JavaModel , JavaMLWritable , JavaMLReadable ):
221+ class LogisticRegressionModel (JavaModel , JavaMLWritable , JavaMLReadable ):
219222 """
220223 .. note:: Experimental
221224
@@ -240,6 +243,14 @@ def intercept(self):
240243 """
241244 return self ._call_java ("intercept" )
242245
246+ @property
247+ @since ("2.0.0" )
248+ def numFeatures (self ):
249+ """
250+ Number of features the model was trained on.
251+ """
252+ return self ._call_java ("numFeatures" )
253+
243254 @property
244255 @since ("2.0.0" )
245256 def summary (self ):
@@ -525,6 +536,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
525536 1
526537 >>> model.featureImportances
527538 SparseVector(1, {0: 1.0})
539+ >>> model.numFeatures
540+ 1
528541 >>> print(model.toDebugString)
529542 DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
530543 >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@@ -549,8 +562,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
549562 >>> model2 = DecisionTreeClassificationModel.load(model_path)
550563 >>> model.featureImportances == model2.featureImportances
551564 True
552- >>> model.numFeatures
553- 1
565+ >>> model.numFeatures == model2.numFeatures
566+ True
554567
555568 .. versionadded:: 1.4.0
556569 """
@@ -600,8 +613,7 @@ def _create_model(self, java_model):
600613
601614
602615@inherit_doc
603- class DecisionTreeClassificationModel (HasNumFeaturesModel , DecisionTreeModel , JavaMLWritable ,
604- JavaMLReadable ):
616+ class DecisionTreeClassificationModel (DecisionTreeModel , JavaMLWritable , JavaMLReadable ):
605617 """
606618 .. note:: Experimental
607619
@@ -631,6 +643,14 @@ def featureImportances(self):
631643 """
632644 return self ._call_java ("featureImportances" )
633645
646+ @property
647+ @since ("2.0.0" )
648+ def numFeatures (self ):
649+ """
650+ Number of features the model was trained on.
651+ """
652+ return self ._call_java ("numFeatures" )
653+
634654
635655@inherit_doc
636656class RandomForestClassifier (JavaEstimator , HasFeaturesCol , HasLabelCol , HasPredictionCol , HasSeed ,
@@ -672,6 +692,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
672692 >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
673693 >>> model.transform(test1).head().prediction
674694 1.0
695+ >>> model.numFeatures
696+ 1
675697 >>> model.trees
676698 [DecisionTreeClassificationModel (uid=...) of depth..., DecisionTreeClassificationModel...]
677699 >>> rfc_path = temp_path + "/rfc"
@@ -734,8 +756,7 @@ def _create_model(self, java_model):
734756 return RandomForestClassificationModel (java_model )
735757
736758
737- class RandomForestClassificationModel (TreeEnsembleModels , JavaMLWritable , JavaMLReadable ,
738- HasNumFeaturesModel ):
759+ class RandomForestClassificationModel (TreeEnsembleModels , JavaMLWritable , JavaMLReadable ):
739760 """
740761 .. note:: Experimental
741762
@@ -759,6 +780,14 @@ def featureImportances(self):
759780 """
760781 return self ._call_java ("featureImportances" )
761782
783+ @property
784+ @since ("2.0.0" )
785+ def numFeatures (self ):
786+ """
787+ Number of features the model was trained on.
788+ """
789+ return self ._call_java ("numFeatures" )
790+
762791 @property
763792 @since ("2.0.0" )
764793 def trees (self ):
@@ -811,6 +840,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
811840 1.0
812841 >>> model.totalNumNodes
813842 15
843+ >>> model.numFeatures
844+ 1
814845 >>> print(model.toDebugString)
815846 GBTClassificationModel (uid=...)...with 5 trees...
816847 >>> gbtc_path = temp_path + "gbtc"
@@ -892,7 +923,7 @@ def getLossType(self):
892923 return self .getOrDefault (self .lossType )
893924
894925
895- class GBTClassificationModel (TreeEnsembleModels , JavaMLWritable , JavaMLReadable , HasNumFeaturesModel ):
926+ class GBTClassificationModel (TreeEnsembleModels , JavaMLWritable , JavaMLReadable ):
896927 """
897928 .. note:: Experimental
898929
@@ -916,6 +947,14 @@ def featureImportances(self):
916947 """
917948 return self ._call_java ("featureImportances" )
918949
950+ @property
951+ @since ("2.0.0" )
952+ def numFeatures (self ):
953+ """
954+ Number of features the model was trained on.
955+ """
956+ return self ._call_java ("numFeatures" )
957+
919958 @property
920959 @since ("2.0.0" )
921960 def trees (self ):
@@ -961,6 +1000,8 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
9611000 >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
9621001 >>> model.transform(test1).head().prediction
9631002 1.0
1003+ >>> model.numFeatures
1004+ 2
9641005 >>> nb_path = temp_path + "/nb"
9651006 >>> nb.save(nb_path)
9661007 >>> nb2 = NaiveBayes.load(nb_path)
@@ -979,7 +1020,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
9791020 >>> result.prediction
9801021 0.0
9811022 >>> model.numFeatures == model2.numFeatures
982- 2
1023+ True
9831024
9841025 .. versionadded:: 1.5.0
9851026 """
@@ -1052,7 +1093,7 @@ def getModelType(self):
10521093 return self .getOrDefault (self .modelType )
10531094
10541095
1055- class NaiveBayesModel (JavaModel , JavaMLWritable , JavaMLReadable , HasNumFeaturesModel ):
1096+ class NaiveBayesModel (JavaModel , JavaMLWritable , JavaMLReadable ):
10561097 """
10571098 .. note:: Experimental
10581099
@@ -1077,6 +1118,14 @@ def theta(self):
10771118 """
10781119 return self ._call_java ("theta" )
10791120
1121+ @property
1122+ @since ("2.0.0" )
1123+ def numFeatures (self ):
1124+ """
1125+ Number of features the model was trained on.
1126+ """
1127+ return self ._call_java ("numFeatures" )
1128+
10801129
10811130@inherit_doc
10821131class MultilayerPerceptronClassifier (JavaEstimator , HasFeaturesCol , HasLabelCol , HasPredictionCol ,
@@ -1102,6 +1151,8 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
11021151 [2, 2, 2]
11031152 >>> model.weights.size
11041153 12
1154+ >>> model.numFeatures
1155+ 2
11051156 >>> testDF = spark.createDataFrame([
11061157 ... (Vectors.dense([1.0, 0.0]),),
11071158 ... (Vectors.dense([0.0, 0.0]),)], ["features"])
@@ -1255,8 +1306,7 @@ def getInitialWeights(self):
12551306 return self .getOrDefault (self .initialWeights )
12561307
12571308
1258- class MultilayerPerceptronClassificationModel (JavaModel , JavaMLWritable , JavaMLReadable ,
1259- HasNumFeaturesModel ):
1309+ class MultilayerPerceptronClassificationModel (JavaModel , JavaMLWritable , JavaMLReadable ):
12601310 """
12611311 .. note:: Experimental
12621312
@@ -1281,6 +1331,14 @@ def weights(self):
12811331 """
12821332 return self ._call_java ("weights" )
12831333
1334+ @property
1335+ @since ("2.0.0" )
1336+ def numFeatures (self ):
1337+ """
1338+ Number of features the model was trained on.
1339+ """
1340+ return self ._call_java ("numFeatures" )
1341+
12841342
12851343class OneVsRestParams (HasFeaturesCol , HasLabelCol , HasPredictionCol ):
12861344 """
0 commit comments