diff --git a/bean/annotate/_supporting_fn.py b/bean/annotate/_supporting_fn.py index e3b34f5..e2c6639 100755 --- a/bean/annotate/_supporting_fn.py +++ b/bean/annotate/_supporting_fn.py @@ -1,7 +1,8 @@ from copy import deepcopy -from typing import List, Tuple, Union +from typing import List, Union, Dict, Optional from tqdm.auto import tqdm -from ..framework.Edit import Edit, Allele +import numpy as np +from ..framework.Edit import Allele from ..framework.AminoAcidEdit import CodingNoncodingAllele import pandas as pd from ..annotate.translate_allele import CDS, RefBaseMismatchException @@ -9,8 +10,8 @@ def filter_allele_by_pos( allele: Allele, - pos_start: int = None, - pos_end: int = None, + pos_start: Optional[Union[float, int]] = None, + pos_end: Optional[Union[float, int]] = None, filter_rel_pos=True, ): """ @@ -24,6 +25,10 @@ def filter_allele_by_pos( filtered_edits = 0 allele_filtered = deepcopy(allele) if not (pos_start is None and pos_end is None): + if pos_start is None: + pos_start = -np.inf + if pos_end is None: + pos_end = np.inf if filter_rel_pos: for edit in allele.edits: if not (edit.rel_pos >= pos_start and edit.rel_pos < pos_end): @@ -34,7 +39,6 @@ def filter_allele_by_pos( if not (edit.pos >= pos_start and edit.pos < pos_end): filtered_edits += 1 allele_filtered.edits.remove(edit) - else: print("No threshold specified") # TODO: warn return (allele_filtered, filtered_edits) @@ -42,9 +46,9 @@ def filter_allele_by_pos( def filter_allele_by_base( allele: Allele, - allowed_base_changes: List[Tuple] = None, - allowed_ref_base: Union[List, str] = None, - allowed_alt_base: Union[List, str] = None, + allowed_base_changes: Optional[Dict[str, str]] = None, + allowed_ref_base: Optional[Union[List, str]] = None, + allowed_alt_base: Optional[Union[List, str]] = None, ): """ Filter alleles based on position and return the filtered allele and @@ -55,28 +59,29 @@ def filter_allele_by_base( allowed_ref_base = [allowed_ref_base] if isinstance(allowed_alt_base, str): allowed_alt_base = [allowed_alt_base] - if ( - not (allowed_ref_base is None and allowed_alt_base is None) - + (allowed_base_changes is None) - == 1 - ): + if (allowed_ref_base is None and allowed_alt_base is None) + ( + allowed_base_changes is None + ) != 1: print("No filters specified or misspecified filters.") - elif not allowed_base_changes is None: + elif allowed_base_changes is not None: for edit in allele.edits.copy(): - if not (edit.ref_base, edit.alt_base) in allowed_base_changes: + if ( + edit.ref_base not in allowed_base_changes + or allowed_base_changes[edit.ref_base] != edit.alt_base + ): filtered_edits += 1 allele.edits.remove(edit) - elif not allowed_ref_base is None: + elif allowed_ref_base is not None: for edit in allele.edits.copy(): if edit.ref_base not in allowed_ref_base: filtered_edits += 1 allele.edits.remove(edit) - elif not allowed_alt_base is None and edit.alt_base not in allowed_alt_base: + elif allowed_alt_base is not None and edit.alt_base not in allowed_alt_base: filtered_edits += 1 allele.edits.remove(edit) else: for edit in allele.edits.copy(): - if edit.alt_base not in allowed_alt_base: + if edit.alt_base not in allowed_alt_base: # type: ignore filtered_edits += 1 allele.edits.remove(edit) return (allele, filtered_edits) @@ -105,7 +110,9 @@ def map_alleles_to_filtered( ): guide_filtered_allele_counts = filtered_allele_counts.loc[ filtered_allele_counts.guide == guide, : - ].set_index("allele") + ].set_index( + "allele" + ) # type: ignore guide_filtered_alleles = guide_filtered_allele_counts.index.tolist() if len(guide_filtered_alleles) == 0: pass diff --git a/bean/annotate/utils.py b/bean/annotate/utils.py index dd1358f..fdd41dd 100755 --- a/bean/annotate/utils.py +++ b/bean/annotate/utils.py @@ -279,7 +279,7 @@ def parse_args(parser=None): parser.add_argument( "--filter-target-basechange", "-b", - help="Only consider target edit (stored in bdata.uns['target_base_change'])", + help="Only consider target edit (stored in bdata.uns['target_base_changes'])", action="store_true", ) parser.add_argument( diff --git a/bean/cli/count_samples.py b/bean/cli/count_samples.py index fcec874..138cdb6 100755 --- a/bean/cli/count_samples.py +++ b/bean/cli/count_samples.py @@ -42,8 +42,11 @@ def count_sample(R1: str, R2: str, sample_id: str, args: argparse.Namespace): args_dict["output_folder"] = os.path.join(args.output_folder, sample_id) base_editing_map = {"A": "G", "C": "T"} - edited_from = args_dict["edited_base"] - edited_to = base_editing_map[edited_from] + try: + target_base_edits = {k: base_editing_map[k] for k in args_dict["edited_base"]} + except KeyError as e: + raise KeyError(args_dict["edited_base"]) from e + match_target_pos = args_dict["match_target_pos"] if ( "guide_start_seqs_tbl" in args_dict @@ -75,7 +78,7 @@ def count_sample(R1: str, R2: str, sample_id: str, args: argparse.Namespace): raise ValueError( f"File {counter.output_dir}.h5ad doesn't have alllele information stored." ) from exc - screen.get_edit_mat_from_uns(edited_from, edited_to, match_target_pos) + screen.get_edit_mat_from_uns(target_base_edits, match_target_pos) info( f"Reading already existing data for {sample_id} from \n\ {counter.output_dir}.h5ad" diff --git a/bean/cli/filter.py b/bean/cli/filter.py index c707c8c..c7ec484 100755 --- a/bean/cli/filter.py +++ b/bean/cli/filter.py @@ -3,6 +3,7 @@ import sys import logging +from itertools import product import pandas as pd import bean as be import bean.annotate.filter_alleles as filter_alleles @@ -109,10 +110,9 @@ def main(args): if len(bdata.uns[allele_df_keys[-1]]) > 0 and not args.keep_indels: filtered_key = f"{allele_df_keys[-1]}_noindels" - info(f"Filtering out indels...") + info("Filtering out indels...") bdata.uns[filtered_key] = bdata.filter_allele_counts_by_base( - ["A", "T", "G", "C"], - ["A", "T", "G", "C"], + {k: v for k, v in product(["A", "C", "T", "G"], ["A", "C", "T", "G"])}, map_to_filtered=False, allele_uns_key=allele_df_keys[-1], ).reset_index(drop=True) @@ -120,13 +120,12 @@ def main(args): allele_df_keys.append(filtered_key) if len(bdata.uns[allele_df_keys[-1]]) > 0 and args.filter_target_basechange: - filtered_key = ( - f"{allele_df_keys[-1]}_{bdata.base_edited_from}.{bdata.base_edited_to}" - ) - info(f"Filtering out non-{bdata.uns['target_base_change']} edits...") + if "target_base_changes" not in bdata.uns and "target_base_change" in bdata.uns: + bdata.uns["target_base_changes"] = bdata.uns["target_base_change"] + filtered_key = f"{allele_df_keys[-1]}_{bdata.uns['target_base_changes']}" + info(f"Filtering out non-{bdata.uns['target_base_changes']} edits...") bdata.uns[filtered_key] = bdata.filter_allele_counts_by_base( - bdata.base_edited_from, - bdata.base_edited_to, + bdata.target_base_changes, map_to_filtered=False, allele_uns_key=allele_df_keys[-1], ).reset_index(drop=True) diff --git a/bean/framework/Edit.py b/bean/framework/Edit.py index afd0c26..bdd60b2 100755 --- a/bean/framework/Edit.py +++ b/bean/framework/Edit.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Iterable +from typing import Iterable, Optional import numpy as np import re from ..utils.arithmetric import jaccard @@ -12,10 +12,10 @@ class Edit: def __init__( self, rel_pos: int, - ref_base: chr, - alt_base: chr, - chrom: str = None, - offset: int = None, + ref_base: str, + alt_base: str, + chrom: Optional[str] = None, + offset: Optional[int] = None, strand: int = 1, unique_identifier=None, ): diff --git a/bean/framework/ReporterScreen.py b/bean/framework/ReporterScreen.py index 2ab6069..78cb31f 100755 --- a/bean/framework/ReporterScreen.py +++ b/bean/framework/ReporterScreen.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Collection, Iterable, List, Optional, Union, Sequence, Literal +from typing import Collection, Iterable, List, Optional, Union, Sequence, Literal, Dict import re import anndata as ad import numpy as np @@ -42,7 +42,7 @@ def _get_counts( res.columns = res.columns.map(lambda x: "{rep}_{spl}") guides = guides.join(res, how="left") guides = guides.fillna(0) - return guides + return guides def _get_edits( @@ -76,7 +76,7 @@ def _get_edits( this_edit.name = "{r}_{c}".format(r=rep, c=spl) edits = edits.join(this_edit, how="left") edits = edits.fillna(0) - return edits + return edits class ReporterScreen(Screen): @@ -85,7 +85,7 @@ def __init__( X=None, X_edit=None, X_bcmatch=None, - target_base_change: Optional[str] = None, + target_base_changes: Optional[str] = None, replicate_label: str = "rep", condition_label: str = "bin", tiling: Optional[bool] = None, @@ -132,12 +132,14 @@ def __init__( self.uns[k].loc[:, "aa_allele"] = self.uns[k].aa_allele.map( lambda s: CodingNoncodingAllele.from_str(s) ) - if target_base_change is not None: - if not re.fullmatch(r"[ACTG]>[ACTG]", target_base_change): + if target_base_changes is not None: + if not re.fullmatch( + r"([ACTG]>[ACTG])(,[ACTG]>[ACTG])*", target_base_changes + ): raise ValueError( - f"target_base_change {target_base_change} doesn't match the allowed base change. Feed in valid base change string ex) 'A>G', 'C>T'" + f"target_base_changes {target_base_changes} doesn't match the allowed base change. Feed in valid base change string ex) 'A>G', 'C>T'" ) - self.uns["target_base_change"] = target_base_change + self.uns["target_base_changes"] = target_base_changes if tiling is not None: self.uns["tiling"] = tiling @@ -157,55 +159,26 @@ def edit_tables(self): def allele_tables(self): return {k: self.uns[k] for k in self.uns.keys() if "allele" in k} - @property - def base_edited_from(self): - return self.uns["target_base_change"][0] + # @property + # def base_edited_from(self): + # return self.uns["target_base_change"][0] - @property - def base_edited_to(self): - return self.uns["target_base_change"][-1] + # @property + # def base_edited_to(self): + # return self.uns["target_base_change"][-1] @property - def target_base_change(self): - return self.uns["target_base_change"] + def target_base_changes(self): + try: + basechanges = self.uns["target_base_changes"] + except KeyError: + basechanges = self.uns["target_base_change"] + return {basechange[0]: basechange[-1] for basechange in basechanges.split(",")} @property def tiling(self): return self.uns["tiling"] - @classmethod - def from_file_paths( - cls, - reps: List[str] = None, - samples: List[str] = None, - guide_info_file_name: str = None, - guide_count_filenames: str = None, - guide_bcmatched_count_filenames: str = None, - edit_count_filenames: str = None, - ): - guide_info = pd.read_csv(guide_info_file_name).set_index("name") - guides = pd.DataFrame(index=(pd.read_csv(guide_info_file_name)["name"])) - guides_lenient = _get_counts(guide_count_filenames, guides, reps, samples) - edits_ag = _get_edits( - edit_count_filenames, guide_info, reps, samples, count_exact=False - ) - edits_exact = _get_edits(edit_count_filenames, guide_info, reps, samples) - if guide_bcmatched_count_filenames is not None: - guides_bcmatch = _get_counts( - guide_bcmatched_count_filenames, guides, reps, samples - ) - repscreen = cls( - guides_lenient, edits_exact, X_bcmatch=guides_bcmatch, guides=guide_info - ) - else: - repscreen = cls(guides_lenient, edits_exact, guides=guide_info) - repscreen.layers["edits_ag"] = edits_ag - repscreen.samples["replicate"] = np.repeat(reps, len(samples)) - repscreen.samples["sort"] = np.tile(samples, len(reps)) - repscreen.uns["replicates"] = reps - repscreen.uns["samples"] = samples - return repscreen - @classmethod def from_adata(cls, adata): X_bcmatch = adata.layers["X_bcmatch"] if "X_bcmatch" in adata.layers else None @@ -384,8 +357,7 @@ def _remove_zero_count(key): def get_edit_mat_from_uns( self, - ref_base: Optional[str] = None, - alt_base: Optional[str] = None, + target_base_edit: Optional[Dict[str, str]] = None, match_target_position: Optional[bool] = None, rel_pos_start=0, rel_pos_end=np.Inf, @@ -399,20 +371,15 @@ def get_edit_mat_from_uns( Args -- - ref_base: reference base of editing event to count. If not provided, default value in `.base_edited_from` is used. - alt_base: alternate base of editing event to count. If not provided, default value in `.base_edited_to` is used. - match_target_position: If `True`, edits with `.rel_pos` that matches `.guides[target_pos_col]` are counted. - If `False`, edits with `.rel_pos` in range [rel_pos_start, rel_pos_end) is counted. + target_base_edit: Dictionary of base edited from, base edited to. rel_pos_start: Position from which edits will be counted when `match_target_position` is `False`. rel_pos_end: Position until where edits will be counted when `match_target_position` is `False`. rel_pos_is_reporter: `rel_pos_start` and `rel_pos_end` is relative to the reporter position. If `False`, those are treated to be relative of spacer position. target_pos_col: Column name in `.guides` DataFrame that has target position information when `match_target_position` is `True`. edit_count_key: Key of the edit counts DataFrame to be used to count the edits (`.uns[edit_count_key]`). """ - if ref_base is None: - ref_base = self.base_edited_from - if alt_base is None: - alt_base = self.base_edited_to + if target_base_edit is None: + target_base_edit = self.target_base_changes if match_target_position is None: match_target_position = not self.tiling if edit_count_key not in self.uns or len(self.uns[edit_count_key]) == 0: @@ -431,7 +398,9 @@ def get_edit_mat_from_uns( edits["ref_base"] = edits.edit.map(lambda e: e.ref_base) edits["alt_base"] = edits.edit.map(lambda e: e.alt_base) edits = edits.loc[ - (edits.ref_base == ref_base) & (edits.alt_base == alt_base), : + (edits.ref_base.map(lambda r: r in target_base_edit)) + & (edits.ref_base.map(target_base_edit) == edits.alt_base), + :, ].reset_index() guide_len = self.guides.sequence.map(len) guide_name_to_idx = self.guides.reset_index().reset_index().set_index("name") @@ -475,11 +444,11 @@ def get_edit_mat_from_uns( def get_guide_edit_rate( self, normalize_by_editable_base: Optional[bool] = None, - edited_base: Optional[str] = None, + edited_bases: Optional[Union[List[str], str]] = None, editable_base_start=3, editable_base_end=8, bcmatch_thres=1, - prior_weight: float = None, + prior_weight: Optional[float] = None, return_result=False, count_layer="X_bcmatch", edit_layer="edits", @@ -492,18 +461,30 @@ def get_guide_edit_rate( prior weight to use when calculating posterior edit rate. unsorted_condition_label: Editing rate is calculated only for the samples that have this string in the sample index. """ - if edited_base is None: - edited_base = self.base_edited_from + if edited_bases is None: + edited_bases = list(self.target_base_changes.keys()) + if isinstance(edited_bases, str): + edited_bases = [edited_bases] + if normalize_by_editable_base is None: normalize_by_editable_base = self.tiling if self.layers[count_layer] is None or self.layers[edit_layer] is None: raise ValueError("edits or barcode matched guide counts not available.") num_targetable_sites = 1.0 if normalize_by_editable_base: - if edited_base not in ["A", "C", "T", "G"]: - raise ValueError("Specify the correct edited_base") - num_targetable_sites = self.guides.sequence.map( - lambda s: s[editable_base_start:editable_base_end].count(edited_base) + num_targetable_sites_all = [] + for edited_base in edited_bases: + if edited_base not in ["A", "C", "T", "G"]: + raise ValueError("Specify the correct edited_base") + num_targetable_sites_all.append( + self.guides.sequence.map( + lambda s: s[editable_base_start:editable_base_end].count( + edited_base + ) + ) + ) + num_targetable_sites = pd.concat(num_targetable_sites_all, axis=1).sum( + axis=1 ) if unsorted_condition_label is not None: bulk_idx = np.where( @@ -538,11 +519,11 @@ def get_guide_edit_rate( def get_edit_rate( self, normalize_by_editable_base=None, - edited_base=None, + edited_base: Optional[Union[str, List[str]]] = None, editable_base_start=3, editable_base_end=8, bcmatch_thres=1, - prior_weight: float = None, + prior_weight: Optional[float] = None, return_result=False, count_layer="X_bcmatch", edit_layer="edits", @@ -554,17 +535,38 @@ def get_edit_rate( """ if normalize_by_editable_base is None: normalize_by_editable_base = self.tiling - if edited_base is None: - edited_base = self.base_edited_from if self.layers[count_layer] is None or self.layers[edit_layer] is None: raise ValueError("edits or barcode matched guide counts not available.") num_targetable_sites = 1.0 if normalize_by_editable_base: - if edited_base not in ["A", "C", "T", "G"]: - raise ValueError("Specify the correct edited_base") - num_targetable_sites = self.guides.sequence.map( - lambda s: s[editable_base_start:editable_base_end].count(edited_base) - ) + if edited_base is None: + edited_base = [ + basechange[0] for basechange in self.target_base_changes.split(",") + ] + for eb in edited_base: + if eb not in ["A", "C", "T", "G"]: + raise ValueError( + f"ReporterScreen target base change ({self.target_base_changes}) is invalid: {eb} as edited base." + ) + + if isinstance(edited_base, str): + num_targetable_sites = self.guides.sequence.map( + lambda s: s[editable_base_start:editable_base_end].count( + edited_base + ) + ) + else: + num_targetable_sites_bases = [] + for eb in edited_base: + num_targetable_sites_bases.append( + self.guides.sequence.map( + lambda s: s[editable_base_start:editable_base_end].count(eb) + ) + ) + num_targetable_sites = pd.concat( + num_targetable_sites_bases, axis=1 + ).sum(axis=1) + if prior_weight is None: prior_weight = 1 n_edits = self.layers[edit_layer] @@ -714,8 +716,7 @@ def filter_allele_counts_by_pos( def filter_allele_counts_by_base( self, - ref_base: Union[List, str] = "A", - alt_base: Union[List, str] = "G", + target_base_edits: Dict[str, str], allele_uns_key="allele_counts", map_to_filtered=True, jaccard_threshold: float = 0.5, @@ -730,7 +731,7 @@ def filter_allele_counts_by_base( filtered_allele, filtered_edits = zip( *allele_count_df.allele.map( lambda a: filter_allele_by_base( - a, allowed_ref_base=ref_base, allowed_alt_base=alt_base + a, allowed_base_changes=target_base_edits ) ) ) @@ -926,7 +927,7 @@ def concat(screens: Sequence[ReporterScreen], *args, axis: Literal[0, 1] = 1, ** if axis == 0: for k in keys: - if k in ["target_base_change", "tiling", "sample_covariates"]: + if k in ["target_base_changes", "tiling", "sample_covariates"]: adata.uns[k] = screens[0].uns[k] continue elif "edit" not in k and "allele" not in k: @@ -936,7 +937,7 @@ def concat(screens: Sequence[ReporterScreen], *args, axis: Literal[0, 1] = 1, ** if axis == 1: # If combining multiple samples, edit/allele tables should be merged. for k in keys: - if k in ["target_base_change", "tiling", "sample_covariates"]: + if k in ["target_base_changes", "tiling", "sample_covariates"]: adata.uns[k] = screens[0].uns[k] continue elif "edit" not in k and "allele" not in k: diff --git a/bean/mapping/CRISPResso2Align.pyx b/bean/mapping/CRISPResso2Align.pyx index f2c17a6..1b3ed50 100755 --- a/bean/mapping/CRISPResso2Align.pyx +++ b/bean/mapping/CRISPResso2Align.pyx @@ -63,7 +63,7 @@ def read_matrix(path): @cython.boundscheck(False) @cython.nonecheck(False) -def global_align_base_editor(str pystr_seqj, str pystr_seqi, str ref_base, str alt_base, +def global_align_base_editor(str pystr_seqj, str pystr_seqi, dict target_base_edits, np.ndarray[DTYPE_INT, ndim=2] matrix, np.ndarray[DTYPE_INT,ndim=1] gap_incentive, int gap_open=-1, int gap_extend=-1, ): @@ -377,8 +377,14 @@ def global_align_base_editor(str pystr_seqj, str pystr_seqi, str ref_base, str a print('seqj: ' + str(seqj) + ' seqi: ' + str(seqi)) raise Exception('wtf4!:pointer: %i', i) # print('at end, currMatrix is ' + str(currMatrix)) - if not (ci == ref_base and cj == alt_base): + is_target_edit = False + for ref_base, alt_base in target_base_edits.items(): + if ci == ref_base and cj == alt_base: + is_target_edit = True + if not is_target_edit: align_counter += 1 + #if not (ci == ref_base and cj == alt_base): + # align_counter += 1 try: align_j = tmp_align_j[:align_counter].decode('UTF-8', 'strict') finally: diff --git a/bean/mapping/GuideEditCounter.py b/bean/mapping/GuideEditCounter.py index c2f1c2b..d7b2a1d 100755 --- a/bean/mapping/GuideEditCounter.py +++ b/bean/mapping/GuideEditCounter.py @@ -1,10 +1,10 @@ -from typing import Tuple +from typing import Tuple, Optional, Sequence, Union import gzip import logging from copy import deepcopy import os import sys -from os import path +import random import numpy as np import pandas as pd @@ -15,7 +15,6 @@ from ._supporting_fn import ( _base_edit_to_from, - _check_readname_match, _get_edited_allele_crispresso, _get_fastq_handle, _read_count_match, @@ -62,8 +61,9 @@ class GuideEditCounter: def __init__(self, **kwargs): self.R1_filename = kwargs["R1"] self.R2_filename = kwargs["R2"] - self.base_edited_from = kwargs["edited_base"] - self.base_edited_to = _base_edit_to_from(self.base_edited_from) + self.target_base_edits = { + k: _base_edit_to_from(k) for k in kwargs["edited_base"] + } self.min_average_read_quality = kwargs["min_average_read_quality"] self.min_single_bp_quality = kwargs["min_single_bp_quality"] @@ -103,7 +103,12 @@ def __init__(self, **kwargs): X_bcmatch=np.zeros((len(self.guides_info_df), 1)), guides=self.guides_info_df, samples=pd.DataFrame(index=[self.database_id]), - target_base_change=f"{self.base_edited_from}>{self.base_edited_to}", + target_base_changes=",".join( + [ + f"{base_edited_from}>{base_edited_to}" + for base_edited_from, base_edited_to in self.target_base_edits.items() + ] + ), tiling=kwargs["tiling"], ) self.screen.guides["guide_len"] = self.screen.guides.sequence.map(len) @@ -147,17 +152,25 @@ def __init__(self, **kwargs): if not self.objectify_allele: info(f"{self.name}: Storing allele as strings.") self.keep_intermediate = kwargs["keep_intermediate"] + self.id_number = random.randint(0, int(1e6)) self.semimatch = 0 self.bcmatch = 0 self.nomatch = 0 self.duplicate_match = 0 self.duplicate_match_wo_barcode = 0 + def mask_sequence(self, seq, revcomp=False): + seq_masked = seq + for base_edited_from, base_edited_to in self.target_base_edits.items(): + if revcomp: + base_edited_from = base_revcomp[base_edited_from] + base_edited_to = base_revcomp[base_edited_to] + seq_masked = seq_masked.replace(base_edited_from, base_edited_to) + return seq_masked + def masked_equal(self, seq1, seq2): """Tests if two sequences are equal, ignoring the allowed base transition.""" - return seq1.replace(self.base_edited_from, self.base_edited_to) == seq2.replace( - self.base_edited_from, self.base_edited_to - ) + return self.mask_sequence(seq1) == self.mask_sequence(seq2) def _set_sgRNA_df(self): """set gRNA info dataframe""" @@ -174,10 +187,10 @@ def _set_sgRNA_df(self): sgRNA_df = sgRNA_df.set_index("name") self.guides_info_df = sgRNA_df self.guides_info_df["masked_sequence"] = self.guides_info_df.sequence.map( - lambda s: s.replace(self.base_edited_from, self.base_edited_to) + self.mask_sequence ) self.guides_info_df["masked_barcode"] = self.guides_info_df.barcode.map( - lambda s: s.replace(self.base_edited_from, self.base_edited_to) + self.mask_sequence ) self.guide_lengths = sgRNA_df.sequence.map(lambda s: len(s)).unique() @@ -213,14 +226,13 @@ def get_counts(self): ) if self.count_reporter_edits or self.count_guide_edits: _write_alignment_matrix( - self.base_edited_from, - self.base_edited_to, + self.target_base_edits, self.output_dir + "/.aln_mat.txt", ) if self.count_only_bcmatched: # count X - self._get_guide_counts_bcmatch() + raise NotImplementedError else: # count both bc matched & unmatched guide counts self._get_guide_counts_bcmatch_semimatch() @@ -359,8 +371,7 @@ def _count_guide_edits( guide_edit_allele, score = _get_edited_allele_crispresso( ref_seq=ref_guide_seq, query_seq=read_guide_seq, - ref_base=self.base_edited_from, - alt_base=self.base_edited_to, + target_base_edits=self.target_base_edits, aln_mat_path=self.output_dir + "/.aln_mat.txt", offset=0, strand=guide_strand, @@ -399,7 +410,9 @@ def _get_strand_offset_from_guide_index(self, guide_idx: int) -> Tuple[int, int] offset = 0 return (guide_strand, offset) - def _update_counted_allele(self, guide_idx: int, allele: Allele) -> None: + def _update_counted_allele( + self, guide_idx: int, allele: Union[Allele, str] + ) -> None: """Add allele count to self.guide_to_allele dictionary.""" if guide_idx in self.guide_to_allele.keys(): if allele in self.guide_to_allele[guide_idx].keys(): @@ -409,7 +422,9 @@ def _update_counted_allele(self, guide_idx: int, allele: Allele) -> None: else: self.guide_to_allele[guide_idx] = {allele: 1} - def _update_counted_guide_allele(self, guide_idx: int, allele: Allele) -> None: + def _update_counted_guide_allele( + self, guide_idx: int, allele: Union[Allele, str] + ) -> None: """Add allele count to self.guide_to_allele dictionary.""" if guide_idx in self.screen.uns["guide_edit_counts"].keys(): if allele in self.screen.uns["guide_edit_counts"][guide_idx].keys(): @@ -420,7 +435,10 @@ def _update_counted_guide_allele(self, guide_idx: int, allele: Allele) -> None: self.screen.uns["guide_edit_counts"][guide_idx] = {allele: 1} def _update_counted_allele_and_guideAllele( - self, guide_idx: int, allele: Allele, guide_allele: Allele + self, + guide_idx: int, + allele: Union[Allele, str], + guide_allele: Union[str, Allele], ) -> None: """Add count of (guide allele, reporter allele) combination to self.guide_reporter_allele dictionary.""" if guide_idx in self.guide_to_guide_reporter_allele.keys(): @@ -443,8 +461,8 @@ def _count_reporter_edits( R1_seq: str, R2_record: SeqIO.SeqRecord, R2_start: int = 0, - single_base_qual_cutoff: str = 30, - guide_allele: Allele = None, + single_base_qual_cutoff: int = 30, + guide_allele: Optional[Union[str, Allele]] = None, ): """ Count edits in a single read to save as allele. @@ -472,8 +490,7 @@ def _count_reporter_edits( allele, score = _get_edited_allele_crispresso( ref_seq=ref_reporter_seq, query_seq=read_reporter_seq, - ref_base=self.base_edited_from, - alt_base=self.base_edited_to, + target_base_edits=self.target_base_edits, aln_mat_path=self.output_dir + "/.aln_mat.txt", offset=offset, strand=guide_strand, @@ -499,7 +516,7 @@ def _count_reporter_edits( def _get_guide_counts_bcmatch_semimatch( self, bcmatch_layer="X_bcmatch", semimatch_layer="X_semimatch" ): - self.screen.layers[semimatch_layer] = np.zeros_like((self.screen.X)) + self.screen.layers[semimatch_layer] = np.zeros_like(self.screen.X) R1_iter, R2_iter = self._get_fastq_iterators( self.filtered_R1_filename, self.filtered_R2_filename ) @@ -614,12 +631,9 @@ def _write_guide_reporter_allele( def get_guide_seq(self, R1_seq, R2_seq, guide_length): """This can be edited by user based on the read construct.""" - # _seq_match = np.where(seq.replace(self.base_edited_from, self.base_edited_to) == self.screen.guides.masked_sequence)[0] if self.guide_end_seq == "": - guide_start_idx = R1_seq.replace( - self.base_edited_from, self.base_edited_to - ).find( - self.guide_start_seq.replace(self.base_edited_from, self.base_edited_to) + guide_start_idx = self.mask_sequence(R1_seq).find( + self.mask_sequence(self.guide_start_seq) ) if guide_start_idx == -1: return None @@ -628,10 +642,8 @@ def get_guide_seq(self, R1_seq, R2_seq, guide_length): guide_start_idx = guide_start_idx + len(self.guide_start_seq) gRNA_seq = R1_seq[guide_start_idx : (guide_start_idx + guide_length)] else: - guide_end_idx = R1_seq.replace( - self.base_edited_from, self.base_edited_to - ).find( - self.guide_end_seq.replace(self.base_edited_from, self.base_edited_to) + guide_end_idx = self.mask_sequence(R1_seq).find( + self.mask_sequence(self.guide_end_seq) ) if guide_end_idx == -1: return None @@ -641,15 +653,13 @@ def get_guide_seq(self, R1_seq, R2_seq, guide_length): return None if len(gRNA_seq) != guide_length else gRNA_seq - def get_guide_seq_qual(self, R1_record: SeqIO.SeqRecord, guide_length): + def get_guide_seq_qual( + self, R1_record: SeqIO.SeqRecord, guide_length: int + ) -> Tuple[str, Sequence]: R1_seq = R1_record.seq - guide_start_idx = R1_seq.replace( - self.base_edited_from, self.base_edited_to - ).find(self.guide_start_seq.replace(self.base_edited_from, self.base_edited_to)) - if guide_start_idx == -1: - return None, None - if guide_start_idx + guide_length >= len(R1_seq): - return None, None + guide_start_idx = self.mask_sequence(R1_seq).find( + self.mask_sequence(self.guide_start_seq) + ) guide_start_idx = guide_start_idx + len(self.guide_start_seq) seq = R1_record[guide_start_idx : (guide_start_idx + guide_length)] return (str(seq.seq), seq.letter_annotations["phred_quality"]) @@ -670,13 +680,8 @@ def get_reporter_seq_qual(self, R2_record: SeqIO.SeqRecord, R2_start=0): def get_barcode(self, R1_seq, R2_seq): if self.barcode_start_seq != "": - barcode_start_idx = R2_seq.replace( - base_revcomp[self.base_edited_from], base_revcomp[self.base_edited_to] - ).find( - revcomp(self.barcode_start_seq).replace( - base_revcomp[self.base_edited_from], - base_revcomp[self.base_edited_to], - ) + barcode_start_idx = self.mask_sequence(R2_seq, revcomp=True).find( + self.mask_sequence(revcomp(self.barcode_start_seq), revcomp=True) ) if barcode_start_idx == -1: return -1, "" @@ -698,12 +703,10 @@ def _match_read_to_sgRNA_bcmatch_semimatch(self, R1_seq: str, R2_seq: str): continue _seq_match = np.where( - seq.replace(self.base_edited_from, self.base_edited_to) - == self.screen.guides.masked_sequence + self.mask_sequence(seq) == self.screen.guides.masked_sequence )[0] _bc_match = np.where( - guide_barcode.replace(self.base_edited_from, self.base_edited_to) - == self.screen.guides.masked_barcode + self.mask_sequence(guide_barcode) == self.screen.guides.masked_barcode )[0] bc_match_idx = np.append( @@ -733,7 +736,7 @@ def get_gRNA_barcode(self, R1_seq, R2_seq): def _get_fastq_handle( self, - out_type: str = None, + out_type: Optional[str] = None, ): assert out_type in { "semimatch", @@ -794,7 +797,7 @@ def _check_names_filter_fastq(self, filter_by_qual=False): def _check_readname_match_and_filter_quality( self, R1_iter, R2_iter, filter_by_qual=False - ) -> Tuple[int, int]: + ) -> int: R1_filtered = gzip.open(self.filtered_R1_filename, "wt+") R2_filtered = gzip.open(self.filtered_R2_filename, "wt+") diff --git a/bean/mapping/_supporting_fn.py b/bean/mapping/_supporting_fn.py index 9859cc7..62c0f79 100755 --- a/bean/mapping/_supporting_fn.py +++ b/bean/mapping/_supporting_fn.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Union, Dict, Optional, Tuple import subprocess as sb import numpy as np import pandas as pd @@ -13,7 +13,7 @@ class InputFileError(Exception): pass -def _base_edit_to_from(start_base: chr = "A"): +def _base_edit_to_from(start_base: str = "A"): try: base_map = {"A": "G", "C": "T"} except KeyError: @@ -138,32 +138,32 @@ def _get_edited_allele( if ref_nt == sample_nt: continue else: - edit = Edit(i - start_pos, ref_nt, sample_nt, offset, strand=strand) + edit = Edit(i - start_pos, ref_nt, sample_nt, offset=offset, strand=strand) allele.add(edit) return allele def _write_alignment_matrix( - ref_base: str, alt_base: str, path, allow_complementary=False + target_base_edits: Dict[str, str], path, allow_complementary=False ): """ Writes base substitution matrix """ bases = ["A", "C", "T", "G"] - if not (ref_base in bases and alt_base in bases): - raise ValueError( - "Specified ref base '{}' or alt base '{}' isn't valid".format( - ref_base, alt_base + for ref_base, alt_base in target_base_edits.items(): + if ref_base not in bases or alt_base not in bases: + raise ValueError( + f"Specified ref base '{ref_base}' or alt base '{alt_base}' isn't valid" ) - ) mat = np.ones((4, 4), dtype=int) * -4 np.fill_diagonal(mat, 5) aln_df = pd.DataFrame(mat, index=bases, columns=bases) - aln_df.loc[ref_base, alt_base] = 0 - if allow_complementary: - comp_map = {"A": "T", "C": "G", "T": "A", "G": "C"} - aln_df.loc[comp_map[ref_base], comp_map[alt_base]] = 0 + for ref_base, alt_base in target_base_edits.items(): + aln_df.loc[ref_base, alt_base] = 0 + if allow_complementary: + comp_map = {"A": "T", "C": "G", "T": "A", "G": "C"} + aln_df.loc[comp_map[ref_base], comp_map[alt_base]] = 0 aln_df.to_csv(path, sep=" ") @@ -174,8 +174,8 @@ def _get_allele_from_alignment( strand: int, start_pos: int, end_pos: int, - chrom: str = None, - positionwise_quality: np.ndarray = None, + chrom: Optional[str] = None, + positionwise_quality: Optional[np.ndarray] = None, quality_thres: float = -1, ): assert len(ref_aligned) == len(query_aligned) @@ -222,26 +222,24 @@ def _get_allele_from_alignment( def _get_edited_allele_crispresso( ref_seq: str, query_seq: str, - ref_base: str, - alt_base: str, + target_base_edits: Dict[str, str], aln_mat_path: str, offset: int, strand: int = 1, - chrom: str = None, + chrom: Optional[str] = None, start_pos: int = 0, end_pos: int = 100, - positionwise_quality: np.ndarray = None, + positionwise_quality: Optional[np.ndarray] = None, quality_thres: float = 30, objectify_allele=True, -): +) -> Tuple[Union[Allele, str], float]: aln_matrix = read_matrix(aln_mat_path) assert strand in [-1, +1] gap_incentive = np.zeros(len(ref_seq) + 1, dtype=int) query_aligned, ref_aligned, score = global_align_base_editor( query_seq, ref_seq, - ref_base, - alt_base, + target_base_edits, aln_matrix, gap_incentive, gap_open=-20, diff --git a/bean/mapping/utils.py b/bean/mapping/utils.py index bdb7b96..5064aa4 100755 --- a/bean/mapping/utils.py +++ b/bean/mapping/utils.py @@ -276,13 +276,25 @@ def _check_arguments(args, info_logger, warn_logger, error_logger): _check_file(args.sgRNA_filename) # Edited base should be one of A/C/T/G + edited_bases = [] if args.edited_base.upper() not in ["A", "C", "T", "G"]: - raise ValueError( - f"The edited base should be one of A/C/T/G, {args.edited_base} provided." - ) - - edited_base = args.edited_base.upper() - info_logger(f"Using specified edited base: {edited_base}") + if "," in args.edited_base: + bases = args.edited_base.split(",") + for base in bases: + if base not in ["A", "C", "T", "G"]: + raise ValueError( + f"The edited base should be one of A/C/T/G, {args.edited_base} provided." + ) + edited_bases.append(base.upper()) + else: + raise ValueError( + f"The edited base should be one of A/C/T/G, {args.edited_base} provided." + ) + if len(edited_bases) == 0: + args.edited_base = [args.edited_base.upper()] + else: + args.edited_base = edited_bases + info_logger(f"Using specified edited base: {args.edited_base}") info_logger( f"Using guide barcode length {args.guide_bc_len}, guide start '{args.guide_start_seq}'" ) diff --git a/bean/notebooks/profile_editing_preference.ipynb b/bean/notebooks/profile_editing_preference.ipynb index 2965673..05dae33 100755 --- a/bean/notebooks/profile_editing_preference.ipynb +++ b/bean/notebooks/profile_editing_preference.ipynb @@ -29,15 +29,17 @@ "import numpy as np\n", "import pandas as pd\n", "from tqdm.auto import tqdm\n", - "\n", + "import logging\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", + "import matplotlib.patches as patches\n", "import seaborn as sns\n", "import logomaker\n", "\n", "import bean as be\n", "from bean import Edit\n", - "import bean.plotting.editing_patterns" + "import bean.plotting.editing_patterns\n", + "logging.getLogger(\"matplotlib.font_manager\").setLevel(logging.ERROR)" ] }, { @@ -614,334 +616,12 @@ "_, profile_df = be.pl.editing_patterns.plot_by_pos_behive(\n", " cedit_rates_df,\n", " cdata_bulk,\n", - " target_basechange=cdata_bulk.uns['target_base_change'], \n", - " nonref_base_changes = None\n", + " target_basechanges=cdata_bulk.target_base_changes, \n", ")\n", "profile_df.to_csv(f\"{output_prefix}_behive_like_profile.csv\")\n", "if save_fig: plt.savefig(f\"{output_prefix}_behive_like_profile.pdf\", bbox_inches = 'tight')" ] }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "ag_rates = profile_df[cdata_bulk.uns['target_base_change']]\n", - "window_sum_maxpos = ag_rates.rolling(max_editing_window_length).sum().argmax()\n", - "window_end = ag_rates.index[window_sum_maxpos]\n", - "window_start = window_end - max_editing_window_length + 1" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "3" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "window_start" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "8" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "window_end" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.patches as patches" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
base_changeA>CA>TA>GC>TC>G
spacer_pos
10.0002510.0013440.2312310.0021830.001298
20.0001750.0010100.2307910.0021540.001042
30.0002950.0011370.2755710.0029420.001145
40.0001250.0010940.3366870.0026050.001785
50.0001440.0013370.3551460.0045230.001151
60.0001020.0011290.3567090.0042310.001516
70.0002110.0008650.2928590.0037610.000748
80.0001840.0016770.2429740.0027550.000939
90.0005040.0013050.2157720.0021440.000936
100.0001400.0004230.1729200.0019330.000815
110.0001660.0012350.1653280.0020580.000775
120.0001380.0007240.1292270.0011660.001041
130.0002570.0008500.1312650.0021940.001379
140.0000920.0008130.1045900.0013440.000267
150.0001430.0007800.1109870.0015960.000162
160.0003070.0006380.0963970.0022890.000732
170.0004140.0004540.1068740.0018190.000649
180.0000700.0005770.0953030.0011820.000128
190.0001660.0008520.0923960.0006970.000892
200.0001360.0010390.0859230.0012610.000188
\n", - "
" - ], - "text/plain": [ - "base_change A>C A>T A>G C>T C>G\n", - "spacer_pos \n", - "1 0.000251 0.001344 0.231231 0.002183 0.001298\n", - "2 0.000175 0.001010 0.230791 0.002154 0.001042\n", - "3 0.000295 0.001137 0.275571 0.002942 0.001145\n", - "4 0.000125 0.001094 0.336687 0.002605 0.001785\n", - "5 0.000144 0.001337 0.355146 0.004523 0.001151\n", - "6 0.000102 0.001129 0.356709 0.004231 0.001516\n", - "7 0.000211 0.000865 0.292859 0.003761 0.000748\n", - "8 0.000184 0.001677 0.242974 0.002755 0.000939\n", - "9 0.000504 0.001305 0.215772 0.002144 0.000936\n", - "10 0.000140 0.000423 0.172920 0.001933 0.000815\n", - "11 0.000166 0.001235 0.165328 0.002058 0.000775\n", - "12 0.000138 0.000724 0.129227 0.001166 0.001041\n", - "13 0.000257 0.000850 0.131265 0.002194 0.001379\n", - "14 0.000092 0.000813 0.104590 0.001344 0.000267\n", - "15 0.000143 0.000780 0.110987 0.001596 0.000162\n", - "16 0.000307 0.000638 0.096397 0.002289 0.000732\n", - "17 0.000414 0.000454 0.106874 0.001819 0.000649\n", - "18 0.000070 0.000577 0.095303 0.001182 0.000128\n", - "19 0.000166 0.000852 0.092396 0.000697 0.000892\n", - "20 0.000136 0.001039 0.085923 0.001261 0.000188" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "profile_df" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "12" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "max(profile_df.index)-window_end" - ] - }, { "cell_type": "code", "execution_count": 19, @@ -962,20 +642,26 @@ "ax, _ = be.pl.editing_patterns.plot_by_pos_behive(\n", " cedit_rates_df,\n", " cdata_bulk,\n", - " target_basechange=cdata_bulk.uns['target_base_change'], \n", - " nonref_base_changes = None,\n", + " target_basechanges=cdata_bulk.target_base_changes, \n", " normalize=True\n", ")\n", - "ax.add_patch(\n", - " patches.Rectangle(\n", - " (2, window_start-1),\n", - " 1,\n", - " max_editing_window_length,\n", - " edgecolor=\"black\",\n", - " fill=False,\n", - " lw=2\n", + "\n", + "for i, _ax in enumerate(ax):\n", + " ag_rates = profile_df[cdata_bulk.uns['target_base_changes'].split(\",\")[i]]\n", + " window_sum_maxpos = ag_rates.rolling(max_editing_window_length).sum().argmax()\n", + " window_end = ag_rates.index[window_sum_maxpos]\n", + " window_start = window_end - max_editing_window_length + 1\n", + " _ax.add_patch(\n", + " patches.Rectangle(\n", + " (2, window_start-1),\n", + " 1,\n", + " max_editing_window_length,\n", + " edgecolor=\"black\",\n", + " fill=False,\n", + " lw=2\n", + " )\n", " )\n", - ")\n", + "\n", "if save_fig: plt.savefig(f\"{output_prefix}_behive_like_profile_normed.pdf\", bbox_inches = 'tight')" ] }, diff --git a/bean/plotting/editing_patterns.py b/bean/plotting/editing_patterns.py index 6454438..21e7bef 100755 --- a/bean/plotting/editing_patterns.py +++ b/bean/plotting/editing_patterns.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Optional, Sequence +from typing import Optional, Sequence, Dict, List, Union import re from ..framework.Edit import Edit import numpy as np @@ -17,13 +17,10 @@ def _add_absent_edits( bdata, guide: str, edit_tbl: pd.DataFrame, - edited_base: str = "A", - target_alt: str = "G", + target_base_edits: Dict[str, str], ): """If edit is not observed in editable position, add into the edit rate table.""" - editable_positions = np.where( - (np.array(list(bdata.guides.loc[guide, "reporter"])) == edited_base) - )[0] + if guide not in edit_tbl.guide.tolist(): observed_rel_pos = [] else: @@ -31,16 +28,20 @@ def _add_absent_edits( observed_rel_pos = edited_db.rel_pos.tolist() edits = [] positions = [] - for editable_pos in editable_positions: - if editable_pos not in observed_rel_pos: - edits.append( - Edit( - editable_pos, - edited_base, - target_alt, + for edited_base, target_alt in target_base_edits.items(): + editable_positions = np.where( + (np.array(list(bdata.guides.loc[guide, "reporter"])) == edited_base) + )[0] + for editable_pos in editable_positions: + if editable_pos not in observed_rel_pos: + edits.append( + Edit( + editable_pos, + edited_base, + target_alt, + ) ) - ) - positions.append(editable_pos) + positions.append(editable_pos) edit_tbl = pd.concat( [ edit_tbl, @@ -88,8 +89,7 @@ def get_edit_rates( bdata, guide, edit_rates_agg, - edited_base=bdata.base_edited_from, - target_alt=bdata.base_edited_to, + target_base_edits=bdata.target_base_changes, ) if adjust_spacer_pos: @@ -180,7 +180,7 @@ def _get_norm_rates_df( bdata, edit_rates_df=None, edit_count_key="edit_counts", - base_changes: Sequence[str] = None, + base_changes: Optional[Sequence[str]] = None, ): change_by_pos = pd.pivot( edit_rates_df[["base_change", "spacer_pos", "rep_mean"]] @@ -211,7 +211,7 @@ def _get_norm_rates_df( return norm_rate.astype(float)[base_changes] # _reduced -def _get_possible_changes_from_target_base(target_basechange: str): +def _get_possible_changes_from_target_base(target_basechange: str) -> List[str]: """Return base changes strings (ex. A>C) for given the same reference base to edit from to the input target_basechange. ex) returns ['A>C', 'A>T'] given input 'A>G'.""" if not re.fullmatch(r"[ATCG]>[ATCG]", target_basechange): raise ValueError( @@ -226,8 +226,7 @@ def plot_by_pos_behive( norm_rates_df: Optional[pd.DataFrame] = None, bdata=None, edit_count_key: str = "edit_counts", - target_basechange="A>G", - nonref_base_changes: Sequence[str] = None, + target_basechanges: Optional[Dict[str, str]] = None, normalize=False, ): """Plot position-wise editing pattern as in BE-Hive. @@ -239,11 +238,21 @@ def plot_by_pos_behive( normalize: Normalize the editing rate relative to the max editing rate by position (as 100). """ - if not re.fullmatch(r"[ATCG]>[ATCG]", target_basechange): - raise ValueError( - f"Input argument {target_basechange} doesn't conform to the valid format ex. 'A>G'" - ) - if nonref_base_changes is None: + if target_basechanges is None: + if bdata is None: + raise ValueError("Target base change not provided.") + target_basechange = bdata.target_base_changes + fig, axes = plt.subplots( + 1, + len(target_basechanges.keys()), + figsize=(3 * len(target_basechanges.keys()), 7), + ) + if not isinstance(axes, np.ndarray): + axes = np.ndarray([axes]) + dfs = [] + for i, (edited_base, alt_base) in enumerate(target_basechanges.items()): # type: ignore + target_basechange = f"{edited_base}>{alt_base}" + print(target_basechange) if target_basechange == "A>G": nonref_base_changes = ["C>T", "C>G"] elif target_basechange == "C>T": @@ -251,71 +260,74 @@ def plot_by_pos_behive( else: print("No non-ref base changes specified. not drawing them") nonref_base_changes = [] - ref_other_changes = _get_possible_changes_from_target_base(target_basechange) - - df_to_draw = _get_norm_rates_df( - bdata, - norm_rates_df, - edit_count_key, - base_changes=ref_other_changes - + [ - target_basechange, - ] - + nonref_base_changes, - ) - fig, ax = plt.subplots(figsize=(3, 7)) - vmax = df_to_draw.max().max() - if normalize: - df_to_draw = df_to_draw / vmax * 100 - vmax = 100 + ref_other_changes = _get_possible_changes_from_target_base(target_basechange) - target_data = df_to_draw.copy() - target_data.loc[:, target_data.columns != target_basechange] = np.nan - sns.heatmap( - target_data, - ax=ax, - annot=True, - cmap="Reds", - vmax=vmax, - cbar=False, - vmin=-0.03, - fmt=".0f" if normalize else ".2g", - annot_kws={"fontsize": 8}, - ) + df_to_draw = _get_norm_rates_df( + bdata, + norm_rates_df, + edit_count_key, + base_changes=ref_other_changes + + [ + target_basechange, + ] + + nonref_base_changes, + ) - ref_data = df_to_draw.copy() - ref_data.loc[ - :, - ~ref_data.columns.isin(ref_other_changes), - ] = np.nan - sns.heatmap( - ref_data, - ax=ax, - annot=True, - cmap="Blues", - vmax=vmax, - cbar=False, - fmt=".1g", - vmin=-0.03, - annot_kws={"fontsize": 8}, - ) + vmax = df_to_draw.max().max() + if normalize: + df_to_draw = df_to_draw / vmax * 100 + vmax = 100 + + target_data = df_to_draw.copy() + target_data.loc[:, target_data.columns != target_basechange] = np.nan + sns.heatmap( + target_data, + ax=axes[i], + annot=True, + cmap="Reds", + vmax=vmax, + cbar=False, + vmin=-0.03, + fmt=".0f" if normalize else ".2g", + annot_kws={"fontsize": 8}, + ) - nonref_data = df_to_draw.copy() - nonref_data.loc[:, ~nonref_data.columns.isin(nonref_base_changes)] = np.nan - sns.heatmap( - nonref_data, - ax=ax, - annot=True, - cmap="Greys", - vmax=vmax, - cbar=False, - fmt=".1g", - vmin=-0.03, - annot_kws={"fontsize": 8}, - ) - ax.set_ylabel("Protospacer position") - return ax, df_to_draw + ref_data = df_to_draw.copy() + ref_data.loc[ + :, + ~ref_data.columns.isin(ref_other_changes), + ] = np.nan + sns.heatmap( + ref_data, + ax=axes[i], + annot=True, + cmap="Blues", + vmax=vmax, + cbar=False, + fmt=".1g", + vmin=-0.03, + annot_kws={"fontsize": 8}, + ) + + nonref_data = df_to_draw.copy() + nonref_data.loc[:, ~nonref_data.columns.isin(nonref_base_changes)] = np.nan + sns.heatmap( + nonref_data, + ax=axes[i], + annot=True, + cmap="Greys", + vmax=vmax, + cbar=False, + fmt=".1g", + vmin=-0.03, + annot_kws={"fontsize": 8}, + ) + axes[i].set_ylabel("Protospacer position") + dfs.append(df_to_draw) + df = pd.concat(dfs, axis=1) + df = df.loc[:, ~df.columns.duplicated()].copy() + return axes, df def get_position_by_pam_rates(bdata, edit_rates_df: pd.DataFrame, pam_col="5-nt PAM"): diff --git a/setup.py b/setup.py index 157f024..c7ee60f 100755 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name="crispr-bean", - version="1.1.1", + version="1.2.0", python_requires=">=3.8.0", author="Jayoung Ryu", author_email="jayoung_ryu@g.harvard.edu", diff --git a/tests/data/var_mini_screen.h5ad b/tests/data/var_mini_screen.h5ad index 3044b5f..bfd3abf 100755 Binary files a/tests/data/var_mini_screen.h5ad and b/tests/data/var_mini_screen.h5ad differ diff --git a/tests/test_count.py b/tests/test_count.py index ff5168c..176084d 100755 --- a/tests/test_count.py +++ b/tests/test_count.py @@ -45,6 +45,19 @@ def test_count_samples(): @pytest.mark.order(105) +def test_count_samples_dual(): + cmd = "bean count-samples -i tests/data/sample_list.csv -b A,C -f tests/data/test_guide_info.csv -o tests/test_res/var/ -r --guide-start-seq=GGAAAGGACGAAACACCG" + try: + subprocess.check_output( + cmd, + shell=True, + universal_newlines=True, + ) + except subprocess.CalledProcessError as exc: + raise exc + + +@pytest.mark.order(106) def test_count_samples_bcstart(): cmd = "bean count-samples -i tests/data/sample_list.csv -b A -f tests/data/test_guide_info.csv -o tests/test_res/var2/ -r --barcode-start-seq=GGAA" try: diff --git a/tests/test_run.py b/tests/test_run.py index 8d98653..ca21f03 100755 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -108,7 +108,7 @@ def test_run_tiling_with_wo_negctrl_noacc(): @pytest.mark.order(423) def test_run_tiling_with_wo_negctrl_uniform(): - cmd = "bean run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --uniform-edit --allele-df-key allele_counts_spacer_0_19_noindels_A.G_translated_prop0.1_0.3 --control-guide-tag None --repguide-mask None --n-iter 10" + cmd = "bean run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --uniform-edit --control-guide-tag None --repguide-mask None --n-iter 10" try: subprocess.check_output( cmd, @@ -134,7 +134,7 @@ def test_run_tiling_negctrl_allelekey(): @pytest.mark.order(425) def test_run_tiling_with_negctrl_noacc(): - cmd = "bean run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --allele-df-key allele_counts_spacer_0_19_noindels_A.G_translated_prop0.1_0.3 --fit-negctrl --negctrl-col strand --negctrl-col-value neg --control-guide-tag neg --repguide-mask None --n-iter 10" + cmd = "bean run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --fit-negctrl --negctrl-col strand --negctrl-col-value neg --control-guide-tag neg --repguide-mask None --n-iter 10" try: subprocess.check_output( cmd, @@ -147,7 +147,7 @@ def test_run_tiling_with_negctrl_noacc(): @pytest.mark.order(426) def test_run_tiling_with_negctrl_uniform(): - cmd = "bean run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --uniform-edit --allele-df-key allele_counts_spacer_0_19_noindels_A.G_translated_prop0.1_0.3 --fit-negctrl --negctrl-col strand --negctrl-col-value neg --control-guide-tag neg --repguide-mask None --n-iter 10" + cmd = "bean run sorting tiling tests/data/tiling_mini_screen_annotated.h5ad -o tests/test_res/tiling/ --uniform-edit --fit-negctrl --negctrl-col strand --negctrl-col-value neg --control-guide-tag neg --repguide-mask None --n-iter 10" try: subprocess.check_output( cmd,