Skip to content

Commit

Permalink
ZstdDecompressCtx: pass nativePtr directly to JNI calls
Browse files Browse the repository at this point in the history
This avoids having to lookup the field and fixes a bug where the
wrong fieldId was passed causing a segfault if
ZstdDecompressCtx#reset() was called before ZstdCompressCtx() was used.
  • Loading branch information
jamie-walker authored and luben committed Feb 27, 2023
1 parent 73a378f commit 1317e44
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 63 deletions.
53 changes: 21 additions & 32 deletions src/main/java/com/github/luben/zstd/ZstdDecompressCtx.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ public class ZstdDecompressCtx extends AutoCloseBase {
private long nativePtr = 0;
private ZstdDictDecompress decompression_dict = null;

private native void init();
private static native long init();

private native void free();
private static native void free(long nativePtr);

/**
* Create a context for faster compress operations
* One such context is required for each thread - put this in a ThreadLocal.
*/
public ZstdDecompressCtx() {
init();
nativePtr = init();
if (0 == nativePtr) {
throw new IllegalStateException("ZSTD_createDeCompressCtx failed");
}
Expand All @@ -33,7 +33,7 @@ public ZstdDecompressCtx() {

void doClose() {
if (nativePtr != 0) {
free();
free(nativePtr);
nativePtr = 0;
}
}
Expand All @@ -43,9 +43,7 @@ void doClose() {
* @param magiclessFlag A 32-bits checksum of content is written at end of frame, default: false
*/
public ZstdDecompressCtx setMagicless(boolean magiclessFlag) {
if (nativePtr == 0) {
throw new IllegalStateException("Compression context is closed");
}
ensureOpen();
acquireSharedLock();
Zstd.setDecompressionMagicless(nativePtr, magiclessFlag);
releaseSharedLock();
Expand All @@ -58,13 +56,11 @@ public ZstdDecompressCtx setMagicless(boolean magiclessFlag) {
* @param dict the dictionary or `null` to remove loaded dictionary
*/
public ZstdDecompressCtx loadDict(ZstdDictDecompress dict) {
if (nativePtr == 0) {
throw new IllegalStateException("Decompression context is closed");
}
ensureOpen();
acquireSharedLock();
dict.acquireSharedLock();
try {
long result = loadDDictFast0(dict);
long result = loadDDictFast0(nativePtr, dict);
if (Zstd.isError(result)) {
throw new ZstdException(result);
}
Expand All @@ -76,20 +72,18 @@ public ZstdDecompressCtx loadDict(ZstdDictDecompress dict) {
}
return this;
}
private native long loadDDictFast0(ZstdDictDecompress dict);
private static native long loadDDictFast0(long nativePtr, ZstdDictDecompress dict);

/**
* Load decompression dictionary.
*
* @param dict the dictionary or `null` to remove loaded dictionary
*/
public ZstdDecompressCtx loadDict(byte[] dict) {
if (nativePtr == 0) {
throw new IllegalStateException("Compression context is closed");
}
ensureOpen();
acquireSharedLock();
try {
long result = loadDDict0(dict);
long result = loadDDict0(nativePtr, dict);
if (Zstd.isError(result)) {
throw new ZstdException(result);
}
Expand All @@ -99,17 +93,17 @@ public ZstdDecompressCtx loadDict(byte[] dict) {
}
return this;
}
private native long loadDDict0(byte[] dict);
private static native long loadDDict0(long nativePtr, byte[] dict);

/**
* Clear all state and parameters from the decompression context. This leaves the object in a
* state identical to a newly created decompression context.
*/
public void reset() {
ensureOpen();
reset0();
reset0(nativePtr);
}
private native void reset0();
private static native void reset0(long nativePtr);

private void ensureOpen() {
if (nativePtr == 0) {
Expand All @@ -127,7 +121,7 @@ private void ensureOpen() {
*/
public boolean decompressDirectByteBufferStream(ByteBuffer dst, ByteBuffer src) {
ensureOpen();
long result = decompressDirectByteBufferStream0(dst, dst.position(), dst.limit(), src, src.position(), src.limit());
long result = decompressDirectByteBufferStream0(nativePtr, dst, dst.position(), dst.limit(), src, src.position(), src.limit());
if ((result & 0x80000000L) != 0) {
long code = result & 0xFF;
throw new ZstdException(code, Zstd.getErrorName(code));
Expand All @@ -144,7 +138,7 @@ public boolean decompressDirectByteBufferStream(ByteBuffer dst, ByteBuffer src)
* bit is set if an error occurred. If an error occurred, the lowest 31 bits encode a zstd error
* code. Otherwise, the lowest 31 bits are the new position of the source buffer.
*/
private native long decompressDirectByteBufferStream0(ByteBuffer dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize);
private static native long decompressDirectByteBufferStream0(long nativePtr, ByteBuffer dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize);

/**
* Decompresses buffer 'srcBuff' into buffer 'dstBuff' using this ZstdDecompressCtx.
Expand All @@ -162,9 +156,7 @@ public boolean decompressDirectByteBufferStream(ByteBuffer dst, ByteBuffer src)
* @return the number of bytes decompressed into destination buffer (originalSize)
*/
public int decompressDirectByteBuffer(ByteBuffer dstBuff, int dstOffset, int dstSize, ByteBuffer srcBuff, int srcOffset, int srcSize) {
if (nativePtr == 0) {
throw new IllegalStateException("Decompression context is closed");
}
ensureOpen();
if (!srcBuff.isDirect()) {
throw new IllegalArgumentException("srcBuff must be a direct buffer");
}
Expand All @@ -175,7 +167,7 @@ public int decompressDirectByteBuffer(ByteBuffer dstBuff, int dstOffset, int dst
acquireSharedLock();

try {
long size = decompressDirectByteBuffer0(dstBuff, dstOffset, dstSize, srcBuff, srcOffset, srcSize);
long size = decompressDirectByteBuffer0(nativePtr, dstBuff, dstOffset, dstSize, srcBuff, srcOffset, srcSize);
if (Zstd.isError(size)) {
throw new ZstdException(size);
}
Expand All @@ -188,7 +180,7 @@ public int decompressDirectByteBuffer(ByteBuffer dstBuff, int dstOffset, int dst
}
}

private native long decompressDirectByteBuffer0(ByteBuffer dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize);
private static native long decompressDirectByteBuffer0(long nativePtr, ByteBuffer dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize);

/**
* Decompresses byte array 'srcBuff' into byte array 'dstBuff' using this ZstdDecompressCtx.
Expand All @@ -204,14 +196,11 @@ public int decompressDirectByteBuffer(ByteBuffer dstBuff, int dstOffset, int dst
* @return the number of bytes decompressed into destination buffer (originalSize)
*/
public int decompressByteArray(byte[] dstBuff, int dstOffset, int dstSize, byte[] srcBuff, int srcOffset, int srcSize) {
if (nativePtr == 0) {
throw new IllegalStateException("Decompression context is closed");
}

ensureOpen();
acquireSharedLock();

try {
long size = decompressByteArray0(dstBuff, dstOffset, dstSize, srcBuff, srcOffset, srcSize);
long size = decompressByteArray0(nativePtr, dstBuff, dstOffset, dstSize, srcBuff, srcOffset, srcSize);
if (Zstd.isError(size)) {
throw new ZstdException(size);
}
Expand All @@ -224,7 +213,7 @@ public int decompressByteArray(byte[] dstBuff, int dstOffset, int dstSize, byte[
}
}

private native long decompressByteArray0(byte[] dst, int dstOffset, int dstSize, byte[] src, int srcOffset, int srcSize);
private static native long decompressByteArray0(long nativePtr, byte[] dst, int dstOffset, int dstSize, byte[] src, int srcOffset, int srcSize);

/* Covenience methods */

Expand Down
56 changes: 25 additions & 31 deletions src/main/native/jni_fast_zstd.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
static jfieldID compress_dict = 0;
static jfieldID decompress_dict = 0;
static jfieldID compress_ctx_nativePtr = 0;
static jfieldID decompress_ctx_nativePtr = 0;


/*
Expand Down Expand Up @@ -469,17 +468,13 @@ E1: return size;
/*
* Class: com_github_luben_zstd_ZstdDecompressCtx
* Method: init
* Signature: ()V
* Signature: ()J
*/
JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_init
(JNIEnv *env, jobject obj)
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_init
(JNIEnv *env, jclass clazz)
{
if (decompress_ctx_nativePtr == 0) {
jclass clazz = (*env)->GetObjectClass(env, obj);
decompress_ctx_nativePtr = (*env)->GetFieldID(env, clazz, "nativePtr", "J");
}
ZSTD_DCtx* dctx = ZSTD_createDCtx();
(*env)->SetLongField(env, obj, decompress_ctx_nativePtr, (jlong)(intptr_t) dctx);
return (jlong)(intptr_t) dctx;
}

/*
Expand All @@ -488,23 +483,22 @@ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_init
* Signature: ()V
*/
JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_free
(JNIEnv *env, jobject obj)
(JNIEnv *env, jclass clazz, jlong ptr)
{
if (decompress_ctx_nativePtr == 0) return;
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)(*env)->GetLongField(env, obj, decompress_ctx_nativePtr);
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr;
if (NULL == dctx) return;
ZSTD_freeDCtx(dctx);
}

/*
* Class: com_github_luben_zstd_ZstdDecompressCtx
* Method: loadDDictFast0
* Signature: (Lcom/github/luben/zstd/ZstdDictDecompress)J
* Signature: (JLcom/github/luben/zstd/ZstdDictDecompress;)J
*/
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_loadDDictFast0
(JNIEnv *env, jobject obj, jobject dict)
(JNIEnv *env, jclass clazz, jlong ptr, jobject dict)
{
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)(*env)->GetLongField(env, obj, decompress_ctx_nativePtr);
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr;
if (dict == NULL) {
// remove dictionary
return ZSTD_DCtx_refDDict(dctx, NULL);
Expand All @@ -517,12 +511,12 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_loadDDictFa
/*
* Class: com_github_luben_zstd_ZstdDecompressCtx
* Method: loadDDict0
* Signature: ([B)J
* Signature: (J[B)J
*/
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_loadDDict0
(JNIEnv *env, jobject obj, jbyteArray dict)
(JNIEnv *env, jclass clazz, jlong ptr, jbyteArray dict)
{
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)(*env)->GetLongField(env, obj, decompress_ctx_nativePtr);
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr;
if (dict == NULL) {
// remove dictionary
return ZSTD_DCtx_loadDictionary(dctx, NULL, 0);
Expand All @@ -538,16 +532,16 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_loadDDict0
/*
* Class: com_github_luben_zstd_ZstdDecompressCtx
* Method: reset0
* Signature: (L)J
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_reset0
(JNIEnv *env, jclass jctx) {
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)(*env)->GetLongField(env, jctx, compress_ctx_nativePtr);
(JNIEnv *env, jclass clazz, jlong ptr) {
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr;
return ZSTD_DCtx_reset(dctx, ZSTD_reset_session_and_parameters);
}

static size_t decompress_direct_buffer_stream
(JNIEnv *env, jclass jctx, jobject dst, jint *dst_offset, jint dst_size, jobject src, jint *src_offset, jint src_size)
(JNIEnv *env, jlong ptr, jobject dst, jint *dst_offset, jint dst_size, jobject src, jint *src_offset, jint src_size)
{
if (NULL == dst) return -ZSTD_error_dstSize_tooSmall;
if (NULL == src) return -ZSTD_error_srcSize_wrong;
Expand All @@ -561,7 +555,7 @@ static size_t decompress_direct_buffer_stream
jsize src_cap = (*env)->GetDirectBufferCapacity(env, src);
if (src_size > src_cap) return -ZSTD_error_srcSize_wrong;

ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)(*env)->GetLongField(env, jctx, decompress_ctx_nativePtr);
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr;

ZSTD_outBuffer out;
out.pos = *dst_offset;
Expand All @@ -583,12 +577,12 @@ static size_t decompress_direct_buffer_stream
/*
* Class: com_github_luben_zstd_ZstdDecompressCtx
* Method: decompressDirectByteBufferStream0
* Signature: (Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;II)J
* Signature: (JLjava/nio/ByteBuffer;IILjava/nio/ByteBuffer;II)J
*/
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressDirectByteBufferStream0
(JNIEnv *env, jclass jctx, jobject dst, jint dst_offset, jint dst_size, jobject src, jint src_offset, jint src_size)
(JNIEnv *env, jclass jclass, jlong ptr, jobject dst, jint dst_offset, jint dst_size, jobject src, jint src_offset, jint src_size)
{
size_t result = decompress_direct_buffer_stream(env, jctx, dst, &dst_offset, dst_size, src, &src_offset, src_size);
size_t result = decompress_direct_buffer_stream(env, ptr, dst, &dst_offset, dst_size, src, &src_offset, src_size);
if (ZSTD_isError(result)) {
return (1ULL << 31) | ZSTD_getErrorCode(result);
}
Expand All @@ -603,10 +597,10 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressD
/*
* Class: com_github_luben_zstd_ZstdDecompressCtx
* Method: decompressDirectByteBuffer0
* Signature: (Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;II)J
* Signature: (JLjava/nio/ByteBuffer;IILjava/nio/ByteBuffer;II)J
*/
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressDirectByteBuffer0
(JNIEnv *env, jclass jctx, jobject dst, jint dst_offset, jint dst_size, jobject src, jint src_offset, jint src_size)
(JNIEnv *env, jclass jclazz, jlong ptr, jobject dst, jint dst_offset, jint dst_size, jobject src, jint src_offset, jint src_size)
{
if (NULL == dst) return -ZSTD_error_dstSize_tooSmall;
if (NULL == src) return -ZSTD_error_srcSize_wrong;
Expand All @@ -619,7 +613,7 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressD
jsize src_cap = (*env)->GetDirectBufferCapacity(env, src);
if (src_offset + src_size > src_cap) return -ZSTD_error_srcSize_wrong;

ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)(*env)->GetLongField(env, jctx, decompress_ctx_nativePtr);
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr;

char *dst_buff = (char*)(*env)->GetDirectBufferAddress(env, dst);
if (dst_buff == NULL) return -ZSTD_error_memory_allocation;
Expand All @@ -636,7 +630,7 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressD
* Signature: (B[IIB[II)J
*/
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressByteArray0
(JNIEnv *env, jclass jctx, jbyteArray dst, jint dst_offset, jint dst_size, jbyteArray src, jint src_offset, jint src_size) {
(JNIEnv *env, jclass jclazz, jlong ptr, jbyteArray dst, jint dst_offset, jint dst_size, jbyteArray src, jint src_offset, jint src_size) {
size_t size = -ZSTD_error_memory_allocation;

if (0 > dst_offset) return -ZSTD_error_dstSize_tooSmall;
Expand All @@ -646,7 +640,7 @@ JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_ZstdDecompressCtx_decompressB
if (src_offset + src_size > (*env)->GetArrayLength(env, src)) return -ZSTD_error_srcSize_wrong;
if (dst_offset + dst_size > (*env)->GetArrayLength(env, dst)) return -ZSTD_error_dstSize_tooSmall;

ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)(*env)->GetLongField(env, jctx, decompress_ctx_nativePtr);
ZSTD_DCtx* dctx = (ZSTD_DCtx*)(intptr_t)ptr;

void *dst_buff = (*env)->GetPrimitiveArrayCritical(env, dst, NULL);
if (dst_buff == NULL) goto E1;
Expand Down

0 comments on commit 1317e44

Please sign in to comment.