Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor consistency_report.py #127

Merged
merged 1 commit into from
Jun 10, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 64 additions & 123 deletions truvari/consistency_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,89 +5,62 @@
CHROM:POS ID REF ALT
"""
# pylint: disable=consider-using-f-string
import io
import gzip
import bisect
import argparse
import itertools

from collections import defaultdict, namedtuple, Counter
from collections import defaultdict, Counter


def parse_vcf(fn):
"""
Simple vcf reader
"""
VCFLine = namedtuple("VCFline",
"CHROM POS ID REF ALT QUAL FILT INFO FORMAT SAMPLES")
if fn.endswith(".gz"):
fh = io.TextIOWrapper(gzip.open(fn))
else:
fh = open(fn, 'r') # pylint: disable=consider-using-with
for line in fh:
if line.startswith("#"):
continue
data = line.strip().split('\t')
yield VCFLine(*data[:9], SAMPLES=data[9:])

openfn = gzip.open if fn.endswith(".gz") else open
with openfn(fn, "r") as fh:
for line in fh:
if line.startswith("#"):
continue
# Only keep the first 5 fields, and use them as key
yield "\t".join(line.split("\t")[:5])

class hash_list(list):

def read_files(allVCFs):
"""
A list that's hashable
"""
Read all VCFs and mark all (union) calls for their presence

def __hash__(self):
"""
Only method needed
"""
return hash(" ".join(self))
For each call, we will use and integer and mark a call's presence
at the index of the VCF file as 1.

For example, if we have 3 VCFs, and the call is present in only the
first VCF, then the integer will be `0 | (1 << 2)`, that is `0b100`,
where the first bit marks that the call appears in the first VCF.

def entry_key(entry):
"""
Turn a vcf entry into a key
Returns:
all_presence: dict of all calls, with the integers as values.
The integers are the bitwise OR of all VCFs that have the call.
n_calls_per_vcf: list of the number of calls in each VCF
"""
key = "%s:%s %s %s %s" % (entry.CHROM, entry.POS,
entry.ID, entry.REF, str(entry.ALT))
return key
n_vcfs = len(allVCFs)
# Initialize the integer to 0
all_presence = defaultdict(lambda: 0)
n_calls_per_vcf = []
for i, vcf in enumerate(allVCFs):
n_calls_per_vcf.append(0)
for key in parse_vcf(vcf):
n_calls_per_vcf[i] += 1
all_presence[key] |= (1 << (n_vcfs - i - 1))

# We don't care about the calls anyway for stats
# Then this becomes a list of integers, which will save memory
return all_presence.values(), n_calls_per_vcf

def read_files(allVCFs):
"""
Load all vcfs and count their number of entries
"""
# call exists in which files
call_lookup = defaultdict(list)
# total number of calls in a file
file_abscnt = defaultdict(float)
for vcfn in allVCFs:
v = parse_vcf(vcfn)
# disallow intra vcf duplicates
seen = {}
for entry in v:
key = entry_key(entry)
if key in seen:
continue
seen[key] = True
bisect.insort(call_lookup[key], vcfn)
file_abscnt[vcfn] += 1

return call_lookup, file_abscnt


def create_file_intersections(allVCFs):
"""
Generate all possible intersections of vcfs
"""
count_lookup = {}
combo_gen = (x for l in range(1, len(allVCFs) + 1)
for x in itertools.combinations(allVCFs, l))
for files_combo in combo_gen:
files_combo = hash_list(files_combo)
files_combo.sort()
count_lookup[files_combo] = 0
return count_lookup
def get_shared_calls(all_presence, n):
"""Get n shared calls from the all_presence dictionary"""
return sum(
1 for presence in all_presence
if bin(presence).count("1") == n
)


def parse_args(args):
Expand All @@ -100,66 +73,45 @@ def parse_args(args):
return args


def make_consistency_overlap(count_lookup, file_abscnt, allVCFs):
"""
1 I want to make a key "101010" so that they can be viz'd easier
2 - I want to sort the count_lookup by their value so that we output them in order
The group
"""
all_consistency = Counter()
all_overlap = []

for combo, value in sorted(count_lookup.items(), key=lambda i: (i[1], i[0]), reverse=True):
# There are no calls here, so we just ignore it... But I think I want to keep that information
cur_data = []
if value == 0:
continue

my_group = ["0"] * len(allVCFs)
m_cnt = 0
for j in combo:
my_group[allVCFs.index(j)] = "1"
m_cnt += 1
cur_data.append("".join(my_group))
cur_data.append(value)
cur_data.append(["0%"] * len(allVCFs))

all_consistency[m_cnt] += value
for fkey in combo:
if file_abscnt[fkey] > 0:
cur_data[-1][allVCFs.index(fkey)] = "%.2f%%" % (
count_lookup[combo] / file_abscnt[fkey] * 100)
all_overlap.append(cur_data)
return all_consistency, all_overlap


def write_report(total_unique_calls, allVCFs, file_abscnt, all_consistency, all_overlap):
def write_report(allVCFs, all_presence, n_calls_per_vcf):
"""
Write the report
"""
total_unique_calls = len(all_presence)
n_vcfs = len(allVCFs)

print("#\n# Total %d calls across %d VCFs\n#" %
(total_unique_calls, len(allVCFs)))
(total_unique_calls, n_vcfs))
print("#File\tNumCalls")
for fn in allVCFs:
print("%s\t%d" % (fn, file_abscnt[fn]))
for i, fn in enumerate(allVCFs):
print("%s\t%d" % (fn, n_calls_per_vcf[i]))

print("#\n# Summary of consistency\n#")
print("#VCFs\tCalls\tPct")

for i in sorted(all_consistency.keys(), reverse=True):
for i in reversed(range(n_vcfs)):
shared_calls = get_shared_calls(all_presence, i + 1)
print("%d\t%d\t%.2f%%" % (
i, all_consistency[i], all_consistency[i] / total_unique_calls * 100))
i + 1,
shared_calls,
shared_calls / total_unique_calls * 100
))

print("#\n# Breakdown of VCFs' consistency\n#")
print("#Group\tTotal\tTotalPct\tPctOfFileCalls")
for my_group, value, combo in all_overlap:
c_text = ""
pos = 0
for i in my_group:
c_text += combo[pos] + " "
pos += 1

all_overlap = Counter(all_presence)
for group, ncalls in sorted(all_overlap.items(), key=lambda x: (-x[1], x[0])):
# '0b100' -> '100'
group = bin(group)[2:].rjust(n_vcfs, '0')
c_text = " ".join(
"%.2f%%" % (ncalls / n_calls_per_vcf[i] * 100)
if group[i] == "1"
else "0%"
for i in range(n_vcfs)
)
print("%s\t%d\t%.2f%%\t%s" %
(my_group, value, value / total_unique_calls * 100, c_text))
(group, ncalls, ncalls / total_unique_calls * 100, c_text))


def consistency_main(args):
Expand All @@ -168,17 +120,6 @@ def consistency_main(args):
"""
args = parse_args(args)

call_lookup, file_abscnt = read_files(args.allVCFs)

count_lookup = create_file_intersections(args.allVCFs)

for key in call_lookup:
count_lookup[hash_list(call_lookup[key])] += 1

all_consistency, all_overlap = make_consistency_overlap(
count_lookup, file_abscnt, args.allVCFs)

total_unique_calls = sum(all_consistency.values())
all_presence, n_calls_per_vcf = read_files(args.allVCFs)

write_report(total_unique_calls, args.allVCFs,
file_abscnt, all_consistency, all_overlap)
write_report(args.allVCFs, all_presence, n_calls_per_vcf)