Skip to content

Commit

Permalink
fix physical channel/device collection slice/string behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
zhindes committed Mar 8, 2024
1 parent 2c18e56 commit 12ba7a1
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 16 deletions.
6 changes: 6 additions & 0 deletions generated/nidaqmx/system/_collections/device_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def __getitem__(self, index):
return [_DeviceAlternateConstructor(name, self._interpreter) for name in self.device_names[index]]
elif isinstance(index, str):
device_names = unflatten_channel_string(index)
all_devices = self.device_names
# Validate the device names we were provided
for device in device_names:
if device not in all_devices:
raise KeyError(f'"{device}" is not a valid device name.')

if len(device_names) == 1:
return _DeviceAlternateConstructor(device_names[0], self._interpreter)
return [_DeviceAlternateConstructor(name, self._interpreter) for name in device_names]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,26 @@ def __getitem__(self, index):
if isinstance(index, int):
return _PhysicalChannelAlternateConstructor(self.channel_names[index], self._interpreter)
elif isinstance(index, slice):
return _PhysicalChannelAlternateConstructor(self.channel_names[index], self._interpreter)
return [_PhysicalChannelAlternateConstructor(channel, self._interpreter) for channel in self.channel_names[index]]
elif isinstance(index, str):
return _PhysicalChannelAlternateConstructor(f'{self._name}/{index}', self._interpreter)
requested_channels = unflatten_channel_string(index)
all_channels = self.channel_names
# Validate the channel names we were provided
channels_to_use = []
for channel in requested_channels:
if channel in all_channels:
channels_to_use.append(channel)
else:
# The channel may have been unqualified, so we'll try to qualify it
qualified_channel = f'{self._name}/{channel}'
if qualified_channel in all_channels:
channels_to_use.append(qualified_channel)
else:
raise KeyError(f'"{channel}" is not a valid channel name.')

if len(channels_to_use) == 1:
return _PhysicalChannelAlternateConstructor(channels_to_use[0], self._interpreter)
return [_PhysicalChannelAlternateConstructor(channel, self._interpreter) for channel in channels_to_use]
else:
raise DaqError(
'Invalid index type "{}" used to access collection.'
Expand Down Expand Up @@ -96,7 +113,7 @@ def all(self):
@property
def channel_names(self):
"""
List[str]: Specifies the entire list of physical channels on this
List[str]: Specifies the entire list of physical channels in this
collection.
"""
raise NotImplementedError()
Expand Down
6 changes: 6 additions & 0 deletions src/handwritten/system/_collections/device_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def __getitem__(self, index):
return [_DeviceAlternateConstructor(name, self._interpreter) for name in self.device_names[index]]
elif isinstance(index, str):
device_names = unflatten_channel_string(index)
all_devices = self.device_names
# Validate the device names we were provided
for device in device_names:
if device not in all_devices:
raise KeyError(f'"{device}" is not a valid device name.')

if len(device_names) == 1:
return _DeviceAlternateConstructor(device_names[0], self._interpreter)
return [_DeviceAlternateConstructor(name, self._interpreter) for name in device_names]
Expand Down
23 changes: 20 additions & 3 deletions src/handwritten/system/_collections/physical_channel_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,26 @@ def __getitem__(self, index):
if isinstance(index, int):
return _PhysicalChannelAlternateConstructor(self.channel_names[index], self._interpreter)
elif isinstance(index, slice):
return _PhysicalChannelAlternateConstructor(self.channel_names[index], self._interpreter)
return [_PhysicalChannelAlternateConstructor(channel, self._interpreter) for channel in self.channel_names[index]]
elif isinstance(index, str):
return _PhysicalChannelAlternateConstructor(f'{self._name}/{index}', self._interpreter)
requested_channels = unflatten_channel_string(index)
all_channels = self.channel_names
# Validate the channel names we were provided
channels_to_use = []
for channel in requested_channels:
if channel in all_channels:
channels_to_use.append(channel)
else:
# The channel may have been unqualified, so we'll try to qualify it
qualified_channel = f'{self._name}/{channel}'
if qualified_channel in all_channels:
channels_to_use.append(qualified_channel)
else:
raise KeyError(f'"{channel}" is not a valid channel name.')

if len(channels_to_use) == 1:
return _PhysicalChannelAlternateConstructor(channels_to_use[0], self._interpreter)
return [_PhysicalChannelAlternateConstructor(channel, self._interpreter) for channel in channels_to_use]
else:
raise DaqError(
'Invalid index type "{}" used to access collection.'
Expand Down Expand Up @@ -96,7 +113,7 @@ def all(self):
@property
def channel_names(self):
"""
List[str]: Specifies the entire list of physical channels on this
List[str]: Specifies the entire list of physical channels in this
collection.
"""
raise NotImplementedError()
Expand Down
4 changes: 2 additions & 2 deletions tests/component/_task_modules/channels/test_di_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test___task___add_di_chan_chan_for_all_lines___sets_channel_attributes(
num_lines: int,
) -> None:
chan: DIChannel = task.di_channels.add_di_chan(
flatten_channel_string(sim_6363_device.di_lines[:num_lines].name),
flatten_channel_string(sim_6363_device.di_lines.channel_names[:num_lines]),
line_grouping=LineGrouping.CHAN_FOR_ALL_LINES,
)

Expand All @@ -36,7 +36,7 @@ def test___task___add_di_chan_chan_per_line___sets_channel_attributes(
num_lines: int,
) -> None:
chans: DIChannel = task.di_channels.add_di_chan(
flatten_channel_string(sim_6363_device.di_lines[:num_lines].name),
flatten_channel_string(sim_6363_device.di_lines.channel_names[:num_lines]),
line_grouping=LineGrouping.CHAN_PER_LINE,
)

Expand Down
4 changes: 2 additions & 2 deletions tests/component/_task_modules/channels/test_do_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test___task___add_do_chan_chan_for_all_lines___sets_channel_attributes(
num_lines: int,
) -> None:
chan: DOChannel = task.do_channels.add_do_chan(
flatten_channel_string(sim_6363_device.do_lines[:num_lines].name),
flatten_channel_string(sim_6363_device.do_lines.channel_names[:num_lines]),
line_grouping=LineGrouping.CHAN_FOR_ALL_LINES,
)

Expand All @@ -36,7 +36,7 @@ def test___task___add_do_chan_chan_per_line___sets_channel_attributes(
num_lines: int,
) -> None:
chans: DOChannel = task.do_channels.add_do_chan(
flatten_channel_string(sim_6363_device.do_lines[:num_lines].name),
flatten_channel_string(sim_6363_device.do_lines.channel_names[:num_lines]),
line_grouping=LineGrouping.CHAN_PER_LINE,
)

Expand Down
17 changes: 17 additions & 0 deletions tests/component/system/_collections/test_device_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,23 @@ def test___devices___getitem_str_list___shared_interpreter(system: System):
assert all(dev._interpreter is system._interpreter for dev in devices)


def test___devices___getitem_invalid_device_str___raises_error(system: System):
with pytest.raises(KeyError) as exc_info:
system.devices["foo"]

assert "foo" in exc_info.value.args[0]


def test___devices___getitem_invalid_device_str_list___raises_error(system: System):
if len(system.devices) == 0:
pytest.skip("This test requires at least one device.")

with pytest.raises(KeyError) as exc_info:
system.devices[f"{system.devices.device_names[0]},foo"]

assert "foo" in exc_info.value.args[0]


def test___devices___iter___forward_order(system: System):
devices = iter(system.devices)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def test___physical_channels___getitem_int___shared_interpreter(
assert all(chan._interpreter is sim_6363_device._interpreter for chan in channels)


@pytest.mark.xfail(reason="https://github.com/ni/nidaqmx-python/issues/392")
@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_slice___forward_order(
collection_name: str, sim_6363_device: Device
Expand All @@ -49,7 +48,6 @@ def test___physical_channels___getitem_slice___forward_order(
assert [chan.name for chan in channels] == physical_channels.channel_names


@pytest.mark.xfail(reason="https://github.com/ni/nidaqmx-python/issues/392")
@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_slice___shared_interpreter(
collection_name: str, sim_6363_device: Device
Expand All @@ -62,7 +60,7 @@ def test___physical_channels___getitem_slice___shared_interpreter(


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_str___shared_interpreter(
def test___physical_channels___getitem_unqualified_str___shared_interpreter(
collection_name: str, sim_6363_device: Device
):
physical_channels = getattr(sim_6363_device, collection_name)
Expand All @@ -77,7 +75,7 @@ def test___physical_channels___getitem_str___shared_interpreter(


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_str_list___shared_interpreter(
def test___physical_channels___getitem_unqualified_str_list___shared_interpreter(
collection_name: str, sim_6363_device: Device
):
physical_channels = getattr(sim_6363_device, collection_name)
Expand All @@ -86,9 +84,147 @@ def test___physical_channels___getitem_str_list___shared_interpreter(
name.replace(device_name + "/", "") for name in physical_channels.channel_names
]

channel = physical_channels[",".join(unqualified_channel_names)]
channels = physical_channels[",".join(unqualified_channel_names)]

assert channel._interpreter == sim_6363_device._interpreter
assert all(chan._interpreter is sim_6363_device._interpreter for chan in channels)


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_unqualified_str___name(
collection_name: str, sim_6363_device: Device
):
physical_channels = getattr(sim_6363_device, collection_name)
device_name = sim_6363_device.name
unqualified_channel_names = [
name.replace(device_name + "/", "") for name in physical_channels.channel_names
]

channels = [physical_channels[name] for name in unqualified_channel_names]

assert all(chan.name == name for chan, name in zip(channels, physical_channels.channel_names))


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_unqualified_str_list___name(
collection_name: str, sim_6363_device: Device
):
physical_channels = getattr(sim_6363_device, collection_name)
device_name = sim_6363_device.name
unqualified_channel_names = [
name.replace(device_name + "/", "") for name in physical_channels.channel_names
]

channels = physical_channels[",".join(unqualified_channel_names)]

assert all(chan.name == name for chan, name in zip(channels, physical_channels.channel_names))


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_qualified_str___shared_interpreter(
collection_name: str, sim_6363_device: Device
):
physical_channels = getattr(sim_6363_device, collection_name)

channels = [physical_channels[name] for name in physical_channels.channel_names]

assert all(chan._interpreter is sim_6363_device._interpreter for chan in channels)


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_qualified_str_list___shared_interpreter(
collection_name: str, sim_6363_device: Device
):
physical_channels = getattr(sim_6363_device, collection_name)

channels = physical_channels[",".join(physical_channels.channel_names)]

assert all(chan._interpreter is sim_6363_device._interpreter for chan in channels)


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_qualified_str___name(
collection_name: str, sim_6363_device: Device
):
physical_channels = getattr(sim_6363_device, collection_name)

channels = [physical_channels[name] for name in physical_channels.channel_names]

assert all(chan.name == name for chan, name in zip(channels, physical_channels.channel_names))


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_qualified_str_list___name(
collection_name: str, sim_6363_device: Device
):
physical_channels = getattr(sim_6363_device, collection_name)

channels = physical_channels[",".join(physical_channels.channel_names)]

assert all(chan.name == name for chan, name in zip(channels, physical_channels.channel_names))


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_mixed_str_list___shared_interpreter(
collection_name: str, sim_6363_device: Device
):
physical_channels = getattr(sim_6363_device, collection_name)
device_name = sim_6363_device.name
qualified_channel_names = physical_channels.channel_names
unqualified_channel_names = [
name.replace(device_name + "/", "") for name in qualified_channel_names
]
middle_idx = len(qualified_channel_names) // 2
mixed_channel_names = (
physical_channels.channel_names[:middle_idx] + unqualified_channel_names[middle_idx:]
)

channels = physical_channels[",".join(mixed_channel_names)]

assert all(chan._interpreter is sim_6363_device._interpreter for chan in channels)


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_mixed_str_list___name(
collection_name: str, sim_6363_device: Device
):
physical_channels = getattr(sim_6363_device, collection_name)
device_name = sim_6363_device.name
qualified_channel_names = physical_channels.channel_names
unqualified_channel_names = [
name.replace(device_name + "/", "") for name in qualified_channel_names
]
middle_idx = len(qualified_channel_names) // 2
mixed_channel_names = (
physical_channels.channel_names[:middle_idx] + unqualified_channel_names[middle_idx:]
)

channels = physical_channels[",".join(mixed_channel_names)]

assert all(chan.name == name for chan, name in zip(channels, physical_channels.channel_names))


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_invalid_channel_str___raises_error(
collection_name: str, sim_6363_device: Device
):
physical_channels = getattr(sim_6363_device, collection_name)

with pytest.raises(KeyError) as exc_info:
physical_channels["foo"]

assert "foo" in exc_info.value.args[0]


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
def test___physical_channels___getitem_invalid_channel_str_list___raises_error(
collection_name: str, sim_6363_device: Device
):
physical_channels = getattr(sim_6363_device, collection_name)

with pytest.raises(KeyError) as exc_info:
physical_channels[f"{physical_channels.channel_names[0]},foo"]

assert "foo" in exc_info.value.args[0]


@pytest.mark.parametrize("collection_name", COLLECTION_NAMES)
Expand Down

0 comments on commit 12ba7a1

Please sign in to comment.