package web import ( "encoding/json" "net" "net/http" "net/url" "strings" "sync" "time" "github.com/gorilla/websocket" "github.com/yuanyuanxiang/SimpleRemoter/server/go/hub" "github.com/yuanyuanxiang/SimpleRemoter/server/go/logger" "github.com/yuanyuanxiang/SimpleRemoter/server/go/wsauth" ) // ----- WS framing knobs --------------------------------------------------- const ( wsWriteWait = 10 * time.Second // single-frame write deadline wsReadLimit = 1 << 20 // refuse incoming frames over 1 MB wsSendBuffer = 64 // outbound queue depth per client ) // baseUpgrader carries the buffer-size config shared by all WS upgrades. // CheckOrigin is set per-hub in wsHub.upgradeWS so the allowlist is // closed-over per instance instead of being a global mutable. var baseUpgrader = websocket.Upgrader{ ReadBufferSize: 4096, WriteBufferSize: 4096, } // ----- per-connection client state ---------------------------------------- // wsMsg is one queued WebSocket frame. binary toggles between // websocket.TextMessage (JSON signaling) and websocket.BinaryMessage // (screen frames). type wsMsg struct { binary bool data []byte } type wsClient struct { conn *websocket.Conn send chan wsMsg closed chan struct{} once sync.Once // Mutated under wsHub.mu (or only by the read loop owning this client). nonce string // outstanding challenge — cleared after a successful login token string // set once authenticated role string // mirrors session role after login addr string // client address for logs watching string // device ID this browser is currently streaming, "" when on the list termWatching string // device ID for an open web terminal session, "" otherwise } // queue writes a JSON text frame onto the send buffer. Drops silently if the // buffer is full so a stuck reader can't back-pressure the broadcast path. func (c *wsClient) queue(payload []byte) { c.enqueue(wsMsg{binary: false, data: payload}) } // queueBinary writes a binary WS frame. Used for screen-stream packets. func (c *wsClient) queueBinary(payload []byte) { c.enqueue(wsMsg{binary: true, data: payload}) } func (c *wsClient) enqueue(m wsMsg) { select { case c.send <- m: case <-c.closed: default: // queue full — drop (acceptable for video; signaling clients are // typically not behind enough for the small text buffer to fill). } } // close signals both loops to exit. Safe to call multiple times. func (c *wsClient) close() { c.once.Do(func() { close(c.closed) _ = c.conn.Close() }) } // ----- ws hub: registry of all connected browsers ------------------------- type wsHub struct { auth *wsauth.Authenticator devices *hub.Hub log *logger.Logger mu sync.RWMutex clients map[*wsClient]struct{} unsub func() // Hardening knobs wired from server.Config. Nil/empty values mean // "no extra restriction" — useful for local dev where the hub is // exercised without server.Server wiring up the env-driven defaults. allowedOrigins []string // empty → only same-origin upgrades accepted loginIPLimit *wsauth.RateLimiter loginUserLimit *wsauth.RateLimiter trustForwardedFor bool // honor X-Forwarded-For (only when behind a trusted proxy) } func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger) *wsHub { h := &wsHub{ auth: auth, devices: devices, log: log, clients: make(map[*wsClient]struct{}), } h.unsub = devices.Subscribe(h) return h } // withOriginAllowlist returns h after installing the explicit Origin // allowlist. Chainable. Pass empty/nil to keep "same-origin only". func (h *wsHub) withOriginAllowlist(origins []string) *wsHub { h.allowedOrigins = origins return h } // withLoginRateLimiters wires per-IP and per-username throttles into // the login flow. Either may be nil to disable that dimension. func (h *wsHub) withLoginRateLimiters(byIP, byUser *wsauth.RateLimiter) *wsHub { h.loginIPLimit = byIP h.loginUserLimit = byUser return h } // withTrustForwardedFor opts in to using the last entry of the // X-Forwarded-For header as the client IP. Safe only when the server is // behind a reverse proxy that you control. func (h *wsHub) withTrustForwardedFor(trust bool) *wsHub { h.trustForwardedFor = trust return h } // checkOrigin decides whether to accept a WebSocket upgrade based on // the request's Origin header. Same-origin (Origin host == Host) is // always accepted; explicit allowlist entries cover the // PWA-from-different-domain or local-dev cases. // // An empty Origin header is rejected: a legitimate browser always sends // it on cross-origin requests, and same-origin requests have it too in // modern Chrome/Safari/Firefox. Non-browser clients (curl, scripts) that // omit Origin shouldn't be talking to the WS endpoint anyway. func (h *wsHub) checkOrigin(r *http.Request) bool { origin := r.Header.Get("Origin") if origin == "" { return false } u, err := url.Parse(origin) if err != nil || u.Host == "" { return false } // Same-origin (Origin host matches the Host the request came in on). // Strip any port mismatch: if the server is behind a proxy, Host may // not include a port while Origin does (or vice versa), so compare // the hostname components. originHost := u.Hostname() reqHost := stripPort(r.Host) if originHost == reqHost && originHost != "" { return true } // Explicit allowlist entries — match Origin in full (scheme + host // + port) so a customer can pin exactly one trusted PWA origin. for _, allowed := range h.allowedOrigins { if allowed == "" { continue } if strings.EqualFold(origin, strings.TrimSpace(allowed)) { return true } } return false } func stripPort(hostport string) string { if h, _, err := net.SplitHostPort(hostport); err == nil { return h } return hostport } // clientIP returns the source IP of an HTTP request. By default uses // r.RemoteAddr (the actual TCP peer); this is the only safe choice when // the server is directly exposed to the internet, because a malicious // client can put anything in X-Forwarded-For and would otherwise rotate // it to evade per-IP rate limits. // // When `trustForwardedFor` is true the LAST entry of X-Forwarded-For is // returned instead — appropriate only when running behind a reverse // proxy that you control and that overwrites/appends the header (caddy, // nginx with proper config, etc). Toggled via the YAMA_WEB_TRUST_PROXY // env var at startup. func clientIP(r *http.Request, trustForwardedFor bool) string { if trustForwardedFor { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // Take the LAST entry — the closest hop, i.e. our own // trusted proxy's view of the peer. Trusting the first // entry would let a malicious client at the head of the // chain set an arbitrary value. parts := strings.Split(xff, ",") ip := strings.TrimSpace(parts[len(parts)-1]) if ip != "" { return ip } } } return stripPort(r.RemoteAddr) } // stop unsubscribes from the device hub. Existing connections keep running // until they close on their own; we only block new event delivery. func (h *wsHub) stop() { if h.unsub != nil { h.unsub() h.unsub = nil } } // hub.EventHandler — invoked from hub.Register / hub.Unregister. func (h *wsHub) OnDeviceOnline(_ hub.DeviceInfo) { h.broadcastAuthenticated(`{"cmd":"devices_changed"}`) } func (h *wsHub) OnDeviceOffline(id string) { // Tell everyone authenticated to refresh their device list (covers // users sitting on the devices-page). h.broadcastAuthenticated(`{"cmd":"devices_changed"}`) // Also tell any browser actively viewing this device's screen — the // devices_changed handler only refreshes the list page; viewers on the // screen page would otherwise see a frozen frame with "Connected" status // indefinitely. The browser handles this by showing "Device offline" and // bouncing back to the device list after 2 s. if id == "" { return } msg := mustJSON(map[string]any{ "cmd": "device_offline", "id": id, }) h.mu.RLock() defer h.mu.RUnlock() for c := range h.clients { if c.watching == id && c.token != "" { c.queue(msg) } } } // OnCursorChange relays the remote cursor index to every viewer of this // device. The browser maps the index to a CSS cursor (desktop) or overlay // SVG variant (touch). Hub already de-duplicates so we always have a real // transition here. func (h *wsHub) OnCursorChange(deviceID string, index byte) { msg := mustJSON(map[string]any{ "cmd": "cursor", "index": index, }) h.mu.RLock() defer h.mu.RUnlock() for c := range h.clients { if c.watching == deviceID && c.token != "" { c.queue(msg) } } } // OnResolutionChange notifies viewers so the browser-side WebCodecs decoder // can be (re)initialized with the right frame size. Without this, incoming // binary frames after connect_result are decoded by an uninitialized // VideoDecoder and the page stays on "Waiting for video...". func (h *wsHub) OnResolutionChange(deviceID string, width, height int) { msg := mustJSON(map[string]any{ "cmd": "resolution_changed", "id": deviceID, "width": width, "height": height, }) h.mu.RLock() defer h.mu.RUnlock() for c := range h.clients { if c.watching == deviceID && c.token != "" { c.queue(msg) } } } // OnScreenFrame ships a screen packet to every browser currently watching // this device. We hold the read lock for the whole iteration, but each // queueBinary is non-blocking (drops on backpressure) so a slow viewer // cannot stall the fast ones. func (h *wsHub) OnScreenFrame(deviceID string, packet []byte, _ bool) { h.mu.RLock() defer h.mu.RUnlock() for c := range h.clients { if c.watching == deviceID && c.token != "" { c.queueBinary(packet) } } } // OnTerminalReady notifies the requesting browser that its term_open // handshake completed. mode is "pty" or "legacy" — xterm.js disables the // resize callback in legacy mode (no PTY behind the cmd pipe). func (h *wsHub) OnTerminalReady(deviceID string, isPTY bool) { mode := "legacy" if isPTY { mode = "pty" } msg := mustJSON(map[string]any{ "cmd": "term_ready", "id": deviceID, "mode": mode, }) h.mu.RLock() defer h.mu.RUnlock() for c := range h.clients { if c.termWatching == deviceID && c.token != "" { c.queue(msg) } } } // OnTerminalData ships one chunk of raw shell output (already wrapped in // the "TRM1" magic header) over the binary WS frame. Single-viewer is // enforced upstream so at most one client matches per device. func (h *wsHub) OnTerminalData(deviceID string, packet []byte) { h.mu.RLock() defer h.mu.RUnlock() for c := range h.clients { if c.termWatching == deviceID && c.token != "" { c.queueBinary(packet) } } } // OnTerminalClosed fires when the device's shell exits or the sub-conn // drops. The browser closes its xterm panel. We also clear termWatching // so a subsequent term_open from the same browser isn't rejected as // "already open" by stale state. func (h *wsHub) OnTerminalClosed(deviceID string, reason string) { msg := mustJSON(map[string]any{ "cmd": "term_closed", "ok": true, "reason": reason, }) h.mu.Lock() defer h.mu.Unlock() for c := range h.clients { if c.termWatching == deviceID && c.token != "" { c.termWatching = "" c.queue(msg) } } } // OnDeviceUpdate forwards heartbeat-derived liveness data so the device-list // rows can refresh RTT and active-window labels without re-fetching. func (h *wsHub) OnDeviceUpdate(id string, rtt int, activeWindow string) { payload := mustJSON(map[string]any{ "cmd": "device_update", "id": id, "rtt": rtt, "activeWindow": activeWindow, }) h.mu.RLock() defer h.mu.RUnlock() for c := range h.clients { if c.token != "" { c.queue(payload) } } } func (h *wsHub) broadcastAuthenticated(msg string) { payload := []byte(msg) h.mu.RLock() defer h.mu.RUnlock() for c := range h.clients { if c.token != "" { c.queue(payload) } } } func (h *wsHub) register(c *wsClient) { h.mu.Lock() h.clients[c] = struct{}{} h.mu.Unlock() } func (h *wsHub) unregister(c *wsClient) { h.mu.Lock() delete(h.clients, c) h.mu.Unlock() // If this client was the last viewer of a device, tear down the screen // session so the device stops encoding. Done OUTSIDE the lock so the // hub's mutators can take their own locks without risk of recursion. if c.watching != "" && h.countWatchers(c.watching) == 0 { h.devices.CloseScreen(c.watching) } // Terminal sessions are single-viewer by design, so any open session // belongs to this client. Tear it down so the next viewer doesn't // hit ErrTerminalBusy from an abandoned session. if c.termWatching != "" { h.devices.CloseTerminalSession(c.termWatching) c.termWatching = "" } // Do NOT revoke the token: tokens are session-scoped, not WS-scoped. // Frontend may close+reopen the WS at any time (visibilitychange handler, // brief network blip, reload) and must be able to resume with the same // cached token. The token expires on its own TTL. c.close() } // ----- HTTP handler ------------------------------------------------------- func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) { // Build a per-call upgrader so CheckOrigin closes over this hub's // allowlist instead of a package-level mutable. upgrader := baseUpgrader upgrader.CheckOrigin = h.checkOrigin conn, err := upgrader.Upgrade(w, r, nil) if err != nil { h.log.Error("ws upgrade: %v", err) return } conn.SetReadLimit(wsReadLimit) nonce, err := wsauth.NewNonce() if err != nil { h.log.Error("nonce gen: %v", err) _ = conn.Close() return } client := &wsClient{ conn: conn, send: make(chan wsMsg, wsSendBuffer), closed: make(chan struct{}), nonce: nonce, addr: clientIP(r, h.trustForwardedFor), } h.register(client) defer h.unregister(client) go h.writeLoop(client) // Greet with a challenge nonce so the browser can compute the login response. client.queue([]byte(`{"cmd":"challenge","nonce":"` + nonce + `"}`)) h.readLoop(client) } // writeLoop drains the send queue. Exits when the channel is closed or a // write fails. Closing the underlying connection is the read loop's job. func (h *wsHub) writeLoop(c *wsClient) { for { select { case msg := <-c.send: msgType := websocket.TextMessage if msg.binary { msgType = websocket.BinaryMessage } _ = c.conn.SetWriteDeadline(time.Now().Add(wsWriteWait)) if err := c.conn.WriteMessage(msgType, msg.data); err != nil { c.close() return } case <-c.closed: return } } } // readLoop dispatches incoming messages. Exits on read error (peer closed, // timeout, malformed frame, etc.), which then triggers unregister cleanup. func (h *wsHub) readLoop(c *wsClient) { for { _, raw, err := c.conn.ReadMessage() if err != nil { return } var env struct { Cmd string `json:"cmd"` } if err := json.Unmarshal(raw, &env); err != nil { continue // ignore garbage frames } h.dispatch(c, env.Cmd, raw) } }