Skip to content

Commit

Permalink
Update Neural Solution based on v2.4 release test (#1410)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
  • Loading branch information
Kaihui-intel authored Nov 23, 2023
1 parent 3b29252 commit ab8c9f0
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 11 deletions.
2 changes: 1 addition & 1 deletion neural_solution/backend/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def launch_task(self, task: Task, resource):
)
task_status = self.check_task_status(log_path)
self.task_db.update_task_status(task.task_id, task_status)
self.q_model_path = get_q_model_path(log_path=log_path) if task_status == "done" else None
self.q_model_path = get_q_model_path(log_path=log_path, task_id=task.task_id) if task_status == "done" else None
self.report_result(task.task_id, log_path, task_runtime)

def dispatch_task(self, task, resource):
Expand Down
6 changes: 5 additions & 1 deletion neural_solution/backend/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,12 @@ def create_dir(path):
os.makedirs(os.path.dirname(path))


def get_q_model_path(log_path):
def get_q_model_path(log_path, task_id):
"""Get the quantized model path from task log.
Args:
log_path (str): log path for task
task_id: the id of task
Returns:
str: quantized model path
Expand All @@ -241,5 +242,8 @@ def get_q_model_path(log_path):
match = re.search(r"(Save quantized model to|Save config file and weights of quantized model to) (.+?)\.", line)
if match:
q_model_path = match.group(2)
match_task_id = re.search(r"(.+/task_workspace/{}/[^/]+)".format(task_id), q_model_path)
if match_task_id:
q_model_path = match_task_id.group()
return q_model_path
return "quantized model path not found"
4 changes: 2 additions & 2 deletions neural_solution/examples/hf_models/task_request.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
"--model_name_or_path bert-base-cased --task_name mrpc --do_eval --output_dir result"
],
"approach": "static",
"requirements": [],
"requirements": ["datasets", "transformers=4.21.0", "torch"],
"workers": 1
}
}
1 change: 1 addition & 0 deletions neural_solution/examples/hf_models_grpc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ optional arguments:
- Step 2: Submit the task request to service, and it will return the submit status and task id for future use.

```shell
[user@server hf_models_grpc]$ cd path/to/neural_solution/examples/hf_models_grpc
[user@server hf_models_grpc]$ python client.py submit --request="test_task_request.json"

# response if submit successfully
Expand Down
4 changes: 2 additions & 2 deletions neural_solution/examples/hf_models_grpc/task_request.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
"--model_name_or_path bert-base-cased --task_name mrpc --do_eval --output_dir result"
],
"approach": "static",
"requirements": [],
"requirements": ["datasets", "transformers=4.21.0", "torch"],
"workers": 1
}
}
6 changes: 3 additions & 3 deletions neural_solution/frontend/gRPC/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def run_query_task_status(args):
query_action_parser.set_defaults(func=run_query_task_result)
query_action_parser.add_argument("--task_id", type=str, default=None, help="Query task by task id.")

parser.add_argument("--grpc_api_port", type=str, default=None, help="grpc server port.")
parser.add_argument("--result_monitor_port", type=str, default=None, help="result monitor port.")
parser.add_argument("--task_monitor_port", type=str, default=None, help="task monitor port.")
parser.add_argument("--grpc_api_port", type=str, default="8001", help="grpc server port.")
parser.add_argument("--result_monitor_port", type=str, default="2222", help="result monitor port.")
parser.add_argument("--task_monitor_port", type=str, default="3333", help="task monitor port.")

args = parser.parse_args()
config.grpc_api_port = args.grpc_api_port
Expand Down
10 changes: 8 additions & 2 deletions neural_solution/test/backend/utils/test_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,19 @@ def test_create_dir(self):
@patch("builtins.open", mock_open(read_data="Save quantized model to /path/to/model."))
def test_get_q_model_path_success(self):
log_path = "fake_log_path"
q_model_path = get_q_model_path(log_path)
q_model_path = get_q_model_path(log_path, "task_id")
self.assertEqual(q_model_path, "/path/to/model")

@patch("builtins.open", mock_open(read_data="Save quantized model to /path/to/task_workspace/task_id/model/1.pb."))
def test_get_q_model_path_success_task_id(self):
log_path = "fake_log_path"
q_model_path = get_q_model_path(log_path, "task_id")
self.assertEqual(q_model_path, "/path/to/task_workspace/task_id/model")

@patch("builtins.open", mock_open(read_data="No quantized model saved."))
def test_get_q_model_path_failure(self):
log_path = "fake_log_path"
q_model_path = get_q_model_path(log_path)
q_model_path = get_q_model_path(log_path, "task_id")
self.assertEqual(q_model_path, "quantized model path not found")


Expand Down

0 comments on commit ab8c9f0

Please sign in to comment.