Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work on mac #53

Merged
merged 7 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions daisy/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .client import Client
from inspect import signature
from inspect import getfullargspec


class Task:
Expand Down Expand Up @@ -163,15 +163,22 @@ def __init__(
if init_callback_fn is not None:
self.init_callback_fn = init_callback_fn
else:
self.init_callback_fn = lambda context: None
self.init_callback_fn = self._default_init

if len(signature(process_function).parameters) == 0:
args = getfullargspec(process_function).args
if len(args) == 0:
# spawn function
self.spawn_worker_function = process_function
elif len(args) == 1:
# process block function
self.spawn_worker_function = self._process_blocks
else:
self.spawn_worker_function = lambda: self._process_blocks()
raise ValueError(f"daisy does not know what to pass into args: {args}")

def _process_blocks(self):
def _default_init(self, context):
pass

def _process_blocks(self):
client = Client()
while True:
with client.acquire_block() as block:
Expand Down
14 changes: 12 additions & 2 deletions daisy/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import multiprocessing
import os
import queue
import dill

logger = logging.getLogger(__name__)

Expand All @@ -27,6 +28,7 @@ class Worker:
"""

__next_id = multiprocessing.Value("L")
_spawn_function = None

@staticmethod
def get_next_id():
Expand All @@ -49,14 +51,22 @@ def __init__(self, spawn_function, context=None, error_queue=None):

self.start()

@property
def spawn_function(self):
return dill.loads(self._spawn_function)

@spawn_function.setter
def spawn_function(self, value):
self._spawn_function = dill.dumps(value)

def start(self):
"""Start this worker. Note that workers are automatically started when
created. Use this function to re-start a stopped worker."""

if self.process is not None:
return

self.process = multiprocessing.Process(target=lambda: self.__spawn_wrapper())
self.process = multiprocessing.Process(target=self._spawn_wrapper)
self.process.start()

def stop(self):
Expand All @@ -74,7 +84,7 @@ def stop(self):
logger.debug("%s terminated", self)
self.process = None

def __spawn_wrapper(self):
def _spawn_wrapper(self):
"""Thin wrapper around the user-specified spawn function to set
environment variables, redirect output, and to capture exceptions."""

Expand Down
10 changes: 5 additions & 5 deletions daisy/worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def set_num_workers(self, num_workers):
logger.debug("current number of workers: %d", len(self.workers))

if diff > 0:
self.__start_workers(diff)
self._start_workers(diff)
elif diff < 0:
self.__stop_workers(-diff)
self._stop_workers(-diff)

def inc_num_workers(self, num_workers):
self.__start_workers(num_workers)
self._start_workers(num_workers)

def stop(self, worker_id=None):
"""Stop all current workers in this pool (``worker_id == None``) or a
Expand Down Expand Up @@ -104,7 +104,7 @@ def check_for_errors(self):
except queue.Empty:
pass

def __start_workers(self, n):
def _start_workers(self, n):

logger.debug("starting %d new workers", n)
new_workers = [
Expand All @@ -113,7 +113,7 @@ def __start_workers(self, n):
]
self.workers.update({worker.worker_id: worker for worker in new_workers})

def __stop_workers(self, n):
def _stop_workers(self, n):

logger.debug("stopping %d workers", n)

Expand Down
2 changes: 1 addition & 1 deletion examples/batch_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def _prepare_task(
)

if check_fn is None:
check_fn = lambda b: self._default_check_fn(b)
check_fn = self._default_check_fn

if self.overwrite:
print("Dropping table %s" % self.db_id)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"tqdm",
"funlib.math",
"funlib.geometry",
"dill",
]

[project.optional-dependencies]
Expand All @@ -46,5 +47,6 @@ module = [
"funlib.*",
"tqdm.*",
"pkg_resources.*",
"dill",
]
ignore_missing_imports = true
83 changes: 40 additions & 43 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,48 @@
from daisy.tcp import TCPServer


class TestClient(unittest.TestCase):
def run_test_server(block, conn):
server = TCPServer()
conn.send(server.address)

def run_test_server(self, block, conn):
server = TCPServer()
conn.send(server.address)
# handle first acquire_block message
message = None
for i in range(10):
message = server.get_message(timeout=1)
if message:
break
if not message:
raise Exception("SERVER COULDN'T GET MESSAGE")
try:
assert isinstance(message, AcquireBlock)
message.stream.send_message(SendBlock(block))
except Exception as e:
message.stream.send_message(ExceptionMessage(e))

# handle first acquire_block message
message = None
for i in range(10):
message = server.get_message(timeout=1)
if message:
break
if not message:
raise Exception("SERVER COULDN'T GET MESSAGE")
try:
self.assertTrue(isinstance(message, AcquireBlock))
message.stream.send_message(SendBlock(block))
except Exception as e:
message.stream.send_message(ExceptionMessage(e))
# handle return_block message
message = server.get_message(timeout=1)
try:
assert isinstance(message, ReleaseBlock)
assert message.block.status == daisy.BlockStatus.SUCCESS
except Exception as e:
message.stream.send_message(ExceptionMessage(e))
conn.send(1)
conn.close()

# handle return_block message
message = server.get_message(timeout=1)
try:
self.assertTrue(isinstance(message, ReleaseBlock))
self.assertTrue(message.block.status == daisy.BlockStatus.SUCCESS)
except Exception as e:
message.stream.send_message(ExceptionMessage(e))
conn.send(1)
conn.close()

def test_basic(self):
roi = daisy.Roi((0, 0, 0), (10, 10, 10))
task_id = 1
block = daisy.Block(roi, roi, roi, block_id=1, task_id=task_id)
parent_conn, child_conn = mp.Pipe()
server_process = mp.Process(
target=self.run_test_server, args=(block, child_conn)
)
server_process.start()
host, port = parent_conn.recv()
context = daisy.Context(hostname=host, port=port, task_id=task_id, worker_id=1)
client = daisy.Client(context=context)
with client.acquire_block() as block:
block.status = daisy.BlockStatus.SUCCESS
def test_basic():
roi = daisy.Roi((0, 0, 0), (10, 10, 10))
task_id = 1
block = daisy.Block(roi, roi, roi, block_id=1, task_id=task_id)
parent_conn, child_conn = mp.Pipe()
server_process = mp.Process(target=run_test_server, args=(block, child_conn))
server_process.start()
host, port = parent_conn.recv()
context = daisy.Context(hostname=host, port=port, task_id=task_id, worker_id=1)
client = daisy.Client(context=context)
with client.acquire_block() as block:
block.status = daisy.BlockStatus.SUCCESS

success = parent_conn.recv()
server_process.join()
self.assertTrue(success)
success = parent_conn.recv()
server_process.join()
assert success
37 changes: 18 additions & 19 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,25 @@
logging.basicConfig(level=logging.DEBUG)


class TestServer(unittest.TestCase):
def process_block(block):
print("Processing block %s" % block)

def test_basic(self):

task = daisy.Task(
"test_server_task",
total_roi=daisy.Roi((0,), (100,)),
read_roi=daisy.Roi((0,), (10,)),
write_roi=daisy.Roi((1,), (8,)),
process_function=lambda b: self.process_block(b),
check_function=None,
read_write_conflict=True,
fit="valid",
num_workers=1,
max_retries=2,
timeout=None,
)
def test_basic():

server = daisy.Server()
server.run_blockwise([task])
task = daisy.Task(
"test_server_task",
total_roi=daisy.Roi((0,), (100,)),
read_roi=daisy.Roi((0,), (10,)),
write_roi=daisy.Roi((1,), (8,)),
process_function=process_block,
check_function=None,
read_write_conflict=True,
fit="valid",
num_workers=1,
max_retries=2,
timeout=None,
)

def process_block(self, block):
print("Processing block %s" % block)
server = daisy.Server()
server.run_blockwise([task])
Loading