Skip to content

Commit

Permalink
[RLlib] Use requests.Session object to reuse connections and use reso…
Browse files Browse the repository at this point in the history
…urce more efficiently in PolicyClient (#33035)

Signed-off-by: mattias <mattias.decharleroy@gmail.com>
  • Loading branch information
MattiasDC authored May 21, 2023
1 parent 0a1f435 commit 323d9d5
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions rllib/env/policy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ class PolicyClient:

@PublicAPI
def __init__(
self, address: str, inference_mode: str = "local", update_interval: float = 10.0
self,
address: str,
inference_mode: str = "local",
update_interval: float = 10.0,
session: Optional[requests.Session] = None,
):
"""Create a PolicyClient instance.
Expand All @@ -71,8 +75,13 @@ def __init__(
update_interval (float or None): If using 'local' inference mode,
the policy is refreshed after this many seconds have passed,
or None for manual control via client.
session (requests.Session or None): If available the session object
is used to communicate with the policy server. Using a session
can lead to speedups as connections are reused. It is the
responsibility of the creator of the session to close it.
"""
self.address = address
self.session = session
self.env: ExternalEnv = None
if inference_mode == "local":
self.local = True
Expand Down Expand Up @@ -241,7 +250,12 @@ def update_policy_weights(self) -> None:

def _send(self, data):
payload = pickle.dumps(data)
response = requests.post(self.address, data=payload)

if self.session is None:
response = requests.post(self.address, data=payload)
else:
response = self.session.post(self.address, data=payload)

if response.status_code != 200:
logger.error("Request failed {}: {}".format(response.text, data))
response.raise_for_status()
Expand Down

0 comments on commit 323d9d5

Please sign in to comment.