Skip to content

Commit

Permalink
refactor(Proofs): Add test to make sure we don't run price_tag detect…
Browse files Browse the repository at this point in the history
…ion on RECEIPT proofs. ref #683
  • Loading branch information
raphodn committed Jan 19, 2025
1 parent 4c10f86 commit c388e5b
Showing 1 changed file with 82 additions and 9 deletions.
91 changes: 82 additions & 9 deletions open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,11 @@ def test_fetch_and_save_ocr_data_invalid_extension(self):


class MLModelTest(TestCase):
@classmethod
def setUpTestData(cls):
# Create a white blank image with Pillow
cls.image = Image.new("RGB", (100, 100), "white")

def test_run_and_save_proof_prediction_proof_file_not_found(self):
proof = ProofFactory()
# check that we emit an error log
Expand All @@ -390,9 +395,78 @@ def test_run_and_save_proof_prediction_proof_file_not_found(self):
],
)

def test_run_and_save_proof_prediction_proof(self):
# Create a white blank image with Pillow
image = Image.new("RGB", (100, 100), "white")
def test_run_and_save_proof_prediction_for_receipt_proof(self):
predict_proof_type_response = [
("RECEIPT", 0.9786477088928223),
("PRICE_TAG", 0.021345501765608788),
]

# We save the image to a temporary file
with tempfile.TemporaryDirectory() as tmpdirname:
NEW_IMAGE_DIR = Path(tmpdirname)
file_path = NEW_IMAGE_DIR / "1.jpg"
self.image.save(file_path)

# change temporarily settings.IMAGE_DIR
with self.settings(IMAGE_DIR=NEW_IMAGE_DIR):
proof = ProofFactory(
file_path=file_path, type=proof_constants.TYPE_RECEIPT
)

# Patch predict_proof_type to return a fixed response
with (
unittest.mock.patch(
"open_prices.proofs.ml.predict_proof_type",
return_value=predict_proof_type_response,
) as mock_predict_proof_type,
unittest.mock.patch(
"open_prices.proofs.ml.detect_price_tags",
return_value=None,
) as mock_detect_price_tags,
):
run_and_save_proof_prediction(proof, run_price_tag_extraction=False)
mock_predict_proof_type.assert_called_once()
mock_detect_price_tags.assert_not_called()

proof_type_prediction = proof.predictions.filter(
type=proof_constants.PROOF_PREDICTION_CLASSIFICATION_TYPE
).first()
self.assertIsNotNone(proof_type_prediction)
self.assertEqual(
proof_type_prediction.type,
proof_constants.PROOF_PREDICTION_CLASSIFICATION_TYPE,
)

self.assertEqual(
proof_type_prediction.model_name, "price_proof_classification"
)
self.assertEqual(
proof_type_prediction.model_version,
"price_proof_classification-1.0",
)
self.assertEqual(proof_type_prediction.value, "RECEIPT")
self.assertEqual(
proof_type_prediction.max_confidence, 0.9786477088928223
)
self.assertEqual(
proof_type_prediction.data,
{
"prediction": [
{"label": "RECEIPT", "score": 0.9786477088928223},
{"label": "PRICE_TAG", "score": 0.021345501765608788},
]
},
)

price_tag_prediction = proof.predictions.filter(
type=proof_constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE
).first()
self.assertIsNone(price_tag_prediction)

proof_type_prediction.delete()
proof.delete()

def test_run_and_save_proof_prediction_for_price_tag_proof(self):
predict_proof_type_response = [
("SHELF", 0.9786477088928223),
("PRICE_TAG", 0.021345501765608788),
Expand All @@ -409,7 +483,7 @@ def test_run_and_save_proof_prediction_proof(self):
with tempfile.TemporaryDirectory() as tmpdirname:
NEW_IMAGE_DIR = Path(tmpdirname)
file_path = NEW_IMAGE_DIR / "1.jpg"
image.save(file_path)
self.image.save(file_path)

# change temporarily settings.IMAGE_DIR
with self.settings(IMAGE_DIR=NEW_IMAGE_DIR):
Expand Down Expand Up @@ -488,20 +562,17 @@ def test_run_and_save_proof_prediction_proof(self):
proof.delete()

def test_run_and_save_proof_type_prediction_already_exists(self):
image = Image.new("RGB", (100, 100), "white")

proof = ProofFactory()
ProofPredictionFactory(
proof=proof,
type=proof_constants.PROOF_PREDICTION_CLASSIFICATION_TYPE,
model_name=PROOF_CLASSIFICATION_MODEL_NAME,
model_version=PROOF_CLASSIFICATION_MODEL_VERSION,
)
result = run_and_save_proof_type_prediction(image, proof)
result = run_and_save_proof_type_prediction(self.image, proof)
self.assertIsNone(result)

def test_run_and_save_price_tag_detection_already_exists(self):
image = Image.new("RGB", (100, 100), "white")
proof = ProofFactory(type=proof_constants.TYPE_PRICE_TAG)
ProofPredictionFactory(
proof=proof,
Expand All @@ -523,7 +594,9 @@ def test_run_and_save_price_tag_detection_already_exists(self):
]
},
)
result = run_and_save_price_tag_detection(image, proof, run_extraction=False)
result = run_and_save_price_tag_detection(
self.image, proof, run_extraction=False
)
self.assertIsNone(result)
price_tags = PriceTag.objects.filter(proof=proof).all()
self.assertEqual(len(price_tags), 2)
Expand Down

0 comments on commit c388e5b

Please sign in to comment.