From 6d62da6fe839b08c81d13871d8bfd15d63f27f85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Mon, 10 Apr 2023 17:22:01 +0700 Subject: [PATCH] feat: implement real multi-platform support (OFF, OBF,...) - delete server_domain in image, product_insight and prediction tables - add server_type field to image and prediction tables - use ProductIdentifier (barcode + server_type) instead of barcode in codebase See https://github.com/openfoodfacts/robotoff/issues/894 --- .env | 4 +- .github/workflows/container-deploy.yml | 8 +- .../interactions-product-opener.md | 2 - doc/how-to-guides/test-and-debug.md | 7 +- doc/references/api.yml | 79 ++++-- docker-compose.yml | 4 +- robotoff/app/api.py | 194 +++++++------ robotoff/app/core.py | 50 ++-- robotoff/app/events.py | 2 + robotoff/app/schema.py | 8 +- robotoff/brands.py | 10 +- robotoff/cli/insights.py | 13 +- robotoff/cli/logos.py | 16 +- robotoff/cli/main.py | 129 +++++---- robotoff/images.py | 44 ++- robotoff/insights/annotate.py | 65 +++-- robotoff/insights/extraction.py | 20 +- robotoff/insights/importer.py | 147 +++++----- robotoff/insights/question.py | 58 ++-- robotoff/logos.py | 38 ++- robotoff/metrics.py | 23 +- robotoff/models.py | 42 ++- robotoff/mongo.py | 11 - robotoff/off.py | 256 ++++++++---------- robotoff/prediction/category/__init__.py | 5 +- robotoff/prediction/category/matcher.py | 3 +- .../category/neural/category_classifier.py | 15 +- .../keras_category_classifier_3_0/__init__.py | 35 ++- robotoff/prediction/ocr/core.py | 15 +- robotoff/products.py | 58 ++-- robotoff/scheduler/__init__.py | 46 ++-- robotoff/scheduler/latent.py | 14 +- robotoff/settings.py | 92 ++++--- robotoff/slack.py | 67 +++-- robotoff/spellcheck/__init__.py | 3 +- robotoff/types.py | 55 +++- robotoff/workers/tasks/__init__.py | 23 +- robotoff/workers/tasks/import_image.py | 118 ++++---- robotoff/workers/tasks/product_updated.py | 50 ++-- scripts/insert_image_predictions.py | 20 +- scripts/insert_images.py | 22 +- scripts/remove_duplicates.py | 10 +- .../insights/test_category_import.py | 27 +- .../insights/test_process_insights.py | 15 +- tests/integration/models_utils.py | 12 +- tests/integration/test_api.py | 12 +- tests/integration/test_core_integration.py | 58 ++-- tests/integration/test_import_image.py | 9 +- tests/integration/test_logos.py | 30 +- tests/unit/insights/test_importer.py | 32 +-- tests/unit/insights/test_question.py | 11 +- .../neural/test_category_classifier.py | 5 +- tests/unit/test_logos.py | 9 +- tests/unit/test_models.py | 4 +- tests/unit/test_settings.py | 27 +- tests/unit/test_slack.py | 55 ++-- .../workers/tasks/test_product_updated.py | 28 +- 57 files changed, 1284 insertions(+), 931 deletions(-) delete mode 100644 robotoff/mongo.py diff --git a/.env b/.env index 4ca9b10da7..49f5796b49 100644 --- a/.env +++ b/.env @@ -23,10 +23,10 @@ ROBOTOFF_INSTANCE=dev # Overwrites the Product Opener domain used. If empty, the domain will # be inferred from `ROBOTOFF_INSTANCE` -ROBOTOFF_DOMAIN=openfoodfacts.net +ROBOTOFF_TLD=net # if you want to connect to a Product Opener dev instance on localhost, use: -# STATIC_OFF_DOMAIN=http://openfoodfacts.localhost +# STATIC_DOMAIN=http://openfoodfacts.localhost # ROBOTOFF_SCHEME=http # for dev scheme is http # for dev only on localhost diff --git a/.github/workflows/container-deploy.yml b/.github/workflows/container-deploy.yml index 053f8cc6af..0d1395526e 100644 --- a/.github/workflows/container-deploy.yml +++ b/.github/workflows/container-deploy.yml @@ -25,7 +25,7 @@ jobs: echo "SSH_PROXY_HOST=ovh2.openfoodfacts.org" >> $GITHUB_ENV echo "SSH_USERNAME=off" >> $GITHUB_ENV echo "ROBOTOFF_INSTANCE=dev" >> $GITHUB_ENV - echo "ROBOTOFF_DOMAIN=openfoodfacts.net" >> $GITHUB_ENV + echo "ROBOTOFF_TLD=net" >> $GITHUB_ENV echo "MONGO_URI=mongodb://10.1.0.200:27017" >> $GITHUB_ENV echo "INFLUXDB_HOST=10.1.0.200" >> $GITHUB_ENV - name: Set various variable for production deployment @@ -35,7 +35,7 @@ jobs: echo "SSH_PROXY_HOST=ovh2.openfoodfacts.org" >> $GITHUB_ENV echo "SSH_USERNAME=off" >> $GITHUB_ENV echo "ROBOTOFF_INSTANCE=prod" >> $GITHUB_ENV - echo "ROBOTOFF_DOMAIN=openfoodfacts.org" >> $GITHUB_ENV + echo "ROBOTOFF_TLD=org" >> $GITHUB_ENV echo "MONGO_URI=mongodb://213.36.253.195:27017" >> $GITHUB_ENV echo "INFLUXDB_HOST=10.1.0.201" >> $GITHUB_ENV - name: Wait for container build workflow @@ -111,7 +111,7 @@ jobs: # Set app variables echo "ROBOTOFF_INSTANCE=${{ env.ROBOTOFF_INSTANCE }}" >> .env - echo "ROBOTOFF_DOMAIN=${{ env.ROBOTOFF_DOMAIN }}" >> .env + echo "ROBOTOFF_TLD=${{ env.ROBOTOFF_TLD }}" >> .env echo "REDIS_HOST=redis.robotoff_default" >> .env echo "POSTGRES_HOST=postgres.robotoff_default" >> .env echo "POSTGRES_DB=postgres" >> .env @@ -132,7 +132,7 @@ jobs: echo "INFLUXDB_AUTH_TOKEN=${{ secrets.INFLUXDB_AUTH_TOKEN }}" >> .env echo "SLACK_TOKEN=${{ secrets.SLACK_TOKEN }}" >> .env echo "GUNICORN_NUM_WORKERS=8" - echo "EVENTS_API_URL=https://event.${{ env.ROBOTOFF_DOMAIN }}" >> .env + echo "EVENTS_API_URL=https://event.openfoodfacts.${{ env.ROBOTOFF_TLD }}" >> .env # TODO: remove this url when we have a proper server running for this purpose echo "IMAGE_MODERATION_SERVICE_URL=https://amathjourney.com/api/off-annotation/flag-image" diff --git a/doc/explanations/interactions-product-opener.md b/doc/explanations/interactions-product-opener.md index 794ae75b51..d94402518c 100644 --- a/doc/explanations/interactions-product-opener.md +++ b/doc/explanations/interactions-product-opener.md @@ -14,7 +14,6 @@ Product Opener calls `POST /api/v1/webhook/product` whenever a product is update - `barcode`: the barcode of product - `action`: either `updated` or `deleted` -- `server_domain`: the server domain (ex: `api.openfoodfacts.org`) After receiving a `product_update` webhook call, Robotoff does the following [^product_update]: @@ -29,7 +28,6 @@ Product Opener calls `POST /api/v1/images/import` whenever an new image is uploa - `barcode`: the barcode of product - `image_url`: the URL of the image - `ocr_url`: the URL of the OCR result (JSON file) -- `server_domain`: the server domain (ex: `api.openfoodfacts.org`) After receiving a `import_image` webhook call, Robotoff does the following [^image_import]: diff --git a/doc/how-to-guides/test-and-debug.md b/doc/how-to-guides/test-and-debug.md index f905821917..ad9b4631d7 100644 --- a/doc/how-to-guides/test-and-debug.md +++ b/doc/how-to-guides/test-and-debug.md @@ -66,11 +66,8 @@ Write test cases every time you write a new feature, to test a feature or to und There are even cases where automated tests are your only chance to test you code. For example: when you write code to post notifications on Slack channel you can only test them by writing a unit test case. -There are instances when Robotoff tries to connect to MongoDB via Open Food Facts server. For local testing we do not yet provide a standarized approach to add a MongoDB Docker in the same network and configure Robotoff to use it. - -In such cases you will have to mock the function which calls MongoDB. Feel free to reuse the existing test cases. - -To identify parts of the code where Robotoff connects to MongoDB or to Open Food Facts server (the part you should mock), keep an eye for variables like `server_url`, `server_domain` or `settings.OFF_SERVER_DOMAIN`. +There are instances when Robotoff tries to connect to MongoDB via Open Food Facts server. To disable this +feature (this is disabled by default on local environments), set `DISABLE_PRODUCT_CHECK=1` in your `.env`. # Debugging guide diff --git a/doc/references/api.yml b/doc/references/api.yml index b9155f2d2f..0cddb89a61 100644 --- a/doc/references/api.yml +++ b/doc/references/api.yml @@ -5,8 +5,6 @@ info: Robotoff provides a simple API allowing consumers to fetch predictions and annotate them. All endpoints must be prefixed with `/api/v1`. The full URL is `https://robotoff.openfoodfacts.org/api/v1/{endpoint}`. - - Robotoff can interact with all Openfoodfacts products: Openfoodfacts, Openbeautyfacts, etc. and all environments (production, development, pro). The `server_domain` field should be used to specify the product/environment: `api.openfoodfacts.org` for OFF-prod, `api.openfoodfacts.net` for OFF-dev, `api.openbeautyfacts.org` for OBF-prod,... contact: {} version: "1.0" servers: @@ -30,7 +28,7 @@ paths: default: 1 minimum: 1 - $ref: "#/components/parameters/barcode_path" - - $ref: "#/components/parameters/server_domain" + - $ref: "#/components/parameters/server_type" - $ref: "#/components/parameters/lang" responses: "200": @@ -57,7 +55,7 @@ paths: parameters: - $ref: "#/components/parameters/lang" - $ref: "#/components/parameters/count" - - $ref: "#/components/parameters/server_domain" + - $ref: "#/components/parameters/server_type" - $ref: "#/components/parameters/insight_types" - $ref: "#/components/parameters/country" - $ref: "#/components/parameters/brands" @@ -110,7 +108,7 @@ paths: parameters: - $ref: "#/components/parameters/lang" - $ref: "#/components/parameters/count" - - $ref: "#/components/parameters/server_domain" + - $ref: "#/components/parameters/server_type" - $ref: "#/components/parameters/insight_types" - $ref: "#/components/parameters/country" - $ref: "#/components/parameters/brands" @@ -151,7 +149,7 @@ paths: parameters: - $ref: "#/components/parameters/lang" - $ref: "#/components/parameters/count" - - $ref: "#/components/parameters/server_domain" + - $ref: "#/components/parameters/server_type" - $ref: "#/components/parameters/insight_types" - $ref: "#/components/parameters/country" - $ref: "#/components/parameters/brands" @@ -180,7 +178,7 @@ paths: type: number default: 25 minimum: 1 - - $ref: "#/components/parameters/server_domain" + - $ref: "#/components/parameters/server_type" - $ref: "#/components/parameters/insight_type" - $ref: "#/components/parameters/country" - $ref: "#/components/parameters/page" @@ -226,7 +224,7 @@ paths: - $ref: "#/components/parameters/insight_type" - $ref: "#/components/parameters/country" - $ref: "#/components/parameters/value_tag" - - $ref: "#/components/parameters/server_domain" + - $ref: "#/components/parameters/server_type" - $ref: "#/components/parameters/count" - $ref: "#/components/parameters/predictor" responses: @@ -249,6 +247,7 @@ paths: summary: Get all insights for a specific product parameters: - $ref: "#/components/parameters/barcode_path" + - $ref: "#/components/parameters/server_type" responses: "200": description: "" @@ -330,6 +329,7 @@ paths: tags: - Insights parameters: + - $ref: "#/components/parameters/server_type" - $ref: "#/components/parameters/value_tag" - $ref: "#/components/parameters/insight_types" - name: barcode @@ -450,12 +450,7 @@ paths: format: uri description: URL of the OCR JSON file associated with the image server_domain: - type: string - description: | - The server domain associated with the image/product. - If the server domain does not match with the server configuration, - the import task will be rejected by Robotoff. - example: "api.openfoodfacts.org" + $ref: "#/components/schemas/ServerDomainParameter" required: - "barcode" - "image_url" @@ -472,12 +467,16 @@ paths: status: type: string description: | - status of the import operation, either `scheduled` if it was - successfully scheduled or rejected if the `server_domain` did - not match Robotoff configured server domain. + status of the import operation, always `scheduled` enum: - - "rejected" - "scheduled" + "400": + description: "HTTP Bad Request error, if the `server_domain` parameter is invalid" + content: + application/json: + schema: + type: object + /images/logos: get: @@ -517,6 +516,7 @@ paths: Search for logos detected using the universal-logo-detector model that meet some criteria (annotation status, annotated, type,...) parameters: + - $ref: "#/components/parameters/server_type" - name: count description: Number of results to return in: query @@ -561,11 +561,6 @@ paths: schema: type: boolean default: false - - name: server_domain - in: query - description: The server domain - schema: - type: string - name: annotated description: The annotation status of the logo. If not provided, both annotated and non-annotated logos are returned @@ -660,6 +655,17 @@ paths: description: The barcode of the product to categorize minLength: 1 example: 0748162621021 + server_type: + type: string + description: | + The server type (=project) to use, such as 'off' (Open Food Facts), 'obf' (Open Beauty Facts),... + Only 'off' is currently supported for category prediction + default: 'off' + enum: + - 'off' + - 'obf' + - 'opff' + - 'opf' deepest_only: type: boolean description: | @@ -875,6 +881,20 @@ components: id: 3cd5aecd-edcc-4237-87d0-6595fc4e53c9 type: label barcode: 9782012805866 + ServerDomainParameter: + description: | + The server domain associated with the image/product. + + If the `server_domain` top level domain does not match the server configuration, + an HTTP 400 error will be raised + type: string + example: "api.openfoodfacts.org" + enum: + - "api.openfoodfacts.org" + - "api.openbeautyfacts.org" + - "api.openproductfacts.org" + - "api.openpetfoodfacts.org" + - "api.pro.openfoodfacts.org" parameters: lang: name: lang @@ -891,13 +911,18 @@ components: type: integer default: 25 minimum: 1 - server_domain: - name: server_domain + server_type: + name: server_type in: query - description: The server domain + description: The server type (=project) to use, such as 'off' (Open Food Facts), 'obf' (Open Beauty Facts),... schema: type: string - default: api.openfoodfacts.org + default: 'off' + enum: + - 'off' + - 'obf' + - 'opff' + - 'opf' insight_types: name: insight_types in: query diff --git a/docker-compose.yml b/docker-compose.yml index d792806091..76340400bf 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -14,9 +14,9 @@ x-robotoff-base-env: &robotoff-base-env LOG_LEVEL: ROBOTOFF_INSTANCE: - ROBOTOFF_DOMAIN: + ROBOTOFF_TLD: ROBOTOFF_SCHEME: - STATIC_OFF_DOMAIN: + STATIC_DOMAIN: GUNICORN_NUM_WORKERS: ROBOTOFF_UPDATED_PRODUCT_WAIT: REDIS_HOST: diff --git a/robotoff/app/api.py b/robotoff/app/api.py index 40db778fa1..1b48d6164c 100644 --- a/robotoff/app/api.py +++ b/robotoff/app/api.py @@ -3,6 +3,7 @@ import functools import hashlib import io +import re import tempfile import uuid from typing import Literal, Optional @@ -58,7 +59,6 @@ generate_image_path, get_barcode_from_url, get_product, - get_server_type, ) from robotoff.prediction.category import predict_category from robotoff.prediction.object_detection import ObjectDetectionModelRegistry @@ -71,6 +71,8 @@ JSONType, NeuralCategoryClassifierModel, PredictionType, + ProductIdentifier, + ServerType, ) from robotoff.utils import get_image_from_url, get_logger, http_session from robotoff.utils.i18n import TranslationStore @@ -93,6 +95,21 @@ TRANSLATION_STORE.load() +def get_server_type_from_req( + req: falcon.Request, default: ServerType = ServerType.off +) -> ServerType: + """Get `ServerType` value from POST x-www-form-urlencoded or GET requests.""" + server_type_str = req.get_param("server_type") + + if server_type_str is None: + return default + + try: + return ServerType[server_type_str] + except KeyError: + raise falcon.HTTPBadRequest(f"invalid `server_type`: {server_type_str}") + + def _get_skip_voted_on( auth: Optional[OFFAuthentication], device_id: str ) -> SkipVotedOn: @@ -114,12 +131,12 @@ def _get_skip_voted_on( class ProductInsightResource: def on_get(self, req: falcon.Request, resp: falcon.Response, barcode: str): - server_domain: Optional[str] = req.get_param("server_domain") response: JSONType = {} + server_type = get_server_type_from_req(req) insights = [ insight.to_dict() for insight in get_insights( - barcode=barcode, server_domain=server_domain, limit=None + barcode=barcode, server_type=server_type, limit=None ) ] @@ -157,7 +174,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): value_tag: str = req.get_param("value_tag") brands = req.get_param_as_list("brands") or None predictor = req.get_param("predictor") - server_domain: Optional[str] = req.get_param("server_domain") + server_type = get_server_type_from_req(req) if keep_types: # Limit the number of types to prevent slow SQL queries @@ -169,9 +186,9 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): get_insights_ = functools.partial( get_insights, + server_type=server_type, keep_types=keep_types, country=country, - server_domain=server_domain, value_tag=value_tag, brands=brands, annotated=annotated, @@ -201,18 +218,18 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): insight_type: Optional[str] = req.get_param("type") country: Optional[str] = req.get_param("country") value_tag: Optional[str] = req.get_param("value_tag") - server_domain: Optional[str] = req.get_param("server_domain") count: int = req.get_param_as_int("count", min_value=1, default=25) predictor = req.get_param("predictor") + server_type = get_server_type_from_req(req) keep_types = [insight_type] if insight_type else None get_insights_ = functools.partial( get_insights, + server_type=server_type, keep_types=keep_types, country=country, value_tag=value_tag, order_by="random", - server_domain=server_domain, predictor=predictor, ) @@ -315,12 +332,13 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): def spellcheck(self, req: falcon.Request, resp: falcon.Response): text = req.get_param("text") + server_type = get_server_type_from_req(req) if text is None: barcode = req.get_param("barcode") if barcode is None: raise falcon.HTTPBadRequest("text or barcode is required.") - product = get_product(barcode) or {} + product = get_product(ProductIdentifier(barcode, server_type)) or {} text = product.get("ingredients_text_fr") if text is None: resp.media = {"status": "not_found"} @@ -361,6 +379,7 @@ def spellcheck(self, req: falcon.Request, resp: falcon.Response): class NutrientPredictorResource: def on_get(self, req: falcon.Request, resp: falcon.Response): ocr_url = req.get_param("ocr_url", required=True) + server_type = get_server_type_from_req(req) if not ocr_url.endswith(".json"): raise falcon.HTTPBadRequest("a JSON file is expected") @@ -372,7 +391,9 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): try: predictions = extract_ocr_predictions( - barcode, ocr_url, [PredictionType.nutrient] + ProductIdentifier(barcode, server_type), + ocr_url, + [PredictionType.nutrient], ) except requests.exceptions.RequestException: @@ -396,13 +417,16 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): class OCRInsightsPredictorResource: def on_get(self, req: falcon.Request, resp: falcon.Response): ocr_url = req.get_param("ocr_url", required=True) + server_type = get_server_type_from_req(req) barcode = get_barcode_from_url(ocr_url) if barcode is None: raise falcon.HTTPBadRequest(f"invalid OCR URL: {ocr_url}") try: insights = extract_ocr_predictions( - barcode, ocr_url, DEFAULT_OCR_PREDICTION_TYPES + ProductIdentifier(barcode, server_type), + ocr_url, + DEFAULT_OCR_PREDICTION_TYPES, ) except requests.exceptions.RequestException: @@ -430,8 +454,14 @@ class CategoryPredictorResource: def on_post(self, req: falcon.Request, resp: falcon.Response): """Predict categories using neural categorizer and matching algorithm for a specific product.""" - predictors: list[str] = req.media.get("predictors") or ["neural", "matcher"] + server_type: ServerType = ServerType[req.media.get("server_type", "off")] + if server_type != ServerType.off: + raise falcon.HTTPBadRequest( + f"category predictor is only available for 'off' server type (here: '{server_type.name}')" + ) + + predictors: list[str] = req.media.get("predictors") or ["neural", "matcher"] neural_model_name = None if (neural_model_name_str := req.media.get("neural_model_name")) is not None: neural_model_name = NeuralCategoryClassifierModel[neural_model_name_str] @@ -439,11 +469,13 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): if "barcode" in req.media: # Fetch product from DB barcode: str = req.media["barcode"] - product = get_product(barcode) or {} + product = get_product(ProductIdentifier(barcode, server_type)) or {} if not product: raise falcon.HTTPNotFound(description=f"product {barcode} not found") + product_id = ProductIdentifier(barcode, server_type) else: product = req.media["product"] + product_id = ProductIdentifier("NULL", server_type) if "matcher" in predictors: if "lang" not in req.media: raise falcon.HTTPBadRequest( @@ -455,6 +487,7 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): resp.media = predict_category( product, + product_id, neural_predictor="neural" in predictors, matcher_predictor="matcher" in predictors, deepest_only=req.media.get("deepest_only", False), @@ -486,23 +519,19 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): barcode = req.get_param("barcode", required=True) image_url = req.get_param("image_url", required=True) ocr_url = req.get_param("ocr_url", required=True) - server_domain = req.get_param("server_domain", required=True) - - if server_domain != settings.BaseURLProvider.server_domain(): - logger.info("Rejecting image import from %s", server_domain) - resp.media = { - "status": "rejected", - } - return + server_domain: str = req.get_param( + "server_domain", default="api.openfoodfacts.org" + ) + check_server_domain(server_domain) + server_type = ServerType.get_from_server_domain(server_domain) enqueue_job( run_import_image_job, high_queue, job_kwargs={"result_ttl": 0}, - barcode=barcode, + product_id=ProductIdentifier(barcode, server_type), image_url=image_url, ocr_url=ocr_url, - server_domain=server_domain, ) resp.media = { "status": "scheduled", @@ -538,18 +567,12 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): inserts = [] for prediction in req.media["predictions"]: - server_domain: str = prediction.get( - "server_domain", settings.BaseURLProvider.server_domain() - ) - server_type: str = get_server_type(server_domain).name source_image = generate_image_path( prediction["barcode"], prediction.pop("image_id") ) inserts.append( { "timestamp": timestamp, - "server_domain": server_domain, - "server_type": server_type, "source_image": source_image, **prediction, } @@ -565,12 +588,12 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): model_name: Optional[str] = req.get_param("model_name") type_: Optional[str] = req.get_param("type") model_version: Optional[str] = req.get_param("model_version") - server_domain: Optional[str] = req.get_param("server_domain") barcode: Optional[str] = req.get_param("barcode") min_confidence: Optional[float] = req.get_param_as_float("min_confidence") random: bool = req.get_param_as_bool("random", default=True) + server_type = get_server_type_from_req(req) - where_clauses = [] + where_clauses = [ImageModel.server_type == server_type.name] if model_name is not None: where_clauses.append(ImagePrediction.model_name == model_name) @@ -581,9 +604,6 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): if type_ is not None: where_clauses.append(ImagePrediction.type == type_) - if server_domain: - where_clauses.append(ImageModel.server_domain == server_domain) - if min_confidence is not None: where_clauses.append(ImagePrediction.max_confidence >= min_confidence) @@ -682,6 +702,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): class ImageLogoSearchResource: def on_get(self, req: falcon.Request, resp: falcon.Response): + server_type = get_server_type_from_req(req) count: int = req.get_param_as_int( "count", min_value=1, max_value=2000, default=25 ) @@ -691,7 +712,6 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): taxonomy_value: Optional[str] = req.get_param("taxonomy_value") min_confidence: Optional[float] = req.get_param_as_float("min_confidence") random: bool = req.get_param_as_bool("random", default=False) - server_domain: Optional[str] = req.get_param("server_domain") annotated: Optional[bool] = req.get_param_as_bool("annotated") if type_ is None and (value is not None or taxonomy_value is not None): @@ -710,22 +730,15 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): "for label type" ) - where_clauses = [] + where_clauses = [ImageModel.server_type == server_type.name] if annotated is not None: where_clauses.append(LogoAnnotation.annotation_value.is_null(not annotated)) if min_confidence is not None: where_clauses.append(LogoAnnotation.score >= min_confidence) - join_image_prediction = False - join_image_model = False - - if server_domain: - where_clauses.append(ImageModel.server_domain == server_domain) - join_image_model = True - if barcode is not None: - where_clauses.append(LogoAnnotation.barcode == barcode) + where_clauses.append(ImageModel.barcode == barcode) if type_ is not None: where_clauses.append(LogoAnnotation.annotation_type == type_) @@ -738,13 +751,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): where_clauses.append(LogoAnnotation.taxonomy_value == taxonomy_value) query = LogoAnnotation.select() - join_image_prediction = join_image_prediction or join_image_model - - if join_image_prediction: - query = query.join(ImagePrediction) - - if join_image_model: - query = query.join(ImageModel) + query = query.join(ImagePrediction).join(ImageModel) if where_clauses: query = query.where(*where_clauses) @@ -796,7 +803,7 @@ def on_put(self, req: falcon.Request, resp: falcon.Response, logo_id: int): ) with db.atomic(): - logo = LogoAnnotation.get_or_none(id=logo_id) + logo: Optional[LogoAnnotation] = LogoAnnotation.get_or_none(id=logo_id) if logo is None: resp.status = falcon.HTTP_404 return @@ -811,8 +818,9 @@ def on_put(self, req: falcon.Request, resp: falcon.Response, logo_id: int): username=auth.get_username() or "unknown", completed_at=datetime.datetime.utcnow(), ) + server_type = ServerType[logo.image_prediction.image.server_type] generate_insights_from_annotated_logos( - annotated_logos, settings.BaseURLProvider.server_domain(), auth + annotated_logos, auth, server_type ) resp.status = falcon.HTTP_204 @@ -876,9 +884,7 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): raise falcon.HTTPForbidden( description="authentication is required to annotate logos" ) - server_domain = req.media.get( - "server_domain", settings.BaseURLProvider.server_domain() - ) + server_type: ServerType = ServerType[req.media.get("server_type", "off")] annotations = req.media["annotations"] completed_at = datetime.datetime.utcnow() annotation_logos = [] @@ -915,7 +921,7 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): high_queue, {"result_ttl": 0, "timeout": "5m"}, logo_ids=logo_ids, - server_domain=server_domain, + server_type=server_type, auth=auth, ) resp.media = {"annotated": len(annotated_logos)} @@ -998,6 +1004,23 @@ def on_get( } +SERVER_DOMAIN_REGEX = re.compile( + r"api(\.pro)?\.open(food|beauty|product|petfood)facts\.(org|net)" +) + + +def check_server_domain(server_domain: str): + if not SERVER_DOMAIN_REGEX.fullmatch(server_domain): + raise falcon.HTTPBadRequest(f"invalid `server_domain`: {server_domain}") + + tld = server_domain.rsplit(".", maxsplit=1)[-1] + instance_tld = settings._get_tld() + if tld != instance_tld: + raise falcon.HTTPBadRequest( + f"invalid `server_domain`, expected '{instance_tld}' tld, got '{tld}'" + ) + + class WebhookProductResource: """This handles requests from product opener that act as webhooks on product update or deletion. @@ -1006,17 +1029,13 @@ class WebhookProductResource: def on_post(self, req: falcon.Request, resp: falcon.Response): barcode = req.get_param("barcode", required=True) action = req.get_param("action", required=True) - server_domain = req.get_param("server_domain", required=True) - if server_domain != settings.BaseURLProvider.server_domain(): - logger.info("Rejecting webhook event from {}".format(server_domain)) - resp.media = { - "status": "rejected", - } - return - + server_domain: str = req.get_param("server_domain", required=True) + check_server_domain(server_domain) logger.info( - "New webhook event received for product {} (action: {}, " - "domain: {})".format(barcode, action, server_domain) + "New webhook event received for product %s (action: %s, domain: %s)", + barcode, + action, + server_domain, ) if action not in ("updated", "deleted"): raise falcon.HTTPBadRequest( @@ -1024,22 +1043,25 @@ def on_post(self, req: falcon.Request, resp: falcon.Response): description="action must be one of " "`deleted`, `updated`", ) - if action == "updated": + server_type = ServerType.get_from_server_domain(server_domain) + product_id = ProductIdentifier(barcode, server_type) + + # Only add the update insight job to the queue for Open Food Facts, + # as we don't have MongoDB connection for other projects yet + if action == "updated" and server_type == ServerType.off: enqueue_in_job( update_insights_job, high_queue, settings.UPDATED_PRODUCT_WAIT, job_kwargs={"result_ttl": 0}, - barcode=barcode, - server_domain=server_domain, + product_id=product_id, ) elif action == "deleted": enqueue_job( delete_product_insights_job, high_queue, job_kwargs={"result_ttl": 0}, - barcode=barcode, - server_domain=server_domain, + product_id=product_id, ) resp.media = { @@ -1057,10 +1079,10 @@ def on_get(self, req: falcon.Request, resp: falcon.Response, barcode: str): response: JSONType = {} count: int = req.get_param_as_int("count", min_value=1, default=25) lang: str = req.get_param("lang", default="en") + server_type = get_server_type_from_req(req) # If the device_id is not provided as a request parameter, we use the # hash of the IPs as a backup. device_id = device_id_from_request(req) - server_domain: Optional[str] = req.get_param("server_domain") auth: Optional[OFFAuthentication] = parse_auth(req) @@ -1069,8 +1091,8 @@ def on_get(self, req: falcon.Request, resp: falcon.Response, barcode: str): insights = list( get_insights( barcode=barcode, + server_type=server_type, keep_types=keep_types, - server_domain=server_domain, limit=count, order_by="n_votes", avoid_voted_on=_get_skip_voted_on(auth, device_id), @@ -1132,10 +1154,10 @@ def get_questions_resource_on_get( country: Optional[str] = req.get_param("country") value_tag: str = req.get_param("value_tag") brands = req.get_param_as_list("brands") or None - server_domain: Optional[str] = req.get_param("server_domain") reserved_barcode: Optional[bool] = req.get_param_as_bool( "reserved_barcode", default=False ) + server_type = get_server_type_from_req(req) # filter by annotation campaigns campaigns: Optional[list[str]] = req.get_param_as_list("campaigns") or None @@ -1168,9 +1190,9 @@ def get_questions_resource_on_get( get_insights_ = functools.partial( get_insights, + server_type=server_type, keep_types=keep_types, country=country, - server_domain=server_domain, value_tag=value_tag, brands=brands, order_by=order_by, @@ -1236,9 +1258,11 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): annotated = req.get_param_as_bool("annotated", blank_as_true=False) value_tag = req.get_param("value_tag") count = req.get_param_as_int("count", min_value=0, max_value=10_000) + server_type = get_server_type_from_req(req) get_insights_ = functools.partial( get_insights, + server_type=server_type, barcode=barcode, keep_types=keep_types, annotated=annotated, @@ -1291,13 +1315,13 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): "with_predictions", default=False ) barcode: Optional[str] = req.get_param("barcode") - server_domain = settings.BaseURLProvider.server_domain() + server_type = get_server_type_from_req(req) get_images_ = functools.partial( get_images, with_predictions=with_predictions, barcode=barcode, - server_domain=server_domain, + server_type=server_type, ) offset: int = (page - 1) * count @@ -1325,7 +1349,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): "insight_types", required=False ) brands = req.get_param_as_list("brands") or None - server_domain: Optional[str] = req.get_param("server_domain") + server_type = get_server_type_from_req(req) if keep_types: # Limit the number of types to prevent slow SQL queries @@ -1336,10 +1360,10 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): brands = brands[:10] query_parameters = { - "server_domain": server_domain, "keep_types": keep_types, "value_tag": value_tag, "barcode": barcode, + "server_type": server_type, } get_predictions_ = functools.partial(get_predictions, **query_parameters) @@ -1368,10 +1392,10 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): count: int = req.get_param_as_int("count", min_value=1, default=25) insight_type: str = req.get_param("type") country: Optional[str] = req.get_param("country") - server_domain: Optional[str] = req.get_param("server_domain") reserved_barcode: Optional[bool] = req.get_param_as_bool( "reserved_barcode", default=False ) + server_type = get_server_type_from_req(req) # filter by annotation campaigns campaigns: Optional[list[str]] = req.get_param_as_list("campaigns") or None if campaigns is None: @@ -1383,11 +1407,11 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): get_insights_ = functools.partial( get_insights, + server_type=server_type, keep_types=[insight_type] if insight_type else None, group_by_value_tag=True, limit=count, country=country, - server_domain=server_domain, automatically_processable=False, reserved_barcode=reserved_barcode, campaigns=campaigns, @@ -1417,13 +1441,13 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): with_logo: Optional[bool] = req.get_param_as_bool("with_logo", default=False) barcode: Optional[str] = req.get_param("barcode") type: Optional[str] = req.get_param("type") - server_domain: Optional[str] = req.get_param("server_domain") + server_type = get_server_type_from_req(req) query_parameters = { "with_logo": with_logo, "barcode": barcode, "type": type, - "server_domain": server_domain, + "server_type": server_type, } get_image_predictions_ = functools.partial( @@ -1450,21 +1474,21 @@ class LogoAnnotationCollection: def on_get(self, req: falcon.Request, resp: falcon.Response): response: JSONType = {} barcode: Optional[str] = req.get_param("barcode") + server_type = get_server_type_from_req(req) keep_types: Optional[list[str]] = req.get_param_as_list("types", required=False) value_tag: str = req.get_param("value_tag") page: int = req.get_param_as_int("page", min_value=1, default=1) count: int = req.get_param_as_int("count", min_value=1, default=25) - server_domain: Optional[str] = req.get_param("server_domain") if keep_types: # Limit the number of types to prevent slow SQL queries keep_types = keep_types[:10] query_parameters = { - "server_domain": server_domain, "barcode": barcode, "keep_types": keep_types, "value_tag": value_tag, + "server_type": server_type, } get_annotation_ = functools.partial(get_logo_annotation, **query_parameters) diff --git a/robotoff/app/core.py b/robotoff/app/core.py index 867a294427..30186e6343 100644 --- a/robotoff/app/core.py +++ b/robotoff/app/core.py @@ -6,7 +6,6 @@ import peewee from peewee import JOIN, SQL, fn -from robotoff import settings from robotoff.app import events from robotoff.insights.annotate import ( ALREADY_ANNOTATED_RESULT, @@ -26,6 +25,7 @@ ) from robotoff.off import OFFAuthentication from robotoff.taxonomy import match_taxonomized_value +from robotoff.types import ServerType from robotoff.utils import get_logger from robotoff.utils.text import get_tag @@ -67,6 +67,7 @@ def _add_vote_exclusions( def get_insights( barcode: Optional[str] = None, + server_type: ServerType = ServerType.off, keep_types: Optional[list[str]] = None, country: Optional[str] = None, brands: Optional[list[str]] = None, @@ -74,7 +75,6 @@ def get_insights( annotation: Optional[int] = None, order_by: Optional[Literal["random", "popularity", "n_votes", "confidence"]] = None, value_tag: Optional[str] = None, - server_domain: Optional[str] = None, reserved_barcode: Optional[bool] = None, as_dict: bool = False, limit: Optional[int] = 25, @@ -92,6 +92,8 @@ def get_insights( parameter. :param barcode: only keep insights with this barcode, defaults to None + :param server_type: the server type of the insights, defaults to + ServerType.off :param keep_types: only keep insights that have any of the these types, defaults to None :param country: only keep insights with this country, defaults to None @@ -106,8 +108,6 @@ def get_insights( decreasing confidence score (confidence) or don't order results (None), defaults to None :param value_tag: only keep insights with this value_tag, defaults to None - :param server_domain: Only keep insights with this server domain, defaults - to `BaseUrlProvider.server_domain()` :param reserved_barcode: only keep insights with reserved barcodes (True) or without reserved barcode (False), defaults to None :param as_dict: if True, return results as dict instead of ProductInsight @@ -132,10 +132,7 @@ def get_insights( - a iterable of objects or dict (if `as_dict=True`) containing product count for each `value_tag`, if `group_by_value_tag=True` """ - if server_domain is None: - server_domain = settings.BaseURLProvider.server_domain() - - where_clauses = [ProductInsight.server_domain == server_domain] + where_clauses = [ProductInsight.server_type == server_type.name] if annotated is not None: where_clauses.append(ProductInsight.annotation.is_null(not annotated)) @@ -217,17 +214,14 @@ def get_insights( def get_images( + server_type: ServerType, with_predictions: Optional[bool] = False, barcode: Optional[str] = None, - server_domain: Optional[str] = None, offset: Optional[int] = None, count: bool = False, limit: Optional[int] = 25, ) -> Iterable[ImageModel]: - if server_domain is None: - server_domain = settings.BaseURLProvider.server_domain() - - where_clauses = [ImageModel.server_domain == server_domain] + where_clauses = [ImageModel.server_type == server_type.name] if barcode: where_clauses.append(ImageModel.barcode == barcode) @@ -250,18 +244,15 @@ def get_images( def get_predictions( + server_type: ServerType, barcode: Optional[str] = None, keep_types: Optional[list[str]] = None, value_tag: Optional[str] = None, - server_domain: Optional[str] = None, limit: Optional[int] = 25, offset: Optional[int] = None, count: bool = False, ) -> Iterable[Prediction]: - if server_domain is None: - server_domain = settings.BaseURLProvider.server_domain() - - where_clauses = [Prediction.server_domain == server_domain] + where_clauses = [Prediction.server_type == server_type.name] if barcode: where_clauses.append(Prediction.barcode == barcode) @@ -286,10 +277,10 @@ def get_predictions( def get_image_predictions( + server_type: ServerType, with_logo: Optional[bool] = False, barcode: Optional[str] = None, type: Optional[str] = None, - server_domain: Optional[str] = None, offset: Optional[int] = None, count: bool = False, limit: Optional[int] = 25, @@ -297,11 +288,8 @@ def get_image_predictions( query = ImagePrediction.select() - if server_domain is None: - server_domain = settings.BaseURLProvider.server_domain() - query = query.switch(ImagePrediction).join(ImageModel) - where_clauses = [ImagePrediction.image.server_domain == server_domain] + where_clauses = [ImagePrediction.image.server_type == server_type.name] if barcode: where_clauses.append(ImagePrediction.image.barcode == barcode) @@ -437,29 +425,27 @@ def save_annotation( result = annotate(insight, annotation, update, data=data, auth=auth) username = auth.get_username() if auth else "unknown annotator" events.event_processor.send_async( - "question_answered", username, device_id, insight.barcode + "question_answered", + username, + device_id, + insight.barcode, + insight.server_type, ) return result def get_logo_annotation( + server_type: ServerType, barcode: Optional[str] = None, keep_types: Optional[list[str]] = None, value_tag: Optional[str] = None, - server_domain: Optional[str] = None, limit: Optional[int] = 25, offset: Optional[int] = None, count: bool = False, ) -> Iterable[LogoAnnotation]: - - if server_domain is None: - server_domain = settings.BaseURLProvider.server_domain() - query = LogoAnnotation.select().join(ImagePrediction).join(ImageModel) - where_clauses = [ - LogoAnnotation.image_prediction.image.server_domain == server_domain - ] + where_clauses = [ImageModel.server_type == server_type.name] if barcode: where_clauses.append(LogoAnnotation.barcode == barcode) diff --git a/robotoff/app/events.py b/robotoff/app/events.py index c873618573..febc691275 100644 --- a/robotoff/app/events.py +++ b/robotoff/app/events.py @@ -50,12 +50,14 @@ def send_event( user_id: str, device_id: str, barcode: Optional[str] = None, + server_type: Optional[str] = None, ): event = { "event_type": event_type, "user_id": user_id, "device_id": device_id, "barcode": barcode, + "server_type": server_type, } logger.debug("Event: %s", event) response = requests.post(api_url, json=event) diff --git a/robotoff/app/schema.py b/robotoff/app/schema.py index 7f39433075..0200e72f2e 100644 --- a/robotoff/app/schema.py +++ b/robotoff/app/schema.py @@ -1,4 +1,4 @@ -from robotoff.types import JSONType, NeuralCategoryClassifierModel +from robotoff.types import JSONType, NeuralCategoryClassifierModel, ServerType IMAGE_PREDICTION_IMPORTER_SCHEMA: JSONType = { "$schema": "http://json-schema.org/draft-07/schema#", @@ -10,7 +10,6 @@ "items": { "type": "object", "properties": { - "server_domain": {"type": "string"}, "barcode": {"type": "string"}, "image_id": {"type": "string"}, "model_name": {"type": "string"}, @@ -161,6 +160,11 @@ "required": ["value", "type", "logo_id"], }, }, + "server_type": { + "type": "string", + "enum": [server_type.name for server_type in ServerType], + "default": ServerType.off, + }, }, "required": ["annotations"], } diff --git a/robotoff/brands.py b/robotoff/brands.py index 99c11cf0b6..10922b7fe1 100644 --- a/robotoff/brands.py +++ b/robotoff/brands.py @@ -5,6 +5,7 @@ from robotoff import settings from robotoff.products import ProductDataset from robotoff.taxonomy import TaxonomyType, get_taxonomy +from robotoff.types import ServerType from robotoff.utils import ( dump_json, dump_text, @@ -98,12 +99,13 @@ def keep_brand_from_taxonomy( def generate_brand_list( threshold: int, + server_type: ServerType, min_length: Optional[int] = None, blacklisted_brands: Optional[set[str]] = None, ) -> list[tuple[str, str]]: min_length = min_length or 0 brand_taxonomy = get_taxonomy(TaxonomyType.brand.name) - url = settings.BaseURLProvider.world() + "/brands.json" + url = settings.BaseURLProvider.world(server_type) + "/brands.json" brand_count_list = http_session.get(url).json()["tags"] brand_count = {tag["id"]: tag for tag in brand_count_list} @@ -130,7 +132,11 @@ def dump_taxonomy_brands( min_length: Optional[int] = None, blacklisted_brands: Optional[set[str]] = None, ): - filtered_brands = generate_brand_list(threshold, min_length, blacklisted_brands) + # Only support OFF for now + server_type = ServerType.off + filtered_brands = generate_brand_list( + threshold, server_type, min_length, blacklisted_brands + ) line_iter = ("{}||{}".format(key, name) for key, name in filtered_brands) dump_text(settings.OCR_TAXONOMY_BRANDS_PATH, line_iter) diff --git a/robotoff/cli/insights.py b/robotoff/cli/insights.py index b41431ab23..94b1a138fd 100644 --- a/robotoff/cli/insights.py +++ b/robotoff/cli/insights.py @@ -13,7 +13,7 @@ from robotoff.off import get_barcode_from_path from robotoff.prediction.ocr import OCRResult, extract_predictions from robotoff.prediction.ocr.core import ocr_content_iter -from robotoff.types import Prediction, PredictionType +from robotoff.types import Prediction, PredictionType, ProductIdentifier, ServerType from robotoff.utils import get_logger, gzip_jsonl_iter, jsonl_iter logger = get_logger(__name__) @@ -22,10 +22,11 @@ def run_from_ocr_archive( input_path: Path, prediction_type: PredictionType, + server_type: ServerType, output: Optional[Path] = None, ): predictions = tqdm.tqdm( - generate_from_ocr_archive(input_path, prediction_type), desc="OCR" + generate_from_ocr_archive(input_path, prediction_type, server_type), desc="OCR" ) output_f: _io._TextIOBase need_decoding = False @@ -47,8 +48,7 @@ def run_from_ocr_archive( def generate_from_ocr_archive( - input_path: Path, - prediction_type: PredictionType, + input_path: Path, prediction_type: PredictionType, server_type: ServerType ) -> Iterable[Prediction]: json_iter = ( gzip_jsonl_iter(input_path) @@ -74,7 +74,10 @@ def generate_from_ocr_archive( continue yield from extract_predictions( - ocr_result, prediction_type, barcode=barcode, source_image=source_image + ocr_result, + prediction_type, + product_id=ProductIdentifier(barcode=barcode, server_type=server_type), + source_image=source_image, ) diff --git a/robotoff/cli/logos.py b/robotoff/cli/logos.py index 7833989a6a..a5aabd55a1 100644 --- a/robotoff/cli/logos.py +++ b/robotoff/cli/logos.py @@ -7,6 +7,7 @@ from robotoff.logos import filter_logos from robotoff.models import ImageModel, ImagePrediction, LogoAnnotation, db from robotoff.off import generate_image_path +from robotoff.types import ServerType from robotoff.utils import get_logger, jsonl_iter logger = get_logger(__name__) @@ -15,11 +16,12 @@ TYPE = "object_detection" -def get_seen_set() -> set[tuple[str, str]]: +def get_seen_set(server_type: ServerType) -> set[tuple[str, str]]: seen_set: set[tuple[str, str]] = set() for item in ( ImagePrediction.select(ImagePrediction.model_name, ImageModel.source_image) .join(ImageModel) + .where(ImageModel.server_type == server_type.name) .tuples() .iterator() ): @@ -28,10 +30,14 @@ def get_seen_set() -> set[tuple[str, str]]: def import_logos( - data_path: pathlib.Path, model_name: str, model_version: str, batch_size: int + data_path: pathlib.Path, + model_name: str, + model_version: str, + batch_size: int, + server_type: ServerType, ) -> int: logger.info("Loading seen set...") - seen_set = get_seen_set() + seen_set = get_seen_set(server_type) logger.info("Seen set loaded") inserted = 0 @@ -48,7 +54,9 @@ def import_logos( if key in seen_set: continue - image_instance = ImageModel.get_or_none(source_image=source_image) + image_instance = ImageModel.get_or_none( + source_image=source_image, server_type=server_type.name + ) if image_instance is None: logger.info("Unknown image in DB: %s", source_image) diff --git a/robotoff/cli/main.py b/robotoff/cli/main.py index c7a4e1104e..6d87503726 100644 --- a/robotoff/cli/main.py +++ b/robotoff/cli/main.py @@ -12,6 +12,8 @@ NeuralCategoryClassifierModel, ObjectDetectionModel, PredictionType, + ProductIdentifier, + ServerType, WorkerQueue, ) @@ -46,11 +48,13 @@ def run_worker( @app.command() def regenerate_ocr_insights( - barcode: str = typer.Argument(..., help="Barcode of the product") + barcode: str = typer.Argument(..., help="Barcode of the product"), + server_type: ServerType = typer.Option( + ServerType.off, help="Server type of the product" + ), ) -> None: """Regenerate OCR predictions/insights for a specific product and import them.""" - from robotoff import settings from robotoff.insights import importer from robotoff.insights.extraction import ( DEFAULT_OCR_PREDICTION_TYPES, @@ -63,7 +67,8 @@ def regenerate_ocr_insights( logger = get_logger() - product = get_product(barcode, ["images"]) + product_id = ProductIdentifier(barcode, server_type) + product = get_product(product_id, ["images"]) if product is None: raise ValueError(f"product not found: {barcode}") @@ -72,15 +77,13 @@ def regenerate_ocr_insights( if not image_id.isdigit(): continue - ocr_url = generate_json_ocr_url(barcode, image_id) + ocr_url = generate_json_ocr_url(product_id, image_id) predictions += extract_ocr_predictions( - barcode, ocr_url, DEFAULT_OCR_PREDICTION_TYPES + product_id, ocr_url, DEFAULT_OCR_PREDICTION_TYPES ) with db: - import_result = importer.import_insights( - predictions, settings.BaseURLProvider.server_domain() - ) + import_result = importer.import_insights(predictions, server_type) logger.info(import_result) @@ -92,6 +95,9 @@ def generate_ocr_predictions( prediction_type: PredictionType = typer.Argument( ..., help="Type of the predictions to generate (label, brand,...)" ), + server_type: ServerType = typer.Option( + ServerType.off, help="Server type of the archive" + ), output: Optional[Path] = typer.Option( None, help="File to write output to, stdout if not specified. Gzipped output are supported.", @@ -104,7 +110,7 @@ def generate_ocr_predictions( from robotoff.utils import get_logger get_logger() - insights.run_from_ocr_archive(input_path, prediction_type, output) + insights.run_from_ocr_archive(input_path, prediction_type, server_type, output) @app.command() @@ -137,6 +143,9 @@ def download_dataset(minify: bool = False) -> None: @app.command() def categorize( barcode: str, + server_type: ServerType = typer.Option( + ServerType.off, help="Server type of the product" + ), deepest_only: bool = False, model_name: NeuralCategoryClassifierModel = typer.Option( NeuralCategoryClassifierModel.keras_image_embeddings_3_0, @@ -159,20 +168,23 @@ def categorize( get_logger(level=logging.DEBUG) - product = get_product(barcode) + product_id = ProductIdentifier(barcode, server_type) + product = get_product(product_id) if product is None: - print(f"Product {barcode} not found") + print(f"{product_id} not found") return predictions, _ = CategoryClassifier( get_taxonomy(TaxonomyType.category.name, offline=True) - ).predict(product, deepest_only, threshold=threshold, model_name=model_name) + ).predict( + product, product_id, deepest_only, threshold=threshold, model_name=model_name + ) if predictions: for prediction in predictions: print(f"{prediction.value_tag}: {prediction.confidence}") else: - print(f"Nothing predicted for product {barcode}") + print(f"Nothing predicted for {product_id}") @app.command() @@ -191,6 +203,9 @@ def import_insights( generate_from: Optional[pathlib.Path] = typer.Option( None, help="Input path of the OCR archive, is incompatible with --input-path" ), + server_type: ServerType = typer.Option( + ServerType.off, help="Server type of the product" + ), ) -> None: """Import insights from a prediction JSONL archive (with --input-path option), or generate them on the fly from an OCR archive (with @@ -198,7 +213,6 @@ def import_insights( import tqdm from more_itertools import chunked - from robotoff import settings from robotoff.cli.insights import generate_from_ocr_archive, insights_iter from robotoff.insights import importer from robotoff.models import db @@ -211,7 +225,9 @@ def import_insights( if prediction_type is None: sys.exit("Required option: --prediction-type") - predictions = generate_from_ocr_archive(generate_from, prediction_type) + predictions = generate_from_ocr_archive( + generate_from, prediction_type, server_type + ) elif input_path is not None: logger.info(f"Importing insights from {input_path}") predictions = insights_iter(input_path) @@ -224,9 +240,7 @@ def import_insights( ): # Create a new transaction for every batch with db.atomic(): - import_results = importer.import_insights( - prediction_batch, settings.BaseURLProvider.server_domain() - ) + import_results = importer.import_insights(prediction_batch, server_type) logger.info(import_results) @@ -236,6 +250,9 @@ def refresh_insights( None, help="Refresh a specific product. If not provided, all products are updated", ), + server_type: ServerType = typer.Option( + ServerType.off, help="Server type of the product" + ), batch_size: int = typer.Option( 100, help="Number of products to send in a worker tasks" ), @@ -249,7 +266,6 @@ def refresh_insights( from more_itertools import chunked from peewee import fn - from robotoff import settings from robotoff.insights.importer import refresh_insights as refresh_insights_ from robotoff.models import Prediction as PredictionModel from robotoff.models import db @@ -260,23 +276,24 @@ def refresh_insights( logger = get_logger() if barcode is not None: - logger.info(f"Refreshing product {barcode}") + product_id = ProductIdentifier(barcode, server_type) + logger.info(f"Refreshing {product_id}") with db: - imported = refresh_insights_( - barcode, settings.BaseURLProvider.server_domain() - ) + imported = refresh_insights_(product_id) logger.info(f"Refreshed insights: {imported}") else: logger.info("Launching insight refresh on full database") with db: - barcodes = [ - barcode - for (barcode,) in PredictionModel.select( + product_ids = [ + ProductIdentifier(barcode, server_type) + for (barcode, server_type) in PredictionModel.select( fn.Distinct(PredictionModel.barcode) - ).tuples() + ) + .where(PredictionModel.server_type == server_type.name) + .tuples() ] - batches = list(chunked(barcodes, batch_size)) + batches = list(chunked(product_ids, batch_size)) confirm = typer.confirm( f"{len(batches)} jobs are going to be launched, confirm?" ) @@ -285,18 +302,20 @@ def refresh_insights( return logger.info("Adding refresh_insights jobs in queue...") - for barcode_batch in tqdm.tqdm(batches, desc="barcode batch"): + for product_id_batch in tqdm.tqdm(batches, desc="barcode batch"): enqueue_job( refresh_insights_job, low_queue, job_kwargs={"result_ttl": 0, "timeout": "5m"}, - barcodes=barcode_batch, - server_domain=settings.BaseURLProvider.server_domain(), + product_ids=product_id_batch, ) @app.command() def import_images_in_db( + server_type: ServerType = typer.Option( + ServerType.off, help="Server type of the product" + ), batch_size: int = typer.Option( 500, help="Number of items to send in a worker tasks" ), @@ -306,10 +325,9 @@ def import_images_in_db( import tqdm from more_itertools import chunked - from robotoff import settings from robotoff.models import ImageModel, db from robotoff.off import generate_image_path - from robotoff.products import get_product_store + from robotoff.products import DBProductStore, get_product_store from robotoff.utils import get_logger from robotoff.workers.queues import enqueue_job, low_queue from robotoff.workers.tasks.import_image import save_image_job @@ -319,18 +337,25 @@ def import_images_in_db( with db: logger.info("Fetching existing images in DB...") existing_images = set( - ImageModel.select(ImageModel.barcode, ImageModel.image_id).tuples() + ImageModel.select(ImageModel.barcode, ImageModel.image_id) + .where(ImageModel.server_type == server_type.name) + .tuples() ) - store = get_product_store() - to_add = [] + store: DBProductStore = get_product_store(server_type) + to_add: list[tuple[ProductIdentifier, str]] = [] for product in tqdm.tqdm( store.iter_product(projection=["images", "code"]), desc="product" ): barcode = product.barcode for image_id in (id_ for id_ in product.images.keys() if id_.isdigit()): if (barcode, image_id) not in existing_images: - to_add.append((barcode, generate_image_path(barcode, image_id))) + to_add.append( + ( + ProductIdentifier(barcode, server_type), + generate_image_path(barcode, image_id), + ) + ) batches = list(chunked(to_add, batch_size)) if typer.confirm( @@ -342,12 +367,15 @@ def import_images_in_db( low_queue, job_kwargs={"result_ttl": 0}, batch=batch, - server_domain=settings.BaseURLProvider.server_domain(), + server_type=server_type, ) @app.command() def run_object_detection_model( + server_type: ServerType = typer.Option( + ServerType.off, help="Server type of the product" + ), model_name: ObjectDetectionModel = typer.Argument( ..., help="Name of the object detection model" ), @@ -370,7 +398,6 @@ def run_object_detection_model( import tqdm from peewee import JOIN - from robotoff import settings from robotoff.models import ImageModel, ImagePrediction, db from robotoff.off import generate_image_url from robotoff.utils import text_file_iter @@ -412,15 +439,17 @@ def run_object_detection_model( ), ) .where( - ImagePrediction.model_name.is_null() - & (ImageModel.deleted == False) # noqa: E712 + ImageModel.server_type + == server_type.name + & ImagePrediction.model_name.is_null() + & (ImageModel.deleted == False), # noqa: E712 ) .tuples() ) if limit: query = query.limit(limit) image_urls = [ - generate_image_url(barcode, image_id) + generate_image_url(ProductIdentifier(barcode, server_type), image_id) for barcode, image_id in query if barcode.isdigit() ] @@ -428,13 +457,14 @@ def run_object_detection_model( if typer.confirm(f"{len(image_urls)} jobs are going to be launched, confirm?"): for image_url in tqdm.tqdm(image_urls, desc="image"): barcode = get_barcode_from_url(image_url) + if barcode is None: + raise RuntimeError() enqueue_job( func, low_queue, job_kwargs={"result_ttl": 0}, - barcode=barcode, + product_id=ProductIdentifier(barcode, server_type), image_url=image_url, - server_domain=settings.BaseURLProvider.server_domain(), ) @@ -507,6 +537,7 @@ def add_logo_to_ann( def refresh_logo_nearest_neighbors( day_offset: int = typer.Option(7, help="Number of days since last refresh", min=1), batch_size: int = typer.Option(500, help="Number of logos to process at once"), + server_type: ServerType = typer.Option(ServerType.off, help="Server type"), ): """Refresh each logo nearest neighbors if the last refresh is more than `day_offset` days old.""" @@ -522,7 +553,7 @@ def refresh_logo_nearest_neighbors( logger.info("Starting refresh of logo nearest neighbors") with db.connection_context(): - refresh_nearest_neighbors(day_offset, batch_size) + refresh_nearest_neighbors(server_type, day_offset, batch_size) @app.command() @@ -593,6 +624,7 @@ def import_logos( batch_size: int = typer.Option( 1024, help="Number of predictions to insert in DB in a single SQL transaction" ), + server_type: ServerType = typer.Option(ServerType.off, help="Server type"), ) -> None: """Import object detection predictions for universal-logo-detector model. @@ -626,6 +658,7 @@ def import_logos( ObjectDetectionModel.universal_logo_detector ], batch_size, + server_type, ) logger.info("%s image predictions created", imported) @@ -668,6 +701,9 @@ def import_image_webhook( help="URL of the image to import to the output file, can either have .jsonl or .jsonl.gz as " "extension", ), + server_domain: str = typer.Option( + "api.openfoodfacts.net", help="Server domain to use for image import" + ), ) -> None: """Import an image in Robotoff by calling POST /api/v1/images/import. @@ -677,7 +713,6 @@ def import_image_webhook( import os from robotoff.off import get_barcode_from_url - from robotoff.settings import BaseURLProvider from robotoff.utils import get_logger, http_session logger = get_logger() @@ -696,7 +731,7 @@ def import_image_webhook( "barcode": barcode, "image_url": image_url, "ocr_url": ocr_url, - "server_domain": BaseURLProvider.server_domain(), + "server_domain": server_domain, }, ) if not r.ok: diff --git a/robotoff/images.py b/robotoff/images.py index dff2626cf3..650681a62c 100644 --- a/robotoff/images.py +++ b/robotoff/images.py @@ -3,39 +3,37 @@ from typing import Optional from robotoff.models import ImageModel -from robotoff.off import generate_image_path, generate_image_url, get_server_type -from robotoff.settings import BaseURLProvider -from robotoff.types import JSONType +from robotoff.off import generate_image_path, generate_image_url +from robotoff.types import JSONType, ProductIdentifier from robotoff.utils import get_image_from_url, get_logger, http_session logger = get_logger(__name__) def save_image( - barcode: str, + product_id: ProductIdentifier, source_image: str, image_url: str, images: Optional[JSONType], - server_domain: str, ) -> Optional[ImageModel]: """Save imported image details in DB. - :param barcode: barcode of the product + :param product_id: identifier of the product :param source_image: source image, in the format '/325/543/254/5234/1.jpg' :param image_url: URL of the image, only used to get image size if images is None :param images: image dict mapping image ID to image metadata, as returned by Product Opener API, is None if product validity check is disabled (`DISABLE_PRODUCT_CHECK=True`) - :param server_domain: the server domain to use, default to - BaseURLProvider.server_domain() :return: this function return either: - the ImageModel of the image if it already exist in DB - None if the image is non raw (non-digit image ID), if it's not referenced in `images` or if there are no size information - the created ImageModel otherwise """ - if existing_image_model := ImageModel.get_or_none(source_image=source_image): + if existing_image_model := ImageModel.get_or_none( + source_image=source_image, server_type=product_id.server_type.name + ): logger.info( f"Image {source_image} already exist in DB, returning existing image" ) @@ -49,7 +47,7 @@ def save_image( if images is not None: if image_id not in images: - logger.info("Unknown image for product %s: %s", barcode, source_image) + logger.info("Unknown image for %s: %s", product_id, source_image) return None image = images[image_id] @@ -87,44 +85,42 @@ def save_image( height = image.height image_model = ImageModel.create( - barcode=barcode, + barcode=product_id.barcode, image_id=image_id, width=width, height=height, source_image=source_image, uploaded_at=uploaded_at, - server_domain=server_domain, - server_type=get_server_type(server_domain).name, + server_type=product_id.server_type.name, ) if image_model is not None: logger.info("New image %s created in DB", image_model.id) return image_model -def refresh_images_in_db( - barcode: str, images: JSONType, server_domain: Optional[str] = None -): +def refresh_images_in_db(product_id: ProductIdentifier, images: JSONType): """Make sure all raw images present in `images` exist in DB in image table. - :param barcode: barcode of the product + :param product_id: identifier of the product :param images: image dict mapping image ID to image metadata, as returned by Product Opener API - :param server_domain: the server domain to use, default to - BaseURLProvider.server_domain() """ - server_domain = server_domain or BaseURLProvider.server_domain() image_ids = [image_id for image_id in images.keys() if image_id.isdigit()] existing_image_ids = set( image_id for (image_id,) in ImageModel.select(ImageModel.image_id) - .where(ImageModel.barcode == barcode, ImageModel.image_id.in_(image_ids)) + .where( + ImageModel.barcode == product_id.barcode, + ImageModel.server_type == product_id.server_type.name, + ImageModel.image_id.in_(image_ids), + ) .tuples() .iterator() ) missing_image_ids = set(image_ids) - existing_image_ids for missing_image_id in missing_image_ids: - source_image = generate_image_path(barcode, missing_image_id) - image_url = generate_image_url(barcode, missing_image_id) + source_image = generate_image_path(product_id.barcode, missing_image_id) + image_url = generate_image_url(product_id, missing_image_id) logger.debug("Creating missing image %s in DB", source_image) - save_image(barcode, source_image, image_url, images, server_domain) + save_image(product_id, source_image, image_url, images) diff --git a/robotoff/insights/annotate.py b/robotoff/insights/annotate.py index f5760138ca..b3387fb156 100644 --- a/robotoff/insights/annotate.py +++ b/robotoff/insights/annotate.py @@ -197,7 +197,8 @@ def process_annotation( ) -> AnnotationResult: emb_code: str = insight.value - product = get_product(insight.barcode, ["emb_codes"]) + product_id = insight.get_product_id() + product = get_product(product_id, ["emb_codes"]) if product is None: return MISSING_PRODUCT_RESULT @@ -213,9 +214,8 @@ def process_annotation( emb_codes.append(emb_code) update_emb_codes( - insight.barcode, + product_id, emb_codes, - server_domain=insight.server_domain, insight_id=insight.id, auth=auth, ) @@ -241,7 +241,8 @@ def process_annotation( data: Optional[dict] = None, auth: Optional[OFFAuthentication] = None, ) -> AnnotationResult: - product = get_product(insight.barcode, ["labels_tags"]) + product_id = insight.get_product_id() + product = get_product(product_id, ["labels_tags"]) if product is None: return MISSING_PRODUCT_RESULT @@ -252,10 +253,9 @@ def process_annotation( return ALREADY_ANNOTATED_RESULT add_label_tag( - insight.barcode, + product_id, insight.value_tag, insight_id=insight.id, - server_domain=insight.server_domain, auth=auth, ) return UPDATED_ANNOTATION_RESULT @@ -269,10 +269,10 @@ def process_annotation( data: Optional[dict] = None, auth: Optional[OFFAuthentication] = None, ) -> AnnotationResult: - barcode = insight.barcode + product_id = insight.get_product_id() lang = insight.data["lang"] field_name = "ingredients_text_{}".format(lang) - product = get_product(barcode, [field_name]) + product = get_product(product_id, [field_name]) if product is None: return MISSING_PRODUCT_RESULT @@ -283,9 +283,8 @@ def process_annotation( if expected_ingredients != original_ingredients: logger.warning( - "ingredients have changed since spellcheck insight " - "creation (product %s)", - barcode, + "ingredients have changed since spellcheck insight " "creation (%s)", + product_id, ) return AnnotationResult( status_code=AnnotationStatus.error_updated_product.value, @@ -294,7 +293,7 @@ def process_annotation( ) save_ingredients( - barcode, + product_id, corrected, lang=lang, insight_id=insight.id, @@ -311,7 +310,8 @@ def process_annotation( data: Optional[dict] = None, auth: Optional[OFFAuthentication] = None, ) -> AnnotationResult: - product = get_product(insight.barcode, ["categories_tags"]) + product_id = insight.get_product_id() + product = get_product(product_id, ["categories_tags"]) if product is None: return MISSING_PRODUCT_RESULT @@ -323,10 +323,9 @@ def process_annotation( category_tag = insight.value_tag add_category( - insight.barcode, + product_id, category_tag, insight_id=insight.id, - server_domain=insight.server_domain, auth=auth, ) return UPDATED_ANNOTATION_RESULT @@ -340,7 +339,8 @@ def process_annotation( data: Optional[dict] = None, auth: Optional[OFFAuthentication] = None, ) -> AnnotationResult: - product = get_product(insight.barcode, ["quantity"]) + product_id = insight.get_product_id() + product = get_product(product_id, ["quantity"]) if product is None: return MISSING_PRODUCT_RESULT @@ -351,10 +351,9 @@ def process_annotation( return ALREADY_ANNOTATED_RESULT update_quantity( - insight.barcode, + product_id, insight.value, insight_id=insight.id, - server_domain=insight.server_domain, auth=auth, ) return UPDATED_ANNOTATION_RESULT @@ -368,7 +367,8 @@ def process_annotation( data: Optional[dict] = None, auth: Optional[OFFAuthentication] = None, ) -> AnnotationResult: - product = get_product(insight.barcode, ["expiration_date"]) + product_id = insight.get_product_id() + product = get_product(product_id, ["expiration_date"]) if product is None: return MISSING_PRODUCT_RESULT @@ -379,10 +379,9 @@ def process_annotation( return ALREADY_ANNOTATED_RESULT update_expiration_date( - insight.barcode, + product_id, insight.value, insight_id=insight.id, - server_domain=insight.server_domain, auth=auth, ) return UPDATED_ANNOTATION_RESULT @@ -396,16 +395,16 @@ def process_annotation( data: Optional[dict] = None, auth: Optional[OFFAuthentication] = None, ) -> AnnotationResult: - product = get_product(insight.barcode, ["brands_tags"]) + product_id = insight.get_product_id() + product = get_product(product_id, ["brands_tags"]) if product is None: return MISSING_PRODUCT_RESULT add_brand( - insight.barcode, + product_id, insight.value, insight_id=insight.id, - server_domain=insight.server_domain, auth=auth, ) @@ -420,7 +419,8 @@ def process_annotation( data: Optional[dict] = None, auth: Optional[OFFAuthentication] = None, ) -> AnnotationResult: - product = get_product(insight.barcode, ["stores_tags"]) + product_id = insight.get_product_id() + product = get_product(product_id, ["stores_tags"]) if product is None: return MISSING_PRODUCT_RESULT @@ -431,10 +431,9 @@ def process_annotation( return ALREADY_ANNOTATED_RESULT add_store( - insight.barcode, + product_id, insight.value, insight_id=insight.id, - server_domain=insight.server_domain, auth=auth, ) return UPDATED_ANNOTATION_RESULT @@ -448,16 +447,16 @@ def process_annotation( data: Optional[dict] = None, auth: Optional[OFFAuthentication] = None, ) -> AnnotationResult: - product = get_product(insight.barcode, ["code"]) + product_id = insight.get_product_id() + product = get_product(product_id, ["code"]) if product is None: return MISSING_PRODUCT_RESULT add_packaging( - insight.barcode, + product_id, insight.data["element"], insight_id=insight.id, - server_domain=insight.server_domain, auth=auth, ) return UPDATED_ANNOTATION_RESULT @@ -471,7 +470,8 @@ def process_annotation( data: Optional[dict] = None, auth: Optional[OFFAuthentication] = None, ) -> AnnotationResult: - product = get_product(insight.barcode, ["code"]) + product_id = insight.get_product_id() + product = get_product(product_id, ["code"]) if product is None: return MISSING_PRODUCT_RESULT @@ -486,11 +486,10 @@ def process_annotation( ) image_key = "nutrition_{}".format(insight.value_tag) select_rotate_image( - barcode=insight.barcode, + product_id=product_id, image_id=image_id, image_key=image_key, rotate=insight.data.get("rotation"), - server_domain=insight.server_domain, auth=auth, ) return UPDATED_ANNOTATION_RESULT diff --git a/robotoff/insights/extraction.py b/robotoff/insights/extraction.py index f79170e2a5..5edf86b2e2 100644 --- a/robotoff/insights/extraction.py +++ b/robotoff/insights/extraction.py @@ -11,7 +11,12 @@ ObjectDetectionModelRegistry, ) from robotoff.prediction.ocr.core import get_ocr_result -from robotoff.types import ObjectDetectionModel, Prediction, PredictionType +from robotoff.types import ( + ObjectDetectionModel, + Prediction, + PredictionType, + ProductIdentifier, +) from robotoff.utils import get_logger, http_session logger = get_logger(__name__) @@ -92,12 +97,12 @@ def run_object_detection_model( def get_predictions_from_product_name( - barcode: str, product_name: str + product_id: ProductIdentifier, product_name: str ) -> list[Prediction]: predictions_all = [] for prediction_type in PRODUCT_NAME_PREDICTION_TYPES: predictions = ocr.extract_predictions( - product_name, prediction_type, barcode=barcode + product_name, prediction_type, product_id=product_id ) for prediction in predictions: prediction.data["source"] = "product_name" @@ -110,7 +115,9 @@ def get_predictions_from_product_name( def extract_ocr_predictions( - barcode: str, ocr_url: str, prediction_types: Iterable[PredictionType] + product_id: ProductIdentifier, + ocr_url: str, + prediction_types: Iterable[PredictionType], ) -> list[Prediction]: logger.info("Generating OCR predictions from OCR %s", ocr_url) @@ -123,7 +130,10 @@ def extract_ocr_predictions( for prediction_type in prediction_types: predictions_all += ocr.extract_predictions( - ocr_result, prediction_type, barcode=barcode, source_image=source_image + ocr_result, + prediction_type, + product_id=product_id, + source_image=source_image, ) return predictions_all diff --git a/robotoff/insights/importer.py b/robotoff/insights/importer.py index 58c1e25e6e..93a04168fd 100644 --- a/robotoff/insights/importer.py +++ b/robotoff/insights/importer.py @@ -13,7 +13,6 @@ from robotoff.insights.normalize import normalize_emb_code from robotoff.models import Prediction as PredictionModel from robotoff.models import ProductInsight, batch_insert -from robotoff.off import get_server_type from robotoff.prediction.ocr.packaging import SHAPE_ONLY_EXCLUDE_SET from robotoff.products import ( DBProductStore, @@ -37,7 +36,9 @@ Prediction, PredictionImportResult, PredictionType, + ProductIdentifier, ProductInsightImportResult, + ServerType, ) from robotoff.utils import get_logger, text_file_iter from robotoff.utils.cache import CachedStore @@ -167,14 +168,14 @@ def convert_bounding_box_absolute_to_relative( def get_existing_insight( - insight_type: InsightType, barcode: str, server_domain: str + insight_type: InsightType, product_id: ProductIdentifier ) -> list[ProductInsight]: """Get all insights for specific product and `insight_type`.""" return list( ProductInsight.select().where( ProductInsight.type == insight_type.name, - ProductInsight.barcode == barcode, - ProductInsight.server_domain == server_domain, + ProductInsight.barcode == product_id.barcode, + ProductInsight.server_type == product_id.server_type.name, ) ) @@ -275,9 +276,8 @@ def get_required_prediction_types() -> set[PredictionType]: @classmethod def import_insights( cls, - barcode: str, + product_id: ProductIdentifier, predictions: list[Prediction], - server_domain: str, product_store: DBProductStore, ) -> ProductInsightImportResult: """Import insights, this is the main method. @@ -303,7 +303,7 @@ def import_insights( inserts = 0 to_create, to_update, to_delete = cls.generate_insights( - barcode, predictions, server_domain, product_store + product_id, predictions, product_store ) to_delete_ids = [insight.id for insight in to_delete] if to_delete_ids: @@ -325,7 +325,7 @@ def import_insights( for field_name in ( key for key in insight.__data__.keys() - if key not in ("id", "barcode", "type", "server_domain", "server_type") + if key not in ("id", "barcode", "type", "server_type") ): if getattr(insight, field_name) != getattr( reference_insight, field_name @@ -342,16 +342,15 @@ def import_insights( insight_created_ids=created_ids, insight_deleted_ids=to_delete_ids, insight_updated_ids=updated_ids, - barcode=barcode, + product_id=product_id, type=cls.get_type(), ) @classmethod def generate_insights( cls, - barcode: str, + product_id: ProductIdentifier, predictions: list[Prediction], - server_domain: str, product_store: DBProductStore, ) -> tuple[ list[ProductInsight], @@ -365,17 +364,14 @@ def generate_insights( (and implemented in sub-classes). """ timestamp = datetime.datetime.utcnow() - server_type = get_server_type(server_domain).name - product = product_store[barcode] - references = get_existing_insight(cls.get_type(), barcode, server_domain) + product = product_store[product_id] + references = get_existing_insight(cls.get_type(), product_id) # If `DISABLE_PRODUCT_CHECK` is False (default, production settings), we # stop the import process and delete all associated insights if product is None and not settings.DISABLE_PRODUCT_CHECK: - logger.info( - f"Product {barcode} not found in DB, deleting existing insights" - ) + logger.info("%s not found in DB, deleting existing insights", product_id) return [], [], references predictions = cls.sort_predictions(predictions) @@ -419,16 +415,10 @@ def generate_insights( to_create, to_update, to_delete = cls.get_insight_update(candidates, references) for insight in to_create: - cls.add_fields(insight, product, timestamp, server_domain, server_type) + cls.add_fields(insight, product, timestamp) - for insight, reference_insight in to_update: - cls.add_fields( - insight, - product, - timestamp, - reference_insight.server_domain, - reference_insight.server_type, - ) + for insight, _ in to_update: + cls.add_fields(insight, product, timestamp) return (to_create, to_update, to_delete) @@ -623,14 +613,10 @@ def add_fields( insight: ProductInsight, product: Optional[Product], timestamp: datetime.datetime, - server_domain: str, - server_type: str, ): """Add mandatory insight fields.""" barcode = insight.barcode insight.reserved_barcode = is_reserved_barcode(barcode) - insight.server_domain = server_domain - insight.server_type = server_type insight.id = str(uuid.uuid4()) insight.timestamp = timestamp insight.n_votes = 0 @@ -1255,49 +1241,45 @@ def is_valid_product_prediction( return True -def create_prediction_model( - prediction: Prediction, - server_domain: str, - timestamp: datetime.datetime, -): +def create_prediction_model(prediction: Prediction, timestamp: datetime.datetime): prediction_dict = prediction.to_dict() prediction_dict.pop("id") - return { - **prediction_dict, - "timestamp": timestamp, - "server_domain": server_domain, - } + return {**prediction_dict, "timestamp": timestamp} def import_product_predictions( barcode: str, + server_type: ServerType, product_predictions_iter: Iterable[Prediction], - server_domain: str, ): """Import predictions for a specific product. - If a prediction already exists in DB (same (barcode, type, server_domain, + If a prediction already exists in DB (same (barcode, type, source_image, value, value_tag, predictor, automatic_processing)), it won't be imported. :param barcode: Barcode of the product. All `product_predictions` must - have the same barcode. + have the same barcode. + :param server_type: the server type (project) of the product, all + `product_predictions` must have the same `server_type`. :param product_predictions_iter: Iterable of Predictions. - :param server_domain: The server domain associated with the predictions. :return: The number of items imported in DB. """ timestamp = datetime.datetime.utcnow() existing_predictions = set( PredictionModel.select( PredictionModel.type, - PredictionModel.server_domain, + PredictionModel.server_type, PredictionModel.source_image, PredictionModel.value_tag, PredictionModel.value, PredictionModel.predictor, PredictionModel.automatic_processing, ) - .where(PredictionModel.barcode == barcode) + .where( + PredictionModel.barcode == barcode, + PredictionModel.server_type == server_type.name, + ) .tuples() ) @@ -1305,11 +1287,11 @@ def import_product_predictions( # when we could decide to replace old predictions of the same key. # It's not yet implemented. to_import = ( - create_prediction_model(prediction, server_domain, timestamp) + create_prediction_model(prediction, timestamp) for prediction in product_predictions_iter if ( prediction.type, - server_domain, + prediction.server_type.name, prediction.source_image, prediction.value_tag, prediction.value, @@ -1335,22 +1317,25 @@ def import_product_predictions( def import_insights( predictions: Iterable[Prediction], - server_domain: str, + server_type: ServerType, product_store: Optional[DBProductStore] = None, ) -> InsightImportResult: """Import predictions and generate (and import) insights from these predictions. :param predictions: an iterable of Predictions to import + :param server_type: the server type (project) of the product + :param product_store: a ProductStore to use, by defaults + DBProductStore (MongoDB-based product store) is used. """ if product_store is None: - product_store = get_product_store() + product_store = get_product_store(server_type) updated_prediction_types_by_barcode, prediction_import_results = import_predictions( - predictions, product_store, server_domain + predictions, product_store, server_type ) product_insight_import_results = import_insights_for_products( - updated_prediction_types_by_barcode, server_domain, product_store + updated_prediction_types_by_barcode, product_store, server_type ) return InsightImportResult( product_insight_import_results=product_insight_import_results, @@ -1360,15 +1345,15 @@ def import_insights( def import_insights_for_products( prediction_types_by_barcode: dict[str, set[PredictionType]], - server_domain: str, product_store: DBProductStore, + server_type: ServerType, ) -> list[ProductInsightImportResult]: """Re-compute insights for products with new predictions. :param prediction_types_by_barcode: a dict that associates each barcode - with a set of prediction type that were updated - :param server_domain: The server domain associated with the predictions + with a set of prediction type that were updated :param product_store: The product store to use + :param server_type: the server type (project) of the product :return: Number of imported insights """ @@ -1384,7 +1369,7 @@ def import_insights_for_products( predictions = [ Prediction(**p) for p in get_product_predictions( - selected_barcodes, list(required_prediction_types) + selected_barcodes, server_type, list(required_prediction_types) ) ] @@ -1392,19 +1377,23 @@ def import_insights_for_products( sorted(predictions, key=operator.attrgetter("barcode")), operator.attrgetter("barcode"), ): + product_id = ProductIdentifier(barcode, server_type) try: - with Lock(name=f"robotoff:import:{barcode}", expire=60, timeout=10): + with Lock( + name=f"robotoff:import:{product_id.server_type.name}:{product_id.barcode}", + expire=60, + timeout=10, + ): result = importer.import_insights( - barcode, + product_id, list(product_predictions), - server_domain, product_store, ) import_results.append(result) except LockedResourceException: logger.info( - "Couldn't acquire insight import lock, skipping insight import for product %s", - barcode, + "Couldn't acquire insight import lock, skipping insight import for %s", + product_id, ) continue return import_results @@ -1413,13 +1402,12 @@ def import_insights_for_products( def import_predictions( predictions: Iterable[Prediction], product_store: DBProductStore, - server_domain: str, + server_type: ServerType, ) -> tuple[dict[str, set[PredictionType]], list[PredictionImportResult]]: """Check validity and import provided Prediction. :param predictions: the Predictions to import :param product_store: The product store to use - :param server_domain: The server domain associated with the predictions :return: dict associating each barcode with prediction types that where updated in order to re-compute associated insights """ @@ -1429,7 +1417,7 @@ def import_predictions( if ( # If product validity check is disable, all predictions are valid settings.DISABLE_PRODUCT_CHECK - or is_valid_product_prediction(p, product_store[p.barcode]) # type: ignore + or is_valid_product_prediction(p, product_store[ProductIdentifier(p.barcode, server_type)]) # type: ignore ) ] @@ -1441,10 +1429,12 @@ def import_predictions( ): product_predictions_group = list(product_predictions_iter) predictions_imported = import_product_predictions( - barcode, product_predictions_group, server_domain + barcode, server_type, product_predictions_group ) predictions_import_results.append( - PredictionImportResult(created=predictions_imported, barcode=barcode) + PredictionImportResult( + created=predictions_imported, barcode=barcode, server_type=server_type + ) ) updated_prediction_types_by_barcode[barcode] = set( prediction.type for prediction in product_predictions_group @@ -1453,8 +1443,7 @@ def import_predictions( def refresh_insights( - barcode: str, - server_domain: str, + product_id: ProductIdentifier, product_store: Optional[DBProductStore] = None, ) -> list[InsightImportResult]: """Refresh all insights for specific product. @@ -1467,15 +1456,17 @@ def refresh_insights( predictions. It's useful to refresh insights after an Product Opener update (some insights may be invalid). - :param barcode: Barcode of the product. - :param server_domain: The server domain associated with the predictions. + :param product_id: identifier of the product :param product_store: The product store to use, defaults to None :return: The number of imported insights. """ if product_store is None: - product_store = get_product_store() + product_store = get_product_store(product_id.server_type) - predictions = [Prediction(**p) for p in get_product_predictions([barcode])] + predictions = [ + Prediction(**p) + for p in get_product_predictions([product_id.barcode], product_id.server_type) + ] prediction_types = set(p.type for p in predictions) import_results = [] @@ -1483,9 +1474,8 @@ def refresh_insights( required_prediction_types = importer.get_required_prediction_types() if prediction_types >= required_prediction_types: import_result = importer.import_insights( - barcode, + product_id, [p for p in predictions if p.type in required_prediction_types], - server_domain, product_store, ) import_results.append(import_result) @@ -1493,9 +1483,14 @@ def refresh_insights( def get_product_predictions( - barcodes: list[str], prediction_types: Optional[list[str]] = None + barcodes: list[str], + server_type: ServerType, + prediction_types: Optional[list[str]] = None, ) -> Iterator[dict]: - where_clauses = [PredictionModel.barcode.in_(barcodes)] + where_clauses = [ + PredictionModel.barcode.in_(barcodes), + PredictionModel.server_type == server_type.name, + ] if prediction_types is not None: where_clauses.append(PredictionModel.type.in_(prediction_types)) diff --git a/robotoff/insights/question.py b/robotoff/insights/question.py index fa309e96e6..9b014b2126 100644 --- a/robotoff/insights/question.py +++ b/robotoff/insights/question.py @@ -4,11 +4,10 @@ from robotoff import settings from robotoff.models import ProductInsight -from robotoff.mongo import MONGO_CLIENT_CACHE from robotoff.off import generate_image_url from robotoff.products import get_product from robotoff.taxonomy import Taxonomy, TaxonomyType, get_taxonomy -from robotoff.types import InsightType, JSONType +from robotoff.types import InsightType, JSONType, ProductIdentifier from robotoff.utils import get_logger, load_json from robotoff.utils.i18n import TranslationStore @@ -23,13 +22,13 @@ def generate_selected_images( - images: JSONType, barcode: str + images: JSONType, product_id: ProductIdentifier ) -> dict[str, dict[str, dict[str, str]]]: """Generate the same `selected_images` field as returned by Product Opener API. :param images: the `images` data of the product - :param barcode: the product barcode + :param product_id: the server type (project) of the product :return: the `selected_images` data """ selected_images: dict[str, dict[str, dict[str, str]]] = { @@ -60,7 +59,7 @@ def generate_selected_images( ): if image_size in available_image_sizes: image_url = generate_image_url( - barcode, f"{key}.{revision_id}.{image_size}" + product_id, f"{key}.{revision_id}.{image_size}" ) selected_images[image_type].setdefault(field_name, {}) selected_images[image_type][field_name][language] = image_url @@ -69,7 +68,7 @@ def generate_selected_images( def get_source_image_url( - barcode: str, field_types: Optional[list[str]] = None + product_id: ProductIdentifier, field_types: Optional[list[str]] = None ) -> Optional[str]: """Generate the URL of a generic image to display for an insight. @@ -77,7 +76,7 @@ def get_source_image_url( language of the following types ("front", "ingredients", "nutrition"), and use this image to generate the image URL. - :param barcode: the barcode of the product + :param product_id: identifier of the product :param field_types: the image field types to check. If not provided, we use ["front", "ingredients", "nutrition"] :return: The image URL or None if no suitable image has been found @@ -85,12 +84,12 @@ def get_source_image_url( if field_types is None: field_types = ["front", "ingredients", "nutrition"] - product: Optional[JSONType] = get_product(barcode, ["images"]) + product: Optional[JSONType] = get_product(product_id, ["images"]) if product is None or "images" not in product: return None - selected_images = generate_selected_images(product["images"], barcode) + selected_images = generate_selected_images(product["images"], product_id) for key in field_types: if key in selected_images: @@ -143,6 +142,7 @@ def __init__( self.insight_id: str = str(insight.id) self.insight_type: str = str(insight.type) self.barcode: str = insight.barcode + self.server_type: str = insight.server_type self.ref_image_url: Optional[str] = ref_image_url self.source_image_url: Optional[str] = source_image_url self.value_tag: Optional[str] = value_tag @@ -153,6 +153,7 @@ def get_type(self): def serialize(self) -> JSONType: serial = { "barcode": self.barcode, + "server_type": self.server_type, "type": self.get_type(), "value": self.value, "question": self.question, @@ -177,6 +178,7 @@ def __init__(self, insight: ProductInsight, ref_image_url: Optional[str]): self.insight_id: str = str(insight.id) self.insight_type: str = str(insight.type) self.barcode: str = insight.barcode + self.server_type: str = insight.server_type self.corrected: str = insight.data["corrected"] self.text: str = insight.data["text"] self.corrections: list[JSONType] = insight.data["corrections"] @@ -189,6 +191,7 @@ def get_type(self): def serialize(self) -> JSONType: serial = { "barcode": self.barcode, + "server_type": self.server_type, "type": self.get_type(), "insight_id": self.insight_id, "insight_type": self.insight_type, @@ -220,7 +223,7 @@ def format_question(self, insight: ProductInsight, lang: str) -> Question: taxonomy: Taxonomy = get_taxonomy(TaxonomyType.category.name) localized_value: str = taxonomy.get_localized_name(insight.value_tag, lang) localized_question = self.translation_store.gettext(lang, self.question) - source_image_url = get_source_image_url(insight.barcode) + source_image_url = get_source_image_url(insight.get_product_id()) return AddBinaryQuestion( question=localized_question, value=localized_value, @@ -231,13 +234,13 @@ def format_question(self, insight: ProductInsight, lang: str) -> Question: @staticmethod def generate_selected_images( - images: JSONType, barcode: str + images: JSONType, product_id: ProductIdentifier ) -> dict[str, dict[str, dict[str, str]]]: """Generate the same `selected_images` field as returned by Product Opener API. :param images: the `images` data of the product - :param barcode: the product barcode + :param product_id: the server type (project) of the product :return: the `selected_images` data """ selected_images: dict[str, dict[str, dict[str, str]]] = { @@ -268,7 +271,7 @@ def generate_selected_images( ): if image_size in available_image_sizes: image_url = generate_image_url( - barcode, f"{key}.{revision_id}.{image_size}" + product_id, f"{key}.{revision_id}.{image_size}" ) selected_images[image_type].setdefault(field_name, {}) selected_images[image_type][field_name][language] = image_url @@ -283,9 +286,10 @@ def format_question(self, insight: ProductInsight, lang: str) -> Question: localized_question = self.translation_store.gettext(lang, self.question) source_image_url = None + server_type = insight.get_product_id().server_type if insight.source_image: source_image_url = settings.BaseURLProvider.image_url( - get_display_image(insight.source_image) + server_type, get_display_image(insight.source_image) ) return AddBinaryQuestion( @@ -308,9 +312,10 @@ def format_question(self, insight: ProductInsight, lang: str) -> Question: localized_question = self.translation_store.gettext(lang, self.question) source_image_url = None + server_type = insight.get_product_id().server_type if insight.source_image: source_image_url = settings.BaseURLProvider.image_url( - get_display_image(insight.source_image) + server_type, get_display_image(insight.source_image) ) return AddBinaryQuestion( @@ -349,9 +354,10 @@ def format_question(self, insight: ProductInsight, lang: str) -> Question: localized_question = self.translation_store.gettext(lang, self.question) source_image_url = None + server_type = insight.get_product_id().server_type if insight.source_image: source_image_url = settings.BaseURLProvider.image_url( - get_display_image(insight.source_image) + server_type, get_display_image(insight.source_image) ) return AddBinaryQuestion( @@ -372,12 +378,13 @@ def format_question(self, insight: ProductInsight, lang: str) -> Question: if insight.predictor in ("curated-list", "taxonomy", "whitelisted-brands"): # Use front image as default for flashtext-brand insights source_image_url = get_source_image_url( - insight.barcode, field_types=["front"] + insight.get_product_id(), field_types=["front"] ) if source_image_url is None and insight.source_image: + server_type = insight.get_product_id().server_type source_image_url = settings.BaseURLProvider.image_url( - get_display_image(insight.source_image) + server_type, get_display_image(insight.source_image) ) return AddBinaryQuestion( @@ -391,15 +398,15 @@ def format_question(self, insight: ProductInsight, lang: str) -> Question: class IngredientSpellcheckQuestionFormatter(QuestionFormatter): def format_question(self, insight: ProductInsight, lang: str) -> Question: - ref_image_url = self.get_ingredient_image_url(insight.barcode, lang) + ref_image_url = self.get_ingredient_image_url(insight.get_product_id(), lang) return IngredientSpellcheckQuestion( insight=insight, ref_image_url=ref_image_url ) - def get_ingredient_image_url(self, barcode: str, lang: str) -> Optional[str]: - mongo_client = MONGO_CLIENT_CACHE.get() - collection = mongo_client.off.products - product = collection.find_one({"code": barcode}, ["images"]) + def get_ingredient_image_url( + self, product_id: ProductIdentifier, lang: str + ) -> Optional[str]: + product = get_product(product_id, ["images"]) if product is None: return None @@ -410,7 +417,7 @@ def get_ingredient_image_url(self, barcode: str, lang: str) -> Optional[str]: if field_name in images: image = images[field_name] image_id = "ingredients_{}.{}.full".format(lang, image["rev"]) - return generate_image_url(barcode, image_id) + return generate_image_url(product_id, image_id) return None @@ -422,9 +429,10 @@ def format_question(self, insight: ProductInsight, lang: str) -> Question: localized_question = self.translation_store.gettext(lang, self.question) source_image_url = None + product_id = insight.get_product_id() if insight.source_image: source_image_url = settings.BaseURLProvider.image_url( - get_display_image(insight.source_image) + product_id.server_type, get_display_image(insight.source_image) ) return AddBinaryQuestion( diff --git a/robotoff/logos.py b/robotoff/logos.py index 45c53942fe..7bc31596a8 100644 --- a/robotoff/logos.py +++ b/robotoff/logos.py @@ -26,6 +26,7 @@ LogoLabelType, Prediction, PredictionType, + ServerType, ) from robotoff.utils import get_logger from robotoff.utils.text import get_tag @@ -301,8 +302,8 @@ def get_weights(dist: np.ndarray, weights: str = "uniform"): def import_logo_insights( logos: list[LogoAnnotation], - server_domain: str, thresholds: dict[LogoLabelType, float], + server_type: ServerType, default_threshold: float = 0.1, notify: bool = True, ) -> InsightImportResult: @@ -342,8 +343,8 @@ def import_logo_insights( # Add a filter on barcode to speed-up filtering & (PredictionModel.barcode.in_([logo.barcode for logo in logos])) ).execute() - predictions = predict_logo_predictions(selected_logos, logo_probs) - import_result = import_insights(predictions, server_domain) + predictions = predict_logo_predictions(selected_logos, logo_probs, server_type) + import_result = import_insights(predictions, server_type) if notify: for logo, probs in zip(selected_logos, logo_probs): @@ -353,7 +354,7 @@ def import_logo_insights( def generate_insights_from_annotated_logos_job( - logo_ids: list[int], server_domain: str, auth: OFFAuthentication + logo_ids: list[int], auth: OFFAuthentication, server_type: ServerType ): """Wrap generate_insights_from_annotated_logos function into a python-rq compatible job.""" @@ -361,11 +362,11 @@ def generate_insights_from_annotated_logos_job( logos = list(LogoAnnotation.select().where(LogoAnnotation.id.in_(logo_ids))) if logos: - generate_insights_from_annotated_logos(logos, server_domain, auth) + generate_insights_from_annotated_logos(logos, auth, server_type) def generate_insights_from_annotated_logos( - logos: list[LogoAnnotation], server_domain: str, auth: OFFAuthentication + logos: list[LogoAnnotation], auth: OFFAuthentication, server_type: ServerType ) -> int: """Generate and apply insights from annotated logos.""" predictions = [] @@ -381,6 +382,7 @@ def generate_insights_from_annotated_logos( "is_annotation": True, # it's worth restating it }, confidence=1.0, + server_type=server_type, ) if prediction is None: @@ -390,7 +392,7 @@ def generate_insights_from_annotated_logos( prediction.source_image = logo.source_image predictions.append(prediction) - import_result = import_insights(predictions, server_domain) + import_result = import_insights(predictions, server_type) if import_result.created_predictions_count(): logger.info(import_result) @@ -399,10 +401,12 @@ def generate_insights_from_annotated_logos( insight_import_result.insight_created_ids for insight_import_result in import_result.product_insight_import_results ): - insight = ProductInsight.get_or_none(id=created_id) + insight: Optional[ProductInsight] = ProductInsight.get_or_none(id=created_id) if insight: logger.info( - "Annotating insight %s (product: %s)", insight.id, insight.barcode + "Annotating insight %s (%s)", + insight.id, + insight.get_product_id(), ) annotation_result = annotate(insight, 1, auth=auth) annotated += int(annotation_result == UPDATED_ANNOTATION_RESULT) @@ -411,7 +415,9 @@ def generate_insights_from_annotated_logos( def predict_logo_predictions( - logos: list[LogoAnnotation], logo_probs: list[dict[LogoLabelType, float]] + logos: list[LogoAnnotation], + logo_probs: list[dict[LogoLabelType, float]], + server_type: ServerType, ) -> list[Prediction]: predictions = [] @@ -436,6 +442,7 @@ def predict_logo_predictions( "logo_id": logo.id, "bounding_box": logo.bounding_box, }, + server_type=server_type, ) if prediction is not None: @@ -451,6 +458,7 @@ def generate_prediction( logo_value: Optional[str], data: dict, confidence: float, + server_type: ServerType, automatic_processing: Optional[bool] = False, ) -> Optional[Prediction]: """Generate a Prediction from a logo. @@ -485,10 +493,13 @@ def generate_prediction( predictor="universal-logo-detector", data=data, confidence=confidence, + server_type=server_type, ) -def refresh_nearest_neighbors(day_offset: int = 7, batch_size: int = 500): +def refresh_nearest_neighbors( + server_type: ServerType, day_offset: int = 7, batch_size: int = 500 +): """Refresh each logo nearest neighbors if the last refresh is more than `day_offset` days old.""" sql_query = """ @@ -527,10 +538,7 @@ def refresh_nearest_neighbors(day_offset: int = 7, batch_size: int = 500): else: logos = [embedding.logo for embedding in logo_embeddings] import_logo_insights( - logos, - thresholds=thresholds, - server_domain=settings.BaseURLProvider.server_domain(), - notify=False, + logos, thresholds=thresholds, server_type=server_type, notify=False ) logger.info("refresh of logo nearest neighbors finished") diff --git a/robotoff/metrics.py b/robotoff/metrics.py index 4ea29bbc9e..c4b29be064 100644 --- a/robotoff/metrics.py +++ b/robotoff/metrics.py @@ -11,6 +11,7 @@ from robotoff import settings from robotoff.models import ProductInsight, with_db +from robotoff.types import ServerType from robotoff.utils import get_logger, http_session logger = get_logger(__name__) @@ -97,24 +98,28 @@ def ensure_influx_database(): logger.exception("Error on ensure_influx_database") -def get_product_count(country_tag: str) -> int: +def get_product_count(server_type: ServerType, country_tag: str) -> int: """Return the number of products in Product Opener for a specific country. :param country_tag: ISO 2-letter country code :return: the number of products currently in Product Opener """ r = http_session.get( - settings.BaseURLProvider.country(country_tag) + "/3.json?fields=null", + settings.BaseURLProvider.country(server_type, country_tag) + + "/3.json?fields=null", auth=settings._off_request_auth, ).json() return int(r["count"]) def save_facet_metrics(): + # Only support for off for now + server_type = ServerType.off inserts = [] target_datetime = datetime.datetime.now() product_counts = { - country_tag: get_product_count(country_tag) for country_tag in COUNTRY_TAGS + country_tag: get_product_count(server_type, country_tag) + for country_tag in COUNTRY_TAGS } for country_tag in COUNTRY_TAGS: @@ -122,10 +127,11 @@ def save_facet_metrics(): for url_path in URL_PATHS: inserts += generate_metrics_from_path( - country_tag, url_path, target_datetime, count + server_type, country_tag, url_path, target_datetime, count ) inserts += generate_metrics_from_path( + server_type, country_tag, "/entry-date/{}/contributors?json=1".format( # get contribution metrics for the previous day @@ -135,7 +141,9 @@ def save_facet_metrics(): facet="contributors", ) - inserts += generate_metrics_from_path("world", "/countries?json=1", target_datetime) + inserts += generate_metrics_from_path( + server_type, "world", "/countries?json=1", target_datetime + ) client = get_influx_client() if client is not None: write_client = client.write_api(write_options=SYNCHRONOUS) @@ -147,6 +155,7 @@ def get_facet_name(url: str) -> str: def generate_metrics_from_path( + server_type: ServerType, country_tag: str, path: str, target_datetime: datetime.datetime, @@ -154,7 +163,7 @@ def generate_metrics_from_path( facet: Optional[str] = None, ) -> list[dict]: inserts: list[dict] = [] - url = settings.BaseURLProvider.country(country_tag + "-en") + path + url = settings.BaseURLProvider.country(server_type, country_tag + "-en") + path if facet is None: facet = get_facet_name(url) @@ -216,6 +225,7 @@ def save_insight_metrics(): - automatic_processing - predictor - reserved_barcode + - server_type """ target_datetime = datetime.datetime.now() @@ -233,6 +243,7 @@ def generate_insight_metrics(target_datetime: datetime.datetime) -> list[dict]: ProductInsight.automatic_processing, ProductInsight.predictor, ProductInsight.reserved_barcode, + ProductInsight.server_type, ] inserts = [] query_results = ( diff --git a/robotoff/models.py b/robotoff/models.py index e32e2c0e7f..51dbac2f7c 100644 --- a/robotoff/models.py +++ b/robotoff/models.py @@ -9,6 +9,7 @@ from playhouse.shortcuts import model_to_dict from robotoff import settings +from robotoff.types import ProductIdentifier, ServerType db = PostgresqlExtDatabase( settings.POSTGRES_DB, @@ -52,9 +53,11 @@ def batch_insert(model_cls, data: Iterable[dict], batch_size=100) -> int: def crop_image_url( - source_image: str, bounding_box: tuple[float, float, float, float] + server_type: ServerType, + source_image: str, + bounding_box: tuple[float, float, float, float], ) -> str: - base_url = settings.BaseURLProvider.image_url(source_image) + base_url = settings.BaseURLProvider.image_url(server_type, source_image) y_min, x_min, y_max, x_max = bounding_box base_robotoff_url = settings.BaseURLProvider.robotoff() return f"{base_robotoff_url}/api/v1/images/crop?image_url={base_url}&y_min={y_min}&x_min={x_min}&y_max={y_max}&x_max={x_max}" @@ -138,13 +141,10 @@ class ProductInsight(BaseModel): # pre-determined threshold. automatic_processing = peewee.BooleanField(default=False, index=True) - server_domain = peewee.TextField( - null=True, help_text="server domain linked to the insight", index=True - ) server_type = peewee.CharField( null=True, max_length=10, - help_text="project associated with the server_domain, " + help_text="project associated with the insight, " "one of 'off', 'obf', 'opff', 'opf'", index=True, ) @@ -165,6 +165,9 @@ class ProductInsight(BaseModel): # Confidence score of the insight, may be null confidence = peewee.FloatField(null=True, index=True) + def get_product_id(self) -> ProductIdentifier: + return ProductIdentifier(self.barcode, ServerType[self.server_type]) + class Prediction(BaseModel): barcode = peewee.CharField(max_length=100, null=False, index=True) @@ -175,11 +178,19 @@ class Prediction(BaseModel): value = peewee.TextField(null=True) source_image = peewee.TextField(null=True, index=True) automatic_processing = peewee.BooleanField(null=True) - server_domain = peewee.TextField( - help_text="server domain linked to the insight", index=True - ) predictor = peewee.CharField(max_length=100, null=True) confidence = peewee.FloatField(null=True, index=False) + server_type = peewee.CharField( + null=False, + max_length=10, + help_text="project associated with the insight, " + "one of 'off', 'obf', 'opff', 'opf'", + index=True, + default="off", + ) + + def get_product_id(self) -> ProductIdentifier: + return ProductIdentifier(self.barcode, ServerType[self.server_type]) class AnnotationVote(BaseModel): @@ -211,12 +222,14 @@ class ImageModel(BaseModel): width = peewee.IntegerField(null=False, index=True) height = peewee.IntegerField(null=False, index=True) deleted = peewee.BooleanField(null=False, index=True, default=False) - server_domain = peewee.TextField(null=True, index=True) server_type = peewee.CharField(null=True, max_length=10, index=True) class Meta: table_name = "image" + def get_product_id(self) -> ProductIdentifier: + return ProductIdentifier(self.barcode, ServerType[self.server_type]) + class ImagePrediction(BaseModel): """Table to store computer vision predictions (object detection, @@ -279,7 +292,14 @@ class Meta: constraints = [peewee.SQL("UNIQUE(image_prediction_id, index)")] def get_crop_image_url(self) -> str: - return crop_image_url(self.source_image, self.bounding_box) + return crop_image_url( + self.get_server_type(), + self.source_image, + self.bounding_box, + ) + + def get_server_type(self) -> ServerType: + return ServerType[self.image_prediction.image.server_type] class LogoEmbedding(BaseModel): diff --git a/robotoff/mongo.py b/robotoff/mongo.py deleted file mode 100644 index 61e936f088..0000000000 --- a/robotoff/mongo.py +++ /dev/null @@ -1,11 +0,0 @@ -from pymongo import MongoClient - -from robotoff import settings -from robotoff.utils.cache import CachedStore - - -def get_mongo_client() -> MongoClient: - return MongoClient(settings.MONGO_URI, serverSelectionTimeoutMS=10_000) - - -MONGO_CLIENT_CACHE = CachedStore(get_mongo_client, expiration_interval=None) diff --git a/robotoff/off.py b/robotoff/off.py index 15e293c4c2..e95acbe777 100644 --- a/robotoff/off.py +++ b/robotoff/off.py @@ -1,14 +1,14 @@ """Interacting with OFF server to eg. update products or get infos """ -import enum import re from pathlib import Path -from typing import Optional, Union +from typing import Optional from urllib.parse import urlparse import requests from robotoff import settings +from robotoff.types import ProductIdentifier, ServerType from robotoff.utils import get_logger, http_session logger = get_logger(__name__) @@ -65,21 +65,6 @@ def get_username(self) -> Optional[str]: return None -class ServerType(enum.Enum): - off = 1 - obf = 2 - opff = 3 - opf = 4 - - -API_URLS: dict[ServerType, str] = { - ServerType.off: settings.BaseURLProvider.world(), - ServerType.obf: "https://world.openbeautyfacts.org", - ServerType.opf: "https://world.openproductfacts.org", - ServerType.opff: "https://world.openpetfoodfacts.org", -} - - BARCODE_PATH_REGEX = re.compile(r"^(...)(...)(...)(.*)$") @@ -112,47 +97,16 @@ def get_barcode_from_path(path: str) -> Optional[str]: return barcode or None -def get_product_image_select_url(server: Union[ServerType, str]) -> str: - return "{}/cgi/product_image_crop.pl".format(get_base_url(server)) +def get_product_image_select_url(server_type: ServerType) -> str: + base_url = settings.BaseURLProvider.api(server_type) + return f"{base_url}/cgi/product_image_crop.pl" -def get_api_product_url(server: Union[ServerType, str]) -> str: +def get_api_product_url(server_type: ServerType) -> str: # V2 of API is required to have proper ingredient nesting # for product categorization - return "{}/api/v2/product".format(get_base_url(server)) - - -def get_base_url(server: Union[ServerType, str]) -> str: - if isinstance(server, str): - server = server.replace("api", "world") - # get scheme, https on prod, but http in dev - scheme = settings._get_default_scheme() - return f"{scheme}://{server}" - else: - if server not in API_URLS: - raise ValueError("unsupported server type: {}".format(server)) - - return API_URLS[server] - - -def get_server_type(server_domain: str) -> ServerType: - """Return the server type (off, obf, opff, opf) associated with the server - domain, or None if the server_domain was not recognized.""" - server_split = server_domain.split(".") - - if len(server_split) == 3: - subdomain, domain, tld = server_split - - if domain == "openfoodfacts": - return ServerType.off - elif domain == "openbeautyfacts": - return ServerType.obf - elif domain == "openpetfoodfacts": - return ServerType.opff - elif domain == "openproductsfacts": - return ServerType.opf - - raise ValueError("unknown server domain: {}".format(server_domain)) + base_url = settings.BaseURLProvider.api(server_type) + return f"{base_url}/api/v2/product" def split_barcode(barcode: str) -> list[str]: @@ -177,19 +131,21 @@ def generate_json_path(barcode: str, image_id: str) -> str: return "/{}/{}.json".format("/".join(splitted_barcode), image_id) -def generate_json_ocr_url(barcode: str, image_id: str) -> str: +def generate_json_ocr_url(product_id: ProductIdentifier, image_id: str) -> str: return ( - settings.BaseURLProvider.static() - + f"/images/products{generate_json_path(barcode, image_id)}" + settings.BaseURLProvider.static(product_id.server_type) + + f"/images/products{generate_json_path(product_id.barcode, image_id)}" ) -def generate_image_url(barcode: str, image_id: str) -> str: - return settings.BaseURLProvider.image_url(generate_image_path(barcode, image_id)) +def generate_image_url(product_id: ProductIdentifier, image_id: str) -> str: + return settings.BaseURLProvider.image_url( + product_id.server_type, generate_image_path(product_id.barcode, image_id) + ) -def is_valid_image(barcode: str, image_id: str) -> bool: - product = get_product(barcode, fields=["images"]) +def is_valid_image(product_id: ProductIdentifier, image_id: str) -> bool: + product = get_product(product_id, fields=["images"]) if product is None: return False @@ -204,17 +160,15 @@ def off_credentials() -> dict[str, str]: def get_product( - barcode: str, + product_id: ProductIdentifier, fields: Optional[list[str]] = None, - server: Optional[Union[ServerType, str]] = None, timeout: Optional[int] = 10, ) -> Optional[dict]: fields = fields or [] - if server is None: - server = ServerType.off - - url = get_api_product_url(server) + "/{}.json".format(barcode) + url = get_api_product_url(product_id.server_type) + "/{}.json".format( + product_id.barcode + ) if fields: # requests escape comma in URLs, as expected, but openfoodfacts server @@ -236,7 +190,10 @@ def get_product( def add_category( - barcode: str, category: str, insight_id: Optional[str] = None, **kwargs + product_id: ProductIdentifier, + category: str, + insight_id: Optional[str] = None, + **kwargs, ): comment = "[robotoff] Adding category '{}'".format(category) @@ -244,15 +201,18 @@ def add_category( comment += ", ID: {}".format(insight_id) params = { - "code": barcode, + "code": product_id.barcode, "add_categories": category, "comment": comment, } - update_product(params, **kwargs) + update_product(params, server_type=product_id.server_type, **kwargs) def update_quantity( - barcode: str, quantity: str, insight_id: Optional[str] = None, **kwargs + product_id: ProductIdentifier, + quantity: str, + insight_id: Optional[str] = None, + **kwargs, ): comment = "[robotoff] Updating quantity to '{}'".format(quantity) @@ -260,15 +220,18 @@ def update_quantity( comment += ", ID: {}".format(insight_id) params = { - "code": barcode, + "code": product_id.barcode, "quantity": quantity, "comment": comment, } - update_product(params, **kwargs) + update_product(params, server_type=product_id.server_type, **kwargs) def update_emb_codes( - barcode: str, emb_codes: list[str], insight_id: Optional[str] = None, **kwargs + product_id: ProductIdentifier, + emb_codes: list[str], + insight_id: Optional[str] = None, + **kwargs, ): emb_codes_str = ",".join(emb_codes) @@ -278,15 +241,18 @@ def update_emb_codes( comment += ", ID: {}".format(insight_id) params = { - "code": barcode, + "code": product_id.barcode, "emb_codes": emb_codes_str, "comment": comment, } - update_product(params, **kwargs) + update_product(params, server_type=product_id.server_type, **kwargs) def update_expiration_date( - barcode: str, expiration_date: str, insight_id: Optional[str] = None, **kwargs + product_id: ProductIdentifier, + expiration_date: str, + insight_id: Optional[str] = None, + **kwargs, ): comment = "[robotoff] Adding expiration date '{}'".format(expiration_date) @@ -294,15 +260,18 @@ def update_expiration_date( comment += ", ID: {}".format(insight_id) params = { - "code": barcode, + "code": product_id.barcode, "expiration_date": expiration_date, "comment": comment, } - update_product(params, **kwargs) + update_product(params, server_type=product_id.server_type, **kwargs) def add_label_tag( - barcode: str, label_tag: str, insight_id: Optional[str] = None, **kwargs + product_id: ProductIdentifier, + label_tag: str, + insight_id: Optional[str] = None, + **kwargs, ): comment = "[robotoff] Adding label tag '{}'".format(label_tag) @@ -310,43 +279,56 @@ def add_label_tag( comment += ", ID: {}".format(insight_id) params = { - "code": barcode, + "code": product_id.barcode, "add_labels": label_tag, "comment": comment, } - update_product(params, **kwargs) + update_product(params, server_type=product_id.server_type, **kwargs) -def add_brand(barcode: str, brand: str, insight_id: Optional[str] = None, **kwargs): +def add_brand( + product_id: ProductIdentifier, + brand: str, + insight_id: Optional[str] = None, + **kwargs, +): comment = "[robotoff] Adding brand '{}'".format(brand) if insight_id: comment += ", ID: {}".format(insight_id) params = { - "code": barcode, + "code": product_id.barcode, "add_brands": brand, "comment": comment, } - update_product(params, **kwargs) + update_product(params, server_type=product_id.server_type, **kwargs) -def add_store(barcode: str, store: str, insight_id: Optional[str] = None, **kwargs): +def add_store( + product_id: ProductIdentifier, + store: str, + insight_id: Optional[str] = None, + **kwargs, +): comment = "[robotoff] Adding store '{}'".format(store) if insight_id: comment += ", ID: {}".format(insight_id) params = { - "code": barcode, + "code": product_id.barcode, "add_stores": store, "comment": comment, } - update_product(params, **kwargs) + update_product(params, server_type=product_id.server_type, **kwargs) def add_packaging( - barcode: str, packaging: dict, insight_id: Optional[str] = None, **kwargs + product_id: ProductIdentifier, + packaging: dict, + insight_id: Optional[str] = None, + **kwargs, ): shape_value_tag = packaging["shape"]["value_tag"] comment = f"[robotoff] Updating/adding packaging elements '{shape_value_tag}'" @@ -367,11 +349,13 @@ def add_packaging( "fields": "none", "comment": comment, } - update_product_v3(barcode, body, **kwargs) + update_product_v3( + product_id.barcode, body, server_type=product_id.server_type, **kwargs + ) def save_ingredients( - barcode: str, + product_id: ProductIdentifier, ingredient_text: str, insight_id: Optional[str] = None, lang: Optional[str] = None, @@ -389,23 +373,21 @@ def save_ingredients( comment += ", ID: {}".format(insight_id) params = { - "code": barcode, + "code": product_id.barcode, "comment": comment, ingredient_key: ingredient_text, } - update_product(params, **kwargs) + update_product(params, server_type=product_id.server_type, **kwargs) def update_product( params: dict, - server_domain: Optional[str] = None, + server_type: ServerType, auth: Optional[OFFAuthentication] = None, timeout: Optional[int] = 15, ): - if server_domain is None: - server_domain = settings.BaseURLProvider.server_domain() - - url = f"{get_base_url(server_domain)}/cgi/product_jqm2.pl" + base_url = settings.BaseURLProvider.api(server_type) + url = f"{base_url}/cgi/product_jqm2.pl" comment = params.get("comment") cookies = None @@ -448,14 +430,12 @@ def update_product( def update_product_v3( barcode: str, body: dict, - server_domain: Optional[str] = None, + server_type: ServerType, auth: Optional[OFFAuthentication] = None, timeout: Optional[int] = 15, ): - if server_domain is None: - server_domain = settings.BaseURLProvider.server_domain() - - url = f"{get_base_url(server_domain)}/api/v3/product/{barcode}" + base_url = settings.BaseURLProvider.api(server_type) + url = f"{base_url}/api/v3/product/{barcode}" comment = body.get("comment") cookies = None @@ -493,14 +473,20 @@ def update_product_v3( raise ValueError("Errors during product update: %s", str(json["errors"])) -def move_to(barcode: str, to: ServerType, timeout: Optional[int] = 10) -> bool: - if get_product(barcode, server=to) is not None: +def move_to( + product_id: ProductIdentifier, to: ServerType, timeout: Optional[int] = 10 +) -> bool: + if ( + get_product(ProductIdentifier(barcode=product_id.barcode, server_type=to)) + is not None + ): return False - url = "{}/cgi/product_jqm.pl".format(settings.BaseURLProvider.world()) + base_url = settings.BaseURLProvider.api(product_id.server_type) + url = f"{base_url}/cgi/product_jqm.pl" params = { "type": "edit", - "code": barcode, + "code": product_id.barcode, "new_code": str(to), **off_credentials(), } @@ -510,24 +496,23 @@ def move_to(barcode: str, to: ServerType, timeout: Optional[int] = 10) -> bool: def delete_image_pipeline( - barcode: str, + product_id: ProductIdentifier, image_id: str, auth: OFFAuthentication, - server_domain: Optional[str] = None, ) -> None: """Delete an image and unselect all selected images that have this image as image ID. - :param barcode: barcode of the product + :param product_id: identifier of the product :param image_id: ID of the image to delete (number) :param auth: user authentication data :param server_domain: the server domain to use, default to BaseURLProvider.server_domain() """ - product = get_product(barcode, ["images"], server_domain) + product = get_product(product_id, ["images"]) if product is None: - logger.info("Product %s not found, cannot delete image %s", barcode, image_id) + logger.info("%s not found, cannot delete image %s", product_id, image_id) return None to_delete = False @@ -544,30 +529,27 @@ def delete_image_pipeline( to_unselect.append(image_field) if to_delete: - logger.info( - "Sending deletion request for image %s of product %s", image_id, barcode - ) - delete_image(barcode, image_id, auth, server_domain) + logger.info("Sending deletion request for image %s of %s", image_id, product_id) + delete_image(product_id, image_id, auth) for image_field in to_unselect: logger.info( - "Sending unselect request for image %s of product %s", image_field, barcode + "Sending unselect request for image %s of %s", image_field, product_id ) - unselect_image(barcode, image_field, auth, server_domain) + unselect_image(product_id, image_field, auth) logger.info("Image deletion pipeline completed") def unselect_image( - barcode: str, + product_id: ProductIdentifier, image_field: str, auth: OFFAuthentication, - server_domain: Optional[str] = None, timeout: Optional[int] = 15, ) -> requests.Response: """Unselect an image. - :param barcode: barcode of the product + :param product_id: identifier of the product :param image_field: field name of the image to unselect, ex: front_fr :param auth: user authentication data :param server_domain: the server domain to use, default to @@ -575,13 +557,11 @@ def unselect_image( :param timeout: request timeout value in seconds, defaults to 15s :return: the request Response """ - if server_domain is None: - server_domain = settings.BaseURLProvider.server_domain() - - url = f"{get_base_url(server_domain)}/cgi/product_image_unselect.pl" + base_url = settings.BaseURLProvider.api(product_id.server_type) + url = f"{base_url}/cgi/product_image_unselect.pl" cookies = None params = { - "code": barcode, + "code": product_id.barcode, "id": image_field, } @@ -606,30 +586,26 @@ def unselect_image( def delete_image( - barcode: str, + product_id: ProductIdentifier, image_id: str, auth: OFFAuthentication, - server_domain: Optional[str] = None, timeout: Optional[int] = 15, ) -> requests.Response: """Delete an image on Product Opener. - :param barcode: barcode of the product + :param product_id: identifier of the product :param image_id: ID of the image to delete (number) :param auth: user authentication data - :param server_domain: the server domain to use, default to - BaseURLProvider.server_domain() :param timeout: request timeout (in seconds), defaults to 15 :return: the requests Response """ - if server_domain is None: - server_domain = settings.BaseURLProvider.server_domain() - url = f"{get_base_url(server_domain)}/cgi/product_image_move.pl" + base_url = settings.BaseURLProvider.api(product_id.server_type) + url = f"{base_url}/cgi/product_image_move.pl" cookies = None params = { "type": "edit", - "code": barcode, + "code": product_id.barcode, "imgids": image_id, "action": "process", "move_to_override": "trash", @@ -661,21 +637,17 @@ def delete_image( def select_rotate_image( - barcode: str, + product_id: ProductIdentifier, image_id: str, image_key: Optional[str] = None, rotate: Optional[int] = None, - server_domain: Optional[str] = None, auth: Optional[OFFAuthentication] = None, timeout: Optional[int] = 15, ): - if server_domain is None: - server_domain = settings.BaseURLProvider.server_domain() - - url = get_product_image_select_url(server_domain) + url = get_product_image_select_url(product_id.server_type) cookies = None params = { - "code": barcode, + "code": product_id.barcode, "imgid": image_id, } diff --git a/robotoff/prediction/category/__init__.py b/robotoff/prediction/category/__init__.py index 4ee0700332..704455f848 100644 --- a/robotoff/prediction/category/__init__.py +++ b/robotoff/prediction/category/__init__.py @@ -2,7 +2,7 @@ from typing import Optional from robotoff.taxonomy import TaxonomyType, get_taxonomy -from robotoff.types import JSONType, NeuralCategoryClassifierModel +from robotoff.types import JSONType, NeuralCategoryClassifierModel, ProductIdentifier from .matcher import predict_by_lang from .neural.category_classifier import CategoryClassifier @@ -10,6 +10,7 @@ def predict_category( product: dict, + product_id: ProductIdentifier, neural_predictor: bool, matcher_predictor: bool, deepest_only: bool, @@ -45,7 +46,7 @@ def predict_category( taxonomy = get_taxonomy(TaxonomyType.category.name) if neural_predictor: predictions, debug = CategoryClassifier(taxonomy).predict( - product, deepest_only, threshold, neural_model_name + product, product_id, deepest_only, threshold, neural_model_name ) response["neural"] = { "predictions": [ diff --git a/robotoff/prediction/category/matcher.py b/robotoff/prediction/category/matcher.py index d936221afc..9a786aa461 100644 --- a/robotoff/prediction/category/matcher.py +++ b/robotoff/prediction/category/matcher.py @@ -10,7 +10,7 @@ from robotoff import settings from robotoff.products import ProductDataset from robotoff.taxonomy import TaxonomyType, get_taxonomy -from robotoff.types import Prediction, PredictionType +from robotoff.types import Prediction, PredictionType, ServerType from robotoff.utils import dump_json, get_logger, load_json from robotoff.utils.text import ( get_lemmatizing_nlp, @@ -440,6 +440,7 @@ def predict_from_dataset( for product in product_stream.iter(): for prediction in predict(product): prediction.barcode = product["code"] + prediction.server_type = ServerType.off yield prediction diff --git a/robotoff/prediction/category/neural/category_classifier.py b/robotoff/prediction/category/neural/category_classifier.py index b8bed87b93..6722af259f 100644 --- a/robotoff/prediction/category/neural/category_classifier.py +++ b/robotoff/prediction/category/neural/category_classifier.py @@ -7,6 +7,7 @@ NeuralCategoryClassifierModel, Prediction, PredictionType, + ProductIdentifier, ) from robotoff.utils import get_logger @@ -16,7 +17,11 @@ def create_prediction( - category: str, confidence: float, model_version: str, **kwargs + category: str, + confidence: float, + model_version: str, + product_id: ProductIdentifier, + **kwargs ) -> Prediction: """Create a Prediction. @@ -35,6 +40,8 @@ def create_prediction( automatic_processing=False, predictor="neural", confidence=confidence, + barcode=product_id.barcode, + server_type=product_id.server_type, ) @@ -53,6 +60,7 @@ def __init__(self, category_taxonomy: Taxonomy): def predict( self, product: dict, + product_id: ProductIdentifier, deepest_only: bool = False, threshold: Optional[float] = None, model_name: Optional[NeuralCategoryClassifierModel] = None, @@ -115,7 +123,7 @@ def predict( # Otherwise we fetch OCR texts from Product Opener # Only fetch OCR texts if it's required by the model ocr_texts = ( - keras_category_classifier_3_0.fetch_ocr_texts(product) + keras_category_classifier_3_0.fetch_ocr_texts(product, product_id) if keras_category_classifier_3_0.model_input_flags[model_name].get( "add_ingredients_ocr_tags", True ) @@ -126,7 +134,7 @@ def predict( triton_stub = get_triton_inference_stub() image_embeddings = ( keras_category_classifier_3_0.generate_image_embeddings( - product, triton_stub + product, triton_stub, product_id ) if keras_category_classifier_3_0.model_input_flags[model_name].get( "add_image_embeddings", True @@ -173,6 +181,7 @@ def predict( category_id, score, model_name.value, + product_id=product_id, above_threshold=above_threshold, # We need to set a higher priority (=lower digit) if # above_threshold is True, as otherwise a deepest diff --git a/robotoff/prediction/category/neural/keras_category_classifier_3_0/__init__.py b/robotoff/prediction/category/neural/keras_category_classifier_3_0/__init__.py index 0a67937127..8c05b0ed9e 100644 --- a/robotoff/prediction/category/neural/keras_category_classifier_3_0/__init__.py +++ b/robotoff/prediction/category/neural/keras_category_classifier_3_0/__init__.py @@ -15,7 +15,7 @@ generate_clip_embedding_request, serialize_byte_tensor, ) -from robotoff.types import JSONType, NeuralCategoryClassifierModel +from robotoff.types import JSONType, NeuralCategoryClassifierModel, ProductIdentifier from robotoff.utils import get_image_from_url, get_logger, http_session, load_json from .preprocessing import ( @@ -38,7 +38,7 @@ def fetch_cached_image_embeddings( - barcode: str, image_ids: list[str] + product_id: ProductIdentifier, image_ids: list[str] ) -> dict[str, np.ndarray]: """Fetch image embeddings cached in DB for a product and specific image IDs. @@ -46,7 +46,7 @@ def fetch_cached_image_embeddings( Only the embeddings existing in DB are returned, image IDs that were not found are ignored. - :param barcode: the product barcode + :param product_id: identifier of the product :param image_ids: a list of image IDs to fetch :return: a dict mapping image IDs to CLIP image embedding """ @@ -55,7 +55,8 @@ def fetch_cached_image_embeddings( ImageEmbedding.select(ImageModel.image_id, ImageEmbedding.embedding) .join(ImageModel) .where( - ImageModel.barcode == barcode, + ImageModel.barcode == product_id.barcode, + ImageModel.server_type == product_id.server_type.name, ImageModel.image_id.in_(image_ids), ) .tuples() @@ -66,10 +67,12 @@ def fetch_cached_image_embeddings( return cached_embeddings -def save_image_embeddings(barcode: str, embeddings: dict[str, np.ndarray]): +def save_image_embeddings( + product_id: ProductIdentifier, embeddings: dict[str, np.ndarray] +): """Save computed image embeddings in ImageEmbedding table. - :param barcode: barcode of the product + :param product_id: identifier of the product :param embeddings: a dict mapping image ID to image embedding """ image_id_to_model_id = { @@ -78,7 +81,8 @@ def save_image_embeddings(barcode: str, embeddings: dict[str, np.ndarray]): ImageModel.id, ImageModel.image_id ) .where( - ImageModel.barcode == barcode, + ImageModel.barcode == product_id.barcode, + ImageModel.server_type == product_id.server_type.name, ImageModel.image_id.in_(list(embeddings.keys())), ) .tuples() @@ -100,7 +104,9 @@ def save_image_embeddings(barcode: str, embeddings: dict[str, np.ndarray]): @with_db -def generate_image_embeddings(product: JSONType, stub) -> Optional[np.ndarray]: +def generate_image_embeddings( + product: JSONType, product_id: ProductIdentifier, stub +) -> Optional[np.ndarray]: """Generate image embeddings using CLIP model for the `MAX_IMAGE_EMBEDDING` most recent images. @@ -123,8 +129,7 @@ def generate_image_embeddings(product: JSONType, stub) -> Optional[np.ndarray]: # Convert image IDs back to string image_ids = [str(image_id) for image_id in image_ids_int] if image_ids: - barcode = product["code"] - embeddings_by_id = fetch_cached_image_embeddings(barcode, image_ids) + embeddings_by_id = fetch_cached_image_embeddings(product_id, image_ids) logger.debug("%d embeddings fetched from DB", len(embeddings_by_id)) missing_embedding_ids = set(image_ids) - set(embeddings_by_id) @@ -137,7 +142,7 @@ def generate_image_embeddings(product: JSONType, stub) -> Optional[np.ndarray]: # Images are resized to 224x224, so there is no need to # fetch the full-sized image, the 400px resized # version is enough - generate_image_url(barcode, f"{image_id}.400"), + generate_image_url(product_id, f"{image_id}.400"), error_raise=False, session=http_session, ) @@ -162,10 +167,10 @@ def generate_image_embeddings(product: JSONType, stub) -> Optional[np.ndarray]: non_null_image_by_ids, stub ) # Make sure all image IDs are in image table - refresh_images_in_db(barcode, product.get("images", {})) + refresh_images_in_db(product_id, product.get("images", {})) # Save embeddings in embeddings.image_embeddings table for future # use - save_image_embeddings(barcode, computed_embeddings_by_id) + save_image_embeddings(product_id, computed_embeddings_by_id) # Merge cached and newly-computed image embeddings embeddings_by_id |= computed_embeddings_by_id @@ -197,7 +202,7 @@ def _generate_image_embeddings( } -def fetch_ocr_texts(product: JSONType) -> list[str]: +def fetch_ocr_texts(product: JSONType, product_id: ProductIdentifier) -> list[str]: """Fetch all image OCRs from Product Opener and return a list of the detected texts, one string per image.""" barcode = product.get("code") @@ -207,7 +212,7 @@ def fetch_ocr_texts(product: JSONType) -> list[str]: ocr_texts = [] image_ids = (id_ for id_ in product.get("images", {}).keys() if id_.isdigit()) for image_id in image_ids: - ocr_url = generate_json_ocr_url(barcode, image_id) + ocr_url = generate_json_ocr_url(product_id, image_id) ocr_result = get_ocr_result(ocr_url, http_session, error_raise=False) if ocr_result: ocr_texts.append(ocr_result.get_full_text_contiguous()) diff --git a/robotoff/prediction/ocr/core.py b/robotoff/prediction/ocr/core.py index 615c8ab66a..faf9d256aa 100644 --- a/robotoff/prediction/ocr/core.py +++ b/robotoff/prediction/ocr/core.py @@ -4,7 +4,7 @@ import requests -from robotoff.types import JSONType, Prediction, PredictionType +from robotoff.types import JSONType, Prediction, PredictionType, ProductIdentifier from robotoff.utils import get_logger, jsonl_iter, jsonl_iter_fp from .brand import find_brands @@ -89,22 +89,27 @@ def get_ocr_result( def extract_predictions( content: Union[OCRResult, str], prediction_type: PredictionType, - barcode: Optional[str] = None, + product_id: Optional[ProductIdentifier] = None, source_image: Optional[str] = None, ) -> list[Prediction]: """Extract predictions from OCR using for provided prediction type. :param content: OCR output to extract predictions from. - :param barcode: Barcode to add to each prediction, defaults to None. + :param prediction_type: type of the prediction to extract. + :param product_id: identifier of the product (barcode + server type) to + add to each prediction, defaults to None. :param source_image: `source_image`to add to each prediction, defaults to - None. + None. :return: The generated predictions. """ if prediction_type in PREDICTION_TYPE_TO_FUNC: predictions = PREDICTION_TYPE_TO_FUNC[prediction_type](content) for prediction in predictions: - prediction.barcode = barcode prediction.source_image = source_image + if product_id is not None: + prediction.barcode = product_id.barcode + prediction.server_type = product_id.server_type + return predictions else: raise ValueError(f"unknown prediction type: {prediction_type}") diff --git a/robotoff/products.py b/robotoff/products.py index cda9817400..bc7c30a5ce 100644 --- a/robotoff/products.py +++ b/robotoff/products.py @@ -14,12 +14,23 @@ from pymongo import MongoClient from robotoff import settings -from robotoff.mongo import MONGO_CLIENT_CACHE -from robotoff.types import JSONType +from robotoff.types import JSONType, ProductIdentifier, ServerType from robotoff.utils import get_logger, gzip_jsonl_iter, http_session, jsonl_iter logger = get_logger(__name__) +MONGO_SELECTION_TIMEOUT_MS = 10_0000 + + +@functools.cache +def get_mongo_client(server_type: ServerType) -> Optional[MongoClient]: + if server_type != ServerType.off: + return None + + return MongoClient( + settings.MONGO_URI, serverSelectionTimeoutMS=MONGO_SELECTION_TIMEOUT_MS + ) + def get_image_id(image_path: str) -> Optional[str]: """Return the image ID from an image path. @@ -473,21 +484,31 @@ def __iter__(self) -> Iterator[Product]: class DBProductStore(ProductStore): - def __init__(self, client: MongoClient): + def __init__(self, server_type: ServerType, client: Optional[MongoClient]): self.client = client - self.db = self.client.off - self.collection = self.db.products + self.server_type = server_type + + if self.client is None: + self.db = None + self.collection = None + else: + self.db = self.client[server_type.name] + self.collection = self.db.products def __len__(self): + if self.collection is None: + return 0 return len(self.collection.estimated_document_count()) def get_product( - self, barcode: str, projection: Optional[list[str]] = None + self, product_id: ProductIdentifier, projection: Optional[list[str]] = None ) -> Optional[JSONType]: - return self.collection.find_one({"code": barcode}, projection) + if self.collection is None: + return None + return self.collection.find_one({"code": product_id.barcode}, projection) - def __getitem__(self, barcode: str) -> Optional[Product]: - product = self.get_product(barcode) + def __getitem__(self, product_id: ProductIdentifier) -> Optional[Product]: + product = self.get_product(product_id) if product: return Product(product) @@ -495,10 +516,12 @@ def __getitem__(self, barcode: str) -> Optional[Product]: return None def __iter__(self): - yield from self.iter() + if self.collection is not None: + yield from self.iter() def iter_product(self, projection: Optional[list[str]] = None): - yield from (Product(p) for p in self.collection.find(projection=projection)) + if self.collection is not None: + yield from (Product(p) for p in self.collection.find(projection=projection)) @functools.cache @@ -509,20 +532,19 @@ def get_min_product_store() -> ProductStore: return ps -def get_product_store() -> DBProductStore: - mongo_client = MONGO_CLIENT_CACHE.get() - return DBProductStore(client=mongo_client) +def get_product_store(server_type: ServerType) -> DBProductStore: + mongo_client = get_mongo_client(server_type) + return DBProductStore(server_type, client=mongo_client) def get_product( - barcode: str, projection: Optional[list[str]] = None + product_id: ProductIdentifier, projection: Optional[list[str]] = None ) -> Optional[JSONType]: """Get product from MongoDB. - :param barcode: barcode of the product to fetch + :param product_id: identifier of the product to fetch :param projection: list of fields to retrieve, if not provided all fields are queried :return: the product as a dict or None if it was not found """ - mongo_client = MONGO_CLIENT_CACHE.get() - return mongo_client.off.products.find_one({"code": barcode}, projection) + return get_product_store(product_id.server_type).get_product(product_id, projection) diff --git a/robotoff/scheduler/__init__.py b/robotoff/scheduler/__init__.py index 8419042e2e..2ce8df768d 100644 --- a/robotoff/scheduler/__init__.py +++ b/robotoff/scheduler/__init__.py @@ -30,6 +30,7 @@ get_min_product_store, has_dataset_changed, ) +from robotoff.types import ServerType from robotoff.utils import get_logger from .latent import generate_quality_facets @@ -43,6 +44,7 @@ def process_insights(): with db.connection_context(): processed = 0 + insight: ProductInsight for insight in ( ProductInsight.select() .where( @@ -54,7 +56,7 @@ def process_insights(): ): try: logger.info( - "Annotating insight %s (product: %s)", insight.id, insight.barcode + "Annotating insight %s (%s)", insight.id, insight.get_product_id() ) annotation_result = annotate(insight, 1, update=True) processed += 1 @@ -69,9 +71,9 @@ def process_insights(): # continue to the next one # Note: annotator already rolled-back the transaction logger.exception( - f"exception {e} while handling annotation of insight %s (product) %s", + f"exception {e} while handling annotation of insight %s (%s)", insight.id, - insight.barcode, + insight.get_product_id(), ) logger.info("%d insights processed", processed) @@ -81,6 +83,8 @@ def refresh_insights(with_deletion: bool = False): deleted = 0 updated = 0 product_store = get_min_product_store() + # Only OFF is currently supported + server_type = ServerType.off datetime_threshold = datetime.datetime.utcnow().replace( hour=0, minute=0, second=0, microsecond=0 @@ -95,21 +99,23 @@ def refresh_insights(with_deletion: bool = False): ) return + insight: ProductInsight for insight in ( ProductInsight.select() .where( ProductInsight.annotation.is_null(), ProductInsight.timestamp <= datetime_threshold, - ProductInsight.server_domain == settings.BaseURLProvider.server_domain(), + ProductInsight.server_type == server_type.name, ) .iterator() ): + product_id = insight.get_product_id() product: Product = product_store[insight.barcode] if product is None: if with_deletion: # Product has been deleted from OFF - logger.info("Product with barcode {} deleted".format(insight.barcode)) + logger.info("%s deleted", product_id) deleted += 1 insight.delete_instance() else: @@ -126,27 +132,30 @@ def update_insight_attributes(product: Product, insight: ProductInsight) -> bool to_update = False if insight.brands != product.brands_tags: logger.info( - "Updating brand {} -> {} ({})".format( - insight.brands, product.brands_tags, product.barcode - ) + "Updating brand %s -> %s (%s)", + insight.brands, + product.brands_tags, + insight.get_product_id(), ) to_update = True insight.brands = product.brands_tags if insight.countries != product.countries_tags: logger.info( - "Updating countries {} -> {} ({})".format( - insight.countries, product.countries_tags, product.barcode - ) + "Updating countries %s -> %s (%s)", + insight.countries, + product.countries_tags, + insight.get_product_id(), ) to_update = True insight.countries = product.countries_tags if insight.unique_scans_n != product.unique_scans_n: logger.info( - "Updating unique scan count {} -> {} ({})".format( - insight.unique_scans_n, product.unique_scans_n, product.barcode - ) + "Updating unique scan count %s -> %s (%s)", + insight.unique_scans_n, + product.unique_scans_n, + insight.get_product_id(), ) to_update = True insight.unique_scans_n = product.unique_scans_n @@ -160,6 +169,7 @@ def update_insight_attributes(product: Product, insight: ProductInsight) -> bool @with_db def mark_insights(): marked = 0 + insight: ProductInsight for insight in ( ProductInsight.select() .where( @@ -170,8 +180,9 @@ def mark_insights(): .iterator() ): logger.info( - "Marking insight {} as processable automatically " - "(product: {})".format(insight.id, insight.barcode) + "Marking insight %s as processable automatically (%s)", + insight.id, + insight.get_product_id(), ) insight.process_after = datetime.datetime.utcnow() + datetime.timedelta( minutes=10 @@ -218,7 +229,8 @@ def generate_insights(): with db: import_result = import_insights( product_predictions_iter, - server_domain=settings.BaseURLProvider.server_domain(), + # Currently the JSONL dataset is OFF-only + server_type=ServerType.off, ) logger.info(import_result) diff --git a/robotoff/scheduler/latent.py b/robotoff/scheduler/latent.py index 479de4c86a..504a4fc97d 100644 --- a/robotoff/scheduler/latent.py +++ b/robotoff/scheduler/latent.py @@ -1,3 +1,5 @@ +from pymongo.collection import Collection + from robotoff.models import Prediction, with_db from robotoff.products import ( DBProductStore, @@ -5,7 +7,7 @@ is_nutrition_image, is_valid_image, ) -from robotoff.types import PredictionType +from robotoff.types import PredictionType, ProductIdentifier, ServerType from robotoff.utils import get_logger logger = get_logger(__name__) @@ -22,8 +24,10 @@ def generate_quality_facets(): @with_db def generate_fiber_quality_facet() -> None: - product_store: DBProductStore = get_product_store() - collection = product_store.collection + # Use ServerType.off as fiber quality facet is only for OFF + server_type = ServerType.off + product_store: DBProductStore = get_product_store(server_type) + collection: Collection = product_store.collection added = 0 seen_set: set[str] = set() @@ -33,6 +37,7 @@ def generate_fiber_quality_facet() -> None: Prediction.type == PredictionType.nutrient_mention.name, Prediction.data["mentions"].contains("fiber"), Prediction.source_image.is_null(False), + Prediction.server_type == server_type.name, ) .iterator() ): @@ -41,8 +46,9 @@ def generate_fiber_quality_facet() -> None: if barcode in seen_set: continue + product_id = ProductIdentifier(barcode, server_type) product = product_store.get_product( - barcode, ["nutriments", "data_quality_tags", "images"] + product_id, ["nutriments", "data_quality_tags", "images"] ) if product is None: diff --git a/robotoff/settings.py b/robotoff/settings.py index 457deb6a61..dd86d3fe7e 100644 --- a/robotoff/settings.py +++ b/robotoff/settings.py @@ -7,6 +7,8 @@ from sentry_sdk.integrations import Integration from sentry_sdk.integrations.logging import LoggingIntegration +from robotoff.types import ServerType + # Robotoff instance gives the environment, either `prod` or `dev` # (`dev` by default). @@ -32,10 +34,10 @@ def _get_default_scheme() -> str: return os.environ.get("ROBOTOFF_SCHEME", "https") -def _get_default_domain(): - # `ROBOTOFF_DOMAIN` can be used to overwrite the Product Opener domain used. - # If empty, the domain will be inferred from `ROBOTOFF_INSTANCE` - return os.environ.get("ROBOTOFF_DOMAIN", "openfoodfacts.%s" % _instance_tld()) +def _get_tld(): + # `ROBOTOFF_TLD` can be used to overwrite the Product Opener top level domain used. + # If empty, the tld will be inferred from `ROBOTOFF_INSTANCE` + return os.environ.get("ROBOTOFF_TLD", _instance_tld()) class BaseURLProvider(object): @@ -46,20 +48,20 @@ class BaseURLProvider(object): @staticmethod def _get_url( + base_domain: str, prefix: Optional[str] = "world", - domain: Optional[str] = None, + tld: Optional[str] = None, scheme: Optional[str] = None, ): + tld = _get_tld() if tld is None else tld data = { - "domain": _get_default_domain(), + "domain": f"{base_domain}.{tld}", "scheme": _get_default_scheme(), } if prefix: data["prefix"] = prefix if scheme: data["scheme"] = scheme - if domain: - data["domain"] = domain if "prefix" in data: return "%(scheme)s://%(prefix)s.%(domain)s" % data @@ -67,57 +69,57 @@ def _get_url( return "%(scheme)s://%(domain)s" % data @staticmethod - def server_domain(): - """Return the server domain: `api.openfoodfacts.*`""" - return "api." + _get_default_domain() + def server_domain(server_type: ServerType) -> str: + """Return the server domain: `api.*.*`""" + return f"api.{server_type.value}.{_get_tld()}" @staticmethod - def world(): - return BaseURLProvider._get_url(prefix="world") + def world(server_type: ServerType): + return BaseURLProvider._get_url(prefix="world", base_domain=server_type.value) @staticmethod def robotoff() -> str: - return BaseURLProvider._get_url(prefix="robotoff") + return BaseURLProvider._get_url( + prefix="robotoff", base_domain=ServerType.off.value + ) @staticmethod - def api() -> str: - return BaseURLProvider._get_url(prefix="api") + def api(server_type: ServerType) -> str: + return BaseURLProvider._get_url(prefix="api", base_domain=server_type.value) @staticmethod - def static() -> str: + def static(server_type: ServerType) -> str: # locally we may want to change it, give environment a chance - static_domain = os.environ.get("STATIC_OFF_DOMAIN", "") - if static_domain: - if "://" in static_domain: - scheme, static_domain = static_domain.split("://", 1) + base_domain = os.environ.get("STATIC_DOMAIN", "") + if base_domain: + if "://" in base_domain: + scheme, base_domain = base_domain.split("://", 1) else: scheme = _get_default_scheme() return BaseURLProvider._get_url( - prefix=None, scheme=scheme, domain=static_domain + prefix=None, scheme=scheme, base_domain=base_domain ) - return BaseURLProvider._get_url(prefix="static") + return BaseURLProvider._get_url(prefix="static", base_domain=server_type.value) @staticmethod - def image_url(image_path: str) -> str: - # If STATIC_OFF_DOMAIN is defined, used the custom static domain - # configured - # Otherwise use images.openfoodfacts.{net,org} as proxy server - prefix = ( - BaseURLProvider.static() - if os.environ.get("STATIC_OFF_DOMAIN") - else BaseURLProvider._get_url(prefix="images") + def image_url(server_type: ServerType, image_path: str) -> str: + prefix = BaseURLProvider._get_url( + prefix="images", base_domain=server_type.value ) return prefix + f"/images/products{image_path}" @staticmethod - def country(country_code: str) -> str: - return BaseURLProvider._get_url(prefix=country_code) + def country(server_type: ServerType, country_code: str) -> str: + return BaseURLProvider._get_url( + prefix=country_code, base_domain=server_type.value + ) @staticmethod def event_api() -> str: return os.environ.get( - "EVENTS_API_URL", BaseURLProvider._get_url(prefix="events") + "EVENTS_API_URL", + BaseURLProvider._get_url(prefix="events", base_domain=ServerType.off.value), ) @@ -135,18 +137,24 @@ def event_api() -> str: # Products JSONL -JSONL_DATASET_URL = BaseURLProvider.static() + "/data/openfoodfacts-products.jsonl.gz" +JSONL_DATASET_URL = ( + BaseURLProvider.static(ServerType.off) + "/data/openfoodfacts-products.jsonl.gz" +) TAXONOMY_URLS = { - "category": BaseURLProvider.static() + "/data/taxonomies/categories.full.json", - "ingredient": BaseURLProvider.static() + "/data/taxonomies/ingredients.full.json", - "label": BaseURLProvider.static() + "/data/taxonomies/labels.full.json", - "brand": BaseURLProvider.static() + "/data/taxonomies/brands.full.json", - "packaging_shape": BaseURLProvider.static() + "category": BaseURLProvider.static(ServerType.off) + + "/data/taxonomies/categories.full.json", + "ingredient": BaseURLProvider.static(ServerType.off) + + "/data/taxonomies/ingredients.full.json", + "label": BaseURLProvider.static(ServerType.off) + + "/data/taxonomies/labels.full.json", + "brand": BaseURLProvider.static(ServerType.off) + + "/data/taxonomies/brands.full.json", + "packaging_shape": BaseURLProvider.static(ServerType.off) + "/data/taxonomies/packaging_shapes.full.json", - "packaging_material": BaseURLProvider.static() + "packaging_material": BaseURLProvider.static(ServerType.off) + "/data/taxonomies/packaging_materials.full.json", - "packaging_recycling": BaseURLProvider.static() + "packaging_recycling": BaseURLProvider.static(ServerType.off) + "/data/taxonomies/packaging_recycling.full.json", } diff --git a/robotoff/slack.py b/robotoff/slack.py index 7dc92e3891..83cb377b37 100644 --- a/robotoff/slack.py +++ b/robotoff/slack.py @@ -7,8 +7,14 @@ from requests.exceptions import JSONDecodeError from robotoff import settings -from robotoff.models import LogoAnnotation, ProductInsight, crop_image_url -from robotoff.types import InsightType, JSONType, LogoLabelType, Prediction +from robotoff.models import ImageModel, LogoAnnotation, ProductInsight, crop_image_url +from robotoff.types import ( + InsightType, + JSONType, + LogoLabelType, + Prediction, + ProductIdentifier, +) from robotoff.utils import get_logger, http_session logger = get_logger(__name__) @@ -28,7 +34,10 @@ class NotifierInterface: # for a notifier might choose to only implements a few def notify_image_flag( - self, predictions: list[Prediction], source_image: str, barcode: str + self, + predictions: list[Prediction], + source_image: str, + product_id: ProductIdentifier, ): pass @@ -129,9 +138,12 @@ def _dispatch(self, function_name: str, *args, **kwargs): fn(*args, **kwargs) def notify_image_flag( - self, predictions: list[Prediction], source_image: str, barcode: str + self, + predictions: list[Prediction], + source_image: str, + product_id: ProductIdentifier, ): - self._dispatch("notify_image_flag", predictions, source_image, barcode) + self._dispatch("notify_image_flag", predictions, source_image, product_id) def notify_automatic_processing(self, insight: ProductInsight): self._dispatch("notify_automatic_processing", insight) @@ -152,20 +164,29 @@ def __init__(self, service_url): self.service_url = service_url.rstrip("/") def notify_image_flag( - self, predictions: list[Prediction], source_image: str, barcode: str + self, + predictions: list[Prediction], + source_image: str, + product_id: ProductIdentifier, ): """Send image to the moderation server so that a human can moderate it""" if not predictions: return - image_url = settings.BaseURLProvider.image_url(source_image) + image_url = settings.BaseURLProvider.image_url( + product_id.server_type, source_image + ) image_id = int(source_image.rsplit("/", 1)[-1].split(".", 1)[0]) params = {"imgid": image_id, "url": image_url} try: - http_session.put(f"{self.service_url}/{barcode}", data=params) + http_session.put(f"{self.service_url}/{product_id.barcode}", data=params) except Exception: logger.exception( "Error while notifying image to moderation service", - extra={"params": params, "url": image_url, "barcode": barcode}, + extra={ + "params": params, + "url": image_url, + "barcode": product_id.barcode, + }, ) @@ -189,7 +210,10 @@ def __init__(self, slack_token: str): self.slack_token = slack_token def notify_image_flag( - self, predictions: list[Prediction], source_image: str, barcode: str + self, + predictions: list[Prediction], + source_image: str, + product_id: ProductIdentifier, ): """Sends alerts to Slack channels for flagged images.""" if not predictions: @@ -212,8 +236,10 @@ def notify_image_flag( match_text = flagged.data["text"] text += f"type: {flag_type}\nlabel: *{label}*, match: {match_text}\n" - edit_url = f"{settings.BaseURLProvider.world()}/cgi/product.pl?type=edit&code={barcode}" - image_url = settings.BaseURLProvider.image_url(source_image) + edit_url = f"{settings.BaseURLProvider.world(product_id.server_type)}/cgi/product.pl?type=edit&code={product_id.barcode}" + image_url = settings.BaseURLProvider.image_url( + product_id.server_type, source_image + ) full_text = f"{text}\n <{image_url}|Image> -- <{edit_url}|*Edit*>" message = _slack_message_block(full_text, with_image=image_url) @@ -221,15 +247,19 @@ def notify_image_flag( self._post_message(message, slack_channel, **self.COLLAPSE_LINKS_PARAMS) def notify_automatic_processing(self, insight: ProductInsight): - product_url = f"{settings.BaseURLProvider.world()}/product/{insight.barcode}" + product_url = f"{settings.BaseURLProvider.world(insight.server_type)}/product/{insight.barcode}" if insight.source_image: if insight.data and "bounding_box" in insight.data: image_url = crop_image_url( - insight.source_image, insight.data.get("bounding_box") + insight.server_type, + insight.source_image, + insight.data.get("bounding_box"), ) else: - image_url = settings.BaseURLProvider.image_url(insight.source_image) + image_url = settings.BaseURLProvider.image_url( + insight.server_type, insight.source_image + ) metadata_text = f"(<{product_url}|product>, <{image_url}|source image>)" else: metadata_text = f"(<{product_url}|product>)" @@ -269,12 +299,13 @@ def send_logo_notification( ) ) ) - barcode = logo.barcode - base_off_url = settings.BaseURLProvider.world() + image_model: ImageModel = logo.image_prediction.image + product_id = image_model.get_product_id() + base_off_url = settings.BaseURLProvider.world(product_id.server_type) text = ( f"Prediction for <{crop_url}|image> " f"(, " - f"<{base_off_url}/product/{barcode}|product>):\n{prob_text}" + f"<{base_off_url}/product/{product_id.barcode}|product>):\n{prob_text}" ) self._post_message(_slack_message_block(text), self.ROBOTOFF_ALERT_CHANNEL) diff --git a/robotoff/spellcheck/__init__.py b/robotoff/spellcheck/__init__.py index f70da3ae81..6a4656c4cd 100644 --- a/robotoff/spellcheck/__init__.py +++ b/robotoff/spellcheck/__init__.py @@ -8,7 +8,7 @@ from robotoff.spellcheck.patterns import PatternsSpellchecker from robotoff.spellcheck.percentages import PercentagesSpellchecker from robotoff.spellcheck.vocabulary import VocabularySpellchecker -from robotoff.types import JSONType, Prediction, PredictionType +from robotoff.types import JSONType, Prediction, PredictionType, ServerType SPELLCHECKERS = { "elasticsearch": ElasticSearchSpellchecker, @@ -81,6 +81,7 @@ def generate_insights( type=PredictionType.ingredient_spellcheck, data=insight, barcode=product["code"], + server_type=ServerType.off, ) insights_count += 1 diff --git a/robotoff/types.py b/robotoff/types.py index 537b652d4a..1526b06f6c 100644 --- a/robotoff/types.py +++ b/robotoff/types.py @@ -151,6 +151,35 @@ class InsightType(str, enum.Enum): nutrition_table_structure = "nutrition_table_structure" +class ServerType(enum.Enum): + off = "openfoodfacts" + obf = "openbeautyfacts" + opff = "openpetfoodfacts" + opf = "openproductfacts" + off_pro = "openfoodfacts" + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return self.name + + @classmethod + def get_from_server_domain(cls, server_domain: str) -> "ServerType": + subdomain, base_domain, tld = server_domain.rsplit(".", maxsplit=2) + + if subdomain == "api.pro": + if base_domain == "openfoodfacts": + return cls.off_pro + raise ValueError("pro platform is only available for Open Food Facts") + + for server_type in cls: + if base_domain == server_type.value: + return server_type + + raise ValueError(f"no ServerType matched for server_domain {server_domain}") + + @dataclasses.dataclass class Prediction: type: PredictionType @@ -162,9 +191,9 @@ class Prediction: barcode: Optional[str] = None timestamp: Optional[datetime.datetime] = None source_image: Optional[str] = None - server_domain: Optional[str] = None id: Optional[int] = None confidence: Optional[float] = None + server_type: ServerType = ServerType.off def to_dict(self) -> dict[str, Any]: return dataclasses.asdict(self, dict_factory=dict_factory) @@ -173,12 +202,31 @@ def to_dict(self) -> dict[str, Any]: def dict_factory(*args, **kwargs): d = dict(*args, **kwargs) for key, value in d.items(): - if isinstance(value, PredictionType): + if isinstance(value, (PredictionType, ServerType)): d[key] = value.name return d +@dataclasses.dataclass +class ProductIdentifier: + """Dataclass to uniquely identify a product across all Open*Facts + projects, with: + + - the product barcode + - the project specified by the ServerType + """ + + barcode: str + server_type: ServerType + + def __repr__(self) -> str: + return "" % (self.barcode, self.server_type.name) + + def __hash__(self) -> int: + return hash((self.barcode, self.server_type)) + + @enum.unique class ElasticSearchIndex(str, enum.Enum): product = "product" @@ -190,7 +238,7 @@ class ProductInsightImportResult: insight_created_ids: list[uuid.UUID] insight_updated_ids: list[uuid.UUID] insight_deleted_ids: list[uuid.UUID] - barcode: str + product_id: ProductIdentifier type: InsightType @@ -198,6 +246,7 @@ class ProductInsightImportResult: class PredictionImportResult: created: int barcode: str + server_type: ServerType @dataclasses.dataclass diff --git a/robotoff/workers/tasks/__init__.py b/robotoff/workers/tasks/__init__.py index c437cd64a0..64fd1bfc40 100644 --- a/robotoff/workers/tasks/__init__.py +++ b/robotoff/workers/tasks/__init__.py @@ -1,6 +1,7 @@ from robotoff.insights.importer import refresh_insights from robotoff.models import Prediction, ProductInsight, with_db from robotoff.products import fetch_dataset, has_dataset_changed +from robotoff.types import ProductIdentifier from robotoff.utils import get_logger from .import_image import run_import_image_job # noqa: F401 @@ -18,27 +19,27 @@ def download_product_dataset_job(): @with_db -def delete_product_insights_job(barcode: str, server_domain: str): +def delete_product_insights_job(product_id: ProductIdentifier): """This job is triggered by Product Opener via /api/v1/webhook/product when the given product has been removed from the database - in this case we must delete all of the associated predictions and insights that have not been annotated. """ - logger.info("Product %s deleted, deleting associated insights...", barcode) + logger.info("%s deleted, deleting associated insights...", product_id) deleted_predictions = ( Prediction.delete() .where( - Prediction.barcode == barcode, - Prediction.server_domain == server_domain, + Prediction.barcode == product_id.barcode, + Prediction.server_type == product_id.server_type.name, ) .execute() ) deleted_insights = ( ProductInsight.delete() .where( - ProductInsight.barcode == barcode, + ProductInsight.barcode == product_id.barcode, + ProductInsight.server_type == product_id.server_type.name, ProductInsight.annotation.is_null(), - ProductInsight.server_domain == server_domain, ) .execute() ) @@ -50,11 +51,9 @@ def delete_product_insights_job(barcode: str, server_domain: str): @with_db -def refresh_insights_job(barcodes: list[str], server_domain: str): - logger.info( - f"Refreshing insights for {len(barcodes)} products, server_domain: {server_domain}" - ) - for barcode in barcodes: - import_results = refresh_insights(barcode, server_domain) +def refresh_insights_job(product_ids: list[ProductIdentifier]): + logger.info(f"Refreshing insights for {len(product_ids)} products") + for product_id in product_ids: + import_results = refresh_insights(product_id) for import_result in import_results: logger.info(import_result) diff --git a/robotoff/workers/tasks/import_image.py b/robotoff/workers/tasks/import_image.py index 99d03303d6..3eaca5f537 100644 --- a/robotoff/workers/tasks/import_image.py +++ b/robotoff/workers/tasks/import_image.py @@ -34,16 +34,20 @@ from robotoff.products import get_product_store from robotoff.slack import NotifierFactory from robotoff.triton import generate_clip_embedding -from robotoff.types import JSONType, ObjectDetectionModel, PredictionType +from robotoff.types import ( + JSONType, + ObjectDetectionModel, + PredictionType, + ProductIdentifier, + ServerType, +) from robotoff.utils import get_image_from_url, get_logger, http_session from robotoff.workers.queues import enqueue_job, high_queue logger = get_logger(__name__) -def run_import_image_job( - barcode: str, image_url: str, ocr_url: str, server_domain: str -): +def run_import_image_job(product_id: ProductIdentifier, image_url: str, ocr_url: str): """This job is triggered every time there is a new OCR image available for processing by Robotoff, via /api/v1/images/import. @@ -54,22 +58,20 @@ def run_import_image_job( 3. Triggers the 'object_detection' task 4. Stores the imported image metadata in the Robotoff DB. """ - logger.info( - f"Running `import_image` for product {barcode} ({server_domain}), image {image_url}" - ) + logger.info("Running `import_image` for %s, image %s", product_id, image_url) source_image = get_source_from_url(image_url) - product = get_product_store()[barcode] + product = get_product_store(product_id.server_type)[product_id] if product is None and not settings.DISABLE_PRODUCT_CHECK: logger.info( - "Product %s does not exist during image import (%s)", barcode, source_image + "%s does not exist during image import (%s)", + product_id, + source_image, ) return product_images: Optional[JSONType] = getattr(product, "images", None) with db: - image_model = save_image( - barcode, source_image, image_url, product_images, server_domain - ) + image_model = save_image(product_id, source_image, image_url, product_images) if image_model is None: # The image is invalid, no need to perform image extraction jobs @@ -81,7 +83,7 @@ def run_import_image_job( image_id = Path(source_image).stem if product_images is not None and image_id not in product_images: # It happens when the image has been deleted after Robotoff import - logger.info("Unknown image for product %s: %s", barcode, source_image) + logger.info("Unknown image for %s: %s", product_id, source_image) image_model.deleted = True ImageModel.bulk_update([image_model], fields=["deleted"]) return @@ -90,10 +92,9 @@ def run_import_image_job( import_insights_from_image, high_queue, job_kwargs={"result_ttl": 0}, - barcode=barcode, + product_id=product_id, image_url=image_url, ocr_url=ocr_url, - server_domain=server_domain, ) # The two following tasks take longer than the previous one, so it # shouldn't be an issue to launch tasks concurrently (and we still have @@ -103,9 +104,8 @@ def run_import_image_job( run_logo_object_detection, high_queue, job_kwargs={"result_ttl": 0}, - barcode=barcode, + product_id=product_id, image_url=image_url, - server_domain=server_domain, ) # Nutrition table detection is not used at the moment, and every new image # is costly in CPU (as we perform object detection) @@ -114,14 +114,13 @@ def run_import_image_job( # run_nutrition_table_object_detection, # high_queue, # job_kwargs={"result_ttl": 0}, - # barcode=barcode, + # product_id=product_id, # image_url=image_url, - # server_domain=server_domain, # ) def import_insights_from_image( - barcode: str, image_url: str, ocr_url: str, server_domain: str + product_id: ProductIdentifier, image_url: str, ocr_url: str ): image = get_image_from_url(image_url, error_raise=False, session=http_session) @@ -131,7 +130,7 @@ def import_insights_from_image( source_image = get_source_from_url(image_url) predictions = extract_ocr_predictions( - barcode, ocr_url, DEFAULT_OCR_PREDICTION_TYPES + product_id, ocr_url, DEFAULT_OCR_PREDICTION_TYPES ) if any( prediction.value_tag == "en:nutriscore" @@ -142,51 +141,48 @@ def import_insights_from_image( run_nutriscore_object_detection, high_queue, job_kwargs={"result_ttl": 0}, - barcode=barcode, + product_id=product_id, image_url=image_url, - server_domain=server_domain, ) NotifierFactory.get_notifier().notify_image_flag( [p for p in predictions if p.type == PredictionType.image_flag], source_image, - barcode, + product_id, ) with db: - import_result = import_insights(predictions, server_domain) + import_result = import_insights(predictions, server_type=product_id.server_type) logger.info(import_result) -def save_image_job(batch: list[tuple[str, str]], server_domain: str): +def save_image_job(batch: list[tuple[ProductIdentifier, str]], server_type: ServerType): """Save a batch of images in DB. - :param batch: a batch of (barcode, source_image) tuples - :param server_domain: the server domain to use + :param batch: a batch of (product_id, source_image) tuples + :param server_type: the server type (project) of the products """ - product_store = get_product_store() + product_store = get_product_store(server_type) with db.connection_context(): - for barcode, source_image in batch: - product = product_store[barcode] + for product_id, source_image in batch: + product = product_store[product_id] if product is None and not settings.DISABLE_PRODUCT_CHECK: continue with db.atomic(): - image_url = generate_image_url(barcode, Path(source_image).stem) + image_url = generate_image_url(product_id, Path(source_image).stem) save_image( - barcode, + product_id, source_image, image_url, getattr(product, "images", None), - server_domain, ) -def run_nutrition_table_object_detection( - barcode: str, image_url: str, server_domain: str -): +def run_nutrition_table_object_detection(product_id: ProductIdentifier, image_url: str): logger.info( - f"Running nutrition table object detection for product {barcode} " - f"({server_domain}), image {image_url}" + "Running nutrition table object detection for %s, image %s", + product_id, + image_url, ) image = get_image_from_url(image_url, error_raise=False, session=http_session) @@ -198,7 +194,9 @@ def run_nutrition_table_object_detection( source_image = get_source_from_url(image_url) with db: - if image_model := ImageModel.get_or_none(source_image=source_image): + if image_model := ImageModel.get_or_none( + source_image=source_image, server_type=product_id.server_type.name + ): run_object_detection_model( ObjectDetectionModel.nutrition_table, image, image_model ) @@ -215,10 +213,11 @@ def run_nutrition_table_object_detection( } -def run_nutriscore_object_detection(barcode: str, image_url: str, server_domain: str): +def run_nutriscore_object_detection( + product_id: ProductIdentifier, image_url: str, server_domain: str +): logger.info( - f"Running nutriscore object detection for product {barcode} " - f"({server_domain}), image {image_url}" + "Running nutriscore object detection for %s, image %s", product_id, image_url ) image = get_image_from_url(image_url, error_raise=False, session=http_session) @@ -230,7 +229,11 @@ def run_nutriscore_object_detection(barcode: str, image_url: str, server_domain: source_image = get_source_from_url(image_url) with db: - if (image_model := ImageModel.get_or_none(source_image=source_image)) is None: + if ( + image_model := ImageModel.get_or_none( + source_image=source_image, server_type=product_id.server_type.name + ) + ) is None: logger.info("Missing image in DB for image %s", source_image) return @@ -260,31 +263,28 @@ def run_nutriscore_object_detection(barcode: str, image_url: str, server_domain: with db: prediction = Prediction( type=PredictionType.label, - barcode=barcode, + barcode=product_id.barcode, source_image=source_image, value_tag=label_tag, automatic_processing=False, - server_domain=server_domain, + server_type=product_id.server_type, predictor="nutriscore", data={"bounding_box": result["bounding_box"]}, confidence=score, ) - import_result = import_insights([prediction], server_domain) + import_result = import_insights([prediction], product_id.server_type) logger.info(import_result) -def run_logo_object_detection(barcode: str, image_url: str, server_domain: str): +def run_logo_object_detection(product_id: ProductIdentifier, image_url: str): """Detect logos using the universal logo detector model and generate logo-related predictions. - :param barcode: Product barcode + :param product_id: identifier of the product :param image_url: URL of the image to use :param server_domain: The server domain associated with the image """ - logger.info( - f"Running logo object detection for product {barcode} " - f"({server_domain}), image {image_url}" - ) + logger.info("Running logo object detection for %s, image %s", product_id, image_url) image = get_image_from_url(image_url, error_raise=False, session=http_session) @@ -295,7 +295,11 @@ def run_logo_object_detection(barcode: str, image_url: str, server_domain: str): source_image = get_source_from_url(image_url) with db: - if (image_model := ImageModel.get_or_none(source_image=source_image)) is None: + if ( + image_model := ImageModel.get_or_none( + source_image=source_image, server_type=product_id.server_type.name + ) + ) is None: logger.info("Missing image in DB for image %s", source_image) return @@ -342,7 +346,7 @@ def run_logo_object_detection(barcode: str, image_url: str, server_domain: str): high_queue, job_kwargs={"result_ttl": 0}, image_prediction_id=image_prediction.id, - server_domain=server_domain, + server_type=product_id.server_type, ) @@ -370,7 +374,7 @@ def save_logo_embeddings(logos: list[LogoAnnotation], image: Image.Image): @with_db -def process_created_logos(image_prediction_id: int, server_domain: str): +def process_created_logos(image_prediction_id: int, server_type: ServerType): logo_embeddings = list( LogoEmbedding.select() .join(LogoAnnotation) @@ -397,4 +401,4 @@ def process_created_logos(image_prediction_id: int, server_domain: str): logos = [embedding.logo for embedding in logo_embeddings] thresholds = get_logo_confidence_thresholds() - import_logo_insights(logos, thresholds=thresholds, server_domain=server_domain) + import_logo_insights(logos, thresholds=thresholds, server_type=server_type) diff --git a/robotoff/workers/tasks/product_updated.py b/robotoff/workers/tasks/product_updated.py index c9cb39f6ed..30f2dafc98 100644 --- a/robotoff/workers/tasks/product_updated.py +++ b/robotoff/workers/tasks/product_updated.py @@ -3,20 +3,19 @@ from robotoff.insights.extraction import get_predictions_from_product_name from robotoff.insights.importer import import_insights, refresh_insights from robotoff.models import with_db -from robotoff.off import ServerType, get_server_type from robotoff.prediction.category.matcher import predict as predict_category_matcher from robotoff.prediction.category.neural.category_classifier import CategoryClassifier from robotoff.products import get_product from robotoff.redis import Lock, LockedResourceException from robotoff.taxonomy import TaxonomyType, get_taxonomy -from robotoff.types import JSONType +from robotoff.types import JSONType, ProductIdentifier, ServerType from robotoff.utils import get_logger logger = get_logger(__name__) @with_db -def update_insights_job(barcode: str, server_domain: str): +def update_insights_job(product_id: ProductIdentifier): """This job is triggered by the webhook API, when product information has been updated. @@ -25,45 +24,49 @@ def update_insights_job(barcode: str, server_domain: str): 1. Generate new predictions related to the product's category and name. 2. Regenerate all insights from the product associated predictions. """ - logger.info("Running `update_insights` for product %s (%s)", barcode, server_domain) + logger.info("Running `update_insights` for %s", product_id) + + if product_id.server_type != ServerType.off: + # We don't have yet MongoDB connection between Robotoff and MongoDB of + # projects other than Open Food Facts, so abort the update job here + return try: with Lock( - name=f"robotoff:product_update_job:{barcode}", expire=300, timeout=10 + name=f"robotoff:product_update_job:{product_id.server_type.name}:{product_id.barcode}", + expire=300, + timeout=10, ): # We handle concurrency thanks to the lock as the task will fetch # product from MongoDB at the time it runs, it's not worth # reprocessing with another task arriving concurrently. # The expire is there only in case the lock is not released # (process killed) - product_dict = get_product(barcode) + product_dict = get_product(product_id) if product_dict is None: - logger.info("Updated product does not exist: %s", barcode) + logger.info("Updated product does not exist: %s", product_id) return - updated_product_predict_insights(barcode, product_dict, server_domain) + updated_product_predict_insights(product_id, product_dict) logger.info("Refreshing insights...") - import_results = refresh_insights(barcode, server_domain) + import_results = refresh_insights(product_id) for import_result in import_results: logger.info(import_result) except LockedResourceException: logger.info( - f"Couldn't acquire product_update lock, skipping product_update for product {barcode}" + "Couldn't acquire product_update lock, skipping product_update for product %s", + product_id, ) -def add_category_insight(barcode: str, product: JSONType, server_domain: str): +def add_category_insight(product_id: ProductIdentifier, product: JSONType): """Predict categories for product and import predicted category insight. - :param barcode: product barcode + :param product_id: identifier of the product :param product: product as retrieved from application - :param server_domain: the server the product belongs to :return: True if at least one category insight was imported """ - if get_server_type(server_domain) != ServerType.off: - return - logger.info("Predicting product categories...") # predict category using matching algorithm on product name product_predictions = predict_category_matcher(product) @@ -72,7 +75,7 @@ def add_category_insight(barcode: str, product: JSONType, server_domain: str): try: neural_predictions, _ = CategoryClassifier( get_taxonomy(TaxonomyType.category.name) - ).predict(product) + ).predict(product, product_id) product_predictions += neural_predictions except requests.exceptions.HTTPError as e: resp = e.response @@ -84,22 +87,23 @@ def add_category_insight(barcode: str, product: JSONType, server_domain: str): return for prediction in product_predictions: - prediction.barcode = barcode + prediction.barcode = product_id.barcode + prediction.server_type = product_id.server_type - import_result = import_insights(product_predictions, server_domain) + import_result = import_insights(product_predictions, product_id.server_type) logger.info(import_result) def updated_product_predict_insights( - barcode: str, product: JSONType, server_domain: str + product_id: ProductIdentifier, product: JSONType ) -> None: - add_category_insight(barcode, product, server_domain) + add_category_insight(product_id, product) product_name = product.get("product_name") if not product_name: return logger.info("Generating predictions from product name...") - predictions_all = get_predictions_from_product_name(barcode, product_name) - import_result = import_insights(predictions_all, server_domain) + predictions_all = get_predictions_from_product_name(product_id, product_name) + import_result = import_insights(predictions_all, product_id.server_type) logger.info(import_result) diff --git a/scripts/insert_image_predictions.py b/scripts/insert_image_predictions.py index 1cb65df9a5..1380498266 100644 --- a/scripts/insert_image_predictions.py +++ b/scripts/insert_image_predictions.py @@ -6,6 +6,7 @@ from robotoff import settings from robotoff.models import ImageModel, ImagePrediction, LogoAnnotation, db from robotoff.off import generate_image_path +from robotoff.types import ServerType from robotoff.utils import get_logger, jsonl_iter logger = get_logger() @@ -15,13 +16,15 @@ MODEL_NAME = "universal-logo-detector" MODEL_VERSION = "tf-universal-logo-detector-1.0" TYPE = "object_detection" +SERVER_TYPE = ServerType.off -def get_seen_set() -> set[tuple[str, str]]: +def get_seen_set(server_type: ServerType) -> set[tuple[str, str]]: seen_set: set[tuple[str, str]] = set() for prediction in ( ImagePrediction.select(ImagePrediction.model_name, ImageModel.source_image) .join(ImageModel) + .where(ImageModel.server_type == server_type.name) .iterator() ): seen_set.add((prediction.model_name, prediction.image.source_image)) @@ -29,10 +32,15 @@ def get_seen_set() -> set[tuple[str, str]]: return seen_set -def insert_batch(data_path: pathlib.Path, model_name: str, model_version: str) -> int: +def insert_batch( + data_path: pathlib.Path, + model_name: str, + model_version: str, + server_type: ServerType, +) -> int: timestamp = datetime.datetime.utcnow() logger.info("Loading seen set...") - seen_set = get_seen_set() + seen_set = get_seen_set(server_type) logger.info("Seen set loaded") inserted = 0 @@ -44,7 +52,9 @@ def insert_batch(data_path: pathlib.Path, model_name: str, model_version: str) - if key in seen_set: continue - image_instance = ImageModel.get_or_none(source_image=source_image) + image_instance = ImageModel.get_or_none( + source_image=source_image, server_type=server_type.name + ) if image_instance is None: logger.warning("Unknown image in DB: {}".format(source_image)) @@ -83,7 +93,7 @@ def main(): logger.info("Starting image prediction import...") with db: - inserted = insert_batch(DATA_PATH, MODEL_NAME, MODEL_VERSION) + inserted = insert_batch(DATA_PATH, MODEL_NAME, MODEL_VERSION, SERVER_TYPE) logger.info("{} image predictions inserted".format(inserted)) diff --git a/scripts/insert_images.py b/scripts/insert_images.py index e91ed604df..0421272eaa 100644 --- a/scripts/insert_images.py +++ b/scripts/insert_images.py @@ -1,23 +1,26 @@ import tqdm -from robotoff import settings from robotoff.models import ImageModel, db from robotoff.off import generate_image_path, generate_image_url from robotoff.products import Product, ProductDataset +from robotoff.types import ProductIdentifier, ServerType from robotoff.utils import get_logger from robotoff.workers.tasks.import_image import save_image logger = get_logger() - +SERVER_TYPE = ServerType.off ds = ProductDataset.load() saved = 0 seen_set = set( ( (x.barcode, x.image_id) - for x in ImageModel.select(ImageModel.barcode, ImageModel.image_id).iterator() + for x in ImageModel.select(ImageModel.barcode, ImageModel.image_id) + .where(ImageModel.server_type == SERVER_TYPE.name) + .iterator() ) ) + with db: product: Product for product in tqdm.tqdm( @@ -26,6 +29,7 @@ if product.barcode is None: continue + product_id = ProductIdentifier(product.barcode, SERVER_TYPE) for image_id in product.images.keys(): if not image_id.isdigit(): continue @@ -34,18 +38,12 @@ continue source_image = generate_image_path(product.barcode, str(image_id)) - image_url = generate_image_url(product.barcode, str(image_id)) + image_url = generate_image_url(product_id, str(image_id)) try: - save_image( - product.barcode, - source_image, - image_url, - product.images, - settings.BaseURLProvider.server_domain(), - ) + save_image(product_id, source_image, image_url, product.images) except Exception as e: - logger.info("Exception for product {}\n{}".format(product.barcode, e)) + logger.info("Exception for %s\n%s", product_id, e) saved += 1 diff --git a/scripts/remove_duplicates.py b/scripts/remove_duplicates.py index 8e75b71030..a21a1d03d3 100644 --- a/scripts/remove_duplicates.py +++ b/scripts/remove_duplicates.py @@ -1,6 +1,10 @@ from peewee import fn from robotoff.models import ImageModel, ImagePrediction, LogoAnnotation, db +from robotoff.types import ServerType + +SERVER_TYPE = ServerType.off + with db.connection_context(): with db.atomic(): @@ -10,6 +14,7 @@ ImageModel.source_image, fn.COUNT(ImageModel.id).alias("count"), ) + .where(ImageModel.server_type == SERVER_TYPE.name) .group_by(ImageModel.source_image) .having(fn.COUNT(ImageModel.id) > 1) .dicts() @@ -25,7 +30,10 @@ set( item[0] for item in ImageModel.select(ImageModel.id) - .where(ImageModel.source_image == source_image) + .where( + ImageModel.source_image == source_image, + ImageModel.server_type == SERVER_TYPE.name, + ) .tuples() ) ) diff --git a/tests/integration/insights/test_category_import.py b/tests/integration/insights/test_category_import.py index 5b0ad51f5a..51c0c226f1 100644 --- a/tests/integration/insights/test_category_import.py +++ b/tests/integration/insights/test_category_import.py @@ -1,15 +1,16 @@ import pytest -from robotoff import settings from robotoff.insights.importer import import_insights from robotoff.models import ProductInsight from robotoff.products import Product -from robotoff.types import Prediction, PredictionType +from robotoff.types import Prediction, PredictionType, ProductIdentifier, ServerType from ..models_utils import PredictionFactory, ProductInsightFactory, clean_db insight_id1 = "94371643-c2bc-4291-a585-af2cb1a5270a" -barcode1 = "00001" +DEFAULT_BARCODE = "00001" +DEFAULT_SERVER_TYPE = ServerType.off +DEFAULT_PRODUCT_ID = ProductIdentifier(DEFAULT_BARCODE, DEFAULT_SERVER_TYPE) @pytest.fixture(autouse=True) @@ -19,7 +20,7 @@ def _set_up_and_tear_down(peewee_db): clean_db() # a category already exists PredictionFactory( - barcode=barcode1, + barcode=DEFAULT_BARCODE, type="category", value_tag="en:salmons", automatic_processing=False, @@ -27,7 +28,7 @@ def _set_up_and_tear_down(peewee_db): ) ProductInsightFactory( id=insight_id1, - barcode=barcode1, + barcode=DEFAULT_BARCODE, type="category", value_tag="en:salmons", predictor="matcher", @@ -40,7 +41,7 @@ def _set_up_and_tear_down(peewee_db): def matcher_prediction(category): return Prediction( - barcode=barcode1, + barcode=DEFAULT_BARCODE, type=PredictionType.category, value_tag=category, data={ @@ -54,7 +55,7 @@ def matcher_prediction(category): def neural_prediction(category, confidence=0.7, auto=False): return Prediction( - barcode=barcode1, + barcode=DEFAULT_BARCODE, type=PredictionType.category, value_tag=category, data={"lang": "xx"}, @@ -71,15 +72,13 @@ class TestCategoryImporter: """ def fake_product_store(self): - return {barcode1: Product({"categories_tags": ["en:fish"]})} + return {DEFAULT_PRODUCT_ID: Product({"categories_tags": ["en:fish"]})} def _run_import(self, predictions, product_store=None): if product_store is None: product_store = self.fake_product_store() return import_insights( - predictions, - server_domain=settings.BaseURLProvider.server_domain(), - product_store=product_store, + predictions, DEFAULT_SERVER_TYPE, product_store=product_store ) @pytest.mark.parametrize( @@ -142,7 +141,6 @@ def test_import_one_different_value_tag(self, predictions): assert ProductInsight.select().count() == 1 inserted = ProductInsight.get(ProductInsight.id != insight_id1) assert inserted.value_tag == "en:smoked-salmons" - assert inserted.server_domain == settings.BaseURLProvider.server_domain() assert not inserted.automatic_processing def test_import_auto(self): @@ -156,7 +154,6 @@ def test_import_auto(self): assert ProductInsight.select().count() == 1 inserted = ProductInsight.get(ProductInsight.id != insight_id1) assert inserted.value_tag == "en:smoked-salmons" - assert inserted.server_domain == settings.BaseURLProvider.server_domain() assert inserted.automatic_processing @pytest.mark.parametrize( @@ -170,7 +167,9 @@ def test_import_auto(self): ) def test_import_product_not_in_store(self, predictions): # we should not create insight for non existing products ! - import_result = self._run_import(predictions, product_store={barcode1: None}) + import_result = self._run_import( + predictions, product_store={DEFAULT_PRODUCT_ID: None} + ) assert import_result.created_insights_count() == 0 assert import_result.updated_insights_count() == 0 assert import_result.deleted_insights_count() == 0 diff --git a/tests/integration/insights/test_process_insights.py b/tests/integration/insights/test_process_insights.py index 194422344a..e36ece508f 100644 --- a/tests/integration/insights/test_process_insights.py +++ b/tests/integration/insights/test_process_insights.py @@ -2,12 +2,14 @@ import pytest -from robotoff import settings from robotoff.models import ProductInsight from robotoff.scheduler import process_insights +from robotoff.types import ServerType from ..models_utils import ProductInsightFactory, clean_db +DEFAULT_SERVER_TYPE = ServerType.off + @pytest.fixture(autouse=True) def _set_up_and_tear_down(peewee_db): @@ -34,6 +36,7 @@ def _create_insight(**kwargs): "automatic_processing": True, "process_after": datetime.utcnow() - timedelta(minutes=12), "n_votes": 3, + "server_type": DEFAULT_SERVER_TYPE.name, }, **kwargs, ) @@ -69,8 +72,8 @@ def test_process_insight_category(mocker, peewee_db): "add_categories": "en:Salmons", "comment": f"[robotoff] Adding category 'en:Salmons', ID: {id1}", }, + server_type=DEFAULT_SERVER_TYPE, auth=None, - server_domain=settings.BaseURLProvider.server_domain(), ) @@ -149,8 +152,8 @@ def raise_for_salmons(params, *args, **kwargs): "add_categories": "en:Salmons", "comment": f"[robotoff] Adding category 'en:Salmons', ID: {id1}", }, + server_type=DEFAULT_SERVER_TYPE, auth=None, - server_domain=settings.BaseURLProvider.server_domain(), ) with peewee_db: @@ -166,8 +169,8 @@ def raise_for_salmons(params, *args, **kwargs): "add_categories": "en:Tuna", "comment": f"[robotoff] Adding category 'en:Tuna', ID: {id2}", }, + server_type=DEFAULT_SERVER_TYPE, auth=None, - server_domain=settings.BaseURLProvider.server_domain(), ) # we add only two calls assert mock.call_count == 2 @@ -205,8 +208,8 @@ def test_process_insight_same_product(mocker, peewee_db): "add_categories": "en:Big fish", "comment": f"[robotoff] Adding category 'en:Big fish', ID: {id2}", }, + server_type=DEFAULT_SERVER_TYPE, auth=None, - server_domain=settings.BaseURLProvider.server_domain(), ) mock.assert_any_call( { @@ -214,6 +217,6 @@ def test_process_insight_same_product(mocker, peewee_db): "add_categories": "en:Smoked Salmon", "comment": f"[robotoff] Adding category 'en:Smoked Salmon', ID: {id3}", }, + server_type=DEFAULT_SERVER_TYPE, auth=None, - server_domain=settings.BaseURLProvider.server_domain(), ) diff --git a/tests/integration/models_utils.py b/tests/integration/models_utils.py index f60b6f80e4..17ca08c90a 100644 --- a/tests/integration/models_utils.py +++ b/tests/integration/models_utils.py @@ -11,7 +11,7 @@ import numpy as np from factory_peewee import PeeweeModelFactory -from robotoff import models, settings +from robotoff import models from robotoff.models import ( AnnotationVote, ImageModel, @@ -47,10 +47,6 @@ class Meta: brands: list[str] = [] n_votes = 0 value_tag = "en:seeds" - # we use a lazy function for settings can change in a test - server_domain: str = factory.LazyFunction( - lambda: settings.BaseURLProvider.server_domain() - ) server_type = "off" unique_scans_n = 10 annotation = None @@ -67,9 +63,6 @@ class Meta: data: dict[str, Any] = {} timestamp = factory.LazyFunction(datetime.utcnow) value_tag = "en:seeds" - server_domain = factory.LazyFunction( - lambda: settings.BaseURLProvider.server_domain() - ) automatic_processing = None predictor = None confidence: Optional[float] = None @@ -97,9 +90,6 @@ class Meta: source_image = factory.Sequence(lambda n: f"/images/{n:02}.jpg") width = 400 height = 400 - server_domain = factory.LazyFunction( - lambda: settings.BaseURLProvider.server_domain() - ) server_type = "off" diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index ea72278bb5..fd2c8f6db3 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -5,11 +5,11 @@ import pytest from falcon import testing -from robotoff import settings from robotoff.app import events from robotoff.app.api import api from robotoff.models import AnnotationVote, LogoAnnotation, ProductInsight from robotoff.off import OFFAuthentication +from robotoff.types import ProductIdentifier, ServerType from .models_utils import ( AnnotationVoteFactory, @@ -23,6 +23,8 @@ insight_id = "94371643-c2bc-4291-a585-af2cb1a5270a" DEFAULT_BARCODE = "1" +DEFAULT_SERVER_TYPE = ServerType.off +DEFAULT_PRODUCT_ID = ProductIdentifier(DEFAULT_BARCODE, DEFAULT_SERVER_TYPE) @pytest.fixture(autouse=True) @@ -71,6 +73,7 @@ def test_random_question(client, mocker): "question": "Does the product belong to this category?", "insight_id": insight_id, "insight_type": "category", + "server_type": "off", "source_image_url": "https://images.openfoodfacts.net/images/products/1/ingredients_fr.51.400.jpg", } ], @@ -108,6 +111,7 @@ def test_popular_question(client, mocker): "question": "Does the product belong to this category?", "insight_id": insight_id, "insight_type": "category", + "server_type": "off", } ], "status": "found", @@ -196,6 +200,7 @@ def test_barcode_question(client, mocker): "question": "Does the product belong to this category?", "insight_id": insight_id, "insight_type": "category", + "server_type": "off", } ], "status": "found", @@ -557,10 +562,9 @@ def test_annotate_insight_anonymous_then_authenticated(client, mocker, peewee_db assert insight.get("completed_at") <= datetime.utcnow() # update was done add_category.assert_called_once_with( - "1", # barcode + DEFAULT_PRODUCT_ID, "en:seeds", # category_tag insight_id=uuid.UUID(insight_id), - server_domain=settings.BaseURLProvider.server_domain(), auth=OFFAuthentication(username="a", password="b"), ) @@ -633,6 +637,7 @@ def test_annotation_event(client, monkeypatch, httpserver): "user_id": "a", "device_id": "test-device", "barcode": "1", + "server_type": "off", } httpserver.expect_oneshot_request( "/", method="POST", json=expected_event @@ -644,6 +649,7 @@ def test_annotation_event(client, monkeypatch, httpserver): "insight_id": insight_id, "annotation": 0, "device_id": "test-device", + "server_type": "off", }, headers={ "Authorization": "Basic " + base64.b64encode(b"a:b").decode("ascii") diff --git a/tests/integration/test_core_integration.py b/tests/integration/test_core_integration.py index 7f1cfa0bfd..c3550480e7 100644 --- a/tests/integration/test_core_integration.py +++ b/tests/integration/test_core_integration.py @@ -7,6 +7,7 @@ get_logo_annotation, get_predictions, ) +from robotoff.types import ServerType from .models_utils import ( ImageModelFactory, @@ -17,6 +18,8 @@ clean_db, ) +DEFAULT_SERVER_TYPE = ServerType.off + @pytest.fixture(autouse=True) def _set_up_and_tear_down(peewee_db): @@ -43,12 +46,14 @@ def test_get_image_predictions(): image_prediction4 = ImagePredictionFactory(image__barcode="123", type="category") # test with "barcode" filter - data = list(get_image_predictions(barcode="123")) + data = list(get_image_predictions(DEFAULT_SERVER_TYPE, barcode="123")) assert len(data) == 2 assert prediction_ids(data) == {image_prediction3.id, image_prediction4.id} # test filter with "barcode" and "with_logo=True" - data = list(get_image_predictions(barcode="123", with_logo=True)) + data = list( + get_image_predictions(DEFAULT_SERVER_TYPE, barcode="123", with_logo=True) + ) assert len(data) == 3 assert prediction_ids(data) == { image_prediction1.id, @@ -57,16 +62,20 @@ def test_get_image_predictions(): } # test filter with "with_logo=True" - data = list(get_image_predictions(with_logo=True)) + data = list(get_image_predictions(DEFAULT_SERVER_TYPE, with_logo=True)) assert len(data) == 4 # we have them all # test filter with "type=label" and "with_logo=True" - data = list(get_image_predictions(type="label", with_logo=True)) + data = list( + get_image_predictions(DEFAULT_SERVER_TYPE, type="label", with_logo=True) + ) assert len(data) == 2 assert prediction_ids(data) == {image_prediction2.id, image_prediction3.id} # test filter with "type=label" and "with_logo=False" - data = list(get_image_predictions(type="label", with_logo=False)) + data = list( + get_image_predictions(DEFAULT_SERVER_TYPE, type="label", with_logo=False) + ) assert len(data) == 1 assert prediction_ids(data) == {image_prediction3.id} @@ -85,7 +94,7 @@ def test_get_predictions(): barcode="456", keep_types="label", value_tag="en:eu-organic" ) - actual_prediction1 = get_predictions(barcode="123") + actual_prediction1 = get_predictions(DEFAULT_SERVER_TYPE, barcode="123") actual_items1 = [item.to_dict() for item in actual_prediction1] actual_items1.sort(key=lambda d: d["id"]) assert len(actual_items1) == 3 @@ -99,12 +108,12 @@ def test_get_predictions(): assert actual_items1[2]["id"] == prediction3.id # test that as we have no "brand" prediction, returned list is empty - actual_prediction2 = get_predictions(keep_types=["brand"]) + actual_prediction2 = get_predictions(DEFAULT_SERVER_TYPE, keep_types=["brand"]) assert list(actual_prediction2) == [] # test that predictions are filtered based on "value_tag=en:eu-organic", # returns only "en:eu-organic" predictions - actual_prediction3 = get_predictions(value_tag="en:eu-organic") + actual_prediction3 = get_predictions(DEFAULT_SERVER_TYPE, value_tag="en:eu-organic") actual_items3 = [item.to_dict() for item in actual_prediction3] actual_items3.sort(key=lambda d: d["id"]) assert len(actual_items3) == 2 @@ -116,14 +125,19 @@ def test_get_predictions(): # test that we can filter "barcode", "value_tag", "keep_types" prediction actual_prediction4 = get_predictions( - barcode="123", value_tag="en:eu-organic", keep_types=["category"] + DEFAULT_SERVER_TYPE, + barcode="123", + value_tag="en:eu-organic", + keep_types=["category"], ) actual_items4 = [item.to_dict() for item in actual_prediction4] assert actual_items4[0]["id"] == prediction3.id assert len(actual_items4) == 1 # test to filter results with "label" and "category" prediction - actual_prediction5 = get_predictions(keep_types=["label", "category"]) + actual_prediction5 = get_predictions( + DEFAULT_SERVER_TYPE, keep_types=["label", "category"] + ) actual_items5 = [item.to_dict() for item in actual_prediction5] assert len(actual_items5) == 4 @@ -137,7 +151,7 @@ def test_get_images(): # test with "barcode" filter - image_model_data = get_images(barcode="123") + image_model_data = get_images(barcode="123", server_type=DEFAULT_SERVER_TYPE) image_model_items = [item.to_dict() for item in image_model_data] assert len(image_model_items) == 1 @@ -145,7 +159,9 @@ def test_get_images(): assert image_model_items[0]["barcode"] == "123" # test filter with "barcode" and "with_predictions=True" - image_model_data = get_images(barcode="123", with_predictions=True) + image_model_data = get_images( + barcode="123", with_predictions=True, server_type=DEFAULT_SERVER_TYPE + ) image_model_items = [item.to_dict() for item in image_model_data] image_model_items.sort(key=lambda d: d["id"]) assert len(image_model_items) == 2 @@ -153,7 +169,9 @@ def test_get_images(): assert image_model_items[1]["id"] == image_model3.id # test filter with "with_predictions=True" - image_model_data = get_images(with_predictions=True) + image_model_data = get_images( + with_predictions=True, server_type=DEFAULT_SERVER_TYPE + ) image_model_items = [item.to_dict() for item in image_model_data] image_model_items.sort(key=lambda d: d["id"]) assert len(image_model_items) == 3 @@ -162,7 +180,9 @@ def test_get_images(): assert image_model_items[2]["id"] == image_model3.id # test filter with "barcode" and "with_predictions=True" - image_model_data = get_images(barcode="456", with_predictions=True) + image_model_data = get_images( + barcode="456", with_predictions=True, server_type=DEFAULT_SERVER_TYPE + ) image_model_items = [item.to_dict() for item in image_model_data] assert len(image_model_items) == 1 assert image_model_items[0]["id"] == image_model2.id @@ -238,19 +258,19 @@ def test_get_logo_annotation(): # tests for "barcode" - annotation_data = get_logo_annotation(barcode="123") + annotation_data = get_logo_annotation(DEFAULT_SERVER_TYPE, barcode="123") annotation_data_items = [item.to_dict() for item in annotation_data] assert annotation_data_items[0]["id"] == annotation_123.id assert annotation_data_items[0]["image_prediction"]["image"]["barcode"] == "123" assert annotation_data_items[0]["annotation_type"] == "brand" - annotation_data = get_logo_annotation(barcode="789") + annotation_data = get_logo_annotation(DEFAULT_SERVER_TYPE, barcode="789") annotation_data_items = [item.to_dict() for item in annotation_data] assert annotation_data_items[0]["id"] == annotation_789.id assert annotation_data_items[0]["image_prediction"]["image"]["barcode"] == "789" assert annotation_data_items[0]["annotation_type"] == "dairies" - annotation_data = get_logo_annotation(barcode="396") + annotation_data = get_logo_annotation(DEFAULT_SERVER_TYPE, barcode="396") annotation_data_items = [item.to_dict() for item in annotation_data] assert annotation_data_items[0]["id"] == annotation_396.id assert annotation_data_items[0]["image_prediction"]["image"]["barcode"] == "396" @@ -258,7 +278,7 @@ def test_get_logo_annotation(): # test for "keep_types" - annotation_data = get_logo_annotation(keep_types=["dairies"]) + annotation_data = get_logo_annotation(DEFAULT_SERVER_TYPE, keep_types=["dairies"]) annotation_data_items = [item.to_dict() for item in annotation_data] annotation_data_items.sort(key=lambda d: d["id"]) assert annotation_data_items[0]["annotation_type"] == "dairies" @@ -267,7 +287,7 @@ def test_get_logo_annotation(): # tests for "value_tag" - annotation_data = get_logo_annotation(value_tag="cheese") + annotation_data = get_logo_annotation(DEFAULT_SERVER_TYPE, value_tag="cheese") annotation_data_items = [item.to_dict() for item in annotation_data] assert annotation_data_items[0]["id"] == annotation_295.id assert annotation_data_items[0]["annotation_value_tag"] == "cheese" diff --git a/tests/integration/test_import_image.py b/tests/integration/test_import_image.py index ff3c30c21d..5723f50b1f 100644 --- a/tests/integration/test_import_image.py +++ b/tests/integration/test_import_image.py @@ -2,9 +2,8 @@ import pytest from PIL import Image -from robotoff import settings from robotoff.models import LogoEmbedding -from robotoff.types import InsightImportResult +from robotoff.types import InsightImportResult, ServerType from robotoff.workers.tasks.import_image import ( process_created_logos, save_logo_embeddings, @@ -17,6 +16,8 @@ clean_db, ) +DEFAULT_SERVER_TYPE = ServerType.off + @pytest.fixture(autouse=True) def _set_up_and_tear_down(peewee_db): @@ -88,9 +89,7 @@ def test_process_created_logos(peewee_db, mocker): for i in range(5) ] logo_embeddings = [LogoEmbeddingFactory(logo=logo) for logo in logos] - process_created_logos( - image_prediction.id, server_domain=settings.BaseURLProvider.server_domain() - ) + process_created_logos(image_prediction.id, DEFAULT_SERVER_TYPE) add_logos_to_ann_mock.assert_called() embedding_args = add_logos_to_ann_mock.mock_calls[0].args[1] assert sorted(embedding_args, key=lambda x: x.logo_id) == logo_embeddings diff --git a/tests/integration/test_logos.py b/tests/integration/test_logos.py index d2e85e24e5..812770f6cf 100644 --- a/tests/integration/test_logos.py +++ b/tests/integration/test_logos.py @@ -2,23 +2,25 @@ import robotoff.insights.importer import robotoff.taxonomy -from robotoff import settings from robotoff.logos import generate_insights_from_annotated_logos_job from robotoff.models import Prediction, ProductInsight from robotoff.off import OFFAuthentication from robotoff.products import Product +from robotoff.types import ProductIdentifier, ServerType from .models_utils import LogoAnnotationFactory +DEFAULT_SERVER_TYPE = ServerType.off -def _fake_store(monkeypatch, barcode): + +def _fake_store(monkeypatch, product_id: ProductIdentifier): monkeypatch.setattr( robotoff.insights.importer, "get_product_store", - lambda: { - barcode: Product( + lambda server_type: { + product_id: Product( { - "code": barcode, # needed to validate brand/label + "code": product_id.barcode, # needed to validate brand/label # needed to validate image "images": { "2": {"rev": 1, "uploaded_t": datetime.utcnow().timestamp()} @@ -31,7 +33,7 @@ def _fake_store(monkeypatch, barcode): def test_generate_insights_from_annotated_logos_job(peewee_db, monkeypatch, mocker): barcode = "0000000000001" - _fake_store(monkeypatch, barcode) + _fake_store(monkeypatch, ProductIdentifier(barcode, DEFAULT_SERVER_TYPE)) mocker.patch( "robotoff.brands.get_brand_prefix", return_value={("Etorki", "0000000xxxxxx")} ) @@ -56,14 +58,19 @@ def test_generate_insights_from_annotated_logos_job(peewee_db, monkeypatch, mock start = datetime.utcnow() generate_insights_from_annotated_logos_job( [ann.id], - settings.BaseURLProvider.server_domain(), OFFAuthentication(username=username, password=username), + server_type=DEFAULT_SERVER_TYPE, ) end = datetime.utcnow() # we generate a prediction with peewee_db: - predictions = list(Prediction.select().where(Prediction.barcode == barcode)) + predictions = list( + Prediction.select().where( + Prediction.barcode == barcode, + Prediction.server_type == DEFAULT_SERVER_TYPE, + ) + ) assert len(predictions) == 1 (prediction,) = predictions assert prediction.type == "brand" @@ -79,11 +86,15 @@ def test_generate_insights_from_annotated_logos_job(peewee_db, monkeypatch, mock assert prediction.predictor == "universal-logo-detector" assert start <= prediction.timestamp <= end assert prediction.automatic_processing is False + assert prediction.server_type == DEFAULT_SERVER_TYPE.name # We check that this prediction in turn generates an insight with peewee_db: insights = list( - ProductInsight.select().where(ProductInsight.barcode == barcode) + ProductInsight.select().where( + ProductInsight.barcode == barcode, + ProductInsight.server_type == DEFAULT_SERVER_TYPE, + ) ) assert len(insights) == 1 (insight,) = insights @@ -103,4 +114,5 @@ def test_generate_insights_from_annotated_logos_job(peewee_db, monkeypatch, mock assert insight.username == "a" assert insight.annotation == 1 assert insight.annotated_result == 2 + assert insight.server_type == DEFAULT_SERVER_TYPE.name assert isinstance(insight.completed_at, datetime) diff --git a/tests/unit/insights/test_importer.py b/tests/unit/insights/test_importer.py index f228ad1e73..6ff412352b 100644 --- a/tests/unit/insights/test_importer.py +++ b/tests/unit/insights/test_importer.py @@ -27,11 +27,14 @@ InsightType, Prediction, PredictionType, + ProductIdentifier, ProductInsightImportResult, + ServerType, ) DEFAULT_BARCODE = "3760094310634" -DEFAULT_SERVER_DOMAIN = "api.openfoodfacts.org" +DEFAULT_SERVER_TYPE = ServerType.off +DEFAULT_PRODUCT_ID = ProductIdentifier(DEFAULT_BARCODE, DEFAULT_SERVER_TYPE) # 2022-02-08 16:07 DEFAULT_UPLOADED_T = "1644332825" @@ -453,7 +456,6 @@ def test_generate_insights_no_predictions(self, mocker): assert CategoryImporter.generate_insights( DEFAULT_BARCODE, [], - DEFAULT_SERVER_DOMAIN, product_store=FakeProductStore(), ) == ([], [], []) get_existing_insight_mock.assert_called_once() @@ -471,7 +473,6 @@ def test_generate_insights_no_predictions_with_existing_insight(self, mocker): assert CategoryImporter.generate_insights( DEFAULT_BARCODE, [], - DEFAULT_SERVER_DOMAIN, product_store=FakeProductStore(), ) == ([], [], [existing_insight]) get_existing_insight_mock.assert_called_once() @@ -489,7 +490,6 @@ def test_generate_insights_missing_product_no_references(self, mocker): data={}, ) ], - DEFAULT_SERVER_DOMAIN, product_store=FakeProductStore(), ) == ([], [], []) get_existing_insight_mock.assert_called_once() @@ -509,7 +509,6 @@ def test_generate_insights_missing_product_with_reference(self, mocker): data={}, ) ], - DEFAULT_SERVER_DOMAIN, product_store=FakeProductStore(), ) assert generated == ([], [], [reference]) @@ -553,7 +552,6 @@ def get_insight_update(cls, candidates, references): generated = FakeImporter.generate_insights( DEFAULT_BARCODE, [prediction], - DEFAULT_SERVER_DOMAIN, product_store=FakeProductStore( data={ DEFAULT_BARCODE: Product( @@ -583,7 +581,6 @@ def get_insight_update(cls, candidates, references): assert created_insight.value_tag == "tag2" assert created_insight.data == {"k": "v"} assert created_insight.barcode == DEFAULT_BARCODE - assert created_insight.server_domain == DEFAULT_SERVER_DOMAIN assert created_insight.server_type == "off" assert created_insight.process_after is not None uuid.UUID(created_insight.id) @@ -615,7 +612,6 @@ def get_insight_update(cls, candidates, references): generated = FakeImporter.generate_insights( DEFAULT_BARCODE, [prediction], - DEFAULT_SERVER_DOMAIN, product_store=FakeProductStore( data={DEFAULT_BARCODE: Product({"code": DEFAULT_BARCODE})} ), @@ -637,7 +633,6 @@ def get_required_prediction_types(): FakeImporter.import_insights( DEFAULT_BARCODE, [Prediction(type=PredictionType.label)], - DEFAULT_SERVER_DOMAIN, product_store=FakeProductStore(), ) @@ -648,9 +643,7 @@ def get_required_prediction_types(): return {PredictionType.label} @classmethod - def generate_insights( - cls, barcode, predictions, server_domains, product_store - ): + def generate_insights(cls, barcode, predictions, product_store): return ( [ ProductInsight( @@ -676,7 +669,6 @@ def generate_insights( import_result = FakeImporter.import_insights( DEFAULT_BARCODE, [Prediction(type=PredictionType.label)], - DEFAULT_SERVER_DOMAIN, product_store=FakeProductStore(), ) assert len(import_result.insight_created_ids) == 1 @@ -1266,8 +1258,8 @@ def test_import_insights_no_element(self, mocker): product_store = FakeProductStore() import_insights_for_products( {DEFAULT_BARCODE: {PredictionType.category}}, - DEFAULT_SERVER_DOMAIN, product_store=product_store, + server_type=DEFAULT_SERVER_TYPE, ) get_product_predictions_mock.assert_called_once() import_insights_mock.assert_not_called() @@ -1277,11 +1269,13 @@ def test_import_insights_single_product(self, mocker): "barcode": DEFAULT_BARCODE, "type": PredictionType.category.name, "data": {}, + "server_type": DEFAULT_SERVER_TYPE, } prediction = Prediction( barcode=DEFAULT_BARCODE, type=PredictionType.category, data={}, + server_type=DEFAULT_SERVER_TYPE, ) get_product_predictions_mock = mocker.patch( "robotoff.insights.importer.get_product_predictions", @@ -1292,19 +1286,19 @@ def test_import_insights_single_product(self, mocker): import_insights_mock = mocker.patch( "robotoff.insights.importer.InsightImporter.import_insights", return_value=ProductInsightImportResult( - [], [], [], DEFAULT_BARCODE, InsightType.category + [], [], [], DEFAULT_PRODUCT_ID, InsightType.category ), ) product_store = FakeProductStore() import_result = import_insights_for_products( {DEFAULT_BARCODE: {PredictionType.category}}, - DEFAULT_SERVER_DOMAIN, product_store=product_store, + server_type=DEFAULT_SERVER_TYPE, ) assert len(import_result) == 1 get_product_predictions_mock.assert_called_once() import_insights_mock.assert_called_once_with( - DEFAULT_BARCODE, [prediction], DEFAULT_SERVER_DOMAIN, product_store + DEFAULT_PRODUCT_ID, [prediction], product_store ) def test_import_insights_type_mismatch(self, mocker): @@ -1322,14 +1316,14 @@ def test_import_insights_type_mismatch(self, mocker): import_insights_mock = mocker.patch( "robotoff.insights.importer.InsightImporter.import_insights", return_value=ProductInsightImportResult( - [], [], [], DEFAULT_BARCODE, InsightType.image_orientation + [], [], [], DEFAULT_PRODUCT_ID, InsightType.image_orientation ), ) product_store = FakeProductStore() import_results = import_insights_for_products( {DEFAULT_BARCODE: {PredictionType.image_orientation}}, - DEFAULT_SERVER_DOMAIN, product_store=product_store, + server_type=DEFAULT_SERVER_TYPE, ) assert len(import_results) == 0 assert not get_product_predictions_mock.called diff --git a/tests/unit/insights/test_question.py b/tests/unit/insights/test_question.py index 9eec8ae132..4fdb3bef2f 100644 --- a/tests/unit/insights/test_question.py +++ b/tests/unit/insights/test_question.py @@ -13,14 +13,14 @@ from robotoff.models import ProductInsight from robotoff.off import split_barcode from robotoff.settings import TEST_DATA_DIR -from robotoff.types import InsightType +from robotoff.types import InsightType, ProductIdentifier, ServerType from robotoff.utils.i18n import TranslationStore def _reset_envvar(monkeypatch): monkeypatch.setenv("ROBOTOFF_INSTANCE", "dev") monkeypatch.delenv("ROBOTOFF_SCHEME", raising=False) - monkeypatch.delenv("ROBOTOFF_DOMAIN", raising=False) + monkeypatch.delenv("ROBOTOFF_TLD", raising=False) @pytest.mark.parametrize( @@ -42,7 +42,8 @@ def test_generate_selected_images(monkeypatch): IMAGE_DATA = json.load(f) selected_images = generate_selected_images( - IMAGE_DATA["product"]["images"], IMAGE_DATA["code"] + IMAGE_DATA["product"]["images"], + ProductIdentifier(IMAGE_DATA["code"], ServerType.off), ) assert selected_images["front"] == { @@ -102,6 +103,7 @@ def generate_insight( value: Optional[str] = None, value_tag: Optional[str] = None, add_source_image: bool = False, + server_type: ServerType = ServerType.off, ) -> ProductInsight: barcode = "1111111111" return ProductInsight( @@ -112,6 +114,7 @@ def generate_insight( source_image=f"/{'/'.join(split_barcode(barcode))}/1.jpg" if add_source_image else None, + server_type=server_type.name, ) @@ -174,6 +177,7 @@ def test_category_question_formatter( "question": expected_question_str, "insight_id": str(insight.id), "insight_type": InsightType.category.name, + "server_type": ServerType.off.name, "source_image_url": "https://images.openfoodfacts.net/images/products/111/111/111/1/front_fr.10.400.jpg", } @@ -227,6 +231,7 @@ def test_label_question_formatter( "question": expected_question_str, "insight_id": str(insight.id), "insight_type": InsightType.label.name, + "server_type": ServerType.off.name, "source_image_url": "https://images.openfoodfacts.net/images/products/111/111/111/1/1.400.jpg", } diff --git a/tests/unit/prediction/category/neural/test_category_classifier.py b/tests/unit/prediction/category/neural/test_category_classifier.py index d7e651f237..c2df4a3e89 100644 --- a/tests/unit/prediction/category/neural/test_category_classifier.py +++ b/tests/unit/prediction/category/neural/test_category_classifier.py @@ -3,8 +3,10 @@ from robotoff.prediction.category.neural.category_classifier import CategoryClassifier from robotoff.triton import serialize_byte_tensor +from robotoff.types import ProductIdentifier, ServerType MODEL_VERSION = "category-classifier" +DEFAULT_PRODUCT_ID = ProductIdentifier("123", ServerType.off) class GRPCResponse: @@ -45,7 +47,7 @@ def test_predict_ingredients_only(mocker, data, category_taxonomy): return_value=MockStub(GRPCResponse(["en:meats"], [0.99])), ) classifier = CategoryClassifier(category_taxonomy) - predictions, debug = classifier.predict(data) + predictions, debug = classifier.predict(data, DEFAULT_PRODUCT_ID) assert debug == { "inputs": { "ingredients_tags": [""], @@ -102,6 +104,7 @@ def test_predict( ) predictions, _ = classifier.predict( {"ingredients_tags": ["ingredient1"], "product_name": "Test Product"}, + DEFAULT_PRODUCT_ID, deepest_only, ) diff --git a/tests/unit/test_logos.py b/tests/unit/test_logos.py index 7d566da215..cf464e3406 100644 --- a/tests/unit/test_logos.py +++ b/tests/unit/test_logos.py @@ -1,7 +1,7 @@ import pytest from robotoff.logos import compute_iou, generate_prediction -from robotoff.types import Prediction, PredictionType +from robotoff.types import Prediction, PredictionType, ServerType @pytest.mark.parametrize( @@ -61,7 +61,12 @@ def test_generate_prediction( ): assert ( generate_prediction( - logo_type, logo_value, data, confidence, automatic_processing + logo_type, + logo_value, + data, + confidence, + ServerType.off, + automatic_processing, ) == prediction ) diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 0f4bbafd60..35bd080606 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -1,5 +1,6 @@ from robotoff import settings from robotoff.models import ImageModel, ImagePrediction, LogoAnnotation +from robotoff.types import ServerType def test_crop_image_url(monkeypatch): @@ -15,6 +16,7 @@ def test_crop_image_url(monkeypatch): source_image="/123/1.jpg", width=20, height=20, + server_type=ServerType.off.name, ), ), bounding_box=(1, 1, 2, 2), @@ -24,5 +26,5 @@ def test_crop_image_url(monkeypatch): assert logo_annotation.get_crop_image_url() == ( f"{settings.BaseURLProvider.robotoff()}/api/v1/images/crop" - + f"?image_url={settings.BaseURLProvider.image_url('/123/1.jpg')}&y_min=1&x_min=1&y_max=2&x_max=2" + + f"?image_url={settings.BaseURLProvider.image_url(ServerType.off, '/123/1.jpg')}&y_min=1&x_min=1&y_max=2&x_max=2" ) diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index 4e132b74cc..eb580c7bd5 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -1,12 +1,27 @@ import pytest from robotoff import settings +from robotoff.types import ServerType # noqa: F401 @pytest.mark.parametrize( "instance,got_url,want_url", [ - ("prod", "settings.BaseURLProvider.world()", "https://world.openfoodfacts.org"), + ( + "prod", + "settings.BaseURLProvider.world(ServerType.off)", + "https://world.openfoodfacts.org", + ), + ( + "prod", + "settings.BaseURLProvider.world(ServerType.obf)", + "https://world.openbeautyfacts.org", + ), + ( + "dev", + "settings.BaseURLProvider.world(ServerType.opf)", + "https://world.openproductfacts.net", + ), ( "prod", "settings.BaseURLProvider.robotoff()", @@ -14,15 +29,19 @@ ), ( "prod", - "settings.BaseURLProvider.country('fr')", + "settings.BaseURLProvider.country(ServerType.off, 'fr')", "https://fr.openfoodfacts.org", ), - ("dev", "settings.BaseURLProvider.world()", "https://world.openfoodfacts.net"), + ( + "dev", + "settings.BaseURLProvider.world(ServerType.off)", + "https://world.openfoodfacts.net", + ), ], ) def test_base_url_provider(monkeypatch, instance, got_url, want_url): monkeypatch.setenv("ROBOTOFF_INSTANCE", instance) - monkeypatch.delenv("ROBOTOFF_DOMAIN", raising=False) # force defaults to apply + monkeypatch.delenv("ROBOTOFF_TLD", raising=False) # force defaults to apply monkeypatch.delenv("ROBOTOFF_SCHEME", raising=False) # force defaults to apply assert eval(got_url) == want_url diff --git a/tests/unit/test_slack.py b/tests/unit/test_slack.py index 867552ecb0..617bde8707 100644 --- a/tests/unit/test_slack.py +++ b/tests/unit/test_slack.py @@ -6,7 +6,11 @@ from robotoff import settings, slack from robotoff.models import ImageModel, ImagePrediction, LogoAnnotation, ProductInsight -from robotoff.types import Prediction, PredictionType +from robotoff.types import Prediction, PredictionType, ProductIdentifier, ServerType + +DEFAULT_BARCODE = "123" +DEFAULT_SERVER_TYPE = ServerType.off +DEFAULT_PRODUCT_ID = ProductIdentifier(DEFAULT_BARCODE, DEFAULT_SERVER_TYPE) class MockSlackResponse: @@ -67,7 +71,7 @@ def test_notify_image_flag_no_prediction(mocker): notifier.notify_image_flag( [], "/source_image", - "123", + DEFAULT_PRODUCT_ID, ) # wont publish anything assert not mock.called @@ -96,22 +100,26 @@ def test_notify_image_flag_public(mocker, monkeypatch): ) ], "/source_image/2.jpg", - "123", + DEFAULT_PRODUCT_ID, ) mock_slack.assert_called_once_with( slack_notifier.POST_MESSAGE_URL, data=PartialRequestMatcher( - f"type: SENSITIVE\nlabel: *flagged*, match: bad_word\n\n <{settings.BaseURLProvider.image_url('/source_image/2.jpg')}|Image> -- <{settings.BaseURLProvider.world()}/cgi/product.pl?type=edit&code=123|*Edit*>", + f"type: SENSITIVE\nlabel: *flagged*, match: bad_word\n\n <{settings.BaseURLProvider.image_url(DEFAULT_SERVER_TYPE, '/source_image/2.jpg')}|Image> -- <{settings.BaseURLProvider.world(DEFAULT_SERVER_TYPE)}/cgi/product.pl?type=edit&code=123|*Edit*>", slack_notifier.ROBOTOFF_PUBLIC_IMAGE_ALERT_CHANNEL, - settings.BaseURLProvider.image_url("/source_image/2.jpg"), + settings.BaseURLProvider.image_url( + DEFAULT_SERVER_TYPE, "/source_image/2.jpg" + ), ), ) mock_image_moderation.assert_called_once_with( "http://images.org/123", data={ "imgid": 2, - "url": settings.BaseURLProvider.image_url("/source_image/2.jpg"), + "url": settings.BaseURLProvider.image_url( + DEFAULT_SERVER_TYPE, "/source_image/2.jpg" + ), }, ) @@ -139,22 +147,26 @@ def test_notify_image_flag_private(mocker, monkeypatch): ) ], "/source_image/2.jpg", - "123", + DEFAULT_PRODUCT_ID, ) mock_slack.assert_called_once_with( slack_notifier.POST_MESSAGE_URL, data=PartialRequestMatcher( - f"type: label_annotation\nlabel: *face*, score: 0.8\n\n <{settings.BaseURLProvider.image_url('/source_image/2.jpg')}|Image> -- <{settings.BaseURLProvider.world()}/cgi/product.pl?type=edit&code=123|*Edit*>", + f"type: label_annotation\nlabel: *face*, score: 0.8\n\n <{settings.BaseURLProvider.image_url(DEFAULT_SERVER_TYPE, '/source_image/2.jpg')}|Image> -- <{settings.BaseURLProvider.world(DEFAULT_SERVER_TYPE)}/cgi/product.pl?type=edit&code=123|*Edit*>", slack_notifier.ROBOTOFF_PRIVATE_IMAGE_ALERT_CHANNEL, - settings.BaseURLProvider.image_url("/source_image/2.jpg"), + settings.BaseURLProvider.image_url( + DEFAULT_SERVER_TYPE, "/source_image/2.jpg" + ), ), ) mock_image_moderation.assert_called_once_with( "http://images.org/123", data={ "imgid": 2, - "url": settings.BaseURLProvider.image_url("/source_image/2.jpg"), + "url": settings.BaseURLProvider.image_url( + DEFAULT_SERVER_TYPE, "/source_image/2.jpg" + ), }, ) @@ -167,21 +179,22 @@ def test_notify_automatic_processing_weight(mocker, monkeypatch): notifier = slack.SlackNotifier("") - print(settings.BaseURLProvider.image_url("/image/1")) + print(settings.BaseURLProvider.image_url(DEFAULT_SERVER_TYPE, "/image/1")) notifier.notify_automatic_processing( ProductInsight( - barcode="123", + barcode=DEFAULT_BARCODE, source_image="/image/1", type="weight", value="200g", data={"raw": "en:200g"}, + server_type=DEFAULT_SERVER_TYPE, ) ) mock.assert_called_once_with( notifier.POST_MESSAGE_URL, data=PartialRequestMatcher( - f"The `200g` weight was automatically added to product 123 (<{settings.BaseURLProvider.world()}/product/123|product>, <{settings.BaseURLProvider.image_url('/image/1')}|source image>)", + f"The `200g` weight was automatically added to product 123 (<{settings.BaseURLProvider.world(DEFAULT_SERVER_TYPE)}/product/123|product>, <{settings.BaseURLProvider.image_url(DEFAULT_SERVER_TYPE, '/image/1')}|source image>)", notifier.ROBOTOFF_ALERT_CHANNEL, ), ) @@ -197,14 +210,18 @@ def test_notify_automatic_processing_label(mocker, monkeypatch): notifier.notify_automatic_processing( ProductInsight( - barcode="123", source_image="/image/1", type="label", value_tag="en:vegan" + barcode=DEFAULT_BARCODE, + source_image="/image/1", + type="label", + value_tag="en:vegan", + server_type=DEFAULT_SERVER_TYPE, ) ) mock.assert_called_once_with( notifier.POST_MESSAGE_URL, data=PartialRequestMatcher( - f"The `en:vegan` label was automatically added to product 123 (<{settings.BaseURLProvider.world()}/product/123|product>, <{settings.BaseURLProvider.image_url('/image/1')}|source image>)", + f"The `en:vegan` label was automatically added to product 123 (<{settings.BaseURLProvider.world(DEFAULT_SERVER_TYPE)}/product/123|product>, <{settings.BaseURLProvider.image_url(DEFAULT_SERVER_TYPE, '/image/1')}|source image>)", notifier.ROBOTOFF_ALERT_CHANNEL, ), ) @@ -218,11 +235,15 @@ def test_noop_slack_notifier_logging(caplog): LogoAnnotation( image_prediction=ImagePrediction( image=ImageModel( - barcode="123", source_image="/123/1.jpg", width=10, height=10 + barcode=DEFAULT_BARCODE, + source_image="/123/1.jpg", + width=10, + height=10, + server_type=DEFAULT_SERVER_TYPE.name, ), ), bounding_box=(1, 1, 2, 2), - barcode="123", + barcode=DEFAULT_BARCODE, source_image="/123/1.jpg", ), {}, diff --git a/tests/unit/workers/tasks/test_product_updated.py b/tests/unit/workers/tasks/test_product_updated.py index 6240ccbb1b..f3581c16c8 100644 --- a/tests/unit/workers/tasks/test_product_updated.py +++ b/tests/unit/workers/tasks/test_product_updated.py @@ -1,5 +1,10 @@ -from robotoff import settings -from robotoff.types import InsightImportResult, Prediction, PredictionType +from robotoff.types import ( + InsightImportResult, + Prediction, + PredictionType, + ProductIdentifier, + ServerType, +) from robotoff.workers.tasks.product_updated import add_category_insight # TODO: refactor function under test to make it easier to test @@ -18,23 +23,24 @@ def test_add_category_insight_no_insights(mocker): import_insights_mock = mocker.patch( "robotoff.workers.tasks.product_updated.import_insights" ) - imported = add_category_insight( - "123", {"code": "123"}, settings.BaseURLProvider.world() - ) + imported = add_category_insight("123", {"code": "123"}) assert not import_insights_mock.called assert not imported def test_add_category_insight_with_ml_insights(mocker): + barcode = "123" + product_id = ProductIdentifier(barcode, ServerType.off) expected_prediction = Prediction( - barcode="123", + barcode=product_id.barcode, type=PredictionType.category, value_tag="en:chicken", data={"lang": "xx"}, automatic_processing=True, predictor="neural", confidence=0.9, + server_type=product_id.server_type, ) mocker.patch( "robotoff.workers.tasks.product_updated.predict_category_matcher", @@ -48,20 +54,20 @@ def test_add_category_insight_with_ml_insights(mocker): "robotoff.workers.tasks.product_updated.import_insights", return_value=InsightImportResult(), ) - server_domain = settings.BaseURLProvider.world() - add_category_insight("123", {"code": "123"}, server_domain) + add_category_insight(product_id, {"code": "123"}) import_insights_mock.assert_called_once_with( [ Prediction( - barcode="123", + barcode=product_id.barcode, type=PredictionType.category, value_tag="en:chicken", data={"lang": "xx"}, automatic_processing=True, predictor="neural", confidence=0.9, - ) + server_type=product_id.server_type, + ), ], - server_domain, + ServerType.off, )