diff --git a/requirements/common.txt b/requirements/common.txt index 32745a101f65..7523c0983ae1 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -41,7 +41,6 @@ depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/other/logging_configuration.md -scipy # Required for phi-4-multimodal-instruct ninja # Required for xgrammar, rocm, tpu, xpu opentelemetry-sdk>=1.26.0,<1.27.0 # vllm.tracing opentelemetry-api>=1.26.0,<1.27.0 # vllm.tracing diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index ec19797f8875..249d0f51ebcd 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -6,7 +6,6 @@ TypedDict, Union) import numpy as np -import scipy.signal import torch import torch.nn as nn import torchvision.transforms as T @@ -26,6 +25,7 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.phi4mm_utils import resample_poly from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors @@ -765,10 +765,10 @@ def extract_spectrogram(self, wav, fs): # Resample to 16000 or 8000 if needed if fs > 16000: - wav = scipy.signal.resample_poly(wav, 1, fs // 16000) + wav = resample_poly(wav, 1, fs // 16000) fs = 16000 elif 8000 < fs < 16000: - wav = scipy.signal.resample_poly(wav, 1, fs // 8000) + wav = resample_poly(wav, 1, fs // 8000) fs = 8000 elif fs < 8000: raise RuntimeError(f"Unsupported sample rate {fs}") @@ -777,7 +777,7 @@ def extract_spectrogram(self, wav, fs): if self._eightk_method == "resample": # Input audio is 8 kHz. Convert to 16 kHz before feature # extraction - wav = scipy.signal.resample_poly(wav, 2, 1) + wav = resample_poly(wav, 2, 1) fs = 16000 # Do nothing here for fillzero method elif fs != 16000: diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index 9f08a1c4c6f5..55472ec2b682 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -7,6 +7,7 @@ import math from typing import Optional, Tuple, Union +import numpy as np import torch import torch.nn.functional as F from torch import Tensor, nn @@ -1881,3 +1882,138 @@ def unfold_tensor(xs_pad, max_seq_len): # NT' x max_seq_len x D xs_pad = xs_pad.view(-1, max_seq_len, D) return xs_pad + + +def firwin(numtaps, cutoff): + """ + Design a simple lowpass FIR filter using a sinc function multiplied + by a Hamming window. + + Parameters + ---------- + numtaps : int + The number of filter taps (must be odd). + cutoff : float + Normalized cutoff frequency (0 < cutoff < 0.5); 1.0 here corresponds + to the Nyquist rate. + + Returns + ------- + h : 1D ndarray + The FIR filter coefficients. + """ + numtaps = int(numtaps) + if numtaps % 2 == 0: + raise ValueError("numtaps must be odd.") + n = np.arange(numtaps) + center = (numtaps - 1) / 2.0 + # Ideal impulse response (sinc-based) + h = np.sinc(2 * cutoff * (n - center)) + h *= 2 * cutoff # ensure unity gain at DC + # Apply a Hamming window for smoother response + w = np.hamming(numtaps) + h *= w + # Normalize to unity gain + h /= np.sum(h) + return h + + +def upfirdn(x, h, up, down, axis=0): + """ + Upsample, filter, and downsample a signal along a given axis. + + Parameters + ---------- + x : array_like + Input signal. + h : 1D ndarray + FIR filter coefficients. + up : int + Upsampling factor. + down : int + Downsampling factor. + axis : int, optional + Axis along which to perform the operation. + + Returns + ------- + y : ndarray + The filtered and decimated signal. + """ + x = np.asarray(x) + h = np.asarray(h) + # Upsample: insert (up-1) zeros between every sample along the axis + new_shape = list(x.shape) + new_shape[axis] = x.shape[axis] * up + x_up = np.zeros(new_shape, dtype=x.dtype) + indexer = [slice(None)] * x.ndim + indexer[axis] = slice(0, new_shape[axis], up) + x_up[tuple(indexer)] = x + + # Convolve along the specified axis using full convolution + def conv1d(v): + return np.convolve(v, h, mode='full') + + y_conv = np.apply_along_axis(conv1d, axis, x_up) + # Downsample: take every down-th sample along the convolution axis + indexer = [slice(None)] * y_conv.ndim + indexer[axis] = slice(0, None, down) + return y_conv[tuple(indexer)] + + +def resample_poly(x, up, down, axis=0): + """ + Resample the signal using polyphase filtering. + + This function upsamples the signal by 'up', applies a lowpass FIR filter, + and downsamples by 'down'. + The result has a sample rate multiplied by up/down. + The filter is designed so that the first output sample roughly matches + the first input sample. + + Parameters + ---------- + x : array_like + Input signal. + up : int + Upsampling factor. + down : int + Downsampling factor. + axis : int, optional + Axis along which to resample. + + Returns + ------- + y : ndarray + The resampled signal. + + Notes + ----- + This implementation follows a strategy similar to SciPy's resample_poly + https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.resample_poly.html + but is self-contained and depends only on NumPy and math. + """ # noqa: E501 + x = np.asarray(x) + # Choose filter length; heuristic: 10 * max(up, down) taps on each side + max_rate = max(up, down) + half_len = 10 * max_rate + numtaps = 2 * half_len + 1 + # Cutoff frequency (normalized): 1 / max(up, down) + cutoff = 1.0 / max_rate + h = firwin(numtaps, cutoff) + # Scale the filter coefficients by the up factor + h = h * up + # Apply the upfirdn process + y_full = upfirdn(x, h, up, down, axis=axis) + # The symmetric FIR filter introduces a delay of (numtaps - 1) / 2 samples. + delay = (numtaps - 1) // 2 + # In the output domain, the delay is reduced by the downsampling factor. + trim = delay // down + # Expected output length (ceiling division) + n_in = x.shape[axis] + n_out = (n_in * up + down - 1) // down + # Slice y_full to remove the initial delay and take n_out samples on axis + slicer = [slice(None)] * y_full.ndim + slicer[axis] = slice(trim, trim + n_out) + y = y_full[tuple(slicer)] + return y