From 6c962a84b73ada488c28614c7665d1db3175e80f Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 26 Sep 2022 13:44:13 +0200 Subject: [PATCH 1/2] Adding transformers as backend --- merlin/models/tf/transformers/__init__.py | 0 requirements/transformers.txt | 1 + setup.py | 4 +++- tests/unit/tf/transformers/__init__.py | 18 ++++++++++++++++++ 4 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 merlin/models/tf/transformers/__init__.py create mode 100644 requirements/transformers.txt create mode 100644 tests/unit/tf/transformers/__init__.py diff --git a/merlin/models/tf/transformers/__init__.py b/merlin/models/tf/transformers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/requirements/transformers.txt b/requirements/transformers.txt new file mode 100644 index 0000000000..747b7aa97a --- /dev/null +++ b/requirements/transformers.txt @@ -0,0 +1 @@ +transformers \ No newline at end of file diff --git a/setup.py b/setup.py index e18e7d873c..57d23a3e81 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ def read_requirements(filename): _dev = read_requirements("requirements/dev.txt") _docs = read_requirements("requirements/docs.txt") _nvt = read_requirements("requirements/nvtabular.txt") +_transformers = read_requirements("requirements/transformers.txt") requirements = { "base": read_requirements("requirements/base.txt"), @@ -49,12 +50,13 @@ def read_requirements(filename): "lightfm": read_requirements("requirements/lightfm.txt"), "implicit": read_requirements("requirements/implicit.txt"), "xgboost": read_requirements("requirements/xgboost.txt"), + "transformers": _transformers, "nvtabular": _nvt, "dev": _dev, "docs": _docs, } dev_requirements = { - "tensorflow-dev": requirements["tensorflow"] + _dev + _nvt, + "tensorflow-dev": requirements["tensorflow"] + _transformers + _dev + _nvt, "pytorch-dev": requirements["pytorch"] + _dev + _nvt, "implicit-dev": requirements["implicit"] + _dev + _nvt, "lightfm-dev": requirements["lightfm"] + _dev + _nvt, diff --git a/tests/unit/tf/transformers/__init__.py b/tests/unit/tf/transformers/__init__.py new file mode 100644 index 0000000000..1a6745db85 --- /dev/null +++ b/tests/unit/tf/transformers/__init__.py @@ -0,0 +1,18 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +pytest.importorskip("transformers") From c06bdaee4768ef613adec190fa85dcb5a9948181 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 26 Sep 2022 13:47:23 +0200 Subject: [PATCH 2/2] Adding transformers to conftest to add a automatic pytest-marker --- pyproject.toml | 1 + tests/conftest.py | 2 ++ tests/unit/tf/transformers/test_block.py | 4 ++++ 3 files changed, 7 insertions(+) create mode 100644 tests/unit/tf/transformers/test_block.py diff --git a/pyproject.toml b/pyproject.toml index 75c695b774..b1faa0a4cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ markers = [ "lightfm", "implicit", "xgboost", + "transformers", "example", "integration", "unit", diff --git a/tests/conftest.py b/tests/conftest.py index 28491f8bcf..4517fd7bcd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -104,3 +104,5 @@ def pytest_collection_modifyitems(items): item.add_marker(pytest.mark.xgboost) if "/datasets/" in path: item.add_marker(pytest.mark.datasets) + if "/transformers/" in path: + item.add_marker(pytest.mark.transformers) diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py new file mode 100644 index 0000000000..17e3e32dde --- /dev/null +++ b/tests/unit/tf/transformers/test_block.py @@ -0,0 +1,4 @@ +def test_import(): + import transformers + + assert transformers is not None