Skip to content

Commit

Permalink
Add change_vocabulary and save_tokenizers() support to Multitask ASR …
Browse files Browse the repository at this point in the history
…models (NVIDIA#8357) (NVIDIA#8367)

* Add change_vocabulary and save_tokenizers() support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update nemo/collections/asr/models/aed_multitask_models.py

---------

Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Piotr Żelasko <petezor@gmail.com>
Signed-off-by: biscayan <skgudwn34@gmail.com>
  • Loading branch information
4 people authored and biscayan committed Feb 15, 2024
1 parent 1d1e628 commit 5ea2c61
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 2 deletions.
135 changes: 134 additions & 1 deletion nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import itertools
import os
import tempfile
Expand All @@ -23,7 +24,7 @@
import numpy as np
import torch
import torch.distributed as dist
from omegaconf import DictConfig, OmegaConf, open_dict
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from torchmetrics.text import SacreBLEUScore
from tqdm.auto import tqdm
Expand Down Expand Up @@ -247,6 +248,138 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):

logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}")

def change_vocabulary(
self,
new_tokenizer_dir: Union[str, DictConfig],
new_tokenizer_type: str,
decoding_cfg: Optional[DictConfig] = None,
prompt_format: Optional[str] = None,
):
"""
Changes vocabulary used during AED decoding process. Use this method when fine-tuning on from pre-trained model.
This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need
model to learn capitalization, punctuation and/or special characters.
Args:
new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`)
new_tokenizer_type: Type of tokenizer. Can be either `agg`, `bpe` or `wpe`.
decoding_cfg: A config for the decoding, which is optional. If the decoding type
needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here.
prompt_format: A string alias of the object that represents the prompt structure.
If not None, it will be used to update the prompt format.
"""
if isinstance(new_tokenizer_dir, (dict, DictConfig)):
if new_tokenizer_type == 'agg':
if not isinstance(new_tokenizer_dir, DictConfig):
new_tokenizer_dir = OmegaConf.create(new_tokenizer_dir)

new_tokenizer_cfg = new_tokenizer_dir
else:
raise ValueError(
f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}'
)
else:
new_tokenizer_cfg = None

if new_tokenizer_cfg is not None:
tokenizer_cfg = new_tokenizer_cfg
else:
if not os.path.isdir(new_tokenizer_dir):
raise NotADirectoryError(
f'New tokenizer dir must be non-empty path to a directory. But instead got: {new_tokenizer_dir}'
)

if new_tokenizer_type.lower() not in ('bpe', 'wpe'):
raise ValueError(f'New tokenizer type must be either `bpe` or `wpe`')

tokenizer_cfg = OmegaConf.create({'dir': new_tokenizer_dir, 'type': new_tokenizer_type})

if prompt_format is None:
prompt_format = self.cfg.prompt_format

# Setup the tokenizer
self._setup_tokenizer(tokenizer_cfg)

# Initialize a dummy vocabulary
vocabulary = self.tokenizer.tokenizer.get_vocab()

# Setup Decoder
transf_decoder_cfg_dict = self.transf_decoder.to_config_dict()

vocab_size = 8 * ceil(self.tokenizer.vocab_size / 8)

# Auto inject vocab size for `get_transformer`
with open_dict(transf_decoder_cfg_dict):
if 'config_dict' in transf_decoder_cfg_dict:
transf_decoder_cfg_dict['config_dict']['vocab_size'] = vocab_size

original_decoder_state_dict = self.transf_decoder.state_dict()
self.transf_decoder = EncDecMultiTaskModel.from_config_dict(transf_decoder_cfg_dict)

# Partially load the original state dict into the new decoder
decoder_state_dict = self.transf_decoder.state_dict()
for og_key, og_value in original_decoder_state_dict.items():
if og_key in decoder_state_dict and og_value.shape == decoder_state_dict[og_key].shape:
decoder_state_dict[og_key] = og_value
else:
logging.warning(
f"Skipping key `{og_key}` in the `transf_decoder` module from original state dict due "
f"to shape mismatch after change in vocabulary.\n"
f"Original shape: {og_value.shape}, New shape: {decoder_state_dict[og_key].shape}"
)

self.transf_decoder.load_state_dict(decoder_state_dict)

# Setup token classifier
with open_dict(self.cfg.head):
self.cfg.head.num_classes = vocab_size

del self.log_softmax
self.log_softmax = EncDecMultiTaskModel.from_config_dict(self.cfg.head)

# Weight tying - if using TokenClassifier only
if isinstance(self.log_softmax, TokenClassifier):
self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight

# Initialize weights of token classifier
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5
self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range))

# Setup Decoding class
if decoding_cfg is None:
# Assume same decoding config as before
decoding_cfg = self.cfg.decoding

# Assert the decoding config with all hyper parameters
decoding_cls = OmegaConf.structured(MultiTaskDecodingConfig)
decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls))
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

del self.decoding
self.decoding = MultiTaskDecoding(
decoding_cfg=decoding_cfg,
transformer_decoder=self.transf_decoder,
log_softmax_module=self.log_softmax,
tokenizer=self.tokenizer,
)

with open_dict(self.cfg.decoding):
self.cfg.decoding = decoding_cfg

# Setup loss
with open_dict(self.cfg.loss):
self.cfg.loss.pad_id = self.tokenizer.pad_id

del self.loss
self.loss = EncDecMultiTaskModel.from_config_dict(self.cfg.loss)

# Update config
with open_dict(self.cfg):
self.cfg.prompt_format = prompt_format

logging.info(f"Changed decoder to output to {vocabulary} vocabulary.")

@torch.no_grad()
def transcribe(
self,
Expand Down
95 changes: 94 additions & 1 deletion nemo/collections/asr/parts/mixins/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import os
import shutil
import tarfile
from abc import ABC, abstractmethod
from typing import List

Expand All @@ -25,7 +27,7 @@
from nemo.collections.asr.parts.utils import asr_module_utils
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.common import tokenizers
from nemo.utils import logging
from nemo.utils import app_state, logging


class ASRBPEMixin(ABC):
Expand Down Expand Up @@ -372,6 +374,97 @@ def _cleanup_aggregate_config_and_artifacts_if_needed(self):
if akey.startswith('tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.'):
self.artifacts.pop(akey)

def save_tokenizers(self, directory: str):
"""
Save the model tokenizer(s) to the specified directory.
Args:
directory: The directory to save the tokenizer(s) to.
"""
if not hasattr(self, 'cfg'):
raise RuntimeError(
"The model has not been initialized with a tokenizer yet. Please call the model's "
"__init__ and _setup_tokenizer methods first."
)

if self.tokenizer_type == 'agg':
for lang in self.tokenizer.langs:
subconfig = self.cfg.tokenizer.langs.get(lang)
new_dir = os.path.join(directory, lang)
self._extract_tokenizer_from_config(subconfig, new_dir)
else:
self._extract_tokenizer_from_config(self.cfg.tokenizer, directory)

def _extract_tokenizer_from_config(self, tokenizer_cfg: DictConfig, dir: str):
"""
Extracts the tokenizer from the config and write the objects to dir.
The file may be from a local path (new model init) or from a .nemo file (restored model).
If its from a newly initialized model, the file is copied to dir.
If its from a restored model, the file is extracted from the .nemo file and copied to dir.
Args:
tokenizer_cfg: The tokenizer config to extract the tokenizer from.
dir: The directory to write the tokenizer objects to.
"""
if not os.path.exists(dir):
os.makedirs(dir, exist_ok=True)

nemo_file_objects = []

for k, v in tokenizer_cfg.items():
# Check if the value is a filepath (new model init) or has `nemo:` in it (restored model)
if isinstance(v, str) and os.path.exists(v):
# local file from first instantiation
loc = shutil.copy2(v, dir)
logging.info(f"Saved {k} at {loc}")

if isinstance(v, str) and v.startswith('nemo:'):
nemo_object_name = v[5:]
nemo_file_objects.append(nemo_object_name)

if len(nemo_file_objects) > 0:
logging.debug(f"Copying the following nemo file objects to {dir}: {nemo_file_objects}")

if not hasattr(self, 'model_guid'):
raise ValueError(
"The model does not have a model_guid attribute. "
"Please ensure that the model has been restored from a .nemo file."
)

appstate = app_state.AppState()
restore_path = appstate.get_model_metadata_from_guid(self.model_guid).restoration_path
if restore_path is None:
raise ValueError(
"The model has not been restored from a .nemo file. Cannot extract the tokenizer "
"as the nemo file cannot be located."
)

# Read the nemo file without fully extracting all contents
# we start with an assumption of uncompressed tar,
# which should be true for versions 1.7.0 and above
tar_header = "r:"
try:
tar_test = tarfile.open(restore_path, tar_header)
tar_test.close()
except tarfile.ReadError:
# can be older checkpoint => try compressed tar
tar_header = "r:gz"
tar = tarfile.open(restore_path, tar_header)

for nemo_object_name in nemo_file_objects:
members = [x for x in tar.getmembers() if nemo_object_name in x.name]
for member in members:
tar.extract(member, dir)

new_name = member.name.split("_")[1:]
if len(new_name) > 1:
new_name = "_".join(new_name)
else:
new_name = new_name[0]
os.rename(os.path.join(dir, member.name), os.path.join(dir, new_name))

logging.info(f"Saved {nemo_object_name} at {os.path.join(dir, new_name)}")


class ASRModuleMixin(ASRAdapterModelMixin):
"""
Expand Down

0 comments on commit 5ea2c61

Please sign in to comment.