Skip to content

Commit

Permalink
feat: save Gemini prediction in price_tag_predictions table
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Dec 17, 2024
1 parent a0c4741 commit bac7c3a
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 155 deletions.
4 changes: 2 additions & 2 deletions open_prices/api/proofs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from open_prices.api.utils import get_source_from_request
from open_prices.common.authentication import CustomAuthentication
from open_prices.common.constants import PriceTagStatus
from open_prices.common.gemini import handle_bulk_labels
from open_prices.proofs.ml import extract_from_price_tags
from open_prices.proofs.models import PriceTag, Proof
from open_prices.proofs.utils import store_file

Expand Down Expand Up @@ -125,7 +125,7 @@ def upload(self, request: Request) -> Response:
def process_with_gemini(self, request: Request) -> Response:
files = request.FILES.getlist("files")
sample_files = [PIL.Image.open(file.file) for file in files]
res = handle_bulk_labels(sample_files)
res = extract_from_price_tags(sample_files)
return Response(res, status=status.HTTP_200_OK)


Expand Down
141 changes: 0 additions & 141 deletions open_prices/common/gemini.py
Original file line number Diff line number Diff line change
@@ -1,141 +0,0 @@
import enum
import json

import google.generativeai as genai
import typing_extensions as typing
from django.conf import settings

genai.configure(api_key=settings.GOOGLE_GEMINI_API_KEY)
model = genai.GenerativeModel(model_name="gemini-1.5-flash")


# TODO: what about orther categories ?
class Products(enum.Enum):
OTHER = "other"
APPLES = "en:apples"
APRICOTS = "en:apricots"
ARTICHOKES = "en:artichokes"
ASPARAGUS = "en:asparagus"
AUBERGINES = "en:aubergines"
AVOCADOS = "en:avocados"
BANANAS = "en:bananas"
BEET = "en:beet"
BERRIES = "en:berries"
BLACKBERRIES = "en:blackberries"
BLUEBERRIES = "en:blueberries"
BOK_CHOY = "en:bok-choy"
BROCCOLI = "en:broccoli"
CABBAGES = "en:cabbages"
CARROTS = "en:carrots"
CAULIFLOWERS = "en:cauliflowers"
CELERY = "en:celery"
CELERY_STALK = "en:celery-stalk"
CEP_MUSHROOMS = "en:cep-mushrooms"
CHANTERELLES = "en:chanterelles"
CHERRIES = "en:cherries"
CHERRY_TOMATOES = "en:cherry-tomatoes"
CHICKPEAS = "en:chickpeas"
CHIVES = "en:chives"
CLEMENTINES = "en:clementines"
COCONUTS = "en:coconuts"
CRANBERRIES = "en:cranberries"
CUCUMBERS = "en:cucumbers"
DATES = "en:dates"
ENDIVES = "en:endives"
FIGS = "en:figs"
GARLIC = "en:garlic"
GINGER = "en:ginger"
GRAPEFRUITS = "en:grapefruits"
GRAPES = "en:grapes"
GREEN_BEANS = "en:green-beans"
KIWIS = "en:kiwis"
KAKIS = "en:kakis"
LEEKS = "en:leeks"
LEMONS = "en:lemons"
LETTUCES = "en:lettuces"
LIMES = "en:limes"
LYCHEES = "en:lychees"
MANDARIN_ORANGES = "en:mandarin-oranges"
MANGOES = "en:mangoes"
MELONS = "en:melons"
MUSHROOMS = "en:mushrooms"
NECTARINES = "en:nectarines"
ONIONS = "en:onions"
ORANGES = "en:oranges"
PAPAYAS = "en:papayas"
PASSION_FRUITS = "en:passion-fruits"
PEACHES = "en:peaches"
PEARS = "en:pears"
PEAS = "en:peas"
PEPPERS = "en:peppers"
PINEAPPLE = "en:pineapple"
PLUMS = "en:plums"
POMEGRANATES = "en:pomegranates"
POMELOS = "en:pomelos"
POTATOES = "en:potatoes"
PUMPKINS = "en:pumpkins"
RADISHES = "en:radishes"
RASPBERRIES = "en:raspberries"
RHUBARBS = "en:rhubarbs"
SCALLIONS = "en:scallions"
SHALLOTS = "en:shallots"
SPINACHS = "en:spinachs"
SPROUTS = "en:sprouts"
STRAWBERRIES = "en:strawberries"
TOMATOES = "en:tomatoes"
TURNIP = "en:turnip"
WATERMELONS = "en:watermelons"
WALNUTS = "en:walnuts"
ZUCCHINI = "en:zucchini"


# TODO: what about other origins ?
class Origin(enum.Enum):
FRANCE = "en:france"
ITALY = "en:italy"
SPAIN = "en:spain"
POLAND = "en:poland"
CHINA = "en:china"
BELGIUM = "en:belgium"
MOROCCO = "en:morocco"
PERU = "en:peru"
PORTUGAL = "en:portugal"
MEXICO = "en:mexico"
OTHER = "other"
UNKNOWN = "unknown"


class Unit(enum.Enum):
KILOGRAM = "KILOGRAM"
UNIT = "UNIT"


class Label(typing.TypedDict):
product: Products
price: float
origin: Origin
unit: Unit
organic: bool
barcode: str


class Labels(typing.TypedDict):
labels: list[Label]


def handle_bulk_labels(images):
response = model.generate_content(
[
"Here are "
+ str(len(images))
+ " pictures containing a label. For each picture of a label, please extract all the following attributes: the product category matching product name, the origin category matching country of origin, the price, is the product organic, the unit (per KILOGRAM or per UNIT) and the barcode. I expect a list of "
+ str(len(images))
+ " labels in your reply, no more, no less. If you cannot decode an attribute, set it to an empty string"
]
+ images,
generation_config=genai.GenerationConfig(
response_mime_type="application/json", response_schema=Labels
),
)
vals = json.loads(response.text)
return vals
8 changes: 6 additions & 2 deletions open_prices/proofs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
PROOF_PREDICTION_OBJECT_DETECTION_TYPE = "OBJECT_DETECTION"
PROOF_PREDICTION_CLASSIFICATION_TYPE = "CLASSIFICATION"
PROOF_PREDICTION_RECEIPT_EXTRACTION_TYPE = "RECEIPT_EXTRACTION"
PROOF_PREDICTION_PRICE_TAG_EXTRACTION_TYPE = "PRICE_TAG_EXTRACTION"
PROOF_PREDICTION_LIST = [
PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
PROOF_PREDICTION_CLASSIFICATION_TYPE,
PROOF_PREDICTION_RECEIPT_EXTRACTION_TYPE,
PROOF_PREDICTION_PRICE_TAG_EXTRACTION_TYPE,
]

PROOF_TYPE_CHOICES = [(key, key) for key in PROOF_PREDICTION_LIST]

PRICE_TAG_EXTRACTION_TYPE = "PRICE_TAG_EXTRACTION"

PRICE_TAG_PREDICTION_TYPE_CHOICES = [
(PRICE_TAG_EXTRACTION_TYPE, PRICE_TAG_EXTRACTION_TYPE)
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Generated by Django 5.1.4 on 2024-12-17 14:01

import django.db.models.deletion
import django.utils.timezone
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("proofs", "0007_pricetag"),
]

operations = [
migrations.AlterField(
model_name="proofprediction",
name="type",
field=models.CharField(
choices=[
("OBJECT_DETECTION", "OBJECT_DETECTION"),
("CLASSIFICATION", "CLASSIFICATION"),
("RECEIPT_EXTRACTION", "RECEIPT_EXTRACTION"),
],
max_length=20,
verbose_name="The type of the prediction",
),
),
migrations.CreateModel(
name="PriceTagPrediction",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
(
"type",
models.CharField(
choices=[("PRICE_TAG_EXTRACTION", "PRICE_TAG_EXTRACTION")],
help_text="The type of the prediction",
max_length=20,
),
),
(
"model_name",
models.CharField(
help_text="The name of the model that generated the prediction",
max_length=30,
),
),
(
"model_version",
models.CharField(
help_text="The specific version of the model that generated the prediction",
max_length=30,
),
),
(
"created",
models.DateTimeField(
default=django.utils.timezone.now,
help_text="When the prediction was created in DB",
),
),
(
"data",
models.JSONField(
default=dict,
help_text="a dict representing the data of the prediction. This field is model-specific.",
),
),
(
"price_tag",
models.ForeignKey(
help_text="The price tag this prediction belongs to",
on_delete=django.db.models.deletion.CASCADE,
related_name="predictions",
to="proofs.pricetag",
),
),
],
options={
"verbose_name": "Price Tag Prediction",
"verbose_name_plural": "Price Tag Predictions",
"db_table": "price_tag_predictions",
},
),
]
Loading

0 comments on commit bac7c3a

Please sign in to comment.