Skip to content

Commit

Permalink
refactor: move helper functions to utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cowana-ai committed Oct 4, 2024
1 parent a9e28a9 commit 4dbbd76
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import ast
import keyword
import re
from collections import deque
from typing import Any
Expand All @@ -8,55 +7,11 @@
from omegaconf import OmegaConf

import feature_fabrica.transform.registry as registry
from feature_fabrica._internal.instantiate.expressions.utils import (
CLOSE_PARENTHESIS, FUNCTION_PATTERN, OPEN_PARENTHESIS, TOKEN_PATTERN,
get_precedence, get_transformation, is_function, is_numeric, is_operator,
is_valid_variable_name)

# Define operator precedence and corresponding transformations
BASIC_MATH_OPERATORS = {
'+': {'precedence': 1, 'transformation': 'SumReduce'},
'-': {'precedence': 1, 'transformation': 'SubtractReduce'},
',': {'precedence': 1, 'transformation': 'FeatureImporter'},
'*': {'precedence': 2, 'transformation': 'MultiplyReduce'},
'/': {'precedence': 2, 'transformation': 'DivideReduce'},
}
OPEN_PARENTHESIS = "("
CLOSE_PARENTHESIS = ")"
FUNCTION_PATTERN = r'\.(\w+)\((.*)\)'
#TOKEN_PATTERN = re.compile(r'\d+\.\d+|\d+|\b\w+\b|\.\w+\([^\)]*\)|[,()+\-*/]')
TOKEN_PATTERN = re.compile(r'\d+\.\d+|\d+|\b\w+:\w+\b|\b\w+\b|\.\w+\([^\)]*\)|[,()+\-*/]')

def is_operator(token: str) -> bool:
"""Check if the token is a mathematical operator."""
return token in BASIC_MATH_OPERATORS

def get_precedence(token: str) -> int:
return BASIC_MATH_OPERATORS[token]['precedence'] # type: ignore[return-value]

def get_transformation(op: str) -> str:
"""Get the corresponding transformation for the given operator."""
return BASIC_MATH_OPERATORS[op]['transformation'] # type: ignore[return-value]

def is_valid_variable_name(name: str) -> bool:
"""Check if the name is a valid Python variable name (non-keyword and identifier)."""
if ":" in name:
name_stage = name.split(":")
if len(name_stage) == 2:
name, transform_stage = name_stage
return is_valid_variable_name(name) and transform_stage in registry.TransformationRegistry.registry
else:
return False
return name.isidentifier() and not keyword.iskeyword(name)

def is_numeric(token: str) -> bool:
"""Check if the token can be converted to a number."""
try:
float(token)
return True
except ValueError:
return False

def is_function(token: str) -> bool:
"""Check if the token represents a function call."""
match = re.match(FUNCTION_PATTERN, token.strip())
return match is not None and match.group() == token

def tokenize(expression: str) -> list[str]:
"""Tokenize the feature-fabrica expression into numbers, variable names, operators, and functions.
Expand Down
60 changes: 60 additions & 0 deletions feature_fabrica/_internal/instantiate/expressions/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import keyword
import re
from functools import lru_cache

import feature_fabrica.transform.registry as registry

# Define operator precedence and corresponding transformations
BASIC_MATH_OPERATORS = {
'+': {'precedence': 1, 'transformation': 'SumReduce'},
'-': {'precedence': 1, 'transformation': 'SubtractReduce'},
',': {'precedence': 1, 'transformation': 'FeatureImporter'},
'*': {'precedence': 2, 'transformation': 'MultiplyReduce'},
'/': {'precedence': 2, 'transformation': 'DivideReduce'},
}
OPEN_PARENTHESIS = "("
CLOSE_PARENTHESIS = ")"
FUNCTION_PATTERN = r'\.(\w+)\((.*)\)'
#TOKEN_PATTERN = re.compile(r'\d+\.\d+|\d+|\b\w+\b|\.\w+\([^\)]*\)|[,()+\-*/]')
TOKEN_PATTERN = re.compile(r'\d+\.\d+|\d+|\b\w+:\w+\b|\b\w+\b|\.\w+\([^\)]*\)|[,()+\-*/]')

@lru_cache(maxsize=1024)
def is_operator(token: str) -> bool:
"""Check if the token is a mathematical operator."""
return token in BASIC_MATH_OPERATORS

def get_precedence(token: str) -> int:
return BASIC_MATH_OPERATORS[token]['precedence'] # type: ignore[return-value]

def get_transformation(op: str) -> str:
"""Get the corresponding transformation for the given operator."""
return BASIC_MATH_OPERATORS[op]['transformation'] # type: ignore[return-value]

@lru_cache(maxsize=1024)
def is_valid_variable_name(name: str) -> bool:
"""Check if the name is a valid Python variable name (non-keyword and identifier)."""
return (name.isidentifier() and not keyword.iskeyword(name)) or is_valid_promise_value(name)

@lru_cache(maxsize=1024)
def is_valid_promise_value(name: str) -> bool:
if ":" in name:
name_stage = name.split(":")
if len(name_stage) == 2:
name, transform_stage = name_stage
return is_valid_variable_name(name) and transform_stage in registry.TransformationRegistry.registry
return False

@lru_cache(maxsize=1024)
def is_numeric(token: str) -> bool:
"""Check if the token can be converted to a number."""
try:
float(token)
return True
except ValueError:
return False

@lru_cache(maxsize=1024)
def is_function(token: str) -> bool:
"""Check if the token represents a function call."""
match = re.match(FUNCTION_PATTERN, token.strip())
return match is not None and match.group() == token
6 changes: 4 additions & 2 deletions feature_fabrica/transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from easydict import EasyDict as edict
from omegaconf import OmegaConf

from feature_fabrica._internal.instantiate.expressions.utils import \
is_valid_promise_value
from feature_fabrica.models import PromiseValue
from feature_fabrica.promise_manager import get_promise_manager
from feature_fabrica.transform.registry import TransformationRegistry
Expand Down Expand Up @@ -51,10 +53,10 @@ def compile(self, feature_name: str, feature_dependencies: dict[str, Feature] |

# If cur_value is str and in features -> resolved immediately
if isinstance(cur_value, str):
if ":" in cur_value:
if is_valid_promise_value(cur_value):
dep_feature_name, transform_stage = cur_value.split(":")
if dep_feature_name not in feature_dependencies:
raise ValueError()
raise RuntimeError(f"Could not identify feature = {dep_feature_name}, make sure it's in feature dependencies!")
cur_value = promise_manager.get_promise_value(base_name=dep_feature_name, suffix=transform_stage)
elif cur_value in feature_dependencies:
cur_value = feature_dependencies[cur_value].feature_value
Expand Down

0 comments on commit 4dbbd76

Please sign in to comment.