Skip to content

Commit

Permalink
Merge pull request #428 from bkhant1/all_optis
Browse files Browse the repository at this point in the history
Optimise `HashingEncoder` for both large and small dataframes
  • Loading branch information
PaulWestenthanner authored Nov 11, 2023
2 parents 26ef261 + e2c1b79 commit 5c94e27
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 107 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
unreleased
==========
* improved: performance of the hashing encoder (about twice as fast)
* deprecate the `max_sample`` parameter, it has no use anymore
* add `process_creation_method` parameter
* use concurrent.futures.ProcessPoolExecutor instead of hand-managed queues
* optimisations to hashlib calls, remove python 2 checks, fork instead of spawn

v2.6.3
======
Expand Down
185 changes: 78 additions & 107 deletions category_encoders/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import category_encoders.utils as util
import multiprocessing
import pandas as pd
import numpy as np
import math
import platform
from concurrent.futures import ProcessPoolExecutor

__author__ = 'willmcginnis', 'LiuShulun'

Expand Down Expand Up @@ -56,6 +58,12 @@ class HashingEncoder(util.BaseEncoder, util.UnsupervisedTransformerMixin):
n_components: int
how many bits to use to represent the feature. By default, we use 8 bits.
For high-cardinality features, consider using up-to 32 bits.
process_creation_method: string
either "fork", "spawn" or "forkserver" (availability depends on your
platform). See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
for more details and tradeoffs. Defaults to "fork" on linux/macos as it
is the fastest option and to "spawn" on windows as it is the only one
available
Example
-------
Expand Down Expand Up @@ -103,12 +111,12 @@ class HashingEncoder(util.BaseEncoder, util.UnsupervisedTransformerMixin):
encoding_relation = util.EncodingRelation.ONE_TO_M

def __init__(self, max_process=0, max_sample=0, verbose=0, n_components=8, cols=None, drop_invariant=False,
return_df=True, hash_method='md5'):
return_df=True, hash_method='md5', process_creation_method='fork'):
super().__init__(verbose=verbose, cols=cols, drop_invariant=drop_invariant, return_df=return_df,
handle_unknown="does not apply", handle_missing="does not apply")

if max_process not in range(1, 128):
if platform.system == 'Windows':
if platform.system() == 'Windows':
self.max_process = 1
else:
self.max_process = int(math.ceil(multiprocessing.cpu_count() / 2))
Expand All @@ -119,7 +127,10 @@ def __init__(self, max_process=0, max_sample=0, verbose=0, n_components=8, cols=
else:
self.max_process = max_process
self.max_sample = int(max_sample)
self.auto_sample = max_sample <= 0
if platform.system() == 'Windows':
self.process_creation_method = "spawn"
else:
self.process_creation_method = process_creation_method
self.data_lines = 0
self.X = None

Expand All @@ -129,87 +140,7 @@ def __init__(self, max_process=0, max_sample=0, verbose=0, n_components=8, cols=
def _fit(self, X, y=None, **kwargs):
pass

def require_data(self, data_lock, new_start, done_index, hashing_parts, process_index):
is_finished = False
while not is_finished:
if data_lock.acquire():
if new_start.value:
end_index = 0
new_start.value = False
else:
end_index = done_index.value

if all([self.data_lines > 0, end_index < self.data_lines]):
start_index = end_index
if (self.data_lines - end_index) <= self.max_sample:
end_index = self.data_lines
else:
end_index += self.max_sample
done_index.value = end_index
data_lock.release()

data_part = self.X.iloc[start_index: end_index]
# Always get df and check it after merge all data parts
data_part = self.hashing_trick(X_in=data_part, hashing_method=self.hash_method,
N=self.n_components, cols=self.cols)
part_index = int(math.ceil(end_index / self.max_sample))
hashing_parts.put({part_index: data_part})
is_finished = end_index >= self.data_lines
if self.verbose == 5:
print(f"Process - {process_index} done hashing data : {start_index} ~ {end_index}")
else:
data_lock.release()
is_finished = True
else:
data_lock.release()

def _transform(self, X):
"""
Call _transform_single_cpu() if you want to use single CPU with all samples
"""
self.X = X

self.data_lines = len(self.X)

data_lock = multiprocessing.Manager().Lock()
new_start = multiprocessing.Manager().Value('d', True)
done_index = multiprocessing.Manager().Value('d', int(0))
hashing_parts = multiprocessing.Manager().Queue()

if self.auto_sample:
self.max_sample = int(self.data_lines / self.max_process)

if self.max_sample == 0:
self.max_sample = 1
if self.max_process == 1:
self.require_data(data_lock, new_start, done_index, hashing_parts, process_index=1)
else:
n_process = []
for thread_idx in range(self.max_process):
process = multiprocessing.Process(target=self.require_data,
args=(data_lock, new_start, done_index, hashing_parts, thread_idx + 1))
process.daemon = True
n_process.append(process)
for process in n_process:
process.start()
for process in n_process:
process.join()
data = self.X
if self.max_sample == 0 or self.max_sample == self.data_lines:
if hashing_parts:
data = list(hashing_parts.get().values())[0]
else:
list_data = {}
while not hashing_parts.empty():
list_data.update(hashing_parts.get())
sort_data = []
for part_index in sorted(list_data):
sort_data.append(list_data[part_index])
if sort_data:
data = pd.concat(sort_data)
return data

def _transform_single_cpu(self, X, override_return_df=False):
def _transform(self, X, override_return_df=False):
"""Perform the transformation to new categorical data.
Parameters
Expand Down Expand Up @@ -238,18 +169,66 @@ def _transform_single_cpu(self, X, override_return_df=False):
if not list(self.cols):
return X

X = self.hashing_trick(X, hashing_method=self.hash_method, N=self.n_components, cols=self.cols)

if self.drop_invariant:
X = X.drop(columns=self.invariant_cols)

if self.return_df or override_return_df:
return X
else:
return X.to_numpy()
X = self.hashing_trick(
X,
hashing_method=self.hash_method,
N=self.n_components,
cols=self.cols,
)

return X

@staticmethod
def hashing_trick(X_in, hashing_method='md5', N=2, cols=None, make_copy=False):
def hash_chunk(args):
hash_method, np_df, N = args
# Calling getattr outside the loop saves some time in the loop
hasher_constructor = getattr(hashlib, hash_method)
# Same when the call to getattr is implicit
int_from_bytes = int.from_bytes
result = np.zeros((np_df.shape[0], N), dtype='int')
for i, row in enumerate(np_df):
for val in row:
if val is not None:
hasher = hasher_constructor()
# Computes an integer index from the hasher digest. The endian is
# "big" as the code use to read:
# column_index = int(hasher.hexdigest(), 16) % N
# which is implicitly considering the hexdigest to be big endian,
# even if the system is little endian.
# Building the index that way is about 30% faster than using the
# hexdigest.
hasher.update(bytes(str(val), 'utf-8'))
column_index = int_from_bytes(hasher.digest(), byteorder='big') % N
result[i, column_index] += 1
return result

def hashing_trick_with_np_parallel(self, df, N: int):
np_df = df.to_numpy()
ctx = multiprocessing.get_context(self.process_creation_method)

with ProcessPoolExecutor(max_workers=self.max_process, mp_context=ctx) as executor:
result = np.concatenate(list(
executor.map(
self.hash_chunk,
zip(
[self.hash_method]*self.max_process,
np.array_split(np_df, self.max_process),
[N]*self.max_process
)
)
))

return pd.DataFrame(result, index=df.index)

def hashing_trick_with_np_no_parallel(self, df, N):
np_df = df.to_numpy()

result = HashingEncoder.hash_chunk((self.hash_method, np_df, N))

return pd.DataFrame(result, index=df.index)


def hashing_trick(self, X_in, hashing_method='md5', N=2, cols=None, make_copy=False):
"""A basic hashing implementation with configurable dimensionality/precision
Performs the hashing trick on a pandas dataframe, `X`, using the hashing method from hashlib
Expand Down Expand Up @@ -296,24 +275,16 @@ def hashing_trick(X_in, hashing_method='md5', N=2, cols=None, make_copy=False):
if cols is None:
cols = X.columns

def hash_fn(x):
tmp = [0 for _ in range(N)]
for val in x.array:
if val is not None:
hasher = hashlib.new(hashing_method)
if sys.version_info[0] == 2:
hasher.update(str(val))
else:
hasher.update(bytes(str(val), 'utf-8'))
tmp[int(hasher.hexdigest(), 16) % N] += 1
return tmp

new_cols = [f'col_{d}' for d in range(N)]

X_cat = X.loc[:, cols]
X_num = X.loc[:, [x for x in X.columns if x not in cols]]

X_cat = X_cat.apply(hash_fn, axis=1, result_type='expand')
if self.max_process == 1:
X_cat = self.hashing_trick_with_np_no_parallel(X_cat, N)
else:
X_cat = self.hashing_trick_with_np_parallel(X_cat, N)

X_cat.columns = new_cols

X = pd.concat([X_cat, X_num], axis=1)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,17 @@ def test_transform_works_with_single_row_df(self):
set(target_columns))
) == df_encoded_multi_process.shape[1]
)

def test_simple_example(self):
df = pd.DataFrame({
'strings': ["aaaa", "bbbb", "cccc"],
"more_strings": ["aaaa", "dddd", "eeee"],
})
encoder = encoders.HashingEncoder(n_components=4, max_process=2)
encoder.fit(df)
assert encoder.transform(df).equals(pd.DataFrame({
"col_0": [0,1,1],
"col_1": [2,0,1],
"col_2": [0,1,0],
"col_3": [0,0,0]
}))

0 comments on commit 5c94e27

Please sign in to comment.