Skip to content

Commit

Permalink
Add complete docstrings and typehinting to aepsych python client
Browse files Browse the repository at this point in the history
Summary: Updated the aepsych python client to match the rest of the aepsych documentation/type hinting.

Differential Revision: D70426277
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Mar 1, 2025
1 parent 9e094a7 commit 153d8e7
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions clients/python/aepsych_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import json
import socket
import warnings
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, Literal

if TYPE_CHECKING:
from aepsych.server import AEPsychServer
import torch


class ServerError(RuntimeError):
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(
self.socket = None

def load_config_index(self) -> None:
"""Loads the config index when server is not None"""
"""Loads the config index when there is an in-memory server."""
if self.server is None:
raise AttributeError("there is no in-memory server")

Expand All @@ -78,7 +79,15 @@ def connect(self, ip: str, port: int) -> None:
self.socket.connect(addr)

def finalize(self) -> Dict[str, Any]:
"""Let the server know experiment is complete."""
"""Let the server know experiment is complete and stop the server.
Returns:
Dict[str, Any]: A dictionary with two entries:
- "config": dictionary with config (keys are strings, values are floats).
Currently always "Terminate" if this function succeeds.
- "is_finished": boolean, true if the strat is finished. Currently always
true if this function succeeds.
"""
request = {"message": "", "type": "exit"}
return self._send_recv(request)

Expand Down Expand Up @@ -110,11 +119,11 @@ def ask(
num_points[int]: Number of points to return.
Returns:
Dict[int, Dict[str, Any]]: Next configuration(s) to evaluate.
If using the legacy backend, this is formatted as a dictionary where keys are parameter names and values
are lists of parameter values.
If using the Ax backend, this is formatted as a dictionary of dictionaries where the outer keys are trial indices,
the inner keys are parameter names, and the values are parameter values.
Dict[str, Any]: A dictionary with three entries
- "config": dictionary with config (keys are strings, values are floats), None
if skipping computations during replay.
- "is_finished": boolean, true if the strat is finished
- "num_points": integer, number of points returned.
"""
request = {"message": {"num_points": num_points}, "type": "ask"}
response = self._send_recv(request)
Expand Down Expand Up @@ -143,9 +152,6 @@ def tell(
- "trials_recorded": integer, the number of trials recorded in the
database.
- "model_data_added": integer, the number of datapoints added to the model.
Raises:
AssertionError if server failed to acknowledge the tell.
"""
message = {
"config": config,
Expand Down Expand Up @@ -177,19 +183,16 @@ def configure(
Returns:
Dict[str, Any]: A dictionary with one entry
- "strat_id": integer, the stategy ID for what was just set up.
Raises:
AssertionError if neither config path nor config_str is passed.
"""

if config_path is not None:
assert config_str is None, "if config_path is passed, don't pass config_str"
with open(config_path, "r") as f:
config_str = f.read()
elif config_str is not None:
assert (
config_path is None
), "if config_str is passed, don't pass config_path"
assert config_path is None, (
"if config_str is passed, don't pass config_path"
)
request = {
"type": "setup",
"message": {"config_str": config_str},
Expand All @@ -214,20 +217,17 @@ def resume(
Returns:
Dict[str, Any]: A dictionary with one entry
- "strat_id": integer, the stategy ID that was resumed.
Raises:
AssertionError if name or ID does not exist, or if both name and ID are passed.
"""
if config_id is not None:
assert config_name is None, "if config_id is passed, don't pass config_name"
assert (
config_id in self.configs
), f"No strat with index {config_id} was created!"
assert config_id in self.configs, (
f"No strat with index {config_id} was created!"
)
elif config_name is not None:
assert config_id is None, "if config_name is passed, don't pass config_id"
assert (
config_name in self.config_names.keys()
), f"{config_name} not known, know {self.config_names.keys()}!"
assert config_name in self.config_names.keys(), (
f"{config_name} not known, know {self.config_names.keys()}!"
)
config_id = self.config_names[config_name]
request = {
"type": "resume",
Expand All @@ -238,11 +238,11 @@ def resume(

def query(
self,
query_type="max",
probability_space=False,
x=None,
y=None,
constraints=None,
query_type: Literal["max", "min", "prediction", "inverse"] = "max",
probability_space: bool = False,
x: Optional[Dict[str, Any]] = None,
y: Optional[Union[float, "torch.Tensor"]] = None,
constraints: Optional[Dict[int, float]] = None,
**kwargs,
) -> Dict[str, Any]:
"""Queries the underlying model for a specific query.
Expand Down Expand Up @@ -284,5 +284,5 @@ def query(

return self._send_recv(request)

def __del___(self):
def __del__(self):
self.finalize()

0 comments on commit 153d8e7

Please sign in to comment.