Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

port masked_select from TH to ATen and optimize perf on CPU #33269

Closed
wants to merge 6 commits into from

Conversation

mingfeima
Copy link
Collaborator

This PR ports masked_select from TH to ATen and optimize the performance on CPU with TensorIterator.

#33053

  1. single socket run: up to 5.4x speedup;
  2. single core run: up to 1.16x speedup.

@mingfeima
Copy link
Collaborator Author

In order to parallel masked_select with TensorIterator, extra meta data which records the offset of output tensor is needed, masked_prefix_sum. It is calculated with std::partial_sum at the moment since PyTorch supports C++14. Once PyTorch upgrades to C++17, we can use std::exclusive_scan which is from parallel STL, thus have better performance on CPU. I checked CUDA implementation has used thrust::exclusive_scan but CPU counterpart has dependency on C++17.

Below shows the performance compare between original and this pr, benchmark code available at this op_bench-py, to reproduce, ./run.sh masked_select.py. Test machine is Xeon Skylake 6148 with 2*20 cores @2.40GHz.

The unit is ms per iteration, the lower the better.
The benchmark takes mask = input.ge(0.0), so roughly 50% of the input elements are selected.

single socket run

input size original this pr speedup
[128 1000] 0.817 0.295 2.77
[256 1000] 1.899 0.413 4.60
[512 1000] 3.747 0.700 5.35
[1024 1000] 6.696 1.238 5.41

single core run

input size original this pr speedup
[128 1000] 0.812 0.703 1.16
[256 1000] 1.620 1.421 1.14
[512 1000] 3.246 2.823 1.15
[1024 1000] 6.488 5.675 1.14

@dr-ci
Copy link

dr-ci bot commented Feb 13, 2020

💊 CI failures summary and remediations

As of commit 065825b (more details on the Dr. CI page):


  • 4/4 failures possibly* introduced in this PR
    • 1/4 non-CircleCI failure(s)

🕵️ 3 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_build (1/3)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_definitions.py 
Auto-merging .circleci/cimodel/data/pytorch_build_definitions.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_data.py 
Auto-merging .circleci/cimodel/data/pytorch_build_data.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/caffe2_build_definitions.py 
Auto-merging .circleci/cimodel/data/caffe2_build_definitions.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/binary_build_definitions.py 
Auto-merging .circleci/cimodel/data/binary_build_definitions.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/binary_build_data.py 
Auto-merging .circleci/cimodel/data/binary_build_data.py 
Automatic merge failed; fix conflicts and then commit the result. 

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_build (2/3)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_definitions.py 
Auto-merging .circleci/cimodel/data/pytorch_build_definitions.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/pytorch_build_data.py 
Auto-merging .circleci/cimodel/data/pytorch_build_data.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/caffe2_build_definitions.py 
Auto-merging .circleci/cimodel/data/caffe2_build_definitions.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/binary_build_definitions.py 
Auto-merging .circleci/cimodel/data/binary_build_definitions.py 
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/binary_build_data.py 
Auto-merging .circleci/cimodel/data/binary_build_data.py 
Automatic merge failed; fix conflicts and then commit the result. 

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_build (3/3)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

Error generating file
Retry attempt 3: 
THCTensorMaskedShort.cu
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\include\thrust\system\cuda\detail\cub\util_type.cuh(901): error C2993: 'T': illegal type for non-type template parameter '__formal'
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\include\thrust\system\cuda\detail\cub\util_type.cuh(909): note: see reference to class template instantiation 'thrust::cuda_cub::cub::BinaryOpHasIdxParam<T,BinaryOp>' being compiled
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\include\thrust\system\cuda\detail\cub\util_type.cuh(901): error C2065: '__T4': undeclared identifier
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\include\thrust\system\cuda\detail\cub\util_type.cuh(901): error C2923: 'std::_Select<__formal>::_Apply': '__T4' is not a valid template type argument for parameter '<unnamed-symbol>'
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\include\thrust\system\cuda\detail\cub\util_type.cuh(901): error C2062: type 'unknown-type' unexpected
-- Removing C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/./torch_cuda_generated_THCTensorMaskedShort.cu.obj
C:/Jenkins/Miniconda3/Library/bin/cmake.exe -E remove C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/./torch_cuda_generated_THCTensorMaskedShort.cu.obj
CMake Error at torch_cuda_generated_THCTensorMaskedShort.cu.obj.Release.cmake:281 (message):
  Error generating file
  C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/./torch_cuda_generated_THCTensorMaskedShort.cu.obj


ated_file:STRING=C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/./torch_cuda_generated_THCTensorMaskedFloat.cu.obj -D generated_cubin_file:STRING=C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/./torch_cuda_generated_THCTensorMaskedFloat.cu.obj.cubin.txt -P C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/torch_cuda_generated_THCTensorMaskedFloat.cu.obj.Release.cmake" 
-- Removing C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/./torch_cuda_generated_THCTensorMaskedFloat.cu.obj
C:/Jenkins/Miniconda3/Library/bin/cmake.exe -E remove C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/./torch_cuda_generated_THCTensorMaskedFloat.cu.obj
-- Generating dependency file: C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/torch_cuda_generated_THCTensorMaskedFloat.cu.obj.NVCC-depend
gle_include -IC:/Users/circleci/project/aten/src/ATen/.. -IC:/Users/circleci/project/build/caffe2/aten/src/ATen -IC:/Users/circleci/project/c10/cuda/../.. -IC:/Users/circleci/project/c10/../ "-IC:/Program Files/NVIDIA Corporation/NvToolsExt/include" -IC:/Users/circleci/project/caffe2/../torch/csrc/api -IC:/Users/circleci/project/caffe2/../torch/csrc/api/include -IC:/Users/circleci/project/build/third_party/ideep/mkl-dnn/include -IC:/Users/circleci/project/third_party/ideep/mkl-dnn/src/../include
THCTensorMaskedFloat.cu 
-- Generating temporary cmake readable file: C:/Users/circleci/project/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/torch_cuda_generated_THCTensorMaskedFloat.cu.obj.depend.tmp

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 27 times.

@VitalyFedyunin VitalyFedyunin self-requested a review February 13, 2020 21:19
@yf225 yf225 added module: porting Issues related to porting TH/THNN legacy to ATen native triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 19, 2020
@kurtamohler
Copy link
Collaborator

Nice job, looks good to me overall. I think you need to rebase to the current master to fix the conflict.

@mingfeima mingfeima force-pushed the indexing/masked_select branch from 3532e7c to c560e14 Compare February 21, 2020 00:29
@mingfeima
Copy link
Collaborator Author

rebased

@kurtamohler kurtamohler self-requested a review February 21, 2020 19:30
@VitalyFedyunin
Copy link
Contributor

Feb 24 08:29:38 ======================================================================
Feb 24 08:29:38 ERROR [0.008s]: test_fmod_cpu_float32 (__main__.TestTorchDeviceTypeCPU)
Feb 24 08:29:38 ----------------------------------------------------------------------
Feb 24 08:29:38 Traceback (most recent call last):
Feb 24 08:29:38   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 197, in instantiated_test
Feb 24 08:29:38     result = test(self, device_arg, dtype)
Feb 24 08:29:38   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 398, in only_fn
Feb 24 08:29:38     return fn(slf, device, *args, **kwargs)
Feb 24 08:29:38   File "test_torch.py", line 13577, in test_fmod
Feb 24 08:29:38     _test_fmod_with_size_tensor(size, device, dtype)
Feb 24 08:29:38   File "test_torch.py", line 13556, in _test_fmod_with_size_tensor
Feb 24 08:29:38     a = torch.rand(size=size, device=device, dtype=dtype)
Feb 24 08:29:38 RuntimeError: _th_uniform_ not supported on CPUType for Byte

@mingfeima
Copy link
Collaborator Author

test_fmod_cpu_float32

Sorry, it is this related to this one?

@VitalyFedyunin
Copy link
Contributor

test_fmod_cpu_float32

Sorry, it is this related to this one?

Oops, wrong browser tab, sorry

@kurtamohler
Copy link
Collaborator

kurtamohler commented Feb 25, 2020

Hey @mingfeima , are you sure your performance comparison between your implementation and the old TH implementation is correct? I was running some measurements on my cuda implementation, and I decided to collect the cpu performance data as well to make sure my changes won't negatively affect yours. Here's the performance I see with a single CPU core, using your commit 1e3b5c4.

dtype tensor_size mask_true_ratio TH CPU time (seconds) ATen cpu (mingfeima) time (seconds) speedup mingfeima ATen over TH
torch.float64 1000 0 0.0000146 0.0001218 0.1198686371
torch.float64 1000 0.25 0.0000182 0.0001524 0.1194225722
torch.float64 1000 0.5 0.0000176 0.0001544 0.1139896373
torch.float64 1000 0.75 0.0000174 0.0001538 0.1131339402
torch.float64 1000 1 0.0000172 0.0001504 0.1143617021
torch.float64 10000 0 0.0000524 0.0006164 0.08500973394
torch.float64 10000 0.25 0.0000722 0.0007122 0.101376018
torch.float64 10000 0.5 0.000076 0.0007412 0.1025364274
torch.float64 10000 0.75 0.000072 0.000738 0.09756097561
torch.float64 10000 1 0.0000586 0.0007264 0.08067180617
torch.float64 100000 0 0.000439 0.0055718 0.07878961915
torch.float64 100000 0.25 0.0006324 0.006367 0.09932464269
torch.float64 100000 0.5 0.0007926 0.0067212 0.1179253705
torch.float64 100000 0.75 0.0006554 0.0066458 0.09861867646
torch.float64 100000 1 0.0004654 0.0065112 0.07147683991
torch.float64 1000000 0 0.0042768 0.0559348 0.07646045038
torch.float64 1000000 0.25 0.0063654 0.064207 0.09913872319
torch.float64 1000000 0.5 0.0083026 0.0674632 0.1230685766
torch.float64 1000000 0.75 0.0067032 0.0665706 0.1006930988
torch.float64 1000000 1 0.0050948 0.0653446 0.07796818712
torch.float64 10000000 0 0.0424076 0.577658 0.07341298831
torch.float64 10000000 0.25 0.0647738 0.6789826 0.09539832096
torch.float64 10000000 0.5 0.0916282 0.7228486 0.1267598775
torch.float64 10000000 0.75 0.0796012 0.7180316 0.110860302
torch.float64 10000000 1 0.0646096 0.7097978 0.09102535962

My performance measurement script is here, perhaps there's some mistake in it: https://github.com/kurtamohler/pytorch-perf-test-scripts/blob/master/masked_select/masked_select-perf.py

I suspect there might be an issue with how your script is measuring time: https://github.com/mingfeima/op_bench-py/blob/master/masked_select.py#L20

The time.time() call uses a syscall, which could be taking a significantly longer time to finish than the masked_select call. It's usually more accurate to call time.time() only twice, surrounding the loop, and then divide the total time by the number of loop iterations, like this:

start_time = time.time()
while blah:
    do stuff
total_time = (time.time() - start_time) / num_iters

Or maybe the difference is just that your script might not be ensuring a single-core run with torch.set_num_threads(1). I'm not sure.

@kurtamohler
Copy link
Collaborator

Actually, I think your performance measurement script is sufficient. When I run it, I get something that seems to be similar to my measurements, showing a performance decrease.

TH CPU masked_select:

input size: [128 1000]; output size: [63941]: time = 1.027 ms
input size: [256 1000]; output size: [127483]: time = 2.225 ms
input size: [512 1000]; output size: [255929]: time = 4.103 ms
input size: [1024 1000]; output size: [511179]: time = 7.888 ms

ATen CPU masked_select:

input size: [128 1000]; output size: [63941]: time = 5.091 ms
input size: [256 1000]; output size: [127483]: time = 6.737 ms
input size: [512 1000]; output size: [255929]: time = 10.314 ms
input size: [1024 1000]; output size: [511179]: time = 17.129 ms

TH CPU masked_select (modified to use only a single core):

input size: [128 1000]; output size: [63941]: time = 1.030 ms
input size: [256 1000]; output size: [127483]: time = 2.098 ms
input size: [512 1000]; output size: [255929]: time = 4.241 ms
input size: [1024 1000]; output size: [511179]: time = 8.834 ms

ATen CPU masked_select (modified to use only a single core):

input size: [128 1000]; output size: [63941]: time = 8.616 ms
input size: [256 1000]; output size: [127483]: time = 17.228 ms
input size: [512 1000]; output size: [255929]: time = 34.527 ms
input size: [1024 1000]; output size: [511179]: time = 69.079 ms

@mingfeima
Copy link
Collaborator Author

@kurtamohler What type of CPU are you using? And how about the environment setting?

@mingfeima
Copy link
Collaborator Author

mingfeima commented Feb 26, 2020

Post my machine configuration:

Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                80
On-line CPU(s) list:   0-79
Thread(s) per core:    2
Core(s) per socket:    20
Socket(s):             2
NUMA node(s):          2
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 85
Model name:            Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
Stepping:              4
CPU MHz:               1000.125
CPU max MHz:           3700.0000
CPU min MHz:           1000.0000
BogoMIPS:              4800.00
Virtualization:        VT-x
L1d cache:             32K
L1i cache:             32K
L2 cache:              1024K
L3 cache:              28160K
NUMA node0 CPU(s):     0-19,40-59
NUMA node1 CPU(s):     20-39,60-79
  • original log
### this refers to ATen implementation
(pytorch-mingfei) [mingfeim@mlt-skx090 op_bench-py]$ ./run.sh masked_select.py

### using KMP_AFFINITY=granularity=fine,compact,1,0
### using KMP_BLOCKTIME=1
### using numactl --physcpubind=0-19 --membind=0


### using OMP_NUM_THREADS=20
input size: [128 1000]; output size: [63941]: time = 0.295 ms
input size: [256 1000]; output size: [127483]: time = 0.413 ms
input size: [512 1000]; output size: [255929]: time = 0.700 ms
input size: [1024 1000]; output size: [511179]: time = 1.238 ms

### using OMP_NUM_THREADS=1
input size: [128 1000]; output size: [63941]: time = 0.703 ms
input size: [256 1000]; output size: [127483]: time = 1.421 ms
input size: [512 1000]; output size: [255929]: time = 2.823 ms
input size: [1024 1000]; output size: [511179]: time = 5.675 ms

### this refers to original TH implementation
(pytorch-cuda) [mingfeim@mlt-skx090 op_bench-py]$ ./run.sh masked_select.py

### using KMP_AFFINITY=granularity=fine,compact,1,0
### using KMP_BLOCKTIME=1
### using numactl --physcpubind=0-19 --membind=0


### using OMP_NUM_THREADS=20
input size: [128 1000]; output size: [63941]: time = 0.817 ms
input size: [256 1000]; output size: [127483]: time = 1.899 ms
input size: [512 1000]; output size: [255929]: time = 3.747 ms
input size: [1024 1000]; output size: [511179]: time = 6.696 ms

### using OMP_NUM_THREADS=1
input size: [128 1000]; output size: [63941]: time = 0.812 ms
input size: [256 1000]; output size: [127483]: time = 1.620 ms
input size: [512 1000]; output size: [255929]: time = 3.246 ms
input size: [1024 1000]; output size: [511179]: time = 6.488 ms
  • this is log when update benchmark script with only two time.time() calls, to verify the overhead of time.time(). Usually the minimal time segment should ms level for benchmark purpose. The variance here is more related to python rather than time().
### this refers to ATen implementation
(pytorch-mingfei) [mingfeim@mlt-skx089 op_bench-py]$ ./run.sh masked_select.py

### using KMP_AFFINITY=granularity=fine,compact,1,0
### using KMP_BLOCKTIME=1
### using numactl --physcpubind=0-19 --membind=0


### using OMP_NUM_THREADS=20
input size: [128 1000]; output size: [63941]: time = 0.320 ms
input size: [256 1000]; output size: [127483]: time = 0.447 ms
input size: [512 1000]; output size: [255929]: time = 0.706 ms
input size: [1024 1000]; output size: [511179]: time = 1.255 ms

### using OMP_NUM_THREADS=1
input size: [128 1000]; output size: [63941]: time = 0.698 ms
input size: [256 1000]; output size: [127483]: time = 1.416 ms
input size: [512 1000]; output size: [255929]: time = 2.814 ms
input size: [1024 1000]; output size: [511179]: time = 5.651 ms


### this refers to original TH implementation
(pytorch-cuda) [mingfeim@mlt-skx089 op_bench-py]$ ./run.sh masked_select.py

### using KMP_AFFINITY=granularity=fine,compact,1,0
### using KMP_BLOCKTIME=1
### using numactl --physcpubind=0-19 --membind=0


### using OMP_NUM_THREADS=20
input size: [128 1000]; output size: [63941]: time = 0.874 ms
input size: [256 1000]; output size: [127483]: time = 1.900 ms
input size: [512 1000]; output size: [255929]: time = 3.711 ms
input size: [1024 1000]; output size: [511179]: time = 6.727 ms

### using OMP_NUM_THREADS=1
input size: [128 1000]; output size: [63941]: time = 0.860 ms
input size: [256 1000]; output size: [127483]: time = 1.612 ms
input size: [512 1000]; output size: [255929]: time = 3.228 ms
input size: [1024 1000]; output size: [511179]: time = 6.450 ms

Since this is not an inplace op, it is almost a must to config jemalloc (tcmalloc also works), otherwise the clear page will be a huge interfere. (also jemalloc is not numa aware... be careful for dual socket machine)

From the algorithm level, it is possible that on single core with larger tensor, this ATen implementation might introduce a downgrade. Since prefix-sum requires additional memory footprint thus more demand on memory bandwidth, in case L1/L2/LLC fails to hold it, you will see a drop.

More efficient way here is to write two sets of kernels for sequential and parallel, the parallel one goes like this ( I actually copied the TH cuda implementation). And sequential one could use nonzero to mark the offset of input tensor and feed TensorIterator with a re-strided input (instead of re-strided output), therefore numel of TensorIteraor will be only the TRUEs in the mask.
However this is not very perf critical op and also to avoid complexity in testing, i did not do it (so as the cuda code).

@kurtamohler
Copy link
Collaborator

kurtamohler commented Feb 26, 2020

My cpu is an AMD Ryzen Threadripper 2970WX, which has 32 KB L1d and 512 KB L2 per core, and 8 MB per 3 cores. But I don't think the difference between our cache sizes could be the explanation, because the measurements I posted show this slowdown even in relatively small tensor sizes, down to 1000, which should fit within L1 for both of us.

Is the TH version using a different amount of memory than your ATen version?

I'm not really sure what jemalloc is. I guess it's just one of the implementations of malloc available as a kernel module, right? I'm guessing my malloc is not jemalloc, which might explain our machines' performance difference?


auto shape = _self.sizes().vec();
auto mask_long = at::empty(shape, self.options().dtype(at::kLong)).copy_(_mask);
int64_t numel = mask_long.sum().item().toLong();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While working on my cuda version of masked_select, I realized that we can reuse the final element from the cumulative sum result to get numel, rather than running a whole other sum calculation.

@kurtamohler
Copy link
Collaborator

@VitalyFedyunin , when you get a chance, could you review this PR? I've ported the CUDA version, but it depends on some of mingfeima's changes here, so I think I have to wait until this PR is merged before creating my PR.

kurtamohler added a commit that referenced this pull request Mar 11, 2020
Issue #33054

Depends on PR #33269.

[ghstack-poisoned]
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@anjali411 anjali411 self-requested a review April 8, 2020 19:07
@ngimel
Copy link
Collaborator

ngimel commented May 14, 2020

Ok, cool, so AMD perf issues are resolved.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good, just some small comments.

@mingfeima mingfeima force-pushed the indexing/masked_select branch from 8f8102f to 065825b Compare May 19, 2020 05:11
@mingfeima
Copy link
Collaborator Author

@ngimel thanks for the comments, all fixed!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in fe66bdb.

gchanan added a commit to gchanan/pytorch that referenced this pull request Jul 22, 2020
…ytorch#33269)"

This reverts commit fe66bdb.

This also makes a sense to THTensorEvenMoreMath because sumall was removed, see THTensor_wrap.
gchanan added a commit that referenced this pull request Jul 22, 2020
…33269)"

This reverts commit fe66bdb.

This also makes a sense to THTensorEvenMoreMath because sumall was removed, see THTensor_wrap.
gchanan added a commit that referenced this pull request Jul 22, 2020
…33269)"

This reverts commit fe66bdb.

This also makes a sense to THTensorEvenMoreMath because sumall was removed, see THTensor_wrap.

ghstack-source-id: 15804225fba97e1ecd87bcb31f9b278d1df2685b
Pull Request resolved: #41828
facebook-github-bot pushed a commit that referenced this pull request Jul 22, 2020
…33269)" (#41828)

Summary:
Pull Request resolved: #41828

This reverts commit fe66bdb.

This also makes a sense to THTensorEvenMoreMath because sumall was removed, see THTensor_wrap.

Test Plan: Imported from OSS

Reviewed By: orionr

Differential Revision: D22657473

Pulled By: malfet

fbshipit-source-id: 95a806cedf1a3f4df91e6a21de1678252b117489
malfet pushed a commit that referenced this pull request Jul 22, 2020
…33269)" (#41829)

This reverts commit fe66bdb.

This also makes a sense to THTensorEvenMoreMath because sumall was removed, see THTensor_wrap.
ngimel pushed a commit to ngimel/pytorch that referenced this pull request Jul 22, 2020
facebook-github-bot pushed a commit that referenced this pull request Aug 4, 2020
Summary:
This fixes #41473 for discontiguous input, mask and out. Tests to follow. Reverting #33269 is not a great solution because I'm told masked_select was needed for printing complex tensors.
cc gchanan , zou3519, ezyang

Pull Request resolved: #41841

Reviewed By: mruberry

Differential Revision: D22706943

Pulled By: ngimel

fbshipit-source-id: 413d7fd3f3308b184de04fd56b8a9aaabcad22fc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: porting Issues related to porting TH/THNN legacy to ATen native open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants