Improve Go Server to support remote desktop and command control #1
@@ -56,6 +56,7 @@ server/go/
|
|||||||
Web 应用能力 (Phase 3-7):
|
Web 应用能力 (Phase 3-7):
|
||||||
|
|
||||||
- **Web 鉴权**: challenge-response 登录 + 不透明 token,与 users.json schema 互通
|
- **Web 鉴权**: challenge-response 登录 + 不透明 token,与 users.json schema 互通
|
||||||
|
- **登录加固**: 双维度速率限制(10 次/分钟·IP + 5 次/15 分钟·用户名)+ 失败固定延迟,防口令枚举;`/get_salt` 用确定性假盐响应未知用户,杜绝用户名探测;WebSocket Origin 同源校验 + 显式白名单;`/api/devices` Bearer Token 鉴权
|
||||||
- **设备列表与监控**: 在线设备 / RTT / 活动窗口 / 分辨率 实时下发
|
- **设备列表与监控**: 在线设备 / RTT / 活动窗口 / 分辨率 实时下发
|
||||||
- **Web 远程桌面**: 浏览器 WebCodecs 解码 H.264,二进制 WS 帧低延迟中继;late-join 自动重发最近 IDR;优雅 BYE 关闭防止客户端无意义重连
|
- **Web 远程桌面**: 浏览器 WebCodecs 解码 H.264,二进制 WS 帧低延迟中继;late-join 自动重发最近 IDR;优雅 BYE 关闭防止客户端无意义重连
|
||||||
- **鼠标 / 键盘输入**: Win32 消息映射 (`WM_*` / `VK_*` / `MK_*`),MSG64 48 字节布局直传客户端
|
- **鼠标 / 键盘输入**: Win32 消息映射 (`WM_*` / `VK_*` / `MK_*`),MSG64 48 字节布局直传客户端
|
||||||
@@ -149,6 +150,8 @@ VSCode F5 调试时由 `sync-web-assets` preLaunchTask 自动同步。
|
|||||||
| `YAMA_WEB_ADMIN_PASS` | Web UI 的 admin 密码(明文);优先于 `YAMA_PWD`。两者都未设置时 Web 登录禁用 | `your_admin_password` |
|
| `YAMA_WEB_ADMIN_PASS` | Web UI 的 admin 密码(明文);优先于 `YAMA_PWD`。两者都未设置时 Web 登录禁用 | `your_admin_password` |
|
||||||
| `YAMA_SIGN_PASSWORD` | HMAC-SHA256 key used to sign CMD_MASTERSETTING replies; must match the client's expected value. Provision out-of-band. Unset → client refuses screen/file ops. | `<deployment-shared-secret>` |
|
| `YAMA_SIGN_PASSWORD` | HMAC-SHA256 key used to sign CMD_MASTERSETTING replies; must match the client's expected value. Provision out-of-band. Unset → client refuses screen/file ops. | `<deployment-shared-secret>` |
|
||||||
| `YAMA_USERS_FILE` | Path to the JSON file that persists non-admin web users (allowed_groups, password hash, salt). Default is `users.json` in the working directory. | `users.json` |
|
| `YAMA_USERS_FILE` | Path to the JSON file that persists non-admin web users (allowed_groups, password hash, salt). Default is `users.json` in the working directory. | `users.json` |
|
||||||
|
| `YAMA_WEB_ALLOWED_ORIGINS` | Comma-separated WebSocket Origin allowlist for cross-origin upgrades. Empty (default) → only same-origin upgrades are accepted, which is correct when the web UI and `/ws` share a host. Add an entry per trusted PWA / dev origin. | `https://yama.example.com,https://yama-mobile.example.com` |
|
||||||
|
| `YAMA_WEB_TRUST_PROXY` | Set to `1` only when running behind a reverse proxy you control (caddy / nginx / cloudflare). Switches client-IP extraction to use the last entry of `X-Forwarded-For` instead of `RemoteAddr`, so per-IP login rate limit sees the real client. Direct-exposure deployments MUST leave this unset — otherwise attackers can spoof the header to evade rate limits. | `1` |
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Linux/macOS
|
# Linux/macOS
|
||||||
|
|||||||
@@ -611,6 +611,27 @@ func parsePorts(portStr string) ([]int, error) {
|
|||||||
return ports, nil
|
return ports, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// splitCSV splits a comma-separated env-var value into trimmed, non-empty
|
||||||
|
// entries. Returns nil for an empty input so callers can keep the natural
|
||||||
|
// "no value → no restriction" semantics with a single nil check.
|
||||||
|
func splitCSV(s string) []string {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
parts := strings.Split(s, ",")
|
||||||
|
out := make([]string, 0, len(parts))
|
||||||
|
for _, p := range parts {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p != "" {
|
||||||
|
out = append(out, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Parse command line flags
|
// Parse command line flags
|
||||||
portStr := flag.String("port", "6543", "Server listen ports (semicolon-separated, e.g. 6543;6544;6545)")
|
portStr := flag.String("port", "6543", "Server listen ports (semicolon-separated, e.g. 6543;6544;6545)")
|
||||||
@@ -733,9 +754,33 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Web-UI hardening knobs for public-HTTPS deployment.
|
||||||
|
//
|
||||||
|
// YAMA_WEB_ALLOWED_ORIGINS: comma-separated Origin allowlist (e.g.
|
||||||
|
// "https://yama.example.com,https://yama-mobile.example.com").
|
||||||
|
// Empty (default) → only same-origin WS upgrades accepted, which
|
||||||
|
// is correct when the web UI and WS endpoint share a host.
|
||||||
|
//
|
||||||
|
// Login rate limits are hard-coded at sensible defaults for the
|
||||||
|
// small-user web UI: 10 attempts / minute per IP, 5 / 15 min per
|
||||||
|
// username. The handler also injects a 250 ms delay on every failure
|
||||||
|
// so online brute force is impractical even within budget.
|
||||||
|
allowedOrigins := splitCSV(os.Getenv("YAMA_WEB_ALLOWED_ORIGINS"))
|
||||||
|
trustProxy := os.Getenv("YAMA_WEB_TRUST_PROXY") == "1"
|
||||||
|
if trustProxy {
|
||||||
|
log.Info("Trusting X-Forwarded-For for client IP — make sure a reverse proxy is in front")
|
||||||
|
}
|
||||||
|
webCfg := web.Config{
|
||||||
|
AllowedOrigins: allowedOrigins,
|
||||||
|
LoginIPLimit: wsauth.NewRateLimiter(10, time.Minute),
|
||||||
|
LoginUserLimit: wsauth.NewRateLimiter(5, 15*time.Minute),
|
||||||
|
TrustForwardedFor: trustProxy,
|
||||||
|
}
|
||||||
|
|
||||||
// Start HTTP server for web UI. Hub gives it read-only access to the
|
// Start HTTP server for web UI. Hub gives it read-only access to the
|
||||||
// device registry; the authenticator owns user accounts and session tokens.
|
// device registry; the authenticator owns user accounts and session tokens.
|
||||||
httpSrv := web.New(*httpPort, log.WithPrefix("Web"), deviceHub, webAuth)
|
httpSrv := web.New(*httpPort, log.WithPrefix("Web"), deviceHub, webAuth).
|
||||||
|
WithConfig(webCfg)
|
||||||
if err := httpSrv.Start(); err != nil {
|
if err := httpSrv.Start(); err != nil {
|
||||||
log.Fatal("Failed to start HTTP server: %v", err)
|
log.Fatal("Failed to start HTTP server: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/hub"
|
"github.com/yuanyuanxiang/SimpleRemoter/server/go/hub"
|
||||||
@@ -25,6 +26,32 @@ type Server struct {
|
|||||||
hub *hub.Hub
|
hub *hub.Hub
|
||||||
auth *wsauth.Authenticator
|
auth *wsauth.Authenticator
|
||||||
ws *wsHub
|
ws *wsHub
|
||||||
|
allowedOrigins []string // for WS Origin allowlist; empty = same-origin only
|
||||||
|
loginIPLimit *wsauth.RateLimiter
|
||||||
|
loginUserLimit *wsauth.RateLimiter
|
||||||
|
trustForwardedFor bool // honor X-Forwarded-For (behind trusted proxy only)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config tunes the server's exposed-on-public-HTTPS hardening knobs.
|
||||||
|
// All fields are optional; zero values pick reasonable defaults.
|
||||||
|
type Config struct {
|
||||||
|
// AllowedOrigins is the comma-separated list of Origin header values
|
||||||
|
// the WebSocket upgrade will accept in addition to same-origin
|
||||||
|
// requests. Empty (default) → only same-origin upgrades are allowed,
|
||||||
|
// which is correct when the web UI and the WS endpoint are served
|
||||||
|
// from the same host.
|
||||||
|
AllowedOrigins []string
|
||||||
|
// LoginIPLimit / LoginUserLimit throttle the get_salt + login flow
|
||||||
|
// per source IP and per username respectively. Pass nil to disable
|
||||||
|
// either dimension (e.g. dev mode).
|
||||||
|
LoginIPLimit *wsauth.RateLimiter
|
||||||
|
LoginUserLimit *wsauth.RateLimiter
|
||||||
|
// TrustForwardedFor switches client-IP extraction from RemoteAddr
|
||||||
|
// (default) to the last entry of X-Forwarded-For. Set true only when
|
||||||
|
// running behind a reverse proxy that you control; on direct
|
||||||
|
// exposure the header is client-controlled and would let attackers
|
||||||
|
// evade per-IP rate limits.
|
||||||
|
TrustForwardedFor bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates an HTTP server bound to the given port. port=0 disables the server.
|
// New creates an HTTP server bound to the given port. port=0 disables the server.
|
||||||
@@ -34,6 +61,16 @@ func New(port int, log *logger.Logger, h *hub.Hub, auth *wsauth.Authenticator) *
|
|||||||
return &Server{port: port, log: log, hub: h, auth: auth}
|
return &Server{port: port, log: log, hub: h, auth: auth}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithConfig applies hardening configuration. Returns the receiver for
|
||||||
|
// chainable setup. Safe to call before Start; ignored thereafter.
|
||||||
|
func (s *Server) WithConfig(cfg Config) *Server {
|
||||||
|
s.allowedOrigins = cfg.AllowedOrigins
|
||||||
|
s.loginIPLimit = cfg.LoginIPLimit
|
||||||
|
s.loginUserLimit = cfg.LoginUserLimit
|
||||||
|
s.trustForwardedFor = cfg.TrustForwardedFor
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// Start launches the server in a goroutine and returns immediately.
|
// Start launches the server in a goroutine and returns immediately.
|
||||||
// If port is 0, returns nil without starting anything.
|
// If port is 0, returns nil without starting anything.
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
@@ -42,13 +79,16 @@ func (s *Server) Start() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
s.ws = newWSHub(s.auth, s.hub, s.log)
|
s.ws = newWSHub(s.auth, s.hub, s.log).
|
||||||
|
withOriginAllowlist(s.allowedOrigins).
|
||||||
|
withLoginRateLimiters(s.loginIPLimit, s.loginUserLimit).
|
||||||
|
withTrustForwardedFor(s.trustForwardedFor)
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/", s.handleIndex)
|
mux.HandleFunc("/", s.handleIndex)
|
||||||
mux.HandleFunc("/health", s.handleHealth)
|
mux.HandleFunc("/health", s.handleHealth)
|
||||||
mux.HandleFunc("/manifest.json", s.handleManifest)
|
mux.HandleFunc("/manifest.json", s.handleManifest)
|
||||||
mux.HandleFunc("/api/devices", s.handleDevices)
|
mux.HandleFunc("/api/devices", s.requireBearer(s.handleDevices))
|
||||||
mux.HandleFunc("/ws", s.ws.serve)
|
mux.HandleFunc("/ws", s.ws.serve)
|
||||||
mux.HandleFunc("/static/xterm.js", staticHandler(xtermJS, "application/javascript; charset=utf-8"))
|
mux.HandleFunc("/static/xterm.js", staticHandler(xtermJS, "application/javascript; charset=utf-8"))
|
||||||
mux.HandleFunc("/static/xterm.css", staticHandler(xtermCSS, "text/css; charset=utf-8"))
|
mux.HandleFunc("/static/xterm.css", staticHandler(xtermCSS, "text/css; charset=utf-8"))
|
||||||
@@ -106,7 +146,7 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// handleDevices returns a JSON snapshot of currently-online devices. Empty
|
// handleDevices returns a JSON snapshot of currently-online devices. Empty
|
||||||
// array (not null) when no clients are connected — matches what the front-end
|
// array (not null) when no clients are connected — matches what the front-end
|
||||||
// will eventually expect.
|
// will eventually expect. Auth-gated via requireBearer.
|
||||||
func (s *Server) handleDevices(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleDevices(w http.ResponseWriter, r *http.Request) {
|
||||||
devices := s.hub.ListDevices()
|
devices := s.hub.ListDevices()
|
||||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
@@ -116,6 +156,39 @@ func (s *Server) handleDevices(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// requireBearer wraps a handler with `Authorization: Bearer <token>` auth
|
||||||
|
// against the same session-token store the WebSocket uses. Returns 401 on
|
||||||
|
// missing / invalid / expired tokens. Used to gate REST endpoints that
|
||||||
|
// previously fell through with no auth (notably /api/devices, which
|
||||||
|
// otherwise leaks the full online-device list to anyone on the internet).
|
||||||
|
func (s *Server) requireBearer(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
const prefix = "Bearer "
|
||||||
|
hdr := r.Header.Get("Authorization")
|
||||||
|
if !strings.HasPrefix(hdr, prefix) {
|
||||||
|
s.unauthorized(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token := strings.TrimSpace(hdr[len(prefix):])
|
||||||
|
if token == "" {
|
||||||
|
s.unauthorized(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := s.auth.ValidateToken(token); err != nil {
|
||||||
|
s.unauthorized(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) unauthorized(w http.ResponseWriter) {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Bearer realm="yama"`)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
|
||||||
|
}
|
||||||
|
|
||||||
// PWA manifest. Referenced by <link rel="manifest"> in index.html.
|
// PWA manifest. Referenced by <link rel="manifest"> in index.html.
|
||||||
// Static JSON, no template needed.
|
// Static JSON, no template needed.
|
||||||
const manifestJSON = `{
|
const manifestJSON = `{
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ package web
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -20,13 +23,12 @@ const (
|
|||||||
wsSendBuffer = 64 // outbound queue depth per client
|
wsSendBuffer = 64 // outbound queue depth per client
|
||||||
)
|
)
|
||||||
|
|
||||||
// upgrader allows any origin — this service is meant to be tunneled through
|
// baseUpgrader carries the buffer-size config shared by all WS upgrades.
|
||||||
// frp, so requests can legitimately arrive from arbitrary front-end hosts.
|
// CheckOrigin is set per-hub in wsHub.upgradeWS so the allowlist is
|
||||||
// Adjust CheckOrigin once we have a deployment story.
|
// closed-over per instance instead of being a global mutable.
|
||||||
var upgrader = websocket.Upgrader{
|
var baseUpgrader = websocket.Upgrader{
|
||||||
ReadBufferSize: 4096,
|
ReadBufferSize: 4096,
|
||||||
WriteBufferSize: 4096,
|
WriteBufferSize: 4096,
|
||||||
CheckOrigin: func(r *http.Request) bool { return true },
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----- per-connection client state ----------------------------------------
|
// ----- per-connection client state ----------------------------------------
|
||||||
@@ -94,6 +96,14 @@ type wsHub struct {
|
|||||||
clients map[*wsClient]struct{}
|
clients map[*wsClient]struct{}
|
||||||
|
|
||||||
unsub func()
|
unsub func()
|
||||||
|
|
||||||
|
// Hardening knobs wired from server.Config. Nil/empty values mean
|
||||||
|
// "no extra restriction" — useful for local dev where the hub is
|
||||||
|
// exercised without server.Server wiring up the env-driven defaults.
|
||||||
|
allowedOrigins []string // empty → only same-origin upgrades accepted
|
||||||
|
loginIPLimit *wsauth.RateLimiter
|
||||||
|
loginUserLimit *wsauth.RateLimiter
|
||||||
|
trustForwardedFor bool // honor X-Forwarded-For (only when behind a trusted proxy)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger) *wsHub {
|
func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger) *wsHub {
|
||||||
@@ -107,6 +117,104 @@ func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger)
|
|||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// withOriginAllowlist returns h after installing the explicit Origin
|
||||||
|
// allowlist. Chainable. Pass empty/nil to keep "same-origin only".
|
||||||
|
func (h *wsHub) withOriginAllowlist(origins []string) *wsHub {
|
||||||
|
h.allowedOrigins = origins
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// withLoginRateLimiters wires per-IP and per-username throttles into
|
||||||
|
// the login flow. Either may be nil to disable that dimension.
|
||||||
|
func (h *wsHub) withLoginRateLimiters(byIP, byUser *wsauth.RateLimiter) *wsHub {
|
||||||
|
h.loginIPLimit = byIP
|
||||||
|
h.loginUserLimit = byUser
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// withTrustForwardedFor opts in to using the last entry of the
|
||||||
|
// X-Forwarded-For header as the client IP. Safe only when the server is
|
||||||
|
// behind a reverse proxy that you control.
|
||||||
|
func (h *wsHub) withTrustForwardedFor(trust bool) *wsHub {
|
||||||
|
h.trustForwardedFor = trust
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkOrigin decides whether to accept a WebSocket upgrade based on
|
||||||
|
// the request's Origin header. Same-origin (Origin host == Host) is
|
||||||
|
// always accepted; explicit allowlist entries cover the
|
||||||
|
// PWA-from-different-domain or local-dev cases.
|
||||||
|
//
|
||||||
|
// An empty Origin header is rejected: a legitimate browser always sends
|
||||||
|
// it on cross-origin requests, and same-origin requests have it too in
|
||||||
|
// modern Chrome/Safari/Firefox. Non-browser clients (curl, scripts) that
|
||||||
|
// omit Origin shouldn't be talking to the WS endpoint anyway.
|
||||||
|
func (h *wsHub) checkOrigin(r *http.Request) bool {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
if origin == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
u, err := url.Parse(origin)
|
||||||
|
if err != nil || u.Host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Same-origin (Origin host matches the Host the request came in on).
|
||||||
|
// Strip any port mismatch: if the server is behind a proxy, Host may
|
||||||
|
// not include a port while Origin does (or vice versa), so compare
|
||||||
|
// the hostname components.
|
||||||
|
originHost := u.Hostname()
|
||||||
|
reqHost := stripPort(r.Host)
|
||||||
|
if originHost == reqHost && originHost != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Explicit allowlist entries — match Origin in full (scheme + host
|
||||||
|
// + port) so a customer can pin exactly one trusted PWA origin.
|
||||||
|
for _, allowed := range h.allowedOrigins {
|
||||||
|
if allowed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(origin, strings.TrimSpace(allowed)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripPort(hostport string) string {
|
||||||
|
if h, _, err := net.SplitHostPort(hostport); err == nil {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
return hostport
|
||||||
|
}
|
||||||
|
|
||||||
|
// clientIP returns the source IP of an HTTP request. By default uses
|
||||||
|
// r.RemoteAddr (the actual TCP peer); this is the only safe choice when
|
||||||
|
// the server is directly exposed to the internet, because a malicious
|
||||||
|
// client can put anything in X-Forwarded-For and would otherwise rotate
|
||||||
|
// it to evade per-IP rate limits.
|
||||||
|
//
|
||||||
|
// When `trustForwardedFor` is true the LAST entry of X-Forwarded-For is
|
||||||
|
// returned instead — appropriate only when running behind a reverse
|
||||||
|
// proxy that you control and that overwrites/appends the header (caddy,
|
||||||
|
// nginx with proper config, etc). Toggled via the YAMA_WEB_TRUST_PROXY
|
||||||
|
// env var at startup.
|
||||||
|
func clientIP(r *http.Request, trustForwardedFor bool) string {
|
||||||
|
if trustForwardedFor {
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
// Take the LAST entry — the closest hop, i.e. our own
|
||||||
|
// trusted proxy's view of the peer. Trusting the first
|
||||||
|
// entry would let a malicious client at the head of the
|
||||||
|
// chain set an arbitrary value.
|
||||||
|
parts := strings.Split(xff, ",")
|
||||||
|
ip := strings.TrimSpace(parts[len(parts)-1])
|
||||||
|
if ip != "" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stripPort(r.RemoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
// stop unsubscribes from the device hub. Existing connections keep running
|
// stop unsubscribes from the device hub. Existing connections keep running
|
||||||
// until they close on their own; we only block new event delivery.
|
// until they close on their own; we only block new event delivery.
|
||||||
func (h *wsHub) stop() {
|
func (h *wsHub) stop() {
|
||||||
@@ -294,6 +402,11 @@ func (h *wsHub) unregister(c *wsClient) {
|
|||||||
// ----- HTTP handler -------------------------------------------------------
|
// ----- HTTP handler -------------------------------------------------------
|
||||||
|
|
||||||
func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) {
|
func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Build a per-call upgrader so CheckOrigin closes over this hub's
|
||||||
|
// allowlist instead of a package-level mutable.
|
||||||
|
upgrader := baseUpgrader
|
||||||
|
upgrader.CheckOrigin = h.checkOrigin
|
||||||
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.log.Error("ws upgrade: %v", err)
|
h.log.Error("ws upgrade: %v", err)
|
||||||
@@ -313,7 +426,7 @@ func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) {
|
|||||||
send: make(chan wsMsg, wsSendBuffer),
|
send: make(chan wsMsg, wsSendBuffer),
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
nonce: nonce,
|
nonce: nonce,
|
||||||
addr: r.RemoteAddr,
|
addr: clientIP(r, h.trustForwardedFor),
|
||||||
}
|
}
|
||||||
h.register(client)
|
h.register(client)
|
||||||
defer h.unregister(client)
|
defer h.unregister(client)
|
||||||
|
|||||||
@@ -79,18 +79,27 @@ func (h *wsHub) requireAdmin(c *wsClient, raw []byte, replyCmd string) (ok bool)
|
|||||||
// ----- handlers ------------------------------------------------------------
|
// ----- handlers ------------------------------------------------------------
|
||||||
|
|
||||||
func (h *wsHub) handleGetSalt(c *wsClient, raw []byte) {
|
func (h *wsHub) handleGetSalt(c *wsClient, raw []byte) {
|
||||||
|
// Throttle the salt-probe surface together with login: an attacker
|
||||||
|
// who can poll get_salt freely would otherwise still learn nothing
|
||||||
|
// (the unknown-user fake salt mitigation handles that), but the
|
||||||
|
// endpoint is otherwise free CPU on the server. Limiting by IP is
|
||||||
|
// enough; we don't have a username yet to limit by user.
|
||||||
|
if !h.allowLoginByIP(c) {
|
||||||
|
// Stall the response so a tight-loop attacker doesn't flood the
|
||||||
|
// queue. Still return a well-formed salt to avoid making the
|
||||||
|
// limit detectable from the client side.
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
var in struct {
|
var in struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
}
|
}
|
||||||
_ = json.Unmarshal(raw, &in)
|
_ = json.Unmarshal(raw, &in)
|
||||||
|
|
||||||
salt, ok := h.auth.GetSalt(in.Username)
|
salt, _ := h.auth.GetSalt(in.Username)
|
||||||
// Do not leak which usernames exist: always return ok=true with a salt.
|
// GetSalt now returns a deterministic fake salt (16 hex chars) for
|
||||||
// For unknown users hand back the empty salt (matches admin convention)
|
// unknown users — same shape as a real salt — so an attacker can't
|
||||||
// so the timing/shape of the response is uniform.
|
// tell from this response alone whether the username exists.
|
||||||
if !ok {
|
|
||||||
salt = ""
|
|
||||||
}
|
|
||||||
c.queue(mustJSON(map[string]any{
|
c.queue(mustJSON(map[string]any{
|
||||||
"cmd": "salt",
|
"cmd": "salt",
|
||||||
"ok": true,
|
"ok": true,
|
||||||
@@ -109,6 +118,20 @@ func (h *wsHub) handleLogin(c *wsClient, raw []byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rate-limit BEFORE doing the hash work, so a flood doesn't pin CPU.
|
||||||
|
// Two-dimensional throttle: per-IP catches scanners that try many
|
||||||
|
// usernames; per-username catches scanners that rotate IPs against a
|
||||||
|
// known account (admin). Either dimension tripping rejects the call
|
||||||
|
// with a uniform "credentials" error so the limit is not detectable.
|
||||||
|
if !h.allowLoginByIP(c) || !h.allowLoginByUsername(in.Username) {
|
||||||
|
h.log.Warn("ws login throttled: user=%s addr=%s", in.Username, c.addr)
|
||||||
|
// Burn the challenge so the attacker can't immediately replay.
|
||||||
|
c.nonce = ""
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
c.queue(mustJSON(map[string]any{"cmd": "login_result", "ok": false, "msg": "Invalid credentials"}))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Bind the response to the challenge we issued at connect time so that
|
// Bind the response to the challenge we issued at connect time so that
|
||||||
// replays from a different connection can't reuse a captured response.
|
// replays from a different connection can't reuse a captured response.
|
||||||
if in.Nonce == "" || in.Nonce != c.nonce {
|
if in.Nonce == "" || in.Nonce != c.nonce {
|
||||||
@@ -120,9 +143,17 @@ func (h *wsHub) handleLogin(c *wsClient, raw []byte) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// Burn the challenge on failure too — forces a new round on retry.
|
// Burn the challenge on failure too — forces a new round on retry.
|
||||||
c.nonce = ""
|
c.nonce = ""
|
||||||
|
// Fixed delay on failure: makes online brute force impractical
|
||||||
|
// even within the rate-limit budget, and erases the timing
|
||||||
|
// difference between "wrong password" and "wrong nonce".
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
c.queue(mustJSON(map[string]any{"cmd": "login_result", "ok": false, "msg": "Invalid credentials"}))
|
c.queue(mustJSON(map[string]any{"cmd": "login_result", "ok": false, "msg": "Invalid credentials"}))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Successful login: clear the per-IP/per-user budgets so a legitimate
|
||||||
|
// user who fat-fingered a few times doesn't stay throttled.
|
||||||
|
h.resetLoginThrottle(c, in.Username)
|
||||||
|
|
||||||
c.nonce = ""
|
c.nonce = ""
|
||||||
c.token = token
|
c.token = token
|
||||||
c.role = role
|
c.role = role
|
||||||
@@ -136,6 +167,31 @@ func (h *wsHub) handleLogin(c *wsClient, raw []byte) {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// allowLoginByIP / allowLoginByUsername return true when the call is
|
||||||
|
// within budget; nil limiter always returns true (effectively disabled).
|
||||||
|
func (h *wsHub) allowLoginByIP(c *wsClient) bool {
|
||||||
|
if h.loginIPLimit == nil || c == nil || c.addr == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return h.loginIPLimit.Allow(c.addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *wsHub) allowLoginByUsername(username string) bool {
|
||||||
|
if h.loginUserLimit == nil || username == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return h.loginUserLimit.Allow(username)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *wsHub) resetLoginThrottle(c *wsClient, username string) {
|
||||||
|
if h.loginIPLimit != nil && c != nil && c.addr != "" {
|
||||||
|
h.loginIPLimit.Reset(c.addr)
|
||||||
|
}
|
||||||
|
if h.loginUserLimit != nil && username != "" {
|
||||||
|
h.loginUserLimit.Reset(username)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// handleConnect kicks off a screen-sharing session for the browser. We send
|
// handleConnect kicks off a screen-sharing session for the browser. We send
|
||||||
// COMMAND_SCREEN_SPY to the device's main TCP connection; the device then
|
// COMMAND_SCREEN_SPY to the device's main TCP connection; the device then
|
||||||
// opens a new sub-connection (TOKEN_BITMAPINFO) which the TCP side binds to
|
// opens a new sub-connection (TOKEN_BITMAPINFO) which the TCP side binds to
|
||||||
|
|||||||
110
server/go/wsauth/ratelimit.go
Normal file
110
server/go/wsauth/ratelimit.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package wsauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RateLimiter is a sliding-window per-key counter used to throttle login
|
||||||
|
// attempts. Two instances are typically created: one keyed by client IP
|
||||||
|
// (to slow distributed brute force), one keyed by username (to slow
|
||||||
|
// targeted attacks against a known account).
|
||||||
|
//
|
||||||
|
// Design notes:
|
||||||
|
// - Denied attempts are NOT recorded — the window slides naturally and a
|
||||||
|
// legitimate user who fat-fingers their password recovers as soon as
|
||||||
|
// the oldest attempt ages out, while a determined attacker is capped
|
||||||
|
// at `limit` successful attempts per `window` indefinitely.
|
||||||
|
// - Lazy cleanup: stale timestamps for a key are pruned on every Allow()
|
||||||
|
// call. Truly idle keys are GC'd by Sweep(), which callers should run
|
||||||
|
// periodically from a background goroutine.
|
||||||
|
// - Map size is bounded by the count of recently-active keys; for the
|
||||||
|
// web UI's expected load (a handful of users + occasional scanners),
|
||||||
|
// no extra GC pressure considerations needed.
|
||||||
|
type RateLimiter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
limit int
|
||||||
|
window time.Duration
|
||||||
|
entries map[string][]time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRateLimiter returns a limiter that allows up to `limit` events per
|
||||||
|
// `window` duration per key. Zero or negative limit/window disables the
|
||||||
|
// limiter (Allow always returns true) — useful for tests / dev mode.
|
||||||
|
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||||
|
return &RateLimiter{
|
||||||
|
limit: limit,
|
||||||
|
window: window,
|
||||||
|
entries: make(map[string][]time.Time),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow records an attempt for `key` if and only if the caller is under
|
||||||
|
// the per-key limit. Returns true when allowed, false when over limit.
|
||||||
|
// Empty key is treated as "no throttle" (returns true without recording)
|
||||||
|
// so the caller can fall through when the IP/username is unavailable.
|
||||||
|
func (r *RateLimiter) Allow(key string) bool {
|
||||||
|
if r == nil || r.limit <= 0 || r.window <= 0 || key == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
cutoff := time.Now().Add(-r.window)
|
||||||
|
times := r.entries[key]
|
||||||
|
// Compact in place — keep only timestamps within the window.
|
||||||
|
keep := times[:0]
|
||||||
|
for _, t := range times {
|
||||||
|
if t.After(cutoff) {
|
||||||
|
keep = append(keep, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keep) >= r.limit {
|
||||||
|
// Update the map even when denying so the compacted slice doesn't
|
||||||
|
// keep stale entries forever. Don't append the new attempt: that
|
||||||
|
// would let attackers extend the window arbitrarily.
|
||||||
|
r.entries[key] = keep
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
r.entries[key] = append(keep, time.Now())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears state for a key. Call on successful login to give the user
|
||||||
|
// a fresh budget — otherwise a string of failed attempts followed by a
|
||||||
|
// correct one still leaves the budget partially consumed.
|
||||||
|
func (r *RateLimiter) Reset(key string) {
|
||||||
|
if r == nil || key == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
delete(r.entries, key)
|
||||||
|
r.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sweep removes entries whose timestamps have all aged out of the window.
|
||||||
|
// Safe to call concurrently with Allow. Intended for periodic invocation
|
||||||
|
// from a background ticker (e.g. every window-length) to bound the map.
|
||||||
|
func (r *RateLimiter) Sweep() {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
cutoff := time.Now().Add(-r.window)
|
||||||
|
for key, times := range r.entries {
|
||||||
|
keep := times[:0]
|
||||||
|
for _, t := range times {
|
||||||
|
if t.After(cutoff) {
|
||||||
|
keep = append(keep, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(keep) == 0 {
|
||||||
|
delete(r.entries, key)
|
||||||
|
} else {
|
||||||
|
r.entries[key] = keep
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -95,13 +95,26 @@ func (a *Authenticator) AddUser(u User) {
|
|||||||
a.mu.Unlock()
|
a.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddAdminFromPlainPassword is a convenience for the bootstrap admin: salt is
|
// AddAdminFromPlainPassword is a convenience for the bootstrap admin.
|
||||||
// empty (matching the C++ admin record), hash is SHA256(password).
|
// Unlike legacy convention, the admin record is given a real per-instance
|
||||||
|
// salt — exposing an empty salt for admin while everyone else has a real
|
||||||
|
// 16-hex one would let an unauthenticated probe distinguish admin from
|
||||||
|
// other accounts via /get_salt alone. The cost is a tiny break in
|
||||||
|
// users.json schema compat: admin is never persisted to users.json
|
||||||
|
// anyway (snapshotPersistableLocked excludes it), so this is in-memory
|
||||||
|
// only.
|
||||||
func (a *Authenticator) AddAdminFromPlainPassword(username, plainPassword string) {
|
func (a *Authenticator) AddAdminFromPlainPassword(username, plainPassword string) {
|
||||||
|
salt, err := NewSalt()
|
||||||
|
if err != nil {
|
||||||
|
// Fall back to deterministic salt derived from the password hash
|
||||||
|
// rather than empty — preserves the uniform-shape property even
|
||||||
|
// if crypto/rand briefly errors at startup.
|
||||||
|
salt = ComputeSHA256(plainPassword)[:saltBytes*2]
|
||||||
|
}
|
||||||
a.AddUser(User{
|
a.AddUser(User{
|
||||||
Username: username,
|
Username: username,
|
||||||
PasswordHash: ComputeSHA256(plainPassword),
|
PasswordHash: HashPassword(plainPassword, salt),
|
||||||
Salt: "",
|
Salt: salt,
|
||||||
Role: "admin",
|
Role: "admin",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -200,18 +213,44 @@ func (a *Authenticator) ListUsers() []User {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSalt returns the per-user salt. If the user does not exist, returns ("", false).
|
// GetSalt returns the per-user salt for an existing user, or a
|
||||||
// Note: the C++ admin uses an empty salt — that is still considered "found"
|
// deterministic 16-hex pseudo-salt for an unknown user. The ok flag
|
||||||
// and the empty string is returned with ok=true.
|
// reports which case occurred, so callers can decide whether to update
|
||||||
|
// rate-limit / audit state — but the returned salt itself is shaped
|
||||||
|
// identically (16 hex chars) in both cases, defeating the user-existence
|
||||||
|
// probe an attacker would otherwise mount via /get_salt.
|
||||||
|
//
|
||||||
|
// The pseudo-salt is derived from a server-instance secret (the admin
|
||||||
|
// password hash, taken at first call) mixed with the username, so the
|
||||||
|
// same unknown user always sees the same fake salt across requests.
|
||||||
|
// Without this, an attacker could fingerprint the "fake-salt branch"
|
||||||
|
// by submitting the same username twice and watching for differences.
|
||||||
func (a *Authenticator) GetSalt(username string) (string, bool) {
|
func (a *Authenticator) GetSalt(username string) (string, bool) {
|
||||||
a.mu.RLock()
|
a.mu.RLock()
|
||||||
u, ok := a.users[username]
|
u, ok := a.users[username]
|
||||||
a.mu.RUnlock()
|
a.mu.RUnlock()
|
||||||
if !ok {
|
if ok {
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return u.Salt, true
|
return u.Salt, true
|
||||||
}
|
}
|
||||||
|
return a.fakeSalt(username), false
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeSalt derives a deterministic 16-hex value for unknown usernames.
|
||||||
|
// The secret pepper is the bootstrap admin's password hash — present as
|
||||||
|
// long as the server has any admin, deterministic per deployment, never
|
||||||
|
// transmitted. Reveals nothing useful to an attacker even if reverse-
|
||||||
|
// engineered: the only thing they can do with it is reproduce the fake
|
||||||
|
// salt, which they already see in the response.
|
||||||
|
func (a *Authenticator) fakeSalt(username string) string {
|
||||||
|
a.mu.RLock()
|
||||||
|
pepper := ""
|
||||||
|
if admin, ok := a.users["admin"]; ok {
|
||||||
|
pepper = admin.PasswordHash
|
||||||
|
}
|
||||||
|
a.mu.RUnlock()
|
||||||
|
digest := ComputeSHA256("yama-fake-salt|" + pepper + "|" + username)
|
||||||
|
return digest[:saltBytes*2]
|
||||||
|
}
|
||||||
|
|
||||||
// VerifyLogin checks a challenge-response login. The browser sends
|
// VerifyLogin checks a challenge-response login. The browser sends
|
||||||
// response = SHA256(passwordHash + nonce). On success the function mints a
|
// response = SHA256(passwordHash + nonce). On success the function mints a
|
||||||
|
|||||||
@@ -14,19 +14,31 @@ func TestSHA256Vector(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoginRoundTripAdminEmptySalt(t *testing.T) {
|
// adminLoginResponse helps tests compute the right login response for an
|
||||||
|
// admin account that now uses a real per-instance salt.
|
||||||
|
func adminLoginResponse(t *testing.T, a *Authenticator, username, password, nonce string) string {
|
||||||
|
t.Helper()
|
||||||
|
salt, ok := a.GetSalt(username)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("admin %s not registered", username)
|
||||||
|
}
|
||||||
|
return ComputeSHA256(HashPassword(password, salt) + nonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginRoundTripAdmin(t *testing.T) {
|
||||||
a := New()
|
a := New()
|
||||||
a.AddAdminFromPlainPassword("admin", "hunter2")
|
a.AddAdminFromPlainPassword("admin", "hunter2")
|
||||||
|
|
||||||
salt, ok := a.GetSalt("admin")
|
salt, ok := a.GetSalt("admin")
|
||||||
if !ok || salt != "" {
|
if !ok {
|
||||||
t.Fatalf("admin salt: ok=%v salt=%q", ok, salt)
|
t.Fatal("admin should be found")
|
||||||
|
}
|
||||||
|
if len(salt) != 2*saltBytes {
|
||||||
|
t.Fatalf("admin salt should be a real 16-hex value, got %q (len=%d)", salt, len(salt))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simulate the browser: nonce = "abc123", response = SHA256(passwordHash + nonce)
|
|
||||||
nonce := "abc123"
|
nonce := "abc123"
|
||||||
passwordHash := ComputeSHA256("hunter2")
|
response := adminLoginResponse(t, a, "admin", "hunter2", nonce)
|
||||||
response := ComputeSHA256(passwordHash + nonce)
|
|
||||||
|
|
||||||
token, role, err := a.VerifyLogin("admin", response, nonce)
|
token, role, err := a.VerifyLogin("admin", response, nonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -71,6 +83,33 @@ func TestLoginRoundTripViewerWithSalt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestGetSaltUnknownUserShape verifies the salt-probe mitigation: an
|
||||||
|
// unknown user must get back a value that's shape-identical to a real
|
||||||
|
// salt, so an attacker can't tell from /get_salt alone whether a
|
||||||
|
// username exists.
|
||||||
|
func TestGetSaltUnknownUserShape(t *testing.T) {
|
||||||
|
a := New()
|
||||||
|
a.AddAdminFromPlainPassword("admin", "pw")
|
||||||
|
|
||||||
|
fake, ok := a.GetSalt("nobody")
|
||||||
|
if ok {
|
||||||
|
t.Fatal("ok should be false for unknown user")
|
||||||
|
}
|
||||||
|
if len(fake) != 2*saltBytes {
|
||||||
|
t.Fatalf("fake salt should be %d hex chars; got %q (len=%d)", 2*saltBytes, fake, len(fake))
|
||||||
|
}
|
||||||
|
// Determinism: repeated probes for the same username get the same fake.
|
||||||
|
fake2, _ := a.GetSalt("nobody")
|
||||||
|
if fake != fake2 {
|
||||||
|
t.Fatalf("fake salt should be deterministic for repeated probes; got %q vs %q", fake, fake2)
|
||||||
|
}
|
||||||
|
// Different usernames get different fake salts.
|
||||||
|
other, _ := a.GetSalt("ghost")
|
||||||
|
if fake == other {
|
||||||
|
t.Fatalf("fake salts should differ across usernames; both = %q", fake)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLoginRejectsWrongResponse(t *testing.T) {
|
func TestLoginRejectsWrongResponse(t *testing.T) {
|
||||||
a := New()
|
a := New()
|
||||||
a.AddAdminFromPlainPassword("admin", "x")
|
a.AddAdminFromPlainPassword("admin", "x")
|
||||||
@@ -89,7 +128,7 @@ func TestTokenExpiry(t *testing.T) {
|
|||||||
a.SetTokenExpire(50 * time.Millisecond)
|
a.SetTokenExpire(50 * time.Millisecond)
|
||||||
a.AddAdminFromPlainPassword("admin", "x")
|
a.AddAdminFromPlainPassword("admin", "x")
|
||||||
nonce, _ := NewNonce()
|
nonce, _ := NewNonce()
|
||||||
response := ComputeSHA256(ComputeSHA256("x") + nonce)
|
response := adminLoginResponse(t, a, "admin", "x", nonce)
|
||||||
token, _, err := a.VerifyLogin("admin", response, nonce)
|
token, _, err := a.VerifyLogin("admin", response, nonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -107,10 +146,57 @@ func TestRevoke(t *testing.T) {
|
|||||||
a := New()
|
a := New()
|
||||||
a.AddAdminFromPlainPassword("admin", "x")
|
a.AddAdminFromPlainPassword("admin", "x")
|
||||||
nonce, _ := NewNonce()
|
nonce, _ := NewNonce()
|
||||||
response := ComputeSHA256(ComputeSHA256("x") + nonce)
|
response := adminLoginResponse(t, a, "admin", "x", nonce)
|
||||||
token, _, _ := a.VerifyLogin("admin", response, nonce)
|
token, _, _ := a.VerifyLogin("admin", response, nonce)
|
||||||
a.RevokeToken(token)
|
a.RevokeToken(token)
|
||||||
if _, err := a.ValidateToken(token); err == nil {
|
if _, err := a.ValidateToken(token); err == nil {
|
||||||
t.Fatal("revoked token should not validate")
|
t.Fatal("revoked token should not validate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterAllowsBurstThenBlocks(t *testing.T) {
|
||||||
|
r := NewRateLimiter(3, time.Minute)
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
if !r.Allow("ip-a") {
|
||||||
|
t.Fatalf("attempt %d should be allowed", i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.Allow("ip-a") {
|
||||||
|
t.Fatal("4th attempt should be denied")
|
||||||
|
}
|
||||||
|
// Different key has independent budget.
|
||||||
|
if !r.Allow("ip-b") {
|
||||||
|
t.Fatal("different key should still be allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterReset(t *testing.T) {
|
||||||
|
r := NewRateLimiter(2, time.Minute)
|
||||||
|
r.Allow("k")
|
||||||
|
r.Allow("k")
|
||||||
|
if r.Allow("k") {
|
||||||
|
t.Fatal("3rd should be denied")
|
||||||
|
}
|
||||||
|
r.Reset("k")
|
||||||
|
if !r.Allow("k") {
|
||||||
|
t.Fatal("after Reset, should be allowed again")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterDisabledWhenZeroLimit(t *testing.T) {
|
||||||
|
r := NewRateLimiter(0, time.Minute)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
if !r.Allow("k") {
|
||||||
|
t.Fatalf("limit=0 should never deny, denied at i=%d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterNilSafe(t *testing.T) {
|
||||||
|
var r *RateLimiter
|
||||||
|
if !r.Allow("anything") {
|
||||||
|
t.Fatal("nil limiter should allow")
|
||||||
|
}
|
||||||
|
r.Reset("anything")
|
||||||
|
r.Sweep()
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user