Skip to content

Commit

Permalink
🚨 Fix linting errors in smoke tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jemrobinson committed Sep 20, 2023
1 parent 15668ec commit b039cd7
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def __init__(
self.file_hash = file_hash
self.file_target = file_target
self.file_permissions = file_permissions
self.force_refresh = Output.from_input(force_refresh).apply(lambda force: force if force else False)
self.force_refresh = Output.from_input(force_refresh).apply(
lambda force: force if force else False
)
self.subscription_name = subscription_name
self.vm_name = vm_name
self.vm_resource_group_name = vm_resource_group_name
Expand Down
10 changes: 7 additions & 3 deletions data_safe_haven/infrastructure/stacks/sre/workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,12 @@ def __init__(
]

# Upload smoke tests
mustache_values={
mustache_values = {
"check_uninstallable_packages": "0",
}
file_uploads = [(FileReader(resources_path / "workspace" / "run_all_tests.bats"), "0444")]
file_uploads = [
(FileReader(resources_path / "workspace" / "run_all_tests.bats"), "0444")
]
for test_file in pathlib.Path(resources_path / "workspace").glob("test*"):
file_uploads.append((FileReader(test_file), "0444"))
for vm, vm_output in zip(vms, vm_outputs, strict=True):
Expand All @@ -189,7 +191,9 @@ def __init__(
file_smoke_test = FileUpload(
replace_separators(f"{self._name}_file_{file_upload.name}", "_"),
FileUploadProps(
file_contents=file_upload.file_contents(mustache_values=mustache_values),
file_contents=file_upload.file_contents(
mustache_values=mustache_values
),
file_hash=file_upload.sha256(),
file_permissions=file_permissions,
file_target=f"/opt/tests/{file_upload.name}",
Expand Down
23 changes: 17 additions & 6 deletions data_safe_haven/resources/workspace/test_databases_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,17 @@
import pymssql


def test_database(server_name, hostname, port, db_type, db_name, username, password):
print(f"Attempting to connect to '{db_name}' on '{server_name}' via port {port}")
def test_database(
server_name: str,
hostname: str,
port: int,
db_type: str,
db_name: str,
username: str,
password: str,
) -> None:
msg = f"Attempting to connect to '{db_name}' on '{server_name}' via port {port}"
print(msg) # noqa: T201
username_full = f"{username}@{hostname}"
cnxn = None
if db_type == "mssql":
Expand All @@ -18,13 +27,15 @@ def test_database(server_name, hostname, port, db_type, db_name, username, passw
connection_string = f"host={server_name} port={port} dbname={db_name} user={username_full} password={password}"
cnxn = psycopg.connect(connection_string)
else:
raise ValueError(f"Database type '{db_type}' was not recognised")
msg = f"Database type '{db_type}' was not recognised"
raise ValueError(msg)
df = pd.read_sql("SELECT * FROM information_schema.tables;", cnxn)
if df.size:
print(df.head(5))
print("All database tests passed")
print(df.head(5)) # noqa: T201
print("All database tests passed") # noqa: T201
else:
raise ValueError(f"Reading from database '{db_name}' failed.")
msg = f"Reading from database '{db_name}' failed."
raise ValueError(msg)


# Parse command line arguments
Expand Down
10 changes: 5 additions & 5 deletions data_safe_haven/resources/workspace/test_functionality_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
from sklearn.linear_model import LogisticRegression


def gen_data(n_samples, n_points):
def gen_data(n_samples: int, n_points: int) -> pd.DataFrame:
"""Generate data for fitting"""
target = np.random.binomial(n=1, p=0.5, size=(n_samples, 1))
theta = np.random.normal(loc=0.0, scale=1.0, size=(1, n_points))
means = np.mean(np.multiply(target, theta), axis=0)
values = np.random.multivariate_normal(
means, np.diag([1] * n_points), size=n_samples
).T
data = dict(("x{}".format(n), values[n]) for n in range(n_points))
data = {f"x{n}": values[n] for n in range(n_points)}
data["y"] = target.reshape((n_samples,))
data["weights"] = np.random.gamma(shape=1, scale=1.0, size=n_samples)
return pd.DataFrame(data=data)


def main():
def main() -> None:
"""Logistic regression"""
data = gen_data(100, 3)
input_data = data.iloc[:, :-2]
Expand All @@ -29,8 +29,8 @@ def main():
logit.fit(input_data, output_data, sample_weight=weights)
logit.score(input_data, output_data, sample_weight=weights)

print("Logistic model ran OK")
print("All functionality tests passed")
print("Logistic model ran OK") # noqa: T201
print("All functionality tests passed") # noqa: T201


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,15 @@ module = [
"cryptography.*",
"dns.*",
"msal.*",
"numpy.*",
"pandas.*",
"psycopg.*",
"pulumi.*",
"pulumi_azure_native.*",
"pymssql.*",
"rich.*",
"simple_acme_dns.*",
"sklearn.*",
"typer.*",
"websocket.*",
]
Expand Down

0 comments on commit b039cd7

Please sign in to comment.