Skip to content

Commit

Permalink
Add complete docstrings and typehinting to aepsych python client (#664)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #664

Updated the aepsych python client to match the rest of the aepsych documentation/type hinting.

Differential Revision: D70426277
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Mar 1, 2025
1 parent 9e094a7 commit f3afebf
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions clients/python/aepsych_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import json
import socket
import warnings
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Union

if TYPE_CHECKING:
import torch
from aepsych.server import AEPsychServer


Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(
self.socket = None

def load_config_index(self) -> None:
"""Loads the config index when server is not None"""
"""Loads the config index when there is an in-memory server."""
if self.server is None:
raise AttributeError("there is no in-memory server")

Expand All @@ -78,7 +79,15 @@ def connect(self, ip: str, port: int) -> None:
self.socket.connect(addr)

def finalize(self) -> Dict[str, Any]:
"""Let the server know experiment is complete."""
"""Let the server know experiment is complete and stop the server.
Returns:
Dict[str, Any]: A dictionary with two entries:
- "config": dictionary with config (keys are strings, values are floats).
Currently always "Terminate" if this function succeeds.
- "is_finished": boolean, true if the strat is finished. Currently always
true if this function succeeds.
"""
request = {"message": "", "type": "exit"}
return self._send_recv(request)

Expand Down Expand Up @@ -110,11 +119,11 @@ def ask(
num_points[int]: Number of points to return.
Returns:
Dict[int, Dict[str, Any]]: Next configuration(s) to evaluate.
If using the legacy backend, this is formatted as a dictionary where keys are parameter names and values
are lists of parameter values.
If using the Ax backend, this is formatted as a dictionary of dictionaries where the outer keys are trial indices,
the inner keys are parameter names, and the values are parameter values.
Dict[str, Any]: A dictionary with three entries
- "config": dictionary with config (keys are strings, values are floats), None
if skipping computations during replay.
- "is_finished": boolean, true if the strat is finished
- "num_points": integer, number of points returned.
"""
request = {"message": {"num_points": num_points}, "type": "ask"}
response = self._send_recv(request)
Expand Down Expand Up @@ -143,9 +152,6 @@ def tell(
- "trials_recorded": integer, the number of trials recorded in the
database.
- "model_data_added": integer, the number of datapoints added to the model.
Raises:
AssertionError if server failed to acknowledge the tell.
"""
message = {
"config": config,
Expand Down Expand Up @@ -177,9 +183,6 @@ def configure(
Returns:
Dict[str, Any]: A dictionary with one entry
- "strat_id": integer, the stategy ID for what was just set up.
Raises:
AssertionError if neither config path nor config_str is passed.
"""

if config_path is not None:
Expand Down Expand Up @@ -214,9 +217,6 @@ def resume(
Returns:
Dict[str, Any]: A dictionary with one entry
- "strat_id": integer, the stategy ID that was resumed.
Raises:
AssertionError if name or ID does not exist, or if both name and ID are passed.
"""
if config_id is not None:
assert config_name is None, "if config_id is passed, don't pass config_name"
Expand All @@ -238,11 +238,11 @@ def resume(

def query(
self,
query_type="max",
probability_space=False,
x=None,
y=None,
constraints=None,
query_type: Literal["max", "min", "prediction", "inverse"] = "max",
probability_space: bool = False,
x: Optional[Dict[str, Any]] = None,
y: Optional[Union[float, "torch.Tensor"]] = None,
constraints: Optional[Dict[int, float]] = None,
**kwargs,
) -> Dict[str, Any]:
"""Queries the underlying model for a specific query.
Expand Down Expand Up @@ -284,5 +284,5 @@ def query(

return self._send_recv(request)

def __del___(self):
def __del__(self):
self.finalize()

0 comments on commit f3afebf

Please sign in to comment.