Security(Go): Login rate limit + WS origin allowlist + REST bearer auth
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/hub"
|
||||
@@ -19,12 +20,38 @@ import (
|
||||
// static assets, the PWA manifest, and JSON APIs backed by the device hub.
|
||||
// WebSocket signaling and screen streaming will be wired up in later phases.
|
||||
type Server struct {
|
||||
port int
|
||||
log *logger.Logger
|
||||
srv *http.Server
|
||||
hub *hub.Hub
|
||||
auth *wsauth.Authenticator
|
||||
ws *wsHub
|
||||
port int
|
||||
log *logger.Logger
|
||||
srv *http.Server
|
||||
hub *hub.Hub
|
||||
auth *wsauth.Authenticator
|
||||
ws *wsHub
|
||||
allowedOrigins []string // for WS Origin allowlist; empty = same-origin only
|
||||
loginIPLimit *wsauth.RateLimiter
|
||||
loginUserLimit *wsauth.RateLimiter
|
||||
trustForwardedFor bool // honor X-Forwarded-For (behind trusted proxy only)
|
||||
}
|
||||
|
||||
// Config tunes the server's exposed-on-public-HTTPS hardening knobs.
|
||||
// All fields are optional; zero values pick reasonable defaults.
|
||||
type Config struct {
|
||||
// AllowedOrigins is the comma-separated list of Origin header values
|
||||
// the WebSocket upgrade will accept in addition to same-origin
|
||||
// requests. Empty (default) → only same-origin upgrades are allowed,
|
||||
// which is correct when the web UI and the WS endpoint are served
|
||||
// from the same host.
|
||||
AllowedOrigins []string
|
||||
// LoginIPLimit / LoginUserLimit throttle the get_salt + login flow
|
||||
// per source IP and per username respectively. Pass nil to disable
|
||||
// either dimension (e.g. dev mode).
|
||||
LoginIPLimit *wsauth.RateLimiter
|
||||
LoginUserLimit *wsauth.RateLimiter
|
||||
// TrustForwardedFor switches client-IP extraction from RemoteAddr
|
||||
// (default) to the last entry of X-Forwarded-For. Set true only when
|
||||
// running behind a reverse proxy that you control; on direct
|
||||
// exposure the header is client-controlled and would let attackers
|
||||
// evade per-IP rate limits.
|
||||
TrustForwardedFor bool
|
||||
}
|
||||
|
||||
// New creates an HTTP server bound to the given port. port=0 disables the server.
|
||||
@@ -34,6 +61,16 @@ func New(port int, log *logger.Logger, h *hub.Hub, auth *wsauth.Authenticator) *
|
||||
return &Server{port: port, log: log, hub: h, auth: auth}
|
||||
}
|
||||
|
||||
// WithConfig applies hardening configuration. Returns the receiver for
|
||||
// chainable setup. Safe to call before Start; ignored thereafter.
|
||||
func (s *Server) WithConfig(cfg Config) *Server {
|
||||
s.allowedOrigins = cfg.AllowedOrigins
|
||||
s.loginIPLimit = cfg.LoginIPLimit
|
||||
s.loginUserLimit = cfg.LoginUserLimit
|
||||
s.trustForwardedFor = cfg.TrustForwardedFor
|
||||
return s
|
||||
}
|
||||
|
||||
// Start launches the server in a goroutine and returns immediately.
|
||||
// If port is 0, returns nil without starting anything.
|
||||
func (s *Server) Start() error {
|
||||
@@ -42,13 +79,16 @@ func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.ws = newWSHub(s.auth, s.hub, s.log)
|
||||
s.ws = newWSHub(s.auth, s.hub, s.log).
|
||||
withOriginAllowlist(s.allowedOrigins).
|
||||
withLoginRateLimiters(s.loginIPLimit, s.loginUserLimit).
|
||||
withTrustForwardedFor(s.trustForwardedFor)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", s.handleIndex)
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
mux.HandleFunc("/manifest.json", s.handleManifest)
|
||||
mux.HandleFunc("/api/devices", s.handleDevices)
|
||||
mux.HandleFunc("/api/devices", s.requireBearer(s.handleDevices))
|
||||
mux.HandleFunc("/ws", s.ws.serve)
|
||||
mux.HandleFunc("/static/xterm.js", staticHandler(xtermJS, "application/javascript; charset=utf-8"))
|
||||
mux.HandleFunc("/static/xterm.css", staticHandler(xtermCSS, "text/css; charset=utf-8"))
|
||||
@@ -106,7 +146,7 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// handleDevices returns a JSON snapshot of currently-online devices. Empty
|
||||
// array (not null) when no clients are connected — matches what the front-end
|
||||
// will eventually expect.
|
||||
// will eventually expect. Auth-gated via requireBearer.
|
||||
func (s *Server) handleDevices(w http.ResponseWriter, r *http.Request) {
|
||||
devices := s.hub.ListDevices()
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
@@ -116,6 +156,39 @@ func (s *Server) handleDevices(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// requireBearer wraps a handler with `Authorization: Bearer <token>` auth
|
||||
// against the same session-token store the WebSocket uses. Returns 401 on
|
||||
// missing / invalid / expired tokens. Used to gate REST endpoints that
|
||||
// previously fell through with no auth (notably /api/devices, which
|
||||
// otherwise leaks the full online-device list to anyone on the internet).
|
||||
func (s *Server) requireBearer(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
const prefix = "Bearer "
|
||||
hdr := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(hdr, prefix) {
|
||||
s.unauthorized(w)
|
||||
return
|
||||
}
|
||||
token := strings.TrimSpace(hdr[len(prefix):])
|
||||
if token == "" {
|
||||
s.unauthorized(w)
|
||||
return
|
||||
}
|
||||
if _, err := s.auth.ValidateToken(token); err != nil {
|
||||
s.unauthorized(w)
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) unauthorized(w http.ResponseWriter) {
|
||||
w.Header().Set("WWW-Authenticate", `Bearer realm="yama"`)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
|
||||
}
|
||||
|
||||
// PWA manifest. Referenced by <link rel="manifest"> in index.html.
|
||||
// Static JSON, no template needed.
|
||||
const manifestJSON = `{
|
||||
|
||||
@@ -2,7 +2,10 @@ package web
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -20,13 +23,12 @@ const (
|
||||
wsSendBuffer = 64 // outbound queue depth per client
|
||||
)
|
||||
|
||||
// upgrader allows any origin — this service is meant to be tunneled through
|
||||
// frp, so requests can legitimately arrive from arbitrary front-end hosts.
|
||||
// Adjust CheckOrigin once we have a deployment story.
|
||||
var upgrader = websocket.Upgrader{
|
||||
// baseUpgrader carries the buffer-size config shared by all WS upgrades.
|
||||
// CheckOrigin is set per-hub in wsHub.upgradeWS so the allowlist is
|
||||
// closed-over per instance instead of being a global mutable.
|
||||
var baseUpgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// ----- per-connection client state ----------------------------------------
|
||||
@@ -94,6 +96,14 @@ type wsHub struct {
|
||||
clients map[*wsClient]struct{}
|
||||
|
||||
unsub func()
|
||||
|
||||
// Hardening knobs wired from server.Config. Nil/empty values mean
|
||||
// "no extra restriction" — useful for local dev where the hub is
|
||||
// exercised without server.Server wiring up the env-driven defaults.
|
||||
allowedOrigins []string // empty → only same-origin upgrades accepted
|
||||
loginIPLimit *wsauth.RateLimiter
|
||||
loginUserLimit *wsauth.RateLimiter
|
||||
trustForwardedFor bool // honor X-Forwarded-For (only when behind a trusted proxy)
|
||||
}
|
||||
|
||||
func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger) *wsHub {
|
||||
@@ -107,6 +117,104 @@ func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger)
|
||||
return h
|
||||
}
|
||||
|
||||
// withOriginAllowlist returns h after installing the explicit Origin
|
||||
// allowlist. Chainable. Pass empty/nil to keep "same-origin only".
|
||||
func (h *wsHub) withOriginAllowlist(origins []string) *wsHub {
|
||||
h.allowedOrigins = origins
|
||||
return h
|
||||
}
|
||||
|
||||
// withLoginRateLimiters wires per-IP and per-username throttles into
|
||||
// the login flow. Either may be nil to disable that dimension.
|
||||
func (h *wsHub) withLoginRateLimiters(byIP, byUser *wsauth.RateLimiter) *wsHub {
|
||||
h.loginIPLimit = byIP
|
||||
h.loginUserLimit = byUser
|
||||
return h
|
||||
}
|
||||
|
||||
// withTrustForwardedFor opts in to using the last entry of the
|
||||
// X-Forwarded-For header as the client IP. Safe only when the server is
|
||||
// behind a reverse proxy that you control.
|
||||
func (h *wsHub) withTrustForwardedFor(trust bool) *wsHub {
|
||||
h.trustForwardedFor = trust
|
||||
return h
|
||||
}
|
||||
|
||||
// checkOrigin decides whether to accept a WebSocket upgrade based on
|
||||
// the request's Origin header. Same-origin (Origin host == Host) is
|
||||
// always accepted; explicit allowlist entries cover the
|
||||
// PWA-from-different-domain or local-dev cases.
|
||||
//
|
||||
// An empty Origin header is rejected: a legitimate browser always sends
|
||||
// it on cross-origin requests, and same-origin requests have it too in
|
||||
// modern Chrome/Safari/Firefox. Non-browser clients (curl, scripts) that
|
||||
// omit Origin shouldn't be talking to the WS endpoint anyway.
|
||||
func (h *wsHub) checkOrigin(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return false
|
||||
}
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil || u.Host == "" {
|
||||
return false
|
||||
}
|
||||
// Same-origin (Origin host matches the Host the request came in on).
|
||||
// Strip any port mismatch: if the server is behind a proxy, Host may
|
||||
// not include a port while Origin does (or vice versa), so compare
|
||||
// the hostname components.
|
||||
originHost := u.Hostname()
|
||||
reqHost := stripPort(r.Host)
|
||||
if originHost == reqHost && originHost != "" {
|
||||
return true
|
||||
}
|
||||
// Explicit allowlist entries — match Origin in full (scheme + host
|
||||
// + port) so a customer can pin exactly one trusted PWA origin.
|
||||
for _, allowed := range h.allowedOrigins {
|
||||
if allowed == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(origin, strings.TrimSpace(allowed)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stripPort(hostport string) string {
|
||||
if h, _, err := net.SplitHostPort(hostport); err == nil {
|
||||
return h
|
||||
}
|
||||
return hostport
|
||||
}
|
||||
|
||||
// clientIP returns the source IP of an HTTP request. By default uses
|
||||
// r.RemoteAddr (the actual TCP peer); this is the only safe choice when
|
||||
// the server is directly exposed to the internet, because a malicious
|
||||
// client can put anything in X-Forwarded-For and would otherwise rotate
|
||||
// it to evade per-IP rate limits.
|
||||
//
|
||||
// When `trustForwardedFor` is true the LAST entry of X-Forwarded-For is
|
||||
// returned instead — appropriate only when running behind a reverse
|
||||
// proxy that you control and that overwrites/appends the header (caddy,
|
||||
// nginx with proper config, etc). Toggled via the YAMA_WEB_TRUST_PROXY
|
||||
// env var at startup.
|
||||
func clientIP(r *http.Request, trustForwardedFor bool) string {
|
||||
if trustForwardedFor {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the LAST entry — the closest hop, i.e. our own
|
||||
// trusted proxy's view of the peer. Trusting the first
|
||||
// entry would let a malicious client at the head of the
|
||||
// chain set an arbitrary value.
|
||||
parts := strings.Split(xff, ",")
|
||||
ip := strings.TrimSpace(parts[len(parts)-1])
|
||||
if ip != "" {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
return stripPort(r.RemoteAddr)
|
||||
}
|
||||
|
||||
// stop unsubscribes from the device hub. Existing connections keep running
|
||||
// until they close on their own; we only block new event delivery.
|
||||
func (h *wsHub) stop() {
|
||||
@@ -294,6 +402,11 @@ func (h *wsHub) unregister(c *wsClient) {
|
||||
// ----- HTTP handler -------------------------------------------------------
|
||||
|
||||
func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) {
|
||||
// Build a per-call upgrader so CheckOrigin closes over this hub's
|
||||
// allowlist instead of a package-level mutable.
|
||||
upgrader := baseUpgrader
|
||||
upgrader.CheckOrigin = h.checkOrigin
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.log.Error("ws upgrade: %v", err)
|
||||
@@ -313,7 +426,7 @@ func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) {
|
||||
send: make(chan wsMsg, wsSendBuffer),
|
||||
closed: make(chan struct{}),
|
||||
nonce: nonce,
|
||||
addr: r.RemoteAddr,
|
||||
addr: clientIP(r, h.trustForwardedFor),
|
||||
}
|
||||
h.register(client)
|
||||
defer h.unregister(client)
|
||||
|
||||
@@ -79,18 +79,27 @@ func (h *wsHub) requireAdmin(c *wsClient, raw []byte, replyCmd string) (ok bool)
|
||||
// ----- handlers ------------------------------------------------------------
|
||||
|
||||
func (h *wsHub) handleGetSalt(c *wsClient, raw []byte) {
|
||||
// Throttle the salt-probe surface together with login: an attacker
|
||||
// who can poll get_salt freely would otherwise still learn nothing
|
||||
// (the unknown-user fake salt mitigation handles that), but the
|
||||
// endpoint is otherwise free CPU on the server. Limiting by IP is
|
||||
// enough; we don't have a username yet to limit by user.
|
||||
if !h.allowLoginByIP(c) {
|
||||
// Stall the response so a tight-loop attacker doesn't flood the
|
||||
// queue. Still return a well-formed salt to avoid making the
|
||||
// limit detectable from the client side.
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
|
||||
var in struct {
|
||||
Username string `json:"username"`
|
||||
}
|
||||
_ = json.Unmarshal(raw, &in)
|
||||
|
||||
salt, ok := h.auth.GetSalt(in.Username)
|
||||
// Do not leak which usernames exist: always return ok=true with a salt.
|
||||
// For unknown users hand back the empty salt (matches admin convention)
|
||||
// so the timing/shape of the response is uniform.
|
||||
if !ok {
|
||||
salt = ""
|
||||
}
|
||||
salt, _ := h.auth.GetSalt(in.Username)
|
||||
// GetSalt now returns a deterministic fake salt (16 hex chars) for
|
||||
// unknown users — same shape as a real salt — so an attacker can't
|
||||
// tell from this response alone whether the username exists.
|
||||
c.queue(mustJSON(map[string]any{
|
||||
"cmd": "salt",
|
||||
"ok": true,
|
||||
@@ -109,6 +118,20 @@ func (h *wsHub) handleLogin(c *wsClient, raw []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
// Rate-limit BEFORE doing the hash work, so a flood doesn't pin CPU.
|
||||
// Two-dimensional throttle: per-IP catches scanners that try many
|
||||
// usernames; per-username catches scanners that rotate IPs against a
|
||||
// known account (admin). Either dimension tripping rejects the call
|
||||
// with a uniform "credentials" error so the limit is not detectable.
|
||||
if !h.allowLoginByIP(c) || !h.allowLoginByUsername(in.Username) {
|
||||
h.log.Warn("ws login throttled: user=%s addr=%s", in.Username, c.addr)
|
||||
// Burn the challenge so the attacker can't immediately replay.
|
||||
c.nonce = ""
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
c.queue(mustJSON(map[string]any{"cmd": "login_result", "ok": false, "msg": "Invalid credentials"}))
|
||||
return
|
||||
}
|
||||
|
||||
// Bind the response to the challenge we issued at connect time so that
|
||||
// replays from a different connection can't reuse a captured response.
|
||||
if in.Nonce == "" || in.Nonce != c.nonce {
|
||||
@@ -120,9 +143,17 @@ func (h *wsHub) handleLogin(c *wsClient, raw []byte) {
|
||||
if err != nil {
|
||||
// Burn the challenge on failure too — forces a new round on retry.
|
||||
c.nonce = ""
|
||||
// Fixed delay on failure: makes online brute force impractical
|
||||
// even within the rate-limit budget, and erases the timing
|
||||
// difference between "wrong password" and "wrong nonce".
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
c.queue(mustJSON(map[string]any{"cmd": "login_result", "ok": false, "msg": "Invalid credentials"}))
|
||||
return
|
||||
}
|
||||
// Successful login: clear the per-IP/per-user budgets so a legitimate
|
||||
// user who fat-fingered a few times doesn't stay throttled.
|
||||
h.resetLoginThrottle(c, in.Username)
|
||||
|
||||
c.nonce = ""
|
||||
c.token = token
|
||||
c.role = role
|
||||
@@ -136,6 +167,31 @@ func (h *wsHub) handleLogin(c *wsClient, raw []byte) {
|
||||
}))
|
||||
}
|
||||
|
||||
// allowLoginByIP / allowLoginByUsername return true when the call is
|
||||
// within budget; nil limiter always returns true (effectively disabled).
|
||||
func (h *wsHub) allowLoginByIP(c *wsClient) bool {
|
||||
if h.loginIPLimit == nil || c == nil || c.addr == "" {
|
||||
return true
|
||||
}
|
||||
return h.loginIPLimit.Allow(c.addr)
|
||||
}
|
||||
|
||||
func (h *wsHub) allowLoginByUsername(username string) bool {
|
||||
if h.loginUserLimit == nil || username == "" {
|
||||
return true
|
||||
}
|
||||
return h.loginUserLimit.Allow(username)
|
||||
}
|
||||
|
||||
func (h *wsHub) resetLoginThrottle(c *wsClient, username string) {
|
||||
if h.loginIPLimit != nil && c != nil && c.addr != "" {
|
||||
h.loginIPLimit.Reset(c.addr)
|
||||
}
|
||||
if h.loginUserLimit != nil && username != "" {
|
||||
h.loginUserLimit.Reset(username)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnect kicks off a screen-sharing session for the browser. We send
|
||||
// COMMAND_SCREEN_SPY to the device's main TCP connection; the device then
|
||||
// opens a new sub-connection (TOKEN_BITMAPINFO) which the TCP side binds to
|
||||
|
||||
Reference in New Issue
Block a user