Skip to content

Commit

Permalink
Fix inference of required keys in TypedDicts (#571)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com>
  • Loading branch information
a-gardner1 and mauvilsa authored Sep 13, 2024
1 parent 88c0387 commit 27137b5
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 16 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ v4.33.0 (2024-09-??)
Added
^^^^^
- Support for Python 3.13.
- Support for `NotRequired` and `Required` annotations for `TypedDict` keys (`#571
<https://github.com/omni-us/jsonargparse/pull/571>`__)

Fixed
^^^^^
Expand Down
6 changes: 4 additions & 2 deletions DOCUMENTATION.rst
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,10 @@ Some notes about this support are:
:ref:`parsing-paths` and :ref:`parsing-urls`.

- ``Dict``, ``Mapping``, ``MutableMapping``, ``MappingProxyType``,
``OrderedDict`` and ``TypedDict`` are supported but only with ``str`` or
``int`` keys. For more details see :ref:`dict-items`.
``OrderedDict``, and ``TypedDict`` are supported but only with ``str`` or
``int`` keys. ``Required`` and ``NotRequired`` are also supported for
fine-grained specification of required/optional ``TypedDict`` keys.
For more details see :ref:`dict-items`.

- ``Tuple``, ``Set`` and ``MutableSet`` are supported even though they can't be
represented in json distinguishable from a list. Each ``Tuple`` element
Expand Down
55 changes: 48 additions & 7 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import os
import re
import sys
from argparse import ArgumentError
from collections import OrderedDict, abc, defaultdict
from contextlib import contextmanager, suppress
Expand Down Expand Up @@ -88,6 +89,21 @@


Literal = typing_extensions_import("Literal")
NotRequired = typing_extensions_import("NotRequired")
Required = typing_extensions_import("Required")
TypedDict = typing_extensions_import("TypedDict")
_TypedDictMeta = typing_extensions_import("_TypedDictMeta")


def _capture_typing_extension_shadows(name: str, *collections) -> None:
"""
Ensure different origins for types in typing_extensions are captured.
"""
current_module = sys.modules[__name__]
typehint = getattr(current_module, name)
if getattr(typehint, "__module__", None) == "typing_extensions" and hasattr(__import__("typing"), name):
for collection in collections:
collection.add(getattr(__import__("typing"), name))


root_types = {
Expand Down Expand Up @@ -124,6 +140,8 @@
OrderedDict,
Callable,
abc.Callable,
NotRequired,
Required,
}

leaf_types = {
Expand Down Expand Up @@ -160,9 +178,20 @@
callable_origin_types = {Callable, abc.Callable}

literal_types = {Literal}
if getattr(Literal, "__module__", None) == "typing_extensions" and hasattr(__import__("typing"), "Literal"):
root_types.add(__import__("typing").Literal)
literal_types.add(__import__("typing").Literal)
_capture_typing_extension_shadows("Literal", root_types, literal_types)

not_required_types = {NotRequired}
_capture_typing_extension_shadows("NotRequired", root_types, not_required_types)

required_types = {Required}
_capture_typing_extension_shadows("Required", root_types, required_types)
not_required_required_types = not_required_types.union(required_types)

typed_dict_types = {TypedDict}
_capture_typing_extension_shadows("TypedDict", typed_dict_types)

typed_dict_meta_types = {_TypedDictMeta}
_capture_typing_extension_shadows("_TypedDictMeta", typed_dict_meta_types)

subclass_arg_parser: ContextVar = ContextVar("subclass_arg_parser")
allow_default_instance: ContextVar = ContextVar("allow_default_instance", default=False)
Expand Down Expand Up @@ -889,11 +918,18 @@ def adapt_typehints(
else:
kwargs["prev_val"] = None
val[k] = adapt_typehints(v, subtypehints[1], **kwargs)
if get_import_path(typehint.__class__) == "typing._TypedDictMeta":
if type(typehint) in typed_dict_meta_types:
if typehint.__total__:
missing_keys = typehint.__annotations__.keys() - val.keys()
if missing_keys:
raise_unexpected_value(f"Missing required keys: {missing_keys}", val)
required_keys = {
k for k, v in typehint.__annotations__.items() if get_typehint_origin(v) not in not_required_types
}
else:
required_keys = {
k for k, v in typehint.__annotations__.items() if get_typehint_origin(v) in required_types
}
missing_keys = required_keys - val.keys()
if missing_keys:
raise_unexpected_value(f"Missing required keys: {missing_keys}", val)
extra_keys = val.keys() - typehint.__annotations__.keys()
if extra_keys:
raise_unexpected_value(f"Unexpected keys: {extra_keys}", val)
Expand All @@ -904,6 +940,11 @@ def adapt_typehints(
elif typehint_origin is OrderedDict:
val = dict(val) if serialize else OrderedDict(val)

# TypedDict NotRequired and Required
elif typehint_origin in not_required_required_types:
assert len(subtypehints) == 1, "(Not)Required requires a single type argument"
val = adapt_typehints(val, subtypehints[0], **adapt_kwargs)

# Callable
elif typehint_origin in callable_origin_types or typehint in callable_origin_types:
if serialize:
Expand Down
2 changes: 1 addition & 1 deletion jsonargparse/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def get_typehint_origin(typehint):
typehint_class = get_import_path(typehint.__class__)
if typehint_class == "types.UnionType":
return Union
if typehint_class == "typing._TypedDictMeta":
if typehint_class in {"typing._TypedDictMeta", "typing_extensions._TypedDictMeta"}:
return dict
return getattr(typehint, "__origin__", None)

Expand Down
66 changes: 60 additions & 6 deletions jsonargparse_tests/test_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
Type,
Union,
)

if sys.version_info >= (3, 8):
from typing import TypedDict
from unittest import mock
from warnings import catch_warnings

Expand All @@ -37,6 +34,9 @@
from jsonargparse._typehints import (
ActionTypeHint,
Literal,
NotRequired,
Required,
TypedDict,
get_all_subclass_paths,
get_subclass_types,
is_optional,
Expand Down Expand Up @@ -501,7 +501,7 @@ def test_mapping_nested_without_args(parser):
assert {"b": {"c": 2}} == parser.parse_args(['--map={"b": {"c": 2}}']).map


@pytest.mark.skipif(sys.version_info < (3, 8), reason="TypedDict introduced in python 3.8")
@pytest.mark.skipif(not TypedDict, reason="TypedDict introduced in python 3.8 or backported in typing_extensions")
def test_typeddict_without_arg(parser):
parser.add_argument("--typeddict", type=TypedDict("MyDict", {}))
assert {} == parser.parse_args(["--typeddict={}"])["typeddict"]
Expand All @@ -513,7 +513,7 @@ def test_typeddict_without_arg(parser):
ctx.match("Expected a <class 'dict'>")


@pytest.mark.skipif(sys.version_info < (3, 8), reason="TypedDict introduced in python 3.8")
@pytest.mark.skipif(not TypedDict, reason="TypedDict introduced in python 3.8 or backported in typing_extensions")
def test_typeddict_with_args(parser):
parser.add_argument("--typeddict", type=TypedDict("MyDict", {"a": int}))
assert {"a": 1} == parser.parse_args(["--typeddict={'a': 1}"])["typeddict"]
Expand All @@ -532,7 +532,7 @@ def test_typeddict_with_args(parser):
ctx.match("Expected a <class 'dict'>")


@pytest.mark.skipif(sys.version_info < (3, 8), reason="TypedDict introduced in python 3.8")
@pytest.mark.skipif(not TypedDict, reason="TypedDict introduced in python 3.8 or backported in typing_extensions")
def test_typeddict_with_args_ntotal(parser):
parser.add_argument("--typeddict", type=TypedDict("MyDict", {"a": int}, total=False))
assert {"a": 1} == parser.parse_args(["--typeddict={'a': 1}"])["typeddict"]
Expand All @@ -548,6 +548,60 @@ def test_typeddict_with_args_ntotal(parser):
ctx.match("Expected a <class 'dict'>")


@pytest.mark.skipif(not NotRequired, reason="NotRequired introduced in python 3.11 or backported in typing_extensions")
def test_not_required_support(parser):
assert ActionTypeHint.is_supported_typehint(NotRequired[Any])


@pytest.mark.skipif(not NotRequired, reason="NotRequired introduced in python 3.11 or backported in typing_extensions")
def test_typeddict_with_not_required_arg(parser):
parser.add_argument("--typeddict", type=TypedDict("MyDict", {"a": int, "b": NotRequired[int]}))
assert {"a": 1} == parser.parse_args(["--typeddict={'a': 1}"])["typeddict"]
assert {"a": 1, "b": 2} == parser.parse_args(["--typeddict={'a': 1, 'b': 2}"])["typeddict"]
with pytest.raises(ArgumentError) as ctx:
parser.parse_args(['--typeddict={"a":1, "b":2, "c": 3}'])
ctx.match("Unexpected keys")
with pytest.raises(ArgumentError) as ctx:
parser.parse_args(['--typeddict={"b":2}'])
ctx.match("Missing required keys")
with pytest.raises(ArgumentError) as ctx:
parser.parse_args(["--typeddict={}"])
ctx.match("Missing required keys")
with pytest.raises(ArgumentError) as ctx:
parser.parse_args(['--typeddict={"a":"x"}'])
ctx.match("Expected a <class 'int'>")
with pytest.raises(ArgumentError) as ctx:
parser.parse_args(['--typeddict={"a":1, "b":"x"}'])
ctx.match("Expected a <class 'int'>")


@pytest.mark.skipif(not Required, reason="Required introduced in python 3.11 or backported in typing_extensions")
def test_required_support(parser):
assert ActionTypeHint.is_supported_typehint(Required[Any])


@pytest.mark.skipif(not Required, reason="Required introduced in python 3.11 or backported in typing_extensions")
def test_typeddict_with_required_arg(parser):
parser.add_argument("--typeddict", type=TypedDict("MyDict", {"a": Required[int], "b": int}, total=False))
assert {"a": 1} == parser.parse_args(["--typeddict={'a': 1}"])["typeddict"]
assert {"a": 1, "b": 2} == parser.parse_args(["--typeddict={'a': 1, 'b': 2}"])["typeddict"]
with pytest.raises(ArgumentError) as ctx:
parser.parse_args(['--typeddict={"a":1, "b":2, "c": 3}'])
ctx.match("Unexpected keys")
with pytest.raises(ArgumentError) as ctx:
parser.parse_args(['--typeddict={"b":2}'])
ctx.match("Missing required keys")
with pytest.raises(ArgumentError) as ctx:
parser.parse_args(["--typeddict={}"])
ctx.match("Missing required keys")
with pytest.raises(ArgumentError) as ctx:
parser.parse_args(['--typeddict={"a":"x"}'])
ctx.match("Expected a <class 'int'>")
with pytest.raises(ArgumentError) as ctx:
parser.parse_args(['--typeddict={"a":1, "b":"x"}'])
ctx.match("Expected a <class 'int'>")


def test_mapping_proxy_type(parser):
parser.add_argument("--mapping", type=MappingProxyType)
cfg = parser.parse_args(['--mapping={"x":1}'])
Expand Down

0 comments on commit 27137b5

Please sign in to comment.