Skip to content

Commit

Permalink
Fix missing hyphens
Browse files Browse the repository at this point in the history
  • Loading branch information
antimora committed Feb 13, 2024
1 parent 08283c6 commit 30d0b56
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions burn-book/src/import/pytorch-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,13 @@ let model = Net::<Backend>::new_with(record);
with both an encoder and a decoder, it's possible to load only the encoder weights. This is done by
defining the encoder in Burn, allowing the loading of its weights while excluding the decoder's.

### Specifying the top-level-key for state_dict
### Specifying the top-level key for state_dict

Sometimes the [`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
is nested under a top-level-key along with other metadata as in a
is nested under a top-level key along with other metadata as in a
[general checkpoint](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training).
For example, the `state_dict` of the whisper model is nested under `model_state_dict` key.
In this case, you can specify the top-level-key in `LoadArgs`:
In this case, you can specify the top-level key in `LoadArgs`:

```rust
let device = Default::default();
Expand Down
2 changes: 1 addition & 1 deletion burn-core/src/record/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ impl<A: BurnModuleAdapter> IntoDeserializer<'_, Error> for NestedValueWrapper<A>

/// A default deserializer that always returns the default value.
struct DefaultDeserializer {
/// The originator field name (the top level missing field name)
/// The originator field name (the top-level missing field name)
originator_field_name: Option<String>,
}

Expand Down
2 changes: 1 addition & 1 deletion burn-import/src/pytorch/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use serde::{de::DeserializeOwned, Serialize};
///
/// * `path` - A string slice that holds the path of the file to read.
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
/// * `top_level_key` - An optional top level key to load state_dict from a dictionary.
/// * `top_level_key` - An optional top-level key to load state_dict from a dictionary.
pub fn from_file<PS, D, B>(
path: &Path,
key_remap: Vec<(Regex, String)>,
Expand Down
10 changes: 5 additions & 5 deletions burn-import/src/pytorch/recorder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ pub struct LoadArgs {
/// A list of key remappings.
pub key_remap: Vec<(Regex, String)>,

/// Top level key to load state_dict from the file.
/// Sometimes the state_dict is nested under a top level key in a dict.
/// Top-level key to load state_dict from the file.
/// Sometimes the state_dict is nested under a top-level key in a dict.
pub top_level_key: Option<String>,
}

Expand Down Expand Up @@ -125,12 +125,12 @@ impl LoadArgs {
self
}

/// Set top level key to load state_dict from the file.
/// Sometimes the state_dict is nested under a top level key in a dict.
/// Set top-level key to load state_dict from the file.
/// Sometimes the state_dict is nested under a top-level key in a dict.
///
/// # Arguments
///
/// * `key` - The top level key to load state_dict from the file.
/// * `key` - The top-level key to load state_dict from the file.
pub fn with_top_level_key(mut self, key: &str) -> Self {
self.top_level_key = Some(key.into());
self
Expand Down

0 comments on commit 30d0b56

Please sign in to comment.