Skip to content

Commit

Permalink
Bugfix/SK-1193 | Handle unknown error, reconnect channel (#743)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede authored Nov 19, 2024
1 parent 58ec837 commit 3e1dcd7
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 18 deletions.
210 changes: 210 additions & 0 deletions .ci/tests/chaos_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from toxiproxy import Toxiproxy
import unittest
import grpc
import time
from fedn.network.clients.grpc_handler import GrpcHandler
import fedn.network.grpc.fedn_pb2 as fedn


class TestGRPCWithToxiproxy(unittest.TestCase):
@classmethod
def setUpClass(cls):
client_name = 'test-client'
client_id = 'test-client-id'
host = 'localhost'
port_proxy = 12081
port_server = 12080
token = ""
combiner_name = 'combiner'

cls.toxiproxy = Toxiproxy()
if cls.toxiproxy.proxies():
cls.toxiproxy.destroy_all()

@classmethod
def tearDownClass(cls):
# Close the proxy and gRPC channel when done
cls.toxiproxy.destroy_all()

@unittest.skip("Not implemented")
def test_normal_heartbeat(self):
# Test the heartbeat without any toxic
client_name = 'test-client'
client_id = 'test-client-id'
# Random proxy port
grpc_handler = GrpcHandler(host='localhost', port=12080, name=client_name, token='', combiner_name='combiner')
try:
response = grpc_handler.heartbeat(client_name, client_id)
self.assertIsInstance(response, fedn.Response)
except grpc.RpcError as e:
self.fail(f'gRPC error: {e.code()} {e.details()}')
finally:
grpc_handler.channel.close()

@unittest.skip("Not implemented")
def test_latency_2s_toxic_heartbeat(self):
# Add latency of 1000ms
client_name = 'test-client'
client_id = 'test-client-id'

proxy = self.toxiproxy.create(name='test_latency_toxic_heartbeat', listen='localhost:12082', upstream='localhost:12080')
grpc_handler = GrpcHandler(host='localhost', port=12082, name=client_name, token='', combiner_name='combiner')
proxy.add_toxic(name='latency', type='latency', attributes={'latency': 2000})

start_time = time.time()
try:
response = grpc_handler.heartbeat(client_name, client_id)
finally:
grpc_handler.channel.close()
proxy.destroy()
end_time = time.time()

# Check that the latency delay is present
self.assertGreaterEqual(end_time - start_time, 2) # Expect at least 1 second delay
self.assertIsInstance(response, fedn.Response)

def test_latency_long_toxic_heartbeat(self):
"""Test gRPC request with a simulated latency of 25s. Should timeout based on KEEPALIVE_TIMEOUT_MS (default set to 20000)."""
client_name = 'test-client'
client_id = 'test-client-id'
latency = 20 # 15s latency

proxy = self.toxiproxy.create(name='test_latency_toxic_heartbeat', listen='localhost:12083', upstream='localhost:12080')
grpc_handler = GrpcHandler(host='localhost', port=12083, name=client_name, token='', combiner_name='combiner')
proxy.add_toxic(name='latency', type='latency', attributes={'latency': latency * 1000})

start_time = time.time()
try:
response = grpc_handler.heartbeat(client_name, client_id)
except grpc.RpcError as e:
response = e
finally:
grpc_handler.channel.close()
proxy.destroy()
end_time = time.time()

response

# Check that the latency delay is present
self.assertGreaterEqual(end_time - start_time, latency) # Expect at least 1 second delay
self.assertIsInstance(response, grpc.RpcError)
self.assertEqual(response.code(), grpc.StatusCode.UNAVAILABLE)
self.assertEqual(response.details(), 'failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:12083: connection attempt timed out before receiving SETTINGS frame')

def test_close_channel(self):
"""
Test closing the gRPC channel and trying to send a heartbeat.
Expect a ValueError to be raised.
"""

client_name = 'test-client'
client_id = 'test-client-id'

grpc_handler = GrpcHandler(host='localhost', port=12080, name=client_name, token='', combiner_name='combiner')

# Close the channel
grpc_handler._disconnect()

# Try to send heartbeat
with self.assertRaises(ValueError) as context:
response = grpc_handler.heartbeat(client_name, client_id)
self.assertEqual(str(context.exception), 'Cannot invoke RPC on closed channel!')


@unittest.skip("Not implemented")
def test_disconnect_toxic_heartbeat(self):
"""Test gRPC request with a simulated disconnection."""
# Add a timeout toxic to simulate network disconnection
client_name = 'test-client'
client_id = 'test-client-id'

proxy = self.toxiproxy.create(name='test_disconnect_toxic_heartbeat', listen='localhost:12084', upstream='localhost:12080')
grpc_handler = GrpcHandler(host='localhost', port=12084, name=client_name, token='', combiner_name='combiner')
proxy.add_toxic(name='timeout', type='timeout', attributes={'timeout': 1000})

try:
response = grpc_handler.heartbeat(client_name, client_id)
except grpc.RpcError as e:
response = e
finally:
grpc_handler.channel.close()
proxy.destroy()

# Assert that the response is a gRPC error with status code UNAVAILABLE
self.assertEqual(response.code(), grpc.StatusCode.UNAVAILABLE)
self.assertEqual(response.details(), 'failed to connect to all addresses; last error: UNAVAILABLE: ipv4:127.0.0.1:12084: Socket closed')

@unittest.skip("Not implemented")
def test_timeout_toxic_heartbeat(self):
"""Stops all data from getting through, and closes the connection after timeout. timeout is 0,
the connection won't close, and data will be delayed until the toxic is removed.
"""
# Add a timeout toxic to simulate network disconnection
client_name = 'test-client'
client_id = 'test-client-id'

proxy = self.toxiproxy.create(name='test_timeout_toxic_heartbeat', listen='localhost:12085', upstream='localhost:12080')
grpc_handler = GrpcHandler(host='localhost', port=12085, name=client_name, token='', combiner_name='combiner')
proxy.add_toxic(name='timeout', type='timeout', attributes={'timeout': 0})

try:
response = grpc_handler.heartbeat(client_name, client_id)
except grpc.RpcError as e:
response = e
finally:
grpc_handler.channel.close()
proxy.destroy()

# Assert that the response is a gRPC error with status code UNAVAILABLE
self.assertEqual(response.code(), grpc.StatusCode.UNAVAILABLE)
self.assertEqual(response.details(), 'failed to connect to all addresses; last error: UNKNOWN: ipv4:127.0.0.1:12085: connection attempt timed out before receiving SETTINGS frame')

@unittest.skip("Not implemented")
def test_rate_limit_toxic_heartbeat(self):
# Purpose: Limits the number of connections that can be established within a certain time frame.
# Toxic: rate_limit
# Use Case: Useful for testing how the client behaves under strict rate limits. For example, in Federated Learning,
# this could simulate constraints in networks with multiple clients trying to access the server.

# Add a rate limit toxic to the proxy
self.proxy.add_rate_limit(rate=1000)

@unittest.skip("Not implemented")
def test_bandwidth_toxic_heartbeat(self):
# Purpose: Limits the bandwidth of the connection.
# Toxic: bandwidth
# Use Case: Useful for testing how the client behaves under limited bandwidth. For example, in Federated Learning,
# this could simulate a slow network connection between the client and the server.

# Add a bandwidth toxic to the proxy
self.proxy.add_bandwidth(rate=1000) # 1 KB/s

@unittest.skip("Not implemented")
def test_connection_reset(self):
# Purpose: Immediately resets the connection, simulating an abrupt network drop.
# Toxic: add_reset
# Use Case: This is helpful for testing error-handling logic on sudden network failures,
# ensuring the client retries appropriately or fails gracefully

# Add a connection_reset toxic to the proxy
self.proxy.add_reset()

@unittest.skip("Not implemented")
def test_slow_close(self):
# Purpose: Simulates a slow closing of the connection.
# Toxic: slow_close
# Use Case: Useful for testing how the client behaves when the server closes the connection slowly.
# This can help ensure that the client handles slow network disconnections gracefully.

# Add a slow_close toxic to the proxy
self.proxy.add_slow_close(delay=1000) # Delay closing the connection by 1 second

@unittest.skip("Not implemented")
def test_slicer(self):
# Purpose: Slices the data into smaller chunks.
# Toxic: slicer
# Use Case: Useful for testing how the client handles fragmented data.
# This can help ensure that the client can reassemble the data correctly and handle partial data gracefully.

# Add a slicer toxic to the proxy
self.proxy.add_slicer(average_size=1000, size_variation=100) # Slice data into chunks of 1 KB with 100 bytes variation
69 changes: 51 additions & 18 deletions fedn/network/clients/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,25 @@ def __init__(self, host: str, port: int, name: str, token: str, combiner_name: s
("client", name),
("grpc-server", combiner_name),
]
self.host = host
self.port = port
self.token = token

if port == 443:
self._init_secure_channel(host, port, token)
else:
self._init_insecure_channel(host, port)
self._init_channel(host, port, token)

self._init_stubs()

def _init_stubs(self):
self.connectorStub = rpc.ConnectorStub(self.channel)
self.combinerStub = rpc.CombinerStub(self.channel)
self.modelStub = rpc.ModelServiceStub(self.channel)

def _init_channel(self, host: str, port: int, token: str):
if port == 443:
self._init_secure_channel(host, port, token)
else:
self._init_insecure_channel(host, port)

def _init_secure_channel(self, host: str, port: int, token: str):
url = f"{host}:{port}"
logger.info(f"Connecting (GRPC) to {url}")
Expand Down Expand Up @@ -116,10 +125,10 @@ def heartbeat(self, client_name: str, client_id: str):
logger.info("Sending heartbeat to combiner")
response = self.connectorStub.SendHeartbeat(heartbeat, metadata=self.metadata)
except grpc.RpcError as e:
logger.error(f"GRPC (SendHeartbeat): An error occurred: {e}")
raise e
except Exception as e:
logger.error(f"GRPC (SendHeartbeat): An error occurred: {e}")
self._disconnect()
raise e
return response

Expand All @@ -130,6 +139,8 @@ def send_heartbeats(self, client_name: str, client_id: str, update_frequency: fl
response = self.heartbeat(client_name, client_id)
except grpc.RpcError as e:
return self._handle_grpc_error(e, "SendHeartbeat", lambda: self.send_heartbeats(client_name, client_id, update_frequency))
except Exception as e:
return self._handle_unknown_error(e, "SendHeartbeat", lambda: self.send_heartbeats(client_name, client_id, update_frequency))
if isinstance(response, fedn.Response):
logger.info("Heartbeat successful.")
else:
Expand Down Expand Up @@ -166,10 +177,11 @@ def listen_to_task_stream(self, client_name: str, client_id: str, callback: Call
callback(request)

except grpc.RpcError as e:
self.logger.error(f"GRPC (TaskStream): An error occurred: {e}")
return self._handle_grpc_error(e, "TaskStream", lambda: self.listen_to_task_stream(client_name, client_id, callback))
except Exception as e:
logger.error(f"GRPC (TaskStream): An error occurred: {e}")
self._disconnect()
self._handle_unknown_error(e, "TaskStream", lambda: self.listen_to_task_stream(client_name, client_id, callback))

def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=None, sesssion_id: str = None, sender_name: str = None):
"""Send status message.
Expand Down Expand Up @@ -204,7 +216,7 @@ def send_status(self, msg: str, log_level=fedn.Status.INFO, type=None, request=N
return self._handle_grpc_error(e, "SendStatus", lambda: self.send_status(msg, log_level, type, request, sesssion_id, sender_name))
except Exception as e:
logger.error(f"GRPC (SendStatus): An error occurred: {e}")
self._disconnect()
self._handle_unknown_error(e, "SendStatus", lambda: self.send_status(msg, log_level, type, request, sesssion_id, sender_name))

def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO:
"""Fetch a model from the assigned combiner.
Expand Down Expand Up @@ -241,8 +253,7 @@ def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) ->
return self._handle_grpc_error(e, "Download", lambda: self.get_model_from_combiner(id, client_id, timeout))
except Exception as e:
logger.error(f"GRPC (Download): An error occurred: {e}")
self._disconnect()

self._handle_unknown_error(e, "Download", lambda: self.get_model_from_combiner(id, client_id, timeout))
return data

def send_model_to_combiner(self, model: BytesIO, id: str):
Expand Down Expand Up @@ -273,8 +284,7 @@ def send_model_to_combiner(self, model: BytesIO, id: str):
return self._handle_grpc_error(e, "Upload", lambda: self.send_model_to_combiner(model, id))
except Exception as e:
logger.error(f"GRPC (Upload): An error occurred: {e}")
self._disconnect()

self._handle_unknown_error(e, "Upload", lambda: self.send_model_to_combiner(model, id))
return result

def create_update_message(
Expand Down Expand Up @@ -353,8 +363,7 @@ def send_model_update(self, update: fedn.ModelUpdate):
return self._handle_grpc_error(e, "SendModelUpdate", lambda: self.send_model_update(update))
except Exception as e:
logger.error(f"GRPC (SendModelUpdate): An error occurred: {e}")
self._disconnect()

self._handle_unknown_error(e, "SendModelUpdate", lambda: self.send_model_update(update))
return True

def send_model_validation(self, validation: fedn.ModelValidation) -> bool:
Expand All @@ -369,8 +378,7 @@ def send_model_validation(self, validation: fedn.ModelValidation) -> bool:
)
except Exception as e:
logger.error(f"GRPC (SendModelValidation): An error occurred: {e}")
self._disconnect()

self._handle_unknown_error(e, "SendModelValidation", lambda: self.send_model_validation(validation))
return True

def send_model_prediction(self, prediction: fedn.ModelPrediction) -> bool:
Expand All @@ -385,8 +393,7 @@ def send_model_prediction(self, prediction: fedn.ModelPrediction) -> bool:
)
except Exception as e:
logger.error(f"GRPC (SendModelPrediction): An error occurred: {e}")
self._disconnect()

self._handle_unknown_error(e, "SendModelPrediction", lambda: self.send_model_prediction(prediction))
return True

def _handle_grpc_error(self, e, method_name: str, sender_function: Callable):
Expand All @@ -399,12 +406,38 @@ def _handle_grpc_error(self, e, method_name: str, sender_function: Callable):
logger.warning(f"GRPC ({method_name}): connection cancelled. Retrying in 5 seconds.")
time.sleep(5)
return sender_function()
if status_code == grpc.StatusCode.UNAUTHENTICATED:
elif status_code == grpc.StatusCode.UNAUTHENTICATED:
details = e.details()
if details == "Token expired":
logger.warning(f"GRPC ({method_name}): Token expired.")
raise e
elif status_code == grpc.StatusCode.UNKNOWN:
logger.warning(f"GRPC ({method_name}): An unknown error occurred: {e}.")
details = e.details()
if details == "Stream removed":
logger.warning(f"GRPC ({method_name}): Stream removed. Reconnecting")
self._disconnect()
self._init_channel(self.host, self.port, self.token)
self._init_stubs()
return sender_function()
raise e
self._disconnect()
logger.error(f"GRPC ({method_name}): An error occurred: {e}")
raise e

def _handle_unknown_error(self, e, method_name: str, sender_function: Callable):
# Try to reconnect
logger.warning(f"GRPC ({method_name}): An unknown error occurred: {e}.")
if isinstance(e, ValueError):
# ValueError is raised when the channel is closed
self._disconnect()
logger.warning(f"GRPC ({method_name}): Reconnecting to channel.")
# recreate the channel
self._init_channel(self.host, self.port, self.token)
self._init_stubs()
return sender_function()
else:
raise e

def _disconnect(self):
"""Disconnect from the combiner."""
Expand Down

0 comments on commit 3e1dcd7

Please sign in to comment.