Skip to content

Commit 0396bff

Browse files
committed
Address comments
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent e1aef1d commit 0396bff

File tree

4 files changed

+12
-20
lines changed

4 files changed

+12
-20
lines changed

tests/compile/test_decorator.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
257257
return x
258258

259259
@support_torch_compile(no_weak_ref_output=True)
260+
@support_torch_compile(no_weak_ref_output=False)
260261
class B(A):
261262
...
262263

@@ -283,12 +284,8 @@ class C(B):
283284
):
284285
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
285286

286-
with compilation_counter.expect(
287-
num_weakref_output_graphs=1,
288-
# This is 1 instead of 0 because B inherits from A
289-
# and A's __init__ is called which initializes the VllmBackend
290-
# If no_weak_ref_output=False, this value would be 2
291-
) and set_current_vllm_config(vllm_config):
287+
with compilation_counter.expect(num_weakref_output_graphs=0,
288+
) and set_current_vllm_config(vllm_config):
292289
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
293290

294291
# B also has support_torch_compile
@@ -301,12 +298,8 @@ class C(B):
301298
):
302299
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
303300

304-
with compilation_counter.expect(
305-
num_weakref_output_graphs=2,
306-
# C inherits from B which inherits from A
307-
# both B and A's __init__ are called, incrementing the count by 2
308-
# as A has no_weak_ref_output=False
309-
) and set_current_vllm_config(vllm_config):
301+
with compilation_counter.expect(num_weakref_output_graphs=1,
302+
) and set_current_vllm_config(vllm_config):
310303
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
311304

312305
# C has support_torch_compile

vllm/compilation/backends.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,9 +454,6 @@ def __init__(
454454
self.compilation_config)
455455

456456
self.no_weak_ref_output = no_weak_ref_output
457-
if not self.no_weak_ref_output:
458-
# used for testing purposes
459-
compilation_counter.num_weakref_output_graphs += 1
460457

461458
# `torch.compile` is JIT compiled, so we don't need to
462459
# do anything here

vllm/compilation/cuda_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def __call__(self, *args, **kwargs):
167167
# any other cuda graph.
168168
output = weak_ref_tensors(output)
169169

170+
compilation_counter.num_weakref_output_graphs += 1
171+
170172
# here we always use weak ref for the output
171173
# to save memory
172174
entry.output = weak_ref_tensors(output)

vllm/compilation/decorators.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def _support_torch_compile(
219219
"""
220220
A decorator to add support for compiling the forward method of a class.
221221
"""
222+
setattr(cls, IGNORE_COMPILE_KEY, False)
223+
224+
# setting as attribute on cls ensures child class will override parent class
225+
setattr(cls, LAST_PIECEWISE_GRAPH_WEAKREF_KEY, no_weak_ref_output)
226+
222227
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
223228
# support decorating multiple times
224229
return cls
@@ -230,11 +235,6 @@ def _support_torch_compile(
230235

231236
old_init = cls.__init__
232237

233-
setattr(cls, IGNORE_COMPILE_KEY, False)
234-
235-
# setting as attribute on cls ensures child class will override parent class
236-
setattr(cls, LAST_PIECEWISE_GRAPH_WEAKREF_KEY, no_weak_ref_output)
237-
238238
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
239239
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
240240
self.vllm_config = vllm_config

0 commit comments

Comments
 (0)