diff --git a/README.md b/README.md new file mode 100644 index 0000000..9817298 --- /dev/null +++ b/README.md @@ -0,0 +1,62 @@ +# ADLC + +As of right now, this just includes a script for generating a dataset from images and annotations and a notebook for preliminary testing on using a CNN for object detection. + +Here is a sample output of `visualize_detections` + +![four img with bounding box](./img/output.png) + +Since the images are really high resolution, the lines cover the targets, but the boxes are accurate. + +## Data Annotation + +For now, I have done data annotation using OpenCVs GUI tool [`opencv_annotation`](https://docs.opencv.org/4.x/dc/d88/tutorial_traincascade.html#Preparation-of-the-training-data), which uses a XYWH bounding-box format. There are annotations in `data/annotation_238.txt`. + +To include the corresponding images, you will need to download them from the Kraken computer and place them in `data/flight_238/*.jpg`. They are located in `/RAID/Flights/Flight_238/*.jpg`. + +## Setup Development Environment + +### Using a Conda/Mamba Environment + +Create a conda environment using: + +```sh +conda env create --file ncsuadlc_condaenv.yaml -n ncsuadlc +conda activate ncsuadlc + +# Some requirements are only up-to-date on PyPi +pip install -r ncsuadlc_pipreqs.txt +``` + +### Pip Only + +```sh +pip install -r requirements.txt +``` + +## Using CuDNN Acceleration on VLC + +NCSU provides VLCs with RTX 2080 GPUs that can be used for training the CNN quickly. CUDA is already installed on these systems but you will need to install CuDNN as well: + +```sh +sudo apt-get install libcudnn8=8.8.0.121-1+cuda12.1 +sudo apt-get install libcudnn8-dev=8.8.0.121-1+cuda12.1 +sudo apt-get install libcudnn8-samples=8.8.0.121-1+cuda12.1 +``` + +To check that CuDNN was set up correctly, run built-in test suite: + +```sh +cp -r /usr/src/cudnn_samples_v8/ $HOME +cd $HOME/cudnn_samples_v8/mnistCUDNN +make clean && make +sudo apt-get install libfreeimage3 libfreeimage-dev +make clean && make +./mnistCUDNN +``` + +You will also need to make sure that Tensorflow has needed GPU dependencies using: + +```sh +pip install tensorflow[and-cuda] +``` diff --git a/adlc_util.py b/adlc_util.py index 472bf59..bca5eb1 100644 --- a/adlc_util.py +++ b/adlc_util.py @@ -3,26 +3,26 @@ CLASS_MAPPING = {0:"target"} -def visualize_dataset(inputs, value_range, rows, cols, bounding_box_format, offset=1): - """ - Unused, but works - """ - it = iter(inputs.take(offset)) +def visualize_detections(model, dataset, bounding_box_format, offset=1): + it = iter(dataset.take(offset)) for _ in range(offset): - inputs = next(it) + images, y_true = next(it) - images, bounding_boxes = inputs["images"], inputs["bounding_boxes"] + # images, y_true = next(iter(dataset.take(1))) + y_pred = model.predict(images) + y_pred = bounding_box.to_ragged(y_pred) visualization.plot_bounding_box_gallery( images, - value_range=value_range, - rows=rows, - cols=cols, - y_true=bounding_boxes, - scale=10, - line_thickness=1, - font_scale=0.7, + value_range=(0, 255), bounding_box_format=bounding_box_format, + y_true=y_true, + y_pred=y_pred, + scale=20, + rows=2, + cols=2, + show=True, + font_scale=0.7, class_mapping=CLASS_MAPPING, ) diff --git a/augment_data.py b/augment_data.py deleted file mode 100644 index 485ce1b..0000000 --- a/augment_data.py +++ /dev/null @@ -1,30 +0,0 @@ -import keras_cv -import tensorflow as tf -from adlc_util import visualize_dataset - - -def get_augmenter(): - return tf.keras.Sequential( - layers=[ - keras_cv.layers.RandomFlip(mode="horizontal", bounding_box_format="xywh"), - keras_cv.layers.JitteredResize( - target_size=(640, 640), - scale_factor=(0.75, 1.3), - bounding_box_format="xywh", - ), - ] - ) - - -if __name__ == "__main__": - # NOTE: this requires tensorflow>=2.14 on linux: https://github.com/XiaotingChen/maxatac_pip_1.0.5/issues/2 - train_ds = tf.data.Dataset.load(path="data/238_train_ds", compression="GZIP") - test_ds = tf.data.Dataset.load(path="data/238_test_ds", compression="GZIP") - - augmenter = get_augmenter() - - train_ds = train_ds.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE) - - visualize_dataset( - train_ds, bounding_box_format="xywh", value_range=(0, 255), rows=2, cols=2, offset=2 - ) diff --git a/cudnn.sh b/cudnn.sh deleted file mode 100644 index 3b6cd49..0000000 --- a/cudnn.sh +++ /dev/null @@ -1,7 +0,0 @@ -#python3 -m pip install tensorflow[and-cuda] -cudnn_version="8.8.0.121" -cuda_version="cuda12.1" - -sudo apt-get install libcudnn8=${cudnn_version}-1+${cuda_version} -sudo apt-get install libcudnn8-dev=${cudnn_version}-1+${cuda_version} -sudo apt-get install libcudnn8-samples=${cudnn_version}-1+${cuda_version} \ No newline at end of file diff --git a/img/output.png b/img/output.png new file mode 100644 index 0000000..2fd6782 Binary files /dev/null and b/img/output.png differ diff --git a/ncsuadlc_condaenv.yaml b/ncsuadlc_condaenv.yaml new file mode 100644 index 0000000..9d454e3 --- /dev/null +++ b/ncsuadlc_condaenv.yaml @@ -0,0 +1,48 @@ +name: ncsuadlc +channels: +- conda-forge +dependencies: +- _libgcc_mutex==0.1 +- _openmp_mutex==4.5 +- bzip2==1.0.8 +- ca-certificates +- click +- empy +- ipykernel +- lark +- ld_impl_linux-64==2.40 +- libblas==3.9.0 +- libcblas==3.9.0 +- libexpat==2.5.0 +- libffi==3.4.2 +- libgcc-ng==13.2.0 +- libgfortran-ng==13.2.0 +- libgfortran5==13.2.0 +- libgomp==13.2.0 +- liblapack==3.9.0 +- libnsl==2.0.0 +- libopenblas==0.3.24 +- libsqlite==3.43.0 +- libstdcxx-ng==13.2.0 +- libuuid==2.38.1 +- libzlib==1.2.13 +- ncurses==6.4 +- numpy==1.25.2 +- openssl +- pandas +- pillow +- pip==23.2.1 +- protobuf +- pycocotools +- python==3.11.5 +- python_abi==3.11 +- readline==8.2 +- scikit-learn +- setuptools==68.1.2 +- tk==8.6.12 +- tqdm +- transforms3d==0.4.1 +- tzdata==2023c +- wheel==0.41.2 +- xz==5.2.6 + diff --git a/ncsuadlc_pipreqs.txt b/ncsuadlc_pipreqs.txt new file mode 100644 index 0000000..e90f02c --- /dev/null +++ b/ncsuadlc_pipreqs.txt @@ -0,0 +1,116 @@ +absl-py==1.4.0 +array-record==0.5.0 +astor==0.8.1 +asttokens==2.4.0 +astunparse==1.6.3 +atomicwrites==1.4.1 +backcall==0.2.0 +backports.functools-lru-cache==1.6.5 +cachetools==5.3.2 +certifi==2023.7.22 +charset-normalizer==3.3.1 +click==8.1.7 +colorama==0.4.6 +comm==0.1.4 +contourpy==1.1.1 +cycler==0.12.1 +Cython==3.0.5 +debugpy==1.8.0 +decorator==5.1.1 +dm-tree==0.1.8 +empy==3.3.4 +etils==1.5.2 +exceptiongroup==1.1.3 +executing==1.2.0 +flatbuffers==23.5.26 +fonttools==4.43.1 +fsspec==2023.10.0 +gast==0.5.4 +google-auth==2.23.4 +google-auth-oauthlib==1.0.0 +google-pasta==0.2.0 +googleapis-common-protos==1.61.0 +grpcio==1.59.2 +h5py==3.10.0 +idna==3.4 +importlib-metadata==6.8.0 +importlib-resources==6.1.0 +ipykernel==6.25.2 +ipython==8.16.1 +jedi==0.19.1 +joblib==1.3.2 +Js2Py==0.74 +jupyter_client==8.4.0 +jupyter_core==5.4.0 +keras==2.14.0 +keras-core==0.1.7 +keras-cv==0.6.4 +kiwisolver==1.4.5 +lark==1.1.7 +libclang==16.0.6 +Markdown==3.5.1 +markdown-it-py==3.0.0 +MarkupSafe==2.1.3 +matplotlib==3.8.0 +matplotlib-inline==0.1.6 +mdurl==0.1.2 +ml-dtypes==0.2.0 +munkres==1.1.4 +namex==0.0.7 +nest-asyncio==1.5.8 +oauthlib==3.2.2 +opencv-python==4.8.1.78 +opt-einsum==3.3.0 +packaging==23.2 +pandas==2.1.2 +parso==0.8.3 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==10.1.0 +platformdirs==3.5.1 +promise==2.3 +prompt-toolkit==3.0.39 +protobuf==3.20.3 +psutil==5.9.5 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pyasn1==0.5.0 +pyasn1-modules==0.3.0 +pycocotools==2.0.6 +Pygments==2.16.1 +pyjsparser==2.7.1 +pyparsing==3.1.1 +python-dateutil==2.8.2 +pytz==2023.3.post1 +pyzmq==25.1.1 +regex==2023.10.3 +requests==2.31.0 +requests-oauthlib==1.3.1 +rich==13.6.0 +rsa==4.9 +scikit-learn==1.3.2 +scipy==1.11.3 +six==1.16.0 +stack-data==0.6.2 +tensorboard==2.14.1 +tensorboard-data-server==0.7.2 +tensorflow==2.14.0 +tensorflow-datasets==4.9.3 +tensorflow-estimator==2.14.0 +tensorflow-io-gcs-filesystem==0.34.0 +tensorflow-metadata==1.14.0 +termcolor==2.3.0 +tfds-nightly==4.9.3.dev202310060044 +threadpoolctl==3.2.0 +toml==0.10.2 +tornado==6.3.3 +tqdm==4.66.1 +traitlets==5.11.2 +typing_extensions==4.8.0 +tzdata==2023.3 +tzlocal==5.1 +urllib3==2.0.7 +wcwidth==0.2.8 +Werkzeug==3.0.1 +wrapt==1.14.1 +zipp==3.17.0 diff --git a/ncsuarc_condaenv.yaml b/ncsuarc_condaenv.yaml deleted file mode 100644 index 281c8a2..0000000 --- a/ncsuarc_condaenv.yaml +++ /dev/null @@ -1,48 +0,0 @@ -name: ncsuarc -channels: -- conda-forge -dependencies: -- _libgcc_mutex==0.1=conda_forge -- _openmp_mutex==4.5=2_gnu -- bzip2==1.0.8=h7f98852_4 -- ca-certificates -- click -- empy -- ipykernel -- lark -- ld_impl_linux-64==2.40=h41732ed_0 -- libblas==3.9.0=18_linux64_openblas -- libcblas==3.9.0=18_linux64_openblas -- libexpat==2.5.0=hcb278e6_1 -- libffi==3.4.2=h7f98852_5 -- libgcc-ng==13.2.0=h807b86a_0 -- libgfortran-ng==13.2.0=h69a702a_0 -- libgfortran5==13.2.0=ha4646dd_0 -- libgomp==13.2.0=h807b86a_0 -- liblapack==3.9.0=18_linux64_openblas -- libnsl==2.0.0=h7f98852_0 -- libopenblas==0.3.24=pthreads_h413a1c8_0 -- libsqlite==3.43.0=h2797004_0 -- libstdcxx-ng==13.2.0=h7e041cc_0 -- libuuid==2.38.1=h0b41bf4_0 -- libzlib==1.2.13=hd590300_5 -- ncurses==6.4=hcb278e6_0 -- numpy==1.25.2=py311h64a7726_0 -- openssl -- pandas -- pillow -- pip==23.2.1=pyhd8ed1ab_0 -- protobuf -- pycocotools -- python==3.11.5=hab00c5b_0_cpython -- python_abi==3.11=3_cp311 -- readline==8.2=h8228510_1 -- scikit-learn -- setuptools==68.1.2=pyhd8ed1ab_0 -- tk==8.6.12=h27826a3_0 -- tqdm -- transforms3d==0.4.1=pyhd8ed1ab_0 -- tzdata==2023c=h71feb2d_0 -- wheel==0.41.2=pyhd8ed1ab_0 -- xz==5.2.6=h166bdaf_0 - diff --git a/ncsuarc_pipreqs.txt b/requirements.txt similarity index 100% rename from ncsuarc_pipreqs.txt rename to requirements.txt diff --git a/test_pipeline.ipynb b/test_objdetect.ipynb similarity index 99% rename from test_pipeline.ipynb rename to test_objdetect.ipynb index 87938db..4c7e1c4 100644 --- a/test_pipeline.ipynb +++ b/test_objdetect.ipynb @@ -1,13 +1,10 @@ { "cells": [ { - "cell_type": "code", - "execution_count": 14, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "del visualize_dataset\n", - "del visualize_detections" + "This notebook is heavily based on [this Colab notebook](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/keras_cv/object_detection_keras_cv.ipynb). It doesn't deploy the object detection model, it is only for ensuring the model creation steps and prediction are functioning." ] }, { @@ -21,7 +18,52 @@ "from adlc_util import visualize_dataset, visualize_detections\n", "import tqdm\n", "\n", - "class_mapping = {0:\"target\"}" + "CLASS_MAPPING = {0:\"target\"}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Hyperparameters\n", + "\n", + "INPUT_SIZE_X = 640\n", + "INPUT_SIZE_Y = 640\n", + "\n", + "BASE_LR = 0.005\n", + "N_EPOCHS = 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load datasets from file\n", + "\n", + "Load the datasets processed by `process_dataset.py`, which have already been split into train/test.\n", + "\n", + "**NOTE:** this requires tensorflow>=2.14 on linux: https://github.com/XiaotingChen/maxatac_pip_1.0.5/issues/2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_ds = tf.data.Dataset.load(path=\"data/238_train_ds\", compression=\"GZIP\")\n", + "test_ds = tf.data.Dataset.load(path=\"data/238_test_ds\", compression=\"GZIP\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Augment Data\n", + "\n", + "This block creates new observations from the data by applying random transformations to the input images and resizing the images." ] }, { @@ -44,15 +86,12 @@ } ], "source": [ - "# NOTE: this requires tensorflow>=2.14 on linux: https://github.com/XiaotingChen/maxatac_pip_1.0.5/issues/2\n", - "train_ds = tf.data.Dataset.load(path=\"data/238_train_ds\", compression=\"GZIP\")\n", - "test_ds = tf.data.Dataset.load(path=\"data/238_test_ds\", compression=\"GZIP\")\n", - "\n", "augmenter = tf.keras.Sequential(\n", " layers=[\n", + " # Adding/changing layers will adjust transformations on original observations\n", " keras_cv.layers.RandomFlip(mode=\"horizontal\", bounding_box_format=\"xywh\"),\n", " keras_cv.layers.JitteredResize(\n", - " target_size=(640, 640),\n", + " target_size=(INPUT_SIZE_X, INPUT_SIZE_Y),\n", " scale_factor=(0.75, 1.3),\n", " bounding_box_format=\"xywh\",\n", " ),\n", @@ -61,14 +100,36 @@ "\n", "train_ds = train_ds.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)\n", "\n", + "# Apply resize to test data\n", "inference_resizing = keras_cv.layers.Resizing(\n", - " 640, 640, bounding_box_format=\"xywh\", pad_to_aspect_ratio=True\n", + " INPUT_SIZE_X, INPUT_SIZE_Y, bounding_box_format=\"xywh\", pad_to_aspect_ratio=True\n", ")\n", - "test_ds = test_ds.map(inference_resizing, num_parallel_calls=tf.data.AUTOTUNE)\n", - "\n", - "# visualize_dataset(\n", - "# train_ds, bounding_box_format=\"xywh\", value_range=(0, 255), rows=2, cols=2, offset=1\n", - "# )" + "test_ds = test_ds.map(inference_resizing, num_parallel_calls=tf.data.AUTOTUNE)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Optionally, visualize a sample of the dataset after augmentation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "visualize_dataset(\n", + " train_ds, bounding_box_format=\"xywh\", value_range=(0, 255), rows=2, cols=2, offset=1\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following converts ragged bounding boxes to a format usable for training." ] }, { @@ -82,8 +143,6 @@ " inputs[\"bounding_boxes\"], max_boxes=32\n", " )\n", "\n", - "\n", - "\n", "train_ds = train_ds.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)\n", "test_ds = test_ds.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)\n", "\n", @@ -95,7 +154,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Optimizer" + "### Create Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Initialize Optimizer" ] }, { @@ -104,10 +170,9 @@ "metadata": {}, "outputs": [], "source": [ - "base_lr = 0.005\n", "# including a global_clipnorm is extremely important in object detection tasks\n", "optimizer = tf.keras.optimizers.SGD(\n", - " learning_rate=base_lr, momentum=0.9, global_clipnorm=10.0\n", + " learning_rate=BASE_LR, momentum=0.9, global_clipnorm=10.0\n", ")" ] }, @@ -115,40 +180,37 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Loss" + "Create RetinaNet with pre-trained weights" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "def get_backbone():\n", - " \"\"\"Builds ResNet50 with pre-trained imagenet weights\"\"\"\n", - " backbone = tf.keras.applications.ResNet50(\n", - " include_top=False, input_shape=[None, None, 3]\n", - " )\n", - " c3_output, c4_output, c5_output = [\n", - " backbone.get_layer(layer_name).output\n", - " for layer_name in [\"conv3_block4_out\", \"conv4_block6_out\", \"conv5_block3_out\"]\n", - " ]\n", - " return tf.keras.Model(\n", - " inputs=[backbone.inputs], outputs=[c3_output, c4_output, c5_output]\n", - " )\n", - "\n", - "# resnet50_backbone = get_backbone()\n", - "# model = keras_cv.models.RetinaNet(num_classes=1, backbone=resnet50_backbone, bounding_box_format=\"xywh\")\n", - "\n", - "model = keras_cv.models.RetinaNet(\n", - " num_classes=len(class_mapping),\n", + "model = keras_cv.models.RetinaNet.from_preset(\n", + " \"resnet50_imagenet\",\n", + " num_classes=len(CLASS_MAPPING),\n", + " # For more info on supported bounding box formats, visit\n", + " # https://keras.io/api/keras_cv/bounding_box/\n", " bounding_box_format=\"xywh\",\n", - " backbone=keras_cv.models.ResNet50Backbone.from_preset(\n", - " \"resnet50_imagenet\"\n", - " )\n", - ")\n", - "\n", - "\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Initialize a decoder that will transform model output into usable results." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ "prediction_decoder = keras_cv.layers.MultiClassNonMaxSuppression(\n", " bounding_box_format=\"xywh\",\n", " from_logits=True,\n", @@ -160,23 +222,11 @@ "model.prediction_decoder = prediction_decoder" ] }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "model.compile(\n", - " classification_loss=\"focal\",\n", - " box_loss=\"smoothl1\",\n", - ")" - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Metric" + "Initialize metrics for callbacks using [COCO Detection Evaluation](https://cocodataset.org/#detection-eval)" ] }, { @@ -245,28 +295,6 @@ "print_metrics(result)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Create Model" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "model = keras_cv.models.RetinaNet.from_preset(\n", - " \"resnet50_imagenet\",\n", - " num_classes=len(class_mapping),\n", - " # For more info on supported bounding box formats, visit\n", - " # https://keras.io/api/keras_cv/bounding_box/\n", - " bounding_box_format=\"xywh\",\n", - ")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -365,20 +393,40 @@ "\n", " metrics = self.metrics.result(force=True)\n", " logs.update(metrics)\n", - " return logs\n", - "\n", - "\n", + " return logs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Fit the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "model.fit(\n", " train_ds.take(20),\n", " validation_data=test_ds.take(20),\n", " # Run for 10-35~ epochs to achieve good scores.\n", - " epochs=1,\n", + " epochs=N_EPOCHS,\n", " callbacks=[EvaluateCOCOMetricsCallback(test_ds.take(20)), \n", " # VisualizeDetections(test_ds.take(20)) # not sure how to pass the visualizations\n", " ],\n", ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Select a subset of the dataset for visualizing predictions." + ] + }, { "cell_type": "code", "execution_count": 34, @@ -391,37 +439,10 @@ ] }, { - "cell_type": "code", - "execution_count": 35, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "from keras_cv import visualization\n", - "from keras_cv import bounding_box\n", - "CLASS_MAPPING = {0:\"target\"}\n", - "\n", - "def visualize_detections(model, dataset, bounding_box_format, offset=1):\n", - " it = iter(dataset.take(offset))\n", - "\n", - " for _ in range(offset):\n", - " images, y_true = next(it)\n", - "\n", - " # images, y_true = next(iter(dataset.take(1)))\n", - " y_pred = model.predict(images)\n", - " y_pred = bounding_box.to_ragged(y_pred)\n", - " visualization.plot_bounding_box_gallery(\n", - " images,\n", - " value_range=(0, 255),\n", - " bounding_box_format=bounding_box_format,\n", - " y_true=y_true,\n", - " y_pred=y_pred,\n", - " scale=20,\n", - " rows=2,\n", - " cols=2,\n", - " show=True,\n", - " font_scale=0.7,\n", - " class_mapping=CLASS_MAPPING,\n", - " )" + "Visualize predictions on test data. Requires IPython / Jupyter" ] }, { @@ -456,39 +477,8 @@ } ], "source": [ - "model.prediction_decoder = keras_cv.layers.MultiClassNonMaxSuppression(\n", - " bounding_box_format=\"xywh\",\n", - " from_logits=True,\n", - " iou_threshold=0.5,\n", - " confidence_threshold=0.75,\n", - ")\n", - "\n", "visualize_detections(model, dataset=visualization_ds, bounding_box_format=\"xywh\")" ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [], - "source": [ - "class VisualizeDetections(tf.keras.callbacks.Callback):\n", - " def __init__(self, data):\n", - " super().__init__()\n", - " self.data = data.unbatch()\n", - "\n", - " def on_epoch_end(self, epoch, logs):\n", - " visualize_detections(\n", - " self.model, bounding_box_format=\"xywh\", dataset=self.data\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {