package licensing import ( "crypto/rand" "crypto/rsa" "encoding/json" "fmt" "net/http" "net/http/httptest" "strings" "sync/atomic" "testing" "time" "github.com/golang-jwt/jwt/v5" ) // testKey generates an ephemeral 2048-bit RSA key for JWT signing in tests. // 2048 keeps the test fast (~50ms) but matches realistic security. func testKey(t *testing.T) *rsa.PrivateKey { t.Helper() k, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("rsa.GenerateKey: %v", err) } return k } // silentLogger swallows logs in tests but still satisfies the Logger // interface — keeps test output uncluttered. type silentLogger struct{} func (silentLogger) Info(string, ...any) {} func (silentLogger) Warn(string, ...any) {} func (silentLogger) Error(string, ...any) {} // mustLocal wraps NewLocal for tests so the per-test boilerplate stays // readable. Any test-only HMAC key must be >= 16 chars (see minMasterKeyLen). func mustLocal(t *testing.T, key string) *LocalSigner { t.Helper() s, err := NewLocal(key) if err != nil { t.Fatalf("NewLocal: %v", err) } return s } // TestIssueVerifyRoundTrip is the smoke test: a token minted by Issue // must verify under VerifyJWT with the matching public key, and the // claims must round-trip intact. func TestIssueVerifyRoundTrip(t *testing.T) { priv := testKey(t) tok, err := Issue(priv, "customer-abc", TierTrial, 0, 30*24*time.Hour) if err != nil { t.Fatalf("Issue: %v", err) } claims, err := VerifyJWT(tok, &priv.PublicKey) if err != nil { t.Fatalf("VerifyJWT: %v", err) } if claims.Subject != "customer-abc" { t.Errorf("sub = %q, want customer-abc", claims.Subject) } if claims.Tier != TierTrial { t.Errorf("tier = %q, want %q", claims.Tier, TierTrial) } if claims.MaxDevices != TrialMaxDevices { // Trial JWT minted with MaxDevices=0 should default to TrialMaxDevices // (VerifyJWT normalizes this — see token.go). t.Errorf("max_devices = %d, want %d (trial default)", claims.MaxDevices, TrialMaxDevices) } } // TestVerifyRejectsWrongKey makes sure a token signed by key A cannot // be verified with key B's public half. This is what would fail open // if "alg":"none" tampering wasn't blocked, or if someone reused keys. func TestVerifyRejectsWrongKey(t *testing.T) { priv1 := testKey(t) priv2 := testKey(t) tok, err := Issue(priv1, "customer-x", TierPaid, 50, time.Hour) if err != nil { t.Fatalf("Issue: %v", err) } if _, err := VerifyJWT(tok, &priv2.PublicKey); err == nil { t.Fatal("VerifyJWT accepted token signed with a different key") } } // TestPaidRequiresMaxDevices: paid tier must carry an explicit cap; trial // gets a default. Catches misconfigured tokens at verify time. func TestPaidRequiresMaxDevices(t *testing.T) { priv := testKey(t) _, err := Issue(priv, "customer-y", TierPaid, 0, time.Hour) if err == nil { t.Fatal("Issue accepted paid tier with max_devices=0") } } // TestNoOpSignerReturnsEmpty: free tier produces no signature so the // client's private library trips and refuses high-tier features. func TestNoOpSignerReturnsEmpty(t *testing.T) { s := NewNoOp() sig, err := s.Sign("2026-05-20", "12345") if err != nil { t.Fatalf("Sign: %v", err) } if sig != "" { t.Errorf("NoOpSigner.Sign = %q, want empty", sig) } if s.Mode() != "noop" { t.Errorf("Mode = %q, want noop", s.Mode()) } } // TestLocalSignerDeterministic: HMAC is deterministic — same key + same // input must always yield the same output. This is the property that // makes RemoteSigner's cache correct. func TestLocalSignerDeterministic(t *testing.T) { s, err := NewLocal("shared-secret-xyz-long-enough") if err != nil { t.Fatalf("NewLocal: %v", err) } a, _ := s.Sign("2026-05-20T10:00:00Z", "12345") b, _ := s.Sign("2026-05-20T10:00:00Z", "12345") if a != b { t.Errorf("non-deterministic: %q vs %q", a, b) } if len(a) != 64 { t.Errorf("signature length = %d, want 64 (hex of HMAC-SHA256)", len(a)) } } // TestRemoteSignerCacheHit verifies that the second call for the same // (startTime, clientID) tuple doesn't hit the network. We assert this // by counting requests at the fake License Server. func TestRemoteSignerCacheHit(t *testing.T) { priv := testKey(t) master := mustLocal(t, "real-hmac-key-for-test-xx") ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{}) ts := httptest.NewServer(ls.Handler()) defer ts.Close() tok, err := Issue(priv, "cust-cache", TierPaid, 10, time.Hour) if err != nil { t.Fatalf("Issue: %v", err) } // Count requests by wrapping the LS handler. var calls atomic.Int64 counting := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { calls.Add(1) ls.Handler().ServeHTTP(w, r) })) defer counting.Close() rs := NewRemote(counting.URL, tok, time.Hour, silentLogger{}) defer rs.Close() sig1, err := rs.Sign("st-1", "client-1") if err != nil { t.Fatalf("first Sign: %v", err) } sig2, err := rs.Sign("st-1", "client-1") // identical → cache hit if err != nil { t.Fatalf("second Sign: %v", err) } if sig1 != sig2 { t.Errorf("signatures differ across cache: %q vs %q", sig1, sig2) } if got := calls.Load(); got != 1 { t.Errorf("expected exactly 1 HTTP call (second served from cache), got %d", got) } } // TestRemoteSignerStaleFallback: when the License Server goes down, an // expired-cache entry is still better than zero signature (avoids breaking // reconnects for existing devices during a transient outage). func TestRemoteSignerStaleFallback(t *testing.T) { priv := testKey(t) master := mustLocal(t, "master-fallback-test-xxx") ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{}) ts := httptest.NewServer(ls.Handler()) tok, err := Issue(priv, "cust-fallback", TierPaid, 5, time.Hour) if err != nil { t.Fatalf("Issue: %v", err) } // Grace = 1 ns so the next call considers the cache stale. rs := NewRemote(ts.URL, tok, time.Nanosecond, silentLogger{}) defer rs.Close() first, err := rs.Sign("st-fb", "client-fb") if err != nil { t.Fatalf("first Sign: %v", err) } // Take the server offline. ts.Close() second, err := rs.Sign("st-fb", "client-fb") if err != nil { t.Fatalf("post-outage Sign should fall back to stale cache, got err=%v", err) } if second != first { t.Errorf("stale fallback returned %q, want cached %q", second, first) } } // TestQuotaEnforcement: trial tier with 2 devices accepts the first two // but rejects the third with 403. func TestQuotaEnforcement(t *testing.T) { priv := testKey(t) master := mustLocal(t, "master-quota-test-xxxxxx") ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{}) ts := httptest.NewServer(ls.Handler()) defer ts.Close() // Trial JWT capped at 2 devices. tok, err := Issue(priv, "cust-quota", TierTrial, 2, time.Hour) if err != nil { t.Fatalf("Issue: %v", err) } call := func(clientID string) (int, string) { body, _ := json.Marshal(signRequest{ClientID: clientID, StartTime: "st"}) req, _ := http.NewRequest("POST", ts.URL+"/license/sign", strings.NewReader(string(body))) req.Header.Set("Authorization", "Bearer "+tok) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Do: %v", err) } defer resp.Body.Close() var sr signResponse _ = json.NewDecoder(resp.Body).Decode(&sr) if sr.Signature != "" { return resp.StatusCode, sr.Signature } return resp.StatusCode, sr.Error } if code, _ := call("dev-1"); code != http.StatusOK { t.Errorf("dev-1 expected 200, got %d", code) } if code, _ := call("dev-2"); code != http.StatusOK { t.Errorf("dev-2 expected 200, got %d", code) } if code, msg := call("dev-3"); code != http.StatusForbidden { t.Errorf("dev-3 expected 403, got %d (%q)", code, msg) } } // TestAnonymousTrialSignsAndCaps: no Authorization header → anonymous trial // branch. /sign returns 200 with a real signature up to FreeMaxDevices, then // 403 once the per-IP cap is reached. Replaces the older "missing bearer → // 401" test now that anonymous trial is a first-class mode. func TestAnonymousTrialSignsAndCaps(t *testing.T) { priv := testKey(t) master := mustLocal(t, "master-trial-test-xxxxxx") ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{}) ts := httptest.NewServer(ls.Handler()) defer ts.Close() call := func(clientID string) (int, string) { body := strings.NewReader(fmt.Sprintf( `{"client_id":%q,"start_time":"2026-01-01T00:00:00Z"}`, clientID)) resp, err := http.Post(ts.URL+"/license/sign", "application/json", body) if err != nil { t.Fatalf("Post: %v", err) } defer resp.Body.Close() var sr signResponse _ = json.NewDecoder(resp.Body).Decode(&sr) if sr.Signature != "" { return resp.StatusCode, sr.Signature } return resp.StatusCode, sr.Error } // First FreeMaxDevices distinct clientIDs get real signatures. for i := range FreeMaxDevices { code, sig := call(fmt.Sprintf("trial-dev-%d", i)) if code != http.StatusOK { t.Errorf("dev-%d expected 200, got %d (%q)", i, code, sig) } if sig == "" { t.Errorf("dev-%d signature unexpectedly empty", i) } } // Cap+1 → 403 quota exceeded. code, msg := call("trial-dev-overflow") if code != http.StatusForbidden { t.Errorf("overflow expected 403, got %d (%q)", code, msg) } } // TestAnonymousTrialIPRateLimit: anonymous /sign is capped at // anonRatePerWindow requests per minute per source IP. Hitting the cap // returns 429 with Retry-After. func TestAnonymousTrialIPRateLimit(t *testing.T) { priv := testKey(t) master := mustLocal(t, "master-rate-test-xxxxxxx") ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{}) ts := httptest.NewServer(ls.Handler()) defer ts.Close() // Reuse the same clientID so quota does NOT also reject — we want to // isolate the rate limiter. quotaTracker.Reserve treats a repeat clientID // as a refresh (always accepted), so all the 200s here are the same slot. hit := func() int { body := strings.NewReader(`{"client_id":"rate-dev","start_time":"t"}`) resp, err := http.Post(ts.URL+"/license/sign", "application/json", body) if err != nil { t.Fatalf("Post: %v", err) } resp.Body.Close() return resp.StatusCode } for i := range anonRatePerWindow { if code := hit(); code != http.StatusOK { t.Fatalf("req %d expected 200, got %d", i, code) } } if code := hit(); code != http.StatusTooManyRequests { t.Errorf("expected 429 after %d requests, got %d", anonRatePerWindow, code) } } // TestAuthRejectsBadBearer: invalid JWT still returns 401 (we did NOT widen // the auth surface; only "no Authorization header at all" enters trial). func TestAuthRejectsBadBearer(t *testing.T) { priv := testKey(t) master := mustLocal(t, "master-bad-bearer-xxxxxx") ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{}) ts := httptest.NewServer(ls.Handler()) defer ts.Close() req, _ := http.NewRequest("POST", ts.URL+"/license/sign", strings.NewReader(`{"client_id":"x","start_time":"y"}`)) req.Header.Set("Authorization", "Bearer not.a.real.jwt") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Do: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusUnauthorized { t.Errorf("expected 401 for malformed bearer, got %d", resp.StatusCode) } } // TestRemoteSignerHardFailNoCacheReturnsError: when the LS is unreachable // AND we have no cached signature for this (startTime, clientID), Sign must // return ("", err) so the caller (sendMasterSetting) ships zeroed bytes. // Previously untested — easy to silently regress. func TestRemoteSignerHardFailNoCacheReturnsError(t *testing.T) { // Point RemoteSigner at a URL that will refuse all connections. // Using an unreachable host avoids depending on a free local port. rs := NewRemote("https://127.0.0.1:1", "any-token", time.Hour, silentLogger{}) defer rs.Close() sig, err := rs.Sign("st-x", "client-x") if err == nil { t.Fatal("expected error when LS unreachable and cache empty") } if sig != "" { t.Errorf("expected empty signature on hard failure, got %q", sig) } } // TestHeartbeatRefreshOnly: malicious customer POSTs fake clientIDs to // /license/heartbeat. The fake IDs MUST NOT show up in the server's view — // only IDs already minted via /sign get refreshed. This is the anti-tamper // property that makes the quota system actually enforce. func TestHeartbeatRefreshOnly(t *testing.T) { priv := testKey(t) master := mustLocal(t, "master-hb-test-xxxxxxxxxx") ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{}) ts := httptest.NewServer(ls.Handler()) defer ts.Close() tok, err := Issue(priv, "cust-hb", TierPaid, 5, time.Hour) if err != nil { t.Fatalf("Issue: %v", err) } // Reserve one legit device via /sign first. signBody, _ := json.Marshal(signRequest{ClientID: "legit-1", StartTime: "st"}) req, _ := http.NewRequest("POST", ts.URL+"/license/sign", strings.NewReader(string(signBody))) req.Header.Set("Authorization", "Bearer "+tok) req.Header.Set("Content-Type", "application/json") if resp, err := http.DefaultClient.Do(req); err != nil { t.Fatalf("Do sign: %v", err) } else { resp.Body.Close() } // Now heartbeat reports 1 legit ID + 99 fake IDs. fakes := make([]string, 99) for i := range fakes { fakes[i] = fmt.Sprintf("fake-%d", i) } all := append([]string{"legit-1"}, fakes...) hbBody, _ := json.Marshal(heartbeatRequest{ ActiveDeviceCount: len(all), ActiveDeviceIDs: all, }) req, _ = http.NewRequest("POST", ts.URL+"/license/heartbeat", strings.NewReader(string(hbBody))) req.Header.Set("Authorization", "Bearer "+tok) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Do heartbeat: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Fatalf("heartbeat returned %d", resp.StatusCode) } var hbResp heartbeatResponse _ = json.NewDecoder(resp.Body).Decode(&hbResp) // Server should report exactly 1 active device (the one minted via /sign), // NOT 1 + 99. drift = 99. if hbResp.ServerViewCount != 1 { t.Errorf("server view = %d, want 1 (heartbeat must not insert fake IDs)", hbResp.ServerViewCount) } if hbResp.Drift != 99 { t.Errorf("drift = %d, want 99", hbResp.Drift) } // Verify the quota cap is still enforced: a fresh /sign for dev-2..dev-5 // should succeed (only 1 slot used), dev-6 should fail. tryReserve := func(cid string) int { body, _ := json.Marshal(signRequest{ClientID: cid, StartTime: "st"}) req, _ := http.NewRequest("POST", ts.URL+"/license/sign", strings.NewReader(string(body))) req.Header.Set("Authorization", "Bearer "+tok) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Do: %v", err) } defer resp.Body.Close() return resp.StatusCode } for i := 2; i <= 5; i++ { if code := tryReserve(fmt.Sprintf("legit-%d", i)); code != http.StatusOK { t.Errorf("legit-%d expected 200, got %d", i, code) } } if code := tryReserve("legit-6"); code != http.StatusForbidden { t.Errorf("legit-6 expected 403 (max=5), got %d", code) } } // TestQuotaRejectionDoesNotConsumeSlot: a rejected /sign must not leave // its clientID in the quota map. Otherwise a denied 3rd device would // permanently take a slot from the legitimate 1st/2nd. func TestQuotaRejectionDoesNotConsumeSlot(t *testing.T) { priv := testKey(t) master := mustLocal(t, "master-no-leak-xxxxxxxxxxxx") ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{}) ts := httptest.NewServer(ls.Handler()) defer ts.Close() tok, err := Issue(priv, "cust-leak", TierTrial, 2, time.Hour) if err != nil { t.Fatalf("Issue: %v", err) } doSign := func(cid string) int { body, _ := json.Marshal(signRequest{ClientID: cid, StartTime: "st"}) req, _ := http.NewRequest("POST", ts.URL+"/license/sign", strings.NewReader(string(body))) req.Header.Set("Authorization", "Bearer "+tok) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Do: %v", err) } defer resp.Body.Close() return resp.StatusCode } // Fill to cap. if c := doSign("a"); c != http.StatusOK { t.Fatalf("a expected 200, got %d", c) } if c := doSign("b"); c != http.StatusOK { t.Fatalf("b expected 200, got %d", c) } // Over cap — must be denied AND must NOT consume a slot. if c := doSign("c"); c != http.StatusForbidden { t.Fatalf("c expected 403, got %d", c) } if c := doSign("d"); c != http.StatusForbidden { t.Fatalf("d expected 403, got %d", c) } // Existing device 'a' re-signs — must still succeed (idempotent refresh). if c := doSign("a"); c != http.StatusOK { t.Fatalf("a re-sign expected 200, got %d", c) } } // TestQuotaTrackerEviction: after evictAfter elapses, a previously-occupied // slot must be reclaimed so a new device can take it. Exercises the time // path that TestQuotaEnforcement skips (it uses a long eviction window). func TestQuotaTrackerEviction(t *testing.T) { q := newQuotaTracker(50 * time.Millisecond) if count, ok := q.Reserve("cust", "dev-1", 1); !ok || count != 1 { t.Fatalf("first Reserve: count=%d ok=%v", count, ok) } if count, ok := q.Reserve("cust", "dev-2", 1); ok { t.Fatalf("expected over-cap rejection, got count=%d ok=%v", count, ok) } time.Sleep(80 * time.Millisecond) // dev-1's entry should now be stale; dev-2 should be admitted. if count, ok := q.Reserve("cust", "dev-2", 1); !ok || count != 1 { t.Fatalf("post-eviction Reserve: count=%d ok=%v", count, ok) } } // TestValidateRemoteURL: factory must reject http:// for non-loopback // targets so JWT/sigs don't leak in cleartext. func TestValidateRemoteURL(t *testing.T) { cases := []struct { url string wantError bool }{ {"https://license.example.com", false}, {"https://localhost:8443", false}, {"http://localhost:8443", false}, // loopback exception {"http://127.0.0.1:8443", false}, // loopback exception {"http://license.example.com", true}, // public http → reject {"ftp://license.example.com", true}, // bad scheme {"not a url at all", true}, } for _, c := range cases { err := ValidateRemoteURL(c.url) if (err != nil) != c.wantError { t.Errorf("ValidateRemoteURL(%q): err=%v, wantError=%v", c.url, err, c.wantError) } } } // TestIssueRejectsShortTTL: catches fat-finger ttl=0 / negative that mints // an already-expired token. func TestIssueRejectsShortTTL(t *testing.T) { priv := testKey(t) if _, err := Issue(priv, "cust", TierPaid, 10, 0); err == nil { t.Error("expected error for ttl=0") } if _, err := Issue(priv, "cust", TierPaid, 10, time.Minute); err == nil { t.Error("expected error for ttl below minimum") } if _, err := Issue(priv, "", TierPaid, 10, time.Hour); err == nil { t.Error("expected error for empty sub") } } // TestNewLocalRejectsShortKey: catches misconfigured YAMA_SIGN_PASSWORD // (empty / typo) at construction instead of silently signing with junk. func TestNewLocalRejectsShortKey(t *testing.T) { if _, err := NewLocal(""); err == nil { t.Error("expected error for empty master key") } if _, err := NewLocal("short"); err == nil { t.Error("expected error for too-short master key") } } // TestJWTAlgLockedToRS256: a token with any non-RS256 alg (here RS384) must // fail verification, even though the underlying RSA primitive is the same. // This pins the docs↔code contract. func TestJWTAlgLockedToRS256(t *testing.T) { priv := testKey(t) now := time.Now() claims := &LicenseClaims{ Tier: TierTrial, MaxDevices: 10, RegisteredClaims: jwt.RegisteredClaims{ Subject: "cust", IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)), }, } tok := jwt.NewWithClaims(jwt.SigningMethodRS384, claims) signed, err := tok.SignedString(priv) if err != nil { t.Fatalf("sign: %v", err) } if _, err := VerifyJWT(signed, &priv.PublicKey); err == nil { t.Error("VerifyJWT accepted RS384; alg should be locked to RS256") } }