mirror of
https://github.com/golang/go.git
synced 2025-05-07 08:32:59 +00:00
This reverts commit CL 401754 (440c9312c8) which reverted CL 400654, thus reapplying CL 400654, re-adding the func init() { netGo = true } to cgo_stub.go CL 400654 had originally removed (mistakenly during development?) that had broken the darwin nocgo builder. Fixes #33097 Change-Id: I90f59746d2ceb6b5d2bd832c9fc90068f8ff7417 Reviewed-on: https://go-review.googlesource.com/c/go/+/409234 Reviewed-by: Ian Lance Taylor <iant@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> Reviewed-by: Ian Lance Taylor <iant@google.com> Reviewed-by: Keith Randall <khr@google.com>
329 lines
8.3 KiB
Go
329 lines
8.3 KiB
Go
// Copyright 2022 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
//go:build !js
|
|
|
|
// Test that Resolver.Dial can be a func returning an in-memory net.Conn
|
|
// speaking DNS.
|
|
|
|
package net
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"sort"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
)
|
|
|
|
func TestResolverDialFunc(t *testing.T) {
|
|
r := &Resolver{
|
|
PreferGo: true,
|
|
Dial: newResolverDialFunc(&resolverDialHandler{
|
|
StartDial: func(network, address string) error {
|
|
t.Logf("StartDial(%q, %q) ...", network, address)
|
|
return nil
|
|
},
|
|
Question: func(h dnsmessage.Header, q dnsmessage.Question) {
|
|
t.Logf("Header: %+v for %q (type=%v, class=%v)", h,
|
|
q.Name.String(), q.Type, q.Class)
|
|
},
|
|
// TODO: add test without HandleA* hooks specified at all, that Go
|
|
// doesn't issue retries; map to something terminal.
|
|
HandleA: func(w AWriter, name string) error {
|
|
w.AddIP([4]byte{1, 2, 3, 4})
|
|
w.AddIP([4]byte{5, 6, 7, 8})
|
|
return nil
|
|
},
|
|
HandleAAAA: func(w AAAAWriter, name string) error {
|
|
w.AddIP([16]byte{1: 1, 15: 15})
|
|
w.AddIP([16]byte{2: 2, 14: 14})
|
|
return nil
|
|
},
|
|
HandleSRV: func(w SRVWriter, name string) error {
|
|
w.AddSRV(1, 2, 80, "foo.bar.")
|
|
w.AddSRV(2, 3, 81, "bar.baz.")
|
|
return nil
|
|
},
|
|
}),
|
|
}
|
|
ctx := context.Background()
|
|
const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld."
|
|
|
|
t.Run("LookupIP", func(t *testing.T) {
|
|
ips, err := r.LookupIP(ctx, "ip", fakeDomain)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) {
|
|
t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want)
|
|
}
|
|
})
|
|
|
|
t.Run("LookupSRV", func(t *testing.T) {
|
|
_, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
want := []*SRV{
|
|
{
|
|
Target: "foo.bar.",
|
|
Port: 80,
|
|
Priority: 1,
|
|
Weight: 2,
|
|
},
|
|
{
|
|
Target: "bar.baz.",
|
|
Port: 81,
|
|
Priority: 2,
|
|
Weight: 3,
|
|
},
|
|
}
|
|
if !reflect.DeepEqual(got, want) {
|
|
t.Errorf("wrong result. got:")
|
|
for _, r := range got {
|
|
t.Logf(" - %+v", r)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func sortedIPStrings(ips []IP) []string {
|
|
ret := make([]string, len(ips))
|
|
for i, ip := range ips {
|
|
ret[i] = ip.String()
|
|
}
|
|
sort.Strings(ret)
|
|
return ret
|
|
}
|
|
|
|
func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) {
|
|
return func(ctx context.Context, network, address string) (Conn, error) {
|
|
a := &resolverFuncConn{
|
|
h: h,
|
|
network: network,
|
|
address: address,
|
|
ttl: 10, // 10 second default if unset
|
|
}
|
|
if h.StartDial != nil {
|
|
if err := h.StartDial(network, address); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return a, nil
|
|
}
|
|
}
|
|
|
|
type resolverDialHandler struct {
|
|
// StartDial, if non-nil, is called when Go first calls Resolver.Dial.
|
|
// Any error returned aborts the dial and is returned unwrapped.
|
|
StartDial func(network, address string) error
|
|
|
|
Question func(dnsmessage.Header, dnsmessage.Question)
|
|
|
|
// err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2).
|
|
// A nil error means success.
|
|
HandleA func(w AWriter, name string) error
|
|
HandleAAAA func(w AAAAWriter, name string) error
|
|
HandleSRV func(w SRVWriter, name string) error
|
|
}
|
|
|
|
type ResponseWriter struct{ a *resolverFuncConn }
|
|
|
|
func (w ResponseWriter) header() dnsmessage.ResourceHeader {
|
|
q := w.a.q
|
|
return dnsmessage.ResourceHeader{
|
|
Name: q.Name,
|
|
Type: q.Type,
|
|
Class: q.Class,
|
|
TTL: w.a.ttl,
|
|
}
|
|
}
|
|
|
|
// SetTTL sets the TTL for subsequent written resources.
|
|
// Once a resource has been written, SetTTL calls are no-ops.
|
|
// That is, it can only be called at most once, before anything
|
|
// else is written.
|
|
func (w ResponseWriter) SetTTL(seconds uint32) {
|
|
// ... intention is last one wins and mutates all previously
|
|
// written records too, but that's a little annoying.
|
|
// But it's also annoying if the requirement is it needs to be set
|
|
// last.
|
|
// And it's also annoying if it's possible for users to set
|
|
// different TTLs per Answer.
|
|
if w.a.wrote {
|
|
return
|
|
}
|
|
w.a.ttl = seconds
|
|
|
|
}
|
|
|
|
type AWriter struct{ ResponseWriter }
|
|
|
|
func (w AWriter) AddIP(v4 [4]byte) {
|
|
w.a.wrote = true
|
|
err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4})
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
type AAAAWriter struct{ ResponseWriter }
|
|
|
|
func (w AAAAWriter) AddIP(v6 [16]byte) {
|
|
w.a.wrote = true
|
|
err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6})
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
type SRVWriter struct{ ResponseWriter }
|
|
|
|
// AddSRV adds a SRV record. The target name must end in a period and
|
|
// be 63 bytes or fewer.
|
|
func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error {
|
|
targetName, err := dnsmessage.NewName(target)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
w.a.wrote = true
|
|
err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{
|
|
Priority: priority,
|
|
Weight: weight,
|
|
Port: port,
|
|
Target: targetName,
|
|
})
|
|
if err != nil {
|
|
panic(err) // internal fault, not user
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var (
|
|
ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN
|
|
ErrRefused = errors.New("refused") // maps to RCode5, REFUSED
|
|
)
|
|
|
|
type resolverFuncConn struct {
|
|
h *resolverDialHandler
|
|
ctx context.Context
|
|
network string
|
|
address string
|
|
builder *dnsmessage.Builder
|
|
q dnsmessage.Question
|
|
ttl uint32
|
|
wrote bool
|
|
|
|
rbuf bytes.Buffer
|
|
}
|
|
|
|
func (*resolverFuncConn) Close() error { return nil }
|
|
func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} }
|
|
func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} }
|
|
func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil }
|
|
func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil }
|
|
func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil }
|
|
|
|
func (a *resolverFuncConn) Read(p []byte) (n int, err error) {
|
|
return a.rbuf.Read(p)
|
|
}
|
|
|
|
func (a *resolverFuncConn) Write(packet []byte) (n int, err error) {
|
|
if len(packet) < 2 {
|
|
return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet))
|
|
}
|
|
reqLen := int(packet[0])<<8 | int(packet[1])
|
|
req := packet[2:]
|
|
if len(req) != reqLen {
|
|
return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req))
|
|
}
|
|
|
|
var parser dnsmessage.Parser
|
|
h, err := parser.Start(req)
|
|
if err != nil {
|
|
// TODO: hook
|
|
return 0, err
|
|
}
|
|
q, err := parser.Question()
|
|
hadQ := (err == nil)
|
|
if err == nil && a.h.Question != nil {
|
|
a.h.Question(h, q)
|
|
}
|
|
if err != nil && err != dnsmessage.ErrSectionDone {
|
|
return 0, err
|
|
}
|
|
|
|
resh := h
|
|
resh.Response = true
|
|
resh.Authoritative = true
|
|
if hadQ {
|
|
resh.RCode = dnsmessage.RCodeSuccess
|
|
} else {
|
|
resh.RCode = dnsmessage.RCodeNotImplemented
|
|
}
|
|
a.rbuf.Grow(514)
|
|
a.rbuf.WriteByte('X') // reserved header for beu16 length
|
|
a.rbuf.WriteByte('Y') // reserved header for beu16 length
|
|
builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh)
|
|
a.builder = &builder
|
|
if hadQ {
|
|
a.q = q
|
|
a.builder.StartQuestions()
|
|
err := a.builder.Question(q)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("Question: %w", err)
|
|
}
|
|
a.builder.StartAnswers()
|
|
switch q.Type {
|
|
case dnsmessage.TypeA:
|
|
if a.h.HandleA != nil {
|
|
resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String()))
|
|
}
|
|
case dnsmessage.TypeAAAA:
|
|
if a.h.HandleAAAA != nil {
|
|
resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String()))
|
|
}
|
|
case dnsmessage.TypeSRV:
|
|
if a.h.HandleSRV != nil {
|
|
resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String()))
|
|
}
|
|
}
|
|
}
|
|
tcpRes, err := builder.Finish()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("Finish: %w", err)
|
|
}
|
|
|
|
n = len(tcpRes) - 2
|
|
tcpRes[0] = byte(n >> 8)
|
|
tcpRes[1] = byte(n)
|
|
a.rbuf.Write(tcpRes[2:])
|
|
|
|
return len(packet), nil
|
|
}
|
|
|
|
type someaddr struct{}
|
|
|
|
func (someaddr) Network() string { return "unused" }
|
|
func (someaddr) String() string { return "unused-someaddr" }
|
|
|
|
func mapRCode(err error) dnsmessage.RCode {
|
|
switch err {
|
|
case nil:
|
|
return dnsmessage.RCodeSuccess
|
|
case ErrNotExist:
|
|
return dnsmessage.RCodeNameError
|
|
case ErrRefused:
|
|
return dnsmessage.RCodeRefused
|
|
default:
|
|
return dnsmessage.RCodeServerFailure
|
|
}
|
|
}
|