Skip to content

Commit

Permalink
internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 468794403
  • Loading branch information
gauravmishra authored and SeqIO committed Aug 19, 2022
1 parent e5c51c5 commit 71e47ac
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 65 deletions.
16 changes: 2 additions & 14 deletions seqio/beam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,8 @@ class GetStats(beam.PTransform):
prefixed by the identifiers.
"""

def __init__(self,
output_features: Mapping[str, seqio.Feature],
task_ids: Optional[Mapping[str, int]] = None):
def __init__(self, output_features: Mapping[str, seqio.Feature]):
self._output_features = output_features
self._task_ids = task_ids or {}

def expand(self, pcoll):
example_counts = (
Expand Down Expand Up @@ -399,15 +396,6 @@ def _merge_dicts(dicts):
merged_dict.update(d)
return merged_dict

stats = [example_counts, total_tokens, max_tokens, char_length]
if self._task_ids:
task_ids_dict = {"task_ids": self._task_ids}
task_ids = (
pcoll
| "sample_for_task_ids" >> beam.combiners.Sample.FixedSizeGlobally(1)
| "create_task_ids" >> beam.Map(lambda _: task_ids_dict))
stats.append(task_ids)

return (stats
return ([example_counts, total_tokens, max_tokens, char_length]
| "flatten_counts" >> beam.Flatten()
| "merge_stats" >> beam.CombineGlobally(_merge_dicts))
43 changes: 0 additions & 43 deletions seqio/beam_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,49 +222,6 @@ def test_get_stats_tokenized_dataset(self):
"targets_chars": 12,
}]))

def test_get_stats_task_ids(self):
# These examples are assumed to be decoded by
# `seqio.test_utils.sentencepiece_vocab()`.
input_examples = [{
# Decoded as "ea", i.e., length 2 string
"inputs": np.array([4, 5]),
# Decoded as "ea test", i.e., length 7 string
"targets": np.array([4, 5, 10]),
}, {
# Decoded as "e", i.e., length 1 string
"inputs": np.array([4]),
# Decoded as "asoil", i.e., length 5 string. "1" is an EOS id.
"targets": np.array([5, 6, 7, 8, 9, 1])
}]

output_features = seqio.test_utils.FakeTaskTest.DEFAULT_OUTPUT_FEATURES
with TestPipeline() as p:
pcoll = (
p
| beam.Create(input_examples)
| beam_utils.GetStats(
output_features=output_features,
task_ids={
"task_name_1": 1,
"task_name_2": 2
}))

util.assert_that(
pcoll,
util.equal_to([{
"inputs_tokens": 3, # 4 and 3 from the first and second exmaples.
"targets_tokens": 8,
"inputs_max_tokens": 2,
"targets_max_tokens": 5,
"examples": 2,
"inputs_chars": 3,
"targets_chars": 12,
"task_ids": {
"task_name_1": 1,
"task_name_2": 2
}
}]))

def test_count_characters_tokenized_dataset(self):
# These examples are assumed to be decoded by
# `seqio.test_utils.sentencepiece_vocab()`.
Expand Down
5 changes: 1 addition & 4 deletions seqio/test_data/cached_task_with_provenance/stats.train.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,5 @@
"inputs_tokens": 36,
"targets_chars": 29,
"targets_max_tokens": 6,
"targets_tokens": 18,
"task_ids": {
"task_name": 1
}
"targets_tokens": 18
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,5 @@
"inputs_tokens": 23,
"targets_chars": 37,
"targets_max_tokens": 21,
"targets_tokens": 36,
"task_ids": {
"task_name": 1
}
"targets_tokens": 36
}

0 comments on commit 71e47ac

Please sign in to comment.