Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support weight-stripped engine and REFIT_IDENTICAL flag #3167

Merged
merged 59 commits into from
Dec 12, 2024
Merged
Changes from 1 commit
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
40349a8
support weight-stripped engine and REFIT_IDENTICAL flag
zewenli98 Sep 19, 2024
5d7c677
refactor with new design
zewenli98 Sep 20, 2024
82b7ddc
lint
zewenli98 Oct 1, 2024
9f6a771
samll fix
zewenli98 Oct 1, 2024
7ea3c0f
remove make_refittable
zewenli98 Oct 1, 2024
bf7553b
fast refit -> slow refit
zewenli98 Oct 2, 2024
46e9bc8
fix np.bool_, group_norm
zewenli98 Oct 2, 2024
d783fdd
add immutable_weights
zewenli98 Oct 2, 2024
160588e
skip engine caching for non-refittable engines, slow refit -> fast refit
zewenli98 Oct 2, 2024
493f981
refactored, there are 3 types of engines
zewenli98 Oct 5, 2024
f204104
fix and add tests
zewenli98 Oct 5, 2024
4663c83
fix issues #3206 #3217
zewenli98 Oct 8, 2024
c57ab06
small fix
zewenli98 Oct 15, 2024
402c9b0
resolve comments
zewenli98 Oct 15, 2024
d8e59da
WIP: cache weight-stripped engine
zewenli98 Oct 22, 2024
e8811fd
Merge branch 'main' into weight_stripped_engine
zewenli98 Oct 31, 2024
f2e3f00
redesigned hash func and add constant mapping to fast refit
zewenli98 Nov 4, 2024
31af308
refactor and add tests
zewenli98 Nov 6, 2024
1ae33f4
Merge branch 'main' into weight_stripped_engine
zewenli98 Nov 6, 2024
90bf679
update
zewenli98 Nov 6, 2024
a8a34f6
increase ENGINE_CACHE_SIZE
zewenli98 Nov 6, 2024
285bc90
skip some tests
zewenli98 Nov 7, 2024
2d152cf
fix tests
zewenli98 Nov 7, 2024
d461608
try fixing cumsum
zewenli98 Nov 8, 2024
d57b885
Merge branch 'main' into weight_stripped_engine
zewenli98 Nov 8, 2024
23d68d5
fix windows cross compile, TODO: whether windows support stripping en…
zewenli98 Nov 8, 2024
a928f67
CI debug test 1
zewenli98 Nov 13, 2024
02625ca
CI debug test 2
zewenli98 Nov 14, 2024
c462e40
CI debug test 3
zewenli98 Nov 16, 2024
9ba33b5
Merge branch 'main' into weight_stripped_engine
Nov 19, 2024
3d68039
reduce -n to 4 for converter tests on CI
zewenli98 Nov 20, 2024
2e7ef3b
reduce -n to 4 for converter tests on CI
zewenli98 Nov 20, 2024
9ff165c
simplify test_different_args_dont_share_cached_engine
zewenli98 Nov 22, 2024
8ca8e2d
reduce -n to 2
zewenli98 Nov 22, 2024
f9f2a70
reduce -n to 1
zewenli98 Nov 22, 2024
c69c61a
revert -n back to 4 and chunk converter
zewenli98 Nov 23, 2024
05b560d
change to opt-in feature
zewenli98 Nov 28, 2024
7feea97
fix conflict
zewenli98 Nov 28, 2024
d1521c3
fix typo
zewenli98 Nov 28, 2024
5a193a2
Merge branch 'main' into weight_stripped_engine
Nov 29, 2024
0b345be
small fix
zewenli98 Dec 3, 2024
6754481
Merge branch 'main' into weight_stripped_engine
zewenli98 Dec 6, 2024
4a7e957
update to manylinux2_28-builder
zewenli98 Dec 10, 2024
6e840ba
remove cuda12.6 tests
zewenli98 Dec 10, 2024
9a8473a
remove one_user_validator for native_layer_norm
zewenli98 Dec 10, 2024
6a07767
clear tests
zewenli98 Dec 10, 2024
ed3424a
remove the whole chunk
zewenli98 Dec 10, 2024
ef54239
add cuda12.6 back and export D_GLIBCXX_USE_CXX11_ABI=1
zewenli98 Dec 10, 2024
f166562
fix env
zewenli98 Dec 10, 2024
80aae71
fix container
zewenli98 Dec 10, 2024
676c9ce
fix env
zewenli98 Dec 11, 2024
bf2edc6
fix env
zewenli98 Dec 11, 2024
627d510
fix env
zewenli98 Dec 11, 2024
b393b6f
fix env
zewenli98 Dec 11, 2024
78d72b6
fix env
zewenli98 Dec 11, 2024
a5d3c18
export USE_CXX11_ABI=1 for cuda12.6
zewenli98 Dec 11, 2024
4f02da8
remove chunk
zewenli98 Dec 11, 2024
7d7423a
resolve comments
zewenli98 Dec 12, 2024
9f76304
Merge branch 'main' into weight_stripped_engine
zewenli98 Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
clear tests
  • Loading branch information
zewenli98 committed Dec 10, 2024
commit 6a077675113cacb1a1030101ff35ac8cfce677ed
41 changes: 9 additions & 32 deletions tests/py/dynamo/models/test_weight_stripped_engine.py
Original file line number Diff line number Diff line change
@@ -39,19 +39,13 @@ def test_three_ways_to_compile(self):
)
gm1_output = gm1(*example_inputs)

# 2. Compile with torch_trt.compile using dynamo backend
gm2 = torch_trt.compile(
pyt_model, ir="dynamo", inputs=example_inputs, **settings
)
gm2_output = gm2(*example_inputs)

# 3. Compile with torch.compile using tensorrt backend
gm3 = torch.compile(
# 2. Compile with torch.compile using tensorrt backend
gm2 = torch.compile(
pyt_model,
backend="tensorrt",
options=settings,
)
gm3_output = gm3(*example_inputs)
gm2_output = gm2(*example_inputs)

pyt_model_output = pyt_model(*example_inputs)

@@ -63,14 +57,9 @@ def test_three_ways_to_compile(self):
gm1_output, gm2_output, 1e-2, 1e-2
), "gm2_output is not correct"

assert torch.allclose(
gm2_output, gm3_output, 1e-2, 1e-2
), "gm3_output is not correct"

def test_three_ways_to_compile_weight_stripped_engine(self):
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
exp_program = torch.export.export(pyt_model, example_inputs)

settings = {
"use_python_runtime": False,
@@ -82,36 +71,24 @@ def test_three_ways_to_compile_weight_stripped_engine(self):
"refit_identical_engine_weights": False,
}

# 1. Compile with torch_trt.dynamo.compile
gm1 = torch_trt.dynamo.compile(
exp_program,
example_inputs,
**settings,
)
gm1_output = gm1(*example_inputs)

# 2. Compile with torch_trt.compile using dynamo backend
gm2 = torch_trt.compile(
# 1. Compile with torch_trt.compile using dynamo backend
gm1 = torch_trt.compile(
pyt_model, ir="dynamo", inputs=example_inputs, **settings
)
gm2_output = gm2(*example_inputs)
gm1_output = gm1(*example_inputs)

# 3. Compile with torch.compile using tensorrt backend, which is not supported to set strip_engine_weights=True
# gm3 = torch.compile(
# 2. Compile with torch.compile using tensorrt backend, which is not supported to set strip_engine_weights=True
# gm2 = torch.compile(
# pyt_model,
# backend="tensorrt",
# options=settings,
# )
# gm3_output = gm3(*example_inputs)
# gm2_output = gm2(*example_inputs)

assertions.assertEqual(
gm1_output.sum(), 0, msg="gm1_output should be all zeros"
)

assertions.assertEqual(
gm2_output.sum(), 0, msg="gm2_output should be all zeros"
)

def test_weight_stripped_engine_sizes(self):
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)