1111from hiclass .metrics import f1 # type: ignore[import-untyped]
1212from loguru import logger
1313from pydantic import ValidationError
14- from rapidfuzz import fuzz
14+ from rapidfuzz import fuzz # type: ignore[import-untyped]
1515from sklearn .metrics import f1_score # type: ignore[import-untyped]
1616from tabulate import tabulate
1717
2222
2323from .schema import ProductMetadata
2424
25- # Initialize the Generator
26- # As of NumPy 1.17+,
27- # this isolates the random state,
28- # which is safer for reproducibility and parallel processing.
29- rng = np .random .default_rng ()
25+ _TRUE_CATEGORY_PAD = "<|__TRUE_CATEGORY_PAD__|>"
26+ _PRED_CATEGORY_PAD = "<|__PRED_CATEGORY_PAD__|>"
27+ _PRED_BRAND_PAD = "<|__PRED_BRAND_PAD__|>"
28+ _CATEGORY_SEPARATOR = " > "
3029
3130
3231def get_hierarchical_components (
3332 predicted_path : str ,
3433 true_path : str ,
35- separator : str = " > " ,
34+ separator : str = _CATEGORY_SEPARATOR ,
3635) -> tuple [int , int , int ]:
3736 """Calculates the components for Hierarchical Precision.
3837
3938 Args:
4039 predicted_path: Categories predicted by the VLM.
4140 true_path: Ground truth categories.
42- separator: String used to separate each category.
41+ separator: The separator used to separate each level of the category.
4342
4443 Returns:
4544 Tuple of number of intersections,
@@ -58,8 +57,7 @@ def get_hierarchical_components(
5857 intersection_count = 0
5958
6059 # Iterate through the paths simultaneously
61- for pred_cat , true_cat in zip (
62- predicted_categories , true_categories , strict = False ):
60+ for pred_cat , true_cat in zip (predicted_categories , true_categories , strict = False ):
6361 if pred_cat == true_cat :
6462 intersection_count += 1
6563 else :
@@ -72,12 +70,15 @@ def get_hierarchical_components(
7270 return intersection_count , pred_length , true_length
7371
7472
75- def calculate_hierarchical_f1 (data : list [tuple [str , str ]]) -> float :
73+ def calculate_hierarchical_f1 (
74+ data : list [tuple [str , str ]],
75+ separator : str = _CATEGORY_SEPARATOR ,
76+ ) -> float :
7677 """Calculates the aggregate hF scores for a list of samples.
7778
7879 Args:
79- data: A list of tuples, where each tuple is
80- (predicted_path_str, true_path_str) .
80+ data: A list of tuples, where each tuple is (predicted_path_str, true_path_str).
81+ separator: The separator used to split the paths into levels of the category .
8182
8283 Returns:
8384 F1 score
@@ -89,8 +90,9 @@ def calculate_hierarchical_f1(data: list[tuple[str, str]]) -> float:
8990 # 1. Aggregate the components across all samples
9091 for pred_path , true_path in data :
9192 intersection , pred_len , true_len = get_hierarchical_components (
92- pred_path ,
93- true_path ,
93+ predicted_path = pred_path ,
94+ true_path = true_path ,
95+ separator = separator ,
9496 )
9597
9698 total_intersection += intersection
@@ -156,20 +158,25 @@ def calculate_secondhand_f1(data: list[tuple[bool, bool]]) -> float:
156158 return f1_score (y_src , y_pred )
157159
158160
159- def calculate_hiclass_f1 (data : list [tuple [str , str ]]) -> float :
161+ def calculate_hiclass_f1 (
162+ data : list [tuple [str , str ]],
163+ separator : str = _CATEGORY_SEPARATOR ,
164+ ) -> float :
160165 """Alt method to calculate hierarchical F1.
161166
162167 Args:
163- data: List of tuples of predicted and true values
168+ data: List of tuples of predicted and true values
169+ separator: The separator used to split the paths into levels of the category.
170+
164171 Returs:
165172 f1 score
166173 """
167174 y_pred_raw = []
168175 y_true_raw = []
169176
170177 for pred , src in data :
171- path1 = pred .split (" > " )
172- path2 = src .split (" > " )
178+ path1 = pred .split (separator )
179+ path2 = src .split (separator )
173180
174181 y_pred_raw .append (path1 )
175182 y_true_raw .append (path2 )
@@ -182,11 +189,11 @@ def calculate_hiclass_f1(data: list[tuple[str, str]]) -> float:
182189 for i in range (len (y_true_raw )):
183190 # Pad Truth
184191 pad_len_true = max_len - len (y_true_raw [i ])
185- y_true_raw [i ] += ["" ] * pad_len_true
192+ y_true_raw [i ] += [_TRUE_CATEGORY_PAD ] * pad_len_true
186193
187194 # Pad Prediction
188195 pad_len_pred = max_len - len (y_pred_raw [i ])
189- y_pred_raw [i ] += ["" ] * pad_len_pred
196+ y_pred_raw [i ] += [_PRED_CATEGORY_PAD ] * pad_len_pred
190197
191198 # 4. Convert to numpy arrays
192199 y_true = np .array (y_true_raw )
@@ -196,8 +203,9 @@ def calculate_hiclass_f1(data: list[tuple[str, str]]) -> float:
196203 return f1 (y_true , y_pred )
197204
198205
199- def run_evaluation (filename : FilePath , dataset : DatasetCLI ) -> None :
206+ def run_evaluation (random_seed : int , filename : FilePath , dataset : DatasetCLI ) -> None :
200207 """Main function to run the evaluation."""
208+ rng = np .random .default_rng (seed = random_seed )
201209 with Path .open (filename ) as f :
202210 model_output = json .load (f )
203211
@@ -207,26 +215,43 @@ def run_evaluation(filename: FilePath, dataset: DatasetCLI) -> None:
207215 split = "+" .join (dataset .split ),
208216 )
209217
218+ num_unparsable_responses = 0
210219 category_dataset_pred_src = []
211220 category_rand_pred_src = []
212221 is_secondhand_pred_src = []
213222 is_secondhand_rand_pred_src = []
214223 brand_pred_src = []
215224
225+ all_possible_brands = set ()
226+
216227 for elem in model_output :
217228 idx = elem ["qsl_idx" ]
218229 response = bytes .fromhex (elem ["data" ]).decode ("utf-8" )
230+ ground_truth_item = original_data [idx ]
231+ all_possible_brands .add (ground_truth_item ["ground_truth_brand" ])
219232 try :
220233 pred_item = ProductMetadata .model_validate_json (response )
221234 except ValidationError :
222- logger .exception (
235+ num_unparsable_responses += 1
236+ pred_item = ProductMetadata (
237+ category = _CATEGORY_SEPARATOR .join (
238+ [_PRED_CATEGORY_PAD ]
239+ * len (
240+ ground_truth_item ["ground_truth_category" ].split (
241+ _CATEGORY_SEPARATOR ,
242+ ),
243+ ),
244+ ),
245+ brand = _PRED_BRAND_PAD ,
246+ is_secondhand = rng .choice ([True , False ], size = 1 ).tolist ()[0 ],
247+ )
248+ logger .error (
223249 "Response\n {}\n (for the sample at index {}) cannot be validated against"
224- " the expected schema\n {} \n . Thus, this submission result is invalid. " ,
250+ " the expected schema. Overwriting this response into \n {} \n " ,
225251 response ,
226252 idx ,
227- json . dumps ( ProductMetadata . model_json_schema (), indent = 2 ) ,
253+ pred_item ,
228254 )
229- ground_truth_item = original_data [idx ]
230255 category_dataset_pred_src .append (
231256 (pred_item .category , ground_truth_item ["ground_truth_category" ]),
232257 )
@@ -236,48 +261,66 @@ def run_evaluation(filename: FilePath, dataset: DatasetCLI) -> None:
236261 ground_truth_item ["ground_truth_is_secondhand" ],
237262 ),
238263 )
264+ brand_pred_src .append (
265+ (pred_item .brand , ground_truth_item ["ground_truth_brand" ]),
266+ )
239267 # random category selection
240268 # Uniform distribution is the default
241- rand_cat = rng .choice (ground_truth_item ["potential_product_categories" ],
242- size = 1 ).tolist ()[0 ]
243- category_rand_pred_src .append ((rand_cat ,
244- ground_truth_item ["ground_truth_category" ]))
245-
269+ rand_cat = rng .choice (ground_truth_item ["potential_product_categories" ])
270+ category_rand_pred_src .append (
271+ (rand_cat , ground_truth_item ["ground_truth_category" ]),
272+ )
246273 # random is_secondhand selection
247- rand_is_secondhand = rng .choice ([True , False ], size = 1 ).tolist ()[0 ]
248- is_secondhand_rand_pred_src .append ((rand_is_secondhand ,
249- ground_truth_item ["ground_truth_is_secondhand" ]))
250-
251- brand_pred_src .append ((pred_item .brand ,
252- ground_truth_item ["ground_truth_brand" ]))
274+ rand_is_secondhand = rng .choice ([True , False ])
275+ is_secondhand_rand_pred_src .append (
276+ (rand_is_secondhand , ground_truth_item ["ground_truth_is_secondhand" ]),
277+ )
253278
254279 category_f1_score = calculate_hierarchical_f1 (category_dataset_pred_src )
255280 hiclass_f1_score = calculate_hiclass_f1 (category_dataset_pred_src )
256281 is_secondhand_f1_score = calculate_secondhand_f1 (is_secondhand_pred_src )
257282 brand_score = calculate_brand_f1_score (brand_pred_src )
258283
259284 rand_cat_f1_score = calculate_hierarchical_f1 (category_rand_pred_src )
260- rand_hiclass_f1_score = calculate_hierarchical_f1 (category_rand_pred_src )
261- rand_is_seconhand_f1_score = calculate_secondhand_f1 (
262- is_secondhand_rand_pred_src )
263-
264- data = [
265- [ "category" , category_f1_score , hiclass_f1_score ,
266- rand_cat_f1_score , rand_hiclass_f1_score , 0 ],
267- [ "is_secondhand" , is_secondhand_f1_score , 0 ,
268- rand_is_seconhand_f1_score , 0 , 0 ],
269- [ "brand" , 0 , 0 , 0 , 0 , brand_score ],
270- ]
285+ rand_hiclass_f1_score = calculate_hiclass_f1 (category_rand_pred_src )
286+ rand_is_seconhand_f1_score = calculate_secondhand_f1 (is_secondhand_rand_pred_src )
287+ rand_brand_score = calculate_brand_f1_score (
288+ [
289+ (
290+ rng . choice ( list ( all_possible_brands )) ,
291+ original_data [ elem [ "qsl_idx" ]][ "ground_truth_brand" ],
292+ )
293+ for elem in model_output
294+ ],
295+ )
271296
272297 logger .info (
273- "Results:\n {}" ,
298+ "{} responses cannot be parsed against the expected schema. Results:\n {}" ,
299+ num_unparsable_responses ,
274300 tabulate (
275- data ,
276- headers = ["Fields" , "F1 Score" ,
277- "HiClass F1 Score" ,
278- "F1 Score Random Selection" ,
279- "HiClass F1 Score Random Selection" ,
280- "Brand F1 Score" ],
301+ [
302+ [
303+ "From accuracy file" ,
304+ category_f1_score ,
305+ hiclass_f1_score ,
306+ brand_score ,
307+ is_secondhand_f1_score ,
308+ ],
309+ [
310+ "Random selection" ,
311+ rand_cat_f1_score ,
312+ rand_hiclass_f1_score ,
313+ rand_brand_score ,
314+ rand_is_seconhand_f1_score ,
315+ ],
316+ ],
317+ headers = [
318+ "Results" ,
319+ "Category hierarchical F1 Score" ,
320+ "Category HiClass F1 Score" ,
321+ "Brand F1 Score" ,
322+ "Is_secondhand F1 Score" ,
323+ ],
281324 tablefmt = "fancy_grid" ,
282325 ),
283326 )
0 commit comments