Skip to content

Commit

Permalink
Fix/sagemaker script generator (#712)
Browse files Browse the repository at this point in the history
Hopefully fixed by:

https://docs.python.org/3/library/shlex.html#shlex.quote

Still need to test on aws

---------

Co-authored-by: Robbe Sneyders <robbe.sneyders@ml6.eu>
  • Loading branch information
GeorgesLorre and RobbeSneyders authored Dec 11, 2023
1 parent 1f95e84 commit 38163c9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
9 changes: 7 additions & 2 deletions src/fondant/pipeline/compiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import shlex
import tempfile
import typing as t
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -765,8 +766,12 @@ def generate_component_script(
"""Generate a bash script for a component to be used as input in a
sagemaker pipeline step. Returns the path to the script.
"""
content = " ".join(["fondant", "execute", "main", *command])
content = ["fondant", "execute", "main"]

# use shlex.quote to escape special bash chars
for c in command:
content.append(shlex.quote(c))

with open(f"{directory}/{component_name}.sh", "w") as f:
f.write(content)
f.write(" ".join(content))
return f"{directory}/{component_name}.sh"
10 changes: 8 additions & 2 deletions tests/pipeline/test_compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import json
import subprocess
import sys
from pathlib import Path
from unittest import mock
Expand Down Expand Up @@ -645,14 +646,19 @@ def test_sagemaker_build_command():

def test_sagemaker_generate_script(tmp_path_factory):
compiler = SagemakerCompiler()
command = ["echo", "hello world"]
command = ["echo", "hello world", r"special chars: $ ! # & ' ( ) | < > ` \ ;"]
with tmp_path_factory.mktemp("temp") as fn:
script_path = compiler.generate_component_script("component_1", command, fn)

assert script_path == f"{fn}/component_1.sh"

assert not subprocess.check_call(["bash", "-n", script_path]) # nosec

with open(script_path) as f:
assert f.read() == "fondant execute main echo hello world"
assert (
f.read()
== "fondant execute main echo 'hello world' 'special chars: $ ! # & '\"'\"' ( ) | < > ` \\ ;'" # noqa E501
)


def test_sagemaker_base_path_validator():
Expand Down

0 comments on commit 38163c9

Please sign in to comment.