Commit 7ea8cdaa authored by Robert Griesemer's avatar Robert Griesemer

go/ast: merge CaseClause and TypeCaseClause

(per rsc's suggestion)

R=rsc
CC=golang-dev
https://golang.org/cl/4276057
parent de811cc0
...@@ -325,26 +325,28 @@ func (f *File) walk(x interface{}, context string, visit func(*File, interface{} ...@@ -325,26 +325,28 @@ func (f *File) walk(x interface{}, context string, visit func(*File, interface{}
f.walk(n.Results, "expr", visit) f.walk(n.Results, "expr", visit)
case *ast.BranchStmt: case *ast.BranchStmt:
case *ast.BlockStmt: case *ast.BlockStmt:
f.walk(n.List, "stmt", visit) f.walk(n.List, context, visit)
case *ast.IfStmt: case *ast.IfStmt:
f.walk(n.Init, "stmt", visit) f.walk(n.Init, "stmt", visit)
f.walk(&n.Cond, "expr", visit) f.walk(&n.Cond, "expr", visit)
f.walk(n.Body, "stmt", visit) f.walk(n.Body, "stmt", visit)
f.walk(n.Else, "stmt", visit) f.walk(n.Else, "stmt", visit)
case *ast.CaseClause: case *ast.CaseClause:
f.walk(n.Values, "expr", visit) if context == "typeswitch" {
context = "type"
} else {
context = "expr"
}
f.walk(n.List, context, visit)
f.walk(n.Body, "stmt", visit) f.walk(n.Body, "stmt", visit)
case *ast.SwitchStmt: case *ast.SwitchStmt:
f.walk(n.Init, "stmt", visit) f.walk(n.Init, "stmt", visit)
f.walk(&n.Tag, "expr", visit) f.walk(&n.Tag, "expr", visit)
f.walk(n.Body, "stmt", visit) f.walk(n.Body, "switch", visit)
case *ast.TypeCaseClause:
f.walk(n.Types, "type", visit)
f.walk(n.Body, "stmt", visit)
case *ast.TypeSwitchStmt: case *ast.TypeSwitchStmt:
f.walk(n.Init, "stmt", visit) f.walk(n.Init, "stmt", visit)
f.walk(n.Assign, "stmt", visit) f.walk(n.Assign, "stmt", visit)
f.walk(n.Body, "stmt", visit) f.walk(n.Body, "typeswitch", visit)
case *ast.CommClause: case *ast.CommClause:
f.walk(n.Comm, "stmt", visit) f.walk(n.Comm, "stmt", visit)
f.walk(n.Body, "stmt", visit) f.walk(n.Body, "stmt", visit)
......
...@@ -145,15 +145,12 @@ func rewrite(x interface{}, visit func(interface{})) { ...@@ -145,15 +145,12 @@ func rewrite(x interface{}, visit func(interface{})) {
rewrite(n.Body, visit) rewrite(n.Body, visit)
rewrite(n.Else, visit) rewrite(n.Else, visit)
case *ast.CaseClause: case *ast.CaseClause:
rewrite(n.Values, visit) rewrite(n.List, visit)
rewrite(n.Body, visit) rewrite(n.Body, visit)
case *ast.SwitchStmt: case *ast.SwitchStmt:
rewrite(n.Init, visit) rewrite(n.Init, visit)
rewrite(&n.Tag, visit) rewrite(&n.Tag, visit)
rewrite(n.Body, visit) rewrite(n.Body, visit)
case *ast.TypeCaseClause:
rewrite(n.Types, visit)
rewrite(n.Body, visit)
case *ast.TypeSwitchStmt: case *ast.TypeSwitchStmt:
rewrite(n.Init, visit) rewrite(n.Init, visit)
rewrite(n.Assign, visit) rewrite(n.Assign, visit)
......
...@@ -287,9 +287,6 @@ func (a *stmtCompiler) compile(s ast.Stmt) { ...@@ -287,9 +287,6 @@ func (a *stmtCompiler) compile(s ast.Stmt) {
case *ast.SwitchStmt: case *ast.SwitchStmt:
a.compileSwitchStmt(s) a.compileSwitchStmt(s)
case *ast.TypeCaseClause:
notimpl = true
case *ast.TypeSwitchStmt: case *ast.TypeSwitchStmt:
notimpl = true notimpl = true
...@@ -1012,13 +1009,13 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) { ...@@ -1012,13 +1009,13 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
a.diagAt(clause.Pos(), "switch statement must contain case clauses") a.diagAt(clause.Pos(), "switch statement must contain case clauses")
continue continue
} }
if clause.Values == nil { if clause.List == nil {
if hasDefault { if hasDefault {
a.diagAt(clause.Pos(), "switch statement contains more than one default case") a.diagAt(clause.Pos(), "switch statement contains more than one default case")
} }
hasDefault = true hasDefault = true
} else { } else {
ncases += len(clause.Values) ncases += len(clause.List)
} }
} }
...@@ -1030,7 +1027,7 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) { ...@@ -1030,7 +1027,7 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
if !ok { if !ok {
continue continue
} }
for _, v := range clause.Values { for _, v := range clause.List {
e := condbc.compileExpr(condbc.block, false, v) e := condbc.compileExpr(condbc.block, false, v)
switch { switch {
case e == nil: case e == nil:
...@@ -1077,8 +1074,8 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) { ...@@ -1077,8 +1074,8 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
// Save jump PC's // Save jump PC's
pc := a.nextPC() pc := a.nextPC()
if clause.Values != nil { if clause.List != nil {
for _ = range clause.Values { for _ = range clause.List {
casePCs[i] = &pc casePCs[i] = &pc
i++ i++
} }
......
...@@ -602,12 +602,12 @@ type ( ...@@ -602,12 +602,12 @@ type (
Else Stmt // else branch; or nil Else Stmt // else branch; or nil
} }
// A CaseClause represents a case of an expression switch statement. // A CaseClause represents a case of an expression or type switch statement.
CaseClause struct { CaseClause struct {
Case token.Pos // position of "case" or "default" keyword Case token.Pos // position of "case" or "default" keyword
Values []Expr // nil means default case List []Expr // list of expressions or types; nil means default case
Colon token.Pos // position of ":" Colon token.Pos // position of ":"
Body []Stmt // statement list; or nil Body []Stmt // statement list; or nil
} }
// A SwitchStmt node represents an expression switch statement. // A SwitchStmt node represents an expression switch statement.
...@@ -618,20 +618,12 @@ type ( ...@@ -618,20 +618,12 @@ type (
Body *BlockStmt // CaseClauses only Body *BlockStmt // CaseClauses only
} }
// A TypeCaseClause represents a case of a type switch statement.
TypeCaseClause struct {
Case token.Pos // position of "case" or "default" keyword
Types []Expr // nil means default case
Colon token.Pos // position of ":"
Body []Stmt // statement list; or nil
}
// An TypeSwitchStmt node represents a type switch statement. // An TypeSwitchStmt node represents a type switch statement.
TypeSwitchStmt struct { TypeSwitchStmt struct {
Switch token.Pos // position of "switch" keyword Switch token.Pos // position of "switch" keyword
Init Stmt // initalization statement; or nil Init Stmt // initalization statement; or nil
Assign Stmt // x := y.(type) Assign Stmt // x := y.(type) or y.(type)
Body *BlockStmt // TypeCaseClauses only Body *BlockStmt // CaseClauses only
} }
// A CommClause node represents a case of a select statement. // A CommClause node represents a case of a select statement.
...@@ -687,7 +679,6 @@ func (s *BlockStmt) Pos() token.Pos { return s.Lbrace } ...@@ -687,7 +679,6 @@ func (s *BlockStmt) Pos() token.Pos { return s.Lbrace }
func (s *IfStmt) Pos() token.Pos { return s.If } func (s *IfStmt) Pos() token.Pos { return s.If }
func (s *CaseClause) Pos() token.Pos { return s.Case } func (s *CaseClause) Pos() token.Pos { return s.Case }
func (s *SwitchStmt) Pos() token.Pos { return s.Switch } func (s *SwitchStmt) Pos() token.Pos { return s.Switch }
func (s *TypeCaseClause) Pos() token.Pos { return s.Case }
func (s *TypeSwitchStmt) Pos() token.Pos { return s.Switch } func (s *TypeSwitchStmt) Pos() token.Pos { return s.Switch }
func (s *CommClause) Pos() token.Pos { return s.Case } func (s *CommClause) Pos() token.Pos { return s.Case }
func (s *SelectStmt) Pos() token.Pos { return s.Select } func (s *SelectStmt) Pos() token.Pos { return s.Select }
...@@ -734,13 +725,7 @@ func (s *CaseClause) End() token.Pos { ...@@ -734,13 +725,7 @@ func (s *CaseClause) End() token.Pos {
} }
return s.Colon + 1 return s.Colon + 1
} }
func (s *SwitchStmt) End() token.Pos { return s.Body.End() } func (s *SwitchStmt) End() token.Pos { return s.Body.End() }
func (s *TypeCaseClause) End() token.Pos {
if n := len(s.Body); n > 0 {
return s.Body[n-1].End()
}
return s.Colon + 1
}
func (s *TypeSwitchStmt) End() token.Pos { return s.Body.End() } func (s *TypeSwitchStmt) End() token.Pos { return s.Body.End() }
func (s *CommClause) End() token.Pos { func (s *CommClause) End() token.Pos {
if n := len(s.Body); n > 0 { if n := len(s.Body); n > 0 {
...@@ -772,7 +757,6 @@ func (s *BlockStmt) stmtNode() {} ...@@ -772,7 +757,6 @@ func (s *BlockStmt) stmtNode() {}
func (s *IfStmt) stmtNode() {} func (s *IfStmt) stmtNode() {}
func (s *CaseClause) stmtNode() {} func (s *CaseClause) stmtNode() {}
func (s *SwitchStmt) stmtNode() {} func (s *SwitchStmt) stmtNode() {}
func (s *TypeCaseClause) stmtNode() {}
func (s *TypeSwitchStmt) stmtNode() {} func (s *TypeSwitchStmt) stmtNode() {}
func (s *CommClause) stmtNode() {} func (s *CommClause) stmtNode() {}
func (s *SelectStmt) stmtNode() {} func (s *SelectStmt) stmtNode() {}
......
...@@ -234,7 +234,7 @@ func Walk(v Visitor, node Node) { ...@@ -234,7 +234,7 @@ func Walk(v Visitor, node Node) {
} }
case *CaseClause: case *CaseClause:
walkExprList(v, n.Values) walkExprList(v, n.List)
walkStmtList(v, n.Body) walkStmtList(v, n.Body)
case *SwitchStmt: case *SwitchStmt:
...@@ -246,12 +246,6 @@ func Walk(v Visitor, node Node) { ...@@ -246,12 +246,6 @@ func Walk(v Visitor, node Node) {
} }
Walk(v, n.Body) Walk(v, n.Body)
case *TypeCaseClause:
for _, x := range n.Types {
Walk(v, x)
}
walkStmtList(v, n.Body)
case *TypeSwitchStmt: case *TypeSwitchStmt:
if n.Init != nil { if n.Init != nil {
Walk(v, n.Init) Walk(v, n.Init)
......
...@@ -1518,29 +1518,6 @@ func (p *parser) parseIfStmt() *ast.IfStmt { ...@@ -1518,29 +1518,6 @@ func (p *parser) parseIfStmt() *ast.IfStmt {
} }
func (p *parser) parseCaseClause() *ast.CaseClause {
if p.trace {
defer un(trace(p, "CaseClause"))
}
pos := p.pos
var x []ast.Expr
if p.tok == token.CASE {
p.next()
x = p.parseExprList()
} else {
p.expect(token.DEFAULT)
}
colon := p.expect(token.COLON)
p.openScope()
body := p.parseStmtList()
p.closeScope()
return &ast.CaseClause{pos, x, colon, body}
}
func (p *parser) parseTypeList() (list []ast.Expr) { func (p *parser) parseTypeList() (list []ast.Expr) {
if p.trace { if p.trace {
defer un(trace(p, "TypeList")) defer un(trace(p, "TypeList"))
...@@ -1556,16 +1533,20 @@ func (p *parser) parseTypeList() (list []ast.Expr) { ...@@ -1556,16 +1533,20 @@ func (p *parser) parseTypeList() (list []ast.Expr) {
} }
func (p *parser) parseTypeCaseClause() *ast.TypeCaseClause { func (p *parser) parseCaseClause(exprSwitch bool) *ast.CaseClause {
if p.trace { if p.trace {
defer un(trace(p, "TypeCaseClause")) defer un(trace(p, "CaseClause"))
} }
pos := p.pos pos := p.pos
var types []ast.Expr var list []ast.Expr
if p.tok == token.CASE { if p.tok == token.CASE {
p.next() p.next()
types = p.parseTypeList() if exprSwitch {
list = p.parseExprList()
} else {
list = p.parseTypeList()
}
} else { } else {
p.expect(token.DEFAULT) p.expect(token.DEFAULT)
} }
...@@ -1575,7 +1556,7 @@ func (p *parser) parseTypeCaseClause() *ast.TypeCaseClause { ...@@ -1575,7 +1556,7 @@ func (p *parser) parseTypeCaseClause() *ast.TypeCaseClause {
body := p.parseStmtList() body := p.parseStmtList()
p.closeScope() p.closeScope()
return &ast.TypeCaseClause{pos, types, colon, body} return &ast.CaseClause{pos, list, colon, body}
} }
...@@ -1620,28 +1601,21 @@ func (p *parser) parseSwitchStmt() ast.Stmt { ...@@ -1620,28 +1601,21 @@ func (p *parser) parseSwitchStmt() ast.Stmt {
p.exprLev = prevLev p.exprLev = prevLev
} }
if isExprSwitch(s2) { exprSwitch := isExprSwitch(s2)
lbrace := p.expect(token.LBRACE)
var list []ast.Stmt
for p.tok == token.CASE || p.tok == token.DEFAULT {
list = append(list, p.parseCaseClause())
}
rbrace := p.expect(token.RBRACE)
body := &ast.BlockStmt{lbrace, list, rbrace}
p.expectSemi()
return &ast.SwitchStmt{pos, s1, p.makeExpr(s2), body}
}
// type switch
// TODO(gri): do all the checks!
lbrace := p.expect(token.LBRACE) lbrace := p.expect(token.LBRACE)
var list []ast.Stmt var list []ast.Stmt
for p.tok == token.CASE || p.tok == token.DEFAULT { for p.tok == token.CASE || p.tok == token.DEFAULT {
list = append(list, p.parseTypeCaseClause()) list = append(list, p.parseCaseClause(exprSwitch))
} }
rbrace := p.expect(token.RBRACE) rbrace := p.expect(token.RBRACE)
p.expectSemi() p.expectSemi()
body := &ast.BlockStmt{lbrace, list, rbrace} body := &ast.BlockStmt{lbrace, list, rbrace}
if exprSwitch {
return &ast.SwitchStmt{pos, s1, p.makeExpr(s2), body}
}
// type switch
// TODO(gri): do all the checks!
return &ast.TypeSwitchStmt{pos, s1, s2, body} return &ast.TypeSwitchStmt{pos, s1, s2, body}
} }
......
...@@ -1148,9 +1148,9 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, multiLine *bool) { ...@@ -1148,9 +1148,9 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, multiLine *bool) {
} }
case *ast.CaseClause: case *ast.CaseClause:
if s.Values != nil { if s.List != nil {
p.print(token.CASE) p.print(token.CASE)
p.exprList(s.Pos(), s.Values, 1, blankStart|commaSep, multiLine, s.Colon) p.exprList(s.Pos(), s.List, 1, blankStart|commaSep, multiLine, s.Colon)
} else { } else {
p.print(token.DEFAULT) p.print(token.DEFAULT)
} }
...@@ -1163,16 +1163,6 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, multiLine *bool) { ...@@ -1163,16 +1163,6 @@ func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, multiLine *bool) {
p.block(s.Body, 0) p.block(s.Body, 0)
*multiLine = true *multiLine = true
case *ast.TypeCaseClause:
if s.Types != nil {
p.print(token.CASE)
p.exprList(s.Pos(), s.Types, 1, blankStart|commaSep, multiLine, s.Colon)
} else {
p.print(token.DEFAULT)
}
p.print(s.Colon, token.COLON)
p.stmtList(s.Body, 1, nextIsRBrace)
case *ast.TypeSwitchStmt: case *ast.TypeSwitchStmt:
p.print(token.SWITCH) p.print(token.SWITCH)
if s.Init != nil { if s.Init != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment