From 610a20c5f8dde566741119648f2d93dec1df7df4 Mon Sep 17 00:00:00 2001 From: Shashank Date: Sun, 17 Feb 2019 17:48:12 +0530 Subject: [PATCH 1/2] Adding Documentation example for combining DocString --- arviz/data/inference_data.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 8bb72d868d..f043f9b69b 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -91,7 +91,14 @@ def to_netcdf(self, filename, compress=True): return filename def __add__(self, other): - """Concatenate two InferenceData objects.""" + """Concatenate two InferenceData objects. + example + -------- + A._groups == ["posterior", "posterior_predictive"] + B._groups == ["prior", "prior_predictive"] + C=A+B + C._groups ==["posterior", "posterior_predictive", "prior", "prior_predictive"] + """ return concat(self, other, copy=True, inplace=False) @@ -116,6 +123,19 @@ def concat(*args, copy=True, inplace=False): InferenceData A new InferenceData object by default. When `inplace==True` merge args to first arg and return `None` + + example + ------- + A._groups == ["posterior", "posterior_predictive"] + B._groups == ["prior", "prior_predictive"] + C = az.concat(A, B) + C._groups ==["posterior", "posterior_predictive", "prior", "prior_predictive"] + + When inplace=True + ----------------- + az.concat(A, B, inplace=True) + A._groups ==["posterior", "posterior_predictive", "prior", "prior_predictive"] + """ if len(args) == 0: return InferenceData() From 3949bee46bcac62783ba631b8e15f45d5eb4b7ec Mon Sep 17 00:00:00 2001 From: Shashank Date: Sun, 17 Feb 2019 19:35:23 +0530 Subject: [PATCH 2/2] Adding the test cases and solving the issue of creation of empty file --- arviz/data/inference_data.py | 24 +++++++++++++++--------- arviz/tests/test_data.py | 16 ++++++++++++++-- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index f043f9b69b..a756bcb3b9 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -80,15 +80,21 @@ def to_netcdf(self, filename, compress=True): Location of netcdf file """ mode = "w" # overwrite first, then append - for group in self._groups: - data = getattr(self, group) - kwargs = {} - if compress: - kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables} - data.to_netcdf(filename, mode=mode, group=group, **kwargs) - data.close() - mode = "a" - return filename + #if the netcdf file previously exists the append the file + if self._groups: + for group in self._groups: + data = getattr(self, group) + kwargs = {} + if compress: + kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables} + data.to_netcdf(filename, mode=mode, group=group, **kwargs) + data.close() + mode = "a" + return filename + else: # else if the file doesnot exists previously creating the empty netcdf file + empty_file=nc.Dataset(filename,mode="w",format="NETCDF4") + empty_file.close() + return filename def __add__(self, other): """Concatenate two InferenceData objects. diff --git a/arviz/tests/test_data.py b/arviz/tests/test_data.py index d65d6ca1b3..7e75839978 100644 --- a/arviz/tests/test_data.py +++ b/arviz/tests/test_data.py @@ -672,7 +672,7 @@ def test_io_function(self, data, eight_schools_params): assert hasattr(inference_data3, "posterior") os.remove(filepath) assert not os.path.exists(filepath) - + def test_io_method(self, data, eight_schools_params): inference_data = self.get_inference_data( # pylint: disable=W0612 data, eight_schools_params @@ -690,7 +690,19 @@ def test_io_method(self, data, eight_schools_params): assert hasattr(inference_data2, "posterior") os.remove(filepath) assert not os.path.exists(filepath) - + + #New test case added for the empty file to be created + def test_new_file(self): + inference_data=InferenceData() + here=os.path.dirname(os.path.abspath(__file__)) + data_directory=os.path.join(here,"saved_models") + filepath=os.path.join(data_directory,"io_new_file.nc") + assert not os.path.exists(filepath) + to_netcdf(inference_data,filepath) + assert os.path.exists(filepath) + assert os.path.getsize(filepath)>0 + os.remove(filepath) + assert not os.path.exists(filepath) class TestPyMC3NetCDFUtils: @pytest.fixture(scope="class")