Commit 6fb1bf26 authored by Roger Peppe's avatar Roger Peppe Committed by Rob Pike

netchan: do not block sends; implement flow control.

When data is received for a channel, but that channel
is not ready to receive it, the central run() loop
is currently blocked, but this can lead to deadlock
and interference between independent channels.
This CL adds an explicit buffer size to netchan
channels (an API change) - the sender will not
send values until the buffer is non empty.

The protocol changes to send ids rather than channel names
because acks can still be sent after a channel is hung up,
we we need an identifier that can be ignored.

R=r, rsc
CC=golang-dev
https://golang.org/cl/2447042
parent 89993544
...@@ -38,21 +38,24 @@ const ( ...@@ -38,21 +38,24 @@ const (
payData // user payload follows payData // user payload follows
payAck // acknowledgement; no payload payAck // acknowledgement; no payload
payClosed // channel is now closed payClosed // channel is now closed
payAckSend // payload has been delivered.
) )
// A header is sent as a prefix to every transmission. It will be followed by // A header is sent as a prefix to every transmission. It will be followed by
// a request structure, an error structure, or an arbitrary user payload structure. // a request structure, an error structure, or an arbitrary user payload structure.
type header struct { type header struct {
Name string Id int
PayloadType int PayloadType int
SeqNum int64 SeqNum int64
} }
// Sent with a header once per channel from importer to exporter to report // Sent with a header once per channel from importer to exporter to report
// that it wants to bind to a channel with the specified direction for count // that it wants to bind to a channel with the specified direction for count
// messages. If count is -1, it means unlimited. // messages, with space for size buffered values. If count is -1, it means unlimited.
type request struct { type request struct {
Name string
Count int64 Count int64
Size int
Dir Dir Dir Dir
} }
...@@ -78,7 +81,7 @@ type chanDir struct { ...@@ -78,7 +81,7 @@ type chanDir struct {
// clients of an exporter and draining outstanding messages. // clients of an exporter and draining outstanding messages.
type clientSet struct { type clientSet struct {
mu sync.Mutex // protects access to channel and client maps mu sync.Mutex // protects access to channel and client maps
chans map[string]*chanDir names map[string]*chanDir
clients map[unackedCounter]bool clients map[unackedCounter]bool
} }
...@@ -132,7 +135,7 @@ func (cs *clientSet) drain(timeout int64) os.Error { ...@@ -132,7 +135,7 @@ func (cs *clientSet) drain(timeout int64) os.Error {
pending := false pending := false
cs.mu.Lock() cs.mu.Lock()
// Any messages waiting for a client? // Any messages waiting for a client?
for _, chDir := range cs.chans { for _, chDir := range cs.names {
if chDir.ch.Len() > 0 { if chDir.ch.Len() > 0 {
pending = true pending = true
} }
...@@ -189,3 +192,134 @@ func (cs *clientSet) sync(timeout int64) os.Error { ...@@ -189,3 +192,134 @@ func (cs *clientSet) sync(timeout int64) os.Error {
} }
return nil return nil
} }
// A netChan represents a channel imported or exported
// on a single connection. Flow is controlled by the receiving
// side by sending payAckSend messages when values
// are delivered into the local channel.
type netChan struct {
*chanDir
name string
id int
size int // buffer size of channel.
// sender-specific state
ackCh chan bool // buffered with space for all the acks we need
space int // available space.
// receiver-specific state
sendCh chan reflect.Value // buffered channel of values received from other end.
ed *encDec // so that we can send acks.
count int64 // number of values still to receive.
}
// Create a new netChan with the given name (only used for
// messages), id, direction, buffer size, and count.
// The connection to the other side is represented by ed.
func newNetChan(name string, id int, ch *chanDir, ed *encDec, size int, count int64) *netChan {
c := &netChan{chanDir: ch, name: name, id: id, size: size, ed: ed, count: count}
if c.dir == Send {
c.ackCh = make(chan bool, size)
c.space = size
}
return c
}
// Close the channel.
func (nch *netChan) close() {
if nch.dir == Recv {
if nch.sendCh != nil {
// If the sender goroutine is active, close the channel to it.
// It will close nch.ch when it can.
close(nch.sendCh)
} else {
nch.ch.Close()
}
} else {
nch.ch.Close()
close(nch.ackCh)
}
}
// Send message from remote side to local receiver.
func (nch *netChan) send(val reflect.Value) {
if nch.dir != Recv {
panic("send on wrong direction of channel")
}
if nch.sendCh == nil {
// If possible, do local send directly and ack immediately.
if nch.ch.TrySend(val) {
nch.sendAck()
return
}
// Start sender goroutine to manage delayed delivery of values.
nch.sendCh = make(chan reflect.Value, nch.size)
go nch.sender()
}
if ok := nch.sendCh <- val; !ok {
// TODO: should this be more resilient?
panic("netchan: remote sender sent more values than allowed")
}
}
// sendAck sends an acknowledgment that a message has left
// the channel's buffer. If the messages remaining to be sent
// will fit in the channel's buffer, then we don't
// need to send an ack.
func (nch *netChan) sendAck() {
if nch.count < 0 || nch.count > int64(nch.size) {
nch.ed.encode(&header{Id: nch.id}, payAckSend, nil)
}
if nch.count > 0 {
nch.count--
}
}
// The sender process forwards items from the sending queue
// to the destination channel, acknowledging each item.
func (nch *netChan) sender() {
if nch.dir != Recv {
panic("sender on wrong direction of channel")
}
// When Exporter.Hangup is called, the underlying channel is closed,
// and so we may get a "too many operations on closed channel" error
// if there are outstanding messages in sendCh.
// Make sure that this doesn't panic the whole program.
defer func() {
if r := recover(); r != nil {
// TODO check that r is "too many operations", otherwise re-panic.
}
}()
for v := range nch.sendCh {
nch.ch.Send(v)
nch.sendAck()
}
nch.ch.Close()
}
// Receive value from local side for sending to remote side.
func (nch *netChan) recv() (val reflect.Value, closed bool) {
if nch.dir != Send {
panic("recv on wrong direction of channel")
}
if nch.space == 0 {
// Wait for buffer space.
<-nch.ackCh
nch.space++
}
nch.space--
return nch.ch.Recv(), nch.ch.Closed()
}
// acked is called when the remote side indicates that
// a value has been delivered.
func (nch *netChan) acked() {
if nch.dir != Send {
panic("recv on wrong direction of channel")
}
if ok := nch.ackCh <- true; !ok {
panic("netchan: remote receiver sent too many acks")
// TODO: should this be more resilient?
}
}
...@@ -26,6 +26,7 @@ import ( ...@@ -26,6 +26,7 @@ import (
"net" "net"
"os" "os"
"reflect" "reflect"
"strconv"
"sync" "sync"
) )
...@@ -48,11 +49,12 @@ type Exporter struct { ...@@ -48,11 +49,12 @@ type Exporter struct {
type expClient struct { type expClient struct {
*encDec *encDec
exp *Exporter exp *Exporter
mu sync.Mutex // protects remaining fields chans map[int]*netChan // channels in use by client
errored bool // client has been sent an error mu sync.Mutex // protects remaining fields
seqNum int64 // sequences messages sent to client; has value of highest sent errored bool // client has been sent an error
ackNum int64 // highest sequence number acknowledged seqNum int64 // sequences messages sent to client; has value of highest sent
seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu ackNum int64 // highest sequence number acknowledged
seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu
} }
func newClient(exp *Exporter, conn net.Conn) *expClient { func newClient(exp *Exporter, conn net.Conn) *expClient {
...@@ -61,8 +63,8 @@ func newClient(exp *Exporter, conn net.Conn) *expClient { ...@@ -61,8 +63,8 @@ func newClient(exp *Exporter, conn net.Conn) *expClient {
client.encDec = newEncDec(conn) client.encDec = newEncDec(conn)
client.seqNum = 0 client.seqNum = 0
client.ackNum = 0 client.ackNum = 0
client.chans = make(map[int]*netChan)
return client return client
} }
func (client *expClient) sendError(hdr *header, err string) { func (client *expClient) sendError(hdr *header, err string) {
...@@ -74,20 +76,33 @@ func (client *expClient) sendError(hdr *header, err string) { ...@@ -74,20 +76,33 @@ func (client *expClient) sendError(hdr *header, err string) {
client.mu.Unlock() client.mu.Unlock()
} }
func (client *expClient) getChan(hdr *header, dir Dir) *chanDir { func (client *expClient) newChan(hdr *header, dir Dir, name string, size int, count int64) *netChan {
exp := client.exp exp := client.exp
exp.mu.Lock() exp.mu.Lock()
ech, ok := exp.chans[hdr.Name] ech, ok := exp.names[name]
exp.mu.Unlock() exp.mu.Unlock()
if !ok { if !ok {
client.sendError(hdr, "no such channel: "+hdr.Name) client.sendError(hdr, "no such channel: "+name)
return nil return nil
} }
if ech.dir != dir { if ech.dir != dir {
client.sendError(hdr, "wrong direction for channel: "+hdr.Name) client.sendError(hdr, "wrong direction for channel: "+name)
return nil
}
nch := newNetChan(name, hdr.Id, ech, client.encDec, size, count)
client.chans[hdr.Id] = nch
return nch
}
func (client *expClient) getChan(hdr *header, dir Dir) *netChan {
nch := client.chans[hdr.Id]
if nch == nil {
return nil return nil
} }
return ech if nch.dir != dir {
client.sendError(hdr, "wrong direction for channel: "+nch.name)
}
return nch
} }
// The function run manages sends and receives for a single client. For each // The function run manages sends and receives for a single client. For each
...@@ -113,12 +128,18 @@ func (client *expClient) run() { ...@@ -113,12 +128,18 @@ func (client *expClient) run() {
expLog("error decoding client request:", err) expLog("error decoding client request:", err)
break break
} }
if req.Size < 1 {
panic("netchan: remote requested " + strconv.Itoa(req.Size) + " values")
}
switch req.Dir { switch req.Dir {
case Recv: case Recv:
go client.serveRecv(*hdr, req.Count) // look up channel before calling serveRecv to
// avoid a lock around client.chans.
if nch := client.newChan(hdr, Send, req.Name, req.Size, req.Count); nch != nil {
go client.serveRecv(nch, *hdr, req.Count)
}
case Send: case Send:
// Request to send is clear as a matter of protocol client.newChan(hdr, Recv, req.Name, req.Size, req.Count)
// but not actually used by the implementation.
// The actual sends will have payload type payData. // The actual sends will have payload type payData.
// TODO: manage the count? // TODO: manage the count?
default: default:
...@@ -143,6 +164,10 @@ func (client *expClient) run() { ...@@ -143,6 +164,10 @@ func (client *expClient) run() {
client.ackNum = hdr.SeqNum client.ackNum = hdr.SeqNum
} }
client.mu.Unlock() client.mu.Unlock()
case payAckSend:
if nch := client.getChan(hdr, Send); nch != nil {
nch.acked()
}
default: default:
log.Exit("netchan export: unknown payload type", hdr.PayloadType) log.Exit("netchan export: unknown payload type", hdr.PayloadType)
} }
...@@ -152,14 +177,10 @@ func (client *expClient) run() { ...@@ -152,14 +177,10 @@ func (client *expClient) run() {
// Send all the data on a single channel to a client asking for a Recv. // Send all the data on a single channel to a client asking for a Recv.
// The header is passed by value to avoid issues of overwriting. // The header is passed by value to avoid issues of overwriting.
func (client *expClient) serveRecv(hdr header, count int64) { func (client *expClient) serveRecv(nch *netChan, hdr header, count int64) {
ech := client.getChan(&hdr, Send)
if ech == nil {
return
}
for { for {
val := ech.ch.Recv() val, closed := nch.recv()
if ech.ch.Closed() { if closed {
if err := client.encode(&hdr, payClosed, nil); err != nil { if err := client.encode(&hdr, payClosed, nil); err != nil {
expLog("error encoding server closed message:", err) expLog("error encoding server closed message:", err)
} }
...@@ -167,7 +188,7 @@ func (client *expClient) serveRecv(hdr header, count int64) { ...@@ -167,7 +188,7 @@ func (client *expClient) serveRecv(hdr header, count int64) {
} }
// We hold the lock during transmission to guarantee messages are // We hold the lock during transmission to guarantee messages are
// sent in sequence number order. Also, we increment first so the // sent in sequence number order. Also, we increment first so the
// value of client.seqNum is the value of the highest used sequence // value of client.SeqNum is the value of the highest used sequence
// number, not one beyond. // number, not one beyond.
client.mu.Lock() client.mu.Lock()
client.seqNum++ client.seqNum++
...@@ -193,27 +214,27 @@ func (client *expClient) serveRecv(hdr header, count int64) { ...@@ -193,27 +214,27 @@ func (client *expClient) serveRecv(hdr header, count int64) {
// Receive and deliver locally one item from a client asking for a Send // Receive and deliver locally one item from a client asking for a Send
// The header is passed by value to avoid issues of overwriting. // The header is passed by value to avoid issues of overwriting.
func (client *expClient) serveSend(hdr header) { func (client *expClient) serveSend(hdr header) {
ech := client.getChan(&hdr, Recv) nch := client.getChan(&hdr, Recv)
if ech == nil { if nch == nil {
return return
} }
// Create a new value for each received item. // Create a new value for each received item.
val := reflect.MakeZero(ech.ch.Type().(*reflect.ChanType).Elem()) val := reflect.MakeZero(nch.ch.Type().(*reflect.ChanType).Elem())
if err := client.decode(val); err != nil { if err := client.decode(val); err != nil {
expLog("value decode:", err) expLog("value decode:", err, "; type ", nch.ch.Type())
return return
} }
ech.ch.Send(val) nch.send(val)
} }
// Report that client has closed the channel that is sending to us. // Report that client has closed the channel that is sending to us.
// The header is passed by value to avoid issues of overwriting. // The header is passed by value to avoid issues of overwriting.
func (client *expClient) serveClosed(hdr header) { func (client *expClient) serveClosed(hdr header) {
ech := client.getChan(&hdr, Recv) nch := client.getChan(&hdr, Recv)
if ech == nil { if nch == nil {
return return
} }
ech.ch.Close() nch.close()
} }
func (client *expClient) unackedCount() int64 { func (client *expClient) unackedCount() int64 {
...@@ -260,7 +281,7 @@ func NewExporter(network, localaddr string) (*Exporter, os.Error) { ...@@ -260,7 +281,7 @@ func NewExporter(network, localaddr string) (*Exporter, os.Error) {
e := &Exporter{ e := &Exporter{
listener: listener, listener: listener,
clientSet: &clientSet{ clientSet: &clientSet{
chans: make(map[string]*chanDir), names: make(map[string]*chanDir),
clients: make(map[unackedCounter]bool), clients: make(map[unackedCounter]bool),
}, },
} }
...@@ -343,11 +364,11 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error { ...@@ -343,11 +364,11 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
} }
exp.mu.Lock() exp.mu.Lock()
defer exp.mu.Unlock() defer exp.mu.Unlock()
_, present := exp.chans[name] _, present := exp.names[name]
if present { if present {
return os.ErrorString("channel name already being exported:" + name) return os.ErrorString("channel name already being exported:" + name)
} }
exp.chans[name] = &chanDir{ch, dir} exp.names[name] = &chanDir{ch, dir}
return nil return nil
} }
...@@ -355,10 +376,11 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error { ...@@ -355,10 +376,11 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
// the channel. Messages in flight for the channel may be dropped. // the channel. Messages in flight for the channel may be dropped.
func (exp *Exporter) Hangup(name string) os.Error { func (exp *Exporter) Hangup(name string) os.Error {
exp.mu.Lock() exp.mu.Lock()
chDir, ok := exp.chans[name] chDir, ok := exp.names[name]
if ok { if ok {
exp.chans[name] = nil, false exp.names[name] = nil, false
} }
// TODO drop all instances of channel from client sets
exp.mu.Unlock() exp.mu.Unlock()
if !ok { if !ok {
return os.ErrorString("netchan export: hangup: no such channel: " + name) return os.ErrorString("netchan export: hangup: no such channel: " + name)
......
...@@ -27,8 +27,10 @@ type Importer struct { ...@@ -27,8 +27,10 @@ type Importer struct {
*encDec *encDec
conn net.Conn conn net.Conn
chanLock sync.Mutex // protects access to channel map chanLock sync.Mutex // protects access to channel map
chans map[string]*chanDir names map[string]*netChan
chans map[int]*netChan
errors chan os.Error errors chan os.Error
maxId int
} }
// NewImporter creates a new Importer object to import channels // NewImporter creates a new Importer object to import channels
...@@ -43,7 +45,8 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) { ...@@ -43,7 +45,8 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) {
imp := new(Importer) imp := new(Importer)
imp.encDec = newEncDec(conn) imp.encDec = newEncDec(conn)
imp.conn = conn imp.conn = conn
imp.chans = make(map[string]*chanDir) imp.chans = make(map[int]*netChan)
imp.names = make(map[string]*netChan)
imp.errors = make(chan os.Error, 10) imp.errors = make(chan os.Error, 10)
go imp.run() go imp.run()
return imp, nil return imp, nil
...@@ -54,7 +57,7 @@ func (imp *Importer) shutdown() { ...@@ -54,7 +57,7 @@ func (imp *Importer) shutdown() {
imp.chanLock.Lock() imp.chanLock.Lock()
for _, ich := range imp.chans { for _, ich := range imp.chans {
if ich.dir == Recv { if ich.dir == Recv {
ich.ch.Close() ich.close()
} }
} }
imp.chanLock.Unlock() imp.chanLock.Unlock()
...@@ -95,43 +98,53 @@ func (imp *Importer) run() { ...@@ -95,43 +98,53 @@ func (imp *Importer) run() {
continue // errors are not acknowledged. continue // errors are not acknowledged.
} }
case payClosed: case payClosed:
ich := imp.getChan(hdr.Name) nch := imp.getChan(hdr.Id, false)
if ich != nil { if nch != nil {
ich.ch.Close() nch.close()
} }
continue // closes are not acknowledged. continue // closes are not acknowledged.
case payAckSend:
// we can receive spurious acks if the channel is
// hung up, so we ask getChan to ignore any errors.
nch := imp.getChan(hdr.Id, true)
if nch != nil {
nch.acked()
}
continue
default: default:
impLog("unexpected payload type:", hdr.PayloadType) impLog("unexpected payload type:", hdr.PayloadType)
return return
} }
ich := imp.getChan(hdr.Name) nch := imp.getChan(hdr.Id, false)
if ich == nil { if nch == nil {
continue continue
} }
if ich.dir != Recv { if nch.dir != Recv {
impLog("cannot happen: receive from non-Recv channel") impLog("cannot happen: receive from non-Recv channel")
return return
} }
// Acknowledge receipt // Acknowledge receipt
ackHdr.Name = hdr.Name ackHdr.Id = hdr.Id
ackHdr.SeqNum = hdr.SeqNum ackHdr.SeqNum = hdr.SeqNum
imp.encode(ackHdr, payAck, nil) imp.encode(ackHdr, payAck, nil)
// Create a new value for each received item. // Create a new value for each received item.
value := reflect.MakeZero(ich.ch.Type().(*reflect.ChanType).Elem()) value := reflect.MakeZero(nch.ch.Type().(*reflect.ChanType).Elem())
if e := imp.decode(value); e != nil { if e := imp.decode(value); e != nil {
impLog("importer value decode:", e) impLog("importer value decode:", e)
return return
} }
ich.ch.Send(value) nch.send(value)
} }
} }
func (imp *Importer) getChan(name string) *chanDir { func (imp *Importer) getChan(id int, errOk bool) *netChan {
imp.chanLock.Lock() imp.chanLock.Lock()
ich := imp.chans[name] ich := imp.chans[id]
imp.chanLock.Unlock() imp.chanLock.Unlock()
if ich == nil { if ich == nil {
impLog("unknown name in netchan request:", name) if !errOk {
impLog("unknown id in netchan request: ", id)
}
return nil return nil
} }
return ich return ich
...@@ -145,17 +158,18 @@ func (imp *Importer) Errors() chan os.Error { ...@@ -145,17 +158,18 @@ func (imp *Importer) Errors() chan os.Error {
return imp.errors return imp.errors
} }
// Import imports a channel of the given type and specified direction. // Import imports a channel of the given type, size and specified direction.
// It is equivalent to ImportNValues with a count of -1, meaning unbounded. // It is equivalent to ImportNValues with a count of -1, meaning unbounded.
func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error { func (imp *Importer) Import(name string, chT interface{}, dir Dir, size int) os.Error {
return imp.ImportNValues(name, chT, dir, -1) return imp.ImportNValues(name, chT, dir, size, -1)
} }
// ImportNValues imports a channel of the given type and specified direction // ImportNValues imports a channel of the given type and specified
// and then receives or transmits up to n values on that channel. A value of // direction and then receives or transmits up to n values on that
// n==-1 implies an unbounded number of values. The channel to be bound to // channel. A value of n==-1 implies an unbounded number of values. The
// the remote site's channel is provided in the call and may be of arbitrary // channel will have buffer space for size values, or 1 value if size < 1.
// channel type. // The channel to be bound to the remote site's channel is provided
// in the call and may be of arbitrary channel type.
// Despite the literal signature, the effective signature is // Despite the literal signature, the effective signature is
// ImportNValues(name string, chT chan T, dir Dir, n int) os.Error // ImportNValues(name string, chT chan T, dir Dir, n int) os.Error
// Example usage: // Example usage:
...@@ -165,21 +179,28 @@ func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error { ...@@ -165,21 +179,28 @@ func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error {
// err = imp.ImportNValues("name", ch, Recv, 1) // err = imp.ImportNValues("name", ch, Recv, 1)
// if err != nil { log.Exit(err) } // if err != nil { log.Exit(err) }
// fmt.Printf("%+v\n", <-ch) // fmt.Printf("%+v\n", <-ch)
func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) os.Error { func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, size, n int) os.Error {
ch, err := checkChan(chT, dir) ch, err := checkChan(chT, dir)
if err != nil { if err != nil {
return err return err
} }
imp.chanLock.Lock() imp.chanLock.Lock()
defer imp.chanLock.Unlock() defer imp.chanLock.Unlock()
_, present := imp.chans[name] _, present := imp.names[name]
if present { if present {
return os.ErrorString("channel name already being imported:" + name) return os.ErrorString("channel name already being imported:" + name)
} }
imp.chans[name] = &chanDir{ch, dir} if size < 1 {
size = 1
}
id := imp.maxId
imp.maxId++
nch := newNetChan(name, id, &chanDir{ch, dir}, imp.encDec, size, int64(n))
imp.names[name] = nch
imp.chans[id] = nch
// Tell the other side about this channel. // Tell the other side about this channel.
hdr := &header{Name: name} hdr := &header{Id: id}
req := &request{Count: int64(n), Dir: dir} req := &request{Name: name, Count: int64(n), Dir: dir, Size: size}
if err = imp.encode(hdr, payRequest, req); err != nil { if err = imp.encode(hdr, payRequest, req); err != nil {
impLog("request encode:", err) impLog("request encode:", err)
return err return err
...@@ -187,8 +208,8 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) ...@@ -187,8 +208,8 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int)
if dir == Send { if dir == Send {
go func() { go func() {
for i := 0; n == -1 || i < n; i++ { for i := 0; n == -1 || i < n; i++ {
val := ch.Recv() val, closed := nch.recv()
if ch.Closed() { if closed {
if err = imp.encode(hdr, payClosed, nil); err != nil { if err = imp.encode(hdr, payClosed, nil); err != nil {
impLog("error encoding client closed message:", err) impLog("error encoding client closed message:", err)
} }
...@@ -208,14 +229,15 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) ...@@ -208,14 +229,15 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int)
// the channel. Messages in flight for the channel may be dropped. // the channel. Messages in flight for the channel may be dropped.
func (imp *Importer) Hangup(name string) os.Error { func (imp *Importer) Hangup(name string) os.Error {
imp.chanLock.Lock() imp.chanLock.Lock()
chDir, ok := imp.chans[name] nc, ok := imp.names[name]
if ok { if ok {
imp.chans[name] = nil, false imp.names[name] = nil, false
imp.chans[nc.id] = nil, false
} }
imp.chanLock.Unlock() imp.chanLock.Unlock()
if !ok { if !ok {
return os.ErrorString("netchan import: hangup: no such channel: " + name) return os.ErrorString("netchan import: hangup: no such channel: " + name)
} }
chDir.ch.Close() nc.close()
return nil return nil
} }
...@@ -15,7 +15,7 @@ const closeCount = 5 // number of items when sender closes early ...@@ -15,7 +15,7 @@ const closeCount = 5 // number of items when sender closes early
const base = 23 const base = 23
func exportSend(exp *Exporter, n int, t *testing.T) { func exportSend(exp *Exporter, n int, t *testing.T, done chan bool) {
ch := make(chan int) ch := make(chan int)
err := exp.Export("exportedSend", ch, Send) err := exp.Export("exportedSend", ch, Send)
if err != nil { if err != nil {
...@@ -26,6 +26,9 @@ func exportSend(exp *Exporter, n int, t *testing.T) { ...@@ -26,6 +26,9 @@ func exportSend(exp *Exporter, n int, t *testing.T) {
ch <- base+i ch <- base+i
} }
close(ch) close(ch)
if done != nil {
done <- true
}
}() }()
} }
...@@ -50,9 +53,9 @@ func exportReceive(exp *Exporter, t *testing.T, expDone chan bool) { ...@@ -50,9 +53,9 @@ func exportReceive(exp *Exporter, t *testing.T, expDone chan bool) {
} }
} }
func importSend(imp *Importer, n int, t *testing.T) { func importSend(imp *Importer, n int, t *testing.T, done chan bool) {
ch := make(chan int) ch := make(chan int)
err := imp.ImportNValues("exportedRecv", ch, Send, count) err := imp.ImportNValues("exportedRecv", ch, Send, 3, -1)
if err != nil { if err != nil {
t.Fatal("importSend:", err) t.Fatal("importSend:", err)
} }
...@@ -61,12 +64,15 @@ func importSend(imp *Importer, n int, t *testing.T) { ...@@ -61,12 +64,15 @@ func importSend(imp *Importer, n int, t *testing.T) {
ch <- base+i ch <- base+i
} }
close(ch) close(ch)
if done != nil {
done <- true
}
}() }()
} }
func importReceive(imp *Importer, t *testing.T, done chan bool) { func importReceive(imp *Importer, t *testing.T, done chan bool) {
ch := make(chan int) ch := make(chan int)
err := imp.ImportNValues("exportedSend", ch, Recv, count) err := imp.ImportNValues("exportedSend", ch, Recv, 3, count)
if err != nil { if err != nil {
t.Fatal("importReceive:", err) t.Fatal("importReceive:", err)
} }
...@@ -78,7 +84,7 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) { ...@@ -78,7 +84,7 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) {
} }
break break
} }
if v != 23+i { if v != base+i {
t.Errorf("importReceive: bad value: expected %d+%d=%d; got %+d", base, i, base+i, v) t.Errorf("importReceive: bad value: expected %d+%d=%d; got %+d", base, i, base+i, v)
} }
} }
...@@ -96,7 +102,7 @@ func TestExportSendImportReceive(t *testing.T) { ...@@ -96,7 +102,7 @@ func TestExportSendImportReceive(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("new importer:", err) t.Fatal("new importer:", err)
} }
exportSend(exp, count, t) exportSend(exp, count, t, nil)
importReceive(imp, t, nil) importReceive(imp, t, nil)
} }
...@@ -116,7 +122,7 @@ func TestExportReceiveImportSend(t *testing.T) { ...@@ -116,7 +122,7 @@ func TestExportReceiveImportSend(t *testing.T) {
done <- true done <- true
}() }()
<-expDone <-expDone
importSend(imp, count, t) importSend(imp, count, t, nil)
<-done <-done
} }
...@@ -129,7 +135,7 @@ func TestClosingExportSendImportReceive(t *testing.T) { ...@@ -129,7 +135,7 @@ func TestClosingExportSendImportReceive(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("new importer:", err) t.Fatal("new importer:", err)
} }
exportSend(exp, closeCount, t) exportSend(exp, closeCount, t, nil)
importReceive(imp, t, nil) importReceive(imp, t, nil)
} }
...@@ -149,7 +155,7 @@ func TestClosingImportSendExportReceive(t *testing.T) { ...@@ -149,7 +155,7 @@ func TestClosingImportSendExportReceive(t *testing.T) {
done <- true done <- true
}() }()
<-expDone <-expDone
importSend(imp, closeCount, t) importSend(imp, closeCount, t, nil)
<-done <-done
} }
...@@ -172,7 +178,7 @@ func TestErrorForIllegalChannel(t *testing.T) { ...@@ -172,7 +178,7 @@ func TestErrorForIllegalChannel(t *testing.T) {
close(ch) close(ch)
// Now try to import a different channel. // Now try to import a different channel.
ch = make(chan int) ch = make(chan int)
err = imp.Import("notAChannel", ch, Recv) err = imp.Import("notAChannel", ch, Recv, 1)
if err != nil { if err != nil {
t.Fatal("import:", err) t.Fatal("import:", err)
} }
...@@ -204,7 +210,7 @@ func TestExportDrain(t *testing.T) { ...@@ -204,7 +210,7 @@ func TestExportDrain(t *testing.T) {
} }
done := make(chan bool) done := make(chan bool)
go func() { go func() {
exportSend(exp, closeCount, t) exportSend(exp, closeCount, t, nil)
done <- true done <- true
}() }()
<-done <-done
...@@ -224,7 +230,7 @@ func TestExportSync(t *testing.T) { ...@@ -224,7 +230,7 @@ func TestExportSync(t *testing.T) {
t.Fatal("new importer:", err) t.Fatal("new importer:", err)
} }
done := make(chan bool) done := make(chan bool)
exportSend(exp, closeCount, t) exportSend(exp, closeCount, t, nil)
go importReceive(imp, t, done) go importReceive(imp, t, done)
exp.Sync(0) exp.Sync(0)
<-done <-done
...@@ -248,7 +254,7 @@ func TestExportHangup(t *testing.T) { ...@@ -248,7 +254,7 @@ func TestExportHangup(t *testing.T) {
} }
// Prepare to receive two values. We'll actually deliver only one. // Prepare to receive two values. We'll actually deliver only one.
ich := make(chan int) ich := make(chan int)
err = imp.ImportNValues("exportedSend", ich, Recv, 2) err = imp.ImportNValues("exportedSend", ich, Recv, 1, 2)
if err != nil { if err != nil {
t.Fatal("import exportedSend:", err) t.Fatal("import exportedSend:", err)
} }
...@@ -285,7 +291,7 @@ func TestImportHangup(t *testing.T) { ...@@ -285,7 +291,7 @@ func TestImportHangup(t *testing.T) {
} }
// Prepare to Send two values. We'll actually deliver only one. // Prepare to Send two values. We'll actually deliver only one.
ich := make(chan int) ich := make(chan int)
err = imp.ImportNValues("exportedRecv", ich, Send, 2) err = imp.ImportNValues("exportedRecv", ich, Send, 1, 2)
if err != nil { if err != nil {
t.Fatal("import exportedRecv:", err) t.Fatal("import exportedRecv:", err)
} }
...@@ -304,10 +310,70 @@ func TestImportHangup(t *testing.T) { ...@@ -304,10 +310,70 @@ func TestImportHangup(t *testing.T) {
} }
} }
// loop back exportedRecv to exportedSend,
// but receive a value from ctlch before starting the loop.
func exportLoopback(exp *Exporter, t *testing.T) {
inch := make(chan int)
if err := exp.Export("exportedRecv", inch, Recv); err != nil {
t.Fatal("exportRecv")
}
outch := make(chan int)
if err := exp.Export("exportedSend", outch, Send); err != nil {
t.Fatal("exportSend")
}
ctlch := make(chan int)
if err := exp.Export("exportedCtl", ctlch, Recv); err != nil {
t.Fatal("exportRecv")
}
go func() {
<-ctlch
for i := 0; i < count; i++ {
x := <-inch
if x != base+i {
t.Errorf("exportLoopback expected %d; got %d", i, x)
}
outch <- x
}
}()
}
// This test checks that channel operations can proceed
// even when other concurrent operations are blocked.
func TestIndependentSends(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
exportLoopback(exp, t)
importSend(imp, count, t, nil)
done := make(chan bool)
go importReceive(imp, t, done)
// wait for export side to try to deliver some values.
time.Sleep(0.25e9)
ctlch := make(chan int)
if err := imp.ImportNValues("exportedCtl", ctlch, Send, 1, 1); err != nil {
t.Fatal("importSend:", err)
}
ctlch <- 0
<-done
}
// This test cross-connects a pair of exporter/importer pairs. // This test cross-connects a pair of exporter/importer pairs.
type value struct { type value struct {
i int I int
source string Source string
} }
func TestCrossConnect(t *testing.T) { func TestCrossConnect(t *testing.T) {
...@@ -353,13 +419,13 @@ func crossExport(e1, e2 *Exporter, t *testing.T) { ...@@ -353,13 +419,13 @@ func crossExport(e1, e2 *Exporter, t *testing.T) {
// Import side of cross-traffic. // Import side of cross-traffic.
func crossImport(i1, i2 *Importer, t *testing.T) { func crossImport(i1, i2 *Importer, t *testing.T) {
s := make(chan value) s := make(chan value)
err := i2.Import("exportedReceive", s, Send) err := i2.Import("exportedReceive", s, Send, 2)
if err != nil { if err != nil {
t.Fatal("import of exportedReceive:", err) t.Fatal("import of exportedReceive:", err)
} }
r := make(chan value) r := make(chan value)
err = i1.Import("exportedSend", r, Recv) err = i1.Import("exportedSend", r, Recv, 2)
if err != nil { if err != nil {
t.Fatal("import of exported Send:", err) t.Fatal("import of exported Send:", err)
} }
...@@ -374,10 +440,76 @@ func crossLoop(name string, s, r chan value, t *testing.T) { ...@@ -374,10 +440,76 @@ func crossLoop(name string, s, r chan value, t *testing.T) {
case s <- value{si, name}: case s <- value{si, name}:
si++ si++
case v := <-r: case v := <-r:
if v.i != ri { if v.I != ri {
t.Errorf("loop: bad value: expected %d, hello; got %+v", ri, v) t.Errorf("loop: bad value: expected %d, hello; got %+v", ri, v)
} }
ri++ ri++
} }
} }
} }
const flowCount = 100
// test flow control from exporter to importer.
func TestExportFlowControl(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
sendDone := make(chan bool, 1)
exportSend(exp, flowCount, t, sendDone)
ch := make(chan int)
err = imp.ImportNValues("exportedSend", ch, Recv, 20, -1)
if err != nil {
t.Fatal("importReceive:", err)
}
testFlow(sendDone, ch, flowCount, t)
}
// test flow control from importer to exporter.
func TestImportFlowControl(t *testing.T) {
exp, err := NewExporter("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
ch := make(chan int)
err = exp.Export("exportedRecv", ch, Recv)
if err != nil {
t.Fatal("importReceive:", err)
}
sendDone := make(chan bool, 1)
importSend(imp, flowCount, t, sendDone)
testFlow(sendDone, ch, flowCount, t)
}
func testFlow(sendDone chan bool, ch <-chan int, N int, t *testing.T) {
go func() {
time.Sleep(1e9)
sendDone <- false
}()
if <-sendDone {
t.Fatal("send did not block")
}
n := 0
for i := range ch {
t.Log("after blocking, got value ", i)
n++
}
if n != N {
t.Fatalf("expected %d values; got %d", N, n)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment