Skip to content

Commit

Permalink
added docs
Browse files Browse the repository at this point in the history
  • Loading branch information
rsexton2 committed Jan 29, 2024
1 parent 470f2b8 commit a731bb3
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 65 deletions.
2 changes: 1 addition & 1 deletion INSTALL.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
conda create -n basicrta python=3.8
conda activate basicrta
conda install mamba
mamba install numpy tqdm matplotlib MDAnalysis scipy pandas seaborn ipython jupyter pymbar
mamba install -c conda-forge numpy tqdm matplotlib MDAnalysis scipy seaborn
pip install .
1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -672,3 +672,4 @@ may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

139 changes: 75 additions & 64 deletions basicrta/functions.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
"""Analysis functions"""

from matplotlib.ticker import (MultipleLocator, FormatStrFormatter, \
AutoMinorLocator)
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
import ast
import multiprocessing
import ast, multiprocessing, os
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import os
import pymbar.timeseries as pmts
from MDAnalysis.analysis.base import Results
import pickle
import pickle, bz2, gc
from glob import glob
import seaborn as sns
import math
from numpy.random import default_rng
from tqdm import tqdm
import MDAnalysis as mda
import gc
from scipy.optimize import linear_sum_assignment as lsa
import bz2
from scipy import stats
from sklearn.cluster import KMeans

Expand Down Expand Up @@ -76,15 +72,21 @@ def tm(Prot,i):


class gibbs(object):
def __init__(self, times, residue, loc=0, ncomp=15, niter=10000):
"""Gibbs sampler to estimate parameters of an exponential mixture for a set
of data. Results are stored in gibbs.results, which uses
MDAnalysis.analysis.base.Results(). If 'results=None' the gibbs sampler has
not been executed, which requires calling '.run()'
"""

def __init__(self, times, residue, loc=0, ncomp=15, niter=50000):
self.times, self.residue = times, residue
self.niter, self.loc, self.ncomp = niter, loc, ncomp
self.results = None

diff = (np.sort(times)[1:]-np.sort(times)[:-1])
self.ts = diff[diff!=0][0]

def __repr__(self):
return f'Gibbs sampler'


def __str__(self):
Expand All @@ -93,61 +95,67 @@ def __str__(self):

def run(self):
x = self.times
g = 100
t, _s = get_s(x, self.ts)
if not os.path.exists(f'{self.residue}'):
os.mkdir(f'{self.residue}')

# initialize arrays
inrates = 0.5*10**np.arange(-self.ncomp+2, 2, dtype=float)
indicator = np.memmap(f'{self.residue}/.indicator_{self.niter}.npy', \
shape=(self.niter, x.shape[0]), mode='w+', \
shape=((self.niter+1)//g, x.shape[0]), mode='w+',\
dtype=np.uint8)
mcweights = np.zeros((self.niter + 1, self.ncomp))
mcrates = np.zeros((self.niter + 1, self.ncomp))
Ns, lnp = np.zeros((self.niter, self.ncomp)), np.zeros(self.niter)
mcweights = np.zeros(((self.niter+1)//g, self.ncomp))
mcrates = np.zeros(((self.niter+1)//g, self.ncomp))
#lnp = np.zeros(self.niter)
tmpw = 9*10**(-np.arange(1, self.ncomp+1, dtype=float))
mcweights[0], mcrates[0] = tmpw/tmpw.sum(), inrates[::-1]
weights, rates = tmpw/tmpw.sum(), inrates[::-1]

# guess hyperparameters
whypers = np.ones(self.ncomp)/[self.ncomp]
rhypers = np.ones((self.ncomp, 2))*[1, 3]

# gibbs sampler
for j in tqdm(range(self.niter), desc=f'{self.residue}-K{self.ncomp}', \
for j in tqdm(range(1, self.niter+1), \
desc=f'{self.residue}-K{self.ncomp}', \
position=self.loc, leave=False):

# compute probabilities
tmp = mcweights[j]*mcrates[j]*np.exp(np.outer(-mcrates[j],x)).T
tmp = weights*rates*np.exp(np.outer(-rates,x)).T
z = (tmp.T/tmp.sum(axis=1)).T

# sample and store indicator
# sample indicator
s = np.argmax(rng.multinomial(1, z), axis=1)
indicator[j] = s

# get occupied states
uniqs = np.unique(s)
inds = [np.where(s==i)[0] for i in range(self.ncomp)]

# compute total time and number of point for each component
Ns[j][:] = np.array([len(inds[i]) for i in range(self.ncomp)])
Ns = np.array([len(inds[i]) for i in range(self.ncomp)])
Ts = np.array([x[inds[i]].sum() for i in range(self.ncomp)])

# compute log posterior
lnp[j] = np.log(tmp.take(s)).sum()+\
np.log(mcweights[j][uniqs]).sum()-\
(mcrates[j][uniqs]*rhypers[uniqs, 1]).sum()+\
np.log(mcweights[j][uniqs]**(whypers[uniqs]-1)).sum()
#lnp[j] = np.log(tmp.take(s)).sum()+\
# np.log(mcweights[j][uniqs]).sum()-\
# (mcrates[j][uniqs]*rhypers[uniqs, 1]).sum()+\
# np.log(mcweights[j][uniqs]**(whypers[uniqs]-1)).sum()

# sample posteriors
mcweights[j+1] = rng.dirichlet(whypers+Ns[j])
mcrates[j+1] = rng.gamma(rhypers[:,0]+Ns[j], 1/(rhypers[:,1]+Ts))
weights = rng.dirichlet(whypers+Ns)
rates = rng.gamma(rhypers[:,0]+Ns, 1/(rhypers[:,1]+Ts))

# save every g steps
if j%g==0:
ind = j//g-1
mcweights[ind], mcrates[ind] = weights, rates
indicator[ind] = s



attrs = ["mcweights", "mcrates", "ncomp", "niter", "s", "t", "residue",
"lnp", "times"]
"times"]
values = [mcweights, mcrates, self.ncomp, self.niter, _s, t,
self.residue, lnp, x]
self.residue, x]

r = save_results(attrs, values)

Expand Down Expand Up @@ -191,30 +199,29 @@ def estimate_params(processed_results):
def process_gibbs(results, cutoff=1e-4):
r = results


#burnin, g, nsample = pmts.detect_equilibration(r.lnp, nskip=20)
burnin, g = int(burnin), int(np.ceil(g))

if burnin==0:
burnin = 500
else:
burnin = burnin
burnin, g = 10000, 100
burnin_ind = burnin//g

inds = np.where(r.mcweights[burnin::g]>cutoff)
indices = np.arange(burnin, r.niter, g)[inds[0]]
H = np.histogram([len(row[row>cutoff]) for row in r.mcweights[burnin::g]], \
bins=np.arange(1, r.ncomp+1))
ncomp = int(H[1][:-1][H[0]==H[0].max()][0])

weights, rates = r.mcweights[burnin::g][inds], r.mcrates[burnin::g][inds]
lnp = r.lnp[burnin::g][inds[0]]
inds = np.where(r.mcweights[burnin_ind:]>cutoff)
indices = np.arange(burnin, r.niter+1, g)[inds[0]]//g
lens = [len(row[row>cutoff]) for row in r.mcweights[burnin_ind:]]
ncomp = stats.mode(lens, keepdims=False)[0]

weights = r.mcweights[burnin_ind::][inds]
rates = r.mcrates[burnin_ind::][inds]
#lnp = r.lnp[burnin::g][inds[0]]

data = np.stack((weights, rates), axis=1)
km = KMeans(n_clusters=ncomp).fit(np.log(data))
Indicator = np.zeros((r.times.shape[0], ncomp))

for j,iteration in enumerate(np.unique(indices)):
indicator = np.memmap(f'{r.residue}/.indicator_{r.niter}.npy', \
shape=((r.niter+1)//g, r.times.shape[0]), mode='r', \
dtype=np.uint8)

for j in np.unique(inds[0]):
mapinds = km.labels_[inds[0]==j]
for i,indx in enumerate(inds[1][indices==iteration]):
tmpind = np.where(indicator[iteration]==indx)[0]
for i,indx in enumerate(inds[1][inds[0]==j]):
tmpind = np.where(indicator[j]==indx)[0]
Indicator[tmpind, mapinds[i]] += 1

Indicator = (Indicator.T/Indicator.sum(axis=1)).T
Expand Down Expand Up @@ -377,7 +384,8 @@ def plot_post(results, attr, comp=None, save=False, show=False):
plt.show()


def plot_trace(results, attr, comp=None, xrange=None, yrange=None, save=False, show=False):
def plot_trace(results, attr, comp=None, xrange=None, yrange=None, save=False, \
show=False):
outdir = results.name
if attr=='weights':
tmp = getattr(results, 'mcweights')
Expand Down Expand Up @@ -409,8 +417,10 @@ def plot_trace(results, attr, comp=None, xrange=None, yrange=None, save=False, s
if yrange!=None:
plt.ylim(yrange[0], yrange[1])
if save:
plt.savefig(f'{outdir}/figs/k{results.ncomp}-trace_{attr}_comps-{"-".join([str(i) for i in comp])}.png')
plt.savefig(f'{outdir}/figs/k{results.ncomp}-trace_{attr}_comps-{"-".join([str(i) for i in comp])}.pdf')
plt.savefig(f'{outdir}/figs/k{results.ncomp}-trace_{attr}_comps-\
{"-".join([str(i) for i in comp])}.png')
plt.savefig(f'{outdir}/figs/k{results.ncomp}-trace_{attr}_comps-\
{"-".join([str(i) for i in comp])}.pdf')
if show:
plt.show()
plt.close('all')
Expand Down Expand Up @@ -590,8 +600,10 @@ def check_results(residues, times, ts):
os.mkdir('result_check')
for time, residue in zip(times, residues):
if os.path.exists(residue):
kmax = glob(f'{residue}/K*_results.pkl')[-1].split('/')[-1].split('/')[-1].split('_')[0][1:]
os.popen(f'cp {residue}/figs/k{kmax}-mean_results.png result_check/{residue}-k{kmax}-results.png')
kmax = glob(f'{residue}/K*_results.pkl')[-1].split('/')[-1].\
split('/')[-1].split('_')[0][1:]
os.popen(f'cp {residue}/figs/k{kmax}-mean_results.png result_check/\
{residue}-k{kmax}-results.png')
else:
t, s = get_s(np.array(time), ts)
plt.scatter(t, s, label='data')
Expand Down Expand Up @@ -643,7 +655,8 @@ def write_trajs(u, time, trajtime, indicator, residue, lipind, step):
for comp in np.where(lens != 0)[0]:
write_frames, write_Linds = get_write_frames(u, time, trajtime, lipind, comp+2)
if len(write_frames) > step:
write_frames, write_Linds = write_frames[::step], write_Linds[::step]
write_frames = write_frames[::step]
write_Linds = write_Linds[::step]
with mda.Writer(f"{residue}/comp{comp}_traj.xtc", \
len((prot+chol.residues[0].atoms).atoms)) as W:
for i, ts in tqdm(enumerate(u.trajectory[write_frames]), \
Expand All @@ -654,18 +667,15 @@ def write_trajs(u, time, trajtime, indicator, residue, lipind, step):


def plot_hists(timelens, indicators, residues):
for timelen, indicator, residue in tqdm(zip(timelens, indicators, residues), total=len(timelens),
for timelen, indicator, residue in tqdm(zip(timelens, indicators, residues),
total=len(timelens),
desc='ploting hists'):
# framec = (np.round(timelen, 1) * 10).astype(int)
#inds = np.array([np.where(indicator.argmax(axis=0) == i)[0] for i in range(8)], dtype=object)
#lens = np.array([len(ind) for ind in inds])
#ncomps = len(np.where(lens != 0)[0])
ncomps = indicator[:,0].shape[0]

plt.close()
for i in range(ncomps):
# h, edges = np.histogram(framec, density=True, bins=50, weights=indicator[i])
h, edges = np.histogram(timelen, density=True, bins=50, weights=indicator[i])
h, edges = np.histogram(timelen, density=True, bins=50, \
weights=indicator[i])
m = 0.5*(edges[1:]+edges[:-1])
plt.plot(m, h, '.', label=i, alpha=0.5)
plt.ylabel('p')
Expand Down Expand Up @@ -738,7 +748,8 @@ def expand_times(contacts):
restimes = []
for lip in range(times.shape[1]):
for i in range(times[res, lip].shape[0]):
[restimes.append(j) for j in [times[res, lip][i]]*Ns[res, lip][i].astype(int)]
[restimes.append(j) for j in [times[res, lip][i]]*\
Ns[res, lip][i].astype(int)]
alltimes.append(restimes)
return np.asarray(alltimes)

Expand Down

0 comments on commit a731bb3

Please sign in to comment.