Skip to content

Commit

Permalink
Fix cli behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
bdragon300 committed Oct 29, 2024
1 parent 4dcb88b commit 90efae2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
45 changes: 32 additions & 13 deletions pyzkaccess/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import argparse
import csv
import io
import ipaddress
Expand Down Expand Up @@ -33,7 +32,7 @@
import wrapt
from fire.core import FireError

from pyzkaccess import ZK100, ZK200, ZK400, UnsupportedPlatformError, ZKAccess, ZKModel
from pyzkaccess import ZK100, ZK200, ZK400, UnsupportedPlatformError, ZKAccess, ZKModel, Relay, Reader, AuxInput
from pyzkaccess._setup import setup
from pyzkaccess.device_data.model import Model, models_registry
from pyzkaccess.device_data.queryset import QuerySet
Expand Down Expand Up @@ -145,6 +144,7 @@ def get_writer(self) -> BaseFormatter.WriterInterface:

class ASCIITableFormatter(BaseFormatter):
"""Formatter for ASCII table format"""

class ASCIITableWriter(BaseFormatter.WriterInterface):
_writer: Optional[prettytable.PrettyTable]

Expand Down Expand Up @@ -210,6 +210,7 @@ class BaseConverter(metaclass=abc.ABCMeta):
to the objects of a particular type. It also converts the objects to
the string data suitable for output
"""

def __init__(self, formatter: BaseFormatter, *args: Any, **kwargs: Any) -> None:
self._formatter = formatter
self._args = args
Expand Down Expand Up @@ -258,7 +259,10 @@ def __init__(self, formatter: BaseFormatter, field_types: Mapping[str, Type], *a
# {type: (cast_function, error message)
self._input_converters: Mapping[Type, Tuple[Callable[[str], Any], str]] = {
str: (str, "string"),
bool: (lambda x: {"True": True, "False": False}[x.capitalize()], 'boolean, "True" or "False"'),
bool: (
lambda x: {"True": True, "False": False}[x.capitalize()] if isinstance(x, str) else bool(x),
"boolean, possible values: True, true, 1, False, false, 0",
),
int: (int, "integer"),
tuple: (self._parse_tuple, "comma separated values"),
date: (lambda x: datetime.strptime(x, "%Y-%m-%d").date(), 'date string, e.g. "2020-02-01"'),
Expand Down Expand Up @@ -328,7 +332,7 @@ def _parse_value(self, field_name: str, value: str, field_datatype: Type) -> Opt
cast_fn, error_msg = self._input_converters[field_datatype]
return cast_fn(value)
except (ValueError, TypeError, KeyError):
raise FireError(f"Bad value of {field_name}={value} but must be: {error_msg}")
raise FireError(f"Bad value of {field_name}={value}, must be: {error_msg}")

def _coalesce_value(self, value: Optional[Any], field_datatype: Type) -> str:
if value is None:
Expand All @@ -338,7 +342,9 @@ def _coalesce_value(self, value: Optional[Any], field_datatype: Type) -> str:

return self._output_converters[field_datatype](value)

def _parse_tuple(self, value: str) -> tuple:
def _parse_tuple(self, value: Union[str, tuple]) -> tuple:
if isinstance(value, tuple):
return value
return tuple(value.split(self.TUPLE_SEPARATOR))

def _coalesce_tuple(self, value: tuple) -> str:
Expand Down Expand Up @@ -444,19 +450,19 @@ def parse_array_index(opt_indexes: Optional[Union[int, str]]) -> Union[int, slic
return slice(None, None)
if isinstance(opt_indexes, str):
if not re.match(r"^\d-\d$", opt_indexes):
raise FireError("Select range must contain numbers divided by dash, for example 0-3")
raise FireError("Range must contain numbers divided by dash, e.g. 0-3")

pieces = opt_indexes.split("-")
start = int(pieces.pop(0)) if pieces else None
stop = int(pieces.pop(0) or 1000) + 1 if pieces else None
return slice(start, stop)
if isinstance(opt_indexes, int):
if opt_indexes < 0:
raise FireError("Select index must be a positive number")
raise FireError("Selection index must be a positive number")

return opt_indexes

raise FireError("Numbers must be list or int")
raise FireError("Selection must be an integer or range")


class Query:
Expand Down Expand Up @@ -593,13 +599,16 @@ class Doors:
def __init__(self, items):
self._items = items

def select(self, indexes: Union[int, list]):
def select(self, indexes: Union[int, str]):
"""Select doors to operate
Args:
indexes: Doors to select. Accepts index `select 2` or
range `select 0-2`. Indexes are started from 0.
"""
if isinstance(self._items, Door):
raise FireError("A single door is already selected")

self._items = self._items[parse_array_index(indexes)]
return self

Expand Down Expand Up @@ -639,14 +648,17 @@ class Relays:
def __init__(self, items):
self._items = items

def select(self, indexes: Union[int, list]):
def select(self, indexes: Union[int, str]):
"""
Select relays to operate
Args:
indexes: Relays to select. Accepts index `select 2` or
range `select 0-2`. Indexes are started from 0.
"""
if isinstance(self._items, Relay):
raise FireError("A single relay is already selected")

self._items = self._items[parse_array_index(indexes)]
return self

Expand All @@ -666,13 +678,16 @@ class Readers:
def __init__(self, items):
self._items = items

def select(self, indexes: Union[int, list]):
def select(self, indexes: Union[int, str]):
"""Select doors to operate
Args:
indexes: Readers to select. Accepts index `select 2` or
range `select 0-2`. Indexes are started from 0.
"""
if isinstance(self._items, Reader):
raise FireError("A single reader is already selected")

self._items = self._items[parse_array_index(indexes)]
return self

Expand All @@ -687,13 +702,16 @@ class AuxInputs:
def __init__(self, items):
self._items = items

def select(self, indexes: Union[int, list]):
def select(self, indexes: Union[int, str]):
"""Select doors to operate
Args:
indexes: Aux input to select. Accepts index `select 2` or
range `select 0-2`. Indexes are started from 0.
"""
if isinstance(self._items, AuxInput):
raise FireError("A single aux input is already selected")

self._items = self._items[parse_array_index(indexes)]
return self

Expand Down Expand Up @@ -726,6 +744,7 @@ def __init__(self, event_log) -> None:
self._io_converter = TypedFieldConverter(formatter, self._event_field_types)

def __call__(self):
self._event_log.refresh()
self._io_converter.write_records(
{s: getattr(ev, s) for s in self._event_field_types.keys()} for ev in self._event_log
)
Expand Down Expand Up @@ -1149,7 +1168,7 @@ def connect(self, ip_or_connstr: str, *, model: str = "ZK400") -> ZKCommand:
connstr = ip_or_connstr

# Hack: prevent making a connection if help is requested
if '--help' in sys.argv or '-h' in sys.argv:
if "--help" in sys.argv or "-h" in sys.argv:
connstr = None

zkcmd = ZKCommand(ZKAccess(connstr, device_model=model, dllpath=self._dllpath))
Expand Down
2 changes: 1 addition & 1 deletion pyzkaccess/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["ZKSDKError"]
__all__ = ["ZKSDKError", "UnsupportedPlatformError"]

from typing import Any

Expand Down

0 comments on commit 90efae2

Please sign in to comment.