Skip to content

Commit

Permalink
Topic Classifier Changes (#585)
Browse files Browse the repository at this point in the history
* Topic Classifier Changes

* added test case changes

* model version updated

* code formatting done
  • Loading branch information
gr8nishan authored Oct 22, 2024
1 parent 997dba6 commit 79f1f9c
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 14 deletions.
4 changes: 4 additions & 0 deletions pebblo/app/pebblo-ui/src/constants/keywordMapping.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,8 @@ export const KEYWORD_MAPPING = {
"azure-key-id": "Azure Key ID",
"azure-client-secret": "Azure Client Secret",
"google-api-key": "Google API Key",
"harmful": "Harmful",
"medical": "Medical" ,
"financial": "Financial",
"corporate-documents": "Corporate Documents"
};
4 changes: 4 additions & 0 deletions pebblo/reports/enums/keyword_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,8 @@
"azure-key-id": "Azure Key ID",
"azure-client-secret": "Azure Client Secret",
"google-api-key": "Google API Key",
"harmful": "Harmful",
"medical": "Medical",
"financial": "Financial",
"corporate-documents": "Corporate Documents",
}
7 changes: 4 additions & 3 deletions pebblo/topic_classifier/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
TOPICS_TO_EXCLUDE = ["NORMAL_TEXT"]

# Model paths
TOKENIZER_PATH = "daxa-ai/pebblo-classifier"
CLASSIFIER_PATH = "daxa-ai/pebblo-classifier"
TOKENIZER_PATH = "daxa-ai/pebblo-classifier-v2"
CLASSIFIER_PATH = "daxa-ai/pebblo-classifier-v2"

# Specific model version to use. Revision can be any identifier allowed by
# git e.g. branch name, a tag name, or a commit id
MODEL_REVISION = "5fbbe83dee7ef72c61a8173c4ccf27b19788fc2e" # Pebblo classifier V8
# https://huggingface.co/daxa-ai/pebblo-classifier-v2
MODEL_REVISION = "a9a3816784cd6f5feb5a515e9536de78d64d6d49" # Pebblo classifier v2
4 changes: 4 additions & 0 deletions pebblo/topic_classifier/enums/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,8 @@
"INTERNAL_PRODUCT_ROADMAP_AGREEMENT": "internal-product-roadmap-agreement",
"SEXUAL_CONTENT": "sexual-content",
"SEXUAL_INCIDENT_REPORT": "sexual-incident-report",
"HARMFUL": "harmful",
"MEDICAL": "medical",
"FINANCIAL": "financial",
"CORPORATE_DOCUMENTS": "corporate-documents",
}
55 changes: 46 additions & 9 deletions tests/app/service/test_doc_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,37 @@
def test_get_classifier_response():
loader_helper = LoaderHelper(app_details, data=data, load_id=data.get("load_id"))
output = loader_helper._get_classifier_response(classifier_response_input_doc)
expected_output = {
"data": "Sachin's SSN is 222-85-4836. His passport ID is 5484880UA. His American express credit card number is\n371449635398431. AWS Access Key AKIAQIPT4PDORIRTV6PH. client-secret is de1d4a2d-d9fa-44f1-84bb-4f73c004afda\n",
"entityCount": 3,
"entities": {"us-ssn": 1, "credit-card-number": 1, "aws-access-key": 1},
"entityDetails": {
"us-ssn": [
{
"location": "16_27",
"confidence_score": "HIGH",
"entity_group": "pii-identification",
}
],
"credit-card-number": [
{
"location": "102_117",
"confidence_score": "HIGH",
"entity_group": "pii-financial",
}
],
"aws-access-key": [
{
"location": "134_154",
"confidence_score": "HIGH",
"entity_group": "secrets_and_tokens",
}
],
},
"topicCount": 1,
"topics": {"financial": 1},
"topicDetails": {"financial": [{"confidence_score": "MEDIUM"}]},
}
assert output.model_dump() == expected_output


Expand All @@ -126,12 +157,15 @@ def test_get_classifier_response_classifier_mode_topic():
output = loader_helper_classifier_mode_topic._get_classifier_response(
classifier_response_input_doc
)

expected_output.update(
{
"data": "Sachin's SSN is 222-85-4836. His passport ID is 5484880UA. His American express credit card number is\n371449635398431. AWS Access Key AKIAQIPT4PDORIRTV6PH. client-secret is de1d4a2d-d9fa-44f1-84bb-4f73c004afda\n",
"entityCount": 0,
"entities": {},
"entityDetails": {},
"topicCount": 1,
"topics": {"financial": 1},
"topicDetails": {"financial": [{"confidence_score": "MEDIUM"}]},
}
)
assert output.model_dump() == expected_output
Expand All @@ -148,31 +182,34 @@ def test_get_classifier_response_anonymize_true():
expected_output.update(
{
"data": "Sachin's SSN is <US_SSN>. His passport ID is 5484880UA. His American express credit card number is\n<CREDIT_CARD>. AWS Access Key <AWS_ACCESS_KEY>. client-secret is de1d4a2d-d9fa-44f1-84bb-4f73c004afda\n",
"entities": {"aws-access-key": 1, "credit-card-number": 1, "us-ssn": 1},
"entityCount": 3,
"entities": {"us-ssn": 1, "credit-card-number": 1, "aws-access-key": 1},
"entityDetails": {
"aws-access-key": [
"us-ssn": [
{
"location": "16_30",
"confidence_score": "HIGH",
"entity_group": "secrets_and_tokens",
"location": "141_163",
"entity_group": "pii-identification",
}
],
"credit-card-number": [
{
"location": "105_124",
"confidence_score": "HIGH",
"entity_group": "pii-financial",
"location": "105_124",
}
],
"us-ssn": [
"aws-access-key": [
{
"location": "141_163",
"confidence_score": "HIGH",
"entity_group": "pii-identification",
"location": "16_30",
"entity_group": "secrets_and_tokens",
}
],
},
"topicCount": 1,
"topics": {"financial": 1},
"topicDetails": {"financial": [{"confidence_score": "MEDIUM"}]},
}
)
assert output.model_dump() == expected_output
35 changes: 35 additions & 0 deletions tests/app/service/test_loader_doc_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,37 @@ def test_get_classifier_response(app_loader_helper):
app_loader_helper.classifier_mode = ClassificationMode.ALL.value
app_loader_helper.anonymize_snippets = False
output = app_loader_helper._get_doc_classification(classifier_response_input_doc)
expected_output = {
"data": "Sachin's SSN is 222-85-4836. His passport ID is 5484880UA. His American express credit card number is\n371449635398431. AWS Access Key AKIAQIPT4PDORIRTV6PH. client-secret is de1d4a2d-d9fa-44f1-84bb-4f73c004afda\n",
"entityCount": 3,
"entities": {"us-ssn": 1, "credit-card-number": 1, "aws-access-key": 1},
"entityDetails": {
"us-ssn": [
{
"location": "16_27",
"confidence_score": "HIGH",
"entity_group": "pii-identification",
}
],
"credit-card-number": [
{
"location": "102_117",
"confidence_score": "HIGH",
"entity_group": "pii-financial",
}
],
"aws-access-key": [
{
"location": "134_154",
"confidence_score": "HIGH",
"entity_group": "secrets_and_tokens",
}
],
},
"topicCount": 1,
"topics": {"financial": 1},
"topicDetails": {"financial": [{"confidence_score": "MEDIUM"}]},
}
assert output.model_dump() == expected_output


Expand All @@ -71,9 +102,13 @@ def test_get_classifier_response_classifier_mode_topic(app_loader_helper):
output = app_loader_helper._get_doc_classification(classifier_response_input_doc)
expected_output.update(
{
"data": "Sachin's SSN is 222-85-4836. His passport ID is 5484880UA. His American express credit card number is\n371449635398431. AWS Access Key AKIAQIPT4PDORIRTV6PH. client-secret is de1d4a2d-d9fa-44f1-84bb-4f73c004afda\n",
"entityCount": 0,
"entities": {},
"entityDetails": {},
"topicCount": 1,
"topics": {"financial": 1},
"topicDetails": {"financial": [{"confidence_score": "MEDIUM"}]},
}
)
assert output.model_dump() == expected_output
Expand Down
3 changes: 1 addition & 2 deletions tests/app/test_prompt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ def test_app_prompt_success(mock_write_json_to_file):
"entities": {},
"topics": {},
}

assert response.json()["retrieval_data"]["response"] == {
"entities": {"us-ssn": 1},
"topics": {},
"topics": {"medical": 1},
}


Expand Down

0 comments on commit 79f1f9c

Please sign in to comment.