Skip to content

Commit

Permalink
Add timeout option
Browse files Browse the repository at this point in the history
  • Loading branch information
remip2 committed Apr 29, 2020
1 parent 40f1209 commit 9330788
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
7 changes: 6 additions & 1 deletion stormshield/sns/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def main():
group.add_argument("-i", "--ip", help="Remote UTM ip", default=None)
group.add_argument("-P", "--port", help="Remote port", default=443, type=int)
group.add_argument("--proxy", help="Proxy URL (scheme://user:password@host:port)", default=None)
group.add_argument("-t", "--timeout", help="Connection timeout in seconds", default=-1, type=int)

group = parser.add_argument_group("Authentication parameters")
group.add_argument("-u", "--user", help="User name", default="admin")
Expand Down Expand Up @@ -120,6 +121,7 @@ def main():
password = args.password
port = args.port
proxy = args.proxy
timeout = args.timeout
user = args.user
sslverifypeer = args.sslverifypeer
sslverifyhost = args.sslverifyhost
Expand Down Expand Up @@ -195,11 +197,14 @@ def logcommand(self, message, *args, **kwargs):
if password is None and usercert is None:
password = getpass.getpass()

if timeout == -1:
timeout = None

try:
client = SSLClient(
host=host, ip=ip, port=port, user=user, password=password,
sslverifypeer=sslverifypeer, sslverifyhost=sslverifyhost,
credentials=credentials, proxy=proxy,
credentials=credentials, proxy=proxy, timeout=timeout,
usercert=usercert, cabundle=cabundle, autoconnect=False)
except Exception as exception:
logging.error(str(exception))
Expand Down
25 changes: 17 additions & 8 deletions stormshield/sns/sslclient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class SSLClient:

def __init__(self, user='admin', password=None, host=None, ip=None, port=443, cabundle=None,
sslverifypeer=True, sslverifyhost=True, credentials=None,
usercert=None, autoconnect=True, proxy=None):
usercert=None, autoconnect=True, proxy=None, timeout=None):
""":class:`SSLclient <SSLClient>` constructor.
:param user: Optional user name.
Expand All @@ -283,6 +283,7 @@ def __init__(self, user='admin', password=None, host=None, ip=None, port=443, ca
:param usercert: Optional user certificate.
:param autoconnect: Connect to the appliance at initialization
:param proxy: https proxy url (socks5://user:pass@host:port http://user:password@host/)
:param timeout: connection and read timeout in seconds
"""

self.user = user
Expand All @@ -303,6 +304,7 @@ def __init__(self, user='admin', password=None, host=None, ip=None, port=443, ca
self.dl_crc = ""
self.autoconnect = autoconnect
self.proxy = proxy
self.conn_options = {}

if host is None:
raise MissingHost("Host parameter must be provided")
Expand Down Expand Up @@ -355,6 +357,9 @@ def __init__(self, user='admin', password=None, host=None, ip=None, port=443, ca
if self.proxy:
self.session.proxies = { "https": self.proxy}

if timeout is not None:
self.conn_options = { "timeout": timeout }

self.logger = logging.getLogger()

if self.autoconnect:
Expand All @@ -376,7 +381,7 @@ def connect(self):
# user cert authentication
request = self.session.get(
self.baseurl + '/auth/admin.html?sslcert=1&app={}'.format(self.app),
headers=self.headers)
headers=self.headers, **self.conn_options)
else:
# password authentication
request = self.session.post(
Expand All @@ -385,7 +390,8 @@ def connect(self):
'uid':base64.b64encode(self.user.encode('utf-8')),
'pswd':base64.b64encode(self.password.encode('utf-8')),
'app':self.app},
headers=self.headers)
headers=self.headers,
**self.conn_options)

self.logger.log(logging.DEBUG, request.text)

Expand All @@ -405,7 +411,8 @@ def connect(self):
request = self.session.post(
self.baseurl + '/api/auth/login',
data=data,
headers=self.headers)
headers=self.headers,
**self.conn_options)

self.logger.log(logging.DEBUG, request.text)

Expand Down Expand Up @@ -435,7 +442,7 @@ def disconnect(self):

request = self.session.get(
self.baseurl + '/api/auth/logout?sessionid=' + self.sessionid,
headers=self.headers)
headers=self.headers, **self.conn_options)

if request.status_code == requests.codes.OK:
self.logger.log(logging.INFO, 'Disconnected from %s', self.host)
Expand Down Expand Up @@ -473,7 +480,7 @@ def send_command(self, command):
request = self.session.get(
self.baseurl + '/api/command?sessionid=' + self.sessionid +
'&cmd=' + requests.compat.quote(command.encode('utf-8')), # manually done since we need %20 encoding
headers=self.headers)
headers=self.headers, **self.conn_options)

self.logger.log(logging.DEBUG, request.text)

Expand Down Expand Up @@ -524,7 +531,8 @@ def download(self, filename):
request = self.session.get(
self.baseurl + '/api/download/tmp.file?sessionid=' + self.sessionid,
headers=self.headers,
stream=True)
stream=True,
**self.conn_options)

if request.status_code == requests.codes.OK:
size = 0
Expand Down Expand Up @@ -569,7 +577,8 @@ def upload(self, filename):
request = self.session.post(
self.baseurl + '/api/upload?sessionid=' + self.sessionid,
headers=headers,
data=data)
data=data,
**self.conn_options)

uploadh.close()

Expand Down

0 comments on commit 9330788

Please sign in to comment.