diff --git a/tests/unit/lib/test_datachain_bootstrap.py b/tests/unit/lib/test_datachain_bootstrap.py index 4ecbdaa3a..7ecd27f86 100644 --- a/tests/unit/lib/test_datachain_bootstrap.py +++ b/tests/unit/lib/test_datachain_bootstrap.py @@ -4,27 +4,26 @@ from datachain.lib.udf import Mapper -class MyMapper(Mapper): - DEFAULT_VALUE = 84 - BOOTSTRAP_VALUE = 1452 - TEARDOWN_VALUE = 98763 - - def __init__(self): - super().__init__() - self.value = MyMapper.DEFAULT_VALUE - self._had_teardown = False +def test_udf(): + class MyMapper(Mapper): + DEFAULT_VALUE = 84 + BOOTSTRAP_VALUE = 1452 + TEARDOWN_VALUE = 98763 - def process(self, *args) -> int: - return self.value + def __init__(self): + super().__init__() + self.value = MyMapper.DEFAULT_VALUE + self._had_teardown = False - def setup(self): - self.value = MyMapper.BOOTSTRAP_VALUE + def process(self, *args) -> int: + return self.value - def teardown(self): - self.value = MyMapper.TEARDOWN_VALUE + def setup(self): + self.value = MyMapper.BOOTSTRAP_VALUE + def teardown(self): + self.value = MyMapper.TEARDOWN_VALUE -def test_udf(): vals = ["a", "b", "c", "d", "e", "f"] chain = DataChain.from_values(key=vals) @@ -35,10 +34,30 @@ def test_udf(): assert udf.value == MyMapper.TEARDOWN_VALUE -@pytest.mark.skip(reason="Skip until tests module will be importer for unit-tests") -def test_udf_parallel(): +@pytest.mark.xdist_group(name="tmpfile") +def test_udf_parallel(test_session_tmpfile): vals = ["a", "b", "c", "d", "e", "f"] - chain = DataChain.from_values(key=vals) + + class MyMapper(Mapper): + DEFAULT_VALUE = 84 + BOOTSTRAP_VALUE = 1452 + TEARDOWN_VALUE = 98763 + + def __init__(self): + super().__init__() + self.value = MyMapper.DEFAULT_VALUE + self._had_teardown = False + + def process(self, *args) -> int: + return self.value + + def setup(self): + self.value = MyMapper.BOOTSTRAP_VALUE + + def teardown(self): + self.value = MyMapper.TEARDOWN_VALUE + + chain = DataChain.from_values(key=vals, session=test_session_tmpfile) res = list(chain.settings(parallel=4).map(res=MyMapper()).collect("res"))