Skip to content

Commit ebdff67

Browse files
support batch comupute for chamferloss
1 parent 4d9d5ff commit ebdff67

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

ppsci/loss/chamfer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,17 @@ class ChamferLoss(base.Loss):
4040
Examples:
4141
>>> import paddle
4242
>>> from ppsci.loss import ChamferLoss
43-
>>> output_dict = {"s1": paddle.to_tensor([[.1, .2, .3], [.4, .5, .6]])}
44-
>>> label_dict = {"s1": paddle.to_tensor([[.4, .5, .6], [.1, .2, .3]])}
43+
>>> _ = paddle.seed(42)
44+
>>> batch_point_cloud1 = paddle.rand([2, 100, 3])
45+
>>> batch_point_cloud2 = paddle.rand([2, 50, 3])
46+
>>> output_dict = {"s1": batch_point_cloud1}
47+
>>> label_dict = {"s1": batch_point_cloud2}
4548
>>> weight = {"s1": 0.8}
4649
>>> loss = ChamferLoss(weight=weight)
4750
>>> result = loss(output_dict, label_dict)
4851
>>> print(result)
4952
Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
50-
0.)
53+
0.04415882)
5154
"""
5255

5356
def __init__(
@@ -61,16 +64,16 @@ def forward(self, output_dict, label_dict, weight_dict=None):
6164
for key in label_dict:
6265
s1 = output_dict[key]
6366
s2 = label_dict[key]
64-
N1, N2 = s1.shape[0], s2.shape[0]
67+
N1, N2 = s1.shape[1], s2.shape[1]
6568

6669
# [N1, N2, 3]
67-
s1_expand = paddle.expand(s1.reshape([N1, 1, 3]), shape=[N1, N2, 3])
70+
s1_expand = paddle.expand(s1.reshape([-1, N1, 1, 3]), shape=[-1, N1, N2, 3])
6871
# [N1, N2, 3]
69-
s2_expand = paddle.expand(s2.reshape([1, N2, 3]), shape=[N1, N2, 3])
72+
s2_expand = paddle.expand(s2.reshape([-1, 1, N2, 3]), shape=[-1, N1, N2, 3])
7073

71-
dis = ((s1_expand - s2_expand) ** 2).sum(axis=2) # [N1, N2]
72-
loss_s12 = dis.min(axis=1) # [N1]
73-
loss_s21 = dis.min(axis=0) # [N2]
74+
dis = ((s1_expand - s2_expand) ** 2).sum(axis=3) # [B, N1, N2]
75+
loss_s12 = dis.min(axis=2) # [B, N1]
76+
loss_s21 = dis.min(axis=1) # [B, N2]
7477
loss = loss_s12.mean() + loss_s21.mean()
7578

7679
if weight_dict and key in weight_dict:

test/loss/chamfer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ def test_chamfer_loss():
1111
"""Test for chamfer distance loss."""
1212
N1 = 100
1313
N2 = 50
14-
output_dict = {"s1": paddle.randn([N1, 3])}
15-
label_dict = {"s1": paddle.randn([N2, 3])}
14+
output_dict = {"s1": paddle.randn([1, N1, 3])}
15+
label_dict = {"s1": paddle.randn([1, N2, 3])}
1616
chamfer_loss = loss.ChamferLoss()
1717
result = chamfer_loss(output_dict, label_dict)
1818

1919
loss_cd_s1 = 0.0
2020
for i in range(N1):
2121
min_i = None
2222
for j in range(N2):
23-
disij = ((output_dict["s1"][i] - label_dict["s1"][j]) ** 2).sum()
23+
disij = ((output_dict["s1"][0, i] - label_dict["s1"][0, j]) ** 2).sum()
2424
if min_i is None or disij < min_i:
2525
min_i = disij
2626
loss_cd_s1 += min_i
@@ -30,7 +30,7 @@ def test_chamfer_loss():
3030
for j in range(N2):
3131
min_j = None
3232
for i in range(N1):
33-
disij = ((output_dict["s1"][i] - label_dict["s1"][j]) ** 2).sum()
33+
disij = ((output_dict["s1"][0, i] - label_dict["s1"][0, j]) ** 2).sum()
3434
if min_j is None or disij < min_j:
3535
min_j = disij
3636
loss_cd_s2 += min_j

0 commit comments

Comments
 (0)