diff --git a/deepmd/dpmodel/descriptor/__init__.py b/deepmd/dpmodel/descriptor/__init__.py
index 1a7b376a36..5c3987e1c5 100644
--- a/deepmd/dpmodel/descriptor/__init__.py
+++ b/deepmd/dpmodel/descriptor/__init__.py
@@ -11,6 +11,9 @@
 from .make_base_descriptor import (
     make_base_descriptor,
 )
+from .se_atten_v2 import (
+    DescrptSeAttenV2,
+)
 from .se_e2_a import (
     DescrptSeA,
 )
@@ -26,6 +29,7 @@
     "DescrptSeR",
     "DescrptSeT",
     "DescrptDPA1",
+    "DescrptSeAttenV2",
     "DescrptDPA2",
     "DescrptHybrid",
     "make_base_descriptor",
diff --git a/deepmd/dpmodel/descriptor/se_atten_v2.py b/deepmd/dpmodel/descriptor/se_atten_v2.py
new file mode 100644
index 0000000000..1375d2265f
--- /dev/null
+++ b/deepmd/dpmodel/descriptor/se_atten_v2.py
@@ -0,0 +1,180 @@
+# SPDX-License-Identifier: LGPL-3.0-or-later
+from typing import (
+    Any,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
+
+import numpy as np
+
+from deepmd.dpmodel import (
+    DEFAULT_PRECISION,
+    PRECISION_DICT,
+)
+from deepmd.dpmodel.utils import (
+    NetworkCollection,
+)
+from deepmd.dpmodel.utils.type_embed import (
+    TypeEmbedNet,
+)
+from deepmd.utils.version import (
+    check_version_compatibility,
+)
+
+from .base_descriptor import (
+    BaseDescriptor,
+)
+from .dpa1 import (
+    DescrptDPA1,
+    NeighborGatedAttention,
+)
+
+
+@BaseDescriptor.register("se_atten_v2")
+class DescrptSeAttenV2(DescrptDPA1):
+    def __init__(
+        self,
+        rcut: float,
+        rcut_smth: float,
+        sel: Union[List[int], int],
+        ntypes: int,
+        neuron: List[int] = [25, 50, 100],
+        axis_neuron: int = 8,
+        tebd_dim: int = 8,
+        resnet_dt: bool = False,
+        trainable: bool = True,
+        type_one_side: bool = False,
+        attn: int = 128,
+        attn_layer: int = 2,
+        attn_dotr: bool = True,
+        attn_mask: bool = False,
+        exclude_types: List[Tuple[int, int]] = [],
+        env_protection: float = 0.0,
+        set_davg_zero: bool = False,
+        activation_function: str = "tanh",
+        precision: str = DEFAULT_PRECISION,
+        scaling_factor=1.0,
+        normalize: bool = True,
+        temperature: Optional[float] = None,
+        trainable_ln: bool = True,
+        ln_eps: Optional[float] = 1e-5,
+        concat_output_tebd: bool = True,
+        spin: Optional[Any] = None,
+        stripped_type_embedding: Optional[bool] = None,
+        use_econf_tebd: bool = False,
+        type_map: Optional[List[str]] = None,
+        # consistent with argcheck, not used though
+        seed: Optional[int] = None,
+    ) -> None:
+        DescrptDPA1.__init__(
+            self,
+            rcut,
+            rcut_smth,
+            sel,
+            ntypes,
+            neuron=neuron,
+            axis_neuron=axis_neuron,
+            tebd_dim=tebd_dim,
+            tebd_input_mode="strip",
+            resnet_dt=resnet_dt,
+            trainable=trainable,
+            type_one_side=type_one_side,
+            attn=attn,
+            attn_layer=attn_layer,
+            attn_dotr=attn_dotr,
+            attn_mask=attn_mask,
+            exclude_types=exclude_types,
+            env_protection=env_protection,
+            set_davg_zero=set_davg_zero,
+            activation_function=activation_function,
+            precision=precision,
+            scaling_factor=scaling_factor,
+            normalize=normalize,
+            temperature=temperature,
+            trainable_ln=trainable_ln,
+            ln_eps=ln_eps,
+            smooth_type_embedding=True,
+            concat_output_tebd=concat_output_tebd,
+            spin=spin,
+            stripped_type_embedding=stripped_type_embedding,
+            use_econf_tebd=use_econf_tebd,
+            type_map=type_map,
+            # consistent with argcheck, not used though
+            seed=seed,
+        )
+
+    def serialize(self) -> dict:
+        """Serialize the descriptor to dict."""
+        obj = self.se_atten
+        data = {
+            "@class": "Descriptor",
+            "type": "se_atten_v2",
+            "@version": 1,
+            "rcut": obj.rcut,
+            "rcut_smth": obj.rcut_smth,
+            "sel": obj.sel,
+            "ntypes": obj.ntypes,
+            "neuron": obj.neuron,
+            "axis_neuron": obj.axis_neuron,
+            "tebd_dim": obj.tebd_dim,
+            "set_davg_zero": obj.set_davg_zero,
+            "attn": obj.attn,
+            "attn_layer": obj.attn_layer,
+            "attn_dotr": obj.attn_dotr,
+            "attn_mask": False,
+            "activation_function": obj.activation_function,
+            "resnet_dt": obj.resnet_dt,
+            "scaling_factor": obj.scaling_factor,
+            "normalize": obj.normalize,
+            "temperature": obj.temperature,
+            "trainable_ln": obj.trainable_ln,
+            "ln_eps": obj.ln_eps,
+            "type_one_side": obj.type_one_side,
+            "concat_output_tebd": self.concat_output_tebd,
+            "use_econf_tebd": self.use_econf_tebd,
+            "type_map": self.type_map,
+            # make deterministic
+            "precision": np.dtype(PRECISION_DICT[obj.precision]).name,
+            "embeddings": obj.embeddings.serialize(),
+            "embeddings_strip": obj.embeddings_strip.serialize(),
+            "attention_layers": obj.dpa1_attention.serialize(),
+            "env_mat": obj.env_mat.serialize(),
+            "type_embedding": self.type_embedding.serialize(),
+            "exclude_types": obj.exclude_types,
+            "env_protection": obj.env_protection,
+            "@variables": {
+                "davg": obj["davg"],
+                "dstd": obj["dstd"],
+            },
+            ## to be updated when the options are supported.
+            "trainable": self.trainable,
+            "spin": None,
+        }
+        return data
+
+    @classmethod
+    def deserialize(cls, data: dict) -> "DescrptSeAttenV2":
+        """Deserialize from dict."""
+        data = data.copy()
+        check_version_compatibility(data.pop("@version"), 1, 1)
+        data.pop("@class")
+        data.pop("type")
+        variables = data.pop("@variables")
+        embeddings = data.pop("embeddings")
+        type_embedding = data.pop("type_embedding")
+        attention_layers = data.pop("attention_layers")
+        data.pop("env_mat")
+        embeddings_strip = data.pop("embeddings_strip")
+        obj = cls(**data)
+
+        obj.se_atten["davg"] = variables["davg"]
+        obj.se_atten["dstd"] = variables["dstd"]
+        obj.se_atten.embeddings = NetworkCollection.deserialize(embeddings)
+        obj.se_atten.embeddings_strip = NetworkCollection.deserialize(embeddings_strip)
+        obj.type_embedding = TypeEmbedNet.deserialize(type_embedding)
+        obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(
+            attention_layers
+        )
+        return obj
diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py
index e5298ba3ef..b42aa98380 100644
--- a/deepmd/pt/model/descriptor/__init__.py
+++ b/deepmd/pt/model/descriptor/__init__.py
@@ -29,6 +29,9 @@
     DescrptBlockSeA,
     DescrptSeA,
 )
+from .se_atten_v2 import (
+    DescrptSeAttenV2,
+)
 from .se_r import (
     DescrptSeR,
 )
@@ -42,6 +45,7 @@
     "make_default_type_embedding",
     "DescrptBlockSeA",
     "DescrptBlockSeAtten",
+    "DescrptSeAttenV2",
     "DescrptSeA",
     "DescrptSeR",
     "DescrptSeT",
diff --git a/deepmd/pt/model/descriptor/se_atten_v2.py b/deepmd/pt/model/descriptor/se_atten_v2.py
new file mode 100644
index 0000000000..3b350ded98
--- /dev/null
+++ b/deepmd/pt/model/descriptor/se_atten_v2.py
@@ -0,0 +1,253 @@
+# SPDX-License-Identifier: LGPL-3.0-or-later
+from typing import (
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
+
+import torch
+
+from deepmd.dpmodel.utils import EnvMat as DPEnvMat
+from deepmd.pt.model.descriptor.dpa1 import (
+    DescrptDPA1,
+)
+from deepmd.pt.model.network.mlp import (
+    NetworkCollection,
+)
+from deepmd.pt.model.network.network import (
+    TypeEmbedNetConsistent,
+)
+from deepmd.pt.utils import (
+    env,
+)
+from deepmd.pt.utils.env import (
+    RESERVED_PRECISON_DICT,
+)
+from deepmd.utils.version import (
+    check_version_compatibility,
+)
+
+from .base_descriptor import (
+    BaseDescriptor,
+)
+from .se_atten import (
+    NeighborGatedAttention,
+)
+
+
+@BaseDescriptor.register("se_atten_v2")
+class DescrptSeAttenV2(DescrptDPA1):
+    def __init__(
+        self,
+        rcut: float,
+        rcut_smth: float,
+        sel: Union[List[int], int],
+        ntypes: int,
+        neuron: list = [25, 50, 100],
+        axis_neuron: int = 16,
+        tebd_dim: int = 8,
+        set_davg_zero: bool = True,
+        attn: int = 128,
+        attn_layer: int = 2,
+        attn_dotr: bool = True,
+        attn_mask: bool = False,
+        activation_function: str = "tanh",
+        precision: str = "float64",
+        resnet_dt: bool = False,
+        exclude_types: List[Tuple[int, int]] = [],
+        env_protection: float = 0.0,
+        scaling_factor: int = 1.0,
+        normalize=True,
+        temperature=None,
+        concat_output_tebd: bool = True,
+        trainable: bool = True,
+        trainable_ln: bool = True,
+        ln_eps: Optional[float] = 1e-5,
+        type_one_side: bool = False,
+        stripped_type_embedding: Optional[bool] = None,
+        seed: Optional[int] = None,
+        use_econf_tebd: bool = False,
+        type_map: Optional[List[str]] = None,
+        # not implemented
+        spin=None,
+        type: Optional[str] = None,
+        old_impl: bool = False,
+    ) -> None:
+        r"""Construct smooth version of embedding net of type `se_atten_v2`.
+
+        Parameters
+        ----------
+        rcut : float
+            The cut-off radius :math:`r_c`
+        rcut_smth : float
+            From where the environment matrix should be smoothed :math:`r_s`
+        sel : list[int], int
+            list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius
+            int: the total maxmum number of atoms in the cut-off radius
+        ntypes : int
+            Number of element types
+        neuron : list[int]
+            Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}`
+        axis_neuron : int
+            Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix)
+        tebd_dim : int
+            Dimension of the type embedding
+        set_davg_zero : bool
+            Set the shift of embedding net input to zero.
+        attn : int
+            Hidden dimension of the attention vectors
+        attn_layer : int
+            Number of attention layers
+        attn_dotr : bool
+            If dot the angular gate to the attention weights
+        attn_mask : bool
+            (Only support False to keep consistent with other backend references.)
+            (Not used in this version.)
+            If mask the diagonal of attention weights
+        activation_function : str
+            The activation function in the embedding net. Supported options are |ACTIVATION_FN|
+        precision : str
+            The precision of the embedding net parameters. Supported options are |PRECISION|
+        resnet_dt : bool
+            Time-step `dt` in the resnet construction:
+            y = x + dt * \phi (Wx + b)
+        exclude_types : List[List[int]]
+            The excluded pairs of types which have no interaction with each other.
+            For example, `[[0, 1]]` means no interaction between type 0 and type 1.
+        env_protection : float
+            Protection parameter to prevent division by zero errors during environment matrix calculations.
+        scaling_factor : float
+            The scaling factor of normalization in calculations of attention weights.
+            If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5
+        normalize : bool
+            Whether to normalize the hidden vectors in attention weights calculation.
+        temperature : float
+            If not None, the scaling of attention weights is `temperature` itself.
+        trainable_ln : bool
+            Whether to use trainable shift and scale weights in layer normalization.
+        ln_eps : float, Optional
+            The epsilon value for layer normalization.
+        type_one_side : bool
+            If 'False', type embeddings of both neighbor and central atoms are considered.
+            If 'True', only type embeddings of neighbor atoms are considered.
+            Default is 'False'.
+        seed : int, Optional
+            Random seed for parameter initialization.
+        """
+        DescrptDPA1.__init__(
+            self,
+            rcut,
+            rcut_smth,
+            sel,
+            ntypes,
+            neuron=neuron,
+            axis_neuron=axis_neuron,
+            tebd_dim=tebd_dim,
+            tebd_input_mode="strip",
+            set_davg_zero=set_davg_zero,
+            attn=attn,
+            attn_layer=attn_layer,
+            attn_dotr=attn_dotr,
+            attn_mask=attn_mask,
+            activation_function=activation_function,
+            precision=precision,
+            resnet_dt=resnet_dt,
+            exclude_types=exclude_types,
+            env_protection=env_protection,
+            scaling_factor=scaling_factor,
+            normalize=normalize,
+            temperature=temperature,
+            concat_output_tebd=concat_output_tebd,
+            trainable=trainable,
+            trainable_ln=trainable_ln,
+            ln_eps=ln_eps,
+            smooth_type_embedding=True,
+            type_one_side=type_one_side,
+            stripped_type_embedding=stripped_type_embedding,
+            seed=seed,
+            use_econf_tebd=use_econf_tebd,
+            type_map=type_map,
+            # not implemented
+            spin=spin,
+            type=type,
+            old_impl=old_impl,
+        )
+
+    def serialize(self) -> dict:
+        obj = self.se_atten
+        data = {
+            "@class": "Descriptor",
+            "type": "se_atten_v2",
+            "@version": 1,
+            "rcut": obj.rcut,
+            "rcut_smth": obj.rcut_smth,
+            "sel": obj.sel,
+            "ntypes": obj.ntypes,
+            "neuron": obj.neuron,
+            "axis_neuron": obj.axis_neuron,
+            "tebd_dim": obj.tebd_dim,
+            "set_davg_zero": obj.set_davg_zero,
+            "attn": obj.attn_dim,
+            "attn_layer": obj.attn_layer,
+            "attn_dotr": obj.attn_dotr,
+            "attn_mask": False,
+            "activation_function": obj.activation_function,
+            "resnet_dt": obj.resnet_dt,
+            "scaling_factor": obj.scaling_factor,
+            "normalize": obj.normalize,
+            "temperature": obj.temperature,
+            "trainable_ln": obj.trainable_ln,
+            "ln_eps": obj.ln_eps,
+            "type_one_side": obj.type_one_side,
+            "concat_output_tebd": self.concat_output_tebd,
+            "use_econf_tebd": self.use_econf_tebd,
+            "type_map": self.type_map,
+            # make deterministic
+            "precision": RESERVED_PRECISON_DICT[obj.prec],
+            "embeddings": obj.filter_layers.serialize(),
+            "embeddings_strip": obj.filter_layers_strip.serialize(),
+            "attention_layers": obj.dpa1_attention.serialize(),
+            "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
+            "type_embedding": self.type_embedding.embedding.serialize(),
+            "exclude_types": obj.exclude_types,
+            "env_protection": obj.env_protection,
+            "@variables": {
+                "davg": obj["davg"].detach().cpu().numpy(),
+                "dstd": obj["dstd"].detach().cpu().numpy(),
+            },
+            "trainable": self.trainable,
+            "spin": None,
+        }
+        return data
+
+    @classmethod
+    def deserialize(cls, data: dict) -> "DescrptSeAttenV2":
+        data = data.copy()
+        check_version_compatibility(data.pop("@version"), 1, 1)
+        data.pop("@class")
+        data.pop("type")
+        variables = data.pop("@variables")
+        embeddings = data.pop("embeddings")
+        type_embedding = data.pop("type_embedding")
+        attention_layers = data.pop("attention_layers")
+        data.pop("env_mat")
+        embeddings_strip = data.pop("embeddings_strip")
+        obj = cls(**data)
+
+        def t_cvt(xx):
+            return torch.tensor(xx, dtype=obj.se_atten.prec, device=env.DEVICE)
+
+        obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize(
+            type_embedding
+        )
+        obj.se_atten["davg"] = t_cvt(variables["davg"])
+        obj.se_atten["dstd"] = t_cvt(variables["dstd"])
+        obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings)
+        obj.se_atten.filter_layers_strip = NetworkCollection.deserialize(
+            embeddings_strip
+        )
+        obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(
+            attention_layers
+        )
+        return obj
diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py
index b240f00647..2bfe71fcf8 100644
--- a/deepmd/tf/descriptor/se_atten.py
+++ b/deepmd/tf/descriptor/se_atten.py
@@ -1878,12 +1878,8 @@ def serialize(self, suffix: str = "") -> dict:
         dict
             The serialized data
         """
-        if type(self) not in [DescrptSeAtten, DescrptDPA1Compat]:
-            raise NotImplementedError(
-                f"Not implemented in class {self.__class__.__name__}"
-            )
-        if self.stripped_type_embedding and type(self) is not DescrptDPA1Compat:
-            # only DescrptDPA1Compat can serialize when tebd_input_mode=='strip'
+        if self.stripped_type_embedding and type(self) is DescrptSeAtten:
+            # only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'
             raise NotImplementedError(
                 "serialization is unsupported by the native model when tebd_input_mode=='strip'"
             )
@@ -1963,8 +1959,8 @@ def serialize(self, suffix: str = "") -> dict:
         }
         if self.tebd_input_mode in ["strip"]:
             assert (
-                type(self) is DescrptDPA1Compat
-            ), "only DescrptDPA1Compat can serialize when tebd_input_mode=='strip'"
+                type(self) is not DescrptSeAtten
+            ), "only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'"
             data.update(
                 {
                     "embeddings_strip": self.serialize_network_strip(
diff --git a/deepmd/tf/descriptor/se_atten_v2.py b/deepmd/tf/descriptor/se_atten_v2.py
index 6204f27855..a4fdf24a55 100644
--- a/deepmd/tf/descriptor/se_atten_v2.py
+++ b/deepmd/tf/descriptor/se_atten_v2.py
@@ -5,6 +5,10 @@
     Optional,
 )
 
+from deepmd.utils.version import (
+    check_version_compatibility,
+)
+
 from .descriptor import (
     Descriptor,
 )
@@ -109,3 +113,68 @@ def __init__(
             smooth_type_embedding=True,
             **kwargs,
         )
+
+    @classmethod
+    def deserialize(cls, data: dict, suffix: str = ""):
+        """Deserialize the model.
+
+        Parameters
+        ----------
+        data : dict
+            The serialized data
+
+        Returns
+        -------
+        Model
+            The deserialized model
+        """
+        if cls is not DescrptSeAttenV2:
+            raise NotImplementedError(f"Not implemented in class {cls.__name__}")
+        data = data.copy()
+        check_version_compatibility(data.pop("@version"), 1, 1)
+        data.pop("@class")
+        data.pop("type")
+        embedding_net_variables = cls.deserialize_network(
+            data.pop("embeddings"), suffix=suffix
+        )
+        attention_layer_variables = cls.deserialize_attention_layers(
+            data.pop("attention_layers"), suffix=suffix
+        )
+        data.pop("env_mat")
+        variables = data.pop("@variables")
+        type_one_side = data["type_one_side"]
+        two_side_embeeding_net_variables = cls.deserialize_network_strip(
+            data.pop("embeddings_strip"),
+            suffix=suffix,
+            type_one_side=type_one_side,
+        )
+        descriptor = cls(**data)
+        descriptor.embedding_net_variables = embedding_net_variables
+        descriptor.attention_layer_variables = attention_layer_variables
+        descriptor.two_side_embeeding_net_variables = two_side_embeeding_net_variables
+        descriptor.davg = variables["davg"].reshape(
+            descriptor.ntypes, descriptor.ndescrpt
+        )
+        descriptor.dstd = variables["dstd"].reshape(
+            descriptor.ntypes, descriptor.ndescrpt
+        )
+        return descriptor
+
+    def serialize(self, suffix: str = "") -> dict:
+        """Serialize the model.
+
+        Parameters
+        ----------
+        suffix : str, optional
+            The suffix of the scope
+
+        Returns
+        -------
+        dict
+            The serialized data
+        """
+        data = super().serialize(suffix)
+        data.pop("smooth_type_embedding")
+        data.pop("tebd_input_mode")
+        data.update({"type": "se_atten_v2"})
+        return data
diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py
index fadec096eb..bbb203eea9 100644
--- a/deepmd/utils/argcheck.py
+++ b/deepmd/utils/argcheck.py
@@ -637,15 +637,79 @@ def descrpt_se_atten_args():
     ]
 
 
-@descrpt_args_plugin.register("se_atten_v2", doc=doc_only_tf_supported)
+@descrpt_args_plugin.register("se_atten_v2")
 def descrpt_se_atten_v2_args():
     doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used"
+    doc_trainable_ln = (
+        "Whether to use trainable shift and scale weights in layer normalization."
+    )
+    doc_ln_eps = "The epsilon value for layer normalization. The default value for TensorFlow is set to 1e-3 to keep consistent with keras while set to 1e-5 in PyTorch and DP implementation."
+    doc_tebd_dim = "The dimension of atom type embedding."
+    doc_use_econf_tebd = r"Whether to use electronic configuration type embedding. For TensorFlow backend, please set `use_econf_tebd` in `type_embedding` block instead."
+    doc_temperature = "The scaling factor of normalization in calculations of attention weights, which is used to scale the matmul(Q, K)."
+    doc_scaling_factor = (
+        "The scaling factor of normalization in calculations of attention weights, which is used to scale the matmul(Q, K). "
+        "If `temperature` is None, the scaling of attention weights is (N_hidden_dim * scaling_factor)**0.5. "
+        "Else, the scaling of attention weights is setting to `temperature`."
+    )
+    doc_normalize = (
+        "Whether to normalize the hidden vectors during attention calculation."
+    )
+    doc_concat_output_tebd = (
+        "Whether to concat type embedding at the output of the descriptor."
+    )
 
     return [
         *descrpt_se_atten_common_args(),
         Argument(
             "set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero
         ),
+        Argument(
+            "trainable_ln", bool, optional=True, default=True, doc=doc_trainable_ln
+        ),
+        Argument("ln_eps", float, optional=True, default=None, doc=doc_ln_eps),
+        # pt only
+        Argument(
+            "tebd_dim",
+            int,
+            optional=True,
+            default=8,
+            doc=doc_only_pt_supported + doc_tebd_dim,
+        ),
+        Argument(
+            "use_econf_tebd",
+            bool,
+            optional=True,
+            default=False,
+            doc=doc_only_pt_supported + doc_use_econf_tebd,
+        ),
+        Argument(
+            "scaling_factor",
+            float,
+            optional=True,
+            default=1.0,
+            doc=doc_only_pt_supported + doc_scaling_factor,
+        ),
+        Argument(
+            "normalize",
+            bool,
+            optional=True,
+            default=True,
+            doc=doc_only_pt_supported + doc_normalize,
+        ),
+        Argument(
+            "temperature",
+            float,
+            optional=True,
+            doc=doc_only_pt_supported + doc_temperature,
+        ),
+        Argument(
+            "concat_output_tebd",
+            bool,
+            optional=True,
+            default=True,
+            doc=doc_only_pt_supported + doc_concat_output_tebd,
+        ),
     ]
 
 
diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md
index acd1a500a7..24950d9595 100644
--- a/doc/model/train-se-atten.md
+++ b/doc/model/train-se-atten.md
@@ -126,7 +126,7 @@ We highly recommend using the version 2.0 of the attention-based descriptor `"se
       "set_davg_zero": false
 ```
 
-When using PyTorch backend, you must continue to use descriptor `"se_atten"` and specify `tebd_input_mode` as `"strip"` and `smooth_type_embedding` as `"true"`, which achieves the effect of `"se_atten_v2"`. The `tebd_input_mode` can take `"concat"` and `"strip"` as values. When using TensorFlow backend, you need to use descriptor `"se_atten_v2"` and do not need to set `tebd_input_mode` and `smooth_type_embedding` because the default value of `tebd_input_mode` is `"strip"`, and the default value of `smooth_type_embedding` is `"true"` in TensorFlow backend. When `tebd_input_mode` is set to `"strip"`, the embedding matrix $\mathcal{G}^i$ is constructed as:
+You can use descriptor `"se_atten_v2"` and do not need to set `tebd_input_mode` and `smooth_type_embedding`. In `"se_atten_v2"`, `tebd_input_mode` is forced to be `"strip"` and `smooth_type_embedding` is forced to be `"true"`. When `tebd_input_mode` is `"strip"`, the embedding matrix $\mathcal{G}^i$ is constructed as:
 
 ```math
    (\mathcal{G}^i)_j = \mathcal{N}_{e,2}(s(r_{ij})) + \mathcal{N}_{e,2}(s(r_{ij})) \odot ({N}_{e,2}(\{\mathcal{A}^i, \mathcal{A}^j\}) \odot s(r_{ij})) \quad \mathrm{or}
diff --git a/source/tests/consistent/descriptor/test_se_atten_v2.py b/source/tests/consistent/descriptor/test_se_atten_v2.py
new file mode 100644
index 0000000000..54f3cb5826
--- /dev/null
+++ b/source/tests/consistent/descriptor/test_se_atten_v2.py
@@ -0,0 +1,300 @@
+# SPDX-License-Identifier: LGPL-3.0-or-later
+import unittest
+from typing import (
+    Any,
+    Optional,
+    Tuple,
+)
+
+import numpy as np
+from dargs import (
+    Argument,
+)
+
+from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP
+from deepmd.env import (
+    GLOBAL_NP_FLOAT_PRECISION,
+)
+
+from ..common import (
+    INSTALLED_PT,
+    CommonTest,
+    parameterized,
+)
+from .common import (
+    DescriptorTest,
+)
+
+if INSTALLED_PT:
+    from deepmd.pt.model.descriptor.se_atten_v2 import (
+        DescrptSeAttenV2 as DescrptSeAttenV2PT,
+    )
+else:
+    DescrptSeAttenV2PT = None
+DescrptSeAttenV2TF = None
+from deepmd.utils.argcheck import (
+    descrpt_se_atten_args,
+)
+
+
+@parameterized(
+    (4,),  # tebd_dim
+    (True,),  # resnet_dt
+    (True, False),  # type_one_side
+    (20,),  # attn
+    (0, 2),  # attn_layer
+    (True, False),  # attn_dotr
+    ([], [[0, 1]]),  # excluded_types
+    (0.0,),  # env_protection
+    (True, False),  # set_davg_zero
+    (1.0,),  # scaling_factor
+    (True, False),  # normalize
+    (None, 1.0),  # temperature
+    (1e-5,),  # ln_eps
+    (True,),  # concat_output_tebd
+    ("float64",),  # precision
+    (True, False),  # use_econf_tebd
+)
+class TestSeAttenV2(CommonTest, DescriptorTest, unittest.TestCase):
+    @property
+    def data(self) -> dict:
+        (
+            tebd_dim,
+            resnet_dt,
+            type_one_side,
+            attn,
+            attn_layer,
+            attn_dotr,
+            excluded_types,
+            env_protection,
+            set_davg_zero,
+            scaling_factor,
+            normalize,
+            temperature,
+            ln_eps,
+            concat_output_tebd,
+            precision,
+            use_econf_tebd,
+        ) = self.param
+        return {
+            "sel": [10],
+            "rcut_smth": 5.80,
+            "rcut": 6.00,
+            "neuron": [6, 12, 24],
+            "ntypes": self.ntypes,
+            "axis_neuron": 3,
+            "tebd_dim": tebd_dim,
+            "attn": attn,
+            "attn_layer": attn_layer,
+            "attn_dotr": attn_dotr,
+            "attn_mask": False,
+            "scaling_factor": scaling_factor,
+            "normalize": normalize,
+            "temperature": temperature,
+            "ln_eps": ln_eps,
+            "concat_output_tebd": concat_output_tebd,
+            "resnet_dt": resnet_dt,
+            "type_one_side": type_one_side,
+            "exclude_types": excluded_types,
+            "env_protection": env_protection,
+            "precision": precision,
+            "set_davg_zero": set_davg_zero,
+            "use_econf_tebd": use_econf_tebd,
+            "type_map": ["O", "H"] if use_econf_tebd else None,
+            "seed": 1145141919810,
+        }
+
+    def is_meaningless_zero_attention_layer_tests(
+        self,
+        attn_layer: int,
+        attn_dotr: bool,
+        normalize: bool,
+        temperature: Optional[float],
+    ) -> bool:
+        return attn_layer == 0 and (attn_dotr or normalize or temperature is not None)
+
+    @property
+    def skip_pt(self) -> bool:
+        (
+            tebd_dim,
+            resnet_dt,
+            type_one_side,
+            attn,
+            attn_layer,
+            attn_dotr,
+            excluded_types,
+            env_protection,
+            set_davg_zero,
+            scaling_factor,
+            normalize,
+            temperature,
+            ln_eps,
+            concat_output_tebd,
+            precision,
+            use_econf_tebd,
+        ) = self.param
+        return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests(
+            attn_layer,
+            attn_dotr,
+            normalize,
+            temperature,
+        )
+
+    @property
+    def skip_dp(self) -> bool:
+        (
+            tebd_dim,
+            resnet_dt,
+            type_one_side,
+            attn,
+            attn_layer,
+            attn_dotr,
+            excluded_types,
+            env_protection,
+            set_davg_zero,
+            scaling_factor,
+            normalize,
+            temperature,
+            ln_eps,
+            concat_output_tebd,
+            precision,
+            use_econf_tebd,
+        ) = self.param
+        return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests(
+            attn_layer,
+            attn_dotr,
+            normalize,
+            temperature,
+        )
+
+    @property
+    def skip_tf(self) -> bool:
+        return True
+
+    tf_class = DescrptSeAttenV2TF
+    dp_class = DescrptSeAttenV2DP
+    pt_class = DescrptSeAttenV2PT
+    args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False))
+
+    def setUp(self):
+        CommonTest.setUp(self)
+
+        self.ntypes = 2
+        self.coords = np.array(
+            [
+                12.83,
+                2.56,
+                2.18,
+                12.09,
+                2.87,
+                2.74,
+                00.25,
+                3.32,
+                1.68,
+                3.36,
+                3.00,
+                1.81,
+                3.51,
+                2.51,
+                2.60,
+                4.27,
+                3.22,
+                1.56,
+            ],
+            dtype=GLOBAL_NP_FLOAT_PRECISION,
+        )
+        self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32)
+        self.box = np.array(
+            [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0],
+            dtype=GLOBAL_NP_FLOAT_PRECISION,
+        )
+        self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)
+
+    def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]:
+        return self.build_tf_descriptor(
+            obj,
+            self.natoms,
+            self.coords,
+            self.atype,
+            self.box,
+            suffix,
+        )
+
+    def eval_dp(self, dp_obj: Any) -> Any:
+        return self.eval_dp_descriptor(
+            dp_obj,
+            self.natoms,
+            self.coords,
+            self.atype,
+            self.box,
+            mixed_types=True,
+        )
+
+    def eval_pt(self, pt_obj: Any) -> Any:
+        return self.eval_pt_descriptor(
+            pt_obj,
+            self.natoms,
+            self.coords,
+            self.atype,
+            self.box,
+            mixed_types=True,
+        )
+
+    def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]:
+        return (ret[0],)
+
+    @property
+    def rtol(self) -> float:
+        """Relative tolerance for comparing the return value."""
+        (
+            tebd_dim,
+            resnet_dt,
+            type_one_side,
+            attn,
+            attn_layer,
+            attn_dotr,
+            excluded_types,
+            env_protection,
+            set_davg_zero,
+            scaling_factor,
+            normalize,
+            temperature,
+            ln_eps,
+            concat_output_tebd,
+            precision,
+            use_econf_tebd,
+        ) = self.param
+        if precision == "float64":
+            return 1e-10
+        elif precision == "float32":
+            return 1e-4
+        else:
+            raise ValueError(f"Unknown precision: {precision}")
+
+    @property
+    def atol(self) -> float:
+        """Absolute tolerance for comparing the return value."""
+        (
+            tebd_dim,
+            resnet_dt,
+            type_one_side,
+            attn,
+            attn_layer,
+            attn_dotr,
+            excluded_types,
+            env_protection,
+            set_davg_zero,
+            scaling_factor,
+            normalize,
+            temperature,
+            ln_eps,
+            concat_output_tebd,
+            precision,
+            use_econf_tebd,
+        ) = self.param
+        if precision == "float64":
+            return 1e-10
+        elif precision == "float32":
+            return 1e-4
+        else:
+            raise ValueError(f"Unknown precision: {precision}")
diff --git a/source/tests/pt/model/test_se_atten_v2.py b/source/tests/pt/model/test_se_atten_v2.py
new file mode 100644
index 0000000000..caecd0a118
--- /dev/null
+++ b/source/tests/pt/model/test_se_atten_v2.py
@@ -0,0 +1,141 @@
+# SPDX-License-Identifier: LGPL-3.0-or-later
+import itertools
+import unittest
+
+import numpy as np
+import torch
+
+from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DPDescrptSeAttenV2
+from deepmd.pt.model.descriptor.se_atten_v2 import (
+    DescrptSeAttenV2,
+)
+from deepmd.pt.utils import (
+    env,
+)
+from deepmd.pt.utils.env import (
+    PRECISION_DICT,
+)
+
+from .test_env_mat import (
+    TestCaseSingleFrameWithNlist,
+)
+from .test_mlp import (
+    get_tols,
+)
+
+dtype = env.GLOBAL_PT_FLOAT_PRECISION
+
+
+class TestDescrptSeAttenV2(unittest.TestCase, TestCaseSingleFrameWithNlist):
+    def setUp(self):
+        TestCaseSingleFrameWithNlist.setUp(self)
+
+    def test_consistency(
+        self,
+    ):
+        rng = np.random.default_rng(100)
+        nf, nloc, nnei = self.nlist.shape
+        davg = rng.normal(size=(self.nt, nnei, 4))
+        dstd = rng.normal(size=(self.nt, nnei, 4))
+        dstd = 0.1 + np.abs(dstd)
+
+        for idt, to, prec, ect in itertools.product(
+            [False, True],  # resnet_dt
+            [False, True],  # type_one_side
+            [
+                "float64",
+            ],  # precision
+            [False, True],  # use_econf_tebd
+        ):
+            dtype = PRECISION_DICT[prec]
+            rtol, atol = get_tols(prec)
+            err_msg = f"idt={idt} prec={prec}"
+
+            # dpa1 new impl
+            dd0 = DescrptSeAttenV2(
+                self.rcut,
+                self.rcut_smth,
+                self.sel_mix,
+                self.nt,
+                attn_layer=2,
+                precision=prec,
+                resnet_dt=idt,
+                type_one_side=to,
+                use_econf_tebd=ect,
+                type_map=["O", "H"] if ect else None,
+                old_impl=False,
+            ).to(env.DEVICE)
+            dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
+            dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
+            rd0, _, _, _, _ = dd0(
+                torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE),
+                torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE),
+                torch.tensor(self.nlist, dtype=int, device=env.DEVICE),
+            )
+            # serialization
+            dd1 = DescrptSeAttenV2.deserialize(dd0.serialize())
+            rd1, _, _, _, _ = dd1(
+                torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE),
+                torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE),
+                torch.tensor(self.nlist, dtype=int, device=env.DEVICE),
+            )
+            np.testing.assert_allclose(
+                rd0.detach().cpu().numpy(),
+                rd1.detach().cpu().numpy(),
+                rtol=rtol,
+                atol=atol,
+                err_msg=err_msg,
+            )
+            # dp impl
+            dd2 = DPDescrptSeAttenV2.deserialize(dd0.serialize())
+            rd2, _, _, _, _ = dd2.call(
+                self.coord_ext,
+                self.atype_ext,
+                self.nlist,
+            )
+            np.testing.assert_allclose(
+                rd0.detach().cpu().numpy(),
+                rd2,
+                rtol=rtol,
+                atol=atol,
+                err_msg=err_msg,
+            )
+
+    def test_jit(
+        self,
+    ):
+        rng = np.random.default_rng()
+        _, _, nnei = self.nlist.shape
+        davg = rng.normal(size=(self.nt, nnei, 4))
+        dstd = rng.normal(size=(self.nt, nnei, 4))
+        dstd = 0.1 + np.abs(dstd)
+
+        for idt, prec, to, ect in itertools.product(
+            [
+                False,
+            ],  # resnet_dt
+            [
+                "float64",
+            ],  # precision
+            [
+                False,
+            ],  # type_one_side
+            [False, True],  # use_econf_tebd
+        ):
+            dtype = PRECISION_DICT[prec]
+            # dpa1 new impl
+            dd0 = DescrptSeAttenV2(
+                self.rcut,
+                self.rcut_smth,
+                self.sel,
+                self.nt,
+                precision=prec,
+                resnet_dt=idt,
+                type_one_side=to,
+                use_econf_tebd=ect,
+                type_map=["O", "H"] if ect else None,
+                old_impl=False,
+            )
+            dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
+            dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
+            _ = torch.jit.script(dd0)