Files
SimpleRemoter/server/2015Remote/IOCPServer.cpp

1376 lines
52 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "StdAfx.h"
#include "IOCPServer.h"
#include "2015Remote.h"
#include "common/IPWhitelist.h"
#include "common/IPBlacklist.h"
#include <iostream>
#include <ws2tcpip.h>
// 服务端 RTT 反代理(试用版执法)。声明在主对话框 cpp 中,无单独头文件。
BOOL IsTrail(const std::string& passcode);
// ============================================================================
// SIO_TCP_INFO 兼容性 shim
//
// SIO_TCP_INFO 自 Win10 1703 / Server 2016 起提供,对应的 SDK 头声明只在
// NTDDI_VERSION >= NTDDI_WIN10_RS2 (0x0A000003) 时才可见。本项目当前
// _WIN32_WINNT=0x0602 / NTDDI_VERSION=0x06020000Win8整体上调宏会
// 波及其他模块,且会排除 Win8/8.1 用户。因此在此处本地声明常量与结构,
// 运行时若 OS 不支持WSAIoctl 会返回 WSAEOPNOTSUPP由探测代码静默降级。
//
// 结构体字段顺序严格遵循 MS 公开的 TCP_INFO_v0 定义,不要随意调整。
// ============================================================================
#ifndef SIO_TCP_INFO
#define SIO_TCP_INFO _WSAIORW(IOC_VENDOR, 39)
#endif
typedef struct _TCP_INFO_v0_local {
ULONG State; // TCPSTATE枚举按 4 字节读)
ULONG Mss;
ULONG64 ConnectionTimeMs;
UCHAR TimestampsEnabled;
UCHAR Pad_[3]; // 显式 padding让 RttUs 落在 4 字节边界
ULONG RttUs; // <-- 本文件唯一关心的字段
ULONG MinRttUs;
ULONG BytesInFlight;
ULONG Cwnd;
ULONG SndWnd;
ULONG RcvWnd;
ULONG RcvBuf;
ULONG64 BytesOut;
ULONG64 BytesIn;
ULONG BytesReordered;
ULONG BytesRetrans;
ULONG FastRetrans;
ULONG DupAcksIn;
ULONG TimeoutEpisodes;
UCHAR SynRetrans;
} TCP_INFO_v0_local;
// 读取 socket 的内核测得 RTT。成功返回 0 并写入 *rttUs失败返回 WSAGetLastError()。
static int QuerySocketTcpRttUs(SOCKET s, uint32_t* rttUs)
{
TCP_INFO_v0_local info; ZeroMemory(&info, sizeof(info));
DWORD ver = 0; // request v0
DWORD bytesReturned = 0;
int ret = WSAIoctl(s, SIO_TCP_INFO,
&ver, sizeof(ver),
&info, sizeof(info),
&bytesReturned, NULL, NULL);
if (ret == 0) {
if (rttUs) *rttUs = info.RttUs;
return 0;
}
return WSAGetLastError();
}
// 全 server 进程级 latchIP 段触发与 RTT 触发共用。多 server 实例(多端口监听)共享一份,
// 任一先触发后其余 server 与其它触发路径不再重复弹框。
std::atomic<bool> IOCPServer::s_TrialAbuseWarned{false};
// Proxy Protocol v2 签名 (12 字节)
static const unsigned char PROXY_PROTOCOL_V2_SIGNATURE[12] = {
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A
};
// 解析 Proxy Protocol v2 头,返回真实客户端 IP
// 成功返回 true 并设置 realIP失败返回 false
// 如果不是 Proxy Protocol返回 false 且不消费任何数据
static bool ParseProxyProtocolV2(SOCKET sock, std::string& realIP)
{
// 等待数据就绪(最多 200ms解决 FRP Proxy Protocol 头延迟到达的时序问题
fd_set readfds;
FD_ZERO(&readfds);
FD_SET(sock, &readfds);
struct timeval tv = { 0, 200000 }; // 200ms
int ready = select((int)sock + 1, &readfds, NULL, NULL, &tv);
if (ready <= 0) {
return false; // 超时或错误,当作普通连接
}
// 先 peek 前 16 字节12 签名 + 4 头部)
unsigned char header[16];
int n = recv(sock, (char*)header, 16, MSG_PEEK);
if (n < 12) {
// 数据不足,检查已有数据是否匹配签名前缀
if (n > 0 && memcmp(header, PROXY_PROTOCOL_V2_SIGNATURE, n) != 0) {
return false; // 不匹配,不是 Proxy Protocol
}
// 数据太少无法判断,当作普通连接
return false;
}
// 检查签名
if (memcmp(header, PROXY_PROTOCOL_V2_SIGNATURE, 12) != 0) {
return false; // 签名不匹配,不是 Proxy Protocol
}
if (n < 16) {
// 有签名但头部不完整,等待更多数据(理论上不会发生)
return false;
}
// 解析版本和命令 (byte 12)
// 高 4 位是版本 (应该是 0x2),低 4 位是命令 (0x0=LOCAL, 0x1=PROXY)
unsigned char verCmd = header[12];
unsigned char version = (verCmd >> 4) & 0x0F;
unsigned char command = verCmd & 0x0F;
if (version != 2) {
// 不是 v2可能是 v1 或无效
return false;
}
// 解析地址族和协议 (byte 13)
// 高 4 位是地址族 (0x1=AF_INET, 0x2=AF_INET6)
// 低 4 位是协议 (0x1=STREAM, 0x2=DGRAM)
unsigned char famProto = header[13];
unsigned char addrFamily = (famProto >> 4) & 0x0F;
// 解析地址长度 (bytes 14-15, big-endian)
unsigned short addrLen = (header[14] << 8) | header[15];
// 计算完整头部长度 (IPv4: 16+12=28, IPv6: 16+36=52)
int totalHeaderLen = 16 + addrLen;
if (totalHeaderLen > 108) { // 安全上限16 + 最大 TLV 长度
return false;
}
// 读取完整头部(真正消费数据),使用固定数组避免动态分配
unsigned char fullHeader[108];
int received = 0;
while (received < totalHeaderLen) {
int r = recv(sock, (char*)fullHeader + received, totalHeaderLen - received, 0);
if (r <= 0) {
return false; // 接收失败
}
received += r;
}
// 如果是 LOCAL 命令,使用 socket 的对端地址
if (command == 0x00) {
return false; // 让调用者使用 getpeername
}
// 解析地址
if (addrFamily == 0x01 && addrLen >= 12) {
// IPv4: src_addr(4) + dst_addr(4) + src_port(2) + dst_port(2)
unsigned char* addr = fullHeader + 16;
char ipStr[INET_ADDRSTRLEN];
snprintf(ipStr, sizeof(ipStr), "%u.%u.%u.%u", addr[0], addr[1], addr[2], addr[3]);
realIP = ipStr;
return true;
} else if (addrFamily == 0x02 && addrLen >= 36) {
// IPv6: src_addr(16) + dst_addr(16) + src_port(2) + dst_port(2)
unsigned char* addr = fullHeader + 16;
char ipStr[INET6_ADDRSTRLEN];
inet_ntop(AF_INET6, addr, ipStr, sizeof(ipStr));
realIP = ipStr;
return true;
}
return false; // 未知地址族
}
// 根据 socket 获取客户端IP地址.
std::string GetPeerName(SOCKET sock)
{
sockaddr_in ClientAddr = {};
int ulClientAddrLen = sizeof(sockaddr_in);
int s = getpeername(sock, (SOCKADDR*)&ClientAddr, &ulClientAddrLen);
return s != INVALID_SOCKET ? inet_ntoa(ClientAddr.sin_addr) : "";
}
// 根据 socket 获取客户端IP地址.
std::string GetRemoteIP(SOCKET sock)
{
sockaddr_in addr;
int addrLen = sizeof(addr);
if (getpeername(sock, (sockaddr*)&addr, &addrLen) == 0) {
char ipStr[INET_ADDRSTRLEN];
inet_ntop(AF_INET, &addr.sin_addr, ipStr, sizeof(ipStr));
TRACE(">>> 对端 IP 地址: %s\n", ipStr);
return ipStr;
}
TRACE(">>> 获取对端 IP 失败, 错误码: %d\n", WSAGetLastError());
char buf[10];
sprintf_s(buf, "%d", sock);
return buf;
}
// IP 连接限流配置 (缓存,避免频繁读取注册表)
static struct {
int banWindow = 60;
int banMaxConn = 15;
int banDuration = 3600;
bool loaded = false;
} g_BanConfig;
void ReloadBanConfig() {
g_BanConfig.banWindow = THIS_CFG.GetInt("settings", "BanWindow", 60);
g_BanConfig.banMaxConn = THIS_CFG.GetInt("settings", "BanMaxConn", 15);
g_BanConfig.banDuration = THIS_CFG.GetInt("settings", "BanDuration", 3600);
g_BanConfig.loaded = true;
}
static int GetBanWindowSeconds() {
if (!g_BanConfig.loaded) ReloadBanConfig();
return g_BanConfig.banWindow;
}
static int GetBanMaxConnections() {
if (!g_BanConfig.loaded) ReloadBanConfig();
return g_BanConfig.banMaxConn;
}
static int GetBanDurationSeconds() {
if (!g_BanConfig.loaded) ReloadBanConfig();
return g_BanConfig.banDuration;
}
// 检查 IP 是否被封禁
bool IOCPServer::IsIPBanned(const std::string& ip)
{
// 检查白名单 (包含本地地址检查)
if (IPWhitelist::getInstance().IsWhitelisted(ip)) {
return false;
}
CLock lock(m_BanLock);
auto it = m_BannedIPs.find(ip);
if (it != m_BannedIPs.end()) {
time_t now = time(nullptr);
if (now < it->second) {
// 仍在封禁期内
return true;
}
// 封禁已过期,移除
m_BannedIPs.erase(it);
}
return false;
}
// 检查 IP 是否在黑名单中
bool IOCPServer::IsIPBlacklisted(const std::string& ip)
{
if (IPBlacklist::getInstance().IsBlacklisted(ip)) {
// 防刷频日志
if (IPBlacklist::getInstance().ShouldLog(ip)) {
Mprintf("Connection rejected: %s (blacklisted)\n", ip.c_str());
if (m_hMainWnd) {
char tip[256];
sprintf_s(tip, _TRF("IP %s 连接被拒绝 (黑名单)"), ip.c_str());
PostMessageA(m_hMainWnd, WM_SHOWERRORMSG,
(WPARAM)new CString(tip), (LPARAM)new CString(_TR("黑名单")));
}
}
return true;
}
return false;
}
// 记录连接并检测异常
void IOCPServer::RecordConnection(const std::string& ip)
{
// 检查白名单 (包含本地地址检查)
if (IPWhitelist::getInstance().IsWhitelisted(ip)) {
return;
}
bool shouldBan = false;
{
CLock lock(m_BanLock);
time_t now = time(nullptr);
int banWindow = GetBanWindowSeconds();
// 定期清理过期的连接计数 (每 1000 次检查一次,或条目超过 10000)
static int cleanupCounter = 0;
if (++cleanupCounter >= 1000 || m_ConnectionCount.size() > 10000) {
cleanupCounter = 0;
for (auto it = m_ConnectionCount.begin(); it != m_ConnectionCount.end(); ) {
if (now - it->second.windowStart >= banWindow) {
it = m_ConnectionCount.erase(it);
} else {
++it;
}
}
}
auto it = m_ConnectionCount.find(ip);
if (it == m_ConnectionCount.end()) {
// 新 IP开始计数
m_ConnectionCount[ip] = { 1, now };
} else {
// 检查是否在同一个统计窗口内
if (now - it->second.windowStart < banWindow) {
it->second.count++;
// 检查是否超过阈值
if (it->second.count > GetBanMaxConnections()) {
shouldBan = true;
}
} else {
// 新窗口,重置计数
it->second.count = 1;
it->second.windowStart = now;
}
}
}
if (shouldBan) {
BanIP(ip, GetBanDurationSeconds());
}
}
// 封禁 IP
void IOCPServer::BanIP(const std::string& ip, int seconds)
{
{
CLock lock(m_BanLock);
time_t expiry = time(nullptr) + seconds;
m_BannedIPs[ip] = expiry;
// 清除连接计数
m_ConnectionCount.erase(ip);
}
Mprintf("IP banned: %s (duration: %d seconds, reason: too many connections)\n",
ip.c_str(), seconds);
// 发送到主窗口信息列表
if (m_hMainWnd) {
char tip[256];
sprintf_s(tip, _TRF("IP %s 已封禁 %d 秒 (连接过于频繁)"), ip.c_str(), seconds);
PostMessageA(m_hMainWnd, WM_SHOWERRORMSG, (WPARAM)new CString(tip),(LPARAM)new CString(_TR("IP 封禁")));
}
}
// 从配置文件加载 IP 白名单
void IOCPServer::LoadIPWhitelist()
{
// 配置格式: IPWhitelist=192.168.1.1;10.0.0.1;172.16.0.100
std::string whitelist = THIS_CFG.GetStr("settings", "IPWhitelist", "");
IPWhitelist::getInstance().Load(whitelist);
size_t count = IPWhitelist::getInstance().Count();
if (count > 0) {
Mprintf("IP whitelist loaded: %zu IPs\n", count);
}
}
// 从配置文件加载 IP 黑名单
void IOCPServer::LoadIPBlacklist()
{
// 配置格式: IPBlacklist=192.168.1.1;10.0.0.1
std::string blacklist = THIS_CFG.GetStr("settings", "IPBlacklist", "");
IPBlacklist::getInstance().Load(blacklist);
size_t count = IPBlacklist::getInstance().Count();
if (count > 0) {
Mprintf("IP blacklist loaded: %zu IPs\n", count);
}
}
IOCPServer::IOCPServer(HWND hWnd)
{
m_hMainWnd = hWnd;
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2,2), &wsaData)!=0) {
return;
}
m_hCompletionPort = NULL;
m_sListenSocket = INVALID_SOCKET;
m_hListenEvent = WSA_INVALID_EVENT;
m_hListenThread = NULL;
m_ulMaxConnections = 10000;
InitializeCriticalSection(&m_cs);
InitializeCriticalSection(&m_BanLock);
LoadIPWhitelist();
LoadIPBlacklist();
m_ulWorkThreadCount = 0;
m_bTimeToKill = FALSE;
m_ulThreadPoolMin = 0;
m_ulThreadPoolMax = 0;
m_ulCPULowThreadsHold = 0;
m_ulCPUHighThreadsHold = 0;
m_ulCurrentThread = 0;
m_ulBusyThread = 0;
m_ulKeepLiveTime = 0;
m_hKillEvent = NULL;
m_NotifyProc = NULL;
m_OfflineProc = NULL;
}
void IOCPServer::Destroy()
{
m_bTimeToKill = TRUE;
if (m_hKillEvent != NULL) {
SetEvent(m_hKillEvent);
// RTT 轮询线程要等它退出后再关 m_hKillEvent否则线程仍在 WaitForSingleObject 上时
// 关句柄是 UB。监听 / 工作线程是用 m_bTimeToKill 兜底的,原有时序不动。
if (m_hRttThread != NULL) {
WaitForSingleObject(m_hRttThread, 5000);
SAFE_CLOSE_HANDLE(m_hRttThread);
m_hRttThread = NULL;
}
SAFE_CLOSE_HANDLE(m_hKillEvent);
m_hKillEvent = NULL;
}
if (m_sListenSocket != INVALID_SOCKET) {
closesocket(m_sListenSocket);
m_sListenSocket = INVALID_SOCKET;
}
if (m_hCompletionPort != INVALID_HANDLE_VALUE) {
SAFE_CLOSE_HANDLE(m_hCompletionPort);
m_hCompletionPort = INVALID_HANDLE_VALUE;
}
if (m_hListenEvent != WSA_INVALID_EVENT) {
SAFE_CLOSE_HANDLE(m_hListenEvent);
m_hListenEvent = WSA_INVALID_EVENT;
}
}
IOCPServer::~IOCPServer(void)
{
Destroy();
while (m_ulWorkThreadCount || m_hListenThread)
Sleep(10);
while (!m_ContextConnectionList.IsEmpty()) {
CONTEXT_OBJECT *ContextObject = m_ContextConnectionList.GetHead();
RemoveStaleContext(ContextObject);
SAFE_DELETE(ContextObject->olps);
}
while (!m_ContextFreePoolList.IsEmpty()) {
CONTEXT_OBJECT *ContextObject = m_ContextFreePoolList.RemoveHead();
// 下述语句有崩溃概率2019.1.14
//SAFE_DELETE(ContextObject->olps);
delete ContextObject;
}
DeleteCriticalSection(&m_cs);
DeleteCriticalSection(&m_BanLock);
m_ulWorkThreadCount = 0;
m_ulThreadPoolMin = 0;
m_ulThreadPoolMax = 0;
m_ulCPULowThreadsHold = 0;
m_ulCPUHighThreadsHold = 0;
m_ulCurrentThread = 0;
m_ulBusyThread = 0;
m_ulKeepLiveTime = 0;
WSACleanup();
}
// 返回错误码0代表成功否则代表错误信息.
UINT IOCPServer::StartServer(pfnNotifyProc NotifyProc, pfnOfflineProc OffProc, USHORT uPort)
{
m_nPort = uPort;
m_NotifyProc = NotifyProc;
m_OfflineProc = OffProc;
// manual-reset本进程内可能有多个等待者ListenThread / RttPollThreadProc
// 自动重置会让 SetEvent 只唤醒一个等待者,另一个要等自身 timeout≤1s
// 改 manual-reset 后所有等待者一次性醒来;本工程从无 ResetEvent 调用,无副作用。
m_hKillEvent = CreateEvent(NULL,TRUE,FALSE,NULL);
if (m_hKillEvent==NULL) {
return 1;
}
m_sListenSocket = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); //创建监听套接字
if (m_sListenSocket == INVALID_SOCKET) {
return 2;
}
m_hListenEvent = WSACreateEvent();
if (m_hListenEvent == WSA_INVALID_EVENT) {
closesocket(m_sListenSocket);
m_sListenSocket = INVALID_SOCKET;
return 3;
}
int iRet = WSAEventSelect(m_sListenSocket, //将监听套接字与事件进行关联并授予FD_ACCEPT的属性
m_hListenEvent,
FD_ACCEPT);
if (iRet == SOCKET_ERROR) {
int a = GetLastError();
closesocket(m_sListenSocket);
m_sListenSocket = INVALID_SOCKET;
WSACloseEvent(m_hListenEvent);
m_hListenEvent = WSA_INVALID_EVENT;
return a;
}
SOCKADDR_IN ServerAddr;
ServerAddr.sin_port = htons(uPort);
ServerAddr.sin_family = AF_INET;
ServerAddr.sin_addr.s_addr = INADDR_ANY; //初始化本地网卡
//将监听套机字和网卡进行bind
iRet = bind(m_sListenSocket,
(sockaddr*)&ServerAddr,
sizeof(ServerAddr));
if (iRet == SOCKET_ERROR) {
int a = GetLastError();
closesocket(m_sListenSocket);
m_sListenSocket = INVALID_SOCKET;
WSACloseEvent(m_hListenEvent);
m_hListenEvent = WSA_INVALID_EVENT;
return a;
}
iRet = listen(m_sListenSocket, SOMAXCONN);
if (iRet == SOCKET_ERROR) {
int a = GetLastError();
closesocket(m_sListenSocket);
m_sListenSocket = INVALID_SOCKET;
WSACloseEvent(m_hListenEvent);
m_hListenEvent = WSA_INVALID_EVENT;
return a;
}
m_hListenThread =
(HANDLE)CreateThread(NULL,
0,
ListenThreadProc,
(void*)this, //向Thread回调函数传入this 方便我们的线程回调访问类中的成员
0,
NULL);
if (m_hListenThread==NULL) {
int a = GetLastError();
closesocket(m_sListenSocket);
m_sListenSocket = INVALID_SOCKET;
WSACloseEvent(m_hListenEvent);
m_hListenEvent = WSA_INVALID_EVENT;
return a;
}
//启动工作线程 1 2
InitializeIOCP();
// 试用版反代理 RTT 轮询(仅在主控自身为试用模式时启动)。
// 检测信号来自内核 SIO_TCP_INFO详见 IOCPServer.h 头部 / RttPollThreadProc 注释。
{
std::string pwd = THIS_CFG.GetStr("settings", "Password", "");
m_bTrialMode = (IsTrail(pwd) == TRUE);
}
if (m_bTrialMode) {
m_hRttThread = CreateThread(NULL, 0, RttPollThreadProc, (void*)this, 0, NULL);
if (m_hRttThread == NULL) {
Mprintf("[Compliance] RTT poll thread spawn failed (err=%lu); LANRttChecker (client-side) remains as fallback.\n",
GetLastError());
}
}
return 0;
}
//1创建完成端口
//2创建工作线程
BOOL IOCPServer::InitializeIOCP(VOID)
{
m_hCompletionPort = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0 );
if ( m_hCompletionPort == NULL ) {
return FALSE;
}
if (m_hCompletionPort==INVALID_HANDLE_VALUE) {
return FALSE;
}
SYSTEM_INFO SystemInfo;
GetSystemInfo(&SystemInfo); //获得PC中有几核
m_ulThreadPoolMin = 1;
m_ulThreadPoolMax = SystemInfo.dwNumberOfProcessors * 2;
m_ulCPULowThreadsHold = 10;
m_ulCPUHighThreadsHold = 75;
ULONG ulWorkThreadCount = m_ulThreadPoolMax;
HANDLE hWorkThread = NULL;
for (int i=0; i<ulWorkThreadCount; ++i) {
hWorkThread = (HANDLE)CreateThread(NULL, //创建工作线程目的是处理投递到完成端口中的任务
0,
WorkThreadProc,
(void*)this,
0,
NULL);
if (hWorkThread == NULL ) {
SAFE_CLOSE_HANDLE(m_hCompletionPort);
return FALSE;
}
AddWorkThread(1);
SAFE_CLOSE_HANDLE(hWorkThread);
}
return TRUE;
}
DWORD IOCPServer::WorkThreadProc(LPVOID lParam)
{
// 压缩库配置
ZSTD_DCtx* m_Dctx = ZSTD_createDCtx(); // 解压上下文
z_stream m_stream = {};
inflateInit2(&m_stream, 15);
IOCPServer* This = (IOCPServer*)(lParam);
HANDLE hCompletionPort = This->m_hCompletionPort;
DWORD dwTrans = 0;
PCONTEXT_OBJECT ContextObject = NULL;
LPOVERLAPPED Overlapped = NULL;
OVERLAPPEDPLUS* OverlappedPlus = NULL;
ULONG ulBusyThread = 0;
BOOL bError = FALSE;
InterlockedIncrement(&This->m_ulCurrentThread);
InterlockedIncrement(&This->m_ulBusyThread);
timeBeginPeriod(1);
while (This->m_bTimeToKill==FALSE) {
InterlockedDecrement(&This->m_ulBusyThread);
// GetQueuedCompletionStatus耗时比较长导致客户端发送数据的速率提高不了
BOOL bOk = GetQueuedCompletionStatus(
hCompletionPort,
&dwTrans,
(PULONG_PTR)&ContextObject,
&Overlapped, INFINITE);
DWORD dwIOError = GetLastError();
OverlappedPlus = CONTAINING_RECORD(Overlapped, OVERLAPPEDPLUS, m_ol);
ulBusyThread = InterlockedIncrement(&This->m_ulBusyThread); //1 1
if ( !bOk && dwIOError != WAIT_TIMEOUT ) { //当对方的套机制发生了关闭
if (ContextObject && This->m_bTimeToKill == FALSE &&dwTrans==0) {
ContextObject->olps = NULL;
Mprintf("!!! RemoveStaleContext: %d \n", WSAGetLastError());
This->RemoveStaleContext(ContextObject);
}
SAFE_DELETE(OverlappedPlus);
continue;
}
if (!bError) {
//分配一个新的线程到线程到线程池
if (ulBusyThread == This->m_ulCurrentThread) {
if (ulBusyThread < This->m_ulThreadPoolMax) {
if (ContextObject != NULL) {
HANDLE hThread = (HANDLE)CreateThread(NULL,
0,
WorkThreadProc,
(void*)This,
0,
NULL);
This->AddWorkThread(hThread ? 1:0);
SAFE_CLOSE_HANDLE(hThread);
}
}
}
if (!bOk && dwIOError == WAIT_TIMEOUT) {
if (ContextObject == NULL) {
if (This->m_ulCurrentThread > This->m_ulThreadPoolMin) {
break;
}
bError = TRUE;
}
}
}
if (!bError && !This->m_bTimeToKill) {
if(bOk && OverlappedPlus!=NULL && ContextObject!=NULL) {
try {
This->HandleIO(OverlappedPlus->m_ioType, ContextObject, dwTrans, m_Dctx, &m_stream);
ContextObject = NULL;
} catch (...) {
Mprintf("This->HandleIO catched an error!!!");
}
}
}
SAFE_DELETE(OverlappedPlus);
}
timeEndPeriod(1);
SAFE_DELETE(OverlappedPlus);
InterlockedDecrement(&This->m_ulCurrentThread);
InterlockedDecrement(&This->m_ulBusyThread);
int n= This->AddWorkThread(-1);
if (n == 0) {
Mprintf("======> IOCPServer All WorkThreadProc done\n");
}
inflateEnd(&m_stream);
ZSTD_freeDCtx(m_Dctx);
return 0;
}
//在工作线程中被调用
BOOL IOCPServer::HandleIO(IOType PacketFlags,PCONTEXT_OBJECT ContextObject, DWORD dwTrans, ZSTD_DCtx* ctx, z_stream* z)
{
// 防止竞态条件 (#215)
// 必须先增加引用计数,再检查 IsRemoved。
// 顺序很重要!如果先检查再增加,会有 TOCTOU 竞态窗口:
// 在检查和增加之间RemoveStaleContext 可能完成并重用对象。
ContextObject->IoRefCount.fetch_add(1);
// 检查对象是否已被标记为移除
if (ContextObject->IsRemoved.load()) {
ContextObject->IoRefCount.fetch_sub(1);
return FALSE;
}
BOOL bRet = FALSE;
switch (PacketFlags) {
case IOInitialize:
bRet = OnClientInitializing(ContextObject, dwTrans);
break;
case IORead:
bRet = OnClientReceiving(ContextObject, dwTrans, ctx, z);
break;
case IOWrite:
bRet = OnClientPostSending(ContextObject, dwTrans);
break;
case IOIdle:
Mprintf("=> HandleIO PacketFlags= IOIdle\n");
break;
default:
break;
}
// 减少引用计数 (#215)
// 特殊返回值 -2 表示内部函数已经减少了引用计数并调用了 RemoveStaleContext
// 此时对象可能已在空闲池或被重用,不能再访问
if (bRet != -2) {
ContextObject->IoRefCount.fetch_sub(1);
}
return bRet == -2 ? FALSE : bRet;
}
BOOL IOCPServer::OnClientInitializing(PCONTEXT_OBJECT ContextObject, DWORD dwTrans)
{
return TRUE;
}
// May be this function should be a member of `CONTEXT_OBJECT`.
BOOL ParseReceivedData(CONTEXT_OBJECT * ContextObject, DWORD dwTrans, pfnNotifyProc m_NotifyProc, ZSTD_DCtx* m_Dctx, z_stream* z)
{
AUTO_TICK(50, ContextObject->GetPeerName());
BOOL ret = 1;
try {
if (dwTrans == 0) { //对方关闭了套接字
return FALSE;
}
//将接收到的数据拷贝到我们自己的内存中wsabuff 8192
ContextObject->InCompressedBuffer.WriteBuffer((PBYTE)ContextObject->szBuffer,dwTrans);
//查看数据包的完整性
while (true) {
PR pr = ContextObject->Parse(ContextObject->InCompressedBuffer);
if (pr.IsFailed()) {
ContextObject->InCompressedBuffer.ClearBuffer();
break;
} else if (pr.IsNeedMore()) {
break;
} else if (pr.IsWinOSLogin()) {
ContextObject->InDeCompressedBuffer.ClearBuffer();
ULONG ulCompressedLength = 0;
ULONG ulOriginalLength = 0;
PBYTE CompressedBuffer = ContextObject->ReadBuffer(ulCompressedLength, ulOriginalLength);
ContextObject->InDeCompressedBuffer.WriteBuffer(CompressedBuffer, ulCompressedLength);
if (m_NotifyProc(ContextObject))
ret = CompressedBuffer[0] == TOKEN_LOGIN ? 999 : 1;
// CompressedBuffer 由 CONTEXT_OBJECT 管理,不在此处释放
break;
}
ULONG ulPackTotalLength = 0;
ContextObject->InCompressedBuffer.CopyBuffer(&ulPackTotalLength, sizeof(ULONG), pr.Result);
//取出数据包的总长度5字节标识+4字节数据包总长度+4字节原始数据长度
int bufLen = ContextObject->InCompressedBuffer.GetBufferLength();
if (ulPackTotalLength && bufLen >= ulPackTotalLength) {
ULONG ulCompressedLength = 0;
ULONG ulOriginalLength = 0;
PBYTE CompressedBuffer = ContextObject->ReadBuffer(ulCompressedLength, ulOriginalLength);
if (ContextObject->CompressMethod == COMPRESS_UNKNOWN) {
// CompressedBuffer 由 CONTEXT_OBJECT 管理,不在此处释放
throw "Unknown method";
} else if (ContextObject->CompressMethod == COMPRESS_NONE) {
ContextObject->InDeCompressedBuffer.ClearBuffer();
ContextObject->Decode(CompressedBuffer, ulOriginalLength);
ContextObject->InDeCompressedBuffer.WriteBuffer(CompressedBuffer, ulOriginalLength);
if (m_NotifyProc(ContextObject))
ret = CompressedBuffer[0] == TOKEN_LOGIN ? 999 : 1;
// CompressedBuffer 由 CONTEXT_OBJECT 管理,不在此处释放
continue;
}
bool usingZstd = ContextObject->CompressMethod == COMPRESS_ZSTD, zlibFailed = false;
// 使用预分配缓冲区,避免频繁内存分配
PBYTE DeCompressedBuffer = ContextObject->GetDecompressBuffer(ulOriginalLength);
size_t iRet = usingZstd ?
Muncompress(DeCompressedBuffer, &ulOriginalLength, CompressedBuffer, ulCompressedLength) :
z_uncompress(z, DeCompressedBuffer, &ulOriginalLength, CompressedBuffer, ulCompressedLength);
if (usingZstd ? C_SUCCESS(iRet) : (S_OK==iRet)) {
ContextObject->InDeCompressedBuffer.ClearBuffer();
ContextObject->Decode(DeCompressedBuffer, ulOriginalLength);
ContextObject->InDeCompressedBuffer.WriteBuffer(DeCompressedBuffer, ulOriginalLength);
if (m_NotifyProc(ContextObject))
ret = DeCompressedBuffer[0] == TOKEN_LOGIN ? 999 : 1;
} else if (usingZstd) {
// 尝试用zlib解压缩
if (Z_OK == z_uncompress(z, DeCompressedBuffer, &ulOriginalLength, CompressedBuffer, ulCompressedLength)) {
ContextObject->CompressMethod = COMPRESS_ZLIB;
ContextObject->InDeCompressedBuffer.ClearBuffer();
ContextObject->Decode(DeCompressedBuffer, ulOriginalLength);
ContextObject->InDeCompressedBuffer.WriteBuffer(DeCompressedBuffer, ulOriginalLength);
if (m_NotifyProc(ContextObject))
ret = DeCompressedBuffer[0] == TOKEN_LOGIN ? 999 : 1;
} else {
zlibFailed = true;
// 注意:不设置 COMPRESS_UNKNOWN后续包仍尝试用 ZSTD
}
} else {
zlibFailed = true;
}
// CompressedBuffer 和 DeCompressedBuffer 都由 CONTEXT_OBJECT 管理,不在此处释放
if (zlibFailed) {
Mprintf("[ERROR] uncompress failed: cmd=0x%02X, compressed=%u, original=%u\n",
(unsigned char)CompressedBuffer[0], ulCompressedLength, ulOriginalLength);
throw "Bad Buffer"; // 抛出异常,在 catch 中清理缓冲区
}
} else {
break;
}
}
} catch(...) {
Mprintf("[ERROR] OnClientReceiving catch an error \n");
ContextObject->InCompressedBuffer.ClearBuffer();
ContextObject->InDeCompressedBuffer.ClearBuffer();
}
return ret;
}
BOOL IOCPServer::OnClientReceiving(PCONTEXT_OBJECT ContextObject, DWORD dwTrans, ZSTD_DCtx* ctx, z_stream* z)
{
if (FALSE == ParseReceivedData(ContextObject, dwTrans, m_NotifyProc, ctx, z)) {
// 先减少引用计数,再调用 RemoveStaleContext (#215)
// RemoveStaleContext 完成后对象会被移到空闲池,不能再访问
ContextObject->IoRefCount.fetch_sub(1);
RemoveStaleContext(ContextObject);
return -2; // 特殊返回值:告诉 HandleIO 不要再减少引用计数
}
PostRecv(ContextObject); //投递新的接收数据的请求
return TRUE;
}
BOOL WriteContextData(CONTEXT_OBJECT* ContextObject, PBYTE szBuffer, size_t ulOriginalLength, ZSTD_CCtx* m_Cctx, z_stream* z)
{
assert(ContextObject);
// 输出服务端所发送的命令
int cmd = szBuffer[0];
#ifdef _DEBUG
if (ulOriginalLength < 100 && cmd != COMMAND_SCREEN_CONTROL && cmd != CMD_HEARTBEAT_ACK &&
cmd != CMD_DRAW_POINT && cmd != CMD_MOVEWINDOW && cmd != CMD_SET_SIZE) {
char buf[100] = { 0 };
if (ulOriginalLength == 1) {
sprintf_s(buf, "command %d", cmd);
} else {
memcpy(buf, szBuffer, ulOriginalLength);
}
Mprintf("[COMMAND] Send: %s\r\n", buf);
}
#endif
try {
do {
if (ulOriginalLength <= 0) return FALSE;
if (ContextObject->CompressMethod == COMPRESS_UNKNOWN) {
Mprintf("[ERROR] UNKNOWN compress method \n");
return FALSE;
} else if (ContextObject->CompressMethod == COMPRESS_NONE) {
Buffer tmp(szBuffer, ulOriginalLength);
szBuffer = tmp.Buf();
ContextObject->WriteBuffer(szBuffer, ulOriginalLength, ulOriginalLength, cmd);
break;
}
bool usingZstd = ContextObject->CompressMethod == COMPRESS_ZSTD;
unsigned long ulCompressedLength = usingZstd ?
ZSTD_compressBound(ulOriginalLength) : (unsigned long)((double)ulOriginalLength * 1.001 + 12);
// 使用预分配缓冲区替代每次 new
LPBYTE CompressedBuffer = ContextObject->GetSendCompressBuffer(ulCompressedLength);
Buffer tmp(szBuffer, ulOriginalLength);
szBuffer = tmp.Buf();
ContextObject->Encode(szBuffer, ulOriginalLength);
if (!m_Cctx) ContextObject->Encode(szBuffer, ulOriginalLength, usingZstd);
size_t iRet = usingZstd ?
Mcompress(CompressedBuffer, &ulCompressedLength, (LPBYTE)szBuffer, ulOriginalLength, ContextObject->GetZstdLevel()):
compress(CompressedBuffer, &ulCompressedLength, (LPBYTE)szBuffer, ulOriginalLength);
if (usingZstd ? C_FAILED(iRet) : (S_OK != iRet)) {
Mprintf("[ERROR] compress failed \n");
// SendCompressBuffer 由 CONTEXT_OBJECT 管理,不在此处释放
return FALSE;
}
ulCompressedLength = usingZstd ? iRet : ulCompressedLength;
ContextObject->WriteBuffer(CompressedBuffer, ulCompressedLength, ulOriginalLength, cmd);
// SendCompressBuffer 由 CONTEXT_OBJECT 管理,不在此处释放
} while (false);
return TRUE;
} catch (...) {
Mprintf("[ERROR] OnClientPreSending catch an error \n");
return FALSE;
}
}
BOOL IOCPServer::OnClientPreSending(CONTEXT_OBJECT* ContextObject, PBYTE szBuffer, size_t ulOriginalLength)
{
if (WriteContextData(ContextObject, szBuffer, ulOriginalLength, ContextObject->Zcctx)) {
OVERLAPPEDPLUS* OverlappedPlus = new OVERLAPPEDPLUS(IOWrite);
BOOL bOk = PostQueuedCompletionStatus(m_hCompletionPort, 0, (ULONG_PTR)ContextObject, &OverlappedPlus->m_ol);
if ( (!bOk && GetLastError() != ERROR_IO_PENDING) ) { //如果投递失败
int a = GetLastError();
Mprintf("!!! OnClientPreSending 投递消息失败\n");
RemoveStaleContext(ContextObject);
SAFE_DELETE(OverlappedPlus);
return FALSE;
}
return TRUE;
}
return FALSE;
}
BOOL IOCPServer::OnClientPostSending(CONTEXT_OBJECT* ContextObject,ULONG ulCompletedLength)
{
CAutoCLock L(ContextObject->SendLock);
try {
DWORD ulFlags = MSG_PARTIAL;
ContextObject->OutCompressedBuffer.RemoveCompletedBuffer(ulCompletedLength); //将完成的数据从数据结构中去除
if (ContextObject->OutCompressedBuffer.GetBufferLength() == 0) {
ContextObject->OutCompressedBuffer.ClearBuffer();
return true; //走到这里说明我们的数据真正完全发送
} else {
OVERLAPPEDPLUS * OverlappedPlus = new OVERLAPPEDPLUS(IOWrite); //数据没有完成 我们继续投递 发送请求
ContextObject->wsaOutBuffer.buf = (char*)ContextObject->OutCompressedBuffer.GetBuffer(0);
ContextObject->wsaOutBuffer.len = ContextObject->OutCompressedBuffer.GetBufferLength();
int iOk = WSASend(ContextObject->sClientSocket, &ContextObject->wsaOutBuffer,1,
NULL, ulFlags,&OverlappedPlus->m_ol, NULL);
if ( iOk == SOCKET_ERROR && WSAGetLastError() != WSA_IO_PENDING ) {
// 先减少引用计数,再调用 RemoveStaleContext (#215)
ContextObject->IoRefCount.fetch_sub(1);
if (RemoveStaleContext(ContextObject))
Mprintf("!!! OnClientPostSending 投递消息失败: %d\n", WSAGetLastError());
SAFE_DELETE(OverlappedPlus);
return -2; // 特殊返回值:告诉 HandleIO 不要再减少引用计数
}
return TRUE;
}
} catch(...) {
Mprintf("[ERROR] OnClientPostSending catch an error \n");
}
return FALSE;
}
// ============================================================================
// 试用版反代理 —— 服务端 RTT 轮询线程
//
// 仅在主控自身处于试用模式IsTrail(passcode) == TRUE时由 StartServer 启动。
// 1 Hz 遍历 m_ContextConnectionList对每个活跃连接调 WSAIoctl(SIO_TCP_INFO) 取
// 内核测得的纯网络 RTT喂给 ctx->m_RttDetector。任一 detector 首次触发 →
// 通过 s_TrialAbuseWarned latch 抢一次 PostMessage 给主窗口弹框;其余 detector
// 仍照常运转(继续记日志),但不再重复弹框。
//
// 并发模型:对齐既有 IoRefCount / IsRemoved 模式 —— 持 m_cs snapshot 指针并
// 引用计数 ++,锁外做 WSAIoctl + 写 atomic最后引用计数 --。RemoveStaleContext
// 会等 IoRefCount==0 才回收,无悬空指针。
//
// 不支持 SIO_TCP_INFO 的 OSWin8 / Server 2012 等):首次探测命中
// WSAEOPNOTSUPP 时打日志后线程自行退出;客户端 LANRttChecker 仍作为兜底。
// ============================================================================
DWORD IOCPServer::RttPollThreadProc(LPVOID lParam)
{
IOCPServer* This = (IOCPServer*)lParam;
while (!This->m_bTimeToKill) {
DWORD waitRet = WaitForSingleObject(This->m_hKillEvent, 1000);
if (waitRet == WAIT_OBJECT_0 || waitRet == WAIT_FAILED) break;
if (This->m_bTimeToKill) break;
// —— 步骤 1持锁快照 + 占引用 —— 锁外才做 WSAIoctl避免阻塞其他 I/O
std::vector<PCONTEXT_OBJECT> snap;
EnterCriticalSection(&This->m_cs);
for (POSITION pos = This->m_ContextConnectionList.GetHeadPosition(); pos != NULL; ) {
PCONTEXT_OBJECT ctx = This->m_ContextConnectionList.GetNext(pos);
if (!ctx) continue;
if (ctx->IsRemoved.load(std::memory_order_acquire)) continue;
ctx->IoRefCount.fetch_add(1, std::memory_order_acq_rel);
snap.push_back(ctx);
}
LeaveCriticalSection(&This->m_cs);
// —— 步骤 2OS 兼容性探测(一次性,借第一个真实连接做) —— 探测失败的 OS
// 上整个线程不必再活,本次循环把已占的引用还掉就退出。
if (!This->m_bSioTcpInfoProbed.load(std::memory_order_acquire) && !snap.empty()) {
uint32_t probeRtt = 0;
int err = QuerySocketTcpRttUs(snap[0]->sClientSocket, &probeRtt);
if (err == WSAEOPNOTSUPP) {
Mprintf("[Compliance] SIO_TCP_INFO not supported by OS (WSAEOPNOTSUPP); "
"server-side RTT monitoring disabled. Client-side LANRttChecker remains active.\n");
This->m_bSioTcpInfoSupported.store(false, std::memory_order_release);
This->m_bSioTcpInfoProbed.store(true, std::memory_order_release);
for (auto* c : snap) c->IoRefCount.fetch_sub(1, std::memory_order_acq_rel);
break;
}
// 其它错误(如 WSAENOTCONN 短连接刚断)不视为 OS 问题,下一轮再试
if (err == 0) {
This->m_bSioTcpInfoSupported.store(true, std::memory_order_release);
This->m_bSioTcpInfoProbed.store(true, std::memory_order_release);
Mprintf("[Compliance] SIO_TCP_INFO probe OK; server-side anti-proxy RTT monitor armed "
"(threshold=%d ms, trigger after >=%d consecutive median breaches @1Hz).\n",
TcpRttBreachDetector::RTT_THRESHOLD_MS, TcpRttBreachDetector::BREACH_PERSIST_COUNT);
}
}
// —— 步骤 3逐 ctx 取 RTT + 喂检测器 —— 同步释放引用
for (auto* ctx : snap) {
uint32_t rttUs = 0;
int err = QuerySocketTcpRttUs(ctx->sClientSocket, &rttUs);
if (err == 0 && rttUs > 0) {
ctx->SetRttUs(rttUs);
// RttUs 单位是微秒,转毫秒喂检测器
int rttMs = (int)((rttUs + 500) / 1000);
if (ctx->m_RttDetector.Feed(rttMs)) {
// 本 ctx 首次触发:记日志(每个 ctx 都记,便于排查 abusive 来源);
// 全 server 一次性 latch 决定要不要弹框
Mprintf("[Compliance] !!! Trial-mode anti-proxy triggered: client=%llu IP=%s "
"median RTT=%d ms (threshold=%d ms).\n",
ctx->ID, ctx->GetPeerName().c_str(),
ctx->m_RttDetector.TriggerMedianMs(),
TcpRttBreachDetector::RTT_THRESHOLD_MS);
bool expected = false;
if (s_TrialAbuseWarned.compare_exchange_strong(expected, true) && This->m_hMainWnd) {
// WPARAM 携带 abusive ctx 的 ClientID 低 32 位仅用于展示LPARAM 携带 medianMs
PostMessageA(This->m_hMainWnd, WM_TRIAL_RTT_ABUSE,
(WPARAM)(ctx->ID & 0xFFFFFFFF),
(LPARAM)ctx->m_RttDetector.TriggerMedianMs());
}
}
}
ctx->IoRefCount.fetch_sub(1, std::memory_order_acq_rel);
}
}
return 0;
}
DWORD IOCPServer::ListenThreadProc(LPVOID lParam) //监听线程
{
IOCPServer* This = (IOCPServer*)(lParam);
WSANETWORKEVENTS NetWorkEvents;
while(!This->m_bTimeToKill) {
if (WaitForSingleObject(This->m_hKillEvent, 100) == WAIT_OBJECT_0)
break;
DWORD dwRet;
dwRet = WSAWaitForMultipleEvents(1,&This->m_hListenEvent,FALSE,100,FALSE);
if (dwRet == WSA_WAIT_TIMEOUT)
continue;
int iRet = WSAEnumNetworkEvents(This->m_sListenSocket,
//如果事件授信 我们就将该事件转换成一个网络事件 进行 判断
This->m_hListenEvent,
&NetWorkEvents);
if (iRet == SOCKET_ERROR)
break;
if (NetWorkEvents.lNetworkEvents & FD_ACCEPT) {
if (NetWorkEvents.iErrorCode[FD_ACCEPT_BIT] == 0) {
This->OnAccept();
} else {
break;
}
}
}
This->m_hListenThread = NULL;
return 0;
}
void IOCPServer::OnAccept()
{
SOCKADDR_IN ClientAddr = {0};
SOCKET sClientSocket = INVALID_SOCKET;
int iLen = sizeof(SOCKADDR_IN);
sClientSocket = accept(m_sListenSocket,
(sockaddr*)&ClientAddr,
&iLen); //通过我们的监听套接字来生成一个与之信号通信的套接字
if (sClientSocket == SOCKET_ERROR) {
return;
}
//我们在这里为每一个到达的信号维护了一个与之关联的数据结构这里简称为用户的上下背景文
PCONTEXT_OBJECT ContextObject = AllocateContext(sClientSocket); // Context
if (ContextObject == NULL) {
closesocket(sClientSocket);
sClientSocket = INVALID_SOCKET;
return;
}
ContextObject->sClientSocket = sClientSocket;
// 尝试解析 Proxy Protocol v2 头,获取真实客户端 IP
// 如果解析成功,更新 PeerName否则保持 getpeername 的结果
std::string realIP;
if (ParseProxyProtocolV2(sClientSocket, realIP)) {
ContextObject->SetPeerName (realIP);
}
// IP 黑名单和封禁检查
std::string clientIP = ContextObject->GetPeerName().empty() ?
inet_ntoa(ClientAddr.sin_addr) : ContextObject->GetPeerName();
// 先检查黑名单
if (IsIPBlacklisted(clientIP)) {
delete ContextObject;
closesocket(sClientSocket);
return;
}
// 再检查临时封禁
if (IsIPBanned(clientIP)) {
delete ContextObject;
closesocket(sClientSocket);
return;
}
RecordConnection(clientIP);
// 试用版反代理 —— 入站 IP 段检测(即时触发,对合作型代理透明)
//
// 与 RttPollThreadProc 的 SIO_TCP_INFO 检测互补RTT 测的是"我↔直接 TCP 对端"
// 任何 TCP 终结型代理都能欺骗它;本检测用 Proxy Protocol v2 解出的真实 IP若有
// 或 getpeername 的 raw IP 直接判私网段。
// - 覆盖:直连 WAN、PP2 透出真实 IP 是公网
// - 不覆盖socat / 不发 PP2 的中转 —— 那种场景仍由客户端 LANRttChecker 兜底
//
// 性能:每个新连接走一次 IsPrivateIPv4Str几个位运算不放心跳路径可忽略。
// 不主动断开连接(与 RTT 路径一致仅告警),由运营商看到弹框后自行处置。
// 详见 docs/Compliance_TechnicalMeasures.md口径文档可能比这里更新
if (m_bTrialMode && !LANChecker::IsPrivateIPv4Str(clientIP)) {
Mprintf("[Compliance] !!! Trial-mode WAN inbound: IP=%s (resolved via %s).\n",
clientIP.c_str(),
ContextObject->GetPeerName().empty() ? "getpeername" : "Proxy Protocol v2 or getpeername");
bool expected = false;
if (s_TrialAbuseWarned.compare_exchange_strong(expected, true) && m_hMainWnd) {
// CString* 由 OnTrialWanIpAbuse handler 负责 delete与 OnShowErrMessage 一致
PostMessageA(m_hMainWnd, WM_TRIAL_WAN_IP_ABUSE,
(WPARAM)new CString(clientIP.c_str()), 0);
}
}
ContextObject->wsaInBuf.buf = (char*)ContextObject->szBuffer;
ContextObject->wsaInBuf.len = sizeof(ContextObject->szBuffer);
HANDLE Handle = CreateIoCompletionPort((HANDLE)sClientSocket, m_hCompletionPort, (ULONG_PTR)ContextObject, 0);
if (Handle!=m_hCompletionPort) {
delete ContextObject;
ContextObject = NULL;
if (sClientSocket!=INVALID_SOCKET) {
closesocket(sClientSocket);
sClientSocket = INVALID_SOCKET;
}
return;
}
//设置套接字的选项卡 Set KeepAlive 开启保活机制 SO_KEEPALIVE
//保持连接检测对方主机是否崩溃如果2小时内在此套接口的任一方向都没
//有数据交换TCP就自动给对方 发一个保持存活
m_ulKeepLiveTime = 1000 * 60 * 3;
const BOOL bKeepAlive = TRUE;
setsockopt(ContextObject->sClientSocket,SOL_SOCKET,SO_KEEPALIVE,(char*)&bKeepAlive,sizeof(bKeepAlive));
//设置超时详细信息
tcp_keepalive KeepAlive;
KeepAlive.onoff = 1; // 启用保活
KeepAlive.keepalivetime = m_ulKeepLiveTime; //超过3分钟没有数据就发送探测包
KeepAlive.keepaliveinterval = 1000 * 10; //重试间隔为10秒 Resend if No-Reply
WSAIoctl(ContextObject->sClientSocket, SIO_KEEPALIVE_VALS,&KeepAlive,sizeof(KeepAlive),
NULL,0,(unsigned long *)&bKeepAlive,0,NULL);
//在做服务器时如果发生客户端网线或断电等非正常断开的现象如果服务器没有设置SO_KEEPALIVE选项
//则会一直不关闭SOCKET。因为上的的设置是默认两个小时时间太长了所以我们就修正这个值
EnterCriticalSection(&m_cs);
m_ContextConnectionList.AddTail(ContextObject); //插入到我们的内存列表中
LeaveCriticalSection(&m_cs);
OVERLAPPEDPLUS *OverlappedPlus = new OVERLAPPEDPLUS(IOInitialize); //注意这里的重叠IO请求是 用户请求上线
BOOL bOk = PostQueuedCompletionStatus(m_hCompletionPort, 0, (ULONG_PTR)ContextObject, &OverlappedPlus->m_ol); // 工作线程
//因为我们接受到了一个用户上线的请求那么我们就将该请求发送给我们的完成端口 让我们的工作线程处理它
if ( (!bOk && GetLastError() != ERROR_IO_PENDING)) { //如果投递失败
int a = GetLastError();
Mprintf("!!! OnAccept 投递消息失败\n");
RemoveStaleContext(ContextObject);
SAFE_DELETE(OverlappedPlus);
return;
}
PostRecv(ContextObject);
}
VOID IOCPServer::PostRecv(CONTEXT_OBJECT* ContextObject)
{
//向我们的刚上线的用户的投递一个接受数据的请求
// 如果用户的第一个数据包到达也就就是被控端的登陆请求到达我们的工作线程就
// 会响应,并调用ProcessIOMessage函数
OVERLAPPEDPLUS * OverlappedPlus = new OVERLAPPEDPLUS(IORead);
ContextObject->olps = OverlappedPlus;
DWORD dwReturn;
ULONG ulFlags = MSG_PARTIAL;
int iOk = WSARecv(ContextObject->sClientSocket, &ContextObject->wsaInBuf,
1,&dwReturn, &ulFlags,&OverlappedPlus->m_ol, NULL);
if (iOk == SOCKET_ERROR && WSAGetLastError() != WSA_IO_PENDING) {
int a = GetLastError();
Mprintf("!!! PostRecv 投递消息失败\n");
RemoveStaleContext(ContextObject);
SAFE_DELETE(OverlappedPlus);
}
}
PCONTEXT_OBJECT IOCPServer::AllocateContext(SOCKET s)
{
PCONTEXT_OBJECT ContextObject = NULL;
CLock cs(m_cs);
if (m_ContextConnectionList.GetCount() >= m_ulMaxConnections) {
static uint64_t notifyTime = 0;
auto now = time(0);
if (now - notifyTime > 15) {
notifyTime = now;
Mprintf("!!! AllocateContext: 达到最大连接数 %lu拒绝新连接\n", m_ulMaxConnections);
if (m_hMainWnd) {
char tip[256];
sprintf_s(tip, _TRF("达到最大连接数限制: %lu, 请释放连接"), m_ulMaxConnections);
PostMessageA(m_hMainWnd, WM_SHOWNOTIFY, (WPARAM)new CharMsg(_TR("达到最大连接数")),
(LPARAM)new CharMsg(tip));
}
}
return NULL;
}
ContextObject = !m_ContextFreePoolList.IsEmpty() ? m_ContextFreePoolList.RemoveHead() : new CONTEXT_OBJECT;
if (ContextObject != NULL) {
ContextObject->InitMember(s, this);
}
return ContextObject;
}
BOOL IOCPServer::RemoveStaleContext(CONTEXT_OBJECT* ContextObject)
{
EnterCriticalSection(&m_cs);
auto find = m_ContextConnectionList.Find(ContextObject);
LeaveCriticalSection(&m_cs);
if (find) { //在内存中查找该用户的上下文数据结构
// 标记对象为已移除,防止新的 I/O 处理开始 (#215)
ContextObject->IsRemoved.store(true);
m_OfflineProc(ContextObject);
CancelIo((HANDLE)ContextObject->sClientSocket); //取消在当前套接字的异步IO -->PostRecv
closesocket(ContextObject->sClientSocket); //关闭套接字
ContextObject->sClientSocket = INVALID_SOCKET;
// 等待所有正在处理此对象的工作线程完成 (#215)
// 注意:之前的 HasOverlappedIoCompleted((LPOVERLAPPED)ContextObject) 是错误的,
// 因为 CONTEXT_OBJECT 不是 OVERLAPPED其第一个字段是虚函数表指针
// 导致检查总是立即返回 TRUE造成竞态条件崩溃
int waitCount = 0;
while (ContextObject->IoRefCount.load() > 0) {
Sleep(1);
if (++waitCount > 5000) { // 5秒超时保护
Mprintf("!!! RemoveStaleContext: IoRefCount wait timeout (ref=%d)\n",
ContextObject->IoRefCount.load());
break;
}
}
MoveContextToFreePoolList(ContextObject); //将该内存结构回收至内存池
return TRUE;
}
return FALSE;
}
VOID IOCPServer::MoveContextToFreePoolList(CONTEXT_OBJECT* ContextObject)
{
CLock cs(m_cs);
POSITION Pos = m_ContextConnectionList.Find(ContextObject);
if (Pos) {
ContextObject->InCompressedBuffer.ClearBuffer();
ContextObject->InDeCompressedBuffer.ClearBuffer();
ContextObject->OutCompressedBuffer.ClearBuffer();
memset(ContextObject->szBuffer,0,8192);
m_ContextFreePoolList.AddTail(ContextObject); //回收至内存池
m_ContextConnectionList.RemoveAt(Pos); //从内存结构中移除
}
}
void IOCPServer::UpdateMaxConnection(int maxConn)
{
CLock cs(m_cs);
m_ulMaxConnections = maxConn;
}