diff --git a/src/dotenv/variables.py b/src/dotenv/variables.py index d77b700c..cad22b9f 100644 --- a/src/dotenv/variables.py +++ b/src/dotenv/variables.py @@ -1,18 +1,6 @@ import re from abc import ABCMeta -from typing import Iterator, Mapping, Optional, Pattern - -_posix_variable = re.compile( - r""" - \$\{ - (?P<name>[^\}:]*) - (?::- - (?P<default>[^\}]*) - )? - \} - """, - re.VERBOSE, -) # type: Pattern[str] +from typing import Iterator, Mapping, Optional, Pattern, List class Atom(): @@ -48,7 +36,7 @@ def resolve(self, env: Mapping[str, Optional[str]]) -> str: class Variable(Atom): - def __init__(self, name: str, default: Optional[str]) -> None: + def __init__(self, name: str, default: Optional[List[Atom]]) -> None: self.name = name self.default = default @@ -64,24 +52,50 @@ def __hash__(self) -> int: return hash((self.__class__, self.name, self.default)) def resolve(self, env: Mapping[str, Optional[str]]) -> str: - default = self.default if self.default is not None else "" + default = "".join(atom.resolve(env) for atom in self.default) if self.default is not None else "" result = env.get(self.name, default) return result if result is not None else "" -def parse_variables(value: str) -> Iterator[Atom]: - cursor = 0 +_variable_re = re.compile( + r""" + ^ + (?P<name>[^\}:]*?) + (?::[-=] + (?P<default>.*) + )? + $ + """, + re.VERBOSE, +) # type: Pattern[str] + +ESC_CHAR = '\\' - for match in _posix_variable.finditer(value): - (start, end) = match.span() - name = match.groupdict()["name"] - default = match.groupdict()["default"] - if start > cursor: - yield Literal(value=value[cursor:start]) +def parse_variables(value: str) -> Iterator[Atom]: + cursor = 0 - yield Variable(name=name, default=default) - cursor = end + starts: List[int] = [] + esc = False + for i in range(len(value)): + if esc: + esc = False + elif ESC_CHAR == value[i]: + esc = True + elif i < len(value) - 1 and '$' == value[i] and '{' == value[i+1]: + if len(starts) == 0 and cursor < i: + yield Literal(value=value[cursor:i]) + starts.append(i + 2) + elif '}' == value[i]: + start = starts.pop() + end = i + cursor = i+1 + if len(starts) == 0: + for match in _variable_re.finditer(value[start:end]): + name = match.groupdict()["name"] + default = match.groupdict()["default"] + default = None if default is None else list(parse_variables(default)) + yield Variable(name=name, default=default) length = len(value) if cursor < length: diff --git a/tests/test_variables.py b/tests/test_variables.py index 86b06466..28e16677 100644 --- a/tests/test_variables.py +++ b/tests/test_variables.py @@ -9,7 +9,22 @@ ("", []), ("a", [Literal(value="a")]), ("${a}", [Variable(name="a", default=None)]), - ("${a:-b}", [Variable(name="a", default="b")]), + ("${a:-b}", [Variable(name="a", default=[Literal(value="b")])]), + ("${a:=b}", [Variable(name="a", default=[Literal(value="b")])]), + ( + "${a:-a${b:-c${d}e}f}", + [ + Variable(name="a", default=[ + Literal(value="a"), + Variable(name="b", default=[ + Literal(value="c"), + Variable(name="d", default=None), + Literal(value="e") + ]), + Literal(value="f") + ]) + ] + ), ( "${a}${b}", [