Commit b2183701 authored by Robert Griesemer's avatar Robert Griesemer

big: implemented Karatsuba multiplication

Plus:
- calibration "test" - include in tests with gotest -calibrate
- basic Mul benchmark
- extra multiplication tests
- various cleanups

This change improves multiplication speed of numbers >= 30 words
in length (current threshold; found empirically with calibrate):

The multiplication benchmark (multiplication of a variety of long numbers)
improves by ~35%, individual multiplies can be significantly faster.

gotest -benchmarks=Mul
big.BenchmarkMul	     500	   6829290 ns/op (w/ Karatsuba)
big.BenchmarkMul	     100	  10600760 ns/op

There's no impact on pidigits for -n=10000 or -n=20000
because the operands are are too small.

R=rsc
CC=golang-dev
https://golang.org/cl/1004042
parent dc606a20
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file computes the Karatsuba threshold as a "test".
// Usage: gotest -calibrate
package big
import (
"flag"
"fmt"
"testing"
"time"
"unsafe" // for Sizeof
)
var calibrate = flag.Bool("calibrate", false, "run calibration test")
// makeNumber creates an n-word number 0xffff...ffff
func makeNumber(n int) *Int {
var w Word
b := make([]byte, n*unsafe.Sizeof(w))
for i := range b {
b[i] = 0xff
}
var x Int
x.SetBytes(b)
return &x
}
// measure returns the time to compute x*x in nanoseconds
func measure(f func()) int64 {
const N = 100
start := time.Nanoseconds()
for i := N; i > 0; i-- {
f()
}
stop := time.Nanoseconds()
return (stop - start) / N
}
func computeThreshold(t *testing.T) int {
// use a mix of numbers as work load
x := make([]*Int, 20)
for i := range x {
x[i] = makeNumber(10 * (i + 1))
}
threshold := -1
for n := 8; threshold < 0 || n <= threshold+20; n += 2 {
// set work load
f := func() {
var t Int
for _, x := range x {
t.Mul(x, x)
}
}
karatsubaThreshold = 1e9 // disable karatsuba
t1 := measure(f)
karatsubaThreshold = n // enable karatsuba
t2 := measure(f)
c := '<'
mark := ""
if t1 > t2 {
c = '>'
if threshold < 0 {
threshold = n
mark = " *"
}
}
fmt.Printf("%4d: %8d %c %8d%s\n", n, t1, c, t2, mark)
}
return threshold
}
func TestCalibrate(t *testing.T) {
if *calibrate {
fmt.Printf("Computing Karatsuba threshold\n")
fmt.Printf("threshold = %d\n", computeThreshold(t))
}
}
......@@ -230,7 +230,7 @@ Error:
// sets z to that value.
func (z *Int) SetBytes(b []byte) *Int {
s := int(_S)
z.abs = z.abs.make((len(b)+s-1)/s, false)
z.abs = z.abs.make((len(b) + s - 1) / s)
z.neg = false
j := 0
......@@ -386,7 +386,7 @@ func ProbablyPrime(z *Int, n int) bool { return !z.neg && z.abs.probablyPrime(n)
func (z *Int) Lsh(x *Int, n uint) *Int {
addedWords := int(n) / _W
// Don't assign z.abs yet, in case z == x
znew := z.abs.make(len(x.abs)+addedWords+1, false)
znew := z.abs.make(len(x.abs) + addedWords + 1)
z.neg = x.neg
znew[addedWords:].shiftLeft(x.abs, n%_W)
for i := range znew[0:addedWords] {
......@@ -401,7 +401,7 @@ func (z *Int) Lsh(x *Int, n uint) *Int {
func (z *Int) Rsh(x *Int, n uint) *Int {
removedWords := int(n) / _W
// Don't assign z.abs yet, in case z == x
znew := z.abs.make(len(x.abs)-removedWords, false)
znew := z.abs.make(len(x.abs) - removedWords)
z.neg = x.neg
znew.shiftRight(x.abs[removedWords:], n%_W)
z.abs = znew.norm()
......
......@@ -93,36 +93,55 @@ func TestProdZZ(t *testing.T) {
}
var facts = map[int]string{
0: "1",
1: "1",
2: "2",
10: "3628800",
20: "2432902008176640000",
100: "933262154439441526816992388562667004907159682643816214685929" +
"638952175999932299156089414639761565182862536979208272237582" +
"51185210916864000000000000000000000000",
}
// mulBytes returns x*y via grade school multiplication. Both inputs
// and the result are assumed to be in big-endian representation (to
// match the semantics of Int.Bytes and Int.SetBytes).
func mulBytes(x, y []byte) []byte {
z := make([]byte, len(x)+len(y))
// multiply
k0 := len(z) - 1
for j := len(y) - 1; j >= 0; j-- {
d := int(y[j])
if d != 0 {
k := k0
carry := 0
for i := len(x) - 1; i >= 0; i-- {
t := int(z[k]) + int(x[i])*d + carry
z[k], carry = byte(t), t>>8
k--
}
z[k] = byte(carry)
}
k0--
}
func fact(n int) *Int {
var z Int
z.New(1)
for i := 2; i <= n; i++ {
var t Int
t.New(int64(i))
z.Mul(&z, &t)
// normalize (remove leading 0's)
i := 0
for i < len(z) && z[i] == 0 {
i++
}
return &z
return z[i:]
}
func TestFact(t *testing.T) {
for n, s := range facts {
f := fact(n).String()
if f != s {
t.Errorf("%d! = %s; want %s", n, f, s)
}
func checkMul(a, b []byte) bool {
var x, y, z1 Int
x.SetBytes(a)
y.SetBytes(b)
z1.Mul(&x, &y)
var z2 Int
z2.SetBytes(mulBytes(a, b))
return z1.Cmp(&z2) == 0
}
func TestMul(t *testing.T) {
if err := quick.Check(checkMul, nil); err != nil {
t.Error(err)
}
}
......@@ -235,8 +254,7 @@ func checkSetBytes(b []byte) bool {
func TestSetBytes(t *testing.T) {
err := quick.Check(checkSetBytes, nil)
if err != nil {
if err := quick.Check(checkSetBytes, nil); err != nil {
t.Error(err)
}
}
......@@ -249,8 +267,7 @@ func checkBytes(b []byte) bool {
func TestBytes(t *testing.T) {
err := quick.Check(checkSetBytes, nil)
if err != nil {
if err := quick.Check(checkSetBytes, nil); err != nil {
t.Error(err)
}
}
......@@ -302,8 +319,7 @@ var divTests = []divTest{
func TestDiv(t *testing.T) {
err := quick.Check(checkDiv, nil)
if err != nil {
if err := quick.Check(checkDiv, nil); err != nil {
t.Error(err)
}
......@@ -676,6 +692,7 @@ var int64Tests = []int64{
-9223372036854775808,
}
func TestInt64(t *testing.T) {
for i, testVal := range int64Tests {
in := NewInt(testVal)
......
This diff is collapsed.
......@@ -111,6 +111,64 @@ func TestFunNN(t *testing.T) {
}
type mulRange struct {
a, b uint64
prod string
}
var mulRanges = []mulRange{
mulRange{0, 0, "0"},
mulRange{1, 1, "1"},
mulRange{1, 2, "2"},
mulRange{1, 3, "6"},
mulRange{1, 3, "6"},
mulRange{10, 10, "10"},
mulRange{0, 100, "0"},
mulRange{0, 1e9, "0"},
mulRange{100, 1, "1"}, // empty range
mulRange{1, 10, "3628800"}, // 10!
mulRange{1, 20, "2432902008176640000"}, // 20!
mulRange{1, 100,
"933262154439441526816992388562667004907159682643816214685929" +
"638952175999932299156089414639761565182862536979208272237582" +
"51185210916864000000000000000000000000", // 100!
},
}
func TestMulRange(t *testing.T) {
for i, r := range mulRanges {
prod := nat(nil).mulRange(r.a, r.b).string(10)
if prod != r.prod {
t.Errorf("%d: got %s; want %s", i, prod, r.prod)
}
}
}
var mulArg nat
func init() {
const n = 1000
mulArg = make(nat, n)
for i := 0; i < n; i++ {
mulArg[i] = _M
}
}
func BenchmarkMul(b *testing.B) {
for i := 0; i < b.N; i++ {
var t nat
for j := 1; j <= 10; j++ {
x := mulArg[0 : j*100]
t.mul(x, x)
}
}
}
type strN struct {
x nat
b int
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment