Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encodings for graphs #129

Merged
merged 23 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ff02744
fix
maypink Mar 17, 2023
987cce0
fix#2
maypink Mar 17, 2023
022363c
minor
maypink Mar 21, 2023
d24dfa2
Merge branch 'main' of https://github.com/aimclub/GOLEM into 66-singl…
maypink Mar 27, 2023
4900e5f
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Mar 29, 2023
cc8729f
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Apr 4, 2023
f320cfa
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Apr 10, 2023
f3ca604
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Apr 21, 2023
b76b1c3
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink May 3, 2023
13a76bf
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 1, 2023
df115e3
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 9, 2023
4342573
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 13, 2023
56db3a7
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 15, 2023
5826890
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 16, 2023
33339ef
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jun 21, 2023
1863a1a
Merge branch 'main' of https://github.com/aimclub/GOLEM
maypink Jul 3, 2023
f9484d9
comparing by context
maypink Jun 19, 2023
ab6992c
add available operations in bandits
maypink Jun 19, 2023
b3af064
tests for context agents
maypink Jun 26, 2023
e44d08f
add adjacency matrix encoding
maypink Jun 26, 2023
4d31a30
fix pep8
maypink Jun 26, 2023
ccfc5cf
review fixes
maypink Jun 27, 2023
6cb4ab7
minor
maypink Jun 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from examples.synthetic_graph_evolution.graph_search import graph_search_setup
from examples.synthetic_graph_evolution.generators import postprocess_nx_graph
from examples.synthetic_graph_evolution.tree_search import tree_search_setup
from golem.core.optimisers.adaptive.context_agents import ContextAgentTypeEnum
from golem.core.optimisers.adaptive.operator_agent import MutationAgentTypeEnum
from golem.core.optimisers.genetic.gp_optimizer import EvoGraphOptimizer
from golem.core.optimisers.genetic.gp_params import GPAlgorithmParameters
Expand Down Expand Up @@ -37,9 +38,11 @@ def generate_trees(graph_sizes: Sequence[int], node_types: Sequence[str] = ('x',
return trees


def get_graph_gp_params(objective: Objective, adaptive_mutation_type: MutationAgentTypeEnum, pop_size: int = None):
def get_graph_gp_params(objective: Objective, adaptive_mutation_type: MutationAgentTypeEnum,
context_agent_type: ContextAgentTypeEnum = None, pop_size: int = None):
return GPAlgorithmParameters(
adaptive_mutation_type=adaptive_mutation_type,
context_agent_type=context_agent_type,
pop_size=pop_size or 21,
multi_objective=objective.is_multi_objective,
genetic_scheme_type=GeneticSchemeTypesEnum.generational,
Expand Down
76 changes: 52 additions & 24 deletions experiments/mab/mab_synthetic_experiment_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from examples.synthetic_graph_evolution.utils import draw_graphs_subplots
from golem.core.adapter.nx_adapter import BaseNetworkxAdapter
from golem.core.dag.graph import Graph
from golem.core.optimisers.adaptive.context_agents import ContextAgentTypeEnum
from golem.core.optimisers.adaptive.operator_agent import MutationAgentTypeEnum

from golem.core.optimisers.genetic.gp_optimizer import EvoGraphOptimizer
Expand All @@ -34,31 +35,36 @@ class MABSyntheticExperimentHelper:
""" Class to provide synthetic experiments without data to compare MABs. """

def __init__(self, launch_num: int, timeout: float, bandits_to_compare: List[MutationAgentTypeEnum],
context_agent_types: List[ContextAgentTypeEnum], bandit_labels: List[str] = None,
path_to_save: str = None, is_visualize: bool = False, n_clusters: Optional[int] = None):
self.launch_num = launch_num
self.timeout = timeout
self.bandits_to_compare = bandits_to_compare
self.bandit_metrics = dict.fromkeys(bandit.name for bandit in self.bandits_to_compare)
self.context_agent_types = context_agent_types
self.bandit_labels = bandit_labels or [bandit.name for bandit in bandits_to_compare]
self.bandit_metrics = dict.fromkeys(bandit for bandit in self.bandit_labels)
self.path_to_save = path_to_save or os.path.join(project_root(), 'mab')
self.is_visualize = is_visualize
self.histories = dict.fromkeys([bandit.name for bandit in self.bandits_to_compare])
self.histories = dict.fromkeys([bandit for bandit in self.bandit_labels])
self.cluster = MiniBatchKMeans(n_clusters=n_clusters)

def compare_bandits(self, setup_parameters: Callable, initial_population_func: Callable = None):
results = dict()
for i in range(self.launch_num):
initial_graphs = initial_population_func()
for bandit in self.bandits_to_compare:
optimizer, objective = setup_parameters(initial_graphs=initial_graphs, bandit_type=bandit)
for j, bandit in enumerate(self.bandits_to_compare):
optimizer, objective = setup_parameters(initial_graphs=initial_graphs, bandit_type=bandit,
context_agent_type=self.context_agent_types[j])
agent = optimizer.mutation.agent
result = self.launch_bandit(bandit_type=bandit, optimizer=optimizer, objective=objective)
if bandit.name not in results.keys():
results[bandit.name] = []
results[bandit.name].append(result)
result = self.launch_bandit(bandit_type=bandit, optimizer=optimizer, objective=objective, bandit_num=j)
if bandit_labels[j] not in results.keys():
results[bandit_labels[j]] = []
results[bandit_labels[j]].append(result)
if self.is_visualize:
self.show_average_action_probabilities(show_action_probabilities=results, actions=agent.actions)

def launch_bandit(self, bandit_type: MutationAgentTypeEnum, optimizer: GraphOptimizer, objective: Callable):
def launch_bandit(self, bandit_type: MutationAgentTypeEnum, optimizer: GraphOptimizer, objective: Callable,
bandit_num: int):

stats_action_value_log: Dict[int, List[List[float]]] = dict()

Expand All @@ -70,9 +76,9 @@ def log_action_values(next_pop: PopulationT, optimizer: EvoGraphOptimizer):

def log_action_values_with_clusters(next_pop: PopulationT, optimizer: EvoGraphOptimizer):
obs_contexts = optimizer.mutation.agent.get_context(next_pop)
self.cluster.partial_fit(np.array(obs_contexts).reshape(-1, 1))
self.cluster.partial_fit(np.array(obs_contexts))
centers = self.cluster.cluster_centers_
for i, center in enumerate(sorted(centers)):
for i, center in enumerate(centers):
values = optimizer.mutation.agent.get_action_values(obs=[center])
if i not in stats_action_value_log.keys():
stats_action_value_log[i] = []
Expand All @@ -88,16 +94,20 @@ def log_action_values_with_clusters(next_pop: PopulationT, optimizer: EvoGraphOp

found_graphs = optimizer.optimise(objective)
found_graph = found_graphs[0] if isinstance(found_graphs, Sequence) else found_graphs

history = optimizer.history
if not self.histories[bandit_type.name]:
self.histories[bandit_type.name] = []
self.histories[bandit_type.name].append(history)
bandit_label = self.bandit_labels[bandit_num]
if not self.histories[bandit_label]:
self.histories[bandit_label] = []
self.histories[bandit_label].append(history)

agent = optimizer.mutation.agent
found_nx_graph = BaseNetworkxAdapter().restore(found_graph)
final_metrics = objective(found_nx_graph).value
if not self.bandit_metrics[bandit_type.name]:
self.bandit_metrics[bandit_type.name] = []
self.bandit_metrics[bandit_type.name].append(final_metrics)

if not self.bandit_metrics[bandit_label]:
self.bandit_metrics[bandit_label] = []
self.bandit_metrics[bandit_label].append(final_metrics)

print('History of action probabilities:')
pprint(stats_action_value_log)
Expand All @@ -124,16 +134,19 @@ def show_action_probabilities(self, bandit_type: MutationAgentTypeEnum, stats_ac
plot_action_values(stats=stats_action_value_log[0], action_tags=actions, titles=titles)
plt.show()
else:
centers = sorted(self.cluster.cluster_centers_)
centers = self.cluster.cluster_centers_
for i in range(self.cluster.n_clusters):
titles_centers = [title + f' for cluster with center {int(centers[i])}' for title in titles]
if len(centers[i]) > 1:
titles_centers = [title + f' for cluster with center idx={i}' for title in titles]
else:
titles_centers = [title + f' for cluster with center {centers[i]}' for title in titles]
plot_action_values(stats=stats_action_value_log[i], action_tags=actions,
titles=titles_centers)
plt.show()

def show_average_action_probabilities(self, show_action_probabilities: dict, actions):
""" Shows action probabilities across several launches. """
for bandit in list(show_action_probabilities.keys()):
for idx, bandit in enumerate(list(show_action_probabilities.keys())):
total_sum = None
for launch in show_action_probabilities[bandit]:
if not total_sum:
Expand All @@ -147,7 +160,7 @@ def show_average_action_probabilities(self, show_action_probabilities: dict, act
for i in range(len(total_sum[cluster])):
for j in range(len(total_sum[cluster][i])):
total_sum[cluster][i][j] /= len(show_action_probabilities[bandit])
self.show_action_probabilities(bandit_type=MutationAgentTypeEnum(bandit),
self.show_action_probabilities(bandit_type=MutationAgentTypeEnum(self.bandits_to_compare[idx]),
stats_action_value_log=total_sum,
actions=actions,
is_average=True)
Expand All @@ -163,6 +176,7 @@ def show_fitness_lines(self):


def setup_parameters(initial_graphs: List[Graph], bandit_type: MutationAgentTypeEnum,
context_agent_type: ContextAgentTypeEnum,
target_size: int, trial_timeout: float):
objective = Objective({'graph_size': lambda graph: abs(target_size -
graph.number_of_nodes())})
Expand All @@ -172,7 +186,8 @@ def setup_parameters(initial_graphs: List[Graph], bandit_type: MutationAgentType
objective=objective,
optimizer_cls=EvoGraphOptimizer,
algorithm_parameters=get_graph_gp_params(objective=objective,
adaptive_mutation_type=bandit_type),
adaptive_mutation_type=bandit_type,
context_agent_type=context_agent_type),
timeout=timedelta(minutes=trial_timeout),
num_iterations=target_size * 3,
initial_graphs=initial_graphs
Expand All @@ -185,22 +200,35 @@ def initial_population_func(graph_size: List[int] = None, pop_size: int = None,
return initial_graphs
initial_graphs = [nx.random_tree(graph_size[i], create_using=nx.DiGraph)
for i in range(pop_size)]
return initial_graphs
initial_opt_graphs = []
for graph in initial_graphs:
opt_graph = BaseNetworkxAdapter().adapt(item=graph)
for node in opt_graph.nodes:
node.content['name'] = 'x'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Вот, кажется, такие штуки тяжело будет ловить, тяжело избегать и невозможно запретить при учете изменений с кэшированием descriptive_id

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

как вариант -- добавить флаг в descriptive_id, чтобы его можно было пересчитать принудительно. ну и оставить там предупреждение о такого рода действиях
здесь так сделано из-за того, что при адаптации в имена нод ставится их uid

initial_opt_graphs.append(opt_graph)
return initial_opt_graphs


if __name__ == '__main__':
timeout = 0.3
launch_num = 1
target_size = 50

bandits_to_compare = [MutationAgentTypeEnum.contextual_bandit, MutationAgentTypeEnum.bandit]
# `bandits_to_compare`, `context_agent_types` and `bandit_labels` correlate one to one.
# Context must be specified for each bandit: for contextual and neural bandits real context must be specified,
# for simple bandits -- ContextAgentTypeEnum.none
bandits_to_compare = [MutationAgentTypeEnum.bandit, MutationAgentTypeEnum.contextual_bandit]
context_agent_types = [ContextAgentTypeEnum.none_encoding, ContextAgentTypeEnum.operations_quantity]
bandit_labels = ['simple_bandit', f'context_{context_agent_types[1].name}']

setup_parameters_func = partial(setup_parameters, target_size=target_size, trial_timeout=timeout)
initial_population_func = partial(initial_population_func,
graph_size=[random.randint(5, 10) for _ in range(10)] +
[random.randint(90, 95) for _ in range(10)],
pop_size=20)

helper = MABSyntheticExperimentHelper(timeout=timeout, launch_num=launch_num, bandits_to_compare=bandits_to_compare,
bandit_labels=bandit_labels, context_agent_types=context_agent_types,
n_clusters=2, is_visualize=True)
helper.compare_bandits(initial_population_func=initial_population_func,
setup_parameters=setup_parameters_func)
Expand Down
103 changes: 93 additions & 10 deletions golem/core/optimisers/adaptive/context_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,124 @@

from typing import List, Callable, Any

import numpy as np
from karateclub import FeatherGraph

from golem.core.adapter.nx_adapter import BanditNetworkxAdapter
from golem.core.optimisers.opt_history_objects.individual import Individual


def feather_graph(obs: Any) -> List[float]:
def adapter_func_to_networkx(func):
""" Decorator function to adapt observation to networkx graphs. """
def wrapper(obs, available_operations):
nx_graph = BanditNetworkxAdapter().restore(obs)
embedding = func(nx_graph, available_operations)
return embedding
return wrapper


def adapter_func_to_graph(func):
""" Decorator function to adapt observation to networkx graphs. """
def wrapper(obs, available_operations):
if isinstance(obs, Individual):
graph = obs.graph
else:
graph = obs
return func(graph, available_operations)
return wrapper


def encode_operations(operations: List[str], available_operations: List[str], mode: str = 'label'):
""" Encoding of operations.
:param operations: operations to encode
:param available_operations: list of all available operations
:param mode: mode of encoding. Available type: 'OHE' and 'label', default -- 'label'
"""
encoded = []
for operation in operations:
if mode == 'label':
encoding = available_operations.index(operation)
else:
encoding = [0] * len(available_operations)
encoding[available_operations.index(operation)] = 1
encoded.append(encoding)
return encoded


@adapter_func_to_networkx
def feather_graph(obs: Any, available_operations: List[str]) -> List[float]:
""" Returns embedding based on an implementation of `"FEATHER-G" <https://arxiv.org/abs/2005.07959>`_.
The procedure uses characteristic functions of node features with random walk weights to describe
node neighborhoods. These node level features are pooled by mean pooling to
create graph level statistics. """
descriptor = FeatherGraph()
nx_graph = BanditNetworkxAdapter().restore(obs)
descriptor.fit([nx_graph])
return descriptor.get_embedding()[:20]
descriptor.fit([obs])
emb = descriptor.get_embedding().reshape(-1, 1)
embd = [i[0] for i in emb]
return embd


def nodes_num(obs: Any) -> int:
@adapter_func_to_graph
def nodes_num(obs: Any, available_operations: List[str]) -> List[int]:
""" Returns number of nodes in graph. """
if isinstance(obs, Individual):
return len(obs.graph.nodes)
else:
return len(obs.nodes)
return [len(obs.nodes)]


@adapter_func_to_graph
def labeled_edges(obs: Any, available_operations: List[str]) -> List[int]:
""" Encodes graph with its edges with nodes labels. """
operations = []
for node in obs.nodes:
for node_ in node.nodes_from:
operations.append(node_.name)
operations.append(node.name)
return encode_operations(operations=operations, available_operations=available_operations)
gkirgizov marked this conversation as resolved.
Show resolved Hide resolved


@adapter_func_to_graph
def operations_quantity(obs: Any, available_operations: List[str]) -> List[int]:
""" Encodes graphs as vectors with quantity of each operation. """
encoding = [0] * len(available_operations)
for node in obs.nodes:
encoding[available_operations.index(node.name)] += 1
return encoding


@adapter_func_to_graph
def adjacency_matrix(obs: Any, available_operations: List[str]) -> List[int]:
""" Encodes graphs as flattened adjacency matrix. """
matrix = np.zeros((len(available_operations), len(available_operations)))
for node in obs.nodes:
operation_parent_idx = available_operations.index(node.name)
for node_ in node.nodes_from:
operation_child_idx = available_operations.index(node_.name)
matrix[operation_parent_idx][operation_child_idx] += 1
return matrix.reshape(1, -1)[0].astype(int).tolist()


def none_encoding(obs: Any, available_operations: List[str]) -> List[int]:
""" Empty encoding. """
return obs


class ContextAgentTypeEnum(Enum):
feather_graph = 'feather_graph'
nodes_num = 'nodes_num'
labeled_edges = 'labeled_edges'
operations_quantity = 'operations_quantity'
adjacency_matrix = 'adjacency_matrix'
none_encoding = 'none_encoding'


class ContextAgentsRepository:
""" Repository of functions to encode observations. """
_agents_implementations = {
ContextAgentTypeEnum.feather_graph: feather_graph,
ContextAgentTypeEnum.nodes_num: nodes_num
ContextAgentTypeEnum.nodes_num: nodes_num,
ContextAgentTypeEnum.labeled_edges: labeled_edges,
ContextAgentTypeEnum.operations_quantity: operations_quantity,
ContextAgentTypeEnum.adjacency_matrix: adjacency_matrix,
ContextAgentTypeEnum.none_encoding: none_encoding
}

@staticmethod
Expand Down
20 changes: 12 additions & 8 deletions golem/core/optimisers/adaptive/mab_agents/contextual_mab_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from functools import partial
from typing import Union, Sequence, Optional, List

import numpy as np
Expand All @@ -15,8 +16,10 @@ class ContextualMultiArmedBanditAgent(OperatorAgent):
""" Contextual Multi-Armed bandit. Observations can be encoded with simple context agent without
using NN to guarantee convergence. """

def __init__(self, actions: Sequence[ActType], n_jobs: int = 1,
context_agent_type: ContextAgentTypeEnum = ContextAgentTypeEnum.nodes_num,
def __init__(self, actions: Sequence[ActType],
context_agent_type: ContextAgentTypeEnum,
available_operations: List[str],
n_jobs: int = 1,
enable_logging: bool = True):
super().__init__(enable_logging)
self.actions = list(actions)
Expand All @@ -26,7 +29,8 @@ def __init__(self, actions: Sequence[ActType], n_jobs: int = 1,
learning_policy=LearningPolicy.UCB1(alpha=1.25),
neighborhood_policy=NeighborhoodPolicy.Clusters(),
n_jobs=n_jobs)
self._context_agent = ContextAgentsRepository.agent_class_by_id(context_agent_type)
self._context_agent = partial(ContextAgentsRepository.agent_class_by_id(context_agent_type),
available_operations=available_operations)
self._is_fitted = False

def _initial_fit(self, obs: ObsType):
Expand All @@ -44,15 +48,15 @@ def choose_action(self, obs: ObsType) -> ActType:
if not self._is_fitted:
self._initial_fit(obs=obs)
contexts = self.get_context(obs=obs)
arm = self._agent.predict(contexts=contexts)
arm = self._agent.predict(contexts=np.array(contexts).reshape(1, -1))
action = self.actions[arm]
return action

def get_action_values(self, obs: Optional[ObsType] = None) -> Sequence[float]:
if not self._is_fitted:
self._initial_fit(obs=obs)
contexts = self.get_context(obs)
prob_dict = self._agent.predict_expectations(contexts=contexts)
prob_dict = self._agent.predict_expectations(contexts=np.array(contexts).reshape(1, -1))
prob_list = [prob_dict[i] for i in range(len(prob_dict))]
return prob_list

Expand All @@ -73,12 +77,12 @@ def partial_fit(self, experience: ExperienceBuffer):

def get_context(self, obs: Union[List[ObsType], ObsType]) -> List[List[float]]:
""" Returns contexts based on specified context agent. """
contexts = []
if not isinstance(obs, list):
obs = [obs]
return self._context_agent(obs)
contexts = []
for ob in obs:
if isinstance(ob, list) or isinstance(ob, np.ndarray):
contexts.append(ob)
else:
contexts.append([self._context_agent(ob)])
contexts.append(self._context_agent(ob))
gkirgizov marked this conversation as resolved.
Show resolved Hide resolved
return contexts
Loading