-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mthreads/master #298
base: master
Are you sure you want to change the base?
Mthreads/master #298
Conversation
Signed-off-by: Jian Li <jian.li@mthreads.com>
config: {BLOCK_M: 8, num_warps: 8} will cause the number of registers within a single thread to be exceeded when the tensor shape is 4096 * 2304, so reduce BLOCK_M to 4 to supprot cumsum.
Signed-off-by: Jian Li <jian.li@mthreads.com>
Signed-off-by: Jian Li <jian.li@mthreads.com>
- Torch_musa does not support fp64 input type, so CPU is used as a reference
- Does not support test_accuracy_groupnorm - Some use cases have accuracy issues in test_embedding
Signed-off-by: Jian Li <jian.li@mthreads.com>
Signed-off-by: Jian Li <jian.li@mthreads.com>
Signed-off-by: Jian Li <jian.li@mthreads.com>
Modify the function parameter type declaration so that it can run in python 3.8 --------- Co-authored-by: zhengyang <zhengyang@baai.ac.cn>
Signed-off-by: jiaqi.wang <jiaqi.wang@mthreads.com>
Add _weight_norm op, while the original _weight_norm op changed to _weight_norm_interface op.
* add Ops & UT & Bench * add full zero ones Ops & UT & Bench * split normal op * [Operator] init slice&select scatter * code format * PR comment * split test_special_ops * add K-S test * split special perf * Exponential added. (#138) * exponential added. * Added K-S tests to exponential_, fp64 corrected. * aligned with aten prototype * Exponential_ uses uint64 offsets in Triton kernel. * Update pyproject config for new test dependencies. * resolve conflict * Use int64 indexing when needed & fix argmax (#146) 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max; 2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size * test for op * test for op * Making libentry thread safe (#136) * libentry now is lock protected. * Add multithreading tests for libentry. * polish code. * add argparse * fix desc * fix num * Update test_specific_ops.py * split UT files * fix * fix * [Operator] Optimize CrossEntropyLoss (#131) reimplement cross_entropy_loss forward and backward support; indices/probabilities/weight/reduction/ignore_index/label_smoothing; perform better than torch eager on large scale tensors * Exponential added. (#138) * exponential added. * Added K-S tests to exponential_, fp64 corrected. * aligned with aten prototype * Exponential_ uses uint64 offsets in Triton kernel. * Update pyproject config for new test dependencies. * Use int64 indexing when needed & fix argmax (#146) 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max; 2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size * Making libentry thread safe (#136) * libentry now is lock protected. * Add multithreading tests for libentry. * polish code. * [Test] Test for op (#151) * [chore] solve slice&select scatter's test cases * [fix] fix slice&select scatter's test cases * [chore] remove out-of-range indices in select_scatter's test cases * [chore] simplify slice_scatter's test cases * [fix] Added range that is deleted by mistake * Merge branch 'master' into slice&select_scatter * [chore] reformat * [fix] typo * [chore] Considering perf, pause the replacement of some aTen operators * slice_scatter * select_scatter * index_select * [fix] Add libentry in op.cumsum * [fix] Del slice&select scatter's perf tests * [Chore] Add pytest mark for slice&select scatter's test * [Fix] Correct slice_scatter test * [Fix] Replace CPU Tensor --------- Co-authored-by: Bowen12992 <zhangbluestars@gmail.com> Co-authored-by: Tongxin Bai <waffle.bai@gmail.com> Co-authored-by: Clement Chan <iclementine@outlook.com> Co-authored-by: Bowen <81504862+Bowen12992@users.noreply.github.com> Co-authored-by: StrongSpoon <35829812+StrongSpoon@users.noreply.github.com>
* benchmark fix * add seven new testing parameters * move shapes info to yaml file * Added the BenchmarkMetrics & BenchmarkResult abstraction
* [Bugfix] Handle negative input dimensions in 'cat' operator Co-authored-by: 2niuhe<tang.kang1@zte.com.cn>
* Add Script to Calculate Summary Information for Benchmark Results
* specializing slice_scatter. WIP. * polish and refine 2d_inner cases. * fix slice_scatter error on 1d inputs. * test slice_scatter fallback
* Relocate select and slice benchmarks to test_select_and_slice_perf.py * sort keys for summary result * clean cuda cache after benchmark * fix repeat_interleave * modify format for summary info
Signed-off-by: jiaqi.wang <jiaqi.wang@mthreads.com>
Signed-off-by: jiaqi.wang <jiaqi.wang@mthreads.com>
Signed-off-by: chuanjiang.ma <chuanjiang.ma@mthreads.com>
Signed-off-by: jiaqi.wang <jiaqi.wang@mthreads.com>
Signed-off-by: chuanjiang.ma <chuanjiang.ma@mthreads.com>
Signed-off-by: chuanjiang.ma <chuanjiang.ma@mthreads.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
review done.
@@ -99,7 +99,7 @@ def weight_norm_input_fn(shape, dtype, device): | |||
weight_norm_interface_input_fn, | |||
), | |||
("weight_norm", torch._weight_norm, weight_norm_input_fn), | |||
("vector_norm", torch.linalg.vector_norm, unary_input_fn), | |||
# ("vector_norm", torch.linalg.vector_norm, unary_input_fn), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not support vector_norm and var_mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
poor performance recently, maybe requires more optimization work in compiler, and we prefer to regard it as not supported yet.
# Complex Operations | ||
("resolve_neg", torch.resolve_neg, [torch.cfloat], resolve_neg_input_fn), | ||
("resolve_conj", torch.resolve_conj, [torch.cfloat], resolve_conj_input_fn), | ||
# ("resolve_neg", torch.resolve_neg, [torch.cfloat], resolve_neg_input_fn), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same reason as vector_norm
src/flag_gems/ops/weightnorm.py
Outdated
@@ -236,7 +236,7 @@ def norm_kernel( | |||
v_shape0, | |||
v_shape1, | |||
v_shape2, | |||
eps, | |||
eps: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recommend not to hint eps as tl.constexpr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not our change but a code sync issue, previous version in flaggems marked "eps" as tl.constexpr, we will sync the change, thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
tests/test_norm_ops.py
Outdated
@@ -74,6 +74,7 @@ def test_accuracy_groupnorm(N, C, H, W, num_groups, dtype): | |||
gems_assert_close(res_bias_grad, ref_bias_grad, dtype, reduce_dim=N * HW) | |||
|
|||
|
|||
@pytest.mark.skip("triton_musa unsupport") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please figure out the reason why not support LayerNorm, cause group norm with similar algorithm is supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it may be a compiler bug and under locating, we will support this op later after the bugs fixed.
@@ -137,6 +138,7 @@ def test_accuracy_cross_entropy_loss_indices( | |||
gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=shape[dim]) | |||
|
|||
|
|||
@pytest.mark.skip("random error") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the absolute difference between result and reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already fixed in our latest triton but in this time test, we still provided elder version of triton, just make it skip in this time test, is it okay for you?
value_tensor = torch.tensor(value, device="cuda", dtype=dtype) | ||
ref_out_tensor = torch.fill(ref_x, value_tensor) | ||
value_tensor = torch.tensor(value, device="musa", dtype=dtype) | ||
ref_value_tensor = to_reference(value_tensor, False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could be merged to master as a bug fix.
tests/accuracy_utils.py
Outdated
@@ -127,6 +127,15 @@ def to_reference(inp, upcast=False): | |||
return ref_inp | |||
|
|||
|
|||
def to_reference_gpu(inp, upcast=False, device='musa'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it used in test code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
confirmed, should be removed, we will fix it.
@@ -1075,7 +1075,8 @@ def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None) | |||
|
|||
assert isinstance(scalar_fn, JITFunction) | |||
self._scalar_fn = scalar_fn | |||
self._scalar_fn_cache_key = scalar_fn.cache_key | |||
# FIXME: cache_key is too long and make open file failed. | |||
self._scalar_fn_cache_key = scalar_fn.cache_key[:33] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if slicing will bring risk. theoretically, there might exist a small probability that two keys have the same prefix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it risks but very affects very little, the collision probability is
src/flag_gems/testing/__init__.py
Outdated
@@ -9,6 +9,11 @@ | |||
torch.bfloat16: 0.016, | |||
} | |||
|
|||
RESOLUTION_DROPOUT = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it used in the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will remove it along with "to_reference_gpu"
src/flag_gems/ops/rsub.py
Outdated
return y - x * alpha | ||
|
||
|
||
def rsub(A, B, *, alpha=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually aten::rsub calls the sub kernel. now that you reimplemented it, just register it into the library;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"rsub.py" will be removed.
Signed-off-by: machuanjiang <chuanjiang.ma@mthreads.com>
Signed-off-by: chuanjiang.ma <chuanjiang.ma@mthreads.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done. just fix epsilon and we could start testing.
src/flag_gems/ops/weightnorm.py
Outdated
@@ -274,7 +274,7 @@ def norm_bwd_kernel( | |||
v_shape0, | |||
v_shape1, | |||
v_shape2, | |||
eps, | |||
eps: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eps hint
Signed-off-by: chuanjiang.ma <chuanjiang.ma@mthreads.com>
1. one test in special_op test change the device type from cuda to musa Signed-off-by: chuanjiang.ma <chuanjiang.ma@mthreads.com>
PR Category
Operator | OP Test | Benchmark
Type of Change
Bug Fix | Performance Optimization | Refactor
Description
mthreads musa backend compatible modification
Issue
N/A
Progress
Performance
N/A