Skip to content

Commit c337f51

Browse files
committed
Rust: Improve type inference for closures
1 parent ce1ce6a commit c337f51

File tree

5 files changed

+897
-149
lines changed

5 files changed

+897
-149
lines changed

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 98 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,14 @@ private predicate isPanicMacroCall(MacroExpr me) {
408408
me.getMacroCall().resolveMacro().(MacroRules).getName().getText() = "panic"
409409
}
410410

411+
// Due to "binding modes" the type of the pattern is not necessarily the
412+
// same as the type of the initializer. The pattern being an identifier
413+
// pattern is sufficient to ensure that this is not the case.
414+
private predicate identLetStmt(LetStmt let, IdentPat lhs, Expr rhs) {
415+
let.getPat() = lhs and
416+
let.getInitializer() = rhs
417+
}
418+
411419
/** Module for inferring certain type information. */
412420
module CertainTypeInference {
413421
pragma[nomagic]
@@ -485,11 +493,7 @@ module CertainTypeInference {
485493
// is not a certain type equality.
486494
exists(LetStmt let |
487495
not let.hasTypeRepr() and
488-
// Due to "binding modes" the type of the pattern is not necessarily the
489-
// same as the type of the initializer. The pattern being an identifier
490-
// pattern is sufficient to ensure that this is not the case.
491-
let.getPat().(IdentPat) = n1 and
492-
let.getInitializer() = n2
496+
identLetStmt(let, n1, n2)
493497
)
494498
or
495499
exists(LetExpr let |
@@ -513,6 +517,25 @@ module CertainTypeInference {
513517
)
514518
else prefix2.isEmpty()
515519
)
520+
or
521+
exists(CallExprImpl::DynamicCallExpr dce, TupleType tt, int i |
522+
n1 = dce.getArgList() and
523+
tt.getArity() = dce.getNumberOfSyntacticArguments() and
524+
n2 = dce.getSyntacticPositionalArgument(i) and
525+
prefix1 = TypePath::singleton(tt.getPositionalTypeParameter(i)) and
526+
prefix2.isEmpty()
527+
)
528+
or
529+
exists(ClosureExpr ce, int index |
530+
n1 = ce and
531+
n2 = ce.getParam(index).getPat() and
532+
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
533+
prefix2.isEmpty()
534+
)
535+
or
536+
n1 = any(ClosureExpr ce | not ce.hasRetType() and ce.getClosureBody() = n2) and
537+
prefix1 = closureReturnPath() and
538+
prefix2.isEmpty()
516539
}
517540

518541
pragma[nomagic]
@@ -775,17 +798,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
775798
n1.(ArrayRepeatExpr).getRepeatOperand() = n2 and
776799
prefix1 = TypePath::singleton(getArrayTypeParameter()) and
777800
prefix2.isEmpty()
778-
or
779-
exists(ClosureExpr ce, int index |
780-
n1 = ce and
781-
n2 = ce.getParam(index).getPat() and
782-
prefix1 = closureParameterPath(ce.getNumberOfParams(), index) and
783-
prefix2.isEmpty()
784-
)
785-
or
786-
n1.(ClosureExpr).getClosureBody() = n2 and
787-
prefix1 = closureReturnPath() and
788-
prefix2.isEmpty()
789801
}
790802

791803
/**
@@ -828,6 +840,19 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
828840
)
829841
}
830842

843+
private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) {
844+
inferType(n, path) = TUnknownType() and
845+
// Normally, these are coercion sites, but in case a type is unknown we
846+
// allow for type information to flow from the type annotation.
847+
exists(TypeMention tm | result = tm.getTypeAt(path) |
848+
tm = any(LetStmt let | identLetStmt(let, _, n)).getTypeRepr()
849+
or
850+
tm = any(ClosureExpr ce | n = ce.getBody()).getRetType().getTypeRepr()
851+
or
852+
tm = getReturnTypeMention(any(Function f | n = f.getBody()))
853+
)
854+
}
855+
831856
/**
832857
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
833858
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
@@ -1533,6 +1558,8 @@ private module MethodResolution {
15331558
* or
15341559
* 4. `MethodCallOperation`: an operation expression, `x + y`, which is syntactic sugar
15351560
* for `Add::add(x, y)`.
1561+
* 5. `ClosureMethodCall`: a call to a closure, `c(x)`, which is syntactic sugar for
1562+
* `c.call_once(x)`, `c.call_mut(x)`, or `c.call(x)`.
15361563
*
15371564
* Note that only in case 1 and 2 is auto-dereferencing and borrowing allowed.
15381565
*
@@ -1544,7 +1571,7 @@ private module MethodResolution {
15441571
abstract class MethodCall extends Expr {
15451572
abstract predicate hasNameAndArity(string name, int arity);
15461573

1547-
abstract Expr getArg(ArgumentPosition pos);
1574+
abstract AstNode getArg(ArgumentPosition pos);
15481575

15491576
abstract predicate supportsAutoDerefAndBorrow();
15501577

@@ -2093,6 +2120,26 @@ private module MethodResolution {
20932120
override Trait getTrait() { super.isOverloaded(result, _, _) }
20942121
}
20952122

2123+
private class ClosureMethodCall extends MethodCall instanceof CallExprImpl::DynamicCallExpr {
2124+
pragma[nomagic]
2125+
override predicate hasNameAndArity(string name, int arity) {
2126+
name = "call_once" and // todo: handle call_mut and call
2127+
arity = 1 // args are passed in a tuple
2128+
}
2129+
2130+
override AstNode getArg(ArgumentPosition pos) {
2131+
pos.isSelf() and
2132+
result = super.getFunction()
2133+
or
2134+
pos.asPosition() = 0 and
2135+
result = super.getArgList()
2136+
}
2137+
2138+
override predicate supportsAutoDerefAndBorrow() { any() }
2139+
2140+
override Trait getTrait() { result instanceof AnyFnTrait }
2141+
}
2142+
20962143
pragma[nomagic]
20972144
private Method getMethodSuccessor(ImplOrTraitItemNode i, string name, int arity) {
20982145
result = i.getASuccessor(name) and
@@ -2600,7 +2647,9 @@ private Type inferMethodCallTypeNonSelf(AstNode n, FunctionPosition pos, TypePat
26002647
* empty, at which point the inferred type can be applied back to `n`.
26012648
*/
26022649
pragma[nomagic]
2603-
private Type inferMethodCallTypeSelf(MethodCall mc, AstNode n, DerefChain derefChain, TypePath path) {
2650+
private Type inferMethodCallTypeSelf(
2651+
MethodCallMatchingInput::Access mc, AstNode n, DerefChain derefChain, TypePath path
2652+
) {
26042653
exists(
26052654
MethodCallMatchingInput::AccessPosition apos, string derefChainBorrow, BorrowKind borrow,
26062655
TypePath path0
@@ -2644,7 +2693,7 @@ private Type inferMethodCallTypeSelf(MethodCall mc, AstNode n, DerefChain derefC
26442693
private Type inferMethodCallTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) {
26452694
result = inferMethodCallTypeNonSelf(n, pos, path)
26462695
or
2647-
exists(MethodCall mc |
2696+
exists(MethodCallMatchingInput::Access mc |
26482697
result = inferMethodCallTypeSelf(mc, n, DerefChain::nil(), path) and
26492698
if mc instanceof CallExpr then pos.asPosition() = 0 else pos.isSelf()
26502699
)
@@ -3942,14 +3991,6 @@ private module InvokedClosureSatisfiesConstraintInput implements
39423991
}
39433992
}
39443993

3945-
private module InvokedClosureSatisfiesConstraint =
3946-
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>;
3947-
3948-
/** Gets the type of `ce` when viewed as an implementation of `FnOnce`. */
3949-
private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
3950-
InvokedClosureSatisfiesConstraint::satisfiesConstraintType(ce, _, path, result)
3951-
}
3952-
39533994
/**
39543995
* Gets the root type of a closure.
39553996
*
@@ -3976,73 +4017,39 @@ private TypePath closureParameterPath(int arity, int index) {
39764017
TypePath::singleton(getTupleTypeParameter(arity, index)))
39774018
}
39784019

3979-
/** Gets the path to the return type of the `FnOnce` trait. */
3980-
private TypePath fnReturnPath() {
3981-
result = TypePath::singleton(getAssociatedTypeTypeParameter(any(FnOnceTrait t).getOutputType()))
3982-
}
3983-
3984-
/**
3985-
* Gets the path to the parameter type of the `FnOnce` trait with arity `arity`
3986-
* and index `index`.
3987-
*/
39884020
pragma[nomagic]
3989-
private TypePath fnParameterPath(int arity, int index) {
3990-
result =
3991-
TypePath::cons(TTypeParamTypeParameter(any(FnOnceTrait t).getTypeParam()),
3992-
TypePath::singleton(getTupleTypeParameter(arity, index)))
3993-
}
3994-
3995-
pragma[nomagic]
3996-
private Type inferDynamicCallExprType(Expr n, TypePath path) {
3997-
exists(InvokedClosureExpr ce |
3998-
// Propagate the function's return type to the call expression
3999-
exists(TypePath path0 | result = invokedClosureFnTypeAt(ce, path0) |
4000-
n = ce.getCall() and
4001-
path = path0.stripPrefix(fnReturnPath())
4021+
private Type inferClosureExprType(AstNode n, TypePath path) {
4022+
exists(ClosureExpr ce |
4023+
n = ce and
4024+
(
4025+
path.isEmpty() and
4026+
result = closureRootType()
4027+
or
4028+
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
4029+
result.(TupleType).getArity() = ce.getNumberOfParams()
40024030
or
4003-
// Propagate the function's parameter type to the arguments
4004-
exists(int index |
4005-
n = ce.getCall().getSyntacticPositionalArgument(index) and
4006-
path =
4007-
path0.stripPrefix(fnParameterPath(ce.getCall().getArgList().getNumberOfArgs(), index))
4031+
exists(TypePath path0 |
4032+
result = ce.getRetType().getTypeRepr().(TypeMention).getTypeAt(path0) and
4033+
path = closureReturnPath().append(path0)
40084034
)
40094035
)
40104036
or
4011-
// _If_ the invoked expression has the type of a closure, then we propagate
4012-
// the surrounding types into the closure.
4013-
exists(int arity, TypePath path0 | ce.getTypeAt(TypePath::nil()) = closureRootType() |
4014-
// Propagate the type of arguments to the parameter types of closure
4015-
exists(int index, ArgList args |
4016-
n = ce and
4017-
args = ce.getCall().getArgList() and
4018-
arity = args.getNumberOfArgs() and
4019-
result = inferType(args.getArg(index), path0) and
4020-
path = closureParameterPath(arity, index).append(path0)
4021-
)
4022-
or
4023-
// Propagate the type of the call expression to the return type of the closure
4024-
n = ce and
4025-
arity = ce.getCall().getArgList().getNumberOfArgs() and
4026-
result = inferType(ce.getCall(), path0) and
4027-
path = closureReturnPath().append(path0)
4037+
exists(Param p |
4038+
p = ce.getAParam() and
4039+
not p.hasTypeRepr() and
4040+
n = p.getPat() and
4041+
result = TUnknownType() and
4042+
path.isEmpty()
40284043
)
40294044
)
40304045
}
40314046

40324047
pragma[nomagic]
4033-
private Type inferClosureExprType(AstNode n, TypePath path) {
4034-
exists(ClosureExpr ce |
4035-
n = ce and
4036-
path.isEmpty() and
4037-
result = closureRootType()
4038-
or
4039-
n = ce and
4040-
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
4041-
result.(TupleType).getArity() = ce.getNumberOfParams()
4042-
or
4043-
// Propagate return type annotation to body
4044-
n = ce.getClosureBody() and
4045-
result = ce.getRetType().getTypeRepr().(TypeMention).getTypeAt(path)
4048+
private TupleType inferArgList(ArgList args, TypePath path) {
4049+
exists(CallExprImpl::DynamicCallExpr dce |
4050+
args = dce.getArgList() and
4051+
result.getArity() = dce.getNumberOfSyntacticArguments() and
4052+
path.isEmpty()
40464053
)
40474054
}
40484055

@@ -4089,7 +4096,9 @@ private module Cached {
40894096
or
40904097
i instanceof ImplItemNode and dispatch = false
40914098
|
4092-
result = call.(MethodResolution::MethodCall).resolveCallTarget(i, _, _) or
4099+
result = call.(MethodResolution::MethodCall).resolveCallTarget(i, _, _) and
4100+
not call instanceof CallExprImpl::DynamicCallExpr
4101+
or
40934102
result = call.(NonMethodResolution::NonMethodCall).resolveCallTargetViaTypeInference(i)
40944103
)
40954104
}
@@ -4199,13 +4208,15 @@ private module Cached {
41994208
or
42004209
result = inferForLoopExprType(n, path)
42014210
or
4202-
result = inferDynamicCallExprType(n, path)
4203-
or
42044211
result = inferClosureExprType(n, path)
42054212
or
4213+
result = inferArgList(n, path)
4214+
or
42064215
result = inferStructPatType(n, path)
42074216
or
42084217
result = inferTupleStructPatType(n, path)
4218+
or
4219+
result = inferUnknownTypeFromAnnotation(n, path)
42094220
)
42104221
}
42114222
}

rust/ql/test/library-tests/type-inference/closure.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ mod fn_once_trait {
6363
};
6464
let _r = apply(f, true); // $ target=apply type=_r:i64
6565

66-
let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
66+
let f = |x| x + 1; // $ type=x:i64 $ MISSING: target=add
6767
let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64
6868
}
6969
}
@@ -100,7 +100,7 @@ mod fn_mut_trait {
100100
};
101101
let _r = apply(f, true); // $ target=apply type=_r:i64
102102

103-
let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
103+
let f = |x| x + 1; // $ type=x:i64 $ MISSING: target=add
104104
let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64
105105
}
106106
}
@@ -137,7 +137,7 @@ mod fn_trait {
137137
};
138138
let _r = apply(f, true); // $ target=apply type=_r:i64
139139

140-
let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
140+
let f = |x| x + 1; // $ type=x:i64 $ MISSING: target=add
141141
let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64
142142
}
143143
}
@@ -183,25 +183,25 @@ mod closure_infer_param {
183183
}
184184

185185
fn test() {
186-
let f = |x| x; // $ MISSING: type=x:i64
186+
let f = |x| x; // $ type=x:i64
187187
let _r = apply1(f, 1i64); // $ target=apply1
188188

189-
let f = |x| x; // $ MISSING: type=x:i64
189+
let f = |x| x; // $ type=x:i64
190190
let _r = apply2(f, 2i64); // $ target=apply2
191191

192-
let f = |x| x; // $ MISSING: type=x:i64
192+
let f = |x| x; // $ type=x:i64
193193
let _r = apply3(&f, 3i64); // $ target=apply3
194194

195-
let f = |x| x; // $ MISSING: type=x:i64
195+
let f = |x| x; // $ type=x:i64
196196
let _r = apply4(f, 4i64); // $ target=apply4
197197

198198
let mut f = |x| x; // $ MISSING: type=x:i64
199199
let _r = apply5(&mut f, 5i64); // $ target=apply5
200200

201-
let f = |x| x; // $ MISSING: type=x:i64
201+
let f = |x| x; // $ type=x:i64
202202
let _r = apply6(f, 6i64); // $ target=apply6
203203

204-
let f = |x| x; // $ MISSING: type=x:i64
204+
let f = |x| x; // $ type=x:i64
205205
let _r = apply7(f, 7i64); // $ target=apply7
206206
}
207207
}
@@ -221,15 +221,15 @@ mod implicit_deref {
221221

222222
pub fn test() {
223223
let x = 0i64;
224-
let v = Default::default(); // $ MISSING: type=v:i64 target=default
224+
let v = Default::default(); // $ type=v:i64 target=default
225225
let s = S(v);
226-
let _ret = s(x); // $ MISSING: type=_ret:bool
226+
let _ret = s(x); // $ type=_ret:bool
227227

228228
let x = 0i32;
229-
let v = Default::default(); // $ MISSING: type=v:i32 target=default
229+
let v = Default::default(); // $ type=v:i32 target=default
230230
let s = S(v);
231231
let s_ref = &s;
232-
let _ret = s_ref(x); // $ MISSING: type=_ret:bool
232+
let _ret = s_ref(x); // $ type=_ret:bool
233233

234234
// The call below is not an implicit deref, instead it will target
235235
// `impl<A, F> FnOnce<A> for &F` from

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,7 +2259,7 @@ mod loops {
22592259
// for loops with arrays
22602260

22612261
for i in [1, 2, 3] {} // $ type=i:i32
2262-
for i in [1, 2, 3].map(|x| x + 1) {} // $ target=map MISSING: type=i:i32
2262+
for i in [1, 2, 3].map(|x| x + 1) {} // $ target=map target=add type=i:i32
22632263
for i in [1, 2, 3].into_iter() {} // $ target=into_iter type=i:i32
22642264

22652265
let vals1 = [1u8, 2, 3]; // $ type=vals1:TArray.u8
@@ -2770,7 +2770,7 @@ mod arg_trait_bounds {
27702770
}
27712771

27722772
fn test() {
2773-
let v = Default::default(); // $ MISSING: type=v:i64 target=default
2773+
let v = Default::default(); // $ type=v:i64 target=default
27742774
let g = Gen(v);
27752775
let _ = my_get(&g); // $ target=my_get
27762776
}

0 commit comments

Comments
 (0)