Skip to content

Commit

Permalink
Merge pull request #419 from kbase/dev_parse_in_place
Browse files Browse the repository at this point in the history
RE2022-218: save pre-processed trait info
  • Loading branch information
Tianhao-Gu authored Aug 23, 2023
2 parents a8b506d + ed3e1b4 commit 6cb7b1e
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 22 deletions.
25 changes: 17 additions & 8 deletions src/common/product_models/heatmap_common_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,32 @@
"""

from enum import Enum

from pydantic import BaseModel, Field

from src.common.product_models.common_models import SubsetProcessStates


# these fields need to match the fields in the models below.
FIELD_HEATMAP_DATA = "data"
FIELD_HEATMAP_MIN_VALUE = "min_value"
FIELD_HEATMAP_MAX_VALUE = "max_value"
FIELD_HEATMAP_COUNT = "count"
FIELD_HEATMAP_CELL_VALUE = "val"
FIELD_HEATMAP_VALUES = 'values'
FIELD_HEATMAP_ROW_CELLS = "cells"
FIELD_HEATMAP_CELL_ID = 'cell_id'
FIELD_HEATMAP_COL_ID = 'col_id'
FIELD_HEATMAP_CELL_VALUE = "val"
FIELD_HEATMAP_NAME = "name"
FIELD_HEATMAP_DESCR = "description"
FIELD_HEATMAP_TYPE = "type"
FIELD_HEATMAP_CATEGORY = "category"

_FLD_CELL_ID = Field(
example="4",
description="The unique ID of the cell in the heatmap."
)


class ColumnType(str, Enum):
"""
The type of a column's values.
Expand Down Expand Up @@ -50,7 +57,7 @@ class ColumnInformation(BaseModel):
)
description: str = Field(
example="Spizizen medium (SM) is a popular minimal medium for the cultivation of "
+ "B. subtilis.",
+ "B. subtilis.",
description="The description of the column."
)
type: ColumnType = Field(
Expand All @@ -66,7 +73,7 @@ class ColumnCategory(BaseModel):
category: str | None = Field(
example="Minimal media",
description="The name of the category that groups columns together. Null if "
+ "columns are not categorized."
+ "columns are not categorized."
)
columns: list[ColumnInformation] = Field(
description="The columns in the category, provided in render order."
Expand Down Expand Up @@ -103,8 +110,9 @@ class Cell(BaseModel):
example=4.2,
description="The value of the heatmap at this cell."
)

class Config:
smart_union=True
smart_union = True


class HeatMapRow(BaseModel):
Expand Down Expand Up @@ -144,12 +152,12 @@ class HeatMap(SubsetProcessStates):
min_value: float | None = Field(
example=32.4,
description="The minimum cell value in the rows in this heatmap "
+ "or null if there are no rows."
+ "or null if there are no rows."
)
max_value: float | None = Field(
example=71.8,
description="The maximum cell value in the rows in this heatmap "
+ "or null if there are no rows."
+ "or null if there are no rows."
)
count: int | None = Field(
example=42,
Expand All @@ -169,8 +177,9 @@ class CellDetailEntry(BaseModel):
example=56.1,
description="The value of the cell entry."
)

class Config:
smart_union=True
smart_union = True


class CellDetail(BaseModel):
Expand Down
5 changes: 1 addition & 4 deletions src/loaders/common/loader_common_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,8 @@

# the name of the CSV file where we store the trait counts
TRAIT_COUNTS_FILE = 'trait_counts.csv'
SYS_TRAIT_ID = 'trait_id' # unique identifier for a trait
DETECTED_GENE_SCORE_COL = 'detected_genes_score' # column name for the detected genes score

# column name for the trait unique identifier defined in the granularity trait count table
MICROTRAIT_TRAIT_NAME = 'microtrait_trait-name'
SYS_TRAIT_ID = 'trait_id' # unique identifier for a trait
UNWRAPPED_GENE_COL = 'unwrapped_genes' # column name that contains parsed gene name from unwrapped rule

# kbase authentication token
Expand Down
2 changes: 2 additions & 0 deletions src/loaders/compute_tools/microtrait/env_microtrait.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ dependencies:
- rpy2=3.5.10
- pandas=2.0.0
- unzip=6.0
- jsonlines=3.1.0
- pydantic=1.10.12
110 changes: 103 additions & 7 deletions src/loaders/compute_tools/microtrait/microtrait.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,68 @@
"""
Runs microtrait on a set of assemblies.
"""
import json
import os
import uuid
from pathlib import Path

import pandas as pd
from rpy2 import robjects

from src.common.product_models.heatmap_common_models import (
FIELD_HEATMAP_VALUES,
FIELD_HEATMAP_ROW_CELLS,
FIELD_HEATMAP_CELL_ID,
FIELD_HEATMAP_COL_ID,
FIELD_HEATMAP_CELL_VALUE,
FIELD_HEATMAP_NAME,
FIELD_HEATMAP_DESCR,
FIELD_HEATMAP_TYPE,
FIELD_HEATMAP_CATEGORY)
from src.common.storage.field_names import FLD_KBASE_ID
from src.loaders.common import loader_common_names
from src.loaders.compute_tools.tool_common import (
FatalTuple,
FatalTuple,
ToolRunner,
write_fatal_tuples_to_dict,
)
from src.loaders.compute_tools.tool_result_parser import (
create_jsonl_files,
MICROTRAIT_META,
MICROTRAIT_CELLS,
MICROTRAIT_DATA)

# the name of the component used for extracting traits from microtrait's 'extract.traits' result
TRAIT_COUNTS_ATGRANULARITY = 'trait_counts_atgranularity3'
_GENE_NAME_COL = 'hmm_name' # column name from the genes_detected_table file that contains the gene name
_GENE_SCORE_COL = 'gene_score' # column name from the genes_detected_table file that contains the gene score

# The following features will be extracted from the MicroTrait result file as heatmap data
_MICROTRAIT_TRAIT_DISPLAYNAME_SHORT = 'microtrait_trait-displaynameshort' # used as column name of the trait
_MICROTRAIT_TRAIT_DISPLAYNAME_LONG = 'microtrait_trait-displaynamelong' # used as description of the trait
_MICROTRAIT_TRAIT_VALUE = 'microtrait_trait-value' # value of the trait (can be integer or 0/1 as boolean)
_MICROTRAIT_TRAIT_TYPE = 'microtrait_trait-type' # type of trait (count or binary)
_MICROTRAIT_TRAIT_ORDER = 'microtrait_trait-displayorder' # order of the trait defined by the granularity table used as the index of trait
_MICROTRAIT_TRAIT_NAME = 'microtrait_trait-name' # column name for the trait unique identifier defined in the granularity trait count table

_SYS_DEFAULT_TRAIT_VALUE = 0 # default value (0 or False) for a trait if the value is missing/not available

_DETECTED_GENE_SCORE_COL = 'detected_genes_score' # column name for the detected genes score

# The map between the MicroTrait trait names and the corresponding system trait names
# Use the microtrait_trait-name column as the unique identifier for a trait globally,
# the microtrait_trait-displaynameshort column as the column name,
# microtrait_trait-displaynamelong column as the column description, and
# microtrait_trait-value as the cell value
_MICROTRAIT_TO_SYS_TRAIT_MAP = {
_MICROTRAIT_TRAIT_NAME: loader_common_names.SYS_TRAIT_ID,
_MICROTRAIT_TRAIT_DISPLAYNAME_SHORT: FIELD_HEATMAP_NAME,
_MICROTRAIT_TRAIT_DISPLAYNAME_LONG: FIELD_HEATMAP_DESCR,
_MICROTRAIT_TRAIT_VALUE: FIELD_HEATMAP_CELL_VALUE,
_MICROTRAIT_TRAIT_TYPE: FIELD_HEATMAP_TYPE,
_MICROTRAIT_TRAIT_ORDER: FIELD_HEATMAP_COL_ID,
_DETECTED_GENE_SCORE_COL: _DETECTED_GENE_SCORE_COL,
}


def _get_r_list_element(r_list, element_name):
# retrieve the element from the R list
Expand Down Expand Up @@ -47,6 +90,54 @@ def _r_table_to_df(r_table):
return df


def _process_trait_counts(
trait_counts_df: pd.DataFrame,
data_id: str):
# process the trait counts file to create the heatmap data and metadata

trait_df = trait_counts_df[_MICROTRAIT_TO_SYS_TRAIT_MAP.keys()]

# Check if the trait index column has non-unique values
if len(trait_df[_MICROTRAIT_TRAIT_NAME].unique()) != len(trait_df):
raise ValueError(f"The {_MICROTRAIT_TRAIT_NAME} column has non-unique values")

# Extract the substring of the 'microtrait_trait-displaynamelong' column before the first colon character
# and assign it to a new 'category' column in the DataFrame
trait_df[FIELD_HEATMAP_CATEGORY] = trait_df[_MICROTRAIT_TRAIT_DISPLAYNAME_LONG].str.split(':').str[0]

trait_df = trait_df.rename(columns=_MICROTRAIT_TO_SYS_TRAIT_MAP)

# ensure the col_id column is string type
trait_df[FIELD_HEATMAP_COL_ID] = trait_df[FIELD_HEATMAP_COL_ID].astype(str)

traits = trait_df.to_dict(orient='records')
cells, cells_meta, traits_meta = list(), list(), list()
for trait in traits:
cell_uuid = str(uuid.uuid4())
# process cell data
cells.append({FIELD_HEATMAP_CELL_ID: cell_uuid,
FIELD_HEATMAP_COL_ID: trait[FIELD_HEATMAP_COL_ID],
FIELD_HEATMAP_CELL_VALUE: trait[FIELD_HEATMAP_CELL_VALUE]})

# process cell meta
cells_meta.append({FIELD_HEATMAP_CELL_ID: cell_uuid,
FIELD_HEATMAP_VALUES: trait[_DETECTED_GENE_SCORE_COL]})

# process trait meta
trait_meta_keys = [FIELD_HEATMAP_COL_ID,
FIELD_HEATMAP_NAME,
FIELD_HEATMAP_DESCR,
FIELD_HEATMAP_CATEGORY,
FIELD_HEATMAP_TYPE]

traits_meta.append({key: trait[key] for key in trait_meta_keys})

heatmap_row = [{FLD_KBASE_ID: data_id,
FIELD_HEATMAP_ROW_CELLS: cells}]

return heatmap_row, cells_meta, traits_meta


def _run_microtrait(tool_safe_data_id: str, data_id: str, fna_file: Path, genome_dir: Path, debug: bool):
# run microtrait.extract_traits on the genome file
# https://github.com/ukaraoz/microtrait
Expand Down Expand Up @@ -78,7 +169,7 @@ def _run_microtrait(tool_safe_data_id: str, data_id: str, fna_file: Path, genome
error_message = "Microtrait output no data"
fatal_tuples = [FatalTuple(data_id, error_message, str(fna_file), None)]
write_fatal_tuples_to_dict(fatal_tuples, genome_dir)
return
return
# example trait_counts_df from trait_counts_atgranularity3
# microtrait_trait-name,microtrait_trait-value,microtrait_trait-displaynameshort,microtrait_trait-displaynamelong,microtrait_trait-strategy,microtrait_trait-type,microtrait_trait-granularity,microtrait_trait-version,microtrait_trait-displayorder,microtrait_trait-value1
# Resource Acquisition:Substrate uptake:aromatic acid transport,1,Aromatic acid transport,Resource Acquisition:Substrate uptake:aromatic acid transport,Resource Acquisition,count,3,production,1,1
Expand All @@ -89,7 +180,7 @@ def _run_microtrait(tool_safe_data_id: str, data_id: str, fna_file: Path, genome
if trait_unwrapped_rules_file:
trait_unwrapped_rules_df = pd.read_csv(trait_unwrapped_rules_file, sep='\t')
trait_counts_df = trait_counts_df.merge(trait_unwrapped_rules_df,
left_on=loader_common_names.MICROTRAIT_TRAIT_NAME,
left_on=_MICROTRAIT_TRAIT_NAME,
right_on=loader_common_names.SYS_TRAIT_ID,
how='left')
trait_counts_df.drop(columns=[loader_common_names.SYS_TRAIT_ID], inplace=True)
Expand All @@ -98,15 +189,20 @@ def _run_microtrait(tool_safe_data_id: str, data_id: str, fna_file: Path, genome
genes_detected_df = _r_table_to_df(genes_detected_table)
detected_genes_score = dict(zip(genes_detected_df[_GENE_NAME_COL], genes_detected_df[_GENE_SCORE_COL]))

trait_counts_df[loader_common_names.DETECTED_GENE_SCORE_COL] = trait_counts_df[
trait_counts_df[_DETECTED_GENE_SCORE_COL] = trait_counts_df[
loader_common_names.UNWRAPPED_GENE_COL].apply(
lambda x: json.dumps({gene: detected_genes_score.get(gene) for gene in str(x).split(';') if
gene in detected_genes_score}))
lambda x: {gene: detected_genes_score.get(gene) for gene in str(x).split(';') if
gene in detected_genes_score})
else:
raise ValueError('Please set environment variable MT_TRAIT_UNWRAPPED_FILE')

trait_counts_df.to_csv(os.path.join(genome_dir, loader_common_names.TRAIT_COUNTS_FILE), index=False)

heatmap_row, cells_meta, traits_meta = _process_trait_counts(trait_counts_df, data_id)
create_jsonl_files(genome_dir / MICROTRAIT_META, traits_meta)
create_jsonl_files(genome_dir / MICROTRAIT_CELLS, cells_meta)
create_jsonl_files(genome_dir / MICROTRAIT_DATA, heatmap_row)


def main():
runner = ToolRunner("microtrait")
Expand Down
6 changes: 5 additions & 1 deletion src/loaders/compute_tools/microtrait/versions.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
versions:
- version: 0.1.0
date: 2023-07-19
date: 2023-07-19
- version: 0.1.1
date: 2023-08-16
notes: |
- install jsonlines to support saving of Microtrait output
3 changes: 3 additions & 0 deletions src/loaders/compute_tools/tool_result_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from src.loaders.compute_tools.tool_common import GenomeTuple

TOOL_GENOME_ATTRI_FILE = "genome_attribs.jsonl"
MICROTRAIT_CELLS = "microtrait_cells.jsonl"
MICROTRAIT_META = "microtrait_meta.jsonl"
MICROTRAIT_DATA = "microtrait_data.jsonl"


def process_genome_attri_result(
Expand Down
4 changes: 2 additions & 2 deletions src/loaders/genome_collection/parse_tool_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@
# microtrait_trait-displaynamelong column as the column description, and
# microtrait_trait-value as the cell value
_MICROTRAIT_TO_SYS_TRAIT_MAP = {
loader_common_names.MICROTRAIT_TRAIT_NAME: loader_common_names.SYS_TRAIT_ID,
# loader_common_names.MICROTRAIT_TRAIT_NAME: loader_common_names.SYS_TRAIT_ID,
_MICROTRAIT_TRAIT_DISPLAYNAME_SHORT: _SYS_TRAIT_NAME,
_MICROTRAIT_TRAIT_DISPLAYNAME_LONG: _SYS_TRAIT_DESCRIPTION,
_MICROTRAIT_TRAIT_VALUE: _SYS_TRAIT_VALUE,
_MICROTRAIT_TRAIT_TYPE: _SYS_TRAIT_TYPE,
_MICROTRAIT_TRAIT_ORDER: _SYS_TRAIT_INDEX,
loader_common_names.DETECTED_GENE_SCORE_COL: loader_common_names.DETECTED_GENE_SCORE_COL,
# loader_common_names.DETECTED_GENE_SCORE_COL: loader_common_names.DETECTED_GENE_SCORE_COL,
}

# The suffix for the sequence metadata file name for Assembly Homology service
Expand Down

0 comments on commit 6cb7b1e

Please sign in to comment.