Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,21 @@ def save(self, fname: str) -> None:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")

@classmethod
def _convert_dims_to_tuple(cls, model_config: Dict) -> Dict:
def _model_config_formatting(cls, model_config: Dict) -> Dict:
"""
Because of json serialization, model_config values that were originally tuples or numpy are being encoded as lists.
This function converts them back to tuples and numpy arrays to ensure correct id encoding.
"""
for key in model_config:
if (
isinstance(model_config[key], dict)
and "dims" in model_config[key]
and isinstance(model_config[key]["dims"], list)
):
model_config[key]["dims"] = tuple(model_config[key]["dims"])
if isinstance(model_config[key], dict):
for sub_key in model_config[key]:
if isinstance(model_config[key][sub_key], list):
# Check if "dims" key to convert it to tuple
if sub_key == "dims":
model_config[key][sub_key] = tuple(model_config[key][sub_key])
# Convert all other lists to numpy arrays
else:
model_config[key][sub_key] = np.array(model_config[key][sub_key])
return model_config

@classmethod
Expand Down Expand Up @@ -420,7 +427,7 @@ def load(cls, fname: str):
filepath = Path(str(fname))
idata = az.from_netcdf(filepath)
# needs to be converted, because json.loads was changing tuple to list
model_config = cls._convert_dims_to_tuple(json.loads(idata.attrs["model_config"]))
model_config = cls._model_config_formatting(json.loads(idata.attrs["model_config"]))
model = cls(
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
Expand Down
4 changes: 0 additions & 4 deletions pymc_experimental/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from numpy.testing import Tester

test = Tester().test
2 changes: 1 addition & 1 deletion pymc_experimental/tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_convert_dims_to_tuple(fitted_model_instance):
],
},
}
converted_model_config = fitted_model_instance._convert_dims_to_tuple(model_config)
converted_model_config = fitted_model_instance._model_config_formatting(model_config)
assert converted_model_config["a"]["dims"] == ("x",)


Expand Down
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
[tool:pytest]
testpaths = tests
filterwarnings =
error
ignore::DeprecationWarning:numpy.core.fromnumeric
ignore:::arviz.*
ignore:DeprecationWarning

[isort]
lines_between_types = 1
Expand Down