Skip to content

Commit

Permalink
Fix for random state passing (#278)
Browse files Browse the repository at this point in the history
* Fix for random state passing (#269)

*Random seed pass

*Test move

*Fix

*Imp 2

*Code review fixes

*Mirror fix

*update requirements.txt

---------

Co-authored-by: nicl-nno <nicl.nno@gmail.com>
  • Loading branch information
VadimsAhmers and nicl-nno authored Dec 10, 2024
1 parent ab6e1f3 commit 45215bf
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/mirror_repo_to_gitlab.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on: [push, pull_request, delete]

jobs:
call-nss-ops-mirror-workflow:
uses: ITMO-NSS-team/NSS-Ops/.github/workflows/mirror-repo.yml@master
uses: aimclub/open-source-ops/.github/workflows/mirror-repo.yml@master
with:
GITLAB_URL: 'https://gitlab.actcognitive.org/itmo-nss-team/GOLEM'
secrets:
Expand Down
11 changes: 6 additions & 5 deletions golem/core/optimisers/adaptive/agent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def fit(self, histories: Iterable[OptHistory], validate_each: int = -1) -> Opera
# Preliminary validity check
# This allows to filter out histories with different objectives automatically
if history.objective.metric_names != self.objective.metric_names:
self._log.warning(f'History #{i+1} has different objective! '
self._log.warning(f'History #{i + 1} has different objective! '
f'Expected {self.objective}, got {history.objective}.')
continue

Expand All @@ -67,13 +67,13 @@ def fit(self, histories: Iterable[OptHistory], validate_each: int = -1) -> Opera
experience, val_experience = experience.split(ratio=0.8, shuffle=True)

# Train
self._log.info(f'Training on history #{i+1} with {len(history.generations)} generations')
self._log.info(f'Training on history #{i + 1} with {len(history.generations)} generations')
self.agent.partial_fit(experience)

# Validate
if val_experience:
reward_loss, reward_target = self.validate_agent(experience=val_experience)
self._log.info(f'Agent validation for history #{i+1} & {experience}: '
self._log.info(f'Agent validation for history #{i + 1} & {experience}: '
f'Reward target={reward_target:.3f}, loss={reward_loss:.3f}')

# Reset mutation probabilities to default
Expand Down Expand Up @@ -163,9 +163,10 @@ def _apply_best_action(self, inds: Sequence[Individual]) -> TrajectoryStep:
return best_step

def _apply_action(self, action: Any, ind: Individual) -> TrajectoryStep:
new_graph, applied = self.mutation._adapt_and_apply_mutation(ind.graph, action)
new_graph = self.mutation._apply_mutations(ind.graph, action)
applied = new_graph is not None
fitness = self._eval_objective(new_graph) if applied else None
parent_op = ParentOperator(type_='mutation', operators=applied, parent_individuals=ind)
parent_op = ParentOperator(type_='mutation', operators=applied, parent_individuals=[ind])
new_ind = Individual(new_graph, fitness=fitness, parent_operator=parent_op)

prev_fitness = ind.fitness or self._eval_objective(ind.graph)
Expand Down
13 changes: 7 additions & 6 deletions golem/core/optimisers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
from golem.core.dag.verification_rules import DEFAULT_DAG_RULES
from golem.core.log import default_log
from golem.core.optimisers.advisor import DefaultChangeAdvisor
from golem.core.optimisers.optimization_parameters import OptimizationParameters
from golem.core.optimisers.genetic.evaluation import DelegateEvaluator
from golem.core.optimisers.genetic.operators.operator import PopulationT
from golem.core.optimisers.graph import OptGraph
from golem.core.optimisers.objective import GraphFunction, Objective, ObjectiveFunction
from golem.core.optimisers.opt_history_objects.opt_history import OptHistory
from golem.core.optimisers.opt_node_factory import DefaultOptNodeFactory, OptNodeFactory
from golem.core.optimisers.optimization_parameters import OptimizationParameters
from golem.core.optimisers.random_graph_factory import RandomGraphFactory, RandomGrowthGraphFactory
from golem.utilities.random import RandomStateHandler
from golem.utilities.utilities import set_random_seed

STRUCTURAL_DIVERSITY_FREQUENCY_CHECK = 5

Expand Down Expand Up @@ -47,6 +48,7 @@ class AlgorithmParameters:
adaptive_depth: bool = False
adaptive_depth_max_stagnation: int = 3
structural_diversity_frequency_check: int = STRUCTURAL_DIVERSITY_FREQUENCY_CHECK
seed = None


@dataclass
Expand Down Expand Up @@ -102,18 +104,15 @@ class GraphOptimizer:
:param requirements: implementation-independent requirements for graph optimizer
:param graph_generation_params: parameters for new graph generation
:param graph_optimizer_params: parameters for specific implementation of graph optimizer
Additional custom params can be specified with `custom_optimizer_params`.
"""

def __init__(self,
objective: Objective,
initial_graphs: Optional[Sequence[Union[Graph, Any]]] = None,
# TODO: rename params to avoid confusion
requirements: Optional[OptimizationParameters] = None,
graph_generation_params: Optional[GraphGenerationParams] = None,
graph_optimizer_params: Optional[AlgorithmParameters] = None,
**custom_optimizer_params):
graph_optimizer_params: Optional[
AlgorithmParameters] = None): # check if correct for inherited optimizers
self.log = default_log(self)
self._objective = objective
initial_graphs = graph_generation_params.adapter.adapt(initial_graphs) if initial_graphs else None
Expand All @@ -125,6 +124,8 @@ def __init__(self,
self._iteration_callback: IterationCallback = do_nothing_callback
self._history = OptHistory(objective.get_info(), requirements.history_dir) \
if requirements and requirements.keep_history else None

set_random_seed(self.graph_optimizer_params.seed)
# Log random state for reproducibility of runs
RandomStateHandler.log_random_state()

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Data
numpy>=1.16.0, !=1.24.0
numpy>=1.16.0, !=1.24.0, < 2.0.0
pandas>=1.3.0; python_version >='3.8'

# Models and frameworks
networkx>=2.4, !=2.7.*, !=2.8.1, !=2.8.2, !=2.8.3, != 3.3
networkx>=2.4, !=2.7.*, !=2.8.1, !=2.8.2, !=2.8.3, < 3.3
scipy<1.13.0
zss>=1.2.0

Expand All @@ -20,7 +20,7 @@ Pillow>=9.5.0
func_timeout==4.3.5
joblib>=0.17.0
requests>=2.0
tqdm~=4.65.0
tqdm~=4.66.3
typing>=3.7.0
psutil>=5.9.2

Expand Down
File renamed without changes.

0 comments on commit 45215bf

Please sign in to comment.