-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
79 lines (63 loc) · 2.37 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
from einops import rearrange
def generate_predicted_videos(
outputs,
videos_patch,
bool_masked_pos,
batch_size,
input_size,
patch_size,
tublet_size,
num_frames,
):
predicted_patch = videos_patch.clone()
predicted_patch[bool_masked_pos] = outputs.reshape(
[-1, tublet_size * patch_size * patch_size]
).to(torch.float32)
predicted_videos = rearrange(
predicted_patch,
"b (t h w) (p0 p1 p2 c) -> b c (t p0) (h p1) (w p2)",
b=batch_size,
c=1,
t=num_frames // tublet_size,
h=input_size // patch_size,
w=input_size // patch_size,
p0=tublet_size,
p1=patch_size,
p2=patch_size,
)
return predicted_videos
def compute_eval_metrics(videos, predicted_videos, obs_frames=8):
B, C, T, H, W = videos.shape
gt_fut_videos = videos[:, :, obs_frames:].reshape(-1, H, W)
predicted_fut_videos = predicted_videos[:, :, obs_frames:].reshape(-1, H, W)
d_akl = kl_divergence(gt_fut_videos, predicted_fut_videos)
d_arkl = kl_divergence(predicted_fut_videos, gt_fut_videos)
d_ajs = js_divergence(gt_fut_videos, predicted_fut_videos)
gt_fut_videos = torch.unsqueeze(videos[:, :, -1], dim=2).reshape(-1, H, W)
predicted_fut_videos = torch.unsqueeze(predicted_videos[:, :, -1], dim=2).reshape(
-1, H, W
)
d_fkl = kl_divergence(gt_fut_videos, predicted_fut_videos)
d_frkl = kl_divergence(predicted_fut_videos, gt_fut_videos)
d_fjs = js_divergence(gt_fut_videos, predicted_fut_videos)
return d_akl, d_arkl, d_ajs, d_fkl, d_frkl, d_fjs
def kl_divergence(p, q, eps=1e-10):
p += eps
q += eps
p = p / (torch.sum(p, dim=(1, 2), keepdim=True) + eps)
q = q / (torch.sum(q, dim=(1, 2), keepdim=True) + eps)
p = torch.clamp(p, min=eps, max=1 - eps)
q = torch.clamp(q, min=eps, max=1 - eps)
return nn.KLDivLoss(reduction="batchmean")((q + eps).log(), p)
def js_divergence(p, q, eps=1e-10):
m = 0.5 * (p + q)
return 0.5 * (kl_divergence(p, m, eps) + kl_divergence(q, m, eps))
if __name__ == "__main__":
gt_videos = torch.rand([4, 1, 20, 80, 80])
pred_videos = torch.rand([4, 1, 20, 80, 80])
d_akl, d_arkl, d_ajs, d_fkl, d_frkl, d_fjs = compute_eval_metrics(
gt_videos, pred_videos, obs_frames=8
)
print(d_akl, d_arkl, d_ajs, d_fkl, d_frkl, d_fjs)