-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weights sharding for Keras saving #19286
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19286 +/- ##
==========================================
- Coverage 80.14% 75.61% -4.53%
==========================================
Files 341 365 +24
Lines 36163 39909 +3746
Branches 7116 7747 +631
==========================================
+ Hits 28982 30177 +1195
- Misses 5578 8054 +2476
- Partials 1603 1678 +75
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@@ -617,6 +662,138 @@ def close(self): | |||
self.io_file.close() | |||
|
|||
|
|||
class ShardedH5IOStore: | |||
def __init__(self, root_path, max_size="10GB", archive=None, mode="r"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't max_size be an int? e.g. in MB?
@@ -754,6 +754,70 @@ def call(self, inputs): | |||
return self.first_layer(self.second_layer(inputs)) | |||
|
|||
|
|||
def _get_large_model(): | |||
model = keras.Sequential( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why pick a convnet for a large model?
@@ -217,7 +251,7 @@ def save_weights_only(model, filepath): | |||
weights_store.close() | |||
|
|||
|
|||
def load_weights_only(model, filepath, skip_mismatch=False): | |||
def load_weights_only(model, filepath, sharded=False, skip_mismatch=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should sharded
be configurable here -- wouldn't it just depend on the file and the model?
Talked with Neel a bit about this, but one idea, building off the recent change Francois made with
Pseudocode: write(path, key, value):
if self.current_shard_size + value.nbytes > self.shard_size:
close current shard
open new shard file
self.current_shard_size = 0
group = create parent groups if needed
self.current_shard_size += value.nbytes
group[key] = value
read(path, key):
for file in shards:
if path in file:
group = file[path]
if key in group:
return group[key] This could be fairly simple. Avoid the need for a separate class if we want (though we still could), allow splitting up individual layer weight across shards (important if you have one big layer). This could even allow avoiding the json file entirely I think? Supporting something like this: # If shard_size is set, pass a format string as path?
filenames = model.save_weights("./model_{}.weights.h5", shard_size="10GB")
# Load weights handles loading a list of files, and checking all files for the variables.
model.load_weights(filenames) This last bit is optional, just though it was interesting. What do people think? |
Actually thinking about this more, let's keep the json file. When downloading from hubs, we want to be able to download one file that tells us exactly what other paths to download. |
This PR adds weights sharding initial functionality to the Keras saving/loading APIs, which are accessed by passing the
sharded=True
flag to the corresponding saving/loading calls.