Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.27】为 Paddle 新增 select_scatter API RFC #757

Merged
merged 4 commits into from
Nov 29, 2023

Conversation

YibinLiu666
Copy link
Contributor

@YibinLiu666 YibinLiu666 commented Nov 23, 2023

新增 select_scatter API RFC
PaddlePaddle/Paddle#57262

Copy link

paddle-bot bot commented Nov 23, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请检查PR提交格式和内容是否完备,具体请参考示例模版
Your PR has been submitted. Thanks for your contribution!
Please check its format and content. For this, you can refer to Template and Demo.

const DenseTensor& value,
int axis,
int index,
DenseTensor* out);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方,可以调研下paddle的OP set_value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,我再看看,目前已经实现了一个版本,如果方便的话可以先麻烦您看看嘛

Copy link
Contributor Author

@YibinLiu666 YibinLiu666 Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我参考https://github.com/PaddlePaddle/Paddle/blob/f822c32b3734e971e3b71e2b78dcf16096528d91/python/paddle/base/variable_index.py#L917
这个里面的setitem_static实现了一个版本做了尝试,只把动态模式下的_C_ops.set_value_换成了_C_ops.set_value,但是这个op似乎不支持PIR模式,用我之前的测试代码会爆这个错。
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zoooo0820 这个PIR模式是否需要支持呢

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zoooo0820 这个PIR模式是否需要支持呢

理论上应该支持的,现在PIR里应该是有相关实现的(参考paddle/fluid/pir/dialect/operator/ir/pd_api.cc)。辛苦再调试下,还是不行的话,可以把具体代码和报错贴出来一起看下呢

@YibinLiu666
Copy link
Contributor Author

YibinLiu666 commented Nov 28, 2023

@zoooo0820 我目前实现的代码如下,我的理解是原本能通过setitem_static来组合实现这个算子,我需要做的就是简化解析过程,然后直接调用set_value的底层C算子就行,所以我在setitem的代码基础上修改的代码为:

def select_scatter(src, values, axis, index):
    """
    Embeds the values of the values tensor into src at the given index of axis.

    Args:
        src (Tensor) : The Destination Tensor.
        values (Tensor) : The tensor to embed into src.
        axis (int) : the dimension to insert the slice into.
        index (int) : the index to select with.

    Returns:
        Tensor, same dtype and shape with src

    Examples:
        .. code-block:: python

            >>> import paddle

            >>> x = paddle.zeros((2,3,4)).astype("float32")
            >>> values = paddle.ones((2,4)).astype("float32")
            >>> res = paddle.select_scatter(x,values,1,1)
            >>> print(res)
            Tensor(shape=[2, 3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
                   [[[0., 0., 0., 0.],
                     [1., 1., 1., 1.],
                     [0., 0., 0., 0.]],
                    [[0., 0., 0., 0.],
                     [1., 1., 1., 1.],
                     [0., 0., 0., 0.]]])

    """
    from ..base.framework import default_main_program
    starts = [index]
    ends = [index+1]
    steps = [1]
    axes = [axis]
    none_axes = []
    decrease_axes = [axis]
    inputs = {'Input': src}
    attrs = {
        'axes': axes,
        'starts': starts,
        'ends': ends,
        'steps': steps,
        'decrease_axes': decrease_axes,
        'none_axes': none_axes,
    }

    StartsTensorList = None
    EndsTensorList = None
    StepsTensorList = None

    if paddle.utils._contain_var(starts):
        StartsTensorList = paddle.utils._convert_to_tensor_list(starts)
        inputs['StartsTensorList'] = StartsTensorList
        del attrs['starts']

    if paddle.utils._contain_var(ends):
        EndsTensorList = paddle.utils._convert_to_tensor_list(ends)
        inputs['EndsTensorList'] = EndsTensorList
        del attrs['ends']
    if paddle.utils._contain_var(steps):
        StepsTensorList = paddle.utils._convert_to_tensor_list(steps)
        inputs['StepsTensorList'] = StepsTensorList
        del attrs['steps']

    # step2. Parse values
    dtype = src.dtype
    attrs['dtype'] = dtype

    values = values.astype(dtype)
    inputs["ValueTensor"] = values

    if in_dynamic_or_pir_mode():
        return _C_ops.set_value_with_tensor(
            src,
            values,
            starts,
            ends,
            steps,
            axes,
            decrease_axes,
            none_axes,
        )
    else:
        helper = LayerHelper(
            'set_value', **locals()
        )
        if helper.main_program.current_block_idx != 0:
            # not in global block, we should create a global variable.
            output = helper._create_global_variable_for_type_inference(
                dtype=src.dtype
            )
        else:
            output = helper.create_variable_for_type_inference(
                dtype=src.dtype
            )
        cur_block = default_main_program().current_block()
        cur_block.append_op(
            type="set_value",
            inputs=inputs,
            outputs={'Out': output},
            attrs=attrs,
            inplace_map={"Input": "Out"},
        )

        # map var to the new output
        paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
            cur_block.program, src.desc.id(), output
        )
        return output

但是这个代码在我之前提的PR中测试代码(https://github.com/PaddlePaddle/Paddle/pull/59343/files#diff-dd8e117af37176658197bd7ef61a66708f4f18e5dbf273fdebc6de5eccd4c84c) 中pir模式下的测试样例全挂了

image

@zoooo0820
Copy link
Contributor

@YibinLiu666 这个问题初步看的确是这个算子没适配好PIR,相关问题我这边这两天排查和修复下。PR里可以暂时不用管PIR报错,优先先review代码及完成其他部分的单测。 待前面PIR的问题解决后再看上述报错的问题是否解决呢

@YibinLiu666
Copy link
Contributor Author

@YibinLiu666 这个问题初步看的确是这个算子没适配好PIR,相关问题我这边这两天排查和修复下。PR里可以暂时不用管PIR报错,优先先review代码及完成其他部分的单测。 待前面PIR的问题解决后再看上述报错的问题是否解决呢

目前已经修改了RFC, @zoooo0820 可以先麻烦review一下RFC,我先忽略掉PIR模式的测试修改一下代码


# 四、对比分析

PyTorch 是使用 C++ API 实现的,Python 端直接调用 C++ 接口,性能较好。尽管paddle能够通过算子组合实现该api,但是使用slice来 setitem 性能较差,并且无法达到非inplace的效果。因此计划在实现paddle的`select_scatter`时实现相关c++ kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以在现状中简要说明下paddle set_value OP的情况。以及这里需要修改下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


**axis** (int) – 需要嵌入到src Tensor的维度。

**index** (int) – 选择的索引。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名可以参考下 API 设计和命名规范 主要可以关注下 src / name

此外数据类型上,目前应该是支持全dtype了,可以简单验证下。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

再加上name参数吧,可以参考下其他API

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@zoooo0820
Copy link
Contributor

@YibinLiu666 这个问题初步看的确是这个算子没适配好PIR,相关问题我这边这两天排查和修复下。PR里可以暂时不用管PIR报错,优先先review代码及完成其他部分的单测。 待前面PIR的问题解决后再看上述报错的问题是否解决呢

目前已经修改了RFC, @zoooo0820 可以先麻烦review一下RFC,我先忽略掉PIR模式的测试修改一下代码

上述问题应该已经解决了,PaddlePaddle/Paddle#59457 ,可以带着这个PR的修改测试下PIR下的情况呢

@YibinLiu666
Copy link
Contributor Author

@YibinLiu666 这个问题初步看的确是这个算子没适配好PIR,相关问题我这边这两天排查和修复下。PR里可以暂时不用管PIR报错,优先先review代码及完成其他部分的单测。 待前面PIR的问题解决后再看上述报错的问题是否解决呢

目前已经修改了RFC, @zoooo0820 可以先麻烦review一下RFC,我先忽略掉PIR模式的测试修改一下代码

上述问题应该已经解决了,PaddlePaddle/Paddle#59457 ,可以带着这个PR的修改测试下PIR下的情况呢

好的,我重新测一下

Copy link
Contributor

@zoooo0820 zoooo0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zoooo0820 zoooo0820 merged commit b61aa47 into PaddlePaddle:master Nov 29, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants