Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick 8321599: Data loss in AVX3 Base64 decoding #38

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/hotspot/cpu/x86/stubGenerator_x86_64.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2003, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2003, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -2324,7 +2324,7 @@ address StubGenerator::generate_base64_decodeBlock() {
const Register isURL = c_rarg5;// Base64 or URL character set
__ movl(isMIME, Address(rbp, 2 * wordSize));
#else
const Address dp_mem(rbp, 6 * wordSize); // length is on stack on Win64
const Address dp_mem(rbp, 6 * wordSize); // length is on stack on Win64
const Address isURL_mem(rbp, 7 * wordSize);
const Register isURL = r10; // pick the volatile windows register
const Register dp = r12;
Expand Down Expand Up @@ -2546,10 +2546,12 @@ address StubGenerator::generate_base64_decodeBlock() {
// output_size in r13

// Strip pad characters, if any, and adjust length and mask
__ addq(length, start_offset);
__ cmpb(Address(source, length, Address::times_1, -1), '=');
__ jcc(Assembler::equal, L_padding);

__ BIND(L_donePadding);
__ subq(length, start_offset);

// Output size is (64 - output_size), output mask is (all 1s >> output_size).
__ kmovql(input_mask, rax);
Expand Down
121 changes: 120 additions & 1 deletion test/hotspot/jtreg/compiler/intrinsics/base64/TestBase64.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -46,10 +46,13 @@
import java.util.Base64;
import java.util.Base64.Decoder;
import java.util.Base64.Encoder;
import java.util.HexFormat;
import java.util.Objects;
import java.util.Random;
import java.util.Arrays;

import static java.lang.String.format;

import compiler.whitebox.CompilerWhiteBoxTest;
import jdk.test.whitebox.code.Compiler;
import jtreg.SkippedException;
Expand All @@ -69,6 +72,8 @@ public static void main(String[] args) throws Exception {

warmup();

length_checks();

test0(FileType.ASCII, Base64Type.BASIC, Base64.getEncoder(), Base64.getDecoder(),"plain.txt", "baseEncode.txt", iters);
test0(FileType.ASCII, Base64Type.URLSAFE, Base64.getUrlEncoder(), Base64.getUrlDecoder(),"plain.txt", "urlEncode.txt", iters);
test0(FileType.ASCII, Base64Type.MIME, Base64.getMimeEncoder(), Base64.getMimeDecoder(),"plain.txt", "mimeEncode.txt", iters);
Expand Down Expand Up @@ -302,4 +307,118 @@ private static final byte getBadBase64Char(Base64Type b64Type) {
throw new InternalError("Internal test error: getBadBase64Char called with unknown Base64Type value");
}
}

static final int POSITIONS = 30_000;
static final int BASE_LENGTH = 256;
static final HexFormat HEX_FORMAT = HexFormat.of().withUpperCase().withDelimiter(" ");

static int[] plainOffsets = new int[POSITIONS + 1];
static byte[] plainBytes;
static int[] base64Offsets = new int[POSITIONS + 1];
static byte[] base64Bytes;

static {
// Set up ByteBuffer with characters to be encoded
int plainLength = 0;
for (int i = 0; i < plainOffsets.length; i++) {
plainOffsets[i] = plainLength;
int positionLength = (BASE_LENGTH + i) % 2048;
plainLength += positionLength;
}
// Put one of each possible byte value into ByteBuffer
plainBytes = new byte[plainLength];
for (int i = 0; i < plainBytes.length; i++) {
plainBytes[i] = (byte) i;
}

// Grab various slices of the ByteBuffer and encode them
ByteBuffer plainBuffer = ByteBuffer.wrap(plainBytes);
int base64Length = 0;
for (int i = 0; i < POSITIONS; i++) {
base64Offsets[i] = base64Length;
int offset = plainOffsets[i];
int length = plainOffsets[i + 1] - offset;
ByteBuffer plainSlice = plainBuffer.slice(offset, length);
base64Length += Base64.getEncoder().encode(plainSlice).remaining();
}

// Decode the slices created above and ensure lengths match
base64Offsets[base64Offsets.length - 1] = base64Length;
base64Bytes = new byte[base64Length];
for (int i = 0; i < POSITIONS; i++) {
int plainOffset = plainOffsets[i];
ByteBuffer plainSlice = plainBuffer.slice(plainOffset, plainOffsets[i + 1] - plainOffset);
ByteBuffer encodedBytes = Base64.getEncoder().encode(plainSlice);
int base64Offset = base64Offsets[i];
int expectedLength = base64Offsets[i + 1] - base64Offset;
if (expectedLength != encodedBytes.remaining()) {
throw new IllegalStateException(format("Unexpected length: %s <> %s", encodedBytes.remaining(), expectedLength));
}
encodedBytes.get(base64Bytes, base64Offset, expectedLength);
}
}

public static void length_checks() {
decodeAndCheck();
encodeDecode();
System.out.println("Test complete, no invalid decodes detected");
}

// Use ByteBuffer to cause decode() to use the base + offset form of decode
// Checks for bug reported in JDK-8321599 where padding characters appear
// within the beginning of the ByteBuffer *before* the offset. This caused
// the decoded string length to be off by 1 or 2 bytes.
static void decodeAndCheck() {
for (int i = 0; i < POSITIONS; i++) {
ByteBuffer encodedBytes = base64BytesAtPosition(i);
ByteBuffer decodedBytes = Base64.getDecoder().decode(encodedBytes);

if (!decodedBytes.equals(plainBytesAtPosition(i))) {
String base64String = base64StringAtPosition(i);
String plainHexString = plainHexStringAtPosition(i);
String decodedHexString = HEX_FORMAT.formatHex(decodedBytes.array(), decodedBytes.arrayOffset() + decodedBytes.position(), decodedBytes.arrayOffset() + decodedBytes.limit());
throw new IllegalStateException(format("Mismatch for %s\n\nExpected:\n%s\n\nActual:\n%s", base64String, plainHexString, decodedHexString));
}
}
}

// Encode strings of lengths 1-1K, decode, and ensure length and contents correct.
// This checks that padding characters are properly handled by decode.
static void encodeDecode() {
String allAs = "A(=)".repeat(128);
for (int i = 1; i <= 512; i++) {
String encStr = Base64.getEncoder().encodeToString(allAs.substring(0, i).getBytes());
String decStr = new String(Base64.getDecoder().decode(encStr));

if ((decStr.length() != allAs.substring(0, i).length()) ||
(!Objects.equals(decStr, allAs.substring(0, i)))
) {
throw new IllegalStateException(format("Mismatch: Expected: %s\n Actual: %s\n", allAs.substring(0, i), decStr));
}
}
}

static ByteBuffer plainBytesAtPosition(int position) {
int offset = plainOffsets[position];
int length = plainOffsets[position + 1] - offset;
return ByteBuffer.wrap(plainBytes, offset, length);
}

static String plainHexStringAtPosition(int position) {
int offset = plainOffsets[position];
int length = plainOffsets[position + 1] - offset;
return HEX_FORMAT.formatHex(plainBytes, offset, offset + length);
}

static String base64StringAtPosition(int position) {
int offset = base64Offsets[position];
int length = base64Offsets[position + 1] - offset;
return new String(base64Bytes, offset, length, StandardCharsets.UTF_8);
}

static ByteBuffer base64BytesAtPosition(int position) {
int offset = base64Offsets[position];
int length = base64Offsets[position + 1] - offset;
return ByteBuffer.wrap(base64Bytes, offset, length);
}
}