From d23750966b5a00fc7766b4b18418ba988db28154 Mon Sep 17 00:00:00 2001 From: Grain Team Date: Mon, 20 Jan 2025 03:49:14 -0800 Subject: [PATCH] internal change PiperOrigin-RevId: 717476919 --- .../python/dataset/transformations/packing.py | 29 +++++++++++++++++++ .../dataset/transformations/packing_test.py | 18 ++++++++++++ 2 files changed, 47 insertions(+) diff --git a/grain/_src/python/dataset/transformations/packing.py b/grain/_src/python/dataset/transformations/packing.py index fef00253..a1f28aaa 100644 --- a/grain/_src/python/dataset/transformations/packing.py +++ b/grain/_src/python/dataset/transformations/packing.py @@ -328,6 +328,35 @@ def __iter__(self) -> dataset.DatasetIterator: meta_features=self._meta_features, ) + @classmethod + def first_row_to_pack_element( + cls, + element_feature_lengths: PyTree[int], + num_packing_bins: int, + length_struct: PyTree[int], + first_free_cell_per_row: PyTree[int], + ) -> int | None: + """Returns the first row to pack an element into or None if it can't fit. + + The logic is the same as the one used in the packing iterator. + + Args: + element_feature_lengths: The lengths of each feature in the element. + num_packing_bins: The number of packing bins. + length_struct: The max length of each feature. + first_free_cell_per_row: The first free cell per row. + """ + row_or_failing_component = packing_packed_batch.PackedBatch.can_add_at_row( + element_feature_lengths, + num_packing_bins, + length_struct, + first_free_cell_per_row, + ) + if row_or_failing_component.row is not None: + return row_or_failing_component.row + else: + return None + class FirstFitPackDatasetIterator(dataset.DatasetIterator): """Iterator for the first-fit packing transformation.""" diff --git a/grain/_src/python/dataset/transformations/packing_test.py b/grain/_src/python/dataset/transformations/packing_test.py index 18557751..b174aded 100644 --- a/grain/_src/python/dataset/transformations/packing_test.py +++ b/grain/_src/python/dataset/transformations/packing_test.py @@ -1016,6 +1016,24 @@ def test_nested_feature(self, mark_as_meta_feature: bool): ): _ = next(iter(ld)) + def test_first_row_to_pack_element(self): + element = {"a": 1, "b": 1} + num_packing_bins = 2 + length_struct = {"a": 10, "b": 10} + first_free_cell_per_row = { + "a": np.zeros(num_packing_bins, dtype=np.int64), + "b": np.zeros(num_packing_bins, dtype=np.int64), + } + self.assertEqual( + packing.FirstFitPackIterDataset.first_row_to_pack_element( + element, + num_packing_bins, + length_struct, + first_free_cell_per_row, + ), + 0, + ) + if __name__ == "__main__": absltest.main()