Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use yaml.safe_load_all and pass bytes to base64 encode #657

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions install_scripts/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
def handle_error(e):
if isinstance(e, helpers.BadRequestException):
return str(e.message), 400
print(traceback.format_exc())
return "Internal Server Error", 500


Expand Down
24 changes: 8 additions & 16 deletions install_scripts/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,17 +310,6 @@ def get_terms(app_slug, app_channel):


def get_app_version_config(app_slug, app_channel):

# kubernetes and swarm yaml is not valid yaml
def handle_iter_exc(gen):
while True:
try:
yield next(gen)
except StopIteration:
raise
except Exception as exc:
print('Invalid yaml: {}:'.format(exc), file=sys.stderr)

cursor = db.get().cursor()
query = ('SELECT ar.config '
'FROM app a '
Expand All @@ -335,10 +324,13 @@ def handle_iter_exc(gen):
if row is None:
return []
(config_raw, ) = row
return [
doc for doc in handle_iter_exc(
yaml.load_all(base64.b64decode(config_raw)))
]
config_decoded = base64.b64decode(config_raw)
return get_all_valid_yaml_files(config_decoded)

# returns a list of all valid yaml files in the combined_config
# if a file is invalid, it will be skipped
def get_all_valid_yaml_files(combined_config):
return [doc for doc in yaml.safe_load_all(combined_config)]


def get_current_replicated_version(replicated_channel, scheduler=None):
Expand Down Expand Up @@ -480,7 +472,7 @@ def does_customer_exist(customer_id):
def base64_encode(data):
if len(data) == 0:
return ''
encoded = base64.b64encode(data)
encoded = bytes.decode(base64.b64encode(bytes(data, 'utf-8')))
return '\n'.join(
encoded[pos:pos + 76] for pos in range(0, len(encoded), 76))

Expand Down