diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 15e0b23d36f..5c24f08c732 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -372,6 +372,21 @@ def _get_allocation_info(self, spec: TensorSpec) -> AllocationDetails: ) return allocation_info + def _save_to_external_constant_map( + self, + fqn: str, + buffer_idx: int, + constant_tag: str, + ) -> None: + """ + Saves external constant to the map. + """ + # buffer data should be in the external_constant_buffer already. + assert buffer_idx < len(self.program_state.external_constant_buffer) + if constant_tag not in self.program_state.external_constant_map: + self.program_state.external_constant_map[constant_tag] = {} + self.program_state.external_constant_map[constant_tag][fqn] = buffer_idx + def _save_new_const_tensor( self, spec: TensorSpec, @@ -403,11 +418,9 @@ def _save_new_const_tensor( buffer_idx = len(self.program_state.external_constant_buffer) self.program_state.external_constant_hash[hashed] = buffer_idx self.program_state.external_constant_buffer.append(buffer_data) - if constant_tag not in self.program_state.external_constant_map: - self.program_state.external_constant_map[constant_tag] = {} - self.program_state.external_constant_map[constant_tag][ - spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`. - ] = buffer_idx + self._save_to_external_constant_map( + spec.extra_tensor_info.fully_qualified_name, buffer_idx, constant_tag + ) # Tensor is mutable with initial state. Place into mutable segment elif allocation_info: buffer_idx = len(self.program_state.mutable_buffer) @@ -466,6 +479,19 @@ def _tensor_spec_to_evalue( and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL ): buffer_idx = self.program_state.external_constant_hash.get(hashed, -1) + if buffer_idx != -1: + # This constant already exists in the external_constant_buffer, + # And doesn't need to be duplicated. However, the fqn is unique + # and should be added. ie, we have the case: fqn0->data, fqn1->data. + # When buffer_idx == 1, the data is new and added with + # `_save_new_const_tensor` below. + assert spec.extra_tensor_info.fully_qualified_name is not None + assert constant_tag is not None + self._save_to_external_constant_map( + spec.extra_tensor_info.fully_qualified_name, + buffer_idx, + constant_tag, + ) else: buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index dcc3544875a..199a667ab64 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1719,6 +1719,83 @@ def forward(self, x): self.assertEqual(external_map["linear.weight"], 0) self.assertEqual(external_map["linear.bias"], 1) + def test_constant_tagged_tensor_dedup(self) -> None: + class ConstantModule(nn.Module): + def __init__(self): + super().__init__() + constant = torch.tensor([1.0, 2.0, 3.0]) + + # Register the same value with two different names as persistent buffers + self.register_buffer("c0", constant.clone(), persistent=True) + self.register_buffer("c1", constant.clone(), persistent=True) + + def forward(self, x): + return x + self.c0 + self.c1 + + model = to_edge( + export(ConstantModule(), (torch.ones(1, 3),), strict=True) + ).to_executorch( + config=ExecutorchBackendConfig( + external_constants=True, + ) + ) + emitter_output = model._emitter_output + # constant_buffer is empty besides the non-constant placeholder 0. + self.assertEqual(len(emitter_output.program.constant_buffer), 1) + # only one item in the external constant buffer. + self.assertEqual(len(emitter_output.external_constant_buffer), 1) + # Setting external_constants=True, saves all constants to the key + # '_default_external_constant'. + external_map = emitter_output.external_constant_map[ + "_default_external_constant" + ] + self.assertEqual(len(external_map), 2) + self.assertEqual(external_map["c0"], 0) + self.assertEqual(external_map["c1"], 0) + + def test_constant_tagged_tensor_dedup_2(self) -> None: + class ConstantModule(nn.Module): + def __init__(self): + super().__init__() + constant0_4 = torch.tensor([1.0, 2.0, 3.0]) + constant4_5 = torch.tensor([2.0, 3.0, 4.0]) + + # Register the same value with two different names as persistent buffers + self.register_buffer("c0", constant0_4.clone(), persistent=True) + self.register_buffer("c1", constant0_4.clone(), persistent=True) + self.register_buffer("c2", constant0_4.clone(), persistent=True) + self.register_buffer("c3", constant0_4.clone(), persistent=True) + self.register_buffer("c4", constant4_5.clone(), persistent=True) + self.register_buffer("c5", constant4_5.clone(), persistent=True) + + def forward(self, x): + return x + self.c0 + self.c1 + self.c2 + self.c3 + self.c4 + self.c5 + + model = to_edge( + export(ConstantModule(), (torch.ones(1, 3),), strict=True) + ).to_executorch( + config=ExecutorchBackendConfig( + external_constants=True, + ) + ) + emitter_output = model._emitter_output + # constant_buffer is empty besides the non-constant placeholder 0. + self.assertEqual(len(emitter_output.program.constant_buffer), 1) + # Two items in the external constant buffer. + self.assertEqual(len(emitter_output.external_constant_buffer), 2) + # Setting external_constants=True, saves all constants to the key + # '_default_external_constant'. + external_map = emitter_output.external_constant_map[ + "_default_external_constant" + ] + self.assertEqual(len(external_map), 6) + self.assertEqual(external_map["c0"], 0) + self.assertEqual(external_map["c1"], 0) + self.assertEqual(external_map["c2"], 0) + self.assertEqual(external_map["c3"], 0) + self.assertEqual(external_map["c4"], 1) + self.assertEqual(external_map["c5"], 1) + def test_delegate_deduplicate(self) -> None: class SharedModule(torch.nn.Module): def __init__(self):