diff --git a/examples/classification.py b/examples/classification.py index a87d516..c7428b4 100644 --- a/examples/classification.py +++ b/examples/classification.py @@ -52,8 +52,8 @@ def preprocess_image(image_path, target_size=(256, 256), crop_size=(224, 224)): def get_top_k_predictions(output, k=5): # Get top k predictions - top_k_indices = np.argsort(output[0])[-k:][::-1] - top_k_scores = output[0][top_k_indices] + top_k_indices = np.argsort(output[0].flatten())[-k:][::-1] + top_k_scores = output[0].flatten()[top_k_indices] return top_k_indices, top_k_scores