Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom CUDA extensions #135

Merged
merged 83 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
c9ee639
Custom CUDA extensionsCo-authored-by: Mark Saroufim <marksaroufim@met…
zou3519 Apr 15, 2024
e93fd38
ci fix
msaroufim Apr 15, 2024
1eaa35f
add continue on ci fail
msaroufim Apr 15, 2024
24992d0
add verbose pip logging
msaroufim Apr 15, 2024
bd8eed1
update
msaroufim Apr 15, 2024
5d9ff0a
remove pip upgrade
msaroufim Apr 15, 2024
ede55b8
remove torch from requirements
msaroufim Apr 15, 2024
97857eb
add python -m
msaroufim Apr 15, 2024
931c97f
readd upgrade
msaroufim Apr 15, 2024
c11b8c2
make everything python pip install
msaroufim Apr 15, 2024
b46297c
typo
msaroufim Apr 15, 2024
2531971
bla
msaroufim Apr 15, 2024
cfb5abf
add torch again to dependencies
msaroufim Apr 15, 2024
7f9a7a1
upd
msaroufim Apr 15, 2024
4d67630
udpate
msaroufim Apr 15, 2024
3951369
one more
msaroufim Apr 15, 2024
2b274f3
bla
msaroufim Apr 15, 2024
ed55d4d
uipda
msaroufim Apr 15, 2024
faebb78
bla
msaroufim Apr 15, 2024
6a40627
bla
msaroufim Apr 15, 2024
18ac42c
add venv for pytest
msaroufim Apr 15, 2024
ccac280
skip test if no cuda
msaroufim Apr 15, 2024
b229d99
push
msaroufim Apr 15, 2024
9b44bc7
fix test
msaroufim Apr 15, 2024
fd34023
fix test
msaroufim Apr 15, 2024
e3cf0a8
fix skip tests
msaroufim Apr 15, 2024
bb113e2
reuse pip install .
msaroufim Apr 15, 2024
3427156
add more debugging tools
msaroufim Apr 15, 2024
2c6a667
yolo
msaroufim Apr 16, 2024
03aba8f
revert yolo change
msaroufim Apr 16, 2024
074ef41
fix
msaroufim Apr 16, 2024
04db944
remove custom install
msaroufim Apr 16, 2024
bca7b5c
print torch version
msaroufim Apr 16, 2024
327cd2b
print torch version
msaroufim Apr 16, 2024
d860891
teest'
msaroufim Apr 16, 2024
250dba1
try pip3?
msaroufim Apr 16, 2024
15c5e41
remove pip3
msaroufim Apr 16, 2024
626f1c2
reshuffling torch install
msaroufim Apr 16, 2024
e21ce5f
no deps?
msaroufim Apr 16, 2024
ed83c17
revert nodeps
msaroufim Apr 16, 2024
ea93b84
ad
msaroufim Apr 16, 2024
1197e99
bla
msaroufim Apr 16, 2024
55b0c26
bla
msaroufim Apr 16, 2024
c2e39b3
Merge branch 'main' into msaroufim/manylinux
cpuhrsch Apr 16, 2024
77c6645
lazy
msaroufim Apr 16, 2024
685a4f7
Merge branch 'main' into msaroufim/manylinux
msaroufim Apr 16, 2024
6a1c04c
use conda and cuda toolkit
msaroufim Apr 16, 2024
7ee1e92
merge
msaroufim Apr 16, 2024
b4c54b6
fix
msaroufim Apr 16, 2024
7225bae
push
msaroufim Apr 16, 2024
8d019fa
update
msaroufim Apr 16, 2024
e9ccf97
1 more try
msaroufim Apr 16, 2024
ee0d473
push
msaroufim Apr 16, 2024
41ee0ca
bla
msaroufim Apr 16, 2024
8935acd
bla
msaroufim Apr 16, 2024
36350b8
conda init
msaroufim Apr 16, 2024
794e365
clowntown
msaroufim Apr 16, 2024
de11d56
this works locally
msaroufim Apr 17, 2024
54eb2d1
update
msaroufim Apr 17, 2024
bd0347a
remove cache
msaroufim Apr 17, 2024
0c51485
push
msaroufim Apr 17, 2024
2ee80d7
alternate cuda install command
msaroufim Apr 17, 2024
1b3bc78
bla
msaroufim Apr 17, 2024
9f0dc35
bla
msaroufim Apr 17, 2024
b71570f
update
msaroufim Apr 17, 2024
1868f7d
push
msaroufim Apr 17, 2024
d94aed3
bla
msaroufim Apr 17, 2024
bd207d4
bla
msaroufim Apr 17, 2024
d3169bd
bla
msaroufim Apr 17, 2024
b9f8631
bla
msaroufim Apr 17, 2024
e38e69e
bla
msaroufim Apr 17, 2024
065b93e
bla
msaroufim Apr 17, 2024
83cb732
yolo
msaroufim Apr 17, 2024
857845c
yolo
msaroufim Apr 17, 2024
c0077d7
yolo
msaroufim Apr 17, 2024
1dfd56d
yolo
msaroufim Apr 17, 2024
fbc32f0
update
msaroufim Apr 25, 2024
efb3890
Merge branch 'main' into msaroufim/manylinux
msaroufim Apr 25, 2024
5f015c7
yolo
msaroufim Apr 25, 2024
0b6abdd
yolo
msaroufim Apr 25, 2024
0d516e5
yolo
msaroufim Apr 25, 2024
756a294
yolo
msaroufim Apr 25, 2024
0ba0006
yolo
msaroufim Apr 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ jobs:
pip install ${{ matrix.torch-spec }}
pip install -r requirements.txt
pip install -r dev-requirements.txt
pip install .
python setup.py install
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to update README? for building from source, is it python setup.py developer now? I spent some time to understand why pip install -e . does not work anymore and find this PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argh, lemme update now

pytest test --verbose -s
5 changes: 3 additions & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ expecttest
parameterized
packaging
transformers
bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers
bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers
matplotlib # needed for triton benchmarking
pandas # also for triton benchmarking
transformers #for galore testing
transformers #for galore testing
ninja
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools", "wheel", "ninja", "torch"]
build-backend = "setuptools.build_meta"
58 changes: 57 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# LICENSE file in the root directory of this source tree.

import os
import glob
from datetime import datetime

from setuptools import find_packages, setup

current_date = datetime.now().strftime("%Y.%m.%d")


def read_requirements(file_path):
with open(file_path, "r") as file:
return file.read().splitlines()
Expand All @@ -22,6 +22,60 @@ def read_requirements(file_path):
# Version is year.month.date if using nightlies
version = current_date if package_name == "torchao-nightly" else "0.1"

import torch

from torch.utils.cpp_extension import (
CppExtension,
CUDAExtension,
BuildExtension,
CUDA_HOME,
)


def get_extensions():
debug_mode = os.getenv('DEBUG', '0') == '1'
if debug_mode:
print("Compiling in debug mode")

# TODO: And cudatoolkit is available
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = []
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
]
}
if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))

if use_cuda:
sources += cuda_sources

ext_modules = [
extension(
"torchao._C",
sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
]

return ext_modules

setup(
name=package_name,
Expand All @@ -31,10 +85,12 @@ def read_requirements(file_path):
package_data={
"torchao.kernel.configs": ["*.pkl"],
},
ext_modules=get_extensions(),
install_requires=read_requirements("requirements.txt"),
extras_require={"dev": read_requirements("dev-requirements.txt")},
description="Package for applying ao techniques to GPU models",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch-labs/ao",
cmdclass={"build_ext": BuildExtension},
)
46 changes: 46 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.optests import opcheck
import torchao
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
import unittest


# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...):
# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace)
class TestOps(TestCase):
def _create_tensors_with_iou(self, N, iou_thresh):
# force last box to have a pre-defined iou with the first box
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
# Adjust the threshold upward a bit with the intent of creating
# at least one box that exceeds (barely) the threshold and so
# should be suppressed.
boxes = torch.rand(N, 4) * 100
boxes[:, 2:] += boxes[:, :2]
boxes[-1, :] = boxes[0, :]
x0, y0, x1, y1 = boxes[-1].tolist()
iou_thresh += 1e-5
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
scores = torch.rand(N)
return boxes, scores

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower")
def test_nms(self):
iou = 0.2
boxes, scores = self._create_tensors_with_iou(1000, iou)
boxes = boxes.cuda()
scores = scores.cuda()

# smoke test
_ = torchao.ops.nms(boxes, scores, iou)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.nms, (boxes, scores, iou), test_utils=test_utils)


if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
autoquant,
)
from . import dtypes
import torch
from . import _C
from . import ops

__all__ = [
"dtypes",
Expand Down
181 changes: 181 additions & 0 deletions torchao/csrc/cuda/nms.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>

namespace torchao {

namespace {

#define CUDA_1D_KERNEL_LOOP_T(i, n, index_t) \
for (index_t i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \
i += (blockDim.x * gridDim.x))

#define CUDA_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP_T(i, n, int)

template <typename integer>
constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
}

int const threadsPerBlock = sizeof(unsigned long long) * 8;

template <typename T>
__device__ inline bool devIoU(
T const* const a,
T const* const b,
const float threshold) {
T left = max(a[0], b[0]), right = min(a[2], b[2]);
T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
using acc_T = at::acc_type<T, /*is_cuda=*/true>;
acc_T interS = (acc_T)width * height;
acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]);
acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]);
return (interS / (Sa + Sb - interS)) > threshold;
}

template <typename T>
__global__ void nms_kernel_impl(
int n_boxes,
double iou_threshold,
const T* dev_boxes,
unsigned long long* dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;

if (row_start > col_start)
return;

const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);

__shared__ T block_boxes[threadsPerBlock * 4];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 4 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0];
block_boxes[threadIdx.x * 4 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1];
block_boxes[threadIdx.x * 4 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2];
block_boxes[threadIdx.x * 4 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3];
}
__syncthreads();

if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 4;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (devIoU<T>(cur_box, block_boxes + i * 4, iou_threshold)) {
t |= 1ULL << i;
}
}
const int col_blocks = ceil_div(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}

at::Tensor nms_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor");
TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor");

TORCH_CHECK(
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(
dets.size(1) == 4,
"boxes should have 4 elements in dimension 1, got ",
dets.size(1));
TORCH_CHECK(
scores.dim() == 1,
"scores should be a 1d tensor, got ",
scores.dim(),
"D");
TORCH_CHECK(
dets.size(0) == scores.size(0),
"boxes and scores should have same number of elements in ",
"dimension 0, got ",
dets.size(0),
" and ",
scores.size(0))

at::cuda::CUDAGuard device_guard(dets.device());

if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong));
}

auto order_t = std::get<1>(
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
auto dets_sorted = dets.index_select(0, order_t).contiguous();

int dets_num = dets.size(0);

const int col_blocks = ceil_div(dets_num, threadsPerBlock);

at::Tensor mask =
at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));

dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
dets_sorted.scalar_type(), "nms_kernel", [&] {
nms_kernel_impl<scalar_t><<<blocks, threads, 0, stream>>>(
dets_num,
iou_threshold,
dets_sorted.data_ptr<scalar_t>(),
(unsigned long long*)mask.data_ptr<int64_t>());
});

at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host =
(unsigned long long*)mask_cpu.data_ptr<int64_t>();

std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);

at::Tensor keep =
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();

int num_to_keep = 0;
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;

if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}

AT_CUDA_CHECK(cudaGetLastError());
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
}

} // namespace

TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::nms", &nms_kernel);
}

} // namespace torchao
3 changes: 3 additions & 0 deletions torchao/csrc/init.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include <torch/extension.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
8 changes: 8 additions & 0 deletions torchao/csrc/nms.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
#include <torch/types.h>

TORCH_LIBRARY_FRAGMENT(torchao, m) {
m.impl_abstract_pystub("torchao.ops");
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
}
23 changes: 23 additions & 0 deletions torchao/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from torch import Tensor

def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
"""
See https://pytorch.org/vision/main/generated/torchvision.ops.nms.html
"""
return torch.ops.torchao.nms.default(boxes, scores, iou_threshold)


# Defines the meta kernel / fake kernel / abstract impl
@torch.library.impl_abstract("torchao::nms")
def _(dets, scores, iou_threshold):
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
torch._check(
dets.size(0) == scores.size(0),
lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
)
ctx = torch._custom_ops.get_ctx()
num_to_keep = ctx.create_unbacked_symint()
return dets.new_empty(num_to_keep, dtype=torch.long)
Loading