Commit 78b0013a authored by Robert Griesemer's avatar Robert Griesemer

- changed general div/mod implementation to a faster algorithm

  (operates on 30bit values at a time instead of 20bit values)
- refactored and cleaned up lots of code
- more tests
- close to check-in as complete library

R=r
OCL=18326
CL=18326
parent 2d4f7ba0
...@@ -38,22 +38,18 @@ const LogB = LogW - LogH; ...@@ -38,22 +38,18 @@ const LogB = LogW - LogH;
const ( const (
L3 = LogB / 3; L2 = LogB / 2;
B3 = 1 << L3;
M3 = B3 - 1;
L2 = L3 * 2;
B2 = 1 << L2; B2 = 1 << L2;
M2 = B2 - 1; M2 = B2 - 1;
L = L3 * 3; L = L2 * 2;
B = 1 << L; B = 1 << L;
M = B - 1; M = B - 1;
) )
type ( type (
Digit3 uint32; Digit2 uint32;
Digit uint64; Digit uint64;
) )
...@@ -69,26 +65,205 @@ func assert(p bool) { ...@@ -69,26 +65,205 @@ func assert(p bool) {
} }
func IsSmall(x Digit) bool { // ----------------------------------------------------------------------------
return x < 1<<LogH; // Raw operations
func And1(z, x *[]Digit, y Digit) {
for i := len(x) - 1; i >= 0; i-- {
z[i] = x[i] & y;
}
} }
func Split(x Digit) (Digit, Digit) { func And(z, x, y *[]Digit) {
return x>>L, x&M; for i := len(x) - 1; i >= 0; i-- {
z[i] = x[i] & y[i];
}
} }
export func Dump(x *[]Digit) { func Or1(z, x *[]Digit, y Digit) {
print("[", len(x), "]");
for i := len(x) - 1; i >= 0; i-- { for i := len(x) - 1; i >= 0; i-- {
print(" ", x[i]); z[i] = x[i] | y;
} }
println();
} }
export func Dump3(x *[]Digit3) { func Or(z, x, y *[]Digit) {
for i := len(x) - 1; i >= 0; i-- {
z[i] = x[i] | y[i];
}
}
func Xor1(z, x *[]Digit, y Digit) {
for i := len(x) - 1; i >= 0; i-- {
z[i] = x[i] ^ y;
}
}
func Xor(z, x, y *[]Digit) {
for i := len(x) - 1; i >= 0; i-- {
z[i] = x[i] ^ y[i];
}
}
func Add1(z, x *[]Digit, c Digit) Digit {
n := len(x);
for i := 0; i < n; i++ {
t := c + x[i];
c, z[i] = t>>L, t&M
}
return c;
}
func Add(z, x, y *[]Digit) Digit {
var c Digit;
n := len(x);
for i := 0; i < n; i++ {
t := c + x[i] + y[i];
c, z[i] = t>>L, t&M
}
return c;
}
func Sub1(z, x *[]Digit, c Digit) Digit {
n := len(x);
for i := 0; i < n; i++ {
t := c + x[i];
c, z[i] = Digit(int64(t)>>L), t&M; // arithmetic shift!
}
return c;
}
func Sub(z, x, y *[]Digit) Digit {
var c Digit;
n := len(x);
for i := 0; i < n; i++ {
t := c + x[i] - y[i];
c, z[i] = Digit(int64(t)>>L), t&M; // arithmetic shift!
}
return c;
}
// Returns c = x*y div B, z = x*y mod B.
func Mul11(x, y Digit) (Digit, Digit) {
// Split x and y into 2 sub-digits each (in base sqrt(B)),
// multiply the digits separately while avoiding overflow,
// and return the product as two separate digits.
const L0 = (L + 1)/2;
const L1 = L - L0;
const DL = L0 - L1; // 0 or 1
const b = 1<<L0;
const m = b - 1;
// split x and y into sub-digits
// x = (x1*b + x0)
// y = (y1*b + y0)
x1, x0 := x>>L0, x&m;
y1, y0 := y>>L0, y&m;
// x*y = t2*b^2 + t1*b + t0
t0 := x0*y0;
t1 := x1*y0 + x0*y1;
t2 := x1*y1;
// compute the result digits but avoid overflow
// z = z1*B + z0 = x*y
z0 := (t1<<L0 + t0)&M;
z1 := t2<<DL + (t1 + t0>>L0)>>L1;
return z1, z0;
}
func Mul(z, x, y *[]Digit) {
n := len(x);
m := len(y);
for j := 0; j < m; j++ {
d := y[j];
if d != 0 {
c := Digit(0);
for i := 0; i < n; i++ {
// z[i+j] += c + x[i]*d;
z1, z0 := Mul11(x[i], d);
t := c + z[i+j] + z0;
c, z[i+j] = t>>L, t&M;
c += z1;
}
z[n+j] = c;
}
}
}
func Mul1(z, x *[]Digit2, y Digit2) Digit2 {
n := len(x);
var c Digit;
f := Digit(y);
for i := 0; i < n; i++ {
t := c + Digit(x[i])*f;
c, z[i] = t>>L2, Digit2(t&M2);
}
return Digit2(c);
}
func Div1(z, x *[]Digit2, y Digit2) Digit2 {
n := len(x);
var c Digit;
d := Digit(y);
for i := n-1; i >= 0; i-- {
t := c*B2 + Digit(x[i]);
c, z[i] = t%d, Digit2(t/d);
}
return Digit2(c);
}
func Shl(z, x *[]Digit, s uint) Digit {
assert(s <= L);
n := len(x);
var c Digit;
for i := 0; i < n; i++ {
c, z[i] = x[i] >> (L-s), x[i] << s & M | c;
}
return c;
}
func Shr(z, x *[]Digit, s uint) Digit {
assert(s <= L);
n := len(x);
var c Digit;
for i := n - 1; i >= 0; i-- {
c, z[i] = x[i] << (L-s) & M, x[i] >> s | c;
}
return c;
}
// ----------------------------------------------------------------------------
// Support
func IsSmall(x Digit) bool {
return x < 1<<LogH;
}
func Split(x Digit) (Digit, Digit) {
return x>>L, x&M;
}
export func Dump(x *[]Digit) {
print("[", len(x), "]"); print("[", len(x), "]");
for i := len(x) - 1; i >= 0; i-- { for i := len(x) - 1; i >= 0; i-- {
print(" ", x[i]); print(" ", x[i]);
...@@ -140,7 +315,7 @@ func Normalize(x *Natural) *Natural { ...@@ -140,7 +315,7 @@ func Normalize(x *Natural) *Natural {
} }
func Normalize3(x *[]Digit3) *[]Digit3 { func Normalize2(x *[]Digit2) *[]Digit2 {
n := len(x); n := len(x);
for n > 0 && x[n - 1] == 0 { n-- } for n > 0 && x[n - 1] == 0 { n-- }
if n < len(x) { if n < len(x) {
...@@ -150,9 +325,10 @@ func Normalize3(x *[]Digit3) *[]Digit3 { ...@@ -150,9 +325,10 @@ func Normalize3(x *[]Digit3) *[]Digit3 {
} }
func (x *Natural) IsZero() bool { // Predicates
return len(x) == 0;
} func (x *Natural) IsZero() bool { return len(x) == 0; }
func (x *Natural) IsOdd() bool { return len(x) > 0 && x[0]&1 != 0; }
func (x *Natural) Add(y *Natural) *Natural { func (x *Natural) Add(y *Natural) *Natural {
...@@ -161,13 +337,10 @@ func (x *Natural) Add(y *Natural) *Natural { ...@@ -161,13 +337,10 @@ func (x *Natural) Add(y *Natural) *Natural {
if n < m { if n < m {
return y.Add(x); return y.Add(x);
} }
assert(n >= m);
z := new(Natural, n + 1);
c := Digit(0); z := new(Natural, n + 1);
for i := 0; i < m; i++ { c, z[i] = Split(c + x[i] + y[i]); } c := Add(z[0 : m], x[0 : m], y);
for i := m; i < n; i++ { c, z[i] = Split(c + x[i]); } z[n] = Add1(z[m : n], x[m : n], c);
z[n] = c;
return Normalize(z); return Normalize(z);
} }
...@@ -176,19 +349,15 @@ func (x *Natural) Add(y *Natural) *Natural { ...@@ -176,19 +349,15 @@ func (x *Natural) Add(y *Natural) *Natural {
func (x *Natural) Sub(y *Natural) *Natural { func (x *Natural) Sub(y *Natural) *Natural {
n := len(x); n := len(x);
m := len(y); m := len(y);
assert(n >= m); if n < m {
z := new(Natural, n); panic("underflow")
c := Digit(0);
for i := 0; i < m; i++ {
t := c + x[i] - y[i];
c, z[i] = Digit(int64(t)>>L), t&M; // arithmetic shift!
} }
for i := m; i < n; i++ {
t := c + x[i]; z := new(Natural, n);
c, z[i] = Digit(int64(t)>>L), t&M; // arithmetic shift! c := Sub(z[0 : m], x[0 : m], y);
if Sub1(z[m : n], x[m : n], c) != 0 {
panic("underflow");
} }
assert(c == 0); // x.Sub(y) must be called with x >= y
return Normalize(z); return Normalize(z);
} }
...@@ -207,56 +376,12 @@ func (x* Natural) MulAdd1(a, c Digit) *Natural { ...@@ -207,56 +376,12 @@ func (x* Natural) MulAdd1(a, c Digit) *Natural {
} }
// Returns c = x*y div B, z = x*y mod B.
func Mul1(x, y Digit) (Digit, Digit) {
// Split x and y into 2 sub-digits each (in base sqrt(B)),
// multiply the digits separately while avoiding overflow,
// and return the product as two separate digits.
const L0 = (L + 1)/2;
const L1 = L - L0;
const DL = L0 - L1; // 0 or 1
const b = 1<<L0;
const m = b - 1;
// split x and y into sub-digits
// x = (x1*b + x0)
// y = (y1*b + y0)
x1, x0 := x>>L0, x&m;
y1, y0 := y>>L0, y&m;
// x*y = t2*b^2 + t1*b + t0
t0 := x0*y0;
t1 := x1*y0 + x0*y1;
t2 := x1*y1;
// compute the result digits but avoid overflow
// z = z1*B + z0 = x*y
z0 := (t1<<L0 + t0)&M;
z1 := t2<<DL + (t1 + t0>>L0)>>L1;
return z1, z0;
}
func (x *Natural) Mul(y *Natural) *Natural { func (x *Natural) Mul(y *Natural) *Natural {
n := len(x); n := len(x);
m := len(y); m := len(y);
z := new(Natural, n + m);
for j := 0; j < m; j++ { z := new(Natural, n + m);
d := y[j]; Mul(z, x, y);
if d != 0 {
c := Digit(0);
for i := 0; i < n; i++ {
// z[i+j] += c + x[i]*d;
z1, z0 := Mul1(x[i], d);
c, z[i+j] = Split(c + z[i+j] + z0);
c += z1;
}
z[n+j] = c;
}
}
return Normalize(z); return Normalize(z);
} }
...@@ -294,42 +419,26 @@ func (x *Natural) Pow(n uint) *Natural { ...@@ -294,42 +419,26 @@ func (x *Natural) Pow(n uint) *Natural {
} }
func Shl1(x, c Digit, s uint) (Digit, Digit) {
assert(s <= LogB);
return x >> (LogB - s), x << s & M | c
}
func Shr1(x, c Digit, s uint) (Digit, Digit) {
assert(s <= LogB);
return x << (LogB - s) & M, x >> s | c
}
func (x *Natural) Shl(s uint) *Natural { func (x *Natural) Shl(s uint) *Natural {
n := len(x); n := uint(len(x));
si := int(s / LogB); m := n + s/L;
s = s % LogB; z := new(Natural, m+1);
z := new(Natural, n + si + 1);
c := Digit(0); z[m] = Shl(z[m-n : m], x, s%L);
for i := 0; i < n; i++ { c, z[i+si] = Shl1(x[i], c, s); }
z[n+si] = c;
return Normalize(z); return Normalize(z);
} }
func (x *Natural) Shr(s uint) *Natural { func (x *Natural) Shr(s uint) *Natural {
n := len(x); n := uint(len(x));
si := int(s / LogB); m := n - s/L;
if si >= n { si = n; } if m > n { // check for underflow
s = s % LogB; m = 0;
assert(si <= n); }
z := new(Natural, n - si); z := new(Natural, m);
c := Digit(0); Shr(z, x[n-m : n], s%L);
for i := n - 1; i >= si; i-- { c, z[i-si] = Shr1(x[i], c, s); }
return Normalize(z); return Normalize(z);
} }
...@@ -337,68 +446,36 @@ func (x *Natural) Shr(s uint) *Natural { ...@@ -337,68 +446,36 @@ func (x *Natural) Shr(s uint) *Natural {
// DivMod needs multi-precision division which is not available if Digit // DivMod needs multi-precision division which is not available if Digit
// is already using the largest uint size. Split base before division, // is already using the largest uint size. Split base before division,
// and merge again after. Each Digit is split into 3 Digit3's. // and merge again after. Each Digit is split into 2 Digit2's.
func SplitBase(x *Natural) *[]Digit3 { func Unpack(x *Natural) *[]Digit2 {
// TODO Use Log() for better result - don't need Normalize3 at the end! // TODO Use Log() for better result - don't need Normalize2 at the end!
n := len(x); n := len(x);
z := new([]Digit3, n*3 + 1); // add space for extra digit (used by DivMod) z := new([]Digit2, n*2 + 1); // add space for extra digit (used by DivMod)
for i, j := 0, 0; i < n; i, j = i+1, j+3 { for i := 0; i < n; i++ {
t := x[i]; t := x[i];
z[j+0] = Digit3(t >> (L3*0) & M3); z[i*2] = Digit2(t & M2);
z[j+1] = Digit3(t >> (L3*1) & M3); z[i*2 + 1] = Digit2(t >> L2 & M2);
z[j+2] = Digit3(t >> (L3*2) & M3);
} }
return Normalize3(z); return Normalize2(z);
} }
func MergeBase(x *[]Digit3) *Natural { func Pack(x *[]Digit2) *Natural {
i := len(x); n := (len(x) + 1) / 2;
j := (i+2)/3; z := new(Natural, n);
z := new(Natural, j); if len(x) & 1 == 1 {
// handle odd len(x)
switch i%3 { n--;
case 1: z[j-1] = Digit(x[i-1]); i--; j--; z[n] = Digit(x[n*2]);
case 2: z[j-1] = Digit(x[i-1])<<L3 | Digit(x[i-2]); i -= 2; j--;
case 0:
} }
for i := 0; i < n; i++ {
for i >= 3 { z[i] = Digit(x[i*2 + 1]) << L2 | Digit(x[i*2]);
z[j-1] = ((Digit(x[i-1])<<L3) | Digit(x[i-2]))<<L3 | Digit(x[i-3]);
i -= 3;
j--;
} }
assert(j == 0);
return Normalize(z); return Normalize(z);
} }
func Split3(x Digit) (Digit, Digit3) {
return uint64(int64(x)>>L3), Digit3(x&M3)
}
func Product(x *[]Digit3, y Digit) {
n := len(x);
c := Digit(0);
for i := 0; i < n; i++ { c, x[i] = Split3(c + Digit(x[i])*y) }
assert(c == 0);
}
func Quotient(x *[]Digit3, y Digit) {
n := len(x);
c := Digit(0);
for i := n-1; i >= 0; i-- {
t := c*B3 + Digit(x[i]);
c, x[i] = t%y, Digit3(t/y);
}
assert(c == 0);
}
// Division and modulo computation - destroys x and y. Based on the // Division and modulo computation - destroys x and y. Based on the
// algorithms described in: // algorithms described in:
// //
...@@ -413,8 +490,8 @@ func Quotient(x *[]Digit3, y Digit) { ...@@ -413,8 +490,8 @@ func Quotient(x *[]Digit3, y Digit) {
// is described in 1), while 2) provides the background for a more // is described in 1), while 2) provides the background for a more
// accurate initial guess of the trial digit. // accurate initial guess of the trial digit.
func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) { func DivMod2(x, y *[]Digit2) (*[]Digit2, *[]Digit2) {
const b = B3; const b = B2;
n := len(x); n := len(x);
m := len(y); m := len(y);
...@@ -424,14 +501,9 @@ func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) { ...@@ -424,14 +501,9 @@ func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) {
if m == 1 { if m == 1 {
// division by single digit // division by single digit
d := Digit(y[0]); // result is shifted left by 1 in place!
c := Digit(0); x[0] = Div1(x[1 : n+1], x[0 : n], y[0]);
for i := n; i > 0; i-- {
t := c*b + Digit(x[i-1]);
c, x[i] = t%d, Digit3(t/d);
}
x[0] = Digit3(c);
} else if m > n { } else if m > n {
// quotient = 0, remainder = x // quotient = 0, remainder = x
// TODO in this case we shouldn't even split base - FIX THIS // TODO in this case we shouldn't even split base - FIX THIS
...@@ -444,24 +516,34 @@ func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) { ...@@ -444,24 +516,34 @@ func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) {
// normalize x and y // normalize x and y
f := b/(Digit(y[m-1]) + 1); f := b/(Digit(y[m-1]) + 1);
Product(x, f); Mul1(x, x, Digit2(f));
Product(y, f); Mul1(y, y, Digit2(f));
assert(b/2 <= y[m-1] && y[m-1] < b); // incorrect scaling assert(b/2 <= y[m-1] && y[m-1] < b); // incorrect scaling
d2 := Digit(y[m-1])*b + Digit(y[m-2]); y1, y2 := Digit(y[m-1]), Digit(y[m-2]);
d2 := Digit(y1)*b + Digit(y2);
for i := n-m; i >= 0; i-- { for i := n-m; i >= 0; i-- {
k := i+m; k := i+m;
// compute trial digit // compute trial digit
r3 := (Digit(x[k])*b + Digit(x[k-1]))*b + Digit(x[k-2]); var q Digit;
q := r3/d2; { // Knuth
if q >= b { q = b-1 } x0, x1, x2 := Digit(x[k]), Digit(x[k-1]), Digit(x[k-2]);
if x0 != y1 {
q = (x0*b + x1)/y1;
} else {
q = b-1;
}
for y2 * q > (x0*b + x1 - y1*q)*b + x2 {
q--
}
}
// subtract y*q // subtract y*q
c := Digit(0); c := Digit(0);
for j := 0; j < m; j++ { for j := 0; j < m; j++ {
t := c + Digit(x[i+j]) - Digit(y[j])*q; // arithmetic shift! t := c + Digit(x[i+j]) - Digit(y[j])*q; // arithmetic shift!
c, x[i+j] = Digit(int64(t)>>L3), Digit3(t&M3); c, x[i+j] = Digit(int64(t)>>L2), Digit2(t&M2);
} }
// correct if trial digit was too large // correct if trial digit was too large
...@@ -469,18 +551,20 @@ func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) { ...@@ -469,18 +551,20 @@ func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) {
// add y // add y
c := Digit(0); c := Digit(0);
for j := 0; j < m; j++ { for j := 0; j < m; j++ {
c, x[i+j] = Split3(c + Digit(x[i+j]) + Digit(y[j])); t := c + Digit(x[i+j]) + Digit(y[j]);
c, x[i+j] = uint64(int64(t) >> L2), Digit2(t & M2)
} }
assert(c + Digit(x[k]) == 0); assert(c + Digit(x[k]) == 0);
// correct trial digit // correct trial digit
q--; q--;
} }
x[k] = Digit3(q); x[k] = Digit2(q);
} }
// undo normalization for remainder // undo normalization for remainder
Quotient(x[0 : m], f); c := Div1(x[0 : m], x[0 : m], Digit2(f));
assert(c == 0);
} }
return x[m : n+1], x[0 : m]; return x[m : n+1], x[0 : m];
...@@ -488,14 +572,20 @@ func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) { ...@@ -488,14 +572,20 @@ func DivMod(x, y *[]Digit3) (*[]Digit3, *[]Digit3) {
func (x *Natural) Div(y *Natural) *Natural { func (x *Natural) Div(y *Natural) *Natural {
q, r := DivMod(SplitBase(x), SplitBase(y)); q, r := DivMod2(Unpack(x), Unpack(y));
return MergeBase(q); return Pack(q);
} }
func (x *Natural) Mod(y *Natural) *Natural { func (x *Natural) Mod(y *Natural) *Natural {
q, r := DivMod(SplitBase(x), SplitBase(y)); q, r := DivMod2(Unpack(x), Unpack(y));
return MergeBase(r); return Pack(r);
}
func (x *Natural) DivMod(y *Natural) (*Natural, *Natural) {
q, r := DivMod2(Unpack(x), Unpack(y));
return Pack(q), Pack(r);
} }
...@@ -520,17 +610,17 @@ func (x *Natural) Cmp(y *Natural) int { ...@@ -520,17 +610,17 @@ func (x *Natural) Cmp(y *Natural) int {
} }
func Log1(x Digit) int { func Log2(x Digit) int {
n := -1; n := -1;
for x != 0 { x = x >> 1; n++; } // BUG >>= broken for uint64 for x != 0 { x = x >> 1; n++; } // BUG >>= broken for uint64
return n; return n;
} }
func (x *Natural) Log() int { func (x *Natural) Log2() int {
n := len(x); n := len(x);
if n > 0 { if n > 0 {
n = (n - 1)*L + Log1(x[n - 1]); n = (n - 1)*L + Log2(x[n - 1]);
} else { } else {
n = -1; n = -1;
} }
...@@ -544,11 +634,10 @@ func (x *Natural) And(y *Natural) *Natural { ...@@ -544,11 +634,10 @@ func (x *Natural) And(y *Natural) *Natural {
if n < m { if n < m {
return y.And(x); return y.And(x);
} }
assert(n >= m);
z := new(Natural, n);
for i := 0; i < m; i++ { z[i] = x[i] & y[i]; } z := new(Natural, n);
for i := m; i < n; i++ { z[i] = x[i]; } And(z[0 : m], x[0 : m], y);
Or1(z[m : n], x[m : n], 0);
return Normalize(z); return Normalize(z);
} }
...@@ -560,11 +649,10 @@ func (x *Natural) Or(y *Natural) *Natural { ...@@ -560,11 +649,10 @@ func (x *Natural) Or(y *Natural) *Natural {
if n < m { if n < m {
return y.Or(x); return y.Or(x);
} }
assert(n >= m);
z := new(Natural, n);
for i := 0; i < m; i++ { z[i] = x[i] | y[i]; } z := new(Natural, n);
for i := m; i < n; i++ { z[i] = x[i]; } Or(z[0 : m], x[0 : m], y);
Or1(z[m : n], x[m : n], 0);
return Normalize(z); return Normalize(z);
} }
...@@ -576,24 +664,15 @@ func (x *Natural) Xor(y *Natural) *Natural { ...@@ -576,24 +664,15 @@ func (x *Natural) Xor(y *Natural) *Natural {
if n < m { if n < m {
return y.Xor(x); return y.Xor(x);
} }
assert(n >= m);
z := new(Natural, n);
for i := 0; i < m; i++ { z[i] = x[i] ^ y[i]; } z := new(Natural, n);
for i := m; i < n; i++ { z[i] = x[i]; } Xor(z[0 : m], x[0 : m], y);
Or1(z[m : n], x[m : n], 0);
return Normalize(z); return Normalize(z);
} }
func Copy(x *Natural) *Natural {
z := new(Natural, len(x));
//*z = *x; // BUG assignment does't work yet
for i := len(x) - 1; i >= 0; i-- { z[i] = x[i]; }
return z;
}
// Computes x = x div d (in place - the recv maybe modified) for "small" d's. // Computes x = x div d (in place - the recv maybe modified) for "small" d's.
// Returns updated x and x mod d. // Returns updated x and x mod d.
func (x *Natural) DivMod1(d Digit) (*Natural, Digit) { func (x *Natural) DivMod1(d Digit) (*Natural, Digit) {
...@@ -601,36 +680,36 @@ func (x *Natural) DivMod1(d Digit) (*Natural, Digit) { ...@@ -601,36 +680,36 @@ func (x *Natural) DivMod1(d Digit) (*Natural, Digit) {
c := Digit(0); c := Digit(0);
for i := len(x) - 1; i >= 0; i-- { for i := len(x) - 1; i >= 0; i-- {
c = c<<L + x[i]; t := c<<L + x[i];
x[i] = c/d; c, x[i] = t%d, t/d;
c %= d;
} }
return Normalize(x), c; return Normalize(x), c;
} }
func (x *Natural) String(base Digit) string { func (x *Natural) String(base uint) string {
if x.IsZero() { if x.IsZero() {
return "0"; return "0";
} }
// allocate string // allocate string
// TODO n is too small for bases < 10!!! assert(2 <= base && base <= 16);
assert(base >= 10); // for now n := (x.Log2() + 1) / Log2(Digit(base)) + 1; // TODO why the +1?
// approx. length: 1 char for 3 bits
n := x.Log()/3 + 10; // +10 (round up) - what is the right number?
s := new([]byte, n); s := new([]byte, n);
// convert // convert
const hex = "0123456789abcdef";
// don't destroy x, make a copy
t := new(Natural, len(x));
Or1(t, x, 0); // copy x
i := n; i := n;
x = Copy(x); // don't destroy recv for !t.IsZero() {
for !x.IsZero() {
i--; i--;
var d Digit; var d Digit;
x, d = x.DivMod1(base); t, d = t.DivMod1(Digit(base));
s[i] = hex[d]; s[i] = "0123456789abcdef"[d];
}; };
return string(s[i : n]); return string(s[i : n]);
...@@ -665,24 +744,24 @@ func (x *Natural) Gcd(y *Natural) *Natural { ...@@ -665,24 +744,24 @@ func (x *Natural) Gcd(y *Natural) *Natural {
} }
func HexValue(ch byte) Digit { func HexValue(ch byte) uint {
d := Digit(1 << LogH); d := uint(1 << LogH);
switch { switch {
case '0' <= ch && ch <= '9': d = Digit(ch - '0'); case '0' <= ch && ch <= '9': d = uint(ch - '0');
case 'a' <= ch && ch <= 'f': d = Digit(ch - 'a') + 10; case 'a' <= ch && ch <= 'f': d = uint(ch - 'a') + 10;
case 'A' <= ch && ch <= 'F': d = Digit(ch - 'A') + 10; case 'A' <= ch && ch <= 'F': d = uint(ch - 'A') + 10;
} }
return d; return d;
} }
// TODO auto-detect base if base argument is 0 // TODO auto-detect base if base argument is 0
export func NatFromString(s string, base Digit) *Natural { export func NatFromString(s string, base uint) *Natural {
x := NatZero; x := NatZero;
for i := 0; i < len(s); i++ { for i := 0; i < len(s); i++ {
d := HexValue(s[i]); d := HexValue(s[i]);
if d < base { if d < base {
x = x.MulAdd1(base, d); x = x.MulAdd1(Digit(base), Digit(d));
} else { } else {
break; break;
} }
...@@ -776,19 +855,31 @@ func (x *Integer) Mul(y *Integer) *Integer { ...@@ -776,19 +855,31 @@ func (x *Integer) Mul(y *Integer) *Integer {
func (x *Integer) Quo(y *Integer) *Integer { func (x *Integer) Quo(y *Integer) *Integer {
panic("UNIMPLEMENTED"); // x / y == x / y
return nil; // x / (-y) == -(x / y)
// (-x) / y == -(x / y)
// (-x) / (-y) == x / y
return &Integer{x.sign != y.sign, x.mant.Div(y.mant)};
} }
func (x *Integer) Rem(y *Integer) *Integer { func (x *Integer) Rem(y *Integer) *Integer {
panic("UNIMPLEMENTED"); // x % y == x % y
return nil; // x % (-y) == x % y
// (-x) % y == -(x % y)
// (-x) % (-y) == -(x % y)
return &Integer{y.sign, x.mant.Mod(y.mant)};
}
func (x *Integer) QuoRem(y *Integer) (*Integer, *Integer) {
q, r := x.mant.DivMod(y.mant);
return &Integer{x.sign != y.sign, q}, &Integer{y.sign, q};
} }
func (x *Integer) Div(y *Integer) *Integer { func (x *Integer) Div(y *Integer) *Integer {
panic("UNIMPLEMENTED"); q, r := x.mant.DivMod(y.mant);
return nil; return nil;
} }
...@@ -805,7 +896,7 @@ func (x *Integer) Cmp(y *Integer) int { ...@@ -805,7 +896,7 @@ func (x *Integer) Cmp(y *Integer) int {
} }
func (x *Integer) String(base Digit) string { func (x *Integer) String(base uint) string {
if x.mant.IsZero() { if x.mant.IsZero() {
return "0"; return "0";
} }
...@@ -817,7 +908,7 @@ func (x *Integer) String(base Digit) string { ...@@ -817,7 +908,7 @@ func (x *Integer) String(base Digit) string {
} }
export func IntFromString(s string, base Digit) *Integer { export func IntFromString(s string, base uint) *Integer {
// get sign, if any // get sign, if any
sign := false; sign := false;
if len(s) > 0 && (s[0] == '-' || s[0] == '+') { if len(s) > 0 && (s[0] == '-' || s[0] == '+') {
......
...@@ -38,14 +38,54 @@ func TEST_EQ(n uint, x, y *Big.Natural) { ...@@ -38,14 +38,54 @@ func TEST_EQ(n uint, x, y *Big.Natural) {
} }
func TestLog2() {
test_msg = "TestLog2A";
TEST(0, Big.Nat(1).Log2() == 0);
TEST(1, Big.Nat(2).Log2() == 1);
TEST(2, Big.Nat(3).Log2() == 1);
TEST(3, Big.Nat(4).Log2() == 2);
test_msg = "TestLog2B";
for i := uint(0); i < 100; i++ {
TEST(i, Big.Nat(1).Shl(i).Log2() == int(i));
}
}
func TestConv() { func TestConv() {
test_msg = "TestConv"; test_msg = "TestConvA";
TEST(0, a.Cmp(Big.Nat(991)) == 0); TEST(0, a.Cmp(Big.Nat(991)) == 0);
TEST(1, b.Cmp(Big.Fact(20)) == 0); TEST(1, b.Cmp(Big.Fact(20)) == 0);
TEST(2, c.Cmp(Big.Fact(100)) == 0); TEST(2, c.Cmp(Big.Fact(100)) == 0);
TEST(3, a.String(10) == sa); TEST(3, a.String(10) == sa);
TEST(4, b.String(10) == sb); TEST(4, b.String(10) == sb);
TEST(5, c.String(10) == sc); TEST(5, c.String(10) == sc);
test_msg = "TestConvB";
t := c.Mul(c);
for base := uint(2); base <= 16; base++ {
TEST_EQ(base, Big.NatFromString(t.String(base), base), t);
}
}
func Sum(n uint, scale *Big.Natural) *Big.Natural {
s := Big.Nat(0);
for ; n > 0; n-- {
s = s.Add(Big.Nat(uint64(n)).Mul(scale));
}
return s;
}
func TestAdd() {
test_msg = "TestAddA";
test_msg = "TestAddB";
for i := uint(0); i < 100; i++ {
t := Big.Nat(uint64(i));
TEST_EQ(i, Sum(i, c), t.Mul(t).Add(t).Shr(1).Mul(c));
}
} }
...@@ -163,7 +203,9 @@ func TestPop() { ...@@ -163,7 +203,9 @@ func TestPop() {
func main() { func main() {
TestLog2();
TestConv(); TestConv();
TestAdd();
TestShift(); TestShift();
TestMul(); TestMul();
TestDiv(); TestDiv();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment