Skip to content

Commit

Permalink
Fix communicator_test.py (NVIDIA#2019)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Sep 21, 2023
1 parent acbdf46 commit 35da789
Showing 1 changed file with 42 additions and 35 deletions.
77 changes: 42 additions & 35 deletions tests/unit_test/fuel/f3/communicator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import time
from threading import Event

import pytest

Expand All @@ -32,66 +34,72 @@
MESSAGE_FROM_B = "Test message from b"


class CommState:
def __init__(self):
self.a_ready_event = Event()
self.a_received_event = Event()
self.b_ready_event = Event()
self.b_received_event = Event()


class Monitor(EndpointMonitor):
def __init__(self, tester):
self.tester = tester
def __init__(self, comm_state: CommState):
self.comm_state = comm_state

def state_change(self, endpoint: Endpoint):
if endpoint.state == EndpointState.READY:
if endpoint.name == NODE_A:
self.tester.a_ready = True
self.comm_state.a_ready_event.set()
else:
self.tester.b_ready = True
self.comm_state.b_ready_event.set()


class Receiver(MessageReceiver):
def __init__(self, tester):
self.tester = tester
def __init__(self, comm_state: CommState):
self.comm_state = comm_state

def process_message(self, endpoint: Endpoint, connection: Connection, app_id: int, message: Message):
text = message.payload.decode("utf-8")
if endpoint.name == NODE_A:
assert text == MESSAGE_FROM_A
self.tester.a_received = True
self.comm_state.a_received_event.set()
else:
assert text == MESSAGE_FROM_B
self.tester.b_received = True
self.comm_state.b_received_event.set()


@pytest.mark.xdist_group(name="test_f3_communicator")
class TestCommunicator:
@pytest.fixture
def comm_a(self):
local_endpoint = Endpoint(NODE_A, {"foo": "test"})
comm = Communicator(local_endpoint)
comm.register_monitor(Monitor(self))
comm.register_message_receiver(APP_ID, Receiver(self))
self.a_ready = False
self.a_received = False
return comm

@pytest.fixture
def comm_b(self):
local_endpoint = Endpoint(NODE_B, {"bar": 123})
comm = Communicator(local_endpoint)
comm.register_monitor(Monitor(self))
comm.register_message_receiver(APP_ID, Receiver(self))
self.b_ready = False
self.b_received = False
return comm
def get_comm_a(comm_state):
local_endpoint = Endpoint(NODE_A, {"foo": "test"})
comm = Communicator(local_endpoint)
comm.register_monitor(Monitor(comm_state))
comm.register_message_receiver(APP_ID, Receiver(comm_state))
return comm


def get_comm_b(comm_state):
local_endpoint = Endpoint(NODE_B, {"bar": 123})
comm = Communicator(local_endpoint)
comm.register_monitor(Monitor(comm_state))
comm.register_message_receiver(APP_ID, Receiver(comm_state))
return comm


class TestCommunicator:
@pytest.mark.parametrize(
"scheme, port_range",
[
("tcp", "2000-3000"),
("grpc", "3000-4000"),
("http", "4000-5000"),
# ("http", "4000-5000"), TODO (YT): We disable this, as it is causing our jenkins hanging
("atcp", "5000-6000"),
],
)
def test_sfm_message(self, comm_a, comm_b, scheme, port_range):
def test_sfm_message(self, scheme, port_range):
comm_state = CommState()
comm_a = get_comm_a(comm_state)
comm_b = get_comm_b(comm_state)

handle1, url = comm_a.start_listener(scheme, {"ports": port_range})
_, url = comm_a.start_listener(scheme, {"ports": port_range})
comm_a.start()

# Check port is in the range
Expand All @@ -106,17 +114,16 @@ def test_sfm_message(self, comm_a, comm_b, scheme, port_range):
comm_b.add_connector(url, Mode.ACTIVE)
comm_b.start()

while not self.a_ready or not self.b_ready:
while not comm_state.a_ready_event.wait(10) or not comm_state.b_ready_event.wait(10):
log.info("Waiting for both endpoints to be ready")
time.sleep(0.1)

comm_a.send(Endpoint(NODE_B), APP_ID, Message({}, MESSAGE_FROM_A.encode("utf-8")))

comm_b.send(Endpoint(NODE_A), APP_ID, Message({}, MESSAGE_FROM_B.encode("utf-8")))

time.sleep(1)

assert self.a_received and self.b_received
assert comm_state.a_received_event.is_set() and comm_state.b_received_event.is_set()

comm_b.stop()
comm_a.stop()

0 comments on commit 35da789

Please sign in to comment.