Files
SimpleRemoter/server/go/licensing/remote.go

346 lines
10 KiB
Go

package licensing
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/sync/singleflight"
)
// QuotaExceededError is returned by Sign when the License Server explicitly
// rejects the device because the customer's slot quota is full. Unlike
// transient network errors, stale-cache fallback is NOT appropriate — the
// License Server's 403 decision is authoritative, and serving a stale
// signature would silently bypass the operator's cap.
type QuotaExceededError struct {
Message string // raw error field from the License Server JSON body
}
func (e *QuotaExceededError) Error() string { return e.Message }
// IsQuotaExceeded reports whether err is, or wraps, a QuotaExceededError.
// Callers (e.g. handleLogin in cmd/main.go) use this to decide whether to
// close the device connection server-side after sending a zeroed signature.
func IsQuotaExceeded(err error) bool {
var qe *QuotaExceededError
return errors.As(err, &qe)
}
// RemoteSigner fetches per-login signatures from an operator-hosted License
// Server. ServerURL and Token (a JWT issued offline by the operator) are
// loaded from YAMA_LICENSE_SERVER / YAMA_LICENSE_TOKEN at startup.
//
// Cache strategy: every (startTime, clientID) tuple deterministically
// produces the same HMAC output. So once we've fetched a signature, we
// can serve it from memory for OfflineGrace (default 24h). We do honor the
// grace — the operator may want to revoke license / clear cache during
// outages — returning a stale signature beyond OfflineGrace defeats the
// point of the License Server.
//
// Thundering herd: on cache miss, concurrent Sign() calls for the same
// (startTime, clientID) are coalesced via singleflight so only one HTTPS
// round-trip happens. This matters at customer-side restart when many
// devices reconnect at the same time.
//
// Heartbeat: a background ticker POSTs the cached clientID set to
// /license/heartbeat every heartbeatInterval. This is what lets the
// License Server re-populate its quota view after a restart — without it,
// LS's view stays empty (since cache hits never re-fetch /sign).
//
// On License Server unreachable / non-200: try stale cache; if no cache,
// return ("", err). Caller (sendMasterSetting) ships zeroed signature and
// the device retries on next reconnect.
type RemoteSigner struct {
serverURL string
token string
offlineGrace time.Duration
httpClient *http.Client
logger Logger
sf singleflight.Group
mu sync.Mutex
cache map[string]cachedSig
hbDone chan struct{}
hbWg sync.WaitGroup
}
type cachedSig struct {
sig string
fetchedAt time.Time
}
// signRequest mirrors the License Server's POST /license/sign body schema.
type signRequest struct {
ClientID string `json:"client_id"`
StartTime string `json:"start_time"`
}
// signResponse mirrors the License Server's response schema.
type signResponse struct {
Signature string `json:"signature,omitempty"`
Error string `json:"error,omitempty"`
}
// nilLogger drops all log lines; used when callers don't pass a logger so
// the RemoteSigner stays safe to use without panic'ing on nil deref.
type nilLogger struct{}
func (nilLogger) Info(string, ...any) {}
func (nilLogger) Warn(string, ...any) {}
func (nilLogger) Error(string, ...any) {}
// heartbeatInterval is the period for /license/heartbeat POSTs. 90s is
// well below the License Server's 5-minute eviction window (quota.go), so
// the customer's devices never get reaped from the LS quota view while
// they're still actively heartbeating to it.
const heartbeatInterval = 90 * time.Second
// ValidateRemoteURL returns nil if u is a safe LICENSE_SERVER URL. We
// require https:// to keep the JWT and signature off the wire in cleartext;
// the only exception is http://localhost / http://127.0.0.1 for testing.
func ValidateRemoteURL(raw string) error {
u, err := url.Parse(raw)
if err != nil {
return fmt.Errorf("not a URL: %w", err)
}
switch u.Scheme {
case "https":
return nil
case "http":
host := u.Hostname()
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
return nil
}
return fmt.Errorf("http:// scheme exposes JWT in cleartext; use https:// (got %q)", raw)
default:
return fmt.Errorf("unsupported scheme %q (need https://)", u.Scheme)
}
}
func NewRemote(serverURL, token string, offlineGrace time.Duration, lg Logger) *RemoteSigner {
if lg == nil {
lg = nilLogger{}
}
r := &RemoteSigner{
serverURL: strings.TrimRight(serverURL, "/"),
token: token,
offlineGrace: offlineGrace,
logger: lg,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
cache: make(map[string]cachedSig),
hbDone: make(chan struct{}),
}
r.hbWg.Add(1)
go r.heartbeatLoop()
return r
}
func (r *RemoteSigner) Sign(startTime, clientID string) (string, error) {
key := startTime + "|" + clientID
// Fresh cache hit: serve from memory, no network.
r.mu.Lock()
if c, ok := r.cache[key]; ok && time.Since(c.fetchedAt) < r.offlineGrace {
sig := c.sig
r.mu.Unlock()
return sig, nil
}
r.mu.Unlock()
// Coalesce concurrent fetches for the same key — protects the License
// Server from herd reconnects after a network blip.
v, err, _ := r.sf.Do(key, func() (any, error) {
return r.fetch(startTime, clientID)
})
if err == nil {
sig := v.(string)
r.mu.Lock()
r.cache[key] = cachedSig{sig: sig, fetchedAt: time.Now()}
r.mu.Unlock()
return sig, nil
}
// Quota-exceeded is authoritative — skip stale-cache fallback entirely.
// Serving a cached signature here would bypass the operator's explicit cap
// decision; the caller (handleLogin) should close the connection instead.
if IsQuotaExceeded(err) {
r.logger.Error("RemoteSigner: quota exceeded for clientID=%s (%v); sending zeroed signature",
clientID, err)
return "", err
}
// Transient failure: fall back to stale cache if any. Better to keep an
// existing device alive than fail closed during a momentary outage.
r.mu.Lock()
c, ok := r.cache[key]
r.mu.Unlock()
if ok {
age := time.Since(c.fetchedAt).Round(time.Second)
r.logger.Warn("RemoteSigner: License Server unreachable (%v); serving stale cache (age=%s) for clientID=%s",
err, age, clientID)
return c.sig, nil
}
r.logger.Error("RemoteSigner: License Server unreachable (%v) and no cache for clientID=%s; client will see zeroed signature",
err, clientID)
return "", err
}
func (r *RemoteSigner) fetch(startTime, clientID string) (string, error) {
body, err := json.Marshal(signRequest{ClientID: clientID, StartTime: startTime})
if err != nil {
return "", err
}
req, err := http.NewRequest("POST", r.serverURL+"/license/sign", bytes.NewReader(body))
if err != nil {
return "", err
}
if r.token != "" {
req.Header.Set("Authorization", "Bearer "+r.token)
}
req.Header.Set("Content-Type", "application/json")
resp, err := r.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
// 403 with a JSON error body: quota exceeded — this is authoritative,
// not a transient failure. Return a typed error so Sign() can skip
// the stale-cache fallback and callers can close the connection.
if resp.StatusCode == http.StatusForbidden {
var sr signResponse
if jsonErr := json.Unmarshal(respBody, &sr); jsonErr == nil && sr.Error != "" {
return "", &QuotaExceededError{Message: sr.Error}
}
return "", &QuotaExceededError{Message: string(respBody)}
}
// 401 / 5xx: token rejected or server error — treat as transient.
return "", fmt.Errorf("License Server returned %d: %s",
resp.StatusCode, string(respBody))
}
var sr signResponse
if err := json.Unmarshal(respBody, &sr); err != nil {
return "", fmt.Errorf("malformed License Server response: %w", err)
}
if sr.Error != "" {
return "", errors.New(sr.Error)
}
if sr.Signature == "" {
return "", errors.New("License Server returned empty signature")
}
return sr.Signature, nil
}
// heartbeatLoop POSTs the cached clientID set to /license/heartbeat every
// heartbeatInterval. Goal: after a License Server restart, the customer's
// existing devices get re-counted in the LS quota view without each one
// needing to cache-miss /sign first.
func (r *RemoteSigner) heartbeatLoop() {
defer r.hbWg.Done()
ticker := time.NewTicker(heartbeatInterval)
defer ticker.Stop()
for {
select {
case <-r.hbDone:
return
case <-ticker.C:
r.sendHeartbeat()
}
}
}
func (r *RemoteSigner) sendHeartbeat() {
// Snapshot the cache's currently-fresh clientIDs.
r.mu.Lock()
cutoff := time.Now().Add(-r.offlineGrace)
ids := make([]string, 0, len(r.cache))
for key, c := range r.cache {
if c.fetchedAt.Before(cutoff) {
continue
}
// key is "startTime|clientID" — extract clientID for the heartbeat.
if _, cid, ok := strings.Cut(key, "|"); ok {
ids = append(ids, cid)
}
}
r.mu.Unlock()
if len(ids) == 0 {
return // nothing to report yet
}
body, err := json.Marshal(heartbeatRequest{
ActiveDeviceCount: len(ids),
ActiveDeviceIDs: ids,
})
if err != nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST",
r.serverURL+"/license/heartbeat", bytes.NewReader(body))
if err != nil {
return
}
if r.token != "" {
req.Header.Set("Authorization", "Bearer "+r.token)
}
req.Header.Set("Content-Type", "application/json")
resp, err := r.httpClient.Do(req)
if err != nil {
// Transient — don't spam logs every 90s; debug-level if we add one.
return
}
defer resp.Body.Close()
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4<<10))
if resp.StatusCode != http.StatusOK {
r.logger.Warn("RemoteSigner: heartbeat returned %d", resp.StatusCode)
}
}
// Mode reports "trial" if this RemoteSigner has no JWT (anonymous
// downstream against the operator's License Server, capped at
// FreeMaxDevices), otherwise "remote" (paid customer with JWT).
func (r *RemoteSigner) Mode() string {
if r.token == "" {
return "trial"
}
return "remote"
}
func (r *RemoteSigner) Close() error {
select {
case <-r.hbDone:
// already closed
default:
close(r.hbDone)
}
r.hbWg.Wait()
r.httpClient.CloseIdleConnections()
return nil
}