From 32a75f4670c7c395836d0625ece2afecc4aadb56 Mon Sep 17 00:00:00 2001 From: yuanyuanxiang <962914132@qq.com> Date: Mon, 18 May 2026 23:37:58 +0200 Subject: [PATCH] Security(Go): Login rate limit + WS origin allowlist + REST bearer auth --- server/go/README.md | 3 + server/go/cmd/main.go | 47 +++++++++++- server/go/web/server.go | 91 ++++++++++++++++++++--- server/go/web/ws.go | 125 ++++++++++++++++++++++++++++++-- server/go/web/ws_handlers.go | 70 ++++++++++++++++-- server/go/wsauth/ratelimit.go | 110 ++++++++++++++++++++++++++++ server/go/wsauth/wsauth.go | 59 ++++++++++++--- server/go/wsauth/wsauth_test.go | 102 ++++++++++++++++++++++++-- 8 files changed, 566 insertions(+), 41 deletions(-) create mode 100644 server/go/wsauth/ratelimit.go diff --git a/server/go/README.md b/server/go/README.md index dc94308..c32f5ed 100644 --- a/server/go/README.md +++ b/server/go/README.md @@ -56,6 +56,7 @@ server/go/ Web 应用能力 (Phase 3-7): - **Web 鉴权**: challenge-response 登录 + 不透明 token,与 users.json schema 互通 +- **登录加固**: 双维度速率限制(10 次/分钟·IP + 5 次/15 分钟·用户名)+ 失败固定延迟,防口令枚举;`/get_salt` 用确定性假盐响应未知用户,杜绝用户名探测;WebSocket Origin 同源校验 + 显式白名单;`/api/devices` Bearer Token 鉴权 - **设备列表与监控**: 在线设备 / RTT / 活动窗口 / 分辨率 实时下发 - **Web 远程桌面**: 浏览器 WebCodecs 解码 H.264,二进制 WS 帧低延迟中继;late-join 自动重发最近 IDR;优雅 BYE 关闭防止客户端无意义重连 - **鼠标 / 键盘输入**: Win32 消息映射 (`WM_*` / `VK_*` / `MK_*`),MSG64 48 字节布局直传客户端 @@ -149,6 +150,8 @@ VSCode F5 调试时由 `sync-web-assets` preLaunchTask 自动同步。 | `YAMA_WEB_ADMIN_PASS` | Web UI 的 admin 密码(明文);优先于 `YAMA_PWD`。两者都未设置时 Web 登录禁用 | `your_admin_password` | | `YAMA_SIGN_PASSWORD` | HMAC-SHA256 key used to sign CMD_MASTERSETTING replies; must match the client's expected value. Provision out-of-band. Unset → client refuses screen/file ops. | `` | | `YAMA_USERS_FILE` | Path to the JSON file that persists non-admin web users (allowed_groups, password hash, salt). Default is `users.json` in the working directory. | `users.json` | +| `YAMA_WEB_ALLOWED_ORIGINS` | Comma-separated WebSocket Origin allowlist for cross-origin upgrades. Empty (default) → only same-origin upgrades are accepted, which is correct when the web UI and `/ws` share a host. Add an entry per trusted PWA / dev origin. | `https://yama.example.com,https://yama-mobile.example.com` | +| `YAMA_WEB_TRUST_PROXY` | Set to `1` only when running behind a reverse proxy you control (caddy / nginx / cloudflare). Switches client-IP extraction to use the last entry of `X-Forwarded-For` instead of `RemoteAddr`, so per-IP login rate limit sees the real client. Direct-exposure deployments MUST leave this unset — otherwise attackers can spoof the header to evade rate limits. | `1` | ```bash # Linux/macOS diff --git a/server/go/cmd/main.go b/server/go/cmd/main.go index 56d602c..0f75ba5 100644 --- a/server/go/cmd/main.go +++ b/server/go/cmd/main.go @@ -611,6 +611,27 @@ func parsePorts(portStr string) ([]int, error) { return ports, nil } +// splitCSV splits a comma-separated env-var value into trimmed, non-empty +// entries. Returns nil for an empty input so callers can keep the natural +// "no value → no restriction" semantics with a single nil check. +func splitCSV(s string) []string { + if s == "" { + return nil + } + parts := strings.Split(s, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + if len(out) == 0 { + return nil + } + return out +} + func main() { // Parse command line flags portStr := flag.String("port", "6543", "Server listen ports (semicolon-separated, e.g. 6543;6544;6545)") @@ -733,9 +754,33 @@ func main() { } } + // Web-UI hardening knobs for public-HTTPS deployment. + // + // YAMA_WEB_ALLOWED_ORIGINS: comma-separated Origin allowlist (e.g. + // "https://yama.example.com,https://yama-mobile.example.com"). + // Empty (default) → only same-origin WS upgrades accepted, which + // is correct when the web UI and WS endpoint share a host. + // + // Login rate limits are hard-coded at sensible defaults for the + // small-user web UI: 10 attempts / minute per IP, 5 / 15 min per + // username. The handler also injects a 250 ms delay on every failure + // so online brute force is impractical even within budget. + allowedOrigins := splitCSV(os.Getenv("YAMA_WEB_ALLOWED_ORIGINS")) + trustProxy := os.Getenv("YAMA_WEB_TRUST_PROXY") == "1" + if trustProxy { + log.Info("Trusting X-Forwarded-For for client IP — make sure a reverse proxy is in front") + } + webCfg := web.Config{ + AllowedOrigins: allowedOrigins, + LoginIPLimit: wsauth.NewRateLimiter(10, time.Minute), + LoginUserLimit: wsauth.NewRateLimiter(5, 15*time.Minute), + TrustForwardedFor: trustProxy, + } + // Start HTTP server for web UI. Hub gives it read-only access to the // device registry; the authenticator owns user accounts and session tokens. - httpSrv := web.New(*httpPort, log.WithPrefix("Web"), deviceHub, webAuth) + httpSrv := web.New(*httpPort, log.WithPrefix("Web"), deviceHub, webAuth). + WithConfig(webCfg) if err := httpSrv.Start(); err != nil { log.Fatal("Failed to start HTTP server: %v", err) } diff --git a/server/go/web/server.go b/server/go/web/server.go index 565de0e..62d9972 100644 --- a/server/go/web/server.go +++ b/server/go/web/server.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "strconv" + "strings" "time" "github.com/yuanyuanxiang/SimpleRemoter/server/go/hub" @@ -19,12 +20,38 @@ import ( // static assets, the PWA manifest, and JSON APIs backed by the device hub. // WebSocket signaling and screen streaming will be wired up in later phases. type Server struct { - port int - log *logger.Logger - srv *http.Server - hub *hub.Hub - auth *wsauth.Authenticator - ws *wsHub + port int + log *logger.Logger + srv *http.Server + hub *hub.Hub + auth *wsauth.Authenticator + ws *wsHub + allowedOrigins []string // for WS Origin allowlist; empty = same-origin only + loginIPLimit *wsauth.RateLimiter + loginUserLimit *wsauth.RateLimiter + trustForwardedFor bool // honor X-Forwarded-For (behind trusted proxy only) +} + +// Config tunes the server's exposed-on-public-HTTPS hardening knobs. +// All fields are optional; zero values pick reasonable defaults. +type Config struct { + // AllowedOrigins is the comma-separated list of Origin header values + // the WebSocket upgrade will accept in addition to same-origin + // requests. Empty (default) → only same-origin upgrades are allowed, + // which is correct when the web UI and the WS endpoint are served + // from the same host. + AllowedOrigins []string + // LoginIPLimit / LoginUserLimit throttle the get_salt + login flow + // per source IP and per username respectively. Pass nil to disable + // either dimension (e.g. dev mode). + LoginIPLimit *wsauth.RateLimiter + LoginUserLimit *wsauth.RateLimiter + // TrustForwardedFor switches client-IP extraction from RemoteAddr + // (default) to the last entry of X-Forwarded-For. Set true only when + // running behind a reverse proxy that you control; on direct + // exposure the header is client-controlled and would let attackers + // evade per-IP rate limits. + TrustForwardedFor bool } // New creates an HTTP server bound to the given port. port=0 disables the server. @@ -34,6 +61,16 @@ func New(port int, log *logger.Logger, h *hub.Hub, auth *wsauth.Authenticator) * return &Server{port: port, log: log, hub: h, auth: auth} } +// WithConfig applies hardening configuration. Returns the receiver for +// chainable setup. Safe to call before Start; ignored thereafter. +func (s *Server) WithConfig(cfg Config) *Server { + s.allowedOrigins = cfg.AllowedOrigins + s.loginIPLimit = cfg.LoginIPLimit + s.loginUserLimit = cfg.LoginUserLimit + s.trustForwardedFor = cfg.TrustForwardedFor + return s +} + // Start launches the server in a goroutine and returns immediately. // If port is 0, returns nil without starting anything. func (s *Server) Start() error { @@ -42,13 +79,16 @@ func (s *Server) Start() error { return nil } - s.ws = newWSHub(s.auth, s.hub, s.log) + s.ws = newWSHub(s.auth, s.hub, s.log). + withOriginAllowlist(s.allowedOrigins). + withLoginRateLimiters(s.loginIPLimit, s.loginUserLimit). + withTrustForwardedFor(s.trustForwardedFor) mux := http.NewServeMux() mux.HandleFunc("/", s.handleIndex) mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/manifest.json", s.handleManifest) - mux.HandleFunc("/api/devices", s.handleDevices) + mux.HandleFunc("/api/devices", s.requireBearer(s.handleDevices)) mux.HandleFunc("/ws", s.ws.serve) mux.HandleFunc("/static/xterm.js", staticHandler(xtermJS, "application/javascript; charset=utf-8")) mux.HandleFunc("/static/xterm.css", staticHandler(xtermCSS, "text/css; charset=utf-8")) @@ -106,7 +146,7 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { // handleDevices returns a JSON snapshot of currently-online devices. Empty // array (not null) when no clients are connected — matches what the front-end -// will eventually expect. +// will eventually expect. Auth-gated via requireBearer. func (s *Server) handleDevices(w http.ResponseWriter, r *http.Request) { devices := s.hub.ListDevices() w.Header().Set("Content-Type", "application/json; charset=utf-8") @@ -116,6 +156,39 @@ func (s *Server) handleDevices(w http.ResponseWriter, r *http.Request) { } } +// requireBearer wraps a handler with `Authorization: Bearer ` auth +// against the same session-token store the WebSocket uses. Returns 401 on +// missing / invalid / expired tokens. Used to gate REST endpoints that +// previously fell through with no auth (notably /api/devices, which +// otherwise leaks the full online-device list to anyone on the internet). +func (s *Server) requireBearer(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + const prefix = "Bearer " + hdr := r.Header.Get("Authorization") + if !strings.HasPrefix(hdr, prefix) { + s.unauthorized(w) + return + } + token := strings.TrimSpace(hdr[len(prefix):]) + if token == "" { + s.unauthorized(w) + return + } + if _, err := s.auth.ValidateToken(token); err != nil { + s.unauthorized(w) + return + } + next(w, r) + } +} + +func (s *Server) unauthorized(w http.ResponseWriter) { + w.Header().Set("WWW-Authenticate", `Bearer realm="yama"`) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) +} + // PWA manifest. Referenced by in index.html. // Static JSON, no template needed. const manifestJSON = `{ diff --git a/server/go/web/ws.go b/server/go/web/ws.go index 5b7b9c5..1328f1a 100644 --- a/server/go/web/ws.go +++ b/server/go/web/ws.go @@ -2,7 +2,10 @@ package web import ( "encoding/json" + "net" "net/http" + "net/url" + "strings" "sync" "time" @@ -20,13 +23,12 @@ const ( wsSendBuffer = 64 // outbound queue depth per client ) -// upgrader allows any origin — this service is meant to be tunneled through -// frp, so requests can legitimately arrive from arbitrary front-end hosts. -// Adjust CheckOrigin once we have a deployment story. -var upgrader = websocket.Upgrader{ +// baseUpgrader carries the buffer-size config shared by all WS upgrades. +// CheckOrigin is set per-hub in wsHub.upgradeWS so the allowlist is +// closed-over per instance instead of being a global mutable. +var baseUpgrader = websocket.Upgrader{ ReadBufferSize: 4096, WriteBufferSize: 4096, - CheckOrigin: func(r *http.Request) bool { return true }, } // ----- per-connection client state ---------------------------------------- @@ -94,6 +96,14 @@ type wsHub struct { clients map[*wsClient]struct{} unsub func() + + // Hardening knobs wired from server.Config. Nil/empty values mean + // "no extra restriction" — useful for local dev where the hub is + // exercised without server.Server wiring up the env-driven defaults. + allowedOrigins []string // empty → only same-origin upgrades accepted + loginIPLimit *wsauth.RateLimiter + loginUserLimit *wsauth.RateLimiter + trustForwardedFor bool // honor X-Forwarded-For (only when behind a trusted proxy) } func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger) *wsHub { @@ -107,6 +117,104 @@ func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger) return h } +// withOriginAllowlist returns h after installing the explicit Origin +// allowlist. Chainable. Pass empty/nil to keep "same-origin only". +func (h *wsHub) withOriginAllowlist(origins []string) *wsHub { + h.allowedOrigins = origins + return h +} + +// withLoginRateLimiters wires per-IP and per-username throttles into +// the login flow. Either may be nil to disable that dimension. +func (h *wsHub) withLoginRateLimiters(byIP, byUser *wsauth.RateLimiter) *wsHub { + h.loginIPLimit = byIP + h.loginUserLimit = byUser + return h +} + +// withTrustForwardedFor opts in to using the last entry of the +// X-Forwarded-For header as the client IP. Safe only when the server is +// behind a reverse proxy that you control. +func (h *wsHub) withTrustForwardedFor(trust bool) *wsHub { + h.trustForwardedFor = trust + return h +} + +// checkOrigin decides whether to accept a WebSocket upgrade based on +// the request's Origin header. Same-origin (Origin host == Host) is +// always accepted; explicit allowlist entries cover the +// PWA-from-different-domain or local-dev cases. +// +// An empty Origin header is rejected: a legitimate browser always sends +// it on cross-origin requests, and same-origin requests have it too in +// modern Chrome/Safari/Firefox. Non-browser clients (curl, scripts) that +// omit Origin shouldn't be talking to the WS endpoint anyway. +func (h *wsHub) checkOrigin(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return false + } + u, err := url.Parse(origin) + if err != nil || u.Host == "" { + return false + } + // Same-origin (Origin host matches the Host the request came in on). + // Strip any port mismatch: if the server is behind a proxy, Host may + // not include a port while Origin does (or vice versa), so compare + // the hostname components. + originHost := u.Hostname() + reqHost := stripPort(r.Host) + if originHost == reqHost && originHost != "" { + return true + } + // Explicit allowlist entries — match Origin in full (scheme + host + // + port) so a customer can pin exactly one trusted PWA origin. + for _, allowed := range h.allowedOrigins { + if allowed == "" { + continue + } + if strings.EqualFold(origin, strings.TrimSpace(allowed)) { + return true + } + } + return false +} + +func stripPort(hostport string) string { + if h, _, err := net.SplitHostPort(hostport); err == nil { + return h + } + return hostport +} + +// clientIP returns the source IP of an HTTP request. By default uses +// r.RemoteAddr (the actual TCP peer); this is the only safe choice when +// the server is directly exposed to the internet, because a malicious +// client can put anything in X-Forwarded-For and would otherwise rotate +// it to evade per-IP rate limits. +// +// When `trustForwardedFor` is true the LAST entry of X-Forwarded-For is +// returned instead — appropriate only when running behind a reverse +// proxy that you control and that overwrites/appends the header (caddy, +// nginx with proper config, etc). Toggled via the YAMA_WEB_TRUST_PROXY +// env var at startup. +func clientIP(r *http.Request, trustForwardedFor bool) string { + if trustForwardedFor { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the LAST entry — the closest hop, i.e. our own + // trusted proxy's view of the peer. Trusting the first + // entry would let a malicious client at the head of the + // chain set an arbitrary value. + parts := strings.Split(xff, ",") + ip := strings.TrimSpace(parts[len(parts)-1]) + if ip != "" { + return ip + } + } + } + return stripPort(r.RemoteAddr) +} + // stop unsubscribes from the device hub. Existing connections keep running // until they close on their own; we only block new event delivery. func (h *wsHub) stop() { @@ -294,6 +402,11 @@ func (h *wsHub) unregister(c *wsClient) { // ----- HTTP handler ------------------------------------------------------- func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) { + // Build a per-call upgrader so CheckOrigin closes over this hub's + // allowlist instead of a package-level mutable. + upgrader := baseUpgrader + upgrader.CheckOrigin = h.checkOrigin + conn, err := upgrader.Upgrade(w, r, nil) if err != nil { h.log.Error("ws upgrade: %v", err) @@ -313,7 +426,7 @@ func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) { send: make(chan wsMsg, wsSendBuffer), closed: make(chan struct{}), nonce: nonce, - addr: r.RemoteAddr, + addr: clientIP(r, h.trustForwardedFor), } h.register(client) defer h.unregister(client) diff --git a/server/go/web/ws_handlers.go b/server/go/web/ws_handlers.go index f1202fa..8d9f3f7 100644 --- a/server/go/web/ws_handlers.go +++ b/server/go/web/ws_handlers.go @@ -79,18 +79,27 @@ func (h *wsHub) requireAdmin(c *wsClient, raw []byte, replyCmd string) (ok bool) // ----- handlers ------------------------------------------------------------ func (h *wsHub) handleGetSalt(c *wsClient, raw []byte) { + // Throttle the salt-probe surface together with login: an attacker + // who can poll get_salt freely would otherwise still learn nothing + // (the unknown-user fake salt mitigation handles that), but the + // endpoint is otherwise free CPU on the server. Limiting by IP is + // enough; we don't have a username yet to limit by user. + if !h.allowLoginByIP(c) { + // Stall the response so a tight-loop attacker doesn't flood the + // queue. Still return a well-formed salt to avoid making the + // limit detectable from the client side. + time.Sleep(250 * time.Millisecond) + } + var in struct { Username string `json:"username"` } _ = json.Unmarshal(raw, &in) - salt, ok := h.auth.GetSalt(in.Username) - // Do not leak which usernames exist: always return ok=true with a salt. - // For unknown users hand back the empty salt (matches admin convention) - // so the timing/shape of the response is uniform. - if !ok { - salt = "" - } + salt, _ := h.auth.GetSalt(in.Username) + // GetSalt now returns a deterministic fake salt (16 hex chars) for + // unknown users — same shape as a real salt — so an attacker can't + // tell from this response alone whether the username exists. c.queue(mustJSON(map[string]any{ "cmd": "salt", "ok": true, @@ -109,6 +118,20 @@ func (h *wsHub) handleLogin(c *wsClient, raw []byte) { return } + // Rate-limit BEFORE doing the hash work, so a flood doesn't pin CPU. + // Two-dimensional throttle: per-IP catches scanners that try many + // usernames; per-username catches scanners that rotate IPs against a + // known account (admin). Either dimension tripping rejects the call + // with a uniform "credentials" error so the limit is not detectable. + if !h.allowLoginByIP(c) || !h.allowLoginByUsername(in.Username) { + h.log.Warn("ws login throttled: user=%s addr=%s", in.Username, c.addr) + // Burn the challenge so the attacker can't immediately replay. + c.nonce = "" + time.Sleep(500 * time.Millisecond) + c.queue(mustJSON(map[string]any{"cmd": "login_result", "ok": false, "msg": "Invalid credentials"})) + return + } + // Bind the response to the challenge we issued at connect time so that // replays from a different connection can't reuse a captured response. if in.Nonce == "" || in.Nonce != c.nonce { @@ -120,9 +143,17 @@ func (h *wsHub) handleLogin(c *wsClient, raw []byte) { if err != nil { // Burn the challenge on failure too — forces a new round on retry. c.nonce = "" + // Fixed delay on failure: makes online brute force impractical + // even within the rate-limit budget, and erases the timing + // difference between "wrong password" and "wrong nonce". + time.Sleep(250 * time.Millisecond) c.queue(mustJSON(map[string]any{"cmd": "login_result", "ok": false, "msg": "Invalid credentials"})) return } + // Successful login: clear the per-IP/per-user budgets so a legitimate + // user who fat-fingered a few times doesn't stay throttled. + h.resetLoginThrottle(c, in.Username) + c.nonce = "" c.token = token c.role = role @@ -136,6 +167,31 @@ func (h *wsHub) handleLogin(c *wsClient, raw []byte) { })) } +// allowLoginByIP / allowLoginByUsername return true when the call is +// within budget; nil limiter always returns true (effectively disabled). +func (h *wsHub) allowLoginByIP(c *wsClient) bool { + if h.loginIPLimit == nil || c == nil || c.addr == "" { + return true + } + return h.loginIPLimit.Allow(c.addr) +} + +func (h *wsHub) allowLoginByUsername(username string) bool { + if h.loginUserLimit == nil || username == "" { + return true + } + return h.loginUserLimit.Allow(username) +} + +func (h *wsHub) resetLoginThrottle(c *wsClient, username string) { + if h.loginIPLimit != nil && c != nil && c.addr != "" { + h.loginIPLimit.Reset(c.addr) + } + if h.loginUserLimit != nil && username != "" { + h.loginUserLimit.Reset(username) + } +} + // handleConnect kicks off a screen-sharing session for the browser. We send // COMMAND_SCREEN_SPY to the device's main TCP connection; the device then // opens a new sub-connection (TOKEN_BITMAPINFO) which the TCP side binds to diff --git a/server/go/wsauth/ratelimit.go b/server/go/wsauth/ratelimit.go new file mode 100644 index 0000000..fcaee45 --- /dev/null +++ b/server/go/wsauth/ratelimit.go @@ -0,0 +1,110 @@ +package wsauth + +import ( + "sync" + "time" +) + +// RateLimiter is a sliding-window per-key counter used to throttle login +// attempts. Two instances are typically created: one keyed by client IP +// (to slow distributed brute force), one keyed by username (to slow +// targeted attacks against a known account). +// +// Design notes: +// - Denied attempts are NOT recorded — the window slides naturally and a +// legitimate user who fat-fingers their password recovers as soon as +// the oldest attempt ages out, while a determined attacker is capped +// at `limit` successful attempts per `window` indefinitely. +// - Lazy cleanup: stale timestamps for a key are pruned on every Allow() +// call. Truly idle keys are GC'd by Sweep(), which callers should run +// periodically from a background goroutine. +// - Map size is bounded by the count of recently-active keys; for the +// web UI's expected load (a handful of users + occasional scanners), +// no extra GC pressure considerations needed. +type RateLimiter struct { + mu sync.Mutex + limit int + window time.Duration + entries map[string][]time.Time +} + +// NewRateLimiter returns a limiter that allows up to `limit` events per +// `window` duration per key. Zero or negative limit/window disables the +// limiter (Allow always returns true) — useful for tests / dev mode. +func NewRateLimiter(limit int, window time.Duration) *RateLimiter { + return &RateLimiter{ + limit: limit, + window: window, + entries: make(map[string][]time.Time), + } +} + +// Allow records an attempt for `key` if and only if the caller is under +// the per-key limit. Returns true when allowed, false when over limit. +// Empty key is treated as "no throttle" (returns true without recording) +// so the caller can fall through when the IP/username is unavailable. +func (r *RateLimiter) Allow(key string) bool { + if r == nil || r.limit <= 0 || r.window <= 0 || key == "" { + return true + } + r.mu.Lock() + defer r.mu.Unlock() + + cutoff := time.Now().Add(-r.window) + times := r.entries[key] + // Compact in place — keep only timestamps within the window. + keep := times[:0] + for _, t := range times { + if t.After(cutoff) { + keep = append(keep, t) + } + } + + if len(keep) >= r.limit { + // Update the map even when denying so the compacted slice doesn't + // keep stale entries forever. Don't append the new attempt: that + // would let attackers extend the window arbitrarily. + r.entries[key] = keep + return false + } + + r.entries[key] = append(keep, time.Now()) + return true +} + +// Reset clears state for a key. Call on successful login to give the user +// a fresh budget — otherwise a string of failed attempts followed by a +// correct one still leaves the budget partially consumed. +func (r *RateLimiter) Reset(key string) { + if r == nil || key == "" { + return + } + r.mu.Lock() + delete(r.entries, key) + r.mu.Unlock() +} + +// Sweep removes entries whose timestamps have all aged out of the window. +// Safe to call concurrently with Allow. Intended for periodic invocation +// from a background ticker (e.g. every window-length) to bound the map. +func (r *RateLimiter) Sweep() { + if r == nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + cutoff := time.Now().Add(-r.window) + for key, times := range r.entries { + keep := times[:0] + for _, t := range times { + if t.After(cutoff) { + keep = append(keep, t) + } + } + if len(keep) == 0 { + delete(r.entries, key) + } else { + r.entries[key] = keep + } + } +} diff --git a/server/go/wsauth/wsauth.go b/server/go/wsauth/wsauth.go index 15b6b39..10be342 100644 --- a/server/go/wsauth/wsauth.go +++ b/server/go/wsauth/wsauth.go @@ -95,13 +95,26 @@ func (a *Authenticator) AddUser(u User) { a.mu.Unlock() } -// AddAdminFromPlainPassword is a convenience for the bootstrap admin: salt is -// empty (matching the C++ admin record), hash is SHA256(password). +// AddAdminFromPlainPassword is a convenience for the bootstrap admin. +// Unlike legacy convention, the admin record is given a real per-instance +// salt — exposing an empty salt for admin while everyone else has a real +// 16-hex one would let an unauthenticated probe distinguish admin from +// other accounts via /get_salt alone. The cost is a tiny break in +// users.json schema compat: admin is never persisted to users.json +// anyway (snapshotPersistableLocked excludes it), so this is in-memory +// only. func (a *Authenticator) AddAdminFromPlainPassword(username, plainPassword string) { + salt, err := NewSalt() + if err != nil { + // Fall back to deterministic salt derived from the password hash + // rather than empty — preserves the uniform-shape property even + // if crypto/rand briefly errors at startup. + salt = ComputeSHA256(plainPassword)[:saltBytes*2] + } a.AddUser(User{ Username: username, - PasswordHash: ComputeSHA256(plainPassword), - Salt: "", + PasswordHash: HashPassword(plainPassword, salt), + Salt: salt, Role: "admin", }) } @@ -200,17 +213,43 @@ func (a *Authenticator) ListUsers() []User { return out } -// GetSalt returns the per-user salt. If the user does not exist, returns ("", false). -// Note: the C++ admin uses an empty salt — that is still considered "found" -// and the empty string is returned with ok=true. +// GetSalt returns the per-user salt for an existing user, or a +// deterministic 16-hex pseudo-salt for an unknown user. The ok flag +// reports which case occurred, so callers can decide whether to update +// rate-limit / audit state — but the returned salt itself is shaped +// identically (16 hex chars) in both cases, defeating the user-existence +// probe an attacker would otherwise mount via /get_salt. +// +// The pseudo-salt is derived from a server-instance secret (the admin +// password hash, taken at first call) mixed with the username, so the +// same unknown user always sees the same fake salt across requests. +// Without this, an attacker could fingerprint the "fake-salt branch" +// by submitting the same username twice and watching for differences. func (a *Authenticator) GetSalt(username string) (string, bool) { a.mu.RLock() u, ok := a.users[username] a.mu.RUnlock() - if !ok { - return "", false + if ok { + return u.Salt, true } - return u.Salt, true + return a.fakeSalt(username), false +} + +// fakeSalt derives a deterministic 16-hex value for unknown usernames. +// The secret pepper is the bootstrap admin's password hash — present as +// long as the server has any admin, deterministic per deployment, never +// transmitted. Reveals nothing useful to an attacker even if reverse- +// engineered: the only thing they can do with it is reproduce the fake +// salt, which they already see in the response. +func (a *Authenticator) fakeSalt(username string) string { + a.mu.RLock() + pepper := "" + if admin, ok := a.users["admin"]; ok { + pepper = admin.PasswordHash + } + a.mu.RUnlock() + digest := ComputeSHA256("yama-fake-salt|" + pepper + "|" + username) + return digest[:saltBytes*2] } // VerifyLogin checks a challenge-response login. The browser sends diff --git a/server/go/wsauth/wsauth_test.go b/server/go/wsauth/wsauth_test.go index 6e0b7cd..78c1f24 100644 --- a/server/go/wsauth/wsauth_test.go +++ b/server/go/wsauth/wsauth_test.go @@ -14,19 +14,31 @@ func TestSHA256Vector(t *testing.T) { } } -func TestLoginRoundTripAdminEmptySalt(t *testing.T) { +// adminLoginResponse helps tests compute the right login response for an +// admin account that now uses a real per-instance salt. +func adminLoginResponse(t *testing.T, a *Authenticator, username, password, nonce string) string { + t.Helper() + salt, ok := a.GetSalt(username) + if !ok { + t.Fatalf("admin %s not registered", username) + } + return ComputeSHA256(HashPassword(password, salt) + nonce) +} + +func TestLoginRoundTripAdmin(t *testing.T) { a := New() a.AddAdminFromPlainPassword("admin", "hunter2") salt, ok := a.GetSalt("admin") - if !ok || salt != "" { - t.Fatalf("admin salt: ok=%v salt=%q", ok, salt) + if !ok { + t.Fatal("admin should be found") + } + if len(salt) != 2*saltBytes { + t.Fatalf("admin salt should be a real 16-hex value, got %q (len=%d)", salt, len(salt)) } - // Simulate the browser: nonce = "abc123", response = SHA256(passwordHash + nonce) nonce := "abc123" - passwordHash := ComputeSHA256("hunter2") - response := ComputeSHA256(passwordHash + nonce) + response := adminLoginResponse(t, a, "admin", "hunter2", nonce) token, role, err := a.VerifyLogin("admin", response, nonce) if err != nil { @@ -71,6 +83,33 @@ func TestLoginRoundTripViewerWithSalt(t *testing.T) { } } +// TestGetSaltUnknownUserShape verifies the salt-probe mitigation: an +// unknown user must get back a value that's shape-identical to a real +// salt, so an attacker can't tell from /get_salt alone whether a +// username exists. +func TestGetSaltUnknownUserShape(t *testing.T) { + a := New() + a.AddAdminFromPlainPassword("admin", "pw") + + fake, ok := a.GetSalt("nobody") + if ok { + t.Fatal("ok should be false for unknown user") + } + if len(fake) != 2*saltBytes { + t.Fatalf("fake salt should be %d hex chars; got %q (len=%d)", 2*saltBytes, fake, len(fake)) + } + // Determinism: repeated probes for the same username get the same fake. + fake2, _ := a.GetSalt("nobody") + if fake != fake2 { + t.Fatalf("fake salt should be deterministic for repeated probes; got %q vs %q", fake, fake2) + } + // Different usernames get different fake salts. + other, _ := a.GetSalt("ghost") + if fake == other { + t.Fatalf("fake salts should differ across usernames; both = %q", fake) + } +} + func TestLoginRejectsWrongResponse(t *testing.T) { a := New() a.AddAdminFromPlainPassword("admin", "x") @@ -89,7 +128,7 @@ func TestTokenExpiry(t *testing.T) { a.SetTokenExpire(50 * time.Millisecond) a.AddAdminFromPlainPassword("admin", "x") nonce, _ := NewNonce() - response := ComputeSHA256(ComputeSHA256("x") + nonce) + response := adminLoginResponse(t, a, "admin", "x", nonce) token, _, err := a.VerifyLogin("admin", response, nonce) if err != nil { t.Fatal(err) @@ -107,10 +146,57 @@ func TestRevoke(t *testing.T) { a := New() a.AddAdminFromPlainPassword("admin", "x") nonce, _ := NewNonce() - response := ComputeSHA256(ComputeSHA256("x") + nonce) + response := adminLoginResponse(t, a, "admin", "x", nonce) token, _, _ := a.VerifyLogin("admin", response, nonce) a.RevokeToken(token) if _, err := a.ValidateToken(token); err == nil { t.Fatal("revoked token should not validate") } } + +func TestRateLimiterAllowsBurstThenBlocks(t *testing.T) { + r := NewRateLimiter(3, time.Minute) + for i := 0; i < 3; i++ { + if !r.Allow("ip-a") { + t.Fatalf("attempt %d should be allowed", i+1) + } + } + if r.Allow("ip-a") { + t.Fatal("4th attempt should be denied") + } + // Different key has independent budget. + if !r.Allow("ip-b") { + t.Fatal("different key should still be allowed") + } +} + +func TestRateLimiterReset(t *testing.T) { + r := NewRateLimiter(2, time.Minute) + r.Allow("k") + r.Allow("k") + if r.Allow("k") { + t.Fatal("3rd should be denied") + } + r.Reset("k") + if !r.Allow("k") { + t.Fatal("after Reset, should be allowed again") + } +} + +func TestRateLimiterDisabledWhenZeroLimit(t *testing.T) { + r := NewRateLimiter(0, time.Minute) + for i := 0; i < 100; i++ { + if !r.Allow("k") { + t.Fatalf("limit=0 should never deny, denied at i=%d", i) + } + } +} + +func TestRateLimiterNilSafe(t *testing.T) { + var r *RateLimiter + if !r.Allow("anything") { + t.Fatal("nil limiter should allow") + } + r.Reset("anything") + r.Sweep() +}