Skip to content

Commit 872d384

Browse files
committed
export numFeatures in ML PredictionModel
1 parent 461b7c6 commit 872d384

File tree

3 files changed

+118
-34
lines changed

3 files changed

+118
-34
lines changed

python/pyspark/ml/base.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,3 @@ class Model(Transformer):
116116
"""
117117

118118
__metaclass__ = ABCMeta
119-
120-
121-
class HasNumFeaturesModel:
122-
"""
123-
Provides getter of the number of features especially for model class
124-
It should be mixin with JavaModel.
125-
"""
126-
@property
127-
def numFeatures(self):
128-
"""
129-
The number of features used to train the model.
130-
"""
131-
return self._call_java("numFeatures")

python/pyspark/ml/classification.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from pyspark import since, keyword_only
2222
from pyspark.ml import Estimator, Model
23-
from pyspark.ml.base import HasNumFeaturesModel
2423
from pyspark.ml.param.shared import *
2524
from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel, \
2625
RandomForestParams, TreeEnsembleModels, TreeEnsembleParams
@@ -66,6 +65,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
6665
DenseVector([5.5...])
6766
>>> model.intercept
6867
-2.68...
68+
>>> model.numFeatures
69+
1
6970
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
7071
>>> result = model.transform(test0).head()
7172
>>> result.prediction
@@ -93,6 +94,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
9394
True
9495
>>> model.intercept == model2.intercept
9596
True
97+
>>> model.numFeatures == model2.numFeatures
98+
True
9699
97100
.. versionadded:: 1.3.0
98101
"""
@@ -215,7 +218,7 @@ def _checkThresholdConsistency(self):
215218
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
216219

217220

218-
class LogisticRegressionModel(HasNumFeaturesModel, JavaModel, JavaMLWritable, JavaMLReadable):
221+
class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
219222
"""
220223
.. note:: Experimental
221224
@@ -240,6 +243,14 @@ def intercept(self):
240243
"""
241244
return self._call_java("intercept")
242245

246+
@property
247+
@since("2.0.0")
248+
def numFeatures(self):
249+
"""
250+
Number of features the model was trained on.
251+
"""
252+
return self._call_java("numFeatures")
253+
243254
@property
244255
@since("2.0.0")
245256
def summary(self):
@@ -525,6 +536,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
525536
1
526537
>>> model.featureImportances
527538
SparseVector(1, {0: 1.0})
539+
>>> model.numFeatures
540+
1
528541
>>> print(model.toDebugString)
529542
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
530543
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@@ -549,8 +562,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
549562
>>> model2 = DecisionTreeClassificationModel.load(model_path)
550563
>>> model.featureImportances == model2.featureImportances
551564
True
552-
>>> model.numFeatures
553-
1
565+
>>> model.numFeatures == model2.numFeatures
566+
True
554567
555568
.. versionadded:: 1.4.0
556569
"""
@@ -600,8 +613,7 @@ def _create_model(self, java_model):
600613

601614

602615
@inherit_doc
603-
class DecisionTreeClassificationModel(HasNumFeaturesModel, DecisionTreeModel, JavaMLWritable,
604-
JavaMLReadable):
616+
class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
605617
"""
606618
.. note:: Experimental
607619
@@ -631,6 +643,14 @@ def featureImportances(self):
631643
"""
632644
return self._call_java("featureImportances")
633645

646+
@property
647+
@since("2.0.0")
648+
def numFeatures(self):
649+
"""
650+
Number of features the model was trained on.
651+
"""
652+
return self._call_java("numFeatures")
653+
634654

635655
@inherit_doc
636656
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
@@ -672,6 +692,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
672692
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
673693
>>> model.transform(test1).head().prediction
674694
1.0
695+
>>> model.numFeatures
696+
1
675697
>>> model.trees
676698
[DecisionTreeClassificationModel (uid=...) of depth..., DecisionTreeClassificationModel...]
677699
>>> rfc_path = temp_path + "/rfc"
@@ -734,8 +756,7 @@ def _create_model(self, java_model):
734756
return RandomForestClassificationModel(java_model)
735757

736758

737-
class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable,
738-
HasNumFeaturesModel):
759+
class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
739760
"""
740761
.. note:: Experimental
741762
@@ -759,6 +780,14 @@ def featureImportances(self):
759780
"""
760781
return self._call_java("featureImportances")
761782

783+
@property
784+
@since("2.0.0")
785+
def numFeatures(self):
786+
"""
787+
Number of features the model was trained on.
788+
"""
789+
return self._call_java("numFeatures")
790+
762791
@property
763792
@since("2.0.0")
764793
def trees(self):
@@ -811,6 +840,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
811840
1.0
812841
>>> model.totalNumNodes
813842
15
843+
>>> model.numFeatures
844+
1
814845
>>> print(model.toDebugString)
815846
GBTClassificationModel (uid=...)...with 5 trees...
816847
>>> gbtc_path = temp_path + "gbtc"
@@ -892,7 +923,7 @@ def getLossType(self):
892923
return self.getOrDefault(self.lossType)
893924

894925

895-
class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable, HasNumFeaturesModel):
926+
class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
896927
"""
897928
.. note:: Experimental
898929
@@ -916,6 +947,14 @@ def featureImportances(self):
916947
"""
917948
return self._call_java("featureImportances")
918949

950+
@property
951+
@since("2.0.0")
952+
def numFeatures(self):
953+
"""
954+
Number of features the model was trained on.
955+
"""
956+
return self._call_java("numFeatures")
957+
919958
@property
920959
@since("2.0.0")
921960
def trees(self):
@@ -961,6 +1000,8 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
9611000
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
9621001
>>> model.transform(test1).head().prediction
9631002
1.0
1003+
>>> model.numFeatures
1004+
2
9641005
>>> nb_path = temp_path + "/nb"
9651006
>>> nb.save(nb_path)
9661007
>>> nb2 = NaiveBayes.load(nb_path)
@@ -979,7 +1020,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
9791020
>>> result.prediction
9801021
0.0
9811022
>>> model.numFeatures == model2.numFeatures
982-
2
1023+
True
9831024
9841025
.. versionadded:: 1.5.0
9851026
"""
@@ -1052,7 +1093,7 @@ def getModelType(self):
10521093
return self.getOrDefault(self.modelType)
10531094

10541095

1055-
class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable, HasNumFeaturesModel):
1096+
class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
10561097
"""
10571098
.. note:: Experimental
10581099
@@ -1077,6 +1118,14 @@ def theta(self):
10771118
"""
10781119
return self._call_java("theta")
10791120

1121+
@property
1122+
@since("2.0.0")
1123+
def numFeatures(self):
1124+
"""
1125+
Number of features the model was trained on.
1126+
"""
1127+
return self._call_java("numFeatures")
1128+
10801129

10811130
@inherit_doc
10821131
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
@@ -1102,6 +1151,8 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
11021151
[2, 2, 2]
11031152
>>> model.weights.size
11041153
12
1154+
>>> model.numFeatures
1155+
2
11051156
>>> testDF = spark.createDataFrame([
11061157
... (Vectors.dense([1.0, 0.0]),),
11071158
... (Vectors.dense([0.0, 0.0]),)], ["features"])
@@ -1255,8 +1306,7 @@ def getInitialWeights(self):
12551306
return self.getOrDefault(self.initialWeights)
12561307

12571308

1258-
class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable,
1259-
HasNumFeaturesModel):
1309+
class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable):
12601310
"""
12611311
.. note:: Experimental
12621312
@@ -1281,6 +1331,14 @@ def weights(self):
12811331
"""
12821332
return self._call_java("weights")
12831333

1334+
@property
1335+
@since("2.0.0")
1336+
def numFeatures(self):
1337+
"""
1338+
Number of features the model was trained on.
1339+
"""
1340+
return self._call_java("numFeatures")
1341+
12841342

12851343
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
12861344
"""

0 commit comments

Comments
 (0)