Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RE2022-218: save pre-processed trait info #419

Merged
merged 13 commits into from
Aug 23, 2023
25 changes: 17 additions & 8 deletions src/common/product_models/heatmap_common_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,32 @@
"""

from enum import Enum

from pydantic import BaseModel, Field

from src.common.product_models.common_models import SubsetProcessStates


# these fields need to match the fields in the models below.
FIELD_HEATMAP_DATA = "data"
FIELD_HEATMAP_MIN_VALUE = "min_value"
FIELD_HEATMAP_MAX_VALUE = "max_value"
FIELD_HEATMAP_COUNT = "count"
FIELD_HEATMAP_CELL_VALUE = "val"
FIELD_HEATMAP_VALUES = 'values'
FIELD_HEATMAP_ROW_CELLS = "cells"
FIELD_HEATMAP_CELL_ID = 'cell_id'
FIELD_HEATMAP_COL_ID = 'col_id'
FIELD_HEATMAP_CELL_VALUE = "val"
FIELD_HEATMAP_NAME = "name"
FIELD_HEATMAP_DESCR = "description"
FIELD_HEATMAP_TYPE = "type"
FIELD_HEATMAP_CATEGORY = "category"

_FLD_CELL_ID = Field(
example="4",
description="The unique ID of the cell in the heatmap."
)


class ColumnType(str, Enum):
"""
The type of a column's values.
Expand Down Expand Up @@ -50,7 +57,7 @@ class ColumnInformation(BaseModel):
)
description: str = Field(
example="Spizizen medium (SM) is a popular minimal medium for the cultivation of "
+ "B. subtilis.",
+ "B. subtilis.",
description="The description of the column."
)
type: ColumnType = Field(
Expand All @@ -66,7 +73,7 @@ class ColumnCategory(BaseModel):
category: str | None = Field(
example="Minimal media",
description="The name of the category that groups columns together. Null if "
+ "columns are not categorized."
+ "columns are not categorized."
)
columns: list[ColumnInformation] = Field(
description="The columns in the category, provided in render order."
Expand Down Expand Up @@ -103,8 +110,9 @@ class Cell(BaseModel):
example=4.2,
description="The value of the heatmap at this cell."
)

class Config:
smart_union=True
smart_union = True


class HeatMapRow(BaseModel):
Expand Down Expand Up @@ -144,12 +152,12 @@ class HeatMap(SubsetProcessStates):
min_value: float | None = Field(
example=32.4,
description="The minimum cell value in the rows in this heatmap "
+ "or null if there are no rows."
+ "or null if there are no rows."
)
max_value: float | None = Field(
example=71.8,
description="The maximum cell value in the rows in this heatmap "
+ "or null if there are no rows."
+ "or null if there are no rows."
)
count: int | None = Field(
example=42,
Expand All @@ -169,8 +177,9 @@ class CellDetailEntry(BaseModel):
example=56.1,
description="The value of the cell entry."
)

class Config:
smart_union=True
smart_union = True


class CellDetail(BaseModel):
Expand Down
5 changes: 1 addition & 4 deletions src/loaders/common/loader_common_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,8 @@

# the name of the CSV file where we store the trait counts
TRAIT_COUNTS_FILE = 'trait_counts.csv'
SYS_TRAIT_ID = 'trait_id' # unique identifier for a trait
DETECTED_GENE_SCORE_COL = 'detected_genes_score' # column name for the detected genes score

# column name for the trait unique identifier defined in the granularity trait count table
MICROTRAIT_TRAIT_NAME = 'microtrait_trait-name'
SYS_TRAIT_ID = 'trait_id' # unique identifier for a trait
UNWRAPPED_GENE_COL = 'unwrapped_genes' # column name that contains parsed gene name from unwrapped rule

# kbase authentication token
Expand Down
2 changes: 2 additions & 0 deletions src/loaders/compute_tools/microtrait/env_microtrait.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ dependencies:
- rpy2=3.5.10
- pandas=2.0.0
- unzip=6.0
- jsonlines=3.1.0
- pydantic=1.10.12
110 changes: 103 additions & 7 deletions src/loaders/compute_tools/microtrait/microtrait.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,68 @@
"""
Runs microtrait on a set of assemblies.
"""
import json
import os
import uuid
from pathlib import Path

import pandas as pd
from rpy2 import robjects

from src.common.product_models.heatmap_common_models import (
FIELD_HEATMAP_VALUES,
FIELD_HEATMAP_ROW_CELLS,
FIELD_HEATMAP_CELL_ID,
FIELD_HEATMAP_COL_ID,
FIELD_HEATMAP_CELL_VALUE,
FIELD_HEATMAP_NAME,
FIELD_HEATMAP_DESCR,
FIELD_HEATMAP_TYPE,
FIELD_HEATMAP_CATEGORY)
from src.common.storage.field_names import FLD_KBASE_ID
from src.loaders.common import loader_common_names
from src.loaders.compute_tools.tool_common import (
FatalTuple,
FatalTuple,
ToolRunner,
write_fatal_tuples_to_dict,
)
from src.loaders.compute_tools.tool_result_parser import (
create_jsonl_files,
MICROTRAIT_META,
MICROTRAIT_CELLS,
MICROTRAIT_DATA)

# the name of the component used for extracting traits from microtrait's 'extract.traits' result
TRAIT_COUNTS_ATGRANULARITY = 'trait_counts_atgranularity3'
_GENE_NAME_COL = 'hmm_name' # column name from the genes_detected_table file that contains the gene name
_GENE_SCORE_COL = 'gene_score' # column name from the genes_detected_table file that contains the gene score

# The following features will be extracted from the MicroTrait result file as heatmap data
_MICROTRAIT_TRAIT_DISPLAYNAME_SHORT = 'microtrait_trait-displaynameshort' # used as column name of the trait
_MICROTRAIT_TRAIT_DISPLAYNAME_LONG = 'microtrait_trait-displaynamelong' # used as description of the trait
_MICROTRAIT_TRAIT_VALUE = 'microtrait_trait-value' # value of the trait (can be integer or 0/1 as boolean)
_MICROTRAIT_TRAIT_TYPE = 'microtrait_trait-type' # type of trait (count or binary)
_MICROTRAIT_TRAIT_ORDER = 'microtrait_trait-displayorder' # order of the trait defined by the granularity table used as the index of trait
_MICROTRAIT_TRAIT_NAME = 'microtrait_trait-name' # column name for the trait unique identifier defined in the granularity trait count table

_SYS_DEFAULT_TRAIT_VALUE = 0 # default value (0 or False) for a trait if the value is missing/not available

_DETECTED_GENE_SCORE_COL = 'detected_genes_score' # column name for the detected genes score

# The map between the MicroTrait trait names and the corresponding system trait names
# Use the microtrait_trait-name column as the unique identifier for a trait globally,
# the microtrait_trait-displaynameshort column as the column name,
# microtrait_trait-displaynamelong column as the column description, and
# microtrait_trait-value as the cell value
_MICROTRAIT_TO_SYS_TRAIT_MAP = {
_MICROTRAIT_TRAIT_NAME: loader_common_names.SYS_TRAIT_ID,
_MICROTRAIT_TRAIT_DISPLAYNAME_SHORT: FIELD_HEATMAP_NAME,
_MICROTRAIT_TRAIT_DISPLAYNAME_LONG: FIELD_HEATMAP_DESCR,
_MICROTRAIT_TRAIT_VALUE: FIELD_HEATMAP_CELL_VALUE,
_MICROTRAIT_TRAIT_TYPE: FIELD_HEATMAP_TYPE,
_MICROTRAIT_TRAIT_ORDER: FIELD_HEATMAP_COL_ID,
_DETECTED_GENE_SCORE_COL: _DETECTED_GENE_SCORE_COL,
}


def _get_r_list_element(r_list, element_name):
# retrieve the element from the R list
Expand Down Expand Up @@ -47,6 +90,54 @@ def _r_table_to_df(r_table):
return df


def _process_trait_counts(
trait_counts_df: pd.DataFrame,
data_id: str):
# process the trait counts file to create the heatmap data and metadata

trait_df = trait_counts_df[_MICROTRAIT_TO_SYS_TRAIT_MAP.keys()]

# Check if the trait index column has non-unique values
if len(trait_df[_MICROTRAIT_TRAIT_NAME].unique()) != len(trait_df):
raise ValueError(f"The {_MICROTRAIT_TRAIT_NAME} column has non-unique values")

# Extract the substring of the 'microtrait_trait-displaynamelong' column before the first colon character
# and assign it to a new 'category' column in the DataFrame
trait_df[FIELD_HEATMAP_CATEGORY] = trait_df[_MICROTRAIT_TRAIT_DISPLAYNAME_LONG].str.split(':').str[0]

trait_df = trait_df.rename(columns=_MICROTRAIT_TO_SYS_TRAIT_MAP)

# ensure the col_id column is string type
trait_df[FIELD_HEATMAP_COL_ID] = trait_df[FIELD_HEATMAP_COL_ID].astype(str)

traits = trait_df.to_dict(orient='records')
cells, cells_meta, traits_meta = list(), list(), list()
for trait in traits:
cell_uuid = str(uuid.uuid4())
# process cell data
cells.append({FIELD_HEATMAP_CELL_ID: cell_uuid,
FIELD_HEATMAP_COL_ID: trait[FIELD_HEATMAP_COL_ID],
FIELD_HEATMAP_CELL_VALUE: trait[FIELD_HEATMAP_CELL_VALUE]})

# process cell meta
cells_meta.append({FIELD_HEATMAP_CELL_ID: cell_uuid,
FIELD_HEATMAP_VALUES: trait[_DETECTED_GENE_SCORE_COL]})

# process trait meta
trait_meta_keys = [FIELD_HEATMAP_COL_ID,
FIELD_HEATMAP_NAME,
FIELD_HEATMAP_DESCR,
FIELD_HEATMAP_CATEGORY,
FIELD_HEATMAP_TYPE]

traits_meta.append({key: trait[key] for key in trait_meta_keys})

heatmap_row = [{FLD_KBASE_ID: data_id,
FIELD_HEATMAP_ROW_CELLS: cells}]

return heatmap_row, cells_meta, traits_meta


def _run_microtrait(tool_safe_data_id: str, data_id: str, fna_file: Path, genome_dir: Path, debug: bool):
# run microtrait.extract_traits on the genome file
# https://github.com/ukaraoz/microtrait
Expand Down Expand Up @@ -78,7 +169,7 @@ def _run_microtrait(tool_safe_data_id: str, data_id: str, fna_file: Path, genome
error_message = "Microtrait output no data"
fatal_tuples = [FatalTuple(data_id, error_message, str(fna_file), None)]
write_fatal_tuples_to_dict(fatal_tuples, genome_dir)
return
return
# example trait_counts_df from trait_counts_atgranularity3
# microtrait_trait-name,microtrait_trait-value,microtrait_trait-displaynameshort,microtrait_trait-displaynamelong,microtrait_trait-strategy,microtrait_trait-type,microtrait_trait-granularity,microtrait_trait-version,microtrait_trait-displayorder,microtrait_trait-value1
# Resource Acquisition:Substrate uptake:aromatic acid transport,1,Aromatic acid transport,Resource Acquisition:Substrate uptake:aromatic acid transport,Resource Acquisition,count,3,production,1,1
Expand All @@ -89,7 +180,7 @@ def _run_microtrait(tool_safe_data_id: str, data_id: str, fna_file: Path, genome
if trait_unwrapped_rules_file:
trait_unwrapped_rules_df = pd.read_csv(trait_unwrapped_rules_file, sep='\t')
trait_counts_df = trait_counts_df.merge(trait_unwrapped_rules_df,
left_on=loader_common_names.MICROTRAIT_TRAIT_NAME,
left_on=_MICROTRAIT_TRAIT_NAME,
right_on=loader_common_names.SYS_TRAIT_ID,
how='left')
trait_counts_df.drop(columns=[loader_common_names.SYS_TRAIT_ID], inplace=True)
Expand All @@ -98,15 +189,20 @@ def _run_microtrait(tool_safe_data_id: str, data_id: str, fna_file: Path, genome
genes_detected_df = _r_table_to_df(genes_detected_table)
detected_genes_score = dict(zip(genes_detected_df[_GENE_NAME_COL], genes_detected_df[_GENE_SCORE_COL]))

trait_counts_df[loader_common_names.DETECTED_GENE_SCORE_COL] = trait_counts_df[
trait_counts_df[_DETECTED_GENE_SCORE_COL] = trait_counts_df[
loader_common_names.UNWRAPPED_GENE_COL].apply(
lambda x: json.dumps({gene: detected_genes_score.get(gene) for gene in str(x).split(';') if
gene in detected_genes_score}))
lambda x: {gene: detected_genes_score.get(gene) for gene in str(x).split(';') if
gene in detected_genes_score})
else:
raise ValueError('Please set environment variable MT_TRAIT_UNWRAPPED_FILE')

trait_counts_df.to_csv(os.path.join(genome_dir, loader_common_names.TRAIT_COUNTS_FILE), index=False)

heatmap_row, cells_meta, traits_meta = _process_trait_counts(trait_counts_df, data_id)
create_jsonl_files(genome_dir / MICROTRAIT_META, traits_meta)
create_jsonl_files(genome_dir / MICROTRAIT_CELLS, cells_meta)
create_jsonl_files(genome_dir / MICROTRAIT_DATA, heatmap_row)


def main():
runner = ToolRunner("microtrait")
Expand Down
6 changes: 5 additions & 1 deletion src/loaders/compute_tools/microtrait/versions.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
versions:
- version: 0.1.0
date: 2023-07-19
date: 2023-07-19
- version: 0.1.1
date: 2023-08-16
notes: |
- install jsonlines to support saving of Microtrait output
3 changes: 3 additions & 0 deletions src/loaders/compute_tools/tool_result_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from src.loaders.compute_tools.tool_common import GenomeTuple

TOOL_GENOME_ATTRI_FILE = "genome_attribs.jsonl"
MICROTRAIT_CELLS = "microtrait_cells.jsonl"
MICROTRAIT_META = "microtrait_meta.jsonl"
MICROTRAIT_DATA = "microtrait_data.jsonl"


def process_genome_attri_result(
Expand Down
4 changes: 2 additions & 2 deletions src/loaders/genome_collection/parse_tool_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@
# microtrait_trait-displaynamelong column as the column description, and
# microtrait_trait-value as the cell value
_MICROTRAIT_TO_SYS_TRAIT_MAP = {
loader_common_names.MICROTRAIT_TRAIT_NAME: loader_common_names.SYS_TRAIT_ID,
# loader_common_names.MICROTRAIT_TRAIT_NAME: loader_common_names.SYS_TRAIT_ID,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented them out so that tests can pass. Those will be removed in the following PRs.

_MICROTRAIT_TRAIT_DISPLAYNAME_SHORT: _SYS_TRAIT_NAME,
_MICROTRAIT_TRAIT_DISPLAYNAME_LONG: _SYS_TRAIT_DESCRIPTION,
_MICROTRAIT_TRAIT_VALUE: _SYS_TRAIT_VALUE,
_MICROTRAIT_TRAIT_TYPE: _SYS_TRAIT_TYPE,
_MICROTRAIT_TRAIT_ORDER: _SYS_TRAIT_INDEX,
loader_common_names.DETECTED_GENE_SCORE_COL: loader_common_names.DETECTED_GENE_SCORE_COL,
# loader_common_names.DETECTED_GENE_SCORE_COL: loader_common_names.DETECTED_GENE_SCORE_COL,
}

# The suffix for the sequence metadata file name for Assembly Homology service
Expand Down