diff --git a/main/src/main/java/com/sedmelluq/discord/lavaplayer/container/wav/WavFileLoader.java b/main/src/main/java/com/sedmelluq/discord/lavaplayer/container/wav/WavFileLoader.java index edde9fad..93feea11 100644 --- a/main/src/main/java/com/sedmelluq/discord/lavaplayer/container/wav/WavFileLoader.java +++ b/main/src/main/java/com/sedmelluq/discord/lavaplayer/container/wav/WavFileLoader.java @@ -7,6 +7,7 @@ import java.io.DataInputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import static com.sedmelluq.discord.lavaplayer.container.MediaContainerDetection.checkNextBytes; @@ -14,7 +15,8 @@ * Loads either WAV header information or a WAV track provider from a stream. */ public class WavFileLoader { - static final int[] WAV_RIFF_HEADER = new int[]{0x52, 0x49, 0x46, 0x46, -1, -1, -1, -1, 0x57, 0x41, 0x56, 0x45}; + static final int[] WAV_RIFF_HEADER = new int[] { 0x52, 0x49, 0x46, 0x46, -1, -1, -1, -1, 0x57, 0x41, 0x56, 0x45 }; + static final byte[] FORMAT_SUBTYPE_PCM = { 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, (byte) 0x80, 0x00, 0x00, (byte) 0xaa, 0x00, 0x38, (byte) 0x9b, 0x71 }; private final SeekableInputStream inputStream; @@ -44,10 +46,11 @@ public WavFileInfo parseHeaders() throws IOException { long chunkSize = Integer.toUnsignedLong(Integer.reverseBytes(dataInput.readInt())); if ("fmt ".equals(chunkName)) { - readFormatChunk(builder, dataInput); + int bytesRead = readFormatChunk(builder, dataInput); + long chunkBytesRemaining = chunkSize - bytesRead; - if (chunkSize > 16) { - inputStream.skipFully(chunkSize - 16); + if (chunkBytesRemaining > 0) { + inputStream.skipFully(chunkBytesRemaining); } } else if ("data".equals(chunkName)) { builder.sampleAreaSize = chunkSize; @@ -65,8 +68,8 @@ private String readChunkName(DataInput dataInput) throws IOException { return new String(buffer, StandardCharsets.US_ASCII); } - private void readFormatChunk(InfoBuilder builder, DataInput dataInput) throws IOException { - builder.audioFormat = Short.reverseBytes(dataInput.readShort()) & 0xFFFF; + private int readFormatChunk(InfoBuilder builder, DataInput dataInput) throws IOException { + builder.setAudioFormat(Short.reverseBytes(dataInput.readShort()) & 0xFFFF); builder.channelCount = Short.reverseBytes(dataInput.readShort()) & 0xFFFF; builder.sampleRate = Integer.reverseBytes(dataInput.readInt()); @@ -75,6 +78,16 @@ private void readFormatChunk(InfoBuilder builder, DataInput dataInput) throws IO builder.blockAlign = Short.reverseBytes(dataInput.readShort()) & 0xFFFF; builder.bitsPerSample = Short.reverseBytes(dataInput.readShort()) & 0xFFFF; + + if (builder.formatType == WaveFormatType.WAVE_FORMAT_EXTENSIBLE) { + dataInput.skipBytes(8); + byte[] subFormat = new byte[16]; + dataInput.readFully(subFormat); + builder.subFormat = subFormat; + return 40; + } + + return 16; } /** @@ -90,6 +103,8 @@ public WavTrackProvider loadTrack(AudioProcessingContext context) throws IOExcep private static class InfoBuilder { private int audioFormat; + private WaveFormatType formatType; + private byte[] subFormat; private int channelCount; private int sampleRate; private int bitsPerSample; @@ -97,6 +112,11 @@ private static class InfoBuilder { private long sampleAreaSize; private long startOffset; + private void setAudioFormat(int audioFormat) { + this.audioFormat = audioFormat; + this.formatType = WaveFormatType.getByCode(audioFormat); + } + private WavFileInfo build() { validateFormat(); validateAlignment(); @@ -105,13 +125,15 @@ private WavFileInfo build() { } private void validateFormat() { - if (audioFormat != 1) { - throw new IllegalStateException("Invalid audio format " + audioFormat + ", must be 1 (PCM)"); + if (formatType == WaveFormatType.WAVE_FORMAT_UNKNOWN) { + throw new IllegalStateException("Invalid audio format " + audioFormat + ", must be 1 (PCM) or 65534 (WAVE_FORMAT_EXTENSIBLE)"); + } else if (subFormat != null && !Arrays.equals(subFormat, FORMAT_SUBTYPE_PCM)) { + throw new IllegalStateException("Invalid subformat " + Arrays.toString(subFormat)); } else if (channelCount < 1 || channelCount > 16) { throw new IllegalStateException("Invalid channel count: " + channelCount); } else if (sampleRate < 100 || sampleRate > 384000) { throw new IllegalStateException("Invalid sample rate: " + sampleRate); - } else if (bitsPerSample != 16 && bitsPerSample != 24) { + } else if (bitsPerSample != 16 && bitsPerSample != 24 && bitsPerSample != 32) { throw new IllegalStateException("Unsupported bits per sample: " + bitsPerSample); } } diff --git a/main/src/main/java/com/sedmelluq/discord/lavaplayer/container/wav/WavTrackProvider.java b/main/src/main/java/com/sedmelluq/discord/lavaplayer/container/wav/WavTrackProvider.java index e8152005..ec1d8453 100644 --- a/main/src/main/java/com/sedmelluq/discord/lavaplayer/container/wav/WavTrackProvider.java +++ b/main/src/main/java/com/sedmelluq/discord/lavaplayer/container/wav/WavTrackProvider.java @@ -10,7 +10,6 @@ import java.io.DataInputStream; import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.ShortBuffer; import static java.nio.ByteOrder.LITTLE_ENDIAN; @@ -23,11 +22,12 @@ public class WavTrackProvider { private final SeekableInputStream inputStream; private final DataInput dataInput; private final WavFileInfo info; + private final int bytesPerSample; private final AudioPipeline downstream; + private final short[] buffer; private final byte[] rawBuffer; private final ByteBuffer byteBuffer; - private final ShortBuffer nioBuffer; /** * @param context Configuration and output information for processing @@ -38,12 +38,12 @@ public WavTrackProvider(AudioProcessingContext context, SeekableInputStream inpu this.inputStream = inputStream; this.dataInput = new DataInputStream(inputStream); this.info = info; + this.bytesPerSample = info.bitsPerSample >> 3; this.downstream = AudioPipelineFactory.create(context, new PcmFormat(info.channelCount, info.sampleRate)); this.buffer = info.getPadding() > 0 ? new short[info.channelCount * BLOCKS_IN_BUFFER] : null; this.byteBuffer = ByteBuffer.allocate(info.blockAlign * BLOCKS_IN_BUFFER).order(LITTLE_ENDIAN); this.rawBuffer = byteBuffer.array(); - this.nioBuffer = byteBuffer.asShortBuffer(); } /** @@ -101,10 +101,10 @@ private void processChunkWithPadding(int blockCount) throws IOException, Interru int indexInBlock = 0; for (int i = 0; i < sampleCount; i++) { - buffer[i] = nioBuffer.get(); + buffer[i] = byteBuffer.getShort(); if (++indexInBlock == info.channelCount) { - nioBuffer.position(nioBuffer.position() + padding); + byteBuffer.position(byteBuffer.position() + padding); indexInBlock = 0; } } @@ -115,27 +115,23 @@ private void processChunkWithPadding(int blockCount) throws IOException, Interru private void processChunk(int blockCount) throws IOException, InterruptedException { int sampleCount = readChunkToBuffer(blockCount); - if (info.bitsPerSample == 16) { - downstream.process(nioBuffer); - } else if (info.bitsPerSample == 24) { - short[] samples = new short[sampleCount]; - + if (info.bitsPerSample != 16) { for (int i = 0; i < sampleCount; i++) { - samples[i] = (short) (byteBuffer.get((i * 3) + 2) << 8 | byteBuffer.get((i * 3) + 1) & 0xFF); + byteBuffer.putShort(i * 2, byteBuffer.getShort((i * bytesPerSample) + bytesPerSample - 2)); } - downstream.process(samples, 0, sampleCount); + byteBuffer.limit(sampleCount * 2); } + + downstream.process(byteBuffer.asShortBuffer()); } private int readChunkToBuffer(int blockCount) throws IOException { - int bytesPerSample = info.bitsPerSample >> 3; int bytesToRead = blockCount * info.blockAlign; dataInput.readFully(rawBuffer, 0, bytesToRead); byteBuffer.position(0); - nioBuffer.position(0); - nioBuffer.limit(bytesToRead / bytesPerSample); + byteBuffer.limit(bytesToRead); return bytesToRead / bytesPerSample; } diff --git a/main/src/main/java/com/sedmelluq/discord/lavaplayer/container/wav/WaveFormatType.java b/main/src/main/java/com/sedmelluq/discord/lavaplayer/container/wav/WaveFormatType.java new file mode 100644 index 00000000..097e6c0d --- /dev/null +++ b/main/src/main/java/com/sedmelluq/discord/lavaplayer/container/wav/WaveFormatType.java @@ -0,0 +1,21 @@ +package com.sedmelluq.discord.lavaplayer.container.wav; + +import java.util.Arrays; + +public enum WaveFormatType { + // https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/Docs/Pages%20from%20mmreg.h.pdf + WAVE_FORMAT_UNKNOWN(0x0000), + WAVE_FORMAT_PCM(0x0001), + WAVE_FORMAT_EXTENSIBLE(0xFFFE); + + final int code; + + WaveFormatType(int code) { + this.code = code; + } + + static WaveFormatType getByCode(int code) { + return Arrays.stream(values()).filter(type -> type.code == code).findFirst() + .orElse(WAVE_FORMAT_UNKNOWN); + } +}