Skip to content

Commit

Permalink
allow user-provided control-conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Aug 23, 2024
1 parent c7b7f54 commit e5b3a90
Show file tree
Hide file tree
Showing 11 changed files with 199 additions and 90 deletions.
10 changes: 10 additions & 0 deletions bean/annotate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,16 @@ def check_args(args):
raise ValueError(
"Invalid arguments: You should specify exactly one of --translate-fasta, --translate-fastas-csv, --translate-gene, translate-genes-list to translate alleles."
)
if (
args.translate_fasta is not None
or args.translate_fastas_csv is not None
or args.translate_gene is not None
or args.translate_genes_list is not None
) and not args.translate:
warn(
"fastq or gene files for translation provided without `--translate` flag. Setting `--translate` flag to True."
)
args.translate = True
if args.translate_genes_list is not None:
args.translate_genes_list = (
pd.read_csv(args.translate_genes_list, header=None).values[:, 0].tolist()
Expand Down
2 changes: 2 additions & 0 deletions bean/cli/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def main(args):
]
if match_target_pos:
counter.screen.get_edit_mat_from_uns(target_base_edits, match_target_pos)
else:
counter.screen.get_edit_mat_from_uns(target_base_edits)
counter.screen.write(f"{counter.output_dir}.h5ad")
counter.screen.to_Excel(f"{counter.output_dir}.xlsx")
info(f"Output written at:\n {counter.output_dir}.h5ad,\n {counter.output_dir}.xlsx")
Expand Down
7 changes: 5 additions & 2 deletions bean/cli/count_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,15 @@ def count_sample(R1: str, R2: str, sample_id: str, args: argparse.Namespace):
screen = counter.screen
if screen.X.max() == 0:
warn(f"Nothing counted for {sample_id}. Check your input.")
if counter.count_reporter_edits and match_target_pos:
if counter.count_reporter_edits:
screen.uns["allele_counts"] = screen.uns["allele_counts"].loc[
screen.uns["allele_counts"].allele.map(str) != "", :
]
screen.get_edit_from_allele("allele_counts", "allele")
screen.get_edit_mat_from_uns(target_base_edits, match_target_pos)
if match_target_pos:
screen.get_edit_mat_from_uns(target_base_edits, match_target_pos)
else:
screen.get_edit_mat_from_uns(target_base_edits)
info(
f"Done for {sample_id}. \n\
Output written at {counter.output_dir}.h5ad"
Expand Down
2 changes: 1 addition & 1 deletion bean/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def main(args, return_data=False):
file_logger = logging.FileHandler(f"{prefix}/bean_run.log")
file_logger.setLevel(logging.INFO)
logging.getLogger().addHandler(file_logger)
info(f"Running: {' '.join(sys.argv[:])}")
if args.cuda:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
torch.set_default_tensor_type(torch.FloatTensor)
Expand Down
5 changes: 2 additions & 3 deletions bean/model/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ def parse_args(parser=None):
"--control-condition",
default="bulk",
type=str,
help="Value in `bdata.samples[condition_col]` that indicates control experimental condition whose editing patterns will be used. Select this as the condition with the least selection- For the sorting screen, use presort (bulk). For the survival screens, use the closest one with T=0.",
help="Comma-separated list of condition values in `bdata.samples[condition_col]` that indicates control experimental condition whose editing patterns will be used.",
)

input_parser.add_argument(
"--plasmid-condition",
default="bulk",
Expand Down Expand Up @@ -167,7 +166,7 @@ def parse_args(parser=None):
"--sample-mask-col",
type=str,
default="mask",
help="Name of the column indicating the sample mask in [Reporter]Screen.samples (or AnnData.var). Sample is ignored if the value in this column is 0. This can be used to mask out low-quality samples.",
help="Name of the column indicating the sample mask in [Reporter]Screen.samples (or AnnData.var). Sample is ignored if the value in this column is 0. This can be used to mask out low-quality samples. If you don't want to mask samples out, provide `--sample-mask-col=''`.",
)

input_parser.add_argument(
Expand Down
12 changes: 10 additions & 2 deletions bean/model/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def check_args(args, bdata):
warn(
f"{args.bdata_path} does not have replicate x guide outlier mask. All guides are included in analysis."
)
if args.sample_mask_col == "":
args.sample_mask_col = None
if args.sample_mask_col is not None:
if args.sample_mask_col not in bdata.samples.columns.tolist():
raise ValueError(
Expand All @@ -139,10 +141,16 @@ def check_args(args, bdata):
raise ValueError(
f"Condition column `{args.condition_col}` set by `--condition-col` not in ReporterScreen.samples.columns:{bdata.samples.columns}. Check your input."
)
if args.control_condition not in bdata.samples[args.condition_col].tolist():
if args.selection == "survival" and args.condition_col == args.time_col:
raise ValueError(
f"No sample has control label `{args.control_condition}` (set by `--control-condition`) in ReporterScreen.samples[{args.condition_col}]: {bdata.samples[args.condition_col]}. Check your input. For the selection of this argument, see more in `--condition-col` under `bean run --help`."
f"Invalid to have the same `--condition-col` ({args.condition_col}) and `--time-col` ({args.time_col})."
)
control_condits = args.control_condition.split(",")
for control_condit in control_condits:
if control_condit not in bdata.samples[args.condition_col].astype(str).tolist():
raise ValueError(
f"No sample has control label `{args.control_condition}` (set by `--control-condition`) in ReporterScreen.samples[{args.condition_col}]: {bdata.samples[args.condition_col]}. Check your input. For the selection of this argument, see more in `--condition-col` under `bean run --help`."
)
if args.replicate_col not in bdata.samples.columns:
raise ValueError(
f"Condition column set by `--replicate-col` {args.replicate_col} not in ReporterScreen.samples.columns:{bdata.samples.columns}. Check your input."
Expand Down
74 changes: 47 additions & 27 deletions bean/model/survival_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def MixtureNormalModel(
data: VariantSurvivalReporterScreenData,
alpha_prior: float = 1,
use_bcmatch: bool = True,
use_all_timepoints_for_pi: bool = True,
sd_scale: float = 0.01,
scale_by_accessibility: bool = False,
fit_noise: bool = False,
Expand All @@ -231,6 +232,7 @@ def MixtureNormalModel(
data: Input data of type VariantSortingReporterScreenData.
alpha_prior: Prior parameter for controlling the concentration of the Dirichlet process. Defaults to 1.
use_bcmatch: Flag indicating whether to use barcode-matched counts. Defaults to True.
use_all_timepoints_for_pi: Use all available timepoints instead of the `--control-condition` timepoint.
sd_scale: Scale for the prior standard deviation. Defaults to 0.01.
scale_by_accessibility: If True, pi fitted from reporter data is scaled by accessibility.
fit_noise: Valid only when scale_by_accessibility is True. If True, parametrically fit noise of endo ~ reporter + noise.
Expand Down Expand Up @@ -295,28 +297,34 @@ def MixtureNormalModel(
pi_a_scaled.unsqueeze(0).unsqueeze(0).expand(data.n_reps, 1, -1, -1)
),
)
with time_plate:
assert pi.shape == (
data.n_reps,
1,
data.n_guides,
2,
), pi.shape
with pyro.plate("time_plate0", len(data.control_timepoint), dim=-2):
with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)):
time_pi = data.timepoints
# Accounting for sample specific overall edit rate across all guides.
# P(allele | guide, bin=bulk)
assert pi.shape == (
data.n_reps,
1,
data.n_guides,
2,
), pi.shape
# if use_all_timepoints_for_pi:
# time_pi = data.timepoints
# expanded_allele_p = pi * r.expand(
# data.n_reps, len(data.timepoints), -1, -1
# ) ** time_pi.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(
# data.n_reps, len(data.timepoints), -1, -1
# )
# pyro.sample(
# "allele_count",
# dist.Multinomial(probs=expanded_allele_p, validate_args=False),
# obs=data.allele_counts,
# )
# else:
time_pi = data.control_timepoint
# If pi is sampled in later timepoint, account for the selection.

expanded_allele_p = pi * r.expand(
data.n_reps, len(data.timepoints), -1, -1
) ** time_pi.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(
data.n_reps, len(data.timepoints), -1, -1
)
expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi
pyro.sample(
"allele_count",
"control_allele_count",
dist.Multinomial(probs=expanded_allele_p, validate_args=False),
obs=data.allele_counts,
obs=data.allele_counts_control,
)
if scale_by_accessibility:
# Endogenous target site editing rate may be different
Expand Down Expand Up @@ -396,6 +404,7 @@ def MultiMixtureNormalModel(
data: TilingSurvivalReporterScreenData,
alpha_prior=1,
use_bcmatch=True,
use_all_timepoints_for_pi: bool = True,
sd_scale=0.01,
norm_pi=False,
scale_by_accessibility=False,
Expand All @@ -410,6 +419,7 @@ def MultiMixtureNormalModel(
data: Input data of type VariantSortingReporterScreenData.
alpha_prior: Prior parameter for controlling the concentration of the Dirichlet process. Defaults to 1.
use_bcmatch: Flag indicating whether to use barcode-matched counts. Defaults to True.
use_all_timepoints_for_pi: Use all available timepoints instead of the `--control-condition` timepoint.
sd_scale: Scale for the prior standard deviation. Defaults to 0.01.
scale_by_accessibility: If True, pi fitted from reporter data is scaled by accessibility.
fit_noise: Valid only when scale_by_accessibility is True. If True, parametrically fit noise of endo ~ reporter + noise.
Expand Down Expand Up @@ -486,25 +496,35 @@ def MultiMixtureNormalModel(

with replicate_plate:
with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)):
time_pi = data.control_timepoint
pi = pyro.sample(
"pi",
dist.Dirichlet(
pi_a_scaled.unsqueeze(0).unsqueeze(0).expand(data.n_reps, 1, -1, -1)
),
)
with time_plate:
with pyro.plate("time_plate0", len(data.control_timepoint), dim=-2):
with guide_plate, poutine.mask(mask=data.repguide_mask.unsqueeze(1)):
# if use_all_timepoints_for_pi:
# time_pi = data.timepoints
# # If pi is sampled in later timepoint, account for the selection.
# expanded_allele_p = pi * r.expand(
# data.n_reps, 1, -1, -1
# ) ** time_pi.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(
# data.n_reps, len(data.timepoints), -1, -1
# )
# pyro.sample(
# "allele_count",
# dist.Multinomial(probs=expanded_allele_p, validate_args=False),
# obs=data.allele_counts,
# )
# else:
time_pi = data.control_timepoint
# If pi is sampled in later timepoint, account for the selection.
expanded_allele_p = pi * r.expand(
data.n_reps, 1, -1, -1
) ** time_pi.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(
data.n_reps, len(data.timepoints), -1, -1
)
expanded_allele_p = pi * r.expand(data.n_reps, 1, -1, -1) ** time_pi
pyro.sample(
"allele_count",
"control_allele_count",
dist.Multinomial(probs=expanded_allele_p, validate_args=False),
obs=data.allele_counts,
obs=data.allele_counts_control,
)
if scale_by_accessibility:
# Endogenous target site editing rate may be different
Expand Down
31 changes: 31 additions & 0 deletions bean/notebooks/sample_quality_report.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"bdata.uns[\"reporter_length\"] = reporter_length\n",
"bdata.uns[\"reporter_right_flank_length\"] = reporter_right_flank_length\n",
"if posctrl_col:\n",
" bdata.guides[posctrl_col] = bdata.guides[posctrl_col].astype(str)\n",
" if posctrl_col not in bdata.guides.columns:\n",
" raise ValueError(f\"--posctrl-col argument '{posctrl_col}' is not present in the input ReporterScreen.guides.columns {bdata.guides.columns}. If you do not want to use positive control gRNA annotation for LFC calculation, feed --posctrl-col='' instead.\")\n",
" if posctrl_val not in bdata.guides[posctrl_col].tolist():\n",
Expand Down Expand Up @@ -325,6 +326,13 @@
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Editing rate"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -345,6 +353,29 @@
" be.qc.plot_guide_edit_rates(bdata)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### R1-R2 recombination"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.hist(\n",
" 1-(\n",
" bdata[:, bdata.samples.condition == ctrl_cond].layers[\"X_bcmatch\"]\n",
" / bdata[:, bdata.samples.condition == ctrl_cond].X\n",
" ).mean(axis=1)\n",
")\n",
"plt.xlabel(\"Recombination rate\")\n",
"plt.ylabel(\"Frequency\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Loading

0 comments on commit e5b3a90

Please sign in to comment.