Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA reduction: allow outputs to have different strides #42649

Closed
wants to merge 13 commits into from

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Aug 6, 2020

Fixes #42364

Benchmark:
https://github.com/zasdfgbnm/things/blob/master/2020Q3/min-benchmark.ipynb

import torch

print(torch.__version__)
print()

for i in range(100):
    torch.randn(1000, device='cuda')
    
for e in range(7, 15):
    N = 2 ** e
    input_ = torch.randn(N, N, device='cuda')
    torch.cuda.synchronize()
    %timeit input_.min(dim=0); torch.cuda.synchronize()
    input_ = torch.randn(N, N, device='cuda').t()
    torch.cuda.synchronize()
    %timeit input_.min(dim=0); torch.cuda.synchronize()
    print()

Before

1.7.0a0+5d7c3f9

21.7 µs ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
20.6 µs ± 773 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

22.5 µs ± 294 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
20.2 µs ± 250 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

26.4 µs ± 67 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
20.9 µs ± 316 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

33 µs ± 474 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
21.1 µs ± 218 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

84.2 µs ± 691 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
50.3 µs ± 105 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

181 µs ± 2.36 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
145 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

542 µs ± 753 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
528 µs ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

2.04 ms ± 9.74 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.01 ms ± 22.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

After

1.7.0a0+9911817

21.4 µs ± 695 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
20.6 µs ± 989 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

22.4 µs ± 153 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
20.5 µs ± 58.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

26.6 µs ± 147 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
20.9 µs ± 675 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

35.4 µs ± 560 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
21.7 µs ± 1.17 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)

86.5 µs ± 1.99 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
52.2 µs ± 1.57 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

195 µs ± 2.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
153 µs ± 4.46 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

550 µs ± 7.72 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
527 µs ± 3.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

2.05 ms ± 7.87 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2 ms ± 4.93 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

@dr-ci
Copy link

dr-ci bot commented Aug 6, 2020

💊 CI failures summary and remediations

As of commit 95e9a8a (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-CircleCI failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_windows_vs2019_py36_cuda11.0_build (1/1)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

Error generating file
Retry attempt 3: 
C:/Users/circleci/project/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cu(236): error: identifier "cusparseScsrmm2" is undefined

C:/Users/circleci/project/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cu(259): error: identifier "cusparseDcsrmm2" is undefined

2 errors detected in the compilation of "C:/Users/circleci/project/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cu".
SparseCUDABlas.cu
-- Removing C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/sparse/cuda/./torch_cuda_generated_SparseCUDABlas.cu.obj
C:/Jenkins/Miniconda3/Library/bin/cmake.exe -E remove C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/sparse/cuda/./torch_cuda_generated_SparseCUDABlas.cu.obj
CMake Error at torch_cuda_generated_SparseCUDABlas.cu.obj.Release.cmake:281 (message):
  Error generating file
  C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/sparse/cuda/./torch_cuda_generated_SparseCUDABlas.cu.obj


Library\bin\cmake.exe -D verbose:BOOL=ON -D build_configuration:STRING=Release -D generated_file:STRING=C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/./torch_cuda_generated_mean_op.cu.obj -D generated_cubin_file:STRING=C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/./torch_cuda_generated_mean_op.cu.obj.cubin.txt -P C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/torch_cuda_generated_mean_op.cu.obj.Release.cmake" 
-- Removing C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/./torch_cuda_generated_mean_op.cu.obj
C:/Jenkins/Miniconda3/Library/bin/cmake.exe -E remove C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/./torch_cuda_generated_mean_op.cu.obj
-- Generating dependency file: C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/torch_cuda_generated_mean_op.cu.obj.NVCC-depend
hird_party/catch/single_include -IC:/Users/circleci/project/aten/src/ATen/.. -IC:/Users/circleci/project/build/caffe2/aten/src/ATen -IC:/Users/circleci/project/c10/cuda/../.. -IC:/Users/circleci/project/c10/../ "-IC:/Program Files/NVIDIA Corporation/NvToolsExt/include" -IC:/Users/circleci/project/torch/csrc/api -IC:/Users/circleci/project/torch/csrc/api/include -IC:/Users/circleci/project/build/third_party/ideep/mkl-dnn/include -IC:/Users/circleci/project/third_party/ideep/mkl-dnn/src/../include
mean_op.cu 
-- Generating temporary cmake readable file: C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/operators/torch_cuda_generated_mean_op.cu.obj.depend.tmp

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 43 times.

@zasdfgbnm zasdfgbnm requested a review from ngimel August 6, 2020 01:58
@@ -219,18 +229,27 @@ __global__ void reduce_kernel(R reduction) {
reduction.template run<output_vec_size>();
}

template <typename index_t>
static OffsetCalculator<2, index_t> make_output_calculator(const TensorIterator& iter) {
template <typename index_t, int num_outputs=1>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you have static_assert that num_outputs is 1 or 2? Nothing else is supported, and OffsetCalculator hardcodes 2 in the else case

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Aug 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_assert added

@@ -344,13 +365,13 @@ struct ReduceOp {
extern __shared__ char shared_memory[];
index_t output_idx = config.output_idx<output_vec_size>();
index_t input_idx = config.input_idx();
auto base_offsets1 = output_calc.get(output_idx)[1];
auto base_offsets_input = output_calc.get(output_idx)[num_outputs];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_outputs and config.num_outputs are totally different things, right? num_outputs is number of output tensors, usually 1, but 2 for min/max operations, and config.num_outputs is number of outputs produced by a single thread?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes they are totally different things. And config.num_outputs I think it is iter.num_output_elements() which is the numel of output tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed config.num_outputs to config.num_output_elements

base_offsets[i] = output_calc.get(output_idx + i)[0];
out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
base_offsets[i] = output_calc.get(output_idx + i);
out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i][0]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is suspicious - how is this going to work for multiple outputs? Or multiple outputs case never takes this path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this is used for storing intermediate accumulation data. And for accumulation, it stores arg_t, not output_scalar_t, and when accumulating using output tensor, it only uses the first output tensor.

if (noutputs >= 1) {
auto res0 = (T1*)((char*)dst[0] + base_offset);
auto res0 = (T1*)((char*)dst[0] + base_offset[0]);
*res0 = x.first;
}
if (noutputs >= 2) {
// base offset is computed assuming element size being sizeof(T1), so we need to make a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comment out of date now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so because TensorIterator gives you the stride data in terms of bytes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, it is out of date now. Deleted from the code.

@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 7, 2020
for perm3 in permutations:
input_ = torch.randn(10, 10, 10, 1, dtype=dtype, device=device).expand(10, 10, 10, 10).permute(*perm1)
expect1, expect2 = input_.min(dim=0, keepdim=True)
out1 = expect1.permute(perm2).clone().contiguous().permute(rev_perm(perm2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you really need empty_like(expect1) here, and not permuted clone of expect?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be empty_like(expect1.permute(perm2)).contiguous().permute(rev_perm(perm2))
I need to do a permute and then permute it back so that the shape is unchanged, but the strides are permuted.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced it with torch.empty_like(expect1.permute(perm2), memory_format=torch.contiguous_format)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zasdfgbnm zasdfgbnm deleted the different-strides-output-reduction branch August 12, 2020 20:21
@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 7f3f502.

@ngimel
Copy link
Collaborator

ngimel commented Aug 13, 2020

@jeffdaily we reverted this PR because it was causting consistent ROCm failures in EmbeddingBag tests https://app.circleci.com/pipelines/github/pytorch/pytorch/199835/workflows/6267d613-5d88-4f1f-a221-607208390a30/jobs/6678621, do you know what might be causing those?

@jeffdaily
Copy link
Collaborator

@ngimel taking a look now.

@jeffdaily
Copy link
Collaborator

@ngimel I've narrowed down the failure in test_nn.py to this line:

self._test_EmbeddingBag(device, 'max', False, dtype)

If I comment it out, the fp16 test TestNNDeviceTypeCUDA.test_embedding_bag_device_cuda_float16 passes. Still digging.

@ngimel
Copy link
Collaborator

ngimel commented Aug 14, 2020

It makes sense that failure is happening in max mode, the PR changed max reduction which is probably called somewhere. Thanks for looking into it!

@jeffdaily
Copy link
Collaborator

@ngimel I was unable to determine why this test was failing for ROCm 3.5, but the good news is that the test is passing for ROCm 3.7. That release only recently became available and we are in the process of upgrading our CI images to ROCm 3.7. In the meantime, I believe you could try and resubmit this PR if you were to conditionally skip this failing test. Since it passes for ROCm 3.7, I suggest adding the following to torch/testing/_internal/common_utils.py to allow skipping based on ROCm version.

diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 1210acdcfd..66ddadad64 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -333,6 +333,7 @@ TEST_WITH_ASAN = os.getenv('PYTORCH_TEST_WITH_ASAN', '0') == '1'
 TEST_WITH_TSAN = os.getenv('PYTORCH_TEST_WITH_TSAN', '0') == '1'
 TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1'
 TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1'
+HIP_VERSION = 0.0 if torch.version.hip is None else float(re.search(r"^\d+\.\d+", torch.version.hip)[0])
 # Enables tests that are slow to run (disabled by default)
 TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1'

@@ -387,6 +388,17 @@ def skipIfRocm(fn):
             fn(*args, **kwargs)
     return wrapper

+def skipIfRocmVersionLt(version):
+    def wrap(fn):
+        @wraps(fn)
+        def wrapped_fn(*args, **kwargs):
+            if TEST_WITH_ROCM and HIP_VERSION < version:
+                raise unittest.SkipTest("test doesn't work with ROCm stack < %s" % version)
+            else:
+                fn(*args, **kwargs)
+        return wrapped_fn
+    return wrap
+

Then you can decorate this test with @skipIfRocmVersionLt(3.7) after having imported the decorator. from torch.testing._internal.common_utils import skipIfRocmVersion.

@jeffdaily
Copy link
Collaborator

I suppose the greater concern is not just skipping a test, but that the functionality enabled by this PR is simply broken for ROCm float16 type in ROCm 3.5, in which case, we would expect additional tests to fail that use the max operator until our CI is updated to ROCm 3.7.

@jeffdaily
Copy link
Collaborator

@ngimel please attempt to reland this PR now that ROCm CI images have moved to 3.7.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adamax not working in Pytorch1.6
8 participants