diff --git a/src/cmd/compile/internal/types2/api.go b/src/cmd/compile/internal/types2/api.go index 56fb578943..e027b9a7e2 100644 --- a/src/cmd/compile/internal/types2/api.go +++ b/src/cmd/compile/internal/types2/api.go @@ -169,6 +169,13 @@ type Config struct { // If DisableUnusedImportCheck is set, packages are not checked // for unused imports. DisableUnusedImportCheck bool + + // If EnableReverseTypeInference is set, uninstantiated and + // partially instantiated generic functions may be assigned + // (incl. returned) to variables of function type and type + // inference will attempt to infer the missing type arguments. + // Experimental. Needs a proposal. + EnableReverseTypeInference bool } func srcimporter_setUsesCgo(conf *Config) { diff --git a/src/cmd/compile/internal/types2/assignments.go b/src/cmd/compile/internal/types2/assignments.go index 3ca6bebd31..5a51b3de1e 100644 --- a/src/cmd/compile/internal/types2/assignments.go +++ b/src/cmd/compile/internal/types2/assignments.go @@ -189,7 +189,7 @@ func (check *Checker) lhsVar(lhs syntax.Expr) Type { } var x operand - check.expr(&x, lhs) + check.expr(nil, &x, lhs) if v != nil { v.used = v_used // restore v.used @@ -205,7 +205,7 @@ func (check *Checker) lhsVar(lhs syntax.Expr) Type { default: if sel, ok := x.expr.(*syntax.SelectorExpr); ok { var op operand - check.expr(&op, sel.X) + check.expr(nil, &op, sel.X) if op.mode == mapindex { check.errorf(&x, UnaddressableFieldAssign, "cannot assign to struct field %s in map", syntax.String(x.expr)) return Typ[Invalid] @@ -218,15 +218,20 @@ func (check *Checker) lhsVar(lhs syntax.Expr) Type { return x.typ } -// assignVar checks the assignment lhs = x. -func (check *Checker) assignVar(lhs syntax.Expr, x *operand) { - if x.mode == invalid { - check.useLHS(lhs) +// assignVar checks the assignment lhs = rhs (if x == nil), or lhs = x (if x != nil). +// If x != nil, it must be the evaluation of rhs (and rhs will be ignored). +func (check *Checker) assignVar(lhs, rhs syntax.Expr, x *operand) { + T := check.lhsVar(lhs) // nil if lhs is _ + if T == Typ[Invalid] { + check.use(rhs) return } - T := check.lhsVar(lhs) // nil if lhs is _ - if T == Typ[Invalid] { + if x == nil { + x = new(operand) + check.expr(T, x, rhs) + } + if x.mode == invalid { return } @@ -351,7 +356,7 @@ func (check *Checker) initVars(lhs []*Var, orig_rhs []syntax.Expr, returnStmt sy if l == r && !isCall { var x operand for i, lhs := range lhs { - check.expr(&x, orig_rhs[i]) + check.expr(lhs.typ, &x, orig_rhs[i]) check.initVar(lhs, &x, context) } return @@ -423,9 +428,7 @@ func (check *Checker) assignVars(lhs, orig_rhs []syntax.Expr) { // each value can be assigned to its corresponding variable. if l == r && !isCall { for i, lhs := range lhs { - var x operand - check.expr(&x, orig_rhs[i]) - check.assignVar(lhs, &x) + check.assignVar(lhs, orig_rhs[i], nil) } return } @@ -446,7 +449,7 @@ func (check *Checker) assignVars(lhs, orig_rhs []syntax.Expr) { r = len(rhs) if l == r { for i, lhs := range lhs { - check.assignVar(lhs, rhs[i]) + check.assignVar(lhs, nil, rhs[i]) } if commaOk { check.recordCommaOkTypes(orig_rhs[0], rhs) diff --git a/src/cmd/compile/internal/types2/builtins.go b/src/cmd/compile/internal/types2/builtins.go index e35dab8140..67aa37e401 100644 --- a/src/cmd/compile/internal/types2/builtins.go +++ b/src/cmd/compile/internal/types2/builtins.go @@ -678,7 +678,7 @@ func (check *Checker) builtin(x *operand, call *syntax.CallExpr, id builtinId) ( return } - check.expr(x, selx.X) + check.expr(nil, x, selx.X) if x.mode == invalid { return } @@ -878,7 +878,7 @@ func (check *Checker) builtin(x *operand, call *syntax.CallExpr, id builtinId) ( var t operand x1 := x for _, arg := range call.ArgList { - check.rawExpr(x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T)) + check.rawExpr(nil, x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T)) check.dump("%v: %s", posFor(x1), x1) x1 = &t // use incoming x only for first argument } diff --git a/src/cmd/compile/internal/types2/call.go b/src/cmd/compile/internal/types2/call.go index 72608dea26..bb82c2464e 100644 --- a/src/cmd/compile/internal/types2/call.go +++ b/src/cmd/compile/internal/types2/call.go @@ -8,31 +8,54 @@ package types2 import ( "cmd/compile/internal/syntax" + "fmt" . "internal/types/errors" "strings" "unicode" ) -// funcInst type-checks a function instantiation inst and returns the result in x. -// The operand x must be the evaluation of inst.X and its type must be a signature. -func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) { +// funcInst type-checks a function instantiation and returns the result in x. +// The incoming x must be an uninstantiated generic function. If inst != 0, +// it provides (some or all of) the type arguments (inst.Index) for the +// instantiation. If the target type T != nil and is a (non-generic) function +// signature, the signature's parameter types are used to infer additional +// missing type arguments of x, if any. +// At least one of inst or T must be provided. +func (check *Checker) funcInst(T Type, pos syntax.Pos, x *operand, inst *syntax.IndexExpr) { if !check.allowVersion(check.pkg, 1, 18) { check.versionErrorf(inst.Pos(), "go1.18", "function instantiation") } - xlist := unpackExpr(inst.Index) - targs := check.typeList(xlist) - if targs == nil { - x.mode = invalid - x.expr = inst - return + // tsig is the (assignment) target function signature, or nil. + // TODO(gri) refactor and pass in tsig to funcInst instead + var tsig *Signature + if check.conf.EnableReverseTypeInference && T != nil { + tsig, _ = under(T).(*Signature) } - assert(len(targs) == len(xlist)) - // check number of type arguments (got) vs number of type parameters (want) + // targs and xlist are the type arguments and corresponding type expressions, or nil. + var targs []Type + var xlist []syntax.Expr + if inst != nil { + xlist = unpackExpr(inst.Index) + targs = check.typeList(xlist) + if targs == nil { + x.mode = invalid + x.expr = inst + return + } + assert(len(targs) == len(xlist)) + } + + assert(tsig != nil || targs != nil) + + // Check the number of type arguments (got) vs number of type parameters (want). + // Note that x is a function value, not a type expression, so we don't need to + // call under below. sig := x.typ.(*Signature) got, want := len(targs), sig.TypeParams().Len() if got > want { + // Providing too many type arguments is always an error. check.errorf(xlist[got-1], WrongTypeArgCount, "got %d type arguments but want %d", got, want) x.mode = invalid x.expr = inst @@ -40,7 +63,37 @@ func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) { } if got < want { - targs = check.infer(inst.Pos(), sig.TypeParams().list(), targs, nil, nil) + // If the uninstantiated or partially instantiated function x is used in an + // assignment (tsig != nil), use the respective function parameter and result + // types to infer additional type arguments. + var args []*operand + var params []*Var + if tsig != nil && sig.tparams != nil && tsig.params.Len() == sig.params.Len() && tsig.results.Len() == sig.results.Len() { + // x is a generic function and the signature arity matches the target function. + // To infer x's missing type arguments, treat the function assignment as a call + // of a synthetic function f where f's parameters are the parameters and results + // of x and where the arguments to the call of f are values of the parameter and + // result types of x. + n := tsig.params.Len() + m := tsig.results.Len() + args = make([]*operand, n+m) + params = make([]*Var, n+m) + for i := 0; i < n; i++ { + lvar := tsig.params.At(i) + lname := syntax.NewName(x.Pos(), paramName(lvar.name, i, "parameter")) + args[i] = &operand{mode: value, expr: lname, typ: lvar.typ} + params[i] = sig.params.At(i) + } + for i := 0; i < m; i++ { + lvar := tsig.results.At(i) + lname := syntax.NewName(x.Pos(), paramName(lvar.name, i, "result parameter")) + args[n+i] = &operand{mode: value, expr: lname, typ: lvar.typ} + params[n+i] = sig.results.At(i) + } + } + + // Note that NewTuple(params...) below is nil if len(params) == 0, as desired. + targs = check.infer(pos, sig.TypeParams().list(), targs, NewTuple(params...), args) if targs == nil { // error was already reported x.mode = invalid @@ -54,10 +107,33 @@ func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) { // instantiate function signature sig = check.instantiateSignature(x.Pos(), sig, targs, xlist) assert(sig.TypeParams().Len() == 0) // signature is not generic anymore - check.recordInstance(inst.X, targs, sig) + x.typ = sig x.mode = value - x.expr = inst + // If we don't have an index expression, keep the existing expression of x. + if inst != nil { + x.expr = inst + } + check.recordInstance(x.expr, targs, sig) +} + +func paramName(name string, i int, kind string) string { + if name != "" { + return name + } + return nth(i+1) + " " + kind +} + +func nth(n int) string { + switch n { + case 1: + return "1st" + case 2: + return "2nd" + case 3: + return "3rd" + } + return fmt.Sprintf("%dth", n) } func (check *Checker) instantiateSignature(pos syntax.Pos, typ *Signature, targs []Type, xlist []syntax.Expr) (res *Signature) { @@ -119,7 +195,7 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind { case typexpr: // conversion - check.nonGeneric(x) + check.nonGeneric(nil, x) if x.mode == invalid { return conversion } @@ -129,7 +205,7 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind { case 0: check.errorf(call, WrongArgCount, "missing argument in conversion to %s", T) case 1: - check.expr(x, call.ArgList[0]) + check.expr(nil, x, call.ArgList[0]) if x.mode != invalid { if t, _ := under(T).(*Interface); t != nil && !isTypeParam(T) { if !t.IsMethodSet() { @@ -272,7 +348,7 @@ func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) { xlist = make([]*operand, len(elist)) for i, e := range elist { var x operand - check.expr(&x, e) + check.expr(nil, &x, e) xlist[i] = &x } } @@ -744,14 +820,14 @@ func (check *Checker) use1(e syntax.Expr, lhs bool) bool { } } } - check.rawExpr(&x, n, nil, true) + check.rawExpr(nil, &x, n, nil, true) if v != nil { v.used = v_used // restore v.used } case *syntax.ListExpr: return check.useN(n.ElemList, lhs) default: - check.rawExpr(&x, e, nil, true) + check.rawExpr(nil, &x, e, nil, true) } return x.mode != invalid } diff --git a/src/cmd/compile/internal/types2/check_test.go b/src/cmd/compile/internal/types2/check_test.go index 26bb1aed9e..382d1ad19e 100644 --- a/src/cmd/compile/internal/types2/check_test.go +++ b/src/cmd/compile/internal/types2/check_test.go @@ -133,6 +133,7 @@ func testFiles(t *testing.T, filenames []string, colDelta uint, manual bool) { flags := flag.NewFlagSet("", flag.PanicOnError) flags.StringVar(&conf.GoVersion, "lang", "", "") flags.BoolVar(&conf.FakeImportC, "fakeImportC", false, "") + flags.BoolVar(&conf.EnableReverseTypeInference, "reverseTypeInference", false, "") if err := parseFlags(filenames[0], nil, flags); err != nil { t.Fatal(err) } diff --git a/src/cmd/compile/internal/types2/decl.go b/src/cmd/compile/internal/types2/decl.go index 0ac0f6196a..afa32c1a5f 100644 --- a/src/cmd/compile/internal/types2/decl.go +++ b/src/cmd/compile/internal/types2/decl.go @@ -408,7 +408,7 @@ func (check *Checker) constDecl(obj *Const, typ, init syntax.Expr, inherited boo // (see issues go.dev/issue/42991, go.dev/issue/42992). check.errpos = obj.pos } - check.expr(&x, init) + check.expr(nil, &x, init) } check.initConst(obj, &x) } @@ -455,7 +455,7 @@ func (check *Checker) varDecl(obj *Var, lhs []*Var, typ, init syntax.Expr) { if lhs == nil || len(lhs) == 1 { assert(lhs == nil || lhs[0] == obj) var x operand - check.expr(&x, init) + check.expr(obj.typ, &x, init) check.initVar(obj, &x, "variable declaration") return } diff --git a/src/cmd/compile/internal/types2/expr.go b/src/cmd/compile/internal/types2/expr.go index fdc7bdbef0..bab52b253b 100644 --- a/src/cmd/compile/internal/types2/expr.go +++ b/src/cmd/compile/internal/types2/expr.go @@ -173,7 +173,7 @@ func underIs(typ Type, f func(Type) bool) bool { } func (check *Checker) unary(x *operand, e *syntax.Operation) { - check.expr(x, e.X) + check.expr(nil, x, e.X) if x.mode == invalid { return } @@ -1097,8 +1097,8 @@ func init() { func (check *Checker) binary(x *operand, e syntax.Expr, lhs, rhs syntax.Expr, op syntax.Operator) { var y operand - check.expr(x, lhs) - check.expr(&y, rhs) + check.expr(nil, x, lhs) + check.expr(nil, &y, rhs) if x.mode == invalid { return @@ -1245,12 +1245,18 @@ const ( statement ) +// TODO(gri) In rawExpr below, consider using T instead of hint and +// some sort of "operation mode" instead of allowGeneric. +// May be clearer and less error-prone. + // rawExpr typechecks expression e and initializes x with the expression // value or type. If an error occurred, x.mode is set to invalid. +// If a non-nil target type T is given and e is a generic function +// or function call, T is used to infer the type arguments for e. // If hint != nil, it is the type of a composite literal element. // If allowGeneric is set, the operand type may be an uninstantiated // parameterized type or function value. -func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric bool) exprKind { +func (check *Checker) rawExpr(T Type, x *operand, e syntax.Expr, hint Type, allowGeneric bool) exprKind { if check.conf.Trace { check.trace(e.Pos(), "-- expr %s", e) check.indent++ @@ -1260,10 +1266,10 @@ func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric }() } - kind := check.exprInternal(x, e, hint) + kind := check.exprInternal(T, x, e, hint) if !allowGeneric { - check.nonGeneric(x) + check.nonGeneric(T, x) } check.record(x) @@ -1271,9 +1277,10 @@ func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric return kind } -// If x is a generic function or type, nonGeneric reports an error and invalidates x.mode and x.typ. +// If x is a generic type, or a generic function whose type arguments cannot be inferred +// from a non-nil target type T, nonGeneric reports an error and invalidates x.mode and x.typ. // Otherwise it leaves x alone. -func (check *Checker) nonGeneric(x *operand) { +func (check *Checker) nonGeneric(T Type, x *operand) { if x.mode == invalid || x.mode == novalue { return } @@ -1285,6 +1292,12 @@ func (check *Checker) nonGeneric(x *operand) { } case *Signature: if t.tparams != nil { + if check.conf.EnableReverseTypeInference && T != nil { + if _, ok := under(T).(*Signature); ok { + check.funcInst(T, x.Pos(), x, nil) + return + } + } what = "function" } } @@ -1297,7 +1310,8 @@ func (check *Checker) nonGeneric(x *operand) { // exprInternal contains the core of type checking of expressions. // Must only be called by rawExpr. -func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKind { +// (See rawExpr for an explanation of the parameters.) +func (check *Checker) exprInternal(T Type, x *operand, e syntax.Expr, hint Type) exprKind { // make sure x has a valid state in case of bailout // (was go.dev/issue/5770) x.mode = invalid @@ -1438,7 +1452,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin key, _ := kv.Key.(*syntax.Name) // do all possible checks early (before exiting due to errors) // so we don't drop information on the floor - check.expr(x, kv.Value) + check.expr(nil, x, kv.Value) if key == nil { check.errorf(kv, InvalidLitField, "invalid field name %s in struct literal", kv.Key) continue @@ -1466,7 +1480,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin check.error(kv, MixedStructLit, "mixture of field:value and value elements in struct literal") continue } - check.expr(x, e) + check.expr(nil, x, e) if i >= len(fields) { check.errorf(x, InvalidStructLit, "too many values in struct literal of type %s", base) break // cannot continue @@ -1593,7 +1607,8 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin x.typ = typ case *syntax.ParenExpr: - kind := check.rawExpr(x, e.X, nil, false) + // type inference doesn't go past parentheses (targe type T = nil) + kind := check.rawExpr(nil, x, e.X, nil, false) x.expr = e return kind @@ -1602,7 +1617,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin case *syntax.IndexExpr: if check.indexExpr(x, e) { - check.funcInst(x, e) + check.funcInst(T, e.Pos(), x, e) } if x.mode == invalid { goto Error @@ -1615,7 +1630,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin } case *syntax.AssertExpr: - check.expr(x, e.X) + check.expr(nil, x, e.X) if x.mode == invalid { goto Error } @@ -1624,7 +1639,6 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin check.error(e, InvalidSyntaxTree, "invalid use of AssertExpr") goto Error } - // TODO(gri) we may want to permit type assertions on type parameter values at some point if isTypeParam(x.typ) { check.errorf(x, InvalidAssert, invalidOp+"cannot use type assertion on type parameter value %s", x) goto Error @@ -1814,10 +1828,19 @@ func (check *Checker) typeAssertion(e syntax.Expr, x *operand, T Type, typeSwitc } // expr typechecks expression e and initializes x with the expression value. +// If a non-nil target type T is given and e is a generic function +// or function call, T is used to infer the type arguments for e. // The result must be a single value. // If an error occurred, x.mode is set to invalid. -func (check *Checker) expr(x *operand, e syntax.Expr) { - check.rawExpr(x, e, nil, false) +func (check *Checker) expr(T Type, x *operand, e syntax.Expr) { + check.rawExpr(T, x, e, nil, false) + check.exclude(x, 1< want { + // Providing too many type arguments is always an error. check.errorf(ix.Indices[got-1], WrongTypeArgCount, "got %d type arguments but want %d", got, want) x.mode = invalid x.expr = ix.Orig @@ -41,11 +65,43 @@ func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) { } if got < want { - targs = check.infer(ix.Orig, sig.TypeParams().list(), targs, nil, nil) + // If the uninstantiated or partially instantiated function x is used in an + // assignment (tsig != nil), use the respective function parameter and result + // types to infer additional type arguments. + var args []*operand + var params []*Var + if tsig != nil && sig.tparams != nil && tsig.params.Len() == sig.params.Len() && tsig.results.Len() == sig.results.Len() { + // x is a generic function and the signature arity matches the target function. + // To infer x's missing type arguments, treat the function assignment as a call + // of a synthetic function f where f's parameters are the parameters and results + // of x and where the arguments to the call of f are values of the parameter and + // result types of x. + n := tsig.params.Len() + m := tsig.results.Len() + args = make([]*operand, n+m) + params = make([]*Var, n+m) + for i := 0; i < n; i++ { + lvar := tsig.params.At(i) + lname := ast.NewIdent(paramName(lvar.name, i, "parameter")) + lname.NamePos = x.Pos() // correct position + args[i] = &operand{mode: value, expr: lname, typ: lvar.typ} + params[i] = sig.params.At(i) + } + for i := 0; i < m; i++ { + lvar := tsig.results.At(i) + lname := ast.NewIdent(paramName(lvar.name, i, "result parameter")) + lname.NamePos = x.Pos() // correct position + args[n+i] = &operand{mode: value, expr: lname, typ: lvar.typ} + params[n+i] = sig.results.At(i) + } + } + + // Note that NewTuple(params...) below is nil if len(params) == 0, as desired. + targs = check.infer(atPos(pos), sig.TypeParams().list(), targs, NewTuple(params...), args) if targs == nil { // error was already reported x.mode = invalid - x.expr = ix.Orig + x.expr = ix // TODO(gri) is this correct? return } got = len(targs) @@ -53,12 +109,35 @@ func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) { assert(got == want) // instantiate function signature - sig = check.instantiateSignature(x.Pos(), sig, targs, ix.Indices) + sig = check.instantiateSignature(x.Pos(), sig, targs, xlist) assert(sig.TypeParams().Len() == 0) // signature is not generic anymore - check.recordInstance(ix.Orig, targs, sig) + x.typ = sig x.mode = value - x.expr = ix.Orig + // If we don't have an index expression, keep the existing expression of x. + if ix != nil { + x.expr = ix.Orig + } + check.recordInstance(x.expr, targs, sig) +} + +func paramName(name string, i int, kind string) string { + if name != "" { + return name + } + return nth(i+1) + " " + kind +} + +func nth(n int) string { + switch n { + case 1: + return "1st" + case 2: + return "2nd" + case 3: + return "3rd" + } + return fmt.Sprintf("%dth", n) } func (check *Checker) instantiateSignature(pos token.Pos, typ *Signature, targs []Type, xlist []ast.Expr) (res *Signature) { @@ -121,7 +200,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind { case typexpr: // conversion - check.nonGeneric(x) + check.nonGeneric(nil, x) if x.mode == invalid { return conversion } @@ -131,7 +210,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind { case 0: check.errorf(inNode(call, call.Rparen), WrongArgCount, "missing argument in conversion to %s", T) case 1: - check.expr(x, call.Args[0]) + check.expr(nil, x, call.Args[0]) if x.mode != invalid { if call.Ellipsis.IsValid() { check.errorf(call.Args[0], BadDotDotDotSyntax, "invalid use of ... in conversion to %s", T) @@ -274,7 +353,7 @@ func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) { xlist = make([]*operand, len(elist)) for i, e := range elist { var x operand - check.expr(&x, e) + check.expr(nil, &x, e) xlist[i] = &x } } @@ -791,12 +870,12 @@ func (check *Checker) use1(e ast.Expr, lhs bool) bool { } } } - check.rawExpr(&x, n, nil, true) + check.rawExpr(nil, &x, n, nil, true) if v != nil { v.used = v_used // restore v.used } default: - check.rawExpr(&x, e, nil, true) + check.rawExpr(nil, &x, e, nil, true) } return x.mode != invalid } diff --git a/src/go/types/check_test.go b/src/go/types/check_test.go index 36809838c7..0f4c320a47 100644 --- a/src/go/types/check_test.go +++ b/src/go/types/check_test.go @@ -145,6 +145,7 @@ func testFiles(t *testing.T, sizes Sizes, filenames []string, srcs [][]byte, man flags := flag.NewFlagSet("", flag.PanicOnError) flags.StringVar(&conf.GoVersion, "lang", "", "") flags.BoolVar(&conf.FakeImportC, "fakeImportC", false, "") + flags.BoolVar(boolFieldAddr(&conf, "_EnableReverseTypeInference"), "reverseTypeInference", false, "") if err := parseFlags(filenames[0], srcs[0], flags); err != nil { t.Fatal(err) } diff --git a/src/go/types/decl.go b/src/go/types/decl.go index 393d8f34e2..3065da2e8e 100644 --- a/src/go/types/decl.go +++ b/src/go/types/decl.go @@ -477,7 +477,7 @@ func (check *Checker) constDecl(obj *Const, typ, init ast.Expr, inherited bool) // (see issues go.dev/issue/42991, go.dev/issue/42992). check.errpos = atPos(obj.pos) } - check.expr(&x, init) + check.expr(nil, &x, init) } check.initConst(obj, &x) } @@ -510,7 +510,7 @@ func (check *Checker) varDecl(obj *Var, lhs []*Var, typ, init ast.Expr) { if lhs == nil || len(lhs) == 1 { assert(lhs == nil || lhs[0] == obj) var x operand - check.expr(&x, init) + check.expr(obj.typ, &x, init) check.initVar(obj, &x, "variable declaration") return } diff --git a/src/go/types/eval.go b/src/go/types/eval.go index 1e4d64fe96..1655a8bd27 100644 --- a/src/go/types/eval.go +++ b/src/go/types/eval.go @@ -91,8 +91,8 @@ func CheckExpr(fset *token.FileSet, pkg *Package, pos token.Pos, expr ast.Expr, // evaluate node var x operand - check.rawExpr(&x, expr, nil, true) // allow generic expressions - check.processDelayed(0) // incl. all functions + check.rawExpr(nil, &x, expr, nil, true) // allow generic expressions + check.processDelayed(0) // incl. all functions check.recordUntyped() return nil diff --git a/src/go/types/expr.go b/src/go/types/expr.go index 1abf963b7f..219a392b88 100644 --- a/src/go/types/expr.go +++ b/src/go/types/expr.go @@ -160,7 +160,7 @@ func underIs(typ Type, f func(Type) bool) bool { // The unary expression e may be nil. It's passed in for better error messages only. func (check *Checker) unary(x *operand, e *ast.UnaryExpr) { - check.expr(x, e.X) + check.expr(nil, x, e.X) if x.mode == invalid { return } @@ -1079,8 +1079,8 @@ func init() { func (check *Checker) binary(x *operand, e ast.Expr, lhs, rhs ast.Expr, op token.Token, opPos token.Pos) { var y operand - check.expr(x, lhs) - check.expr(&y, rhs) + check.expr(nil, x, lhs) + check.expr(nil, &y, rhs) if x.mode == invalid { return @@ -1230,12 +1230,18 @@ const ( statement ) +// TODO(gri) In rawExpr below, consider using T instead of hint and +// some sort of "operation mode" instead of allowGeneric. +// May be clearer and less error-prone. + // rawExpr typechecks expression e and initializes x with the expression // value or type. If an error occurred, x.mode is set to invalid. +// If a non-nil target type T is given and e is a generic function +// or function call, T is used to infer the type arguments for e. // If hint != nil, it is the type of a composite literal element. // If allowGeneric is set, the operand type may be an uninstantiated // parameterized type or function value. -func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bool) exprKind { +func (check *Checker) rawExpr(T Type, x *operand, e ast.Expr, hint Type, allowGeneric bool) exprKind { if check.conf._Trace { check.trace(e.Pos(), "-- expr %s", e) check.indent++ @@ -1245,10 +1251,10 @@ func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bo }() } - kind := check.exprInternal(x, e, hint) + kind := check.exprInternal(T, x, e, hint) if !allowGeneric { - check.nonGeneric(x) + check.nonGeneric(T, x) } check.record(x) @@ -1256,9 +1262,10 @@ func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bo return kind } -// If x is a generic function or type, nonGeneric reports an error and invalidates x.mode and x.typ. +// If x is a generic type, or a generic function whose type arguments cannot be inferred +// from a non-nil target type T, nonGeneric reports an error and invalidates x.mode and x.typ. // Otherwise it leaves x alone. -func (check *Checker) nonGeneric(x *operand) { +func (check *Checker) nonGeneric(T Type, x *operand) { if x.mode == invalid || x.mode == novalue { return } @@ -1270,6 +1277,12 @@ func (check *Checker) nonGeneric(x *operand) { } case *Signature: if t.tparams != nil { + if check.conf._EnableReverseTypeInference && T != nil { + if _, ok := under(T).(*Signature); ok { + check.funcInst(T, x.Pos(), x, nil) + return + } + } what = "function" } } @@ -1282,7 +1295,8 @@ func (check *Checker) nonGeneric(x *operand) { // exprInternal contains the core of type checking of expressions. // Must only be called by rawExpr. -func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { +// (See rawExpr for an explanation of the parameters.) +func (check *Checker) exprInternal(T Type, x *operand, e ast.Expr, hint Type) exprKind { // make sure x has a valid state in case of bailout // (was go.dev/issue/5770) x.mode = invalid @@ -1418,7 +1432,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { key, _ := kv.Key.(*ast.Ident) // do all possible checks early (before exiting due to errors) // so we don't drop information on the floor - check.expr(x, kv.Value) + check.expr(nil, x, kv.Value) if key == nil { check.errorf(kv, InvalidLitField, "invalid field name %s in struct literal", kv.Key) continue @@ -1446,7 +1460,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { check.error(kv, MixedStructLit, "mixture of field:value and value elements in struct literal") continue } - check.expr(x, e) + check.expr(nil, x, e) if i >= len(fields) { check.errorf(x, InvalidStructLit, "too many values in struct literal of type %s", base) break // cannot continue @@ -1575,7 +1589,8 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { x.typ = typ case *ast.ParenExpr: - kind := check.rawExpr(x, e.X, nil, false) + // type inference doesn't go past parentheses (targe type T = nil) + kind := check.rawExpr(nil, x, e.X, nil, false) x.expr = e return kind @@ -1585,7 +1600,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { case *ast.IndexExpr, *ast.IndexListExpr: ix := typeparams.UnpackIndexExpr(e) if check.indexExpr(x, ix) { - check.funcInst(x, ix) + check.funcInst(T, e.Pos(), x, ix) } if x.mode == invalid { goto Error @@ -1598,7 +1613,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { } case *ast.TypeAssertExpr: - check.expr(x, e.X) + check.expr(nil, x, e.X) if x.mode == invalid { goto Error } @@ -1609,7 +1624,6 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind { check.error(e, BadTypeKeyword, "use of .(type) outside type switch") goto Error } - // TODO(gri) we may want to permit type assertions on type parameter values at some point if isTypeParam(x.typ) { check.errorf(x, InvalidAssert, invalidOp+"cannot use type assertion on type parameter value %s", x) goto Error @@ -1761,10 +1775,19 @@ func (check *Checker) typeAssertion(e ast.Expr, x *operand, T Type, typeSwitch b } // expr typechecks expression e and initializes x with the expression value. +// If a non-nil target type T is given and e is a generic function +// or function call, T is used to infer the type arguments for e. // The result must be a single value. // If an error occurred, x.mode is set to invalid. -func (check *Checker) expr(x *operand, e ast.Expr) { - check.rawExpr(x, e, nil, false) +func (check *Checker) expr(T Type, x *operand, e ast.Expr) { + check.rawExpr(T, x, e, nil, false) + check.exclude(x, 1<