Skip to content

Commit

Permalink
[Zero-Dim] Support paddle.max output 0D, test=allcase (#53242)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 authored Apr 24, 2023
1 parent ddd7203 commit 9f9cd91
Show file tree
Hide file tree
Showing 19 changed files with 88 additions and 41 deletions.
7 changes: 4 additions & 3 deletions paddle/fluid/operators/reduce_ops/reduce_max_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ class ReduceMaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(reduce_max,
ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::OriginReduceInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(
reduce_max,
ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));

REGISTER_OPERATOR(
reduce_max,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,7 @@ void max_grad(const Tensor& x,
} else {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@
args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out)
infer_meta :
func : OriginReduceInferMeta
func : ReduceIntArrayAxisInferMeta
kernel :
func : max
backward : max_grad
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/add_n_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ PD_REGISTER_KERNEL(add_n,
double,
int,
phi::dtype::bfloat16,
phi::dtype::float16,
int64_t) {}

PD_REGISTER_KERNEL(add_n_array,
Expand All @@ -99,4 +100,5 @@ PD_REGISTER_KERNEL(add_n_array,
double,
int,
phi::dtype::bfloat16,
phi::dtype::float16,
int64_t) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/funcs/selected_rows_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ template struct SelectedRowsAddToTensor<phi::CPUContext, float>;
template struct SelectedRowsAddToTensor<phi::CPUContext, double>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int>;
template struct SelectedRowsAddToTensor<phi::CPUContext, int64_t>;
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::float16>;
template struct SelectedRowsAddToTensor<phi::CPUContext, phi::dtype::bfloat16>;

#ifdef PADDLE_WITH_XPU
Expand Down
18 changes: 9 additions & 9 deletions paddle/phi/kernels/funcs/unsqueeze.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,32 +105,32 @@ inline DDim GetOutputSqueezeShape(const std::vector<int> squeeze_dims,

inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
const DDim& in_dims) {
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0);
int output_rank = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_rank = in_dims.size();
std::vector<int64_t> output_shape(output_rank, 0);

// Validity Check: rank range.
PADDLE_ENFORCE_LE(
output_size,
output_rank,
6,
phi::errors::InvalidArgument("The output "
"tensor's rank should be less than 6."));

for (int axis : unsqz_dims) {
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
int cur = axis < 0 ? axis + cur_output_rank + 1 : axis;
// Vaildity Check: the axis bound
PADDLE_ENFORCE_GE(
cur,
0,
phi::errors::InvalidArgument("The insert dimension value should "
"not be less than 0"));
PADDLE_ENFORCE_LE(cur,
cur_output_size,
cur_output_rank,
phi::errors::InvalidArgument(
"The insert dimension value shoule not be larger "
"than the dimension size of input tensor"));
// Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) {
for (int i = cur_output_rank; i >= cur; --i) {
if (output_shape[i] == 1) {
// Move axis
output_shape[i + 1] = 1;
Expand All @@ -139,11 +139,11 @@ inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
}
output_shape[cur] = 1;
// Add the output size.
cur_output_size++;
cur_output_rank++;
}

// Make output shape
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
for (int in_idx = 0, out_idx = 0; out_idx < output_rank; ++out_idx) {
if (output_shape[out_idx] == 0) {
output_shape[out_idx] = in_dims[in_idx++];
}
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/onednn/reduce_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ void ReduceKernel(const Context& dev_ctx,
reduction_p->execute(astream, reduction_args);
astream.wait();

out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
const auto reshape_dims = out->dims().size() != 0
? vectorize<int64_t>(out->dims())
: std::vector<int64_t>{1};
out->set_mem_desc(dst_memory_p->get_desc().reshape(reshape_dims));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def unscale_method(self, optimizer):
paddle.distributed.all_reduce(
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
)
self._found_inf = is_found_inf.numpy()[0]
self._found_inf = int(is_found_inf)


class MixPrecisionScaler:
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _ndim_(var):

@property
def _size_(var):
return np.prod(var.shape)
return int(np.prod(var.shape))

@property
def _T_(var):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def test_LR_state_dict(self):
adam_test.set_dict(opt_state)
self.assertEqual(
adam_test._learning_rate.best_loss,
adam3._learning_rate.best_loss.numpy()[0],
adam3._learning_rate.best_loss,
"best_loss is different before and after set_dict",
)
self.assertEqual(
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_LinearLrWarmup(self):
t = lr()

np.testing.assert_allclose(
t.numpy()[0].item(), right_result[i], rtol=1e-05
t.numpy().item(), right_result[i], rtol=1e-05
)

with self.assertRaises(TypeError):
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_StepDecay(self):
right_result = step_decay(
epoch, learning_rate, step_size, decay_rate
)
fluid_result = scheduler().numpy()[0]
fluid_result = scheduler().numpy().item()
scheduler.epoch()
self.assertAlmostEqual(
right_result,
Expand Down Expand Up @@ -371,7 +371,7 @@ def test_LambdaDecay(self):

for epoch in range(30):
right_result = lambda_decay(epoch, learning_rate, lr_lambda)
fluid_result = scheduler().numpy()[0]
fluid_result = scheduler().numpy().item()
scheduler.epoch()
self.assertAlmostEqual(
right_result,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _test_dygraph(self, place, kwargs):
self.assertEqual(
scheduler.cooldown_counter, scheduler1.cooldown_counter
)
self.assertEqual(scheduler.best.numpy()[0], scheduler1.best)
self.assertEqual(scheduler.best, scheduler1.best)
self.assertEqual(scheduler.num_bad_epochs, scheduler1.num_bad_epochs)
self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch)
self.assertEqual(scheduler.last_lr, scheduler1.last_lr)
Expand Down
49 changes: 44 additions & 5 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,19 @@ def test_dygraph_reduce(self):
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad.numpy(), np.array(3.0))

# 2) x is ND
if api in [
paddle.sum,
paddle.mean,
paddle.nanmean,
paddle.nansum,
paddle.max,
]:
return

x = paddle.rand([3, 5])
# 2) x is ND, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [3, 5]).astype('bool')
else:
x = paddle.rand([3, 5])
x.stop_gradient = False
out = api(x, None)
out.retain_grads()
Expand All @@ -240,6 +242,21 @@ def test_dygraph_reduce(self):
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [3, 5])

# 3) x is 1D, axis=0, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [5]).astype('bool')
else:
x = paddle.rand([5])
x.stop_gradient = False
out = api(x, 0)
out.retain_grads()
out.backward()

self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [5])

paddle.enable_static()

def test_static_reduce(self):
Expand Down Expand Up @@ -284,16 +301,19 @@ def test_static_reduce(self):
np.testing.assert_allclose(res[2], np.array(1.0))
np.testing.assert_allclose(res[3], np.array(1.0))

# 2) x is ND
if api in [
paddle.sum,
paddle.mean,
paddle.nanmean,
paddle.nansum,
paddle.max,
]:
return

# 2) x is ND, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [3, 5]).astype('bool')
else:
x = paddle.rand([3, 5])
x = paddle.rand([3, 5])
x.stop_gradient = False
out = api(x, None)
Expand All @@ -309,6 +329,25 @@ def test_static_reduce(self):
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (3, 5))

# 3) x is 1D, axis=0, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [5]).astype('bool')
else:
x = paddle.rand([5])
x.stop_gradient = False
out = api(x, 0)
paddle.static.append_backward(out)

fetch_list = [out]
if block.has_var(x.grad_name):
fetch_list.extend([out.grad_name, x.grad_name])

res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ())
if len(res) > 1:
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (5,))

paddle.disable_static()


Expand Down
9 changes: 7 additions & 2 deletions python/paddle/hapi/progressbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,13 @@ def convert_uint16_to_float(in_list):

for i, (k, val) in enumerate(values):
if k == "loss":
val = val if isinstance(val, (list, np.ndarray)) else [val]
if isinstance(val[0], np.uint16):
if isinstance(val, list):
scalar_val = val[0]
elif isinstance(val, np.ndarray):
scalar_val = val.item()
else:
scalar_val = val
if isinstance(scalar_val, np.uint16):
values[i] = ("loss", list(convert_uint16_to_float(val)))

if current_num:
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def _dygraph_clip(self, params_grads):
global_norm_var = paddle.add_n(global_norm_var)
global_norm_var = paddle.sqrt(global_norm_var)
max_global_norm = paddle.full(
shape=[1], dtype=global_norm_var.dtype, fill_value=self.clip_norm
shape=[], dtype=global_norm_var.dtype, fill_value=self.clip_norm
)

need_clip = False
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/quant/lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(
s_attr = ParamAttr(
name=self._scale_name, initializer=Constant(1.0), trainable=True
)
self.s = self.create_parameter(shape=[1], attr=s_attr, dtype='float32')
self.s = self.create_parameter(shape=[], attr=s_attr, dtype='float32')
self.s.stop_gradient = False

if not self.symmetric:
Expand All @@ -189,7 +189,7 @@ def __init__(
name=self._beta_name, initializer=Constant(0.0), trainable=True
)
self.beta = self.create_parameter(
shape=[1], attr=beta_attr, dtype='float32'
shape=[], attr=beta_attr, dtype='float32'
)
self.beta.stop_gradient = False

Expand Down
5 changes: 1 addition & 4 deletions test/autograd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
# Finite Difference Utils
##########################################################
def _product(t):
if isinstance(t, int):
return t
else:
return np.product(t)
return int(np.product(t))


def _get_item(t, idx):
Expand Down
2 changes: 1 addition & 1 deletion test/dygraph_to_static/seq2seq_dygraph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def beam_search(self, inputs):
parent_ids = []

for step_idx in range(paddle.to_tensor(self.beam_max_step_num)):
if paddle.sum(1 - beam_finished).numpy()[0] == 0:
if paddle.sum(1 - beam_finished) == 0:
break
step_input = self._merge_batch_beams(step_input)
new_dec_hidden, new_dec_cell = [], []
Expand Down
2 changes: 1 addition & 1 deletion test/dygraph_to_static/test_for_enumerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def for_in_range(x):
z = paddle.tensor.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x)
for i in range(x.numpy()[0]):
for i in range(x.numpy().item()):
z = z + i
return z

Expand Down
4 changes: 2 additions & 2 deletions test/dygraph_to_static/test_sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def train(args, to_static):

model.train()
avg_cost, prediction, acc = model(doc, label)
loss_data.append(avg_cost.numpy()[0])
loss_data.append(float(avg_cost))

avg_cost.backward()
sgd_optimizer.minimize(avg_cost)
Expand All @@ -358,7 +358,7 @@ def train(args, to_static):
"step: %d, ave loss: %f, speed: %f steps/s"
% (
batch_id,
avg_cost.numpy()[0],
float(avg_cost),
args.log_step / used_time,
)
)
Expand Down

0 comments on commit 9f9cd91

Please sign in to comment.