Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mthreads/master #298

Open
wants to merge 75 commits into
base: master
Choose a base branch
from
Open

Mthreads/master #298

wants to merge 75 commits into from

Conversation

machuanjiang
Copy link
Collaborator

PR Category

Operator | OP Test | Benchmark

Type of Change

Bug Fix | Performance Optimization | Refactor

Description

mthreads musa backend compatible modification

Issue

N/A

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

N/A

yuzhe-wu and others added 30 commits October 17, 2024 17:30
Signed-off-by: Jian Li <jian.li@mthreads.com>
config: {BLOCK_M: 8, num_warps: 8} will cause the number of registers
within a single thread to be exceeded when the tensor shape is 4096 * 2304,
so reduce BLOCK_M to 4 to supprot cumsum.
Signed-off-by: Jian Li <jian.li@mthreads.com>
Signed-off-by: Jian Li <jian.li@mthreads.com>
- Torch_musa does not support fp64 input type, so CPU is used as a reference
- Does not support test_accuracy_groupnorm

- Some use cases have accuracy issues in test_embedding
Signed-off-by: Jian Li <jian.li@mthreads.com>
Signed-off-by: Jian Li <jian.li@mthreads.com>
Signed-off-by: Jian Li <jian.li@mthreads.com>
ZaccurLi and others added 23 commits October 30, 2024 19:38
Modify the function parameter type declaration so that it can run in python 3.8

---------

Co-authored-by: zhengyang <zhengyang@baai.ac.cn>
Signed-off-by: jiaqi.wang <jiaqi.wang@mthreads.com>
Add _weight_norm op, while the original _weight_norm op changed to _weight_norm_interface op.
* add Ops & UT & Bench

* add full zero ones Ops & UT & Bench

* split normal op

* [Operator] init slice&select scatter

* code format

* PR comment

* split test_special_ops

* add K-S test

* split special perf

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* resolve conflict

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* test for op

* test for op

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* add argparse

* fix desc

* fix num

* Update test_specific_ops.py

* split UT files

* fix

* fix

* [Operator] Optimize CrossEntropyLoss (#131)

reimplement cross_entropy_loss forward and backward
support; indices/probabilities/weight/reduction/ignore_index/label_smoothing; perform better than torch eager on large scale tensors

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* [Test] Test for op (#151)

* [chore] solve slice&select scatter's test cases

* [fix] fix slice&select scatter's test cases

* [chore] remove out-of-range indices in select_scatter's test cases

* [chore] simplify slice_scatter's test cases

* [fix] Added range that is deleted by mistake

* Merge branch 'master' into slice&select_scatter

* [chore] reformat

* [fix] typo

* [chore] Considering perf, pause the replacement of some aTen operators
* slice_scatter
* select_scatter
* index_select

* [fix] Add libentry in op.cumsum

* [fix] Del slice&select scatter's perf tests

* [Chore] Add pytest mark for slice&select scatter's test

* [Fix] Correct slice_scatter test

* [Fix] Replace CPU Tensor

---------

Co-authored-by: Bowen12992 <zhangbluestars@gmail.com>
Co-authored-by: Tongxin Bai <waffle.bai@gmail.com>
Co-authored-by: Clement Chan <iclementine@outlook.com>
Co-authored-by: Bowen <81504862+Bowen12992@users.noreply.github.com>
Co-authored-by: StrongSpoon <35829812+StrongSpoon@users.noreply.github.com>
* benchmark fix

*  add seven new testing parameters

* move shapes info to yaml file

* Added the BenchmarkMetrics & BenchmarkResult  abstraction
* [Bugfix] Handle negative input dimensions in 'cat' operator

Co-authored-by: 2niuhe<tang.kang1@zte.com.cn>
* Add Script to Calculate Summary Information for Benchmark Results
* specializing slice_scatter. WIP.

* polish and refine 2d_inner cases.

* fix slice_scatter error on 1d inputs.

* test slice_scatter fallback
* Relocate select and slice benchmarks to test_select_and_slice_perf.py

* sort keys for summary result

* clean cuda cache after benchmark

* fix repeat_interleave

* modify format for summary info
Signed-off-by: jiaqi.wang <jiaqi.wang@mthreads.com>
Signed-off-by: jiaqi.wang <jiaqi.wang@mthreads.com>
Signed-off-by: chuanjiang.ma <chuanjiang.ma@mthreads.com>
Signed-off-by: jiaqi.wang <jiaqi.wang@mthreads.com>
Signed-off-by: chuanjiang.ma <chuanjiang.ma@mthreads.com>
Signed-off-by: chuanjiang.ma <chuanjiang.ma@mthreads.com>
Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review done.

@@ -99,7 +99,7 @@ def weight_norm_input_fn(shape, dtype, device):
weight_norm_interface_input_fn,
),
("weight_norm", torch._weight_norm, weight_norm_input_fn),
("vector_norm", torch.linalg.vector_norm, unary_input_fn),
# ("vector_norm", torch.linalg.vector_norm, unary_input_fn),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not support vector_norm and var_mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

poor performance recently, maybe requires more optimization work in compiler, and we prefer to regard it as not supported yet.

# Complex Operations
("resolve_neg", torch.resolve_neg, [torch.cfloat], resolve_neg_input_fn),
("resolve_conj", torch.resolve_conj, [torch.cfloat], resolve_conj_input_fn),
# ("resolve_neg", torch.resolve_neg, [torch.cfloat], resolve_neg_input_fn),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same reason as vector_norm

@@ -236,7 +236,7 @@ def norm_kernel(
v_shape0,
v_shape1,
v_shape2,
eps,
eps: tl.constexpr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recommend not to hint eps as tl.constexpr

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not our change but a code sync issue, previous version in flaggems marked "eps" as tl.constexpr, we will sync the change, thanks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -74,6 +74,7 @@ def test_accuracy_groupnorm(N, C, H, W, num_groups, dtype):
gems_assert_close(res_bias_grad, ref_bias_grad, dtype, reduce_dim=N * HW)


@pytest.mark.skip("triton_musa unsupport")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please figure out the reason why not support LayerNorm, cause group norm with similar algorithm is supported.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it may be a compiler bug and under locating, we will support this op later after the bugs fixed.

@@ -137,6 +138,7 @@ def test_accuracy_cross_entropy_loss_indices(
gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=shape[dim])


@pytest.mark.skip("random error")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the absolute difference between result and reference?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already fixed in our latest triton but in this time test, we still provided elder version of triton, just make it skip in this time test, is it okay for you?

value_tensor = torch.tensor(value, device="cuda", dtype=dtype)
ref_out_tensor = torch.fill(ref_x, value_tensor)
value_tensor = torch.tensor(value, device="musa", dtype=dtype)
ref_value_tensor = to_reference(value_tensor, False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be merged to master as a bug fix.

@@ -127,6 +127,15 @@ def to_reference(inp, upcast=False):
return ref_inp


def to_reference_gpu(inp, upcast=False, device='musa'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it used in test code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed, should be removed, we will fix it.

@@ -1075,7 +1075,8 @@ def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None)

assert isinstance(scalar_fn, JITFunction)
self._scalar_fn = scalar_fn
self._scalar_fn_cache_key = scalar_fn.cache_key
# FIXME: cache_key is too long and make open file failed.
self._scalar_fn_cache_key = scalar_fn.cache_key[:33]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if slicing will bring risk. theoretically, there might exist a small probability that two keys have the same prefix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it risks but very affects very little, the collision probability is $P≈2^{-128}$, if do insist that should be modified, we will fix it. Anyway, we will deep in again and check the root reason of the failure, sorry about that

@@ -9,6 +9,11 @@
torch.bfloat16: 0.016,
}

RESOLUTION_DROPOUT = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it used in the test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will remove it along with "to_reference_gpu"

return y - x * alpha


def rsub(A, B, *, alpha=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually aten::rsub calls the sub kernel. now that you reimplemented it, just register it into the library;)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"rsub.py" will be removed.

StrongSpoon and others added 3 commits November 21, 2024 12:11
Signed-off-by: machuanjiang <chuanjiang.ma@mthreads.com>
Signed-off-by: chuanjiang.ma <chuanjiang.ma@mthreads.com>
Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. just fix epsilon and we could start testing.

@@ -274,7 +274,7 @@ def norm_bwd_kernel(
v_shape0,
v_shape1,
v_shape2,
eps,
eps: tl.constexpr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eps hint

Signed-off-by: chuanjiang.ma <chuanjiang.ma@mthreads.com>
1. one test in special_op test change the device type from cuda to musa

Signed-off-by: chuanjiang.ma <chuanjiang.ma@mthreads.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.