Skip to content

Commit

Permalink
mode_some_npu_bugs_2
Browse files Browse the repository at this point in the history
  • Loading branch information
Baibaifan committed May 12, 2021
1 parent cb4523d commit 19cf4ee
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 29 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,8 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
// will be executed and a warning will be given at the same time.
if (SupportGPU()) {
expected_kernel_key.place_ = dev_ctx->GetPlace();
} else if (SupportNPU()) {
expected_kernel_key.place_ = dev_ctx->GetPlace();
} else {
expected_kernel_key.place_ = platform::CPUPlace();
LOG_FIRST_N(WARNING, 1)
Expand Down Expand Up @@ -1299,6 +1301,8 @@ void OperatorWithKernel::TransferInplaceVarsBack(
auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto original_dims = original_tensor->dims();
original_tensor->ShareDataWith(*transformed_tensor);
// In order to solve the problem that the output latitude of NPU reshape
// operator is not changed when inplace.
if (type_ != "reshape2" && type_ != "reshape2_grad") {
original_tensor->Resize(original_dims);
}
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class OperatorBase {
std::string DebugString() const { return DebugStringEx(nullptr); }

virtual bool SupportGPU() const { return false; }
virtual bool SupportNPU() const { return false; }

const std::string& Type() const { return type_; }

Expand Down Expand Up @@ -490,6 +491,13 @@ class OperatorWithKernel : public OperatorBase {
return platform::is_gpu_place(kern_pair.first.place_);
});
}
bool SupportNPU() const override {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
return platform::is_npu_place(kern_pair.first.place_);
});
}
bool SupportsMKLDNN(proto::VarType::Type data_type) const;

bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/operators/collective/recv_v2_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
auto x = ctx.Output<framework::LoDTensor>("Out");
x->mutable_data<T>(x->dims(), ctx.GetPlace());
void* ptr = reinterpret_cast<void*>(const_cast<T*>(x->data<T>()));
int numel = x->numel();
HcclDataType dtype = platform::ToHCCLDataType(x->type());
auto out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(out->dims(), ctx.GetPlace());
void* ptr = reinterpret_cast<void*>(const_cast<T*>(out->data<T>()));
int numel = out->numel();
HcclDataType dtype = platform::ToHCCLDataType(out->type());

int ring_id = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/lookup_table_v2_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
auto *output_t = ctx.Output<framework::LoDTensor>("Out"); // float tensor
auto *table_t = ctx.Input<framework::LoDTensor>("W");

// It seems cann 20.1 accepts int64, but cann 20.2+ not.
PADDLE_ENFORCE_EQ(ids_t->type(), framework::proto::VarType::INT32,
platform::errors::Unimplemented(
"The index of LookupTableV2 should be int32."));

auto *table_var = ctx.InputVar("W");
PADDLE_ENFORCE_EQ(
table_var->IsType<framework::LoDTensor>(), true,
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/optimizers/adam_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class AdamNPUKernel : public framework::OpKernel<T> {
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
Expand Down
99 changes: 93 additions & 6 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..fluid.data_feeder import check_dtype
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
from ..fluid.dygraph import layers
from ..fluid.dygraph.parallel import prepare_context
import paddle
from .fleet import fleet
Expand Down Expand Up @@ -875,6 +876,84 @@ def _mp_allreduce(tensor,
raise NotImplementedError("No support _mp_allreduce in dygraph mode.")


class _Linear(layers.Layer):
"""
Linear
"""

def __init__(self,
in_features,
out_features,
weight_attr=None,
bias_attr=None,
name=None):
super(_Linear, self).__init__()
self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self.weight = self.create_parameter(
shape=[in_features, out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
self.bias = self.create_parameter(
shape=[out_features],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True)
self.name = name

def forward(self, input):
out = _linear(
x=input, weight=self.weight, bias=self.bias, name=self.name)
return out

def extra_repr(self):
name_str = ', name={}'.format(self.name) if self.name else ''
return 'in_features={}, out_features={}, dtype={}{}'.format(
self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)


def _linear(x, weight, bias=None, name=None):
"""
Fuction Linear
"""
if in_dygraph_mode():
pre_bias = _varbase_creator(dtype=x.dtype)
core.ops.matmul(x, weight, pre_bias, 'transpose_X', False,
'transpose_Y', False, "alpha", 1)
return dygraph_utils._append_bias_in_dygraph(
pre_bias, bias, axis=len(x.shape) - 1)
else:
helper = LayerHelper('linear', **locals())
dtype = x.dtype

check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'linear')
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')

inputs = {'X': [x], 'Y': [weight]}
attrs = {
'transpose_X': False,
'transpose_Y': False,
'alpha': 1,
}
tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='matmul_v2', inputs=inputs, outputs={'Out': tmp}, attrs=attrs)
if bias is not None:
res = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='elementwise_add',
inputs={'X': [tmp],
'Y': [bias]},
outputs={'Out': [res]},
attrs={'axis': len(x.shape) - 1})
else:
res = tmp
return res


def _parallel_linear(x,
num_rows,
num_cols,
Expand All @@ -900,12 +979,20 @@ def _parallel_linear(x,
else:
x = _c_identity(x, group=group)

linear = paddle.nn.Linear(
num_rows,
num_cols,
weight_attr=param_attr,
bias_attr=bias_attr,
name=name)
if core.is_compiled_with_npu():
linear = _Linear(
num_rows,
num_cols,
weight_attr=param_attr,
bias_attr=bias_attr,
name=name)
else:
linear = paddle.nn.Linear(
num_rows,
num_cols,
weight_attr=param_attr,
bias_attr=bias_attr,
name=name)

linear_out = linear(x)
startup_block = paddle.static.default_startup_program().global_block()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,18 @@ def get_grad_device(grad_name, shard):
return shard.global_param2device[base_name]


def get_first_check_finite_and_unscale_op_idx(block):
def get_first_check_finite_and_unscale_op_idx(block, raise_error=True):

for idx, op in enumerate(block.ops):
if op.type == "check_finite_and_unscale":
return idx

raise ValueError("check_finite_and_unscale does not exist in block")
if raise_error:
raise ValueError(
"amp is turned on but check_finite_and_unscale op does not exist in main block"
)

return -1


def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def minimize_impl(self,
print("persistable FP32 grad: ")
print(accumulated_grad_names)
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block)
main_block, raise_error=self.user_defined_strategy.amp)
insert_reduce_ops(
main_block,
first_optimize_op_index,
Expand All @@ -309,14 +309,15 @@ def minimize_impl(self,
use_calc_stream=True)
if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block)
insert_allreduce_ops(
main_block,
first_optimize_op_index,
self.dp_ring_id,
accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
main_block, raise_error=self.user_defined_strategy.amp)
if first_optimize_op_index >= 0:
insert_allreduce_ops(
main_block,
first_optimize_op_index,
self.dp_ring_id,
accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)

# if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/fluid/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,11 @@ def set_use_var(self, var_list):
slot_var.type = "float"
elif var.dtype == core.VarDesc.VarType.INT64:
slot_var.type = "uint64"
elif var.dtype == core.VarDesc.VarType.INT32:
slot_var.type = "uint32"
else:
raise ValueError(
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
"Currently, fluid.dataset only supports dtype=float32, dtype=int32 and dtype=int64"
)

def set_hdfs_config(self, fs_name, fs_ugi):
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14772,7 +14772,7 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
the size of the last shard will be less than the calculated `shard_size`

Args:
input (Tensor): Input indices with data type int64. It's last dimension must be 1.
input (Tensor): Input indices with data type int64 or int32. It's last dimension must be 1.
index_num (int): An integer defining the range of the index.
nshards (int): The number of shards.
shard_id (int): The index of the current shard.
Expand All @@ -14793,7 +14793,7 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
print(shard_label)
# [[-1], [1]]
"""
check_variable_and_dtype(input, 'input', ['int64'], 'shard_index')
check_variable_and_dtype(input, 'input', ['int64', 'int32'], 'shard_index')
op_type = 'shard_index'
helper = LayerHelper(op_type, **locals())
if shard_id < 0 or shard_id >= nshards:
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4200,13 +4200,16 @@ def _add_op_device_attr_for_op(self, op, idx, block):
op.type == 'elementwise_div'):
device = "gpu:all"
op._set_attr(self._op_device_key, device)
elif op.type == "alloc_float_status":
op._set_attr(self._op_device_key, "gpu:all")
else:
other_known_ops = [
'update_loss_scaling',
'reduce_any',
'concat',
'sum',
'check_finite_and_unscale',
'alloc_float_status',
]
assert op.type in other_known_ops, "For other ops without " \
"op_device set, they must be one of {}, but it " \
Expand Down Expand Up @@ -4272,8 +4275,9 @@ def _check_validation(self, block):
"{} has not been set.".format(op.type))
if device == "gpu:all": continue
dev_type = device.split(':')[0]
assert dev_type == "gpu", ("Now only gpu devices are supported "
"for pipeline parallelism.")
assert dev_type == "gpu" or dev_type == 'npu', (
"Now only gpu and npu devices are supported "
"for pipeline parallelism.")
if not device in device_list:
device_list.append(device)
return device_list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def setUp(self):
vocab = 10
dim = 20
w = np.ones([vocab, dim]).astype(self.dtype)
x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int64)
x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32)
out = np.ones([bsz, seqlen, dim]).astype(self.dtype)

self.inputs = {
Expand Down

0 comments on commit 19cf4ee

Please sign in to comment.