Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added create figure #4

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 130 additions & 13 deletions plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ def plot_surf_stat_map(coords, faces, stat_map=None,
threshold=None, bg_map=None,
bg_on_stat=False,
alpha='auto',
vmax=None, symmetric_cbar="auto",
vmin=None, vmax=None,
cbar='sequential', # or'diverging'
symmetric_cbar="auto",
figsize=None,
labels=None, label_cpal=None,
labels=None, label_col=None, label_cpal=None,
mask=None, mask_lenient=None,
**kwargs):

import numpy as np
Expand Down Expand Up @@ -53,6 +56,16 @@ def plot_surf_stat_map(coords, faces, stat_map=None,
antialiased=False,
color='white')

# where mask is indices of nodes to include:
if mask is not None:
cmask = np.zeros(len(coords))
cmask[mask] = 1
cutoff = 2 # include triangles in cortex only if ALL nodes in mask
if mask_lenient: # include triangles in cortex if ANY are in mask
cutoff = 0
fmask = np.where(cmask[faces].sum(axis=1) > cutoff)[0]


# If depth_map and/or stat_map are provided, map these onto the surface
# set_facecolors function of Poly3DCollection is used as passing the
# facecolors argument to plot_trisurf does not seem to work
Expand All @@ -78,11 +91,18 @@ def plot_surf_stat_map(coords, faces, stat_map=None,
stat_map_data = stat_map
stat_map_faces = np.mean(stat_map_data[faces], axis=1)

# Call _get_plot_stat_map_params to derive symmetric vmin and vmax
# And colorbar limits depending on symmetric_cbar settings
cbar_vmin, cbar_vmax, vmin, vmax = \
_get_plot_stat_map_params(stat_map_faces, vmax,
symmetric_cbar, kwargs)
if cbar is 'diverging':
print cbar
# Call _get_plot_stat_map_params to derive symmetric vmin and vmax
# And colorbar limits depending on symmetric_cbar settings
cbar_vmin, cbar_vmax, vmin, vmax = \
_get_plot_stat_map_params(stat_map_faces, vmax,
symmetric_cbar, kwargs)
if cbar is 'sequential':
if vmin is None:
vmin = stat_map_data.min()
if vmax is None:
vmax = stat_map_data.max()

if threshold is not None:
kept_indices = np.where(abs(stat_map_faces) >= threshold)[0]
Expand All @@ -96,9 +116,16 @@ def plot_surf_stat_map(coords, faces, stat_map=None,
stat_map_faces = stat_map_faces - vmin
stat_map_faces = stat_map_faces / (vmax-vmin)
if bg_on_stat:
face_colors = cmap(stat_map_faces) * face_colors
if mask is not None:
face_colors[fmask] = cmap(stat_map_faces)[fmask] * face_colors[fmask]
else:
face_colors = cmap(stat_map_faces) * face_colors
else:
face_colors = cmap(stat_map_faces)
if mask is not None:
face_colors[fmask] = cmap(stat_map_faces)[fmask]

else:
face_colors = cmap(stat_map_faces)

if labels is not None:
'''
Expand All @@ -111,6 +138,8 @@ def plot_surf_stat_map(coords, faces, stat_map=None,
valid color names from http://xkcd.com/color/rgb/
'''
if label_cpal is not None:
if label_col is not None:
raise ValueError("Don't use label_cpal and label_col together.")
if type(label_cpal) == str:
cpal = sns.color_palette(label_cpal, len(labels))
if type(label_cpal) == list:
Expand All @@ -121,12 +150,18 @@ def plot_surf_stat_map(coords, faces, stat_map=None,
except:
cpal = sns.xkcd_palette(label_cpal)




for n_label, label in enumerate(labels):
for n_face, face in enumerate(faces):
count = len(set(face).intersection(set(label)))
if (count > 0) & (count < 3):
if label_cpal is None:
face_colors[n_face,0:3] = sns.xkcd_palette(["black"])[0]
if label_col is not None:
face_colors[n_face,0:3] = sns.xkcd_palette([label_col])[0]
else:
face_colors[n_face,0:3] = sns.xkcd_palette(["black"])[0]
else:
face_colors[n_face,0:3] = cpal[n_label]

Expand Down Expand Up @@ -290,14 +325,96 @@ def crop_img(fig, margin=10):

kept = {'rows':[], 'cols':[]}
for row in range(img.shape[0]):
if len(set(np.ndarray.flatten(img[row,:,:]))) > 3:
if len(set(np.ndarray.flatten(img[row,:,:]))) > 1:
kept['rows'].append(row)
for col in range(img.shape[1]):
if len(set(np.ndarray.flatten(img[:,col,:]))) > 3:
if len(set(np.ndarray.flatten(img[:,col,:]))) > 1:
kept['cols'].append(col)

if margin:
return img[min(kept['rows'])-margin:max(kept['rows'])+margin,
min(kept['cols'])-margin:max(kept['cols'])+margin]
else:
return img[kept['rows']][:,kept['cols']]
return img[kept['rows']][:,kept['cols']]




def create_fig(data=None, labels=None, label_col=None,
hemi=None, surf='pial',
sulc=True, alpha='auto',
cmap='cubehelix', cpal='bright', cbar=False,
dmin=None, dmax=None,
mask=None):

import nibabel as nib, numpy as np
import matplotlib.pyplot as plt, matplotlib as mpl
from IPython.core.display import Image, display
import os

fsDir = '/afs/cbs.mpg.de/projects/mar004_lsd-lemon-preproc/freesurfer'
surf_f = '%s/fsaverage5/surf/%s.%s' % (fsDir, hemi, surf)
coords = nib.freesurfer.io.read_geometry(surf_f)[0]
faces = nib.freesurfer.io.read_geometry(surf_f)[1]
if sulc:
sulc_f = '%s/fsaverage5/surf/%s.sulc' % (fsDir, hemi)
sulc = nib.freesurfer.io.read_morph_data(sulc_f)
sulc_bool = True
else:
sulc = None
sulc_bool = False

# create images
imgs = []
for azim in [0, 180]:

if data is not None:
if dmin is None:
dmin = data[np.nonzero(data)].min()
if dmax is None:
dmax = data.max()
fig = plot_surf_stat_map(coords, faces, stat_map=data,
elev=0, azim=azim,
cmap=cmap,
bg_map=sulc,bg_on_stat=sulc_bool,
vmin=dmin, vmax=dmax,
labels=labels, label_col=label_col,
alpha=alpha,
mask=mask, mask_lenient=False)
#label_cpal=cpal)
else:
fig = plot_surf_label(coords, faces,
labels=labels,
elev=0, azim=azim,
bg_map=sulc,
cpal=cpal,
bg_on_labels=sulc_bool,
alpha=alpha)

# crop image
imgs.append((crop_img(fig, margin=15)),)
plt.close(fig)

# create figure with color bar
fig = plt.figure()
fig.set_size_inches(8, 4)

ax1 = plt.subplot2grid((4,60), (0,0), colspan = 26, rowspan =4)
plt.imshow(imgs[0])
ax1.set_axis_off()

ax2 = plt.subplot2grid((4,60), (0,28), colspan = 26, rowspan =4)
plt.imshow(imgs[1])
ax2.set_axis_off()

if cbar==True and data is not None:
cax = plt.subplot2grid((4,60), (1,59), colspan = 1, rowspan =2)
cmap = plt.cm.get_cmap(cmap)
norm = mpl.colors.Normalize(vmin=dmin, vmax=dmax)
cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)
cb.set_ticks([dmin, dmax])

fig.savefig('./tempimage')
plt.close(fig)
display(Image(filename='./tempimage.png', width=800))
os.remove('./tempimage.png')