Skip to content

Commit

Permalink
Merge pull request #20 from redst4r/sparse_fix
Browse files Browse the repository at this point in the history
BUGFIX for sparse expression matrices
  • Loading branch information
cflerin authored Feb 8, 2021
2 parents 3ff7b6f + 77ddb9f commit 3cf1923
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions arboreto/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import logging

import scipy
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor, ExtraTreesRegressor
from dask import delayed
from dask.dataframe import from_delayed
Expand Down Expand Up @@ -119,7 +120,12 @@ def fit_model(regressor_type,
"""
regressor_type = regressor_type.upper()

assert tf_matrix.shape[0] == len(target_gene_expression)

if isinstance(target_gene_expression, scipy.sparse.spmatrix):
target_gene_expression = target_gene_expression.A.flatten()

assert tf_matrix.shape[0] == target_gene_expression.shape[0]


def do_sklearn_regression():
regressor = SKLEARN_REGRESSOR_FACTORY[regressor_type](random_state=seed, **regressor_kwargs)
Expand Down Expand Up @@ -226,7 +232,12 @@ def clean(tf_matrix,
if target_gene_name not in tf_matrix_gene_names:
clean_tf_matrix = tf_matrix
else:
clean_tf_matrix = np.delete(tf_matrix, tf_matrix_gene_names.index(target_gene_name), 1)
ix = tf_matrix_gene_names.index(target_gene_name)
if isinstance(tf_matrix, scipy.sparse.spmatrix):
clean_tf_matrix = scipy.sparse.hstack([tf_matrix[:, :ix],
tf_matrix[:, ix+1:]])
else:
clean_tf_matrix = np.delete(tf_matrix, ix, 1)

clean_tf_names = [tf for tf in tf_matrix_gene_names if tf != target_gene_name]

Expand Down

0 comments on commit 3cf1923

Please sign in to comment.