net: permit use of Resolver.PreferGo, netgo on Windows and Plan 9

This reverts commit CL 401754 (440c9312c8) which reverted CL 400654,
thus reapplying CL 400654, re-adding the func init() { netGo = true }
to cgo_stub.go CL 400654 had originally removed (mistakenly during
development?) that had broken the darwin nocgo builder.

Fixes #33097

Change-Id: I90f59746d2ceb6b5d2bd832c9fc90068f8ff7417
Reviewed-on: https://go-review.googlesource.com/c/go/+/409234
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Keith Randall <khr@google.com>
This commit is contained in:
Brad Fitzpatrick 2022-05-28 14:06:43 -07:00
parent a21cf916f4
commit af88fb6502
15 changed files with 835 additions and 290 deletions

View File

@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build unix
// Minimal RFC 6724 address selection. // Minimal RFC 6724 address selection.
package net package net

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build unix //go:build !js
package net package net
@ -21,7 +21,7 @@ type conf struct {
forceCgoLookupHost bool forceCgoLookupHost bool
netGo bool // go DNS resolution forced netGo bool // go DNS resolution forced
netCgo bool // cgo DNS resolution forced netCgo bool // non-go DNS resolution forced (cgo, or win32)
// machine has an /etc/mdns.allow file // machine has an /etc/mdns.allow file
hasMDNSAllow bool hasMDNSAllow bool
@ -49,9 +49,23 @@ func initConfVal() {
confVal.dnsDebugLevel = debugLevel confVal.dnsDebugLevel = debugLevel
confVal.netGo = netGo || dnsMode == "go" confVal.netGo = netGo || dnsMode == "go"
confVal.netCgo = netCgo || dnsMode == "cgo" confVal.netCgo = netCgo || dnsMode == "cgo"
if !confVal.netGo && !confVal.netCgo && (runtime.GOOS == "windows" || runtime.GOOS == "plan9") {
// Neither of these platforms actually use cgo.
//
// The meaning of "cgo" mode in the net package is
// really "the native OS way", which for libc meant
// cgo on the original platforms that motivated
// PreferGo support before Windows and Plan9 got support,
// at which time the GODEBUG=netdns=go and GODEBUG=netdns=cgo
// names were already kinda locked in.
confVal.netCgo = true
}
if confVal.dnsDebugLevel > 0 { if confVal.dnsDebugLevel > 0 {
defer func() { defer func() {
if confVal.dnsDebugLevel > 1 {
println("go package net: confVal.netCgo =", confVal.netCgo, " netGo =", confVal.netGo)
}
switch { switch {
case confVal.netGo: case confVal.netGo:
if netGo { if netGo {
@ -75,6 +89,10 @@ func initConfVal() {
return return
} }
if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
return
}
// If any environment-specified resolver options are specified, // If any environment-specified resolver options are specified,
// force cgo. Note that LOCALDOMAIN can change behavior merely // force cgo. Note that LOCALDOMAIN can change behavior merely
// by being specified with the empty string. // by being specified with the empty string.
@ -129,7 +147,19 @@ func (c *conf) hostLookupOrder(r *Resolver, hostname string) (ret hostLookupOrde
} }
fallbackOrder := hostLookupCgo fallbackOrder := hostLookupCgo
if c.netGo || r.preferGo() { if c.netGo || r.preferGo() {
fallbackOrder = hostLookupFilesDNS switch c.goos {
case "windows":
// TODO(bradfitz): implement files-based
// lookup on Windows too? I guess /etc/hosts
// kinda exists on Windows. But for now, only
// do DNS.
fallbackOrder = hostLookupDNS
default:
fallbackOrder = hostLookupFilesDNS
}
}
if c.goos == "windows" || c.goos == "plan9" {
return fallbackOrder
} }
if c.forceCgoLookupHost || c.resolv.unknownOpt || c.goos == "android" { if c.forceCgoLookupHost || c.resolv.unknownOpt || c.goos == "android" {
return fallbackOrder return fallbackOrder

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build unix //go:build !js
// DNS client: see RFC 1035. // DNS client: see RFC 1035.
// Has to be linked into package net for Dial. // Has to be linked into package net for Dial.
@ -20,6 +20,7 @@ import (
"internal/itoa" "internal/itoa"
"io" "io"
"os" "os"
"runtime"
"sync" "sync"
"time" "time"
@ -381,12 +382,21 @@ func (conf *resolverConfig) tryUpdate(name string) {
} }
conf.lastChecked = now conf.lastChecked = now
var mtime time.Time switch runtime.GOOS {
if fi, err := os.Stat(name); err == nil { case "windows":
mtime = fi.ModTime() // There's no file on disk, so don't bother checking
} // and failing.
if mtime.Equal(conf.dnsConfig.mtime) { //
return // The Windows implementation of dnsReadConfig (called
// below) ignores the name.
default:
var mtime time.Time
if fi, err := os.Stat(name); err == nil {
mtime = fi.ModTime()
}
if mtime.Equal(conf.dnsConfig.mtime) {
return
}
} }
dnsConf := dnsReadConfig(name) dnsConf := dnsReadConfig(name)

43
src/net/dnsconfig.go Normal file
View File

@ -0,0 +1,43 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"os"
"sync/atomic"
"time"
)
var (
defaultNS = []string{"127.0.0.1:53", "[::1]:53"}
getHostname = os.Hostname // variable for testing
)
type dnsConfig struct {
servers []string // server addresses (in host:port form) to use
search []string // rooted suffixes to append to local name
ndots int // number of dots in name to trigger absolute lookup
timeout time.Duration // wait before giving up on a query, including retries
attempts int // lost packets before giving up on server
rotate bool // round robin among servers
unknownOpt bool // anything unknown was encountered
lookup []string // OpenBSD top-level database "lookup" order
err error // any error that occurs during open of resolv.conf
mtime time.Time // time of resolv.conf modification
soffset uint32 // used by serverOffset
singleRequest bool // use sequential A and AAAA queries instead of parallel queries
useTCP bool // force usage of TCP for DNS resolutions
}
// serverOffset returns an offset that can be used to determine
// indices of servers in c.servers when making queries.
// When the rotate option is enabled, this offset increases.
// Otherwise it is always 0.
func (c *dnsConfig) serverOffset() uint32 {
if c.rotate {
return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start
}
return 0
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build unix //go:build !js && !windows
// Read system DNS config from /etc/resolv.conf // Read system DNS config from /etc/resolv.conf
@ -10,32 +10,9 @@ package net
import ( import (
"internal/bytealg" "internal/bytealg"
"os"
"sync/atomic"
"time" "time"
) )
var (
defaultNS = []string{"127.0.0.1:53", "[::1]:53"}
getHostname = os.Hostname // variable for testing
)
type dnsConfig struct {
servers []string // server addresses (in host:port form) to use
search []string // rooted suffixes to append to local name
ndots int // number of dots in name to trigger absolute lookup
timeout time.Duration // wait before giving up on a query, including retries
attempts int // lost packets before giving up on server
rotate bool // round robin among servers
unknownOpt bool // anything unknown was encountered
lookup []string // OpenBSD top-level database "lookup" order
err error // any error that occurs during open of resolv.conf
mtime time.Time // time of resolv.conf modification
soffset uint32 // used by serverOffset
singleRequest bool // use sequential A and AAAA queries instead of parallel queries
useTCP bool // force usage of TCP for DNS resolutions
}
// See resolv.conf(5) on a Linux machine. // See resolv.conf(5) on a Linux machine.
func dnsReadConfig(filename string) *dnsConfig { func dnsReadConfig(filename string) *dnsConfig {
conf := &dnsConfig{ conf := &dnsConfig{
@ -156,17 +133,6 @@ func dnsReadConfig(filename string) *dnsConfig {
return conf return conf
} }
// serverOffset returns an offset that can be used to determine
// indices of servers in c.servers when making queries.
// When the rotate option is enabled, this offset increases.
// Otherwise it is always 0.
func (c *dnsConfig) serverOffset() uint32 {
if c.rotate {
return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start
}
return 0
}
func dnsDefaultSearch() []string { func dnsDefaultSearch() []string {
hn, err := getHostname() hn, err := getHostname()
if err != nil { if err != nil {

View File

@ -0,0 +1,58 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"syscall"
"time"
)
func dnsReadConfig(ignoredFilename string) (conf *dnsConfig) {
conf = &dnsConfig{
ndots: 1,
timeout: 5 * time.Second,
attempts: 2,
}
defer func() {
if len(conf.servers) == 0 {
conf.servers = defaultNS
}
}()
aas, err := adapterAddresses()
if err != nil {
return
}
// TODO(bradfitz): this just collects all the DNS servers on all
// the interfaces in some random order. It should order it by
// default route, or only use the default route(s) instead.
// In practice, however, it mostly works.
for _, aa := range aas {
for dns := aa.FirstDnsServerAddress; dns != nil; dns = dns.Next {
sa, err := dns.Address.Sockaddr.Sockaddr()
if err != nil {
continue
}
var ip IP
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
ip = IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])
case *syscall.SockaddrInet6:
ip = make(IP, IPv6len)
copy(ip, sa.Addr[:])
if ip[0] == 0xfe && ip[1] == 0xc0 {
// Ignore these fec0/10 ones. Windows seems to
// populate them as defaults on its misc rando
// interfaces.
continue
}
default:
// Unexpected type.
continue
}
conf.servers = append(conf.servers, JoinHostPort(ip.String(), "53"))
}
}
return conf
}

View File

@ -10,6 +10,8 @@ import (
"internal/singleflight" "internal/singleflight"
"net/netip" "net/netip"
"sync" "sync"
"golang.org/x/net/dns/dnsmessage"
) )
// protocols contains minimal mappings between internet protocol // protocols contains minimal mappings between internet protocol
@ -665,3 +667,227 @@ func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error
// method receives DNS records which contain invalid DNS names. This may be returned alongside // method receives DNS records which contain invalid DNS names. This may be returned alongside
// results which have had the malformed records filtered out. // results which have had the malformed records filtered out.
var errMalformedDNSRecordsDetail = "DNS response contained records which contain invalid names" var errMalformedDNSRecordsDetail = "DNS response contained records which contain invalid names"
// dial makes a new connection to the provided server (which must be
// an IP address) with the provided network type, using either r.Dial
// (if both r and r.Dial are non-nil) or else Dialer.DialContext.
func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) {
// Calling Dial here is scary -- we have to be sure not to
// dial a name that will require a DNS lookup, or Dial will
// call back here to translate it. The DNS config parser has
// already checked that all the cfg.servers are IP
// addresses, which Dial will use without a DNS lookup.
var c Conn
var err error
if r != nil && r.Dial != nil {
c, err = r.Dial(ctx, network, server)
} else {
var d Dialer
c, err = d.DialContext(ctx, network, server)
}
if err != nil {
return nil, mapErr(err)
}
return c, nil
}
// goLookupSRV returns the SRV records for a target name, built either
// from its component service ("sip"), protocol ("tcp"), and name
// ("example.com."), or from name directly (if service and proto are
// both empty).
//
// In either case, the returned target name ("_sip._tcp.example.com.")
// is also returned on success.
//
// The records are sorted by weight.
func (r *Resolver) goLookupSRV(ctx context.Context, service, proto, name string) (target string, srvs []*SRV, err error) {
if service == "" && proto == "" {
target = name
} else {
target = "_" + service + "._" + proto + "." + name
}
p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV)
if err != nil {
return "", nil, err
}
var cname dnsmessage.Name
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeSRV {
if err := p.SkipAnswer(); err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
if cname.Length == 0 && h.Name.Length != 0 {
cname = h.Name
}
srv, err := p.SRVResource()
if err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight})
}
byPriorityWeight(srvs).sort()
return cname.String(), srvs, nil
}
// goLookupMX returns the MX records for name.
func (r *Resolver) goLookupMX(ctx context.Context, name string) ([]*MX, error) {
p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX)
if err != nil {
return nil, err
}
var mxs []*MX
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeMX {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
mx, err := p.MXResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref})
}
byPref(mxs).sort()
return mxs, nil
}
// goLookupNS returns the NS records for name.
func (r *Resolver) goLookupNS(ctx context.Context, name string) ([]*NS, error) {
p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS)
if err != nil {
return nil, err
}
var nss []*NS
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeNS {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
ns, err := p.NSResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
nss = append(nss, &NS{Host: ns.NS.String()})
}
return nss, nil
}
// goLookupTXT returns the TXT records from name.
func (r *Resolver) goLookupTXT(ctx context.Context, name string) ([]string, error) {
p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT)
if err != nil {
return nil, err
}
var txts []string
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeTXT {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
txt, err := p.TXTResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
// Multiple strings in one TXT record need to be
// concatenated without separator to be consistent
// with previous Go resolver.
n := 0
for _, s := range txt.TXT {
n += len(s)
}
txtJoin := make([]byte, 0, n)
for _, s := range txt.TXT {
txtJoin = append(txtJoin, s...)
}
if len(txts) == 0 {
txts = make([]string, 0, 1)
}
txts = append(txts, string(txtJoin))
}
return txts, nil
}

View File

@ -179,7 +179,27 @@ loop:
return return
} }
func (r *Resolver) lookupIP(ctx context.Context, _, host string) (addrs []IPAddr, err error) { // preferGoOverPlan9 reports whether the resolver should use the
// "PreferGo" implementation rather than asking plan9 services
// for the answers.
func (r *Resolver) preferGoOverPlan9() bool {
conf := systemConf()
order := conf.hostLookupOrder(r, "") // name is unused
// TODO(bradfitz): for now we only permit use of the PreferGo
// implementation when there's a non-nil Resolver with a
// non-nil Dialer. This is a sign that they the code is trying
// to use their DNS-speaking net.Conn (such as an in-memory
// DNS cache) and they don't want to actually hit the network.
// Once we add support for looking the default DNS servers
// from plan9, though, then we can relax this.
return order != hostLookupCgo && r != nil && r.Dial != nil
}
func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
if r.preferGoOverPlan9() {
return r.goLookupIP(ctx, network, host)
}
lits, err := r.lookupHost(ctx, host) lits, err := r.lookupHost(ctx, host)
if err != nil { if err != nil {
return return
@ -223,7 +243,10 @@ func (*Resolver) lookupPort(ctx context.Context, network, service string) (port
return 0, unknownPortError return 0, unknownPortError
} }
func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) { func (r *Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) {
if r.preferGoOverPlan9() {
return r.goLookupCNAME(ctx, name)
}
lines, err := queryDNS(ctx, name, "cname") lines, err := queryDNS(ctx, name, "cname")
if err != nil { if err != nil {
if stringsHasSuffix(err.Error(), "dns failure") || stringsHasSuffix(err.Error(), "resource does not exist; negrcode 0") { if stringsHasSuffix(err.Error(), "dns failure") || stringsHasSuffix(err.Error(), "resource does not exist; negrcode 0") {
@ -240,7 +263,10 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, er
return "", errors.New("bad response from ndb/dns") return "", errors.New("bad response from ndb/dns")
} }
func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) { func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) {
if r.preferGoOverPlan9() {
return r.goLookupSRV(ctx, service, proto, name)
}
var target string var target string
if service == "" && proto == "" { if service == "" && proto == "" {
target = name target = name
@ -269,7 +295,10 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cn
return return
} }
func (*Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error) { func (r *Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error) {
if r.preferGoOverPlan9() {
return r.goLookupMX(ctx, name)
}
lines, err := queryDNS(ctx, name, "mx") lines, err := queryDNS(ctx, name, "mx")
if err != nil { if err != nil {
return return
@ -287,7 +316,10 @@ func (*Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error
return return
} }
func (*Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error) { func (r *Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error) {
if r.preferGoOverPlan9() {
return r.goLookupNS(ctx, name)
}
lines, err := queryDNS(ctx, name, "ns") lines, err := queryDNS(ctx, name, "ns")
if err != nil { if err != nil {
return return
@ -302,7 +334,10 @@ func (*Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error
return return
} }
func (*Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err error) { func (r *Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err error) {
if r.preferGoOverPlan9() {
return r.goLookupTXT(ctx, name)
}
lines, err := queryDNS(ctx, name, "txt") lines, err := queryDNS(ctx, name, "txt")
if err != nil { if err != nil {
return return
@ -315,7 +350,10 @@ func (*Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err
return return
} }
func (*Resolver) lookupAddr(ctx context.Context, addr string) (name []string, err error) { func (r *Resolver) lookupAddr(ctx context.Context, addr string) (name []string, err error) {
if r.preferGoOverPlan9() {
return r.goLookupPTR(ctx, addr)
}
arpa, err := reverseaddr(addr) arpa, err := reverseaddr(addr)
if err != nil { if err != nil {
return return

View File

@ -11,8 +11,6 @@ import (
"internal/bytealg" "internal/bytealg"
"sync" "sync"
"syscall" "syscall"
"golang.org/x/net/dns/dnsmessage"
) )
var onceReadProtocols sync.Once var onceReadProtocols sync.Once
@ -55,26 +53,6 @@ func lookupProtocol(_ context.Context, name string) (int, error) {
return lookupProtocolMap(name) return lookupProtocolMap(name)
} }
func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) {
// Calling Dial here is scary -- we have to be sure not to
// dial a name that will require a DNS lookup, or Dial will
// call back here to translate it. The DNS config parser has
// already checked that all the cfg.servers are IP
// addresses, which Dial will use without a DNS lookup.
var c Conn
var err error
if r != nil && r.Dial != nil {
c, err = r.Dial(ctx, network, server)
} else {
var d Dialer
c, err = d.DialContext(ctx, network, server)
}
if err != nil {
return nil, mapErr(err)
}
return c, nil
}
func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) { func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
order := systemConf().hostLookupOrder(r, host) order := systemConf().hostLookupOrder(r, host)
if !r.preferGo() && order == hostLookupCgo { if !r.preferGo() && order == hostLookupCgo {
@ -129,194 +107,19 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error)
} }
func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
var target string return r.goLookupSRV(ctx, service, proto, name)
if service == "" && proto == "" {
target = name
} else {
target = "_" + service + "._" + proto + "." + name
}
p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV)
if err != nil {
return "", nil, err
}
var srvs []*SRV
var cname dnsmessage.Name
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeSRV {
if err := p.SkipAnswer(); err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
if cname.Length == 0 && h.Name.Length != 0 {
cname = h.Name
}
srv, err := p.SRVResource()
if err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight})
}
byPriorityWeight(srvs).sort()
return cname.String(), srvs, nil
} }
func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX) return r.goLookupMX(ctx, name)
if err != nil {
return nil, err
}
var mxs []*MX
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeMX {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
mx, err := p.MXResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref})
}
byPref(mxs).sort()
return mxs, nil
} }
func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS) return r.goLookupNS(ctx, name)
if err != nil {
return nil, err
}
var nss []*NS
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeNS {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
ns, err := p.NSResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
nss = append(nss, &NS{Host: ns.NS.String()})
}
return nss, nil
} }
func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT) return r.goLookupTXT(ctx, name)
if err != nil {
return nil, err
}
var txts []string
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeTXT {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
txt, err := p.TXTResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
// Multiple strings in one TXT record need to be
// concatenated without separator to be consistent
// with previous Go resolver.
n := 0
for _, s := range txt.TXT {
n += len(s)
}
txtJoin := make([]byte, 0, n)
for _, s := range txt.TXT {
txtJoin = append(txtJoin, s...)
}
if len(txts) == 0 {
txts = make([]string, 0, 1)
}
txts = append(txts, string(txtJoin))
}
return txts, nil
} }
func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) { func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {

View File

@ -82,7 +82,19 @@ func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error
return addrs, nil return addrs, nil
} }
// preferGoOverWindows reports whether the resolver should use the
// pure Go implementation rather than making win32 calls to ask the
// kernel for its answer.
func (r *Resolver) preferGoOverWindows() bool {
conf := systemConf()
order := conf.hostLookupOrder(r, "") // name is unused
return order != hostLookupCgo
}
func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr, error) { func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr, error) {
if r.preferGoOverWindows() {
return r.goLookupIP(ctx, network, name)
}
// TODO(bradfitz,brainman): use ctx more. See TODO below. // TODO(bradfitz,brainman): use ctx more. See TODO below.
var family int32 = syscall.AF_UNSPEC var family int32 = syscall.AF_UNSPEC
@ -169,7 +181,7 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr
} }
func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) { func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
if r.preferGo() { if r.preferGoOverWindows() {
return lookupPortMap(network, service) return lookupPortMap(network, service)
} }
@ -217,12 +229,15 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int
return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service} return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
} }
func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) { func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
if r.preferGoOverWindows() {
return r.goLookupCNAME(ctx, name)
}
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread() acquireThread()
defer releaseThread() defer releaseThread()
var r *syscall.DNSRecord var rec *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &rec, nil)
// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s // windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS { if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
// if there are no aliases, the canonical name is the input name // if there are no aliases, the canonical name is the input name
@ -231,14 +246,17 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
if e != nil { if e != nil {
return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name} return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
} }
defer syscall.DnsRecordListFree(r, 1) defer syscall.DnsRecordListFree(rec, 1)
resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r) resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), rec)
cname := windows.UTF16PtrToString(resolved) cname := windows.UTF16PtrToString(resolved)
return absDomainName(cname), nil return absDomainName(cname), nil
} }
func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
if r.preferGoOverWindows() {
return r.goLookupSRV(ctx, service, proto, name)
}
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread() acquireThread()
defer releaseThread() defer releaseThread()
@ -248,15 +266,15 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (st
} else { } else {
target = "_" + service + "._" + proto + "." + name target = "_" + service + "._" + proto + "." + name
} }
var r *syscall.DNSRecord var rec *syscall.DNSRecord
e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil) e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &rec, nil)
if e != nil { if e != nil {
return "", nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: target} return "", nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: target}
} }
defer syscall.DnsRecordListFree(r, 1) defer syscall.DnsRecordListFree(rec, 1)
srvs := make([]*SRV, 0, 10) srvs := make([]*SRV, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) { for _, p := range validRecs(rec, syscall.DNS_TYPE_SRV, target) {
v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0])) v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
srvs = append(srvs, &SRV{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:])), v.Port, v.Priority, v.Weight}) srvs = append(srvs, &SRV{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:])), v.Port, v.Priority, v.Weight})
} }
@ -264,19 +282,22 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (st
return absDomainName(target), srvs, nil return absDomainName(target), srvs, nil
} }
func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
if r.preferGoOverWindows() {
return r.goLookupMX(ctx, name)
}
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread() acquireThread()
defer releaseThread() defer releaseThread()
var r *syscall.DNSRecord var rec *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil) e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &rec, nil)
if e != nil { if e != nil {
return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name} return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
} }
defer syscall.DnsRecordListFree(r, 1) defer syscall.DnsRecordListFree(rec, 1)
mxs := make([]*MX, 0, 10) mxs := make([]*MX, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) { for _, p := range validRecs(rec, syscall.DNS_TYPE_MX, name) {
v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0])) v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
mxs = append(mxs, &MX{absDomainName(windows.UTF16PtrToString(v.NameExchange)), v.Preference}) mxs = append(mxs, &MX{absDomainName(windows.UTF16PtrToString(v.NameExchange)), v.Preference})
} }
@ -284,38 +305,44 @@ func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
return mxs, nil return mxs, nil
} }
func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
if r.preferGoOverWindows() {
return r.goLookupNS(ctx, name)
}
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread() acquireThread()
defer releaseThread() defer releaseThread()
var r *syscall.DNSRecord var rec *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil) e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &rec, nil)
if e != nil { if e != nil {
return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name} return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
} }
defer syscall.DnsRecordListFree(r, 1) defer syscall.DnsRecordListFree(rec, 1)
nss := make([]*NS, 0, 10) nss := make([]*NS, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) { for _, p := range validRecs(rec, syscall.DNS_TYPE_NS, name) {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
nss = append(nss, &NS{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))}) nss = append(nss, &NS{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))})
} }
return nss, nil return nss, nil
} }
func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
if r.preferGoOverWindows() {
return r.lookupTXT(ctx, name)
}
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread() acquireThread()
defer releaseThread() defer releaseThread()
var r *syscall.DNSRecord var rec *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil) e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &rec, nil)
if e != nil { if e != nil {
return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name} return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
} }
defer syscall.DnsRecordListFree(r, 1) defer syscall.DnsRecordListFree(rec, 1)
txts := make([]string, 0, 10) txts := make([]string, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) { for _, p := range validRecs(rec, syscall.DNS_TYPE_TEXT, name) {
d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0])) d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
s := "" s := ""
for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount:d.StringCount] { for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount:d.StringCount] {
@ -326,7 +353,11 @@ func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
return txts, nil return txts, nil
} }
func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) { func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
if r.preferGoOverWindows() {
return r.goLookupPTR(ctx, addr)
}
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread() acquireThread()
defer releaseThread() defer releaseThread()
@ -334,15 +365,15 @@ func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var r *syscall.DNSRecord var rec *syscall.DNSRecord
e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil) e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &rec, nil)
if e != nil { if e != nil {
return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: addr} return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: addr}
} }
defer syscall.DnsRecordListFree(r, 1) defer syscall.DnsRecordListFree(rec, 1)
ptrs := make([]string, 0, 10) ptrs := make([]string, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) { for _, p := range validRecs(rec, syscall.DNS_TYPE_PTR, arpa) {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
ptrs = append(ptrs, absDomainName(windows.UTF16PtrToString(v.Host))) ptrs = append(ptrs, absDomainName(windows.UTF16PtrToString(v.Host)))
} }

View File

@ -61,7 +61,7 @@ The resolver decision can be overridden by setting the netdns value of the
GODEBUG environment variable (see package runtime) to go or cgo, as in: GODEBUG environment variable (see package runtime) to go or cgo, as in:
export GODEBUG=netdns=go # force pure Go resolver export GODEBUG=netdns=go # force pure Go resolver
export GODEBUG=netdns=cgo # force cgo resolver export GODEBUG=netdns=cgo # force native resolver (cgo, win32)
The decision can also be forced while building the Go source tree The decision can also be forced while building the Go source tree
by setting the netgo or netcgo build tag. by setting the netgo or netcgo build tag.
@ -73,7 +73,8 @@ join the two settings by a plus sign, as in GODEBUG=netdns=go+1.
On Plan 9, the resolver always accesses /net/cs and /net/dns. On Plan 9, the resolver always accesses /net/cs and /net/dns.
On Windows, the resolver always uses C library functions, such as GetAddrInfo and DnsQuery. On Windows, in Go 1.18.x and earlier, the resolver always used C
library functions, such as GetAddrInfo and DnsQuery.
*/ */
package net package net

View File

@ -16,6 +16,8 @@ import (
"sync" "sync"
"syscall" "syscall"
"time" "time"
"golang.org/x/net/dns/dnsmessage"
) )
var listenersMu sync.Mutex var listenersMu sync.Mutex
@ -314,3 +316,7 @@ func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
func (fd *netFD) dup() (f *os.File, err error) { func (fd *netFD) dup() (f *os.File, err error) {
return nil, syscall.ENOSYS return nil, syscall.ENOSYS
} }
func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
panic("unreachable")
}

9
src/net/netgo.go Normal file
View File

@ -0,0 +1,9 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build netgo
package net
func init() { netGo = true }

View File

@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build unix
package net package net
import ( import (

View File

@ -0,0 +1,328 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !js
// Test that Resolver.Dial can be a func returning an in-memory net.Conn
// speaking DNS.
package net
import (
"bytes"
"context"
"errors"
"fmt"
"reflect"
"sort"
"testing"
"time"
"golang.org/x/net/dns/dnsmessage"
)
func TestResolverDialFunc(t *testing.T) {
r := &Resolver{
PreferGo: true,
Dial: newResolverDialFunc(&resolverDialHandler{
StartDial: func(network, address string) error {
t.Logf("StartDial(%q, %q) ...", network, address)
return nil
},
Question: func(h dnsmessage.Header, q dnsmessage.Question) {
t.Logf("Header: %+v for %q (type=%v, class=%v)", h,
q.Name.String(), q.Type, q.Class)
},
// TODO: add test without HandleA* hooks specified at all, that Go
// doesn't issue retries; map to something terminal.
HandleA: func(w AWriter, name string) error {
w.AddIP([4]byte{1, 2, 3, 4})
w.AddIP([4]byte{5, 6, 7, 8})
return nil
},
HandleAAAA: func(w AAAAWriter, name string) error {
w.AddIP([16]byte{1: 1, 15: 15})
w.AddIP([16]byte{2: 2, 14: 14})
return nil
},
HandleSRV: func(w SRVWriter, name string) error {
w.AddSRV(1, 2, 80, "foo.bar.")
w.AddSRV(2, 3, 81, "bar.baz.")
return nil
},
}),
}
ctx := context.Background()
const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld."
t.Run("LookupIP", func(t *testing.T) {
ips, err := r.LookupIP(ctx, "ip", fakeDomain)
if err != nil {
t.Fatal(err)
}
if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) {
t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want)
}
})
t.Run("LookupSRV", func(t *testing.T) {
_, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain)
if err != nil {
t.Fatal(err)
}
want := []*SRV{
{
Target: "foo.bar.",
Port: 80,
Priority: 1,
Weight: 2,
},
{
Target: "bar.baz.",
Port: 81,
Priority: 2,
Weight: 3,
},
}
if !reflect.DeepEqual(got, want) {
t.Errorf("wrong result. got:")
for _, r := range got {
t.Logf(" - %+v", r)
}
}
})
}
func sortedIPStrings(ips []IP) []string {
ret := make([]string, len(ips))
for i, ip := range ips {
ret[i] = ip.String()
}
sort.Strings(ret)
return ret
}
func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) {
return func(ctx context.Context, network, address string) (Conn, error) {
a := &resolverFuncConn{
h: h,
network: network,
address: address,
ttl: 10, // 10 second default if unset
}
if h.StartDial != nil {
if err := h.StartDial(network, address); err != nil {
return nil, err
}
}
return a, nil
}
}
type resolverDialHandler struct {
// StartDial, if non-nil, is called when Go first calls Resolver.Dial.
// Any error returned aborts the dial and is returned unwrapped.
StartDial func(network, address string) error
Question func(dnsmessage.Header, dnsmessage.Question)
// err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2).
// A nil error means success.
HandleA func(w AWriter, name string) error
HandleAAAA func(w AAAAWriter, name string) error
HandleSRV func(w SRVWriter, name string) error
}
type ResponseWriter struct{ a *resolverFuncConn }
func (w ResponseWriter) header() dnsmessage.ResourceHeader {
q := w.a.q
return dnsmessage.ResourceHeader{
Name: q.Name,
Type: q.Type,
Class: q.Class,
TTL: w.a.ttl,
}
}
// SetTTL sets the TTL for subsequent written resources.
// Once a resource has been written, SetTTL calls are no-ops.
// That is, it can only be called at most once, before anything
// else is written.
func (w ResponseWriter) SetTTL(seconds uint32) {
// ... intention is last one wins and mutates all previously
// written records too, but that's a little annoying.
// But it's also annoying if the requirement is it needs to be set
// last.
// And it's also annoying if it's possible for users to set
// different TTLs per Answer.
if w.a.wrote {
return
}
w.a.ttl = seconds
}
type AWriter struct{ ResponseWriter }
func (w AWriter) AddIP(v4 [4]byte) {
w.a.wrote = true
err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4})
if err != nil {
panic(err)
}
}
type AAAAWriter struct{ ResponseWriter }
func (w AAAAWriter) AddIP(v6 [16]byte) {
w.a.wrote = true
err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6})
if err != nil {
panic(err)
}
}
type SRVWriter struct{ ResponseWriter }
// AddSRV adds a SRV record. The target name must end in a period and
// be 63 bytes or fewer.
func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error {
targetName, err := dnsmessage.NewName(target)
if err != nil {
return err
}
w.a.wrote = true
err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{
Priority: priority,
Weight: weight,
Port: port,
Target: targetName,
})
if err != nil {
panic(err) // internal fault, not user
}
return nil
}
var (
ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN
ErrRefused = errors.New("refused") // maps to RCode5, REFUSED
)
type resolverFuncConn struct {
h *resolverDialHandler
ctx context.Context
network string
address string
builder *dnsmessage.Builder
q dnsmessage.Question
ttl uint32
wrote bool
rbuf bytes.Buffer
}
func (*resolverFuncConn) Close() error { return nil }
func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} }
func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} }
func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil }
func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil }
func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil }
func (a *resolverFuncConn) Read(p []byte) (n int, err error) {
return a.rbuf.Read(p)
}
func (a *resolverFuncConn) Write(packet []byte) (n int, err error) {
if len(packet) < 2 {
return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet))
}
reqLen := int(packet[0])<<8 | int(packet[1])
req := packet[2:]
if len(req) != reqLen {
return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req))
}
var parser dnsmessage.Parser
h, err := parser.Start(req)
if err != nil {
// TODO: hook
return 0, err
}
q, err := parser.Question()
hadQ := (err == nil)
if err == nil && a.h.Question != nil {
a.h.Question(h, q)
}
if err != nil && err != dnsmessage.ErrSectionDone {
return 0, err
}
resh := h
resh.Response = true
resh.Authoritative = true
if hadQ {
resh.RCode = dnsmessage.RCodeSuccess
} else {
resh.RCode = dnsmessage.RCodeNotImplemented
}
a.rbuf.Grow(514)
a.rbuf.WriteByte('X') // reserved header for beu16 length
a.rbuf.WriteByte('Y') // reserved header for beu16 length
builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh)
a.builder = &builder
if hadQ {
a.q = q
a.builder.StartQuestions()
err := a.builder.Question(q)
if err != nil {
return 0, fmt.Errorf("Question: %w", err)
}
a.builder.StartAnswers()
switch q.Type {
case dnsmessage.TypeA:
if a.h.HandleA != nil {
resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String()))
}
case dnsmessage.TypeAAAA:
if a.h.HandleAAAA != nil {
resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String()))
}
case dnsmessage.TypeSRV:
if a.h.HandleSRV != nil {
resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String()))
}
}
}
tcpRes, err := builder.Finish()
if err != nil {
return 0, fmt.Errorf("Finish: %w", err)
}
n = len(tcpRes) - 2
tcpRes[0] = byte(n >> 8)
tcpRes[1] = byte(n)
a.rbuf.Write(tcpRes[2:])
return len(packet), nil
}
type someaddr struct{}
func (someaddr) Network() string { return "unused" }
func (someaddr) String() string { return "unused-someaddr" }
func mapRCode(err error) dnsmessage.RCode {
switch err {
case nil:
return dnsmessage.RCodeSuccess
case ErrNotExist:
return dnsmessage.RCodeNameError
case ErrRefused:
return dnsmessage.RCodeRefused
default:
return dnsmessage.RCodeServerFailure
}
}