Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance validation of deserialization and reporting of errors #345

Merged
merged 2 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ namespace DurableTask.Netherite
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.Serialization;

/// <summary>
/// Represents a key used to identify <see cref="TrackedObject"/> instances.
/// </summary>
struct TrackedObjectKey
struct TrackedObjectKey
{
public TrackedObjectType ObjectType;
public string InstanceId;
Expand Down Expand Up @@ -70,12 +71,12 @@ public static int Compare(ref TrackedObjectKey key1, ref TrackedObjectKey key2)

public class Comparer : IComparer<TrackedObjectKey>
{
public int Compare(TrackedObjectKey x, TrackedObjectKey y) => TrackedObjectKey.Compare(ref x, ref y);
public int Compare(TrackedObjectKey x, TrackedObjectKey y) => TrackedObjectKey.Compare(ref x, ref y);
}

public override int GetHashCode()
{
return (this.InstanceId?.GetHashCode() ?? 0) + (int) this.ObjectType;
return (this.InstanceId?.GetHashCode() ?? 0) + (int)this.ObjectType;
}

public override bool Equals(object obj)
Expand All @@ -97,55 +98,37 @@ public override bool Equals(object obj)

// convenient constructors for non-singletons

public static TrackedObjectKey History(string id) => new TrackedObjectKey()
{
ObjectType = TrackedObjectType.History,
InstanceId = id,
};
public static TrackedObjectKey Instance(string id) => new TrackedObjectKey()
{
ObjectType = TrackedObjectType.Instance,
InstanceId = id,
};
public static TrackedObjectKey History(string id) => new TrackedObjectKey()
{
ObjectType = TrackedObjectType.History,
InstanceId = id,
};
public static TrackedObjectKey Instance(string id) => new TrackedObjectKey()
{
ObjectType = TrackedObjectType.Instance,
InstanceId = id,
};

public static TrackedObject Factory(TrackedObjectKey key) => key.ObjectType switch
{
TrackedObjectType.Activities => new ActivitiesState(),
TrackedObjectType.Dedup => new DedupState(),
TrackedObjectType.Outbox => new OutboxState(),
TrackedObjectType.Reassembly => new ReassemblyState(),
TrackedObjectType.Sessions => new SessionsState(),
TrackedObjectType.Timers => new TimersState(),
TrackedObjectType.Prefetch => new PrefetchState(),
TrackedObjectType.Queries => new QueriesState(),
TrackedObjectType.Stats => new StatsState(),
TrackedObjectType.History => new HistoryState() { InstanceId = key.InstanceId },
TrackedObjectType.Instance => new InstanceState() { InstanceId = key.InstanceId },
_ => throw new ArgumentException("invalid key", nameof(key)),
};

public static IEnumerable<TrackedObjectKey> GetSingletons()
{
TrackedObjectType.Activities => new ActivitiesState(),
TrackedObjectType.Dedup => new DedupState(),
TrackedObjectType.Outbox => new OutboxState(),
TrackedObjectType.Reassembly => new ReassemblyState(),
TrackedObjectType.Sessions => new SessionsState(),
TrackedObjectType.Timers => new TimersState(),
TrackedObjectType.Prefetch => new PrefetchState(),
TrackedObjectType.Queries => new QueriesState(),
TrackedObjectType.Stats => new StatsState(),
TrackedObjectType.History => new HistoryState() { InstanceId = key.InstanceId },
TrackedObjectType.Instance => new InstanceState() { InstanceId = key.InstanceId },
_ => throw new ArgumentException("invalid key", nameof(key)),
};

public static IEnumerable<TrackedObjectKey> GetSingletons()
=> Enum.GetValues(typeof(TrackedObjectType)).Cast<TrackedObjectType>().Where(t => IsSingletonType(t)).Select(t => new TrackedObjectKey() { ObjectType = t });

public override string ToString()
public override string ToString()
=> this.InstanceId == null ? this.ObjectType.ToString() : $"{this.ObjectType}-{this.InstanceId}";

public void Deserialize(BinaryReader reader)
{
this.ObjectType = (TrackedObjectType) reader.ReadByte();
if (!IsSingletonType(this.ObjectType))
{
this.InstanceId = reader.ReadString();
}
}

public void Serialize(BinaryWriter writer)
{
writer.Write((byte) this.ObjectType);
if (!IsSingletonType(this.ObjectType))
{
writer.Write(this.InstanceId);
}
}
}
}
123 changes: 95 additions & 28 deletions src/DurableTask.Netherite/StorageLayer/Faster/FasterKV.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace DurableTask.Netherite.Faster
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.Serialization;
using System.Text;
using System.Threading;
using System.Threading.Channels;
Expand Down Expand Up @@ -97,8 +98,8 @@ public FasterKV(Partition partition, BlobManager blobManager, MemoryTracker memo
blobManager.StoreCheckpointSettings,
new SerializerSettings<Key, Value>
{
keySerializer = () => new Key.Serializer(),
valueSerializer = () => new Value.Serializer(this.StoreStats, partition.TraceHelper, this.cacheDebugger),
keySerializer = () => new Key.Serializer(partition.ErrorHandler),
valueSerializer = () => new Value.Serializer(this.StoreStats, partition.TraceHelper, this.cacheDebugger, partition.ErrorHandler),
});

this.cacheTracker = memoryTracker.NewCacheTracker(this, (int) partition.PartitionId, this.cacheDebugger);
Expand Down Expand Up @@ -1594,13 +1595,51 @@ public bool Equals(ref Key k1, ref Key k2)

public class Serializer : BinaryObjectSerializer<Key>
{
readonly IPartitionErrorHandler errorHandler;

public Serializer(IPartitionErrorHandler errorHandler)
{
this.errorHandler = errorHandler;
}

public override void Deserialize(out Key obj)
{
obj = new Key();
obj.Val.Deserialize(this.reader);
try
{
if (!this.errorHandler.IsTerminated) // skip deserialization if the partition is already terminated - to speed up cancellation and to avoid repeated errors
{
// first, determine the object type
var objectType = (TrackedObjectKey.TrackedObjectType)this.reader.ReadByte();
if (objectType != TrackedObjectKey.TrackedObjectType.History
&& objectType != TrackedObjectKey.TrackedObjectType.Instance)
{
throw new SerializationException("invalid object type field");
}
var instanceId = this.reader.ReadString();
obj = new TrackedObjectKey(objectType, instanceId);
return;
}
}
catch (Exception ex)
{
this.errorHandler.HandleError("FasterKV.Key.Serializer", "could not deserialize key - possible data corruption", ex, true, !this.errorHandler.IsTerminated);
}

obj = default;
}

public override void Serialize(ref Key obj) => obj.Val.Serialize(this.writer);
public override void Serialize(ref Key obj)
{
try
{
this.writer.Write((byte)obj.Val.ObjectType);
this.writer.Write(obj.Val.InstanceId);
}
catch (Exception ex)
{
this.errorHandler.HandleError("FasterKV.Key.Serializer", "could not serialize key", ex, true, false);
}
}
}
}

Expand All @@ -1624,49 +1663,77 @@ public class Serializer : BinaryObjectSerializer<Value>
readonly StoreStatistics storeStats;
readonly PartitionTraceHelper traceHelper;
readonly CacheDebugger cacheDebugger;
readonly IPartitionErrorHandler errorHandler;

public Serializer(StoreStatistics storeStats, PartitionTraceHelper traceHelper, CacheDebugger cacheDebugger)
public Serializer(StoreStatistics storeStats, PartitionTraceHelper traceHelper, CacheDebugger cacheDebugger, IPartitionErrorHandler errorHandler)
{
this.storeStats = storeStats;
this.traceHelper = traceHelper;
this.cacheDebugger = cacheDebugger;
this.errorHandler = errorHandler;
}

public override void Deserialize(out Value obj)
{
int version = this.reader.ReadInt32();
int count = this.reader.ReadInt32();
byte[] bytes = this.reader.ReadBytes(count); // lazy deserialization - keep as byte array until used
obj = new Value { Val = bytes, Version = version};
if (this.cacheDebugger != null)
try
{
var trackedObject = DurableTask.Netherite.Serializer.DeserializeTrackedObject(bytes);
this.cacheDebugger?.Record(trackedObject.Key, CacheDebugger.CacheEvent.DeserializeBytes, version, null, 0);
if (!this.errorHandler.IsTerminated) // skip deserialization if the partition is already terminated - to speed up cancellation and to avoid repeated errors
{
int version = this.reader.ReadInt32();
int count = this.reader.ReadInt32();
byte[] bytes = this.reader.ReadBytes(count); // lazy deserialization - keep as byte array until used

if (bytes.Length != count)
{
throw new EndOfStreamException($"trying to read {count} bytes but only found {bytes.Length}");
}

obj = new Value { Val = bytes, Version = version};
if (this.cacheDebugger != null)
{
var trackedObject = DurableTask.Netherite.Serializer.DeserializeTrackedObject(bytes);
this.cacheDebugger?.Record(trackedObject.Key, CacheDebugger.CacheEvent.DeserializeBytes, version, null, 0);
}

return;
}
}
catch (Exception ex)
{
this.errorHandler.HandleError("FasterKV.Value.Serializer", "could not deserialize value - possible data corruption", ex, true, !this.errorHandler.IsTerminated);
}
obj = default;
}

public override void Serialize(ref Value obj)
{
this.writer.Write(obj.Version);
if (obj.Val is byte[] serialized)
try
{
// We did already serialize this object on the last CopyUpdate. So we can just use the byte array.
this.writer.Write(serialized.Length);
this.writer.Write(serialized);
if (this.cacheDebugger != null)
this.writer.Write(obj.Version);
if (obj.Val is byte[] serialized)
{
// We did already serialize this object on the last CopyUpdate. So we can just use the byte array.
this.writer.Write(serialized.Length);
this.writer.Write(serialized);
if (this.cacheDebugger != null)
{
var trackedObject = DurableTask.Netherite.Serializer.DeserializeTrackedObject(serialized);
this.cacheDebugger?.Record(trackedObject.Key, CacheDebugger.CacheEvent.SerializeBytes, obj.Version, null, 0);
}
}
else
{
var trackedObject = DurableTask.Netherite.Serializer.DeserializeTrackedObject(serialized);
this.cacheDebugger?.Record(trackedObject.Key, CacheDebugger.CacheEvent.SerializeBytes, obj.Version, null, 0);
TrackedObject trackedObject = (TrackedObject) obj.Val;
var bytes = DurableTask.Netherite.Serializer.SerializeTrackedObject(trackedObject);
this.storeStats.Serialize++;
this.writer.Write(bytes.Length);
this.writer.Write(bytes);
this.cacheDebugger?.Record(trackedObject.Key, CacheDebugger.CacheEvent.SerializeObject, obj.Version, null, 0);
}
}
else
catch (Exception ex)
{
TrackedObject trackedObject = (TrackedObject) obj.Val;
var bytes = DurableTask.Netherite.Serializer.SerializeTrackedObject(trackedObject);
this.storeStats.Serialize++;
this.writer.Write(bytes.Length);
this.writer.Write(bytes);
this.cacheDebugger?.Record(trackedObject.Key, CacheDebugger.CacheEvent.SerializeObject, obj.Version, null, 0);
this.errorHandler.HandleError("FasterKV.Value.Serializer", "could not serialize value", ex, true, false);
}
}
}
Expand Down