From cd6926cb16cd0eafef5a5df0bc94f9861b709b01 Mon Sep 17 00:00:00 2001 From: Zhiltsov Max Date: Thu, 27 Feb 2020 15:38:57 +0300 Subject: [PATCH] Move project dir to .datumaro --- .../datumaro/cli/contexts/project/__init__.py | 46 +++++++------ datumaro/datumaro/cli/util/project.py | 13 +--- datumaro/datumaro/components/project.py | 66 +++++++++++++------ 3 files changed, 74 insertions(+), 51 deletions(-) diff --git a/datumaro/datumaro/cli/contexts/project/__init__.py b/datumaro/datumaro/cli/contexts/project/__init__.py index 3ca79e047827..61bb77ad433d 100644 --- a/datumaro/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/datumaro/cli/contexts/project/__init__.py @@ -10,7 +10,8 @@ import os.path as osp import shutil -from datumaro.components.project import Project, Environment +from datumaro.components.project import Project, Environment, \ + PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG from datumaro.components.comparator import Comparator from datumaro.components.dataset_filter import DatasetItemEncoder from datumaro.components.extractor import AnnotationType @@ -18,8 +19,7 @@ from .diff import DiffVisualizer from ...util import add_subparser, CliException, MultilineFormatter, \ make_file_name -from ...util.project import make_project_path, load_project, \ - generate_next_dir_name +from ...util.project import load_project, generate_next_dir_name def build_create_parser(parser_ctor=argparse.ArgumentParser): @@ -48,19 +48,23 @@ def build_create_parser(parser_ctor=argparse.ArgumentParser): def create_command(args): project_dir = osp.abspath(args.dst_dir) - project_path = make_project_path(project_dir) - if osp.isdir(project_dir) and os.listdir(project_dir): + project_env_dir = osp.join(project_dir, DEFAULT_CONFIG.env_dir) + if osp.isdir(project_env_dir) and os.listdir(project_env_dir): if not args.overwrite: raise CliException("Directory '%s' already exists " - "(pass --overwrite to force creation)" % project_dir) + "(pass --overwrite to force creation)" % project_env_dir) else: - shutil.rmtree(project_dir) - os.makedirs(project_dir, exist_ok=True) + shutil.rmtree(project_env_dir, ignore_errors=True) - if not args.overwrite and osp.isfile(project_path): - raise CliException("Project file '%s' already exists " - "(pass --overwrite to force creation)" % project_path) + own_dataset_dir = osp.join(project_dir, DEFAULT_CONFIG.dataset_dir) + if osp.isdir(own_dataset_dir) and os.listdir(own_dataset_dir): + if not args.overwrite: + raise CliException("Directory '%s' already exists " + "(pass --overwrite to force creation)" % own_dataset_dir) + else: + # NOTE: remove the dir to avoid using data from previous project + shutil.rmtree(own_dataset_dir) project_name = args.name if project_name is None: @@ -138,19 +142,23 @@ def build_import_parser(parser_ctor=argparse.ArgumentParser): def import_command(args): project_dir = osp.abspath(args.dst_dir) - project_path = make_project_path(project_dir) - if osp.isdir(project_dir) and os.listdir(project_dir): + project_env_dir = osp.join(project_dir, DEFAULT_CONFIG.env_dir) + if osp.isdir(project_env_dir) and os.listdir(project_env_dir): if not args.overwrite: raise CliException("Directory '%s' already exists " - "(pass --overwrite to force creation)" % project_dir) + "(pass --overwrite to force creation)" % project_env_dir) else: - shutil.rmtree(project_dir) - os.makedirs(project_dir, exist_ok=True) + shutil.rmtree(project_env_dir, ignore_errors=True) - if not args.overwrite and osp.isfile(project_path): - raise CliException("Project file '%s' already exists " - "(pass --overwrite to force creation)" % project_path) + own_dataset_dir = osp.join(project_dir, DEFAULT_CONFIG.dataset_dir) + if osp.isdir(own_dataset_dir) and os.listdir(own_dataset_dir): + if not args.overwrite: + raise CliException("Directory '%s' already exists " + "(pass --overwrite to force creation)" % own_dataset_dir) + else: + # NOTE: remove the dir to avoid using data from previous project + shutil.rmtree(own_dataset_dir) project_name = args.name if project_name is None: diff --git a/datumaro/datumaro/cli/util/project.py b/datumaro/datumaro/cli/util/project.py index dde4531a2aaf..af92458bcd14 100644 --- a/datumaro/datumaro/cli/util/project.py +++ b/datumaro/datumaro/cli/util/project.py @@ -4,20 +4,11 @@ # SPDX-License-Identifier: MIT import os -import os.path as osp -from datumaro.components.project import Project, \ - PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG +from datumaro.components.project import Project -def make_project_path(project_dir, project_filename=None): - if project_filename is None: - project_filename = DEFAULT_CONFIG.project_filename - return osp.join(project_dir, project_filename) - -def load_project(project_dir, project_filename=None): - if project_filename: - project_dir = osp.join(project_dir, project_filename) +def load_project(project_dir): return Project.load(project_dir) def generate_next_dir_name(dirname, basedir='.', sep='.'): diff --git a/datumaro/datumaro/components/project.py b/datumaro/datumaro/components/project.py index f6ca653a8673..262710dc7497 100644 --- a/datumaro/datumaro/components/project.py +++ b/datumaro/datumaro/components/project.py @@ -12,11 +12,13 @@ import logging as log import os import os.path as osp +import shutil import sys from datumaro.components.config import Config, DEFAULT_FORMAT -from datumaro.components.config_model import * -from datumaro.components.extractor import DatasetItem, Extractor +from datumaro.components.config_model import (Model, Source, + PROJECT_DEFAULT_CONFIG, PROJECT_SCHEMA) +from datumaro.components.extractor import Extractor from datumaro.components.launcher import InferenceWrapper from datumaro.components.dataset_filter import \ XPathDatasetFilter, XPathAnnotationsFilter @@ -672,16 +674,21 @@ def apply_model(self, model, save_dir=None, batch_size=1): def export_project(self, save_dir, converter, filter_expr=None, filter_annotations=False, remove_empty=False): # NOTE: probably this function should be in the ViewModel layer - save_dir = osp.abspath(save_dir) - os.makedirs(save_dir, exist_ok=True) - dataset = self if filter_expr: dataset = dataset.extract(filter_expr, filter_annotations=filter_annotations, remove_empty=remove_empty) - converter(dataset, save_dir) + save_dir = osp.abspath(save_dir) + save_dir_existed = osp.exists(save_dir) + try: + os.makedirs(save_dir, exist_ok=True) + converter(dataset, save_dir) + except Exception: + if not save_dir_existed: + shutil.rmtree(save_dir) + raise def extract_project(self, filter_expr, filter_annotations=False, save_dir=None, remove_empty=False): @@ -694,24 +701,41 @@ def extract_project(self, filter_expr, filter_annotations=False, self._save_branch_project(filtered, save_dir=save_dir) class Project: - @staticmethod - def load(path): + @classmethod + def load(cls, path): path = osp.abspath(path) - if osp.isdir(path): - path = osp.join(path, PROJECT_DEFAULT_CONFIG.project_filename) - config = Config.parse(path) - config.project_dir = osp.dirname(path) - config.project_filename = osp.basename(path) + config_path = osp.join(path, PROJECT_DEFAULT_CONFIG.env_dir, + PROJECT_DEFAULT_CONFIG.project_filename) + config = Config.parse(config_path) + config.project_dir = path + config.project_filename = osp.basename(config_path) return Project(config) def save(self, save_dir=None): config = self.config + if save_dir is None: assert config.project_dir - save_dir = osp.abspath(config.project_dir) - os.makedirs(save_dir, exist_ok=True) - config_path = osp.join(save_dir, config.project_filename) - config.dump(config_path) + project_dir = config.project_dir + else: + project_dir = save_dir + + env_dir = osp.join(project_dir, config.env_dir) + save_dir = osp.abspath(env_dir) + + project_dir_existed = osp.exists(project_dir) + env_dir_existed = osp.exists(env_dir) + try: + os.makedirs(save_dir, exist_ok=True) + + config_path = osp.join(save_dir, config.project_filename) + config.dump(config_path) + except Exception: + if not env_dir_existed: + shutil.rmtree(save_dir, ignore_errors=True) + if not project_dir_existed: + shutil.rmtree(project_dir, ignore_errors=True) + raise @staticmethod def generate(save_dir, config=None): @@ -735,8 +759,8 @@ def __init__(self, config=None): def make_dataset(self): return ProjectDataset(self) - def add_source(self, name, value=Source()): - if isinstance(value, (dict, Config)): + def add_source(self, name, value=None): + if value is None or isinstance(value, (dict, Config)): value = Source(value) self.config.sources[name] = value self.env.sources.register(name, value) @@ -760,8 +784,8 @@ def set_subsets(self, value): else: self.config.subsets = value - def add_model(self, name, value=Model()): - if isinstance(value, (dict, Config)): + def add_model(self, name, value=None): + if value is None or isinstance(value, (dict, Config)): value = Model(value) self.env.register_model(name, value) self.config.models[name] = value