Skip to content
This repository has been archived by the owner on May 23, 2024. It is now read-only.

Commit

Permalink
fix: modify the way port number passing (#210)
Browse files Browse the repository at this point in the history
* refactor naming of some variables

* fix sanity test failure

* modify to pick ports in linear fashion

* modify the way port number passing

* using plural ports instead of switching to singular

* trigger rebuild test

* fix KeyError

* retrigger rebuild test

Co-authored-by: Jinpeng Qi <qijinpen@amazon.com>
  • Loading branch information
jinpengqi and Jinpeng Qi authored Jun 15, 2021
1 parent 98b7b6b commit 6a51a60
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 37 deletions.
21 changes: 8 additions & 13 deletions docker/build_artifacts/sagemaker/python_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@

SAGEMAKER_BATCHING_ENABLED = os.environ.get("SAGEMAKER_TFS_ENABLE_BATCHING", "false").lower()
MODEL_CONFIG_FILE_PATH = "/sagemaker/model-config.cfg"
TFS_GRPC_PORT_RANGE = os.environ.get("TFS_GRPC_PORT_RANGE")
TFS_REST_PORT_RANGE = os.environ.get("TFS_REST_PORT_RANGE")
TFS_GRPC_PORTS = os.environ.get("TFS_GRPC_PORTS")
TFS_REST_PORTS = os.environ.get("TFS_REST_PORTS")
SAGEMAKER_TFS_PORT_RANGE = os.environ.get("SAGEMAKER_SAFE_PORT_RANGE")
TFS_INSTANCE_COUNT = int(os.environ.get("SAGEMAKER_TFS_INSTANCE_COUNT", "1"))

Expand Down Expand Up @@ -69,8 +69,8 @@ def __init__(self):
# during the _handle_load_model_post()
self.model_handlers = {}
else:
self._tfs_grpc_ports = self._parse_sagemaker_port_range(TFS_GRPC_PORT_RANGE)
self._tfs_rest_ports = self._parse_sagemaker_port_range(TFS_REST_PORT_RANGE)
self._tfs_grpc_ports = self._parse_concat_ports(TFS_GRPC_PORTS)
self._tfs_rest_ports = self._parse_concat_ports(TFS_REST_PORTS)

self._channels = {}
for grpc_port in self._tfs_grpc_ports:
Expand Down Expand Up @@ -98,16 +98,11 @@ def on_post(self, req, res, model_name=None):
data = json.loads(req.stream.read().decode("utf-8"))
self._handle_load_model_post(res, data)

def _parse_sagemaker_port_range(self, port_range):
lower, upper = port_range.split('-')
lower = int(lower)
upper = int(upper)
if lower == upper:
return [lower]
return [lower + 2 * i for i in range(TFS_INSTANCE_COUNT)]
def _parse_concat_ports(self, concat_ports):
return concat_ports.split(",")

def _pick_port(self, ports):
return str(random.choice(ports))
return random.choice(ports)

def _parse_sagemaker_port_range_mme(self, port_range):
lower, upper = port_range.split('-')
Expand Down Expand Up @@ -254,7 +249,7 @@ def _handle_invocation_post(self, req, res, model_name=None):
rest_port = self._pick_port(self._tfs_rest_ports)
data, context = tfs_utils.parse_request(req, rest_port, grpc_port,
self._tfs_default_model_name,
channel=self._channels[int(grpc_port)])
channel=self._channels[grpc_port])

try:
res.status = falcon.HTTP_200
Expand Down
52 changes: 28 additions & 24 deletions docker/build_artifacts/sagemaker/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,30 +89,29 @@ def __init__(self):
parts = self._sagemaker_port_range.split("-")
low = int(parts[0])
hi = int(parts[1])
self._tfs_grpc_port = []
self._tfs_rest_port = []
self._tfs_grpc_ports = []
self._tfs_rest_ports = []
if low + 2 * self._tfs_instance_count > hi:
raise ValueError("not enough ports available in SAGEMAKER_SAFE_PORT_RANGE ({})"
.format(self._sagemaker_port_range))
self._tfs_grpc_port_range = "{}-{}".format(low,
low + 2 * self._tfs_instance_count)
self._tfs_rest_port_range = "{}-{}".format(low + 1,
low + 2 * self._tfs_instance_count + 1)
# select non-overlapping grpc and rest ports based on tfs instance count
for i in range(self._tfs_instance_count):
self._tfs_grpc_port.append(str(low + 2 * i))
self._tfs_rest_port.append(str(low + 2 * i + 1))
# set environment variable for python service
os.environ["TFS_GRPC_PORT_RANGE"] = self._tfs_grpc_port_range
os.environ["TFS_REST_PORT_RANGE"] = self._tfs_rest_port_range
self._tfs_grpc_ports.append(str(low + 2 * i))
self._tfs_rest_ports.append(str(low + 2 * i + 1))
# concat selected ports respectively in order to pass them to python service
self._tfs_grpc_concat_ports = self._concat_ports(self._tfs_grpc_ports)
self._tfs_rest_concat_ports = self._concat_ports(self._tfs_rest_ports)
else:
# just use the standard default ports
self._tfs_grpc_port = ["9000"]
self._tfs_rest_port = ["8501"]
self._tfs_grpc_port_range = "9000-9000"
self._tfs_rest_port_range = "8501-8501"
# set environment variable for python service
os.environ["TFS_GRPC_PORT_RANGE"] = self._tfs_grpc_port_range
os.environ["TFS_REST_PORT_RANGE"] = self._tfs_rest_port_range
self._tfs_grpc_ports = ["9000"]
self._tfs_rest_ports = ["8501"]
# provide single concat port here for default case
self._tfs_grpc_concat_ports = "9000"
self._tfs_rest_concat_ports = "8501"

# set environment variable for python service
os.environ["TFS_GRPC_PORTS"] = self._tfs_grpc_concat_ports
os.environ["TFS_REST_PORTS"] = self._tfs_rest_concat_ports

def _need_python_service(self):
if os.path.exists(INFERENCE_PATH):
Expand All @@ -121,6 +120,11 @@ def _need_python_service(self):
and os.environ.get("SAGEMAKER_MULTI_MODEL_UNIVERSAL_PREFIX"):
self._enable_python_service = True

def _concat_ports(self, ports):
str_ports = [str(port) for port in ports]
concat_str_ports = ",".join(str_ports)
return concat_str_ports

def _create_tfs_config(self):
models = tfs_utils.find_models()

Expand Down Expand Up @@ -194,13 +198,13 @@ def _setup_gunicorn(self):
gunicorn_command = (
"gunicorn -b unix:/tmp/gunicorn.sock -k {} --chdir /sagemaker "
"--workers {} --threads {} "
"{}{} -e TFS_GRPC_PORT_RANGE={} -e TFS_REST_PORT_RANGE={} "
"{}{} -e TFS_GRPC_PORTS={} -e TFS_REST_PORTS={} "
"-e SAGEMAKER_MULTI_MODEL={} -e SAGEMAKER_SAFE_PORT_RANGE={} "
"-e SAGEMAKER_TFS_WAIT_TIME_SECONDS={} "
"python_service:app").format(self._gunicorn_worker_class,
self._gunicorn_workers, self._gunicorn_threads,
python_path_option, ",".join(python_path_content),
self._tfs_grpc_port_range, self._tfs_rest_port_range,
self._tfs_grpc_concat_ports, self._tfs_rest_concat_ports,
self._tfs_enable_multi_model_endpoint,
self._sagemaker_port_range,
self._tfs_wait_time_seconds)
Expand Down Expand Up @@ -230,7 +234,7 @@ def _download_scripts(self, bucket, prefix):
def _create_nginx_tfs_upstream(self):
indentation = " "
tfs_upstream = ""
for port in self._tfs_rest_port:
for port in self._tfs_rest_ports:
tfs_upstream += "{}server localhost:{};\n".format(indentation, port)
tfs_upstream = tfs_upstream[len(indentation):-2]

Expand Down Expand Up @@ -334,7 +338,7 @@ def _wait_for_gunicorn(self):

def _wait_for_tfs(self):
for i in range(self._tfs_instance_count):
tfs_utils.wait_for_model(self._tfs_rest_port[i],
tfs_utils.wait_for_model(self._tfs_rest_ports[i],
self._tfs_default_model_name, self._tfs_wait_time_seconds)

@contextmanager
Expand Down Expand Up @@ -370,8 +374,8 @@ def _restart_single_tfs(self, pid):

def _start_single_tfs(self, instance_id):
cmd = tfs_utils.tfs_command(
self._tfs_grpc_port[instance_id],
self._tfs_rest_port[instance_id],
self._tfs_grpc_ports[instance_id],
self._tfs_rest_ports[instance_id],
self._tfs_config_path,
self._tfs_enable_batching,
self._tfs_batching_config_path,
Expand Down

0 comments on commit 6a51a60

Please sign in to comment.