From 65bb31862b0f69f826ff08a975d98f284643eb4a Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 31 Jul 2024 19:01:24 +0800 Subject: [PATCH] [Typing][C-46] Add type annotations for `python/paddle/distributed/fleet/dataset/index_dataset.py` (#66760) --- .../fleet/dataset/index_dataset.py | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/python/paddle/distributed/fleet/dataset/index_dataset.py b/python/paddle/distributed/fleet/dataset/index_dataset.py index 4c02cb8d82c7a3..eab780846c2c15 100644 --- a/python/paddle/distributed/fleet/dataset/index_dataset.py +++ b/python/paddle/distributed/fleet/dataset/index_dataset.py @@ -11,18 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + +from typing import Any + from paddle.base import core __all__ = [] class Index: - def __init__(self, name): + def __init__(self, name: str) -> None: self._name = name class TreeIndex(Index): - def __init__(self, name, path): + def __init__(self, name: str, path: str) -> None: super().__init__(name) self._wrapper = core.IndexWrapper() self._wrapper.insert_tree_index(name, path) @@ -33,57 +38,65 @@ def __init__(self, name, path): self._emb_size = self._tree.emb_size() self._layerwise_sampler = None - def height(self): + def height(self) -> int: return self._height - def branch(self): + def branch(self) -> int: return self._branch - def total_node_nums(self): + def total_node_nums(self) -> int: return self._total_node_nums - def emb_size(self): + def emb_size(self) -> int: return self._emb_size - def get_all_leafs(self): + def get_all_leafs(self) -> list[Any]: return self._tree.get_all_leafs() - def get_nodes(self, codes): + def get_nodes(self, codes: list[int]) -> list[Any]: return self._tree.get_nodes(codes) - def get_layer_codes(self, level): + def get_layer_codes(self, level: int) -> list[int]: return self._tree.get_layer_codes(level) - def get_travel_codes(self, id, start_level=0): + def get_travel_codes(self, id: int, start_level: int = 0) -> list[int]: return self._tree.get_travel_codes(id, start_level) - def get_ancestor_codes(self, ids, level): + def get_ancestor_codes(self, ids: list[int], level: int) -> list[int]: return self._tree.get_ancestor_codes(ids, level) - def get_children_codes(self, ancestor, level): + def get_children_codes(self, ancestor: int, level: int) -> list[int]: return self._tree.get_children_codes(ancestor, level) - def get_travel_path(self, child, ancestor): + def get_travel_path(self, child: int, ancestor: int) -> list[int]: res = [] while child > ancestor: res.append(child) child = int((child - 1) / self._branch) return res - def get_pi_relation(self, ids, level): + def get_pi_relation(self, ids: list[int], level: int) -> dict[int, int]: codes = self.get_ancestor_codes(ids, level) return dict(zip(ids, codes)) def init_layerwise_sampler( - self, layer_sample_counts, start_sample_layer=1, seed=0 - ): + self, + layer_sample_counts: list[int], + start_sample_layer: int = 1, + seed: int = 0, + ) -> None: assert self._layerwise_sampler is None self._layerwise_sampler = core.IndexSampler("by_layerwise", self._name) self._layerwise_sampler.init_layerwise_conf( layer_sample_counts, start_sample_layer, seed ) - def layerwise_sample(self, user_input, index_input, with_hierarchy=False): + def layerwise_sample( + self, + user_input: list[list[int]], + index_input: list[int], + with_hierarchy: bool = False, + ) -> list[list[int]]: if self._layerwise_sampler is None: raise ValueError("please init layerwise_sampler first.") return self._layerwise_sampler.sample(