Skip to content

Commit

Permalink
[REF-843] Automatically update api_url and deploy_url (#1954)
Browse files Browse the repository at this point in the history
  • Loading branch information
masenf authored Oct 13, 2023
1 parent d0cb5b0 commit 684912e
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 11 deletions.
23 changes: 21 additions & 2 deletions reflex/.templates/web/utils/state.js
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,24 @@ export const getToken = () => {
return token;
};

/**
* Get the URL for the websocket connection
* @returns The websocket URL object.
*/
export const getEventURL = () => {
// Get backend URL object from the endpoint.
const endpoint = new URL(EVENTURL);
if (endpoint.hostname === "localhost") {
// If the backend URL references localhost, and the frontend is not on localhost,
// then use the frontend host.
const frontend_hostname = window.location.hostname;
if (frontend_hostname !== "localhost") {
endpoint.hostname = frontend_hostname;
}
}
return endpoint
}

/**
* Apply a delta to the state.
* @param state The state to apply the delta to.
Expand Down Expand Up @@ -289,9 +307,10 @@ export const connect = async (
client_storage = {},
) => {
// Get backend URL object from the endpoint.
const endpoint = new URL(EVENTURL);
const endpoint = getEventURL()

// Create the socket.
socket.current = io(EVENTURL, {
socket.current = io(endpoint.href, {
path: endpoint["pathname"],
transports: transports,
autoUnref: false,
Expand Down
28 changes: 23 additions & 5 deletions reflex/components/overlay/banner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

from typing import Optional

from reflex.components.base.bare import Bare
from reflex.components.component import Component
from reflex.components.layout import Box, Cond
from reflex.components.overlay.modal import Modal
from reflex.components.typography import Text
from reflex.vars import Var
from reflex.utils import imports
from reflex.vars import ImportVar, Var

connection_error: Var = Var.create_safe(
value="(connectError !== null) ? connectError.message : ''",
Expand All @@ -21,19 +23,35 @@
has_connection_error.type_ = bool


def default_connection_error() -> list[str | Var]:
class WebsocketTargetURL(Bare):
"""A component that renders the websocket target URL."""

def _get_imports(self) -> imports.ImportDict:
return {
"/utils/state.js": {ImportVar(tag="getEventURL")},
}

@classmethod
def create(cls) -> Component:
"""Create a websocket target URL component.
Returns:
The websocket target URL component.
"""
return super().create(contents="{getEventURL().href}")


def default_connection_error() -> list[str | Var | Component]:
"""Get the default connection error message.
Returns:
The default connection error message.
"""
from reflex.config import get_config

return [
"Cannot connect to server: ",
connection_error,
". Check if server is reachable at ",
get_config().api_url or "<API_URL not set>",
WebsocketTargetURL.create(),
]


Expand Down
53 changes: 49 additions & 4 deletions reflex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import os
import sys
import urllib.parse
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set

import pydantic

from reflex import constants
from reflex.base import Base
Expand Down Expand Up @@ -191,6 +193,9 @@ class Config:
# The username.
username: Optional[str] = None

# Attributes that were explicitly set by the user.
_non_default_attributes: Set[str] = pydantic.PrivateAttr(set())

def __init__(self, *args, **kwargs):
"""Initialize the config values.
Expand All @@ -204,7 +209,14 @@ def __init__(self, *args, **kwargs):
self.check_deprecated_values(**kwargs)

# Update the config from environment variables.
self.update_from_env()
env_kwargs = self.update_from_env()
for key, env_value in env_kwargs.items():
setattr(self, key, env_value)

# Update default URLs if ports were set
kwargs.update(env_kwargs)
self._non_default_attributes.update(kwargs)
self._replace_defaults(**kwargs)

@staticmethod
def check_deprecated_values(**kwargs):
Expand All @@ -227,13 +239,16 @@ def check_deprecated_values(**kwargs):
"env_path is deprecated - use environment variables instead"
)

def update_from_env(self):
def update_from_env(self) -> dict[str, Any]:
"""Update the config from environment variables.
Returns:
The updated config values.
Raises:
ValueError: If an environment variable is set to an invalid type.
"""
updated_values = {}
# Iterate over the fields.
for key, field in self.__fields__.items():
# The env var name is the key in uppercase.
Expand All @@ -260,7 +275,9 @@ def update_from_env(self):
raise

# Set the value.
setattr(self, key, env_var)
updated_values[key] = env_var

return updated_values

def get_event_namespace(self) -> str | None:
"""Get the websocket event namespace.
Expand All @@ -274,6 +291,34 @@ def get_event_namespace(self) -> str | None:
event_url = constants.Endpoint.EVENT.get_url()
return urllib.parse.urlsplit(event_url).path

def _replace_defaults(self, **kwargs):
"""Replace formatted defaults when the caller provides updates.
Args:
**kwargs: The kwargs passed to the config or from the env.
"""
if "api_url" not in self._non_default_attributes and "backend_port" in kwargs:
self.api_url = f"http://localhost:{kwargs['backend_port']}"

if (
"deploy_url" not in self._non_default_attributes
and "frontend_port" in kwargs
):
self.deploy_url = f"http://localhost:{kwargs['frontend_port']}"

def _set_persistent(self, **kwargs):
"""Set values in this config and in the environment so they persist into subprocess.
Args:
**kwargs: The kwargs passed to the config.
"""
for key, value in kwargs.items():
if value is not None:
os.environ[key.upper()] = str(value)
setattr(self, key, value)
self._non_default_attributes.update(kwargs)
self._replace_defaults(**kwargs)


def get_config(reload: bool = False) -> Config:
"""Get the app config.
Expand Down
1 change: 1 addition & 0 deletions reflex/config.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,6 @@ class Config(Base):
def check_deprecated_values(**kwargs) -> None: ...
def update_from_env(self) -> None: ...
def get_event_namespace(self) -> str | None: ...
def _set_persistent(self, **kwargs) -> None: ...

def get_config(reload: bool = ...) -> Config: ...
6 changes: 6 additions & 0 deletions reflex/reflex.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ def run(
if backend and processes.is_process_on_port(backend_port):
backend_port = processes.change_or_terminate_port(backend_port, "backend")

# Apply the new ports to the config.
if frontend_port != str(config.frontend_port):
config._set_persistent(frontend_port=frontend_port)
if backend_port != str(config.backend_port):
config._set_persistent(backend_port=backend_port)

console.rule("[bold]Starting Reflex App")

if frontend:
Expand Down
92 changes: 92 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,95 @@ def test_event_namespace(mocker, kwargs, expected):
config = reflex.config.get_config()
assert conf == config
assert config.get_event_namespace() == expected


DEFAULT_CONFIG = rx.Config(app_name="a")


@pytest.mark.parametrize(
("config_kwargs", "env_vars", "set_persistent_vars", "exp_config_values"),
[
(
{},
{},
{},
{
"api_url": DEFAULT_CONFIG.api_url,
"backend_port": DEFAULT_CONFIG.backend_port,
"deploy_url": DEFAULT_CONFIG.deploy_url,
"frontend_port": DEFAULT_CONFIG.frontend_port,
},
),
# Ports set in config kwargs
(
{"backend_port": 8001, "frontend_port": 3001},
{},
{},
{
"api_url": "http://localhost:8001",
"backend_port": 8001,
"deploy_url": "http://localhost:3001",
"frontend_port": 3001,
},
),
# Ports set in environment take precendence
(
{"backend_port": 8001, "frontend_port": 3001},
{"BACKEND_PORT": 8002},
{},
{
"api_url": "http://localhost:8002",
"backend_port": 8002,
"deploy_url": "http://localhost:3001",
"frontend_port": 3001,
},
),
# Ports set on the command line take precendence
(
{"backend_port": 8001, "frontend_port": 3001},
{"BACKEND_PORT": 8002},
{"frontend_port": "3005"},
{
"api_url": "http://localhost:8002",
"backend_port": 8002,
"deploy_url": "http://localhost:3005",
"frontend_port": 3005,
},
),
# api_url / deploy_url already set should not be overridden
(
{"api_url": "http://foo.bar:8900", "deploy_url": "http://foo.bar:3001"},
{"BACKEND_PORT": 8002},
{"frontend_port": "3005"},
{
"api_url": "http://foo.bar:8900",
"backend_port": 8002,
"deploy_url": "http://foo.bar:3001",
"frontend_port": 3005,
},
),
],
)
def test_replace_defaults(
monkeypatch,
config_kwargs,
env_vars,
set_persistent_vars,
exp_config_values,
):
"""Test that the config replaces defaults with values from the environment.
Args:
monkeypatch: The pytest monkeypatch object.
config_kwargs: The config kwargs.
env_vars: The environment variables.
set_persistent_vars: The values passed to config._set_persistent variables.
exp_config_values: The expected config values.
"""
mock_os_env = os.environ.copy()
monkeypatch.setattr(reflex.config.os, "environ", mock_os_env) # type: ignore
mock_os_env.update({k: str(v) for k, v in env_vars.items()})
c = rx.Config(app_name="a", **config_kwargs)
c._set_persistent(**set_persistent_vars)
for key, value in exp_config_values.items():
assert getattr(c, key) == value

0 comments on commit 684912e

Please sign in to comment.