-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 9d5d8d9
Showing
9 changed files
with
189 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
__pycache__/ | ||
.tox/ | ||
.coverage | ||
.python-version | ||
.cache | ||
htmlcov/ | ||
*.egg-info/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
sudo: false | ||
language: python | ||
python: | ||
- "3.5" | ||
- "3.6" | ||
install: pip install tox-travis coveralls | ||
script: tox | ||
after_success: | ||
- coveralls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
|Build Status| |Coverage Status| | ||
|
||
======== | ||
sampled | ||
======== | ||
|
||
|
||
*Decorator for reusable models in PyMC3* | ||
|
||
.. |Build Status| image:: https://travis-ci.org/ColCarroll/sampled.svg?branch=master | ||
:target: https://travis-ci.org/ColCarroll/sampled | ||
.. |Coverage Status| image:: https://coveralls.io/repos/github/ColCarroll/sampled/badge.svg?branch=master | ||
:target: https://coveralls.io/github/ColCarroll/sampled?branch=master |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .sampled import sampled # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import pymc3 as pm | ||
|
||
|
||
class ObserverModel(pm.Model): | ||
"""Stores observed variables until the model is created.""" | ||
def __init__(self, observed): | ||
self.observed = observed | ||
super(ObserverModel, self).__init__() | ||
|
||
def Var(self, name, dist, data=None, **kwargs): | ||
return super(ObserverModel, self).Var(name, dist, | ||
data=self.observed.get(name, data), | ||
**kwargs) | ||
|
||
|
||
def sampled(f): | ||
"""Decorator to delay initializing pymc3 model until data is passed in.""" | ||
def wrapped_f(**observed): | ||
try: | ||
with ObserverModel(observed) as model: | ||
f(**observed) | ||
except TypeError: | ||
with ObserverModel(observed) as model: | ||
f() | ||
return model | ||
return wrapped_f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
[tool:pytest] | ||
norecursedirs = .tox | ||
testpaths = test | ||
|
||
[bdist_wheel] | ||
universal=1 | ||
|
||
[flake8] | ||
max-line-length = 100 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from codecs import open | ||
from os import path | ||
from setuptools import setup, find_packages | ||
|
||
here = path.abspath(path.dirname(__file__)) | ||
|
||
# Get the long description from the README file | ||
with open(path.join(here, 'README.rst'), encoding='utf-8') as buff: | ||
long_description = buff.read() | ||
|
||
setup( | ||
name='sampled', | ||
version='0.0.1', | ||
description='Decorator for reusable models in PyMC3', | ||
long_description=long_description, | ||
author='Colin Carroll', | ||
author_email='colcarroll@gmail.com', | ||
url='https://github.com/ColCarroll/sampled', | ||
license='MIT', | ||
classifiers=[ | ||
'Development Status :: 3 - Alpha', | ||
'Intended Audience :: Developers', | ||
'License :: OSI Approved :: MIT License', | ||
'Programming Language :: Python :: 2', | ||
'Programming Language :: Python :: 2.7', | ||
'Programming Language :: Python :: 3', | ||
'Programming Language :: Python :: 3.5', | ||
'Programming Language :: Python :: 3.6', | ||
], | ||
packages=find_packages(exclude=['test']), | ||
install_requires=[ | ||
'pymc3>=3.0', | ||
], | ||
include_package_data=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
import pymc3 as pm | ||
import theano.tensor as tt | ||
|
||
from sampled import sampled | ||
|
||
|
||
def test_sampled_one_model(): | ||
@sampled | ||
def just_a_normal(): | ||
pm.Normal('x', mu=0, sd=1) | ||
|
||
draws = 50 | ||
with just_a_normal(): | ||
decorated_trace = pm.sample(draws=draws, tune=50, init=None) | ||
|
||
assert decorated_trace.varnames == ['x'] | ||
assert len(decorated_trace.get_values('x')) == draws | ||
|
||
|
||
def test_reuse_model(): | ||
@sampled | ||
def two_normals(): | ||
mu = pm.Normal('mu', mu=0, sd=1) | ||
pm.Normal('x', mu=mu, sd=1) | ||
|
||
with two_normals(): | ||
generated_data = pm.sample(draws=50, tune=50, init=None) | ||
|
||
for varname in ('mu', 'x'): | ||
assert varname in generated_data.varnames | ||
|
||
with two_normals(mu=1): | ||
posterior_data = pm.sample(draws=50, tune=50, init=None) | ||
|
||
assert 'x' in posterior_data.varnames | ||
assert 'mu' not in posterior_data.varnames | ||
assert posterior_data.get_values('x').mean() > generated_data.get_values('x').mean() | ||
|
||
|
||
def test_linear_model(): | ||
rows, cols = 1000, 10 | ||
X = np.random.normal(size=(rows, cols)) | ||
w = np.random.normal(size=cols) | ||
y = X.dot(w) + np.random.normal(scale=0.1, size=rows) | ||
|
||
@sampled | ||
def linear_model(X, y): | ||
shape = X.shape | ||
X = pm.Normal('X', mu=np.mean(X, axis=0), sd=np.std(X, axis=0), shape=shape) | ||
coefs = pm.Normal('coefs', mu=tt.zeros(shape[1]), sd=tt.ones(shape[1]), shape=shape[1]) | ||
pm.Normal('y', mu=tt.dot(X, coefs), sd=tt.ones(shape[0]), shape=shape[0]) | ||
|
||
with linear_model(X=X, y=y): | ||
sampled_coefs = pm.sample(draws=1000, tune=500) | ||
mean_coefs = sampled_coefs.get_values('coefs').mean(axis=0) | ||
np.testing.assert_allclose(mean_coefs, w, atol=0.01) | ||
|
||
|
||
def test_partial_model(): | ||
rows, cols = 1000, 10 | ||
X = np.random.normal(size=(rows, cols)) | ||
w = np.random.normal(size=cols) | ||
y = X.dot(w) + np.random.normal(scale=0.1, size=rows) | ||
|
||
@sampled | ||
def partial_linear_model(X): | ||
shape = X.shape | ||
X = pm.Normal('X', mu=np.mean(X, axis=0), sd=np.std(X, axis=0), shape=shape) | ||
pm.Normal('coefs', mu=tt.zeros(shape[1]), sd=tt.ones(shape[1]), shape=shape[1]) | ||
|
||
with partial_linear_model(X=X) as model: | ||
coefs = model.named_vars['coefs'] | ||
pm.Normal('y', mu=tt.dot(X, coefs), sd=tt.ones(y.shape), observed=y) | ||
sampled_coefs = pm.sample(draws=1000, tune=500) | ||
|
||
mean_coefs = sampled_coefs.get_values('coefs').mean(axis=0) | ||
np.testing.assert_allclose(mean_coefs, w, atol=0.01) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
[tox] | ||
envlist = py35, py36 | ||
|
||
[testenv] | ||
deps= | ||
pytest | ||
pytest-cov | ||
flake8 | ||
commands= | ||
py.test -v --cov={envsitepackagesdir}/sampled --cov-report=html --cov-report=term test/ | ||
flake8 |