Security(Go): Login rate limit + WS origin allowlist + REST bearer auth
This commit is contained in:
@@ -2,7 +2,10 @@ package web
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -20,13 +23,12 @@ const (
|
||||
wsSendBuffer = 64 // outbound queue depth per client
|
||||
)
|
||||
|
||||
// upgrader allows any origin — this service is meant to be tunneled through
|
||||
// frp, so requests can legitimately arrive from arbitrary front-end hosts.
|
||||
// Adjust CheckOrigin once we have a deployment story.
|
||||
var upgrader = websocket.Upgrader{
|
||||
// baseUpgrader carries the buffer-size config shared by all WS upgrades.
|
||||
// CheckOrigin is set per-hub in wsHub.upgradeWS so the allowlist is
|
||||
// closed-over per instance instead of being a global mutable.
|
||||
var baseUpgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// ----- per-connection client state ----------------------------------------
|
||||
@@ -94,6 +96,14 @@ type wsHub struct {
|
||||
clients map[*wsClient]struct{}
|
||||
|
||||
unsub func()
|
||||
|
||||
// Hardening knobs wired from server.Config. Nil/empty values mean
|
||||
// "no extra restriction" — useful for local dev where the hub is
|
||||
// exercised without server.Server wiring up the env-driven defaults.
|
||||
allowedOrigins []string // empty → only same-origin upgrades accepted
|
||||
loginIPLimit *wsauth.RateLimiter
|
||||
loginUserLimit *wsauth.RateLimiter
|
||||
trustForwardedFor bool // honor X-Forwarded-For (only when behind a trusted proxy)
|
||||
}
|
||||
|
||||
func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger) *wsHub {
|
||||
@@ -107,6 +117,104 @@ func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger)
|
||||
return h
|
||||
}
|
||||
|
||||
// withOriginAllowlist returns h after installing the explicit Origin
|
||||
// allowlist. Chainable. Pass empty/nil to keep "same-origin only".
|
||||
func (h *wsHub) withOriginAllowlist(origins []string) *wsHub {
|
||||
h.allowedOrigins = origins
|
||||
return h
|
||||
}
|
||||
|
||||
// withLoginRateLimiters wires per-IP and per-username throttles into
|
||||
// the login flow. Either may be nil to disable that dimension.
|
||||
func (h *wsHub) withLoginRateLimiters(byIP, byUser *wsauth.RateLimiter) *wsHub {
|
||||
h.loginIPLimit = byIP
|
||||
h.loginUserLimit = byUser
|
||||
return h
|
||||
}
|
||||
|
||||
// withTrustForwardedFor opts in to using the last entry of the
|
||||
// X-Forwarded-For header as the client IP. Safe only when the server is
|
||||
// behind a reverse proxy that you control.
|
||||
func (h *wsHub) withTrustForwardedFor(trust bool) *wsHub {
|
||||
h.trustForwardedFor = trust
|
||||
return h
|
||||
}
|
||||
|
||||
// checkOrigin decides whether to accept a WebSocket upgrade based on
|
||||
// the request's Origin header. Same-origin (Origin host == Host) is
|
||||
// always accepted; explicit allowlist entries cover the
|
||||
// PWA-from-different-domain or local-dev cases.
|
||||
//
|
||||
// An empty Origin header is rejected: a legitimate browser always sends
|
||||
// it on cross-origin requests, and same-origin requests have it too in
|
||||
// modern Chrome/Safari/Firefox. Non-browser clients (curl, scripts) that
|
||||
// omit Origin shouldn't be talking to the WS endpoint anyway.
|
||||
func (h *wsHub) checkOrigin(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return false
|
||||
}
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil || u.Host == "" {
|
||||
return false
|
||||
}
|
||||
// Same-origin (Origin host matches the Host the request came in on).
|
||||
// Strip any port mismatch: if the server is behind a proxy, Host may
|
||||
// not include a port while Origin does (or vice versa), so compare
|
||||
// the hostname components.
|
||||
originHost := u.Hostname()
|
||||
reqHost := stripPort(r.Host)
|
||||
if originHost == reqHost && originHost != "" {
|
||||
return true
|
||||
}
|
||||
// Explicit allowlist entries — match Origin in full (scheme + host
|
||||
// + port) so a customer can pin exactly one trusted PWA origin.
|
||||
for _, allowed := range h.allowedOrigins {
|
||||
if allowed == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(origin, strings.TrimSpace(allowed)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stripPort(hostport string) string {
|
||||
if h, _, err := net.SplitHostPort(hostport); err == nil {
|
||||
return h
|
||||
}
|
||||
return hostport
|
||||
}
|
||||
|
||||
// clientIP returns the source IP of an HTTP request. By default uses
|
||||
// r.RemoteAddr (the actual TCP peer); this is the only safe choice when
|
||||
// the server is directly exposed to the internet, because a malicious
|
||||
// client can put anything in X-Forwarded-For and would otherwise rotate
|
||||
// it to evade per-IP rate limits.
|
||||
//
|
||||
// When `trustForwardedFor` is true the LAST entry of X-Forwarded-For is
|
||||
// returned instead — appropriate only when running behind a reverse
|
||||
// proxy that you control and that overwrites/appends the header (caddy,
|
||||
// nginx with proper config, etc). Toggled via the YAMA_WEB_TRUST_PROXY
|
||||
// env var at startup.
|
||||
func clientIP(r *http.Request, trustForwardedFor bool) string {
|
||||
if trustForwardedFor {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the LAST entry — the closest hop, i.e. our own
|
||||
// trusted proxy's view of the peer. Trusting the first
|
||||
// entry would let a malicious client at the head of the
|
||||
// chain set an arbitrary value.
|
||||
parts := strings.Split(xff, ",")
|
||||
ip := strings.TrimSpace(parts[len(parts)-1])
|
||||
if ip != "" {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
return stripPort(r.RemoteAddr)
|
||||
}
|
||||
|
||||
// stop unsubscribes from the device hub. Existing connections keep running
|
||||
// until they close on their own; we only block new event delivery.
|
||||
func (h *wsHub) stop() {
|
||||
@@ -294,6 +402,11 @@ func (h *wsHub) unregister(c *wsClient) {
|
||||
// ----- HTTP handler -------------------------------------------------------
|
||||
|
||||
func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) {
|
||||
// Build a per-call upgrader so CheckOrigin closes over this hub's
|
||||
// allowlist instead of a package-level mutable.
|
||||
upgrader := baseUpgrader
|
||||
upgrader.CheckOrigin = h.checkOrigin
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.log.Error("ws upgrade: %v", err)
|
||||
@@ -313,7 +426,7 @@ func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) {
|
||||
send: make(chan wsMsg, wsSendBuffer),
|
||||
closed: make(chan struct{}),
|
||||
nonce: nonce,
|
||||
addr: r.RemoteAddr,
|
||||
addr: clientIP(r, h.trustForwardedFor),
|
||||
}
|
||||
h.register(client)
|
||||
defer h.unregister(client)
|
||||
|
||||
Reference in New Issue
Block a user