Improve Go Server to support remote desktop and command control #1

Merged
yuanyuanxiang merged 7 commits from feature/go-server into main 2026-05-18 22:06:08 +00:00
13 changed files with 1211 additions and 21 deletions
Showing only changes of commit b1f229706c - Show all commits

View File

@@ -8,8 +8,12 @@
"mode": "auto", "mode": "auto",
"program": "${workspaceFolder}/cmd", "program": "${workspaceFolder}/cmd",
"cwd": "${workspaceFolder}", "cwd": "${workspaceFolder}",
"args": [], "args": [
"env": {}, "-port=9090"
],
"env": {
"YAMA_WEB_ADMIN_PASS": "3.14159"
},
"console": "integratedTerminal", "console": "integratedTerminal",
"preLaunchTask": "sync-web-assets" "preLaunchTask": "sync-web-assets"
}, },
@@ -23,7 +27,9 @@
"args": [ "args": [
"-port=9090" "-port=9090"
], ],
"env": {}, "env": {
"YAMA_WEB_ADMIN_PASS": "3.14159"
},
"console": "integratedTerminal", "console": "integratedTerminal",
"buildFlags": "-gcflags='all=-N -l'", "buildFlags": "-gcflags='all=-N -l'",
"preLaunchTask": "sync-web-assets" "preLaunchTask": "sync-web-assets"

View File

@@ -25,9 +25,15 @@ server/go/
│ └── pool.go # Goroutine 工作池 │ └── pool.go # Goroutine 工作池
├── logger/ ├── logger/
│ └── logger.go # 日志模块 (基于 zerolog) │ └── logger.go # 日志模块 (基于 zerolog)
├── hub/
│ └── hub.go # 在线设备注册表 + 事件订阅
├── wsauth/
│ └── wsauth.go # Web 鉴权 (challenge-response + 不透明 token)
├── web/ ├── web/
│ ├── embed.go # //go:embed 嵌入 HTML/xterm.js 等 web 资源 │ ├── embed.go # //go:embed 嵌入 HTML/xterm.js 等 web 资源
│ ├── server.go # HTTP server (静态页面 + 后续 WS 信令) │ ├── server.go # HTTP server (静态页面 + REST + WS 路由)
│ ├── ws.go # WebSocket 连接生命周期
│ ├── ws_handlers.go # WS 消息分发与处理
│ └── assets/ │ └── assets/
│ ├── index.html # 从 ../../web/index.html sync 而来 (gitignored) │ ├── index.html # 从 ../../web/index.html sync 而来 (gitignored)
│ └── static/ # 第三方 xterm.js 资源 (checked in) │ └── static/ # 第三方 xterm.js 资源 (checked in)
@@ -109,9 +115,10 @@ VSCode F5 调试时由 `sync-web-assets` preLaunchTask 自动同步。
### 环境变量 ### 环境变量
| 变量 | 说明 | 示例 | | 变量 | 说明 | 示例 |
|------|------|------| | ---- | ---- | ---- |
| `YAMA_PWDHASH` | 密码的 SHA256 哈希值 (64位十六进制) | `61f04dd6...` | | `YAMA_PWDHASH` | 密码的 SHA256 哈希值 (64位十六进制) | `61f04dd6...` |
| `YAMA_PWD` | 超级密码,用于 HMAC 签名验证 | `your_super_password` | | `YAMA_PWD` | 超级密码,用于 HMAC 签名验证;也作为 Web admin 密码的默认来源 | `your_super_password` |
| `YAMA_WEB_ADMIN_PASS` | Web UI 的 admin 密码(明文);优先于 `YAMA_PWD`。两者都未设置时 Web 登录禁用 | `your_admin_password` |
```bash ```bash
# Linux/macOS # Linux/macOS

View File

@@ -8,13 +8,16 @@ import (
"strconv" "strconv"
"strings" "strings"
"syscall" "syscall"
"time"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/auth" "github.com/yuanyuanxiang/SimpleRemoter/server/go/auth"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection" "github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/hub"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/logger" "github.com/yuanyuanxiang/SimpleRemoter/server/go/logger"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol" "github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/server" "github.com/yuanyuanxiang/SimpleRemoter/server/go/server"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/web" "github.com/yuanyuanxiang/SimpleRemoter/server/go/web"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/wsauth"
) )
// MyHandler implements the server.Handler interface // MyHandler implements the server.Handler interface
@@ -22,6 +25,7 @@ type MyHandler struct {
log *logger.Logger log *logger.Logger
auth *auth.Authenticator auth *auth.Authenticator
srv *server.Server srv *server.Server
hub *hub.Hub
} }
// OnConnect is called when a client connects // OnConnect is called when a client connects
@@ -37,6 +41,7 @@ func (h *MyHandler) OnDisconnect(ctx *connection.Context) {
"clientID", info.ClientID, "clientID", info.ClientID,
"computer", info.ComputerName, "computer", info.ComputerName,
) )
h.hub.Unregister(info.ClientID)
} }
} }
@@ -110,8 +115,40 @@ func (h *MyHandler) handleLogin(ctx *connection.Context, data []byte) {
"version", info.ModuleVersion, "version", info.ModuleVersion,
"path", clientInfo.FilePath, "path", clientInfo.FilePath,
) )
// PCName carries "ComputerName/Group"; ModuleVersion carries "Version-Capability".
// strings.Cut returns the full string as the head when the separator is
// absent, which gives us the natural "no group / no capability" fallback.
name, group, _ := strings.Cut(info.PCName, "/")
version, capability, _ := strings.Cut(info.ModuleVersion, "-")
// Reserved field 10 (ClientLoc) is the client-reported geo string.
location := ""
if len(reserved) > 10 {
location = info.GetReservedField(10)
}
// Register with hub so the web side can list this device. Sub-connections
// (screen / terminal etc.) reuse the MasterID and will overwrite this entry
// harmlessly, but only the main login carries enough info to be useful here.
h.hub.Register(&hub.Device{
ID: clientID,
Name: name,
Group: group,
Version: version,
Capability: capability,
OS: info.OsVerInfo,
CPU: clientInfo.CPU,
FilePath: clientInfo.FilePath,
InstallTime: info.StartTime,
Location: location,
PeerIP: ctx.GetPeerIP(),
PublicIP: clientInfo.IP,
ConnectedAt: time.Now(),
})
} }
// handleAuth handles authorization request (TOKEN_AUTH = 100) // handleAuth handles authorization request (TOKEN_AUTH = 100)
func (h *MyHandler) handleAuth(ctx *connection.Context, data []byte) { func (h *MyHandler) handleAuth(ctx *connection.Context, data []byte) {
result := h.auth.Authenticate(data) result := h.auth.Authenticate(data)
@@ -160,6 +197,25 @@ func (h *MyHandler) handleHeartbeat(ctx *connection.Context, data []byte) {
uint64(data[5])<<32 | uint64(data[6])<<40 | uint64(data[7])<<48 | uint64(data[8])<<56 uint64(data[5])<<32 | uint64(data[6])<<40 | uint64(data[7])<<48 | uint64(data[8])<<56
} }
// Forward live fields (ActiveWnd + Ping) to the hub so the web UI can
// display current latency and foreground window per device. Skip until
// login has happened — the hub is keyed by MasterID, which only exists
// post-login.
if info := ctx.GetInfo(); info.ClientID != "" {
var rtt int32
var activeWindow string
// ActiveWnd at data[9..521] is a 512-byte GBK-encoded string.
if len(data) >= 9+512 {
activeWindow = protocol.GbkToUTF8(data[9 : 9+512])
}
// Ping at data[521..525] is a little-endian int32.
if len(data) >= 525 {
rtt = int32(uint32(data[521]) | uint32(data[522])<<8 |
uint32(data[523])<<16 | uint32(data[524])<<24)
}
h.hub.UpdateLive(info.ClientID, int(rtt), activeWindow)
}
// Authenticate heartbeat if it contains authorization info // Authenticate heartbeat if it contains authorization info
// data[1:] skips the command byte to get the raw Heartbeat structure // data[1:] skips the command byte to get the raw Heartbeat structure
var authorized byte = 0 var authorized byte = 0
@@ -269,6 +325,27 @@ func main() {
// Create authenticator (shared by all servers) // Create authenticator (shared by all servers)
authenticator := auth.New(authCfg) authenticator := auth.New(authCfg)
// Shared device registry — every TCP handler reports devices into it,
// the HTTP server reads from it.
deviceHub := hub.New()
// Web user authenticator. Bootstrap admin from env var YAMA_WEB_ADMIN_PASS;
// if unset, fall back to YAMA_PWD (same secret the TCP authorization uses)
// so a single password env var is enough to bring up the whole stack.
// If neither is set, no admin is registered and login will always fail —
// the user must define a password before browsers can log in.
webAuth := wsauth.New()
adminPass := os.Getenv("YAMA_WEB_ADMIN_PASS")
if adminPass == "" {
adminPass = os.Getenv("YAMA_PWD")
}
if adminPass != "" {
webAuth.AddAdminFromPlainPassword("admin", adminPass)
log.Info("Web admin user configured")
} else {
log.Warn("Neither YAMA_WEB_ADMIN_PASS nor YAMA_PWD is set; web login will be unavailable")
}
// Create servers for each port // Create servers for each port
var servers []*server.Server var servers []*server.Server
for _, port := range ports { for _, port := range ports {
@@ -284,6 +361,7 @@ func main() {
log: log.WithPrefix(fmt.Sprintf("Handler:%d", port)), log: log.WithPrefix(fmt.Sprintf("Handler:%d", port)),
auth: authenticator, auth: authenticator,
srv: srv, srv: srv,
hub: deviceHub,
} }
srv.SetHandler(handler) srv.SetHandler(handler)
@@ -297,8 +375,9 @@ func main() {
} }
} }
// Start HTTP server for web UI (Phase 1: serves index.html only) // Start HTTP server for web UI. Hub gives it read-only access to the
httpSrv := web.New(*httpPort, log.WithPrefix("Web")) // device registry; the authenticator owns user accounts and session tokens.
httpSrv := web.New(*httpPort, log.WithPrefix("Web"), deviceHub, webAuth)
if err := httpSrv.Start(); err != nil { if err := httpSrv.Start(); err != nil {
log.Fatal("Failed to start HTTP server: %v", err) log.Fatal("Failed to start HTTP server: %v", err)
} }

View File

@@ -3,6 +3,7 @@ module github.com/yuanyuanxiang/SimpleRemoter/server/go
go 1.24.5 go 1.24.5
require ( require (
github.com/gorilla/websocket v1.5.3
github.com/klauspost/compress v1.18.2 github.com/klauspost/compress v1.18.2
github.com/rs/zerolog v1.34.0 github.com/rs/zerolog v1.34.0
golang.org/x/text v0.32.0 golang.org/x/text v0.32.0

View File

@@ -1,5 +1,7 @@
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=

217
server/go/hub/hub.go Normal file
View File

@@ -0,0 +1,217 @@
// Package hub maintains the registry of currently online devices and acts as
// the bridge between the TCP server (which sees raw client connections) and
// the web server (which serves browser clients).
//
// The TCP side calls RegisterDevice / UnregisterDevice as clients come and go.
// The web side calls ListDevices / GetDevice / (Phase 4) SendToDevice.
// Neither side imports the other — both depend only on this package.
//
// Phase 3 scope: device list only. Frame/cursor pub-sub and SendToDevice are
// added in later phases as features need them.
package hub
import (
"sync"
"time"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
)
// Device is the internal record for one logical end-device (keyed by MasterID).
// A single device may use multiple TCP sub-connections (screen, terminal …);
// only the main login connection is stored here.
//
// PCName from LOGIN_INFOR is interpreted as "ComputerName/Group" and
// ModuleVersion as "Version-Capability"; the split halves live in separate
// fields so the front-end can render them independently.
type Device struct {
ID string // MasterID — stable identifier the client reports at login
Name string // PCName before '/' (real computer name)
Group string // PCName after '/' (group label; may be empty)
Version string // ModuleVersion before '-' (semantic version)
Capability string // ModuleVersion after '-' (capability tags; may be empty)
OS string // OS version string
CPU string // from LOGIN_INFOR reserved field 2
FilePath string // from LOGIN_INFOR reserved field 4
InstallTime string // from LOGIN_INFOR reserved field 6 (or StartTime)
Location string // client-reported geo string (reserved field 10)
PeerIP string // network-level remote address as seen by the server
PublicIP string // client-reported public IP (reserved field 11)
ConnectedAt time.Time
// Live fields refreshed on every heartbeat. Protected by hub.mu.
RTT int // network latency in ms (Heartbeat.Ping)
ActiveWindow string // foreground window title (Heartbeat.ActiveWnd, decoded)
// conn is the main connection's context. Web side will use it in Phase 4
// to push COMMAND_SCREEN_SPY and similar commands via the hub.
conn *connection.Context
}
// DeviceInfo is the JSON-safe projection of Device for the /api/devices
// endpoint and the WS device_list message. Field names match what the
// existing browser front-end expects.
type DeviceInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Group string `json:"group,omitempty"`
Version string `json:"version"`
Capability string `json:"capability,omitempty"`
OS string `json:"os"`
CPU string `json:"cpu,omitempty"`
FilePath string `json:"file_path,omitempty"`
InstallTime string `json:"install_time,omitempty"`
Location string `json:"location,omitempty"`
IP string `json:"ip"` // client-reported public IP (matches C++ key)
PeerIP string `json:"peer_ip,omitempty"`
RTT int `json:"rtt"`
ActiveWindow string `json:"activeWindow,omitempty"`
ConnectedAt int64 `json:"connected_at"`
Online bool `json:"online"`
}
// EventHandler receives notifications about device lifecycle and per-tick
// live updates. Methods are invoked synchronously from Register / Unregister /
// UpdateLive — implementations must be non-blocking (typically just write to
// a channel or queue).
type EventHandler interface {
OnDeviceOnline(d DeviceInfo)
OnDeviceOffline(id string)
OnDeviceUpdate(id string, rtt int, activeWindow string)
}
// Hub is a thread-safe registry of online devices.
type Hub struct {
mu sync.RWMutex
devices map[string]*Device
subMu sync.RWMutex
subscribers []EventHandler
}
// New returns an empty Hub.
func New() *Hub {
return &Hub{devices: make(map[string]*Device)}
}
// Subscribe registers an EventHandler. The returned func removes it.
// Multiple handlers are supported; each receives every event.
func (h *Hub) Subscribe(eh EventHandler) (unsubscribe func()) {
h.subMu.Lock()
h.subscribers = append(h.subscribers, eh)
h.subMu.Unlock()
return func() {
h.subMu.Lock()
defer h.subMu.Unlock()
for i, x := range h.subscribers {
if x == eh {
h.subscribers = append(h.subscribers[:i], h.subscribers[i+1:]...)
return
}
}
}
}
func (h *Hub) snapshotSubscribers() []EventHandler {
h.subMu.RLock()
defer h.subMu.RUnlock()
out := make([]EventHandler, len(h.subscribers))
copy(out, h.subscribers)
return out
}
// Register records a device as online. Re-registering an existing ID overwrites
// the previous entry (e.g. a client reconnect with the same MasterID).
// A nil device or empty ID is silently ignored.
// Subscribers are notified after the device is added.
func (h *Hub) Register(d *Device) {
if d == nil || d.ID == "" {
return
}
h.mu.Lock()
h.devices[d.ID] = d
info := deviceToInfo(d)
h.mu.Unlock()
for _, s := range h.snapshotSubscribers() {
s.OnDeviceOnline(info)
}
}
// Unregister removes a device by ID. No-op if not present.
// Subscribers are notified after the device is removed (only if it existed).
func (h *Hub) Unregister(id string) {
if id == "" {
return
}
h.mu.Lock()
_, existed := h.devices[id]
delete(h.devices, id)
h.mu.Unlock()
if !existed {
return
}
for _, s := range h.snapshotSubscribers() {
s.OnDeviceOffline(id)
}
}
// ListDevices returns a fresh snapshot slice. The caller may mutate it freely;
// it shares no state with the hub.
func (h *Hub) ListDevices() []DeviceInfo {
h.mu.RLock()
defer h.mu.RUnlock()
out := make([]DeviceInfo, 0, len(h.devices))
for _, d := range h.devices {
out = append(out, deviceToInfo(d))
}
return out
}
func deviceToInfo(d *Device) DeviceInfo {
return DeviceInfo{
ID: d.ID,
Name: d.Name,
Group: d.Group,
Version: d.Version,
Capability: d.Capability,
OS: d.OS,
CPU: d.CPU,
FilePath: d.FilePath,
InstallTime: d.InstallTime,
Location: d.Location,
IP: d.PublicIP,
PeerIP: d.PeerIP,
RTT: d.RTT,
ActiveWindow: d.ActiveWindow,
ConnectedAt: d.ConnectedAt.Unix(),
Online: true, // a device that's in the map is by definition online
}
}
// UpdateLive applies a heartbeat-derived RTT and active-window title to the
// device's live fields, then notifies subscribers. No-op if the device is
// not registered (e.g. heartbeat arriving for a connection that never sent
// TOKEN_LOGIN or has already disconnected).
func (h *Hub) UpdateLive(id string, rtt int, activeWindow string) {
if id == "" {
return
}
h.mu.Lock()
d, ok := h.devices[id]
if !ok {
h.mu.Unlock()
return
}
d.RTT = rtt
d.ActiveWindow = activeWindow
h.mu.Unlock()
for _, s := range h.snapshotSubscribers() {
s.OnDeviceUpdate(id, rtt, activeWindow)
}
}
// Count returns the current number of online devices.
func (h *Hub) Count() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.devices)
}

152
server/go/hub/hub_test.go Normal file
View File

@@ -0,0 +1,152 @@
package hub
import (
"fmt"
"sync"
"testing"
"time"
)
func TestHubRegisterListUnregister(t *testing.T) {
h := New()
if got := h.Count(); got != 0 {
t.Fatalf("empty hub: want Count=0, got %d", got)
}
h.Register(&Device{ID: "a", Name: "Alice", ConnectedAt: time.Now()})
h.Register(&Device{ID: "b", Name: "Bob", ConnectedAt: time.Now()})
if got := h.Count(); got != 2 {
t.Fatalf("after 2 registers: want Count=2, got %d", got)
}
list := h.ListDevices()
if len(list) != 2 {
t.Fatalf("want 2 devices in list, got %d", len(list))
}
h.Unregister("a")
if got := h.Count(); got != 1 {
t.Fatalf("after unregister: want Count=1, got %d", got)
}
// Unregister non-existent ID is a no-op
h.Unregister("ghost")
if got := h.Count(); got != 1 {
t.Fatalf("after no-op unregister: want Count=1, got %d", got)
}
}
func TestHubNilAndEmptyIgnored(t *testing.T) {
h := New()
h.Register(nil)
h.Register(&Device{ID: ""})
h.Unregister("")
if got := h.Count(); got != 0 {
t.Fatalf("nil/empty register should be no-op, got Count=%d", got)
}
}
type captureHandler struct {
mu sync.Mutex
online []string
offline []string
updates []string // formatted "id:rtt"
}
func (c *captureHandler) OnDeviceOnline(d DeviceInfo) {
c.mu.Lock()
c.online = append(c.online, d.ID)
c.mu.Unlock()
}
func (c *captureHandler) OnDeviceOffline(id string) {
c.mu.Lock()
c.offline = append(c.offline, id)
c.mu.Unlock()
}
func (c *captureHandler) OnDeviceUpdate(id string, rtt int, _ string) {
c.mu.Lock()
c.updates = append(c.updates, fmt.Sprintf("%s:%d", id, rtt))
c.mu.Unlock()
}
func TestHubSubscribeEvents(t *testing.T) {
h := New()
c := &captureHandler{}
unsub := h.Subscribe(c)
h.Register(&Device{ID: "x", Name: "x"})
h.Register(&Device{ID: "y", Name: "y"})
h.Unregister("x")
h.Unregister("nonexistent") // no event
if len(c.online) != 2 || c.online[0] != "x" || c.online[1] != "y" {
t.Fatalf("online events: %+v", c.online)
}
if len(c.offline) != 1 || c.offline[0] != "x" {
t.Fatalf("offline events: %+v", c.offline)
}
unsub()
h.Register(&Device{ID: "z"})
if len(c.online) != 2 {
t.Fatalf("after unsubscribe should not receive events: %+v", c.online)
}
}
func TestHubUpdateLive(t *testing.T) {
h := New()
c := &captureHandler{}
h.Subscribe(c)
h.Register(&Device{ID: "x", Name: "x"})
h.UpdateLive("x", 42, "Notepad")
h.UpdateLive("ghost", 999, "should be ignored") // unknown id, no event
if len(c.updates) != 1 || c.updates[0] != "x:42" {
t.Fatalf("updates: %+v", c.updates)
}
list := h.ListDevices()
if list[0].RTT != 42 || list[0].ActiveWindow != "Notepad" {
t.Fatalf("live fields not applied: %+v", list[0])
}
}
func TestHubRegisterOverwrites(t *testing.T) {
h := New()
h.Register(&Device{ID: "x", Name: "first"})
h.Register(&Device{ID: "x", Name: "second"})
list := h.ListDevices()
if len(list) != 1 || list[0].Name != "second" {
t.Fatalf("re-register should overwrite, got %+v", list)
}
}
// Race detector should not fire under `go test -race ./hub/...`.
func TestHubConcurrent(t *testing.T) {
h := New()
const goroutines = 50
const opsPer = 100
var wg sync.WaitGroup
for g := range goroutines {
wg.Add(1)
go func(g int) {
defer wg.Done()
for i := range opsPer {
id := fmt.Sprintf("g%d-%d", g, i)
h.Register(&Device{ID: id, Name: id, ConnectedAt: time.Now()})
_ = h.ListDevices()
_ = h.Count()
h.Unregister(id)
}
}(g)
}
wg.Wait()
if got := h.Count(); got != 0 {
t.Fatalf("after all unregisters: want 0, got %d", got)
}
}

View File

@@ -9,8 +9,11 @@ import (
"golang.org/x/text/transform" "golang.org/x/text/transform"
) )
// gbkToUTF8 converts GBK encoded bytes to UTF-8 string // GbkToUTF8 converts GBK encoded bytes to UTF-8 string. The input is treated
func gbkToUTF8(data []byte) string { // as a null-terminated GBK buffer (typical for Windows clients); content
// after the first NUL byte is discarded. Non-printable characters are
// stripped from the result.
func GbkToUTF8(data []byte) string {
// Find the first null byte and truncate there // Find the first null byte and truncate there
if idx := bytes.IndexByte(data, 0); idx >= 0 { if idx := bytes.IndexByte(data, 0); idx >= 0 {
data = data[:idx] data = data[:idx]
@@ -111,17 +114,17 @@ func ParseLoginInfo(data []byte) (*LoginInfo, error) {
// Parse module version (offset 164, 24 bytes) // Parse module version (offset 164, 24 bytes)
// This contains date string like "Dec 19 2025" // This contains date string like "Dec 19 2025"
if len(data) >= OffsetModuleVersion+24 { if len(data) >= OffsetModuleVersion+24 {
info.ModuleVersion = gbkToUTF8(data[OffsetModuleVersion : OffsetModuleVersion+24]) info.ModuleVersion = GbkToUTF8(data[OffsetModuleVersion : OffsetModuleVersion+24])
} }
// Parse PC name (offset 188, 240 bytes) // Parse PC name (offset 188, 240 bytes)
if len(data) >= OffsetPCName+240 { if len(data) >= OffsetPCName+240 {
info.PCName = gbkToUTF8(data[OffsetPCName : OffsetPCName+240]) info.PCName = GbkToUTF8(data[OffsetPCName : OffsetPCName+240])
} }
// Parse Master ID (offset 428, 20 bytes) // Parse Master ID (offset 428, 20 bytes)
if len(data) >= OffsetMasterID+20 { if len(data) >= OffsetMasterID+20 {
info.MasterID = gbkToUTF8(data[OffsetMasterID : OffsetMasterID+20]) info.MasterID = GbkToUTF8(data[OffsetMasterID : OffsetMasterID+20])
} }
// Parse WebCam exist (offset 448, 4 bytes) // Parse WebCam exist (offset 448, 4 bytes)
@@ -136,14 +139,14 @@ func ParseLoginInfo(data []byte) (*LoginInfo, error) {
// Parse Start time (offset 456, 20 bytes) // Parse Start time (offset 456, 20 bytes)
if len(data) >= OffsetStartTime+20 { if len(data) >= OffsetStartTime+20 {
info.StartTime = gbkToUTF8(data[OffsetStartTime : OffsetStartTime+20]) info.StartTime = GbkToUTF8(data[OffsetStartTime : OffsetStartTime+20])
} }
// Parse Reserved (offset 476, 512 bytes) - contains additional info // Parse Reserved (offset 476, 512 bytes) - contains additional info
if len(data) >= OffsetReserved+512 { if len(data) >= OffsetReserved+512 {
info.Reserved = gbkToUTF8(data[OffsetReserved : OffsetReserved+512]) info.Reserved = GbkToUTF8(data[OffsetReserved : OffsetReserved+512])
} else if len(data) > OffsetReserved { } else if len(data) > OffsetReserved {
info.Reserved = gbkToUTF8(data[OffsetReserved:]) info.Reserved = GbkToUTF8(data[OffsetReserved:])
} }
return info, nil return info, nil
@@ -152,7 +155,7 @@ func ParseLoginInfo(data []byte) (*LoginInfo, error) {
// parseOsVersionInfo parses the OS version info field // parseOsVersionInfo parses the OS version info field
// The C++ client fills this with a readable string like "Windows 10" via getSystemName() // The C++ client fills this with a readable string like "Windows 10" via getSystemName()
func parseOsVersionInfo(data []byte) string { func parseOsVersionInfo(data []byte) string {
return gbkToUTF8(data) return GbkToUTF8(data)
} }
// ParseReserved parses the reserved field into a slice of strings // ParseReserved parses the reserved field into a slice of strings

View File

@@ -2,6 +2,7 @@ package web
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@@ -9,21 +10,28 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/hub"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/logger" "github.com/yuanyuanxiang/SimpleRemoter/server/go/logger"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/wsauth"
) )
// Server serves the web remote desktop UI: the embedded index.html, xterm.js // Server serves the web remote desktop UI: the embedded index.html, xterm.js
// static assets, and the PWA manifest. WebSocket signaling, device list and // static assets, the PWA manifest, and JSON APIs backed by the device hub.
// screen streaming will be wired up in later phases. // WebSocket signaling and screen streaming will be wired up in later phases.
type Server struct { type Server struct {
port int port int
log *logger.Logger log *logger.Logger
srv *http.Server srv *http.Server
hub *hub.Hub
auth *wsauth.Authenticator
ws *wsHub
} }
// New creates an HTTP server bound to the given port. port=0 disables the server. // New creates an HTTP server bound to the given port. port=0 disables the server.
func New(port int, log *logger.Logger) *Server { // The hub provides read access to the online-device registry; the authenticator
return &Server{port: port, log: log} // owns user accounts and session tokens.
func New(port int, log *logger.Logger, h *hub.Hub, auth *wsauth.Authenticator) *Server {
return &Server{port: port, log: log, hub: h, auth: auth}
} }
// Start launches the server in a goroutine and returns immediately. // Start launches the server in a goroutine and returns immediately.
@@ -34,10 +42,14 @@ func (s *Server) Start() error {
return nil return nil
} }
s.ws = newWSHub(s.auth, s.hub, s.log)
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/", s.handleIndex) mux.HandleFunc("/", s.handleIndex)
mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/manifest.json", s.handleManifest) mux.HandleFunc("/manifest.json", s.handleManifest)
mux.HandleFunc("/api/devices", s.handleDevices)
mux.HandleFunc("/ws", s.ws.serve)
mux.HandleFunc("/static/xterm.js", staticHandler(xtermJS, "application/javascript; charset=utf-8")) mux.HandleFunc("/static/xterm.js", staticHandler(xtermJS, "application/javascript; charset=utf-8"))
mux.HandleFunc("/static/xterm.css", staticHandler(xtermCSS, "text/css; charset=utf-8")) mux.HandleFunc("/static/xterm.css", staticHandler(xtermCSS, "text/css; charset=utf-8"))
mux.HandleFunc("/static/xterm-fit.js", staticHandler(xtermFitJS, "application/javascript; charset=utf-8")) mux.HandleFunc("/static/xterm-fit.js", staticHandler(xtermFitJS, "application/javascript; charset=utf-8"))
@@ -66,6 +78,9 @@ func (s *Server) Start() error {
// Stop gracefully shuts the server down. // Stop gracefully shuts the server down.
func (s *Server) Stop() { func (s *Server) Stop() {
if s.ws != nil {
s.ws.stop()
}
if s.srv == nil { if s.srv == nil {
return return
} }
@@ -89,6 +104,18 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"status":"ok"}`)) _, _ = w.Write([]byte(`{"status":"ok"}`))
} }
// handleDevices returns a JSON snapshot of currently-online devices. Empty
// array (not null) when no clients are connected — matches what the front-end
// will eventually expect.
func (s *Server) handleDevices(w http.ResponseWriter, r *http.Request) {
devices := s.hub.ListDevices()
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("Cache-Control", "no-store")
if err := json.NewEncoder(w).Encode(devices); err != nil {
s.log.Error("encode /api/devices: %v", err)
}
}
// PWA manifest. Referenced by <link rel="manifest"> in index.html. // PWA manifest. Referenced by <link rel="manifest"> in index.html.
// Static JSON, no template needed. // Static JSON, no template needed.
const manifestJSON = `{ const manifestJSON = `{

222
server/go/web/ws.go Normal file
View File

@@ -0,0 +1,222 @@
package web
import (
"encoding/json"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/hub"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/logger"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/wsauth"
)
// ----- WS framing knobs ---------------------------------------------------
const (
wsWriteWait = 10 * time.Second // single-frame write deadline
wsReadLimit = 1 << 20 // refuse incoming frames over 1 MB
wsSendBuffer = 64 // outbound queue depth per client
)
// upgrader allows any origin — this service is meant to be tunneled through
// frp, so requests can legitimately arrive from arbitrary front-end hosts.
// Adjust CheckOrigin once we have a deployment story.
var upgrader = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool { return true },
}
// ----- per-connection client state ----------------------------------------
type wsClient struct {
conn *websocket.Conn
send chan []byte
closed chan struct{}
once sync.Once
// Mutated under wsHub.mu (or only by the read loop owning this client).
nonce string // outstanding challenge — cleared after a successful login
token string // set once authenticated
role string // mirrors session role after login
addr string // client address for logs
}
// queue writes a payload onto the send buffer. Drops silently if the buffer
// is full so a stuck reader can't back-pressure the broadcast path.
func (c *wsClient) queue(payload []byte) {
select {
case c.send <- payload:
case <-c.closed:
default:
// queue full — caller is responsible for noticing if it matters.
}
}
// close signals both loops to exit. Safe to call multiple times.
func (c *wsClient) close() {
c.once.Do(func() {
close(c.closed)
_ = c.conn.Close()
})
}
// ----- ws hub: registry of all connected browsers -------------------------
type wsHub struct {
auth *wsauth.Authenticator
devices *hub.Hub
log *logger.Logger
mu sync.RWMutex
clients map[*wsClient]struct{}
unsub func()
}
func newWSHub(auth *wsauth.Authenticator, devices *hub.Hub, log *logger.Logger) *wsHub {
h := &wsHub{
auth: auth,
devices: devices,
log: log,
clients: make(map[*wsClient]struct{}),
}
h.unsub = devices.Subscribe(h)
return h
}
// stop unsubscribes from the device hub. Existing connections keep running
// until they close on their own; we only block new event delivery.
func (h *wsHub) stop() {
if h.unsub != nil {
h.unsub()
h.unsub = nil
}
}
// hub.EventHandler — invoked from hub.Register / hub.Unregister.
func (h *wsHub) OnDeviceOnline(_ hub.DeviceInfo) {
h.broadcastAuthenticated(`{"cmd":"devices_changed"}`)
}
func (h *wsHub) OnDeviceOffline(_ string) {
h.broadcastAuthenticated(`{"cmd":"devices_changed"}`)
}
// OnDeviceUpdate forwards heartbeat-derived liveness data so the device-list
// rows can refresh RTT and active-window labels without re-fetching.
func (h *wsHub) OnDeviceUpdate(id string, rtt int, activeWindow string) {
payload := mustJSON(map[string]any{
"cmd": "device_update",
"id": id,
"rtt": rtt,
"activeWindow": activeWindow,
})
h.mu.RLock()
defer h.mu.RUnlock()
for c := range h.clients {
if c.token != "" {
c.queue(payload)
}
}
}
func (h *wsHub) broadcastAuthenticated(msg string) {
payload := []byte(msg)
h.mu.RLock()
defer h.mu.RUnlock()
for c := range h.clients {
if c.token != "" {
c.queue(payload)
}
}
}
func (h *wsHub) register(c *wsClient) {
h.mu.Lock()
h.clients[c] = struct{}{}
h.mu.Unlock()
}
func (h *wsHub) unregister(c *wsClient) {
h.mu.Lock()
delete(h.clients, c)
h.mu.Unlock()
// Do NOT revoke the token: tokens are session-scoped, not WS-scoped.
// Frontend may close+reopen the WS at any time (visibilitychange handler,
// brief network blip, reload) and must be able to resume with the same
// cached token. The token expires on its own TTL.
c.close()
}
// ----- HTTP handler -------------------------------------------------------
func (h *wsHub) serve(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
h.log.Error("ws upgrade: %v", err)
return
}
conn.SetReadLimit(wsReadLimit)
nonce, err := wsauth.NewNonce()
if err != nil {
h.log.Error("nonce gen: %v", err)
_ = conn.Close()
return
}
client := &wsClient{
conn: conn,
send: make(chan []byte, wsSendBuffer),
closed: make(chan struct{}),
nonce: nonce,
addr: r.RemoteAddr,
}
h.register(client)
defer h.unregister(client)
go h.writeLoop(client)
// Greet with a challenge nonce so the browser can compute the login response.
client.queue([]byte(`{"cmd":"challenge","nonce":"` + nonce + `"}`))
h.readLoop(client)
}
// writeLoop drains the send queue. Exits when the channel is closed or a
// write fails. Closing the underlying connection is the read loop's job.
func (h *wsHub) writeLoop(c *wsClient) {
for {
select {
case msg := <-c.send:
_ = c.conn.SetWriteDeadline(time.Now().Add(wsWriteWait))
if err := c.conn.WriteMessage(websocket.TextMessage, msg); err != nil {
c.close()
return
}
case <-c.closed:
return
}
}
}
// readLoop dispatches incoming messages. Exits on read error (peer closed,
// timeout, malformed frame, etc.), which then triggers unregister cleanup.
func (h *wsHub) readLoop(c *wsClient) {
for {
_, raw, err := c.conn.ReadMessage()
if err != nil {
return
}
var env struct {
Cmd string `json:"cmd"`
}
if err := json.Unmarshal(raw, &env); err != nil {
continue // ignore garbage frames
}
h.dispatch(c, env.Cmd, raw)
}
}

View File

@@ -0,0 +1,166 @@
package web
import (
"encoding/json"
)
// dispatch routes one inbound message to its handler. The `raw` payload is
// passed through so handlers can re-parse to their own shape.
//
// Phase 3 implements: get_salt, login, get_devices, ping, disconnect.
// Phase 4/5/6 commands (connect, mouse, key, term_*, etc.) get a friendly
// "not yet implemented" reply so the browser UI doesn't hang silently.
func (h *wsHub) dispatch(c *wsClient, cmd string, raw []byte) {
switch cmd {
case "get_salt":
h.handleGetSalt(c, raw)
case "login":
h.handleLogin(c, raw)
case "get_devices":
h.handleGetDevices(c, raw)
case "ping":
// no-op heartbeat; the read itself was the keep-alive signal
case "disconnect":
c.queue([]byte(`{"cmd":"disconnect_result","ok":true}`))
// Reserved for later phases. Reply with a benign failure so the UI can
// surface a clear error instead of spinning indefinitely.
case "connect":
h.replyNotImplemented(c, "connect_result", "Screen sharing not yet implemented on Go server")
case "rdp_reset":
// silently ignored — UI uses this as a fire-and-forget
case "mouse", "key":
// silently ignored — no remote screen yet
case "term_open":
h.replyNotImplemented(c, "term_closed", "Web terminal not yet implemented on Go server")
case "term_input", "term_resize", "term_close":
// silently ignored — no terminal session
// Admin operations (Phase 7).
case "create_user":
h.replyNotImplemented(c, "create_user_result", "User management not yet implemented")
case "delete_user":
h.replyNotImplemented(c, "delete_user_result", "User management not yet implemented")
case "list_users":
h.replyNotImplemented(c, "list_users_result", "User management not yet implemented")
case "get_groups":
c.queue([]byte(`{"cmd":"groups","ok":true,"groups":[]}`))
}
}
func (h *wsHub) replyNotImplemented(c *wsClient, replyCmd, msg string) {
c.queue(mustJSON(map[string]any{
"cmd": replyCmd,
"ok": false,
"msg": msg,
}))
}
// ----- handlers ------------------------------------------------------------
func (h *wsHub) handleGetSalt(c *wsClient, raw []byte) {
var in struct {
Username string `json:"username"`
}
_ = json.Unmarshal(raw, &in)
salt, ok := h.auth.GetSalt(in.Username)
// Do not leak which usernames exist: always return ok=true with a salt.
// For unknown users hand back the empty salt (matches admin convention)
// so the timing/shape of the response is uniform.
if !ok {
salt = ""
}
c.queue(mustJSON(map[string]any{
"cmd": "salt",
"ok": true,
"salt": salt,
}))
}
func (h *wsHub) handleLogin(c *wsClient, raw []byte) {
var in struct {
Username string `json:"username"`
Response string `json:"response"`
Nonce string `json:"nonce"`
}
if err := json.Unmarshal(raw, &in); err != nil {
c.queue(mustJSON(map[string]any{"cmd": "login_result", "ok": false, "msg": "Invalid request"}))
return
}
// Bind the response to the challenge we issued at connect time so that
// replays from a different connection can't reuse a captured response.
if in.Nonce == "" || in.Nonce != c.nonce {
c.queue(mustJSON(map[string]any{"cmd": "login_result", "ok": false, "msg": "Invalid challenge"}))
return
}
token, role, err := h.auth.VerifyLogin(in.Username, in.Response, in.Nonce)
if err != nil {
// Burn the challenge on failure too — forces a new round on retry.
c.nonce = ""
c.queue(mustJSON(map[string]any{"cmd": "login_result", "ok": false, "msg": "Invalid credentials"}))
return
}
c.nonce = ""
c.token = token
c.role = role
h.log.Info("ws login: user=%s role=%s addr=%s", in.Username, role, c.addr)
c.queue(mustJSON(map[string]any{
"cmd": "login_result",
"ok": true,
"token": token,
"role": role,
}))
}
func (h *wsHub) handleGetDevices(c *wsClient, raw []byte) {
if !h.requireAuth(c, raw, "device_list") {
return
}
devices := h.devices.ListDevices()
c.queue(mustJSON(map[string]any{
"cmd": "device_list",
"ok": true,
"devices": devices,
}))
}
// requireAuth validates the token embedded in raw against the authenticator's
// session store (not against c.token). Tokens live independently of WS
// connections — the browser may reconnect after a visibility/network blip and
// resume with the same token, so we must not tie validity to one WS lifetime.
// On the first authenticated message we cache the token/role on the wsClient
// so broadcasts know to deliver to this connection.
func (h *wsHub) requireAuth(c *wsClient, raw []byte, replyCmd string) bool {
var in struct {
Token string `json:"token"`
}
_ = json.Unmarshal(raw, &in)
if in.Token == "" {
c.queue(mustJSON(map[string]any{"cmd": replyCmd, "ok": false}))
return false
}
sess, err := h.auth.ValidateToken(in.Token)
if err != nil {
c.queue(mustJSON(map[string]any{"cmd": replyCmd, "ok": false}))
return false
}
if c.token == "" {
c.token = in.Token
c.role = sess.Role
}
return true
}
func mustJSON(v any) []byte {
b, err := json.Marshal(v)
if err != nil {
// All callers pass simple map[string]any with primitive values;
// marshal can't realistically fail. If it does, return a safe fallback.
return []byte(`{"cmd":"error","msg":"internal encode error"}`)
}
return b
}

192
server/go/wsauth/wsauth.go Normal file
View File

@@ -0,0 +1,192 @@
// Package wsauth provides authentication and session-token management for
// the web service. Protocol surface (challenge nonce + SHA256-based response
// and SHA256(password+salt) hashes) is kept compatible with the existing
// browser front-end and users.json format. Internal token representation is
// deliberately different from the C++ counterpart — opaque random hex strings
// keyed into an in-memory map — to avoid leaking the proprietary token format.
package wsauth
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"sync"
"time"
)
// Default knobs. Override via SetDefaults at startup if needed.
const (
DefaultTokenExpire = 24 * time.Hour
nonceBytes = 16 // 32 hex chars
tokenBytes = 32 // 64 hex chars
saltBytes = 8 // 16 hex chars
)
// ErrInvalidToken is returned when a token is unknown or expired.
var ErrInvalidToken = errors.New("invalid or expired token")
// User is the credentials record for one web account.
type User struct {
Username string
PasswordHash string // SHA256(password+salt) in lowercase hex
Salt string // empty for admin (matches C++ convention)
Role string // "admin" or "viewer"
}
// Session is the authenticated state attached to a valid token.
type Session struct {
Username string
Role string
ExpiresAt time.Time
}
// Authenticator owns the user table and the active token map. It is safe to
// use from multiple goroutines.
type Authenticator struct {
mu sync.RWMutex
users map[string]*User // username -> user
tokens map[string]*Session // token -> session
tokenExpire time.Duration
}
// New returns an empty Authenticator. Call AddUser to populate.
func New() *Authenticator {
return &Authenticator{
users: make(map[string]*User),
tokens: make(map[string]*Session),
tokenExpire: DefaultTokenExpire,
}
}
// SetTokenExpire overrides the default session lifetime.
func (a *Authenticator) SetTokenExpire(d time.Duration) {
if d <= 0 {
return
}
a.mu.Lock()
a.tokenExpire = d
a.mu.Unlock()
}
// AddUser registers a user. PasswordHash should already be
// SHA256(password+salt) in lowercase hex; pass empty Salt to mirror the
// admin-style "no salt" convention used by the C++ side.
func (a *Authenticator) AddUser(u User) {
if u.Username == "" {
return
}
a.mu.Lock()
a.users[u.Username] = &u
a.mu.Unlock()
}
// AddAdminFromPlainPassword is a convenience for the bootstrap admin: salt is
// empty (matching the C++ admin record), hash is SHA256(password).
func (a *Authenticator) AddAdminFromPlainPassword(username, plainPassword string) {
a.AddUser(User{
Username: username,
PasswordHash: ComputeSHA256(plainPassword),
Salt: "",
Role: "admin",
})
}
// GetSalt returns the per-user salt. If the user does not exist, returns ("", false).
// Note: the C++ admin uses an empty salt — that is still considered "found"
// and the empty string is returned with ok=true.
func (a *Authenticator) GetSalt(username string) (string, bool) {
a.mu.RLock()
u, ok := a.users[username]
a.mu.RUnlock()
if !ok {
return "", false
}
return u.Salt, true
}
// VerifyLogin checks a challenge-response login. The browser sends
// response = SHA256(passwordHash + nonce). On success the function mints a
// new session token, stores it, and returns (token, role, nil).
func (a *Authenticator) VerifyLogin(username, response, nonce string) (token, role string, err error) {
a.mu.RLock()
u, ok := a.users[username]
expire := a.tokenExpire
a.mu.RUnlock()
if !ok {
return "", "", errors.New("invalid credentials")
}
expected := ComputeSHA256(u.PasswordHash + nonce)
if response != expected {
return "", "", errors.New("invalid credentials")
}
token, err = randomHex(tokenBytes)
if err != nil {
return "", "", err
}
a.mu.Lock()
a.tokens[token] = &Session{
Username: username,
Role: u.Role,
ExpiresAt: time.Now().Add(expire),
}
a.mu.Unlock()
return token, u.Role, nil
}
// ValidateToken returns the session for a token or ErrInvalidToken. Expired
// tokens are removed lazily as they are looked up.
func (a *Authenticator) ValidateToken(token string) (*Session, error) {
a.mu.RLock()
s, ok := a.tokens[token]
a.mu.RUnlock()
if !ok {
return nil, ErrInvalidToken
}
if time.Now().After(s.ExpiresAt) {
a.mu.Lock()
delete(a.tokens, token)
a.mu.Unlock()
return nil, ErrInvalidToken
}
return s, nil
}
// RevokeToken removes a token from the active set. No-op for unknown tokens.
func (a *Authenticator) RevokeToken(token string) {
a.mu.Lock()
delete(a.tokens, token)
a.mu.Unlock()
}
// NewNonce returns a fresh challenge nonce (hex string). Each WS connection
// should receive exactly one nonce, consumed by a single login attempt.
func NewNonce() (string, error) {
return randomHex(nonceBytes)
}
// NewSalt returns a fresh per-user salt (hex string).
func NewSalt() (string, error) {
return randomHex(saltBytes)
}
// ComputeSHA256 returns the lowercase-hex SHA256 of s.
func ComputeSHA256(s string) string {
sum := sha256.Sum256([]byte(s))
return hex.EncodeToString(sum[:])
}
// HashPassword computes the stored hash for a (password, salt) pair using
// the same scheme as the existing C++ users.json: SHA256(password + salt).
func HashPassword(password, salt string) string {
return ComputeSHA256(password + salt)
}
func randomHex(n int) (string, error) {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}

View File

@@ -0,0 +1,116 @@
package wsauth
import (
"testing"
"time"
)
func TestSHA256Vector(t *testing.T) {
// Known vector — keeps us honest against accidental algorithm changes.
got := ComputeSHA256("abc")
want := "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
if got != want {
t.Fatalf("SHA256(abc): got %s want %s", got, want)
}
}
func TestLoginRoundTripAdminEmptySalt(t *testing.T) {
a := New()
a.AddAdminFromPlainPassword("admin", "hunter2")
salt, ok := a.GetSalt("admin")
if !ok || salt != "" {
t.Fatalf("admin salt: ok=%v salt=%q", ok, salt)
}
// Simulate the browser: nonce = "abc123", response = SHA256(passwordHash + nonce)
nonce := "abc123"
passwordHash := ComputeSHA256("hunter2")
response := ComputeSHA256(passwordHash + nonce)
token, role, err := a.VerifyLogin("admin", response, nonce)
if err != nil {
t.Fatalf("VerifyLogin: %v", err)
}
if role != "admin" {
t.Fatalf("role: got %q want admin", role)
}
if len(token) != 2*tokenBytes {
t.Fatalf("token length: got %d want %d", len(token), 2*tokenBytes)
}
sess, err := a.ValidateToken(token)
if err != nil {
t.Fatalf("ValidateToken: %v", err)
}
if sess.Username != "admin" || sess.Role != "admin" {
t.Fatalf("session: %+v", sess)
}
}
func TestLoginRoundTripViewerWithSalt(t *testing.T) {
a := New()
salt, _ := NewSalt()
a.AddUser(User{
Username: "alice",
PasswordHash: HashPassword("p@ss", salt),
Salt: salt,
Role: "viewer",
})
gotSalt, ok := a.GetSalt("alice")
if !ok || gotSalt != salt {
t.Fatalf("salt: ok=%v got=%q want=%q", ok, gotSalt, salt)
}
nonce, _ := NewNonce()
response := ComputeSHA256(HashPassword("p@ss", salt) + nonce)
_, role, err := a.VerifyLogin("alice", response, nonce)
if err != nil || role != "viewer" {
t.Fatalf("VerifyLogin: role=%q err=%v", role, err)
}
}
func TestLoginRejectsWrongResponse(t *testing.T) {
a := New()
a.AddAdminFromPlainPassword("admin", "x")
_, _, err := a.VerifyLogin("admin", "deadbeef", "nonce")
if err == nil {
t.Fatal("expected error for bad response")
}
_, _, err = a.VerifyLogin("ghost", "anything", "anything")
if err == nil {
t.Fatal("expected error for unknown user")
}
}
func TestTokenExpiry(t *testing.T) {
a := New()
a.SetTokenExpire(50 * time.Millisecond)
a.AddAdminFromPlainPassword("admin", "x")
nonce, _ := NewNonce()
response := ComputeSHA256(ComputeSHA256("x") + nonce)
token, _, err := a.VerifyLogin("admin", response, nonce)
if err != nil {
t.Fatal(err)
}
if _, err := a.ValidateToken(token); err != nil {
t.Fatalf("fresh token should validate: %v", err)
}
time.Sleep(80 * time.Millisecond)
if _, err := a.ValidateToken(token); err == nil {
t.Fatal("expired token should not validate")
}
}
func TestRevoke(t *testing.T) {
a := New()
a.AddAdminFromPlainPassword("admin", "x")
nonce, _ := NewNonce()
response := ComputeSHA256(ComputeSHA256("x") + nonce)
token, _, _ := a.VerifyLogin("admin", response, nonce)
a.RevokeToken(token)
if _, err := a.ValidateToken(token); err == nil {
t.Fatal("revoked token should not validate")
}
}