Skip to content

Commit

Permalink
Merge pull request #1 from seemingwang/accessor_merge
Browse files Browse the repository at this point in the history
change CtrSparseTable to default sparse table
  • Loading branch information
zhaocaibei123 authored Aug 18, 2021
2 parents c34d8fb + a1a3b42 commit d916743
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/paddle/distributed/fleet/runtime/the_one_ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def parse_table_class(varname, o_main_program):
if op.has_attr('table_class') and op.attr("table_class") != "none":
return op.attr('table_class')
else:
return "CommonSparseTable"
return "CtrSparseTable"


class Accessor:
Expand Down
9 changes: 6 additions & 3 deletions python/paddle/fluid/contrib/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ def sparse_embedding(input,
padding_idx=None,
is_test=False,
entry=None,
table_class="CommonSparseTable",
table_class="CtrSparseTable",
param_attr=None,
dtype='float32'):
helper = LayerHelper('sparse_embedding', **locals())
Expand All @@ -991,9 +991,12 @@ def sparse_embedding(input,
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
size[0] + padding_idx)

if table_class not in ["CommonSparseTable", "SSDSparseTable"]:
if table_class not in [
"CommonSparseTable", "SSDSparseTable", "CtrSparseTable"
]:
raise ValueError(
"table_class must be in [CommonSparseTable, SSDSparseTable]")
"table_class must be in [CommonSparseTable, SSDSparseTable, CtrSparseTable]"
)

entry_str = "none"

Expand Down

0 comments on commit d916743

Please sign in to comment.