Skip to content

Commit 355436f

Browse files
committed
add support for torch.Tensor.size
1 parent 6793994 commit 355436f

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

vllm/model_executor/model_optimizer/fused_op_generator_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def arg_schema_type(n: torch.fx.node.Argument,
7272
ty = n.type.__name__
7373
elif n.meta.get('type') and n.meta.get('type').__name__ != 'FakeTensor':
7474
ty = n.meta.get('type').__name__
75-
print(f"meta type {ty}")
7675
if ty == 'Size':
7776
return 'std::vector<int64_t>' if add_prefix else 'int[]'
7877
else:
@@ -84,7 +83,6 @@ def arg_schema_type(n: torch.fx.node.Argument,
8483
if add_prefix and ty in builtin_types:
8584
return builtin_types[ty]
8685

87-
print(f"arg_schema_type {ty}")
8886
if ty == "SymInt" and add_prefix:
8987
return "int64_t"
9088

vllm/model_executor/model_optimizer/naive_fused_op_generator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,6 @@ def make_fused_op(
320320
f"{arg_schema_type(inp, True)}" for inp in inputs.values()
321321
]
322322
logger.debug("fused op argument types: %s", arg_types)
323-
print(f"fused op argument types: {str(arg_types)}")
324323
for i, name in enumerate(inputs.keys()):
325324
# Don't use const refs here so inputs can be deleted when no
326325
# longer needed.
@@ -370,6 +369,11 @@ def make_fused_op(
370369

371370
for n, fn in zip(nodes, fn_names):
372371
return_type = extract_node_type(n)
372+
373+
# Total hack
374+
if n.op == 'call_method':
375+
return_type = "Size"
376+
373377
input_types = [argument_type_str(inp) for inp in n.args]
374378
comment_str = f" // ({', '.join(input_types)}) -> {return_type}"
375379

@@ -388,7 +392,10 @@ def make_fused_op(
388392
f"{self.sanitize(n.args[0].name, '::')}.")
389393
first_arg = 1
390394

391-
if node_function_target(n).startswith("torch.ops._C"):
395+
# First check is total hack here
396+
if fn == 'size':
397+
call_str = call_str + "sizes("
398+
elif node_function_target(n).startswith("torch.ops._C"):
392399
call_str = call_str + f"{self.sanitize(fn, '::')}.call("
393400
else:
394401
call_str = call_str + f"{self.sanitize(fn, '::')}("

vllm/model_executor/model_optimizer/register.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def register_defaults():
103103
logger.debug("REGISTER DEFAULTS")
104104
# Note: methods need to be supported via function object and not name.
105105
register_fusable(torch.Tensor.to)
106+
register_fusable(torch.Tensor.size, is_trivial=True)
106107
register_fusable(torch.Tensor.transpose, is_trivial=True)
107108
register_fusable(torch.Tensor.numel, is_trivial=True)
108109
register_fusable('_operator.add')

0 commit comments

Comments
 (0)