From 07760ba9b7ad93699a1d43f7ac4eaec3b4fb3d7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=AA=20Trung=20Ho=C3=A0ng?= Date: Mon, 22 Jan 2024 10:07:04 +0700 Subject: [PATCH] Fix max_seq_length variable using the maximum number of baskets in training sequences as default (#591) --- cornac/models/beacon/recom_beacon.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/cornac/models/beacon/recom_beacon.py b/cornac/models/beacon/recom_beacon.py index 1838a2e3..93d1b061 100644 --- a/cornac/models/beacon/recom_beacon.py +++ b/cornac/models/beacon/recom_beacon.py @@ -51,6 +51,10 @@ class Beacon(NextBasketRecommender): Number of hops for constructing correlation matrix. If 0, zeros matrix will be used. + max_seq_length: int, optional, default: None + Maximum basket sequence length. + If None, it is the maximum number of basket in training sequences. + n_epochs: int, optional, default: 15 Number of training epochs @@ -83,6 +87,7 @@ def __init__( rnn_cell_type="LSTM", dropout_rate=0.5, nb_hop=1, + max_seq_length=None, n_epochs=15, batch_size=32, lr=0.001, @@ -99,10 +104,12 @@ def __init__( self.alpha = alpha self.rnn_cell_type = rnn_cell_type self.dropout_rate = dropout_rate + self.max_seq_length = max_seq_length self.seed = seed self.lr = lr def fit(self, train_set, val_set=None): + super().fit(train_set=train_set, val_set=val_set) import tensorflow.compat.v1 as tf from .beacon_tf import BeaconModel @@ -113,8 +120,12 @@ def fit(self, train_set, val_set=None): os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.logging.set_verbosity(tf.logging.ERROR) - super().fit(train_set=train_set, val_set=val_set) - + # max sequence length + self.max_seq_length = ( + max([len(bids) for bids in train_set.user_basket_data.values()]) + if self.max_seq_length is None # init max_seq_length + else self.max_seq_length + ) self.correlation_matrix = self._build_correlation_matrix( train_set=train_set, val_set=val_set, n_items=self.total_items ) @@ -132,7 +143,7 @@ def fit(self, train_set, val_set=None): self.emb_dim, self.rnn_unit, self.alpha, - train_set.max_basket_size, + self.max_seq_length, self.total_items, self.item_probs, self.correlation_matrix,