Skip to content

Commit

Permalink
[bugfix]: fix slow node crash in hive predictor (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
dawn310826 authored Jun 21, 2022
1 parent 7192f99 commit ccacb30
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 38 deletions.
35 changes: 26 additions & 9 deletions easy_rec/python/inference/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def out_of_range_exception(self):
def _write_line(self, table_writer, outputs):
pass

def load_to_table(self, output_path):
def load_to_table(self, output_path, slice_num, slice_id):
pass

def _get_fg_json(self, fg_json_path, model_path):
Expand Down Expand Up @@ -519,6 +519,7 @@ def _parse_value(all_vals):
self._output_cols, all_vals,
outputs)
outputs = [x for x in zip(*reserve_vals)]
logging.info('predict size: %s' % len(outputs))
self._write_line(table_writer, outputs)

ts3 = time.time()
Expand All @@ -536,7 +537,7 @@ def _parse_value(all_vals):
logging.info('Final_time_stats: read: %.2f predict: %.2f write: %.2f' %
(sum_t0, sum_t1, sum_t2))
table_writer.close()
self.load_to_table(output_path)
self.load_to_table(output_path, slice_num, slice_id)
logging.info('Predict %s done.' % input_path)

def predict(self, input_data_dict_list, output_names=None, batch_size=1):
Expand Down Expand Up @@ -886,7 +887,7 @@ def _get_writer(self, output_path, slice_id):
assert not is_exist, '%s is already exists. Please drop it.' % output_path

output_path = output_path.replace('.', '/')
self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s' % (self._hive_config.host, output_path)
self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % (self._hive_config.host, output_path)
if not gfile.Exists(self._hdfs_path):
gfile.MakeDirs(self._hdfs_path)
res_path = os.path.join(self._hdfs_path, 'part-%d.csv' % slice_id)
Expand All @@ -903,7 +904,21 @@ def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
[all_vals[k] for k in reserved_cols]
return reserve_vals

def load_to_table(self, output_path):
def load_to_table(self, output_path, slice_num, slice_id):
res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id)
success_writer = gfile.GFile(res_path, 'w')
success_writer.write('')

if slice_id != 0:
return

for id in range(slice_num):
res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % id)
while not gfile.Exists(res_path):
time.sleep(10)

table_name, partition_name, partition_val = self.get_table_info(
output_path)
schema = ''
for output_col_name in self._output_cols:
tf_type = self._predictor_impl._outputs_map[output_col_name].dtype
Expand All @@ -914,11 +929,10 @@ def load_to_table(self, output_path):
assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name
idx = self._all_cols.index(output_col_name)
output_col_types = self._all_col_types[idx]
schema += output_col_name + ' ' + output_col_types + ','
if output_col_name != partition_name:
schema += output_col_name + ' ' + output_col_types + ','
schema = schema.rstrip(',')

table_name, partition_name, partition_val = self.get_table_info(
output_path)
if partition_name and partition_val:
sql = "create table if not exists %s (%s) PARTITIONED BY (%s string)" % \
(table_name, schema, partition_name)
Expand All @@ -927,8 +941,11 @@ def load_to_table(self, output_path):
(self._hdfs_path, table_name, partition_name, partition_val)
self._hive_util.run_sql(sql)
else:
sql = "create external table if not exists %s (%s) location '%s'" % \
(table_name, schema, self._hdfs_path)
sql = "create table if not exists %s (%s)" % \
(table_name, schema)
self._hive_util.run_sql(sql)
sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s" % \
(self._hdfs_path, table_name)
self._hive_util.run_sql(sql)

@property
Expand Down
8 changes: 3 additions & 5 deletions easy_rec/python/protos/hive_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,15 @@ message HiveConfig {
required uint32 port = 2 [default = 10000];

// hive username
required string username = 3;
required string username = 3 [default = 'admin'];

// hive database
required string database = 4 [default = 'default'];

required string table_name = 5;

required string hash_fields = 6;
optional uint32 limit_num = 6 [default = 0];

optional uint32 limit_num = 7 [default = 0];

required uint32 fetch_size = 8 [default = 512];
required uint32 fetch_size = 7 [default = 512];

}
44 changes: 23 additions & 21 deletions easy_rec/python/utils/hive_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__(self,
tablename,
selected_cols,
partition_kv,
hash_fields,
limit_num,
batch_size=16,
task_index=0,
Expand All @@ -25,7 +24,6 @@ def __init__(self,
self.tablename = tablename
self.selected_cols = selected_cols
self.partition_kv = partition_kv
self.hash_fields = hash_fields
self.limit_num = limit_num
self.task_index = task_index
self.task_num = task_num
Expand All @@ -41,19 +39,15 @@ def gen_sql(self):
part = ' '.join(res)
sql = """select {}
from {}""".format(self.selected_cols, self.tablename)
assert self.hash_fields is not None, 'hash_fields must not be empty'
fields = [
'cast({} as string)'.format(key) for key in self.hash_fields.split(',')
]
str_fields = ','.join(fields)

if not part:
sql += """
where hash(concat({}))%{}={}
""".format(str_fields, self.task_num, self.task_index)
where CAST((rand(1) * {}) AS BIGINT) = {}
""".format(self.task_num, self.task_index)
else:
sql += """
where {} and hash(concat({}))%{}={}
""".format(part, str_fields, self.task_num, self.task_index)
where {} and CAST((rand(1) * {}) AS BIGINT) = {}
""".format(part, self.task_num, self.task_index)
if self.limit_num is not None and self.limit_num > 0:
sql += ' limit {}'.format(self.limit_num)
return sql
Expand Down Expand Up @@ -84,7 +78,7 @@ def __init__(self,
self._selected_cols = selected_cols
self._record_defaults = record_defaults

def _construct_table_info(self, table_name, hash_fields, limit_num):
def _construct_table_info(self, table_name, limit_num):
# sample_table/dt=2014-11-23/name=a
segs = table_name.split('/')
table_name = segs[0].strip()
Expand All @@ -94,7 +88,7 @@ def _construct_table_info(self, table_name, hash_fields, limit_num):
partition_kv = None

table_info = TableInfo(table_name, self._selected_cols, partition_kv,
hash_fields, limit_num, self._data_config.batch_size,
limit_num, self._data_config.batch_size,
self._task_index, self._task_num, self._num_epoch)
return table_info

Expand All @@ -120,7 +114,6 @@ def hive_read(self, input_path):

for table_path in input_path.split(','):
table_info = self._construct_table_info(table_path,
self._hive_config.hash_fields,
self._hive_config.limit_num)
batch_size = self._this_batch_size
batch_defaults = []
Expand Down Expand Up @@ -160,8 +153,8 @@ def hive_read(self, input_path):
conn.close()
logging.info('finish epoch[%d]' % self._num_epoch_record)

def hive_read_line(self, input_path, hash_fields, limit_num=None):
table_info = self._construct_table_info(input_path, hash_fields, limit_num)
def hive_read_line(self, input_path, limit_num=None):
table_info = self._construct_table_info(input_path, limit_num)
conn = self._construct_hive_connect()
cursor = conn.cursor()
sql = table_info.gen_sql()
Expand Down Expand Up @@ -189,13 +182,22 @@ def is_table_or_partition_exist(self,
partition_val=None):
if partition_name and partition_val:
sql = 'show partitions %s partition(%s=%s)' % (table_name, partition_name, partition_val)
try:
res = self.run_sql(sql)
if not res:
return False
else:
return True
except:
return False

else:
sql = 'desc %s' % table_name
try:
self.run_sql(sql)
return True
except:
return False
try:
self.run_sql(sql)
return True
except:
return False

def get_all_cols(self, input_path):
conn = self._construct_hive_connect()
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
__version__ = '0.4.9'
__version__ = '0.4.10'
2 changes: 0 additions & 2 deletions samples/emr_script/mmoe/mmoe_census_income.config
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ hive_train_input {
username: "admin"
table_name: "census_income_train_simple"
limit_num: 500
hash_fields: "age,class_of_worker,marital_status,education"
fetch_size: 1024
}

Expand All @@ -12,7 +11,6 @@ hive_eval_input {
username: "admin"
table_name: "census_income_train_simple"
limit_num: 500
hash_fields: "age,class_of_worker,marital_status,education"
fetch_size: 1024
}

Expand Down

0 comments on commit ccacb30

Please sign in to comment.