From 3bd9d8aecdca6d3cd558f94db97bb4d0576f631a Mon Sep 17 00:00:00 2001 From: Michal Piszczek Date: Mon, 22 Aug 2022 14:29:09 -0700 Subject: [PATCH] Remove mutable defaults in mlp_model --- .../tvm/meta_schedule/cost_model/mlp_model.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/mlp_model.py b/python/tvm/meta_schedule/cost_model/mlp_model.py index 04ccca0563f9..e7f07f0a4542 100644 --- a/python/tvm/meta_schedule/cost_model/mlp_model.py +++ b/python/tvm/meta_schedule/cost_model/mlp_model.py @@ -26,7 +26,7 @@ import tempfile from collections import OrderedDict from itertools import chain as itertools_chain -from typing import Dict, List, NamedTuple, Tuple +from typing import Dict, List, NamedTuple, Optional, Tuple import numpy as np # type: ignore import torch # type: ignore @@ -418,8 +418,8 @@ def forward( # pylint: disable=missing-function-docstring def extract_features( context: TuneContext, candidates: List[MeasureCandidate], - results: List[RunnerResult] = None, - extractor: FeatureExtractor = PerStoreFeature(extract_workload=True), + results: Optional[List[RunnerResult]] = None, + extractor: Optional[FeatureExtractor] = None, ): """Extract feature vectors and compute mean costs. @@ -429,9 +429,9 @@ def extract_features( The tuning context. candidates: List[MeasureCandidate] The measure candidates. - results: List[RunnerResult] + results: Optional[List[RunnerResult]] The measured results, can be None if used in prediction. - extractor: FeatureExtractor + extractor: Optional[FeatureExtractor] The feature extractor. Returns @@ -441,6 +441,7 @@ def extract_features( new_mean_costs: np.ndarray The mean costs. """ + extractor = extractor or PerStoreFeature(extract_workload=True) def _feature(feature: NDArray) -> np.ndarray: return feature.numpy().astype("float32") @@ -481,9 +482,12 @@ class State: def __init__( self, - model_config: SegmentSumMLPConfig = SegmentSumMLPConfig(), - extractor: FeatureExtractor = PerStoreFeature(extract_workload=True), + model_config: Optional[SegmentSumMLPConfig] = None, + extractor: Optional[FeatureExtractor] = None, ): + model_config = model_config or SegmentSumMLPConfig() + extractor = extractor or PerStoreFeature(extract_workload=True) + self.model = SegmentSumMLP(**model_config.to_dict()) self.data = OrderedDict() self.data_size = 0 @@ -662,9 +666,12 @@ class SegmentSumMLPTrainer: def __init__( self, - train_config: TrainerConfig = TrainerConfig(), - state: State = State(), + train_config: Optional[TrainerConfig] = None, + state: Optional[State] = None, ): + train_config = train_config or TrainerConfig() + state = state or State() + config = train_config.to_dict() for attr in config: setattr(self, attr, config[attr]) @@ -676,7 +683,7 @@ def train_step( self, data: Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"], batch: int = 0, - train_loss: float = None, + train_loss: Optional[float] = None, ) -> float: """Helper function for training on a single batch. @@ -686,7 +693,7 @@ def train_step( A batch of data, should be a tuple of (segment_sizes, features, gt_results). batch: int = 0 The current batch number. - train_loss: float = None + train_loss: Optional[float] = None The previous averaged training loss, None if it is the first batch. Returns @@ -863,7 +870,7 @@ def train_incremental( def predict_incremental( self, features: List[np.ndarray], - results: np.ndarray = None, + results: Optional[np.ndarray] = None, ) -> np.ndarray: """Predicting (validating) on incremental data. @@ -871,7 +878,7 @@ def predict_incremental( ---------- features: List[np.ndarray] The extracted features. - results: np.ndarray + results: Optional[np.ndarray] The measured results, can be None if used for predicting. Returns @@ -943,10 +950,10 @@ class MLPModel(PyCostModel): def __init__( self, *, - trainer: SegmentSumMLPTrainer = SegmentSumMLPTrainer(), + trainer: Optional[SegmentSumMLPTrainer] = None, ): super().__init__() - self.trainer = trainer + self.trainer = trainer or SegmentSumMLPTrainer() def load(self, path: str) -> None: """Load the cost model, cached data or raw data from given file location.