From df5530fa5b44d62bb41aef30835e1316645738ed Mon Sep 17 00:00:00 2001 From: Wes Belt <45329127+chkp-wbelt@users.noreply.github.com> Date: Thu, 3 Jan 2019 23:10:00 -0500 Subject: [PATCH] Utilize official Check Point python API (#5) * Add reference to official API * Initial changes for official API, no testing yet * Correct some API errors before testing * Format fixes, use data and success from API * Take advantage of wait_for_task and add some more console output * minor format tweak * fix typo of last typo :) * Console output tweaks * Initial testing complete * More testing and updates, everything except rule conversions tested * Fix bug for rule names * Remove old imports --- .gitmodules | 3 + convert-wildcard.py | 416 +++++++++++++++-------------------------- cp_mgmt_api_python_sdk | 1 + 3 files changed, 157 insertions(+), 263 deletions(-) create mode 100644 .gitmodules create mode 160000 cp_mgmt_api_python_sdk diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..b997982 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "cp_mgmt_api_python_sdk"] + path = cp_mgmt_api_python_sdk + url = https://github.com/CheckPointSW/cp_mgmt_api_python_sdk diff --git a/convert-wildcard.py b/convert-wildcard.py index f928267..4969da0 100644 --- a/convert-wildcard.py +++ b/convert-wildcard.py @@ -1,43 +1,121 @@ """Script to migrate R77.30 network objects to R80.20 native wildcard objects.""" import sys -import requests -import json import os.path import csv -import urllib3 import argparse import getpass -from time import sleep -from urlparse import urljoin - -__version__ = "1.3" - -class APIException(Exception): - """Exception raised when API response is abnormal.""" - - def __init__(self, code, message): - self.code = code - self.message = message - - def __str__(self): - return("{}: {}".format(self.code,self.message)) +from cp_mgmt_api_python_sdk.lib import APIClient, APIClientArgs +__version__ = "2.0" +class WildcardManager(): + def __init__(self, client): + self.client = client + + def convert(self, records): + #Loop through each record. Rename existing objects and replace references with new object + if len(records) > 0: + print '--- Convert found {:,} records for processing'.format(len(records)) + #Create variable to track number of objects + track = 0 + suffix = {} + suffix["R77"] = "_R77" + suffix["R80"] = "_WC" + for record in records: + track += 1 + print '--- Record {:,} of {:,}. Working with {} (color:{} network:{} mask:{})'.format(track,len(records),record['name'],record['color'],record['ipv4-address'],record['ipv4-mask-wildcard']) + #get uid for network object + NUID = self.getNetworkUID(record['name']) + r77name = record['name'] + suffix["R77"] + r80name = record['name'] + suffix["R80"] + if NUID: + print ' > Found original network object "{}" (uid: {})'.format(record['name'],NUID) + params = {} + params["uid"] = NUID + params["new-name"] = r77name + response = self.client.api_call("set-network", params) + if response.success: + print ' > Renamed network object to "{}" (uid: {})'.format(r77name,NUID) + else: + print ' > Failed to renamed network object to "{}" (Message: {})'.format(r77name,response.error_message) + NUID = "" + if NUID == "": + NUID = self.getNetworkUID(r77name) + if NUID: + print ' > Found R77 network object "{}" (uid: {})'.format(r77name,NUID) + #get uid for wildcard object + WUID = self.getWildcardUID(r80name) + if WUID: + print ' > Found wildcard object "{}" (uid: {})'.format(r80name,WUID) + else: + params = {} + params["name"] = r80name + params["ipv4-address"] = record["ipv4-address"] + params["ipv4-mask-wildcard"] = record["ipv4-mask-wildcard"] + params["color"] = record["color"] + response = self.client.api_call("add-wildcard", params) + if "uid" in response.data: + WUID = response.data["uid"] + print ' > Created new wildcard object "{}" (uid: {})'.format(r80name,WUID) + response = self.client.api_call("where-used", {'uid': '{}'.format(NUID)}) + for rule in response.data["used-directly"]["access-control-rules"]: + params = {} + params["uid"] = rule["rule"]["uid"] + params["layer"] = rule["layer"]["uid"] + params["details-level"] = "uid" + ruledata = self.client.api_call('show-access-rule', params) + ruledesc = "(uid: {})".format(params["uid"]) + if "name" in ruledata.data: + if ruledata.data["name"] != "": + ruledesc = '"{}"'.format(ruledata.data["name"]) + print ' > Checking rule {}'.format(ruledesc) + sources = 0 + destinations = 0 + params["source"] = ruledata.data["source"] + params["destination"] = ruledata.data["destination"] + for i, record in enumerate(params["source"]): + if record == NUID: + params["source"][i] = WUID + sources += 1 + for i, record in enumerate(params["destination"]): + if record == NUID: + params["destination"][i] = WUID + destinations += 1 + if sources > 0 or destinations > 0: + if sources > 0 and destinations > 0: + description = "source and destination columns" + elif sources > 0: + description = "source column" + else: + description = "destination column" + self.client.api_call('set-access-rule', params) + print ' > Updated {}'.format(description) + + def publish(self): + response = self.client.api_call('publish', {}, wait_for_task=True) + return response + + def logout(self): + response = self.client.api_call('logout', {}) + return response + + def getGenericUID(self, objectType, name): + retval = "" + response = self.client.api_call(objectType, {'name': '{}'.format(name),'details-level': 'uid'}) + if response.success and 'uid' in response.data: + retval = response.data['uid'] + return retval + + def getNetworkUID(self, name): + return self.getGenericUID('show-network',name) + + def getWildcardUID(self, name): + return self.getGenericUID('show-wildcard',name) -class APIVariables(): - def __init__(self): - self.managmentURL = "" - self.chkpSID = "" - self.sessionUID = "" - self.sessionHeaders = {} - self.sessionHeaders['Content-Type'] = 'application/json' - self.sessionHeaders['cache-control'] = 'no-cache' - -apivars = APIVariables() def main(): """Main entry point for the script.""" parser = argparse.ArgumentParser(description="Wildcard import/rule conversion script v" + __version__) - print ("*** ") - print ("*** " + parser.description) - print ("*** ") + print "*** " + print "*** " + parser.description + print "*** " parser._action_groups.pop() required = parser.add_argument_group('required arguments') optional = parser.add_argument_group('optional arguments') @@ -48,255 +126,67 @@ def main(): optional.add_argument("-d", "--domain", action="store", help="Domain (when using multidomain)") args = parser.parse_args() - apivars.managmentURL = args.server if os.path.isfile(args.input): - print ("--- Using input file '{}'".format(args.input)) + records = [] + with open(args.input, 'rb') as f: + reader = csv.DictReader(f) + for line in reader: + records.append(line) + print "--- Read {:,} records from input file '{}'".format(len(records),args.input) else: - print ("!!! Could not open input file '{}'".format(args.input)) + print "!!! Could not open input file '{}'".format(args.input) + sys.exit(1) + + print "--- Connecting to server at {}".format(args.server) + client = APIClient(APIClientArgs(server=args.server, unsafe_auto_accept=True)) + result = client.check_fingerprint() + if result is False: + print("!!! Could not get the server's fingerprint! Check connectivity with the server.") sys.exit(1) - print ("--- Connecting to server at {}".format(apivars.managmentURL)) if not args.user: args.user = raw_input('Username: ').strip('\r') else: - print ("--- Attempting to login as '{}'".format(args.user)) + print "--- Attempting to login as '{}'".format(args.user) if not args.password: args.password = getpass.getpass(stream=sys.stderr).strip('\r') - urllib3.disable_warnings() - #Login to management server - try: - params = {} - params["user"] = args.user - params["password"] = args.password - if args.domain: - print ("--- Using domain '{}'".format(args.domain)) - params["domain"] = args.domain - response = mgmtreq('login', params) + if args.domain: + print "--- Using domain '{}'".format(args.domain) + response = client.login(args.user, args.password, domain=args.domain) + else: + print "--- Standard login..." + response = client.login(args.user, args.password) - if response["api-server-version"]: - apiversion = response["api-server-version"] - if float(apiversion) < 1.3: - raise (APIException(409,"Server API needs to be version 1.3 or greater '{}' returned from login.".format(apiversion))) - print ("--- Login to management via API {} complete. (session-uid: {})".format(apiversion,apivars.sessionUID)) - except APIException as apie: - print ("!!! Login failed with message ({})".format(apie)) + if response.success is False: + print "!!! Login failed: {}".format(response.error_message) sys.exit(1) - #Parse CSV for records - records = getRecords(args.input) - - #Loop through each record. Rename existing objects and replace references with new object - if len(records) > 0: - print ('--- Parse CSV complete. ({:,} records found for processing)'.format(len(records))) - #Create variable to track number of objects - track = 0 - suffix = {} - suffix["R77"] = "_R77" - suffix["R80"] = "_WC" - for record in records: - track += 1 - print ('--- Record {:,} of {:,}. Working with {} (color:{} network:{} mask:{})'.format(track,len(records),record['name'],record['color'],record['ipv4-address'],record['ipv4-mask-wildcard'])) - #get uid for network object - NUID = getNetworkUID(record['name']) - r77name = record['name'] + suffix["R77"] - r80name = record['name'] + suffix["R80"] - if NUID: - try: - print (' > Found original network object "{}" (uid: {})'.format(record['name'],NUID)) - mgmtreq('set-network', {'uid': '{}'.format(NUID),'new-name': '{}'.format(r77name)}) - print (' > Renamed network object to "{}" (uid: {})'.format(r77name,NUID)) - NUID = "" - except APIException as apie: - print (' > Failed to rename network object to "{}" (uid: {}) {}'.format(r77name,NUID,apie)) - if NUID == "": - NUID = getNetworkUID(r77name) - if NUID: - print (' > Found R77 network object "{}" (uid: {})'.format(r77name,NUID)) - #get uid for wildcard object - WUID = getWildcardUID(r80name) - if WUID: - print (' > Found wildcard object "{}" (uid: {})'.format(r80name,WUID)) - else: - params = {} - params["name"] = r80name - params["ipv4-address"] = record["ipv4-address"] - params["ipv4-mask-wildcard"] = record["ipv4-mask-wildcard"] - params["color"] = record["color"] - response = mgmtreq('add-wildcard', params) - if response["uid"]: - WUID = response["uid"] - print (' > Created new wildcard object "{}" (uid: {})'.format(r80name,WUID)) - replaceWhereUsed(NUID,WUID) - changes = mgmtchanges() - if changes > 0: - print ("--- Found {:,} changes pending.".format(changes)) - mgmtpublish() - mgmtlogout() - -def replaceWhereUsed(OldID,NewID): - """Replace all instances of Old ID with New ID in rules.""" - response = mgmtreq('where-used', {'uid': '{}'.format(OldID)}) - for rule in response["used-directly"]["access-control-rules"]: - params = {} - params["uid"] = rule["rule"]["uid"] - params["layer"] = rule["layer"]["uid"] - params["details-level"] = "uid" - ruledata = mgmtreq('show-access-rule', params) - ruledesc = "(uid: {})".format(params["uid"]) - if "name" in ruledata: - if ruledata["name"] != "": - ruledesc = '"{}"'.format(ruledata["name"]) - print (' > Checking rule {}'.format(ruledesc)) - sources = 0 - destinations = 0 - params["source"] = ruledata["source"] - params["destination"] = ruledata["destination"] - for i, record in enumerate(params["source"]): - if record == OldID: - params["source"][i] = NewID - sources += 1 - for i, record in enumerate(params["destination"]): - if record == OldID: - params["destination"][i] = NewID - destinations += 1 - if sources > 0 or destinations > 0: - if sources > 0 and destinations > 0: - description = "source and destination columns" - elif sources > 0: - description = "source column" - else: - description = "destination column" - mgmtreq('set-access-rule', params) - print (' > Updated {}'.format(description)) - -def getRecords(filename): - """Read all records to be imported from CSV file.""" - retval = [] - with open(filename, 'rb') as f: - reader = csv.DictReader(f) - for line in reader: - retval.append(line) - return (retval) - -def getNetworkUID(name): - """Get UID for standard network object by name.""" - retval = "" - try: - data = mgmtreq('show-network', {'name': '{}'.format(name),'details-level': 'uid'}) - if 'uid' in data: - retval = data['uid'] - except APIException as apie: - if apie.code == 404: - retval = "" - else: - raise - return (retval) - -def getWildcardUID(name): - """Get UID for wildcard network object by name.""" - retval = "" - try: - data = mgmtreq('show-wildcard',{'name': '{}'.format(name),'details-level': 'uid'}) - if 'uid' in data: - retval = data['uid'] - except APIException as apie: - if apie.code == 404: - retval = "" - else: - raise - return (retval) - -def mgmtpublish(): - """Publish all pending changes.""" - try: - data = mgmtreq('publish', {}) - if data["task-id"]: - taskID = data["task-id"] - print ("--- Publish in-progress. (task-id: {})".format(taskID)) - processesing = True - while processesing: - data = mgmtreq('show-task',{'task-id': '{}'.format(taskID)}) - processesing = False - for task in data["tasks"]: - if task["status"] == "in progress": - processesing = True - break - if processesing: - sleep(2) - print ("--- Publish complete. (task-id: {})".format(taskID)) - - except APIException as apie: - print ("--- Publish failed! (ERROR: {})".format(apie.message)) - -def mgmtchanges(): - """Return number of changes in the current session.""" - retval = 0 - try: - data = mgmtreq('show-session', {'uid': '{}'.format(apivars.sessionUID)}) - if data["changes"]: - retval = data["changes"] - except APIException as apie: - print ("--- Show session failed! {}".format(apie)) - return (retval) - -def mgmtdiscard(): - """Discard any changes in the current session.""" - try: - data = mgmtreq('discard', {}) - if data["number-of-discarded-changes"]: - print ("--- Discard complete. ({:,} changes discarded)".format(data["number-of-discarded-changes"])) - except APIException as apie: - print ("--- Discard failed! {}".format(apie)) - -def mgmtlogout(): - """Return number of changes in the current session.""" - try: - data = mgmtreq('logout', {}) - if data["message"]: - print ("--- Logout complete. (message: {})".format(data["message"])) - except APIException as apie: - print ("--- Logout failed! {}".format(apie)) + if response.data['api-server-version'] >= 1.3: + print "--- Login to management via API {} complete. (session-uid: {})".format(response.data['api-server-version'],response.data['sid']) + else: + print "!!! Server API needs to be version 1.3 or greater '{}' returned from login.".format(response.data['api-server-version']) + sys.exit(1) -def mgmtreq(command, payload): - """Make a REST API call to management server and return results.""" - apiBaseStem = "/web_api/" - managementBase = urljoin(apivars.managmentURL,apiBaseStem) + wcm = WildcardManager(client) + print "--- Starting Convert..." + wcm.convert(records) + print "--- Starting Publish..." + response = wcm.publish() + if response.success: + print "--- Publish complete." + else: + print "--- Publish failed. (message: {})".format(response.error_message) + response = wcm.logout() + if response.success: + message = None + if 'message' in response.data: + message = response.data['message'] + print "--- Logout complete. (message: {})".format(message) + else: + print "--- Logout failed. (message: {})".format(response.error_message) - data = {} - if command: - isLogin = False - apiRequest = urljoin(managementBase,command) - headers = apivars.sessionHeaders - if command.lower() == 'login': - isLogin = True - if isLogin == False and apivars.chkpSID: - headers['X-chkp-sid'] = '{}'.format(apivars.chkpSID) - try: - response = requests.request("POST", apiRequest, json=payload, headers=headers, verify=False) - except requests.exceptions.ConnectionError: - raise APIException(522,"Connection to '{}' failed".format(apiRequest)) - if response.status_code == 200: - if response.text: - data = json.loads(response.text) - if isLogin: - if data["sid"] and data["uid"]: - apivars.chkpSID = data["sid"] - apivars.sessionUID = data["uid"] - else: - raise APIException(500,"Server returned a status code of 200, but with unexpected data.") - else: - if isLogin: - raise APIException(407,"Authentication for '{}' failed".format(apiRequest)) - else: - message = "Unknown error" - if response.text: - data = json.loads(response.text) - if data["code"]: - message = data["code"] - raise APIException(response.status_code,message) - return data if __name__ == '__main__': main() diff --git a/cp_mgmt_api_python_sdk b/cp_mgmt_api_python_sdk new file mode 160000 index 0000000..dcb0717 --- /dev/null +++ b/cp_mgmt_api_python_sdk @@ -0,0 +1 @@ +Subproject commit dcb07178cf7a0270455bab22e42868973eaa1c3b