Skip to content

Commit

Permalink
unskip test_udf_parallel (#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon authored Sep 12, 2024
1 parent 027162c commit b9e6b8e
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions tests/unit/lib/test_datachain_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,26 @@
from datachain.lib.udf import Mapper


class MyMapper(Mapper):
DEFAULT_VALUE = 84
BOOTSTRAP_VALUE = 1452
TEARDOWN_VALUE = 98763

def __init__(self):
super().__init__()
self.value = MyMapper.DEFAULT_VALUE
self._had_teardown = False
def test_udf():
class MyMapper(Mapper):
DEFAULT_VALUE = 84
BOOTSTRAP_VALUE = 1452
TEARDOWN_VALUE = 98763

def process(self, *args) -> int:
return self.value
def __init__(self):
super().__init__()
self.value = MyMapper.DEFAULT_VALUE
self._had_teardown = False

def setup(self):
self.value = MyMapper.BOOTSTRAP_VALUE
def process(self, *args) -> int:
return self.value

def teardown(self):
self.value = MyMapper.TEARDOWN_VALUE
def setup(self):
self.value = MyMapper.BOOTSTRAP_VALUE

def teardown(self):
self.value = MyMapper.TEARDOWN_VALUE

def test_udf():
vals = ["a", "b", "c", "d", "e", "f"]
chain = DataChain.from_values(key=vals)

Expand All @@ -35,10 +34,30 @@ def test_udf():
assert udf.value == MyMapper.TEARDOWN_VALUE


@pytest.mark.skip(reason="Skip until tests module will be importer for unit-tests")
def test_udf_parallel():
@pytest.mark.xdist_group(name="tmpfile")
def test_udf_parallel(test_session_tmpfile):
vals = ["a", "b", "c", "d", "e", "f"]
chain = DataChain.from_values(key=vals)

class MyMapper(Mapper):
DEFAULT_VALUE = 84
BOOTSTRAP_VALUE = 1452
TEARDOWN_VALUE = 98763

def __init__(self):
super().__init__()
self.value = MyMapper.DEFAULT_VALUE
self._had_teardown = False

def process(self, *args) -> int:
return self.value

def setup(self):
self.value = MyMapper.BOOTSTRAP_VALUE

def teardown(self):
self.value = MyMapper.TEARDOWN_VALUE

chain = DataChain.from_values(key=vals, session=test_session_tmpfile)

res = list(chain.settings(parallel=4).map(res=MyMapper()).collect("res"))

Expand Down

0 comments on commit b9e6b8e

Please sign in to comment.