Skip to content

Commit

Permalink
add Result type
Browse files Browse the repository at this point in the history
  • Loading branch information
philopon committed Aug 16, 2017
1 parent b52520a commit 8d2e7b7
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 46 deletions.
3 changes: 3 additions & 0 deletions docs/mordred.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ mordred package

.. automodule:: mordred

.. autoclass:: mordred.Result
:members:

.. autoclass:: mordred.Descriptor
:members:

Expand Down
2 changes: 2 additions & 0 deletions mordred/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Descriptor,
get_descriptors_from_module,
is_missing,
Result,
)

from ._version import __version__
Expand All @@ -17,4 +18,5 @@
"all_descriptors",
"get_descriptors_from_module",
"is_missing",
"Result",
)
43 changes: 3 additions & 40 deletions mordred/_base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
"""Mordred base package."""

import os
import warnings

from importlib import import_module
from ..error import MissingValueBase

from .descriptor import (
Expand All @@ -12,6 +8,8 @@
)
from .calculator import Calculator, get_descriptors_from_module
from .parallel import parallel
from .util import is_missing, all_descriptors
from .result import Result


__all__ = (
Expand All @@ -20,32 +18,10 @@
"Calculator",
"get_descriptors_from_module",
"is_missing",
"Result",
)


def all_descriptors():
r"""**[deprecated]** use mordred.descriptors module instead.
yield all descriptor modules.
:returns: all modules
:rtype: :py:class:`Iterator` (:py:class:`Descriptor`)
"""
warnings.warn(
"all_descriptors() is deprecated, use mordred.descriptors module instead",
DeprecationWarning,
stacklevel=2,
)
base_dir = os.path.dirname(os.path.dirname(__file__))

for name in os.listdir(base_dir):
name, ext = os.path.splitext(name)
if name[:1] == "_" or ext != ".py" or name == "descriptors":
continue

yield import_module(".." + name, __package__)


def _Descriptor__call__(self, mol, id=-1):
r"""Calculate single descriptor value.
Expand Down Expand Up @@ -116,19 +92,6 @@ def _Descriptor_from_json(self, obj):
return _from_json(obj, descs)


def is_missing(v):
"""Check argument is either MissingValue or not.
Parameters:
v(any): value
Returns:
bool
"""
return isinstance(v, MissingValueBase)


Descriptor.__call__ = _Descriptor__call__
Descriptor.from_json = _Descriptor_from_json
Calculator._parallel = parallel
14 changes: 10 additions & 4 deletions mordred/_base/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .._util import Capture, DummyBar, NotebookWrapper
from ..error import Error, Missing, MultipleFragments, DuplicatedDescriptorName
from .result import Result
from .context import Context
from .descriptor import Descriptor, MissingValueException, is_descriptor_class

Expand Down Expand Up @@ -235,16 +236,21 @@ def __call__(self, mol, id=-1):
:type id: int
:param id: conformer id
:rtype: [scalar or Error]
:rtype: Result[scalar or Error]
:returns: iterator of descriptor and value
"""
return list(self._calculate(Context.from_calculator(self, mol, id)))
return self._wrap_result(
self._calculate(Context.from_calculator(self, mol, id)),
)

def _wrap_result(self, r):
return Result(r, self._descriptors)

def _serial(self, mols, nmols, quiet, ipynb, id):
with self._progress(quiet, nmols, ipynb) as bar:
for m in mols:
with Capture() as capture:
r = list(self._calculate(Context.from_calculator(self, m, id)))
r = self._wrap_result(self._calculate(Context.from_calculator(self, m, id)))

for e in capture.result:
e = e.rstrip()
Expand Down Expand Up @@ -314,7 +320,7 @@ def map(self, mols, nproc=None, nmols=None, quiet=False, ipynb=False, id=-1):
id(int): conformer id to use. default: -1.
Returns:
Iterator[scalar]
Iterator[Result[scalar]]
"""
if hasattr(mols, "__len__"):
Expand Down
2 changes: 1 addition & 1 deletion mordred/_base/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def do_task(mol):

bar.write(e)

yield r
yield self._wrap_result(r)
bar.update()

finally:
Expand Down
63 changes: 63 additions & 0 deletions mordred/_base/result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from .util import is_missing


class Result(list):
r"""Result type."""

def __init__(self, r, d):
super(Result, self).__init__(r)
self._descriptors = d

def fillna(self, value=float('nan')):
r"""Replace missing value to 'value'.
Parameters:
value: value that missing value is replaced
Returns:
Result
"""
return self.__class__(
[(value if is_missing(v) else v) for v in self],
self._descriptors,
)

def dropna(self):
r"""Delete missing value.
Returns:
Result
"""
newvalues = []
newdescs = []
for v, d in zip(self, self._descriptors):
if not is_missing(v):
newvalues.append(v)
newdescs.append(d)

return self.__class__(newvalues, newdescs)

def asdict(self, rawkey=False):
r"""Convert Result to dict.
Parameters:
rawkey(bool):
* True: dict key is Descriptor instance
* False: dict key is str
Returns:
dict
"""
if rawkey:
def keyconv(k):
return k
else:
keyconv = str

return {
keyconv(k): v
for k, v in zip(self._descriptors, self)
}
41 changes: 41 additions & 0 deletions mordred/_base/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import warnings
from importlib import import_module

from ..error import MissingValueBase


def all_descriptors():
r"""**[deprecated]** use mordred.descriptors module instead.
yield all descriptor modules.
:returns: all modules
:rtype: :py:class:`Iterator` (:py:class:`Descriptor`)
"""
warnings.warn(
"all_descriptors() is deprecated, use mordred.descriptors module instead",
DeprecationWarning,
stacklevel=2,
)
base_dir = os.path.dirname(os.path.dirname(__file__))

for name in os.listdir(base_dir):
name, ext = os.path.splitext(name)
if name[:1] == "_" or ext != ".py" or name == "descriptors":
continue

yield import_module(".." + name, __package__)


def is_missing(v):
"""Check argument is either MissingValue or not.
Parameters:
v(any): value
Returns:
bool
"""
return isinstance(v, MissingValueBase)
2 changes: 1 addition & 1 deletion scripts/requirements-flake8.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pep8
flake8
flake8-isort
flake8-double-quotes
flake8-print
flake8-commas
Expand Down

0 comments on commit 8d2e7b7

Please sign in to comment.