33// See the LICENSE file in the project root for more information.
44
55using System ;
6+ using System . Collections . Immutable ;
67using System . Diagnostics ;
78using System . Diagnostics . CodeAnalysis ;
89using System . Linq ;
@@ -578,6 +579,16 @@ private static bool IsConversionCastSafeToRemove(
578579 return true ;
579580 }
580581
582+ // Similarly, we want to support this for:
583+ //
584+ // int? a = b switch { true => (int?)0, false => 1 }
585+ if ( IsSwitchExpressionCaseCastSafeToRemove (
586+ castNode , originalSemanticModel ,
587+ rewrittenExpression , rewrittenSemanticModel , cancellationToken ) )
588+ {
589+ return true ;
590+ }
591+
581592 // Widening a value before bitwise negation produces the same value as bitwise negation
582593 // followed by the same widening. For example:
583594 //
@@ -825,17 +836,75 @@ private static bool IsConditionalCastSafeToRemove(
825836 ExpressionSyntax castNode , SemanticModel originalSemanticModel ,
826837 ExpressionSyntax rewrittenExpression , SemanticModel rewrittenSemanticModel , CancellationToken cancellationToken )
827838 {
828- if ( castNode is not CastExpressionSyntax castExpression )
829- return false ;
839+ // Defer to common helper to determine if the cast can be removed. This unified processing of `x ? y : z` and
840+ // `x switch { .. => y, .. => z, .. => w, ... }` expressions.
841+ return IsSwitchOrConditionalCastSafeToRemove (
842+ castNode ,
843+ originalSemanticModel ,
844+ rewrittenExpression ,
845+ rewrittenSemanticModel ,
846+ static parentExpression => parentExpression . Parent is ConditionalExpressionSyntax conditionalExpression && conditionalExpression . Condition != parentExpression
847+ ? conditionalExpression
848+ : null ,
849+ static conditionalExpression => [ conditionalExpression . WhenTrue , conditionalExpression . WhenFalse ] ,
850+ static ( conditionalExpression , armExpression ) =>
851+ {
852+ Contract . ThrowIfFalse ( conditionalExpression . WhenTrue == armExpression || conditionalExpression . WhenFalse == armExpression ) ;
853+ return armExpression == conditionalExpression . WhenTrue
854+ ? conditionalExpression . WhenFalse
855+ : conditionalExpression . WhenTrue ;
856+ } ,
857+ cancellationToken ) ;
858+ }
859+
860+ private static bool IsSwitchExpressionCaseCastSafeToRemove (
861+ ExpressionSyntax castNode , SemanticModel originalSemanticModel ,
862+ ExpressionSyntax rewrittenExpression , SemanticModel rewrittenSemanticModel , CancellationToken cancellationToken )
863+ {
864+ // Defer to common helper to determine if the cast can be removed. This unified processing of `x ? y : z` and
865+ // `x switch { .. => y, .. => z, .. => w, ... }` expressions.
866+ return IsSwitchOrConditionalCastSafeToRemove (
867+ castNode ,
868+ originalSemanticModel ,
869+ rewrittenExpression ,
870+ rewrittenSemanticModel ,
871+ static parentExpression => parentExpression . Parent is SwitchExpressionArmSyntax { Parent : SwitchExpressionSyntax switchExpression }
872+ ? switchExpression
873+ : null ,
874+ static switchExpression => switchExpression . Arms . SelectAsArray ( a => a . Expression ) ,
875+ static ( switchExpression , armExpression ) =>
876+ {
877+ if ( switchExpression . Arms . Count < 2 )
878+ return null ;
879+
880+ var arm = switchExpression . Arms . Single ( a => a . Expression == armExpression ) ;
881+ var armIndex = switchExpression . Arms . IndexOf ( arm ) ;
882+ return armIndex == 0
883+ ? switchExpression . Arms [ 1 ] . Expression
884+ : switchExpression . Arms [ armIndex - 1 ] . Expression ;
885+ } ,
886+ cancellationToken ) ;
887+ }
830888
831- var parent = castExpression . WalkUpParentheses ( ) ;
832- if ( parent . Parent is not ConditionalExpressionSyntax originalConditionalExpression )
889+ private static bool IsSwitchOrConditionalCastSafeToRemove < TConditionalOrSwitchExpression > (
890+ ExpressionSyntax castNode ,
891+ SemanticModel originalSemanticModel ,
892+ ExpressionSyntax rewrittenExpression ,
893+ SemanticModel rewrittenSemanticModel ,
894+ Func < ExpressionSyntax , TConditionalOrSwitchExpression ? > getConditionalOrSwitchExpression ,
895+ Func < TConditionalOrSwitchExpression , ImmutableArray < ExpressionSyntax > > getArmExpressions ,
896+ Func < TConditionalOrSwitchExpression , ExpressionSyntax , ExpressionSyntax ? > getAlternativeArm ,
897+ CancellationToken cancellationToken )
898+ where TConditionalOrSwitchExpression : ExpressionSyntax
899+ {
900+ if ( castNode is not CastExpressionSyntax castExpression )
833901 return false ;
834902
835- // if we were parented by a conditional before, we must be parented by a conditional afterwards.
836- var rewrittenConditionalExpression = ( ConditionalExpressionSyntax ) rewrittenExpression . WalkUpParentheses ( ) . GetRequiredParent ( ) ;
903+ var parentExpression = castExpression . WalkUpParentheses ( ) ;
837904
838- if ( parent != originalConditionalExpression . WhenFalse && parent != originalConditionalExpression . WhenTrue )
905+ var originalConditionalOrSwitchExpression = getConditionalOrSwitchExpression ( parentExpression ) ;
906+ var rewrittenConditionalOrSwitchExpression = getConditionalOrSwitchExpression ( rewrittenExpression . WalkUpParentheses ( ) ) ;
907+ if ( originalConditionalOrSwitchExpression is null || rewrittenConditionalOrSwitchExpression is null )
839908 return false ;
840909
841910 if ( originalSemanticModel . GetOperation ( castExpression , cancellationToken ) is not IConversionOperation conversionOperation )
@@ -856,16 +925,16 @@ bool IsConditionalCastSafeToRemoveDueToConversionOfEntireConditionalExpression()
856925 {
857926 // if we have `a ? (int?)b : default` then we can't remove the nullable cast as it changes the
858927 // meaning of `default`.
859- if ( originalConditionalExpression . WhenTrue . WalkDownParentheses ( ) . IsKind ( SyntaxKind . DefaultLiteralExpression ) ||
860- originalConditionalExpression . WhenFalse . WalkDownParentheses ( ) . IsKind ( SyntaxKind . DefaultLiteralExpression ) )
928+ foreach ( var armExpression in getArmExpressions ( originalConditionalOrSwitchExpression ) )
861929 {
862- return false ;
930+ if ( armExpression . WalkDownParentheses ( ) . IsKind ( SyntaxKind . DefaultLiteralExpression ) )
931+ return false ;
863932 }
864933 }
865934
866935 var originalCastExpressionTypeInfo = originalSemanticModel . GetTypeInfo ( castExpression , cancellationToken ) ;
867- var originalConditionalTypeInfo = originalSemanticModel . GetTypeInfo ( originalConditionalExpression , cancellationToken ) ;
868- var rewrittenConditionalTypeInfo = rewrittenSemanticModel . GetTypeInfo ( rewrittenConditionalExpression , cancellationToken ) ;
936+ var originalConditionalTypeInfo = originalSemanticModel . GetTypeInfo ( originalConditionalOrSwitchExpression , cancellationToken ) ;
937+ var rewrittenConditionalTypeInfo = rewrittenSemanticModel . GetTypeInfo ( rewrittenConditionalOrSwitchExpression , cancellationToken ) ;
869938
870939 if ( IsNullOrErrorType ( originalCastExpressionTypeInfo ) ||
871940 IsNullOrErrorType ( originalConditionalTypeInfo ) ||
@@ -886,13 +955,14 @@ bool IsConditionalCastSafeToRemoveDueToConversionOfEntireConditionalExpression()
886955 if ( IsNullOrErrorType ( castType ) )
887956 return false ;
888957
889- if ( rewrittenSemanticModel . GetOperation ( rewrittenConditionalExpression , cancellationToken ) is not IConditionalOperation rewrittenConditionalOperation )
958+ var rewrittenOperation = rewrittenSemanticModel . GetOperation ( rewrittenConditionalOrSwitchExpression , cancellationToken ) ;
959+ if ( rewrittenOperation is not IConditionalOperation and not ISwitchExpressionOperation )
890960 return false ;
891961
892- if ( castType . Equals ( rewrittenConditionalOperation . Type , SymbolEqualityComparer . IncludeNullability ) )
962+ if ( castType . Equals ( rewrittenOperation . Type , SymbolEqualityComparer . IncludeNullability ) )
893963 return true ;
894964
895- if ( rewrittenConditionalOperation . Parent is IConversionOperation conditionalParentConversion &&
965+ if ( rewrittenOperation . Parent is IConversionOperation conditionalParentConversion &&
896966 conditionalParentConversion . GetConversion ( ) . IsImplicit &&
897967 castType . Equals ( conditionalParentConversion . Type , SymbolEqualityComparer . IncludeNullability ) )
898968 {
@@ -911,7 +981,10 @@ bool IsConditionalCastSafeToRemoveDueToConversionToOtherBranch()
911981 if ( castExpression . Expression . WalkDownParentheses ( ) . IsKind ( SyntaxKind . DefaultLiteralExpression ) )
912982 return false ;
913983
914- var otherSide = parent == originalConditionalExpression . WhenFalse ? originalConditionalExpression . WhenTrue : originalConditionalExpression . WhenFalse ;
984+ var otherSide = getAlternativeArm ( originalConditionalOrSwitchExpression , parentExpression ) ;
985+ if ( otherSide is null )
986+ return false ;
987+
915988 var otherSideType = originalSemanticModel . GetTypeInfo ( otherSide , cancellationToken ) . Type ;
916989 var thisSideRewrittenType = rewrittenSemanticModel . GetTypeInfo ( rewrittenExpression , cancellationToken ) . Type ;
917990
@@ -925,11 +998,11 @@ bool IsConditionalCastSafeToRemoveDueToConversionToOtherBranch()
925998 // Now check that with the (T) cast removed, that the outer `x ? y : z` is still
926999 // immediately implicitly converted to a 'T'. If so, we can remove this inner (T) cast.
9271000
928- var rewrittenConditionalConvertedType = rewrittenSemanticModel . GetTypeInfo ( rewrittenConditionalExpression , cancellationToken ) . ConvertedType ;
1001+ var rewrittenConditionalConvertedType = rewrittenSemanticModel . GetTypeInfo ( rewrittenConditionalOrSwitchExpression , cancellationToken ) . ConvertedType ;
9291002 if ( rewrittenConditionalConvertedType is null )
9301003 return false ;
9311004
932- var outerConversion = rewrittenSemanticModel . GetConversion ( rewrittenConditionalExpression , cancellationToken ) ;
1005+ var outerConversion = rewrittenSemanticModel . GetConversion ( rewrittenConditionalOrSwitchExpression , cancellationToken ) ;
9331006 if ( ! outerConversion . IsImplicit )
9341007 return false ;
9351008
0 commit comments