Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parallel=False flag to Group() #272

Merged
merged 3 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions supriya/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,18 @@ def move(self, add_action: AddActionLike, target_node: "NodeProxy") -> None:
class GroupProxy(NodeProxy):
identifier: Union["supriya.nonrealtime.Node", int]
provider: "Provider"
parallel: bool = False

def as_add_request(self, add_action, target_node) -> commands.GroupNewRequest:
return commands.GroupNewRequest(
def as_add_request(
self, add_action, target_node
) -> Union[commands.GroupNewRequest, commands.ParallelGroupNewRequest]:

request_method = commands.GroupNewRequest
if self.parallel:
request_method = commands.ParallelGroupNewRequest
return request_method(
items=[
commands.GroupNewRequest.Item(
request_method.Item(
node_id=int(self.identifier),
add_action=add_action,
target_node_id=int(target_node),
Expand Down Expand Up @@ -543,6 +550,7 @@ def add_group(
target_node=None,
add_action=AddAction.ADD_TO_HEAD,
name: Optional[str] = None,
parallel: bool = False,
) -> GroupProxy:
raise NotImplementedError

Expand Down Expand Up @@ -751,13 +759,14 @@ def add_group(
target_node=None,
add_action=AddAction.ADD_TO_HEAD,
name: Optional[str] = None,
parallel: bool = False,
) -> GroupProxy:
if not self.moment:
raise ValueError("No current moment")
identifier = self._resolve_target_node(target_node).add_group(
add_action=add_action
)
proxy = GroupProxy(identifier=identifier, provider=self)
proxy = GroupProxy(identifier=identifier, provider=self, parallel=parallel)
return proxy

def add_synth(
Expand Down Expand Up @@ -938,12 +947,13 @@ def add_group(
target_node=None,
add_action=AddAction.ADD_TO_HEAD,
name: Optional[str] = None,
parallel: bool = False,
) -> GroupProxy:
if not self.moment:
raise ValueError("No current moment")
target_node = self._resolve_target_node(target_node)
identifier = self._server.node_id_allocator.allocate_node_id(1)
proxy = GroupProxy(identifier=identifier, provider=self)
proxy = GroupProxy(identifier=identifier, provider=self, parallel=parallel)
self.moment.node_additions.append((proxy, add_action, target_node))
if name:
self._annotation_map[identifier] = name
Expand Down
33 changes: 26 additions & 7 deletions supriya/realtime/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ def _unregister_with_local_server(self):

### PUBLIC METHODS ###

def add_group(self, add_action: Optional[AddActionLike] = None) -> "Group":
def add_group(
self, add_action: Optional[AddActionLike] = None, parallel: bool = False
) -> "Group":
"""
Add a group relative to this node via ``add_action``.

Expand All @@ -315,7 +317,7 @@ def add_group(self, add_action: Optional[AddActionLike] = None) -> "Group":
add_action = AddAction.from_expr(add_action)
if add_action not in self._valid_add_actions:
raise ValueError("Invalid add action: {add_action}")
group = Group()
group = Group(parallel=parallel)
group.allocate(add_action=add_action, target_node=self)
return group

Expand Down Expand Up @@ -544,10 +546,17 @@ class Group(Node, UniqueTreeList):

### INITIALIZER ###

def __init__(self, children=None, name=None, node_id_is_permanent=False):
def __init__(
self,
children=None,
name=None,
node_id_is_permanent=False,
parallel: bool = False,
):
self._control_interface = GroupInterface(client=self)
Node.__init__(self, name=name, node_id_is_permanent=node_id_is_permanent)
UniqueTreeList.__init__(self, children=children, name=name)
self._parallel = parallel

### SPECIAL METHODS ###

Expand Down Expand Up @@ -664,9 +673,12 @@ def _collect_requests_and_synthdefs(self, expr, server, start=0):
requests.append(request)
else:
if isinstance(node, Group):
request = supriya.commands.GroupNewRequest(
request_method = supriya.commands.GroupNewRequest
if node.parallel:
request_method = supriya.commands.ParallelGroupNewRequest
request = request_method(
items=[
supriya.commands.GroupNewRequest.Item(
request_method.Item(
add_action=add_action,
node_id=node,
target_node_id=target_node,
Expand Down Expand Up @@ -759,9 +771,12 @@ def allocate(
self._node_id_is_permanent = bool(node_id_is_permanent)
target_node = Node._expr_as_target(target_node)
server = target_node.server
group_new_request = supriya.commands.GroupNewRequest(
request_method = supriya.commands.GroupNewRequest
if self._parallel:
request_method = supriya.commands.ParallelGroupNewRequest
group_new_request = request_method(
items=[
supriya.commands.GroupNewRequest.Item(
request_method.Item(
add_action=AddAction.from_expr(add_action),
node_id=self,
target_node_id=target_node.node_id,
Expand Down Expand Up @@ -794,6 +809,10 @@ def prepend(self, expr: Node) -> None:
def controls(self) -> GroupInterface:
return self._control_interface

@property
def parallel(self) -> bool:
return self._parallel


class Synth(Node):
"""
Expand Down
6 changes: 4 additions & 2 deletions supriya/realtime/servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,9 @@ def add_bus_group(
bus_group.allocate(server=self)
return bus_group

def add_group(self, add_action: Optional[AddActionLike] = None) -> Group:
def add_group(
self, add_action: Optional[AddActionLike] = None, parallel: bool = False
) -> Group:
"""
Add a group relative to the default group via ``add_action``.

Expand All @@ -986,7 +988,7 @@ def add_group(self, add_action: Optional[AddActionLike] = None) -> Group:
"""
if self.default_group is None:
raise ServerOffline
return self.default_group.add_group(add_action=add_action)
return self.default_group.add_group(add_action=add_action, parallel=parallel)

def add_synth(
self, synthdef=None, add_action: Optional[AddActionLike] = None, **kwargs
Expand Down
11 changes: 11 additions & 0 deletions tests/providers/test_RealtimeProvider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from supriya.assets.synthdefs import default
from supriya.enums import AddAction, CalculationRate
from supriya.osc.messages import OscBundle, OscMessage
from supriya.providers import (
BufferProxy,
BusGroupProxy,
Expand Down Expand Up @@ -564,3 +565,13 @@ def test_RealtimeProvider_set_node_error(server):
group_proxy["foo"] = 23
with pytest.raises(ValueError):
synth_proxy["foo"] = 23


def test_RealtimeProvider_add_group_parallel(server):
provider = Provider.from_context(server)
with server.osc_protocol.capture() as transcript:
with provider.at(None):
provider.add_group(parallel=True)
assert [(_.label, _.message) for _ in transcript] == [
("S", OscBundle(contents=(OscMessage("/p_new", 1000, 0, 1),)))
]
31 changes: 31 additions & 0 deletions tests/realtime/test_Group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,3 +1918,34 @@ def test_remove_01(server):
assert not synth_b.is_allocated
assert not synth_c.is_allocated
assert not synth_d.is_allocated


def test_allocate_parallel(server):
group = supriya.realtime.Group(parallel=True)
assert group.parallel
with server.osc_protocol.capture() as transcript:
group.allocate(server)
server_state = str(server.query())
assert server_state == normalize(
"""
NODE TREE 0 group
1 group
1000 group
"""
)
assert [(_.label, _.message) for _ in transcript] == [
("S", OscMessage("/p_new", 1000, 0, 1)),
("R", OscMessage("/n_go", 1000, 1, -1, -1, 1, -1, -1)),
]


def test_add_group_parallel(server):
with server.osc_protocol.capture() as transcript:
server.add_group(parallel=True)
server.default_group.add_group(parallel=True)
assert [(_.label, _.message) for _ in transcript] == [
("S", OscMessage("/p_new", 1000, 0, 1)),
("R", OscMessage("/n_go", 1000, 1, -1, -1, 1, -1, -1)),
("S", OscMessage("/p_new", 1001, 0, 1)),
("R", OscMessage("/n_go", 1001, 1, -1, 1000, 1, -1, -1)),
]