vspd/rpc/client.go
jholdstock 72c16ad2c7 Close RPC connections after the web server is stopped.
Previously all of the shutdown tasks were running concurrently, which meant the RPC connections be closed before the webserver is finished using them.
2020-06-12 13:23:55 +00:00

95 lines
2.4 KiB
Go

package rpc
import (
"context"
"crypto/tls"
"crypto/x509"
"sync"
"github.com/jrick/wsrpc/v2"
)
// Caller provides a client interface to perform JSON-RPC remote procedure calls.
type Caller interface {
// String returns the dialed URL.
String() string
// Call performs the remote procedure call defined by method and
// waits for a response or a broken client connection.
// Args provides positional parameters for the call.
// Res must be a pointer to a struct, slice, or map type to unmarshal
// a result (if any), or nil if no result is needed.
Call(ctx context.Context, method string, res interface{}, args ...interface{}) error
}
// client wraps a wsrpc.Client, as well as all of the connection details
// required to make a new client if the existing client is closed.
type client struct {
mu *sync.Mutex
client *wsrpc.Client
addr string
tlsOpt wsrpc.Option
authOpt wsrpc.Option
notifier wsrpc.Notifier
}
func setup(user, pass, addr string, cert []byte, n wsrpc.Notifier) *client {
// Create TLS options.
pool := x509.NewCertPool()
pool.AppendCertsFromPEM(cert)
tc := &tls.Config{RootCAs: pool}
tlsOpt := wsrpc.WithTLSConfig(tc)
// Create authentication options.
authOpt := wsrpc.WithBasicAuth(user, pass)
var mu sync.Mutex
var c *wsrpc.Client
return &client{&mu, c, addr, tlsOpt, authOpt, n}
}
func (c *client) Close() {
if c.client != nil {
select {
case <-c.client.Done():
log.Tracef("RPC already closed (%s)", c.addr)
default:
if err := c.client.Close(); err != nil {
log.Errorf("Failed to close RPC (%s): %v", c.addr, err)
} else {
log.Tracef("RPC closed (%s)", c.addr)
}
}
}
}
// dial will return a connect rpc client if one exists, or attempt to create a
// new one if not. A
// boolean indicates whether this connection is new (true), or if it is an
// existing connection which is being reused (false).
func (c *client) dial(ctx context.Context) (Caller, bool, error) {
defer c.mu.Unlock()
c.mu.Lock()
if c.client != nil {
select {
case <-c.client.Done():
log.Debugf("RPC client %s errored (%v); reconnecting...", c.addr, c.client.Err())
c.client = nil
default:
return c.client, false, nil
}
}
var err error
fullAddr := "wss://" + c.addr + "/ws"
c.client, err = wsrpc.Dial(ctx, fullAddr, c.tlsOpt, c.authOpt, wsrpc.WithNotifier(c.notifier))
if err != nil {
return nil, false, err
}
return c.client, true, nil
}