22import os
33import re
44import tempfile
5+ from collections import defaultdict
6+ from collections .abc import Mapping
57from concurrent .futures import ProcessPoolExecutor , as_completed
68from contextlib import contextmanager
79from typing import Iterable , Optional
810
911from funcy import cached_property
1012
1113from dvc .exceptions import DvcException
14+ from dvc .path_info import PathInfo
1215from dvc .repo .experiments .executor import ExperimentExecutor , LocalExecutor
1316from dvc .scm .git import Git
1417from dvc .stage .serialize import to_lockfile
@@ -139,21 +142,39 @@ def _scm_checkout(self, rev):
139142 logger .debug ("Checking out experiment commit '%s'" , rev )
140143 self .scm .checkout (rev )
141144
142- def _stash_exp (self , * args , ** kwargs ):
145+ def _stash_exp (self , * args , params : Optional [ dict ] = None , ** kwargs ):
143146 """Stash changes from the current (parent) workspace as an experiment.
147+
148+ Args:
149+ params: Optional dictionary of parameter values to be used.
150+ Values take priority over any parameters specified in the
151+ user's workspace.
144152 """
145153 rev = self .scm .get_rev ()
154+
155+ # patch user's workspace into experiments clone
146156 tmp = tempfile .NamedTemporaryFile (delete = False ).name
147157 try :
148158 self .repo .scm .repo .git .diff (patch = True , output = tmp )
149159 if os .path .getsize (tmp ):
150160 logger .debug ("Patching experiment workspace" )
151161 self .scm .repo .git .apply (tmp )
152- else :
162+ elif not params :
163+ # experiment matches original baseline
153164 raise UnchangedExperimentError (rev )
154165 finally :
155166 remove (tmp )
167+
168+ # update experiment params from command line
169+ if params :
170+ self ._update_params (params )
171+
172+ # save additional repro command line arguments
156173 self ._pack_args (* args , ** kwargs )
174+
175+ # save experiment as a stash commit w/message containing baseline rev
176+ # (stash commits are merge commits and do not contain a parent commit
177+ # SHA)
157178 msg = f"{ self .STASH_MSG_PREFIX } { rev } "
158179 self .scm .repo .git .stash ("push" , "-m" , msg )
159180 return self .scm .resolve_rev ("stash@{0}" )
@@ -166,6 +187,36 @@ def _unpack_args(self, tree=None):
166187 args_file = os .path .join (self .exp_dvc .tmp_dir , self .PACKED_ARGS_FILE )
167188 return ExperimentExecutor .unpack_repro_args (args_file , tree = tree )
168189
190+ def _update_params (self , params : dict ):
191+ """Update experiment params files with the specified values."""
192+ from dvc .utils .toml import dump_toml , parse_toml_for_update
193+ from dvc .utils .yaml import dump_yaml , parse_yaml_for_update
194+
195+ logger .debug ("Using experiment params '%s'" , params )
196+
197+ # recursive dict update
198+ def _update (dict_ , other ):
199+ for key , value in other .items ():
200+ if isinstance (value , Mapping ):
201+ dict_ [key ] = _update (dict_ .get (key , {}), value )
202+ else :
203+ dict_ [key ] = value
204+ return dict_
205+
206+ loaders = defaultdict (lambda : parse_yaml_for_update )
207+ loaders .update ({".toml" : parse_toml_for_update })
208+ dumpers = defaultdict (lambda : dump_yaml )
209+ dumpers .update ({".toml" : dump_toml })
210+
211+ for params_fname in params :
212+ path = PathInfo (self .exp_dvc .root_dir ) / params_fname
213+ with self .exp_dvc .tree .open (path , "r" ) as fobj :
214+ text = fobj .read ()
215+ suffix = path .suffix .lower ()
216+ data = loaders [suffix ](text , path )
217+ _update (data , params [params_fname ])
218+ dumpers [suffix ](path , data )
219+
169220 def _commit (self , exp_hash , check_exists = True , branch = True ):
170221 """Commit stages as an experiment and return the commit SHA."""
171222 if not self .scm .is_dirty ():
@@ -207,23 +258,19 @@ def reproduce_queued(self, **kwargs):
207258 )
208259 return results
209260
210- def new (self , * args , workspace = True , ** kwargs ):
261+ def new (self , * args , ** kwargs ):
211262 """Create a new experiment.
212263
213264 Experiment will be reproduced and checked out into the user's
214265 workspace.
215266 """
216267 rev = self .repo .scm .get_rev ()
217268 self ._scm_checkout (rev )
218- if workspace :
219- try :
220- stash_rev = self ._stash_exp (* args , ** kwargs )
221- except UnchangedExperimentError as exc :
222- logger .info ("Reproducing existing experiment '%s'." , rev [:7 ])
223- raise exc
224- else :
225- # configure params via command line here
226- pass
269+ try :
270+ stash_rev = self ._stash_exp (* args , ** kwargs )
271+ except UnchangedExperimentError as exc :
272+ logger .info ("Reproducing existing experiment '%s'." , rev [:7 ])
273+ raise exc
227274 logger .debug (
228275 "Stashed experiment '%s' for future execution." , stash_rev [:7 ]
229276 )
@@ -365,8 +412,10 @@ def checkout_exp(self, rev):
365412 tmp = tempfile .NamedTemporaryFile (delete = False ).name
366413 self .scm .repo .head .commit .diff ("HEAD~1" , patch = True , output = tmp )
367414
368- logger .debug ("Stashing workspace changes." )
369- self .repo .scm .repo .git .stash ("push" )
415+ dirty = self .repo .scm .is_dirty ()
416+ if dirty :
417+ logger .debug ("Stashing workspace changes." )
418+ self .repo .scm .repo .git .stash ("push" )
370419
371420 try :
372421 if os .path .getsize (tmp ):
@@ -379,7 +428,8 @@ def checkout_exp(self, rev):
379428 raise DvcException ("failed to apply experiment changes." )
380429 finally :
381430 remove (tmp )
382- self ._unstash_workspace ()
431+ if dirty :
432+ self ._unstash_workspace ()
383433
384434 if need_checkout :
385435 dvc_checkout (self .repo )
0 commit comments