Files
SimpleRemoter/server/go/licensing/token.go

143 lines
4.4 KiB
Go

package licensing
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"os"
"time"
"github.com/golang-jwt/jwt/v5"
)
// LicenseClaims is the JWT payload the operator signs and ships to each
// customer. The operator picks "sub" (a unique customer ID), "tier" (trial
// or paid), "max_devices" (concurrent device cap), and "exp" (token
// expiry — independent of any in-memory cache TTL on the RemoteSigner
// side).
type LicenseClaims struct {
Tier string `json:"tier"`
MaxDevices int `json:"max_devices"`
jwt.RegisteredClaims
}
// LoadRSAPrivateKey parses an RSA private key from a PEM file. Used by the
// "issue-token" CLI subcommand to sign customer JWTs offline.
// Accepts PKCS#1 ("RSA PRIVATE KEY") and PKCS#8 ("PRIVATE KEY") PEM encodings.
func LoadRSAPrivateKey(pemPath string) (*rsa.PrivateKey, error) {
data, err := os.ReadFile(pemPath)
if err != nil {
return nil, fmt.Errorf("read private key %s: %w", pemPath, err)
}
block, _ := pem.Decode(data)
if block == nil {
return nil, fmt.Errorf("no PEM block in %s", pemPath)
}
// PKCS#1: "RSA PRIVATE KEY"
if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return key, nil
}
// PKCS#8: "PRIVATE KEY"
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("PKCS#8 key in %s is not RSA", pemPath)
}
return rsaKey, nil
}
return nil, fmt.Errorf("failed to parse %s as PKCS#1 or PKCS#8 RSA private key", pemPath)
}
// LoadRSAPublicKey parses an RSA public key from a PEM file. The License
// Server loads this once at startup to verify incoming customer JWTs.
// Accepts both PKCS#1 ("RSA PUBLIC KEY") and PKIX ("PUBLIC KEY") PEM
// encodings — openssl emits PKIX by default; "openssl rsa -RSAPublicKey_out"
// emits PKCS#1.
func LoadRSAPublicKey(pemPath string) (*rsa.PublicKey, error) {
data, err := os.ReadFile(pemPath)
if err != nil {
return nil, fmt.Errorf("read public key %s: %w", pemPath, err)
}
block, _ := pem.Decode(data)
if block == nil {
return nil, fmt.Errorf("no PEM block in %s", pemPath)
}
// Try PKIX first (most common output of openssl genrsa | openssl rsa -pubout).
if pub, err := x509.ParsePKIXPublicKey(block.Bytes); err == nil {
rsaPub, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("PKIX key in %s is not RSA", pemPath)
}
return rsaPub, nil
}
// Fall back to PKCS#1.
if pub, err := x509.ParsePKCS1PublicKey(block.Bytes); err == nil {
return pub, nil
}
return nil, fmt.Errorf("failed to parse %s as PKIX or PKCS#1 RSA public key", pemPath)
}
// VerifyJWT validates the customer's JWT against the License Server's
// public key. Returns the parsed claims on success. Caller enforces the
// tier-specific quota using claims.MaxDevices.
//
// Validation done by jwt.ParseWithClaims:
// - signature (RS256, using the supplied public key)
// - "exp" claim (expiry)
// - "iat" / "nbf" if present
//
// We additionally require tier ∈ {trial, paid}, max_devices > 0, and "sub"
// non-empty.
func VerifyJWT(tokenStr string, pubKey *rsa.PublicKey) (*LicenseClaims, error) {
claims := &LicenseClaims{}
token, err := jwt.ParseWithClaims(tokenStr, claims, func(t *jwt.Token) (any, error) {
// Lock to exactly RS256. The wider *SigningMethodRSA check (which
// would also accept RS384/RS512) is technically still safe because
// all RSA variants require the private key — but pinning the exact
// alg matches the docs and avoids surprises if Issue() ever changes.
// "alg":"none" tampering is blocked because none isn't SigningMethodRS256.
if t.Method != jwt.SigningMethodRS256 {
return nil, fmt.Errorf("unexpected JWT alg: %v (need RS256)", t.Header["alg"])
}
return pubKey, nil
})
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid JWT")
}
if claims.Subject == "" {
return nil, errors.New("JWT missing 'sub' claim")
}
switch claims.Tier {
case TierTrial:
if claims.MaxDevices <= 0 {
claims.MaxDevices = TrialMaxDevices
}
case TierPaid:
if claims.MaxDevices <= 0 {
return nil, errors.New("paid-tier JWT missing max_devices")
}
default:
return nil, fmt.Errorf("unsupported tier: %q", claims.Tier)
}
return claims, nil
}
// ttlSinceNow returns the time.Duration until exp; useful for logging.
func (c *LicenseClaims) ttlSinceNow() time.Duration {
if c.ExpiresAt == nil {
return 0
}
return time.Until(c.ExpiresAt.Time)
}