package licensing import ( "crypto/rsa" "encoding/json" "errors" "fmt" "net" "net/http" "strings" "sync" "time" "github.com/golang-jwt/jwt/v5" ) // Anonymous-trial rate limit: per-source-IP cap on /license/sign + // /license/heartbeat requests without a Bearer token. Picked at "high enough // for any legitimate single deployment, low enough to make brute-force / // signature-probing pointless." Each /sign costs 1, /heartbeat 1, in a 60s // sliding window. Authenticated requests skip this. const ( anonRatePerWindow = 10 anonRateWindow = time.Minute // anonReapInterval throws away stale buckets so the map doesn't grow // unbounded across IP-cycling attackers. Walk the map every N requests. anonReapEvery = 200 ) // Log-throttle cooldowns: downstream devices reconnect every few seconds when // over-quota, so without throttling these two Warn lines flood the operator's // log file. One log entry per cooldown window per unique key is enough signal. const ( quotaWarnCooldown = 5 * time.Minute // per (sub, clientID) pair rlWarnCooldown = anonRateWindow // per IP, matches the rate-limit window ) // LicenseServer is the HTTP service the operator's LocalSigner exposes for // RemoteSigner customer deployments. It uses the same LocalSigner instance // (same HMAC master key) to produce signatures so customers can issue // device logins without ever holding the master key themselves. // // Endpoints: // // POST /license/sign // Body: {"client_id": "...", "start_time": "..."} // Auth: Authorization: Bearer // Reply: {"signature": "<64-hex>"} (200) // or {"error": "..."} (4xx/5xx) // Enforces quota: claims.max_devices for the JWT's "sub". // // POST /license/heartbeat // Body: {"active_device_count": N, "active_device_ids": ["...","..."]} // Auth: Authorization: Bearer // Reply: {"server_view_count": M, "drift": N-M} // Used by the customer's Go server to surface its view; cross-validated // against what /sign has actually been asked for under that customer. // Large drift is logged for anti-tamper review; no automatic revocation // in v1 — operator decides. // // Security: serve plain HTTP and put nginx / Caddy in front for TLS. JWT // "alg" is locked to RS256 in token.go; "alg":"none" tampering is blocked. // // Anonymous trial: requests without a Bearer token are treated as anonymous // trial — the client's source IP is used as "sub" (key=`trial:`) and // MaxDevices is capped at FreeMaxDevices. This is what lets a zero-config // downstream binary "just work" for evaluation. Heavily rate-limited per IP // to make brute-force / signature-probing pointless. type LicenseServer struct { signer *LocalSigner pubKey *rsa.PublicKey tracker *quotaTracker logger Logger mux *http.ServeMux trustProxy bool // honor X-Forwarded-For / X-Real-IP — set only behind a trusted reverse proxy anonMu sync.Mutex anonBuckets map[string]*anonBucket // ip → bucket anonReqSeen int // counter for periodic reap warnMu sync.Mutex lastWarn map[string]time.Time // dedup key → last log time } // anonBucket tracks anonymous request count within a sliding window. type anonBucket struct { count int windowStart time.Time } // Logger is the minimal logging interface we need. The cmd package's // *logger.Logger satisfies this with its existing Info/Warn methods. type Logger interface { Info(format string, args ...any) Warn(format string, args ...any) Error(format string, args ...any) } // NewLicenseServer builds the HTTP handler set. evictAfter is how long a // quiet device keeps its slot before its quota is reclaimed (recommend // 5 min — twice a typical heartbeat interval). func NewLicenseServer(signer *LocalSigner, pubKey *rsa.PublicKey, evictAfter time.Duration, lg Logger) *LicenseServer { s := &LicenseServer{ signer: signer, pubKey: pubKey, tracker: newQuotaTracker(evictAfter), logger: lg, mux: http.NewServeMux(), anonBuckets: make(map[string]*anonBucket), lastWarn: make(map[string]time.Time), } s.mux.HandleFunc("/license/sign", s.handleSign) s.mux.HandleFunc("/license/heartbeat", s.handleHeartbeat) return s } // warnOnce emits a Warn log at most once per cooldown window for the given // dedup key. Subsequent identical events within the window are silently // dropped. This keeps high-frequency but expected conditions (quota exceeded, // rate limit hit) from flooding the operator's log file while still providing // one clear signal per event burst. func (s *LicenseServer) warnOnce(key string, cooldown time.Duration, format string, args ...any) { s.warnMu.Lock() if t, ok := s.lastWarn[key]; ok && time.Since(t) < cooldown { s.warnMu.Unlock() return } s.lastWarn[key] = time.Now() s.warnMu.Unlock() s.logger.Warn(format, args...) } // SetTrustProxy switches IP extraction to X-Forwarded-For / X-Real-IP for // the anonymous-trial branch. Only set this when running behind a reverse // proxy you control (nginx / caddy / cloudflare); direct-exposure // deployments MUST leave it false or attackers can spoof the header to // evade the per-IP rate limit and the trial quota. func (s *LicenseServer) SetTrustProxy(trust bool) { s.trustProxy = trust } // Handler returns the http.Handler the operator wires into their HTTP // server (or runs standalone via http.ListenAndServe). func (s *LicenseServer) Handler() http.Handler { return s.mux } // authenticate extracts and verifies the bearer JWT. Returns the parsed // claims on success; writes the appropriate HTTP error and returns nil // on failure so the caller can simply `return` on a nil result. func (s *LicenseServer) authenticate(w http.ResponseWriter, r *http.Request) *LicenseClaims { authHdr := r.Header.Get("Authorization") if !strings.HasPrefix(authHdr, "Bearer ") { writeJSONError(w, http.StatusUnauthorized, "missing Bearer token") return nil } tokenStr := strings.TrimPrefix(authHdr, "Bearer ") claims, err := VerifyJWT(tokenStr, s.pubKey) if err != nil { writeJSONError(w, http.StatusUnauthorized, fmt.Sprintf("invalid token: %v", err)) return nil } return claims } // resolveAuth decides whether the request is paid (Bearer JWT) or anonymous // trial (no Authorization header). Returns a LicenseClaims structure either // way: // - Paid: claims from JWT, untouched. // - Trial: synthesized claims with Subject="trial:", Tier=TierFree, // MaxDevices=FreeMaxDevices, no Bearer required. // // Anonymous requests are rate-limited per source IP; if the IP's bucket is // full we write 429 and return nil. Bad JWTs still 401 as before. func (s *LicenseServer) resolveAuth(w http.ResponseWriter, r *http.Request) *LicenseClaims { if r.Header.Get("Authorization") != "" { return s.authenticate(w, r) } // Anonymous trial branch. ip := s.clientIP(r) if !s.allowAnon(ip) { s.warnOnce("rl:"+ip, rlWarnCooldown, "License Server: anonymous rate limit hit for ip=%s", ip) w.Header().Set("Retry-After", "60") writeJSONError(w, http.StatusTooManyRequests, "trial rate limit exceeded; set YAMA_LICENSE_TOKEN for full license") return nil } return &LicenseClaims{ Tier: TierFree, MaxDevices: FreeMaxDevices, RegisteredClaims: jwt.RegisteredClaims{ Subject: "trial:" + ip, }, } } // clientIP returns the request's source IP. When trustProxy is set, prefer // X-Real-IP, then the last entry of X-Forwarded-For. Otherwise fall back to // r.RemoteAddr (host part only). func (s *LicenseServer) clientIP(r *http.Request) string { if s.trustProxy { if v := strings.TrimSpace(r.Header.Get("X-Real-IP")); v != "" { return v } if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // Last entry is the one appended by the proxy closest to us. parts := strings.Split(xff, ",") last := strings.TrimSpace(parts[len(parts)-1]) if last != "" { return last } } } host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host } // allowAnon enforces the anonymous-trial per-IP rate limit. Returns true // when the call is admitted, false when the IP's bucket is full. Also reaps // stale buckets opportunistically. func (s *LicenseServer) allowAnon(ip string) bool { s.anonMu.Lock() defer s.anonMu.Unlock() now := time.Now() b, ok := s.anonBuckets[ip] if !ok || now.Sub(b.windowStart) >= anonRateWindow { s.anonBuckets[ip] = &anonBucket{count: 1, windowStart: now} s.maybeReapAnonLocked(now) return true } if b.count >= anonRatePerWindow { return false } b.count++ return true } // maybeReapAnonLocked drops buckets whose windows are stale every N requests. // Caller must hold s.anonMu. func (s *LicenseServer) maybeReapAnonLocked(now time.Time) { s.anonReqSeen++ if s.anonReqSeen < anonReapEvery { return } s.anonReqSeen = 0 cutoff := now.Add(-anonRateWindow) for ip, b := range s.anonBuckets { if b.windowStart.Before(cutoff) { delete(s.anonBuckets, ip) } } } func (s *LicenseServer) handleSign(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") return } claims := s.resolveAuth(w, r) if claims == nil { return } var req signRequest if err := readJSONLimited(w, r, maxSignBodyBytes, &req); err != nil { writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("bad request body: %v", err)) return } if req.ClientID == "" || req.StartTime == "" { writeJSONError(w, http.StatusBadRequest, "client_id and start_time required") return } // Atomically check + reserve the slot. A rejected request does NOT // consume a slot — see quotaTracker.Reserve. active, accepted := s.tracker.Reserve(claims.Subject, req.ClientID, claims.MaxDevices) if !accepted { s.warnOnce("quota:"+claims.Subject+":"+req.ClientID, quotaWarnCooldown, "License Server: quota exceeded for sub=%s tier=%s active=%d max=%d clientID=%s", claims.Subject, claims.Tier, active, claims.MaxDevices, req.ClientID) writeJSONError(w, http.StatusForbidden, fmt.Sprintf("quota exceeded: %d/%d devices in use", active, claims.MaxDevices)) return } // Mint the signature using the local HMAC master key. sig, err := s.signer.Sign(req.StartTime, req.ClientID) if err != nil { s.logger.Error("License Server: signer failed: %v", err) writeJSONError(w, http.StatusInternalServerError, "signing failed") return } writeJSON(w, http.StatusOK, signResponse{Signature: sig}) s.logger.Info("License Server: signed for sub=%s clientID=%s active=%d/%d ttl=%s", claims.Subject, req.ClientID, active, claims.MaxDevices, formatTTL(claims.ttlSinceNow())) } type heartbeatRequest struct { ActiveDeviceCount int `json:"active_device_count"` ActiveDeviceIDs []string `json:"active_device_ids"` } type heartbeatResponse struct { ServerViewCount int `json:"server_view_count"` Drift int `json:"drift"` // customer count - server view count } func (s *LicenseServer) handleHeartbeat(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed") return } claims := s.resolveAuth(w, r) if claims == nil { return } var req heartbeatRequest if err := readJSONLimited(w, r, maxHeartbeatBodyBytes, &req); err != nil { writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("bad request body: %v", err)) return } // REFRESH-ONLY: only bump activity timestamps for devices already in // the customer's set (i.e. that came through /sign). New IDs reported // here are silently ignored — otherwise a malicious customer could // inflate their quota by POSTing fake IDs to /heartbeat. refreshed := s.tracker.RefreshExisting(claims.Subject, req.ActiveDeviceIDs) serverView := len(s.tracker.Snapshot(claims.Subject)) drift := req.ActiveDeviceCount - serverView // Soft anti-tamper: large persistent drift means the customer reports // devices we never minted signatures for. Log for operator review; no // automatic revocation in v1. if drift > claims.MaxDevices/2 && drift > 5 { s.logger.Warn("License Server: heartbeat drift sub=%s reported=%d server=%d refreshed=%d drift=%d", claims.Subject, req.ActiveDeviceCount, serverView, refreshed, drift) } writeJSON(w, http.StatusOK, heartbeatResponse{ ServerViewCount: serverView, Drift: drift, }) } // Issue is a small in-process helper for operators to mint customer JWTs // without spinning up a separate tool. It is intentionally unexported as // an HTTP endpoint — JWT issuance is a one-off operator action, not a // remote API. Use it from a cmd/issue-token CLI or interactively. // // Returns the signed RS256 token string. // minTokenTTL guards against fat-finger ttl=0 / negative — those produce // already-expired tokens that 401 immediately and confuse the customer. const minTokenTTL = time.Hour func Issue(privKey *rsa.PrivateKey, sub, tier string, maxDevices int, ttl time.Duration) (string, error) { if privKey == nil { return "", errors.New("nil private key") } if sub == "" { return "", errors.New("sub (customer ID) is required") } if ttl < minTokenTTL { return "", fmt.Errorf("ttl too short (%v); minimum is %v", ttl, minTokenTTL) } switch tier { case TierTrial: // Trial: 0 means "use the default 20" (kept consistent with VerifyJWT). if maxDevices <= 0 { maxDevices = TrialMaxDevices } case TierPaid: // Paid: must be explicit. A 0 here is almost certainly a misconfig // and would silently let one customer use unlimited devices. if maxDevices <= 0 { return "", errors.New("paid tier requires explicit max_devices > 0") } default: return "", fmt.Errorf("unsupported tier: %q", tier) } now := time.Now() claims := &LicenseClaims{ Tier: tier, MaxDevices: maxDevices, RegisteredClaims: jwt.RegisteredClaims{ Subject: sub, IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(ttl)), }, } tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) return tok.SignedString(privKey) } // ---- helpers ---- // Body size caps. /sign sends a tiny fixed-schema body; /heartbeat carries // a list of clientIDs whose worst case is bounded by the customer's max // device cap (paid is operator-controlled). 64 KiB is roomy for hundreds // of 20-byte IDs while still bounding the worst case. const ( maxSignBodyBytes int64 = 8 << 10 // 8 KiB maxHeartbeatBodyBytes int64 = 64 << 10 // 64 KiB ) // readJSONLimited wraps the body with http.MaxBytesReader so a misconfigured // (or malicious) client cannot OOM the server with a multi-GB payload. // DisallowUnknownFields catches schema drift cleanly. func readJSONLimited(w http.ResponseWriter, r *http.Request, limit int64, dst any) error { r.Body = http.MaxBytesReader(w, r.Body, limit) dec := json.NewDecoder(r.Body) dec.DisallowUnknownFields() return dec.Decode(dst) } func writeJSON(w http.ResponseWriter, status int, v any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) _ = json.NewEncoder(w).Encode(v) } func writeJSONError(w http.ResponseWriter, status int, msg string) { writeJSON(w, status, signResponse{Error: msg}) } func formatTTL(d time.Duration) string { if d <= 0 { return "expired" } days := int(d.Hours() / 24) if days > 0 { return fmt.Sprintf("%dd", days) } return d.Round(time.Minute).String() }