Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify config #417

Open
wants to merge 4 commits into
base: refactor_config2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 20 additions & 78 deletions alphadia/constants/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ library_prediction:

# define custom alphabase modifications not part of unimod or alphabase
# also used for decoy channels
# TODO make this a list
# - name: Dimethyl:d12@K
# composition: H(-2)2H(8)13C(2)
custom_modifications:
# Dimethyl @K channel decoy
- name: Dimethyl:d12@K
Expand Down Expand Up @@ -130,9 +127,6 @@ calibration:
# the maximum number of times an automatic optimizer can be skipped before it is considered to have converged
max_skips: 1

# TODO: remove this parameter
final_full_calibration: False

# TODO: remove this parameter
norm_rt_mode: 'linear'

Expand Down Expand Up @@ -201,27 +195,27 @@ library_multiplexing:
# channels can be either a number or a string
# for every channel, the library gets copied and the modifications are translated according to the mapping
# the following example shows how to multiplex mTRAQ to three sample channels and a decoy channel
# TODO make this a list
# - channel_name: 0
# channel_modifications:
# mTRAQ@K: mTRAQ@K
# mTRAQ@Any_N-term: mTRAQ@Any_N-term
multiplex_mapping: {}
#0:
# mTRAQ@K: mTRAQ@K
# mTRAQ@Any_N-term: mTRAQ@Any_N-term

#4:
# mTRAQ@K: mTRAQ:13C(3)15N(1)@K
# mTRAQ@Any_N-term: mTRAQ:13C(3)15N(1)@Any_N-term
multiplex_mapping: []
# - channel_name: 0
# modifications:
# mTRAQ@K: mTRAQ@K
# mTRAQ@Any_N-term: mTRAQ@Any_N-term
#
# - channel_name: 4
# modifications:
# mTRAQ@K: mTRAQ:13C(3)15N(1)@K
# mTRAQ@Any_N-term: mTRAQ:13C(3)15N(1)@Any_N-term
#
# - channel_name: 8
# modifications:
# mTRAQ@K: mTRAQ:13C(6)15N(2)@K
# mTRAQ@Any_N-term: mTRAQ:13C(6)15N(2)@Any_N-term
#
# - channel_name: 12
# modifications:
# mTRAQ@K: mTRAQ:d12@K
# mTRAQ@Any_N-term: mTRAQ:d12@Any_N-term

#8:
# mTRAQ@K: mTRAQ:13C(6)15N(2)@K
# mTRAQ@Any_N-term: mTRAQ:13C(6)15N(2)@Any_N-term

#12:
# mTRAQ@K: mTRAQ:d12@K
# mTRAQ@Any_N-term: mTRAQ:d12@Any_N-term



Expand Down Expand Up @@ -388,58 +382,6 @@ transfer_learning:
instrument: 'Lumos'


# configuration for the calibration manager
# the config has to start with the calibration keyword and consists of a list of calibration groups.
# each group consists of datapoints which have multiple properties.
# This can be for example precursors (mz, rt ...), fragments (mz, ...), quadrupole (transfer_efficiency)
calibration_manager: # TODO move to a separate file or hard-code
- name: fragment
estimators:
- name: mz
model: LOESSRegression
model_args:
n_kernels: 2
input_columns:
- mz_library
target_columns:
- mz_observed
output_columns:
- mz_calibrated
transform_deviation: 1e6
- name: precursor
estimators:
- name: mz
model: LOESSRegression
model_args:
n_kernels: 2
input_columns:
- mz_library
target_columns:
- mz_observed
output_columns:
- mz_calibrated
transform_deviation: 1e6
- name: rt
model: LOESSRegression
model_args:
n_kernels: 6
input_columns:
- rt_library
target_columns:
- rt_observed
output_columns:
- rt_calibrated
- name: mobility
model: LOESSRegression
model_args:
n_kernels: 2
input_columns:
- mobility_library
target_columns:
- mobility_observed
output_columns:
- mobility_calibrated

# scope of default yaml should be one search step
multistep_search:
transfer_step_enabled: False
Expand Down
12 changes: 10 additions & 2 deletions alphadia/libtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,12 +617,20 @@ def forward(self, input: SpecLibBase) -> SpecLibBase:


class MultiplexLibrary(ProcessingStep):
def __init__(self, multiplex_mapping: dict, input_channel: str | int | None = None):
def __init__(self, multiplex_mapping: list, input_channel: str | int | None = None):
"""Initialize the MultiplexLibrary step."""

self._multiplex_mapping = multiplex_mapping
self._multiplex_mapping = self._create_multiplex_mapping(multiplex_mapping)
self._input_channel = input_channel

@staticmethod
def _create_multiplex_mapping(multiplex_mapping: list) -> dict:
"""Create a dictionary from the multiplex mapping list."""
mapping = {}
for list_item in multiplex_mapping:
mapping[list_item["channel_name"]] = list_item["modifications"]
return mapping

def validate(self, input: str) -> bool:
"""Validate the input object. It is expected that the input is a path to a file which exists."""
valid = True
Expand Down
6 changes: 3 additions & 3 deletions alphadia/search_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _init_config(

config_updates = []

if user_config is not None:
if user_config:
logger.info("loading additional config provided via CLI")
# load update config from dict
if isinstance(user_config, dict):
Expand All @@ -108,15 +108,15 @@ def _init_config(
"'config' parameter must be of type 'dict' or 'Config'"
)

if cli_config is not None:
if cli_config:
logger.info("loading additional config provided via CLI parameters")
cli_config_update = Config(
cli_config, experiment_name=USER_DEFINED_CLI_PARAM
)
config_updates.append(cli_config_update)

# this needs to be last
if extra_config is not None:
if extra_config:
extra_config_update = Config(extra_config, experiment_name=MULTISTEP_SEARCH)
# need to overwrite user-defined output folder here to have correct value in config dump
extra_config[ConfigKeys.OUTPUT_DIRECTORY] = output_folder
Expand Down
1 change: 0 additions & 1 deletion alphadia/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def load(

# initialize the calibration manager
self._calibration_manager = manager.CalibrationManager(
self.config["calibration_manager"],
path=os.path.join(self.path, self.CALIBRATION_MANAGER_PKL_NAME),
load_from_file=self.config["general"]["reuse_calibration"],
reporter=self.reporter,
Expand Down
59 changes: 53 additions & 6 deletions alphadia/workflow/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,58 @@

# TODO move all managers to dedicated modules

# configuration for the calibration manager
# the config has to start with the calibration keyword and consists of a list of calibration groups.
# each group consists of datapoints which have multiple properties.
# This can be for example precursors (mz, rt ...), fragments (mz, ...), quadrupole (transfer_efficiency)
# TODO simplify this structure and the config loading
CALIBRATION_MANAGER_CONFIG = [
{
"estimators": [
{
"input_columns": ["mz_library"],
"model": "LOESSRegression",
"model_args": {"n_kernels": 2},
"name": "mz",
"output_columns": ["mz_calibrated"],
"target_columns": ["mz_observed"],
"transform_deviation": "1e6",
}
],
"name": "fragment",
},
{
"estimators": [
{
"input_columns": ["mz_library"],
"model": "LOESSRegression",
"model_args": {"n_kernels": 2},
"name": "mz",
"output_columns": ["mz_calibrated"],
"target_columns": ["mz_observed"],
"transform_deviation": "1e6",
},
{
"input_columns": ["rt_library"],
"model": "LOESSRegression",
"model_args": {"n_kernels": 6},
"name": "rt",
"output_columns": ["rt_calibrated"],
"target_columns": ["rt_observed"],
},
{
"input_columns": ["mobility_library"],
"model": "LOESSRegression",
"model_args": {"n_kernels": 2},
"name": "mobility",
"output_columns": ["mobility_calibrated"],
"target_columns": ["mobility_observed"],
},
],
"name": "precursor",
},
]


class BaseManager:
def __init__(
Expand Down Expand Up @@ -151,7 +203,6 @@ def fit_predict(self):
class CalibrationManager(BaseManager):
def __init__(
self,
config: None | dict = None,
path: None | str = None,
load_from_file: bool = True,
**kwargs,
Expand All @@ -162,10 +213,6 @@ def __init__(

Parameters
----------

config : typing.Union[None, dict], default=None
Calibration config dict. If None, the default config is used.

path : str, default=None
Path where the current parameter set is saved to and loaded from.

Expand All @@ -181,7 +228,7 @@ def __init__(

if not self.is_loaded_from_file:
self.estimator_groups = []
self.load_config(config)
self.load_config(CALIBRATION_MANAGER_CONFIG)

@property
def estimator_groups(self):
Expand Down
15 changes: 9 additions & 6 deletions tests/unit_tests/test_libtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,14 @@ def test_multiplex_library():
test_lib.calc_precursor_mz()
test_lib.calc_fragment_mz_df()

test_multiplex_mapping = {
0: {"mTRAQ@K": "mTRAQ@K"},
"magic_chanel": {"mTRAQ@K": "mTRAQ:13C(3)15N(1)@K"},
1337: {"mTRAQ@K": "mTRAQ:13C(6)15N(2)@K"},
}
test_multiplex_mapping = [
{"channel_name": 0, "modifications": {"mTRAQ@K": "mTRAQ@K"}},
{
"channel_name": "magic_channel",
"modifications": {"mTRAQ@K": "mTRAQ:13C(3)15N(1)@K"},
},
{"channel_name": 1337, "modifications": {"mTRAQ@K": "mTRAQ:13C(6)15N(2)@K"}},
]

# when
multiplexer = libtransform.MultiplexLibrary(test_multiplex_mapping)
Expand All @@ -116,7 +119,7 @@ def test_multiplex_library():
assert result_lib.precursor_df["charge"].nunique() == 2
assert result_lib.precursor_df["frag_stop_idx"].nunique() == 6

for channel in [0, 1337, "magic_chanel"]:
for channel in [0, 1337, "magic_channel"]:
assert (
result_lib.precursor_df[
result_lib.precursor_df["channel"] == channel
Expand Down
44 changes: 30 additions & 14 deletions tests/unit_tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
from copy import deepcopy
from pathlib import Path
from unittest.mock import patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -48,7 +49,7 @@ def test_base_manager_load():
os.remove(my_base_manager.path)


TEST_CONFIG = [
TEST_CALIBRATION_MANAGER_CONFIG = [
{
"name": "precursor",
"estimators": [
Expand Down Expand Up @@ -90,9 +91,13 @@ def test_base_manager_load():
def test_calibration_manager_init():
# initialize the calibration manager
temp_path = os.path.join(tempfile.tempdir, "calibration_manager.pkl")
calibration_manager = manager.CalibrationManager(
TEST_CONFIG, path=temp_path, load_from_file=False
)
with patch(
"alphadia.workflow.manager.CALIBRATION_MANAGER_CONFIG",
TEST_CALIBRATION_MANAGER_CONFIG,
):
calibration_manager = manager.CalibrationManager(
path=temp_path, load_from_file=False
)

assert calibration_manager.path == temp_path
assert calibration_manager.is_loaded_from_file is False
Expand Down Expand Up @@ -158,9 +163,13 @@ def calibration_testdata():

def test_calibration_manager_fit_predict():
temp_path = os.path.join(tempfile.tempdir, "calibration_manager.pkl")
calibration_manager = manager.CalibrationManager(
TEST_CONFIG, path=temp_path, load_from_file=False
)
with patch(
"alphadia.workflow.manager.CALIBRATION_MANAGER_CONFIG",
TEST_CALIBRATION_MANAGER_CONFIG,
):
calibration_manager = manager.CalibrationManager(
path=temp_path, load_from_file=False
)

test_df = calibration_testdata()

Expand All @@ -182,9 +191,13 @@ def test_calibration_manager_fit_predict():

def test_calibration_manager_save_load():
temp_path = os.path.join(tempfile.tempdir, "calibration_manager.pkl")
calibration_manager = manager.CalibrationManager(
TEST_CONFIG, path=temp_path, load_from_file=False
)
with patch(
"alphadia.workflow.manager.CALIBRATION_MANAGER_CONFIG",
TEST_CALIBRATION_MANAGER_CONFIG,
):
calibration_manager = manager.CalibrationManager(
path=temp_path, load_from_file=False
)

test_df = calibration_testdata()
calibration_manager.fit(test_df, "precursor", plot=False)
Expand All @@ -195,9 +208,13 @@ def test_calibration_manager_save_load():

calibration_manager.save()

calibration_manager_loaded = manager.CalibrationManager(
TEST_CONFIG, path=temp_path, load_from_file=True
)
with patch(
"alphadia.workflow.manager.CALIBRATION_MANAGER_CONFIG",
TEST_CALIBRATION_MANAGER_CONFIG,
):
calibration_manager_loaded = manager.CalibrationManager(
path=temp_path, load_from_file=True
)
assert calibration_manager_loaded.is_fitted is True
assert calibration_manager_loaded.is_loaded_from_file is True

Expand Down Expand Up @@ -433,7 +450,6 @@ def create_workflow_instance():
]
)
workflow._calibration_manager = manager.CalibrationManager(
workflow.config["calibration_manager"],
path=os.path.join(workflow.path, workflow.CALIBRATION_MANAGER_PKL_NAME),
load_from_file=workflow.config["general"]["reuse_calibration"],
reporter=workflow.reporter,
Expand Down
Loading