Skip to content

Commit

Permalink
bump python to 3.12 in the test environment (#3343)
Browse files Browse the repository at this point in the history
Fix a bug caused by the breaking change in Keras 3 (shipped by TF 2.16).

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
(cherry picked from commit 473cc0a)
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Apr 6, 2024
1 parent 36a93f2 commit 01f1dfd
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
tf: 1.14
- python: 3.8
tf:
- python: "3.11"
- python: "3.12"
tf:

steps:
Expand Down
5 changes: 5 additions & 0 deletions backend/find_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def get_tf_requirement(tf_version: str = "") -> dict:
extra_select = {}
if not (tf_version == "" or tf_version in SpecifierSet(">=2.12", prereleases=True)):
extra_requires.append("protobuf<3.20")
# keras 3 is not compatible with tf.compat.v1
if tf_version == "" or tf_version in SpecifierSet(">=2.15.0rc0", prereleases=True):
extra_requires.append("tf-keras; python_version>='3.9'")
# only TF>=2.16 is compatible with Python 3.12
extra_requires.append("tf-keras>=2.16.0rc0; python_version>='3.12'")
if tf_version == "" or tf_version in SpecifierSet(">=1.15", prereleases=True):
extra_select["mpi"] = [
"horovod",
Expand Down
1 change: 1 addition & 0 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ def _attention_layers(
input_xyz = tf.keras.layers.LayerNormalization(
beta_initializer=tf.constant_initializer(self.beta[i]),
gamma_initializer=tf.constant_initializer(self.gamma[i]),
dtype=self.filter_precision,
)(input_xyz)
# input_xyz = self._feedforward(input_xyz, outputs_size[-1], self.att_n)
return input_xyz
Expand Down
3 changes: 3 additions & 0 deletions deepmd/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def dlopen_library(module: str, filename: str):
dlopen_library("nvidia.cudnn.lib", "libcudnn.so*")


# keras 3 is incompatible with tf.compat.v1
# https://keras.io/getting_started/#tensorflow--keras-2-backwards-compatibility
os.environ["TF_USE_LEGACY_KERAS"] = "1"
# import tensorflow v1 compatability
try:
import tensorflow.compat.v1 as tf
Expand Down

0 comments on commit 01f1dfd

Please sign in to comment.