Skip to content

Commit

Permalink
Simplify method signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed May 18, 2019
1 parent 79221d4 commit 009382d
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions chatterbot/logic/unit_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from chatterbot.conversation import Statement
from chatterbot import languages
from chatterbot import parsing
from pint import UnitRegistry
from mathparse import mathparse
import re

Expand All @@ -23,6 +22,7 @@ class UnitConversion(LogicAdapter):

def __init__(self, chatbot, **kwargs):
super().__init__(chatbot, **kwargs)
from pint import UnitRegistry

self.language = kwargs.get('language', languages.ENG)
self.cache = {}
Expand Down Expand Up @@ -63,34 +63,29 @@ def __init__(self, chatbot, **kwargs):
lambda m: self.handle_matches(m)
)
]
self.unit_registry = UnitRegistry()

def get_unit(self, ureg, unit_variations):
def get_unit(self, unit_variations):
"""
Get the first match unit metric object supported by pint library
given a variation of unit metric names (Ex:['HOUR', 'hour']).
:param ureg: unit registry which units are defined and handled
:type ureg: pint.registry.UnitRegistry object
:param unit_variations: A list of strings with names of units
:type unit_variations: str
"""
for unit in unit_variations:
try:
return getattr(ureg, unit)
return getattr(self.unit_registry, unit)
except Exception:
continue
return None

def get_valid_units(self, ureg, from_unit, target_unit):
def get_valid_units(self, from_unit, target_unit):
"""
Returns the firt match `pint.unit.Unit` object for from_unit and
target_unit strings from a possible variation of metric unit names
supported by pint library.
:param ureg: unit registry which units are defined and handled
:type ureg: `pint.registry.UnitRegistry`
:param from_unit: source metric unit
:type from_unit: str
Expand All @@ -99,8 +94,8 @@ def get_valid_units(self, ureg, from_unit, target_unit):
"""
from_unit_variations = [from_unit.lower(), from_unit.upper()]
target_unit_variations = [target_unit.lower(), target_unit.upper()]
from_unit = self.get_unit(ureg, from_unit_variations)
target_unit = self.get_unit(ureg, target_unit_variations)
from_unit = self.get_unit(from_unit_variations)
target_unit = self.get_unit(target_unit_variations)
return from_unit, target_unit

def handle_matches(self, match):
Expand All @@ -121,13 +116,12 @@ def handle_matches(self, match):

n = mathparse.parse(n_statement, self.language.ISO_639.upper())

ureg = UnitRegistry()
from_parsed, target_parsed = self.get_valid_units(ureg, from_parsed, target_parsed)
from_parsed, target_parsed = self.get_valid_units(from_parsed, target_parsed)

if from_parsed is None or target_parsed is None:
response.confidence = 0.0
else:
from_value = ureg.Quantity(float(n), from_parsed)
from_value = self.unit_registry.Quantity(float(n), from_parsed)
target_value = from_value.to(target_parsed)
response.confidence = 1.0
response.text = str(target_value.magnitude)
Expand Down

0 comments on commit 009382d

Please sign in to comment.