Skip to content

Commit 2388b9a

Browse files
authored
Merge pull request #1081 from CommunityToolkit/user/sergiopedri/handle-canexecute-overrides
Handle 'CanExecute' with method overrides
2 parents 95043f5 + 7d4d6ec commit 2388b9a

File tree

5 files changed

+224
-2
lines changed

5 files changed

+224
-2
lines changed

src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
<Compile Include="$(MSBuildThisFileDirectory)Extensions\GeneratorAttributeSyntaxContextWithOptions.cs" />
7575
<Compile Include="$(MSBuildThisFileDirectory)Extensions\IncrementalGeneratorInitializationContextExtensions.cs" />
7676
<Compile Include="$(MSBuildThisFileDirectory)Extensions\IncrementalValuesProviderExtensions.cs" />
77+
<Compile Include="$(MSBuildThisFileDirectory)Extensions\IMethodSymbolExtensions.cs" />
7778
<Compile Include="$(MSBuildThisFileDirectory)Extensions\MethodDeclarationSyntaxExtensions.cs" />
7879
<Compile Include="$(MSBuildThisFileDirectory)Extensions\SymbolInfoExtensions.cs" />
7980
<Compile Include="$(MSBuildThisFileDirectory)Extensions\ISymbolExtensions.cs" />
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Immutable;
6+
using System.Linq;
7+
using Microsoft.CodeAnalysis;
8+
9+
namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions;
10+
11+
/// <summary>
12+
/// Extension methods for the <see cref="IMethodSymbol"/> type.
13+
/// </summary>
14+
internal static class IMethodSymbolExtensions
15+
{
16+
/// <summary>
17+
/// Checks whether all input symbols are <see cref="IMethodSymbol"/>-s in the same override hierarchy.
18+
/// </summary>
19+
/// <param name="symbols">The input <see cref="ISymbol"/> set to check.</param>
20+
/// <returns>Whether all input symbols are <see cref="IMethodSymbol"/>-s in the same override hierarchy.</returns>
21+
public static bool AreAllInSameOverriddenMethodHierarchy(this ImmutableArray<ISymbol> symbols)
22+
{
23+
IMethodSymbol? baseSymbol = null;
24+
25+
// Look for the base method
26+
foreach (ISymbol currentSymbol in symbols)
27+
{
28+
// If any input symbol is not a method, we can stop right away
29+
if (currentSymbol is not IMethodSymbol methodSymbol)
30+
{
31+
return false;
32+
}
33+
34+
if (methodSymbol.IsVirtual)
35+
{
36+
// If we already found a base method, all methods can't possibly be in the same hierarchy
37+
if (baseSymbol is not null)
38+
{
39+
return false;
40+
}
41+
42+
baseSymbol = methodSymbol;
43+
}
44+
}
45+
46+
// If we didn't find any, stop here
47+
if (baseSymbol is null)
48+
{
49+
return false;
50+
}
51+
52+
// Verify all methods are in the same tree
53+
foreach (ISymbol currentSymbol in symbols)
54+
{
55+
IMethodSymbol methodSymbol = (IMethodSymbol)currentSymbol;
56+
57+
// Ignore the base method
58+
if (SymbolEqualityComparer.Default.Equals(methodSymbol, baseSymbol))
59+
{
60+
continue;
61+
}
62+
63+
// If the current method isn't an override, then fail
64+
if (methodSymbol.OverriddenMethod is not { } overriddenMethod)
65+
{
66+
return false;
67+
}
68+
69+
// The current method must be overriding another one in the set
70+
if (!symbols.Any(symbol => SymbolEqualityComparer.Default.Equals(symbol, overriddenMethod)))
71+
{
72+
return false;
73+
}
74+
}
75+
76+
return true;
77+
}
78+
}

src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,8 +810,10 @@ private static bool TryGetCanExecuteExpressionType(
810810

811811
diagnostics.Add(InvalidCanExecuteMemberNameError, methodSymbol, memberName, methodSymbol.ContainingType);
812812
}
813-
else if (canExecuteSymbols.Length > 1)
813+
else if (canExecuteSymbols.Length > 1 && !canExecuteSymbols.AreAllInSameOverriddenMethodHierarchy())
814814
{
815+
// We specifically allow targeting methods which are overridden: they'll be more than one,
816+
// but it doesn't matter since you'd only ever call "one", being the most derived one.
815817
diagnostics.Add(MultipleCanExecuteMemberNameMatchesError, methodSymbol, memberName, methodSymbol.ContainingType);
816818
}
817819
else if (TryGetCanExecuteExpressionFromSymbol(canExecuteSymbols[0], commandTypeArguments, out canExecuteExpressionType))

tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,6 +2060,60 @@ partial class MyViewModel
20602060
VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyViewModel.Test.g.cs", result));
20612061
}
20622062

2063+
[TestMethod]
2064+
public void RelayCommandWithOverriddenCanExecute_TargetsOverriddenMethod()
2065+
{
2066+
string source = """
2067+
using CommunityToolkit.Mvvm.ComponentModel;
2068+
using CommunityToolkit.Mvvm.Input;
2069+
2070+
namespace MyApp;
2071+
2072+
public partial class BaseViewModel : ObservableObject
2073+
{
2074+
protected virtual bool CanDoStuff()
2075+
{
2076+
return false;
2077+
}
2078+
}
2079+
2080+
public partial class SampleViewModel : BaseViewModel
2081+
{
2082+
[RelayCommand(CanExecute = nameof(CanDoStuff)]
2083+
private void DoStuff()
2084+
{
2085+
}
2086+
2087+
protected override bool CanDoStuff()
2088+
{
2089+
return true;
2090+
}
2091+
}
2092+
""";
2093+
2094+
string result = """
2095+
// <auto-generated/>
2096+
#pragma warning disable
2097+
#nullable enable
2098+
namespace MyApp
2099+
{
2100+
/// <inheritdoc/>
2101+
partial class SampleViewModel
2102+
{
2103+
/// <summary>The backing field for <see cref="DoStuffCommand"/>.</summary>
2104+
[global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", <ASSEMBLY_VERSION>)]
2105+
private global::CommunityToolkit.Mvvm.Input.RelayCommand? doStuffCommand;
2106+
/// <summary>Gets an <see cref="global::CommunityToolkit.Mvvm.Input.IRelayCommand"/> instance wrapping <see cref="DoStuff"/>.</summary>
2107+
[global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", <ASSEMBLY_VERSION>)]
2108+
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
2109+
public global::CommunityToolkit.Mvvm.Input.IRelayCommand DoStuffCommand => doStuffCommand ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(DoStuff), CanDoStuff);
2110+
}
2111+
}
2112+
""";
2113+
2114+
VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.SampleViewModel.DoStuff.g.cs", result));
2115+
}
2116+
20632117
[TestMethod]
20642118
public void ObservableProperty_AnnotatedFieldHasValueIdentifier()
20652119
{

tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2259,6 +2259,80 @@ await CSharpAnalyzerWithLanguageVersionTest<WinRTRelayCommandIsNotGeneratedBinda
22592259
editorconfig: [("_MvvmToolkitIsUsingWindowsRuntimePack", true)]);
22602260
}
22612261

2262+
[TestMethod]
2263+
public void RelayCommandWithOverriddenCanExecute_DoesNotWarn()
2264+
{
2265+
const string source = """
2266+
using CommunityToolkit.Mvvm.ComponentModel;
2267+
using CommunityToolkit.Mvvm.Input;
2268+
2269+
namespace MyApp
2270+
{
2271+
public partial class BaseViewModel : ObservableObject
2272+
{
2273+
protected virtual bool CanDoStuff()
2274+
{
2275+
return false;
2276+
}
2277+
}
2278+
2279+
public partial class SampleViewModel : BaseViewModel
2280+
{
2281+
[RelayCommand(CanExecute = nameof(CanDoStuff))]
2282+
private void DoStuff()
2283+
{
2284+
}
2285+
2286+
protected override bool CanDoStuff()
2287+
{
2288+
return true;
2289+
}
2290+
}
2291+
}
2292+
""";
2293+
2294+
VerifyGeneratedDiagnostics<RelayCommandGenerator>(source, LanguageVersion.CSharp12);
2295+
}
2296+
2297+
[TestMethod]
2298+
public void RelayCommandWithOverriddenCanExecute_WithOneMethodNotInTheSameHierarchy_Warns()
2299+
{
2300+
const string source = """
2301+
using CommunityToolkit.Mvvm.ComponentModel;
2302+
using CommunityToolkit.Mvvm.Input;
2303+
2304+
namespace MyApp
2305+
{
2306+
public partial class BaseViewModel : ObservableObject
2307+
{
2308+
protected virtual bool CanDoStuff()
2309+
{
2310+
return false;
2311+
}
2312+
2313+
protected bool CanDoStuff(string x)
2314+
{
2315+
}
2316+
}
2317+
2318+
public partial class SampleViewModel : BaseViewModel
2319+
{
2320+
[RelayCommand(CanExecute = nameof(CanDoStuff)]
2321+
private void DoStuff()
2322+
{
2323+
}
2324+
2325+
private override bool CanDoStuff()
2326+
{
2327+
return true;
2328+
}
2329+
}
2330+
}
2331+
""";
2332+
2333+
VerifyGeneratedDiagnostics<RelayCommandGenerator>(source, "MVVMTK0010");
2334+
}
2335+
22622336
[TestMethod]
22632337
public async Task WinRTClassUsingNotifyPropertyChangedAttributesAnalyzer_NotTargetingWindows_DoesNotWarn()
22642338
{
@@ -2451,10 +2525,23 @@ internal static async Task VerifyAnalyzerDiagnosticsAndSuccessfulGeneration<TAna
24512525
/// <param name="diagnosticsIds">The diagnostic ids to expect for the input source code.</param>
24522526
internal static void VerifyGeneratedDiagnostics<TGenerator>(string source, params string[] diagnosticsIds)
24532527
where TGenerator : class, IIncrementalGenerator, new()
2528+
{
2529+
VerifyGeneratedDiagnostics<TGenerator>(source, LanguageVersion.CSharp8, diagnosticsIds);
2530+
}
2531+
2532+
/// <summary>
2533+
/// Verifies the output of a source generator.
2534+
/// </summary>
2535+
/// <typeparam name="TGenerator">The generator type to use.</typeparam>
2536+
/// <param name="source">The input source to process.</param>
2537+
/// <param name="languageVersion">The language version to use to parse code and run tests.</param>
2538+
/// <param name="diagnosticsIds">The diagnostic ids to expect for the input source code.</param>
2539+
internal static void VerifyGeneratedDiagnostics<TGenerator>(string source, LanguageVersion languageVersion, params string[] diagnosticsIds)
2540+
where TGenerator : class, IIncrementalGenerator, new()
24542541
{
24552542
IIncrementalGenerator generator = new TGenerator();
24562543

2457-
VerifyGeneratedDiagnostics(CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.CSharp8)), new[] { generator }, diagnosticsIds, []);
2544+
VerifyGeneratedDiagnostics(CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default.WithLanguageVersion(languageVersion)), new[] { generator }, diagnosticsIds, []);
24582545
}
24592546

24602547
/// <summary>

0 commit comments

Comments
 (0)