-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wiki] Tutorial #80
Comments
1. How to implement a method with train (Classification) and test (OOD) stage ?
|
2. Implement a method with test stage: take Gram as an example
from openood.datasets import get_dataloader, get_ood_dataloader
from openood.evaluators import get_evaluator
from openood.networks import get_network
from openood.postprocessors import get_postprocessor
from openood.utils import setup_logger
class TestOODPipeline:
def __init__(self, config) -> None:
self.config = config
def run(self):
# generate output directory and save the full config file
setup_logger(self.config)
# get dataloader
id_loader_dict = get_dataloader(self.config)
ood_loader_dict = get_ood_dataloader(self.config)
# init network
net = get_network(self.config.network)
# init ood evaluator
evaluator = get_evaluator(self.config)
# init ood postprocessor
postprocessor = get_postprocessor(self.config)
# setup for distance-based methods
postprocessor.setup(net, id_loader_dict, ood_loader_dict)
print('\n', flush=True)
print(u'\u2500' * 70, flush=True)
# start calculating accuracy
print('\nStart evaluation...', flush=True)
acc_metrics = evaluator.eval_acc(net, id_loader_dict['test'],
postprocessor)
print('\nAccuracy {:.2f}%'.format(100 * acc_metrics['acc']),
flush=True)
print(u'\u2500' * 70, flush=True)
# start evaluating ood detection methods
evaluator.eval_ood(net, id_loader_dict, ood_loader_dict, postprocessor)
print('Completed!', flush=True)
class GRAMPostprocessor(BasePostprocessor):
def __init__(self, config):
self.config = config
self.postprocessor_args = config.postprocessor.postprocessor_args
self.num_classes = self.config.dataset.num_classes
self.powers = self.postprocessor_args.powers
self.feature_min, self.feature_max = None, None
def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict):
self.feature_min, self.feature_max = sample_estimator(
net, id_loader_dict['train'], self.num_classes, self.powers) @torch.no_grad()
def sample_estimator(model, train_loader, num_classes, powers):
model.eval()
num_layer = 5
num_poles_list = powers
num_poles = len(num_poles_list)
feature_class = [[[None for x in range(num_poles)]
for y in range(num_layer)] for z in range(num_classes)]
label_list = []
mins = [[[None for x in range(num_poles)] for y in range(num_layer)]
for z in range(num_classes)]
maxs = [[[None for x in range(num_poles)] for y in range(num_layer)]
for z in range(num_classes)]
# collect features and compute gram metrix
for batch in tqdm(train_loader, desc='Compute min/max'):
data = batch['data'].cuda()
label = batch['label']
_, feature_list = model(data, return_feature_list=True)
label_list = tensor2list(label)
for layer_idx in range(num_layer):
for pole_idx, p in enumerate(num_poles_list):
temp = feature_list[layer_idx].detach()
temp = temp**p
temp = temp.reshape(temp.shape[0], temp.shape[1], -1)
temp = ((torch.matmul(temp,
temp.transpose(dim0=2,
dim1=1)))).sum(dim=2)
temp = (temp.sign() * torch.abs(temp)**(1 / p)).reshape(
temp.shape[0], -1)
temp = tensor2list(temp)
for feature, label in zip(temp, label_list):
if isinstance(feature_class[label][layer_idx][pole_idx],
type(None)):
feature_class[label][layer_idx][pole_idx] = feature
else:
feature_class[label][layer_idx][pole_idx].extend(
feature)
# compute mins/maxs
for label in range(num_classes):
for layer_idx in range(num_layer):
for poles_idx in range(num_poles):
feature = torch.tensor(
np.array(feature_class[label][layer_idx][poles_idx]))
current_min = feature.min(dim=0, keepdim=True)[0]
current_max = feature.max(dim=0, keepdim=True)[0]
if mins[label][layer_idx][poles_idx] is None:
mins[label][layer_idx][poles_idx] = current_min
maxs[label][layer_idx][poles_idx] = current_max
else:
mins[label][layer_idx][poles_idx] = torch.min(
current_min, mins[label][layer_idx][poles_idx])
maxs[label][layer_idx][poles_idx] = torch.max(
current_min, maxs[label][layer_idx][poles_idx])
return mins, maxs
def postprocess(self, net: nn.Module, data: Any):
preds, deviations = get_deviations(net, data, self.feature_min,
self.feature_max, self.num_classes,
self.powers)
return preds, deviations def get_deviations(model, data, mins, maxs, num_classes, powers):
model.eval()
num_layer = 5
num_poles_list = powers
exist = 1
pred_list = []
dev = [0 for x in range(200)]
# get predictions
logits, feature_list = model(data, return_feature_list=True)
confs = F.softmax(logits, dim=1).cpu().detach().numpy()
preds = np.argmax(confs, axis=1)
predsList = preds.tolist()
preds = torch.tensor(preds)
for pred in predsList:
exist = 1
if len(pred_list) == 0:
pred_list.extend([pred])
else:
for pred_now in pred_list:
if pred_now == pred:
exist = 0
if exist == 1:
pred_list.extend([pred])
# compute sample level deviation
for layer_idx in range(num_layer):
for pole_idx, p in enumerate(num_poles_list):
# get gram metirx
temp = feature_list[layer_idx].detach()
temp = temp**p
temp = temp.reshape(temp.shape[0], temp.shape[1], -1)
temp = ((torch.matmul(temp, temp.transpose(dim0=2,
dim1=1)))).sum(dim=2)
temp = (temp.sign() * torch.abs(temp)**(1 / p)).reshape(
temp.shape[0], -1)
temp = tensor2list(temp)
# compute the deviations with train data
for idx in range(len(temp)):
dev[idx] += (F.relu(mins[preds[idx]][layer_idx][pole_idx] -
sum(temp[idx])) /
torch.abs(mins[preds[idx]][layer_idx][pole_idx] +
10**-6)).sum()
dev[idx] += (F.relu(
sum(temp[idx]) - maxs[preds[idx]][layer_idx][pole_idx]) /
torch.abs(maxs[preds[idx]][layer_idx][pole_idx] +
10**-6)).sum()
conf = [i / 50 for i in dev]
return preds, torch.tensor(conf)
#!/bin/bash
# sh scripts/ood/gram/7_cifar_test_ood_gram.sh
GPU=1
CPU=1
node=36
jobname=openood
PYTHONPATH='.':$PYTHONPATH \
#srun -p dsta --mpi=pmi2 --gres=gpu:${GPU} -n1 \
#--cpus-per-task=${CPU} --ntasks-per-node=${GPU} \
#--kill-on-bad-exit=1 --job-name=${jobname} -w SG-IDC1-10-51-2-${node} \
python main.py \
--config configs/datasets/objects/cifar10.yml \
configs/datasets/objects/cifar10_ood.yml \
configs/networks/resnet18_32x32.yml \
configs/pipelines/test/test_gram.yml \
configs/postprocessors/gram.yml \
--dataset.image_size 32 \
--network.name resnet18_32x32 \
--num_workers 8 |
No description provided.
The text was updated successfully, but these errors were encountered: