Skip to content

Commit

Permalink
Merge branch 'main' into tf/continuous_seq_feats_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rnyak authored Feb 5, 2023
2 parents e16743d + c8d34ad commit 83e200f
Show file tree
Hide file tree
Showing 11 changed files with 31 additions and 22 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cpu-horovod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
branch=main
if [[ $ref_type == "tag"* ]]
then
raw=$(git branch -r --contains ${{ github.ref_name }})
branch=${raw/origin\/}
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
cd ${{ github.workspace }}; tox -e py38-horovod-cpu -- $branch
4 changes: 2 additions & 2 deletions .github/workflows/cpu-nvtabular.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
branch=main
if [[ $ref_type == "tag"* ]]
then
raw=$(git branch -r --contains ${{ github.ref_name }})
branch=${raw/origin\/}
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
GIT_COMMIT=`git rev-parse HEAD` tox -e py38-nvtabular-cpu -- $branch
4 changes: 2 additions & 2 deletions .github/workflows/cpu-systems.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
branch=main
if [[ $ref_type == "tag"* ]]
then
raw=$(git branch -r --contains ${{ github.ref_name }})
branch=${raw/origin\/}
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
GIT_COMMIT=`git rev-parse HEAD` tox -e py38-systems-cpu -- $branch
4 changes: 2 additions & 2 deletions .github/workflows/datasets.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ jobs:
branch=main
if [[ $ref_type == "tag"* ]]
then
raw=$(git branch -r --contains ${{ github.ref_name }})
branch=${raw/origin\/}
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
pip install "pandas>=1.2.0,<1.4.0dev0"
pip install "NVTabular@git+https://github.com/NVIDIA-Merlin/NVTabular.git@$branch"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/gpu-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
branch=main
if [[ $ref_type == "tag"* ]]
then
raw=$(git branch -r --contains ${{ github.ref_name }})
branch=${raw/origin\/}
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
cd ${{ github.workspace }}; tox -e py38-gpu -- $branch
4 changes: 2 additions & 2 deletions .github/workflows/implicit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ jobs:
branch=main
if [[ $ref_type == "tag"* ]]
then
raw=$(git branch -r --contains ${{ github.ref_name }})
branch=${raw/origin\/}
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
pip install "pandas>=1.2.0,<1.4.0dev0"
pip install "NVTabular@git+https://github.com/NVIDIA-Merlin/NVTabular.git@$branch"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/lightfm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ jobs:
branch=main
if [[ $ref_type == "tag"* ]]
then
raw=$(git branch -r --contains ${{ github.ref_name }})
branch=${raw/origin\/}
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
pip install "pandas>=1.2.0,<1.4.0dev0"
pip install "NVTabular@git+https://github.com/NVIDIA-Merlin/NVTabular.git@$branch"
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/pytorch.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

name: pytorch

on:
Expand Down Expand Up @@ -32,8 +33,8 @@ jobs:
branch=main
if [[ $ref_type == "tag"* ]]
then
raw=$(git branch -r --contains ${{ github.ref_name }})
branch=${raw/origin\/}
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
pip install "pandas>=1.2.0,<1.4.0dev0"
pip install "NVTabular@git+https://github.com/NVIDIA-Merlin/NVTabular.git@$branch"
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/tensorflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ jobs:
branch=main
if [[ $ref_type == "tag"* ]]
then
raw=$(git branch -r --contains ${{ github.ref_name }})
branch=${raw/origin\/}
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
pip install "pandas>=1.2.0,<1.4.0dev0"
pip install "NVTabular@git+https://github.com/NVIDIA-Merlin/NVTabular.git@$branch"
Expand Down Expand Up @@ -93,8 +93,8 @@ jobs:
branch=main
if [[ $ref_type == "tag"* ]]
then
raw=$(git branch -r --contains ${{ github.ref_name }})
branch=${raw/origin\/}
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
pip install "pandas>=1.2.0,<1.4.0dev0"
pip install "NVTabular@git+https://github.com/NVIDIA-Merlin/NVTabular.git@$branch"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/xgboost.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ jobs:
branch=main
if [[ $ref_type == "tag"* ]]
then
raw=$(git branch -r --contains ${{ github.ref_name }})
branch=${raw/origin\/}
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
pip install "pandas>=1.2.0,<1.4.0dev0"
pip install "NVTabular@git+https://github.com/NVIDIA-Merlin/NVTabular.git@$branch"
Expand Down
8 changes: 8 additions & 0 deletions merlin/models/tf/inputs/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def _get_dim(col, embedding_dims, infer_dim_fn):
return dim


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class AverageEmbeddingsByWeightFeature(tf.keras.layers.Layer):
def __init__(self, weight_feature_name: str, axis=1, **kwargs):
"""Computes the weighted average of a Tensor based
Expand Down Expand Up @@ -694,6 +695,13 @@ def from_schema_convention(schema: Schema, weight_features_name_suffix: str = "_

return seq_combiners

def get_config(self):
config = super().get_config()
config["axis"] = self.axis
config["weight_feature_name"] = self.weight_feature_name

return config


@dataclass
class EmbeddingOptions:
Expand Down

0 comments on commit 83e200f

Please sign in to comment.