Skip to content

Commit

Permalink
torchx: use importlib_metadata for Python 3.10+ syntax (#623)
Browse files Browse the repository at this point in the history
Summary:
importlib.metadata in 3.12 breaks compatibility with the dict style interface. This switches TorchX to use importlib_metadata for all versions and switches the code to use the 3.10+ select style interface instead of dict.

This avoids having to pin importlib_metadata<5 such as in pytorch/tutorials#2091

Pull Request resolved: #623

Test Plan: Unit tests on both importlib_metadata 5.0 and 4.1.3

Reviewed By: priyaramani

Differential Revision: D40561638

Pulled By: d4l3k

fbshipit-source-id: 95144406c0e3dcbe203ada3ff3236f7384ab2a5c
  • Loading branch information
d4l3k authored and facebook-github-bot committed Oct 21, 2022
1 parent f9fa2fe commit e9cf74e
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 35 deletions.
1 change: 0 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ classy-vision>=0.6.0
flake8==3.9.0
fsspec[s3]==2022.1.0
hydra-core
importlib-metadata<5.0
ipython
kfp==1.8.9
moto==3.0.2
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pyre-extensions
docstring-parser==0.8.1
importlib-metadata
pyyaml
docker
filelock
Expand Down
3 changes: 0 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ def get_nightly_version():
"kubernetes": ["kubernetes>=11"],
"ray": ["ray>=1.12.1"],
"dev": dev_reqs,
':python_version < "3.8"': [
"importlib-metadata",
],
},
# PyPI package information.
classifiers=[
Expand Down
28 changes: 10 additions & 18 deletions torchx/util/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

try:
from importlib import metadata
from importlib.metadata import EntryPoint
except ImportError:
import importlib_metadata as metadata
from importlib_metadata import EntryPoint
from typing import Any, Dict, Optional

import importlib_metadata as metadata
from importlib_metadata import EntryPoint


# pyre-ignore-all-errors[3, 2]
def load(group: str, name: str, default=None):
Expand All @@ -31,18 +28,13 @@ def load(group: str, name: str, default=None):
raises an error.
"""

entrypoints = metadata.entry_points()
entrypoints = metadata.entry_points().select(group=group)

if group not in entrypoints and default:
if name not in entrypoints.names and default is not None:
return default

eps: Dict[str, EntryPoint] = {ep.name: ep for ep in entrypoints[group]}

if name not in eps and default:
return default
else:
ep = eps[name]
return ep.load()
ep = entrypoints[name]
return ep.load()


def _defer_load_ep(ep: EntryPoint) -> object:
Expand Down Expand Up @@ -75,12 +67,12 @@ def load_group(
"""

entrypoints = metadata.entry_points()
entrypoints = metadata.entry_points().select(group=group)

if group not in entrypoints:
if len(entrypoints) == 0:
return default

eps = {}
for ep in entrypoints[group]:
for ep in entrypoints:
eps[ep.name] = _defer_load_ep(ep)
return eps
26 changes: 13 additions & 13 deletions torchx/util/test/entrypoints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@
# LICENSE file in the root directory of this source tree.

import unittest

try:
from importlib.metadata import EntryPoint
except ImportError:
from importlib_metadata import EntryPoint
from configparser import ConfigParser
from typing import Dict, List
from typing import List
from unittest.mock import MagicMock, patch

from importlib_metadata import EntryPoint, EntryPoints

from torchx.util.entrypoints import load, load_group


Expand Down Expand Up @@ -61,12 +58,12 @@ def barbaz() -> str:
[ep.grp.missing.mod.test]
baz = torchx.util.test.entrypoints_test.missing_module
"""
_ENTRY_POINTS: Dict[str, List[EntryPoint]] = {
"entrypoints.test": EntryPoint_from_text(_EP_TXT),
"ep.grp.test": EntryPoint_from_text(_EP_GRP_TXT),
"ep.grp.missing.attr.test": EntryPoint_from_text(_EP_GRP_IGN_ATTR_TXT),
"ep.grp.missing.mod.test": EntryPoint_from_text(_EP_GRP_IGN_MOD_TXT),
}
_ENTRY_POINTS: EntryPoints = EntryPoints(
EntryPoint_from_text(_EP_TXT)
+ EntryPoint_from_text(_EP_GRP_TXT)
+ EntryPoint_from_text(_EP_GRP_IGN_ATTR_TXT)
+ EntryPoint_from_text(_EP_GRP_IGN_MOD_TXT)
)

_METADATA_EPS: str = "torchx.util.entrypoints.metadata.entry_points"

Expand All @@ -77,6 +74,9 @@ def test_load(self, mock_md_eps: MagicMock) -> None:
print(type(load("entrypoints.test", "foo")))
self.assertEqual("foobar", load("entrypoints.test", "foo")())

with self.assertRaisesRegex(KeyError, "baz"):
load("entrypoints.test", "baz")()

@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
def test_load_with_default(self, mock_md_eps: MagicMock) -> None:
self.assertEqual("barbaz", load("entrypoints.test", "missing", barbaz)())
Expand All @@ -86,7 +86,7 @@ def test_load_with_default(self, mock_md_eps: MagicMock) -> None:
@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
def test_load_group(self, mock_md_eps: MagicMock) -> None:
eps = load_group("ep.grp.test")
self.assertEqual(2, len(eps))
self.assertEqual(2, len(eps), eps)
self.assertEqual("foobar", eps["foo"]())
self.assertEqual("barbaz", eps["bar"]())

Expand Down

0 comments on commit e9cf74e

Please sign in to comment.