|
9 | 9 | "Licensed under the MIT License."
|
10 | 10 | ]
|
11 | 11 | },
|
| 12 | + { |
| 13 | + "cell_type": "markdown", |
| 14 | + "metadata": {}, |
| 15 | + "source": [ |
| 16 | + "" |
| 17 | + ] |
| 18 | + }, |
12 | 19 | {
|
13 | 20 | "cell_type": "markdown",
|
14 | 21 | "metadata": {},
|
|
66 | 73 | "import numpy as np\n",
|
67 | 74 | "import pandas as pd\n",
|
68 | 75 | "from sklearn import datasets\n",
|
| 76 | + "from sklearn.model_selection import train_test_split\n", |
69 | 77 | "\n",
|
70 | 78 | "import azureml.core\n",
|
71 | 79 | "from azureml.core.experiment import Experiment\n",
|
72 | 80 | "from azureml.core.workspace import Workspace\n",
|
73 |
| - "from azureml.train.automl import AutoMLConfig" |
| 81 | + "from azureml.train.automl import AutoMLConfig, constants" |
74 | 82 | ]
|
75 | 83 | },
|
76 | 84 | {
|
|
106 | 114 | "source": [
|
107 | 115 | "## Data\n",
|
108 | 116 | "\n",
|
109 |
| - "This uses scikit-learn's [load_digits](http://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) method." |
| 117 | + "This uses scikit-learn's [load_iris](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html) method." |
110 | 118 | ]
|
111 | 119 | },
|
112 | 120 | {
|
|
115 | 123 | "metadata": {},
|
116 | 124 | "outputs": [],
|
117 | 125 | "source": [
|
118 |
| - "digits = datasets.load_digits()\n", |
| 126 | + "iris = datasets.load_iris()\n", |
| 127 | + "X_train, X_test, y_train, y_test = train_test_split(iris.data, \n", |
| 128 | + " iris.target, \n", |
| 129 | + " test_size=0.2, \n", |
| 130 | + " random_state=0)\n", |
119 | 131 | "\n",
|
120 |
| - "# Exclude the first 100 rows from training so that they can be used for test.\n", |
121 |
| - "X_train = digits.data[100:,:]\n", |
122 |
| - "y_train = digits.target[100:]" |
| 132 | + "# Convert the X_train and X_test to pandas DataFrame and set column names,\n", |
| 133 | + "# This is needed for initializing the input variable names of ONNX model, \n", |
| 134 | + "# and the prediction with the ONNX model using the inference helper.\n", |
| 135 | + "X_train = pd.DataFrame(X_train, columns=['c1', 'c2', 'c3', 'c4'])\n", |
| 136 | + "X_test = pd.DataFrame(X_test, columns=['c1', 'c2', 'c3', 'c4'])" |
123 | 137 | ]
|
124 | 138 | },
|
125 | 139 | {
|
|
155 | 169 | " primary_metric = 'AUC_weighted',\n",
|
156 | 170 | " iteration_timeout_minutes = 60,\n",
|
157 | 171 | " iterations = 10,\n",
|
158 |
| - " verbosity = logging.INFO,\n", |
| 172 | + " verbosity = logging.INFO, \n", |
159 | 173 | " X = X_train, \n",
|
160 | 174 | " y = y_train,\n",
|
| 175 | + " preprocess=True,\n", |
161 | 176 | " enable_onnx_compatible_models=True,\n",
|
162 | 177 | " path = project_folder)"
|
163 | 178 | ]
|
|
253 | 268 | "onnx_fl_path = \"./best_model.onnx\"\n",
|
254 | 269 | "OnnxConverter.save_onnx_model(onnx_mdl, onnx_fl_path)"
|
255 | 270 | ]
|
| 271 | + }, |
| 272 | + { |
| 273 | + "cell_type": "markdown", |
| 274 | + "metadata": {}, |
| 275 | + "source": [ |
| 276 | + "### Predict with the ONNX model, using onnxruntime package" |
| 277 | + ] |
| 278 | + }, |
| 279 | + { |
| 280 | + "cell_type": "code", |
| 281 | + "execution_count": null, |
| 282 | + "metadata": {}, |
| 283 | + "outputs": [], |
| 284 | + "source": [ |
| 285 | + "import sys\n", |
| 286 | + "import json\n", |
| 287 | + "from azureml.automl.core.onnx_convert import OnnxConvertConstants\n", |
| 288 | + "\n", |
| 289 | + "if sys.version_info < OnnxConvertConstants.OnnxIncompatiblePythonVersion:\n", |
| 290 | + " python_version_compatible = True\n", |
| 291 | + "else:\n", |
| 292 | + " python_version_compatible = False\n", |
| 293 | + "\n", |
| 294 | + "try:\n", |
| 295 | + " import onnxruntime\n", |
| 296 | + " from azureml.automl.core.onnx_convert import OnnxInferenceHelper \n", |
| 297 | + " onnxrt_present = True\n", |
| 298 | + "except ImportError:\n", |
| 299 | + " onnxrt_present = False\n", |
| 300 | + "\n", |
| 301 | + "def get_onnx_res(run):\n", |
| 302 | + " res_path = '_debug_y_trans_converter.json'\n", |
| 303 | + " run.download_file(name=constants.MODEL_RESOURCE_PATH_ONNX, output_file_path=res_path)\n", |
| 304 | + " with open(res_path) as f:\n", |
| 305 | + " onnx_res = json.load(f)\n", |
| 306 | + " return onnx_res\n", |
| 307 | + "\n", |
| 308 | + "if onnxrt_present and python_version_compatible: \n", |
| 309 | + " mdl_bytes = onnx_mdl.SerializeToString()\n", |
| 310 | + " onnx_res = get_onnx_res(best_run)\n", |
| 311 | + "\n", |
| 312 | + " onnxrt_helper = OnnxInferenceHelper(mdl_bytes, onnx_res)\n", |
| 313 | + " pred_onnx, pred_prob_onnx = onnxrt_helper.predict(X_test)\n", |
| 314 | + "\n", |
| 315 | + " print(pred_onnx)\n", |
| 316 | + " print(pred_prob_onnx)\n", |
| 317 | + "else:\n", |
| 318 | + " if not python_version_compatible:\n", |
| 319 | + " print('Please use Python version 3.6 to run the inference helper.') \n", |
| 320 | + " if not onnxrt_present:\n", |
| 321 | + " print('Please install the onnxruntime package to do the prediction with ONNX model.')" |
| 322 | + ] |
| 323 | + }, |
| 324 | + { |
| 325 | + "cell_type": "code", |
| 326 | + "execution_count": null, |
| 327 | + "metadata": {}, |
| 328 | + "outputs": [], |
| 329 | + "source": [] |
256 | 330 | }
|
257 | 331 | ],
|
258 | 332 | "metadata": {
|
|
0 commit comments