1 module dscanner.analysis.mismatched_args;
2 
3 import dscanner.analysis.base;
4 import dscanner.utils : safeAccess;
5 import dsymbol.scope_;
6 import dsymbol.symbol;
7 import dparse.ast;
8 import dparse.lexer : tok;
9 import dsymbol.builtin.names;
10 
11 /// Checks for mismatched argument and parameter names
12 final class MismatchedArgumentCheck : BaseAnalyzer
13 {
14 	mixin AnalyzerInfo!"mismatched_args_check";
15 
16 	///
17 	this(string fileName, const(Scope)* sc, bool skipTests = false)
18 	{
19 		super(fileName, sc, skipTests);
20 	}
21 
22 	override void visit(const FunctionCallExpression fce)
23 	{
24 		import std.typecons : scoped;
25 		import std.algorithm.iteration : each, map;
26 		import std.array : array;
27 
28 		if (fce.arguments is null)
29 			return;
30 		auto argVisitor = scoped!ArgVisitor;
31 		argVisitor.visit(fce.arguments);
32 		const istring[] args = argVisitor.args;
33 
34 		auto identVisitor = scoped!IdentVisitor;
35 		if (fce.unaryExpression !is null)
36 			identVisitor.visit(fce.unaryExpression);
37 		else if (fce.type !is null)
38 			identVisitor.visit(fce.type);
39 
40 		const(DSymbol)*[] symbols = resolveSymbol(sc, identVisitor.names.length > 0
41 				? identVisitor.names : [CONSTRUCTOR_SYMBOL_NAME]);
42 
43 		static struct ErrorMessage
44 		{
45 			size_t line;
46 			size_t column;
47 			string message;
48 		}
49 
50 		ErrorMessage[] messages;
51 		bool matched;
52 
53 		foreach (sym; symbols)
54 		{
55 			// The cast is a hack because .array() confuses the compiler's overload
56 			// resolution code.
57 			const(istring)[] params = sym is null ? [] : sym.argNames[].map!(a => cast() a).array();
58 			const ArgMismatch[] mismatches = compareArgsToParams(params, args);
59 			if (mismatches.length == 0)
60 				matched = true;
61 			else
62 			{
63 				foreach (size_t i, ref const mm; mismatches)
64 				{
65 					messages ~= ErrorMessage(argVisitor.lines[i],
66 							argVisitor.columns[i], createWarningFromMismatch(mm));
67 				}
68 			}
69 		}
70 
71 		if (!matched)
72 			foreach (m; messages)
73 				addErrorMessage(m.line, m.column, KEY, m.message);
74 	}
75 
76 	alias visit = ASTVisitor.visit;
77 
78 private:
79 
80 	enum string KEY = "dscanner.confusing.argument_parameter_mismatch";
81 }
82 
83 final class IdentVisitor : ASTVisitor
84 {
85 	override void visit(const IdentifierOrTemplateInstance ioti)
86 	{
87 		import dsymbol.string_interning : internString;
88 
89 		if (ioti.identifier != tok!"")
90 			names ~= internString(ioti.identifier.text);
91 		else
92 			names ~= internString(ioti.templateInstance.identifier.text);
93 	}
94 
95 	override void visit(const Arguments)
96 	{
97 	}
98 
99 	override void visit(const IndexExpression ie)
100 	{
101 		if (ie.unaryExpression !is null)
102 			visit(ie.unaryExpression);
103 	}
104 
105 	alias visit = ASTVisitor.visit;
106 
107 	istring[] names;
108 }
109 
110 final class ArgVisitor : ASTVisitor
111 {
112 	override void visit(const ArgumentList al)
113 	{
114 		foreach (a; al.items)
115 		{
116 			auto u = cast(UnaryExpression) a;
117 			if (u !is null)
118 				visit(u);
119 			else
120 			{
121 				args ~= istring.init;
122 				lines ~= size_t.max;
123 				columns ~= size_t.max;
124 			}
125 		}
126 	}
127 
128 	override void visit(const UnaryExpression unary)
129 	{
130 		import dsymbol.string_interning : internString;
131 
132 		if (auto iot = unary.safeAccess.primaryExpression.identifierOrTemplateInstance.unwrap)
133 		{
134 			if (iot.identifier == tok!"")
135 				return;
136 			immutable t = iot.identifier;
137 			lines ~= t.line;
138 			columns ~= t.column;
139 			args ~= internString(t.text);
140 		}
141 	}
142 
143 	alias visit = ASTVisitor.visit;
144 
145 	size_t[] lines;
146 	size_t[] columns;
147 	istring[] args;
148 }
149 
150 const(DSymbol)*[] resolveSymbol(const Scope* sc, const istring[] symbolChain)
151 {
152 	import std.array : empty;
153 
154 	const(DSymbol)*[] matchingSymbols = sc.getSymbolsByName(symbolChain[0]);
155 	if (matchingSymbols.empty)
156 		return null;
157 
158 	foreach (ref symbol; matchingSymbols)
159 	{
160 		inner: foreach (i; 1 .. symbolChain.length)
161 		{
162 			if (symbol.kind == CompletionKind.variableName
163 					|| symbol.kind == CompletionKind.memberVariableName
164 					|| symbol.kind == CompletionKind.functionName)
165 				symbol = symbol.type;
166 			if (symbol is null)
167 			{
168 				symbol = null;
169 				break inner;
170 			}
171 			auto p = symbol.getPartsByName(symbolChain[i]);
172 			if (p.empty)
173 			{
174 				symbol = null;
175 				break inner;
176 			}
177 			symbol = p[0];
178 		}
179 	}
180 	return matchingSymbols;
181 }
182 
183 struct ArgMismatch
184 {
185 	size_t argIndex;
186 	size_t paramIndex;
187 	string name;
188 }
189 
190 immutable(ArgMismatch[]) compareArgsToParams(const istring[] params, const istring[] args) pure
191 {
192 	import std.exception : assumeUnique;
193 
194 	if (args.length != params.length)
195 		return [];
196 	ArgMismatch[] retVal;
197 	foreach (i, arg; args)
198 	{
199 		if (arg is null || arg == params[i])
200 			continue;
201 		foreach (j, param; params)
202 			if (param == arg)
203 				retVal ~= ArgMismatch(i, j, arg);
204 	}
205 	return assumeUnique(retVal);
206 }
207 
208 string createWarningFromMismatch(const ArgMismatch mismatch) pure
209 {
210 	import std.format : format;
211 
212 	return "Argument %d is named '%s', but this is the name of parameter %d".format(
213 			mismatch.argIndex + 1, mismatch.name, mismatch.paramIndex + 1);
214 }
215 
216 unittest
217 {
218 	import dsymbol.string_interning : internString;
219 	import std.algorithm.iteration : map;
220 	import std.array : array;
221 	import std.conv : to;
222 
223 	{
224 		istring[] args = ["a", "b", "c"].map!internString().array();
225 		istring[] params = ["a", "b", "c"].map!internString().array();
226 		immutable res = compareArgsToParams(params, args);
227 		assert(res == []);
228 	}
229 
230 	{
231 		istring[] args = ["a", "c", "b"].map!internString().array();
232 		istring[] params = ["a", "b", "c"].map!internString().array();
233 		immutable res = compareArgsToParams(params, args);
234 		assert(res == [ArgMismatch(1, 2, "c"), ArgMismatch(2, 1, "b")], to!string(res));
235 	}
236 
237 	{
238 		istring[] args = ["a", "c", "b"].map!internString().array();
239 		istring[] params = ["alpha", "bravo", "c"].map!internString().array();
240 		immutable res = compareArgsToParams(params, args);
241 		assert(res == [ArgMismatch(1, 2, "c")]);
242 	}
243 
244 	{
245 		istring[] args = ["a", "b"].map!internString().array();
246 		istring[] params = [null, "b"].map!internString().array();
247 		immutable res = compareArgsToParams(params, args);
248 		assert(res == []);
249 	}
250 }