Commit 726190da authored by Josh Bleecher Snyder's avatar Josh Bleecher Snyder

cmd/compile: explicitly manage default and nil switch cases

Rather than juggle default and nil cases as part
of a slice, handle them explicitly.

Change-Id: I97b200c9d3f23fe1a438acdbf3d13b0cf7e0851e
Reviewed-on: https://go-review.googlesource.com/26761
TryBot-Result: Gobot Gobot <gobot@golang.org>
Run-TryBot: Josh Bleecher Snyder <josharian@gmail.com>
Reviewed-by: 's avatarMatthew Dempsky <mdempsky@google.com>
parent cf20525b
...@@ -17,14 +17,11 @@ const ( ...@@ -17,14 +17,11 @@ const (
) )
const ( const (
caseKindDefault = iota // default:
// expression switch // expression switch
caseKindExprConst // case 5: caseKindExprConst = iota // case 5:
caseKindExprVar // case x: caseKindExprVar // case x:
// type switch // type switch
caseKindTypeNil // case nil:
caseKindTypeConst // case time.Time: (concrete type, has type hash) caseKindTypeConst // case time.Time: (concrete type, has type hash)
caseKindTypeVar // case io.Reader: (interface type) caseKindTypeVar // case io.Reader: (interface type)
) )
...@@ -52,6 +49,13 @@ type caseClause struct { ...@@ -52,6 +49,13 @@ type caseClause struct {
typ uint8 // type of case typ uint8 // type of case
} }
// caseClauses are all the case clauses in a switch statement.
type caseClauses struct {
list []caseClause // general cases
defjmp *Node // OGOTO for default case or OBREAK if no default case present
niljmp *Node // OGOTO for nil type case in a type switch
}
// typecheckswitch typechecks a switch statement. // typecheckswitch typechecks a switch statement.
func typecheckswitch(n *Node) { func typecheckswitch(n *Node) {
lno := lineno lno := lineno
...@@ -248,16 +252,10 @@ func (s *exprSwitch) walk(sw *Node) { ...@@ -248,16 +252,10 @@ func (s *exprSwitch) walk(sw *Node) {
typecheckslice(cas, Etop) typecheckslice(cas, Etop)
} }
// enumerate the cases, and lop off the default case // Enumerate the cases and prepare the default case.
cc := caseClauses(sw, s.kind) clauses := genCaseClauses(sw, s.kind)
sw.List.Set(nil) sw.List.Set(nil)
var def *Node cc := clauses.list
if len(cc) > 0 && cc[0].typ == caseKindDefault {
def = cc[0].node.Right
cc = cc[1:]
} else {
def = Nod(OBREAK, nil, nil)
}
// handle the cases in order // handle the cases in order
for len(cc) > 0 { for len(cc) > 0 {
...@@ -283,14 +281,14 @@ func (s *exprSwitch) walk(sw *Node) { ...@@ -283,14 +281,14 @@ func (s *exprSwitch) walk(sw *Node) {
// handle default case // handle default case
if nerrors == 0 { if nerrors == 0 {
cas = append(cas, def) cas = append(cas, clauses.defjmp)
sw.Nbody.Set(append(cas, sw.Nbody.Slice()...)) sw.Nbody.Set(append(cas, sw.Nbody.Slice()...))
walkstmtlist(sw.Nbody.Slice()) walkstmtlist(sw.Nbody.Slice())
} }
} }
// walkCases generates an AST implementing the cases in cc. // walkCases generates an AST implementing the cases in cc.
func (s *exprSwitch) walkCases(cc []*caseClause) *Node { func (s *exprSwitch) walkCases(cc []caseClause) *Node {
if len(cc) < binarySearchMin { if len(cc) < binarySearchMin {
// linear search // linear search
var cas []*Node var cas []*Node
...@@ -422,27 +420,35 @@ func casebody(sw *Node, typeswvar *Node) { ...@@ -422,27 +420,35 @@ func casebody(sw *Node, typeswvar *Node) {
lineno = lno lineno = lno
} }
// caseClauses generates a slice of caseClauses // genCaseClauses generates the caseClauses value
// corresponding to the clauses in the switch statement sw. // corresponding to the clauses in the switch statement sw.
// Kind is the kind of switch statement. // Kind is the kind of switch statement.
func caseClauses(sw *Node, kind int) []*caseClause { func genCaseClauses(sw *Node, kind int) caseClauses {
var cc []*caseClause var cc caseClauses
for _, n := range sw.List.Slice() { for _, n := range sw.List.Slice() {
c := new(caseClause)
cc = append(cc, c)
c.ordinal = len(cc)
c.node = n
if n.Left == nil { if n.Left == nil {
c.typ = caseKindDefault // default case
if cc.defjmp != nil {
Fatalf("duplicate default case not detected during typechecking")
}
cc.defjmp = n.Right
continue continue
} }
if kind == switchKindType && n.Left.Op == OLITERAL {
// nil case in type switch
if cc.niljmp != nil {
Fatalf("duplicate nil case not detected during typechecking")
}
cc.niljmp = n.Right
continue
}
// general case
c := caseClause{node: n, ordinal: len(cc.list)}
if kind == switchKindType { if kind == switchKindType {
// type switch // type switch
switch { switch {
case n.Left.Op == OLITERAL:
c.typ = caseKindTypeNil
case n.Left.Type.IsInterface(): case n.Left.Type.IsInterface():
c.typ = caseKindTypeVar c.typ = caseKindTypeVar
default: default:
...@@ -458,22 +464,24 @@ func caseClauses(sw *Node, kind int) []*caseClause { ...@@ -458,22 +464,24 @@ func caseClauses(sw *Node, kind int) []*caseClause {
c.typ = caseKindExprVar c.typ = caseKindExprVar
} }
} }
cc.list = append(cc.list, c)
} }
if cc == nil { if cc.defjmp == nil {
return nil cc.defjmp = Nod(OBREAK, nil, nil)
}
if cc.list == nil {
return cc
} }
// sort by value and diagnose duplicate cases // sort by value and diagnose duplicate cases
if kind == switchKindType { if kind == switchKindType {
// type switch // type switch
sort.Sort(caseClauseByType(cc)) sort.Sort(caseClauseByType(cc.list))
for i, c1 := range cc { for i, c1 := range cc.list {
if c1.typ == caseKindTypeNil || c1.typ == caseKindDefault { for _, c2 := range cc.list[i+1:] {
break if c1.hash != c2.hash {
}
for _, c2 := range cc[i+1:] {
if c2.typ == caseKindTypeNil || c2.typ == caseKindDefault || c1.hash != c2.hash {
break break
} }
if Eqtype(c1.node.Left.Type, c2.node.Left.Type) { if Eqtype(c1.node.Left.Type, c2.node.Left.Type) {
...@@ -483,12 +491,12 @@ func caseClauses(sw *Node, kind int) []*caseClause { ...@@ -483,12 +491,12 @@ func caseClauses(sw *Node, kind int) []*caseClause {
} }
} else { } else {
// expression switch // expression switch
sort.Sort(caseClauseByExpr(cc)) sort.Sort(caseClauseByExpr(cc.list))
for i, c1 := range cc { for i, c1 := range cc.list {
if i+1 == len(cc) { if i+1 == len(cc.list) {
break break
} }
c2 := cc[i+1] c2 := cc.list[i+1]
if exprcmp(c1, c2) != 0 { if exprcmp(c1, c2) != 0 {
continue continue
} }
...@@ -498,7 +506,7 @@ func caseClauses(sw *Node, kind int) []*caseClause { ...@@ -498,7 +506,7 @@ func caseClauses(sw *Node, kind int) []*caseClause {
} }
// put list back in processing order // put list back in processing order
sort.Sort(caseClauseByOrd(cc)) sort.Sort(caseClauseByOrd(cc.list))
return cc return cc
} }
...@@ -545,20 +553,9 @@ func (s *typeSwitch) walk(sw *Node) { ...@@ -545,20 +553,9 @@ func (s *typeSwitch) walk(sw *Node) {
// set up labels and jumps // set up labels and jumps
casebody(sw, s.facename) casebody(sw, s.facename)
cc := caseClauses(sw, switchKindType) clauses := genCaseClauses(sw, switchKindType)
sw.List.Set(nil) sw.List.Set(nil)
var def *Node def := clauses.defjmp
if len(cc) > 0 && cc[0].typ == caseKindDefault {
def = cc[0].node.Right
cc = cc[1:]
} else {
def = Nod(OBREAK, nil, nil)
}
var typenil *Node
if len(cc) > 0 && cc[0].typ == caseKindTypeNil {
typenil = cc[0].node.Right
cc = cc[1:]
}
// For empty interfaces, do: // For empty interfaces, do:
// if e._type == nil { // if e._type == nil {
...@@ -573,9 +570,9 @@ func (s *typeSwitch) walk(sw *Node) { ...@@ -573,9 +570,9 @@ func (s *typeSwitch) walk(sw *Node) {
// Check for nil first. // Check for nil first.
i := Nod(OIF, nil, nil) i := Nod(OIF, nil, nil)
i.Left = Nod(OEQ, typ, nodnil()) i.Left = Nod(OEQ, typ, nodnil())
if typenil != nil { if clauses.niljmp != nil {
// Do explicit nil case right here. // Do explicit nil case right here.
i.Nbody.Set1(typenil) i.Nbody.Set1(clauses.niljmp)
} else { } else {
// Jump to default case. // Jump to default case.
lbl := autolabel(".s") lbl := autolabel(".s")
...@@ -602,6 +599,8 @@ func (s *typeSwitch) walk(sw *Node) { ...@@ -602,6 +599,8 @@ func (s *typeSwitch) walk(sw *Node) {
a = typecheck(a, Etop) a = typecheck(a, Etop)
cas = append(cas, a) cas = append(cas, a)
cc := clauses.list
// insert type equality check into each case block // insert type equality check into each case block
for _, c := range cc { for _, c := range cc {
n := c.node n := c.node
...@@ -696,7 +695,7 @@ func (s *typeSwitch) typeone(t *Node) *Node { ...@@ -696,7 +695,7 @@ func (s *typeSwitch) typeone(t *Node) *Node {
} }
// walkCases generates an AST implementing the cases in cc. // walkCases generates an AST implementing the cases in cc.
func (s *typeSwitch) walkCases(cc []*caseClause) *Node { func (s *typeSwitch) walkCases(cc []caseClause) *Node {
if len(cc) < binarySearchMin { if len(cc) < binarySearchMin {
var cas []*Node var cas []*Node
for _, c := range cc { for _, c := range cc {
...@@ -723,31 +722,13 @@ func (s *typeSwitch) walkCases(cc []*caseClause) *Node { ...@@ -723,31 +722,13 @@ func (s *typeSwitch) walkCases(cc []*caseClause) *Node {
return a return a
} }
type caseClauseByOrd []*caseClause type caseClauseByOrd []caseClause
func (x caseClauseByOrd) Len() int { return len(x) }
func (x caseClauseByOrd) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x caseClauseByOrd) Less(i, j int) bool {
c1, c2 := x[i], x[j]
switch {
// sort default first
case c1.typ == caseKindDefault:
return true
case c2.typ == caseKindDefault:
return false
// sort nil second
case c1.typ == caseKindTypeNil:
return true
case c2.typ == caseKindTypeNil:
return false
}
// sort by ordinal func (x caseClauseByOrd) Len() int { return len(x) }
return c1.ordinal < c2.ordinal func (x caseClauseByOrd) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
} func (x caseClauseByOrd) Less(i, j int) bool { return x[i].ordinal < x[j].ordinal }
type caseClauseByExpr []*caseClause type caseClauseByExpr []caseClause
func (x caseClauseByExpr) Len() int { return len(x) } func (x caseClauseByExpr) Len() int { return len(x) }
func (x caseClauseByExpr) Swap(i, j int) { x[i], x[j] = x[j], x[i] } func (x caseClauseByExpr) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
...@@ -755,7 +736,7 @@ func (x caseClauseByExpr) Less(i, j int) bool { ...@@ -755,7 +736,7 @@ func (x caseClauseByExpr) Less(i, j int) bool {
return exprcmp(x[i], x[j]) < 0 return exprcmp(x[i], x[j]) < 0
} }
func exprcmp(c1, c2 *caseClause) int { func exprcmp(c1, c2 caseClause) int {
// sort non-constants last // sort non-constants last
if c1.typ != caseKindExprConst { if c1.typ != caseKindExprConst {
return +1 return +1
...@@ -814,7 +795,7 @@ func exprcmp(c1, c2 *caseClause) int { ...@@ -814,7 +795,7 @@ func exprcmp(c1, c2 *caseClause) int {
return 0 return 0
} }
type caseClauseByType []*caseClause type caseClauseByType []caseClause
func (x caseClauseByType) Len() int { return len(x) } func (x caseClauseByType) Len() int { return len(x) }
func (x caseClauseByType) Swap(i, j int) { x[i], x[j] = x[j], x[i] } func (x caseClauseByType) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
......
...@@ -134,7 +134,7 @@ func TestExprcmp(t *testing.T) { ...@@ -134,7 +134,7 @@ func TestExprcmp(t *testing.T) {
}, },
} }
for i, d := range testdata { for i, d := range testdata {
got := exprcmp(&d.a, &d.b) got := exprcmp(d.a, d.b)
if d.want != got { if d.want != got {
t.Errorf("%d: exprcmp(a, b) = %d; want %d", i, got, d.want) t.Errorf("%d: exprcmp(a, b) = %d; want %d", i, got, d.want)
t.Logf("\ta = caseClause{node: %#v, typ: %#v}", d.a.node, d.a.typ) t.Logf("\ta = caseClause{node: %#v, typ: %#v}", d.a.node, d.a.typ)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment