Skip to content

Commit

Permalink
Merge pull request #242 from andersen-lab/solver_options
Browse files Browse the repository at this point in the history
Solver options
  • Loading branch information
joshuailevy authored Jun 14, 2024
2 parents 51b2731 + d3cb40d commit d54a8e8
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
auto-update-conda: true
channels: bioconda,conda-forge
channel-priority: true
python-version: 3.7
python-version: '3.10'
activate-environment: test

- name: Install
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/update_barcodes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
auto-update-conda: true
channels: bioconda,conda-forge
channel-priority: true
python-version: 3.7
python-version: '3.10'
activate-environment: test

- name: Install
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ Freyja is intended as a post-processing step after primer trimming and variant c

To ensure reproducibility of results, we provide old (timestamped) barcodes and metadata in the separate [Freyja-data](https://github.com/andersen-lab/Freyja-data) repository. Barcode version can be checked using the `freyja demix --version` command.

NOTE: Freyja barcodes are now stored in compressed feather format, as the initial csv barcode file became too large. Specific lineage definitions are now provided in [here](https://github.com/andersen-lab/Freyja/blob/main/freyja/data/lineage_mutations.json).

## Installation via conda
Freyja is entirely written in Python 3, but requires preprocessing by tools like iVar and [samtools](https://github.com/samtools/samtools) mpileup to generate the required input data. We recommend using python3.7, but Freyja has been tested on python versions up to 3.10. First, create an environment for freyja
Freyja is entirely written in Python 3, but requires preprocessing by tools like iVar and [samtools](https://github.com/samtools/samtools) mpileup to generate the required input data. We recommend using python3.10 to take advantage of the Clarabel solver, but Freyja has been tested on python versions from 3.7 to 3.10. First, create an environment for freyja
```
conda create -n freyja-env
```
Expand Down
18 changes: 14 additions & 4 deletions freyja/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ def print_barcode_version(ctx, param, value):
@click.option('--relaxedthresh', default=0.9,
help='associated threshold for robust mrca function',
show_default=True)
@click.option('--solver', default='CLARABEL',
help='solver used for estimating lineage prevalence',
show_default=True)
def demix(variants, depths, output, eps, barcodes, meta,
covcut, confirmedonly, depthcutoff, lineageyml,
adapt, a_eps, region_of_interest,
relaxedmrca, relaxedthresh):
relaxedmrca, relaxedthresh, solver):
"""
Generate relative lineage abundances from VARIANTS and DEPTHS
"""
Expand Down Expand Up @@ -109,12 +112,14 @@ def demix(variants, depths, output, eps, barcodes, meta,
covcut)
print('demixing')
df_barcodes, mix, depths_ = reindex_dfs(df_barcodes, mix, depths_)

try:
sample_strains, abundances, error = solve_demixing_problem(df_barcodes,
mix,
depths_,
eps, adapt,
a_eps)
a_eps,
solver)
except Exception as e:
print(e)
print('Error: Demixing step failed. Returning empty data output')
Expand Down Expand Up @@ -470,9 +475,13 @@ def variants(bamfile, ref, variants, depths, refname, minq, annot, varthresh):
@click.option('--bootseed', default=0,
help='set seed for bootstrap generation',
show_default=True)
@click.option('--solver', default='CLARABEL',
help='solver used for estimating lineage prevalence',
show_default=True)
def boot(variants, depths, output_base, eps, barcodes, meta,
nb, nt, boxplot, confirmedonly, lineageyml, depthcutoff,
rawboots, relaxedmrca, relaxedthresh, bootseed):
rawboots, relaxedmrca, relaxedthresh, bootseed,
solver):
"""
Perform bootstrapping method for freyja using VARIANTS and DEPTHS
"""
Expand Down Expand Up @@ -511,7 +520,8 @@ def boot(variants, depths, output_base, eps, barcodes, meta,
df_barcodes, mix, depths_ = reindex_dfs(df_barcodes, mix, depths_)
lin_df, constell_df = perform_bootstrap(df_barcodes, mix, depths_,
nb, eps, nt, mapDict, muts,
boxplot, output_base, bootseed)
boxplot, output_base, bootseed,
solver)
if rawboots:
lin_df.to_csv(output_base + '_lineages_boot.csv')
constell_df.to_csv(output_base + '_summarized_boot.csv')
Expand Down
25 changes: 18 additions & 7 deletions freyja/sample_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def map_to_constellation(sample_strains, vals, mapDict):
return localDict


def solve_demixing_problem(df_barcodes, mix, depths, eps, adapt, a_eps):
def solve_demixing_problem(df_barcodes, mix, depths, eps, adapt,
a_eps, solver):
# single file problem setup, solving

dep = np.log2(depths+1)
Expand All @@ -166,10 +167,18 @@ def solve_demixing_problem(df_barcodes, mix, depths, eps, adapt, a_eps):
cost = cp.norm(A @ x - b, 1)
constraints = [sum(x) == 1, x >= 0]
prob = cp.Problem(cp.Minimize(cost), constraints)

if solver == 'ECOS':
solver_ = cp.ECOS
elif solver == 'OSQP':
solver_ = cp.OSQP
else:
solver_ = cp.CLARABEL

try:
prob.solve(verbose=False, solver=cp.ECOS)
prob.solve(verbose=False, solver=solver_)
except cp.error.SolverError:
raise ValueError('Solver error encountered, most'
raise ValueError('Solver error encountered, most '
'likely due to insufficient sequencing depth.'
'Try running with --depthcutoff.')
sol = x.value
Expand Down Expand Up @@ -207,7 +216,7 @@ def solve_demixing_problem(df_barcodes, mix, depths, eps, adapt, a_eps):

def bootstrap_parallel(jj, samplesDefining, fracDepths_adj, mix_grp,
mix, df_barcodes, eps0, muts, mapDict,
adapt, a_eps, bootseed):
adapt, a_eps, bootseed, solver):
# helper function for fast bootstrap and solve
# get sequencing depth at the position of all defining mutations
mix_boot = mix.copy()
Expand Down Expand Up @@ -254,15 +263,16 @@ def bootstrap_parallel(jj, samplesDefining, fracDepths_adj, mix_grp,
sample_strains, abundances, error = solve_demixing_problem(df_barcodes,
mix_boot_,
dps_, eps0,
adapt, a_eps)
adapt, a_eps,
solver)
localDict = map_to_constellation(sample_strains, abundances, mapDict)
return sample_strains, abundances, localDict


def perform_bootstrap(df_barcodes, mix, depths_,
numBootstraps, eps0, n_jobs,
mapDict, muts, boxplot, basename, bootseed,
adapt=0., a_eps=1E-8):
solver, adapt=0., a_eps=1E-8):
depths_.index = depths_.index.to_series().apply(lambda x:
int(x[1:len(x)-1]))
depths_ = depths_[~depths_.index.duplicated(keep='first')]
Expand Down Expand Up @@ -298,7 +308,8 @@ def perform_bootstrap(df_barcodes, mix, depths_,
mapDict,
adapt,
a_eps,
bootseed)
bootseed,
solver)
for jj0 in tqdm(range(numBootstraps)))
for i in range(len(out)):
sample_lins, abundances, localDict = out[i]
Expand Down
8 changes: 6 additions & 2 deletions freyja/tests/test_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def test_demixing(self):
# generate random sequencing depth at each position
depths = negative_binomial(50, 0.25, size=len(mix))
eps = 0.001
solver = 'CLARABEL'
sample_strains, abundances, error = solve_demixing_problem(df_barcodes,
mix, depths,
eps, 0.,
1.E-8)
1.E-8,
solver)
self.assertAlmostEqual(
abundances[sample_strains.tolist().index(strain1)], mixFracs[0])
self.assertAlmostEqual(
Expand All @@ -81,10 +83,12 @@ def test_boot(self):
numBootstraps = 3
n_jobs = 1
bootseed = 10
solver = 'CLARABEL'
lin_df, constell_df = perform_bootstrap(df_barcodes, mix, depths,
numBootstraps, eps,
n_jobs, mapDict, muts,
'', 'test', bootseed)
'', 'test', bootseed,
solver)
lin_out = lin_df.quantile([0.5])
constell_out = constell_df.quantile([0.5])
self.assertAlmostEqual(lin_out.loc[0.5, 'B.1.1.7'], 0.4, delta=0.1)
Expand Down
13 changes: 8 additions & 5 deletions freyja/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from datetime import datetime
from datetime import datetime, timedelta
import yaml
from scipy.optimize import curve_fit

Expand Down Expand Up @@ -426,7 +426,7 @@ def makePlot_simple(agg_df, lineages, outputFn, config, lineage_info,
labelList.append(label)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], loc='center left',
bbox_to_anchor=(1, 0.5), prop={'size': 4})
bbox_to_anchor=(1, 0.5))
ax.set_ylabel('Variant Prevalence')
ax.set_xticks(range(0, agg_df.shape[0]))
ax.set_xticklabels(agg_df.index,
Expand Down Expand Up @@ -499,9 +499,10 @@ def makePlot_time(agg_df, lineages, times_df, interval, outputFn,
labels=df_abundances.columns, colors=cmap_dict.values())
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], loc='center left',
bbox_to_anchor=(1, 0.5), prop={'size': 4})
bbox_to_anchor=(1, 0.5))
ax.set_ylabel('Variant Prevalence')
ax.set_ylim([0, 1])
ax.set_xlim([df_abundances.index.min(), df_abundances.index.max()])
plt.setp(ax.get_xticklabels(), rotation=90)
fig.tight_layout()
plt.savefig(outputFn)
Expand All @@ -514,13 +515,15 @@ def makePlot_time(agg_df, lineages, times_df, interval, outputFn,
label=label, color=cmap_dict[label])
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], loc='center left',
bbox_to_anchor=(1, 0.5), prop={'size': 4})
bbox_to_anchor=(1, 0.5))
ax.set_ylabel('Variant Prevalence')
locator = mdates.MonthLocator(bymonthday=1)
ax.xaxis.set_major_locator(locator)
ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(locator))
ax.set_ylim([0, 1])
ax.set_aspect(150)
ax.set_xlim([df_abundances.index.min()-timedelta(days=15),
df_abundances.index.max()+timedelta(days=15)])
fig.tight_layout()
plt.savefig(outputFn)
plt.close()
Expand All @@ -537,7 +540,7 @@ def makePlot_time(agg_df, lineages, times_df, interval, outputFn,
label=label, color=cmap_dict[label])
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], loc='center left',
bbox_to_anchor=(1, 0.5), prop={'size': 4})
bbox_to_anchor=(1, 0.5))
labelsAx = [item.split('-')[1] for item in df_abundances.index]
ax.set_xticks(range(0, len(labelsAx)))
ax.set_xticklabels(labelsAx)
Expand Down

0 comments on commit d54a8e8

Please sign in to comment.