diff --git a/platform/mellanox/mlnx-platform-api/sonic_platform/eeprom.py b/platform/mellanox/mlnx-platform-api/sonic_platform/eeprom.py index 17f14b04430f..f5b13f0ae4a5 100644 --- a/platform/mellanox/mlnx-platform-api/sonic_platform/eeprom.py +++ b/platform/mellanox/mlnx-platform-api/sonic_platform/eeprom.py @@ -31,13 +31,13 @@ raise ImportError (str(e) + "- required module not found") from .device_data import DeviceDataManager -from .utils import default_return, is_host +from .utils import default_return, is_host, wait_until logger = Logger() # # this is mlnx-specific -# should this be moved to chass.py or here, which better? +# should this be moved to chassis.py or here, which better? # EEPROM_SYMLINK = "/var/run/hw-management/eeprom/vpd_info" platform_name = DeviceDataManager.get_platform_name() @@ -51,10 +51,12 @@ os.makedirs(os.path.dirname(EEPROM_SYMLINK)) subprocess.check_call(['/usr/bin/xxd', '-r', '-p', 'syseeprom.hex', EEPROM_SYMLINK], cwd=platform_path) +WAIT_EEPROM_READY_SEC = 10 + class Eeprom(eeprom_tlvinfo.TlvInfoDecoder): def __init__(self): - if not os.path.exists(EEPROM_SYMLINK): + if not wait_until(predict=os.path.exists, timeout=WAIT_EEPROM_READY_SEC, path=EEPROM_SYMLINK): logger.log_error("Nowhere to read syseeprom from! No symlink found") raise RuntimeError("No syseeprom symlink found") diff --git a/platform/mellanox/mlnx-platform-api/sonic_platform/utils.py b/platform/mellanox/mlnx-platform-api/sonic_platform/utils.py index 634078c9a077..83063b5c368e 100644 --- a/platform/mellanox/mlnx-platform-api/sonic_platform/utils.py +++ b/platform/mellanox/mlnx-platform-api/sonic_platform/utils.py @@ -19,6 +19,7 @@ import subprocess import json import sys +import time import os from sonic_py_common import device_info from sonic_py_common.logger import Logger @@ -266,3 +267,21 @@ def extract_RJ45_ports_index(): return RJ45_port_index_list if bool(RJ45_port_index_list) else None + +def wait_until(predict, timeout, interval=1, *args, **kwargs): + """Wait until a condition become true + + Args: + predict (object): a callable such as function, lambda + timeout (int): wait time in seconds + interval (int, optional): interval to check the predict. Defaults to 1. + + Returns: + _type_: _description_ + """ + while timeout > 0: + if predict(*args, **kwargs): + return True + time.sleep(interval) + timeout -= interval + return False diff --git a/platform/mellanox/mlnx-platform-api/tests/test_eeprom.py b/platform/mellanox/mlnx-platform-api/tests/test_eeprom.py index 5f0a30dbf519..b07f9327d098 100644 --- a/platform/mellanox/mlnx-platform-api/tests/test_eeprom.py +++ b/platform/mellanox/mlnx-platform-api/tests/test_eeprom.py @@ -49,6 +49,7 @@ def test_chassis_eeprom(self, mock_eeprom_info): assert chassis.get_serial() == 'MT2019X13878' assert chassis.get_system_eeprom_info() == mock_eeprom_info.return_value + @patch('sonic_platform.eeprom.wait_until', MagicMock(return_value=False)) def test_eeprom_init(self): # Test symlink not exist, there is an exception with pytest.raises(RuntimeError): @@ -83,7 +84,7 @@ def side_effect(key, field): @patch('os.path.exists', MagicMock(return_value=True)) @patch('os.path.islink', MagicMock(return_value=True)) - def test_get_system_eeprom_info_from_hardware(self): + def test_get_system_eeprom_info_from_hardware(self): eeprom = Eeprom() eeprom.p = os.path.join(test_path, 'mock_eeprom_data') eeprom._redis_hget = MagicMock() diff --git a/platform/mellanox/mlnx-platform-api/tests/test_utils.py b/platform/mellanox/mlnx-platform-api/tests/test_utils.py index c4c8d0c000a9..5e01fc70dc0e 100644 --- a/platform/mellanox/mlnx-platform-api/tests/test_utils.py +++ b/platform/mellanox/mlnx-platform-api/tests/test_utils.py @@ -18,6 +18,8 @@ import os import pytest import sys +import threading +import time if sys.version_info.major == 3: from unittest import mock else: @@ -125,3 +127,17 @@ def test_run_command(self): def test_extract_RJ45_ports_index(self): rj45_list = utils.extract_RJ45_ports_index() assert rj45_list is None + + def test_wait_until(self): + values = [] + assert utils.wait_until(lambda: len(values) == 0, timeout=1) + assert not utils.wait_until(lambda: len(values) > 0, timeout=1) + + def thread_func(items): + time.sleep(3) + items.append(0) + + t = threading.Thread(target=thread_func, args=(values, )) + t.start() + assert utils.wait_until(lambda: len(values) > 0, timeout=5) + t.join()