Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encryption integration and test #5544

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
import org.apache.iceberg.DataFile;
import org.apache.iceberg.DeleteFile;
import org.apache.iceberg.ManifestFile;
import org.apache.iceberg.exceptions.RuntimeIOException;
import org.apache.iceberg.io.FileIO;
import org.apache.iceberg.io.InputFile;
import org.apache.iceberg.io.OutputFile;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;

Expand Down Expand Up @@ -109,14 +111,28 @@ public InputFile newInputFile(ManifestFile manifest) {
}
}

/**
* @deprecated will be removed in 2.0.0. use {@link #newDecryptingInputFile(String, long,
* ByteBuffer)} instead.
*/
@Deprecated
public InputFile newDecryptingInputFile(String path, ByteBuffer buffer) {
return em.decrypt(wrap(io.newInputFile(path), buffer));
throw new RuntimeIOException("Deprecated API. File decryption without length is not safe");
}

public InputFile newDecryptingInputFile(String path, long length, ByteBuffer buffer) {
// TODO: is the length correct for the encrypted file? It may be the length of the plaintext
// stream
return em.decrypt(wrap(io.newInputFile(path, length), buffer));
Preconditions.checkArgument(
length > 0, "Cannot safely decrypt file %s because its length is not specified", path);

InputFile inputFile = io.newInputFile(path, length);

if (inputFile.getLength() != length) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not call inputFile.getLength because it will either return the length passed in above and is useless (newInputFile(path, length)) or it will make a call to the underlying storage (needless HEAD request). Don't we already catch cases where the file has been truncated?

Hm. I don't see a test in TestGcmStreams so we should probably add one that validates truncated streams specifically.

Copy link
Contributor Author

@ggershinsky ggershinsky Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a test in TestGcmStreams so we should probably add one that validates truncated streams specifically.

I'm not sure how. If a file is truncated by exactly 1 block, GCM Streams won't detect that. The upper layer (Avro or Json readers, etc) might detect that or might not, it's not guaranteed. That's why we've added an explicit requirement in the GCM Stream spec to take the length from a trusted source. This is "out of band" from the spec point of view, meaning that we must make sure the length comes from a parent metadata (and not from the file system) everywhere in Iceberg where we decrypt a stream.

However, FileIO.newInputFile(path, length) implementations are custom; some of them simply ignore the length parameter - and then indeed send a file system request upon InputFile.getLength() call. But we can prevent a security breach by making this check (if (inputFile.getLength() != length)). In most cases, this won't trigger a HEAD request, because most of FileIO implementations don't ignore the length parameter in newInputFile(path, length) and store it. For those few that do ignore it, we need to verify the file length wasn't truncated in the file system.

throw new RuntimeIOException(
"Cannot safely decrypt a file because its size was changed by FileIO %s from %s to %s",
io.getClass(), length, inputFile.getLength());
}

return em.decrypt(wrap(inputFile, buffer));
}

@Override
Expand Down Expand Up @@ -157,7 +173,7 @@ private static SimpleEncryptedInputFile wrap(InputFile encryptedInputFile, ByteB
}

private static EncryptionKeyMetadata toKeyMetadata(ByteBuffer buffer) {
return buffer != null ? new SimpleKeyMetadata(buffer) : EmptyKeyMetadata.get();
return buffer != null ? new SimpleKeyMetadata(buffer) : EncryptionKeyMetadata.empty();
}

private static class SimpleEncryptedInputFile implements EncryptedInputFile {
Expand Down Expand Up @@ -198,22 +214,4 @@ public EncryptionKeyMetadata copy() {
return new SimpleKeyMetadata(metadataBuffer.duplicate());
}
}

private static class EmptyKeyMetadata implements EncryptionKeyMetadata {
private static final EmptyKeyMetadata INSTANCE = new EmptyKeyMetadata();

private static EmptyKeyMetadata get() {
return INSTANCE;
}

@Override
public ByteBuffer buffer() {
return null;
}

@Override
public EncryptionKeyMetadata copy() {
return this;
}
}
}
10 changes: 10 additions & 0 deletions core/src/main/java/org/apache/iceberg/TableMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.iceberg.encryption.KeyEncryptionKey;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;
import org.apache.iceberg.relocated.com.google.common.base.Objects;
Expand Down Expand Up @@ -260,6 +261,7 @@ public String toString() {
private final List<PartitionStatisticsFile> partitionStatisticsFiles;
private final List<MetadataUpdate> changes;
private SerializableSupplier<List<Snapshot>> snapshotsSupplier;
private Map<String, KeyEncryptionKey> kekCache;
private volatile List<Snapshot> snapshots;
private volatile Map<Long, Snapshot> snapshotsById;
private volatile Map<String, SnapshotRef> refs;
Expand Down Expand Up @@ -512,6 +514,14 @@ public List<Snapshot> snapshots() {
return snapshots;
}

public void setKekCache(Map<String, KeyEncryptionKey> kekCache) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change this to withKekCache, making the implementation similar to withUUID (via new Builder).

this.kekCache = kekCache;
}

public Map<String, KeyEncryptionKey> kekCache() {
return kekCache;
}

private synchronized void ensureSnapshotsLoaded() {
if (!snapshotsLoaded) {
List<Snapshot> loadedSnapshots = Lists.newArrayList(snapshotsSupplier.get());
Expand Down
98 changes: 72 additions & 26 deletions core/src/main/java/org/apache/iceberg/TableMetadataParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import java.util.zip.GZIPOutputStream;
import org.apache.iceberg.TableMetadata.MetadataLogEntry;
import org.apache.iceberg.TableMetadata.SnapshotLogEntry;
import org.apache.iceberg.encryption.EncryptionUtil;
import org.apache.iceberg.encryption.KeyEncryptionKey;
import org.apache.iceberg.exceptions.RuntimeIOException;
import org.apache.iceberg.io.FileIO;
import org.apache.iceberg.io.InputFile;
Expand All @@ -42,6 +44,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.util.JsonUtil;

public class TableMetadataParser {
Expand Down Expand Up @@ -104,6 +107,9 @@ private TableMetadataParser() {}
static final String REFS = "refs";
static final String SNAPSHOTS = "snapshots";
static final String SNAPSHOT_ID = "snapshot-id";
static final String KEK_CACHE = "kek-cache";
static final String KEK_ID = "kek-id";
static final String KEK_WRAP = "kek-wrap";
static final String TIMESTAMP_MS = "timestamp-ms";
static final String SNAPSHOT_LOG = "snapshot-log";
static final String METADATA_FILE = "metadata-file";
Expand Down Expand Up @@ -220,6 +226,18 @@ public static void toJson(TableMetadata metadata, JsonGenerator generator) throw

toJson(metadata.refs(), generator);

if (metadata.kekCache() != null && !metadata.kekCache().isEmpty()) {
generator.writeArrayFieldStart(KEK_CACHE);
for (Map.Entry<String, KeyEncryptionKey> entry : metadata.kekCache().entrySet()) {
generator.writeStartObject();
generator.writeStringField(KEK_ID, entry.getKey());
generator.writeStringField(KEK_WRAP, entry.getValue().wrappedKey());
generator.writeNumberField(TIMESTAMP_MS, entry.getValue().timestamp());
generator.writeEndObject();
}
generator.writeEndArray();
}

generator.writeArrayFieldStart(SNAPSHOTS);
for (Snapshot snapshot : metadata.snapshots()) {
SnapshotParser.toJson(snapshot, generator);
Expand Down Expand Up @@ -277,7 +295,11 @@ public static TableMetadata read(FileIO io, InputFile file) {
Codec codec = Codec.fromFileName(file.location());
try (InputStream is =
codec == Codec.GZIP ? new GZIPInputStream(file.newStream()) : file.newStream()) {
return fromJson(file, JsonUtil.mapper().readValue(is, JsonNode.class));
TableMetadata tableMetadata = fromJson(file, JsonUtil.mapper().readValue(is, JsonNode.class));
if (tableMetadata.kekCache() != null) {
EncryptionUtil.getKekCacheFromMetadata(io, tableMetadata.kekCache());
}
return tableMetadata;
} catch (IOException e) {
throw new RuntimeIOException(e, "Failed to read file: %s", file);
}
Expand Down Expand Up @@ -466,6 +488,23 @@ public static TableMetadata fromJson(String metadataLocation, JsonNode node) {
refs = ImmutableMap.of();
}

Map<String, KeyEncryptionKey> kekCache = null;
if (node.has(KEK_CACHE)) {
kekCache = Maps.newHashMap();
Iterator<JsonNode> cacheIterator = node.get(KEK_CACHE).elements();
while (cacheIterator.hasNext()) {
JsonNode entryNode = cacheIterator.next();
String kekID = JsonUtil.getString(KEK_ID, entryNode);
kekCache.put(
kekID,
new KeyEncryptionKey(
kekID,
null, // key will be unwrapped later
JsonUtil.getString(KEK_WRAP, entryNode),
JsonUtil.getLong(TIMESTAMP_MS, entryNode)));
}
}

List<Snapshot> snapshots;
if (node.has(SNAPSHOTS)) {
JsonNode snapshotArray = JsonUtil.get(SNAPSHOTS, node);
Expand Down Expand Up @@ -519,31 +558,38 @@ public static TableMetadata fromJson(String metadataLocation, JsonNode node) {
}
}

return new TableMetadata(
metadataLocation,
formatVersion,
uuid,
location,
lastSequenceNumber,
lastUpdatedMillis,
lastAssignedColumnId,
currentSchemaId,
schemas,
defaultSpecId,
specs,
lastAssignedPartitionId,
defaultSortOrderId,
sortOrders,
properties,
currentSnapshotId,
snapshots,
null,
entries.build(),
metadataEntries.build(),
refs,
statisticsFiles,
partitionStatisticsFiles,
ImmutableList.of() /* no changes from the file */);
TableMetadata result =
new TableMetadata(
metadataLocation,
formatVersion,
uuid,
location,
lastSequenceNumber,
lastUpdatedMillis,
lastAssignedColumnId,
currentSchemaId,
schemas,
defaultSpecId,
specs,
lastAssignedPartitionId,
defaultSortOrderId,
sortOrders,
properties,
currentSnapshotId,
snapshots,
null,
entries.build(),
metadataEntries.build(),
refs,
statisticsFiles,
partitionStatisticsFiles,
ImmutableList.of()); /* no changes from the file */

if (kekCache != null) {
result.setKekCache(kekCache);
}

return result;
}

private static Map<String, SnapshotRef> refsFromJson(JsonNode refMap) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ public void flush() throws IOException {

@Override
public void close() throws IOException {
if (isClosed) {
return;
}

if (!isHeaderWritten) {
writeHeader();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import java.util.Map;

/** A minimum client interface to connect to a key management service (KMS). */
interface KeyManagementClient extends Serializable, Closeable {
public interface KeyManagementClient extends Serializable, Closeable {

/**
* Wrap a secret key, using a wrapping/master key which is stored in KMS and referenced by an ID.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import org.apache.iceberg.catalog.Namespace;
import org.apache.iceberg.catalog.SupportsNamespaces;
import org.apache.iceberg.catalog.TableIdentifier;
import org.apache.iceberg.encryption.EncryptionUtil;
import org.apache.iceberg.encryption.KeyManagementClient;
import org.apache.iceberg.exceptions.NamespaceNotEmptyException;
import org.apache.iceberg.exceptions.NoSuchNamespaceException;
import org.apache.iceberg.exceptions.NoSuchTableException;
Expand All @@ -56,6 +58,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.util.LocationUtil;
import org.apache.iceberg.util.PropertyUtil;
import org.apache.thrift.TException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -76,9 +79,11 @@ public class HiveCatalog extends BaseMetastoreCatalog implements SupportsNamespa
private String name;
private Configuration conf;
private FileIO fileIO;
private KeyManagementClient keyManagementClient;
private ClientPool<IMetaStoreClient, TException> clients;
private boolean listAllTables = false;
private Map<String, String> catalogProperties;
private long writerKekTimeout;

public HiveCatalog() {}

Expand Down Expand Up @@ -110,6 +115,15 @@ public void initialize(String inputName, Map<String, String> properties) {
? new HadoopFileIO(conf)
: CatalogUtil.loadFileIO(fileIOImpl, properties, conf);

if (catalogProperties.containsKey(CatalogProperties.ENCRYPTION_KMS_IMPL)) {
this.keyManagementClient = EncryptionUtil.createKmsClient(properties);
this.writerKekTimeout =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Wanna keep this PR active
  • (nit) Maybe it make sense to add unit to the variable name, i.e. writerKekTimeoutMs? For example, in LockManagers, we have
      this.heartbeatIntervalMs =
          PropertyUtil.propertyAsLong(
              properties,
              CatalogProperties.LOCK_HEARTBEAT_INTERVAL_MS,
              CatalogProperties.LOCK_HEARTBEAT_INTERVAL_MS_DEFAULT);

PropertyUtil.propertyAsLong(
properties,
CatalogProperties.WRITER_KEK_TIMEOUT_MS,
CatalogProperties.WRITER_KEK_TIMEOUT_MS_DEFAULT);
}

this.clients = new CachedClientPool(conf, properties);
}

Expand Down Expand Up @@ -512,7 +526,8 @@ private boolean isValidateNamespace(Namespace namespace) {
public TableOperations newTableOps(TableIdentifier tableIdentifier) {
String dbName = tableIdentifier.namespace().level(0);
String tableName = tableIdentifier.name();
return new HiveTableOperations(conf, clients, fileIO, name, dbName, tableName);
return new HiveTableOperations(
conf, clients, fileIO, keyManagementClient, name, dbName, tableName, writerKekTimeout);
}

@Override
Expand Down
Loading