Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
Remove mutable defaults in mlp_model (apache#12546)
Browse files Browse the repository at this point in the history
  • Loading branch information
michalpiszczek authored and xinetzone committed Nov 25, 2022
1 parent 877c8b5 commit a367538
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions python/tvm/meta_schedule/cost_model/mlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import tempfile
from collections import OrderedDict
from itertools import chain as itertools_chain
from typing import Dict, List, NamedTuple, Tuple
from typing import Dict, List, NamedTuple, Optional, Tuple

import numpy as np # type: ignore
import torch # type: ignore
Expand Down Expand Up @@ -418,8 +418,8 @@ def forward( # pylint: disable=missing-function-docstring
def extract_features(
context: TuneContext,
candidates: List[MeasureCandidate],
results: List[RunnerResult] = None,
extractor: FeatureExtractor = PerStoreFeature(extract_workload=True),
results: Optional[List[RunnerResult]] = None,
extractor: Optional[FeatureExtractor] = None,
):
"""Extract feature vectors and compute mean costs.
Expand All @@ -429,9 +429,9 @@ def extract_features(
The tuning context.
candidates: List[MeasureCandidate]
The measure candidates.
results: List[RunnerResult]
results: Optional[List[RunnerResult]]
The measured results, can be None if used in prediction.
extractor: FeatureExtractor
extractor: Optional[FeatureExtractor]
The feature extractor.
Returns
Expand All @@ -441,6 +441,7 @@ def extract_features(
new_mean_costs: np.ndarray
The mean costs.
"""
extractor = extractor or PerStoreFeature(extract_workload=True)

def _feature(feature: NDArray) -> np.ndarray:
return feature.numpy().astype("float32")
Expand Down Expand Up @@ -481,9 +482,12 @@ class State:

def __init__(
self,
model_config: SegmentSumMLPConfig = SegmentSumMLPConfig(),
extractor: FeatureExtractor = PerStoreFeature(extract_workload=True),
model_config: Optional[SegmentSumMLPConfig] = None,
extractor: Optional[FeatureExtractor] = None,
):
model_config = model_config or SegmentSumMLPConfig()
extractor = extractor or PerStoreFeature(extract_workload=True)

self.model = SegmentSumMLP(**model_config.to_dict())
self.data = OrderedDict()
self.data_size = 0
Expand Down Expand Up @@ -662,9 +666,12 @@ class SegmentSumMLPTrainer:

def __init__(
self,
train_config: TrainerConfig = TrainerConfig(),
state: State = State(),
train_config: Optional[TrainerConfig] = None,
state: Optional[State] = None,
):
train_config = train_config or TrainerConfig()
state = state or State()

config = train_config.to_dict()
for attr in config:
setattr(self, attr, config[attr])
Expand All @@ -676,7 +683,7 @@ def train_step(
self,
data: Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"],
batch: int = 0,
train_loss: float = None,
train_loss: Optional[float] = None,
) -> float:
"""Helper function for training on a single batch.
Expand All @@ -686,7 +693,7 @@ def train_step(
A batch of data, should be a tuple of (segment_sizes, features, gt_results).
batch: int = 0
The current batch number.
train_loss: float = None
train_loss: Optional[float] = None
The previous averaged training loss, None if it is the first batch.
Returns
Expand Down Expand Up @@ -863,15 +870,15 @@ def train_incremental(
def predict_incremental(
self,
features: List[np.ndarray],
results: np.ndarray = None,
results: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Predicting (validating) on incremental data.
Parameters
----------
features: List[np.ndarray]
The extracted features.
results: np.ndarray
results: Optional[np.ndarray]
The measured results, can be None if used for predicting.
Returns
Expand Down Expand Up @@ -943,10 +950,10 @@ class MLPModel(PyCostModel):
def __init__(
self,
*,
trainer: SegmentSumMLPTrainer = SegmentSumMLPTrainer(),
trainer: Optional[SegmentSumMLPTrainer] = None,
):
super().__init__()
self.trainer = trainer
self.trainer = trainer or SegmentSumMLPTrainer()

def load(self, path: str) -> None:
"""Load the cost model, cached data or raw data from given file location.
Expand Down

0 comments on commit a367538

Please sign in to comment.