package licensing import ( "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "fmt" "os" "time" "github.com/golang-jwt/jwt/v5" ) // LicenseClaims is the JWT payload the operator signs and ships to each // customer. The operator picks "sub" (a unique customer ID), "tier" (trial // or paid), "max_devices" (concurrent device cap), and "exp" (token // expiry — independent of any in-memory cache TTL on the RemoteSigner // side). type LicenseClaims struct { Tier string `json:"tier"` MaxDevices int `json:"max_devices"` jwt.RegisteredClaims } // LoadRSAPrivateKey parses an RSA private key from a PEM file. Used by the // "issue-token" CLI subcommand to sign customer JWTs offline. // Accepts PKCS#1 ("RSA PRIVATE KEY") and PKCS#8 ("PRIVATE KEY") PEM encodings. func LoadRSAPrivateKey(pemPath string) (*rsa.PrivateKey, error) { data, err := os.ReadFile(pemPath) if err != nil { return nil, fmt.Errorf("read private key %s: %w", pemPath, err) } block, _ := pem.Decode(data) if block == nil { return nil, fmt.Errorf("no PEM block in %s", pemPath) } // PKCS#1: "RSA PRIVATE KEY" if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { return key, nil } // PKCS#8: "PRIVATE KEY" if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { rsaKey, ok := key.(*rsa.PrivateKey) if !ok { return nil, fmt.Errorf("PKCS#8 key in %s is not RSA", pemPath) } return rsaKey, nil } return nil, fmt.Errorf("failed to parse %s as PKCS#1 or PKCS#8 RSA private key", pemPath) } // LoadRSAPublicKey parses an RSA public key from a PEM file. The License // Server loads this once at startup to verify incoming customer JWTs. // Accepts both PKCS#1 ("RSA PUBLIC KEY") and PKIX ("PUBLIC KEY") PEM // encodings — openssl emits PKIX by default; "openssl rsa -RSAPublicKey_out" // emits PKCS#1. func LoadRSAPublicKey(pemPath string) (*rsa.PublicKey, error) { data, err := os.ReadFile(pemPath) if err != nil { return nil, fmt.Errorf("read public key %s: %w", pemPath, err) } block, _ := pem.Decode(data) if block == nil { return nil, fmt.Errorf("no PEM block in %s", pemPath) } // Try PKIX first (most common output of openssl genrsa | openssl rsa -pubout). if pub, err := x509.ParsePKIXPublicKey(block.Bytes); err == nil { rsaPub, ok := pub.(*rsa.PublicKey) if !ok { return nil, fmt.Errorf("PKIX key in %s is not RSA", pemPath) } return rsaPub, nil } // Fall back to PKCS#1. if pub, err := x509.ParsePKCS1PublicKey(block.Bytes); err == nil { return pub, nil } return nil, fmt.Errorf("failed to parse %s as PKIX or PKCS#1 RSA public key", pemPath) } // VerifyJWT validates the customer's JWT against the License Server's // public key. Returns the parsed claims on success. Caller enforces the // tier-specific quota using claims.MaxDevices. // // Validation done by jwt.ParseWithClaims: // - signature (RS256, using the supplied public key) // - "exp" claim (expiry) // - "iat" / "nbf" if present // // We additionally require tier ∈ {trial, paid}, max_devices > 0, and "sub" // non-empty. func VerifyJWT(tokenStr string, pubKey *rsa.PublicKey) (*LicenseClaims, error) { claims := &LicenseClaims{} token, err := jwt.ParseWithClaims(tokenStr, claims, func(t *jwt.Token) (any, error) { // Lock to exactly RS256. The wider *SigningMethodRSA check (which // would also accept RS384/RS512) is technically still safe because // all RSA variants require the private key — but pinning the exact // alg matches the docs and avoids surprises if Issue() ever changes. // "alg":"none" tampering is blocked because none isn't SigningMethodRS256. if t.Method != jwt.SigningMethodRS256 { return nil, fmt.Errorf("unexpected JWT alg: %v (need RS256)", t.Header["alg"]) } return pubKey, nil }) if err != nil { return nil, err } if !token.Valid { return nil, errors.New("invalid JWT") } if claims.Subject == "" { return nil, errors.New("JWT missing 'sub' claim") } switch claims.Tier { case TierTrial: if claims.MaxDevices <= 0 { claims.MaxDevices = TrialMaxDevices } case TierPaid: if claims.MaxDevices <= 0 { return nil, errors.New("paid-tier JWT missing max_devices") } default: return nil, fmt.Errorf("unsupported tier: %q", claims.Tier) } return claims, nil } // ttlSinceNow returns the time.Duration until exp; useful for logging. func (c *LicenseClaims) ttlSinceNow() time.Duration { if c.ExpiresAt == nil { return 0 } return time.Until(c.ExpiresAt.Time) }