Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

triton2 support + torch >= 1.14 gmm bias support #171

Merged
merged 53 commits into from
Jan 10, 2023
Merged

Conversation

puririshi98
Copy link
Contributor

@puririshi98 puririshi98 commented Jan 6, 2023

triton 2.0 issues:

_________________________________ test_triton __________________________________
203    @onlyCUDA
204    @onlyTriton
205    def test_triton():
206        x = torch.rand(100, device='cuda')
207        y = torch.rand(100, device='cuda')
208>       assert torch.allclose(x + y, add(x, y))
209test/test_triton.py:37: 
210_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
211test/test_triton.py:28: in add
212    add_kernel[grid](x, y, out, x.numel(), BLOCK_SIZE=1024)
213_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
214args = (tensor([0.6357, 0.4071, 0.4647, 0.2352, 0.8715, 0.7922, 0.7041, 0.1047, 0.1574,
215        0.3703, 0.3483, 0.2265, 0.943...0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
216        0., 0., 0., 0.], device='cuda:0'), 100)
217kwargs = {'BLOCK_SIZE': 1024}
218    def launcher(*args, **kwargs):
219>       return self.run(*args, grid=grid, **kwargs)
220E       TypeError: add_kernel() got an unexpected keyword argument 'BLOCK_SIZE'


__________________________ test_fused_scatter_reduce ___________________________
    @onlyCUDA
    @onlyTriton
    def test_fused_scatter_reduce():
        x = torch.randn(5, 4, device='cuda')
        index = torch.tensor([0, 1, 0, 1, 0], device='cuda')
    
>       out = fused_scatter_reduce(x, index, dim_size=2,
                                   reduce_list=['sum', 'mean'])
test/ops/test_scatter_reduce.py:13: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.8/dist-packages/pyg_lib-0.1.0-py3.8-linux-x86_64.egg/pyg_lib/ops/scatter_reduce.py:133: in fused_scatter_reduce
    fused_scatter_reduce_kernel[grid](
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
args = (tensor([[-0.1390, -0.5071,  0.1290, -0.1594],
        [-1.2358, -0.3644, -0.5571, -0.9715],
        [-0.3111, -0.7264...a:0'), tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0'), 4, 2, 20)
kwargs = {'BLOCK_SIZE': 256, 'REDUCE_LIST': ['sum', 'mean', 'none', 'none']}
    def launcher(*args, **kwargs):
>       return self.run(*args, grid=grid, **kwargs)
E       TypeError: fused_scatter_reduce_kernel() got an unexpected keyword argument 'REDUCE_LIST'

torch >= 1.14 issue:

            if biases is not None:
>               outs += torch.nested.as_nested_tensor(biases)
E               RuntimeError: add_ does not support broadcasting when given a NestedTensor

@puririshi98 puririshi98 changed the title triton2 support triton2 support + torch > 1.14 gmm bias support Jan 6, 2023
@puririshi98 puririshi98 changed the title triton2 support + torch > 1.14 gmm bias support triton2 support + torch >= 1.14 gmm bias support Jan 6, 2023
@codecov-commenter
Copy link

codecov-commenter commented Jan 6, 2023

Codecov Report

Merging #171 (b0f2a24) into master (ef340dd) will not change coverage.
The diff coverage is n/a.

@@           Coverage Diff           @@
##           master     #171   +/-   ##
=======================================
  Coverage   93.14%   93.14%           
=======================================
  Files          23       23           
  Lines         729      729           
=======================================
  Hits          679      679           
  Misses         50       50           

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@puririshi98 puririshi98 changed the title triton2 support + torch >= 1.14 gmm bias support Draft: triton2 support + torch >= 1.14 gmm bias support Jan 6, 2023
@puririshi98 puririshi98 changed the title Draft: triton2 support + torch >= 1.14 gmm bias support triton2 support + torch >= 1.14 gmm bias support Jan 6, 2023
@puririshi98
Copy link
Contributor Author

finally got it working, triton is a headache

7 passed

@puririshi98 puririshi98 enabled auto-merge (squash) January 6, 2023 23:45
@puririshi98 puririshi98 merged commit a045a47 into master Jan 10, 2023
@puririshi98 puririshi98 deleted the triton2_support branch January 10, 2023 13:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants