1 //          Copyright Basile Burg 2017.
2 // Distributed under the Boost Software License, Version 1.0.
3 //    (See accompanying file LICENSE_1_0.txt or copy at
4 //
5 module dscanner.analysis.vcall_in_ctor;
6 
7 import dscanner.analysis.base;
8 import dscanner.utils;
9 import dparse.ast, dparse.lexer;
10 import std.algorithm.searching : canFind;
11 import std.range: retro;
12 
13 /**
14  * Checks virtual calls from the constructor to methods defined in the same class.
15  *
16  * When not used carefully, virtual calls from constructors can lead to a call
17  * in a derived instance that's not yet constructed.
18  */
19 final class VcallCtorChecker : BaseAnalyzer
20 {
21     alias visit = BaseAnalyzer.visit;
22 
23 private:
24 
25     enum string KEY = "dscanner.vcall_ctor";
26     enum string MSG = "a virtual call inside a constructor may lead to"
27         ~ " unexpected results in the derived classes";
28 
29     // what's called in the ctor
30     Token[][] _ctorCalls;
31     // the virtual method in the classes
32     Token[][] _virtualMethods;
33 
34 
35     // The problem only happens in classes
36     bool[] _inClass = [false];
37     // The problem only happens in __ctor
38     bool[] _inCtor = [false];
39     // The problem only happens with call to virtual methods
40     bool[] _isVirtual = [true];
41     // The problem only happens with call to virtual methods
42     bool[] _isNestedFun = [false];
43     // The problem only happens in derived classes that override
44     bool[] _isFinal = [false];
45 
46     void pushVirtual(bool value)
47     {
48         _isVirtual ~= value;
49     }
50 
51     void pushInClass(bool value)
52     {
53         _inClass ~= value;
54         _ctorCalls.length += 1;
55         _virtualMethods.length += 1;
56     }
57 
58     void pushInCtor(bool value)
59     {
60         _inCtor ~= value;
61     }
62 
63     void pushNestedFunc(bool value)
64     {
65         _isNestedFun ~= value;
66     }
67 
68     void pushIsFinal(bool value)
69     {
70         _isFinal ~= value;
71     }
72 
73     void popVirtual()
74     {
75         _isVirtual.length -= 1;
76     }
77 
78     void popInClass()
79     {
80         _inClass.length -= 1;
81         _ctorCalls.length -= 1;
82         _virtualMethods.length -= 1;
83     }
84 
85     void popInCtor()
86     {
87         _inCtor.length -= 1;
88     }
89 
90     void popNestedFunc()
91     {
92         _isNestedFun.length -= 1;
93     }
94 
95     void popIsFinal()
96     {
97         _isFinal.length -= 1;
98     }
99 
100     void overwriteVirtual(bool value)
101     {
102         _isVirtual[$-1] = value;
103     }
104 
105     bool isVirtual()
106     {
107         return _isVirtual[$-1];
108     }
109 
110     bool isInClass()
111     {
112         return _inClass[$-1];
113     }
114 
115     bool isInCtor()
116     {
117         return _inCtor[$-1];
118     }
119 
120     bool isFinal()
121     {
122         return _isFinal[$-1];
123     }
124 
125     bool isInNestedFunc()
126     {
127         return _isNestedFun[$-1];
128     }
129 
130     void check()
131     {
132         foreach (call; _ctorCalls[$-1])
133             foreach (vm; _virtualMethods[$-1])
134         {
135             if (call == vm)
136             {
137                 addErrorMessage(call.line, call.column, KEY, MSG);
138                 break;
139             }
140         }
141     }
142 
143 public:
144 
145     ///
146     this(string fileName, bool skipTests = false)
147     {
148         super(fileName, null, skipTests);
149     }
150 
151     override void visit(const(ClassDeclaration) decl)
152     {
153         pushVirtual(true);
154         pushInClass(true);
155         pushNestedFunc(false);
156         decl.accept(this);
157         check();
158         popVirtual();
159         popInClass();
160         popNestedFunc();
161     }
162 
163     override void visit(const(Constructor) ctor)
164     {
165         pushInCtor(isInClass);
166         ctor.accept(this);
167         popInCtor();
168     }
169 
170     override void visit(const(Declaration) d)
171     {
172         // "<protection>:"
173         if (d.attributeDeclaration && d.attributeDeclaration.attribute)
174         {
175             const tp = d.attributeDeclaration.attribute.attribute.type;
176             overwriteVirtual(isProtection(tp) & (tp != tok!"private"));
177         }
178 
179         // "protection {}"
180         bool pop;
181         scope(exit) if (pop)
182             popVirtual;
183 
184         const bool hasAttribs = d.attributes !is null;
185         const bool hasStatic = hasAttribs ? d.attributes.canFind!(a => a.attribute.type == tok!"static") : false;
186         const bool hasFinal = hasAttribs ? d.attributes.canFind!(a => a.attribute.type == tok!"final") : false;
187 
188         if (d.attributes) foreach (attr; d.attributes.retro)
189         {
190             if (!hasStatic &&
191                (attr.attribute == tok!"public" || attr.attribute == tok!"protected"))
192             {
193                 pushVirtual(true);
194                 pop = true;
195                 break;
196             }
197             else if (hasStatic || attr.attribute == tok!"private" || attr.attribute == tok!"package")
198             {
199                 pushVirtual(false);
200                 pop = true;
201                 break;
202             }
203         }
204 
205         // final class... final function
206         if ((d.classDeclaration || d.functionDeclaration) && hasFinal)
207             pushIsFinal(true);
208 
209         d.accept(this);
210 
211         if ((d.classDeclaration || d.functionDeclaration) && hasFinal)
212             popIsFinal;
213     }
214 
215     override void visit(const(FunctionCallExpression) exp)
216     {
217         // nested function are not virtual
218         pushNestedFunc(true);
219         exp.accept(this);
220         popNestedFunc();
221     }
222 
223     override void visit(const(UnaryExpression) exp)
224     {
225         if (isInCtor)
226         // get function identifier for a call, only for this member (so no ident chain)
227         if (const IdentifierOrTemplateInstance iot = safeAccess(exp)
228             .functionCallExpression.unaryExpression.primaryExpression.identifierOrTemplateInstance)
229         {
230             const Token t = iot.identifier;
231             if (t != tok!"")
232             {
233                 _ctorCalls[$-1] ~= t;
234             }
235         }
236         exp.accept(this);
237     }
238 
239     override void visit(const(FunctionDeclaration) d)
240     {
241         if (isInClass() && !isInNestedFunc() && !isFinal() && !d.templateParameters)
242         {
243             bool virtualOnce;
244             bool notVirtualOnce;
245 
246             const bool hasAttribs = d.attributes !is null;
247             const bool hasStatic = hasAttribs ? d.attributes.canFind!(a => a.attribute.type == tok!"static") : false;
248 
249             // handle "private", "public"... for this declaration
250             if (d.attributes) foreach (attr; d.attributes.retro)
251             {
252                 if (!hasStatic &&
253                    (attr.attribute == tok!"public" || attr.attribute == tok!"protected"))
254                 {
255                     if (!isVirtual)
256                     {
257                         virtualOnce = true;
258                         break;
259                     }
260                 }
261                 else if (hasStatic || attr.attribute == tok!"private" || attr.attribute == tok!"package")
262                 {
263                     if (isVirtual)
264                     {
265                         notVirtualOnce = true;
266                         break;
267                     }
268                 }
269             }
270 
271             if (!isVirtual && virtualOnce)
272                 _virtualMethods[$-1] ~= d.name;
273             else if (isVirtual && !virtualOnce)
274                 _virtualMethods[$-1] ~= d.name;
275 
276         }
277         d.accept(this);
278     }
279 }
280 
281 unittest
282 {
283     import dscanner.analysis.config : StaticAnalysisConfig, Check, disabledConfig;
284     import dscanner.analysis.helpers : assertAnalyzerWarnings;
285     import std.stdio : stderr;
286     import std.format : format;
287 
288     StaticAnalysisConfig sac = disabledConfig();
289     sac.vcall_in_ctor = Check.enabled;
290 
291     // fails
292     assertAnalyzerWarnings(q{
293         class Bar
294         {
295             this(){foo();} // [warn]: %s
296             private:
297             public
298             void foo(){}
299 
300         }
301     }}.format(VcallCtorChecker.MSG), sac);
302 
303     assertAnalyzerWarnings(q{
304         class Bar
305         {
306             this()
307             {
308                 foo(); // [warn]: %s
309                 foo(); // [warn]: %s
310                 bar();
311             }
312             private: void bar();
313             public{void foo(){}}
314         }
315     }}.format(VcallCtorChecker.MSG, VcallCtorChecker.MSG), sac);
316 
317     assertAnalyzerWarnings(q{
318         class Bar
319         {
320             this()
321             {
322                 foo();
323                 bar(); // [warn]: %s
324             }
325             private: public void bar();
326             public private {void foo(){}}
327         }
328     }}.format(VcallCtorChecker.MSG), sac);
329 
330     // passes
331     assertAnalyzerWarnings(q{
332         class Bar
333         {
334             this(){foo();}
335             private void foo(){}
336         }
337     }, sac);
338 
339     assertAnalyzerWarnings(q{
340         class Bar
341         {
342             this(){foo();}
343             private {void foo(){}}
344         }
345     }, sac);
346 
347     assertAnalyzerWarnings(q{
348         class Bar
349         {
350             this(){foo();}
351             private public protected private void foo(){}
352         }
353     }, sac);
354 
355     assertAnalyzerWarnings(q{
356         class Bar
357         {
358             this(){foo();}
359             final private public protected void foo(){}
360         }
361     }, sac);
362 
363     assertAnalyzerWarnings(q{
364         class Bar
365         {
366             this(){foo();}
367             final void foo(){}
368         }
369     }, sac);
370 
371     assertAnalyzerWarnings(q{
372         final class Bar
373         {
374             public:
375             this(){foo();}
376             void foo(){}
377         }
378     }, sac);
379 
380     assertAnalyzerWarnings(q{
381         class Bar
382         {
383             public:
384             this(){foo();}
385             void foo(T)(){}
386         }
387     }, sac);
388 
389     assertAnalyzerWarnings(q{
390         class Foo
391         {
392             static void nonVirtual();
393             this(){nonVirtual();}
394         }
395     }, sac);
396 
397     assertAnalyzerWarnings(q{
398         class Foo
399         {
400             package void nonVirtual();
401             this(){nonVirtual();}
402         }
403     }, sac);
404 
405     import std.stdio: writeln;
406     writeln("Unittest for VcallCtorChecker passed");
407 }
408