Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ColCarroll committed Jun 1, 2017
0 parents commit 9d5d8d9
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
__pycache__/
.tox/
.coverage
.python-version
.cache
htmlcov/
*.egg-info/
9 changes: 9 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
sudo: false
language: python
python:
- "3.5"
- "3.6"
install: pip install tox-travis coveralls
script: tox
after_success:
- coveralls
13 changes: 13 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
|Build Status| |Coverage Status|

========
sampled
========


*Decorator for reusable models in PyMC3*

.. |Build Status| image:: https://travis-ci.org/ColCarroll/sampled.svg?branch=master
:target: https://travis-ci.org/ColCarroll/sampled
.. |Coverage Status| image:: https://coveralls.io/repos/github/ColCarroll/sampled/badge.svg?branch=master
:target: https://coveralls.io/github/ColCarroll/sampled?branch=master
1 change: 1 addition & 0 deletions sampled/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .sampled import sampled # noqa
26 changes: 26 additions & 0 deletions sampled/sampled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pymc3 as pm


class ObserverModel(pm.Model):
"""Stores observed variables until the model is created."""
def __init__(self, observed):
self.observed = observed
super(ObserverModel, self).__init__()

def Var(self, name, dist, data=None, **kwargs):
return super(ObserverModel, self).Var(name, dist,
data=self.observed.get(name, data),
**kwargs)


def sampled(f):
"""Decorator to delay initializing pymc3 model until data is passed in."""
def wrapped_f(**observed):
try:
with ObserverModel(observed) as model:
f(**observed)
except TypeError:
with ObserverModel(observed) as model:
f()
return model
return wrapped_f
9 changes: 9 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[tool:pytest]
norecursedirs = .tox
testpaths = test

[bdist_wheel]
universal=1

[flake8]
max-line-length = 100
35 changes: 35 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from codecs import open
from os import path
from setuptools import setup, find_packages

here = path.abspath(path.dirname(__file__))

# Get the long description from the README file
with open(path.join(here, 'README.rst'), encoding='utf-8') as buff:
long_description = buff.read()

setup(
name='sampled',
version='0.0.1',
description='Decorator for reusable models in PyMC3',
long_description=long_description,
author='Colin Carroll',
author_email='colcarroll@gmail.com',
url='https://github.com/ColCarroll/sampled',
license='MIT',
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
],
packages=find_packages(exclude=['test']),
install_requires=[
'pymc3>=3.0',
],
include_package_data=True,
)
78 changes: 78 additions & 0 deletions test/test_sampled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import pymc3 as pm
import theano.tensor as tt

from sampled import sampled


def test_sampled_one_model():
@sampled
def just_a_normal():
pm.Normal('x', mu=0, sd=1)

draws = 50
with just_a_normal():
decorated_trace = pm.sample(draws=draws, tune=50, init=None)

assert decorated_trace.varnames == ['x']
assert len(decorated_trace.get_values('x')) == draws


def test_reuse_model():
@sampled
def two_normals():
mu = pm.Normal('mu', mu=0, sd=1)
pm.Normal('x', mu=mu, sd=1)

with two_normals():
generated_data = pm.sample(draws=50, tune=50, init=None)

for varname in ('mu', 'x'):
assert varname in generated_data.varnames

with two_normals(mu=1):
posterior_data = pm.sample(draws=50, tune=50, init=None)

assert 'x' in posterior_data.varnames
assert 'mu' not in posterior_data.varnames
assert posterior_data.get_values('x').mean() > generated_data.get_values('x').mean()


def test_linear_model():
rows, cols = 1000, 10
X = np.random.normal(size=(rows, cols))
w = np.random.normal(size=cols)
y = X.dot(w) + np.random.normal(scale=0.1, size=rows)

@sampled
def linear_model(X, y):
shape = X.shape
X = pm.Normal('X', mu=np.mean(X, axis=0), sd=np.std(X, axis=0), shape=shape)
coefs = pm.Normal('coefs', mu=tt.zeros(shape[1]), sd=tt.ones(shape[1]), shape=shape[1])
pm.Normal('y', mu=tt.dot(X, coefs), sd=tt.ones(shape[0]), shape=shape[0])

with linear_model(X=X, y=y):
sampled_coefs = pm.sample(draws=1000, tune=500)
mean_coefs = sampled_coefs.get_values('coefs').mean(axis=0)
np.testing.assert_allclose(mean_coefs, w, atol=0.01)


def test_partial_model():
rows, cols = 1000, 10
X = np.random.normal(size=(rows, cols))
w = np.random.normal(size=cols)
y = X.dot(w) + np.random.normal(scale=0.1, size=rows)

@sampled
def partial_linear_model(X):
shape = X.shape
X = pm.Normal('X', mu=np.mean(X, axis=0), sd=np.std(X, axis=0), shape=shape)
pm.Normal('coefs', mu=tt.zeros(shape[1]), sd=tt.ones(shape[1]), shape=shape[1])

with partial_linear_model(X=X) as model:
coefs = model.named_vars['coefs']
pm.Normal('y', mu=tt.dot(X, coefs), sd=tt.ones(y.shape), observed=y)
sampled_coefs = pm.sample(draws=1000, tune=500)

mean_coefs = sampled_coefs.get_values('coefs').mean(axis=0)
np.testing.assert_allclose(mean_coefs, w, atol=0.01)
11 changes: 11 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[tox]
envlist = py35, py36

[testenv]
deps=
pytest
pytest-cov
flake8
commands=
py.test -v --cov={envsitepackagesdir}/sampled --cov-report=html --cov-report=term test/
flake8

0 comments on commit 9d5d8d9

Please sign in to comment.