diff --git a/.travis.yml b/.travis.yml index 4e759c7c9f8..00ae30b41e0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,7 +35,10 @@ script: - py.test --cov=qcodes --cov-report xml --cov-config=.coveragerc # build docs with warnings as errors - | - cd ../docs + cd .. + mypy qcodes --ignore-missing-imports + - | + cd docs make SPHINXOPTS="-W" html-api - cd .. diff --git a/qcodes/__init__.py b/qcodes/__init__.py index 18cb286c17f..e21fda06059 100644 --- a/qcodes/__init__.py +++ b/qcodes/__init__.py @@ -10,7 +10,7 @@ # we dont want spyder to reload qcodes as this will overwrite the default station # instrument list and running monitor add_to_spyder_UMR_excludelist('qcodes') -config = Config() +config = Config() # type: Config from qcodes.version import __version__ @@ -88,7 +88,7 @@ del _c try: - get_ipython() # Check if we are in iPython + get_ipython() # type: ignore # Check if we are in iPython from qcodes.utils.magic import register_magic_class _register_magic = config.core.get('register_magic', False) if _register_magic is not False: diff --git a/qcodes/config/config.py b/qcodes/config/config.py index 767832efe35..85b5056c35c 100644 --- a/qcodes/config/config.py +++ b/qcodes/config/config.py @@ -9,6 +9,7 @@ from pathlib import Path import jsonschema +from typing import Dict logger = logging.getLogger(__name__) @@ -87,8 +88,8 @@ class Config(): defaults = None defaults_schema = None - _diff_config = {} - _diff_schema = {} + _diff_config: Dict[str, dict] = {} + _diff_schema: Dict[str, dict] = {} def __init__(self): self.defaults, self.defaults_schema = self.load_default() diff --git a/qcodes/data/data_set.py b/qcodes/data/data_set.py index d7ec0f9e3ce..b1c25f7991f 100644 --- a/qcodes/data/data_set.py +++ b/qcodes/data/data_set.py @@ -5,6 +5,7 @@ from traceback import format_exc from copy import deepcopy from collections import OrderedDict +from typing import Dict, Callable from .gnuplot_format import GNUPlotFormat from .io import DiskIO @@ -168,7 +169,7 @@ class DataSet(DelegateAttributes): default_formatter = GNUPlotFormat() location_provider = FormatLocation() - background_functions = OrderedDict() + background_functions: Dict[str, Callable] = OrderedDict() def __init__(self, location=None, arrays=None, formatter=None, io=None, write_period=5): diff --git a/qcodes/data/location.py b/qcodes/data/location.py index f794bcf1806..92555307ed1 100644 --- a/qcodes/data/location.py +++ b/qcodes/data/location.py @@ -3,7 +3,7 @@ import re import string -from qcodes import config +import qcodes.config class SafeFormatter(string.Formatter): @@ -83,7 +83,7 @@ class FormatLocation: as '{date:%Y-%m-%d}' or '{counter:03}' """ - default_fmt = config['core']['default_fmt'] + default_fmt = qcodes.config['core']['default_fmt'] def __init__(self, fmt=None, fmt_date=None, fmt_time=None, fmt_counter=None, record=None): diff --git a/qcodes/dataset/data_set.py b/qcodes/dataset/data_set.py index 87d7841e1b0..5d6e0a912f4 100644 --- a/qcodes/dataset/data_set.py +++ b/qcodes/dataset/data_set.py @@ -179,8 +179,8 @@ def _new(self, name, exp_id, specs: SPECS = None, values=None, Actually perform all the side effects needed for the creation of a new dataset. """ - _, run_id, _ = create_run(self.conn, exp_id, name, - specs, values, metadata) + _, run_id, __ = create_run(self.conn, exp_id, name, + specs, values, metadata) # this is really the UUID (an ever increasing count in the db) self.run_id = run_id @@ -440,7 +440,7 @@ def modify_results(self, start_index: int, flattened_keys, flattened_values) - def add_parameter_values(self, spec: ParamSpec, values: List[VALUES]): + def add_parameter_values(self, spec: ParamSpec, values: VALUES): """ Add a parameter to the DataSet and associates result values with the new parameter. diff --git a/qcodes/dataset/experiment_container.py b/qcodes/dataset/experiment_container.py index 50caec3d549..e4a98155233 100644 --- a/qcodes/dataset/experiment_container.py +++ b/qcodes/dataset/experiment_container.py @@ -236,8 +236,8 @@ def load_experiment_by_name(name: str, for row in rows: s = f"exp_id:{row['exp_id']} ({row['name']}-{row['sample_name']}) started at({row['start_time']})" _repr.append(s) - _repr = "\n".join(_repr) - raise ValueError(f"Many experiments matching your request found {_repr}") + _repr_str = "\n".join(_repr) + raise ValueError(f"Many experiments matching your request found {_repr_str}") else: e.exp_id = rows[0]['exp_id'] return e diff --git a/qcodes/dataset/measurements.py b/qcodes/dataset/measurements.py index 3a64f19651c..7924fc79934 100644 --- a/qcodes/dataset/measurements.py +++ b/qcodes/dataset/measurements.py @@ -11,7 +11,7 @@ import qcodes as qc from qcodes import Station -from qcodes.instrument.parameter import ArrayParameter, _BaseParameter +from qcodes.instrument.parameter import ArrayParameter, _BaseParameter, Parameter from qcodes.dataset.experiment_container import Experiment from qcodes.dataset.param_spec import ParamSpec from qcodes.dataset.data_set import DataSet @@ -29,7 +29,7 @@ class DataSaver: datasaving to the database """ - default_callback = None + default_callback: Optional[dict] = None def __init__(self, dataset: DataSet, write_period: float, parameters: Dict[str, ParamSpec]) -> None: @@ -45,7 +45,7 @@ def __init__(self, dataset: DataSet, write_period: float, self._known_parameters = list(parameters.keys()) self._results: List[dict] = [] # will be filled by addResult self._last_save_time = monotonic() - self._known_dependencies: Dict[str, str] = {} + self._known_dependencies: Dict[str, List[str]] = {} for param, parspec in parameters.items(): if parspec.depends_on != '': self._known_dependencies.update({str(param): @@ -152,6 +152,7 @@ def add_result(self, # For compatibility with the old Loop, setpoints are # tuples of numbers (usually tuple(np.linspace(...)) if hasattr(value, '__len__') and not(isinstance(value, str)): + value = cast(Union[Sequence,np.ndarray], value) res_dict.update({param: value[index]}) else: res_dict.update({param: value}) @@ -398,6 +399,7 @@ def register_parameter( name = str(parameter) if isinstance(parameter, ArrayParameter): + parameter = cast(ArrayParameter, parameter) if parameter.setpoint_names: spname = (f'{parameter._instrument.name}_' f'{parameter.setpoint_names[0]}') @@ -416,8 +418,10 @@ def register_parameter( label=splabel, unit=spunit) self.parameters[spname] = sp - setpoints = setpoints if setpoints else () - setpoints += (spname,) + my_setpoints: Tuple[Union[_BaseParameter, str], ...] = setpoints if setpoints else () + my_setpoints += (spname,) + else: + my_setpoints = setpoints # We currently treat ALL parameters as 'numeric' and fail to add them # to the dataset if they can not be unraveled to fit that description @@ -426,12 +430,13 @@ def register_parameter( # requirement later and start saving binary blobs with the datasaver, # but for now binary blob saving is referred to using the DataSet # API directly + parameter = cast(Union[Parameter, ArrayParameter], parameter) paramtype = 'numeric' label = parameter.label unit = parameter.unit - if setpoints: - sp_strings = [str(sp) for sp in setpoints] + if my_setpoints: + sp_strings = [str(sp) for sp in my_setpoints] else: sp_strings = [] if basis: diff --git a/qcodes/dataset/plotting.py b/qcodes/dataset/plotting.py index d3be9e6759b..720c0cbd8d5 100644 --- a/qcodes/dataset/plotting.py +++ b/qcodes/dataset/plotting.py @@ -13,11 +13,10 @@ log = logging.getLogger(__name__) DB = qc.config["core"]["db_location"] -mplaxes = matplotlib.axes.Axes def plot_by_id(run_id: int, - axes: Optional[Union[mplaxes, - Sequence[mplaxes]]]=None) -> List[mplaxes]: + axes: Optional[Union[matplotlib.axes.Axes, + Sequence[matplotlib.axes.Axes]]]=None) -> List[matplotlib.axes.Axes]: def set_axis_labels(ax, data): if data[0]['label'] == '': lbl = data[0]['name'] @@ -50,7 +49,7 @@ def set_axis_labels(ax, data): """ alldata = get_data_by_id(run_id) nplots = len(alldata) - if isinstance(axes, mplaxes): + if isinstance(axes, matplotlib.axes.Axes): axes = [axes] if axes is None: @@ -115,7 +114,7 @@ def set_axis_labels(ax, data): def plot_on_a_plain_grid(x: np.ndarray, y: np.ndarray, z: np.ndarray, - ax: mplaxes) -> mplaxes: + ax: matplotlib.axes.Axes) -> matplotlib.axes.Axes: """ Plot a heatmap of z using x and y as axes. Assumes that the data are rectangular, i.e. that x and y together describe a rectangular diff --git a/qcodes/dataset/sqlite_base.py b/qcodes/dataset/sqlite_base.py index a9481b442ba..5ae4c2d038a 100644 --- a/qcodes/dataset/sqlite_base.py +++ b/qcodes/dataset/sqlite_base.py @@ -107,7 +107,7 @@ def _convert_array(text: bytes) -> ndarray: return np.load(out) -def one(curr: sqlite3.Cursor, column: str) -> Any: +def one(curr: sqlite3.Cursor, column: Union[int, str]) -> Any: """Get the value of one column from one row Args: curr: cursor to operate on @@ -408,11 +408,11 @@ def insert_many_values(conn: sqlite3.Connection, # According to the SQLite changelog, the version number # to check against below # ought to be 3.7.11, but that fails on Travis - if LooseVersion(version) <= LooseVersion('3.8.2'): + if LooseVersion(str(version)) <= LooseVersion('3.8.2'): max_var = qc.SQLiteSettings.limits['MAX_COMPOUND_SELECT'] else: max_var = qc.SQLiteSettings.limits['MAX_VARIABLE_NUMBER'] - rows_per_transaction = int(max_var/no_of_columns) + rows_per_transaction = int(int(max_var)/no_of_columns) _columns = ",".join(columns) _values = "(" + ",".join(["?"] * len(values[0])) + ")" diff --git a/qcodes/dataset/sqlite_settings.py b/qcodes/dataset/sqlite_settings.py index bd05b9ad1e0..65b821b3b62 100644 --- a/qcodes/dataset/sqlite_settings.py +++ b/qcodes/dataset/sqlite_settings.py @@ -2,7 +2,7 @@ from typing import Tuple, Dict, Union -def _read_settings() -> Tuple[Dict[str, str], +def _read_settings() -> Tuple[Dict[str, Union[str,int]], Dict[str, Union[bool, int, str]]]: """ Function to read the local SQLite settings at import time. @@ -19,6 +19,7 @@ def _read_settings() -> Tuple[Dict[str, str], """ # For the limits, there are known default values # (known from https://www.sqlite.org/limits.html) + DEFAULT_LIMITS: Dict[str, Union[str, int]] DEFAULT_LIMITS = {'MAX_ATTACHED': 10, 'MAX_COLUMN': 2000, 'MAX_COMPOUND_SELECT': 500, @@ -35,6 +36,7 @@ def _read_settings() -> Tuple[Dict[str, str], opt_num = 0 resp = '' + limits: Dict[str, Union[str,int]] limits = DEFAULT_LIMITS.copy() settings = {} @@ -47,6 +49,7 @@ def _read_settings() -> Tuple[Dict[str, str], opt_num += 1 lst = resp.split('=') if len(lst) == 2: + val: Union[str,int] (param, val) = lst if val.isnumeric(): val = int(val) diff --git a/qcodes/instrument/base.py b/qcodes/instrument/base.py index df6d60c33e2..2c88743fc37 100644 --- a/qcodes/instrument/base.py +++ b/qcodes/instrument/base.py @@ -3,14 +3,15 @@ import time import warnings import weakref -from typing import Sequence, Optional, Dict, Union, Callable, Any, List +from typing import Sequence, Optional, Dict, Union, Callable, Any, List, TYPE_CHECKING, cast import numpy as np - +if TYPE_CHECKING: + from qcodes.instrumet.channel import ChannelList from qcodes.utils.helpers import DelegateAttributes, strip_attrs, full_class from qcodes.utils.metadata import Metadatable from qcodes.utils.validators import Anything -from .parameter import Parameter +from .parameter import Parameter, _BaseParameter from .function import Function log = logging.getLogger(__name__) @@ -45,9 +46,9 @@ def __init__(self, name: str, metadata: Optional[Dict]=None, **kwargs) -> None: self.name = str(name) - self.parameters = {} - self.functions = {} - self.submodules = {} + self.parameters = {} # type: Dict[str, _BaseParameter] + self.functions = {} # type: Dict[str, Function] + self.submodules = {} # type: Dict[str, Union['InstrumentBase', 'ChannelList']] super().__init__(**kwargs) def add_parameter(self, name: str, @@ -109,7 +110,7 @@ def add_function(self, name: str, **kwargs) -> None: func = Function(name=name, instrument=self, **kwargs) self.functions[name] = func - def add_submodule(self, name: str, submodule: Metadatable) -> None: + def add_submodule(self, name: str, submodule: Union['InstrumentBase', 'ChannelList']) -> None: """ Bind one submodule to this instrument. @@ -360,7 +361,9 @@ class Instrument(InstrumentBase): shared_kwargs = () - _all_instruments = {} + _all_instruments = {} # type: Dict[str, weakref.ref[Instrument]] + _type = None + _instances = [] # type: List[weakref.ref] def __init__(self, name: str, metadata: Optional[Dict]=None, **kwargs) -> None: @@ -377,7 +380,7 @@ def __init__(self, name: str, self.record_instance(self) - def get_idn(self) -> Dict: + def get_idn(self) -> Dict[str, Optional[str]]: """ Parse a standard VISA '\*IDN?' response into an ID dict. @@ -399,6 +402,7 @@ def get_idn(self) -> Dict: idstr = self.ask('*IDN?') # form is supposed to be comma-separated, but we've seen # other separators occasionally + idparts = [] # type: List[Optional[str]] for separator in ',;:': # split into no more than 4 parts, so we don't lose info idparts = [p.strip() for p in idstr.split(separator, 3)] @@ -580,14 +584,14 @@ def find_instrument(cls, name: str, if ins is None: del cls._all_instruments[name] raise KeyError('Instrument {} has been removed'.format(name)) - + inst = cast('Instrument', ins) if instrument_class is not None: - if not isinstance(ins, instrument_class): + if not isinstance(inst, instrument_class): raise TypeError( 'Instrument {} is {} but {} was requested'.format( - name, type(ins), instrument_class)) + name, type(inst), instrument_class)) - return ins + return inst # `write_raw` and `ask_raw` are the interface to hardware # # `write` and `ask` are standard wrappers to help with error reporting # @@ -658,7 +662,7 @@ def ask(self, cmd: str) -> str: e.args = e.args + ('asking ' + repr(cmd) + ' to ' + inst,) raise e - def ask_raw(self, cmd: str) -> None: + def ask_raw(self, cmd: str) -> str: """ Low level method to write to the hardware and return a response. diff --git a/qcodes/instrument/channel.py b/qcodes/instrument/channel.py index 0eabe46bc01..26acccc9dbb 100644 --- a/qcodes/instrument/channel.py +++ b/qcodes/instrument/channel.py @@ -1,8 +1,8 @@ """ Base class for the channel of an instrument """ -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Optional, Dict, Sequence, cast from .base import InstrumentBase, Instrument -from .parameter import MultiParameter, ArrayParameter +from .parameter import MultiParameter, ArrayParameter, Parameter from ..utils.validators import Validator from ..utils.metadata import Metadatable from ..utils.helpers import full_class @@ -28,7 +28,7 @@ class InstrumentChannel(InstrumentBase): channel. Usually populated via ``add_function`` """ - def __init__(self, parent: Instrument, name: str, **kwargs): + def __init__(self, parent: Instrument, name: str, **kwargs) -> None: # Initialize base classes of Instrument. We will overwrite what we # want to do in the Instrument initializer super().__init__(name=name, **kwargs) @@ -79,7 +79,7 @@ class MultiChannelInstrumentParameter(MultiParameter): def __init__(self, channels: Union[List, Tuple], param_name: str, - *args, **kwargs): + *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._channels = channels self._param_name = param_name @@ -150,9 +150,9 @@ class ChannelList(Metadatable): def __init__(self, parent: Instrument, name: str, chan_type: type, - chan_list: Union[List, Tuple, None]=None, + chan_list: Optional[Sequence[InstrumentChannel]]=None, snapshotable: bool=True, - multichan_paramclass: type = MultiChannelInstrumentParameter): + multichan_paramclass: type = MultiChannelInstrumentParameter) -> None: super().__init__() self._parent = parent @@ -171,12 +171,13 @@ def __init__(self, parent: Instrument, self._snapshotable = snapshotable self._paramclass = multichan_paramclass - self._channel_mapping = {} # provide lookup of channels by name + self._channel_mapping: Dict[str, InstrumentChannel] = {} + # provide lookup of channels by name # If a list of channels is not provided, define a list to store # channels. This will eventually become a locked tuple. if chan_list is None: self._locked = False - self._channels = [] + self._channels: Union[List[InstrumentChannel],Tuple[InstrumentChannel, ...]] = [] else: self._locked = True self._channels = tuple(chan_list) @@ -238,7 +239,7 @@ def __add__(self, other: 'ChannelList'): "together.") return ChannelList(self._parent, self._name, self._chan_type, - self._channels + other._channels) + list(self._channels) + list(other._channels)) def append(self, obj: InstrumentChannel): """ @@ -248,7 +249,7 @@ def append(self, obj: InstrumentChannel): Args: obj(chan_type): New channel to add to the list. """ - if self._locked: + if (isinstance(self._channels, tuple) or self._locked): raise AttributeError("Cannot append to a locked channel list") if not isinstance(obj, self._chan_type): raise TypeError("All items in a channel list must be of the same " @@ -294,7 +295,7 @@ def insert(self, index: int, obj: InstrumentChannel): obj(chan_type): Object of type chan_type to insert. """ - if self._locked: + if (isinstance(self._channels, tuple) or self._locked): raise AttributeError("Cannot insert into a locked channel list") if not isinstance(obj, self._chan_type): raise TypeError("All items in a channel list must be of the same " @@ -324,7 +325,7 @@ def lock(self): self._channels = tuple(self._channels) self._locked = True - def snapshot_base(self, update: bool=False): + def snapshot_base(self, update: bool=False, params_to_skip_update: Optional[Sequence[str]]=None): """ State of the instrument as a JSON-compatible dict. @@ -368,30 +369,32 @@ def __getattr__(self, name: str): if isinstance(self._channels[0].parameters[name], MultiParameter): raise NotImplementedError("Slicing is currently not " "supported for MultiParameters") + parameters = cast(List[Union[Parameter, ArrayParameter]], + [chan.parameters[name] for chan in self._channels]) names = tuple("{}_{}".format(chan.name, name) for chan in self._channels) - labels = tuple(chan.parameters[name].label - for chan in self._channels) - units = tuple(chan.parameters[name].unit - for chan in self._channels) - - if isinstance(self._channels[0].parameters[name], ArrayParameter): - shapes = tuple(chan.parameters[name].shape for - chan in self._channels) - - if self._channels[0].parameters[name].setpoints: - setpoints = tuple(chan.parameters[name].setpoints for - chan in self._channels) - if self._channels[0].parameters[name].setpoint_names: - setpoint_names = tuple(chan.parameters[name].setpoint_names - for chan in self._channels) - if self._channels[0].parameters[name].setpoint_labels: + labels = tuple(parameter.label + for parameter in parameters) + units = tuple(parameter.unit + for parameter in parameters) + + if isinstance(parameters[0], ArrayParameter): + arrayparameters = cast(List[ArrayParameter],parameters) + shapes = tuple(parameter.shape for + parameter in arrayparameters) + if arrayparameters[0].setpoints: + setpoints = tuple(parameter.setpoints for + parameter in arrayparameters) + if arrayparameters[0].setpoint_names: + setpoint_names = tuple(parameter.setpoint_names for + parameter in arrayparameters) + if arrayparameters[0].setpoint_labels: setpoint_labels = tuple( - chan.parameters[name].setpoint_labels - for chan in self._channels) - if self._channels[0].parameters[name].setpoint_units: - setpoint_units = tuple(chan.parameters[name].setpoint_units - for chan in self._channels) + parameter.setpoint_labels + for parameter in arrayparameters) + if arrayparameters[0].setpoint_units: + setpoint_units = tuple(parameter.setpoint_units + for parameter in arrayparameters) else: shapes = tuple(() for _ in self._channels) @@ -427,7 +430,7 @@ def multi_func(*args, **kwargs): ''.format(self.__class__.__name__, name)) def __dir__(self) -> list: - names = super().__dir__() + names = list(super().__dir__()) if self._channels: names += list(self._channels[0].parameters.keys()) names += list(self._channels[0].functions.keys()) @@ -448,7 +451,7 @@ class ChannelListValidator(Validator): The channel list must be locked and populated before it can be used to construct a validator. """ - def __init__(self, channel_list: ChannelList): + def __init__(self, channel_list: ChannelList) -> None: # Save the base parameter list if not isinstance(channel_list, ChannelList): raise ValueError("channel_list must be a ChannelList object containing the " diff --git a/qcodes/instrument/ip_to_visa.py b/qcodes/instrument/ip_to_visa.py index c1eaf14ab6f..3578e791713 100644 --- a/qcodes/instrument/ip_to_visa.py +++ b/qcodes/instrument/ip_to_visa.py @@ -14,7 +14,7 @@ # Such a driver is just a two-line class definition. -class IPToVisa(VisaInstrument, IPInstrument): +class IPToVisa(VisaInstrument, IPInstrument): # type: ignore """ Class to inject an VisaInstrument like behaviour in an IPInstrument that we'd like to use as a VISAInstrument with the @@ -78,5 +78,5 @@ def close(self): self.remove_instance(self) -class AMI430_VISA(AMI430, IPToVisa): +class AMI430_VISA(AMI430, IPToVisa): # type: ignore pass diff --git a/qcodes/instrument/parameter.py b/qcodes/instrument/parameter.py index 51ec81b521c..015c4256031 100644 --- a/qcodes/instrument/parameter.py +++ b/qcodes/instrument/parameter.py @@ -59,7 +59,7 @@ import os import collections import warnings -from typing import Optional, Sequence, TYPE_CHECKING, Union, Callable, List +from typing import Optional, Sequence, TYPE_CHECKING, Union, Callable, List, Dict, Any, Sized from functools import partial, wraps import numpy @@ -149,6 +149,8 @@ class _BaseParameter(Metadatable, DeferredOperations): metadata (Optional[dict]): extra information to include with the JSON snapshot of the parameter """ + get_raw = None # type: Optional[Callable] + set_raw = None # type: Optional[Callable] def __init__(self, name: str, instrument: Optional['Instrument'], @@ -164,7 +166,7 @@ def __init__(self, name: str, snapshot_value: bool=True, max_val_age: Optional[float]=None, vals: Optional[Validator]=None, - delay: Optional[Union[int, float]]=None): + delay: Optional[Union[int, float]]=None) -> None: super().__init__(metadata) self.name = str(name) self._instrument = instrument @@ -203,14 +205,14 @@ def __init__(self, name: str, self._latest = {'value': None, 'ts': None, 'raw_value': None} self.get_latest = GetLatest(self, max_val_age=max_val_age) - if hasattr(self, 'get_raw'): + if hasattr(self, 'get_raw') and self.get_raw is not None: self.get = self._wrap_get(self.get_raw) elif hasattr(self, 'get'): warnings.warn('Wrapping get method, original get method will not ' 'be directly accessible. It is recommended to ' 'define get_raw in your subclass instead.' ) self.get = self._wrap_get(self.get) - if hasattr(self, 'set_raw'): + if hasattr(self, 'set_raw') and self.set_raw is not None: self.set = self._wrap_set(self.set_raw) elif hasattr(self, 'set'): warnings.warn('Wrapping set method, original set method will not ' @@ -253,7 +255,7 @@ def __call__(self, *args, **kwargs): ' Parameter {}'.format(self.name)) def snapshot_base(self, update: bool=False, - params_to_skip_update: Sequence[str]=None) -> dict: + params_to_skip_update: Sequence[str]=None) -> Dict[str, Any]: """ State of the parameter as a JSON-compatible dict. @@ -271,7 +273,7 @@ def snapshot_base(self, update: bool=False, and self._snapshot_value and update: self.get() - state = copy(self._latest) + state = copy(self._latest) # type: Dict[str, Any] state['__class__'] = full_class(self) state['full_name'] = str(self) @@ -280,7 +282,8 @@ def snapshot_base(self, update: bool=False, state.pop('raw_value', None) if isinstance(state['ts'], datetime): - state['ts'] = state['ts'].strftime('%Y-%m-%d %H:%M:%S') + dttime = state['ts'] # type: datetime + state['ts'] = dttime.strftime('%Y-%m-%d %H:%M:%S') for attr in set(self._meta_attrs): if attr == 'instrument' and self._instrument: @@ -418,9 +421,10 @@ def set_wrapper(value, **kwargs): return set_wrapper - def get_ramp_values(self, value: Union[float, int], + def get_ramp_values(self, value: Union[float, int, Sized], step: Union[float, int]=None) -> List[Union[float, - int]]: + int, + Sized]]: """ Return values to sweep from current value to target value. This method can be overridden to have a custom sweep behaviour. @@ -435,7 +439,7 @@ def get_ramp_values(self, value: Union[float, int], if step is None: return [value] else: - if isinstance(value, collections.Iterable) and len(value) > 1: + if isinstance(value, collections.Sized) and len(value) > 1: raise RuntimeError("Don't know how to step a parameter with more than one value") if self.get_latest() is None: self.get() @@ -495,7 +499,7 @@ def step(self, step: Union[int, float]): TypeError: if step is not a number """ if step is None: - self._step = step + self._step = step # type: Optional[Union[float, int]] elif not getattr(self.vals, 'is_numeric', True): raise TypeError('you can only step numeric parameters') elif not isinstance(step, (int, float)): @@ -739,9 +743,9 @@ def __init__(self, name: str, set_cmd: Optional[Union[str, Callable, bool]]=False, initial_value: Optional[Union[float, int, str]]=None, max_val_age: Optional[float]=None, - vals: Optional[str]=None, + vals: Optional[Validator]=None, docstring: Optional[str]=None, - **kwargs): + **kwargs) -> None: super().__init__(name=name, instrument=instrument, vals=vals, **kwargs) # Enable set/get methods if get_cmd/set_cmd is given @@ -753,16 +757,16 @@ def __init__(self, name: str, 'when max_val_age is set') self.get_raw = lambda: self._latest['raw_value'] else: - exec_str = instrument.ask if instrument else None - self.get_raw = Command(arg_count=0, cmd=get_cmd, exec_str=exec_str) + exec_str_ask = instrument.ask if instrument else None + self.get_raw = Command(arg_count=0, cmd=get_cmd, exec_str=exec_str_ask) self.get = self._wrap_get(self.get_raw) if not hasattr(self, 'set') and set_cmd is not False: if set_cmd is None: - self.set_raw = partial(self._save_val, validate=False) + self.set_raw = partial(self._save_val, validate=False)# type: Callable else: - exec_str = instrument.write if instrument else None - self.set_raw = Command(arg_count=1, cmd=set_cmd, exec_str=exec_str) + exec_str_write = instrument.write if instrument else None + self.set_raw = Command(arg_count=1, cmd=set_cmd, exec_str=exec_str_write)# type: Callable self.set = self._wrap_set(self.set_raw) self._meta_attrs.extend(['label', 'unit', 'vals']) @@ -918,7 +922,7 @@ def __init__(self, docstring: Optional[str]=None, snapshot_get: bool=True, snapshot_value: bool=False, - metadata: bool=None): + metadata: Optional[dict]=None) -> None: super().__init__(name, instrument, snapshot_get, metadata, snapshot_value=snapshot_value) @@ -1094,7 +1098,7 @@ def __init__(self, docstring: str=None, snapshot_get: bool=True, snapshot_value: bool=False, - metadata: Optional[dict]=None): + metadata: Optional[dict]=None) -> None: super().__init__(name, instrument, snapshot_get, metadata, snapshot_value=snapshot_value) @@ -1278,6 +1282,7 @@ def __init__(self, parameters, name, label=None, if unit is None: unit = units self.parameter.unit = unit + self.setpoints=[] # endhack self.parameters = parameters self.sets = [parameter.set for parameter in self.parameters] @@ -1302,7 +1307,7 @@ def set(self, index: int): setFunction(value) return values - def sweep(self, *array: numpy.ndarray): + def sweep(self, *array: numpy.ndarray) -> 'CombinedParameter': """ Creates a new combined parameter to be iterated over. One can sweep over either: @@ -1324,26 +1329,26 @@ def sweep(self, *array: numpy.ndarray): dim = set([len(a) for a in array]) if len(dim) != 1: raise ValueError('Arrays have different number of setpoints') - array = numpy.array(array).transpose() + nparray = numpy.array(array).transpose() else: # cast to array in case users # decide to not read docstring # and pass a 2d list - array = numpy.array(array[0]) + nparray = numpy.array(array[0]) new = copy(self) _error_msg = """ Dimensionality of array does not match\ the number of parameter combined. Expected a \ {} dimensional array, got a {} dimensional array. \ """ try: - if array.shape[1] != self.dimensionality: + if nparray.shape[1] != self.dimensionality: raise ValueError(_error_msg.format(self.dimensionality, - array.shape[1])) + nparray.shape[1])) except KeyError: # this means the array is 1d raise ValueError(_error_msg.format(self.dimensionality, 1)) - new.setpoints = array.tolist() + new.setpoints = nparray.tolist() return new def _aggregate(self, *vals): diff --git a/qcodes/instrument_drivers/Keysight/KeysightAgilent_33XXX.py b/qcodes/instrument_drivers/Keysight/KeysightAgilent_33XXX.py index a1654ea5b5d..eb00573be4f 100644 --- a/qcodes/instrument_drivers/Keysight/KeysightAgilent_33XXX.py +++ b/qcodes/instrument_drivers/Keysight/KeysightAgilent_33XXX.py @@ -1,5 +1,6 @@ from functools import partial import logging +from typing import Union from qcodes import VisaInstrument, validators as vals from qcodes.instrument.channel import InstrumentChannel @@ -27,7 +28,7 @@ def __init__(self, parent: Instrument, name: str, channum: int) -> None: """ super().__init__(parent, name) - def val_parser(parser: type, inputstring: str) -> str: + def val_parser(parser: type, inputstring: str) -> Union[float,int]: """ Parses return values from instrument. Meant to be used when a query can return a meaningful finite number or a numeric representation diff --git a/qcodes/instrument_drivers/Keysight/M3201A.py b/qcodes/instrument_drivers/Keysight/M3201A.py index ee6ba180411..d207c732d9b 100644 --- a/qcodes/instrument_drivers/Keysight/M3201A.py +++ b/qcodes/instrument_drivers/Keysight/M3201A.py @@ -2,7 +2,7 @@ # from functools import partial # # from .SD_common.SD_Module import * -from .SD_common.SD_AWG import * +from .SD_common.SD_AWG import SD_AWG class Keysight_M3201A(SD_AWG): diff --git a/qcodes/instrument_drivers/Keysight/test_suite.py b/qcodes/instrument_drivers/Keysight/test_suite.py index db663f8d88f..cffc5f10a3f 100644 --- a/qcodes/instrument_drivers/Keysight/test_suite.py +++ b/qcodes/instrument_drivers/Keysight/test_suite.py @@ -1,16 +1,23 @@ from qcodes.instrument_drivers.test import DriverTestCase import unittest + try: from .M3201A import Keysight_M3201A + Keysight_M3201A_found = True +except ImportError: + Keysight_M3201A_found = False +try: from .M3300A import M3300A_AWG + M3300A_AWG_found = True +except ImportError: + M3300A_AWG_found = False +try: from .SD_common.SD_Module import SD_Module + SD_Module_found = True except ImportError: - SD_Module = None - Keysight_M3201A = None - M3300A_AWG = None + SD_Module_found = False - -@unittest.skipIf(not SD_Module, "SD_Module tests requires the keysightSD1 module") +@unittest.skipIf(not SD_Module_found, "SD_Module tests requires the keysightSD1 module") class TestSD_Module(DriverTestCase): """ Tis is a test suite for testing the general Keysight SD_Module driver. @@ -18,8 +25,8 @@ class TestSD_Module(DriverTestCase): This test suit is only used during the development of the general SD_Module driver. In a real-life scenario, no direct instances will be made from this class, but rather instances of either SD_AWG or SD_DIG. """ - - driver = SD_Module + if SD_Module_found: + driver = SD_Module @classmethod def setUpClass(cls): @@ -40,7 +47,7 @@ def test_chassis_and_slot(self): self.assertEqual(serial_number_test, serial_number) -@unittest.skipIf(not Keysight_M3201A, "Keysight_M3201A tests requires the keysightSD1 module") +@unittest.skipIf(not Keysight_M3201A_found, "Keysight_M3201A tests requires the keysightSD1 module") class TestKeysight_M3201A(DriverTestCase): """ This is a test suite for testing the Signadyne M3201A AWG card driver. @@ -59,8 +66,8 @@ class TestKeysight_M3201A(DriverTestCase): We can however test for ValueErrors which is a useful safety test. """ - - driver = Keysight_M3201A + if Keysight_M3201A_found: + driver = Keysight_M3201A @classmethod def setUpClass(cls): @@ -279,7 +286,7 @@ def test_PXI_trigger(self): self.instrument.pxi_trigger_number_0.set(cur_pxi) -@unittest.skipIf(not M3300A_AWG, "M3300A_AWG tests requires the keysightSD1 module") +@unittest.skipIf(not M3300A_AWG_found, "M3300A_AWG tests requires the keysightSD1 module") class TestKeysight_M3300A(DriverTestCase): """ This is a test suite for testing the Signadyne M3201A AWG card driver. @@ -298,8 +305,8 @@ class TestKeysight_M3300A(DriverTestCase): We can however test for ValueErrors which is a useful safety test. """ - - driver = M3300A_AWG + if M3300A_AWG_found: + driver = M3300A_AWG @classmethod def setUpClass(cls): diff --git a/qcodes/instrument_drivers/QuTech/D4.py b/qcodes/instrument_drivers/QuTech/D4.py index 7072f0682ec..262c87134a0 100644 --- a/qcodes/instrument_drivers/QuTech/D4.py +++ b/qcodes/instrument_drivers/QuTech/D4.py @@ -1,5 +1,4 @@ -from qcodes import Instrument - +from qcodes.instrument.base import Instrument try: from spirack import D4_module except ImportError: @@ -8,7 +7,6 @@ from functools import partial - class D4(Instrument): """ Qcodes driver for the D4 ADC SPI-rack module. Requires installation diff --git a/qcodes/instrument_drivers/QuTech/D5a.py b/qcodes/instrument_drivers/QuTech/D5a.py index 0499ec2c498..1d0008cf12a 100644 --- a/qcodes/instrument_drivers/QuTech/D5a.py +++ b/qcodes/instrument_drivers/QuTech/D5a.py @@ -1,4 +1,4 @@ -from qcodes import Instrument +from qcodes.instrument.base import Instrument from qcodes.utils.validators import Enum, Numbers try: diff --git a/qcodes/instrument_drivers/QuTech/F1d.py b/qcodes/instrument_drivers/QuTech/F1d.py index 07364b0e2be..4095334ccd2 100644 --- a/qcodes/instrument_drivers/QuTech/F1d.py +++ b/qcodes/instrument_drivers/QuTech/F1d.py @@ -1,4 +1,4 @@ -from qcodes import Instrument +from qcodes.instrument.base import Instrument from qcodes.utils.validators import Enum try: diff --git a/qcodes/instrument_drivers/QuTech/S5i.py b/qcodes/instrument_drivers/QuTech/S5i.py index e324e0d0411..163b6dd808e 100644 --- a/qcodes/instrument_drivers/QuTech/S5i.py +++ b/qcodes/instrument_drivers/QuTech/S5i.py @@ -1,4 +1,4 @@ -from qcodes import Instrument +from qcodes.instrument.base import Instrument from qcodes.utils.validators import Bool, Numbers try: diff --git a/qcodes/instrument_drivers/Spectrum/M4i.py b/qcodes/instrument_drivers/Spectrum/M4i.py index 9ff9929f376..6a3838f4e7e 100644 --- a/qcodes/instrument_drivers/Spectrum/M4i.py +++ b/qcodes/instrument_drivers/Spectrum/M4i.py @@ -31,9 +31,9 @@ sys.path.append(header_dir) import pyspcm except (ImportError, OSError) as ex: - log.exception(ex) - raise ImportError( - 'to use the M4i driver install the pyspcm module and the M4i libs') + info_str = 'to use the M4i driver install the pyspcm module and the M4i libs' + log.exception(info_str) + raise ImportError(info_str) #%% Helper functions diff --git a/qcodes/instrument_drivers/Spectrum/pyspcm.py b/qcodes/instrument_drivers/Spectrum/pyspcm.py index b278ca7d1fb..5d47b396d5c 100644 --- a/qcodes/instrument_drivers/Spectrum/pyspcm.py +++ b/qcodes/instrument_drivers/Spectrum/pyspcm.py @@ -1,6 +1,7 @@ import os import platform -import sys +import sys +from ctypes import c_int8, c_int16, c_int32, c_int64, c_uint8, c_uint16, c_uint32, c_uint64, c_char_p, POINTER, c_void_p, cdll from ctypes import * # load registers for easier access @@ -59,9 +60,9 @@ # Load DLL into memory. # use windll because all driver access functions use _stdcall calling convention under windows if (bIs64Bit == 1): - spcmDll = windll.LoadLibrary ("c:\\windows\\system32\\spcm_win64.dll") + spcmDll = windll.LoadLibrary ("c:\\windows\\system32\\spcm_win64.dll") # type: ignore else: - spcmDll = windll.LoadLibrary ("c:\\windows\\system32\\spcm_win32.dll") + spcmDll = windll.LoadLibrary ("c:\\windows\\system32\\spcm_win32.dll") # type: ignore # load spcm_hOpen if (bIs64Bit): diff --git a/qcodes/instrument_drivers/ZI/ZIUHFLI.py b/qcodes/instrument_drivers/ZI/ZIUHFLI.py index 06eac506e25..6cd8e98bff2 100644 --- a/qcodes/instrument_drivers/ZI/ZIUHFLI.py +++ b/qcodes/instrument_drivers/ZI/ZIUHFLI.py @@ -4,7 +4,7 @@ from functools import partial from math import sqrt -from typing import Callable, List, Union +from typing import Callable, List, Union, cast try: import zhinst.utils @@ -1243,7 +1243,7 @@ def __init__(self, name: str, device_ID: str, **kwargs) -> None: # A "manual" parameter: a list of the signals for the sweeper # to subscribe to - self._sweeper_signals = [] + self._sweeper_signals = [] # type: List[str] # This is the dictionary keeping track of the sweeper settings # These are the default settings @@ -1629,7 +1629,7 @@ def _get_demod_sample(self, number: int, demod_param: str) -> float: setting = 'sample' if demod_param not in ['x', 'y', 'R', 'phi']: raise RuntimeError("Invalid demodulator parameter") - datadict = self._getter(module, number, mode, setting) + datadict = cast(dict, self._getter(module, number, mode, setting)) datadict['R'] = np.abs(datadict['x'] + 1j * datadict['y']) datadict['phi'] = np.angle(datadict['x'] + 1j * datadict['y'], deg=True) return datadict[demod_param] diff --git a/qcodes/instrument_drivers/american_magnetics/AMI430.py b/qcodes/instrument_drivers/american_magnetics/AMI430.py index fd1a8c284c3..e643ab40c92 100644 --- a/qcodes/instrument_drivers/american_magnetics/AMI430.py +++ b/qcodes/instrument_drivers/american_magnetics/AMI430.py @@ -28,7 +28,7 @@ def check_enabled_decorator(self, *args, **kwargs): return f(self, *args, **kwargs) return check_enabled_decorator - def __init__(self, parent: 'AMI430'): + def __init__(self, parent: 'AMI430') -> None: super().__init__(parent, "SwitchHeater") # Add state parameters diff --git a/qcodes/instrument_drivers/devices.py b/qcodes/instrument_drivers/devices.py index 6c561781385..0544ea98c3b 100644 --- a/qcodes/instrument_drivers/devices.py +++ b/qcodes/instrument_drivers/devices.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, cast from qcodes import Parameter, Instrument @@ -73,7 +73,10 @@ def __init__(self, self._meta_attrs.extend(["division_value"]) def set_raw(self, value: Union[int, float]) -> None: - instrument_value = value * self.division_value + instrument_value = value * self.division_value # type: ignore + # disable type check due to https://github.com/python/mypy/issues/2128 + instrument_value = cast(Union[int, float], instrument_value) + self._save_val(value) self.v1.set(instrument_value) diff --git a/qcodes/instrument_drivers/ithaco/Ithaco_1211.py b/qcodes/instrument_drivers/ithaco/Ithaco_1211.py index 709a537b12c..83694539e6f 100644 --- a/qcodes/instrument_drivers/ithaco/Ithaco_1211.py +++ b/qcodes/instrument_drivers/ithaco/Ithaco_1211.py @@ -1,4 +1,4 @@ -from qcodes import Instrument +from qcodes.instrument.base import Instrument from qcodes.instrument.parameter import MultiParameter from qcodes.utils.validators import Enum, Bool diff --git a/qcodes/instrument_drivers/rohde_schwarz/ZNB.py b/qcodes/instrument_drivers/rohde_schwarz/ZNB.py index 10929beb8e5..b3a7af0934b 100644 --- a/qcodes/instrument_drivers/rohde_schwarz/ZNB.py +++ b/qcodes/instrument_drivers/rohde_schwarz/ZNB.py @@ -94,7 +94,7 @@ def get_raw(self): class ZNBChannel(InstrumentChannel): - def __init__(self, parent, name, channel, vna_parameter: str=None): + def __init__(self, parent, name, channel, vna_parameter: str=None) -> None: """ Args: parent: Instrument that this channel is bound to. @@ -372,10 +372,11 @@ class ZNB(VisaInstrument): TODO: - check initialisation settings and test functions """ + CHANNEL_CLASS = ZNBChannel - def __init__(self, name: str, address: str, init_s_params: bool=True, **kwargs): + def __init__(self, name: str, address: str, init_s_params: bool=True, **kwargs) -> None: super().__init__(name=name, address=address, **kwargs) @@ -384,7 +385,11 @@ def __init__(self, name: str, address: str, init_s_params: bool=True, **kwargs): # See page 1025 in the manual. 7.3.15.10 for details of max/min freq # no attempt to support ZNB40, not clear without one how the format # is due to variants - model = self.get_idn()['model'].split('-')[0] + fullmodel = self.get_idn()['model'] + if fullmodel is not None: + model = fullmodel.split('-')[0] + else: + raise RuntimeError("Could not determine ZNB model") # format seems to be ZNB8-4Port if model == 'ZNB4': self._max_freq = 4.5e9 @@ -395,6 +400,8 @@ def __init__(self, name: str, address: str, init_s_params: bool=True, **kwargs): elif model == 'ZNB20': self._max_freq = 20e9 self._min_freq = 100e3 + else: + raise RuntimeError("Unsupported ZNB model {}".format(model)) self.add_parameter(name='num_ports', get_cmd='INST:PORT:COUN?', get_parser=int) diff --git a/qcodes/instrument_drivers/stanford_research/SR560.py b/qcodes/instrument_drivers/stanford_research/SR560.py index 1fb9f03b3f2..b961859d71a 100644 --- a/qcodes/instrument_drivers/stanford_research/SR560.py +++ b/qcodes/instrument_drivers/stanford_research/SR560.py @@ -1,4 +1,4 @@ -from qcodes import Instrument +from qcodes.instrument.base import Instrument from qcodes.instrument.parameter import MultiParameter from qcodes.utils.validators import Bool, Enum diff --git a/qcodes/instrument_drivers/stanford_research/SR830.py b/qcodes/instrument_drivers/stanford_research/SR830.py index a707d729162..0e4447fb4f8 100644 --- a/qcodes/instrument_drivers/stanford_research/SR830.py +++ b/qcodes/instrument_drivers/stanford_research/SR830.py @@ -15,7 +15,7 @@ class ChannelBuffer(ArrayParameter): The instrument natively supports this in its TRCL call. """ - def __init__(self, name: str, instrument: 'SR830', channel: int): + def __init__(self, name: str, instrument: 'SR830', channel: int) -> None: """ Args: name (str): The name of the parameter diff --git a/qcodes/instrument_drivers/stanford_research/SR86x.py b/qcodes/instrument_drivers/stanford_research/SR86x.py index c5baebcef08..e3dab3547e7 100644 --- a/qcodes/instrument_drivers/stanford_research/SR86x.py +++ b/qcodes/instrument_drivers/stanford_research/SR86x.py @@ -1,7 +1,7 @@ import numpy as np import time import logging -from typing import Sequence, Dict +from typing import Optional, Sequence, Dict from qcodes import VisaInstrument from qcodes.instrument.channel import InstrumentChannel @@ -141,7 +141,6 @@ def __init__(self, parent: 'SR86x', name: str) ->None: ) self.bytes_per_sample = 4 - self._capture_data = dict() def snapshot_base(self, update: bool = False, params_to_skip_update: Sequence[str] = None) -> Dict: @@ -151,6 +150,7 @@ def snapshot_base(self, update: bool = False, # it can only be read after a completed capture and will # timeout otherwise when the snapshot is updated, e.g. at # station creation time + params_to_skip_update = list(params_to_skip_update) params_to_skip_update.append('count_capture_kilobytes') snapshot = super().snapshot_base(update, params_to_skip_update) diff --git a/qcodes/instrument_drivers/tektronix/AWG70000A.py b/qcodes/instrument_drivers/tektronix/AWG70000A.py index e5f0e3c7523..a69667ca0b0 100644 --- a/qcodes/instrument_drivers/tektronix/AWG70000A.py +++ b/qcodes/instrument_drivers/tektronix/AWG70000A.py @@ -7,7 +7,7 @@ import logging from functools import partial -from dateutil.tz import time +import time from typing import List, Sequence from qcodes import Instrument, VisaInstrument, validators as vals @@ -262,7 +262,7 @@ def setWaveform(self, name: str) -> None: if name not in self.root_instrument.waveformList: raise ValueError('No such waveform in the waveform list') - self.root_instrument.write(f'SOURce{channel}:CASSet:WAVeform "{name}"') + self.root_instrument.write(f'SOURce{self.channel}:CASSet:WAVeform "{name}"') def setSequenceTrack(self, seqname: str, tracknr: int) -> None: """ @@ -420,10 +420,10 @@ def waveformList(self) -> List[str]: """ Return the waveform list as a list of strings """ - resp = self.ask("WLISt:LIST?") - resp = resp.strip() - resp = resp.replace('"', '') - resp = resp.split(',') + respstr = self.ask("WLISt:LIST?") + respstr = respstr.strip() + respstr = respstr.replace('"', '') + resp = respstr.split(',') return resp @@ -798,7 +798,7 @@ def makeSEQXFile(trig_waits: Sequence[int], for ch in range(1, chans+1)] # generate wfmx files for the waveforms - flat_wfmxs = [] + flat_wfmxs = [] # type: List[bytes] for amplitude, wfm_lst in zip(amplitudes, wfms): flat_wfmxs += [AWG70000A.makeWFMXFile(wfm, amplitude) for wfm in wfm_lst] diff --git a/qcodes/instrument_drivers/tektronix/Keithley_2600_channels.py b/qcodes/instrument_drivers/tektronix/Keithley_2600_channels.py index 536e004ab32..8f3849b68eb 100644 --- a/qcodes/instrument_drivers/tektronix/Keithley_2600_channels.py +++ b/qcodes/instrument_drivers/tektronix/Keithley_2600_channels.py @@ -24,9 +24,8 @@ def __init__(self, name: str, instrument: Instrument) -> None: super().__init__(name=name, shape=(1,), - docstring='Holds a sweep') - - self._instrument = instrument + docstring='Holds a sweep', + instrument=instrument) def prepareSweep(self, start: float, stop: float, steps: int, mode: str) -> None: @@ -69,10 +68,13 @@ def prepareSweep(self, start: float, stop: float, steps: int, def get_raw(self) -> np.ndarray: - data = self._instrument._fast_sweep(self.start, - self.stop, - self.steps, - self.mode) + if self._instrument is not None: + data = self._instrument._fast_sweep(self.start, + self.stop, + self.steps, + self.mode) + else: + raise RuntimeError("No instrument attached to Parameter") return data diff --git a/qcodes/instrument_drivers/test.py b/qcodes/instrument_drivers/test.py index 1169792f513..1849c263ee4 100644 --- a/qcodes/instrument_drivers/test.py +++ b/qcodes/instrument_drivers/test.py @@ -1,5 +1,5 @@ import unittest - +from typing import Optional ''' This module defines: @@ -26,7 +26,8 @@ class DriverTestCase(unittest.TestCase): - driver = None # override this in a subclass + # override this in a subclass + driver = None # type: Optional[type] @classmethod def setUpClass(cls): diff --git a/qcodes/instrument_drivers/yokogawa/GS200.py b/qcodes/instrument_drivers/yokogawa/GS200.py index f3896ebe2cc..7667880f0d2 100644 --- a/qcodes/instrument_drivers/yokogawa/GS200.py +++ b/qcodes/instrument_drivers/yokogawa/GS200.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Optional +from typing import Optional, Union from qcodes import VisaInstrument, InstrumentChannel from qcodes.utils.validators import Numbers, Bool, Enum, Ints @@ -195,7 +195,7 @@ def __init__(self, name: str, address: str, terminator: str="\n", # We want to cache the range value so communication with the instrument only happens when the set the # range. Getting the range always returns the cached value. This value is adjusted when calling # self._set_range - self._cached_range_value = None # type: Optional[float] + self._cached_range_value = None # type: Optional[Union[float,int]] self.add_parameter('voltage_range', label='Voltage Source Range', @@ -382,6 +382,9 @@ def _set_output(self, output_level: float) -> None: if not auto_enabled: self_range = self._cached_range_value + if self_range is None: + raise RuntimeError("Trying to set output but not in" + " auto mode and range is unknown.") else: mode = self._cached_mode if mode == "CURR": @@ -396,6 +399,9 @@ def _set_output(self, output_level: float) -> None: # Update range self.range() self_range = self._cached_range_value + if self_range is None: + raise RuntimeError("Trying to set output but not in" + " auto mode and range is unknown.") # If we are still out of range, raise a value error if abs(output_level) > abs(self_range): raise ValueError("Desired output level not in range [-{self_range:.3}, {self_range:.3}]".format( diff --git a/qcodes/monitor/monitor.py b/qcodes/monitor/monitor.py index bbaf34552d7..b0585a780fa 100644 --- a/qcodes/monitor/monitor.py +++ b/qcodes/monitor/monitor.py @@ -22,8 +22,9 @@ from copy import deepcopy from threading import Thread -from typing import Dict +from typing import Dict, Any from asyncio import CancelledError +import functools import websockets @@ -32,14 +33,14 @@ log = logging.getLogger(__name__) -def _get_metadata(*parameters) -> Dict[float, list]: +def _get_metadata(*parameters) -> Dict[str, Any]: """ Return a dict that contains the parameter metadata grouped by the instrument it belongs to. """ ts = time.time() # group meta data by instrument if any - metas = {} + metas = {} # type: Dict for parameter in parameters: _meta = getattr(parameter, "_latest", None) if _meta: @@ -59,11 +60,11 @@ def _get_metadata(*parameters) -> Dict[float, list]: accumulator = metas.get(str(baseinst), []) accumulator.append(meta) metas[str(baseinst)] = accumulator - parameters = [] + parameters_out = [] for instrument in metas: temp = {"instrument": instrument, "parameters": metas[instrument]} - parameters.append(temp) - state = {"ts": ts, "parameters": parameters} + parameters_out.append(temp) + state = {"ts": ts, "parameters": parameters_out} return state @@ -167,7 +168,7 @@ def join(self, timeout=None) -> None: except RuntimeError as e: # the above may throw a runtime error if the loop is already # stopped in which case there is nothing more to do - log.exception(e) + log.exception("Could not close loop") while not self.loop_is_closed: log.debug("waiting for loop to stop and close") time.sleep(0.01) diff --git a/qcodes/plots/pyqtgraph.py b/qcodes/plots/pyqtgraph.py index 7faa43c71c6..05ea316b33f 100644 --- a/qcodes/plots/pyqtgraph.py +++ b/qcodes/plots/pyqtgraph.py @@ -1,12 +1,15 @@ """ Live plotting using pyqtgraph """ -from typing import Optional, Dict, Union +from typing import Optional, Dict, Union, Deque, List, cast import numpy as np import pyqtgraph as pg import pyqtgraph.multiprocess as pgmp + +from pyqtgraph.multiprocess.remoteproxy import ClosedError, ObjectProxy +from pyqtgraph.graphicsItems.PlotItem.PlotItem import PlotItem from pyqtgraph import QtGui -from pyqtgraph.multiprocess.remoteproxy import ClosedError + import qcodes.utils.helpers import warnings @@ -59,7 +62,8 @@ class QtPlot(BasePlot): # close event on win but this is difficult with remote proxy process # as the list of plots lives in the main process and the plot locally # in a remote process - plots = deque(maxlen=qcodes.config['gui']['pyqtmaxplots']) + max_len = qcodes.config['gui']['pyqtmaxplots'] # type: int + plots = deque(maxlen=max_len) # type: Deque['QtPlot'] def __init__(self, *args, figsize=(1000, 600), interval=0.25, window_title='', theme=((60, 60, 60), 'w'), show_window=True, @@ -99,8 +103,7 @@ def __init__(self, *args, figsize=(1000, 600), interval=0.25, self._orig_fig_size = figsize self.set_relative_window_position(fig_x_position, fig_y_position) - - self.subplots = [self.add_subplot()] + self.subplots = [self.add_subplot()] # type: List[Union[PlotItem, ObjectProxy]] if args or kwargs: self.add(*args, **kwargs) @@ -140,7 +143,7 @@ def clear(self): """ self.win.clear() self.traces = [] - self.subplots = [] + self.subplots = [] # type: List[Union[PlotItem, ObjectProxy]] def add_subplot(self): subplot_object = self.win.addPlot() @@ -514,7 +517,7 @@ def setGeometry(self, x, y, w, h): """ Set geometry of the plotting window """ self.win.setGeometry(x, y, w, h) - def autorange(self, reset_colorbar: bool=False): + def autorange(self, reset_colorbar: bool=False) -> None: """ Auto range all limits in case they were changed during interactive plot. Reset colormap if changed and resize window to original size. @@ -522,7 +525,11 @@ def autorange(self, reset_colorbar: bool=False): reset_colorbar: Should the limits and colorscale of the colorbar be reset. Off by default """ - for subplot in self.subplots: + # seem to be a bug in mypy but the type of self.subplots cannot be + # deducted even when typed above so ignore it and cast for now + subplots = self.subplots # type: ignore + subplots = cast(List[Union[PlotItem,ObjectProxy]], subplots) + for subplot in subplots: vBox = subplot.getViewBox() vBox.enableAutoRange(vBox.XYAxes) cmap = None @@ -558,7 +565,11 @@ def fixUnitScaling(self, startranges: Optional[Dict[str, Dict[str, Union[float,i axismapping = {'x': 'bottom', 'y': 'left'} standardunits = self.standardunits - for i, plot in enumerate(self.subplots): + # seem to be a bug in mypy but the type of self.subplots cannot be + # deducted even when typed above so ignore it and cast for now + subplots = self.subplots # type: ignore + subplots = cast(List[Union[PlotItem,ObjectProxy]], subplots) + for i, plot in enumerate(subplots): # make a dict mapping axis labels to axis positions for axis in ('x', 'y', 'z'): if self.traces[i]['config'].get(axis) is not None: diff --git a/qcodes/station.py b/qcodes/station.py index 7b6cd9a020b..ba8138b7807 100644 --- a/qcodes/station.py +++ b/qcodes/station.py @@ -1,4 +1,5 @@ """Station objects - collect all the equipment you use to do an experiment.""" +from typing import Dict, List, Optional, Sequence, Any from qcodes.utils.metadata import Metadatable from qcodes.utils.helpers import make_unique, DelegateAttributes @@ -25,7 +26,7 @@ class Station(Metadatable, DelegateAttributes): *components (list[Any]): components to add immediately to the Station. can be added later via self.add_component - monitor (None): Not implememnted, the object that monitors the system continuously + monitor (None): Not implemented, the object that monitors the system continuously default (bool): is this station the default, which gets used in Loops and elsewhere that a Station can be specified, default true @@ -40,10 +41,11 @@ class Station(Metadatable, DelegateAttributes): attributes of self """ - default = None + default = None # type: 'Station' - def __init__(self, *components, monitor=None, default=True, - update_snapshot=True, **kwargs): + def __init__(self, *components: Metadatable, + monitor: Any=None, default: bool=True, + update_snapshot: bool=True, **kwargs) -> None: super().__init__(**kwargs) # when a new station is defined, store it in a class variable @@ -55,21 +57,22 @@ def __init__(self, *components, monitor=None, default=True, if default: Station.default = self - self.components = {} + self.components = {} # type: Dict[str, Metadatable] for item in components: self.add_component(item, update_snapshot=update_snapshot) self.monitor = monitor - self.default_measurement = [] + self.default_measurement = [] # type: List - def snapshot_base(self, update=False): + def snapshot_base(self, update: bool=False, + params_to_skip_update: Sequence[str]=None) -> Dict: """ State of the station as a JSON-compatible dict. Args: update (bool): If True, update the state by querying the - all the childs: f.ex. instruments, parameters, components, etc. + all the children: f.ex. instruments, parameters, components, etc. If False, just use the latest values in memory. Returns: @@ -96,7 +99,8 @@ def snapshot_base(self, update=False): return snap - def add_component(self, component, name=None, update_snapshot=True): + def add_component(self, component: Metadatable, name: str=None, + update_snapshot: bool=True) -> str: """ Record one component as part of this Station. @@ -118,9 +122,9 @@ def add_component(self, component, name=None, update_snapshot=True): if name is None: name = getattr(component, 'name', 'component{}'.format(len(self.components))) - name = make_unique(str(name), self.components) - self.components[name] = component - return name + namestr = make_unique(str(name), self.components) + self.components[namestr] = component + return namestr def set_measurement(self, *actions): """ diff --git a/qcodes/tests/instrument_mocks.py b/qcodes/tests/instrument_mocks.py index c6b0518941e..9c60e8b4bd9 100644 --- a/qcodes/tests/instrument_mocks.py +++ b/qcodes/tests/instrument_mocks.py @@ -57,8 +57,6 @@ class MockMetaParabola(Instrument): ''' Test for a meta instrument, has a tunable gain knob ''' - # TODO (giulioungaretti) remove unneded shared_kwargs - shared_kwargs = ['mock_parabola_inst'] def __init__(self, name, mock_parabola_inst, **kw): super().__init__(name, **kw) diff --git a/qcodes/tests/test_helpers.py b/qcodes/tests/test_helpers.py index 40c011ccae7..b1d4546032c 100644 --- a/qcodes/tests/test_helpers.py +++ b/qcodes/tests/test_helpers.py @@ -104,7 +104,7 @@ def f_async_old(): class TestIsSequence(TestCase): - def a_func(): + def a_func(self): raise RuntimeError('this function shouldn\'t get called') class AClass(): diff --git a/qcodes/tests/test_location_provider.py b/qcodes/tests/test_location_provider.py index 43465bcc617..8dabea25b6c 100644 --- a/qcodes/tests/test_location_provider.py +++ b/qcodes/tests/test_location_provider.py @@ -23,11 +23,11 @@ def test_missing(self): def _default(time: datetime, formatter: FormatLocation, counter:str, name: str): date = time.strftime(formatter.fmt_date) - time = time.strftime(formatter.fmt_time) + mytime = time.strftime(formatter.fmt_time) fmted = formatter.formatter.format(formatter.default_fmt, date=date, counter=counter, - time=time, + time=mytime, name=name) return fmted diff --git a/qcodes/tests/test_parameter.py b/qcodes/tests/test_parameter.py index 8c3e8574cc7..dcfad9599b9 100644 --- a/qcodes/tests/test_parameter.py +++ b/qcodes/tests/test_parameter.py @@ -3,6 +3,7 @@ """ from collections import namedtuple from unittest import TestCase +from typing import Tuple from time import sleep import numpy as np @@ -45,7 +46,7 @@ def validate(self, value, context=''): None, # no instrument at all namedtuple('noname', '')(), # no .name namedtuple('blank', 'name')('') # blank .name -) +) # type: Tuple named_instrument = namedtuple('yesname', 'name')('astro') diff --git a/qcodes/tests/test_visa.py b/qcodes/tests/test_visa.py index ee4fd21414f..dd024f8c91f 100644 --- a/qcodes/tests/test_visa.py +++ b/qcodes/tests/test_visa.py @@ -31,17 +31,18 @@ class MockVisaHandle: ''' def __init__(self): self.state = 0 + self.closed = False def clear(self): self.state = 0 def close(self): # make it an error to ask or write after close - self.write = None - self.ask = None - self.query = None + self.closed = True def write(self, cmd): + if self.closed: + raise RuntimeError("Trying to write to a closed instrument") num = float(cmd.split(':')[-1]) self.state = num @@ -56,6 +57,8 @@ def write(self, cmd): return len(cmd), ret_code def ask(self, cmd): + if self.closed: + raise RuntimeError("Trying to ask a closed instrument") if self.state > 10: raise ValueError("I'm out of fingers") return self.state diff --git a/qcodes/utils/helpers.py b/qcodes/utils/helpers.py index 4472324cda0..9ade2ca6847 100644 --- a/qcodes/utils/helpers.py +++ b/qcodes/utils/helpers.py @@ -9,10 +9,11 @@ from collections import Iterator, Sequence, Mapping from copy import deepcopy +from typing import Dict, List import numpy as np -_tprint_times = {} +_tprint_times= {} # type: Dict[str, float] log = logging.getLogger(__name__) @@ -309,9 +310,9 @@ class DelegateAttributes: 2. keys of each dict in delegate_attr_dicts (in order) 3. attributes of each object in delegate_attr_objects (in order) """ - delegate_attr_dicts = [] - delegate_attr_objects = [] - omit_delegate_attrs = [] + delegate_attr_dicts = [] # type: List[str] + delegate_attr_objects = [] # type: List[str] + omit_delegate_attrs = [] # type: List[str] def __getattr__(self, key): if key in self.omit_delegate_attrs: @@ -503,13 +504,12 @@ def add_to_spyder_UMR_excludelist(modulename: str): if any('SPYDER' in name for name in os.environ): try: - from spyder.utils.site.sitecustomize import UserModuleReloader - global __umr__ + from spyder.utils.site import sitecustomize excludednamelist = os.environ.get('SPY_UMR_NAMELIST', '').split(',') if modulename not in excludednamelist: log.info("adding {} to excluded modules".format(modulename)) excludednamelist.append(modulename) - __umr__ = UserModuleReloader(namelist=excludednamelist) + sitecustomize.__umr__ = sitecustomize.UserModuleReloader(namelist=excludednamelist) except ImportError: pass diff --git a/qcodes/utils/validators.py b/qcodes/utils/validators.py index f8d7eeb9bf2..f559550713e 100644 --- a/qcodes/utils/validators.py +++ b/qcodes/utils/validators.py @@ -1,5 +1,5 @@ import math -from typing import Union, Tuple, cast +from typing import Union, Tuple, cast, Optional import numpy as np @@ -385,7 +385,7 @@ def __init__(self, divisor: Union[float, int, np.floating], self.precision = precision self._numval = Numbers() if isinstance(divisor, int): - self._mulval = Multiples(divisor=abs(divisor)) + self._mulval = Multiples(divisor=abs(divisor)) # type: Optional[Multiples] else: self._mulval = None self._valid_values = [divisor] @@ -407,7 +407,7 @@ def validate(self, value: Union[float, int, np.floating], # multiply our way out of the problem by constructing true # multiples in the relevant range and see if `value` is one # of them (within rounding errors) - divs = int(divmod(value, self.divisor)[0]) + divs = int(divmod(value, self.divisor)[0]) # type: ignore true_vals = np.array([n*self.divisor for n in range(divs, divs+2)]) abs_errs = [abs(tv-value) for tv in true_vals] if min(abs_errs) > self.precision: diff --git a/qcodes/utils/zmq_helpers.py b/qcodes/utils/zmq_helpers.py index 518f1362836..c0940199d3a 100644 --- a/qcodes/utils/zmq_helpers.py +++ b/qcodes/utils/zmq_helpers.py @@ -16,7 +16,7 @@ class UnboundedPublisher: def __init__(self, topic: str, interface_or_socket: str="tcp://localhost:5559", - context: zmq.Context = None): + context: zmq.Context = None) -> None: """ Args: @@ -49,7 +49,7 @@ class Publisher(UnboundedPublisher): def __init__(self, topic: str, interface_or_socket: str="tcp://localhost:5559", timeout: int = _LINGER*10, - hwm: int = _ZMQ_HWM*5, context: zmq.Context = None): + hwm: int = _ZMQ_HWM*5, context: zmq.Context = None) -> None: """ Args: diff --git a/test_requirements.txt b/test_requirements.txt index 3d4f33c91fe..b364fdd5bbb 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -3,6 +3,7 @@ pytest-cov pytest codacy-coverage hypothesis +mypy!=0.570 # due to https://github.com/python/mypy/issues/4674 git+https://github.com/QCoDeS/pyvisa-sim.git lxml codecov