diff --git a/src/net/dial.go b/src/net/dial.go index e4e44d2263..22992d5b7a 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -5,7 +5,6 @@ package net import ( - "errors" "runtime" "time" ) @@ -140,8 +139,11 @@ func parseNetwork(net string) (afnet string, proto int, err error) { return "", 0, UnknownNetworkError(net) } -func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error) { - afnet, _, err := parseNetwork(net) +// resolverAddrList resolves addr using hint and returns a list of +// addresses. The result contains at least one address when error is +// nil. +func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (addrList, error) { + afnet, _, err := parseNetwork(network) if err != nil { return nil, err } @@ -154,9 +156,59 @@ func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error) if err != nil { return nil, err } + if op == "dial" && hint != nil && addr.Network() != hint.Network() { + return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()} + } return addrList{addr}, nil } - return internetAddrList(afnet, addr, deadline) + addrs, err := internetAddrList(afnet, addr, deadline) + if err != nil || op != "dial" || hint == nil { + return addrs, err + } + var ( + tcp *TCPAddr + udp *UDPAddr + ip *IPAddr + wildcard bool + ) + switch hint := hint.(type) { + case *TCPAddr: + tcp = hint + wildcard = tcp.isWildcard() + case *UDPAddr: + udp = hint + wildcard = udp.isWildcard() + case *IPAddr: + ip = hint + wildcard = ip.isWildcard() + } + naddrs := addrs[:0] + for _, addr := range addrs { + if addr.Network() != hint.Network() { + return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()} + } + switch addr := addr.(type) { + case *TCPAddr: + if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) { + continue + } + naddrs = append(naddrs, addr) + case *UDPAddr: + if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) { + continue + } + naddrs = append(naddrs, addr) + case *IPAddr: + if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) { + continue + } + naddrs = append(naddrs, addr) + } + } + if len(naddrs) == 0 { + return nil, errNoSuitableAddress + } + return naddrs, nil } // Dial connects to the address on the named network. @@ -214,7 +266,7 @@ type dialContext struct { // parameters. func (d *Dialer) Dial(network, address string) (Conn, error) { finalDeadline := d.deadline(time.Now()) - addrs, err := resolveAddrList("dial", network, address, finalDeadline) + addrs, err := resolveAddrList("dial", network, address, d.LocalAddr, finalDeadline) if err != nil { return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err} } @@ -387,9 +439,6 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e // dial function, because some OSes don't implement the deadline feature. func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan struct{}) (c Conn, err error) { la := ctx.LocalAddr - if la != nil && la.Network() != ra.Network() { - return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: errors.New("mismatched local address type " + la.Network())} - } switch ra := ra.(type) { case *TCPAddr: la, _ := la.(*TCPAddr) @@ -420,7 +469,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan str // instead of just the interface with the given host address. // See Dial for more details about address syntax. func Listen(net, laddr string) (Listener, error) { - addrs, err := resolveAddrList("listen", net, laddr, noDeadline) + addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} } @@ -447,7 +496,7 @@ func Listen(net, laddr string) (Listener, error) { // instead of just the interface with the given host address. // See Dial for the syntax of laddr. func ListenPacket(net, laddr string) (PacketConn, error) { - addrs, err := resolveAddrList("listen", net, laddr, noDeadline) + addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} } diff --git a/src/net/dial_test.go b/src/net/dial_test.go index 5fe3e856f8..3335df5a93 100644 --- a/src/net/dial_test.go +++ b/src/net/dial_test.go @@ -646,41 +646,118 @@ func TestDialerPartialDeadline(t *testing.T) { } } +type dialerLocalAddrTest struct { + network, raddr string + laddr Addr + error +} + +var dialerLocalAddrTests = []dialerLocalAddrTest{ + {"tcp4", "127.0.0.1", nil, nil}, + {"tcp4", "127.0.0.1", &TCPAddr{}, nil}, + {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil}, + {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil}, + {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, &AddrError{Err: "some error"}}, + {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, nil}, + {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, nil}, + {"tcp4", "127.0.0.1", &TCPAddr{IP: IPv6loopback}, errNoSuitableAddress}, + {"tcp4", "127.0.0.1", &UDPAddr{}, &AddrError{Err: "some error"}}, + {"tcp4", "127.0.0.1", &UnixAddr{}, &AddrError{Err: "some error"}}, + + {"tcp6", "::1", nil, nil}, + {"tcp6", "::1", &TCPAddr{}, nil}, + {"tcp6", "::1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil}, + {"tcp6", "::1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil}, + {"tcp6", "::1", &TCPAddr{IP: ParseIP("::")}, nil}, + {"tcp6", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, errNoSuitableAddress}, + {"tcp6", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, errNoSuitableAddress}, + {"tcp6", "::1", &TCPAddr{IP: IPv6loopback}, nil}, + {"tcp6", "::1", &UDPAddr{}, &AddrError{Err: "some error"}}, + {"tcp6", "::1", &UnixAddr{}, &AddrError{Err: "some error"}}, + + {"tcp", "127.0.0.1", nil, nil}, + {"tcp", "127.0.0.1", &TCPAddr{}, nil}, + {"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil}, + {"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil}, + {"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, nil}, + {"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, nil}, + {"tcp", "127.0.0.1", &TCPAddr{IP: IPv6loopback}, errNoSuitableAddress}, + {"tcp", "127.0.0.1", &UDPAddr{}, &AddrError{Err: "some error"}}, + {"tcp", "127.0.0.1", &UnixAddr{}, &AddrError{Err: "some error"}}, + + {"tcp", "::1", nil, nil}, + {"tcp", "::1", &TCPAddr{}, nil}, + {"tcp", "::1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil}, + {"tcp", "::1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil}, + {"tcp", "::1", &TCPAddr{IP: ParseIP("::")}, nil}, + {"tcp", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, errNoSuitableAddress}, + {"tcp", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, errNoSuitableAddress}, + {"tcp", "::1", &TCPAddr{IP: IPv6loopback}, nil}, + {"tcp", "::1", &UDPAddr{}, &AddrError{Err: "some error"}}, + {"tcp", "::1", &UnixAddr{}, &AddrError{Err: "some error"}}, +} + func TestDialerLocalAddr(t *testing.T) { - ch := make(chan error, 1) - handler := func(ls *localServer, ln Listener) { - c, err := ln.Accept() - if err != nil { - ch <- err - return - } - defer c.Close() - ch <- nil - } - ls, err := newLocalServer("tcp") - if err != nil { - t.Fatal(err) - } - defer ls.teardown() - if err := ls.buildup(handler); err != nil { - t.Fatal(err) + if !supportsIPv4 || !supportsIPv6 { + t.Skip("both IPv4 and IPv6 are required") } - laddr, err := ResolveTCPAddr(ls.Listener.Addr().Network(), ls.Listener.Addr().String()) - if err != nil { - t.Fatal(err) + if supportsIPv4map { + dialerLocalAddrTests = append(dialerLocalAddrTests, dialerLocalAddrTest{ + "tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, nil, + }) + } else { + dialerLocalAddrTests = append(dialerLocalAddrTests, dialerLocalAddrTest{ + "tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, &AddrError{Err: "some error"}, + }) } - laddr.Port = 0 - d := &Dialer{LocalAddr: laddr} - c, err := d.Dial(ls.Listener.Addr().Network(), ls.Addr().String()) - if err != nil { - t.Fatal(err) + + origTestHookLookupIP := testHookLookupIP + defer func() { testHookLookupIP = origTestHookLookupIP }() + testHookLookupIP = lookupLocalhost + handler := func(ls *localServer, ln Listener) { + for { + c, err := ln.Accept() + if err != nil { + return + } + c.Close() + } } - defer c.Close() - c.Read(make([]byte, 1)) - err = <-ch - if err != nil { - t.Error(err) + var err error + var lss [2]*localServer + for i, network := range []string{"tcp4", "tcp6"} { + lss[i], err = newLocalServer(network) + if err != nil { + t.Fatal(err) + } + defer lss[i].teardown() + if err := lss[i].buildup(handler); err != nil { + t.Fatal(err) + } + } + + for _, tt := range dialerLocalAddrTests { + d := &Dialer{LocalAddr: tt.laddr} + var addr string + ip := ParseIP(tt.raddr) + if ip.To4() != nil { + addr = lss[0].Listener.Addr().String() + } + if ip.To16() != nil && ip.To4() == nil { + addr = lss[1].Listener.Addr().String() + } + c, err := d.Dial(tt.network, addr) + if err == nil && tt.error != nil || err != nil && tt.error == nil { + t.Errorf("%s %v->%s: got %v; want %v", tt.network, tt.laddr, tt.raddr, err, tt.error) + } + if err != nil { + if perr := parseDialError(err); perr != nil { + t.Error(perr) + } + continue + } + c.Close() } } diff --git a/src/net/error_test.go b/src/net/error_test.go index ee0979c748..c3a4d32382 100644 --- a/src/net/error_test.go +++ b/src/net/error_test.go @@ -96,7 +96,7 @@ second: goto third } switch nestedErr { - case errCanceled, errClosing, errMissingAddress: + case errCanceled, errClosing, errMissingAddress, errNoSuitableAddress: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -416,7 +416,7 @@ second: goto third } switch nestedErr { - case errCanceled, errClosing, errTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF: + case errCanceled, errClosing, errMissingAddress, errTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) diff --git a/src/net/ip.go b/src/net/ip.go index a25729cfc9..0501f5a6a3 100644 --- a/src/net/ip.go +++ b/src/net/ip.go @@ -377,6 +377,10 @@ func bytesEqual(x, y []byte) bool { return true } +func (ip IP) matchAddrFamily(x IP) bool { + return ip.To4() != nil && x.To4() != nil || ip.To16() != nil && ip.To4() == nil && x.To16() != nil && x.To4() == nil +} + // If mask is a sequence of 1 bits followed by 0 bits, // return the number of 1 bits. func simpleMaskLength(mask IPMask) int { diff --git a/src/net/ipsock.go b/src/net/ipsock.go index f3ac00df05..f093b4926d 100644 --- a/src/net/ipsock.go +++ b/src/net/ipsock.go @@ -6,10 +6,7 @@ package net -import ( - "errors" - "time" -) +import "time" var ( // supportsIPv4 reports whether the platform supports IPv4 @@ -73,8 +70,6 @@ func (addrs addrList) partition(strategy func(Addr) bool) (primaries, fallbacks return } -var errNoSuitableAddress = errors.New("no suitable address found") - // filterAddrList applies a filter to a list of IP addresses, // yielding a list of Addr objects. Known filters are nil, ipv4only, // and ipv6only. It returns every address when the filter is nil. diff --git a/src/net/net.go b/src/net/net.go index 2ff1a34981..3b37b336d1 100644 --- a/src/net/net.go +++ b/src/net/net.go @@ -364,6 +364,9 @@ type Error interface { // Various errors contained in OpError. var ( + // For connection setup operations. + errNoSuitableAddress = errors.New("no suitable address found") + // For connection setup and write operations. errMissingAddress = errors.New("missing address")