diff --git a/AFQ/viz/altair.py b/AFQ/viz/altair.py new file mode 100644 index 000000000..9175f65ff --- /dev/null +++ b/AFQ/viz/altair.py @@ -0,0 +1,194 @@ +from AFQ.viz.utils import COLOR_DICT, FORMAL_BUNDLE_NAMES +import numpy as np +import scipy.stats as stats +import altair as alt + + +def altair_color_dict(names_to_include=None): + """ + Given a list of bundle names, return a dictionary of colors for each + Formatted for Altair. + """ + altair_cd = dict(COLOR_DICT.copy()) + for key in list(altair_cd.keys()): + value = altair_cd[key] + if (names_to_include is None) or (key in names_to_include): + altair_cd[key] = ( + f"rgb({int(value[0]*255)}," + f"{int(value[1]*255)}," + f"{int(value[2]*255)})") + else: + del altair_cd[key] + return altair_cd + + +def combined_profiles_df_to_altair_df( + profiles, + tissue_properties=['dti_fa', 'dti_md']): + """ + Given a profiles dataframe that is combined + from many subjects, return a dataframe formatted for Altair. + """ + profiles = profiles.copy() + if 'dki_md' in tissue_properties: + profiles.dki_md = profiles.dki_md * 1000. + if 'dti_md' in tissue_properties: + profiles.dti_md = profiles.dti_md * 1000. + + id_vars = ['tractID', 'nodeID', 'subjectID'] + if 'sessionID' in profiles.columns: + id_vars.append('sessionID') + + profiles = profiles.melt( + id_vars=id_vars, + value_vars=tissue_properties, + var_name='TP', + value_name='Value') + + # Function to calculate 95% CI using a normal distribution + def calculate_95CI(x): + ci = stats.norm.interval( + 0.95, loc=np.mean(x), scale=np.std(x) / np.sqrt(len(x))) + return ci + + # Group by 'tractID', 'nodeID', 'TP' and apply the aggregation functions + profiles = profiles.groupby(['tractID', 'nodeID', 'TP'])['Value'].agg( + mean='mean', + CI_lower=lambda x: calculate_95CI(x)[0], + CI_upper=lambda x: calculate_95CI(x)[1], + IQR_lower=lambda x: x.quantile(0.25), + IQR_upper=lambda x: x.quantile(0.75) + ).reset_index() + + def get_hemi(cc): + if cc == "L": + return "Left" + elif cc == "R": + return "Right" + else: + return "Callosal" + + def get_bname(s): + if s.startswith("Left "): + return s[5:] + elif s.startswith("Right "): + return s[6:] + return s + + def formal_tp(tp_name): + return tp_name.upper().replace("_", " ") + + profiles["Hemi"] = profiles["tractID"].apply(lambda x: get_hemi(x[-1])) + profiles["Bundle Name"] = profiles["tractID"].replace( + FORMAL_BUNDLE_NAMES).apply(get_bname) + profiles["TP"] = profiles["TP"].apply(formal_tp) + + return profiles + + +def altair_df_to_chart(profiles, position_domain=(20, 80), + column_count=1, font_size=20, + line_size=10, row_label_angle=90, + legend_line_size=5, + **kwargs): + """ + Given a dataframe formatted for Altair, probably from + combined_profiles_df_to_altair_df, return a chart. + + Example + ------- + call_results = results[results.Hemi == "Callosal"] + stand_results = results[results.Hemi != "Callosal"] + prof_chart = altair_df_to_chart(call_results) + prof_chart.save("supp_chart_call.png", dpi=300) + prof_chart = altair_df_to_chart(stand_results, + column_count=2, color="Hemi") + prof_chart.save("supp_chart_stand.png", dpi=300) + """ + this_cd = altair_color_dict(profiles.tractID.unique()) + this_cd = {FORMAL_BUNDLE_NAMES.get( + key, key): value for key, value in this_cd.items()} + + alt.data_transformers.disable_max_rows() + + profiles = profiles[np.logical_and( + profiles.nodeID >= position_domain[0], + profiles.nodeID < position_domain[1])] + + tp_units = { + "DKI AWF": "", + "DKI FA": "", + "DKI MD": " (µm²/ms)", + "DKI MK": ""} + + row_charts = [] + for jj, b_name in enumerate(profiles["Bundle Name"].unique()): + row_dataframe = profiles[profiles["Bundle Name"] == b_name] + charts = [] + for ii, tp in enumerate(profiles.TP.unique()): + this_dataframe = row_dataframe[row_dataframe.TP == tp] + if jj == 0: + title_name = tp + tp_units[tp] + else: + title_name = "" + if ii == 0: + y_axis_title = b_name + else: + y_axis_title = "" + if jj == len(profiles["Bundle Name"].unique()) - 1: + x_axis_title = "Position (%)" + useXlab = True + else: + x_axis_title = "" + useXlab = False + y_kwargs = dict( + scale=alt.Scale(zero=False), + # axis=alt.Axis(title=""), + title=y_axis_title + ) + x_kwargs = dict( + axis=alt.Axis(title=x_axis_title, labels=useXlab)) + prof_chart = alt.Chart( + this_dataframe, title=title_name).mark_line( + size=line_size).encode( + y=alt.Y('mean', **y_kwargs), + x=alt.X('nodeID', **x_kwargs), + **kwargs) + prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line( + size=line_size, opacity=0.5, strokeDash=[1, 1]).encode( + y=alt.Y('IQR_lower', **y_kwargs), + x=alt.X('nodeID', **x_kwargs), + **kwargs) + prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line( + size=line_size, opacity=0.5, strokeDash=[1, 1]).encode( + y=alt.Y('IQR_upper', **y_kwargs), + x=alt.X('nodeID', **x_kwargs), + **kwargs) + prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line( + size=line_size, opacity=0.5).encode( + y=alt.Y('CI_lower', **y_kwargs), + x=alt.X('nodeID', **x_kwargs), + **kwargs) + prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line( + size=line_size, opacity=0.5).encode( + y=alt.Y('CI_upper', **y_kwargs), + x=alt.X('nodeID', **x_kwargs), + **kwargs) + charts.append(prof_chart) + row_charts.append(alt.HConcatChart(hconcat=charts)) + return alt.VConcatChart(vconcat=row_charts).configure_axis( + labelFontSize=font_size, + titleFontSize=font_size, + labelLimit=0 + ).configure_legend( + labelFontSize=font_size, + titleFontSize=font_size, + titleLimit=0, + labelLimit=0, + columns=column_count, + symbolStrokeWidth=legend_line_size * 10, + symbolSize=legend_line_size * 100, + orient='right' + ).configure_title( + fontSize=font_size + ) diff --git a/AFQ/viz/plot.py b/AFQ/viz/plot.py index abaddf17a..b92ae6fc7 100644 --- a/AFQ/viz/plot.py +++ b/AFQ/viz/plot.py @@ -175,7 +175,7 @@ def plot_line(self, bundle, x, y, data, ylabel, ylim, n_boot, alpha, ax.set_ylabel(ylabel, fontsize=vut.medium_font) ax.set_ylim(ylim) - def format(self, disable_x=True): + def format(self, disable_x=True, disable_y=True): ''' Call this functions once after all axes that you intend to use have been plotted on. Automatically formats brain axes. @@ -193,14 +193,14 @@ def format(self, disable_x=True): axis='x', which='major', labelsize=vut.small_font) if not self.on_grid[i, j]: self.axes[i, j].axis("off") - if self. twinning: + if self.twinning: if j != self.size[1] - 1 and self.on_grid[i][j + 1]: self.axes[i, j].set_yticklabels([]) self.axes[i, j].set_ylabel("") self.axes[i, j].set_xticklabels([]) self.axes[i, j].set_xlabel("") else: - if j != 0 and self.on_grid[i][j - 1]: + if disable_y and (j != 0 and self.on_grid[i][j - 1]): self.axes[i, j].set_yticklabels([]) self.axes[i, j].set_ylabel("") if disable_x or (i != self.size[0] - 1 diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index b6b9dd865..7c7c82e68 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -6,6 +6,10 @@ import imageio as io from PIL import Image, ImageChops +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import matplotlib.transforms as mtransforms + import nibabel as nib import dipy.tracking.streamlinespeed as dps from dipy.align import resample @@ -64,7 +68,15 @@ "UNC_R": tableau_20[17], "UF_R": tableau_20[17], "ARC_L": tableau_20[18], "AF_L": tableau_20[18], "ARC_R": tableau_20[19], "AF_R": tableau_20[19], - "median": tableau_20[6]}) + "median": tableau_20[6], + "Orbital": (0.2, 0.13, 0.53), # Paul Tol's palette for callosal bundles + "AntFrontal": (0.07, 0.47, 0.2), + "SupFrontal": (0.27, 0.67, 0.6), + "Motor": (0.53, 0.8, 0.93), + "SupParietal": (0.87, 0.8, 0.47), + "PostParietal": (0.8, 0.4, 0.47), + "Occipital": (0.67, 0.27, 0.6), + "Temporal": (0.53, 0.13, 0.33)}) POSITIONS = OrderedDict({ "ATR_L": (1, 0), "ATR_R": (1, 4), "C_L": (1, 0), "C_R": (1, 4), @@ -104,6 +116,135 @@ "pARC_L": ("Coronal", "Back"), "pARC_R": ("Coronal", "Back")} +FORMAL_BUNDLE_NAMES = { + "ATR_L": "Left Anterior Thalamic", + "ATR_R": "Right Anterior Thalamic", + "CST_L": "Left Corticospinal", + "CST_R": "Right Corticospinal", + "CGC_L": "Left Cingulum Cingulate", + "CGC_R": "Right Cingulum Cingulate", + "FP": "Forceps Major", + "FA": "Forceps Minor", + "IFO_L": "Left Inferior Fronto-Occipital", + "IFO_R": "Right Inferior Fronto-Occipital", + "ILF_L": "Left Inferior Longitudinal", + "ILF_R": "Right Inferior Longitudinal", + "SLF_L": "Left Superior Longitudinal", + "SLF_R": "Right Superior Longitudinal", + "UNC_L": "Left Uncinate", + "UNC_R": "Right Uncinate", + "ARC_L": "Left Arcuate", + "ARC_R": "Right Arcuate", + "VOF_L": "Left Vertical Occipital", + "VOF_R": "Right Vertical Occipital", + "pARC_L": "Left Posterior Arcuate", + "pARC_R": "Right Posterior Arcuate" +} + + +class PanelFigure(): + """ + Super useful class for organizing existing images + into subplots using matplotlib + """ + + def __init__(self, num_rows, num_cols, width, height): + """ + Initialize PanelFigure. + + Parameters + ---------- + num_rows : int + Number of rows in figure + num_cols : int + Number of columns in figure + width : int + Width of figure in inches + height : int + Height of figure in inches + """ + self.fig = plt.figure(figsize=(width, height)) + self.grid = plt.GridSpec(num_rows, num_cols, hspace=0, wspace=0) + self.subplot_count = 0 + + def add_img(self, fname, x_coord, y_coord, reduct_count=1, + subplot_label_pos=(0.1, 1.0), legend=None, legend_kwargs={}, + add_panel_label=True, panel_label_font_size="medium"): + """ + Add image from fname into figure as a panel. + + Parameters + ---------- + fname : str + path to image file to add to subplot + x_coord : int or slice + x coordinate(s) of subplot in matlotlib figure + y_coord : int or slice + y coordinate(s) of subplot in matlotlib figure + reduct_count : int + number of times to trim whitespace around image + Default: 1 + subplot_label_pos : tuple of floats + position of subplot label + Default: (0.1, 1.0) + legend : dict + dictionary of legend items, where keys are labels + and values are colors + Default: None + legend_kwargs : dict + ADditional arguments for matplotlib's legend method + add_panel_label : bool + Whether or not to add a panel label to the subplot + Default: True + panel_label_font_size : str + Font size of panel label + Default: "medium" + """ + ax = self.fig.add_subplot(self.grid[y_coord, x_coord]) + im1 = Image.open(fname) + for _ in range(reduct_count): + im1 = trim(im1) + if legend is not None: + patches = [] + for value, color in legend.items(): + patches.append(mpatches.Patch( + color=color, + label=value)) + ax.legend(handles=patches, borderaxespad=0., **legend_kwargs) + if add_panel_label: + trans = mtransforms.ScaledTranslation( + 10 / 72, -5 / 72, self.fig.dpi_scale_trans) + ax.text( + subplot_label_pos[0], subplot_label_pos[1], + f"{chr(65+self.subplot_count)})", + transform=ax.transAxes + trans, + fontsize=panel_label_font_size, verticalalignment="top", + fontfamily='serif', + bbox=dict(facecolor='0.7', edgecolor='none', pad=3.0)) + ax.imshow(np.asarray(im1), aspect=1) + ax.axis('off') + self.subplot_count = self.subplot_count + 1 + return ax + + def format_and_save_figure(self, fname, trim_final=True): + """ + Format and save figure to fname. + Parameters + ---------- + fname : str + Path to save figure to + trim : bool + Whether or not to trim whitespace around figure. + Default: True + """ + self.fig.tight_layout() + self.fig.savefig(fname, dpi=300) + if trim_final: + im1 = Image.open(fname) + im1 = trim(im1) + im1.save(fname) + + def get_eye(view, direc): direc = direc.lower() view = view.lower()