Security(Go): Login rate limit + WS origin allowlist + REST bearer auth

This commit is contained in:
yuanyuanxiang
2026-05-18 23:37:58 +02:00
parent 5947d41617
commit c6244462a9
8 changed files with 566 additions and 41 deletions

View File

@@ -8,6 +8,7 @@ import (
"net"
"net/http"
"strconv"
"strings"
"time"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/hub"
@@ -19,12 +20,38 @@ import (
// static assets, the PWA manifest, and JSON APIs backed by the device hub.
// WebSocket signaling and screen streaming will be wired up in later phases.
type Server struct {
port int
log *logger.Logger
srv *http.Server
hub *hub.Hub
auth *wsauth.Authenticator
ws *wsHub
port int
log *logger.Logger
srv *http.Server
hub *hub.Hub
auth *wsauth.Authenticator
ws *wsHub
allowedOrigins []string // for WS Origin allowlist; empty = same-origin only
loginIPLimit *wsauth.RateLimiter
loginUserLimit *wsauth.RateLimiter
trustForwardedFor bool // honor X-Forwarded-For (behind trusted proxy only)
}
// Config tunes the server's exposed-on-public-HTTPS hardening knobs.
// All fields are optional; zero values pick reasonable defaults.
type Config struct {
// AllowedOrigins is the comma-separated list of Origin header values
// the WebSocket upgrade will accept in addition to same-origin
// requests. Empty (default) → only same-origin upgrades are allowed,
// which is correct when the web UI and the WS endpoint are served
// from the same host.
AllowedOrigins []string
// LoginIPLimit / LoginUserLimit throttle the get_salt + login flow
// per source IP and per username respectively. Pass nil to disable
// either dimension (e.g. dev mode).
LoginIPLimit *wsauth.RateLimiter
LoginUserLimit *wsauth.RateLimiter
// TrustForwardedFor switches client-IP extraction from RemoteAddr
// (default) to the last entry of X-Forwarded-For. Set true only when
// running behind a reverse proxy that you control; on direct
// exposure the header is client-controlled and would let attackers
// evade per-IP rate limits.
TrustForwardedFor bool
}
// New creates an HTTP server bound to the given port. port=0 disables the server.
@@ -34,6 +61,16 @@ func New(port int, log *logger.Logger, h *hub.Hub, auth *wsauth.Authenticator) *
return &Server{port: port, log: log, hub: h, auth: auth}
}
// WithConfig applies hardening configuration. Returns the receiver for
// chainable setup. Safe to call before Start; ignored thereafter.
func (s *Server) WithConfig(cfg Config) *Server {
s.allowedOrigins = cfg.AllowedOrigins
s.loginIPLimit = cfg.LoginIPLimit
s.loginUserLimit = cfg.LoginUserLimit
s.trustForwardedFor = cfg.TrustForwardedFor
return s
}
// Start launches the server in a goroutine and returns immediately.
// If port is 0, returns nil without starting anything.
func (s *Server) Start() error {
@@ -42,13 +79,16 @@ func (s *Server) Start() error {
return nil
}
s.ws = newWSHub(s.auth, s.hub, s.log)
s.ws = newWSHub(s.auth, s.hub, s.log).
withOriginAllowlist(s.allowedOrigins).
withLoginRateLimiters(s.loginIPLimit, s.loginUserLimit).
withTrustForwardedFor(s.trustForwardedFor)
mux := http.NewServeMux()
mux.HandleFunc("/", s.handleIndex)
mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/manifest.json", s.handleManifest)
mux.HandleFunc("/api/devices", s.handleDevices)
mux.HandleFunc("/api/devices", s.requireBearer(s.handleDevices))
mux.HandleFunc("/ws", s.ws.serve)
mux.HandleFunc("/static/xterm.js", staticHandler(xtermJS, "application/javascript; charset=utf-8"))
mux.HandleFunc("/static/xterm.css", staticHandler(xtermCSS, "text/css; charset=utf-8"))
@@ -106,7 +146,7 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
// handleDevices returns a JSON snapshot of currently-online devices. Empty
// array (not null) when no clients are connected — matches what the front-end
// will eventually expect.
// will eventually expect. Auth-gated via requireBearer.
func (s *Server) handleDevices(w http.ResponseWriter, r *http.Request) {
devices := s.hub.ListDevices()
w.Header().Set("Content-Type", "application/json; charset=utf-8")
@@ -116,6 +156,39 @@ func (s *Server) handleDevices(w http.ResponseWriter, r *http.Request) {
}
}
// requireBearer wraps a handler with `Authorization: Bearer <token>` auth
// against the same session-token store the WebSocket uses. Returns 401 on
// missing / invalid / expired tokens. Used to gate REST endpoints that
// previously fell through with no auth (notably /api/devices, which
// otherwise leaks the full online-device list to anyone on the internet).
func (s *Server) requireBearer(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
const prefix = "Bearer "
hdr := r.Header.Get("Authorization")
if !strings.HasPrefix(hdr, prefix) {
s.unauthorized(w)
return
}
token := strings.TrimSpace(hdr[len(prefix):])
if token == "" {
s.unauthorized(w)
return
}
if _, err := s.auth.ValidateToken(token); err != nil {
s.unauthorized(w)
return
}
next(w, r)
}
}
func (s *Server) unauthorized(w http.ResponseWriter) {
w.Header().Set("WWW-Authenticate", `Bearer realm="yama"`)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
}
// PWA manifest. Referenced by <link rel="manifest"> in index.html.
// Static JSON, no template needed.
const manifestJSON = `{

View File

@@ -2,7 +2,10 @@ package web
import (
"encoding/json"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
@@ -20,13 +23,12 @@ const (
wsSendBuffer = 64 // outbound queue depth per client
)
// upgrader allows any origin — this service is meant to be tunneled through
// frp, so requests can legitimately arrive from arbitrary front-end hosts.
// Adjust CheckOrigin once we have a deployment story.
var upgrader = websocket.Upgrader{
// baseUpgrader carries the buffer-size config shared by all WS upgrades.
// CheckOrigin is set per-hub in wsHub.upgradeWS so the allowlist is
// closed-over per instance instead of being a global mutable.
var baseUpgrader = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool { return true },
}
// ----- per-connection client state ----------------------------------------
@@ -94,6 +96,14 @@ type wsHub struct {
clients map[*wsClient]struct{}
unsub func()
// Hardening knobs wired from server.Config. Nil/empty values mean
// "no extra restriction" — useful for local dev where the hub is
// exercised without server.Server wiring up the env-driven defaults.
allowedOrigins []string // empty → only same-origin upgrades accepted
loginIPLimit *wsauth.RateLimiter
loginUserLimit *wsauth.RateLimiter
trustForwardedFor bool // honor X-Forwarded-For (only when behind a trusted proxy)
}
func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger) *wsHub {
@@ -107,6 +117,104 @@ func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger)
return h
}
// withOriginAllowlist returns h after installing the explicit Origin
// allowlist. Chainable. Pass empty/nil to keep "same-origin only".
func (h *wsHub) withOriginAllowlist(origins []string) *wsHub {
h.allowedOrigins = origins
return h
}
// withLoginRateLimiters wires per-IP and per-username throttles into
// the login flow. Either may be nil to disable that dimension.
func (h *wsHub) withLoginRateLimiters(byIP, byUser *wsauth.RateLimiter) *wsHub {
h.loginIPLimit = byIP
h.loginUserLimit = byUser
return h
}
// withTrustForwardedFor opts in to using the last entry of the
// X-Forwarded-For header as the client IP. Safe only when the server is
// behind a reverse proxy that you control.
func (h *wsHub) withTrustForwardedFor(trust bool) *wsHub {
h.trustForwardedFor = trust
return h
}
// checkOrigin decides whether to accept a WebSocket upgrade based on
// the request's Origin header. Same-origin (Origin host == Host) is
// always accepted; explicit allowlist entries cover the
// PWA-from-different-domain or local-dev cases.
//
// An empty Origin header is rejected: a legitimate browser always sends
// it on cross-origin requests, and same-origin requests have it too in
// modern Chrome/Safari/Firefox. Non-browser clients (curl, scripts) that
// omit Origin shouldn't be talking to the WS endpoint anyway.
func (h *wsHub) checkOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return false
}
u, err := url.Parse(origin)
if err != nil || u.Host == "" {
return false
}
// Same-origin (Origin host matches the Host the request came in on).
// Strip any port mismatch: if the server is behind a proxy, Host may
// not include a port while Origin does (or vice versa), so compare
// the hostname components.
originHost := u.Hostname()
reqHost := stripPort(r.Host)
if originHost == reqHost && originHost != "" {
return true
}
// Explicit allowlist entries — match Origin in full (scheme + host
// + port) so a customer can pin exactly one trusted PWA origin.
for _, allowed := range h.allowedOrigins {
if allowed == "" {
continue
}
if strings.EqualFold(origin, strings.TrimSpace(allowed)) {
return true
}
}
return false
}
func stripPort(hostport string) string {
if h, _, err := net.SplitHostPort(hostport); err == nil {
return h
}
return hostport
}
// clientIP returns the source IP of an HTTP request. By default uses
// r.RemoteAddr (the actual TCP peer); this is the only safe choice when
// the server is directly exposed to the internet, because a malicious
// client can put anything in X-Forwarded-For and would otherwise rotate
// it to evade per-IP rate limits.
//
// When `trustForwardedFor` is true the LAST entry of X-Forwarded-For is
// returned instead — appropriate only when running behind a reverse
// proxy that you control and that overwrites/appends the header (caddy,
// nginx with proper config, etc). Toggled via the YAMA_WEB_TRUST_PROXY
// env var at startup.
func clientIP(r *http.Request, trustForwardedFor bool) string {
if trustForwardedFor {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the LAST entry — the closest hop, i.e. our own
// trusted proxy's view of the peer. Trusting the first
// entry would let a malicious client at the head of the
// chain set an arbitrary value.
parts := strings.Split(xff, ",")
ip := strings.TrimSpace(parts[len(parts)-1])
if ip != "" {
return ip
}
}
}
return stripPort(r.RemoteAddr)
}
// stop unsubscribes from the device hub. Existing connections keep running
// until they close on their own; we only block new event delivery.
func (h *wsHub) stop() {
@@ -294,6 +402,11 @@ func (h *wsHub) unregister(c *wsClient) {
// ----- HTTP handler -------------------------------------------------------
func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) {
// Build a per-call upgrader so CheckOrigin closes over this hub's
// allowlist instead of a package-level mutable.
upgrader := baseUpgrader
upgrader.CheckOrigin = h.checkOrigin
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
h.log.Error("ws upgrade: %v", err)
@@ -313,7 +426,7 @@ func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) {
send: make(chan wsMsg, wsSendBuffer),
closed: make(chan struct{}),
nonce: nonce,
addr: r.RemoteAddr,
addr: clientIP(r, h.trustForwardedFor),
}
h.register(client)
defer h.unregister(client)

View File

@@ -79,18 +79,27 @@ func (h *wsHub) requireAdmin(c *wsClient, raw []byte, replyCmd string) (ok bool)
// ----- handlers ------------------------------------------------------------
func (h *wsHub) handleGetSalt(c *wsClient, raw []byte) {
// Throttle the salt-probe surface together with login: an attacker
// who can poll get_salt freely would otherwise still learn nothing
// (the unknown-user fake salt mitigation handles that), but the
// endpoint is otherwise free CPU on the server. Limiting by IP is
// enough; we don't have a username yet to limit by user.
if !h.allowLoginByIP(c) {
// Stall the response so a tight-loop attacker doesn't flood the
// queue. Still return a well-formed salt to avoid making the
// limit detectable from the client side.
time.Sleep(250 * time.Millisecond)
}
var in struct {
Username string `json:"username"`
}
_ = json.Unmarshal(raw, &in)
salt, ok := h.auth.GetSalt(in.Username)
// Do not leak which usernames exist: always return ok=true with a salt.
// For unknown users hand back the empty salt (matches admin convention)
// so the timing/shape of the response is uniform.
if !ok {
salt = ""
}
salt, _ := h.auth.GetSalt(in.Username)
// GetSalt now returns a deterministic fake salt (16 hex chars) for
// unknown users — same shape as a real salt — so an attacker can't
// tell from this response alone whether the username exists.
c.queue(mustJSON(map[string]any{
"cmd": "salt",
"ok": true,
@@ -109,6 +118,20 @@ func (h *wsHub) handleLogin(c *wsClient, raw []byte) {
return
}
// Rate-limit BEFORE doing the hash work, so a flood doesn't pin CPU.
// Two-dimensional throttle: per-IP catches scanners that try many
// usernames; per-username catches scanners that rotate IPs against a
// known account (admin). Either dimension tripping rejects the call
// with a uniform "credentials" error so the limit is not detectable.
if !h.allowLoginByIP(c) || !h.allowLoginByUsername(in.Username) {
h.log.Warn("ws login throttled: user=%s addr=%s", in.Username, c.addr)
// Burn the challenge so the attacker can't immediately replay.
c.nonce = ""
time.Sleep(500 * time.Millisecond)
c.queue(mustJSON(map[string]any{"cmd": "login_result", "ok": false, "msg": "Invalid credentials"}))
return
}
// Bind the response to the challenge we issued at connect time so that
// replays from a different connection can't reuse a captured response.
if in.Nonce == "" || in.Nonce != c.nonce {
@@ -120,9 +143,17 @@ func (h *wsHub) handleLogin(c *wsClient, raw []byte) {
if err != nil {
// Burn the challenge on failure too — forces a new round on retry.
c.nonce = ""
// Fixed delay on failure: makes online brute force impractical
// even within the rate-limit budget, and erases the timing
// difference between "wrong password" and "wrong nonce".
time.Sleep(250 * time.Millisecond)
c.queue(mustJSON(map[string]any{"cmd": "login_result", "ok": false, "msg": "Invalid credentials"}))
return
}
// Successful login: clear the per-IP/per-user budgets so a legitimate
// user who fat-fingered a few times doesn't stay throttled.
h.resetLoginThrottle(c, in.Username)
c.nonce = ""
c.token = token
c.role = role
@@ -136,6 +167,31 @@ func (h *wsHub) handleLogin(c *wsClient, raw []byte) {
}))
}
// allowLoginByIP / allowLoginByUsername return true when the call is
// within budget; nil limiter always returns true (effectively disabled).
func (h *wsHub) allowLoginByIP(c *wsClient) bool {
if h.loginIPLimit == nil || c == nil || c.addr == "" {
return true
}
return h.loginIPLimit.Allow(c.addr)
}
func (h *wsHub) allowLoginByUsername(username string) bool {
if h.loginUserLimit == nil || username == "" {
return true
}
return h.loginUserLimit.Allow(username)
}
func (h *wsHub) resetLoginThrottle(c *wsClient, username string) {
if h.loginIPLimit != nil && c != nil && c.addr != "" {
h.loginIPLimit.Reset(c.addr)
}
if h.loginUserLimit != nil && username != "" {
h.loginUserLimit.Reset(username)
}
}
// handleConnect kicks off a screen-sharing session for the browser. We send
// COMMAND_SCREEN_SPY to the device's main TCP connection; the device then
// opens a new sub-connection (TOKEN_BITMAPINFO) which the TCP side binds to