-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom CUDA extensions #135
Merged
Merged
Changes from all commits
Commits
Show all changes
83 commits
Select commit
Hold shift + click to select a range
c9ee639
Custom CUDA extensionsCo-authored-by: Mark Saroufim <marksaroufim@met…
zou3519 e93fd38
ci fix
msaroufim 1eaa35f
add continue on ci fail
msaroufim 24992d0
add verbose pip logging
msaroufim bd8eed1
update
msaroufim 5d9ff0a
remove pip upgrade
msaroufim ede55b8
remove torch from requirements
msaroufim 97857eb
add python -m
msaroufim 931c97f
readd upgrade
msaroufim c11b8c2
make everything python pip install
msaroufim b46297c
typo
msaroufim 2531971
bla
msaroufim cfb5abf
add torch again to dependencies
msaroufim 7f9a7a1
upd
msaroufim 4d67630
udpate
msaroufim 3951369
one more
msaroufim 2b274f3
bla
msaroufim ed55d4d
uipda
msaroufim faebb78
bla
msaroufim 6a40627
bla
msaroufim 18ac42c
add venv for pytest
msaroufim ccac280
skip test if no cuda
msaroufim b229d99
push
msaroufim 9b44bc7
fix test
msaroufim fd34023
fix test
msaroufim e3cf0a8
fix skip tests
msaroufim bb113e2
reuse pip install .
msaroufim 3427156
add more debugging tools
msaroufim 2c6a667
yolo
msaroufim 03aba8f
revert yolo change
msaroufim 074ef41
fix
msaroufim 04db944
remove custom install
msaroufim bca7b5c
print torch version
msaroufim 327cd2b
print torch version
msaroufim d860891
teest'
msaroufim 250dba1
try pip3?
msaroufim 15c5e41
remove pip3
msaroufim 626f1c2
reshuffling torch install
msaroufim e21ce5f
no deps?
msaroufim ed83c17
revert nodeps
msaroufim ea93b84
ad
msaroufim 1197e99
bla
msaroufim 55b0c26
bla
msaroufim c2e39b3
Merge branch 'main' into msaroufim/manylinux
cpuhrsch 77c6645
lazy
msaroufim 685a4f7
Merge branch 'main' into msaroufim/manylinux
msaroufim 6a1c04c
use conda and cuda toolkit
msaroufim 7ee1e92
merge
msaroufim b4c54b6
fix
msaroufim 7225bae
push
msaroufim 8d019fa
update
msaroufim e9ccf97
1 more try
msaroufim ee0d473
push
msaroufim 41ee0ca
bla
msaroufim 8935acd
bla
msaroufim 36350b8
conda init
msaroufim 794e365
clowntown
msaroufim de11d56
this works locally
msaroufim 54eb2d1
update
msaroufim bd0347a
remove cache
msaroufim 0c51485
push
msaroufim 2ee80d7
alternate cuda install command
msaroufim 1b3bc78
bla
msaroufim 9f0dc35
bla
msaroufim b71570f
update
msaroufim 1868f7d
push
msaroufim d94aed3
bla
msaroufim bd207d4
bla
msaroufim d3169bd
bla
msaroufim b9f8631
bla
msaroufim e38e69e
bla
msaroufim 065b93e
bla
msaroufim 83cb732
yolo
msaroufim 857845c
yolo
msaroufim c0077d7
yolo
msaroufim 1dfd56d
yolo
msaroufim fbc32f0
update
msaroufim efb3890
Merge branch 'main' into msaroufim/manylinux
msaroufim 5f015c7
yolo
msaroufim 0b6abdd
yolo
msaroufim 0d516e5
yolo
msaroufim 756a294
yolo
msaroufim 0ba0006
yolo
msaroufim File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[build-system] | ||
requires = ["setuptools", "wheel", "ninja", "torch"] | ||
build-backend = "setuptools.build_meta" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import torch | ||
from torch.testing._internal.common_utils import TestCase | ||
from torch.testing._internal.optests import opcheck | ||
import torchao | ||
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 | ||
import unittest | ||
|
||
|
||
# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): | ||
# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace) | ||
class TestOps(TestCase): | ||
def _create_tensors_with_iou(self, N, iou_thresh): | ||
# force last box to have a pre-defined iou with the first box | ||
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1], | ||
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh, | ||
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh | ||
# Adjust the threshold upward a bit with the intent of creating | ||
# at least one box that exceeds (barely) the threshold and so | ||
# should be suppressed. | ||
boxes = torch.rand(N, 4) * 100 | ||
boxes[:, 2:] += boxes[:, :2] | ||
boxes[-1, :] = boxes[0, :] | ||
x0, y0, x1, y1 = boxes[-1].tolist() | ||
iou_thresh += 1e-5 | ||
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh | ||
scores = torch.rand(N) | ||
return boxes, scores | ||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | ||
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower") | ||
def test_nms(self): | ||
iou = 0.2 | ||
boxes, scores = self._create_tensors_with_iou(1000, iou) | ||
boxes = boxes.cuda() | ||
scores = scores.cuda() | ||
|
||
# smoke test | ||
_ = torchao.ops.nms(boxes, scores, iou) | ||
|
||
# comprehensive testing | ||
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] | ||
opcheck(torch.ops.torchao.nms, (boxes, scores, iou), test_utils=test_utils) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/AccumulateType.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
#include <torch/library.h> | ||
|
||
namespace torchao { | ||
|
||
namespace { | ||
|
||
#define CUDA_1D_KERNEL_LOOP_T(i, n, index_t) \ | ||
for (index_t i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ | ||
i += (blockDim.x * gridDim.x)) | ||
|
||
#define CUDA_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP_T(i, n, int) | ||
|
||
template <typename integer> | ||
constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { | ||
return (n + m - 1) / m; | ||
} | ||
|
||
int const threadsPerBlock = sizeof(unsigned long long) * 8; | ||
|
||
template <typename T> | ||
__device__ inline bool devIoU( | ||
T const* const a, | ||
T const* const b, | ||
const float threshold) { | ||
T left = max(a[0], b[0]), right = min(a[2], b[2]); | ||
T top = max(a[1], b[1]), bottom = min(a[3], b[3]); | ||
T width = max(right - left, (T)0), height = max(bottom - top, (T)0); | ||
using acc_T = at::acc_type<T, /*is_cuda=*/true>; | ||
acc_T interS = (acc_T)width * height; | ||
acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]); | ||
acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]); | ||
return (interS / (Sa + Sb - interS)) > threshold; | ||
} | ||
|
||
template <typename T> | ||
__global__ void nms_kernel_impl( | ||
int n_boxes, | ||
double iou_threshold, | ||
const T* dev_boxes, | ||
unsigned long long* dev_mask) { | ||
const int row_start = blockIdx.y; | ||
const int col_start = blockIdx.x; | ||
|
||
if (row_start > col_start) | ||
return; | ||
|
||
const int row_size = | ||
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); | ||
const int col_size = | ||
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); | ||
|
||
__shared__ T block_boxes[threadsPerBlock * 4]; | ||
if (threadIdx.x < col_size) { | ||
block_boxes[threadIdx.x * 4 + 0] = | ||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0]; | ||
block_boxes[threadIdx.x * 4 + 1] = | ||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1]; | ||
block_boxes[threadIdx.x * 4 + 2] = | ||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2]; | ||
block_boxes[threadIdx.x * 4 + 3] = | ||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3]; | ||
} | ||
__syncthreads(); | ||
|
||
if (threadIdx.x < row_size) { | ||
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; | ||
const T* cur_box = dev_boxes + cur_box_idx * 4; | ||
int i = 0; | ||
unsigned long long t = 0; | ||
int start = 0; | ||
if (row_start == col_start) { | ||
start = threadIdx.x + 1; | ||
} | ||
for (i = start; i < col_size; i++) { | ||
if (devIoU<T>(cur_box, block_boxes + i * 4, iou_threshold)) { | ||
t |= 1ULL << i; | ||
} | ||
} | ||
const int col_blocks = ceil_div(n_boxes, threadsPerBlock); | ||
dev_mask[cur_box_idx * col_blocks + col_start] = t; | ||
} | ||
} | ||
|
||
at::Tensor nms_kernel( | ||
const at::Tensor& dets, | ||
const at::Tensor& scores, | ||
double iou_threshold) { | ||
TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor"); | ||
TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor"); | ||
|
||
TORCH_CHECK( | ||
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); | ||
TORCH_CHECK( | ||
dets.size(1) == 4, | ||
"boxes should have 4 elements in dimension 1, got ", | ||
dets.size(1)); | ||
TORCH_CHECK( | ||
scores.dim() == 1, | ||
"scores should be a 1d tensor, got ", | ||
scores.dim(), | ||
"D"); | ||
TORCH_CHECK( | ||
dets.size(0) == scores.size(0), | ||
"boxes and scores should have same number of elements in ", | ||
"dimension 0, got ", | ||
dets.size(0), | ||
" and ", | ||
scores.size(0)) | ||
|
||
at::cuda::CUDAGuard device_guard(dets.device()); | ||
|
||
if (dets.numel() == 0) { | ||
return at::empty({0}, dets.options().dtype(at::kLong)); | ||
} | ||
|
||
auto order_t = std::get<1>( | ||
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); | ||
auto dets_sorted = dets.index_select(0, order_t).contiguous(); | ||
|
||
int dets_num = dets.size(0); | ||
|
||
const int col_blocks = ceil_div(dets_num, threadsPerBlock); | ||
|
||
at::Tensor mask = | ||
at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); | ||
|
||
dim3 blocks(col_blocks, col_blocks); | ||
dim3 threads(threadsPerBlock); | ||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||
dets_sorted.scalar_type(), "nms_kernel", [&] { | ||
nms_kernel_impl<scalar_t><<<blocks, threads, 0, stream>>>( | ||
dets_num, | ||
iou_threshold, | ||
dets_sorted.data_ptr<scalar_t>(), | ||
(unsigned long long*)mask.data_ptr<int64_t>()); | ||
}); | ||
|
||
at::Tensor mask_cpu = mask.to(at::kCPU); | ||
unsigned long long* mask_host = | ||
(unsigned long long*)mask_cpu.data_ptr<int64_t>(); | ||
|
||
std::vector<unsigned long long> remv(col_blocks); | ||
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); | ||
|
||
at::Tensor keep = | ||
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); | ||
int64_t* keep_out = keep.data_ptr<int64_t>(); | ||
|
||
int num_to_keep = 0; | ||
for (int i = 0; i < dets_num; i++) { | ||
int nblock = i / threadsPerBlock; | ||
int inblock = i % threadsPerBlock; | ||
|
||
if (!(remv[nblock] & (1ULL << inblock))) { | ||
keep_out[num_to_keep++] = i; | ||
unsigned long long* p = mask_host + i * col_blocks; | ||
for (int j = nblock; j < col_blocks; j++) { | ||
remv[j] |= p[j]; | ||
} | ||
} | ||
} | ||
|
||
AT_CUDA_CHECK(cudaGetLastError()); | ||
return order_t.index( | ||
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep) | ||
.to(order_t.device(), keep.scalar_type())}); | ||
} | ||
|
||
} // namespace | ||
|
||
TORCH_LIBRARY_IMPL(torchao, CUDA, m) { | ||
m.impl("torchao::nms", &nms_kernel); | ||
} | ||
|
||
} // namespace torchao |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#include <torch/extension.h> | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#include <ATen/core/dispatch/Dispatcher.h> | ||
#include <torch/library.h> | ||
#include <torch/types.h> | ||
|
||
TORCH_LIBRARY_FRAGMENT(torchao, m) { | ||
m.impl_abstract_pystub("torchao.ops"); | ||
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: | ||
""" | ||
See https://pytorch.org/vision/main/generated/torchvision.ops.nms.html | ||
""" | ||
return torch.ops.torchao.nms.default(boxes, scores, iou_threshold) | ||
|
||
|
||
# Defines the meta kernel / fake kernel / abstract impl | ||
@torch.library.impl_abstract("torchao::nms") | ||
def _(dets, scores, iou_threshold): | ||
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") | ||
torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}") | ||
torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}") | ||
torch._check( | ||
dets.size(0) == scores.size(0), | ||
lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}", | ||
) | ||
ctx = torch._custom_ops.get_ctx() | ||
num_to_keep = ctx.create_unbacked_symint() | ||
return dets.new_empty(num_to_keep, dtype=torch.long) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to update README? for building from source, is it
python setup.py developer
now? I spent some time to understand whypip install -e .
does not work anymore and find this PRThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
argh, lemme update now