forked from StudioCommunity/CustomModules
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
build-in models, save sklearn & keras models
- Loading branch information
1 parent
6bb4c4a
commit a26d650
Showing
15 changed files
with
302 additions
and
261 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import yaml | ||
import json | ||
from sys import version_info | ||
|
||
PYTHON_VERSION = "{major}.{minor}.{micro}".format(major=version_info.major, | ||
minor=version_info.minor, | ||
micro=version_info.micro) | ||
_conda_header = """\ | ||
name: project_environment | ||
channels: | ||
- defaults | ||
""" | ||
|
||
_extra_index_url = "--extra-index-url=https://test.pypi.org/simple" | ||
_alghost_pip = "alghost==0.0.59" | ||
_azureml_defaults_pip = "azureml-defaults" | ||
|
||
# temp solution, would remove later | ||
_data_type_file_name = "data_type.json" | ||
_data_ilearner_file_name = "data.ilearner" | ||
|
||
|
||
def _generate_conda_env(path=None, additional_conda_deps=None, additional_pip_deps=None, | ||
additional_conda_channels=None, install_alghost=True, install_azureml=True): | ||
env = yaml.safe_load(_conda_header) | ||
env["dependencies"] = ["python={}".format(PYTHON_VERSION), "git", "regex"] | ||
pip_deps = ([_extra_index_url, _alghost_pip] if install_alghost else []) + ( | ||
[_azureml_defaults_pip] if install_alghost else []) + ( | ||
additional_pip_deps if additional_pip_deps else []) | ||
if additional_conda_deps is not None: | ||
env["dependencies"] += additional_conda_deps | ||
env["dependencies"].append({"pip": pip_deps}) | ||
if additional_conda_channels is not None: | ||
env["channels"] += additional_conda_channels | ||
|
||
if path is not None: | ||
with open(path, "w") as out: | ||
yaml.safe_dump(env, stream=out, default_flow_style=False) | ||
return None | ||
else: | ||
return env | ||
|
||
|
||
def _generate_ilearner_files(path): | ||
# Dump data_type.json as a work around until SMT deploys | ||
dct = { | ||
"Id": "ILearnerDotNet", | ||
"Name": "ILearner .NET file", | ||
"ShortName": "Model", | ||
"Description": "A .NET serialized ILearner", | ||
"IsDirectory": False, | ||
"Owner": "Microsoft Corporation", | ||
"FileExtension": "ilearner", | ||
"ContentType": "application/octet-stream", | ||
"AllowUpload": False, | ||
"AllowPromotion": False, | ||
"AllowModelPromotion": True, | ||
"AuxiliaryFileExtension": None, | ||
"AuxiliaryContentType": None | ||
} | ||
with open(os.path.join(path, _data_type_file_name), 'w') as fp: | ||
json.dump(dct, fp) | ||
# Dump data.ilearner as a work around until data type design | ||
with open(os.path.join(path, _data_ilearner_file_name), 'w') as fp: | ||
fp.writelines('{}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
import yaml | ||
|
||
from builtin_models.environment import _generate_conda_env | ||
from builtin_models.environment import _generate_ilearner_files | ||
|
||
FLAVOR_NAME = "keras" | ||
model_file_name = "model.h5" | ||
conda_file_name = "conda.yaml" | ||
model_spec_file_name = "model_spec.yml" | ||
|
||
def _get_default_conda_env(): | ||
import keras | ||
import tensorflow as tf | ||
|
||
return _generate_conda_env( | ||
additional_pip_deps=[ | ||
"keras=={}".format(keras.__version__), | ||
"tensorflow=={}".format(tf.__version__), | ||
]) | ||
|
||
|
||
def _save_conda_env(path, conda_env=None): | ||
if conda_env is None: | ||
conda_env = _get_default_conda_env() | ||
elif not isinstance(conda_env, dict): | ||
with open(conda_env, "r") as f: # conda_env is a file | ||
conda_env = yaml.safe_load(f) | ||
with open(os.path.join(path, conda_file_name), "w") as f: | ||
yaml.safe_dump(conda_env, stream=f, default_flow_style=False) | ||
|
||
|
||
def _save_model_spec(path): | ||
spec = { | ||
'flavor' : { | ||
'framework' : FLAVOR_NAME | ||
}, | ||
FLAVOR_NAME: { | ||
'model_file_path': model_file_name | ||
}, | ||
'conda': { | ||
'conda_file_path': conda_file_name | ||
}, | ||
} | ||
with open(os.path.join(path, model_spec_file_name), 'w') as fp: | ||
yaml.dump(spec, fp, default_flow_style=False) | ||
|
||
|
||
def load_model_from_local_file(path): | ||
from keras.models import load_model | ||
return load_model(path) | ||
|
||
|
||
def save_model(keras_model, path, conda_env=None): | ||
import keras | ||
|
||
if(not path.endswith('/')): | ||
path += '/' | ||
if not os.path.exists(path): | ||
os.makedirs(path) | ||
|
||
keras_model.save(os.path.join(path, model_file_name)) | ||
_save_conda_env(path, conda_env) | ||
_save_model_spec(path) | ||
_generate_ilearner_files(path) # temp solution, to remove later | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import os | ||
import yaml | ||
import pickle | ||
|
||
from builtin_models.environment import _generate_conda_env | ||
from builtin_models.environment import _generate_ilearner_files | ||
|
||
FLAVOR_NAME = "sklearn" | ||
model_file_name = "model.pkl" | ||
conda_file_name = "conda.yaml" | ||
model_spec_file_name = "model_spec.yml" | ||
|
||
def _get_default_conda_env(): | ||
import sklearn | ||
|
||
return _generate_conda_env( | ||
additional_pip_deps=[ | ||
"scikit-learn=={}".format(sklearn.__version__) | ||
]) | ||
|
||
|
||
def _save_conda_env(path, conda_env=None): | ||
if conda_env is None: | ||
conda_env = _get_default_conda_env() | ||
elif not isinstance(conda_env, dict): | ||
with open(conda_env, "r") as f: # conda_env is a file | ||
conda_env = yaml.safe_load(f) | ||
with open(os.path.join(path, conda_file_name), "w") as f: | ||
yaml.safe_dump(conda_env, stream=f, default_flow_style=False) | ||
|
||
|
||
def _save_model_spec(path): | ||
spec = { | ||
'flavor' : { | ||
'framework' : FLAVOR_NAME | ||
}, | ||
FLAVOR_NAME: { | ||
'model_file_path': model_file_name | ||
}, | ||
'conda': { | ||
'conda_file_path': conda_file_name | ||
}, | ||
} | ||
with open(os.path.join(path, model_spec_file_name), 'w') as fp: | ||
yaml.dump(spec, fp, default_flow_style=False) | ||
|
||
|
||
def _save_model(sklearn_model, path): | ||
with open(os.path.join(path, model_file_name), "wb") as fb: | ||
pickle.dump(sklearn_model, fb) | ||
|
||
|
||
def load_model_from_local_file(path): | ||
with open(path, "rb") as f: | ||
return pickle.load(f) | ||
|
||
|
||
def save_model(sklearn_model, path, conda_env=None): | ||
import sklearn | ||
|
||
if(not path.endswith('/')): | ||
path += '/' | ||
if not os.path.exists(path): | ||
os.makedirs(path) | ||
|
||
_save_model(sklearn_model, path) | ||
_save_conda_env(path, conda_env) | ||
_save_model_spec(path) | ||
_generate_ilearner_files(path) # temp solution, to remove later | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from setuptools import setup | ||
|
||
# python setup.py install | ||
setup( | ||
name="builtin_models", | ||
version="0.0.1", | ||
description="builtin_models", | ||
packages=["builtin_models"], | ||
author="Xin Zou", | ||
license="MIT", | ||
include_package_data=True, | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
|
||
# python -m test.builtin_models_test | ||
if __name__ == '__main__': | ||
# keras test | ||
from builtin_models.keras import * | ||
print('---keras test---') | ||
model = load_model_from_local_file('D:/GIT/CustomModules-migu-NewYamlTest2/dstest/model/keras-mnist/model.h5') | ||
print('------') | ||
save_model(model, "./test/outputModels/keras/") | ||
print('********') | ||
|
||
#sklearn test | ||
from builtin_models.sklearn import * | ||
print('---sklearn test---') | ||
model = load_model_from_local_file('D:/GIT/CustomModules-migu-NewYamlTest2/dstest/dstest/sklearn/model/sklearn/model.pkl') | ||
print('------') | ||
save_model(model, "./test/outputModels/sklearn/") | ||
print('********') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.