12
12
import logging as log
13
13
import os
14
14
import os .path as osp
15
+ import shutil
15
16
import sys
16
17
17
18
from datumaro .components .config import Config , DEFAULT_FORMAT
18
- from datumaro .components .config_model import *
19
- from datumaro .components .extractor import DatasetItem , Extractor
19
+ from datumaro .components .config_model import (Model , Source ,
20
+ PROJECT_DEFAULT_CONFIG , PROJECT_SCHEMA )
21
+ from datumaro .components .extractor import Extractor
20
22
from datumaro .components .launcher import InferenceWrapper
21
23
from datumaro .components .dataset_filter import \
22
24
XPathDatasetFilter , XPathAnnotationsFilter
@@ -672,16 +674,21 @@ def apply_model(self, model, save_dir=None, batch_size=1):
672
674
def export_project (self , save_dir , converter ,
673
675
filter_expr = None , filter_annotations = False , remove_empty = False ):
674
676
# NOTE: probably this function should be in the ViewModel layer
675
- save_dir = osp .abspath (save_dir )
676
- os .makedirs (save_dir , exist_ok = True )
677
-
678
677
dataset = self
679
678
if filter_expr :
680
679
dataset = dataset .extract (filter_expr ,
681
680
filter_annotations = filter_annotations ,
682
681
remove_empty = remove_empty )
683
682
684
- converter (dataset , save_dir )
683
+ save_dir = osp .abspath (save_dir )
684
+ save_dir_existed = osp .exists (save_dir )
685
+ try :
686
+ os .makedirs (save_dir , exist_ok = True )
687
+ converter (dataset , save_dir )
688
+ except Exception :
689
+ if not save_dir_existed :
690
+ shutil .rmtree (save_dir )
691
+ raise
685
692
686
693
def extract_project (self , filter_expr , filter_annotations = False ,
687
694
save_dir = None , remove_empty = False ):
@@ -694,24 +701,41 @@ def extract_project(self, filter_expr, filter_annotations=False,
694
701
self ._save_branch_project (filtered , save_dir = save_dir )
695
702
696
703
class Project :
697
- @staticmethod
698
- def load (path ):
704
+ @classmethod
705
+ def load (cls , path ):
699
706
path = osp .abspath (path )
700
- if osp .isdir (path ):
701
- path = osp . join ( path , PROJECT_DEFAULT_CONFIG .project_filename )
702
- config = Config .parse (path )
703
- config .project_dir = osp . dirname ( path )
704
- config .project_filename = osp .basename (path )
707
+ config_path = osp .join (path , PROJECT_DEFAULT_CONFIG . env_dir ,
708
+ PROJECT_DEFAULT_CONFIG .project_filename )
709
+ config = Config .parse (config_path )
710
+ config .project_dir = path
711
+ config .project_filename = osp .basename (config_path )
705
712
return Project (config )
706
713
707
714
def save (self , save_dir = None ):
708
715
config = self .config
716
+
709
717
if save_dir is None :
710
718
assert config .project_dir
711
- save_dir = osp .abspath (config .project_dir )
712
- os .makedirs (save_dir , exist_ok = True )
713
- config_path = osp .join (save_dir , config .project_filename )
714
- config .dump (config_path )
719
+ project_dir = config .project_dir
720
+ else :
721
+ project_dir = save_dir
722
+
723
+ env_dir = osp .join (project_dir , config .env_dir )
724
+ save_dir = osp .abspath (env_dir )
725
+
726
+ project_dir_existed = osp .exists (project_dir )
727
+ env_dir_existed = osp .exists (env_dir )
728
+ try :
729
+ os .makedirs (save_dir , exist_ok = True )
730
+
731
+ config_path = osp .join (save_dir , config .project_filename )
732
+ config .dump (config_path )
733
+ except Exception :
734
+ if not env_dir_existed :
735
+ shutil .rmtree (save_dir , ignore_errors = True )
736
+ if not project_dir_existed :
737
+ shutil .rmtree (project_dir , ignore_errors = True )
738
+ raise
715
739
716
740
@staticmethod
717
741
def generate (save_dir , config = None ):
@@ -735,8 +759,8 @@ def __init__(self, config=None):
735
759
def make_dataset (self ):
736
760
return ProjectDataset (self )
737
761
738
- def add_source (self , name , value = Source () ):
739
- if isinstance (value , (dict , Config )):
762
+ def add_source (self , name , value = None ):
763
+ if value is None or isinstance (value , (dict , Config )):
740
764
value = Source (value )
741
765
self .config .sources [name ] = value
742
766
self .env .sources .register (name , value )
@@ -760,8 +784,8 @@ def set_subsets(self, value):
760
784
else :
761
785
self .config .subsets = value
762
786
763
- def add_model (self , name , value = Model () ):
764
- if isinstance (value , (dict , Config )):
787
+ def add_model (self , name , value = None ):
788
+ if value is None or isinstance (value , (dict , Config )):
765
789
value = Model (value )
766
790
self .env .register_model (name , value )
767
791
self .config .models [name ] = value
0 commit comments