-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modular caching layer / network design #1261
Comments
Thanks for the issue! Mulling this one over. Will reply more soon. |
Sorry for the delay! Some thoughts... At a high-level I agree it's hard to bring different attention to our encoder/decoder blocks. Note that we don't care about bringing a different attention mechanism to I do think there are some design constraints we should follow about that aren't addressed above:
One option for making it easier to use class CustomDecoder(keras_nlp.layers.CustomDecoder):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def create_attention_layer(self, inputs_shape):
# Override the default MultiHeadAttention here.
# Examples, RoPE embeddings, multi-query, T5 style trainable bias, etc.
return keras.layers.CustomAttention(
num_heads=self.num_heads,
key_dim=int(inputs_shape[-1] // self.num_heads),
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
) Anyway, I suspect more discussion might be needed here, but maybe this helps understand the design constraints we are working with? |
It absolutely does, thanks! I thought 1 and 3 might be constraints - just wasn't sure how "hard" they were (I'm a hacker, not a maintainer - can you tell?). Re: Regarding the restrictions you've enumerated above:
"Leaf" caching layers (i.e. those that create/use/manage their own caches) still have 3 separate
"variable" is an overloaded term, so I'm going to refer to these as "cache"s, but I get your drift. Conceptually this would make things simpler, though I don't see a way of making it computationally efficient. Most attention mechanisms have different parallel / sequential implementations, e.g. if we have a prompt for GPT2, we want to compute the valid lower triangular sub-matrix corresponding to the prompt in one go without any existing cache, then update that cache token-by-token during generation. In the below I've left it coupled with the forward without a cache, but if you can show there's an efficient way of doing things in a decoupled manner it wouldn't be hard to decouple.
This was almost a deal breaker for my involvement since I didn't think I've included a draft implementation below - see the bottom for a full script including the immediately snippets below, but I might take a moment to illustrate the interface from an implementers perspective with a simple example. The general idea is that "leaf" caching layers - layers which know how to create, use and update their own states - must implement two methods: one forward method without a cache that also creates the cache, and another forward pass using and potentially updating the cache. class LagAndAdd(CachingLayer):
def call_and_create_cache(self, x):
lagged = keras.ops.pad(x[:, :-1], ((0, 0), (1, 0), (0, 0)))
# always return the cache
cache = x[:, -1:]
return x + lagged, cache
def call_with_cache(self, x, cache):
updated_cache = x
return x + cache, updated_cache We can construct models using these layers as if they were regular layers, ignoring their caching potential. inp = keras.Input((None, 3))
x = LagAndAdd()(inp)
x = LagAndAdd()(x)
base_model = keras.Model(inp, x)
x = keras.random.normal((5, 7, 3))
base_out = base_model(x) We can then transform these models using graph transformations to ones that do both a forward pass and cache creation, and (separately) a model that performs a forward pass with cache + cache update call_and_create_cache_model = get_call_and_create_cache(base_model)
call_with_cache_model = get_call_with_cache(call_and_create_cache_model)
leading, cache = call_and_create_cache_model(x[:, :-1])
trailing, updated_cache = call_with_cache_model((x[:, -1:], cache))
np.testing.assert_allclose(keras.ops.concatenate((leading, trailing), axis=1), base_out)
print("Passed functional model version") and you can do a similar thing with a layer which isn't a model but defines a forward pass using other layers. class CompoundCache(CachingFunctionalLayer):
def build(self, input_shape):
if self.built:
return
self.layer0 = LagAndAdd()
self.layer1 = LagAndAdd()
for layer in (self.layer0, self.layer1):
layer.build(input_shape)
super().build(input_shape)
def call_without_cache(self, x):
residual = x
x = self.layer0(x)
x = self.layer1(x)
return x + residual
layer = CompoundCache()
base_out = layer(x)
leading, cache = layer(x[:, :-1], return_cache=True)
trailing, cache = layer(x[:, -1:], cache=cache)
np.testing.assert_allclose(keras.ops.concatenate((leading, trailing), axis=1), base_out)
print("Passed CachingFunctionalLayer version") Below are backing base classes, and everything from above for convenience. It'll need some more rigorous tests and tweaks before it's ready for a PR, but is this the kind of thing that might be accepted? That said, if a PR is a better place to continue discussion let me know and I'll put one together. import abc
import typing as tp
import tree
import keras_core as keras
class CachingLayer(keras.layers.Layer):
"""A layer that can create and update a cache for iterative inference.
Implementations should implement `call_and_create_cache` and
`call_with_cache`. They may optionally implement `call_without_cache`
if creation of the cache in `call_and_create_cache` is expensive.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.uses_cache = True
def call(self, *args, cache=None, return_cache=None, **kwargs):
if cache is None:
if return_cache:
return self.call_and_create_cache(*args, **kwargs)
return self.call_without_cache(*args, **kwargs)
assert return_cache is None or return_cache
return self.call_with_cache(*args, cache=cache, **kwargs)
@abc.abstractmethod
def call_and_create_cache(self, *args, **kwargs):
"""Get the output of this layer and create a cache.
The returned cache may be used in subsequent calls to
`call_with_cache`.
"""
@abc.abstractmethod
def call_with_cache(self, *args, cache, **kwargs):
"""Get the output of this layer using a previously created cache.
This method should return *args, where args[:-1] is the normal
output of the layer, and args[-1] is a single-tensor cache.
"""
def call_without_cache(self, *args, **kwargs):
"""Get the output of this layer without a cache input or output.
By default, this redirects to `call_and_create_cache` and throws
out the `cache`. Implementers should override this method if
there is a more optimal implementation that does not involve
creating the cache at all.
"""
*output, cache = self.call_and_create_cache(*args, **kwargs)
del cache
if len(output) == 1:
return output[0]
return output
def get_call_and_create_cache(model: keras.Model):
"""
Get `call_and_create_cache` model from `call_without_cache`.
"""
cache_outputs = []
def clone_function(op):
if isinstance(op, CachingLayer):
def f(*args, **kwargs):
kwargs = dict(kwargs)
kwargs["return_cache"] = True
*output, cache = op(*args, **kwargs)
cache_outputs.append(cache)
if len(output) == 1:
return output[0]
return output
return f
return op
cloned = keras.models.clone_model(
model, model.input, clone_function=clone_function
)
output = cloned.output
cache_output = keras.ops.stack(cache_outputs, axis=1)
if isinstance(output, keras.KerasTensor):
output = (output, cache_output)
else:
output = (*output, cache_output)
return keras.Model(cloned.input, output)
def get_call_with_cache(model: keras.Model):
"""
Get `call_with_cache` model from a `call_and_create_cache` model.
"""
cache_output = model.output[-1]
cache_input = keras.Input(batch_shape=cache_output.shape, dtype=cache_output.dtype)
cache_inputs = keras.ops.unstack(cache_input, axis=1)
# reverse order so we can pop in order
cache_inputs = cache_inputs[-1::-1]
def clone_function(op):
if getattr(op, "uses_cache", False):
def f(*args, **kwargs):
assert kwargs["return_cache"], kwargs
return op(*args, cache=cache_inputs.pop(), **kwargs)
return f
return op
inp = model.input
if isinstance(inp, keras.KerasTensor):
inputs = (inp, cache_input)
else:
inputs = (*inp, cache_input)
return keras.models.clone_model(model, inputs, clone_function=clone_function)
def _is_tensor(x) -> bool:
return hasattr(x, "__array__")
Tensor = tp.Any
class CachingFunctionalLayer(CachingLayer):
"""A caching layer made from other caching layers.
Implementations should implement `call_without_cache`, which should
be conceptually similar to a layer's standard `call` method without
concerns for the absence, presence or creation of any caches used by
constituent layers.
Currently, the only condition on constituent caching layers is that
they all produce caches of the same size such that they can be stacked.
The main difference between a normal layer's `call` method and
`call_without_cache` is that `call_without_cache` may be called with
symbolic inputs (`keras.KerasTensor`s). This is used for graph
transformations that create `call_and_create_cache` and `call_with_cache`
implementations.
"""
def _get_call_and_create_cache_model(
self, args, kwargs
) -> tp.Tuple[keras.Model, tp.List[Tensor]]:
tensors = [arg for arg in tree.flatten((args, kwargs)) if _is_tensor(arg)]
model_args, model_kwargs = tree.map_structure(
lambda x: keras.Input(batch_shape=x.shape, dtype=x.dtype)
if _is_tensor(x)
else x,
(args, kwargs),
)
inputs = [
arg
for arg in tree.flatten((model_args, model_kwargs))
if keras.backend.is_keras_tensor(arg)
]
output = self.call_without_cache(*model_args, **model_kwargs)
model = keras.Model(inputs, output)
call_and_create_cache_model = get_call_and_create_cache(model)
return call_and_create_cache_model, tensors
def call_and_create_cache(self, *args, **kwargs):
"""Get the output of this layer and create a cache.
The returned cache may be used in subsequent calls to
`call_with_cache`.
"""
model, tensors = self._get_call_and_create_cache_model(args, kwargs)
return model(tensors)
def call_with_cache(self, *args, cache, **kwargs):
"""Get the output of this layer using a previously created cache.
This method should return *args, where args[:-1] is the normal
output of the layer, and args[-1] is a single-tensor cache.
"""
call_and_create_cache_model, tensors = self._get_call_and_create_cache_model(
args, kwargs
)
call_with_cache_model = get_call_with_cache(call_and_create_cache_model)
return call_with_cache_model((*tensors, cache))
@abc.abstractmethod
def call_without_cache(self, *args, **kwargs):
"""Get the output of this layer without a cache input or output.
args and kwargs may contain symbolic tensors or backend tensors, but
never both.
"""
raise NotImplementedError("Abstract method")
def main():
import numpy as np
class LagAndAdd(CachingLayer):
def call_and_create_cache(self, x):
lagged = keras.ops.pad(x[:, :-1], ((0, 0), (1, 0), (0, 0)))
# always return the cache
cache = x[:, -1:]
return x + lagged, cache
def call_with_cache(self, x, cache):
updated_cache = x
return x + cache, updated_cache
inp = keras.Input((None, 3))
x = LagAndAdd()(inp)
x = LagAndAdd()(x)
base_model = keras.Model(inp, x)
x = keras.random.normal((5, 7, 3))
base_out = base_model(x)
call_and_create_cache_model = get_call_and_create_cache(base_model)
call_with_cache_model = get_call_with_cache(call_and_create_cache_model)
leading, cache = call_and_create_cache_model(x[:, :-1])
trailing, updated_cache = call_with_cache_model((x[:, -1:], cache))
np.testing.assert_allclose(keras.ops.concatenate((leading, trailing), axis=1), base_out)
print("Passed functional model version")
class CompoundCache(CachingFunctionalLayer):
def build(self, input_shape):
if self.built:
return
self.layer0 = LagAndAdd()
self.layer1 = LagAndAdd()
for layer in (self.layer0, self.layer1):
layer.build(input_shape)
super().build(input_shape)
def call_without_cache(self, x):
residual = x
x = self.layer0(x)
x = self.layer1(x)
return x + residual
layer = CompoundCache()
base_out = layer(x)
leading, cache = layer(x[:, :-1], return_cache=True)
trailing, cache = layer(x[:, -1:], cache=cache)
np.testing.assert_allclose(keras.ops.concatenate((leading, trailing), axis=1), base_out)
print("Passed CachingFunctionalLayer version")
if __name__ == '__main__':
main() |
I think this is actually supported today. And how we do it for our generative models. You can call Or you could do the whole cached computation in a loop one token at a time--even for the user supplied prompt. It is worth noting that for good XLA compiled performance, you want to always pad all input lengths to the attn layer. So use a query/key/value length of either Also worth noting that for the second case here (where you only compute a single token at a time and slice updates into your cache), you could never mix cache creation with call efficiently for jax. Jax will need the cache to be a loop variable in the compiled loop, and you don't actually want the forward pass to be run outside the compiled loop. So here you do want cache creation (just initializing nd.array of zeros) and call to be separate. Not sure if that helps, but while not obvious the current signature allows seeding a cache and incrementally updating a cache. |
Anyway, at a high-level I think this is probably not something we would want a PR for right now. A few reasons. First the UX is a little different from what we have. We don't really ship model in -> model out functions like The second is more to do with where all this lives and the progression of our Keras 3 release. We are somewhat in an interim state, but soon Keras 3 will be released, and this package can more readily rely on Keras 3 features. I suspect we will then want to push some of this development down into Keras 3...
A nice to have here would be an easy pattern for creating a transformer block with custom attention, but I don't think a deal breaker. More important is to have the low level building blocks (multi-head, grouped query, etc) in core Keras, and the high-level popular models in KerasNLP. At the end of the day, there is still so much variation in transformer blocks that I don't really think it's a feasible design goal to cover them all in one class. So make your own "transformer block" from dense and attention layers will have to be a common path. Hope that helps explain things! But wouldn't put the breaks on this by any means, just might be more something for it's own repo, because of slightly different design goals. |
There I was thinking
I don't follow this. Surely there's an initialization stage before the loop. That's where I should say I'm thinking about this from a concept more general than just transformers - particularly models like RWKV and RetNet that don't only have linear memory constraints. Just read your second point and that all sounds fine. I'll close this for now - might comment later if I get around to splitting something like what's above into a separate repo for my own purposes in case others decide it's useful. Thanks for the idea bounce :) |
From the roadmap, "KerasNLP is focused on modular and reusable building blocks". Having tried to implement some causal generative models I've found this not to be the case. The low level blocks are great - e.g.
CachedMultiHeadAttention
- but higher level blocks (TransformerDecoder
,GPT2
models) have implementations tightly coupled to the underlying attention mechanism. Specifically, each of these higher level constructs needs to know not only how to call a layer in standard training, but also the types of inputs required during token-by-token generation.Describe the solution you'd like
Conceptually, caching layers (like
CachedMultiHeadAttention
) need to support 3 functionalities:CachedMultiHeadAttention
does all three in a singlecall
method, with the presence or absence of input arguments dictating the behaviour, but it could just as easily be implemented asHigher level layers/models can then implement
call_and_create_cache
andcall_with_cache
based on the keras graph implied incall
, substituting out the relevantcall
s forcall_and_create_cache
orcall_with_cache
. Cache creation / storage / retrieval would have to be at the call node level, rather than the child-layer level since each layer could potentially be called multiple times and would need it's own cache for each call, but the keras infrastructure is already set up to support this.This would allow:
__call__
methods, ignorant of any caching; andgenerate
function the "just works".I've experimented with this here (CausalLM model here) and it works - though there are no doubt edge cases that aren't accounted for. It's based on
Function._run_through_graph
and uses private member_nodes_by_depth
.I would be very happy to contribute such a feature. Before I spend any more time on it though, a few questions:
The text was updated successfully, but these errors were encountered: