diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 8cd291b48e..121f27d13e 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -18,6 +18,8 @@ repositories { dependencies { compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation project(':opensearch-ml-common') + implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" + implementation "org.opensearch:common-utils:${common_utils_version}" implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' implementation group: 'org.reflections', name: 'reflections', version: '0.9.12' implementation group: 'org.tribuo', name: 'tribuo-clustering-kmeans', version: '4.2.1' @@ -34,6 +36,7 @@ dependencies { testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.3.1' implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre' + implementation group: 'com.google.code.gson', name: 'gson', version: '2.9.1' implementation platform("ai.djl:bom:0.19.0") implementation group: 'ai.djl.pytorch', name: 'pytorch-model-zoo' implementation group: 'ai.djl', name: 'api' diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java index 98241b2e76..2f9d1d272e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java @@ -14,21 +14,56 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; +import java.io.ObjectInputFilter; import java.util.Base64; @UtilityClass public class ModelSerDeSer { - // Welcome list includes OpenSearch ml plugin classes, JDK common classes and Tribuo libraries. + // Accept list includes OpenSearch ml plugin classes, JDK common classes and Tribuo libraries. public static final String[] ACCEPT_CLASS_PATTERNS = { "java.lang.*", "java.util.*", "java.time.*", + "org.tribuo.*", + "com.oracle.labs.mlrg.olcut.provenance.*", + "com.oracle.labs.mlrg.olcut.util.*", + "[I", + "[Z", + "[J", + "[C", + "[D", + "[F", + "[Ljava.lang.*", + "[Lorg.tribuo.*", + "[Llibsvm.*", + "[[I", + "[[Z", + "[[J", + "[[C", + "[[D", + "[[F", + "[[Ljava.lang.*", + "[[Lorg.tribuo.*", + "[[Llibsvm.*", "org.opensearch.ml.*", - "*org.tribuo.*", "libsvm.*", - "com.oracle.labs.*", - "[*", - "com.amazon.randomcutforest.*" + }; + + public static final String[] REJECT_CLASS_PATTERNS = { + "java.util.logging.*", + "java.util.zip.*", + "java.util.jar.*", + "java.util.random.*", + "java.util.spi.*", + "java.util.stream.*", + "java.util.regex.*", + "java.util.concurrent.*", + "java.util.function.*", + "java.util.prefs.*", + "java.time.zone.*", + "java.time.format.*", + "java.time.temporal.*", + "java.time.chrono.*", }; public static String serializeToBase64(Object model) { @@ -47,11 +82,15 @@ public static byte[] serialize(Object model) { } } + // This method has been tested in K-means, Linear Regression, Logistic regression, Anomaly Detection and Random Cut Forest summarization and passed. public static Object deserialize(byte[] modelBin) { try (ByteArrayInputStream inputStream = new ByteArrayInputStream(modelBin); ValidatingObjectInputStream validatingObjectInputStream = new ValidatingObjectInputStream(inputStream)){ // Validate the model class type to avoid deserialization attack. - validatingObjectInputStream.accept(ACCEPT_CLASS_PATTERNS); + validatingObjectInputStream + .accept(ACCEPT_CLASS_PATTERNS) + .reject(REJECT_CLASS_PATTERNS) + .setObjectInputFilter(ObjectInputFilter.Config.createFilter("maxdepth=20;maxrefs=5000;maxbytes=10000000;maxarray=100000")); return validatingObjectInputStream.readObject(); } catch (IOException | ClassNotFoundException e) { throw new ModelSerDeSerException("Failed to deserialize model.", e.getCause()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVMTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVMTest.java index 0b44dad07f..f734420a26 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVMTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVMTest.java @@ -24,11 +24,13 @@ import org.opensearch.ml.common.input.parameter.ad.AnomalyDetectionLibSVMParams; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.engine.utils.ModelSerDeSer; import org.tribuo.Dataset; import org.tribuo.Example; import org.tribuo.Feature; import org.tribuo.anomaly.Event; import org.tribuo.anomaly.example.AnomalyDataGenerator; +import org.tribuo.common.libsvm.LibSVMModel; import java.util.ArrayList; import java.util.Iterator; @@ -117,6 +119,13 @@ public void train() { Assert.assertNotNull(model.getContent()); } + @Test + public void testModelSerDeSer() { + MLModel model = anomalyDetection.train(trainDataFrameInput); + LibSVMModel deserializedModel = (LibSVMModel) ModelSerDeSer.deserialize(model); + Assert.assertNotNull(deserializedModel); + } + @Test public void trainWithFullParams() { AnomalyDetectionLibSVMParams parameters = AnomalyDetectionLibSVMParams.builder().gamma(gamma).nu(nu).cost(1.0).coeff(0.01).epsilon(0.001).degree(1).kernelType(AnomalyDetectionLibSVMParams.ADKernelType.LINEAR).build(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarizeTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarizeTest.java index 914633e22c..40a91068b8 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarizeTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarizeTest.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine.algorithms.clustering; +import com.amazon.randomcutforest.returntypes.SampleSummary; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -17,6 +18,7 @@ import org.opensearch.ml.common.input.parameter.clustering.RCFSummarizeParams; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.engine.utils.ModelSerDeSer; import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; @@ -61,6 +63,13 @@ public void predictWithTrivalModelExpectBoNorminalOutput() { predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1)); } + @Test + public void testModelSerDeSer() { + MLModel model = rcfSummarize.train(trainDataFrameInput); + SampleSummary deserializedModel = ((SerializableSummary) ModelSerDeSer.deserialize(model)).getSummary(); + Assert.assertNotNull(deserializedModel); + } + @Test public void trainAndPredictWithRegularInputExpectNotNullOutput() { RCFSummarizeParams parameters = RCFSummarizeParams.builder() diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegressionTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegressionTest.java index f17b9dfb9b..7891e15b16 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegressionTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegressionTest.java @@ -17,6 +17,8 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.engine.utils.ModelSerDeSer; +import org.tribuo.classification.Label; import static org.opensearch.ml.engine.helper.LogisticRegressionHelper.constructLogisticRegressionPredictionDataFrame; import static org.opensearch.ml.engine.helper.LogisticRegressionHelper.constructLogisticRegressionTrainDataFrame; @@ -109,6 +111,14 @@ public void predict() { Assert.assertEquals(2, predictions.size()); } + @Test + public void testModelSerDeSer() { + LogisticRegression classification = new LogisticRegression(parameters); + MLModel model = classification.train(trainDataFrameInput); + org.tribuo.Model