From c352814d17ca640ca5c37a5b885ab45d01b2ef3b Mon Sep 17 00:00:00 2001 From: Laszlo Treszkai Date: Thu, 6 Jan 2022 18:02:36 +0100 Subject: [PATCH] Workaround Theano issue w/ NumPy>=1.22.0 (#5310) --- RELEASE-NOTES.md | 3 +++ pymc3/__init__.py | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 558bfe4cf7c..0062d28c689 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -6,6 +6,9 @@ + The `pm.Distribution(testval=...)` kwarg was deprecated and will be replaced by `pm.Distribution(initval=...)`in `pymc >=4` (see [#5226](https://github.com/pymc-devs/pymc/pulls/5226)). + The `pm.sample(start=...)` kwarg was deprecated and will be replaced by `pm.sample(initvals=...)`in `pymc >=4` (see [#5226](https://github.com/pymc-devs/pymc/pulls/5226)). +### Bugfixes ++ A hotfix is applied on import to remain compatible with NumPy 1.22 (see [#5316](https://github.com/pymc-devs/pymc/pull/5316)). + ## PyMC3 3.11.4 (20 August 2021) ### New Features diff --git a/pymc3/__init__.py b/pymc3/__init__.py index 53ca2d3977c..e8ea258fb49 100644 --- a/pymc3/__init__.py +++ b/pymc3/__init__.py @@ -18,8 +18,30 @@ import logging import multiprocessing as mp import platform +import warnings +import numpy.distutils import semver + +# Workaround for Theano bug that tries to access blas_opt_info; +# must be done before importing theano. +# https://github.com/pymc-devs/pymc/issues/5310 +# Copied from theano/link/c/cmodule.py: default_blas_ldflags() +if ( + hasattr(numpy.distutils, "__config__") + and numpy.distutils.__config__ + and not hasattr(numpy.distutils.__config__, "blas_opt_info") +): + import numpy.distutils.system_info # noqa + + # We need to catch warnings as in some cases NumPy print + # stuff that we don't want the user to see. + with warnings.catch_warnings(record=True): + numpy.distutils.system_info.system_info.verbosity = 0 + blas_info = numpy.distutils.system_info.get_info("blas_opt") + + numpy.distutils.__config__.blas_opt_info = blas_info + import theano _log = logging.getLogger("pymc3")