Skip to content

Commit

Permalink
support dp for class_center_sample, margin_cross_entropy
Browse files Browse the repository at this point in the history
add unittest

fix unittest

fix unittest
  • Loading branch information
GuoxiaWang committed Feb 28, 2022
1 parent dcfe198 commit a9c1e82
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 17 deletions.
14 changes: 14 additions & 0 deletions python/paddle/fluid/tests/unittests/test_class_center_sample_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,21 @@ def test_empty_label():
label, self.num_classes, self.num_samples)
print(remapped_label, sampled_class_index)

def test_group_value():
for place in self.places:
with paddle.fluid.dygraph.guard(place):
label_np = np.random.randint(
0,
self.num_classes, (self.batch_size, ),
dtype=self.dtype)
label = paddle.to_tensor(label_np)

remapped_label, sampled_class_index = paddle.nn.functional.class_center_sample(
label, self.num_classes, self.num_samples)
print(remapped_label, sampled_class_index)

self.assertRaises(ValueError, test_empty_label)
self.assertRaises(ValueError, test_group_value)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,31 @@ def test_label_type():
return_softmax=True,
reduction=None)

def test_group_value():
for place in self.places:
with paddle.fluid.dygraph.guard(place):
labels_np = np.random.randint(
0, self.num_class, (self.batch_dim, ), dtype="int64")
logits_np = np.random.uniform(
-0.99, 0.99,
[self.batch_dim, self.num_class]).astype(self.dtype)
labels = paddle.to_tensor(labels_np)
logits = paddle.to_tensor(logits_np)

loss, softmax = paddle.nn.functional.margin_cross_entropy(
logits,
labels,
margin1=self.margin1,
margin2=self.margin2,
margin3=self.margin3,
scale=self.scale,
return_softmax=True,
reduction=None,
group=False)

self.assertRaises(ValueError, test_dim)
self.assertRaises(NotImplementedError, test_label_type)
self.assertRaises(ValueError, test_group_value)


if __name__ == '__main__':
Expand Down
20 changes: 12 additions & 8 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1731,18 +1731,22 @@ class centers and the shape of sampled_class_center will be [num_positive_class_
#Tensor(shape=[7], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
# [0, 1, 2, 3, 5, 7, 8])
"""
if group is not None and not group.is_member():
if group != False and group is not None and not group.is_member():
raise ValueError(
'Expected group is True, None or instance of paddle.distributed.collective.Group \
(got group: {})'.format(group))
return

ring_id = 0 if group is None else group.id
ring_id = 0
rank = 0
nranks = 1
if core.is_compiled_with_dist():
parallel_env = paddle.distributed.ParallelEnv()
global_rank = parallel_env.rank
rank = global_rank if group is None else group.get_group_rank(
global_rank)
nranks = parallel_env.world_size if group is None else group.nranks
if group != False:
if core.is_compiled_with_dist():
parallel_env = paddle.distributed.ParallelEnv()
global_rank = parallel_env.rank
rank = global_rank if group is None else group.get_group_rank(
global_rank)
nranks = parallel_env.world_size if group is None else group.nranks

if num_samples > num_classes:
raise ValueError(
Expand Down
23 changes: 14 additions & 9 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,24 +1297,29 @@ def margin_cross_entropy(logits,
"""

assert reduction in ['mean', 'sum', 'none', None]
if group is not None and not group.is_member():
if group != False and group is not None and not group.is_member():
raise ValueError(
'Expected group is False, None or instance of paddle.distributed.collective.Group \
(got group: {})'.format(group))
return

ring_id = 0 if group is None else group.id
ring_id = 0
rank = 0
nranks = 1
if core.is_compiled_with_dist():
parallel_env = paddle.distributed.ParallelEnv()
global_rank = parallel_env.rank
rank = global_rank if group is None else group.get_group_rank(
global_rank)
nranks = parallel_env.world_size if group is None else group.nranks
if group != False:
ring_id = 0 if group is None else group.id
if core.is_compiled_with_dist():
parallel_env = paddle.distributed.ParallelEnv()
global_rank = parallel_env.rank
rank = global_rank if group is None else group.get_group_rank(
global_rank)
nranks = parallel_env.world_size if group is None else group.nranks

input_dims = len(list(logits.shape))
label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims:
raise ValueError(
'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
'Expected input_dims - 1 = label_dims or input_dims == label_dims\
(got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=-1)
Expand Down

0 comments on commit a9c1e82

Please sign in to comment.