diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 2feeefc4ef9..1ebe9b2224d 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -340,10 +340,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: exir.print_program.pretty_print(program) deboxed_int_list = [] - for item in program.execution_plan[0].values[5].val.items: # pyre-ignore[16] - deboxed_int_list.append( - program.execution_plan[0].values[item].val.int_val # pyre-ignore[16] - ) + for item in program.execution_plan[0].values[5].val.items: + deboxed_int_list.append(program.execution_plan[0].values[item].val.int_val) self.assertEqual(IntList(deboxed_int_list), IntList([2, 0, 1])) @@ -459,11 +457,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Check the mul operator's stack trace contains f -> g -> h self.assertTrue( "return torch.mul(x, torch.randn(3, 2))" - in program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .stacktrace[1] - .items[-1] - .context + in program.execution_plan[0].chains[0].stacktrace[1].items[-1].context ) self.assertEqual( program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "f" @@ -616,11 +610,7 @@ def false_fn(y: torch.Tensor) -> torch.Tensor: if not isinstance(inst.instr_args, KernelCall): continue - op = ( - program.execution_plan[0] - .operators[inst.instr_args.op_index] # pyre-ignore[16] - .name - ) + op = program.execution_plan[0].operators[inst.instr_args.op_index].name if "mm" in op: num_mm += 1 @@ -657,19 +647,13 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # generate the tensor on which this iteration will operate on. self.assertEqual( op_table[ - program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .instructions[0] - .instr_args.op_index + program.execution_plan[0].chains[0].instructions[0].instr_args.op_index ].name, "aten::sym_size", ) self.assertEqual( op_table[ - program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .instructions[1] - .instr_args.op_index + program.execution_plan[0].chains[0].instructions[1].instr_args.op_index ].name, "aten::select_copy", ) @@ -681,28 +665,19 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # We check here that both of these have been generated. self.assertEqual( op_table[ - program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .instructions[-5] - .instr_args.op_index + program.execution_plan[0].chains[0].instructions[-5].instr_args.op_index ].name, "executorch_prim::et_copy_index", ) self.assertEqual( op_table[ - program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .instructions[-4] - .instr_args.op_index + program.execution_plan[0].chains[0].instructions[-4].instr_args.op_index ].name, "executorch_prim::add", ) self.assertEqual( op_table[ - program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .instructions[-3] - .instr_args.op_index + program.execution_plan[0].chains[0].instructions[-3].instr_args.op_index ].name, "executorch_prim::eq", ) @@ -716,10 +691,7 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ) self.assertEqual( op_table[ - program.execution_plan[0] # pyre-ignore[16] - .chains[0] - .instructions[-1] - .instr_args.op_index + program.execution_plan[0].chains[0].instructions[-1].instr_args.op_index ].name, "executorch_prim::sub", ) @@ -1300,9 +1272,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # this triggers the actual emission of the graph program = program_mul._emitter_output.program node = None - program.execution_plan[0].chains[0].instructions[ # pyre-ignore[16] - 0 - ].instr_args.op_index + program.execution_plan[0].chains[0].instructions[0].instr_args.op_index # Find the multiplication node in the graph that was emitted. for node in program_mul.exported_program().graph.nodes: @@ -1314,7 +1284,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Find the multiplication instruction in the program that was emitted. for idx in range(len(program.execution_plan[0].chains[0].instructions)): instruction = program.execution_plan[0].chains[0].instructions[idx] - op_index = instruction.instr_args.op_index # pyre-ignore[16] + op_index = instruction.instr_args.op_index if "mul" in program.execution_plan[0].operators[op_index].name: break @@ -1453,9 +1423,7 @@ def forward(self, x, y): exec_prog._emitter_output.program self.assertIsNotNone(exec_prog.delegate_map) self.assertIsNotNone(exec_prog.delegate_map.get("forward")) - self.assertIsNotNone( - exec_prog.delegate_map.get("forward").get(0) # pyre-ignore[16] - ) + self.assertIsNotNone(exec_prog.delegate_map.get("forward").get(0)) self.assertEqual( exec_prog.delegate_map.get("forward").get(0).get("name"), "BackendWithCompilerExample", @@ -1568,9 +1536,7 @@ def forward(self, x): model = model.to_executorch() model.dump_executorch_program(True) self.assertTrue( - model.executorch_program.execution_plan[0] # pyre-ignore[16] - .values[0] - .val.allocation_info + model.executorch_program.execution_plan[0].values[0].val.allocation_info is not None ) executorch_module = _load_for_executorch_from_buffer(model.buffer) @@ -1611,9 +1577,7 @@ def forward(self, x): ) model.dump_executorch_program(True) self.assertTrue( - model.executorch_program.execution_plan[0] # pyre-ignore[16] - .values[0] - .val.allocation_info + model.executorch_program.execution_plan[0].values[0].val.allocation_info is not None ) executorch_module = _load_for_executorch_from_buffer(model.buffer) diff --git a/exir/program/TARGETS b/exir/program/TARGETS index fc73abf1ff7..674d7baa35e 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -1,4 +1,5 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") oncall("executorch") @@ -43,7 +44,7 @@ python_library( "//executorch/exir/passes:spec_prop_pass", "//executorch/exir/passes:weights_to_outputs_pass", "//executorch/exir/verification:verifier", - ], + ] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else []) ) python_library( diff --git a/exir/program/_program.py b/exir/program/_program.py index b136d6cead9..cbfa1105280 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -75,8 +75,24 @@ Val = Any +from typing import Any, Callable + from torch.library import Library +try: + from executorch.exir.program.fb.logger import et_logger +except ImportError: + # Define a stub decorator that does nothing + def et_logger(api_name: str) -> Callable[[Any], Any]: + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + # This is the reserved namespace that is used to register ops to that will # be prevented from being decomposed during to_edge_transform_and_lower. edge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP" @@ -957,6 +973,7 @@ def _gen_edge_manager_for_partitioners( return edge_manager +@et_logger("to_edge_transform_and_lower") def to_edge_transform_and_lower( programs: Union[ExportedProgram, Dict[str, ExportedProgram]], transform_passes: Optional[ @@ -1110,6 +1127,7 @@ def to_edge_with_preserved_ops( ) +@et_logger("to_edge") def to_edge( programs: Union[ExportedProgram, Dict[str, ExportedProgram]], constant_methods: Optional[Dict[str, Any]] = None, @@ -1204,8 +1222,10 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram: """ Returns the ExportedProgram specified by 'method_name'. """ + return self._edge_programs[method_name] + @et_logger("transform") def transform( self, passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]], @@ -1253,6 +1273,7 @@ def transform( new_programs, copy.deepcopy(self._config_methods), compile_config ) + @et_logger("to_backend") def to_backend( self, partitioner: Union[Partitioner, Dict[str, Partitioner]] ) -> "EdgeProgramManager": @@ -1296,6 +1317,7 @@ def to_backend( new_edge_programs, copy.deepcopy(self._config_methods), config ) + @et_logger("to_executorch") def to_executorch( self, config: Optional[ExecutorchBackendConfig] = None, diff --git a/exir/tests/test_joint_graph.py b/exir/tests/test_joint_graph.py index 2413e2b4980..f3b6f0ed557 100644 --- a/exir/tests/test_joint_graph.py +++ b/exir/tests/test_joint_graph.py @@ -73,25 +73,21 @@ def forward(self, x, y): # assert that the weight and bias have proper data_buffer_idx and allocation_info self.assertEqual( - et.executorch_program.execution_plan[0] # pyre-ignore - .values[0] - .val.data_buffer_idx, + et.executorch_program.execution_plan[0].values[0].val.data_buffer_idx, 1, ) self.assertEqual( - et.executorch_program.execution_plan[0] # pyre-ignore - .values[1] - .val.data_buffer_idx, + et.executorch_program.execution_plan[0].values[1].val.data_buffer_idx, 2, ) self.assertEqual( - et.executorch_program.execution_plan[0] # pyre-ignore + et.executorch_program.execution_plan[0] .values[0] .val.allocation_info.memory_offset_low, 0, ) self.assertEqual( - et.executorch_program.execution_plan[0] # pyre-ignore + et.executorch_program.execution_plan[0] .values[1] .val.allocation_info.memory_offset_low, 48, @@ -106,7 +102,7 @@ def forward(self, x, y): self.assertTrue(torch.allclose(loss, et_outputs[0])) self.assertTrue( - torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore[6] + torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore ) self.assertTrue(torch.allclose(m.linear.bias.grad, et_outputs[2])) self.assertTrue(torch.allclose(m.linear.weight, et_outputs[3])) @@ -118,23 +114,17 @@ def forward(self, x, y): # gradient outputs start at index 1 self.assertEqual( - et.executorch_program.execution_plan[1] # pyre-ignore - .values[0] - .val.int_val, + et.executorch_program.execution_plan[1].values[0].val.int_val, 1, ) self.assertEqual( - et.executorch_program.execution_plan[2] # pyre-ignore - .values[0] - .val.string_val, + et.executorch_program.execution_plan[2].values[0].val.string_val, "linear.weight", ) # parameter outputs start at index 3 self.assertEqual( - et.executorch_program.execution_plan[3] # pyre-ignore - .values[0] - .val.int_val, + et.executorch_program.execution_plan[3].values[0].val.int_val, 3, ) diff --git a/exir/tests/test_remove_view_copy.py b/exir/tests/test_remove_view_copy.py index 0925a8abc89..318dc085b45 100644 --- a/exir/tests/test_remove_view_copy.py +++ b/exir/tests/test_remove_view_copy.py @@ -196,24 +196,14 @@ def test_spec(self) -> None: instructions = plan.chains[0].instructions self.assertEqual(len(instructions), 7) + self.assertEqual(instructions[0].instr_args.op_index, 0) # view @ idx2 + self.assertEqual(instructions[1].instr_args.op_index, 0) # view @ idx3 + self.assertEqual(instructions[2].instr_args.op_index, 1) # aten:mul @ idx6 + self.assertEqual(instructions[3].instr_args.op_index, 0) # view @ idx7 + self.assertEqual(instructions[4].instr_args.op_index, 1) # aten:mul @ idx9 self.assertEqual( - instructions[0].instr_args.op_index, 0 # pyre-ignore - ) # view @ idx2 - self.assertEqual( - instructions[1].instr_args.op_index, 0 # pyre-ignore - ) # view @ idx3 - self.assertEqual( - instructions[2].instr_args.op_index, 1 # pyre-ignore - ) # aten:mul @ idx6 - self.assertEqual( - instructions[3].instr_args.op_index, 0 # pyre-ignore - ) # view @ idx7 - self.assertEqual( - instructions[4].instr_args.op_index, 1 # pyre-ignore - ) # aten:mul @ idx9 - self.assertEqual( - instructions[5].instr_args.op_index, 2 # pyre-ignore + instructions[5].instr_args.op_index, 2 ) # aten:view_copy @ idx11 self.assertEqual( - instructions[6].instr_args.op_index, 2 # pyre-ignore + instructions[6].instr_args.op_index, 2 ) # aten:view_copy @ idx11