Skip to content

Commit

Permalink
[CPU-PSLIB] Add consistency insepection of op's embedding name and sp…
Browse files Browse the repository at this point in the history
…arse table name in config_fleet.py, test=develop (#34249)
  • Loading branch information
WorgenZhang authored Jul 21, 2021
1 parent 038883f commit 2f76bb8
Showing 1 changed file with 81 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .node import DownpourWorker, DownpourServer
from . import ps_pb2 as pslib
import os
import logging

OpRole = core.op_proto_and_checker_maker.OpRole
# this dict is for store info about pull/push sparse ops.
Expand All @@ -41,6 +42,10 @@
"scale_sparse_grad": None,
}

logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)


class DistributedOptimizerImplBase(object):
"""
Expand Down Expand Up @@ -300,6 +305,74 @@ def _generate_multi_dense_table(self,

return dense_tables, cond2denseid, lists_params, lists_grads, root_params_list, root_grads_list

def _gen_distributed_emb_to_size_dict(self, program):
d_size = dict()
local_vars = program.current_block().vars

for op in program.global_block().ops:
if op.type in self.supported_embedding_types:
if op.attr('is_distributed') is True:
table_name = op.input("W")[0]
emb_size = local_vars[table_name].shape[1]
if d_size.get(table_name) is None:
d_size[table_name] = emb_size
elif d_size[table_name] != emb_size:
raise ValueError("embedding size error: %s vs %s" %
(emb_size, d_size[table_name]))

return d_size

def _check_config_fleet_with_program_op(self, strategy, table_name,
emb_to_size):
if strategy.get(table_name) is None:
strategy[table_name] = dict()
st = strategy[table_name]

accessor = None
if st.get("sparse_accessor_class") is not None:
accessor = st["sparse_accessor_class"]

if accessor is None:
accessor = "DownpourCtrAccessor"

# set sparse_embedx_dim in strategy,
# user do not have to set it in config_fleet
if accessor == "DownpourFeatureValueAccessor" \
or accessor == "DownpourCtrAccessor" \
or accessor == "DownpourDoubleUnitAccessor" \
or accessor == "DownpourUnitAccessor":
if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding size - 3 = %s" %
(st["sparse_embedx_dim"],
emb_to_size[table_name] - 3))
if st.get("sparse_embedx_dim") is None:
logger.warning(
"sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. "
"Hence automatically set sparse_embedx_dim = {} - 3.".
format(table_name, emb_to_size[table_name], emb_to_size[
table_name]))
st["sparse_embedx_dim"] = emb_to_size[table_name] - 3
elif accessor == "DownpourSparseValueAccessor":
if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[table_name]:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding size = %s" %
(st["sparse_embedx_dim"],
emb_to_size[table_name]))
if st.get("sparse_embedx_dim") is None:
logger.warning(
"sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. "
"Hence automatically set sparse_embedx_dim = {}.".format(
table_name, emb_to_size[table_name], emb_to_size[
table_name]))
st["sparse_embedx_dim"] = emb_to_size[table_name]

return strategy

def _minimize(self,
losses,
startup_program=None,
Expand Down Expand Up @@ -397,6 +470,10 @@ def _minimize(self,
sparse_table_to_index[tn] = sparse_table_index
sparse_table_index += 1

# get {table_name: emb_size} dict from program ops
emb_to_size = self._gen_distributed_emb_to_size_dict(
loss.block.program)

# get inputs_dict
inputs_dict = self._find_distributed_lookup_table_inputs(
loss.block.program, sparse_table)
Expand Down Expand Up @@ -511,8 +588,10 @@ def _minimize(self,
# ServerParameter add all sparse tables
for tn in sparse_table_to_index:
sparse_table_index = sparse_table_to_index[tn]
if strategy.get(tn) is not None:
server.add_sparse_table(sparse_table_index, strategy[tn])
st = self._check_config_fleet_with_program_op(strategy, tn,
emb_to_size)
if st.get(tn) is not None:
server.add_sparse_table(sparse_table_index, st[tn])
else:
server.add_sparse_table(sparse_table_index, None)

Expand Down

0 comments on commit 2f76bb8

Please sign in to comment.