Skip to content

Commit

Permalink
Fix environment variable leak for unused formatters (#338)
Browse files Browse the repository at this point in the history
* Add tests against environment pollution

* Delay `rpy2` import until the formatter is requested
  • Loading branch information
krassowski authored Jul 19, 2024
1 parent a7a9f10 commit 7cead40
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 32 deletions.
59 changes: 27 additions & 32 deletions jupyterlab_code_formatter/formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,11 @@
from functools import wraps
from typing import List, Type

try:
import rpy2
import rpy2.robjects
except ImportError:
pass
if sys.version_info >= (3, 9):
from functools import cache
else:
from functools import lru_cache

cache = lru_cache(maxsize=None)

from packaging import version
Expand Down Expand Up @@ -357,56 +353,55 @@ def format_code(self, code: str, notebook: bool, **options) -> str:
return isort.code(code=code, **options)


class FormatRFormatter(BaseFormatter):
label = "Apply FormatR Formatter"
package_name = "formatR"
class RFormatter(BaseFormatter):
@property
@abc.abstractmethod
def package_name(self) -> str:
pass

@property
def importable(self) -> bool:
try:
import rpy2.robjects.packages as rpackages
package_location = subprocess.run(
["Rscript", "-e", f"cat(system.file(package='{self.package_name}'))"],
capture_output=True,
text=True,
)
return package_location != ""

rpackages.importr(self.package_name, robject_translations={".env": "env"})

return True
except Exception:
return False
class FormatRFormatter(RFormatter):
label = "Apply FormatR Formatter"
package_name = "formatR"

@handle_line_ending_and_magic
def format_code(self, code: str, notebook: bool, **options) -> str:
import rpy2.robjects.packages as rpackages
from rpy2.robjects import conversion, default_converter

format_r = rpackages.importr(self.package_name, robject_translations={".env": "env"})
formatted_code = format_r.tidy_source(text=code, output=False, **options)
return "\n".join(formatted_code[0])
with conversion.localconverter(default_converter):
format_r = rpackages.importr(self.package_name, robject_translations={".env": "env"})
formatted_code = format_r.tidy_source(text=code, output=False, **options)
return "\n".join(formatted_code[0])


class StylerFormatter(BaseFormatter):
class StylerFormatter(RFormatter):
label = "Apply Styler Formatter"
package_name = "styler"

@property
def importable(self) -> bool:
try:
import rpy2.robjects.packages as rpackages

rpackages.importr(self.package_name)

return True
except Exception:
return False

@handle_line_ending_and_magic
def format_code(self, code: str, notebook: bool, **options) -> str:
import rpy2.robjects.packages as rpackages
from rpy2.robjects import conversion, default_converter

styler_r = rpackages.importr(self.package_name)
formatted_code = styler_r.style_text(code, **self._transform_options(styler_r, options))
return "\n".join(formatted_code)
with conversion.localconverter(default_converter):
styler_r = rpackages.importr(self.package_name)
formatted_code = styler_r.style_text(code, **self._transform_options(styler_r, options))
return "\n".join(formatted_code)

@staticmethod
def _transform_options(styler_r, options):
transformed_options = copy.deepcopy(options)
import rpy2.robjects

if "math_token_spacing" in transformed_options:
if isinstance(options["math_token_spacing"], dict):
Expand Down
38 changes: 38 additions & 0 deletions jupyterlab_code_formatter/tests/test_formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import json
import os
import sys
from subprocess import run
from unittest import mock

import pytest

from jupyterlab_code_formatter.formatters import SERVER_FORMATTERS


def test_env_pollution_on_import():
# should not pollute environment on import
code = "; ".join(
[
"from jupyterlab_code_formatter import formatters",
"import json",
"import os",
"assert formatters",
"print(json.dumps(os.environ.copy()))",
]
)
result = run([sys.executable, "-c", f"{code}"], capture_output=True, text=True, check=True, env={})
environ = json.loads(result.stdout)
assert set(environ.keys()) - {"LC_CTYPE"} == set()


@pytest.mark.parametrize("name", SERVER_FORMATTERS)
def test_env_pollution_on_importable_check(name):
formatter = SERVER_FORMATTERS[name]
# should not pollute environment on `importable` check
with mock.patch.dict(os.environ, {}, clear=True):
# invoke the property getter
is_importable = formatter.importable
# the environment should have no extra keys
assert set(os.environ.keys()) == set()
if not is_importable:
pytest.skip(f"{name} formatter was not importable, the test may yield false negatives")

0 comments on commit 7cead40

Please sign in to comment.