Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Reduce] optimize and unify reduce operator to a single place #311

Merged
merged 17 commits into from
Jul 13, 2023

Conversation

xinli-git
Copy link
Collaborator

This changes optimizes the reduce operator and unifies both reduce and reduce_f16 into a single class / implementation.

On Llama, 255 tokens (RTX3090)
Before: 6.173s
After: 5.295s

Contribution from reduce:
before: 732ms (13.6% of total)
after: 126ms (2.7% of total)

Copy link
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xinli-git ! You are getting more professional in using hidet script :)

python/hidet/graph/ops/reduce/resolve.py Outdated Show resolved Hide resolved
python/hidet/graph/ops/reduce/reduce.py Outdated Show resolved Hide resolved
@xinli-git xinli-git merged commit 571800a into hidet-org:main Jul 13, 2023
@xinli-git xinli-git deleted the reduce_unify branch August 21, 2023 02:49
vadiklyutiy pushed a commit that referenced this pull request Jul 22, 2024
…. Fixed (#311)

when device is None, `device_from_torch` returns 'cpu' by default. As a
result it copies some of the tensors to 'cpu' device. The correct
behaviours is to not to move the tensor at all.

---------

Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
vadiklyutiy pushed a commit that referenced this pull request Jul 23, 2024
…. Fixed (#311)

when device is None, `device_from_torch` returns 'cpu' by default. As a
result it copies some of the tensors to 'cpu' device. The correct
behaviours is to not to move the tensor at all.

---------

Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
vadiklyutiy pushed a commit that referenced this pull request Dec 26, 2024
…. Fixed (#311)

when device is None, `device_from_torch` returns 'cpu' by default. As a
result it copies some of the tensors to 'cpu' device. The correct
behaviours is to not to move the tensor at all.

---------

Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants