Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exapt squote() to provide secure shell-quoting #496

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions WDL/StdLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import json
import tempfile
import shlex
from typing import List, Tuple, Callable, BinaryIO, Optional
from abc import ABC, abstractmethod
import regex
Expand Down Expand Up @@ -140,7 +141,9 @@ def sep(sep: Value.String, iterable: Value.Array) -> Value.String:
self.min = _ArithmeticOperator("min", lambda l, r: min(l, r))
self.max = _ArithmeticOperator("max", lambda l, r: max(l, r))
self.quote = _Quote()
self.squote = _Quote(squote=True)
self.squote = (
_Quote(squote=True) if self.wdl_version != "development" else _ShellQuote()
)
self.keys = _Keys()
self.as_map = _AsMap()
self.as_pairs = _AsPairs()
Expand Down Expand Up @@ -1012,7 +1015,6 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.
class _Quote(EagerFunction):
# t array -> string array
# if input array is nonempty then so is output
# Append a suffix to every element within the array

def __init__(self, squote: bool = False) -> None:
if squote:
Expand All @@ -1025,7 +1027,6 @@ def infer_type(self, expr: "Expr.Apply") -> Type.Base:
raise Error.WrongArity(expr, 1)
expr.arguments[0].typecheck(Type.Array(Type.String()))
arg0ty = expr.arguments[0].type
nonempty = isinstance(arg0ty, Type.Array) and arg0ty.nonempty
return Type.Array(
Type.String(), nonempty=(isinstance(arg0ty, Type.Array) and arg0ty.nonempty)
)
Expand All @@ -1040,6 +1041,39 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.
)


class _ShellQuote(EagerFunction):
# t -> string or t array -> string array

def infer_type(self, expr: "Expr.Apply") -> Type.Base:
if len(expr.arguments) != 1:
raise Error.WrongArity(expr, 1)
if not isinstance(expr.arguments[0].type, Type.Array):
expr.arguments[0].typecheck(Type.String(optional=True))
return Type.String()
expr.arguments[0].typecheck(Type.Array(Type.String(optional=True), optional=True))
arg0ty = expr.arguments[0].type
return Type.Array(
Type.String(), nonempty=(isinstance(arg0ty, Type.Array) and arg0ty.nonempty)
)

def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.Base:
ty = self.infer_type(expr)
if isinstance(ty, Type.String):
return Value.String(self._shellquote(arguments[0].coerce(Type.String()).value))
assert isinstance(ty, Type.Array)
return Value.Array(
Type.String(),
[
Value.String(self._shellquote(s.coerce(Type.String()).value))
for s in arguments[0].value
],
)

def _shellquote(self, s: str) -> str:
q = shlex.quote(s)
return q if q != s else ("'" + s + "'")


class _Keys(EagerFunction):
def infer_type(self, expr: "Expr.Apply") -> Type.Base:
if len(expr.arguments) != 1:
Expand Down
49 changes: 41 additions & 8 deletions tests/test_5stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,23 +1079,23 @@ def test_quote(self):

def test_squote(self):
outputs = self._test_task(R"""
version development
version 1.1
task test_squote {
command {}
output {
Array[String] arguments = ["foo","bar","baz"]
Array[String] quoted_args = squote(arguments) # ["'foo'","'bar'","'baz'"]
Array[String] arguments = ["foo","bar","baz'"]
Array[String] quoted_args = squote(arguments) # ["'foo'","'bar'","'baz''"]
}
}
""")
# Check to make sure each element has be quoted appropriately
# Check to make sure each element has been quoted appropriately
self.assertEqual(outputs, {
"arguments": ["foo","bar","baz"],
"quoted_args": ["'foo'","'bar'","'baz'"]
"arguments": ["foo","bar","baz'"],
"quoted_args": ["'foo'","'bar'","'baz''"]
})

outputs = self._test_task(R"""
version development
version 1.1
task test_squote {
command {}
output {
Expand All @@ -1113,7 +1113,7 @@ def test_squote(self):

# Check invalid type does not work
outputs = self._test_task(R"""
version development
version 1.1
task test_squote {
command {}
output {
Expand All @@ -1123,6 +1123,39 @@ def test_squote(self):
}
""",expected_exception=WDL.Error.StaticTypeMismatch)

def test_shellquote(self):
outputs = self._test_task(R"""
version development
task test_squote {
command {}
output {
Array[String?] arguments = ["foo","bar","baz'", None]
Array[String] quoted_args = squote(arguments)
}
}
""")
# Check to make sure each element has been quoted appropriately
self.assertEqual(outputs, {
"arguments": ["foo","bar","baz'", None],
"quoted_args": ["'foo'","'bar'",""" 'baz'"'"'' """.strip(), "''"]
})

outputs = self._test_task(R"""
version development
task test_squote {
command {}
output {
String quoted = squote(1)
String quoted2 = squote("so'wl'chu'")
}
}
""")

self.assertEqual(outputs, {
"quoted": "'1'",
"quoted2": """ 'so'"'"'wl'"'"'chu'"'"'' """.strip()
})

def test_keys(self):
outputs = self._test_task(R"""
version development
Expand Down