package rpc import ( "context" "crypto/tls" "crypto/x509" "sync" "github.com/jrick/wsrpc/v2" ) // Caller provides a client interface to perform JSON-RPC remote procedure calls. type Caller interface { // String returns the dialed URL. String() string // Call performs the remote procedure call defined by method and // waits for a response or a broken client connection. // Args provides positional parameters for the call. // Res must be a pointer to a struct, slice, or map type to unmarshal // a result (if any), or nil if no result is needed. Call(ctx context.Context, method string, res interface{}, args ...interface{}) error } // client wraps a wsrpc.Client, as well as all of the connection details // required to make a new client if the existing client is closed. type client struct { mu *sync.Mutex client *wsrpc.Client addr string tlsOpt wsrpc.Option authOpt wsrpc.Option notifier wsrpc.Notifier } func setup(user, pass, addr string, cert []byte, n wsrpc.Notifier) *client { // Create TLS options. pool := x509.NewCertPool() pool.AppendCertsFromPEM(cert) tc := &tls.Config{RootCAs: pool} tlsOpt := wsrpc.WithTLSConfig(tc) // Create authentication options. authOpt := wsrpc.WithBasicAuth(user, pass) var mu sync.Mutex var c *wsrpc.Client return &client{&mu, c, addr, tlsOpt, authOpt, n} } func (c *client) Close() { if c.client != nil { select { case <-c.client.Done(): log.Tracef("RPC already closed (%s)", c.addr) default: if err := c.client.Close(); err != nil { log.Errorf("Failed to close RPC (%s): %v", c.addr, err) } else { log.Tracef("RPC closed (%s)", c.addr) } } } } // dial will return a connect rpc client if one exists, or attempt to create a // new one if not. A // boolean indicates whether this connection is new (true), or if it is an // existing connection which is being reused (false). func (c *client) dial(ctx context.Context) (Caller, bool, error) { defer c.mu.Unlock() c.mu.Lock() if c.client != nil { select { case <-c.client.Done(): log.Debugf("RPC client %s errored (%v); reconnecting...", c.addr, c.client.Err()) c.client = nil default: return c.client, false, nil } } var err error fullAddr := "wss://" + c.addr + "/ws" c.client, err = wsrpc.Dial(ctx, fullAddr, c.tlsOpt, c.authOpt, wsrpc.WithNotifier(c.notifier)) if err != nil { return nil, false, err } return c.client, true, nil }