Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 34 additions & 17 deletions Harmony/Public/Harmony.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;

namespace HarmonyLib
{
Expand Down Expand Up @@ -132,21 +133,41 @@ public void PatchCategory(string category)
PatchCategory(assembly, category);
}

private static readonly ConditionalWeakTable<Assembly, Dictionary<string, List<Type>>> AssemblyCachedCategories = new();

/// <summary>Searches an assembly for HarmonyPatch-annotated classes/structs with a specific category and uses them to create patches</summary>
/// <param name="assembly">The assembly</param>
/// <param name="category">Name of patch category</param>
///
public void PatchCategory(Assembly assembly, string category)
{
AccessTools.GetTypesFromAssembly(assembly)
.Where(type =>
var categoryCache = AssemblyCachedCategories.GetValue(assembly, BuildCategoryCache);
if (categoryCache.TryGetValue(category, out var toPatch))
{
toPatch.Do(type => CreateClassProcessor(type).Patch());
}
}

private static Dictionary<string, List<Type>> BuildCategoryCache(Assembly assembly)
{
Dictionary<string, List<Type>> toBuild = [];
foreach (var type in AccessTools.GetTypesFromAssembly(assembly))
{
var harmonyAttributes = HarmonyMethodExtensions.GetFromType(type);
if (harmonyAttributes.Count == 0) continue;
var containerAttributes = HarmonyMethod.Merge(harmonyAttributes);
var category = containerAttributes.category;
if (!string.IsNullOrEmpty(category))
{
var harmonyAttributes = HarmonyMethodExtensions.GetFromType(type);
if (harmonyAttributes.Count == 0) return false;
var containerAttributes = HarmonyMethod.Merge(harmonyAttributes);
return containerAttributes.category == category;
})
.Do(type => CreateClassProcessor(type).Patch());
if (!toBuild.TryGetValue(category, out var typeList))
{
typeList ??= [];
}
typeList.Add(type);
toBuild[category] = typeList;
}
}
return toBuild;
}

/// <summary>Creates patches by manually specifying the methods</summary>
Expand Down Expand Up @@ -239,15 +260,11 @@ public void UnpatchCategory(string category)
///
public void UnpatchCategory(Assembly assembly, string category)
{
AccessTools.GetTypesFromAssembly(assembly)
.Where(type =>
{
var harmonyAttributes = HarmonyMethodExtensions.GetFromType(type);
if (harmonyAttributes.Count == 0) return false;
var containerAttributes = HarmonyMethod.Merge(harmonyAttributes);
return containerAttributes.category == category;
})
.Do(type => CreateClassProcessor(type).Unpatch());
var categoryCache = AssemblyCachedCategories.GetValue(assembly, BuildCategoryCache);
if (categoryCache.TryGetValue(category, out var toPatch))
{
toPatch.Do(type => CreateClassProcessor(type).Unpatch());
}
}

/// <summary>Test for patches from a specific Harmony ID</summary>
Expand Down
89 changes: 89 additions & 0 deletions HarmonyTests/Patching/CategoryPatches.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using HarmonyLib;
using NUnit.Framework;
using System.Runtime.CompilerServices;

namespace HarmonyLibTests.Patching
{
[TestFixture, NonParallelizable]
public class CategoryPatches : TestLogger
{
[Test]
public void Test_HarmonyPatchAll()
{
var harmony = new Harmony("test");
harmony.PatchCategory("CategoryA");

Assert.AreEqual(2, Get1());
Assert.AreEqual(false, GetTrue());
Assert.AreEqual("Hello World", GetHelloWorld());
Assert.AreEqual(18, Multiply(3, 6));


harmony.PatchCategory("CategoryB");

Assert.AreEqual(2, Get1());
Assert.AreEqual(false, GetTrue());
Assert.AreEqual("Hello World!", GetHelloWorld());
Assert.AreEqual(36, Multiply(3, 6));

harmony.UnpatchCategory("CategoryA");

Assert.AreEqual(1, Get1());
Assert.AreEqual(true, GetTrue());
Assert.AreEqual("Hello World!", GetHelloWorld());
Assert.AreEqual(36, Multiply(3, 6));

harmony.UnpatchCategory("CategoryB");

Assert.AreEqual(1, Get1());
Assert.AreEqual(true, GetTrue());
Assert.AreEqual("Hello World", GetHelloWorld());
Assert.AreEqual(18, Multiply(3, 6));
}
[MethodImpl(MethodImplOptions.NoInlining)]
public static int Get1() => 1;

[MethodImpl(MethodImplOptions.NoInlining)]
public static bool GetTrue() => true;

[MethodImpl(MethodImplOptions.NoInlining)]
public static string GetHelloWorld() => "Hello World";

[MethodImpl(MethodImplOptions.NoInlining)]
public static int Multiply(int a, int b) => a * b;

[HarmonyPatch]
[HarmonyPatch(typeof(CategoryPatches))]
[HarmonyPatchCategory("CategoryA")]
static class CategoryAPatches
{
[HarmonyPatch(nameof(Get1)), HarmonyPrefix]
public static bool Get1Patch(ref int __result)
{
__result = 2;
return false;
}
[HarmonyPatch(nameof(GetTrue)), HarmonyPostfix]
public static void GetTruePatch(ref bool __result)
{
__result = false;
}
}

[HarmonyPatch]
[HarmonyPatchCategory("CategoryB")]
static class CategoryBPatches
{
[HarmonyPatch(typeof(CategoryPatches), nameof(GetHelloWorld)), HarmonyPostfix]
public static void GetHelloWorldPatch(ref string __result)
{
__result = __result + "!";
}
[HarmonyPatch(typeof(CategoryPatches), nameof(Multiply)), HarmonyPrefix]
public static void Multiply(ref int a)
{
a *= 2;
}
}
}
}
Loading