Skip to content

Commit

Permalink
[DROOLS-6625] Validate input data (apache#3864)
Browse files Browse the repository at this point in the history
* [DROOLS-6625] Managing missing "required" input data

* [DROOLS-6625] Managing best-effort conversion of input data

* [DROOLS-6625] Managing invalid values - TODO: integration tests

* [DROOLS-6625] Managing invalid values

* [DROOLS-6625] Managing missing values

* [DROOLS-6625] Validate input data

* [DROOLS-6625] Fix merge with base branch

* [DROOLS-6625] Fix as per PR suggestion
  • Loading branch information
gitgabrio authored Oct 4, 2021
1 parent b09e0f0 commit 48e09ee
Show file tree
Hide file tree
Showing 33 changed files with 1,409 additions and 338 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
import org.kie.pmml.api.exceptions.KieEnumException;

/**
* @see <a href=http://dmg.org/pmml/v4-4/MiningSchema.html#xsdType_MISSING-VALUE-TREATMENT-METHOD>MISSING-VALUE_TREATMENT-METHOD</a>
* @see
* <a href=http://dmg.org/pmml/v4-4/MiningSchema.html#xsdType_MISSING-VALUE-TREATMENT-METHOD>MISSING-VALUE_TREATMENT-METHOD</a>
*/
public enum MISSING_VALUE_TREATMENT_METHOD {

ASSOCIATION_RULES("associationRules"),
SEQUENCES("sequences"),
CLASSIFICATION("classification"),
REGRESSION("regression"),
CLUSTERING("clustering"),
TIME_SERIES("timeSeries"),
MIXED("mixed");
AS_IS("asIs"),
AS_MEAN("asMean"),
AS_MODE("asMode"),
AS_MEDIAN("asMedian"),
AS_VALUE("asValue"),
RETURN_INVALID("returnInvalid");

private String name;

Expand All @@ -40,7 +40,8 @@ public enum MISSING_VALUE_TREATMENT_METHOD {
}

public static MISSING_VALUE_TREATMENT_METHOD byName(String name) {
return Arrays.stream(MISSING_VALUE_TREATMENT_METHOD.values()).filter(value -> Objects.equals(name, value.name)).findFirst().orElseThrow(() -> new KieEnumException("Failed to find MINING_FUNCTION with name: " + name));
return Arrays.stream(MISSING_VALUE_TREATMENT_METHOD.values()).filter(value -> Objects.equals(name,
value.name)).findFirst().orElseThrow(() -> new KieEnumException("Failed to find MISSING_VALUE_TREATMENT_METHOD with name: " + name));
}

public String getName() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2020 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.kie.pmml.api.exceptions;

/**
* <code>RuntimeException</code>s to be wrapping to <b>unchecked</b> ones at <i>customer</i> API boundaries
*/
public class KiePMMLInputDataException extends KiePMMLException {

private static final long serialVersionUID = -6638828457762000141L;

public KiePMMLInputDataException(String message, Throwable cause) {
super(message, cause);
}

public KiePMMLInputDataException(Throwable cause) {
super(cause);
}

public KiePMMLInputDataException(String message) {
super(message);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ public abstract class KiePMMLModel extends AbstractKiePMMLComponent implements P
protected MINING_FUNCTION miningFunction;
protected String targetField;
protected Map<String, Object> outputFieldsMap = new HashMap<>();
protected Map<String, Object> missingValueReplacementMap = new HashMap<>();
protected List<MiningField> miningFields = new ArrayList<>();
protected List<OutputField> outputFields = new ArrayList<>();
protected List<KiePMMLOutputField> kiePMMLOutputFields = new ArrayList<>();
Expand Down Expand Up @@ -75,10 +74,6 @@ public Map<String, Object> getOutputFieldsMap() {
return Collections.unmodifiableMap(outputFieldsMap);
}

public Map<String, Object> getMissingValueReplacementMap() {
return Collections.unmodifiableMap(missingValueReplacementMap);
}

/**
* Method to retrieve the <b>package</b> name to be used inside kiebase/package attribute of
* kmodule.xml and to use for package creation inside PMMLAssemblerService
Expand Down Expand Up @@ -118,7 +113,8 @@ public void setKiePMMLTargets(List<KiePMMLTarget> kiePMMLTargets) {
}

public List<KiePMMLOutputField> getKiePMMLOutputFields() {
return kiePMMLOutputFields != null ? Collections.unmodifiableList(kiePMMLOutputFields) : Collections.emptyList();
return kiePMMLOutputFields != null ? Collections.unmodifiableList(kiePMMLOutputFields) :
Collections.emptyList();
}

public KiePMMLTransformationDictionary getTransformationDictionary() {
Expand All @@ -131,7 +127,8 @@ public KiePMMLLocalTransformations getLocalTransformations() {

public Map<String, Double> getProbabilityMap() {
final LinkedHashMap<String, Double> probabilityResultMap = getProbabilityResultMap();
return probabilityResultMap != null ? Collections.unmodifiableMap(getFixedProbabilityMap(probabilityResultMap)) : Collections.emptyMap();
return probabilityResultMap != null ?
Collections.unmodifiableMap(getFixedProbabilityMap(probabilityResultMap)) : Collections.emptyMap();
}

public Object getPredictedDisplayValue() {
Expand Down Expand Up @@ -205,12 +202,9 @@ public Builder<T> withTargetField(String targetField) {
}

public Builder<T> withOutputFieldsMap(Map<String, Object> outputFieldsMap) {
toBuild.outputFieldsMap.putAll(outputFieldsMap);
return this;
}

public Builder<T> withMissingValueReplacementMap(Map<String, Object> missingValueReplacementMap) {
toBuild.missingValueReplacementMap.putAll(missingValueReplacementMap);
if (outputFieldsMap != null) {
toBuild.outputFieldsMap.putAll(outputFieldsMap);
}
return this;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.kie.pmml.api.enums.MINING_FUNCTION;
import org.kie.pmml.api.enums.PMML_MODEL;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.models.MiningField;
import org.kie.pmml.commons.model.KiePMMLExtension;
import org.kie.pmml.commons.model.KiePMMLModel;
import org.kie.pmml.commons.model.KiePMMLOutputField;
Expand Down Expand Up @@ -80,5 +81,12 @@ public Builder withLocalTransformations(final KiePMMLLocalTransformations localT
toBuild.localTransformations = localTransformations;
return this;
}

public Builder withMiningFields(final List<MiningField> miningFields) {
if (miningFields != null) {
toBuild.miningFields = miningFields;
}
return this;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
*/
package org.kie.pmml.compiler.commons.builders;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.ConstructorDeclaration;
Expand All @@ -29,11 +26,9 @@
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.expr.ThisExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.utils.Pair;
import org.dmg.pmml.Field;
import org.dmg.pmml.Model;
import org.dmg.pmml.TransformationDictionary;
import org.kie.pmml.api.enums.DATA_TYPE;
import org.kie.pmml.api.enums.MINING_FUNCTION;
import org.kie.pmml.api.enums.PMML_MODEL;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
Expand All @@ -60,7 +55,8 @@ private KiePMMLModelCodegenUtils() {
}

/**
* Initialize the given <code>ClassOrInterfaceDeclaration</code> with all the <b>common</b> code needed to generate a <code>KiePMMLModel</code>
* Initialize the given <code>ClassOrInterfaceDeclaration</code> with all the <b>common</b> code needed to
* generate a <code>KiePMMLModel</code>
* @param modelTemplate
* @param fields
* @param transformationDictionary
Expand All @@ -70,10 +66,12 @@ public static void init(final ClassOrInterfaceDeclaration modelTemplate,
final List<Field<?>> fields,
final TransformationDictionary transformationDictionary,
final Model pmmlModel) {
final ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_DEFAULT_CONSTRUCTOR, modelTemplate.getName())));
final ConstructorDeclaration constructorDeclaration =
modelTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_DEFAULT_CONSTRUCTOR, modelTemplate.getName())));
final String name = pmmlModel.getModelName();
final String generatedClassName = getSanitizedClassName(name);
final List<MiningField> miningFields = ModelUtils.convertToKieMiningFieldList(pmmlModel.getMiningSchema(), fields);
final List<MiningField> miningFields = ModelUtils.convertToKieMiningFieldList(pmmlModel.getMiningSchema(),
fields);
final List<OutputField> outputFields = ModelUtils.convertToKieOutputFieldList(pmmlModel.getOutput(), fields);
final Expression miningFunctionExpression;
if (pmmlModel.getMiningFunction() != null) {
Expand All @@ -83,17 +81,22 @@ public static void init(final ClassOrInterfaceDeclaration modelTemplate,
miningFunctionExpression = new NullLiteralExpr();
}
final PMML_MODEL pmmlModelEnum = PMML_MODEL.byName(pmmlModel.getClass().getSimpleName());
final NameExpr pmmlMODELExpression = new NameExpr(pmmlModelEnum.getClass().getName() + "." + pmmlModelEnum.name());
final NameExpr pmmlMODELExpression =
new NameExpr(pmmlModelEnum.getClass().getName() + "." + pmmlModelEnum.name());
String targetFieldName = getTargetFieldName(fields, pmmlModel).orElse(null);
final Expression targetFieldExpression;
if (targetFieldName != null) {
targetFieldExpression = new StringLiteralExpr(targetFieldName);
} else {
targetFieldExpression = new NullLiteralExpr();
}
Map<String, Pair<DATA_TYPE, String>> missingValueReplacements = getMissingValueReplacementsMap(fields, pmmlModel);
setKiePMMLModelConstructor(generatedClassName, constructorDeclaration, name, miningFields, outputFields, missingValueReplacements);
addTransformationsInClassOrInterfaceDeclaration(modelTemplate, transformationDictionary, pmmlModel.getLocalTransformations());
setKiePMMLModelConstructor(generatedClassName,
constructorDeclaration,
name,
miningFields,
outputFields);
addTransformationsInClassOrInterfaceDeclaration(modelTemplate, transformationDictionary,
pmmlModel.getLocalTransformations());
final BlockStmt body = constructorDeclaration.getBody();
CommonCodegenUtils.setAssignExpressionValue(body, "pmmlMODEL", pmmlMODELExpression);
CommonCodegenUtils.setAssignExpressionValue(body, "miningFunction", miningFunctionExpression);
Expand All @@ -106,19 +109,4 @@ public static void init(final ClassOrInterfaceDeclaration modelTemplate,
CommonCodegenUtils.setAssignExpressionValue(body, "kiePMMLOutputFields", getCreatedKiePMMLOutputFieldsExpr);
}
}

static Map<String, Pair<DATA_TYPE, String>> getMissingValueReplacementsMap(final List<Field<?>> fields, Model pmmlModel) {
Map<String, DATA_TYPE> dataTypeMap = fields.stream()
.collect(Collectors.toMap(i -> i.getName().getValue(),
i -> DATA_TYPE.byName(i.getDataType().value()),
(prevDataType, newDataType) -> newDataType));
return pmmlModel.getMiningSchema() == null
? Collections.emptyMap()
: pmmlModel.getMiningSchema().getMiningFields().stream()
.filter(mf -> mf.getMissingValueReplacement() instanceof String)
.collect(Collectors.toMap(
mf -> mf.getName().getValue(),
mf -> new Pair<>(dataTypeMap.get(mf.getName().getValue()), (String) mf.getMissingValueReplacement())
));
}
}
Loading

0 comments on commit 48e09ee

Please sign in to comment.