Feat(go): add Signer interface + License Server for multi-customer deployments
This commit is contained in:
114
server/go/licensing/token.go
Normal file
114
server/go/licensing/token.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package licensing
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// LicenseClaims is the JWT payload the operator signs and ships to each
|
||||
// customer. The operator picks "sub" (a unique customer ID), "tier" (trial
|
||||
// or paid), "max_devices" (concurrent device cap), and "exp" (token
|
||||
// expiry — independent of any in-memory cache TTL on the RemoteSigner
|
||||
// side).
|
||||
type LicenseClaims struct {
|
||||
Tier string `json:"tier"`
|
||||
MaxDevices int `json:"max_devices"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// LoadRSAPublicKey parses an RSA public key from a PEM file. The License
|
||||
// Server loads this once at startup to verify incoming customer JWTs.
|
||||
// Accepts both PKCS#1 ("RSA PUBLIC KEY") and PKIX ("PUBLIC KEY") PEM
|
||||
// encodings — openssl emits PKIX by default; "openssl rsa -RSAPublicKey_out"
|
||||
// emits PKCS#1.
|
||||
func LoadRSAPublicKey(pemPath string) (*rsa.PublicKey, error) {
|
||||
data, err := os.ReadFile(pemPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read public key %s: %w", pemPath, err)
|
||||
}
|
||||
block, _ := pem.Decode(data)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM block in %s", pemPath)
|
||||
}
|
||||
|
||||
// Try PKIX first (most common output of openssl genrsa | openssl rsa -pubout).
|
||||
if pub, err := x509.ParsePKIXPublicKey(block.Bytes); err == nil {
|
||||
rsaPub, ok := pub.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("PKIX key in %s is not RSA", pemPath)
|
||||
}
|
||||
return rsaPub, nil
|
||||
}
|
||||
|
||||
// Fall back to PKCS#1.
|
||||
if pub, err := x509.ParsePKCS1PublicKey(block.Bytes); err == nil {
|
||||
return pub, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to parse %s as PKIX or PKCS#1 RSA public key", pemPath)
|
||||
}
|
||||
|
||||
// VerifyJWT validates the customer's JWT against the License Server's
|
||||
// public key. Returns the parsed claims on success. Caller enforces the
|
||||
// tier-specific quota using claims.MaxDevices.
|
||||
//
|
||||
// Validation done by jwt.ParseWithClaims:
|
||||
// - signature (RS256, using the supplied public key)
|
||||
// - "exp" claim (expiry)
|
||||
// - "iat" / "nbf" if present
|
||||
//
|
||||
// We additionally require tier ∈ {trial, paid}, max_devices > 0, and "sub"
|
||||
// non-empty.
|
||||
func VerifyJWT(tokenStr string, pubKey *rsa.PublicKey) (*LicenseClaims, error) {
|
||||
claims := &LicenseClaims{}
|
||||
token, err := jwt.ParseWithClaims(tokenStr, claims, func(t *jwt.Token) (any, error) {
|
||||
// Lock to exactly RS256. The wider *SigningMethodRSA check (which
|
||||
// would also accept RS384/RS512) is technically still safe because
|
||||
// all RSA variants require the private key — but pinning the exact
|
||||
// alg matches the docs and avoids surprises if Issue() ever changes.
|
||||
// "alg":"none" tampering is blocked because none isn't SigningMethodRS256.
|
||||
if t.Method != jwt.SigningMethodRS256 {
|
||||
return nil, fmt.Errorf("unexpected JWT alg: %v (need RS256)", t.Header["alg"])
|
||||
}
|
||||
return pubKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid JWT")
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, errors.New("JWT missing 'sub' claim")
|
||||
}
|
||||
switch claims.Tier {
|
||||
case TierTrial:
|
||||
if claims.MaxDevices <= 0 {
|
||||
claims.MaxDevices = TrialMaxDevices
|
||||
}
|
||||
case TierPaid:
|
||||
if claims.MaxDevices <= 0 {
|
||||
return nil, errors.New("paid-tier JWT missing max_devices")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported tier: %q", claims.Tier)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// ttlSinceNow returns the time.Duration until exp; useful for logging.
|
||||
func (c *LicenseClaims) ttlSinceNow() time.Duration {
|
||||
if c.ExpiresAt == nil {
|
||||
return 0
|
||||
}
|
||||
return time.Until(c.ExpiresAt.Time)
|
||||
}
|
||||
Reference in New Issue
Block a user