Skip to content

Commit

Permalink
Merge threading and other numpy optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
johnarevalo committed Sep 13, 2023
2 parents 8d6ec0a + 55168c0 commit e80e36d
Show file tree
Hide file tree
Showing 4 changed files with 390 additions and 55 deletions.
26 changes: 26 additions & 0 deletions src/copairs/map.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import logging
import re

import numpy as np
import pandas as pd
Expand All @@ -11,12 +12,36 @@
logger = logging.getLogger('copairs')


def evaluate_and_filter(df, columns) -> list:
'''Evaluate the query and filter the dataframe'''
parsed_cols = []
for col in columns:
if col in df.columns:
parsed_cols.append(col)
continue

column_names = re.findall(r'(\w+)\s*[=<>!]+', col)
valid_column_names = [col for col in column_names if col in df.columns]
if not valid_column_names:
raise ValueError(f"Invalid query or column name: {col}")

try:
df = df.query(col)
parsed_cols.extend(valid_column_names)
except:
raise ValueError(f"Invalid query expression: {col}")

return df, parsed_cols


def flatten_str_list(*args):
'''create a single list with all the params given'''
columns = set()
for col in args:
if isinstance(col, str):
columns.add(col)
elif isinstance(col, dict):
columns.update(itertools.chain.from_iterable(col.values()))
else:
columns.update(col)
columns = list(columns)
Expand All @@ -30,6 +55,7 @@ def create_matcher(obs: pd.DataFrame,
neg_diffby,
multilabel_col=None):
columns = flatten_str_list(pos_sameby, pos_diffby, neg_sameby, neg_diffby)
obs, columns = evaluate_and_filter(obs, columns)
if multilabel_col:
return MatcherMultilabel(obs, columns, multilabel_col, seed=0)
return Matcher(obs, columns, seed=0)
Expand Down
Loading

0 comments on commit e80e36d

Please sign in to comment.