Skip to content

Commit

Permalink
Task models fix (#1922)
Browse files Browse the repository at this point in the history
* added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates

* fix for wrongly configured task models LLama, PaliGemma, Mistral and Phi3 + test

* comments

* un commented the test lines that were commented by mistake

* fixed linter errors
  • Loading branch information
martin-gorner authored Oct 16, 2024
1 parent 1777eac commit b737b83
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 4 deletions.
4 changes: 3 additions & 1 deletion keras_hub/src/models/llama/llama_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
self.preprocessor = preprocessor

# === Functional Model ===
inputs = backbone.inputs
# This must be "backbone.input" i.e. the full input structure,
# rather than "backbone.inputs" which is the flattened list of inputs.
inputs = backbone.input
hidden_states = backbone(inputs)
outputs = backbone.token_embedding(hidden_states, reverse=True)
super().__init__(
Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/models/mistral/mistral_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
self.preprocessor = preprocessor

# === Functional Model ===
inputs = backbone.inputs
# This must be "backbone.input" i.e. the full input structure,
# rather than "backbone.inputs" which is the flattened list of inputs.
inputs = backbone.input
hidden_states = backbone(inputs)
outputs = backbone.token_embedding(hidden_states, reverse=True)
super().__init__(
Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def __init__(
self.backbone = backbone

# === Functional Model ===
inputs = backbone.inputs
# This must be "backbone.input" i.e. the full input structure,
# rather than "backbone.inputs" which is the flattened list of inputs.
inputs = backbone.input
hidden_state = backbone(inputs=inputs)
outputs = backbone.token_embedding(hidden_state, reverse=True)
outputs = outputs[:, backbone.image_sequence_length :, :]
Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/models/phi3/phi3_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
self.preprocessor = preprocessor

# === Functional Model ===
inputs = backbone.inputs
# This must be "backbone.input" i.e. the full input structure,
# rather than "backbone.inputs" which is the flattened list of inputs.
inputs = backbone.input
hidden_states = backbone(inputs)
outputs = backbone.token_embedding(hidden_states, reverse=True)
super().__init__(
Expand Down
9 changes: 9 additions & 0 deletions keras_hub/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,15 @@ def run_task_test(
ds = tf.data.Dataset.from_tensor_slices(train_data).batch(batch_size)
x, y, sw = keras.utils.unpack_x_y_sample_weight(train_data)

# Test: the tree struct output by the
# preprocessor must match what model expects.
preprocessed_data = preprocessor(*train_data)[0]
tree.assert_same_structure(
preprocessed_data,
task._inputs_struct,
check_types=False,
)

# Test predict.
output = task.predict(x)
if expected_output_shape is not None:
Expand Down

0 comments on commit b737b83

Please sign in to comment.