-
Notifications
You must be signed in to change notification settings - Fork 6
/
predict.py
56 lines (45 loc) · 2.15 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import yaml
import h5py
import argparse
import pathlib
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from train import LightningTrainer
class LightningTester(LightningTrainer):
def __init__(self, config):
super().__init__(config)
def _load_save_file(self, save_path):
self.f = h5py.File(save_path, 'w')
def test_step(self, batch, batch_idx):
rpz, fea, _ = batch['teacher']
p_rep, output = self(self.teacher, fea, rpz)
conf, pred = torch.max(output.softmax(1), dim=1)
conf = conf.cpu().detach().numpy()
pred = pred.cpu().detach().numpy()
key = os.path.join(self.train_dataset.label_paths[batch_idx])
conf_key, pred_key = os.path.join(key, 'conf'), os.path.join(key, 'pred')
self.f.create_dataset(conf_key, data=conf)
self.f.create_dataset(pred_key, data=pred)
def test_dataloader(self):
self.config['train_dataloader']['shuffle'] = False
self.train_dataset.split = 'test'
return DataLoader(dataset=self.train_dataset, **self.config['train_dataloader'])
if __name__=='__main__':
CKPT_PATH="your/ckpt/file/path/mycheckpoint.ckpt"
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', default='config/predict.yaml')
parser.add_argument('--dataset_config_path', default='config/semantickitti.yaml')
parser.add_argument('--checkpoint_path', default=CKPT_PATH)
parser.add_argument('--save_dir', default='your/save/dir')
args = parser.parse_args()
config = yaml.safe_load(open(args.config_path, 'r'))
config['dataset'].update(yaml.safe_load(open(args.dataset_config_path, 'r')))
wandb_logger = WandbLogger(config=config, save_dir=config['trainer']['default_root_dir'], **config['logger'])
trainer = Trainer(logger=wandb_logger, **config['trainer'])
model = LightningTester.load_from_checkpoint(args.checkpoint_path, config=config)
pathlib.Path(args.save_dir).mkdir(parents=True, exist_ok=True)
model._load_save_file(os.path.join(args.save_dir, 'training_results.h5'))
trainer.test(model)