Skip to content

Commit

Permalink
IGNITE-11244: [ML] Improve model loading from directory instead
Browse files Browse the repository at this point in the history
full path to file with model

This closes apache#6065
  • Loading branch information
zaleslaw authored and ybabak committed Feb 11, 2019
1 parent cca4491 commit 816f435
Show file tree
Hide file tree
Showing 96 changed files with 385 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
*/
public class DecisionTreeFromSparkExample {
/** Path to Spark DT model. */
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/dt/data" +
"/part-00000-86bc0f70-df49-48b3-8356-9a26f9a6eb0f-c000.snappy.parquet";
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/dt";

/** Run example. */
public static void main(String[] args) throws FileNotFoundException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@
*/
public class DecisionTreeRegressionFromSparkExample {
/** Path to Spark Decision tree regression model. */
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/dtreg/data" +
"/part-00000-366f6ff2-698b-4bdd-8b1c-de87e11b3d1b-c000.snappy.parquet";
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/dtreg";

/** Run example. */
public static void main(String[] args) throws FileNotFoundException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,7 @@
*/
public class GBTFromSparkExample {
/** Path to Spark LogReg model. */
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/gbt/data" +
"/part-00000-ea23dcda-6344-4b1f-9716-fbedf7caba2d-c000.snappy.parquet";

/** Spark model metadata path. */
private static final String SPARK_MDL_METADATA_PATH = "examples/src/main/resources/models/spark/serialized/gbt/treesMetadata" +
"/part-00000-9033203a-e1e6-4d24-9900-be8a4396710b-c000.snappy.parquet";
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/gbt";

/** Run example. */
public static void main(String[] args) throws FileNotFoundException {
Expand All @@ -67,8 +62,8 @@ public static void main(String[] args) throws FileNotFoundException {

IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];

ModelsComposition mdl = (ModelsComposition)SparkModelParser.parseWithMetadata(
SPARK_MDL_PATH, SPARK_MDL_METADATA_PATH,
ModelsComposition mdl = (ModelsComposition)SparkModelParser.parse(
SPARK_MDL_PATH,
SupportedSparkModels.GRADIENT_BOOSTED_TREES
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@
*/
public class GBTRegressionFromSparkExample {
/** Path to Spark GBT Regression model. */
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/gbtreg/data" +
"/part-00000-db4215b8-888b-4944-b933-7897869a29d3-c000.snappy.parquet";

/** Spark model metadata path. */
private static final String SPARK_MDL_METADATA_PATH = "examples/src/main/resources/models/spark/serialized/gbtreg/treesMetadata" +
"/part-00000-999806a9-1326-48b3-bad7-07c343405928-c000.snappy.parquet";
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/gbtreg";

/** Run example. */
public static void main(String[] args) throws FileNotFoundException {
Expand All @@ -68,8 +63,8 @@ public static void main(String[] args) throws FileNotFoundException {

IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[4];

ModelsComposition mdl = (ModelsComposition)SparkModelParser.parseWithMetadata(
SPARK_MDL_PATH, SPARK_MDL_METADATA_PATH,
ModelsComposition mdl = (ModelsComposition)SparkModelParser.parse(
SPARK_MDL_PATH,
SupportedSparkModels.GRADIENT_BOOSTED_TREES_REGRESSION
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@
*/
public class KMeansFromSparkExample {
/** Path to Spark KMeans model. */
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/kmeans/data" +
"/part-00000-e1f2c475-c65a-4b9e-879e-de4afd4f65bc-c000.snappy.parquet";
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/kmeans";

/** Run example. */
public static void main(String[] args) throws FileNotFoundException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@
*/
public class LinearRegressionFromSparkExample {
/** Path to Spark linear regression model. */
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/linreg/data" +
"/part-00000-1ff2d09d-6cdf-4ad3-bddd-7cad8378429d-c000.snappy.parquet";
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/linreg";

/** Run example. */
public static void main(String[] args) throws FileNotFoundException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
*/
public class LogRegFromSparkExample {
/** Path to Spark LogReg model. */
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/logreg/data" +
"/part-00000-7551081d-c0a8-4ed7-afe4-a464aabc7f80-c000.snappy.parquet";
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/logreg";

/** Run example. */
public static void main(String[] args) throws FileNotFoundException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
*/
public class RandomForestFromSparkExample {
/** Path to Spark Random Forest model. */
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/rf/data" +
"/part-00000-290bdb9d-bc1b-411c-8811-c3205434f5fc-c000.snappy.parquet";
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/rf";

/** Run example. */
public static void main(String[] args) throws FileNotFoundException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@
*/
public class RandomForestRegressionFromSparkExample {
/** Path to Spark Random Forest regression model. */
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/rfreg/data" +
"/part-00000-06273895-4b81-4a77-823e-dfd32d1560eb-c000.snappy.parquet";
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/rfreg";

/** Run example. */
public static void main(String[] args) throws FileNotFoundException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
*/
public class SVMFromSparkExample {
/** Path to Spark SVM model. */
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/svm/data" +
"/part-00000-b3d800e2-a36c-4948-8e65-29c9f5c9c5b2-c000.snappy.parquet";
public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/svm";

/** Run example. */
public static void main(String[] args) throws FileNotFoundException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -62,14 +63,94 @@

/** Parser of Spark models. */
public class SparkModelParser {
/**
* Load model from parquet (presented as a directory).
*
* @param pathToMdl Path to directory with saved model.
* @param parsedSparkMdl Parsed spark model.
*/
public static Model parse(String pathToMdl, SupportedSparkModels parsedSparkMdl) throws IllegalArgumentException {
File mdlDir = IgniteUtils.resolveIgnitePath(pathToMdl);

if (mdlDir == null)
throw new IllegalArgumentException("Directory not found or empty [directory_path=" + pathToMdl + "]");

if (!mdlDir.isDirectory())
throw new IllegalArgumentException("Spark Model Parser supports loading from directory only. " +
"The specified path " + pathToMdl + " is not the path to directory.");

String[] files = mdlDir.list();
if (files.length == 0) throw new IllegalArgumentException("Directory contain 0 files and sub-directories [directory_path=" + pathToMdl + "]");

if (Arrays.stream(files).noneMatch("data"::equals))
throw new IllegalArgumentException("Directory should contain data sub-directory [directory_path=" + pathToMdl + "]");

if (Arrays.stream(files).noneMatch("metadata"::equals))
throw new IllegalArgumentException("Directory should contain metadata sub-directory [directory_path=" + pathToMdl + "]");

String pathToData = pathToMdl + File.separator + "data";
File dataDir = IgniteUtils.resolveIgnitePath(pathToData);

File[] dataParquetFiles = dataDir.listFiles((dir, name) -> name.matches("^part-.*\\.snappy\\.parquet$"));
if (dataParquetFiles.length == 0)
throw new IllegalArgumentException("Directory should contain parquet file " +
"with model [directory_path=" + pathToData + "]");

if (dataParquetFiles.length > 1)
throw new IllegalArgumentException("Directory should contain only one parquet file " +
"with model [directory_path=" + pathToData + "]");

String pathToMdlFile = dataParquetFiles[0].getPath();

String pathToMetadata = pathToMdl + File.separator + "metadata";
File metadataDir = IgniteUtils.resolveIgnitePath(pathToMetadata);
String[] metadataFiles = metadataDir.list();

if (Arrays.stream(metadataFiles).noneMatch("part-00000"::equals))
throw new IllegalArgumentException("Directory should contain json file with model metadata " +
"with name part-00000 [directory_path=" + pathToMetadata + "]");

if (shouldContainTreeMetadataSubDirectory(parsedSparkMdl)) {
if (Arrays.stream(files).noneMatch("treesMetadata"::equals))
throw new IllegalArgumentException("Directory should contain treeMetadata sub-directory [directory_path=" + pathToMdl + "]");

String pathToTreesMetadata = pathToMdl + File.separator + "treesMetadata";
File treesMetadataDir = IgniteUtils.resolveIgnitePath(pathToTreesMetadata);

File[] treesMetadataParquetFiles = treesMetadataDir.listFiles((dir, name) -> name.matches("^part-.*\\.snappy\\.parquet$"));
if (treesMetadataParquetFiles.length == 0)
throw new IllegalArgumentException("Directory should contain parquet file " +
"with model treesMetadata [directory_path=" + pathToTreesMetadata + "]");

if (treesMetadataParquetFiles.length > 1)
throw new IllegalArgumentException("Directory should contain only one parquet file " +
"with model [directory_path=" + pathToTreesMetadata + "]");

String pathToTreesMetadataFile = treesMetadataParquetFiles[0].getPath();

return parseDataWithMetadata(pathToMdlFile, pathToTreesMetadataFile, parsedSparkMdl);
} else
return parseData(pathToMdlFile, parsedSparkMdl);

}

/**
* @param parsedSparkMdl Parsed spark model.
*/
private static boolean shouldContainTreeMetadataSubDirectory(SupportedSparkModels parsedSparkMdl) {
return parsedSparkMdl == SupportedSparkModels.GRADIENT_BOOSTED_TREES
|| parsedSparkMdl == SupportedSparkModels.GRADIENT_BOOSTED_TREES_REGRESSION;
}


/**
* Load model from parquet file.
*
* @param pathToMdl Hadoop path to model saved from Spark.
* @param parsedSparkMdl One of supported Spark models to parse it.
* @return Instance of parsedSparkMdl model.
*/
public static Model parse(String pathToMdl, SupportedSparkModels parsedSparkMdl) {
private static Model parseData(String pathToMdl, SupportedSparkModels parsedSparkMdl) {
File mdlRsrc = IgniteUtils.resolveIgnitePath(pathToMdl);
if (mdlRsrc == null)
throw new IllegalArgumentException("Resource not found [resource_path=" + pathToMdl + "]");
Expand Down Expand Up @@ -98,6 +179,39 @@ public static Model parse(String pathToMdl, SupportedSparkModels parsedSparkMdl)
}
}


/**
* Load model and its metadata from parquet files.
*
* @param pathToMdl Hadoop path to model saved from Spark.
* @param pathToMetaData Hadoop path to metadata saved from Spark.
* @param parsedSparkMdl One of supported Spark models to parse it.
* @return Instance of parsedSparkMdl model.
*/
private static Model parseDataWithMetadata(String pathToMdl, String pathToMetaData,
SupportedSparkModels parsedSparkMdl) {
File mdlRsrc1 = IgniteUtils.resolveIgnitePath(pathToMdl);
if (mdlRsrc1 == null)
throw new IllegalArgumentException("Resource not found [resource_path=" + pathToMdl + "]");

String ignitePathToMdl = mdlRsrc1.getPath();

File mdlRsrc2 = IgniteUtils.resolveIgnitePath(pathToMetaData);
if (mdlRsrc2 == null)
throw new IllegalArgumentException("Resource not found [resource_path=" + pathToMetaData + "]");

String ignitePathToMdlMetaData = mdlRsrc2.getPath();

switch (parsedSparkMdl) {
case GRADIENT_BOOSTED_TREES:
return loadGBTClassifierModel(ignitePathToMdl, ignitePathToMdlMetaData);
case GRADIENT_BOOSTED_TREES_REGRESSION:
return loadGBTRegressionModel(ignitePathToMdl, ignitePathToMdlMetaData);
default:
throw new UnsupportedSparkModelException(ignitePathToMdl);
}
}

/**
* Load Random Forest Regression model.
*
Expand Down Expand Up @@ -163,38 +277,6 @@ private static Model loadKMeansModel(String pathToMdl) {
return new KMeansModel(centers, new EuclideanDistance());
}

/**
* Load model and its metadata from parquet files.
*
* @param pathToMdl Hadoop path to model saved from Spark.
* @param pathToMetaData Hadoop path to metadata saved from Spark.
* @param parsedSparkMdl One of supported Spark models to parse it.
* @return Instance of parsedSparkMdl model.
*/
public static Model parseWithMetadata(String pathToMdl, String pathToMetaData,
SupportedSparkModels parsedSparkMdl) {
File mdlRsrc1 = IgniteUtils.resolveIgnitePath(pathToMdl);
if (mdlRsrc1 == null)
throw new IllegalArgumentException("Resource not found [resource_path=" + pathToMdl + "]");

String ignitePathToMdl = mdlRsrc1.getPath();

File mdlRsrc2 = IgniteUtils.resolveIgnitePath(pathToMetaData);
if (mdlRsrc2 == null)
throw new IllegalArgumentException("Resource not found [resource_path=" + pathToMetaData + "]");

String ignitePathToMdlMetaData = mdlRsrc2.getPath();

switch (parsedSparkMdl) {
case GRADIENT_BOOSTED_TREES:
return loadGBTClassifierModel(ignitePathToMdl, ignitePathToMdlMetaData);
case GRADIENT_BOOSTED_TREES_REGRESSION:
return loadGBTRegressionModel(ignitePathToMdl, ignitePathToMdlMetaData);
default:
throw new UnsupportedSparkModelException(ignitePathToMdl);
}
}

/**
* Load GDB Regression model.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.ignite.ml.sparkmodelparser;

import org.junit.runner.RunWith;
import org.junit.runners.Suite;

/** Test suite for all module tests. */
@RunWith(Suite.class)
@Suite.SuiteClasses({
SparkModelParserTest.class
})
public class IgniteMLSparkModelParserTestSuite {
// No-op.
}
Loading

0 comments on commit 816f435

Please sign in to comment.