libgo: Update to weekly.2011-11-18.

From-SVN: r182266
This commit is contained in:
Ian Lance Taylor 2011-12-12 23:40:51 +00:00
parent 6e456f4cf4
commit ab61e9c4da
223 changed files with 6373 additions and 3999 deletions

View File

@ -1,4 +1,4 @@
2f4482b89a6b
b4a91b693374
The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources.

View File

@ -648,7 +648,8 @@ go_math_files = \
go_mime_files = \
go/mime/grammar.go \
go/mime/mediatype.go \
go/mime/type.go
go/mime/type.go \
go/mime/type_unix.go
if LIBGO_IS_RTEMS
go_net_fd_os_file = go/net/fd_select.go
@ -770,7 +771,6 @@ go_os_files = \
$(go_os_dir_file) \
go/os/dir.go \
go/os/env.go \
go/os/env_unix.go \
go/os/error.go \
go/os/error_posix.go \
go/os/exec.go \
@ -1156,6 +1156,7 @@ go_exp_sql_files = \
go/exp/sql/sql.go
go_exp_ssh_files = \
go/exp/ssh/channel.go \
go/exp/ssh/cipher.go \
go/exp/ssh/client.go \
go/exp/ssh/client_auth.go \
go/exp/ssh/common.go \
@ -1164,10 +1165,11 @@ go_exp_ssh_files = \
go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \
go/exp/ssh/session.go \
go/exp/ssh/tcpip.go \
go/exp/ssh/transport.go
go_exp_terminal_files = \
go/exp/terminal/shell.go \
go/exp/terminal/terminal.go
go/exp/terminal/terminal.go \
go/exp/terminal/util.go
go_exp_types_files = \
go/exp/types/check.go \
go/exp/types/const.go \
@ -1546,6 +1548,7 @@ syscall_netlink_file =
endif
go_base_syscall_files = \
go/syscall/env_unix.go \
go/syscall/libcall_support.go \
go/syscall/libcall_posix.go \
go/syscall/socket.go \

View File

@ -1032,7 +1032,8 @@ go_math_files = \
go_mime_files = \
go/mime/grammar.go \
go/mime/mediatype.go \
go/mime/type.go
go/mime/type.go \
go/mime/type_unix.go
# By default use select with pipes. Most systems should have
# something better.
@ -1103,7 +1104,6 @@ go_os_files = \
$(go_os_dir_file) \
go/os/dir.go \
go/os/env.go \
go/os/env_unix.go \
go/os/error.go \
go/os/error_posix.go \
go/os/exec.go \
@ -1521,6 +1521,7 @@ go_exp_sql_files = \
go_exp_ssh_files = \
go/exp/ssh/channel.go \
go/exp/ssh/cipher.go \
go/exp/ssh/client.go \
go/exp/ssh/client_auth.go \
go/exp/ssh/common.go \
@ -1529,11 +1530,12 @@ go_exp_ssh_files = \
go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \
go/exp/ssh/session.go \
go/exp/ssh/tcpip.go \
go/exp/ssh/transport.go
go_exp_terminal_files = \
go/exp/terminal/shell.go \
go/exp/terminal/terminal.go
go/exp/terminal/terminal.go \
go/exp/terminal/util.go
go_exp_types_files = \
go/exp/types/check.go \
@ -1890,6 +1892,7 @@ go_unicode_utf8_files = \
# Support for netlink sockets and messages.
@LIBGO_IS_LINUX_TRUE@syscall_netlink_file = go/syscall/netlink_linux.go
go_base_syscall_files = \
go/syscall/env_unix.go \
go/syscall/libcall_support.go \
go/syscall/libcall_posix.go \
go/syscall/socket.go \

View File

@ -10,7 +10,6 @@ import (
"fmt"
"io"
"io/ioutil"
"os"
"strings"
"testing"
"testing/iotest"
@ -425,9 +424,9 @@ var errorWriterTests = []errorWriterTest{
{0, 1, nil, io.ErrShortWrite},
{1, 2, nil, io.ErrShortWrite},
{1, 1, nil, nil},
{0, 1, os.EPIPE, os.EPIPE},
{1, 2, os.EPIPE, os.EPIPE},
{1, 1, os.EPIPE, os.EPIPE},
{0, 1, io.ErrClosedPipe, io.ErrClosedPipe},
{1, 2, io.ErrClosedPipe, io.ErrClosedPipe},
{1, 1, io.ErrClosedPipe, io.ErrClosedPipe},
}
func TestWriteErrors(t *testing.T) {

View File

@ -91,6 +91,11 @@ type rune rune
// invocation.
type Type int
// Type1 is here for the purposes of documentation only. It is a stand-in
// for any Go type, but represents the same type for any given function
// invocation.
type Type1 int
// IntegerType is here for the purposes of documentation only. It is a stand-in
// for any integer type: int, uint, int8 etc.
type IntegerType int
@ -119,6 +124,11 @@ func append(slice []Type, elems ...Type) []Type
// len(src) and len(dst).
func copy(dst, src []Type) int
// The delete built-in function deletes the element with the specified key
// (m[key]) from the map. If there is no such element, delete is a no-op.
// If m is nil, delete panics.
func delete(m map[Type]Type1, key Type)
// The len built-in function returns the length of v, according to its type:
// Array: the number of elements in v.
// Pointer to array: the number of elements in *v (even if v is nil).
@ -171,7 +181,7 @@ func complex(r, i FloatType) ComplexType
// The return value will be floating point type corresponding to the type of c.
func real(c ComplexType) FloatType
// The imaginary built-in function returns the imaginary part of the complex
// The imag built-in function returns the imaginary part of the complex
// number c. The return value will be floating point type corresponding to
// the type of c.
func imag(c ComplexType) FloatType

View File

@ -662,48 +662,49 @@ func TestRunes(t *testing.T) {
}
type TrimTest struct {
f func([]byte, string) []byte
f string
in, cutset, out string
}
var trimTests = []TrimTest{
{Trim, "abba", "a", "bb"},
{Trim, "abba", "ab", ""},
{TrimLeft, "abba", "ab", ""},
{TrimRight, "abba", "ab", ""},
{TrimLeft, "abba", "a", "bba"},
{TrimRight, "abba", "a", "abb"},
{Trim, "<tag>", "<>", "tag"},
{Trim, "* listitem", " *", "listitem"},
{Trim, `"quote"`, `"`, "quote"},
{Trim, "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"},
{"Trim", "abba", "a", "bb"},
{"Trim", "abba", "ab", ""},
{"TrimLeft", "abba", "ab", ""},
{"TrimRight", "abba", "ab", ""},
{"TrimLeft", "abba", "a", "bba"},
{"TrimRight", "abba", "a", "abb"},
{"Trim", "<tag>", "<>", "tag"},
{"Trim", "* listitem", " *", "listitem"},
{"Trim", `"quote"`, `"`, "quote"},
{"Trim", "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"},
//empty string tests
{Trim, "abba", "", "abba"},
{Trim, "", "123", ""},
{Trim, "", "", ""},
{TrimLeft, "abba", "", "abba"},
{TrimLeft, "", "123", ""},
{TrimLeft, "", "", ""},
{TrimRight, "abba", "", "abba"},
{TrimRight, "", "123", ""},
{TrimRight, "", "", ""},
{TrimRight, "☺\xc0", "☺", "☺\xc0"},
{"Trim", "abba", "", "abba"},
{"Trim", "", "123", ""},
{"Trim", "", "", ""},
{"TrimLeft", "abba", "", "abba"},
{"TrimLeft", "", "123", ""},
{"TrimLeft", "", "", ""},
{"TrimRight", "abba", "", "abba"},
{"TrimRight", "", "123", ""},
{"TrimRight", "", "", ""},
{"TrimRight", "☺\xc0", "☺", "☺\xc0"},
}
func TestTrim(t *testing.T) {
for _, tc := range trimTests {
actual := string(tc.f([]byte(tc.in), tc.cutset))
var name string
switch tc.f {
case Trim:
name = "Trim"
case TrimLeft:
name = "TrimLeft"
case TrimRight:
name = "TrimRight"
name := tc.f
var f func([]byte, string) []byte
switch name {
case "Trim":
f = Trim
case "TrimLeft":
f = TrimLeft
case "TrimRight":
f = TrimRight
default:
t.Error("Undefined trim function")
t.Error("Undefined trim function %s", name)
}
actual := string(f([]byte(tc.in), tc.cutset))
if actual != tc.out {
t.Errorf("%s(%q, %q) = %q; want %q", name, tc.in, tc.cutset, actual, tc.out)
}

View File

@ -19,7 +19,6 @@ import (
"errors"
"fmt"
"io"
"os"
)
// Order specifies the bit ordering in an LZW data stream.
@ -212,8 +211,10 @@ func (d *decoder) flush() {
d.o = 0
}
var errClosed = errors.New("compress/lzw: reader/writer is closed")
func (d *decoder) Close() error {
d.err = os.EINVAL // in case any Reads come along
d.err = errClosed // in case any Reads come along
return nil
}

View File

@ -9,7 +9,6 @@ import (
"errors"
"fmt"
"io"
"os"
)
// A writer is a buffered, flushable writer.
@ -49,8 +48,9 @@ const (
type encoder struct {
// w is the writer that compressed bytes are written to.
w writer
// write, bits, nBits and width are the state for converting a code stream
// into a byte stream.
// order, write, bits, nBits and width are the state for
// converting a code stream into a byte stream.
order Order
write func(*encoder, uint32) error
bits uint32
nBits uint
@ -64,7 +64,7 @@ type encoder struct {
// call. It is equal to invalidCode if there was no such call.
savedCode uint32
// err is the first error encountered during writing. Closing the encoder
// will make any future Write calls return os.EINVAL.
// will make any future Write calls return errClosed
err error
// table is the hash table from 20-bit keys to 12-bit values. Each table
// entry contains key<<12|val and collisions resolve by linear probing.
@ -191,13 +191,13 @@ loop:
// flush e's underlying writer.
func (e *encoder) Close() error {
if e.err != nil {
if e.err == os.EINVAL {
if e.err == errClosed {
return nil
}
return e.err
}
// Make any future calls to Write return os.EINVAL.
e.err = os.EINVAL
// Make any future calls to Write return errClosed.
e.err = errClosed
// Write the savedCode if valid.
if e.savedCode != invalidCode {
if err := e.write(e, e.savedCode); err != nil {
@ -214,7 +214,7 @@ func (e *encoder) Close() error {
}
// Write the final bits.
if e.nBits > 0 {
if e.write == (*encoder).writeMSB {
if e.order == MSB {
e.bits >>= 24
}
if err := e.w.WriteByte(uint8(e.bits)); err != nil {
@ -250,6 +250,7 @@ func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
lw := uint(litWidth)
return &encoder{
w: bw,
order: order,
write: write,
width: 1 + lw,
litWidth: lw,

View File

@ -50,10 +50,6 @@ func testFile(t *testing.T, fn string, order Order, litWidth int) {
return
}
_, err1 := lzww.Write(b[:n])
if err1 == os.EPIPE {
// Fail, but do not report the error, as some other (presumably reportable) error broke the pipe.
return
}
if err1 != nil {
t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
return

View File

@ -59,10 +59,6 @@ func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) {
}
defer zlibw.Close()
_, err = zlibw.Write(b0)
if err == os.EPIPE {
// Fail, but do not report the error, as some other (presumably reported) error broke the pipe.
return
}
if err != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
return

View File

@ -41,7 +41,7 @@ func NewCipher(key []byte) (*Cipher, error) {
}
// BlockSize returns the AES block size, 16 bytes.
// It is necessary to satisfy the Cipher interface in the
// It is necessary to satisfy the Block interface in the
// package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize }

View File

@ -54,7 +54,7 @@ func NewSaltedCipher(key, salt []byte) (*Cipher, error) {
}
// BlockSize returns the Blowfish block size, 8 bytes.
// It is necessary to satisfy the Cipher interface in the
// It is necessary to satisfy the Block interface in the
// package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize }

View File

@ -28,16 +28,16 @@ func (r *rngReader) Read(b []byte) (n int, err error) {
if r.prov == 0 {
const provType = syscall.PROV_RSA_FULL
const flags = syscall.CRYPT_VERIFYCONTEXT | syscall.CRYPT_SILENT
errno := syscall.CryptAcquireContext(&r.prov, nil, nil, provType, flags)
if errno != 0 {
err := syscall.CryptAcquireContext(&r.prov, nil, nil, provType, flags)
if err != nil {
r.mu.Unlock()
return 0, os.NewSyscallError("CryptAcquireContext", errno)
return 0, os.NewSyscallError("CryptAcquireContext", err)
}
}
r.mu.Unlock()
errno := syscall.CryptGenRandom(r.prov, uint32(len(b)), &b[0])
if errno != 0 {
return 0, os.NewSyscallError("CryptGenRandom", errno)
err = syscall.CryptGenRandom(r.prov, uint32(len(b)), &b[0])
if err != nil {
return 0, os.NewSyscallError("CryptGenRandom", err)
}
return len(b), nil
}

View File

@ -5,16 +5,16 @@
package rand
import (
"errors"
"io"
"math/big"
"os"
)
// Prime returns a number, p, of the given size, such that p is prime
// with high probability.
func Prime(rand io.Reader, bits int) (p *big.Int, err error) {
if bits < 1 {
err = os.EINVAL
err = errors.New("crypto/rand: prime size must be positive")
}
b := uint(bits % 8)

View File

@ -93,7 +93,8 @@ func (c *Conn) SetTimeout(nsec int64) error {
}
// SetReadTimeout sets the time (in nanoseconds) that
// Read will wait for data before returning os.EAGAIN.
// Read will wait for data before returning a net.Error
// with Timeout() == true.
// Setting nsec == 0 (the default) disables the deadline.
func (c *Conn) SetReadTimeout(nsec int64) error {
return c.conn.SetReadTimeout(nsec)
@ -737,7 +738,7 @@ func (c *Conn) Write(b []byte) (n int, err error) {
return c.writeRecord(recordTypeApplicationData, b)
}
// Read can be made to time out and return err == os.EAGAIN
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetTimeout and SetReadTimeout.
func (c *Conn) Read(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {

View File

@ -4,6 +4,8 @@
package tls
import "bytes"
type clientHelloMsg struct {
raw []byte
vers uint16
@ -18,6 +20,25 @@ type clientHelloMsg struct {
supportedPoints []uint8
}
func (m *clientHelloMsg) equal(i interface{}) bool {
m1, ok := i.(*clientHelloMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
eqUint16s(m.cipherSuites, m1.cipherSuites) &&
bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
m.nextProtoNeg == m1.nextProtoNeg &&
m.serverName == m1.serverName &&
m.ocspStapling == m1.ocspStapling &&
eqUint16s(m.supportedCurves, m1.supportedCurves) &&
bytes.Equal(m.supportedPoints, m1.supportedPoints)
}
func (m *clientHelloMsg) marshal() []byte {
if m.raw != nil {
return m.raw
@ -309,6 +330,23 @@ type serverHelloMsg struct {
ocspStapling bool
}
func (m *serverHelloMsg) equal(i interface{}) bool {
m1, ok := i.(*serverHelloMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
m.cipherSuite == m1.cipherSuite &&
m.compressionMethod == m1.compressionMethod &&
m.nextProtoNeg == m1.nextProtoNeg &&
eqStrings(m.nextProtos, m1.nextProtos) &&
m.ocspStapling == m1.ocspStapling
}
func (m *serverHelloMsg) marshal() []byte {
if m.raw != nil {
return m.raw
@ -463,6 +501,16 @@ type certificateMsg struct {
certificates [][]byte
}
func (m *certificateMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
eqByteSlices(m.certificates, m1.certificates)
}
func (m *certificateMsg) marshal() (x []byte) {
if m.raw != nil {
return m.raw
@ -540,6 +588,16 @@ type serverKeyExchangeMsg struct {
key []byte
}
func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
m1, ok := i.(*serverKeyExchangeMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.key, m1.key)
}
func (m *serverKeyExchangeMsg) marshal() []byte {
if m.raw != nil {
return m.raw
@ -571,6 +629,17 @@ type certificateStatusMsg struct {
response []byte
}
func (m *certificateStatusMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateStatusMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.statusType == m1.statusType &&
bytes.Equal(m.response, m1.response)
}
func (m *certificateStatusMsg) marshal() []byte {
if m.raw != nil {
return m.raw
@ -622,6 +691,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool {
type serverHelloDoneMsg struct{}
func (m *serverHelloDoneMsg) equal(i interface{}) bool {
_, ok := i.(*serverHelloDoneMsg)
return ok
}
func (m *serverHelloDoneMsg) marshal() []byte {
x := make([]byte, 4)
x[0] = typeServerHelloDone
@ -637,6 +711,16 @@ type clientKeyExchangeMsg struct {
ciphertext []byte
}
func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
m1, ok := i.(*clientKeyExchangeMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.ciphertext, m1.ciphertext)
}
func (m *clientKeyExchangeMsg) marshal() []byte {
if m.raw != nil {
return m.raw
@ -671,6 +755,16 @@ type finishedMsg struct {
verifyData []byte
}
func (m *finishedMsg) equal(i interface{}) bool {
m1, ok := i.(*finishedMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.verifyData, m1.verifyData)
}
func (m *finishedMsg) marshal() (x []byte) {
if m.raw != nil {
return m.raw
@ -698,6 +792,16 @@ type nextProtoMsg struct {
proto string
}
func (m *nextProtoMsg) equal(i interface{}) bool {
m1, ok := i.(*nextProtoMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.proto == m1.proto
}
func (m *nextProtoMsg) marshal() []byte {
if m.raw != nil {
return m.raw
@ -759,6 +863,17 @@ type certificateRequestMsg struct {
certificateAuthorities [][]byte
}
func (m *certificateRequestMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateRequestMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities)
}
func (m *certificateRequestMsg) marshal() (x []byte) {
if m.raw != nil {
return m.raw
@ -859,6 +974,16 @@ type certificateVerifyMsg struct {
signature []byte
}
func (m *certificateVerifyMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateVerifyMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.signature, m1.signature)
}
func (m *certificateVerifyMsg) marshal() (x []byte) {
if m.raw != nil {
return m.raw
@ -902,3 +1027,39 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
return true
}
func eqUint16s(x, y []uint16) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
if y[i] != v {
return false
}
}
return true
}
func eqStrings(x, y []string) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
if y[i] != v {
return false
}
}
return true
}
func eqByteSlices(x, y [][]byte) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
if !bytes.Equal(v, y[i]) {
return false
}
}
return true
}

View File

@ -27,10 +27,12 @@ var tests = []interface{}{
type testMessage interface {
marshal() []byte
unmarshal([]byte) bool
equal(interface{}) bool
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0))
for i, iface := range tests {
ty := reflect.ValueOf(iface).Type()
@ -54,7 +56,7 @@ func TestMarshalUnmarshal(t *testing.T) {
}
m2.marshal() // to fill any marshal cache in the message
if !reflect.DeepEqual(m1, m2) {
if !m1.equal(m2) {
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
break
}

View File

@ -12,8 +12,8 @@ import (
)
func loadStore(roots *x509.CertPool, name string) {
store, errno := syscall.CertOpenSystemStore(syscall.InvalidHandle, syscall.StringToUTF16Ptr(name))
if errno != 0 {
store, err := syscall.CertOpenSystemStore(syscall.InvalidHandle, syscall.StringToUTF16Ptr(name))
if err != nil {
return
}

View File

@ -44,7 +44,7 @@ func NewCipher(key []byte) (*Cipher, error) {
}
// BlockSize returns the XTEA block size, 8 bytes.
// It is necessary to satisfy the Cipher interface in the
// It is necessary to satisfy the Block interface in the
// package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize }

View File

@ -0,0 +1,157 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Large data benchmark.
// The JSON data is a summary of agl's changes in the
// go, webkit, and chromium open source projects.
// We benchmark converting between the JSON form
// and in-memory data structures.
package json
import (
"bytes"
"compress/gzip"
"io/ioutil"
"os"
"testing"
)
type codeResponse struct {
Tree *codeNode `json:"tree"`
Username string `json:"username"`
}
type codeNode struct {
Name string `json:"name"`
Kids []*codeNode `json:"kids"`
CLWeight float64 `json:"cl_weight"`
Touches int `json:"touches"`
MinT int64 `json:"min_t"`
MaxT int64 `json:"max_t"`
MeanT int64 `json:"mean_t"`
}
var codeJSON []byte
var codeStruct codeResponse
func codeInit() {
f, err := os.Open("testdata/code.json.gz")
if err != nil {
panic(err)
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
panic(err)
}
data, err := ioutil.ReadAll(gz)
if err != nil {
panic(err)
}
codeJSON = data
if err := Unmarshal(codeJSON, &codeStruct); err != nil {
panic("unmarshal code.json: " + err.Error())
}
if data, err = Marshal(&codeStruct); err != nil {
panic("marshal code.json: " + err.Error())
}
if !bytes.Equal(data, codeJSON) {
println("different lengths", len(data), len(codeJSON))
for i := 0; i < len(data) && i < len(codeJSON); i++ {
if data[i] != codeJSON[i] {
println("re-marshal: changed at byte", i)
println("orig: ", string(codeJSON[i-10:i+10]))
println("new: ", string(data[i-10:i+10]))
break
}
}
panic("re-marshal code.json: different result")
}
}
func BenchmarkCodeEncoder(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
enc := NewEncoder(ioutil.Discard)
for i := 0; i < b.N; i++ {
if err := enc.Encode(&codeStruct); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
func BenchmarkCodeMarshal(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
for i := 0; i < b.N; i++ {
if _, err := Marshal(&codeStruct); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
func BenchmarkCodeDecoder(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
var buf bytes.Buffer
dec := NewDecoder(&buf)
var r codeResponse
for i := 0; i < b.N; i++ {
buf.Write(codeJSON)
// hide EOF
buf.WriteByte('\n')
buf.WriteByte('\n')
buf.WriteByte('\n')
if err := dec.Decode(&r); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
func BenchmarkCodeUnmarshal(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
for i := 0; i < b.N; i++ {
var r codeResponse
if err := Unmarshal(codeJSON, &r); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
func BenchmarkCodeUnmarshalReuse(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
var r codeResponse
for i := 0; i < b.N; i++ {
if err := Unmarshal(codeJSON, &r); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}

View File

@ -227,7 +227,7 @@ func (d *decodeState) value(v reflect.Value) {
// d.scan thinks we're still at the beginning of the item.
// Feed in an empty string - the shortest, simplest value -
// so that it knows we got to the end of the value.
if d.scan.step == stateRedo {
if d.scan.redo {
panic("redo")
}
d.scan.step(&d.scan, '"')
@ -381,6 +381,7 @@ func (d *decodeState) array(v reflect.Value) {
d.error(errPhase)
}
}
if i < av.Len() {
if !sv.IsValid() {
// Array. Zero the rest.
@ -392,6 +393,9 @@ func (d *decodeState) array(v reflect.Value) {
sv.SetLen(i)
}
}
if i == 0 && av.Kind() == reflect.Slice && sv.IsNil() {
sv.Set(reflect.MakeSlice(sv.Type(), 0, 0))
}
}
// object consumes an object from d.data[d.off-1:], decoding into the value v.

View File

@ -80,6 +80,9 @@ type scanner struct {
// on a 64-bit Mac Mini, and it's nicer to read.
step func(*scanner, int) int
// Reached end of top-level value.
endTop bool
// Stack of what we're in the middle of - array values, object keys, object values.
parseState []int
@ -87,6 +90,7 @@ type scanner struct {
err error
// 1-byte redo (see undo method)
redo bool
redoCode int
redoState func(*scanner, int) int
@ -135,6 +139,8 @@ func (s *scanner) reset() {
s.step = stateBeginValue
s.parseState = s.parseState[0:0]
s.err = nil
s.redo = false
s.endTop = false
}
// eof tells the scanner that the end of input has been reached.
@ -143,11 +149,11 @@ func (s *scanner) eof() int {
if s.err != nil {
return scanError
}
if s.step == stateEndTop {
if s.endTop {
return scanEnd
}
s.step(s, ' ')
if s.step == stateEndTop {
if s.endTop {
return scanEnd
}
if s.err == nil {
@ -166,8 +172,10 @@ func (s *scanner) pushParseState(p int) {
func (s *scanner) popParseState() {
n := len(s.parseState) - 1
s.parseState = s.parseState[0:n]
s.redo = false
if n == 0 {
s.step = stateEndTop
s.endTop = true
} else {
s.step = stateEndValue
}
@ -269,6 +277,7 @@ func stateEndValue(s *scanner, c int) int {
if n == 0 {
// Completed top-level before the current byte.
s.step = stateEndTop
s.endTop = true
return stateEndTop(s, c)
}
if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') {
@ -606,16 +615,18 @@ func quoteChar(c int) string {
// undo causes the scanner to return scanCode from the next state transition.
// This gives callers a simple 1-byte undo mechanism.
func (s *scanner) undo(scanCode int) {
if s.step == stateRedo {
panic("invalid use of scanner")
if s.redo {
panic("json: invalid use of scanner")
}
s.redoCode = scanCode
s.redoState = s.step
s.step = stateRedo
s.redo = true
}
// stateRedo helps implement the scanner's 1-byte undo.
func stateRedo(s *scanner, c int) int {
s.redo = false
s.step = s.redoState
return s.redoCode
}

View File

@ -186,11 +186,12 @@ func TestNextValueBig(t *testing.T) {
}
}
var benchScan scanner
func BenchmarkSkipValue(b *testing.B) {
initBig()
var scan scanner
for i := 0; i < b.N; i++ {
nextValue(jsonBig, &scan)
nextValue(jsonBig, &benchScan)
}
b.SetBytes(int64(len(jsonBig)))
}

View File

@ -7,7 +7,6 @@ package xml
import (
"bytes"
"io"
"os"
"reflect"
"strings"
"testing"
@ -43,17 +42,17 @@ var rawTokens = []Token{
CharData([]byte("World <>'\" 白鵬翔")),
EndElement{Name{"", "hello"}},
CharData([]byte("\n ")),
StartElement{Name{"", "goodbye"}, nil},
StartElement{Name{"", "goodbye"}, []Attr{}},
EndElement{Name{"", "goodbye"}},
CharData([]byte("\n ")),
StartElement{Name{"", "outer"}, []Attr{{Name{"foo", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
CharData([]byte("\n ")),
StartElement{Name{"", "inner"}, nil},
StartElement{Name{"", "inner"}, []Attr{}},
EndElement{Name{"", "inner"}},
CharData([]byte("\n ")),
EndElement{Name{"", "outer"}},
CharData([]byte("\n ")),
StartElement{Name{"tag", "name"}, nil},
StartElement{Name{"tag", "name"}, []Attr{}},
CharData([]byte("\n ")),
CharData([]byte("Some text here.")),
CharData([]byte("\n ")),
@ -77,17 +76,17 @@ var cookedTokens = []Token{
CharData([]byte("World <>'\" 白鵬翔")),
EndElement{Name{"ns2", "hello"}},
CharData([]byte("\n ")),
StartElement{Name{"ns2", "goodbye"}, nil},
StartElement{Name{"ns2", "goodbye"}, []Attr{}},
EndElement{Name{"ns2", "goodbye"}},
CharData([]byte("\n ")),
StartElement{Name{"ns2", "outer"}, []Attr{{Name{"ns1", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
CharData([]byte("\n ")),
StartElement{Name{"ns2", "inner"}, nil},
StartElement{Name{"ns2", "inner"}, []Attr{}},
EndElement{Name{"ns2", "inner"}},
CharData([]byte("\n ")),
EndElement{Name{"ns2", "outer"}},
CharData([]byte("\n ")),
StartElement{Name{"ns3", "name"}, nil},
StartElement{Name{"ns3", "name"}, []Attr{}},
CharData([]byte("\n ")),
CharData([]byte("Some text here.")),
CharData([]byte("\n ")),
@ -105,7 +104,7 @@ var rawTokensAltEncoding = []Token{
CharData([]byte("\n")),
ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)},
CharData([]byte("\n")),
StartElement{Name{"", "tag"}, nil},
StartElement{Name{"", "tag"}, []Attr{}},
CharData([]byte("value")),
EndElement{Name{"", "tag"}},
}
@ -205,7 +204,7 @@ func (d *downCaser) ReadByte() (c byte, err error) {
func (d *downCaser) Read(p []byte) (int, error) {
d.t.Fatalf("unexpected Read call on downCaser reader")
return 0, os.EINVAL
panic("unreachable")
}
func TestRawTokenAltEncoding(t *testing.T) {

View File

@ -105,9 +105,9 @@ func (w *Watcher) AddWatch(path string, flags uint32) error {
watchEntry.flags |= flags
flags |= syscall.IN_MASK_ADD
}
wd, errno := syscall.InotifyAddWatch(w.fd, path, flags)
if wd == -1 {
return &os.PathError{"inotify_add_watch", path, os.Errno(errno)}
wd, err := syscall.InotifyAddWatch(w.fd, path, flags)
if err != nil {
return &os.PathError{"inotify_add_watch", path, err}
}
if !found {
@ -139,14 +139,10 @@ func (w *Watcher) RemoveWatch(path string) error {
// readEvents reads from the inotify file descriptor, converts the
// received events into Event objects and sends them via the Event channel
func (w *Watcher) readEvents() {
var (
buf [syscall.SizeofInotifyEvent * 4096]byte // Buffer for a maximum of 4096 raw events
n int // Number of bytes read with read()
errno int // Syscall errno
)
var buf [syscall.SizeofInotifyEvent * 4096]byte
for {
n, errno = syscall.Read(w.fd, buf[0:])
n, err := syscall.Read(w.fd, buf[0:])
// See if there is a message on the "done" channel
var done bool
select {
@ -156,16 +152,16 @@ func (w *Watcher) readEvents() {
// If EOF or a "done" message is received
if n == 0 || done {
errno := syscall.Close(w.fd)
if errno == -1 {
w.Error <- os.NewSyscallError("close", errno)
err := syscall.Close(w.fd)
if err != nil {
w.Error <- os.NewSyscallError("close", err)
}
close(w.Event)
close(w.Error)
return
}
if n < 0 {
w.Error <- os.NewSyscallError("read", errno)
w.Error <- os.NewSyscallError("read", err)
continue
}
if n < syscall.SizeofInotifyEvent {

View File

@ -14,6 +14,21 @@ import (
"strconv"
)
// subsetTypeArgs takes a slice of arguments from callers of the sql
// package and converts them into a slice of the driver package's
// "subset types".
func subsetTypeArgs(args []interface{}) ([]interface{}, error) {
out := make([]interface{}, len(args))
for n, arg := range args {
var err error
out[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
if err != nil {
return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n+1, err)
}
}
return out, nil
}
// convertAssign copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information.
// dest should be a pointer type.

View File

@ -36,19 +36,22 @@ type Driver interface {
Open(name string) (Conn, error)
}
// Execer is an optional interface that may be implemented by a Driver
// or a Conn.
//
// If a Driver does not implement Execer, the sql package's DB.Exec
// method first obtains a free connection from its free pool or from
// the driver's Open method. Execer should only be implemented by
// drivers that can provide a more efficient implementation.
// ErrSkip may be returned by some optional interfaces' methods to
// indicate at runtime that the fast path is unavailable and the sql
// package should continue as if the optional interface was not
// implemented. ErrSkip is only supported where explicitly
// documented.
var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
// Execer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Execer, the db package's DB.Exec will
// first prepare a query, execute the statement, and then close the
// statement.
//
// All arguments are of a subset type as defined in the package docs.
//
// Exec may return ErrSkip.
type Execer interface {
Exec(query string, args []interface{}) (Result, error)
}
@ -94,6 +97,9 @@ type Stmt interface {
Close() error
// NumInput returns the number of placeholder parameters.
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
NumInput() int
// Exec executes a query that doesn't return rows, such
@ -135,6 +141,8 @@ type Rows interface {
// The dest slice may be populated with only with values
// of subset types defined above, but excluding string.
// All string values must be converted to []byte.
//
// Next should return io.EOF when there are no more rows.
Next(dest []interface{}) error
}

View File

@ -195,6 +195,29 @@ func (c *fakeConn) Close() error {
return nil
}
func checkSubsetTypes(args []interface{}) error {
for n, arg := range args {
switch arg.(type) {
case int64, float64, bool, nil, []byte, string:
default:
return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
}
}
return nil
}
func (c *fakeConn) Exec(query string, args []interface{}) (driver.Result, error) {
// This is an optional interface, but it's implemented here
// just to check that all the args of of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
// implement this at all.
err := checkSubsetTypes(args)
if err != nil {
return nil, err
}
return nil, driver.ErrSkip
}
func errf(msg string, args ...interface{}) error {
return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
}
@ -323,6 +346,11 @@ func (s *fakeStmt) Close() error {
}
func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
err := checkSubsetTypes(args)
if err != nil {
return nil, err
}
db := s.c.db
switch s.cmd {
case "WIPE":
@ -377,6 +405,11 @@ func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) {
}
func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) {
err := checkSubsetTypes(args)
if err != nil {
return nil, err
}
db := s.c.db
if len(args) != s.placeholders {
panic("error in pkg db; should only get here if size is correct")

View File

@ -88,8 +88,9 @@ type DB struct {
driver driver.Driver
dsn string
mu sync.Mutex
mu sync.Mutex // protects freeConn and closed
freeConn []driver.Conn
closed bool
}
// Open opens a database specified by its database driver name and a
@ -106,6 +107,22 @@ func Open(driverName, dataSourceName string) (*DB, error) {
return &DB{driver: driver, dsn: dataSourceName}, nil
}
// Close closes the database, releasing any open resources.
func (db *DB) Close() error {
db.mu.Lock()
defer db.mu.Unlock()
var err error
for _, c := range db.freeConn {
err1 := c.Close()
if err1 != nil {
err = err1
}
}
db.freeConn = nil
db.closed = true
return err
}
func (db *DB) maxIdleConns() int {
const defaultMaxIdleConns = 2
// TODO(bradfitz): ask driver, if supported, for its default preference
@ -116,6 +133,9 @@ func (db *DB) maxIdleConns() int {
// conn returns a newly-opened or cached driver.Conn
func (db *DB) conn() (driver.Conn, error) {
db.mu.Lock()
if db.closed {
return nil, errors.New("sql: database is closed")
}
if n := len(db.freeConn); n > 0 {
conn := db.freeConn[n-1]
db.freeConn = db.freeConn[:n-1]
@ -140,11 +160,13 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
}
func (db *DB) putConn(c driver.Conn) {
if n := len(db.freeConn); n < db.maxIdleConns() {
db.mu.Lock()
defer db.mu.Unlock()
if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() {
db.freeConn = append(db.freeConn, c)
return
}
db.closeConn(c)
db.closeConn(c) // TODO(bradfitz): release lock before calling this?
}
func (db *DB) closeConn(c driver.Conn) {
@ -180,17 +202,11 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
// Exec executes a query without returning any rows.
func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
// Optional fast path, if the driver implements driver.Execer.
if execer, ok := db.driver.(driver.Execer); ok {
resi, err := execer.Exec(query, args)
if err != nil {
return nil, err
}
return result{resi}, nil
sargs, err := subsetTypeArgs(args)
if err != nil {
return nil, err
}
// If the driver does not implement driver.Execer, we need
// a connection.
ci, err := db.conn()
if err != nil {
return nil, err
@ -198,11 +214,13 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
defer db.putConn(ci)
if execer, ok := ci.(driver.Execer); ok {
resi, err := execer.Exec(query, args)
if err != nil {
return nil, err
resi, err := execer.Exec(query, sargs)
if err != driver.ErrSkip {
if err != nil {
return nil, err
}
return result{resi}, nil
}
return result{resi}, nil
}
sti, err := ci.Prepare(query)
@ -210,7 +228,8 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
return nil, err
}
defer sti.Close()
resi, err := sti.Exec(args)
resi, err := sti.Exec(sargs)
if err != nil {
return nil, err
}
@ -386,7 +405,13 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
return nil, err
}
defer sti.Close()
resi, err := sti.Exec(args)
sargs, err := subsetTypeArgs(args)
if err != nil {
return nil, err
}
resi, err := sti.Exec(sargs)
if err != nil {
return nil, err
}
@ -449,7 +474,10 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
}
defer releaseConn()
if want := si.NumInput(); len(args) != want {
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
if want := si.NumInput(); want != -1 && len(args) != want {
return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args))
}
@ -545,10 +573,18 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
if err != nil {
return nil, err
}
if len(args) != si.NumInput() {
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
if want := si.NumInput(); want != -1 && len(args) != want {
return nil, fmt.Errorf("db: statement expects %d inputs; got %d", si.NumInput(), len(args))
}
rowsi, err := si.Query(args)
sargs, err := subsetTypeArgs(args)
if err != nil {
return nil, err
}
rowsi, err := si.Query(sargs)
if err != nil {
s.db.putConn(ci)
return nil, err

View File

@ -34,8 +34,16 @@ func exec(t *testing.T, db *DB, query string, args ...interface{}) {
}
}
func closeDB(t *testing.T, db *DB) {
err := db.Close()
if err != nil {
t.Fatalf("error closing DB: %v", err)
}
}
func TestQuery(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
var name string
var age int
@ -69,6 +77,7 @@ func TestQuery(t *testing.T) {
func TestStatementQueryRow(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
stmt, err := db.Prepare("SELECT|people|age|name=?")
if err != nil {
t.Fatalf("Prepare: %v", err)
@ -94,6 +103,7 @@ func TestStatementQueryRow(t *testing.T) {
// just a test of fakedb itself
func TestBogusPreboundParameters(t *testing.T) {
db := newTestDB(t, "foo")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
_, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion")
if err == nil {
@ -106,6 +116,7 @@ func TestBogusPreboundParameters(t *testing.T) {
func TestDb(t *testing.T) {
db := newTestDB(t, "foo")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
if err != nil {

View File

@ -0,0 +1,88 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto/aes"
"crypto/cipher"
"crypto/rc4"
)
// streamDump is used to dump the initial keystream for stream ciphers. It is a
// a write-only buffer, and not intended for reading so do not require a mutex.
var streamDump [512]byte
// noneCipher implements cipher.Stream and provides no encryption. It is used
// by the transport before the first key-exchange.
type noneCipher struct{}
func (c noneCipher) XORKeyStream(dst, src []byte) {
copy(dst, src)
}
func newAESCTR(key, iv []byte) (cipher.Stream, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewCTR(c, iv), nil
}
func newRC4(key, iv []byte) (cipher.Stream, error) {
return rc4.NewCipher(key)
}
type cipherMode struct {
keySize int
ivSize int
skip int
createFn func(key, iv []byte) (cipher.Stream, error)
}
func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
if len(key) < c.keySize {
panic("ssh: key length too small for cipher")
}
if len(iv) < c.ivSize {
panic("ssh: iv too small for cipher")
}
stream, err := c.createFn(key[:c.keySize], iv[:c.ivSize])
if err != nil {
return nil, err
}
for remainingToDump := c.skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
dumpThisTime = len(streamDump)
}
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
remainingToDump -= dumpThisTime
}
return stream, nil
}
// Specifies a default set of ciphers and a preference order. This is based on
// OpenSSH's default client preference order, minus algorithms that are not
// implemented.
var DefaultCipherOrder = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"arcfour256", "arcfour128",
}
var cipherModes = map[string]*cipherMode{
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC.
"aes128-ctr": &cipherMode{16, aes.BlockSize, 0, newAESCTR},
"aes192-ctr": &cipherMode{24, aes.BlockSize, 0, newAESCTR},
"aes256-ctr": &cipherMode{32, aes.BlockSize, 0, newAESCTR},
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
// They are defined in the order specified in the RFC.
"arcfour128": &cipherMode{16, 0, 1536, newRC4},
"arcfour256": &cipherMode{32, 0, 1536, newRC4},
}

View File

@ -0,0 +1,62 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"testing"
)
// TestCipherReversal tests that each cipher factory produces ciphers that can
// encrypt and decrypt some data successfully.
func TestCipherReversal(t *testing.T) {
testData := []byte("abcdefghijklmnopqrstuvwxyz012345")
testKey := []byte("AbCdEfGhIjKlMnOpQrStUvWxYz012345")
testIv := []byte("sdflkjhsadflkjhasdflkjhsadfklhsa")
cryptBuffer := make([]byte, 32)
for name, cipherMode := range cipherModes {
encrypter, err := cipherMode.createCipher(testKey, testIv)
if err != nil {
t.Errorf("failed to create encrypter for %q: %s", name, err)
continue
}
decrypter, err := cipherMode.createCipher(testKey, testIv)
if err != nil {
t.Errorf("failed to create decrypter for %q: %s", name, err)
continue
}
copy(cryptBuffer, testData)
encrypter.XORKeyStream(cryptBuffer, cryptBuffer)
if name == "none" {
if !bytes.Equal(cryptBuffer, testData) {
t.Errorf("encryption made change with 'none' cipher")
continue
}
} else {
if bytes.Equal(cryptBuffer, testData) {
t.Errorf("encryption made no change with %q", name)
continue
}
}
decrypter.XORKeyStream(cryptBuffer, cryptBuffer)
if !bytes.Equal(cryptBuffer, testData) {
t.Errorf("decrypted bytes not equal to input with %q", name)
continue
}
}
}
func TestDefaultCiphersExist(t *testing.T) {
for _, cipherAlgo := range DefaultCipherOrder {
if _, ok := cipherModes[cipherAlgo]; !ok {
t.Errorf("default cipher %q is unknown", cipherAlgo)
}
}
}

View File

@ -35,10 +35,6 @@ func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
conn.Close()
return nil, err
}
if err := conn.authenticate(); err != nil {
conn.Close()
return nil, err
}
go conn.mainLoop()
return conn, nil
}
@ -64,8 +60,8 @@ func (c *ClientConn) handshake() error {
clientKexInit := kexInitMsg{
KexAlgos: supportedKexAlgos,
ServerHostKeyAlgos: supportedHostKeyAlgos,
CiphersClientServer: supportedCiphers,
CiphersServerClient: supportedCiphers,
CiphersClientServer: c.config.Crypto.ciphers(),
CiphersServerClient: c.config.Crypto.ciphers(),
MACsClientServer: supportedMACs,
MACsServerClient: supportedMACs,
CompressionClientServer: supportedCompressions,
@ -128,7 +124,10 @@ func (c *ClientConn) handshake() error {
if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
return c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc)
if err := c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
return err
}
return c.authenticate(H)
}
// kexDH performs Diffie-Hellman key agreement on a ClientConn. The
@ -195,6 +194,7 @@ func (c *ClientConn) openChan(typ string) (*clientChan, error) {
switch msg := (<-ch.msg).(type) {
case *channelOpenConfirmMsg:
ch.peersId = msg.MyId
ch.win <- int(msg.MyWindow)
case *channelOpenFailureMsg:
c.chanlist.remove(ch.id)
return nil, errors.New(msg.Message)
@ -301,6 +301,9 @@ type ClientConfig struct {
// A slice of ClientAuth methods. Only the first instance
// of a particular RFC 4252 method will be used during authentication.
Auth []ClientAuth
// Cryptographic-related configuration.
Crypto CryptoConfig
}
func (c *ClientConfig) rand() io.Reader {

View File

@ -6,10 +6,11 @@ package ssh
import (
"errors"
"io"
)
// authenticate authenticates with the remote server. See RFC 4252.
func (c *ClientConn) authenticate() error {
func (c *ClientConn) authenticate(session []byte) error {
// initiate user auth session
if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
return err
@ -26,7 +27,7 @@ func (c *ClientConn) authenticate() error {
// then any untried methods suggested by the server.
tried, remain := make(map[string]bool), make(map[string]bool)
for auth := ClientAuth(new(noneAuth)); auth != nil; {
ok, methods, err := auth.auth(c.config.User, c.transport)
ok, methods, err := auth.auth(session, c.config.User, c.transport, c.config.rand())
if err != nil {
return err
}
@ -60,7 +61,7 @@ type ClientAuth interface {
// Returns true if authentication is successful.
// If authentication is not successful, a []string of alternative
// method names is returned.
auth(user string, t *transport) (bool, []string, error)
auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error)
// method returns the RFC 4252 method name.
method() string
@ -69,7 +70,7 @@ type ClientAuth interface {
// "none" authentication, RFC 4252 section 5.2.
type noneAuth int
func (n *noneAuth) auth(user string, t *transport) (bool, []string, error) {
func (n *noneAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
if err := t.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
User: user,
Service: serviceSSH,
@ -102,7 +103,7 @@ type passwordAuth struct {
ClientPassword
}
func (p *passwordAuth) auth(user string, t *transport) (bool, []string, error) {
func (p *passwordAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
type passwordAuthMsg struct {
User string
Service string
@ -155,3 +156,140 @@ type ClientPassword interface {
func ClientAuthPassword(impl ClientPassword) ClientAuth {
return &passwordAuth{impl}
}
// ClientKeyring implements access to a client key ring.
type ClientKeyring interface {
// Key returns the i'th rsa.Publickey or dsa.Publickey, or nil if
// no key exists at i.
Key(i int) (key interface{}, err error)
// Sign returns a signature of the given data using the i'th key
// and the supplied random source.
Sign(i int, rand io.Reader, data []byte) (sig []byte, err error)
}
// "publickey" authentication, RFC 4252 Section 7.
type publickeyAuth struct {
ClientKeyring
}
func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
type publickeyAuthMsg struct {
User string
Service string
Method string
// HasSig indicates to the reciver packet that the auth request is signed and
// should be used for authentication of the request.
HasSig bool
Algoname string
Pubkey string
// Sig is defined as []byte so marshal will exclude it during the query phase
Sig []byte `ssh:"rest"`
}
// Authentication is performed in two stages. The first stage sends an
// enquiry to test if each key is acceptable to the remote. The second
// stage attempts to authenticate with the valid keys obtained in the
// first stage.
var index int
// a map of public keys to their index in the keyring
validKeys := make(map[int]interface{})
for {
key, err := p.Key(index)
if err != nil {
return false, nil, err
}
if key == nil {
// no more keys in the keyring
break
}
pubkey := serializePublickey(key)
algoname := algoName(key)
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
Method: p.method(),
HasSig: false,
Algoname: algoname,
Pubkey: string(pubkey),
}
if err := t.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
return false, nil, err
}
packet, err := t.readPacket()
if err != nil {
return false, nil, err
}
switch packet[0] {
case msgUserAuthPubKeyOk:
msg := decode(packet).(*userAuthPubKeyOkMsg)
if msg.Algo != algoname || msg.PubKey != string(pubkey) {
continue
}
validKeys[index] = key
case msgUserAuthFailure:
default:
return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
}
index++
}
// methods that may continue if this auth is not successful.
var methods []string
for i, key := range validKeys {
pubkey := serializePublickey(key)
algoname := algoName(key)
sign, err := p.Sign(i, rand, buildDataSignedForAuth(session, userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: p.method(),
}, []byte(algoname), pubkey))
if err != nil {
return false, nil, err
}
// manually wrap the serialized signature in a string
s := serializeSignature(algoname, sign)
sig := make([]byte, stringLength(s))
marshalString(sig, s)
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
Method: p.method(),
HasSig: true,
Algoname: algoname,
Pubkey: string(pubkey),
Sig: sig,
}
p := marshal(msgUserAuthRequest, msg)
if err := t.writePacket(p); err != nil {
return false, nil, err
}
packet, err := t.readPacket()
if err != nil {
return false, nil, err
}
switch packet[0] {
case msgUserAuthSuccess:
return true, nil, nil
case msgUserAuthFailure:
msg := decode(packet).(*userAuthFailureMsg)
methods = msg.Methods
continue
case msgDisconnect:
return false, nil, io.EOF
default:
return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
}
}
return false, methods, nil
}
func (p *publickeyAuth) method() string {
return "publickey"
}
// ClientAuthPublickey returns a ClientAuth using public key authentication.
func ClientAuthPublickey(impl ClientKeyring) ClientAuth {
return &publickeyAuth{impl}
}

View File

@ -0,0 +1,248 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"io"
"io/ioutil"
"testing"
)
const _pem = `-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA19lGVsTqIT5iiNYRgnoY1CwkbETW5cq+Rzk5v/kTlf31XpSU
70HVWkbTERECjaYdXM2gGcbb+sxpq6GtXf1M3kVomycqhxwhPv4Cr6Xp4WT/jkFx
9z+FFzpeodGJWjOH6L2H5uX1Cvr9EDdQp9t9/J32/qBFntY8GwoUI/y/1MSTmMiF
tupdMODN064vd3gyMKTwrlQ8tZM6aYuyOPsutLlUY7M5x5FwMDYvnPDSeyT/Iw0z
s3B+NCyqeeMd2T7YzQFnRATj0M7rM5LoSs7DVqVriOEABssFyLj31PboaoLhOKgc
qoM9khkNzr7FHVvi+DhYM2jD0DwvqZLN6NmnLwIDAQABAoIBAQCGVj+kuSFOV1lT
+IclQYA6bM6uY5mroqcSBNegVxCNhWU03BxlW//BE9tA/+kq53vWylMeN9mpGZea
riEMIh25KFGWXqXlOOioH8bkMsqA8S7sBmc7jljyv+0toQ9vCCtJ+sueNPhxQQxH
D2YvUjfzBQ04I9+wn30BByDJ1QA/FoPsunxIOUCcRBE/7jxuLYcpR+JvEF68yYIh
atXRld4W4in7T65YDR8jK1Uj9XAcNeDYNpT/M6oFLx1aPIlkG86aCWRO19S1jLPT
b1ZAKHHxPMCVkSYW0RqvIgLXQOR62D0Zne6/2wtzJkk5UCjkSQ2z7ZzJpMkWgDgN
ifCULFPBAoGBAPoMZ5q1w+zB+knXUD33n1J+niN6TZHJulpf2w5zsW+m2K6Zn62M
MXndXlVAHtk6p02q9kxHdgov34Uo8VpuNjbS1+abGFTI8NZgFo+bsDxJdItemwC4
KJ7L1iz39hRN/ZylMRLz5uTYRGddCkeIHhiG2h7zohH/MaYzUacXEEy3AoGBANz8
e/msleB+iXC0cXKwds26N4hyMdAFE5qAqJXvV3S2W8JZnmU+sS7vPAWMYPlERPk1
D8Q2eXqdPIkAWBhrx4RxD7rNc5qFNcQWEhCIxC9fccluH1y5g2M+4jpMX2CT8Uv+
3z+NoJ5uDTXZTnLCfoZzgZ4nCZVZ+6iU5U1+YXFJAoGBANLPpIV920n/nJmmquMj
orI1R/QXR9Cy56cMC65agezlGOfTYxk5Cfl5Ve+/2IJCfgzwJyjWUsFx7RviEeGw
64o7JoUom1HX+5xxdHPsyZ96OoTJ5RqtKKoApnhRMamau0fWydH1yeOEJd+TRHhc
XStGfhz8QNa1dVFvENczja1vAoGABGWhsd4VPVpHMc7lUvrf4kgKQtTC2PjA4xoc
QJ96hf/642sVE76jl+N6tkGMzGjnVm4P2j+bOy1VvwQavKGoXqJBRd5Apppv727g
/SM7hBXKFc/zH80xKBBgP/i1DR7kdjakCoeu4ngeGywvu2jTS6mQsqzkK+yWbUxJ
I7mYBsECgYB/KNXlTEpXtz/kwWCHFSYA8U74l7zZbVD8ul0e56JDK+lLcJ0tJffk
gqnBycHj6AhEycjda75cs+0zybZvN4x65KZHOGW/O/7OAWEcZP5TPb3zf9ned3Hl
NsZoFj52ponUM6+99A2CmezFCN16c4mbA//luWF+k3VVqR6BpkrhKw==
-----END RSA PRIVATE KEY-----`
// reused internally by tests
var serverConfig = new(ServerConfig)
func init() {
if err := serverConfig.SetRSAPrivateKey([]byte(_pem)); err != nil {
panic("unable to set private key: " + err.Error())
}
}
// keychain implements the ClientPublickey interface
type keychain struct {
keys []*rsa.PrivateKey
}
func (k *keychain) Key(i int) (interface{}, error) {
if i < 0 || i >= len(k.keys) {
return nil, nil
}
return k.keys[i].PublicKey, nil
}
func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
hashFunc := crypto.SHA1
h := hashFunc.New()
h.Write(data)
digest := h.Sum()
return rsa.SignPKCS1v15(rand, k.keys[i], hashFunc, digest)
}
func (k *keychain) loadPEM(file string) error {
buf, err := ioutil.ReadFile(file)
if err != nil {
return err
}
block, _ := pem.Decode(buf)
if block == nil {
return errors.New("ssh: no key found")
}
r, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
}
k.keys = append(k.keys, r)
return nil
}
var pkey *rsa.PrivateKey
func init() {
var err error
pkey, err = rsa.GenerateKey(rand.Reader, 512)
if err != nil {
panic("unable to generate public key")
}
}
func TestClientAuthPublickey(t *testing.T) {
k := new(keychain)
k.keys = append(k.keys, pkey)
serverConfig.PubKeyCallback = func(user, algo string, pubkey []byte) bool {
expected := []byte(serializePublickey(k.keys[0].PublicKey))
algoname := algoName(k.keys[0].PublicKey)
return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
}
serverConfig.PasswordCallback = nil
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to listen: %s", err)
}
defer l.Close()
done := make(chan bool, 1)
go func() {
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer c.Close()
if err := c.Handshake(); err != nil {
t.Error(err)
}
done <- true
}()
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPublickey(k),
},
}
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
defer c.Close()
<-done
}
// password implements the ClientPassword interface
type password string
func (p password) Password(user string) (string, error) {
return string(p), nil
}
func TestClientAuthPassword(t *testing.T) {
pw := password("tiger")
serverConfig.PasswordCallback = func(user, pass string) bool {
return user == "testuser" && pass == string(pw)
}
serverConfig.PubKeyCallback = nil
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to listen: %s", err)
}
defer l.Close()
done := make(chan bool)
go func() {
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
if err := c.Handshake(); err != nil {
t.Error(err)
}
defer c.Close()
done <- true
}()
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPassword(pw),
},
}
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
defer c.Close()
<-done
}
func TestClientAuthPasswordAndPublickey(t *testing.T) {
pw := password("tiger")
serverConfig.PasswordCallback = func(user, pass string) bool {
return user == "testuser" && pass == string(pw)
}
k := new(keychain)
k.keys = append(k.keys, pkey)
serverConfig.PubKeyCallback = func(user, algo string, pubkey []byte) bool {
expected := []byte(serializePublickey(k.keys[0].PublicKey))
algoname := algoName(k.keys[0].PublicKey)
return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
}
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to listen: %s", err)
}
defer l.Close()
done := make(chan bool)
go func() {
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
if err := c.Handshake(); err != nil {
t.Error(err)
}
defer c.Close()
done <- true
}()
wrongPw := password("wrong")
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPassword(wrongPw),
ClientAuthPublickey(k),
},
}
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
defer c.Close()
<-done
}

View File

@ -0,0 +1,61 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// ClientConn functional tests.
// These tests require a running ssh server listening on port 22
// on the local host. Functional tests will be skipped unless
// -ssh.user and -ssh.pass must be passed to gotest.
import (
"flag"
"testing"
)
var (
sshuser = flag.String("ssh.user", "", "ssh username")
sshpass = flag.String("ssh.pass", "", "ssh password")
sshprivkey = flag.String("ssh.privkey", "", "ssh privkey file")
)
func TestFuncPasswordAuth(t *testing.T) {
if *sshuser == "" {
t.Log("ssh.user not defined, skipping test")
return
}
config := &ClientConfig{
User: *sshuser,
Auth: []ClientAuth{
ClientAuthPassword(password(*sshpass)),
},
}
conn, err := Dial("tcp", "localhost:22", config)
if err != nil {
t.Fatalf("Unable to connect: %s", err)
}
defer conn.Close()
}
func TestFuncPublickeyAuth(t *testing.T) {
if *sshuser == "" {
t.Log("ssh.user not defined, skipping test")
return
}
kc := new(keychain)
if err := kc.loadPEM(*sshprivkey); err != nil {
t.Fatalf("unable to load private key: %s", err)
}
config := &ClientConfig{
User: *sshuser,
Auth: []ClientAuth{
ClientAuthPublickey(kc),
},
}
conn, err := Dial("tcp", "localhost:22", config)
if err != nil {
t.Fatalf("unable to connect: %s", err)
}
defer conn.Close()
}

View File

@ -5,6 +5,8 @@
package ssh
import (
"crypto/dsa"
"crypto/rsa"
"math/big"
"strconv"
"sync"
@ -14,7 +16,6 @@ import (
const (
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
hostAlgoRSA = "ssh-rsa"
cipherAES128CTR = "aes128-ctr"
macSHA196 = "hmac-sha1-96"
compressionNone = "none"
serviceUserAuth = "ssh-userauth"
@ -23,7 +24,6 @@ const (
var supportedKexAlgos = []string{kexAlgoDH14SHA1}
var supportedHostKeyAlgos = []string{hostAlgoRSA}
var supportedCiphers = []string{cipherAES128CTR}
var supportedMACs = []string{macSHA196}
var supportedCompressions = []string{compressionNone}
@ -127,3 +127,100 @@ func findAgreedAlgorithms(transport *transport, clientKexInit, serverKexInit *ke
ok = true
return
}
// Cryptographic configuration common to both ServerConfig and ClientConfig.
type CryptoConfig struct {
// The allowed cipher algorithms. If unspecified then DefaultCipherOrder is
// used.
Ciphers []string
}
func (c *CryptoConfig) ciphers() []string {
if c.Ciphers == nil {
return DefaultCipherOrder
}
return c.Ciphers
}
// serialize a signed slice according to RFC 4254 6.6.
func serializeSignature(algoname string, sig []byte) []byte {
length := stringLength([]byte(algoname))
length += stringLength(sig)
ret := make([]byte, length)
r := marshalString(ret, []byte(algoname))
r = marshalString(r, sig)
return ret
}
// serialize an rsa.PublicKey or dsa.PublicKey according to RFC 4253 6.6.
func serializePublickey(key interface{}) []byte {
algoname := algoName(key)
switch key := key.(type) {
case rsa.PublicKey:
e := new(big.Int).SetInt64(int64(key.E))
length := stringLength([]byte(algoname))
length += intLength(e)
length += intLength(key.N)
ret := make([]byte, length)
r := marshalString(ret, []byte(algoname))
r = marshalInt(r, e)
marshalInt(r, key.N)
return ret
case dsa.PublicKey:
length := stringLength([]byte(algoname))
length += intLength(key.P)
length += intLength(key.Q)
length += intLength(key.G)
length += intLength(key.Y)
ret := make([]byte, length)
r := marshalString(ret, []byte(algoname))
r = marshalInt(r, key.P)
r = marshalInt(r, key.Q)
r = marshalInt(r, key.G)
marshalInt(r, key.Y)
return ret
}
panic("unexpected key type")
}
func algoName(key interface{}) string {
switch key.(type) {
case rsa.PublicKey:
return "ssh-rsa"
case dsa.PublicKey:
return "ssh-dss"
}
panic("unexpected key type")
}
// buildDataSignedForAuth returns the data that is signed in order to prove
// posession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
user := []byte(req.User)
service := []byte(req.Service)
method := []byte(req.Method)
length := stringLength(sessionId)
length += 1
length += stringLength(user)
length += stringLength(service)
length += stringLength(method)
length += 1
length += stringLength(algo)
length += stringLength(pubKey)
ret := make([]byte, length)
r := marshalString(ret, sessionId)
r[0] = msgUserAuthRequest
r = r[1:]
r = marshalString(r, user)
r = marshalString(r, service)
r = marshalString(r, method)
r[0] = 1
r = r[1:]
r = marshalString(r, algo)
r = marshalString(r, pubKey)
return ret
}

View File

@ -392,7 +392,10 @@ func parseString(in []byte) (out, rest []byte, ok bool) {
return
}
var comma = []byte{','}
var (
comma = []byte{','}
emptyNameList = []string{}
)
func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
contents, rest, ok := parseString(in)
@ -400,6 +403,7 @@ func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
return
}
if len(contents) == 0 {
out = emptyNameList
return
}
parts := bytes.Split(contents, comma)
@ -444,8 +448,6 @@ func parseUint32(in []byte) (out uint32, rest []byte, ok bool) {
return
}
const maxPacketSize = 36000
func nameListLength(namelist []string) int {
length := 4 /* uint32 length prefix */
for i, name := range namelist {

View File

@ -40,6 +40,9 @@ type ServerConfig struct {
// key authentication. It must return true iff the given public key is
// valid for the given user.
PubKeyCallback func(user, algo string, pubkey []byte) bool
// Cryptographic-related configuration.
Crypto CryptoConfig
}
func (c *ServerConfig) rand() io.Reader {
@ -221,7 +224,7 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
return nil, nil, errors.New("internal error")
}
serializedSig := serializeRSASignature(sig)
serializedSig := serializeSignature(hostAlgoRSA, sig)
kexDHReply := kexDHReplyMsg{
HostKey: serializedHostKey,
@ -234,50 +237,9 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
return
}
func serializeRSASignature(sig []byte) []byte {
length := stringLength([]byte(hostAlgoRSA))
length += stringLength(sig)
ret := make([]byte, length)
r := marshalString(ret, []byte(hostAlgoRSA))
r = marshalString(r, sig)
return ret
}
// serverVersion is the fixed identification string that Server will use.
var serverVersion = []byte("SSH-2.0-Go\r\n")
// buildDataSignedForAuth returns the data that is signed in order to prove
// posession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
user := []byte(req.User)
service := []byte(req.Service)
method := []byte(req.Method)
length := stringLength(sessionId)
length += 1
length += stringLength(user)
length += stringLength(service)
length += stringLength(method)
length += 1
length += stringLength(algo)
length += stringLength(pubKey)
ret := make([]byte, length)
r := marshalString(ret, sessionId)
r[0] = msgUserAuthRequest
r = r[1:]
r = marshalString(r, user)
r = marshalString(r, service)
r = marshalString(r, method)
r[0] = 1
r = r[1:]
r = marshalString(r, algo)
r = marshalString(r, pubKey)
return ret
}
// Handshake performs an SSH transport and client authentication on the given ServerConn.
func (s *ServerConn) Handshake() error {
var magics handshakeMagics
@ -298,8 +260,8 @@ func (s *ServerConn) Handshake() error {
serverKexInit := kexInitMsg{
KexAlgos: supportedKexAlgos,
ServerHostKeyAlgos: supportedHostKeyAlgos,
CiphersClientServer: supportedCiphers,
CiphersServerClient: supportedCiphers,
CiphersClientServer: s.config.Crypto.ciphers(),
CiphersServerClient: s.config.Crypto.ciphers(),
MACsClientServer: supportedMACs,
MACsServerClient: supportedMACs,
CompressionClientServer: supportedCompressions,
@ -364,7 +326,9 @@ func (s *ServerConn) Handshake() error {
if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc)
if err = s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
return err
}
if packet, err = s.readPacket(); err != nil {
return err
}

146
libgo/go/exp/ssh/tcpip.go Normal file
View File

@ -0,0 +1,146 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"errors"
"io"
"net"
)
// Dial initiates a connection to the addr from the remote host.
// addr is resolved using net.ResolveTCPAddr before connection.
// This could allow an observer to observe the DNS name of the
// remote host. Consider using ssh.DialTCP to avoid this.
func (c *ClientConn) Dial(n, addr string) (net.Conn, error) {
raddr, err := net.ResolveTCPAddr(n, addr)
if err != nil {
return nil, err
}
return c.DialTCP(n, nil, raddr)
}
// DialTCP connects to the remote address raddr on the network net,
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
// as the local address for the connection.
func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
if laddr == nil {
laddr = &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
}
ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
if err != nil {
return nil, err
}
return &tcpchanconn{
tcpchan: ch,
laddr: laddr,
raddr: raddr,
}, nil
}
// dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as
// strings and are expected to be resolveable at the remote end.
func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpchan, error) {
// RFC 4254 7.2
type channelOpenDirectMsg struct {
ChanType string
PeersId uint32
PeersWindow uint32
MaxPacketSize uint32
raddr string
rport uint32
laddr string
lport uint32
}
ch := c.newChan(c.transport)
if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
ChanType: "direct-tcpip",
PeersId: ch.id,
PeersWindow: 1 << 14,
MaxPacketSize: 1 << 15, // RFC 4253 6.1
raddr: raddr,
rport: uint32(rport),
laddr: laddr,
lport: uint32(lport),
})); err != nil {
c.chanlist.remove(ch.id)
return nil, err
}
// wait for response
switch msg := (<-ch.msg).(type) {
case *channelOpenConfirmMsg:
ch.peersId = msg.MyId
ch.win <- int(msg.MyWindow)
case *channelOpenFailureMsg:
c.chanlist.remove(ch.id)
return nil, errors.New("ssh: error opening remote TCP connection: " + msg.Message)
default:
c.chanlist.remove(ch.id)
return nil, errors.New("ssh: unexpected packet")
}
return &tcpchan{
clientChan: ch,
Reader: &chanReader{
packetWriter: ch,
id: ch.id,
data: ch.data,
},
Writer: &chanWriter{
packetWriter: ch,
id: ch.id,
win: ch.win,
},
}, nil
}
type tcpchan struct {
*clientChan // the backing channel
io.Reader
io.Writer
}
// tcpchanconn fulfills the net.Conn interface without
// the tcpchan having to hold laddr or raddr directly.
type tcpchanconn struct {
*tcpchan
laddr, raddr net.Addr
}
// LocalAddr returns the local network address.
func (t *tcpchanconn) LocalAddr() net.Addr {
return t.laddr
}
// RemoteAddr returns the remote network address.
func (t *tcpchanconn) RemoteAddr() net.Addr {
return t.raddr
}
// SetTimeout sets the read and write deadlines associated
// with the connection.
func (t *tcpchanconn) SetTimeout(nsec int64) error {
if err := t.SetReadTimeout(nsec); err != nil {
return err
}
return t.SetWriteTimeout(nsec)
}
// SetReadTimeout sets the time (in nanoseconds) that
// Read will wait for data before returning an error with Timeout() == true.
// Setting nsec == 0 (the default) disables the deadline.
func (t *tcpchanconn) SetReadTimeout(nsec int64) error {
return errors.New("ssh: tcpchan: timeout not supported")
}
// SetWriteTimeout sets the time (in nanoseconds) that
// Write will wait to send its data before returning an error with Timeout() == true.
// Setting nsec == 0 (the default) disables the deadline.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
func (t *tcpchanconn) SetWriteTimeout(nsec int64) error {
return errors.New("ssh: tcpchan: timeout not supported")
}

View File

@ -7,7 +7,6 @@ package ssh
import (
"bufio"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/subtle"
@ -19,7 +18,10 @@ import (
)
const (
paddingMultiple = 16 // TODO(dfc) does this need to be configurable?
packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
minPacketSize = 16
maxPacketSize = 36000
minPaddingSize = 4 // TODO(huin) should this be configurable?
)
// filteredConn reduces the set of methods exposed when embeddeding
@ -61,8 +63,7 @@ type reader struct {
type writer struct {
*sync.Mutex // protects writer.Writer from concurrent writes
*bufio.Writer
paddingMultiple int
rand io.Reader
rand io.Reader
common
}
@ -82,14 +83,11 @@ type common struct {
func (r *reader) readOnePacket() ([]byte, error) {
var lengthBytes = make([]byte, 5)
var macSize uint32
if _, err := io.ReadFull(r, lengthBytes); err != nil {
return nil, err
}
if r.cipher != nil {
r.cipher.XORKeyStream(lengthBytes, lengthBytes)
}
r.cipher.XORKeyStream(lengthBytes, lengthBytes)
if r.mac != nil {
r.mac.Reset()
@ -153,9 +151,9 @@ func (w *writer) writePacket(packet []byte) error {
w.Mutex.Lock()
defer w.Mutex.Unlock()
paddingLength := paddingMultiple - (5+len(packet))%paddingMultiple
paddingLength := packetSizeMultiple - (5+len(packet))%packetSizeMultiple
if paddingLength < 4 {
paddingLength += paddingMultiple
paddingLength += packetSizeMultiple
}
length := len(packet) + 1 + paddingLength
@ -188,11 +186,9 @@ func (w *writer) writePacket(packet []byte) error {
// TODO(dfc) lengthBytes, packet and padding should be
// subslices of a single buffer
if w.cipher != nil {
w.cipher.XORKeyStream(lengthBytes, lengthBytes)
w.cipher.XORKeyStream(packet, packet)
w.cipher.XORKeyStream(padding, padding)
}
w.cipher.XORKeyStream(lengthBytes, lengthBytes)
w.cipher.XORKeyStream(packet, packet)
w.cipher.XORKeyStream(padding, padding)
if _, err := w.Write(lengthBytes); err != nil {
return err
@ -227,11 +223,17 @@ func newTransport(conn net.Conn, rand io.Reader) *transport {
return &transport{
reader: reader{
Reader: bufio.NewReader(conn),
common: common{
cipher: noneCipher{},
},
},
writer: writer{
Writer: bufio.NewWriter(conn),
rand: rand,
Mutex: new(sync.Mutex),
common: common{
cipher: noneCipher{},
},
},
filteredConn: conn,
}
@ -249,29 +251,32 @@ var (
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
)
// setupKeys sets the cipher and MAC keys from K, H and sessionId, as
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error {
h := hashFunc.New()
cipherMode := cipherModes[c.cipherAlgo]
blockSize := 16
keySize := 16
macKeySize := 20
iv := make([]byte, blockSize)
key := make([]byte, keySize)
iv := make([]byte, cipherMode.ivSize)
key := make([]byte, cipherMode.keySize)
macKey := make([]byte, macKeySize)
h := hashFunc.New()
generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
aes, err := aes.NewCipher(key)
cipher, err := cipherMode.createCipher(key, iv)
if err != nil {
return err
}
c.cipher = cipher.NewCTR(aes, iv)
c.cipher = cipher
return nil
}

View File

@ -1,356 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package terminal
import "io"
// Shell contains the state for running a VT100 terminal that is capable of
// reading lines of input.
type Shell struct {
c io.ReadWriter
prompt string
// line is the current line being entered.
line []byte
// pos is the logical position of the cursor in line
pos int
// cursorX contains the current X value of the cursor where the left
// edge is 0. cursorY contains the row number where the first row of
// the current line is 0.
cursorX, cursorY int
// maxLine is the greatest value of cursorY so far.
maxLine int
termWidth, termHeight int
// outBuf contains the terminal data to be sent.
outBuf []byte
// remainder contains the remainder of any partial key sequences after
// a read. It aliases into inBuf.
remainder []byte
inBuf [256]byte
}
// NewShell runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
// a local terminal, that terminal must first have been put into raw mode.
// prompt is a string that is written at the start of each input line (i.e.
// "> ").
func NewShell(c io.ReadWriter, prompt string) *Shell {
return &Shell{
c: c,
prompt: prompt,
termWidth: 80,
termHeight: 24,
}
}
const (
keyCtrlD = 4
keyEnter = '\r'
keyEscape = 27
keyBackspace = 127
keyUnknown = 256 + iota
keyUp
keyDown
keyLeft
keyRight
keyAltLeft
keyAltRight
)
// bytesToKey tries to parse a key sequence from b. If successful, it returns
// the key and the remainder of the input. Otherwise it returns -1.
func bytesToKey(b []byte) (int, []byte) {
if len(b) == 0 {
return -1, nil
}
if b[0] != keyEscape {
return int(b[0]), b[1:]
}
if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
switch b[2] {
case 'A':
return keyUp, b[3:]
case 'B':
return keyDown, b[3:]
case 'C':
return keyRight, b[3:]
case 'D':
return keyLeft, b[3:]
}
}
if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
switch b[5] {
case 'C':
return keyAltRight, b[6:]
case 'D':
return keyAltLeft, b[6:]
}
}
// If we get here then we have a key that we don't recognise, or a
// partial sequence. It's not clear how one should find the end of a
// sequence without knowing them all, but it seems that [a-zA-Z] only
// appears at the end of a sequence.
for i, c := range b[0:] {
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
return keyUnknown, b[i+1:]
}
}
return -1, b
}
// queue appends data to the end of ss.outBuf
func (ss *Shell) queue(data []byte) {
if len(ss.outBuf)+len(data) > cap(ss.outBuf) {
newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data)))
copy(newOutBuf, ss.outBuf)
ss.outBuf = newOutBuf
}
oldLen := len(ss.outBuf)
ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)]
copy(ss.outBuf[oldLen:], data)
}
var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
func isPrintable(key int) bool {
return key >= 32 && key < 127
}
// moveCursorToPos appends data to ss.outBuf which will move the cursor to the
// given, logical position in the text.
func (ss *Shell) moveCursorToPos(pos int) {
x := len(ss.prompt) + pos
y := x / ss.termWidth
x = x % ss.termWidth
up := 0
if y < ss.cursorY {
up = ss.cursorY - y
}
down := 0
if y > ss.cursorY {
down = y - ss.cursorY
}
left := 0
if x < ss.cursorX {
left = ss.cursorX - x
}
right := 0
if x > ss.cursorX {
right = x - ss.cursorX
}
movement := make([]byte, 3*(up+down+left+right))
m := movement
for i := 0; i < up; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'A'
m = m[3:]
}
for i := 0; i < down; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'B'
m = m[3:]
}
for i := 0; i < left; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'D'
m = m[3:]
}
for i := 0; i < right; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'C'
m = m[3:]
}
ss.cursorX = x
ss.cursorY = y
ss.queue(movement)
}
const maxLineLength = 4096
// handleKey processes the given key and, optionally, returns a line of text
// that the user has entered.
func (ss *Shell) handleKey(key int) (line string, ok bool) {
switch key {
case keyBackspace:
if ss.pos == 0 {
return
}
ss.pos--
copy(ss.line[ss.pos:], ss.line[1+ss.pos:])
ss.line = ss.line[:len(ss.line)-1]
ss.writeLine(ss.line[ss.pos:])
ss.moveCursorToPos(ss.pos)
ss.queue(eraseUnderCursor)
case keyAltLeft:
// move left by a word.
if ss.pos == 0 {
return
}
ss.pos--
for ss.pos > 0 {
if ss.line[ss.pos] != ' ' {
break
}
ss.pos--
}
for ss.pos > 0 {
if ss.line[ss.pos] == ' ' {
ss.pos++
break
}
ss.pos--
}
ss.moveCursorToPos(ss.pos)
case keyAltRight:
// move right by a word.
for ss.pos < len(ss.line) {
if ss.line[ss.pos] == ' ' {
break
}
ss.pos++
}
for ss.pos < len(ss.line) {
if ss.line[ss.pos] != ' ' {
break
}
ss.pos++
}
ss.moveCursorToPos(ss.pos)
case keyLeft:
if ss.pos == 0 {
return
}
ss.pos--
ss.moveCursorToPos(ss.pos)
case keyRight:
if ss.pos == len(ss.line) {
return
}
ss.pos++
ss.moveCursorToPos(ss.pos)
case keyEnter:
ss.moveCursorToPos(len(ss.line))
ss.queue([]byte("\r\n"))
line = string(ss.line)
ok = true
ss.line = ss.line[:0]
ss.pos = 0
ss.cursorX = 0
ss.cursorY = 0
ss.maxLine = 0
default:
if !isPrintable(key) {
return
}
if len(ss.line) == maxLineLength {
return
}
if len(ss.line) == cap(ss.line) {
newLine := make([]byte, len(ss.line), 2*(1+len(ss.line)))
copy(newLine, ss.line)
ss.line = newLine
}
ss.line = ss.line[:len(ss.line)+1]
copy(ss.line[ss.pos+1:], ss.line[ss.pos:])
ss.line[ss.pos] = byte(key)
ss.writeLine(ss.line[ss.pos:])
ss.pos++
ss.moveCursorToPos(ss.pos)
}
return
}
func (ss *Shell) writeLine(line []byte) {
for len(line) != 0 {
if ss.cursorX == ss.termWidth {
ss.queue([]byte("\r\n"))
ss.cursorX = 0
ss.cursorY++
if ss.cursorY > ss.maxLine {
ss.maxLine = ss.cursorY
}
}
remainingOnLine := ss.termWidth - ss.cursorX
todo := len(line)
if todo > remainingOnLine {
todo = remainingOnLine
}
ss.queue(line[:todo])
ss.cursorX += todo
line = line[todo:]
}
}
func (ss *Shell) Write(buf []byte) (n int, err error) {
return ss.c.Write(buf)
}
// ReadLine returns a line of input from the terminal.
func (ss *Shell) ReadLine() (line string, err error) {
ss.writeLine([]byte(ss.prompt))
ss.c.Write(ss.outBuf)
ss.outBuf = ss.outBuf[:0]
for {
// ss.remainder is a slice at the beginning of ss.inBuf
// containing a partial key sequence
readBuf := ss.inBuf[len(ss.remainder):]
var n int
n, err = ss.c.Read(readBuf)
if err != nil {
return
}
if err == nil {
ss.remainder = ss.inBuf[:n+len(ss.remainder)]
rest := ss.remainder
lineOk := false
for !lineOk {
var key int
key, rest = bytesToKey(rest)
if key < 0 {
break
}
if key == keyCtrlD {
return "", io.EOF
}
line, lineOk = ss.handleKey(key)
}
if len(rest) > 0 {
n := copy(ss.inBuf[:], rest)
ss.remainder = ss.inBuf[:n]
} else {
ss.remainder = nil
}
ss.c.Write(ss.outBuf)
ss.outBuf = ss.outBuf[:0]
if lineOk {
return
}
continue
}
}
panic("unreachable")
}

View File

@ -2,102 +2,361 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Putting a terminal into raw mode is the most common requirement:
//
// oldState, err := terminal.MakeRaw(0)
// if err != nil {
// panic(err.String())
// }
// defer terminal.Restore(0, oldState)
package terminal
import (
"io"
"os"
"syscall"
import "io"
// Terminal contains the state for running a VT100 terminal that is capable of
// reading lines of input.
type Terminal struct {
c io.ReadWriter
prompt string
// line is the current line being entered.
line []byte
// pos is the logical position of the cursor in line
pos int
// cursorX contains the current X value of the cursor where the left
// edge is 0. cursorY contains the row number where the first row of
// the current line is 0.
cursorX, cursorY int
// maxLine is the greatest value of cursorY so far.
maxLine int
termWidth, termHeight int
// outBuf contains the terminal data to be sent.
outBuf []byte
// remainder contains the remainder of any partial key sequences after
// a read. It aliases into inBuf.
remainder []byte
inBuf [256]byte
}
// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
// a local terminal, that terminal must first have been put into raw mode.
// prompt is a string that is written at the start of each input line (i.e.
// "> ").
func NewTerminal(c io.ReadWriter, prompt string) *Terminal {
return &Terminal{
c: c,
prompt: prompt,
termWidth: 80,
termHeight: 24,
}
}
const (
keyCtrlD = 4
keyEnter = '\r'
keyEscape = 27
keyBackspace = 127
keyUnknown = 256 + iota
keyUp
keyDown
keyLeft
keyRight
keyAltLeft
keyAltRight
)
// State contains the state of a terminal.
type State struct {
termios syscall.Termios
}
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
var termios syscall.Termios
e := syscall.Tcgetattr(fd, &termios)
return e == 0
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
var oldState State
if e := syscall.Tcgetattr(fd, &oldState.termios); e != 0 {
return nil, os.Errno(e)
// bytesToKey tries to parse a key sequence from b. If successful, it returns
// the key and the remainder of the input. Otherwise it returns -1.
func bytesToKey(b []byte) (int, []byte) {
if len(b) == 0 {
return -1, nil
}
newState := oldState.termios
newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF
newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG
if e := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); e != 0 {
return nil, os.Errno(e)
if b[0] != keyEscape {
return int(b[0]), b[1:]
}
return &oldState, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
e := syscall.Tcsetattr(fd, syscall.TCSANOW, &state.termios)
return os.Errno(e)
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
var oldState syscall.Termios
if e := syscall.Tcgetattr(fd, &oldState); e != 0 {
return nil, os.Errno(e)
}
newState := oldState
newState.Lflag &^= syscall.ECHO
if e := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); e != 0 {
return nil, os.Errno(e)
}
defer func() {
syscall.Tcsetattr(fd, syscall.TCSANOW, &oldState)
}()
var buf [16]byte
var ret []byte
for {
n, errno := syscall.Read(fd, buf[:])
if errno != 0 {
return nil, os.Errno(errno)
if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
switch b[2] {
case 'A':
return keyUp, b[3:]
case 'B':
return keyDown, b[3:]
case 'C':
return keyRight, b[3:]
case 'D':
return keyLeft, b[3:]
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
switch b[5] {
case 'C':
return keyAltRight, b[6:]
case 'D':
return keyAltLeft, b[6:]
}
}
// If we get here then we have a key that we don't recognise, or a
// partial sequence. It's not clear how one should find the end of a
// sequence without knowing them all, but it seems that [a-zA-Z] only
// appears at the end of a sequence.
for i, c := range b[0:] {
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
return keyUnknown, b[i+1:]
}
}
return -1, b
}
// queue appends data to the end of t.outBuf
func (t *Terminal) queue(data []byte) {
if len(t.outBuf)+len(data) > cap(t.outBuf) {
newOutBuf := make([]byte, len(t.outBuf), 2*(len(t.outBuf)+len(data)))
copy(newOutBuf, t.outBuf)
t.outBuf = newOutBuf
}
oldLen := len(t.outBuf)
t.outBuf = t.outBuf[:len(t.outBuf)+len(data)]
copy(t.outBuf[oldLen:], data)
}
var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
func isPrintable(key int) bool {
return key >= 32 && key < 127
}
// moveCursorToPos appends data to t.outBuf which will move the cursor to the
// given, logical position in the text.
func (t *Terminal) moveCursorToPos(pos int) {
x := len(t.prompt) + pos
y := x / t.termWidth
x = x % t.termWidth
up := 0
if y < t.cursorY {
up = t.cursorY - y
}
down := 0
if y > t.cursorY {
down = y - t.cursorY
}
left := 0
if x < t.cursorX {
left = t.cursorX - x
}
right := 0
if x > t.cursorX {
right = x - t.cursorX
}
movement := make([]byte, 3*(up+down+left+right))
m := movement
for i := 0; i < up; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'A'
m = m[3:]
}
for i := 0; i < down; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'B'
m = m[3:]
}
for i := 0; i < left; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'D'
m = m[3:]
}
for i := 0; i < right; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'C'
m = m[3:]
}
t.cursorX = x
t.cursorY = y
t.queue(movement)
}
const maxLineLength = 4096
// handleKey processes the given key and, optionally, returns a line of text
// that the user has entered.
func (t *Terminal) handleKey(key int) (line string, ok bool) {
switch key {
case keyBackspace:
if t.pos == 0 {
return
}
t.pos--
copy(t.line[t.pos:], t.line[1+t.pos:])
t.line = t.line[:len(t.line)-1]
t.writeLine(t.line[t.pos:])
t.moveCursorToPos(t.pos)
t.queue(eraseUnderCursor)
case keyAltLeft:
// move left by a word.
if t.pos == 0 {
return
}
t.pos--
for t.pos > 0 {
if t.line[t.pos] != ' ' {
break
}
break
t.pos--
}
if buf[n-1] == '\n' {
n--
for t.pos > 0 {
if t.line[t.pos] == ' ' {
t.pos++
break
}
t.pos--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
t.moveCursorToPos(t.pos)
case keyAltRight:
// move right by a word.
for t.pos < len(t.line) {
if t.line[t.pos] == ' ' {
break
}
t.pos++
}
for t.pos < len(t.line) {
if t.line[t.pos] != ' ' {
break
}
t.pos++
}
t.moveCursorToPos(t.pos)
case keyLeft:
if t.pos == 0 {
return
}
t.pos--
t.moveCursorToPos(t.pos)
case keyRight:
if t.pos == len(t.line) {
return
}
t.pos++
t.moveCursorToPos(t.pos)
case keyEnter:
t.moveCursorToPos(len(t.line))
t.queue([]byte("\r\n"))
line = string(t.line)
ok = true
t.line = t.line[:0]
t.pos = 0
t.cursorX = 0
t.cursorY = 0
t.maxLine = 0
default:
if !isPrintable(key) {
return
}
if len(t.line) == maxLineLength {
return
}
if len(t.line) == cap(t.line) {
newLine := make([]byte, len(t.line), 2*(1+len(t.line)))
copy(newLine, t.line)
t.line = newLine
}
t.line = t.line[:len(t.line)+1]
copy(t.line[t.pos+1:], t.line[t.pos:])
t.line[t.pos] = byte(key)
t.writeLine(t.line[t.pos:])
t.pos++
t.moveCursorToPos(t.pos)
}
return
}
func (t *Terminal) writeLine(line []byte) {
for len(line) != 0 {
if t.cursorX == t.termWidth {
t.queue([]byte("\r\n"))
t.cursorX = 0
t.cursorY++
if t.cursorY > t.maxLine {
t.maxLine = t.cursorY
}
}
remainingOnLine := t.termWidth - t.cursorX
todo := len(line)
if todo > remainingOnLine {
todo = remainingOnLine
}
t.queue(line[:todo])
t.cursorX += todo
line = line[todo:]
}
}
func (t *Terminal) Write(buf []byte) (n int, err error) {
return t.c.Write(buf)
}
// ReadLine returns a line of input from the terminal.
func (t *Terminal) ReadLine() (line string, err error) {
if t.cursorX == 0 {
t.writeLine([]byte(t.prompt))
t.c.Write(t.outBuf)
t.outBuf = t.outBuf[:0]
}
return ret, nil
for {
// t.remainder is a slice at the beginning of t.inBuf
// containing a partial key sequence
readBuf := t.inBuf[len(t.remainder):]
var n int
n, err = t.c.Read(readBuf)
if err != nil {
return
}
if err == nil {
t.remainder = t.inBuf[:n+len(t.remainder)]
rest := t.remainder
lineOk := false
for !lineOk {
var key int
key, rest = bytesToKey(rest)
if key < 0 {
break
}
if key == keyCtrlD {
return "", io.EOF
}
line, lineOk = t.handleKey(key)
}
if len(rest) > 0 {
n := copy(t.inBuf[:], rest)
t.remainder = t.inBuf[:n]
} else {
t.remainder = nil
}
t.c.Write(t.outBuf)
t.outBuf = t.outBuf[:0]
if lineOk {
return
}
continue
}
}
panic("unreachable")
}
func (t *Terminal) SetSize(width, height int) {
t.termWidth, t.termHeight = width, height
}

View File

@ -41,7 +41,7 @@ func (c *MockTerminal) Write(data []byte) (n int, err error) {
func TestClose(t *testing.T) {
c := &MockTerminal{}
ss := NewShell(c, "> ")
ss := NewTerminal(c, "> ")
line, err := ss.ReadLine()
if line != "" {
t.Errorf("Expected empty line but got: %s", line)
@ -95,7 +95,7 @@ func TestKeyPresses(t *testing.T) {
toSend: []byte(test.in),
bytesPerRead: j,
}
ss := NewShell(c, "> ")
ss := NewTerminal(c, "> ")
line, err := ss.ReadLine()
if line != test.line {
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)

View File

@ -0,0 +1,102 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Putting a terminal into raw mode is the most common requirement:
//
// oldState, err := terminal.MakeRaw(0)
// if err != nil {
// panic(err.String())
// }
// defer terminal.Restore(0, oldState)
package terminal
import (
"io"
"syscall"
)
// State contains the state of a terminal.
type State struct {
termios syscall.Termios
}
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
var termios syscall.Termios
err := syscall.Tcgetattr(fd, &termios)
return err == nil
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
var oldState State
if err := syscall.Tcgetattr(fd, &oldState.termios); err != nil {
return nil, err
}
newState := oldState.termios
newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF
newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG
if err := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); err != nil {
return nil, err
}
return &oldState, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
err := syscall.Tcsetattr(fd, syscall.TCSANOW, &state.termios)
return err
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
var oldState syscall.Termios
if err := syscall.Tcgetattr(fd, &oldState); err != nil {
return nil, err
}
newState := oldState
newState.Lflag &^= syscall.ECHO
if err := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); err != nil {
return nil, err
}
defer func() {
syscall.Tcsetattr(fd, syscall.TCSANOW, &oldState)
}()
var buf [16]byte
var ret []byte
for {
n, err := syscall.Read(fd, buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
}

View File

@ -357,6 +357,10 @@ var fmttests = []struct {
{"%#v", map[string]B{"a": {1, 2}}, `map[string] fmt_test.B{"a":fmt_test.B{I:1, j:2}}`},
{"%#v", []string{"a", "b"}, `[]string{"a", "b"}`},
{"%#v", SI{}, `fmt_test.SI{I:interface {}(nil)}`},
{"%#v", []int(nil), `[]int(nil)`},
{"%#v", []int{}, `[]int{}`},
{"%#v", map[int]byte(nil), `map[int] uint8(nil)`},
{"%#v", map[int]byte{}, `map[int] uint8{}`},
// slices with other formats
{"%#x", []int{1, 2, 15}, `[0x1 0x2 0xf]`},

View File

@ -795,6 +795,10 @@ BigSwitch:
case reflect.Map:
if goSyntax {
p.buf.WriteString(f.Type().String())
if f.IsNil() {
p.buf.WriteString("(nil)")
break
}
p.buf.WriteByte('{')
} else {
p.buf.Write(mapBytes)
@ -873,6 +877,10 @@ BigSwitch:
}
if goSyntax {
p.buf.WriteString(value.Type().String())
if f.IsNil() {
p.buf.WriteString("(nil)")
break
}
p.buf.WriteByte('{')
} else {
p.buf.WriteByte('[')

View File

@ -324,7 +324,7 @@ var x, y Xs
var z IntString
var multiTests = []ScanfMultiTest{
{"", "", nil, nil, ""},
{"", "", []interface{}{}, []interface{}{}, ""},
{"%d", "23", args(&i), args(23), ""},
{"%2s%3s", "22333", args(&s, &t), args("22", "333"), ""},
{"%2d%3d", "44555", args(&i, &j), args(44, 555), ""},
@ -378,7 +378,7 @@ func testScan(name string, t *testing.T, scan func(r io.Reader, a ...interface{}
}
val := v.Interface()
if !reflect.DeepEqual(val, test.out) {
t.Errorf("%s scanning %q: expected %v got %v, type %T", name, test.text, test.out, val, val)
t.Errorf("%s scanning %q: expected %#v got %#v, type %T", name, test.text, test.out, val, val)
}
}
}
@ -417,7 +417,7 @@ func TestScanf(t *testing.T) {
}
val := v.Interface()
if !reflect.DeepEqual(val, test.out) {
t.Errorf("scanning (%q, %q): expected %v got %v, type %T", test.format, test.text, test.out, val, val)
t.Errorf("scanning (%q, %q): expected %#v got %#v, type %T", test.format, test.text, test.out, val, val)
}
}
}
@ -520,7 +520,7 @@ func testScanfMulti(name string, t *testing.T) {
}
result := resultVal.Interface()
if !reflect.DeepEqual(result, test.out) {
t.Errorf("scanning (%q, %q): expected %v got %v", test.format, test.text, test.out, result)
t.Errorf("scanning (%q, %q): expected %#v got %#v", test.format, test.text, test.out, result)
}
}
}

View File

@ -412,29 +412,29 @@ func (x *ChanType) End() token.Pos { return x.Value.End() }
// exprNode() ensures that only expression/type nodes can be
// assigned to an ExprNode.
//
func (x *BadExpr) exprNode() {}
func (x *Ident) exprNode() {}
func (x *Ellipsis) exprNode() {}
func (x *BasicLit) exprNode() {}
func (x *FuncLit) exprNode() {}
func (x *CompositeLit) exprNode() {}
func (x *ParenExpr) exprNode() {}
func (x *SelectorExpr) exprNode() {}
func (x *IndexExpr) exprNode() {}
func (x *SliceExpr) exprNode() {}
func (x *TypeAssertExpr) exprNode() {}
func (x *CallExpr) exprNode() {}
func (x *StarExpr) exprNode() {}
func (x *UnaryExpr) exprNode() {}
func (x *BinaryExpr) exprNode() {}
func (x *KeyValueExpr) exprNode() {}
func (*BadExpr) exprNode() {}
func (*Ident) exprNode() {}
func (*Ellipsis) exprNode() {}
func (*BasicLit) exprNode() {}
func (*FuncLit) exprNode() {}
func (*CompositeLit) exprNode() {}
func (*ParenExpr) exprNode() {}
func (*SelectorExpr) exprNode() {}
func (*IndexExpr) exprNode() {}
func (*SliceExpr) exprNode() {}
func (*TypeAssertExpr) exprNode() {}
func (*CallExpr) exprNode() {}
func (*StarExpr) exprNode() {}
func (*UnaryExpr) exprNode() {}
func (*BinaryExpr) exprNode() {}
func (*KeyValueExpr) exprNode() {}
func (x *ArrayType) exprNode() {}
func (x *StructType) exprNode() {}
func (x *FuncType) exprNode() {}
func (x *InterfaceType) exprNode() {}
func (x *MapType) exprNode() {}
func (x *ChanType) exprNode() {}
func (*ArrayType) exprNode() {}
func (*StructType) exprNode() {}
func (*FuncType) exprNode() {}
func (*InterfaceType) exprNode() {}
func (*MapType) exprNode() {}
func (*ChanType) exprNode() {}
// ----------------------------------------------------------------------------
// Convenience functions for Idents
@ -711,27 +711,27 @@ func (s *RangeStmt) End() token.Pos { return s.Body.End() }
// stmtNode() ensures that only statement nodes can be
// assigned to a StmtNode.
//
func (s *BadStmt) stmtNode() {}
func (s *DeclStmt) stmtNode() {}
func (s *EmptyStmt) stmtNode() {}
func (s *LabeledStmt) stmtNode() {}
func (s *ExprStmt) stmtNode() {}
func (s *SendStmt) stmtNode() {}
func (s *IncDecStmt) stmtNode() {}
func (s *AssignStmt) stmtNode() {}
func (s *GoStmt) stmtNode() {}
func (s *DeferStmt) stmtNode() {}
func (s *ReturnStmt) stmtNode() {}
func (s *BranchStmt) stmtNode() {}
func (s *BlockStmt) stmtNode() {}
func (s *IfStmt) stmtNode() {}
func (s *CaseClause) stmtNode() {}
func (s *SwitchStmt) stmtNode() {}
func (s *TypeSwitchStmt) stmtNode() {}
func (s *CommClause) stmtNode() {}
func (s *SelectStmt) stmtNode() {}
func (s *ForStmt) stmtNode() {}
func (s *RangeStmt) stmtNode() {}
func (*BadStmt) stmtNode() {}
func (*DeclStmt) stmtNode() {}
func (*EmptyStmt) stmtNode() {}
func (*LabeledStmt) stmtNode() {}
func (*ExprStmt) stmtNode() {}
func (*SendStmt) stmtNode() {}
func (*IncDecStmt) stmtNode() {}
func (*AssignStmt) stmtNode() {}
func (*GoStmt) stmtNode() {}
func (*DeferStmt) stmtNode() {}
func (*ReturnStmt) stmtNode() {}
func (*BranchStmt) stmtNode() {}
func (*BlockStmt) stmtNode() {}
func (*IfStmt) stmtNode() {}
func (*CaseClause) stmtNode() {}
func (*SwitchStmt) stmtNode() {}
func (*TypeSwitchStmt) stmtNode() {}
func (*CommClause) stmtNode() {}
func (*SelectStmt) stmtNode() {}
func (*ForStmt) stmtNode() {}
func (*RangeStmt) stmtNode() {}
// ----------------------------------------------------------------------------
// Declarations
@ -807,9 +807,9 @@ func (s *TypeSpec) End() token.Pos { return s.Type.End() }
// specNode() ensures that only spec nodes can be
// assigned to a Spec.
//
func (s *ImportSpec) specNode() {}
func (s *ValueSpec) specNode() {}
func (s *TypeSpec) specNode() {}
func (*ImportSpec) specNode() {}
func (*ValueSpec) specNode() {}
func (*TypeSpec) specNode() {}
// A declaration is represented by one of the following declaration nodes.
//
@ -875,9 +875,9 @@ func (d *FuncDecl) End() token.Pos {
// declNode() ensures that only declaration nodes can be
// assigned to a DeclNode.
//
func (d *BadDecl) declNode() {}
func (d *GenDecl) declNode() {}
func (d *FuncDecl) declNode() {}
func (*BadDecl) declNode() {}
func (*GenDecl) declNode() {}
func (*FuncDecl) declNode() {}
// ----------------------------------------------------------------------------
// Files and packages

View File

@ -24,7 +24,7 @@ func exportFilter(name string) bool {
// it returns false otherwise.
//
func FileExports(src *File) bool {
return FilterFile(src, exportFilter)
return filterFile(src, exportFilter, true)
}
// PackageExports trims the AST for a Go package in place such that
@ -35,7 +35,7 @@ func FileExports(src *File) bool {
// it returns false otherwise.
//
func PackageExports(pkg *Package) bool {
return FilterPackage(pkg, exportFilter)
return filterPackage(pkg, exportFilter, true)
}
// ----------------------------------------------------------------------------
@ -72,7 +72,7 @@ func fieldName(x Expr) *Ident {
return nil
}
func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
func filterFieldList(fields *FieldList, filter Filter, export bool) (removedFields bool) {
if fields == nil {
return false
}
@ -93,8 +93,8 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
keepField = len(f.Names) > 0
}
if keepField {
if filter == exportFilter {
filterType(f.Type, filter)
if export {
filterType(f.Type, filter, export)
}
list[j] = f
j++
@ -107,84 +107,84 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
return
}
func filterParamList(fields *FieldList, filter Filter) bool {
func filterParamList(fields *FieldList, filter Filter, export bool) bool {
if fields == nil {
return false
}
var b bool
for _, f := range fields.List {
if filterType(f.Type, filter) {
if filterType(f.Type, filter, export) {
b = true
}
}
return b
}
func filterType(typ Expr, f Filter) bool {
func filterType(typ Expr, f Filter, export bool) bool {
switch t := typ.(type) {
case *Ident:
return f(t.Name)
case *ParenExpr:
return filterType(t.X, f)
return filterType(t.X, f, export)
case *ArrayType:
return filterType(t.Elt, f)
return filterType(t.Elt, f, export)
case *StructType:
if filterFieldList(t.Fields, f) {
if filterFieldList(t.Fields, f, export) {
t.Incomplete = true
}
return len(t.Fields.List) > 0
case *FuncType:
b1 := filterParamList(t.Params, f)
b2 := filterParamList(t.Results, f)
b1 := filterParamList(t.Params, f, export)
b2 := filterParamList(t.Results, f, export)
return b1 || b2
case *InterfaceType:
if filterFieldList(t.Methods, f) {
if filterFieldList(t.Methods, f, export) {
t.Incomplete = true
}
return len(t.Methods.List) > 0
case *MapType:
b1 := filterType(t.Key, f)
b2 := filterType(t.Value, f)
b1 := filterType(t.Key, f, export)
b2 := filterType(t.Value, f, export)
return b1 || b2
case *ChanType:
return filterType(t.Value, f)
return filterType(t.Value, f, export)
}
return false
}
func filterSpec(spec Spec, f Filter) bool {
func filterSpec(spec Spec, f Filter, export bool) bool {
switch s := spec.(type) {
case *ValueSpec:
s.Names = filterIdentList(s.Names, f)
if len(s.Names) > 0 {
if f == exportFilter {
filterType(s.Type, f)
if export {
filterType(s.Type, f, export)
}
return true
}
case *TypeSpec:
if f(s.Name.Name) {
if f == exportFilter {
filterType(s.Type, f)
if export {
filterType(s.Type, f, export)
}
return true
}
if f != exportFilter {
if !export {
// For general filtering (not just exports),
// filter type even if name is not filtered
// out.
// If the type contains filtered elements,
// keep the declaration.
return filterType(s.Type, f)
return filterType(s.Type, f, export)
}
}
return false
}
func filterSpecList(list []Spec, f Filter) []Spec {
func filterSpecList(list []Spec, f Filter, export bool) []Spec {
j := 0
for _, s := range list {
if filterSpec(s, f) {
if filterSpec(s, f, export) {
list[j] = s
j++
}
@ -200,9 +200,13 @@ func filterSpecList(list []Spec, f Filter) []Spec {
// filtering; it returns false otherwise.
//
func FilterDecl(decl Decl, f Filter) bool {
return filterDecl(decl, f, false)
}
func filterDecl(decl Decl, f Filter, export bool) bool {
switch d := decl.(type) {
case *GenDecl:
d.Specs = filterSpecList(d.Specs, f)
d.Specs = filterSpecList(d.Specs, f, export)
return len(d.Specs) > 0
case *FuncDecl:
return f(d.Name.Name)
@ -221,9 +225,13 @@ func FilterDecl(decl Decl, f Filter) bool {
// left after filtering; it returns false otherwise.
//
func FilterFile(src *File, f Filter) bool {
return filterFile(src, f, false)
}
func filterFile(src *File, f Filter, export bool) bool {
j := 0
for _, d := range src.Decls {
if FilterDecl(d, f) {
if filterDecl(d, f, export) {
src.Decls[j] = d
j++
}
@ -244,9 +252,13 @@ func FilterFile(src *File, f Filter) bool {
// left after filtering; it returns false otherwise.
//
func FilterPackage(pkg *Package, f Filter) bool {
return filterPackage(pkg, f, false)
}
func filterPackage(pkg *Package, f Filter, export bool) bool {
hasDecls := false
for _, src := range pkg.Files {
if FilterFile(src, f) {
if filterFile(src, f, export) {
hasDecls = true
}
}

View File

@ -37,18 +37,20 @@ var buildPkgs = []struct {
{
"go/build/cmdtest",
&DirInfo{
GoFiles: []string{"main.go"},
Package: "main",
Imports: []string{"go/build/pkgtest"},
GoFiles: []string{"main.go"},
Package: "main",
Imports: []string{"go/build/pkgtest"},
TestImports: []string{},
},
},
{
"go/build/cgotest",
&DirInfo{
CgoFiles: []string{"cgotest.go"},
CFiles: []string{"cgotest.c"},
Imports: []string{"C", "unsafe"},
Package: "cgotest",
CgoFiles: []string{"cgotest.go"},
CFiles: []string{"cgotest.c"},
Imports: []string{"C", "unsafe"},
TestImports: []string{},
Package: "cgotest",
},
},
}

View File

@ -13,6 +13,8 @@ import (
"io"
"os"
"path/filepath"
"strconv"
"strings"
"text/tabwriter"
)
@ -244,6 +246,8 @@ func (p *printer) writeItem(pos token.Position, data string) {
p.last = p.pos
}
const linePrefix = "//line "
// writeCommentPrefix writes the whitespace before a comment.
// If there is any pending whitespace, it consumes as much of
// it as is likely to help position the comment nicely.
@ -252,7 +256,7 @@ func (p *printer) writeItem(pos token.Position, data string) {
// a group of comments (or nil), and isKeyword indicates if the
// next item is a keyword.
//
func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment, isKeyword bool) {
func (p *printer) writeCommentPrefix(pos, next token.Position, prev, comment *ast.Comment, isKeyword bool) {
if p.written == 0 {
// the comment is the first item to be printed - don't write any whitespace
return
@ -337,6 +341,13 @@ func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment
}
p.writeWhitespace(j)
}
// turn off indent if we're about to print a line directive.
indent := p.indent
if strings.HasPrefix(comment.Text, linePrefix) {
p.indent = 0
}
// use formfeeds to break columns before a comment;
// this is analogous to using formfeeds to separate
// individual lines of /*-style comments - but make
@ -347,6 +358,7 @@ func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment
n = 1
}
p.writeNewlines(n, true)
p.indent = indent
}
}
@ -526,6 +538,26 @@ func stripCommonPrefix(lines [][]byte) {
func (p *printer) writeComment(comment *ast.Comment) {
text := comment.Text
if strings.HasPrefix(text, linePrefix) {
pos := strings.TrimSpace(text[len(linePrefix):])
i := strings.LastIndex(pos, ":")
if i >= 0 {
// The line directive we are about to print changed
// the Filename and Line number used by go/token
// as it was reading the input originally.
// In order to match the original input, we have to
// update our own idea of the file and line number
// accordingly, after printing the directive.
file := pos[:i]
line, _ := strconv.Atoi(string(pos[i+1:]))
defer func() {
p.pos.Filename = string(file)
p.pos.Line = line
p.pos.Column = 1
}()
}
}
// shortcut common case of //-style comments
if text[1] == '/' {
p.writeItem(p.fset.Position(comment.Pos()), p.escape(text))
@ -599,7 +631,7 @@ func (p *printer) intersperseComments(next token.Position, tok token.Token) (dro
var last *ast.Comment
for ; p.commentBefore(next); p.cindex++ {
for _, c := range p.comments[p.cindex].List {
p.writeCommentPrefix(p.fset.Position(c.Pos()), next, last, tok.IsKeyword())
p.writeCommentPrefix(p.fset.Position(c.Pos()), next, last, c, tok.IsKeyword())
p.writeComment(c)
last = c
}

View File

@ -37,7 +37,7 @@ lower-cased, and attributes are collected into a []Attribute. For example:
for {
if z.Next() == html.ErrorToken {
// Returning io.EOF indicates success.
return z.Error()
return z.Err()
}
emitToken(z.Token())
}
@ -51,7 +51,7 @@ call to Next. For example, to extract an HTML page's anchor text:
tt := z.Next()
switch tt {
case ErrorToken:
return z.Error()
return z.Err()
case TextToken:
if depth > 0 {
// emitBytes should copy the []byte it receives,

File diff suppressed because it is too large Load Diff

View File

@ -133,8 +133,8 @@ func TestParser(t *testing.T) {
n int
}{
// TODO(nigeltao): Process all the test cases from all the .dat files.
{"tests1.dat", 92},
{"tests2.dat", 0},
{"tests1.dat", -1},
{"tests2.dat", 43},
{"tests3.dat", 0},
}
for _, tf := range testFiles {
@ -213,4 +213,8 @@ var renderTestBlacklist = map[string]bool{
// More cases of <a> being reparented:
`<a href="blah">aba<table><a href="foo">br<tr><td></td></tr>x</table>aoe`: true,
`<a><table><a></table><p><a><div><a>`: true,
`<a><table><td><a><table></table><a></tr><a></table><a>`: true,
// A <plaintext> element is reparented, putting it before a table.
// A <plaintext> element can't have anything after it in HTML.
`<table><plaintext><td>`: true,
}

View File

@ -52,7 +52,19 @@ func Render(w io.Writer, n *Node) error {
return buf.Flush()
}
// plaintextAbort is returned from render1 when a <plaintext> element
// has been rendered. No more end tags should be rendered after that.
var plaintextAbort = errors.New("html: internal error (plaintext abort)")
func render(w writer, n *Node) error {
err := render1(w, n)
if err == plaintextAbort {
err = nil
}
return err
}
func render1(w writer, n *Node) error {
// Render non-element nodes; these are the easy cases.
switch n.Type {
case ErrorNode:
@ -61,7 +73,7 @@ func render(w writer, n *Node) error {
return escape(w, n.Data)
case DocumentNode:
for _, c := range n.Child {
if err := render(w, c); err != nil {
if err := render1(w, c); err != nil {
return err
}
}
@ -128,7 +140,7 @@ func render(w writer, n *Node) error {
// Render any child nodes.
switch n.Data {
case "noembed", "noframes", "noscript", "script", "style":
case "noembed", "noframes", "noscript", "plaintext", "script", "style":
for _, c := range n.Child {
if c.Type != TextNode {
return fmt.Errorf("html: raw text element <%s> has non-text child node", n.Data)
@ -137,18 +149,23 @@ func render(w writer, n *Node) error {
return err
}
}
if n.Data == "plaintext" {
// Don't render anything else. <plaintext> must be the
// last element in the file, with no closing tag.
return plaintextAbort
}
case "textarea", "title":
for _, c := range n.Child {
if c.Type != TextNode {
return fmt.Errorf("html: RCDATA element <%s> has non-text child node", n.Data)
}
if err := render(w, c); err != nil {
if err := render1(w, c); err != nil {
return err
}
}
default:
for _, c := range n.Child {
if err := render(w, c); err != nil {
if err := render1(w, c); err != nil {
return err
}
}

View File

@ -6,6 +6,7 @@ package template
import (
"fmt"
"reflect"
)
// Strings of content from a trusted source.
@ -70,10 +71,25 @@ const (
contentTypeUnsafe
)
// indirect returns the value, after dereferencing as many times
// as necessary to reach the base type (or nil).
func indirect(a interface{}) interface{} {
if t := reflect.TypeOf(a); t.Kind() != reflect.Ptr {
// Avoid creating a reflect.Value if it's not a pointer.
return a
}
v := reflect.ValueOf(a)
for v.Kind() == reflect.Ptr && !v.IsNil() {
v = v.Elem()
}
return v.Interface()
}
// stringify converts its arguments to a string and the type of the content.
// All pointers are dereferenced, as in the text/template package.
func stringify(args ...interface{}) (string, contentType) {
if len(args) == 1 {
switch s := args[0].(type) {
switch s := indirect(args[0]).(type) {
case string:
return s, contentTypePlain
case CSS:
@ -90,5 +106,8 @@ func stringify(args ...interface{}) (string, contentType) {
return string(s), contentTypeURL
}
}
for i, arg := range args {
args[i] = indirect(arg)
}
return fmt.Sprint(args...), contentTypePlain
}

View File

@ -28,7 +28,7 @@ func (x *goodMarshaler) MarshalJSON() ([]byte, error) {
}
func TestEscape(t *testing.T) {
var data = struct {
data := struct {
F, T bool
C, G, H string
A, E []string
@ -50,6 +50,7 @@ func TestEscape(t *testing.T) {
Z: nil,
W: HTML(`&iexcl;<b class="foo">Hello</b>, <textarea>O'World</textarea>!`),
}
pdata := &data
tests := []struct {
name string
@ -668,6 +669,15 @@ func TestEscape(t *testing.T) {
t.Errorf("%s: escaped output: want\n\t%q\ngot\n\t%q", test.name, w, g)
continue
}
b.Reset()
if err := tmpl.Execute(b, pdata); err != nil {
t.Errorf("%s: template execution failed for pointer: %s", test.name, err)
continue
}
if w, g := test.output, b.String(); w != g {
t.Errorf("%s: escaped output for pointer: want\n\t%q\ngot\n\t%q", test.name, w, g)
continue
}
}
}
@ -1605,6 +1615,29 @@ func TestRedundantFuncs(t *testing.T) {
}
}
func TestIndirectPrint(t *testing.T) {
a := 3
ap := &a
b := "hello"
bp := &b
bpp := &bp
tmpl := Must(New("t").Parse(`{{.}}`))
var buf bytes.Buffer
err := tmpl.Execute(&buf, ap)
if err != nil {
t.Errorf("Unexpected error: %s", err)
} else if buf.String() != "3" {
t.Errorf(`Expected "3"; got %q`, buf.String())
}
buf.Reset()
err = tmpl.Execute(&buf, bpp)
if err != nil {
t.Errorf("Unexpected error: %s", err)
} else if buf.String() != "hello" {
t.Errorf(`Expected "hello"; got %q`, buf.String())
}
}
func BenchmarkEscapedExecute(b *testing.B) {
tmpl := Must(New("t").Parse(`<a onclick="alert('{{.}}')">{{.}}</a>`))
var buf bytes.Buffer

View File

@ -8,6 +8,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"strings"
"unicode/utf8"
)
@ -117,12 +118,24 @@ var regexpPrecederKeywords = map[string]bool{
"void": true,
}
var jsonMarshalType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
// indirectToJSONMarshaler returns the value, after dereferencing as many times
// as necessary to reach the base type (or nil) or an implementation of json.Marshal.
func indirectToJSONMarshaler(a interface{}) interface{} {
v := reflect.ValueOf(a)
for !v.Type().Implements(jsonMarshalType) && v.Kind() == reflect.Ptr && !v.IsNil() {
v = v.Elem()
}
return v.Interface()
}
// jsValEscaper escapes its inputs to a JS Expression (section 11.14) that has
// nether side-effects nor free variables outside (NaN, Infinity).
// neither side-effects nor free variables outside (NaN, Infinity).
func jsValEscaper(args ...interface{}) string {
var a interface{}
if len(args) == 1 {
a = args[0]
a = indirectToJSONMarshaler(args[0])
switch t := a.(type) {
case JS:
return string(t)
@ -135,6 +148,9 @@ func jsValEscaper(args ...interface{}) string {
a = t.String()
}
} else {
for i, arg := range args {
args[i] = indirectToJSONMarshaler(arg)
}
a = fmt.Sprint(args...)
}
// TODO: detect cycles before calling Marshal which loops infinitely on

View File

@ -401,14 +401,14 @@ func (z *Tokenizer) readStartTag() TokenType {
break
}
}
// Any "<noembed>", "<noframes>", "<noscript>", "<script>", "<style>",
// Any "<noembed>", "<noframes>", "<noscript>", "<plaintext", "<script>", "<style>",
// "<textarea>" or "<title>" tag flags the tokenizer's next token as raw.
// The tag name lengths of these special cases ranges in [5, 8].
if x := z.data.end - z.data.start; 5 <= x && x <= 8 {
// The tag name lengths of these special cases ranges in [5, 9].
if x := z.data.end - z.data.start; 5 <= x && x <= 9 {
switch z.buf[z.data.start] {
case 'n', 's', 't', 'N', 'S', 'T':
case 'n', 'p', 's', 't', 'N', 'P', 'S', 'T':
switch s := strings.ToLower(string(z.buf[z.data.start:z.data.end])); s {
case "noembed", "noframes", "noscript", "script", "style", "textarea", "title":
case "noembed", "noframes", "noscript", "plaintext", "script", "style", "textarea", "title":
z.rawTag = s
}
}
@ -551,9 +551,19 @@ func (z *Tokenizer) Next() TokenType {
z.data.start = z.raw.end
z.data.end = z.raw.end
if z.rawTag != "" {
z.readRawOrRCDATA()
z.tt = TextToken
return z.tt
if z.rawTag == "plaintext" {
// Read everything up to EOF.
for z.err == nil {
z.readByte()
}
z.textIsRaw = true
} else {
z.readRawOrRCDATA()
}
if z.data.end > z.data.start {
z.tt = TextToken
return z.tt
}
}
z.textIsRaw = false

View File

@ -4,10 +4,7 @@
package tiff
import (
"io"
"os"
)
import "io"
// buffer buffers an io.Reader to satisfy io.ReaderAt.
type buffer struct {
@ -19,7 +16,7 @@ func (b *buffer) ReadAt(p []byte, off int64) (int, error) {
o := int(off)
end := o + len(p)
if int64(end) != off+int64(len(p)) {
return 0, os.EINVAL
return 0, io.ErrUnexpectedEOF
}
m := len(b.buf)

View File

@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"strconv"
"time"
)
// Random number state, accessed without lock; racy but harmless.
@ -17,8 +18,7 @@ import (
var rand uint32
func reseed() uint32 {
sec, nsec, _ := os.Time()
return uint32(sec*1e9 + nsec + int64(os.Getpid()))
return uint32(time.Nanoseconds() + int64(os.Getpid()))
}
func nextSuffix() string {

View File

@ -8,6 +8,7 @@
package syslog
import (
"errors"
"fmt"
"log"
"net"
@ -75,7 +76,7 @@ func Dial(network, raddr string, priority Priority, prefix string) (w *Writer, e
// Write sends a log message to the syslog daemon.
func (w *Writer) Write(b []byte) (int, error) {
if w.priority > LOG_DEBUG || w.priority < LOG_EMERG {
return 0, os.EINVAL
return 0, errors.New("log/syslog: invalid priority")
}
return w.conn.writeBytes(w.priority, w.prefix, b)
}

View File

@ -176,7 +176,7 @@ func (z *Int) Quo(x, y *Int) *Int {
// If y == 0, a division-by-zero run-time panic occurs.
// Rem implements truncated modulus (like Go); see QuoRem for more details.
func (z *Int) Rem(x, y *Int) *Int {
_, z.abs = nat{}.div(z.abs, x.abs, y.abs)
_, z.abs = nat(nil).div(z.abs, x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg // 0 has no sign
return z
}
@ -678,14 +678,14 @@ func (z *Int) Bit(i int) uint {
panic("negative bit index")
}
if z.neg {
t := nat{}.sub(z.abs, natOne)
t := nat(nil).sub(z.abs, natOne)
return t.bit(uint(i)) ^ 1
}
return z.abs.bit(uint(i))
}
// SetBit sets the i'th bit of z to bit and returns z.
// SetBit sets z to x, with x's i'th bit set to b (0 or 1).
// That is, if bit is 1 SetBit sets z = x | (1 << i);
// if bit is 0 it sets z = x &^ (1 << i). If bit is not 0 or 1,
// SetBit will panic.
@ -710,8 +710,8 @@ func (z *Int) And(x, y *Int) *Int {
if x.neg == y.neg {
if x.neg {
// (-x) & (-y) == ^(x-1) & ^(y-1) == ^((x-1) | (y-1)) == -(((x-1) | (y-1)) + 1)
x1 := nat{}.sub(x.abs, natOne)
y1 := nat{}.sub(y.abs, natOne)
x1 := nat(nil).sub(x.abs, natOne)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.or(x1, y1), natOne)
z.neg = true // z cannot be zero if x and y are negative
return z
@ -729,7 +729,7 @@ func (z *Int) And(x, y *Int) *Int {
}
// x & (-y) == x & ^(y-1) == x &^ (y-1)
y1 := nat{}.sub(y.abs, natOne)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.andNot(x.abs, y1)
z.neg = false
return z
@ -740,8 +740,8 @@ func (z *Int) AndNot(x, y *Int) *Int {
if x.neg == y.neg {
if x.neg {
// (-x) &^ (-y) == ^(x-1) &^ ^(y-1) == ^(x-1) & (y-1) == (y-1) &^ (x-1)
x1 := nat{}.sub(x.abs, natOne)
y1 := nat{}.sub(y.abs, natOne)
x1 := nat(nil).sub(x.abs, natOne)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.andNot(y1, x1)
z.neg = false
return z
@ -755,14 +755,14 @@ func (z *Int) AndNot(x, y *Int) *Int {
if x.neg {
// (-x) &^ y == ^(x-1) &^ y == ^(x-1) & ^y == ^((x-1) | y) == -(((x-1) | y) + 1)
x1 := nat{}.sub(x.abs, natOne)
x1 := nat(nil).sub(x.abs, natOne)
z.abs = z.abs.add(z.abs.or(x1, y.abs), natOne)
z.neg = true // z cannot be zero if x is negative and y is positive
return z
}
// x &^ (-y) == x &^ ^(y-1) == x & (y-1)
y1 := nat{}.add(y.abs, natOne)
y1 := nat(nil).add(y.abs, natOne)
z.abs = z.abs.and(x.abs, y1)
z.neg = false
return z
@ -773,8 +773,8 @@ func (z *Int) Or(x, y *Int) *Int {
if x.neg == y.neg {
if x.neg {
// (-x) | (-y) == ^(x-1) | ^(y-1) == ^((x-1) & (y-1)) == -(((x-1) & (y-1)) + 1)
x1 := nat{}.sub(x.abs, natOne)
y1 := nat{}.sub(y.abs, natOne)
x1 := nat(nil).sub(x.abs, natOne)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.and(x1, y1), natOne)
z.neg = true // z cannot be zero if x and y are negative
return z
@ -792,7 +792,7 @@ func (z *Int) Or(x, y *Int) *Int {
}
// x | (-y) == x | ^(y-1) == ^((y-1) &^ x) == -(^((y-1) &^ x) + 1)
y1 := nat{}.sub(y.abs, natOne)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.andNot(y1, x.abs), natOne)
z.neg = true // z cannot be zero if one of x or y is negative
return z
@ -803,8 +803,8 @@ func (z *Int) Xor(x, y *Int) *Int {
if x.neg == y.neg {
if x.neg {
// (-x) ^ (-y) == ^(x-1) ^ ^(y-1) == (x-1) ^ (y-1)
x1 := nat{}.sub(x.abs, natOne)
y1 := nat{}.sub(y.abs, natOne)
x1 := nat(nil).sub(x.abs, natOne)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.xor(x1, y1)
z.neg = false
return z
@ -822,7 +822,7 @@ func (z *Int) Xor(x, y *Int) *Int {
}
// x ^ (-y) == x ^ ^(y-1) == ^(x ^ (y-1)) == -((x ^ (y-1)) + 1)
y1 := nat{}.sub(y.abs, natOne)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.xor(x.abs, y1), natOne)
z.neg = true // z cannot be zero if only one of x or y is negative
return z

View File

@ -447,10 +447,10 @@ func (z nat) mulRange(a, b uint64) nat {
case a == b:
return z.setUint64(a)
case a+1 == b:
return z.mul(nat{}.setUint64(a), nat{}.setUint64(b))
return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
}
m := (a + b) / 2
return z.mul(nat{}.mulRange(a, m), nat{}.mulRange(m+1, b))
return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
}
// q = (x-r)/y, with 0 <= r < y
@ -785,7 +785,7 @@ func (x nat) string(charset string) string {
}
// preserve x, create local copy for use in repeated divisions
q := nat{}.set(x)
q := nat(nil).set(x)
var r Word
// convert
@ -1191,11 +1191,11 @@ func (n nat) probablyPrime(reps int) bool {
return false
}
nm1 := nat{}.sub(n, natOne)
nm1 := nat(nil).sub(n, natOne)
// 1<<k * q = nm1;
q, k := nm1.powersOfTwoDecompose()
nm3 := nat{}.sub(nm1, natTwo)
nm3 := nat(nil).sub(nm1, natTwo)
rand := rand.New(rand.NewSource(int64(n[0])))
var x, y, quotient nat

View File

@ -16,9 +16,9 @@ var cmpTests = []struct {
r int
}{
{nil, nil, 0},
{nil, nat{}, 0},
{nat{}, nil, 0},
{nat{}, nat{}, 0},
{nil, nat(nil), 0},
{nat(nil), nil, 0},
{nat(nil), nat(nil), 0},
{nat{0}, nat{0}, 0},
{nat{0}, nat{1}, -1},
{nat{1}, nat{0}, 1},
@ -67,7 +67,7 @@ var prodNN = []argNN{
func TestSet(t *testing.T) {
for _, a := range sumNN {
z := nat{}.set(a.z)
z := nat(nil).set(a.z)
if z.cmp(a.z) != 0 {
t.Errorf("got z = %v; want %v", z, a.z)
}
@ -129,7 +129,7 @@ var mulRangesN = []struct {
func TestMulRangeN(t *testing.T) {
for i, r := range mulRangesN {
prod := nat{}.mulRange(r.a, r.b).decimalString()
prod := nat(nil).mulRange(r.a, r.b).decimalString()
if prod != r.prod {
t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
}
@ -175,7 +175,7 @@ func toString(x nat, charset string) string {
s := make([]byte, i)
// don't destroy x
q := nat{}.set(x)
q := nat(nil).set(x)
// convert
for len(q) > 0 {
@ -212,7 +212,7 @@ func TestString(t *testing.T) {
t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s)
}
x, b, err := nat{}.scan(strings.NewReader(a.s), len(a.c))
x, b, err := nat(nil).scan(strings.NewReader(a.s), len(a.c))
if x.cmp(a.x) != 0 {
t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
}
@ -271,7 +271,7 @@ var natScanTests = []struct {
func TestScanBase(t *testing.T) {
for _, a := range natScanTests {
r := strings.NewReader(a.s)
x, b, err := nat{}.scan(r, a.base)
x, b, err := nat(nil).scan(r, a.base)
if err == nil && !a.ok {
t.Errorf("scan%+v\n\texpected error", a)
}
@ -651,17 +651,17 @@ var expNNTests = []struct {
func TestExpNN(t *testing.T) {
for i, test := range expNNTests {
x, _, _ := nat{}.scan(strings.NewReader(test.x), 0)
y, _, _ := nat{}.scan(strings.NewReader(test.y), 0)
out, _, _ := nat{}.scan(strings.NewReader(test.out), 0)
x, _, _ := nat(nil).scan(strings.NewReader(test.x), 0)
y, _, _ := nat(nil).scan(strings.NewReader(test.y), 0)
out, _, _ := nat(nil).scan(strings.NewReader(test.out), 0)
var m nat
if len(test.m) > 0 {
m, _, _ = nat{}.scan(strings.NewReader(test.m), 0)
m, _, _ = nat(nil).scan(strings.NewReader(test.m), 0)
}
z := nat{}.expNN(x, y, m)
z := nat(nil).expNN(x, y, m)
if z.cmp(out) != 0 {
t.Errorf("#%d got %v want %v", i, z, out)
}

View File

@ -33,7 +33,7 @@ func (z *Rat) SetFrac(a, b *Int) *Rat {
panic("division by zero")
}
if &z.a == b || alias(z.a.abs, babs) {
babs = nat{}.set(babs) // make a copy
babs = nat(nil).set(babs) // make a copy
}
z.a.abs = z.a.abs.set(a.abs)
z.b = z.b.set(babs)
@ -315,7 +315,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
if _, ok := z.a.SetString(s, 10); !ok {
return nil, false
}
powTen := nat{}.expNN(natTen, exp.abs, nil)
powTen := nat(nil).expNN(natTen, exp.abs, nil)
if exp.neg {
z.b = powTen
z.norm()
@ -357,23 +357,23 @@ func (z *Rat) FloatString(prec int) string {
}
// z.b != 0
q, r := nat{}.div(nat{}, z.a.abs, z.b)
q, r := nat(nil).div(nat(nil), z.a.abs, z.b)
p := natOne
if prec > 0 {
p = nat{}.expNN(natTen, nat{}.setUint64(uint64(prec)), nil)
p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil)
}
r = r.mul(r, p)
r, r2 := r.div(nat{}, r, z.b)
r, r2 := r.div(nat(nil), r, z.b)
// see if we need to round up
r2 = r2.add(r2, r2)
if z.b.cmp(r2) <= 0 {
r = r.add(r, natOne)
if r.cmp(p) >= 0 {
q = nat{}.add(q, natOne)
r = nat{}.sub(r, p)
q = nat(nil).add(q, natOne)
r = nat(nil).sub(r, p)
}
}

View File

@ -63,7 +63,7 @@ package math
// Stephen L. Moshier
// moshier@na-net.ornl.gov
var _P = [...]float64{
var _gamP = [...]float64{
1.60119522476751861407e-04,
1.19135147006586384913e-03,
1.04213797561761569935e-02,
@ -72,7 +72,7 @@ var _P = [...]float64{
4.94214826801497100753e-01,
9.99999999999999996796e-01,
}
var _Q = [...]float64{
var _gamQ = [...]float64{
-2.31581873324120129819e-05,
5.39605580493303397842e-04,
-4.45641913851797240494e-03,
@ -82,7 +82,7 @@ var _Q = [...]float64{
7.14304917030273074085e-02,
1.00000000000000000320e+00,
}
var _S = [...]float64{
var _gamS = [...]float64{
7.87311395793093628397e-04,
-2.29549961613378126380e-04,
-2.68132617805781232825e-03,
@ -98,7 +98,7 @@ func stirling(x float64) float64 {
MaxStirling = 143.01608
)
w := 1 / x
w = 1 + w*((((_S[0]*w+_S[1])*w+_S[2])*w+_S[3])*w+_S[4])
w = 1 + w*((((_gamS[0]*w+_gamS[1])*w+_gamS[2])*w+_gamS[3])*w+_gamS[4])
y := Exp(x)
if x > MaxStirling { // avoid Pow() overflow
v := Pow(x, 0.5*x-0.25)
@ -176,8 +176,8 @@ func Gamma(x float64) float64 {
}
x = x - 2
p = (((((x*_P[0]+_P[1])*x+_P[2])*x+_P[3])*x+_P[4])*x+_P[5])*x + _P[6]
q = ((((((x*_Q[0]+_Q[1])*x+_Q[2])*x+_Q[3])*x+_Q[4])*x+_Q[5])*x+_Q[6])*x + _Q[7]
p = (((((x*_gamP[0]+_gamP[1])*x+_gamP[2])*x+_gamP[3])*x+_gamP[4])*x+_gamP[5])*x + _gamP[6]
q = ((((((x*_gamQ[0]+_gamQ[1])*x+_gamQ[2])*x+_gamQ[3])*x+_gamQ[4])*x+_gamQ[5])*x+_gamQ[6])*x + _gamQ[7]
return z * p / q
small:

View File

@ -88,6 +88,81 @@ package math
//
//
var _lgamA = [...]float64{
7.72156649015328655494e-02, // 0x3FB3C467E37DB0C8
3.22467033424113591611e-01, // 0x3FD4A34CC4A60FAD
6.73523010531292681824e-02, // 0x3FB13E001A5562A7
2.05808084325167332806e-02, // 0x3F951322AC92547B
7.38555086081402883957e-03, // 0x3F7E404FB68FEFE8
2.89051383673415629091e-03, // 0x3F67ADD8CCB7926B
1.19270763183362067845e-03, // 0x3F538A94116F3F5D
5.10069792153511336608e-04, // 0x3F40B6C689B99C00
2.20862790713908385557e-04, // 0x3F2CF2ECED10E54D
1.08011567247583939954e-04, // 0x3F1C5088987DFB07
2.52144565451257326939e-05, // 0x3EFA7074428CFA52
4.48640949618915160150e-05, // 0x3F07858E90A45837
}
var _lgamR = [...]float64{
1.0, // placeholder
1.39200533467621045958e+00, // 0x3FF645A762C4AB74
7.21935547567138069525e-01, // 0x3FE71A1893D3DCDC
1.71933865632803078993e-01, // 0x3FC601EDCCFBDF27
1.86459191715652901344e-02, // 0x3F9317EA742ED475
7.77942496381893596434e-04, // 0x3F497DDACA41A95B
7.32668430744625636189e-06, // 0x3EDEBAF7A5B38140
}
var _lgamS = [...]float64{
-7.72156649015328655494e-02, // 0xBFB3C467E37DB0C8
2.14982415960608852501e-01, // 0x3FCB848B36E20878
3.25778796408930981787e-01, // 0x3FD4D98F4F139F59
1.46350472652464452805e-01, // 0x3FC2BB9CBEE5F2F7
2.66422703033638609560e-02, // 0x3F9B481C7E939961
1.84028451407337715652e-03, // 0x3F5E26B67368F239
3.19475326584100867617e-05, // 0x3F00BFECDD17E945
}
var _lgamT = [...]float64{
4.83836122723810047042e-01, // 0x3FDEF72BC8EE38A2
-1.47587722994593911752e-01, // 0xBFC2E4278DC6C509
6.46249402391333854778e-02, // 0x3FB08B4294D5419B
-3.27885410759859649565e-02, // 0xBFA0C9A8DF35B713
1.79706750811820387126e-02, // 0x3F9266E7970AF9EC
-1.03142241298341437450e-02, // 0xBF851F9FBA91EC6A
6.10053870246291332635e-03, // 0x3F78FCE0E370E344
-3.68452016781138256760e-03, // 0xBF6E2EFFB3E914D7
2.25964780900612472250e-03, // 0x3F6282D32E15C915
-1.40346469989232843813e-03, // 0xBF56FE8EBF2D1AF1
8.81081882437654011382e-04, // 0x3F4CDF0CEF61A8E9
-5.38595305356740546715e-04, // 0xBF41A6109C73E0EC
3.15632070903625950361e-04, // 0x3F34AF6D6C0EBBF7
-3.12754168375120860518e-04, // 0xBF347F24ECC38C38
3.35529192635519073543e-04, // 0x3F35FD3EE8C2D3F4
}
var _lgamU = [...]float64{
-7.72156649015328655494e-02, // 0xBFB3C467E37DB0C8
6.32827064025093366517e-01, // 0x3FE4401E8B005DFF
1.45492250137234768737e+00, // 0x3FF7475CD119BD6F
9.77717527963372745603e-01, // 0x3FEF497644EA8450
2.28963728064692451092e-01, // 0x3FCD4EAEF6010924
1.33810918536787660377e-02, // 0x3F8B678BBF2BAB09
}
var _lgamV = [...]float64{
1.0,
2.45597793713041134822e+00, // 0x4003A5D7C2BD619C
2.12848976379893395361e+00, // 0x40010725A42B18F5
7.69285150456672783825e-01, // 0x3FE89DFBE45050AF
1.04222645593369134254e-01, // 0x3FBAAE55D6537C88
3.21709242282423911810e-03, // 0x3F6A5ABB57D0CF61
}
var _lgamW = [...]float64{
4.18938533204672725052e-01, // 0x3FDACFE390C97D69
8.33333333333329678849e-02, // 0x3FB555555555553B
-2.77777777728775536470e-03, // 0xBF66C16C16B02E5C
7.93650558643019558500e-04, // 0x3F4A019F98CF38B6
-5.95187557450339963135e-04, // 0xBF4380CB8C0FE741
8.36339918996282139126e-04, // 0x3F4B67BA4CDAD5D1
-1.63092934096575273989e-03, // 0xBF5AB89D0B9E43E4
}
// Lgamma returns the natural logarithm and sign (-1 or +1) of Gamma(x).
//
// Special cases are:
@ -103,68 +178,10 @@ func Lgamma(x float64) (lgamma float64, sign int) {
Two53 = 1 << 53 // 0x4340000000000000 ~9.0072e+15
Two58 = 1 << 58 // 0x4390000000000000 ~2.8823e+17
Tiny = 1.0 / (1 << 70) // 0x3b90000000000000 ~8.47033e-22
A0 = 7.72156649015328655494e-02 // 0x3FB3C467E37DB0C8
A1 = 3.22467033424113591611e-01 // 0x3FD4A34CC4A60FAD
A2 = 6.73523010531292681824e-02 // 0x3FB13E001A5562A7
A3 = 2.05808084325167332806e-02 // 0x3F951322AC92547B
A4 = 7.38555086081402883957e-03 // 0x3F7E404FB68FEFE8
A5 = 2.89051383673415629091e-03 // 0x3F67ADD8CCB7926B
A6 = 1.19270763183362067845e-03 // 0x3F538A94116F3F5D
A7 = 5.10069792153511336608e-04 // 0x3F40B6C689B99C00
A8 = 2.20862790713908385557e-04 // 0x3F2CF2ECED10E54D
A9 = 1.08011567247583939954e-04 // 0x3F1C5088987DFB07
A10 = 2.52144565451257326939e-05 // 0x3EFA7074428CFA52
A11 = 4.48640949618915160150e-05 // 0x3F07858E90A45837
Tc = 1.46163214496836224576e+00 // 0x3FF762D86356BE3F
Tf = -1.21486290535849611461e-01 // 0xBFBF19B9BCC38A42
// Tt = -(tail of Tf)
Tt = -3.63867699703950536541e-18 // 0xBC50C7CAA48A971F
T0 = 4.83836122723810047042e-01 // 0x3FDEF72BC8EE38A2
T1 = -1.47587722994593911752e-01 // 0xBFC2E4278DC6C509
T2 = 6.46249402391333854778e-02 // 0x3FB08B4294D5419B
T3 = -3.27885410759859649565e-02 // 0xBFA0C9A8DF35B713
T4 = 1.79706750811820387126e-02 // 0x3F9266E7970AF9EC
T5 = -1.03142241298341437450e-02 // 0xBF851F9FBA91EC6A
T6 = 6.10053870246291332635e-03 // 0x3F78FCE0E370E344
T7 = -3.68452016781138256760e-03 // 0xBF6E2EFFB3E914D7
T8 = 2.25964780900612472250e-03 // 0x3F6282D32E15C915
T9 = -1.40346469989232843813e-03 // 0xBF56FE8EBF2D1AF1
T10 = 8.81081882437654011382e-04 // 0x3F4CDF0CEF61A8E9
T11 = -5.38595305356740546715e-04 // 0xBF41A6109C73E0EC
T12 = 3.15632070903625950361e-04 // 0x3F34AF6D6C0EBBF7
T13 = -3.12754168375120860518e-04 // 0xBF347F24ECC38C38
T14 = 3.35529192635519073543e-04 // 0x3F35FD3EE8C2D3F4
U0 = -7.72156649015328655494e-02 // 0xBFB3C467E37DB0C8
U1 = 6.32827064025093366517e-01 // 0x3FE4401E8B005DFF
U2 = 1.45492250137234768737e+00 // 0x3FF7475CD119BD6F
U3 = 9.77717527963372745603e-01 // 0x3FEF497644EA8450
U4 = 2.28963728064692451092e-01 // 0x3FCD4EAEF6010924
U5 = 1.33810918536787660377e-02 // 0x3F8B678BBF2BAB09
V1 = 2.45597793713041134822e+00 // 0x4003A5D7C2BD619C
V2 = 2.12848976379893395361e+00 // 0x40010725A42B18F5
V3 = 7.69285150456672783825e-01 // 0x3FE89DFBE45050AF
V4 = 1.04222645593369134254e-01 // 0x3FBAAE55D6537C88
V5 = 3.21709242282423911810e-03 // 0x3F6A5ABB57D0CF61
S0 = -7.72156649015328655494e-02 // 0xBFB3C467E37DB0C8
S1 = 2.14982415960608852501e-01 // 0x3FCB848B36E20878
S2 = 3.25778796408930981787e-01 // 0x3FD4D98F4F139F59
S3 = 1.46350472652464452805e-01 // 0x3FC2BB9CBEE5F2F7
S4 = 2.66422703033638609560e-02 // 0x3F9B481C7E939961
S5 = 1.84028451407337715652e-03 // 0x3F5E26B67368F239
S6 = 3.19475326584100867617e-05 // 0x3F00BFECDD17E945
R1 = 1.39200533467621045958e+00 // 0x3FF645A762C4AB74
R2 = 7.21935547567138069525e-01 // 0x3FE71A1893D3DCDC
R3 = 1.71933865632803078993e-01 // 0x3FC601EDCCFBDF27
R4 = 1.86459191715652901344e-02 // 0x3F9317EA742ED475
R5 = 7.77942496381893596434e-04 // 0x3F497DDACA41A95B
R6 = 7.32668430744625636189e-06 // 0x3EDEBAF7A5B38140
W0 = 4.18938533204672725052e-01 // 0x3FDACFE390C97D69
W1 = 8.33333333333329678849e-02 // 0x3FB555555555553B
W2 = -2.77777777728775536470e-03 // 0xBF66C16C16B02E5C
W3 = 7.93650558643019558500e-04 // 0x3F4A019F98CF38B6
W4 = -5.95187557450339963135e-04 // 0xBF4380CB8C0FE741
W5 = 8.36339918996282139126e-04 // 0x3F4B67BA4CDAD5D1
W6 = -1.63092934096575273989e-03 // 0xBF5AB89D0B9E43E4
Tt = -3.63867699703950536541e-18 // 0xBC50C7CAA48A971F
)
// TODO(rsc): Remove manual inlining of IsNaN, IsInf
// when compiler does it for us
@ -249,28 +266,28 @@ func Lgamma(x float64) (lgamma float64, sign int) {
switch i {
case 0:
z := y * y
p1 := A0 + z*(A2+z*(A4+z*(A6+z*(A8+z*A10))))
p2 := z * (A1 + z*(A3+z*(A5+z*(A7+z*(A9+z*A11)))))
p1 := _lgamA[0] + z*(_lgamA[2]+z*(_lgamA[4]+z*(_lgamA[6]+z*(_lgamA[8]+z*_lgamA[10]))))
p2 := z * (_lgamA[1] + z*(+_lgamA[3]+z*(_lgamA[5]+z*(_lgamA[7]+z*(_lgamA[9]+z*_lgamA[11])))))
p := y*p1 + p2
lgamma += (p - 0.5*y)
case 1:
z := y * y
w := z * y
p1 := T0 + w*(T3+w*(T6+w*(T9+w*T12))) // parallel comp
p2 := T1 + w*(T4+w*(T7+w*(T10+w*T13)))
p3 := T2 + w*(T5+w*(T8+w*(T11+w*T14)))
p1 := _lgamT[0] + w*(_lgamT[3]+w*(_lgamT[6]+w*(_lgamT[9]+w*_lgamT[12]))) // parallel comp
p2 := _lgamT[1] + w*(_lgamT[4]+w*(_lgamT[7]+w*(_lgamT[10]+w*_lgamT[13])))
p3 := _lgamT[2] + w*(_lgamT[5]+w*(_lgamT[8]+w*(_lgamT[11]+w*_lgamT[14])))
p := z*p1 - (Tt - w*(p2+y*p3))
lgamma += (Tf + p)
case 2:
p1 := y * (U0 + y*(U1+y*(U2+y*(U3+y*(U4+y*U5)))))
p2 := 1 + y*(V1+y*(V2+y*(V3+y*(V4+y*V5))))
p1 := y * (_lgamU[0] + y*(_lgamU[1]+y*(_lgamU[2]+y*(_lgamU[3]+y*(_lgamU[4]+y*_lgamU[5])))))
p2 := 1 + y*(_lgamV[1]+y*(_lgamV[2]+y*(_lgamV[3]+y*(_lgamV[4]+y*_lgamV[5]))))
lgamma += (-0.5*y + p1/p2)
}
case x < 8: // 2 <= x < 8
i := int(x)
y := x - float64(i)
p := y * (S0 + y*(S1+y*(S2+y*(S3+y*(S4+y*(S5+y*S6))))))
q := 1 + y*(R1+y*(R2+y*(R3+y*(R4+y*(R5+y*R6)))))
p := y * (_lgamS[0] + y*(_lgamS[1]+y*(_lgamS[2]+y*(_lgamS[3]+y*(_lgamS[4]+y*(_lgamS[5]+y*_lgamS[6]))))))
q := 1 + y*(_lgamR[1]+y*(_lgamR[2]+y*(_lgamR[3]+y*(_lgamR[4]+y*(_lgamR[5]+y*_lgamR[6])))))
lgamma = 0.5*y + p/q
z := 1.0 // Lgamma(1+s) = Log(s) + Lgamma(s)
switch i {
@ -294,7 +311,7 @@ func Lgamma(x float64) (lgamma float64, sign int) {
t := Log(x)
z := 1 / x
y := z * z
w := W0 + z*(W1+y*(W2+y*(W3+y*(W4+y*(W5+y*W6)))))
w := _lgamW[0] + z*(_lgamW[1]+y*(_lgamW[2]+y*(_lgamW[3]+y*(_lgamW[4]+y*(_lgamW[5]+y*_lgamW[6])))))
lgamma = (x-0.5)*(t-1) + w
default: // 2**58 <= x <= Inf
lgamma = x * (Log(x) - 1)

View File

@ -160,7 +160,7 @@ type sliceReaderAt []byte
func (r sliceReaderAt) ReadAt(b []byte, off int64) (int, error) {
if int(off) >= len(r) || off < 0 {
return 0, os.EINVAL
return 0, io.ErrUnexpectedEOF
}
n := copy(b, r[int(off):])
return n, nil

View File

@ -6,19 +6,11 @@
package mime
import (
"bufio"
"fmt"
"os"
"strings"
"sync"
)
var typeFiles = []string{
"/etc/mime.types",
"/etc/apache2/mime.types",
"/etc/apache/mime.types",
}
var mimeTypes = map[string]string{
".css": "text/css; charset=utf-8",
".gif": "image/gif",
@ -33,46 +25,13 @@ var mimeTypes = map[string]string{
var mimeLock sync.RWMutex
func loadMimeFile(filename string) {
f, err := os.Open(filename)
if err != nil {
return
}
reader := bufio.NewReader(f)
for {
line, err := reader.ReadString('\n')
if err != nil {
f.Close()
return
}
fields := strings.Fields(line)
if len(fields) <= 1 || fields[0][0] == '#' {
continue
}
mimeType := fields[0]
for _, ext := range fields[1:] {
if ext[0] == '#' {
break
}
setExtensionType("."+ext, mimeType)
}
}
}
func initMime() {
for _, filename := range typeFiles {
loadMimeFile(filename)
}
}
var once sync.Once
// TypeByExtension returns the MIME type associated with the file extension ext.
// The extension ext should begin with a leading dot, as in ".html".
// When ext has no associated type, TypeByExtension returns "".
//
// The built-in table is small but is is augmented by the local
// The built-in table is small but on unix it is augmented by the local
// system's mime.types file(s) if available under one or more of these
// names:
//
@ -80,6 +39,8 @@ var once sync.Once
// /etc/apache2/mime.types
// /etc/apache/mime.types
//
// Windows system mime types are extracted from registry.
//
// Text types have the charset parameter set to "utf-8" by default.
func TypeByExtension(ext string) string {
once.Do(initMime)

View File

@ -6,15 +6,9 @@ package mime
import "testing"
var typeTests = map[string]string{
".t1": "application/test",
".t2": "text/test; charset=utf-8",
".png": "image/png",
}
var typeTests = initMimeForTests()
func TestTypeByExtension(t *testing.T) {
typeFiles = []string{"test.types"}
for ext, want := range typeTests {
val := TypeByExtension(ext)
if val != want {

View File

@ -0,0 +1,59 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mime
import (
"bufio"
"os"
"strings"
)
var typeFiles = []string{
"/etc/mime.types",
"/etc/apache2/mime.types",
"/etc/apache/mime.types",
}
func loadMimeFile(filename string) {
f, err := os.Open(filename)
if err != nil {
return
}
reader := bufio.NewReader(f)
for {
line, err := reader.ReadString('\n')
if err != nil {
f.Close()
return
}
fields := strings.Fields(line)
if len(fields) <= 1 || fields[0][0] == '#' {
continue
}
mimeType := fields[0]
for _, ext := range fields[1:] {
if ext[0] == '#' {
break
}
setExtensionType("."+ext, mimeType)
}
}
}
func initMime() {
for _, filename := range typeFiles {
loadMimeFile(filename)
}
}
func initMimeForTests() map[string]string {
typeFiles = []string{"test.types"}
return map[string]string{
".t1": "application/test",
".t2": "text/test; charset=utf-8",
".png": "image/png",
}
}

View File

@ -0,0 +1,61 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mime
import (
"syscall"
"unsafe"
)
func initMime() {
var root syscall.Handle
if syscall.RegOpenKeyEx(syscall.HKEY_CLASSES_ROOT, syscall.StringToUTF16Ptr(`\`),
0, syscall.KEY_READ, &root) != 0 {
return
}
defer syscall.RegCloseKey(root)
var count uint32
if syscall.RegQueryInfoKey(root, nil, nil, nil, &count, nil, nil, nil, nil, nil, nil, nil) != 0 {
return
}
var buf [1 << 10]uint16
for i := uint32(0); i < count; i++ {
n := uint32(len(buf))
if syscall.RegEnumKeyEx(root, i, &buf[0], &n, nil, nil, nil, nil) != 0 {
continue
}
ext := syscall.UTF16ToString(buf[:])
if len(ext) < 2 || ext[0] != '.' { // looking for extensions only
continue
}
var h syscall.Handle
if syscall.RegOpenKeyEx(
syscall.HKEY_CLASSES_ROOT, syscall.StringToUTF16Ptr(`\`+ext),
0, syscall.KEY_READ, &h) != 0 {
continue
}
var typ uint32
n = uint32(len(buf) * 2) // api expects array of bytes, not uint16
if syscall.RegQueryValueEx(
h, syscall.StringToUTF16Ptr("Content Type"),
nil, &typ, (*byte)(unsafe.Pointer(&buf[0])), &n) != 0 {
syscall.RegCloseKey(h)
continue
}
syscall.RegCloseKey(h)
if typ != syscall.REG_SZ { // null terminated strings only
continue
}
mimeType := syscall.UTF16ToString(buf[:])
setExtensionType(ext, mimeType)
}
}
func initMimeForTests() map[string]string {
return map[string]string{
".bmp": "image/bmp",
".png": "image/png",
}
}

View File

@ -109,7 +109,7 @@ func cgoLookupIPCNAME(name string) (addrs []IP, cname string, err error, complet
if gerrno == syscall.EAI_NONAME {
str = noSuchHost
} else if gerrno == syscall.EAI_SYSTEM {
str = syscall.Errstr(syscall.GetErrno())
str = syscall.GetErrno().Error()
} else {
str = bytePtrToString(libc_gai_strerror(gerrno))
}

View File

@ -278,8 +278,8 @@ func startServer() {
func newFD(fd, family, proto int, net string) (f *netFD, err error) {
onceStartServer.Do(startServer)
if e := syscall.SetNonblock(fd, true); e != 0 {
return nil, os.Errno(e)
if e := syscall.SetNonblock(fd, true); e != nil {
return nil, e
}
f = &netFD{
sysfd: fd,
@ -306,19 +306,19 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
}
func (fd *netFD) connect(ra syscall.Sockaddr) (err error) {
e := syscall.Connect(fd.sysfd, ra)
if e == syscall.EINPROGRESS {
var errno int
err = syscall.Connect(fd.sysfd, ra)
if err == syscall.EINPROGRESS {
pollserver.WaitWrite(fd)
e, errno = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
if errno != 0 {
return os.NewSyscallError("getsockopt", errno)
var e int
e, err = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
if err != nil {
return os.NewSyscallError("getsockopt", err)
}
if e != 0 {
err = syscall.Errno(e)
}
}
if e != 0 {
return os.Errno(e)
}
return nil
return err
}
// Add a reference to this fd.
@ -362,9 +362,9 @@ func (fd *netFD) shutdown(how int) error {
if fd == nil || fd.sysfile == nil {
return os.EINVAL
}
errno := syscall.Shutdown(fd.sysfd, how)
if errno != 0 {
return &OpError{"shutdown", fd.net, fd.laddr, os.Errno(errno)}
err := syscall.Shutdown(fd.sysfd, how)
if err != nil {
return &OpError{"shutdown", fd.net, fd.laddr, err}
}
return nil
}
@ -377,6 +377,14 @@ func (fd *netFD) CloseWrite() error {
return fd.shutdown(syscall.SHUT_WR)
}
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
var errTimeout error = &timeoutError{}
func (fd *netFD) Read(p []byte) (n int, err error) {
if fd == nil {
return 0, os.EINVAL
@ -393,24 +401,24 @@ func (fd *netFD) Read(p []byte) (n int, err error) {
} else {
fd.rdeadline = 0
}
var oserr error
for {
var errno int
n, errno = syscall.Read(fd.sysfile.Fd(), p)
if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
pollserver.WaitRead(fd)
continue
n, err = syscall.Read(fd.sysfile.Fd(), p)
if err == syscall.EAGAIN {
if fd.rdeadline >= 0 {
pollserver.WaitRead(fd)
continue
}
err = errTimeout
}
if errno != 0 {
if err != nil {
n = 0
oserr = os.Errno(errno)
} else if n == 0 && errno == 0 && fd.proto != syscall.SOCK_DGRAM {
} else if n == 0 && err == nil && fd.proto != syscall.SOCK_DGRAM {
err = io.EOF
}
break
}
if oserr != nil {
err = &OpError{"read", fd.net, fd.raddr, oserr}
if err != nil && err != io.EOF {
err = &OpError{"read", fd.net, fd.raddr, err}
}
return
}
@ -428,22 +436,22 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
} else {
fd.rdeadline = 0
}
var oserr error
for {
var errno int
n, sa, errno = syscall.Recvfrom(fd.sysfd, p, 0)
if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
pollserver.WaitRead(fd)
continue
n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0)
if err == syscall.EAGAIN {
if fd.rdeadline >= 0 {
pollserver.WaitRead(fd)
continue
}
err = errTimeout
}
if errno != 0 {
if err != nil {
n = 0
oserr = os.Errno(errno)
}
break
}
if oserr != nil {
err = &OpError{"read", fd.net, fd.laddr, oserr}
if err != nil {
err = &OpError{"read", fd.net, fd.laddr, err}
}
return
}
@ -461,24 +469,22 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S
} else {
fd.rdeadline = 0
}
var oserr error
for {
var errno int
n, oobn, flags, sa, errno = syscall.Recvmsg(fd.sysfd, p, oob, 0)
if errno == syscall.EAGAIN && fd.rdeadline >= 0 {
pollserver.WaitRead(fd)
continue
n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0)
if err == syscall.EAGAIN {
if fd.rdeadline >= 0 {
pollserver.WaitRead(fd)
continue
}
err = errTimeout
}
if errno != 0 {
oserr = os.Errno(errno)
}
if n == 0 {
oserr = io.EOF
if err == nil && n == 0 {
err = io.EOF
}
break
}
if oserr != nil {
err = &OpError{"read", fd.net, fd.laddr, oserr}
if err != nil && err != io.EOF {
err = &OpError{"read", fd.net, fd.laddr, err}
return
}
return
@ -501,32 +507,34 @@ func (fd *netFD) Write(p []byte) (n int, err error) {
fd.wdeadline = 0
}
nn := 0
var oserr error
for {
n, errno := syscall.Write(fd.sysfile.Fd(), p[nn:])
var n int
n, err = syscall.Write(fd.sysfile.Fd(), p[nn:])
if n > 0 {
nn += n
}
if nn == len(p) {
break
}
if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
pollserver.WaitWrite(fd)
continue
if err == syscall.EAGAIN {
if fd.wdeadline >= 0 {
pollserver.WaitWrite(fd)
continue
}
err = errTimeout
}
if errno != 0 {
if err != nil {
n = 0
oserr = os.Errno(errno)
break
}
if n == 0 {
oserr = io.ErrUnexpectedEOF
err = io.ErrUnexpectedEOF
break
}
}
if oserr != nil {
err = &OpError{"write", fd.net, fd.raddr, oserr}
if err != nil {
err = &OpError{"write", fd.net, fd.raddr, err}
}
return nn, err
}
@ -544,22 +552,21 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
} else {
fd.wdeadline = 0
}
var oserr error
for {
errno := syscall.Sendto(fd.sysfd, p, 0, sa)
if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
pollserver.WaitWrite(fd)
continue
}
if errno != 0 {
oserr = os.Errno(errno)
err = syscall.Sendto(fd.sysfd, p, 0, sa)
if err == syscall.EAGAIN {
if fd.wdeadline >= 0 {
pollserver.WaitWrite(fd)
continue
}
err = errTimeout
}
break
}
if oserr == nil {
if err == nil {
n = len(p)
} else {
err = &OpError{"write", fd.net, fd.raddr, oserr}
err = &OpError{"write", fd.net, fd.raddr, err}
}
return
}
@ -577,24 +584,22 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
} else {
fd.wdeadline = 0
}
var oserr error
for {
var errno int
errno = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
if errno == syscall.EAGAIN && fd.wdeadline >= 0 {
pollserver.WaitWrite(fd)
continue
}
if errno != 0 {
oserr = os.Errno(errno)
err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
if err == syscall.EAGAIN {
if fd.wdeadline >= 0 {
pollserver.WaitWrite(fd)
continue
}
err = errTimeout
}
break
}
if oserr == nil {
if err == nil {
n = len(p)
oobn = len(oob)
} else {
err = &OpError{"write", fd.net, fd.raddr, oserr}
err = &OpError{"write", fd.net, fd.raddr, err}
}
return
}
@ -615,25 +620,26 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err err
// See ../syscall/exec.go for description of ForkLock.
// It is okay to hold the lock across syscall.Accept
// because we have put fd.sysfd into non-blocking mode.
syscall.ForkLock.RLock()
var s, e int
var s int
var rsa syscall.Sockaddr
for {
if fd.closing {
syscall.ForkLock.RUnlock()
return nil, os.EINVAL
}
s, rsa, e = syscall.Accept(fd.sysfd)
if e != syscall.EAGAIN || fd.rdeadline < 0 {
break
}
syscall.ForkLock.RUnlock()
pollserver.WaitRead(fd)
syscall.ForkLock.RLock()
}
if e != 0 {
syscall.ForkLock.RUnlock()
return nil, &OpError{"accept", fd.net, fd.laddr, os.Errno(e)}
s, rsa, err = syscall.Accept(fd.sysfd)
if err != nil {
syscall.ForkLock.RUnlock()
if err == syscall.EAGAIN {
if fd.rdeadline >= 0 {
pollserver.WaitRead(fd)
continue
}
err = errTimeout
}
return nil, &OpError{"accept", fd.net, fd.laddr, err}
}
break
}
syscall.CloseOnExec(s)
syscall.ForkLock.RUnlock()
@ -648,19 +654,19 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err err
}
func (fd *netFD) dup() (f *os.File, err error) {
ns, e := syscall.Dup(fd.sysfd)
if e != 0 {
return nil, &OpError{"dup", fd.net, fd.laddr, os.Errno(e)}
ns, err := syscall.Dup(fd.sysfd)
if err != nil {
return nil, &OpError{"dup", fd.net, fd.laddr, err}
}
// We want blocking mode for the new fd, hence the double negative.
if e = syscall.SetNonblock(ns, false); e != 0 {
return nil, &OpError{"setnonblock", fd.net, fd.laddr, os.Errno(e)}
if err = syscall.SetNonblock(ns, false); err != nil {
return nil, &OpError{"setnonblock", fd.net, fd.laddr, err}
}
return os.NewFile(ns, fd.sysfile.Name()), nil
}
func closesocket(s int) (errno int) {
func closesocket(s int) error {
return syscall.Close(s)
}

View File

@ -35,12 +35,12 @@ type pollster struct {
func newpollster() (p *pollster, err error) {
p = new(pollster)
var e int
var e error
// The arg to epoll_create is a hint to the kernel
// about the number of FDs we will care about.
// We don't know, and since 2.6.8 the kernel ignores it anyhow.
if p.epfd, e = syscall.EpollCreate(16); e != 0 {
if p.epfd, e = syscall.EpollCreate(16); e != nil {
return nil, os.NewSyscallError("epoll_create", e)
}
p.events = make(map[int]uint32)
@ -68,7 +68,7 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
} else {
op = syscall.EPOLL_CTL_ADD
}
if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != 0 {
if e := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); e != nil {
return false, os.NewSyscallError("epoll_ctl", e)
}
p.events[fd] = p.ctlEvent.Events
@ -97,13 +97,13 @@ func (p *pollster) StopWaiting(fd int, bits uint) {
if int32(events)&^syscall.EPOLLONESHOT != 0 {
p.ctlEvent.Fd = int32(fd)
p.ctlEvent.Events = events
if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != 0 {
print("Epoll modify fd=", fd, ": ", os.Errno(e).Error(), "\n")
if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); e != nil {
print("Epoll modify fd=", fd, ": ", e.Error(), "\n")
}
p.events[fd] = events
} else {
if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); e != 0 {
print("Epoll delete fd=", fd, ": ", os.Errno(e).Error(), "\n")
if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); e != nil {
print("Epoll delete fd=", fd, ": ", e.Error(), "\n")
}
delete(p.events, fd)
}
@ -141,7 +141,7 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err erro
n, e := syscall.EpollWait(p.epfd, p.waitEventBuf[0:], msec)
s.Lock()
if e != 0 {
if e != nil {
if e == syscall.EAGAIN || e == syscall.EINTR {
continue
}

View File

@ -23,9 +23,8 @@ type pollster struct {
func newpollster() (p *pollster, err error) {
p = new(pollster)
var e int
if p.kq, e = syscall.Kqueue(); e != 0 {
return nil, os.NewSyscallError("kqueue", e)
if p.kq, err = syscall.Kqueue(); err != nil {
return nil, os.NewSyscallError("kqueue", err)
}
p.events = p.eventbuf[0:0]
return p, nil
@ -50,14 +49,14 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) {
syscall.SetKevent(ev, fd, kmode, flags)
n, e := syscall.Kevent(p.kq, p.kbuf[:], nil, nil)
if e != 0 {
if e != nil {
return false, os.NewSyscallError("kevent", e)
}
if n != 1 || (ev.Flags&syscall.EV_ERROR) == 0 || int(ev.Ident) != fd || int(ev.Filter) != kmode {
return false, os.NewSyscallError("kqueue phase error", e)
}
if ev.Data != 0 {
return false, os.Errno(int(ev.Data))
return false, syscall.Errno(int(ev.Data))
}
return false, nil
}
@ -91,7 +90,7 @@ func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err erro
nn, e := syscall.Kevent(p.kq, nil, p.eventbuf[:], t)
s.Lock()
if e != 0 {
if e != nil {
if e == syscall.EINTR {
continue
}

View File

@ -26,11 +26,11 @@ func init() {
var d syscall.WSAData
e := syscall.WSAStartup(uint32(0x202), &d)
if e != 0 {
initErr = os.NewSyscallError("WSAStartup", e)
initErr = os.NewSyscallError("WSAStartup", syscall.Errno(e))
}
}
func closesocket(s syscall.Handle) (errno int) {
func closesocket(s syscall.Handle) (err error) {
return syscall.Closesocket(s)
}
@ -38,13 +38,13 @@ func closesocket(s syscall.Handle) (errno int) {
type anOpIface interface {
Op() *anOp
Name() string
Submit() (errno int)
Submit() (err error)
}
// IO completion result parameters.
type ioResult struct {
qty uint32
err int
err error
}
// anOp implements functionality common to all io operations.
@ -54,7 +54,7 @@ type anOp struct {
o syscall.Overlapped
resultc chan ioResult
errnoc chan int
errnoc chan error
fd *netFD
}
@ -71,7 +71,7 @@ func (o *anOp) Init(fd *netFD, mode int) {
}
o.resultc = fd.resultc[i]
if fd.errnoc[i] == nil {
fd.errnoc[i] = make(chan int)
fd.errnoc[i] = make(chan error)
}
o.errnoc = fd.errnoc[i]
}
@ -111,14 +111,14 @@ func (s *resultSrv) Run() {
for {
r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE)
switch {
case r.err == 0:
case r.err == nil:
// Dequeued successfully completed io packet.
case r.err == syscall.WAIT_TIMEOUT && o == nil:
case r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil:
// Wait has timed out (should not happen now, but might be used in the future).
panic("GetQueuedCompletionStatus timed out")
case o == nil:
// Failed to dequeue anything -> report the error.
panic("GetQueuedCompletionStatus failed " + syscall.Errstr(r.err))
panic("GetQueuedCompletionStatus failed " + r.err.Error())
default:
// Dequeued failed io packet.
}
@ -153,7 +153,7 @@ func (s *ioSrv) ProcessRemoteIO() {
// inline, or, if timeouts are employed, passes the request onto
// a special goroutine and waits for completion or cancels request.
func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err error) {
var e int
var e error
o := oi.Op()
if deadline_delta > 0 {
// Send request to a special dedicated thread,
@ -164,12 +164,12 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err error) {
e = oi.Submit()
}
switch e {
case 0:
case nil:
// IO completed immediately, but we need to get our completion message anyway.
case syscall.ERROR_IO_PENDING:
// IO started, and we have to wait for its completion.
default:
return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, os.Errno(e)}
return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, e}
}
// Wait for our request to complete.
var r ioResult
@ -187,8 +187,8 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline_delta int64) (n int, err error) {
} else {
r = <-o.resultc
}
if r.err != 0 {
err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, os.Errno(r.err)}
if r.err != nil {
err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, r.err}
}
return int(r.qty), err
}
@ -200,10 +200,10 @@ var onceStartServer sync.Once
func startServer() {
resultsrv = new(resultSrv)
var errno int
resultsrv.iocp, errno = syscall.CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 1)
if errno != 0 {
panic("CreateIoCompletionPort failed " + syscall.Errstr(errno))
var err error
resultsrv.iocp, err = syscall.CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 1)
if err != nil {
panic("CreateIoCompletionPort: " + err.Error())
}
go resultsrv.Run()
@ -228,7 +228,7 @@ type netFD struct {
laddr Addr
raddr Addr
resultc [2]chan ioResult // read/write completion results
errnoc [2]chan int // read/write submit or cancel operation errors
errnoc [2]chan error // read/write submit or cancel operation errors
// owned by client
rdeadline_delta int64
@ -256,8 +256,8 @@ func newFD(fd syscall.Handle, family, proto int, net string) (f *netFD, err erro
}
onceStartServer.Do(startServer)
// Associate our socket with resultsrv.iocp.
if _, e := syscall.CreateIoCompletionPort(syscall.Handle(fd), resultsrv.iocp, 0, 0); e != 0 {
return nil, os.Errno(e)
if _, e := syscall.CreateIoCompletionPort(syscall.Handle(fd), resultsrv.iocp, 0, 0); e != nil {
return nil, e
}
return allocFD(fd, family, proto, net), nil
}
@ -268,11 +268,7 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
}
func (fd *netFD) connect(ra syscall.Sockaddr) (err error) {
e := syscall.Connect(fd.sysfd, ra)
if e != 0 {
return os.Errno(e)
}
return nil
return syscall.Connect(fd.sysfd, ra)
}
// Add a reference to this fd.
@ -317,9 +313,9 @@ func (fd *netFD) shutdown(how int) error {
if fd == nil || fd.sysfd == syscall.InvalidHandle {
return os.EINVAL
}
errno := syscall.Shutdown(fd.sysfd, how)
if errno != 0 {
return &OpError{"shutdown", fd.net, fd.laddr, os.Errno(errno)}
err := syscall.Shutdown(fd.sysfd, how)
if err != nil {
return &OpError{"shutdown", fd.net, fd.laddr, err}
}
return nil
}
@ -338,7 +334,7 @@ type readOp struct {
bufOp
}
func (o *readOp) Submit() (errno int) {
func (o *readOp) Submit() (err error) {
var d, f uint32
return syscall.WSARecv(syscall.Handle(o.fd.sysfd), &o.buf, 1, &d, &f, &o.o, nil)
}
@ -375,7 +371,7 @@ type readFromOp struct {
rsan int32
}
func (o *readFromOp) Submit() (errno int) {
func (o *readFromOp) Submit() (err error) {
var d, f uint32
return syscall.WSARecvFrom(o.fd.sysfd, &o.buf, 1, &d, &f, &o.rsa, &o.rsan, &o.o, nil)
}
@ -415,7 +411,7 @@ type writeOp struct {
bufOp
}
func (o *writeOp) Submit() (errno int) {
func (o *writeOp) Submit() (err error) {
var d uint32
return syscall.WSASend(o.fd.sysfd, &o.buf, 1, &d, 0, &o.o, nil)
}
@ -447,7 +443,7 @@ type writeToOp struct {
sa syscall.Sockaddr
}
func (o *writeToOp) Submit() (errno int) {
func (o *writeToOp) Submit() (err error) {
var d uint32
return syscall.WSASendto(o.fd.sysfd, &o.buf, 1, &d, 0, o.sa, &o.o, nil)
}
@ -484,7 +480,7 @@ type acceptOp struct {
attrs [2]syscall.RawSockaddrAny // space for local and remote address only
}
func (o *acceptOp) Submit() (errno int) {
func (o *acceptOp) Submit() (err error) {
var d uint32
l := uint32(unsafe.Sizeof(o.attrs[0]))
return syscall.AcceptEx(o.fd.sysfd, o.newsock,
@ -506,17 +502,17 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err err
// See ../syscall/exec.go for description of ForkLock.
syscall.ForkLock.RLock()
s, e := syscall.Socket(fd.family, fd.proto, 0)
if e != 0 {
if e != nil {
syscall.ForkLock.RUnlock()
return nil, os.Errno(e)
return nil, e
}
syscall.CloseOnExec(s)
syscall.ForkLock.RUnlock()
// Associate our new socket with IOCP.
onceStartServer.Do(startServer)
if _, e = syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); e != 0 {
return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, os.Errno(e)}
if _, e = syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); e != nil {
return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, e}
}
// Submit accept request.
@ -531,9 +527,9 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err err
// Inherit properties of the listening socket.
e = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))
if e != 0 {
if e != nil {
closesocket(s)
return nil, err
return nil, e
}
// Get local and peer addr out of AcceptEx buffer.

View File

@ -13,12 +13,12 @@ import (
func newFileFD(f *os.File) (nfd *netFD, err error) {
fd, errno := syscall.Dup(f.Fd())
if errno != 0 {
if errno != nil {
return nil, os.NewSyscallError("dup", errno)
}
proto, errno := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
if errno != 0 {
if errno != nil {
return nil, os.NewSyscallError("getsockopt", errno)
}

View File

@ -7,8 +7,8 @@
package net
import (
"os"
"sync"
"time"
)
const cacheMaxAge = int64(300) // 5 minutes.
@ -26,7 +26,7 @@ var hosts struct {
}
func readHosts() {
now, _, _ := os.Time()
now := time.Seconds()
hp := hostsPath
if len(hosts.byName) == 0 || hosts.time+cacheMaxAge <= now || hosts.path != hp {
hs := make(map[string][]string)
@ -51,7 +51,7 @@ func readHosts() {
}
}
// Update the data cache.
hosts.time, _, _ = os.Time()
hosts.time = time.Seconds()
hosts.path = hp
hosts.byName = hs
hosts.byAddr = is

View File

@ -363,14 +363,13 @@ func TestCopyError(t *testing.T) {
}
conn.Close()
if tries := 0; childRunning() {
for tries < 15 && childRunning() {
time.Sleep(50e6 * int64(tries))
tries++
}
if childRunning() {
t.Fatalf("post-conn.Close, expected child to be gone")
}
tries := 0
for tries < 15 && childRunning() {
time.Sleep(50e6 * int64(tries))
tries++
}
if childRunning() {
t.Fatalf("post-conn.Close, expected child to be gone")
}
}

View File

@ -2,20 +2,137 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// The wire protocol for HTTP's "chunked" Transfer-Encoding.
// This code is duplicated in httputil/chunked.go.
// Please make any changes in both files.
package http
import (
"bufio"
"bytes"
"errors"
"io"
"strconv"
)
const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
var ErrLineTooLong = errors.New("header line too long")
// newChunkedReader returns a new chunkedReader that translates the data read from r
// out of HTTP "chunked" format before returning it.
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
//
// newChunkedReader is not needed by normal applications. The http package
// automatically decodes chunking when reading response bodies.
func newChunkedReader(r io.Reader) io.Reader {
br, ok := r.(*bufio.Reader)
if !ok {
br = bufio.NewReader(r)
}
return &chunkedReader{r: br}
}
type chunkedReader struct {
r *bufio.Reader
n uint64 // unread bytes in chunk
err error
}
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
var line string
line, cr.err = readLine(cr.r)
if cr.err != nil {
return
}
cr.n, cr.err = strconv.Btoui64(line, 16)
if cr.err != nil {
return
}
if cr.n == 0 {
cr.err = io.EOF
}
}
func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
if cr.err != nil {
return 0, cr.err
}
if cr.n == 0 {
cr.beginChunk()
if cr.err != nil {
return 0, cr.err
}
}
if uint64(len(b)) > cr.n {
b = b[0:cr.n]
}
n, cr.err = cr.r.Read(b)
cr.n -= uint64(n)
if cr.n == 0 && cr.err == nil {
// end of chunk (CRLF)
b := make([]byte, 2)
if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
if b[0] != '\r' || b[1] != '\n' {
cr.err = errors.New("malformed chunked encoding")
}
}
}
return n, cr.err
}
// Read a line of bytes (up to \n) from b.
// Give up if the line exceeds maxLineLength.
// The returned bytes are a pointer into storage in
// the bufio, so they are only valid until the next bufio read.
func readLineBytes(b *bufio.Reader) (p []byte, err error) {
if p, err = b.ReadSlice('\n'); err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
if err == io.EOF {
err = io.ErrUnexpectedEOF
} else if err == bufio.ErrBufferFull {
err = ErrLineTooLong
}
return nil, err
}
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
// Chop off trailing white space.
p = bytes.TrimRight(p, " \r\t\n")
return p, nil
}
// readLineBytes, but convert the bytes into a string.
func readLine(b *bufio.Reader) (s string, err error) {
p, e := readLineBytes(b)
if e != nil {
return "", e
}
return string(p), nil
}
// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP
// "chunked" format before writing them to w. Closing the returned chunkedWriter
// sends the final 0-length chunk that marks the end of the stream.
//
// newChunkedWriter is not needed by normal applications. The http
// package adds chunking automatically if handlers don't set a
// Content-Length header. Using newChunkedWriter inside a handler
// would result in double chunking or chunking with a Content-Length
// length, both of which are wrong.
func newChunkedWriter(w io.Writer) io.WriteCloser {
return &chunkedWriter{w}
}
// Writing to ChunkedWriter translates to writing in HTTP chunked Transfer
// Encoding wire format to the underlying Wire writer.
// Writing to chunkedWriter translates to writing in HTTP chunked Transfer
// Encoding wire format to the underlying Wire chunkedWriter.
type chunkedWriter struct {
Wire io.Writer
}
@ -51,7 +168,3 @@ func (cw *chunkedWriter) Close() error {
_, err := io.WriteString(cw.Wire, "0\r\n")
return err
}
func newChunkedReader(r *bufio.Reader) io.Reader {
return &chunkedReader{r: r}
}

View File

@ -0,0 +1,39 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This code is duplicated in httputil/chunked_test.go.
// Please make any changes in both files.
package http
import (
"bytes"
"io/ioutil"
"testing"
)
func TestChunk(t *testing.T) {
var b bytes.Buffer
w := newChunkedWriter(&b)
const chunk1 = "hello, "
const chunk2 = "world! 0123456789abcdef"
w.Write([]byte(chunk1))
w.Write([]byte(chunk2))
w.Close()
if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e {
t.Fatalf("chunk writer wrote %q; want %q", g, e)
}
r := newChunkedReader(&b)
data, err := ioutil.ReadAll(r)
if err != nil {
t.Logf(`data: "%s"`, data)
t.Fatalf("ReadAll from reader: %v", err)
}
if g, e := string(data), chunk1+chunk2; g != e {
t.Errorf("chunk reader read %q; want %q", g, e)
}
}

View File

@ -26,6 +26,31 @@ var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "User-agent: go\nDisallow: /something/")
})
// pedanticReadAll works like ioutil.ReadAll but additionally
// verifies that r obeys the documented io.Reader contract.
func pedanticReadAll(r io.Reader) (b []byte, err error) {
var bufa [64]byte
buf := bufa[:]
for {
n, err := r.Read(buf)
if n == 0 && err == nil {
return nil, fmt.Errorf("Read: n=0 with err=nil")
}
b = append(b, buf[:n]...)
if err == io.EOF {
n, err := r.Read(buf)
if n != 0 || err != io.EOF {
return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err)
}
return b, nil
}
if err != nil {
return b, err
}
}
panic("unreachable")
}
func TestClient(t *testing.T) {
ts := httptest.NewServer(robotsTxtHandler)
defer ts.Close()
@ -33,7 +58,7 @@ func TestClient(t *testing.T) {
r, err := Get(ts.URL)
var b []byte
if err == nil {
b, err = ioutil.ReadAll(r.Body)
b, err = pedanticReadAll(r.Body)
r.Body.Close()
}
if err != nil {

View File

@ -7,6 +7,7 @@ package fcgi
// This file implements FastCGI from the perspective of a child process.
import (
"errors"
"fmt"
"io"
"net"
@ -123,91 +124,103 @@ func (r *response) Close() error {
}
type child struct {
conn *conn
handler http.Handler
conn *conn
handler http.Handler
requests map[uint16]*request // keyed by request ID
}
func newChild(rwc net.Conn, handler http.Handler) *child {
return &child{newConn(rwc), handler}
func newChild(rwc io.ReadWriteCloser, handler http.Handler) *child {
return &child{
conn: newConn(rwc),
handler: handler,
requests: make(map[uint16]*request),
}
}
func (c *child) serve() {
requests := map[uint16]*request{}
defer c.conn.Close()
var rec record
var br beginRequest
for {
if err := rec.read(c.conn.rwc); err != nil {
return
}
req, ok := requests[rec.h.Id]
if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues {
// The spec says to ignore unknown request IDs.
continue
}
if ok && rec.h.Type == typeBeginRequest {
// The server is trying to begin a request with the same ID
// as an in-progress request. This is an error.
if err := c.handleRecord(&rec); err != nil {
return
}
switch rec.h.Type {
case typeBeginRequest:
if err := br.read(rec.content()); err != nil {
return
}
if br.role != roleResponder {
c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole)
break
}
requests[rec.h.Id] = newRequest(rec.h.Id, br.flags)
case typeParams:
// NOTE(eds): Technically a key-value pair can straddle the boundary
// between two packets. We buffer until we've received all parameters.
if len(rec.content()) > 0 {
req.rawParams = append(req.rawParams, rec.content()...)
break
}
req.parseParams()
case typeStdin:
content := rec.content()
if req.pw == nil {
var body io.ReadCloser
if len(content) > 0 {
// body could be an io.LimitReader, but it shouldn't matter
// as long as both sides are behaving.
body, req.pw = io.Pipe()
}
go c.serveRequest(req, body)
}
if len(content) > 0 {
// TODO(eds): This blocks until the handler reads from the pipe.
// If the handler takes a long time, it might be a problem.
req.pw.Write(content)
} else if req.pw != nil {
req.pw.Close()
}
case typeGetValues:
values := map[string]string{"FCGI_MPXS_CONNS": "1"}
c.conn.writePairs(0, typeGetValuesResult, values)
case typeData:
// If the filter role is implemented, read the data stream here.
case typeAbortRequest:
delete(requests, rec.h.Id)
c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete)
if !req.keepConn {
// connection will close upon return
return
}
default:
b := make([]byte, 8)
b[0] = rec.h.Type
c.conn.writeRecord(typeUnknownType, 0, b)
}
}
}
var errCloseConn = errors.New("fcgi: connection should be closed")
func (c *child) handleRecord(rec *record) error {
req, ok := c.requests[rec.h.Id]
if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues {
// The spec says to ignore unknown request IDs.
return nil
}
if ok && rec.h.Type == typeBeginRequest {
// The server is trying to begin a request with the same ID
// as an in-progress request. This is an error.
return errors.New("fcgi: received ID that is already in-flight")
}
switch rec.h.Type {
case typeBeginRequest:
var br beginRequest
if err := br.read(rec.content()); err != nil {
return err
}
if br.role != roleResponder {
c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole)
return nil
}
c.requests[rec.h.Id] = newRequest(rec.h.Id, br.flags)
case typeParams:
// NOTE(eds): Technically a key-value pair can straddle the boundary
// between two packets. We buffer until we've received all parameters.
if len(rec.content()) > 0 {
req.rawParams = append(req.rawParams, rec.content()...)
return nil
}
req.parseParams()
case typeStdin:
content := rec.content()
if req.pw == nil {
var body io.ReadCloser
if len(content) > 0 {
// body could be an io.LimitReader, but it shouldn't matter
// as long as both sides are behaving.
body, req.pw = io.Pipe()
}
go c.serveRequest(req, body)
}
if len(content) > 0 {
// TODO(eds): This blocks until the handler reads from the pipe.
// If the handler takes a long time, it might be a problem.
req.pw.Write(content)
} else if req.pw != nil {
req.pw.Close()
}
case typeGetValues:
values := map[string]string{"FCGI_MPXS_CONNS": "1"}
c.conn.writePairs(typeGetValuesResult, 0, values)
case typeData:
// If the filter role is implemented, read the data stream here.
case typeAbortRequest:
delete(c.requests, rec.h.Id)
c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete)
if !req.keepConn {
// connection will close upon return
return errCloseConn
}
default:
b := make([]byte, 8)
b[0] = byte(rec.h.Type)
c.conn.writeRecord(typeUnknownType, 0, b)
}
return nil
}
func (c *child) serveRequest(req *request, body io.ReadCloser) {
r := newResponse(c, req)
httpReq, err := cgi.RequestFromMap(req.params)

View File

@ -19,19 +19,22 @@ import (
"sync"
)
// recType is a record type, as defined by
// http://www.fastcgi.com/devkit/doc/fcgi-spec.html#S8
type recType uint8
const (
// Packet Types
typeBeginRequest = iota + 1
typeAbortRequest
typeEndRequest
typeParams
typeStdin
typeStdout
typeStderr
typeData
typeGetValues
typeGetValuesResult
typeUnknownType
typeBeginRequest recType = 1
typeAbortRequest recType = 2
typeEndRequest recType = 3
typeParams recType = 4
typeStdin recType = 5
typeStdout recType = 6
typeStderr recType = 7
typeData recType = 8
typeGetValues recType = 9
typeGetValuesResult recType = 10
typeUnknownType recType = 11
)
// keep the connection between web-server and responder open after request
@ -59,7 +62,7 @@ const headerLen = 8
type header struct {
Version uint8
Type uint8
Type recType
Id uint16
ContentLength uint16
PaddingLength uint8
@ -85,7 +88,7 @@ func (br *beginRequest) read(content []byte) error {
// not synchronized because we don't care what the contents are
var pad [maxPad]byte
func (h *header) init(recType uint8, reqId uint16, contentLength int) {
func (h *header) init(recType recType, reqId uint16, contentLength int) {
h.Version = 1
h.Type = recType
h.Id = reqId
@ -137,7 +140,7 @@ func (r *record) content() []byte {
}
// writeRecord writes and sends a single record.
func (c *conn) writeRecord(recType uint8, reqId uint16, b []byte) error {
func (c *conn) writeRecord(recType recType, reqId uint16, b []byte) error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.buf.Reset()
@ -167,12 +170,12 @@ func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8
return c.writeRecord(typeEndRequest, reqId, b)
}
func (c *conn) writePairs(recType uint8, reqId uint16, pairs map[string]string) error {
func (c *conn) writePairs(recType recType, reqId uint16, pairs map[string]string) error {
w := newWriter(c, recType, reqId)
b := make([]byte, 8)
for k, v := range pairs {
n := encodeSize(b, uint32(len(k)))
n += encodeSize(b[n:], uint32(len(k)))
n += encodeSize(b[n:], uint32(len(v)))
if _, err := w.Write(b[:n]); err != nil {
return err
}
@ -235,7 +238,7 @@ func (w *bufWriter) Close() error {
return w.closer.Close()
}
func newWriter(c *conn, recType uint8, reqId uint16) *bufWriter {
func newWriter(c *conn, recType recType, reqId uint16) *bufWriter {
s := &streamWriter{c: c, recType: recType, reqId: reqId}
w, _ := bufio.NewWriterSize(s, maxWrite)
return &bufWriter{s, w}
@ -245,7 +248,7 @@ func newWriter(c *conn, recType uint8, reqId uint16) *bufWriter {
// It only writes maxWrite bytes at a time.
type streamWriter struct {
c *conn
recType uint8
recType recType
reqId uint16
}

View File

@ -6,6 +6,7 @@ package fcgi
import (
"bytes"
"errors"
"io"
"testing"
)
@ -40,25 +41,25 @@ func TestSize(t *testing.T) {
var streamTests = []struct {
desc string
recType uint8
recType recType
reqId uint16
content []byte
raw []byte
}{
{"single record", typeStdout, 1, nil,
[]byte{1, typeStdout, 0, 1, 0, 0, 0, 0},
[]byte{1, byte(typeStdout), 0, 1, 0, 0, 0, 0},
},
// this data will have to be split into two records
{"two records", typeStdin, 300, make([]byte, 66000),
bytes.Join([][]byte{
// header for the first record
{1, typeStdin, 0x01, 0x2C, 0xFF, 0xFF, 1, 0},
{1, byte(typeStdin), 0x01, 0x2C, 0xFF, 0xFF, 1, 0},
make([]byte, 65536),
// header for the second
{1, typeStdin, 0x01, 0x2C, 0x01, 0xD1, 7, 0},
{1, byte(typeStdin), 0x01, 0x2C, 0x01, 0xD1, 7, 0},
make([]byte, 472),
// header for the empty record
{1, typeStdin, 0x01, 0x2C, 0, 0, 0, 0},
{1, byte(typeStdin), 0x01, 0x2C, 0, 0, 0, 0},
},
nil),
},
@ -111,3 +112,39 @@ outer:
}
}
}
type writeOnlyConn struct {
buf []byte
}
func (c *writeOnlyConn) Write(p []byte) (int, error) {
c.buf = append(c.buf, p...)
return len(p), nil
}
func (c *writeOnlyConn) Read(p []byte) (int, error) {
return 0, errors.New("conn is write-only")
}
func (c *writeOnlyConn) Close() error {
return nil
}
func TestGetValues(t *testing.T) {
var rec record
rec.h.Type = typeGetValues
wc := new(writeOnlyConn)
c := newChild(wc, nil)
err := c.handleRecord(&rec)
if err != nil {
t.Fatalf("handleRecord: %v", err)
}
const want = "\x01\n\x00\x00\x00\x12\x06\x00" +
"\x0f\x01FCGI_MPXS_CONNS1" +
"\x00\x00\x00\x00\x00\x00\x01\n\x00\x00\x00\x00\x00\x00"
if got := string(wc.buf); got != want {
t.Errorf(" got: %q\nwant: %q\n", got, want)
}
}

View File

@ -22,13 +22,19 @@ import (
// A Dir implements http.FileSystem using the native file
// system restricted to a specific directory tree.
//
// An empty Dir is treated as ".".
type Dir string
func (d Dir) Open(name string) (File, error) {
if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 {
return nil, errors.New("http: invalid character in file path")
}
f, err := os.Open(filepath.Join(string(d), filepath.FromSlash(path.Clean("/"+name))))
dir := string(d)
if dir == "" {
dir = "."
}
f, err := os.Open(filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name))))
if err != nil {
return nil, err
}

View File

@ -208,6 +208,20 @@ func TestDirJoin(t *testing.T) {
test(Dir("/etc/hosts"), "../")
}
func TestEmptyDirOpenCWD(t *testing.T) {
test := func(d Dir) {
name := "fs_test.go"
f, err := d.Open(name)
if err != nil {
t.Fatalf("open of %s: %v", name, err)
}
defer f.Close()
}
test(Dir(""))
test(Dir("."))
test(Dir("./"))
}
func TestServeFileContentType(t *testing.T) {
const ctype = "icecream/chocolate"
override := false
@ -247,6 +261,20 @@ func TestServeFileMimeType(t *testing.T) {
}
}
func TestServeFileFromCWD(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ServeFile(w, r, "fs_test.go")
}))
defer ts.Close()
r, err := Get(ts.URL)
if err != nil {
t.Fatal(err)
}
if r.StatusCode != 200 {
t.Fatalf("expected 200 OK, got %s", r.Status)
}
}
func TestServeFileWithContentEncoding(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Encoding", "foo")

View File

@ -2,18 +2,126 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// The wire protocol for HTTP's "chunked" Transfer-Encoding.
// This code is a duplicate of ../chunked.go with these edits:
// s/newChunked/NewChunked/g
// s/package http/package httputil/
// Please make any changes in both files.
package httputil
import (
"bufio"
"bytes"
"errors"
"io"
"net/http"
"strconv"
"strings"
)
// NewChunkedWriter returns a new writer that translates writes into HTTP
// "chunked" format before writing them to w. Closing the returned writer
const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
var ErrLineTooLong = errors.New("header line too long")
// NewChunkedReader returns a new chunkedReader that translates the data read from r
// out of HTTP "chunked" format before returning it.
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
//
// NewChunkedReader is not needed by normal applications. The http package
// automatically decodes chunking when reading response bodies.
func NewChunkedReader(r io.Reader) io.Reader {
br, ok := r.(*bufio.Reader)
if !ok {
br = bufio.NewReader(r)
}
return &chunkedReader{r: br}
}
type chunkedReader struct {
r *bufio.Reader
n uint64 // unread bytes in chunk
err error
}
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
var line string
line, cr.err = readLine(cr.r)
if cr.err != nil {
return
}
cr.n, cr.err = strconv.Btoui64(line, 16)
if cr.err != nil {
return
}
if cr.n == 0 {
cr.err = io.EOF
}
}
func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
if cr.err != nil {
return 0, cr.err
}
if cr.n == 0 {
cr.beginChunk()
if cr.err != nil {
return 0, cr.err
}
}
if uint64(len(b)) > cr.n {
b = b[0:cr.n]
}
n, cr.err = cr.r.Read(b)
cr.n -= uint64(n)
if cr.n == 0 && cr.err == nil {
// end of chunk (CRLF)
b := make([]byte, 2)
if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
if b[0] != '\r' || b[1] != '\n' {
cr.err = errors.New("malformed chunked encoding")
}
}
}
return n, cr.err
}
// Read a line of bytes (up to \n) from b.
// Give up if the line exceeds maxLineLength.
// The returned bytes are a pointer into storage in
// the bufio, so they are only valid until the next bufio read.
func readLineBytes(b *bufio.Reader) (p []byte, err error) {
if p, err = b.ReadSlice('\n'); err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
if err == io.EOF {
err = io.ErrUnexpectedEOF
} else if err == bufio.ErrBufferFull {
err = ErrLineTooLong
}
return nil, err
}
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
// Chop off trailing white space.
p = bytes.TrimRight(p, " \r\t\n")
return p, nil
}
// readLineBytes, but convert the bytes into a string.
func readLine(b *bufio.Reader) (s string, err error) {
p, e := readLineBytes(b)
if e != nil {
return "", e
}
return string(p), nil
}
// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
// "chunked" format before writing them to w. Closing the returned chunkedWriter
// sends the final 0-length chunk that marks the end of the stream.
//
// NewChunkedWriter is not needed by normal applications. The http
@ -25,8 +133,8 @@ func NewChunkedWriter(w io.Writer) io.WriteCloser {
return &chunkedWriter{w}
}
// Writing to ChunkedWriter translates to writing in HTTP chunked Transfer
// Encoding wire format to the underlying Wire writer.
// Writing to chunkedWriter translates to writing in HTTP chunked Transfer
// Encoding wire format to the underlying Wire chunkedWriter.
type chunkedWriter struct {
Wire io.Writer
}
@ -62,23 +170,3 @@ func (cw *chunkedWriter) Close() error {
_, err := io.WriteString(cw.Wire, "0\r\n")
return err
}
// NewChunkedReader returns a new reader that translates the data read from r
// out of HTTP "chunked" format before returning it.
// The reader returns io.EOF when the final 0-length chunk is read.
//
// NewChunkedReader is not needed by normal applications. The http package
// automatically decodes chunking when reading response bodies.
func NewChunkedReader(r io.Reader) io.Reader {
// This is a bit of a hack so we don't have to copy chunkedReader into
// httputil. It's a bit more complex than chunkedWriter, which is copied
// above.
req, err := http.ReadRequest(bufio.NewReader(io.MultiReader(
strings.NewReader("POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"),
r,
strings.NewReader("\r\n"))))
if err != nil {
panic("bad fake request: " + err.Error())
}
return req.Body
}

View File

@ -2,6 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This code is a duplicate of ../chunked_test.go with these edits:
// s/newChunked/NewChunked/g
// s/package http/package httputil/
// Please make any changes in both files.
package httputil
import (
@ -27,7 +32,8 @@ func TestChunk(t *testing.T) {
r := NewChunkedReader(&b)
data, err := ioutil.ReadAll(r)
if err != nil {
t.Fatalf("ReadAll from NewChunkedReader: %v", err)
t.Logf(`data: "%s"`, data)
t.Fatalf("ReadAll from reader: %v", err)
}
if g, e := string(data), chunk1+chunk2; g != e {
t.Errorf("chunk reader read %q; want %q", g, e)

View File

@ -22,6 +22,10 @@ var (
ErrPipeline = &http.ProtocolError{"pipeline error"}
)
// This is an API usage error - the local side is closed.
// ErrPersistEOF (above) reports that the remote side is closed.
var errClosed = errors.New("i/o operation on closed connection")
// A ServerConn reads requests and sends responses over an underlying
// connection, until the HTTP keepalive logic commands an end. ServerConn
// also allows hijacking the underlying connection by calling Hijack
@ -108,7 +112,7 @@ func (sc *ServerConn) Read() (req *http.Request, err error) {
}
if sc.r == nil { // connection closed by user in the meantime
defer sc.lk.Unlock()
return nil, os.EBADF
return nil, errClosed
}
r := sc.r
lastbody := sc.lastbody
@ -313,7 +317,7 @@ func (cc *ClientConn) Write(req *http.Request) (err error) {
}
if cc.c == nil { // connection closed by user in the meantime
defer cc.lk.Unlock()
return os.EBADF
return errClosed
}
c := cc.c
if req.Close {
@ -369,7 +373,7 @@ func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) {
}
if cc.r == nil { // connection closed by user in the meantime
defer cc.lk.Unlock()
return nil, os.EBADF
return nil, errClosed
}
r := cc.r
lastbody := cc.lastbody

View File

@ -70,7 +70,6 @@ var reqTests = []reqTest{
Close: false,
ContentLength: 7,
Host: "www.techcrunch.com",
Form: url.Values{},
},
"abcdef\n",
@ -94,10 +93,10 @@ var reqTests = []reqTest{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: Header{},
Close: false,
ContentLength: 0,
Host: "foo.com",
Form: url.Values{},
},
noBody,
@ -131,7 +130,6 @@ var reqTests = []reqTest{
Close: false,
ContentLength: 0,
Host: "test",
Form: url.Values{},
},
noBody,
@ -180,9 +178,9 @@ var reqTests = []reqTest{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: Header{},
ContentLength: -1,
Host: "foo.com",
Form: url.Values{},
},
"foobar",

View File

@ -19,12 +19,10 @@ import (
"mime/multipart"
"net/textproto"
"net/url"
"strconv"
"strings"
)
const (
maxLineLength = 4096 // assumed <= bufio.defaultBufSize
maxValueLength = 4096
maxHeaderLines = 1024
chunkSize = 4 << 10 // 4 KB chunks
@ -43,7 +41,6 @@ type ProtocolError struct {
func (err *ProtocolError) Error() string { return err.ErrorString }
var (
ErrLineTooLong = &ProtocolError{"header line too long"}
ErrHeaderTooLong = &ProtocolError{"header too long"}
ErrShortBody = &ProtocolError{"entity body too short"}
ErrNotSupported = &ProtocolError{"feature not supported"}
@ -375,44 +372,6 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
return nil
}
// Read a line of bytes (up to \n) from b.
// Give up if the line exceeds maxLineLength.
// The returned bytes are a pointer into storage in
// the bufio, so they are only valid until the next bufio read.
func readLineBytes(b *bufio.Reader) (p []byte, err error) {
if p, err = b.ReadSlice('\n'); err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
if err == io.EOF {
err = io.ErrUnexpectedEOF
} else if err == bufio.ErrBufferFull {
err = ErrLineTooLong
}
return nil, err
}
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
// Chop off trailing white space.
var i int
for i = len(p); i > 0; i-- {
if c := p[i-1]; c != ' ' && c != '\r' && c != '\t' && c != '\n' {
break
}
}
return p[0:i], nil
}
// readLineBytes, but convert the bytes into a string.
func readLine(b *bufio.Reader) (s string, err error) {
p, e := readLineBytes(b)
if e != nil {
return "", e
}
return string(p), nil
}
// Convert decimal at s[i:len(s)] to integer,
// returning value, string position where the digits stopped,
// and whether there was a valid number (digits, not too big).
@ -448,55 +407,6 @@ func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
return major, minor, true
}
type chunkedReader struct {
r *bufio.Reader
n uint64 // unread bytes in chunk
err error
}
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
var line string
line, cr.err = readLine(cr.r)
if cr.err != nil {
return
}
cr.n, cr.err = strconv.Btoui64(line, 16)
if cr.err != nil {
return
}
if cr.n == 0 {
cr.err = io.EOF
}
}
func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
if cr.err != nil {
return 0, cr.err
}
if cr.n == 0 {
cr.beginChunk()
if cr.err != nil {
return 0, cr.err
}
}
if uint64(len(b)) > cr.n {
b = b[0:cr.n]
}
n, cr.err = cr.r.Read(b)
cr.n -= uint64(n)
if cr.n == 0 && cr.err == nil {
// end of chunk (CRLF)
b := make([]byte, 2)
if _, cr.err = io.ReadFull(cr.r, b); cr.err == nil {
if b[0] != '\r' || b[1] != '\n' {
cr.err = errors.New("malformed chunked encoding")
}
}
}
return n, cr.err
}
// NewRequest returns a new Request given a method, URL, and optional body.
func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
u, err := url.Parse(urlStr)

View File

@ -65,6 +65,7 @@ var respTests = []respTest{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: Header{},
Request: dummyReq("GET"),
Close: true,
ContentLength: -1,
@ -85,6 +86,7 @@ var respTests = []respTest{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: Header{},
Request: dummyReq("GET"),
Close: false,
ContentLength: 0,
@ -315,7 +317,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) {
}
var wr io.Writer = &buf
if test.chunked {
wr = &chunkedWriter{wr}
wr = newChunkedWriter(wr)
}
if test.compressed {
buf.WriteString("Content-Encoding: gzip\r\n")

View File

@ -1077,6 +1077,31 @@ func TestClientWriteShutdown(t *testing.T) {
}
}
// Tests that chunked server responses that write 1 byte at a time are
// buffered before chunk headers are added, not after chunk headers.
func TestServerBufferedChunking(t *testing.T) {
if true {
t.Logf("Skipping known broken test; see Issue 2357")
return
}
conn := new(testConn)
conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
done := make(chan bool)
ls := &oneConnListener{conn}
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
defer close(done)
rw.Header().Set("Content-Type", "text/plain") // prevent sniffing, which buffers
rw.Write([]byte{'x'})
rw.Write([]byte{'y'})
rw.Write([]byte{'z'})
}))
<-done
if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
conn.writeBuf.Bytes())
}
}
// goTimeout runs f, failing t if f takes more than ns to complete.
func goTimeout(t *testing.T, ns int64, f func()) {
ch := make(chan bool, 2)
@ -1120,7 +1145,7 @@ func TestAcceptMaxFds(t *testing.T) {
ln := &errorListener{[]error{
&net.OpError{
Op: "accept",
Err: os.Errno(syscall.EMFILE),
Err: syscall.EMFILE,
}}}
err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})))
if err != io.EOF {

View File

@ -149,11 +149,13 @@ type writerOnly struct {
}
func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
// Flush before checking w.chunking, as Flush will call
// WriteHeader if it hasn't been called yet, and WriteHeader
// is what sets w.chunking.
w.Flush()
// Call WriteHeader before checking w.chunking if it hasn't
// been called yet, since WriteHeader is what sets w.chunking.
if !w.wroteHeader {
w.WriteHeader(StatusOK)
}
if !w.chunking && w.bodyAllowed() && !w.needSniff {
w.Flush()
if rf, ok := w.conn.rwc.(io.ReaderFrom); ok {
n, err = rf.ReadFrom(src)
w.written += n

View File

@ -6,6 +6,7 @@ package http_test
import (
"bytes"
"io"
"io/ioutil"
"log"
. "net/http"
@ -79,3 +80,35 @@ func TestServerContentType(t *testing.T) {
resp.Body.Close()
}
}
func TestContentTypeWithCopy(t *testing.T) {
const (
input = "\n<html>\n\t<head>\n"
expected = "text/html; charset=utf-8"
)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
// Use io.Copy from a bytes.Buffer to trigger ReadFrom.
buf := bytes.NewBuffer([]byte(input))
n, err := io.Copy(w, buf)
if int(n) != len(input) || err != nil {
t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
}
}))
defer ts.Close()
resp, err := Get(ts.URL)
if err != nil {
t.Fatalf("Get: %v", err)
}
if ct := resp.Header.Get("Content-Type"); ct != expected {
t.Errorf("Content-Type = %q, want %q", ct, expected)
}
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("reading body: %v", err)
} else if !bytes.Equal(data, []byte(input)) {
t.Errorf("data is %q, want %q", data, input)
}
resp.Body.Close()
}

View File

@ -537,7 +537,9 @@ func (b *body) Read(p []byte) (n int, err error) {
// Read the final trailer once we hit EOF.
if err == io.EOF && b.hdr != nil {
err = b.readTrailer()
if e := b.readTrailer(); e != nil {
err = e
}
b.hdr = nil
}
return n, err

Some files were not shown because too many files have changed in this diff Show More