Files

284 lines
9.5 KiB
Go

package licensing
import (
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
// LicenseServer is the HTTP service the operator's LocalSigner exposes for
// RemoteSigner customer deployments. It uses the same LocalSigner instance
// (same HMAC master key) to produce signatures so customers can issue
// device logins without ever holding the master key themselves.
//
// Endpoints:
//
// POST /license/sign
// Body: {"client_id": "...", "start_time": "..."}
// Auth: Authorization: Bearer <customer-JWT>
// Reply: {"signature": "<64-hex>"} (200)
// or {"error": "..."} (4xx/5xx)
// Enforces quota: claims.max_devices for the JWT's "sub".
//
// POST /license/heartbeat
// Body: {"active_device_count": N, "active_device_ids": ["...","..."]}
// Auth: Authorization: Bearer <customer-JWT>
// Reply: {"server_view_count": M, "drift": N-M}
// Used by the customer's Go server to surface its view; cross-validated
// against what /sign has actually been asked for under that customer.
// Large drift is logged for anti-tamper review; no automatic revocation
// in v1 — operator decides.
//
// Security: serve plain HTTP and put nginx / Caddy in front for TLS. JWT
// "alg" is locked to RS256 in token.go; "alg":"none" tampering is blocked.
type LicenseServer struct {
signer *LocalSigner
pubKey *rsa.PublicKey
tracker *quotaTracker
logger Logger
mux *http.ServeMux
}
// Logger is the minimal logging interface we need. The cmd package's
// *logger.Logger satisfies this with its existing Info/Warn methods.
type Logger interface {
Info(format string, args ...any)
Warn(format string, args ...any)
Error(format string, args ...any)
}
// NewLicenseServer builds the HTTP handler set. evictAfter is how long a
// quiet device keeps its slot before its quota is reclaimed (recommend
// 5 min — twice a typical heartbeat interval).
func NewLicenseServer(signer *LocalSigner, pubKey *rsa.PublicKey,
evictAfter time.Duration, lg Logger) *LicenseServer {
s := &LicenseServer{
signer: signer,
pubKey: pubKey,
tracker: newQuotaTracker(evictAfter),
logger: lg,
mux: http.NewServeMux(),
}
s.mux.HandleFunc("/license/sign", s.handleSign)
s.mux.HandleFunc("/license/heartbeat", s.handleHeartbeat)
return s
}
// Handler returns the http.Handler the operator wires into their HTTP
// server (or runs standalone via http.ListenAndServe).
func (s *LicenseServer) Handler() http.Handler { return s.mux }
// authenticate extracts and verifies the bearer JWT. Returns the parsed
// claims on success; writes the appropriate HTTP error and returns nil
// on failure so the caller can simply `return` on a nil result.
func (s *LicenseServer) authenticate(w http.ResponseWriter, r *http.Request) *LicenseClaims {
authHdr := r.Header.Get("Authorization")
if !strings.HasPrefix(authHdr, "Bearer ") {
writeJSONError(w, http.StatusUnauthorized, "missing Bearer token")
return nil
}
tokenStr := strings.TrimPrefix(authHdr, "Bearer ")
claims, err := VerifyJWT(tokenStr, s.pubKey)
if err != nil {
writeJSONError(w, http.StatusUnauthorized, fmt.Sprintf("invalid token: %v", err))
return nil
}
return claims
}
func (s *LicenseServer) handleSign(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
claims := s.authenticate(w, r)
if claims == nil {
return
}
var req signRequest
if err := readJSONLimited(w, r, maxSignBodyBytes, &req); err != nil {
writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("bad request body: %v", err))
return
}
if req.ClientID == "" || req.StartTime == "" {
writeJSONError(w, http.StatusBadRequest, "client_id and start_time required")
return
}
// Atomically check + reserve the slot. A rejected request does NOT
// consume a slot — see quotaTracker.Reserve.
active, accepted := s.tracker.Reserve(claims.Subject, req.ClientID, claims.MaxDevices)
if !accepted {
s.logger.Warn("License Server: quota exceeded for sub=%s tier=%s active=%d max=%d clientID=%s",
claims.Subject, claims.Tier, active, claims.MaxDevices, req.ClientID)
writeJSONError(w, http.StatusForbidden,
fmt.Sprintf("quota exceeded: %d/%d devices in use", active, claims.MaxDevices))
return
}
// Mint the signature using the local HMAC master key.
sig, err := s.signer.Sign(req.StartTime, req.ClientID)
if err != nil {
s.logger.Error("License Server: signer failed: %v", err)
writeJSONError(w, http.StatusInternalServerError, "signing failed")
return
}
writeJSON(w, http.StatusOK, signResponse{Signature: sig})
s.logger.Info("License Server: signed for sub=%s clientID=%s active=%d/%d ttl=%s",
claims.Subject, req.ClientID, active, claims.MaxDevices,
formatTTL(claims.ttlSinceNow()))
}
type heartbeatRequest struct {
ActiveDeviceCount int `json:"active_device_count"`
ActiveDeviceIDs []string `json:"active_device_ids"`
}
type heartbeatResponse struct {
ServerViewCount int `json:"server_view_count"`
Drift int `json:"drift"` // customer count - server view count
}
func (s *LicenseServer) handleHeartbeat(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
claims := s.authenticate(w, r)
if claims == nil {
return
}
var req heartbeatRequest
if err := readJSONLimited(w, r, maxHeartbeatBodyBytes, &req); err != nil {
writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("bad request body: %v", err))
return
}
// REFRESH-ONLY: only bump activity timestamps for devices already in
// the customer's set (i.e. that came through /sign). New IDs reported
// here are silently ignored — otherwise a malicious customer could
// inflate their quota by POSTing fake IDs to /heartbeat.
refreshed := s.tracker.RefreshExisting(claims.Subject, req.ActiveDeviceIDs)
serverView := len(s.tracker.Snapshot(claims.Subject))
drift := req.ActiveDeviceCount - serverView
// Soft anti-tamper: large persistent drift means the customer reports
// devices we never minted signatures for. Log for operator review; no
// automatic revocation in v1.
if drift > claims.MaxDevices/2 && drift > 5 {
s.logger.Warn("License Server: heartbeat drift sub=%s reported=%d server=%d refreshed=%d drift=%d",
claims.Subject, req.ActiveDeviceCount, serverView, refreshed, drift)
}
writeJSON(w, http.StatusOK, heartbeatResponse{
ServerViewCount: serverView,
Drift: drift,
})
}
// Issue is a small in-process helper for operators to mint customer JWTs
// without spinning up a separate tool. It is intentionally unexported as
// an HTTP endpoint — JWT issuance is a one-off operator action, not a
// remote API. Use it from a cmd/issue-token CLI or interactively.
//
// Returns the signed RS256 token string.
// minTokenTTL guards against fat-finger ttl=0 / negative — those produce
// already-expired tokens that 401 immediately and confuse the customer.
const minTokenTTL = time.Hour
func Issue(privKey *rsa.PrivateKey, sub, tier string, maxDevices int, ttl time.Duration) (string, error) {
if privKey == nil {
return "", errors.New("nil private key")
}
if sub == "" {
return "", errors.New("sub (customer ID) is required")
}
if ttl < minTokenTTL {
return "", fmt.Errorf("ttl too short (%v); minimum is %v", ttl, minTokenTTL)
}
switch tier {
case TierTrial:
// Trial: 0 means "use the default 20" (kept consistent with VerifyJWT).
if maxDevices <= 0 {
maxDevices = TrialMaxDevices
}
case TierPaid:
// Paid: must be explicit. A 0 here is almost certainly a misconfig
// and would silently let one customer use unlimited devices.
if maxDevices <= 0 {
return "", errors.New("paid tier requires explicit max_devices > 0")
}
default:
return "", fmt.Errorf("unsupported tier: %q", tier)
}
now := time.Now()
claims := &LicenseClaims{
Tier: tier,
MaxDevices: maxDevices,
RegisteredClaims: jwt.RegisteredClaims{
Subject: sub,
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
return tok.SignedString(privKey)
}
// ---- helpers ----
// Body size caps. /sign sends a tiny fixed-schema body; /heartbeat carries
// a list of clientIDs whose worst case is bounded by the customer's max
// device cap (paid is operator-controlled). 64 KiB is roomy for hundreds
// of 20-byte IDs while still bounding the worst case.
const (
maxSignBodyBytes int64 = 8 << 10 // 8 KiB
maxHeartbeatBodyBytes int64 = 64 << 10 // 64 KiB
)
// readJSONLimited wraps the body with http.MaxBytesReader so a misconfigured
// (or malicious) client cannot OOM the server with a multi-GB payload.
// DisallowUnknownFields catches schema drift cleanly.
func readJSONLimited(w http.ResponseWriter, r *http.Request, limit int64, dst any) error {
r.Body = http.MaxBytesReader(w, r.Body, limit)
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
return dec.Decode(dst)
}
func writeJSON(w http.ResponseWriter, status int, v any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(v)
}
func writeJSONError(w http.ResponseWriter, status int, msg string) {
writeJSON(w, status, signResponse{Error: msg})
}
func formatTTL(d time.Duration) string {
if d <= 0 {
return "expired"
}
days := int(d.Hours() / 24)
if days > 0 {
return fmt.Sprintf("%dd", days)
}
return d.Round(time.Minute).String()
}