go/types, types2: reverse inference of function type arguments

This CL implements type inference for generic functions used in
assignments: variable init expressions, regular assignments, and
return statements, but (not yet) function arguments passed to
functions. For instance, given a generic function

        func f[P any](x P)

and a variable of function type

        var v func(x int)

the assignment

        v = f

is valid w/o explicit instantiation of f, and the missing type
argument for f is inferred from the type of v. More generally,
the function f may have multiple type arguments, and it may be
partially instantiated.

This new form of inference is not enabled by default (it needs
to go through the proposal process first). It can be enabled
by setting Config.EnableReverseTypeInference.

The mechanism is implemented as follows:

- The various expression evaluation functions take an additional
  (first) argument T, which is the target type for the expression.
  If not nil, it is the type of the LHS in an assignment.

- The method Checker.funcInst is changed such that it uses both,
  provided type arguments (if any), and a target type (if any)
  to augment type inference.

Change-Id: Idfde61078e1ee4f22abcca894a4c84d681734ff6
Reviewed-on: https://go-review.googlesource.com/c/go/+/476075
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Robert Griesemer <gri@google.com>
Reviewed-by: Robert Findley <rfindley@google.com>
Reviewed-by: Robert Griesemer <gri@google.com>
Run-TryBot: Robert Griesemer <gri@google.com>
This commit is contained in:
Robert Griesemer 2023-03-13 16:38:14 -07:00 committed by Gopher Robot
parent 93b3035dbb
commit cc048b32f3
23 changed files with 462 additions and 159 deletions

View File

@ -169,6 +169,13 @@ type Config struct {
// If DisableUnusedImportCheck is set, packages are not checked
// for unused imports.
DisableUnusedImportCheck bool
// If EnableReverseTypeInference is set, uninstantiated and
// partially instantiated generic functions may be assigned
// (incl. returned) to variables of function type and type
// inference will attempt to infer the missing type arguments.
// Experimental. Needs a proposal.
EnableReverseTypeInference bool
}
func srcimporter_setUsesCgo(conf *Config) {

View File

@ -189,7 +189,7 @@ func (check *Checker) lhsVar(lhs syntax.Expr) Type {
}
var x operand
check.expr(&x, lhs)
check.expr(nil, &x, lhs)
if v != nil {
v.used = v_used // restore v.used
@ -205,7 +205,7 @@ func (check *Checker) lhsVar(lhs syntax.Expr) Type {
default:
if sel, ok := x.expr.(*syntax.SelectorExpr); ok {
var op operand
check.expr(&op, sel.X)
check.expr(nil, &op, sel.X)
if op.mode == mapindex {
check.errorf(&x, UnaddressableFieldAssign, "cannot assign to struct field %s in map", syntax.String(x.expr))
return Typ[Invalid]
@ -218,15 +218,20 @@ func (check *Checker) lhsVar(lhs syntax.Expr) Type {
return x.typ
}
// assignVar checks the assignment lhs = x.
func (check *Checker) assignVar(lhs syntax.Expr, x *operand) {
if x.mode == invalid {
check.useLHS(lhs)
// assignVar checks the assignment lhs = rhs (if x == nil), or lhs = x (if x != nil).
// If x != nil, it must be the evaluation of rhs (and rhs will be ignored).
func (check *Checker) assignVar(lhs, rhs syntax.Expr, x *operand) {
T := check.lhsVar(lhs) // nil if lhs is _
if T == Typ[Invalid] {
check.use(rhs)
return
}
T := check.lhsVar(lhs) // nil if lhs is _
if T == Typ[Invalid] {
if x == nil {
x = new(operand)
check.expr(T, x, rhs)
}
if x.mode == invalid {
return
}
@ -351,7 +356,7 @@ func (check *Checker) initVars(lhs []*Var, orig_rhs []syntax.Expr, returnStmt sy
if l == r && !isCall {
var x operand
for i, lhs := range lhs {
check.expr(&x, orig_rhs[i])
check.expr(lhs.typ, &x, orig_rhs[i])
check.initVar(lhs, &x, context)
}
return
@ -423,9 +428,7 @@ func (check *Checker) assignVars(lhs, orig_rhs []syntax.Expr) {
// each value can be assigned to its corresponding variable.
if l == r && !isCall {
for i, lhs := range lhs {
var x operand
check.expr(&x, orig_rhs[i])
check.assignVar(lhs, &x)
check.assignVar(lhs, orig_rhs[i], nil)
}
return
}
@ -446,7 +449,7 @@ func (check *Checker) assignVars(lhs, orig_rhs []syntax.Expr) {
r = len(rhs)
if l == r {
for i, lhs := range lhs {
check.assignVar(lhs, rhs[i])
check.assignVar(lhs, nil, rhs[i])
}
if commaOk {
check.recordCommaOkTypes(orig_rhs[0], rhs)

View File

@ -678,7 +678,7 @@ func (check *Checker) builtin(x *operand, call *syntax.CallExpr, id builtinId) (
return
}
check.expr(x, selx.X)
check.expr(nil, x, selx.X)
if x.mode == invalid {
return
}
@ -878,7 +878,7 @@ func (check *Checker) builtin(x *operand, call *syntax.CallExpr, id builtinId) (
var t operand
x1 := x
for _, arg := range call.ArgList {
check.rawExpr(x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
check.rawExpr(nil, x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
check.dump("%v: %s", posFor(x1), x1)
x1 = &t // use incoming x only for first argument
}

View File

@ -8,31 +8,54 @@ package types2
import (
"cmd/compile/internal/syntax"
"fmt"
. "internal/types/errors"
"strings"
"unicode"
)
// funcInst type-checks a function instantiation inst and returns the result in x.
// The operand x must be the evaluation of inst.X and its type must be a signature.
func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) {
// funcInst type-checks a function instantiation and returns the result in x.
// The incoming x must be an uninstantiated generic function. If inst != 0,
// it provides (some or all of) the type arguments (inst.Index) for the
// instantiation. If the target type T != nil and is a (non-generic) function
// signature, the signature's parameter types are used to infer additional
// missing type arguments of x, if any.
// At least one of inst or T must be provided.
func (check *Checker) funcInst(T Type, pos syntax.Pos, x *operand, inst *syntax.IndexExpr) {
if !check.allowVersion(check.pkg, 1, 18) {
check.versionErrorf(inst.Pos(), "go1.18", "function instantiation")
}
xlist := unpackExpr(inst.Index)
targs := check.typeList(xlist)
// tsig is the (assignment) target function signature, or nil.
// TODO(gri) refactor and pass in tsig to funcInst instead
var tsig *Signature
if check.conf.EnableReverseTypeInference && T != nil {
tsig, _ = under(T).(*Signature)
}
// targs and xlist are the type arguments and corresponding type expressions, or nil.
var targs []Type
var xlist []syntax.Expr
if inst != nil {
xlist = unpackExpr(inst.Index)
targs = check.typeList(xlist)
if targs == nil {
x.mode = invalid
x.expr = inst
return
}
assert(len(targs) == len(xlist))
}
// check number of type arguments (got) vs number of type parameters (want)
assert(tsig != nil || targs != nil)
// Check the number of type arguments (got) vs number of type parameters (want).
// Note that x is a function value, not a type expression, so we don't need to
// call under below.
sig := x.typ.(*Signature)
got, want := len(targs), sig.TypeParams().Len()
if got > want {
// Providing too many type arguments is always an error.
check.errorf(xlist[got-1], WrongTypeArgCount, "got %d type arguments but want %d", got, want)
x.mode = invalid
x.expr = inst
@ -40,7 +63,37 @@ func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) {
}
if got < want {
targs = check.infer(inst.Pos(), sig.TypeParams().list(), targs, nil, nil)
// If the uninstantiated or partially instantiated function x is used in an
// assignment (tsig != nil), use the respective function parameter and result
// types to infer additional type arguments.
var args []*operand
var params []*Var
if tsig != nil && sig.tparams != nil && tsig.params.Len() == sig.params.Len() && tsig.results.Len() == sig.results.Len() {
// x is a generic function and the signature arity matches the target function.
// To infer x's missing type arguments, treat the function assignment as a call
// of a synthetic function f where f's parameters are the parameters and results
// of x and where the arguments to the call of f are values of the parameter and
// result types of x.
n := tsig.params.Len()
m := tsig.results.Len()
args = make([]*operand, n+m)
params = make([]*Var, n+m)
for i := 0; i < n; i++ {
lvar := tsig.params.At(i)
lname := syntax.NewName(x.Pos(), paramName(lvar.name, i, "parameter"))
args[i] = &operand{mode: value, expr: lname, typ: lvar.typ}
params[i] = sig.params.At(i)
}
for i := 0; i < m; i++ {
lvar := tsig.results.At(i)
lname := syntax.NewName(x.Pos(), paramName(lvar.name, i, "result parameter"))
args[n+i] = &operand{mode: value, expr: lname, typ: lvar.typ}
params[n+i] = sig.results.At(i)
}
}
// Note that NewTuple(params...) below is nil if len(params) == 0, as desired.
targs = check.infer(pos, sig.TypeParams().list(), targs, NewTuple(params...), args)
if targs == nil {
// error was already reported
x.mode = invalid
@ -54,10 +107,33 @@ func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) {
// instantiate function signature
sig = check.instantiateSignature(x.Pos(), sig, targs, xlist)
assert(sig.TypeParams().Len() == 0) // signature is not generic anymore
check.recordInstance(inst.X, targs, sig)
x.typ = sig
x.mode = value
// If we don't have an index expression, keep the existing expression of x.
if inst != nil {
x.expr = inst
}
check.recordInstance(x.expr, targs, sig)
}
func paramName(name string, i int, kind string) string {
if name != "" {
return name
}
return nth(i+1) + " " + kind
}
func nth(n int) string {
switch n {
case 1:
return "1st"
case 2:
return "2nd"
case 3:
return "3rd"
}
return fmt.Sprintf("%dth", n)
}
func (check *Checker) instantiateSignature(pos syntax.Pos, typ *Signature, targs []Type, xlist []syntax.Expr) (res *Signature) {
@ -119,7 +195,7 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind {
case typexpr:
// conversion
check.nonGeneric(x)
check.nonGeneric(nil, x)
if x.mode == invalid {
return conversion
}
@ -129,7 +205,7 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind {
case 0:
check.errorf(call, WrongArgCount, "missing argument in conversion to %s", T)
case 1:
check.expr(x, call.ArgList[0])
check.expr(nil, x, call.ArgList[0])
if x.mode != invalid {
if t, _ := under(T).(*Interface); t != nil && !isTypeParam(T) {
if !t.IsMethodSet() {
@ -272,7 +348,7 @@ func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) {
xlist = make([]*operand, len(elist))
for i, e := range elist {
var x operand
check.expr(&x, e)
check.expr(nil, &x, e)
xlist[i] = &x
}
}
@ -744,14 +820,14 @@ func (check *Checker) use1(e syntax.Expr, lhs bool) bool {
}
}
}
check.rawExpr(&x, n, nil, true)
check.rawExpr(nil, &x, n, nil, true)
if v != nil {
v.used = v_used // restore v.used
}
case *syntax.ListExpr:
return check.useN(n.ElemList, lhs)
default:
check.rawExpr(&x, e, nil, true)
check.rawExpr(nil, &x, e, nil, true)
}
return x.mode != invalid
}

View File

@ -133,6 +133,7 @@ func testFiles(t *testing.T, filenames []string, colDelta uint, manual bool) {
flags := flag.NewFlagSet("", flag.PanicOnError)
flags.StringVar(&conf.GoVersion, "lang", "", "")
flags.BoolVar(&conf.FakeImportC, "fakeImportC", false, "")
flags.BoolVar(&conf.EnableReverseTypeInference, "reverseTypeInference", false, "")
if err := parseFlags(filenames[0], nil, flags); err != nil {
t.Fatal(err)
}

View File

@ -408,7 +408,7 @@ func (check *Checker) constDecl(obj *Const, typ, init syntax.Expr, inherited boo
// (see issues go.dev/issue/42991, go.dev/issue/42992).
check.errpos = obj.pos
}
check.expr(&x, init)
check.expr(nil, &x, init)
}
check.initConst(obj, &x)
}
@ -455,7 +455,7 @@ func (check *Checker) varDecl(obj *Var, lhs []*Var, typ, init syntax.Expr) {
if lhs == nil || len(lhs) == 1 {
assert(lhs == nil || lhs[0] == obj)
var x operand
check.expr(&x, init)
check.expr(obj.typ, &x, init)
check.initVar(obj, &x, "variable declaration")
return
}

View File

@ -173,7 +173,7 @@ func underIs(typ Type, f func(Type) bool) bool {
}
func (check *Checker) unary(x *operand, e *syntax.Operation) {
check.expr(x, e.X)
check.expr(nil, x, e.X)
if x.mode == invalid {
return
}
@ -1097,8 +1097,8 @@ func init() {
func (check *Checker) binary(x *operand, e syntax.Expr, lhs, rhs syntax.Expr, op syntax.Operator) {
var y operand
check.expr(x, lhs)
check.expr(&y, rhs)
check.expr(nil, x, lhs)
check.expr(nil, &y, rhs)
if x.mode == invalid {
return
@ -1245,12 +1245,18 @@ const (
statement
)
// TODO(gri) In rawExpr below, consider using T instead of hint and
// some sort of "operation mode" instead of allowGeneric.
// May be clearer and less error-prone.
// rawExpr typechecks expression e and initializes x with the expression
// value or type. If an error occurred, x.mode is set to invalid.
// If a non-nil target type T is given and e is a generic function
// or function call, T is used to infer the type arguments for e.
// If hint != nil, it is the type of a composite literal element.
// If allowGeneric is set, the operand type may be an uninstantiated
// parameterized type or function value.
func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric bool) exprKind {
func (check *Checker) rawExpr(T Type, x *operand, e syntax.Expr, hint Type, allowGeneric bool) exprKind {
if check.conf.Trace {
check.trace(e.Pos(), "-- expr %s", e)
check.indent++
@ -1260,10 +1266,10 @@ func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric
}()
}
kind := check.exprInternal(x, e, hint)
kind := check.exprInternal(T, x, e, hint)
if !allowGeneric {
check.nonGeneric(x)
check.nonGeneric(T, x)
}
check.record(x)
@ -1271,9 +1277,10 @@ func (check *Checker) rawExpr(x *operand, e syntax.Expr, hint Type, allowGeneric
return kind
}
// If x is a generic function or type, nonGeneric reports an error and invalidates x.mode and x.typ.
// If x is a generic type, or a generic function whose type arguments cannot be inferred
// from a non-nil target type T, nonGeneric reports an error and invalidates x.mode and x.typ.
// Otherwise it leaves x alone.
func (check *Checker) nonGeneric(x *operand) {
func (check *Checker) nonGeneric(T Type, x *operand) {
if x.mode == invalid || x.mode == novalue {
return
}
@ -1285,6 +1292,12 @@ func (check *Checker) nonGeneric(x *operand) {
}
case *Signature:
if t.tparams != nil {
if check.conf.EnableReverseTypeInference && T != nil {
if _, ok := under(T).(*Signature); ok {
check.funcInst(T, x.Pos(), x, nil)
return
}
}
what = "function"
}
}
@ -1297,7 +1310,8 @@ func (check *Checker) nonGeneric(x *operand) {
// exprInternal contains the core of type checking of expressions.
// Must only be called by rawExpr.
func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKind {
// (See rawExpr for an explanation of the parameters.)
func (check *Checker) exprInternal(T Type, x *operand, e syntax.Expr, hint Type) exprKind {
// make sure x has a valid state in case of bailout
// (was go.dev/issue/5770)
x.mode = invalid
@ -1438,7 +1452,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin
key, _ := kv.Key.(*syntax.Name)
// do all possible checks early (before exiting due to errors)
// so we don't drop information on the floor
check.expr(x, kv.Value)
check.expr(nil, x, kv.Value)
if key == nil {
check.errorf(kv, InvalidLitField, "invalid field name %s in struct literal", kv.Key)
continue
@ -1466,7 +1480,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin
check.error(kv, MixedStructLit, "mixture of field:value and value elements in struct literal")
continue
}
check.expr(x, e)
check.expr(nil, x, e)
if i >= len(fields) {
check.errorf(x, InvalidStructLit, "too many values in struct literal of type %s", base)
break // cannot continue
@ -1593,7 +1607,8 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin
x.typ = typ
case *syntax.ParenExpr:
kind := check.rawExpr(x, e.X, nil, false)
// type inference doesn't go past parentheses (targe type T = nil)
kind := check.rawExpr(nil, x, e.X, nil, false)
x.expr = e
return kind
@ -1602,7 +1617,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin
case *syntax.IndexExpr:
if check.indexExpr(x, e) {
check.funcInst(x, e)
check.funcInst(T, e.Pos(), x, e)
}
if x.mode == invalid {
goto Error
@ -1615,7 +1630,7 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin
}
case *syntax.AssertExpr:
check.expr(x, e.X)
check.expr(nil, x, e.X)
if x.mode == invalid {
goto Error
}
@ -1624,7 +1639,6 @@ func (check *Checker) exprInternal(x *operand, e syntax.Expr, hint Type) exprKin
check.error(e, InvalidSyntaxTree, "invalid use of AssertExpr")
goto Error
}
// TODO(gri) we may want to permit type assertions on type parameter values at some point
if isTypeParam(x.typ) {
check.errorf(x, InvalidAssert, invalidOp+"cannot use type assertion on type parameter value %s", x)
goto Error
@ -1814,10 +1828,19 @@ func (check *Checker) typeAssertion(e syntax.Expr, x *operand, T Type, typeSwitc
}
// expr typechecks expression e and initializes x with the expression value.
// If a non-nil target type T is given and e is a generic function
// or function call, T is used to infer the type arguments for e.
// The result must be a single value.
// If an error occurred, x.mode is set to invalid.
func (check *Checker) expr(x *operand, e syntax.Expr) {
check.rawExpr(x, e, nil, false)
func (check *Checker) expr(T Type, x *operand, e syntax.Expr) {
check.rawExpr(T, x, e, nil, false)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
check.singleValue(x)
}
// genericExpr is like expr but the result may also be generic.
func (check *Checker) genericExpr(x *operand, e syntax.Expr) {
check.rawExpr(nil, x, e, nil, true)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
check.singleValue(x)
}
@ -1829,7 +1852,7 @@ func (check *Checker) expr(x *operand, e syntax.Expr) {
// If an error occurred, list[0] is not valid.
func (check *Checker) multiExpr(e syntax.Expr, allowCommaOk bool) (list []*operand, commaOk bool) {
var x operand
check.rawExpr(&x, e, nil, false)
check.rawExpr(nil, &x, e, nil, false)
check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
@ -1860,7 +1883,7 @@ func (check *Checker) multiExpr(e syntax.Expr, allowCommaOk bool) (list []*opera
// If an error occurred, x.mode is set to invalid.
func (check *Checker) exprWithHint(x *operand, e syntax.Expr, hint Type) {
assert(hint != nil)
check.rawExpr(x, e, hint, false)
check.rawExpr(nil, x, e, hint, false)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
check.singleValue(x)
}
@ -1870,7 +1893,7 @@ func (check *Checker) exprWithHint(x *operand, e syntax.Expr, hint Type) {
// value.
// If an error occurred, x.mode is set to invalid.
func (check *Checker) exprOrType(x *operand, e syntax.Expr, allowGeneric bool) {
check.rawExpr(x, e, nil, allowGeneric)
check.rawExpr(nil, x, e, nil, allowGeneric)
check.exclude(x, 1<<novalue)
check.singleValue(x)
}

View File

@ -42,7 +42,7 @@ func (check *Checker) indexExpr(x *operand, e *syntax.IndexExpr) (isFuncInst boo
}
// x should not be generic at this point, but be safe and check
check.nonGeneric(x)
check.nonGeneric(nil, x)
if x.mode == invalid {
return false
}
@ -92,7 +92,7 @@ func (check *Checker) indexExpr(x *operand, e *syntax.IndexExpr) (isFuncInst boo
return false
}
var key operand
check.expr(&key, index)
check.expr(nil, &key, index)
check.assignment(&key, typ.key, "map index")
// ok to continue even if indexing failed - map element type is known
x.mode = mapindex
@ -166,7 +166,7 @@ func (check *Checker) indexExpr(x *operand, e *syntax.IndexExpr) (isFuncInst boo
return false
}
var k operand
check.expr(&k, index)
check.expr(nil, &k, index)
check.assignment(&k, key, "map index")
// ok to continue even if indexing failed - map element type is known
x.mode = mapindex
@ -206,7 +206,7 @@ func (check *Checker) indexExpr(x *operand, e *syntax.IndexExpr) (isFuncInst boo
}
func (check *Checker) sliceExpr(x *operand, e *syntax.SliceExpr) {
check.expr(x, e.X)
check.expr(nil, x, e.X)
if x.mode == invalid {
check.use(e.Index[:]...)
return
@ -353,7 +353,7 @@ func (check *Checker) index(index syntax.Expr, max int64) (typ Type, val int64)
val = -1
var x operand
check.expr(&x, index)
check.expr(nil, &x, index)
if !check.isValidIndex(&x, InvalidIndex, "index", false) {
return
}

View File

@ -180,7 +180,7 @@ func (check *Checker) suspendedCall(keyword string, call syntax.Expr) {
var x operand
var msg string
switch check.rawExpr(&x, call, nil, false) {
switch check.rawExpr(nil, &x, call, nil, false) {
case conversion:
msg = "requires function call, not conversion"
case expression:
@ -240,7 +240,7 @@ func (check *Checker) caseValues(x *operand, values []syntax.Expr, seen valueMap
L:
for _, e := range values {
var v operand
check.expr(&v, e)
check.expr(nil, &v, e)
if x.mode == invalid || v.mode == invalid {
continue L
}
@ -294,7 +294,7 @@ L:
// The spec allows the value nil instead of a type.
if check.isNil(e) {
T = nil
check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
check.expr(nil, &dummy, e) // run e through expr so we get the usual Info recordings
} else {
T = check.varType(e)
if T == Typ[Invalid] {
@ -336,7 +336,7 @@ L:
// // The spec allows the value nil instead of a type.
// var hash string
// if check.isNil(e) {
// check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
// check.expr(nil, &dummy, e) // run e through expr so we get the usual Info recordings
// T = nil
// hash = "<nil>" // avoid collision with a type named nil
// } else {
@ -403,7 +403,7 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
// function and method calls and receive operations can appear
// in statement context. Such statements may be parenthesized."
var x operand
kind := check.rawExpr(&x, s.X, nil, false)
kind := check.rawExpr(nil, &x, s.X, nil, false)
var msg string
var code Code
switch x.mode {
@ -424,8 +424,8 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
case *syntax.SendStmt:
var ch, val operand
check.expr(&ch, s.Chan)
check.expr(&val, s.Value)
check.expr(nil, &ch, s.Chan)
check.expr(nil, &val, s.Value)
if ch.mode == invalid || val.mode == invalid {
return
}
@ -450,7 +450,7 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
// x++ or x--
// (no need to call unpackExpr as s.Lhs must be single-valued)
var x operand
check.expr(&x, s.Lhs)
check.expr(nil, &x, s.Lhs)
if x.mode == invalid {
return
}
@ -458,7 +458,7 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
check.errorf(s.Lhs, NonNumericIncDec, invalidOp+"%s%s%s (non-numeric type %s)", s.Lhs, s.Op, s.Op, x.typ)
return
}
check.assignVar(s.Lhs, &x)
check.assignVar(s.Lhs, nil, &x)
return
}
@ -481,7 +481,7 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
var x operand
check.binary(&x, nil, lhs[0], rhs[0], s.Op)
check.assignVar(lhs[0], &x)
check.assignVar(lhs[0], nil, &x)
case *syntax.CallStmt:
kind := "go"
@ -566,7 +566,7 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
check.simpleStmt(s.Init)
var x operand
check.expr(&x, s.Cond)
check.expr(nil, &x, s.Cond)
if x.mode != invalid && !allBoolean(x.typ) {
check.error(s.Cond, InvalidCond, "non-boolean condition in if statement")
}
@ -656,7 +656,7 @@ func (check *Checker) stmt(ctxt stmtContext, s syntax.Stmt) {
check.simpleStmt(s.Init)
if s.Cond != nil {
var x operand
check.expr(&x, s.Cond)
check.expr(nil, &x, s.Cond)
if x.mode != invalid && !allBoolean(x.typ) {
check.error(s.Cond, InvalidCond, "non-boolean condition in for statement")
}
@ -680,7 +680,7 @@ func (check *Checker) switchStmt(inner stmtContext, s *syntax.SwitchStmt) {
var x operand
if s.Tag != nil {
check.expr(&x, s.Tag)
check.expr(nil, &x, s.Tag)
// By checking assignment of x to an invisible temporary
// (as a compiler would), we get all the relevant checks.
check.assignment(&x, nil, "switch expression")
@ -747,7 +747,7 @@ func (check *Checker) typeSwitchStmt(inner stmtContext, s *syntax.SwitchStmt, gu
// check rhs
var x operand
check.expr(&x, guard.X)
check.expr(nil, &x, guard.X)
if x.mode == invalid {
return
}
@ -847,7 +847,7 @@ func (check *Checker) rangeStmt(inner stmtContext, s *syntax.ForStmt, rclause *s
// check expression to iterate over
var x operand
check.expr(&x, rclause.X)
check.expr(nil, &x, rclause.X)
// determine key/value types
var key, val Type
@ -950,7 +950,7 @@ func (check *Checker) rangeStmt(inner stmtContext, s *syntax.ForStmt, rclause *s
x.mode = value
x.expr = lhs // we don't have a better rhs expression to use here
x.typ = typ
check.assignVar(lhs, &x)
check.assignVar(lhs, nil, &x)
}
}
}

View File

@ -489,7 +489,7 @@ func (check *Checker) arrayLength(e syntax.Expr) int64 {
}
var x operand
check.expr(&x, e)
check.expr(nil, &x, e)
if x.mode != constant_ {
if x.mode != invalid {
check.errorf(&x, InvalidArrayLen, "array length %s must be constant", &x)

View File

@ -170,6 +170,13 @@ type Config struct {
// If DisableUnusedImportCheck is set, packages are not checked
// for unused imports.
DisableUnusedImportCheck bool
// If _EnableReverseTypeInference is set, uninstantiated and
// partially instantiated generic functions may be assigned
// (incl. returned) to variables of function type and type
// inference will attempt to infer the missing type arguments.
// Experimental. Needs a proposal.
_EnableReverseTypeInference bool
}
func srcimporter_setUsesCgo(conf *Config) {

View File

@ -187,7 +187,7 @@ func (check *Checker) lhsVar(lhs ast.Expr) Type {
}
var x operand
check.expr(&x, lhs)
check.expr(nil, &x, lhs)
if v != nil {
v.used = v_used // restore v.used
@ -203,7 +203,7 @@ func (check *Checker) lhsVar(lhs ast.Expr) Type {
default:
if sel, ok := x.expr.(*ast.SelectorExpr); ok {
var op operand
check.expr(&op, sel.X)
check.expr(nil, &op, sel.X)
if op.mode == mapindex {
check.errorf(&x, UnaddressableFieldAssign, "cannot assign to struct field %s in map", ExprString(x.expr))
return Typ[Invalid]
@ -216,15 +216,20 @@ func (check *Checker) lhsVar(lhs ast.Expr) Type {
return x.typ
}
// assignVar checks the assignment lhs = x.
func (check *Checker) assignVar(lhs ast.Expr, x *operand) {
if x.mode == invalid {
check.useLHS(lhs)
// assignVar checks the assignment lhs = rhs (if x == nil), or lhs = x (if x != nil).
// If x != nil, it must be the evaluation of rhs (and rhs will be ignored).
func (check *Checker) assignVar(lhs, rhs ast.Expr, x *operand) {
T := check.lhsVar(lhs) // nil if lhs is _
if T == Typ[Invalid] {
check.use(rhs)
return
}
T := check.lhsVar(lhs) // nil if lhs is _
if T == Typ[Invalid] {
if x == nil {
x = new(operand)
check.expr(T, x, rhs)
}
if x.mode == invalid {
return
}
@ -349,7 +354,7 @@ func (check *Checker) initVars(lhs []*Var, orig_rhs []ast.Expr, returnStmt ast.S
if l == r && !isCall {
var x operand
for i, lhs := range lhs {
check.expr(&x, orig_rhs[i])
check.expr(lhs.typ, &x, orig_rhs[i])
check.initVar(lhs, &x, context)
}
return
@ -421,9 +426,7 @@ func (check *Checker) assignVars(lhs, orig_rhs []ast.Expr) {
// each value can be assigned to its corresponding variable.
if l == r && !isCall {
for i, lhs := range lhs {
var x operand
check.expr(&x, orig_rhs[i])
check.assignVar(lhs, &x)
check.assignVar(lhs, orig_rhs[i], nil)
}
return
}
@ -444,7 +447,7 @@ func (check *Checker) assignVars(lhs, orig_rhs []ast.Expr) {
r = len(rhs)
if l == r {
for i, lhs := range lhs {
check.assignVar(lhs, rhs[i])
check.assignVar(lhs, nil, rhs[i])
}
if commaOk {
check.recordCommaOkTypes(orig_rhs[0], rhs)

View File

@ -679,7 +679,7 @@ func (check *Checker) builtin(x *operand, call *ast.CallExpr, id builtinId) (_ b
return
}
check.expr(x, selx.X)
check.expr(nil, x, selx.X)
if x.mode == invalid {
return
}
@ -879,7 +879,7 @@ func (check *Checker) builtin(x *operand, call *ast.CallExpr, id builtinId) (_ b
var t operand
x1 := x
for _, arg := range call.Args {
check.rawExpr(x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
check.rawExpr(nil, x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
check.dump("%v: %s", x1.Pos(), x1)
x1 = &t // use incoming x only for first argument
}

View File

@ -7,6 +7,7 @@
package types
import (
"fmt"
"go/ast"
"go/internal/typeparams"
"go/token"
@ -15,25 +16,48 @@ import (
"unicode"
)
// funcInst type-checks a function instantiation inst and returns the result in x.
// The operand x must be the evaluation of inst.X and its type must be a signature.
func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) {
// funcInst type-checks a function instantiation and returns the result in x.
// The incoming x must be an uninstantiated generic function. If ix != 0,
// it provides (some or all of) the type arguments (ix.Indices) for the
// instantiation. If the target type T != nil and is a (non-generic) function
// signature, the signature's parameter types are used to infer additional
// missing type arguments of x, if any.
// At least one of inst or T must be provided.
func (check *Checker) funcInst(T Type, pos token.Pos, x *operand, ix *typeparams.IndexExpr) {
if !check.allowVersion(check.pkg, 1, 18) {
check.softErrorf(inNode(ix.Orig, ix.Lbrack), UnsupportedFeature, "function instantiation requires go1.18 or later")
}
targs := check.typeList(ix.Indices)
// tsig is the (assignment) target function signature, or nil.
// TODO(gri) refactor and pass in tsig to funcInst instead
var tsig *Signature
if check.conf._EnableReverseTypeInference && T != nil {
tsig, _ = under(T).(*Signature)
}
// targs and xlist are the type arguments and corresponding type expressions, or nil.
var targs []Type
var xlist []ast.Expr
if ix != nil {
xlist = ix.Indices
targs = check.typeList(xlist)
if targs == nil {
x.mode = invalid
x.expr = ix.Orig
x.expr = ix
return
}
assert(len(targs) == len(ix.Indices))
assert(len(targs) == len(xlist))
}
// check number of type arguments (got) vs number of type parameters (want)
assert(tsig != nil || targs != nil)
// Check the number of type arguments (got) vs number of type parameters (want).
// Note that x is a function value, not a type expression, so we don't need to
// call under below.
sig := x.typ.(*Signature)
got, want := len(targs), sig.TypeParams().Len()
if got > want {
// Providing too many type arguments is always an error.
check.errorf(ix.Indices[got-1], WrongTypeArgCount, "got %d type arguments but want %d", got, want)
x.mode = invalid
x.expr = ix.Orig
@ -41,11 +65,43 @@ func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) {
}
if got < want {
targs = check.infer(ix.Orig, sig.TypeParams().list(), targs, nil, nil)
// If the uninstantiated or partially instantiated function x is used in an
// assignment (tsig != nil), use the respective function parameter and result
// types to infer additional type arguments.
var args []*operand
var params []*Var
if tsig != nil && sig.tparams != nil && tsig.params.Len() == sig.params.Len() && tsig.results.Len() == sig.results.Len() {
// x is a generic function and the signature arity matches the target function.
// To infer x's missing type arguments, treat the function assignment as a call
// of a synthetic function f where f's parameters are the parameters and results
// of x and where the arguments to the call of f are values of the parameter and
// result types of x.
n := tsig.params.Len()
m := tsig.results.Len()
args = make([]*operand, n+m)
params = make([]*Var, n+m)
for i := 0; i < n; i++ {
lvar := tsig.params.At(i)
lname := ast.NewIdent(paramName(lvar.name, i, "parameter"))
lname.NamePos = x.Pos() // correct position
args[i] = &operand{mode: value, expr: lname, typ: lvar.typ}
params[i] = sig.params.At(i)
}
for i := 0; i < m; i++ {
lvar := tsig.results.At(i)
lname := ast.NewIdent(paramName(lvar.name, i, "result parameter"))
lname.NamePos = x.Pos() // correct position
args[n+i] = &operand{mode: value, expr: lname, typ: lvar.typ}
params[n+i] = sig.results.At(i)
}
}
// Note that NewTuple(params...) below is nil if len(params) == 0, as desired.
targs = check.infer(atPos(pos), sig.TypeParams().list(), targs, NewTuple(params...), args)
if targs == nil {
// error was already reported
x.mode = invalid
x.expr = ix.Orig
x.expr = ix // TODO(gri) is this correct?
return
}
got = len(targs)
@ -53,12 +109,35 @@ func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) {
assert(got == want)
// instantiate function signature
sig = check.instantiateSignature(x.Pos(), sig, targs, ix.Indices)
sig = check.instantiateSignature(x.Pos(), sig, targs, xlist)
assert(sig.TypeParams().Len() == 0) // signature is not generic anymore
check.recordInstance(ix.Orig, targs, sig)
x.typ = sig
x.mode = value
// If we don't have an index expression, keep the existing expression of x.
if ix != nil {
x.expr = ix.Orig
}
check.recordInstance(x.expr, targs, sig)
}
func paramName(name string, i int, kind string) string {
if name != "" {
return name
}
return nth(i+1) + " " + kind
}
func nth(n int) string {
switch n {
case 1:
return "1st"
case 2:
return "2nd"
case 3:
return "3rd"
}
return fmt.Sprintf("%dth", n)
}
func (check *Checker) instantiateSignature(pos token.Pos, typ *Signature, targs []Type, xlist []ast.Expr) (res *Signature) {
@ -121,7 +200,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
case typexpr:
// conversion
check.nonGeneric(x)
check.nonGeneric(nil, x)
if x.mode == invalid {
return conversion
}
@ -131,7 +210,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
case 0:
check.errorf(inNode(call, call.Rparen), WrongArgCount, "missing argument in conversion to %s", T)
case 1:
check.expr(x, call.Args[0])
check.expr(nil, x, call.Args[0])
if x.mode != invalid {
if call.Ellipsis.IsValid() {
check.errorf(call.Args[0], BadDotDotDotSyntax, "invalid use of ... in conversion to %s", T)
@ -274,7 +353,7 @@ func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) {
xlist = make([]*operand, len(elist))
for i, e := range elist {
var x operand
check.expr(&x, e)
check.expr(nil, &x, e)
xlist[i] = &x
}
}
@ -791,12 +870,12 @@ func (check *Checker) use1(e ast.Expr, lhs bool) bool {
}
}
}
check.rawExpr(&x, n, nil, true)
check.rawExpr(nil, &x, n, nil, true)
if v != nil {
v.used = v_used // restore v.used
}
default:
check.rawExpr(&x, e, nil, true)
check.rawExpr(nil, &x, e, nil, true)
}
return x.mode != invalid
}

View File

@ -145,6 +145,7 @@ func testFiles(t *testing.T, sizes Sizes, filenames []string, srcs [][]byte, man
flags := flag.NewFlagSet("", flag.PanicOnError)
flags.StringVar(&conf.GoVersion, "lang", "", "")
flags.BoolVar(&conf.FakeImportC, "fakeImportC", false, "")
flags.BoolVar(boolFieldAddr(&conf, "_EnableReverseTypeInference"), "reverseTypeInference", false, "")
if err := parseFlags(filenames[0], srcs[0], flags); err != nil {
t.Fatal(err)
}

View File

@ -477,7 +477,7 @@ func (check *Checker) constDecl(obj *Const, typ, init ast.Expr, inherited bool)
// (see issues go.dev/issue/42991, go.dev/issue/42992).
check.errpos = atPos(obj.pos)
}
check.expr(&x, init)
check.expr(nil, &x, init)
}
check.initConst(obj, &x)
}
@ -510,7 +510,7 @@ func (check *Checker) varDecl(obj *Var, lhs []*Var, typ, init ast.Expr) {
if lhs == nil || len(lhs) == 1 {
assert(lhs == nil || lhs[0] == obj)
var x operand
check.expr(&x, init)
check.expr(obj.typ, &x, init)
check.initVar(obj, &x, "variable declaration")
return
}

View File

@ -91,7 +91,7 @@ func CheckExpr(fset *token.FileSet, pkg *Package, pos token.Pos, expr ast.Expr,
// evaluate node
var x operand
check.rawExpr(&x, expr, nil, true) // allow generic expressions
check.rawExpr(nil, &x, expr, nil, true) // allow generic expressions
check.processDelayed(0) // incl. all functions
check.recordUntyped()

View File

@ -160,7 +160,7 @@ func underIs(typ Type, f func(Type) bool) bool {
// The unary expression e may be nil. It's passed in for better error messages only.
func (check *Checker) unary(x *operand, e *ast.UnaryExpr) {
check.expr(x, e.X)
check.expr(nil, x, e.X)
if x.mode == invalid {
return
}
@ -1079,8 +1079,8 @@ func init() {
func (check *Checker) binary(x *operand, e ast.Expr, lhs, rhs ast.Expr, op token.Token, opPos token.Pos) {
var y operand
check.expr(x, lhs)
check.expr(&y, rhs)
check.expr(nil, x, lhs)
check.expr(nil, &y, rhs)
if x.mode == invalid {
return
@ -1230,12 +1230,18 @@ const (
statement
)
// TODO(gri) In rawExpr below, consider using T instead of hint and
// some sort of "operation mode" instead of allowGeneric.
// May be clearer and less error-prone.
// rawExpr typechecks expression e and initializes x with the expression
// value or type. If an error occurred, x.mode is set to invalid.
// If a non-nil target type T is given and e is a generic function
// or function call, T is used to infer the type arguments for e.
// If hint != nil, it is the type of a composite literal element.
// If allowGeneric is set, the operand type may be an uninstantiated
// parameterized type or function value.
func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bool) exprKind {
func (check *Checker) rawExpr(T Type, x *operand, e ast.Expr, hint Type, allowGeneric bool) exprKind {
if check.conf._Trace {
check.trace(e.Pos(), "-- expr %s", e)
check.indent++
@ -1245,10 +1251,10 @@ func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bo
}()
}
kind := check.exprInternal(x, e, hint)
kind := check.exprInternal(T, x, e, hint)
if !allowGeneric {
check.nonGeneric(x)
check.nonGeneric(T, x)
}
check.record(x)
@ -1256,9 +1262,10 @@ func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bo
return kind
}
// If x is a generic function or type, nonGeneric reports an error and invalidates x.mode and x.typ.
// If x is a generic type, or a generic function whose type arguments cannot be inferred
// from a non-nil target type T, nonGeneric reports an error and invalidates x.mode and x.typ.
// Otherwise it leaves x alone.
func (check *Checker) nonGeneric(x *operand) {
func (check *Checker) nonGeneric(T Type, x *operand) {
if x.mode == invalid || x.mode == novalue {
return
}
@ -1270,6 +1277,12 @@ func (check *Checker) nonGeneric(x *operand) {
}
case *Signature:
if t.tparams != nil {
if check.conf._EnableReverseTypeInference && T != nil {
if _, ok := under(T).(*Signature); ok {
check.funcInst(T, x.Pos(), x, nil)
return
}
}
what = "function"
}
}
@ -1282,7 +1295,8 @@ func (check *Checker) nonGeneric(x *operand) {
// exprInternal contains the core of type checking of expressions.
// Must only be called by rawExpr.
func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
// (See rawExpr for an explanation of the parameters.)
func (check *Checker) exprInternal(T Type, x *operand, e ast.Expr, hint Type) exprKind {
// make sure x has a valid state in case of bailout
// (was go.dev/issue/5770)
x.mode = invalid
@ -1418,7 +1432,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
key, _ := kv.Key.(*ast.Ident)
// do all possible checks early (before exiting due to errors)
// so we don't drop information on the floor
check.expr(x, kv.Value)
check.expr(nil, x, kv.Value)
if key == nil {
check.errorf(kv, InvalidLitField, "invalid field name %s in struct literal", kv.Key)
continue
@ -1446,7 +1460,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
check.error(kv, MixedStructLit, "mixture of field:value and value elements in struct literal")
continue
}
check.expr(x, e)
check.expr(nil, x, e)
if i >= len(fields) {
check.errorf(x, InvalidStructLit, "too many values in struct literal of type %s", base)
break // cannot continue
@ -1575,7 +1589,8 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
x.typ = typ
case *ast.ParenExpr:
kind := check.rawExpr(x, e.X, nil, false)
// type inference doesn't go past parentheses (targe type T = nil)
kind := check.rawExpr(nil, x, e.X, nil, false)
x.expr = e
return kind
@ -1585,7 +1600,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
case *ast.IndexExpr, *ast.IndexListExpr:
ix := typeparams.UnpackIndexExpr(e)
if check.indexExpr(x, ix) {
check.funcInst(x, ix)
check.funcInst(T, e.Pos(), x, ix)
}
if x.mode == invalid {
goto Error
@ -1598,7 +1613,7 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
}
case *ast.TypeAssertExpr:
check.expr(x, e.X)
check.expr(nil, x, e.X)
if x.mode == invalid {
goto Error
}
@ -1609,7 +1624,6 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
check.error(e, BadTypeKeyword, "use of .(type) outside type switch")
goto Error
}
// TODO(gri) we may want to permit type assertions on type parameter values at some point
if isTypeParam(x.typ) {
check.errorf(x, InvalidAssert, invalidOp+"cannot use type assertion on type parameter value %s", x)
goto Error
@ -1761,10 +1775,19 @@ func (check *Checker) typeAssertion(e ast.Expr, x *operand, T Type, typeSwitch b
}
// expr typechecks expression e and initializes x with the expression value.
// If a non-nil target type T is given and e is a generic function
// or function call, T is used to infer the type arguments for e.
// The result must be a single value.
// If an error occurred, x.mode is set to invalid.
func (check *Checker) expr(x *operand, e ast.Expr) {
check.rawExpr(x, e, nil, false)
func (check *Checker) expr(T Type, x *operand, e ast.Expr) {
check.rawExpr(T, x, e, nil, false)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
check.singleValue(x)
}
// genericExpr is like expr but the result may also be generic.
func (check *Checker) genericExpr(x *operand, e ast.Expr) {
check.rawExpr(nil, x, e, nil, true)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
check.singleValue(x)
}
@ -1776,7 +1799,7 @@ func (check *Checker) expr(x *operand, e ast.Expr) {
// If an error occurred, list[0] is not valid.
func (check *Checker) multiExpr(e ast.Expr, allowCommaOk bool) (list []*operand, commaOk bool) {
var x operand
check.rawExpr(&x, e, nil, false)
check.rawExpr(nil, &x, e, nil, false)
check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
@ -1807,7 +1830,7 @@ func (check *Checker) multiExpr(e ast.Expr, allowCommaOk bool) (list []*operand,
// If an error occurred, x.mode is set to invalid.
func (check *Checker) exprWithHint(x *operand, e ast.Expr, hint Type) {
assert(hint != nil)
check.rawExpr(x, e, hint, false)
check.rawExpr(nil, x, e, hint, false)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
check.singleValue(x)
}
@ -1817,7 +1840,7 @@ func (check *Checker) exprWithHint(x *operand, e ast.Expr, hint Type) {
// value.
// If an error occurred, x.mode is set to invalid.
func (check *Checker) exprOrType(x *operand, e ast.Expr, allowGeneric bool) {
check.rawExpr(x, e, nil, allowGeneric)
check.rawExpr(nil, x, e, nil, allowGeneric)
check.exclude(x, 1<<novalue)
check.singleValue(x)
}

View File

@ -43,7 +43,7 @@ func (check *Checker) indexExpr(x *operand, e *typeparams.IndexExpr) (isFuncInst
}
// x should not be generic at this point, but be safe and check
check.nonGeneric(x)
check.nonGeneric(nil, x)
if x.mode == invalid {
return false
}
@ -93,7 +93,7 @@ func (check *Checker) indexExpr(x *operand, e *typeparams.IndexExpr) (isFuncInst
return false
}
var key operand
check.expr(&key, index)
check.expr(nil, &key, index)
check.assignment(&key, typ.key, "map index")
// ok to continue even if indexing failed - map element type is known
x.mode = mapindex
@ -167,7 +167,7 @@ func (check *Checker) indexExpr(x *operand, e *typeparams.IndexExpr) (isFuncInst
return false
}
var k operand
check.expr(&k, index)
check.expr(nil, &k, index)
check.assignment(&k, key, "map index")
// ok to continue even if indexing failed - map element type is known
x.mode = mapindex
@ -208,7 +208,7 @@ func (check *Checker) indexExpr(x *operand, e *typeparams.IndexExpr) (isFuncInst
}
func (check *Checker) sliceExpr(x *operand, e *ast.SliceExpr) {
check.expr(x, e.X)
check.expr(nil, x, e.X)
if x.mode == invalid {
check.use(e.Low, e.High, e.Max)
return
@ -350,7 +350,7 @@ func (check *Checker) index(index ast.Expr, max int64) (typ Type, val int64) {
val = -1
var x operand
check.expr(&x, index)
check.expr(nil, &x, index)
if !check.isValidIndex(&x, InvalidIndex, "index", false) {
return
}

View File

@ -173,7 +173,7 @@ func (check *Checker) suspendedCall(keyword string, call *ast.CallExpr) {
var x operand
var msg string
var code Code
switch check.rawExpr(&x, call, nil, false) {
switch check.rawExpr(nil, &x, call, nil, false) {
case conversion:
msg = "requires function call, not conversion"
code = InvalidDefer
@ -237,7 +237,7 @@ func (check *Checker) caseValues(x *operand, values []ast.Expr, seen valueMap) {
L:
for _, e := range values {
var v operand
check.expr(&v, e)
check.expr(nil, &v, e)
if x.mode == invalid || v.mode == invalid {
continue L
}
@ -288,7 +288,7 @@ L:
// The spec allows the value nil instead of a type.
if check.isNil(e) {
T = nil
check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
check.expr(nil, &dummy, e) // run e through expr so we get the usual Info recordings
} else {
T = check.varType(e)
if T == Typ[Invalid] {
@ -327,7 +327,7 @@ L:
// // The spec allows the value nil instead of a type.
// var hash string
// if check.isNil(e) {
// check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
// check.expr(nil, &dummy, e) // run e through expr so we get the usual Info recordings
// T = nil
// hash = "<nil>" // avoid collision with a type named nil
// } else {
@ -394,7 +394,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
// function and method calls and receive operations can appear
// in statement context. Such statements may be parenthesized."
var x operand
kind := check.rawExpr(&x, s.X, nil, false)
kind := check.rawExpr(nil, &x, s.X, nil, false)
var msg string
var code Code
switch x.mode {
@ -415,8 +415,8 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
case *ast.SendStmt:
var ch, val operand
check.expr(&ch, s.Chan)
check.expr(&val, s.Value)
check.expr(nil, &ch, s.Chan)
check.expr(nil, &val, s.Value)
if ch.mode == invalid || val.mode == invalid {
return
}
@ -449,7 +449,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
}
var x operand
check.expr(&x, s.X)
check.expr(nil, &x, s.X)
if x.mode == invalid {
return
}
@ -463,7 +463,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
if x.mode == invalid {
return
}
check.assignVar(s.X, &x)
check.assignVar(s.X, nil, &x)
case *ast.AssignStmt:
switch s.Tok {
@ -495,7 +495,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
if x.mode == invalid {
return
}
check.assignVar(s.Lhs[0], &x)
check.assignVar(s.Lhs[0], nil, &x)
}
case *ast.GoStmt:
@ -570,7 +570,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
check.simpleStmt(s.Init)
var x operand
check.expr(&x, s.Cond)
check.expr(nil, &x, s.Cond)
if x.mode != invalid && !allBoolean(x.typ) {
check.error(s.Cond, InvalidCond, "non-boolean condition in if statement")
}
@ -594,7 +594,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
check.simpleStmt(s.Init)
var x operand
if s.Tag != nil {
check.expr(&x, s.Tag)
check.expr(nil, &x, s.Tag)
// By checking assignment of x to an invisible temporary
// (as a compiler would), we get all the relevant checks.
check.assignment(&x, nil, "switch expression")
@ -686,7 +686,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
return
}
var x operand
check.expr(&x, expr.X)
check.expr(nil, &x, expr.X)
if x.mode == invalid {
return
}
@ -808,7 +808,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
check.simpleStmt(s.Init)
if s.Cond != nil {
var x operand
check.expr(&x, s.Cond)
check.expr(nil, &x, s.Cond)
if x.mode != invalid && !allBoolean(x.typ) {
check.error(s.Cond, InvalidCond, "non-boolean condition in for statement")
}
@ -830,7 +830,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
// check expression to iterate over
var x operand
check.expr(&x, s.X)
check.expr(nil, &x, s.X)
// determine key/value types
var key, val Type
@ -928,7 +928,7 @@ func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
x.mode = value
x.expr = lhs // we don't have a better rhs expression to use here
x.typ = typ
check.assignVar(lhs, &x)
check.assignVar(lhs, nil, &x)
}
}
}

View File

@ -480,7 +480,7 @@ func (check *Checker) arrayLength(e ast.Expr) int64 {
}
var x operand
check.expr(&x, e)
check.expr(nil, &x, e)
if x.mode != constant_ {
if x.mode != invalid {
check.errorf(&x, InvalidArrayLen, "array length %s must be constant", &x)

View File

@ -148,3 +148,10 @@ func _() {
wantsMethods /* ERROR "any does not satisfy interface{m1(Q); m2() R} (missing method m1)" */ (any(nil))
wantsMethods /* ERROR "hasMethods4 does not satisfy interface{m1(Q); m2() R} (wrong type for method m1)" */ (hasMethods4(nil))
}
// "Reverse" type inference is not yet permitted.
func f[P any](P) {}
// This must not crash.
var _ func(int) = f // ERROR "cannot use generic function f without instantiation"

View File

@ -0,0 +1,73 @@
// -reverseTypeInference
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file shows some examples of "reverse" type inference
// where the type arguments for generic functions are determined
// from assigning the functions.
package p
func f1[P any](P) {}
func f2[P any]() P { var x P; return x }
func f3[P, Q any](P) Q { var x Q; return x }
func f4[P any](P, P) {}
func f5[P any](P) []P { return nil }
// initialization expressions
var (
v1 = f1 // ERROR "cannot use generic function f1 without instantiation"
v2 func(int) = f2 // ERROR "cannot infer P"
v3 func(int) = f1
v4 func() int = f2
v5 func(int) int = f3
_ func(int) int = f3[int]
v6 func(int, int) = f4
v7 func(int, string) = f4 // ERROR "type string of 2nd parameter does not match inferred type int for P"
v8 func(int) []int = f5
v9 func(string) []int = f5 // ERROR "type []int of 1st result parameter does not match inferred type []string for []P"
_, _ func(int) = f1, f1
_, _ func(int) = f1, f2 // ERROR "cannot infer P"
)
// Regular assignments
func _() {
v1 = f1 // no error here because v1 is invalid (we don't know its type) due to the error above
var v1_ func() int
_ = v1_
v1_ = f1 // ERROR "cannot infer P"
v2 = f2 // ERROR "cannot infer P"
v3 = f1
v4 = f2
v5 = f3
v5 = f3[int]
v6 = f4
v7 = f4 // ERROR "type string of 2nd parameter does not match inferred type int for P"
v8 = f5
v9 = f5 // ERROR "type []int of 1st result parameter does not match inferred type []string for []P"
}
// Return statements
func _() func(int) { return f1 }
func _() func() int { return f2 }
func _() func(int) int { return f3 }
func _() func(int) int { return f3[int] }
func _() func(int, int) { return f4 }
func _() func(int, string) {
return f4 /* ERROR "type string of 2nd parameter does not match inferred type int for P" */
}
func _() func(int) []int { return f5 }
func _() func(string) []int {
return f5 /* ERROR "type []int of 1st result parameter does not match inferred type []string for []P" */
}
func _() (_, _ func(int)) { return f1, f1 }
func _() (_, _ func(int)) { return f1, f2 /* ERROR "cannot infer P" */ }