Skip to content

Commit

Permalink
Merge branch 'update-default-logger' into log-level-updates
Browse files Browse the repository at this point in the history
# Conflicts:
#	runhouse/servers/http/http_client.py
  • Loading branch information
jlewitt1 committed Aug 20, 2024
2 parents de13380 + 2cb2773 commit aa20a92
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 27 deletions.
2 changes: 1 addition & 1 deletion runhouse/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
CONDA_INSTALL_CMDS = [
"wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh",
"bash ~/miniconda.sh -b -p ~/miniconda",
"source $HOME/miniconda3/bin/activate",
"source $HOME/miniconda/bin/activate",
]

TEST_ORG = "test-org"
Expand Down
17 changes: 9 additions & 8 deletions runhouse/logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import logging.config
import os
import sys

from runhouse.constants import DEFAULT_LOG_LEVEL

Expand All @@ -13,13 +13,14 @@ def get_logger(name: str = __name__):
# Set the logging level
logger.setLevel(level)

if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
fmt="%(levelname)s | %(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
handler.setFormatter(formatter)
logger.addHandler(handler)
# Clear existing handlers to prevent duplicate logs
logger.handlers.clear()
handler = logging.StreamHandler(stream=sys.stdout)
formatter = logging.Formatter(
fmt="%(levelname)s | %(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
handler.setFormatter(formatter)
logger.addHandler(handler)

# Prevent the logger from propagating to the root logger
logger.propagate = False
Expand Down
2 changes: 1 addition & 1 deletion runhouse/resources/envs/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _full_command(self, command: str):
def _run_command(self, command: str, **kwargs):
"""Run command locally inside the environment"""
command = self._full_command(command)
logging.info(f"Running command in {self.name}: {command}")
logger.info(f"Running command in {self.name}: {command}")
return run_with_logs(command, **kwargs)

def to(
Expand Down
14 changes: 5 additions & 9 deletions runhouse/servers/http/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
ObjStoreError,
RaySetupOption,
)
from runhouse.utils import generate_default_name, sync_function
from runhouse.utils import sync_function

app = FastAPI(docs_url=None, redoc_url=None)

Expand Down Expand Up @@ -321,11 +321,8 @@ async def _call(
params = params or CallParams()

try:
params.run_name = params.run_name or generate_default_name(
prefix=key if method_name == "__call__" else f"{key}_{method_name}",
precision="ms", # Higher precision because we see collisions within the same second
sep="@",
)
if not params.run_name:
raise ValueError("run_name is required for all calls.")
# Call async so we can loop to collect logs until the result is ready

fut = asyncio.create_task(
Expand All @@ -345,7 +342,6 @@ async def _call(

return StreamingResponse(
HTTPServer._get_results_and_logs_generator(
key,
fut=fut,
run_name=params.run_name,
serialization=params.serialization,
Expand Down Expand Up @@ -497,7 +493,7 @@ def open_new_logfiles(key, open_files):
return open_files

@staticmethod
async def _get_results_and_logs_generator(key, fut, run_name, serialization=None):
async def _get_results_and_logs_generator(fut, run_name, serialization=None):
logger.debug(f"Streaming logs for key {run_name}")
open_logfiles = []
waiting_for_results = True
Expand Down Expand Up @@ -555,7 +551,7 @@ async def _get_results_and_logs_generator(key, fut, run_name, serialization=None
)
finally:
if not open_logfiles:
logger.warning(f"No logfiles found for call {key}")
logger.warning(f"No logfiles found for call {run_name}")
for f in open_logfiles:
f.close()

Expand Down
2 changes: 1 addition & 1 deletion runhouse/servers/http/http_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ class ServerSettings(BaseModel):


class CallParams(BaseModel):
run_name: str
data: Any = None
serialization: Optional[str] = "none"
run_name: Optional[str] = None
stream_logs: Optional[bool] = False
save: Optional[bool] = False
remote: Optional[bool] = False
Expand Down
9 changes: 5 additions & 4 deletions runhouse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,11 @@ def run_command_with_password_login(
####################################################################################################


def _thread_coroutine(coroutine, context):
def thread_coroutine(coroutine, context=None):
# Copy contextvars from the parent thread to the new thread
for var, value in context.items():
var.set(value)
if context is not None:
for var, value in context.items():
var.set(value)

# Technically, event loop logic is not threadsafe. However, this event loop is only in this thread.
loop = asyncio.new_event_loop()
Expand All @@ -272,7 +273,7 @@ def wrapper(*args, **kwargs):
# and the resources are cleaned up
with ThreadPoolExecutor() as executor:
future = executor.submit(
_thread_coroutine,
thread_coroutine,
coroutine_func(*args, **kwargs),
contextvars.copy_context(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def test_http_url(self, cluster):
"kwargs": {},
},
"serialization": None,
"run_name": "test_http_url",
},
headers=rns_client.request_headers(cluster.rns_address)
if cluster.den_auth
Expand All @@ -383,6 +384,7 @@ def test_http_url(self, cluster):
"kwargs": {"a": 1, "b": 2},
},
"serialization": None,
"run_name": "test_http_url",
},
headers=rns_client.request_headers(cluster.rns_address)
if cluster.den_auth
Expand Down
13 changes: 10 additions & 3 deletions tests/test_servers/test_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,13 @@ def test_call_module_method(self, mocker):
# Call the method under test
method_name = "install"
module_name = EMPTY_DEFAULT_ENV_NAME

# Need to specify the run_name to avoid generating a unique one that contains the timestamp
result = self.client.call(
module_name, method_name, resource_address=self.local_cluster.rns_address
module_name,
method_name,
resource_address=self.local_cluster.rns_address,
run_name="test_run_name",
)

assert result == "final_result"
Expand All @@ -167,7 +172,7 @@ def test_call_module_method(self, mocker):
expected_json_data = {
"data": None,
"serialization": "pickle",
"run_name": None,
"run_name": "test_run_name",
"stream_logs": True,
"save": False,
"remote": False,
Expand Down Expand Up @@ -202,18 +207,20 @@ def test_call_module_method_with_args_kwargs(self, mocker):
module_name = "module"
method_name = "install"

# Need to specify the run_name to avoid generating a unique one that contains the timestamp
self.client.call(
module_name,
method_name,
data=data,
resource_address=self.local_cluster.rns_address,
run_name="test_run_name",
)

# Assert that the post request was called with the correct data
expected_json_data = {
"data": serialize_data(data, "pickle"),
"serialization": "pickle",
"run_name": None,
"run_name": "test_run_name",
"stream_logs": True,
"save": False,
"remote": False,
Expand Down
8 changes: 8 additions & 0 deletions tests/test_servers/test_http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def test_call_module_method(self, http_client, remote_func):
"data": serialize_data(data, "pickle"),
"stream_logs": True,
"serialization": "pickle",
"run_name": "test_call_module_method",
},
headers=rns_client.request_headers(remote_func.system.rns_address),
)
Expand All @@ -163,6 +164,7 @@ def test_call_module_method_get_call(self, http_client, remote_func):
params={
"a": 1,
"b": 2,
"run_name": "test_call_module_method_get_call",
},
headers=rns_client.request_headers(remote_func.system.rns_address),
)
Expand Down Expand Up @@ -192,6 +194,7 @@ def test_log_streaming_call(self, http_client, remote_log_streaming_func):
"data": serialize_data(data, "pickle"),
"stream_logs": True,
"serialization": "pickle",
"run_name": "test_log_streaming_call",
},
headers=rns_client.request_headers(clus.rns_address),
) as r:
Expand Down Expand Up @@ -447,6 +450,7 @@ async def test_async_call(self, async_http_client, remote_func):
"kwargs": {},
},
"serialization": None,
"run_name": "test_async_call",
},
headers=rns_client.request_headers(remote_func.system.rns_address),
)
Expand All @@ -468,6 +472,7 @@ async def test_async_call_with_invalid_serialization(
"kwargs": {},
},
"serialization": "random",
"run_name": "test_async_call_with_invalid_serialization",
},
headers=rns_client.request_headers(remote_func.system.rns_address),
)
Expand All @@ -492,6 +497,7 @@ async def test_async_call_with_pickle_serialization(
"pickle",
),
"serialization": "pickle",
"run_name": "test_async_call_with_pickle_serialization",
},
headers=rns_client.request_headers(remote_func.system.rns_address),
)
Expand All @@ -518,6 +524,7 @@ async def test_async_call_with_json_serialization(
}
),
"serialization": "json",
"run_name": "test_async_call_with_json_serialization",
},
headers=rns_client.request_headers(remote_func.system.rns_address),
)
Expand Down Expand Up @@ -661,6 +668,7 @@ def test_call_module_method_with_invalid_token(self, http_client, remote_func):
"data": {"args": args, "kwargs": kwargs},
"stream_logs": False,
"serialization": None,
"run_name": "test_call_module_method_with_invalid_token",
},
headers=INVALID_HEADERS,
)
Expand Down

0 comments on commit aa20a92

Please sign in to comment.