Commit 1b56947f authored by Rob Pike's avatar Rob Pike

netchan: allow client to send as well as receive.

much rewriting and improving, but it's still experimental.

R=rsc
CC=golang-dev
https://golang.org/cl/875045
parent effddcad
// Copyright 2009 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
...@@ -6,12 +6,12 @@ package netchan ...@@ -6,12 +6,12 @@ package netchan
import ( import (
"gob" "gob"
"log"
"net" "net"
"os" "os"
"sync" "sync"
) )
// The direction of a connection from the client's perspective.
type Dir int type Dir int
const ( const (
...@@ -19,8 +19,34 @@ const ( ...@@ -19,8 +19,34 @@ const (
Send Send
) )
// Mutex-protected encoder and decoder pair // Payload types
const (
payRequest = iota // request structure follows
payError // error structure follows
payData // user payload follows
)
// A header is sent as a prefix to every transmission. It will be followed by
// a request structure, an error structure, or an arbitrary user payload structure.
type header struct {
name string
payloadType int
}
// Sent with a header once per channel from importer to exporter to report
// that it wants to bind to a channel with the specified direction for count
// messages. If count is zero, it means unlimited.
type request struct {
count int
dir Dir
}
// Sent with a header to report an error.
type error struct {
error string
}
// Mutex-protected encoder and decoder pair.
type encDec struct { type encDec struct {
decLock sync.Mutex decLock sync.Mutex
dec *gob.Decoder dec *gob.Decoder
...@@ -35,29 +61,27 @@ func newEncDec(conn net.Conn) *encDec { ...@@ -35,29 +61,27 @@ func newEncDec(conn net.Conn) *encDec {
} }
} }
// Decode an item from the connection.
func (ed *encDec) decode(e interface{}) os.Error { func (ed *encDec) decode(e interface{}) os.Error {
ed.decLock.Lock() ed.decLock.Lock()
defer ed.decLock.Unlock()
err := ed.dec.Decode(e) err := ed.dec.Decode(e)
if err != nil { if err != nil {
log.Stderr("exporter decode:", err) // TODO: tear down connection?
// TODO: tear down connection
return err
} }
return nil ed.decLock.Unlock()
return err
} }
func (ed *encDec) encode(e0, e1 interface{}) os.Error { // Encode a header and payload onto the connection.
func (ed *encDec) encode(hdr *header, payloadType int, payload interface{}) os.Error {
ed.encLock.Lock() ed.encLock.Lock()
defer ed.encLock.Unlock() hdr.payloadType = payloadType
err := ed.enc.Encode(e0) err := ed.enc.Encode(hdr)
if err == nil && e1 != nil { if err == nil {
err = ed.enc.Encode(e1) err = ed.enc.Encode(payload)
} } else {
if err != nil { // TODO: tear down connection if there is an error?
log.Stderr("exporter encode:", err)
// TODO: tear down connection?
return err
} }
return nil ed.encLock.Unlock()
return err
} }
// Copyright 2009 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
...@@ -17,10 +17,6 @@ ...@@ -17,10 +17,6 @@
Networked channels are not synchronized; they always behave Networked channels are not synchronized; they always behave
as if there is a buffer of at least one element between the as if there is a buffer of at least one element between the
two machines. two machines.
TODO: at the moment, the exporting machine must send and
the importing machine must receive. This restriction will
be lifted soon.
*/ */
package netchan package netchan
...@@ -34,10 +30,12 @@ import ( ...@@ -34,10 +30,12 @@ import (
// Export // Export
// A channel and its associated information: a direction // A channel and its associated information: a direction plus
// a handy marshaling place for its data.
type exportChan struct { type exportChan struct {
ch *reflect.ChanValue ch *reflect.ChanValue
dir Dir dir Dir
ptr *reflect.PtrValue // a pointer value we can point at each new received item
} }
// An Exporter allows a set of channels to be published on a single // An Exporter allows a set of channels to be published on a single
...@@ -62,21 +60,6 @@ func newClient(exp *Exporter, conn net.Conn) *expClient { ...@@ -62,21 +60,6 @@ func newClient(exp *Exporter, conn net.Conn) *expClient {
} }
// TODO: ASSUMES EXPORT MEANS SEND
// Sent once per channel from importer to exporter to report that it's listening to a channel
type request struct {
name string
dir Dir
count int
}
// Reply to request, sent from exporter to importer on each send.
type response struct {
name string
error string
}
// Wait for incoming connections, start a new runner for each // Wait for incoming connections, start a new runner for each
func (exp *Exporter) listen() { func (exp *Exporter) listen() {
for { for {
...@@ -85,70 +68,112 @@ func (exp *Exporter) listen() { ...@@ -85,70 +68,112 @@ func (exp *Exporter) listen() {
log.Stderr("exporter.listen:", err) log.Stderr("exporter.listen:", err)
break break
} }
log.Stderr("accepted call from", conn.RemoteAddr())
client := newClient(exp, conn) client := newClient(exp, conn)
go client.run() go client.run()
} }
} }
// Send a single client all its data. For each request, this will launch func (client *expClient) sendError(hdr *header, err string) {
// a serveRecv goroutine to deliver the data for that channel. error := &error{err}
log.Stderr("export:", error.error)
client.encode(hdr, payError, error) // ignore any encode error, hope client gets it
}
func (client *expClient) getChan(hdr *header, dir Dir) *exportChan {
exp := client.exp
exp.chanLock.Lock()
ech, ok := exp.chans[hdr.name]
exp.chanLock.Unlock()
if !ok {
client.sendError(hdr, "no such channel: "+hdr.name)
return nil
}
if ech.dir != dir {
client.sendError(hdr, "wrong direction for channel: "+hdr.name)
return nil
}
return ech
}
// Manage sends and receives for a single client. For each (client Recv) request,
// this will launch a serveRecv goroutine to deliver the data for that channel,
// while (client Send) requests are handled as data arrives from the client.
func (client *expClient) run() { func (client *expClient) run() {
hdr := new(header)
req := new(request) req := new(request)
error := new(error)
for { for {
if err := client.decode(req); err != nil { if err := client.decode(hdr); err != nil {
log.Stderr("error decoding client request:", err) log.Stderr("error decoding client header:", err)
// TODO: tear down connection // TODO: tear down connection
break return
} }
log.Stderrf("export request: %+v", req) switch hdr.payloadType {
if req.dir == Recv { case payRequest:
go client.serveRecv(req) if err := client.decode(req); err != nil {
} else { log.Stderr("error decoding client request:", err)
log.Stderr("export request: can't handle channel direction", req.dir) // TODO: tear down connection
resp := new(response) return
resp.name = req.name }
resp.error = "export request: can't handle channel direction" switch req.dir {
client.encode(resp, nil) case Recv:
break go client.serveRecv(*hdr, req.count)
case Send:
// Request to send is clear as a matter of protocol
// but not actually used by the implementation.
// The actual sends will have payload type payData.
// TODO: manage the count?
default:
error.error = "export request: can't handle channel direction"
log.Stderr(error.error, req.dir)
client.encode(hdr, payError, error)
}
case payData:
client.serveSend(*hdr)
} }
} }
} }
// Send all the data on a single channel to a client asking for a Recv // Send all the data on a single channel to a client asking for a Recv.
func (client *expClient) serveRecv(req *request) { // The header is passed by value to avoid issues of overwriting.
exp := client.exp func (client *expClient) serveRecv(hdr header, count int) {
resp := new(response) ech := client.getChan(&hdr, Send)
resp.name = req.name if ech == nil {
var ok bool
exp.chanLock.Lock()
ech, ok := exp.chans[req.name]
exp.chanLock.Unlock()
if !ok {
resp.error = "no such channel: " + req.name
log.Stderr("export:", resp.error)
client.encode(resp, nil) // ignore any encode error, hope client gets it
return return
} }
for { for {
if ech.dir != Send {
log.Stderr("TODO: recv export unimplemented")
break
}
val := ech.ch.Recv() val := ech.ch.Recv()
if err := client.encode(resp, val.Interface()); err != nil { if err := client.encode(&hdr, payData, val.Interface()); err != nil {
log.Stderr("error encoding client response:", err) log.Stderr("error encoding client response:", err)
client.sendError(&hdr, err.String())
break break
} }
if req.count > 0 { if count > 0 {
req.count-- if count--; count == 0 {
if req.count == 0 {
break break
} }
} }
} }
} }
// Receive and deliver locally one item from a client asking for a Send
// The header is passed by value to avoid issues of overwriting.
func (client *expClient) serveSend(hdr header) {
ech := client.getChan(&hdr, Recv)
if ech == nil {
return
}
// Create a new value for each received item.
val := reflect.MakeZero(ech.ptr.Type().(*reflect.PtrType).Elem())
ech.ptr.PointTo(val)
if err := client.decode(ech.ptr.Interface()); err != nil {
log.Stderr("exporter value decode:", err)
return
}
ech.ch.Send(val)
// TODO count
}
// NewExporter creates a new Exporter to export channels // NewExporter creates a new Exporter to export channels
// on the network and local address defined as in net.Listen. // on the network and local address defined as in net.Listen.
func NewExporter(network, localaddr string) (*Exporter, os.Error) { func NewExporter(network, localaddr string) (*Exporter, os.Error) {
...@@ -195,7 +220,8 @@ func checkChan(chT interface{}, dir Dir) (*reflect.ChanValue, os.Error) { ...@@ -195,7 +220,8 @@ func checkChan(chT interface{}, dir Dir) (*reflect.ChanValue, os.Error) {
// Despite the literal signature, the effective signature is // Despite the literal signature, the effective signature is
// Export(name string, chT chan T, dir Dir) // Export(name string, chT chan T, dir Dir)
// where T must be a struct, pointer to struct, etc. // where T must be a struct, pointer to struct, etc.
func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error { // TODO: fix gob interface so we can eliminate the need for pT, and for structs.
func (exp *Exporter) Export(name string, chT interface{}, dir Dir, pT interface{}) os.Error {
ch, err := checkChan(chT, dir) ch, err := checkChan(chT, dir)
if err != nil { if err != nil {
return err return err
...@@ -206,6 +232,7 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error { ...@@ -206,6 +232,7 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
if present { if present {
return os.ErrorString("channel name already being exported:" + name) return os.ErrorString("channel name already being exported:" + name)
} }
exp.chans[name] = &exportChan{ch, dir} ptr := reflect.MakeZero(reflect.Typeof(pT)).(*reflect.PtrValue)
exp.chans[name] = &exportChan{ch, dir, ptr}
return nil return nil
} }
// Copyright 2009 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
...@@ -14,12 +14,12 @@ import ( ...@@ -14,12 +14,12 @@ import (
// Import // Import
// A channel and its associated information: a template value, direction and a count // A channel and its associated information: a template value and direction,
// plus a handy marshaling place for its data.
type importChan struct { type importChan struct {
ch *reflect.ChanValue ch *reflect.ChanValue
dir Dir dir Dir
ptr *reflect.PtrValue // a pointer value we can point at each new item ptr *reflect.PtrValue // a pointer value we can point at each new received item
count int
} }
// An Importer allows a set of channels to be imported from a single // An Importer allows a set of channels to be imported from a single
...@@ -32,8 +32,6 @@ type Importer struct { ...@@ -32,8 +32,6 @@ type Importer struct {
chans map[string]*importChan chans map[string]*importChan
} }
// TODO: ASSUMES IMPORT MEANS RECEIVE
// NewImporter creates a new Importer object to import channels // NewImporter creates a new Importer object to import channels
// from an Exporter at the network and remote address as defined in net.Dial. // from an Exporter at the network and remote address as defined in net.Dial.
// The Exporter must be available and serving when the Importer is // The Exporter must be available and serving when the Importer is
...@@ -54,37 +52,49 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) { ...@@ -54,37 +52,49 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) {
// Handle the data from a single imported data stream, which will // Handle the data from a single imported data stream, which will
// have the form // have the form
// (response, data)* // (response, data)*
// The response identifies by name which channel is receiving data. // The response identifies by name which channel is transmitting data.
// TODO: allow an importer to send.
func (imp *Importer) run() { func (imp *Importer) run() {
// Loop on responses; requests are sent by ImportNValues() // Loop on responses; requests are sent by ImportNValues()
resp := new(response) hdr := new(header)
err := new(error)
for { for {
if err := imp.decode(resp); err != nil { if e := imp.decode(hdr); e != nil {
log.Stderr("importer response decode:", err) log.Stderr("importer header:", e)
break return
} }
if resp.error != "" { switch hdr.payloadType {
log.Stderr("importer response error:", resp.error) case payData:
// TODO: tear down connection // done lower in loop
break case payError:
if e := imp.decode(err); e != nil {
log.Stderr("importer error:", e)
return
}
if err.error != "" {
log.Stderr("importer response error:", err.error)
// TODO: tear down connection
return
}
default:
log.Stderr("unexpected payload type:", hdr.payloadType)
return
} }
imp.chanLock.Lock() imp.chanLock.Lock()
ich, ok := imp.chans[resp.name] ich, ok := imp.chans[hdr.name]
imp.chanLock.Unlock() imp.chanLock.Unlock()
if !ok { if !ok {
log.Stderr("unknown name in request:", resp.name) log.Stderr("unknown name in request:", hdr.name)
break return
} }
if ich.dir != Recv { if ich.dir != Recv {
log.Stderr("TODO: import send unimplemented") log.Stderr("cannot happen: receive from non-Recv channel")
break return
} }
// Create a new value for each received item. // Create a new value for each received item.
val := reflect.MakeZero(ich.ptr.Type().(*reflect.PtrType).Elem()) val := reflect.MakeZero(ich.ptr.Type().(*reflect.PtrType).Elem())
ich.ptr.PointTo(val) ich.ptr.PointTo(val)
if err := imp.decode(ich.ptr.Interface()); err != nil { if e := imp.decode(ich.ptr.Interface()); e != nil {
log.Stderr("importer value decode:", err) log.Stderr("importer value decode:", e)
return return
} }
ich.ch.Send(val) ich.ch.Send(val)
...@@ -103,7 +113,7 @@ func (imp *Importer) Import(name string, chT interface{}, dir Dir, pT interface{ ...@@ -103,7 +113,7 @@ func (imp *Importer) Import(name string, chT interface{}, dir Dir, pT interface{
// the remote site's channel is provided in the call and may be of arbitrary // the remote site's channel is provided in the call and may be of arbitrary
// channel type. // channel type.
// Despite the literal signature, the effective signature is // Despite the literal signature, the effective signature is
// ImportNValues(name string, chT chan T, dir Dir, pT T) // ImportNValues(name string, chT chan T, dir Dir, pT T, n int) os.Error
// where T must be a struct, pointer to struct, etc. pT may be more indirect // where T must be a struct, pointer to struct, etc. pT may be more indirect
// than the value type of the channel (e.g. chan T, pT *T) but it must be a // than the value type of the channel (e.g. chan T, pT *T) but it must be a
// pointer. // pointer.
...@@ -114,7 +124,7 @@ func (imp *Importer) Import(name string, chT interface{}, dir Dir, pT interface{ ...@@ -114,7 +124,7 @@ func (imp *Importer) Import(name string, chT interface{}, dir Dir, pT interface{
// err := imp.ImportNValues("name", ch, Recv, new(myType), 1) // err := imp.ImportNValues("name", ch, Recv, new(myType), 1)
// if err != nil { log.Exit(err) } // if err != nil { log.Exit(err) }
// fmt.Printf("%+v\n", <-ch) // fmt.Printf("%+v\n", <-ch)
// (TODO: Can we eliminate the need for pT?) // TODO: fix gob interface so we can eliminate the need for pT, and for structs.
func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, pT interface{}, n int) os.Error { func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, pT interface{}, n int) os.Error {
ch, err := checkChan(chT, dir) ch, err := checkChan(chT, dir)
if err != nil { if err != nil {
...@@ -135,15 +145,28 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, pT int ...@@ -135,15 +145,28 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, pT int
return os.ErrorString("channel name already being imported:" + name) return os.ErrorString("channel name already being imported:" + name)
} }
ptr := reflect.MakeZero(reflect.Typeof(pT)).(*reflect.PtrValue) ptr := reflect.MakeZero(reflect.Typeof(pT)).(*reflect.PtrValue)
imp.chans[name] = &importChan{ch, dir, ptr, n} imp.chans[name] = &importChan{ch, dir, ptr}
// Tell the other side about this channel. // Tell the other side about this channel.
hdr := new(header)
hdr.name = name
hdr.payloadType = payRequest
req := new(request) req := new(request)
req.name = name
req.dir = dir req.dir = dir
req.count = n req.count = n
if err := imp.encode(req, nil); err != nil { if err := imp.encode(hdr, payRequest, req); err != nil {
log.Stderr("importer request encode:", err) log.Stderr("importer request encode:", err)
return err return err
} }
if dir == Send {
go func() {
for i := 0; n == 0 || i < n; i++ {
val := ch.Recv()
if err := imp.encode(hdr, payData, val.Interface()); err != nil {
log.Stderr("error encoding client response:", err)
return
}
}
}()
}
return nil return nil
} }
// Copyright 2009 The Go Authors. All rights reserved. // Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
...@@ -14,37 +14,82 @@ type value struct { ...@@ -14,37 +14,82 @@ type value struct {
s string s string
} }
const count = 10
func exportSend(exp *Exporter, t *testing.T) { func exportSend(exp *Exporter, t *testing.T) {
c := make(chan value) ch := make(chan value)
err := exp.Export("name", c, Send) err := exp.Export("exportedSend", ch, Send, new(value))
if err != nil {
t.Fatal("exportSend:", err)
}
for i := 0; i < count; i++ {
ch <- value{23 + i, "hello"}
}
}
func exportReceive(exp *Exporter, t *testing.T) {
ch := make(chan value)
err := exp.Export("exportedRecv", ch, Recv, new(value))
if err != nil { if err != nil {
t.Fatal("export:", err) t.Fatal("exportReceive:", err)
}
for i := 0; i < count; i++ {
v := <-ch
fmt.Printf("%v\n", v)
if v.i != 45+i || v.s != "hello" {
t.Errorf("export Receive: bad value: expected 4%d, hello; got %+v", 45+i, v)
}
} }
c <- value{23, "hello"}
} }
func importReceive(imp *Importer, t *testing.T) { func importReceive(imp *Importer, t *testing.T) {
ch := make(chan value) ch := make(chan value)
err := imp.ImportNValues("name", ch, Recv, new(value), 1) err := imp.ImportNValues("exportedSend", ch, Recv, new(value), count)
if err != nil { if err != nil {
t.Fatal("import:", err) t.Fatal("importReceive:", err)
} }
v := <-ch for i := 0; i < count; i++ {
fmt.Printf("%v\n", v) v := <-ch
if v.i != 23 || v.s != "hello" { fmt.Printf("%v\n", v)
t.Errorf("bad value: expected 23, hello; got %+v\n", v) if v.i != 23+i || v.s != "hello" {
t.Errorf("importReceive: bad value: expected %d, hello; got %+v", 23+i, v)
}
} }
} }
func TestBabyStep(t *testing.T) { func importSend(imp *Importer, t *testing.T) {
ch := make(chan value)
err := imp.ImportNValues("exportedRecv", ch, Send, new(value), count)
if err != nil {
t.Fatal("importSend:", err)
}
for i := 0; i < count; i++ {
ch <- value{45 + i, "hello"}
}
}
func TestExportSendImportReceive(t *testing.T) {
exp, err := NewExporter("tcp", ":0") exp, err := NewExporter("tcp", ":0")
if err != nil { if err != nil {
t.Fatal("new exporter:", err) t.Fatal("new exporter:", err)
} }
go exportSend(exp, t)
imp, err := NewImporter("tcp", exp.Addr().String()) imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil { if err != nil {
t.Fatal("new importer:", err) t.Fatal("new importer:", err)
} }
go exportSend(exp, t)
importReceive(imp, t) importReceive(imp, t)
} }
func TestExportReceiveImportSend(t *testing.T) {
exp, err := NewExporter("tcp", ":0")
if err != nil {
t.Fatal("new exporter:", err)
}
imp, err := NewImporter("tcp", exp.Addr().String())
if err != nil {
t.Fatal("new importer:", err)
}
go importSend(imp, t)
exportReceive(exp, t)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment