diff --git a/a_sync/iter.pyx b/a_sync/iter.pyx index 94d9a09c..8f4f6e31 100644 --- a/a_sync/iter.pyx +++ b/a_sync/iter.pyx @@ -2,6 +2,7 @@ import asyncio import functools import inspect import logging +import re import sys import weakref from copy import deepcopy @@ -186,6 +187,18 @@ class _AwaitableAsyncIterableMixin(AsyncIterable[T]): and any(pattern in attr_value.__doc__ for pattern in _FORMAT_PATTERNS) } + def match_placeholders(original_format_string: str, input_string: str) -> bool: + # Copy the format_string to ensure the original isn't modified + pattern = original_format_string.format(cls=r'(.+)', obj=r'(.+)') + + # Use anchors to ensure the entire string is matched + full_pattern = f'^{pattern}$' + + # Check if the input string matches the pattern + match = re.match(full_pattern, input_string) + + return bool(match) # Return True if it matches, otherwise False + cdef str function_name cdef object function_obj for function_name, function_obj in functions_to_redefine.items(): @@ -200,19 +213,16 @@ class _AwaitableAsyncIterableMixin(AsyncIterable[T]): redefined_function_obj = None if hasattr(_AwaitableAsyncIterableMixin, function_name): base_definition = getattr(_AwaitableAsyncIterableMixin, function_name) - if function_obj.__doc__ == base_definition.__doc__: + if base_definition.__code__ == function_obj.__code__ and match_placeholders(base_definition.__doc__, function_obj.__doc__): redefined_function_obj = deepcopy(base_definition) elif cls.__name__ != "ASyncIterable" and hasattr(ASyncIterable, function_name): base_definition = getattr(ASyncIterable, function_name) - if function_obj.__doc__ == base_definition.__doc__: + if base_definition.__code__ == function_obj.__code__ and match_placeholders(base_definition.__doc__, function_obj.__doc__): redefined_function_obj = deepcopy(base_definition) elif cls.__name__ not in ("ASyncIterable", "ASyncIterator") and hasattr(ASyncIterator, function_name): base_definition = getattr(ASyncIterator, function_name) - if function_obj.__doc__ == base_definition.__doc__: + if base_definition.__code__ == function_obj.__code__ and match_placeholders(base_definition.__doc__, function_obj.__doc__): redefined_function_obj = deepcopy(base_definition) - - if redefined_function_obj is None: - redefined_function_obj = deepcopy(function_obj) redefined_function_obj.__doc__ = function_obj.__doc__.format( cls=cls.__name__,