diff --git a/.travis.yml b/.travis.yml index 4fe52c9b4c5..0ce72d7a25a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -93,6 +93,8 @@ script: # Run manual S3 test - elapsed && bazel test //tensorboard/compat/tensorflow_stub:gfile_s3_test - elapsed && bazel test //tensorboard/summary/writer:event_file_writer_s3_test + - elapsed && bazel test //tensorboard/compat/tensorflow_stub:gfile_gcs_test + - elapsed && bazel test //tensorboard/summary/writer:event_file_writer_gcs_test - elapsed "script (done)" after_script: diff --git a/tensorboard/compat/tensorflow_stub/BUILD b/tensorboard/compat/tensorflow_stub/BUILD index 36b79483b83..2055a901f8d 100644 --- a/tensorboard/compat/tensorflow_stub/BUILD +++ b/tensorboard/compat/tensorflow_stub/BUILD @@ -64,3 +64,18 @@ py_test( "//tensorboard:test", ], ) + +py_test( + name = "gfile_gcs_test", + size = "small", + srcs = ["io/gfile_gcs_test.py"], + srcs_version = "PY2AND3", + tags = [ + "manual", + "notap", + ], + deps = [ + ":tensorflow_stub", + "//tensorboard:test", + ], +) diff --git a/tensorboard/compat/tensorflow_stub/io/gfile.py b/tensorboard/compat/tensorflow_stub/io/gfile.py index e56bb96917e..beb9dced8e8 100644 --- a/tensorboard/compat/tensorflow_stub/io/gfile.py +++ b/tensorboard/compat/tensorflow_stub/io/gfile.py @@ -43,6 +43,17 @@ except ImportError: S3_ENABLED = False +try: + from google.cloud import storage + from google.cloud import exceptions as gc_exceptions + from six.moves import http_client + + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "cred.json" + + GCS_ENABLED = True +except ImportError: + GCS_ENABLED = False + if sys.version_info < (3, 0): # In Python 2 FileExistsError is not defined and the # error manifests it as OSError. @@ -421,6 +432,170 @@ def stat(self, filename): register_filesystem("s3", S3FileSystem()) +class GCSFileSystem(object): + """Writes tensorboard protocol buffer files to Google Cloud Storage.""" + + def __init__(self): + if not GCS_ENABLED: + raise ImportError( + "`google-cloud-storage` must be installed in order to use " + "the 'gs://' protocol" + ) + + self.client = storage.Client() + + def get_blob(filename): + bucket_name, filepath = self.bucket_and_path(filename) + bucket = storage.Bucket(self.client, bucket_name) + return storage.Blob( + filepath, bucket, chunk_size=_DEFAULT_BLOCK_SIZE + ) + + self.blob = get_blob + + def bucket_and_path(self, url): + url = compat.as_str_any(url) + if url.startswith("gs://"): + url = url[len("gs://") :] + bp = url.split("/") + bucket = bp[0] + path = url[1 + len(bucket) :] + return bucket, path + + def exists(self, filename): + """Determines whether a path exists or not.""" + bucket, path = self.bucket_and_path(filename) + r = self.client.list_blobs(bucket_or_name=bucket, prefix=path) + if len(list(r)) != 0: + return True + return False + + def join(self, path, *paths): + """Join paths with a slash.""" + return "/".join((path,) + paths) + + def read(self, filename, binary_mode=False, size=None, continue_from=None): + + if continue_from is None: + continue_from = 0 + + if size is not None: + end = continue_from + size + else: + end = None + + try: + stream = self.blob(filename).download_as_string( + start=continue_from, end=end + ) + except Exception as e: + if e.code == http_client.REQUESTED_RANGE_NOT_SATISFIABLE: + return "", continue_from + + else: + raise + + continue_from += len(stream) + if binary_mode: + return (bytes(stream), continue_from) + else: + return (stream.decode("utf-8"), continue_from) + + def write(self, filename, file_content, binary_mode=False): + file_content = compat.as_bytes(file_content) + self.blob(filename).upload_from_string( + file_content + ) # this will overwrite! + + def glob(self, filename): + """Returns a list of files that match the given pattern(s).""" + # Only support prefix with * at the end and no ? in the string + star_i = filename.find("*") + quest_i = filename.find("?") + if quest_i >= 0: + raise NotImplementedError( + "{} not supported by compat glob".format(filename) + ) + if star_i != len(filename) - 1: + # Just return empty so we can use glob from directory watcher + # + # TODO: Remove and instead handle in GetLogdirSubdirectories. + # However, we would need to handle it for all non-local registered + # filesystems in some way. + return [] + filename = filename[:-1] + bucket, path = self.bucket_and_path(filename) + result = list( + self.client.list_blobs(bucket_or_name=bucket, prefix=path) + ) + + keys = [] + for r in result: + # glob.glob('./*') returns folder as well. + if r.name[-1] != "/": # in order to pass the unit test + keys.append(filename + r.name[len(path) :]) + + return keys + + def isdir(self, dirname): + """Returns whether the path is a directory or not.""" + bucket, path = self.bucket_and_path(dirname) + if path[-1] != "/": + path += "/" + result = list( + self.client.list_blobs( + bucket_or_name=bucket, prefix=path, delimiter="/" + ) + ) + return len(result) > 0 + + def listdir(self, dirname): + """Returns a list of entries contained within a directory.""" + bucket, path = self.bucket_and_path(dirname) + + if path[-1] != "/": + path += "/" + path_depth = len(path.split("/")) - 1 + result = list( + self.client.list_blobs(bucket_or_name=bucket, prefix=path) + ) + keys = set() + + for r in result: + dirs = r.name.split("/") + if len(dirs) > path_depth: + if dirs[path_depth] != "": + keys.add(dirs[path_depth]) + return keys + + def makedirs(self, dirname): + """Creates a directory and all parent/intermediate directories.""" + if self.exists(dirname): + raise errors.AlreadyExistsError( + None, None, "Directory already exists" + ) + if not dirname.endswith("/"): + dirname += "/" # This will make sure we don't override a file + self.blob(dirname).upload_from_string("") + + def stat(self, filename): + """Returns file statistics for a given path.""" + # NOTE: Size of the file is given by ContentLength from S3, + # but we convert to .length + bucket_name, path = self.bucket_and_path(filename) + bucket = storage.Bucket(self.client, bucket_name) + blob = bucket.get_blob(path) + if blob == None: + raise errors.NotFoundError(None, None, "Could not find file") + + # use get_blob to get metadata + return StatData(bucket.get_blob(path).size) + + +if GCS_ENABLED: + register_filesystem("gs", GCSFileSystem()) + + class GFile(object): # Only methods needed for TensorBoard are implemented. diff --git a/tensorboard/compat/tensorflow_stub/io/gfile_gcs_test.py b/tensorboard/compat/tensorflow_stub/io/gfile_gcs_test.py new file mode 100644 index 00000000000..5145d89bbf1 --- /dev/null +++ b/tensorboard/compat/tensorflow_stub/io/gfile_gcs_test.py @@ -0,0 +1,468 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import six +import unittest +from google.cloud import storage + +from tensorboard.compat.tensorflow_stub import errors +from tensorboard.compat.tensorflow_stub.io import gfile + +# Placeholder values to make sure any local keys are overridden +# and moto mock is being called + + +class GFileTest(unittest.TestCase): + def testExists(self): + ckpt_path = self._PathJoin(temp_dir, "model.ckpt") + self.assertTrue(gfile.exists(temp_dir)) + + def testGlob(self): + # S3 glob includes subdirectory content, which standard + # filesystem does not. However, this is good for perf. + expected = [ + "a.tfevents.1", + "bar/b.tfevents.1", + "bar/baz/c.tfevents.1", + "bar/baz/d.tfevents.1", + "bar/quux/some_flume_output.txt", + "bar/quux/some_more_flume_output.txt", + "bar/red_herring.txt", + "model.ckpt", + "quuz/e.tfevents.1", + "quuz/garply/corge/g.tfevents.1", + "quuz/garply/f.tfevents.1", + "quuz/garply/grault/h.tfevents.1", + "waldo/fred/i.tfevents.1", + ] + expected_listing = [self._PathJoin(temp_dir, f) for f in expected] + gotten_listing = gfile.glob(self._PathJoin(temp_dir, "*")) + six.assertCountEqual( + self, + expected_listing, + gotten_listing, + "Files must match. Expected %r. Got %r." + % (expected_listing, gotten_listing), + ) + + def testIsdir(self): + self.assertTrue(gfile.isdir(temp_dir)) + + def testListdir(self): + expected_files = [ + # Empty directory not returned + "foo", + "bar", + "quuz", + "a.tfevents.1", + "model.ckpt", + "waldo", + ] + gotten_files = gfile.listdir(temp_dir) + six.assertCountEqual(self, expected_files, gotten_files) + + # This can only run once, the second run will get AlreadyExistsError + def testMakeDirs(self): + remove_newdir() + new_dir = self._PathJoin(temp_dir, "newdir", "subdir", "subsubdir") + gfile.makedirs(new_dir) + self.assertTrue(gfile.isdir(new_dir)) + remove_newdir() + + def testMakeDirsAlreadyExists(self): + temp_dir = self._CreateDeepGCSStructure() + new_dir = self._PathJoin(temp_dir, "bar", "baz") + with self.assertRaises(errors.AlreadyExistsError): + gfile.makedirs(new_dir) + + def testWalk(self): + temp_dir = "gs://lanpa-tbx/123" + expected = [ + ["", ["a.tfevents.1", "model.ckpt",]], + # Empty directory not returned + ["foo", []], + ["bar", ["b.tfevents.1", "red_herring.txt",]], + ["bar/baz", ["c.tfevents.1", "d.tfevents.1",]], + [ + "bar/quux", + ["some_flume_output.txt", "some_more_flume_output.txt",], + ], + ["quuz", ["e.tfevents.1",]], + ["quuz/garply", ["f.tfevents.1",]], + ["quuz/garply/corge", ["g.tfevents.1",]], + ["quuz/garply/grault", ["h.tfevents.1",]], + ["waldo", []], + ["waldo/fred", ["i.tfevents.1",]], + ] + for pair in expected: + # If this is not the top-level directory, prepend the high-level + # directory. + pair[0] = self._PathJoin(temp_dir, pair[0]) if pair[0] else temp_dir + gotten = gfile.walk(temp_dir) + self._CompareFilesPerSubdirectory(expected, gotten) + + def testStat(self): + ckpt_content = "asdfasdfasdffoobarbuzz" + temp_dir = self._CreateDeepGCSStructure(ckpt_content=ckpt_content) + ckpt_path = self._PathJoin(temp_dir, "model.ckpt") + ckpt_stat = gfile.stat(ckpt_path) + self.assertEqual(ckpt_stat.length, len(ckpt_content)) + bad_ckpt_path = self._PathJoin(temp_dir, "bad_model.ckpt") + with self.assertRaises(errors.NotFoundError): + gfile.stat(bad_ckpt_path) + + def testRead(self): + ckpt_content = "asdfasdfasdffoobarbuzz" + temp_dir = self._CreateDeepGCSStructure(ckpt_content=ckpt_content) + ckpt_path = self._PathJoin(temp_dir, "model.ckpt") + with gfile.GFile(ckpt_path, "r") as f: + f.buff_chunk_size = 4 # Test buffering by reducing chunk size + ckpt_read = f.read() + self.assertEqual(ckpt_content, ckpt_read) + + def testReadLines(self): + ckpt_lines = ( + [u"\n"] + [u"line {}\n".format(i) for i in range(10)] + [u" "] + ) + ckpt_content = u"".join(ckpt_lines) + temp_dir = self._CreateDeepGCSStructure(ckpt_content=ckpt_content) + ckpt_path = self._PathJoin(temp_dir, "model.ckpt") + with gfile.GFile(ckpt_path, "r") as f: + f.buff_chunk_size = 4 # Test buffering by reducing chunk size + ckpt_read_lines = list(f) # list(f) + self.assertEqual(ckpt_lines, ckpt_read_lines) + + def testReadWithOffset(self): + ckpt_content = "asdfasdfasdffoobarbuzz" + ckpt_b_content = b"asdfasdfasdffoobarbuzz" + temp_dir = self._CreateDeepGCSStructure(ckpt_content=ckpt_content) + ckpt_path = self._PathJoin(temp_dir, "model.ckpt") + with gfile.GFile(ckpt_path, "r") as f: + f.buff_chunk_size = 4 # Test buffering by reducing chunk size + ckpt_read = f.read(12) + self.assertEqual("asdfasdfasdf", ckpt_read) + ckpt_read = f.read(6) + self.assertEqual("foobar", ckpt_read) + ckpt_read = f.read(1) + self.assertEqual("b", ckpt_read) + ckpt_read = f.read() + self.assertEqual("uzz", ckpt_read) + ckpt_read = f.read(1000) + self.assertEqual("", ckpt_read) + with gfile.GFile(ckpt_path, "rb") as f: + ckpt_read = f.read() + self.assertEqual(ckpt_b_content, ckpt_read) + + def testWrite(self): + remove_model2_ckpt() + ckpt_path = os.path.join(temp_dir_write, "model2.ckpt") + ckpt_content = u"asdfasdfasdffoobarbuzz" + with gfile.GFile(ckpt_path, "w") as f: + f.write(ckpt_content) + with gfile.GFile(ckpt_path, "r") as f: + ckpt_read = f.read() + self.assertEqual(ckpt_content, ckpt_read) + + def testOverwrite(self): + remove_model2_ckpt() + ckpt_path = os.path.join(temp_dir_write, "model2.ckpt") + ckpt_content = u"asdfasdfasdffoobarbuzz" + with gfile.GFile(ckpt_path, "w") as f: + f.write(u"original") + with gfile.GFile(ckpt_path, "w") as f: + f.write(ckpt_content) + with gfile.GFile(ckpt_path, "r") as f: + ckpt_read = f.read() + self.assertEqual(ckpt_content, ckpt_read) + + def testWriteMultiple(self): + remove_model2_ckpt() + ckpt_path = os.path.join(temp_dir_write, "model2.ckpt") + ckpt_content = u"asdfasdfasdffoobarbuzz" * 5 + with gfile.GFile(ckpt_path, "w") as f: + for i in range(0, len(ckpt_content), 3): + f.write(ckpt_content[i : i + 3]) + # Test periodic flushing of the file + if i % 9 == 0: + f.flush() + with gfile.GFile(ckpt_path, "r") as f: + ckpt_read = f.read() + self.assertEqual(ckpt_content, ckpt_read) + + def testWriteEmpty(self): + remove_model2_ckpt() + ckpt_path = os.path.join(temp_dir_write, "model2.ckpt") + ckpt_content = u"" + with gfile.GFile(ckpt_path, "w") as f: + f.write(ckpt_content) + with gfile.GFile(ckpt_path, "r") as f: + ckpt_read = f.read() + self.assertEqual(ckpt_content, ckpt_read) + + def testWriteBinary(self): + remove_model2_ckpt() + ckpt_path = os.path.join(temp_dir_write, "model2.ckpt") + ckpt_content = b"asdfasdfasdffoobarbuzz" + with gfile.GFile(ckpt_path, "wb") as f: + f.write(ckpt_content) + with gfile.GFile(ckpt_path, "rb") as f: + ckpt_read = f.read() + self.assertEqual(ckpt_content, ckpt_read) + + def testWriteMultipleBinary(self): + remove_model2_ckpt() + ckpt_path = os.path.join(temp_dir_write, "model2.ckpt") + ckpt_content = b"asdfasdfasdffoobarbuzz" * 5 + with gfile.GFile(ckpt_path, "wb") as f: + for i in range(0, len(ckpt_content), 3): + f.write(ckpt_content[i : i + 3]) + # Test periodic flushing of the file + if i % 9 == 0: + f.flush() + with gfile.GFile(ckpt_path, "rb") as f: + ckpt_read = f.read() + self.assertEqual(ckpt_content, ckpt_read) + + def _PathJoin(self, *args): + """Join directory and path with slash and not local separator""" + return "/".join(args) + + def _CreateDeepGCSStructure( + self, + top_directory="123", + ckpt_content="", + region_name="us-east-1", + bucket_name="lanpa-tbx", + ): + """Creates a reasonable deep structure of GCS subdirectories with files. + + Args: + top_directory: The path of the top level GCS directory in which + to create the directory structure. Defaults to 'top_dir'. + ckpt_content: The content to put into model.ckpt. Default to ''. + region_name: The GCS region name. Defaults to 'us-east-1'. + bucket_name: The GCS bucket name. Defaults to 'test'. + + Returns: + GCS URL of the top directory in the form 'gs://bucket/path' + """ + gs_top_url = "gs://{}/{}".format(bucket_name, top_directory) + # return gs_top_url + # Add a few subdirectories. + directory_names = ( + # An empty directory. + "foo", + # A directory with an events file (and a text file). + "bar", + # A deeper directory with events files. + "bar/baz", + # A non-empty subdir that lacks event files (should be ignored). + "bar/quux", + # This 3-level deep set of subdirectories tests logic that replaces + # the full glob string with an absolute path prefix if there is + # only 1 subdirectory in the final mapping. + "quuz/garply", + "quuz/garply/corge", + "quuz/garply/grault", + # A directory that lacks events files, but contains a subdirectory + # with events files (first level should be ignored, second level + # should be included). + "waldo", + "waldo/fred", + ) + client = storage.Client() + bucket = storage.Bucket(client, bucket_name) + blob = storage.Blob(top_directory, bucket) + + for directory_name in directory_names: + # Add an end slash + path = top_directory + "/" + directory_name + "/" + # Create an empty object so the location exists + blob = storage.Blob(path, bucket) + blob.upload_from_string("") + + # Add a few files to the directory. + file_names = ( + "a.tfevents.1", + "model.ckpt", + "bar/b.tfevents.1", + "bar/red_herring.txt", + "bar/baz/c.tfevents.1", + "bar/baz/d.tfevents.1", + "bar/quux/some_flume_output.txt", + "bar/quux/some_more_flume_output.txt", + "quuz/e.tfevents.1", + "quuz/garply/f.tfevents.1", + "quuz/garply/corge/g.tfevents.1", + "quuz/garply/grault/h.tfevents.1", + "waldo/fred/i.tfevents.1", + ) + for file_name in file_names: + # Add an end slash + path = top_directory + "/" + file_name + if file_name == "model.ckpt": + content = ckpt_content + else: + content = "" + blob = storage.Blob(path, bucket) + blob.upload_from_string(content) + return gs_top_url + + def _CompareFilesPerSubdirectory(self, expected, gotten): + """Compares iterables of (subdirectory path, list of absolute paths) + + Args: + expected: The expected iterable of 2-tuples. + gotten: The gotten iterable of 2-tuples. + """ + expected_directory_to_files = { + result[0]: list(result[1]) for result in expected + } + gotten_directory_to_files = { + # Note we ignore subdirectories and just compare files + result[0]: list(result[2]) + for result in gotten + } + six.assertCountEqual( + self, + expected_directory_to_files.keys(), + gotten_directory_to_files.keys(), + ) + + for subdir, expected_listing in expected_directory_to_files.items(): + gotten_listing = gotten_directory_to_files[subdir] + six.assertCountEqual( + self, + expected_listing, + gotten_listing, + "Files for subdir %r must match. Expected %r. Got %r." + % (subdir, expected_listing, gotten_listing), + ) + + +def remove_newdir(): + try: + client = storage.Client() + bucket = storage.Bucket(client, "lanpa-tbx") + blobs = bucket.list_blobs(prefix="123/newdir/subdir/subsubdir") + for b in blobs: + b.delete() + except: + pass + + +def remove_model2_ckpt(): + try: + client = storage.Client() + bucket = storage.Bucket(client, "lanpa-tbx") + blobs = bucket.list_blobs(prefix="write/model2.ckpt") + for b in blobs: + b.delete() + except: + pass + + +def CreateDeepGCSStructure( + top_directory="123", + ckpt_content="", + region_name="us-east-1", + bucket_name="lanpa-tbx", +): + """Creates a reasonable deep structure of GCS subdirectories with files. + + Args: + top_directory: The path of the top level GCS directory in which + to create the directory structure. Defaults to 'top_dir'. + ckpt_content: The content to put into model.ckpt. Default to ''. + region_name: The GCS region name. Defaults to 'us-east-1'. + bucket_name: The GCS bucket name. Defaults to 'test'. + + Returns: + GCS URL of the top directory in the form 'gs://bucket/path' + """ + gs_top_url = "gs://{}/{}".format(bucket_name, top_directory) + # return gs_top_url + # Add a few subdirectories. + directory_names = ( + # An empty directory. + "foo", + # A directory with an events file (and a text file). + "bar", + # A deeper directory with events files. + "bar/baz", + # A non-empty subdir that lacks event files (should be ignored). + "bar/quux", + # This 3-level deep set of subdirectories tests logic that replaces + # the full glob string with an absolute path prefix if there is + # only 1 subdirectory in the final mapping. + "quuz/garply", + "quuz/garply/corge", + "quuz/garply/grault", + # A directory that lacks events files, but contains a subdirectory + # with events files (first level should be ignored, second level + # should be included). + "waldo", + "waldo/fred", + ) + + client = storage.Client() + bucket = storage.Bucket(client, bucket_name) + blob = storage.Blob(top_directory, bucket) + + for directory_name in directory_names: + # Add an end slash + path = top_directory + "/" + directory_name + "/" + # Create an empty object so the location exists + blob = storage.Blob(path, bucket) + blob.upload_from_string("") + + # Add a few files to the directory. + file_names = ( + "a.tfevents.1", + "model.ckpt", + "bar/b.tfevents.1", + "bar/red_herring.txt", + "bar/baz/c.tfevents.1", + "bar/baz/d.tfevents.1", + "bar/quux/some_flume_output.txt", + "bar/quux/some_more_flume_output.txt", + "quuz/e.tfevents.1", + "quuz/garply/f.tfevents.1", + "quuz/garply/corge/g.tfevents.1", + "quuz/garply/grault/h.tfevents.1", + "waldo/fred/i.tfevents.1", + ) + for file_name in file_names: + # Add an end slash + path = top_directory + "/" + file_name + if file_name == "model.ckpt": + content = ckpt_content + else: + content = "" + blob = storage.Blob(path, bucket) + blob.upload_from_string(content) + return gs_top_url + + +temp_dir = CreateDeepGCSStructure() +temp_dir_write = "gs://lanpa-tbx/write" + +if __name__ == "__main__": + unittest.main() diff --git a/tensorboard/pip_package/requirements_dev.txt b/tensorboard/pip_package/requirements_dev.txt index a0ea901a0f5..9b16bc8af4b 100644 --- a/tensorboard/pip_package/requirements_dev.txt +++ b/tensorboard/pip_package/requirements_dev.txt @@ -20,6 +20,7 @@ grpcio-testing==1.24.3 # For gfile S3 test boto3==1.9.86 moto==1.3.7 +google-cloud-storage==1.24.1 # For linting black==19.10b0; python_version >= "3" diff --git a/tensorboard/summary/writer/BUILD b/tensorboard/summary/writer/BUILD index 5f2b9c254cc..cc97c73a3c3 100644 --- a/tensorboard/summary/writer/BUILD +++ b/tensorboard/summary/writer/BUILD @@ -52,6 +52,23 @@ py_test( ], ) +py_test( + name = "event_file_writer_gcs_test", + size = "small", + srcs = ["event_file_writer_gcs_test.py"], + main = "event_file_writer_gcs_test.py", + srcs_version = "PY2AND3", + tags = [ + "manual", + "notap", + "support_notf", + ], + deps = [ + ":writer", + "//tensorboard:test", + ], +) + py_test( name = "record_writer_test", size = "small", diff --git a/tensorboard/summary/writer/event_file_writer_gcs_test.py b/tensorboard/summary/writer/event_file_writer_gcs_test.py new file mode 100644 index 00000000000..7389a5e634e --- /dev/null +++ b/tensorboard/summary/writer/event_file_writer_gcs_test.py @@ -0,0 +1,97 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# """Tests for EventFileWriter and _AsyncWriter""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import unittest +from tensorboard.summary.writer.event_file_writer import EventFileWriter +from tensorboard.summary.writer.event_file_writer import _AsyncWriter +from tensorboard.compat import tf +from tensorboard.compat.proto import event_pb2 +from tensorboard.compat.proto.summary_pb2 import Summary +from tensorboard.compat.tensorflow_stub.io import gfile +from tensorboard.compat.tensorflow_stub.pywrap_tensorflow import ( + PyRecordReader_New, +) +from tensorboard import test as tb_test +from google.cloud import storage + +# Placeholder values to make sure any local keys are overridden +# and moto mock is being called + +USING_REAL_TF = tf.__version__ != "stub" + + +def gcs_temp_dir(top_directory="event-test", bucket_name="lanpa-tbx"): + """Creates a test GCS bucket and returns directory location. + The files in `top_directory` will be cleared after this call. + Args: + top_directory: The path of the top level GCS directory in which + to create the directory structure. Defaults to 'top_dir'. + bucket_name: The GCS bucket name. + + Returns GCS URL of the top directory in the form 'gs://bucket/path' + """ + gcs_url = "gs://{}/{}".format(bucket_name, top_directory) + client = storage.Client() + bucket = storage.Bucket(client, bucket_name) + bloblist = bucket.list_blobs(prefix=top_directory) + for f in bloblist: + f.delete() + return gcs_url + + +def GCS_join(*args): + """Joins an GCS directory path as a replacement for os.path.join.""" + return "/".join(args) + + +class EventFileWriterTest(tb_test.TestCase): + @unittest.skipIf(USING_REAL_TF, "Test only passes when using stub TF") + def test_event_file_writer_roundtrip(self): + _TAGNAME = "dummy" + _DUMMY_VALUE = 42 + logdir = gcs_temp_dir() + w = EventFileWriter(logdir) + summary = Summary( + value=[Summary.Value(tag=_TAGNAME, simple_value=_DUMMY_VALUE)] + ) + fakeevent = event_pb2.Event(summary=summary) + w.add_event(fakeevent) + w.close() + event_files = sorted(gfile.glob(GCS_join(logdir, "*"))) + self.assertEqual(len(event_files), 1) + r = PyRecordReader_New(event_files[0]) + r.GetNext() # meta data, so skip + r.GetNext() + self.assertEqual(fakeevent.SerializeToString(), r.record()) + + @unittest.skipIf(USING_REAL_TF, "Test only passes when using stub TF") + def test_setting_filename_suffix_works(self): + logdir = gcs_temp_dir() + + w = EventFileWriter(logdir, filename_suffix=".event_horizon") + w.close() + event_files = sorted(gfile.glob(GCS_join(logdir, "*"))) + self.assertEqual(event_files[0].split(".")[-1], "event_horizon") + + +if __name__ == "__main__": + tb_test.main()