Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added subplot convenience functions #11

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
2.7.15
Empty file added __init__.py
Empty file.
1,075 changes: 1,075 additions & 0 deletions docs/.ipynb_checkpoints/easyplot_docs-checkpoint.ipynb

Large diffs are not rendered by default.

1,983 changes: 1,036 additions & 947 deletions docs/easyplot_docs.ipynb

Large diffs are not rendered by default.

Binary file added easyplot/.DS_Store
Binary file not shown.
105 changes: 85 additions & 20 deletions easyplot/easyplot.py
Original file line number Diff line number Diff line change
@@ -5,11 +5,15 @@

import matplotlib.pyplot as plt
import matplotlib as mpl

import numpy as np
import matplotlib.gridspec as gridspec
from mpl_toolkits.mplot3d.axes3d import Axes3D, get_test_data
from IPython import embed
if not plt.isinteractive():
print("\nMatplotlib interactive mode is currently OFF. It is "
print("\nMatplotlib interactive mode was currently OFF. It is "
"recommended to use a suitable matplotlib backend and turn it "
"on by calling matplotlib.pyplot.ion()\n")
"on by calling matplotlib.pyplot.ion()\n. Turning it on automatically now...")
plt.ion()

class EasyPlot(object):
"""
@@ -75,12 +79,15 @@ def __init__(self, *args, **kwargs):
"""
self._default_kwargs = {'fig': None,
'ax': None,
'gs':None,
'figsize': None,
'dpi': mpl.rcParams['figure.dpi'],
'showlegend': False,
'fancybox': True,
'loc': 'best',
'numpoints': 1
'numpoints': 1,
'return_handle':False,
'specialfunc': False
}
# Dictionary of plot parameter aliases
self.alias_dict = {'lw': 'linewidth', 'ls': 'linestyle',
@@ -101,6 +108,7 @@ def __init__(self, *args, **kwargs):
# Mapping between plot parameter and corresponding axes function to call
self._ax_funcs = {'xlabel': 'set_xlabel',
'ylabel': 'set_ylabel',
'zlabel': 'set_zlabel',
'xlim': 'set_xlim',
'ylim': 'set_ylim',
'title': 'set_title',
@@ -141,6 +149,10 @@ def add_plot(self, *args, **kwargs):
ax, fig = self.kwargs['ax'], self.kwargs['fig']

ax.ticklabel_format(useOffset=False) # Prevent offset notation in plots
if 'fontsize' in self.kwargs:
self.set_fontsize(self.kwargs['fontsize'])
ax.yaxis.label.set_size(self.kwargs['fontsize'])
ax.xaxis.label.set_size(self.kwargs['fontsize'])

# Apply axes functions if present in kwargs
for kwarg in self.kwargs:
@@ -155,7 +167,13 @@ def add_plot(self, *args, **kwargs):
plot_kwargs = {kwarg: self.kwargs[kwarg] for kwarg
in self._plot_kwargs if kwarg in self.kwargs}

line, = ax.plot(*self.args, **plot_kwargs)
if self.kwargs['specialfunc']:
if self.kwargs['specialfunc']=='scatter':
plot_kwargs.pop('markersize',None)
plot_kwargs.pop('linewidth',None)
line = ax.scatter(*self.args, **plot_kwargs)
else:
line, = ax.plot(*self.args, **plot_kwargs)
self.line_list.append(line)

# Display legend if required
@@ -165,14 +183,13 @@ def add_plot(self, *args, **kwargs):
leg = ax.legend(**legend_kwargs)
if leg is not None:
leg.draggable(state=True)

if 'fontsize' in self.kwargs:
self.set_fontsize(self.kwargs['fontsize'])

self._delete_uniqueparams() # Clear unique parameters from kwargs list

if plt.isinteractive(): # Only redraw canvas in interactive mode
self.redraw()
if self.kwargs['return_handle']:
return line

def update_plot(self, **kwargs):
""""Update plot parameters (keyword arguments) and replot figure
@@ -250,6 +267,18 @@ def iter_plot(self, x, y, mode='dict', **kwargs):
else:
print('Error! Incorrect mode specification. Ignoring method call')

def make_subplots(self,*args, **kwargs):
if self.kwargs['fig'] is None:
self.kwargs['fig'] = plt.figure(figsize=self.kwargs['figsize'],
dpi=self.kwargs['dpi'])
self.kwargs['gs'] = gridspec.GridSpec(*args, **kwargs)

def set_subplot(self,rows,cols,projection=None):
self.kwargs['ax'] = plt.subplot(self.kwargs['gs'][rows,cols])
if not (projection is None):
self.kwargs['ax'].remove()
self.kwargs['ax'] = self.kwargs['fig'].add_subplot(self.kwargs['gs'][rows,cols],projection=projection)

def autoscale(self, enable=True, axis='both', tight=None):
"""Autoscale the axis view to the data (toggle).

@@ -304,18 +333,22 @@ def redraw(self):

def set_fontsize(self, font_size):
""" Updates global font size for all plot elements"""
mpl.rcParams['font.size'] = font_size
self.redraw()
# mpl.rcParams['font.size'] = font_size
# mpl.rcParams['axes.labelsize'] = font_size
# self.redraw()
#TODO: Implement individual font size setting
# params = {'font.family': 'serif',
# 'font.size': 16,
# 'axes.labelsize': 18,
# 'text.fontsize': 18,
# 'legend.fontsize': 18,
# 'xtick.labelsize': 18,
# 'ytick.labelsize': 18,
# 'text.usetex': True}
# mpl.rcParams.update(params)
params = {'font.family': 'serif',
'font.size': font_size,
'axes.labelsize': font_size,
# 'text.fontsize': font_size,
'font.size': font_size,
'legend.fontsize': font_size,
'xtick.labelsize': font_size,
'ytick.labelsize': font_size,
#'text.usetex': True
}
mpl.rcParams.update(params)
self.redraw()

# def set_font(self, family=None, weight=None, size=None):
# """ Updates global font properties for all plot elements
@@ -367,4 +400,36 @@ def _reset(self, reset=False):
self.kwargs['fig'] = None
self.kwargs['ax'] = None
if reset:
self.kwargs = self._default_kwargs.copy()
self.kwargs = self._default_kwargs.copy()

def savefig(self,*args,**kwargs):
(self.kwargs['fig']).savefig(*args,**kwargs)

class EasyPlotManager(EasyPlot):
"""docstring for EasyPlotManager"""
def __init__(self,*args, **kwargs):
super(EasyPlotManager, self).__init__(*args, **kwargs)
self.handles = {}

def adddata(self,label,data,style='bo',xlabel=r'$x$',ylabel=r'$y$',showlegend=False,markersize=10,alpha=1,linewidth=2.,**kwargs):
if len(data)==3:
self.handles[label] = self.add_plot(data[0],data[1],data[2],label=label,showlegend=showlegend,markersize=markersize,alpha=alpha,xlabel=xlabel,ylabel=ylabel,linewidth=linewidth,return_handle=True,**kwargs)
else:
self.handles[label] = self.add_plot(data[0],data[1],style,label=label,showlegend=showlegend,markersize=markersize,alpha=alpha,xlabel=xlabel,ylabel=ylabel,linewidth=linewidth,return_handle=True,**kwargs)

def remove(self,label):
self.handles[label].remove()

def updatedata(self,label,data):
try:
(self.handles[label]).set_xdata(data[0])
(self.handles[label]).set_ydata(data[1])
except:
# embed()
(self.handles[label]).set_offsets(np.vstack((data[0],data[1])).T)
(self.handles[label]).set_3d_properties(data[2],'z')

def updateplot(self):
# embed()
(self.kwargs['fig']).canvas.draw()
(self.kwargs['fig']).canvas.flush_events()