447 lines
15 KiB
Go
447 lines
15 KiB
Go
package licensing
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
// Anonymous-trial rate limit: per-source-IP cap on /license/sign +
|
|
// /license/heartbeat requests without a Bearer token. Picked at "high enough
|
|
// for any legitimate single deployment, low enough to make brute-force /
|
|
// signature-probing pointless." Each /sign costs 1, /heartbeat 1, in a 60s
|
|
// sliding window. Authenticated requests skip this.
|
|
const (
|
|
anonRatePerWindow = 10
|
|
anonRateWindow = time.Minute
|
|
// anonReapInterval throws away stale buckets so the map doesn't grow
|
|
// unbounded across IP-cycling attackers. Walk the map every N requests.
|
|
anonReapEvery = 200
|
|
)
|
|
|
|
// Log-throttle cooldowns: downstream devices reconnect every few seconds when
|
|
// over-quota, so without throttling these two Warn lines flood the operator's
|
|
// log file. One log entry per cooldown window per unique key is enough signal.
|
|
const (
|
|
quotaWarnCooldown = 5 * time.Minute // per (sub, clientID) pair
|
|
rlWarnCooldown = anonRateWindow // per IP, matches the rate-limit window
|
|
)
|
|
|
|
// LicenseServer is the HTTP service the operator's LocalSigner exposes for
|
|
// RemoteSigner customer deployments. It uses the same LocalSigner instance
|
|
// (same HMAC master key) to produce signatures so customers can issue
|
|
// device logins without ever holding the master key themselves.
|
|
//
|
|
// Endpoints:
|
|
//
|
|
// POST /license/sign
|
|
// Body: {"client_id": "...", "start_time": "..."}
|
|
// Auth: Authorization: Bearer <customer-JWT>
|
|
// Reply: {"signature": "<64-hex>"} (200)
|
|
// or {"error": "..."} (4xx/5xx)
|
|
// Enforces quota: claims.max_devices for the JWT's "sub".
|
|
//
|
|
// POST /license/heartbeat
|
|
// Body: {"active_device_count": N, "active_device_ids": ["...","..."]}
|
|
// Auth: Authorization: Bearer <customer-JWT>
|
|
// Reply: {"server_view_count": M, "drift": N-M}
|
|
// Used by the customer's Go server to surface its view; cross-validated
|
|
// against what /sign has actually been asked for under that customer.
|
|
// Large drift is logged for anti-tamper review; no automatic revocation
|
|
// in v1 — operator decides.
|
|
//
|
|
// Security: serve plain HTTP and put nginx / Caddy in front for TLS. JWT
|
|
// "alg" is locked to RS256 in token.go; "alg":"none" tampering is blocked.
|
|
//
|
|
// Anonymous trial: requests without a Bearer token are treated as anonymous
|
|
// trial — the client's source IP is used as "sub" (key=`trial:<ip>`) and
|
|
// MaxDevices is capped at FreeMaxDevices. This is what lets a zero-config
|
|
// downstream binary "just work" for evaluation. Heavily rate-limited per IP
|
|
// to make brute-force / signature-probing pointless.
|
|
type LicenseServer struct {
|
|
signer *LocalSigner
|
|
pubKey *rsa.PublicKey
|
|
tracker *quotaTracker
|
|
logger Logger
|
|
mux *http.ServeMux
|
|
trustProxy bool // honor X-Forwarded-For / X-Real-IP — set only behind a trusted reverse proxy
|
|
|
|
anonMu sync.Mutex
|
|
anonBuckets map[string]*anonBucket // ip → bucket
|
|
anonReqSeen int // counter for periodic reap
|
|
|
|
warnMu sync.Mutex
|
|
lastWarn map[string]time.Time // dedup key → last log time
|
|
}
|
|
|
|
// anonBucket tracks anonymous request count within a sliding window.
|
|
type anonBucket struct {
|
|
count int
|
|
windowStart time.Time
|
|
}
|
|
|
|
// Logger is the minimal logging interface we need. The cmd package's
|
|
// *logger.Logger satisfies this with its existing Info/Warn methods.
|
|
type Logger interface {
|
|
Info(format string, args ...any)
|
|
Warn(format string, args ...any)
|
|
Error(format string, args ...any)
|
|
}
|
|
|
|
// NewLicenseServer builds the HTTP handler set. evictAfter is how long a
|
|
// quiet device keeps its slot before its quota is reclaimed (recommend
|
|
// 5 min — twice a typical heartbeat interval).
|
|
func NewLicenseServer(signer *LocalSigner, pubKey *rsa.PublicKey,
|
|
evictAfter time.Duration, lg Logger) *LicenseServer {
|
|
s := &LicenseServer{
|
|
signer: signer,
|
|
pubKey: pubKey,
|
|
tracker: newQuotaTracker(evictAfter),
|
|
logger: lg,
|
|
mux: http.NewServeMux(),
|
|
anonBuckets: make(map[string]*anonBucket),
|
|
lastWarn: make(map[string]time.Time),
|
|
}
|
|
s.mux.HandleFunc("/license/sign", s.handleSign)
|
|
s.mux.HandleFunc("/license/heartbeat", s.handleHeartbeat)
|
|
return s
|
|
}
|
|
|
|
// warnOnce emits a Warn log at most once per cooldown window for the given
|
|
// dedup key. Subsequent identical events within the window are silently
|
|
// dropped. This keeps high-frequency but expected conditions (quota exceeded,
|
|
// rate limit hit) from flooding the operator's log file while still providing
|
|
// one clear signal per event burst.
|
|
func (s *LicenseServer) warnOnce(key string, cooldown time.Duration, format string, args ...any) {
|
|
s.warnMu.Lock()
|
|
if t, ok := s.lastWarn[key]; ok && time.Since(t) < cooldown {
|
|
s.warnMu.Unlock()
|
|
return
|
|
}
|
|
s.lastWarn[key] = time.Now()
|
|
s.warnMu.Unlock()
|
|
s.logger.Warn(format, args...)
|
|
}
|
|
|
|
// SetTrustProxy switches IP extraction to X-Forwarded-For / X-Real-IP for
|
|
// the anonymous-trial branch. Only set this when running behind a reverse
|
|
// proxy you control (nginx / caddy / cloudflare); direct-exposure
|
|
// deployments MUST leave it false or attackers can spoof the header to
|
|
// evade the per-IP rate limit and the trial quota.
|
|
func (s *LicenseServer) SetTrustProxy(trust bool) { s.trustProxy = trust }
|
|
|
|
// Handler returns the http.Handler the operator wires into their HTTP
|
|
// server (or runs standalone via http.ListenAndServe).
|
|
func (s *LicenseServer) Handler() http.Handler { return s.mux }
|
|
|
|
// authenticate extracts and verifies the bearer JWT. Returns the parsed
|
|
// claims on success; writes the appropriate HTTP error and returns nil
|
|
// on failure so the caller can simply `return` on a nil result.
|
|
func (s *LicenseServer) authenticate(w http.ResponseWriter, r *http.Request) *LicenseClaims {
|
|
authHdr := r.Header.Get("Authorization")
|
|
if !strings.HasPrefix(authHdr, "Bearer ") {
|
|
writeJSONError(w, http.StatusUnauthorized, "missing Bearer token")
|
|
return nil
|
|
}
|
|
tokenStr := strings.TrimPrefix(authHdr, "Bearer ")
|
|
claims, err := VerifyJWT(tokenStr, s.pubKey)
|
|
if err != nil {
|
|
writeJSONError(w, http.StatusUnauthorized, fmt.Sprintf("invalid token: %v", err))
|
|
return nil
|
|
}
|
|
return claims
|
|
}
|
|
|
|
// resolveAuth decides whether the request is paid (Bearer JWT) or anonymous
|
|
// trial (no Authorization header). Returns a LicenseClaims structure either
|
|
// way:
|
|
// - Paid: claims from JWT, untouched.
|
|
// - Trial: synthesized claims with Subject="trial:<ip>", Tier=TierFree,
|
|
// MaxDevices=FreeMaxDevices, no Bearer required.
|
|
//
|
|
// Anonymous requests are rate-limited per source IP; if the IP's bucket is
|
|
// full we write 429 and return nil. Bad JWTs still 401 as before.
|
|
func (s *LicenseServer) resolveAuth(w http.ResponseWriter, r *http.Request) *LicenseClaims {
|
|
if r.Header.Get("Authorization") != "" {
|
|
return s.authenticate(w, r)
|
|
}
|
|
|
|
// Anonymous trial branch.
|
|
ip := s.clientIP(r)
|
|
if !s.allowAnon(ip) {
|
|
s.warnOnce("rl:"+ip, rlWarnCooldown, "License Server: anonymous rate limit hit for ip=%s", ip)
|
|
w.Header().Set("Retry-After", "60")
|
|
writeJSONError(w, http.StatusTooManyRequests,
|
|
"trial rate limit exceeded; set YAMA_LICENSE_TOKEN for full license")
|
|
return nil
|
|
}
|
|
|
|
return &LicenseClaims{
|
|
Tier: TierFree,
|
|
MaxDevices: FreeMaxDevices,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: "trial:" + ip,
|
|
},
|
|
}
|
|
}
|
|
|
|
// clientIP returns the request's source IP. When trustProxy is set, prefer
|
|
// X-Real-IP, then the last entry of X-Forwarded-For. Otherwise fall back to
|
|
// r.RemoteAddr (host part only).
|
|
func (s *LicenseServer) clientIP(r *http.Request) string {
|
|
if s.trustProxy {
|
|
if v := strings.TrimSpace(r.Header.Get("X-Real-IP")); v != "" {
|
|
return v
|
|
}
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
// Last entry is the one appended by the proxy closest to us.
|
|
parts := strings.Split(xff, ",")
|
|
last := strings.TrimSpace(parts[len(parts)-1])
|
|
if last != "" {
|
|
return last
|
|
}
|
|
}
|
|
}
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
return host
|
|
}
|
|
|
|
// allowAnon enforces the anonymous-trial per-IP rate limit. Returns true
|
|
// when the call is admitted, false when the IP's bucket is full. Also reaps
|
|
// stale buckets opportunistically.
|
|
func (s *LicenseServer) allowAnon(ip string) bool {
|
|
s.anonMu.Lock()
|
|
defer s.anonMu.Unlock()
|
|
|
|
now := time.Now()
|
|
b, ok := s.anonBuckets[ip]
|
|
if !ok || now.Sub(b.windowStart) >= anonRateWindow {
|
|
s.anonBuckets[ip] = &anonBucket{count: 1, windowStart: now}
|
|
s.maybeReapAnonLocked(now)
|
|
return true
|
|
}
|
|
if b.count >= anonRatePerWindow {
|
|
return false
|
|
}
|
|
b.count++
|
|
return true
|
|
}
|
|
|
|
// maybeReapAnonLocked drops buckets whose windows are stale every N requests.
|
|
// Caller must hold s.anonMu.
|
|
func (s *LicenseServer) maybeReapAnonLocked(now time.Time) {
|
|
s.anonReqSeen++
|
|
if s.anonReqSeen < anonReapEvery {
|
|
return
|
|
}
|
|
s.anonReqSeen = 0
|
|
cutoff := now.Add(-anonRateWindow)
|
|
for ip, b := range s.anonBuckets {
|
|
if b.windowStart.Before(cutoff) {
|
|
delete(s.anonBuckets, ip)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *LicenseServer) handleSign(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
|
|
claims := s.resolveAuth(w, r)
|
|
if claims == nil {
|
|
return
|
|
}
|
|
|
|
var req signRequest
|
|
if err := readJSONLimited(w, r, maxSignBodyBytes, &req); err != nil {
|
|
writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("bad request body: %v", err))
|
|
return
|
|
}
|
|
if req.ClientID == "" || req.StartTime == "" {
|
|
writeJSONError(w, http.StatusBadRequest, "client_id and start_time required")
|
|
return
|
|
}
|
|
|
|
// Atomically check + reserve the slot. A rejected request does NOT
|
|
// consume a slot — see quotaTracker.Reserve.
|
|
active, accepted := s.tracker.Reserve(claims.Subject, req.ClientID, claims.MaxDevices)
|
|
if !accepted {
|
|
s.warnOnce("quota:"+claims.Subject+":"+req.ClientID, quotaWarnCooldown,
|
|
"License Server: quota exceeded for sub=%s tier=%s active=%d max=%d clientID=%s",
|
|
claims.Subject, claims.Tier, active, claims.MaxDevices, req.ClientID)
|
|
writeJSONError(w, http.StatusForbidden,
|
|
fmt.Sprintf("quota exceeded: %d/%d devices in use", active, claims.MaxDevices))
|
|
return
|
|
}
|
|
|
|
// Mint the signature using the local HMAC master key.
|
|
sig, err := s.signer.Sign(req.StartTime, req.ClientID)
|
|
if err != nil {
|
|
s.logger.Error("License Server: signer failed: %v", err)
|
|
writeJSONError(w, http.StatusInternalServerError, "signing failed")
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, signResponse{Signature: sig})
|
|
s.logger.Info("License Server: signed for sub=%s clientID=%s active=%d/%d ttl=%s",
|
|
claims.Subject, req.ClientID, active, claims.MaxDevices,
|
|
formatTTL(claims.ttlSinceNow()))
|
|
}
|
|
|
|
type heartbeatRequest struct {
|
|
ActiveDeviceCount int `json:"active_device_count"`
|
|
ActiveDeviceIDs []string `json:"active_device_ids"`
|
|
}
|
|
|
|
type heartbeatResponse struct {
|
|
ServerViewCount int `json:"server_view_count"`
|
|
Drift int `json:"drift"` // customer count - server view count
|
|
}
|
|
|
|
func (s *LicenseServer) handleHeartbeat(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
|
|
claims := s.resolveAuth(w, r)
|
|
if claims == nil {
|
|
return
|
|
}
|
|
|
|
var req heartbeatRequest
|
|
if err := readJSONLimited(w, r, maxHeartbeatBodyBytes, &req); err != nil {
|
|
writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("bad request body: %v", err))
|
|
return
|
|
}
|
|
|
|
// REFRESH-ONLY: only bump activity timestamps for devices already in
|
|
// the customer's set (i.e. that came through /sign). New IDs reported
|
|
// here are silently ignored — otherwise a malicious customer could
|
|
// inflate their quota by POSTing fake IDs to /heartbeat.
|
|
refreshed := s.tracker.RefreshExisting(claims.Subject, req.ActiveDeviceIDs)
|
|
|
|
serverView := len(s.tracker.Snapshot(claims.Subject))
|
|
drift := req.ActiveDeviceCount - serverView
|
|
|
|
// Soft anti-tamper: large persistent drift means the customer reports
|
|
// devices we never minted signatures for. Log for operator review; no
|
|
// automatic revocation in v1.
|
|
if drift > claims.MaxDevices/2 && drift > 5 {
|
|
s.logger.Warn("License Server: heartbeat drift sub=%s reported=%d server=%d refreshed=%d drift=%d",
|
|
claims.Subject, req.ActiveDeviceCount, serverView, refreshed, drift)
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, heartbeatResponse{
|
|
ServerViewCount: serverView,
|
|
Drift: drift,
|
|
})
|
|
}
|
|
|
|
// Issue is a small in-process helper for operators to mint customer JWTs
|
|
// without spinning up a separate tool. It is intentionally unexported as
|
|
// an HTTP endpoint — JWT issuance is a one-off operator action, not a
|
|
// remote API. Use it from a cmd/issue-token CLI or interactively.
|
|
//
|
|
// Returns the signed RS256 token string.
|
|
// minTokenTTL guards against fat-finger ttl=0 / negative — those produce
|
|
// already-expired tokens that 401 immediately and confuse the customer.
|
|
const minTokenTTL = time.Hour
|
|
|
|
func Issue(privKey *rsa.PrivateKey, sub, tier string, maxDevices int, ttl time.Duration) (string, error) {
|
|
if privKey == nil {
|
|
return "", errors.New("nil private key")
|
|
}
|
|
if sub == "" {
|
|
return "", errors.New("sub (customer ID) is required")
|
|
}
|
|
if ttl < minTokenTTL {
|
|
return "", fmt.Errorf("ttl too short (%v); minimum is %v", ttl, minTokenTTL)
|
|
}
|
|
switch tier {
|
|
case TierTrial:
|
|
// Trial: 0 means "use the default 20" (kept consistent with VerifyJWT).
|
|
if maxDevices <= 0 {
|
|
maxDevices = TrialMaxDevices
|
|
}
|
|
case TierPaid:
|
|
// Paid: must be explicit. A 0 here is almost certainly a misconfig
|
|
// and would silently let one customer use unlimited devices.
|
|
if maxDevices <= 0 {
|
|
return "", errors.New("paid tier requires explicit max_devices > 0")
|
|
}
|
|
default:
|
|
return "", fmt.Errorf("unsupported tier: %q", tier)
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := &LicenseClaims{
|
|
Tier: tier,
|
|
MaxDevices: maxDevices,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Subject: sub,
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
NotBefore: jwt.NewNumericDate(now),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
|
|
},
|
|
}
|
|
tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
return tok.SignedString(privKey)
|
|
}
|
|
|
|
// ---- helpers ----
|
|
|
|
// Body size caps. /sign sends a tiny fixed-schema body; /heartbeat carries
|
|
// a list of clientIDs whose worst case is bounded by the customer's max
|
|
// device cap (paid is operator-controlled). 64 KiB is roomy for hundreds
|
|
// of 20-byte IDs while still bounding the worst case.
|
|
const (
|
|
maxSignBodyBytes int64 = 8 << 10 // 8 KiB
|
|
maxHeartbeatBodyBytes int64 = 64 << 10 // 64 KiB
|
|
)
|
|
|
|
// readJSONLimited wraps the body with http.MaxBytesReader so a misconfigured
|
|
// (or malicious) client cannot OOM the server with a multi-GB payload.
|
|
// DisallowUnknownFields catches schema drift cleanly.
|
|
func readJSONLimited(w http.ResponseWriter, r *http.Request, limit int64, dst any) error {
|
|
r.Body = http.MaxBytesReader(w, r.Body, limit)
|
|
dec := json.NewDecoder(r.Body)
|
|
dec.DisallowUnknownFields()
|
|
return dec.Decode(dst)
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, status int, v any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
_ = json.NewEncoder(w).Encode(v)
|
|
}
|
|
|
|
func writeJSONError(w http.ResponseWriter, status int, msg string) {
|
|
writeJSON(w, status, signResponse{Error: msg})
|
|
}
|
|
|
|
func formatTTL(d time.Duration) string {
|
|
if d <= 0 {
|
|
return "expired"
|
|
}
|
|
days := int(d.Hours() / 24)
|
|
if days > 0 {
|
|
return fmt.Sprintf("%dd", days)
|
|
}
|
|
return d.Round(time.Minute).String()
|
|
}
|
|
|