Skip to content

Commit

Permalink
added code for new version of the nfa detection algorithm: tree nfa
Browse files Browse the repository at this point in the history
  • Loading branch information
mtailanian committed Jan 22, 2024
1 parent 247a373 commit 6c46263
Show file tree
Hide file tree
Showing 8 changed files with 347 additions and 76 deletions.
30 changes: 2 additions & 28 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,5 @@ scikit-learn==1.1.1
tensorboard==2.11.0
pyyaml==6.0
gdown==4.5.3

# seaborn~=0.11.2
# opencv-python~=4.6.0.66
# pytorch-lightning-bolts
# tensorboard
# albumentations
# visdom
# optuna
# tqdm~=4.48.2
# patool~=1.12
# natsort~=7.0.1
# wget~=3.2
# progressbar~=2.5
# rarfile~=4.0
# pyunpack
# Pillow~=9.2.0
# scikit-learn~=1.1.1
# argparse~=1.4.0
# scikit-image~=0.19.3
# hypothesis~=6.0.2
# pyod~=0.8.6
# gym
# pandas~=1.2.0
# Shapely~=1.8.2
# scipy~=1.9.0
# pyyaml==6.0
# setuptools==59.5.0

scikit-image==0.21.0
networkx==3.1
8 changes: 4 additions & 4 deletions src/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def __init__(self, experiment_path):
)


class ModelCheckpointByIoU(ModelCheckpoint):
class ModelCheckpointBymIoU(ModelCheckpoint):
def __init__(self, experiment_path):
super(ModelCheckpointByIoU, self).__init__(
monitor='iou',
super(ModelCheckpointBymIoU, self).__init__(
monitor='miou',
dirpath=str(experiment_path),
mode='max',
filename='best_val_iou_nfa__epoch_{epoch:04d}__iou_{iou:.4f}',
filename='best_val_miou_nfa__epoch_{epoch:04d}__miou_{miou:.4f}',
auto_insert_metric_name=False,
)

Expand Down
59 changes: 37 additions & 22 deletions src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

from src.model import UFlow
from src.datamodule import MVTecLightningDatamodule
from src.iou import IoU
from src.miou import mIoU
from src.aupro import AUPRO
from src.nfa import compute_log_nfa_anomaly_score
from src.nfa_block import compute_log_nfa_anomaly_score
from src.nfa_tree import compute_nfa_anomaly_score_tree

warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

Expand Down Expand Up @@ -72,11 +73,10 @@ def reproduce_results(args):
# IoU
flow_model.from_pretrained(Path("models") / "iou" / f"{category}.ckpt")
flow_model.eval()
eval_iou(
eval_miou(
flow_model,
datamodule,
target_size=TARGET_SIZE,
high_precision=args.high_precision
target_size=TARGET_SIZE
)


Expand Down Expand Up @@ -131,18 +131,18 @@ def eval_aupro(model, dataloader, target_size: Union[None, int] = None):
print(f"\t\tAUPRO: {aupro.compute()}")


def eval_iou(model, datamodule, target_size: Union[None, int] = None, high_precision: bool = False):
def eval_miou(model, datamodule, target_size: Union[None, int] = None):

if target_size is None:
target_size = model.input_size

model = model.to(DEVICE)

fair_likelihood_thr = get_fair_threshold(model, datamodule.train_dataloader(), TARGET_FPR, target_size)
nfa_thresholds = list(np.arange(-2, 2, 0.05))
# This would be the code for computing the fair threshold for the case when we do not have an automatic threshold
# fair_likelihood_thr = get_fair_threshold(model, datamodule.train_dataloader(), TARGET_FPR, target_size)

iou_likelihood = IoU(thresholds=[fair_likelihood_thr])
iou_nfa = IoU(thresholds=nfa_thresholds)
nfa_thresholds = list(np.arange(-200, 1001, 20))
miou_metric = mIoU(thresholds=nfa_thresholds)

progress_bar = tqdm(datamodule.val_dataloader())
progress_bar.set_description("\tComputing IoU")
Expand All @@ -152,28 +152,43 @@ def eval_iou(model, datamodule, target_size: Union[None, int] = None, high_preci
with torch.no_grad():
z, _ = model(image)

anomaly_score_likelihood = 1 - model.get_probability(z, target_size)
anomaly_score_nfa = compute_log_nfa_anomaly_score(
z, win_size=5, binomial_probability_thr=0.9, high_precision=high_precision
)
anomaly_score = compute_nfa_anomaly_score_tree(z, target_size=target_size)

# Alternative old computation -------------------------------------------
block_nfa = False
if block_nfa:
anomaly_score = compute_log_nfa_anomaly_score(z, high_precision=True)
# -----------------------------------------------------------------------

if targets.shape[-1] != target_size:
targets = F.interpolate(targets, size=[target_size, target_size], mode="bilinear", align_corners=False)
targets = 1 * (targets > 0.5)

iou_likelihood.update(anomaly_score_likelihood.detach().cpu(), targets.cpu())
iou_nfa.update(anomaly_score_nfa.detach().cpu(), targets.cpu())
miou_metric.update(anomaly_score.detach().cpu(), targets.cpu())

iou_fair = iou_likelihood.compute().numpy()
iou_nfas = iou_nfa.compute().numpy()
mious = miou_metric.compute().numpy()

print(f"\t\tIoU @ log(NFA)=0: {iou_nfas[list(np.around(nfa_thresholds, 2)).index(0)]}")
print(f"\t\tIoU @ oracle-thr: {np.max(iou_nfas)}")
print(f"\t\tIoU @ fair-thr : {iou_fair}")
print(f"\t\tmIoU @ log(NFA)=0: {mious[list(np.around(nfa_thresholds, 2)).index(0)]}")
print(f"\t\tmIoU @ oracle-thr: {np.max(mious)}")


def get_fair_threshold(model, dataloader, target_fpr=0.01, target_size=None):

"""
This is the code used for computing the fair threshold over the likelihoods for the case when we do not have an
automatic thresholds (i.e. we do not have the NFA). This method was used for computing the fair threshold for all
competitors in the paper. It mimics the same rationale of the NFA, allowing at most one false positive per image on
average. as explained in the paper, or this computation only the anomaly free images are used.
Parameters
----------
model:
dataloader
target_fpr
target_size
Returns
-------
The fair threshold, i.e. the threshold that allows at most one false positive per image on average.
"""
if target_size is None:
target_size = model.input_size

Expand Down
64 changes: 64 additions & 0 deletions src/miou.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from abc import ABC
import numpy as np
from typing import Any, Callable, Optional, Union, List

import torch
from torch import Tensor
from torchmetrics import Metric

EPS = np.finfo(float).eps


class mIoU(Metric, ABC):
"""
Computes intersection over union metric (or Jaccard Index), for different thresholds
J(A, B) = (A \cap B) / (A \cup B), for each image independently, and then takes the
average for all images for each threshold
"""
full_state_update: bool = False

def __init__(
self,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
thresholds: Optional[Union[float, List]] = None
):
super(mIoU, self).__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn
)
if type(thresholds) is not list:
thresholds = [thresholds]
self.thresholds = torch.tensor(thresholds, dtype=torch.float32)

self.add_state("intersection", default=[])
self.add_state("union", default=[])

def update(self, preds: Tensor, target: Tensor):
if len(preds.shape) == 2:
preds = preds.unsqueeze(0)
target = target.unsqueeze(0)

preds, target = preds.view(preds.shape[0], -1), target.view(target.shape[0], -1)

intersections, unions = self.compute_intersection_and_union(preds, target)
self.intersection.append(intersections)
self.union.append(unions)

def compute_intersection_and_union(self, detection: np.array, labels: np.array):
intersections, unions = [], []
labels_thr = (torch.max(labels) - torch.min(labels)) / 2
for thr in self.thresholds:
pred, target = detection > thr, labels > labels_thr
intersections.append(torch.sum(pred & target, dim=1))
unions.append(torch.sum(pred | target, dim=1))
return torch.stack(intersections).T, torch.stack(unions).T

def compute(self):
self.intersection = torch.cat(self.intersection)
self.union = torch.cat(self.union)
return torch.mean((self.intersection + EPS) / (self.union + EPS), dim=0)
4 changes: 2 additions & 2 deletions src/nfa_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

def compute_log_nfa_anomaly_score(
z: List[torch.Tensor],
win_size: int = 7,
binomial_probability_thr: float = 0.5,
win_size: int = 5,
binomial_probability_thr: float = 0.9,
target_size: int = 256,
high_precision: bool = False
):
Expand Down
Loading

0 comments on commit 6c46263

Please sign in to comment.