Skip to content

Commit ed321bf

Browse files
committed
wip4
1 parent a95a619 commit ed321bf

2 files changed

Lines changed: 123 additions & 45 deletions

File tree

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ private module Input3 implements InputSig3 {
286286
(exists(resolveTupleFieldExpr(_, _)) implies any())
287287
}
288288

289+
class BoolType extends DataType {
290+
BoolType() { this.getTypeItem() instanceof Builtins::Bool }
291+
}
292+
289293
class AstNode = Rust::AstNode;
290294

291295
TypeMention getTypeAnnotation(AstNode n) {
@@ -304,32 +308,60 @@ private module Input3 implements InputSig3 {
304308
result = n.(ShorthandSelfParameterMention)
305309
}
306310

311+
class Expr = Rust::Expr;
312+
313+
class ConditionalExpr extends AstNode, IfExpr {
314+
Expr getCondition() { result = super.getCondition() }
315+
316+
Expr getThen() { result = super.getThen() }
317+
318+
Expr getElse() { result = super.getElse() }
319+
}
320+
321+
class BinaryExpr extends AstNode, Rust::BinaryExpr {
322+
Expr getLeftOperand() { result = super.getLhs() }
323+
324+
Expr getRightOperand() { result = super.getRhs() }
325+
}
326+
327+
class LogicalAndExpr extends BinaryExpr, Rust::LogicalAndExpr { }
328+
329+
class LogicalOrExpr extends BinaryExpr, Rust::LogicalOrExpr { }
330+
331+
abstract class Assignment extends BinaryExpr { }
332+
333+
class AssignExpr extends Assignment, Rust::AssignmentExpr { }
334+
335+
class ParenExpr extends AstNode, Rust::ParenExpr {
336+
AstNode getExpr() { result = super.getExpr() }
337+
}
338+
307339
class Variable extends Rust::Variable {
308340
AstNode getDefiningNode() {
309341
result = this.getPat().getName() or
310342
result = this.getParameter().(SelfParam)
311343
}
312344

313-
AstNode getAnAccess() { result = super.getAnAccess() }
345+
Expr getAnAccess() { result = super.getAnAccess() }
314346
}
315347

316-
abstract class Assignment extends AstNode {
348+
abstract class LetDeclaration extends AstNode {
317349
abstract predicate isCoercionSite();
318350

319351
abstract AstNode getLeftOperand();
320352

321353
abstract AstNode getRightOperand();
322354
}
323355

324-
private class LetExprAssignment extends Assignment, LetExpr {
356+
private class LetExprLetDeclaration extends LetDeclaration, LetExpr {
325357
override predicate isCoercionSite() { not this.getPat() instanceof IdentPat }
326358

327359
override AstNode getLeftOperand() { result = this.getPat() }
328360

329361
override AstNode getRightOperand() { result = this.getScrutinee() }
330362
}
331363

332-
private class LetStmtAssignment extends Assignment, LetStmt {
364+
private class LetStmtLetDeclaration extends LetDeclaration, LetStmt {
333365
override predicate isCoercionSite() {
334366
this.hasTypeRepr() or
335367
not identLetStmt(this, _, _)
@@ -340,18 +372,6 @@ private module Input3 implements InputSig3 {
340372
override AstNode getRightOperand() { result = this.getInitializer() }
341373
}
342374

343-
private class AssignmentExprAssignment extends Assignment, AssignmentExpr {
344-
override predicate isCoercionSite() { any() }
345-
346-
override AstNode getLeftOperand() { result = this.getLhs() }
347-
348-
override AstNode getRightOperand() { result = this.getRhs() }
349-
}
350-
351-
class ParenExpr extends AstNode, Rust::ParenExpr {
352-
AstNode getExpr() { result = super.getExpr() }
353-
}
354-
355375
predicate certainTypeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
356376
n1 =
357377
any(IdentPat ip |
@@ -824,8 +844,6 @@ private module CertainTypeInferenceInput {
824844
result = inferRefExprType(n) and
825845
path.isEmpty()
826846
or
827-
result = inferLogicalOperationType(n, path)
828-
or
829847
result = inferCertainStructExprType(n, path)
830848
or
831849
result = inferCertainStructPatType(n, path)
@@ -857,14 +875,6 @@ private module CertainTypeInferenceInput {
857875
}
858876
}
859877

860-
private Type inferLogicalOperationType(AstNode n, TypePath path) {
861-
exists(Builtins::Bool t, BinaryLogicalOperation be |
862-
n = [be, be.getLhs(), be.getRhs()] and
863-
path.isEmpty() and
864-
result = TDataType(t)
865-
)
866-
}
867-
868878
private Type inferAssignmentOperationType(AstNode n, TypePath path) {
869879
n instanceof AssignmentOperation and
870880
path.isEmpty() and

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2141,6 +2141,8 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
21412141

21422142
/**
21432143
* Provides the input to `Make3`.
2144+
*
2145+
* TODO: Eventually align the AST signature with that of the shared CFG library.
21442146
*/
21452147
signature module InputSig3 {
21462148
/**
@@ -2149,6 +2151,9 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
21492151
*/
21502152
default predicate cachedStageRevRef() { none() }
21512153

2154+
/** A boolean type. */
2155+
class BoolType extends Type;
2156+
21522157
/** An AST node. */
21532158
class AstNode {
21542159
/** Gets a textual representation of this AST node. */
@@ -2161,13 +2166,63 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
21612166
/** Gets the type annotation that applies to `n`, if any. */
21622167
TypeMention getTypeAnnotation(AstNode n);
21632168

2169+
/** An expression. */
2170+
class Expr extends AstNode;
2171+
2172+
/** A ternary conditional expression. */
2173+
class ConditionalExpr extends Expr {
2174+
/** Gets the condition of this expression. */
2175+
Expr getCondition();
2176+
2177+
/** Gets the true branch of this expression. */
2178+
Expr getThen();
2179+
2180+
/** Gets the false branch of this expression. */
2181+
Expr getElse();
2182+
}
2183+
2184+
/** A binary expression. */
2185+
class BinaryExpr extends Expr {
2186+
/** Gets the left operand of this binary expression. */
2187+
Expr getLeftOperand();
2188+
2189+
/** Gets the right operand of this binary expression. */
2190+
Expr getRightOperand();
2191+
}
2192+
2193+
/** A short-circuiting logical AND expression. */
2194+
class LogicalAndExpr extends BinaryExpr;
2195+
2196+
/** A short-circuiting logical OR expression. */
2197+
class LogicalOrExpr extends BinaryExpr;
2198+
2199+
/**
2200+
* An assignment expression, either compound or simple.
2201+
*
2202+
* Examples:
2203+
*
2204+
* ```
2205+
* x = y
2206+
* sum += element
2207+
* ```
2208+
*/
2209+
class Assignment extends BinaryExpr;
2210+
2211+
/** A simple assignment expression, for example `x = y`. */
2212+
class AssignExpr extends Assignment;
2213+
2214+
/** A parenthesized expression. */
2215+
class ParenExpr extends AstNode {
2216+
AstNode getExpr();
2217+
}
2218+
21642219
/** A variable, for example a local variable or a field. */
21652220
class Variable {
21662221
/** Gets the AST node that defines this variable. */
21672222
AstNode getDefiningNode();
21682223

21692224
/** Gets an access to this variable. */
2170-
AstNode getAnAccess();
2225+
Expr getAnAccess();
21712226

21722227
/** Gets a textual representation of this element. */
21732228
string toString();
@@ -2177,28 +2232,22 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
21772232
}
21782233

21792234
/**
2180-
* An assignment where type information can flow from one operand to the
2181-
* other.
2235+
* A `let` declaration, for example a local variable declaration.
21822236
*/
2183-
class Assignment extends AstNode {
2237+
class LetDeclaration extends AstNode {
21842238
/**
2185-
* Holds if this assignment is a coercion site, meaning that the type of the right
2239+
* Holds if this declaration is a coercion site, meaning that the type of the right
21862240
* operand may have to be coerced to the type of the left operand.
21872241
*/
21882242
predicate isCoercionSite();
21892243

2190-
/** Gets the left operand of this binary expression. */
2244+
/** Gets the left operand of this declaration. */
21912245
AstNode getLeftOperand();
21922246

2193-
/** Gets the right operand of this binary expression. */
2247+
/** Gets the right operand of this declaration. */
21942248
AstNode getRightOperand();
21952249
}
21962250

2197-
/** A parenthesized expression. */
2198-
class ParenExpr extends AstNode {
2199-
AstNode getExpr();
2200-
}
2201-
22022251
/**
22032252
* Holds if the types of `n1` at `path1` and `n2` at `path2` are certainly equal.
22042253
*/
@@ -2249,10 +2298,10 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22492298
(
22502299
exists(Variable v | n1 = v.getAnAccess() and n2 = v.getDefiningNode())
22512300
or
2252-
exists(Assignment a |
2253-
not a.isCoercionSite() and
2254-
n1 = a.getLeftOperand() and
2255-
n2 = a.getRightOperand()
2301+
exists(LetDeclaration let |
2302+
not let.isCoercionSite() and
2303+
n1 = let.getLeftOperand() and
2304+
n2 = let.getRightOperand()
22562305
)
22572306
or
22582307
n1 = n2.(ParenExpr).getExpr()
@@ -2273,6 +2322,16 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22732322
)
22742323
}
22752324

2325+
private Type inferLogicalOperationType(AstNode n, TypePath path) {
2326+
(
2327+
exists(LogicalAndExpr lae | n = [lae, lae.getLeftOperand(), lae.getRightOperand()]) or
2328+
exists(LogicalOrExpr loe | n = [loe, loe.getLeftOperand(), loe.getRightOperand()]) //or
2329+
// exists(LogicalNotExpr lne | n = [lne, lne.getOperand()])
2330+
) and
2331+
result instanceof BoolType and
2332+
path.isEmpty()
2333+
}
2334+
22762335
/** Gets the inferred certain type of `n` at `path`. */
22772336
cached
22782337
Type inferCertainType(AstNode n, TypePath path) {
@@ -2283,6 +2342,8 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22832342
or
22842343
result = inferCertainTypeInput(n, path)
22852344
or
2345+
result = inferLogicalOperationType(n, path)
2346+
or
22862347
infersCertainTypeAt(n, path, result.getATypeParameter())
22872348
}
22882349

@@ -2336,9 +2397,16 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
23362397
or
23372398
path1.isEmpty() and
23382399
path2.isEmpty() and
2339-
exists(Assignment a |
2340-
a.getLeftOperand() = n1 and
2341-
a.getRightOperand() = n2
2400+
(
2401+
exists(Assignment a |
2402+
a.getLeftOperand() = n1 and
2403+
a.getRightOperand() = n2
2404+
)
2405+
or
2406+
exists(LetDeclaration let |
2407+
let.getLeftOperand() = n1 and
2408+
let.getRightOperand() = n2
2409+
)
23422410
)
23432411
or
23442412
typeEqualityInput(n1, path1, n2, path2)

0 commit comments

Comments
 (0)