Skip to content

Commit

Permalink
Initial changes to allow pymc3.Data() to support both int and float i…
Browse files Browse the repository at this point in the history
…nput data (previously all input data was coerced to float)

WIP for pymc-devs#3813
  • Loading branch information
hottwaj committed Feb 26, 2020
1 parent 433c693 commit 05462f2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
14 changes: 12 additions & 2 deletions pymc3/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,20 @@ class Data:
For more information, take a look at this example notebook
https://docs.pymc.io/notebooks/data_container.html
"""
def __new__(self, name, value):
def __new__(self, name, value, dtype = None):
if dtype is None:
if hasattr(value, 'dtype'):
# if no dtype given, but available as attr of value, use that as dtype
dtype = value.dtype
elif isinstance(value, int):
dtype = int
else:
# otherwise, assume float
dtype = float

# `pm.model.pandas_to_array` takes care of parameter `value` and
# transforms it to something digestible for pymc3
shared_object = theano.shared(pm.model.pandas_to_array(value), name)
shared_object = theano.shared(pm.model.pandas_to_array(value, dtype = dtype), name)

# To draw the node for this variable in the graphviz Digraph we need
# its shape.
Expand Down
8 changes: 5 additions & 3 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,7 @@ def init_value(self):
return self.tag.test_value


def pandas_to_array(data):
def pandas_to_array(data, dtype = float):
if hasattr(data, 'values'): # pandas
if data.isnull().any().any(): # missing values
ret = np.ma.MaskedArray(data.values, data.isnull().values)
Expand All @@ -1492,8 +1492,10 @@ def pandas_to_array(data):
ret = generator(data)
else:
ret = np.asarray(data)
return pm.floatX(ret)

if dtype in [float, np.float32, np.float64]:
return pm.floatX(ret)
elif dtype in [int, np.int32, np.int64]:
return pm.intX(ret)

def as_tensor(data, name, model, distribution):
dtype = distribution.dtype
Expand Down

0 comments on commit 05462f2

Please sign in to comment.