Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] add some utilities for working with altair and pyafq #1049

Merged
merged 5 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions AFQ/viz/altair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from AFQ.viz.utils import COLOR_DICT, FORMAL_BUNDLE_NAMES
import numpy as np
import scipy.stats as stats
import altair as alt


def altair_color_dict(names_to_include=None):
"""
Given a list of bundle names, return a dictionary of colors for each
Formatted for Altair.
"""
altair_cd = dict(COLOR_DICT.copy())
for key in list(altair_cd.keys()):
value = altair_cd[key]
if (names_to_include is None) or (key in names_to_include):
altair_cd[key] = (
f"rgb({int(value[0]*255)},"
f"{int(value[1]*255)},"
f"{int(value[2]*255)})")
else:
del altair_cd[key]
return altair_cd


def combined_profiles_df_to_altair_df(
profiles,
tissue_properties=['dti_fa', 'dti_md']):
"""
Given a profiles dataframe that is combined
from many subjects, return a dataframe formatted for Altair.
"""
profiles = profiles.copy()
if 'dki_md' in tissue_properties:
profiles.dki_md = profiles.dki_md * 1000.
if 'dti_md' in tissue_properties:
profiles.dti_md = profiles.dti_md * 1000.

id_vars = ['tractID', 'nodeID', 'subjectID']
if 'sessionID' in profiles.columns:
id_vars.append('sessionID')

profiles = profiles.melt(
id_vars=id_vars,
value_vars=tissue_properties,
var_name='TP',
value_name='Value')

# Function to calculate 95% CI using a normal distribution
def calculate_95CI(x):
ci = stats.norm.interval(
0.95, loc=np.mean(x), scale=np.std(x) / np.sqrt(len(x)))
return ci

# Group by 'tractID', 'nodeID', 'TP' and apply the aggregation functions
profiles = profiles.groupby(['tractID', 'nodeID', 'TP'])['Value'].agg(
mean='mean',
CI_lower=lambda x: calculate_95CI(x)[0],
CI_upper=lambda x: calculate_95CI(x)[1],
IQR_lower=lambda x: x.quantile(0.25),
IQR_upper=lambda x: x.quantile(0.75)
).reset_index()

def get_hemi(cc):
if cc == "L":
return "Left"
elif cc == "R":
return "Right"
else:
return "Callosal"

def get_bname(s):
if s.startswith("Left "):
return s[5:]
elif s.startswith("Right "):
return s[6:]
return s

def formal_tp(tp_name):
return tp_name.upper().replace("_", " ")

profiles["Hemi"] = profiles["tractID"].apply(lambda x: get_hemi(x[-1]))
profiles["Bundle Name"] = profiles["tractID"].replace(
FORMAL_BUNDLE_NAMES).apply(get_bname)
profiles["TP"] = profiles["TP"].apply(formal_tp)

return profiles


def altair_df_to_chart(profiles, position_domain=(20, 80),
column_count=1, font_size=20,
line_size=10, row_label_angle=90,
legend_line_size=5,
**kwargs):
"""
Given a dataframe formatted for Altair, probably from
combined_profiles_df_to_altair_df, return a chart.

Example
-------
call_results = results[results.Hemi == "Callosal"]
stand_results = results[results.Hemi != "Callosal"]
prof_chart = altair_df_to_chart(call_results)
prof_chart.save("supp_chart_call.png", dpi=300)
prof_chart = altair_df_to_chart(stand_results,
column_count=2, color="Hemi")
prof_chart.save("supp_chart_stand.png", dpi=300)
"""
this_cd = altair_color_dict(profiles.tractID.unique())
this_cd = {FORMAL_BUNDLE_NAMES.get(
key, key): value for key, value in this_cd.items()}

alt.data_transformers.disable_max_rows()

profiles = profiles[np.logical_and(
profiles.nodeID >= position_domain[0],
profiles.nodeID < position_domain[1])]

tp_units = {
"DKI AWF": "",
"DKI FA": "",
"DKI MD": " (µm²/ms)",
"DKI MK": ""}

row_charts = []
for jj, b_name in enumerate(profiles["Bundle Name"].unique()):
row_dataframe = profiles[profiles["Bundle Name"] == b_name]
charts = []
for ii, tp in enumerate(profiles.TP.unique()):
this_dataframe = row_dataframe[row_dataframe.TP == tp]
if jj == 0:
title_name = tp + tp_units[tp]
else:
title_name = ""
if ii == 0:
y_axis_title = b_name
else:
y_axis_title = ""
if jj == len(profiles["Bundle Name"].unique()) - 1:
x_axis_title = "Position (%)"
useXlab = True
else:
x_axis_title = ""
useXlab = False
y_kwargs = dict(
scale=alt.Scale(zero=False),
# axis=alt.Axis(title=""),
title=y_axis_title
)
x_kwargs = dict(
axis=alt.Axis(title=x_axis_title, labels=useXlab))
prof_chart = alt.Chart(
this_dataframe, title=title_name).mark_line(
size=line_size).encode(
y=alt.Y('mean', **y_kwargs),
x=alt.X('nodeID', **x_kwargs),
**kwargs)
prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line(
size=line_size, opacity=0.5, strokeDash=[1, 1]).encode(
y=alt.Y('IQR_lower', **y_kwargs),
x=alt.X('nodeID', **x_kwargs),
**kwargs)
prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line(
size=line_size, opacity=0.5, strokeDash=[1, 1]).encode(
y=alt.Y('IQR_upper', **y_kwargs),
x=alt.X('nodeID', **x_kwargs),
**kwargs)
prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line(
size=line_size, opacity=0.5).encode(
y=alt.Y('CI_lower', **y_kwargs),
x=alt.X('nodeID', **x_kwargs),
**kwargs)
prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line(
size=line_size, opacity=0.5).encode(
y=alt.Y('CI_upper', **y_kwargs),
x=alt.X('nodeID', **x_kwargs),
**kwargs)
charts.append(prof_chart)
row_charts.append(alt.HConcatChart(hconcat=charts))
return alt.VConcatChart(vconcat=row_charts).configure_axis(
labelFontSize=font_size,
titleFontSize=font_size,
labelLimit=0
).configure_legend(
labelFontSize=font_size,
titleFontSize=font_size,
titleLimit=0,
labelLimit=0,
columns=column_count,
symbolStrokeWidth=legend_line_size * 10,
symbolSize=legend_line_size * 100,
orient='right'
).configure_title(
fontSize=font_size
)
6 changes: 3 additions & 3 deletions AFQ/viz/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def plot_line(self, bundle, x, y, data, ylabel, ylim, n_boot, alpha,
ax.set_ylabel(ylabel, fontsize=vut.medium_font)
ax.set_ylim(ylim)

def format(self, disable_x=True):
def format(self, disable_x=True, disable_y=True):
'''
Call this functions once after all axes that you intend to use
have been plotted on. Automatically formats brain axes.
Expand All @@ -193,14 +193,14 @@ def format(self, disable_x=True):
axis='x', which='major', labelsize=vut.small_font)
if not self.on_grid[i, j]:
self.axes[i, j].axis("off")
if self. twinning:
if self.twinning:
if j != self.size[1] - 1 and self.on_grid[i][j + 1]:
self.axes[i, j].set_yticklabels([])
self.axes[i, j].set_ylabel("")
self.axes[i, j].set_xticklabels([])
self.axes[i, j].set_xlabel("")
else:
if j != 0 and self.on_grid[i][j - 1]:
if disable_y and (j != 0 and self.on_grid[i][j - 1]):
self.axes[i, j].set_yticklabels([])
self.axes[i, j].set_ylabel("")
if disable_x or (i != self.size[0] - 1
Expand Down
143 changes: 142 additions & 1 deletion AFQ/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import imageio as io
from PIL import Image, ImageChops

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.transforms as mtransforms

import nibabel as nib
import dipy.tracking.streamlinespeed as dps
from dipy.align import resample
Expand Down Expand Up @@ -64,7 +68,15 @@
"UNC_R": tableau_20[17], "UF_R": tableau_20[17],
"ARC_L": tableau_20[18], "AF_L": tableau_20[18],
"ARC_R": tableau_20[19], "AF_R": tableau_20[19],
"median": tableau_20[6]})
"median": tableau_20[6],
"Orbital": (0.2, 0.13, 0.53), # Paul Tol's palette for callosal bundles
"AntFrontal": (0.07, 0.47, 0.2),
"SupFrontal": (0.27, 0.67, 0.6),
"Motor": (0.53, 0.8, 0.93),
"SupParietal": (0.87, 0.8, 0.47),
"PostParietal": (0.8, 0.4, 0.47),
"Occipital": (0.67, 0.27, 0.6),
"Temporal": (0.53, 0.13, 0.33)})

POSITIONS = OrderedDict({
"ATR_L": (1, 0), "ATR_R": (1, 4), "C_L": (1, 0), "C_R": (1, 4),
Expand Down Expand Up @@ -104,6 +116,135 @@
"pARC_L": ("Coronal", "Back"), "pARC_R": ("Coronal", "Back")}


FORMAL_BUNDLE_NAMES = {
"ATR_L": "Left Anterior Thalamic",
"ATR_R": "Right Anterior Thalamic",
"CST_L": "Left Corticospinal",
"CST_R": "Right Corticospinal",
"CGC_L": "Left Cingulum Cingulate",
"CGC_R": "Right Cingulum Cingulate",
"FP": "Forceps Major",
"FA": "Forceps Minor",
"IFO_L": "Left Inferior Fronto-Occipital",
"IFO_R": "Right Inferior Fronto-Occipital",
"ILF_L": "Left Inferior Longitudinal",
"ILF_R": "Right Inferior Longitudinal",
"SLF_L": "Left Superior Longitudinal",
"SLF_R": "Right Superior Longitudinal",
"UNC_L": "Left Uncinate",
"UNC_R": "Right Uncinate",
"ARC_L": "Left Arcuate",
"ARC_R": "Right Arcuate",
"VOF_L": "Left Vertical Occipital",
"VOF_R": "Right Vertical Occipital",
"pARC_L": "Left Posterior Arcuate",
"pARC_R": "Right Posterior Arcuate"
}


class PanelFigure():
"""
Super useful class for organizing existing images
into subplots using matplotlib
"""

def __init__(self, num_rows, num_cols, width, height):
"""
Initialize PanelFigure.

Parameters
----------
num_rows : int
Number of rows in figure
num_cols : int
Number of columns in figure
width : int
Width of figure in inches
height : int
Height of figure in inches
"""
self.fig = plt.figure(figsize=(width, height))
self.grid = plt.GridSpec(num_rows, num_cols, hspace=0, wspace=0)
self.subplot_count = 0

def add_img(self, fname, x_coord, y_coord, reduct_count=1,
subplot_label_pos=(0.1, 1.0), legend=None, legend_kwargs={},
add_panel_label=True, panel_label_font_size="medium"):
"""
Add image from fname into figure as a panel.

Parameters
----------
fname : str
path to image file to add to subplot
x_coord : int or slice
x coordinate(s) of subplot in matlotlib figure
y_coord : int or slice
y coordinate(s) of subplot in matlotlib figure
reduct_count : int
number of times to trim whitespace around image
Default: 1
subplot_label_pos : tuple of floats
position of subplot label
Default: (0.1, 1.0)
legend : dict
dictionary of legend items, where keys are labels
and values are colors
Default: None
legend_kwargs : dict
ADditional arguments for matplotlib's legend method
add_panel_label : bool
Whether or not to add a panel label to the subplot
Default: True
panel_label_font_size : str
Font size of panel label
Default: "medium"
"""
ax = self.fig.add_subplot(self.grid[y_coord, x_coord])
im1 = Image.open(fname)
for _ in range(reduct_count):
im1 = trim(im1)
if legend is not None:
patches = []
for value, color in legend.items():
patches.append(mpatches.Patch(
color=color,
label=value))
ax.legend(handles=patches, borderaxespad=0., **legend_kwargs)
if add_panel_label:
trans = mtransforms.ScaledTranslation(
10 / 72, -5 / 72, self.fig.dpi_scale_trans)
ax.text(
subplot_label_pos[0], subplot_label_pos[1],
f"{chr(65+self.subplot_count)})",
transform=ax.transAxes + trans,
fontsize=panel_label_font_size, verticalalignment="top",
fontfamily='serif',
bbox=dict(facecolor='0.7', edgecolor='none', pad=3.0))
ax.imshow(np.asarray(im1), aspect=1)
ax.axis('off')
self.subplot_count = self.subplot_count + 1
return ax

def format_and_save_figure(self, fname, trim_final=True):
"""
Format and save figure to fname.
Parameters
----------
fname : str
Path to save figure to
trim : bool
Whether or not to trim whitespace around figure.
Default: True
"""
self.fig.tight_layout()
self.fig.savefig(fname, dpi=300)
if trim_final:
im1 = Image.open(fname)
im1 = trim(im1)
im1.save(fname)


def get_eye(view, direc):
direc = direc.lower()
view = view.lower()
Expand Down
Loading