1 //          Copyright Basile Burg 2017.
2 // Distributed under the Boost Software License, Version 1.0.
3 //    (See accompanying file LICENSE_1_0.txt or copy at
4 //
5 module dscanner.analysis.vcall_in_ctor;
6 
7 import dscanner.analysis.base;
8 import dscanner.utils;
9 import dparse.ast, dparse.lexer;
10 import std.algorithm.searching : canFind;
11 import std.range: retro;
12 
13 /**
14  * Checks virtual calls from the constructor to methods defined in the same class.
15  *
16  * When not used carefully, virtual calls from constructors can lead to a call
17  * in a derived instance that's not yet constructed.
18  */
19 final class VcallCtorChecker : BaseAnalyzer
20 {
21     alias visit = BaseAnalyzer.visit;
22 
23     mixin AnalyzerInfo!"vcall_in_ctor";
24 
25 private:
26 
27     enum string KEY = "dscanner.vcall_ctor";
28     enum string MSG = "a virtual call inside a constructor may lead to"
29         ~ " unexpected results in the derived classes";
30 
31     // what's called in the ctor
32     Token[][] _ctorCalls;
33     // the virtual method in the classes
34     Token[][] _virtualMethods;
35 
36 
37     // The problem only happens in classes
38     bool[] _inClass = [false];
39     // The problem only happens in __ctor
40     bool[] _inCtor = [false];
41     // The problem only happens with call to virtual methods
42     bool[] _isVirtual = [true];
43     // The problem only happens with call to virtual methods
44     bool[] _isNestedFun = [false];
45     // The problem only happens in derived classes that override
46     bool[] _isFinal = [false];
47 
48     void pushVirtual(bool value)
49     {
50         _isVirtual ~= value;
51     }
52 
53     void pushInClass(bool value)
54     {
55         _inClass ~= value;
56         _ctorCalls.length += 1;
57         _virtualMethods.length += 1;
58     }
59 
60     void pushInCtor(bool value)
61     {
62         _inCtor ~= value;
63     }
64 
65     void pushNestedFunc(bool value)
66     {
67         _isNestedFun ~= value;
68     }
69 
70     void pushIsFinal(bool value)
71     {
72         _isFinal ~= value;
73     }
74 
75     void popVirtual()
76     {
77         _isVirtual.length -= 1;
78     }
79 
80     void popInClass()
81     {
82         _inClass.length -= 1;
83         _ctorCalls.length -= 1;
84         _virtualMethods.length -= 1;
85     }
86 
87     void popInCtor()
88     {
89         _inCtor.length -= 1;
90     }
91 
92     void popNestedFunc()
93     {
94         _isNestedFun.length -= 1;
95     }
96 
97     void popIsFinal()
98     {
99         _isFinal.length -= 1;
100     }
101 
102     void overwriteVirtual(bool value)
103     {
104         _isVirtual[$-1] = value;
105     }
106 
107     bool isVirtual()
108     {
109         return _isVirtual[$-1];
110     }
111 
112     bool isInClass()
113     {
114         return _inClass[$-1];
115     }
116 
117     bool isInCtor()
118     {
119         return _inCtor[$-1];
120     }
121 
122     bool isFinal()
123     {
124         return _isFinal[$-1];
125     }
126 
127     bool isInNestedFunc()
128     {
129         return _isNestedFun[$-1];
130     }
131 
132     void check()
133     {
134         foreach (call; _ctorCalls[$-1])
135             foreach (vm; _virtualMethods[$-1])
136         {
137             if (call == vm)
138             {
139                 addErrorMessage(call.line, call.column, KEY, MSG);
140                 break;
141             }
142         }
143     }
144 
145 public:
146 
147     ///
148     this(string fileName, bool skipTests = false)
149     {
150         super(fileName, null, skipTests);
151     }
152 
153     override void visit(const(ClassDeclaration) decl)
154     {
155         pushVirtual(true);
156         pushInClass(true);
157         pushNestedFunc(false);
158         decl.accept(this);
159         check();
160         popVirtual();
161         popInClass();
162         popNestedFunc();
163     }
164 
165     override void visit(const(Constructor) ctor)
166     {
167         pushInCtor(isInClass);
168         ctor.accept(this);
169         popInCtor();
170     }
171 
172     override void visit(const(Declaration) d)
173     {
174         // "<protection>:"
175         if (d.attributeDeclaration && d.attributeDeclaration.attribute)
176         {
177             const tp = d.attributeDeclaration.attribute.attribute.type;
178             overwriteVirtual(isProtection(tp) & (tp != tok!"private"));
179         }
180 
181         // "protection {}"
182         bool pop;
183         scope(exit) if (pop)
184             popVirtual;
185 
186         const bool hasAttribs = d.attributes !is null;
187         const bool hasStatic = hasAttribs ? d.attributes.canFind!(a => a.attribute.type == tok!"static") : false;
188         const bool hasFinal = hasAttribs ? d.attributes.canFind!(a => a.attribute.type == tok!"final") : false;
189 
190         if (d.attributes) foreach (attr; d.attributes.retro)
191         {
192             if (!hasStatic &&
193                (attr.attribute == tok!"public" || attr.attribute == tok!"protected"))
194             {
195                 pushVirtual(true);
196                 pop = true;
197                 break;
198             }
199             else if (hasStatic || attr.attribute == tok!"private" || attr.attribute == tok!"package")
200             {
201                 pushVirtual(false);
202                 pop = true;
203                 break;
204             }
205         }
206 
207         // final class... final function
208         if ((d.classDeclaration || d.functionDeclaration) && hasFinal)
209             pushIsFinal(true);
210 
211         d.accept(this);
212 
213         if ((d.classDeclaration || d.functionDeclaration) && hasFinal)
214             popIsFinal;
215     }
216 
217     override void visit(const(FunctionCallExpression) exp)
218     {
219         // nested function are not virtual
220         pushNestedFunc(true);
221         exp.accept(this);
222         popNestedFunc();
223     }
224 
225     override void visit(const(UnaryExpression) exp)
226     {
227         if (isInCtor)
228         // get function identifier for a call, only for this member (so no ident chain)
229         if (const IdentifierOrTemplateInstance iot = safeAccess(exp)
230             .functionCallExpression.unaryExpression.primaryExpression.identifierOrTemplateInstance)
231         {
232             const Token t = iot.identifier;
233             if (t != tok!"")
234             {
235                 _ctorCalls[$-1] ~= t;
236             }
237         }
238         exp.accept(this);
239     }
240 
241     override void visit(const(FunctionDeclaration) d)
242     {
243         if (isInClass() && !isInNestedFunc() && !isFinal() && !d.templateParameters)
244         {
245             bool virtualOnce;
246             bool notVirtualOnce;
247 
248             const bool hasAttribs = d.attributes !is null;
249             const bool hasStatic = hasAttribs ? d.attributes.canFind!(a => a.attribute.type == tok!"static") : false;
250 
251             // handle "private", "public"... for this declaration
252             if (d.attributes) foreach (attr; d.attributes.retro)
253             {
254                 if (!hasStatic &&
255                    (attr.attribute == tok!"public" || attr.attribute == tok!"protected"))
256                 {
257                     if (!isVirtual)
258                     {
259                         virtualOnce = true;
260                         break;
261                     }
262                 }
263                 else if (hasStatic || attr.attribute == tok!"private" || attr.attribute == tok!"package")
264                 {
265                     if (isVirtual)
266                     {
267                         notVirtualOnce = true;
268                         break;
269                     }
270                 }
271             }
272 
273             if (!isVirtual && virtualOnce)
274                 _virtualMethods[$-1] ~= d.name;
275             else if (isVirtual && !virtualOnce)
276                 _virtualMethods[$-1] ~= d.name;
277 
278         }
279         d.accept(this);
280     }
281 }
282 
283 unittest
284 {
285     import dscanner.analysis.config : StaticAnalysisConfig, Check, disabledConfig;
286     import dscanner.analysis.helpers : assertAnalyzerWarnings;
287     import std.stdio : stderr;
288     import std.format : format;
289 
290     StaticAnalysisConfig sac = disabledConfig();
291     sac.vcall_in_ctor = Check.enabled;
292 
293     // fails
294     assertAnalyzerWarnings(q{
295         class Bar
296         {
297             this(){foo();} // [warn]: %s
298             private:
299             public
300             void foo(){}
301 
302         }
303     }}.format(VcallCtorChecker.MSG), sac);
304 
305     assertAnalyzerWarnings(q{
306         class Bar
307         {
308             this()
309             {
310                 foo(); // [warn]: %s
311                 foo(); // [warn]: %s
312                 bar();
313             }
314             private: void bar();
315             public{void foo(){}}
316         }
317     }}.format(VcallCtorChecker.MSG, VcallCtorChecker.MSG), sac);
318 
319     assertAnalyzerWarnings(q{
320         class Bar
321         {
322             this()
323             {
324                 foo();
325                 bar(); // [warn]: %s
326             }
327             private: public void bar();
328             public private {void foo(){}}
329         }
330     }}.format(VcallCtorChecker.MSG), sac);
331 
332     // passes
333     assertAnalyzerWarnings(q{
334         class Bar
335         {
336             this(){foo();}
337             private void foo(){}
338         }
339     }, sac);
340 
341     assertAnalyzerWarnings(q{
342         class Bar
343         {
344             this(){foo();}
345             private {void foo(){}}
346         }
347     }, sac);
348 
349     assertAnalyzerWarnings(q{
350         class Bar
351         {
352             this(){foo();}
353             private public protected private void foo(){}
354         }
355     }, sac);
356 
357     assertAnalyzerWarnings(q{
358         class Bar
359         {
360             this(){foo();}
361             final private public protected void foo(){}
362         }
363     }, sac);
364 
365     assertAnalyzerWarnings(q{
366         class Bar
367         {
368             this(){foo();}
369             final void foo(){}
370         }
371     }, sac);
372 
373     assertAnalyzerWarnings(q{
374         final class Bar
375         {
376             public:
377             this(){foo();}
378             void foo(){}
379         }
380     }, sac);
381 
382     assertAnalyzerWarnings(q{
383         class Bar
384         {
385             public:
386             this(){foo();}
387             void foo(T)(){}
388         }
389     }, sac);
390 
391     assertAnalyzerWarnings(q{
392         class Foo
393         {
394             static void nonVirtual();
395             this(){nonVirtual();}
396         }
397     }, sac);
398 
399     assertAnalyzerWarnings(q{
400         class Foo
401         {
402             package void nonVirtual();
403             this(){nonVirtual();}
404         }
405     }, sac);
406 
407     import std.stdio: writeln;
408     writeln("Unittest for VcallCtorChecker passed");
409 }
410