Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding mocks and unit tests for socket code #63

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 176 additions & 12 deletions dss_datamover/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,26 @@
from logger import MultiprocessingLogger
from master_application import Master
from multiprocessing import Queue, Value, Lock
from socket_communication import ClientSocket, ServerSocket

import json
import os
import socket
import pytest
from enum import Enum


class Status(Enum):
NORMAL = 0
CONNECTIONERROR = 1
CONNECTIONREFUSEERROR = 2
SOCKETTIMEOUT = 3
SOCKETERROR = 4
EXCEPTION = 5
MISALIGNEDBUFSIZE = 6
WRONGBUFSIZE = 7
LISTENING = 8
CLOSED = 9


@pytest.fixture(scope="session")
Expand All @@ -52,7 +68,8 @@ def get_pytest_configs():

@pytest.fixture(scope="session")
def get_config_object():
test_config_filepath = os.path.dirname(__file__) + "/pytest_config.json"
# test_config_filepath = os.path.dirname(__file__) + "/pytest_config.json"
test_config_filepath = "/etc/dss/datamover/standard_config.json"
config_obj = config.Config({}, config_filepath=test_config_filepath)
return config_obj

Expand All @@ -77,12 +94,13 @@ def get_system_config_dict(get_system_config_object):
def get_multiprocessing_logger(tmpdir):
logger_status = Value('i', 0) # 0=NOT-STARTED, 1=RUNNING, 2=STOPPED
logger_queue = Queue()
logger_lock = Lock()
logger_lock = Value('i', 0)
logging_path = tmpdir
logging_level = "INFO"

logger = MultiprocessingLogger(logger_queue, logger_lock, logger_status)
logger.config(logging_path, __file__, logging_level)
logger.create_logger_handle()
logger.start()

yield logger
Expand All @@ -103,19 +121,138 @@ class MockSocket():
"""
Dummy Object for an actual socket, should simulate all basic functions of a socket object
"""
# TODO: finish building out MockSocket class
def __init__(self, family=0, type=0, proto=0, fileno=0):
self.timeout = 0
self.status = Status.NORMAL
self.data = ''
self.data_index = 0 # indicates the starting pos of the sending data when calling recv
self.max_bufsize = 10 # maximum length of return data when calling recv

def connect(self, address):
if self.status == Status.CONNECTIONERROR:
raise ConnectionError
elif self.status == Status.CONNECTIONREFUSEERROR:
raise ConnectionRefusedError
elif self.status == Status.SOCKETERROR:
raise socket.error
elif self.status == Status.SOCKETTIMEOUT:
raise socket.timeout
else:
return

def recv(self, bufsize):
if self.status == Status.CONNECTIONERROR:
raise ConnectionError
elif self.status == Status.SOCKETTIMEOUT:
raise socket.timeout
elif self.status == Status.EXCEPTION:
raise Exception
elif self.status == Status.MISALIGNEDBUFSIZE:
ret = self.data[self.data_index: self.data_index + bufsize + 1]
return ret
else:
ret = ''
if not self.data:
return ret
if self.data_index >= len(self.data):
raise Exception
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we raise a specific exception here?

if bufsize > self.max_bufsize:
bufsize = self.max_bufsize
if bufsize >= len(self.data) - self.data_index:
ret = self.data[self.data_index:]
self.data_index = len(self.data)
else:
ret = self.data[self.data_index: self.data_index + bufsize]
self.data_index += bufsize
return ret.encode("utf8", "ignore")

def send(self, data, flags=None):
return self.sendall(data, flags)

def sendall(self, data, flags=None):
self.data = ''
self.data_index = 0
if self.status == Status.CONNECTIONERROR:
raise ConnectionError
elif self.status == Status.CONNECTIONREFUSEERROR:
raise ConnectionRefusedError
elif self.status == Status.SOCKETERROR:
raise socket.error
elif self.status == Status.SOCKETTIMEOUT:
raise socket.timeout
else:
self.data = data
return

def setsockopt(self, param1, param2, param3):
pass

def settimeout(self, new_timeout):
pass

def close(self):
if self.status == Status.LISTENING or self.status == Status.NORMAL:
self.status = Status.CLOSED
else:
raise Exception

def listen(self, backlog):
self.status = Status.LISTENING

def bind(self, address):
if self.status == Status.NORMAL:
return
else:
raise Exception

def get_default_ip(self):
default_ip = ""
pytest_config_filepath = os.path.dirname(__file__) + "/pytest_config.json"
with open(pytest_config_filepath) as f:
pytest_configs = json.load(f)
default_ip = pytest_configs['default_ip']
return default_ip

def accept(self):
return self, (self.get_default_ip(), 1234)


@pytest.fixture
def get_header_length(mocker, get_config_dict):
return get_config_dict.get("socket_options", {}).get("response_header_length", 10)


class MockLogger():

def __init__(self):
self.host = "xxx.xxxx.xxxx.xxxx"
self.port = "xxxx"
self.logs = {'error': [], 'info': [], 'warn': [], 'excep': []}

def connect(self):
return True
def error(self, msg):
self.logs['error'].append(msg)

def recv(self):
pass
def info(self, msg):
self.logs['info'].append(msg)

def sendall(self):
pass
def warn(self, msg):
self.logs['warn'].append(msg)

def excep(self, msg):
self.logs['excep'].append(msg)

def get_last(self, type):
ret = ''
if len(self.logs[type]) > 0:
ret = self.logs[type][-1]
return ret

def clear(self):
for key in self.logs:
self.logs[key].clear()


@pytest.fixture
def get_mock_logger():
return MockLogger()


@pytest.fixture
Expand All @@ -126,7 +263,7 @@ def get_mock_clientsocket(mocker):

@pytest.fixture
def get_mock_serversocket(mocker):
mock_serversocket = mocker.patch('socket_communication.ClientSocket', spec=True)
mock_serversocket = mocker.patch('socket_communication.ServerSocket', spec=True)
return mock_serversocket


Expand Down Expand Up @@ -179,3 +316,30 @@ def _method(master):
master.stop_monitor()
print("stopping monitoring")
return _method


class MockMinio():
def __init__(self):
self.data = {}

def list(self, key=''):
if not key:
return list(self.data.items())
return self.data[key].items() if isinstance(self.data[key], dict) else [(key, self.data[key])]

def get(self, key):
if key in self.data:
return self.data[key]
return None

def put(self, key, value):
if key in self.data:
self.data[key] = value
return True
return False

def delete(self, key):
if key in self.data:
self.data.pop(key)
return True
return False
5 changes: 3 additions & 2 deletions dss_datamover/tests/pytest_config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"config": "/etc/dss/datamover/config.json",
"dest_path": "/tmp/xyz",
"cache": ["/var/log/dss/prefix_index_data.json", "/var/log/dss/dm_resume_prefix_dir_keys.txt"]
}
"cache": ["/var/log/dss/prefix_index_data.json", "/var/log/dss/dm_resume_prefix_dir_keys.txt"],
"default_ip": "1.2.3.4"
}
Loading