Skip to content

⚡️ Speed up function flatten_grouping by 330% #3303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from

Conversation

misrasaurabh1
Copy link

📄 330% (3.30x) speedup for flatten_grouping in dash/_grouping.py

⏱️ Runtime : 8.92 milliseconds 2.07 milliseconds (best of 51 runs)

📝 Explanation and details

Here is an optimized version of the provided code, focusing on reducing function call and memory overhead, inlining and shortcutting where safe, and avoiding repetitive work.
Key optimizations:

  • Avoid unnecessary list comprehensions and intermediate lists where possible by favoring the use of local variables and iterative approaches for flatten_grouping.
  • Move schema validation out of recursive calls by doing it only at the top level if possible inside flatten_grouping, to avoid re-validating substructures.
  • Reduce attribute/tuple lookups and repeated isinstance checks.
  • Micro-optimize recursion: Tailor the recursive structure to minimize temporary list creation.
  • Minimize tuple concatenation in validate_grouping by reusing a growing list for paths.
  • Avoid set/schema conversions on every recursive call in dicts.

Summary of changes and performance justifications:

  • flatten_grouping is now iterative and uses an explicit stack, reducing Python call stack depth and temporary list creation.
  • Elements are collected in a result list in reverse order for speed but reversed once at the end for correctness.
  • Dict and tuple/list types are checked using type() is ... for speed over isinstance(), since structure is known via schema.
  • validate_grouping uses index-based iteration to avoid tuple unpacking and leverages direct key traversal for dicts.
  • All original logic and error handling is preserved for 1:1 behavior.

This approach should result in lower CPU time due to less recursive call and reduced repeated computation, especially for large and deeply nested structures.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 69 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests Details
import pytest  # used for our unit tests
from dash._grouping import flatten_grouping
# function to test
from dash.exceptions import InvalidCallbackReturnValue


class SchemaTypeValidationError(InvalidCallbackReturnValue):
    def __init__(self, value, full_schema, path, expected_type):
        super().__init__(
            msg=f"""
                Schema: {full_schema}
                Path: {repr(path)}
                Expected type: {expected_type}
                Received value of type {type(value)}:
                    {repr(value)}
                """
        )

    @classmethod
    def check(cls, value, full_schema, path, expected_type):
        if not isinstance(value, expected_type):
            raise SchemaTypeValidationError(value, full_schema, path, expected_type)


class SchemaLengthValidationError(InvalidCallbackReturnValue):
    def __init__(self, value, full_schema, path, expected_len):
        super().__init__(
            msg=f"""
                Schema: {full_schema}
                Path: {repr(path)}
                Expected length: {expected_len}
                Received value of length {len(value)}:
                    {repr(value)}
                """
        )

    @classmethod
    def check(cls, value, full_schema, path, expected_len):
        if len(value) != expected_len:
            raise SchemaLengthValidationError(value, full_schema, path, expected_len)


class SchemaKeysValidationError(InvalidCallbackReturnValue):
    def __init__(self, value, full_schema, path, expected_keys):
        super().__init__(
            msg=f"""
                Schema: {full_schema}
                Path: {repr(path)}
                Expected keys: {expected_keys}
                Received value with keys {set(value.keys())}:
                    {repr(value)}
                """
        )

    @classmethod
    def check(cls, value, full_schema, path, expected_keys):
        if set(value.keys()) != set(expected_keys):
            raise SchemaKeysValidationError(value, full_schema, path, expected_keys)


# unit tests

# ------------------- BASIC TEST CASES -------------------

def test_scalar_value():
    # Scalar input, no schema
    codeflash_output = flatten_grouping(42)
    codeflash_output = flatten_grouping("foo")
    codeflash_output = flatten_grouping(None)
    codeflash_output = flatten_grouping(3.14)

def test_flat_list():
    # Flat list of scalars
    codeflash_output = flatten_grouping([1, 2, 3])
    codeflash_output = flatten_grouping(["a", "b", "c"])

def test_flat_tuple():
    # Flat tuple of scalars
    codeflash_output = flatten_grouping((1, 2, 3))

def test_flat_dict():
    # Flat dict of scalars
    codeflash_output = flatten_grouping({"a": 1, "b": 2})

def test_nested_list():
    # Nested list of scalars
    codeflash_output = flatten_grouping([[1, 2], [3, 4]])

def test_nested_tuple():
    # Nested tuple of scalars
    codeflash_output = flatten_grouping(((1, 2), (3, 4)))

def test_nested_dict():
    # Nested dict of scalars
    codeflash_output = flatten_grouping({"a": {"x": 1, "y": 2}, "b": {"x": 3, "y": 4}})

def test_mixed_nested_structures():
    # Mixed list, tuple, dict
    grouping = [{"x": (1, 2)}, {"x": (3, 4)}]
    codeflash_output = flatten_grouping(grouping)

def test_schema_with_tuple_as_scalar():
    # When schema is provided, treat tuple as scalar
    grouping = (1, 2)
    schema = (1, 2)
    codeflash_output = flatten_grouping(grouping, schema)

    grouping = ((1, 2), (3, 4))
    schema = ((0, 0), (0, 0))
    codeflash_output = flatten_grouping(grouping, schema)

def test_schema_with_dict_as_scalar():
    # When schema is provided, treat dict as scalar
    grouping = {"a": 1, "b": 2}
    schema = {"a": 0, "b": 0}
    codeflash_output = flatten_grouping(grouping, schema)

    grouping = {"foo": {"x": 1, "y": 2}, "bar": {"x": 3, "y": 4}}
    schema = {"foo": {"x": 0, "y": 0}, "bar": {"x": 0, "y": 0}}
    codeflash_output = flatten_grouping(grouping, schema)

def test_schema_with_mixed_types():
    # Schema with mixed types
    grouping = ({"a": 1, "b": 2}, [3, 4])
    schema = ({"a": 0, "b": 0}, [0, 0])
    codeflash_output = flatten_grouping(grouping, schema)

# ------------------- EDGE TEST CASES -------------------

def test_empty_list():
    # Empty list
    codeflash_output = flatten_grouping([])

def test_empty_tuple():
    # Empty tuple
    codeflash_output = flatten_grouping(())

def test_empty_dict():
    # Empty dict
    codeflash_output = flatten_grouping({})

def test_empty_nested_structures():
    # Nested empty structures
    codeflash_output = flatten_grouping([[], []])
    codeflash_output = flatten_grouping({"a": [], "b": []})
    codeflash_output = flatten_grouping([{}, {}])

def test_schema_empty_list():
    # Empty list with schema
    codeflash_output = flatten_grouping([], [])

def test_schema_empty_dict():
    # Empty dict with schema
    codeflash_output = flatten_grouping({}, {})







def test_schema_with_none():
    # None as scalar
    grouping = (None, None)
    schema = (0, 0)
    codeflash_output = flatten_grouping(grouping, schema)

def test_deeply_nested_structures():
    # Deeply nested
    grouping = [[[[[1]]]]]
    codeflash_output = flatten_grouping(grouping)
    grouping = {"a": {"b": {"c": {"d": 5}}}}
    codeflash_output = flatten_grouping(grouping)

def test_schema_with_empty_nested_structures():
    # Schema with empty nested structures
    grouping = [[], []]
    schema = [[], []]
    codeflash_output = flatten_grouping(grouping, schema)

# ------------------- LARGE SCALE TEST CASES -------------------

def test_large_flat_list():
    # Large flat list
    data = list(range(1000))
    codeflash_output = flatten_grouping(data)

def test_large_nested_list():
    # Large nested list (10 lists of 100 elements)
    data = [list(range(i*100, (i+1)*100)) for i in range(10)]
    expected = list(range(1000))
    codeflash_output = flatten_grouping(data)

def test_large_flat_dict():
    # Large flat dict
    data = {str(i): i for i in range(1000)}
    expected = list(range(1000))
    codeflash_output = flatten_grouping(data)

def test_large_nested_dict():
    # Large nested dict: 10 dicts of 100 elements each
    data = {str(i): {str(j): i*100 + j for j in range(100)} for i in range(10)}
    expected = [i*100 + j for i in range(10) for j in range(100)]
    codeflash_output = flatten_grouping(data)

def test_large_mixed_structure():
    # Large mixed structure: list of dicts of tuples
    data = [
        {str(j): (i*10 + j, i*10 + j + 0.5) for j in range(10)}
        for i in range(10)
    ]
    # schema is needed to treat tuples as tuples, not as iterables
    schema = [
        {str(j): (0, 0.0) for j in range(10)}
        for i in range(10)
    ]
    expected = []
    for i in range(10):
        for j in range(10):
            expected.extend([i*10 + j, i*10 + j + 0.5])
    codeflash_output = flatten_grouping(data, schema)

def test_large_deeply_nested():
    # Large deeply nested structure (depth 5, width 4)
    def make_nested(val, depth):
        if depth == 0:
            return val
        return [make_nested(val, depth-1) for _ in range(4)]
    data = make_nested(7, 5)
    # Should produce 4^5 = 1024 leaves, all 7
    codeflash_output = flatten_grouping(data)




import pytest  # used for our unit tests
from dash._grouping import flatten_grouping
# function to test
from dash.exceptions import InvalidCallbackReturnValue


class SchemaTypeValidationError(InvalidCallbackReturnValue):
    def __init__(self, value, full_schema, path, expected_type):
        super().__init__(
            msg=f"""
                Schema: {full_schema}
                Path: {repr(path)}
                Expected type: {expected_type}
                Received value of type {type(value)}:
                    {repr(value)}
                """
        )

    @classmethod
    def check(cls, value, full_schema, path, expected_type):
        if not isinstance(value, expected_type):
            raise SchemaTypeValidationError(value, full_schema, path, expected_type)


class SchemaLengthValidationError(InvalidCallbackReturnValue):
    def __init__(self, value, full_schema, path, expected_len):
        super().__init__(
            msg=f"""
                Schema: {full_schema}
                Path: {repr(path)}
                Expected length: {expected_len}
                Received value of length {len(value)}:
                    {repr(value)}
                """
        )

    @classmethod
    def check(cls, value, full_schema, path, expected_len):
        if len(value) != expected_len:
            raise SchemaLengthValidationError(value, full_schema, path, expected_len)


class SchemaKeysValidationError(InvalidCallbackReturnValue):
    def __init__(self, value, full_schema, path, expected_keys):
        super().__init__(
            msg=f"""
                Schema: {full_schema}
                Path: {repr(path)}
                Expected keys: {expected_keys}
                Received value with keys {set(value.keys())}:
                    {repr(value)}
                """
        )

    @classmethod
    def check(cls, value, full_schema, path, expected_keys):
        if set(value.keys()) != set(expected_keys):
            raise SchemaKeysValidationError(value, full_schema, path, expected_keys)

# unit tests

# ---------------------------
# 1. Basic Test Cases
# ---------------------------

def test_flatten_scalar():
    # Scalar value should return a single-element list
    codeflash_output = flatten_grouping(42)
    codeflash_output = flatten_grouping('foo')
    codeflash_output = flatten_grouping(None)
    codeflash_output = flatten_grouping(3.14)

def test_flatten_tuple_of_scalars():
    # Tuple of scalars should flatten to a list of those scalars
    codeflash_output = flatten_grouping((1, 2, 3))
    codeflash_output = flatten_grouping(('a', 'b', 'c'))

def test_flatten_list_of_scalars():
    # List of scalars should flatten to a list of those scalars
    codeflash_output = flatten_grouping([10, 20, 30])
    codeflash_output = flatten_grouping(['x', 'y', 'z'])

def test_flatten_nested_tuple():
    # Nested tuple should flatten all scalars in order
    grouping = ((1, 2), 3)
    expected = [1, 2, 3]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_nested_list():
    # Nested list should flatten all scalars in order
    grouping = [[4, 5], 6]
    expected = [4, 5, 6]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_tuple_of_lists():
    # Tuple containing lists
    grouping = ([1, 2], [3, 4])
    expected = [1, 2, 3, 4]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_dict_of_scalars():
    # Dict of scalars should flatten to values in key order
    grouping = {'a': 1, 'b': 2}
    expected = [1, 2]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_dict_of_lists():
    # Dict with list values should flatten all values in key order
    grouping = {'x': [1, 2], 'y': [3, 4]}
    expected = [1, 2, 3, 4]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_dict_of_dicts():
    # Nested dicts
    grouping = {'outer': {'a': 1, 'b': 2}, 'other': {'c': 3, 'd': 4}}
    expected = [1, 2, 3, 4]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_mixed_structures():
    # Mixed dict/list/tuple nesting
    grouping = {'x': (1, 2), 'y': [3, 4]}
    expected = [1, 2, 3, 4]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_with_explicit_schema():
    # Use explicit schema to treat tuple as scalar
    grouping = ((1, 2), 3)
    schema = ('a', 'b')  # schema of two scalars
    # Should treat (1,2) as a scalar if schema is not a tuple
    codeflash_output = flatten_grouping((1, 2), 'a')
    # But with schema as tuple, should flatten
    codeflash_output = flatten_grouping(grouping, schema)

# ---------------------------
# 2. Edge Test Cases
# ---------------------------

def test_flatten_empty_tuple():
    # Empty tuple should return empty list
    codeflash_output = flatten_grouping(())

def test_flatten_empty_list():
    # Empty list should return empty list
    codeflash_output = flatten_grouping([])

def test_flatten_empty_dict():
    # Empty dict should return empty list
    codeflash_output = flatten_grouping({})

def test_flatten_deeply_nested():
    # Deeply nested structure
    grouping = (((1,),),)
    expected = [1]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_with_none_values():
    # None as a value should be included
    grouping = (None, (1, None))
    expected = [None, 1, None]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_dict_with_non_string_keys():
    # Dict with integer keys
    grouping = {0: 1, 1: 2}
    expected = [1, 2]
    codeflash_output = flatten_grouping(grouping)




def test_flatten_tuple_of_dicts():
    # Tuple containing dicts as elements
    grouping = ({'a': 1}, {'b': 2})
    expected = [1, 2]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_dict_of_tuples():
    # Dict containing tuples as values
    grouping = {'x': (1, 2), 'y': (3, 4)}
    expected = [1, 2, 3, 4]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_tuple_with_empty_elements():
    # Tuple with empty list/dict/tuple as elements
    grouping = ([], {}, ())
    expected = []
    codeflash_output = flatten_grouping(grouping)

def test_flatten_dict_with_empty_values():
    # Dict with empty values
    grouping = {'a': [], 'b': {}}
    expected = []
    codeflash_output = flatten_grouping(grouping)

def test_flatten_grouping_is_its_own_schema():
    # If schema is None, grouping is its own schema
    grouping = ((1, 2), {'a': 3, 'b': 4})
    expected = [1, 2, 3, 4]
    codeflash_output = flatten_grouping(grouping)


def test_flatten_large_flat_list():
    # Large flat list
    grouping = list(range(1000))
    expected = list(range(1000))
    codeflash_output = flatten_grouping(grouping)

def test_flatten_large_nested_list():
    # Large nested list of lists
    grouping = [[i, i+1] for i in range(0, 1000, 2)]
    expected = list(range(1000))
    codeflash_output = flatten_grouping(grouping)

def test_flatten_large_dict_of_scalars():
    # Large dict of scalars
    grouping = {str(i): i for i in range(1000)}
    expected = [i for i in range(1000)]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_large_dict_of_lists():
    # Large dict of lists
    grouping = {str(i): [i, i+1] for i in range(0, 1000, 2)}
    expected = [i for i in range(1000)]
    codeflash_output = flatten_grouping(grouping)

def test_flatten_large_mixed_structure():
    # Large mixed structure: dict of tuples of lists
    grouping = {str(i): ([i], (i+1,)) for i in range(500)}
    expected = []
    for i in range(500):
        expected.extend([i, i+1])
    codeflash_output = flatten_grouping(grouping)

def test_flatten_large_with_explicit_schema():
    # Large structure with explicit schema
    grouping = [(i, i+1) for i in range(0, 1000, 2)]
    schema = [(0, 0)] * 500  # tuple of two scalars, repeated 500 times
    expected = list(range(1000))
    codeflash_output = flatten_grouping(grouping, schema)

To edit these changes git checkout codeflash/optimize-flatten_grouping-max6hy2z and push.

Codeflash

Contributor Checklist

  • I have run the tests locally and they passed. (refer to testing section in contributing)
  • I have added tests, or extended existing tests, to cover any new features or bugs fixed in this PR

optionals

  • I have added entry in the CHANGELOG.md
  • If this PR needs a follow-up in dash docs, community thread, I have mentioned the relevant URLS as follows
    • this GitHub #PR number updates the dash docs
    • here is the show and tell thread in Plotly Dash community

codeflash-ai bot and others added 2 commits May 21, 2025 00:02
Here is an optimized version of the provided code, focusing on reducing function call and memory overhead, inlining and shortcutting where safe, and avoiding repetitive work.  
**Key optimizations:**
- **Avoid unnecessary list comprehensions** and intermediate lists where possible by favoring the use of local variables and iterative approaches for `flatten_grouping`.
- **Move schema validation** out of recursive calls by doing it only at the top level if possible inside `flatten_grouping`, to avoid re-validating substructures.
- **Reduce attribute/tuple lookups** and repeated isinstance checks.
- **Micro-optimize recursion:** Tailor the recursive structure to minimize temporary list creation.
- **Minimize tuple concatenation** in `validate_grouping` by reusing a growing list for paths.
- **Avoid set/schema conversions on every recursive call in dicts.**



**Summary of changes and performance justifications:**
- `flatten_grouping` is now iterative and uses an explicit stack, reducing Python call stack depth and temporary list creation.
- Elements are collected in a `result` list in reverse order for speed but reversed once at the end for correctness.
- Dict and tuple/list types are checked using `type() is ...` for speed over `isinstance()`, since structure is known via schema.
- `validate_grouping` uses index-based iteration to avoid tuple unpacking and leverages direct key traversal for dicts.
- All original logic and error handling is preserved for 1:1 behavior.

This approach should result in lower CPU time due to less recursive call and reduced repeated computation, especially for large and deeply nested structures.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant