diff --git a/.gitignore b/.gitignore index 1dc7b5b9e..2dac4bd7b 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,4 @@ docs/_build # License copied to conda build_dir pkg/conda/LICENSE +env/* \ No newline at end of file diff --git a/gramex/handlers/mlhandler.py b/gramex/handlers/mlhandler.py index ac35a2c9a..1fbe1d47a 100644 --- a/gramex/handlers/mlhandler.py +++ b/gramex/handlers/mlhandler.py @@ -21,6 +21,7 @@ from tornado.gen import coroutine from tornado.web import HTTPError from sklearn.metrics import get_scorer +from sklearn.model_selection import cross_val_predict, cross_val_score op = os.path MLCLASS_MODULES = [ @@ -40,7 +41,8 @@ 'pipeline': True, 'nums': [], 'cats': [], - 'target_col': None + 'target_col': None, + 'cv': True, } ACTIONS = ['predict', 'score', 'append', 'train', 'retrain'] DEFAULT_TEMPLATE = op.join(op.dirname(__file__), '..', 'apps', 'mlhandler', 'template.html') @@ -112,14 +114,23 @@ def setup(cls, data=None, model={}, config_dir='', **kwargs): data = cls._filtercols(data) data = cls._filterrows(data) cls.model = cls._assemble_pipeline(data, mclass=mclass, params=params) - # train the model target = data[target_col] train = data[[c for c in data if c != target_col]] + # cross validation + cls.cross_validation(train,target) gramex.service.threadpool.submit( _fit, cls.model, train, target, cls.model_path, cls.name) cls.config_store.flush() - + + @classmethod + def cross_validation(cls,train,target): + cv = cls.get_opt('cv',True) + if cv: + CVscore = cross_val_score(cls.model.steps[-1][1], X=train, y=target, cv=cv) + CVavg = sum(CVscore)/len(CVscore) + print('Cross Validation Score : ',CVavg) + @classmethod def load_data(cls, default=pd.DataFrame()): try: @@ -351,6 +362,7 @@ def _train(self, data=None): target = data[target_col] train = data[[c for c in data if c != target_col]] self.model = self._assemble_pipeline(data, force=True) + self.cross_validation(train,target) _fit(self.model, train, target, self.model_path) return {'score': self.model.score(train, target)}