Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix S3 PartialContentInputStream to be compliant to InputStream specification #2217

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.springframework.util.Assert;

import javax.crypto.CipherInputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
Expand Down Expand Up @@ -267,7 +266,7 @@ public Resource getResource(S entity, PropertyPath propertyPath) {

// remove cast and use conversion service
unencryptedStream = encrypter.decrypt((byte[]) contentProperty.getCustomProperty(entity, this.encryptionKeyContentProperty), r.getInputStream(), 0, this.keyRing);
r = new InputStreamResource(new SkipInputStream(unencryptedStream));
r = new InputStreamResource(unencryptedStream);
} catch (IOException e) {
throw new StoreAccessException("error encrypting resource", e);
}
Expand All @@ -282,7 +281,7 @@ public Resource getResource(S entity, PropertyPath propertyPath, org.springframe
Assert.notNull(propertyPath, "propertyPath not set");
Assert.notNull(storeDelegate, "store not set");

Resource r = storeDelegate.getResource(entity, propertyPath, rewriteParamsForCTR(params));
Resource r = storeDelegate.getResource(entity, propertyPath, params);

if (r != null) {
InputStream unencryptedStream = null;
Expand All @@ -309,7 +308,7 @@ public Resource getResource(S entity, PropertyPath propertyPath, GetResourcePara
Assert.notNull(propertyPath, "propertyPath not set");
Assert.notNull(delegate, "store not set");

Resource r = delegate.getResource(entity, propertyPath, rewriteParamsForCTR(params));
Resource r = delegate.getResource(entity, propertyPath, params);

if (r != null) {
InputStream unencryptedStream = null;
Expand All @@ -330,44 +329,26 @@ public Resource getResource(S entity, PropertyPath propertyPath, GetResourcePara
return r;
}

private GetResourceParams rewriteParamsForCTR(GetResourceParams params) {
if (params.getRange() == null) {
return params;
}
int begin = Integer.parseInt(StringUtils.substringBetween(params.getRange(), "bytes=", "-"));
int blockBegin = begin - (begin % 16);
return GetResourceParams.builder().range("bytes=" + blockBegin + "-" + StringUtils.substringAfter(params.getRange(), "-")).build();
}

private org.springframework.content.commons.store.GetResourceParams rewriteParamsForCTR(org.springframework.content.commons.store.GetResourceParams params) {
if (params.getRange() == null) {
return params;
}
int begin = Integer.parseInt(StringUtils.substringBetween(params.getRange(), "bytes=", "-"));
int blockBegin = begin - (begin % 16);
return org.springframework.content.commons.store.GetResourceParams.builder().range("bytes=" + blockBegin + "-" + StringUtils.substringAfter(params.getRange(), "-")).build();
}

private int getOffset(Resource r, GetResourceParams params) {
private long getOffset(Resource r, GetResourceParams params) {
int offset = 0;

if (r instanceof RangeableResource == false)
return offset;
if (params.getRange() == null)
return offset;

return Integer.parseInt(StringUtils.substringBetween(params.getRange(), "bytes=", "-"));
return Long.parseUnsignedLong(StringUtils.substringBetween(params.getRange(), "bytes=", "-"));
}

private int getOffset(Resource r, org.springframework.content.commons.store.GetResourceParams params) {
int offset = 0;
private long getOffset(Resource r, org.springframework.content.commons.store.GetResourceParams params) {
long offset = 0;

if (r instanceof RangeableResource == false)
return offset;
if (params.getRange() == null)
return offset;

return Integer.parseInt(StringUtils.substringBetween(params.getRange(), "bytes=", "-"));
return Long.parseUnsignedLong(StringUtils.substringBetween(params.getRange(), "bytes=", "-"));
}

@Override
Expand Down Expand Up @@ -440,41 +421,6 @@ private void configure(Class<? extends Store> storeInterfaceClass) {
}
}

// CipherInputStream skip does not work. This wraps a cipherinputstream purely to override the skip with a
// working version
public class SkipInputStream extends FilterInputStream
{
private static final int MAX_SKIP_BUFFER_SIZE = 2048;

protected SkipInputStream (InputStream in)
{
super(in);
}

public long skip(long n)
throws IOException
{
long remaining = n;
int nr;

if (n <= 0) {
return 0;
}

int size = (int)Math.min(MAX_SKIP_BUFFER_SIZE, remaining);
byte[] skipBuffer = new byte[size];
while (remaining > 0) {
nr = in.read(skipBuffer, 0, (int)Math.min(size, remaining));
if (nr < 0) {
break;
}
remaining -= nr;
}

return n - remaining;
}
}

public class EncryptingContentStoreConfigurationImpl implements EncryptingContentStoreConfiguration {
private String encryptionKeyContentProperty;
private String keyring;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.springframework.content.encryption;

import org.springframework.beans.factory.InitializingBean;
import java.math.BigInteger;
import org.springframework.data.util.Pair;
import org.springframework.vault.core.VaultOperations;
import org.springframework.vault.core.VaultTransitOperations;
Expand All @@ -11,7 +11,6 @@
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
Expand Down Expand Up @@ -85,14 +84,14 @@ private SecretKey generateDataKey() {
return KEY_GENERATOR.generateKey();
}

private InputStream decryptInputStream(final SecretKeySpec secretKeySpec, byte[] nonce, int offset, InputStream is) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException, IOException, InvalidAlgorithmParameterException {
private InputStream decryptInputStream(final SecretKeySpec secretKeySpec, byte[] nonce, long offset, InputStream is) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException, IOException, InvalidAlgorithmParameterException {
Cipher cipher = Cipher.getInstance(transformation);

byte[] iv = new byte[128 / 8];
System.arraycopy(nonce, 0, iv, 0, nonce.length);

int AES_BLOCK_SIZE = 16;
int blockOffset = offset - (offset % AES_BLOCK_SIZE);
long blockOffset = offset - (offset % AES_BLOCK_SIZE);
final BigInteger ivBI = new BigInteger(1, iv);
final BigInteger ivForOffsetBI = ivBI.add(BigInteger.valueOf(blockOffset / AES_BLOCK_SIZE));
final byte[] ivForOffsetBA = ivForOffsetBI.toByteArray();
Expand All @@ -105,19 +104,17 @@ private InputStream decryptInputStream(final SecretKeySpec secretKeySpec, byte[]
ivForOffset = new IvParameterSpec(ivForOffsetBASized);
}

cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivForOffset);

CipherInputStream cis = new CipherInputStream(is, cipher);
// Skip the blocks that we are not going to decrypt.
// We advanced the IV manually to compensate for these skipped blocks,
// and the stream will be zero-prefixed to compensate on the other side as well.
// This saves encryption processing for all blocks that would be discarded anyways
is.skipNBytes(blockOffset);

InputStream inputStreamToReturn = cis;
if (offset == 0) {
inputStreamToReturn = new ZeroOffsetSkipInputStream(cis);
} else if (offset > 0) {
inputStreamToReturn = new OffsetSkipInputStream(cis, offset % AES_BLOCK_SIZE);
}
cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivForOffset);

return inputStreamToReturn;
return new OffsetInputStream(new SkippingInputStream(new CipherInputStream(is, cipher)), blockOffset);
}

private SecretKeySpec decryptKey(byte[] encryptedKey, String keyName) {
VaultTransitOperations transit = vaultOperations.opsForTransit();
String decryptedBase64Key = transit.decrypt(keyName, new String(encryptedKey));
Expand All @@ -127,7 +124,7 @@ private SecretKeySpec decryptKey(byte[] encryptedKey, String keyName) {
return key;
}

public InputStream decrypt(byte[] ecryptedContext, InputStream is, int offset, String keyName) {
public InputStream decrypt(byte[] ecryptedContext, InputStream is, long offset, String keyName) {

byte[] key = new byte[105];
System.arraycopy(ecryptedContext, 0, key, 0, 105);
Expand All @@ -148,12 +145,12 @@ public void rotate(String keyName) {
}

// CipherInputStream skip does not work. This wraps a cipherinputstream purely to override the skip with a
// working version. Used when backend Store has not already primed the input stream.
public class ZeroOffsetSkipInputStream extends FilterInputStream
// working version.
private static class SkippingInputStream extends FilterInputStream
{
private static final int MAX_SKIP_BUFFER_SIZE = 2048;

protected ZeroOffsetSkipInputStream(InputStream in)
protected SkippingInputStream(InputStream in)
{
super(in);
}
Expand Down Expand Up @@ -182,42 +179,64 @@ public long skip(long n)
}
}

// This wraps a cipherinputstream purely to override skip
//
// Used when a backend store has already satisfied a range request (this service will request a range to the nearest block).
// Skips then skips bytes between the beginning of the block and the start actual range that the client requested.
public class OffsetSkipInputStream extends FilterInputStream
{
private static final int MAX_SKIP_BUFFER_SIZE = 2048;
private final int offset;

protected OffsetSkipInputStream(InputStream in, int offset)
{
super(in);
this.offset = offset;
/**
* Adds a fixed amount of 0-bytes in front of the delegate {@link InputStream}
* <p>
*
* */
private static class OffsetInputStream extends InputStream {
private InputStream delegate;
private long offsetBytes;

public OffsetInputStream(InputStream delegate, long offsetBytes) {
this.delegate = delegate;
this.offsetBytes = offsetBytes;
}

public long skip(long n)
throws IOException
{
long remaining = offset;
int nr;
@Override
public long skip(long n) throws IOException {
if(n <= 0) {
return 0;
}
if(n <= offsetBytes) {
offsetBytes -= n;
return n;
}
if(offsetBytes > 0) {
n = n - offsetBytes; // Still skipping so many bytes from the offset
try {
return offsetBytes + delegate.skip(n);
} finally {
offsetBytes = 0; // Now the whole offset is consumed; skip to the delegate
}
}

if (n <= 0) {
return delegate.skip(n);
}

@Override
public int read() throws IOException {
if(offsetBytes > 0) {
offsetBytes--;
return 0;
}
return delegate.read();
}

int size = (int)Math.min(MAX_SKIP_BUFFER_SIZE, remaining);
byte[] skipBuffer = new byte[size];
while (remaining > 0) {
nr = in.read(skipBuffer, 0, (int)Math.min(size, remaining));
if (nr < 0) {
break;
}
remaining -= nr;
@Override
public int read(byte[] b, int off, int len) throws IOException {
if(offsetBytes > 0) {
return super.read(b, off, len);
}
return delegate.read(b, off, len);
}

return n - remaining;
@Override
public int available() throws IOException {
if(offsetBytes > 0) {
return (int)Math.max(offsetBytes, Integer.MAX_VALUE);
}
return delegate.available();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ private static class Singleton {
public static S3Client getAmazonS3Client() throws URISyntaxException {
return S3Client.builder()
.endpointOverride(new URI(Singleton.INSTANCE.getEndpointConfiguration(LocalStackContainer.Service.S3).getServiceEndpoint()))
.region(Region.US_EAST_1)
.credentialsProvider(new CrossAwsCredentialsProvider(Singleton.INSTANCE.getDefaultCredentialsProvider()))
.serviceConfiguration((bldr) -> bldr.pathStyleAccessEnabled(true).build())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,28 @@ public class EncryptionIT {
MockMvcResponse r =
given()
.header("accept", "text/plain")
.header("range", "bytes=16-27")
.header("range", "bytes=14-27")
.get("/files/" + f.getId() + "/content")
.then()
.statusCode(HttpStatus.SC_PARTIAL_CONTENT)
.assertThat()
.contentType(Matchers.startsWith("text/plain"))
.and().extract().response();

assertThat(r.asString(), is("e encryption"));
assertThat(r.asString(), is("ide encryption"));

r =
given()
.header("accept", "text/plain")
.header("range", "bytes=19-27")
.get("/files/" + f.getId() + "/content")
.then()
.statusCode(HttpStatus.SC_PARTIAL_CONTENT)
.assertThat()
.contentType(Matchers.startsWith("text/plain"))
.and().extract().response();

assertThat(r.asString(), is("ncryption"));
});
Context("when the keyring is rotated", () -> {
BeforeEach(() -> {
Expand Down
Loading
Loading