diff --git a/vatools/vcf_expression_annotator.py b/vatools/vcf_expression_annotator.py index 65d36e8..1818feb 100644 --- a/vatools/vcf_expression_annotator.py +++ b/vatools/vcf_expression_annotator.py @@ -54,7 +54,7 @@ def to_array(dictionary): def parse_expression_file(args, vcf_reader, vcf_writer): if args.format == 'stringtie' and args.mode == 'transcript': - df_all = read_gtf(args.expression_file) + df_all = read_gtf(args.expression_file, usecols=['reference_id', 'transcript_id', 'TPM', 'feature']) df = df_all[df_all["feature"] == "transcript"] id_column = resolve_stringtie_id_column(args, df.columns.values) else: @@ -106,24 +106,23 @@ def create_vcf_writer(args, vcf_reader): output_file = args.output_vcf return vcfpy.Writer.from_path(output_file, new_header) -def add_expressions(entry, is_multi_sample, sample_name, df, items, tag, id_column, expression_column, ignore_ensembl_id_version, missing_expressions_count, entry_count): - expressions = {} - for item in items: - entry_count += 1 - if ignore_ensembl_id_version: - subset = df.loc[df['transcript_without_version'] == re.sub(r'\.[0-9]+$', '', item)] - else: - subset = df.loc[df[id_column] == item] - if len(subset) > 0: - expressions[item] = subset[expression_column].sum() - else: - missing_expressions_count += 1 +def add_expressions(entry, is_multi_sample, sample_name, df, items, tag, id_column, expression_column, ignore_ensembl_id_version, missing_expressions_count): + subset = None + if ignore_ensembl_id_version: + items_without_version = [re.sub(r'\.[0-9]+$', '', item) for item in items] + subset = df[df['transcript_without_version'].isin(items_without_version)] + for item in items: + subset.loc[subset['transcript_without_version'] == re.sub(r'\.[0-9]+$', "", item), id_column] = item + else: + subset = df[df[id_column].isin(items)] + expressions = subset[[id_column, expression_column]].groupby(id_column).sum().to_dict()[expression_column] + missing_expressions_count += (len(items) - len(expressions)) if is_multi_sample: entry.FORMAT += [tag] entry.call_for_sample[sample_name].data[tag] = to_array(expressions) else: entry.add_format(tag, to_array(expressions)) - return (entry, missing_expressions_count, entry_count) + return (entry, missing_expressions_count) def define_parser(): parser = argparse.ArgumentParser( @@ -216,11 +215,13 @@ def main(args_input = sys.argv[1:]): if args.mode == 'gene': genes = list(genes) if len(genes) > 0: - (entry, missing_expressions_count, entry_count) = add_expressions(entry, is_multi_sample, args.sample_name, df, genes, 'GX', id_column, expression_column, args.ignore_ensembl_id_version, missing_expressions_count, entry_count) + (entry, missing_expressions_count) = add_expressions(entry, is_multi_sample, args.sample_name, df, genes, 'GX', id_column, expression_column, args.ignore_ensembl_id_version, missing_expressions_count) + entry_count += len(genes) elif args.mode == 'transcript': transcript_ids = list(transcript_ids) if len(transcript_ids) > 0: - (entry, missing_expressions_count, entry_count) = add_expressions(entry, is_multi_sample, args.sample_name, df, transcript_ids, 'TX', id_column, expression_column, args.ignore_ensembl_id_version, missing_expressions_count, entry_count) + (entry, missing_expressions_count) = add_expressions(entry, is_multi_sample, args.sample_name, df, transcript_ids, 'TX', id_column, expression_column, args.ignore_ensembl_id_version, missing_expressions_count) + entry_count += len(transcript_ids) vcf_writer.write_record(entry) vcf_reader.close()