Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
case UseProperAssertMethodsAnalyzer.CodeFixModeAddArgument:
createChangedDocument = ct => FixAssertMethodForAddArgumentModeAsync(context.Document, diagnostic.AdditionalLocations[0], diagnostic.AdditionalLocations[1], diagnostic.AdditionalLocations[2], root, simpleNameSyntax, properAssertMethodName, ct);
break;
case UseProperAssertMethodsAnalyzer.CodeFixModeAddTwoArguments:
createChangedDocument = ct => FixAssertMethodForAddTwoArgumentsModeAsync(context.Document, diagnostic.AdditionalLocations[0], diagnostic.AdditionalLocations[1], diagnostic.AdditionalLocations[2], diagnostic.AdditionalLocations[3], root, simpleNameSyntax, properAssertMethodName, ct);
break;
case UseProperAssertMethodsAnalyzer.CodeFixModeRemoveArgument:
createChangedDocument = ct => FixAssertMethodForRemoveArgumentModeAsync(context.Document, diagnostic.AdditionalLocations, root, simpleNameSyntax, properAssertMethodName, diagnostic.Properties.ContainsKey(UseProperAssertMethodsAnalyzer.NeedsNullableBooleanCastKey), ct);
break;
Expand Down Expand Up @@ -215,6 +218,98 @@ private static async Task<Document> FixAssertMethodForAddArgumentModeAsync(Docum
return editor.GetChangedDocument();
}

private static async Task<Document> FixAssertMethodForAddTwoArgumentsModeAsync(Document document, Location conditionLocation, Location firstArgLocation, Location secondArgLocation, Location thirdArgLocation, SyntaxNode root, SimpleNameSyntax simpleNameSyntax, string properAssertMethodName, CancellationToken cancellationToken)
{
// Handle Contains with comparer: Assert.IsTrue(enumerable.Contains(item, comparer)) -> Assert.Contains(item, enumerable, comparer)
if (root.FindNode(conditionLocation.SourceSpan) is not ArgumentSyntax conditionNode)
{
return document;
}

if (conditionNode.Parent is not ArgumentListSyntax argumentList)
{
return document;
}

// FindNode may return ArgumentSyntax (outermost for tied spans) when the expression
// is a direct child of an ArgumentSyntax in the inner invocation. Extract the expression.
if (!TryGetExpressionFromNode(root, firstArgLocation, out ExpressionSyntax? firstArgNode))
{
return document;
}

if (!TryGetExpressionFromNode(root, secondArgLocation, out ExpressionSyntax? secondArgNode))
{
return document;
}

if (!TryGetExpressionFromNode(root, thirdArgLocation, out ExpressionSyntax? thirdArgNode))
{
return document;
}

DocumentEditor editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
FixInvocationMethodName(editor, simpleNameSyntax, properAssertMethodName);

int conditionIndex = argumentList.Arguments.IndexOf(conditionNode);

// Build the new arguments list: (firstArg, secondArg, thirdArg, ...remaining)
var newArguments = new List<ArgumentSyntax>();
var newSeparators = new List<SyntaxToken>();

newArguments.Add(SyntaxFactory.Argument(
firstArgNode
.WithLeadingTrivia(conditionNode.GetLeadingTrivia())
.WithoutTrailingTrivia()));

newSeparators.Add(SyntaxFactory.Token(SyntaxKind.CommaToken));
newArguments.Add(SyntaxFactory.Argument(
secondArgNode
.WithoutLeadingTrivia()
.WithoutTrailingTrivia()
.WithLeadingTrivia(SyntaxFactory.Space)));

newSeparators.Add(SyntaxFactory.Token(SyntaxKind.CommaToken));
newArguments.Add(SyntaxFactory.Argument(
thirdArgNode
.WithoutLeadingTrivia()
.WithoutTrailingTrivia()
.WithLeadingTrivia(SyntaxFactory.Space)));

// Add remaining arguments (e.g., message) with their original separators and trivia
var originalSeparators = argumentList.Arguments.GetSeparators().ToList();
for (int i = conditionIndex + 1; i < argumentList.Arguments.Count; i++)
{
if (i - 1 < originalSeparators.Count)
{
newSeparators.Add(originalSeparators[i - 1]);
}

newArguments.Add(argumentList.Arguments[i]);
}

ArgumentListSyntax newArgumentList = argumentList.WithArguments(SyntaxFactory.SeparatedList(newArguments, newSeparators));
editor.ReplaceNode(argumentList, newArgumentList);

return editor.GetChangedDocument();
}

// FindNode may return an ArgumentSyntax when the location span coincides with
// the span of an argument in the inner invocation (e.g. Contains(item, comparer)).
// We handle this by extracting the expression from the ArgumentSyntax.
private static bool TryGetExpressionFromNode(SyntaxNode root, Location location, [NotNullWhen(true)] out ExpressionSyntax? expression)
{
SyntaxNode? node = root.FindNode(location.SourceSpan);
expression = node switch
{
ArgumentSyntax argument => argument.Expression,
ExpressionSyntax expr => expr,
_ => null,
};

return expression is not null;
}

private static async Task<Document> FixAssertMethodForRemoveArgumentModeAsync(
Document document,
IReadOnlyList<Location> additionalLocations,
Expand Down
Loading