diff --git a/mycli/config.py b/mycli/config.py index 03fb502a..e0f2d1fc 100644 --- a/mycli/config.py +++ b/mycli/config.py @@ -1,3 +1,4 @@ +import io import shutil from copy import copy from io import BytesIO, TextIOWrapper @@ -6,6 +7,7 @@ from os.path import exists import struct import sys +from typing import Union from configobj import ConfigObj, ConfigObjError from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes @@ -59,23 +61,34 @@ def read_config_file(f, list_values=True): return config -def get_included_configs(config_path) -> list: +def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list: """Get a list of configuration files that are included into config_path - with !includedir directive.""" - if not os.path.exists(config_path): + with !includedir directive. + + "Normal" configs should be passed as file paths. The only exception + is .mylogin which is decoded into a stream. However, it never + contains include directives and so will be ignored by this + function. + + """ + if not isinstance(config_file, str) or not os.path.isfile(config_file): return [] included_configs = [] - with open(config_path) as f: - include_directives = filter( - lambda s: s.startswith('!includedir'), - f - ) - dirs = map(lambda s: s.strip().split()[-1], include_directives) - dirs = filter(os.path.isdir, dirs) - for dir in dirs: - for filename in os.listdir(dir): - if filename.endswith('.cnf'): - included_configs.append(os.path.join(dir, filename)) + + try: + with open(config_file) as f: + include_directives = filter( + lambda s: s.startswith('!includedir'), + f + ) + dirs = map(lambda s: s.strip().split()[-1], include_directives) + dirs = filter(os.path.isdir, dirs) + for dir in dirs: + for filename in os.listdir(dir): + if filename.endswith('.cnf'): + included_configs.append(os.path.join(dir, filename)) + except (PermissionError, UnicodeDecodeError): + pass return included_configs @@ -86,8 +99,12 @@ def read_config_files(files, list_values=True): _files = copy(files) while _files: _file = _files.pop(0) - _files = get_included_configs(_file) + _files _config = read_config_file(_file, list_values=list_values) + + # expand includes only if we were able to parse config + # (otherwise we'll just encounter the same errors again) + if config is not None: + _files = get_included_configs(_file) + _files if bool(_config) is True: config.merge(_config) config.filename = _config.filename diff --git a/mycli/main.py b/mycli/main.py index 55afedd5..1fe2a848 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -318,7 +318,7 @@ def read_my_cnf_files(self, files, keys): """ cnf = read_config_files(files, list_values=False) - sections = ['client'] + sections = ['client', 'mysqld'] if self.login_path and self.login_path != 'client': sections.append(self.login_path)