Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dataloader to provide new output structure #101

Merged
merged 16 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/01b-Getting-started-Pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@
" with torch.no_grad():\n",
" for batch in loader:\n",
" batch_size = batch[0]['rating'].shape[0]\n",
" loss += criterion(model(batch), batch[0]['rating'].squeeze(1)) * batch_size\n",
" loss += criterion(model(batch), batch[0]['rating']) * batch_size\n",
" n += batch_size\n",
" return loss.item() / n"
]
Expand Down Expand Up @@ -392,7 +392,7 @@
"model.train()\n",
"for batch in loader:\n",
"\n",
" loss = criterion(model(batch), batch[0]['rating'].squeeze(1).float())\n",
" loss = criterion(model(batch), batch[0]['rating'].float())\n",
"\n",
" # compute gradient and do an update step\n",
" optimizer.zero_grad()\n",
Expand Down
9 changes: 9 additions & 0 deletions merlin/dataloader/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,12 @@ def _cast_to_numpy_dtype(self, dtype):

def _to_sparse_tensor(self, values_offset, column_name):
raise NotImplementedError("Sparse support isn't implemented yet for the Jax dataloader")

def _row_lengths_to_offsets(self, row_lengths):
zero_value = jnp.array([0], dtype=row_lengths.dtype)
if len(row_lengths.shape) == 2:
zero_value = zero_value.reshape(-1, 1)
return jnp.concatenate([zero_value, jnp.cumsum(row_lengths, axis=0)], axis=0)

def _reshape_dim(self, tensor):
return jax.numpy.reshape(tensor, (-1,))
17 changes: 11 additions & 6 deletions merlin/dataloader/loader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
make_df,
pull_apart_list,
)
from merlin.dag import BaseOperator, ColumnSelector, DictArray, Graph, Node
from merlin.dag import BaseOperator, ColumnSelector, DictArray, Graph, Node, ungroup_values_offsets
from merlin.dag.executors import LocalExecutor
from merlin.io import shuffle_df
from merlin.schema import Schema, Tags
Expand Down Expand Up @@ -502,16 +502,21 @@ def _process_batch(self, tensors):
if isinstance(k, tuple):
values = self._tensor_split(v, len(k), axis=1)
for column_name, column_value in zip(k, values):
X[column_name] = column_value
X[column_name] = self._reshape_dim(column_value)
else:
if isinstance(v, tuple):
v = tuple(self._reshape_dim(tv) for tv in v)
else:
v = self._reshape_dim(v)
X[k] = v

X = ungroup_values_offsets(X)
for column_name in self.sparse_names:
if column_name in self.sparse_max:
# raise ValueError(
# f"Did not convert {column_name} to sparse due to missing sparse_max entry"
# )
X[column_name] = self._to_sparse_tensor(X[column_name], column_name)
tensor = (X[f"{column_name}__values"], X[f"{column_name}__offsets"])
X.pop(f"{column_name}__values")
X.pop(f"{column_name}__offsets")
X[column_name] = self._to_sparse_tensor(tensor, column_name)

# Return a tensor if we have only one label column, but return a
# dictionary of tensors if there are multiple label columns, since
Expand Down
4 changes: 2 additions & 2 deletions merlin/dataloader/ops/embeddings/torch_embedding_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Torch_NumpyEmbeddingOperator(NumpyEmbeddingOperator):
"""

def _format_embeddings(self, embeddings, keys):
return torch.from_numpy(embeddings).to(keys.device).squeeze(1)
return torch.from_numpy(embeddings).to(keys.device)


class Torch_MmapNumpyTorchEmbedding(MmapNumpyTorchEmbedding):
Expand All @@ -103,4 +103,4 @@ class Torch_MmapNumpyTorchEmbedding(MmapNumpyTorchEmbedding):
"""

def _format_embeddings(self, embeddings, keys):
return torch.from_numpy(embeddings).to(keys.device).squeeze(1)
return torch.from_numpy(embeddings).to(keys.device)
18 changes: 13 additions & 5 deletions merlin/dataloader/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class Loader(tf.keras.utils.Sequence, LoaderBase):
will usually contain fewer rows.
"""

_use_row_lengths = True
_use_row_lengths = False

def __init__(
self,
Expand Down Expand Up @@ -254,15 +254,20 @@ def _pull_values_offsets(self, values_offset):
diff_offsets = None
if isinstance(values_offset, tuple):
values = tf.reshape(values_offset[0], [-1])
diff_offsets = tf.cast(tf.reshape(values_offset[1], [-1]), dtype=tf.int64)
offsets = tf.math.cumsum(diff_offsets)
offsets = tf.reshape(values_offset[1], [-1])
else:
values = tf.reshape(values_offset, [-1])
offsets = tf.arange(tf.shape(values)[0], dtype=tf.int64)
diff_offsets = offsets[1:] - offsets[:-1]
num_rows = len(offsets)
diff_offsets = offsets[1:] - offsets[:-1]
return values, offsets, diff_offsets, num_rows

def _row_lengths_to_offsets(self, row_lengths):
zero_value = tf.constant([0], dtype=row_lengths.dtype)
if len(row_lengths.shape) == 2:
zero_value = tf.expand_dims(zero_value, axis=0)
return tf.concat([zero_value, tf.cumsum(row_lengths)], axis=0)

def _get_max_seq_len(self, diff_offsets):
# get_max_seq_len, return int
return int(tf.math.reduce_max(diff_offsets))
Expand All @@ -289,7 +294,7 @@ def _get_sparse_tensor(self, values, indices, num_rows, seq_limit):
def _build_sparse_tensor(
self, values, offsets, diff_offsets, num_rows, seq_limit, sparse_as_dense
):
ragged = tf.RaggedTensor.from_row_lengths(values=values, row_lengths=diff_offsets)
ragged = tf.RaggedTensor.from_row_splits(values=values, row_splits=offsets)
tensor = tf.RaggedTensor.from_tensor(ragged.to_tensor(shape=[None, seq_limit])).to_sparse()
if sparse_as_dense:
tensor = tf.sparse.to_dense(tensor)
Expand All @@ -309,6 +314,9 @@ def _cast_to_numpy_dtype(self, dtype):
"""
return dtype.as_numpy_dtype()

def _reshape_dim(self, tensor):
return tf.reshape(tensor, shape=[-1])


class KerasSequenceValidater(tf.keras.callbacks.Callback):
# TODO: document
Expand Down
5 changes: 4 additions & 1 deletion merlin/dataloader/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def _split_fn(self, tensor, idx, axis=0):
def _tensor_split(self, tensor, idx, axis=0):
return torch.tensor_split(tensor, idx, axis=axis)

def _reshape_dim(self, tensor):
return tensor.view(-1)

def _pull_values_offsets(self, values_offset):
# pull_values_offsets, return values offsets diff_offsets
if isinstance(values_offset, tuple):
Expand Down Expand Up @@ -162,7 +165,7 @@ def _sum(self, tensor):
return tensor.sum()

def _row_lengths_to_offsets(self, row_lengths):
zero_value = torch.tensor([0], device=self.device)
zero_value = torch.tensor([0], device=self.device, dtype=row_lengths.dtype)
if len(row_lengths.shape) == 2:
zero_value = zero_value.view(-1, 1)
return torch.cat((zero_value, torch.cumsum(row_lengths, 0)))
Expand Down
100 changes: 51 additions & 49 deletions tests/unit/dataloader/test_tf_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,12 @@ def test_nested_list():

df = pd.DataFrame(
{
"data": [
np.random.rand(np.random.randint(10) + 1, 3).tolist() for i in range(num_rows)
],
# ToDo: We deprioritized nested columns - it requires multiple offsets
# "data": [
# np.random.rand(np.random.randint(10) + 1, 3).tolist() for i in range(num_rows)
# ],
# keep a data field because we use index in the tests
"data": [np.random.rand() for i in range(num_rows)],
"data2": [np.random.rand(np.random.randint(10) + 1).tolist() for i in range(num_rows)],
"label": [np.random.rand() for i in range(num_rows)],
}
Expand All @@ -127,25 +130,24 @@ def test_nested_list():
)

batch = next(loader)

# [[1,2,3],[3,1],[...],[]]
@tf.function
def _ragged_for_nested_data_col():
nested_data_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data"][0][:, 0], tf.cast(batch[0]["data"][1][:, 0], tf.int32)
).to_tensor()
return nested_data_col

nested_data_col = _ragged_for_nested_data_col()
true_data_col = tf.reshape(
tf.ragged.constant(df.iloc[:batch_size, 0].tolist()).to_tensor(), [batch_size, -1]
)
# @tf.function
# def _ragged_for_nested_data_col():
# nested_data_col = tf.RaggedTensor.from_row_lengths(
# batch[0]["data"][0][:, 0], tf.cast(batch[0]["data"][1][:, 0], tf.int32)
# ).to_tensor()
# return nested_data_col

# nested_data_col = _ragged_for_nested_data_col()
# true_data_col = tf.reshape(
# tf.ragged.constant(df.iloc[:batch_size, 0].tolist()).to_tensor(), [batch_size, -1]
# )

# [1,2,3]
@tf.function
def _ragged_for_multihot_data_col():
multihot_data2_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data2"][0][:, 0], tf.cast(batch[0]["data2"][1][:, 0], tf.int32)
multihot_data2_col = tf.RaggedTensor.from_row_splits(
batch[0]["data2__values"], tf.cast(batch[0]["data2__offsets"], tf.int32)
).to_tensor()
return multihot_data2_col

Expand All @@ -154,8 +156,8 @@ def _ragged_for_multihot_data_col():
tf.ragged.constant(df.iloc[:batch_size, 1].tolist()).to_tensor(),
[batch_size, -1],
)
assert nested_data_col.shape == true_data_col.shape
assert np.allclose(nested_data_col.numpy(), true_data_col.numpy())
# assert nested_data_col.shape == true_data_col.shape
# assert np.allclose(nested_data_col.numpy(), true_data_col.numpy())
assert multihot_data2_col.shape == true_data2_col.shape
assert np.allclose(multihot_data2_col.numpy(), true_data2_col.numpy())

Expand Down Expand Up @@ -387,39 +389,32 @@ def test_mh_support(tmpdir, multihot_data, multihot_dataset, batch_size):
batch_size=batch_size,
shuffle=False,
)
row_lengths = None
offsets = None
idx = 0

for X, y in data_itr:
assert len(X) == 4
assert len(X) == 7
n_samples = y.shape[0]

for mh_name in ["Authors", "Reviewers", "Embedding"]:
# assert (mh_name) in X
array, row_lengths = X[mh_name]
row_lengths = row_lengths.numpy()[:, 0]
array = array.numpy()[:, 0]
array, offsets = X[f"{mh_name}__values"], X[f"{mh_name}__offsets"]
offsets = offsets.numpy()
array = array.numpy()
lens = [0]
cur = 0
for x in multihot_data[mh_name][idx * batch_size : idx * batch_size + n_samples]:
cur += len(x)
lens.append(cur)
assert (offsets == np.array(lens)).all()
assert len(array) == max(lens)

if mh_name == "Embedding":
assert (row_lengths == 3).all()
else:
lens = [
len(x)
for x in multihot_data[mh_name][idx * batch_size : idx * batch_size + n_samples]
]
assert (row_lengths == np.array(lens)).all()

if mh_name == "Embedding":
assert len(array) == (n_samples * 3)
else:
assert len(array) == sum(lens)
idx += 1
assert idx == (3 // batch_size + 1)


@pytest.mark.parametrize("batch_size", [1, 2, 4])
@pytest.mark.parametrize("batch_size", [128, 256])
def test_validater(tmpdir, batch_size):
n_samples = 9
n_samples = 10000
rand = np.random.RandomState(0)

df = make_df({"a": rand.randn(n_samples), "label": rand.randint(2, size=n_samples)})
Expand All @@ -434,18 +429,18 @@ def test_validater(tmpdir, batch_size):

input_ = tf.keras.Input(name="a", dtype=tf.float32, shape=(1,))
x = tf.keras.layers.Dense(128, "relu")(input_)
x = tf.keras.layers.Dense(1, activation="softmax")(x)
x = tf.keras.layers.Dense(1, activation="sigmoid")(x)

model = tf.keras.Model(inputs=input_, outputs=x)
model.compile("sgd", "binary_crossentropy", metrics=["accuracy", tf.keras.metrics.AUC()])

validater = tf_dataloader.KerasSequenceValidater(dataloader)
model.fit(dataloader, epochs=2, verbose=0, callbacks=[validater])
model.fit(dataloader, epochs=1, verbose=0, callbacks=[validater])

predictions, labels = [], []
for X, y_true in dataloader:
y_pred = model(X)
labels.extend(y_true.numpy()[:, 0])
labels.extend(y_true.numpy())
predictions.extend(y_pred.numpy()[:, 0])
predictions = np.array(predictions)
labels = np.array(labels)
Expand All @@ -455,12 +450,16 @@ def test_validater(tmpdir, batch_size):
auc_key = [i for i in logs if i.startswith("val_auc")][0]

true_accuracy = (labels == (predictions > 0.5)).mean()
print(true_accuracy)
estimated_accuracy = logs["val_accuracy"]
assert np.isclose(true_accuracy, estimated_accuracy, rtol=1e-6)
print(estimated_accuracy)
assert np.isclose(true_accuracy, estimated_accuracy, rtol=0.1)

true_auc = roc_auc_score(labels, predictions)
estimated_auc = logs[auc_key]
assert np.isclose(true_auc, estimated_auc, rtol=1e-6)
print(true_auc)
print(estimated_auc)
assert np.isclose(true_auc, estimated_auc, rtol=0.1)


@pytest.mark.parametrize("batch_size", [1, 10, 100])
Expand Down Expand Up @@ -539,12 +538,13 @@ def test_sparse_tensors(tmpdir, sparse_dense):
feats, labs = batch
for col in spa_lst:
# grab row lengths
feature_tensor = feats[f"{col}"]
if not sparse_dense:
feature_tensor = feats[f"{col}"]
assert list(feature_tensor.shape) == [batch_size, spa_mx[col]]
assert isinstance(feature_tensor, tf.sparse.SparseTensor)
else:
assert feature_tensor[1].shape[0] == batch_size
feature_tensor = feats[f"{col}__offsets"]
assert feature_tensor.shape[0] == batch_size + 1
assert not isinstance(feature_tensor, tf.sparse.SparseTensor)


Expand Down Expand Up @@ -702,12 +702,14 @@ def test_keras_model_with_multiple_label_columns():

inputs = tf.keras.Input(name="a", dtype=tf.float32, shape=(1,))
outputs = tf.keras.layers.Dense(16, "relu")(inputs)
output_1 = tf.keras.layers.Dense(1, activation="softmax", name="label1")(outputs)
output_1 = tf.keras.layers.Dense(2, activation="softmax", name="label1")(outputs)
output_2 = tf.keras.layers.Dense(5, activation="softmax", name="label2")(outputs)
# If we are using a Keras model and dataloader returns multiple labels,
# `outputs` keys must match the multiple labels returned by the dataloader.
model = tf.keras.Model(inputs=inputs, outputs={"label1": output_1, "label2": output_2})
model.compile(optimizer="sgd", loss="binary_crossentropy", metrics=["accuracy"])
model.compile(
optimizer="sgd", loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=["accuracy"]
)
model.fit(loader, epochs=2)

preds_model = model.predict({"a": tf.constant([0.1, 0.2, 0.3])})
Expand Down
Loading