From 655205d56825ebe33702f52664ac1523a16b036c Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Sat, 10 Apr 2021 14:19:48 -0500 Subject: [PATCH 1/2] Move observations to cpu in PyroConverter --- arviz/data/io_pyro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/data/io_pyro.py b/arviz/data/io_pyro.py index b1314533ff..e1a2f1844a 100644 --- a/arviz/data/io_pyro.py +++ b/arviz/data/io_pyro.py @@ -110,7 +110,7 @@ def arbitrary_element(dct): if self.model is not None: trace = pyro.poutine.trace(self.model).get_trace(*self._args, **self._kwargs) observations = { - name: site["value"] + name: site["value"].cpu() for name, site in trace.nodes.items() if site["type"] == "sample" and site["is_observed"] } From 6b76eb2f5eb4483992723c9e7ae89510db30080e Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Sat, 10 Apr 2021 14:44:31 -0500 Subject: [PATCH 2/2] Updated CHANGELOG with fix --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 56af478e65..347606d10b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ * Improve error messages in `stats.compare()`, and `var_name` parameter. ([1616](https://github.com/arviz-devs/arviz/pull/1616)) ### Maintenance and fixes +* Fixed conversion of Pyro output fit using GPUs ([1659](https://github.com/arviz-devs/arviz/pull/1659)) * Enforced using coordinate values as default labels ([1201](https://github.com/arviz-devs/arviz/pull/1201)) * Integrate `index_origin` with all the library ([1201](https://github.com/arviz-devs/arviz/pull/1201)) * Fix pareto k threshold typo in reloo function ([1580](https://github.com/arviz-devs/arviz/pull/1580))