Commit ac121316 authored by Rémy Oudompheng's avatar Rémy Oudompheng Committed by Robert Griesemer

math/big: correct quadratic space complexity in Mul.

The previous implementation used to have a O(n) recursion
depth for unbalanced inputs. A test is added to check that a
reasonable amount of bytes is allocated in this case.

Fixes #3807.

R=golang-dev, dsymonds, gri
CC=golang-dev, remy
https://golang.org/cl/6345075
parent eb1c03ea
......@@ -342,7 +342,7 @@ func alias(x, y nat) bool {
return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
}
// addAt implements z += x*(1<<(_W*i)); z must be long enough.
// addAt implements z += x<<(_W*i); z must be long enough.
// (we don't use nat.add because we need z to stay the same
// slice, and we don't need to normalize z after each addition)
func addAt(z, x nat, i int) {
......@@ -405,8 +405,8 @@ func (z nat) mul(x, y nat) nat {
// determine Karatsuba length k such that
//
// x = x1*b + x0
// y = y1*b + y0 (and k <= len(y), which implies k <= len(x))
// x = xh*b + x0 (0 <= x0 < b)
// y = yh*b + y0 (0 <= y0 < b)
// b = 1<<(_W*k) ("base" of digits xi, yi)
//
k := karatsubaLen(n)
......@@ -417,27 +417,41 @@ func (z nat) mul(x, y nat) nat {
y0 := y[0:k] // y0 is not normalized
z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
karatsuba(z, x0, y0)
z = z[0 : m+n] // z has final length but may be incomplete, upper portion is garbage
// If x1 and/or y1 are not 0, add missing terms to z explicitly:
//
// m+n 2*k 0
// z = [ ... | x0*y0 ]
// + [ x1*y1 ]
// + [ x1*y0 ]
// + [ x0*y1 ]
z = z[0 : m+n] // z has final length but may be incomplete
z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
// If xh != 0 or yh != 0, add the missing terms to z. For
//
// xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
// yh = y1*b (0 <= y1 < b)
//
// the missing terms are
//
// x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
//
// since all the yi for i > 1 are 0 by choice of k: If any of them
// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
// be a larger valid threshold contradicting the assumption about k.
//
if k < n || m != n {
x1 := x[k:] // x1 is normalized because x is
y1 := y[k:] // y1 is normalized because y is
var t nat
t = t.mul(x1, y1)
copy(z[2*k:], t)
z[2*k+len(t):].clear() // upper portion of z is garbage
t = t.mul(x1, y0.norm())
addAt(z, t, k)
t = t.mul(x0.norm(), y1)
addAt(z, t, k)
// add x0*y1*b
x0 := x0.norm()
y1 := y[k:] // y1 is normalized because y is
addAt(z, t.mul(x0, y1), k)
// add xi*y0<<i, xi*y1*b<<(i+k)
y0 := y0.norm()
for i := k; i < len(x); i += k {
xi := x[i:]
if len(xi) > k {
xi = xi[:k]
}
xi = xi.norm()
addAt(z, t.mul(xi, y0), i)
addAt(z, t.mul(xi, y1), i+k)
}
}
return z.norm()
......
......@@ -7,6 +7,7 @@ package big
import (
"io"
"math/rand"
"runtime"
"strings"
"testing"
)
......@@ -63,6 +64,36 @@ var prodNN = []argNN{
{nat{0, 0, 991 * 991}, nat{0, 991}, nat{0, 991}},
{nat{1 * 991, 2 * 991, 3 * 991, 4 * 991}, nat{1, 2, 3, 4}, nat{991}},
{nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}},
// 3^100 * 3^28 = 3^128
{
natFromString("11790184577738583171520872861412518665678211592275841109096961"),
natFromString("515377520732011331036461129765621272702107522001"),
natFromString("22876792454961"),
},
// z = 111....1 (70000 digits)
// x = 10^(99*700) + ... + 10^1400 + 10^700 + 1
// y = 111....1 (700 digits, larger than Karatsuba threshold on 32-bit and 64-bit)
{
natFromString(strings.Repeat("1", 70000)),
natFromString("1" + strings.Repeat(strings.Repeat("0", 699)+"1", 99)),
natFromString(strings.Repeat("1", 700)),
},
// z = 111....1 (20000 digits)
// x = 10^10000 + 1
// y = 111....1 (10000 digits)
{
natFromString(strings.Repeat("1", 20000)),
natFromString("1" + strings.Repeat("0", 9999) + "1"),
natFromString(strings.Repeat("1", 10000)),
},
}
func natFromString(s string) nat {
x, _, err := nat(nil).scan(strings.NewReader(s), 0)
if err != nil {
panic(err)
}
return x
}
func TestSet(t *testing.T) {
......@@ -136,6 +167,31 @@ func TestMulRangeN(t *testing.T) {
}
}
// allocBytes returns the number of bytes allocated by invoking f.
func allocBytes(f func()) uint64 {
var stats runtime.MemStats
runtime.ReadMemStats(&stats)
t := stats.TotalAlloc
f()
runtime.ReadMemStats(&stats)
return stats.TotalAlloc - t
}
// TestMulUnbalanced tests that multiplying numbers of different lengths
// does not cause deep recursion and in turn allocate too much memory.
// test case for issue 3807
func TestMulUnbalanced(t *testing.T) {
x := rndNat(50000)
y := rndNat(40)
allocSize := allocBytes(func() {
nat(nil).mul(x, y)
})
inputSize := uint64(len(x)+len(y)) * _S
if ratio := allocSize / uint64(inputSize); ratio > 10 {
t.Errorf("multiplication uses too much memory (%d > %d times the size of inputs)", allocSize, ratio)
}
}
var rnd = rand.New(rand.NewSource(0x43de683f473542af))
var mulx = rndNat(1e4)
var muly = rndNat(1e4)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment