diff --git a/models/official/detection/dataloader/tf_example_decoder.py b/models/official/detection/dataloader/tf_example_decoder.py index c79ee53f9..356827578 100644 --- a/models/official/detection/dataloader/tf_example_decoder.py +++ b/models/official/detection/dataloader/tf_example_decoder.py @@ -38,8 +38,11 @@ def __init__( # copypara:strip_end regenerate_source_id=False, label_key='image/object/class/label', - label_dtype=tf.int64): + label_dtype=tf.int64, + include_keypoint=False, + num_keypoints_per_instance=0): self._include_mask = include_mask + self._include_keypoint = include_keypoint # copypara:strip_begin self._include_polygon = include_polygon # copypara:strip_end @@ -62,6 +65,13 @@ def __init__( self._label_dtype) if include_mask: self._keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.string) + if include_keypoint: + self._num_keypoints_per_instance = num_keypoints_per_instance + self._keys_to_features.update({ + 'image/object/keypoint/visibility': tf.io.VarLenFeature(tf.int64), + 'image/object/keypoint/x': tf.io.VarLenFeature(tf.float32), + 'image/object/keypoint/y': tf.io.VarLenFeature(tf.float32), + }) def _decode_image(self, parsed_tensors): """Decodes the image and set its static shape.""" @@ -106,6 +116,17 @@ def _decode_areas(self, parsed_tensors): lambda: parsed_tensors['image/object/area'], lambda: (xmax - xmin) * (ymax - ymin) * height * width) + def _decode_keypoints(self, parsed_tensors): + """Decode keypoint coordinates and visibilities.""" + keypoint_x = parsed_tensors['image/object/keypoint/x'] + keypoint_y = parsed_tensors['image/object/keypoint/y'] + keypoints = tf.stack([keypoint_y, keypoint_x], axis=-1) + keypoints = tf.reshape(keypoints, [-1, self._num_keypoints_per_instance, 2]) + keypoint_visibilities = parsed_tensors['image/object/keypoint/visibility'] + keypoint_visibilities = tf.reshape(keypoint_visibilities, + [-1, self._num_keypoints_per_instance]) + return keypoints, keypoint_visibilities + def decode(self, serialized_example): """Decode the serialized example. @@ -165,6 +186,8 @@ def decode(self, serialized_example): lambda: _get_source_id_from_encoded_image(parsed_tensors)) if self._include_mask: masks = self._decode_masks(parsed_tensors) + if self._include_keypoint: + keypoints, keypoint_visibilities = self._decode_keypoints(parsed_tensors) groundtruth_classes = parsed_tensors[self._label_key] decoded_tensors = { @@ -182,4 +205,9 @@ def decode(self, serialized_example): 'groundtruth_instance_masks': masks, 'groundtruth_instance_masks_png': parsed_tensors['image/object/mask'], }) + if self._include_keypoint: + decoded_tensors.update({ + 'groundtruth_keypoints': keypoints, + 'groundtruth_keypoint_visibilities': keypoint_visibilities, + }) return decoded_tensors diff --git a/models/official/detection/evaluation/coco_evaluator.py b/models/official/detection/evaluation/coco_evaluator.py index 4e5c7fdaa..69883e86c 100644 --- a/models/official/detection/evaluation/coco_evaluator.py +++ b/models/official/detection/evaluation/coco_evaluator.py @@ -55,6 +55,9 @@ def __init__( need_rescale_bboxes=True, per_category_metrics=False, remove_invalid_boxes=False, + include_keypoint=False, + need_rescale_keypoints=False, + kpt_oks_sigmas=None ): """Constructs COCO evaluation class. @@ -80,6 +83,13 @@ def __init__( per_category_metrics: Whether to return per category metrics. remove_invalid_boxes: A boolean indicating whether to remove invalid box during evaluation. + include_keypoint: a boolean to indicate whether or not to include the + keypoint eval. + need_rescale_keypoints: If true keypoints in `predictions` will be + rescaled back to absolute values (`image_info` is needed in this case). + kpt_oks_sigmas: The sigmas used to calculate keypoint OKS. See + http://cocodataset.org/#keypoints-eval. When None, it will use the + defaults in COCO. """ if annotation_file: if annotation_file.startswith('gs://'): @@ -105,7 +115,8 @@ def __init__( 'detection_boxes' ] self._need_rescale_bboxes = need_rescale_bboxes - if self._need_rescale_bboxes: + self._need_rescale_keypoints = need_rescale_keypoints + if self._need_rescale_bboxes or self._need_rescale_keypoints: self._required_prediction_fields.append('image_info') self._required_groundtruth_fields = [ 'source_id', 'height', 'width', 'classes', 'boxes' @@ -116,6 +127,18 @@ def __init__( self._required_prediction_fields.extend(['detection_masks']) self._required_groundtruth_fields.extend(['masks']) self.remove_invalid_boxes = remove_invalid_boxes + self._include_keypoint = include_keypoint + self._kpt_oks_sigmas = kpt_oks_sigmas + if self._include_keypoint: + keypoint_metric_names = [ + 'AP', 'AP50', 'AP75', 'APm', 'APl', 'ARmax1', 'ARmax10', 'ARmax100', + 'ARm', 'ARl' + ] + keypoint_metric_names = ['keypoint_' + x for x in keypoint_metric_names] + self._metric_names.extend(keypoint_metric_names) + self._required_prediction_fields.extend(['detection_keypoints']) + self._required_groundtruth_fields.extend(['keypoints']) + self.reset() def reset(self): @@ -168,7 +191,7 @@ def evaluate(self): coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() - coco_metrics = coco_eval.stats + metrics = coco_eval.stats if self._include_mask: mcoco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='segm') @@ -177,11 +200,17 @@ def evaluate(self): mcoco_eval.accumulate() mcoco_eval.summarize() mask_coco_metrics = mcoco_eval.stats - - if self._include_mask: - metrics = np.hstack((coco_metrics, mask_coco_metrics)) - else: - metrics = coco_metrics + metrics = np.hstack((metrics, mask_coco_metrics)) + + if self._include_keypoint: + kcoco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='keypoints', + kpt_oks_sigmas=self._kpt_oks_sigmas) + kcoco_eval.params.imgIds = image_ids + kcoco_eval.evaluate() + kcoco_eval.accumulate() + kcoco_eval.summarize() + keypoint_coco_metrics = kcoco_eval.stats + metrics = np.hstack((metrics, keypoint_coco_metrics)) # Cleans up the internal variables in order for a fresh eval next time. self.reset() @@ -192,46 +221,64 @@ def evaluate(self): # Adds metrics per category. if self._per_category_metrics and hasattr(coco_eval, 'category_stats'): - for category_index, category_id in enumerate(coco_eval.params.catIds): - metrics_dict['Precision mAP ByCategory/{}'.format( - category_id)] = coco_eval.category_stats[0][category_index].astype( - np.float32) - metrics_dict['Precision mAP ByCategory@50IoU/{}'.format( - category_id)] = coco_eval.category_stats[1][category_index].astype( - np.float32) - metrics_dict['Precision mAP ByCategory@75IoU/{}'.format( - category_id)] = coco_eval.category_stats[2][category_index].astype( - np.float32) - metrics_dict['Precision mAP ByCategory (small) /{}'.format( - category_id)] = coco_eval.category_stats[3][category_index].astype( - np.float32) - metrics_dict['Precision mAP ByCategory (medium) /{}'.format( - category_id)] = coco_eval.category_stats[4][category_index].astype( - np.float32) - metrics_dict['Precision mAP ByCategory (large) /{}'.format( - category_id)] = coco_eval.category_stats[5][category_index].astype( - np.float32) - metrics_dict['Recall AR@1 ByCategory/{}'.format( - category_id)] = coco_eval.category_stats[6][category_index].astype( - np.float32) - metrics_dict['Recall AR@10 ByCategory/{}'.format( - category_id)] = coco_eval.category_stats[7][category_index].astype( - np.float32) - metrics_dict['Recall AR@100 ByCategory/{}'.format( - category_id)] = coco_eval.category_stats[8][category_index].astype( - np.float32) - metrics_dict['Recall AR (small) ByCategory/{}'.format( - category_id)] = coco_eval.category_stats[9][category_index].astype( - np.float32) - metrics_dict['Recall AR (medium) ByCategory/{}'.format( - category_id)] = coco_eval.category_stats[10][category_index].astype( - np.float32) - metrics_dict['Recall AR (large) ByCategory/{}'.format( - category_id)] = coco_eval.category_stats[11][category_index].astype( - np.float32) + metrics_dict.update(self._retrieve_per_category_metrics(coco_eval)) + + if self._include_keypoint: + metrics_dict.update(self._retrieve_per_category_metrics( + kcoco_eval, prefix='keypoints')) return metrics_dict - def _process_predictions(self, predictions): + def _retrieve_per_category_metrics(self, coco_eval, prefix=''): + """Retrieves and per-category metrics and returns them in a dict. + + Args: + coco_eval: a cocoeval.COCOeval object containing evaluation data. + prefix: str, A string used to prefix metric names. + + Returns: + metrics_dict: A dictionary with per category metrics. + """ + + metrics_dict = {} + if prefix: + prefix = prefix + ' ' + + for category_index, category_id in enumerate(coco_eval.params.catIds): + if 'keypoints' in prefix: + metrics_dict_keys = [ + 'Precision mAP ByCategory', + 'Precision mAP ByCategory@50IoU', + 'Precision mAP ByCategory@75IoU', + 'Precision mAP ByCategory (medium)', + 'Precision mAP ByCategory (large)', + 'Recall AR@1 ByCategory', + 'Recall AR@10 ByCategory', + 'Recall AR@100 ByCategory', + 'Recall AR (medium) ByCategory', + 'Recall AR (large) ByCategory', + ] + else: + metrics_dict_keys = [ + 'Precision mAP ByCategory', + 'Precision mAP ByCategory@50IoU', + 'Precision mAP ByCategory@75IoU', + 'Precision mAP ByCategory (small)', + 'Precision mAP ByCategory (medium)', + 'Precision mAP ByCategory (large)', + 'Recall AR@1 ByCategory', + 'Recall AR@10 ByCategory', + 'Recall AR@100 ByCategory', + 'Recall AR (small) ByCategory', + 'Recall AR (medium) ByCategory', + 'Recall AR (large) ByCategory', + ] + for idx, key in enumerate(metrics_dict_keys): + metrics_dict[prefix + key + '/{}'.format( + category_id)] = coco_eval.category_stats[idx][ + category_index].astype(np.float32) + return metrics_dict + + def _process_bboxes_predictions(self, predictions): image_scale = np.tile(predictions['image_info'][:, 2:3, :], (1, 1, 2)) predictions['detection_boxes'] = ( predictions['detection_boxes'].astype(np.float32)) @@ -241,6 +288,13 @@ def _process_predictions(self, predictions): predictions['detection_outer_boxes'].astype(np.float32)) predictions['detection_outer_boxes'] /= image_scale + def _process_keypoints_predictions(self, predictions): + image_scale = tf.reshape(predictions['image_info'][:, 2:3, :], + [-1, 1, 1, 2]) + predictions['detection_keypoints'] = ( + predictions['detection_keypoints'].astype(np.float32)) + predictions['detection_keypoints'] /= image_scale + def update(self, predictions, groundtruths=None): """Update and aggregate detection results and groundtruth data. @@ -286,7 +340,9 @@ def update(self, predictions, groundtruths=None): raise ValueError( 'Missing the required key `{}` in predictions!'.format(k)) if self._need_rescale_bboxes: - self._process_predictions(predictions) + self._process_bboxes_predictions(predictions) + if self._need_rescale_keypoints: + self._process_keypoints_predictions(predictions) for k, v in six.iteritems(predictions): if k not in self._predictions: self._predictions[k] = [v] @@ -305,6 +361,20 @@ def update(self, predictions, groundtruths=None): else: self._groundtruths[k].append(v) + def merge(self, other): + """Merges the states from the other CocoEvaluator.""" + for k, v in other._predictions.items(): # pylint: disable=protected-access + if k not in self._predictions: + self._predictions[k] = v + else: + self._predictions[k].extend(v) + + for k, v in other._groundtruths.items(): # pylint: disable=protected-access + if k not in self._groundtruths: + self._groundtruths[k] = v + else: + self._groundtruths[k].extend(v) + class ShapeMaskCOCOEvaluator(COCOEvaluator): """COCO evaluation metric class for ShapeMask.""" @@ -463,6 +533,7 @@ def __init__( self._metric_names.extend(mask_metric_names) self._required_prediction_fields.extend(['detection_masks']) self._required_groundtruth_fields.extend(['masks']) + self._need_rescale_keypoints = False self.reset() diff --git a/models/official/detection/evaluation/coco_utils.py b/models/official/detection/evaluation/coco_utils.py index 4abeebe5e..5d000c843 100644 --- a/models/official/detection/evaluation/coco_utils.py +++ b/models/official/detection/evaluation/coco_utils.py @@ -138,6 +138,8 @@ def convert_predictions_to_coco_annotations( Optional fields: - detection_masks: a list of numpy arrays of float of shape [batch_size, K, mask_height, mask_width]. + - detection_keypoints: a list of numpy arrays of float of shape + [batch_size, K, num_keypoints, 2] remove_invalid_boxes: A boolean indicating whether to remove invalid box during evaluation. @@ -160,6 +162,19 @@ def convert_predictions_to_coco_annotations( # NOTE: Batch size may differ between chunks. batch_size = predictions['source_id'][i].shape[0] + if 'detection_keypoints' in predictions: + # Adds extra ones to indicate the visibility for each keypoint as is + # recommended by MSCOCO. Also, convert keypoint from [y, x] to [x, y] + # as mandated by COCO. + num_keypoints = predictions['detection_keypoints'][i].shape[2] + coco_keypoints = np.concatenate( + [ + predictions['detection_keypoints'][i][..., 1:], + predictions['detection_keypoints'][i][..., :1], + np.ones([batch_size, max_num_detections, num_keypoints, 1]), + ], + axis=-1, + ).astype(int) for j in range(batch_size): if 'detection_masks' in predictions: image_masks = mask_utils.paste_instance_masks( @@ -185,6 +200,8 @@ def convert_predictions_to_coco_annotations( ann['score'] = predictions['detection_scores'][i][j, k] if 'detection_masks' in predictions: ann['segmentation'] = encoded_masks[k] + if 'detection_keypoints' in predictions: + ann['keypoints'] = coco_keypoints[j, k].flatten().tolist() coco_predictions.append(ann) for i, ann in enumerate(coco_predictions): @@ -272,6 +289,25 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None): ann['segmentation'] = encoded_mask if 'areas' not in groundtruths: ann['area'] = mask_api.area(encoded_mask) + if 'keypoints' in groundtruths: + keypoints = groundtruths['keypoints'][i] + coco_keypoints = [] + num_valid_keypoints = 0 + for z in range(len(keypoints[j, k, :, 1])): + # Convert from [y, x] to [x, y] as mandated by COCO. + x = float(keypoints[j, k, z, 1]) + y = float(keypoints[j, k, z, 0]) + coco_keypoints.append(x) + coco_keypoints.append(y) + if tf.math.is_nan(x) or tf.math.is_nan(y) or ( + x == 0 and y == 0): + visibility = 0 + else: + visibility = 2 + num_valid_keypoints = num_valid_keypoints + 1 + coco_keypoints.append(visibility) + ann['keypoints'] = coco_keypoints + ann['num_keypoints'] = num_valid_keypoints gt_annotations.append(ann) for i, ann in enumerate(gt_annotations): diff --git a/models/official/detection/export_tensorrt_model.py b/models/official/detection/export_tensorrt_model.py deleted file mode 100644 index b1145f7f5..000000000 --- a/models/official/detection/export_tensorrt_model.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""A binary to export the TensorRT model.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl import flags -import tensorflow.compat.v1 as tf - -# pylint: disable=g-direct-tensorflow-import -from tensorflow.python.compiler.tensorrt import trt_convert -# pylint: enable=g-direct-tensorflow-import - - -FLAGS = flags.FLAGS - -flags.DEFINE_string('saved_model_dir', None, 'The saved model directory.') -flags.DEFINE_string('output_dir', None, 'The export TensorRT model directory.') - -flags.DEFINE_integer('max_batch_size', 1, 'max size for the input batch') -flags.DEFINE_integer( - 'max_workspace_size_bytes', 2 << 20, 'control memory allocation in bytes.') -flags.DEFINE_enum( - 'precision_mode', 'FP16', ['FP32', 'FP16', 'INT8'], - 'TensorRT precision mode, one of `FP32`, `FP16`, or `INT8`.') -flags.DEFINE_integer( - 'minimum_segment_size', 3, - 'minimum number of nodes required for a subgraph to be replaced by ' - 'TRTEngineOp.') -flags.DEFINE_boolean( - 'is_dynamic_op', False, 'whether to generate dynamic TRT ops') -flags.DEFINE_integer( - 'maximum_cached_engines', 1, - 'max number of cached TRT engines in dynamic TRT ops.') - - -def export(saved_model_dir, - tensorrt_model_dir, - max_batch_size=1, - max_workspace_size_bytes=2 << 20, - precision_mode='FP16', - minimum_segment_size=3, - is_dynamic_op=False, - maximum_cached_engines=1): - """Exports TensorRT model.""" - trt_convert.create_inference_graph( - None, - None, - max_batch_size=max_batch_size, - max_workspace_size_bytes=max_workspace_size_bytes, - precision_mode=precision_mode, - minimum_segment_size=minimum_segment_size, - is_dynamic_op=is_dynamic_op, - maximum_cached_engines=maximum_cached_engines, - input_saved_model_dir=saved_model_dir, - input_saved_model_tags=None, - input_saved_model_signature_key=None, - output_saved_model_dir=tensorrt_model_dir) - - -def main(argv): - del argv # Unused. - export(FLAGS.saved_model_dir, - FLAGS.output_dir, - FLAGS.max_batch_size, - FLAGS.max_workspace_size_bytes, - FLAGS.precision_mode, - FLAGS.minimum_segment_size, - FLAGS.is_dynamic_op, - FLAGS.maximum_cached_engines) - - -if __name__ == '__main__': - flags.mark_flag_as_required('saved_model_dir') - flags.mark_flag_as_required('output_dir') - tf.app.run(main) diff --git a/models/official/detection/utils/autoaugment_utils.py b/models/official/detection/utils/autoaugment_utils.py index e6fb163d1..8574cf221 100644 --- a/models/official/detection/utils/autoaugment_utils.py +++ b/models/official/detection/utils/autoaugment_utils.py @@ -1486,19 +1486,19 @@ def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams): # Check to see if prob is passed into function. This is used for operations # where we alter bboxes independently. # pytype:disable=wrong-arg-types - if 'prob' in inspect.getargspec(func)[0]: + if 'prob' in inspect.getfullargspec(func).args: args = tuple([prob] + list(args)) # pytype:enable=wrong-arg-types # Add in replace arg if it is required for the function that is being called. - if 'replace' in inspect.getargspec(func)[0]: + if 'replace' in inspect.getfullargspec(func).args: # Make sure replace is the final argument - assert 'replace' == inspect.getargspec(func)[0][-1] + assert 'replace' == inspect.getfullargspec(func).args[-1] args = tuple(list(args) + [replace_value]) # Add bboxes as the second positional argument for the function if it does # not already exist. - if 'bboxes' not in inspect.getargspec(func)[0]: + if 'bboxes' not in inspect.getfullargspec(func).args: func = bbox_wrapper(func) return (func, prob, args) @@ -1506,11 +1506,11 @@ def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams): def _apply_func_with_prob(func, image, args, prob, bboxes): """Apply `func` to image w/ `args` as input with probability `prob`.""" assert isinstance(args, tuple) - assert 'bboxes' == inspect.getargspec(func)[0][1] + assert 'bboxes' == inspect.getfullargspec(func).args[1] # If prob is a function argument, then this randomness is being handled # inside the function, so make sure it is always called. - if 'prob' in inspect.getargspec(func)[0]: + if 'prob' in inspect.getfullargspec(func).args: prob = 1.0 # Apply the function with probability `prob`. diff --git a/models/official/detection/utils/object_detection/preprocessor.py b/models/official/detection/utils/object_detection/preprocessor.py index e336f5619..d1c4a24f2 100644 --- a/models/official/detection/utils/object_detection/preprocessor.py +++ b/models/official/detection/utils/object_detection/preprocessor.py @@ -112,6 +112,38 @@ def keypoint_flip_horizontal(keypoints, flip_point, flip_permutation, return new_keypoints +def flip_keypoint_visibility( + keypoint_visibilities, flip_permutation, scope=None): + """Flips the keypoint visibilities. + + This operation flips the visibility for each keypoint in a manner specified by + flip_permutation. + + Args: + keypoint_visibilities: an int32 tensor of shape [num_instances, + num_keypoints] indicates the visibility of each keypoint, with 0 + representing an out-of-frame keypoint and a positive value representing a + visible or occluded keypoint. + flip_permutation: a tf.Tensor with type int32 and rank 1 containing the + keypoint visibility flip permutation. This specifies the mapping from + original keypoint visibility indices to the flipped keypoint visibility + indices. This is used primarily for keypoints associated with the + visibilities that are not reflection invariant. E.g. suppose there are 3 + keypoints representing ['head', 'right_eye', 'left_eye'], then a logical + choice for flip_permutation might be [0, 2, 1] since we want to swap the + 'left_eye' and 'right_eye' after a horizontal flip. + scope: an optional string indicating the name scope. + + Returns: + new_keypoint_visibilities: a tensor of shape [num_instances, num_keypoints] + """ + with tf.name_scope(scope, 'FlipKeypointVisibility'): + keypoint_visibilities = tf.transpose(keypoint_visibilities, [1, 0]) + keypoint_visibilities = tf.gather(keypoint_visibilities, flip_permutation) + new_keypoint_visibilities = tf.transpose(keypoint_visibilities, [1, 0]) + return new_keypoint_visibilities + + def keypoint_change_coordinate_frame(keypoints, window, scope=None): """Changes coordinate frame of the keypoints to be relative to window's frame. @@ -178,7 +210,8 @@ def random_horizontal_flip(image, roi_boxes=None, keypoints=None, keypoint_flip_permutation=None, - seed=None): + seed=None, + keypoint_visibilities=None): """Randomly flips the image and detections horizontally. The probability of flipping the image is 50%. @@ -204,6 +237,10 @@ def random_horizontal_flip(image, keypoint_flip_permutation: rank 1 int32 tensor containing the keypoint flip permutation. seed: random seed + keypoint_visibilities: an int32 tensor of shape [num_instances, + num_keypoints] indicates the visibility of each keypoint, with 0 + representing an out-of-frame keypoint and a positive value representing a + visible or occluded keypoint. Returns: image: image which is the same shape as input image. @@ -221,6 +258,8 @@ def random_horizontal_flip(image, between [0, 1]. keypoints: rank 3 float32 tensor with shape [num_instances, num_keypoints, 2] + keypoint_visibilities: rank 2 int32 tensor with shape [num_instances, + num_keypoints] Raises: ValueError: if keypoints are provided but keypoint_flip_permutation is not. @@ -272,6 +311,19 @@ def _flip_image(image): lambda: keypoints) result.append(keypoints) + # flip keypoint visibilities + if ( + keypoint_visibilities is not None + and keypoint_flip_permutation is not None + ): + permutation = keypoint_flip_permutation + keypoint_visibilities = tf.cond( + do_a_flip_random, + lambda: flip_keypoint_visibility(keypoint_visibilities, permutation), + lambda: keypoint_visibilities, + ) + result.append(keypoint_visibilities) + return tuple(result) diff --git a/models/official/detection/utils/object_detection/visualization_utils.py b/models/official/detection/utils/object_detection/visualization_utils.py index d88a3461e..b2aa122d0 100644 --- a/models/official/detection/utils/object_detection/visualization_utils.py +++ b/models/official/detection/utils/object_detection/visualization_utils.py @@ -184,7 +184,10 @@ def draw_bounding_box_on_image(image, # If the total height of the display strings added to the top of the bounding # box exceeds the top of the image, stack the strings below the bounding box # instead of above. - display_str_heights = [font.getsize(ds)[1] for ds in display_str_list] + if hasattr(font, 'getsize'): + display_str_heights = [font.getsize(ds)[1] for ds in display_str_list] + else: + display_str_heights = [font.getbbox(ds)[3] for ds in display_str_list] # Each display_str has a top and bottom margin of 0.05x. total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights) @@ -194,7 +197,10 @@ def draw_bounding_box_on_image(image, text_bottom = bottom + total_display_str_height # Reverse list and print from bottom to top. for display_str in display_str_list[::-1]: - text_width, text_height = font.getsize(display_str) + if hasattr(font, 'getsize'): + text_width, text_height = font.getsize(display_str) + else: + text_width, text_height = font.getbbox(display_str)[2:4] margin = np.ceil(0.05 * text_height) draw.rectangle( [(left, text_bottom - text_height - 2 * margin), (left + text_width, diff --git a/models/official/mask_rcnn/export_saved_model.py b/models/official/mask_rcnn/export_saved_model.py index bc8cd3f07..95b901e01 100644 --- a/models/official/mask_rcnn/export_saved_model.py +++ b/models/official/mask_rcnn/export_saved_model.py @@ -37,6 +37,7 @@ from absl import flags import tensorflow.compat.v1 as tf from tensorflow.compat.v1 import estimator as tf_estimator +from tensorflow_estimator.python.estimator.tpu import tpu_config # pylint: disable=g-deprecated-tf-checker import sys sys.path.insert(0, 'tpu/models') @@ -44,7 +45,6 @@ from hyperparameters import params_dict import serving from configs import mask_rcnn_config -from tensorflow.python.tpu import tpu_config # pylint: disable=g-direct-tensorflow-import FLAGS = flags.FLAGS diff --git a/models/official/resnet/resnet_preprocessing.py b/models/official/resnet/resnet_preprocessing.py index abdfd0d80..ff00aa5d3 100644 --- a/models/official/resnet/resnet_preprocessing.py +++ b/models/official/resnet/resnet_preprocessing.py @@ -56,7 +56,8 @@ def distorted_bounding_box_crop(image_bytes, cropped image `Tensor` """ with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]): - shape = tf.image.extract_jpeg_shape(image_bytes) + image = tf.io.decode_image(image_bytes, channels=3, expand_animations=False) + shape = tf.shape(image) sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( shape, bounding_boxes=bbox, @@ -70,8 +71,9 @@ def distorted_bounding_box_crop(image_bytes, # Crop the image to the specified bounding box. offset_y, offset_x, _ = tf.unstack(bbox_begin) target_height, target_width, _ = tf.unstack(bbox_size) - crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) - image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + image = tf.image.crop_to_bounding_box( + image, offset_y, offset_x, target_height, target_width + ) return image diff --git a/models/official/retinanet/retinanet_main.py b/models/official/retinanet/retinanet_main.py index 2948d788e..16cf3993c 100644 --- a/models/official/retinanet/retinanet_main.py +++ b/models/official/retinanet/retinanet_main.py @@ -39,7 +39,6 @@ from tensorflow.contrib import training as contrib_training from tensorflow.core.protobuf import rewriter_config_pb2 # pylint: disable=g-direct-tensorflow-import from tensorflow.python.client import device_lib # pylint: disable=g-direct-tensorflow-import -from tensorflow.python.eager import profiler # pylint: disable=g-direct-tensorflow-import _COLLECTIVE_COMMUNICATION_OPTIONS = { None: tf.distribute.experimental.CollectiveCommunication.AUTO, @@ -242,7 +241,7 @@ def main(argv): if FLAGS.start_profiler_server: # Starts profiler. It will perform profiling when receive profiling request. - profiler.start_profiler_server(FLAGS.profiler_port_number) + tf.profiler.experimental.server.start(FLAGS.profiler_port_number) if FLAGS.use_tpu: if FLAGS.distribution_strategy is None: diff --git a/tools/grpc_tpu_worker/grpc_tpu_worker.py b/tools/grpc_tpu_worker/grpc_tpu_worker.py deleted file mode 100644 index 3d590c4fe..000000000 --- a/tools/grpc_tpu_worker/grpc_tpu_worker.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# pylint: disable=g-import-not-at-top -"""Python-based TPU Worker GRPC server. - -Start a blocking TPU Worker GRPC server. - -Usage: - python3 grpc_tpu_worker.py -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import os -import sys - -import requests - - -def get_metadata(key): - return requests.get( - 'http://metadata.google.internal/computeMetadata' - '/v1/instance/attributes/{}'.format(key), - headers={ - 'Metadata-Flavor': 'Google' - }).text - - -def get_host_ip(): - return requests.get( - 'http://metadata.google.internal/computeMetadata/v1/instance/network-interfaces/0/ip', - headers={ - 'Metadata-Flavor': 'Google' - }).text - - -def setup_env_vars(): - """Set environment variables.""" - worker_id = get_metadata('agent-worker-number') - accelerator_type = get_metadata('accelerator-type') - worker_network_endpoints = get_metadata('worker-network-endpoints') - os.environ['TPU_STDERR_LOG_LEVEL'] = '0' - os.environ['CLOUD_TPU_TASK_ID'] = worker_id - os.environ['TPU_LOCK_DEVICE'] = 'true' - os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1' - accelerator_type_to_host_bounds = { - # v2 - 'v2-8': '1,1,1', - 'v2-32': '2,2,1', - 'v2-128': '4,4,1', - 'v2-256': '4,8,1', - 'v2-512': '8,8,1', - # v3 - 'v3-8': '1,1,1', - 'v3-32': '2,2,1', - 'v3-64': '2,4,1', - 'v3-128': '4,4,1', - 'v3-256': '4,8,1', - 'v3-512': '8,8,1', - 'v3-1024': '8,16,1', - 'v3-2048': '16,16,1', - # v4 - 'v4-8': '1,1,1', - 'v4-16': '1,1,2', - 'v4-32': '1,1,4', - 'v4-64': '1,2,4', - 'v4-128': '2,2,4', - 'v4-256': '2,2,8', - 'v4-512': '2,4,8', - 'v4-1024': '4,4,8', - 'v4-2048': '4,4,16', - 'v4-4096': '4,8,16', - } - - os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[ - accelerator_type] - os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = worker_network_endpoints.split( - ',')[0].split(':')[2] + ':8476' - os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476' - - os.environ['TPU_STDERR_LOG_LEVEL'] = '0' - - if accelerator_type not in ['v4-8', 'v4-16', 'v4-32', 'v4-64']: - os.environ['TPU_TOPOLOGY_WRAP'] = 'true,true,true' - - # Set the hostname override. - os.environ['TPU_HOSTNAME_OVERRIDE'] = get_host_ip() - - -def main(unused_args): - # Create Protobuf ServerDef. - server_def = tensorflow_server_pb2.ServerDef(protocol='grpc') - job_def = server_def.cluster.job.add() - job_def.name = 'tpu_worker' - job_def.tasks[0] = 'localhost:8470' - server_def.job_name = 'tpu_worker' - server_def.task_index = 0 - - config = config_pb2.ConfigProto() - - # Create GRPC Server instance - server = server_lib.Server(server_def, config=config) - - # join() is blocking, unlike start() - server.join() - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - - FLAGS, unparsed = parser.parse_known_args() - # Must set environment variables before importing tensorflow. - setup_env_vars() - from tensorflow.core.protobuf import config_pb2 - from tensorflow.core.protobuf import tensorflow_server_pb2 - from tensorflow.python.platform import app - from tensorflow.python.training import server_lib - app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v4/main.tf b/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v4/main.tf new file mode 100644 index 000000000..c3b6990ca --- /dev/null +++ b/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v4/main.tf @@ -0,0 +1,15 @@ +variable "project_id" {} +variable "resource_name_prefix" {} +variable "region" {} +variable "tpu_node_pools" {} +variable "maintenance_interval" {} + + +module "tpu-gke" { + source = "../../module" + project_id = var.project_id + resource_name_prefix = var.resource_name_prefix + region = var.region + tpu_node_pools = var.tpu_node_pools + maintenance_interval = var.maintenance_interval +} diff --git a/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v4/outputs.tf b/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v4/outputs.tf new file mode 100644 index 000000000..78a05be82 --- /dev/null +++ b/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v4/outputs.tf @@ -0,0 +1,24 @@ +output "region" { + value = var.region + description = "GCloud Region" +} + +output "project_id" { + value = var.project_id + description = "GCloud Project ID" +} + +output "kubernetes_cluster_name" { + value = module.tpu-gke.kubernetes_cluster_name + description = "GKE Cluster Name" +} + +output "kubernetes_cluster_host" { + value = module.tpu-gke.kubernetes_cluster_host + description = "GKE Cluster Host" +} + +output "nodepool_tpu_topology" { + value = module.tpu-gke.nodepool_tpu_topology + description = "GKE TPU topology" +} diff --git a/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v4/terraform.tfvars b/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v4/terraform.tfvars new file mode 100644 index 000000000..84f608505 --- /dev/null +++ b/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v4/terraform.tfvars @@ -0,0 +1,9 @@ +project_id = "project-id" +resource_name_prefix = "tpu-test" +region = "us-central2" +tpu_node_pools = [{ + zone = "us-central2-b" + node_count = 2 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x2" +}] diff --git a/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v5e/terraform.tfvars b/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v5e/terraform.tfvars index 9148d1195..ac742690b 100644 --- a/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v5e/terraform.tfvars +++ b/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/examples/v5e/terraform.tfvars @@ -11,13 +11,5 @@ tpu_node_pools = [{ policy = "sb-compact-rp1" disk_type = "pd-balanced" disk_size_gb = 120 - }, { - zone = "us-east5-b" - node_count = 32 - machine_type = "ct5lp-hightpu-4t" - topology = "8x16" - policy = "sb-compact-rp1" - disk_type = "pd-balanced" - disk_size_gb = 120 }] maintenance_interval = "PERIODIC" diff --git a/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/module/terraform.tfvars b/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/module/terraform.tfvars index 520ff3a12..e44d8decd 100644 --- a/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/module/terraform.tfvars +++ b/tools/kubernetes/terraform/batching_with_compact_placement/add_node_pool/module/terraform.tfvars @@ -1,11 +1,20 @@ project_id = "project-id" resource_name_prefix = "tpu-test" -region = "us-east5" +region = "us-central2" tpu_node_pools = [{ - zone = "us-east5-b" - node_count = 32 - machine_type = "ct5lp-hightpu-4t" - topology = "8x16" - policy = "sb-compact-rp1" + zone = "us-central2-b" + node_count = 4 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x4" + }, { + zone = "us-central2-b" + node_count = 4 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x4" + }, { + zone = "us-central2-b" + node_count = 2 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x2" }] maintenance_interval = "AS_NEEDED" diff --git a/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/examples/v4/main.tf b/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/examples/v4/main.tf new file mode 100644 index 000000000..304251dca --- /dev/null +++ b/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/examples/v4/main.tf @@ -0,0 +1,17 @@ +variable "project_id" {} +variable "resource_name_prefix" {} +variable "region" {} +variable "tpu_node_pools" {} +variable "cpu_node_pool" {} +variable "maintenance_interval" {} + + +module "tpu-gke" { + source = "../../module" + project_id = var.project_id + resource_name_prefix = var.resource_name_prefix + region = var.region + tpu_node_pools = var.tpu_node_pools + cpu_node_pool = var.cpu_node_pool + maintenance_interval = var.maintenance_interval +} diff --git a/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/examples/v4/outputs.tf b/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/examples/v4/outputs.tf new file mode 100644 index 000000000..78a05be82 --- /dev/null +++ b/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/examples/v4/outputs.tf @@ -0,0 +1,24 @@ +output "region" { + value = var.region + description = "GCloud Region" +} + +output "project_id" { + value = var.project_id + description = "GCloud Project ID" +} + +output "kubernetes_cluster_name" { + value = module.tpu-gke.kubernetes_cluster_name + description = "GKE Cluster Name" +} + +output "kubernetes_cluster_host" { + value = module.tpu-gke.kubernetes_cluster_host + description = "GKE Cluster Host" +} + +output "nodepool_tpu_topology" { + value = module.tpu-gke.nodepool_tpu_topology + description = "GKE TPU topology" +} diff --git a/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/examples/v4/terraform.tfvars b/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/examples/v4/terraform.tfvars new file mode 100644 index 000000000..20ecf2ca1 --- /dev/null +++ b/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/examples/v4/terraform.tfvars @@ -0,0 +1,16 @@ +project_id = "project-id" +resource_name_prefix = "tpu-test" +region = "us-central2" +tpu_node_pools = [{ + zone = "us-central2-b" + node_count = 2 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x2" +}] +cpu_node_pool = { + zone = ["us-central2-a", "us-central2-b", "us-central2-c"] + machine_type = "n2-standard-8", + initial_node_count_per_zone = 1, + min_node_count_per_zone = 1, + max_node_count_per_zone = 30, +} diff --git a/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/module/terraform.tfvars b/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/module/terraform.tfvars index bdda5d5e9..f3f4e7be0 100644 --- a/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/module/terraform.tfvars +++ b/tools/kubernetes/terraform/batching_with_compact_placement/create_cluster/module/terraform.tfvars @@ -1,9 +1,9 @@ project_id = "project-id" resource_name_prefix = "tpu-test" -region = "us-east5" +region = "us-central2" authorized_cidr_blocks = [] cpu_node_pool = { - zone = ["us-east5-a", "us-east5-b", "us-east5-c"] + zone = ["us-central2-a", "us-central2-b", "us-central2-c"] machine_type = "n2-standard-64", initial_node_count_per_zone = 1, min_node_count_per_zone = 1, diff --git a/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/examples/v4/main.tf b/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/examples/v4/main.tf new file mode 100644 index 000000000..c3b6990ca --- /dev/null +++ b/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/examples/v4/main.tf @@ -0,0 +1,15 @@ +variable "project_id" {} +variable "resource_name_prefix" {} +variable "region" {} +variable "tpu_node_pools" {} +variable "maintenance_interval" {} + + +module "tpu-gke" { + source = "../../module" + project_id = var.project_id + resource_name_prefix = var.resource_name_prefix + region = var.region + tpu_node_pools = var.tpu_node_pools + maintenance_interval = var.maintenance_interval +} diff --git a/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/examples/v4/outputs.tf b/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/examples/v4/outputs.tf new file mode 100644 index 000000000..44d8350f9 --- /dev/null +++ b/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/examples/v4/outputs.tf @@ -0,0 +1,19 @@ +output "region" { + value = var.region + description = "GCloud Region" +} + +output "project_id" { + value = var.project_id + description = "GCloud Project ID" +} + +output "kubernetes_cluster_name" { + value = module.tpu-gke.kubernetes_cluster_name + description = "GKE Cluster Name" +} + +output "kubernetes_cluster_host" { + value = module.tpu-gke.kubernetes_cluster_host + description = "GKE Cluster Host" +} diff --git a/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/examples/v4/terraform.tfvars b/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/examples/v4/terraform.tfvars new file mode 100644 index 000000000..84f608505 --- /dev/null +++ b/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/examples/v4/terraform.tfvars @@ -0,0 +1,9 @@ +project_id = "project-id" +resource_name_prefix = "tpu-test" +region = "us-central2" +tpu_node_pools = [{ + zone = "us-central2-b" + node_count = 2 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x2" +}] diff --git a/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/module/terraform.tfvars b/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/module/terraform.tfvars index a38800da4..e44d8decd 100644 --- a/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/module/terraform.tfvars +++ b/tools/kubernetes/terraform/batching_without_compact_placement/add_node_pool/module/terraform.tfvars @@ -1,10 +1,20 @@ project_id = "project-id" resource_name_prefix = "tpu-test" -region = "us-east5" +region = "us-central2" tpu_node_pools = [{ - zone = "us-east5-b" - node_count = 16 - machine_type = "ct5lp-hightpu-4t" - topology = "8x8" + zone = "us-central2-b" + node_count = 4 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x4" + }, { + zone = "us-central2-b" + node_count = 4 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x4" + }, { + zone = "us-central2-b" + node_count = 2 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x2" }] maintenance_interval = "AS_NEEDED" diff --git a/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/examples/v4/main.tf b/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/examples/v4/main.tf new file mode 100644 index 000000000..304251dca --- /dev/null +++ b/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/examples/v4/main.tf @@ -0,0 +1,17 @@ +variable "project_id" {} +variable "resource_name_prefix" {} +variable "region" {} +variable "tpu_node_pools" {} +variable "cpu_node_pool" {} +variable "maintenance_interval" {} + + +module "tpu-gke" { + source = "../../module" + project_id = var.project_id + resource_name_prefix = var.resource_name_prefix + region = var.region + tpu_node_pools = var.tpu_node_pools + cpu_node_pool = var.cpu_node_pool + maintenance_interval = var.maintenance_interval +} diff --git a/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/examples/v4/outputs.tf b/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/examples/v4/outputs.tf new file mode 100644 index 000000000..44d8350f9 --- /dev/null +++ b/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/examples/v4/outputs.tf @@ -0,0 +1,19 @@ +output "region" { + value = var.region + description = "GCloud Region" +} + +output "project_id" { + value = var.project_id + description = "GCloud Project ID" +} + +output "kubernetes_cluster_name" { + value = module.tpu-gke.kubernetes_cluster_name + description = "GKE Cluster Name" +} + +output "kubernetes_cluster_host" { + value = module.tpu-gke.kubernetes_cluster_host + description = "GKE Cluster Host" +} diff --git a/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/examples/v4/terraform.tfvars b/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/examples/v4/terraform.tfvars new file mode 100644 index 000000000..1e6c096f5 --- /dev/null +++ b/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/examples/v4/terraform.tfvars @@ -0,0 +1,16 @@ +project_id = "project-id" +resource_name_prefix = "tpu-test" +region = "us-central2" +tpu_node_pools = [{ + zone = "us-central2-b" + node_count = 2 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x2" + }] +cpu_node_pool = { + zone = ["us-central2-a", "us-central2-b", "us-central2-c"] + machine_type = "n2-standard-8", + initial_node_count_per_zone = 1, + min_node_count_per_zone = 1, + max_node_count_per_zone = 30, +} diff --git a/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/module/terraform.tfvars b/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/module/terraform.tfvars index bdda5d5e9..f3f4e7be0 100644 --- a/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/module/terraform.tfvars +++ b/tools/kubernetes/terraform/batching_without_compact_placement/create_cluster/module/terraform.tfvars @@ -1,9 +1,9 @@ project_id = "project-id" resource_name_prefix = "tpu-test" -region = "us-east5" +region = "us-central2" authorized_cidr_blocks = [] cpu_node_pool = { - zone = ["us-east5-a", "us-east5-b", "us-east5-c"] + zone = ["us-central2-a", "us-central2-b", "us-central2-c"] machine_type = "n2-standard-64", initial_node_count_per_zone = 1, min_node_count_per_zone = 1, diff --git a/tools/kubernetes/terraform/non_batching_with_compact_placement/examples/v4/main.tf b/tools/kubernetes/terraform/non_batching_with_compact_placement/examples/v4/main.tf new file mode 100644 index 000000000..c3b6990ca --- /dev/null +++ b/tools/kubernetes/terraform/non_batching_with_compact_placement/examples/v4/main.tf @@ -0,0 +1,15 @@ +variable "project_id" {} +variable "resource_name_prefix" {} +variable "region" {} +variable "tpu_node_pools" {} +variable "maintenance_interval" {} + + +module "tpu-gke" { + source = "../../module" + project_id = var.project_id + resource_name_prefix = var.resource_name_prefix + region = var.region + tpu_node_pools = var.tpu_node_pools + maintenance_interval = var.maintenance_interval +} diff --git a/tools/kubernetes/terraform/non_batching_with_compact_placement/examples/v4/outputs.tf b/tools/kubernetes/terraform/non_batching_with_compact_placement/examples/v4/outputs.tf new file mode 100644 index 000000000..78a05be82 --- /dev/null +++ b/tools/kubernetes/terraform/non_batching_with_compact_placement/examples/v4/outputs.tf @@ -0,0 +1,24 @@ +output "region" { + value = var.region + description = "GCloud Region" +} + +output "project_id" { + value = var.project_id + description = "GCloud Project ID" +} + +output "kubernetes_cluster_name" { + value = module.tpu-gke.kubernetes_cluster_name + description = "GKE Cluster Name" +} + +output "kubernetes_cluster_host" { + value = module.tpu-gke.kubernetes_cluster_host + description = "GKE Cluster Host" +} + +output "nodepool_tpu_topology" { + value = module.tpu-gke.nodepool_tpu_topology + description = "GKE TPU topology" +} diff --git a/tools/kubernetes/terraform/non_batching_with_compact_placement/examples/v4/terraform.tfvars b/tools/kubernetes/terraform/non_batching_with_compact_placement/examples/v4/terraform.tfvars new file mode 100644 index 000000000..84f608505 --- /dev/null +++ b/tools/kubernetes/terraform/non_batching_with_compact_placement/examples/v4/terraform.tfvars @@ -0,0 +1,9 @@ +project_id = "project-id" +resource_name_prefix = "tpu-test" +region = "us-central2" +tpu_node_pools = [{ + zone = "us-central2-b" + node_count = 2 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x2" +}] diff --git a/tools/kubernetes/terraform/non_batching_with_compact_placement/module/terraform.tfvars b/tools/kubernetes/terraform/non_batching_with_compact_placement/module/terraform.tfvars index 75e433778..8f63265f8 100644 --- a/tools/kubernetes/terraform/non_batching_with_compact_placement/module/terraform.tfvars +++ b/tools/kubernetes/terraform/non_batching_with_compact_placement/module/terraform.tfvars @@ -1,23 +1,28 @@ project_id = "project-id" resource_name_prefix = "tpu-test" -region = "us-east5-a" +region = "us-central2" authorized_cidr_blocks = [] cpu_node_pool = { - zone = ["us-east5-a", "us-east5-b", "us-east5-c"] + zone = ["us-central2-a", "us-central2-b", "us-central2-c"] machine_type = "n2-standard-64", initial_node_count_per_zone = 1, min_node_count_per_zone = 1, max_node_count_per_zone = 10 } tpu_node_pools = [{ - zone = "us-east5-b" - node_count = 64 - machine_type = "ct5lp-hightpu-4t" - topology = "16x16" - },{ - zone = "us-east5-b" - node_count = 64 - machine_type = "ct5lp-hightpu-4t" - topology = "16x16" + zone = "us-central2-b" + node_count = 4 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x4" + }, { + zone = "us-central2-b" + node_count = 4 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x4" + }, { + zone = "us-central2-b" + node_count = 2 + machine_type = "ct4p-hightpu-4t" + topology = "2x2x2" }] maintenance_interval = "AS_NEEDED"