Files
SimpleRemoter/server/go/licensing/licensing_test.go

521 lines
17 KiB
Go

package licensing
import (
"crypto/rand"
"crypto/rsa"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
)
// testKey generates an ephemeral 2048-bit RSA key for JWT signing in tests.
// 2048 keeps the test fast (~50ms) but matches realistic security.
func testKey(t *testing.T) *rsa.PrivateKey {
t.Helper()
k, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("rsa.GenerateKey: %v", err)
}
return k
}
// silentLogger swallows logs in tests but still satisfies the Logger
// interface — keeps test output uncluttered.
type silentLogger struct{}
func (silentLogger) Info(string, ...any) {}
func (silentLogger) Warn(string, ...any) {}
func (silentLogger) Error(string, ...any) {}
// mustLocal wraps NewLocal for tests so the per-test boilerplate stays
// readable. Any test-only HMAC key must be >= 16 chars (see minMasterKeyLen).
func mustLocal(t *testing.T, key string) *LocalSigner {
t.Helper()
s, err := NewLocal(key)
if err != nil {
t.Fatalf("NewLocal: %v", err)
}
return s
}
// TestIssueVerifyRoundTrip is the smoke test: a token minted by Issue
// must verify under VerifyJWT with the matching public key, and the
// claims must round-trip intact.
func TestIssueVerifyRoundTrip(t *testing.T) {
priv := testKey(t)
tok, err := Issue(priv, "customer-abc", TierTrial, 0, 30*24*time.Hour)
if err != nil {
t.Fatalf("Issue: %v", err)
}
claims, err := VerifyJWT(tok, &priv.PublicKey)
if err != nil {
t.Fatalf("VerifyJWT: %v", err)
}
if claims.Subject != "customer-abc" {
t.Errorf("sub = %q, want customer-abc", claims.Subject)
}
if claims.Tier != TierTrial {
t.Errorf("tier = %q, want %q", claims.Tier, TierTrial)
}
if claims.MaxDevices != TrialMaxDevices {
// Trial JWT minted with MaxDevices=0 should default to TrialMaxDevices
// (VerifyJWT normalizes this — see token.go).
t.Errorf("max_devices = %d, want %d (trial default)", claims.MaxDevices, TrialMaxDevices)
}
}
// TestVerifyRejectsWrongKey makes sure a token signed by key A cannot
// be verified with key B's public half. This is what would fail open
// if "alg":"none" tampering wasn't blocked, or if someone reused keys.
func TestVerifyRejectsWrongKey(t *testing.T) {
priv1 := testKey(t)
priv2 := testKey(t)
tok, err := Issue(priv1, "customer-x", TierPaid, 50, time.Hour)
if err != nil {
t.Fatalf("Issue: %v", err)
}
if _, err := VerifyJWT(tok, &priv2.PublicKey); err == nil {
t.Fatal("VerifyJWT accepted token signed with a different key")
}
}
// TestPaidRequiresMaxDevices: paid tier must carry an explicit cap; trial
// gets a default. Catches misconfigured tokens at verify time.
func TestPaidRequiresMaxDevices(t *testing.T) {
priv := testKey(t)
_, err := Issue(priv, "customer-y", TierPaid, 0, time.Hour)
if err == nil {
t.Fatal("Issue accepted paid tier with max_devices=0")
}
}
// TestNoOpSignerReturnsEmpty: free tier produces no signature so the
// client's private library trips and refuses high-tier features.
func TestNoOpSignerReturnsEmpty(t *testing.T) {
s := NewNoOp()
sig, err := s.Sign("2026-05-20", "12345")
if err != nil {
t.Fatalf("Sign: %v", err)
}
if sig != "" {
t.Errorf("NoOpSigner.Sign = %q, want empty", sig)
}
if s.Mode() != "noop" {
t.Errorf("Mode = %q, want noop", s.Mode())
}
}
// TestLocalSignerDeterministic: HMAC is deterministic — same key + same
// input must always yield the same output. This is the property that
// makes RemoteSigner's cache correct.
func TestLocalSignerDeterministic(t *testing.T) {
s, err := NewLocal("shared-secret-xyz-long-enough")
if err != nil {
t.Fatalf("NewLocal: %v", err)
}
a, _ := s.Sign("2026-05-20T10:00:00Z", "12345")
b, _ := s.Sign("2026-05-20T10:00:00Z", "12345")
if a != b {
t.Errorf("non-deterministic: %q vs %q", a, b)
}
if len(a) != 64 {
t.Errorf("signature length = %d, want 64 (hex of HMAC-SHA256)", len(a))
}
}
// TestRemoteSignerCacheHit verifies that the second call for the same
// (startTime, clientID) tuple doesn't hit the network. We assert this
// by counting requests at the fake License Server.
func TestRemoteSignerCacheHit(t *testing.T) {
priv := testKey(t)
master := mustLocal(t, "real-hmac-key-for-test-xx")
ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{})
ts := httptest.NewServer(ls.Handler())
defer ts.Close()
tok, err := Issue(priv, "cust-cache", TierPaid, 10, time.Hour)
if err != nil {
t.Fatalf("Issue: %v", err)
}
// Count requests by wrapping the LS handler.
var calls atomic.Int64
counting := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls.Add(1)
ls.Handler().ServeHTTP(w, r)
}))
defer counting.Close()
rs := NewRemote(counting.URL, tok, time.Hour, silentLogger{})
defer rs.Close()
sig1, err := rs.Sign("st-1", "client-1")
if err != nil {
t.Fatalf("first Sign: %v", err)
}
sig2, err := rs.Sign("st-1", "client-1") // identical → cache hit
if err != nil {
t.Fatalf("second Sign: %v", err)
}
if sig1 != sig2 {
t.Errorf("signatures differ across cache: %q vs %q", sig1, sig2)
}
if got := calls.Load(); got != 1 {
t.Errorf("expected exactly 1 HTTP call (second served from cache), got %d", got)
}
}
// TestRemoteSignerStaleFallback: when the License Server goes down, an
// expired-cache entry is still better than zero signature (avoids breaking
// reconnects for existing devices during a transient outage).
func TestRemoteSignerStaleFallback(t *testing.T) {
priv := testKey(t)
master := mustLocal(t, "master-fallback-test-xxx")
ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{})
ts := httptest.NewServer(ls.Handler())
tok, err := Issue(priv, "cust-fallback", TierPaid, 5, time.Hour)
if err != nil {
t.Fatalf("Issue: %v", err)
}
// Grace = 1 ns so the next call considers the cache stale.
rs := NewRemote(ts.URL, tok, time.Nanosecond, silentLogger{})
defer rs.Close()
first, err := rs.Sign("st-fb", "client-fb")
if err != nil {
t.Fatalf("first Sign: %v", err)
}
// Take the server offline.
ts.Close()
second, err := rs.Sign("st-fb", "client-fb")
if err != nil {
t.Fatalf("post-outage Sign should fall back to stale cache, got err=%v", err)
}
if second != first {
t.Errorf("stale fallback returned %q, want cached %q", second, first)
}
}
// TestQuotaEnforcement: trial tier with 2 devices accepts the first two
// but rejects the third with 403.
func TestQuotaEnforcement(t *testing.T) {
priv := testKey(t)
master := mustLocal(t, "master-quota-test-xxxxxx")
ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{})
ts := httptest.NewServer(ls.Handler())
defer ts.Close()
// Trial JWT capped at 2 devices.
tok, err := Issue(priv, "cust-quota", TierTrial, 2, time.Hour)
if err != nil {
t.Fatalf("Issue: %v", err)
}
call := func(clientID string) (int, string) {
body, _ := json.Marshal(signRequest{ClientID: clientID, StartTime: "st"})
req, _ := http.NewRequest("POST", ts.URL+"/license/sign", strings.NewReader(string(body)))
req.Header.Set("Authorization", "Bearer "+tok)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Do: %v", err)
}
defer resp.Body.Close()
var sr signResponse
_ = json.NewDecoder(resp.Body).Decode(&sr)
if sr.Signature != "" {
return resp.StatusCode, sr.Signature
}
return resp.StatusCode, sr.Error
}
if code, _ := call("dev-1"); code != http.StatusOK {
t.Errorf("dev-1 expected 200, got %d", code)
}
if code, _ := call("dev-2"); code != http.StatusOK {
t.Errorf("dev-2 expected 200, got %d", code)
}
if code, msg := call("dev-3"); code != http.StatusForbidden {
t.Errorf("dev-3 expected 403, got %d (%q)", code, msg)
}
}
// TestAuthRejectsMissingBearer: no token → 401, not 200 / not 500. Belt
// and braces — the auth check sits in front of /sign and /heartbeat.
func TestAuthRejectsMissingBearer(t *testing.T) {
priv := testKey(t)
master := mustLocal(t, "master-auth-test-xxxxxxx")
ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{})
ts := httptest.NewServer(ls.Handler())
defer ts.Close()
body := strings.NewReader(`{"client_id":"x","start_time":"y"}`)
resp, err := http.Post(ts.URL+"/license/sign", "application/json", body)
if err != nil {
t.Fatalf("Post: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", resp.StatusCode)
}
}
// TestRemoteSignerHardFailNoCacheReturnsError: when the LS is unreachable
// AND we have no cached signature for this (startTime, clientID), Sign must
// return ("", err) so the caller (sendMasterSetting) ships zeroed bytes.
// Previously untested — easy to silently regress.
func TestRemoteSignerHardFailNoCacheReturnsError(t *testing.T) {
// Point RemoteSigner at a URL that will refuse all connections.
// Using an unreachable host avoids depending on a free local port.
rs := NewRemote("https://127.0.0.1:1", "any-token", time.Hour, silentLogger{})
defer rs.Close()
sig, err := rs.Sign("st-x", "client-x")
if err == nil {
t.Fatal("expected error when LS unreachable and cache empty")
}
if sig != "" {
t.Errorf("expected empty signature on hard failure, got %q", sig)
}
}
// TestHeartbeatRefreshOnly: malicious customer POSTs fake clientIDs to
// /license/heartbeat. The fake IDs MUST NOT show up in the server's view —
// only IDs already minted via /sign get refreshed. This is the anti-tamper
// property that makes the quota system actually enforce.
func TestHeartbeatRefreshOnly(t *testing.T) {
priv := testKey(t)
master := mustLocal(t, "master-hb-test-xxxxxxxxxx")
ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{})
ts := httptest.NewServer(ls.Handler())
defer ts.Close()
tok, err := Issue(priv, "cust-hb", TierPaid, 5, time.Hour)
if err != nil {
t.Fatalf("Issue: %v", err)
}
// Reserve one legit device via /sign first.
signBody, _ := json.Marshal(signRequest{ClientID: "legit-1", StartTime: "st"})
req, _ := http.NewRequest("POST", ts.URL+"/license/sign", strings.NewReader(string(signBody)))
req.Header.Set("Authorization", "Bearer "+tok)
req.Header.Set("Content-Type", "application/json")
if resp, err := http.DefaultClient.Do(req); err != nil {
t.Fatalf("Do sign: %v", err)
} else {
resp.Body.Close()
}
// Now heartbeat reports 1 legit ID + 99 fake IDs.
fakes := make([]string, 99)
for i := range fakes {
fakes[i] = fmt.Sprintf("fake-%d", i)
}
all := append([]string{"legit-1"}, fakes...)
hbBody, _ := json.Marshal(heartbeatRequest{
ActiveDeviceCount: len(all),
ActiveDeviceIDs: all,
})
req, _ = http.NewRequest("POST", ts.URL+"/license/heartbeat", strings.NewReader(string(hbBody)))
req.Header.Set("Authorization", "Bearer "+tok)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Do heartbeat: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("heartbeat returned %d", resp.StatusCode)
}
var hbResp heartbeatResponse
_ = json.NewDecoder(resp.Body).Decode(&hbResp)
// Server should report exactly 1 active device (the one minted via /sign),
// NOT 1 + 99. drift = 99.
if hbResp.ServerViewCount != 1 {
t.Errorf("server view = %d, want 1 (heartbeat must not insert fake IDs)", hbResp.ServerViewCount)
}
if hbResp.Drift != 99 {
t.Errorf("drift = %d, want 99", hbResp.Drift)
}
// Verify the quota cap is still enforced: a fresh /sign for dev-2..dev-5
// should succeed (only 1 slot used), dev-6 should fail.
tryReserve := func(cid string) int {
body, _ := json.Marshal(signRequest{ClientID: cid, StartTime: "st"})
req, _ := http.NewRequest("POST", ts.URL+"/license/sign", strings.NewReader(string(body)))
req.Header.Set("Authorization", "Bearer "+tok)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Do: %v", err)
}
defer resp.Body.Close()
return resp.StatusCode
}
for i := 2; i <= 5; i++ {
if code := tryReserve(fmt.Sprintf("legit-%d", i)); code != http.StatusOK {
t.Errorf("legit-%d expected 200, got %d", i, code)
}
}
if code := tryReserve("legit-6"); code != http.StatusForbidden {
t.Errorf("legit-6 expected 403 (max=5), got %d", code)
}
}
// TestQuotaRejectionDoesNotConsumeSlot: a rejected /sign must not leave
// its clientID in the quota map. Otherwise a denied 3rd device would
// permanently take a slot from the legitimate 1st/2nd.
func TestQuotaRejectionDoesNotConsumeSlot(t *testing.T) {
priv := testKey(t)
master := mustLocal(t, "master-no-leak-xxxxxxxxxxxx")
ls := NewLicenseServer(master, &priv.PublicKey, time.Minute, silentLogger{})
ts := httptest.NewServer(ls.Handler())
defer ts.Close()
tok, err := Issue(priv, "cust-leak", TierTrial, 2, time.Hour)
if err != nil {
t.Fatalf("Issue: %v", err)
}
doSign := func(cid string) int {
body, _ := json.Marshal(signRequest{ClientID: cid, StartTime: "st"})
req, _ := http.NewRequest("POST", ts.URL+"/license/sign", strings.NewReader(string(body)))
req.Header.Set("Authorization", "Bearer "+tok)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Do: %v", err)
}
defer resp.Body.Close()
return resp.StatusCode
}
// Fill to cap.
if c := doSign("a"); c != http.StatusOK {
t.Fatalf("a expected 200, got %d", c)
}
if c := doSign("b"); c != http.StatusOK {
t.Fatalf("b expected 200, got %d", c)
}
// Over cap — must be denied AND must NOT consume a slot.
if c := doSign("c"); c != http.StatusForbidden {
t.Fatalf("c expected 403, got %d", c)
}
if c := doSign("d"); c != http.StatusForbidden {
t.Fatalf("d expected 403, got %d", c)
}
// Existing device 'a' re-signs — must still succeed (idempotent refresh).
if c := doSign("a"); c != http.StatusOK {
t.Fatalf("a re-sign expected 200, got %d", c)
}
}
// TestQuotaTrackerEviction: after evictAfter elapses, a previously-occupied
// slot must be reclaimed so a new device can take it. Exercises the time
// path that TestQuotaEnforcement skips (it uses a long eviction window).
func TestQuotaTrackerEviction(t *testing.T) {
q := newQuotaTracker(50 * time.Millisecond)
if count, ok := q.Reserve("cust", "dev-1", 1); !ok || count != 1 {
t.Fatalf("first Reserve: count=%d ok=%v", count, ok)
}
if count, ok := q.Reserve("cust", "dev-2", 1); ok {
t.Fatalf("expected over-cap rejection, got count=%d ok=%v", count, ok)
}
time.Sleep(80 * time.Millisecond)
// dev-1's entry should now be stale; dev-2 should be admitted.
if count, ok := q.Reserve("cust", "dev-2", 1); !ok || count != 1 {
t.Fatalf("post-eviction Reserve: count=%d ok=%v", count, ok)
}
}
// TestValidateRemoteURL: factory must reject http:// for non-loopback
// targets so JWT/sigs don't leak in cleartext.
func TestValidateRemoteURL(t *testing.T) {
cases := []struct {
url string
wantError bool
}{
{"https://license.example.com", false},
{"https://localhost:8443", false},
{"http://localhost:8443", false}, // loopback exception
{"http://127.0.0.1:8443", false}, // loopback exception
{"http://license.example.com", true}, // public http → reject
{"ftp://license.example.com", true}, // bad scheme
{"not a url at all", true},
}
for _, c := range cases {
err := ValidateRemoteURL(c.url)
if (err != nil) != c.wantError {
t.Errorf("ValidateRemoteURL(%q): err=%v, wantError=%v", c.url, err, c.wantError)
}
}
}
// TestIssueRejectsShortTTL: catches fat-finger ttl=0 / negative that mints
// an already-expired token.
func TestIssueRejectsShortTTL(t *testing.T) {
priv := testKey(t)
if _, err := Issue(priv, "cust", TierPaid, 10, 0); err == nil {
t.Error("expected error for ttl=0")
}
if _, err := Issue(priv, "cust", TierPaid, 10, time.Minute); err == nil {
t.Error("expected error for ttl below minimum")
}
if _, err := Issue(priv, "", TierPaid, 10, time.Hour); err == nil {
t.Error("expected error for empty sub")
}
}
// TestNewLocalRejectsShortKey: catches misconfigured YAMA_SIGN_PASSWORD
// (empty / typo) at construction instead of silently signing with junk.
func TestNewLocalRejectsShortKey(t *testing.T) {
if _, err := NewLocal(""); err == nil {
t.Error("expected error for empty master key")
}
if _, err := NewLocal("short"); err == nil {
t.Error("expected error for too-short master key")
}
}
// TestJWTAlgLockedToRS256: a token with any non-RS256 alg (here RS384) must
// fail verification, even though the underlying RSA primitive is the same.
// This pins the docs↔code contract.
func TestJWTAlgLockedToRS256(t *testing.T) {
priv := testKey(t)
now := time.Now()
claims := &LicenseClaims{
Tier: TierTrial,
MaxDevices: 10,
RegisteredClaims: jwt.RegisteredClaims{
Subject: "cust",
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodRS384, claims)
signed, err := tok.SignedString(priv)
if err != nil {
t.Fatalf("sign: %v", err)
}
if _, err := VerifyJWT(signed, &priv.PublicKey); err == nil {
t.Error("VerifyJWT accepted RS384; alg should be locked to RS256")
}
}