Skip to content

Commit

Permalink
change return type of classmethods-constructors to Self
Browse files Browse the repository at this point in the history
  • Loading branch information
voorhs committed Nov 12, 2024
1 parent d6884d4 commit 9e55a65
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 8 deletions.
5 changes: 3 additions & 2 deletions autointent/modules/scoring/dnnc/head_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from sentence_transformers import CrossEncoder
from sklearn.linear_model import LogisticRegressionCV
from typing_extensions import Self

from autointent.custom_types import LabelType

Expand Down Expand Up @@ -133,7 +134,7 @@ def set_classifier(self, clf: LogisticRegressionCV) -> None:
self._clf = clf

@classmethod
def load(cls, path: str) -> "CrossEncoderWithLogreg":
def load(cls, path: str) -> Self:
dump_dir = Path(path)

# load sklearn model
Expand All @@ -144,7 +145,7 @@ def load(cls, path: str) -> "CrossEncoderWithLogreg":
crossencoder_dir = str(dump_dir / "crossencoder")
model = CrossEncoder(crossencoder_dir) # TODO control device

res = CrossEncoderWithLogreg(model)
res = cls(model)
res.set_classifier(clf)

return res
3 changes: 2 additions & 1 deletion autointent/nodes/inference/inference_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gc

import torch
from typing_extensions import Self

from autointent.configs.node import InferenceNodeConfig
from autointent.modules.base import Module
Expand All @@ -13,7 +14,7 @@ def __init__(self, module: Module, node_type: str) -> None:
self.node_type = node_type

@classmethod
def from_config(cls, config: InferenceNodeConfig) -> "InferenceNode":
def from_config(cls, config: InferenceNodeConfig) -> Self:
node_info = NODES_INFO[config.node_type]
module = node_info.modules_available[config.module_type](**config.module_config)
if config.load_path is not None:
Expand Down
10 changes: 6 additions & 4 deletions autointent/pipeline/inference/inference_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any

from typing_extensions import Self

from autointent.configs.node import InferenceNodeConfig
from autointent.context import Context
from autointent.custom_types import LabelType, NodeType
Expand All @@ -11,13 +13,13 @@ def __init__(self, nodes: list[InferenceNode]) -> None:
self.nodes = {node.node_type: node for node in nodes}

@classmethod
def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> "InferencePipeline":
def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> Self:
nodes_configs_ = [InferenceNodeConfig(**cfg) for cfg in nodes_configs]
nodes = [InferenceNode.from_config(cfg) for cfg in nodes_configs_]
return cls(nodes)

@classmethod
def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "InferencePipeline":
def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> Self:
nodes = [InferenceNode.from_config(cfg) for cfg in nodes_configs]
return cls(nodes)

Expand All @@ -29,12 +31,12 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
pass

@classmethod
def from_context(cls, context: Context) -> "InferencePipeline":
def from_context(cls, context: Context) -> Self:
if not context.has_saved_modules():
config = context.optimization_info.get_inference_nodes_config()
return cls.from_config(config)
nodes = [
InferenceNode(module, node_type)
for node_type, module in context.optimization_info.get_best_modules().items()
]
return InferencePipeline(nodes)
return cls(nodes)
3 changes: 2 additions & 1 deletion autointent/pipeline/optimization/pipeline_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from hydra.utils import instantiate
from typing_extensions import Self

from autointent import Context
from autointent.configs.optimization_cli import EmbedderConfig, LoggingConfig, VectorIndexConfig
Expand Down Expand Up @@ -37,7 +38,7 @@ def set_config(self, config: LoggingConfig | VectorIndexConfig | EmbedderConfig)
raise TypeError(msg)

@classmethod
def from_dict_config(cls, config: dict[str, Any]) -> "PipelineOptimizer":
def from_dict_config(cls, config: dict[str, Any]) -> Self:
return instantiate(PipelineOptimizerConfig, **config) # type: ignore[no-any-return]

def optimize(self, context: Context) -> None:
Expand Down

0 comments on commit 9e55a65

Please sign in to comment.