Skip to content

Commit

Permalink
enforce write only access to folder http endpoint (#1109)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlewitt1 authored Aug 12, 2024
1 parent 6c27589 commit ec2a418
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
4 changes: 4 additions & 0 deletions runhouse/servers/http/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def clear_cache(self, token: str = None):
async def averify_cluster_access(
cluster_uri: str,
token: str,
access_level_required: ResourceAccess = None,
) -> bool:
"""Checks whether the user has access to the cluster.
Note: A user with write access to the cluster or a cluster owner will have access to all other resources on
Expand All @@ -94,4 +95,7 @@ async def averify_cluster_access(

cluster_access_level = await obj_store.aresource_access_level(token, cluster_uri)

if access_level_required is not None:
return cluster_access_level == access_level_required

return cluster_access_level in [ResourceAccess.WRITE, ResourceAccess.READ]
11 changes: 9 additions & 2 deletions runhouse/servers/http/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from runhouse.globals import configs, obj_store, rns_client
from runhouse.logger import logger
from runhouse.rns.utils.api import resolve_absolute_path
from runhouse.rns.utils.api import resolve_absolute_path, ResourceAccess
from runhouse.rns.utils.names import _generate_default_name
from runhouse.servers.caddy.config import CaddyConfig
from runhouse.servers.http.auth import averify_cluster_access
Expand Down Expand Up @@ -80,6 +80,11 @@ async def wrapper(*args, **kwargs):
is_coro = inspect.iscoroutinefunction(func)

func_call: bool = func.__name__ in ["post_call", "get_call"]

# restrict access for folder specific APIs
access_level_required = (
ResourceAccess.WRITE if func.__name__.startswith("folder") else None
)
token = get_token_from_request(request)

request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
Expand All @@ -95,7 +100,9 @@ async def wrapper(*args, **kwargs):
"provide a valid token in the Authorization header.",
)
cluster_uri = (await obj_store.aget_cluster_config()).get("name")
cluster_access = await averify_cluster_access(cluster_uri, token)
cluster_access = await averify_cluster_access(
cluster_uri, token, access_level_required
)
if not cluster_access:
# Must have cluster access for all the non func calls
# Note: for func calls we handle the auth in the object store
Expand Down

0 comments on commit ec2a418

Please sign in to comment.