Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/_test_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _test_job_7_delete():
assert task_info.status in ("deleted", "deleting")


def test_job_source_norm():
def test_job_source_norm(caplog):
"""test complete run"""
job = web.Job(simulation=sim_original, task_name="test_job", callback_url=CALLBACK_URL)
sim_data_norm = job.run(path=PATH_SIM_DATA, normalize_index=0)
Expand Down
28 changes: 1 addition & 27 deletions tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,7 @@

from tidy3d import *
from tidy3d.log import ValidationError, SetupError


def assert_log_level(caplog, log_level_expected):
"""ensure something got logged if log_level is not None.
note: I put this here rather than utils.py because if we import from utils.py,
it will validate the sims there and those get included in log.
"""

# get log output
logs = caplog.record_tuples

# there's a log but the log level is not None (problem)
if logs and not log_level_expected:
raise Exception

# we expect a log but none is given (problem)
if log_level_expected and not logs:
raise Exception

# both expected and got log, check the log levels match
if logs and log_level_expected:
for log in logs:
log_level = log[1]
if log_level == log_level_expected:
# log level was triggered, exit
return
raise Exception
from .utils import assert_log_level


def test_sim():
Expand Down
27 changes: 27 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,30 @@ def prepend_tmp(path):
run_time=1e-12,
pml_layers=3 * [PML()],
)


def assert_log_level(caplog, log_level_expected):
"""ensure something got logged if log_level is not None.
note: I put this here rather than utils.py because if we import from utils.py,
it will validate the sims there and those get included in log.
"""

# get log output
logs = caplog.record_tuples

# there's a log but the log level is not None (problem)
if logs and not log_level_expected:
raise Exception

# we expect a log but none is given (problem)
if log_level_expected and not logs:
raise Exception

# both expected and got log, check the log levels match
if logs and log_level_expected:
for log in logs:
log_level = log[1]
if log_level == log_level_expected:
# log level was triggered, exit
return
raise Exception
5 changes: 2 additions & 3 deletions tidy3d/components/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,9 +1195,8 @@ def normalize(self, normalize_index: int = 0):
try:
source = self.simulation.sources[normalize_index]
source_time = source.source_time
except Exception: # pylint:disable=broad-except
logging.warning(f"Could not locate source at normalize_index={normalize_index}.")
return self
except IndexError as e:
raise DataError(f"Could not locate source at normalize_index={normalize_index}.") from e

source_time = source.source_time
sim_data_norm = self.copy(deep=True)
Expand Down
12 changes: 4 additions & 8 deletions tidy3d/plugins/dispersion/fit_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from ...components import PoleResidue
from ...constants import MICROMETER, HERTZ
from ...log import log, WebError, Tidy3dError
from ...web.httputils import get_headers
from ...web.auth import requires_auth

from .fit import DispersionFitter

BOUND_MAX_FACTOR = 10
Expand Down Expand Up @@ -143,6 +146,7 @@ def _set_url(config_env: Literal["default", "dev", "prod", "local"] = "default")
return URL_ENV[_env]

@staticmethod
@requires_auth
def _setup_server(url_server: str):
"""set up web server access

Expand All @@ -152,14 +156,6 @@ def _setup_server(url_server: str):
URL for the server
"""

from ...web.auth import ( # pylint:disable=import-outside-toplevel, unused-import
get_credentials,
)
from ...web.httputils import ( # pylint:disable=import-outside-toplevel
get_headers,
)

# get_credentials()
access_token = get_headers()
headers = {"Authorization": access_token["Authorization"]}

Expand Down
5 changes: 1 addition & 4 deletions tidy3d/plugins/smatrix/smatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ...components.viz import add_ax_if_none, equal_aspect
from ...components.base import Tidy3dBaseModel
from ...log import SetupError
from ...web.container import Batch

# fwidth of gaussian pulse in units of central frequency
FWIDTH_FRAC = 1.0 / 10
Expand Down Expand Up @@ -172,10 +173,6 @@ def _run_sims(
self, sim_dict: Dict[str, Simulation], folder_name: str, path_dir: str
) -> "BatchData":
"""Run :class:`Simulations` for each port and return the batch after saving."""

# do it here as to not trigger web auth when importing the plugin
from ...web.container import Batch

batch = Batch(simulations=sim_dict, folder_name=folder_name)

batch.upload()
Expand Down
3 changes: 0 additions & 3 deletions tidy3d/web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,4 @@
from .webapi import run, upload, get_info, start, monitor, delete, download, load
from .webapi import get_tasks, delete_old
from .container import Job, Batch

from .auth import get_credentials

get_credentials()
65 changes: 47 additions & 18 deletions tidy3d/web/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,30 @@
import getpass
import hashlib
import json
import functools

import boto3
import requests

from .config import DEFAULT_CONFIG as Config
from ..log import log

# maximum attempts for credentials input
MAX_ATTEMPTS = 3

# where we store the credentials locally
CREDENTIAL_FILE = "~/.tidy3d/auth.json"
credential_path = os.path.expanduser(CREDENTIAL_FILE)
credential_dir = os.path.split(credential_path)[0]
if not os.path.exists(credential_dir):
os.mkdir(credential_dir)

boto3.setup_default_session(region_name=Config.s3_region)


def set_authentication_config(email: str, password: str) -> None:
"""Sets the authorization and keys in the config for a for user."""
os.environ["TIDY3D_USER"] = email
os.environ["TIDY3D_PASS_HASH"] = password

url = "/".join([Config.auth_api_endpoint, "auth"])
headers = {"Application": "TIDY3D"}
resp = requests.get(url, headers=headers, auth=(email, password))
Expand All @@ -44,25 +50,32 @@ def encode_password(password: str) -> str:
return hashlib.sha512(password.encode("utf-8") + salt.encode("utf-8")).hexdigest()


# pylint:disable=too-many-branches
def get_credentials() -> None:
"""Tries to log user in from environment variables, then from file, if not working, prompts
user for login info and saves to file."""

# if we find credentials in environment variables
if "TIDY3D_USER" in os.environ and "TIDY3D_PASS" in os.environ:
print("Using Tidy3D credentials from enviornment")
if "TIDY3D_USER" in os.environ and (
"TIDY3D_PASS" in os.environ or "TIDY3D_PASS_HASH" in os.environ
):
log.debug("Using Tidy3D credentials from enviornment")
email = os.environ["TIDY3D_USER"]
password = os.environ["TIDY3D_PASS"]
password = os.environ.get("TIDY3D_PASS")
if password is None:
password = os.environ.get("TIDY3D_PASS_HASH")
else:
password = encode_password(password)
try:
set_authentication_config(email, encode_password(password))
set_authentication_config(email, password)
return

except Exception: # pylint:disable=broad-except
print("Error: Failed to log in with environment credentials.")
log.info("Error: Failed to log in with environment credentials.")

# if we find something in the credential path
if os.path.exists(credential_path):
print("Using Tidy3D credentials from stored file")
log.info("Using Tidy3D credentials from stored file")
# try to authenticate them
try:
with open(credential_path, "r", encoding="utf-8") as fp:
Expand All @@ -73,10 +86,10 @@ def get_credentials() -> None:
return

except Exception: # pylint:disable=broad-except
print("Error: Failed to log in with saved credentials.")
log.info("Error: Failed to log in with saved credentials.")

# keep trying to log in
while True:
for _ in range(MAX_ATTEMPTS):

email = input("enter your email registered at tidy3d: ")
password = getpass.getpass("enter your password: ")
Expand All @@ -88,27 +101,43 @@ def get_credentials() -> None:
break

except Exception: # pylint:disable=broad-except
print("Error: Failed to log in with new username and password.")
log.info("Error: Failed to log in with new username and password.")

# ask to stay logged in
while True:
for _ in range(MAX_ATTEMPTS):

keep_logged_in = input("Do you want to keep logged in on this machine? ([Y]es / [N]o) ")

# if user wants to stay logged in
if keep_logged_in.lower() == "y":

auth_json = {"email": email, "password": password}
with open(credential_path, "w", encoding="utf-8") as fp:
json.dump(auth_json, fp)
return
try:
if not os.path.exists(credential_dir):
os.mkdir(credential_dir)

auth_json = {"email": email, "password": password}
with open(credential_path, "w", encoding="utf-8") as fp:
json.dump(auth_json, fp)
return

except Exception: # pylint:disable=broad-except
log.info("Error: Failed to store credentials.")
return

# if doesn't want to keep logged in, just return without saving file
if keep_logged_in.lower() == "n":
return

# otherwise, prompt again
print(f"Unknown response: {keep_logged_in}")
log.info(f"Unknown response: {keep_logged_in}")


def requires_auth(func):
"""Decorator for functions that require the authentication step."""

@functools.wraps(func)
def auth_before_call(*args, **kwargs):
get_credentials()
return func(*args, **kwargs)

# get_credentials()
return auth_before_call
10 changes: 10 additions & 0 deletions tidy3d/web/webapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .config import DEFAULT_CONFIG as Config
from .s3utils import get_s3_user, DownloadProgress
from .auth import requires_auth
from .task import TaskId, TaskInfo
from . import httputils as http
from ..components.simulation import Simulation
Expand Down Expand Up @@ -110,6 +111,7 @@ def upload(
return task_id


@requires_auth
def get_info(task_id: TaskId) -> TaskInfo:
"""Return information about a task.

Expand All @@ -130,6 +132,7 @@ def get_info(task_id: TaskId) -> TaskInfo:
return TaskInfo(**info_dict)


@requires_auth
def start(task_id: TaskId) -> None:
"""Start running the simulation associated with task.

Expand All @@ -149,6 +152,7 @@ def start(task_id: TaskId) -> None:
http.put(method, data=task.dict())


@requires_auth
def get_run_info(task_id: TaskId):
"""Gets the % done and field_decay for a running task.

Expand Down Expand Up @@ -177,6 +181,7 @@ def get_run_info(task_id: TaskId):
return None, None


@requires_auth
def monitor(task_id: TaskId) -> None:
"""Print the real time task progress until completion.

Expand Down Expand Up @@ -329,6 +334,7 @@ def load(
return sim_data


@requires_auth
def delete(task_id: TaskId) -> TaskInfo:
"""Delete server-side data associated with task.

Expand All @@ -347,6 +353,7 @@ def delete(task_id: TaskId) -> TaskInfo:
return http.delete(method)


@requires_auth
def delete_old(days_old: int = 100, folder: str = None) -> int:
"""Delete all tasks older than a given amount of days.

Expand Down Expand Up @@ -380,6 +387,7 @@ def delete_old(days_old: int = 100, folder: str = None) -> int:
return count


@requires_auth
def get_tasks(num_tasks: int = None, order: Literal["new", "old"] = "new") -> List[Dict]:
"""Get a list with the metadata of the last ``num_tasks`` tasks.

Expand Down Expand Up @@ -424,6 +432,7 @@ def get_tasks(num_tasks: int = None, order: Literal["new", "old"] = "new") -> Li
return out_dict


@requires_auth
def _upload_task( # pylint:disable=too-many-locals,too-many-arguments
simulation: Simulation,
task_name: str,
Expand Down Expand Up @@ -480,6 +489,7 @@ def _upload_task( # pylint:disable=too-many-locals,too-many-arguments
return task_id


@requires_auth
def _download_file(task_id: TaskId, fname: str, path: str) -> None:
"""Download a specific file from server.

Expand Down