Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support union types X | Y syntax for HfArgumentParser for Python 3.10+ #23126

Merged
merged 3 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dataclasses
import json
import sys
import types
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
from copy import copy
from enum import Enum
Expand Down Expand Up @@ -159,7 +160,7 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
aliases = [aliases]

origin_type = getattr(field.type, "__origin__", field.type)
if origin_type is Union:
if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)):
if str not in field.type.__args__ and (
len(field.type.__args__) != 2 or type(None) not in field.type.__args__
):
Expand Down Expand Up @@ -245,10 +246,23 @@ def _add_dataclass_arguments(self, dtype: DataClassType):
type_hints: Dict[str, type] = get_type_hints(dtype)
except NameError:
raise RuntimeError(
f"Type resolution failed for f{dtype}. Try declaring the class in global scope or "
f"Type resolution failed for {dtype}. Try declaring the class in global scope or "
"removing line of `from __future__ import annotations` which opts in Postponed "
"Evaluation of Annotations (PEP 563)"
)
except TypeError as ex:
# Remove this block when we drop Python 3.9 support
if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex):
python_version = ".".join(map(str, sys.version_info[:3]))
raise RuntimeError(
f"Type resolution failed for {dtype} on Python {python_version}. Try removing "
"line of `from __future__ import annotations` which opts in union types as "
"`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To "
"support Python versions that lower than 3.10, you need to use "
"`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of "
"`X | None`."
) from ex
raise

for field in dataclasses.fields(dtype):
if not field.init:
Expand Down
75 changes: 55 additions & 20 deletions tests/utils/test_hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse
import json
import os
import sys
import tempfile
import unittest
from argparse import Namespace
Expand All @@ -37,6 +38,9 @@
from typing_extensions import Literal


PEP604 = sys.version_info >= (3, 10)
XuehaiPan marked this conversation as resolved.
Show resolved Hide resolved


def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)

Expand All @@ -62,6 +66,15 @@ class WithDefaultBoolExample:
opt: Optional[bool] = None


if PEP604:

@dataclass
class WithDefaultBoolExamplePep604:
foo: bool = False
baz: bool = True
opt: bool | None = None
XuehaiPan marked this conversation as resolved.
Show resolved Hide resolved


class BasicEnum(Enum):
titi = "titi"
toto = "toto"
Expand Down Expand Up @@ -98,6 +111,17 @@ class OptionalExample:
des: Optional[List[int]] = list_field(default=[])


if PEP604:

@dataclass
class OptionalExamplePep604:
foo: int | None = None
bar: float | None = field(default=None, metadata={"help": "help message"})
baz: str | None = None
ces: list[str] | None = list_field(default=[])
des: list[int] | None = list_field(default=[])


@dataclass
class ListExample:
foo_int: List[int] = list_field(default=[])
Expand Down Expand Up @@ -167,31 +191,36 @@ def test_with_default(self):
self.argparsersEqual(parser, expected)

def test_with_default_bool(self):
parser = HfArgumentParser(WithDefaultBoolExample)

expected = argparse.ArgumentParser()
expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
# A boolean no_* argument always has to come after its "default: True" regular counter-part
# and its default must be set to False
expected.add_argument("--no_baz", action="store_false", default=False, dest="baz")
expected.add_argument("--opt", type=string_to_bool, default=None)
self.argparsersEqual(parser, expected)

args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))
dataclass_types = [WithDefaultBoolExample]
if PEP604:
dataclass_types.append(WithDefaultBoolExamplePep604)

for dataclass_type in dataclass_types:
parser = HfArgumentParser(dataclass_type)
self.argparsersEqual(parser, expected)

args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))

args = parser.parse_args(["--foo", "--no_baz"])
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
args = parser.parse_args(["--foo", "--no_baz"])
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))

args = parser.parse_args(["--foo", "--baz"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
args = parser.parse_args(["--foo", "--baz"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))

args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))
args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))

args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))

def test_with_enum(self):
parser = HfArgumentParser(MixedTypeEnumExample)
Expand Down Expand Up @@ -266,21 +295,27 @@ def test_with_list(self):
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))

def test_with_optional(self):
parser = HfArgumentParser(OptionalExample)

expected = argparse.ArgumentParser()
expected.add_argument("--foo", default=None, type=int)
expected.add_argument("--bar", default=None, type=float, help="help message")
expected.add_argument("--baz", default=None, type=str)
expected.add_argument("--ces", nargs="+", default=[], type=str)
expected.add_argument("--des", nargs="+", default=[], type=int)
self.argparsersEqual(parser, expected)

args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None, ces=[], des=[]))
dataclass_types = [OptionalExample]
if PEP604:
dataclass_types.append(OptionalExamplePep604)

for dataclass_type in dataclass_types:
parser = HfArgumentParser(dataclass_type)

self.argparsersEqual(parser, expected)

args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None, ces=[], des=[]))

args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))

def test_with_required(self):
parser = HfArgumentParser(RequiredExample)
Expand Down