Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix problems found when running hugging face text summariztion model on large input. #929

Merged
merged 8 commits into from
Aug 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions evadb/executor/delete_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,31 +45,31 @@ def predicate_node_to_filter_clause(
left = predicate_node.get_child(0)
right = predicate_node.get_child(1)

if type(left) == TupleValueExpression:
if isinstance(left, TupleValueExpression):
column = left.name
x = table.columns[column]
elif type(left) == ConstantValueExpression:
elif isinstance(left, ConstantValueExpression):
value = left.value
x = value
else:
left_filter_clause = self.predicate_node_to_filter_clause(table, left)

if type(right) == TupleValueExpression:
if isinstance(right, TupleValueExpression):
column = right.name
y = table.columns[column]
elif type(right) == ConstantValueExpression:
elif isinstance(right, ConstantValueExpression):
value = right.value
y = value
else:
right_filter_clause = self.predicate_node_to_filter_clause(table, right)

if type(predicate_node) == LogicalExpression:
if isinstance(predicate_node, LogicalExpression):
if predicate_node.etype == ExpressionType.LOGICAL_AND:
filter_clause = and_(left_filter_clause, right_filter_clause)
elif predicate_node.etype == ExpressionType.LOGICAL_OR:
filter_clause = or_(left_filter_clause, right_filter_clause)

elif type(predicate_node) == ComparisonExpression:
elif isinstance(predicate_node, ComparisonExpression):
assert (
predicate_node.etype != ExpressionType.COMPARE_CONTAINS
and predicate_node.etype != ExpressionType.COMPARE_IS_CONTAINED
Expand Down
2 changes: 2 additions & 0 deletions evadb/optimizer/rules/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,7 @@ def apply(self, before: LogicalGet, context: OptimizerContext):
# read in a batch from storage engine.
# Todo: Experiment heuristics.
after = SeqScanPlan(None, before.target_list, before.alias)
batch_mem_size = context.db.config.get_value("executor", "batch_mem_size")
after.append_child(
StoragePlan(
before.table_obj,
Expand All @@ -880,6 +881,7 @@ def apply(self, before: LogicalGet, context: OptimizerContext):
sampling_rate=before.sampling_rate,
sampling_type=before.sampling_type,
chunk_params=before.chunk_params,
batch_mem_size=batch_mem_size,
)
)
yield after
Expand Down
6 changes: 6 additions & 0 deletions evadb/third_party/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class TextHFModel(AbstractHFUdf):
Base Model for all HF Models that take in text as input
"""

def __call__(self, *args, **kwargs):
# Use truncation=True to handle the case where num of tokens is larger
# than limit
# Ref: https://stackoverflow.com/questions/66954682/token-indices-sequence-length-is-longer-than-the-specified-maximum-sequence-leng
return self.forward(args[0], truncation=True)

def input_formatter(self, inputs: Any):
return inputs.values.flatten().tolist()

Expand Down
60 changes: 60 additions & 0 deletions test/optimizer/rules/test_batch_mem_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test.util import get_evadb_for_testing

from mock import ANY, patch

from evadb.server.command_handler import execute_query_fetch_all
from evadb.storage.sqlite_storage_engine import SQLStorageEngine


class BatchMemSizeTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.evadb = get_evadb_for_testing()
# reset the catalog manager before running each test
cls.evadb.catalog().reset()

@classmethod
def tearDownClass(cls):
execute_query_fetch_all(cls.evadb, "DROP TABLE IF EXISTS MyCSV;")

def test_batch_mem_size_for_sqlite_storage_engine(self):
"""
This testcase make sure that the `batch_mem_size` is correctly passed to
the storage engine.
"""
test_batch_mem_size = 100
self.evadb.config.update_value(
"executor", "batch_mem_size", test_batch_mem_size
)
create_table_query = """
CREATE TABLE IF NOT EXISTS MyCSV (
id INTEGER UNIQUE,
frame_id INTEGER,
video_id INTEGER,
dataset_name TEXT(30),
label TEXT(30),
bbox NDARRAY FLOAT32(4),
object_id INTEGER
);"""
execute_query_fetch_all(self.evadb, create_table_query)

select_table_query = "SELECT * FROM MyCSV;"
with patch.object(SQLStorageEngine, "read") as mock_read:
mock_read.__iter__.return_value = []
execute_query_fetch_all(self.evadb, select_table_query)
mock_read.assert_called_with(ANY, test_batch_mem_size)
47 changes: 47 additions & 0 deletions test/udfs/test_hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import pandas as pd
from mock import MagicMock

from evadb.third_party.huggingface.model import TextHFModel


class TestTextHFModel(TextHFModel):
@property
def default_pipeline_args(self) -> dict:
# We need to improve the hugging face interface, passing
# UdfCatalogEntry into UDF is not ideal.
return {
"task": "summarization",
"model": "sshleifer/distilbart-cnn-12-6",
"min_length": 5,
"max_length": 100,
}


class HuggingFaceTest(unittest.TestCase):
def test_hugging_face_with_large_input(self):
udf_obj = MagicMock()
udf_obj.metadata = []
text_summarization_model = TestTextHFModel(udf_obj)

large_text = pd.DataFrame([{"text": "hello" * 4096}])
try:
text_summarization_model(large_text)
except IndexError:
self.fail("hugging face with large input raised IndexError.")