Skip to content

Commit

Permalink
1. Add col_name as default args in feature_preprocess (Fix #105)
Browse files Browse the repository at this point in the history
2. Update TransAct with adjust_mask
  • Loading branch information
xpai committed Aug 12, 2024
1 parent 72667e3 commit 3a060ba
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion fuxictr/datasets/criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


class CustomizedFeatureProcessor(FeatureProcessor):
def convert_to_bucket(self, col_name=None):
def convert_to_bucket(self, col_name):
def _convert_to_bucket(value):
if value > 2:
value = int(np.floor(np.log(value) ** 2))
Expand Down
6 changes: 3 additions & 3 deletions fuxictr/datasets/kkbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@


class CustomizedFeatureProcessor(FeatureProcessor):
def extract_country_code(self, col_name=None):
def extract_country_code(self, col_name):
return pl.col(col_name).apply(lambda isrc: isrc[0:2] if not pl.is_null(isrc) else "")

def bucketize_age(self, col_name=None):
def bucketize_age(self, col_name):
def _bucketize(age):
if pd.isnull(age):
if pl.is_null(age):
return ""
else:
age = float(age)
Expand Down
10 changes: 7 additions & 3 deletions fuxictr/preprocess/feature_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,12 @@ def preprocess(self, ddf):
if col.get("preprocess"):
preprocess_args = re.split(r"\(|\)", col["preprocess"])
preprocess_fn = getattr(self, preprocess_args[0])
if len(preprocess_args) == 1:
preprocess_args = [name] # use col_name as args when not being explicitly set
else:
preprocess_args = preprocess_args[1:-1]
ddf = ddf.with_columns(
preprocess_fn(*preprocess_args[1:-1])
preprocess_fn(*preprocess_args)
.alias(name)
.cast(self.dtype_dict[name])
)
Expand Down Expand Up @@ -364,5 +368,5 @@ def save_vocab(self, vocab_file):
with open(vocab_file, "w") as fd:
fd.write(json.dumps(vocab, indent=4))

def copy_from(self, col_name):
return pl.col(col_name)
def copy_from(self, src_col):
return pl.col(src_col)
3 changes: 1 addition & 2 deletions model_zoo/TransAct/src/TransAct.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def forward(self, target_emb, sequence_emb, time_interval_seq=None, mask=None):
concat_seq_emb = torch.cat([sequence_emb,
target_emb.unsqueeze(1).expand(-1, seq_len, -1)], dim=-1)
# get sequence mask (1's are masked)
key_padding_mask = mask
key_padding_mask = self.adjust_mask(mask) # keep the last dim
if self.use_time_window_mask and self.training:
rand_time_window_ms = random.randint(0, self.time_window_ms)
time_window_mask = (time_interval_seq < rand_time_window_ms)
Expand All @@ -235,7 +235,6 @@ def forward(self, target_emb, sequence_emb, time_interval_seq=None, mask=None):
output_concat.append(tfmr_out[:, -self.first_k_cols:].flatten(start_dim=1))
if self.concat_max_pool:
# Apply max pooling to the transformer output
key_padding_mask = self.adjust_mask(key_padding_mask) # keep the last dim
tfmr_out = tfmr_out.masked_fill(
key_padding_mask.unsqueeze(-1).repeat(1, 1, tfmr_out.shape[-1]), -1e9
)
Expand Down

0 comments on commit 3a060ba

Please sign in to comment.