Commit d8e3b16f authored by Alan Donovan's avatar Alan Donovan

exp/ssa: special-case 'range' loops based on type of range expression.

The lowering of ast.RangeStmt now has three distinct cases:

1) rangeIter for maps and strings; approximately:
    it = range x
    for {
      k, v, ok = next it
      if !ok { break }
      ...
    }
   The Range instruction and the interpreter's "iter"
   datatype are now restricted to these types.

2) rangeChan for channels; approximately:
    for {
      k, ok = <-x
      if !ok { break }
      ...
    }

3) rangeIndexed for slices, arrays, and *array; approximately:
    for k, l = 0, len(x); k < l; k++ {
      v = x[k]
      ...
    }

In all cases we now evaluate the side effects of the range expression
exactly once, per comments on http://code.google.com/p/go/issues/detail?id=4644.

However the exact spec wording is still being discussed in
https://golang.org/cl/7307083/.  Further (small)
changes may be required once the dust settles.

R=iant
CC=golang-dev
https://golang.org/cl/7303074
parent 1c1096ea
...@@ -1949,49 +1949,196 @@ func (b *Builder) forStmt(fn *Function, s *ast.ForStmt, label *lblock) { ...@@ -1949,49 +1949,196 @@ func (b *Builder) forStmt(fn *Function, s *ast.ForStmt, label *lblock) {
fn.currentBlock = done fn.currentBlock = done
} }
// rangeStmt emits to fn code for the range statement s, optionally // rangeIndexed emits to fn the header for an integer indexed loop
// labelled by label. // over array, *array or slice value x.
// The v result is defined only if tv is non-nil.
// //
func (b *Builder) rangeStmt(fn *Function, s *ast.RangeStmt, label *lblock) { func (b *Builder) rangeIndexed(fn *Function, x Value, tv types.Type) (k, v Value, loop, done *BasicBlock) {
// it := range x //
// jump loop // length = len(x)
// index = -1
// loop: (target of continue) // loop: (target of continue)
// okv := next it (ok, key, value?) // index++
// ok = extract okv #0 // if index < length goto body else done
// if ok goto body else done
// body: // body:
// t0 = extract okv #1 // k = index
// k = *t0 // v = x[index]
// t1 = extract okv #2
// v = *t1
// ...body... // ...body...
// jump loop // jump loop
// done: (target of break) // done: (target of break)
hasK := !isBlankIdent(s.Key) // Determine number of iterations.
hasV := s.Value != nil && !isBlankIdent(s.Value) var length Value
if arr, ok := deref(x.Type()).(*types.Array); ok {
// For array or *array, the number of iterations is
// known statically thanks to the type. We avoid a
// data dependence upon x, permitting later dead-code
// elimination if x is pure, static unrolling, etc.
// Ranging over a nil *array may have >0 iterations.
length = intLiteral(arr.Len)
} else {
// length = len(x).
var call Call
call.Func = b.globals[types.Universe.Lookup("len")]
call.Args = []Value{x}
call.setType(tInt)
length = fn.emit(&call)
}
index := fn.addLocal(tInt)
emitStore(fn, index, intLiteral(-1))
// Ranging over just the keys of a pointer to an array loop = fn.newBasicBlock("rangeindex.loop")
// doesn't (need to) evaluate the array: emitJump(fn, loop)
// for i := range (*[10]int)(nil) {...} fn.currentBlock = loop
// Instead it is transformed into a simple loop:
// i = -1 incr := &BinOp{
// jump loop Op: token.ADD,
X: emitLoad(fn, index),
Y: intLiteral(1),
}
incr.setType(tInt)
emitStore(fn, index, fn.emit(incr))
body := fn.newBasicBlock("rangeindex.body")
done = fn.newBasicBlock("rangeindex.done")
emitIf(fn, emitCompare(fn, token.LSS, incr, length), body, done)
fn.currentBlock = body
k = emitLoad(fn, index)
if tv != nil {
switch t := underlyingType(x.Type()).(type) {
case *types.Array:
instr := &Index{
X: x,
Index: k,
}
instr.setType(t.Elt)
v = fn.emit(instr)
case *types.Pointer: // *array
instr := &IndexAddr{
X: x,
Index: k,
}
instr.setType(pointer(t.Base.(*types.Array).Elt))
v = emitLoad(fn, fn.emit(instr))
case *types.Slice:
instr := &IndexAddr{
X: x,
Index: k,
}
instr.setType(pointer(t.Elt))
v = emitLoad(fn, fn.emit(instr))
default:
panic("rangeIndexed x:" + t.String())
}
}
return
}
// rangeIter emits to fn the header for a loop using
// Range/Next/Extract to iterate over map or string value x.
// tk and tv are the types of the key/value results k and v, or nil
// if the respective component is not wanted.
//
func (b *Builder) rangeIter(fn *Function, x Value, tk, tv types.Type) (k, v Value, loop, done *BasicBlock) {
//
// it = range x
// loop: (target of continue) // loop: (target of continue)
// increment i // okv = next it (ok, key, value)
// if i < 10 goto body else done // ok = extract okv #0
// if ok goto body else done
// body: // body:
// k = i // k = extract okv #1
// v = extract okv #2
// ...body... // ...body...
// jump loop // jump loop
// done: (target of break) // done: (target of break)
var arrayLen int64 = -1 //
if !hasV {
if ptr, ok := underlyingType(b.exprType(s.X)).(*types.Pointer); ok { rng := &Range{X: x}
if arr, ok := underlyingType(ptr.Base).(*types.Array); ok { rng.setType(tRangeIter)
arrayLen = arr.Len it := fn.emit(rng)
}
} loop = fn.newBasicBlock("rangeiter.loop")
emitJump(fn, loop)
fn.currentBlock = loop
okv := &Next{Iter: it}
okv.setType(&types.Result{Values: []*types.Var{
varOk,
{Name: "k", Type: tk},
{Name: "v", Type: tv},
}})
fn.emit(okv)
body := fn.newBasicBlock("rangeiter.body")
done = fn.newBasicBlock("rangeiter.done")
emitIf(fn, emitExtract(fn, okv, 0, tBool), body, done)
fn.currentBlock = body
if tk != nil {
k = emitExtract(fn, okv, 1, tk)
}
if tv != nil {
v = emitExtract(fn, okv, 2, tv)
}
return
}
// rangeChan emits to fn the header for a loop that receives from
// channel x until it fails.
// tk is the channel's element type, or nil if the k result is not
// wanted
//
func (b *Builder) rangeChan(fn *Function, x Value, tk types.Type) (k Value, loop, done *BasicBlock) {
//
// loop: (target of continue)
// ko = <-x (key, ok)
// ok = extract ko #1
// if ok goto body else done
// body:
// k = extract ko #0
// ...
// goto loop
// done: (target of break)
loop = fn.newBasicBlock("rangechan.loop")
emitJump(fn, loop)
fn.currentBlock = loop
recv := &UnOp{
Op: token.ARROW,
X: x,
CommaOk: true,
}
recv.setType(&types.Result{Values: []*types.Var{
{Name: "k", Type: tk},
varOk,
}})
ko := fn.emit(recv)
body := fn.newBasicBlock("rangechan.body")
done = fn.newBasicBlock("rangechan.done")
emitIf(fn, emitExtract(fn, ko, 1, tBool), body, done)
fn.currentBlock = body
if tk != nil {
k = emitExtract(fn, ko, 0, tk)
}
return
}
// rangeStmt emits to fn code for the range statement s, optionally
// labelled by label.
//
func (b *Builder) rangeStmt(fn *Function, s *ast.RangeStmt, label *lblock) {
var tk, tv types.Type
if !isBlankIdent(s.Key) {
tk = b.exprType(s.Key)
}
if s.Value != nil && !isBlankIdent(s.Value) {
tv = b.exprType(s.Value)
} }
// If iteration variables are defined (:=), this // If iteration variables are defined (:=), this
...@@ -2001,91 +2148,52 @@ func (b *Builder) rangeStmt(fn *Function, s *ast.RangeStmt, label *lblock) { ...@@ -2001,91 +2148,52 @@ func (b *Builder) rangeStmt(fn *Function, s *ast.RangeStmt, label *lblock) {
// using := never redeclares an existing variable; it // using := never redeclares an existing variable; it
// always creates a new one. // always creates a new one.
if s.Tok == token.DEFINE { if s.Tok == token.DEFINE {
if hasK { if tk != nil {
fn.addNamedLocal(b.obj(s.Key.(*ast.Ident))) fn.addNamedLocal(b.obj(s.Key.(*ast.Ident)))
} }
if hasV { if tv != nil {
fn.addNamedLocal(b.obj(s.Value.(*ast.Ident))) fn.addNamedLocal(b.obj(s.Value.(*ast.Ident)))
} }
} }
var ok Value x := b.expr(fn, s.X)
var okv *Next
var okvVars []*types.Var
var index *Alloc // *array index loops only
loop := fn.newBasicBlock("range.loop")
var body, done *BasicBlock
if arrayLen == -1 {
rng := &Range{X: b.expr(fn, s.X)}
rng.setType(tRangeIter)
it := fn.emit(rng)
emitJump(fn, loop) var k, v Value
fn.currentBlock = loop var loop, done *BasicBlock
switch rt := underlyingType(x.Type()).(type) {
case *types.Slice, *types.Array, *types.Pointer: // *array
k, v, loop, done = b.rangeIndexed(fn, x, tv)
okv = &Next{Iter: it} case *types.Chan:
okvVars = []*types.Var{ k, loop, done = b.rangeChan(fn, x, tk)
varOk,
{Name: "k", Type: tInvalid}, // mutated below
{Name: "v", Type: tInvalid}, // mutated below
}
okv.setType(&types.Result{Values: okvVars})
fn.emit(okv)
ok = emitExtract(fn, okv, 0, tBool)
} else {
index = fn.addLocal(tInt)
emitStore(fn, index, intLiteral(-1))
emitJump(fn, loop) case *types.Map, *types.Basic: // string
fn.currentBlock = loop k, v, loop, done = b.rangeIter(fn, x, tk, tv)
// TODO use emitArith here and elsewhere? default:
incr := &BinOp{ panic("Cannot range over: " + rt.String())
Op: token.ADD,
X: emitLoad(fn, index),
Y: intLiteral(1),
}
incr.setType(tInt)
emitStore(fn, index, fn.emit(incr))
ok = emitCompare(fn, token.LSS, incr, intLiteral(arrayLen))
} }
body = fn.newBasicBlock("range.body") // Evaluate both LHS expressions before we update either.
done = fn.newBasicBlock("range.done") var kl, vl lvalue
if tk != nil {
emitIf(fn, ok, body, done) kl = b.addr(fn, s.Key, false) // non-escaping
fn.currentBlock = body }
if tv != nil {
vl = b.addr(fn, s.Value, false) // non-escaping
}
if tk != nil {
kl.store(fn, k)
}
if tv != nil {
vl.store(fn, v)
}
if label != nil { if label != nil {
label._break = done label._break = done
label._continue = loop label._continue = loop
} }
if arrayLen == -1 {
// Evaluate both LHS expressions before we update either.
var k, v lvalue
if hasK {
k = b.addr(fn, s.Key, false) // non-escaping
okvVars[1].Type = b.exprType(s.Key)
}
if hasV {
v = b.addr(fn, s.Value, false) // non-escaping
okvVars[2].Type = b.exprType(s.Value)
}
if hasK {
k.store(fn, emitExtract(fn, okv, 1, okvVars[1].Type))
}
if hasV {
v.store(fn, emitExtract(fn, okv, 2, okvVars[2].Type))
}
} else {
// Store a copy of the index variable to k.
if hasK {
k := b.addr(fn, s.Key, false) // non-escaping
k.store(fn, emitLoad(fn, index))
}
}
fn.targets = &targets{ fn.targets = &targets{
tail: fn.targets, tail: fn.targets,
_break: done, _break: done,
......
...@@ -1009,8 +1009,6 @@ func callBuiltin(caller *frame, callpos token.Pos, fn *ssa.Builtin, args []value ...@@ -1009,8 +1009,6 @@ func callBuiltin(caller *frame, callpos token.Pos, fn *ssa.Builtin, args []value
func rangeIter(x value, t types.Type) iter { func rangeIter(x value, t types.Type) iter {
switch x := x.(type) { switch x := x.(type) {
case nil:
panic("range of nil")
case map[value]value: case map[value]value:
// TODO(adonovan): fix: leaks goroutines and channels // TODO(adonovan): fix: leaks goroutines and channels
// on each incomplete map iteration. We need to open // on each incomplete map iteration. We need to open
...@@ -1040,16 +1038,8 @@ func rangeIter(x value, t types.Type) iter { ...@@ -1040,16 +1038,8 @@ func rangeIter(x value, t types.Type) iter {
close(it) close(it)
}() }()
return it return it
case *value: // non-nil *array
return &arrayIter{a: (*x).(array)}
case array:
return &arrayIter{a: x}
case []value:
return &arrayIter{a: array(x)}
case string: case string:
return &stringIter{Reader: strings.NewReader(x)} return &stringIter{Reader: strings.NewReader(x)}
case chan value:
return chanIter(x)
} }
panic(fmt.Sprintf("cannot range over %T", x)) panic(fmt.Sprintf("cannot range over %T", x))
} }
......
...@@ -20,7 +20,7 @@ package interp ...@@ -20,7 +20,7 @@ package interp
// *ssa.Builtin } --- functions. // *ssa.Builtin } --- functions.
// *closure / // *closure /
// - tuple --- as returned by Ret, Next, "value,ok" modes, etc. // - tuple --- as returned by Ret, Next, "value,ok" modes, etc.
// - iter --- iterators from 'range'. // - iter --- iterators from 'range' over map or string.
// - bad --- a poison pill for locals that have gone out of scope. // - bad --- a poison pill for locals that have gone out of scope.
// - rtype -- the interpreter's concrete implementation of reflect.Type // - rtype -- the interpreter's concrete implementation of reflect.Type
// //
...@@ -441,31 +441,6 @@ func toString(v value) string { ...@@ -441,31 +441,6 @@ func toString(v value) string {
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
// Iterators // Iterators
type arrayIter struct {
a array
i int
}
func (it *arrayIter) next() tuple {
okv := make(tuple, 3)
ok := it.i < len(it.a)
okv[0] = ok
if ok {
okv[1] = it.i
okv[2] = copyVal(it.a[it.i])
}
it.i++
return okv
}
type chanIter chan value
func (it chanIter) next() tuple {
okv := make(tuple, 3)
okv[1], okv[0] = <-it
return okv
}
type stringIter struct { type stringIter struct {
*strings.Reader *strings.Reader
i int i int
......
...@@ -675,7 +675,9 @@ type Select struct { ...@@ -675,7 +675,9 @@ type Select struct {
Blocking bool Blocking bool
} }
// Range yields an iterator over the domain and range of X. // Range yields an iterator over the domain and range of X,
// which must be a string or map.
//
// Elements are accessed via Next. // Elements are accessed via Next.
// //
// Type() returns a *types.Result (tuple type). // Type() returns a *types.Result (tuple type).
...@@ -685,7 +687,7 @@ type Select struct { ...@@ -685,7 +687,7 @@ type Select struct {
// //
type Range struct { type Range struct {
Register Register
X Value // array, *array, slice, string, map or chan X Value // string or map
} }
// Next reads and advances the iterator Iter and returns a 3-tuple // Next reads and advances the iterator Iter and returns a 3-tuple
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment