3cbb7cbb09
Reviewed-on: https://go-review.googlesource.com/c/140277 From-SVN: r264932
539 lines
12 KiB
Go
539 lines
12 KiB
Go
// Copyright 2018 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
// +build linux
|
|
|
|
package net
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"os"
|
|
"sync"
|
|
"syscall"
|
|
"testing"
|
|
)
|
|
|
|
func TestSplice(t *testing.T) {
|
|
t.Run("simple", testSpliceSimple)
|
|
t.Run("multipleWrite", testSpliceMultipleWrite)
|
|
t.Run("big", testSpliceBig)
|
|
t.Run("honorsLimitedReader", testSpliceHonorsLimitedReader)
|
|
t.Run("readerAtEOF", testSpliceReaderAtEOF)
|
|
t.Run("issue25985", testSpliceIssue25985)
|
|
}
|
|
|
|
func testSpliceSimple(t *testing.T) {
|
|
srv, err := newSpliceTestServer()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer srv.Close()
|
|
copyDone := srv.Copy()
|
|
msg := []byte("splice test")
|
|
if _, err := srv.Write(msg); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
got := make([]byte, len(msg))
|
|
if _, err := io.ReadFull(srv, got); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !bytes.Equal(got, msg) {
|
|
t.Errorf("got %q, wrote %q", got, msg)
|
|
}
|
|
srv.CloseWrite()
|
|
srv.CloseRead()
|
|
if err := <-copyDone; err != nil {
|
|
t.Errorf("splice: %v", err)
|
|
}
|
|
}
|
|
|
|
func testSpliceMultipleWrite(t *testing.T) {
|
|
srv, err := newSpliceTestServer()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer srv.Close()
|
|
copyDone := srv.Copy()
|
|
msg1 := []byte("splice test part 1 ")
|
|
msg2 := []byte(" splice test part 2")
|
|
if _, err := srv.Write(msg1); err != nil {
|
|
t.Fatalf("Write: %v", err)
|
|
}
|
|
if _, err := srv.Write(msg2); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
got := make([]byte, len(msg1)+len(msg2))
|
|
if _, err := io.ReadFull(srv, got); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
want := append(msg1, msg2...)
|
|
if !bytes.Equal(got, want) {
|
|
t.Errorf("got %q, wrote %q", got, want)
|
|
}
|
|
srv.CloseWrite()
|
|
srv.CloseRead()
|
|
if err := <-copyDone; err != nil {
|
|
t.Errorf("splice: %v", err)
|
|
}
|
|
}
|
|
|
|
func testSpliceBig(t *testing.T) {
|
|
// The maximum amount of data that internal/poll.Splice will use in a
|
|
// splice(2) call is 4 << 20. Use a bigger size here so that we test an
|
|
// amount that doesn't fit in a single call.
|
|
size := 5 << 20
|
|
srv, err := newSpliceTestServer()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer srv.Close()
|
|
big := make([]byte, size)
|
|
copyDone := srv.Copy()
|
|
type readResult struct {
|
|
b []byte
|
|
err error
|
|
}
|
|
readDone := make(chan readResult)
|
|
go func() {
|
|
got := make([]byte, len(big))
|
|
_, err := io.ReadFull(srv, got)
|
|
readDone <- readResult{got, err}
|
|
}()
|
|
if _, err := srv.Write(big); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
res := <-readDone
|
|
if res.err != nil {
|
|
t.Fatal(res.err)
|
|
}
|
|
got := res.b
|
|
if !bytes.Equal(got, big) {
|
|
t.Errorf("input and output differ")
|
|
}
|
|
srv.CloseWrite()
|
|
srv.CloseRead()
|
|
if err := <-copyDone; err != nil {
|
|
t.Errorf("splice: %v", err)
|
|
}
|
|
}
|
|
|
|
func testSpliceHonorsLimitedReader(t *testing.T) {
|
|
t.Run("stopsAfterN", testSpliceStopsAfterN)
|
|
t.Run("updatesN", testSpliceUpdatesN)
|
|
t.Run("readerAtLimit", testSpliceReaderAtLimit)
|
|
}
|
|
|
|
func testSpliceStopsAfterN(t *testing.T) {
|
|
clientUp, serverUp, err := spliceTestSocketPair("tcp")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer clientUp.Close()
|
|
defer serverUp.Close()
|
|
clientDown, serverDown, err := spliceTestSocketPair("tcp")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer clientDown.Close()
|
|
defer serverDown.Close()
|
|
count := 128
|
|
copyDone := make(chan error)
|
|
lr := &io.LimitedReader{
|
|
N: int64(count),
|
|
R: serverUp,
|
|
}
|
|
go func() {
|
|
_, err := io.Copy(serverDown, lr)
|
|
serverDown.Close()
|
|
copyDone <- err
|
|
}()
|
|
msg := make([]byte, 2*count)
|
|
if _, err := clientUp.Write(msg); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
clientUp.Close()
|
|
var buf bytes.Buffer
|
|
if _, err := io.Copy(&buf, clientDown); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if buf.Len() != count {
|
|
t.Errorf("splice transferred %d bytes, want to stop after %d", buf.Len(), count)
|
|
}
|
|
clientDown.Close()
|
|
if err := <-copyDone; err != nil {
|
|
t.Errorf("splice: %v", err)
|
|
}
|
|
}
|
|
|
|
func testSpliceUpdatesN(t *testing.T) {
|
|
clientUp, serverUp, err := spliceTestSocketPair("tcp")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer clientUp.Close()
|
|
defer serverUp.Close()
|
|
clientDown, serverDown, err := spliceTestSocketPair("tcp")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer clientDown.Close()
|
|
defer serverDown.Close()
|
|
count := 128
|
|
copyDone := make(chan error)
|
|
lr := &io.LimitedReader{
|
|
N: int64(100 + count),
|
|
R: serverUp,
|
|
}
|
|
go func() {
|
|
_, err := io.Copy(serverDown, lr)
|
|
copyDone <- err
|
|
}()
|
|
msg := make([]byte, count)
|
|
if _, err := clientUp.Write(msg); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
clientUp.Close()
|
|
got := make([]byte, count)
|
|
if _, err := io.ReadFull(clientDown, got); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
clientDown.Close()
|
|
if err := <-copyDone; err != nil {
|
|
t.Errorf("splice: %v", err)
|
|
}
|
|
wantN := int64(100)
|
|
if lr.N != wantN {
|
|
t.Errorf("lr.N = %d, want %d", lr.N, wantN)
|
|
}
|
|
}
|
|
|
|
func testSpliceReaderAtLimit(t *testing.T) {
|
|
clientUp, serverUp, err := spliceTestSocketPair("tcp")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer clientUp.Close()
|
|
defer serverUp.Close()
|
|
clientDown, serverDown, err := spliceTestSocketPair("tcp")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer clientDown.Close()
|
|
defer serverDown.Close()
|
|
|
|
lr := &io.LimitedReader{
|
|
N: 0,
|
|
R: serverUp,
|
|
}
|
|
_, err, handled := splice(serverDown.(*TCPConn).fd, lr)
|
|
if !handled {
|
|
if serr, ok := err.(*os.SyscallError); ok && serr.Syscall == "pipe2" && serr.Err == syscall.ENOSYS {
|
|
t.Skip("pipe2 not supported")
|
|
}
|
|
|
|
t.Errorf("exhausted LimitedReader: got err = %v, handled = %t, want handled = true", err, handled)
|
|
}
|
|
}
|
|
|
|
func testSpliceReaderAtEOF(t *testing.T) {
|
|
clientUp, serverUp, err := spliceTestSocketPair("tcp")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer clientUp.Close()
|
|
clientDown, serverDown, err := spliceTestSocketPair("tcp")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer clientDown.Close()
|
|
|
|
serverUp.Close()
|
|
|
|
// We'd like to call net.splice here and check the handled return
|
|
// value, but we disable splice on old Linux kernels.
|
|
//
|
|
// In that case, poll.Splice and net.splice return a non-nil error
|
|
// and handled == false. We'd ideally like to see handled == true
|
|
// because the source reader is at EOF, but if we're running on an old
|
|
// kernel, and splice is disabled, we won't see EOF from net.splice,
|
|
// because we won't touch the reader at all.
|
|
//
|
|
// Trying to untangle the errors from net.splice and match them
|
|
// against the errors created by the poll package would be brittle,
|
|
// so this is a higher level test.
|
|
//
|
|
// The following ReadFrom should return immediately, regardless of
|
|
// whether splice is disabled or not. The other side should then
|
|
// get a goodbye signal. Test for the goodbye signal.
|
|
msg := "bye"
|
|
go func() {
|
|
serverDown.(*TCPConn).ReadFrom(serverUp)
|
|
io.WriteString(serverDown, msg)
|
|
serverDown.Close()
|
|
}()
|
|
|
|
buf := make([]byte, 3)
|
|
_, err = io.ReadFull(clientDown, buf)
|
|
if err != nil {
|
|
t.Errorf("clientDown: %v", err)
|
|
}
|
|
if string(buf) != msg {
|
|
t.Errorf("clientDown got %q, want %q", buf, msg)
|
|
}
|
|
}
|
|
|
|
func testSpliceIssue25985(t *testing.T) {
|
|
front, err := newLocalListener("tcp")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer front.Close()
|
|
back, err := newLocalListener("tcp")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer back.Close()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
proxy := func() {
|
|
src, err := front.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
dst, err := Dial("tcp", back.Addr().String())
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer dst.Close()
|
|
defer src.Close()
|
|
go func() {
|
|
io.Copy(src, dst)
|
|
wg.Done()
|
|
}()
|
|
go func() {
|
|
io.Copy(dst, src)
|
|
wg.Done()
|
|
}()
|
|
}
|
|
|
|
go proxy()
|
|
|
|
toFront, err := Dial("tcp", front.Addr().String())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
io.WriteString(toFront, "foo")
|
|
toFront.Close()
|
|
|
|
fromProxy, err := back.Accept()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer fromProxy.Close()
|
|
|
|
_, err = ioutil.ReadAll(fromProxy)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func BenchmarkTCPReadFrom(b *testing.B) {
|
|
testHookUninstaller.Do(uninstallTestHooks)
|
|
|
|
var chunkSizes []int
|
|
for i := uint(10); i <= 20; i++ {
|
|
chunkSizes = append(chunkSizes, 1<<i)
|
|
}
|
|
// To benchmark the genericReadFrom code path, set this to false.
|
|
useSplice := true
|
|
for _, chunkSize := range chunkSizes {
|
|
b.Run(fmt.Sprint(chunkSize), func(b *testing.B) {
|
|
benchmarkSplice(b, chunkSize, useSplice)
|
|
})
|
|
}
|
|
}
|
|
|
|
func benchmarkSplice(b *testing.B, chunkSize int, useSplice bool) {
|
|
srv, err := newSpliceTestServer()
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
defer srv.Close()
|
|
var copyDone <-chan error
|
|
if useSplice {
|
|
copyDone = srv.Copy()
|
|
} else {
|
|
copyDone = srv.CopyNoSplice()
|
|
}
|
|
chunk := make([]byte, chunkSize)
|
|
discardDone := make(chan struct{})
|
|
go func() {
|
|
for {
|
|
buf := make([]byte, chunkSize)
|
|
_, err := srv.Read(buf)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
discardDone <- struct{}{}
|
|
}()
|
|
b.SetBytes(int64(chunkSize))
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
srv.Write(chunk)
|
|
}
|
|
srv.CloseWrite()
|
|
<-copyDone
|
|
srv.CloseRead()
|
|
<-discardDone
|
|
}
|
|
|
|
type spliceTestServer struct {
|
|
clientUp io.WriteCloser
|
|
clientDown io.ReadCloser
|
|
serverUp io.ReadCloser
|
|
serverDown io.WriteCloser
|
|
}
|
|
|
|
func newSpliceTestServer() (*spliceTestServer, error) {
|
|
// For now, both networks are hard-coded to TCP.
|
|
// If splice is enabled for non-tcp upstream connections,
|
|
// newSpliceTestServer will need to take a network parameter.
|
|
clientUp, serverUp, err := spliceTestSocketPair("tcp")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
clientDown, serverDown, err := spliceTestSocketPair("tcp")
|
|
if err != nil {
|
|
clientUp.Close()
|
|
serverUp.Close()
|
|
return nil, err
|
|
}
|
|
return &spliceTestServer{clientUp, clientDown, serverUp, serverDown}, nil
|
|
}
|
|
|
|
// Read reads from the downstream connection.
|
|
func (srv *spliceTestServer) Read(b []byte) (int, error) {
|
|
return srv.clientDown.Read(b)
|
|
}
|
|
|
|
// Write writes to the upstream connection.
|
|
func (srv *spliceTestServer) Write(b []byte) (int, error) {
|
|
return srv.clientUp.Write(b)
|
|
}
|
|
|
|
// Close closes the server.
|
|
func (srv *spliceTestServer) Close() error {
|
|
err := srv.closeUp()
|
|
err1 := srv.closeDown()
|
|
if err == nil {
|
|
return err1
|
|
}
|
|
return err
|
|
}
|
|
|
|
// CloseWrite closes the client side of the upstream connection.
|
|
func (srv *spliceTestServer) CloseWrite() error {
|
|
return srv.clientUp.Close()
|
|
}
|
|
|
|
// CloseRead closes the client side of the downstream connection.
|
|
func (srv *spliceTestServer) CloseRead() error {
|
|
return srv.clientDown.Close()
|
|
}
|
|
|
|
// Copy copies from the server side of the upstream connection
|
|
// to the server side of the downstream connection, in a separate
|
|
// goroutine. Copy is done when the first send on the returned
|
|
// channel succeeds.
|
|
func (srv *spliceTestServer) Copy() <-chan error {
|
|
ch := make(chan error)
|
|
go func() {
|
|
_, err := io.Copy(srv.serverDown, srv.serverUp)
|
|
ch <- err
|
|
close(ch)
|
|
}()
|
|
return ch
|
|
}
|
|
|
|
// CopyNoSplice is like Copy, but ensures that the splice code path
|
|
// is not reached.
|
|
func (srv *spliceTestServer) CopyNoSplice() <-chan error {
|
|
type onlyReader struct {
|
|
io.Reader
|
|
}
|
|
ch := make(chan error)
|
|
go func() {
|
|
_, err := io.Copy(srv.serverDown, onlyReader{srv.serverUp})
|
|
ch <- err
|
|
close(ch)
|
|
}()
|
|
return ch
|
|
}
|
|
|
|
func (srv *spliceTestServer) closeUp() error {
|
|
var err, err1 error
|
|
if srv.serverUp != nil {
|
|
err = srv.serverUp.Close()
|
|
}
|
|
if srv.clientUp != nil {
|
|
err1 = srv.clientUp.Close()
|
|
}
|
|
if err == nil {
|
|
return err1
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (srv *spliceTestServer) closeDown() error {
|
|
var err, err1 error
|
|
if srv.serverDown != nil {
|
|
err = srv.serverDown.Close()
|
|
}
|
|
if srv.clientDown != nil {
|
|
err1 = srv.clientDown.Close()
|
|
}
|
|
if err == nil {
|
|
return err1
|
|
}
|
|
return err
|
|
}
|
|
|
|
func spliceTestSocketPair(net string) (client, server Conn, err error) {
|
|
ln, err := newLocalListener(net)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
defer ln.Close()
|
|
var cerr, serr error
|
|
acceptDone := make(chan struct{})
|
|
go func() {
|
|
server, serr = ln.Accept()
|
|
acceptDone <- struct{}{}
|
|
}()
|
|
client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
|
|
<-acceptDone
|
|
if cerr != nil {
|
|
if server != nil {
|
|
server.Close()
|
|
}
|
|
return nil, nil, cerr
|
|
}
|
|
if serr != nil {
|
|
if client != nil {
|
|
client.Close()
|
|
}
|
|
return nil, nil, serr
|
|
}
|
|
return client, server, nil
|
|
}
|