Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo 3.11] fix jump if (not) none #96505

Closed
wants to merge 10 commits into from
24 changes: 24 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4723,6 +4723,30 @@ def fn(x):
res = opt_fn(x)
self.assertTrue(same(ref, res))

def test_if_tensor_is_none(self):
"""
Python 3.11 adds new jump instructions that check if
TOS is None. We do not support these instructions.
"""

def f(x, y):
z = 1
if x is None:
z *= 2
if y is not None:
z *= 3
return z

# TODO remove condition once 3.11 is fully supported
if sys.version_info < (3, 11):
opt_f = torch._dynamo.optimize("eager", nopython=True)(f)
self.assertEqual(opt_f(None, torch.ones(2)), 6)

if sys.version_info >= (3, 11):
insts = bytecode_transformation.cleaned_instructions(f.__code__)
for inst in insts:
self.assertNotIn("_NONE", inst.opname)

@unittest.skipIf(sys.version_info < (3, 11), "requires Python 3.11+")
def test_py311_jump_offset(self):
new_inst = bytecode_transformation.create_instruction
Expand Down
26 changes: 26 additions & 0 deletions torch/_dynamo/bytecode_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,28 @@ def remove_load_call_method(instructions: List[Instruction]):
return instructions


def remove_jump_if_none(instructions: List[Instruction]):
new_insts = []
for inst in instructions:
new_insts.append(inst)
if "_NONE" in inst.opname:
is_op = create_instruction("IS_OP", arg=int("NOT" in inst.opname))
is_op.argval = is_op.arg
jump_op = create_instruction(
"POP_JUMP_FORWARD_IF_TRUE"
if "FORWARD" in inst.opname
else "POP_JUMP_BACKWARD_IF_TRUE",
target=inst.target,
)
# modify inst in-place to preserve jump target
inst.opcode = dis.opmap["LOAD_CONST"]
inst.opname = "LOAD_CONST"
inst.arg = None
inst.argval = None
new_insts.extend([is_op, jump_op])
instructions[:] = new_insts


def explicit_super(code: types.CodeType, instructions: List[Instruction]):
"""convert super() with no args into explicit arg form"""
cell_and_free = (code.co_cellvars or tuple()) + (code.co_freevars or tuple())
Expand Down Expand Up @@ -642,6 +664,10 @@ def cleaned_instructions(code, safe=False):
if not safe:
if sys.version_info < (3, 11):
remove_load_call_method(instructions)
else:
remove_jump_if_none(instructions)
update_offsets(instructions)
devirtualize_jumps(instructions)
explicit_super(code, instructions)
return instructions

Expand Down
13 changes: 0 additions & 13 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,14 +408,6 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
return decorator


def is_none(x):
return x is None


def is_not_none(x):
return x is not None


class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState]):
output: OutputGraph
symbolic_locals: Dict[str, VariableTracker]
Expand Down Expand Up @@ -1581,11 +1573,6 @@ def SWAP(self, inst):
POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False)
POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False)

POP_JUMP_FORWARD_IF_NOT_NONE = generic_jump(is_not_none, False)
POP_JUMP_BACKWARD_IF_NOT_NONE = generic_jump(is_not_none, False)
POP_JUMP_FORWARD_IF_NONE = generic_jump(is_none, False)
POP_JUMP_BACKWARD_IF_NONE = generic_jump(is_none, False)

def CACHE(self, inst):
pass

Expand Down