Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run black on python folder #514

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/elastic_search_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,10 @@ def elastic_search_parallel_bulk(

@mgp.read_proc
def connect(
elastic_url: str, ca_certs: Union[str, None]=None, elastic_user: str="", elastic_password: str=""
elastic_url: str,
ca_certs: Union[str, None] = None,
elastic_user: str = "",
elastic_password: str = "",
) -> mgp.Record(connection_status=mgp.Map):
"""Establishes connection with the Elasticsearch. This configuration needs to be specific to the Elasticsearch deployment. Uses basic authentication
Args:
Expand Down
1 change: 0 additions & 1 deletion python/graph_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def _get_analysis_funcs():
def _analyze_graph(
context: mgp.ProcCtx, g: nx.MultiDiGraph, analyses: List[str]
) -> List[Tuple[str, str]]:

functions = (
_get_analysis_funcs()
if analyses is None
Expand Down
2 changes: 0 additions & 2 deletions python/igraphalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def maxflow(
target: mgp.Vertex,
capacity: str = "weight",
) -> mgp.Record(max_flow=mgp.Number):

graph = MemgraphIgraph(ctx=ctx, directed=True)
max_flow_value = graph.maxflow(source=source, target=target, capacity=capacity)

Expand Down Expand Up @@ -95,7 +94,6 @@ def mincut(
def topological_sort(
ctx: mgp.ProcCtx, mode: str = "out"
) -> mgp.Record(nodes=mgp.List[mgp.Vertex]):

if mode not in [
TopologicalSortingModes.IN.value,
TopologicalSortingModes.OUT.value,
Expand Down
2 changes: 1 addition & 1 deletion python/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,6 @@ def recommend( # noqa: C901
def get_training_results(
ctx: mgp.ProcCtx,
) -> mgp.Record(training_results=mgp.Any, validation_results=mgp.Any):

"""This method is used when user wants to get performance data obtained from the last training. It is in the form of list of records where each record is a Dict[metric_name, metric_value]. Training and validation
results are returned.

Expand Down Expand Up @@ -1011,6 +1010,7 @@ def validate_user_parameters(parameters: mgp.Map) -> None: # noqa: C901
Returns:
Nothing or raises an exception if something is wrong.
"""

# Hacky Python
def raise_(ex):
raise ex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def _markow_chain(
index: int,
parameters: Dict[str, Any],
) -> None:

temperature = param_value(graph, parameters, Parameter.QA_TEMPERATURE)
max_steps = param_value(graph, parameters, Parameter.QA_MAX_STEPS)
mutation = param_value(graph, parameters, Parameter.MUTATION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(
prev_indv: Individual,
next_indv: Individual,
):

super().__init__(graph, individuals)
self._prev_indv = prev_indv
self._next_indv = next_indv
Expand Down
4 changes: 1 addition & 3 deletions python/mage/graph_coloring_module/components/individual.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
conflict_nodes: Set[int] = None,
conflicts_counter: List[int] = None,
):

self._graph = graph
self._no_of_units = len(graph)
self._no_of_colors = no_of_colors
Expand Down Expand Up @@ -115,7 +114,7 @@ def replace_units(self, indices: List[int], colors: List[int]):
conflict_nodes = self._conflict_nodes.copy()
conflict_edges = self.conflicts_weight

for (index, color) in zip(indices, colors):
for index, color in zip(indices, colors):
if not (0 <= color < self._no_of_colors):
raise IllegalColorException(
"The given color is not in the range of allowed colors!"
Expand Down Expand Up @@ -152,7 +151,6 @@ def _calculate_diff(
conflicts_counter: List[int],
conflict_nodes: Set[int],
) -> Tuple[int, List[int], Set[int]]:

diff = 0
for neigh, weight in self.graph.weighted_neighbors(node):
if chromosome[node] == chromosome[neigh]:
Expand Down
1 change: 0 additions & 1 deletion python/mage/graph_coloring_module/components/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class Population(ABC):
information with individuals that are located next to it."""

def __init__(self, graph: Graph, individuals: List[Individual]):

self._size = len(individuals)
self._individuals = individuals
self._best_individuals = self._individuals[:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def execute(
population: Population,
parameters: Dict[str, Any] = None,
) -> None:

simple_tunneling_max_attempts = param_value(
graph, parameters, Parameter.SIMPLE_TUNNELING_MAX_ATTEMPTS
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
def param_value(
graph: Graph, parameters: Dict[str, Any], param: str, initial_value: Any = None
) -> Any:

if parameters is None:
if initial_value is None:
return None
Expand Down
3 changes: 1 addition & 2 deletions python/mage/node2vec/second_order_random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def calculate_edge_transition_probs(
unnorm_trans_probs = []

for dest_neighbor_id in graph.get_neighbors(dest_node_id):

edge_weight = graph.get_edge_weight(dest_node_id, dest_neighbor_id)

if dest_neighbor_id == src_node_id:
Expand All @@ -153,7 +152,7 @@ def set_graph_transition_probs(self, graph: Graph) -> None:
Args:
graph (Graph): Graph for which to set first pass transition probs
"""
for (node_from, node_to) in graph.get_edges():
for node_from, node_to in graph.get_edges():
graph.set_edge_transition_probs(
(node_from, node_to),
self.calculate_edge_transition_probs(graph, node_from, node_to),
Expand Down
6 changes: 2 additions & 4 deletions python/mage/node2vec_online_module/walk_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def process_new_edge(self, source: Any, target: Any, time: int) -> List[List[Any
return walks

def sample_single_walk(self, source: Any, target: Any, time: int) -> List[Any]:

node_ = source
time_ = self.last_timestamp.get(source, 0)
centrality_ = self.centrality.get(source, 0)
Expand All @@ -84,7 +83,7 @@ def sample_single_walk(self, source: Any, target: Any, time: int) -> List[Any]:
sum_ = centrality_ * random.uniform(0, 1)
sum__ = 0
broken = False
for (n, t, c) in reversed(self.graph[node_]):
for n, t, c in reversed(self.graph[node_]):
if t < time_:
sum__ += (c + 1) * self.beta * math.exp(self.c * (t - time_))
if sum__ >= sum_:
Expand All @@ -100,7 +99,6 @@ def sample_single_walk(self, source: Any, target: Any, time: int) -> List[Any]:
return (node_, target)

def update(self, source: Any, target: Any, time: int) -> None:

# all walks that terminated at the target before adding the new edge are decayed
if target in self.centrality:
self.centrality[target] *= math.exp(
Expand Down Expand Up @@ -133,7 +131,7 @@ def clean_in_edges(self, node: Any, current_time: int) -> None:

# possible improvement by binary search
index = 0
for (source, time, centrality) in self.graph[node]:
for source, time, centrality in self.graph[node]:
if current_time - time < self.cutoff:
break
index += 1
Expand Down
1 change: 0 additions & 1 deletion python/mage/node_classification/models/gatjk.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(
self.bns = torch.nn.ModuleList()
self.bns.append(torch.nn.BatchNorm1d(hidden_features_size[0] * heads))
for i in range(len(hidden_features_size) - 2):

self.convs.append(
GATConv(
hidden_features_size[i] * heads,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def nodes_fetching(
inv_label_reindexing = defaultdict()

for node in nodes:

if features_name not in node.properties:
continue # if features are not available, skip the node

Expand Down Expand Up @@ -79,7 +78,6 @@ def nodes_fetching(

# since node_types is Counter, key is the node type and value is the number of nodes of that type
for node_type, num_types_node in node_types.items():

# for each node type, create a tensor of size num_types_node x embedding_lengths[node_type]
data[node_type].x = torch.tensor(
np.zeros((num_types_node, embedding_lengths[node_type])),
Expand Down
20 changes: 12 additions & 8 deletions python/mage/tgn/definitions/tgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def _process_current_batch(
edge_idxs: np.array,
timestamps: np.array,
) -> None:

self._update_raw_message_store_current_batch(
sources=sources,
destinations=destinations,
Expand All @@ -203,7 +202,6 @@ def _process_current_batch(
self.node_features[node_id] = node_feature

def _process_previous_batches(self) -> None:

# dict nodeid -> List[event]
raw_messages = self.raw_message_store.get_messages()

Expand All @@ -226,7 +224,6 @@ def _update_raw_message_store_current_batch(
edge_features: Dict[int, torch.Tensor],
node_features: Dict[int, torch.Tensor],
) -> None:

interaction_events: Dict[int, List[Event]] = create_interaction_events(
sources=sources,
destinations=destinations,
Expand Down Expand Up @@ -269,7 +266,6 @@ def _create_messages(
)
)
elif type(message) is InteractionRawMessage:

interaction_raw_message = message

processed_messages_dict[node].append(
Expand Down Expand Up @@ -424,8 +420,12 @@ def _form_computation_graph(
cur_arr = [(n, v) for (n, v) in prev]

node_arr = []
for (v, t) in cur_arr:
(neighbors, edge_idxs, timestamps,) = (
for v, t in cur_arr:
(
neighbors,
edge_idxs,
timestamps,
) = (
self.temporal_neighborhood.get_neighborhood(
v, t, self.num_neighbors
)
Expand All @@ -447,8 +447,12 @@ def _form_computation_graph(

global_edge_indexes = []
global_timestamps = []
for (v, t) in node_layers[0]:
(neighbors, edge_idxs, timestamps,) = (
for v, t in node_layers[0]:
(
neighbors,
edge_idxs,
timestamps,
) = (
self.temporal_neighborhood.get_neighborhood(v, t, self.num_neighbors)
if (v, t) not in sampled_neighborhood
else sampled_neighborhood[(v, t)]
Expand Down
5 changes: 0 additions & 5 deletions python/mgp_igraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def pagerank(
def get_all_simple_paths(
self, v: mgp.Vertex, to: mgp.Vertex, cutoff: int
) -> List[List[mgp.Vertex]]:

paths = [
self._convert_vertex_ids_to_mgp_vertices(path)
for path in super().get_all_simple_paths(
Expand Down Expand Up @@ -86,7 +85,6 @@ def community_leiden(
def mincut(
self, source: mgp.Vertex, target: mgp.Vertex, capacity: str
) -> Tuple[List[mgp.Vertex], float]:

cut = super().mincut(
source=self.id_mappings[source.id],
target=self.id_mappings[target.id],
Expand Down Expand Up @@ -117,7 +115,6 @@ def spanning_tree(
def shortest_path_length(
self, source: mgp.Vertex, target: mgp.Vertex, weights: str
) -> float:

length = super().distances(
source=self.id_mappings[source.id],
target=self.id_mappings[target.id],
Expand All @@ -134,7 +131,6 @@ def all_shortest_path_lengths(self, weights: str) -> List[List[float]]:
def get_shortest_path(
self, source: mgp.Vertex, target: mgp.Vertex, weights: str
) -> List[mgp.Vertex]:

path = super().get_shortest_paths(
v=self.id_mappings[source.id],
to=self.id_mappings[target.id],
Expand All @@ -149,7 +145,6 @@ def get_vertex_by_id(self, id: int) -> mgp.Vertex:
def _convert_vertex_ids_to_mgp_vertices(
self, vertex_ids: List[int]
) -> List[mgp.Vertex]:

vertices = []
for id in vertex_ids:
vertices.append(
Expand Down
7 changes: 5 additions & 2 deletions python/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# constants
##############################


# parameters for the model
class ModelParams:
IN_CHANNELS = "in_channels"
Expand Down Expand Up @@ -597,8 +598,10 @@ def train(


@mgp.read_proc
def get_training_data() -> mgp.Record(
epoch=int, loss=float, val_loss=float, train_log=mgp.Any, val_log=mgp.Any
def get_training_data() -> (
mgp.Record(
epoch=int, loss=float, val_loss=float, train_log=mgp.Any, val_log=mgp.Any
)
):
"""This function is used so user can see what is logged data from training.

Expand Down
1 change: 0 additions & 1 deletion python/tests/graph_coloring/test_conflict_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def test_individual_error_no_setting(


def test_population_error(set_seed, chain_population):

error = ConflictError().population_err(
graph,
chain_population,
Expand Down
1 change: 0 additions & 1 deletion python/tests/node2vec/test_basic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

@pytest.fixture(params=[True, False])
def is_directed(request):

return request.param


Expand Down
2 changes: 0 additions & 2 deletions python/tests/node2vec_online/test_walk_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def test_correct_walk_number(walk_sampling):


def test_legal_combinations_time_crossover_edges(walk_sampling):

walk_sampling.process_new_edge(1, 2, time.time())
walk_sampling.process_new_edge(0, 1, time.time())

Expand All @@ -40,7 +39,6 @@ def test_legal_combinations_time_crossover_edges(walk_sampling):


def test_legal_combinations_time_linear_edges(walk_sampling):

walk_sampling.process_new_edge(0, 1, time.time())
walk_sampling.process_new_edge(1, 2, time.time())
walks = walk_sampling.process_new_edge(2, 3, time.time())
Expand Down
4 changes: 3 additions & 1 deletion python/tgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
# params and classes
##################


# params TGN must receive
class TGNParameters:
NUM_OF_LAYERS = "num_of_layers"
Expand Down Expand Up @@ -820,7 +821,6 @@ def train_eval_epochs(
assert batch_size > 0

for epoch in range(num_epochs):

# update global epoch counter
update_epoch_counter()

Expand Down Expand Up @@ -886,6 +886,7 @@ def train_eval_epochs(

# all available read_procs


#####################################################
@mgp.read_proc
def predict_link_score(
Expand Down Expand Up @@ -1287,6 +1288,7 @@ def is_correctly_typed(defined_types, input_values):

# helper functions


#####################################
def get_tgn_layer_enum(layer_type: str) -> TGNLayerType:
if TGNLayerType(layer_type) is TGNLayerType.GraphAttentionEmbedding:
Expand Down