diff --git a/samples/basic/condition.py b/samples/basic/condition.py index 76d02336e77..62125c20f37 100755 --- a/samples/basic/condition.py +++ b/samples/basic/condition.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2018 Google LLC +# Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,66 +14,62 @@ # limitations under the License. -import kfp.dsl as dsl +import kfp +from kfp import dsl -class RandomNumOp(dsl.ContainerOp): - """Generate a random number between low and high.""" +def random_num_op(low, high): + """Generate a random number between low and high.""" + return dsl.ContainerOp( + name='Generate random number', + image='python:alpine3.6', + command=['sh', '-c'], + arguments=['python -c "import random; print(random.randint($0, $1))" | tee $2', str(low), str(high), '/tmp/output'], + file_outputs={'output': '/tmp/output'} + ) - def __init__(self, low, high): - super(RandomNumOp, self).__init__( - name='Random number', - image='python:alpine3.6', - command=['sh', '-c'], - arguments=['python -c "import random; print(random.randint(%s,%s))" | tee /tmp/output' % (low, high)], - file_outputs={'output': '/tmp/output'}) +def flip_coin_op(): + """Flip a coin and output heads or tails randomly.""" + return dsl.ContainerOp( + name='Flip coin', + image='python:alpine3.6', + command=['sh', '-c'], + arguments=['python -c "import random; result = \'heads\' if random.randint(0,1) == 0 ' + 'else \'tails\'; print(result)" | tee /tmp/output'], + file_outputs={'output': '/tmp/output'} + ) -class FlipCoinOp(dsl.ContainerOp): - """Flip a coin and output heads or tails randomly.""" - def __init__(self): - super(FlipCoinOp, self).__init__( - name='Flip', - image='python:alpine3.6', - command=['sh', '-c'], - arguments=['python -c "import random; result = \'heads\' if random.randint(0,1) == 0 ' - 'else \'tails\'; print(result)" | tee /tmp/output'], - file_outputs={'output': '/tmp/output'}) - - -class PrintOp(dsl.ContainerOp): - """Print a message.""" - - def __init__(self, msg): - super(PrintOp, self).__init__( - name='Print', - image='alpine:3.6', - command=['echo', msg], - ) +def print_op(msg): + """Print a message.""" + return dsl.ContainerOp( + name='Print', + image='alpine:3.6', + command=['echo', msg], + ) @dsl.pipeline( - name='pipeline flip coin', - description='shows how to use dsl.Condition.' + name='Conditional execution pipeline', + description='Shows how to use dsl.Condition().' ) -def flipcoin(): - flip = FlipCoinOp() - with dsl.Condition(flip.output == 'heads'): - random_num_head = RandomNumOp(0, 9) - with dsl.Condition(random_num_head.output > 5): - PrintOp('heads and %s > 5!' % random_num_head.output) - with dsl.Condition(random_num_head.output <= 5): - PrintOp('heads and %s <= 5!' % random_num_head.output) +def flipcoin_pipeline(): + flip = flip_coin_op() + with dsl.Condition(flip.output == 'heads'): + random_num_head = random_num_op(0, 9) + with dsl.Condition(random_num_head.output > 5): + print_op('heads and %s > 5!' % random_num_head.output) + with dsl.Condition(random_num_head.output <= 5): + print_op('heads and %s <= 5!' % random_num_head.output) - with dsl.Condition(flip.output == 'tails'): - random_num_tail = RandomNumOp(10, 19) - with dsl.Condition(random_num_tail.output > 15): - PrintOp('tails and %s > 15!' % random_num_tail.output) - with dsl.Condition(random_num_tail.output <= 15): - PrintOp('tails and %s <= 15!' % random_num_tail.output) + with dsl.Condition(flip.output == 'tails'): + random_num_tail = random_num_op(10, 19) + with dsl.Condition(random_num_tail.output > 15): + print_op('tails and %s > 15!' % random_num_tail.output) + with dsl.Condition(random_num_tail.output <= 15): + print_op('tails and %s <= 15!' % random_num_tail.output) if __name__ == '__main__': - import kfp.compiler as compiler - compiler.Compiler().compile(flipcoin, __file__ + '.zip') + kfp.compiler.Compiler().compile(flipcoin_pipeline, __file__ + '.zip')