Security(Go): Login rate limit + WS origin allowlist + REST bearer auth

This commit is contained in:
yuanyuanxiang
2026-05-18 23:37:58 +02:00
committed by yuanyuanxiang
parent d7f38ecfdb
commit 32a75f4670
8 changed files with 566 additions and 41 deletions

View File

@@ -2,7 +2,10 @@ package web
import (
"encoding/json"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
@@ -20,13 +23,12 @@ const (
wsSendBuffer = 64 // outbound queue depth per client
)
// upgrader allows any origin — this service is meant to be tunneled through
// frp, so requests can legitimately arrive from arbitrary front-end hosts.
// Adjust CheckOrigin once we have a deployment story.
var upgrader = websocket.Upgrader{
// baseUpgrader carries the buffer-size config shared by all WS upgrades.
// CheckOrigin is set per-hub in wsHub.upgradeWS so the allowlist is
// closed-over per instance instead of being a global mutable.
var baseUpgrader = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool { return true },
}
// ----- per-connection client state ----------------------------------------
@@ -94,6 +96,14 @@ type wsHub struct {
clients map[*wsClient]struct{}
unsub func()
// Hardening knobs wired from server.Config. Nil/empty values mean
// "no extra restriction" — useful for local dev where the hub is
// exercised without server.Server wiring up the env-driven defaults.
allowedOrigins []string // empty → only same-origin upgrades accepted
loginIPLimit *wsauth.RateLimiter
loginUserLimit *wsauth.RateLimiter
trustForwardedFor bool // honor X-Forwarded-For (only when behind a trusted proxy)
}
func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger) *wsHub {
@@ -107,6 +117,104 @@ func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger)
return h
}
// withOriginAllowlist returns h after installing the explicit Origin
// allowlist. Chainable. Pass empty/nil to keep "same-origin only".
func (h *wsHub) withOriginAllowlist(origins []string) *wsHub {
h.allowedOrigins = origins
return h
}
// withLoginRateLimiters wires per-IP and per-username throttles into
// the login flow. Either may be nil to disable that dimension.
func (h *wsHub) withLoginRateLimiters(byIP, byUser *wsauth.RateLimiter) *wsHub {
h.loginIPLimit = byIP
h.loginUserLimit = byUser
return h
}
// withTrustForwardedFor opts in to using the last entry of the
// X-Forwarded-For header as the client IP. Safe only when the server is
// behind a reverse proxy that you control.
func (h *wsHub) withTrustForwardedFor(trust bool) *wsHub {
h.trustForwardedFor = trust
return h
}
// checkOrigin decides whether to accept a WebSocket upgrade based on
// the request's Origin header. Same-origin (Origin host == Host) is
// always accepted; explicit allowlist entries cover the
// PWA-from-different-domain or local-dev cases.
//
// An empty Origin header is rejected: a legitimate browser always sends
// it on cross-origin requests, and same-origin requests have it too in
// modern Chrome/Safari/Firefox. Non-browser clients (curl, scripts) that
// omit Origin shouldn't be talking to the WS endpoint anyway.
func (h *wsHub) checkOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return false
}
u, err := url.Parse(origin)
if err != nil || u.Host == "" {
return false
}
// Same-origin (Origin host matches the Host the request came in on).
// Strip any port mismatch: if the server is behind a proxy, Host may
// not include a port while Origin does (or vice versa), so compare
// the hostname components.
originHost := u.Hostname()
reqHost := stripPort(r.Host)
if originHost == reqHost && originHost != "" {
return true
}
// Explicit allowlist entries — match Origin in full (scheme + host
// + port) so a customer can pin exactly one trusted PWA origin.
for _, allowed := range h.allowedOrigins {
if allowed == "" {
continue
}
if strings.EqualFold(origin, strings.TrimSpace(allowed)) {
return true
}
}
return false
}
func stripPort(hostport string) string {
if h, _, err := net.SplitHostPort(hostport); err == nil {
return h
}
return hostport
}
// clientIP returns the source IP of an HTTP request. By default uses
// r.RemoteAddr (the actual TCP peer); this is the only safe choice when
// the server is directly exposed to the internet, because a malicious
// client can put anything in X-Forwarded-For and would otherwise rotate
// it to evade per-IP rate limits.
//
// When `trustForwardedFor` is true the LAST entry of X-Forwarded-For is
// returned instead — appropriate only when running behind a reverse
// proxy that you control and that overwrites/appends the header (caddy,
// nginx with proper config, etc). Toggled via the YAMA_WEB_TRUST_PROXY
// env var at startup.
func clientIP(r *http.Request, trustForwardedFor bool) string {
if trustForwardedFor {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the LAST entry — the closest hop, i.e. our own
// trusted proxy's view of the peer. Trusting the first
// entry would let a malicious client at the head of the
// chain set an arbitrary value.
parts := strings.Split(xff, ",")
ip := strings.TrimSpace(parts[len(parts)-1])
if ip != "" {
return ip
}
}
}
return stripPort(r.RemoteAddr)
}
// stop unsubscribes from the device hub. Existing connections keep running
// until they close on their own; we only block new event delivery.
func (h *wsHub) stop() {
@@ -294,6 +402,11 @@ func (h *wsHub) unregister(c *wsClient) {
// ----- HTTP handler -------------------------------------------------------
func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) {
// Build a per-call upgrader so CheckOrigin closes over this hub's
// allowlist instead of a package-level mutable.
upgrader := baseUpgrader
upgrader.CheckOrigin = h.checkOrigin
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
h.log.Error("ws upgrade: %v", err)
@@ -313,7 +426,7 @@ func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) {
send: make(chan wsMsg, wsSendBuffer),
closed: make(chan struct{}),
nonce: nonce,
addr: r.RemoteAddr,
addr: clientIP(r, h.trustForwardedFor),
}
h.register(client)
defer h.unregister(client)