Skip to content

Commit

Permalink
Remove implicit clone (#286)
Browse files Browse the repository at this point in the history
Fixes #254 

Summary:
1. Only clone when `na_strategy` is specified in stype encoder. Only
clone the `values` of a `MultiTensor`.
  • Loading branch information
yiweny authored Dec 20, 2023
1 parent 8fca093 commit 4a6eeed
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 66 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

### Changed
- Removed implicit clones in `StypeEncoder` ([#286](https://github.com/pyg-team/pytorch-frame/pull/286))

### Deprecated

Expand Down
83 changes: 40 additions & 43 deletions test/nn/encoder/test_stype_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import copy

import pytest
import torch
from torch.nn import ReLU

import torch_frame
from torch_frame import NAStrategy, stype
from torch_frame.config import ModelConfig
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.config.text_tokenizer import TextTokenizerConfig
from torch_frame.data.dataset import Dataset
from torch_frame.data.stats import StatType
Expand All @@ -22,7 +23,6 @@
StackEncoder,
TimestampEncoder,
)
from torch_frame.testing.text_embedder import HashTextEmbedder
from torch_frame.testing.text_tokenizer import (
RandomTextModel,
WhiteSpaceHashTokenizer,
Expand All @@ -44,10 +44,12 @@ def test_categorical_feature_encoder(encoder_cls_kwargs):
stype=stype.categorical,
**encoder_cls_kwargs[1],
)
feat_cat = tensor_frame.feat_dict[stype.categorical]
feat_cat = tensor_frame.feat_dict[stype.categorical].clone()
col_names = tensor_frame.col_names_dict[stype.categorical]
x = encoder(feat_cat, col_names)
assert x.shape == (feat_cat.size(0), feat_cat.size(1), 8)
# Make sure no in-place modification
assert torch.allclose(feat_cat, tensor_frame.feat_dict[stype.categorical])

# Perturb the first column
num_categories = len(stats_list[0][StatType.COUNT])
Expand Down Expand Up @@ -96,10 +98,12 @@ def test_numerical_feature_encoder(encoder_cls_kwargs):
stype=stype.numerical,
**encoder_cls_kwargs[1],
)
feat_num = tensor_frame.feat_dict[stype.numerical]
feat_num = tensor_frame.feat_dict[stype.numerical].clone()
col_names = tensor_frame.col_names_dict[stype.numerical]
x = encoder(feat_num, col_names)
assert x.shape == (feat_num.size(0), feat_num.size(1), 8)
# Make sure no in-place modification
assert torch.allclose(feat_num, tensor_frame.feat_dict[stype.numerical])
if "post_module" in encoder_cls_kwargs[1]:
assert encoder.post_module is not None
else:
Expand Down Expand Up @@ -142,9 +146,16 @@ def test_multicategorical_feature_encoder(encoder_cls_kwargs):
stype=stype.multicategorical,
**encoder_cls_kwargs[1],
)
feat_multicat = tensor_frame.feat_dict[stype.multicategorical]
feat_multicat = tensor_frame.feat_dict[stype.multicategorical].clone()
col_names = tensor_frame.col_names_dict[stype.multicategorical]
x = encoder(feat_multicat, col_names)
# Make sure no in-place modification
assert torch.allclose(
feat_multicat.values,
tensor_frame.feat_dict[stype.multicategorical].values)
assert torch.allclose(
feat_multicat.offset,
tensor_frame.feat_dict[stype.multicategorical].offset)
assert x.shape == (feat_multicat.size(0), feat_multicat.size(1), 8)

# Perturb the first column
Expand Down Expand Up @@ -178,9 +189,12 @@ def test_timestamp_feature_encoder(encoder_cls_kwargs):
stype=stype.timestamp,
**encoder_cls_kwargs[1],
)
feat_timestamp = tensor_frame.feat_dict[stype.timestamp]
feat_timestamp = tensor_frame.feat_dict[stype.timestamp].clone()
col_names = tensor_frame.col_names_dict[stype.timestamp]
x = encoder(feat_timestamp, col_names)
# Make sure no in-place modification
assert torch.allclose(feat_timestamp,
tensor_frame.feat_dict[stype.timestamp])
assert x.shape == (feat_timestamp.size(0), feat_timestamp.size(1), 8)


Expand Down Expand Up @@ -324,40 +338,6 @@ def test_timestamp_feature_encoder_with_nan(encoder_cls_kwargs):
assert (~torch.isnan(x)).all()


def test_text_embedded_encoder():
num_rows = 20
text_emb_channels = 10
out_channels = 5
dataset = FakeDataset(
num_rows=num_rows,
stypes=[
torch_frame.text_embedded,
],
col_to_text_embedder_cfg=TextEmbedderConfig(
text_embedder=HashTextEmbedder(text_emb_channels),
batch_size=None),
)
dataset.materialize()
tensor_frame = dataset.tensor_frame
stats_list = [
dataset.col_stats[col_name]
for col_name in tensor_frame.col_names_dict[stype.embedding]
]
encoder = LinearEmbeddingEncoder(
out_channels=out_channels,
stats_list=stats_list,
stype=stype.embedding,
)
feat_text = tensor_frame.feat_dict[stype.embedding]
col_names = tensor_frame.col_names_dict[stype.embedding]
feat = encoder(feat_text, col_names)
assert feat.shape == (
num_rows,
len(tensor_frame.col_names_dict[stype.embedding]),
out_channels,
)


def test_embedding_encoder():
num_rows = 20
out_channels = 5
Expand All @@ -378,9 +358,14 @@ def test_embedding_encoder():
stats_list=stats_list,
stype=stype.embedding,
)
feat_text = tensor_frame.feat_dict[stype.embedding]
feat_emb = tensor_frame.feat_dict[stype.embedding].clone()
col_names = tensor_frame.col_names_dict[stype.embedding]
x = encoder(feat_text, col_names)
x = encoder(feat_emb, col_names)
# Make sure no in-place modification
assert torch.allclose(feat_emb.values,
tensor_frame.feat_dict[stype.embedding].values)
assert torch.allclose(feat_emb.offset,
tensor_frame.feat_dict[stype.embedding].offset)
assert x.shape == (
num_rows,
len(tensor_frame.col_names_dict[stype.embedding]),
Expand Down Expand Up @@ -421,11 +406,23 @@ def test_text_tokenized_encoder():
stype=stype.text_tokenized,
col_to_model_cfg=col_to_model_cfg,
)
feat_text = tensor_frame.feat_dict[stype.text_tokenized]
feat_text = copy.deepcopy(tensor_frame.feat_dict[stype.text_tokenized])
col_names = tensor_frame.col_names_dict[stype.text_tokenized]
x = encoder(feat_text, col_names)
assert x.shape == (
num_rows,
len(tensor_frame.col_names_dict[stype.text_tokenized]),
out_channels,
)
# Make sure no in-place modification
assert isinstance(feat_text, dict) and isinstance(
tensor_frame.feat_dict[stype.text_tokenized], dict)
assert feat_text.keys() == tensor_frame.feat_dict[
stype.text_tokenized].keys()
for key in feat_text.keys():
assert torch.allclose(
feat_text[key].values,
tensor_frame.feat_dict[stype.text_tokenized][key].values)
assert torch.allclose(
feat_text[key].offset,
tensor_frame.feat_dict[stype.text_tokenized][key].offset)
4 changes: 2 additions & 2 deletions test/nn/models/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
gamma=0.1,
),
None,
4,
7,
id="TabNet",
),
pytest.param(
Expand All @@ -54,7 +54,7 @@
Trompt,
dict(channels=8, num_prompts=2),
None,
11,
16,
id="Trompt",
),
pytest.param(
Expand Down
71 changes: 50 additions & 21 deletions torch_frame/nn/encoder/stype_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ def reset_parameters_soft(module: Module):
module.reset_parameters()


def get_na_mask(tensor: Tensor) -> Tensor:
r"""Obtains the Na maks of the input :obj:`Tensor`.
Args:
tensor (Tensor): Input :obj:`Tensor`.
"""
if tensor.is_floating_point():
na_mask = torch.isnan(tensor)
else:
na_mask = tensor == -1
return na_mask


class StypeEncoder(Module, ABC):
r"""Base class for stype encoder. This module transforms tensor of a
specific stype, i.e., `TensorFrame.feat_dict[stype.xxx]` into 3-dimensional
Expand Down Expand Up @@ -121,11 +134,6 @@ def forward(
f"The number of columns in feat and the length of "
f"col_names must match (got {num_cols} and "
f"{len(col_names)}, respectively.)")
# Clone the tensor to avoid in-place modification
if not isinstance(feat, dict):
feat = feat.clone()
else:
feat = {key: value.clone() for key, value in feat.items()}
# NaN handling of the input Tensor
feat = self.na_forward(feat)
# Main encoding into column embeddings
Expand Down Expand Up @@ -174,20 +182,36 @@ def na_forward(self, feat: TensorData) -> TensorData:
"""
if self.na_strategy is None:
return feat
for col in range(feat.size(1)):
column_data = feat[:, col]
if isinstance(feat, _MultiTensor):
column_data = column_data.values
if column_data.is_floating_point():
nan_mask = torch.isnan(column_data)

# Since we are not changing the number of items in each column, it's
# faster to just clone the values, while reusing the same offset
# object.
if isinstance(feat, Tensor):
if get_na_mask(feat).any():
feat = feat.clone()
else:
return feat
elif isinstance(feat, MultiEmbeddingTensor):
if get_na_mask(feat.values).any():
feat = MultiEmbeddingTensor(num_rows=feat.num_rows,
num_cols=feat.num_cols,
values=feat.values.clone(),
offset=feat.offset)
else:
return feat
elif isinstance(feat, MultiNestedTensor):
if get_na_mask(feat.values).any():
feat = MultiNestedTensor(num_rows=feat.num_rows,
num_cols=feat.num_cols,
values=feat.values.clone(),
offset=feat.offset)
else:
nan_mask = column_data == -1
if nan_mask.ndim == 2:
nan_mask = nan_mask.any(dim=-1)
assert nan_mask.ndim == 1
assert len(nan_mask) == len(column_data)
if not nan_mask.any():
continue
return feat
else:
raise ValueError(f"Unrecognized type {type(feat)} in na_forward.")

# TODO: Remove for-loop over columns
for col in range(feat.size(1)):
if self.na_strategy == NAStrategy.MOST_FREQUENT:
# Categorical index is sorted based on count,
# so 0-th index is always the most frequent.
Expand All @@ -210,7 +234,13 @@ def na_forward(self, feat: TensorData) -> TensorData:
if isinstance(feat, _MultiTensor):
feat.fillna_col(col, fill_value)
else:
column_data[nan_mask] = fill_value
column_data = feat[:, col]
na_mask = get_na_mask(column_data)
if na_mask.ndim == 2:
na_mask = na_mask.any(dim=-1)
assert na_mask.ndim == 1
assert len(na_mask) == len(column_data)
column_data[na_mask] = fill_value
# Add better safeguard here to make sure nans are actually
# replaced, expecially when nans are represented as -1's. They are
# very hard to catch as they won't error out.
Expand Down Expand Up @@ -339,11 +369,10 @@ def encode_forward(
# Increment the index by one so that NaN index (-1) becomes 0
# (padding_idx)
# feat: [batch_size, num_cols]
feat.values = feat.values + 1
xs = []
for i, emb in enumerate(self.embs):
col_feat = feat[:, i]
xs.append(emb(col_feat.values, col_feat.offset[:-1]))
xs.append(emb(col_feat.values + 1, col_feat.offset[:-1]))
# [batch_size, num_cols, hidden_channels]
x = torch.stack(xs, dim=1)
return x
Expand Down

0 comments on commit 4a6eeed

Please sign in to comment.