Commit 0a24b8c7 authored by Robert Griesemer's avatar Robert Griesemer

math/big: sketched out complete set of Float/string conversion functions

Also:
- use io.ByteScanner rather than io.RuneScanner internally
- minor simplifications in Float.Add/Sub

Change-Id: Iae0e99384128dba9eccf68592c4fd389e2bd3b4f
Reviewed-on: https://go-review.googlesource.com/3380Reviewed-by: 's avatarRob Pike <r@golang.org>
parent 813e97b7
......@@ -179,7 +179,7 @@ func (x *Float) validate() {
const msb = 1 << (_W - 1)
m := len(x.mant)
if x.mant[m-1]&msb == 0 {
panic(fmt.Sprintf("msb not set in last word %#x of %s", x.mant[m-1], x.PString()))
panic(fmt.Sprintf("msb not set in last word %#x of %s", x.mant[m-1], x.pstring()))
}
if x.prec <= 0 {
panic(fmt.Sprintf("invalid precision %d", x.prec))
......@@ -566,10 +566,6 @@ func (z *Float) Neg(x *Float) *Float {
// z = x + y, ignoring signs of x and y.
// x and y must not be 0.
func (z *Float) uadd(x, y *Float) {
if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
panic("uadd called with 0 argument")
}
// Note: This implementation requires 2 shifts most of the
// time. It is also inefficient if exponents or precisions
// differ by wide margins. The following article describes
......@@ -580,60 +576,61 @@ func (z *Float) uadd(x, y *Float) {
// Point Addition With Exact Rounding (as in the MPFR Library)"
// http://www.vinc17.net/research/papers/rnc6.pdf
if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
panic("uadd called with 0 argument")
}
// compute exponents ex, ey for mantissa with "binary point"
// on the right (mantissa.0) - use int64 to avoid overflow
ex := int64(x.exp) - int64(len(x.mant))*_W
ey := int64(y.exp) - int64(len(y.mant))*_W
var e int64
// TODO(gri) having a combined add-and-shift primitive
// could make this code significantly faster
switch {
case ex < ey:
t := z.mant.shl(y.mant, uint(ey-ex))
z.mant = t.add(x.mant, t)
e = ex
default:
// ex == ey
// ex == ey, no shift needed
z.mant = z.mant.add(x.mant, y.mant)
e = ex
case ex > ey:
t := z.mant.shl(x.mant, uint(ex-ey))
z.mant = t.add(t, y.mant)
e = ey
ex = ey
}
// len(z.mant) > 0
z.setExp(e + int64(len(z.mant))*_W - int64(fnorm(z.mant)))
z.setExp(ex + int64(len(z.mant))*_W - int64(fnorm(z.mant)))
z.round(0)
}
// z = x - y for x >= y, ignoring signs of x and y.
// x and y must not be zero.
func (z *Float) usub(x, y *Float) {
// This code is symmetric to uadd.
// We have not factored the common code out because
// eventually uadd (and usub) should be optimized
// by special-casing, and the code will diverge.
if debugFloat && (len(x.mant) == 0 || len(y.mant) == 0) {
panic("usub called with 0 argument")
}
if x.exp < y.exp {
panic("underflow")
}
// This code is symmetric to uadd.
ex := int64(x.exp) - int64(len(x.mant))*_W
ey := int64(y.exp) - int64(len(y.mant))*_W
var e int64
switch {
case ex < ey:
t := z.mant.shl(y.mant, uint(ey-ex))
z.mant = t.sub(x.mant, t)
e = ex
default:
// ex == ey
// ex == ey, no shift needed
z.mant = z.mant.sub(x.mant, y.mant)
e = ex
case ex > ey:
t := z.mant.shl(x.mant, uint(ex-ey))
z.mant = t.sub(t, y.mant)
e = ey
ex = ey
}
// operands may have cancelled each other out
......@@ -644,7 +641,7 @@ func (z *Float) usub(x, y *Float) {
}
// len(z.mant) > 0
z.setExp(e + int64(len(z.mant))*_W - int64(fnorm(z.mant)))
z.setExp(ex + int64(len(z.mant))*_W - int64(fnorm(z.mant)))
z.round(0)
}
......@@ -962,13 +959,9 @@ func (x *Float) Sign() int {
return 1
}
func (x *Float) String() string {
return x.PString() // TODO(gri) fix this
}
// PString returns x as a string in the format ["-"] "0." mantissa "p" exponent
// pstring returns x as a string in the format ["-"] "0." mantissa "p" exponent
// with a hexadecimal mantissa and a decimal exponent, or ["-"] "0" if x is zero.
func (x *Float) PString() string {
func (x *Float) pstring() string {
// TODO(gri) handle Inf
var buf bytes.Buffer
if x.neg {
......@@ -1029,7 +1022,7 @@ func (z *Float) SetString(s string) (*Float, bool) {
// is 1.2 * 2**3. If the operation failed, the value of z is undefined but
// the returned value is nil.
//
func (z *Float) scan(r io.RuneScanner) (f *Float, err error) {
func (z *Float) scan(r io.ByteScanner) (f *Float, err error) {
// sign
z.neg, err = scanSign(r)
if err != nil {
......@@ -1091,3 +1084,106 @@ func (z *Float) scan(r io.RuneScanner) (f *Float, err error) {
return z, nil
}
// Scan scans the number corresponding to the longest possible prefix
// of r representing a floating-point number with a mantissa in the
// given conversion base (the exponent is always a decimal number).
// It returns the corresponding Float f, the actual base b, and an
// error err, if any. The number must be of the form:
//
// number = [ prefix ] [ sign ] mantissa [ exponent ] .
// mantissa = digits | digits "." [ digits ] | "." digits .
// prefix = prefix = "0" ( "x" | "X" | "b" | "B" ) .
// sign = "+" | "-" .
// exponent = ( "E" | "e" | "p" ) [ sign ] digits .
// digits = digit { digit } .
// digit = digit = "0" ... "9" | "a" ... "z" | "A" ... "Z" .
//
// The base argument must be 0 or a value between 2 and MaxBase, inclusive.
//
// For base 0, the number prefix determines the actual base: A prefix of
// ``0x'' or ``0X'' selects base 16, and a ``0b'' or ``0B'' prefix selects
// base 2; otherwise, the actual base is 10 and no prefix is permitted.
// Note that the octal prefix ``0'' is not supported.
//
// A "p" exponent indicates power of 2 for the exponent; for instance "1.2p3"
// with base 0 or 10 corresponds to the value 1.2 * 2**3.
//
// BUG(gri) Currently, Scan only accepts base 10.
func (z *Float) Scan(r io.ByteScanner, base int) (f *Float, b int, err error) {
if base != 10 {
err = fmt.Errorf("base %d not supported yet", base)
return
}
b = 10
f, err = z.scan(r)
return
}
// Parse is like z.Scan(r, base), but instead of reading from an
// io.ByteScanner, it parses the string s. An error is returned if the
// string contains invalid or trailing characters not belonging to the
// number.
//
// TODO(gri) define possible errors more precisely
func (z *Float) Parse(s string, base int) (f *Float, b int, err error) {
r := strings.NewReader(s)
if f, b, err = z.Scan(r, base); err != nil {
return
}
// entire string must have been consumed
var ch byte
if ch, err = r.ReadByte(); err != io.EOF {
if err == nil {
err = fmt.Errorf("expected end of string, found %q", ch)
}
}
return
}
// ScanFloat is like f.Scan(r, base) with f set to the given precision
// and rounding mode.
func ScanFloat(r io.ByteScanner, base int, prec uint, mode RoundingMode) (f *Float, b int, err error) {
return NewFloat(0, prec, mode).Scan(r, base)
}
// ParseFloat is like f.Parse(s, base) with f set to the given precision
// and rounding mode.
func ParseFloat(s string, base int, prec uint, mode RoundingMode) (f *Float, b int, err error) {
return NewFloat(0, prec, mode).Parse(s, base)
}
// Format converts the floating-point number x to a string according
// to the given format and precision prec.
//
// The format is one of
// 'e' (-d.dddde±dd, decimal exponent),
// 'E' (-d.ddddE±dd, decimal exponent),
// 'f' (-ddddd.dddd, no exponent),
// 'g' ('e' for large exponents, 'f' otherwise),
// 'G' ('E' for large exponents, 'f' otherwise),
// 'b' (-ddddddp±dd, binary exponent), or
// 'p' (-0.ddddp±dd, hexadecimal mantissa, binary exponent).
//
// The precision prec controls the number of digits (excluding the exponent)
// printed by the 'e', 'E', 'f', 'g', and 'G' formats. For 'e', 'E', and 'f'
// it is the number of digits after the decimal point. For 'g' and 'G' it is
// the total number of digits. A negative precision selects the smallest
// number of digits necessary such that ParseFloat will return f exactly.
// The prec value is ignored for the 'b' or 'p' format.
//
// BUG(gri) Currently, Format only accepts the 'p' format.
func (x *Float) Format(format byte, prec int) string {
if format != 'p' {
return fmt.Sprintf(`%c`, format)
}
return x.pstring()
}
// BUG(gri): Currently, String uses the 'p' (rather than 'g') format.
func (x *Float) String() string {
return x.Format('p', 0)
}
......@@ -208,7 +208,7 @@ func TestFloatSetUint64(t *testing.T) {
for _, want := range tests {
f := new(Float).SetUint64(want)
if got := f.Uint64(); got != want {
t.Errorf("got %d (%s); want %d", got, f.PString(), want)
t.Errorf("got %d (%s); want %d", got, f.pstring(), want)
}
}
}
......@@ -231,7 +231,7 @@ func TestFloatSetInt64(t *testing.T) {
}
f := new(Float).SetInt64(want)
if got := f.Int64(); got != want {
t.Errorf("got %d (%s); want %d", got, f.PString(), want)
t.Errorf("got %d (%s); want %d", got, f.pstring(), want)
}
}
}
......@@ -256,7 +256,7 @@ func TestFloatSetFloat64(t *testing.T) {
}
f := new(Float).SetFloat64(want)
if got, _ := f.Float64(); got != want {
t.Errorf("got %g (%s); want %g", got, f.PString(), want)
t.Errorf("got %g (%s); want %g", got, f.pstring(), want)
}
}
}
......@@ -687,7 +687,7 @@ func TestFromBits(t *testing.T) {
for _, test := range tests {
f := fromBits(test.bits...)
if got := f.PString(); got != test.want {
if got := f.pstring(); got != test.want {
t.Errorf("setBits(%v) = %s; want %s", test.bits, got, test.want)
}
}
......@@ -757,7 +757,7 @@ func TestFloatSetFloat64String(t *testing.T) {
}
}
func TestFloatPString(t *testing.T) {
func TestFloatpstring(t *testing.T) {
var tests = []struct {
x Float
want string
......@@ -768,7 +768,7 @@ func TestFloatPString(t *testing.T) {
{Float{mant: nat{0x87654321}, exp: -10}, "0.87654321p-10"},
}
for _, test := range tests {
if got := test.x.PString(); got != test.want {
if got := test.x.pstring(); got != test.want {
t.Errorf("%v: got %s; want %s", test.x, got, test.want)
}
}
......
......@@ -461,7 +461,7 @@ func (x *Int) Format(s fmt.State, ch rune) {
// ``0x'' or ``0X'' selects base 16; the ``0'' prefix selects base 8, and a
// ``0b'' or ``0B'' prefix selects base 2. Otherwise the selected base is 10.
//
func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, error) {
func (z *Int) scan(r io.ByteScanner, base int) (*Int, int, error) {
// determine sign
neg, err := scanSign(r)
if err != nil {
......@@ -478,9 +478,9 @@ func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, error) {
return z, base, nil
}
func scanSign(r io.RuneScanner) (neg bool, err error) {
var ch rune
if ch, _, err = r.ReadRune(); err != nil {
func scanSign(r io.ByteScanner) (neg bool, err error) {
var ch byte
if ch, err = r.ReadByte(); err != nil {
return false, err
}
switch ch {
......@@ -489,11 +489,29 @@ func scanSign(r io.RuneScanner) (neg bool, err error) {
case '+':
// nothing to do
default:
r.UnreadRune()
r.UnreadByte()
}
return
}
// byteReader is a local wrapper around fmt.ScanState;
// it implements the ByteReader interface.
type byteReader struct {
fmt.ScanState
}
func (r byteReader) ReadByte() (byte, error) {
ch, size, err := r.ReadRune()
if size != 1 && err == nil {
err = fmt.Errorf("invalid rune %#U", ch)
}
return byte(ch), err
}
func (r byteReader) UnreadByte() error {
return r.UnreadRune()
}
// Scan is a support routine for fmt.Scanner; it sets z to the value of
// the scanned number. It accepts the formats 'b' (binary), 'o' (octal),
// 'd' (decimal), 'x' (lowercase hexadecimal), and 'X' (uppercase hexadecimal).
......@@ -514,7 +532,7 @@ func (z *Int) Scan(s fmt.ScanState, ch rune) error {
default:
return errors.New("Int.Scan: invalid verb")
}
_, _, err := z.scan(s, base)
_, _, err := z.scan(byteReader{s}, base)
return err
}
......@@ -569,7 +587,7 @@ func (z *Int) SetString(s string, base int) (*Int, bool) {
if err != nil {
return nil, false
}
_, _, err = r.ReadRune()
_, err = r.ReadByte()
if err != io.EOF {
return nil, false
}
......
......@@ -668,7 +668,7 @@ func pow(x Word, n int) (p Word) {
// base == 1, only), and the number of fractional digits is -count. In this
// case, the value of the scanned number is res * 10**count.
//
func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error) {
func (z nat) scan(r io.ByteScanner, base int) (res nat, b, count int, err error) {
// reject illegal bases
if base < 0 || base > MaxBase {
err = errors.New("illegal number base")
......@@ -676,7 +676,7 @@ func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error)
}
// one char look-ahead
ch, _, err := r.ReadRune()
ch, err := r.ReadByte()
if err != nil {
return
}
......@@ -687,7 +687,7 @@ func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error)
// actual base is 10 unless there's a base prefix
b = 10
if ch == '0' {
switch ch, _, err = r.ReadRune(); err {
switch ch, err = r.ReadByte(); err {
case nil:
// possibly one of 0x, 0X, 0b, 0B
b = 8
......@@ -698,7 +698,7 @@ func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error)
b = 2
}
if b == 2 || b == 16 {
if ch, _, err = r.ReadRune(); err != nil {
if ch, err = r.ReadByte(); err != nil {
// io.EOF is also an error in this case
return
}
......@@ -736,7 +736,7 @@ func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error)
base = 10 // no 2nd decimal point permitted
dp = count
// advance
if ch, _, err = r.ReadRune(); err != nil {
if ch, err = r.ReadByte(); err != nil {
if err == io.EOF {
err = nil
break
......@@ -758,7 +758,7 @@ func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error)
d1 = MaxBase + 1
}
if d1 >= b1 {
r.UnreadRune() // ch does not belong to number anymore
r.UnreadByte() // ch does not belong to number anymore
break
}
count++
......@@ -775,7 +775,7 @@ func (z nat) scan(r io.RuneScanner, base int) (res nat, b, count int, err error)
}
// advance
if ch, _, err = r.ReadRune(); err != nil {
if ch, err = r.ReadByte(); err != nil {
if err == io.EOF {
err = nil
break
......
......@@ -585,7 +585,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
}
// there should be no unread characters left
if _, _, err = r.ReadRune(); err != io.EOF {
if _, err = r.ReadByte(); err != io.EOF {
return nil, false
}
......@@ -615,11 +615,11 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
return z, true
}
func scanExponent(r io.RuneScanner) (exp int64, base int, err error) {
func scanExponent(r io.ByteScanner) (exp int64, base int, err error) {
base = 10
var ch rune
if ch, _, err = r.ReadRune(); err != nil {
var ch byte
if ch, err = r.ReadByte(); err != nil {
if err == io.EOF {
err = nil // no exponent; same as e0
}
......@@ -632,7 +632,7 @@ func scanExponent(r io.RuneScanner) (exp int64, base int, err error) {
case 'p':
base = 2
default:
r.UnreadRune()
r.UnreadByte()
return // no exponent; same as e0
}
......@@ -650,7 +650,7 @@ func scanExponent(r io.RuneScanner) (exp int64, base int, err error) {
// since we only care about int64 values - the
// from-scratch scan is easy enough and faster
for i := 0; ; i++ {
if ch, _, err = r.ReadRune(); err != nil {
if ch, err = r.ReadByte(); err != nil {
if err != io.EOF || i == 0 {
return
}
......@@ -659,7 +659,7 @@ func scanExponent(r io.RuneScanner) (exp int64, base int, err error) {
}
if ch < '0' || '9' < ch {
if i == 0 {
r.UnreadRune()
r.UnreadByte()
err = fmt.Errorf("invalid exponent (missing digits)")
return
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment