From cc048b32f3de4168de6b0207fd01c65e51d37ac0 Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Mon, 13 Mar 2023 16:38:14 -0700 Subject: [PATCH] go/types, types2: reverse inference of function type arguments This CL implements type inference for generic functions used in assignments: variable init expressions, regular assignments, and return statements, but (not yet) function arguments passed to functions. For instance, given a generic function func f[P any](x P) and a variable of function type var v func(x int) the assignment v = f is valid w/o explicit instantiation of f, and the missing type argument for f is inferred from the type of v. More generally, the function f may have multiple type arguments, and it may be partially instantiated. This new form of inference is not enabled by default (it needs to go through the proposal process first). It can be enabled by setting Config.EnableReverseTypeInference. The mechanism is implemented as follows: - The various expression evaluation functions take an additional (first) argument T, which is the target type for the expression. If not nil, it is the type of the LHS in an assignment. - The method Checker.funcInst is changed such that it uses both, provided type arguments (if any), and a target type (if any) to augment type inference. Change-Id: Idfde61078e1ee4f22abcca894a4c84d681734ff6 Reviewed-on: https://go-review.googlesource.com/c/go/+/476075 TryBot-Result: Gopher Robot Auto-Submit: Robert Griesemer Reviewed-by: Robert Findley Reviewed-by: Robert Griesemer Run-TryBot: Robert Griesemer --- src/cmd/compile/internal/types2/api.go | 7 ++ .../compile/internal/types2/assignments.go | 29 +++-- src/cmd/compile/internal/types2/builtins.go | 4 +- src/cmd/compile/internal/types2/call.go | 114 ++++++++++++++--- src/cmd/compile/internal/types2/check_test.go | 1 + src/cmd/compile/internal/types2/decl.go | 4 +- src/cmd/compile/internal/types2/expr.go | 63 +++++++--- src/cmd/compile/internal/types2/index.go | 10 +- src/cmd/compile/internal/types2/stmt.go | 32 ++--- src/cmd/compile/internal/types2/typexpr.go | 2 +- src/go/types/api.go | 7 ++ src/go/types/assignments.go | 29 +++-- src/go/types/builtins.go | 4 +- src/go/types/call.go | 119 +++++++++++++++--- src/go/types/check_test.go | 1 + src/go/types/decl.go | 4 +- src/go/types/eval.go | 4 +- src/go/types/expr.go | 63 +++++++--- src/go/types/index.go | 10 +- src/go/types/stmt.go | 32 ++--- src/go/types/typexpr.go | 2 +- .../types/testdata/examples/inference.go | 7 ++ .../types/testdata/examples/inference2.go | 73 +++++++++++ 23 files changed, 462 insertions(+), 159 deletions(-) create mode 100644 src/internal/types/testdata/examples/inference2.go diff --git a/src/cmd/compile/internal/types2/api.go b/src/cmd/compile/internal/types2/api.go index 56fb578943..e027b9a7e2 100644 --- a/src/cmd/compile/internal/types2/api.go +++ b/src/cmd/compile/internal/types2/api.go @@ -169,6 +169,13 @@ type Config struct { // If DisableUnusedImportCheck is set, packages are not checked // for unused imports. DisableUnusedImportCheck bool + + // If EnableReverseTypeInference is set, uninstantiated and + // partially instantiated generic functions may be assigned + // (incl. returned) to variables of function type and type + // inference will attempt to infer the missing type arguments. + // Experimental. Needs a proposal. + EnableReverseTypeInference bool } func srcimporter_setUsesCgo(conf *Config) { diff --git a/src/cmd/compile/internal/types2/assignments.go b/src/cmd/compile/internal/types2/assignments.go index 3ca6bebd31..5a51b3de1e 100644 --- a/src/cmd/compile/internal/types2/assignments.go +++ b/src/cmd/compile/internal/types2/assignments.go @@ -189,7 +189,7 @@ func (check *Checker) lhsVar(lhs syntax.Expr) Type { } var x operand - check.expr(&x, lhs) + check.expr(nil, &x, lhs) if v != nil { v.used = v_used // restore v.used @@ -205,7 +205,7 @@ func (check *Checker) lhsVar(lhs syntax.Expr) Type { default: if sel, ok := x.expr.(*syntax.SelectorExpr); ok { var op operand - check.expr(&op, sel.X) + check.expr(nil, &op, sel.X) if op.mode == mapindex { check.errorf(&x, UnaddressableFieldAssign, "cannot assign to struct field %s in map", syntax.String(x.expr)) return Typ[Invalid] @@ -218,15 +218,20 @@ func (check *Checker) lhsVar(lhs syntax.Expr) Type { return x.typ } -// assignVar checks the assignment lhs = x. -func (check *Checker) assignVar(lhs syntax.Expr, x *operand) { - if x.mode == invalid { - check.useLHS(lhs) +// assignVar checks the assignment lhs = rhs (if x == nil), or lhs = x (if x != nil). +// If x != nil, it must be the evaluation of rhs (and rhs will be ignored). +func (check *Checker) assignVar(lhs, rhs syntax.Expr, x *operand) { + T := check.lhsVar(lhs) // nil if lhs is _ + if T == Typ[Invalid] { + check.use(rhs) return } - T := check.lhsVar(lhs) // nil if lhs is _ - if T == Typ[Invalid] { + if x == nil { + x = new(operand) + check.expr(T, x, rhs) + } + if x.mode == invalid { return } @@ -351,7 +356,7 @@ func (check *Checker) initVars(lhs []*Var, orig_rhs []syntax.Expr, returnStmt sy if l == r && !isCall { var x operand for i, lhs := range lhs { - check.expr(&x, orig_rhs[i]) + check.expr(lhs.typ, &x, orig_rhs[i]) check.initVar(lhs, &x, context) } return @@ -423,9 +428,7 @@ func (check *Checker) assignVars(lhs, orig_rhs []syntax.Expr) { // each value can be assigned to its corresponding variable. if l == r && !isCall { for i, lhs := range lhs { - var x operand - check.expr(&x, orig_rhs[i]) - check.assignVar(lhs, &x) + check.assignVar(lhs, orig_rhs[i], nil) } return } @@ -446,7 +449,7 @@ func (check *Checker) assignVars(lhs, orig_rhs []syntax.Expr) { r = len(rhs) if l == r { for i, lhs := range lhs { - check.assignVar(lhs, rhs[i]) + check.assignVar(lhs, nil, rhs[i]) } if commaOk { check.recordCommaOkTypes(orig_rhs[0], rhs) diff --git a/src/cmd/compile/internal/types2/builtins.go b/src/cmd/compile/internal/types2/builtins.go index e35dab8140..67aa37e401 100644 --- a/src/cmd/compile/internal/types2/builtins.go +++ b/src/cmd/compile/internal/types2/builtins.go @@ -678,7 +678,7 @@ func (check *Checker) builtin(x *operand, call *syntax.CallExpr, id builtinId) ( return } - check.expr(x, selx.X) + check.expr(nil, x, selx.X) if x.mode == invalid { return } @@ -878,7 +878,7 @@ func (check *Checker) builtin(x *operand, call *syntax.CallExpr, id builtinId) ( var t operand x1 := x for _, arg := range call.ArgList { - check.rawExpr(x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T)) + check.rawExpr(nil, x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T)) check.dump("%v: %s", posFor(x1), x1) x1 = &t // use incoming x only for first argument } diff --git a/src/cmd/compile/internal/types2/call.go b/src/cmd/compile/internal/types2/call.go index 72608dea26..bb82c2464e 100644 --- a/src/cmd/compile/internal/types2/call.go +++ b/src/cmd/compile/internal/types2/call.go @@ -8,31 +8,54 @@ package types2 import ( "cmd/compile/internal/syntax" + "fmt" . "internal/types/errors" "strings" "unicode" ) -// funcInst type-checks a function instantiation inst and returns the result in x. -// The operand x must be the evaluation of inst.X and its type must be a signature. -func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) { +// funcInst type-checks a function instantiation and returns the result in x. +// The incoming x must be an uninstantiated generic function. If inst != 0, +// it provides (some or all of) the type arguments (inst.Index) for the +// instantiation. If the target type T != nil and is a (non-generic) function +// signature, the signature's parameter types are used to infer additional +// missing type arguments of x, if any. +// At least one of inst or T must be provided. +func (check *Checker) funcInst(T Type, pos syntax.Pos, x *operand, inst *syntax.IndexExpr) { if !check.allowVersion(check.pkg, 1, 18) { check.versionErrorf(inst.Pos(), "go1.18", "function instantiation") } - xlist := unpackExpr(inst.Index) - targs := check.typeList(xlist) - if targs == nil { - x.mode = invalid - x.expr = inst - return + // tsig is the (assignment) target function signature, or nil. + // TODO(gri) refactor and pass in tsig to funcInst instead + var tsig *Signature + if check.conf.EnableReverseTypeInference && T != nil { + tsig, _ = under(T).(*Signature) } - assert(len(targs) == len(xlist)) - // check number of type arguments (got) vs number of type parameters (want) + // targs and xlist are the type arguments and corresponding type expressions, or nil. + var targs []Type + var xlist []syntax.Expr + if inst != nil { + xlist = unpackExpr(inst.Index) + targs = check.typeList(xlist) + if targs == nil { + x.mode = invalid + x.expr = inst + return + } + assert(len(targs) == len(xlist)) + } + + assert(tsig != nil || targs != nil) + + // Check the number of type arguments (got) vs number of type parameters (want). + // Note that x is a function value, not a type expression, so we don't need to + // call under below. sig := x.typ.(*Signature) got, want := len(targs), sig.TypeParams().Len() if got > want { + // Providing too many type arguments is always an error. check.errorf(xlist[got-1], WrongTypeArgCount, "got %d type arguments but want %d", got, want) x.mode = invalid x.expr = inst @@ -40,7 +63,37 @@ func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) { } if got < want { - targs = check.infer(inst.Pos(), sig.TypeParams().list(), targs, nil, nil) + // If the uninstantiated or partially instantiated function x is used in an + // assignment (tsig != nil), use the respective function parameter and result + // types to infer additional type arguments. + var args []*operand + var params []*Var + if tsig != nil && sig.tparams != nil && tsig.params.Len() == sig.params.Len() && tsig.results.Len() == sig.results.Len() { + // x is a generic function and the signature arity matches the target function. + // To infer x's missing type arguments, treat the function assignment as a call + // of a synthetic function f where f's parameters are the parameters and results + // of x and where the arguments to the call of f are values of the parameter and + // result types of x. + n := tsig.params.Len() + m := tsig.results.Len() + args = make([]*operand, n+m) + params = make([]*Var, n+m) + for i := 0; i < n; i++ { + lvar := tsig.params.At(i) + lname := syntax.NewName(x.Pos(), paramName(lvar.name, i, "parameter")) + args[i] = &operand{mode: value, expr: lname, typ: lvar.typ} + params[i] = sig.params.At(i) + } + for i := 0; i < m; i++ { + lvar := tsig.results.At(i) + lname := syntax.NewName(x.Pos(), paramName(lvar.name, i, "result parameter")) + args[n+i] = &operand{mode: value, expr: lname, typ: lvar.typ} + params[n+i] = sig.results.At(i) + } + } + + // Note that NewTuple(params...) below is nil if len(params) == 0, as desired. + targs = check.infer(pos, sig.TypeParams().list(), targs, NewTuple(params...), args) if targs == nil { // error was already reported x.mode = invalid @@ -54,10 +107,33 @@ func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) { // instantiate function signature sig = check.instantiateSignature(x.Pos(), sig, targs, xlist) assert(sig.TypeParams().Len() == 0) // signature is not generic anymore - check.recordInstance(inst.X, targs, sig) + x.typ = sig x.mode = value - x.expr = inst + // If we don't have an index expression, keep the existing expression of x. + if inst != nil { + x.expr = inst + } + check.recordInstance(x.expr, targs, sig) +} + +func paramName(name string, i int, kind string) string { + if name != "" { + return name + } + return nth(i+1) + " " + kind +} + +func nth(n int) string { + switch n { + case 1: + return "1st" + case 2: + return "2nd" + case 3: + return "3rd" + } + return fmt.Sprintf("%dth", n) } func (check *Checker) instantiateSignature(pos syntax.Pos, typ *Signature, targs []Type, xlist []syntax.Expr) (res *Signature) { @@ -119,7 +195,7 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind { case typexpr: // conversion - check.nonGeneric(x) + check.nonGeneric(nil, x) if x.mode == invalid { return conversion } @@ -129,7 +205,7 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind { case 0: check.errorf(call, WrongArgCount, "missing argument in conversion to %s", T) case 1: - check.expr(x, call.ArgList[0]) + check.expr(nil, x, call.ArgList[0]) if x.mode != invalid { if t, _ := under(T).(*Interface); t != nil && !isTypeParam(T) { if !t.IsMethodSet() { @@ -272,7 +348,7 @@ func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) { xlist = make([]*operand, len(elist)) for i, e := range elist { var x operand - check.expr(&x, e) + check.expr(nil, &x, e) xlist[i] = &x } } @@ -744,14 +820,14 @@ func (check *Checker) use1(e syntax.Expr, lhs bool) bool { } } } - check.rawExpr(&x, n, nil, true) + check.rawExpr(nil, &x, n, nil, true) if v != nil { v.used = v_used // restore v.used } case *syntax.ListExpr: return check.useN(n.ElemList, lhs) default: - check.rawExpr(&x, e, nil, true) + check.rawExpr(nil, &x, e, nil, true) } return x.mode != invalid } diff --git a/src/cmd/compile/internal/types2/check_test.go b/src/cmd/compile/internal/types2/check_test.go index 26bb1aed9e..382d1ad19e 100644 --- a/src/cmd/compile/internal/types2/check_test.go +++ b/src/cmd/compile/internal/types2/check_test.go @@ -133,6 +133,7 @@ func testFiles(t *testing.T, filenames []string, colDelta uint, manual bool) { flags := flag.NewFlagSet("", flag.PanicOnError) flags.StringVar(&conf.GoVersion, "lang", "", "") flags.BoolVar(&conf.FakeImportC, "fakeImportC", false, "") + flags.BoolVar(&conf.EnableReverseTypeInference, "reverseTypeInference", false, "") if err := parseFlags(filenames[0], nil, flags); err != nil { t.Fatal(err) } diff --git a/src/cmd/compile/internal/types2/decl.go b/src/cmd/compile/internal/types2/decl.go index 0ac0f6196a..afa32c1a5f 100644 --- a/src/cmd/compile/internal/types2/decl.go +++ b/src/cmd/compile/internal/types2/decl.go @@ -408,7 +408,7 @@ func (check *Checker) constDecl(obj *Const, typ, init syntax.Expr, inherited boo // (see issues go.dev/issue/42991, go.dev/issue/42992). check.errpos = obj.pos } - check.expr(&x, init) + check.expr(nil, &x, init) } check.initConst(obj, &x) } @@ -455,7 +455,7 @@ func (check *Checker) varDecl(obj *Var, lhs []*Var, typ, init syntax.Expr) { if lhs == nil || len(lhs) == 1 { assert(lhs == nil || lhs[0] == obj) var x operand - check.expr(&x, init) + check.expr(obj.typ, &x, init) check.initVar(obj, &x, "variable declaration") return } diff --git a/src/cmd/compile/internal/types2/expr.go b/src/cmd/compile/internal/types2/expr.go index fdc7bdbef0..bab52b253b 100644 --- a/src/cmd/compile/internal/types2/expr.go +++ b/src/cmd/compile/internal/types2/expr.go @@ -173,7 +173,7 @@ func underIs(typ Type, f func(Type) bool) bool { } func (check *Checker) unary(x *operand, e *syntax.Operation) { - check.expr(x, e.X) + check.expr(nil, x, e.X) if x.mode == invalid { return } @@ -1097,8 +1097,8 @@ func init() { func (check *Checker) binary(x *operand, e syntax.Expr, lhs, rhs syntax.Expr, op syntax.Operator) { var y operand - check.expr(x, lhs) - check.expr(&y, rhs) + check.expr(nil, x, lhs) + check.expr(nil, &y, rhs) if x.mode == invalid { return @@ -1245,12 +1245,18 @@ const ( statement ) +// TODO(gri) In rawExpr below, consider using T instead of hint and +// some sort of "operation mode" instead of allowGeneric. +// May be clearer and less error-prone. + // rawExpr typechecks expression e and initializes x with the expression // value or type. If an error occurred, x.mode is set to invalid. +// If a non-nil target type T is given and e is a generic function +// or function call, T is used to infer the type arguments for e. // If hint != nil, it is the type of a composite literal element. // If allowGeneric is set, the operand type may be an uninstantiated // parameterized type or function value. -func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric bool) exprKind { +func (check *Checker) rawExpr(T Type, x *operand, e syntax.Expr, hint Type, allowGeneric bool) exprKind { if check.conf.Trace { check.trace(e.Pos(), "-- expr %s", e) check.indent++ @@ -1260,10 +1266,10 @@ func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric }() } - kind := check.exprInternal(x, e, hint) + kind := check.exprInternal(T, x, e, hint) if !allowGeneric { - check.nonGeneric(x) + check.nonGeneric(T, x) } check.record(x) @@ -1271,9 +1277,10 @@ func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric return kind } -// If x is a generic function or type, nonGeneric reports an error and invalidates x.mode and x.typ. +// If x is a generic type, or a generic function whose type arguments cannot be inferred +// from a non-nil target type T, nonGeneric reports an error and invalidates x.mode and x.typ. // Otherwise it leaves x alone. -func (check *Checker) nonGeneric(x *operand) { +func (check *Checker) nonGeneric(T Type, x *operand) { if x.mode == invalid || x.mode == novalue { return } @@ -1285,6 +1292,12 @@ func (check *Checker) nonGeneric(x *operand) { } case *Signature: if t.tparams != nil { + if check.conf.EnableReverseTypeInference && T != nil { + if _, ok := under(T).(*Signature); ok { + check.funcInst(T, x.Pos(), x, nil) + return + } + } what = "function" } } @@ -1297,7 +1310,8 @@ func (check *Checker) nonGeneric(x *operand) { // exprInternal contains the core of type checking of expressions. // Must only be called by rawExpr. -func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKind { +// (See rawExpr for an explanation of the parameters.) +func (check *Checker) exprInternal(T Type, x *operand, e syntax.Expr, hint Type) exprKind { // make sure x has a valid state in case of bailout // (was go.dev/issue/5770) x.mode = invalid @@ -1438,7 +1452,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin key, _ := kv.Key.(*syntax.Name) // do all possible checks early (before exiting due to errors) // so we don't drop information on the floor - check.expr(x, kv.Value) + check.expr(nil, x, kv.Value) if key == nil { check.errorf(kv, InvalidLitField, "invalid field name %s in struct literal", kv.Key) continue @@ -1466,7 +1480,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin check.error(kv, MixedStructLit, "mixture of field:value and value elements in struct literal") continue } - check.expr(x, e) + check.expr(nil, x, e) if i >= len(fields) { check.errorf(x, InvalidStructLit, "too many values in struct literal of type %s", base) break // cannot continue @@ -1593,7 +1607,8 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin x.typ = typ case *syntax.ParenExpr: - kind := check.rawExpr(x, e.X, nil, false) + // type inference doesn't go past parentheses (targe type T = nil) + kind := check.rawExpr(nil, x, e.X, nil, false) x.expr = e return kind @@ -1602,7 +1617,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin case *syntax.IndexExpr: if check.indexExpr(x, e) { - check.funcInst(x, e) + check.funcInst(T, e.Pos(), x, e) } if x.mode == invalid { goto Error @@ -1615,7 +1630,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin } case *syntax.AssertExpr: - check.expr(x, e.X) + check.expr(nil, x, e.X) if x.mode == invalid { goto Error } @@ -1624,7 +1639,6 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin check.error(e, InvalidSyntaxTree, "invalid use of AssertExpr") goto Error } - // TODO(gri) we may want to permit type assertions on type parameter values at some point if isTypeParam(x.typ) { check.errorf(x, InvalidAssert, invalidOp+"cannot use type assertion on type parameter value %s", x) goto Error @@ -1814,10 +1828,19 @@ func (check *Checker) typeAssertion(e syntax.Expr, x *operand, T Type, typeSwitc } // expr typechecks expression e and initializes x with the expression value. +// If a non-nil target type T is given and e is a generic function +// or function call, T is used to infer the type arguments for e. // The result must be a single value. // If an error occurred, x.mode is set to invalid. -func (check *Checker) expr(x *operand, e syntax.Expr) { - check.rawExpr(x, e, nil, false) +func (check *Checker) expr(T Type, x *operand, e syntax.Expr) { + check.rawExpr(T, x, e, nil, false) + check.exclude(x, 1< want { + // Providing too many type arguments is always an error. check.errorf(ix.Indices[got-1], WrongTypeArgCount, "got %d type arguments but want %d", got, want) x.mode = invalid x.expr = ix.Orig @@ -41,11 +65,43 @@ func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) { } if got < want { - targs = check.infer(ix.Orig, sig.TypeParams().list(), targs, nil, nil) + // If the uninstantiated or partially instantiated function x is used in an + // assignment (tsig != nil), use the respective function parameter and result + // types to infer additional type arguments. + var args []*operand + var params []*Var + if tsig != nil && sig.tparams != nil && tsig.params.Len() == sig.params.Len() && tsig.results.Len() == sig.results.Len() { + // x is a generic function and the signature arity matches the target function. + // To infer x's missing type arguments, treat the function assignment as a call + // of a synthetic function f where f's parameters are the parameters and results + // of x and where the arguments to the call of f are values of the parameter and + // result types of x. + n := tsig.params.Len() + m := tsig.results.Len() + args = make([]*operand, n+m) + params = make([]*Var, n+m) + for i := 0; i < n; i++ { + lvar := tsig.params.At(i) + lname := ast.NewIdent(paramName(lvar.name, i, "parameter")) + lname.NamePos = x.Pos() // correct position + args[i] = &operand{mode: value, expr: lname, typ: lvar.typ} + params[i] = sig.params.At(i) + } + for i := 0; i < m; i++ { + lvar := tsig.results.At(i) + lname := ast.NewIdent(paramName(lvar.name, i, "result parameter")) + lname.NamePos = x.Pos() // correct position + args[n+i] = &operand{mode: value, expr: lname, typ: lvar.typ} + params[n+i] = sig.results.At(i) + } + } + + // Note that NewTuple(params...) below is nil if len(params) == 0, as desired. + targs = check.infer(atPos(pos), sig.TypeParams().list(), targs, NewTuple(params...), args) if targs == nil { // error was already reported x.mode = invalid - x.expr = ix.Orig + x.expr = ix // TODO(gri) is this correct? return } got = len(targs) @@ -53,12 +109,35 @@ func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) { assert(got == want) // instantiate function signature - sig = check.instantiateSignature(x.Pos(), sig, targs, ix.Indices) + sig = check.instantiateSignature(x.Pos(), sig, targs, xlist) assert(sig.TypeParams().Len() == 0) // signature is not generic anymore - check.recordInstance(ix.Orig, targs, sig) + x.typ = sig x.mode = value - x.expr = ix.Orig + // If we don't have an index expression, keep the existing expression of x. + if ix != nil { + x.expr = ix.Orig + } + check.recordInstance(x.expr, targs, sig) +} + +func paramName(name string, i int, kind string) string { + if name != "" { + return name + } + return nth(i+1) + " " + kind +} + +func nth(n int) string { + switch n { + case 1: + return "1st" + case 2: + return "2nd" + case 3: + return "3rd" + } + return fmt.Sprintf("%dth", n) } func (check *Checker) instantiateSignature(pos token.Pos, typ *Signature, targs []Type, xlist []ast.Expr) (res *Signature) { @@ -121,7 +200,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind { case typexpr: // conversion - check.nonGeneric(x) + check.nonGeneric(nil, x) if x.mode == invalid { return conversion } @@ -131,7 +210,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind { case 0: check.errorf(inNode(call, call.Rparen), WrongArgCount, "missing argument in conversion to %s", T) case 1: - check.expr(x, call.Args[0]) + check.expr(nil, x, call.Args[0]) if x.mode != invalid { if call.Ellipsis.IsValid() { check.errorf(call.Args[0], BadDotDotDotSyntax, "invalid use of ... in conversion to %s", T) @@ -274,7 +353,7 @@ func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) { xlist = make([]*operand, len(elist)) for i, e := range elist { var x operand - check.expr(&x, e) + check.expr(nil, &x, e) xlist[i] = &x } } @@ -791,12 +870,12 @@ func (check *Checker) use1(e ast.Expr, lhs bool) bool { } } } - check.rawExpr(&x, n, nil, true) + check.rawExpr(nil, &x, n, nil, true) if v != nil { v.used = v_used // restore v.used } default: - check.rawExpr(&x, e, nil, true) + check.rawExpr(nil, &x, e, nil, true) } return x.mode != invalid } diff --git a/src/go/types/check_test.go b/src/go/types/check_test.go index 36809838c7..0f4c320a47 100644 --- a/src/go/types/check_test.go +++ b/src/go/types/check_test.go @@ -145,6 +145,7 @@ func testFiles(t *testing.T, sizes Sizes, filenames []string, srcs [][]byte, man flags := flag.NewFlagSet("", flag.PanicOnError) flags.StringVar(&conf.GoVersion, "lang", "", "") flags.BoolVar(&conf.FakeImportC, "fakeImportC", false, "") + flags.BoolVar(boolFieldAddr(&conf, "_EnableReverseTypeInference"), "reverseTypeInference", false, "") if err := parseFlags(filenames[0], srcs[0], flags); err != nil { t.Fatal(err) } diff --git a/src/go/types/decl.go b/src/go/types/decl.go index 393d8f34e2..3065da2e8e 100644 --- a/src/go/types/decl.go +++ b/src/go/types/decl.go @@ -477,7 +477,7 @@ func (check *Checker) constDecl(obj *Const, typ, init ast.Expr, inherited bool) // (see issues go.dev/issue/42991, go.dev/issue/42992). check.errpos = atPos(obj.pos) } - check.expr(&x, init) + check.expr(nil, &x, init) } check.initConst(obj, &x) } @@ -510,7 +510,7 @@ func (check *Checker) varDecl(obj *Var, lhs []*Var, typ, init ast.Expr) { if lhs == nil || len(lhs) == 1 { assert(lhs == nil || lhs[0] == obj) var x operand - check.expr(&x, init) + check.expr(obj.typ, &x, init) check.initVar(obj, &x, "variable declaration") return } diff --git a/src/go/types/eval.go b/src/go/types/eval.go index 1e4d64fe96..1655a8bd27 100644 --- a/src/go/types/eval.go +++ b/src/go/types/eval.go @@ -91,8 +91,8 @@ func CheckExpr(fset *token.FileSet, pkg *Package, pos token.Pos, expr ast.Expr, // evaluate node var x operand - check.rawExpr(&x, expr, nil, true) // allow generic expressions - check.processDelayed(0) // incl. all functions + check.rawExpr(nil, &x, expr, nil, true) // allow generic expressions + check.processDelayed(0) // incl. all functions check.recordUntyped() return nil diff --git a/src/go/types/expr.go b/src/go/types/expr.go index 1abf963b7f..219a392b88 100644 --- a/src/go/types/expr.go +++ b/src/go/types/expr.go @@ -160,7 +160,7 @@ func underIs(typ Type, f func(Type) bool) bool { // The unary expression e may be nil. It's passed in for better error messages only. func (check *Checker) unary(x *operand, e *ast.UnaryExpr) { - check.expr(x, e.X) + check.expr(nil, x, e.X) if x.mode == invalid { return } @@ -1079,8 +1079,8 @@ func init() { func (check *Checker) binary(x *operand, e ast.Expr, lhs, rhs ast.Expr, op token.Token, opPos token.Pos) { var y operand - check.expr(x, lhs) - check.expr(&y, rhs) + check.expr(nil, x, lhs) + check.expr(nil, &y, rhs) if x.mode == invalid { return @@ -1230,12 +1230,18 @@ const ( statement ) +// TODO(gri) In rawExpr below, consider using T instead of hint and +// some sort of "operation mode" instead of allowGeneric. +// May be clearer and less error-prone. + // rawExpr typechecks expression e and initializes x with the expression // value or type. If an error occurred, x.mode is set to invalid. +// If a non-nil target type T is given and e is a generic function +// or function call, T is used to infer the type arguments for e. // If hint != nil, it is the type of a composite literal element. // If allowGeneric is set, the operand type may be an uninstantiated // parameterized type or function value. -func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bool) exprKind { +func (check *Checker) rawExpr(T Type, x *operand, e ast.Expr, hint Type, allowGeneric bool) exprKind { if check.conf._Trace { check.trace(e.Pos(), "-- expr %s", e) check.indent++ @@ -1245,10 +1251,10 @@ func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bo }() } - kind := check.exprInternal(x, e, hint) + kind := check.exprInternal(T, x, e, hint) if !allowGeneric { - check.nonGeneric(x) + check.nonGeneric(T, x) } check.record(x) @@ -1256,9 +1262,10 @@ func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bo return kind } -// If x is a generic function or type, nonGeneric reports an error and invalidates x.mode and x.typ. +// If x is a generic type, or a generic function whose type arguments cannot be inferred +// from a non-nil target type T, nonGeneric reports an error and invalidates x.mode and x.typ. // Otherwise it leaves x alone. -func (check *Checker) nonGeneric(x *operand) { +func (check *Checker) nonGeneric(T Type, x *operand) { if x.mode == invalid || x.mode == novalue { return } @@ -1270,6 +1277,12 @@ func (check *Checker) nonGeneric(x *operand) { } case *Signature: if t.tparams != nil { + if check.conf._EnableReverseTypeInference && T != nil { + if _, ok := under(T).(*Signature); ok { + check.funcInst(T, x.Pos(), x, nil) + return + } + } what = "function" } } @@ -1282,7 +1295,8 @@ func (check *Checker) nonGeneric(x *operand) { // exprInternal contains the core of type checking of expressions. // Must only be called by rawExpr. -func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { +// (See rawExpr for an explanation of the parameters.) +func (check *Checker) exprInternal(T Type, x *operand, e ast.Expr, hint Type) exprKind { // make sure x has a valid state in case of bailout // (was go.dev/issue/5770) x.mode = invalid @@ -1418,7 +1432,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { key, _ := kv.Key.(*ast.Ident) // do all possible checks early (before exiting due to errors) // so we don't drop information on the floor - check.expr(x, kv.Value) + check.expr(nil, x, kv.Value) if key == nil { check.errorf(kv, InvalidLitField, "invalid field name %s in struct literal", kv.Key) continue @@ -1446,7 +1460,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { check.error(kv, MixedStructLit, "mixture of field:value and value elements in struct literal") continue } - check.expr(x, e) + check.expr(nil, x, e) if i >= len(fields) { check.errorf(x, InvalidStructLit, "too many values in struct literal of type %s", base) break // cannot continue @@ -1575,7 +1589,8 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { x.typ = typ case *ast.ParenExpr: - kind := check.rawExpr(x, e.X, nil, false) + // type inference doesn't go past parentheses (targe type T = nil) + kind := check.rawExpr(nil, x, e.X, nil, false) x.expr = e return kind @@ -1585,7 +1600,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { case *ast.IndexExpr, *ast.IndexListExpr: ix := typeparams.UnpackIndexExpr(e) if check.indexExpr(x, ix) { - check.funcInst(x, ix) + check.funcInst(T, e.Pos(), x, ix) } if x.mode == invalid { goto Error @@ -1598,7 +1613,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { } case *ast.TypeAssertExpr: - check.expr(x, e.X) + check.expr(nil, x, e.X) if x.mode == invalid { goto Error } @@ -1609,7 +1624,6 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { check.error(e, BadTypeKeyword, "use of .(type) outside type switch") goto Error } - // TODO(gri) we may want to permit type assertions on type parameter values at some point if isTypeParam(x.typ) { check.errorf(x, InvalidAssert, invalidOp+"cannot use type assertion on type parameter value %s", x) goto Error @@ -1761,10 +1775,19 @@ func (check *Checker) typeAssertion(e ast.Expr, x *operand, T Type, typeSwitch b } // expr typechecks expression e and initializes x with the expression value. +// If a non-nil target type T is given and e is a generic function +// or function call, T is used to infer the type arguments for e. // The result must be a single value. // If an error occurred, x.mode is set to invalid. -func (check *Checker) expr(x *operand, e ast.Expr) { - check.rawExpr(x, e, nil, false) +func (check *Checker) expr(T Type, x *operand, e ast.Expr) { + check.rawExpr(T, x, e, nil, false) + check.exclude(x, 1<