diff --git a/Harmony/Public/Harmony.cs b/Harmony/Public/Harmony.cs index 5784913e..5788d896 100644 --- a/Harmony/Public/Harmony.cs +++ b/Harmony/Public/Harmony.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Linq; using System.Reflection; +using System.Runtime.CompilerServices; namespace HarmonyLib { @@ -132,21 +133,41 @@ public void PatchCategory(string category) PatchCategory(assembly, category); } + private static readonly ConditionalWeakTable>> AssemblyCachedCategories = new(); + /// Searches an assembly for HarmonyPatch-annotated classes/structs with a specific category and uses them to create patches /// The assembly /// Name of patch category /// public void PatchCategory(Assembly assembly, string category) { - AccessTools.GetTypesFromAssembly(assembly) - .Where(type => + var categoryCache = AssemblyCachedCategories.GetValue(assembly, BuildCategoryCache); + if (categoryCache.TryGetValue(category, out var toPatch)) + { + toPatch.Do(type => CreateClassProcessor(type).Patch()); + } + } + + private static Dictionary> BuildCategoryCache(Assembly assembly) + { + Dictionary> toBuild = []; + foreach (var type in AccessTools.GetTypesFromAssembly(assembly)) + { + var harmonyAttributes = HarmonyMethodExtensions.GetFromType(type); + if (harmonyAttributes.Count == 0) continue; + var containerAttributes = HarmonyMethod.Merge(harmonyAttributes); + var category = containerAttributes.category; + if (!string.IsNullOrEmpty(category)) { - var harmonyAttributes = HarmonyMethodExtensions.GetFromType(type); - if (harmonyAttributes.Count == 0) return false; - var containerAttributes = HarmonyMethod.Merge(harmonyAttributes); - return containerAttributes.category == category; - }) - .Do(type => CreateClassProcessor(type).Patch()); + if (!toBuild.TryGetValue(category, out var typeList)) + { + typeList ??= []; + } + typeList.Add(type); + toBuild[category] = typeList; + } + } + return toBuild; } /// Creates patches by manually specifying the methods @@ -239,15 +260,11 @@ public void UnpatchCategory(string category) /// public void UnpatchCategory(Assembly assembly, string category) { - AccessTools.GetTypesFromAssembly(assembly) - .Where(type => - { - var harmonyAttributes = HarmonyMethodExtensions.GetFromType(type); - if (harmonyAttributes.Count == 0) return false; - var containerAttributes = HarmonyMethod.Merge(harmonyAttributes); - return containerAttributes.category == category; - }) - .Do(type => CreateClassProcessor(type).Unpatch()); + var categoryCache = AssemblyCachedCategories.GetValue(assembly, BuildCategoryCache); + if (categoryCache.TryGetValue(category, out var toPatch)) + { + toPatch.Do(type => CreateClassProcessor(type).Unpatch()); + } } /// Test for patches from a specific Harmony ID diff --git a/HarmonyTests/Patching/CategoryPatches.cs b/HarmonyTests/Patching/CategoryPatches.cs new file mode 100644 index 00000000..9b101391 --- /dev/null +++ b/HarmonyTests/Patching/CategoryPatches.cs @@ -0,0 +1,89 @@ +using HarmonyLib; +using NUnit.Framework; +using System.Runtime.CompilerServices; + +namespace HarmonyLibTests.Patching +{ + [TestFixture, NonParallelizable] + public class CategoryPatches : TestLogger + { + [Test] + public void Test_HarmonyPatchAll() + { + var harmony = new Harmony("test"); + harmony.PatchCategory("CategoryA"); + + Assert.AreEqual(2, Get1()); + Assert.AreEqual(false, GetTrue()); + Assert.AreEqual("Hello World", GetHelloWorld()); + Assert.AreEqual(18, Multiply(3, 6)); + + + harmony.PatchCategory("CategoryB"); + + Assert.AreEqual(2, Get1()); + Assert.AreEqual(false, GetTrue()); + Assert.AreEqual("Hello World!", GetHelloWorld()); + Assert.AreEqual(36, Multiply(3, 6)); + + harmony.UnpatchCategory("CategoryA"); + + Assert.AreEqual(1, Get1()); + Assert.AreEqual(true, GetTrue()); + Assert.AreEqual("Hello World!", GetHelloWorld()); + Assert.AreEqual(36, Multiply(3, 6)); + + harmony.UnpatchCategory("CategoryB"); + + Assert.AreEqual(1, Get1()); + Assert.AreEqual(true, GetTrue()); + Assert.AreEqual("Hello World", GetHelloWorld()); + Assert.AreEqual(18, Multiply(3, 6)); + } + [MethodImpl(MethodImplOptions.NoInlining)] + public static int Get1() => 1; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static bool GetTrue() => true; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static string GetHelloWorld() => "Hello World"; + + [MethodImpl(MethodImplOptions.NoInlining)] + public static int Multiply(int a, int b) => a * b; + + [HarmonyPatch] + [HarmonyPatch(typeof(CategoryPatches))] + [HarmonyPatchCategory("CategoryA")] + static class CategoryAPatches + { + [HarmonyPatch(nameof(Get1)), HarmonyPrefix] + public static bool Get1Patch(ref int __result) + { + __result = 2; + return false; + } + [HarmonyPatch(nameof(GetTrue)), HarmonyPostfix] + public static void GetTruePatch(ref bool __result) + { + __result = false; + } + } + + [HarmonyPatch] + [HarmonyPatchCategory("CategoryB")] + static class CategoryBPatches + { + [HarmonyPatch(typeof(CategoryPatches), nameof(GetHelloWorld)), HarmonyPostfix] + public static void GetHelloWorldPatch(ref string __result) + { + __result = __result + "!"; + } + [HarmonyPatch(typeof(CategoryPatches), nameof(Multiply)), HarmonyPrefix] + public static void Multiply(ref int a) + { + a *= 2; + } + } + } +}