diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml new file mode 100644 index 0000000..4c92122 --- /dev/null +++ b/.github/workflows/build_wheels.yml @@ -0,0 +1,92 @@ +name: Build wheels + +on: + release: + types: [published] + workflow_dispatch: + +jobs: + generate-wheels-matrix: + name: Generate wheels matrix + runs-on: ubuntu-latest + outputs: + include: ${{ steps.set-matrix.outputs.include }} + steps: + - uses: actions/checkout@v3 + - name: Install cibuildwheel + run: pipx install cibuildwheel==2.14.0 + - id: set-matrix + run: | + # ... (Keep this section as is in the original SHAP example) + + build_wheels: + name: Build ${{ matrix.only }} + needs: generate-wheels-matrix + strategy: + fail-fast: false + matrix: + include: ${{ fromJson(needs.generate-wheels-matrix.outputs.include) }} + runs-on: ${{ matrix.os }} + steps: + - name: Check out the repo + uses: actions/checkout@v3 + + - name: Set up QEMU + if: runner.os == 'Linux' + uses: docker/setup-qemu-action@v2 + with: + platforms: all + + - name: Build wheels + uses: pypa/cibuildwheel@v2.14.0 + with: + only: ${{ matrix.only }} + + - uses: actions/upload-artifact@v3 + with: + path: ./wheelhouse/*.whl + name: bdist_files + + build_sdist: + name: Build source distribution + runs-on: ubuntu-20.04 + steps: + - name: Check out the repo + uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: '3.11' + - name: Build sdist (pep517) + run: | + python -m pip install build + python -m build --sdist + - name: Upload sdist + uses: actions/upload-artifact@v3 + with: + name: sdist_files + path: dist/*.tar.gz + + publish_wheels: + name: Publish wheels on pypi + needs: [build_wheels, build_sdist] + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags') + steps: + - uses: actions/download-artifact@v3 + with: + name: bdist_files + path: dist + - uses: actions/download-artifact@v3 + with: + name: sdist_files + path: dist + - name: Publish package to TestPyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.TEST_PYPI_TOKEN }} + repository-url: https://test.pypi.org/legacy/ + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/codeql_analysis.yml b/.github/workflows/codeql_analysis.yml new file mode 100644 index 0000000..73f4cdf --- /dev/null +++ b/.github/workflows/codeql_analysis.yml @@ -0,0 +1,54 @@ +name: CodeQL + +on: + push: + branches: + - master + pull_request: + branches: + - master + schedule: + - cron: '0 0 * * 0' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: "python" + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + + code-quality: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.9 + + - name: Cache Python dependencies + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-3.9-${{ hashFiles('**/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip-3.9- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 mypy pylint + + - name: Run flake8 (check style) + run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + + - name: Run mypy (check static typing) + run: mypy xomics/__init__.py --follow-imports=skip \ No newline at end of file diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..db5955c --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,45 @@ +name: Unit Tests + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest] + python-version: ['3.9', '3.10', '3.11', '3.12'] + + steps: + - uses: actions/checkout@v4 + + - name: Cache Python dependencies + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}- + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest + + - name: Run Tests + run: pytest tests + env: + HYPOTHESIS_DEADLINE: 10000000 + MPLBACKEND: Agg diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml deleted file mode 100644 index 7f81800..0000000 --- a/.github/workflows/python-package.yml +++ /dev/null @@ -1,44 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - -name: Python package - -on: - push: - branches: [ "master" ] - pull_request: - branches: [ "master" ] - -jobs: - build: - - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.9", "3.10", "3.11"] - - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: install package - run: | - python -m pip install -e . - - name: Test with pytest - run: | - pytest - diff --git a/.github/workflows/test_coverage.yml b/.github/workflows/test_coverage.yml new file mode 100644 index 0000000..45ba3fe --- /dev/null +++ b/.github/workflows/test_coverage.yml @@ -0,0 +1,40 @@ +name: Test Coverage + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest pytest-cov + + - name: Run tests with coverage and generate XML report + run: | + pytest --cov=./ --cov-report=xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: ./coverage.xml + fail_ci_if_error: true + env: + MPLBACKEND: Agg diff --git a/xomics/__init__.py b/xomics/__init__.py index a500a3a..4679c4f 100644 --- a/xomics/__init__.py +++ b/xomics/__init__.py @@ -7,11 +7,7 @@ plot_enrich_map, plot_prank, plot_prank_scatter, - plot_imput_histo, - plot_settings, - plot_legend, - plot_get_clist, - plot_gcfs) + plot_imput_histo) __all__ = [ "pRank", @@ -24,8 +20,4 @@ "plot_prank", "plot_prank_scatter", "plot_imput_histo", - "plot_settings", - "plot_legend", - "plot_get_clist", - "plot_gcfs" ] diff --git a/xomics/_utils/_utils.py b/xomics/_utils/_utils.py new file mode 100644 index 0000000..4463266 --- /dev/null +++ b/xomics/_utils/_utils.py @@ -0,0 +1,10 @@ +"""This is a script for utility functions for utility functions.""" + + +# Helper functions +# This function is only used in utility check function +def add_str(str_error=None, str_add=None): + """Add additional error message 'str_add' to default error message ('add_str')""" + if str_add: + str_error += "\n " + str_add + return str_error diff --git a/xomics/_utils/check_data.py b/xomics/_utils/check_data.py index 2647cee..86aed22 100644 --- a/xomics/_utils/check_data.py +++ b/xomics/_utils/check_data.py @@ -5,37 +5,104 @@ import numpy as np from sklearn.utils import check_array +from ._utils import add_str +from .utils_types import VALID_INT_TYPES, VALID_INT_FLOAT_TYPES +import xomics._utils.check_type as check_type + # Helper functions -def check_array_like(name=None, val=None, dtype=None, ensure_2d=False, allow_nan=False): +def _convert_2d(val=None, name=None, str_add=None): """ - Check if the provided value matches the specified dtype. - If dtype is None, checks for general array-likeness. - If dtype is 'int', 'float', or 'any', checks for specific types. + Convert array-like data to 2D array. Handles lists of arrays, lists of lists, and 1D lists. """ - if name is None: - raise ValueError(f"'{name}' should not be None.") + str_error = add_str(str_error=f"'{name}' should be a 2D list or 2D array with rows having the same number of columns.", + str_add=str_add) + if isinstance(val, list): + # Check if List with arrays and return if yes + if all(isinstance(i, np.ndarray) for i in val): + try: + val = np.asarray(val) + except ValueError: + raise ValueError(str_error) + # Convert 1D list to 2D list + elif all(not isinstance(i, list) for i in val): + try: + val = np.asarray([val]) + except ValueError: + raise ValueError(str_error) + # For nested lists, ensure they are 2D (list of lists with equal lengths) + else: + try: + val = np.array(val) # Convert nested list to numpy array + if val.ndim != 2: + raise ValueError + except ValueError: + raise ValueError(str_error) + elif hasattr(val, 'ndim') and val.ndim == 1: + try: + val = np.asarray([val]) + except ValueError: + raise ValueError(str_error) + return val + + +# Check array like +def check_array_like(name=None, val=None, dtype=None, ensure_2d=False, allow_nan=False, convert_2d=False, + accept_none=False, expected_dim=None, str_add=None): + """Check if the provided value is array-like and matches the specified dtype.""" + if val is None: + if accept_none: + return None # Skip tests + else: + raise ValueError(f"'{name}' should not be None.") + # Extend dtype to handle a list of dtypes including bool + dtype = check_type.check_str(name="dtype", val=dtype, accept_none=True) + valid_dtypes = ["numeric", "int", "float", "bool", None] + if dtype not in valid_dtypes: + str_error = add_str(str_error=f"'dtype' should be one of the following: {valid_dtypes}", str_add=str_add) + raise ValueError(str_error) + dict_expected_dtype = {"numeric": "numeric", "int": "int64", "float": "float64", "bool": "bool"} + expected_dtype = dict_expected_dtype[dtype] if dtype is not None else None + # Specific check for boolean arrays + if dtype == "bool": + flattened_val = np.array(val).flatten() + if not all(isinstance(item, (bool, np.bool_)) for item in flattened_val): + str_error = add_str(str_error=f"All elements in '{name}' must be of type 'bool' (either Python native or NumPy bool).", + str_add=str_add) + raise ValueError(str_error) + # Convert a 1D list or array to a 2D array if needed + if convert_2d: + val = _convert_2d(val=val, name=name, str_add=str_add) # Utilize Scikit-learn's check_array for robust checking - if dtype == 'int': - expected_dtype = 'int' - elif dtype == 'float': - expected_dtype = 'float64' - elif dtype == 'any' or dtype is None: - expected_dtype = None - else: - raise ValueError(f"'dtype' ({dtype}) not recognized.") try: + # Convert list to array val = check_array(val, dtype=expected_dtype, ensure_2d=ensure_2d, force_all_finite=not allow_nan) except Exception as e: - raise ValueError(f"'{name}' should be array-like with {dtype} values." - f"\nscikit message:\n\t{e}") + dtype = "any type" if dtype is None else dtype + raise ValueError(f"'{name}' should be array-like with '{dtype}' values." + f"\nScikit message:\n\t{e}") + # Check dimensions if specified + if expected_dim is not None and len(val.shape) != expected_dim: + str_error = add_str(str_error=f"'{name}' should have {expected_dim} dimensions, but has {len(val.shape)}.", + str_add=str_add) + raise ValueError(str_error) return val # Check feature matrix and labels -def check_X(X, min_n_samples=3, min_n_features=2, ensure_2d=True, allow_nan=False): +def check_X(X, X_name="X", min_n_samples=3, min_n_features=2, min_n_unique_features=None, + ensure_2d=True, allow_nan=False, accept_none=False, str_add=None): """Check the feature matrix X is valid.""" + if X is None: + if not accept_none: + raise ValueError(f"'{X_name}' should not be None") + else: + return None X = check_array_like(name="X", val=X, dtype="float", ensure_2d=ensure_2d, allow_nan=allow_nan) + if np.isinf(X).any(): + str_error = add_str(str_error=f"'X' should not contain infinite values", + str_add=str_add) + raise ValueError(str_error) n_samples, n_features = X.shape if n_samples < min_n_samples: raise ValueError(f"n_samples ({n_samples} in 'X') should be >= {min_n_samples}." @@ -43,69 +110,173 @@ def check_X(X, min_n_samples=3, min_n_features=2, ensure_2d=True, allow_nan=Fals if n_features < min_n_features: raise ValueError(f"n_features ({n_features} in 'X') should be >= {min_n_features}." f"\nX = {X}") + if min_n_unique_features is not None: + n_unique_features = sum([len(set(X[:, col])) > 1 for col in range(n_features)]) + if n_unique_features < min_n_unique_features: + str_error = add_str(str_error=f"'n_unique_features' ({n_unique_features}) should be >= {min_n_unique_features}", + str_add=str_add) + raise ValueError(str_error) return X -def check_X_unique_samples(X, min_n_unique_samples=3): +def check_X_unique_samples(X, min_n_unique_samples=3, str_add=None): """Check if the matrix X has a sufficient number of unique samples.""" n_unique_samples = len(set(map(tuple, X))) if n_unique_samples == 1: - raise ValueError("Feature matrix 'X' should not have all identical samples.") + str_error = add_str(str_error="Feature matrix 'X' should not have all identical samples.", + str_add=str_add) + raise ValueError(str_error) if n_unique_samples < min_n_unique_samples: raise ValueError(f"n_unique_samples ({n_unique_samples}) should be >= {min_n_unique_samples}." f"\nX = {X}") return X -def check_labels(labels=None): - """""" +def check_labels(labels=None, name="labels", vals_requiered=None, len_requiered=None, allow_other_vals=True, + n_per_group_requiered=None, accept_float=False, str_add=None): + """Check the provided labels against various criteria like type, required values, and length.""" if labels is None: - raise ValueError(f"'labels' should not be None.") + raise ValueError(f"'{name}' should not be None.") + labels = check_type.check_list_like(name=name, val=labels) # Convert labels to a numpy array if it's not already labels = np.asarray(labels) + # Ensure labels is at least 1-dimensional + if labels.ndim == 0: + labels = np.array([labels.item()]) # Convert 0-d array to 1-d array unique_labels = set(labels) if len(unique_labels) == 1: - raise ValueError(f"'labels' should contain more than one different value ({unique_labels}).") - wrong_types = [l for l in unique_labels if not np.issubdtype(type(l), np.integer)] + str_error = add_str(str_error=f"'{name}' should contain more than one different value ({unique_labels}).", + str_add=str_add) + raise ValueError(str_error) + valid_types = VALID_INT_TYPES if not accept_float else VALID_INT_FLOAT_TYPES + wrong_types = [l for l in unique_labels if not isinstance(l, valid_types)] if wrong_types: - raise ValueError(f"Labels in 'labels' should be type int, but contain: {set(map(type, wrong_types))}") + str_error = add_str(str_error=f"Labels in '{name}' should be type int, but contain: {set(map(type, wrong_types))}", + str_add=str_add) + raise ValueError(str_error) + if vals_requiered is not None: + missing_vals = [x for x in vals_requiered if x not in labels] + if len(missing_vals) > 0: + str_error = add_str(str_error=f"'{name}' ({unique_labels}) does not contain requiered values: {missing_vals}", + str_add=str_add) + raise ValueError(str_error) + if not allow_other_vals: + wrong_vals = [x for x in labels if x not in vals_requiered] + if len(wrong_vals) > 0: + str_error = add_str(str_error=f"'{name}' ({unique_labels}) does contain wrong values: {wrong_vals}", + str_add=str_add) + raise ValueError(str_error) + if len_requiered is not None and len(labels) != len_requiered: + str_error = add_str(str_error=f"'{name}' (n={len(labels)}) should contain {len_requiered} values.", + str_add=str_add) + raise ValueError(str_error) + # Check for minimum length per group + if n_per_group_requiered is not None: + label_counts = {label: np.sum(labels == label) for label in unique_labels} + underrepresented_labels = {label: count for label, count in label_counts.items() if + count < n_per_group_requiered} + if underrepresented_labels: + str_error = add_str(str_error=f"Each label should have at least {n_per_group_requiered} occurrences. " + f"Underrepresented labels: {underrepresented_labels}", + str_add=str_add) + raise ValueError(str_error) return labels -def check_match_X_labels(X=None, X_name="X", labels=None, labels_name="labels"): - """""" +def check_match_X_labels(X=None, X_name="X", labels=None, labels_name="labels", check_variability=False, + str_add=None): + """Check if the number of samples in X matches the number of labels.""" n_samples, n_features = X.shape if n_samples != len(labels): - raise ValueError(f"n_samples does not match for '{X_name}' ({len(X)}) and '{labels_name}' ({len(labels)}).") + str_error = add_str(str_error=f"n_samples does not match for '{X_name}' ({len(X)}) and '{labels_name}' ({len(labels)}).", + str_add=str_add) + raise ValueError(str_error) + if check_variability: + unique_labels = np.unique(labels) + for label in unique_labels: + group_X = X[labels == label] + if not np.all(np.var(group_X, axis=0) != 0): + str_error = add_str(str_error=f"Variance in 'X' for label '{label}' from '{labels_name}' is too low.", + str_add=str_add) + raise ValueError(str_error) + + +def check_match_X_list_labels(X=None, list_labels=None, check_variability=False, vals_requiered=None, str_add=None): + """Check if each label set is matching with X""" + for i, labels in enumerate(list_labels): + check_labels(labels=labels, vals_requiered=vals_requiered) + check_match_X_labels(X=X, labels=labels, labels_name=f"list_labels (set {i+1})", + check_variability=check_variability, str_add=str_add) + + +def check_match_list_labels_names_datasets(list_labels=None, names_datasets=None, str_add=None): + """Check if length of list_labels and names match""" + if names_datasets is None: + return None # Skip check + if len(list_labels) != len(names_datasets): + str_error = add_str(str_error=f"Length of 'list_labels' ({len(list_labels)}) and 'names_datasets'" + f" ({len(names_datasets)} does not match)", + str_add=str_add) + raise ValueError(str_error) # Check sets -def check_superset_subset(subset=None, superset=None, name_subset=None, name_superset=None): - """""" +def check_superset_subset(subset=None, superset=None, name_subset=None, name_superset=None, str_add=None): + """Check if all elements of the subset are contained in the superset.""" wrong_elements = [x for x in subset if x not in superset] if len(wrong_elements) != 0: - raise ValueError(f"'{name_superset}' does not contain the following elements of '{name_subset}': {wrong_elements}") + str_error = add_str(str_error=f"'{name_superset}' does not contain the following elements of '{name_subset}': {wrong_elements}", + str_add=str_add) + raise ValueError(str_error) # df checking functions -def check_df(name="df", df=None, cols_req=None, accept_none=False, accept_nan=True, all_positive=False): - """""" - df = df.copy() - if not accept_none and df is None: - raise ValueError(f"'{name}' should not be None") +def check_df(name="df", df=None, accept_none=False, accept_nan=True, check_all_positive=False, + cols_requiered=None, cols_forbidden=None, cols_nan_check=None, str_add=None): + """Check if the provided DataFrame meets various criteria such as NaN values, required/forbidden columns, etc.""" + # Check DataFrame and values + if df is None: + if not accept_none: + raise ValueError(f"'{name}' should not be None") + else: + return None if not isinstance(df, pd.DataFrame): - raise ValueError(f"'{name}' ({type(df)}) should be DataFrame") + str_error = add_str(str_error= f"'{name}' ({type(df)}) should be DataFrame", + str_add=str_add) + raise ValueError(str_error) if not accept_nan and df.isna().any().any(): - raise ValueError(f"'{name}' contains NaN values, which are not allowed") - if cols_req is not None: - missing_cols = set(cols_req) - set(df.columns) - if missing_cols: - raise ValueError(f"'{name}' is missing required columns: {cols_req}") - if all_positive: + str_error = add_str(str_error=f"'{name}' contains NaN values, which are not allowed", + str_add=str_add) + raise ValueError(str_error) + if check_all_positive: numeric_df = df.select_dtypes(include=['float', 'int']) if numeric_df.min().min() <= 0: - raise ValueError(f"'{name}' should not contain non-positive values.") - return df.copy() + str_error = add_str(str_error=f"'{name}' should not contain non-positive values.", + str_add=str_add) + raise ValueError(str_error) + # Check columns + args = dict(accept_str=True, accept_none=True, str_add=str_add) + cols_requiered = check_type.check_list_like(name='cols_requiered', val=cols_requiered, **args) + cols_forbidden = check_type.check_list_like(name='cols_forbidden', val=cols_forbidden, **args) + cols_nan_check = check_type.check_list_like(name='cols_nan_check', val=cols_nan_check, **args) + if cols_requiered is not None: + missing_cols = [col for col in cols_requiered if col not in df.columns] + if len(missing_cols) > 0: + str_error = add_str(str_error=f"'{name}' is missing required columns: {missing_cols}", + str_add=str_add) + raise ValueError(str_error) + if cols_forbidden is not None: + forbidden_cols = [col for col in cols_forbidden if col in df.columns] + if len(forbidden_cols) > 0: + str_error = add_str(str_error=f"'{name}' is contains forbidden columns: {forbidden_cols}", + str_add=str_add) + raise ValueError(str_error) + if cols_nan_check is not None: + if df[cols_nan_check].isna().sum().sum() > 0: + str_error = add_str(str_error=f"NaN values are not allowed in '{cols_nan_check}'.", + str_add=str_add) + raise ValueError(str_error) + return df def check_col_in_df(df=None, name_df=None, cols=None, name_cols=None, accept_nan=False, error_if_exists=False, accept_none=False): @@ -131,5 +302,4 @@ def check_col_in_df(df=None, name_df=None, cols=None, name_cols=None, accept_nan # Check if NaNs are present when they are not accepted if not accept_nan: if df[cols].isna().any().any(): - raise ValueError(f"NaN values are not allowed in '{cols}'.") - + raise ValueError(f"NaN values are not allowed in '{cols}'.") \ No newline at end of file diff --git a/xomics/_utils/check_models.py b/xomics/_utils/check_models.py index 73e0c16..9eb2be5 100644 --- a/xomics/_utils/check_models.py +++ b/xomics/_utils/check_models.py @@ -1,39 +1,76 @@ -"""This is a script for scikit-learn model specific check functions""" +"""This is a script for scikit-learn model-specific check functions""" import inspect from inspect import isclass +from ._utils import add_str + # Helper functions # Main functions -def check_mode_class(model_class=None): - """""" +def check_mode_class(model_class=None, str_add=None): + """Check if the provided object is a class and callable, typically used for validating model classes.""" # Check if model_class is actually a class and not an instance if not isclass(model_class): - raise ValueError(f"'{model_class}' is not a model class. Please provide a valid model class.") + str_error = add_str(str_error=f"'model_class' ('{model_class}') is not a model class. " + f"Please provide a valid model class.", + str_add=str_add) + raise ValueError(str_error) # Check if model is callable if not callable(getattr(model_class, "__call__", None)): - raise ValueError(f"'{model_class}' is not a callable model.") - return model_class + str_error = add_str(str_error=f"'model_class' ('{model_class}') is not a callable model.", + str_add=str_add) + raise ValueError(str_error) -def check_model_kwargs(model_class=None, model_kwargs=None, param_to_check="n_clusters", method_to_check=None): +def check_model_kwargs(model_class=None, model_kwargs=None, name_model_class="model_class", + param_to_check=None, method_to_check=None, attribute_to_check=None, + random_state=None, str_add=None): """ - Check if the provided model has 'n_clusters' as a parameter. - Filter the model_kwargs to only include keys that are valid parameters for the model. + Check if the provided model class contains specific parameters and methods. Filters 'model_kwargs' to include only + valid parameters for the model class. + + Parameters: + model_class: The class of the model to check. + model_kwargs: A dictionary of keyword arguments for the model. + name_model_class: Name of model class for model class kwargs + param_to_check: A specific parameter to check in the model class. + method_to_check: A specific method to check in the model class. + attribute_to_check: A specific attribute to check in model class + random_state: random state + str_add: additional error string + + Returns: + model_kwargs: A filtered dictionary of model_kwargs containing only valid parameters for the model class. """ model_kwargs = model_kwargs or {} if model_class is None: - raise ValueError("'model_class' must be provided.") + str_error = add_str(str_error=f"'{name_model_class}' must be provided.", str_add=str_add) + raise ValueError(str_error) valid_args = list(inspect.signature(model_class).parameters.keys()) # Check if 'param_to_check' is a parameter of the model if param_to_check is not None and param_to_check not in valid_args: - raise ValueError(f"'{param_to_check}' should be an argument in the given 'model' ({model_class}).") # Check if 'method_to_check' is a method of the model + str_error = add_str(str_error=f"'{param_to_check}' should be an argument in the given '{name_model_class}' ({model_class}).", + str_add=str_add) + raise ValueError(str_error) # Check if 'method_to_check' is a method of the model if method_to_check is not None and not hasattr(model_class, method_to_check): - raise ValueError(f"'{method_to_check}' should be a method in the given 'model' ({model_class}).") + str_error = add_str(str_error=f"'{method_to_check}' should be a method in the given '{name_model_class}' ({model_class}).", + str_add=str_add) + raise ValueError(str_error) + # Check if 'attribute_to_check' is an attribute of the model + if attribute_to_check is not None and not hasattr(model_class, attribute_to_check): + str_error = add_str(str_error=f"'{attribute_to_check}' should be an attribute in the given '{name_model_class}' ({model_class}).", + str_add=str_add) + raise ValueError(str_error) # Check if model_kwargs contain invalid parameters for the model invalid_kwargs = [x for x in model_kwargs if x not in valid_args] if len(invalid_kwargs): - raise ValueError(f"'model_kwargs' contains invalid arguments: {invalid_kwargs}") + str_error = add_str(str_error=f"'model_kwargs' (for '{model_class}') contains invalid arguments: {invalid_kwargs}", + str_add=str_add) + raise ValueError(str_error) + if "random_state" not in model_kwargs and "random_state" in valid_args: + model_kwargs.update(dict(random_state=random_state)) return model_kwargs + + diff --git a/xomics/_utils/check_plots.py b/xomics/_utils/check_plots.py new file mode 100644 index 0000000..9bf3d0c --- /dev/null +++ b/xomics/_utils/check_plots.py @@ -0,0 +1,212 @@ +""" +This is a script for plot checking utility functions. +""" +import re +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors + +from ._utils import add_str +import xomics._utils.check_type as check_type +import numpy as np + + +# Helper functions +def _is_valid_hex_color(val): + """Check if a value is a valid hex color.""" + return isinstance(val, str) and re.match(r'^#[0-9A-Fa-f]{6}$', val) + + +def _is_valid_rgb_tuple(val): + """Check if a value is a valid RGB tuple.""" + return (isinstance(val, tuple) and len(val) == 3 and + all(isinstance(c, (int, float)) and 0 <= c <= 255 for c in val)) + + +# Check figure +def check_fig(fig=None, accept_none=False, str_add=None): + """Check if the provided value is a matplotlib Figure instance or None.""" + import matplotlib.figure + if accept_none and fig is None: + return None + if not isinstance(fig, matplotlib.figure.Figure): + str_error = f"'fig' (type={type(fig)}) should be mpl.figure.Figure or None." + if str_add: + str_error += " " + str_add + raise ValueError(str_error) + return fig + + +def check_ax(ax=None, accept_none=False, str_add=None, return_first=False): + """Check if the provided value is a matplotlib Axes instance or None.""" + import matplotlib.axes + if accept_none and ax is None: + return None + if not isinstance(ax, matplotlib.axes.Axes): + str_error = add_str(str_error=f"'ax' (type={type(ax)}) should be mpl.axes.Axes or None.", + str_add=str_add) + raise ValueError(str_error) + if return_first: + if isinstance(ax, (list, tuple, np.ndarray)) and len(ax) > 0: + ax = ax[0] + return ax + + +def check_figsize(figsize=None, accept_none=False, str_add=None): + """Check size of figure""" + if accept_none and figsize is None: + return None # skip check + check_type.check_tuple(name="figsize", val=figsize, n=2, str_add=str_add) + args = dict(min_val=1, just_int=False, str_add=str_add) + check_type.check_number_range(name="figsize:width", val=figsize[0], **args) + check_type.check_number_range(name="figsize:height", val=figsize[1], **args) + + +def check_grid_axis(grid_axis="y", accept_none=True, str_add=None): + if accept_none and grid_axis is None: + return None # Skip test + list_grid_axis = ["y", "x", "both"] + if grid_axis not in list_grid_axis: + str_error = add_str(str_error=f"'grid_axis' ({grid_axis}) should be one of following: {list_grid_axis}", + str_add=str_add) + raise ValueError(str_error) + + +def check_font_weight(name="font_weight", font_weight=None, accept_none=False, str_add=None): + if accept_none and font_weight is None: + return None # Skip test + list_weight = ["normal", "bold"] + if font_weight not in list_weight: + str_error = add_str(str_error=f"'{name}' ({font_weight}) should be one of following: {list_weight}", + str_add=str_add) + raise ValueError(str_error) + + +def check_fontsize_args(**kwargs): + """Check fontsize parameters""" + args_fs = {} + for name, val in kwargs.items(): + check_type.check_number_range(name=name, val=val, min_val=0, accept_none=True, just_int=False) + args_fs[name] = val + return args_fs + + +# Check min and max values +def check_vmin_vmax(vmin=None, vmax=None, str_add=None): + """Check if vmin and vmax are valid numbers and vmin is less than vmax.""" + args = dict(accept_none=True, just_int=False, str_add=str_add) + check_type.check_number_val(name="vmin", val=vmin, **args) + check_type.check_number_val(name="vmax", val=vmax, **args) + if vmin is not None and vmax is not None and vmin >= vmax: + str_error = add_str(str_error=f"'vmin' ({vmin}) < 'vmax' ({vmax}) not fulfilled.", + str_add=str_add) + raise ValueError(str_error) + + +def check_lim(name="xlim", val=None, accept_none=True, str_add=None): + """Validate that lim parameter ('xlim' or 'ylim') is tuple with two numbers, where the first is less than the second.""" + if val is None: + if accept_none: + return None # Skip check + else: + raise ValueError(f"'{name}' should not be None") + check_type.check_tuple(name=name, val=val, n=2) + min_val, max_val = val + args = dict(just_int=False, str_add=str_add) + check_type.check_number_val(name=f"{name}:min", val=min_val, **args) + check_type.check_number_val(name=f"{name}:max", val=max_val, **args) + if min_val >= max_val: + str_error = add_str(str_error=f"'{name}:min' ({min_val}) should be < '{name}:max' ({max_val}).", + str_add=str_add) + raise ValueError(str_error) + + +def check_dict_xlims(dict_xlims=None, n_ax=None, str_add=None): + """Validate the structure and content of dict_xlims to ensure it contains the correct keys and value formats.""" + if n_ax is None: + # DEV: Developer warning + raise ValueError("'n_ax' must be specified") + if dict_xlims is None: + return + check_type.check_dict(name="dict_xlims", val=dict_xlims, str_add=str_add) + wrong_keys = [x for x in list(dict_xlims) if x not in range(n_ax)] + if len(wrong_keys) > 0: + str_error = add_str(str_error= f"'dict_xlims' contains invalid keys: {wrong_keys}. " + f"Valid keys are axis indices from 0 to {n_ax - 1}.", + str_add=str_add) + raise ValueError(str_error) + for key in dict_xlims: + check_lim(name="xlim", val=dict_xlims[key], str_add=str_add) + + +# Check colors +def check_color(name=None, val=None, accept_none=False, str_add=None): + """Check if the provided value is a valid color for matplotlib.""" + if val is None: + if accept_none: + return None # Skip test + else: + raise ValueError(f"'{name}' should not be None") + base_colors = list(mcolors.BASE_COLORS.keys()) + tableau_colors = list(mcolors.TABLEAU_COLORS.keys()) + css4_colors = list(mcolors.CSS4_COLORS.keys()) + all_colors = base_colors + tableau_colors + css4_colors + if _is_valid_hex_color(val) or _is_valid_rgb_tuple(val): + return + elif val not in all_colors: + str_error = add_str(str_error=f"'{name}' ('{val}') is not a valid color. Chose from following: {all_colors}", + str_add=str_add) + raise ValueError(str_error) + + +def check_list_colors(name=None, val=None, accept_none=False, min_n=None, max_n=None, str_add=None): + """Check if color list is valid""" + if accept_none and val is None: + return None # Skip check + val = check_type.check_list_like(name=name, val=val, accept_none=accept_none, accept_str=True, str_add=str_add) + for l in val: + check_color(name=name, val=l, accept_none=accept_none, str_add=str_add) + if min_n is not None and len(val) < min_n: + str_error = add_str(str_error=f"'{name}' should contain at least {min_n} colors", + str_add=str_add) + raise ValueError(str_error) + if max_n is not None and len(val) > max_n: + str_error = add_str(str_error=f"'{name}' should contain no more than {max_n} colors", + str_add=str_add) + raise ValueError(str_error) + + +def check_dict_color(name="dict_color", val=None, accept_none=False, min_n=None, max_n=None, str_add=None): + """Check if colors in dict_color are valid""" + if accept_none and val is None: + return None # Skip check + check_type.check_dict(name=name, val=val, accept_none=accept_none) + for key in val: + check_color(name=name, val=val[key], accept_none=accept_none) + if min_n is not None and len(val) < min_n: + str_error = add_str(str_error=f"'{name}' should contain at least {min_n} colors", + str_add=str_add) + raise ValueError(str_error) + if max_n is not None and len(val) > max_n: + str_error = add_str(str_error=f"'{name}' should contain no more than {max_n} colors", + str_add=str_add) + raise ValueError(str_error) + + +def check_cmap(name=None, val=None, accept_none=False, str_add=None): + """Check if cmap is a valid colormap for matplotlib.""" + valid_cmaps = plt.colormaps() + if accept_none and val is None: + pass + elif val not in valid_cmaps: + str_error = add_str(str_error=f"'{name}' ('{val}') is not a valid cmap. Chose from following: {valid_cmaps}", + str_add=str_add) + raise ValueError(str_error) + + +def check_palette(name=None, val=None, accept_none=False, str_add=None): + """Check if the provided value is a valid color palette.""" + if isinstance(val, str): + check_cmap(name=name, val=val, accept_none=accept_none, str_add=str_add) + elif isinstance(val, list): + for v in val: + check_color(name=name, val=v, accept_none=accept_none, str_add=str_add) diff --git a/xomics/_utils/check_type.py b/xomics/_utils/check_type.py index 2b938ba..b0b2298 100644 --- a/xomics/_utils/check_type.py +++ b/xomics/_utils/check_type.py @@ -4,52 +4,78 @@ import pandas as pd import numpy as np +from ._utils import add_str +from .utils_types import VALID_INT_TYPES, VALID_INT_FLOAT_TYPES + # Type checking functions -def check_number_val(name=None, val=None, accept_none=False, just_int=False): - """Check if value is float""" +def check_number_val(name=None, val=None, accept_none=False, just_int=False, str_add=None): + """Check if value is a valid integer or float""" if val is None: if not accept_none: - raise ValueError(f"'{name}' should not be None.") + str_error = add_str(str_error=f"'{name}' should not be None.", str_add=str_add) + raise ValueError(str_error) return None if just_int is None: raise ValueError("'just_int' must be specified") - valid_types = (int,) if just_int else (float, int) - type_description = "int" if just_int else "float or int" + # Define valid types for integers and floating points + valid_types = VALID_INT_TYPES if just_int else VALID_INT_TYPES + VALID_INT_FLOAT_TYPES + type_description = "an integer" if just_int else "a float or an integer" if not isinstance(val, valid_types): - raise ValueError(f"'{name}' ({val}) should be {type_description}.") + str_error = add_str(str_error=f"'{name}' should be {type_description}, but got {type(val).__name__}.", + str_add=str_add) + raise ValueError(str_error) -def check_number_range(name=None, val=None, min_val=0, max_val=None, accept_none=False, just_int=None): +def check_number_range(name=None, val=None, min_val=0, max_val=None, exclusive_limits=False, + accept_none=False, just_int=None, str_add=None): """Check if value of given name is within defined range""" if val is None: if not accept_none: - raise ValueError(f"'{name}' should not be None.") + str_error = add_str(str_error=f"'{name}' should not be None.", str_add=str_add) + raise ValueError(str_error) return None if just_int is None: raise ValueError("'just_int' must be specified") - valid_types = (int,) if just_int else (float, int) - type_description = "int" if just_int else "float or int n, with" - + # Define valid types for integers and floating points + valid_types = VALID_INT_TYPES if just_int else VALID_INT_TYPES + VALID_INT_FLOAT_TYPES # Verify the value's type and range - if not isinstance(val, valid_types) or val < min_val or (max_val is not None and val > max_val): - range_desc = f"n>={min_val}" if max_val is None else f"{min_val}<=n<={max_val}" - error = f"'{name}' ({val}) should be {type_description} {range_desc}. " - if accept_none: - error += "None is also accepted." - raise ValueError(error) + type_description = "an integer" if just_int else "a float or an integer" + if not isinstance(val, valid_types): + str_error = add_str(str_error=f"'{name}' should be {type_description}, but got {type(val).__name__}.", + str_add=str_add) + raise ValueError(str_error) + # Min and max values are excluded from allowed values + if exclusive_limits: + if val <= min_val or (max_val is not None and val >= max_val): + range_desc = f"n > {min_val}" if max_val is None else f"{min_val} < n < {max_val}" + str_error = add_str(str_error=f"'{name}' should be {type_description} with {range_desc}, but got {val}.", + str_add=str_add) + raise ValueError(str_error) + else: + if val < min_val or (max_val is not None and val > max_val): + range_desc = f"n >= {min_val}" if max_val is None else f"{min_val} <= n <= {max_val}" + str_error = add_str(str_error=f"'{name}' should be {type_description} with {range_desc}, but got {val}.", + str_add=str_add) + raise ValueError(str_error) + return val -def check_str(name=None, val=None, accept_none=False): +def check_str(name=None, val=None, accept_none=False, return_empty_string=False, str_add=None): """Check type string""" if val is None: if not accept_none: - raise ValueError(f"'{name}' should not be None.") - return None + str_error = add_str(str_error=f"'{name}' should not be None.", str_add=str_add) + raise ValueError(str_error) + return "" if return_empty_string else None if not isinstance(val, str): - raise ValueError(f"'{name}' ('{val}') should be string.") + str_error = add_str(str_error=f"'{name}' ('{val}') should be string.", + str_add=str_add) + raise ValueError(str_error) + return val +# TODO check def check_str_in_list(name=None, val=None, list_options=None, accept_none=False): """Check if val is one of the given options""" if val is None: @@ -60,64 +86,103 @@ def check_str_in_list(name=None, val=None, list_options=None, accept_none=False) raise ValueError(f"'{name}' ('{val}') should be of the following: {list_options}") -def check_bool(name=None, val=None): +def check_bool(name=None, val=None, accept_none=False, str_add=None): """Check if the provided value is a boolean.""" + if val is None: + if not accept_none: + str_error = add_str(str_error=f"'{name}' should not be None.", str_add=str_add) + raise ValueError(str_error) + return None if not isinstance(val, bool): - raise ValueError(f"'{name}' ({val}) should be bool.") + str_error = add_str(str_error=f"'{name}' ({val}) should be bool.", + str_add=str_add) + raise ValueError(str_error) -def check_dict(name=None, val=None, accept_none=False): +def check_dict(name=None, val=None, accept_none=False, str_add=None): """Check if the provided value is a dictionary.""" if val is None: if not accept_none: - raise ValueError(f"'{name}' should not be None.") + str_error = add_str(str_error=f"'{name}' should not be None.", str_add=str_add) + raise ValueError(str_error) return None if not isinstance(val, dict): - error = f"'{name}' ({val}) should be a dictionary" - error += " or None." if accept_none else "." - raise ValueError(error) + str_error = add_str(str_error=f"'{name}' ({val}) should be a dictionary.", + str_add=str_add) + raise ValueError(str_error) -def check_tuple(name=None, val=None, n=None, check_n=True, accept_none=False): - """""" +def check_tuple(name=None, val=None, n=None, check_number=True, accept_none=False, + accept_none_number=False, str_add=None): + """Check if the provided value is a tuple, optionally of a certain length and containing only numbers.""" if val is None: if not accept_none: - raise ValueError(f"'{name}' should not be None.") + str_error = add_str(str_error=f"'{name}' should not be None.", str_add=str_add) + raise ValueError(str_error) return None if not isinstance(val, tuple): - raise ValueError(f"'{name}' ({val}) should be a tuple.") - if check_n and n is not None and len(val) != n: - raise ValueError(f"'{name}' ({val}) should be a tuple with {n} elements.") - - -def check_list_like(name=None, val=None, accept_none=False, convert=True, accept_str=False): - """""" + str_error = add_str(str_error=f"'{name}' ({val}) should be a tuple.", + str_add=str_add) + raise ValueError(str_error) + if n is not None and len(val) != n: + str_error = add_str(str_error=f"'{name}' ({val}) should be a tuple with {n} elements.", + str_add=str_add) + raise ValueError(str_error) + if n is not None and check_number: + for v in val: + check_number_val(name=name, val=v, just_int=False, accept_none=accept_none_number, + str_add=str_add) + + +def check_list_like(name=None, val=None, accept_none=False, convert=True, accept_str=False, min_len=None, + check_all_non_neg_int=False, check_all_non_none=True, check_all_str_or_convertible=False, + str_add=None): + """Check if the value is list-like, optionally converting it to a list, and performing additional checks.""" if val is None: if not accept_none: - raise ValueError(f"'{name}' should not be None.") + str_error = add_str(str_error=f"'{name}' should not be None.", str_add=str_add) + raise ValueError(str_error) return None if not convert: if not isinstance(val, list): - raise ValueError(f"'{name}' (type: {type(val)}) should be a list.") + str_error = add_str(str_error=f"'{name}' (type: {type(val)}) should be a list.", + str_add=str_add) + raise ValueError(str_error) elif accept_str and isinstance(val, str): return [val] else: allowed_types = (list, tuple, np.ndarray, pd.Series) if not isinstance(val, allowed_types): - raise ValueError(f"'{name}' (type: {type(val)}) should be one of {allowed_types}.") + str_error = add_str(str_error=f"'{name}' (type: {type(val)}) should be one of {allowed_types}.", + str_add=str_add) + raise ValueError(str_error) if isinstance(val, np.ndarray) and val.ndim != 1: - raise ValueError(f"'{name}' is a multi-dimensional numpy array and cannot be considered as a list.") - val = list(val) + str_error = add_str(str_error=f"'{name}' is a multi-dimensional numpy array and cannot be considered as a list.", + str_add=str_add) + raise ValueError(str_error) + val = list(val) if isinstance(val, (np.ndarray, pd.Series)) else val + if check_all_non_none: + n_none = len([x for x in val if x is None]) + if n_none > 0: + str_error = add_str(str_error=f"'{name}' should not contain 'None' (n={n_none})", + str_add=str_add) + raise ValueError(str_error) + if check_all_non_neg_int: + if any(type(i) != int or i < 0 for i in val): + str_error = add_str(str_error=f"'{name}' should only contain non-negative integers.", + str_add=str_add) + raise ValueError(str_error) + if check_all_str_or_convertible: + wrong_elements = [x for x in val if not isinstance(x, (str, int, float, np.number))] + if len(wrong_elements) > 0: + str_error = add_str(str_error=f"The following elements in '{name}' are not strings or" + f" reasonably convertible: {wrong_elements}", + str_add=str_add) + raise ValueError(str_error) + else: + val = [str(x) for x in val] + if min_len is not None and len(val) < min_len: + str_error = add_str(str_error=f"'{name}' should not contain at least {min_len} elements", + str_add=str_add) + raise ValueError(str_error) return val - - -# Check special types -def check_ax(ax=None, accept_none=False): - """""" - import matplotlib.axes - if ax is None: - if not accept_none: - raise ValueError(f"'ax' should not be None.") - return None - if not isinstance(ax, matplotlib.axes.Axes): - raise ValueError(f"'ax' (type={type(ax)}) should be mpl.axes.Axes or None.") diff --git a/xomics/_utils/decorators.py b/xomics/_utils/decorators.py index 7269d1c..873cdd5 100644 --- a/xomics/_utils/decorators.py +++ b/xomics/_utils/decorators.py @@ -1,45 +1,14 @@ """ -This a script for general decorators used in xOmics +This a script for general decorators used in xOmics. +# Dev: use runtime decorator only for internal methods since they destroy the signature for some IDEs """ import warnings import traceback -from sklearn.exceptions import ConvergenceWarning +from sklearn.exceptions import ConvergenceWarning, UndefinedMetricWarning import functools import re - -# Document common interfaces -def doc_params(**kwargs): - """Decorator to add parameter descriptions to the docstring. - - Usage - ----- - @_doc_params(arg1=desc1, arg2=desc2, ...) - def func(): - '''Description {arg1} {arg2}.''' - """ - - def decorator(func): - doc = func.__doc__ - # Regular expression to find replacement fields and their indentation - pattern = re.compile(r'(?P *){(?P\w+)}\n') - # Function to adjust indentation - def adjust_indent(match): - key = match.group('key') - indent = match.group('indent') - try: - # Add the indent to all lines in the replacement string - replacement = kwargs[key].replace('\n', '\n' + indent) - return indent + replacement + '\n' - except KeyError: - # Key not provided in kwargs, keep original - return match.group(0) - # Replace all matching strings in doc - func.__doc__ = pattern.sub(adjust_indent, doc) - #print(func.__doc__) # Debugging line - return func - - return decorator +# Helper functions # Catch Runtime @@ -70,13 +39,12 @@ def _catch_warning(self, message, category, filename, lineno, file=None, line=No def get_warnings(self): return self._warn_list + def catch_runtime_warnings(): """Decorator to catch RuntimeWarnings and store them in a list. - Returns - ------- - decorated_func : method - The decorated function. + Returns: + decorated_func: The decorated function """ def decorator(func): @functools.wraps(func) @@ -100,13 +68,12 @@ def __init__(self, message, distinct_clusters): super().__init__(message) self.distinct_clusters = distinct_clusters + def catch_convergence_warning(): """Decorator to catch ConvergenceWarnings and raise custom exceptions. - Returns - ------- - decorated_func : method - The decorated function. + Returns: + decorated_func: The decorated function. """ def decorator(func): @functools.wraps(func) @@ -128,18 +95,18 @@ def wrapper(*args, **kwargs): return decorator + # Catch invalid division (could be added to AAclust().comp_medoids()) class InvalidDivisionException(Exception): pass + def catch_invalid_divide_warning(): """Decorator to catch specific RuntimeWarnings related to invalid division and raise custom exceptions. - Returns - ------- - decorated_func : method - The decorated function. + Returns: + decorated_func: The decorated function. """ def decorator(func): @functools.wraps(func) @@ -147,9 +114,47 @@ def wrapper(*args, **kwargs): with CatchRuntimeWarnings() as crw: result = func(*args, **kwargs) if crw.get_warnings(): - raise InvalidDivisionException(f"\nError due to RuntimeWarning: {crw.get_warnings()[0]}") + raise InvalidDivisionException(f"\nError due to 'RuntimeWarning': {crw.get_warnings()[0]}") return result return wrapper return decorator +# Catch UndefinedMetricWarnings +class CatchUndefinedMetricWarning: + """Context manager to catch and aggregate UndefinedMetricWarnings.""" + def __enter__(self): + self._warn_set = set() + self._other_warnings = [] + self._showwarning_orig = warnings.showwarning + warnings.showwarning = self._catch_warning + return self + def __exit__(self, exc_type, exc_value, tb): + warnings.showwarning = self._showwarning_orig + for warn_message, warn_category, filename, lineno in self._other_warnings: + warnings.warn_explicit(warn_message, warn_category, filename, lineno) + + def _catch_warning(self, message, category, filename, lineno, file=None, line=None): + if category == UndefinedMetricWarning: + self._warn_set.add(str(message)) # Add message to set (duplicates are automatically handled) + else: + self._other_warnings.append((message, category, filename, lineno)) + + def get_warnings(self): + return list(self._warn_set) + +def catch_undefined_metric_warning(): + """Decorator to catch and report UndefinedMetricWarnings once per unique message.""" + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with CatchUndefinedMetricWarning() as cumw: + result = func(*args, **kwargs) + if cumw.get_warnings(): + summary_msg = "The following 'UndefinedMetricWarning' was caught:\n" + "\n".join(cumw.get_warnings()) + summary_msg += ("\n This warning was likely triggered due to 'precision' or 'f1' metrics and " + "an imbalanced and/or small dataset.") + warnings.warn(summary_msg, UndefinedMetricWarning) + return result + return wrapper + return decorator \ No newline at end of file diff --git a/xomics/_utils/utils_groups.py b/xomics/_utils/utils_groups.py index 66823ca..39d9ccc 100644 --- a/xomics/_utils/utils_groups.py +++ b/xomics/_utils/utils_groups.py @@ -53,7 +53,7 @@ def get_qcols(df=None, groups=None, str_quant=None): str_quant: Substring in column indicating quantification """ if str_quant is None: - raise ValueError("'str_quant' must be given") + raise ValueError("'str_quant' should not be None") dict_col_group = get_dict_qcol_group(df=df, groups=groups, str_quant=str_quant) dict_group_cols = {g: [k for k, v in dict_col_group.items() if v == g] for g in groups} list_group_cols = [col for group_cols in dict_group_cols.values() for col in group_cols] diff --git a/xomics/_utils/utils_output.py b/xomics/_utils/utils_output.py index 0371d94..0ea36d4 100644 --- a/xomics/_utils/utils_output.py +++ b/xomics/_utils/utils_output.py @@ -2,6 +2,9 @@ This is a script for adjusting terminal output. """ import numpy as np +import sys + +STR_PROGRESS = "." # I Helper Functions @@ -10,34 +13,51 @@ def _print_red(input_str, **args): """Prints the given string in red text.""" print(f"\033[91m{input_str}\033[0m", **args) + def _print_blue(input_str, **args): """Prints the given string in blue text.""" print(f"\033[94m{input_str}\033[0m", **args) + def _print_green(input_str, **args): """Prints the given string in Matrix-style green text.""" print(f"\033[92m{input_str}\033[0m", **args) + #print(f"\033[32m{input_str}\033[0m", **args) + def print_out(input_str, **args): """Prints the given string in Matrix-style green text.""" _print_blue(input_str, **args) + # Progress bar -def print_start_progress(): +def print_start_progress(start_message=None): """Print start progress""" + # Start message + if start_message is not None: + print_out(start_message) + # Start progress bar progress_bar = " " * 25 - print_out(f"\r |{progress_bar}| 0.00%", end="") + print_out(f"\r |{progress_bar}| 0.0%", end="") + sys.stdout.flush() -def print_progress(i=0, n=0): +def print_progress(i=0, n=0, add_new_line=False): """Print progress""" progress = min(np.round(i/n * 100, 2), 100) - progress_bar = "#" * int(progress/4) + " " * (25-int(progress/4)) - print_out(f"\r |{progress_bar}| {progress:.2f}%", end="") + progress_bar = STR_PROGRESS * int(progress/4) + " " * (25-int(progress/4)) + str_end = "\n" if add_new_line else "" + print_out(f"\r |{progress_bar}| {progress:.1f}%", end=str_end) + sys.stdout.flush() -def print_finished_progress(): +def print_end_progress(end_message=None): """Print finished progress bar""" - progress_bar = "#" * 25 - print_out(f"\r |{progress_bar}| 100.00%") + # End progress bar + progress_bar = STR_PROGRESS * 25 + print_out(f"\r |{progress_bar}| 100.0%") + # End message + if end_message is not None: + print_out(end_message) + sys.stdout.flush() diff --git a/xomics/_utils/utils_plotting.py b/xomics/_utils/utils_plotting.py index bfd9fe6..537dd6f 100644 --- a/xomics/_utils/utils_plotting.py +++ b/xomics/_utils/utils_plotting.py @@ -1,16 +1,257 @@ """ -This is a script for internal plotting utility functions used in the backend. +This is a script for the backend of the plotting module functions used by other xOmics modules. """ import seaborn as sns +import matplotlib as mpl +from matplotlib import pyplot as plt +import matplotlib.lines as mlines +import warnings +from .check_type import check_number_range -# Helper functions +# I Helper function +def _create_marker(color, label, marker, marker_size, lw, edgecolor, linestyle, hatch, hatchcolor): + """Create custom marker based on input.""" + # Default marker (matching to plot) + if marker is None: + return mpl.patches.Patch(facecolor=color, + label=label, + lw=lw, + hatch=hatch, + edgecolor=hatchcolor) + # If marker is '-', treat it as a line + if marker == "-": + return plt.Line2D(xdata=[0, 1], ydata=[0, 1], + color=color, + linestyle=linestyle, + lw=lw, + label=label) + # Creates marker element without line (lw=0) + return plt.Line2D(xdata=[0], ydata=[0], + marker=marker, + label=label, + markerfacecolor=color, + color=edgecolor, + markersize=marker_size, + lw=0, + markeredgewidth=lw) -# Main function + +# Check functions +def _marker_has(marker, val=None): + if isinstance(marker, str): + return marker == val + elif marker is None: + return False + elif isinstance(marker, list): + return any([x == val for x in marker]) + else: + raise ValueError(f"'marker' ({marker}) is wrong") + + +def _marker_has_no(marker, val=None): + if isinstance(marker, str): + return marker != val + elif marker is None: + return False + elif isinstance(marker, list): + return any([x != val for x in marker]) + else: + raise ValueError(f"'marker' ({marker}) is wrong") + + +# Checking functions for list inputs +def _check_list_cat(dict_color=None, list_cat=None): + """Ensure items in list_cat are keys in dict_color and match in length.""" + if not list_cat: + return list(dict_color.keys()) + if not all(elem in dict_color for elem in list_cat): + missing_keys = [elem for elem in list_cat if elem not in dict_color] + raise ValueError(f"The following keys in 'list_cat' are not in 'dict_colors': {', '.join(missing_keys)}") + if len(dict_color) < len(list_cat): + raise ValueError( + f"'dict_colors' (n={len(dict_color)}) must contain >= elements than 'list_cat' (n={len(list_cat)}).") + return list_cat + + +def _check_labels(list_cat=None, labels=None): + """Validate labels and match their length to list_cat.""" + if labels is None: + labels = list_cat + if len(list_cat) != len(labels): + raise ValueError(f"Length must match of 'labels' ({len(labels)}) and categories ({len(list_cat)}).") + return labels + + +# Checking functions for inputs that can be list or single values (redundancy accepted for better user communication) +def _check_hatches(marker=None, hatch=None, list_cat=None): + """Check validity of list_hatche.""" + valid_hatches = ['/', '\\', '|', '-', '+', 'x', 'o', 'O', '.', '*'] + # Check if hatch is valid + if isinstance(hatch, str): + if hatch not in valid_hatches: + raise ValueError(f"'hatch' ('{hatch}') must be one of following: {valid_hatches}") + if isinstance(hatch, list): + wrong_hatch = [x for x in hatch if x not in valid_hatches] + if len(wrong_hatch) != 0: + raise ValueError( + f"'hatch' contains wrong values ('{wrong_hatch}')! Should be one of following: {valid_hatches}") + if len(hatch) != len(list_cat): + raise ValueError(f"Length must match of 'hatch' ({hatch}) and categories ({list_cat}).") # Check if hatch can be chosen + # Warn for parameter conflicts + if _marker_has_no(marker, val=None) and hatch: + warnings.warn(f"'hatch' can only be applied to the default marker, set 'marker=None'.", UserWarning) + # Create hatch list + list_hatch = [hatch] * len(list_cat) if not isinstance(hatch, list) else hatch + return list_hatch + + +def _check_marker(marker=None, list_cat=None, lw=0): + """Check validity of markers""" + # Add '-' for line and None for default marker + valid_markers = [None, "-"] + list(mlines.Line2D.markers.keys()) + # Check if marker is valid + if not isinstance(marker, list) and marker not in valid_markers: + raise ValueError(f"'marker' ('{marker}') must be one of following: {valid_markers}") + if isinstance(marker, list): + wrong_markers = [x for x in marker if x not in valid_markers] + if len(wrong_markers) != 0: + raise ValueError(f"'marker' contains wrong values ('{wrong_markers}'). Should be one of following: {valid_markers}") + if len(marker) != len(list_cat): + raise ValueError(f"Length must match of 'marker' ({marker}) and categories ({list_cat}).") + # Warn for parameter conflicts + if _marker_has(marker, val="-") and lw <= 0: + warnings.warn(f"Marker lines ('-') are only shown if 'lw' ({lw}) is > 0.", UserWarning) + # Create marker list + list_marker = [marker] * len(list_cat) if not isinstance(marker, list) else marker + return list_marker + + +def _check_marker_size(marker_size=10, list_cat=None): + """Check size of markers""" + # Check if marker_size is valid + if isinstance(marker_size, (int, float)): + check_number_range(name='marker_size', val=marker_size, min_val=0, accept_none=True, just_int=False) + elif isinstance(marker_size, list): + for i in marker_size: + check_number_range(name='marker_size', val=i, min_val=0, accept_none=True, just_int=False) + elif isinstance(marker_size, list) and len(marker_size) != len(list_cat): + raise ValueError(f"Length must match of 'marker_size' (marker_size) and categories ({list_cat}).") + else: + raise ValueError(f"'marker_size' has wrong data type: {type(marker_size)}") + # Create marker_size list + list_marker_size = [marker_size] * len(list_cat) if not isinstance(marker_size, list) else marker_size + return list_marker_size + + +def _check_linestyle(linestyle=None, list_cat=None, marker=None): + """Check validity of linestyle.""" + _lines = ['-', '--', '-.', ':', ] + _names = ["solid", "dashed", "dashed-doted", "dotted"] + valid_mls = _lines + _names + # Check if marker_linestyle is valid + if isinstance(linestyle, list): + wrong_mls = [x for x in linestyle if x not in valid_mls] + if len(wrong_mls) != 0: + raise ValueError( + f"'marker_linestyle' contains wrong values ('{wrong_mls}')! Should be one of following: {valid_mls}") + if len(linestyle) != len(list_cat): + raise ValueError(f"Length must match of 'marker_linestyle' ({linestyle}) and categories ({list_cat}).") + # Check if marker_linestyle is conflicting with other settings + if isinstance(linestyle, str): + if linestyle not in valid_mls: + raise ValueError(f"'marker_linestyle' ('{linestyle}') must be one of following: {_lines}," + f" or corresponding names: {_names} ") + # Warn for parameter conflicts + if linestyle is not None and _marker_has_no(marker, val="-"): + warnings.warn(f"'linestyle' ({linestyle}) is only applicable to marker lines ('-'), not to '{marker}'.", UserWarning) + # Create list_marker_linestyle list + list_marker_linestyle = [linestyle] * len(list_cat) if not isinstance(linestyle, list) else linestyle + return list_marker_linestyle + + +# II Main Functions +# DEV: General function for plot_gcfs def plot_gco(option='font.size', show_options=False): """Get current option from plotting context""" current_context = sns.plotting_context() if show_options: print(current_context) - option_value = current_context[option] # Typically font_size + try: + option_value = current_context[option] # Typically font_size + except KeyError: + options = list(current_context.keys()) + raise ValueError(f"Option not valid, select from the following: {options}") return option_value + + +# DEV: plot_get_cdict and plot_get_cmap are implemented in main utils + +# Remaining backend plotting functions +def plot_get_clist_(n_colors=3): + """Get manually curated list of 2 to 9 colors.""" + # Base lists + list_colors_3_to_4 = ["tab:gray", "tab:blue", "tab:red", "tab:orange"] + list_colors_5_to_6 = ["tab:blue", "tab:cyan", "tab:gray","tab:red", + "tab:orange", "tab:brown"] + list_colors_8_to_9 = ["tab:blue", "tab:orange", "tab:green", "tab:red", + "tab:gray", "gold", "tab:cyan", "tab:brown", + "tab:purple"] + # Two classes + if n_colors == 2: + return ["tab:blue", "tab:red"] + # Control/base + 2-3 classes + elif n_colors in [3, 4]: + return list_colors_3_to_4[0:n_colors] + # 5-7 classes (gray in middle as visual "breather") + elif n_colors in [5, 6]: + return list_colors_5_to_6[0:n_colors] + elif n_colors == 7: + return ["tab:blue", "tab:cyan", "tab:purple", "tab:gray", + "tab:red", "tab:orange", "tab:brown"] + # 8-9 classes (colors from scale categories) + elif n_colors in [8, 9]: + return list_colors_8_to_9[0:n_colors] + else: + return sns.color_palette(palette="husl", n_colors=n_colors) + + +def plot_legend_(ax=None, dict_color=None, list_cat=None, labels=None, + loc="upper left", loc_out=False, y=None, x=None, n_cols=3, + labelspacing=0.2, columnspacing=1.0, handletextpad=0.8, handlelength=2.0, + fontsize=None, fontsize_title=None, weight_font="normal", weight_title="normal", + marker=None, marker_size=10, lw=0, linestyle=None, edgecolor=None, + hatch=None, hatchcolor="white", title=None, title_align_left=True, + frameon=False, **kwargs): + """Sets an independently customizable plot legend""" + # Check input + if ax is None: + ax = plt.gca() + list_cat = _check_list_cat(dict_color=dict_color, list_cat=list_cat) + labels = _check_labels(list_cat=list_cat, labels=labels) + marker = _check_marker(marker=marker, list_cat=list_cat, lw=lw) + hatch = _check_hatches(marker=marker, hatch=hatch, list_cat=list_cat) + linestyle = _check_linestyle(linestyle=linestyle, list_cat=list_cat, marker=marker) + marker_size = _check_marker_size(marker_size=marker_size, list_cat=list_cat) + # Remove existing legend + if ax.get_legend() is not None and len(ax.get_legend().get_lines()) > 0: + ax.legend_.remove() + # Update legend arguments + args = dict(loc=loc, ncol=n_cols, fontsize=fontsize, labelspacing=labelspacing, columnspacing=columnspacing, + handletextpad=handletextpad, handlelength=handlelength, borderpad=0, title=title, + edgecolor=edgecolor, prop={"weight": weight_font, "size": fontsize}, frameon=frameon) + args.update(kwargs) + if fontsize_title: + args["title_fontproperties"] = {"weight": weight_title, "size": fontsize_title} + if loc_out: + x, y = x or 0, y or -0.25 + if x or y: + args["bbox_to_anchor"] = (x or 0, y or 1) + # Create handles and legend + handles = [_create_marker(dict_color[cat], labels[i], marker[i], marker_size[i], + lw, edgecolor, linestyle[i], hatch[i], hatchcolor) + for i, cat in enumerate(list_cat)] + legend = ax.legend(handles=handles, labels=labels, **args) + if title_align_left: + legend._legend_box.align = "left" + return ax diff --git a/xomics/_utils/new_types.py b/xomics/_utils/utils_types.py similarity index 57% rename from xomics/_utils/new_types.py rename to xomics/_utils/utils_types.py index 5af60c7..ea31223 100644 --- a/xomics/_utils/new_types.py +++ b/xomics/_utils/utils_types.py @@ -9,11 +9,24 @@ ArrayLike1DUnion = Union[Sequence[NumericType], np.ndarray, pd.Series] ArrayLike2DUnion = Union[Sequence[Sequence[NumericType]], np.ndarray, pd.DataFrame] -# Now, we'll create distinct named types using NewType. -# This won't change runtime behavior but will be recognized by static type checkers and can be documented. - # A 1D array-like object. Can be a sequence (e.g., list or tuple) of ints/floats, numpy ndarray, or pandas Series. ArrayLike1D = NewType("ArrayLike1D", ArrayLike1DUnion) # A 2D array-like object. Can be a sequence of sequence of ints/floats, numpy ndarray, or pandas DataFrame. ArrayLike2D = NewType("ArrayLike2D", ArrayLike2DUnion) + +# Numeric type lists +VALID_INT_TYPES = ( + int, + np.int_, np.intc, np.intp, np.integer, # np.integer covers all standard integer types + np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + np.longlong # Equivalent to np.int64 on most platforms +) + +VALID_FLOAT_TYPES = ( + float, + np.float_, np.float16, np.float32, np.float64, np.longdouble # np.longdouble is the highest precision float available +) + +VALID_INT_FLOAT_TYPES = VALID_INT_TYPES + VALID_FLOAT_TYPES diff --git a/xomics/config.py b/xomics/config.py index e968f51..3a50036 100644 --- a/xomics/config.py +++ b/xomics/config.py @@ -2,32 +2,146 @@ This is a script for setting system level options for xOmics. """ from typing import Dict, Any +import os + +from ._utils.check_type import check_bool, check_number_val, check_number_range, check_str +from ._utils.check_data import check_df # System level options -verbose = True -replace_underscore_in_plots = True +_dict_options = { + 'verbose': "off", + 'random_state': "off", + 'allow_multiprocessing': True, + 'replace_underscore_in_plots': True, +} + + +# Check system level (option) parameters or depending parameters +def check_verbose(verbose=None): + """Check if general verbosity is on or off. Adjusted based on options setting and value provided to object""" + global_verbose = options["verbose"] + if global_verbose != "off": + # System level verbosity + check_bool(name="verbose (option)", val=global_verbose) + verbose = global_verbose + else: + check_bool(name="verbose", val=verbose) + return verbose + + +def check_random_state(random_state=None): + """Adjust random state if global is not 'off' (default)""" + global_random_state = options["random_state"] + args = dict(min_val=0, accept_none=True, just_int=True) + if global_random_state != "off": + # System-level random state + check_number_range(name="random_state (option)", val=global_random_state, **args) + random_state = global_random_state + else: + check_number_range(name="random_state", val=random_state, **args) + return random_state +def check_n_jobs(n_jobs=None): + """Adjust n_jobs to 1 if multiprocessing is not allowed""" + allow_multiprocessing = options["allow_multiprocessing"] + check_bool(name="allow_multiprocessing (options)", val=allow_multiprocessing) + # Disable multiprocessing + if not allow_multiprocessing: + n_jobs = 1 + os.environ['LOKY_MAX_CPU_COUNT'] = "1" + # Set n_jobs to maximum number of CPUs + if n_jobs == -1: + n_jobs = os.cpu_count() + # Check which n_jobs are allowed + check_number_val(name="j_jobs", val=n_jobs, accept_none=True) + if n_jobs is None or n_jobs >= 1: + check_number_range(name="n_jobs", val=n_jobs, accept_none=True, just_int=True, min_val=1) + return n_jobs + + +# DEV: Parameters are used as directive to get better documentation style # Enables setting of system level variables like in matplotlib +def _check_option(name_option="", option=None): + """Check if option is valid""" + if name_option == "verbose": + if option != "off": + check_verbose(verbose=option) + if name_option == "random_state": + if option != "off": + check_random_state(random_state=option) + if name_option == "allow_multiprocessing": + check_bool(name=name_option, val=option) + + class Settings: + """ + A class for managing system-level settings for AAanalysis. + + This class mimics a dictionary-like interface, allowing the setting and retrieving + of system-level options. It is designed to be used as a single global instance, ``options``. + + Parameters + ---------- + The following options can be set: + + verbose : bool or 'off', default='off' + Sets verbose mode to ``True`` or ``False`` globally if not 'off'. + random_state : int, None, or 'off', default='off' + The seed used by the random number generator. + + * If set to a positive integer, results of stochastic processes are consistent, enabling reproducibility. + * If set to ``None``, stochastic processes will be truly random. + * If set to 'off', no global random state variable will be set, allowing the underlying libraries to use + their default random state behavior. + + allow_multiprocessing : bool, default=True + Whether multiprocessing is allowed in general. If ``False``, ``n_jobs`` is automatically set to 1. + replace_underscore_in_plots : bool, default=True + Whether to replace underscores from variables in plot labels. + + + See Also + -------- + * :class:`numpy.random.RandomState` for details on the ``random_state`` variable used to make stochastic processes + yielding consistent results. + + Warnings + -------- + * Multiprocessing Compatibility: Enabling multiprocessing (``allow_multiprocessing=True``) + can lead to issues in environments that don't support forking or when interfacing with + certain libraries. If encountering errors, consider setting ``allow_multiprocessing=False``. + Note that this may affect performance in computation-intensive operations. + + Examples + -------- + .. include:: examples/options.rst + """ def __init__(self): - self._settings: Dict[str, Any] = { - 'verbose': verbose, - 'replace_underscore_in_plots': replace_underscore_in_plots - } + self._settings: Dict[str, Any] = _dict_options.copy() def __getitem__(self, key: str) -> Any: """Retrieve a setting's value using dict-like access.""" return self._settings.get(key, None) def __setitem__(self, key: str, value: Any) -> None: - """Set a setting's value using dict-like access.""" - self._settings[key] = value + """Set a setting's value using dict-like access. + Prevent adding new keys that are not already in the system options.""" + if key in self._settings: + _check_option(name_option=key, option=value) + self._settings[key] = value + else: + valid_options = list(_dict_options.keys()) + raise KeyError(f"'{key}' is not valid options. Valid options are: {valid_options}.") def __contains__(self, key: str) -> bool: """Check if a key is in the settings.""" return key in self._settings + def __str__(self) -> str: + """Return a string representation of the settings dictionary.""" + return str(self._settings) # Global settings instance options = Settings() + diff --git a/xomics/data_handling/_preprocess.py b/xomics/data_handling/_preprocess.py index c560f3b..b8f0b1b 100644 --- a/xomics/data_handling/_preprocess.py +++ b/xomics/data_handling/_preprocess.py @@ -14,6 +14,13 @@ # I Helper Functions +def check_all_positive(df=None): + """""" + min_val = df.min().min() + if min_val < 0: + raise ValueError(f"Minimum value ({min_val}) in 'df' should be >= 0") + + def check_base(base=None): """Ensure 'base' is a valid numerical type and has an acceptable value""" if not isinstance(base, (int, float)) or base not in [2, 10]: @@ -36,22 +43,12 @@ def check_match_df_ids(df=None, list_ids=None): # II Main Functions -# Common interface -doc_param_df_groups = \ -"""\ -df - DataFrame with quantifications. ``Rows`` typically correspond to proteins and ``columns`` to conditions. -groups - List with names grouping conditions from ``df`` columns.\ -""" - - class PreProcess: """ Pre-processing class for quantifications of omics data. """ def __init__(self, - col_id: str = ut.COL_PROT_ID, + col_id: str = "protein_id", col_name: str = ut.COL_GENE_NAME, str_quant: str = "log2_lfq" ): @@ -63,7 +60,7 @@ def __init__(self, col_name Name of column with sample names in DataFrame. str_quant - Identifier for the LFQ columns in the DataFrame. + Identifier for the quantification columns in the DataFrame. """ ut.check_str(name="col_id", val=col_id, accept_none=False) ut.check_str(name="col_name", val=col_name, accept_none=False) @@ -72,17 +69,19 @@ def __init__(self, self.col_name = col_name self.str_quant = str_quant - @ut.doc_params(doc_param_df_groups=doc_param_df_groups) def get_qcols(self, df: pd.DataFrame = None, - groups: list = None + groups: ut.ArrayLike1D = None ) -> list: """ Create a list with groups from df based on str_quant and given groups Parameters ---------- - {doc_param_df_groups} + df : pd.DataFrame, shape (n_samples, n_conditions) + DataFrame with quantifications. ``Rows`` typically correspond to proteins and ``columns`` to conditions. + groups : array-like, shape (n_groups,) + List with names grouping conditions from ``df`` columns. Return ------ @@ -97,17 +96,19 @@ def get_qcols(self, cols_quant = ut.get_qcols(df=df, groups=groups, str_quant=self.str_quant) return cols_quant - @ut.doc_params(doc_param_df_groups=doc_param_df_groups) def get_dict_qcol_group(self, df: pd.DataFrame = None, - groups: list = None + groups: ut.ArrayLike1D = None ) -> dict: """ Create a dictionary with quantification columns and the group they are subordinated to Parameters ---------- - {doc_param_df_groups} + df : pd.DataFrame, shape (n_samples, n_conditions) + DataFrame with quantifications. ``Rows`` typically correspond to proteins and ``columns`` to conditions. + groups : array-like, shape (n_groups,) + List with names grouping conditions from ``df`` columns. Return ------ @@ -122,17 +123,19 @@ def get_dict_qcol_group(self, dict_qcol_group = ut.get_dict_qcol_group(df=df, groups=groups, str_quant=self.str_quant) return dict_qcol_group - @ut.doc_params(doc_param_df_groups=doc_param_df_groups) def get_dict_group_qcols(self, df: pd.DataFrame = None, - groups: list = None + groups: ut.ArrayLike1D = None ) -> dict: """ Create a dictionary with for groups from df and their corresponding columns with quantifications Parameters ---------- - {doc_param_df_groups} + df : pd.DataFrame, shape (n_samples, n_conditions) + DataFrame with quantifications. ``Rows`` typically correspond to proteins and ``columns`` to conditions. + groups : array-like, shape (n_groups,) + List with names grouping conditions from ``df`` columns. Return ------ @@ -147,10 +150,9 @@ def get_dict_group_qcols(self, dict_group_qcols = ut.get_dict_group_qcols(df=df, groups=groups, str_quant=self.str_quant) return dict_group_qcols - @ut.doc_params(doc_param_df_groups=doc_param_df_groups) def filter_nan(self, df: pd.DataFrame = None, - groups: Optional[list] = None, + groups: Optional[ut.ArrayLike1D] = None, cols: Optional[list] = None, ) -> pd.DataFrame: """ @@ -158,7 +160,10 @@ def filter_nan(self, Parameters ---------- - {doc_param_df_groups} + df : pd.DataFrame, shape (n_samples, n_conditions) + DataFrame with quantifications. ``Rows`` typically correspond to proteins and ``columns`` to conditions. + groups : array-like, shape (n_groups,) + List with names grouping conditions from ``df`` columns. cols List of columns from ``df`` to consider for filtering. @@ -178,7 +183,7 @@ def filter_nan(self, if cols is None and groups is None: cols = list(df) cols = ut.check_list_like(name="cols", val=cols, accept_none=True, accept_str=True) - ut.check_col_in_df(df=df, name_df="df", cols=cols, accept_none=True, accept_nan=True) + df = ut.check_df(df=df, name="df", cols_requiered=cols, accept_none=False, accept_nan=True) if groups is not None: ut.check_match_df_groups(groups=groups, df=df, str_quant=self.str_quant) if cols is None: @@ -188,10 +193,9 @@ def filter_nan(self, df = df.reset_index(drop=True) return df - @ut.doc_params(doc_param_df_groups=doc_param_df_groups) def filter_groups(self, df: pd.DataFrame = None, - groups: Optional[list] = None, + groups: Optional[ut.ArrayLike1D] = None, min_pct: float = 0.8, ) -> pd.DataFrame: """ @@ -199,7 +203,10 @@ def filter_groups(self, Parameters ---------- - {doc_param_df_groups} + df : pd.DataFrame, shape (n_samples, n_conditions) + DataFrame with quantifications. ``Rows`` typically correspond to proteins and ``columns`` to conditions. + groups : array-like, shape (n_groups,) + List with names grouping conditions from ``df`` columns. min_pct Minimum percentage threshold of non-missing values in at least one group. @@ -244,7 +251,7 @@ def filter_duplicated_names(df: pd.DataFrame = None, The modified and filtered DataFrame. """ ut.check_df(df=df) - ut.check_col_in_df(df=df, name_df="df", cols=col, accept_nan=True, accept_none=False) + df = ut.check_df(df=df, name="df", cols_requiered=col, accept_none=False, accept_nan=True) ut.check_str(name="str_split", val=str_split) ut.check_bool(name="split_names", val=split_names) # Filtering @@ -284,11 +291,11 @@ def apply_log(df: pd.DataFrame = None, - NaN values will remain NaN after the transformation. """ # Check input - df = ut.check_df(df=df, all_positive=True) if cols is None: cols = list(df) cols = ut.check_list_like(name="cols", val=cols, accept_none=False, accept_str=True) - ut.check_col_in_df(df=df, cols=cols, accept_nan=True) + df = ut.check_df(df=df, name="df", cols_requiered=cols, accept_none=False, accept_nan=True) + check_all_positive(df=df[cols]) ut.check_bool(name="log2", val=log2) ut.check_bool(name="neg", val=neg) # Log transform @@ -333,7 +340,7 @@ def apply_exp(df: pd.DataFrame = None, if cols is None: cols = list(df) cols = ut.check_list_like(name="cols", val=cols, accept_none=False, accept_str=True) - ut.check_col_in_df(df=df, cols=cols, accept_nan=True) + df = ut.check_df(df=df, name="df", cols_requiered=cols, accept_none=False, accept_nan=True) ut.check_bool(name="neg", val=neg) check_base(base=base) # Exponential transform @@ -403,7 +410,7 @@ def add_significance(df: pd.DataFrame = None, DataFrame with added significance column. """ # Check input - df = ut.check_df(name="df", df=df, cols_req=[col_fc, col_pval]) + df = ut.check_df(name="df", df=df, cols_requiered=[col_fc, col_pval]) ut.check_number_range(name="th_fc", val=th_fc, min_val=0, just_int=False) ut.check_number_range(name="th_pval", val=th_pval, min_val=0, max_val=1, just_int=False) # Rescale p-value @@ -412,10 +419,9 @@ def add_significance(df: pd.DataFrame = None, df[ut.COL_SIG_CLASS] = ut.get_sig_classes(df=df, col_fc=col_fc, col_pval=col_pval, th_pval=th_pval, th_fc=th_fc) return df - @ut.doc_params(doc_param_df_groups=doc_param_df_groups) def run(self, df: pd.DataFrame = None, - groups: list = None, + groups: ut.ArrayLike1D = None, groups_ctrl: list = None, pvals_correction: Optional[str] = None, pvals_neg_log10: bool = True @@ -426,7 +432,10 @@ def run(self, Parameters ---------- - {doc_param_df_groups} + df : pd.DataFrame, shape (n_samples, n_conditions) + DataFrame with quantifications. ``Rows`` typically correspond to proteins and ``columns`` to conditions. + groups : array-like, shape (n_groups,) + List with names grouping conditions from ``df`` columns. groups_ctrl List with names control grouping conditions from ``df`` columns. pvals_correction diff --git a/xomics/imputation/_cimpute.py b/xomics/imputation/_cimpute.py index 4acca36..5048796 100644 --- a/xomics/imputation/_cimpute.py +++ b/xomics/imputation/_cimpute.py @@ -15,15 +15,6 @@ # TODO d) Extend to other omics data # I Helper Functions -doc_param_df_groups_upmnar = \ -"""\ -df - DataFrame containing quantified values with MVs. ``Rows`` typically correspond to proteins and ``columns`` to conditions. -groups - List of quantification group (substrings of columns in ``df``). -loc_pct_upmnar - Location factor [0-1] for the upper MNAR limit (upMNAR) given as relative proportion (percentage) of the detection range.\ -""" # II Main Functions @@ -75,7 +66,6 @@ def __init__(self, self.col_name = col_name self.str_quant = str_quant - @ut.doc_params(doc_param_df_groups_upmnar=doc_param_df_groups_upmnar) def get_limits(self, df: pd.DataFrame = None, groups: ut.ArrayLike1D = None, @@ -88,8 +78,15 @@ def get_limits(self, Parameters ---------- - {doc_param_df_groups_upmnar} - cols_quant + df : pd.DataFrame, shape(n_samples, n_conditions) + DataFrame containing quantified values with missing values. ``Rows`` typically correspond to proteins + and ``columns`` to conditions. + groups : array-like, shape (n_groups,) + List of quantification group (substrings of columns in ``df``). + loc_pct_upmnar : float, default=0.25 + Location factor [0-1] for the upper MNAR limit (upMNAR) given as relative proportion (percentage) + of the detection range. + cols_quant : array-like, shape (n_columns,) Column names with quantification data in ``df``. Return @@ -103,7 +100,7 @@ def get_limits(self, """ # Check input cols_quant = ut.check_list_like(name="cols_quant", val=cols_quant, accept_none=True) - df = ut.check_df(df=df, accept_none=False, cols_req=cols_quant) + df = ut.check_df(df=df, accept_none=False, cols_requiered=cols_quant) groups = ut.check_list_like(name="groups", val=groups, accept_none=False) ut.check_match_df_groups(groups=groups, df=df, str_quant=self.str_quant) ut.check_number_range(name="loc_pct_upmnar", val=loc_pct_upmnar, min_val=0, max_val=1, @@ -115,7 +112,6 @@ def get_limits(self, d_max = df[cols_quant].max().max() return d_min, up_mnar, d_max - @ut.doc_params(doc_param_df_groups_upmnar=doc_param_df_groups_upmnar) def run(self, df: pd.DataFrame = None, groups: ut.ArrayLike1D = None, @@ -132,15 +128,22 @@ def run(self, Parameters ---------- - {doc_param_df_groups_upmnar} - min_cs + df : pd.DataFrame, shape(n_samples, n_conditions) + DataFrame containing quantified values with missing values. ``Rows`` typically correspond to proteins + and ``columns`` to conditions. + groups : array-like, shape (n_groups,) + List of quantification group (substrings of columns in ``df``). + loc_pct_upmnar : float, default=0.25 + Location factor [0-1] for the upper MNAR limit (upMNAR) given as relative proportion (percentage) + of the detection range. + min_cs : float, default=0.5 Minimum of confidence score [0-1] used for selecting values for protein in groups to apply imputation on. - n_neighbors + n_neighbors: int, default=5 Number of neighboring samples to use for MCAR imputation by KNN. Return ------ - df_imp + df_imp : pd.DataFrame DataFrame with (a) imputed intensities values and (b) group-wise confidence score and NaN classification. Notes diff --git a/xomics/plotting/__init__.py b/xomics/plotting/__init__.py index 8b5d4bb..8cd8e4e 100644 --- a/xomics/plotting/__init__.py +++ b/xomics/plotting/__init__.py @@ -1,21 +1,12 @@ -from ._plot_get_clist import plot_get_clist -from ._plot_settings import plot_settings -from ._plot_gcfs import plot_gcfs -from ._plot_legend import plot_legend from ._plot_volcano import plot_volcano from ._plot_inferno import plot_inferno from ._plot_enrich_map import plot_enrich_map from ._plot_enrich_rank import plot_enrich_rank from ._plot_prank import plot_prank, plot_prank_scatter from ._plot_imput_histo import plot_imput_histo -from ._plot_pintegrate import plot_pintegrate __all__ = [ - "plot_get_clist", - "plot_settings", - "plot_gcfs", - "plot_legend", "plot_volcano", # "plot_inferno", TODO must be developed "plot_enrich_rank", @@ -23,5 +14,4 @@ "plot_prank", "plot_prank_scatter", "plot_imput_histo", - #"plot_integ_scatter", TODO must be developed ] \ No newline at end of file diff --git a/xomics/plotting/_plot_gcfs.py b/xomics/plotting/_plot_gcfs.py deleted file mode 100644 index b4eeacf..0000000 --- a/xomics/plotting/_plot_gcfs.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -This is a script for getting current font size of figures. -""" -import seaborn as sns - - -# Main function -def plot_gcfs(option='font.size'): - """ - Gets current font size (or axes linewdith). - - This font size can be set by :func:`plot_settings` function. - - Examples - -------- - Here are the default colors used in CPP and CPP-SHAP plots: - - .. plot:: - :include-source: - - >>> import matplotlib.pyplot as plt - >>> import seaborn as sns - >>> import xomics as xo - >>> data = {'Classes': ['Class A', 'Class B', 'Class C'], 'Values': [23, 27, 43]} - >>> colors = xo.plot_get_clist() - >>> xo.plot_settings() - >>> sns.barplot(y='Classes', x='Values', data=data, palette=colors, hue="Classes", legend=False) - >>> sns.despine() - >>> plt.title("Two points bigger title", size=xo.plot_gcfs()+2) - >>> plt.tight_layout() - >>> plt.show() - - See Also - -------- - * Our `Plotting Prelude `_. - """ - allowed_options = ["font.size", "axes.linewidth"] - if option not in allowed_options: - return ValueError(f"'option' should be one of following: {allowed_options}") - # Get the current plotting context - current_context = sns.plotting_context() - option_value = current_context[option] # Typically font_size - return option_value \ No newline at end of file diff --git a/xomics/plotting/_plot_get_clist.py b/xomics/plotting/_plot_get_clist.py deleted file mode 100644 index f6547a3..0000000 --- a/xomics/plotting/_plot_get_clist.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -Plotting utility function to obtain xOmics color list. -""" -from typing import List -import xomics.utils as ut - - -# II Main function -def plot_get_clist(n_colors: int = 3) -> List[str]: - """ - Returns list of 2 to 9 colors. - - This fuctions returns one of eight different colorl lists optimized - for appealing visualization of categories. - - Parameters - ---------- - n_colors - Number of colors. Must be between 2 and 9. - Returns - ------- - list - List with colors given as matplotlib color names. - - Examples - -------- - .. plot:: - :include-source: - - >>> import matplotlib.pyplot as plt - >>> import seaborn as sns - >>> import xomics as xo - >>> colors = xo.plot_get_clist(n_colors=3) - >>> data = {'Classes': ['Class A', 'Class B', 'Class C'], 'Values': [10, 23, 33]} - >>> xo.plot_settings() - >>> sns.barplot(data=data, x='Classes', y='Values', palette=colors, hue="Classes", legend=False) - >>> plt.show() - - See Also - -------- - * The example notebooks in `Plotting Prelude `_. - * `Matplotlib color names `_ - * :func:`seaborn.color_palette` function to generate a color palette in seaborn. - """ - # Check input - ut.check_number_range(name="n_colors", val=n_colors, min_val=2, max_val=9, just_int=True) - - # Base lists - list_colors_3_to_4 = ["tab:gray", "tab:blue", "tab:red", "tab:orange"] - list_colors_5_to_6 = ["tab:blue", "tab:cyan", "tab:gray","tab:red", - "tab:orange", "tab:brown"] - list_colors_8_to_9 = ["tab:blue", "tab:orange", "tab:green", "tab:red", - "tab:gray", "gold", "tab:cyan", "tab:brown", - "tab:purple"] - # Two classes - if n_colors == 2: - return ["tab:blue", "tab:red"] - # Control/base + 2-3 classes - elif n_colors in [3, 4]: - return list_colors_3_to_4[0:n_colors] - # 5-7 classes (gray in middle as visual "breather") - elif n_colors in [5, 6]: - return list_colors_5_to_6[0:n_colors] - elif n_colors == 7: - return ["tab:blue", "tab:cyan", "tab:purple", "tab:gray", - "tab:red", "tab:orange", "tab:brown"] - # 8-9 classes (colors from scale categories) - elif n_colors in [8, 9]: - return list_colors_8_to_9[0:n_colors] - diff --git a/xomics/plotting/_plot_imput_histo.py b/xomics/plotting/_plot_imput_histo.py index 1985a17..85efe87 100644 --- a/xomics/plotting/_plot_imput_histo.py +++ b/xomics/plotting/_plot_imput_histo.py @@ -16,7 +16,6 @@ # II Main Functions # TODO add check function, improve interface and make consistent, add tests, add tutorial - def plot_imput_histo(ax: Optional[plt.Axes] = None, figsize: Tuple[int, int] = (6, 5), df_raw: pd.DataFrame = None, @@ -69,10 +68,10 @@ def plot_imput_histo(ax: Optional[plt.Axes] = None, """ # Check input ut.check_list_like(name="cols_quant", val=cols_quant, accept_none=False) - ut.check_df(name="df_raw", df=df_raw, cols_req=cols_quant) - ut.check_df(name="df_imp", df=df_imp, cols_req=cols_quant) + ut.check_df(name="df_raw", df=df_raw, cols_requiered=cols_quant) + ut.check_df(name="df_imp", df=df_imp, cols_requiered=cols_quant) # Pre-process data - colors = xo.plot_get_clist(n_colors=3) if colors is None else colors + colors = ut.plot_get_clist_(n_colors=3) if colors is None else colors _args = dict(binwidth=binwidth, **kwargs) vals_raw = df_raw[cols_quant].values.flatten() n_raw = len(df_raw) diff --git a/xomics/plotting/_plot_legend.py b/xomics/plotting/_plot_legend.py deleted file mode 100644 index c24401a..0000000 --- a/xomics/plotting/_plot_legend.py +++ /dev/null @@ -1,361 +0,0 @@ -""" -This is a script for setting plot legend. -""" -from typing import Optional, List, Dict, Union, Tuple -import matplotlib as mpl -from matplotlib import pyplot as plt -import xomics.utils as ut -import matplotlib.lines as mlines -import warnings - - -# I Helper functions -def marker_has(marker, val=None): - if isinstance(marker, str): - return marker == val - elif marker is None: - return False - elif isinstance(marker, list): - return any([x == val for x in marker]) - else: - raise ValueError(f"'marker' ({marker}) is wrong") - - -def marker_has_no(marker, val=None): - if isinstance(marker, str): - return marker != val - elif marker is None: - return False - elif isinstance(marker, list): - return any([x != val for x in marker]) - else: - raise ValueError(f"'marker' ({marker}) is wrong") - - -# Checking functions for list inputs -def check_list_cat(dict_color=None, list_cat=None): - """Ensure items in list_cat are keys in dict_color and match in length.""" - if not list_cat: - return list(dict_color.keys()) - if not all(elem in dict_color for elem in list_cat): - missing_keys = [elem for elem in list_cat if elem not in dict_color] - raise ValueError(f"The following keys in 'list_cat' are not in 'dict_colors': {', '.join(missing_keys)}") - if len(dict_color) != len(list_cat): - raise ValueError( - f"Length must match between 'list_cat' ({len(list_cat)}) and 'dict_colors' ({len(dict_color)}).") - return list_cat - - -def check_labels(list_cat=None, labels=None): - """Validate labels and match their length to list_cat.""" - if labels is None: - labels = list_cat - if len(list_cat) != len(labels): - raise ValueError(f"Length must match of 'labels' ({len(labels)}) and categories ({len(list_cat)}).") - return labels - - -# Checking functions for inputs that can be list or single values (redundancy accepted for better user communication) -def check_hatches(marker=None, hatch=None, list_cat=None): - """Check validity of list_hatche.""" - valid_hatches = ['/', '\\', '|', '-', '+', 'x', 'o', 'O', '.', '*'] - # Check if hatch is valid - if isinstance(hatch, str): - if hatch not in valid_hatches: - raise ValueError(f"'hatch' ('{hatch}') must be one of following: {valid_hatches}") - if isinstance(hatch, list): - wrong_hatch = [x for x in hatch if x not in valid_hatches] - if len(wrong_hatch) != 0: - raise ValueError( - f"'hatch' contains wrong values ('{wrong_hatch}')! Should be one of following: {valid_hatches}") - if len(hatch) != len(list_cat): - raise ValueError(f"Length must match of 'hatch' ({hatch}) and categories ({list_cat}).") # Check if hatch can be chosen - # Warn for parameter conflicts - if marker_has_no(marker, val=None) and hatch: - warnings.warn(f"'hatch' can only be applied to the default marker, set 'marker=None'.", UserWarning) - # Create hatch list - list_hatch = [hatch] * len(list_cat) if not isinstance(hatch, list) else hatch - return list_hatch - - -def check_marker(marker=None, list_cat=None, lw=0): - """Check validity of markers""" - # Add '-' for line and None for default marker - valid_markers = [None, "-"] + list(mlines.Line2D.markers.keys()) - # Check if marker is valid - if not isinstance(marker, list) and marker not in valid_markers: - raise ValueError(f"'marker' ('{marker}') must be one of following: {valid_markers}") - if isinstance(marker, list): - wrong_markers = [x for x in marker if x not in valid_markers] - if len(wrong_markers) != 0: - raise ValueError(f"'marker' contains wrong values ('{wrong_markers}'). Should be one of following: {valid_markers}") - if len(marker) != len(list_cat): - raise ValueError(f"Length must match of 'marker' ({marker}) and categories ({list_cat}).") - # Warn for parameter conflicts - if marker_has(marker, val="-") and lw <= 0: - warnings.warn(f"Marker lines ('-') are only shown if 'lw' ({lw}) is > 0.", UserWarning) - # Create marker list - list_marker = [marker] * len(list_cat) if not isinstance(marker, list) else marker - return list_marker - - -def check_marker_size(marker_size=None, list_cat=None): - """""" - # Check if marker_size is valid - if isinstance(marker_size, (int, float)): - ut.check_number_range(name='marker_size', val=marker_size, min_val=0, accept_none=True, just_int=False) - elif isinstance(marker_size, list): - for i in marker_size: - ut.check_number_range(name='marker_size', val=i, min_val=0, accept_none=True, just_int=False) - elif isinstance(marker_size, list) and len(marker_size) != len(list_cat): - raise ValueError(f"Length must match of 'marker_size' (marker_size) and categories ({list_cat}).") - else: - raise ValueError(f"'marker_size' has wrong data type: {type(marker_size)}") - # Create marker_size list - list_marker_size = [marker_size] * len(list_cat) if not isinstance(marker_size, list) else marker_size - return list_marker_size - - -def check_linestyle(linestyle=None, list_cat=None, marker=None): - """Check validity of linestyle.""" - _lines = ['-', '--', '-.', ':', ] - _names = ["solid", "dashed", "dashed-doted", "dotted"] - valid_mls = _lines + _names - # Check if marker_linestyle is valid - if isinstance(linestyle, list): - wrong_mls = [x for x in linestyle if x not in valid_mls] - if len(wrong_mls) != 0: - raise ValueError( - f"'marker_linestyle' contains wrong values ('{wrong_mls}')! Should be one of following: {valid_mls}") - if len(linestyle) != len(list_cat): - raise ValueError(f"Length must match of 'marker_linestyle' ({linestyle}) and categories ({list_cat}).") - # Check if marker_linestyle is conflicting with other settings - if isinstance(linestyle, str): - if linestyle not in valid_mls: - raise ValueError(f"'marker_linestyle' ('{linestyle}') must be one of following: {_lines}," - f" or corresponding names: {_names} ") - # Warn for parameter conflicts - if linestyle is not None and marker_has_no(marker, val="-"): - warnings.warn(f"'linestyle' ({linestyle}) is only applicable to marker lines ('-'), not to '{marker}'.", UserWarning) - # Create list_marker_linestyle list - list_marker_linestyle = [linestyle] * len(list_cat) if not isinstance(linestyle, list) else linestyle - return list_marker_linestyle - - -# Helper function -def _create_marker(color, label, marker, marker_size, lw, edgecolor, linestyle, hatch, hatchcolor): - """Create custom marker based on input.""" - # Default marker (matching to plot) - if marker is None: - return mpl.patches.Patch(facecolor=color, - label=label, - lw=lw, - hatch=hatch, - edgecolor=hatchcolor) - # If marker is '-', treat it as a line - if marker == "-": - return plt.Line2D(xdata=[0, 1], ydata=[0, 1], - color=color, - linestyle=linestyle, - lw=lw, - label=label) - # Creates marker element without line (lw=0) - return plt.Line2D(xdata=[0], ydata=[0], - marker=marker, - label=label, - markerfacecolor=color, - color=edgecolor, - markersize=marker_size, - lw=0, - markeredgewidth=lw) - - -# II Main function -def plot_legend(ax: Optional[plt.Axes] = None, - # Categories and colors - dict_color: Optional[Dict[str, str]] = None, - list_cat: Optional[List[str]] = None, - labels: Optional[List[str]] = None, - # Position and Layout - loc: Union[str, int] = "upper left", - loc_out: bool = False, - y: Optional[Union[int, float]] = None, - x: Optional[Union[int, float]] = None, - ncol: int = 3, - labelspacing: Union[int, float] = 0.2, - columnspacing: Union[int, float] = 1.0, - handletextpad: Union[int, float] = 0.8, - handlelength: Union[int, float] = 2, - # Font and Style - fontsize: Optional[Union[int, float]] = None, - fontsize_title: Optional[Union[int, float]] = None, - weight: str = "normal", - fontsize_weight: str = "normal", - # Line, Marker, and Area - marker: Optional[Union[str, int, list]] = None, - marker_size: Union[int, float, List[Union[int, float]]] = 10, - lw: Union[int, float] = 0, - linestyle: Optional[Union[str, list]] = None, - edgecolor: str = None, - hatch: Optional[Union[str, List[str]]] = None, - hatchcolor: str = "white", - # Title - title: str = None, - title_align_left: bool = True, - **kwargs - ) -> Union[plt.Axes, Tuple[List, List[str]]]: - """ - Sets an independntly customizable plot legend. - - Legends can be flexbily adjusted based categories and colors provided in ``dict_color`` dictionary. - This functions comprises the most convinient settings for ``func:`matplotlib.pyplot.legend``. - - Parameters - ---------- - ax - The axes to attach the legend to. If not provided, the current axes will be used. - dict_color - A dictionary mapping categories to colors. - list_cat - List of categories to include in the legend (keys of ``dict_color``). - labels - Labels for legend items corresponding to given categories. - loc - Location for the legend. - loc_out - If ``True``, sets automatically ``x=0`` and ``y=-0.25`` if they are ``None``. - y - The y-coordinate for the legend's anchor point. - x - The x-coordinate for the legend's anchor point. - ncol - Number of columns in the legend, at least 1. - labelspacing - Vertical spacing between legend items. - columnspacing - Horizontal spacing between legend columns. - handletextpad - Horizontal spacing bewtween legend handle (marker) and label. - handlelength - Length of legend handle. - fontsize - Font size for the legend text. - fontsize_title - Font size for the legend title. - weight - Weight of the font. - fontsize_weight - Font weight for the legend title. - marker - Marker for legend items. Lines ('-') only visiable if ``lw>0``. - marker_size - Marker size for legend items. - lw - Line width for legend items. If negative, corners are rounded. - linestyle - Style of line. Only applied to lines (``marker='-'``). - edgecolor - Edge color of legend items. Not applicable to lines. - hatch - Filling pattern for default marker. Only applicable when ``marker=None``. - hatchcolor - Hatch color of legend items. Only applicable when ``marker=None``. - title - Title for the legend. - title_align_left - Whether to align the title to the left. - **kwargs - Furhter key word arguments for :attr:`matplotlib.axes.Axes.legend`. - - Returns - ------- - ax - Axes on which legend is applied to. - - Examples - -------- - .. plot:: - :include-source: - - >>> import matplotlib.pyplot as plt - >>> import seaborn as sns - >>> import xomics as xo - >>> data = {'Classes': ['A', 'B', 'C'], 'Values': [23, 27, 43]} - >>> colors = xo.plot_get_clist(n_colors=3) - >>> xo.plot_settings() - >>> sns.barplot(x='Classes', y='Values', data=data, palette=colors, hatch=["/", ".", "."], hue="Classes", legend=False) - >>> sns.despine() - >>> dict_color = {"Group 1": "black", "Group 2": "black"} - >>> xo.plot_legend(dict_color=dict_color, ncol=2, y=1.1, hatch=["/", "."]) - >>> plt.tight_layout() - >>> plt.show() - - Notes - ----- - Markers can be None (default), lines ('-') or one of the `matplotlib markers `_. - - See Also - -------- - * More examples in `Plotting Prelude `_. - * `Linestyles of markers `_. - * `Hatches `_, which are filling patterns. - * :class:`matplotlib.lines.Line2D` for available marker shapes and line properties. - * :class:`matplotlib.axes.Axes`, which is the core object in matplotlib. - * :func:`matplotlib.pyplot.gca` to get the current Axes instance. - """ - # Check input - ut.check_ax(ax=ax, accept_none=True) - if ax is None: - ax = plt.gca() - - ut.check_dict(name="dict_color", val=dict_color, accept_none=False) - list_cat = check_list_cat(dict_color=dict_color, list_cat=list_cat) - labels = check_labels(list_cat=list_cat, labels=labels) - - ut.check_bool(name="title_align_left", val=title_align_left) - ut.check_bool(name="loc_out", val=loc_out) - - ut.check_number_range(name="ncol", val=ncol, min_val=1, accept_none=True, just_int=True) - ut.check_number_val(name="x", val=x, accept_none=True, just_int=False) - ut.check_number_val(name="y", val=y, accept_none=True, just_int=False) - ut.check_number_val(name="lw", val=lw, accept_none=True, just_int=False) - - args_non_neg = {"labelspacing": labelspacing, "columnspacing": columnspacing, - "handletextpad": handletextpad, "handlelength": handlelength, - "fontsize": fontsize, "fontsize_legend": fontsize_title} - for key in args_non_neg: - ut.check_number_range(name=key, val=args_non_neg[key], min_val=0, accept_none=True, just_int=False) - - marker = check_marker(marker=marker, list_cat=list_cat, lw=lw) - hatch = check_hatches(marker=marker, hatch=hatch, list_cat=list_cat) - linestyle = check_linestyle(linestyle=linestyle, list_cat=list_cat, marker=marker) - marker_size = check_marker_size(marker_size, list_cat=list_cat) - - # Remove existing legend - if ax.get_legend() is not None and len(ax.get_legend().get_lines()) > 0: - ax.legend_.remove() - - # Update legend arguments - args = dict(loc=loc, ncol=ncol, fontsize=fontsize, labelspacing=labelspacing, columnspacing=columnspacing, - handletextpad=handletextpad, handlelength=handlelength, borderpad=0, title=title, - edgecolor=edgecolor, prop={"weight": weight, "size": fontsize}) - args.update(kwargs) - - if fontsize_title: - args["title_fontproperties"] = {"weight": fontsize_weight, "size": fontsize_title} - - if loc_out: - x, y = x or 0, y or -0.25 - if x or y: - args["bbox_to_anchor"] = (x or 0, y or 1) - - # Create handles and legend - handles = [_create_marker(dict_color[cat], labels[i], marker[i], marker_size[i], - lw, edgecolor, linestyle[i], hatch[i], hatchcolor) - for i, cat in enumerate(list_cat)] - - legend = ax.legend(handles=handles, labels=labels, **args) - if title_align_left: - legend._legend_box.align = "left" - return ax - diff --git a/xomics/plotting/_plot_pintegrate.py b/xomics/plotting/_plot_pintegrate.py deleted file mode 100644 index ea1ffab..0000000 --- a/xomics/plotting/_plot_pintegrate.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -This is a script for plotting use_cases integration data. -""" -from matplotlib import pyplot as plt -import matplotlib.patches as mpatches -import seaborn as sns - -import xomics.utils as ut - - -# I Helper Functions - - -# II Main Functions -def plot_pintegrate(): - """""" - # TODO Clarissa \ No newline at end of file diff --git a/xomics/plotting/_plot_volcano.py b/xomics/plotting/_plot_volcano.py index b981a37..ee13eeb 100644 --- a/xomics/plotting/_plot_volcano.py +++ b/xomics/plotting/_plot_volcano.py @@ -9,7 +9,6 @@ from typing import Optional, Tuple, Union, List from adjustText import adjust_text -import xomics as xo import xomics.utils as ut # Constats @@ -154,20 +153,21 @@ def plot_volcano(ax: Optional[plt.Axes] = None, # Initial parameter validation ut.check_ax(ax=ax, accept_none=True) ut.check_tuple(name="figsize", val=figsize, n=2, accept_none=True) - ut.check_col_in_df(df=df, name_df="df", cols=[col_fc, col_pval], name_cols=["col_fc", "col_pval"]) - df = ut.check_df(name="df", df=df, cols_req=[col_fc, col_pval]) + cols_requiered = [col_fc, col_pval] + ut.check_df(df=df, name="df", cols_requiered=cols_requiered, accept_nan=False, accept_none=False) if col_names is not None or names_to_annotate is not None: - ut.check_col_in_df(name_df="df", df=df, cols=col_names, name_cols="col_names") + cols_requiered.append(col_names) names_to_annotate = ut.check_list_like(name="names_to_annotate", val=names_to_annotate, accept_none=False) names_to_annotate = check_match_df_names(df=df, list_names=names_to_annotate, col_names=col_names) if col_cbar is not None: - ut.check_col_in_df(name_df="df", df=df, cols=col_cbar, accept_nan=True) + cols_requiered.append(col_cbar) + ut.check_df(df=df, name="df", cols_requiered=cols_requiered) ut.check_number_range(name="th_fc", val=th_fc, min_val=0, just_int=False) ut.check_number_range(name="th_pval", val=th_pval, min_val=0, max_val=1, just_int=False) colors = ut.check_list_like(name="colors", val=colors, accept_none=True, accept_str=True) - ut.check_tuple(name="colors_pos_neg_non", val=colors_pos_neg_non, accept_none=True, n=3, check_n=True) + ut.check_tuple(name="colors_pos_neg_non", val=colors_pos_neg_non, accept_none=True, n=3, check_number=True) ut.check_number_range(name="size", val=size, min_val=1, just_int=True) - ut.check_tuple(name="size_pos_neg_non", val=colors_pos_neg_non, accept_none=True, n=3, check_n=True) + ut.check_tuple(name="size_pos_neg_non", val=colors_pos_neg_non, accept_none=True, n=3, check_number=True) ut.check_number_range(name="alpha", val=alpha, min_val=0, max_val=1, just_int=False) ut.check_number_range(name="edge_width", val=edge_width, min_val=0, just_int=False) ut.check_dict(name="label_fontdict", val=label_fontdict, accept_none=True) @@ -179,7 +179,7 @@ def plot_volcano(ax: Optional[plt.Axes] = None, # Plot settings if colors_pos_neg_non is None: - color_non_sig, color_sig_neg, color_sig_pos = xo.plot_get_clist(n_colors=3) + color_non_sig, color_sig_neg, color_sig_pos = ut.plot_get_clist_(n_colors=3) else: color_sig_pos, color_sig_neg, color_non_sig = colors_pos_neg_non if sizes_pos_neg_non is None: @@ -253,7 +253,7 @@ def plot_volcano(ax: Optional[plt.Axes] = None, if names_to_annotate is not None: labels = [(row[col_names], row[col_fc], row[col_pval]) for i, row in df.iterrows() if row[col_names] in names_to_annotate and not np.isnan(row[col_fc])] - fontdict = dict(size=xo.plot_gcfs()-8) + fontdict = dict(size=ut.plot_gco()-8) if label_fontdict is not None: fontdict.update(**label_fontdict) texts = [plt.text(x, y, label, fontdict=fontdict) for label, x, y in labels] @@ -264,7 +264,7 @@ def plot_volcano(ax: Optional[plt.Axes] = None, if not legend or col_cbar is not None: ax.legend().set_visible(False) else: - xo.plot_legend(dict_color=dict_color, + ut.plot_legend_(dict_color=dict_color, list_cat=[ut.STR_NON_SIG, ut.STR_SIG_NEG, ut.STR_SIG_POS], ncol=1, marker="o", loc=loc_legend, diff --git a/xomics/plotting/_utils_plot.py b/xomics/plotting/_utils_plot.py index 56965f6..2a7fb50 100644 --- a/xomics/plotting/_utils_plot.py +++ b/xomics/plotting/_utils_plot.py @@ -10,6 +10,7 @@ import xomics.utils as ut +# TODO into utils and remove def set_legend_handles_labels(ax=None, dict_color=None, list_cat=None, labels=None, y=-0.2, x=0.5, ncol=3, fontsize=11, weight="normal", lw=0, edgecolor=None, return_handles=False, loc=9, diff --git a/xomics/plotting_ut/__init__.py b/xomics/plotting_ut/__init__.py new file mode 100644 index 0000000..34ec5c3 --- /dev/null +++ b/xomics/plotting_ut/__init__.py @@ -0,0 +1,13 @@ +from ._plot_get_clist import plot_get_clist +from ._plot_settings import plot_settings +from ._plot_gcfs import plot_gcfs +from ._plot_legend import plot_legend +from ._display_df import display_df + +__all__ = [ + "plot_get_clist", + "plot_settings", + "plot_gcfs", + "plot_legend", + "display_df", +] \ No newline at end of file diff --git a/xomics/plotting_ut/_display_df.py b/xomics/plotting_ut/_display_df.py new file mode 100644 index 0000000..e921372 --- /dev/null +++ b/xomics/plotting_ut/_display_df.py @@ -0,0 +1,144 @@ +""" +This is a script for displaying pd.DataFrames as HTML output for jupyter notebooks. +""" +from typing import Optional, Union +import pandas as pd +from IPython.display import display, HTML + +import xomics.utils as ut + + +# Helper functions +def _adjust_df(df=None, char_limit = 50): + df = df.copy() + list_index = df.index + # Adjust index if it consists solely of integers + if all(isinstance(i, int) for i in list_index): + df.index = [x + 1 for x in df.index] + # Function to truncate strings longer than char_limit + def truncate_string(s): + return str(s)[:int(char_limit/2)] + '...' + str(s)[-int(char_limit/2):] if len(str(s)) > char_limit else s + # Apply truncation to each cell in the DataFrame + if char_limit is not None: + for col in df.columns: + df[col] = df[col].apply(lambda x: truncate_string(x) if isinstance(x, str) else x) + return df + + +def _check_show(name="row_to_show", val=None, df=None): + """Check if valid string or int""" + if val is None: + return None # Skip test + rows_or_columns = list(df.T) if "row" in name else list(df) + n = len(rows_or_columns) + str_row_or_column = "row" if "row" in name else "column" + if isinstance(val, str): + ut.check_str(name=name, val=val, accept_none=True) + if val not in rows_or_columns: + raise ValueError(f"'{name}' ('{val}') should be one of following: {rows_or_columns}") + elif isinstance(val, int): + ut.check_number_range(name=name, val=val, accept_none=True, min_val=0, max_val=n, just_int=True) + else: + raise ValueError(f"'{name}' ('{val}') should be int (<{n}) or one of following {str_row_or_column} names: {rows_or_columns}") + +def _select_row(df=None, row_to_show=None): + """Select row""" + if row_to_show is not None: + if isinstance(row_to_show, int): + df = df.iloc[[row_to_show]] + elif isinstance(row_to_show, str): + df = df.loc[[row_to_show]] + return df + +def _select_col(df=None, col_to_show=None): + """Select column""" + if col_to_show is not None: + if isinstance(col_to_show, int): + df = df.iloc[:, [col_to_show]] + elif isinstance(col_to_show, str): + df = df[[col_to_show]] + return df + + +# Main functions +def display_df(df: pd.DataFrame = None, + max_width_pct: int = 100, + max_height: int = 300, + char_limit: int = 30, + show_shape=False, + n_rows: Optional[int] = None, + n_cols: Optional[int] = None, + row_to_show: Optional[Union[int, str]] = None, + col_to_show: Optional[Union[int, str]] = None, + ): + """ + Display DataFrame with specific style as HTML output for jupyter notebooks. + + Parameters + ---------- + df : pd.DataFrame + DataFrame to be displayed as HTML output. + max_width_pct: int, default=100 + Maximum width in percentage of main page for table. + max_height : int, default=300 + Maximum height in pixels of table. + char_limit : int, default=30 + Maximum number of characters to display in a cell. + show_shape : bool, default=False + If ``True``, shape of ``df`` is printed. + n_rows : int, optional + Display only the first n rows. If negative, last n rows will be shown. + n_cols : int, optional + Display only the first n columns. If negative, last n columns will be shown. + row_to_show : int or str, optional + Display only the specified row. + col_to_show : int or str, optional + Display only the specified column. + + Examples + -------- + .. include:: examples/display_df.rst + """ + # Check input + ut.check_df(name="df", df=df, accept_none=False) + ut.check_number_range(name="max_width_pct", val=max_width_pct, min_val=1, max_val=100, accept_none=False, just_int=True) + ut.check_number_range(name="max_height", val=max_height, min_val=1, accept_none=False, just_int=True) + ut.check_number_range(name="char_limit", val=char_limit, min_val=1, accept_none=True, just_int=True) + n_rows_, n_cols_ = len(df), len(df.T) + ut.check_number_range(name="n_rows", val=n_rows, min_val=-n_rows_, max_val=n_rows_, accept_none=True, just_int=True) + ut.check_number_range(name="n_cols", val=n_cols, min_val=-n_cols_, max_val=n_cols_, accept_none=True, just_int=True) + _check_show(name="show_only_col", val=col_to_show, df=df) + _check_show(name="show_only_row", val=row_to_show, df=df) + # Show shape before filtering + if show_shape: + print(f"DataFrame shape: {df.shape}") + # Filtering + df = df.copy() + df = _select_col(df=df, col_to_show=col_to_show) + df = _select_row(df=df, row_to_show=row_to_show) + if row_to_show is None and n_rows is not None: + if n_rows > 0: + df = df.head(n_rows) + else: + df = df.tail(abs(n_rows)) + if col_to_show is None and n_cols is not None: + if n_cols > 0: + df = df.T.head(n_cols).T + else: + df = df.T.tail(abs(n_cols)).T + # Style dataframe + df = _adjust_df(df=df, char_limit=char_limit) + styled_df = ( + df.style + .set_table_attributes(f"style='display:block; max-height: {max_height}px; max-width: {max_width_pct}%; overflow-x: auto; overflow-y: auto;'") + .set_table_styles([ + # Explicitly set background and text color for headers + {'selector': 'thead th', 'props': [('background-color', 'white'), ('color', 'black')]}, + # Style for odd and even rows + {'selector': 'tbody tr:nth-child(odd)', 'props': [('background-color', '#f2f2f2')]}, + {'selector': 'tbody tr:nth-child(even)', 'props': [('background-color', 'white')]}, + # General styling for table cells + {'selector': 'th, td', 'props': [('padding', '5px'), ('white-space', 'nowrap')]}, + ]) + ) + display(HTML(styled_df.to_html())) diff --git a/xomics/plotting_ut/_plot_gcfs.py b/xomics/plotting_ut/_plot_gcfs.py new file mode 100644 index 0000000..20e4de2 --- /dev/null +++ b/xomics/plotting_ut/_plot_gcfs.py @@ -0,0 +1,39 @@ +""" +This is a script for getting current font size of figures. +""" +import seaborn as sns + + +# Main function +def plot_gcfs(option: str = 'font.size') -> int: + """ + Get the current font size (or axes linewidth). + + This font size can be set by :func:`plot_settings` function. + + Parameters + ---------- + option : str, default='font.size' + Figure setting to get default value from. Either 'font.size' or 'axes.linewidth' + + Returns + ------- + option_value : int + Numerical value for selected option. + + See Also + -------- + * `Plotting Prelude `_. + + Examples + -------- + .. include:: examples/plot_gcfs.rst + """ + # Check input + allowed_options = ["font.size", "axes.linewidth"] + if option not in allowed_options: + raise ValueError(f"'option' should be one of following: {allowed_options}") + # Get the current plotting context + current_context = sns.plotting_context() + option_value = current_context[option] + return option_value diff --git a/xomics/plotting_ut/_plot_get_clist.py b/xomics/plotting_ut/_plot_get_clist.py new file mode 100644 index 0000000..c102ad9 --- /dev/null +++ b/xomics/plotting_ut/_plot_get_clist.py @@ -0,0 +1,42 @@ +""" +This is a script for frontend of plotting utility function to obtain AAanalysis color list. +The backend is in general utility module to provide function to remaining AAanalysis modules. +""" +from typing import List +import xomics.utils as ut + + +# II Main function +def plot_get_clist(n_colors: int = 3) -> List[str]: + """ + Get a manually curated list of 2 to 9 colors or 'husl' palette for more than 9 colors. + + This functions returns one of eight different color lists optimized for appealing visualization of categories. + If more than 9 colors are selected, :func:`seaborn.color_palette` with 'husl' palette will be used. + + Parameters + ---------- + n_colors : int, default=3 + Number of colors. Must be greater 2. + + Returns + ------- + list + Color list given as matplotlib color names. + + See Also + -------- + * The example notebooks in `Plotting Prelude `_. + * `Matplotlib color names `_ + * :func:`seaborn.color_palette` function to generate a color palette in seaborn. + + Examples + -------- + .. include:: examples/plot_get_clist.rst + """ + # Check input + ut.check_number_range(name="n_colors", val=n_colors, min_val=2, just_int=True) + # Base lists + colors = ut.plot_get_clist_(n_colors=n_colors) + return colors + diff --git a/xomics/plotting_ut/_plot_legend.py b/xomics/plotting_ut/_plot_legend.py new file mode 100644 index 0000000..b366c74 --- /dev/null +++ b/xomics/plotting_ut/_plot_legend.py @@ -0,0 +1,162 @@ +""" +This is a script for frontend of the setting plot legend. +The backend is in general utility module to provide function to remaining AAanalysis modules. +""" +from typing import Optional, List, Dict, Union, Tuple +from matplotlib import pyplot as plt +import xomics.utils as ut + +# I Helper functions + + +# II Main function +def plot_legend(ax: Optional[plt.Axes] = None, + # Categories and colors + dict_color: Dict[str, str] = None, + list_cat: Optional[List[str]] = None, + labels: Optional[List[str]] = None, + # Position and Layout + loc: Union[str, int] = "upper left", + loc_out: bool = False, + frameon: bool = False, + y: Optional[Union[int, float]] = None, + x: Optional[Union[int, float]] = None, + n_cols: int = 3, + labelspacing: Union[int, float] = 0.2, + columnspacing: Union[int, float] = 1.0, + handletextpad: Union[int, float] = 0.8, + handlelength: Union[int, float] = 2.0, + # Font and Style + fontsize: Optional[Union[int, float]] = None, + fontsize_title: Optional[Union[int, float]] = None, + weight_font: str = "normal", + weight_title: str = "normal", + # Marker, Lines, and Area + marker: Optional[Union[str, int, list]] = None, + marker_size: Union[int, float, List[Union[int, float]]] = 10, + lw: Union[int, float] = 0, + linestyle: Optional[Union[str, list]] = None, + edgecolor: Optional[str] = None, + hatch: Optional[Union[str, List[str]]] = None, + hatchcolor: str = "white", + # Title + title: Optional[str] = None, + title_align_left: bool = True, + **kwargs + ) -> Union[plt.Axes, Tuple[List, List[str]]]: + """ + Set an independently customizable plot legend. + + Legends can be flexibly adjusted based categories and colors provided in ``dict_color`` dictionary. + This functions comprises the most convenient settings for ``func:`matplotlib.pyplot.legend``. + + Parameters + ---------- + ax : plt.Axes, optional + The axes to attach the legend to. If not provided, the current axes will be used. + dict_color : dict, optional + A dictionary mapping categories to colors. + list_cat : list of str, optional + List of categories to include in the legend (keys of ``dict_color``). + labels : list of str, optional + Legend labels corresponding to given categories. + loc : int or str + Location for the legend. + loc_out : bool, default=False + If ``True``, sets automatically ``x=0`` and ``y=-0.25`` if they are ``None``. + frameon : bool, default=False + If ``True``, a figure background patch (frame) will be drawn. + y : int or float, optional + The y-coordinate for the legend's anchor point. + x : int or float, optional + The x-coordinate for the legend's anchor point. + n_cols : int, default=1 + Number of columns in the legend, at least 1. + labelspacing : int or float, default=0.2 + Vertical spacing between legend items. + columnspacing : int or float, default=1.0 + Horizontal spacing between legend columns. + handletextpad : int or float, default=0.8 + Horizontal spacing between legend handle (marker) and label. + handlelength : int or float, default=2.0 + Length of legend handle. + fontsize : int or float, optional + Font size of the legend text. + fontsize_title : inf or float, optional + Font size of the legend title. + weight_font : str, default='normal' + Weight of the font. + weight_title : str, default='normal' + Font weight for the legend title. + marker : str, int, or list, optional + Handle marker for legend items. Lines ('-') only visible if ``lw>0``. + marker_size : int, float, or list, optional + Marker size of legend items. + lw : int or float, default=0 + Line width for legend items. If negative, corners are rounded. + linestyle : str or list, optional + Style of line. Only applied to lines (``marker='-'``). + edgecolor : str, optional + Edge color of legend items. Not applicable to lines. + hatch : str or list, optional + Filling pattern for default marker. Only applicable when ``marker=None``. + hatchcolor : str, default='white' + Hatch color of legend items. Only applicable when ``marker=None``. + title : str, optional + Legend title. + title_align_left : bool, default=True + Whether to align the title to the left. + **kwargs + Further key word arguments for :attr:`matplotlib.axes.Axes.legend`. + + Returns + ------- + ax : plt.Axes + The axes object on which legend is applied to. + + Notes + ----- + Markers can be None (default), lines ('-') or one of the `matplotlib markers + `_. + + See Also + -------- + * More examples in `Plotting Prelude `_. + * `Linestyles of markers `_. + * `Hatches `_, which are filling patterns. + * :class:`matplotlib.lines.Line2D` for available marker shapes and line properties. + * :class:`matplotlib.axes.Axes`, which is the core object in matplotlib. + * :func:`matplotlib.pyplot.gca` to get the current Axes instance. + + Examples + -------- + .. include:: examples/plot_legend.rst + """ + # Check input + ut.check_ax(ax=ax, accept_none=True) + if ax is None: + ax = plt.gca() + ut.check_dict(name="dict_color", val=dict_color, accept_none=False) + ut.check_bool(name="title_align_left", val=title_align_left) + ut.check_bool(name="loc_out", val=loc_out) + ut.check_bool(name="frameon", val=frameon) + ut.check_number_range(name="n_cols", val=n_cols, min_val=1, accept_none=True, just_int=True) + ut.check_number_val(name="x", val=x, accept_none=True, just_int=False) + ut.check_number_val(name="y", val=y, accept_none=True, just_int=False) + ut.check_number_val(name="lw", val=lw, accept_none=True, just_int=False) + args_non_neg = {"labelspacing": labelspacing, "columnspacing": columnspacing, + "handletextpad": handletextpad, "handlelength": handlelength, + "fontsize": fontsize, "fontsize_legend": fontsize_title} + for key in args_non_neg: + ut.check_number_range(name=key, val=args_non_neg[key], min_val=0, accept_none=True, just_int=False) + # Create new legend + ax = ut.plot_legend_(ax=ax, dict_color=dict_color, list_cat=list_cat, labels=labels, + loc=loc, loc_out=loc_out, y=y, x=x, n_cols=n_cols, + labelspacing=labelspacing, columnspacing=columnspacing, + handletextpad=handletextpad, handlelength=handlelength, + fontsize=fontsize, fontsize_title=fontsize_title, + weight_font=weight_font, weight_title=weight_title, + marker=marker, marker_size=marker_size, lw=lw, linestyle=linestyle, edgecolor=edgecolor, + hatch=hatch, hatchcolor=hatchcolor, title=title, title_align_left=title_align_left, + frameon=frameon, **kwargs) + return ax diff --git a/xomics/plotting/_plot_settings.py b/xomics/plotting_ut/_plot_settings.py similarity index 75% rename from xomics/plotting/_plot_settings.py rename to xomics/plotting_ut/_plot_settings.py index 28a1188..bac9969 100644 --- a/xomics/plotting/_plot_settings.py +++ b/xomics/plotting_ut/_plot_settings.py @@ -1,45 +1,36 @@ """ -Plotting utility functions for xOmics to create publication ready figures. Can -be used for any Python project independently of xOmics. +Plotting utility functions for AAanalysis to create publication ready figures. Can +be used for any Python project independently of AAanalysis. """ +from typing import Union import seaborn as sns import matplotlib as mpl import matplotlib.pyplot as plt import xomics.utils as ut import warnings -LIST_FONTS = ['Arial', 'Avant Garde', - 'Bitstream Vera Sans', 'Computer Modern Sans Serif', - 'DejaVu Sans', 'Geneva', - 'Helvetica', 'Lucid', - 'Lucida Grande', 'Verdana'] + +LIST_FONTS = ['Arial', 'Courier New', 'DejaVu Sans', 'Times New Roman', 'Verdana'] # I Helper functions # Check plot_settings def check_font(font="Arial"): - """""" if font not in LIST_FONTS: error_message = f"'font' ({font}) not in recommended fonts: {LIST_FONTS}. Set font manually by:" \ f"\n\tplt.rcParams['font.sans-serif'] = '{font}'" raise ValueError(error_message) -def check_grid_axis(grid_axis="y"): - list_grid_axis = ["y", "x", "both"] - if grid_axis not in list_grid_axis: - raise ValueError(f"'grid_axis' ({grid_axis}) should be one of following: {list_grid_axis}") - - # Helper function def set_tick_size(axis=None, major_size=None, minor_size=None): - """Set tick size for the given axis.""" + """Set tick size of the given axis.""" plt.rcParams[f"{axis}tick.major.size"] = major_size plt.rcParams[f"{axis}tick.minor.size"] = minor_size # II Main functions -def plot_settings(font_scale: float = 1, +def plot_settings(font_scale: Union[int, float] = 1, font: str = "Arial", weight_bold: bool = True, adjust_only_font: bool = False, @@ -52,9 +43,10 @@ def plot_settings(font_scale: float = 1, short_ticks_x: bool = False, no_ticks_y: bool = False, short_ticks_y: bool = False, - show_options: bool = False) -> None: + show_options: bool = False + ) -> None: """ - Configures general plot settings. + Configure general plot settings. This function modifies the global settings of :mod:`matplotlib` and :mod:`seaborn` libraries. It adjusts font embedding for vector formats like PDF and SVG, ensuring compatibility and @@ -62,79 +54,52 @@ def plot_settings(font_scale: float = 1, Parameters ---------- - font_scale + font_scale : int or float, default=1 Scaling factor to scale the size of font elements. Consistent with :func:`seaborn.set_context`. - font - Name of text font. Common options are 'Arial', 'Verdana', 'Helvetica', or 'DejaVu Sans' (Matplotlib default). - weight_bold + font : {'Arial', 'Courier New', 'DejaVu Sans', 'Times New Roman', 'Verdana'}, default='Arial' + Name of text font. Common options are 'Arial' or 'DejaVu Sans' (Matplotlib default). + weight_bold : bool, default=True If ``True``, font and line elements are bold. - adjust_only_font + adjust_only_font : bool, default=False If ``True``, only the font style will be adjusted, leaving other elements unchanged. - adjust_further_elements + adjust_further_elements : bool, default=True If ``True``, makes additional visual and layout adjustments to the plot (errorbars, legend). - grid + grid : bool, default=False If ``True``, display the grid in plots. - grid_axis + grid_axis : {'y', 'x', 'both'}, default='y' Choose the axis ('y', 'x', 'both') to apply the grid to. - no_ticks + no_ticks : bool, default=False If ``True``, remove all tick marks on both x and y axes. - short_ticks + short_ticks : bool, default=False If ``True``, display short tick marks on both x and y axes. Is ignored if ``no_ticks=True``. - no_ticks_x + no_ticks_x : bool, default=False If ``True``, remove tick marks on the x-axis. - short_ticks_x + short_ticks_x : bool, default=False If ``True``, display short tick marks on the x-axis. Is ignored if ``no_ticks=True``. - no_ticks_y + no_ticks_y : bool, default=False If ``True``, remove tick marks on the y-axis. - short_ticks_y + short_ticks_y : bool, default=False If ``True``, display short tick marks on the y-axis. Is ignored if ``no_ticks=True``. - show_options + show_options : bool, default=False If ``True``, show all plot runtime configurations of matplotlib. - Examples - -------- - Create default seaborn plot: - - .. plot:: - :include-source: - - >>> import matplotlib.pyplot as plt - >>> import seaborn as sns - >>> import xomics as xo - >>> data = {'Classes': ['Class A', 'Class B', 'Class C'], 'Values': [23, 27, 43]} - >>> sns.barplot(x='Classes', y='Values', data=data) - >>> sns.despine() - >>> plt.title("Seaborn default") - >>> plt.tight_layout() - >>> plt.show() - - Adjust polts with xOmics: - - .. plot:: - :include-source: - - >>> import matplotlib.pyplot as plt - >>> import seaborn as sns - >>> import xomics as xo - >>> data = {'Classes': ['Class A', 'Class B', 'Class C'], 'Values': [23, 27, 43]} - >>> colors = xo.plot_get_clist() - >>> xo.plot_settings() - >>> sns.barplot(data=data, x='Classes', y='Values', palette=colors, hue="Classes", legend=False) - >>> sns.despine() - >>> plt.title("Adjusted") - >>> plt.tight_layout() - >>> plt.show() + Notes + ----- + * ``grid_axis`` work only for axis with numerical values. See Also -------- - * More examples in `Plotting Prelude `_. * :func:`seaborn.set_context`, where ``font_scale`` is utilized. * :data:`matplotlib.rcParams`, which manages the global settings in :mod:`matplotlib`. + + Examples + -------- + .. include:: examples/plot_settings.rst """ # Check input ut.check_number_range(name="font_scale", val=font_scale, min_val=0, just_int=False) check_font(font=font) - check_grid_axis(grid_axis=grid_axis) + ut.check_grid_axis(grid_axis=grid_axis) args_bool = {"weight_bold": weight_bold, "adjust_only_font": adjust_only_font, "adjust_further_elements": adjust_further_elements, "grid": grid, "short_ticks": short_ticks, "short_ticks_x": short_ticks_x, "short_ticks_y": short_ticks_y, diff --git a/xomics/ranking/_backend/prank.py b/xomics/ranking/_backend/prank.py index 5d30c98..0541903 100644 --- a/xomics/ranking/_backend/prank.py +++ b/xomics/ranking/_backend/prank.py @@ -158,7 +158,6 @@ def e_score_only_pvals(names=None, name_lists=None, x_pval=None): return e_scores - def c_score(ids=None, df_imp=None, col_id=None): """Obtain protein use_cases confidence score (C score) from cImpute output""" list_ids = df_imp.index.to_list() if col_id is None else df_imp[col_id].to_list() diff --git a/xomics/ranking/_prank.py b/xomics/ranking/_prank.py index 949af10..7675035 100644 --- a/xomics/ranking/_prank.py +++ b/xomics/ranking/_prank.py @@ -126,8 +126,8 @@ class pRank: Hybrid imputation algorithm for missing values (MVs) in (prote)omics data. """ def __init__(self, - col_id: str = ut.COL_PROT_ID, - col_name: str = ut.COL_GENE_NAME, + col_id: str = "protein_id", + col_name: str = "gene_name", str_quant: str = ut.STR_QUANT, ): """ @@ -308,8 +308,8 @@ def e_hits(ids=None, Returns ------- - df_e_hit : pandas.DataFrame - Data frame with links between gene/protein ids and 'enrichment' terms + df_e_hit : pd.DataFrame + DataFrame with links between gene/protein ids and 'enrichment' terms Examples -------- diff --git a/xomics/utils.py b/xomics/utils.py index f773508..e9507bf 100644 --- a/xomics/utils.py +++ b/xomics/utils.py @@ -12,25 +12,66 @@ from .config import options -# Import utility functions explicitly -from ._utils.check_data import (check_X, check_X_unique_samples, check_labels, check_match_X_labels, - check_array_like, check_superset_subset, - check_col_in_df, check_df) -from ._utils.check_models import check_mode_class, check_model_kwargs -from ._utils.check_type import (check_number_range, check_number_val, check_str, check_bool, - check_dict, check_tuple, check_list_like, check_str_in_list, - check_ax) - -from ._utils.new_types import ArrayLike1D, ArrayLike2D - -from ._utils.decorators import (catch_runtime_warnings, CatchRuntimeWarnings, - catch_convergence_warning, ClusteringConvergenceException, +# Data types +from ._utils.utils_types import (ArrayLike1D, + ArrayLike2D, + VALID_INT_TYPES, + VALID_FLOAT_TYPES, + VALID_INT_FLOAT_TYPES) + +# Decorators +from ._utils.decorators import (catch_runtime_warnings, + CatchRuntimeWarnings, + catch_convergence_warning, + ClusteringConvergenceException, catch_invalid_divide_warning, - doc_params) - -from ._utils.utils_output import (print_out, print_start_progress, print_progress, print_finished_progress) + catch_undefined_metric_warning, + CatchUndefinedMetricWarning) + +# Check functions +from ._utils.check_type import (check_number_range, + check_number_val, + check_str, + check_str_in_list, + check_bool, + check_dict, + check_tuple, + check_list_like) +from ._utils.check_data import (check_X, + check_X_unique_samples, + check_labels, + check_match_X_labels, + check_match_X_list_labels, + check_match_list_labels_names_datasets, + check_array_like, + check_superset_subset, + check_df) +from ._utils.check_models import (check_mode_class, + check_model_kwargs) +from ._utils.check_plots import (check_fig, + check_ax, + check_figsize, + check_grid_axis, + check_font_weight, + check_fontsize_args, + check_vmin_vmax, + check_lim, + check_dict_xlims, + check_color, + check_list_colors, + check_dict_color, + check_cmap, + check_palette) + +# Internal utility functions +from ._utils.utils_output import (print_out, + print_start_progress, + print_progress, + print_end_progress) + +# External (system-level) utility functions (only backend) from ._utils.utils_groups import get_dict_qcol_group, get_dict_group_qcols, get_qcols -from ._utils.utils_plotting import plot_gco +from ._utils.utils_plotting import plot_gco, plot_legend_, plot_get_clist_ # Folder structure @@ -145,17 +186,9 @@ def read_csv_cached(name, sep=None): # Main check functions -def check_verbose(verbose): - if verbose is None: - # System level verbosity - verbose = options['verbose'] - else: - check_bool(name="verbose", val=verbose) - return verbose - - def check_match_df_groups(df=None, groups=None, name_groups="groups", str_quant=None): """""" + print(df) if str_quant is None: raise ValueError("'str_quant' must be given.") list_substr_cols = [col.replace(str_quant, "").split("_") for col in list(df)]