Skip to content

Commit

Permalink
Merge pull request #1472 from TeslaZhao/develop
Browse files Browse the repository at this point in the history
python pipeline support lod input
  • Loading branch information
bjjwwang authored Nov 8, 2021
2 parents 0376fd4 + 433ff7f commit 54a78cc
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
16 changes: 13 additions & 3 deletions python/pipeline/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,20 +1589,27 @@ def proto_tensor_2_numpy(self, tensor):
tensor: one tensor in request.tensors.
Returns:
np.ndnumpy
np_data: np.ndnumpy, the tensor data is converted to numpy.
lod_info: np.ndnumpy, lod info of the tensor data, None default.
"""
if tensor is None or tensor.elem_type is None or tensor.name is None:
_LOGGER.error("input params of tensor is wrong. tensor: {}".format(
tensor))
return None

# Set dim shape
dims = []
if tensor.shape is None:
dims.append(1)
else:
for one_dim in tensor.shape:
dims.append(one_dim)

# Set up 2-d lod tensor
np_lod = None
if len(tensor.lod) > 0:
np_lod = np.array(tensor.lod).astype(int32).reshape(2, -1)

np_data = None
_LOGGER.info("proto_to_numpy, name:{}, type:{}, dims:{}".format(
tensor.name, tensor.elem_type, dims))
Expand Down Expand Up @@ -1648,7 +1655,7 @@ def proto_tensor_2_numpy(self, tensor):
"Sorry, the type {} of tensor {} is not supported.".format(
tensor.elem_type, tensor.name))

return np_data
return np_data, np_lod

def unpack_request_package(self, request):
"""
Expand Down Expand Up @@ -1705,7 +1712,10 @@ def unpack_request_package(self, request):

dict_data[name] = new_string
else:
dict_data[name] = self.proto_tensor_2_numpy(one_tensor)
np_data, np_lod = self.proto_tensor_2_numpy(one_tensor)
dict_data[name] = np_data
if np_lod is not None:
dict_data[name + ".lod"] = np_lod

_LOGGER.info("RequestOp unpack one request. log_id:{}, clientip:{} \
name:{}, method:{}, time:{}"
Expand Down
13 changes: 13 additions & 0 deletions python/pipeline/pipeline_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def _pack_request_package(self, feed_dict, pack_tensor_format,
else:
# pack tensor format
for key, value in feed_dict.items():

# skipping the lod feed_var.
# The declare of lod feed_var must be hebind the feed_var.
if ".lod" in key:
continue

one_tensor = req.tensors.add()
one_tensor.name = key

Expand All @@ -114,6 +120,13 @@ def _pack_request_package(self, feed_dict, pack_tensor_format,
for one_dim in value.shape:
one_tensor.shape.append(one_dim)

# set lod info, must be list type.
lod_key = key + ".lod"
if lod_key in feed_dict:
lod_list = feed_dict.get(lod_key)
if lod_list is not None:
one_tensor.lod.extend(lod_list)

# packed into bytes
if use_tensor_bytes is True:
np_bytes = BytesIO()
Expand Down

0 comments on commit 54a78cc

Please sign in to comment.