-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
879 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
// Copyright 2009 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package geerpc | ||
|
||
import ( | ||
"bufio" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"geerpc/codec" | ||
"io" | ||
"log" | ||
"net" | ||
"net/http" | ||
"sync" | ||
) | ||
|
||
// Call represents an active RPC. | ||
type Call struct { | ||
ServiceMethod string // format "<service>.<Method>" | ||
Args interface{} // arguments to the function | ||
Reply interface{} // reply from the function | ||
Error error // if error occurs, it will be set | ||
Done chan *Call // Strobes when call is complete. | ||
} | ||
|
||
func (call *Call) done() { | ||
call.Done <- call | ||
} | ||
|
||
// Client represents an RPC Client. | ||
// There may be multiple outstanding Calls associated | ||
// with a single Client, and a Client may be used by | ||
// multiple goroutines simultaneously. | ||
type Client struct { | ||
cc codec.Codec | ||
sending sync.Mutex // protect following | ||
header codec.Header | ||
mu sync.Mutex // protect following | ||
seq uint64 | ||
pending map[uint64]*Call | ||
closed bool // user has called Close | ||
} | ||
|
||
var _ io.Closer = (*Client)(nil) | ||
|
||
var ErrShutdown = errors.New("connection is shut down") | ||
|
||
// Close the connection | ||
func (client *Client) Close() error { | ||
client.mu.Lock() | ||
defer client.mu.Unlock() | ||
if client.closed { | ||
return ErrShutdown | ||
} | ||
client.closed = true | ||
return client.cc.Close() | ||
} | ||
|
||
func (client *Client) registerCall(call *Call) (uint64, error) { | ||
client.mu.Lock() | ||
defer client.mu.Unlock() | ||
if client.closed { | ||
return 0, ErrShutdown | ||
} | ||
seq := client.seq | ||
client.pending[seq] = call | ||
client.seq++ | ||
return seq, nil | ||
} | ||
|
||
func (client *Client) removeCall(seq uint64) *Call { | ||
client.mu.Lock() | ||
defer client.mu.Unlock() | ||
call := client.pending[seq] | ||
delete(client.pending, seq) | ||
return call | ||
} | ||
|
||
func (client *Client) terminateCalls(err error) { | ||
client.sending.Lock() | ||
defer client.sending.Unlock() | ||
client.mu.Lock() | ||
defer client.mu.Unlock() | ||
for _, call := range client.pending { | ||
call.Error = err | ||
call.done() | ||
} | ||
} | ||
|
||
func (client *Client) send(call *Call) { | ||
// make sure that the client will send a complete request | ||
client.sending.Lock() | ||
defer client.sending.Unlock() | ||
|
||
// register this call. | ||
seq, err := client.registerCall(call) | ||
if err != nil { | ||
call.Error = err | ||
call.done() | ||
return | ||
} | ||
|
||
// prepare request header | ||
client.header.ServiceMethod = call.ServiceMethod | ||
client.header.Seq = seq | ||
client.header.Error = "" | ||
|
||
// encode and send the request | ||
if err := client.cc.Write(&client.header, call.Args); err != nil { | ||
call := client.removeCall(seq) | ||
// call may be nil, it usually means that Write partially failed, | ||
// client has received the response and handled | ||
if call != nil { | ||
call.Error = err | ||
call.done() | ||
} | ||
} | ||
} | ||
|
||
func (client *Client) receive() { | ||
var err error | ||
for err == nil { | ||
var h codec.Header | ||
if err = client.cc.ReadHeader(&h); err != nil { | ||
break | ||
} | ||
call := client.removeCall(h.Seq) | ||
switch { | ||
case call == nil: | ||
// it usually means that Write partially failed | ||
// and call was already removed. | ||
err = client.cc.ReadBody(nil) | ||
case h.Error != "": | ||
call.Error = fmt.Errorf(h.Error) | ||
err = client.cc.ReadBody(nil) | ||
call.done() | ||
default: | ||
err = client.cc.ReadBody(call.Reply) | ||
if err != nil { | ||
call.Error = errors.New("reading body " + err.Error()) | ||
} | ||
call.done() | ||
} | ||
} | ||
// error occurs, so terminateCalls pending calls | ||
client.terminateCalls(err) | ||
} | ||
|
||
// Go invokes the function asynchronously. | ||
// It returns the Call structure representing the invocation. | ||
func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call { | ||
if done == nil { | ||
done = make(chan *Call, 10) | ||
} else if cap(done) == 0 { | ||
log.Panic("rpc client: done channel is unbuffered") | ||
} | ||
call := &Call{ | ||
ServiceMethod: serviceMethod, | ||
Args: args, | ||
Reply: reply, | ||
Done: done, | ||
} | ||
client.send(call) | ||
return call | ||
} | ||
|
||
// Call invokes the named function, waits for it to complete, | ||
// and returns its error status. | ||
func (client *Client) Call(serviceMethod string, args, reply interface{}) error { | ||
call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done | ||
return call.Error | ||
} | ||
|
||
func NewClient(conn io.ReadWriteCloser, opt *Options) (*Client, error) { | ||
var err error | ||
defer func() { | ||
if err != nil { | ||
_ = conn.Close() | ||
} | ||
}() | ||
if opt.MagicNumber == 0 { | ||
opt.MagicNumber = MagicNumber | ||
} | ||
f := codec.NewCodecFuncMap[opt.CodecType] | ||
if f == nil { | ||
err = fmt.Errorf("invalid codec type %s", opt.CodecType) | ||
log.Println("rpc client: codec error:", err) | ||
return nil, err | ||
} | ||
// send options with server | ||
if err = json.NewEncoder(conn).Encode(opt); err != nil { | ||
log.Println("rpc client: options error: ", err) | ||
return nil, err | ||
} | ||
return newClientCodec(f(conn)), nil | ||
} | ||
|
||
func newClientCodec(cc codec.Codec) *Client { | ||
client := &Client{ | ||
cc: cc, | ||
pending: make(map[uint64]*Call), | ||
} | ||
go client.receive() | ||
return client | ||
} | ||
|
||
// Dial connects to an RPC server at the specified network address | ||
func Dial(network, address string, opts ...*Options) (*Client, error) { | ||
opt := defaultOptions | ||
if len(opts) > 0 && opts[0] != nil { | ||
opt = opts[0] | ||
} | ||
conn, err := net.Dial(network, address) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return NewClient(conn, opt) | ||
} | ||
|
||
// DialHTTP connects to an HTTP RPC server at the specified network address | ||
// listening on the default HTTP RPC path. | ||
func DialHTTP(network, address string, opts ...*Options) (*Client, error) { | ||
return DialHTTPPath(network, address, defaultRPCPath, opts...) | ||
} | ||
|
||
// DialHTTPPath connects to an HTTP RPC server | ||
// at the specified network address and path. | ||
func DialHTTPPath(network, address, path string, opts ...*Options) (*Client, error) { | ||
opt := defaultOptions | ||
if len(opts) > 0 && opts[0] != nil { | ||
opt = opts[0] | ||
} | ||
var err error | ||
conn, err := net.Dial(network, address) | ||
if err != nil { | ||
return nil, err | ||
} | ||
_, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", path)) | ||
|
||
// Require successful HTTP response | ||
// before switching to RPC protocol. | ||
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) | ||
if err == nil && resp.Status == connected { | ||
return NewClient(conn, opt) | ||
} | ||
if err == nil { | ||
err = errors.New("unexpected HTTP response: " + resp.Status) | ||
} | ||
_ = conn.Close() | ||
return nil, err | ||
} |
Oops, something went wrong.