1 //          Copyright Basile Burg 2017.
2 // Distributed under the Boost Software License, Version 1.0.
3 //    (See accompanying file LICENSE_1_0.txt or copy at
4 //
5 module dscanner.analysis.vcall_in_ctor;
6 
7 import dscanner.analysis.base;
8 import dparse.ast, dparse.lexer;
9 import std.algorithm: among;
10 import std.algorithm.iteration : filter;
11 import std.algorithm.searching : find;
12 import std.range.primitives : empty;
13 import std.range: retro;
14 
15 /**
16  * Checks virtual calls from the constructor to methods defined in the same class.
17  *
18  * When not used carefully, virtual calls from constructors can lead to a call
19  * in a derived instance that's not yet constructed.
20  */
21 class VcallCtorChecker : BaseAnalyzer
22 {
23     alias visit = BaseAnalyzer.visit;
24 
25 private:
26 
27     enum string KEY = "dscanner.vcall_ctor";
28     enum string MSG = "a virtual call inside a constructor may lead to"
29         ~ " unexpected results in the derived classes";
30 
31     // what's called in the ctor
32     Token[][] _ctorCalls;
33     // the virtual method in the classes
34     Token[][] _virtualMethods;
35 
36 
37     // The problem only happens in classes
38     bool[] _inClass = [false];
39     // The problem only happens in __ctor
40     bool[] _inCtor = [false];
41     // The problem only happens with call to virtual methods
42     bool[] _isVirtual = [true];
43     // The problem only happens with call to virtual methods
44     bool[] _isNestedFun = [false];
45     // The problem only happens in derived classes that override
46     bool[] _isFinal = [false];
47 
48     void pushVirtual(bool value)
49     {
50         _isVirtual ~= value;
51     }
52 
53     void pushInClass(bool value)
54     {
55         _inClass ~= value;
56         _ctorCalls.length += 1;
57         _virtualMethods.length += 1;
58     }
59 
60     void pushInCtor(bool value)
61     {
62         _inCtor ~= value;
63     }
64 
65     void pushNestedFunc(bool value)
66     {
67         _isNestedFun ~= value;
68     }
69 
70     void pushIsFinal(bool value)
71     {
72         _isFinal ~= value;
73     }
74 
75     void popVirtual()
76     {
77         _isVirtual.length -= 1;
78     }
79 
80     void popInClass()
81     {
82         _inClass.length -= 1;
83         _ctorCalls.length -= 1;
84         _virtualMethods.length -= 1;
85     }
86 
87     void popInCtor()
88     {
89         _inCtor.length -= 1;
90     }
91 
92     void popNestedFunc()
93     {
94         _isNestedFun.length -= 1;
95     }
96 
97     void popIsFinal()
98     {
99         _isFinal.length -= 1;
100     }
101 
102     void overwriteVirtual(bool value)
103     {
104         _isVirtual[$-1] = value;
105     }
106 
107     bool isVirtual()
108     {
109         return _isVirtual[$-1];
110     }
111 
112     bool isInClass()
113     {
114         return _inClass[$-1];
115     }
116 
117     bool isInCtor()
118     {
119         return _inCtor[$-1];
120     }
121 
122     bool isFinal()
123     {
124         return _isFinal[$-1];
125     }
126 
127     bool isInNestedFunc()
128     {
129         return _isNestedFun[$-1];
130     }
131 
132     void check()
133     {
134         foreach (call; _ctorCalls[$-1])
135             foreach (vm; _virtualMethods[$-1])
136         {
137             if (call == vm)
138             {
139                 addErrorMessage(call.line, call.column, KEY, MSG);
140                 break;
141             }
142         }
143     }
144 
145 public:
146 
147     ///
148     this(string fileName, bool skipTests = false)
149     {
150         super(fileName, null, skipTests);
151     }
152 
153     override void visit(const(ClassDeclaration) decl)
154     {
155         pushVirtual(true);
156         pushInClass(true);
157         pushNestedFunc(false);
158         decl.accept(this);
159         check();
160         popVirtual();
161         popInClass();
162         popNestedFunc();
163     }
164 
165     override void visit(const(Constructor) ctor)
166     {
167         pushInCtor(isInClass);
168         ctor.accept(this);
169         popInCtor();
170     }
171 
172     override void visit(const(Declaration) d)
173     {
174         // "<protection>:"
175         if (d.attributeDeclaration && d.attributeDeclaration.attribute)
176         {
177             const tp = d.attributeDeclaration.attribute.attribute.type;
178             overwriteVirtual(isProtection(tp) & (tp != tok!"private"));
179         }
180 
181         // "protection {}"
182         bool pop;
183         scope(exit) if (pop)
184             popVirtual;
185         if (d.attributes) foreach (attr; d.attributes.retro)
186         {
187             if (attr.attribute == tok!"public" || attr.attribute == tok!"protected" ||
188                 attr.attribute == tok!"package")
189             {
190                 pushVirtual(true);
191                 pop = true;
192                 break;
193             }
194             else if (attr.attribute == tok!"private")
195             {
196                 pushVirtual(false);
197                 pop = true;
198                 break;
199             }
200         }
201 
202         // final class... final function
203         const bool pf = !d.attributes.find!(a => a.attribute.type == tok!"final").empty;
204         if ((d.classDeclaration || d.functionDeclaration) && pf)
205             pushIsFinal(true);
206 
207         d.accept(this);
208 
209         if ((d.classDeclaration || d.functionDeclaration) && pf)
210             popIsFinal;
211     }
212 
213     override void visit(const(FunctionCallExpression) exp)
214     {
215         // nested function are not virtual
216         pushNestedFunc(true);
217         exp.accept(this);
218         popNestedFunc();
219     }
220 
221     override void visit(const(UnaryExpression) exp)
222     {
223         // get function identifier for a call, only for this member (so no ident chain)
224         if (isInCtor && exp.functionCallExpression &&
225             exp.functionCallExpression.unaryExpression &&
226             exp.functionCallExpression.unaryExpression.primaryExpression &&
227             exp.functionCallExpression.unaryExpression.primaryExpression
228                 .identifierOrTemplateInstance)
229         {
230             const Token t = exp.functionCallExpression.unaryExpression
231                 .primaryExpression.identifierOrTemplateInstance.identifier;
232 
233             if (t != tok!"")
234             {
235                 _ctorCalls[$-1] ~= t;
236             }
237         }
238         exp.accept(this);
239     }
240 
241     override void visit(const(FunctionDeclaration) d)
242     {
243         if (isInClass() && !isInNestedFunc() && !isFinal() && !d.templateParameters)
244         {
245             bool virtualOnce;
246             bool notVirtualOnce;
247 
248             // handle "private", "public"... for this declaration
249             if (d.attributes) foreach (attr; d.attributes.retro)
250             {
251                 if (attr.attribute == tok!"public" || attr.attribute ==  tok!"protected" ||
252                     attr.attribute ==  tok!"package")
253                 {
254                     if (!isVirtual)
255                     {
256                         virtualOnce = true;
257                         break;
258                     }
259                 }
260                 else if (attr.attribute == tok!"private")
261                 {
262                     if (isVirtual)
263                     {
264                         notVirtualOnce = true;
265                         break;
266                     }
267                 }
268             }
269 
270             if (!isVirtual && virtualOnce)
271                 _virtualMethods[$-1] ~= d.name;
272             else if (isVirtual && !virtualOnce)
273                 _virtualMethods[$-1] ~= d.name;
274 
275         }
276         d.accept(this);
277     }
278 }
279 
280 unittest
281 {
282     import dscanner.analysis.config : StaticAnalysisConfig, Check, disabledConfig;
283     import dscanner.analysis.helpers : assertAnalyzerWarnings;
284     import std.stdio : stderr;
285     import std.format : format;
286 
287     StaticAnalysisConfig sac = disabledConfig();
288     sac.vcall_in_ctor = Check.enabled;
289 
290     // fails
291     assertAnalyzerWarnings(q{
292         class Bar
293         {
294             this(){foo();} // [warn]: %s
295             private:
296             public
297             void foo(){}
298 
299         }
300     }}.format(VcallCtorChecker.MSG), sac);
301 
302     assertAnalyzerWarnings(q{
303         class Bar
304         {
305             this()
306             {
307                 foo(); // [warn]: %s
308                 foo(); // [warn]: %s
309                 bar();
310             }
311             private: void bar();
312             public{void foo(){}}
313         }
314     }}.format(VcallCtorChecker.MSG, VcallCtorChecker.MSG), sac);
315 
316     assertAnalyzerWarnings(q{
317         class Bar
318         {
319             this()
320             {
321                 foo();
322                 bar(); // [warn]: %s
323             }
324             private: public void bar();
325             public private {void foo(){}}
326         }
327     }}.format(VcallCtorChecker.MSG), sac);
328 
329     // passes
330     assertAnalyzerWarnings(q{
331         class Bar
332         {
333             this(){foo();}
334             private void foo(){}
335         }
336     }, sac);
337 
338     assertAnalyzerWarnings(q{
339         class Bar
340         {
341             this(){foo();}
342             private {void foo(){}}
343         }
344     }, sac);
345 
346     assertAnalyzerWarnings(q{
347         class Bar
348         {
349             this(){foo();}
350             private public protected private void foo(){}
351         }
352     }, sac);
353 
354     assertAnalyzerWarnings(q{
355         class Bar
356         {
357             this(){foo();}
358             final private public protected void foo(){}
359         }
360     }, sac);
361 
362     assertAnalyzerWarnings(q{
363         class Bar
364         {
365             this(){foo();}
366             final void foo(){}
367         }
368     }, sac);
369 
370     assertAnalyzerWarnings(q{
371         final class Bar
372         {
373             public:
374             this(){foo();}
375             void foo(){}
376         }
377     }, sac);
378 
379     assertAnalyzerWarnings(q{
380         class Bar
381         {
382             public:
383             this(){foo();}
384             void foo(T)(){}
385         }
386     }, sac);
387 
388     import std.stdio: writeln;
389     writeln("Unittest for VcallCtorChecker passed");
390 }
391