runtime: make netpoll events sources identifiable on Windows

This is another attempt at CL 558895, but without adding stale pollDescs
protection, which deviates from the original goal of the CL and adds
complexity without proper testing.

It is currently not possible to distinguish between a netpollBreak,
an internal/poll WSA operation, and an external WSA operation (as
in #58870). This can cause spurious wakeups when external WSA operations
are retrieved from the queue, as they are treated as netpollBreak
events.

This CL makes use of completion keys to identify the source of the
event.

While here, fix TestWSASocketConflict, which was not properly
exercising the "external WSA operation" case.

Change-Id: I91f746d300d32eb7fed3c8f27266fef379360d98
Reviewed-on: https://go-review.googlesource.com/c/go/+/561895
Reviewed-by: Than McIntosh <thanm@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Alex Brainman <alex.brainman@gmail.com>
This commit is contained in:
qmuntal 2024-02-08 08:46:24 +01:00 committed by Quim Muntal
parent 69d6c7b8ee
commit 1de46564a7
2 changed files with 93 additions and 33 deletions

View File

@ -133,23 +133,14 @@ func TestWSASocketConflict(t *testing.T) {
var outbuf _TCP_INFO_v0
cbbr := uint32(0)
var ovs []syscall.Overlapped = make([]syscall.Overlapped, 2)
// Attempt to exercise behavior where a user-owned syscall.Overlapped
// induces an invalid pointer dereference in the Windows-specific version
// of runtime.netpoll.
ovs[1].Internal -= 1
var ov syscall.Overlapped
// Create an event so that we can efficiently wait for completion
// of a requested overlapped I/O operation.
ovs[0].HEvent, _ = windows.CreateEvent(nil, 0, 0, nil)
if ovs[0].HEvent == 0 {
ov.HEvent, _ = windows.CreateEvent(nil, 0, 0, nil)
if ov.HEvent == 0 {
t.Fatalf("could not create the event!")
}
// Set the low bit of the Event Handle so that the completion
// of the overlapped I/O event will not trigger a completion event
// on any I/O completion port associated with the handle.
ovs[0].HEvent |= 0x1
defer syscall.CloseHandle(ov.HEvent)
if err = fd.WSAIoctl(
SIO_TCP_INFO,
@ -158,7 +149,7 @@ func TestWSASocketConflict(t *testing.T) {
(*byte)(unsafe.Pointer(&outbuf)),
uint32(unsafe.Sizeof(outbuf)),
&cbbr,
&ovs[0],
&ov,
0,
); err != nil && !errors.Is(err, syscall.ERROR_IO_PENDING) {
t.Fatalf("could not perform the WSAIoctl: %v", err)
@ -167,14 +158,10 @@ func TestWSASocketConflict(t *testing.T) {
if err != nil && errors.Is(err, syscall.ERROR_IO_PENDING) {
// It is possible that the overlapped I/O operation completed
// immediately so there is no need to wait for it to complete.
if res, err := syscall.WaitForSingleObject(ovs[0].HEvent, syscall.INFINITE); res != 0 {
if res, err := syscall.WaitForSingleObject(ov.HEvent, syscall.INFINITE); res != 0 {
t.Fatalf("waiting for the completion of the overlapped IO failed: %v", err)
}
}
if err = syscall.CloseHandle(ovs[0].HEvent); err != nil {
t.Fatalf("could not close the event handle: %v", err)
}
}
type _TCP_INFO_v0 struct {

View File

@ -5,6 +5,7 @@
package runtime
import (
"internal/goarch"
"runtime/internal/atomic"
"unsafe"
)
@ -13,19 +14,82 @@ const _DWORD_MAX = 0xffffffff
const _INVALID_HANDLE_VALUE = ^uintptr(0)
// net_op must be the same as beginning of internal/poll.operation.
// Sources are used to identify the event that created an overlapped entry.
// The source values are arbitrary. There is no risk of collision with user
// defined values because the only way to set the key of an overlapped entry
// is using the iocphandle, which is not accessible to user code.
const (
netpollSourceReady = iota + 1
netpollSourceBreak
)
const (
// sourceBits is the number of bits needed to represent a source.
// 4 bits can hold 16 different sources, which is more than enough.
// It is set to a low value so the overlapped entry key can
// contain as much bits as possible for the pollDesc pointer.
sourceBits = 4 // 4 bits can hold 16 different sources, which is more than enough.
sourceMasks = 1<<sourceBits - 1
)
// packNetpollKey creates a key from a source and a tag.
// Bits that don't fit in the result are discarded.
func packNetpollKey(source uint8, pd *pollDesc) uintptr {
// TODO: Consider combining the source with pd.fdseq to detect stale pollDescs.
if source > (1<<sourceBits)-1 {
// Also fail on 64-bit systems, even though it can hold more bits.
throw("runtime: source value is too large")
}
if goarch.PtrSize == 4 {
return uintptr(unsafe.Pointer(pd))<<sourceBits | uintptr(source)
}
return uintptr(taggedPointerPack(unsafe.Pointer(pd), uintptr(source)))
}
// unpackNetpollSource returns the source packed key.
func unpackNetpollSource(key uintptr) uint8 {
if goarch.PtrSize == 4 {
return uint8(key & sourceMasks)
}
return uint8(taggedPointer(key).tag())
}
// pollOperation must be the same as beginning of internal/poll.operation.
// Keep these in sync.
type net_op struct {
type pollOperation struct {
// used by windows
o overlapped
_ overlapped
// used by netpoll
pd *pollDesc
mode int32
}
// pollOperationFromOverlappedEntry returns the pollOperation contained in
// e. It can return nil if the entry is not from internal/poll.
// See go.dev/issue/58870
func pollOperationFromOverlappedEntry(e *overlappedEntry) *pollOperation {
if e.ov == nil {
return nil
}
op := (*pollOperation)(unsafe.Pointer(e.ov))
// Check that the key matches the pollDesc pointer.
var keyMatch bool
if goarch.PtrSize == 4 {
keyMatch = e.key&^sourceMasks == uintptr(unsafe.Pointer(op.pd))<<sourceBits
} else {
keyMatch = (*pollDesc)(taggedPointer(e.key).pointer()) == op.pd
}
if !keyMatch {
return nil
}
return op
}
// overlappedEntry contains the information returned by a call to GetQueuedCompletionStatusEx.
// https://learn.microsoft.com/en-us/windows/win32/api/minwinbase/ns-minwinbase-overlapped_entry
type overlappedEntry struct {
key *pollDesc
op *net_op // In reality it's *overlapped, but we cast it to *net_op anyway.
key uintptr
ov *overlapped
internal uintptr
qty uint32
}
@ -49,8 +113,8 @@ func netpollIsPollDescriptor(fd uintptr) bool {
}
func netpollopen(fd uintptr, pd *pollDesc) int32 {
// TODO(iant): Consider using taggedPointer on 64-bit systems.
if stdcall4(_CreateIoCompletionPort, fd, iocphandle, uintptr(unsafe.Pointer(pd)), 0) == 0 {
key := packNetpollKey(netpollSourceReady, pd)
if stdcall4(_CreateIoCompletionPort, fd, iocphandle, key, 0) == 0 {
return int32(getlasterror())
}
return 0
@ -71,7 +135,8 @@ func netpollBreak() {
return
}
if stdcall4(_PostQueuedCompletionStatus, iocphandle, 0, 0, 0) == 0 {
key := packNetpollKey(netpollSourceBreak, nil)
if stdcall4(_PostQueuedCompletionStatus, iocphandle, 0, key, 0) == 0 {
println("runtime: netpoll: PostQueuedCompletionStatus failed (errno=", getlasterror(), ")")
throw("runtime: netpoll: PostQueuedCompletionStatus failed")
}
@ -86,7 +151,6 @@ func netpoll(delay int64) (gList, int32) {
var entries [64]overlappedEntry
var wait, n, i uint32
var errno int32
var op *net_op
var toRun gList
mp := getg().m
@ -127,21 +191,30 @@ func netpoll(delay int64) (gList, int32) {
mp.blocked = false
delta := int32(0)
for i = 0; i < n; i++ {
op = entries[i].op
if op != nil && op.pd == entries[i].key {
e := &entries[i]
switch unpackNetpollSource(e.key) {
case netpollSourceReady:
op := pollOperationFromOverlappedEntry(e)
if op == nil {
// Entry from outside the Go runtime and internal/poll, ignore.
continue
}
// Entry from internal/poll.
mode := op.mode
if mode != 'r' && mode != 'w' {
println("runtime: GetQueuedCompletionStatusEx returned net_op with invalid mode=", mode)
throw("runtime: netpoll failed")
}
delta += netpollready(&toRun, op.pd, mode)
} else {
case netpollSourceBreak:
netpollWakeSig.Store(0)
if delay == 0 {
// Forward the notification to the
// blocked poller.
// Forward the notification to the blocked poller.
netpollBreak()
}
default:
println("runtime: GetQueuedCompletionStatusEx returned net_op with invalid key=", e.key)
throw("runtime: netpoll failed")
}
}
return toRun, delta