Skip to content

Commit

Permalink
Merge pull request #649 from flairNLP/fix-a-bug-with-attribute-defaul…
Browse files Browse the repository at this point in the history
…ts-values

Fix a bug with attribute defaults and add `default_factory` parameter
  • Loading branch information
MaxDall authored Oct 24, 2024
2 parents 9eea357 + 129a1ec commit 92b348a
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 12 deletions.
38 changes: 28 additions & 10 deletions src/fundus/parser/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,35 @@ def __repr__(self):


class Attribute(RegisteredFunction):
def __init__(self, func: Callable[[object], Any], priority: Optional[int], validate: bool):
def __init__(
self,
func: Callable[[object], Any],
priority: Optional[int],
validate: bool,
default_factory: Optional[Callable[[], Any]],
):
self.validate = validate
self.default_factory = default_factory
super(Attribute, self).__init__(func=func, priority=priority)

@functools.cached_property
def __default__(self):
if self.default_factory is not None:
return self.default_factory()

annotation = self.__annotations__["return"]
origin = get_origin(annotation)
args = get_args(annotation)

if not (origin or args):
try:
default = annotation()
except TypeError:
default = None
elif callable(origin):
default = origin()
default = annotation()
elif origin == Union:
if type(None) in args:
default = None
else:
raise NotImplementedError(f"Unsupported args {args}")
raise NotImplementedError(f"Cannot determine default for {origin!r} with args {args!r}")
elif isinstance(origin, type):
default = origin()
else:
raise NotImplementedError(f"Unsupported origin {origin}")
return default
Expand All @@ -122,8 +129,15 @@ def wrapper(func):
return wrapper(cls)


def attribute(cls=None, /, *, priority: Optional[int] = None, validate: bool = True):
return _register(cls, factory=Attribute, priority=priority, validate=validate)
def attribute(
cls=None,
/,
*,
priority: Optional[int] = None,
validate: bool = True,
default_factory: Optional[Callable[[], Any]] = None,
):
return _register(cls, factory=Attribute, priority=priority, validate=validate, default_factory=default_factory)


def function(cls=None, /, *, priority: Optional[int] = None):
Expand Down Expand Up @@ -232,6 +246,10 @@ def parse(self, html: str, error_handling: Literal["suppress", "catch", "raise"]
except Exception as err:
if error_handling == "suppress":
parsed_data[attribute_name] = func.__default__
logger.info(
f"Couldn't parse attribute {attribute_name!r} for "
f"{self.precomputed.meta.get('og:url')!r}: {err}"
)
elif error_handling == "catch":
parsed_data[attribute_name] = err
elif error_handling == "raise":
Expand Down
2 changes: 1 addition & 1 deletion src/fundus/scraping/article.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __getattr__(self, item: str):

@property
def plaintext(self) -> Optional[str]:
return str(self.body) or None
return str(self.body) or None if not isinstance(self.body, Exception) else None

@property
def lang(self) -> Optional[str]:
Expand Down
50 changes: 49 additions & 1 deletion tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import List
from typing import Any, Dict, List, Optional, Tuple, Union

import lxml.html
import pytest
Expand Down Expand Up @@ -75,6 +75,54 @@ def unvalidated(self) -> str:
assert (funcs := list(unvalidated)) != [parser.unvalidated]
assert funcs[0].__func__ == parser.unvalidated.__func__

def test_default_values_for_attributes(self):
class Parser(BaseParser):
@attribute
def test_optional(self) -> Optional[str]:
raise Exception

@attribute
def test_collection(self) -> Tuple[str, ...]:
raise Exception

@attribute
def test_nested_collection(self) -> List[Tuple[str, str]]:
raise Exception

@attribute(default_factory=lambda: "This is a default")
def test_default_factory(self) -> Union[str, bool]:
raise Exception

@attribute
def test_boolean(self) -> bool:
raise Exception

parser = Parser()

default_values = {attr.__name__: attr.__default__ for attr in parser.attributes()}

expected_values: Dict[str, Any] = {
"test_optional": None,
"test_collection": tuple(),
"test_nested_collection": list(),
"test_default_factory": "This is a default",
"test_boolean": False,
"free_access": False,
}

for name, value in default_values.items():
assert value == expected_values[name]

class ParserWithUnion(BaseParser):
@attribute
def this_should_fail(self) -> Union[str, bool]:
raise Exception

parser_with_union = ParserWithUnion()

with pytest.raises(NotImplementedError):
default_values = {attr.__name__: attr.__default__ for attr in parser_with_union.attributes()}


class TestParserProxy:
def test_empty_proxy(self, empty_parser_proxy):
Expand Down

0 comments on commit 92b348a

Please sign in to comment.