Skip to content

Commit dedbe33

Browse files
committed
reflect the latest version of the dataset
1 parent 586146b commit dedbe33

File tree

4 files changed

+162
-88
lines changed

4 files changed

+162
-88
lines changed

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/cli.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727

2828
@app.command()
2929
def evaluate(
30+
*,
31+
random_seed: Annotated[
32+
int,
33+
Option(help="The seed for the random number generator used by the benchmark."),
34+
] = 12345,
3035
filename: Annotated[
3136
FilePath,
3237
Option(
@@ -37,7 +42,7 @@ def evaluate(
3742
) -> None:
3843
"""Evaluate the accuracy of the VLM responses."""
3944
logger.info("Evaluating the accuracy file")
40-
run_evaluation(filename=filename, dataset=dataset)
45+
run_evaluation(random_seed=random_seed, filename=filename, dataset=dataset)
4146

4247

4348
@benchmark_app.command(name="endpoint")
@@ -78,9 +83,7 @@ def _run_benchmark(
7883
"""Run the VL2L benchmark."""
7984
logger.info("Running VL2L benchmark with settings: {}", settings)
8085
logger.info("Running VL2L benchmark with dataset: {}", dataset)
81-
logger.info(
82-
"Running VL2L benchmark with OpenAI API endpoint: {}",
83-
endpoint)
86+
logger.info("Running VL2L benchmark with OpenAI API endpoint: {}", endpoint)
8487
logger.info("Running VL2L benchmark with random seed: {}", random_seed)
8588
test_settings, log_settings = settings.to_lgtype()
8689
task = ShopifyGlobalCatalogue(

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/evaluation.py

Lines changed: 98 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from hiclass.metrics import f1 # type: ignore[import-untyped]
1212
from loguru import logger
1313
from pydantic import ValidationError
14-
from rapidfuzz import fuzz
14+
from rapidfuzz import fuzz # type: ignore[import-untyped]
1515
from sklearn.metrics import f1_score # type: ignore[import-untyped]
1616
from tabulate import tabulate
1717

@@ -22,24 +22,23 @@
2222

2323
from .schema import ProductMetadata
2424

25-
# Initialize the Generator
26-
# As of NumPy 1.17+,
27-
# this isolates the random state,
28-
# which is safer for reproducibility and parallel processing.
29-
rng = np.random.default_rng()
25+
_TRUE_CATEGORY_PAD = "<|__TRUE_CATEGORY_PAD__|>"
26+
_PRED_CATEGORY_PAD = "<|__PRED_CATEGORY_PAD__|>"
27+
_PRED_BRAND_PAD = "<|__PRED_BRAND_PAD__|>"
28+
_CATEGORY_SEPARATOR = " > "
3029

3130

3231
def get_hierarchical_components(
3332
predicted_path: str,
3433
true_path: str,
35-
separator: str = " > ",
34+
separator: str = _CATEGORY_SEPARATOR,
3635
) -> tuple[int, int, int]:
3736
"""Calculates the components for Hierarchical Precision.
3837
3938
Args:
4039
predicted_path: Categories predicted by the VLM.
4140
true_path: Ground truth categories.
42-
separator: String used to separate each category.
41+
separator: The separator used to separate each level of the category.
4342
4443
Returns:
4544
Tuple of number of intersections,
@@ -58,8 +57,7 @@ def get_hierarchical_components(
5857
intersection_count = 0
5958

6059
# Iterate through the paths simultaneously
61-
for pred_cat, true_cat in zip(
62-
predicted_categories, true_categories, strict=False):
60+
for pred_cat, true_cat in zip(predicted_categories, true_categories, strict=False):
6361
if pred_cat == true_cat:
6462
intersection_count += 1
6563
else:
@@ -72,12 +70,15 @@ def get_hierarchical_components(
7270
return intersection_count, pred_length, true_length
7371

7472

75-
def calculate_hierarchical_f1(data: list[tuple[str, str]]) -> float:
73+
def calculate_hierarchical_f1(
74+
data: list[tuple[str, str]],
75+
separator: str = _CATEGORY_SEPARATOR,
76+
) -> float:
7677
"""Calculates the aggregate hF scores for a list of samples.
7778
7879
Args:
79-
data: A list of tuples, where each tuple is
80-
(predicted_path_str, true_path_str).
80+
data: A list of tuples, where each tuple is (predicted_path_str, true_path_str).
81+
separator: The separator used to split the paths into levels of the category.
8182
8283
Returns:
8384
F1 score
@@ -89,8 +90,9 @@ def calculate_hierarchical_f1(data: list[tuple[str, str]]) -> float:
8990
# 1. Aggregate the components across all samples
9091
for pred_path, true_path in data:
9192
intersection, pred_len, true_len = get_hierarchical_components(
92-
pred_path,
93-
true_path,
93+
predicted_path=pred_path,
94+
true_path=true_path,
95+
separator=separator,
9496
)
9597

9698
total_intersection += intersection
@@ -156,20 +158,25 @@ def calculate_secondhand_f1(data: list[tuple[bool, bool]]) -> float:
156158
return f1_score(y_src, y_pred)
157159

158160

159-
def calculate_hiclass_f1(data: list[tuple[str, str]]) -> float:
161+
def calculate_hiclass_f1(
162+
data: list[tuple[str, str]],
163+
separator: str = _CATEGORY_SEPARATOR,
164+
) -> float:
160165
"""Alt method to calculate hierarchical F1.
161166
162167
Args:
163-
data: List of tuples of predicted and true values
168+
data: List of tuples of predicted and true values
169+
separator: The separator used to split the paths into levels of the category.
170+
164171
Returs:
165172
f1 score
166173
"""
167174
y_pred_raw = []
168175
y_true_raw = []
169176

170177
for pred, src in data:
171-
path1 = pred.split(" > ")
172-
path2 = src.split(" > ")
178+
path1 = pred.split(separator)
179+
path2 = src.split(separator)
173180

174181
y_pred_raw.append(path1)
175182
y_true_raw.append(path2)
@@ -182,11 +189,11 @@ def calculate_hiclass_f1(data: list[tuple[str, str]]) -> float:
182189
for i in range(len(y_true_raw)):
183190
# Pad Truth
184191
pad_len_true = max_len - len(y_true_raw[i])
185-
y_true_raw[i] += [""] * pad_len_true
192+
y_true_raw[i] += [_TRUE_CATEGORY_PAD] * pad_len_true
186193

187194
# Pad Prediction
188195
pad_len_pred = max_len - len(y_pred_raw[i])
189-
y_pred_raw[i] += [""] * pad_len_pred
196+
y_pred_raw[i] += [_PRED_CATEGORY_PAD] * pad_len_pred
190197

191198
# 4. Convert to numpy arrays
192199
y_true = np.array(y_true_raw)
@@ -196,8 +203,9 @@ def calculate_hiclass_f1(data: list[tuple[str, str]]) -> float:
196203
return f1(y_true, y_pred)
197204

198205

199-
def run_evaluation(filename: FilePath, dataset: DatasetCLI) -> None:
206+
def run_evaluation(random_seed: int, filename: FilePath, dataset: DatasetCLI) -> None:
200207
"""Main function to run the evaluation."""
208+
rng = np.random.default_rng(seed=random_seed)
201209
with Path.open(filename) as f:
202210
model_output = json.load(f)
203211

@@ -207,26 +215,43 @@ def run_evaluation(filename: FilePath, dataset: DatasetCLI) -> None:
207215
split="+".join(dataset.split),
208216
)
209217

218+
num_unparsable_responses = 0
210219
category_dataset_pred_src = []
211220
category_rand_pred_src = []
212221
is_secondhand_pred_src = []
213222
is_secondhand_rand_pred_src = []
214223
brand_pred_src = []
215224

225+
all_possible_brands = set()
226+
216227
for elem in model_output:
217228
idx = elem["qsl_idx"]
218229
response = bytes.fromhex(elem["data"]).decode("utf-8")
230+
ground_truth_item = original_data[idx]
231+
all_possible_brands.add(ground_truth_item["ground_truth_brand"])
219232
try:
220233
pred_item = ProductMetadata.model_validate_json(response)
221234
except ValidationError:
222-
logger.exception(
235+
num_unparsable_responses += 1
236+
pred_item = ProductMetadata(
237+
category=_CATEGORY_SEPARATOR.join(
238+
[_PRED_CATEGORY_PAD]
239+
* len(
240+
ground_truth_item["ground_truth_category"].split(
241+
_CATEGORY_SEPARATOR,
242+
),
243+
),
244+
),
245+
brand=_PRED_BRAND_PAD,
246+
is_secondhand=rng.choice([True, False], size=1).tolist()[0],
247+
)
248+
logger.error(
223249
"Response\n{}\n(for the sample at index {}) cannot be validated against"
224-
" the expected schema\n{}\n. Thus, this submission result is invalid.",
250+
" the expected schema. Overwriting this response into \n{}\n",
225251
response,
226252
idx,
227-
json.dumps(ProductMetadata.model_json_schema(), indent=2),
253+
pred_item,
228254
)
229-
ground_truth_item = original_data[idx]
230255
category_dataset_pred_src.append(
231256
(pred_item.category, ground_truth_item["ground_truth_category"]),
232257
)
@@ -236,48 +261,66 @@ def run_evaluation(filename: FilePath, dataset: DatasetCLI) -> None:
236261
ground_truth_item["ground_truth_is_secondhand"],
237262
),
238263
)
264+
brand_pred_src.append(
265+
(pred_item.brand, ground_truth_item["ground_truth_brand"]),
266+
)
239267
# random category selection
240268
# Uniform distribution is the default
241-
rand_cat = rng.choice(ground_truth_item["potential_product_categories"],
242-
size=1).tolist()[0]
243-
category_rand_pred_src.append((rand_cat,
244-
ground_truth_item["ground_truth_category"]))
245-
269+
rand_cat = rng.choice(ground_truth_item["potential_product_categories"])
270+
category_rand_pred_src.append(
271+
(rand_cat, ground_truth_item["ground_truth_category"]),
272+
)
246273
# random is_secondhand selection
247-
rand_is_secondhand = rng.choice([True, False], size=1).tolist()[0]
248-
is_secondhand_rand_pred_src.append((rand_is_secondhand,
249-
ground_truth_item["ground_truth_is_secondhand"]))
250-
251-
brand_pred_src.append((pred_item.brand,
252-
ground_truth_item["ground_truth_brand"]))
274+
rand_is_secondhand = rng.choice([True, False])
275+
is_secondhand_rand_pred_src.append(
276+
(rand_is_secondhand, ground_truth_item["ground_truth_is_secondhand"]),
277+
)
253278

254279
category_f1_score = calculate_hierarchical_f1(category_dataset_pred_src)
255280
hiclass_f1_score = calculate_hiclass_f1(category_dataset_pred_src)
256281
is_secondhand_f1_score = calculate_secondhand_f1(is_secondhand_pred_src)
257282
brand_score = calculate_brand_f1_score(brand_pred_src)
258283

259284
rand_cat_f1_score = calculate_hierarchical_f1(category_rand_pred_src)
260-
rand_hiclass_f1_score = calculate_hierarchical_f1(category_rand_pred_src)
261-
rand_is_seconhand_f1_score = calculate_secondhand_f1(
262-
is_secondhand_rand_pred_src)
263-
264-
data = [
265-
["category", category_f1_score, hiclass_f1_score,
266-
rand_cat_f1_score, rand_hiclass_f1_score, 0],
267-
["is_secondhand", is_secondhand_f1_score, 0,
268-
rand_is_seconhand_f1_score, 0, 0],
269-
["brand", 0, 0, 0, 0, brand_score],
270-
]
285+
rand_hiclass_f1_score = calculate_hiclass_f1(category_rand_pred_src)
286+
rand_is_seconhand_f1_score = calculate_secondhand_f1(is_secondhand_rand_pred_src)
287+
rand_brand_score = calculate_brand_f1_score(
288+
[
289+
(
290+
rng.choice(list(all_possible_brands)),
291+
original_data[elem["qsl_idx"]]["ground_truth_brand"],
292+
)
293+
for elem in model_output
294+
],
295+
)
271296

272297
logger.info(
273-
"Results:\n{}",
298+
"{} responses cannot be parsed against the expected schema. Results:\n{}",
299+
num_unparsable_responses,
274300
tabulate(
275-
data,
276-
headers=["Fields", "F1 Score",
277-
"HiClass F1 Score",
278-
"F1 Score Random Selection",
279-
"HiClass F1 Score Random Selection",
280-
"Brand F1 Score"],
301+
[
302+
[
303+
"From accuracy file",
304+
category_f1_score,
305+
hiclass_f1_score,
306+
brand_score,
307+
is_secondhand_f1_score,
308+
],
309+
[
310+
"Random selection",
311+
rand_cat_f1_score,
312+
rand_hiclass_f1_score,
313+
rand_brand_score,
314+
rand_is_seconhand_f1_score,
315+
],
316+
],
317+
headers=[
318+
"Results",
319+
"Category hierarchical F1 Score",
320+
"Category HiClass F1 Score",
321+
"Brand F1 Score",
322+
"Is_secondhand F1 Score",
323+
],
281324
tablefmt="fancy_grid",
282325
),
283326
)

0 commit comments

Comments
 (0)