Skip to content

Commit

Permalink
Adds support for transitive rebinding of parameters.
Browse files Browse the repository at this point in the history
If fragment A has two children, B and C, support already existed for overriding C.my_param to depend on B.my_param then B.my_param to depend on A.my_param.

This PR lets you do it the other way around, giving the same final result but reversing the call order so you can first bind B.my_param to A.my_param, then C.my_param to B.my_param.
  • Loading branch information
charlesbaynham committed Dec 11, 2024
1 parent 7adc2b6 commit 9765e57
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 27 deletions.
52 changes: 29 additions & 23 deletions ndscan/experiment/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ def build(self, fragment_path: list[str], *args, **kwargs):
#: attribute names of the respective ParamHandles) to *Param instances.
self._free_params = OrderedDict()

#: Maps own attribute name to the ParamHandles of the rebound parameters in
#: their original subfragment.
self._rebound_subfragment_params = dict()

#: List of (param, store) tuples of parameters set to their defaults after
#: init_params().
self._default_params = []
Expand Down Expand Up @@ -525,36 +521,45 @@ def bind_param(self, param_name: str, source: ParamHandle) -> Any:
``Fluoresce``, binding its intensity and detuning parameters to values and
defaults appropriate for those particular tasks.
Transitive binding is supported: if the source parameter is itself bound /
overridden, this parameter will be bound to the ultimate source.
See :meth:`override_param`, which sets the parameter to a fixed value/store.
:param param_name: The name of the parameter to be bound (i.e.
``self.<param_name>``). Must be a free parameter of this fragment (not
already bound or overridden).
:param source: The parameter to bind to. Must be a free parameter of its
respective owner.
:param source: The parameter to bind to. Can be free, overridden or rebound.
"""
param = self._free_params.get(param_name, None)
assert param is not None, f"Not a free parameter: '{param_name}'"

# We don't support "transitive" binding for parameters that are already bound.
assert source.name in source.owner._free_params, \
"Source parameter is not a free parameter; already bound/overridden?"

own_type = type(self._free_params[param_name])
source_type = type(source.owner._free_params[source.name])
assert own_type == source_type, (
f"Cannot bind {own_type.__name__} '{param_name}' " +
f"to source of type {source_type.__name__}")

# To support "transitive" binding of parameters, follow the chaining of
# rebindings until the top-level source is found.
toplevel_handle = source._get_toplevel_handle()

# Compare the types of the ParamHandles instead of the Params. In the
# case of EnumParams, this will catch an attempt bind the wrong type of
# EnumParam to another EnumParam.
param_handle: ParamHandle = getattr(self, param_name)
own_handle_type = type(param_handle)
toplevel_handle_type = type(toplevel_handle)
assert own_handle_type == toplevel_handle_type, (
f"Cannot bind {own_handle_type.__name__} '{param_name}' " +
f"to source of type {toplevel_handle_type.__name__}")

# Record the new rebinding
del self._free_params[param_name]
source._add_child_handle(param_handle)

handles = self._get_all_handles_for_param(param_name)

for handle in handles:
handle.parameter = source.parameter
# Redirect the Parameter and ParamStores of all the bound handles to the
# new toplevel
toplevel_store = toplevel_handle.get_store()
for handle in toplevel_handle._get_all_handles_for_param():
if toplevel_store is not None:
handle.set_store(toplevel_store)

source.owner._rebound_subfragment_params.setdefault(source.name,
[]).extend(handles)
handle.parameter = toplevel_handle.parameter

return param

Expand Down Expand Up @@ -721,7 +726,8 @@ def get_always_shown_params(self):
return [getattr(self, name) for name in self._free_params.keys()]

def _get_all_handles_for_param(self, name: str) -> list[ParamHandle]:
return [getattr(self, name)] + self._rebound_subfragment_params.get(name, [])
handle: ParamHandle = getattr(self, name)
return handle._get_all_handles_for_param()

def _stringize_path(self) -> str:
return "/".join(self._fragment_path)
Expand Down
43 changes: 42 additions & 1 deletion ndscan/experiment/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from artiq.language import host_only, portable, units
from enum import Enum
from numpy import int32
from typing import Any, TYPE_CHECKING
from typing import Any, TYPE_CHECKING, Optional
from ..utils import eval_param_default, GetDataset

__all__ = [
Expand Down Expand Up @@ -260,6 +260,14 @@ def __init__(self, owner: "Fragment", name: str, parameter):
self._store = None
self._changed_after_use = True

self._parent_handle: ParamHandle | None = None
self._children_handles: list[ParamHandle] = []

def get_store(self) -> Optional[ParamStore]:
"""
"""
return self._store

def set_store(self, store: ParamStore) -> None:
"""
"""
Expand All @@ -275,6 +283,39 @@ def changed_after_use(self) -> bool:
"""
return self._changed_after_use

@host_only
def _get_toplevel_handle(self) -> "ParamHandle":
"""
Get the highest level ParamHandle in the chain of bound parameters
Walks the DAG of bound parameters to find the highest level handle.
That may be this ParamHandle if this parameter is not rebound.
"""
if self._parent_handle is None:
return self
else:
return self._parent_handle._get_toplevel_handle()

@host_only
def _get_all_handles_for_param(self) -> list["ParamHandle"]:
"""
Get all handles that are bound to this handle (including this one)
Walks the DAG of bound parameters to find all child handles.
"""
result = [self]
for child in self._children_handles:
result.extend(child._get_all_handles_for_param())
return result

@host_only
def _add_child_handle(self, rebound_handle: "ParamHandle"):
"""
Mark a new parameter handle as being rebound to this one
"""
rebound_handle._parent_handle = self
self._children_handles.append(rebound_handle)


class FloatParamHandle(ParamHandle):
@portable
Expand Down
130 changes: 127 additions & 3 deletions test/test_experiment_fragment.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""
Tests for general fragment tree behaviour.
"""

from enum import Enum
from ndscan.experiment import *
from ndscan.experiment.parameters import IntParamStore
from fixtures import (AddOneFragment, MultiReboundAddOneFragment,
ReboundReboundAddOneFragment)
from fixtures import (
AddOneFragment,
MultiReboundAddOneFragment,
ReboundReboundAddOneFragment,
)
from mock_environment import HasEnvironmentCase


Expand Down Expand Up @@ -57,6 +60,24 @@ def test_datasets(self):
self.assertEqual(ddf.bar.get(), 5)


class TransitiveReboundAddOneFragment(ExpFragment):
def build_fragment(self):
self.setattr_fragment("first", AddOneFragment)
self.setattr_fragment("second", AddOneFragment)
self.setattr_fragment("third", AddOneFragment)

self.setattr_param_like("value", self.first, default=2)

self.first.bind_param("value", self.value)
self.second.bind_param("value", self.first.value)
self.third.bind_param("value", self.second.value)

def run_once(self):
self.first.run_once()
self.second.run_once()
self.third.run_once()


class TestRebinding(HasEnvironmentCase):
def test_recursive_rebind_default(self):
rrf = self.create(ReboundReboundAddOneFragment, [])
Expand All @@ -76,6 +97,44 @@ def test_multi_rebind(self):
self.assertEqual(result[mrf.first.result], 3)
self.assertEqual(result[mrf.second.result], 3)

def test_transitive_rebind(self):
trf = self.create(TransitiveReboundAddOneFragment, [])

result = run_fragment_once(trf)
self.assertEqual(result[trf.first.result], 3)
self.assertEqual(result[trf.second.result], 3)
self.assertEqual(result[trf.third.result], 3)

def test_transitive_rebind_with_final_override(self):
trf = self.create(TransitiveReboundAddOneFragment, [])
trf.override_param("value", 3)
result = run_fragment_once(trf)
self.assertEqual(result[trf.first.result], 4)
self.assertEqual(result[trf.second.result], 4)
self.assertEqual(result[trf.third.result], 4)

def test_transitive_rebind_with_initial_override(self):
class OverriddenTransitiveReboundAddOneFragment(ExpFragment):
def build_fragment(self):
self.setattr_fragment("first", AddOneFragment)
self.setattr_fragment("second", AddOneFragment)
self.setattr_fragment("third", AddOneFragment)

self.first.override_param("value", 2)
self.second.bind_param("value", self.first.value)
self.third.bind_param("value", self.second.value)

def run_once(self):
self.first.run_once()
self.second.run_once()
self.third.run_once()

trf = self.create(OverriddenTransitiveReboundAddOneFragment, [])
result = run_fragment_once(trf)
self.assertEqual(result[trf.first.result], 3)
self.assertEqual(result[trf.second.result], 3)
self.assertEqual(result[trf.third.result], 3)

def test_invalid_bind(self):
class InvalidBindFragment(ExpFragment):
def build_fragment(self):
Expand All @@ -87,6 +146,71 @@ def build_fragment(self):
self.create(InvalidBindFragment, [])


class StrOptions(Enum):
first = "A"
second = "B"
third = "C"


class IntOptions(Enum):
first = 1
second = 2
third = 3


class EnumFragString(ExpFragment):
def build_fragment(self):
self.setattr_param("value_str",
EnumParam,
description="Enum string",
default=StrOptions.first)

def run_once(self):
print(self.value_str.get())


class EnumFragInt(ExpFragment):
def build_fragment(self):
self.setattr_param("value_int",
EnumParam,
description="Enum int",
default=IntOptions.first)

def run_once(self):
print(self.value_int.get())


class TestEnumRebinding(HasEnvironmentCase):
def test_binding_wrong_enum_fails(self):
class EnumFragsWrong(ExpFragment):
def build_fragment(self):
self.setattr_fragment("enum_frag_str", EnumFragString)
self.setattr_fragment("enum_frag_int", EnumFragInt)
self.enum_frag_int.bind_param("value_int", self.enum_frag_str.value_str)

def run_once(self):
self.enum_frag_str.run_once()
self.enum_frag_int.run_once()

with self.assertRaises(AssertionError):
self.create(EnumFragsWrong, [])

def test_binding_right_enum(self):
class EnumFragsRight(ExpFragment):
def build_fragment(self):
self.setattr_fragment("enum_frag_str1", EnumFragString)
self.setattr_fragment("enum_frag_str2", EnumFragString)
self.enum_frag_str1.bind_param("value_str",
self.enum_frag_str2.value_str)

def run_once(self):
self.enum_frag_str1.run_once()
self.enum_frag_str2.run_once()

frag = self.create(EnumFragsRight, [])
run_fragment_once(frag)


class TestMisc(HasEnvironmentCase):
def test_namespacing(self):
a = self.create(AddOneFragment, ["a"])
Expand Down

0 comments on commit 9765e57

Please sign in to comment.