Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ depyf==0.18.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
watchfiles # required for http server to monitor the updates of TLS files
python-json-logger # Used by logging as per examples/other/logging_configuration.md
scipy # Required for phi-4-multimodal-instruct
ninja # Required for xgrammar, rocm, tpu, xpu
opentelemetry-sdk>=1.26.0,<1.27.0 # vllm.tracing
opentelemetry-api>=1.26.0,<1.27.0 # vllm.tracing
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
TypedDict, Union)

import numpy as np
import scipy.signal
import torch
import torch.nn as nn
import torchvision.transforms as T
Expand All @@ -26,6 +25,7 @@
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.phi4mm_utils import resample_poly
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
Expand Down Expand Up @@ -765,10 +765,10 @@ def extract_spectrogram(self, wav, fs):

# Resample to 16000 or 8000 if needed
if fs > 16000:
wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
wav = resample_poly(wav, 1, fs // 16000)
fs = 16000
elif 8000 < fs < 16000:
wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
wav = resample_poly(wav, 1, fs // 8000)
fs = 8000
elif fs < 8000:
raise RuntimeError(f"Unsupported sample rate {fs}")
Expand All @@ -777,7 +777,7 @@ def extract_spectrogram(self, wav, fs):
if self._eightk_method == "resample":
# Input audio is 8 kHz. Convert to 16 kHz before feature
# extraction
wav = scipy.signal.resample_poly(wav, 2, 1)
wav = resample_poly(wav, 2, 1)
fs = 16000
# Do nothing here for fillzero method
elif fs != 16000:
Expand Down
136 changes: 136 additions & 0 deletions vllm/model_executor/models/phi4mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
Expand Down Expand Up @@ -1881,3 +1882,138 @@ def unfold_tensor(xs_pad, max_seq_len):
# NT' x max_seq_len x D
xs_pad = xs_pad.view(-1, max_seq_len, D)
return xs_pad


def firwin(numtaps, cutoff):
"""
Design a simple lowpass FIR filter using a sinc function multiplied
by a Hamming window.

Parameters
----------
numtaps : int
The number of filter taps (must be odd).
cutoff : float
Normalized cutoff frequency (0 < cutoff < 0.5); 1.0 here corresponds
to the Nyquist rate.

Returns
-------
h : 1D ndarray
The FIR filter coefficients.
"""
numtaps = int(numtaps)
if numtaps % 2 == 0:
raise ValueError("numtaps must be odd.")
n = np.arange(numtaps)
center = (numtaps - 1) / 2.0
# Ideal impulse response (sinc-based)
h = np.sinc(2 * cutoff * (n - center))
h *= 2 * cutoff # ensure unity gain at DC
# Apply a Hamming window for smoother response
w = np.hamming(numtaps)
h *= w
# Normalize to unity gain
h /= np.sum(h)
return h


def upfirdn(x, h, up, down, axis=0):
"""
Upsample, filter, and downsample a signal along a given axis.

Parameters
----------
x : array_like
Input signal.
h : 1D ndarray
FIR filter coefficients.
up : int
Upsampling factor.
down : int
Downsampling factor.
axis : int, optional
Axis along which to perform the operation.

Returns
-------
y : ndarray
The filtered and decimated signal.
"""
x = np.asarray(x)
h = np.asarray(h)
# Upsample: insert (up-1) zeros between every sample along the axis
new_shape = list(x.shape)
new_shape[axis] = x.shape[axis] * up
x_up = np.zeros(new_shape, dtype=x.dtype)
indexer = [slice(None)] * x.ndim
indexer[axis] = slice(0, new_shape[axis], up)
x_up[tuple(indexer)] = x

# Convolve along the specified axis using full convolution
def conv1d(v):
return np.convolve(v, h, mode='full')

y_conv = np.apply_along_axis(conv1d, axis, x_up)
# Downsample: take every down-th sample along the convolution axis
indexer = [slice(None)] * y_conv.ndim
indexer[axis] = slice(0, None, down)
return y_conv[tuple(indexer)]


def resample_poly(x, up, down, axis=0):
"""
Resample the signal using polyphase filtering.

This function upsamples the signal by 'up', applies a lowpass FIR filter,
and downsamples by 'down'.
The result has a sample rate multiplied by up/down.
The filter is designed so that the first output sample roughly matches
the first input sample.

Parameters
----------
x : array_like
Input signal.
up : int
Upsampling factor.
down : int
Downsampling factor.
axis : int, optional
Axis along which to resample.

Returns
-------
y : ndarray
The resampled signal.

Notes
-----
This implementation follows a strategy similar to SciPy's resample_poly
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.resample_poly.html
but is self-contained and depends only on NumPy and math.
""" # noqa: E501
x = np.asarray(x)
# Choose filter length; heuristic: 10 * max(up, down) taps on each side
max_rate = max(up, down)
half_len = 10 * max_rate
numtaps = 2 * half_len + 1
# Cutoff frequency (normalized): 1 / max(up, down)
cutoff = 1.0 / max_rate
h = firwin(numtaps, cutoff)
# Scale the filter coefficients by the up factor
h = h * up
# Apply the upfirdn process
y_full = upfirdn(x, h, up, down, axis=axis)
# The symmetric FIR filter introduces a delay of (numtaps - 1) / 2 samples.
delay = (numtaps - 1) // 2
# In the output domain, the delay is reduced by the downsampling factor.
trim = delay // down
# Expected output length (ceiling division)
n_in = x.shape[axis]
n_out = (n_in * up + down - 1) // down
# Slice y_full to remove the initial delay and take n_out samples on axis
slicer = [slice(None)] * y_full.ndim
slicer[axis] = slice(trim, trim + n_out)
y = y_full[tuple(slicer)]
return y