Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Resolve tag addition issue from parallel runs #3247

Merged
merged 5 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
### Fixes:
- Fix aggregated metrics' computations (mihran113)
- Fix bug in RunStatusReporter raising non-deterministic RuntimeError exception (VassilisVassiliadis)

- Fix tag addition issue from parallel runs (mihran113)

## 3.26.1 Dec 3, 2024
- Re-upload after PyPI size limitation fix
Expand Down
2 changes: 1 addition & 1 deletion aim/cli/runs/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import tqdm

from aim.cli.runs.utils import make_zip_archive, match_runs, upload_repo_runs
from aim.sdk.repo import Repo
from aim.sdk.index_manager import RepoIndexManager
from aim.sdk.repo import Repo
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP
from psutil import cpu_count

Expand Down
6 changes: 2 additions & 4 deletions aim/sdk/reporter/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@

class FileManager(object):
@abstractmethod
def poll(self, pattern: str) -> Optional[str]:
...
def poll(self, pattern: str) -> Optional[str]: ...

@abstractmethod
def touch(self, filename: str, cleanup_file_pattern: Optional[str] = None):
...
def touch(self, filename: str, cleanup_file_pattern: Optional[str] = None): ...


class LocalFileManager(FileManager):
Expand Down
2 changes: 1 addition & 1 deletion aim/sdk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from aim.sdk.reporter import RunStatusReporter, ScheduledStatusReporter
from aim.sdk.reporter.file_manager import LocalFileManager
from aim.sdk.sequence import Sequence
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP
from aim.sdk.sequence_collection import SingleRunSequenceCollection
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP
from aim.sdk.tracker import RunTracker
from aim.sdk.types import AimObject
from aim.sdk.utils import (
Expand Down
9 changes: 3 additions & 6 deletions aim/sdk/run_status_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,13 @@ def __init__(self, *, obj_idx: Optional[str] = None, rank: Optional[int] = None,
self.message = message

@abstractmethod
def is_sent(self):
...
def is_sent(self): ...

@abstractmethod
def update_last_sent(self):
...
def update_last_sent(self): ...

@abstractmethod
def get_msg_details(self):
...
def get_msg_details(self): ...


class StatusNotification(Notification):
Expand Down
2 changes: 1 addition & 1 deletion aim/sdk/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

from aim.sdk.configs import AIM_ENABLE_TRACKING_THREAD
from aim.sdk.num_utils import convert_to_py_number, is_number
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP
from aim.sdk.utils import check_types_compatibility, get_object_typename
from aim.storage.context import Context
from aim.storage.hashing import hash_auto
from aim.storage.object import CustomObject
from aim.storage.types import AimObject
from aim.sdk.sequences.sequence_type_map import SEQUENCE_TYPE_MAP


if TYPE_CHECKING:
Expand Down
25 changes: 18 additions & 7 deletions aim/storage/structured/sql_engine/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,25 @@ def tags(self) -> List[str]:
return [tag.name for tag in self.tags_obj]

def add_tag(self, value: str) -> None:
def unsafe_add_tag():
if value is None:
tag = None
else:
tag = session.query(TagModel).filter(TagModel.name == value).first()
if not tag:
tag = TagModel(value)
session.add(tag)
self._model.tags.append(tag)
session.add(self._model)

session = self._session
tag = session.query(TagModel).filter(TagModel.name == value).first()
if not tag:
tag = TagModel(value)
session.add(tag)
self._model.tags.append(tag)
session.add(self._model)
session_commit_or_flush(session)
unsafe_add_tag()
try:
session_commit_or_flush(session)
except IntegrityError:
session.rollback()
unsafe_add_tag()
session_commit_or_flush(session)

def remove_tag(self, tag_name: str) -> bool:
session = self._session
Expand Down
8 changes: 2 additions & 6 deletions troubleshooting/base_project_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import sys
import time

import tqdm

import aim
import tqdm


def count_metrics(run):
Expand All @@ -24,10 +23,7 @@ def count_dict_keys(params):
Count the number of leaf nodes in a nested dictionary.
A leaf node is a value that is not a dictionary.
"""
return sum(
count_dict_keys(value) if isinstance(value, dict) else 1
for value in params.values()
)
return sum(count_dict_keys(value) if isinstance(value, dict) else 1 for value in params.values())


parser = argparse.ArgumentParser(description='Process command line arguments.')
Expand Down
Loading