Skip to content

Commit

Permalink
Add group_by to named_output #1120
Browse files Browse the repository at this point in the history
  • Loading branch information
Bo Peng committed Dec 21, 2018
1 parent 5fbb28b commit be0fb78
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/sos/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,10 @@ def handle_ctl_req_msg(self, msg):
for step_output in self._completed_steps.values():
if name in step_output.sources:
found = True
self.ctl_req_socket.send_pyobj(step_output[name])
res = step_output[name]
# we also alice the groups to be the groups of res
res._groups = [x[name] for x in step_output._groups]
self.ctl_req_socket.send_pyobj(res)
break
if not found:
self.ctl_req_socket.send_pyobj(None)
Expand Down
18 changes: 16 additions & 2 deletions src/sos/section_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,23 @@ def get_step_depends(section):
step_depends.extend([sos_step(x) for x in get_output_from_steps(stmt, section.last_step)])
if 'named_output' in stmt:
# there can be multiple named_output calls
step_depends.extend(named_output(x[-1]) for x in get_param_of_function('named_output', stmt,
pars = get_param_of_function('named_output', stmt,
extra_dict=env.sos_dict._dict)
)
for par in pars:
# a single argument
if len(par) == 1:
if not isinstance(par[0], str):
raise ValueError(f'Value for named_output can only be a name (str): {par[0]} provided')
step_depends.extend(named_output(par[0]))
else:
if par[0] == 'group_by':
continue
elif par[0] == 'name':
if not isinstance(par[1], str):
raise ValueError(f'Value for named_output can only be a name (str): {par[1]} provided')
step_depends.extend(named_output(par[1]))
else:
raise ValueError(f'Unacceptable keyword argument {par[0]} for named_output()')


depends_idx = find_statement(section, 'depends')
Expand Down
56 changes: 56 additions & 0 deletions test/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,24 @@ def testReturn_OutputInStepOutput(self):
assert(step_input.groups[0] == 'a_0.txt')
assert(step_input.groups[4] == 'a_4.txt')
''')
wf = script.workflow()
Base_Executor(wf).run()
#
# test accumulation of named output
script = SoS_Script('''\
[1]
input: for_each=dict(i=range(5))
output: a=f'a_{i}.txt', b=f'b_{i}.txt'
_output.touch()
[2]
assert(len(step_input.groups) == 5)
assert(len(step_input) == 10)
assert(step_input.groups[0] == ['a_0.txt', 'b_0.txt'])
assert(step_input.groups[0].sources == ['a', 'b'])
assert(step_input.groups[4] == ['a_4.txt', 'b_4.txt'])
assert(step_input.groups[4].sources == ['a', 'b'])
''')
wf = script.workflow()
Base_Executor(wf).run()
Expand Down Expand Up @@ -1525,5 +1543,43 @@ def testOutputFrom(self):
wf = script.workflow(wf)
Base_Executor(wf).run()


def testNamedOutput(self):
'''Testing named_output input function'''
script = SoS_Script('''\
[A]
input: for_each=dict(i=range(4))
output: aa=f'a_{i}.txt', bb=f'b_{i}.txt'
_output.touch()
[B]
input: named_output('aa')
assert(len(step_input.groups) == 4)
assert(len(step_input) == 4)
assert(step_input.sources == ['aa']*4)
assert(step_input.groups[0] == 'a_0.txt')
assert(step_input.groups[3] == 'a_3.txt')
[C]
input: K=named_output('bb')
assert(len(step_input.groups) == 4)
assert(len(step_input) == 4)
assert(step_input.sources == ['K']*4)
assert(step_input.groups[0] == 'b_0.txt')
assert(step_input.groups[3] == 'b_3.txt')
[D]
input: K=named_output('bb', group_by=2)
assert(len(step_input.groups) == 2)
assert(len(step_input) == 4)
assert(step_input.sources == ['K']*4)
assert(step_input.groups[1] == ['b_2.txt', 'b_3.txt'])
''')
for wf in ('B', 'C', 'D'):
wf = script.workflow(wf)
Base_Executor(wf).run()

if __name__ == '__main__':
unittest.main()

0 comments on commit be0fb78

Please sign in to comment.