diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index 642d0e427fa8c..5a5e1283e8e92 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -49,7 +49,7 @@ def parse_table_class(varname, o_main_program): if op.has_attr('table_class') and op.attr("table_class") != "none": return op.attr('table_class') else: - return "CommonSparseTable" + return "CtrSparseTable" class Accessor: diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index eb2c94b20106c..55fb59453481e 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -968,7 +968,7 @@ def sparse_embedding(input, padding_idx=None, is_test=False, entry=None, - table_class="CommonSparseTable", + table_class="CtrSparseTable", param_attr=None, dtype='float32'): helper = LayerHelper('sparse_embedding', **locals()) @@ -991,9 +991,12 @@ def sparse_embedding(input, padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( size[0] + padding_idx) - if table_class not in ["CommonSparseTable", "SSDSparseTable"]: + if table_class not in [ + "CommonSparseTable", "SSDSparseTable", "CtrSparseTable" + ]: raise ValueError( - "table_class must be in [CommonSparseTable, SSDSparseTable]") + "table_class must be in [CommonSparseTable, SSDSparseTable, CtrSparseTable]" + ) entry_str = "none"