diff --git a/src/Analyzers/MSTest.Analyzers.CodeFixes/UseProperAssertMethodsFixer.cs b/src/Analyzers/MSTest.Analyzers.CodeFixes/UseProperAssertMethodsFixer.cs index 259b1d912e..2f68f777fe 100644 --- a/src/Analyzers/MSTest.Analyzers.CodeFixes/UseProperAssertMethodsFixer.cs +++ b/src/Analyzers/MSTest.Analyzers.CodeFixes/UseProperAssertMethodsFixer.cs @@ -75,6 +75,9 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context) case UseProperAssertMethodsAnalyzer.CodeFixModeAddArgument: createChangedDocument = ct => FixAssertMethodForAddArgumentModeAsync(context.Document, diagnostic.AdditionalLocations[0], diagnostic.AdditionalLocations[1], diagnostic.AdditionalLocations[2], root, simpleNameSyntax, properAssertMethodName, ct); break; + case UseProperAssertMethodsAnalyzer.CodeFixModeAddTwoArguments: + createChangedDocument = ct => FixAssertMethodForAddTwoArgumentsModeAsync(context.Document, diagnostic.AdditionalLocations[0], diagnostic.AdditionalLocations[1], diagnostic.AdditionalLocations[2], diagnostic.AdditionalLocations[3], root, simpleNameSyntax, properAssertMethodName, ct); + break; case UseProperAssertMethodsAnalyzer.CodeFixModeRemoveArgument: createChangedDocument = ct => FixAssertMethodForRemoveArgumentModeAsync(context.Document, diagnostic.AdditionalLocations, root, simpleNameSyntax, properAssertMethodName, diagnostic.Properties.ContainsKey(UseProperAssertMethodsAnalyzer.NeedsNullableBooleanCastKey), ct); break; @@ -215,6 +218,98 @@ private static async Task FixAssertMethodForAddArgumentModeAsync(Docum return editor.GetChangedDocument(); } + private static async Task FixAssertMethodForAddTwoArgumentsModeAsync(Document document, Location conditionLocation, Location firstArgLocation, Location secondArgLocation, Location thirdArgLocation, SyntaxNode root, SimpleNameSyntax simpleNameSyntax, string properAssertMethodName, CancellationToken cancellationToken) + { + // Handle Contains with comparer: Assert.IsTrue(enumerable.Contains(item, comparer)) -> Assert.Contains(item, enumerable, comparer) + if (root.FindNode(conditionLocation.SourceSpan) is not ArgumentSyntax conditionNode) + { + return document; + } + + if (conditionNode.Parent is not ArgumentListSyntax argumentList) + { + return document; + } + + // FindNode may return ArgumentSyntax (outermost for tied spans) when the expression + // is a direct child of an ArgumentSyntax in the inner invocation. Extract the expression. + if (!TryGetExpressionFromNode(root, firstArgLocation, out ExpressionSyntax? firstArgNode)) + { + return document; + } + + if (!TryGetExpressionFromNode(root, secondArgLocation, out ExpressionSyntax? secondArgNode)) + { + return document; + } + + if (!TryGetExpressionFromNode(root, thirdArgLocation, out ExpressionSyntax? thirdArgNode)) + { + return document; + } + + DocumentEditor editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false); + FixInvocationMethodName(editor, simpleNameSyntax, properAssertMethodName); + + int conditionIndex = argumentList.Arguments.IndexOf(conditionNode); + + // Build the new arguments list: (firstArg, secondArg, thirdArg, ...remaining) + var newArguments = new List(); + var newSeparators = new List(); + + newArguments.Add(SyntaxFactory.Argument( + firstArgNode + .WithLeadingTrivia(conditionNode.GetLeadingTrivia()) + .WithoutTrailingTrivia())); + + newSeparators.Add(SyntaxFactory.Token(SyntaxKind.CommaToken)); + newArguments.Add(SyntaxFactory.Argument( + secondArgNode + .WithoutLeadingTrivia() + .WithoutTrailingTrivia() + .WithLeadingTrivia(SyntaxFactory.Space))); + + newSeparators.Add(SyntaxFactory.Token(SyntaxKind.CommaToken)); + newArguments.Add(SyntaxFactory.Argument( + thirdArgNode + .WithoutLeadingTrivia() + .WithoutTrailingTrivia() + .WithLeadingTrivia(SyntaxFactory.Space))); + + // Add remaining arguments (e.g., message) with their original separators and trivia + var originalSeparators = argumentList.Arguments.GetSeparators().ToList(); + for (int i = conditionIndex + 1; i < argumentList.Arguments.Count; i++) + { + if (i - 1 < originalSeparators.Count) + { + newSeparators.Add(originalSeparators[i - 1]); + } + + newArguments.Add(argumentList.Arguments[i]); + } + + ArgumentListSyntax newArgumentList = argumentList.WithArguments(SyntaxFactory.SeparatedList(newArguments, newSeparators)); + editor.ReplaceNode(argumentList, newArgumentList); + + return editor.GetChangedDocument(); + } + + // FindNode may return an ArgumentSyntax when the location span coincides with + // the span of an argument in the inner invocation (e.g. Contains(item, comparer)). + // We handle this by extracting the expression from the ArgumentSyntax. + private static bool TryGetExpressionFromNode(SyntaxNode root, Location location, [NotNullWhen(true)] out ExpressionSyntax? expression) + { + SyntaxNode? node = root.FindNode(location.SourceSpan); + expression = node switch + { + ArgumentSyntax argument => argument.Expression, + ExpressionSyntax expr => expr, + _ => null, + }; + + return expression is not null; + } + private static async Task FixAssertMethodForRemoveArgumentModeAsync( Document document, IReadOnlyList additionalLocations, diff --git a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs index da443496d4..be9462eac4 100644 --- a/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs +++ b/src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs @@ -73,9 +73,49 @@ namespace MSTest.Analyzers; /// /// /// +/// Assert.AreEqual([0|X], myEnumerable.Count()) +/// +/// +/// +/// +/// Assert.AreNotEqual(0, myCollection.[Count|Length]) +/// +/// +/// +/// +/// Assert.AreNotEqual(0, myEnumerable.Count()) +/// +/// +/// +/// /// Assert.IsTrue(myCollection.[Count|Length] [>|!=|==] 0) /// /// +/// +/// +/// Assert.IsTrue(myEnumerable.Count() [>|!=|==] 0) +/// +/// +/// +/// +/// Assert.IsTrue(myEnumerable.Any()) +/// +/// +/// +/// +/// Assert.IsFalse(myEnumerable.Any()) +/// +/// +/// +/// +/// Assert.IsTrue(myEnumerable.Contains(item, comparer)) +/// +/// +/// +/// +/// Assert.IsFalse(myEnumerable.Contains(item, comparer)) +/// +/// /// /// [DiagnosticAnalyzer(LanguageNames.CSharp, LanguageNames.VisualBasic)] @@ -116,6 +156,7 @@ private enum CollectionCheckStatus { Unknown, Contains, + ContainsWithComparer, } private enum CountCheckStatus @@ -229,6 +270,22 @@ private enum LinqPredicateCheckStatus /// internal const string CodeFixModeRemoveArgumentReplaceArgumentAndAddArgument = nameof(CodeFixModeRemoveArgumentReplaceArgumentAndAddArgument); + /// + /// This mode means the codefix operation is as follows: + /// + /// Find the right assert method name from the properties bag using . + /// Replace the identifier syntax for the invocation with the right assert method name. The identifier syntax is calculated by the codefix. + /// Replace the syntax node from the first additional locations with three new arguments from the second, third, and fourth additional locations. + /// + /// Example: For Assert.IsTrue(collection.Contains(item, comparer)), it will become Assert.Contains(item, collection, comparer). + /// The value for ProperAssertMethodNameKey is "Contains". + /// The first additional location will point to the "collection.Contains(item, comparer)" node. + /// The second additional location will point to the "item" node. + /// The third additional location will point to the "collection" node. + /// The fourth additional location will point to the "comparer" node. + /// + internal const string CodeFixModeAddTwoArguments = nameof(CodeFixModeAddTwoArguments); + private static readonly LocalizableResourceString Title = new(nameof(Resources.UseProperAssertMethodsTitle), Resources.ResourceManager, typeof(Resources)); private static readonly LocalizableResourceString MessageFormat = new(nameof(Resources.UseProperAssertMethodsMessageFormat), Resources.ResourceManager, typeof(Resources)); @@ -495,7 +552,8 @@ private static CollectionCheckStatus RecognizeCollectionMethodCheck( INamedTypeSymbol objectTypeSymbol, INamedTypeSymbol? enumerableTypeSymbol, out SyntaxNode? collectionExpression, - out SyntaxNode? itemExpression) + out SyntaxNode? itemExpression, + out SyntaxNode? comparerExpression) { if (operation is IInvocationOperation invocation) { @@ -524,6 +582,7 @@ private static CollectionCheckStatus RecognizeCollectionMethodCheck( // So, even if we are dealing with KeyedCollection, the types won't match, and we won't produce a diagnostic. collectionExpression = invocation.Instance?.Syntax; itemExpression = invocation.Arguments[0].Value.Syntax; + comparerExpression = null; return CollectionCheckStatus.Contains; } } @@ -539,12 +598,27 @@ enumerableTypeSymbol is not null && { collectionExpression = invocation.Arguments[0].Value.Syntax; itemExpression = invocation.Arguments[1].Value.Syntax; + comparerExpression = null; return CollectionCheckStatus.Contains; } + + // Handle LINQ Enumerable.Contains(this IEnumerable, TSource, IEqualityComparer) + // In the Roslyn operation model, Arguments includes the 'this' parameter, so Arguments.Length == 3. + if (methodName == "Contains" && + invocation.Arguments.Length == 3 && + enumerableTypeSymbol is not null && + SymbolEqualityComparer.Default.Equals(invocation.TargetMethod.ContainingType, enumerableTypeSymbol)) + { + collectionExpression = invocation.Arguments[0].Value.Syntax; + itemExpression = invocation.Arguments[1].Value.Syntax; + comparerExpression = invocation.Arguments[2].Value.Syntax; + return CollectionCheckStatus.ContainsWithComparer; + } } collectionExpression = null; itemExpression = null; + comparerExpression = null; return CollectionCheckStatus.Unknown; } @@ -762,7 +836,7 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co } // Check for collection method patterns: myCollection.Contains(...) - CollectionCheckStatus collectionMethodStatus = RecognizeCollectionMethodCheck(conditionArgument, objectTypeSymbol, enumerableTypeSymbol, out SyntaxNode? collectionExpr, out SyntaxNode? itemExpr); + CollectionCheckStatus collectionMethodStatus = RecognizeCollectionMethodCheck(conditionArgument, objectTypeSymbol, enumerableTypeSymbol, out SyntaxNode? collectionExpr, out SyntaxNode? itemExpr, out SyntaxNode? comparerExpr); if (collectionMethodStatus != CollectionCheckStatus.Unknown) { if (collectionMethodStatus == CollectionCheckStatus.Contains) @@ -780,10 +854,26 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co isTrueInvocation ? "IsTrue" : "IsFalse")); return; } + + if (collectionMethodStatus == CollectionCheckStatus.ContainsWithComparer) + { + string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain"; + + ImmutableDictionary.Builder properties = ImmutableDictionary.CreateBuilder(); + properties.Add(ProperAssertMethodNameKey, properAssertMethod); + properties.Add(CodeFixModeKey, CodeFixModeAddTwoArguments); + context.ReportDiagnostic(context.Operation.CreateDiagnostic( + Rule, + additionalLocations: ImmutableArray.Create(conditionArgument.Syntax.GetLocation(), itemExpr!.GetLocation(), collectionExpr!.GetLocation(), comparerExpr!.GetLocation()), + properties: properties.ToImmutable(), + properAssertMethod, + isTrueInvocation ? "IsTrue" : "IsFalse")); + return; + } } // Check for collection emptiness patterns: myCollection.Count > 0, myCollection.Count != 0, or myCollection.Count == 0 - CountCheckStatus countStatus = RecognizeCountCheck(conditionArgument, objectTypeSymbol, out SyntaxNode? collectionEmptinessExpr); + CountCheckStatus countStatus = RecognizeCountCheck(conditionArgument, objectTypeSymbol, enumerableTypeSymbol, out SyntaxNode? collectionEmptinessExpr); if (countStatus != CountCheckStatus.Unknown) { string properAssertMethod = countStatus switch @@ -985,6 +1075,7 @@ expectedArgument.ConstantValue.Value is int expectedCountValue && expectedArgument, actualArgumentValue, objectTypeSymbol, + enumerableTypeSymbol, out SyntaxNode? nodeToBeReplaced1, out SyntaxNode? replacement1, out SyntaxNode? nodeToBeReplaced2, @@ -1035,6 +1126,47 @@ expectedArgument.ConstantValue.Value is int expectedCountValue && } } + // Check for AreNotEqual(0, collection.Count/Length) or AreNotEqual(0, enumerable.Count()) → IsNotEmpty + if (!isAreEqualInvocation && + TryGetSecondArgumentValue((IInvocationOperation)context.Operation, out IOperation? actualArgumentValueNotEqual)) + { + CountCheckStatus notEqualCountStatus = RecognizeCountCheck( + expectedArgument, + actualArgumentValueNotEqual, + objectTypeSymbol, + enumerableTypeSymbol, + out SyntaxNode? nodeToBeReplacedNE1, + out SyntaxNode? replacementNE1, + out SyntaxNode? nodeToBeReplacedNE2, + out _); + + // We only handle IsEmpty (i.e. AreNotEqual(0, count) → IsNotEmpty). + // HasCount is intentionally not handled: there's no semantic equivalent for + // AreNotEqual(N, count) where N != 0. + if (notEqualCountStatus == CountCheckStatus.IsEmpty) + { + if (nodeToBeReplacedNE1 is null || replacementNE1 is null || nodeToBeReplacedNE2 is null) + { + throw ApplicationStateGuard.Unreachable(); + } + + // AreNotEqual(0, collection.Count/Length/Count()) → IsNotEmpty(collection) + ImmutableDictionary.Builder properties = ImmutableDictionary.CreateBuilder(); + properties.Add(ProperAssertMethodNameKey, "IsNotEmpty"); + properties.Add(CodeFixModeKey, CodeFixModeRemoveArgumentAndReplaceArgument); + context.ReportDiagnostic(context.Operation.CreateDiagnostic( + Rule, + additionalLocations: ImmutableArray.Create( + nodeToBeReplacedNE2.GetLocation(), + nodeToBeReplacedNE1.GetLocation(), + replacementNE1.GetLocation()), + properties: properties.ToImmutable(), + "IsNotEmpty", + "AreNotEqual")); + return; + } + } + // Don't flag a warning for Assert.AreNotEqual([true|false], x). // This is not the same as Assert.IsFalse(x). if (isAreEqualInvocation && expectedArgument is ILiteralOperation { ConstantValue: { HasValue: true, Value: bool expectedLiteralBoolean } }) @@ -1149,6 +1281,7 @@ private static bool TryMatchLinqMethod( private static CountCheckStatus RecognizeCountCheck( IOperation operation, INamedTypeSymbol objectTypeSymbol, + INamedTypeSymbol? enumerableTypeSymbol, out SyntaxNode? collectionExpression) { collectionExpression = null; @@ -1161,6 +1294,14 @@ private static CountCheckStatus RecognizeCountCheck( return CountCheckStatus.HasCount; } + // Check for enumerable.Count() > 0 + if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.GreaterThan, LeftOperand: IInvocationOperation linqCountInv1, RightOperand: ILiteralOperation { ConstantValue: { HasValue: true, Value: 0 } } } && + TryGetLinqCountNoPredicate(linqCountInv1, enumerableTypeSymbol, out SyntaxNode? linqExpr1)) + { + collectionExpression = linqExpr1; + return CountCheckStatus.HasCount; + } + // Check for 0 < collection.Count or 0 < collection.Length if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.LessThan, LeftOperand: ILiteralOperation { ConstantValue: { HasValue: true, Value: 0 } }, RightOperand: IPropertyReferenceOperation propertyRef2 } && TryGetCollectionExpressionIfBCLCollectionLengthOrCount(propertyRef2, objectTypeSymbol) is { } expression2) @@ -1169,6 +1310,14 @@ private static CountCheckStatus RecognizeCountCheck( return CountCheckStatus.HasCount; } + // Check for 0 < enumerable.Count() + if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.LessThan, LeftOperand: ILiteralOperation { ConstantValue: { HasValue: true, Value: 0 } }, RightOperand: IInvocationOperation linqCountInv2 } && + TryGetLinqCountNoPredicate(linqCountInv2, enumerableTypeSymbol, out SyntaxNode? linqExpr2)) + { + collectionExpression = linqExpr2; + return CountCheckStatus.HasCount; + } + // Check for collection.Count != 0 or collection.Length != 0 if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.NotEquals, LeftOperand: IPropertyReferenceOperation propertyRef3, RightOperand: ILiteralOperation { ConstantValue: { HasValue: true, Value: 0 } } } && TryGetCollectionExpressionIfBCLCollectionLengthOrCount(propertyRef3, objectTypeSymbol) is { } expression3) @@ -1177,6 +1326,14 @@ private static CountCheckStatus RecognizeCountCheck( return CountCheckStatus.HasCount; } + // Check for enumerable.Count() != 0 + if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.NotEquals, LeftOperand: IInvocationOperation linqCountInv3, RightOperand: ILiteralOperation { ConstantValue: { HasValue: true, Value: 0 } } } && + TryGetLinqCountNoPredicate(linqCountInv3, enumerableTypeSymbol, out SyntaxNode? linqExpr3)) + { + collectionExpression = linqExpr3; + return CountCheckStatus.HasCount; + } + // Check for 0 != collection.Count or 0 != collection.Length (reverse order) if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.NotEquals, LeftOperand: ILiteralOperation { ConstantValue: { HasValue: true, Value: 0 } }, RightOperand: IPropertyReferenceOperation propertyRef4 } && TryGetCollectionExpressionIfBCLCollectionLengthOrCount(propertyRef4, objectTypeSymbol) is { } expression4) @@ -1185,6 +1342,14 @@ private static CountCheckStatus RecognizeCountCheck( return CountCheckStatus.HasCount; } + // Check for 0 != enumerable.Count() + if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.NotEquals, LeftOperand: ILiteralOperation { ConstantValue: { HasValue: true, Value: 0 } }, RightOperand: IInvocationOperation linqCountInv4 } && + TryGetLinqCountNoPredicate(linqCountInv4, enumerableTypeSymbol, out SyntaxNode? linqExpr4)) + { + collectionExpression = linqExpr4; + return CountCheckStatus.HasCount; + } + // Check for collection.Count == 0 or collection.Length == 0 if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.Equals, LeftOperand: IPropertyReferenceOperation propertyRef5, RightOperand: ILiteralOperation { ConstantValue: { HasValue: true, Value: 0 } } } && TryGetCollectionExpressionIfBCLCollectionLengthOrCount(propertyRef5, objectTypeSymbol) is { } expression5) @@ -1193,6 +1358,14 @@ private static CountCheckStatus RecognizeCountCheck( return CountCheckStatus.IsEmpty; } + // Check for enumerable.Count() == 0 + if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.Equals, LeftOperand: IInvocationOperation linqCountInv5, RightOperand: ILiteralOperation { ConstantValue: { HasValue: true, Value: 0 } } } && + TryGetLinqCountNoPredicate(linqCountInv5, enumerableTypeSymbol, out SyntaxNode? linqExpr5)) + { + collectionExpression = linqExpr5; + return CountCheckStatus.IsEmpty; + } + // Check for 0 == collection.Count or 0 == collection.Length (reverse order) if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.Equals, LeftOperand: ILiteralOperation { ConstantValue: { HasValue: true, Value: 0 } }, RightOperand: IPropertyReferenceOperation propertyRef6 } && TryGetCollectionExpressionIfBCLCollectionLengthOrCount(propertyRef6, objectTypeSymbol) is { } expression6) @@ -1201,6 +1374,26 @@ private static CountCheckStatus RecognizeCountCheck( return CountCheckStatus.IsEmpty; } + // Check for 0 == enumerable.Count() + if (operation is IBinaryOperation { OperatorKind: BinaryOperatorKind.Equals, LeftOperand: ILiteralOperation { ConstantValue: { HasValue: true, Value: 0 } }, RightOperand: IInvocationOperation linqCountInv6 } && + TryGetLinqCountNoPredicate(linqCountInv6, enumerableTypeSymbol, out SyntaxNode? linqExpr6)) + { + collectionExpression = linqExpr6; + return CountCheckStatus.IsEmpty; + } + + // Check for enumerable.Any() (no predicate) - direct invocation. + // NOTE: We return HasCount here because the caller (AnalyzeIsTrueOrIsFalseInvocation) + // maps HasCount → IsNotEmpty (for IsTrue) and HasCount → IsEmpty (for IsFalse), + // which is the correct behavior for Any(). This method is NOT called from the + // AreEqual/AreNotEqual path (which uses the two-argument overload), so the HasCount + // value won't be misinterpreted as suggesting Assert.HasCount for Any(). + if (TryGetLinqAnyNoPredicate(operation, enumerableTypeSymbol, out SyntaxNode? linqAnyExpr)) + { + collectionExpression = linqAnyExpr; + return CountCheckStatus.HasCount; + } + return CountCheckStatus.Unknown; } @@ -1208,6 +1401,7 @@ private static CountCheckStatus RecognizeCountCheck( IOperation expectedArgument, IOperation actualArgument, INamedTypeSymbol objectTypeSymbol, + INamedTypeSymbol? enumerableTypeSymbol, out SyntaxNode? nodeToBeReplaced1, out SyntaxNode? replacement1, out SyntaxNode? nodeToBeReplaced2, @@ -1245,6 +1439,36 @@ expectedArgument.ConstantValue.Value is int expectedValue && } } + // Check if actualArgument is a LINQ Count() call with no predicate + if (actualArgument is IInvocationOperation linqCountInvocation && + TryGetLinqCountNoPredicate(linqCountInvocation, enumerableTypeSymbol, out SyntaxNode? linqCollection)) + { + bool isEmpty = expectedArgument.ConstantValue.HasValue && + expectedArgument.ConstantValue.Value is int expectedLinqValue && + expectedLinqValue == 0; + + if (isEmpty) + { + // We have Assert.AreEqual(0, enumerable.Count()) + // We want Assert.IsEmpty(enumerable) + nodeToBeReplaced1 = actualArgument.Syntax; // enumerable.Count() + replacement1 = linqCollection; // enumerable + nodeToBeReplaced2 = expectedArgument.Syntax; // 0 + replacement2 = null; + return CountCheckStatus.IsEmpty; + } + else + { + // We have Assert.AreEqual(expectedCount, enumerable.Count()) + // We want Assert.HasCount(expectedCount, enumerable) + nodeToBeReplaced1 = actualArgument.Syntax; // enumerable.Count() + replacement1 = linqCollection; // enumerable + nodeToBeReplaced2 = null; + replacement2 = null; + return CountCheckStatus.HasCount; + } + } + nodeToBeReplaced1 = null; replacement1 = null; nodeToBeReplaced2 = null; @@ -1259,6 +1483,50 @@ expectedArgument.ConstantValue.Value is int expectedValue && ? propertyReference.Instance.Syntax : null; + /// + /// Sets to the collection syntax node if the operation is a LINQ Count() call with no predicate, + /// and returns ; otherwise sets it to and returns . + /// + private static bool TryGetLinqCountNoPredicate(IOperation operation, INamedTypeSymbol? enumerableTypeSymbol, [NotNullWhen(true)] out SyntaxNode? collectionExpression) + { + // LINQ Count() with no predicate is Enumerable.Count(this IEnumerable) + // In the Roslyn operation model, extension calls include 'this' in Arguments, so Arguments.Length == 1. + if (enumerableTypeSymbol is not null && + operation is IInvocationOperation invocation && + invocation.TargetMethod.Name == "Count" && + invocation.Arguments.Length == 1 && + SymbolEqualityComparer.Default.Equals(invocation.TargetMethod.ContainingType, enumerableTypeSymbol)) + { + collectionExpression = invocation.Arguments[0].Value.Syntax; + return true; + } + + collectionExpression = null; + return false; + } + + /// + /// Sets to the collection syntax node if the operation is a LINQ Any() call with no predicate, + /// and returns ; otherwise sets it to and returns . + /// + private static bool TryGetLinqAnyNoPredicate(IOperation operation, INamedTypeSymbol? enumerableTypeSymbol, [NotNullWhen(true)] out SyntaxNode? collectionExpression) + { + // LINQ Any() with no predicate is Enumerable.Any(this IEnumerable) + // In the Roslyn operation model, extension calls include 'this' in Arguments, so Arguments.Length == 1. + if (enumerableTypeSymbol is not null && + operation is IInvocationOperation invocation && + invocation.TargetMethod.Name == "Any" && + invocation.Arguments.Length == 1 && + SymbolEqualityComparer.Default.Equals(invocation.TargetMethod.ContainingType, enumerableTypeSymbol)) + { + collectionExpression = invocation.Arguments[0].Value.Syntax; + return true; + } + + collectionExpression = null; + return false; + } + private static bool TryGetFirstArgumentValue(IInvocationOperation operation, [NotNullWhen(true)] out IOperation? argumentValue) => TryGetArgumentValueForParameterOrdinal(operation, 0, out argumentValue); diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs index 8fb29a6385..e43eec561a 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/UseProperAssertMethodsAnalyzerTests.cs @@ -2100,12 +2100,32 @@ public class MyTestClass public void MyTestMethod() { var list = new List(); - Assert.AreNotEqual(0, list.Count); + {|#0:Assert.AreNotEqual(0, list.Count)|}; } } """; - await VerifyCS.VerifyAnalyzerAsync(code); + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var list = new List(); + Assert.IsNotEmpty(list); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + // /0/Test0.cs(13,9): info MSTEST0037: Use 'Assert.IsNotEmpty' instead of 'Assert.AreNotEqual' + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsNotEmpty", "AreNotEqual"), + fixedCode); } [TestMethod] @@ -4135,4 +4155,778 @@ await VerifyCS.VerifyCodeFixAsync( } #endregion + + #region LINQ enumerable.Any(), enumerable.Count(), and Contains with comparer + + [TestMethod] + public async Task WhenAssertIsTrueWithEnumerableAnyNoPredicate() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { 1, 2, 3 }; + {|#0:Assert.IsTrue(items.Any())|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { 1, 2, 3 }; + Assert.IsNotEmpty(items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsNotEmpty", "IsTrue"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsFalseWithEnumerableAnyNoPredicate() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + {|#0:Assert.IsFalse(items.Any())|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + Assert.IsEmpty(items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsEmpty", "IsFalse"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsTrueWithEnumerableCountGreaterThanZero() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { 1, 2, 3 }; + {|#0:Assert.IsTrue(items.Count() > 0)|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { 1, 2, 3 }; + Assert.IsNotEmpty(items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsNotEmpty", "IsTrue"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsFalseWithEnumerableCountGreaterThanZero() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + {|#0:Assert.IsFalse(items.Count() > 0)|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + Assert.IsEmpty(items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsEmpty", "IsFalse"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsTrueWithEnumerableCountNotEqualToZero() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { 1, 2, 3 }; + {|#0:Assert.IsTrue(items.Count() != 0)|}; + {|#1:Assert.IsTrue(0 != items.Count())|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { 1, 2, 3 }; + Assert.IsNotEmpty(items); + Assert.IsNotEmpty(items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + [ + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsNotEmpty", "IsTrue"), + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(1).WithArguments("IsNotEmpty", "IsTrue"), + ], + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsTrueWithEnumerableCountEqualsZero() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + {|#0:Assert.IsTrue(items.Count() == 0)|}; + {|#1:Assert.IsTrue(0 == items.Count())|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + Assert.IsEmpty(items); + Assert.IsEmpty(items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + [ + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsEmpty", "IsTrue"), + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(1).WithArguments("IsEmpty", "IsTrue"), + ], + fixedCode); + } + + [TestMethod] + public async Task WhenAssertAreEqualWithEnumerableCountZero() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + {|#0:Assert.AreEqual(0, items.Count())|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + Assert.IsEmpty(items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsEmpty", "AreEqual"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertAreEqualWithEnumerableCountNonZero() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { 1, 2, 3 }; + {|#0:Assert.AreEqual(3, items.Count())|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { 1, 2, 3 }; + Assert.HasCount(3, items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("HasCount", "AreEqual"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertAreNotEqualWithEnumerableCountZero() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { 1, 2, 3 }; + {|#0:Assert.AreNotEqual(0, items.Count())|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { 1, 2, 3 }; + Assert.IsNotEmpty(items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsNotEmpty", "AreNotEqual"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertAreEqualWithArrayLengthZero() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var array = new int[] {}; + {|#0:Assert.AreEqual(0, array.Length)|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var array = new int[] {}; + Assert.IsEmpty(array); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsEmpty", "AreEqual"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertAreEqualWithArrayLengthNonZero() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var array = new int[] { 1, 2, 3 }; + {|#0:Assert.AreEqual(3, array.Length)|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var array = new int[] { 1, 2, 3 }; + Assert.HasCount(3, array); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("HasCount", "AreEqual"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertAreNotEqualWithArrayLengthZero() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var array = new int[] { 1, 2, 3 }; + {|#0:Assert.AreNotEqual(0, array.Length)|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + var array = new int[] { 1, 2, 3 }; + Assert.IsNotEmpty(array); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsNotEmpty", "AreNotEqual"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsTrueWithLinqContainsAndComparer() + { + string code = """ + using System; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { "a", "b", "c" }; + var comparer = StringComparer.OrdinalIgnoreCase; + {|#0:Assert.IsTrue(items.Contains("B", comparer))|}; + } + } + """; + + string fixedCode = """ + using System; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { "a", "b", "c" }; + var comparer = StringComparer.OrdinalIgnoreCase; + Assert.Contains("B", items, comparer); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("Contains", "IsTrue"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsFalseWithLinqContainsAndComparer() + { + string code = """ + using System; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { "a", "b", "c" }; + var comparer = StringComparer.OrdinalIgnoreCase; + {|#0:Assert.IsFalse(items.Contains("X", comparer))|}; + } + } + """; + + string fixedCode = """ + using System; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { "a", "b", "c" }; + var comparer = StringComparer.OrdinalIgnoreCase; + Assert.DoesNotContain("X", items, comparer); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("DoesNotContain", "IsFalse"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsFalseWithEnumerableCountNotEqualToZero() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + {|#0:Assert.IsFalse(items.Count() != 0)|}; + {|#1:Assert.IsFalse(0 != items.Count())|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + Assert.IsEmpty(items); + Assert.IsEmpty(items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + [ + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsEmpty", "IsFalse"), + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(1).WithArguments("IsEmpty", "IsFalse"), + ], + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsFalseWithEnumerableCountEqualsZero() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + {|#0:Assert.IsFalse(items.Count() == 0)|}; + {|#1:Assert.IsFalse(0 == items.Count())|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + Assert.IsNotEmpty(items); + Assert.IsNotEmpty(items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + [ + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsNotEmpty", "IsFalse"), + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(1).WithArguments("IsNotEmpty", "IsFalse"), + ], + fixedCode); + } + + [TestMethod] + public async Task WhenAssertIsFalseWithZeroLessThanEnumerableCount() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + {|#0:Assert.IsFalse(0 < items.Count())|}; + } + } + """; + + string fixedCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List(); + Assert.IsEmpty(items); + } + } + """; + + await VerifyCS.VerifyCodeFixAsync( + code, + VerifyCS.DiagnosticIgnoringAdditionalLocations().WithLocation(0).WithArguments("IsEmpty", "IsFalse"), + fixedCode); + } + + [TestMethod] + public async Task WhenAssertAreNotEqualWithNonZeroEnumerableCount_NoDiagnostic() + { + string code = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System.Collections.Generic; + using System.Linq; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + IEnumerable items = new List { 1, 2, 3 }; + Assert.AreNotEqual(5, items.Count()); + } + } + """; + + await VerifyCS.VerifyAnalyzerAsync(code); + } + + #endregion }