Skip to content

Commit

Permalink
Make runners configurable from the command-line using OmegaConf (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbmi authored Oct 27, 2023
1 parent e10a19c commit 4ac73cc
Show file tree
Hide file tree
Showing 18 changed files with 706 additions and 318 deletions.
224 changes: 132 additions & 92 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,17 @@ pydantic = "^1.9.1"
pylsl = "^1.16.2"
colorednoise = "^2.2.0"
pydantic-yaml = "^0.9.0"
pyaml = "^21.10.1"
scipy = "^1.10.0"
pygame = "2.2.0.dev2"
pooch = "^1.7.0"
neo = "^0.12.0"
joblib = "^1.2.0"
matplotlib = "^3.7.1"
rich = ">=10.0.0"
scikit-learn = "1.2.1"
screeninfo = "0.8.1"
pyobjc-framework-Quartz = { version = "*", platform = "darwin" }
omegaconf = "^2.3"

[tool.poetry.dev-dependencies]
pytest = "^7.1.2"
Expand Down
1 change: 1 addition & 0 deletions src/neural_data_simulator/config/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ timer:
encoder:
# set a path to your custom preprocessor plugin
# preprocessor: 'plugins/preprocessor.py'
preprocessor: null

model: "plugins/model.py"

Expand Down
62 changes: 37 additions & 25 deletions src/neural_data_simulator/core/settings.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Models for parsing and validating the contents of `settings.yaml`."""
from enum import Enum
from enum import unique
from typing import Dict, Optional

from pydantic import BaseModel
from pydantic import Extra
from pydantic import Json
from pydantic import validator
from pydantic_yaml import VersionedYamlModel
from pydantic_yaml import YamlStrEnum


class LogLevel(YamlStrEnum):
@unique
class LogLevel(str, Enum):
"""Possible log levels."""

_DEBUG = "DEBUG"
Expand All @@ -18,28 +21,32 @@ class LogLevel(YamlStrEnum):
_CRITICAL = "CRITICAL"


class EncoderEndpointType(YamlStrEnum):
@unique
class EncoderEndpointType(str, Enum):
"""Possible types for the encoder input or output."""

FILE = "file"
LSL = "LSL"


class EphysGeneratorEndpointType(YamlStrEnum):
@unique
class EphysGeneratorEndpointType(str, Enum):
"""Possible types of input for the ephys generator."""

TESTING = "testing"
LSL = "LSL"


class EncoderModelType(YamlStrEnum):
@unique
class EncoderModelType(str, Enum):
"""Possible types of input for the encoder model."""

PLUGIN = "plugin"
VELOCITY_TUNING_CURVES = "velocity_tuning_curves"


class LSLChannelFormatType(YamlStrEnum):
@unique
class LSLChannelFormatType(str, Enum):
"""Possible values for the LSL channel format."""

_FLOAT32 = "float32"
Expand All @@ -50,24 +57,24 @@ class LSLChannelFormatType(YamlStrEnum):
_INT64 = "int64"


class TimerModel(BaseModel):
class TimerModel(BaseModel, extra=Extra.forbid):
"""Settings for the timer implementation."""

max_cpu_buffer: float
loop_time: float


class LSLInputModel(BaseModel):
class LSLInputModel(BaseModel, extra=Extra.forbid):
"""Settings for all LSL inlets."""

connection_timeout: float
stream_name: str


class LSLOutputModel(BaseModel):
class LSLOutputModel(BaseModel, extra=Extra.forbid):
"""Settings for all LSL outlets."""

class _Instrument(BaseModel):
class _Instrument(BaseModel, extra=Extra.forbid):
manufacturer: str
model: str
id: int
Expand All @@ -80,13 +87,13 @@ class _Instrument(BaseModel):
channel_labels: Optional[list[str]]


class EncoderSettings(BaseModel):
class EncoderSettings(BaseModel, extra=Extra.forbid):
"""Settings for the encoder."""

class Input(BaseModel):
class Input(BaseModel, extra=Extra.forbid):
"""Settings for the encoder input."""

class File(BaseModel):
class File(BaseModel, extra=Extra.forbid):
"""Settings for the encoder input type file."""

path: str
Expand All @@ -98,7 +105,7 @@ class File(BaseModel):
file: Optional[File]
lsl: Optional[LSLInputModel]

class Output(BaseModel):
class Output(BaseModel, extra=Extra.forbid):
"""Settings for the encoder output."""

n_channels: int
Expand All @@ -115,13 +122,13 @@ class Output(BaseModel):

@validator("model")
def _model_entry_point_must_be_a_python_file(cls, v):
if v is not None and v.endswith(".py"):
if v is None or v.endswith(".py"):
return v
raise ValueError("The model entry point must be a Python file")

@validator("preprocessor", "postprocessor")
def _plugin_entry_point_must_be_a_python_file(cls, v):
if v is not None and v.endswith(".py"):
if v is None or v.endswith(".py"):
return v
raise ValueError("The plugin entry point must be a Python file")

Expand All @@ -138,10 +145,10 @@ def _lsl_type_must_have_a_lsl_object(cls, value):
return value


class EphysGeneratorSettings(BaseModel):
class EphysGeneratorSettings(BaseModel, extra=Extra.forbid):
"""Settings for the spike generator."""

class Waveforms(BaseModel):
class Waveforms(BaseModel, extra=Extra.forbid):
"""Settings for the spike waveform prototypes."""

n_samples: int
Expand Down Expand Up @@ -169,10 +176,10 @@ def _unit_prototype_mapping_needs_a_default_value(cls, v):
raise ValueError("Mapped prototype doesn't have a default value.")
return v

class Input(BaseModel):
class Input(BaseModel, extra=Extra.forbid):
"""Settings for the ephys generator input."""

class Testing(BaseModel):
class Testing(BaseModel, extra=Extra.forbid):
"""Settings for the ephys generator input type testing."""

n_channels: int
Expand All @@ -182,23 +189,23 @@ class Testing(BaseModel):
lsl: Optional[LSLInputModel]
testing: Optional[Testing]

class Output(BaseModel):
class Output(BaseModel, extra=Extra.forbid):
"""Settings for the ephys generator output."""

class Raw(BaseModel):
class Raw(BaseModel, extra=Extra.forbid):
"""Settings for the ephys generator output type raw."""

lsl: LSLOutputModel

class LFP(BaseModel):
class LFP(BaseModel, extra=Extra.forbid):
"""Settings for the ephys generator output type LFP."""

data_frequency: float
filter_cutoff: float
filter_order: int
lsl: LSLOutputModel

class SpikeEvents(BaseModel):
class SpikeEvents(BaseModel, extra=Extra.forbid):
"""Settings for the ephys generator output type spike events."""

lsl: LSLOutputModel
Expand All @@ -207,7 +214,7 @@ class SpikeEvents(BaseModel):
lfp: LFP
spike_events: SpikeEvents

class Noise(BaseModel):
class Noise(BaseModel, extra=Extra.forbid):
"""Settings for the ephys generator noise."""

beta: float
Expand Down Expand Up @@ -250,3 +257,8 @@ class Settings(VersionedYamlModel):
timer: TimerModel
encoder: EncoderSettings
ephys_generator: EphysGeneratorSettings

class Config:
"""Pydantic configuration."""

extra = Extra.forbid
54 changes: 45 additions & 9 deletions src/neural_data_simulator/decoder/run_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from pathlib import Path
from typing import cast

from pydantic import Extra
from pydantic_yaml import VersionedYamlModel
from rich.pretty import pprint
import yaml

from neural_data_simulator.core import inputs
from neural_data_simulator.core import outputs
Expand All @@ -17,9 +20,11 @@
from neural_data_simulator.decoder.settings import DecoderSettings
from neural_data_simulator.util.runtime import configure_logger
from neural_data_simulator.util.runtime import get_abs_path
from neural_data_simulator.util.runtime import get_configs_dir
from neural_data_simulator.util.runtime import initialize_logger
from neural_data_simulator.util.runtime import open_connection
from neural_data_simulator.util.settings_loader import get_script_settings
from neural_data_simulator.util.settings_loader import check_config_override_str
from neural_data_simulator.util.settings_loader import load_settings

SCRIPT_NAME = "nds-decoder"
logger = logging.getLogger(__name__)
Expand All @@ -32,17 +37,39 @@ class _Settings(VersionedYamlModel):
decoder: DecoderSettings
timer: TimerModel

class Config:
extra = Extra.forbid

def _parse_args_settings_path() -> Path:
"""Parse command-line arguments for the settings path."""
parser = argparse.ArgumentParser(description="Run decoder.")

def _parse_args():
parser = argparse.ArgumentParser(
description="Decode behavior from input neural data.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--settings-path",
type=Path,
default=Path(get_configs_dir()).joinpath("settings_decoder.yaml"),
help="Path to the settings_decoder.yaml file.",
)
parser.add_argument(
"--overrides",
"-o",
nargs="*",
type=check_config_override_str,
help=(
"Specify settings overrides as key-value pairs, separated by spaces. "
"For example: -o log_level=DEBUG decoder.spike_threshold=-210"
),
)
parser.add_argument(
"--print-settings-only",
"-p",
action="store_true",
help="Parse/print the settings and exit.",
)
args = parser.parse_args()
return args.settings_path
return args


def _read_decode_send(
Expand All @@ -65,14 +92,21 @@ def _read_decode_send(
def run():
"""Run the decoder loop."""
initialize_logger(SCRIPT_NAME)

settings = cast(
args = _parse_args()
settings: _Settings = cast(
_Settings,
get_script_settings(
_parse_args_settings_path(), "settings_decoder.yaml", _Settings
load_settings(
args.settings_path,
settings_parser=_Settings,
override_dotlist=args.overrides,
),
)
if args.print_settings_only:
pprint(settings)
return

configure_logger(SCRIPT_NAME, settings.log_level)
logger.debug(f"run_decoder settings:\n{yaml.dump(settings.dict())}")

# Set up timer
timer_settings = settings.timer
Expand All @@ -90,6 +124,7 @@ def run():
stream_name=lsl_input_settings.stream_name,
connection_timeout=lsl_input_settings.connection_timeout,
)
logger.debug(f"Querying info from LSL stream: {lsl_input_settings.stream_name}")

# Set up decoder
decoder_model = PersistedFileDecoderModel(get_abs_path(settings.decoder.model_file))
Expand All @@ -101,6 +136,7 @@ def run():
threshold=settings.decoder.spike_threshold,
)

logger.debug("Attempting to open LSL connections...")
try:
with open_connection(data_output), open_connection(data_input):
timer.start()
Expand Down
7 changes: 4 additions & 3 deletions src/neural_data_simulator/decoder/settings.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
"""Schema for Decoder settings."""

from pydantic import BaseModel
from pydantic import Extra

from neural_data_simulator.core.settings import LSLInputModel
from neural_data_simulator.core.settings import LSLOutputModel


class DecoderSettings(BaseModel):
class DecoderSettings(BaseModel, extra=Extra.forbid):
"""Decoder settings."""

class Input(BaseModel):
class Input(BaseModel, extra=Extra.forbid):
"""Decoder input settings."""

lsl: LSLInputModel

class Output(BaseModel):
class Output(BaseModel, extra=Extra.forbid):
"""Decoder output settings."""

sampling_rate: float
Expand Down
Loading

0 comments on commit 4ac73cc

Please sign in to comment.