Update Go library to r60.

From-SVN: r178910
This commit is contained in:
Ian Lance Taylor 2011-09-16 15:47:21 +00:00
parent 5548ca3540
commit adb0401dac
718 changed files with 58911 additions and 30469 deletions

View File

@ -806,7 +806,7 @@ proc go-gc-tests { } {
$status $name
} else {
verbose -log $comp_output
fali $name
fail $name
}
file delete $ofile1 $ofile2 $output_file
set runtests $hold_runtests

View File

@ -37,7 +37,7 @@ func main() {
}
fmt.Fprintln(out, `}`)
}
do(recv)
do(send)
do(recvOrder)
@ -54,8 +54,8 @@ func run(t *template.Template, a interface{}, out io.Writer) {
}
}
type arg struct{
def bool
type arg struct {
def bool
nreset int
}
@ -135,181 +135,180 @@ func main() {
}
`
func parse(s string) *template.Template {
t := template.New(nil)
t.SetDelims("〈", "〉")
if err := t.Parse(s); err != nil {
panic(s)
func parse(name, s string) *template.Template {
t, err := template.New(name).Parse(s)
if err != nil {
panic(fmt.Sprintf("%q: %s", name, err))
}
return t
}
var recv = parse(`
# Send n, receive it one way or another into x, check that they match.
var recv = parse("recv", `
{{/* Send n, receive it one way or another into x, check that they match. */}}
c <- n
.section Maybe
{{if .Maybe}}
x = <-c
.or
{{else}}
select {
# Blocking or non-blocking, before the receive.
# The compiler implements two-case select where one is default with custom code,
# so test the default branch both before and after the send.
.section MaybeDefault
{{/* Blocking or non-blocking, before the receive. */}}
{{/* The compiler implements two-case select where one is default with custom code, */}}
{{/* so test the default branch both before and after the send. */}}
{{if .MaybeDefault}}
default:
panic("nonblock")
.end
# Receive from c. Different cases are direct, indirect, :=, interface, and map assignment.
.section Maybe
{{end}}
{{/* Receive from c. Different cases are direct, indirect, :=, interface, and map assignment. */}}
{{if .Maybe}}
case x = <-c:
.or.section Maybe
{{else}}{{if .Maybe}}
case *f(&x) = <-c:
.or.section Maybe
{{else}}{{if .Maybe}}
case y := <-c:
x = y
.or.section Maybe
{{else}}{{if .Maybe}}
case i = <-c:
x = i.(int)
.or
{{else}}
case m[13] = <-c:
x = m[13]
.end.end.end.end
# Blocking or non-blocking again, after the receive.
.section MaybeDefault
{{end}}{{end}}{{end}}{{end}}
{{/* Blocking or non-blocking again, after the receive. */}}
{{if .MaybeDefault}}
default:
panic("nonblock")
.end
# Dummy send, receive to keep compiler from optimizing select.
.section Maybe
{{end}}
{{/* Dummy send, receive to keep compiler from optimizing select. */}}
{{if .Maybe}}
case dummy <- 1:
panic("dummy send")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case <-dummy:
panic("dummy receive")
.end
# Nil channel send, receive to keep compiler from optimizing select.
.section Maybe
{{end}}
{{/* Nil channel send, receive to keep compiler from optimizing select. */}}
{{if .Maybe}}
case nilch <- 1:
panic("nilch send")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case <-nilch:
panic("nilch recv")
.end
{{end}}
}
.end
{{end}}
if x != n {
die(x)
}
n++
`)
var recvOrder = parse(`
# Send n, receive it one way or another into x, check that they match.
# Check order of operations along the way by calling functions that check
# that the argument sequence is strictly increasing.
var recvOrder = parse("recvOrder", `
{{/* Send n, receive it one way or another into x, check that they match. */}}
{{/* Check order of operations along the way by calling functions that check */}}
{{/* that the argument sequence is strictly increasing. */}}
order = 0
c <- n
.section Maybe
# Outside of select, left-to-right rule applies.
# (Inside select, assignment waits until case is chosen,
# so right hand side happens before anything on left hand side.
{{if .Maybe}}
{{/* Outside of select, left-to-right rule applies. */}}
{{/* (Inside select, assignment waits until case is chosen, */}}
{{/* so right hand side happens before anything on left hand side. */}}
*fp(&x, 1) = <-fc(c, 2)
.or.section Maybe
{{else}}{{if .Maybe}}
m[fn(13, 1)] = <-fc(c, 2)
x = m[13]
.or
{{else}}
select {
# Blocking or non-blocking, before the receive.
# The compiler implements two-case select where one is default with custom code,
# so test the default branch both before and after the send.
.section MaybeDefault
{{/* Blocking or non-blocking, before the receive. */}}
{{/* The compiler implements two-case select where one is default with custom code, */}}
{{/* so test the default branch both before and after the send. */}}
{{if .MaybeDefault}}
default:
panic("nonblock")
.end
# Receive from c. Different cases are direct, indirect, :=, interface, and map assignment.
.section Maybe
{{end}}
{{/* Receive from c. Different cases are direct, indirect, :=, interface, and map assignment. */}}
{{if .Maybe}}
case *fp(&x, 100) = <-fc(c, 1):
.or.section Maybe
{{else}}{{if .Maybe}}
case y := <-fc(c, 1):
x = y
.or.section Maybe
{{else}}{{if .Maybe}}
case i = <-fc(c, 1):
x = i.(int)
.or
{{else}}
case m[fn(13, 100)] = <-fc(c, 1):
x = m[13]
.end.end.end
# Blocking or non-blocking again, after the receive.
.section MaybeDefault
{{end}}{{end}}{{end}}
{{/* Blocking or non-blocking again, after the receive. */}}
{{if .MaybeDefault}}
default:
panic("nonblock")
.end
# Dummy send, receive to keep compiler from optimizing select.
.section Maybe
{{end}}
{{/* Dummy send, receive to keep compiler from optimizing select. */}}
{{if .Maybe}}
case fc(dummy, 2) <- fn(1, 3):
panic("dummy send")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case <-fc(dummy, 4):
panic("dummy receive")
.end
# Nil channel send, receive to keep compiler from optimizing select.
.section Maybe
{{end}}
{{/* Nil channel send, receive to keep compiler from optimizing select. */}}
{{if .Maybe}}
case fc(nilch, 5) <- fn(1, 6):
panic("nilch send")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case <-fc(nilch, 7):
panic("nilch recv")
.end
{{end}}
}
.end.end
{{end}}{{end}}
if x != n {
die(x)
}
n++
`)
var send = parse(`
# Send n one way or another, receive it into x, check that they match.
.section Maybe
var send = parse("send", `
{{/* Send n one way or another, receive it into x, check that they match. */}}
{{if .Maybe}}
c <- n
.or
{{else}}
select {
# Blocking or non-blocking, before the receive (same reason as in recv).
.section MaybeDefault
{{/* Blocking or non-blocking, before the receive (same reason as in recv). */}}
{{if .MaybeDefault}}
default:
panic("nonblock")
.end
# Send c <- n. No real special cases here, because no values come back
# from the send operation.
{{end}}
{{/* Send c <- n. No real special cases here, because no values come back */}}
{{/* from the send operation. */}}
case c <- n:
# Blocking or non-blocking.
.section MaybeDefault
{{/* Blocking or non-blocking. */}}
{{if .MaybeDefault}}
default:
panic("nonblock")
.end
# Dummy send, receive to keep compiler from optimizing select.
.section Maybe
{{end}}
{{/* Dummy send, receive to keep compiler from optimizing select. */}}
{{if .Maybe}}
case dummy <- 1:
panic("dummy send")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case <-dummy:
panic("dummy receive")
.end
# Nil channel send, receive to keep compiler from optimizing select.
.section Maybe
{{end}}
{{/* Nil channel send, receive to keep compiler from optimizing select. */}}
{{if .Maybe}}
case nilch <- 1:
panic("nilch send")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case <-nilch:
panic("nilch recv")
.end
{{end}}
}
.end
{{end}}
x = <-c
if x != n {
die(x)
@ -317,48 +316,48 @@ var send = parse(`
n++
`)
var sendOrder = parse(`
# Send n one way or another, receive it into x, check that they match.
# Check order of operations along the way by calling functions that check
# that the argument sequence is strictly increasing.
var sendOrder = parse("sendOrder", `
{{/* Send n one way or another, receive it into x, check that they match. */}}
{{/* Check order of operations along the way by calling functions that check */}}
{{/* that the argument sequence is strictly increasing. */}}
order = 0
.section Maybe
{{if .Maybe}}
fc(c, 1) <- fn(n, 2)
.or
{{else}}
select {
# Blocking or non-blocking, before the receive (same reason as in recv).
.section MaybeDefault
{{/* Blocking or non-blocking, before the receive (same reason as in recv). */}}
{{if .MaybeDefault}}
default:
panic("nonblock")
.end
# Send c <- n. No real special cases here, because no values come back
# from the send operation.
{{end}}
{{/* Send c <- n. No real special cases here, because no values come back */}}
{{/* from the send operation. */}}
case fc(c, 1) <- fn(n, 2):
# Blocking or non-blocking.
.section MaybeDefault
{{/* Blocking or non-blocking. */}}
{{if .MaybeDefault}}
default:
panic("nonblock")
.end
# Dummy send, receive to keep compiler from optimizing select.
.section Maybe
{{end}}
{{/* Dummy send, receive to keep compiler from optimizing select. */}}
{{if .Maybe}}
case fc(dummy, 3) <- fn(1, 4):
panic("dummy send")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case <-fc(dummy, 5):
panic("dummy receive")
.end
# Nil channel send, receive to keep compiler from optimizing select.
.section Maybe
{{end}}
{{/* Nil channel send, receive to keep compiler from optimizing select. */}}
{{if .Maybe}}
case fc(nilch, 6) <- fn(1, 7):
panic("nilch send")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case <-fc(nilch, 8):
panic("nilch recv")
.end
{{end}}
}
.end
{{end}}
x = <-c
if x != n {
die(x)
@ -366,49 +365,49 @@ var sendOrder = parse(`
n++
`)
var nonblock = parse(`
var nonblock = parse("nonblock", `
x = n
# Test various combinations of non-blocking operations.
# Receive assignments must not edit or even attempt to compute the address of the lhs.
{{/* Test various combinations of non-blocking operations. */}}
{{/* Receive assignments must not edit or even attempt to compute the address of the lhs. */}}
select {
.section MaybeDefault
{{if .MaybeDefault}}
default:
.end
.section Maybe
{{end}}
{{if .Maybe}}
case dummy <- 1:
panic("dummy <- 1")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case nilch <- 1:
panic("nilch <- 1")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case <-dummy:
panic("<-dummy")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case x = <-dummy:
panic("<-dummy x")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case **(**int)(nil) = <-dummy:
panic("<-dummy (and didn't crash saving result!)")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case <-nilch:
panic("<-nilch")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case x = <-nilch:
panic("<-nilch x")
.end
.section Maybe
{{end}}
{{if .Maybe}}
case **(**int)(nil) = <-nilch:
panic("<-nilch (and didn't crash saving result!)")
.end
.section MustDefault
{{end}}
{{if .MustDefault}}
default:
.end
{{end}}
}
if x != n {
die(x)
@ -466,7 +465,7 @@ func next() bool {
}
// increment last choice sequence
cp = len(choices)-1
cp = len(choices) - 1
for cp >= 0 && choices[cp].i == choices[cp].n-1 {
cp--
}
@ -479,4 +478,3 @@ func next() bool {
cp = 0
return true
}

View File

@ -38,7 +38,7 @@ func Listen(x, y string) (T, string) {
}
func (t T) Addr() os.Error {
return os.ErrorString("stringer")
return os.NewError("stringer")
}
func (t T) Accept() (int, string) {
@ -49,4 +49,3 @@ func Dial(x, y, z string) (int, string) {
global <- 1
return 0, ""
}

View File

@ -18,6 +18,7 @@ var chatty = flag.Bool("v", false, "chatty")
var oldsys uint64
func bigger() {
runtime.UpdateMemStats()
if st := runtime.MemStats; oldsys < st.Sys {
oldsys = st.Sys
if *chatty {
@ -31,7 +32,7 @@ func bigger() {
}
func main() {
runtime.GC() // clean up garbage from init
runtime.GC() // clean up garbage from init
runtime.MemProfileRate = 0 // disable profiler
runtime.MemStats.Alloc = 0 // ignore stacks
flag.Parse()
@ -45,8 +46,10 @@ func main() {
panic("fail")
}
b := runtime.Alloc(uintptr(j))
runtime.UpdateMemStats()
during := runtime.MemStats.Alloc
runtime.Free(b)
runtime.UpdateMemStats()
if a := runtime.MemStats.Alloc; a != 0 {
println("allocated ", j, ": wrong stats: during=", during, " after=", a, " (want 0)")
panic("fail")

View File

@ -42,6 +42,7 @@ func AllocAndFree(size, count int) {
if *chatty {
fmt.Printf("size=%d count=%d ...\n", size, count)
}
runtime.UpdateMemStats()
n1 := stats.Alloc
for i := 0; i < count; i++ {
b[i] = runtime.Alloc(uintptr(size))
@ -50,11 +51,13 @@ func AllocAndFree(size, count int) {
println("lookup failed: got", base, n, "for", b[i])
panic("fail")
}
if runtime.MemStats.Sys > 1e9 {
runtime.UpdateMemStats()
if stats.Sys > 1e9 {
println("too much memory allocated")
panic("fail")
}
}
runtime.UpdateMemStats()
n2 := stats.Alloc
if *chatty {
fmt.Printf("size=%d count=%d stats=%+v\n", size, count, *stats)
@ -72,6 +75,7 @@ func AllocAndFree(size, count int) {
panic("fail")
}
runtime.Free(b[i])
runtime.UpdateMemStats()
if stats.Alloc != uint64(alloc-n) {
println("free alloc got", stats.Alloc, "expected", alloc-n, "after free of", n)
panic("fail")
@ -81,6 +85,7 @@ func AllocAndFree(size, count int) {
panic("fail")
}
}
runtime.UpdateMemStats()
n4 := stats.Alloc
if *chatty {

View File

@ -1,4 +1,4 @@
aea0ba6e5935
504f4e9b079c
The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -12,12 +12,24 @@
/* Define to 1 if you have the <inttypes.h> header file. */
#undef HAVE_INTTYPES_H
/* Define to 1 if you have the <linux/filter.h> header file. */
#undef HAVE_LINUX_FILTER_H
/* Define to 1 if you have the <linux/netlink.h> header file. */
#undef HAVE_LINUX_NETLINK_H
/* Define to 1 if you have the <linux/rtnetlink.h> header file. */
#undef HAVE_LINUX_RTNETLINK_H
/* Define to 1 if you have the <memory.h> header file. */
#undef HAVE_MEMORY_H
/* Define to 1 if you have the `mincore' function. */
#undef HAVE_MINCORE
/* Define to 1 if you have the <net/if.h> header file. */
#undef HAVE_NET_IF_H
/* Define to 1 if the system has the type `off64_t'. */
#undef HAVE_OFF64_T
@ -71,6 +83,9 @@
/* Define to 1 if you have the <sys/select.h> header file. */
#undef HAVE_SYS_SELECT_H
/* Define to 1 if you have the <sys/socket.h> header file. */
#undef HAVE_SYS_SOCKET_H
/* Define to 1 if you have the <sys/stat.h> header file. */
#undef HAVE_SYS_STAT_H

33
libgo/configure vendored
View File

@ -617,7 +617,6 @@ USING_SPLIT_STACK_FALSE
USING_SPLIT_STACK_TRUE
SPLIT_STACK
OSCFLAGS
GO_DEBUG_PROC_REGS_OS_ARCH_FILE
GO_SYSCALLS_SYSCALL_OS_ARCH_FILE
GOARCH
LIBGO_IS_X86_64_FALSE
@ -10914,7 +10913,7 @@ else
lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2
lt_status=$lt_dlunknown
cat > conftest.$ac_ext <<_LT_EOF
#line 10917 "configure"
#line 10916 "configure"
#include "confdefs.h"
#if HAVE_DLFCN_H
@ -11020,7 +11019,7 @@ else
lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2
lt_status=$lt_dlunknown
cat > conftest.$ac_ext <<_LT_EOF
#line 11023 "configure"
#line 11022 "configure"
#include "confdefs.h"
#if HAVE_DLFCN_H
@ -13558,12 +13557,6 @@ if test -f ${srcdir}/syscalls/syscall_${GOOS}_${GOARCH}.go; then
fi
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=
if test -f ${srcdir}/go/debug/proc/regs_${GOOS}_${GOARCH}.go; then
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=go/debug/proc/regs_${GOOS}_${GOARCH}.go
fi
case "$target" in
mips-sgi-irix6.5*)
# IRIX 6 needs _XOPEN_SOURCE=500 for the XPG5 version of struct
@ -14252,7 +14245,7 @@ no)
;;
esac
for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h
for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h
do :
as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default"
@ -14266,6 +14259,26 @@ fi
done
for ac_header in linux/filter.h linux/netlink.h linux/rtnetlink.h
do :
as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
ac_fn_c_check_header_compile "$LINENO" "$ac_header" "$as_ac_Header" "#ifdef HAVE_SYS_SOCKET_H
#include <sys/socket.h>
#endif
"
eval as_val=\$$as_ac_Header
if test "x$as_val" = x""yes; then :
cat >>confdefs.h <<_ACEOF
#define `$as_echo "HAVE_$ac_header" | $as_tr_cpp` 1
_ACEOF
fi
done
if test "$ac_cv_header_sys_mman_h" = yes; then
HAVE_SYS_MMAN_H_TRUE=
HAVE_SYS_MMAN_H_FALSE='#'

View File

@ -255,12 +255,6 @@ if test -f ${srcdir}/syscalls/syscall_${GOOS}_${GOARCH}.go; then
fi
AC_SUBST(GO_SYSCALLS_SYSCALL_OS_ARCH_FILE)
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=
if test -f ${srcdir}/go/debug/proc/regs_${GOOS}_${GOARCH}.go; then
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=go/debug/proc/regs_${GOOS}_${GOARCH}.go
fi
AC_SUBST(GO_DEBUG_PROC_REGS_OS_ARCH_FILE)
dnl Some targets need special flags to build sysinfo.go.
case "$target" in
mips-sgi-irix6.5*)
@ -431,7 +425,14 @@ no)
;;
esac
AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h)
AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h)
AC_CHECK_HEADERS([linux/filter.h linux/netlink.h linux/rtnetlink.h], [], [],
[#ifdef HAVE_SYS_SOCKET_H
#include <sys/socket.h>
#endif
])
AM_CONDITIONAL(HAVE_SYS_MMAN_H, test "$ac_cv_header_sys_mman_h" = yes)
AC_CHECK_FUNCS(srandom random strerror_r strsignal wait4 mincore setenv)

View File

@ -16,7 +16,7 @@ import (
)
var (
HeaderError os.Error = os.ErrorString("invalid tar header")
HeaderError = os.NewError("invalid tar header")
)
// A Reader provides sequential access to the contents of a tar archive.

View File

@ -178,7 +178,6 @@ func TestPartialRead(t *testing.T) {
}
}
func TestIncrementalRead(t *testing.T) {
test := gnuTarTest
f, err := os.Open(test.file)

View File

@ -2,18 +2,10 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package zip provides support for reading ZIP archives.
See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT
This package does not support ZIP64 or disk spanning.
*/
package zip
import (
"bufio"
"bytes"
"compress/flate"
"hash"
"hash/crc32"
@ -24,9 +16,9 @@ import (
)
var (
FormatError = os.NewError("not a valid zip file")
UnsupportedMethod = os.NewError("unsupported compression algorithm")
ChecksumError = os.NewError("checksum error")
FormatError = os.NewError("zip: not a valid zip file")
UnsupportedMethod = os.NewError("zip: unsupported compression algorithm")
ChecksumError = os.NewError("zip: checksum error")
)
type Reader struct {
@ -44,15 +36,14 @@ type File struct {
FileHeader
zipr io.ReaderAt
zipsize int64
headerOffset uint32
bodyOffset int64
headerOffset int64
}
func (f *File) hasDataDescriptor() bool {
return f.Flags&0x8 != 0
}
// OpenReader will open the Zip file specified by name and return a ReaderCloser.
// OpenReader will open the Zip file specified by name and return a ReadCloser.
func OpenReader(name string) (*ReadCloser, os.Error) {
f, err := os.Open(name)
if err != nil {
@ -87,18 +78,33 @@ func (z *Reader) init(r io.ReaderAt, size int64) os.Error {
return err
}
z.r = r
z.File = make([]*File, end.directoryRecords)
z.File = make([]*File, 0, end.directoryRecords)
z.Comment = end.comment
rs := io.NewSectionReader(r, 0, size)
if _, err = rs.Seek(int64(end.directoryOffset), os.SEEK_SET); err != nil {
return err
}
buf := bufio.NewReader(rs)
for i := range z.File {
z.File[i] = &File{zipr: r, zipsize: size}
if err := readDirectoryHeader(z.File[i], buf); err != nil {
// The count of files inside a zip is truncated to fit in a uint16.
// Gloss over this by reading headers until we encounter
// a bad one, and then only report a FormatError or UnexpectedEOF if
// the file count modulo 65536 is incorrect.
for {
f := &File{zipr: r, zipsize: size}
err = readDirectoryHeader(f, buf)
if err == FormatError || err == io.ErrUnexpectedEOF {
break
}
if err != nil {
return err
}
z.File = append(z.File, f)
}
if uint16(len(z.File)) != end.directoryRecords {
// Return the readDirectoryHeader error if we read
// the wrong number of directory entries.
return err
}
return nil
}
@ -109,31 +115,22 @@ func (rc *ReadCloser) Close() os.Error {
}
// Open returns a ReadCloser that provides access to the File's contents.
// It is safe to Open and Read from files concurrently.
func (f *File) Open() (rc io.ReadCloser, err os.Error) {
off := int64(f.headerOffset)
if f.bodyOffset == 0 {
r := io.NewSectionReader(f.zipr, off, f.zipsize-off)
if err = readFileHeader(f, r); err != nil {
return
}
if f.bodyOffset, err = r.Seek(0, os.SEEK_CUR); err != nil {
return
}
bodyOffset, err := f.findBodyOffset()
if err != nil {
return
}
size := int64(f.CompressedSize)
if f.hasDataDescriptor() {
if size == 0 {
// permit SectionReader to see the rest of the file
size = f.zipsize - (off + f.bodyOffset)
} else {
size += dataDescriptorLen
}
if size == 0 && f.hasDataDescriptor() {
// permit SectionReader to see the rest of the file
size = f.zipsize - (f.headerOffset + bodyOffset)
}
r := io.NewSectionReader(f.zipr, off+f.bodyOffset, size)
r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, size)
switch f.Method {
case 0: // store (no compression)
case Store: // (no compression)
rc = ioutil.NopCloser(r)
case 8: // DEFLATE
case Deflate:
rc = flate.NewReader(r)
default:
err = UnsupportedMethod
@ -170,90 +167,102 @@ func (r *checksumReader) Read(b []byte) (n int, err os.Error) {
func (r *checksumReader) Close() os.Error { return r.rc.Close() }
func readFileHeader(f *File, r io.Reader) (err os.Error) {
defer func() {
if rerr, ok := recover().(os.Error); ok {
err = rerr
}
}()
var (
signature uint32
filenameLength uint16
extraLength uint16
)
read(r, &signature)
if signature != fileHeaderSignature {
func readFileHeader(f *File, r io.Reader) os.Error {
var b [fileHeaderLen]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
c := binary.LittleEndian
if sig := c.Uint32(b[:4]); sig != fileHeaderSignature {
return FormatError
}
read(r, &f.ReaderVersion)
read(r, &f.Flags)
read(r, &f.Method)
read(r, &f.ModifiedTime)
read(r, &f.ModifiedDate)
read(r, &f.CRC32)
read(r, &f.CompressedSize)
read(r, &f.UncompressedSize)
read(r, &filenameLength)
read(r, &extraLength)
f.Name = string(readByteSlice(r, filenameLength))
f.Extra = readByteSlice(r, extraLength)
return
f.ReaderVersion = c.Uint16(b[4:6])
f.Flags = c.Uint16(b[6:8])
f.Method = c.Uint16(b[8:10])
f.ModifiedTime = c.Uint16(b[10:12])
f.ModifiedDate = c.Uint16(b[12:14])
f.CRC32 = c.Uint32(b[14:18])
f.CompressedSize = c.Uint32(b[18:22])
f.UncompressedSize = c.Uint32(b[22:26])
filenameLen := int(c.Uint16(b[26:28]))
extraLen := int(c.Uint16(b[28:30]))
d := make([]byte, filenameLen+extraLen)
if _, err := io.ReadFull(r, d); err != nil {
return err
}
f.Name = string(d[:filenameLen])
f.Extra = d[filenameLen:]
return nil
}
func readDirectoryHeader(f *File, r io.Reader) (err os.Error) {
defer func() {
if rerr, ok := recover().(os.Error); ok {
err = rerr
}
}()
var (
signature uint32
filenameLength uint16
extraLength uint16
commentLength uint16
startDiskNumber uint16 // unused
internalAttributes uint16 // unused
externalAttributes uint32 // unused
)
read(r, &signature)
if signature != directoryHeaderSignature {
// findBodyOffset does the minimum work to verify the file has a header
// and returns the file body offset.
func (f *File) findBodyOffset() (int64, os.Error) {
r := io.NewSectionReader(f.zipr, f.headerOffset, f.zipsize-f.headerOffset)
var b [fileHeaderLen]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return 0, err
}
c := binary.LittleEndian
if sig := c.Uint32(b[:4]); sig != fileHeaderSignature {
return 0, FormatError
}
filenameLen := int(c.Uint16(b[26:28]))
extraLen := int(c.Uint16(b[28:30]))
return int64(fileHeaderLen + filenameLen + extraLen), nil
}
// readDirectoryHeader attempts to read a directory header from r.
// It returns io.ErrUnexpectedEOF if it cannot read a complete header,
// and FormatError if it doesn't find a valid header signature.
func readDirectoryHeader(f *File, r io.Reader) os.Error {
var b [directoryHeaderLen]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
c := binary.LittleEndian
if sig := c.Uint32(b[:4]); sig != directoryHeaderSignature {
return FormatError
}
read(r, &f.CreatorVersion)
read(r, &f.ReaderVersion)
read(r, &f.Flags)
read(r, &f.Method)
read(r, &f.ModifiedTime)
read(r, &f.ModifiedDate)
read(r, &f.CRC32)
read(r, &f.CompressedSize)
read(r, &f.UncompressedSize)
read(r, &filenameLength)
read(r, &extraLength)
read(r, &commentLength)
read(r, &startDiskNumber)
read(r, &internalAttributes)
read(r, &externalAttributes)
read(r, &f.headerOffset)
f.Name = string(readByteSlice(r, filenameLength))
f.Extra = readByteSlice(r, extraLength)
f.Comment = string(readByteSlice(r, commentLength))
return
f.CreatorVersion = c.Uint16(b[4:6])
f.ReaderVersion = c.Uint16(b[6:8])
f.Flags = c.Uint16(b[8:10])
f.Method = c.Uint16(b[10:12])
f.ModifiedTime = c.Uint16(b[12:14])
f.ModifiedDate = c.Uint16(b[14:16])
f.CRC32 = c.Uint32(b[16:20])
f.CompressedSize = c.Uint32(b[20:24])
f.UncompressedSize = c.Uint32(b[24:28])
filenameLen := int(c.Uint16(b[28:30]))
extraLen := int(c.Uint16(b[30:32]))
commentLen := int(c.Uint16(b[32:34]))
// startDiskNumber := c.Uint16(b[34:36]) // Unused
// internalAttributes := c.Uint16(b[36:38]) // Unused
// externalAttributes := c.Uint32(b[38:42]) // Unused
f.headerOffset = int64(c.Uint32(b[42:46]))
d := make([]byte, filenameLen+extraLen+commentLen)
if _, err := io.ReadFull(r, d); err != nil {
return err
}
f.Name = string(d[:filenameLen])
f.Extra = d[filenameLen : filenameLen+extraLen]
f.Comment = string(d[filenameLen+extraLen:])
return nil
}
func readDataDescriptor(r io.Reader, f *File) (err os.Error) {
defer func() {
if rerr, ok := recover().(os.Error); ok {
err = rerr
}
}()
read(r, &f.CRC32)
read(r, &f.CompressedSize)
read(r, &f.UncompressedSize)
return
func readDataDescriptor(r io.Reader, f *File) os.Error {
var b [dataDescriptorLen]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
c := binary.LittleEndian
f.CRC32 = c.Uint32(b[:4])
f.CompressedSize = c.Uint32(b[4:8])
f.UncompressedSize = c.Uint32(b[8:12])
return nil
}
func readDirectoryEnd(r io.ReaderAt, size int64) (d *directoryEnd, err os.Error) {
func readDirectoryEnd(r io.ReaderAt, size int64) (dir *directoryEnd, err os.Error) {
// look for directoryEndSignature in the last 1k, then in the last 65k
var b []byte
for i, bLen := range []int64{1024, 65 * 1024} {
@ -274,53 +283,29 @@ func readDirectoryEnd(r io.ReaderAt, size int64) (d *directoryEnd, err os.Error)
}
// read header into struct
defer func() {
if rerr, ok := recover().(os.Error); ok {
err = rerr
d = nil
}
}()
br := bytes.NewBuffer(b[4:]) // skip over signature
d = new(directoryEnd)
read(br, &d.diskNbr)
read(br, &d.dirDiskNbr)
read(br, &d.dirRecordsThisDisk)
read(br, &d.directoryRecords)
read(br, &d.directorySize)
read(br, &d.directoryOffset)
read(br, &d.commentLen)
d.comment = string(readByteSlice(br, d.commentLen))
c := binary.LittleEndian
d := new(directoryEnd)
d.diskNbr = c.Uint16(b[4:6])
d.dirDiskNbr = c.Uint16(b[6:8])
d.dirRecordsThisDisk = c.Uint16(b[8:10])
d.directoryRecords = c.Uint16(b[10:12])
d.directorySize = c.Uint32(b[12:16])
d.directoryOffset = c.Uint32(b[16:20])
d.commentLen = c.Uint16(b[20:22])
d.comment = string(b[22 : 22+int(d.commentLen)])
return d, nil
}
func findSignatureInBlock(b []byte) int {
const minSize = 4 + 2 + 2 + 2 + 2 + 4 + 4 + 2 // fixed part of header
for i := len(b) - minSize; i >= 0; i-- {
for i := len(b) - directoryEndLen; i >= 0; i-- {
// defined from directoryEndSignature in struct.go
if b[i] == 'P' && b[i+1] == 'K' && b[i+2] == 0x05 && b[i+3] == 0x06 {
// n is length of comment
n := int(b[i+minSize-2]) | int(b[i+minSize-1])<<8
if n+minSize+i == len(b) {
n := int(b[i+directoryEndLen-2]) | int(b[i+directoryEndLen-1])<<8
if n+directoryEndLen+i == len(b) {
return i
}
}
}
return -1
}
func read(r io.Reader, data interface{}) {
if err := binary.Read(r, binary.LittleEndian, data); err != nil {
panic(err)
}
}
func readByteSlice(r io.Reader, l uint16) []byte {
b := make([]byte, l)
if l == 0 {
return b
}
if _, err := io.ReadFull(r, b); err != nil {
panic(err)
}
return b
}

View File

@ -11,6 +11,7 @@ import (
"io/ioutil"
"os"
"testing"
"time"
)
type ZipTest struct {
@ -24,8 +25,19 @@ type ZipTestFile struct {
Name string
Content []byte // if blank, will attempt to compare against File
File string // name of file to compare to (relative to testdata/)
Mtime string // modified time in format "mm-dd-yy hh:mm:ss"
}
// Caution: The Mtime values found for the test files should correspond to
// the values listed with unzip -l <zipfile>. However, the values
// listed by unzip appear to be off by some hours. When creating
// fresh test files and testing them, this issue is not present.
// The test files were created in Sydney, so there might be a time
// zone issue. The time zone information does have to be encoded
// somewhere, because otherwise unzip -l could not provide a different
// time from what the archive/zip package provides, but there appears
// to be no documentation about this.
var tests = []ZipTest{
{
Name: "test.zip",
@ -34,10 +46,12 @@ var tests = []ZipTest{
{
Name: "test.txt",
Content: []byte("This is a test text file.\n"),
Mtime: "09-05-10 12:12:02",
},
{
Name: "gophercolor16x16.png",
File: "gophercolor16x16.png",
Name: "gophercolor16x16.png",
File: "gophercolor16x16.png",
Mtime: "09-05-10 15:52:58",
},
},
},
@ -45,8 +59,9 @@ var tests = []ZipTest{
Name: "r.zip",
File: []ZipTestFile{
{
Name: "r/r.zip",
File: "r.zip",
Name: "r/r.zip",
File: "r.zip",
Mtime: "03-04-10 00:24:16",
},
},
},
@ -58,6 +73,7 @@ var tests = []ZipTest{
{
Name: "filename",
Content: []byte("This is a test textfile.\n"),
Mtime: "02-02-11 13:06:20",
},
},
},
@ -136,18 +152,36 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
if f.Name != ft.Name {
t.Errorf("name=%q, want %q", f.Name, ft.Name)
}
mtime, err := time.Parse("01-02-06 15:04:05", ft.Mtime)
if err != nil {
t.Error(err)
return
}
if got, want := f.Mtime_ns()/1e9, mtime.Seconds(); got != want {
t.Errorf("%s: mtime=%s (%d); want %s (%d)", f.Name, time.SecondsToUTC(got), got, mtime, want)
}
size0 := f.UncompressedSize
var b bytes.Buffer
r, err := f.Open()
if err != nil {
t.Error(err)
return
}
if size1 := f.UncompressedSize; size0 != size1 {
t.Errorf("file %q changed f.UncompressedSize from %d to %d", f.Name, size0, size1)
}
_, err = io.Copy(&b, r)
if err != nil {
t.Error(err)
return
}
r.Close()
var c []byte
if len(ft.Content) != 0 {
c = ft.Content
@ -155,10 +189,12 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
t.Error(err)
return
}
if b.Len() != len(c) {
t.Errorf("%s: len=%d, want %d", f.Name, b.Len(), len(c))
return
}
for i, b := range b.Bytes() {
if b != c[i] {
t.Errorf("%s: content[%d]=%q want %q", f.Name, i, b, c[i])

View File

@ -1,9 +1,32 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package zip provides support for reading and writing ZIP archives.
See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT
This package does not support ZIP64 or disk spanning.
*/
package zip
import "os"
import "time"
// Compression methods.
const (
Store uint16 = 0
Deflate uint16 = 8
)
const (
fileHeaderSignature = 0x04034b50
directoryHeaderSignature = 0x02014b50
directoryEndSignature = 0x06054b50
fileHeaderLen = 30 // + filename + extra
directoryHeaderLen = 46 // + filename + extra + comment
directoryEndLen = 22 // + comment
dataDescriptorLen = 12
)
@ -13,8 +36,8 @@ type FileHeader struct {
ReaderVersion uint16
Flags uint16
Method uint16
ModifiedTime uint16
ModifiedDate uint16
ModifiedTime uint16 // MS-DOS time
ModifiedDate uint16 // MS-DOS date
CRC32 uint32
CompressedSize uint32
UncompressedSize uint32
@ -32,3 +55,37 @@ type directoryEnd struct {
commentLen uint16
comment string
}
func recoverError(err *os.Error) {
if e := recover(); e != nil {
if osErr, ok := e.(os.Error); ok {
*err = osErr
return
}
panic(e)
}
}
// msDosTimeToTime converts an MS-DOS date and time into a time.Time.
// The resolution is 2s.
// See: http://msdn.microsoft.com/en-us/library/ms724247(v=VS.85).aspx
func msDosTimeToTime(dosDate, dosTime uint16) time.Time {
return time.Time{
// date bits 0-4: day of month; 5-8: month; 9-15: years since 1980
Year: int64(dosDate>>9 + 1980),
Month: int(dosDate >> 5 & 0xf),
Day: int(dosDate & 0x1f),
// time bits 0-4: second/2; 5-10: minute; 11-15: hour
Hour: int(dosTime >> 11),
Minute: int(dosTime >> 5 & 0x3f),
Second: int(dosTime & 0x1f * 2),
}
}
// Mtime_ns returns the modified time in ns since epoch.
// The resolution is 2s.
func (h *FileHeader) Mtime_ns() int64 {
t := msDosTimeToTime(h.ModifiedDate, h.ModifiedTime)
return t.Seconds() * 1e9
}

View File

@ -0,0 +1,244 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"bufio"
"compress/flate"
"encoding/binary"
"hash"
"hash/crc32"
"io"
"os"
)
// TODO(adg): support zip file comments
// TODO(adg): support specifying deflate level
// Writer implements a zip file writer.
type Writer struct {
*countWriter
dir []*header
last *fileWriter
closed bool
}
type header struct {
*FileHeader
offset uint32
}
// NewWriter returns a new Writer writing a zip file to w.
func NewWriter(w io.Writer) *Writer {
return &Writer{countWriter: &countWriter{w: bufio.NewWriter(w)}}
}
// Close finishes writing the zip file by writing the central directory.
// It does not (and can not) close the underlying writer.
func (w *Writer) Close() (err os.Error) {
if w.last != nil && !w.last.closed {
if err = w.last.close(); err != nil {
return
}
w.last = nil
}
if w.closed {
return os.NewError("zip: writer closed twice")
}
w.closed = true
defer recoverError(&err)
// write central directory
start := w.count
for _, h := range w.dir {
write(w, uint32(directoryHeaderSignature))
write(w, h.CreatorVersion)
write(w, h.ReaderVersion)
write(w, h.Flags)
write(w, h.Method)
write(w, h.ModifiedTime)
write(w, h.ModifiedDate)
write(w, h.CRC32)
write(w, h.CompressedSize)
write(w, h.UncompressedSize)
write(w, uint16(len(h.Name)))
write(w, uint16(len(h.Extra)))
write(w, uint16(len(h.Comment)))
write(w, uint16(0)) // disk number start
write(w, uint16(0)) // internal file attributes
write(w, uint32(0)) // external file attributes
write(w, h.offset)
writeBytes(w, []byte(h.Name))
writeBytes(w, h.Extra)
writeBytes(w, []byte(h.Comment))
}
end := w.count
// write end record
write(w, uint32(directoryEndSignature))
write(w, uint16(0)) // disk number
write(w, uint16(0)) // disk number where directory starts
write(w, uint16(len(w.dir))) // number of entries this disk
write(w, uint16(len(w.dir))) // number of entries total
write(w, uint32(end-start)) // size of directory
write(w, uint32(start)) // start of directory
write(w, uint16(0)) // size of comment
return w.w.(*bufio.Writer).Flush()
}
// Create adds a file to the zip file using the provided name.
// It returns a Writer to which the file contents should be written.
// The file's contents must be written to the io.Writer before the next
// call to Create, CreateHeader, or Close.
func (w *Writer) Create(name string) (io.Writer, os.Error) {
header := &FileHeader{
Name: name,
Method: Deflate,
}
return w.CreateHeader(header)
}
// CreateHeader adds a file to the zip file using the provided FileHeader
// for the file metadata.
// It returns a Writer to which the file contents should be written.
// The file's contents must be written to the io.Writer before the next
// call to Create, CreateHeader, or Close.
func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, os.Error) {
if w.last != nil && !w.last.closed {
if err := w.last.close(); err != nil {
return nil, err
}
}
fh.Flags |= 0x8 // we will write a data descriptor
fh.CreatorVersion = 0x14
fh.ReaderVersion = 0x14
fw := &fileWriter{
zipw: w,
compCount: &countWriter{w: w},
crc32: crc32.NewIEEE(),
}
switch fh.Method {
case Store:
fw.comp = nopCloser{fw.compCount}
case Deflate:
fw.comp = flate.NewWriter(fw.compCount, 5)
default:
return nil, UnsupportedMethod
}
fw.rawCount = &countWriter{w: fw.comp}
h := &header{
FileHeader: fh,
offset: uint32(w.count),
}
w.dir = append(w.dir, h)
fw.header = h
if err := writeHeader(w, fh); err != nil {
return nil, err
}
w.last = fw
return fw, nil
}
func writeHeader(w io.Writer, h *FileHeader) (err os.Error) {
defer recoverError(&err)
write(w, uint32(fileHeaderSignature))
write(w, h.ReaderVersion)
write(w, h.Flags)
write(w, h.Method)
write(w, h.ModifiedTime)
write(w, h.ModifiedDate)
write(w, h.CRC32)
write(w, h.CompressedSize)
write(w, h.UncompressedSize)
write(w, uint16(len(h.Name)))
write(w, uint16(len(h.Extra)))
writeBytes(w, []byte(h.Name))
writeBytes(w, h.Extra)
return nil
}
type fileWriter struct {
*header
zipw io.Writer
rawCount *countWriter
comp io.WriteCloser
compCount *countWriter
crc32 hash.Hash32
closed bool
}
func (w *fileWriter) Write(p []byte) (int, os.Error) {
if w.closed {
return 0, os.NewError("zip: write to closed file")
}
w.crc32.Write(p)
return w.rawCount.Write(p)
}
func (w *fileWriter) close() (err os.Error) {
if w.closed {
return os.NewError("zip: file closed twice")
}
w.closed = true
if err = w.comp.Close(); err != nil {
return
}
// update FileHeader
fh := w.header.FileHeader
fh.CRC32 = w.crc32.Sum32()
fh.CompressedSize = uint32(w.compCount.count)
fh.UncompressedSize = uint32(w.rawCount.count)
// write data descriptor
defer recoverError(&err)
write(w.zipw, fh.CRC32)
write(w.zipw, fh.CompressedSize)
write(w.zipw, fh.UncompressedSize)
return nil
}
type countWriter struct {
w io.Writer
count int64
}
func (w *countWriter) Write(p []byte) (int, os.Error) {
n, err := w.w.Write(p)
w.count += int64(n)
return n, err
}
type nopCloser struct {
io.Writer
}
func (w nopCloser) Close() os.Error {
return nil
}
func write(w io.Writer, data interface{}) {
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
panic(err)
}
}
func writeBytes(w io.Writer, b []byte) {
n, err := w.Write(b)
if err != nil {
panic(err)
}
if n != len(b) {
panic(io.ErrShortWrite)
}
}

View File

@ -0,0 +1,73 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"bytes"
"io/ioutil"
"rand"
"testing"
)
// TODO(adg): a more sophisticated test suite
const testString = "Rabbits, guinea pigs, gophers, marsupial rats, and quolls."
func TestWriter(t *testing.T) {
largeData := make([]byte, 1<<17)
for i := range largeData {
largeData[i] = byte(rand.Int())
}
// write a zip file
buf := new(bytes.Buffer)
w := NewWriter(buf)
testCreate(t, w, "foo", []byte(testString), Store)
testCreate(t, w, "bar", largeData, Deflate)
if err := w.Close(); err != nil {
t.Fatal(err)
}
// read it back
r, err := NewReader(sliceReaderAt(buf.Bytes()), int64(buf.Len()))
if err != nil {
t.Fatal(err)
}
testReadFile(t, r.File[0], []byte(testString))
testReadFile(t, r.File[1], largeData)
}
func testCreate(t *testing.T, w *Writer, name string, data []byte, method uint16) {
header := &FileHeader{
Name: name,
Method: method,
}
f, err := w.CreateHeader(header)
if err != nil {
t.Fatal(err)
}
_, err = f.Write(data)
if err != nil {
t.Fatal(err)
}
}
func testReadFile(t *testing.T, f *File, data []byte) {
rc, err := f.Open()
if err != nil {
t.Fatal("opening:", err)
}
b, err := ioutil.ReadAll(rc)
if err != nil {
t.Fatal("reading:", err)
}
err = rc.Close()
if err != nil {
t.Fatal("closing:", err)
}
if !bytes.Equal(b, data) {
t.Errorf("File contents %q, want %q", b, data)
}
}

View File

@ -0,0 +1,57 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Tests that involve both reading and writing.
package zip
import (
"bytes"
"fmt"
"os"
"testing"
)
type stringReaderAt string
func (s stringReaderAt) ReadAt(p []byte, off int64) (n int, err os.Error) {
if off >= int64(len(s)) {
return 0, os.EOF
}
n = copy(p, s[off:])
return
}
func TestOver65kFiles(t *testing.T) {
if testing.Short() {
t.Logf("slow test; skipping")
return
}
buf := new(bytes.Buffer)
w := NewWriter(buf)
const nFiles = (1 << 16) + 42
for i := 0; i < nFiles; i++ {
_, err := w.Create(fmt.Sprintf("%d.dat", i))
if err != nil {
t.Fatalf("creating file %d: %v", i, err)
}
}
if err := w.Close(); err != nil {
t.Fatalf("Writer.Close: %v", err)
}
rat := stringReaderAt(buf.String())
zr, err := NewReader(rat, int64(len(rat)))
if err != nil {
t.Fatalf("NewReader: %v", err)
}
if got := len(zr.File); got != nFiles {
t.Fatalf("File contains %d files, want %d", got, nFiles)
}
for i := 0; i < nFiles; i++ {
want := fmt.Sprintf("%d.dat", i)
if zr.File[i].Name != want {
t.Fatalf("File(%d) = %q, want %q", i, zr.File[i].Name, want)
}
}
}

View File

@ -20,6 +20,7 @@ package asn1
// everything by any means.
import (
"big"
"fmt"
"os"
"reflect"
@ -88,6 +89,27 @@ func parseInt(bytes []byte) (int, os.Error) {
return int(ret64), nil
}
var bigOne = big.NewInt(1)
// parseBigInt treats the given bytes as a big-endian, signed integer and returns
// the result.
func parseBigInt(bytes []byte) *big.Int {
ret := new(big.Int)
if len(bytes) > 0 && bytes[0]&0x80 == 0x80 {
// This is a negative number.
notBytes := make([]byte, len(bytes))
for i := range notBytes {
notBytes[i] = ^bytes[i]
}
ret.SetBytes(notBytes)
ret.Add(ret, bigOne)
ret.Neg(ret)
return ret
}
ret.SetBytes(bytes)
return ret
}
// BIT STRING
// BitString is the structure to use when you want an ASN.1 BIT STRING type. A
@ -127,7 +149,7 @@ func (b BitString) RightAlign() []byte {
return a
}
// parseBitString parses an ASN.1 bit string from the given byte array and returns it.
// parseBitString parses an ASN.1 bit string from the given byte slice and returns it.
func parseBitString(bytes []byte) (ret BitString, err os.Error) {
if len(bytes) == 0 {
err = SyntaxError{"zero length BIT STRING"}
@ -164,9 +186,9 @@ func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool {
return true
}
// parseObjectIdentifier parses an OBJECT IDENTIFER from the given bytes and
// returns it. An object identifer is a sequence of variable length integers
// that are assigned in a hierarachy.
// parseObjectIdentifier parses an OBJECT IDENTIFIER from the given bytes and
// returns it. An object identifier is a sequence of variable length integers
// that are assigned in a hierarchy.
func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) {
if len(bytes) == 0 {
err = SyntaxError{"zero length OBJECT IDENTIFIER"}
@ -198,14 +220,13 @@ func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) {
// An Enumerated is represented as a plain int.
type Enumerated int
// FLAG
// A Flag accepts any data and is set to true if present.
type Flag bool
// parseBase128Int parses a base-128 encoded int from the given offset in the
// given byte array. It returns the value and the new offset.
// given byte slice. It returns the value and the new offset.
func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err os.Error) {
offset = initOffset
for shifted := 0; offset < len(bytes); shifted++ {
@ -237,7 +258,7 @@ func parseUTCTime(bytes []byte) (ret *time.Time, err os.Error) {
return
}
// parseGeneralizedTime parses the GeneralizedTime from the given byte array
// parseGeneralizedTime parses the GeneralizedTime from the given byte slice
// and returns the resulting time.
func parseGeneralizedTime(bytes []byte) (ret *time.Time, err os.Error) {
return time.Parse("20060102150405Z0700", string(bytes))
@ -269,7 +290,7 @@ func isPrintable(b byte) bool {
b == ':' ||
b == '=' ||
b == '?' ||
// This is techincally not allowed in a PrintableString.
// This is technically not allowed in a PrintableString.
// However, x509 certificates with wildcard strings don't
// always use the correct string type so we permit it.
b == '*'
@ -278,7 +299,7 @@ func isPrintable(b byte) bool {
// IA5String
// parseIA5String parses a ASN.1 IA5String (ASCII string) from the given
// byte array and returns it.
// byte slice and returns it.
func parseIA5String(bytes []byte) (ret string, err os.Error) {
for _, b := range bytes {
if b >= 0x80 {
@ -293,11 +314,19 @@ func parseIA5String(bytes []byte) (ret string, err os.Error) {
// T61String
// parseT61String parses a ASN.1 T61String (8-bit clean string) from the given
// byte array and returns it.
// byte slice and returns it.
func parseT61String(bytes []byte) (ret string, err os.Error) {
return string(bytes), nil
}
// UTF8String
// parseUTF8String parses a ASN.1 UTF8String (raw UTF-8) from the given byte
// array and returns it.
func parseUTF8String(bytes []byte) (ret string, err os.Error) {
return string(bytes), nil
}
// A RawValue represents an undecoded ASN.1 object.
type RawValue struct {
Class, Tag int
@ -314,7 +343,7 @@ type RawContent []byte
// Tagging
// parseTagAndLength parses an ASN.1 tag and length pair from the given offset
// into a byte array. It returns the parsed data and the new offset. SET and
// into a byte slice. It returns the parsed data and the new offset. SET and
// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we
// don't distinguish between ordered and unordered objects in this code.
func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err os.Error) {
@ -371,7 +400,7 @@ func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset i
}
// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse
// a number of ASN.1 values from the given byte array and returns them as a
// a number of ASN.1 values from the given byte slice and returns them as a
// slice of Go values of the given type.
func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err os.Error) {
expectedTag, compoundType, ok := getUniversalType(elemType)
@ -425,6 +454,7 @@ var (
timeType = reflect.TypeOf(&time.Time{})
rawValueType = reflect.TypeOf(RawValue{})
rawContentsType = reflect.TypeOf(RawContent(nil))
bigIntType = reflect.TypeOf(new(big.Int))
)
// invalidLength returns true iff offset + length > sliceLength, or if the
@ -433,7 +463,7 @@ func invalidLength(offset, length, sliceLength int) bool {
return offset+length < offset || offset+length > sliceLength
}
// parseField is the main parsing function. Given a byte array and an offset
// parseField is the main parsing function. Given a byte slice and an offset
// into the array, it will try to parse a suitable ASN.1 value out and store it
// in the given Value.
func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err os.Error) {
@ -550,16 +580,15 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
}
}
// Special case for strings: PrintableString and IA5String both map to
// the Go type string. getUniversalType returns the tag for
// PrintableString when it sees a string so, if we see an IA5String on
// the wire, we change the universal type to match.
if universalTag == tagPrintableString && t.tag == tagIA5String {
universalTag = tagIA5String
}
// Likewise for GeneralString
if universalTag == tagPrintableString && t.tag == tagGeneralString {
universalTag = tagGeneralString
// Special case for strings: all the ASN.1 string types map to the Go
// type string. getUniversalType returns the tag for PrintableString
// when it sees a string, so if we see a different string type on the
// wire, we change the universal type to match.
if universalTag == tagPrintableString {
switch t.tag {
case tagIA5String, tagGeneralString, tagT61String, tagUTF8String:
universalTag = t.tag
}
}
// Special case for time: UTCTime and GeneralizedTime both map to the
@ -639,6 +668,10 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
case flagType:
v.SetBool(true)
return
case bigIntType:
parsedInt := parseBigInt(innerBytes)
v.Set(reflect.ValueOf(parsedInt))
return
}
switch val := v; val.Kind() {
case reflect.Bool:
@ -648,23 +681,21 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
}
err = err1
return
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch val.Type().Kind() {
case reflect.Int:
parsedInt, err1 := parseInt(innerBytes)
if err1 == nil {
val.SetInt(int64(parsedInt))
}
err = err1
return
case reflect.Int64:
parsedInt, err1 := parseInt64(innerBytes)
if err1 == nil {
val.SetInt(parsedInt)
}
err = err1
return
case reflect.Int, reflect.Int32:
parsedInt, err1 := parseInt(innerBytes)
if err1 == nil {
val.SetInt(int64(parsedInt))
}
err = err1
return
case reflect.Int64:
parsedInt, err1 := parseInt64(innerBytes)
if err1 == nil {
val.SetInt(parsedInt)
}
err = err1
return
// TODO(dfc) Add support for the remaining integer types
case reflect.Struct:
structType := fieldType
@ -680,7 +711,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
if i == 0 && field.Type == rawContentsType {
continue
}
innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag))
innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag.Get("asn1")))
if err != nil {
return
}
@ -711,6 +742,8 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
v, err = parseIA5String(innerBytes)
case tagT61String:
v, err = parseT61String(innerBytes)
case tagUTF8String:
v, err = parseUTF8String(innerBytes)
case tagGeneralString:
// GeneralString is specified in ISO-2022/ECMA-35,
// A brief review suggests that it includes structures
@ -725,7 +758,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
}
return
}
err = StructuralError{"unknown Go type"}
err = StructuralError{"unsupported: " + v.Type().String()}
return
}
@ -752,7 +785,7 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
// Because Unmarshal uses the reflect package, the structs
// being written to must use upper case field names.
//
// An ASN.1 INTEGER can be written to an int or int64.
// An ASN.1 INTEGER can be written to an int, int32 or int64.
// If the encoded value does not fit in the Go type,
// Unmarshal returns a parse error.
//

View File

@ -42,6 +42,64 @@ func TestParseInt64(t *testing.T) {
}
}
type int32Test struct {
in []byte
ok bool
out int32
}
var int32TestData = []int32Test{
{[]byte{0x00}, true, 0},
{[]byte{0x7f}, true, 127},
{[]byte{0x00, 0x80}, true, 128},
{[]byte{0x01, 0x00}, true, 256},
{[]byte{0x80}, true, -128},
{[]byte{0xff, 0x7f}, true, -129},
{[]byte{0xff, 0xff, 0xff, 0xff}, true, -1},
{[]byte{0xff}, true, -1},
{[]byte{0x80, 0x00, 0x00, 0x00}, true, -2147483648},
{[]byte{0x80, 0x00, 0x00, 0x00, 0x00}, false, 0},
}
func TestParseInt32(t *testing.T) {
for i, test := range int32TestData {
ret, err := parseInt(test.in)
if (err == nil) != test.ok {
t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok)
}
if test.ok && int32(ret) != test.out {
t.Errorf("#%d: Bad result: %v (expected %v)", i, ret, test.out)
}
}
}
var bigIntTests = []struct {
in []byte
base10 string
}{
{[]byte{0xff}, "-1"},
{[]byte{0x00}, "0"},
{[]byte{0x01}, "1"},
{[]byte{0x00, 0xff}, "255"},
{[]byte{0xff, 0x00}, "-256"},
{[]byte{0x01, 0x00}, "256"},
}
func TestParseBigInt(t *testing.T) {
for i, test := range bigIntTests {
ret := parseBigInt(test.in)
if ret.String() != test.base10 {
t.Errorf("#%d: bad result from %x, got %s want %s", i, test.in, ret.String(), test.base10)
}
fw := newForkableWriter()
marshalBigInt(fw, ret)
result := fw.Bytes()
if !bytes.Equal(result, test.in) {
t.Errorf("#%d: got %x from marshaling %s, want %x", i, result, ret, test.in)
}
}
}
type bitStringTest struct {
in []byte
ok bool
@ -148,10 +206,10 @@ type timeTest struct {
}
var utcTestData = []timeTest{
{"910506164540-0700", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, -7 * 60 * 60, ""}},
{"910506164540+0730", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 7*60*60 + 30*60, ""}},
{"910506234540Z", true, &time.Time{1991, 05, 06, 23, 45, 40, 0, 0, "UTC"}},
{"9105062345Z", true, &time.Time{1991, 05, 06, 23, 45, 0, 0, 0, "UTC"}},
{"910506164540-0700", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 0, -7 * 60 * 60, ""}},
{"910506164540+0730", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 0, 7*60*60 + 30*60, ""}},
{"910506234540Z", true, &time.Time{1991, 05, 06, 23, 45, 40, 0, 0, 0, "UTC"}},
{"9105062345Z", true, &time.Time{1991, 05, 06, 23, 45, 0, 0, 0, 0, "UTC"}},
{"a10506234540Z", false, nil},
{"91a506234540Z", false, nil},
{"9105a6234540Z", false, nil},
@ -177,10 +235,10 @@ func TestUTCTime(t *testing.T) {
}
var generalizedTimeTestData = []timeTest{
{"20100102030405Z", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, "UTC"}},
{"20100102030405Z", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, 0, "UTC"}},
{"20100102030405", false, nil},
{"20100102030405+0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 6*60*60 + 7*60, ""}},
{"20100102030405-0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, -6*60*60 - 7*60, ""}},
{"20100102030405+0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, 6*60*60 + 7*60, ""}},
{"20100102030405-0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, -6*60*60 - 7*60, ""}},
}
func TestGeneralizedTime(t *testing.T) {
@ -272,11 +330,11 @@ type TestObjectIdentifierStruct struct {
}
type TestContextSpecificTags struct {
A int "tag:1"
A int `asn1:"tag:1"`
}
type TestContextSpecificTags2 struct {
A int "explicit,tag:1"
A int `asn1:"explicit,tag:1"`
B int
}
@ -326,7 +384,7 @@ type Certificate struct {
}
type TBSCertificate struct {
Version int "optional,explicit,default:0,tag:0"
Version int `asn1:"optional,explicit,default:0,tag:0"`
SerialNumber RawValue
SignatureAlgorithm AlgorithmIdentifier
Issuer RDNSequence

View File

@ -10,7 +10,7 @@ import (
"strings"
)
// ASN.1 objects have metadata preceeding them:
// ASN.1 objects have metadata preceding them:
// the tag: the type of the object
// a flag denoting if this object is compound or not
// the class type: the namespace of the tag
@ -25,6 +25,7 @@ const (
tagOctetString = 4
tagOID = 6
tagEnum = 10
tagUTF8String = 12
tagSequence = 16
tagSet = 17
tagPrintableString = 19
@ -83,7 +84,7 @@ type fieldParameters struct {
// parseFieldParameters will parse it into a fieldParameters structure,
// ignoring unknown parts of the string.
func parseFieldParameters(str string) (ret fieldParameters) {
for _, part := range strings.Split(str, ",", -1) {
for _, part := range strings.Split(str, ",") {
switch {
case part == "optional":
ret.optional = true
@ -132,6 +133,8 @@ func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) {
return tagUTCTime, false, true
case enumeratedType:
return tagEnum, false, true
case bigIntType:
return tagInteger, false, true
}
switch t.Kind() {
case reflect.Bool:

View File

@ -5,6 +5,7 @@
package asn1
import (
"big"
"bytes"
"fmt"
"io"
@ -125,6 +126,43 @@ func int64Length(i int64) (numBytes int) {
return
}
func marshalBigInt(out *forkableWriter, n *big.Int) (err os.Error) {
if n.Sign() < 0 {
// A negative number has to be converted to two's-complement
// form. So we'll subtract 1 and invert. If the
// most-significant-bit isn't set then we'll need to pad the
// beginning with 0xff in order to keep the number negative.
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bytes := nMinus1.Bytes()
for i := range bytes {
bytes[i] ^= 0xff
}
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
err = out.WriteByte(0xff)
if err != nil {
return
}
}
_, err = out.Write(bytes)
} else if n.Sign() == 0 {
// Zero is written as a single 0 zero rather than no bytes.
err = out.WriteByte(0x00)
} else {
bytes := n.Bytes()
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
// We'll have to pad this with 0x00 in order to stop it
// looking like a negative number.
err = out.WriteByte(0)
if err != nil {
return
}
}
_, err = out.Write(bytes)
}
return
}
func marshalLength(out *forkableWriter, i int) (err os.Error) {
n := lengthLength(i)
@ -334,6 +372,8 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
return marshalBitString(out, value.Interface().(BitString))
case objectIdentifierType:
return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
case bigIntType:
return marshalBigInt(out, value.Interface().(*big.Int))
}
switch v := value; v.Kind() {
@ -351,7 +391,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
startingField := 0
// If the first element of the structure is a non-empty
// RawContents, then we don't bother serialising the rest.
// RawContents, then we don't bother serializing the rest.
if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
s := v.Field(0)
if s.Len() > 0 {
@ -361,7 +401,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
}
/* The RawContents will contain the tag and
* length fields but we'll also be writing
* those outselves, so we strip them out of
* those ourselves, so we strip them out of
* bytes */
_, err = out.Write(stripTagAndLength(bytes))
return
@ -373,7 +413,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
for i := startingField; i < t.NumField(); i++ {
var pre *forkableWriter
pre, out = out.fork()
err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag))
err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
if err != nil {
return
}
@ -418,6 +458,10 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
return marshalField(out, v.Elem(), params)
}
if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
return
}
if v.Type() == rawValueType {
rv := v.Interface().(RawValue)
err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
@ -428,10 +472,6 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
return
}
if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
return
}
tag, isCompound, ok := getUniversalType(v.Type())
if !ok {
err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}

View File

@ -30,19 +30,23 @@ type rawContentsStruct struct {
}
type implicitTagTest struct {
A int "implicit,tag:5"
A int `asn1:"implicit,tag:5"`
}
type explicitTagTest struct {
A int "explicit,tag:5"
A int `asn1:"explicit,tag:5"`
}
type ia5StringTest struct {
A string "ia5"
A string `asn1:"ia5"`
}
type printableStringTest struct {
A string "printable"
A string `asn1:"printable"`
}
type optionalRawValueTest struct {
A RawValue `asn1:"optional"`
}
type testSET []int
@ -102,6 +106,7 @@ var marshalTests = []marshalTest{
"7878787878787878787878787878787878787878787878787878787878787878",
},
{ia5StringTest{"test"}, "3006160474657374"},
{optionalRawValueTest{}, "3000"},
{printableStringTest{"test"}, "3006130474657374"},
{printableStringTest{"test*"}, "30071305746573742a"},
{rawContentsStruct{nil, 64}, "3003020140"},

View File

@ -27,7 +27,6 @@ const (
_M2 = _B2 - 1 // half digit mask
)
// ----------------------------------------------------------------------------
// Elementary operations on words
//
@ -43,7 +42,6 @@ func addWW_g(x, y, c Word) (z1, z0 Word) {
return
}
// z1<<_W + z0 = x-y-c, with c == 0 or 1
func subWW_g(x, y, c Word) (z1, z0 Word) {
yc := y + c
@ -54,7 +52,6 @@ func subWW_g(x, y, c Word) (z1, z0 Word) {
return
}
// z1<<_W + z0 = x*y
func mulWW(x, y Word) (z1, z0 Word) { return mulWW_g(x, y) }
// Adapted from Warren, Hacker's Delight, p. 132.
@ -73,7 +70,6 @@ func mulWW_g(x, y Word) (z1, z0 Word) {
return
}
// z1<<_W + z0 = x*y + c
func mulAddWWW_g(x, y, c Word) (z1, z0 Word) {
z1, zz0 := mulWW(x, y)
@ -83,7 +79,6 @@ func mulAddWWW_g(x, y, c Word) (z1, z0 Word) {
return
}
// Length of x in bits.
func bitLen(x Word) (n int) {
for ; x >= 0x100; x >>= 8 {
@ -95,7 +90,6 @@ func bitLen(x Word) (n int) {
return
}
// log2 computes the integer binary logarithm of x.
// The result is the integer n for which 2^n <= x < 2^(n+1).
// If x == 0, the result is -1.
@ -103,13 +97,11 @@ func log2(x Word) int {
return bitLen(x) - 1
}
// Number of leading zeros in x.
func leadingZeros(x Word) uint {
return uint(_W - bitLen(x))
}
// q = (u1<<_W + u0 - r)/y
func divWW(x1, x0, y Word) (q, r Word) { return divWW_g(x1, x0, y) }
// Adapted from Warren, Hacker's Delight, p. 152.
@ -155,7 +147,6 @@ again2:
return q1*_B2 + q0, (un21*_B2 + un0 - q0*v) >> s
}
func addVV(z, x, y []Word) (c Word) { return addVV_g(z, x, y) }
func addVV_g(z, x, y []Word) (c Word) {
for i := range z {
@ -164,7 +155,6 @@ func addVV_g(z, x, y []Word) (c Word) {
return
}
func subVV(z, x, y []Word) (c Word) { return subVV_g(z, x, y) }
func subVV_g(z, x, y []Word) (c Word) {
for i := range z {
@ -173,7 +163,6 @@ func subVV_g(z, x, y []Word) (c Word) {
return
}
func addVW(z, x []Word, y Word) (c Word) { return addVW_g(z, x, y) }
func addVW_g(z, x []Word, y Word) (c Word) {
c = y
@ -183,7 +172,6 @@ func addVW_g(z, x []Word, y Word) (c Word) {
return
}
func subVW(z, x []Word, y Word) (c Word) { return subVW_g(z, x, y) }
func subVW_g(z, x []Word, y Word) (c Word) {
c = y
@ -193,9 +181,8 @@ func subVW_g(z, x []Word, y Word) (c Word) {
return
}
func shlVW(z, x []Word, s Word) (c Word) { return shlVW_g(z, x, s) }
func shlVW_g(z, x []Word, s Word) (c Word) {
func shlVU(z, x []Word, s uint) (c Word) { return shlVU_g(z, x, s) }
func shlVU_g(z, x []Word, s uint) (c Word) {
if n := len(z); n > 0 {
ŝ := _W - s
w1 := x[n-1]
@ -210,9 +197,8 @@ func shlVW_g(z, x []Word, s Word) (c Word) {
return
}
func shrVW(z, x []Word, s Word) (c Word) { return shrVW_g(z, x, s) }
func shrVW_g(z, x []Word, s Word) (c Word) {
func shrVU(z, x []Word, s uint) (c Word) { return shrVU_g(z, x, s) }
func shrVU_g(z, x []Word, s uint) (c Word) {
if n := len(z); n > 0 {
ŝ := _W - s
w1 := x[0]
@ -227,7 +213,6 @@ func shrVW_g(z, x []Word, s Word) (c Word) {
return
}
func mulAddVWW(z, x []Word, y, r Word) (c Word) { return mulAddVWW_g(z, x, y, r) }
func mulAddVWW_g(z, x []Word, y, r Word) (c Word) {
c = r
@ -237,7 +222,6 @@ func mulAddVWW_g(z, x []Word, y, r Word) (c Word) {
return
}
func addMulVVW(z, x []Word, y Word) (c Word) { return addMulVVW_g(z, x, y) }
func addMulVVW_g(z, x []Word, y Word) (c Word) {
for i := range z {
@ -248,7 +232,6 @@ func addMulVVW_g(z, x []Word, y Word) (c Word) {
return
}
func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) { return divWVW_g(z, xn, x, y) }
func divWVW_g(z []Word, xn Word, x []Word, y Word) (r Word) {
r = xn

View File

@ -11,8 +11,8 @@ func addVV(z, x, y []Word) (c Word)
func subVV(z, x, y []Word) (c Word)
func addVW(z, x []Word, y Word) (c Word)
func subVW(z, x []Word, y Word) (c Word)
func shlVW(z, x []Word, s Word) (c Word)
func shrVW(z, x []Word, s Word) (c Word)
func shlVU(z, x []Word, s uint) (c Word)
func shrVU(z, x []Word, s uint) (c Word)
func mulAddVWW(z, x []Word, y, r Word) (c Word)
func addMulVVW(z, x []Word, y Word) (c Word)
func divWVW(z []Word, xn Word, x []Word, y Word) (r Word)

View File

@ -6,7 +6,6 @@ package big
import "testing"
type funWW func(x, y, c Word) (z1, z0 Word)
type argWW struct {
x, y, c, z1, z0 Word
@ -26,7 +25,6 @@ var sumWW = []argWW{
{_M, _M, 1, 1, _M},
}
func testFunWW(t *testing.T, msg string, f funWW, a argWW) {
z1, z0 := f(a.x, a.y, a.c)
if z1 != a.z1 || z0 != a.z0 {
@ -34,7 +32,6 @@ func testFunWW(t *testing.T, msg string, f funWW, a argWW) {
}
}
func TestFunWW(t *testing.T) {
for _, a := range sumWW {
arg := a
@ -51,7 +48,6 @@ func TestFunWW(t *testing.T) {
}
}
type funVV func(z, x, y []Word) (c Word)
type argVV struct {
z, x, y nat
@ -70,7 +66,6 @@ var sumVV = []argVV{
{nat{0, 0, 0, 0}, nat{_M, 0, _M, 0}, nat{1, _M, 0, _M}, 1},
}
func testFunVV(t *testing.T, msg string, f funVV, a argVV) {
z := make(nat, len(a.z))
c := f(z, a.x, a.y)
@ -85,7 +80,6 @@ func testFunVV(t *testing.T, msg string, f funVV, a argVV) {
}
}
func TestFunVV(t *testing.T) {
for _, a := range sumVV {
arg := a
@ -106,7 +100,6 @@ func TestFunVV(t *testing.T) {
}
}
type funVW func(z, x []Word, y Word) (c Word)
type argVW struct {
z, x nat
@ -169,7 +162,6 @@ var rshVW = []argVW{
{nat{_M, _M, _M >> 20}, nat{_M, _M, _M}, 20, _M << (_W - 20) & _M},
}
func testFunVW(t *testing.T, msg string, f funVW, a argVW) {
z := make(nat, len(a.z))
c := f(z, a.x, a.y)
@ -184,6 +176,11 @@ func testFunVW(t *testing.T, msg string, f funVW, a argVW) {
}
}
func makeFunVW(f func(z, x []Word, s uint) (c Word)) funVW {
return func(z, x []Word, s Word) (c Word) {
return f(z, x, uint(s))
}
}
func TestFunVW(t *testing.T) {
for _, a := range sumVW {
@ -196,20 +193,23 @@ func TestFunVW(t *testing.T) {
testFunVW(t, "subVW", subVW, arg)
}
shlVW_g := makeFunVW(shlVU_g)
shlVW := makeFunVW(shlVU)
for _, a := range lshVW {
arg := a
testFunVW(t, "shlVW_g", shlVW_g, arg)
testFunVW(t, "shlVW", shlVW, arg)
testFunVW(t, "shlVU_g", shlVW_g, arg)
testFunVW(t, "shlVU", shlVW, arg)
}
shrVW_g := makeFunVW(shrVU_g)
shrVW := makeFunVW(shrVU)
for _, a := range rshVW {
arg := a
testFunVW(t, "shrVW_g", shrVW_g, arg)
testFunVW(t, "shrVW", shrVW, arg)
testFunVW(t, "shrVU_g", shrVW_g, arg)
testFunVW(t, "shrVU", shrVW, arg)
}
}
type funVWW func(z, x []Word, y, r Word) (c Word)
type argVWW struct {
z, x nat
@ -243,7 +243,6 @@ var prodVWW = []argVWW{
{nat{_M<<7&_M + 1<<6, _M, _M, _M}, nat{_M, _M, _M, _M}, 1 << 7, 1 << 6, _M >> (_W - 7)},
}
func testFunVWW(t *testing.T, msg string, f funVWW, a argVWW) {
z := make(nat, len(a.z))
c := f(z, a.x, a.y, a.r)
@ -258,7 +257,6 @@ func testFunVWW(t *testing.T, msg string, f funVWW, a argVWW) {
}
}
// TODO(gri) mulAddVWW and divWVW are symmetric operations but
// their signature is not symmetric. Try to unify.
@ -285,7 +283,6 @@ func testFunWVW(t *testing.T, msg string, f funWVW, a argWVW) {
}
}
func TestFunVWW(t *testing.T) {
for _, a := range prodVWW {
arg := a
@ -300,7 +297,6 @@ func TestFunVWW(t *testing.T) {
}
}
var mulWWTests = []struct {
x, y Word
q, r Word
@ -309,7 +305,6 @@ var mulWWTests = []struct {
// 32 bit only: {0xc47dfa8c, 50911, 0x98a4, 0x998587f4},
}
func TestMulWW(t *testing.T) {
for i, test := range mulWWTests {
q, r := mulWW_g(test.x, test.y)
@ -319,7 +314,6 @@ func TestMulWW(t *testing.T) {
}
}
var mulAddWWWTests = []struct {
x, y, c Word
q, r Word
@ -331,7 +325,6 @@ var mulAddWWWTests = []struct {
{_M, _M, _M, _M, 0},
}
func TestMulAddWWW(t *testing.T) {
for i, test := range mulAddWWWTests {
q, r := mulAddWWW_g(test.x, test.y, test.c)

View File

@ -19,10 +19,8 @@ import (
"time"
)
var calibrate = flag.Bool("calibrate", false, "run calibration test")
// measure returns the time to run f
func measure(f func()) int64 {
const N = 100
@ -34,7 +32,6 @@ func measure(f func()) int64 {
return (stop - start) / N
}
func computeThresholds() {
fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
fmt.Printf("(run repeatedly for good results)\n")
@ -84,7 +81,6 @@ func computeThresholds() {
}
}
func TestCalibrate(t *testing.T) {
if *calibrate {
computeThresholds()

View File

@ -13,13 +13,11 @@ import (
"testing"
)
type matrix struct {
n, m int
a []*Rat
}
func (a *matrix) at(i, j int) *Rat {
if !(0 <= i && i < a.n && 0 <= j && j < a.m) {
panic("index out of range")
@ -27,7 +25,6 @@ func (a *matrix) at(i, j int) *Rat {
return a.a[i*a.m+j]
}
func (a *matrix) set(i, j int, x *Rat) {
if !(0 <= i && i < a.n && 0 <= j && j < a.m) {
panic("index out of range")
@ -35,7 +32,6 @@ func (a *matrix) set(i, j int, x *Rat) {
a.a[i*a.m+j] = x
}
func newMatrix(n, m int) *matrix {
if !(0 <= n && 0 <= m) {
panic("illegal matrix")
@ -47,7 +43,6 @@ func newMatrix(n, m int) *matrix {
return a
}
func newUnit(n int) *matrix {
a := newMatrix(n, n)
for i := 0; i < n; i++ {
@ -62,7 +57,6 @@ func newUnit(n int) *matrix {
return a
}
func newHilbert(n int) *matrix {
a := newMatrix(n, n)
for i := 0; i < n; i++ {
@ -73,7 +67,6 @@ func newHilbert(n int) *matrix {
return a
}
func newInverseHilbert(n int) *matrix {
a := newMatrix(n, n)
for i := 0; i < n; i++ {
@ -98,7 +91,6 @@ func newInverseHilbert(n int) *matrix {
return a
}
func (a *matrix) mul(b *matrix) *matrix {
if a.m != b.n {
panic("illegal matrix multiply")
@ -116,7 +108,6 @@ func (a *matrix) mul(b *matrix) *matrix {
return c
}
func (a *matrix) eql(b *matrix) bool {
if a.n != b.n || a.m != b.m {
return false
@ -131,7 +122,6 @@ func (a *matrix) eql(b *matrix) bool {
return true
}
func (a *matrix) String() string {
s := ""
for i := 0; i < a.n; i++ {
@ -143,7 +133,6 @@ func (a *matrix) String() string {
return s
}
func doHilbert(t *testing.T, n int) {
a := newHilbert(n)
b := newInverseHilbert(n)
@ -160,12 +149,10 @@ func doHilbert(t *testing.T, n int) {
}
}
func TestHilbert(t *testing.T) {
doHilbert(t, 10)
}
func BenchmarkHilbert(b *testing.B) {
for i := 0; i < b.N; i++ {
doHilbert(nil, 10)

View File

@ -8,8 +8,10 @@ package big
import (
"fmt"
"io"
"os"
"rand"
"strings"
)
// An Int represents a signed multi-precision integer.
@ -19,10 +21,8 @@ type Int struct {
abs nat // absolute value of the integer
}
var intOne = &Int{false, natOne}
// Sign returns:
//
// -1 if x < 0
@ -39,7 +39,6 @@ func (x *Int) Sign() int {
return 1
}
// SetInt64 sets z to x and returns z.
func (z *Int) SetInt64(x int64) *Int {
neg := false
@ -52,13 +51,11 @@ func (z *Int) SetInt64(x int64) *Int {
return z
}
// NewInt allocates and returns a new Int set to x.
func NewInt(x int64) *Int {
return new(Int).SetInt64(x)
}
// Set sets z to x and returns z.
func (z *Int) Set(x *Int) *Int {
z.abs = z.abs.set(x.abs)
@ -66,7 +63,6 @@ func (z *Int) Set(x *Int) *Int {
return z
}
// Abs sets z to |x| (the absolute value of x) and returns z.
func (z *Int) Abs(x *Int) *Int {
z.abs = z.abs.set(x.abs)
@ -74,7 +70,6 @@ func (z *Int) Abs(x *Int) *Int {
return z
}
// Neg sets z to -x and returns z.
func (z *Int) Neg(x *Int) *Int {
z.abs = z.abs.set(x.abs)
@ -82,7 +77,6 @@ func (z *Int) Neg(x *Int) *Int {
return z
}
// Add sets z to the sum x+y and returns z.
func (z *Int) Add(x, y *Int) *Int {
neg := x.neg
@ -104,7 +98,6 @@ func (z *Int) Add(x, y *Int) *Int {
return z
}
// Sub sets z to the difference x-y and returns z.
func (z *Int) Sub(x, y *Int) *Int {
neg := x.neg
@ -126,7 +119,6 @@ func (z *Int) Sub(x, y *Int) *Int {
return z
}
// Mul sets z to the product x*y and returns z.
func (z *Int) Mul(x, y *Int) *Int {
// x * y == x * y
@ -138,7 +130,6 @@ func (z *Int) Mul(x, y *Int) *Int {
return z
}
// MulRange sets z to the product of all integers
// in the range [a, b] inclusively and returns z.
// If a > b (empty range), the result is 1.
@ -162,7 +153,6 @@ func (z *Int) MulRange(a, b int64) *Int {
return z
}
// Binomial sets z to the binomial coefficient of (n, k) and returns z.
func (z *Int) Binomial(n, k int64) *Int {
var a, b Int
@ -171,7 +161,6 @@ func (z *Int) Binomial(n, k int64) *Int {
return z.Quo(&a, &b)
}
// Quo sets z to the quotient x/y for y != 0 and returns z.
// If y == 0, a division-by-zero run-time panic occurs.
// See QuoRem for more details.
@ -181,7 +170,6 @@ func (z *Int) Quo(x, y *Int) *Int {
return z
}
// Rem sets z to the remainder x%y for y != 0 and returns z.
// If y == 0, a division-by-zero run-time panic occurs.
// See QuoRem for more details.
@ -191,7 +179,6 @@ func (z *Int) Rem(x, y *Int) *Int {
return z
}
// QuoRem sets z to the quotient x/y and r to the remainder x%y
// and returns the pair (z, r) for y != 0.
// If y == 0, a division-by-zero run-time panic occurs.
@ -209,7 +196,6 @@ func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) {
return z, r
}
// Div sets z to the quotient x/y for y != 0 and returns z.
// If y == 0, a division-by-zero run-time panic occurs.
// See DivMod for more details.
@ -227,7 +213,6 @@ func (z *Int) Div(x, y *Int) *Int {
return z
}
// Mod sets z to the modulus x%y for y != 0 and returns z.
// If y == 0, a division-by-zero run-time panic occurs.
// See DivMod for more details.
@ -248,7 +233,6 @@ func (z *Int) Mod(x, y *Int) *Int {
return z
}
// DivMod sets z to the quotient x div y and m to the modulus x mod y
// and returns the pair (z, m) for y != 0.
// If y == 0, a division-by-zero run-time panic occurs.
@ -281,7 +265,6 @@ func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) {
return z, m
}
// Cmp compares x and y and returns:
//
// -1 if x < y
@ -307,49 +290,197 @@ func (x *Int) Cmp(y *Int) (r int) {
return
}
func (x *Int) String() string {
s := ""
if x.neg {
s = "-"
switch {
case x == nil:
return "<nil>"
case x.neg:
return "-" + x.abs.decimalString()
}
return s + x.abs.string(10)
return x.abs.decimalString()
}
func fmtbase(ch int) int {
func charset(ch int) string {
switch ch {
case 'b':
return 2
return lowercaseDigits[0:2]
case 'o':
return 8
case 'd':
return 10
return lowercaseDigits[0:8]
case 'd', 's', 'v':
return lowercaseDigits[0:10]
case 'x':
return 16
return lowercaseDigits[0:16]
case 'X':
return uppercaseDigits[0:16]
}
return 10
return "" // unknown format
}
// write count copies of text to s
func writeMultiple(s fmt.State, text string, count int) {
if len(text) > 0 {
b := []byte(text)
for ; count > 0; count-- {
s.Write(b)
}
}
}
// Format is a support routine for fmt.Formatter. It accepts
// the formats 'b' (binary), 'o' (octal), 'd' (decimal) and
// 'x' (hexadecimal).
// the formats 'b' (binary), 'o' (octal), 'd' (decimal), 'x'
// (lowercase hexadecimal), and 'X' (uppercase hexadecimal).
// Also supported are the full suite of package fmt's format
// verbs for integral types, including '+', '-', and ' '
// for sign control, '#' for leading zero in octal and for
// hexadecimal, a leading "0x" or "0X" for "%#x" and "%#X"
// respectively, specification of minimum digits precision,
// output field width, space or zero padding, and left or
// right justification.
//
func (x *Int) Format(s fmt.State, ch int) {
if x == nil {
cs := charset(ch)
// special cases
switch {
case cs == "":
// unknown format
fmt.Fprintf(s, "%%!%c(big.Int=%s)", ch, x.String())
return
case x == nil:
fmt.Fprint(s, "<nil>")
return
}
if x.neg {
fmt.Fprint(s, "-")
// determine sign character
sign := ""
switch {
case x.neg:
sign = "-"
case s.Flag('+'): // supersedes ' ' when both specified
sign = "+"
case s.Flag(' '):
sign = " "
}
fmt.Fprint(s, x.abs.string(fmtbase(ch)))
// determine prefix characters for indicating output base
prefix := ""
if s.Flag('#') {
switch ch {
case 'o': // octal
prefix = "0"
case 'x': // hexadecimal
prefix = "0x"
case 'X':
prefix = "0X"
}
}
// determine digits with base set by len(cs) and digit characters from cs
digits := x.abs.string(cs)
// number of characters for the three classes of number padding
var left int // space characters to left of digits for right justification ("%8d")
var zeroes int // zero characters (actually cs[0]) as left-most digits ("%.8d")
var right int // space characters to right of digits for left justification ("%-8d")
// determine number padding from precision: the least number of digits to output
precision, precisionSet := s.Precision()
if precisionSet {
switch {
case len(digits) < precision:
zeroes = precision - len(digits) // count of zero padding
case digits == "0" && precision == 0:
return // print nothing if zero value (x == 0) and zero precision ("." or ".0")
}
}
// determine field pad from width: the least number of characters to output
length := len(sign) + len(prefix) + zeroes + len(digits)
if width, widthSet := s.Width(); widthSet && length < width { // pad as specified
switch d := width - length; {
case s.Flag('-'):
// pad on the right with spaces; supersedes '0' when both specified
right = d
case s.Flag('0') && !precisionSet:
// pad with zeroes unless precision also specified
zeroes = d
default:
// pad on the left with spaces
left = d
}
}
// print number as [left pad][sign][prefix][zero pad][digits][right pad]
writeMultiple(s, " ", left)
writeMultiple(s, sign, 1)
writeMultiple(s, prefix, 1)
writeMultiple(s, "0", zeroes)
writeMultiple(s, digits, 1)
writeMultiple(s, " ", right)
}
// scan sets z to the integer value corresponding to the longest possible prefix
// read from r representing a signed integer number in a given conversion base.
// It returns z, the actual conversion base used, and an error, if any. In the
// error case, the value of z is undefined. The syntax follows the syntax of
// integer literals in Go.
//
// The base argument must be 0 or a value from 2 through MaxBase. If the base
// is 0, the string prefix determines the actual conversion base. A prefix of
// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a
// ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10.
//
func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, os.Error) {
// determine sign
ch, _, err := r.ReadRune()
if err != nil {
return z, 0, err
}
neg := false
switch ch {
case '-':
neg = true
case '+': // nothing to do
default:
r.UnreadRune()
}
// Int64 returns the int64 representation of z.
// If z cannot be represented in an int64, the result is undefined.
// determine mantissa
z.abs, base, err = z.abs.scan(r, base)
if err != nil {
return z, base, err
}
z.neg = len(z.abs) > 0 && neg // 0 has no sign
return z, base, nil
}
// Scan is a support routine for fmt.Scanner; it sets z to the value of
// the scanned number. It accepts the formats 'b' (binary), 'o' (octal),
// 'd' (decimal), 'x' (lowercase hexadecimal), and 'X' (uppercase hexadecimal).
func (z *Int) Scan(s fmt.ScanState, ch int) os.Error {
s.SkipSpace() // skip leading space characters
base := 0
switch ch {
case 'b':
base = 2
case 'o':
base = 8
case 'd':
base = 10
case 'x', 'X':
base = 16
case 's', 'v':
// let scan determine the base
default:
return os.NewError("Int.Scan: invalid verb")
}
_, _, err := z.scan(s, base)
return err
}
// Int64 returns the int64 representation of x.
// If x cannot be represented in an int64, the result is undefined.
func (x *Int) Int64() int64 {
if len(x.abs) == 0 {
return 0
@ -364,40 +495,25 @@ func (x *Int) Int64() int64 {
return v
}
// SetString sets z to the value of s, interpreted in the given base,
// and returns z and a boolean indicating success. If SetString fails,
// the value of z is undefined.
//
// If the base argument is 0, the string prefix determines the actual
// conversion base. A prefix of ``0x'' or ``0X'' selects base 16; the
// ``0'' prefix selects base 8, and a ``0b'' or ``0B'' prefix selects
// base 2. Otherwise the selected base is 10.
// The base argument must be 0 or a value from 2 through MaxBase. If the base
// is 0, the string prefix determines the actual conversion base. A prefix of
// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a
// ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10.
//
func (z *Int) SetString(s string, base int) (*Int, bool) {
if len(s) == 0 || base < 0 || base == 1 || 16 < base {
r := strings.NewReader(s)
_, _, err := z.scan(r, base)
if err != nil {
return z, false
}
neg := s[0] == '-'
if neg || s[0] == '+' {
s = s[1:]
if len(s) == 0 {
return z, false
}
}
var scanned int
z.abs, _, scanned = z.abs.scan(s, base)
if scanned != len(s) {
return z, false
}
z.neg = len(z.abs) > 0 && neg // 0 has no sign
return z, true
_, _, err = r.ReadRune()
return z, err == os.EOF // err == os.EOF => scan consumed all of s
}
// SetBytes interprets buf as the bytes of a big-endian unsigned
// integer, sets z to that value, and returns z.
func (z *Int) SetBytes(buf []byte) *Int {
@ -406,21 +522,18 @@ func (z *Int) SetBytes(buf []byte) *Int {
return z
}
// Bytes returns the absolute value of z as a big-endian byte slice.
func (z *Int) Bytes() []byte {
buf := make([]byte, len(z.abs)*_S)
return buf[z.abs.bytes(buf):]
}
// BitLen returns the length of the absolute value of z in bits.
// The bit length of 0 is 0.
func (z *Int) BitLen() int {
return z.abs.bitLen()
}
// Exp sets z = x**y mod m. If m is nil, z = x**y.
// See Knuth, volume 2, section 4.6.3.
func (z *Int) Exp(x, y, m *Int) *Int {
@ -441,7 +554,6 @@ func (z *Int) Exp(x, y, m *Int) *Int {
return z
}
// GcdInt sets d to the greatest common divisor of a and b, which must be
// positive numbers.
// If x and y are not nil, GcdInt sets x and y such that d = a*x + b*y.
@ -500,7 +612,6 @@ func GcdInt(d, x, y, a, b *Int) {
*d = *A
}
// ProbablyPrime performs n Miller-Rabin tests to check whether z is prime.
// If it returns true, z is prime with probability 1 - 1/4^n.
// If it returns false, z is not prime.
@ -508,8 +619,7 @@ func ProbablyPrime(z *Int, n int) bool {
return !z.neg && z.abs.probablyPrime(n)
}
// Rand sets z to a pseudo-random number in [0, n) and returns z.
// Rand sets z to a pseudo-random number in [0, n) and returns z.
func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int {
z.neg = false
if n.neg == true || len(n.abs) == 0 {
@ -520,7 +630,6 @@ func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int {
return z
}
// ModInverse sets z to the multiplicative inverse of g in the group /p (where
// p is a prime) and returns z.
func (z *Int) ModInverse(g, p *Int) *Int {
@ -534,7 +643,6 @@ func (z *Int) ModInverse(g, p *Int) *Int {
return z
}
// Lsh sets z = x << n and returns z.
func (z *Int) Lsh(x *Int, n uint) *Int {
z.abs = z.abs.shl(x.abs, n)
@ -542,7 +650,6 @@ func (z *Int) Lsh(x *Int, n uint) *Int {
return z
}
// Rsh sets z = x >> n and returns z.
func (z *Int) Rsh(x *Int, n uint) *Int {
if x.neg {
@ -559,6 +666,39 @@ func (z *Int) Rsh(x *Int, n uint) *Int {
return z
}
// Bit returns the value of the i'th bit of z. That is, it
// returns (z>>i)&1. The bit index i must be >= 0.
func (z *Int) Bit(i int) uint {
if i < 0 {
panic("negative bit index")
}
if z.neg {
t := nat{}.sub(z.abs, natOne)
return t.bit(uint(i)) ^ 1
}
return z.abs.bit(uint(i))
}
// SetBit sets the i'th bit of z to bit and returns z.
// That is, if bit is 1 SetBit sets z = x | (1 << i);
// if bit is 0 it sets z = x &^ (1 << i). If bit is not 0 or 1,
// SetBit will panic.
func (z *Int) SetBit(x *Int, i int, b uint) *Int {
if i < 0 {
panic("negative bit index")
}
if x.neg {
t := z.abs.sub(x.abs, natOne)
t = t.setBit(t, uint(i), b^1)
z.abs = t.add(t, natOne)
z.neg = len(z.abs) > 0
return z
}
z.abs = z.abs.setBit(x.abs, uint(i), b)
z.neg = false
return z
}
// And sets z = x & y and returns z.
func (z *Int) And(x, y *Int) *Int {
@ -590,7 +730,6 @@ func (z *Int) And(x, y *Int) *Int {
return z
}
// AndNot sets z = x &^ y and returns z.
func (z *Int) AndNot(x, y *Int) *Int {
if x.neg == y.neg {
@ -624,7 +763,6 @@ func (z *Int) AndNot(x, y *Int) *Int {
return z
}
// Or sets z = x | y and returns z.
func (z *Int) Or(x, y *Int) *Int {
if x.neg == y.neg {
@ -655,7 +793,6 @@ func (z *Int) Or(x, y *Int) *Int {
return z
}
// Xor sets z = x ^ y and returns z.
func (z *Int) Xor(x, y *Int) *Int {
if x.neg == y.neg {
@ -686,7 +823,6 @@ func (z *Int) Xor(x, y *Int) *Int {
return z
}
// Not sets z = ^x and returns z.
func (z *Int) Not(x *Int) *Int {
if x.neg {
@ -702,15 +838,14 @@ func (z *Int) Not(x *Int) *Int {
return z
}
// Gob codec version. Permits backward-compatible changes to the encoding.
const version byte = 1
const intGobVersion byte = 1
// GobEncode implements the gob.GobEncoder interface.
func (z *Int) GobEncode() ([]byte, os.Error) {
buf := make([]byte, len(z.abs)*_S+1) // extra byte for version and sign bit
buf := make([]byte, 1+len(z.abs)*_S) // extra byte for version and sign bit
i := z.abs.bytes(buf) - 1 // i >= 0
b := version << 1 // make space for sign bit
b := intGobVersion << 1 // make space for sign bit
if z.neg {
b |= 1
}
@ -718,14 +853,13 @@ func (z *Int) GobEncode() ([]byte, os.Error) {
return buf[i:], nil
}
// GobDecode implements the gob.GobDecoder interface.
func (z *Int) GobDecode(buf []byte) os.Error {
if len(buf) == 0 {
return os.NewError("Int.GobDecode: no data")
}
b := buf[0]
if b>>1 != version {
if b>>1 != intGobVersion {
return os.NewError(fmt.Sprintf("Int.GobDecode: encoding version %d not supported", b>>1))
}
z.neg = b&1 != 0

View File

@ -13,7 +13,6 @@ import (
"testing/quick"
)
func isNormalized(x *Int) bool {
if len(x.abs) == 0 {
return !x.neg
@ -22,13 +21,11 @@ func isNormalized(x *Int) bool {
return x.abs[len(x.abs)-1] != 0
}
type funZZ func(z, x, y *Int) *Int
type argZZ struct {
z, x, y *Int
}
var sumZZ = []argZZ{
{NewInt(0), NewInt(0), NewInt(0)},
{NewInt(1), NewInt(1), NewInt(0)},
@ -38,7 +35,6 @@ var sumZZ = []argZZ{
{NewInt(-1111111110), NewInt(-123456789), NewInt(-987654321)},
}
var prodZZ = []argZZ{
{NewInt(0), NewInt(0), NewInt(0)},
{NewInt(0), NewInt(1), NewInt(0)},
@ -47,7 +43,6 @@ var prodZZ = []argZZ{
// TODO(gri) add larger products
}
func TestSignZ(t *testing.T) {
var zero Int
for _, a := range sumZZ {
@ -59,7 +54,6 @@ func TestSignZ(t *testing.T) {
}
}
func TestSetZ(t *testing.T) {
for _, a := range sumZZ {
var z Int
@ -73,7 +67,6 @@ func TestSetZ(t *testing.T) {
}
}
func TestAbsZ(t *testing.T) {
var zero Int
for _, a := range sumZZ {
@ -90,7 +83,6 @@ func TestAbsZ(t *testing.T) {
}
}
func testFunZZ(t *testing.T, msg string, f funZZ, a argZZ) {
var z Int
f(&z, a.x, a.y)
@ -102,7 +94,6 @@ func testFunZZ(t *testing.T, msg string, f funZZ, a argZZ) {
}
}
func TestSumZZ(t *testing.T) {
AddZZ := func(z, x, y *Int) *Int { return z.Add(x, y) }
SubZZ := func(z, x, y *Int) *Int { return z.Sub(x, y) }
@ -121,7 +112,6 @@ func TestSumZZ(t *testing.T) {
}
}
func TestProdZZ(t *testing.T) {
MulZZ := func(z, x, y *Int) *Int { return z.Mul(x, y) }
for _, a := range prodZZ {
@ -133,7 +123,6 @@ func TestProdZZ(t *testing.T) {
}
}
// mulBytes returns x*y via grade school multiplication. Both inputs
// and the result are assumed to be in big-endian representation (to
// match the semantics of Int.Bytes and Int.SetBytes).
@ -166,7 +155,6 @@ func mulBytes(x, y []byte) []byte {
return z[i:]
}
func checkMul(a, b []byte) bool {
var x, y, z1 Int
x.SetBytes(a)
@ -179,14 +167,12 @@ func checkMul(a, b []byte) bool {
return z1.Cmp(&z2) == 0
}
func TestMul(t *testing.T) {
if err := quick.Check(checkMul, nil); err != nil {
t.Error(err)
}
}
var mulRangesZ = []struct {
a, b int64
prod string
@ -212,7 +198,6 @@ var mulRangesZ = []struct {
},
}
func TestMulRangeZ(t *testing.T) {
var tmp Int
// test entirely positive ranges
@ -231,7 +216,6 @@ func TestMulRangeZ(t *testing.T) {
}
}
var stringTests = []struct {
in string
out string
@ -280,7 +264,6 @@ var stringTests = []struct {
{"1001010111", "1001010111", 2, 0x257, true},
}
func format(base int) string {
switch base {
case 2:
@ -293,7 +276,6 @@ func format(base int) string {
return "%d"
}
func TestGetString(t *testing.T) {
z := new(Int)
for i, test := range stringTests {
@ -316,7 +298,6 @@ func TestGetString(t *testing.T) {
}
}
func TestSetString(t *testing.T) {
tmp := new(Int)
for i, test := range stringTests {
@ -347,6 +328,212 @@ func TestSetString(t *testing.T) {
}
}
var formatTests = []struct {
input string
format string
output string
}{
{"<nil>", "%x", "<nil>"},
{"<nil>", "%#x", "<nil>"},
{"<nil>", "%#y", "%!y(big.Int=<nil>)"},
{"10", "%b", "1010"},
{"10", "%o", "12"},
{"10", "%d", "10"},
{"10", "%v", "10"},
{"10", "%x", "a"},
{"10", "%X", "A"},
{"-10", "%X", "-A"},
{"10", "%y", "%!y(big.Int=10)"},
{"-10", "%y", "%!y(big.Int=-10)"},
{"10", "%#b", "1010"},
{"10", "%#o", "012"},
{"10", "%#d", "10"},
{"10", "%#v", "10"},
{"10", "%#x", "0xa"},
{"10", "%#X", "0XA"},
{"-10", "%#X", "-0XA"},
{"10", "%#y", "%!y(big.Int=10)"},
{"-10", "%#y", "%!y(big.Int=-10)"},
{"1234", "%d", "1234"},
{"1234", "%3d", "1234"},
{"1234", "%4d", "1234"},
{"-1234", "%d", "-1234"},
{"1234", "% 5d", " 1234"},
{"1234", "%+5d", "+1234"},
{"1234", "%-5d", "1234 "},
{"1234", "%x", "4d2"},
{"1234", "%X", "4D2"},
{"-1234", "%3x", "-4d2"},
{"-1234", "%4x", "-4d2"},
{"-1234", "%5x", " -4d2"},
{"-1234", "%-5x", "-4d2 "},
{"1234", "%03d", "1234"},
{"1234", "%04d", "1234"},
{"1234", "%05d", "01234"},
{"1234", "%06d", "001234"},
{"-1234", "%06d", "-01234"},
{"1234", "%+06d", "+01234"},
{"1234", "% 06d", " 01234"},
{"1234", "%-6d", "1234 "},
{"1234", "%-06d", "1234 "},
{"-1234", "%-06d", "-1234 "},
{"1234", "%.3d", "1234"},
{"1234", "%.4d", "1234"},
{"1234", "%.5d", "01234"},
{"1234", "%.6d", "001234"},
{"-1234", "%.3d", "-1234"},
{"-1234", "%.4d", "-1234"},
{"-1234", "%.5d", "-01234"},
{"-1234", "%.6d", "-001234"},
{"1234", "%8.3d", " 1234"},
{"1234", "%8.4d", " 1234"},
{"1234", "%8.5d", " 01234"},
{"1234", "%8.6d", " 001234"},
{"-1234", "%8.3d", " -1234"},
{"-1234", "%8.4d", " -1234"},
{"-1234", "%8.5d", " -01234"},
{"-1234", "%8.6d", " -001234"},
{"1234", "%+8.3d", " +1234"},
{"1234", "%+8.4d", " +1234"},
{"1234", "%+8.5d", " +01234"},
{"1234", "%+8.6d", " +001234"},
{"-1234", "%+8.3d", " -1234"},
{"-1234", "%+8.4d", " -1234"},
{"-1234", "%+8.5d", " -01234"},
{"-1234", "%+8.6d", " -001234"},
{"1234", "% 8.3d", " 1234"},
{"1234", "% 8.4d", " 1234"},
{"1234", "% 8.5d", " 01234"},
{"1234", "% 8.6d", " 001234"},
{"-1234", "% 8.3d", " -1234"},
{"-1234", "% 8.4d", " -1234"},
{"-1234", "% 8.5d", " -01234"},
{"-1234", "% 8.6d", " -001234"},
{"1234", "%.3x", "4d2"},
{"1234", "%.4x", "04d2"},
{"1234", "%.5x", "004d2"},
{"1234", "%.6x", "0004d2"},
{"-1234", "%.3x", "-4d2"},
{"-1234", "%.4x", "-04d2"},
{"-1234", "%.5x", "-004d2"},
{"-1234", "%.6x", "-0004d2"},
{"1234", "%8.3x", " 4d2"},
{"1234", "%8.4x", " 04d2"},
{"1234", "%8.5x", " 004d2"},
{"1234", "%8.6x", " 0004d2"},
{"-1234", "%8.3x", " -4d2"},
{"-1234", "%8.4x", " -04d2"},
{"-1234", "%8.5x", " -004d2"},
{"-1234", "%8.6x", " -0004d2"},
{"1234", "%+8.3x", " +4d2"},
{"1234", "%+8.4x", " +04d2"},
{"1234", "%+8.5x", " +004d2"},
{"1234", "%+8.6x", " +0004d2"},
{"-1234", "%+8.3x", " -4d2"},
{"-1234", "%+8.4x", " -04d2"},
{"-1234", "%+8.5x", " -004d2"},
{"-1234", "%+8.6x", " -0004d2"},
{"1234", "% 8.3x", " 4d2"},
{"1234", "% 8.4x", " 04d2"},
{"1234", "% 8.5x", " 004d2"},
{"1234", "% 8.6x", " 0004d2"},
{"1234", "% 8.7x", " 00004d2"},
{"1234", "% 8.8x", " 000004d2"},
{"-1234", "% 8.3x", " -4d2"},
{"-1234", "% 8.4x", " -04d2"},
{"-1234", "% 8.5x", " -004d2"},
{"-1234", "% 8.6x", " -0004d2"},
{"-1234", "% 8.7x", "-00004d2"},
{"-1234", "% 8.8x", "-000004d2"},
{"1234", "%-8.3d", "1234 "},
{"1234", "%-8.4d", "1234 "},
{"1234", "%-8.5d", "01234 "},
{"1234", "%-8.6d", "001234 "},
{"1234", "%-8.7d", "0001234 "},
{"1234", "%-8.8d", "00001234"},
{"-1234", "%-8.3d", "-1234 "},
{"-1234", "%-8.4d", "-1234 "},
{"-1234", "%-8.5d", "-01234 "},
{"-1234", "%-8.6d", "-001234 "},
{"-1234", "%-8.7d", "-0001234"},
{"-1234", "%-8.8d", "-00001234"},
{"16777215", "%b", "111111111111111111111111"}, // 2**24 - 1
{"0", "%.d", ""},
{"0", "%.0d", ""},
{"0", "%3.d", ""},
}
func TestFormat(t *testing.T) {
for i, test := range formatTests {
var x *Int
if test.input != "<nil>" {
var ok bool
x, ok = new(Int).SetString(test.input, 0)
if !ok {
t.Errorf("#%d failed reading input %s", i, test.input)
}
}
output := fmt.Sprintf(test.format, x)
if output != test.output {
t.Errorf("#%d got %q; want %q, {%q, %q, %q}", i, output, test.output, test.input, test.format, test.output)
}
}
}
var scanTests = []struct {
input string
format string
output string
remaining int
}{
{"1010", "%b", "10", 0},
{"0b1010", "%v", "10", 0},
{"12", "%o", "10", 0},
{"012", "%v", "10", 0},
{"10", "%d", "10", 0},
{"10", "%v", "10", 0},
{"a", "%x", "10", 0},
{"0xa", "%v", "10", 0},
{"A", "%X", "10", 0},
{"-A", "%X", "-10", 0},
{"+0b1011001", "%v", "89", 0},
{"0xA", "%v", "10", 0},
{"0 ", "%v", "0", 1},
{"2+3", "%v", "2", 2},
{"0XABC 12", "%v", "2748", 3},
}
func TestScan(t *testing.T) {
var buf bytes.Buffer
for i, test := range scanTests {
x := new(Int)
buf.Reset()
buf.WriteString(test.input)
if _, err := fmt.Fscanf(&buf, test.format, x); err != nil {
t.Errorf("#%d error: %s", i, err.String())
}
if x.String() != test.output {
t.Errorf("#%d got %s; want %s", i, x.String(), test.output)
}
if buf.Len() != test.remaining {
t.Errorf("#%d got %d bytes remaining; want %d", i, buf.Len(), test.remaining)
}
}
}
// Examples from the Go Language Spec, section "Arithmetic operators"
var divisionSignsTests = []struct {
@ -362,7 +549,6 @@ var divisionSignsTests = []struct {
{8, 4, 2, 0, 2, 0},
}
func TestDivisionSigns(t *testing.T) {
for i, test := range divisionSignsTests {
x := NewInt(test.x)
@ -420,7 +606,6 @@ func TestDivisionSigns(t *testing.T) {
}
}
func checkSetBytes(b []byte) bool {
hex1 := hex.EncodeToString(new(Int).SetBytes(b).Bytes())
hex2 := hex.EncodeToString(b)
@ -436,27 +621,23 @@ func checkSetBytes(b []byte) bool {
return hex1 == hex2
}
func TestSetBytes(t *testing.T) {
if err := quick.Check(checkSetBytes, nil); err != nil {
t.Error(err)
}
}
func checkBytes(b []byte) bool {
b2 := new(Int).SetBytes(b).Bytes()
return bytes.Compare(b, b2) == 0
}
func TestBytes(t *testing.T) {
if err := quick.Check(checkSetBytes, nil); err != nil {
t.Error(err)
}
}
func checkQuo(x, y []byte) bool {
u := new(Int).SetBytes(x)
v := new(Int).SetBytes(y)
@ -479,7 +660,6 @@ func checkQuo(x, y []byte) bool {
return uprime.Cmp(u) == 0
}
var quoTests = []struct {
x, y string
q, r string
@ -498,7 +678,6 @@ var quoTests = []struct {
},
}
func TestQuo(t *testing.T) {
if err := quick.Check(checkQuo, nil); err != nil {
t.Error(err)
@ -519,7 +698,6 @@ func TestQuo(t *testing.T) {
}
}
func TestQuoStepD6(t *testing.T) {
// See Knuth, Volume 2, section 4.3.1, exercise 21. This code exercises
// a code path which only triggers 1 in 10^{-19} cases.
@ -539,7 +717,6 @@ func TestQuoStepD6(t *testing.T) {
}
}
var bitLenTests = []struct {
in string
out int
@ -558,7 +735,6 @@ var bitLenTests = []struct {
{"-0x4000000000000000000000", 87},
}
func TestBitLen(t *testing.T) {
for i, test := range bitLenTests {
x, ok := new(Int).SetString(test.in, 0)
@ -573,7 +749,6 @@ func TestBitLen(t *testing.T) {
}
}
var expTests = []struct {
x, y, m string
out string
@ -598,7 +773,6 @@ var expTests = []struct {
},
}
func TestExp(t *testing.T) {
for i, test := range expTests {
x, ok1 := new(Int).SetString(test.x, 0)
@ -629,7 +803,6 @@ func TestExp(t *testing.T) {
}
}
func checkGcd(aBytes, bBytes []byte) bool {
a := new(Int).SetBytes(aBytes)
b := new(Int).SetBytes(bBytes)
@ -646,7 +819,6 @@ func checkGcd(aBytes, bBytes []byte) bool {
return x.Cmp(d) == 0
}
var gcdTests = []struct {
a, b int64
d, x, y int64
@ -654,7 +826,6 @@ var gcdTests = []struct {
{120, 23, 1, -9, 47},
}
func TestGcd(t *testing.T) {
for i, test := range gcdTests {
a := NewInt(test.a)
@ -680,7 +851,6 @@ func TestGcd(t *testing.T) {
quick.Check(checkGcd, nil)
}
var primes = []string{
"2",
"3",
@ -706,7 +876,6 @@ var primes = []string{
"203956878356401977405765866929034577280193993314348263094772646453283062722701277632936616063144088173312372882677123879538709400158306567338328279154499698366071906766440037074217117805690872792848149112022286332144876183376326512083574821647933992961249917319836219304274280243803104015000563790123",
}
var composites = []string{
"21284175091214687912771199898307297748211672914763848041968395774954376176754",
"6084766654921918907427900243509372380954290099172559290432744450051395395951",
@ -714,7 +883,6 @@ var composites = []string{
"82793403787388584738507275144194252681",
}
func TestProbablyPrime(t *testing.T) {
nreps := 20
if testing.Short() {
@ -738,14 +906,12 @@ func TestProbablyPrime(t *testing.T) {
}
}
type intShiftTest struct {
in string
shift uint
out string
}
var rshTests = []intShiftTest{
{"0", 0, "0"},
{"-0", 0, "0"},
@ -773,7 +939,6 @@ var rshTests = []intShiftTest{
{"340282366920938463463374607431768211456", 128, "1"},
}
func TestRsh(t *testing.T) {
for i, test := range rshTests {
in, _ := new(Int).SetString(test.in, 10)
@ -789,7 +954,6 @@ func TestRsh(t *testing.T) {
}
}
func TestRshSelf(t *testing.T) {
for i, test := range rshTests {
z, _ := new(Int).SetString(test.in, 10)
@ -805,7 +969,6 @@ func TestRshSelf(t *testing.T) {
}
}
var lshTests = []intShiftTest{
{"0", 0, "0"},
{"0", 1, "0"},
@ -828,7 +991,6 @@ var lshTests = []intShiftTest{
{"1", 128, "340282366920938463463374607431768211456"},
}
func TestLsh(t *testing.T) {
for i, test := range lshTests {
in, _ := new(Int).SetString(test.in, 10)
@ -844,7 +1006,6 @@ func TestLsh(t *testing.T) {
}
}
func TestLshSelf(t *testing.T) {
for i, test := range lshTests {
z, _ := new(Int).SetString(test.in, 10)
@ -860,7 +1021,6 @@ func TestLshSelf(t *testing.T) {
}
}
func TestLshRsh(t *testing.T) {
for i, test := range rshTests {
in, _ := new(Int).SetString(test.in, 10)
@ -888,7 +1048,6 @@ func TestLshRsh(t *testing.T) {
}
}
var int64Tests = []int64{
0,
1,
@ -902,7 +1061,6 @@ var int64Tests = []int64{
-9223372036854775808,
}
func TestInt64(t *testing.T) {
for i, testVal := range int64Tests {
in := NewInt(testVal)
@ -914,7 +1072,6 @@ func TestInt64(t *testing.T) {
}
}
var bitwiseTests = []struct {
x, y string
and, or, xor, andNot string
@ -958,7 +1115,6 @@ var bitwiseTests = []struct {
},
}
type bitFun func(z, x, y *Int) *Int
func testBitFun(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
@ -971,7 +1127,6 @@ func testBitFun(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
}
}
func testBitFunSelf(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
self := new(Int)
self.Set(x)
@ -984,6 +1139,142 @@ func testBitFunSelf(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
}
}
func altBit(x *Int, i int) uint {
z := new(Int).Rsh(x, uint(i))
z = z.And(z, NewInt(1))
if z.Cmp(new(Int)) != 0 {
return 1
}
return 0
}
func altSetBit(z *Int, x *Int, i int, b uint) *Int {
one := NewInt(1)
m := one.Lsh(one, uint(i))
switch b {
case 1:
return z.Or(x, m)
case 0:
return z.AndNot(x, m)
}
panic("set bit is not 0 or 1")
}
func testBitset(t *testing.T, x *Int) {
n := x.BitLen()
z := new(Int).Set(x)
z1 := new(Int).Set(x)
for i := 0; i < n+10; i++ {
old := z.Bit(i)
old1 := altBit(z1, i)
if old != old1 {
t.Errorf("bitset: inconsistent value for Bit(%s, %d), got %v want %v", z1, i, old, old1)
}
z := new(Int).SetBit(z, i, 1)
z1 := altSetBit(new(Int), z1, i, 1)
if z.Bit(i) == 0 {
t.Errorf("bitset: bit %d of %s got 0 want 1", i, x)
}
if z.Cmp(z1) != 0 {
t.Errorf("bitset: inconsistent value after SetBit 1, got %s want %s", z, z1)
}
z.SetBit(z, i, 0)
altSetBit(z1, z1, i, 0)
if z.Bit(i) != 0 {
t.Errorf("bitset: bit %d of %s got 1 want 0", i, x)
}
if z.Cmp(z1) != 0 {
t.Errorf("bitset: inconsistent value after SetBit 0, got %s want %s", z, z1)
}
altSetBit(z1, z1, i, old)
z.SetBit(z, i, old)
if z.Cmp(z1) != 0 {
t.Errorf("bitset: inconsistent value after SetBit old, got %s want %s", z, z1)
}
}
if z.Cmp(x) != 0 {
t.Errorf("bitset: got %s want %s", z, x)
}
}
var bitsetTests = []struct {
x string
i int
b uint
}{
{"0", 0, 0},
{"0", 200, 0},
{"1", 0, 1},
{"1", 1, 0},
{"-1", 0, 1},
{"-1", 200, 1},
{"0x2000000000000000000000000000", 108, 0},
{"0x2000000000000000000000000000", 109, 1},
{"0x2000000000000000000000000000", 110, 0},
{"-0x2000000000000000000000000001", 108, 1},
{"-0x2000000000000000000000000001", 109, 0},
{"-0x2000000000000000000000000001", 110, 1},
}
func TestBitSet(t *testing.T) {
for _, test := range bitwiseTests {
x := new(Int)
x.SetString(test.x, 0)
testBitset(t, x)
x = new(Int)
x.SetString(test.y, 0)
testBitset(t, x)
}
for i, test := range bitsetTests {
x := new(Int)
x.SetString(test.x, 0)
b := x.Bit(test.i)
if b != test.b {
t.Errorf("#%d want %v got %v", i, test.b, b)
}
}
}
func BenchmarkBitset(b *testing.B) {
z := new(Int)
z.SetBit(z, 512, 1)
b.ResetTimer()
b.StartTimer()
for i := b.N - 1; i >= 0; i-- {
z.SetBit(z, i&512, 1)
}
}
func BenchmarkBitsetNeg(b *testing.B) {
z := NewInt(-1)
z.SetBit(z, 512, 0)
b.ResetTimer()
b.StartTimer()
for i := b.N - 1; i >= 0; i-- {
z.SetBit(z, i&512, 0)
}
}
func BenchmarkBitsetOrig(b *testing.B) {
z := new(Int)
altSetBit(z, z, 512, 1)
b.ResetTimer()
b.StartTimer()
for i := b.N - 1; i >= 0; i-- {
altSetBit(z, z, i&512, 1)
}
}
func BenchmarkBitsetNegOrig(b *testing.B) {
z := NewInt(-1)
altSetBit(z, z, 512, 0)
b.ResetTimer()
b.StartTimer()
for i := b.N - 1; i >= 0; i-- {
altSetBit(z, z, i&512, 0)
}
}
func TestBitwise(t *testing.T) {
x := new(Int)
@ -1003,7 +1294,6 @@ func TestBitwise(t *testing.T) {
}
}
var notTests = []struct {
in string
out string
@ -1037,7 +1327,6 @@ func TestNot(t *testing.T) {
}
}
var modInverseTests = []struct {
element string
prime string
@ -1062,7 +1351,7 @@ func TestModInverse(t *testing.T) {
}
}
// used by TestIntGobEncoding and TestRatGobEncoding
var gobEncodingTests = []string{
"0",
"1",
@ -1073,7 +1362,7 @@ var gobEncodingTests = []string{
"298472983472983471903246121093472394872319615612417471234712061",
}
func TestGobEncoding(t *testing.T) {
func TestIntGobEncoding(t *testing.T) {
var medium bytes.Buffer
enc := gob.NewEncoder(&medium)
dec := gob.NewDecoder(&medium)
@ -1081,7 +1370,8 @@ func TestGobEncoding(t *testing.T) {
for j := 0; j < 2; j++ {
medium.Reset() // empty buffer for each test case (in case of failures)
stest := test
if j == 0 {
if j != 0 {
// negative numbers
stest = "-" + test
}
var tx Int

View File

@ -18,7 +18,11 @@ package big
// These are the building blocks for the operations on signed integers
// and rationals.
import "rand"
import (
"io"
"os"
"rand"
)
// An unsigned integer x of the form
//
@ -40,14 +44,12 @@ var (
natTen = nat{10}
)
func (z nat) clear() {
for i := range z {
z[i] = 0
}
}
func (z nat) norm() nat {
i := len(z)
for i > 0 && z[i-1] == 0 {
@ -56,7 +58,6 @@ func (z nat) norm() nat {
return z[0:i]
}
func (z nat) make(n int) nat {
if n <= cap(z) {
return z[0:n] // reuse z
@ -67,7 +68,6 @@ func (z nat) make(n int) nat {
return make(nat, n, n+e)
}
func (z nat) setWord(x Word) nat {
if x == 0 {
return z.make(0)
@ -77,7 +77,6 @@ func (z nat) setWord(x Word) nat {
return z
}
func (z nat) setUint64(x uint64) nat {
// single-digit values
if w := Word(x); uint64(w) == x {
@ -100,14 +99,12 @@ func (z nat) setUint64(x uint64) nat {
return z
}
func (z nat) set(x nat) nat {
z = z.make(len(x))
copy(z, x)
return z
}
func (z nat) add(x, y nat) nat {
m := len(x)
n := len(y)
@ -134,7 +131,6 @@ func (z nat) add(x, y nat) nat {
return z.norm()
}
func (z nat) sub(x, y nat) nat {
m := len(x)
n := len(y)
@ -163,7 +159,6 @@ func (z nat) sub(x, y nat) nat {
return z.norm()
}
func (x nat) cmp(y nat) (r int) {
m := len(x)
n := len(y)
@ -191,7 +186,6 @@ func (x nat) cmp(y nat) (r int) {
return
}
func (z nat) mulAddWW(x nat, y, r Word) nat {
m := len(x)
if m == 0 || y == 0 {
@ -205,7 +199,6 @@ func (z nat) mulAddWW(x nat, y, r Word) nat {
return z.norm()
}
// basicMul multiplies x and y and leaves the result in z.
// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
func basicMul(z, x, y nat) {
@ -217,7 +210,6 @@ func basicMul(z, x, y nat) {
}
}
// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
// Factored out for readability - do not use outside karatsuba.
func karatsubaAdd(z, x nat, n int) {
@ -226,7 +218,6 @@ func karatsubaAdd(z, x nat, n int) {
}
}
// Like karatsubaAdd, but does subtract.
func karatsubaSub(z, x nat, n int) {
if c := subVV(z[0:n], z, x); c != 0 {
@ -234,7 +225,6 @@ func karatsubaSub(z, x nat, n int) {
}
}
// Operands that are shorter than karatsubaThreshold are multiplied using
// "grade school" multiplication; for longer operands the Karatsuba algorithm
// is used.
@ -339,13 +329,11 @@ func karatsuba(z, x, y nat) {
}
}
// alias returns true if x and y share the same base array.
func alias(x, y nat) bool {
return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
}
// addAt implements z += x*(1<<(_W*i)); z must be long enough.
// (we don't use nat.add because we need z to stay the same
// slice, and we don't need to normalize z after each addition)
@ -360,7 +348,6 @@ func addAt(z, x nat, i int) {
}
}
func max(x, y int) int {
if x > y {
return x
@ -368,7 +355,6 @@ func max(x, y int) int {
return y
}
// karatsubaLen computes an approximation to the maximum k <= n such that
// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
// result is the largest number that can be divided repeatedly by 2 before
@ -382,7 +368,6 @@ func karatsubaLen(n int) int {
return n << i
}
func (z nat) mul(x, y nat) nat {
m := len(x)
n := len(y)
@ -450,7 +435,6 @@ func (z nat) mul(x, y nat) nat {
return z.norm()
}
// mulRange computes the product of all the unsigned integers in the
// range [a, b] inclusively. If a > b (empty range), the result is 1.
func (z nat) mulRange(a, b uint64) nat {
@ -469,7 +453,6 @@ func (z nat) mulRange(a, b uint64) nat {
return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
}
// q = (x-r)/y, with 0 <= r < y
func (z nat) divW(x nat, y Word) (q nat, r Word) {
m := len(x)
@ -490,7 +473,6 @@ func (z nat) divW(x nat, y Word) (q nat, r Word) {
return
}
func (z nat) div(z2, u, v nat) (q, r nat) {
if len(v) == 0 {
panic("division by zero")
@ -518,7 +500,6 @@ func (z nat) div(z2, u, v nat) (q, r nat) {
return
}
// q = (uIn-r)/v, with 0 <= r < y
// Uses z as storage for q, and u as storage for r if possible.
// See Knuth, Volume 2, section 4.3.1, Algorithm D.
@ -545,9 +526,14 @@ func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
u.clear()
// D1.
shift := Word(leadingZeros(v[n-1]))
shlVW(v, v, shift)
u[len(uIn)] = shlVW(u[0:len(uIn)], uIn, shift)
shift := leadingZeros(v[n-1])
if shift > 0 {
// do not modify v, it may be used by another goroutine simultaneously
v1 := make(nat, n)
shlVU(v1, v, shift)
v = v1
}
u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
// D2.
for j := m; j >= 0; j-- {
@ -586,14 +572,12 @@ func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
}
q = q.norm()
shrVW(u, u, shift)
shrVW(v, v, shift)
shrVU(u, u, shift)
r = u.norm()
return q, r
}
// Length of x in bits. x must be normalized.
func (x nat) bitLen() int {
if i := len(x) - 1; i >= 0 {
@ -602,103 +586,253 @@ func (x nat) bitLen() int {
return 0
}
// MaxBase is the largest number base accepted for string conversions.
const MaxBase = 'z' - 'a' + 10 + 1 // = hexValue('z') + 1
func hexValue(ch byte) int {
var d byte
func hexValue(ch int) Word {
d := MaxBase + 1 // illegal base
switch {
case '0' <= ch && ch <= '9':
d = ch - '0'
case 'a' <= ch && ch <= 'f':
case 'a' <= ch && ch <= 'z':
d = ch - 'a' + 10
case 'A' <= ch && ch <= 'F':
case 'A' <= ch && ch <= 'Z':
d = ch - 'A' + 10
default:
return -1
}
return int(d)
return Word(d)
}
// scan sets z to the natural number corresponding to the longest possible prefix
// read from r representing an unsigned integer in a given conversion base.
// It returns z, the actual conversion base used, and an error, if any. In the
// error case, the value of z is undefined. The syntax follows the syntax of
// unsigned integer literals in Go.
//
// The base argument must be 0 or a value from 2 through MaxBase. If the base
// is 0, the string prefix determines the actual conversion base. A prefix of
// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a
// ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10.
//
func (z nat) scan(r io.RuneScanner, base int) (nat, int, os.Error) {
// reject illegal bases
if base < 0 || base == 1 || MaxBase < base {
return z, 0, os.NewError("illegal number base")
}
// one char look-ahead
ch, _, err := r.ReadRune()
if err != nil {
return z, 0, err
}
// scan returns the natural number corresponding to the
// longest possible prefix of s representing a natural number in a
// given conversion base, the actual conversion base used, and the
// prefix length. The syntax of natural numbers follows the syntax
// of unsigned integer literals in Go.
//
// If the base argument is 0, the string prefix determines the actual
// conversion base. A prefix of ``0x'' or ``0X'' selects base 16; the
// ``0'' prefix selects base 8, and a ``0b'' or ``0B'' prefix selects
// base 2. Otherwise the selected base is 10.
//
func (z nat) scan(s string, base int) (nat, int, int) {
// determine base if necessary
i, n := 0, len(s)
b := Word(base)
if base == 0 {
base = 10
if n > 0 && s[0] == '0' {
base, i = 8, 1
if n > 1 {
switch s[1] {
b = 10
if ch == '0' {
switch ch, _, err = r.ReadRune(); err {
case nil:
b = 8
switch ch {
case 'x', 'X':
base, i = 16, 2
b = 16
case 'b', 'B':
base, i = 2, 2
b = 2
}
if b == 2 || b == 16 {
if ch, _, err = r.ReadRune(); err != nil {
return z, 0, err
}
}
case os.EOF:
return z, 10, nil
default:
return z, 10, err
}
}
}
// convert string
// - group as many digits d as possible together into a "super-digit" dd with "super-base" bb
// - only when bb does not fit into a word anymore, do a full number mulAddWW using bb and dd
z = z.make(0)
bb := Word(1)
dd := Word(0)
for max := _M / b; ; {
d := hexValue(ch)
if d >= b {
r.UnreadRune() // ch does not belong to number anymore
break
}
if bb <= max {
bb *= b
dd = dd*b + d
} else {
// bb * b would overflow
z = z.mulAddWW(z, bb, dd)
bb = b
dd = d
}
if ch, _, err = r.ReadRune(); err != nil {
if err != os.EOF {
return z, int(b), err
}
break
}
}
switch {
case bb > 1:
// there was at least one mantissa digit
z = z.mulAddWW(z, bb, dd)
case base == 0 && b == 8:
// there was only the octal prefix 0 (possibly followed by digits > 7);
// return base 10, not 8
return z, 10, nil
case base != 0 || b != 8:
// there was neither a mantissa digit nor the octal prefix 0
return z, int(b), os.NewError("syntax error scanning number")
}
return z.norm(), int(b), nil
}
// Character sets for string conversion.
const (
lowercaseDigits = "0123456789abcdefghijklmnopqrstuvwxyz"
uppercaseDigits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
// decimalString returns a decimal representation of x.
// It calls x.string with the charset "0123456789".
func (x nat) decimalString() string {
return x.string(lowercaseDigits[0:10])
}
// string converts x to a string using digits from a charset; a digit with
// value d is represented by charset[d]. The conversion base is determined
// by len(charset), which must be >= 2.
func (x nat) string(charset string) string {
b := Word(len(charset))
// special cases
switch {
case b < 2 || b > 256:
panic("illegal base")
case len(x) == 0:
return string(charset[0])
}
// allocate buffer for conversion
i := x.bitLen()/log2(b) + 1 // +1: round up
s := make([]byte, i)
// special case: power of two bases can avoid divisions completely
if b == b&-b {
// shift is base-b digit size in bits
shift := uint(trailingZeroBits(b)) // shift > 0 because b >= 2
mask := Word(1)<<shift - 1
w := x[0]
nbits := uint(_W) // number of unprocessed bits in w
// convert less-significant words
for k := 1; k < len(x); k++ {
// convert full digits
for nbits >= shift {
i--
s[i] = charset[w&mask]
w >>= shift
nbits -= shift
}
// convert any partial leading digit and advance to next word
if nbits == 0 {
// no partial digit remaining, just advance
w = x[k]
nbits = _W
} else {
// partial digit in current (k-1) and next (k) word
w |= x[k] << nbits
i--
s[i] = charset[w&mask]
// advance
w = x[k] >> (shift - nbits)
nbits = _W - (shift - nbits)
}
}
// convert digits of most-significant word (omit leading zeros)
for nbits >= 0 && w != 0 {
i--
s[i] = charset[w&mask]
w >>= shift
nbits -= shift
}
return string(s[i:])
}
// general case: extract groups of digits by multiprecision division
// maximize ndigits where b**ndigits < 2^_W; bb (big base) is b**ndigits
bb := Word(1)
ndigits := 0
for max := Word(_M / b); bb <= max; bb *= b {
ndigits++
}
// preserve x, create local copy for use in repeated divisions
q := nat(nil).set(x)
var r Word
// convert
if b == 10 { // hard-coding for 10 here speeds this up by 1.25x
for len(q) > 0 {
// extract least significant, base bb "digit"
q, r = q.divW(q, bb) // N.B. >82% of time is here. Optimize divW
if len(q) == 0 {
// skip leading zeros in most-significant group of digits
for j := 0; j < ndigits && r != 0; j++ {
i--
s[i] = charset[r%10]
r /= 10
}
} else {
for j := 0; j < ndigits; j++ {
i--
s[i] = charset[r%10]
r /= 10
}
}
}
} else {
for len(q) > 0 {
// extract least significant group of digits
q, r = q.divW(q, bb) // N.B. >82% of time is here. Optimize divW
if len(q) == 0 {
// skip leading zeros in most-significant group of digits
for j := 0; j < ndigits && r != 0; j++ {
i--
s[i] = charset[r%b]
r /= b
}
} else {
for j := 0; j < ndigits; j++ {
i--
s[i] = charset[r%b]
r /= b
}
}
}
}
// reject illegal bases or strings consisting only of prefix
if base < 2 || 16 < base || (base != 8 && i >= n) {
return z, 0, 0
}
// convert string
z = z.make(0)
for ; i < n; i++ {
d := hexValue(s[i])
if 0 <= d && d < base {
z = z.mulAddWW(z, Word(base), Word(d))
} else {
break
}
}
return z.norm(), base, i
}
// string converts x to a string for a given base, with 2 <= base <= 16.
// TODO(gri) in the style of the other routines, perhaps this should take
// a []byte buffer and return it
func (x nat) string(base int) string {
if base < 2 || 16 < base {
panic("illegal base")
}
if len(x) == 0 {
return "0"
}
// allocate buffer for conversion
i := x.bitLen()/log2(Word(base)) + 1 // +1: round up
s := make([]byte, i)
// don't destroy x
q := nat(nil).set(x)
// convert
for len(q) > 0 {
i--
var r Word
q, r = q.divW(q, Word(base))
s[i] = "0123456789abcdef"[r]
}
return string(s[i:])
}
const deBruijn32 = 0x077CB531
var deBruijn32Lookup = []byte{
@ -721,7 +855,7 @@ var deBruijn64Lookup = []byte{
func trailingZeroBits(x Word) int {
// x & -x leaves only the right-most bit set in the word. Let k be the
// index of that bit. Since only a single bit is set, the value is two
// to the power of k. Multipling by a power of two is equivalent to
// to the power of k. Multiplying by a power of two is equivalent to
// left shifting, in this case by k bits. The de Bruijn constant is
// such that all six bit, consecutive substrings are distinct.
// Therefore, if we have a left shifted version of this constant we can
@ -739,7 +873,6 @@ func trailingZeroBits(x Word) int {
return 0
}
// z = x << s
func (z nat) shl(x nat, s uint) nat {
m := len(x)
@ -750,13 +883,12 @@ func (z nat) shl(x nat, s uint) nat {
n := m + int(s/_W)
z = z.make(n + 1)
z[n] = shlVW(z[n-m:n], x, Word(s%_W))
z[n] = shlVU(z[n-m:n], x, s%_W)
z[0 : n-m].clear()
return z.norm()
}
// z = x >> s
func (z nat) shr(x nat, s uint) nat {
m := len(x)
@ -767,11 +899,45 @@ func (z nat) shr(x nat, s uint) nat {
// n > 0
z = z.make(n)
shrVW(z, x[m-n:], Word(s%_W))
shrVU(z, x[m-n:], s%_W)
return z.norm()
}
func (z nat) setBit(x nat, i uint, b uint) nat {
j := int(i / _W)
m := Word(1) << (i % _W)
n := len(x)
switch b {
case 0:
z = z.make(n)
copy(z, x)
if j >= n {
// no need to grow
return z
}
z[j] &^= m
return z.norm()
case 1:
if j >= n {
n = j + 1
}
z = z.make(n)
copy(z, x)
z[j] |= m
// no need to normalize
return z
}
panic("set bit is not 0 or 1")
}
func (z nat) bit(i uint) uint {
j := int(i / _W)
if j >= len(z) {
return 0
}
return uint(z[j] >> (i % _W) & 1)
}
func (z nat) and(x, y nat) nat {
m := len(x)
@ -789,7 +955,6 @@ func (z nat) and(x, y nat) nat {
return z.norm()
}
func (z nat) andNot(x, y nat) nat {
m := len(x)
n := len(y)
@ -807,7 +972,6 @@ func (z nat) andNot(x, y nat) nat {
return z.norm()
}
func (z nat) or(x, y nat) nat {
m := len(x)
n := len(y)
@ -827,7 +991,6 @@ func (z nat) or(x, y nat) nat {
return z.norm()
}
func (z nat) xor(x, y nat) nat {
m := len(x)
n := len(y)
@ -847,10 +1010,10 @@ func (z nat) xor(x, y nat) nat {
return z.norm()
}
// greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2)
func greaterThan(x1, x2, y1, y2 Word) bool { return x1 > y1 || x1 == y1 && x2 > y2 }
func greaterThan(x1, x2, y1, y2 Word) bool {
return x1 > y1 || x1 == y1 && x2 > y2
}
// modW returns x % d.
func (x nat) modW(d Word) (r Word) {
@ -860,30 +1023,29 @@ func (x nat) modW(d Word) (r Word) {
return divWVW(q, 0, x, d)
}
// powersOfTwoDecompose finds q and k such that q * 1<<k = n and q is odd.
func (n nat) powersOfTwoDecompose() (q nat, k Word) {
if len(n) == 0 {
return n, 0
// powersOfTwoDecompose finds q and k with x = q * 1<<k and q is odd, or q and k are 0.
func (x nat) powersOfTwoDecompose() (q nat, k int) {
if len(x) == 0 {
return x, 0
}
zeroWords := 0
for n[zeroWords] == 0 {
zeroWords++
// One of the words must be non-zero by definition,
// so this loop will terminate with i < len(x), and
// i is the number of 0 words.
i := 0
for x[i] == 0 {
i++
}
// One of the words must be non-zero by invariant, therefore
// zeroWords < len(n).
x := trailingZeroBits(n[zeroWords])
n := trailingZeroBits(x[i]) // x[i] != 0
q = make(nat, len(x)-i)
shrVU(q, x[i:], uint(n))
q = q.make(len(n) - zeroWords)
shrVW(q, n[zeroWords:], Word(x))
q = q.norm()
k = Word(_W*zeroWords + x)
k = i*_W + n
return
}
// random creates a random integer in [0..limit), using the space in z if
// possible. n is the bit length of limit.
func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
@ -914,7 +1076,6 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
return z.norm()
}
// If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It
// reuses the storage of z if possible.
func (z nat) expNN(x, y, m nat) nat {
@ -983,7 +1144,6 @@ func (z nat) expNN(x, y, m nat) nat {
return z
}
// probablyPrime performs reps Miller-Rabin tests to check whether n is prime.
// If it returns true, n is prime with probability 1 - 1/4^reps.
// If it returns false, n is not prime.
@ -1050,7 +1210,7 @@ NextRandom:
if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
continue
}
for j := Word(1); j < k; j++ {
for j := 1; j < k; j++ {
y = y.mul(y, y)
quotient, y = quotient.div(y, y, n)
if y.cmp(nm1) == 0 {
@ -1066,7 +1226,6 @@ NextRandom:
return true
}
// bytes writes the value of z into buf using big-endian encoding.
// len(buf) must be >= len(z)*_S. The value of z is encoded in the
// slice buf[i:]. The number i of unused bytes at the beginning of
@ -1088,7 +1247,6 @@ func (z nat) bytes(buf []byte) (i int) {
return
}
// setBytes interprets buf as the bytes of a big-endian unsigned
// integer, sets z to that value, and returns z.
func (z nat) setBytes(buf []byte) nat {

View File

@ -4,7 +4,12 @@
package big
import "testing"
import (
"fmt"
"os"
"strings"
"testing"
)
var cmpTests = []struct {
x, y nat
@ -26,7 +31,6 @@ var cmpTests = []struct {
{nat{34986, 41, 105, 1957}, nat{56, 7458, 104, 1957}, 1},
}
func TestCmp(t *testing.T) {
for i, a := range cmpTests {
r := a.x.cmp(a.y)
@ -36,13 +40,11 @@ func TestCmp(t *testing.T) {
}
}
type funNN func(z, x, y nat) nat
type argNN struct {
z, x, y nat
}
var sumNN = []argNN{
{},
{nat{1}, nil, nat{1}},
@ -52,7 +54,6 @@ var sumNN = []argNN{
{nat{0, 0, 0, 1}, nat{0, 0, _M}, nat{0, 0, 1}},
}
var prodNN = []argNN{
{},
{nil, nil, nil},
@ -64,7 +65,6 @@ var prodNN = []argNN{
{nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}},
}
func TestSet(t *testing.T) {
for _, a := range sumNN {
z := nat(nil).set(a.z)
@ -74,7 +74,6 @@ func TestSet(t *testing.T) {
}
}
func testFunNN(t *testing.T, msg string, f funNN, a argNN) {
z := f(nil, a.x, a.y)
if z.cmp(a.z) != 0 {
@ -82,7 +81,6 @@ func testFunNN(t *testing.T, msg string, f funNN, a argNN) {
}
}
func TestFunNN(t *testing.T) {
for _, a := range sumNN {
arg := a
@ -107,7 +105,6 @@ func TestFunNN(t *testing.T) {
}
}
var mulRangesN = []struct {
a, b uint64
prod string
@ -130,17 +127,15 @@ var mulRangesN = []struct {
},
}
func TestMulRangeN(t *testing.T) {
for i, r := range mulRangesN {
prod := nat(nil).mulRange(r.a, r.b).string(10)
prod := nat(nil).mulRange(r.a, r.b).decimalString()
if prod != r.prod {
t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
}
}
}
var mulArg, mulTmp nat
func init() {
@ -151,7 +146,6 @@ func init() {
}
}
func benchmarkMulLoad() {
for j := 1; j <= 10; j++ {
x := mulArg[0 : j*100]
@ -159,46 +153,376 @@ func benchmarkMulLoad() {
}
}
func BenchmarkMul(b *testing.B) {
for i := 0; i < b.N; i++ {
benchmarkMulLoad()
}
}
func toString(x nat, charset string) string {
base := len(charset)
var tab = []struct {
x nat
b int
s string
}{
{nil, 10, "0"},
{nat{1}, 10, "1"},
{nat{10}, 10, "10"},
{nat{1234567890}, 10, "1234567890"},
// special cases
switch {
case base < 2:
panic("illegal base")
case len(x) == 0:
return string(charset[0])
}
// allocate buffer for conversion
i := x.bitLen()/log2(Word(base)) + 1 // +1: round up
s := make([]byte, i)
// don't destroy x
q := nat(nil).set(x)
// convert
for len(q) > 0 {
i--
var r Word
q, r = q.divW(q, Word(base))
s[i] = charset[r]
}
return string(s[i:])
}
var strTests = []struct {
x nat // nat value to be converted
c string // conversion charset
s string // expected result
}{
{nil, "01", "0"},
{nat{1}, "01", "1"},
{nat{0xc5}, "01", "11000101"},
{nat{03271}, lowercaseDigits[0:8], "3271"},
{nat{10}, lowercaseDigits[0:10], "10"},
{nat{1234567890}, uppercaseDigits[0:10], "1234567890"},
{nat{0xdeadbeef}, lowercaseDigits[0:16], "deadbeef"},
{nat{0xdeadbeef}, uppercaseDigits[0:16], "DEADBEEF"},
{nat{0x229be7}, lowercaseDigits[0:17], "1a2b3c"},
{nat{0x309663e6}, uppercaseDigits[0:32], "O9COV6"},
}
func TestString(t *testing.T) {
for _, a := range tab {
s := a.x.string(a.b)
for _, a := range strTests {
s := a.x.string(a.c)
if s != a.s {
t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s)
}
x, b, n := nat(nil).scan(a.s, a.b)
x, b, err := nat(nil).scan(strings.NewReader(a.s), len(a.c))
if x.cmp(a.x) != 0 {
t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
}
if b != a.b {
t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, a.b)
if b != len(a.c) {
t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, len(a.c))
}
if n != len(a.s) {
t.Errorf("scan%+v\n\tgot n = %d; want %d", a, n, len(a.s))
if err != nil {
t.Errorf("scan%+v\n\tgot error = %s", a, err)
}
}
}
var natScanTests = []struct {
s string // string to be scanned
base int // input base
x nat // expected nat
b int // expected base
ok bool // expected success
next int // next character (or 0, if at EOF)
}{
// error: illegal base
{base: -1},
{base: 1},
{base: 37},
// error: no mantissa
{},
{s: "?"},
{base: 10},
{base: 36},
{s: "?", base: 10},
{s: "0x"},
{s: "345", base: 2},
// no errors
{"0", 0, nil, 10, true, 0},
{"0", 10, nil, 10, true, 0},
{"0", 36, nil, 36, true, 0},
{"1", 0, nat{1}, 10, true, 0},
{"1", 10, nat{1}, 10, true, 0},
{"0 ", 0, nil, 10, true, ' '},
{"08", 0, nil, 10, true, '8'},
{"018", 0, nat{1}, 8, true, '8'},
{"0b1", 0, nat{1}, 2, true, 0},
{"0b11000101", 0, nat{0xc5}, 2, true, 0},
{"03271", 0, nat{03271}, 8, true, 0},
{"10ab", 0, nat{10}, 10, true, 'a'},
{"1234567890", 0, nat{1234567890}, 10, true, 0},
{"xyz", 36, nat{(33*36+34)*36 + 35}, 36, true, 0},
{"xyz?", 36, nat{(33*36+34)*36 + 35}, 36, true, '?'},
{"0x", 16, nil, 16, true, 'x'},
{"0xdeadbeef", 0, nat{0xdeadbeef}, 16, true, 0},
{"0XDEADBEEF", 0, nat{0xdeadbeef}, 16, true, 0},
}
func TestScanBase(t *testing.T) {
for _, a := range natScanTests {
r := strings.NewReader(a.s)
x, b, err := nat(nil).scan(r, a.base)
if err == nil && !a.ok {
t.Errorf("scan%+v\n\texpected error", a)
}
if err != nil {
if a.ok {
t.Errorf("scan%+v\n\tgot error = %s", a, err)
}
continue
}
if x.cmp(a.x) != 0 {
t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
}
if b != a.b {
t.Errorf("scan%+v\n\tgot b = %d; want %d", a, b, a.base)
}
next, _, err := r.ReadRune()
if err == os.EOF {
next = 0
err = nil
}
if err == nil && next != a.next {
t.Errorf("scan%+v\n\tgot next = %q; want %q", a, next, a.next)
}
}
}
var pi = "3" +
"14159265358979323846264338327950288419716939937510582097494459230781640628620899862803482534211706798214808651" +
"32823066470938446095505822317253594081284811174502841027019385211055596446229489549303819644288109756659334461" +
"28475648233786783165271201909145648566923460348610454326648213393607260249141273724587006606315588174881520920" +
"96282925409171536436789259036001133053054882046652138414695194151160943305727036575959195309218611738193261179" +
"31051185480744623799627495673518857527248912279381830119491298336733624406566430860213949463952247371907021798" +
"60943702770539217176293176752384674818467669405132000568127145263560827785771342757789609173637178721468440901" +
"22495343014654958537105079227968925892354201995611212902196086403441815981362977477130996051870721134999999837" +
"29780499510597317328160963185950244594553469083026425223082533446850352619311881710100031378387528865875332083" +
"81420617177669147303598253490428755468731159562863882353787593751957781857780532171226806613001927876611195909" +
"21642019893809525720106548586327886593615338182796823030195203530185296899577362259941389124972177528347913151" +
"55748572424541506959508295331168617278558890750983817546374649393192550604009277016711390098488240128583616035" +
"63707660104710181942955596198946767837449448255379774726847104047534646208046684259069491293313677028989152104" +
"75216205696602405803815019351125338243003558764024749647326391419927260426992279678235478163600934172164121992" +
"45863150302861829745557067498385054945885869269956909272107975093029553211653449872027559602364806654991198818" +
"34797753566369807426542527862551818417574672890977772793800081647060016145249192173217214772350141441973568548" +
"16136115735255213347574184946843852332390739414333454776241686251898356948556209921922218427255025425688767179" +
"04946016534668049886272327917860857843838279679766814541009538837863609506800642251252051173929848960841284886" +
"26945604241965285022210661186306744278622039194945047123713786960956364371917287467764657573962413890865832645" +
"99581339047802759009946576407895126946839835259570982582262052248940772671947826848260147699090264013639443745" +
"53050682034962524517493996514314298091906592509372216964615157098583874105978859597729754989301617539284681382" +
"68683868942774155991855925245953959431049972524680845987273644695848653836736222626099124608051243884390451244" +
"13654976278079771569143599770012961608944169486855584840635342207222582848864815845602850601684273945226746767" +
"88952521385225499546667278239864565961163548862305774564980355936345681743241125150760694794510965960940252288" +
"79710893145669136867228748940560101503308617928680920874760917824938589009714909675985261365549781893129784821" +
"68299894872265880485756401427047755513237964145152374623436454285844479526586782105114135473573952311342716610" +
"21359695362314429524849371871101457654035902799344037420073105785390621983874478084784896833214457138687519435" +
"06430218453191048481005370614680674919278191197939952061419663428754440643745123718192179998391015919561814675" +
"14269123974894090718649423196156794520809514655022523160388193014209376213785595663893778708303906979207734672" +
"21825625996615014215030680384477345492026054146659252014974428507325186660021324340881907104863317346496514539" +
"05796268561005508106658796998163574736384052571459102897064140110971206280439039759515677157700420337869936007" +
"23055876317635942187312514712053292819182618612586732157919841484882916447060957527069572209175671167229109816" +
"90915280173506712748583222871835209353965725121083579151369882091444210067510334671103141267111369908658516398" +
"31501970165151168517143765761835155650884909989859982387345528331635507647918535893226185489632132933089857064" +
"20467525907091548141654985946163718027098199430992448895757128289059232332609729971208443357326548938239119325" +
"97463667305836041428138830320382490375898524374417029132765618093773444030707469211201913020330380197621101100" +
"44929321516084244485963766983895228684783123552658213144957685726243344189303968642624341077322697802807318915" +
"44110104468232527162010526522721116603966655730925471105578537634668206531098965269186205647693125705863566201" +
"85581007293606598764861179104533488503461136576867532494416680396265797877185560845529654126654085306143444318" +
"58676975145661406800700237877659134401712749470420562230538994561314071127000407854733269939081454664645880797" +
"27082668306343285878569830523580893306575740679545716377525420211495576158140025012622859413021647155097925923" +
"09907965473761255176567513575178296664547791745011299614890304639947132962107340437518957359614589019389713111" +
"79042978285647503203198691514028708085990480109412147221317947647772622414254854540332157185306142288137585043" +
"06332175182979866223717215916077166925474873898665494945011465406284336639379003976926567214638530673609657120" +
"91807638327166416274888800786925602902284721040317211860820419000422966171196377921337575114959501566049631862" +
"94726547364252308177036751590673502350728354056704038674351362222477158915049530984448933309634087807693259939" +
"78054193414473774418426312986080998886874132604721569516239658645730216315981931951673538129741677294786724229" +
"24654366800980676928238280689964004824354037014163149658979409243237896907069779422362508221688957383798623001" +
"59377647165122893578601588161755782973523344604281512627203734314653197777416031990665541876397929334419521541" +
"34189948544473456738316249934191318148092777710386387734317720754565453220777092120190516609628049092636019759" +
"88281613323166636528619326686336062735676303544776280350450777235547105859548702790814356240145171806246436267" +
"94561275318134078330336254232783944975382437205835311477119926063813346776879695970309833913077109870408591337"
// Test case for BenchmarkScanPi.
func TestScanPi(t *testing.T) {
var x nat
z, _, err := x.scan(strings.NewReader(pi), 10)
if err != nil {
t.Errorf("scanning pi: %s", err)
}
if s := z.decimalString(); s != pi {
t.Errorf("scanning pi: got %s", s)
}
}
func BenchmarkScanPi(b *testing.B) {
for i := 0; i < b.N; i++ {
var x nat
x.scan(strings.NewReader(pi), 10)
}
}
const (
// 314**271
// base 2: 2249 digits
// base 8: 751 digits
// base 10: 678 digits
// base 16: 563 digits
shortBase = 314
shortExponent = 271
// 3141**2178
// base 2: 31577 digits
// base 8: 10527 digits
// base 10: 9507 digits
// base 16: 7895 digits
mediumBase = 3141
mediumExponent = 2718
// 3141**2178
// base 2: 406078 digits
// base 8: 135360 digits
// base 10: 122243 digits
// base 16: 101521 digits
longBase = 31415
longExponent = 27182
)
func BenchmarkScanShort2(b *testing.B) {
ScanHelper(b, 2, shortBase, shortExponent)
}
func BenchmarkScanShort8(b *testing.B) {
ScanHelper(b, 8, shortBase, shortExponent)
}
func BenchmarkScanSort10(b *testing.B) {
ScanHelper(b, 10, shortBase, shortExponent)
}
func BenchmarkScanShort16(b *testing.B) {
ScanHelper(b, 16, shortBase, shortExponent)
}
func BenchmarkScanMedium2(b *testing.B) {
ScanHelper(b, 2, mediumBase, mediumExponent)
}
func BenchmarkScanMedium8(b *testing.B) {
ScanHelper(b, 8, mediumBase, mediumExponent)
}
func BenchmarkScanMedium10(b *testing.B) {
ScanHelper(b, 10, mediumBase, mediumExponent)
}
func BenchmarkScanMedium16(b *testing.B) {
ScanHelper(b, 16, mediumBase, mediumExponent)
}
func BenchmarkScanLong2(b *testing.B) {
ScanHelper(b, 2, longBase, longExponent)
}
func BenchmarkScanLong8(b *testing.B) {
ScanHelper(b, 8, longBase, longExponent)
}
func BenchmarkScanLong10(b *testing.B) {
ScanHelper(b, 10, longBase, longExponent)
}
func BenchmarkScanLong16(b *testing.B) {
ScanHelper(b, 16, longBase, longExponent)
}
func ScanHelper(b *testing.B, base int, xv, yv Word) {
b.StopTimer()
var x, y, z nat
x = x.setWord(xv)
y = y.setWord(yv)
z = z.expNN(x, y, nil)
var s string
s = z.string(lowercaseDigits[0:base])
if t := toString(z, lowercaseDigits[0:base]); t != s {
panic(fmt.Sprintf("scanning: got %s; want %s", s, t))
}
b.StartTimer()
for i := 0; i < b.N; i++ {
x.scan(strings.NewReader(s), base)
}
}
func BenchmarkStringShort2(b *testing.B) {
StringHelper(b, 2, shortBase, shortExponent)
}
func BenchmarkStringShort8(b *testing.B) {
StringHelper(b, 8, shortBase, shortExponent)
}
func BenchmarkStringShort10(b *testing.B) {
StringHelper(b, 10, shortBase, shortExponent)
}
func BenchmarkStringShort16(b *testing.B) {
StringHelper(b, 16, shortBase, shortExponent)
}
func BenchmarkStringMedium2(b *testing.B) {
StringHelper(b, 2, mediumBase, mediumExponent)
}
func BenchmarkStringMedium8(b *testing.B) {
StringHelper(b, 8, mediumBase, mediumExponent)
}
func BenchmarkStringMedium10(b *testing.B) {
StringHelper(b, 10, mediumBase, mediumExponent)
}
func BenchmarkStringMedium16(b *testing.B) {
StringHelper(b, 16, mediumBase, mediumExponent)
}
func BenchmarkStringLong2(b *testing.B) {
StringHelper(b, 2, longBase, longExponent)
}
func BenchmarkStringLong8(b *testing.B) {
StringHelper(b, 8, longBase, longExponent)
}
func BenchmarkStringLong10(b *testing.B) {
StringHelper(b, 10, longBase, longExponent)
}
func BenchmarkStringLong16(b *testing.B) {
StringHelper(b, 16, longBase, longExponent)
}
func StringHelper(b *testing.B, base int, xv, yv Word) {
b.StopTimer()
var x, y, z nat
x = x.setWord(xv)
y = y.setWord(yv)
z = z.expNN(x, y, nil)
b.StartTimer()
for i := 0; i < b.N; i++ {
z.string(lowercaseDigits[0:base])
}
}
func TestLeadingZeros(t *testing.T) {
var x Word = _B >> 1
@ -210,14 +534,12 @@ func TestLeadingZeros(t *testing.T) {
}
}
type shiftTest struct {
in nat
shift uint
out nat
}
var leftShiftTests = []shiftTest{
{nil, 0, nil},
{nil, 1, nil},
@ -227,7 +549,6 @@ var leftShiftTests = []shiftTest{
{nat{1 << (_W - 1), 0}, 1, nat{0, 1}},
}
func TestShiftLeft(t *testing.T) {
for i, test := range leftShiftTests {
var z nat
@ -241,7 +562,6 @@ func TestShiftLeft(t *testing.T) {
}
}
var rightShiftTests = []shiftTest{
{nil, 0, nil},
{nil, 1, nil},
@ -252,7 +572,6 @@ var rightShiftTests = []shiftTest{
{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1)}},
}
func TestShiftRight(t *testing.T) {
for i, test := range rightShiftTests {
var z nat
@ -266,24 +585,20 @@ func TestShiftRight(t *testing.T) {
}
}
type modWTest struct {
in string
dividend string
out string
}
var modWTests32 = []modWTest{
{"23492635982634928349238759823742", "252341", "220170"},
}
var modWTests64 = []modWTest{
{"6527895462947293856291561095690465243862946", "524326975699234", "375066989628668"},
}
func runModWTests(t *testing.T, tests []modWTest) {
for i, test := range tests {
in, _ := new(Int).SetString(test.in, 10)
@ -297,7 +612,6 @@ func runModWTests(t *testing.T, tests []modWTest) {
}
}
func TestModW(t *testing.T) {
if _W >= 32 {
runModWTests(t, modWTests32)
@ -307,7 +621,6 @@ func TestModW(t *testing.T) {
}
}
func TestTrailingZeroBits(t *testing.T) {
var x Word
x--
@ -319,7 +632,6 @@ func TestTrailingZeroBits(t *testing.T) {
}
}
var expNNTests = []struct {
x, y, m string
out string
@ -337,17 +649,16 @@ var expNNTests = []struct {
},
}
func TestExpNN(t *testing.T) {
for i, test := range expNNTests {
x, _, _ := nat(nil).scan(test.x, 0)
y, _, _ := nat(nil).scan(test.y, 0)
out, _, _ := nat(nil).scan(test.out, 0)
x, _, _ := nat(nil).scan(strings.NewReader(test.x), 0)
y, _, _ := nat(nil).scan(strings.NewReader(test.y), 0)
out, _, _ := nat(nil).scan(strings.NewReader(test.out), 0)
var m nat
if len(test.m) > 0 {
m, _, _ = nat(nil).scan(test.m, 0)
m, _, _ = nat(nil).scan(strings.NewReader(test.m), 0)
}
z := nat(nil).expNN(x, y, m)

View File

@ -6,7 +6,12 @@
package big
import "strings"
import (
"encoding/binary"
"fmt"
"os"
"strings"
)
// A Rat represents a quotient a/b of arbitrary precision. The zero value for
// a Rat, 0/0, is not a legal Rat.
@ -15,13 +20,11 @@ type Rat struct {
b nat
}
// NewRat creates a new Rat with numerator a and denominator b.
func NewRat(a, b int64) *Rat {
return new(Rat).SetFrac64(a, b)
}
// SetFrac sets z to a/b and returns z.
func (z *Rat) SetFrac(a, b *Int) *Rat {
z.a.Set(a)
@ -30,7 +33,6 @@ func (z *Rat) SetFrac(a, b *Int) *Rat {
return z.norm()
}
// SetFrac64 sets z to a/b and returns z.
func (z *Rat) SetFrac64(a, b int64) *Rat {
z.a.SetInt64(a)
@ -42,7 +44,6 @@ func (z *Rat) SetFrac64(a, b int64) *Rat {
return z.norm()
}
// SetInt sets z to x (by making a copy of x) and returns z.
func (z *Rat) SetInt(x *Int) *Rat {
z.a.Set(x)
@ -50,7 +51,6 @@ func (z *Rat) SetInt(x *Int) *Rat {
return z
}
// SetInt64 sets z to x and returns z.
func (z *Rat) SetInt64(x int64) *Rat {
z.a.SetInt64(x)
@ -58,7 +58,6 @@ func (z *Rat) SetInt64(x int64) *Rat {
return z
}
// Sign returns:
//
// -1 if x < 0
@ -69,13 +68,11 @@ func (x *Rat) Sign() int {
return x.a.Sign()
}
// IsInt returns true if the denominator of x is 1.
func (x *Rat) IsInt() bool {
return len(x.b) == 1 && x.b[0] == 1
}
// Num returns the numerator of z; it may be <= 0.
// The result is a reference to z's numerator; it
// may change if a new value is assigned to z.
@ -83,15 +80,13 @@ func (z *Rat) Num() *Int {
return &z.a
}
// Demom returns the denominator of z; it is always > 0.
// Denom returns the denominator of z; it is always > 0.
// The result is a reference to z's denominator; it
// may change if a new value is assigned to z.
func (z *Rat) Denom() *Int {
return &Int{false, z.b}
}
func gcd(x, y nat) nat {
// Euclidean algorithm.
var a, b nat
@ -106,7 +101,6 @@ func gcd(x, y nat) nat {
return a
}
func (z *Rat) norm() *Rat {
f := gcd(z.a.abs, z.b)
if len(z.a.abs) == 0 {
@ -122,7 +116,6 @@ func (z *Rat) norm() *Rat {
return z
}
func mulNat(x *Int, y nat) *Int {
var z Int
z.abs = z.abs.mul(x.abs, y)
@ -130,7 +123,6 @@ func mulNat(x *Int, y nat) *Int {
return &z
}
// Cmp compares x and y and returns:
//
// -1 if x < y
@ -141,7 +133,6 @@ func (x *Rat) Cmp(y *Rat) (r int) {
return mulNat(&x.a, y.b).Cmp(mulNat(&y.a, x.b))
}
// Abs sets z to |x| (the absolute value of x) and returns z.
func (z *Rat) Abs(x *Rat) *Rat {
z.a.Abs(&x.a)
@ -149,7 +140,6 @@ func (z *Rat) Abs(x *Rat) *Rat {
return z
}
// Add sets z to the sum x+y and returns z.
func (z *Rat) Add(x, y *Rat) *Rat {
a1 := mulNat(&x.a, y.b)
@ -159,7 +149,6 @@ func (z *Rat) Add(x, y *Rat) *Rat {
return z.norm()
}
// Sub sets z to the difference x-y and returns z.
func (z *Rat) Sub(x, y *Rat) *Rat {
a1 := mulNat(&x.a, y.b)
@ -169,7 +158,6 @@ func (z *Rat) Sub(x, y *Rat) *Rat {
return z.norm()
}
// Mul sets z to the product x*y and returns z.
func (z *Rat) Mul(x, y *Rat) *Rat {
z.a.Mul(&x.a, &y.a)
@ -177,7 +165,6 @@ func (z *Rat) Mul(x, y *Rat) *Rat {
return z.norm()
}
// Quo sets z to the quotient x/y and returns z.
// If y == 0, a division-by-zero run-time panic occurs.
func (z *Rat) Quo(x, y *Rat) *Rat {
@ -192,7 +179,6 @@ func (z *Rat) Quo(x, y *Rat) *Rat {
return z.norm()
}
// Neg sets z to -x (by making a copy of x if necessary) and returns z.
func (z *Rat) Neg(x *Rat) *Rat {
z.a.Neg(&x.a)
@ -200,7 +186,6 @@ func (z *Rat) Neg(x *Rat) *Rat {
return z
}
// Set sets z to x (by making a copy of x if necessary) and returns z.
func (z *Rat) Set(x *Rat) *Rat {
z.a.Set(&x.a)
@ -208,6 +193,25 @@ func (z *Rat) Set(x *Rat) *Rat {
return z
}
func ratTok(ch int) bool {
return strings.IndexRune("+-/0123456789.eE", ch) >= 0
}
// Scan is a support routine for fmt.Scanner. It accepts the formats
// 'e', 'E', 'f', 'F', 'g', 'G', and 'v'. All formats are equivalent.
func (z *Rat) Scan(s fmt.ScanState, ch int) os.Error {
tok, err := s.Token(true, ratTok)
if err != nil {
return err
}
if strings.IndexRune("efgEFGv", ch) < 0 {
return os.NewError("Rat.Scan: invalid verb")
}
if _, ok := z.SetString(string(tok)); !ok {
return os.NewError("Rat.Scan: invalid syntax")
}
return nil
}
// SetString sets z to the value of s and returns z and a boolean indicating
// success. s can be given as a fraction "a/b" or as a floating-point number
@ -225,8 +229,8 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
return z, false
}
s = s[sep+1:]
var n int
if z.b, _, n = z.b.scan(s, 10); n != len(s) {
var err os.Error
if z.b, _, err = z.b.scan(strings.NewReader(s), 10); err != nil {
return z, false
}
return z.norm(), true
@ -267,13 +271,11 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
return z, true
}
// String returns a string representation of z in the form "a/b" (even if b == 1).
func (z *Rat) String() string {
return z.a.String() + "/" + z.b.string(10)
return z.a.String() + "/" + z.b.decimalString()
}
// RatString returns a string representation of z in the form "a/b" if b != 1,
// and in the form "a" if b == 1.
func (z *Rat) RatString() string {
@ -283,12 +285,15 @@ func (z *Rat) RatString() string {
return z.String()
}
// FloatString returns a string representation of z in decimal form with prec
// digits of precision after the decimal point and the last digit rounded.
func (z *Rat) FloatString(prec int) string {
if z.IsInt() {
return z.a.String()
s := z.a.String()
if prec > 0 {
s += "." + strings.Repeat("0", prec)
}
return s
}
q, r := nat{}.div(nat{}, z.a.abs, z.b)
@ -311,16 +316,56 @@ func (z *Rat) FloatString(prec int) string {
}
}
s := q.string(10)
s := q.decimalString()
if z.a.neg {
s = "-" + s
}
if prec > 0 {
rs := r.string(10)
rs := r.decimalString()
leadingZeros := prec - len(rs)
s += "." + strings.Repeat("0", leadingZeros) + rs
}
return s
}
// Gob codec version. Permits backward-compatible changes to the encoding.
const ratGobVersion byte = 1
// GobEncode implements the gob.GobEncoder interface.
func (z *Rat) GobEncode() ([]byte, os.Error) {
buf := make([]byte, 1+4+(len(z.a.abs)+len(z.b))*_S) // extra bytes for version and sign bit (1), and numerator length (4)
i := z.b.bytes(buf)
j := z.a.abs.bytes(buf[0:i])
n := i - j
if int(uint32(n)) != n {
// this should never happen
return nil, os.NewError("Rat.GobEncode: numerator too large")
}
binary.BigEndian.PutUint32(buf[j-4:j], uint32(n))
j -= 1 + 4
b := ratGobVersion << 1 // make space for sign bit
if z.a.neg {
b |= 1
}
buf[j] = b
return buf[j:], nil
}
// GobDecode implements the gob.GobDecoder interface.
func (z *Rat) GobDecode(buf []byte) os.Error {
if len(buf) == 0 {
return os.NewError("Rat.GobDecode: no data")
}
b := buf[0]
if b>>1 != ratGobVersion {
return os.NewError(fmt.Sprintf("Rat.GobDecode: encoding version %d not supported", b>>1))
}
const j = 1 + 4
i := j + binary.BigEndian.Uint32(buf[j-4:j])
z.a.neg = b&1 != 0
z.a.abs = z.a.abs.setBytes(buf[j:i])
z.b = z.b.setBytes(buf[i:])
return nil
}

View File

@ -4,8 +4,12 @@
package big
import "testing"
import (
"bytes"
"fmt"
"gob"
"testing"
)
var setStringTests = []struct {
in, out string
@ -52,6 +56,27 @@ func TestRatSetString(t *testing.T) {
}
}
func TestRatScan(t *testing.T) {
var buf bytes.Buffer
for i, test := range setStringTests {
x := new(Rat)
buf.Reset()
buf.WriteString(test.in)
_, err := fmt.Fscanf(&buf, "%v", x)
if err == nil != test.ok {
if test.ok {
t.Errorf("#%d error: %s", i, err.String())
} else {
t.Errorf("#%d expected error", i)
}
continue
}
if err == nil && x.RatString() != test.out {
t.Errorf("#%d got %s want %s", i, x.RatString(), test.out)
}
}
}
var floatStringTests = []struct {
in string
@ -59,12 +84,13 @@ var floatStringTests = []struct {
out string
}{
{"0", 0, "0"},
{"0", 4, "0"},
{"0", 4, "0.0000"},
{"1", 0, "1"},
{"1", 2, "1"},
{"1", 2, "1.00"},
{"-1", 0, "-1"},
{".25", 2, "0.25"},
{".25", 1, "0.3"},
{".25", 3, "0.250"},
{"-1/3", 3, "-0.333"},
{"-2/3", 4, "-0.6667"},
{"0.96", 1, "1.0"},
@ -84,7 +110,6 @@ func TestFloatString(t *testing.T) {
}
}
func TestRatSign(t *testing.T) {
zero := NewRat(0, 1)
for _, a := range setStringTests {
@ -98,7 +123,6 @@ func TestRatSign(t *testing.T) {
}
}
var ratCmpTests = []struct {
rat1, rat2 string
out int
@ -126,7 +150,6 @@ func TestRatCmp(t *testing.T) {
}
}
func TestIsInt(t *testing.T) {
one := NewInt(1)
for _, a := range setStringTests {
@ -140,7 +163,6 @@ func TestIsInt(t *testing.T) {
}
}
func TestRatAbs(t *testing.T) {
zero := NewRat(0, 1)
for _, a := range setStringTests {
@ -158,7 +180,6 @@ func TestRatAbs(t *testing.T) {
}
}
type ratBinFun func(z, x, y *Rat) *Rat
type ratBinArg struct {
x, y, z string
@ -175,7 +196,6 @@ func testRatBin(t *testing.T, i int, name string, f ratBinFun, a ratBinArg) {
}
}
var ratBinTests = []struct {
x, y string
sum, prod string
@ -232,7 +252,6 @@ func TestRatBin(t *testing.T) {
}
}
func TestIssue820(t *testing.T) {
x := NewRat(3, 1)
y := NewRat(2, 1)
@ -258,7 +277,6 @@ func TestIssue820(t *testing.T) {
}
}
var setFrac64Tests = []struct {
a, b int64
out string
@ -280,3 +298,35 @@ func TestRatSetFrac64Rat(t *testing.T) {
}
}
}
func TestRatGobEncoding(t *testing.T) {
var medium bytes.Buffer
enc := gob.NewEncoder(&medium)
dec := gob.NewDecoder(&medium)
for i, test := range gobEncodingTests {
for j := 0; j < 4; j++ {
medium.Reset() // empty buffer for each test case (in case of failures)
stest := test
if j&1 != 0 {
// negative numbers
stest = "-" + test
}
if j%2 != 0 {
// fractions
stest = stest + "." + test
}
var tx Rat
tx.SetString(stest)
if err := enc.Encode(&tx); err != nil {
t.Errorf("#%d%c: encoding failed: %s", i, 'a'+j, err)
}
var rx Rat
if err := dec.Decode(&rx); err != nil {
t.Errorf("#%d%c: decoding failed: %s", i, 'a'+j, err)
}
if rx.Cmp(&tx) != 0 {
t.Errorf("#%d%c: transmission failed: got %s want %s", i, 'a'+j, &rx, &tx)
}
}
}
}

View File

@ -15,16 +15,17 @@ import (
"utf8"
)
const (
defaultBufSize = 4096
)
// Errors introduced by this package.
type Error struct {
os.ErrorString
ErrorString string
}
func (err *Error) String() string { return err.ErrorString }
var (
ErrInvalidUnreadByte os.Error = &Error{"bufio: invalid use of UnreadByte"}
ErrInvalidUnreadRune os.Error = &Error{"bufio: invalid use of UnreadRune"}
@ -40,7 +41,6 @@ func (b BufSizeError) String() string {
return "bufio: bad buffer size " + strconv.Itoa(int(b))
}
// Buffered input.
// Reader implements buffering for an io.Reader object.
@ -101,6 +101,12 @@ func (b *Reader) fill() {
}
}
func (b *Reader) readErr() os.Error {
err := b.err
b.err = nil
return err
}
// Peek returns the next n bytes without advancing the reader. The bytes stop
// being valid at the next read call. If Peek returns fewer than n bytes, it
// also returns an error explaining why the read is short. The error is
@ -119,7 +125,7 @@ func (b *Reader) Peek(n int) ([]byte, os.Error) {
if m > n {
m = n
}
err := b.err
err := b.readErr()
if m < n && err == nil {
err = ErrBufferFull
}
@ -134,11 +140,11 @@ func (b *Reader) Peek(n int) ([]byte, os.Error) {
func (b *Reader) Read(p []byte) (n int, err os.Error) {
n = len(p)
if n == 0 {
return 0, b.err
return 0, b.readErr()
}
if b.w == b.r {
if b.err != nil {
return 0, b.err
return 0, b.readErr()
}
if len(p) >= len(b.buf) {
// Large read, empty buffer.
@ -148,11 +154,11 @@ func (b *Reader) Read(p []byte) (n int, err os.Error) {
b.lastByte = int(p[n-1])
b.lastRuneSize = -1
}
return n, b.err
return n, b.readErr()
}
b.fill()
if b.w == b.r {
return 0, b.err
return 0, b.readErr()
}
}
@ -172,7 +178,7 @@ func (b *Reader) ReadByte() (c byte, err os.Error) {
b.lastRuneSize = -1
for b.w == b.r {
if b.err != nil {
return 0, b.err
return 0, b.readErr()
}
b.fill()
}
@ -208,7 +214,7 @@ func (b *Reader) ReadRune() (rune int, size int, err os.Error) {
}
b.lastRuneSize = -1
if b.r == b.w {
return 0, 0, b.err
return 0, 0, b.readErr()
}
rune, size = int(b.buf[b.r]), 1
if rune >= 0x80 {
@ -260,7 +266,7 @@ func (b *Reader) ReadSlice(delim byte) (line []byte, err os.Error) {
if b.err != nil {
line := b.buf[b.r:b.w]
b.r = b.w
return line, b.err
return line, b.readErr()
}
n := b.Buffered()
@ -367,7 +373,6 @@ func (b *Reader) ReadString(delim byte) (line string, err os.Error) {
return string(bytes), e
}
// buffered output
// Writer implements buffering for an io.Writer object.

View File

@ -53,11 +53,12 @@ func readBytes(buf *Reader) string {
if e == os.EOF {
break
}
if e != nil {
if e == nil {
b[nb] = c
nb++
} else if e != iotest.ErrTimeout {
panic("Data: " + e.String())
}
b[nb] = c
nb++
}
return string(b[0:nb])
}
@ -75,7 +76,6 @@ func TestReaderSimple(t *testing.T) {
}
}
type readMaker struct {
name string
fn func(io.Reader) io.Reader
@ -86,6 +86,7 @@ var readMakers = []readMaker{
{"byte", iotest.OneByteReader},
{"half", iotest.HalfReader},
{"data+err", iotest.DataErrReader},
{"timeout", iotest.TimeoutReader},
}
// Call ReadString (which ends up calling everything else)
@ -97,7 +98,7 @@ func readLines(b *Reader) string {
if e == os.EOF {
break
}
if e != nil {
if e != nil && e != iotest.ErrTimeout {
panic("GetLines: " + e.String())
}
s += s1

135
libgo/go/builtin/builtin.go Normal file
View File

@ -0,0 +1,135 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package builtin provides documentation for Go's built-in functions.
The functions documented here are not actually in package builtin
but their descriptions here allow godoc to present documentation
for the language's special functions.
*/
package builtin
// Type is here for the purposes of documentation only. It is a stand-in
// for any Go type, but represents the same type for any given function
// invocation.
type Type int
// IntegerType is here for the purposes of documentation only. It is a stand-in
// for any integer type: int, uint, int8 etc.
type IntegerType int
// FloatType is here for the purposes of documentation only. It is a stand-in
// for either float type: float32 or float64.
type FloatType int
// ComplexType is here for the purposes of documentation only. It is a
// stand-in for either complex type: complex64 or complex128.
type ComplexType int
// The append built-in function appends elements to the end of a slice. If
// it has sufficient capacity, the destination is resliced to accommodate the
// new elements. If it does not, a new underlying array will be allocated.
// Append returns the updated slice. It is therefore necessary to store the
// result of append, often in the variable holding the slice itself:
// slice = append(slice, elem1, elem2)
// slice = append(slice, anotherSlice...)
func append(slice []Type, elems ...Type) []Type
// The copy built-in function copies elements from a source slice into a
// destination slice. (As a special case, it also will copy bytes from a
// string to a slice of bytes.) The source and destination may overlap. Copy
// returns the number of elements copied, which will be the minimum of
// len(src) and len(dst).
func copy(dst, src []Type) int
// The len built-in function returns the length of v, according to its type:
// Array: the number of elements in v.
// Pointer to array: the number of elements in *v (even if v is nil).
// Slice, or map: the number of elements in v; if v is nil, len(v) is zero.
// String: the number of bytes in v.
// Channel: the number of elements queued (unread) in the channel buffer;
// if v is nil, len(v) is zero.
func len(v Type) int
// The cap built-in function returns the capacity of v, according to its type:
// Array: the number of elements in v (same as len(v)).
// Pointer to array: the number of elements in *v (same as len(v)).
// Slice: the maximum length the slice can reach when resliced;
// if v is nil, cap(v) is zero.
// Channel: the channel buffer capacity, in units of elements;
// if v is nil, cap(v) is zero.
func cap(v Type) int
// The make built-in function allocates and initializes an object of type
// slice, map, or chan (only). Like new, the first argument is a type, not a
// value. Unlike new, make's return type is the same as the type of its
// argument, not a pointer to it. The specification of the result depends on
// the type:
// Slice: The size specifies the length. The capacity of the slice is
// equal to its length. A second integer argument may be provided to
// specify a different capacity; it must be no smaller than the
// length, so make([]int, 0, 10) allocates a slice of length 0 and
// capacity 10.
// Map: An initial allocation is made according to the size but the
// resulting map has length 0. The size may be omitted, in which case
// a small starting size is allocated.
// Channel: The channel's buffer is initialized with the specified
// buffer capacity. If zero, or the size is omitted, the channel is
// unbuffered.
func make(Type, size IntegerType) Type
// The new built-in function allocates memory. The first argument is a type,
// not a value, and the value returned is a pointer to a newly
// allocated zero value of that type.
func new(Type) *Type
// The complex built-in function constructs a complex value from two
// floating-point values. The real and imaginary parts must be of the same
// size, either float32 or float64 (or assignable to them), and the return
// value will be the corresponding complex type (complex64 for float32,
// complex128 for float64).
func complex(r, i FloatType) ComplexType
// The real built-in function returns the real part of the complex number c.
// The return value will be floating point type corresponding to the type of c.
func real(c ComplexType) FloatType
// The imaginary built-in function returns the imaginary part of the complex
// number c. The return value will be floating point type corresponding to
// the type of c.
func imag(c ComplexType) FloatType
// The close built-in function closes a channel, which must be either
// bidirectional or send-only. It should be executed only by the sender,
// never the receiver, and has the effect of shutting down the channel after
// the last sent value is received. After the last value has been received
// from a closed channel c, any receive from c will succeed without
// blocking, returning the zero value for the channel element. The form
// x, ok := <-c
// will also set ok to false for a closed channel.
func close(c chan<- Type)
// The panic built-in function stops normal execution of the current
// goroutine. When a function F calls panic, normal execution of F stops
// immediately. Any functions whose execution was deferred by F are run in
// the usual way, and then F returns to its caller. To the caller G, the
// invocation of F then behaves like a call to panic, terminating G's
// execution and running any deferred functions. This continues until all
// functions in the executing goroutine have stopped, in reverse order. At
// that point, the program is terminated and the error condition is reported,
// including the value of the argument to panic. This termination sequence
// is called panicking and can be controlled by the built-in function
// recover.
func panic(v interface{})
// The recover built-in function allows a program to manage behavior of a
// panicking goroutine. Executing a call to recover inside a deferred
// function (but not any function called by it) stops the panicking sequence
// by restoring normal execution and retrieves the error value passed to the
// call of panic. If recover is called outside the deferred function it will
// not stop a panicking sequence. In this case, or when the goroutine is not
// panicking, or if the argument supplied to panic was nil, recover returns
// nil. Thus the return value from recover reports whether the goroutine is
// panicking.
func recover() interface{}

View File

@ -280,7 +280,7 @@ func (b *Buffer) ReadRune() (r int, size int, err os.Error) {
// from any read operation.)
func (b *Buffer) UnreadRune() os.Error {
if b.lastRead != opReadRune {
return os.ErrorString("bytes.Buffer: UnreadRune: previous operation was not ReadRune")
return os.NewError("bytes.Buffer: UnreadRune: previous operation was not ReadRune")
}
b.lastRead = opInvalid
if b.off > 0 {
@ -295,7 +295,7 @@ func (b *Buffer) UnreadRune() os.Error {
// returns an error.
func (b *Buffer) UnreadByte() os.Error {
if b.lastRead != opReadRune && b.lastRead != opRead {
return os.ErrorString("bytes.Buffer: UnreadByte: previous operation was not a read")
return os.NewError("bytes.Buffer: UnreadByte: previous operation was not a read")
}
b.lastRead = opInvalid
if b.off > 0 {

View File

@ -12,7 +12,6 @@ import (
"utf8"
)
const N = 10000 // make this bigger for a larger (and slower) test
var data string // test data for write tests
var bytes []byte // test data; same as data but as a slice.
@ -47,7 +46,6 @@ func check(t *testing.T, testname string, buf *Buffer, s string) {
}
}
// Fill buf through n writes of string fus.
// The initial contents of buf corresponds to the string s;
// the result is the final contents of buf returned as a string.
@ -67,7 +65,6 @@ func fillString(t *testing.T, testname string, buf *Buffer, s string, n int, fus
return s
}
// Fill buf through n writes of byte slice fub.
// The initial contents of buf corresponds to the string s;
// the result is the final contents of buf returned as a string.
@ -87,19 +84,16 @@ func fillBytes(t *testing.T, testname string, buf *Buffer, s string, n int, fub
return s
}
func TestNewBuffer(t *testing.T) {
buf := NewBuffer(bytes)
check(t, "NewBuffer", buf, data)
}
func TestNewBufferString(t *testing.T) {
buf := NewBufferString(data)
check(t, "NewBufferString", buf, data)
}
// Empty buf through repeated reads into fub.
// The initial contents of buf corresponds to the string s.
func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) {
@ -120,7 +114,6 @@ func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) {
check(t, testname+" (empty 4)", buf, "")
}
func TestBasicOperations(t *testing.T) {
var buf Buffer
@ -175,7 +168,6 @@ func TestBasicOperations(t *testing.T) {
}
}
func TestLargeStringWrites(t *testing.T) {
var buf Buffer
limit := 30
@ -189,7 +181,6 @@ func TestLargeStringWrites(t *testing.T) {
check(t, "TestLargeStringWrites (3)", &buf, "")
}
func TestLargeByteWrites(t *testing.T) {
var buf Buffer
limit := 30
@ -203,7 +194,6 @@ func TestLargeByteWrites(t *testing.T) {
check(t, "TestLargeByteWrites (3)", &buf, "")
}
func TestLargeStringReads(t *testing.T) {
var buf Buffer
for i := 3; i < 30; i += 3 {
@ -213,7 +203,6 @@ func TestLargeStringReads(t *testing.T) {
check(t, "TestLargeStringReads (3)", &buf, "")
}
func TestLargeByteReads(t *testing.T) {
var buf Buffer
for i := 3; i < 30; i += 3 {
@ -223,7 +212,6 @@ func TestLargeByteReads(t *testing.T) {
check(t, "TestLargeByteReads (3)", &buf, "")
}
func TestMixedReadsAndWrites(t *testing.T) {
var buf Buffer
s := ""
@ -243,7 +231,6 @@ func TestMixedReadsAndWrites(t *testing.T) {
empty(t, "TestMixedReadsAndWrites (2)", &buf, s, make([]byte, buf.Len()))
}
func TestNil(t *testing.T) {
var b *Buffer
if b.String() != "<nil>" {
@ -251,7 +238,6 @@ func TestNil(t *testing.T) {
}
}
func TestReadFrom(t *testing.T) {
var buf Buffer
for i := 3; i < 30; i += 3 {
@ -262,7 +248,6 @@ func TestReadFrom(t *testing.T) {
}
}
func TestWriteTo(t *testing.T) {
var buf Buffer
for i := 3; i < 30; i += 3 {
@ -273,7 +258,6 @@ func TestWriteTo(t *testing.T) {
}
}
func TestRuneIO(t *testing.T) {
const NRune = 1000
// Built a test array while we write the data
@ -323,7 +307,6 @@ func TestRuneIO(t *testing.T) {
}
}
func TestNext(t *testing.T) {
b := []byte{0, 1, 2, 3, 4}
tmp := make([]byte, 5)

View File

@ -212,24 +212,38 @@ func genSplit(s, sep []byte, sepSave, n int) [][]byte {
return a[0 : na+1]
}
// Split slices s into subslices separated by sep and returns a slice of
// SplitN slices s into subslices separated by sep and returns a slice of
// the subslices between those separators.
// If sep is empty, SplitN splits after each UTF-8 sequence.
// The count determines the number of subslices to return:
// n > 0: at most n subslices; the last subslice will be the unsplit remainder.
// n == 0: the result is nil (zero subslices)
// n < 0: all subslices
func SplitN(s, sep []byte, n int) [][]byte { return genSplit(s, sep, 0, n) }
// SplitAfterN slices s into subslices after each instance of sep and
// returns a slice of those subslices.
// If sep is empty, SplitAfterN splits after each UTF-8 sequence.
// The count determines the number of subslices to return:
// n > 0: at most n subslices; the last subslice will be the unsplit remainder.
// n == 0: the result is nil (zero subslices)
// n < 0: all subslices
func SplitAfterN(s, sep []byte, n int) [][]byte {
return genSplit(s, sep, len(sep), n)
}
// Split slices s into all subslices separated by sep and returns a slice of
// the subslices between those separators.
// If sep is empty, Split splits after each UTF-8 sequence.
// The count determines the number of subslices to return:
// n > 0: at most n subslices; the last subslice will be the unsplit remainder.
// n == 0: the result is nil (zero subslices)
// n < 0: all subslices
func Split(s, sep []byte, n int) [][]byte { return genSplit(s, sep, 0, n) }
// It is equivalent to SplitN with a count of -1.
func Split(s, sep []byte) [][]byte { return genSplit(s, sep, 0, -1) }
// SplitAfter slices s into subslices after each instance of sep and
// SplitAfter slices s into all subslices after each instance of sep and
// returns a slice of those subslices.
// If sep is empty, Split splits after each UTF-8 sequence.
// The count determines the number of subslices to return:
// n > 0: at most n subslices; the last subslice will be the unsplit remainder.
// n == 0: the result is nil (zero subslices)
// n < 0: all subslices
func SplitAfter(s, sep []byte, n int) [][]byte {
return genSplit(s, sep, len(sep), n)
// If sep is empty, SplitAfter splits after each UTF-8 sequence.
// It is equivalent to SplitAfterN with a count of -1.
func SplitAfter(s, sep []byte) [][]byte {
return genSplit(s, sep, len(sep), -1)
}
// Fields splits the array s around each instance of one or more consecutive white space
@ -384,7 +398,6 @@ func ToTitleSpecial(_case unicode.SpecialCase, s []byte) []byte {
return Map(func(r int) int { return _case.ToTitle(r) }, s)
}
// isSeparator reports whether the rune could mark a word boundary.
// TODO: update when package unicode captures more of the properties.
func isSeparator(rune int) bool {

View File

@ -6,6 +6,7 @@ package bytes_test
import (
. "bytes"
"reflect"
"testing"
"unicode"
"utf8"
@ -315,7 +316,7 @@ var explodetests = []ExplodeTest{
func TestExplode(t *testing.T) {
for _, tt := range explodetests {
a := Split([]byte(tt.s), nil, tt.n)
a := SplitN([]byte(tt.s), nil, tt.n)
result := arrayOfString(a)
if !eq(result, tt.a) {
t.Errorf(`Explode("%s", %d) = %v; want %v`, tt.s, tt.n, result, tt.a)
@ -328,7 +329,6 @@ func TestExplode(t *testing.T) {
}
}
type SplitTest struct {
s string
sep string
@ -354,7 +354,7 @@ var splittests = []SplitTest{
func TestSplit(t *testing.T) {
for _, tt := range splittests {
a := Split([]byte(tt.s), []byte(tt.sep), tt.n)
a := SplitN([]byte(tt.s), []byte(tt.sep), tt.n)
result := arrayOfString(a)
if !eq(result, tt.a) {
t.Errorf(`Split(%q, %q, %d) = %v; want %v`, tt.s, tt.sep, tt.n, result, tt.a)
@ -367,6 +367,12 @@ func TestSplit(t *testing.T) {
if string(s) != tt.s {
t.Errorf(`Join(Split(%q, %q, %d), %q) = %q`, tt.s, tt.sep, tt.n, tt.sep, s)
}
if tt.n < 0 {
b := Split([]byte(tt.s), []byte(tt.sep))
if !reflect.DeepEqual(a, b) {
t.Errorf("Split disagrees withSplitN(%q, %q, %d) = %v; want %v", tt.s, tt.sep, tt.n, b, a)
}
}
}
}
@ -388,7 +394,7 @@ var splitaftertests = []SplitTest{
func TestSplitAfter(t *testing.T) {
for _, tt := range splitaftertests {
a := SplitAfter([]byte(tt.s), []byte(tt.sep), tt.n)
a := SplitAfterN([]byte(tt.s), []byte(tt.sep), tt.n)
result := arrayOfString(a)
if !eq(result, tt.a) {
t.Errorf(`Split(%q, %q, %d) = %v; want %v`, tt.s, tt.sep, tt.n, result, tt.a)
@ -398,6 +404,12 @@ func TestSplitAfter(t *testing.T) {
if string(s) != tt.s {
t.Errorf(`Join(Split(%q, %q, %d), %q) = %q`, tt.s, tt.sep, tt.n, tt.sep, s)
}
if tt.n < 0 {
b := SplitAfter([]byte(tt.s), []byte(tt.sep))
if !reflect.DeepEqual(a, b) {
t.Errorf("SplitAfter disagrees withSplitAfterN(%q, %q, %d) = %v; want %v", tt.s, tt.sep, tt.n, b, a)
}
}
}
}
@ -649,7 +661,6 @@ func TestRunes(t *testing.T) {
}
}
type TrimTest struct {
f func([]byte, string) []byte
in, cutset, out string

View File

@ -284,7 +284,7 @@ func (bz2 *reader) readBlock() (err os.Error) {
repeat := 0
repeat_power := 0
// The `C' array (used by the inverse BWT) needs to be zero initialised.
// The `C' array (used by the inverse BWT) needs to be zero initialized.
for i := range bz2.c {
bz2.c[i] = 0
}
@ -330,7 +330,7 @@ func (bz2 *reader) readBlock() (err os.Error) {
if int(v) == numSymbols-1 {
// This is the EOF symbol. Because it's always at the
// end of the move-to-front list, and nevers gets moved
// end of the move-to-front list, and never gets moved
// to the front, it has this unique value.
break
}

View File

@ -68,7 +68,7 @@ func newHuffmanTree(lengths []uint8) (huffmanTree, os.Error) {
// each symbol (consider reflecting a tree down the middle, for
// example). Since the code length assignments determine the
// efficiency of the tree, each of these trees is equally good. In
// order to minimise the amount of information needed to build a tree
// order to minimize the amount of information needed to build a tree
// bzip2 uses a canonical tree so that it can be reconstructed given
// only the code length assignments.

View File

@ -11,16 +11,18 @@ import (
)
const (
NoCompression = 0
BestSpeed = 1
fastCompression = 3
BestCompression = 9
DefaultCompression = -1
logMaxOffsetSize = 15 // Standard DEFLATE
wideLogMaxOffsetSize = 22 // Wide DEFLATE
minMatchLength = 3 // The smallest match that the compressor looks for
maxMatchLength = 258 // The longest match for the compressor
minOffsetSize = 1 // The shortest offset that makes any sence
NoCompression = 0
BestSpeed = 1
fastCompression = 3
BestCompression = 9
DefaultCompression = -1
logWindowSize = 15
windowSize = 1 << logWindowSize
windowMask = windowSize - 1
logMaxOffsetSize = 15 // Standard DEFLATE
minMatchLength = 3 // The smallest match that the compressor looks for
maxMatchLength = 258 // The longest match for the compressor
minOffsetSize = 1 // The shortest offset that makes any sence
// The maximum number of tokens we put into a single flat block, just too
// stop things from getting too large.
@ -32,22 +34,6 @@ const (
hashShift = (hashBits + minMatchLength - 1) / minMatchLength
)
type syncPipeReader struct {
*io.PipeReader
closeChan chan bool
}
func (sr *syncPipeReader) CloseWithError(err os.Error) os.Error {
retErr := sr.PipeReader.CloseWithError(err)
sr.closeChan <- true // finish writer close
return retErr
}
type syncPipeWriter struct {
*io.PipeWriter
closeChan chan bool
}
type compressionLevel struct {
good, lazy, nice, chain, fastSkipHashing int
}
@ -68,105 +54,73 @@ var levels = []compressionLevel{
{32, 258, 258, 4096, math.MaxInt32},
}
func (sw *syncPipeWriter) Close() os.Error {
err := sw.PipeWriter.Close()
<-sw.closeChan // wait for reader close
return err
}
func syncPipe() (*syncPipeReader, *syncPipeWriter) {
r, w := io.Pipe()
sr := &syncPipeReader{r, make(chan bool, 1)}
sw := &syncPipeWriter{w, sr.closeChan}
return sr, sw
}
type compressor struct {
level int
logWindowSize uint
w *huffmanBitWriter
r io.Reader
// (1 << logWindowSize) - 1.
windowMask int
compressionLevel
eof bool // has eof been reached on input?
sync bool // writer wants to flush
syncChan chan os.Error
w *huffmanBitWriter
// compression algorithm
fill func(*compressor, []byte) int // copy data to window
step func(*compressor) // process window
sync bool // requesting flush
// Input hash chains
// hashHead[hashValue] contains the largest inputIndex with the specified hash value
hashHead []int
// If hashHead[hashValue] is within the current window, then
// hashPrev[hashHead[hashValue] & windowMask] contains the previous index
// with the same hash value.
hashPrev []int
chainHead int
hashHead []int
hashPrev []int
// If we find a match of length >= niceMatch, then we don't bother searching
// any further.
niceMatch int
// input window: unprocessed data is window[index:windowEnd]
index int
window []byte
windowEnd int
blockStart int // window index where current tokens start
byteAvailable bool // if true, still need to process window[index-1].
// If we find a match of length >= goodMatch, we only do a half-hearted
// effort at doing lazy matching starting at the next character
goodMatch int
// queued output tokens: tokens[:ti]
tokens []token
ti int
// The maximum number of chains we look at when finding a match
maxChainLength int
// The sliding window we use for matching
window []byte
// The index just past the last valid character
windowEnd int
// index in "window" at which current block starts
blockStart int
// deflate state
length int
offset int
hash int
maxInsertIndex int
err os.Error
}
func (d *compressor) flush() os.Error {
d.w.flush()
return d.w.err
}
func (d *compressor) fillWindow(index int) (int, os.Error) {
if d.sync {
return index, nil
}
wSize := d.windowMask + 1
if index >= wSize+wSize-(minMatchLength+maxMatchLength) {
// shift the window by wSize
copy(d.window, d.window[wSize:2*wSize])
index -= wSize
d.windowEnd -= wSize
if d.blockStart >= wSize {
d.blockStart -= wSize
func (d *compressor) fillDeflate(b []byte) int {
if d.index >= 2*windowSize-(minMatchLength+maxMatchLength) {
// shift the window by windowSize
copy(d.window, d.window[windowSize:2*windowSize])
d.index -= windowSize
d.windowEnd -= windowSize
if d.blockStart >= windowSize {
d.blockStart -= windowSize
} else {
d.blockStart = math.MaxInt32
}
for i, h := range d.hashHead {
v := h - wSize
v := h - windowSize
if v < -1 {
v = -1
}
d.hashHead[i] = v
}
for i, h := range d.hashPrev {
v := -h - wSize
v := -h - windowSize
if v < -1 {
v = -1
}
d.hashPrev[i] = v
}
}
count, err := d.r.Read(d.window[d.windowEnd:])
d.windowEnd += count
if count == 0 && err == nil {
d.sync = true
}
if err == os.EOF {
d.eof = true
err = nil
}
return index, err
n := copy(d.window[d.windowEnd:], b)
d.windowEnd += n
return n
}
func (d *compressor) writeBlock(tokens []token, index int, eof bool) os.Error {
@ -194,21 +148,21 @@ func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead
// We quit when we get a match that's at least nice long
nice := len(win) - pos
if d.niceMatch < nice {
nice = d.niceMatch
if d.nice < nice {
nice = d.nice
}
// If we've got a match that's good enough, only look in 1/4 the chain.
tries := d.maxChainLength
tries := d.chain
length = prevLength
if length >= d.goodMatch {
if length >= d.good {
tries >>= 2
}
w0 := win[pos]
w1 := win[pos+1]
wEnd := win[pos+length]
minIndex := pos - (d.windowMask + 1)
minIndex := pos - windowSize
for i := prevHead; tries > 0; tries-- {
if w0 == win[i] && w1 == win[i+1] && wEnd == win[i+length] {
@ -233,7 +187,7 @@ func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead
// hashPrev[i & windowMask] has already been overwritten, so stop now.
break
}
if i = d.hashPrev[i&d.windowMask]; i < minIndex || i < 0 {
if i = d.hashPrev[i&windowMask]; i < minIndex || i < 0 {
break
}
}
@ -248,234 +202,224 @@ func (d *compressor) writeStoredBlock(buf []byte) os.Error {
return d.w.err
}
func (d *compressor) storedDeflate() os.Error {
buf := make([]byte, maxStoreBlockSize)
for {
n, err := d.r.Read(buf)
if n == 0 && err == nil {
d.sync = true
}
if n > 0 || d.sync {
if err := d.writeStoredBlock(buf[0:n]); err != nil {
return err
}
if d.sync {
d.syncChan <- nil
d.sync = false
}
}
if err != nil {
if err == os.EOF {
break
}
return err
}
}
return nil
func (d *compressor) initDeflate() {
d.hashHead = make([]int, hashSize)
d.hashPrev = make([]int, windowSize)
d.window = make([]byte, 2*windowSize)
fillInts(d.hashHead, -1)
d.tokens = make([]token, maxFlateBlockTokens, maxFlateBlockTokens+1)
d.length = minMatchLength - 1
d.offset = 0
d.byteAvailable = false
d.index = 0
d.ti = 0
d.hash = 0
d.chainHead = -1
}
func (d *compressor) doDeflate() (err os.Error) {
// init
d.windowMask = 1<<d.logWindowSize - 1
d.hashHead = make([]int, hashSize)
d.hashPrev = make([]int, 1<<d.logWindowSize)
d.window = make([]byte, 2<<d.logWindowSize)
fillInts(d.hashHead, -1)
tokens := make([]token, maxFlateBlockTokens, maxFlateBlockTokens+1)
l := levels[d.level]
d.goodMatch = l.good
d.niceMatch = l.nice
d.maxChainLength = l.chain
lazyMatch := l.lazy
length := minMatchLength - 1
offset := 0
byteAvailable := false
isFastDeflate := l.fastSkipHashing != 0
index := 0
// run
if index, err = d.fillWindow(index); err != nil {
func (d *compressor) deflate() {
if d.windowEnd-d.index < minMatchLength+maxMatchLength && !d.sync {
return
}
maxOffset := d.windowMask + 1 // (1 << logWindowSize);
// only need to change when you refill the window
windowEnd := d.windowEnd
maxInsertIndex := windowEnd - (minMatchLength - 1)
ti := 0
hash := int(0)
if index < maxInsertIndex {
hash = int(d.window[index])<<hashShift + int(d.window[index+1])
d.maxInsertIndex = d.windowEnd - (minMatchLength - 1)
if d.index < d.maxInsertIndex {
d.hash = int(d.window[d.index])<<hashShift + int(d.window[d.index+1])
}
chainHead := -1
Loop:
for {
if index > windowEnd {
if d.index > d.windowEnd {
panic("index > windowEnd")
}
lookahead := windowEnd - index
lookahead := d.windowEnd - d.index
if lookahead < minMatchLength+maxMatchLength {
if index, err = d.fillWindow(index); err != nil {
return
if !d.sync {
break Loop
}
windowEnd = d.windowEnd
if index > windowEnd {
if d.index > d.windowEnd {
panic("index > windowEnd")
}
maxInsertIndex = windowEnd - (minMatchLength - 1)
lookahead = windowEnd - index
if lookahead == 0 {
// Flush current output block if any.
if byteAvailable {
if d.byteAvailable {
// There is still one pending token that needs to be flushed
tokens[ti] = literalToken(uint32(d.window[index-1]) & 0xFF)
ti++
byteAvailable = false
d.tokens[d.ti] = literalToken(uint32(d.window[d.index-1]))
d.ti++
d.byteAvailable = false
}
if ti > 0 {
if err = d.writeBlock(tokens[0:ti], index, false); err != nil {
if d.ti > 0 {
if d.err = d.writeBlock(d.tokens[0:d.ti], d.index, false); d.err != nil {
return
}
ti = 0
}
if d.sync {
d.w.writeStoredHeader(0, false)
d.w.flush()
d.syncChan <- d.w.err
d.sync = false
}
// If this was only a sync (not at EOF) keep going.
if !d.eof {
continue
d.ti = 0
}
break Loop
}
}
if index < maxInsertIndex {
if d.index < d.maxInsertIndex {
// Update the hash
hash = (hash<<hashShift + int(d.window[index+2])) & hashMask
chainHead = d.hashHead[hash]
d.hashPrev[index&d.windowMask] = chainHead
d.hashHead[hash] = index
d.hash = (d.hash<<hashShift + int(d.window[d.index+2])) & hashMask
d.chainHead = d.hashHead[d.hash]
d.hashPrev[d.index&windowMask] = d.chainHead
d.hashHead[d.hash] = d.index
}
prevLength := length
prevOffset := offset
length = minMatchLength - 1
offset = 0
minIndex := index - maxOffset
prevLength := d.length
prevOffset := d.offset
d.length = minMatchLength - 1
d.offset = 0
minIndex := d.index - windowSize
if minIndex < 0 {
minIndex = 0
}
if chainHead >= minIndex &&
(isFastDeflate && lookahead > minMatchLength-1 ||
!isFastDeflate && lookahead > prevLength && prevLength < lazyMatch) {
if newLength, newOffset, ok := d.findMatch(index, chainHead, minMatchLength-1, lookahead); ok {
length = newLength
offset = newOffset
if d.chainHead >= minIndex &&
(d.fastSkipHashing != 0 && lookahead > minMatchLength-1 ||
d.fastSkipHashing == 0 && lookahead > prevLength && prevLength < d.lazy) {
if newLength, newOffset, ok := d.findMatch(d.index, d.chainHead, minMatchLength-1, lookahead); ok {
d.length = newLength
d.offset = newOffset
}
}
if isFastDeflate && length >= minMatchLength ||
!isFastDeflate && prevLength >= minMatchLength && length <= prevLength {
if d.fastSkipHashing != 0 && d.length >= minMatchLength ||
d.fastSkipHashing == 0 && prevLength >= minMatchLength && d.length <= prevLength {
// There was a match at the previous step, and the current match is
// not better. Output the previous match.
if isFastDeflate {
tokens[ti] = matchToken(uint32(length-minMatchLength), uint32(offset-minOffsetSize))
if d.fastSkipHashing != 0 {
d.tokens[d.ti] = matchToken(uint32(d.length-minMatchLength), uint32(d.offset-minOffsetSize))
} else {
tokens[ti] = matchToken(uint32(prevLength-minMatchLength), uint32(prevOffset-minOffsetSize))
d.tokens[d.ti] = matchToken(uint32(prevLength-minMatchLength), uint32(prevOffset-minOffsetSize))
}
ti++
d.ti++
// Insert in the hash table all strings up to the end of the match.
// index and index-1 are already inserted. If there is not enough
// lookahead, the last two strings are not inserted into the hash
// table.
if length <= l.fastSkipHashing {
if d.length <= d.fastSkipHashing {
var newIndex int
if isFastDeflate {
newIndex = index + length
if d.fastSkipHashing != 0 {
newIndex = d.index + d.length
} else {
newIndex = prevLength - 1
}
for index++; index < newIndex; index++ {
if index < maxInsertIndex {
hash = (hash<<hashShift + int(d.window[index+2])) & hashMask
for d.index++; d.index < newIndex; d.index++ {
if d.index < d.maxInsertIndex {
d.hash = (d.hash<<hashShift + int(d.window[d.index+2])) & hashMask
// Get previous value with the same hash.
// Our chain should point to the previous value.
d.hashPrev[index&d.windowMask] = d.hashHead[hash]
d.hashPrev[d.index&windowMask] = d.hashHead[d.hash]
// Set the head of the hash chain to us.
d.hashHead[hash] = index
d.hashHead[d.hash] = d.index
}
}
if !isFastDeflate {
byteAvailable = false
length = minMatchLength - 1
if d.fastSkipHashing == 0 {
d.byteAvailable = false
d.length = minMatchLength - 1
}
} else {
// For matches this long, we don't bother inserting each individual
// item into the table.
index += length
hash = (int(d.window[index])<<hashShift + int(d.window[index+1]))
d.index += d.length
d.hash = (int(d.window[d.index])<<hashShift + int(d.window[d.index+1]))
}
if ti == maxFlateBlockTokens {
if d.ti == maxFlateBlockTokens {
// The block includes the current character
if err = d.writeBlock(tokens, index, false); err != nil {
if d.err = d.writeBlock(d.tokens, d.index, false); d.err != nil {
return
}
ti = 0
d.ti = 0
}
} else {
if isFastDeflate || byteAvailable {
i := index - 1
if isFastDeflate {
i = index
if d.fastSkipHashing != 0 || d.byteAvailable {
i := d.index - 1
if d.fastSkipHashing != 0 {
i = d.index
}
tokens[ti] = literalToken(uint32(d.window[i]) & 0xFF)
ti++
if ti == maxFlateBlockTokens {
if err = d.writeBlock(tokens, i+1, false); err != nil {
d.tokens[d.ti] = literalToken(uint32(d.window[i]))
d.ti++
if d.ti == maxFlateBlockTokens {
if d.err = d.writeBlock(d.tokens, i+1, false); d.err != nil {
return
}
ti = 0
d.ti = 0
}
}
index++
if !isFastDeflate {
byteAvailable = true
d.index++
if d.fastSkipHashing == 0 {
d.byteAvailable = true
}
}
}
return
}
func (d *compressor) compress(r io.Reader, w io.Writer, level int, logWindowSize uint) (err os.Error) {
d.r = r
func (d *compressor) fillStore(b []byte) int {
n := copy(d.window[d.windowEnd:], b)
d.windowEnd += n
return n
}
func (d *compressor) store() {
if d.windowEnd > 0 {
d.err = d.writeStoredBlock(d.window[:d.windowEnd])
}
d.windowEnd = 0
}
func (d *compressor) write(b []byte) (n int, err os.Error) {
n = len(b)
b = b[d.fill(d, b):]
for len(b) > 0 {
d.step(d)
b = b[d.fill(d, b):]
}
return n, d.err
}
func (d *compressor) syncFlush() os.Error {
d.sync = true
d.step(d)
if d.err == nil {
d.w.writeStoredHeader(0, false)
d.w.flush()
d.err = d.w.err
}
d.sync = false
return d.err
}
func (d *compressor) init(w io.Writer, level int) (err os.Error) {
d.w = newHuffmanBitWriter(w)
d.level = level
d.logWindowSize = logWindowSize
switch {
case level == NoCompression:
err = d.storedDeflate()
d.window = make([]byte, maxStoreBlockSize)
d.fill = (*compressor).fillStore
d.step = (*compressor).store
case level == DefaultCompression:
d.level = 6
level = 6
fallthrough
case 1 <= level && level <= 9:
err = d.doDeflate()
d.compressionLevel = levels[level]
d.initDeflate()
d.fill = (*compressor).fillDeflate
d.step = (*compressor).deflate
default:
return WrongValueError{"level", 0, 9, int32(level)}
}
return nil
}
if d.sync {
d.syncChan <- err
d.sync = false
}
if err != nil {
return err
func (d *compressor) close() os.Error {
d.sync = true
d.step(d)
if d.err != nil {
return d.err
}
if d.w.writeStoredHeader(0, true); d.w.err != nil {
return d.w.err
}
return d.flush()
d.w.flush()
return d.w.err
}
// NewWriter returns a new Writer compressing
@ -486,14 +430,9 @@ func (d *compressor) compress(r io.Reader, w io.Writer, level int, logWindowSize
// compression; it only adds the necessary DEFLATE framing.
func NewWriter(w io.Writer, level int) *Writer {
const logWindowSize = logMaxOffsetSize
var d compressor
d.syncChan = make(chan os.Error, 1)
pr, pw := syncPipe()
go func() {
err := d.compress(pr, w, level, logWindowSize)
pr.CloseWithError(err)
}()
return &Writer{pw, &d}
var dw Writer
dw.d.init(w, level)
return &dw
}
// NewWriterDict is like NewWriter but initializes the new
@ -526,18 +465,13 @@ func (w *dictWriter) Write(b []byte) (n int, err os.Error) {
// A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see NewWriter).
type Writer struct {
w *syncPipeWriter
d *compressor
d compressor
}
// Write writes data to w, which will eventually write the
// compressed form of data to its underlying writer.
func (w *Writer) Write(data []byte) (n int, err os.Error) {
if len(data) == 0 {
// no point, and nil interferes with sync
return
}
return w.w.Write(data)
return w.d.write(data)
}
// Flush flushes any pending compressed data to the underlying writer.
@ -550,18 +484,10 @@ func (w *Writer) Write(data []byte) (n int, err os.Error) {
func (w *Writer) Flush() os.Error {
// For more about flushing:
// http://www.bolet.org/~pornin/deflate-flush.html
if w.d.sync {
panic("compress/flate: double Flush")
}
_, err := w.w.Write(nil)
err1 := <-w.d.syncChan
if err == nil {
err = err1
}
return err
return w.d.syncFlush()
}
// Close flushes and closes the writer.
func (w *Writer) Close() os.Error {
return w.w.Close()
return w.d.close()
}

View File

@ -57,7 +57,7 @@ var deflateInflateTests = []*deflateInflateTest{
&deflateInflateTest{[]byte{0x11, 0x12}},
&deflateInflateTest{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}},
&deflateInflateTest{[]byte{0x11, 0x10, 0x13, 0x41, 0x21, 0x21, 0x41, 0x13, 0x87, 0x78, 0x13}},
&deflateInflateTest{getLargeDataChunk()},
&deflateInflateTest{largeDataChunk()},
}
var reverseBitsTests = []*reverseBitsTest{
@ -71,23 +71,22 @@ var reverseBitsTests = []*reverseBitsTest{
&reverseBitsTest{29, 5, 23},
}
func getLargeDataChunk() []byte {
func largeDataChunk() []byte {
result := make([]byte, 100000)
for i := range result {
result[i] = byte(int64(i) * int64(i) & 0xFF)
result[i] = byte(i * i & 0xFF)
}
return result
}
func TestDeflate(t *testing.T) {
for _, h := range deflateTests {
buffer := bytes.NewBuffer(nil)
w := NewWriter(buffer, h.level)
var buf bytes.Buffer
w := NewWriter(&buf, h.level)
w.Write(h.in)
w.Close()
if bytes.Compare(buffer.Bytes(), h.out) != 0 {
t.Errorf("buffer is wrong; level = %v, buffer.Bytes() = %v, expected output = %v",
h.level, buffer.Bytes(), h.out)
if !bytes.Equal(buf.Bytes(), h.out) {
t.Errorf("Deflate(%d, %x) = %x, want %x", h.level, h.in, buf.Bytes(), h.out)
}
}
}
@ -226,7 +225,6 @@ func testSync(t *testing.T, level int, input []byte, name string) {
}
}
func testToFromWithLevel(t *testing.T, level int, input []byte, name string) os.Error {
buffer := bytes.NewBuffer(nil)
w := NewWriter(buffer, level)

View File

@ -15,9 +15,6 @@ const (
// The largest offset code.
offsetCodeCount = 30
// The largest offset code in the extensions.
extendedOffsetCodeCount = 42
// The special code used to mark the end of a block.
endBlockMarker = 256
@ -100,11 +97,11 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
return &huffmanBitWriter{
w: w,
literalFreq: make([]int32, maxLit),
offsetFreq: make([]int32, extendedOffsetCodeCount),
codegen: make([]uint8, maxLit+extendedOffsetCodeCount+1),
offsetFreq: make([]int32, offsetCodeCount),
codegen: make([]uint8, maxLit+offsetCodeCount+1),
codegenFreq: make([]int32, codegenCodeCount),
literalEncoding: newHuffmanEncoder(maxLit),
offsetEncoding: newHuffmanEncoder(extendedOffsetCodeCount),
offsetEncoding: newHuffmanEncoder(offsetCodeCount),
codegenEncoding: newHuffmanEncoder(codegenCodeCount),
}
}
@ -185,7 +182,7 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) {
_, w.err = w.w.Write(bytes)
}
// RFC 1951 3.2.7 specifies a special run-length encoding for specifiying
// RFC 1951 3.2.7 specifies a special run-length encoding for specifying
// the literal and offset lengths arrays (which are concatenated into a single
// array). This method generates that run-length encoding.
//
@ -279,7 +276,7 @@ func (w *huffmanBitWriter) writeCode(code *huffmanEncoder, literal uint32) {
//
// numLiterals The number of literals specified in codegen
// numOffsets The number of offsets specified in codegen
// numCodegens Tne number of codegens used in codegen
// numCodegens The number of codegens used in codegen
func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, numCodegens int, isEof bool) {
if w.err != nil {
return
@ -290,13 +287,7 @@ func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, n
}
w.writeBits(firstBits, 3)
w.writeBits(int32(numLiterals-257), 5)
if numOffsets > offsetCodeCount {
// Extended version of decompressor
w.writeBits(int32(offsetCodeCount+((numOffsets-(1+offsetCodeCount))>>3)), 5)
w.writeBits(int32((numOffsets-(1+offsetCodeCount))&0x7), 3)
} else {
w.writeBits(int32(numOffsets-1), 5)
}
w.writeBits(int32(numOffsets-1), 5)
w.writeBits(int32(numCodegens-4), 4)
for i := 0; i < numCodegens; i++ {
@ -368,24 +359,17 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
tokens = tokens[0 : n+1]
tokens[n] = endBlockMarker
totalLength := -1 // Subtract 1 for endBlock.
for _, t := range tokens {
switch t.typ() {
case literalType:
w.literalFreq[t.literal()]++
totalLength++
break
case matchType:
length := t.length()
offset := t.offset()
totalLength += int(length + 3)
w.literalFreq[lengthCodesStart+lengthCode(length)]++
w.offsetFreq[offsetCode(offset)]++
break
}
}
w.literalEncoding.generate(w.literalFreq, 15)
w.offsetEncoding.generate(w.offsetFreq, 15)
// get the number of literals
numLiterals := len(w.literalFreq)
@ -394,15 +378,25 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
}
// get the number of offsets
numOffsets := len(w.offsetFreq)
for numOffsets > 1 && w.offsetFreq[numOffsets-1] == 0 {
for numOffsets > 0 && w.offsetFreq[numOffsets-1] == 0 {
numOffsets--
}
if numOffsets == 0 {
// We haven't found a single match. If we want to go with the dynamic encoding,
// we should count at least one offset to be sure that the offset huffman tree could be encoded.
w.offsetFreq[0] = 1
numOffsets = 1
}
w.literalEncoding.generate(w.literalFreq, 15)
w.offsetEncoding.generate(w.offsetFreq, 15)
storedBytes := 0
if input != nil {
storedBytes = len(input)
}
var extraBits int64
var storedSize int64
var storedSize int64 = math.MaxInt64
if storedBytes <= maxStoreBlockSize && input != nil {
storedSize = int64((storedBytes + 5) * 8)
// We only bother calculating the costs of the extra bits required by
@ -417,34 +411,29 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
// First four offset codes have extra size = 0.
extraBits += int64(w.offsetFreq[offsetCode]) * int64(offsetExtraBits[offsetCode])
}
} else {
storedSize = math.MaxInt32
}
// Figure out which generates smaller code, fixed Huffman, dynamic
// Huffman, or just storing the data.
var fixedSize int64 = math.MaxInt64
if numOffsets <= offsetCodeCount {
fixedSize = int64(3) +
fixedLiteralEncoding.bitLength(w.literalFreq) +
fixedOffsetEncoding.bitLength(w.offsetFreq) +
extraBits
}
// Figure out smallest code.
// Fixed Huffman baseline.
var size = int64(3) +
fixedLiteralEncoding.bitLength(w.literalFreq) +
fixedOffsetEncoding.bitLength(w.offsetFreq) +
extraBits
var literalEncoding = fixedLiteralEncoding
var offsetEncoding = fixedOffsetEncoding
// Dynamic Huffman?
var numCodegens int
// Generate codegen and codegenFrequencies, which indicates how to encode
// the literalEncoding and the offsetEncoding.
w.generateCodegen(numLiterals, numOffsets)
w.codegenEncoding.generate(w.codegenFreq, 7)
numCodegens := len(w.codegenFreq)
numCodegens = len(w.codegenFreq)
for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 {
numCodegens--
}
extensionSummand := 0
if numOffsets > offsetCodeCount {
extensionSummand = 3
}
dynamicHeader := int64(3+5+5+4+(3*numCodegens)) +
// Following line is an extension.
int64(extensionSummand) +
w.codegenEncoding.bitLength(w.codegenFreq) +
int64(extraBits) +
int64(w.codegenFreq[16]*2) +
@ -454,26 +443,25 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
w.literalEncoding.bitLength(w.literalFreq) +
w.offsetEncoding.bitLength(w.offsetFreq)
if storedSize < fixedSize && storedSize < dynamicSize {
w.writeStoredHeader(storedBytes, eof)
w.writeBytes(input[0:storedBytes])
return
}
var literalEncoding *huffmanEncoder
var offsetEncoding *huffmanEncoder
if fixedSize <= dynamicSize {
w.writeFixedHeader(eof)
literalEncoding = fixedLiteralEncoding
offsetEncoding = fixedOffsetEncoding
} else {
// Write the header.
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
if dynamicSize < size {
size = dynamicSize
literalEncoding = w.literalEncoding
offsetEncoding = w.offsetEncoding
}
// Write the tokens.
// Stored bytes?
if storedSize < size {
w.writeStoredHeader(storedBytes, eof)
w.writeBytes(input[0:storedBytes])
return
}
// Huffman.
if literalEncoding == fixedLiteralEncoding {
w.writeFixedHeader(eof)
} else {
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
}
for _, t := range tokens {
switch t.typ() {
case literalType:

View File

@ -363,7 +363,12 @@ func (s literalNodeSorter) Less(i, j int) bool {
func (s literalNodeSorter) Swap(i, j int) { s.a[i], s.a[j] = s.a[j], s.a[i] }
func sortByFreq(a []literalNode) {
s := &literalNodeSorter{a, func(i, j int) bool { return a[i].freq < a[j].freq }}
s := &literalNodeSorter{a, func(i, j int) bool {
if a[i].freq == a[j].freq {
return a[i].literal < a[j].literal
}
return a[i].freq < a[j].freq
}}
sort.Sort(s)
}

View File

@ -77,8 +77,6 @@ type huffmanDecoder struct {
// Initialize Huffman decoding tables from array of code lengths.
func (h *huffmanDecoder) init(bits []int) bool {
// TODO(rsc): Return false sometimes.
// Count number of codes of each length,
// compute min and max length.
var count [maxCodeLen + 1]int
@ -197,9 +195,8 @@ type Reader interface {
// Decompress state.
type decompressor struct {
// Input/output sources.
// Input source.
r Reader
w io.Writer
roffset int64
woffset int64
@ -222,38 +219,79 @@ type decompressor struct {
// Temporary buffer (avoids repeated allocation).
buf [4]byte
// Next step in the decompression,
// and decompression state.
step func(*decompressor)
final bool
err os.Error
toRead []byte
hl, hd *huffmanDecoder
copyLen int
copyDist int
}
func (f *decompressor) inflate() (err os.Error) {
final := false
for err == nil && !final {
for f.nb < 1+2 {
if err = f.moreBits(); err != nil {
return
}
func (f *decompressor) nextBlock() {
if f.final {
if f.hw != f.hp {
f.flush((*decompressor).nextBlock)
return
}
final = f.b&1 == 1
f.b >>= 1
typ := f.b & 3
f.b >>= 2
f.nb -= 1 + 2
switch typ {
case 0:
err = f.dataBlock()
case 1:
// compressed, fixed Huffman tables
err = f.decodeBlock(&fixedHuffmanDecoder, nil)
case 2:
// compressed, dynamic Huffman tables
if err = f.readHuffman(); err == nil {
err = f.decodeBlock(&f.h1, &f.h2)
}
default:
// 3 is reserved.
err = CorruptInputError(f.roffset)
f.err = os.EOF
return
}
for f.nb < 1+2 {
if f.err = f.moreBits(); f.err != nil {
return
}
}
return
f.final = f.b&1 == 1
f.b >>= 1
typ := f.b & 3
f.b >>= 2
f.nb -= 1 + 2
switch typ {
case 0:
f.dataBlock()
case 1:
// compressed, fixed Huffman tables
f.hl = &fixedHuffmanDecoder
f.hd = nil
f.huffmanBlock()
case 2:
// compressed, dynamic Huffman tables
if f.err = f.readHuffman(); f.err != nil {
break
}
f.hl = &f.h1
f.hd = &f.h2
f.huffmanBlock()
default:
// 3 is reserved.
f.err = CorruptInputError(f.roffset)
}
}
func (f *decompressor) Read(b []byte) (int, os.Error) {
for {
if len(f.toRead) > 0 {
n := copy(b, f.toRead)
f.toRead = f.toRead[n:]
return n, nil
}
if f.err != nil {
return 0, f.err
}
f.step(f)
}
panic("unreachable")
}
func (f *decompressor) Close() os.Error {
if f.err == os.EOF {
return nil
}
return f.err
}
// RFC 1951 section 3.2.7.
@ -358,11 +396,12 @@ func (f *decompressor) readHuffman() os.Error {
// hl and hd are the Huffman states for the lit/length values
// and the distance values, respectively. If hd == nil, using the
// fixed distance encoding associated with fixed Huffman blocks.
func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
func (f *decompressor) huffmanBlock() {
for {
v, err := f.huffSym(hl)
v, err := f.huffSym(f.hl)
if err != nil {
return err
f.err = err
return
}
var n uint // number of bits extra
var length int
@ -371,13 +410,15 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
f.hist[f.hp] = byte(v)
f.hp++
if f.hp == len(f.hist) {
if err = f.flush(); err != nil {
return err
}
// After the flush, continue this loop.
f.flush((*decompressor).huffmanBlock)
return
}
continue
case v == 256:
return nil
// Done with huffman block; read next block.
f.step = (*decompressor).nextBlock
return
// otherwise, reference to older data
case v < 265:
length = v - (257 - 3)
@ -404,7 +445,8 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
if n > 0 {
for f.nb < n {
if err = f.moreBits(); err != nil {
return err
f.err = err
return
}
}
length += int(f.b & uint32(1<<n-1))
@ -413,18 +455,20 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
}
var dist int
if hd == nil {
if f.hd == nil {
for f.nb < 5 {
if err = f.moreBits(); err != nil {
return err
f.err = err
return
}
}
dist = int(reverseByte[(f.b&0x1F)<<3])
f.b >>= 5
f.nb -= 5
} else {
if dist, err = f.huffSym(hd); err != nil {
return err
if dist, err = f.huffSym(f.hd); err != nil {
f.err = err
return
}
}
@ -432,14 +476,16 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
case dist < 4:
dist++
case dist >= 30:
return CorruptInputError(f.roffset)
f.err = CorruptInputError(f.roffset)
return
default:
nb := uint(dist-2) >> 1
// have 1 bit in bottom of dist, need nb more.
extra := (dist & 1) << nb
for f.nb < nb {
if err = f.moreBits(); err != nil {
return err
f.err = err
return
}
}
extra |= int(f.b & uint32(1<<nb-1))
@ -450,12 +496,14 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
// Copy history[-dist:-dist+length] into output.
if dist > len(f.hist) {
return InternalError("bad history distance")
f.err = InternalError("bad history distance")
return
}
// No check on length; encoding can be prescient.
if !f.hfull && dist > f.hp {
return CorruptInputError(f.roffset)
f.err = CorruptInputError(f.roffset)
return
}
p := f.hp - dist
@ -467,9 +515,11 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
f.hp++
p++
if f.hp == len(f.hist) {
if err = f.flush(); err != nil {
return err
}
// After flush continue copying out of history.
f.copyLen = length - (i + 1)
f.copyDist = dist
f.flush((*decompressor).copyHuff)
return
}
if p == len(f.hist) {
p = 0
@ -479,8 +529,33 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
panic("unreached")
}
func (f *decompressor) copyHuff() {
length := f.copyLen
dist := f.copyDist
p := f.hp - dist
if p < 0 {
p += len(f.hist)
}
for i := 0; i < length; i++ {
f.hist[f.hp] = f.hist[p]
f.hp++
p++
if f.hp == len(f.hist) {
f.copyLen = length - (i + 1)
f.flush((*decompressor).copyHuff)
return
}
if p == len(f.hist) {
p = 0
}
}
// Continue processing Huffman block.
f.huffmanBlock()
}
// Copy a single uncompressed data block from input to output.
func (f *decompressor) dataBlock() os.Error {
func (f *decompressor) dataBlock() {
// Uncompressed.
// Discard current half-byte.
f.nb = 0
@ -490,21 +565,30 @@ func (f *decompressor) dataBlock() os.Error {
nr, err := io.ReadFull(f.r, f.buf[0:4])
f.roffset += int64(nr)
if err != nil {
return &ReadError{f.roffset, err}
f.err = &ReadError{f.roffset, err}
return
}
n := int(f.buf[0]) | int(f.buf[1])<<8
nn := int(f.buf[2]) | int(f.buf[3])<<8
if uint16(nn) != uint16(^n) {
return CorruptInputError(f.roffset)
f.err = CorruptInputError(f.roffset)
return
}
if n == 0 {
// 0-length block means sync
return f.flush()
f.flush((*decompressor).nextBlock)
return
}
// Read len bytes into history,
// writing as history fills.
f.copyLen = n
f.copyData()
}
func (f *decompressor) copyData() {
// Read f.dataLen bytes into history,
// pausing for reads as history fills.
n := f.copyLen
for n > 0 {
m := len(f.hist) - f.hp
if m > n {
@ -513,17 +597,18 @@ func (f *decompressor) dataBlock() os.Error {
m, err := io.ReadFull(f.r, f.hist[f.hp:f.hp+m])
f.roffset += int64(m)
if err != nil {
return &ReadError{f.roffset, err}
f.err = &ReadError{f.roffset, err}
return
}
n -= m
f.hp += m
if f.hp == len(f.hist) {
if err = f.flush(); err != nil {
return err
}
f.copyLen = n
f.flush((*decompressor).copyData)
return
}
}
return nil
f.step = (*decompressor).nextBlock
}
func (f *decompressor) setDict(dict []byte) {
@ -579,17 +664,8 @@ func (f *decompressor) huffSym(h *huffmanDecoder) (int, os.Error) {
}
// Flush any buffered output to the underlying writer.
func (f *decompressor) flush() os.Error {
if f.hw == f.hp {
return nil
}
n, err := f.w.Write(f.hist[f.hw:f.hp])
if n != f.hp-f.hw && err == nil {
err = io.ErrShortWrite
}
if err != nil {
return &WriteError{f.woffset, err}
}
func (f *decompressor) flush(step func(*decompressor)) {
f.toRead = f.hist[f.hw:f.hp]
f.woffset += int64(f.hp - f.hw)
f.hw = f.hp
if f.hp == len(f.hist) {
@ -597,7 +673,7 @@ func (f *decompressor) flush() os.Error {
f.hw = 0
f.hfull = true
}
return nil
f.step = step
}
func makeReader(r io.Reader) Reader {
@ -607,30 +683,15 @@ func makeReader(r io.Reader) Reader {
return bufio.NewReader(r)
}
// decompress reads DEFLATE-compressed data from r and writes
// the uncompressed data to w.
func (f *decompressor) decompress(r io.Reader, w io.Writer) os.Error {
f.r = makeReader(r)
f.w = w
f.woffset = 0
if err := f.inflate(); err != nil {
return err
}
if err := f.flush(); err != nil {
return err
}
return nil
}
// NewReader returns a new ReadCloser that can be used
// to read the uncompressed version of r. It is the caller's
// responsibility to call Close on the ReadCloser when
// finished reading.
func NewReader(r io.Reader) io.ReadCloser {
var f decompressor
pr, pw := io.Pipe()
go func() { pw.CloseWithError(f.decompress(r, pw)) }()
return pr
f.r = makeReader(r)
f.step = (*decompressor).nextBlock
return &f
}
// NewReaderDict is like NewReader but initializes the reader
@ -641,7 +702,7 @@ func NewReader(r io.Reader) io.ReadCloser {
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
var f decompressor
f.setDict(dict)
pr, pw := io.Pipe()
go func() { pw.CloseWithError(f.decompress(r, pw)) }()
return pr
f.r = makeReader(r)
f.step = (*decompressor).nextBlock
return &f
}

View File

@ -36,8 +36,8 @@ func makeReader(r io.Reader) flate.Reader {
return bufio.NewReader(r)
}
var HeaderError os.Error = os.ErrorString("invalid gzip header")
var ChecksumError os.Error = os.ErrorString("gzip checksum error")
var HeaderError = os.NewError("invalid gzip header")
var ChecksumError = os.NewError("gzip checksum error")
// The gzip file stores a header giving metadata about the compressed file.
// That header is exposed as the fields of the Compressor and Decompressor structs.

View File

@ -11,7 +11,7 @@ import (
)
// pipe creates two ends of a pipe that gzip and gunzip, and runs dfunc at the
// writer end and ifunc at the reader end.
// writer end and cfunc at the reader end.
func pipe(t *testing.T, dfunc func(*Compressor), cfunc func(*Decompressor)) {
piper, pipew := io.Pipe()
defer piper.Close()

View File

@ -32,13 +32,49 @@ const (
MSB
)
const (
maxWidth = 12
decoderInvalidCode = 0xffff
flushBuffer = 1 << maxWidth
)
// decoder is the state from which the readXxx method converts a byte
// stream into a code stream.
type decoder struct {
r io.ByteReader
bits uint32
nBits uint
width uint
r io.ByteReader
bits uint32
nBits uint
width uint
read func(*decoder) (uint16, os.Error) // readLSB or readMSB
litWidth int // width in bits of literal codes
err os.Error
// The first 1<<litWidth codes are literal codes.
// The next two codes mean clear and EOF.
// Other valid codes are in the range [lo, hi] where lo := clear + 2,
// with the upper bound incrementing on each code seen.
// overflow is the code at which hi overflows the code width.
// last is the most recently seen code, or decoderInvalidCode.
clear, eof, hi, overflow, last uint16
// Each code c in [lo, hi] expands to two or more bytes. For c != hi:
// suffix[c] is the last of these bytes.
// prefix[c] is the code for all but the last byte.
// This code can either be a literal code or another code in [lo, c).
// The c == hi case is a special case.
suffix [1 << maxWidth]uint8
prefix [1 << maxWidth]uint16
// output is the temporary output buffer.
// Literal codes are accumulated from the start of the buffer.
// Non-literal codes decode to a sequence of suffixes that are first
// written right-to-left from the end of the buffer before being copied
// to the start of the buffer.
// It is flushed when it contains >= 1<<maxWidth bytes,
// so that there is always room to decode an entire code.
output [2 * 1 << maxWidth]byte
o int // write index into output
toRead []byte // bytes to return from Read
}
// readLSB returns the next code for "Least Significant Bits first" data.
@ -73,119 +109,113 @@ func (d *decoder) readMSB() (uint16, os.Error) {
return code, nil
}
// decode decompresses bytes from r and writes them to pw.
// read specifies how to decode bytes into codes.
// litWidth is the width in bits of literal codes.
func decode(r io.Reader, read func(*decoder) (uint16, os.Error), litWidth int, pw *io.PipeWriter) {
br, ok := r.(io.ByteReader)
if !ok {
br = bufio.NewReader(r)
func (d *decoder) Read(b []byte) (int, os.Error) {
for {
if len(d.toRead) > 0 {
n := copy(b, d.toRead)
d.toRead = d.toRead[n:]
return n, nil
}
if d.err != nil {
return 0, d.err
}
d.decode()
}
pw.CloseWithError(decode1(pw, br, read, uint(litWidth)))
panic("unreachable")
}
func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os.Error), litWidth uint) os.Error {
const (
maxWidth = 12
invalidCode = 0xffff
)
d := decoder{r, 0, 0, 1 + litWidth}
w := bufio.NewWriter(pw)
// The first 1<<litWidth codes are literal codes.
// The next two codes mean clear and EOF.
// Other valid codes are in the range [lo, hi] where lo := clear + 2,
// with the upper bound incrementing on each code seen.
clear := uint16(1) << litWidth
eof, hi := clear+1, clear+1
// overflow is the code at which hi overflows the code width.
overflow := uint16(1) << d.width
var (
// Each code c in [lo, hi] expands to two or more bytes. For c != hi:
// suffix[c] is the last of these bytes.
// prefix[c] is the code for all but the last byte.
// This code can either be a literal code or another code in [lo, c).
// The c == hi case is a special case.
suffix [1 << maxWidth]uint8
prefix [1 << maxWidth]uint16
// buf is a scratch buffer for reconstituting the bytes that a code expands to.
// Code suffixes are written right-to-left from the end of the buffer.
buf [1 << maxWidth]byte
)
// decode decompresses bytes from r and leaves them in d.toRead.
// read specifies how to decode bytes into codes.
// litWidth is the width in bits of literal codes.
func (d *decoder) decode() {
// Loop over the code stream, converting codes into decompressed bytes.
last := uint16(invalidCode)
for {
code, err := read(&d)
code, err := d.read(d)
if err != nil {
if err == os.EOF {
err = io.ErrUnexpectedEOF
}
return err
d.err = err
return
}
switch {
case code < clear:
case code < d.clear:
// We have a literal code.
if err := w.WriteByte(uint8(code)); err != nil {
return err
}
if last != invalidCode {
d.output[d.o] = uint8(code)
d.o++
if d.last != decoderInvalidCode {
// Save what the hi code expands to.
suffix[hi] = uint8(code)
prefix[hi] = last
d.suffix[d.hi] = uint8(code)
d.prefix[d.hi] = d.last
}
case code == clear:
d.width = 1 + litWidth
hi = eof
overflow = 1 << d.width
last = invalidCode
case code == d.clear:
d.width = 1 + uint(d.litWidth)
d.hi = d.eof
d.overflow = 1 << d.width
d.last = decoderInvalidCode
continue
case code == eof:
return w.Flush()
case code <= hi:
c, i := code, len(buf)-1
if code == hi {
case code == d.eof:
d.flush()
d.err = os.EOF
return
case code <= d.hi:
c, i := code, len(d.output)-1
if code == d.hi {
// code == hi is a special case which expands to the last expansion
// followed by the head of the last expansion. To find the head, we walk
// the prefix chain until we find a literal code.
c = last
for c >= clear {
c = prefix[c]
c = d.last
for c >= d.clear {
c = d.prefix[c]
}
buf[i] = uint8(c)
d.output[i] = uint8(c)
i--
c = last
c = d.last
}
// Copy the suffix chain into buf and then write that to w.
for c >= clear {
buf[i] = suffix[c]
// Copy the suffix chain into output and then write that to w.
for c >= d.clear {
d.output[i] = d.suffix[c]
i--
c = prefix[c]
c = d.prefix[c]
}
buf[i] = uint8(c)
if _, err := w.Write(buf[i:]); err != nil {
return err
}
if last != invalidCode {
d.output[i] = uint8(c)
d.o += copy(d.output[d.o:], d.output[i:])
if d.last != decoderInvalidCode {
// Save what the hi code expands to.
suffix[hi] = uint8(c)
prefix[hi] = last
d.suffix[d.hi] = uint8(c)
d.prefix[d.hi] = d.last
}
default:
return os.NewError("lzw: invalid code")
d.err = os.NewError("lzw: invalid code")
return
}
last, hi = code, hi+1
if hi >= overflow {
d.last, d.hi = code, d.hi+1
if d.hi >= d.overflow {
if d.width == maxWidth {
last = invalidCode
continue
d.last = decoderInvalidCode
} else {
d.width++
d.overflow <<= 1
}
d.width++
overflow <<= 1
}
if d.o >= flushBuffer {
d.flush()
return
}
}
panic("unreachable")
}
func (d *decoder) flush() {
d.toRead = d.output[:d.o]
d.o = 0
}
func (d *decoder) Close() os.Error {
d.err = os.EINVAL // in case any Reads come along
return nil
}
// NewReader creates a new io.ReadCloser that satisfies reads by decompressing
// the data read from r.
// It is the caller's responsibility to call Close on the ReadCloser when
@ -193,21 +223,31 @@ func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os
// The number of bits to use for literal codes, litWidth, must be in the
// range [2,8] and is typically 8.
func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
pr, pw := io.Pipe()
var read func(*decoder) (uint16, os.Error)
d := new(decoder)
switch order {
case LSB:
read = (*decoder).readLSB
d.read = (*decoder).readLSB
case MSB:
read = (*decoder).readMSB
d.read = (*decoder).readMSB
default:
pw.CloseWithError(os.NewError("lzw: unknown order"))
return pr
d.err = os.NewError("lzw: unknown order")
return d
}
if litWidth < 2 || 8 < litWidth {
pw.CloseWithError(fmt.Errorf("lzw: litWidth %d out of range", litWidth))
return pr
d.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
return d
}
go decode(r, read, litWidth, pw)
return pr
if br, ok := r.(io.ByteReader); ok {
d.r = br
} else {
d.r = bufio.NewReader(r)
}
d.litWidth = litWidth
d.width = 1 + uint(litWidth)
d.clear = uint16(1) << uint(litWidth)
d.eof, d.hi = d.clear+1, d.clear+1
d.overflow = uint16(1) << d.width
d.last = decoderInvalidCode
return d
}

View File

@ -84,7 +84,7 @@ var lzwTests = []lzwTest{
func TestReader(t *testing.T) {
b := bytes.NewBuffer(nil)
for _, tt := range lzwTests {
d := strings.Split(tt.desc, ";", -1)
d := strings.Split(tt.desc, ";")
var order Order
switch d[1] {
case "LSB":

View File

@ -77,13 +77,13 @@ func testFile(t *testing.T, fn string, order Order, litWidth int) {
t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
return
}
if len(b0) != len(b1) {
t.Errorf("%s (order=%d litWidth=%d): length mismatch %d versus %d", fn, order, litWidth, len(b0), len(b1))
if len(b1) != len(b0) {
t.Errorf("%s (order=%d litWidth=%d): length mismatch %d != %d", fn, order, litWidth, len(b1), len(b0))
return
}
for i := 0; i < len(b0); i++ {
if b0[i] != b1[i] {
t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x versus 0x%02x\n", fn, order, litWidth, i, b0[i], b1[i])
if b1[i] != b0[i] {
t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x != 0x%02x\n", fn, order, litWidth, i, b1[i], b0[i])
return
}
}

View File

@ -34,9 +34,9 @@ import (
const zlibDeflate = 8
var ChecksumError os.Error = os.ErrorString("zlib checksum error")
var HeaderError os.Error = os.ErrorString("invalid zlib header")
var DictionaryError os.Error = os.ErrorString("invalid zlib dictionary")
var ChecksumError = os.NewError("zlib checksum error")
var HeaderError = os.NewError("invalid zlib header")
var DictionaryError = os.NewError("invalid zlib dictionary")
type reader struct {
r flate.Reader

View File

@ -89,7 +89,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, os.Error) {
}
}
z.w = w
z.compressor = flate.NewWriter(w, level)
z.compressor = flate.NewWriterDict(w, level, dict)
z.digest = adler32.New()
return z, nil
}

View File

@ -5,6 +5,8 @@
package zlib
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
@ -16,15 +18,13 @@ var filenames = []string{
"../testdata/pi.txt",
}
var data = []string{
"test a reasonable sized string that can be compressed",
}
// Tests that compressing and then decompressing the given file at the given compression level and dictionary
// yields equivalent bytes to the original file.
func testFileLevelDict(t *testing.T, fn string, level int, d string) {
// Read dictionary, if given.
var dict []byte
if d != "" {
dict = []byte(d)
}
// Read the file, as golden output.
golden, err := os.Open(fn)
if err != nil {
@ -32,17 +32,25 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
return
}
defer golden.Close()
// Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end.
raw, err := os.Open(fn)
if err != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
b0, err0 := ioutil.ReadAll(golden)
if err0 != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
return
}
testLevelDict(t, fn, b0, level, d)
}
func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) {
// Make dictionary, if given.
var dict []byte
if d != "" {
dict = []byte(d)
}
// Push data through a pipe that compresses at the write end, and decompresses at the read end.
piper, pipew := io.Pipe()
defer piper.Close()
go func() {
defer raw.Close()
defer pipew.Close()
zlibw, err := NewWriterDict(pipew, level, dict)
if err != nil {
@ -50,25 +58,14 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
return
}
defer zlibw.Close()
var b [1024]byte
for {
n, err0 := raw.Read(b[0:])
if err0 != nil && err0 != os.EOF {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
return
}
_, err1 := zlibw.Write(b[0:n])
if err1 == os.EPIPE {
// Fail, but do not report the error, as some other (presumably reportable) error broke the pipe.
return
}
if err1 != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1)
return
}
if err0 == os.EOF {
break
}
_, err = zlibw.Write(b0)
if err == os.EPIPE {
// Fail, but do not report the error, as some other (presumably reported) error broke the pipe.
return
}
if err != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
return
}
}()
zlibr, err := NewReaderDict(piper, dict)
@ -78,13 +75,8 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
}
defer zlibr.Close()
// Compare the two.
b0, err0 := ioutil.ReadAll(golden)
// Compare the decompressed data.
b1, err1 := ioutil.ReadAll(zlibr)
if err0 != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
return
}
if err1 != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1)
return
@ -102,6 +94,18 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
}
func TestWriter(t *testing.T) {
for i, s := range data {
b := []byte(s)
tag := fmt.Sprintf("#%d", i)
testLevelDict(t, tag, b, DefaultCompression, "")
testLevelDict(t, tag, b, NoCompression, "")
for level := BestSpeed; level <= BestCompression; level++ {
testLevelDict(t, tag, b, level, "")
}
}
}
func TestWriterBig(t *testing.T) {
for _, fn := range filenames {
testFileLevelDict(t, fn, DefaultCompression, "")
testFileLevelDict(t, fn, NoCompression, "")
@ -121,3 +125,20 @@ func TestWriterDict(t *testing.T) {
}
}
}
func TestWriterDictIsUsed(t *testing.T) {
var input = []byte("Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
buf := bytes.NewBuffer(nil)
compressor, err := NewWriterDict(buf, BestCompression, input)
if err != nil {
t.Errorf("error in NewWriterDict: %s", err)
return
}
compressor.Write(input)
compressor.Close()
const expectedMaxSize = 25
output := buf.Bytes()
if len(output) > expectedMaxSize {
t.Errorf("result too large (got %d, want <= %d bytes). Is the dictionary being used?", len(output), expectedMaxSize)
}
}

View File

@ -21,8 +21,7 @@ type Interface interface {
Pop() interface{}
}
// A heaper must be initialized before any of the heap operations
// A heap must be initialized before any of the heap operations
// can be used. Init is idempotent with respect to the heap invariants
// and may be called whenever the heap invariants may have been invalidated.
// Its complexity is O(n) where n = h.Len().
@ -35,7 +34,6 @@ func Init(h Interface) {
}
}
// Push pushes the element x onto the heap. The complexity is
// O(log(n)) where n = h.Len().
//
@ -44,7 +42,6 @@ func Push(h Interface, x interface{}) {
up(h, h.Len()-1)
}
// Pop removes the minimum element (according to Less) from the heap
// and returns it. The complexity is O(log(n)) where n = h.Len().
// Same as Remove(h, 0).
@ -56,7 +53,6 @@ func Pop(h Interface) interface{} {
return h.Pop()
}
// Remove removes the element at index i from the heap.
// The complexity is O(log(n)) where n = h.Len().
//
@ -70,7 +66,6 @@ func Remove(h Interface, i int) interface{} {
return h.Pop()
}
func up(h Interface, j int) {
for {
i := (j - 1) / 2 // parent
@ -82,7 +77,6 @@ func up(h Interface, j int) {
}
}
func down(h Interface, i, n int) {
for {
j1 := 2*i + 1

View File

@ -10,17 +10,14 @@ import (
. "container/heap"
)
type myHeap struct {
// A vector.Vector implements sort.Interface except for Less,
// and it implements Push and Pop as required for heap.Interface.
vector.Vector
}
func (h *myHeap) Less(i, j int) bool { return h.At(i).(int) < h.At(j).(int) }
func (h *myHeap) verify(t *testing.T, i int) {
n := h.Len()
j1 := 2*i + 1
@ -41,7 +38,6 @@ func (h *myHeap) verify(t *testing.T, i int) {
}
}
func TestInit0(t *testing.T) {
h := new(myHeap)
for i := 20; i > 0; i-- {
@ -59,7 +55,6 @@ func TestInit0(t *testing.T) {
}
}
func TestInit1(t *testing.T) {
h := new(myHeap)
for i := 20; i > 0; i-- {
@ -77,7 +72,6 @@ func TestInit1(t *testing.T) {
}
}
func Test(t *testing.T) {
h := new(myHeap)
h.verify(t, 0)
@ -105,7 +99,6 @@ func Test(t *testing.T) {
}
}
func TestRemove0(t *testing.T) {
h := new(myHeap)
for i := 0; i < 10; i++ {
@ -123,7 +116,6 @@ func TestRemove0(t *testing.T) {
}
}
func TestRemove1(t *testing.T) {
h := new(myHeap)
for i := 0; i < 10; i++ {
@ -140,7 +132,6 @@ func TestRemove1(t *testing.T) {
}
}
func TestRemove2(t *testing.T) {
N := 10

View File

@ -16,14 +16,12 @@ type Ring struct {
Value interface{} // for use by client; untouched by this library
}
func (r *Ring) init() *Ring {
r.next = r
r.prev = r
return r
}
// Next returns the next ring element. r must not be empty.
func (r *Ring) Next() *Ring {
if r.next == nil {
@ -32,7 +30,6 @@ func (r *Ring) Next() *Ring {
return r.next
}
// Prev returns the previous ring element. r must not be empty.
func (r *Ring) Prev() *Ring {
if r.next == nil {
@ -41,7 +38,6 @@ func (r *Ring) Prev() *Ring {
return r.prev
}
// Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0)
// in the ring and returns that ring element. r must not be empty.
//
@ -62,7 +58,6 @@ func (r *Ring) Move(n int) *Ring {
return r
}
// New creates a ring of n elements.
func New(n int) *Ring {
if n <= 0 {
@ -79,7 +74,6 @@ func New(n int) *Ring {
return r
}
// Link connects ring r with with ring s such that r.Next()
// becomes s and returns the original value for r.Next().
// r must not be empty.
@ -110,7 +104,6 @@ func (r *Ring) Link(s *Ring) *Ring {
return n
}
// Unlink removes n % r.Len() elements from the ring r, starting
// at r.Next(). If n % r.Len() == 0, r remains unchanged.
// The result is the removed subring. r must not be empty.
@ -122,7 +115,6 @@ func (r *Ring) Unlink(n int) *Ring {
return r.Link(r.Move(n + 1))
}
// Len computes the number of elements in ring r.
// It executes in time proportional to the number of elements.
//
@ -137,7 +129,6 @@ func (r *Ring) Len() int {
return n
}
// Do calls function f on each element of the ring, in forward order.
// The behavior of Do is undefined if f changes *r.
func (r *Ring) Do(f func(interface{})) {

View File

@ -9,7 +9,6 @@ import (
"testing"
)
// For debugging - keep around.
func dump(r *Ring) {
if r == nil {
@ -24,7 +23,6 @@ func dump(r *Ring) {
fmt.Println()
}
func verify(t *testing.T, r *Ring, N int, sum int) {
// Len
n := r.Len()
@ -96,7 +94,6 @@ func verify(t *testing.T, r *Ring, N int, sum int) {
}
}
func TestCornerCases(t *testing.T) {
var (
r0 *Ring
@ -118,7 +115,6 @@ func TestCornerCases(t *testing.T) {
verify(t, &r1, 1, 0)
}
func makeN(n int) *Ring {
r := New(n)
for i := 1; i <= n; i++ {
@ -130,7 +126,6 @@ func makeN(n int) *Ring {
func sumN(n int) int { return (n*n + n) / 2 }
func TestNew(t *testing.T) {
for i := 0; i < 10; i++ {
r := New(i)
@ -142,7 +137,6 @@ func TestNew(t *testing.T) {
}
}
func TestLink1(t *testing.T) {
r1a := makeN(1)
var r1b Ring
@ -163,7 +157,6 @@ func TestLink1(t *testing.T) {
verify(t, r2b, 1, 0)
}
func TestLink2(t *testing.T) {
var r0 *Ring
r1a := &Ring{Value: 42}
@ -183,7 +176,6 @@ func TestLink2(t *testing.T) {
verify(t, r10, 12, sumN(10)+42+77)
}
func TestLink3(t *testing.T) {
var r Ring
n := 1
@ -193,7 +185,6 @@ func TestLink3(t *testing.T) {
}
}
func TestUnlink(t *testing.T) {
r10 := makeN(10)
s10 := r10.Move(6)
@ -215,7 +206,6 @@ func TestUnlink(t *testing.T) {
verify(t, r10, 9, sum10-2)
}
func TestLinkUnlink(t *testing.T) {
for i := 1; i < 4; i++ {
ri := New(i)

View File

@ -6,29 +6,24 @@
// Vectors grow and shrink dynamically as necessary.
package vector
// Vector is a container for numbered sequences of elements of type interface{}.
// A vector's length and capacity adjusts automatically as necessary.
// The zero value for Vector is an empty vector ready to use.
type Vector []interface{}
// IntVector is a container for numbered sequences of elements of type int.
// A vector's length and capacity adjusts automatically as necessary.
// The zero value for IntVector is an empty vector ready to use.
type IntVector []int
// StringVector is a container for numbered sequences of elements of type string.
// A vector's length and capacity adjusts automatically as necessary.
// The zero value for StringVector is an empty vector ready to use.
type StringVector []string
// Initial underlying array size
const initialSize = 8
// Partial sort.Interface support
// LessInterface provides partial support of the sort.Interface.
@ -36,16 +31,13 @@ type LessInterface interface {
Less(y interface{}) bool
}
// Less returns a boolean denoting whether the i'th element is less than the j'th element.
func (p *Vector) Less(i, j int) bool { return (*p)[i].(LessInterface).Less((*p)[j]) }
// sort.Interface support
// Less returns a boolean denoting whether the i'th element is less than the j'th element.
func (p *IntVector) Less(i, j int) bool { return (*p)[i] < (*p)[j] }
// Less returns a boolean denoting whether the i'th element is less than the j'th element.
func (p *StringVector) Less(i, j int) bool { return (*p)[i] < (*p)[j] }

View File

@ -7,7 +7,6 @@
package vector
func (p *IntVector) realloc(length, capacity int) (b []int) {
if capacity < initialSize {
capacity = initialSize
@ -21,7 +20,6 @@ func (p *IntVector) realloc(length, capacity int) (b []int) {
return
}
// Insert n elements at position i.
func (p *IntVector) Expand(i, n int) {
a := *p
@ -51,11 +49,9 @@ func (p *IntVector) Expand(i, n int) {
*p = a
}
// Insert n elements at the end of a vector.
func (p *IntVector) Extend(n int) { p.Expand(len(*p), n) }
// Resize changes the length and capacity of a vector.
// If the new length is shorter than the current length, Resize discards
// trailing elements. If the new length is longer than the current length,
@ -80,30 +76,24 @@ func (p *IntVector) Resize(length, capacity int) *IntVector {
return p
}
// Len returns the number of elements in the vector.
// Same as len(*p).
func (p *IntVector) Len() int { return len(*p) }
// Cap returns the capacity of the vector; that is, the
// maximum length the vector can grow without resizing.
// Same as cap(*p).
func (p *IntVector) Cap() int { return cap(*p) }
// At returns the i'th element of the vector.
func (p *IntVector) At(i int) int { return (*p)[i] }
// Set sets the i'th element of the vector to value x.
func (p *IntVector) Set(i int, x int) { (*p)[i] = x }
// Last returns the element in the vector of highest index.
func (p *IntVector) Last() int { return (*p)[len(*p)-1] }
// Copy makes a copy of the vector and returns it.
func (p *IntVector) Copy() IntVector {
arr := make(IntVector, len(*p))
@ -111,7 +101,6 @@ func (p *IntVector) Copy() IntVector {
return arr
}
// Insert inserts into the vector an element of value x before
// the current element at index i.
func (p *IntVector) Insert(i int, x int) {
@ -119,7 +108,6 @@ func (p *IntVector) Insert(i int, x int) {
(*p)[i] = x
}
// Delete deletes the i'th element of the vector. The gap is closed so the old
// element at index i+1 has index i afterwards.
func (p *IntVector) Delete(i int) {
@ -132,7 +120,6 @@ func (p *IntVector) Delete(i int) {
*p = a[0 : n-1]
}
// InsertVector inserts into the vector the contents of the vector
// x such that the 0th element of x appears at index i after insertion.
func (p *IntVector) InsertVector(i int, x *IntVector) {
@ -142,7 +129,6 @@ func (p *IntVector) InsertVector(i int, x *IntVector) {
copy((*p)[i:i+len(b)], b)
}
// Cut deletes elements i through j-1, inclusive.
func (p *IntVector) Cut(i, j int) {
a := *p
@ -158,7 +144,6 @@ func (p *IntVector) Cut(i, j int) {
*p = a[0:m]
}
// Slice returns a new sub-vector by slicing the old one to extract slice [i:j].
// The elements are copied. The original vector is unchanged.
func (p *IntVector) Slice(i, j int) *IntVector {
@ -168,13 +153,11 @@ func (p *IntVector) Slice(i, j int) *IntVector {
return &s
}
// Convenience wrappers
// Push appends x to the end of the vector.
func (p *IntVector) Push(x int) { p.Insert(len(*p), x) }
// Pop deletes the last element of the vector.
func (p *IntVector) Pop() int {
a := *p
@ -187,18 +170,15 @@ func (p *IntVector) Pop() int {
return x
}
// AppendVector appends the entire vector x to the end of this vector.
func (p *IntVector) AppendVector(x *IntVector) { p.InsertVector(len(*p), x) }
// Swap exchanges the elements at indexes i and j.
func (p *IntVector) Swap(i, j int) {
a := *p
a[i], a[j] = a[j], a[i]
}
// Do calls function f for each element of the vector, in order.
// The behavior of Do is undefined if f changes *p.
func (p *IntVector) Do(f func(elem int)) {

View File

@ -9,7 +9,6 @@ package vector
import "testing"
func TestIntZeroLen(t *testing.T) {
a := new(IntVector)
if a.Len() != 0 {
@ -27,7 +26,6 @@ func TestIntZeroLen(t *testing.T) {
}
}
func TestIntResize(t *testing.T) {
var a IntVector
checkSize(t, &a, 0, 0)
@ -40,7 +38,6 @@ func TestIntResize(t *testing.T) {
checkSize(t, a.Resize(11, 100), 11, 100)
}
func TestIntResize2(t *testing.T) {
var a IntVector
checkSize(t, &a, 0, 0)
@ -62,7 +59,6 @@ func TestIntResize2(t *testing.T) {
}
}
func checkIntZero(t *testing.T, a *IntVector, i int) {
for j := 0; j < i; j++ {
if a.At(j) == intzero {
@ -82,7 +78,6 @@ func checkIntZero(t *testing.T, a *IntVector, i int) {
}
}
func TestIntTrailingElements(t *testing.T) {
var a IntVector
for i := 0; i < 10; i++ {
@ -95,7 +90,6 @@ func TestIntTrailingElements(t *testing.T) {
checkIntZero(t, &a, 5)
}
func TestIntAccess(t *testing.T) {
const n = 100
var a IntVector
@ -120,7 +114,6 @@ func TestIntAccess(t *testing.T) {
}
}
func TestIntInsertDeleteClear(t *testing.T) {
const n = 100
var a IntVector
@ -207,7 +200,6 @@ func TestIntInsertDeleteClear(t *testing.T) {
}
}
func verify_sliceInt(t *testing.T, x *IntVector, elt, i, j int) {
for k := i; k < j; k++ {
if elem2IntValue(x.At(k)) != int2IntValue(elt) {
@ -223,7 +215,6 @@ func verify_sliceInt(t *testing.T, x *IntVector, elt, i, j int) {
}
}
func verify_patternInt(t *testing.T, x *IntVector, a, b, c int) {
n := a + b + c
if x.Len() != n {
@ -237,7 +228,6 @@ func verify_patternInt(t *testing.T, x *IntVector, a, b, c int) {
verify_sliceInt(t, x, 0, a+b, n)
}
func make_vectorInt(elt, len int) *IntVector {
x := new(IntVector).Resize(len, 0)
for i := 0; i < len; i++ {
@ -246,7 +236,6 @@ func make_vectorInt(elt, len int) *IntVector {
return x
}
func TestIntInsertVector(t *testing.T) {
// 1
a := make_vectorInt(0, 0)
@ -270,7 +259,6 @@ func TestIntInsertVector(t *testing.T) {
verify_patternInt(t, a, 8, 1000, 2)
}
func TestIntDo(t *testing.T) {
const n = 25
const salt = 17
@ -325,7 +313,6 @@ func TestIntDo(t *testing.T) {
}
func TestIntVectorCopy(t *testing.T) {
// verify Copy() returns a copy, not simply a slice of the original vector
const Len = 10

View File

@ -4,7 +4,6 @@
package vector
import (
"fmt"
"sort"
@ -17,28 +16,23 @@ var (
strzero string
)
func int2Value(x int) int { return x }
func int2IntValue(x int) int { return x }
func int2StrValue(x int) string { return string(x) }
func elem2Value(x interface{}) int { return x.(int) }
func elem2IntValue(x int) int { return x }
func elem2StrValue(x string) string { return x }
func intf2Value(x interface{}) int { return x.(int) }
func intf2IntValue(x interface{}) int { return x.(int) }
func intf2StrValue(x interface{}) string { return x.(string) }
type VectorInterface interface {
Len() int
Cap() int
}
func checkSize(t *testing.T, v VectorInterface, len, cap int) {
if v.Len() != len {
t.Errorf("%T expected len = %d; found %d", v, len, v.Len())
@ -48,10 +42,8 @@ func checkSize(t *testing.T, v VectorInterface, len, cap int) {
}
}
func val(i int) int { return i*991 - 1234 }
func TestSorting(t *testing.T) {
const n = 100
@ -72,5 +64,4 @@ func TestSorting(t *testing.T) {
}
}
func tname(x interface{}) string { return fmt.Sprintf("%T: ", x) }

View File

@ -11,10 +11,8 @@ import (
"testing"
)
const memTestN = 1000000
func s(n uint64) string {
str := fmt.Sprintf("%d", n)
lens := len(str)
@ -31,7 +29,6 @@ func s(n uint64) string {
return strings.Join(a, " ")
}
func TestVectorNums(t *testing.T) {
if testing.Short() {
return
@ -52,7 +49,6 @@ func TestVectorNums(t *testing.T) {
t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN)
}
func TestIntVectorNums(t *testing.T) {
if testing.Short() {
return
@ -73,7 +69,6 @@ func TestIntVectorNums(t *testing.T) {
t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN)
}
func TestStringVectorNums(t *testing.T) {
if testing.Short() {
return
@ -94,7 +89,6 @@ func TestStringVectorNums(t *testing.T) {
t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN)
}
func BenchmarkVectorNums(b *testing.B) {
c := int(0)
var v Vector
@ -106,7 +100,6 @@ func BenchmarkVectorNums(b *testing.B) {
}
}
func BenchmarkIntVectorNums(b *testing.B) {
c := int(0)
var v IntVector
@ -118,7 +111,6 @@ func BenchmarkIntVectorNums(b *testing.B) {
}
}
func BenchmarkStringVectorNums(b *testing.B) {
c := ""
var v StringVector

View File

@ -7,7 +7,6 @@
package vector
func (p *StringVector) realloc(length, capacity int) (b []string) {
if capacity < initialSize {
capacity = initialSize
@ -21,7 +20,6 @@ func (p *StringVector) realloc(length, capacity int) (b []string) {
return
}
// Insert n elements at position i.
func (p *StringVector) Expand(i, n int) {
a := *p
@ -51,11 +49,9 @@ func (p *StringVector) Expand(i, n int) {
*p = a
}
// Insert n elements at the end of a vector.
func (p *StringVector) Extend(n int) { p.Expand(len(*p), n) }
// Resize changes the length and capacity of a vector.
// If the new length is shorter than the current length, Resize discards
// trailing elements. If the new length is longer than the current length,
@ -80,30 +76,24 @@ func (p *StringVector) Resize(length, capacity int) *StringVector {
return p
}
// Len returns the number of elements in the vector.
// Same as len(*p).
func (p *StringVector) Len() int { return len(*p) }
// Cap returns the capacity of the vector; that is, the
// maximum length the vector can grow without resizing.
// Same as cap(*p).
func (p *StringVector) Cap() int { return cap(*p) }
// At returns the i'th element of the vector.
func (p *StringVector) At(i int) string { return (*p)[i] }
// Set sets the i'th element of the vector to value x.
func (p *StringVector) Set(i int, x string) { (*p)[i] = x }
// Last returns the element in the vector of highest index.
func (p *StringVector) Last() string { return (*p)[len(*p)-1] }
// Copy makes a copy of the vector and returns it.
func (p *StringVector) Copy() StringVector {
arr := make(StringVector, len(*p))
@ -111,7 +101,6 @@ func (p *StringVector) Copy() StringVector {
return arr
}
// Insert inserts into the vector an element of value x before
// the current element at index i.
func (p *StringVector) Insert(i int, x string) {
@ -119,7 +108,6 @@ func (p *StringVector) Insert(i int, x string) {
(*p)[i] = x
}
// Delete deletes the i'th element of the vector. The gap is closed so the old
// element at index i+1 has index i afterwards.
func (p *StringVector) Delete(i int) {
@ -132,7 +120,6 @@ func (p *StringVector) Delete(i int) {
*p = a[0 : n-1]
}
// InsertVector inserts into the vector the contents of the vector
// x such that the 0th element of x appears at index i after insertion.
func (p *StringVector) InsertVector(i int, x *StringVector) {
@ -142,7 +129,6 @@ func (p *StringVector) InsertVector(i int, x *StringVector) {
copy((*p)[i:i+len(b)], b)
}
// Cut deletes elements i through j-1, inclusive.
func (p *StringVector) Cut(i, j int) {
a := *p
@ -158,7 +144,6 @@ func (p *StringVector) Cut(i, j int) {
*p = a[0:m]
}
// Slice returns a new sub-vector by slicing the old one to extract slice [i:j].
// The elements are copied. The original vector is unchanged.
func (p *StringVector) Slice(i, j int) *StringVector {
@ -168,13 +153,11 @@ func (p *StringVector) Slice(i, j int) *StringVector {
return &s
}
// Convenience wrappers
// Push appends x to the end of the vector.
func (p *StringVector) Push(x string) { p.Insert(len(*p), x) }
// Pop deletes the last element of the vector.
func (p *StringVector) Pop() string {
a := *p
@ -187,18 +170,15 @@ func (p *StringVector) Pop() string {
return x
}
// AppendVector appends the entire vector x to the end of this vector.
func (p *StringVector) AppendVector(x *StringVector) { p.InsertVector(len(*p), x) }
// Swap exchanges the elements at indexes i and j.
func (p *StringVector) Swap(i, j int) {
a := *p
a[i], a[j] = a[j], a[i]
}
// Do calls function f for each element of the vector, in order.
// The behavior of Do is undefined if f changes *p.
func (p *StringVector) Do(f func(elem string)) {

View File

@ -9,7 +9,6 @@ package vector
import "testing"
func TestStrZeroLen(t *testing.T) {
a := new(StringVector)
if a.Len() != 0 {
@ -27,7 +26,6 @@ func TestStrZeroLen(t *testing.T) {
}
}
func TestStrResize(t *testing.T) {
var a StringVector
checkSize(t, &a, 0, 0)
@ -40,7 +38,6 @@ func TestStrResize(t *testing.T) {
checkSize(t, a.Resize(11, 100), 11, 100)
}
func TestStrResize2(t *testing.T) {
var a StringVector
checkSize(t, &a, 0, 0)
@ -62,7 +59,6 @@ func TestStrResize2(t *testing.T) {
}
}
func checkStrZero(t *testing.T, a *StringVector, i int) {
for j := 0; j < i; j++ {
if a.At(j) == strzero {
@ -82,7 +78,6 @@ func checkStrZero(t *testing.T, a *StringVector, i int) {
}
}
func TestStrTrailingElements(t *testing.T) {
var a StringVector
for i := 0; i < 10; i++ {
@ -95,7 +90,6 @@ func TestStrTrailingElements(t *testing.T) {
checkStrZero(t, &a, 5)
}
func TestStrAccess(t *testing.T) {
const n = 100
var a StringVector
@ -120,7 +114,6 @@ func TestStrAccess(t *testing.T) {
}
}
func TestStrInsertDeleteClear(t *testing.T) {
const n = 100
var a StringVector
@ -207,7 +200,6 @@ func TestStrInsertDeleteClear(t *testing.T) {
}
}
func verify_sliceStr(t *testing.T, x *StringVector, elt, i, j int) {
for k := i; k < j; k++ {
if elem2StrValue(x.At(k)) != int2StrValue(elt) {
@ -223,7 +215,6 @@ func verify_sliceStr(t *testing.T, x *StringVector, elt, i, j int) {
}
}
func verify_patternStr(t *testing.T, x *StringVector, a, b, c int) {
n := a + b + c
if x.Len() != n {
@ -237,7 +228,6 @@ func verify_patternStr(t *testing.T, x *StringVector, a, b, c int) {
verify_sliceStr(t, x, 0, a+b, n)
}
func make_vectorStr(elt, len int) *StringVector {
x := new(StringVector).Resize(len, 0)
for i := 0; i < len; i++ {
@ -246,7 +236,6 @@ func make_vectorStr(elt, len int) *StringVector {
return x
}
func TestStrInsertVector(t *testing.T) {
// 1
a := make_vectorStr(0, 0)
@ -270,7 +259,6 @@ func TestStrInsertVector(t *testing.T) {
verify_patternStr(t, a, 8, 1000, 2)
}
func TestStrDo(t *testing.T) {
const n = 25
const salt = 17
@ -325,7 +313,6 @@ func TestStrDo(t *testing.T) {
}
func TestStrVectorCopy(t *testing.T) {
// verify Copy() returns a copy, not simply a slice of the original vector
const Len = 10

View File

@ -7,7 +7,6 @@
package vector
func (p *Vector) realloc(length, capacity int) (b []interface{}) {
if capacity < initialSize {
capacity = initialSize
@ -21,7 +20,6 @@ func (p *Vector) realloc(length, capacity int) (b []interface{}) {
return
}
// Insert n elements at position i.
func (p *Vector) Expand(i, n int) {
a := *p
@ -51,11 +49,9 @@ func (p *Vector) Expand(i, n int) {
*p = a
}
// Insert n elements at the end of a vector.
func (p *Vector) Extend(n int) { p.Expand(len(*p), n) }
// Resize changes the length and capacity of a vector.
// If the new length is shorter than the current length, Resize discards
// trailing elements. If the new length is longer than the current length,
@ -80,30 +76,24 @@ func (p *Vector) Resize(length, capacity int) *Vector {
return p
}
// Len returns the number of elements in the vector.
// Same as len(*p).
func (p *Vector) Len() int { return len(*p) }
// Cap returns the capacity of the vector; that is, the
// maximum length the vector can grow without resizing.
// Same as cap(*p).
func (p *Vector) Cap() int { return cap(*p) }
// At returns the i'th element of the vector.
func (p *Vector) At(i int) interface{} { return (*p)[i] }
// Set sets the i'th element of the vector to value x.
func (p *Vector) Set(i int, x interface{}) { (*p)[i] = x }
// Last returns the element in the vector of highest index.
func (p *Vector) Last() interface{} { return (*p)[len(*p)-1] }
// Copy makes a copy of the vector and returns it.
func (p *Vector) Copy() Vector {
arr := make(Vector, len(*p))
@ -111,7 +101,6 @@ func (p *Vector) Copy() Vector {
return arr
}
// Insert inserts into the vector an element of value x before
// the current element at index i.
func (p *Vector) Insert(i int, x interface{}) {
@ -119,7 +108,6 @@ func (p *Vector) Insert(i int, x interface{}) {
(*p)[i] = x
}
// Delete deletes the i'th element of the vector. The gap is closed so the old
// element at index i+1 has index i afterwards.
func (p *Vector) Delete(i int) {
@ -132,7 +120,6 @@ func (p *Vector) Delete(i int) {
*p = a[0 : n-1]
}
// InsertVector inserts into the vector the contents of the vector
// x such that the 0th element of x appears at index i after insertion.
func (p *Vector) InsertVector(i int, x *Vector) {
@ -142,7 +129,6 @@ func (p *Vector) InsertVector(i int, x *Vector) {
copy((*p)[i:i+len(b)], b)
}
// Cut deletes elements i through j-1, inclusive.
func (p *Vector) Cut(i, j int) {
a := *p
@ -158,7 +144,6 @@ func (p *Vector) Cut(i, j int) {
*p = a[0:m]
}
// Slice returns a new sub-vector by slicing the old one to extract slice [i:j].
// The elements are copied. The original vector is unchanged.
func (p *Vector) Slice(i, j int) *Vector {
@ -168,13 +153,11 @@ func (p *Vector) Slice(i, j int) *Vector {
return &s
}
// Convenience wrappers
// Push appends x to the end of the vector.
func (p *Vector) Push(x interface{}) { p.Insert(len(*p), x) }
// Pop deletes the last element of the vector.
func (p *Vector) Pop() interface{} {
a := *p
@ -187,18 +170,15 @@ func (p *Vector) Pop() interface{} {
return x
}
// AppendVector appends the entire vector x to the end of this vector.
func (p *Vector) AppendVector(x *Vector) { p.InsertVector(len(*p), x) }
// Swap exchanges the elements at indexes i and j.
func (p *Vector) Swap(i, j int) {
a := *p
a[i], a[j] = a[j], a[i]
}
// Do calls function f for each element of the vector, in order.
// The behavior of Do is undefined if f changes *p.
func (p *Vector) Do(f func(elem interface{})) {

View File

@ -9,7 +9,6 @@ package vector
import "testing"
func TestZeroLen(t *testing.T) {
a := new(Vector)
if a.Len() != 0 {
@ -27,7 +26,6 @@ func TestZeroLen(t *testing.T) {
}
}
func TestResize(t *testing.T) {
var a Vector
checkSize(t, &a, 0, 0)
@ -40,7 +38,6 @@ func TestResize(t *testing.T) {
checkSize(t, a.Resize(11, 100), 11, 100)
}
func TestResize2(t *testing.T) {
var a Vector
checkSize(t, &a, 0, 0)
@ -62,7 +59,6 @@ func TestResize2(t *testing.T) {
}
}
func checkZero(t *testing.T, a *Vector, i int) {
for j := 0; j < i; j++ {
if a.At(j) == zero {
@ -82,7 +78,6 @@ func checkZero(t *testing.T, a *Vector, i int) {
}
}
func TestTrailingElements(t *testing.T) {
var a Vector
for i := 0; i < 10; i++ {
@ -95,7 +90,6 @@ func TestTrailingElements(t *testing.T) {
checkZero(t, &a, 5)
}
func TestAccess(t *testing.T) {
const n = 100
var a Vector
@ -120,7 +114,6 @@ func TestAccess(t *testing.T) {
}
}
func TestInsertDeleteClear(t *testing.T) {
const n = 100
var a Vector
@ -207,7 +200,6 @@ func TestInsertDeleteClear(t *testing.T) {
}
}
func verify_slice(t *testing.T, x *Vector, elt, i, j int) {
for k := i; k < j; k++ {
if elem2Value(x.At(k)) != int2Value(elt) {
@ -223,7 +215,6 @@ func verify_slice(t *testing.T, x *Vector, elt, i, j int) {
}
}
func verify_pattern(t *testing.T, x *Vector, a, b, c int) {
n := a + b + c
if x.Len() != n {
@ -237,7 +228,6 @@ func verify_pattern(t *testing.T, x *Vector, a, b, c int) {
verify_slice(t, x, 0, a+b, n)
}
func make_vector(elt, len int) *Vector {
x := new(Vector).Resize(len, 0)
for i := 0; i < len; i++ {
@ -246,7 +236,6 @@ func make_vector(elt, len int) *Vector {
return x
}
func TestInsertVector(t *testing.T) {
// 1
a := make_vector(0, 0)
@ -270,7 +259,6 @@ func TestInsertVector(t *testing.T) {
verify_pattern(t, a, 8, 1000, 2)
}
func TestDo(t *testing.T) {
const n = 25
const salt = 17
@ -325,7 +313,6 @@ func TestDo(t *testing.T) {
}
func TestVectorCopy(t *testing.T) {
// verify Copy() returns a copy, not simply a slice of the original vector
const Len = 10

View File

@ -45,14 +45,14 @@ func NewCipher(key []byte) (*Cipher, os.Error) {
// BlockSize returns the AES block size, 16 bytes.
// It is necessary to satisfy the Cipher interface in the
// package "crypto/block".
// package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize }
// Encrypt encrypts the 16-byte buffer src using the key k
// and stores the result in dst.
// Note that for amounts of data larger than a block,
// it is not safe to just call Encrypt on successive blocks;
// instead, use an encryption mode like CBC (see crypto/block/cbc.go).
// instead, use an encryption mode like CBC (see crypto/cipher/cbc.go).
func (c *Cipher) Encrypt(dst, src []byte) { encryptBlock(c.enc, dst, src) }
// Decrypt decrypts the 16-byte buffer src using the key k

View File

@ -42,14 +42,14 @@ func NewCipher(key []byte) (*Cipher, os.Error) {
// BlockSize returns the Blowfish block size, 8 bytes.
// It is necessary to satisfy the Cipher interface in the
// package "crypto/block".
// package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize }
// Encrypt encrypts the 8-byte buffer src using the key k
// and stores the result in dst.
// Note that for amounts of data larger than a block,
// it is not safe to just call Encrypt on successive blocks;
// instead, use an encryption mode like CBC (see crypto/block/cbc.go).
// instead, use an encryption mode like CBC (see crypto/cipher/cbc.go).
func (c *Cipher) Encrypt(dst, src []byte) {
l := uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3])
r := uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7])

View File

@ -20,7 +20,7 @@ type Cipher struct {
func NewCipher(key []byte) (c *Cipher, err os.Error) {
if len(key) != KeySize {
return nil, os.ErrorString("CAST5: keys must be 16 bytes")
return nil, os.NewError("CAST5: keys must be 16 bytes")
}
c = new(Cipher)

View File

@ -80,9 +80,10 @@ type ocfbDecrypter struct {
// NewOCFBDecrypter returns a Stream which decrypts data with OpenPGP's cipher
// feedback mode using the given Block. Prefix must be the first blockSize + 2
// bytes of the ciphertext, where blockSize is the Block's block size. If an
// incorrect key is detected then nil is returned. Resync determines if the
// "resynchronization step" from RFC 4880, 13.9 step 7 is performed. Different
// parts of OpenPGP vary on this point.
// incorrect key is detected then nil is returned. On successful exit,
// blockSize+2 bytes of decrypted data are written into prefix. Resync
// determines if the "resynchronization step" from RFC 4880, 13.9 step 7 is
// performed. Different parts of OpenPGP vary on this point.
func NewOCFBDecrypter(block Block, prefix []byte, resync OCFBResyncOption) Stream {
blockSize := block.BlockSize()
if len(prefix) != blockSize+2 {
@ -118,6 +119,7 @@ func NewOCFBDecrypter(block Block, prefix []byte, resync OCFBResyncOption) Strea
x.fre[1] = prefix[blockSize+1]
x.outUsed = 2
}
copy(prefix, prefixCopy)
return x
}

View File

@ -79,7 +79,7 @@ func GenerateParameters(params *Parameters, rand io.Reader, sizes ParameterSizes
L = 3072
N = 256
default:
return os.ErrorString("crypto/dsa: invalid ParameterSizes")
return os.NewError("crypto/dsa: invalid ParameterSizes")
}
qBytes := make([]byte, N/8)
@ -158,7 +158,7 @@ GeneratePrimes:
// PrivateKey must already be valid (see GenerateParameters).
func GenerateKey(priv *PrivateKey, rand io.Reader) os.Error {
if priv.P == nil || priv.Q == nil || priv.G == nil {
return os.ErrorString("crypto/dsa: parameters not set up before generating key")
return os.NewError("crypto/dsa: parameters not set up before generating key")
}
x := new(big.Int)

View File

@ -284,7 +284,7 @@ func (curve *Curve) Marshal(x, y *big.Int) []byte {
return ret
}
// Unmarshal converts a point, serialised by Marshal, into an x, y pair. On
// Unmarshal converts a point, serialized by Marshal, into an x, y pair. On
// error, x = nil.
func (curve *Curve) Unmarshal(data []byte) (x, y *big.Int) {
byteLen := (curve.BitSize + 7) >> 3

View File

@ -321,8 +321,8 @@ func TestMarshal(t *testing.T) {
t.Error(err)
return
}
serialised := p224.Marshal(x, y)
xx, yy := p224.Unmarshal(serialised)
serialized := p224.Marshal(x, y)
xx, yy := p224.Unmarshal(serialized)
if xx == nil {
t.Error("failed to unmarshal")
return

View File

@ -190,7 +190,7 @@ func TestHMAC(t *testing.T) {
continue
}
// Repetive Sum() calls should return the same value
// Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", h.Sum())
if sum != tt.out {

View File

@ -13,6 +13,7 @@ import (
"crypto/rsa"
_ "crypto/sha1"
"crypto/x509"
"crypto/x509/pkix"
"os"
"time"
)
@ -32,21 +33,8 @@ const (
ocspUnauthorized = 5
)
type rdnSequence []relativeDistinguishedNameSET
type relativeDistinguishedNameSET []attributeTypeAndValue
type attributeTypeAndValue struct {
Type asn1.ObjectIdentifier
Value interface{}
}
type algorithmIdentifier struct {
Algorithm asn1.ObjectIdentifier
}
type certID struct {
HashAlgorithm algorithmIdentifier
HashAlgorithm pkix.AlgorithmIdentifier
NameHash []byte
IssuerKeyHash []byte
SerialNumber asn1.RawValue
@ -54,7 +42,7 @@ type certID struct {
type responseASN1 struct {
Status asn1.Enumerated
Response responseBytes "explicit,tag:0"
Response responseBytes `asn1:"explicit,tag:0"`
}
type responseBytes struct {
@ -64,32 +52,32 @@ type responseBytes struct {
type basicResponse struct {
TBSResponseData responseData
SignatureAlgorithm algorithmIdentifier
SignatureAlgorithm pkix.AlgorithmIdentifier
Signature asn1.BitString
Certificates []asn1.RawValue "explicit,tag:0,optional"
Certificates []asn1.RawValue `asn1:"explicit,tag:0,optional"`
}
type responseData struct {
Raw asn1.RawContent
Version int "optional,default:1,explicit,tag:0"
RequestorName rdnSequence "optional,explicit,tag:1"
KeyHash []byte "optional,explicit,tag:2"
Version int `asn1:"optional,default:1,explicit,tag:0"`
RequestorName pkix.RDNSequence `asn1:"optional,explicit,tag:1"`
KeyHash []byte `asn1:"optional,explicit,tag:2"`
ProducedAt *time.Time
Responses []singleResponse
}
type singleResponse struct {
CertID certID
Good asn1.Flag "explicit,tag:0,optional"
Revoked revokedInfo "explicit,tag:1,optional"
Unknown asn1.Flag "explicit,tag:2,optional"
Good asn1.Flag `asn1:"explicit,tag:0,optional"`
Revoked revokedInfo `asn1:"explicit,tag:1,optional"`
Unknown asn1.Flag `asn1:"explicit,tag:2,optional"`
ThisUpdate *time.Time
NextUpdate *time.Time "explicit,tag:0,optional"
NextUpdate *time.Time `asn1:"explicit,tag:0,optional"`
}
type revokedInfo struct {
RevocationTime *time.Time
Reason int "explicit,tag:0,optional"
Reason int `asn1:"explicit,tag:0,optional"`
}
// This is the exposed reflection of the internal OCSP structures.

View File

@ -153,7 +153,7 @@ func (r *openpgpReader) Read(p []byte) (n int, err os.Error) {
// Decode reads a PGP armored block from the given Reader. It will ignore
// leading garbage. If it doesn't find a block, it will return nil, os.EOF. The
// given Reader is not usable after calling this function: an arbitary amount
// given Reader is not usable after calling this function: an arbitrary amount
// of data may have been read past the end of the block.
func Decode(in io.Reader) (p *Block, err os.Error) {
r, _ := bufio.NewReaderSize(in, 100)

View File

@ -30,7 +30,6 @@ func (r recordingHash) Size() int {
panic("shouldn't be called")
}
func testCanonicalText(t *testing.T, input, expected string) {
r := recordingHash{bytes.NewBuffer(nil)}
c := NewCanonicalTextHash(r)

View File

@ -0,0 +1,122 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package elgamal implements ElGamal encryption, suitable for OpenPGP,
// as specified in "A Public-Key Cryptosystem and a Signature Scheme Based on
// Discrete Logarithms," IEEE Transactions on Information Theory, v. IT-31,
// n. 4, 1985, pp. 469-472.
//
// This form of ElGamal embeds PKCS#1 v1.5 padding, which may make it
// unsuitable for other protocols. RSA should be used in preference in any
// case.
package elgamal
import (
"big"
"crypto/rand"
"crypto/subtle"
"io"
"os"
)
// PublicKey represents an ElGamal public key.
type PublicKey struct {
G, P, Y *big.Int
}
// PrivateKey represents an ElGamal private key.
type PrivateKey struct {
PublicKey
X *big.Int
}
// Encrypt encrypts the given message to the given public key. The result is a
// pair of integers. Errors can result from reading random, or because msg is
// too large to be encrypted to the public key.
func Encrypt(random io.Reader, pub *PublicKey, msg []byte) (c1, c2 *big.Int, err os.Error) {
pLen := (pub.P.BitLen() + 7) / 8
if len(msg) > pLen-11 {
err = os.NewError("elgamal: message too long")
return
}
// EM = 0x02 || PS || 0x00 || M
em := make([]byte, pLen-1)
em[0] = 2
ps, mm := em[1:len(em)-len(msg)-1], em[len(em)-len(msg):]
err = nonZeroRandomBytes(ps, random)
if err != nil {
return
}
em[len(em)-len(msg)-1] = 0
copy(mm, msg)
m := new(big.Int).SetBytes(em)
k, err := rand.Int(random, pub.P)
if err != nil {
return
}
c1 = new(big.Int).Exp(pub.G, k, pub.P)
s := new(big.Int).Exp(pub.Y, k, pub.P)
c2 = s.Mul(s, m)
c2.Mod(c2, pub.P)
return
}
// Decrypt takes two integers, resulting from an ElGamal encryption, and
// returns the plaintext of the message. An error can result only if the
// ciphertext is invalid. Users should keep in mind that this is a padding
// oracle and thus, if exposed to an adaptive chosen ciphertext attack, can
// be used to break the cryptosystem. See ``Chosen Ciphertext Attacks
// Against Protocols Based on the RSA Encryption Standard PKCS #1'', Daniel
// Bleichenbacher, Advances in Cryptology (Crypto '98),
func Decrypt(priv *PrivateKey, c1, c2 *big.Int) (msg []byte, err os.Error) {
s := new(big.Int).Exp(c1, priv.X, priv.P)
s.ModInverse(s, priv.P)
s.Mul(s, c2)
s.Mod(s, priv.P)
em := s.Bytes()
firstByteIsTwo := subtle.ConstantTimeByteEq(em[0], 2)
// The remainder of the plaintext must be a string of non-zero random
// octets, followed by a 0, followed by the message.
// lookingForIndex: 1 iff we are still looking for the zero.
// index: the offset of the first zero byte.
var lookingForIndex, index int
lookingForIndex = 1
for i := 1; i < len(em); i++ {
equals0 := subtle.ConstantTimeByteEq(em[i], 0)
index = subtle.ConstantTimeSelect(lookingForIndex&equals0, i, index)
lookingForIndex = subtle.ConstantTimeSelect(equals0, 0, lookingForIndex)
}
if firstByteIsTwo != 1 || lookingForIndex != 0 || index < 9 {
return nil, os.NewError("elgamal: decryption error")
}
return em[index+1:], nil
}
// nonZeroRandomBytes fills the given slice with non-zero random octets.
func nonZeroRandomBytes(s []byte, rand io.Reader) (err os.Error) {
_, err = io.ReadFull(rand, s)
if err != nil {
return
}
for i := 0; i < len(s); i++ {
for s[i] == 0 {
_, err = io.ReadFull(rand, s[i:i+1])
if err != nil {
return
}
}
}
return
}

View File

@ -0,0 +1,49 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package elgamal
import (
"big"
"bytes"
"crypto/rand"
"testing"
)
// This is the 1024-bit MODP group from RFC 5114, section 2.1:
const primeHex = "B10B8F96A080E01DDE92DE5EAE5D54EC52C99FBCFB06A3C69A6A9DCA52D23B616073E28675A23D189838EF1E2EE652C013ECB4AEA906112324975C3CD49B83BFACCBDD7D90C4BD7098488E9C219A73724EFFD6FAE5644738FAA31A4FF55BCCC0A151AF5F0DC8B4BD45BF37DF365C1A65E68CFDA76D4DA708DF1FB2BC2E4A4371"
const generatorHex = "A4D1CBD5C3FD34126765A442EFB99905F8104DD258AC507FD6406CFF14266D31266FEA1E5C41564B777E690F5504F213160217B4B01B886A5E91547F9E2749F4D7FBD7D3B9A92EE1909D0D2263F80A76A6A24C087A091F531DBF0A0169B6A28AD662A4D18E73AFA32D779D5918D08BC8858F4DCEF97C2A24855E6EEB22B3B2E5"
func fromHex(hex string) *big.Int {
n, ok := new(big.Int).SetString(hex, 16)
if !ok {
panic("failed to parse hex number")
}
return n
}
func TestEncryptDecrypt(t *testing.T) {
priv := &PrivateKey{
PublicKey: PublicKey{
G: fromHex(generatorHex),
P: fromHex(primeHex),
},
X: fromHex("42"),
}
priv.Y = new(big.Int).Exp(priv.G, priv.X, priv.P)
message := []byte("hello world")
c1, c2, err := Encrypt(rand.Reader, &priv.PublicKey, message)
if err != nil {
t.Errorf("error encrypting: %s", err)
}
message2, err := Decrypt(priv, c1, c2)
if err != nil {
t.Errorf("error decrypting: %s", err)
}
if !bytes.Equal(message2, message) {
t.Errorf("decryption failed, got: %x, want: %x", message2, message)
}
}

View File

@ -5,11 +5,14 @@
package openpgp
import (
"crypto"
"crypto/openpgp/armor"
"crypto/openpgp/error"
"crypto/openpgp/packet"
"crypto/rsa"
"io"
"os"
"time"
)
// PublicKeyType is the armor type for a PGP public key.
@ -62,6 +65,78 @@ type KeyRing interface {
DecryptionKeys() []Key
}
// primaryIdentity returns the Identity marked as primary or the first identity
// if none are so marked.
func (e *Entity) primaryIdentity() *Identity {
var firstIdentity *Identity
for _, ident := range e.Identities {
if firstIdentity == nil {
firstIdentity = ident
}
if ident.SelfSignature.IsPrimaryId != nil && *ident.SelfSignature.IsPrimaryId {
return ident
}
}
return firstIdentity
}
// encryptionKey returns the best candidate Key for encrypting a message to the
// given Entity.
func (e *Entity) encryptionKey() Key {
candidateSubkey := -1
for i, subkey := range e.Subkeys {
if subkey.Sig.FlagsValid && subkey.Sig.FlagEncryptCommunications && subkey.PublicKey.PubKeyAlgo.CanEncrypt() {
candidateSubkey = i
break
}
}
i := e.primaryIdentity()
if e.PrimaryKey.PubKeyAlgo.CanEncrypt() {
// If we don't have any candidate subkeys for encryption and
// the primary key doesn't have any usage metadata then we
// assume that the primary key is ok. Or, if the primary key is
// marked as ok to encrypt to, then we can obviously use it.
if candidateSubkey == -1 && !i.SelfSignature.FlagsValid || i.SelfSignature.FlagEncryptCommunications && i.SelfSignature.FlagsValid {
return Key{e, e.PrimaryKey, e.PrivateKey, i.SelfSignature}
}
}
if candidateSubkey != -1 {
subkey := e.Subkeys[candidateSubkey]
return Key{e, subkey.PublicKey, subkey.PrivateKey, subkey.Sig}
}
// This Entity appears to be signing only.
return Key{}
}
// signingKey return the best candidate Key for signing a message with this
// Entity.
func (e *Entity) signingKey() Key {
candidateSubkey := -1
for i, subkey := range e.Subkeys {
if subkey.Sig.FlagsValid && subkey.Sig.FlagSign && subkey.PublicKey.PubKeyAlgo.CanSign() {
candidateSubkey = i
break
}
}
i := e.primaryIdentity()
// If we have no candidate subkey then we assume that it's ok to sign
// with the primary key.
if candidateSubkey == -1 || i.SelfSignature.FlagsValid && i.SelfSignature.FlagSign {
return Key{e, e.PrimaryKey, e.PrivateKey, i.SelfSignature}
}
subkey := e.Subkeys[candidateSubkey]
return Key{e, subkey.PublicKey, subkey.PrivateKey, subkey.Sig}
}
// An EntityList contains one or more Entities.
type EntityList []*Entity
@ -197,6 +272,10 @@ func readEntity(packets *packet.Reader) (*Entity, os.Error) {
}
}
if !e.PrimaryKey.PubKeyAlgo.CanSign() {
return nil, error.StructuralError("primary key cannot be used for signatures")
}
var current *Identity
EachPacket:
for {
@ -227,7 +306,7 @@ EachPacket:
return nil, error.StructuralError("user ID packet not followed by self-signature")
}
if sig.SigType == packet.SigTypePositiveCert && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
if (sig.SigType == packet.SigTypePositiveCert || sig.SigType == packet.SigTypeGenericCert) && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil {
return nil, error.StructuralError("user ID self-signature invalid: " + err.String())
}
@ -297,3 +376,170 @@ func addSubkey(e *Entity, packets *packet.Reader, pub *packet.PublicKey, priv *p
e.Subkeys = append(e.Subkeys, subKey)
return nil
}
const defaultRSAKeyBits = 2048
// NewEntity returns an Entity that contains a fresh RSA/RSA keypair with a
// single identity composed of the given full name, comment and email, any of
// which may be empty but must not contain any of "()<>\x00".
func NewEntity(rand io.Reader, currentTimeSecs int64, name, comment, email string) (*Entity, os.Error) {
uid := packet.NewUserId(name, comment, email)
if uid == nil {
return nil, error.InvalidArgumentError("user id field contained invalid characters")
}
signingPriv, err := rsa.GenerateKey(rand, defaultRSAKeyBits)
if err != nil {
return nil, err
}
encryptingPriv, err := rsa.GenerateKey(rand, defaultRSAKeyBits)
if err != nil {
return nil, err
}
t := uint32(currentTimeSecs)
e := &Entity{
PrimaryKey: packet.NewRSAPublicKey(t, &signingPriv.PublicKey, false /* not a subkey */ ),
PrivateKey: packet.NewRSAPrivateKey(t, signingPriv, false /* not a subkey */ ),
Identities: make(map[string]*Identity),
}
isPrimaryId := true
e.Identities[uid.Id] = &Identity{
Name: uid.Name,
UserId: uid,
SelfSignature: &packet.Signature{
CreationTime: t,
SigType: packet.SigTypePositiveCert,
PubKeyAlgo: packet.PubKeyAlgoRSA,
Hash: crypto.SHA256,
IsPrimaryId: &isPrimaryId,
FlagsValid: true,
FlagSign: true,
FlagCertify: true,
IssuerKeyId: &e.PrimaryKey.KeyId,
},
}
e.Subkeys = make([]Subkey, 1)
e.Subkeys[0] = Subkey{
PublicKey: packet.NewRSAPublicKey(t, &encryptingPriv.PublicKey, true /* is a subkey */ ),
PrivateKey: packet.NewRSAPrivateKey(t, encryptingPriv, true /* is a subkey */ ),
Sig: &packet.Signature{
CreationTime: t,
SigType: packet.SigTypeSubkeyBinding,
PubKeyAlgo: packet.PubKeyAlgoRSA,
Hash: crypto.SHA256,
FlagsValid: true,
FlagEncryptStorage: true,
FlagEncryptCommunications: true,
IssuerKeyId: &e.PrimaryKey.KeyId,
},
}
return e, nil
}
// SerializePrivate serializes an Entity, including private key material, to
// the given Writer. For now, it must only be used on an Entity returned from
// NewEntity.
func (e *Entity) SerializePrivate(w io.Writer) (err os.Error) {
err = e.PrivateKey.Serialize(w)
if err != nil {
return
}
for _, ident := range e.Identities {
err = ident.UserId.Serialize(w)
if err != nil {
return
}
err = ident.SelfSignature.SignUserId(ident.UserId.Id, e.PrimaryKey, e.PrivateKey)
if err != nil {
return
}
err = ident.SelfSignature.Serialize(w)
if err != nil {
return
}
}
for _, subkey := range e.Subkeys {
err = subkey.PrivateKey.Serialize(w)
if err != nil {
return
}
err = subkey.Sig.SignKey(subkey.PublicKey, e.PrivateKey)
if err != nil {
return
}
err = subkey.Sig.Serialize(w)
if err != nil {
return
}
}
return nil
}
// Serialize writes the public part of the given Entity to w. (No private
// key material will be output).
func (e *Entity) Serialize(w io.Writer) os.Error {
err := e.PrimaryKey.Serialize(w)
if err != nil {
return err
}
for _, ident := range e.Identities {
err = ident.UserId.Serialize(w)
if err != nil {
return err
}
err = ident.SelfSignature.Serialize(w)
if err != nil {
return err
}
for _, sig := range ident.Signatures {
err = sig.Serialize(w)
if err != nil {
return err
}
}
}
for _, subkey := range e.Subkeys {
err = subkey.PublicKey.Serialize(w)
if err != nil {
return err
}
err = subkey.Sig.Serialize(w)
if err != nil {
return err
}
}
return nil
}
// SignIdentity adds a signature to e, from signer, attesting that identity is
// associated with e. The provided identity must already be an element of
// e.Identities and the private key of signer must have been decrypted if
// necessary.
func (e *Entity) SignIdentity(identity string, signer *Entity) os.Error {
if signer.PrivateKey == nil {
return error.InvalidArgumentError("signing Entity must have a private key")
}
if signer.PrivateKey.Encrypted {
return error.InvalidArgumentError("signing Entity's private key must be decrypted")
}
ident, ok := e.Identities[identity]
if !ok {
return error.InvalidArgumentError("given identity string not found in Entity")
}
sig := &packet.Signature{
SigType: packet.SigTypeGenericCert,
PubKeyAlgo: signer.PrivateKey.PubKeyAlgo,
Hash: crypto.SHA256,
CreationTime: uint32(time.Seconds()),
IssuerKeyId: &signer.PrivateKey.KeyId,
}
if err := sig.SignKey(e.PrimaryKey, signer.PrivateKey); err != nil {
return err
}
ident.Signatures = append(ident.Signatures, sig)
return nil
}

View File

@ -5,6 +5,8 @@
package packet
import (
"big"
"crypto/openpgp/elgamal"
"crypto/openpgp/error"
"crypto/rand"
"crypto/rsa"
@ -14,14 +16,17 @@ import (
"strconv"
)
const encryptedKeyVersion = 3
// EncryptedKey represents a public-key encrypted session key. See RFC 4880,
// section 5.1.
type EncryptedKey struct {
KeyId uint64
Algo PublicKeyAlgorithm
Encrypted []byte
CipherFunc CipherFunction // only valid after a successful Decrypt
Key []byte // only valid after a successful Decrypt
encryptedMPI1, encryptedMPI2 []byte
}
func (e *EncryptedKey) parse(r io.Reader) (err os.Error) {
@ -30,37 +35,134 @@ func (e *EncryptedKey) parse(r io.Reader) (err os.Error) {
if err != nil {
return
}
if buf[0] != 3 {
if buf[0] != encryptedKeyVersion {
return error.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0])))
}
e.KeyId = binary.BigEndian.Uint64(buf[1:9])
e.Algo = PublicKeyAlgorithm(buf[9])
if e.Algo == PubKeyAlgoRSA || e.Algo == PubKeyAlgoRSAEncryptOnly {
e.Encrypted, _, err = readMPI(r)
switch e.Algo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
e.encryptedMPI1, _, err = readMPI(r)
case PubKeyAlgoElGamal:
e.encryptedMPI1, _, err = readMPI(r)
if err != nil {
return
}
e.encryptedMPI2, _, err = readMPI(r)
}
_, err = consumeAll(r)
return
}
// DecryptRSA decrypts an RSA encrypted session key with the given private key.
func (e *EncryptedKey) DecryptRSA(priv *rsa.PrivateKey) (err os.Error) {
if e.Algo != PubKeyAlgoRSA && e.Algo != PubKeyAlgoRSAEncryptOnly {
return error.InvalidArgumentError("EncryptedKey not RSA encrypted")
func checksumKeyMaterial(key []byte) uint16 {
var checksum uint16
for _, v := range key {
checksum += uint16(v)
}
b, err := rsa.DecryptPKCS1v15(rand.Reader, priv, e.Encrypted)
return checksum
}
// Decrypt decrypts an encrypted session key with the given private key. The
// private key must have been decrypted first.
func (e *EncryptedKey) Decrypt(priv *PrivateKey) os.Error {
var err os.Error
var b []byte
// TODO(agl): use session key decryption routines here to avoid
// padding oracle attacks.
switch priv.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
b, err = rsa.DecryptPKCS1v15(rand.Reader, priv.PrivateKey.(*rsa.PrivateKey), e.encryptedMPI1)
case PubKeyAlgoElGamal:
c1 := new(big.Int).SetBytes(e.encryptedMPI1)
c2 := new(big.Int).SetBytes(e.encryptedMPI2)
b, err = elgamal.Decrypt(priv.PrivateKey.(*elgamal.PrivateKey), c1, c2)
default:
err = error.InvalidArgumentError("cannot decrypted encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo)))
}
if err != nil {
return
return err
}
e.CipherFunc = CipherFunction(b[0])
e.Key = b[1 : len(b)-2]
expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1])
var checksum uint16
for _, v := range e.Key {
checksum += uint16(v)
}
checksum := checksumKeyMaterial(e.Key)
if checksum != expectedChecksum {
return error.StructuralError("EncryptedKey checksum incorrect")
}
return
return nil
}
// SerializeEncryptedKey serializes an encrypted key packet to w that contains
// key, encrypted to pub.
func SerializeEncryptedKey(w io.Writer, rand io.Reader, pub *PublicKey, cipherFunc CipherFunction, key []byte) os.Error {
var buf [10]byte
buf[0] = encryptedKeyVersion
binary.BigEndian.PutUint64(buf[1:9], pub.KeyId)
buf[9] = byte(pub.PubKeyAlgo)
keyBlock := make([]byte, 1 /* cipher type */ +len(key)+2 /* checksum */ )
keyBlock[0] = byte(cipherFunc)
copy(keyBlock[1:], key)
checksum := checksumKeyMaterial(key)
keyBlock[1+len(key)] = byte(checksum >> 8)
keyBlock[1+len(key)+1] = byte(checksum)
switch pub.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
return serializeEncryptedKeyRSA(w, rand, buf, pub.PublicKey.(*rsa.PublicKey), keyBlock)
case PubKeyAlgoElGamal:
return serializeEncryptedKeyElGamal(w, rand, buf, pub.PublicKey.(*elgamal.PublicKey), keyBlock)
case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly:
return error.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
}
return error.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
}
func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub *rsa.PublicKey, keyBlock []byte) os.Error {
cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock)
if err != nil {
return error.InvalidArgumentError("RSA encryption failed: " + err.String())
}
packetLen := 10 /* header length */ + 2 /* mpi size */ + len(cipherText)
err = serializeHeader(w, packetTypeEncryptedKey, packetLen)
if err != nil {
return err
}
_, err = w.Write(header[:])
if err != nil {
return err
}
return writeMPI(w, 8*uint16(len(cipherText)), cipherText)
}
func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, pub *elgamal.PublicKey, keyBlock []byte) os.Error {
c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock)
if err != nil {
return error.InvalidArgumentError("ElGamal encryption failed: " + err.String())
}
packetLen := 10 /* header length */
packetLen += 2 /* mpi size */ + (c1.BitLen()+7)/8
packetLen += 2 /* mpi size */ + (c2.BitLen()+7)/8
err = serializeHeader(w, packetTypeEncryptedKey, packetLen)
if err != nil {
return err
}
_, err = w.Write(header[:])
if err != nil {
return err
}
err = writeBig(w, c1)
if err != nil {
return err
}
return writeBig(w, c2)
}

View File

@ -6,6 +6,8 @@ package packet
import (
"big"
"bytes"
"crypto/rand"
"crypto/rsa"
"fmt"
"testing"
@ -19,7 +21,27 @@ func bigFromBase10(s string) *big.Int {
return b
}
func TestEncryptedKey(t *testing.T) {
var encryptedKeyPub = rsa.PublicKey{
E: 65537,
N: bigFromBase10("115804063926007623305902631768113868327816898845124614648849934718568541074358183759250136204762053879858102352159854352727097033322663029387610959884180306668628526686121021235757016368038585212410610742029286439607686208110250133174279811431933746643015923132833417396844716207301518956640020862630546868823"),
}
var encryptedKeyRSAPriv = &rsa.PrivateKey{
PublicKey: encryptedKeyPub,
D: bigFromBase10("32355588668219869544751561565313228297765464314098552250409557267371233892496951383426602439009993875125222579159850054973310859166139474359774543943714622292329487391199285040721944491839695981199720170366763547754915493640685849961780092241140181198779299712578774460837139360803883139311171713302987058393"),
}
var encryptedKeyPriv = &PrivateKey{
PublicKey: PublicKey{
PubKeyAlgo: PubKeyAlgoRSA,
},
PrivateKey: encryptedKeyRSAPriv,
}
func TestDecryptingEncryptedKey(t *testing.T) {
const encryptedKeyHex = "c18c032a67d68660df41c70104005789d0de26b6a50c985a02a13131ca829c413a35d0e6fa8d6842599252162808ac7439c72151c8c6183e76923fe3299301414d0c25a2f06a2257db3839e7df0ec964773f6e4c4ac7ff3b48c444237166dd46ba8ff443a5410dc670cb486672fdbe7c9dfafb75b4fea83af3a204fe2a7dfa86bd20122b4f3d2646cbeecb8f7be8"
const expectedKeyHex = "d930363f7e0308c333b9618617ea728963d8df993665ae7be1092d4926fd864b"
p, err := Read(readerFromHex(encryptedKeyHex))
if err != nil {
t.Errorf("error from Read: %s", err)
@ -36,19 +58,9 @@ func TestEncryptedKey(t *testing.T) {
return
}
pub := rsa.PublicKey{
E: 65537,
N: bigFromBase10("115804063926007623305902631768113868327816898845124614648849934718568541074358183759250136204762053879858102352159854352727097033322663029387610959884180306668628526686121021235757016368038585212410610742029286439607686208110250133174279811431933746643015923132833417396844716207301518956640020862630546868823"),
}
priv := &rsa.PrivateKey{
PublicKey: pub,
D: bigFromBase10("32355588668219869544751561565313228297765464314098552250409557267371233892496951383426602439009993875125222579159850054973310859166139474359774543943714622292329487391199285040721944491839695981199720170366763547754915493640685849961780092241140181198779299712578774460837139360803883139311171713302987058393"),
}
err = ek.DecryptRSA(priv)
err = ek.Decrypt(encryptedKeyPriv)
if err != nil {
t.Errorf("error from DecryptRSA: %s", err)
t.Errorf("error from Decrypt: %s", err)
return
}
@ -63,5 +75,52 @@ func TestEncryptedKey(t *testing.T) {
}
}
const encryptedKeyHex = "c18c032a67d68660df41c70104005789d0de26b6a50c985a02a13131ca829c413a35d0e6fa8d6842599252162808ac7439c72151c8c6183e76923fe3299301414d0c25a2f06a2257db3839e7df0ec964773f6e4c4ac7ff3b48c444237166dd46ba8ff443a5410dc670cb486672fdbe7c9dfafb75b4fea83af3a204fe2a7dfa86bd20122b4f3d2646cbeecb8f7be8"
const expectedKeyHex = "d930363f7e0308c333b9618617ea728963d8df993665ae7be1092d4926fd864b"
func TestEncryptingEncryptedKey(t *testing.T) {
key := []byte{1, 2, 3, 4}
const expectedKeyHex = "01020304"
const keyId = 42
pub := &PublicKey{
PublicKey: &encryptedKeyPub,
KeyId: keyId,
PubKeyAlgo: PubKeyAlgoRSAEncryptOnly,
}
buf := new(bytes.Buffer)
err := SerializeEncryptedKey(buf, rand.Reader, pub, CipherAES128, key)
if err != nil {
t.Errorf("error writing encrypted key packet: %s", err)
}
p, err := Read(buf)
if err != nil {
t.Errorf("error from Read: %s", err)
return
}
ek, ok := p.(*EncryptedKey)
if !ok {
t.Errorf("didn't parse an EncryptedKey, got %#v", p)
return
}
if ek.KeyId != keyId || ek.Algo != PubKeyAlgoRSAEncryptOnly {
t.Errorf("unexpected EncryptedKey contents: %#v", ek)
return
}
err = ek.Decrypt(encryptedKeyPriv)
if err != nil {
t.Errorf("error from Decrypt: %s", err)
return
}
if ek.CipherFunc != CipherAES128 {
t.Errorf("unexpected EncryptedKey contents: %#v", ek)
return
}
keyHex := fmt.Sprintf("%x", ek.Key)
if keyHex != expectedKeyHex {
t.Errorf("bad key, got %s want %x", keyHex, expectedKeyHex)
}
}

View File

@ -51,3 +51,40 @@ func (l *LiteralData) parse(r io.Reader) (err os.Error) {
l.Body = r
return
}
// SerializeLiteral serializes a literal data packet to w and returns a
// WriteCloser to which the data itself can be written and which MUST be closed
// on completion. The fileName is truncated to 255 bytes.
func SerializeLiteral(w io.WriteCloser, isBinary bool, fileName string, time uint32) (plaintext io.WriteCloser, err os.Error) {
var buf [4]byte
buf[0] = 't'
if isBinary {
buf[0] = 'b'
}
if len(fileName) > 255 {
fileName = fileName[:255]
}
buf[1] = byte(len(fileName))
inner, err := serializeStreamHeader(w, packetTypeLiteralData)
if err != nil {
return
}
_, err = inner.Write(buf[:2])
if err != nil {
return
}
_, err = inner.Write([]byte(fileName))
if err != nil {
return
}
binary.BigEndian.PutUint32(buf[:], time)
_, err = inner.Write(buf[:])
if err != nil {
return
}
plaintext = inner
return
}

View File

@ -24,6 +24,8 @@ type OnePassSignature struct {
IsLast bool
}
const onePassSignatureVersion = 3
func (ops *OnePassSignature) parse(r io.Reader) (err os.Error) {
var buf [13]byte
@ -31,7 +33,7 @@ func (ops *OnePassSignature) parse(r io.Reader) (err os.Error) {
if err != nil {
return
}
if buf[0] != 3 {
if buf[0] != onePassSignatureVersion {
err = error.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0])))
}
@ -47,3 +49,26 @@ func (ops *OnePassSignature) parse(r io.Reader) (err os.Error) {
ops.IsLast = buf[12] != 0
return
}
// Serialize marshals the given OnePassSignature to w.
func (ops *OnePassSignature) Serialize(w io.Writer) os.Error {
var buf [13]byte
buf[0] = onePassSignatureVersion
buf[1] = uint8(ops.SigType)
var ok bool
buf[2], ok = s2k.HashToHashId(ops.Hash)
if !ok {
return error.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash)))
}
buf[3] = uint8(ops.PubKeyAlgo)
binary.BigEndian.PutUint64(buf[4:12], ops.KeyId)
if ops.IsLast {
buf[12] = 1
}
if err := serializeHeader(w, packetTypeOnePassSignature, len(buf)); err != nil {
return err
}
_, err := w.Write(buf[:])
return err
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package packet implements parsing and serialisation of OpenPGP packets, as
// Package packet implements parsing and serialization of OpenPGP packets, as
// specified in RFC 4880.
package packet
@ -92,6 +92,46 @@ func (r *partialLengthReader) Read(p []byte) (n int, err os.Error) {
return
}
// partialLengthWriter writes a stream of data using OpenPGP partial lengths.
// See RFC 4880, section 4.2.2.4.
type partialLengthWriter struct {
w io.WriteCloser
lengthByte [1]byte
}
func (w *partialLengthWriter) Write(p []byte) (n int, err os.Error) {
for len(p) > 0 {
for power := uint(14); power < 32; power-- {
l := 1 << power
if len(p) >= l {
w.lengthByte[0] = 224 + uint8(power)
_, err = w.w.Write(w.lengthByte[:])
if err != nil {
return
}
var m int
m, err = w.w.Write(p[:l])
n += m
if err != nil {
return
}
p = p[l:]
break
}
}
}
return
}
func (w *partialLengthWriter) Close() os.Error {
w.lengthByte[0] = 0
_, err := w.w.Write(w.lengthByte[:])
if err != nil {
return err
}
return w.w.Close()
}
// A spanReader is an io.LimitReader, but it returns ErrUnexpectedEOF if the
// underlying Reader returns EOF before the limit has been reached.
type spanReader struct {
@ -195,6 +235,20 @@ func serializeHeader(w io.Writer, ptype packetType, length int) (err os.Error) {
return
}
// serializeStreamHeader writes an OpenPGP packet header to w where the
// length of the packet is unknown. It returns a io.WriteCloser which can be
// used to write the contents of the packet. See RFC 4880, section 4.2.
func serializeStreamHeader(w io.WriteCloser, ptype packetType) (out io.WriteCloser, err os.Error) {
var buf [1]byte
buf[0] = 0x80 | 0x40 | byte(ptype)
_, err = w.Write(buf[:])
if err != nil {
return
}
out = &partialLengthWriter{w: w}
return
}
// Packet represents an OpenPGP packet. Users are expected to try casting
// instances of this interface to specific packet types.
type Packet interface {
@ -301,12 +355,12 @@ type SignatureType uint8
const (
SigTypeBinary SignatureType = 0
SigTypeText = 1
SigTypeGenericCert = 0x10
SigTypePersonaCert = 0x11
SigTypeCasualCert = 0x12
SigTypePositiveCert = 0x13
SigTypeSubkeyBinding = 0x18
SigTypeText = 1
SigTypeGenericCert = 0x10
SigTypePersonaCert = 0x11
SigTypeCasualCert = 0x12
SigTypePositiveCert = 0x13
SigTypeSubkeyBinding = 0x18
)
// PublicKeyAlgorithm represents the different public key system specified for
@ -318,23 +372,43 @@ const (
PubKeyAlgoRSA PublicKeyAlgorithm = 1
PubKeyAlgoRSAEncryptOnly PublicKeyAlgorithm = 2
PubKeyAlgoRSASignOnly PublicKeyAlgorithm = 3
PubKeyAlgoElgamal PublicKeyAlgorithm = 16
PubKeyAlgoElGamal PublicKeyAlgorithm = 16
PubKeyAlgoDSA PublicKeyAlgorithm = 17
)
// CanEncrypt returns true if it's possible to encrypt a message to a public
// key of the given type.
func (pka PublicKeyAlgorithm) CanEncrypt() bool {
switch pka {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal:
return true
}
return false
}
// CanSign returns true if it's possible for a public key of the given type to
// sign a message.
func (pka PublicKeyAlgorithm) CanSign() bool {
switch pka {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA:
return true
}
return false
}
// CipherFunction represents the different block ciphers specified for OpenPGP. See
// http://www.iana.org/assignments/pgp-parameters/pgp-parameters.xhtml#pgp-parameters-13
type CipherFunction uint8
const (
CipherCAST5 = 3
CipherAES128 = 7
CipherAES192 = 8
CipherAES256 = 9
CipherCAST5 CipherFunction = 3
CipherAES128 CipherFunction = 7
CipherAES192 CipherFunction = 8
CipherAES256 CipherFunction = 9
)
// keySize returns the key size, in bytes, of cipher.
func (cipher CipherFunction) keySize() int {
// KeySize returns the key size, in bytes, of cipher.
func (cipher CipherFunction) KeySize() int {
switch cipher {
case CipherCAST5:
return cast5.KeySize
@ -386,6 +460,14 @@ func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err os.Error) {
return
}
// mpiLength returns the length of the given *big.Int when serialized as an
// MPI.
func mpiLength(n *big.Int) (mpiLengthInBytes int) {
mpiLengthInBytes = 2 /* MPI length */
mpiLengthInBytes += (n.BitLen() + 7) / 8
return
}
// writeMPI serializes a big integer to w.
func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) {
_, err = w.Write([]byte{byte(bitLength >> 8), byte(bitLength)})

View File

@ -210,3 +210,47 @@ func TestSerializeHeader(t *testing.T) {
}
}
}
func TestPartialLengths(t *testing.T) {
buf := bytes.NewBuffer(nil)
w := new(partialLengthWriter)
w.w = noOpCloser{buf}
const maxChunkSize = 64
var b [maxChunkSize]byte
var n uint8
for l := 1; l <= maxChunkSize; l++ {
for i := 0; i < l; i++ {
b[i] = n
n++
}
m, err := w.Write(b[:l])
if m != l {
t.Errorf("short write got: %d want: %d", m, l)
}
if err != nil {
t.Errorf("error from write: %s", err)
}
}
w.Close()
want := (maxChunkSize * (maxChunkSize + 1)) / 2
copyBuf := bytes.NewBuffer(nil)
r := &partialLengthReader{buf, 0, true}
m, err := io.Copy(copyBuf, r)
if m != int64(want) {
t.Errorf("short copy got: %d want: %d", m, want)
}
if err != nil {
t.Errorf("error from copy: %s", err)
}
copyBytes := copyBuf.Bytes()
for i := 0; i < want; i++ {
if copyBytes[i] != uint8(i) {
t.Errorf("bad pattern in copy at %d", i)
break
}
}
}

View File

@ -9,6 +9,7 @@ import (
"bytes"
"crypto/cipher"
"crypto/dsa"
"crypto/openpgp/elgamal"
"crypto/openpgp/error"
"crypto/openpgp/s2k"
"crypto/rsa"
@ -32,6 +33,13 @@ type PrivateKey struct {
iv []byte
}
func NewRSAPrivateKey(currentTimeSecs uint32, priv *rsa.PrivateKey, isSubkey bool) *PrivateKey {
pk := new(PrivateKey)
pk.PublicKey = *NewRSAPublicKey(currentTimeSecs, &priv.PublicKey, isSubkey)
pk.PrivateKey = priv
return pk
}
func (pk *PrivateKey) parse(r io.Reader) (err os.Error) {
err = (&pk.PublicKey).parse(r)
if err != nil {
@ -91,13 +99,90 @@ func (pk *PrivateKey) parse(r io.Reader) (err os.Error) {
return
}
func mod64kHash(d []byte) uint16 {
h := uint16(0)
for i := 0; i < len(d); i += 2 {
v := uint16(d[i]) << 8
if i+1 < len(d) {
v += uint16(d[i+1])
}
h += v
}
return h
}
func (pk *PrivateKey) Serialize(w io.Writer) (err os.Error) {
// TODO(agl): support encrypted private keys
buf := bytes.NewBuffer(nil)
err = pk.PublicKey.serializeWithoutHeaders(buf)
if err != nil {
return
}
buf.WriteByte(0 /* no encryption */ )
privateKeyBuf := bytes.NewBuffer(nil)
switch priv := pk.PrivateKey.(type) {
case *rsa.PrivateKey:
err = serializeRSAPrivateKey(privateKeyBuf, priv)
default:
err = error.InvalidArgumentError("non-RSA private key")
}
if err != nil {
return
}
ptype := packetTypePrivateKey
contents := buf.Bytes()
privateKeyBytes := privateKeyBuf.Bytes()
if pk.IsSubkey {
ptype = packetTypePrivateSubkey
}
err = serializeHeader(w, ptype, len(contents)+len(privateKeyBytes)+2)
if err != nil {
return
}
_, err = w.Write(contents)
if err != nil {
return
}
_, err = w.Write(privateKeyBytes)
if err != nil {
return
}
checksum := mod64kHash(privateKeyBytes)
var checksumBytes [2]byte
checksumBytes[0] = byte(checksum >> 8)
checksumBytes[1] = byte(checksum)
_, err = w.Write(checksumBytes[:])
return
}
func serializeRSAPrivateKey(w io.Writer, priv *rsa.PrivateKey) os.Error {
err := writeBig(w, priv.D)
if err != nil {
return err
}
err = writeBig(w, priv.Primes[1])
if err != nil {
return err
}
err = writeBig(w, priv.Primes[0])
if err != nil {
return err
}
return writeBig(w, priv.Precomputed.Qinv)
}
// Decrypt decrypts an encrypted private key using a passphrase.
func (pk *PrivateKey) Decrypt(passphrase []byte) os.Error {
if !pk.Encrypted {
return nil
}
key := make([]byte, pk.cipher.keySize())
key := make([]byte, pk.cipher.KeySize())
pk.s2k(key, passphrase)
block := pk.cipher.new(key)
cfb := cipher.NewCFBDecrypter(block, pk.iv)
@ -140,6 +225,8 @@ func (pk *PrivateKey) parsePrivateKey(data []byte) (err os.Error) {
return pk.parseRSAPrivateKey(data)
case PubKeyAlgoDSA:
return pk.parseDSAPrivateKey(data)
case PubKeyAlgoElGamal:
return pk.parseElGamalPrivateKey(data)
}
panic("impossible")
}
@ -193,3 +280,22 @@ func (pk *PrivateKey) parseDSAPrivateKey(data []byte) (err os.Error) {
return nil
}
func (pk *PrivateKey) parseElGamalPrivateKey(data []byte) (err os.Error) {
pub := pk.PublicKey.PublicKey.(*elgamal.PublicKey)
priv := new(elgamal.PrivateKey)
priv.PublicKey = *pub
buf := bytes.NewBuffer(data)
x, _, err := readMPI(buf)
if err != nil {
return
}
priv.X = new(big.Int).SetBytes(x)
pk.PrivateKey = priv
pk.Encrypted = false
pk.encryptedData = nil
return nil
}

View File

@ -8,30 +8,50 @@ import (
"testing"
)
var privateKeyTests = []struct {
privateKeyHex string
creationTime uint32
}{
{
privKeyRSAHex,
0x4cc349a8,
},
{
privKeyElGamalHex,
0x4df9ee1a,
},
}
func TestPrivateKeyRead(t *testing.T) {
packet, err := Read(readerFromHex(privKeyHex))
if err != nil {
t.Error(err)
return
}
for i, test := range privateKeyTests {
packet, err := Read(readerFromHex(test.privateKeyHex))
if err != nil {
t.Errorf("#%d: failed to parse: %s", i, err)
continue
}
privKey := packet.(*PrivateKey)
privKey := packet.(*PrivateKey)
if !privKey.Encrypted {
t.Error("private key isn't encrypted")
return
}
if !privKey.Encrypted {
t.Errorf("#%d: private key isn't encrypted", i)
continue
}
err = privKey.Decrypt([]byte("testing"))
if err != nil {
t.Error(err)
return
}
err = privKey.Decrypt([]byte("testing"))
if err != nil {
t.Errorf("#%d: failed to decrypt: %s", i, err)
continue
}
if privKey.CreationTime != 0x4cc349a8 || privKey.Encrypted {
t.Errorf("failed to parse, got: %#v", privKey)
if privKey.CreationTime != test.creationTime || privKey.Encrypted {
t.Errorf("#%d: bad result, got: %#v", i, privKey)
}
}
}
// Generated with `gpg --export-secret-keys "Test Key 2"`
const privKeyHex = "9501fe044cc349a8010400b70ca0010e98c090008d45d1ee8f9113bd5861fd57b88bacb7c68658747663f1e1a3b5a98f32fda6472373c024b97359cd2efc88ff60f77751adfbf6af5e615e6a1408cfad8bf0cea30b0d5f53aa27ad59089ba9b15b7ebc2777a25d7b436144027e3bcd203909f147d0e332b240cf63d3395f5dfe0df0a6c04e8655af7eacdf0011010001fe0303024a252e7d475fd445607de39a265472aa74a9320ba2dac395faa687e9e0336aeb7e9a7397e511b5afd9dc84557c80ac0f3d4d7bfec5ae16f20d41c8c84a04552a33870b930420e230e179564f6d19bb153145e76c33ae993886c388832b0fa042ddda7f133924f3854481533e0ede31d51278c0519b29abc3bf53da673e13e3e1214b52413d179d7f66deee35cac8eacb060f78379d70ef4af8607e68131ff529439668fc39c9ce6dfef8a5ac234d234802cbfb749a26107db26406213ae5c06d4673253a3cbee1fcbae58d6ab77e38d6e2c0e7c6317c48e054edadb5a40d0d48acb44643d998139a8a66bb820be1f3f80185bc777d14b5954b60effe2448a036d565c6bc0b915fcea518acdd20ab07bc1529f561c58cd044f723109b93f6fd99f876ff891d64306b5d08f48bab59f38695e9109c4dec34013ba3153488ce070268381ba923ee1eb77125b36afcb4347ec3478c8f2735b06ef17351d872e577fa95d0c397c88c71b59629a36aec"
const privKeyRSAHex = "9501fe044cc349a8010400b70ca0010e98c090008d45d1ee8f9113bd5861fd57b88bacb7c68658747663f1e1a3b5a98f32fda6472373c024b97359cd2efc88ff60f77751adfbf6af5e615e6a1408cfad8bf0cea30b0d5f53aa27ad59089ba9b15b7ebc2777a25d7b436144027e3bcd203909f147d0e332b240cf63d3395f5dfe0df0a6c04e8655af7eacdf0011010001fe0303024a252e7d475fd445607de39a265472aa74a9320ba2dac395faa687e9e0336aeb7e9a7397e511b5afd9dc84557c80ac0f3d4d7bfec5ae16f20d41c8c84a04552a33870b930420e230e179564f6d19bb153145e76c33ae993886c388832b0fa042ddda7f133924f3854481533e0ede31d51278c0519b29abc3bf53da673e13e3e1214b52413d179d7f66deee35cac8eacb060f78379d70ef4af8607e68131ff529439668fc39c9ce6dfef8a5ac234d234802cbfb749a26107db26406213ae5c06d4673253a3cbee1fcbae58d6ab77e38d6e2c0e7c6317c48e054edadb5a40d0d48acb44643d998139a8a66bb820be1f3f80185bc777d14b5954b60effe2448a036d565c6bc0b915fcea518acdd20ab07bc1529f561c58cd044f723109b93f6fd99f876ff891d64306b5d08f48bab59f38695e9109c4dec34013ba3153488ce070268381ba923ee1eb77125b36afcb4347ec3478c8f2735b06ef17351d872e577fa95d0c397c88c71b59629a36aec"
// Generated by `gpg --export-secret-keys` followed by a manual extraction of
// the ElGamal subkey from the packets.
const privKeyElGamalHex = "9d0157044df9ee1a100400eb8e136a58ec39b582629cdadf830bc64e0a94ed8103ca8bb247b27b11b46d1d25297ef4bcc3071785ba0c0bedfe89eabc5287fcc0edf81ab5896c1c8e4b20d27d79813c7aede75320b33eaeeaa586edc00fd1036c10133e6ba0ff277245d0d59d04b2b3421b7244aca5f4a8d870c6f1c1fbff9e1c26699a860b9504f35ca1d700030503fd1ededd3b840795be6d9ccbe3c51ee42e2f39233c432b831ddd9c4e72b7025a819317e47bf94f9ee316d7273b05d5fcf2999c3a681f519b1234bbfa6d359b4752bd9c3f77d6b6456cde152464763414ca130f4e91d91041432f90620fec0e6d6b5116076c2985d5aeaae13be492b9b329efcaf7ee25120159a0a30cd976b42d7afe030302dae7eb80db744d4960c4df930d57e87fe81412eaace9f900e6c839817a614ddb75ba6603b9417c33ea7b6c93967dfa2bcff3fa3c74a5ce2c962db65b03aece14c96cbd0038fc"

View File

@ -7,6 +7,7 @@ package packet
import (
"big"
"crypto/dsa"
"crypto/openpgp/elgamal"
"crypto/openpgp/error"
"crypto/rsa"
"crypto/sha1"
@ -30,6 +31,28 @@ type PublicKey struct {
n, e, p, q, g, y parsedMPI
}
func fromBig(n *big.Int) parsedMPI {
return parsedMPI{
bytes: n.Bytes(),
bitLength: uint16(n.BitLen()),
}
}
// NewRSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
func NewRSAPublicKey(creationTimeSecs uint32, pub *rsa.PublicKey, isSubkey bool) *PublicKey {
pk := &PublicKey{
CreationTime: creationTimeSecs,
PubKeyAlgo: PubKeyAlgoRSA,
PublicKey: pub,
IsSubkey: isSubkey,
n: fromBig(pub.N),
e: fromBig(big.NewInt(int64(pub.E))),
}
pk.setFingerPrintAndKeyId()
return pk
}
func (pk *PublicKey) parse(r io.Reader) (err os.Error) {
// RFC 4880, section 5.5.2
var buf [6]byte
@ -47,6 +70,8 @@ func (pk *PublicKey) parse(r io.Reader) (err os.Error) {
err = pk.parseRSA(r)
case PubKeyAlgoDSA:
err = pk.parseDSA(r)
case PubKeyAlgoElGamal:
err = pk.parseElGamal(r)
default:
err = error.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
}
@ -54,14 +79,17 @@ func (pk *PublicKey) parse(r io.Reader) (err os.Error) {
return
}
pk.setFingerPrintAndKeyId()
return
}
func (pk *PublicKey) setFingerPrintAndKeyId() {
// RFC 4880, section 12.2
fingerPrint := sha1.New()
pk.SerializeSignaturePrefix(fingerPrint)
pk.Serialize(fingerPrint)
pk.serializeWithoutHeaders(fingerPrint)
copy(pk.Fingerprint[:], fingerPrint.Sum())
pk.KeyId = binary.BigEndian.Uint64(pk.Fingerprint[12:20])
return
}
// parseRSA parses RSA public key material from the given Reader. See RFC 4880,
@ -92,7 +120,7 @@ func (pk *PublicKey) parseRSA(r io.Reader) (err os.Error) {
return
}
// parseRSA parses DSA public key material from the given Reader. See RFC 4880,
// parseDSA parses DSA public key material from the given Reader. See RFC 4880,
// section 5.5.2.
func (pk *PublicKey) parseDSA(r io.Reader) (err os.Error) {
pk.p.bytes, pk.p.bitLength, err = readMPI(r)
@ -121,6 +149,30 @@ func (pk *PublicKey) parseDSA(r io.Reader) (err os.Error) {
return
}
// parseElGamal parses ElGamal public key material from the given Reader. See
// RFC 4880, section 5.5.2.
func (pk *PublicKey) parseElGamal(r io.Reader) (err os.Error) {
pk.p.bytes, pk.p.bitLength, err = readMPI(r)
if err != nil {
return
}
pk.g.bytes, pk.g.bitLength, err = readMPI(r)
if err != nil {
return
}
pk.y.bytes, pk.y.bitLength, err = readMPI(r)
if err != nil {
return
}
elgamal := new(elgamal.PublicKey)
elgamal.P = new(big.Int).SetBytes(pk.p.bytes)
elgamal.G = new(big.Int).SetBytes(pk.g.bytes)
elgamal.Y = new(big.Int).SetBytes(pk.y.bytes)
pk.PublicKey = elgamal
return
}
// SerializeSignaturePrefix writes the prefix for this public key to the given Writer.
// The prefix is used when calculating a signature over this public key. See
// RFC 4880, section 5.2.4.
@ -135,6 +187,10 @@ func (pk *PublicKey) SerializeSignaturePrefix(h hash.Hash) {
pLength += 2 + uint16(len(pk.q.bytes))
pLength += 2 + uint16(len(pk.g.bytes))
pLength += 2 + uint16(len(pk.y.bytes))
case PubKeyAlgoElGamal:
pLength += 2 + uint16(len(pk.p.bytes))
pLength += 2 + uint16(len(pk.g.bytes))
pLength += 2 + uint16(len(pk.y.bytes))
default:
panic("unknown public key algorithm")
}
@ -143,9 +199,40 @@ func (pk *PublicKey) SerializeSignaturePrefix(h hash.Hash) {
return
}
// Serialize marshals the PublicKey to w in the form of an OpenPGP public key
// packet, not including the packet header.
func (pk *PublicKey) Serialize(w io.Writer) (err os.Error) {
length := 6 // 6 byte header
switch pk.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
length += 2 + len(pk.n.bytes)
length += 2 + len(pk.e.bytes)
case PubKeyAlgoDSA:
length += 2 + len(pk.p.bytes)
length += 2 + len(pk.q.bytes)
length += 2 + len(pk.g.bytes)
length += 2 + len(pk.y.bytes)
case PubKeyAlgoElGamal:
length += 2 + len(pk.p.bytes)
length += 2 + len(pk.g.bytes)
length += 2 + len(pk.y.bytes)
default:
panic("unknown public key algorithm")
}
packetType := packetTypePublicKey
if pk.IsSubkey {
packetType = packetTypePublicSubkey
}
err = serializeHeader(w, packetType, length)
if err != nil {
return
}
return pk.serializeWithoutHeaders(w)
}
// serializeWithoutHeaders marshals the PublicKey to w in the form of an
// OpenPGP public key packet, not including the packet header.
func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err os.Error) {
var buf [6]byte
buf[0] = 4
buf[1] = byte(pk.CreationTime >> 24)
@ -164,13 +251,15 @@ func (pk *PublicKey) Serialize(w io.Writer) (err os.Error) {
return writeMPIs(w, pk.n, pk.e)
case PubKeyAlgoDSA:
return writeMPIs(w, pk.p, pk.q, pk.g, pk.y)
case PubKeyAlgoElGamal:
return writeMPIs(w, pk.p, pk.g, pk.y)
}
return error.InvalidArgumentError("bad public-key algorithm")
}
// CanSign returns true iff this public key can generate signatures
func (pk *PublicKey) CanSign() bool {
return pk.PubKeyAlgo != PubKeyAlgoRSAEncryptOnly && pk.PubKeyAlgo != PubKeyAlgoElgamal
return pk.PubKeyAlgo != PubKeyAlgoRSAEncryptOnly && pk.PubKeyAlgo != PubKeyAlgoElGamal
}
// VerifySignature returns nil iff sig is a valid signature, made by this
@ -194,14 +283,14 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E
switch pk.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey)
err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature)
err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature.bytes)
if err != nil {
return error.SignatureError("RSA verification failure")
}
return nil
case PubKeyAlgoDSA:
dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey)
if !dsa.Verify(dsaPublicKey, hashBytes, sig.DSASigR, sig.DSASigS) {
if !dsa.Verify(dsaPublicKey, hashBytes, new(big.Int).SetBytes(sig.DSASigR.bytes), new(big.Int).SetBytes(sig.DSASigS.bytes)) {
return error.SignatureError("DSA verification failure")
}
return nil
@ -211,34 +300,43 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E
panic("unreachable")
}
// VerifyKeySignature returns nil iff sig is a valid signature, make by this
// public key, of the public key in signed.
func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) (err os.Error) {
h := sig.Hash.New()
// keySignatureHash returns a Hash of the message that needs to be signed for
// pk to assert a subkey relationship to signed.
func keySignatureHash(pk, signed *PublicKey, sig *Signature) (h hash.Hash, err os.Error) {
h = sig.Hash.New()
if h == nil {
return error.UnsupportedError("hash function")
return nil, error.UnsupportedError("hash function")
}
// RFC 4880, section 5.2.4
pk.SerializeSignaturePrefix(h)
pk.Serialize(h)
pk.serializeWithoutHeaders(h)
signed.SerializeSignaturePrefix(h)
signed.Serialize(h)
signed.serializeWithoutHeaders(h)
return
}
// VerifyKeySignature returns nil iff sig is a valid signature, made by this
// public key, of signed.
func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) (err os.Error) {
h, err := keySignatureHash(pk, signed, sig)
if err != nil {
return err
}
return pk.VerifySignature(h, sig)
}
// VerifyUserIdSignature returns nil iff sig is a valid signature, make by this
// public key, of the given user id.
func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Error) {
h := sig.Hash.New()
// userIdSignatureHash returns a Hash of the message that needs to be signed
// to assert that pk is a valid key for id.
func userIdSignatureHash(id string, pk *PublicKey, sig *Signature) (h hash.Hash, err os.Error) {
h = sig.Hash.New()
if h == nil {
return error.UnsupportedError("hash function")
return nil, error.UnsupportedError("hash function")
}
// RFC 4880, section 5.2.4
pk.SerializeSignaturePrefix(h)
pk.Serialize(h)
pk.serializeWithoutHeaders(h)
var buf [5]byte
buf[0] = 0xb4
@ -249,6 +347,16 @@ func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Er
h.Write(buf[:])
h.Write([]byte(id))
return
}
// VerifyUserIdSignature returns nil iff sig is a valid signature, made by this
// public key, of id.
func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Error) {
h, err := userIdSignatureHash(id, pk, sig)
if err != nil {
return err
}
return pk.VerifySignature(h, sig)
}
@ -272,7 +380,7 @@ type parsedMPI struct {
bitLength uint16
}
// writeMPIs is a utility function for serialising several big integers to the
// writeMPIs is a utility function for serializing several big integers to the
// given Writer.
func writeMPIs(w io.Writer, mpis ...parsedMPI) (err os.Error) {
for _, mpi := range mpis {

View File

@ -28,12 +28,12 @@ func TestPublicKeyRead(t *testing.T) {
packet, err := Read(readerFromHex(test.hexData))
if err != nil {
t.Errorf("#%d: Read error: %s", i, err)
return
continue
}
pk, ok := packet.(*PublicKey)
if !ok {
t.Errorf("#%d: failed to parse, got: %#v", i, packet)
return
continue
}
if pk.PubKeyAlgo != test.pubKeyAlgo {
t.Errorf("#%d: bad public key algorithm got:%x want:%x", i, pk.PubKeyAlgo, test.pubKeyAlgo)
@ -57,6 +57,38 @@ func TestPublicKeyRead(t *testing.T) {
}
}
func TestPublicKeySerialize(t *testing.T) {
for i, test := range pubKeyTests {
packet, err := Read(readerFromHex(test.hexData))
if err != nil {
t.Errorf("#%d: Read error: %s", i, err)
continue
}
pk, ok := packet.(*PublicKey)
if !ok {
t.Errorf("#%d: failed to parse, got: %#v", i, packet)
continue
}
serializeBuf := bytes.NewBuffer(nil)
err = pk.Serialize(serializeBuf)
if err != nil {
t.Errorf("#%d: failed to serialize: %s", i, err)
continue
}
packet, err = Read(serializeBuf)
if err != nil {
t.Errorf("#%d: Read error (from serialized data): %s", i, err)
continue
}
pk, ok = packet.(*PublicKey)
if !ok {
t.Errorf("#%d: failed to parse serialized data, got: %#v", i, packet)
continue
}
}
}
const rsaFingerprintHex = "5fb74b1d03b1e3cb31bc2f8aa34d7e18c20c31bb"
const rsaPkDataHex = "988d044d3c5c10010400b1d13382944bd5aba23a4312968b5095d14f947f600eb478e14a6fcb16b0e0cac764884909c020bc495cfcc39a935387c661507bdb236a0612fb582cac3af9b29cc2c8c70090616c41b662f4da4c1201e195472eb7f4ae1ccbcbf9940fe21d985e379a5563dde5b9a23d35f1cfaa5790da3b79db26f23695107bfaca8e7b5bcd0011010001"

View File

@ -5,7 +5,6 @@
package packet
import (
"big"
"crypto"
"crypto/dsa"
"crypto/openpgp/error"
@ -32,8 +31,11 @@ type Signature struct {
HashTag [2]byte
CreationTime uint32 // Unix epoch time
RSASignature []byte
DSASigR, DSASigS *big.Int
RSASignature parsedMPI
DSASigR, DSASigS parsedMPI
// rawSubpackets contains the unparsed subpackets, in order.
rawSubpackets []outputSubpacket
// The following are optional so are nil when not included in the
// signature.
@ -128,14 +130,11 @@ func (sig *Signature) parse(r io.Reader) (err os.Error) {
switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
sig.RSASignature, _, err = readMPI(r)
sig.RSASignature.bytes, sig.RSASignature.bitLength, err = readMPI(r)
case PubKeyAlgoDSA:
var rBytes, sBytes []byte
rBytes, _, err = readMPI(r)
sig.DSASigR = new(big.Int).SetBytes(rBytes)
sig.DSASigR.bytes, sig.DSASigR.bitLength, err = readMPI(r)
if err == nil {
sBytes, _, err = readMPI(r)
sig.DSASigS = new(big.Int).SetBytes(sBytes)
sig.DSASigS.bytes, sig.DSASigS.bitLength, err = readMPI(r)
}
default:
panic("unreachable")
@ -177,7 +176,11 @@ const (
// parseSignatureSubpacket parses a single subpacket. len(subpacket) is >= 1.
func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (rest []byte, err os.Error) {
// RFC 4880, section 5.2.3.1
var length uint32
var (
length uint32
packetType signatureSubpacketType
isCritical bool
)
switch {
case subpacket[0] < 192:
length = uint32(subpacket[0])
@ -207,10 +210,11 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
err = error.StructuralError("zero length signature subpacket")
return
}
packetType := subpacket[0] & 0x7f
isCritial := subpacket[0]&0x80 == 0x80
packetType = signatureSubpacketType(subpacket[0] & 0x7f)
isCritical = subpacket[0]&0x80 == 0x80
subpacket = subpacket[1:]
switch signatureSubpacketType(packetType) {
sig.rawSubpackets = append(sig.rawSubpackets, outputSubpacket{isHashed, packetType, isCritical, subpacket})
switch packetType {
case creationTimeSubpacket:
if !isHashed {
err = error.StructuralError("signature creation time in non-hashed area")
@ -309,7 +313,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
}
default:
if isCritial {
if isCritical {
err = error.UnsupportedError("unknown critical signature subpacket type " + strconv.Itoa(int(packetType)))
return
}
@ -381,7 +385,6 @@ func serializeSubpackets(to []byte, subpackets []outputSubpacket, hashed bool) {
// buildHashSuffix constructs the HashSuffix member of sig in preparation for signing.
func (sig *Signature) buildHashSuffix() (err os.Error) {
sig.outSubpackets = sig.buildSubpackets()
hashedSubpacketsLen := subpacketsLength(sig.outSubpackets, true)
var ok bool
@ -393,7 +396,7 @@ func (sig *Signature) buildHashSuffix() (err os.Error) {
sig.HashSuffix[3], ok = s2k.HashToHashId(sig.Hash)
if !ok {
sig.HashSuffix = nil
return error.InvalidArgumentError("hash cannot be repesented in OpenPGP: " + strconv.Itoa(int(sig.Hash)))
return error.InvalidArgumentError("hash cannot be represented in OpenPGP: " + strconv.Itoa(int(sig.Hash)))
}
sig.HashSuffix[4] = byte(hashedSubpacketsLen >> 8)
sig.HashSuffix[5] = byte(hashedSubpacketsLen)
@ -420,45 +423,72 @@ func (sig *Signature) signPrepareHash(h hash.Hash) (digest []byte, err os.Error)
return
}
// SignRSA signs a message with an RSA private key. The hash, h, must contain
// Sign signs a message with a private key. The hash, h, must contain
// the hash of the message to be signed and will be mutated by this function.
// On success, the signature is stored in sig. Call Serialize to write it out.
func (sig *Signature) SignRSA(h hash.Hash, priv *rsa.PrivateKey) (err os.Error) {
func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey) (err os.Error) {
sig.outSubpackets = sig.buildSubpackets()
digest, err := sig.signPrepareHash(h)
if err != nil {
return
}
sig.RSASignature, err = rsa.SignPKCS1v15(rand.Reader, priv, sig.Hash, digest)
switch priv.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
sig.RSASignature.bytes, err = rsa.SignPKCS1v15(rand.Reader, priv.PrivateKey.(*rsa.PrivateKey), sig.Hash, digest)
sig.RSASignature.bitLength = uint16(8 * len(sig.RSASignature.bytes))
case PubKeyAlgoDSA:
r, s, err := dsa.Sign(rand.Reader, priv.PrivateKey.(*dsa.PrivateKey), digest)
if err == nil {
sig.DSASigR.bytes = r.Bytes()
sig.DSASigR.bitLength = uint16(8 * len(sig.DSASigR.bytes))
sig.DSASigS.bytes = s.Bytes()
sig.DSASigS.bitLength = uint16(8 * len(sig.DSASigS.bytes))
}
default:
err = error.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo)))
}
return
}
// SignDSA signs a message with a DSA private key. The hash, h, must contain
// the hash of the message to be signed and will be mutated by this function.
// On success, the signature is stored in sig. Call Serialize to write it out.
func (sig *Signature) SignDSA(h hash.Hash, priv *dsa.PrivateKey) (err os.Error) {
digest, err := sig.signPrepareHash(h)
// SignUserId computes a signature from priv, asserting that pub is a valid
// key for the identity id. On success, the signature is stored in sig. Call
// Serialize to write it out.
func (sig *Signature) SignUserId(id string, pub *PublicKey, priv *PrivateKey) os.Error {
h, err := userIdSignatureHash(id, pub, sig)
if err != nil {
return
return nil
}
sig.DSASigR, sig.DSASigS, err = dsa.Sign(rand.Reader, priv, digest)
return
return sig.Sign(h, priv)
}
// SignKey computes a signature from priv, asserting that pub is a subkey. On
// success, the signature is stored in sig. Call Serialize to write it out.
func (sig *Signature) SignKey(pub *PublicKey, priv *PrivateKey) os.Error {
h, err := keySignatureHash(&priv.PublicKey, pub, sig)
if err != nil {
return err
}
return sig.Sign(h, priv)
}
// Serialize marshals sig to w. SignRSA or SignDSA must have been called first.
func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
if sig.RSASignature == nil && sig.DSASigR == nil {
if len(sig.outSubpackets) == 0 {
sig.outSubpackets = sig.rawSubpackets
}
if sig.RSASignature.bytes == nil && sig.DSASigR.bytes == nil {
return error.InvalidArgumentError("Signature: need to call SignRSA or SignDSA before Serialize")
}
sigLength := 0
switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
sigLength = len(sig.RSASignature)
sigLength = 2 + len(sig.RSASignature.bytes)
case PubKeyAlgoDSA:
sigLength = 2 /* MPI length */
sigLength += (sig.DSASigR.BitLen() + 7) / 8
sigLength += 2 /* MPI length */
sigLength += (sig.DSASigS.BitLen() + 7) / 8
sigLength = 2 + len(sig.DSASigR.bytes)
sigLength += 2 + len(sig.DSASigS.bytes)
default:
panic("impossible")
}
@ -466,7 +496,7 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
unhashedSubpacketsLen := subpacketsLength(sig.outSubpackets, false)
length := len(sig.HashSuffix) - 6 /* trailer not included */ +
2 /* length of unhashed subpackets */ + unhashedSubpacketsLen +
2 /* hash tag */ + 2 /* length of signature MPI */ + sigLength
2 /* hash tag */ + sigLength
err = serializeHeader(w, packetTypeSignature, length)
if err != nil {
return
@ -493,12 +523,9 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
err = writeMPI(w, 8*uint16(len(sig.RSASignature)), sig.RSASignature)
err = writeMPIs(w, sig.RSASignature)
case PubKeyAlgoDSA:
err = writeBig(w, sig.DSASigR)
if err == nil {
err = writeBig(w, sig.DSASigS)
}
err = writeMPIs(w, sig.DSASigR, sig.DSASigS)
default:
panic("impossible")
}
@ -509,6 +536,7 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
type outputSubpacket struct {
hashed bool // true if this subpacket is in the hashed area.
subpacketType signatureSubpacketType
isCritical bool
contents []byte
}
@ -518,12 +546,12 @@ func (sig *Signature) buildSubpackets() (subpackets []outputSubpacket) {
creationTime[1] = byte(sig.CreationTime >> 16)
creationTime[2] = byte(sig.CreationTime >> 8)
creationTime[3] = byte(sig.CreationTime)
subpackets = append(subpackets, outputSubpacket{true, creationTimeSubpacket, creationTime})
subpackets = append(subpackets, outputSubpacket{true, creationTimeSubpacket, false, creationTime})
if sig.IssuerKeyId != nil {
keyId := make([]byte, 8)
binary.BigEndian.PutUint64(keyId, *sig.IssuerKeyId)
subpackets = append(subpackets, outputSubpacket{true, issuerSubpacket, keyId})
subpackets = append(subpackets, outputSubpacket{true, issuerSubpacket, false, keyId})
}
return

View File

@ -12,9 +12,7 @@ import (
)
func TestSignatureRead(t *testing.T) {
signatureData, _ := hex.DecodeString(signatureDataHex)
buf := bytes.NewBuffer(signatureData)
packet, err := Read(buf)
packet, err := Read(readerFromHex(signatureDataHex))
if err != nil {
t.Error(err)
return
@ -25,4 +23,20 @@ func TestSignatureRead(t *testing.T) {
}
}
const signatureDataHex = "89011c04000102000605024cb45112000a0910ab105c91af38fb158f8d07ff5596ea368c5efe015bed6e78348c0f033c931d5f2ce5db54ce7f2a7e4b4ad64db758d65a7a71773edeab7ba2a9e0908e6a94a1175edd86c1d843279f045b021a6971a72702fcbd650efc393c5474d5b59a15f96d2eaad4c4c426797e0dcca2803ef41c6ff234d403eec38f31d610c344c06f2401c262f0993b2e66cad8a81ebc4322c723e0d4ba09fe917e8777658307ad8329adacba821420741009dfe87f007759f0982275d028a392c6ed983a0d846f890b36148c7358bdb8a516007fac760261ecd06076813831a36d0459075d1befa245ae7f7fb103d92ca759e9498fe60ef8078a39a3beda510deea251ea9f0a7f0df6ef42060f20780360686f3e400e"
func TestSignatureReserialize(t *testing.T) {
packet, _ := Read(readerFromHex(signatureDataHex))
sig := packet.(*Signature)
out := new(bytes.Buffer)
err := sig.Serialize(out)
if err != nil {
t.Errorf("error reserializing: %s", err)
return
}
expected, _ := hex.DecodeString(signatureDataHex)
if !bytes.Equal(expected, out.Bytes()) {
t.Errorf("output doesn't match input (got vs expected):\n%s\n%s", hex.Dump(out.Bytes()), hex.Dump(expected))
}
}
const signatureDataHex = "c2c05c04000102000605024cb45112000a0910ab105c91af38fb158f8d07ff5596ea368c5efe015bed6e78348c0f033c931d5f2ce5db54ce7f2a7e4b4ad64db758d65a7a71773edeab7ba2a9e0908e6a94a1175edd86c1d843279f045b021a6971a72702fcbd650efc393c5474d5b59a15f96d2eaad4c4c426797e0dcca2803ef41c6ff234d403eec38f31d610c344c06f2401c262f0993b2e66cad8a81ebc4322c723e0d4ba09fe917e8777658307ad8329adacba821420741009dfe87f007759f0982275d028a392c6ed983a0d846f890b36148c7358bdb8a516007fac760261ecd06076813831a36d0459075d1befa245ae7f7fb103d92ca759e9498fe60ef8078a39a3beda510deea251ea9f0a7f0df6ef42060f20780360686f3e400e"

View File

@ -5,6 +5,7 @@
package packet
import (
"bytes"
"crypto/cipher"
"crypto/openpgp/error"
"crypto/openpgp/s2k"
@ -27,6 +28,8 @@ type SymmetricKeyEncrypted struct {
encryptedKey []byte
}
const symmetricKeyEncryptedVersion = 4
func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err os.Error) {
// RFC 4880, section 5.3.
var buf [2]byte
@ -34,12 +37,12 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err os.Error) {
if err != nil {
return
}
if buf[0] != 4 {
if buf[0] != symmetricKeyEncryptedVersion {
return error.UnsupportedError("SymmetricKeyEncrypted version")
}
ske.CipherFunc = CipherFunction(buf[1])
if ske.CipherFunc.keySize() == 0 {
if ske.CipherFunc.KeySize() == 0 {
return error.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1])))
}
@ -75,7 +78,7 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) os.Error {
return nil
}
key := make([]byte, ske.CipherFunc.keySize())
key := make([]byte, ske.CipherFunc.KeySize())
ske.s2k(key, passphrase)
if len(ske.encryptedKey) == 0 {
@ -100,3 +103,60 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) os.Error {
ske.Encrypted = false
return nil
}
// SerializeSymmetricKeyEncrypted serializes a symmetric key packet to w. The
// packet contains a random session key, encrypted by a key derived from the
// given passphrase. The session key is returned and must be passed to
// SerializeSymmetricallyEncrypted.
func SerializeSymmetricKeyEncrypted(w io.Writer, rand io.Reader, passphrase []byte, cipherFunc CipherFunction) (key []byte, err os.Error) {
keySize := cipherFunc.KeySize()
if keySize == 0 {
return nil, error.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc)))
}
s2kBuf := new(bytes.Buffer)
keyEncryptingKey := make([]byte, keySize)
// s2k.Serialize salts and stretches the passphrase, and writes the
// resulting key to keyEncryptingKey and the s2k descriptor to s2kBuf.
err = s2k.Serialize(s2kBuf, keyEncryptingKey, rand, passphrase)
if err != nil {
return
}
s2kBytes := s2kBuf.Bytes()
packetLength := 2 /* header */ + len(s2kBytes) + 1 /* cipher type */ + keySize
err = serializeHeader(w, packetTypeSymmetricKeyEncrypted, packetLength)
if err != nil {
return
}
var buf [2]byte
buf[0] = symmetricKeyEncryptedVersion
buf[1] = byte(cipherFunc)
_, err = w.Write(buf[:])
if err != nil {
return
}
_, err = w.Write(s2kBytes)
if err != nil {
return
}
sessionKey := make([]byte, keySize)
_, err = io.ReadFull(rand, sessionKey)
if err != nil {
return
}
iv := make([]byte, cipherFunc.blockSize())
c := cipher.NewCFBEncrypter(cipherFunc.new(keyEncryptingKey), iv)
encryptedCipherAndKey := make([]byte, keySize+1)
c.XORKeyStream(encryptedCipherAndKey, buf[1:])
c.XORKeyStream(encryptedCipherAndKey[1:], sessionKey)
_, err = w.Write(encryptedCipherAndKey)
if err != nil {
return
}
key = sessionKey
return
}

View File

@ -6,6 +6,7 @@ package packet
import (
"bytes"
"crypto/rand"
"encoding/hex"
"io/ioutil"
"os"
@ -60,3 +61,41 @@ func TestSymmetricKeyEncrypted(t *testing.T) {
const symmetricallyEncryptedHex = "8c0d04030302371a0b38d884f02060c91cf97c9973b8e58e028e9501708ccfe618fb92afef7fa2d80ddadd93cf"
const symmetricallyEncryptedContentsHex = "cb1062004d14c4df636f6e74656e74732e0a"
func TestSerializeSymmetricKeyEncrypted(t *testing.T) {
buf := bytes.NewBuffer(nil)
passphrase := []byte("testing")
cipherFunc := CipherAES128
key, err := SerializeSymmetricKeyEncrypted(buf, rand.Reader, passphrase, cipherFunc)
if err != nil {
t.Errorf("failed to serialize: %s", err)
return
}
p, err := Read(buf)
if err != nil {
t.Errorf("failed to reparse: %s", err)
return
}
ske, ok := p.(*SymmetricKeyEncrypted)
if !ok {
t.Errorf("parsed a different packet type: %#v", p)
return
}
if !ske.Encrypted {
t.Errorf("SKE not encrypted but should be")
}
if ske.CipherFunc != cipherFunc {
t.Errorf("SKE cipher function is %d (expected %d)", ske.CipherFunc, cipherFunc)
}
err = ske.Decrypt(passphrase)
if err != nil {
t.Errorf("failed to decrypt reparsed SKE: %s", err)
return
}
if !bytes.Equal(key, ske.Key) {
t.Errorf("keys don't match after Decrpyt: %x (original) vs %x (parsed)", key, ske.Key)
}
}

View File

@ -7,6 +7,7 @@ package packet
import (
"crypto/cipher"
"crypto/openpgp/error"
"crypto/rand"
"crypto/sha1"
"crypto/subtle"
"hash"
@ -24,6 +25,8 @@ type SymmetricallyEncrypted struct {
prefix []byte
}
const symmetricallyEncryptedVersion = 1
func (se *SymmetricallyEncrypted) parse(r io.Reader) os.Error {
if se.MDC {
// See RFC 4880, section 5.13.
@ -32,7 +35,7 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) os.Error {
if err != nil {
return err
}
if buf[0] != 1 {
if buf[0] != symmetricallyEncryptedVersion {
return error.UnsupportedError("unknown SymmetricallyEncrypted version")
}
}
@ -44,7 +47,7 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) os.Error {
// packet can be read. An incorrect key can, with high probability, be detected
// immediately and this will result in a KeyIncorrect error being returned.
func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.ReadCloser, os.Error) {
keySize := c.keySize()
keySize := c.KeySize()
if keySize == 0 {
return nil, error.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c)))
}
@ -174,6 +177,9 @@ func (ser *seMDCReader) Read(buf []byte) (n int, err os.Error) {
return
}
// This is a new-format packet tag byte for a type 19 (MDC) packet.
const mdcPacketTagByte = byte(0x80) | 0x40 | 19
func (ser *seMDCReader) Close() os.Error {
if ser.error {
return error.SignatureError("error during reading")
@ -191,16 +197,95 @@ func (ser *seMDCReader) Close() os.Error {
}
}
// This is a new-format packet tag byte for a type 19 (MDC) packet.
const mdcPacketTagByte = byte(0x80) | 0x40 | 19
if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size {
return error.SignatureError("MDC packet not found")
}
ser.h.Write(ser.trailer[:2])
final := ser.h.Sum()
if subtle.ConstantTimeCompare(final, ser.trailer[2:]) == 1 {
if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 {
return error.SignatureError("hash mismatch")
}
return nil
}
// An seMDCWriter writes through to an io.WriteCloser while maintains a running
// hash of the data written. On close, it emits an MDC packet containing the
// running hash.
type seMDCWriter struct {
w io.WriteCloser
h hash.Hash
}
func (w *seMDCWriter) Write(buf []byte) (n int, err os.Error) {
w.h.Write(buf)
return w.w.Write(buf)
}
func (w *seMDCWriter) Close() (err os.Error) {
var buf [mdcTrailerSize]byte
buf[0] = mdcPacketTagByte
buf[1] = sha1.Size
w.h.Write(buf[:2])
digest := w.h.Sum()
copy(buf[2:], digest)
_, err = w.w.Write(buf[:])
if err != nil {
return
}
return w.w.Close()
}
// noOpCloser is like an ioutil.NopCloser, but for an io.Writer.
type noOpCloser struct {
w io.Writer
}
func (c noOpCloser) Write(data []byte) (n int, err os.Error) {
return c.w.Write(data)
}
func (c noOpCloser) Close() os.Error {
return nil
}
// SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet
// to w and returns a WriteCloser to which the to-be-encrypted packets can be
// written.
func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte) (contents io.WriteCloser, err os.Error) {
if c.KeySize() != len(key) {
return nil, error.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length")
}
writeCloser := noOpCloser{w}
ciphertext, err := serializeStreamHeader(writeCloser, packetTypeSymmetricallyEncryptedMDC)
if err != nil {
return
}
_, err = ciphertext.Write([]byte{symmetricallyEncryptedVersion})
if err != nil {
return
}
block := c.new(key)
blockSize := block.BlockSize()
iv := make([]byte, blockSize)
_, err = rand.Reader.Read(iv)
if err != nil {
return
}
s, prefix := cipher.NewOCFBEncrypter(block, iv, cipher.OCFBNoResync)
_, err = ciphertext.Write(prefix)
if err != nil {
return
}
plaintext := cipher.StreamWriter{S: s, W: ciphertext}
h := sha1.New()
h.Write(iv)
h.Write(iv[blockSize-2:])
contents = &seMDCWriter{w: plaintext, h: h}
return
}

View File

@ -9,6 +9,7 @@ import (
"crypto/openpgp/error"
"crypto/sha1"
"encoding/hex"
"io"
"io/ioutil"
"os"
"testing"
@ -76,3 +77,48 @@ func testMDCReader(t *testing.T) {
}
const mdcPlaintextHex = "a302789c3b2d93c4e0eb9aba22283539b3203335af44a134afb800c849cb4c4de10200aff40b45d31432c80cb384299a0655966d6939dfdeed1dddf980"
func TestSerialize(t *testing.T) {
buf := bytes.NewBuffer(nil)
c := CipherAES128
key := make([]byte, c.KeySize())
w, err := SerializeSymmetricallyEncrypted(buf, c, key)
if err != nil {
t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
return
}
contents := []byte("hello world\n")
w.Write(contents)
w.Close()
p, err := Read(buf)
if err != nil {
t.Errorf("error from Read: %s", err)
return
}
se, ok := p.(*SymmetricallyEncrypted)
if !ok {
t.Errorf("didn't read a *SymmetricallyEncrypted")
return
}
r, err := se.Decrypt(c, key)
if err != nil {
t.Errorf("error from Decrypt: %s", err)
return
}
contentsCopy := bytes.NewBuffer(nil)
_, err = io.Copy(contentsCopy, r)
if err != nil {
t.Errorf("error from io.Copy: %s", err)
return
}
if !bytes.Equal(contentsCopy.Bytes(), contents) {
t.Errorf("contents not equal got: %x want: %x", contentsCopy.Bytes(), contents)
}
}

Some files were not shown because too many files have changed in this diff Show More