Skip to content

Commit

Permalink
Fix hash collisions among data points when using timing_characterizat…
Browse files Browse the repository at this point in the history
…ion_client; fix #1525

Currently, there is a hash collision of two kArrayIndex data points when running
timing_characterization_client. This caused by the fact that the timing client
relies on hashing the op to map a sample to its result, and the hash function did
not take into account the element count of every operand. So, if there are two
data points that only differ in element counts, the data point of only one of the
two samples will be recorded in the output.

This change fixes the problem by adding element counts part of the hash.

PiperOrigin-RevId: 671152475
  • Loading branch information
hnpl authored and copybara-github committed Sep 5, 2024
1 parent 21e4479 commit 72048d6
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 33 deletions.
20 changes: 18 additions & 2 deletions xls/estimators/delay_model/delay_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from collections.abc import Mapping
import dataclasses
import math
from typing import Sequence

from xls.estimators.delay_model import delay_model_pb2
Expand Down Expand Up @@ -42,10 +43,17 @@ def get_data_point_key(data_point: delay_model_pb2.DataPoint) -> str:

def get_sample_spec_key(spec: SampleSpec) -> str:
"""Returns a key that can be used to represent a sample spec in a dict."""
# This dictionary stores the element counts for array operands. For non-array
# operands, the element count is 0.
array_operand_element_counts = {
e.operand_number: math.prod(e.element_counts)
for e in spec.point.operand_element_counts
}
bit_count_strs = []
for bit_count in spec.point.operand_widths:
for operand_idx, operand_width in enumerate(spec.point.operand_widths):
operand = delay_model_pb2.Operation.Operand(
bit_count=bit_count, element_count=0
bit_count=operand_width,
element_count=array_operand_element_counts.get(operand_idx, 0),
)
bit_count_strs.append(str(operand))
key = (
Expand All @@ -71,4 +79,12 @@ def data_point_to_sample_spec(
point.operand_widths.extend(
[operand.bit_count for operand in data_point.operation.operands]
)
for operand_idx, operand in enumerate(data_point.operation.operands):
if operand.element_count:
point.operand_element_counts.append(
delay_model_pb2.OperandElementCounts(
operand_number=operand_idx,
element_counts=[operand.element_count],
)
)
return SampleSpec(op_samples, point)
145 changes: 114 additions & 31 deletions xls/estimators/delay_model/delay_model_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def _create_sample_spec(
op_name: str,
result_width: int,
operand_widths: list[int],
element_counts: list[int],
specialization: delay_model_pb2.SpecializationKind = delay_model_pb2.SpecializationKind.NO_SPECIALIZATION,
) -> delay_model_utils.SampleSpec:
"""Convenience function for tests to create SampleSpec protos."""
Expand All @@ -32,13 +33,20 @@ def _create_sample_spec(
point = delay_model_pb2.Parameterization()
point.result_width = result_width
point.operand_widths.extend(operand_widths)
for operand_idx, element_count in enumerate(element_counts):
if element_count:
operand_element_counts = delay_model_pb2.OperandElementCounts()
operand_element_counts.operand_number = operand_idx
operand_element_counts.element_counts.append(element_count)
point.operand_element_counts.append(operand_element_counts)
return delay_model_utils.SampleSpec(op_samples, point)


def _create_data_point(
op_name: str,
result_width: int,
operand_widths: list[int],
element_counts: list[int],
specialization: delay_model_pb2.SpecializationKind = delay_model_pb2.SpecializationKind.NO_SPECIALIZATION,
) -> delay_model_pb2.DataPoint:
"""Convenience function for tests to create DataPoint protos."""
Expand All @@ -47,9 +55,10 @@ def _create_data_point(
if specialization != delay_model_pb2.SpecializationKind.NO_SPECIALIZATION:
dp.operation.specialization = specialization
dp.operation.bit_count = result_width
for width in operand_widths:
for width, count in zip(operand_widths, element_counts):
operand = delay_model_pb2.Operation.Operand()
operand.bit_count = width
operand.element_count = count
dp.operation.operands.append(operand)
return dp

Expand All @@ -59,46 +68,78 @@ class DelayModelUtilsTest(parameterized.TestCase):
def test_sample_spec_key_equal_for_same_sample(self):
self.assertEqual(
delay_model_utils.get_sample_spec_key(
_create_sample_spec('kAnd', 16, [8, 8])
_create_sample_spec('kAnd', 16, [8, 8], [0, 0])
),
delay_model_utils.get_sample_spec_key(
_create_sample_spec('kAnd', 16, [8, 8])
_create_sample_spec('kAnd', 16, [8, 8], [0, 0])
),
)

def test_sample_spec_key_differs_for_op(self):
self.assertNotEqual(
delay_model_utils.get_sample_spec_key(
_create_sample_spec('kAnd', 16, [8, 8])
_create_sample_spec('kAnd', 16, [8, 8], [0, 0])
),
delay_model_utils.get_sample_spec_key(
_create_sample_spec('kOr', 16, [8, 8])
_create_sample_spec('kOr', 16, [8, 8], [0, 0])
),
)

def test_sample_spec_key_differs_for_result_width(self):
self.assertNotEqual(
delay_model_utils.get_sample_spec_key(
_create_sample_spec('kAnd', 16, [8, 8])
_create_sample_spec('kAnd', 16, [8, 8], [0, 0])
),
delay_model_utils.get_sample_spec_key(
_create_sample_spec('kAnd', 8, [8, 8])
_create_sample_spec('kAnd', 8, [8, 8], [0, 0])
),
)

@parameterized.named_parameters(
('different_operand_counts', [8, 8, 8], [8, 8]),
('different_widths_same_count', [8, 8], [8, 4]),
('different_operand_counts', [8, 8, 8], [0, 0, 0], [8, 8], [0, 0]),
('different_widths_same_count', [8, 8], [0, 0], [8, 4], [0, 0]),
)
def test_sample_spec_key_differs_for_operand_widths(
self, left_operand_widths, right_operand_widths
self,
left_operand_widths,
left_element_counts,
right_operand_widths,
right_element_counts,
):
self.assertNotEqual(
delay_model_utils.get_sample_spec_key(
_create_sample_spec(
'kAnd', 8, left_operand_widths, left_element_counts
)
),
delay_model_utils.get_sample_spec_key(
_create_sample_spec(
'kAnd', 8, right_operand_widths, right_element_counts
)
),
)

@parameterized.named_parameters(
('different_element_counts', [8, 8], [0, 0], [8, 8], [0, 16]),
('different_widths_same_element_count', [8, 1], [0, 32], [8, 8], [0, 32]),
)
def test_sample_spec_key_differs_for_element_counts(
self,
left_operand_widths,
left_element_counts,
right_operand_widths,
right_element_counts,
):
self.assertNotEqual(
delay_model_utils.get_sample_spec_key(
_create_sample_spec('kAnd', 8, left_operand_widths)
_create_sample_spec(
'kArrayIndex', 8, left_operand_widths, left_element_counts
)
),
delay_model_utils.get_sample_spec_key(
_create_sample_spec('kAnd', 8, right_operand_widths)
_create_sample_spec(
'kArrayIndex', 8, right_operand_widths, right_element_counts
)
),
)

Expand All @@ -119,73 +160,109 @@ def test_sample_spec_key_differs_for_specialization(
):
self.assertNotEqual(
delay_model_utils.get_sample_spec_key(
_create_sample_spec('kAnd', 4, [4], left_specialization)
_create_sample_spec('kAnd', 4, [4], [0], left_specialization)
),
delay_model_utils.get_sample_spec_key(
_create_sample_spec('kAnd', 4, [4], right_specialization)
_create_sample_spec('kAnd', 4, [4], [0], right_specialization)
),
)

@parameterized.named_parameters(
('no_operand_widths', 'kOr', 32, []),
('no_specialization', 'kOr', 32, [16, 16]),
('no_operand_widths', 'kOr', 32, [], []),
('no_specialization', 'kOr', 32, [16, 16], [0, 0]),
(
'specialization',
'kOr',
32,
[16, 16],
[0, 0],
delay_model_pb2.SpecializationKind.HAS_LITERAL_OPERAND,
),
('non_zero_element_count', 'kArrayIndex', 8, [8, 8], [0, 16]),
)
def test_data_point_to_sample_spec(
self,
op_name,
result_width,
operand_widths,
element_counts,
specialization=delay_model_pb2.SpecializationKind.NO_SPECIALIZATION,
):
self.assertEqual(
delay_model_utils.data_point_to_sample_spec(
_create_data_point(
op_name, result_width, operand_widths, specialization
op_name,
result_width,
operand_widths,
element_counts,
specialization,
)
),
_create_sample_spec(
op_name, result_width, operand_widths, specialization
op_name,
result_width,
operand_widths,
element_counts,
specialization,
),
)

def test_get_data_point_key(self):
@parameterized.named_parameters(
(
'speicialization',
'kOr',
32,
[16, 16],
[0, 0],
delay_model_pb2.SpecializationKind.HAS_LITERAL_OPERAND,
),
('non_zero_element_count', 'kArrayIndex', 8, [8, 8], [0, 16]),
)
def test_get_data_point_key(
self,
op_name,
result_width,
operand_widths,
element_counts,
specialization=delay_model_pb2.SpecializationKind.NO_SPECIALIZATION,
):
self.assertEqual(
delay_model_utils.get_data_point_key(
_create_data_point(
'kOr',
32,
[16, 16],
delay_model_pb2.SpecializationKind.HAS_LITERAL_OPERAND,
op_name,
result_width,
operand_widths,
element_counts,
specialization,
)
),
delay_model_utils.get_sample_spec_key(
_create_sample_spec(
'kOr',
32,
[16, 16],
delay_model_pb2.SpecializationKind.HAS_LITERAL_OPERAND,
op_name,
result_width,
operand_widths,
element_counts,
specialization,
)
),
)

def test_map_data_points_by_key(self):
point1 = _create_data_point('kOr', 32, [])
point2 = _create_data_point('kOr', 32, [16, 16])
point1 = _create_data_point('kOr', 32, [], [])
point2 = _create_data_point('kOr', 32, [16, 16], [0, 0])
point3 = _create_data_point(
'kAdd',
32,
[16, 16],
[0, 0],
delay_model_pb2.SpecializationKind.HAS_LITERAL_OPERAND,
)
mapping = delay_model_utils.map_data_points_by_key([point1, point2, point3])
self.assertLen(mapping, 3)
point4 = _create_data_point('kArrayIndex', 8, [8, 8], [0, 16])
point5 = _create_data_point('kArrayIndex', 8, [8, 8], [0, 1])
mapping = delay_model_utils.map_data_points_by_key(
[point1, point2, point3, point4, point5]
)
self.assertLen(mapping, 5)
self.assertEqual(
mapping[delay_model_utils.get_data_point_key(point1)], point1
)
Expand All @@ -195,6 +272,12 @@ def test_map_data_points_by_key(self):
self.assertEqual(
mapping[delay_model_utils.get_data_point_key(point3)], point3
)
self.assertEqual(
mapping[delay_model_utils.get_data_point_key(point4)], point4
)
self.assertEqual(
mapping[delay_model_utils.get_data_point_key(point5)], point5
)


if __name__ == '__main__':
Expand Down

0 comments on commit 72048d6

Please sign in to comment.