Skip to content

Commit

Permalink
[Typing][C-50] Add type annotations for `python/paddle/distributed/fl…
Browse files Browse the repository at this point in the history
…eet/utils/ps_util.py` (#66770)

---------

Co-authored-by: SigureMo <sigure.qaq@gmail.com>
  • Loading branch information
megemini and SigureMo authored Jul 31, 2024
1 parent 3e59023 commit 2b32c1d
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions python/paddle/distributed/fleet/utils/ps_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,19 @@
# limitations under the License.
"""Parameter Server utils"""

from __future__ import annotations

import os
import warnings
from typing import TYPE_CHECKING

import paddle

if TYPE_CHECKING:
from paddle import Tensor
from paddle.distributed.fleet.base.role_maker import RoleMakerBase
from paddle.static import Executor, Program

__all__ = []


Expand All @@ -26,7 +34,11 @@ class DistributedInfer:
Utility class for distributed infer of PaddlePaddle.
"""

def __init__(self, main_program=None, startup_program=None):
def __init__(
self,
main_program: Program | None = None,
startup_program: Program | None = None,
) -> None:
if main_program:
self.origin_main_program = main_program.clone()
else:
Expand All @@ -43,8 +55,12 @@ def __init__(self, main_program=None, startup_program=None):
self.sparse_table_maps = None

def init_distributed_infer_env(
self, exe, loss, role_maker=None, dirname=None
):
self,
exe: Executor,
loss: Tensor,
role_maker: RoleMakerBase | None = None,
dirname: str | None = None,
) -> None:
from paddle.distributed import fleet

if fleet.fleet._runtime_handle is None:
Expand Down Expand Up @@ -112,7 +128,7 @@ def _init_dense_params(self, exe=None, dirname=None):
vars=need_load_vars,
)

def get_dist_infer_program(self):
def get_dist_infer_program(self) -> Program:
varname2tables = self._get_sparse_table_map()
convert_program = self._convert_program(
self.origin_main_program, varname2tables
Expand Down

0 comments on commit 2b32c1d

Please sign in to comment.