Skip to content

Commit

Permalink
Removed all usages of "six" in tfx_bsl.
Browse files Browse the repository at this point in the history
Also, omit the MRO search type passed to super() in cases of single-inheritance.

PiperOrigin-RevId: 330766477
  • Loading branch information
brills authored and tfx-copybara committed Sep 9, 2020
1 parent e5331ea commit c2e8033
Show file tree
Hide file tree
Showing 29 changed files with 74 additions and 126 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* Depends on `absl-py>=0.9,<0.11`.
* Depends on `pandas>=1.0,<2`.
* Depends on `protobuf>=3.9.2,<4`.
* Stopped depending on `six`.

## Breaking changes

Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def has_ext_modules(self):
'pandas>=1.0,<2',
'protobuf>=3.9.2,<4',
'pyarrow>=0.17,<0.18',
'six>=1.12,<2',
'tensorflow>=1.15.2,!=2.0.*,!=2.1.*,!=2.2.*,<3',
'tensorflow-metadata>=0.23,<0.24',
'tensorflow-serving-api>=1.15,!=2.0.*,!=2.1.*,!=2.2.*,<3',
Expand Down
8 changes: 0 additions & 8 deletions tfx_bsl/arrow/array_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import numpy as np
import pyarrow as pa
import six

from tfx_bsl.arrow import array_util

Expand Down Expand Up @@ -235,13 +234,6 @@ def test_match(self):
values=pa.array([], type=pa.int64()),
expected=pa.array([None, None, None, None, None],
type=pa.list_(pa.int64()))),
dict(
testcase_name="long_num_parent",
num_parents=(long(1) if six.PY2 else 1),
parent_indices=pa.array([0], type=pa.int64()),
values=pa.array([1]),
expected=pa.array([[1]])
),
dict(
testcase_name="leading nones",
num_parents=3,
Expand Down
24 changes: 7 additions & 17 deletions tfx_bsl/arrow/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,19 @@

from typing import Iterable, Text, Tuple, Union

import six

from tensorflow_metadata.proto.v0 import path_pb2


# Text on py3, bytes on py2.
Step = Union[bytes, Text]


@six.python_2_unicode_compatible
class ColumnPath(object):
"""ColumnPath addresses a column potentially nested under a StructArray."""

__slot__ = ["_steps"]

def __init__(self, steps: Union[Iterable[Step], Step]):
def __init__(self, steps: Union[Iterable[Text], Text]):
"""If a single Step is specified, constructs a Path of that step."""
if isinstance(steps, (bytes, six.text_type)):
if isinstance(steps, Text):
steps = (steps,)
self._steps = tuple(
s if isinstance(s, six.text_type) else s.decode("utf-8") for s in steps)
self._steps = tuple(steps)

def to_proto(self) -> path_pb2.Path:
"""Creates a tensorflow_metadata path proto this ColumnPath."""
Expand All @@ -53,7 +45,7 @@ def from_proto(path_proto: path_pb2.Path):
"""
return ColumnPath(path_proto.step)

def steps(self) -> Tuple[Step, ...]:
def steps(self) -> Tuple[Text, ...]:
"""Returns the tuple of steps that represents this ColumnPath."""
return self._steps

Expand All @@ -70,7 +62,7 @@ def parent(self) -> "ColumnPath":
raise ValueError("Root does not have parent.")
return ColumnPath(self._steps[:-1])

def child(self, child_step: Step) -> "ColumnPath":
def child(self, child_step: Text) -> "ColumnPath":
"""Creates a new ColumnPath with a new child.
example: ColumnPath(["this", "is", "my", "path"]).child("new_step") will
Expand All @@ -82,9 +74,7 @@ def child(self, child_step: Step) -> "ColumnPath":
Returns:
A ColumnPath with the new child_step
"""
if isinstance(child_step, six.text_type):
return ColumnPath(self._steps + (child_step,))
return ColumnPath(self._steps + (child_step.decode("utf-8"),))
return ColumnPath(self._steps + (child_step,))

def prefix(self, ending_index: int) -> "ColumnPath":
"""Creates a new ColumnPath, taking the prefix until the ending_index.
Expand Down Expand Up @@ -114,7 +104,7 @@ def suffix(self, starting_index: int) -> "ColumnPath":
"""
return ColumnPath(self._steps[starting_index:])

def initial_step(self) -> Step:
def initial_step(self) -> Text:
"""Returns the first step of this path.
Raises:
Expand Down
20 changes: 9 additions & 11 deletions tfx_bsl/beam/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from googleapiclient import discovery
from googleapiclient import http
import numpy as np
import six
import tensorflow as tf
from tfx_bsl.beam import shared
from tfx_bsl.public.proto import model_spec_pb2
Expand Down Expand Up @@ -207,8 +206,7 @@ def _MultiInference(pcoll: beam.pvalue.PCollection, # pylint: disable=invalid-n
raise NotImplementedError


@six.add_metaclass(abc.ABCMeta)
class _BaseDoFn(beam.DoFn):
class _BaseDoFn(beam.DoFn, metaclass=abc.ABCMeta):
"""Base DoFn that performs bulk inference."""

class _MetricsCollector(object):
Expand Down Expand Up @@ -267,7 +265,7 @@ def update(self, elements: List[Union[tf.train.Example,
sum(element.ByteSize() for element in elements))

def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType):
super(_BaseDoFn, self).__init__()
super().__init__()
self._clock = None
self._metrics_collector = self._MetricsCollector(inference_spec_type)

Expand Down Expand Up @@ -346,7 +344,7 @@ class _RemotePredictDoFn(_BaseDoFn):

def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType,
pipeline_options: PipelineOptions):
super(_RemotePredictDoFn, self).__init__(inference_spec_type)
super().__init__(inference_spec_type)
self._ai_platform_prediction_model_spec = (
inference_spec_type.ai_platform_prediction_model_spec)
self._api_client = None
Expand All @@ -373,7 +371,7 @@ def __init__(self, inference_spec_type: model_spec_pb2.InferenceSpecType,
version_name)

def setup(self):
super(_RemotePredictDoFn, self).setup()
super().setup()
# TODO(b/151468119): Add tfx_bsl_version and tfx_bsl_py_version to
# user agent once custom header is supported in googleapiclient.
self._api_client = discovery.build('ml', 'v1')
Expand Down Expand Up @@ -505,7 +503,7 @@ def __init__(
inference_spec_type: model_spec_pb2.InferenceSpecType,
shared_model_handle: shared.Shared,
):
super(_BaseBatchSavedModelDoFn, self).__init__(inference_spec_type)
super().__init__(inference_spec_type)
self._inference_spec_type = inference_spec_type
self._shared_model_handle = shared_model_handle
self._model_path = inference_spec_type.saved_model_spec.model_path
Expand All @@ -524,7 +522,7 @@ def setup(self):
to b/139207285.
"""

super(_BaseBatchSavedModelDoFn, self).setup()
super().setup()
self._tags = _get_tags(self._inference_spec_type)
self._io_tensor_spec = self._pre_process()

Expand Down Expand Up @@ -636,7 +634,7 @@ def setup(self):
'BulkInferrerClassifyDoFn requires signature method '
'name %s, got: %s' % tf.saved_model.CLASSIFY_METHOD_NAME,
signature_def.method_name)
super(_BatchClassifyDoFn, self).setup()
super().setup()

def _check_elements(
self, elements: List[Union[tf.train.Example,
Expand All @@ -661,7 +659,7 @@ class _BatchRegressDoFn(_BaseBatchSavedModelDoFn):
"""A DoFn that run inference on regression model."""

def setup(self):
super(_BatchRegressDoFn, self).setup()
super().setup()

def _check_elements(
self, elements: List[Union[tf.train.Example,
Expand Down Expand Up @@ -690,7 +688,7 @@ def setup(self):
'BulkInferrerPredictDoFn requires signature method '
'name %s, got: %s' % tf.saved_model.PREDICT_METHOD_NAME,
signature_def.method_name)
super(_BatchPredictDoFn, self).setup()
super().setup()

def _check_elements(
self, elements: List[Union[tf.train.Example,
Expand Down
10 changes: 5 additions & 5 deletions tfx_bsl/beam/run_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Tests for tfx_bsl.run_inference."""

import base64
from http import client as http_client
import json
import os
try:
Expand All @@ -27,7 +28,6 @@
from apache_beam.testing.util import equal_to
from googleapiclient import discovery
from googleapiclient import http
from six.moves import http_client
import tensorflow as tf
from tfx_bsl.beam import run_inference
from tfx_bsl.public.proto import model_spec_pb2
Expand All @@ -40,7 +40,7 @@
class RunInferenceFixture(tf.test.TestCase):

def setUp(self):
super(RunInferenceFixture, self).setUp()
super().setUp()
self._predict_examples = [
text_format.Parse(
"""
Expand Down Expand Up @@ -70,7 +70,7 @@ def _prepare_predict_examples(self, example_path):
class RunOfflineInferenceTest(RunInferenceFixture):

def setUp(self):
super(RunOfflineInferenceTest, self).setUp()
super().setUp()
self._predict_examples = [
text_format.Parse(
"""
Expand Down Expand Up @@ -361,7 +361,7 @@ def testKerasModelPredict(self):
class TestKerasModel(tf.keras.Model):

def __init__(self, inference_model):
super(TestKerasModel, self).__init__(name='test_keras_model')
super().__init__(name='test_keras_model')
self.inference_model = inference_model

@tf.function(input_signature=[
Expand Down Expand Up @@ -449,7 +449,7 @@ def testTelemetry(self):
class RunRemoteInferenceTest(RunInferenceFixture):

def setUp(self):
super(RunRemoteInferenceTest, self).setUp()
super().setUp()
self.example_path = self._get_output_data_dir('example')
self._prepare_predict_examples(self.example_path)
# This is from https://ml.googleapis.com/$discovery/rest?version=v1.
Expand Down
10 changes: 2 additions & 8 deletions tfx_bsl/coders/csv_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import apache_beam as beam
import numpy as np
import pyarrow as pa
import six
import tensorflow as tf
from tfx_bsl.coders import batch_util

Expand Down Expand Up @@ -263,7 +262,7 @@ def merge_accumulators(
# Merge the types inferred in each partition using the type hierarchy.
# Specifically, whenever we observe a type higher in the type hierarchy
# we update the type.
for feature_name, feature_type in six.iteritems(shard_types):
for feature_name, feature_type in shard_types.items():
if feature_name not in result or feature_type > result[feature_name]:
result[feature_name] = feature_type
return result
Expand Down Expand Up @@ -448,12 +447,7 @@ def __init__(self, delimiter: Union[Text, bytes]):
self._delimiter = delimiter
self._line_iterator = _MutableRepeat()
self._reader = csv.reader(self._line_iterator, delimiter=delimiter)
# Python 2 csv reader accepts bytes while Python 3 csv reader accepts
# unicode.
if six.PY2:
self._to_reader_input = tf.compat.as_bytes
else:
self._to_reader_input = tf.compat.as_text
self._to_reader_input = tf.compat.as_text

def ReadLine(self, csv_line: CSVLine) -> List[CSVCell]:
"""Reads out bytes for PY2 and Unicode for PY3."""
Expand Down
8 changes: 3 additions & 5 deletions tfx_bsl/coders/tf_graph_record_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import abc
from typing import Dict, List, Optional, Text, Union

import six
import tensorflow as tf

from tensorflow.python.framework import composite_tensor # pylint: disable=g-direct-tensorflow-import
Expand All @@ -25,8 +24,7 @@
TensorAlike = Union[tf.Tensor, composite_tensor.CompositeTensor]


@six.add_metaclass(abc.ABCMeta)
class TFGraphRecordDecoder(tf.Module):
class TFGraphRecordDecoder(tf.Module, metaclass=abc.ABCMeta):
"""Base class for decoders that turns a list of bytes to (composite) tensors.
Sub-classes must implemented `_decode_record_internal()` (see its docstring
Expand All @@ -45,7 +43,7 @@ def __init__(self, name: Text):
name: Must be a valid TF scope name. May be used to create TF namescopes.
see https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope.
"""
super(TFGraphRecordDecoder, self).__init__(name=name)
super().__init__(name=name)

@tf.function(input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.string)])
def decode_record(self, records: List[bytes]) -> Dict[Text, TensorAlike]:
Expand Down Expand Up @@ -131,7 +129,7 @@ class LoadedDecoder(TFGraphRecordDecoder):
"""

def __init__(self, loaded_module):
super(LoadedDecoder, self).__init__(name="LoadedDecoder")
super().__init__(name="LoadedDecoder")
self._loaded_module = loaded_module
if tf.executing_eagerly():
record_index_tensor_name = (
Expand Down
4 changes: 2 additions & 2 deletions tfx_bsl/coders/tf_graph_record_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class _DecoderForTesting(tf_graph_record_decoder.TFGraphRecordDecoder):

def __init__(self):
super(_DecoderForTesting, self).__init__("DecoderForTesting")
super().__init__("DecoderForTesting")

def _decode_record_internal(self, record):
indices = tf.transpose(tf.stack([
Expand Down Expand Up @@ -63,7 +63,7 @@ def record_index_tensor_name(self):
class TfGraphRecordDecoderTest(tf.test.TestCase):

def setUp(self):
super(TfGraphRecordDecoderTest, self).setUp()
super().setUp()
self._tmp_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir)

def test_save_load_decode(self):
Expand Down
19 changes: 6 additions & 13 deletions tfx_bsl/test_util/run_all_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from absl import app
from absl import flags
from absl import logging
import six


flags.DEFINE_list(
Expand Down Expand Up @@ -118,18 +117,12 @@ def PrintLogs(self) -> None:
(self.stdout, "STDOUT"), (self.stderr, "STDERR")):
f.flush()
f.seek(0)
if six.PY2:
sys.stdout.write("BEGIN %s of test %s\n" % (stream_name, self))
sys.stdout.write(f.read())
sys.stdout.write("END %s of test %s\n" % (stream_name, self))
sys.stdout.flush()
else:
# Since we collected binary data, we have to write binary data.
encoded = (stream_name.encode(), str(self).encode())
sys.stdout.buffer.write(b"BEGIN %s of test %s\n" % encoded)
sys.stdout.buffer.write(f.read())
sys.stdout.buffer.write(b"END %s of test %s\n" % encoded)
sys.stdout.buffer.flush()
# Since we collected binary data, we have to write binary data.
encoded = (stream_name.encode(), str(self).encode())
sys.stdout.buffer.write(b"BEGIN %s of test %s\n" % encoded)
sys.stdout.buffer.write(f.read())
sys.stdout.buffer.write(b"END %s of test %s\n" % encoded)
sys.stdout.buffer.flush()


def _DiscoverTests(root_dirs: List[Text],
Expand Down
Loading

0 comments on commit c2e8033

Please sign in to comment.