Commit 7ea8cdaa authored by Robert Griesemer's avatar Robert Griesemer

go/ast: merge CaseClause and TypeCaseClause

(per rsc's suggestion)

R=rsc
CC=golang-dev
https://golang.org/cl/4276057
parent de811cc0
......@@ -325,26 +325,28 @@ func (f *File) walk(x interface{}, context string, visit func(*File, interface{}
f.walk(n.Results, "expr", visit)
case *ast.BranchStmt:
case *ast.BlockStmt:
f.walk(n.List, "stmt", visit)
f.walk(n.List, context, visit)
case *ast.IfStmt:
f.walk(n.Init, "stmt", visit)
f.walk(&n.Cond, "expr", visit)
f.walk(n.Body, "stmt", visit)
f.walk(n.Else, "stmt", visit)
case *ast.CaseClause:
f.walk(n.Values, "expr", visit)
if context == "typeswitch" {
context = "type"
} else {
context = "expr"
}
f.walk(n.List, context, visit)
f.walk(n.Body, "stmt", visit)
case *ast.SwitchStmt:
f.walk(n.Init, "stmt", visit)
f.walk(&n.Tag, "expr", visit)
f.walk(n.Body, "stmt", visit)
case *ast.TypeCaseClause:
f.walk(n.Types, "type", visit)
f.walk(n.Body, "stmt", visit)
f.walk(n.Body, "switch", visit)
case *ast.TypeSwitchStmt:
f.walk(n.Init, "stmt", visit)
f.walk(n.Assign, "stmt", visit)
f.walk(n.Body, "stmt", visit)
f.walk(n.Body, "typeswitch", visit)
case *ast.CommClause:
f.walk(n.Comm, "stmt", visit)
f.walk(n.Body, "stmt", visit)
......
......@@ -145,15 +145,12 @@ func rewrite(x interface{}, visit func(interface{})) {
rewrite(n.Body, visit)
rewrite(n.Else, visit)
case *ast.CaseClause:
rewrite(n.Values, visit)
rewrite(n.List, visit)
rewrite(n.Body, visit)
case *ast.SwitchStmt:
rewrite(n.Init, visit)
rewrite(&n.Tag, visit)
rewrite(n.Body, visit)
case *ast.TypeCaseClause:
rewrite(n.Types, visit)
rewrite(n.Body, visit)
case *ast.TypeSwitchStmt:
rewrite(n.Init, visit)
rewrite(n.Assign, visit)
......
......@@ -287,9 +287,6 @@ func (a *stmtCompiler) compile(s ast.Stmt) {
case *ast.SwitchStmt:
a.compileSwitchStmt(s)
case *ast.TypeCaseClause:
notimpl = true
case *ast.TypeSwitchStmt:
notimpl = true
......@@ -1012,13 +1009,13 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
a.diagAt(clause.Pos(), "switch statement must contain case clauses")
continue
}
if clause.Values == nil {
if clause.List == nil {
if hasDefault {
a.diagAt(clause.Pos(), "switch statement contains more than one default case")
}
hasDefault = true
} else {
ncases += len(clause.Values)
ncases += len(clause.List)
}
}
......@@ -1030,7 +1027,7 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
if !ok {
continue
}
for _, v := range clause.Values {
for _, v := range clause.List {
e := condbc.compileExpr(condbc.block, false, v)
switch {
case e == nil:
......@@ -1077,8 +1074,8 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
// Save jump PC's
pc := a.nextPC()
if clause.Values != nil {
for _ = range clause.Values {
if clause.List != nil {
for _ = range clause.List {
casePCs[i] = &pc
i++
}
......
......@@ -602,12 +602,12 @@ type (
Else Stmt // else branch; or nil
}
// A CaseClause represents a case of an expression switch statement.
// A CaseClause represents a case of an expression or type switch statement.
CaseClause struct {
Case token.Pos // position of "case" or "default" keyword
Values []Expr // nil means default case
Colon token.Pos // position of ":"
Body []Stmt // statement list; or nil
Case token.Pos // position of "case" or "default" keyword
List []Expr // list of expressions or types; nil means default case
Colon token.Pos // position of ":"
Body []Stmt // statement list; or nil
}
// A SwitchStmt node represents an expression switch statement.
......@@ -618,20 +618,12 @@ type (
Body *BlockStmt // CaseClauses only
}
// A TypeCaseClause represents a case of a type switch statement.
TypeCaseClause struct {
Case token.Pos // position of "case" or "default" keyword
Types []Expr // nil means default case
Colon token.Pos // position of ":"
Body []Stmt // statement list; or nil
}
// An TypeSwitchStmt node represents a type switch statement.
TypeSwitchStmt struct {
Switch token.Pos // position of "switch" keyword
Init Stmt // initalization statement; or nil
Assign Stmt // x := y.(type)
Body *BlockStmt // TypeCaseClauses only
Assign Stmt // x := y.(type) or y.(type)
Body *BlockStmt // CaseClauses only
}
// A CommClause node represents a case of a select statement.
......@@ -687,7 +679,6 @@ func (s *BlockStmt) Pos() token.Pos { return s.Lbrace }
func (s *IfStmt) Pos() token.Pos { return s.If }
func (s *CaseClause) Pos() token.Pos { return s.Case }
func (s *SwitchStmt) Pos() token.Pos { return s.Switch }
func (s *TypeCaseClause) Pos() token.Pos { return s.Case }
func (s *TypeSwitchStmt) Pos() token.Pos { return s.Switch }
func (s *CommClause) Pos() token.Pos { return s.Case }
func (s *SelectStmt) Pos() token.Pos { return s.Select }
......@@ -734,13 +725,7 @@ func (s *CaseClause) End() token.Pos {
}
return s.Colon + 1
}
func (s *SwitchStmt) End() token.Pos { return s.Body.End() }
func (s *TypeCaseClause) End() token.Pos {
if n := len(s.Body); n > 0 {
return s.Body[n-1].End()
}
return s.Colon + 1
}
func (s *SwitchStmt) End() token.Pos { return s.Body.End() }
func (s *TypeSwitchStmt) End() token.Pos { return s.Body.End() }
func (s *CommClause) End() token.Pos {
if n := len(s.Body); n > 0 {
......@@ -772,7 +757,6 @@ func (s *BlockStmt) stmtNode() {}
func (s *IfStmt) stmtNode() {}
func (s *CaseClause) stmtNode() {}
func (s *SwitchStmt) stmtNode() {}
func (s *TypeCaseClause) stmtNode() {}
func (s *TypeSwitchStmt) stmtNode() {}
func (s *CommClause) stmtNode() {}
func (s *SelectStmt) stmtNode() {}
......
......@@ -234,7 +234,7 @@ func Walk(v Visitor, node Node) {
}
case *CaseClause:
walkExprList(v, n.Values)
walkExprList(v, n.List)
walkStmtList(v, n.Body)
case *SwitchStmt:
......@@ -246,12 +246,6 @@ func Walk(v Visitor, node Node) {
}
Walk(v, n.Body)
case *TypeCaseClause:
for _, x := range n.Types {
Walk(v, x)
}
walkStmtList(v, n.Body)
case *TypeSwitchStmt:
if n.Init != nil {
Walk(v, n.Init)
......
......@@ -1518,29 +1518,6 @@ func (p *parser) parseIfStmt() *ast.IfStmt {
}
func (p *parser) parseCaseClause() *ast.CaseClause {
if p.trace {
defer un(trace(p, "CaseClause"))
}
pos := p.pos
var x []ast.Expr
if p.tok == token.CASE {
p.next()
x = p.parseExprList()
} else {
p.expect(token.DEFAULT)
}
colon := p.expect(token.COLON)
p.openScope()
body := p.parseStmtList()
p.closeScope()
return &ast.CaseClause{pos, x, colon, body}
}
func (p *parser) parseTypeList() (list []ast.Expr) {
if p.trace {
defer un(trace(p, "TypeList"))
......@@ -1556,16 +1533,20 @@ func (p *parser) parseTypeList() (list []ast.Expr) {
}
func (p *parser) parseTypeCaseClause() *ast.TypeCaseClause {
func (p *parser) parseCaseClause(exprSwitch bool) *ast.CaseClause {
if p.trace {
defer un(trace(p, "TypeCaseClause"))
defer un(trace(p, "CaseClause"))
}
pos := p.pos
var types []ast.Expr
var list []ast.Expr
if p.tok == token.CASE {
p.next()
types = p.parseTypeList()
if exprSwitch {
list = p.parseExprList()
} else {
list = p.parseTypeList()
}
} else {
p.expect(token.DEFAULT)
}
......@@ -1575,7 +1556,7 @@ func (p *parser) parseTypeCaseClause() *ast.TypeCaseClause {
body := p.parseStmtList()
p.closeScope()
return &ast.TypeCaseClause{pos, types, colon, body}
return &ast.CaseClause{pos, list, colon, body}
}
......@@ -1620,28 +1601,21 @@ func (p *parser) parseSwitchStmt() ast.Stmt {
p.exprLev = prevLev
}
if isExprSwitch(s2) {
lbrace := p.expect(token.LBRACE)
var list []ast.Stmt
for p.tok == token.CASE || p.tok == token.DEFAULT {
list = append(list, p.parseCaseClause())
}
rbrace := p.expect(token.RBRACE)
body := &ast.BlockStmt{lbrace, list, rbrace}
p.expectSemi()
return &ast.SwitchStmt{pos, s1, p.makeExpr(s2), body}
}
// type switch
// TODO(gri): do all the checks!
exprSwitch := isExprSwitch(s2)
lbrace := p.expect(token.LBRACE)
var list []ast.Stmt
for p.tok == token.CASE || p.tok == token.DEFAULT {
list = append(list, p.parseTypeCaseClause())
list = append(list, p.parseCaseClause(exprSwitch))
}
rbrace := p.expect(token.RBRACE)
p.expectSemi()
body := &ast.BlockStmt{lbrace, list, rbrace}
if exprSwitch {
return &ast.SwitchStmt{pos, s1, p.makeExpr(s2), body}
}
// type switch
// TODO(gri): do all the checks!
return &ast.TypeSwitchStmt{pos, s1, s2, body}
}
......
......@@ -1148,9 +1148,9 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, multiLine *bool) {
}
case *ast.CaseClause:
if s.Values != nil {
if s.List != nil {
p.print(token.CASE)
p.exprList(s.Pos(), s.Values, 1, blankStart|commaSep, multiLine, s.Colon)
p.exprList(s.Pos(), s.List, 1, blankStart|commaSep, multiLine, s.Colon)
} else {
p.print(token.DEFAULT)
}
......@@ -1163,16 +1163,6 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, multiLine *bool) {
p.block(s.Body, 0)
*multiLine = true
case *ast.TypeCaseClause:
if s.Types != nil {
p.print(token.CASE)
p.exprList(s.Pos(), s.Types, 1, blankStart|commaSep, multiLine, s.Colon)
} else {
p.print(token.DEFAULT)
}
p.print(s.Colon, token.COLON)
p.stmtList(s.Body, 1, nextIsRBrace)
case *ast.TypeSwitchStmt:
p.print(token.SWITCH)
if s.Init != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment