-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA reduction: allow outputs to have different strides #42649
Conversation
💊 CI failures summary and remediationsAs of commit 95e9a8a (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_windows_vs2019_py36_cuda11.0_build (1/1)Step: "Build" (full log | diagnosis details | 🔁 rerun)
|
…rides-output-reduction
@@ -219,18 +229,27 @@ __global__ void reduce_kernel(R reduction) { | |||
reduction.template run<output_vec_size>(); | |||
} | |||
|
|||
template <typename index_t> | |||
static OffsetCalculator<2, index_t> make_output_calculator(const TensorIterator& iter) { | |||
template <typename index_t, int num_outputs=1> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should you have static_assert that num_outputs is 1 or 2? Nothing else is supported, and OffsetCalculator hardcodes 2 in the else
case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static_assert
added
@@ -344,13 +365,13 @@ struct ReduceOp { | |||
extern __shared__ char shared_memory[]; | |||
index_t output_idx = config.output_idx<output_vec_size>(); | |||
index_t input_idx = config.input_idx(); | |||
auto base_offsets1 = output_calc.get(output_idx)[1]; | |||
auto base_offsets_input = output_calc.get(output_idx)[num_outputs]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_outputs
and config.num_outputs
are totally different things, right? num_outputs
is number of output tensors, usually 1, but 2 for min/max operations, and config.num_outputs
is number of outputs produced by a single thread?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes they are totally different things. And config.num_outputs
I think it is iter.num_output_elements()
which is the numel of output tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed config.num_outputs
to config.num_output_elements
base_offsets[i] = output_calc.get(output_idx + i)[0]; | ||
out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); | ||
base_offsets[i] = output_calc.get(output_idx + i); | ||
out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i][0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is suspicious - how is this going to work for multiple outputs? Or multiple outputs case never takes this path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this is used for storing intermediate accumulation data. And for accumulation, it stores arg_t, not output_scalar_t, and when accumulating using output tensor, it only uses the first output tensor.
aten/src/ATen/native/cuda/Reduce.cuh
Outdated
if (noutputs >= 1) { | ||
auto res0 = (T1*)((char*)dst[0] + base_offset); | ||
auto res0 = (T1*)((char*)dst[0] + base_offset[0]); | ||
*res0 = x.first; | ||
} | ||
if (noutputs >= 2) { | ||
// base offset is computed assuming element size being sizeof(T1), so we need to make a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this comment out of date now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess so because TensorIterator gives you the stride data in terms of bytes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, it is out of date now. Deleted from the code.
test/test_torch.py
Outdated
for perm3 in permutations: | ||
input_ = torch.randn(10, 10, 10, 1, dtype=dtype, device=device).expand(10, 10, 10, 10).permute(*perm1) | ||
expect1, expect2 = input_.min(dim=0, keepdim=True) | ||
out1 = expect1.permute(perm2).clone().contiguous().permute(rev_perm(perm2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you really need empty_like(expect1)
here, and not permuted clone of expect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be empty_like(expect1.permute(perm2)).contiguous().permute(rev_perm(perm2))
I need to do a permute and then permute it back so that the shape is unchanged, but the strides are permuted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I replaced it with torch.empty_like(expect1.permute(perm2), memory_format=torch.contiguous_format)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@jeffdaily we reverted this PR because it was causting consistent ROCm failures in EmbeddingBag tests https://app.circleci.com/pipelines/github/pytorch/pytorch/199835/workflows/6267d613-5d88-4f1f-a221-607208390a30/jobs/6678621, do you know what might be causing those? |
@ngimel taking a look now. |
It makes sense that failure is happening in |
@ngimel I was unable to determine why this test was failing for ROCm 3.5, but the good news is that the test is passing for ROCm 3.7. That release only recently became available and we are in the process of upgrading our CI images to ROCm 3.7. In the meantime, I believe you could try and resubmit this PR if you were to conditionally skip this failing test. Since it passes for ROCm 3.7, I suggest adding the following to diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 1210acdcfd..66ddadad64 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -333,6 +333,7 @@ TEST_WITH_ASAN = os.getenv('PYTORCH_TEST_WITH_ASAN', '0') == '1'
TEST_WITH_TSAN = os.getenv('PYTORCH_TEST_WITH_TSAN', '0') == '1'
TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1'
TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1'
+HIP_VERSION = 0.0 if torch.version.hip is None else float(re.search(r"^\d+\.\d+", torch.version.hip)[0])
# Enables tests that are slow to run (disabled by default)
TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1'
@@ -387,6 +388,17 @@ def skipIfRocm(fn):
fn(*args, **kwargs)
return wrapper
+def skipIfRocmVersionLt(version):
+ def wrap(fn):
+ @wraps(fn)
+ def wrapped_fn(*args, **kwargs):
+ if TEST_WITH_ROCM and HIP_VERSION < version:
+ raise unittest.SkipTest("test doesn't work with ROCm stack < %s" % version)
+ else:
+ fn(*args, **kwargs)
+ return wrapped_fn
+ return wrap
+ Then you can decorate this test with |
I suppose the greater concern is not just skipping a test, but that the functionality enabled by this PR is simply broken for ROCm float16 type in ROCm 3.5, in which case, we would expect additional tests to fail that use the max operator until our CI is updated to ROCm 3.7. |
@ngimel please attempt to reland this PR now that ROCm CI images have moved to 3.7. |
Fixes #42364
Benchmark:
https://github.com/zasdfgbnm/things/blob/master/2020Q3/min-benchmark.ipynb
Before
After