diff --git a/bean/annotate/_supporting_fn.py b/bean/annotate/_supporting_fn.py index e3b34f5..e2c6639 100755 --- a/bean/annotate/_supporting_fn.py +++ b/bean/annotate/_supporting_fn.py @@ -1,7 +1,8 @@ from copy import deepcopy -from typing import List, Tuple, Union +from typing import List, Union, Dict, Optional from tqdm.auto import tqdm -from ..framework.Edit import Edit, Allele +import numpy as np +from ..framework.Edit import Allele from ..framework.AminoAcidEdit import CodingNoncodingAllele import pandas as pd from ..annotate.translate_allele import CDS, RefBaseMismatchException @@ -9,8 +10,8 @@ def filter_allele_by_pos( allele: Allele, - pos_start: int = None, - pos_end: int = None, + pos_start: Optional[Union[float, int]] = None, + pos_end: Optional[Union[float, int]] = None, filter_rel_pos=True, ): """ @@ -24,6 +25,10 @@ def filter_allele_by_pos( filtered_edits = 0 allele_filtered = deepcopy(allele) if not (pos_start is None and pos_end is None): + if pos_start is None: + pos_start = -np.inf + if pos_end is None: + pos_end = np.inf if filter_rel_pos: for edit in allele.edits: if not (edit.rel_pos >= pos_start and edit.rel_pos < pos_end): @@ -34,7 +39,6 @@ def filter_allele_by_pos( if not (edit.pos >= pos_start and edit.pos < pos_end): filtered_edits += 1 allele_filtered.edits.remove(edit) - else: print("No threshold specified") # TODO: warn return (allele_filtered, filtered_edits) @@ -42,9 +46,9 @@ def filter_allele_by_pos( def filter_allele_by_base( allele: Allele, - allowed_base_changes: List[Tuple] = None, - allowed_ref_base: Union[List, str] = None, - allowed_alt_base: Union[List, str] = None, + allowed_base_changes: Optional[Dict[str, str]] = None, + allowed_ref_base: Optional[Union[List, str]] = None, + allowed_alt_base: Optional[Union[List, str]] = None, ): """ Filter alleles based on position and return the filtered allele and @@ -55,28 +59,29 @@ def filter_allele_by_base( allowed_ref_base = [allowed_ref_base] if isinstance(allowed_alt_base, str): allowed_alt_base = [allowed_alt_base] - if ( - not (allowed_ref_base is None and allowed_alt_base is None) - + (allowed_base_changes is None) - == 1 - ): + if (allowed_ref_base is None and allowed_alt_base is None) + ( + allowed_base_changes is None + ) != 1: print("No filters specified or misspecified filters.") - elif not allowed_base_changes is None: + elif allowed_base_changes is not None: for edit in allele.edits.copy(): - if not (edit.ref_base, edit.alt_base) in allowed_base_changes: + if ( + edit.ref_base not in allowed_base_changes + or allowed_base_changes[edit.ref_base] != edit.alt_base + ): filtered_edits += 1 allele.edits.remove(edit) - elif not allowed_ref_base is None: + elif allowed_ref_base is not None: for edit in allele.edits.copy(): if edit.ref_base not in allowed_ref_base: filtered_edits += 1 allele.edits.remove(edit) - elif not allowed_alt_base is None and edit.alt_base not in allowed_alt_base: + elif allowed_alt_base is not None and edit.alt_base not in allowed_alt_base: filtered_edits += 1 allele.edits.remove(edit) else: for edit in allele.edits.copy(): - if edit.alt_base not in allowed_alt_base: + if edit.alt_base not in allowed_alt_base: # type: ignore filtered_edits += 1 allele.edits.remove(edit) return (allele, filtered_edits) @@ -105,7 +110,9 @@ def map_alleles_to_filtered( ): guide_filtered_allele_counts = filtered_allele_counts.loc[ filtered_allele_counts.guide == guide, : - ].set_index("allele") + ].set_index( + "allele" + ) # type: ignore guide_filtered_alleles = guide_filtered_allele_counts.index.tolist() if len(guide_filtered_alleles) == 0: pass diff --git a/bean/annotate/utils.py b/bean/annotate/utils.py index dd1358f..fdd41dd 100755 --- a/bean/annotate/utils.py +++ b/bean/annotate/utils.py @@ -279,7 +279,7 @@ def parse_args(parser=None): parser.add_argument( "--filter-target-basechange", "-b", - help="Only consider target edit (stored in bdata.uns['target_base_change'])", + help="Only consider target edit (stored in bdata.uns['target_base_changes'])", action="store_true", ) parser.add_argument( diff --git a/bean/cli/count_samples.py b/bean/cli/count_samples.py index fcec874..138cdb6 100755 --- a/bean/cli/count_samples.py +++ b/bean/cli/count_samples.py @@ -42,8 +42,11 @@ def count_sample(R1: str, R2: str, sample_id: str, args: argparse.Namespace): args_dict["output_folder"] = os.path.join(args.output_folder, sample_id) base_editing_map = {"A": "G", "C": "T"} - edited_from = args_dict["edited_base"] - edited_to = base_editing_map[edited_from] + try: + target_base_edits = {k: base_editing_map[k] for k in args_dict["edited_base"]} + except KeyError as e: + raise KeyError(args_dict["edited_base"]) from e + match_target_pos = args_dict["match_target_pos"] if ( "guide_start_seqs_tbl" in args_dict @@ -75,7 +78,7 @@ def count_sample(R1: str, R2: str, sample_id: str, args: argparse.Namespace): raise ValueError( f"File {counter.output_dir}.h5ad doesn't have alllele information stored." ) from exc - screen.get_edit_mat_from_uns(edited_from, edited_to, match_target_pos) + screen.get_edit_mat_from_uns(target_base_edits, match_target_pos) info( f"Reading already existing data for {sample_id} from \n\ {counter.output_dir}.h5ad" diff --git a/bean/cli/filter.py b/bean/cli/filter.py index c707c8c..c7ec484 100755 --- a/bean/cli/filter.py +++ b/bean/cli/filter.py @@ -3,6 +3,7 @@ import sys import logging +from itertools import product import pandas as pd import bean as be import bean.annotate.filter_alleles as filter_alleles @@ -109,10 +110,9 @@ def main(args): if len(bdata.uns[allele_df_keys[-1]]) > 0 and not args.keep_indels: filtered_key = f"{allele_df_keys[-1]}_noindels" - info(f"Filtering out indels...") + info("Filtering out indels...") bdata.uns[filtered_key] = bdata.filter_allele_counts_by_base( - ["A", "T", "G", "C"], - ["A", "T", "G", "C"], + {k: v for k, v in product(["A", "C", "T", "G"], ["A", "C", "T", "G"])}, map_to_filtered=False, allele_uns_key=allele_df_keys[-1], ).reset_index(drop=True) @@ -120,13 +120,12 @@ def main(args): allele_df_keys.append(filtered_key) if len(bdata.uns[allele_df_keys[-1]]) > 0 and args.filter_target_basechange: - filtered_key = ( - f"{allele_df_keys[-1]}_{bdata.base_edited_from}.{bdata.base_edited_to}" - ) - info(f"Filtering out non-{bdata.uns['target_base_change']} edits...") + if "target_base_changes" not in bdata.uns and "target_base_change" in bdata.uns: + bdata.uns["target_base_changes"] = bdata.uns["target_base_change"] + filtered_key = f"{allele_df_keys[-1]}_{bdata.uns['target_base_changes']}" + info(f"Filtering out non-{bdata.uns['target_base_changes']} edits...") bdata.uns[filtered_key] = bdata.filter_allele_counts_by_base( - bdata.base_edited_from, - bdata.base_edited_to, + bdata.target_base_changes, map_to_filtered=False, allele_uns_key=allele_df_keys[-1], ).reset_index(drop=True) diff --git a/bean/framework/Edit.py b/bean/framework/Edit.py index afd0c26..bdd60b2 100755 --- a/bean/framework/Edit.py +++ b/bean/framework/Edit.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Iterable +from typing import Iterable, Optional import numpy as np import re from ..utils.arithmetric import jaccard @@ -12,10 +12,10 @@ class Edit: def __init__( self, rel_pos: int, - ref_base: chr, - alt_base: chr, - chrom: str = None, - offset: int = None, + ref_base: str, + alt_base: str, + chrom: Optional[str] = None, + offset: Optional[int] = None, strand: int = 1, unique_identifier=None, ): diff --git a/bean/framework/ReporterScreen.py b/bean/framework/ReporterScreen.py index 2ab6069..78cb31f 100755 --- a/bean/framework/ReporterScreen.py +++ b/bean/framework/ReporterScreen.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Collection, Iterable, List, Optional, Union, Sequence, Literal +from typing import Collection, Iterable, List, Optional, Union, Sequence, Literal, Dict import re import anndata as ad import numpy as np @@ -42,7 +42,7 @@ def _get_counts( res.columns = res.columns.map(lambda x: "{rep}_{spl}") guides = guides.join(res, how="left") guides = guides.fillna(0) - return guides + return guides def _get_edits( @@ -76,7 +76,7 @@ def _get_edits( this_edit.name = "{r}_{c}".format(r=rep, c=spl) edits = edits.join(this_edit, how="left") edits = edits.fillna(0) - return edits + return edits class ReporterScreen(Screen): @@ -85,7 +85,7 @@ def __init__( X=None, X_edit=None, X_bcmatch=None, - target_base_change: Optional[str] = None, + target_base_changes: Optional[str] = None, replicate_label: str = "rep", condition_label: str = "bin", tiling: Optional[bool] = None, @@ -132,12 +132,14 @@ def __init__( self.uns[k].loc[:, "aa_allele"] = self.uns[k].aa_allele.map( lambda s: CodingNoncodingAllele.from_str(s) ) - if target_base_change is not None: - if not re.fullmatch(r"[ACTG]>[ACTG]", target_base_change): + if target_base_changes is not None: + if not re.fullmatch( + r"([ACTG]>[ACTG])(,[ACTG]>[ACTG])*", target_base_changes + ): raise ValueError( - f"target_base_change {target_base_change} doesn't match the allowed base change. Feed in valid base change string ex) 'A>G', 'C>T'" + f"target_base_changes {target_base_changes} doesn't match the allowed base change. Feed in valid base change string ex) 'A>G', 'C>T'" ) - self.uns["target_base_change"] = target_base_change + self.uns["target_base_changes"] = target_base_changes if tiling is not None: self.uns["tiling"] = tiling @@ -157,55 +159,26 @@ def edit_tables(self): def allele_tables(self): return {k: self.uns[k] for k in self.uns.keys() if "allele" in k} - @property - def base_edited_from(self): - return self.uns["target_base_change"][0] + # @property + # def base_edited_from(self): + # return self.uns["target_base_change"][0] - @property - def base_edited_to(self): - return self.uns["target_base_change"][-1] + # @property + # def base_edited_to(self): + # return self.uns["target_base_change"][-1] @property - def target_base_change(self): - return self.uns["target_base_change"] + def target_base_changes(self): + try: + basechanges = self.uns["target_base_changes"] + except KeyError: + basechanges = self.uns["target_base_change"] + return {basechange[0]: basechange[-1] for basechange in basechanges.split(",")} @property def tiling(self): return self.uns["tiling"] - @classmethod - def from_file_paths( - cls, - reps: List[str] = None, - samples: List[str] = None, - guide_info_file_name: str = None, - guide_count_filenames: str = None, - guide_bcmatched_count_filenames: str = None, - edit_count_filenames: str = None, - ): - guide_info = pd.read_csv(guide_info_file_name).set_index("name") - guides = pd.DataFrame(index=(pd.read_csv(guide_info_file_name)["name"])) - guides_lenient = _get_counts(guide_count_filenames, guides, reps, samples) - edits_ag = _get_edits( - edit_count_filenames, guide_info, reps, samples, count_exact=False - ) - edits_exact = _get_edits(edit_count_filenames, guide_info, reps, samples) - if guide_bcmatched_count_filenames is not None: - guides_bcmatch = _get_counts( - guide_bcmatched_count_filenames, guides, reps, samples - ) - repscreen = cls( - guides_lenient, edits_exact, X_bcmatch=guides_bcmatch, guides=guide_info - ) - else: - repscreen = cls(guides_lenient, edits_exact, guides=guide_info) - repscreen.layers["edits_ag"] = edits_ag - repscreen.samples["replicate"] = np.repeat(reps, len(samples)) - repscreen.samples["sort"] = np.tile(samples, len(reps)) - repscreen.uns["replicates"] = reps - repscreen.uns["samples"] = samples - return repscreen - @classmethod def from_adata(cls, adata): X_bcmatch = adata.layers["X_bcmatch"] if "X_bcmatch" in adata.layers else None @@ -384,8 +357,7 @@ def _remove_zero_count(key): def get_edit_mat_from_uns( self, - ref_base: Optional[str] = None, - alt_base: Optional[str] = None, + target_base_edit: Optional[Dict[str, str]] = None, match_target_position: Optional[bool] = None, rel_pos_start=0, rel_pos_end=np.Inf, @@ -399,20 +371,15 @@ def get_edit_mat_from_uns( Args -- - ref_base: reference base of editing event to count. If not provided, default value in `.base_edited_from` is used. - alt_base: alternate base of editing event to count. If not provided, default value in `.base_edited_to` is used. - match_target_position: If `True`, edits with `.rel_pos` that matches `.guides[target_pos_col]` are counted. - If `False`, edits with `.rel_pos` in range [rel_pos_start, rel_pos_end) is counted. + target_base_edit: Dictionary of base edited from, base edited to. rel_pos_start: Position from which edits will be counted when `match_target_position` is `False`. rel_pos_end: Position until where edits will be counted when `match_target_position` is `False`. rel_pos_is_reporter: `rel_pos_start` and `rel_pos_end` is relative to the reporter position. If `False`, those are treated to be relative of spacer position. target_pos_col: Column name in `.guides` DataFrame that has target position information when `match_target_position` is `True`. edit_count_key: Key of the edit counts DataFrame to be used to count the edits (`.uns[edit_count_key]`). """ - if ref_base is None: - ref_base = self.base_edited_from - if alt_base is None: - alt_base = self.base_edited_to + if target_base_edit is None: + target_base_edit = self.target_base_changes if match_target_position is None: match_target_position = not self.tiling if edit_count_key not in self.uns or len(self.uns[edit_count_key]) == 0: @@ -431,7 +398,9 @@ def get_edit_mat_from_uns( edits["ref_base"] = edits.edit.map(lambda e: e.ref_base) edits["alt_base"] = edits.edit.map(lambda e: e.alt_base) edits = edits.loc[ - (edits.ref_base == ref_base) & (edits.alt_base == alt_base), : + (edits.ref_base.map(lambda r: r in target_base_edit)) + & (edits.ref_base.map(target_base_edit) == edits.alt_base), + :, ].reset_index() guide_len = self.guides.sequence.map(len) guide_name_to_idx = self.guides.reset_index().reset_index().set_index("name") @@ -475,11 +444,11 @@ def get_edit_mat_from_uns( def get_guide_edit_rate( self, normalize_by_editable_base: Optional[bool] = None, - edited_base: Optional[str] = None, + edited_bases: Optional[Union[List[str], str]] = None, editable_base_start=3, editable_base_end=8, bcmatch_thres=1, - prior_weight: float = None, + prior_weight: Optional[float] = None, return_result=False, count_layer="X_bcmatch", edit_layer="edits", @@ -492,18 +461,30 @@ def get_guide_edit_rate( prior weight to use when calculating posterior edit rate. unsorted_condition_label: Editing rate is calculated only for the samples that have this string in the sample index. """ - if edited_base is None: - edited_base = self.base_edited_from + if edited_bases is None: + edited_bases = list(self.target_base_changes.keys()) + if isinstance(edited_bases, str): + edited_bases = [edited_bases] + if normalize_by_editable_base is None: normalize_by_editable_base = self.tiling if self.layers[count_layer] is None or self.layers[edit_layer] is None: raise ValueError("edits or barcode matched guide counts not available.") num_targetable_sites = 1.0 if normalize_by_editable_base: - if edited_base not in ["A", "C", "T", "G"]: - raise ValueError("Specify the correct edited_base") - num_targetable_sites = self.guides.sequence.map( - lambda s: s[editable_base_start:editable_base_end].count(edited_base) + num_targetable_sites_all = [] + for edited_base in edited_bases: + if edited_base not in ["A", "C", "T", "G"]: + raise ValueError("Specify the correct edited_base") + num_targetable_sites_all.append( + self.guides.sequence.map( + lambda s: s[editable_base_start:editable_base_end].count( + edited_base + ) + ) + ) + num_targetable_sites = pd.concat(num_targetable_sites_all, axis=1).sum( + axis=1 ) if unsorted_condition_label is not None: bulk_idx = np.where( @@ -538,11 +519,11 @@ def get_guide_edit_rate( def get_edit_rate( self, normalize_by_editable_base=None, - edited_base=None, + edited_base: Optional[Union[str, List[str]]] = None, editable_base_start=3, editable_base_end=8, bcmatch_thres=1, - prior_weight: float = None, + prior_weight: Optional[float] = None, return_result=False, count_layer="X_bcmatch", edit_layer="edits", @@ -554,17 +535,38 @@ def get_edit_rate( """ if normalize_by_editable_base is None: normalize_by_editable_base = self.tiling - if edited_base is None: - edited_base = self.base_edited_from if self.layers[count_layer] is None or self.layers[edit_layer] is None: raise ValueError("edits or barcode matched guide counts not available.") num_targetable_sites = 1.0 if normalize_by_editable_base: - if edited_base not in ["A", "C", "T", "G"]: - raise ValueError("Specify the correct edited_base") - num_targetable_sites = self.guides.sequence.map( - lambda s: s[editable_base_start:editable_base_end].count(edited_base) - ) + if edited_base is None: + edited_base = [ + basechange[0] for basechange in self.target_base_changes.split(",") + ] + for eb in edited_base: + if eb not in ["A", "C", "T", "G"]: + raise ValueError( + f"ReporterScreen target base change ({self.target_base_changes}) is invalid: {eb} as edited base." + ) + + if isinstance(edited_base, str): + num_targetable_sites = self.guides.sequence.map( + lambda s: s[editable_base_start:editable_base_end].count( + edited_base + ) + ) + else: + num_targetable_sites_bases = [] + for eb in edited_base: + num_targetable_sites_bases.append( + self.guides.sequence.map( + lambda s: s[editable_base_start:editable_base_end].count(eb) + ) + ) + num_targetable_sites = pd.concat( + num_targetable_sites_bases, axis=1 + ).sum(axis=1) + if prior_weight is None: prior_weight = 1 n_edits = self.layers[edit_layer] @@ -714,8 +716,7 @@ def filter_allele_counts_by_pos( def filter_allele_counts_by_base( self, - ref_base: Union[List, str] = "A", - alt_base: Union[List, str] = "G", + target_base_edits: Dict[str, str], allele_uns_key="allele_counts", map_to_filtered=True, jaccard_threshold: float = 0.5, @@ -730,7 +731,7 @@ def filter_allele_counts_by_base( filtered_allele, filtered_edits = zip( *allele_count_df.allele.map( lambda a: filter_allele_by_base( - a, allowed_ref_base=ref_base, allowed_alt_base=alt_base + a, allowed_base_changes=target_base_edits ) ) ) @@ -926,7 +927,7 @@ def concat(screens: Sequence[ReporterScreen], *args, axis: Literal[0, 1] = 1, ** if axis == 0: for k in keys: - if k in ["target_base_change", "tiling", "sample_covariates"]: + if k in ["target_base_changes", "tiling", "sample_covariates"]: adata.uns[k] = screens[0].uns[k] continue elif "edit" not in k and "allele" not in k: @@ -936,7 +937,7 @@ def concat(screens: Sequence[ReporterScreen], *args, axis: Literal[0, 1] = 1, ** if axis == 1: # If combining multiple samples, edit/allele tables should be merged. for k in keys: - if k in ["target_base_change", "tiling", "sample_covariates"]: + if k in ["target_base_changes", "tiling", "sample_covariates"]: adata.uns[k] = screens[0].uns[k] continue elif "edit" not in k and "allele" not in k: diff --git a/bean/mapping/CRISPResso2Align.pyx b/bean/mapping/CRISPResso2Align.pyx index f2c17a6..1b3ed50 100755 --- a/bean/mapping/CRISPResso2Align.pyx +++ b/bean/mapping/CRISPResso2Align.pyx @@ -63,7 +63,7 @@ def read_matrix(path): @cython.boundscheck(False) @cython.nonecheck(False) -def global_align_base_editor(str pystr_seqj, str pystr_seqi, str ref_base, str alt_base, +def global_align_base_editor(str pystr_seqj, str pystr_seqi, dict target_base_edits, np.ndarray[DTYPE_INT, ndim=2] matrix, np.ndarray[DTYPE_INT,ndim=1] gap_incentive, int gap_open=-1, int gap_extend=-1, ): @@ -377,8 +377,14 @@ def global_align_base_editor(str pystr_seqj, str pystr_seqi, str ref_base, str a print('seqj: ' + str(seqj) + ' seqi: ' + str(seqi)) raise Exception('wtf4!:pointer: %i', i) # print('at end, currMatrix is ' + str(currMatrix)) - if not (ci == ref_base and cj == alt_base): + is_target_edit = False + for ref_base, alt_base in target_base_edits.items(): + if ci == ref_base and cj == alt_base: + is_target_edit = True + if not is_target_edit: align_counter += 1 + #if not (ci == ref_base and cj == alt_base): + # align_counter += 1 try: align_j = tmp_align_j[:align_counter].decode('UTF-8', 'strict') finally: diff --git a/bean/mapping/GuideEditCounter.py b/bean/mapping/GuideEditCounter.py index c2f1c2b..d7b2a1d 100755 --- a/bean/mapping/GuideEditCounter.py +++ b/bean/mapping/GuideEditCounter.py @@ -1,10 +1,10 @@ -from typing import Tuple +from typing import Tuple, Optional, Sequence, Union import gzip import logging from copy import deepcopy import os import sys -from os import path +import random import numpy as np import pandas as pd @@ -15,7 +15,6 @@ from ._supporting_fn import ( _base_edit_to_from, - _check_readname_match, _get_edited_allele_crispresso, _get_fastq_handle, _read_count_match, @@ -62,8 +61,9 @@ class GuideEditCounter: def __init__(self, **kwargs): self.R1_filename = kwargs["R1"] self.R2_filename = kwargs["R2"] - self.base_edited_from = kwargs["edited_base"] - self.base_edited_to = _base_edit_to_from(self.base_edited_from) + self.target_base_edits = { + k: _base_edit_to_from(k) for k in kwargs["edited_base"] + } self.min_average_read_quality = kwargs["min_average_read_quality"] self.min_single_bp_quality = kwargs["min_single_bp_quality"] @@ -103,7 +103,12 @@ def __init__(self, **kwargs): X_bcmatch=np.zeros((len(self.guides_info_df), 1)), guides=self.guides_info_df, samples=pd.DataFrame(index=[self.database_id]), - target_base_change=f"{self.base_edited_from}>{self.base_edited_to}", + target_base_changes=",".join( + [ + f"{base_edited_from}>{base_edited_to}" + for base_edited_from, base_edited_to in self.target_base_edits.items() + ] + ), tiling=kwargs["tiling"], ) self.screen.guides["guide_len"] = self.screen.guides.sequence.map(len) @@ -147,17 +152,25 @@ def __init__(self, **kwargs): if not self.objectify_allele: info(f"{self.name}: Storing allele as strings.") self.keep_intermediate = kwargs["keep_intermediate"] + self.id_number = random.randint(0, int(1e6)) self.semimatch = 0 self.bcmatch = 0 self.nomatch = 0 self.duplicate_match = 0 self.duplicate_match_wo_barcode = 0 + def mask_sequence(self, seq, revcomp=False): + seq_masked = seq + for base_edited_from, base_edited_to in self.target_base_edits.items(): + if revcomp: + base_edited_from = base_revcomp[base_edited_from] + base_edited_to = base_revcomp[base_edited_to] + seq_masked = seq_masked.replace(base_edited_from, base_edited_to) + return seq_masked + def masked_equal(self, seq1, seq2): """Tests if two sequences are equal, ignoring the allowed base transition.""" - return seq1.replace(self.base_edited_from, self.base_edited_to) == seq2.replace( - self.base_edited_from, self.base_edited_to - ) + return self.mask_sequence(seq1) == self.mask_sequence(seq2) def _set_sgRNA_df(self): """set gRNA info dataframe""" @@ -174,10 +187,10 @@ def _set_sgRNA_df(self): sgRNA_df = sgRNA_df.set_index("name") self.guides_info_df = sgRNA_df self.guides_info_df["masked_sequence"] = self.guides_info_df.sequence.map( - lambda s: s.replace(self.base_edited_from, self.base_edited_to) + self.mask_sequence ) self.guides_info_df["masked_barcode"] = self.guides_info_df.barcode.map( - lambda s: s.replace(self.base_edited_from, self.base_edited_to) + self.mask_sequence ) self.guide_lengths = sgRNA_df.sequence.map(lambda s: len(s)).unique() @@ -213,14 +226,13 @@ def get_counts(self): ) if self.count_reporter_edits or self.count_guide_edits: _write_alignment_matrix( - self.base_edited_from, - self.base_edited_to, + self.target_base_edits, self.output_dir + "/.aln_mat.txt", ) if self.count_only_bcmatched: # count X - self._get_guide_counts_bcmatch() + raise NotImplementedError else: # count both bc matched & unmatched guide counts self._get_guide_counts_bcmatch_semimatch() @@ -359,8 +371,7 @@ def _count_guide_edits( guide_edit_allele, score = _get_edited_allele_crispresso( ref_seq=ref_guide_seq, query_seq=read_guide_seq, - ref_base=self.base_edited_from, - alt_base=self.base_edited_to, + target_base_edits=self.target_base_edits, aln_mat_path=self.output_dir + "/.aln_mat.txt", offset=0, strand=guide_strand, @@ -399,7 +410,9 @@ def _get_strand_offset_from_guide_index(self, guide_idx: int) -> Tuple[int, int] offset = 0 return (guide_strand, offset) - def _update_counted_allele(self, guide_idx: int, allele: Allele) -> None: + def _update_counted_allele( + self, guide_idx: int, allele: Union[Allele, str] + ) -> None: """Add allele count to self.guide_to_allele dictionary.""" if guide_idx in self.guide_to_allele.keys(): if allele in self.guide_to_allele[guide_idx].keys(): @@ -409,7 +422,9 @@ def _update_counted_allele(self, guide_idx: int, allele: Allele) -> None: else: self.guide_to_allele[guide_idx] = {allele: 1} - def _update_counted_guide_allele(self, guide_idx: int, allele: Allele) -> None: + def _update_counted_guide_allele( + self, guide_idx: int, allele: Union[Allele, str] + ) -> None: """Add allele count to self.guide_to_allele dictionary.""" if guide_idx in self.screen.uns["guide_edit_counts"].keys(): if allele in self.screen.uns["guide_edit_counts"][guide_idx].keys(): @@ -420,7 +435,10 @@ def _update_counted_guide_allele(self, guide_idx: int, allele: Allele) -> None: self.screen.uns["guide_edit_counts"][guide_idx] = {allele: 1} def _update_counted_allele_and_guideAllele( - self, guide_idx: int, allele: Allele, guide_allele: Allele + self, + guide_idx: int, + allele: Union[Allele, str], + guide_allele: Union[str, Allele], ) -> None: """Add count of (guide allele, reporter allele) combination to self.guide_reporter_allele dictionary.""" if guide_idx in self.guide_to_guide_reporter_allele.keys(): @@ -443,8 +461,8 @@ def _count_reporter_edits( R1_seq: str, R2_record: SeqIO.SeqRecord, R2_start: int = 0, - single_base_qual_cutoff: str = 30, - guide_allele: Allele = None, + single_base_qual_cutoff: int = 30, + guide_allele: Optional[Union[str, Allele]] = None, ): """ Count edits in a single read to save as allele. @@ -472,8 +490,7 @@ def _count_reporter_edits( allele, score = _get_edited_allele_crispresso( ref_seq=ref_reporter_seq, query_seq=read_reporter_seq, - ref_base=self.base_edited_from, - alt_base=self.base_edited_to, + target_base_edits=self.target_base_edits, aln_mat_path=self.output_dir + "/.aln_mat.txt", offset=offset, strand=guide_strand, @@ -499,7 +516,7 @@ def _count_reporter_edits( def _get_guide_counts_bcmatch_semimatch( self, bcmatch_layer="X_bcmatch", semimatch_layer="X_semimatch" ): - self.screen.layers[semimatch_layer] = np.zeros_like((self.screen.X)) + self.screen.layers[semimatch_layer] = np.zeros_like(self.screen.X) R1_iter, R2_iter = self._get_fastq_iterators( self.filtered_R1_filename, self.filtered_R2_filename ) @@ -614,12 +631,9 @@ def _write_guide_reporter_allele( def get_guide_seq(self, R1_seq, R2_seq, guide_length): """This can be edited by user based on the read construct.""" - # _seq_match = np.where(seq.replace(self.base_edited_from, self.base_edited_to) == self.screen.guides.masked_sequence)[0] if self.guide_end_seq == "": - guide_start_idx = R1_seq.replace( - self.base_edited_from, self.base_edited_to - ).find( - self.guide_start_seq.replace(self.base_edited_from, self.base_edited_to) + guide_start_idx = self.mask_sequence(R1_seq).find( + self.mask_sequence(self.guide_start_seq) ) if guide_start_idx == -1: return None @@ -628,10 +642,8 @@ def get_guide_seq(self, R1_seq, R2_seq, guide_length): guide_start_idx = guide_start_idx + len(self.guide_start_seq) gRNA_seq = R1_seq[guide_start_idx : (guide_start_idx + guide_length)] else: - guide_end_idx = R1_seq.replace( - self.base_edited_from, self.base_edited_to - ).find( - self.guide_end_seq.replace(self.base_edited_from, self.base_edited_to) + guide_end_idx = self.mask_sequence(R1_seq).find( + self.mask_sequence(self.guide_end_seq) ) if guide_end_idx == -1: return None @@ -641,15 +653,13 @@ def get_guide_seq(self, R1_seq, R2_seq, guide_length): return None if len(gRNA_seq) != guide_length else gRNA_seq - def get_guide_seq_qual(self, R1_record: SeqIO.SeqRecord, guide_length): + def get_guide_seq_qual( + self, R1_record: SeqIO.SeqRecord, guide_length: int + ) -> Tuple[str, Sequence]: R1_seq = R1_record.seq - guide_start_idx = R1_seq.replace( - self.base_edited_from, self.base_edited_to - ).find(self.guide_start_seq.replace(self.base_edited_from, self.base_edited_to)) - if guide_start_idx == -1: - return None, None - if guide_start_idx + guide_length >= len(R1_seq): - return None, None + guide_start_idx = self.mask_sequence(R1_seq).find( + self.mask_sequence(self.guide_start_seq) + ) guide_start_idx = guide_start_idx + len(self.guide_start_seq) seq = R1_record[guide_start_idx : (guide_start_idx + guide_length)] return (str(seq.seq), seq.letter_annotations["phred_quality"]) @@ -670,13 +680,8 @@ def get_reporter_seq_qual(self, R2_record: SeqIO.SeqRecord, R2_start=0): def get_barcode(self, R1_seq, R2_seq): if self.barcode_start_seq != "": - barcode_start_idx = R2_seq.replace( - base_revcomp[self.base_edited_from], base_revcomp[self.base_edited_to] - ).find( - revcomp(self.barcode_start_seq).replace( - base_revcomp[self.base_edited_from], - base_revcomp[self.base_edited_to], - ) + barcode_start_idx = self.mask_sequence(R2_seq, revcomp=True).find( + self.mask_sequence(revcomp(self.barcode_start_seq), revcomp=True) ) if barcode_start_idx == -1: return -1, "" @@ -698,12 +703,10 @@ def _match_read_to_sgRNA_bcmatch_semimatch(self, R1_seq: str, R2_seq: str): continue _seq_match = np.where( - seq.replace(self.base_edited_from, self.base_edited_to) - == self.screen.guides.masked_sequence + self.mask_sequence(seq) == self.screen.guides.masked_sequence )[0] _bc_match = np.where( - guide_barcode.replace(self.base_edited_from, self.base_edited_to) - == self.screen.guides.masked_barcode + self.mask_sequence(guide_barcode) == self.screen.guides.masked_barcode )[0] bc_match_idx = np.append( @@ -733,7 +736,7 @@ def get_gRNA_barcode(self, R1_seq, R2_seq): def _get_fastq_handle( self, - out_type: str = None, + out_type: Optional[str] = None, ): assert out_type in { "semimatch", @@ -794,7 +797,7 @@ def _check_names_filter_fastq(self, filter_by_qual=False): def _check_readname_match_and_filter_quality( self, R1_iter, R2_iter, filter_by_qual=False - ) -> Tuple[int, int]: + ) -> int: R1_filtered = gzip.open(self.filtered_R1_filename, "wt+") R2_filtered = gzip.open(self.filtered_R2_filename, "wt+") diff --git a/bean/mapping/_supporting_fn.py b/bean/mapping/_supporting_fn.py index 9859cc7..62c0f79 100755 --- a/bean/mapping/_supporting_fn.py +++ b/bean/mapping/_supporting_fn.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Union, Dict, Optional, Tuple import subprocess as sb import numpy as np import pandas as pd @@ -13,7 +13,7 @@ class InputFileError(Exception): pass -def _base_edit_to_from(start_base: chr = "A"): +def _base_edit_to_from(start_base: str = "A"): try: base_map = {"A": "G", "C": "T"} except KeyError: @@ -138,32 +138,32 @@ def _get_edited_allele( if ref_nt == sample_nt: continue else: - edit = Edit(i - start_pos, ref_nt, sample_nt, offset, strand=strand) + edit = Edit(i - start_pos, ref_nt, sample_nt, offset=offset, strand=strand) allele.add(edit) return allele def _write_alignment_matrix( - ref_base: str, alt_base: str, path, allow_complementary=False + target_base_edits: Dict[str, str], path, allow_complementary=False ): """ Writes base substitution matrix """ bases = ["A", "C", "T", "G"] - if not (ref_base in bases and alt_base in bases): - raise ValueError( - "Specified ref base '{}' or alt base '{}' isn't valid".format( - ref_base, alt_base + for ref_base, alt_base in target_base_edits.items(): + if ref_base not in bases or alt_base not in bases: + raise ValueError( + f"Specified ref base '{ref_base}' or alt base '{alt_base}' isn't valid" ) - ) mat = np.ones((4, 4), dtype=int) * -4 np.fill_diagonal(mat, 5) aln_df = pd.DataFrame(mat, index=bases, columns=bases) - aln_df.loc[ref_base, alt_base] = 0 - if allow_complementary: - comp_map = {"A": "T", "C": "G", "T": "A", "G": "C"} - aln_df.loc[comp_map[ref_base], comp_map[alt_base]] = 0 + for ref_base, alt_base in target_base_edits.items(): + aln_df.loc[ref_base, alt_base] = 0 + if allow_complementary: + comp_map = {"A": "T", "C": "G", "T": "A", "G": "C"} + aln_df.loc[comp_map[ref_base], comp_map[alt_base]] = 0 aln_df.to_csv(path, sep=" ") @@ -174,8 +174,8 @@ def _get_allele_from_alignment( strand: int, start_pos: int, end_pos: int, - chrom: str = None, - positionwise_quality: np.ndarray = None, + chrom: Optional[str] = None, + positionwise_quality: Optional[np.ndarray] = None, quality_thres: float = -1, ): assert len(ref_aligned) == len(query_aligned) @@ -222,26 +222,24 @@ def _get_allele_from_alignment( def _get_edited_allele_crispresso( ref_seq: str, query_seq: str, - ref_base: str, - alt_base: str, + target_base_edits: Dict[str, str], aln_mat_path: str, offset: int, strand: int = 1, - chrom: str = None, + chrom: Optional[str] = None, start_pos: int = 0, end_pos: int = 100, - positionwise_quality: np.ndarray = None, + positionwise_quality: Optional[np.ndarray] = None, quality_thres: float = 30, objectify_allele=True, -): +) -> Tuple[Union[Allele, str], float]: aln_matrix = read_matrix(aln_mat_path) assert strand in [-1, +1] gap_incentive = np.zeros(len(ref_seq) + 1, dtype=int) query_aligned, ref_aligned, score = global_align_base_editor( query_seq, ref_seq, - ref_base, - alt_base, + target_base_edits, aln_matrix, gap_incentive, gap_open=-20, diff --git a/bean/mapping/utils.py b/bean/mapping/utils.py index bdb7b96..5064aa4 100755 --- a/bean/mapping/utils.py +++ b/bean/mapping/utils.py @@ -276,13 +276,25 @@ def _check_arguments(args, info_logger, warn_logger, error_logger): _check_file(args.sgRNA_filename) # Edited base should be one of A/C/T/G + edited_bases = [] if args.edited_base.upper() not in ["A", "C", "T", "G"]: - raise ValueError( - f"The edited base should be one of A/C/T/G, {args.edited_base} provided." - ) - - edited_base = args.edited_base.upper() - info_logger(f"Using specified edited base: {edited_base}") + if "," in args.edited_base: + bases = args.edited_base.split(",") + for base in bases: + if base not in ["A", "C", "T", "G"]: + raise ValueError( + f"The edited base should be one of A/C/T/G, {args.edited_base} provided." + ) + edited_bases.append(base.upper()) + else: + raise ValueError( + f"The edited base should be one of A/C/T/G, {args.edited_base} provided." + ) + if len(edited_bases) == 0: + args.edited_base = [args.edited_base.upper()] + else: + args.edited_base = edited_bases + info_logger(f"Using specified edited base: {args.edited_base}") info_logger( f"Using guide barcode length {args.guide_bc_len}, guide start '{args.guide_start_seq}'" ) diff --git a/bean/notebooks/profile_editing_preference.ipynb b/bean/notebooks/profile_editing_preference.ipynb index 2965673..05dae33 100755 --- a/bean/notebooks/profile_editing_preference.ipynb +++ b/bean/notebooks/profile_editing_preference.ipynb @@ -29,15 +29,17 @@ "import numpy as np\n", "import pandas as pd\n", "from tqdm.auto import tqdm\n", - "\n", + "import logging\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", + "import matplotlib.patches as patches\n", "import seaborn as sns\n", "import logomaker\n", "\n", "import bean as be\n", "from bean import Edit\n", - "import bean.plotting.editing_patterns" + "import bean.plotting.editing_patterns\n", + "logging.getLogger(\"matplotlib.font_manager\").setLevel(logging.ERROR)" ] }, { @@ -614,334 +616,12 @@ "_, profile_df = be.pl.editing_patterns.plot_by_pos_behive(\n", " cedit_rates_df,\n", " cdata_bulk,\n", - " target_basechange=cdata_bulk.uns['target_base_change'], \n", - " nonref_base_changes = None\n", + " target_basechanges=cdata_bulk.target_base_changes, \n", ")\n", "profile_df.to_csv(f\"{output_prefix}_behive_like_profile.csv\")\n", "if save_fig: plt.savefig(f\"{output_prefix}_behive_like_profile.pdf\", bbox_inches = 'tight')" ] }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "ag_rates = profile_df[cdata_bulk.uns['target_base_change']]\n", - "window_sum_maxpos = ag_rates.rolling(max_editing_window_length).sum().argmax()\n", - "window_end = ag_rates.index[window_sum_maxpos]\n", - "window_start = window_end - max_editing_window_length + 1" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "3" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "window_start" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "8" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "window_end" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.patches as patches" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
base_change | \n", - "A>C | \n", - "A>T | \n", - "A>G | \n", - "C>T | \n", - "C>G | \n", - "
---|---|---|---|---|---|
spacer_pos | \n", - "\n", - " | \n", - " | \n", - " | \n", - " | \n", - " |
1 | \n", - "0.000251 | \n", - "0.001344 | \n", - "0.231231 | \n", - "0.002183 | \n", - "0.001298 | \n", - "
2 | \n", - "0.000175 | \n", - "0.001010 | \n", - "0.230791 | \n", - "0.002154 | \n", - "0.001042 | \n", - "
3 | \n", - "0.000295 | \n", - "0.001137 | \n", - "0.275571 | \n", - "0.002942 | \n", - "0.001145 | \n", - "
4 | \n", - "0.000125 | \n", - "0.001094 | \n", - "0.336687 | \n", - "0.002605 | \n", - "0.001785 | \n", - "
5 | \n", - "0.000144 | \n", - "0.001337 | \n", - "0.355146 | \n", - "0.004523 | \n", - "0.001151 | \n", - "
6 | \n", - "0.000102 | \n", - "0.001129 | \n", - "0.356709 | \n", - "0.004231 | \n", - "0.001516 | \n", - "
7 | \n", - "0.000211 | \n", - "0.000865 | \n", - "0.292859 | \n", - "0.003761 | \n", - "0.000748 | \n", - "
8 | \n", - "0.000184 | \n", - "0.001677 | \n", - "0.242974 | \n", - "0.002755 | \n", - "0.000939 | \n", - "
9 | \n", - "0.000504 | \n", - "0.001305 | \n", - "0.215772 | \n", - "0.002144 | \n", - "0.000936 | \n", - "
10 | \n", - "0.000140 | \n", - "0.000423 | \n", - "0.172920 | \n", - "0.001933 | \n", - "0.000815 | \n", - "
11 | \n", - "0.000166 | \n", - "0.001235 | \n", - "0.165328 | \n", - "0.002058 | \n", - "0.000775 | \n", - "
12 | \n", - "0.000138 | \n", - "0.000724 | \n", - "0.129227 | \n", - "0.001166 | \n", - "0.001041 | \n", - "
13 | \n", - "0.000257 | \n", - "0.000850 | \n", - "0.131265 | \n", - "0.002194 | \n", - "0.001379 | \n", - "
14 | \n", - "0.000092 | \n", - "0.000813 | \n", - "0.104590 | \n", - "0.001344 | \n", - "0.000267 | \n", - "
15 | \n", - "0.000143 | \n", - "0.000780 | \n", - "0.110987 | \n", - "0.001596 | \n", - "0.000162 | \n", - "
16 | \n", - "0.000307 | \n", - "0.000638 | \n", - "0.096397 | \n", - "0.002289 | \n", - "0.000732 | \n", - "
17 | \n", - "0.000414 | \n", - "0.000454 | \n", - "0.106874 | \n", - "0.001819 | \n", - "0.000649 | \n", - "
18 | \n", - "0.000070 | \n", - "0.000577 | \n", - "0.095303 | \n", - "0.001182 | \n", - "0.000128 | \n", - "
19 | \n", - "0.000166 | \n", - "0.000852 | \n", - "0.092396 | \n", - "0.000697 | \n", - "0.000892 | \n", - "
20 | \n", - "0.000136 | \n", - "0.001039 | \n", - "0.085923 | \n", - "0.001261 | \n", - "0.000188 | \n", - "