diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index 992051fc9..bd30b163d 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -381,14 +381,21 @@ def save(self, fname: str) -> None: raise RuntimeError("The model hasn't been fit yet, call .fit() first") @classmethod - def _convert_dims_to_tuple(cls, model_config: Dict) -> Dict: + def _model_config_formatting(cls, model_config: Dict) -> Dict: + """ + Because of json serialization, model_config values that were originally tuples or numpy are being encoded as lists. + This function converts them back to tuples and numpy arrays to ensure correct id encoding. + """ for key in model_config: - if ( - isinstance(model_config[key], dict) - and "dims" in model_config[key] - and isinstance(model_config[key]["dims"], list) - ): - model_config[key]["dims"] = tuple(model_config[key]["dims"]) + if isinstance(model_config[key], dict): + for sub_key in model_config[key]: + if isinstance(model_config[key][sub_key], list): + # Check if "dims" key to convert it to tuple + if sub_key == "dims": + model_config[key][sub_key] = tuple(model_config[key][sub_key]) + # Convert all other lists to numpy arrays + else: + model_config[key][sub_key] = np.array(model_config[key][sub_key]) return model_config @classmethod @@ -420,7 +427,7 @@ def load(cls, fname: str): filepath = Path(str(fname)) idata = az.from_netcdf(filepath) # needs to be converted, because json.loads was changing tuple to list - model_config = cls._convert_dims_to_tuple(json.loads(idata.attrs["model_config"])) + model_config = cls._model_config_formatting(json.loads(idata.attrs["model_config"])) model = cls( model_config=model_config, sampler_config=json.loads(idata.attrs["sampler_config"]), diff --git a/pymc_experimental/tests/__init__.py b/pymc_experimental/tests/__init__.py index 3fb1a4058..5421880d2 100644 --- a/pymc_experimental/tests/__init__.py +++ b/pymc_experimental/tests/__init__.py @@ -11,7 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from numpy.testing import Tester - -test = Tester().test diff --git a/pymc_experimental/tests/test_model_builder.py b/pymc_experimental/tests/test_model_builder.py index 3327d1e0d..a37b00a4a 100644 --- a/pymc_experimental/tests/test_model_builder.py +++ b/pymc_experimental/tests/test_model_builder.py @@ -160,7 +160,7 @@ def test_convert_dims_to_tuple(fitted_model_instance): ], }, } - converted_model_config = fitted_model_instance._convert_dims_to_tuple(model_config) + converted_model_config = fitted_model_instance._model_config_formatting(model_config) assert converted_model_config["a"]["dims"] == ("x",) diff --git a/setup.cfg b/setup.cfg index a15615941..0b4c86f9d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,10 @@ [tool:pytest] testpaths = tests +filterwarnings = + error + ignore::DeprecationWarning:numpy.core.fromnumeric + ignore:::arviz.* + ignore:DeprecationWarning [isort] lines_between_types = 1