143 lines
4.4 KiB
Go
143 lines
4.4 KiB
Go
package licensing
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
// LicenseClaims is the JWT payload the operator signs and ships to each
|
|
// customer. The operator picks "sub" (a unique customer ID), "tier" (trial
|
|
// or paid), "max_devices" (concurrent device cap), and "exp" (token
|
|
// expiry — independent of any in-memory cache TTL on the RemoteSigner
|
|
// side).
|
|
type LicenseClaims struct {
|
|
Tier string `json:"tier"`
|
|
MaxDevices int `json:"max_devices"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
// LoadRSAPrivateKey parses an RSA private key from a PEM file. Used by the
|
|
// "issue-token" CLI subcommand to sign customer JWTs offline.
|
|
// Accepts PKCS#1 ("RSA PRIVATE KEY") and PKCS#8 ("PRIVATE KEY") PEM encodings.
|
|
func LoadRSAPrivateKey(pemPath string) (*rsa.PrivateKey, error) {
|
|
data, err := os.ReadFile(pemPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read private key %s: %w", pemPath, err)
|
|
}
|
|
block, _ := pem.Decode(data)
|
|
if block == nil {
|
|
return nil, fmt.Errorf("no PEM block in %s", pemPath)
|
|
}
|
|
|
|
// PKCS#1: "RSA PRIVATE KEY"
|
|
if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
|
return key, nil
|
|
}
|
|
// PKCS#8: "PRIVATE KEY"
|
|
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
|
|
rsaKey, ok := key.(*rsa.PrivateKey)
|
|
if !ok {
|
|
return nil, fmt.Errorf("PKCS#8 key in %s is not RSA", pemPath)
|
|
}
|
|
return rsaKey, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to parse %s as PKCS#1 or PKCS#8 RSA private key", pemPath)
|
|
}
|
|
|
|
// LoadRSAPublicKey parses an RSA public key from a PEM file. The License
|
|
// Server loads this once at startup to verify incoming customer JWTs.
|
|
// Accepts both PKCS#1 ("RSA PUBLIC KEY") and PKIX ("PUBLIC KEY") PEM
|
|
// encodings — openssl emits PKIX by default; "openssl rsa -RSAPublicKey_out"
|
|
// emits PKCS#1.
|
|
func LoadRSAPublicKey(pemPath string) (*rsa.PublicKey, error) {
|
|
data, err := os.ReadFile(pemPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read public key %s: %w", pemPath, err)
|
|
}
|
|
block, _ := pem.Decode(data)
|
|
if block == nil {
|
|
return nil, fmt.Errorf("no PEM block in %s", pemPath)
|
|
}
|
|
|
|
// Try PKIX first (most common output of openssl genrsa | openssl rsa -pubout).
|
|
if pub, err := x509.ParsePKIXPublicKey(block.Bytes); err == nil {
|
|
rsaPub, ok := pub.(*rsa.PublicKey)
|
|
if !ok {
|
|
return nil, fmt.Errorf("PKIX key in %s is not RSA", pemPath)
|
|
}
|
|
return rsaPub, nil
|
|
}
|
|
|
|
// Fall back to PKCS#1.
|
|
if pub, err := x509.ParsePKCS1PublicKey(block.Bytes); err == nil {
|
|
return pub, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("failed to parse %s as PKIX or PKCS#1 RSA public key", pemPath)
|
|
}
|
|
|
|
// VerifyJWT validates the customer's JWT against the License Server's
|
|
// public key. Returns the parsed claims on success. Caller enforces the
|
|
// tier-specific quota using claims.MaxDevices.
|
|
//
|
|
// Validation done by jwt.ParseWithClaims:
|
|
// - signature (RS256, using the supplied public key)
|
|
// - "exp" claim (expiry)
|
|
// - "iat" / "nbf" if present
|
|
//
|
|
// We additionally require tier ∈ {trial, paid}, max_devices > 0, and "sub"
|
|
// non-empty.
|
|
func VerifyJWT(tokenStr string, pubKey *rsa.PublicKey) (*LicenseClaims, error) {
|
|
claims := &LicenseClaims{}
|
|
token, err := jwt.ParseWithClaims(tokenStr, claims, func(t *jwt.Token) (any, error) {
|
|
// Lock to exactly RS256. The wider *SigningMethodRSA check (which
|
|
// would also accept RS384/RS512) is technically still safe because
|
|
// all RSA variants require the private key — but pinning the exact
|
|
// alg matches the docs and avoids surprises if Issue() ever changes.
|
|
// "alg":"none" tampering is blocked because none isn't SigningMethodRS256.
|
|
if t.Method != jwt.SigningMethodRS256 {
|
|
return nil, fmt.Errorf("unexpected JWT alg: %v (need RS256)", t.Header["alg"])
|
|
}
|
|
return pubKey, nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !token.Valid {
|
|
return nil, errors.New("invalid JWT")
|
|
}
|
|
|
|
if claims.Subject == "" {
|
|
return nil, errors.New("JWT missing 'sub' claim")
|
|
}
|
|
switch claims.Tier {
|
|
case TierTrial:
|
|
if claims.MaxDevices <= 0 {
|
|
claims.MaxDevices = TrialMaxDevices
|
|
}
|
|
case TierPaid:
|
|
if claims.MaxDevices <= 0 {
|
|
return nil, errors.New("paid-tier JWT missing max_devices")
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("unsupported tier: %q", claims.Tier)
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// ttlSinceNow returns the time.Duration until exp; useful for logging.
|
|
func (c *LicenseClaims) ttlSinceNow() time.Duration {
|
|
if c.ExpiresAt == nil {
|
|
return 0
|
|
}
|
|
return time.Until(c.ExpiresAt.Time)
|
|
}
|