diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index 7725ebaad..f3187b9e5 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -355,13 +355,25 @@ ) # isinstance +Dispatcher.register( + isinstance, + ("TensorVariable", "VariableBase"), + {}, + lambda left, right: ConstantVariable.wrap_literal( + isinstance(paddle.to_tensor(0), right.get_py_value(allow_tensor=True)), + left.graph, + ), +) + Dispatcher.register( isinstance, ("VariableBase", "VariableBase"), {}, lambda left, right: ConstantVariable.wrap_literal( - left.get_py_type() == right.get_py_value() - or left.get_py_type() in right.get_py_value(), + isinstance( + left.get_py_value(allow_tensor=True), + right.get_py_value(allow_tensor=True), + ), left.graph, ), )