Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Will not land] Cherry-pick of changes to ProtoRS #815

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions torchdata/dataloader2/communication/eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pickle
import threading
import time

import torch

Expand All @@ -30,10 +31,51 @@
"SpawnThreadForDataPipeline",
]

TIME_SLEEP_BETWEEN_CHECKING_DIFFERENT_QUEUES = 0.00000001

# TODO(VitalyFedyunin): Find better names to the two functions below as they are separate thread/process/work-items
# TODO(VitalyFedyunin): Can combine Multiple and Single functions by checking size of pipes_and_queues and deciding block/non-block.


def MultipleDataPipesToQueuesLoop(pipes_and_queues, call_locally_fn=None):
if call_locally_fn is not None:
raise Exception("MultipleDataPipesToQueuesLoop does not support call_locally_fn")
torch.set_num_threads(1)

resets_counter = [0]

iterators = []
for source_datapipe, req_queue, res_queue in pipes_and_queues:
iterators.append(
DataPipeToQueuesLoopIterator(
source_datapipe,
req_queue,
res_queue,
blocking_request_get=False,
resets_to_proceed=len(pipes_and_queues),
resets_counter=resets_counter,
)
)

# TODO(VitalyFedyunin): Maybe better way to combine iterators
for _ in zip(*iterators):
# TODO(VitalyFedyunin): Check python MP implementation why this sleep impacts queues statuses
# This magical sleep allows mp queue messages to travel faster
time.sleep(TIME_SLEEP_BETWEEN_CHECKING_DIFFERENT_QUEUES)
pass


def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_locally_fn=None):
if call_locally_fn is not None:
call_locally_fn(source_datapipe)
source_datapipe = call_locally_fn(source_datapipe)
torch.set_num_threads(1)
for _ in DataPipeToQueuesLoopIterator(source_datapipe, req_queue, res_queue, blocking_request_get=True):
pass


def DataPipeToQueuesLoopIterator(
source_datapipe, req_queue, res_queue, blocking_request_get=True, resets_to_proceed=1, resets_counter=[]
):
if isinstance(source_datapipe, IterDataPipe):
pipe_type = communication.iter
protocol_type = communication.protocol.IterDataPipeQueueProtocolServer
Expand All @@ -43,13 +85,14 @@ def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_locally_fn=
else:
raise Exception("Only supports IterDataPipe or MapDataPipe, got", source_datapipe)

# torch.utils.data.graph_settings.apply_sharding(source_datapipe, self.num_workers, worker_id)

torch.set_num_threads(1)
for _ in pipe_type.DataPipeBehindQueues(
source_datapipe, protocol_type(req_queue, res_queue), blocking_request_get=True
source_datapipe,
protocol_type(req_queue, res_queue),
blocking_request_get=blocking_request_get,
resets_to_proceed=resets_to_proceed,
resets_counter=resets_counter,
):
pass
yield True


def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe, call_locally_fn=None):
Expand All @@ -61,6 +104,19 @@ def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe, call_locally_fn=N
return process, req_queue, res_queue


def SpawnProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes, call_locally_fn=None):
pipes_and_queues = []
for dp in datapipes:
req_queue = multiprocessing_ctx.Queue()
res_queue = multiprocessing_ctx.Queue()
pipes_and_queues.append((dp, req_queue, res_queue))

process = multiprocessing_ctx.Process(
target=MultipleDataPipesToQueuesLoop, args=(pipes_and_queues, call_locally_fn)
)
return process, pipes_and_queues


def SpawnThreadForDataPipeline(datapipe):
r"""
Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with DataPipeToQueuesLoop as target,
Expand Down
39 changes: 29 additions & 10 deletions torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import time
import types

from torch.utils.data import IterDataPipe
from torchdata.dataloader2 import communication

# from torchdata.dataloader2.communication.iter import InvalidStateResetRequired

DEFAULT_NON_BLOCKING_SLEEP = 0.001

__all__ = [
Expand All @@ -34,19 +37,13 @@ class NotAvailable(Exception):
class InvalidStateResetRequired(Exception):
"""
Returned by DataPipe when it is expecting to get reset request,
for example RouterDataPipe expecting all workers to request reset.
for example RouterDataPipe expecting all workers to request reset'
"""

pass


class TerminateRequired(Exception):
"""
Returned by DataPipe when it is expecting to get terminate request,
for example it got terminate request from other source and at the process
of stopping.
"""

pass


Expand Down Expand Up @@ -106,7 +103,9 @@ def reset_iterator(self):
return validated_datapipe


def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False):
def DataPipeBehindQueues(
source_datapipe, protocol, full_stop=False, blocking_request_get=False, resets_to_proceed=1, resets_counter=[]
):
"""
Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue
If raise_stop is true, raises exception when StopIteration received from the source_datapipe
Expand All @@ -127,6 +126,18 @@ def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_re
source_datapipe.reset_iterator()
protocol.response_reset_iterator()

# if resets_to_proceed > 1:
# print(os.getpid(), f"Received one of {resets_to_proceed} reset requests, waiting others to unblock")
# if resets_counter[0] == resets_to_proceed:
# resets_counter[0] = 0
# resets_counter[0] += 1
# while resets_counter[0] < resets_to_proceed:
# print(os.getpid(), "waiting for reset counters", resets_counter)
# yield True
# time.sleep(1)
# print(os.getpid(), f" Collected {resets_to_proceed} resets")
# # resets_counter[0] = 0

elif isinstance(request, communication.messages.TerminateRequest):
forever = False
protocol.response_terminate()
Expand All @@ -145,7 +156,8 @@ def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_re
else:
yield True
break
except InvalidStateResetRequired:
except (InvalidStateResetRequired, RuntimeError):
print(os.getpid(), "Non blocking failed with Invalid state")
protocol.response_invalid_state()
if full_stop:
forever = False
Expand Down Expand Up @@ -191,11 +203,18 @@ def nonblocking_next(self):
self.protocol.request_next()
try:
response = self.protocol.get_response_next(block=True, timeout=self._response_wait_time)
# response = self.protocol.get_response_next(block=True)
except communication.protocol.EmptyQueue:
raise NotAvailable
# print(os.getpid(), "got response from q", response)
if isinstance(response, communication.messages.StopIterationResponse):
self._stop_iteration = True
raise StopIteration
if isinstance(response, communication.messages.InvalidStateResponse):
raise NotAvailable
self._stop_iteration = True
raise communication.iter.InvalidStateResetRequired
if isinstance(response, communication.messages.TerminateResponse):
# This will happen with terminate sent directly into the queue on exit
self._stop_iteration = True
raise communication.iter.TerminateRequired
return response.value
4 changes: 3 additions & 1 deletion torchdata/dataloader2/communication/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def nonblocking_getitem(self, index):
return validated_datapipe


def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False):
def DataPipeBehindQueues(
source_datapipe, protocol, full_stop=False, blocking_request_get=False, resets_to_proceed=1, resets_counter=[]
):
"""
Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue
If raise_stop is true, raises exception when StopIteration received from the source_datapipe
Expand Down
24 changes: 21 additions & 3 deletions torchdata/dataloader2/communication/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os

from torchdata.dataloader2 import communication


def _protocol_log(*args):
if False:
print(os.getpid(), args)


class Protocol:
__slots__ = ("request_queue", "response_queue")

Expand All @@ -26,6 +33,7 @@ def __init__(self, request_queue, response_queue):
self.request_queue = request_queue
self.response_queue = response_queue
self._req_sent = None
_protocol_log("Clinet", self, self.request_queue, self.response_queue)

def can_take_request(self):
return self._req_sent is None
Expand Down Expand Up @@ -55,6 +63,7 @@ def __init__(self, request_queue, response_queue):
self.request_queue = request_queue
self.response_queue = response_queue
self._req_received = None
_protocol_log("----> Server", self, self.request_queue, self.response_queue)

def have_pending_request(self):
return self._req_received is not None
Expand All @@ -67,6 +76,7 @@ def get_new_request(self, block=False):
except Exception: # TODO(625): Catch only timeout exceptions
raise EmptyQueue("queue is empty")
self._req_received = response
# _protocol_log("Server received request", response)
return response
# TODO(626): Validate supported requests

Expand Down Expand Up @@ -150,6 +160,7 @@ def response_reset_iterator(self):
if not isinstance(self._req_received, communication.messages.ResetIteratorRequest):
raise Exception("Replaying with reset status to other type of message")
self.response_queue.put(communication.messages.ResetIteratorResponse())
_protocol_log("server repried by ", communication.messages.ResetIteratorResponse())
self._req_received = None

def response_next(self, value):
Expand All @@ -174,36 +185,43 @@ def response_invalid_state(self):
class IterDataPipeQueueProtocolClient(ProtocolClient):
def request_reset_iterator(self):
if not self.can_take_request():
raise Exception("Can not reset while we are still waiting response for previous request")
raise Exception(os.getpid(), "Can not reset while we are still waiting response for previous request")
request = communication.messages.ResetIteratorRequest()
self.request_queue.put(request)
self.request_sent(request)
# import os
# if 'exc' in os.environ:
# raise Exception('requesut_reset_iterator called')
_protocol_log("request_reset_iterator")

def request_next(self):
if not self.can_take_request():
raise Exception("Can not request next item while we are still waiting response for previous request")
request = communication.messages.GetNextRequest()
self.request_queue.put(request)
self.request_sent(request)
_protocol_log("request_next")

def get_response_reset_iterator(self, block=False):
# _protocol_log("get_response_reset_iterator")
try:
response = self.response_queue.get(block=block)
except Exception: # TODO(627): Catch only timeout exceptions
raise EmptyQueue("queue is empty")
self.request_served(response)

_protocol_log("get_response_reset_iterator OK")
if not isinstance(response, communication.messages.ResetIteratorResponse):
raise Exception("Invalid response received")

def get_response_next(self, block=False, timeout=None):
# _protocol_log("get_response_next")
if not self.waiting_for_response():
raise Exception("Can not expect any response without submitted request")
try:
response = self.response_queue.get(block=block, timeout=timeout)
except Exception: # TODO(628): Catch only timeout exceptions
raise EmptyQueue("queue is empty")
self.request_served(response)

_protocol_log("get_response_next OK")
# TODO(629): Add possible response types validation here
return response
10 changes: 9 additions & 1 deletion torchdata/dataloader2/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.


import pickle
from typing import List, Type

from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
Expand All @@ -13,9 +14,16 @@
from torchdata.datapipes.map import MapDataPipe


__all__ = ["find_dps", "replace_dp", "remove_dp"]
__all__ = ["find_dps", "replace_dp", "remove_dp", "clone_datapipe"]

# Make a copy of the graph
def clone_datapipe(datapipe: DataPipe) -> DataPipe:
# TODO(VitalyFedyunin): Unify it with all dill operations
datapipe = pickle.loads(pickle.dumps(datapipe))
return datapipe


# In case that there will be multiple datapipe needs to be adapted
def find_dps(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> List[DataPipe]:
r"""
Given the graph of DataPipe generated by ``traverse_dps`` function, return DataPipe
Expand Down
Loading