Skip to content

Commit

Permalink
handled duplicate labels in arctic3d-localise plots (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgiulini authored Jul 27, 2023
1 parent 54f82a8 commit 1f59e8d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
45 changes: 43 additions & 2 deletions src/arctic3d/modules/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,13 +503,54 @@ def create_barplot(cluster, sorted_dict, max_labels=70):
return


def remove_duplicate_labels(labels, values):
"""
Remove duplicate labels.
Parameters
----------
labels : list
list of labels
values : list
list of values
Returns
-------
new_labels : list
list of labels without duplicates
new_values : list
list of values without duplicates
"""
new_labels, new_values = [], []
for n in range(len(labels)):
if labels[n] not in new_labels:
new_labels.append(labels[n])
new_values.append(values[n])
else:
log.info(f"Detected duplicate label {labels[n]}.")
return new_labels, new_values


def create_barplotly(cluster, sorted_dict, format, scale, max_labels=25):
"""
Create horizontal barplot using plotly.
Parameters
----------
cluster : int or str
cluster ID
sorted_dict : dict
dictionary of sorted entries
format : str
format of the output figure
scale : float
scale of the output figure
max_labels : int
maximum number of labels to include
"""
labels = shorten_labels(list(sorted_dict.keys())[-max_labels:])
values = list(sorted_dict.values())[-max_labels:]
tmp_labels = shorten_labels(list(sorted_dict.keys())[-max_labels:])
tmp_values = list(sorted_dict.values())[-max_labels:]
labels, values = remove_duplicate_labels(tmp_labels, tmp_values)
fig = go.Figure(go.Bar(x=values, y=labels, orientation="h"))
fig_fname = f"cluster_{cluster}.html"
fig.write_html(fig_fname)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
create_output_folder,
output_pdb,
read_residues_probs,
remove_duplicate_labels,
setup_output_folder,
shorten_labels,
write_dict,
Expand Down Expand Up @@ -192,3 +193,14 @@ def test_shorten_labels(example_B_labels):
"positive regulation of transcription by RNA polymerase...",
]
assert exp_shortened_labels == obs_shortened_labels


def test_remove_duplicate_labels():
"""Test remove_duplicate_labels."""
tmp_labels = ["Polymerase...", "Polymerase...", "Polymerase..."]
tmp_values = [2, 3, 1]
exp_labels = ["Polymerase..."]
exp_values = [2]
obs_labels, obs_values = remove_duplicate_labels(tmp_labels, tmp_values)
assert exp_labels == obs_labels
assert exp_values == obs_values

0 comments on commit 1f59e8d

Please sign in to comment.