Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the uses of ThreadStatic in class having static ctor #746

Merged
merged 5 commits into from
Sep 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 67 additions & 62 deletions cs/src/core/Epochs/LightEpoch.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

using System;
Expand All @@ -14,6 +14,42 @@ namespace FASTER.core
/// </summary>
public unsafe sealed class LightEpoch
{
/// <summary>
/// Store thread-static metadata separately; see https://github.com/microsoft/FASTER/pull/746
/// </summary>
private class Metadata
{
/// <summary>
/// Managed thread id of this thread
/// </summary>
[ThreadStatic]
internal static int threadId;

/// <summary>
/// Start offset to reserve entry in the epoch table
/// </summary>
[ThreadStatic]
internal static ushort startOffset1;

/// <summary>
/// Alternate start offset to reserve entry in the epoch table (to reduce probing if <see cref="startOffset1"/> slot is already filled)
/// </summary>
[ThreadStatic]
internal static ushort startOffset2;

/// <summary>
/// A thread's entry in the epoch table.
/// </summary>
[ThreadStatic]
internal static int threadEntryIndex;

/// <summary>
/// Number of instances using this entry
/// </summary>
[ThreadStatic]
internal static int threadEntryIndexCount;
}

/// <summary>
/// Size of cache line in bytes
/// </summary>
Expand Down Expand Up @@ -56,36 +92,6 @@ public unsafe sealed class LightEpoch
volatile int drainCount = 0;
readonly EpochActionPair[] drainList = new EpochActionPair[kDrainListSize];

/// <summary>
/// A thread's entry in the epoch table.
/// </summary>
[ThreadStatic]
static int threadEntryIndex;

/// <summary>
/// Number of instances using this entry
/// </summary>
[ThreadStatic]
static int threadEntryIndexCount;

/// <summary>
/// Managed thread id of this thread
/// </summary>
[ThreadStatic]
static int threadId;

/// <summary>
/// Start offset to reserve entry in the epoch table
/// </summary>
[ThreadStatic]
static ushort startOffset1;

/// <summary>
/// Alternate start offset to reserve entry in the epoch table (to reduce probing if <see cref="startOffset1"/> slot is already filled)
/// </summary>
[ThreadStatic]
static ushort startOffset2;

/// <summary>
/// Global current epoch value
/// </summary>
Expand Down Expand Up @@ -165,7 +171,7 @@ public void Dispose()
/// <returns>Result of the check</returns>
public bool ThisInstanceProtected()
{
int entry = threadEntryIndex;
int entry = Metadata.threadEntryIndex;
if (kInvalidIndex != entry)
{
if ((*(tableAligned + entry)).threadId == entry)
Expand All @@ -181,10 +187,10 @@ public bool ThisInstanceProtected()
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void ProtectAndDrain()
{
int entry = threadEntryIndex;
int entry = Metadata.threadEntryIndex;

// Protect CurrentEpoch by making an entry for it in the non-static epoch table so ComputeNewSafeToReclaimEpoch() will see it.
(*(tableAligned + entry)).threadId = threadEntryIndex;
(*(tableAligned + entry)).threadId = Metadata.threadEntryIndex;
(*(tableAligned + entry)).localCurrentEpoch = CurrentEpoch;

if (drainCount > 0)
Expand Down Expand Up @@ -279,8 +285,7 @@ public void BumpCurrentEpoch(Action onDrain)
public void Mark(int markerIdx, long version)
{
Debug.Assert(markerIdx < 6);

(*(tableAligned + threadEntryIndex)).markers[markerIdx] = version;
(*(tableAligned + Metadata.threadEntryIndex)).markers[markerIdx] = version;
}

/// <summary>
Expand Down Expand Up @@ -412,14 +417,14 @@ void Drain(long nextEpoch)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
void Acquire()
{
if (threadEntryIndex == kInvalidIndex)
threadEntryIndex = ReserveEntryForThread();
if (Metadata.threadEntryIndex == kInvalidIndex)
Metadata.threadEntryIndex = ReserveEntryForThread();

Debug.Assert((*(tableAligned + threadEntryIndex)).localCurrentEpoch == 0,
Debug.Assert((*(tableAligned + Metadata.threadEntryIndex)).localCurrentEpoch == 0,
"Trying to acquire protected epoch. Make sure you do not re-enter FASTER from callbacks or IDevice implementations. If using tasks, use TaskCreationOptions.RunContinuationsAsynchronously.");

// This corresponds to AnyInstanceProtected(). We do not mark "ThisInstanceProtected" until ProtectAndDrain().
threadEntryIndexCount++;
Metadata.threadEntryIndexCount++;
}

/// <summary>
Expand All @@ -428,7 +433,7 @@ void Acquire()
[MethodImpl(MethodImplOptions.AggressiveInlining)]
void Release()
{
int entry = threadEntryIndex;
int entry = Metadata.threadEntryIndex;

Debug.Assert((*(tableAligned + entry)).localCurrentEpoch != 0,
"Trying to release unprotected epoch. Make sure you do not re-enter FASTER from callbacks or IDevice implementations. If using tasks, use TaskCreationOptions.RunContinuationsAsynchronously.");
Expand All @@ -438,11 +443,11 @@ void Release()
(*(tableAligned + entry)).threadId = 0;

// Decrement "AnyInstanceProtected()" (static thread table)
threadEntryIndexCount--;
if (threadEntryIndexCount == 0)
Metadata.threadEntryIndexCount--;
if (Metadata.threadEntryIndexCount == 0)
{
(threadIndexAligned + threadEntryIndex)->threadId = 0;
threadEntryIndex = kInvalidIndex;
(threadIndexAligned + Metadata.threadEntryIndex)->threadId = 0;
Metadata.threadEntryIndex = kInvalidIndex;
}
}

Expand All @@ -456,24 +461,24 @@ static int ReserveEntry()
while (true)
{
// Try to acquire entry
if (0 == (threadIndexAligned + startOffset1)->threadId)
if (0 == (threadIndexAligned + Metadata.startOffset1)->threadId)
{
if (0 == Interlocked.CompareExchange(
ref (threadIndexAligned + startOffset1)->threadId,
threadId, 0))
return startOffset1;
ref (threadIndexAligned + Metadata.startOffset1)->threadId,
Metadata.threadId, 0))
return Metadata.startOffset1;
}

if (startOffset2 > 0)
if (Metadata.startOffset2 > 0)
{
// Try alternate entry
startOffset1 = startOffset2;
startOffset2 = 0;
Metadata.startOffset1 = Metadata.startOffset2;
Metadata.startOffset2 = 0;
}
else startOffset1++; // Probe next sequential entry
if (startOffset1 > kTableSize)
else Metadata.startOffset1++; // Probe next sequential entry
if (Metadata.startOffset1 > kTableSize)
{
startOffset1 -= kTableSize;
Metadata.startOffset1 -= kTableSize;
Thread.Yield();
}
}
Expand All @@ -484,7 +489,7 @@ static int ReserveEntry()
/// </summary>
/// <param name="h"></param>
/// <returns></returns>
private static int Murmur3(int h)
static int Murmur3(int h)
{
uint a = (uint)h;
a ^= a >> 16;
Expand All @@ -502,12 +507,12 @@ private static int Murmur3(int h)
/// <returns>Reserved entry</returns>
static int ReserveEntryForThread()
{
if (threadId == 0) // run once per thread for performance
if (Metadata.threadId == 0) // run once per thread for performance
{
threadId = Environment.CurrentManagedThreadId;
uint code = (uint)Murmur3(threadId);
startOffset1 = (ushort)(1 + (code % kTableSize));
startOffset2 = (ushort)(1 + ((code >> 16) % kTableSize));
Metadata.threadId = Environment.CurrentManagedThreadId;
uint code = (uint)Murmur3(Metadata.threadId);
Metadata.startOffset1 = (ushort)(1 + (code % kTableSize));
Metadata.startOffset2 = (ushort)(1 + ((code >> 16) % kTableSize));
}
return ReserveEntry();
}
Expand Down Expand Up @@ -546,4 +551,4 @@ struct EpochActionPair
public override string ToString() => $"epoch = {epoch}, action = {(action is null ? "n/a" : action.Method.ToString())}";
}
}
}
}