From 7630e36990fac828707251a1a05cf4d5e4ac4fca Mon Sep 17 00:00:00 2001 From: Teddy Crepineau Date: Mon, 12 Feb 2024 12:16:33 +0100 Subject: [PATCH 1/3] feat: added severity classifier for DQ incidents --- conf/openmetadata.yaml | 3 + .../service/OpenMetadataApplication.java | 4 + .../OpenMetadataApplicationConfig.java | 4 + .../TestCaseResolutionStatusRepository.java | 19 +++ .../IncidentSeverityClassifierInterface.java | 44 ++++++ ...cRegressionIncidentSeverityClassifier.java | 127 ++++++++++++++++++ .../resources/openmetadata-secure-test.yaml | 6 +- .../dataQualityConfiguration.json | 18 +++ 8 files changed, 224 insertions(+), 1 deletion(-) create mode 100644 openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/IncidentSeverityClassifierInterface.java create mode 100644 openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/LogisticRegressionIncidentSeverityClassifier.java create mode 100644 openmetadata-spec/src/main/resources/json/schema/configuration/dataQualityConfiguration.json diff --git a/conf/openmetadata.yaml b/conf/openmetadata.yaml index 02cd22ccc7b7..2475b86b804d 100644 --- a/conf/openmetadata.yaml +++ b/conf/openmetadata.yaml @@ -370,3 +370,6 @@ web: permission-policy: enabled: ${WEB_CONF_PERMISSION_POLICY_ENABLED:-false} option: ${WEB_CONF_PERMISSION_POLICY_OPTION:-""} + +dataQualityConfiguration: + severityIncidentClassifier: ${DATA_QUALITY_SEVERITY_INCIDENT_CLASSIFIER:-"org.openmetadata.service.util.incidentSeverityClassifier.LogisticRegressionIncidentSeverityClassifier"} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplication.java b/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplication.java index bf73509e087f..418e67086a90 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplication.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplication.java @@ -113,6 +113,7 @@ import org.openmetadata.service.socket.SocketAddressFilter; import org.openmetadata.service.socket.WebSocketManager; import org.openmetadata.service.util.MicrometerBundleSingleton; +import org.openmetadata.service.util.incidentSeverityClassifier.IncidentSeverityClassifierInterface; import org.openmetadata.service.util.jdbi.DatabaseAuthenticationProviderFactory; import org.quartz.SchedulerException; @@ -136,6 +137,9 @@ public void run(OpenMetadataApplicationConfig catalogConfig, Environment environ NoSuchAlgorithmException { validateConfiguration(catalogConfig); + // Instantiate incident severity classifier + IncidentSeverityClassifierInterface.createInstance(catalogConfig.getDataQualityConfiguration()); + // init for dataSourceFactory DatasourceConfig.initialize(catalogConfig.getDataSourceFactory().getDriverClass()); diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplicationConfig.java b/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplicationConfig.java index ca145ca63d89..47deae5398f9 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplicationConfig.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/OpenMetadataApplicationConfig.java @@ -23,6 +23,7 @@ import lombok.Getter; import lombok.Setter; import org.openmetadata.schema.api.configuration.apps.AppsPrivateConfiguration; +import org.openmetadata.schema.api.configuration.dataQuality.DataQualityConfiguration; import org.openmetadata.schema.api.configuration.events.EventHandlerConfiguration; import org.openmetadata.schema.api.configuration.pipelineServiceClient.PipelineServiceClientConfiguration; import org.openmetadata.schema.api.fernet.FernetConfiguration; @@ -94,6 +95,9 @@ public class OpenMetadataApplicationConfig extends Configuration { @JsonProperty("web") private OMWebConfiguration webConfiguration = new OMWebConfiguration(); + @JsonProperty("dataQualityConfiguration") + private DataQualityConfiguration dataQualityConfiguration; + @JsonProperty("applications") private AppsPrivateConfiguration appsPrivateConfiguration; diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TestCaseResolutionStatusRepository.java b/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TestCaseResolutionStatusRepository.java index fcf19d488a3b..1711873b23f9 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TestCaseResolutionStatusRepository.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TestCaseResolutionStatusRepository.java @@ -13,12 +13,14 @@ import javax.json.JsonPatch; import javax.ws.rs.core.Response; import org.jdbi.v3.sqlobject.transaction.Transaction; +import org.openmetadata.schema.EntityInterface; import org.openmetadata.schema.api.feed.ResolveTask; import org.openmetadata.schema.entity.feed.Thread; import org.openmetadata.schema.entity.teams.User; import org.openmetadata.schema.tests.TestCase; import org.openmetadata.schema.tests.type.Assigned; import org.openmetadata.schema.tests.type.Resolved; +import org.openmetadata.schema.tests.type.Severity; import org.openmetadata.schema.tests.type.TestCaseResolutionStatus; import org.openmetadata.schema.tests.type.TestCaseResolutionStatusTypes; import org.openmetadata.schema.type.EntityReference; @@ -35,6 +37,7 @@ import org.openmetadata.service.util.JsonUtils; import org.openmetadata.service.util.RestUtil; import org.openmetadata.service.util.ResultList; +import org.openmetadata.service.util.incidentSeverityClassifier.IncidentSeverityClassifierInterface; public class TestCaseResolutionStatusRepository extends EntityTimeSeriesRepository { @@ -168,6 +171,8 @@ public TestCaseResolutionStatus createNewRecord( : recordEntity.getSeverity()); } + inferIncidentSeverity(recordEntity); + switch (recordEntity.getTestCaseResolutionStatusType()) { case New -> { // If there is already an existing New incident we'll return it @@ -300,4 +305,18 @@ private void patchTaskAssignee(Thread originalTask, EntityReference newAssignee, FeedRepository feedRepository = Entity.getFeedRepository(); feedRepository.patchThread(null, originalTask.getId(), user, patch); } + + public void inferIncidentSeverity(TestCaseResolutionStatus incident) { + if (incident.getSeverity() != null) { + // If the severity is already set, we don't need to infer it + return; + } + IncidentSeverityClassifierInterface incidentSeverityClassifier = IncidentSeverityClassifierInterface.getInstance(); + EntityReference testCaseReference = incident.getTestCaseReference(); + TestCase testCase = Entity.getEntityByName(testCaseReference.getType(), testCaseReference.getFullyQualifiedName(), "", Include.ALL); + MessageParser.EntityLink entityLink = MessageParser.EntityLink.parse(testCase.getEntityLink()); + EntityInterface entity = Entity.getEntityByName(entityLink.getEntityType(), entityLink.getEntityFQN(), "followers,owner,tags,votes", Include.ALL); + Severity severity = incidentSeverityClassifier.classifyIncidentSeverity(entity); + incident.setSeverity(severity); + } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/IncidentSeverityClassifierInterface.java b/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/IncidentSeverityClassifierInterface.java new file mode 100644 index 000000000000..08d9f6923324 --- /dev/null +++ b/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/IncidentSeverityClassifierInterface.java @@ -0,0 +1,44 @@ +package org.openmetadata.service.util.incidentSeverityClassifier; + +import lombok.extern.slf4j.Slf4j; +import org.openmetadata.schema.EntityInterface; +import org.openmetadata.schema.api.configuration.dataQuality.DataQualityConfiguration; +import org.openmetadata.schema.tests.type.Severity; +import org.openmetadata.service.Entity; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; + +@Slf4j +public abstract class IncidentSeverityClassifierInterface { + protected static IncidentSeverityClassifierInterface instance; + + public static IncidentSeverityClassifierInterface getInstance() { + if (instance == null) { + LOG.info("Incident severity classifier instance is null. Default to LogisticRegressionClassifier"); + instance = new LogisticRegressionIncidentSeverityClassifier(); + } + return instance; + } + + public static void createInstance(DataQualityConfiguration dataQualityConfiguration) { + instance = getClassifierClass(dataQualityConfiguration.getSeverityIncidentClassifier()); + } + + private static IncidentSeverityClassifierInterface getClassifierClass(String severityClassifierClassString) { + IncidentSeverityClassifierInterface incidentSeverityClassifier; + try { + Class severityClassifierClass = Class.forName(severityClassifierClassString); + Constructor severityClassifierConstructor = severityClassifierClass.getConstructor(); + incidentSeverityClassifier = (IncidentSeverityClassifierInterface) severityClassifierConstructor.newInstance(); + } catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InstantiationException | + InvocationTargetException e) { + LOG.error("Error occurred while initializing the incident severity classifier. Default to LogisticRegressionClassifier", e); + // If we encounter an error while initializing the incident severity classifier, we default to the logistic regression classifier + incidentSeverityClassifier = new LogisticRegressionIncidentSeverityClassifier(); + } + return incidentSeverityClassifier; + } + + public abstract Severity classifyIncidentSeverity(EntityInterface entity); +} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/LogisticRegressionIncidentSeverityClassifier.java b/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/LogisticRegressionIncidentSeverityClassifier.java new file mode 100644 index 000000000000..b06412c4a03b --- /dev/null +++ b/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/LogisticRegressionIncidentSeverityClassifier.java @@ -0,0 +1,127 @@ +package org.openmetadata.service.util.incidentSeverityClassifier; + +import lombok.extern.slf4j.Slf4j; +import org.openmetadata.schema.EntityInterface; +import org.openmetadata.schema.tests.type.Severity; +import org.openmetadata.schema.type.TagLabel; + +import java.util.Arrays; +import java.util.List; + +@Slf4j +public class LogisticRegressionIncidentSeverityClassifier extends IncidentSeverityClassifierInterface { + // coef. matrix represents the weights of the logistic regression model. It has shape + // (n_feature, n_class) where n_features are respectively: + // - row 0: 'Tier' (1, 2, 3, 4, 5) for an asset + // - row 1: 'HasOwner' 1 if the asset has an owner, 0 otherwise + // - row 2: 'Followers' number of followers of the asset + // - row 3: 'Votes' number of + votes of the asset. + // Coef. matrix was generated using synthetic data. + static final double[][] coefMatrix = { + new double[]{-39.7199427, -3.16664212, 7.52955733, 16.7600252, 18.5970022}, + new double[]{65.6563864, 9.33015912, -3.11353307, -13.7841793, -58.0888332}, + new double[]{0.0102508192, 0.00490356651, -0.00162766138, -0.00622724217, -0.0072994822}, + new double[]{0.0784018717, -0.01140259, -0.00911123152, -0.0237962385, -0.0340918118}, + }; + + @Override + public Severity classifyIncidentSeverity(EntityInterface entity) { + double[] vectorX = getVectorX(entity); + if (vectorX.length == 0) { + return null; + } + try { + double[] vectorZ = dotProduct(vectorX); + double[] softmaxVector = softmax(vectorZ); + int predictedClass = argmax(softmaxVector); + switch (predictedClass) { + case 0: + return Severity.Severity1; + case 1: + return Severity.Severity2; + case 2: + return Severity.Severity3; + case 3: + return Severity.Severity4; + case 4: + return Severity.Severity5; + } + } catch (Exception e) { + LOG.error("Error occurred while classifying incident severity", e); + } + return null; + } + + private double[] dotProduct(double[] vectorX) { + // compute the dot product of the input vector and the coef. matrix + double[] result = new double[coefMatrix[0].length]; + for (int i = 0; i < coefMatrix.length; i++) { + int sum = 0; + for (int j = 0; j < vectorX.length; j++) { + sum += vectorX[j] * coefMatrix[j][i]; + } + result[i] = sum; + } + return result; + } + + private double[] softmax(double[] vectorZ) { + // compute the softmax of the z vector + double expSum = Arrays.stream(vectorZ).map(Math::exp).sum(); + double[] softmax = new double[vectorZ.length]; + for (int i = 0; i < vectorZ.length; i++) { + softmax[i] = Math.exp(vectorZ[i]) / expSum; + } + return softmax; + } + + private int argmax(double[] softmaxVector) { + // return the index of the max value in the softmax vector + // (i.e. the predicted class) + int maxIndex = 0; + double argmax = 0; + + for (int i = 0 ; i < softmaxVector.length; i++) { + if (softmaxVector[i] > argmax) { + argmax = softmaxVector[i]; + maxIndex = i; + } + } + return maxIndex; + } + + private double[] getVectorX(EntityInterface entity) { + // get the input vector for the logistic regression model + double hasOwner = entity.getOwner() != null ? 1 : 0; + double followers = entity.getFollowers() != null ? entity.getFollowers().size() : 0; + double votes = entity.getVotes() != null ? entity.getVotes().getUpVotes() : 0; + double tier = entity.getTags() != null ? getTier(entity.getTags()) : 0; + if (tier == 0) { + // if we don't have a tier set we can't run the classifier + return new double[]{}; + } + return new double[]{tier, hasOwner, followers, votes}; + } + + private double getTier(List tags) { + // get the tier of the asset + + for (TagLabel tag : tags) { + if (tag.getName().contains("Tier")) { + switch (tag.getName()) { + case "Tier1": + return 1; + case "Tier2": + return 2; + case "Tier3": + return 3; + case "Tier4": + return 4; + case "Tier5": + return 5; + } + } + } + return 0; + } +} diff --git a/openmetadata-service/src/test/resources/openmetadata-secure-test.yaml b/openmetadata-service/src/test/resources/openmetadata-secure-test.yaml index d8a9cb81d984..332d6c5510b6 100644 --- a/openmetadata-service/src/test/resources/openmetadata-secure-test.yaml +++ b/openmetadata-service/src/test/resources/openmetadata-secure-test.yaml @@ -222,4 +222,8 @@ email: serverPort: "" username: "" password: "" - transportationStrategy: "SMTP_TLS" \ No newline at end of file + transportationStrategy: "SMTP_TLS" + + +dataQualityConfiguration: + severityIncidentClassifier: "org.openmetadata.service.util.incidentSeverityClassifier.LogisticRegressionIncidentSeverityClassifier" diff --git a/openmetadata-spec/src/main/resources/json/schema/configuration/dataQualityConfiguration.json b/openmetadata-spec/src/main/resources/json/schema/configuration/dataQualityConfiguration.json new file mode 100644 index 000000000000..90387739a943 --- /dev/null +++ b/openmetadata-spec/src/main/resources/json/schema/configuration/dataQualityConfiguration.json @@ -0,0 +1,18 @@ +{ + "$id": "https://open-metadata.org/schema/entity/configuration/dataQualityConfiguration.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "DataQualityConfiguration", + "description": "This schema defines the Data Quality Configuration", + "type": "object", + "javaType": "org.openmetadata.schema.api.configuration.dataQuality.DataQualityConfiguration", + "properties": { + "severityIncidentClassifier": { + "description": "Class Name for the severity incident classifier.", + "type": "string" + } + }, + "required": [ + "severityIncidentClassifier" + ], + "additionalProperties": false +} From 047f9f282815184b0b5740ffd588acdf92b7ca14 Mon Sep 17 00:00:00 2001 From: Teddy Crepineau Date: Mon, 12 Feb 2024 12:16:58 +0100 Subject: [PATCH 2/3] feat: added severity classifier tests --- .../dqtests/TestCaseResourceTest.java | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/openmetadata-service/src/test/java/org/openmetadata/service/resources/dqtests/TestCaseResourceTest.java b/openmetadata-service/src/test/java/org/openmetadata/service/resources/dqtests/TestCaseResourceTest.java index 33febca4c893..a33bdca5aa1c 100644 --- a/openmetadata-service/src/test/java/org/openmetadata/service/resources/dqtests/TestCaseResourceTest.java +++ b/openmetadata-service/src/test/java/org/openmetadata/service/resources/dqtests/TestCaseResourceTest.java @@ -69,6 +69,7 @@ import org.openmetadata.schema.type.ChangeDescription; import org.openmetadata.schema.type.Column; import org.openmetadata.schema.type.ColumnDataType; +import org.openmetadata.schema.type.TagLabel; import org.openmetadata.schema.type.TaskStatus; import org.openmetadata.service.Entity; import org.openmetadata.service.resources.EntityResourceTest; @@ -77,6 +78,7 @@ import org.openmetadata.service.util.JsonUtils; import org.openmetadata.service.util.ResultList; import org.openmetadata.service.util.TestUtils; +import org.openmetadata.service.util.incidentSeverityClassifier.IncidentSeverityClassifierInterface; @TestMethodOrder(MethodOrderer.OrderAnnotation.class) @Slf4j @@ -1565,6 +1567,23 @@ public void unauthorizedTestCaseResolutionFlow(TestInfo test) "Incident with status [Assigned] cannot be moved to [Ack]"); } + @Test + public void testInferSeverity(TestInfo test) { + IncidentSeverityClassifierInterface severityClassifier = IncidentSeverityClassifierInterface.getInstance(); + // TEST_TABLE1 has no tier information, hence severity should be null as the classifier won't be able to infer + Severity severity = severityClassifier.classifyIncidentSeverity(TEST_TABLE1); + assertNull(severity); + + List tags = new ArrayList<>(); + tags.add(new TagLabel().withTagFQN("Tier.Tier1").withName("Tier1")); + TEST_TABLE1.setTags(tags); + + // With tier set to Tier1, the severity should be inferred + severity = severityClassifier.classifyIncidentSeverity(TEST_TABLE1); + assertNotNull(severity); + + } + public void deleteTestCaseResult(String fqn, Long timestamp, Map authHeaders) throws HttpResponseException { WebTarget target = getCollection().path("/" + fqn + "/testCaseResult/" + timestamp); From e5164149aa0e80661c673838d4f4876b936ae894 Mon Sep 17 00:00:00 2001 From: Teddy Crepineau Date: Mon, 12 Feb 2024 12:17:54 +0100 Subject: [PATCH 3/3] style: ran java linting --- .../TestCaseResolutionStatusRepository.java | 17 +- .../IncidentSeverityClassifierInterface.java | 63 +++--- ...cRegressionIncidentSeverityClassifier.java | 206 +++++++++--------- .../dqtests/TestCaseResourceTest.java | 7 +- 4 files changed, 156 insertions(+), 137 deletions(-) diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TestCaseResolutionStatusRepository.java b/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TestCaseResolutionStatusRepository.java index 1711873b23f9..6eafdd0bf397 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TestCaseResolutionStatusRepository.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/TestCaseResolutionStatusRepository.java @@ -311,11 +311,22 @@ public void inferIncidentSeverity(TestCaseResolutionStatus incident) { // If the severity is already set, we don't need to infer it return; } - IncidentSeverityClassifierInterface incidentSeverityClassifier = IncidentSeverityClassifierInterface.getInstance(); + IncidentSeverityClassifierInterface incidentSeverityClassifier = + IncidentSeverityClassifierInterface.getInstance(); EntityReference testCaseReference = incident.getTestCaseReference(); - TestCase testCase = Entity.getEntityByName(testCaseReference.getType(), testCaseReference.getFullyQualifiedName(), "", Include.ALL); + TestCase testCase = + Entity.getEntityByName( + testCaseReference.getType(), + testCaseReference.getFullyQualifiedName(), + "", + Include.ALL); MessageParser.EntityLink entityLink = MessageParser.EntityLink.parse(testCase.getEntityLink()); - EntityInterface entity = Entity.getEntityByName(entityLink.getEntityType(), entityLink.getEntityFQN(), "followers,owner,tags,votes", Include.ALL); + EntityInterface entity = + Entity.getEntityByName( + entityLink.getEntityType(), + entityLink.getEntityFQN(), + "followers,owner,tags,votes", + Include.ALL); Severity severity = incidentSeverityClassifier.classifyIncidentSeverity(entity); incident.setSeverity(severity); } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/IncidentSeverityClassifierInterface.java b/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/IncidentSeverityClassifierInterface.java index 08d9f6923324..9b1b7845d363 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/IncidentSeverityClassifierInterface.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/IncidentSeverityClassifierInterface.java @@ -1,44 +1,51 @@ package org.openmetadata.service.util.incidentSeverityClassifier; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; import lombok.extern.slf4j.Slf4j; import org.openmetadata.schema.EntityInterface; import org.openmetadata.schema.api.configuration.dataQuality.DataQualityConfiguration; import org.openmetadata.schema.tests.type.Severity; -import org.openmetadata.service.Entity; - -import java.lang.reflect.Constructor; -import java.lang.reflect.InvocationTargetException; @Slf4j public abstract class IncidentSeverityClassifierInterface { - protected static IncidentSeverityClassifierInterface instance; + protected static IncidentSeverityClassifierInterface instance; - public static IncidentSeverityClassifierInterface getInstance() { - if (instance == null) { - LOG.info("Incident severity classifier instance is null. Default to LogisticRegressionClassifier"); - instance = new LogisticRegressionIncidentSeverityClassifier(); - } - return instance; + public static IncidentSeverityClassifierInterface getInstance() { + if (instance == null) { + LOG.info( + "Incident severity classifier instance is null. Default to LogisticRegressionClassifier"); + instance = new LogisticRegressionIncidentSeverityClassifier(); } + return instance; + } - public static void createInstance(DataQualityConfiguration dataQualityConfiguration) { - instance = getClassifierClass(dataQualityConfiguration.getSeverityIncidentClassifier()); - } + public static void createInstance(DataQualityConfiguration dataQualityConfiguration) { + instance = getClassifierClass(dataQualityConfiguration.getSeverityIncidentClassifier()); + } - private static IncidentSeverityClassifierInterface getClassifierClass(String severityClassifierClassString) { - IncidentSeverityClassifierInterface incidentSeverityClassifier; - try { - Class severityClassifierClass = Class.forName(severityClassifierClassString); - Constructor severityClassifierConstructor = severityClassifierClass.getConstructor(); - incidentSeverityClassifier = (IncidentSeverityClassifierInterface) severityClassifierConstructor.newInstance(); - } catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InstantiationException | - InvocationTargetException e) { - LOG.error("Error occurred while initializing the incident severity classifier. Default to LogisticRegressionClassifier", e); - // If we encounter an error while initializing the incident severity classifier, we default to the logistic regression classifier - incidentSeverityClassifier = new LogisticRegressionIncidentSeverityClassifier(); - } - return incidentSeverityClassifier; + private static IncidentSeverityClassifierInterface getClassifierClass( + String severityClassifierClassString) { + IncidentSeverityClassifierInterface incidentSeverityClassifier; + try { + Class severityClassifierClass = Class.forName(severityClassifierClassString); + Constructor severityClassifierConstructor = severityClassifierClass.getConstructor(); + incidentSeverityClassifier = + (IncidentSeverityClassifierInterface) severityClassifierConstructor.newInstance(); + } catch (ClassNotFoundException + | NoSuchMethodException + | IllegalAccessException + | InstantiationException + | InvocationTargetException e) { + LOG.error( + "Error occurred while initializing the incident severity classifier. Default to LogisticRegressionClassifier", + e); + // If we encounter an error while initializing the incident severity classifier, we default to + // the logistic regression classifier + incidentSeverityClassifier = new LogisticRegressionIncidentSeverityClassifier(); } + return incidentSeverityClassifier; + } - public abstract Severity classifyIncidentSeverity(EntityInterface entity); + public abstract Severity classifyIncidentSeverity(EntityInterface entity); } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/LogisticRegressionIncidentSeverityClassifier.java b/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/LogisticRegressionIncidentSeverityClassifier.java index b06412c4a03b..89177ed76b7b 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/LogisticRegressionIncidentSeverityClassifier.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/util/incidentSeverityClassifier/LogisticRegressionIncidentSeverityClassifier.java @@ -1,127 +1,127 @@ package org.openmetadata.service.util.incidentSeverityClassifier; +import java.util.Arrays; +import java.util.List; import lombok.extern.slf4j.Slf4j; import org.openmetadata.schema.EntityInterface; import org.openmetadata.schema.tests.type.Severity; import org.openmetadata.schema.type.TagLabel; -import java.util.Arrays; -import java.util.List; - @Slf4j -public class LogisticRegressionIncidentSeverityClassifier extends IncidentSeverityClassifierInterface { - // coef. matrix represents the weights of the logistic regression model. It has shape - // (n_feature, n_class) where n_features are respectively: - // - row 0: 'Tier' (1, 2, 3, 4, 5) for an asset - // - row 1: 'HasOwner' 1 if the asset has an owner, 0 otherwise - // - row 2: 'Followers' number of followers of the asset - // - row 3: 'Votes' number of + votes of the asset. - // Coef. matrix was generated using synthetic data. - static final double[][] coefMatrix = { - new double[]{-39.7199427, -3.16664212, 7.52955733, 16.7600252, 18.5970022}, - new double[]{65.6563864, 9.33015912, -3.11353307, -13.7841793, -58.0888332}, - new double[]{0.0102508192, 0.00490356651, -0.00162766138, -0.00622724217, -0.0072994822}, - new double[]{0.0784018717, -0.01140259, -0.00911123152, -0.0237962385, -0.0340918118}, - }; +public class LogisticRegressionIncidentSeverityClassifier + extends IncidentSeverityClassifierInterface { + // coef. matrix represents the weights of the logistic regression model. It has shape + // (n_feature, n_class) where n_features are respectively: + // - row 0: 'Tier' (1, 2, 3, 4, 5) for an asset + // - row 1: 'HasOwner' 1 if the asset has an owner, 0 otherwise + // - row 2: 'Followers' number of followers of the asset + // - row 3: 'Votes' number of + votes of the asset. + // Coef. matrix was generated using synthetic data. + static final double[][] coefMatrix = { + new double[] {-39.7199427, -3.16664212, 7.52955733, 16.7600252, 18.5970022}, + new double[] {65.6563864, 9.33015912, -3.11353307, -13.7841793, -58.0888332}, + new double[] {0.0102508192, 0.00490356651, -0.00162766138, -0.00622724217, -0.0072994822}, + new double[] {0.0784018717, -0.01140259, -0.00911123152, -0.0237962385, -0.0340918118}, + }; - @Override - public Severity classifyIncidentSeverity(EntityInterface entity) { - double[] vectorX = getVectorX(entity); - if (vectorX.length == 0) { - return null; - } - try { - double[] vectorZ = dotProduct(vectorX); - double[] softmaxVector = softmax(vectorZ); - int predictedClass = argmax(softmaxVector); - switch (predictedClass) { - case 0: - return Severity.Severity1; - case 1: - return Severity.Severity2; - case 2: - return Severity.Severity3; - case 3: - return Severity.Severity4; - case 4: - return Severity.Severity5; - } - } catch (Exception e) { - LOG.error("Error occurred while classifying incident severity", e); - } - return null; + @Override + public Severity classifyIncidentSeverity(EntityInterface entity) { + double[] vectorX = getVectorX(entity); + if (vectorX.length == 0) { + return null; + } + try { + double[] vectorZ = dotProduct(vectorX); + double[] softmaxVector = softmax(vectorZ); + int predictedClass = argmax(softmaxVector); + switch (predictedClass) { + case 0: + return Severity.Severity1; + case 1: + return Severity.Severity2; + case 2: + return Severity.Severity3; + case 3: + return Severity.Severity4; + case 4: + return Severity.Severity5; + } + } catch (Exception e) { + LOG.error("Error occurred while classifying incident severity", e); } + return null; + } - private double[] dotProduct(double[] vectorX) { - // compute the dot product of the input vector and the coef. matrix - double[] result = new double[coefMatrix[0].length]; - for (int i = 0; i < coefMatrix.length; i++) { - int sum = 0; - for (int j = 0; j < vectorX.length; j++) { - sum += vectorX[j] * coefMatrix[j][i]; - } - result[i] = sum; - } - return result; + private double[] dotProduct(double[] vectorX) { + // compute the dot product of the input vector and the coef. matrix + double[] result = new double[coefMatrix[0].length]; + for (int i = 0; i < coefMatrix.length; i++) { + int sum = 0; + for (int j = 0; j < vectorX.length; j++) { + sum += vectorX[j] * coefMatrix[j][i]; + } + result[i] = sum; } + return result; + } - private double[] softmax(double[] vectorZ) { - // compute the softmax of the z vector - double expSum = Arrays.stream(vectorZ).map(Math::exp).sum(); - double[] softmax = new double[vectorZ.length]; - for (int i = 0; i < vectorZ.length; i++) { - softmax[i] = Math.exp(vectorZ[i]) / expSum; - } - return softmax; + private double[] softmax(double[] vectorZ) { + // compute the softmax of the z vector + double expSum = Arrays.stream(vectorZ).map(Math::exp).sum(); + double[] softmax = new double[vectorZ.length]; + for (int i = 0; i < vectorZ.length; i++) { + softmax[i] = Math.exp(vectorZ[i]) / expSum; } + return softmax; + } - private int argmax(double[] softmaxVector) { - // return the index of the max value in the softmax vector - // (i.e. the predicted class) - int maxIndex = 0; - double argmax = 0; + private int argmax(double[] softmaxVector) { + // return the index of the max value in the softmax vector + // (i.e. the predicted class) + int maxIndex = 0; + double argmax = 0; - for (int i = 0 ; i < softmaxVector.length; i++) { - if (softmaxVector[i] > argmax) { - argmax = softmaxVector[i]; - maxIndex = i; - } - } - return maxIndex; + for (int i = 0; i < softmaxVector.length; i++) { + if (softmaxVector[i] > argmax) { + argmax = softmaxVector[i]; + maxIndex = i; + } } + return maxIndex; + } - private double[] getVectorX(EntityInterface entity) { - // get the input vector for the logistic regression model - double hasOwner = entity.getOwner() != null ? 1 : 0; - double followers = entity.getFollowers() != null ? entity.getFollowers().size() : 0; - double votes = entity.getVotes() != null ? entity.getVotes().getUpVotes() : 0; - double tier = entity.getTags() != null ? getTier(entity.getTags()) : 0; - if (tier == 0) { - // if we don't have a tier set we can't run the classifier - return new double[]{}; - } - return new double[]{tier, hasOwner, followers, votes}; + private double[] getVectorX(EntityInterface entity) { + // get the input vector for the logistic regression model + double hasOwner = entity.getOwner() != null ? 1 : 0; + double followers = entity.getFollowers() != null ? entity.getFollowers().size() : 0; + double votes = entity.getVotes() != null ? entity.getVotes().getUpVotes() : 0; + double tier = entity.getTags() != null ? getTier(entity.getTags()) : 0; + if (tier == 0) { + // if we don't have a tier set we can't run the classifier + return new double[] {}; } + return new double[] {tier, hasOwner, followers, votes}; + } - private double getTier(List tags) { - // get the tier of the asset + private double getTier(List tags) { + // get the tier of the asset - for (TagLabel tag : tags) { - if (tag.getName().contains("Tier")) { - switch (tag.getName()) { - case "Tier1": - return 1; - case "Tier2": - return 2; - case "Tier3": - return 3; - case "Tier4": - return 4; - case "Tier5": - return 5; - } - } + for (TagLabel tag : tags) { + if (tag.getName().contains("Tier")) { + switch (tag.getName()) { + case "Tier1": + return 1; + case "Tier2": + return 2; + case "Tier3": + return 3; + case "Tier4": + return 4; + case "Tier5": + return 5; } - return 0; + } } + return 0; + } } diff --git a/openmetadata-service/src/test/java/org/openmetadata/service/resources/dqtests/TestCaseResourceTest.java b/openmetadata-service/src/test/java/org/openmetadata/service/resources/dqtests/TestCaseResourceTest.java index a33bdca5aa1c..f1e8d30e2597 100644 --- a/openmetadata-service/src/test/java/org/openmetadata/service/resources/dqtests/TestCaseResourceTest.java +++ b/openmetadata-service/src/test/java/org/openmetadata/service/resources/dqtests/TestCaseResourceTest.java @@ -1569,8 +1569,10 @@ public void unauthorizedTestCaseResolutionFlow(TestInfo test) @Test public void testInferSeverity(TestInfo test) { - IncidentSeverityClassifierInterface severityClassifier = IncidentSeverityClassifierInterface.getInstance(); - // TEST_TABLE1 has no tier information, hence severity should be null as the classifier won't be able to infer + IncidentSeverityClassifierInterface severityClassifier = + IncidentSeverityClassifierInterface.getInstance(); + // TEST_TABLE1 has no tier information, hence severity should be null as the classifier won't be + // able to infer Severity severity = severityClassifier.classifyIncidentSeverity(TEST_TABLE1); assertNull(severity); @@ -1581,7 +1583,6 @@ public void testInferSeverity(TestInfo test) { // With tier set to Tier1, the severity should be inferred severity = severityClassifier.classifyIncidentSeverity(TEST_TABLE1); assertNotNull(severity); - } public void deleteTestCaseResult(String fqn, Long timestamp, Map authHeaders)