Commit 89313063 authored by Russ Cox's avatar Russ Cox

cmd/gc: reject non-Go constants

Expressions involving nil, even if they can be evaluated
at compile time, do not count as Go constants and cannot
be used in const initializers.

Fixes #4673.
Fixes #4680.

R=ken2
CC=golang-dev
https://golang.org/cl/7278043
parent f607c479
...@@ -78,6 +78,7 @@ convlit1(Node **np, Type *t, int explicit) ...@@ -78,6 +78,7 @@ convlit1(Node **np, Type *t, int explicit)
if(!explicit && !isideal(n->type)) if(!explicit && !isideal(n->type))
return; return;
if(n->op == OLITERAL) { if(n->op == OLITERAL) {
nn = nod(OXXX, N, N); nn = nod(OXXX, N, N);
*nn = *n; *nn = *n;
...@@ -953,10 +954,6 @@ ret: ...@@ -953,10 +954,6 @@ ret:
*n = *nl; *n = *nl;
// restore value of n->orig. // restore value of n->orig.
n->orig = norig; n->orig = norig;
if(norig->op == OCONV) {
dump("N", n);
dump("NORIG", norig);
}
n->val = v; n->val = v;
// check range. // check range.
...@@ -1449,3 +1446,132 @@ cmplxdiv(Mpcplx *v, Mpcplx *rv) ...@@ -1449,3 +1446,132 @@ cmplxdiv(Mpcplx *v, Mpcplx *rv)
mpsubfltflt(&v->imag, &ad); // bc-ad mpsubfltflt(&v->imag, &ad); // bc-ad
mpdivfltflt(&v->imag, &cc_plus_dd); // (bc+ad)/(cc+dd) mpdivfltflt(&v->imag, &cc_plus_dd); // (bc+ad)/(cc+dd)
} }
static int hascallchan(Node*);
// Is n a Go language constant (as opposed to a compile-time constant)?
// Expressions derived from nil, like string([]byte(nil)), while they
// may be known at compile time, are not Go language constants.
// Only called for expressions known to evaluated to compile-time
// constants.
int
isgoconst(Node *n)
{
Node *l;
Type *t;
if(n->orig != N)
n = n->orig;
switch(n->op) {
case OADD:
case OADDSTR:
case OAND:
case OANDAND:
case OANDNOT:
case OCOM:
case ODIV:
case OEQ:
case OGE:
case OGT:
case OLE:
case OLSH:
case OLT:
case OMINUS:
case OMOD:
case OMUL:
case ONE:
case ONOT:
case OOR:
case OOROR:
case OPLUS:
case ORSH:
case OSUB:
case OXOR:
case OCONV:
case OIOTA:
case OCOMPLEX:
case OREAL:
case OIMAG:
if(isgoconst(n->left) && (n->right == N || isgoconst(n->right)))
return 1;
break;
case OLEN:
case OCAP:
l = n->left;
if(isgoconst(l))
return 1;
// Special case: len/cap is constant when applied to array or
// pointer to array when the expression does not contain
// function calls or channel receive operations.
t = l->type;
if(t != T && isptr[t->etype])
t = t->type;
if(isfixedarray(t) && !hascallchan(l))
return 1;
break;
case OLITERAL:
if(n->val.ctype != CTNIL)
return 1;
break;
case ONAME:
l = n->sym->def;
if(l->op == OLITERAL && n->val.ctype != CTNIL)
return 1;
break;
case ONONAME:
if(n->sym->def != N && n->sym->def->op == OIOTA)
return 1;
break;
case OCALL:
// Only constant calls are unsafe.Alignof, Offsetof, and Sizeof.
l = n->left;
while(l->op == OPAREN)
l = l->left;
if(l->op != ONAME || l->sym->pkg != unsafepkg)
break;
if(strcmp(l->sym->name, "Alignof") == 0 ||
strcmp(l->sym->name, "Offsetof") == 0 ||
strcmp(l->sym->name, "Sizeof") == 0)
return 1;
break;
}
//dump("nonconst", n);
return 0;
}
static int
hascallchan(Node *n)
{
NodeList *l;
if(n == N)
return 0;
switch(n->op) {
case OCALL:
case OCALLFUNC:
case OCALLMETH:
case OCALLINTER:
case ORECV:
return 1;
}
if(hascallchan(n->left) ||
hascallchan(n->right))
return 1;
for(l=n->list; l; l=l->next)
if(hascallchan(l->n))
return 1;
for(l=n->rlist; l; l=l->next)
if(hascallchan(l->n))
return 1;
return 0;
}
...@@ -997,6 +997,7 @@ void defaultlit(Node **np, Type *t); ...@@ -997,6 +997,7 @@ void defaultlit(Node **np, Type *t);
void defaultlit2(Node **lp, Node **rp, int force); void defaultlit2(Node **lp, Node **rp, int force);
void evconst(Node *n); void evconst(Node *n);
int isconst(Node *n, int ct); int isconst(Node *n, int ct);
int isgoconst(Node *n);
Node* nodcplxlit(Val r, Val i); Node* nodcplxlit(Val r, Val i);
Node* nodlit(Val v); Node* nodlit(Val v);
long nonnegconst(Node *n); long nonnegconst(Node *n);
......
...@@ -840,6 +840,7 @@ treecopy(Node *n) ...@@ -840,6 +840,7 @@ treecopy(Node *n)
default: default:
m = nod(OXXX, N, N); m = nod(OXXX, N, N);
*m = *n; *m = *n;
m->orig = m;
m->left = treecopy(n->left); m->left = treecopy(n->left);
m->right = treecopy(n->right); m->right = treecopy(n->right);
m->list = listtreecopy(n->list); m->list = listtreecopy(n->list);
......
...@@ -1336,6 +1336,9 @@ reswitch: ...@@ -1336,6 +1336,9 @@ reswitch:
case OCONV: case OCONV:
doconv: doconv:
ok |= Erv; ok |= Erv;
l = nod(OXXX, N, N);
n->orig = l;
*l = *n;
typecheck(&n->left, Erv | (top & (Eindir | Eiota))); typecheck(&n->left, Erv | (top & (Eindir | Eiota)));
convlit1(&n->left, n->type, 1); convlit1(&n->left, n->type, 1);
if((t = n->left->type) == T || n->type == T) if((t = n->left->type) == T || n->type == T)
...@@ -3007,14 +3010,14 @@ typecheckdef(Node *n) ...@@ -3007,14 +3010,14 @@ typecheckdef(Node *n)
yyerror("xxx"); yyerror("xxx");
} }
typecheck(&e, Erv | Eiota); typecheck(&e, Erv | Eiota);
if(e->type != T && e->op != OLITERAL) {
yyerror("const initializer must be constant");
goto ret;
}
if(isconst(e, CTNIL)) { if(isconst(e, CTNIL)) {
yyerror("const initializer cannot be nil"); yyerror("const initializer cannot be nil");
goto ret; goto ret;
} }
if(e->type != T && e->op != OLITERAL || !isgoconst(e)) {
yyerror("const initializer %N is not a constant", e);
goto ret;
}
t = n->type; t = n->type;
if(t != T) { if(t != T) {
if(!okforconst[t->etype]) { if(!okforconst[t->etype]) {
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
package main package main
import "unsafe"
type I interface{} type I interface{}
const ( const (
...@@ -86,3 +88,7 @@ func main() { ...@@ -86,3 +88,7 @@ func main() {
} }
const ptr = nil // ERROR "const.*nil" const ptr = nil // ERROR "const.*nil"
const _ = string([]byte(nil)) // ERROR "is not a constant"
const _ = uintptr(unsafe.Pointer((*int)(nil))) // ERROR "is not a constant"
const _ = unsafe.Pointer((*int)(nil)) // ERROR "cannot be nil"
const _ = (*int)(nil) // ERROR "cannot be nil"
...@@ -24,10 +24,10 @@ const ( ...@@ -24,10 +24,10 @@ const (
n2 = len(m[""]) n2 = len(m[""])
n3 = len(s[10]) n3 = len(s[10])
n4 = len(f()) // ERROR "must be constant|is not constant" n4 = len(f()) // ERROR "is not a constant|is not constant"
n5 = len(<-c) // ERROR "must be constant|is not constant" n5 = len(<-c) // ERROR "is not a constant|is not constant"
n6 = cap(f()) // ERROR "must be constant|is not constant" n6 = cap(f()) // ERROR "is not a constant|is not constant"
n7 = cap(<-c) // ERROR "must be constant|is not constant" n7 = cap(<-c) // ERROR "is not a constant|is not constant"
) )
...@@ -11,5 +11,5 @@ package main ...@@ -11,5 +11,5 @@ package main
type ByteSize float64 type ByteSize float64
const ( const (
_ = iota; // ignore first value by assigning to blank identifier _ = iota; // ignore first value by assigning to blank identifier
KB ByteSize = 1<<(10*X) // ERROR "undefined" "as type ByteSize" KB ByteSize = 1<<(10*X) // ERROR "undefined" "is not a constant|as type ByteSize"
) )
...@@ -7,5 +7,5 @@ ...@@ -7,5 +7,5 @@
package foo package foo
var s [][10]int var s [][10]int
const m = len(s[len(s)-1]) // ERROR "must be constant" const m = len(s[len(s)-1]) // ERROR "is not a constant"
...@@ -48,7 +48,7 @@ func f() { ...@@ -48,7 +48,7 @@ func f() {
defer recover() // ok defer recover() // ok
int(0) // ERROR "int\(0\) evaluated but not used" int(0) // ERROR "int\(0\) evaluated but not used"
string([]byte("abc")) // ERROR "string\(\[\]byte literal\) evaluated but not used" string([]byte("abc")) // ERROR "string\(.*\) evaluated but not used"
append(x, 1) // ERROR "not used" append(x, 1) // ERROR "not used"
cap(x) // ERROR "not used" cap(x) // ERROR "not used"
......
...@@ -683,6 +683,7 @@ func (t *test) errorCheck(outStr string, fullshort ...string) (err error) { ...@@ -683,6 +683,7 @@ func (t *test) errorCheck(outStr string, fullshort ...string) (err error) {
continue continue
} }
matched := false matched := false
n := len(out)
for _, errmsg := range errmsgs { for _, errmsg := range errmsgs {
if we.re.MatchString(errmsg) { if we.re.MatchString(errmsg) {
matched = true matched = true
...@@ -691,7 +692,7 @@ func (t *test) errorCheck(outStr string, fullshort ...string) (err error) { ...@@ -691,7 +692,7 @@ func (t *test) errorCheck(outStr string, fullshort ...string) (err error) {
} }
} }
if !matched { if !matched {
errs = append(errs, fmt.Errorf("%s:%d: no match for %q in%s", we.file, we.lineNum, we.reStr, strings.Join(out, "\n"))) errs = append(errs, fmt.Errorf("%s:%d: no match for %#q in:\n\t%s", we.file, we.lineNum, we.reStr, strings.Join(out[n:], "\n\t")))
continue continue
} }
} }
...@@ -758,7 +759,7 @@ func (t *test) wantedErrors(file, short string) (errs []wantedError) { ...@@ -758,7 +759,7 @@ func (t *test) wantedErrors(file, short string) (errs []wantedError) {
all := m[1] all := m[1]
mm := errQuotesRx.FindAllStringSubmatch(all, -1) mm := errQuotesRx.FindAllStringSubmatch(all, -1)
if mm == nil { if mm == nil {
log.Fatalf("invalid errchk line in %s: %s", t.goFileName(), line) log.Fatalf("%s:%d: invalid errchk line: %s", t.goFileName(), lineNum, line)
} }
for _, m := range mm { for _, m := range mm {
rx := lineRx.ReplaceAllStringFunc(m[1], func(m string) string { rx := lineRx.ReplaceAllStringFunc(m[1], func(m string) string {
...@@ -772,10 +773,14 @@ func (t *test) wantedErrors(file, short string) (errs []wantedError) { ...@@ -772,10 +773,14 @@ func (t *test) wantedErrors(file, short string) (errs []wantedError) {
} }
return fmt.Sprintf("%s:%d", short, n) return fmt.Sprintf("%s:%d", short, n)
}) })
filterPattern := fmt.Sprintf(`^(\w+/)?%s:%d[:[]`, short, lineNum) re, err := regexp.Compile(rx)
if err != nil {
log.Fatalf("%s:%d: invalid regexp in ERROR line: %v", t.goFileName(), lineNum, err)
}
filterPattern := fmt.Sprintf(`^(\w+/)?%s:%d[:[]`, regexp.QuoteMeta(short), lineNum)
errs = append(errs, wantedError{ errs = append(errs, wantedError{
reStr: rx, reStr: rx,
re: regexp.MustCompile(rx), re: re,
filterRe: regexp.MustCompile(filterPattern), filterRe: regexp.MustCompile(filterPattern),
lineNum: lineNum, lineNum: lineNum,
file: short, file: short,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment