From 72048d662799a70e9d3b3f988218e018b5d09843 Mon Sep 17 00:00:00 2001 From: Hoa Nguyen Date: Wed, 4 Sep 2024 17:06:57 -0700 Subject: [PATCH] Fix hash collisions among data points when using timing_characterization_client; fix #1525 Currently, there is a hash collision of two kArrayIndex data points when running timing_characterization_client. This caused by the fact that the timing client relies on hashing the op to map a sample to its result, and the hash function did not take into account the element count of every operand. So, if there are two data points that only differ in element counts, the data point of only one of the two samples will be recorded in the output. This change fixes the problem by adding element counts part of the hash. PiperOrigin-RevId: 671152475 --- .../delay_model/delay_model_utils.py | 20 ++- .../delay_model/delay_model_utils_test.py | 145 ++++++++++++++---- 2 files changed, 132 insertions(+), 33 deletions(-) diff --git a/xls/estimators/delay_model/delay_model_utils.py b/xls/estimators/delay_model/delay_model_utils.py index 5fd050e954..6e12b5a91e 100644 --- a/xls/estimators/delay_model/delay_model_utils.py +++ b/xls/estimators/delay_model/delay_model_utils.py @@ -15,6 +15,7 @@ from collections.abc import Mapping import dataclasses +import math from typing import Sequence from xls.estimators.delay_model import delay_model_pb2 @@ -42,10 +43,17 @@ def get_data_point_key(data_point: delay_model_pb2.DataPoint) -> str: def get_sample_spec_key(spec: SampleSpec) -> str: """Returns a key that can be used to represent a sample spec in a dict.""" + # This dictionary stores the element counts for array operands. For non-array + # operands, the element count is 0. + array_operand_element_counts = { + e.operand_number: math.prod(e.element_counts) + for e in spec.point.operand_element_counts + } bit_count_strs = [] - for bit_count in spec.point.operand_widths: + for operand_idx, operand_width in enumerate(spec.point.operand_widths): operand = delay_model_pb2.Operation.Operand( - bit_count=bit_count, element_count=0 + bit_count=operand_width, + element_count=array_operand_element_counts.get(operand_idx, 0), ) bit_count_strs.append(str(operand)) key = ( @@ -71,4 +79,12 @@ def data_point_to_sample_spec( point.operand_widths.extend( [operand.bit_count for operand in data_point.operation.operands] ) + for operand_idx, operand in enumerate(data_point.operation.operands): + if operand.element_count: + point.operand_element_counts.append( + delay_model_pb2.OperandElementCounts( + operand_number=operand_idx, + element_counts=[operand.element_count], + ) + ) return SampleSpec(op_samples, point) diff --git a/xls/estimators/delay_model/delay_model_utils_test.py b/xls/estimators/delay_model/delay_model_utils_test.py index ad10797f0e..48a8ee8202 100644 --- a/xls/estimators/delay_model/delay_model_utils_test.py +++ b/xls/estimators/delay_model/delay_model_utils_test.py @@ -22,6 +22,7 @@ def _create_sample_spec( op_name: str, result_width: int, operand_widths: list[int], + element_counts: list[int], specialization: delay_model_pb2.SpecializationKind = delay_model_pb2.SpecializationKind.NO_SPECIALIZATION, ) -> delay_model_utils.SampleSpec: """Convenience function for tests to create SampleSpec protos.""" @@ -32,6 +33,12 @@ def _create_sample_spec( point = delay_model_pb2.Parameterization() point.result_width = result_width point.operand_widths.extend(operand_widths) + for operand_idx, element_count in enumerate(element_counts): + if element_count: + operand_element_counts = delay_model_pb2.OperandElementCounts() + operand_element_counts.operand_number = operand_idx + operand_element_counts.element_counts.append(element_count) + point.operand_element_counts.append(operand_element_counts) return delay_model_utils.SampleSpec(op_samples, point) @@ -39,6 +46,7 @@ def _create_data_point( op_name: str, result_width: int, operand_widths: list[int], + element_counts: list[int], specialization: delay_model_pb2.SpecializationKind = delay_model_pb2.SpecializationKind.NO_SPECIALIZATION, ) -> delay_model_pb2.DataPoint: """Convenience function for tests to create DataPoint protos.""" @@ -47,9 +55,10 @@ def _create_data_point( if specialization != delay_model_pb2.SpecializationKind.NO_SPECIALIZATION: dp.operation.specialization = specialization dp.operation.bit_count = result_width - for width in operand_widths: + for width, count in zip(operand_widths, element_counts): operand = delay_model_pb2.Operation.Operand() operand.bit_count = width + operand.element_count = count dp.operation.operands.append(operand) return dp @@ -59,46 +68,78 @@ class DelayModelUtilsTest(parameterized.TestCase): def test_sample_spec_key_equal_for_same_sample(self): self.assertEqual( delay_model_utils.get_sample_spec_key( - _create_sample_spec('kAnd', 16, [8, 8]) + _create_sample_spec('kAnd', 16, [8, 8], [0, 0]) ), delay_model_utils.get_sample_spec_key( - _create_sample_spec('kAnd', 16, [8, 8]) + _create_sample_spec('kAnd', 16, [8, 8], [0, 0]) ), ) def test_sample_spec_key_differs_for_op(self): self.assertNotEqual( delay_model_utils.get_sample_spec_key( - _create_sample_spec('kAnd', 16, [8, 8]) + _create_sample_spec('kAnd', 16, [8, 8], [0, 0]) ), delay_model_utils.get_sample_spec_key( - _create_sample_spec('kOr', 16, [8, 8]) + _create_sample_spec('kOr', 16, [8, 8], [0, 0]) ), ) def test_sample_spec_key_differs_for_result_width(self): self.assertNotEqual( delay_model_utils.get_sample_spec_key( - _create_sample_spec('kAnd', 16, [8, 8]) + _create_sample_spec('kAnd', 16, [8, 8], [0, 0]) ), delay_model_utils.get_sample_spec_key( - _create_sample_spec('kAnd', 8, [8, 8]) + _create_sample_spec('kAnd', 8, [8, 8], [0, 0]) ), ) @parameterized.named_parameters( - ('different_operand_counts', [8, 8, 8], [8, 8]), - ('different_widths_same_count', [8, 8], [8, 4]), + ('different_operand_counts', [8, 8, 8], [0, 0, 0], [8, 8], [0, 0]), + ('different_widths_same_count', [8, 8], [0, 0], [8, 4], [0, 0]), ) def test_sample_spec_key_differs_for_operand_widths( - self, left_operand_widths, right_operand_widths + self, + left_operand_widths, + left_element_counts, + right_operand_widths, + right_element_counts, + ): + self.assertNotEqual( + delay_model_utils.get_sample_spec_key( + _create_sample_spec( + 'kAnd', 8, left_operand_widths, left_element_counts + ) + ), + delay_model_utils.get_sample_spec_key( + _create_sample_spec( + 'kAnd', 8, right_operand_widths, right_element_counts + ) + ), + ) + + @parameterized.named_parameters( + ('different_element_counts', [8, 8], [0, 0], [8, 8], [0, 16]), + ('different_widths_same_element_count', [8, 1], [0, 32], [8, 8], [0, 32]), + ) + def test_sample_spec_key_differs_for_element_counts( + self, + left_operand_widths, + left_element_counts, + right_operand_widths, + right_element_counts, ): self.assertNotEqual( delay_model_utils.get_sample_spec_key( - _create_sample_spec('kAnd', 8, left_operand_widths) + _create_sample_spec( + 'kArrayIndex', 8, left_operand_widths, left_element_counts + ) ), delay_model_utils.get_sample_spec_key( - _create_sample_spec('kAnd', 8, right_operand_widths) + _create_sample_spec( + 'kArrayIndex', 8, right_operand_widths, right_element_counts + ) ), ) @@ -119,73 +160,109 @@ def test_sample_spec_key_differs_for_specialization( ): self.assertNotEqual( delay_model_utils.get_sample_spec_key( - _create_sample_spec('kAnd', 4, [4], left_specialization) + _create_sample_spec('kAnd', 4, [4], [0], left_specialization) ), delay_model_utils.get_sample_spec_key( - _create_sample_spec('kAnd', 4, [4], right_specialization) + _create_sample_spec('kAnd', 4, [4], [0], right_specialization) ), ) @parameterized.named_parameters( - ('no_operand_widths', 'kOr', 32, []), - ('no_specialization', 'kOr', 32, [16, 16]), + ('no_operand_widths', 'kOr', 32, [], []), + ('no_specialization', 'kOr', 32, [16, 16], [0, 0]), ( 'specialization', 'kOr', 32, [16, 16], + [0, 0], delay_model_pb2.SpecializationKind.HAS_LITERAL_OPERAND, ), + ('non_zero_element_count', 'kArrayIndex', 8, [8, 8], [0, 16]), ) def test_data_point_to_sample_spec( self, op_name, result_width, operand_widths, + element_counts, specialization=delay_model_pb2.SpecializationKind.NO_SPECIALIZATION, ): self.assertEqual( delay_model_utils.data_point_to_sample_spec( _create_data_point( - op_name, result_width, operand_widths, specialization + op_name, + result_width, + operand_widths, + element_counts, + specialization, ) ), _create_sample_spec( - op_name, result_width, operand_widths, specialization + op_name, + result_width, + operand_widths, + element_counts, + specialization, ), ) - def test_get_data_point_key(self): + @parameterized.named_parameters( + ( + 'speicialization', + 'kOr', + 32, + [16, 16], + [0, 0], + delay_model_pb2.SpecializationKind.HAS_LITERAL_OPERAND, + ), + ('non_zero_element_count', 'kArrayIndex', 8, [8, 8], [0, 16]), + ) + def test_get_data_point_key( + self, + op_name, + result_width, + operand_widths, + element_counts, + specialization=delay_model_pb2.SpecializationKind.NO_SPECIALIZATION, + ): self.assertEqual( delay_model_utils.get_data_point_key( _create_data_point( - 'kOr', - 32, - [16, 16], - delay_model_pb2.SpecializationKind.HAS_LITERAL_OPERAND, + op_name, + result_width, + operand_widths, + element_counts, + specialization, ) ), delay_model_utils.get_sample_spec_key( _create_sample_spec( - 'kOr', - 32, - [16, 16], - delay_model_pb2.SpecializationKind.HAS_LITERAL_OPERAND, + op_name, + result_width, + operand_widths, + element_counts, + specialization, ) ), ) def test_map_data_points_by_key(self): - point1 = _create_data_point('kOr', 32, []) - point2 = _create_data_point('kOr', 32, [16, 16]) + point1 = _create_data_point('kOr', 32, [], []) + point2 = _create_data_point('kOr', 32, [16, 16], [0, 0]) point3 = _create_data_point( 'kAdd', 32, [16, 16], + [0, 0], delay_model_pb2.SpecializationKind.HAS_LITERAL_OPERAND, ) - mapping = delay_model_utils.map_data_points_by_key([point1, point2, point3]) - self.assertLen(mapping, 3) + point4 = _create_data_point('kArrayIndex', 8, [8, 8], [0, 16]) + point5 = _create_data_point('kArrayIndex', 8, [8, 8], [0, 1]) + mapping = delay_model_utils.map_data_points_by_key( + [point1, point2, point3, point4, point5] + ) + self.assertLen(mapping, 5) self.assertEqual( mapping[delay_model_utils.get_data_point_key(point1)], point1 ) @@ -195,6 +272,12 @@ def test_map_data_points_by_key(self): self.assertEqual( mapping[delay_model_utils.get_data_point_key(point3)], point3 ) + self.assertEqual( + mapping[delay_model_utils.get_data_point_key(point4)], point4 + ) + self.assertEqual( + mapping[delay_model_utils.get_data_point_key(point5)], point5 + ) if __name__ == '__main__':