Commit 58e77990 authored by Robert Griesemer's avatar Robert Griesemer

big: use fast shift routines

- fixed a couple of bugs in the process
  (shift right was incorrect for negative numbers)
- added more tests and made some tests more robust
- changed pidigits back to using shifts to multiply
  by 2 instead of add

  This improves pidigit -s -n 10000 by approx. 5%:

  user 0m6.496s (old)
  user 0m6.156s (new)

R=rsc
CC=golang-dev
https://golang.org/cl/963044
parent 161b44c7
...@@ -216,13 +216,16 @@ func (z *Int) SetString(s string, base int) (*Int, bool) { ...@@ -216,13 +216,16 @@ func (z *Int) SetString(s string, base int) (*Int, bool) {
if scanned != len(s) { if scanned != len(s) {
goto Error goto Error
} }
if len(z.abs) == 0 {
z.neg = false // 0 has no sign
}
return z, true return z, true
Error: Error:
z.neg = false z.neg = false
z.abs = nil z.abs = nil
return nil, false return z, false
} }
...@@ -384,26 +387,24 @@ func ProbablyPrime(z *Int, n int) bool { return !z.neg && z.abs.probablyPrime(n) ...@@ -384,26 +387,24 @@ func ProbablyPrime(z *Int, n int) bool { return !z.neg && z.abs.probablyPrime(n)
// Lsh sets z = x << n and returns z. // Lsh sets z = x << n and returns z.
func (z *Int) Lsh(x *Int, n uint) *Int { func (z *Int) Lsh(x *Int, n uint) *Int {
addedWords := int(n) / _W
// Don't assign z.abs yet, in case z == x
znew := z.abs.make(len(x.abs) + addedWords + 1)
z.neg = x.neg z.neg = x.neg
znew[addedWords:].shiftLeft(x.abs, n%_W) z.abs = z.abs.shl(x.abs, n)
for i := range znew[0:addedWords] {
znew[i] = 0
}
z.abs = znew.norm()
return z return z
} }
// Rsh sets z = x >> n and returns z. // Rsh sets z = x >> n and returns z.
func (z *Int) Rsh(x *Int, n uint) *Int { func (z *Int) Rsh(x *Int, n uint) *Int {
removedWords := int(n) / _W if x.neg {
// Don't assign z.abs yet, in case z == x // (-x) >> s == ^(x-1) >> s == ^((x-1) >> s) == -(((x-1) >> s) + 1)
znew := z.abs.make(len(x.abs) - removedWords) z.neg = true
z.neg = x.neg t := z.abs.sub(x.abs, natOne) // no underflow because |x| > 0
znew.shiftRight(x.abs[removedWords:], n%_W) t = t.shr(t, n)
z.abs = znew.norm() z.abs = t.add(t, natOne)
return z
}
z.neg = false
z.abs = z.abs.shr(x.abs, n)
return z return z
} }
...@@ -562,6 +562,7 @@ type intShiftTest struct { ...@@ -562,6 +562,7 @@ type intShiftTest struct {
var rshTests = []intShiftTest{ var rshTests = []intShiftTest{
intShiftTest{"0", 0, "0"}, intShiftTest{"0", 0, "0"},
intShiftTest{"-0", 0, "0"},
intShiftTest{"0", 1, "0"}, intShiftTest{"0", 1, "0"},
intShiftTest{"0", 2, "0"}, intShiftTest{"0", 2, "0"},
intShiftTest{"1", 0, "1"}, intShiftTest{"1", 0, "1"},
...@@ -569,7 +570,12 @@ var rshTests = []intShiftTest{ ...@@ -569,7 +570,12 @@ var rshTests = []intShiftTest{
intShiftTest{"1", 2, "0"}, intShiftTest{"1", 2, "0"},
intShiftTest{"2", 0, "2"}, intShiftTest{"2", 0, "2"},
intShiftTest{"2", 1, "1"}, intShiftTest{"2", 1, "1"},
intShiftTest{"2", 2, "0"}, intShiftTest{"-1", 0, "-1"},
intShiftTest{"-1", 1, "-1"},
intShiftTest{"-1", 10, "-1"},
intShiftTest{"-100", 2, "-25"},
intShiftTest{"-100", 3, "-13"},
intShiftTest{"-100", 100, "-1"},
intShiftTest{"4294967296", 0, "4294967296"}, intShiftTest{"4294967296", 0, "4294967296"},
intShiftTest{"4294967296", 1, "2147483648"}, intShiftTest{"4294967296", 1, "2147483648"},
intShiftTest{"4294967296", 2, "1073741824"}, intShiftTest{"4294967296", 2, "1073741824"},
......
...@@ -554,8 +554,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) { ...@@ -554,8 +554,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) {
// D1. // D1.
shift := uint(leadingZeroBits(v[n-1])) shift := uint(leadingZeroBits(v[n-1]))
v.shiftLeft(v, shift) v.shiftLeftDeprecated(v, shift)
u.shiftLeft(uIn, shift) u.shiftLeftDeprecated(uIn, shift)
u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift)) u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift))
// D2. // D2.
...@@ -597,8 +597,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) { ...@@ -597,8 +597,8 @@ func (z nat) divLarge(z2, uIn, v nat) (q, r nat) {
} }
q = q.norm() q = q.norm()
u.shiftRight(u, shift) u.shiftRightDeprecated(u, shift)
v.shiftRight(v, shift) v.shiftRightDeprecated(v, shift)
r = u.norm() r = u.norm()
return q, r return q, r
...@@ -780,12 +780,56 @@ func trailingZeroBits(x Word) int { ...@@ -780,12 +780,56 @@ func trailingZeroBits(x Word) int {
} }
// TODO(gri) Make the shift routines faster. // z = x << s
// Use pidigits.go benchmark as a test case. func (z nat) shl(x nat, s uint) nat {
m := len(x)
if m == 0 {
return z.make(0)
}
// m > 0
// determine if z can be reused
// TODO(gri) change shlVW so we don't need this
if len(z) > 0 && alias(z, x) {
z = nil // z is an alias for x - cannot reuse
}
n := m + int(s/_W)
z = z.make(n + 1)
z[n] = shlVW(&z[n-m], &x[0], Word(s%_W), m)
return z.norm()
}
// z = x >> s
func (z nat) shr(x nat, s uint) nat {
m := len(x)
n := m - int(s/_W)
if n <= 0 {
return z.make(0)
}
// n > 0
// determine if z can be reused
// TODO(gri) change shrVW so we don't need this
if len(z) > 0 && alias(z, x) {
z = nil // z is an alias for x - cannot reuse
}
z = z.make(n)
shrVW(&z[0], &x[m-n], Word(s%_W), m)
return z.norm()
}
// TODO(gri) Remove these shift functions once shlVW and shrVW can be
// used directly in divLarge and powersOfTwoDecompose
//
// To avoid losing the top n bits, z should be sized so that // To avoid losing the top n bits, z should be sized so that
// len(z) == len(x) + 1. // len(z) == len(x) + 1.
func (z nat) shiftLeft(x nat, n uint) nat { func (z nat) shiftLeftDeprecated(x nat, n uint) nat {
if len(x) == 0 { if len(x) == 0 {
return x return x
} }
...@@ -805,7 +849,7 @@ func (z nat) shiftLeft(x nat, n uint) nat { ...@@ -805,7 +849,7 @@ func (z nat) shiftLeft(x nat, n uint) nat {
} }
func (z nat) shiftRight(x nat, n uint) nat { func (z nat) shiftRightDeprecated(x nat, n uint) nat {
if len(x) == 0 { if len(x) == 0 {
return x return x
} }
...@@ -850,7 +894,7 @@ func (n nat) powersOfTwoDecompose() (q nat, k Word) { ...@@ -850,7 +894,7 @@ func (n nat) powersOfTwoDecompose() (q nat, k Word) {
x := trailingZeroBits(n[zeroWords]) x := trailingZeroBits(n[zeroWords])
q = q.make(len(n) - zeroWords) q = q.make(len(n) - zeroWords)
q.shiftRight(n[zeroWords:], uint(x)) q.shiftRightDeprecated(n[zeroWords:], uint(x))
q = q.norm() q = q.norm()
k = Word(_W*zeroWords + x) k = Word(_W*zeroWords + x)
......
...@@ -230,9 +230,8 @@ type shiftTest struct { ...@@ -230,9 +230,8 @@ type shiftTest struct {
var leftShiftTests = []shiftTest{ var leftShiftTests = []shiftTest{
shiftTest{nil, 0, nil}, shiftTest{nil, 0, nil},
shiftTest{nil, 1, nil}, shiftTest{nil, 1, nil},
shiftTest{nat{0}, 0, nat{0}}, shiftTest{natOne, 0, natOne},
shiftTest{nat{1}, 0, nat{1}}, shiftTest{natOne, 1, natTwo},
shiftTest{nat{1}, 1, nat{2}},
shiftTest{nat{1 << (_W - 1)}, 1, nat{0}}, shiftTest{nat{1 << (_W - 1)}, 1, nat{0}},
shiftTest{nat{1 << (_W - 1), 0}, 1, nat{0, 1}}, shiftTest{nat{1 << (_W - 1), 0}, 1, nat{0, 1}},
} }
...@@ -240,11 +239,11 @@ var leftShiftTests = []shiftTest{ ...@@ -240,11 +239,11 @@ var leftShiftTests = []shiftTest{
func TestShiftLeft(t *testing.T) { func TestShiftLeft(t *testing.T) {
for i, test := range leftShiftTests { for i, test := range leftShiftTests {
dst := make(nat, len(test.out)) var z nat
dst.shiftLeft(test.in, test.shift) z = z.shl(test.in, test.shift)
for j, v := range dst { for j, d := range test.out {
if test.out[j] != v { if j >= len(z) || z[j] != d {
t.Errorf("#%d: got: %v want: %v", i, dst, test.out) t.Errorf("#%d: got: %v want: %v", i, z, test.out)
break break
} }
} }
...@@ -255,22 +254,21 @@ func TestShiftLeft(t *testing.T) { ...@@ -255,22 +254,21 @@ func TestShiftLeft(t *testing.T) {
var rightShiftTests = []shiftTest{ var rightShiftTests = []shiftTest{
shiftTest{nil, 0, nil}, shiftTest{nil, 0, nil},
shiftTest{nil, 1, nil}, shiftTest{nil, 1, nil},
shiftTest{nat{0}, 0, nat{0}}, shiftTest{natOne, 0, natOne},
shiftTest{nat{1}, 0, nat{1}}, shiftTest{natOne, 1, nil},
shiftTest{nat{1}, 1, nat{0}}, shiftTest{natTwo, 1, natOne},
shiftTest{nat{2}, 1, nat{1}}, shiftTest{nat{0, 1}, 1, nat{1 << (_W - 1)}},
shiftTest{nat{0, 1}, 1, nat{1 << (_W - 1), 0}}, shiftTest{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1)}},
shiftTest{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1), 0}},
} }
func TestShiftRight(t *testing.T) { func TestShiftRight(t *testing.T) {
for i, test := range rightShiftTests { for i, test := range rightShiftTests {
dst := make(nat, len(test.out)) var z nat
dst.shiftRight(test.in, test.shift) z = z.shr(test.in, test.shift)
for j, v := range dst { for j, d := range test.out {
if test.out[j] != v { if j >= len(z) || z[j] != d {
t.Errorf("#%d: got: %v want: %v", i, dst, test.out) t.Errorf("#%d: got: %v want: %v", i, z, test.out)
break break
} }
} }
......
...@@ -63,7 +63,7 @@ func extract_digit() int64 { ...@@ -63,7 +63,7 @@ func extract_digit() int64 {
} }
// Compute (numer * 3 + accum) / denom // Compute (numer * 3 + accum) / denom
tmp1.Add(numer, numer) // tmp1.Lsh(numer, 1) tmp1.Lsh(numer, 1)
tmp1.Add(tmp1, numer) tmp1.Add(tmp1, numer)
tmp1.Add(tmp1, accum) tmp1.Add(tmp1, accum)
tmp1.DivMod(tmp1, denom, tmp2) tmp1.DivMod(tmp1, denom, tmp2)
...@@ -84,7 +84,7 @@ func next_term(k int64) { ...@@ -84,7 +84,7 @@ func next_term(k int64) {
y2.New(k*2 + 1) y2.New(k*2 + 1)
bigk.New(k) bigk.New(k)
tmp1.Add(numer, numer) // tmp1.Lsh(numer, 1) tmp1.Lsh(numer, 1)
accum.Add(accum, tmp1) accum.Add(accum, tmp1)
accum.Mul(accum, y2) accum.Mul(accum, y2)
numer.Mul(numer, bigk) numer.Mul(numer, bigk)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment