Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Update [added initial_prompt support for automatic-speech-recognition whisper pipeline] #28556

Open
wants to merge 37 commits into
base: main
Choose a base branch
from

Conversation

Biswajit2902
Copy link

@Biswajit2902 Biswajit2902 commented Jan 17, 2024

What does this PR do?

Fixes # (feature)

  • initial_prompt support for whisper Pipeline (automatic-speech-recognition)

Before submitting

  • Added initial_prompt as an option for whisper model
  • To handle initial prompt processor considered as optional parameter
  • Current implementation supports only Torch version of decoding.
  • how to use initial prompt;
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
import torch

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-small"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    batch_size=16,
    torch_dtype=torch_dtype,
    device=device,
    processor=processor
)


dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
audio = dataset[0]["audio"]["array"]
sampling_rate = dataset[0]["audio"]["sampling_rate"]

# including timestamp
print(pipe(audio, initial_prompt = "Biswajit, Whisper", return_timestamps=True))

# without timestamp
print(pipe(audio, initial_prompt = "Biswajit, Whisper"))

Who can review?

Anyone in the community is free to review the PR once the tests have passed. @sanchit-gandhi , @Narsil, Can anyone help to take this PR forward please. Let me know, if anything is needed.

fixes #27317

@Biswajit2902 Biswajit2902 changed the title Feature Update [added support for initial_prompt for automatic-speech-recognition whisper pipeline] Feature Update [added initial_prompt support for automatic-speech-recognition whisper pipeline] Jan 17, 2024
@Biswajit2902 Biswajit2902 marked this pull request as ready for review January 17, 2024 14:10
@kaminwong
Copy link

Hi thank you your code saved my day! I think line 535 needs to modify a bit prompt_tensor = torch.tensor(generate_kwargs["prompt_ids"], dtype=out["tokens"].dtype).cuda() if is_torch_cuda_available else torch.tensor(generate_kwargs["prompt_ids"], dtype=out["tokens"].dtype), and add is_torch_cuda_available to line 22. without cuda it'll run on cpu which is a lot slower.

@Biswajit2902
Copy link
Author

@kaminwong , this is just to modify the output sequence to avoid showing inital_prompt in transcription.

Actual generation has device handles in below line.

           tokens = self.model.generate(
                attention_mask=attention_mask,
                **generate_kwargs,
            )

Apart from this token decoding part is serialised implementation which has no effect, that can be misuse of GPU.

@kaminwong
Copy link

kaminwong commented Jan 27, 2024

Thanks for the reply! But if I don't make that changes I get the following error, so I assume prompt_tensor needs to be in cuda if device is also in cuda? Or is there any other way to correct the error? Thank you for your time.

File "/.../python3.10/site-packages/transformers/pipelines/automatic_speech_recognition.py", line 538, in _forward if (tmp_tokens[0:nprompt_token] == prompt_tensor).sum() == nprompt_token: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I followed the code you posted:


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-small"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    batch_size=16,
    torch_dtype=torch_dtype,
    device=device,
    processor=processor
)

@Biswajit2902
Copy link
Author

@kaminwong , Thank you for addressing. I understood the issue. let me verify and reolved it.

…/transformers/pipelines/automatic_speech_recognition.py
…/transformers/pipelines/automatic_speech_recognition.py (formatted)
@Biswajit2902
Copy link
Author

@kaminwong , you can pull latest commit and install it should work now. its fixed.

@thomasmol
Copy link

@Biswajit2902 any new updates? let me know if you need help

@Biswajit2902
Copy link
Author

@thomasmol I will update on this soon. was busy since two weeks. Thank you for the reminder.

@amyeroberts amyeroberts added Core: Pipeline Internals of the library; Pipeline. Audio labels Apr 24, 2024
@Biswajit2902
Copy link
Author

Biswajit2902 commented Apr 29, 2024

@thomasmol @sanchit-gandhi , i see below conflict in AutomaticSpeechRecognitionPipeline._sanitize_parameters;

<<<<<<< main
            forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens
        if initial_prompt is not None:
            forward_params["generate_kwargs"]["initial_prompt"] = initial_prompt
=======
            forward_params["max_new_tokens"] = max_new_tokens
>>>>>>> main

I want to understand why we removed generate_kwargs from forward_params. Also initial_prompt.

My changes before were working fine. But after this, there seems have some bug. I am working on resolving it. So need your input on this.

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented May 20, 2024

Hey @Biswajit2902 - you can read the motivation for this change here. Essentially, we're unifying the forward_params and generate_kwargs in _sanitize_parameters. However, for the purposes of your feature, you should strive to put the initial_prompt under preprocess_params:

preprocess_params["initial_prompt"] = initial_prompt

And then convert the text prompt to token ids in the preprocess method, which will then be passed to _forward.

@sanchit-gandhi sanchit-gandhi linked an issue May 22, 2024 that may be closed by this pull request
4 tasks
@Biswajit2902
Copy link
Author

@sanchit-gandhi , Thanks for the pointer. Sorry got super busy could go back review. Will do it soon and close it.

@Biswajit2902
Copy link
Author

@sanchit-gandhi , Just an update. I have made the changes for this issue as suggested. But i have identified the output is not proper like before. seems like generate has some issue. its adding initial prompt with all the chunks. Will check and update on this. Also let me know if any existing issue going on this to your knowledge.

@huggingface huggingface deleted a comment from github-actions bot Jul 16, 2024
@amyeroberts
Copy link
Collaborator

cc @kamilakesbi as @sanchit-gandhi is off

@basicblueberrry136
Copy link

are there any updates on this? or other ways you know of for pushing the model to more easily detect certain words using this pipeline?

@amyeroberts
Copy link
Collaborator

c @ylacombe

@ylacombe
Copy link
Contributor

ylacombe commented Sep 2, 2024

Hey @basicblueberrry136, thanks for your comment!
@sanchit-gandhi's review still has to be addressed before the next steps. Once it's done, I'll make another review! Hopefully it'll move fast!

@JacobLinCool
Copy link
Contributor

I believe this is very helpful when used with the serverless inference API.

It seems that the serverless inference API uses the Transformers library to run models, and we cannot pass any parameter that has a type of tensor, as shown below:

const data = fs.readFileSync(filename);
const b64 = data.toString('base64');

const body = JSON.stringify({
    inputs: b64,
    parameters: {
        return_timestamps: true,
        generate_kwargs: {
            num_beams: 1,
            prompt_ids: [50362, 27338, 3763, 48022, 2257, 48022, 6784, 118, 25157, 1546, 15789, 23987, 5975, 17174, 28472, 25750, 6062, 1543],
        }
    }
});

It results in the following error:

{
  "error": "unknown error",
  "warnings": [
    "There was an inference error: unknown error: list indices must be integers or slices, not NoneType"
  ]
}

If initial_prompt is added, we can pass the prompt as a string to the serverless inference API.

@jollyfish-cjy
Copy link

Hi, thanks for your work! Are there any updates on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Audio Core: Pipeline Internals of the library; Pipeline.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

audio pipeline support for initial_prompt? openai/whisper-large-v2 prompt
9 participants