Skip to content

Commit

Permalink
fix TensorPy for mindspore 2.5 (#1961)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng authored Feb 19, 2025
1 parent ab20e2e commit d859a1b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
6 changes: 5 additions & 1 deletion mindnlp/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""big modeling"""
from contextlib import contextmanager
from mindspore._c_expression import Tensor as Tensor_ # pylint: disable=no-name-in-module
try:
from mindspore._c_expression import TensorPy as Tensor_ # pylint: disable=no-name-in-module
except:
from mindspore._c_expression import Tensor as Tensor_ # pylint: disable=no-name-in-module

from mindnlp.utils.testing_utils import parse_flag_from_env
from mindnlp.core import nn

Expand Down
7 changes: 6 additions & 1 deletion mindnlp/core/ops/creation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
"""creation ops"""
import numpy as np
import mindspore
from mindspore._c_expression import Tensor as CTensor # pylint: disable=no-name-in-module, import-error
try:
from mindspore._c_expression import TensorPy as CTensor # pylint: disable=no-name-in-module
except:
from mindspore._c_expression import Tensor as CTensor # pylint: disable=no-name-in-module


from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim
from mindnlp.configs import use_pyboost, ON_ORANGE_PI
Expand Down
6 changes: 5 additions & 1 deletion mindnlp/parallel/comm_func.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""communication functional api."""
from mindspore import ops, Tensor
from mindspore._c_expression import Tensor as Tensor_ # pylint: disable=no-name-in-module
from mindspore.ops.operations._inner_ops import Send, Receive
from mindspore.communication import GlobalComm, get_group_rank_from_world_rank
from mindspore.ops._primitive_cache import _get_cache_prim
try:
from mindspore._c_expression import TensorPy as Tensor_ # pylint: disable=no-name-in-module
except:
from mindspore._c_expression import Tensor as Tensor_ # pylint: disable=no-name-in-module


def isend(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
"""
Expand Down
7 changes: 6 additions & 1 deletion mindnlp/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import mindspore
from mindspore._c_expression import Tensor as RawTensor # pylint: disable=no-name-in-module
try:
from mindspore._c_expression import TensorPy as RawTensor # pylint: disable=no-name-in-module
except:
from mindspore._c_expression import Tensor as RawTensor # pylint: disable=no-name-in-module


import mindnlp.core.nn.functional as F
from mindnlp.core import nn, ops
from mindnlp.core.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
Expand Down

0 comments on commit d859a1b

Please sign in to comment.