Skip to content

Commit

Permalink
Fix max_seq_length variable using the maximum number of baskets in tr…
Browse files Browse the repository at this point in the history
…aining sequences as default (#591)
  • Loading branch information
lthoang authored Jan 22, 2024
1 parent 85ce38c commit 07760ba
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions cornac/models/beacon/recom_beacon.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class Beacon(NextBasketRecommender):
Number of hops for constructing correlation matrix.
If 0, zeros matrix will be used.
max_seq_length: int, optional, default: None
Maximum basket sequence length.
If None, it is the maximum number of basket in training sequences.
n_epochs: int, optional, default: 15
Number of training epochs
Expand Down Expand Up @@ -83,6 +87,7 @@ def __init__(
rnn_cell_type="LSTM",
dropout_rate=0.5,
nb_hop=1,
max_seq_length=None,
n_epochs=15,
batch_size=32,
lr=0.001,
Expand All @@ -99,10 +104,12 @@ def __init__(
self.alpha = alpha
self.rnn_cell_type = rnn_cell_type
self.dropout_rate = dropout_rate
self.max_seq_length = max_seq_length
self.seed = seed
self.lr = lr

def fit(self, train_set, val_set=None):
super().fit(train_set=train_set, val_set=val_set)
import tensorflow.compat.v1 as tf

from .beacon_tf import BeaconModel
Expand All @@ -113,8 +120,12 @@ def fit(self, train_set, val_set=None):
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.logging.set_verbosity(tf.logging.ERROR)

super().fit(train_set=train_set, val_set=val_set)

# max sequence length
self.max_seq_length = (
max([len(bids) for bids in train_set.user_basket_data.values()])
if self.max_seq_length is None # init max_seq_length
else self.max_seq_length
)
self.correlation_matrix = self._build_correlation_matrix(
train_set=train_set, val_set=val_set, n_items=self.total_items
)
Expand All @@ -132,7 +143,7 @@ def fit(self, train_set, val_set=None):
self.emb_dim,
self.rnn_unit,
self.alpha,
train_set.max_basket_size,
self.max_seq_length,
self.total_items,
self.item_probs,
self.correlation_matrix,
Expand Down

0 comments on commit 07760ba

Please sign in to comment.