Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for hasher role in /lookup endpoint #1729

Merged
merged 4 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hasher-matcher-actioner/.devcontainer/postcreate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -e

pip install --editable .[all]

# Find Python packages in opt and install them
# Find Python packages in extensions and install them
for setup_script in "$(find /workspace/extensions -name setup.py)"
do
module_dir="$(dirname "$setup_script")"
Expand Down
2 changes: 1 addition & 1 deletion hasher-matcher-actioner/src/OpenMediaMatch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from threatexchange.signal_type.pdq import signal as _
## Resume regularly scheduled imports
# Resume regularly scheduled imports

import logging
import os
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def query_index(
try:
signal = signal_type.validate_signal_str(signal)
except Exception as e:
abort(400, f"invalid signal type: {e}")
abort(400, f"invalid signal: {e}")

index = _get_index(signal_type)

Expand Down Expand Up @@ -203,8 +203,10 @@ def lookup_get():
Output:
* List of matching banks
"""
# Url was passed as a query param?
if request.args.get("url", None):
if not current_app.config.get("ROLE_HASHER", False):
abort(403, "Hashing is disabled, missing role")

hashes = hashing.hash_media()
resp = {}
for signal_type in hashes.keys():
Expand All @@ -230,6 +232,9 @@ def lookup_post():
Output:
* List of matching banks
"""
if not current_app.config.get("ROLE_HASHER", False):
abort(403, "Hashing is disabled, missing role")

hashes = hashing.hash_media_post()

resp = {}
Expand Down
222 changes: 38 additions & 184 deletions hasher-matcher-actioner/src/OpenMediaMatch/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

from io import BytesIO
import tempfile
import typing as t

from flask.testing import FlaskClient
from flask import Flask
from PIL import Image
import requests

from threatexchange.exchanges.impl.fb_threatexchange_api import (
FBThreatExchangeSignalExchangeAPI,
Expand All @@ -16,13 +20,9 @@
from OpenMediaMatch.tests.utils import (
app,
client,
create_bank,
add_hash_to_bank,
IMAGE_URL_TO_PDQ,
)
from OpenMediaMatch.background_tasks.build_index import build_all_indices
from OpenMediaMatch.persistence import get_storage
from OpenMediaMatch.storage.postgres import database


def test_status_response(client: FlaskClient):
Expand All @@ -31,194 +31,48 @@ def test_status_response(client: FlaskClient):
assert response.data == b"I-AM-ALIVE"


def test_banks_empty_index(client: FlaskClient):
response = client.get("/c/banks")
assert response.status_code == 200
assert response.json == []


def test_banks_create(client: FlaskClient):
# Must not start with number
post_response = client.post(
"/c/banks",
json={"name": "01_TEST_BANK"},
)
assert post_response.status_code == 400

# Cannot contain lowercase letters
post_response = client.post(
"/c/banks",
json={"name": "my_test_bank"},
)
assert post_response.status_code == 400

post_response = client.post(
"/c/banks",
json={"name": "MY_TEST_BANK_01"},
)
assert post_response.status_code == 201
assert post_response.json == {
"matching_enabled_ratio": 1.0,
"name": "MY_TEST_BANK_01",
}

# Should now be visible on index
response = client.get("/c/banks")
assert response.status_code == 200
assert response.json == [post_response.json]


def test_banks_update(client: FlaskClient):
post_response = client.post(
"/c/banks",
json={"name": "MY_TEST_BANK"},
)
assert post_response.status_code == 201

# check name validation
post_response = client.put(
"/c/bank/MY_TEST_BANK",
json={"name": "1_invalid_name"},
)
assert post_response.status_code == 400

# check update with rename
post_response = client.put(
"/c/bank/MY_TEST_BANK",
json={"name": "MY_TEST_BANK_RENAMED"},
)
assert post_response.status_code == 200
assert post_response.get_json()["name"] == "MY_TEST_BANK_RENAMED"

# check update without rename
post_response = client.put(
"/c/bank/MY_TEST_BANK_RENAMED",
json={"enabled": False},
)
assert post_response.status_code == 200
assert post_response.get_json()["matching_enabled_ratio"] == 0

# check update without ratio
post_response = client.put(
"/c/bank/MY_TEST_BANK_RENAMED",
json={"enabled_ratio": 0.5},
)
assert post_response.status_code == 200
assert post_response.get_json()["matching_enabled_ratio"] == 0.5

# Final test to make sure we only have one bank with proper name and disabled

get_response = client.get("/c/banks")
assert get_response.status_code == 200
json = get_response.get_json()
assert len(json) == 1
assert json[0] == {"name": "MY_TEST_BANK_RENAMED", "matching_enabled_ratio": 0.5}


def test_banks_delete(client: FlaskClient):
post_response = client.post(
"/c/banks",
json={"name": "MY_TEST_BANK"},
)
assert post_response.status_code == 201

# check name validation
post_response = client.delete(
"/c/bank/MY_TEST_BANK",
)
assert post_response.status_code == 200

# deleting non existing bank should succeed
post_response = client.delete(
"/c/bank/MY_TEST_BANK",
)
assert post_response.status_code == 200


def test_banks_add_hash(client: FlaskClient):
bank_name = "NEW_BANK"
create_bank(client, bank_name)

image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"

post_response = client.post(
f"/c/bank/{bank_name}/content?url={image_url}&content_type=photo"
)

assert post_response.status_code == 200, str(post_response.get_json())
assert post_response.json == {
"id": 1,
"signals": {
"pdq": "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22"
},
}


def test_banks_delete_hash(client: FlaskClient):
bank_name = "NEW_BANK"
image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"

create_bank(client, bank_name)
add_hash_to_bank(client, bank_name, image_url, 1)

post_response = client.delete(f"/c/bank/{bank_name}/content/1")

assert post_response.status_code == 200
assert post_response.json == {"deleted": 1}


def test_banks_add_metadata(client: FlaskClient):
bank_name = "NEW_BANK"
create_bank(client, bank_name)

image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"
post_request = f"/c/bank/{bank_name}/content?url={image_url}&content_type=photo"

post_response = client.post(
post_request, json={"metadata": {"invalid_metadata": 5}}
)
assert post_response.status_code == 400, str(post_response.get_json())

post_response = client.post(
post_request,
json={"metadata": {"content_id": "1197433091", "json": {"asdf": {}}}},
)

assert post_response.status_code == 200, str(post_response.get_json())


def test_banks_add_hash_index(app: Flask, client: FlaskClient):
bank_name = "NEW_BANK"
bank_name_2 = "NEW_BANK_2"
image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"
image_url_2 = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/misc-images/c.png?raw=true"

# Make two banks and add images to each bank
create_bank(client, bank_name)
add_hash_to_bank(client, bank_name, image_url, 1)
create_bank(client, bank_name_2)
add_hash_to_bank(client, bank_name, image_url_2, 2)

def test_lookup_success(app: Flask, client: FlaskClient):
storage = get_storage()
# ensure index is empty to start with
assert storage.get_signal_type_index(PdqSignal) is None

# Build index
build_all_indices(storage, storage, storage)

# Test against first image
post_response = client.get(
f"/m/raw_lookup?signal_type=pdq&signal={IMAGE_URL_TO_PDQ[image_url]}"
)
assert post_response.status_code == 200
assert post_response.json == {"matches": [1]}
# test GET
image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"
get_resp = client.get(f"/m/lookup?url={image_url}")
assert get_resp.status_code == 200

# Test against second image
post_response = client.get(
f"/m/raw_lookup?signal_type=pdq&signal={IMAGE_URL_TO_PDQ[image_url_2]}"
)
assert post_response.status_code == 200
assert post_response.json == {"matches": [2]}
# test POST with temp file
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
with tempfile.NamedTemporaryFile(suffix=".jpg") as f:
image.save(f, format="JPEG")
files = {"photo": (f.name, f.name, "image/jpeg")}
resp = client.post("/m/lookup", data=files)
assert resp.status_code == 200


def test_lookup_without_role(app: Flask, client: FlaskClient):
# role resets to True in the next test
client.application.config["ROLE_HASHER"] = False

# test GET
image_url = "https://github.com/facebook/ThreatExchange/blob/main/pdq/data/bridge-mods/aaa-orig.jpg?raw=true"
get_resp = client.get(f"/m/lookup?url={image_url}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: I don't see any existing unittests in this file that check against this endpoint at all (the others are using "raw_lookup". Can you also test a lookup is accepted?

This file is getting big enough that it may make sense to break out unittests for the different APIs.

assert get_resp.status_code == 403

# test POST with temp file
with tempfile.NamedTemporaryFile(suffix=".jpg") as f:
# Write a minimal valid JPEG file header
f.write(
b"\xff\xd8\xff\xe0\x00\x10\x4a\x46\x49\x46\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9"
)
f.flush()
files = {"file": (f.name, f.name, "image/jpeg")}
resp = client.post("/m/lookup", data=files)
assert resp.status_code == 403


def test_exchange_api_set_auth(app: Flask, client: FlaskClient):
Expand Down
Loading
Loading