Skip to content

Commit

Permalink
Move avg_by_attr from StatisticsCollector
Browse files Browse the repository at this point in the history
  • Loading branch information
Parzival-05 committed Sep 7, 2024
1 parent d0e1a93 commit 485557c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
18 changes: 9 additions & 9 deletions AIAgent/ml/training/epochs_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ def sort_dict(d):
return dict(natsort.natsorted(d.items()))


def avg_by_attr(results, path_to_coverage: str) -> int:
coverage = np.average(
list(map(lambda result: getattr(result, path_to_coverage), results))
)
return coverage


@dataclass
class StatsWithTable:
avg: float
Expand All @@ -43,13 +50,6 @@ def __init__(
self._svms_stats_dict: dict[SVMName, StatsWithTable] = {}
self._failed_maps_dict: dict[SVMName, FailedMaps] = {}

@staticmethod
def avg_by_attr(results, path_to_coverage: str) -> int:
coverage = np.average(
list(map(lambda result: getattr(result, path_to_coverage), results))
)
return coverage

def update_file(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
Expand Down Expand Up @@ -102,7 +102,7 @@ def generate_svms_stats_dict(
svms_stats_dict: dict[SVMName, list[StatsWithTable]] = dict()
for svm_name, map2results_list in svms_and_map2results.items():
svms_stats_dict[svm_name] = StatsWithTable(
StatisticsCollector.avg_by_attr(
avg_by_attr(
list(
map(
lambda map2result: map2result.game_result,
Expand All @@ -124,7 +124,7 @@ def __get_results(self) -> str:
svms_stats = self._svms_stats_dict.items()
_, svms_stats_with_table = list(zip(*svms_stats))

avg_coverage = StatisticsCollector.avg_by_attr(svms_stats_with_table, "avg")
avg_coverage = avg_by_attr(svms_stats_with_table, "avg")
df_concat = pd.concat(
list(
map(lambda stats_with_table: stats_with_table.df, svms_stats_with_table)
Expand Down
6 changes: 3 additions & 3 deletions AIAgent/ml/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ml.game.play_game import play_game
from ml.training.dataset import TrainingDataset
from ml.training.wrapper import TrainingModelWrapper
from ml.training.epochs_statistics import StatisticsCollector
from ml.training.epochs_statistics import StatisticsCollector, avg_by_attr
from torch_geometric.loader import DataLoader
from paths import CURRENT_TABLE_PATH

Expand Down Expand Up @@ -88,13 +88,13 @@ def validate_coverage(

statistics_collector.update_results(all_results)

average_result = StatisticsCollector.avg_by_attr(
average_result = avg_by_attr(
list(map(lambda map2result: map2result.game_result, all_results)),
"actual_coverage_percent",
)
mlflow.log_metrics(
{
"average_dataset_state_result": StatisticsCollector.avg_by_attr(
"average_dataset_state_result": avg_by_attr(
dataset.maps_results.values(), "coverage_percent"
),
"average_result": average_result,
Expand Down

0 comments on commit 485557c

Please sign in to comment.