Skip to content

Commit

Permalink
Merge pull request #2796 from Suor/opt-validate
Browse files Browse the repository at this point in the history
perf: switch schema validation library
  • Loading branch information
efiop authored Nov 20, 2019
2 parents f563892 + a1433ba commit c7c852a
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 190 deletions.
230 changes: 85 additions & 145 deletions dvc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@
import re

import configobj
from schema import And
from schema import Optional
from schema import Regex
from schema import Schema
from schema import SchemaError
from schema import Use
from voluptuous import Schema, Required, Optional, Invalid
from voluptuous import All, Any, Lower, Range, Coerce, Match

from dvc.exceptions import DvcException
from dvc.exceptions import NotDvcRepoError
Expand Down Expand Up @@ -50,82 +46,43 @@ def __init__(self, command, cause=None):


def supported_cache_type(types):
"""Checks if link type config option has a valid value.
"""Checks if link type config option consists only of valid values.
Args:
types (list/string): type(s) of links that dvc should try out.
"""
if types is None:
return None
if isinstance(types, str):
types = [typ.strip() for typ in types.split(",")]
for typ in types:
if typ not in ["reflink", "hardlink", "symlink", "copy"]:
return False
return True


def is_bool(val):
"""Checks that value is a boolean.
Args:
val (str): string value verify.
Returns:
bool: True if value stands for boolean, False otherwise.
"""
return val.lower() in ["true", "false"]


def to_bool(val):
"""Converts value to boolean.
Args:
val (str): string to convert to boolean.
Returns:
bool: True if value.lower() == 'true', False otherwise.
"""
return val.lower() == "true"


def is_whole(val):
"""Checks that value is a whole integer.
Args:
val (str): number string to verify.
Returns:
bool: True if val is a whole number, False otherwise.
"""
return int(val) >= 0

unsupported = set(types) - {"reflink", "hardlink", "symlink", "copy"}
if unsupported:
raise Invalid(
"Unsupported cache type(s): {}".format(", ".join(unsupported))
)

def is_percent(val):
"""Checks that value is a percent.
return types

Args:
val (str): number string to verify.

Returns:
bool: True if 0<=value<=100, False otherwise.
"""
return int(val) >= 0 and int(val) <= 100
# Checks that value is either true or false and converts it to bool
Bool = All(
Lower,
Any("true", "false"),
lambda v: v == "true",
msg="expected true or false",
)
to_bool = Schema(Bool)


class Choices(object):
def Choices(*choices):
"""Checks that value belongs to the specified set of values
Args:
*choices: pass allowed values as arguments, or pass a list or
tuple as a single argument
"""

def __init__(self, *choices):
if len(choices) == 1 and isinstance(choices[0], (list, tuple)):
choices = choices[0]
self.choices = choices

def __call__(self, value):
return value in self.choices
return Any(*choices, msg="expected one of {}".format(",".join(choices)))


class Config(object): # pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -158,28 +115,22 @@ class Config(object): # pylint: disable=too-many-instance-attributes
LEVEL_GLOBAL = 2
LEVEL_SYSTEM = 3

BOOL_SCHEMA = And(str, is_bool, Use(to_bool))

SECTION_CORE = "core"
SECTION_CORE_LOGLEVEL = "loglevel"
SECTION_CORE_LOGLEVEL_SCHEMA = And(
Use(str.lower), Choices("info", "debug", "warning", "error")
SECTION_CORE_LOGLEVEL_SCHEMA = All(
Lower, Choices("info", "debug", "warning", "error")
)
SECTION_CORE_REMOTE = "remote"
SECTION_CORE_INTERACTIVE_SCHEMA = BOOL_SCHEMA
SECTION_CORE_INTERACTIVE = "interactive"
SECTION_CORE_ANALYTICS = "analytics"
SECTION_CORE_ANALYTICS_SCHEMA = BOOL_SCHEMA
SECTION_CORE_CHECKSUM_JOBS = "checksum_jobs"
SECTION_CORE_CHECKSUM_JOBS_SCHEMA = And(Use(int), lambda x: x > 0)

SECTION_CACHE = "cache"
SECTION_CACHE_DIR = "dir"
SECTION_CACHE_TYPE = "type"
SECTION_CACHE_TYPE_SCHEMA = supported_cache_type
SECTION_CACHE_PROTECTED = "protected"
SECTION_CACHE_SHARED = "shared"
SECTION_CACHE_SHARED_SCHEMA = And(Use(str.lower), Choices("group"))
SECTION_CACHE_SHARED_SCHEMA = All(Lower, Choices("group"))
SECTION_CACHE_LOCAL = "local"
SECTION_CACHE_S3 = "s3"
SECTION_CACHE_GS = "gs"
Expand All @@ -188,34 +139,26 @@ class Config(object): # pylint: disable=too-many-instance-attributes
SECTION_CACHE_AZURE = "azure"
SECTION_CACHE_SLOW_LINK_WARNING = "slow_link_warning"
SECTION_CACHE_SCHEMA = {
Optional(SECTION_CACHE_LOCAL): str,
Optional(SECTION_CACHE_S3): str,
Optional(SECTION_CACHE_GS): str,
Optional(SECTION_CACHE_HDFS): str,
Optional(SECTION_CACHE_SSH): str,
Optional(SECTION_CACHE_AZURE): str,
Optional(SECTION_CACHE_DIR): str,
Optional(SECTION_CACHE_TYPE, default=None): SECTION_CACHE_TYPE_SCHEMA,
Optional(SECTION_CACHE_PROTECTED, default=False): BOOL_SCHEMA,
Optional(SECTION_CACHE_SHARED): SECTION_CACHE_SHARED_SCHEMA,
Optional(PRIVATE_CWD): str,
Optional(SECTION_CACHE_SLOW_LINK_WARNING, default=True): BOOL_SCHEMA,
SECTION_CACHE_LOCAL: str,
SECTION_CACHE_S3: str,
SECTION_CACHE_GS: str,
SECTION_CACHE_HDFS: str,
SECTION_CACHE_SSH: str,
SECTION_CACHE_AZURE: str,
SECTION_CACHE_DIR: str,
SECTION_CACHE_TYPE: supported_cache_type,
Optional(SECTION_CACHE_PROTECTED, default=False): Bool,
SECTION_CACHE_SHARED: SECTION_CACHE_SHARED_SCHEMA,
PRIVATE_CWD: str,
Optional(SECTION_CACHE_SLOW_LINK_WARNING, default=True): Bool,
}

SECTION_CORE_SCHEMA = {
Optional(SECTION_CORE_LOGLEVEL): And(
str, Use(str.lower), SECTION_CORE_LOGLEVEL_SCHEMA
),
Optional(SECTION_CORE_REMOTE, default=""): And(str, Use(str.lower)),
Optional(
SECTION_CORE_INTERACTIVE, default=False
): SECTION_CORE_INTERACTIVE_SCHEMA,
Optional(
SECTION_CORE_ANALYTICS, default=True
): SECTION_CORE_ANALYTICS_SCHEMA,
Optional(
SECTION_CORE_CHECKSUM_JOBS, default=None
): SECTION_CORE_CHECKSUM_JOBS_SCHEMA,
SECTION_CORE_LOGLEVEL: SECTION_CORE_LOGLEVEL_SCHEMA,
SECTION_CORE_REMOTE: Lower,
Optional(SECTION_CORE_INTERACTIVE, default=False): Bool,
Optional(SECTION_CORE_ANALYTICS, default=True): Bool,
SECTION_CORE_CHECKSUM_JOBS: All(Coerce(int), Range(1)),
}

# backward compatibility
Expand All @@ -230,15 +173,15 @@ class Config(object): # pylint: disable=too-many-instance-attributes
SECTION_AWS_SSE = "sse"
SECTION_AWS_ACL = "acl"
SECTION_AWS_SCHEMA = {
SECTION_AWS_STORAGEPATH: str,
Optional(SECTION_AWS_REGION): str,
Optional(SECTION_AWS_PROFILE): str,
Optional(SECTION_AWS_CREDENTIALPATH): str,
Optional(SECTION_AWS_ENDPOINT_URL): str,
Optional(SECTION_AWS_LIST_OBJECTS, default=False): BOOL_SCHEMA,
Optional(SECTION_AWS_USE_SSL, default=True): BOOL_SCHEMA,
Optional(SECTION_AWS_SSE): str,
Optional(SECTION_AWS_ACL): str,
Required(SECTION_AWS_STORAGEPATH): str,
SECTION_AWS_REGION: str,
SECTION_AWS_PROFILE: str,
SECTION_AWS_CREDENTIALPATH: str,
SECTION_AWS_ENDPOINT_URL: str,
Optional(SECTION_AWS_LIST_OBJECTS, default=False): Bool,
Optional(SECTION_AWS_USE_SSL, default=True): Bool,
SECTION_AWS_SSE: str,
SECTION_AWS_ACL: str,
}

# backward compatibility
Expand All @@ -247,14 +190,14 @@ class Config(object): # pylint: disable=too-many-instance-attributes
SECTION_GCP_CREDENTIALPATH = SECTION_AWS_CREDENTIALPATH
SECTION_GCP_PROJECTNAME = "projectname"
SECTION_GCP_SCHEMA = {
SECTION_GCP_STORAGEPATH: str,
Optional(SECTION_GCP_PROJECTNAME): str,
Required(SECTION_GCP_STORAGEPATH): str,
SECTION_GCP_PROJECTNAME: str,
}

# backward compatibility
SECTION_LOCAL = "local"
SECTION_LOCAL_STORAGEPATH = SECTION_AWS_STORAGEPATH
SECTION_LOCAL_SCHEMA = {SECTION_LOCAL_STORAGEPATH: str}
SECTION_LOCAL_SCHEMA = {Required(SECTION_LOCAL_STORAGEPATH): str}

SECTION_AZURE_CONNECTION_STRING = "connection_string"
# Alibabacloud oss options
Expand All @@ -274,51 +217,48 @@ class Config(object): # pylint: disable=too-many-instance-attributes
SECTION_REMOTE_GSS_AUTH = "gss_auth"
SECTION_REMOTE_NO_TRAVERSE = "no_traverse"
SECTION_REMOTE_SCHEMA = {
SECTION_REMOTE_URL: str,
Optional(SECTION_AWS_REGION): str,
Optional(SECTION_AWS_PROFILE): str,
Optional(SECTION_AWS_CREDENTIALPATH): str,
Optional(SECTION_AWS_ENDPOINT_URL): str,
Optional(SECTION_AWS_LIST_OBJECTS, default=False): BOOL_SCHEMA,
Optional(SECTION_AWS_USE_SSL, default=True): BOOL_SCHEMA,
Optional(SECTION_AWS_SSE): str,
Optional(SECTION_AWS_ACL): str,
Optional(SECTION_GCP_PROJECTNAME): str,
Optional(SECTION_CACHE_TYPE): SECTION_CACHE_TYPE_SCHEMA,
Optional(SECTION_CACHE_PROTECTED, default=False): BOOL_SCHEMA,
Optional(SECTION_REMOTE_USER): str,
Optional(SECTION_REMOTE_PORT): Use(int),
Optional(SECTION_REMOTE_KEY_FILE): str,
Optional(SECTION_REMOTE_TIMEOUT): Use(int),
Optional(SECTION_REMOTE_PASSWORD): str,
Optional(SECTION_REMOTE_ASK_PASSWORD): BOOL_SCHEMA,
Optional(SECTION_REMOTE_GSS_AUTH): BOOL_SCHEMA,
Optional(SECTION_AZURE_CONNECTION_STRING): str,
Optional(SECTION_OSS_ACCESS_KEY_ID): str,
Optional(SECTION_OSS_ACCESS_KEY_SECRET): str,
Optional(SECTION_OSS_ENDPOINT): str,
Optional(PRIVATE_CWD): str,
Optional(SECTION_REMOTE_NO_TRAVERSE, default=True): BOOL_SCHEMA,
Required(SECTION_REMOTE_URL): str,
SECTION_AWS_REGION: str,
SECTION_AWS_PROFILE: str,
SECTION_AWS_CREDENTIALPATH: str,
SECTION_AWS_ENDPOINT_URL: str,
Optional(SECTION_AWS_LIST_OBJECTS, default=False): Bool,
Optional(SECTION_AWS_USE_SSL, default=True): Bool,
SECTION_AWS_SSE: str,
SECTION_AWS_ACL: str,
SECTION_GCP_PROJECTNAME: str,
SECTION_CACHE_TYPE: supported_cache_type,
Optional(SECTION_CACHE_PROTECTED, default=False): Bool,
SECTION_REMOTE_USER: str,
SECTION_REMOTE_PORT: Coerce(int),
SECTION_REMOTE_KEY_FILE: str,
SECTION_REMOTE_TIMEOUT: Coerce(int),
SECTION_REMOTE_PASSWORD: str,
SECTION_REMOTE_ASK_PASSWORD: Bool,
SECTION_REMOTE_GSS_AUTH: Bool,
SECTION_AZURE_CONNECTION_STRING: str,
SECTION_OSS_ACCESS_KEY_ID: str,
SECTION_OSS_ACCESS_KEY_SECRET: str,
SECTION_OSS_ENDPOINT: str,
PRIVATE_CWD: str,
Optional(SECTION_REMOTE_NO_TRAVERSE, default=True): Bool,
}

SECTION_STATE = "state"
SECTION_STATE_ROW_LIMIT = "row_limit"
SECTION_STATE_ROW_CLEANUP_QUOTA = "row_cleanup_quota"
SECTION_STATE_SCHEMA = {
Optional(SECTION_STATE_ROW_LIMIT): And(Use(int), is_whole),
Optional(SECTION_STATE_ROW_CLEANUP_QUOTA): And(Use(int), is_percent),
SECTION_STATE_ROW_LIMIT: All(Coerce(int), Range(1)),
SECTION_STATE_ROW_CLEANUP_QUOTA: All(Coerce(int), Range(0, 100)),
}

SCHEMA = {
Optional(SECTION_CORE, default={}): SECTION_CORE_SCHEMA,
Optional(Regex(SECTION_REMOTE_REGEX)): SECTION_REMOTE_SCHEMA,
Match(SECTION_REMOTE_REGEX): SECTION_REMOTE_SCHEMA,
Optional(SECTION_CACHE, default={}): SECTION_CACHE_SCHEMA,
Optional(SECTION_STATE, default={}): SECTION_STATE_SCHEMA,
# backward compatibility
Optional(SECTION_AWS, default={}): SECTION_AWS_SCHEMA,
Optional(SECTION_GCP, default={}): SECTION_GCP_SCHEMA,
Optional(SECTION_LOCAL, default={}): SECTION_LOCAL_SCHEMA,
}
COMPILED_SCHEMA = Schema(SCHEMA)

def __init__(self, dvc_dir=None, validate=True):
self.dvc_dir = dvc_dir
Expand Down Expand Up @@ -457,9 +397,9 @@ def load(self):

d = self.config.dict()
try:
d = Schema(self.SCHEMA).validate(d)
except SchemaError as exc:
raise ConfigError("config format error", cause=exc)
d = self.COMPILED_SCHEMA(d)
except Invalid as exc:
raise ConfigError(str(exc), cause=exc)
self.config = configobj.ConfigObj(d, write_empty_values=True)

def save(self, config=None):
Expand Down
10 changes: 4 additions & 6 deletions dvc/dependency/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from __future__ import unicode_literals

import schema

import dvc.output as output
from .repo import DependencyREPO
from dvc.dependency.gs import DependencyGS
from dvc.dependency.hdfs import DependencyHDFS
from dvc.dependency.http import DependencyHTTP
Expand All @@ -14,6 +11,7 @@
from dvc.output.base import OutputBase
from dvc.remote import Remote
from dvc.scheme import Schemes
from .repo import DependencyREPO


DEPS = [
Expand Down Expand Up @@ -42,9 +40,9 @@
# cached, see -o and -O flags for `dvc run`) and 'metric' (whether or not
# output is a metric file and how to parse it, see `-M` flag for `dvc run`).
SCHEMA = output.SCHEMA.copy()
del SCHEMA[schema.Optional(OutputBase.PARAM_CACHE)]
del SCHEMA[schema.Optional(OutputBase.PARAM_METRIC)]
SCHEMA[schema.Optional(DependencyREPO.PARAM_REPO)] = DependencyREPO.REPO_SCHEMA
del SCHEMA[OutputBase.PARAM_CACHE]
del SCHEMA[OutputBase.PARAM_METRIC]
SCHEMA[DependencyREPO.PARAM_REPO] = DependencyREPO.REPO_SCHEMA


def _get(stage, p, info):
Expand Down
7 changes: 1 addition & 6 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from contextlib import contextmanager

from funcy import merge
from schema import Optional

from .local import DependencyLOCAL
from dvc.external_repo import external_repo
Expand All @@ -17,11 +16,7 @@ class DependencyREPO(DependencyLOCAL):
PARAM_REV = "rev"
PARAM_REV_LOCK = "rev_lock"

REPO_SCHEMA = {
Optional(PARAM_URL): str,
Optional(PARAM_REV): str,
Optional(PARAM_REV_LOCK): str,
}
REPO_SCHEMA = {PARAM_URL: str, PARAM_REV: str, PARAM_REV_LOCK: str}

def __init__(self, def_repo, stage, *args, **kwargs):
self.def_repo = def_repo
Expand Down
Loading

0 comments on commit c7c852a

Please sign in to comment.