Files
SimpleRemoter/server/2015Remote/IOCPServer.cpp
2026-04-19 22:55:21 +02:00

1174 lines
42 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "StdAfx.h"
#include "IOCPServer.h"
#include "2015Remote.h"
#include "common/IPWhitelist.h"
#include "common/IPBlacklist.h"
#include <iostream>
#include <ws2tcpip.h>
// Proxy Protocol v2 签名 (12 字节)
static const unsigned char PROXY_PROTOCOL_V2_SIGNATURE[12] = {
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A
};
// 解析 Proxy Protocol v2 头,返回真实客户端 IP
// 成功返回 true 并设置 realIP失败返回 false
// 如果不是 Proxy Protocol返回 false 且不消费任何数据
static bool ParseProxyProtocolV2(SOCKET sock, std::string& realIP)
{
// 等待数据就绪(最多 200ms解决 FRP Proxy Protocol 头延迟到达的时序问题
fd_set readfds;
FD_ZERO(&readfds);
FD_SET(sock, &readfds);
struct timeval tv = { 0, 200000 }; // 200ms
int ready = select((int)sock + 1, &readfds, NULL, NULL, &tv);
if (ready <= 0) {
return false; // 超时或错误,当作普通连接
}
// 先 peek 前 16 字节12 签名 + 4 头部)
unsigned char header[16];
int n = recv(sock, (char*)header, 16, MSG_PEEK);
if (n < 12) {
// 数据不足,检查已有数据是否匹配签名前缀
if (n > 0 && memcmp(header, PROXY_PROTOCOL_V2_SIGNATURE, n) != 0) {
return false; // 不匹配,不是 Proxy Protocol
}
// 数据太少无法判断,当作普通连接
return false;
}
// 检查签名
if (memcmp(header, PROXY_PROTOCOL_V2_SIGNATURE, 12) != 0) {
return false; // 签名不匹配,不是 Proxy Protocol
}
if (n < 16) {
// 有签名但头部不完整,等待更多数据(理论上不会发生)
return false;
}
// 解析版本和命令 (byte 12)
// 高 4 位是版本 (应该是 0x2),低 4 位是命令 (0x0=LOCAL, 0x1=PROXY)
unsigned char verCmd = header[12];
unsigned char version = (verCmd >> 4) & 0x0F;
unsigned char command = verCmd & 0x0F;
if (version != 2) {
// 不是 v2可能是 v1 或无效
return false;
}
// 解析地址族和协议 (byte 13)
// 高 4 位是地址族 (0x1=AF_INET, 0x2=AF_INET6)
// 低 4 位是协议 (0x1=STREAM, 0x2=DGRAM)
unsigned char famProto = header[13];
unsigned char addrFamily = (famProto >> 4) & 0x0F;
// 解析地址长度 (bytes 14-15, big-endian)
unsigned short addrLen = (header[14] << 8) | header[15];
// 计算完整头部长度 (IPv4: 16+12=28, IPv6: 16+36=52)
int totalHeaderLen = 16 + addrLen;
if (totalHeaderLen > 108) { // 安全上限16 + 最大 TLV 长度
return false;
}
// 读取完整头部(真正消费数据),使用固定数组避免动态分配
unsigned char fullHeader[108];
int received = 0;
while (received < totalHeaderLen) {
int r = recv(sock, (char*)fullHeader + received, totalHeaderLen - received, 0);
if (r <= 0) {
return false; // 接收失败
}
received += r;
}
// 如果是 LOCAL 命令,使用 socket 的对端地址
if (command == 0x00) {
return false; // 让调用者使用 getpeername
}
// 解析地址
if (addrFamily == 0x01 && addrLen >= 12) {
// IPv4: src_addr(4) + dst_addr(4) + src_port(2) + dst_port(2)
unsigned char* addr = fullHeader + 16;
char ipStr[INET_ADDRSTRLEN];
snprintf(ipStr, sizeof(ipStr), "%u.%u.%u.%u", addr[0], addr[1], addr[2], addr[3]);
realIP = ipStr;
return true;
} else if (addrFamily == 0x02 && addrLen >= 36) {
// IPv6: src_addr(16) + dst_addr(16) + src_port(2) + dst_port(2)
unsigned char* addr = fullHeader + 16;
char ipStr[INET6_ADDRSTRLEN];
inet_ntop(AF_INET6, addr, ipStr, sizeof(ipStr));
realIP = ipStr;
return true;
}
return false; // 未知地址族
}
// 根据 socket 获取客户端IP地址.
std::string GetPeerName(SOCKET sock)
{
sockaddr_in ClientAddr = {};
int ulClientAddrLen = sizeof(sockaddr_in);
int s = getpeername(sock, (SOCKADDR*)&ClientAddr, &ulClientAddrLen);
return s != INVALID_SOCKET ? inet_ntoa(ClientAddr.sin_addr) : "";
}
// 根据 socket 获取客户端IP地址.
std::string GetRemoteIP(SOCKET sock)
{
sockaddr_in addr;
int addrLen = sizeof(addr);
if (getpeername(sock, (sockaddr*)&addr, &addrLen) == 0) {
char ipStr[INET_ADDRSTRLEN];
inet_ntop(AF_INET, &addr.sin_addr, ipStr, sizeof(ipStr));
TRACE(">>> 对端 IP 地址: %s\n", ipStr);
return ipStr;
}
TRACE(">>> 获取对端 IP 失败, 错误码: %d\n", WSAGetLastError());
char buf[10];
sprintf_s(buf, "%d", sock);
return buf;
}
// IP 连接限流配置 (缓存,避免频繁读取注册表)
static struct {
int banWindow = 60;
int banMaxConn = 15;
int banDuration = 3600;
bool loaded = false;
} g_BanConfig;
void ReloadBanConfig() {
g_BanConfig.banWindow = THIS_CFG.GetInt("settings", "BanWindow", 60);
g_BanConfig.banMaxConn = THIS_CFG.GetInt("settings", "BanMaxConn", 15);
g_BanConfig.banDuration = THIS_CFG.GetInt("settings", "BanDuration", 3600);
g_BanConfig.loaded = true;
}
static int GetBanWindowSeconds() {
if (!g_BanConfig.loaded) ReloadBanConfig();
return g_BanConfig.banWindow;
}
static int GetBanMaxConnections() {
if (!g_BanConfig.loaded) ReloadBanConfig();
return g_BanConfig.banMaxConn;
}
static int GetBanDurationSeconds() {
if (!g_BanConfig.loaded) ReloadBanConfig();
return g_BanConfig.banDuration;
}
// 检查 IP 是否被封禁
bool IOCPServer::IsIPBanned(const std::string& ip)
{
// 检查白名单 (包含本地地址检查)
if (IPWhitelist::getInstance().IsWhitelisted(ip)) {
return false;
}
CLock lock(m_BanLock);
auto it = m_BannedIPs.find(ip);
if (it != m_BannedIPs.end()) {
time_t now = time(nullptr);
if (now < it->second) {
// 仍在封禁期内
return true;
}
// 封禁已过期,移除
m_BannedIPs.erase(it);
}
return false;
}
// 检查 IP 是否在黑名单中
bool IOCPServer::IsIPBlacklisted(const std::string& ip)
{
if (IPBlacklist::getInstance().IsBlacklisted(ip)) {
// 防刷频日志
if (IPBlacklist::getInstance().ShouldLog(ip)) {
Mprintf("Connection rejected: %s (blacklisted)\n", ip.c_str());
if (m_hMainWnd) {
char tip[256];
sprintf_s(tip, _TRF("IP %s 连接被拒绝 (黑名单)"), ip.c_str());
PostMessageA(m_hMainWnd, WM_SHOWERRORMSG,
(WPARAM)new CString(tip), (LPARAM)new CString(_TR("黑名单")));
}
}
return true;
}
return false;
}
// 记录连接并检测异常
void IOCPServer::RecordConnection(const std::string& ip)
{
// 检查白名单 (包含本地地址检查)
if (IPWhitelist::getInstance().IsWhitelisted(ip)) {
return;
}
bool shouldBan = false;
{
CLock lock(m_BanLock);
time_t now = time(nullptr);
int banWindow = GetBanWindowSeconds();
// 定期清理过期的连接计数 (每 1000 次检查一次,或条目超过 10000)
static int cleanupCounter = 0;
if (++cleanupCounter >= 1000 || m_ConnectionCount.size() > 10000) {
cleanupCounter = 0;
for (auto it = m_ConnectionCount.begin(); it != m_ConnectionCount.end(); ) {
if (now - it->second.windowStart >= banWindow) {
it = m_ConnectionCount.erase(it);
} else {
++it;
}
}
}
auto it = m_ConnectionCount.find(ip);
if (it == m_ConnectionCount.end()) {
// 新 IP开始计数
m_ConnectionCount[ip] = { 1, now };
} else {
// 检查是否在同一个统计窗口内
if (now - it->second.windowStart < banWindow) {
it->second.count++;
// 检查是否超过阈值
if (it->second.count > GetBanMaxConnections()) {
shouldBan = true;
}
} else {
// 新窗口,重置计数
it->second.count = 1;
it->second.windowStart = now;
}
}
}
if (shouldBan) {
BanIP(ip, GetBanDurationSeconds());
}
}
// 封禁 IP
void IOCPServer::BanIP(const std::string& ip, int seconds)
{
{
CLock lock(m_BanLock);
time_t expiry = time(nullptr) + seconds;
m_BannedIPs[ip] = expiry;
// 清除连接计数
m_ConnectionCount.erase(ip);
}
Mprintf("IP banned: %s (duration: %d seconds, reason: too many connections)\n",
ip.c_str(), seconds);
// 发送到主窗口信息列表
if (m_hMainWnd) {
char tip[256];
sprintf_s(tip, _TRF("IP %s 已封禁 %d 秒 (连接过于频繁)"), ip.c_str(), seconds);
PostMessageA(m_hMainWnd, WM_SHOWERRORMSG, (WPARAM)new CString(tip),(LPARAM)new CString(_TR("IP 封禁")));
}
}
// 从配置文件加载 IP 白名单
void IOCPServer::LoadIPWhitelist()
{
// 配置格式: IPWhitelist=192.168.1.1;10.0.0.1;172.16.0.100
std::string whitelist = THIS_CFG.GetStr("settings", "IPWhitelist", "");
IPWhitelist::getInstance().Load(whitelist);
size_t count = IPWhitelist::getInstance().Count();
if (count > 0) {
Mprintf("IP whitelist loaded: %zu IPs\n", count);
}
}
// 从配置文件加载 IP 黑名单
void IOCPServer::LoadIPBlacklist()
{
// 配置格式: IPBlacklist=192.168.1.1;10.0.0.1
std::string blacklist = THIS_CFG.GetStr("settings", "IPBlacklist", "");
IPBlacklist::getInstance().Load(blacklist);
size_t count = IPBlacklist::getInstance().Count();
if (count > 0) {
Mprintf("IP blacklist loaded: %zu IPs\n", count);
}
}
IOCPServer::IOCPServer(HWND hWnd)
{
m_hMainWnd = hWnd;
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2,2), &wsaData)!=0) {
return;
}
m_hCompletionPort = NULL;
m_sListenSocket = INVALID_SOCKET;
m_hListenEvent = WSA_INVALID_EVENT;
m_hListenThread = NULL;
m_ulMaxConnections = 10000;
InitializeCriticalSection(&m_cs);
InitializeCriticalSection(&m_BanLock);
LoadIPWhitelist();
LoadIPBlacklist();
m_ulWorkThreadCount = 0;
m_bTimeToKill = FALSE;
m_ulThreadPoolMin = 0;
m_ulThreadPoolMax = 0;
m_ulCPULowThreadsHold = 0;
m_ulCPUHighThreadsHold = 0;
m_ulCurrentThread = 0;
m_ulBusyThread = 0;
m_ulKeepLiveTime = 0;
m_hKillEvent = NULL;
m_NotifyProc = NULL;
m_OfflineProc = NULL;
}
void IOCPServer::Destroy()
{
m_bTimeToKill = TRUE;
if (m_hKillEvent != NULL) {
SetEvent(m_hKillEvent);
SAFE_CLOSE_HANDLE(m_hKillEvent);
m_hKillEvent = NULL;
}
if (m_sListenSocket != INVALID_SOCKET) {
closesocket(m_sListenSocket);
m_sListenSocket = INVALID_SOCKET;
}
if (m_hCompletionPort != INVALID_HANDLE_VALUE) {
SAFE_CLOSE_HANDLE(m_hCompletionPort);
m_hCompletionPort = INVALID_HANDLE_VALUE;
}
if (m_hListenEvent != WSA_INVALID_EVENT) {
SAFE_CLOSE_HANDLE(m_hListenEvent);
m_hListenEvent = WSA_INVALID_EVENT;
}
}
IOCPServer::~IOCPServer(void)
{
Destroy();
while (m_ulWorkThreadCount || m_hListenThread)
Sleep(10);
while (!m_ContextConnectionList.IsEmpty()) {
CONTEXT_OBJECT *ContextObject = m_ContextConnectionList.GetHead();
RemoveStaleContext(ContextObject);
SAFE_DELETE(ContextObject->olps);
}
while (!m_ContextFreePoolList.IsEmpty()) {
CONTEXT_OBJECT *ContextObject = m_ContextFreePoolList.RemoveHead();
// 下述语句有崩溃概率2019.1.14
//SAFE_DELETE(ContextObject->olps);
delete ContextObject;
}
DeleteCriticalSection(&m_cs);
DeleteCriticalSection(&m_BanLock);
m_ulWorkThreadCount = 0;
m_ulThreadPoolMin = 0;
m_ulThreadPoolMax = 0;
m_ulCPULowThreadsHold = 0;
m_ulCPUHighThreadsHold = 0;
m_ulCurrentThread = 0;
m_ulBusyThread = 0;
m_ulKeepLiveTime = 0;
WSACleanup();
}
// 返回错误码0代表成功否则代表错误信息.
UINT IOCPServer::StartServer(pfnNotifyProc NotifyProc, pfnOfflineProc OffProc, USHORT uPort)
{
m_nPort = uPort;
m_NotifyProc = NotifyProc;
m_OfflineProc = OffProc;
m_hKillEvent = CreateEvent(NULL,FALSE,FALSE,NULL);
if (m_hKillEvent==NULL) {
return 1;
}
m_sListenSocket = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); //创建监听套接字
if (m_sListenSocket == INVALID_SOCKET) {
return 2;
}
m_hListenEvent = WSACreateEvent();
if (m_hListenEvent == WSA_INVALID_EVENT) {
closesocket(m_sListenSocket);
m_sListenSocket = INVALID_SOCKET;
return 3;
}
int iRet = WSAEventSelect(m_sListenSocket, //将监听套接字与事件进行关联并授予FD_ACCEPT的属性
m_hListenEvent,
FD_ACCEPT);
if (iRet == SOCKET_ERROR) {
int a = GetLastError();
closesocket(m_sListenSocket);
m_sListenSocket = INVALID_SOCKET;
WSACloseEvent(m_hListenEvent);
m_hListenEvent = WSA_INVALID_EVENT;
return a;
}
SOCKADDR_IN ServerAddr;
ServerAddr.sin_port = htons(uPort);
ServerAddr.sin_family = AF_INET;
ServerAddr.sin_addr.s_addr = INADDR_ANY; //初始化本地网卡
//将监听套机字和网卡进行bind
iRet = bind(m_sListenSocket,
(sockaddr*)&ServerAddr,
sizeof(ServerAddr));
if (iRet == SOCKET_ERROR) {
int a = GetLastError();
closesocket(m_sListenSocket);
m_sListenSocket = INVALID_SOCKET;
WSACloseEvent(m_hListenEvent);
m_hListenEvent = WSA_INVALID_EVENT;
return a;
}
iRet = listen(m_sListenSocket, SOMAXCONN);
if (iRet == SOCKET_ERROR) {
int a = GetLastError();
closesocket(m_sListenSocket);
m_sListenSocket = INVALID_SOCKET;
WSACloseEvent(m_hListenEvent);
m_hListenEvent = WSA_INVALID_EVENT;
return a;
}
m_hListenThread =
(HANDLE)CreateThread(NULL,
0,
ListenThreadProc,
(void*)this, //向Thread回调函数传入this 方便我们的线程回调访问类中的成员
0,
NULL);
if (m_hListenThread==NULL) {
int a = GetLastError();
closesocket(m_sListenSocket);
m_sListenSocket = INVALID_SOCKET;
WSACloseEvent(m_hListenEvent);
m_hListenEvent = WSA_INVALID_EVENT;
return a;
}
//启动工作线程 1 2
InitializeIOCP();
return 0;
}
//1创建完成端口
//2创建工作线程
BOOL IOCPServer::InitializeIOCP(VOID)
{
m_hCompletionPort = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0 );
if ( m_hCompletionPort == NULL ) {
return FALSE;
}
if (m_hCompletionPort==INVALID_HANDLE_VALUE) {
return FALSE;
}
SYSTEM_INFO SystemInfo;
GetSystemInfo(&SystemInfo); //获得PC中有几核
m_ulThreadPoolMin = 1;
m_ulThreadPoolMax = SystemInfo.dwNumberOfProcessors * 2;
m_ulCPULowThreadsHold = 10;
m_ulCPUHighThreadsHold = 75;
ULONG ulWorkThreadCount = m_ulThreadPoolMax;
HANDLE hWorkThread = NULL;
for (int i=0; i<ulWorkThreadCount; ++i) {
hWorkThread = (HANDLE)CreateThread(NULL, //创建工作线程目的是处理投递到完成端口中的任务
0,
WorkThreadProc,
(void*)this,
0,
NULL);
if (hWorkThread == NULL ) {
SAFE_CLOSE_HANDLE(m_hCompletionPort);
return FALSE;
}
AddWorkThread(1);
SAFE_CLOSE_HANDLE(hWorkThread);
}
return TRUE;
}
DWORD IOCPServer::WorkThreadProc(LPVOID lParam)
{
// 压缩库配置
ZSTD_DCtx* m_Dctx = ZSTD_createDCtx(); // 解压上下文
z_stream m_stream = {};
inflateInit2(&m_stream, 15);
IOCPServer* This = (IOCPServer*)(lParam);
HANDLE hCompletionPort = This->m_hCompletionPort;
DWORD dwTrans = 0;
PCONTEXT_OBJECT ContextObject = NULL;
LPOVERLAPPED Overlapped = NULL;
OVERLAPPEDPLUS* OverlappedPlus = NULL;
ULONG ulBusyThread = 0;
BOOL bError = FALSE;
InterlockedIncrement(&This->m_ulCurrentThread);
InterlockedIncrement(&This->m_ulBusyThread);
timeBeginPeriod(1);
while (This->m_bTimeToKill==FALSE) {
InterlockedDecrement(&This->m_ulBusyThread);
// GetQueuedCompletionStatus耗时比较长导致客户端发送数据的速率提高不了
BOOL bOk = GetQueuedCompletionStatus(
hCompletionPort,
&dwTrans,
(PULONG_PTR)&ContextObject,
&Overlapped, INFINITE);
DWORD dwIOError = GetLastError();
OverlappedPlus = CONTAINING_RECORD(Overlapped, OVERLAPPEDPLUS, m_ol);
ulBusyThread = InterlockedIncrement(&This->m_ulBusyThread); //1 1
if ( !bOk && dwIOError != WAIT_TIMEOUT ) { //当对方的套机制发生了关闭
if (ContextObject && This->m_bTimeToKill == FALSE &&dwTrans==0) {
ContextObject->olps = NULL;
Mprintf("!!! RemoveStaleContext: %d \n", WSAGetLastError());
This->RemoveStaleContext(ContextObject);
}
SAFE_DELETE(OverlappedPlus);
continue;
}
if (!bError) {
//分配一个新的线程到线程到线程池
if (ulBusyThread == This->m_ulCurrentThread) {
if (ulBusyThread < This->m_ulThreadPoolMax) {
if (ContextObject != NULL) {
HANDLE hThread = (HANDLE)CreateThread(NULL,
0,
WorkThreadProc,
(void*)This,
0,
NULL);
This->AddWorkThread(hThread ? 1:0);
SAFE_CLOSE_HANDLE(hThread);
}
}
}
if (!bOk && dwIOError == WAIT_TIMEOUT) {
if (ContextObject == NULL) {
if (This->m_ulCurrentThread > This->m_ulThreadPoolMin) {
break;
}
bError = TRUE;
}
}
}
if (!bError && !This->m_bTimeToKill) {
if(bOk && OverlappedPlus!=NULL && ContextObject!=NULL) {
try {
This->HandleIO(OverlappedPlus->m_ioType, ContextObject, dwTrans, m_Dctx, &m_stream);
ContextObject = NULL;
} catch (...) {
Mprintf("This->HandleIO catched an error!!!");
}
}
}
SAFE_DELETE(OverlappedPlus);
}
timeEndPeriod(1);
SAFE_DELETE(OverlappedPlus);
InterlockedDecrement(&This->m_ulCurrentThread);
InterlockedDecrement(&This->m_ulBusyThread);
int n= This->AddWorkThread(-1);
if (n == 0) {
Mprintf("======> IOCPServer All WorkThreadProc done\n");
}
inflateEnd(&m_stream);
ZSTD_freeDCtx(m_Dctx);
return 0;
}
//在工作线程中被调用
BOOL IOCPServer::HandleIO(IOType PacketFlags,PCONTEXT_OBJECT ContextObject, DWORD dwTrans, ZSTD_DCtx* ctx, z_stream* z)
{
// 防止竞态条件 (#215)
// 必须先增加引用计数,再检查 IsRemoved。
// 顺序很重要!如果先检查再增加,会有 TOCTOU 竞态窗口:
// 在检查和增加之间RemoveStaleContext 可能完成并重用对象。
ContextObject->IoRefCount.fetch_add(1);
// 检查对象是否已被标记为移除
if (ContextObject->IsRemoved.load()) {
ContextObject->IoRefCount.fetch_sub(1);
return FALSE;
}
BOOL bRet = FALSE;
switch (PacketFlags) {
case IOInitialize:
bRet = OnClientInitializing(ContextObject, dwTrans);
break;
case IORead:
bRet = OnClientReceiving(ContextObject, dwTrans, ctx, z);
break;
case IOWrite:
bRet = OnClientPostSending(ContextObject, dwTrans);
break;
case IOIdle:
Mprintf("=> HandleIO PacketFlags= IOIdle\n");
break;
default:
break;
}
// 减少引用计数 (#215)
// 特殊返回值 -2 表示内部函数已经减少了引用计数并调用了 RemoveStaleContext
// 此时对象可能已在空闲池或被重用,不能再访问
if (bRet != -2) {
ContextObject->IoRefCount.fetch_sub(1);
}
return bRet == -2 ? FALSE : bRet;
}
BOOL IOCPServer::OnClientInitializing(PCONTEXT_OBJECT ContextObject, DWORD dwTrans)
{
return TRUE;
}
// May be this function should be a member of `CONTEXT_OBJECT`.
BOOL ParseReceivedData(CONTEXT_OBJECT * ContextObject, DWORD dwTrans, pfnNotifyProc m_NotifyProc, ZSTD_DCtx* m_Dctx, z_stream* z)
{
AUTO_TICK(50, ContextObject->GetPeerName());
BOOL ret = 1;
try {
if (dwTrans == 0) { //对方关闭了套接字
return FALSE;
}
//将接收到的数据拷贝到我们自己的内存中wsabuff 8192
ContextObject->InCompressedBuffer.WriteBuffer((PBYTE)ContextObject->szBuffer,dwTrans);
//查看数据包的完整性
while (true) {
PR pr = ContextObject->Parse(ContextObject->InCompressedBuffer);
if (pr.IsFailed()) {
ContextObject->InCompressedBuffer.ClearBuffer();
break;
} else if (pr.IsNeedMore()) {
break;
} else if (pr.IsWinOSLogin()) {
ContextObject->InDeCompressedBuffer.ClearBuffer();
ULONG ulCompressedLength = 0;
ULONG ulOriginalLength = 0;
PBYTE CompressedBuffer = ContextObject->ReadBuffer(ulCompressedLength, ulOriginalLength);
ContextObject->InDeCompressedBuffer.WriteBuffer(CompressedBuffer, ulCompressedLength);
if (m_NotifyProc(ContextObject))
ret = CompressedBuffer[0] == TOKEN_LOGIN ? 999 : 1;
// CompressedBuffer 由 CONTEXT_OBJECT 管理,不在此处释放
break;
}
ULONG ulPackTotalLength = 0;
ContextObject->InCompressedBuffer.CopyBuffer(&ulPackTotalLength, sizeof(ULONG), pr.Result);
//取出数据包的总长度5字节标识+4字节数据包总长度+4字节原始数据长度
int bufLen = ContextObject->InCompressedBuffer.GetBufferLength();
if (ulPackTotalLength && bufLen >= ulPackTotalLength) {
ULONG ulCompressedLength = 0;
ULONG ulOriginalLength = 0;
PBYTE CompressedBuffer = ContextObject->ReadBuffer(ulCompressedLength, ulOriginalLength);
if (ContextObject->CompressMethod == COMPRESS_UNKNOWN) {
// CompressedBuffer 由 CONTEXT_OBJECT 管理,不在此处释放
throw "Unknown method";
} else if (ContextObject->CompressMethod == COMPRESS_NONE) {
ContextObject->InDeCompressedBuffer.ClearBuffer();
ContextObject->Decode(CompressedBuffer, ulOriginalLength);
ContextObject->InDeCompressedBuffer.WriteBuffer(CompressedBuffer, ulOriginalLength);
if (m_NotifyProc(ContextObject))
ret = CompressedBuffer[0] == TOKEN_LOGIN ? 999 : 1;
// CompressedBuffer 由 CONTEXT_OBJECT 管理,不在此处释放
continue;
}
bool usingZstd = ContextObject->CompressMethod == COMPRESS_ZSTD, zlibFailed = false;
// 使用预分配缓冲区,避免频繁内存分配
PBYTE DeCompressedBuffer = ContextObject->GetDecompressBuffer(ulOriginalLength);
size_t iRet = usingZstd ?
Muncompress(DeCompressedBuffer, &ulOriginalLength, CompressedBuffer, ulCompressedLength) :
z_uncompress(z, DeCompressedBuffer, &ulOriginalLength, CompressedBuffer, ulCompressedLength);
if (usingZstd ? C_SUCCESS(iRet) : (S_OK==iRet)) {
ContextObject->InDeCompressedBuffer.ClearBuffer();
ContextObject->Decode(DeCompressedBuffer, ulOriginalLength);
ContextObject->InDeCompressedBuffer.WriteBuffer(DeCompressedBuffer, ulOriginalLength);
if (m_NotifyProc(ContextObject))
ret = DeCompressedBuffer[0] == TOKEN_LOGIN ? 999 : 1;
} else if (usingZstd) {
// 尝试用zlib解压缩
if (Z_OK == z_uncompress(z, DeCompressedBuffer, &ulOriginalLength, CompressedBuffer, ulCompressedLength)) {
ContextObject->CompressMethod = COMPRESS_ZLIB;
ContextObject->InDeCompressedBuffer.ClearBuffer();
ContextObject->Decode(DeCompressedBuffer, ulOriginalLength);
ContextObject->InDeCompressedBuffer.WriteBuffer(DeCompressedBuffer, ulOriginalLength);
if (m_NotifyProc(ContextObject))
ret = DeCompressedBuffer[0] == TOKEN_LOGIN ? 999 : 1;
} else {
zlibFailed = true;
// 注意:不设置 COMPRESS_UNKNOWN后续包仍尝试用 ZSTD
}
} else {
zlibFailed = true;
}
// CompressedBuffer 和 DeCompressedBuffer 都由 CONTEXT_OBJECT 管理,不在此处释放
if (zlibFailed) {
Mprintf("[ERROR] uncompress failed: cmd=0x%02X, compressed=%u, original=%u\n",
(unsigned char)CompressedBuffer[0], ulCompressedLength, ulOriginalLength);
throw "Bad Buffer"; // 抛出异常,在 catch 中清理缓冲区
}
} else {
break;
}
}
} catch(...) {
Mprintf("[ERROR] OnClientReceiving catch an error \n");
ContextObject->InCompressedBuffer.ClearBuffer();
ContextObject->InDeCompressedBuffer.ClearBuffer();
}
return ret;
}
BOOL IOCPServer::OnClientReceiving(PCONTEXT_OBJECT ContextObject, DWORD dwTrans, ZSTD_DCtx* ctx, z_stream* z)
{
if (FALSE == ParseReceivedData(ContextObject, dwTrans, m_NotifyProc, ctx, z)) {
// 先减少引用计数,再调用 RemoveStaleContext (#215)
// RemoveStaleContext 完成后对象会被移到空闲池,不能再访问
ContextObject->IoRefCount.fetch_sub(1);
RemoveStaleContext(ContextObject);
return -2; // 特殊返回值:告诉 HandleIO 不要再减少引用计数
}
PostRecv(ContextObject); //投递新的接收数据的请求
return TRUE;
}
BOOL WriteContextData(CONTEXT_OBJECT* ContextObject, PBYTE szBuffer, size_t ulOriginalLength, ZSTD_CCtx* m_Cctx, z_stream* z)
{
assert(ContextObject);
// 输出服务端所发送的命令
int cmd = szBuffer[0];
if (ulOriginalLength < 100 && cmd != COMMAND_SCREEN_CONTROL && cmd != CMD_HEARTBEAT_ACK &&
cmd != CMD_DRAW_POINT && cmd != CMD_MOVEWINDOW && cmd != CMD_SET_SIZE) {
char buf[100] = { 0 };
if (ulOriginalLength == 1) {
sprintf_s(buf, "command %d", cmd);
} else {
memcpy(buf, szBuffer, ulOriginalLength);
}
Mprintf("[COMMAND] Send: %s\r\n", buf);
}
try {
do {
if (ulOriginalLength <= 0) return FALSE;
if (ContextObject->CompressMethod == COMPRESS_UNKNOWN) {
Mprintf("[ERROR] UNKNOWN compress method \n");
return FALSE;
} else if (ContextObject->CompressMethod == COMPRESS_NONE) {
Buffer tmp(szBuffer, ulOriginalLength);
szBuffer = tmp.Buf();
ContextObject->WriteBuffer(szBuffer, ulOriginalLength, ulOriginalLength, cmd);
break;
}
bool usingZstd = ContextObject->CompressMethod == COMPRESS_ZSTD;
unsigned long ulCompressedLength = usingZstd ?
ZSTD_compressBound(ulOriginalLength) : (unsigned long)((double)ulOriginalLength * 1.001 + 12);
// 使用预分配缓冲区替代每次 new
LPBYTE CompressedBuffer = ContextObject->GetSendCompressBuffer(ulCompressedLength);
Buffer tmp(szBuffer, ulOriginalLength);
szBuffer = tmp.Buf();
ContextObject->Encode(szBuffer, ulOriginalLength);
if (!m_Cctx) ContextObject->Encode(szBuffer, ulOriginalLength, usingZstd);
size_t iRet = usingZstd ?
Mcompress(CompressedBuffer, &ulCompressedLength, (LPBYTE)szBuffer, ulOriginalLength, ContextObject->GetZstdLevel()):
compress(CompressedBuffer, &ulCompressedLength, (LPBYTE)szBuffer, ulOriginalLength);
if (usingZstd ? C_FAILED(iRet) : (S_OK != iRet)) {
Mprintf("[ERROR] compress failed \n");
// SendCompressBuffer 由 CONTEXT_OBJECT 管理,不在此处释放
return FALSE;
}
ulCompressedLength = usingZstd ? iRet : ulCompressedLength;
ContextObject->WriteBuffer(CompressedBuffer, ulCompressedLength, ulOriginalLength, cmd);
// SendCompressBuffer 由 CONTEXT_OBJECT 管理,不在此处释放
} while (false);
return TRUE;
} catch (...) {
Mprintf("[ERROR] OnClientPreSending catch an error \n");
return FALSE;
}
}
BOOL IOCPServer::OnClientPreSending(CONTEXT_OBJECT* ContextObject, PBYTE szBuffer, size_t ulOriginalLength)
{
if (WriteContextData(ContextObject, szBuffer, ulOriginalLength, ContextObject->Zcctx)) {
OVERLAPPEDPLUS* OverlappedPlus = new OVERLAPPEDPLUS(IOWrite);
BOOL bOk = PostQueuedCompletionStatus(m_hCompletionPort, 0, (ULONG_PTR)ContextObject, &OverlappedPlus->m_ol);
if ( (!bOk && GetLastError() != ERROR_IO_PENDING) ) { //如果投递失败
int a = GetLastError();
Mprintf("!!! OnClientPreSending 投递消息失败\n");
RemoveStaleContext(ContextObject);
SAFE_DELETE(OverlappedPlus);
return FALSE;
}
return TRUE;
}
return FALSE;
}
BOOL IOCPServer::OnClientPostSending(CONTEXT_OBJECT* ContextObject,ULONG ulCompletedLength)
{
CAutoCLock L(ContextObject->SendLock);
try {
DWORD ulFlags = MSG_PARTIAL;
ContextObject->OutCompressedBuffer.RemoveCompletedBuffer(ulCompletedLength); //将完成的数据从数据结构中去除
if (ContextObject->OutCompressedBuffer.GetBufferLength() == 0) {
ContextObject->OutCompressedBuffer.ClearBuffer();
return true; //走到这里说明我们的数据真正完全发送
} else {
OVERLAPPEDPLUS * OverlappedPlus = new OVERLAPPEDPLUS(IOWrite); //数据没有完成 我们继续投递 发送请求
ContextObject->wsaOutBuffer.buf = (char*)ContextObject->OutCompressedBuffer.GetBuffer(0);
ContextObject->wsaOutBuffer.len = ContextObject->OutCompressedBuffer.GetBufferLength();
int iOk = WSASend(ContextObject->sClientSocket, &ContextObject->wsaOutBuffer,1,
NULL, ulFlags,&OverlappedPlus->m_ol, NULL);
if ( iOk == SOCKET_ERROR && WSAGetLastError() != WSA_IO_PENDING ) {
// 先减少引用计数,再调用 RemoveStaleContext (#215)
ContextObject->IoRefCount.fetch_sub(1);
if (RemoveStaleContext(ContextObject))
Mprintf("!!! OnClientPostSending 投递消息失败: %d\n", WSAGetLastError());
SAFE_DELETE(OverlappedPlus);
return -2; // 特殊返回值:告诉 HandleIO 不要再减少引用计数
}
return TRUE;
}
} catch(...) {
Mprintf("[ERROR] OnClientPostSending catch an error \n");
}
return FALSE;
}
DWORD IOCPServer::ListenThreadProc(LPVOID lParam) //监听线程
{
IOCPServer* This = (IOCPServer*)(lParam);
WSANETWORKEVENTS NetWorkEvents;
while(!This->m_bTimeToKill) {
if (WaitForSingleObject(This->m_hKillEvent, 100) == WAIT_OBJECT_0)
break;
DWORD dwRet;
dwRet = WSAWaitForMultipleEvents(1,&This->m_hListenEvent,FALSE,100,FALSE);
if (dwRet == WSA_WAIT_TIMEOUT)
continue;
int iRet = WSAEnumNetworkEvents(This->m_sListenSocket,
//如果事件授信 我们就将该事件转换成一个网络事件 进行 判断
This->m_hListenEvent,
&NetWorkEvents);
if (iRet == SOCKET_ERROR)
break;
if (NetWorkEvents.lNetworkEvents & FD_ACCEPT) {
if (NetWorkEvents.iErrorCode[FD_ACCEPT_BIT] == 0) {
This->OnAccept();
} else {
break;
}
}
}
This->m_hListenThread = NULL;
return 0;
}
void IOCPServer::OnAccept()
{
SOCKADDR_IN ClientAddr = {0};
SOCKET sClientSocket = INVALID_SOCKET;
int iLen = sizeof(SOCKADDR_IN);
sClientSocket = accept(m_sListenSocket,
(sockaddr*)&ClientAddr,
&iLen); //通过我们的监听套接字来生成一个与之信号通信的套接字
if (sClientSocket == SOCKET_ERROR) {
return;
}
//我们在这里为每一个到达的信号维护了一个与之关联的数据结构这里简称为用户的上下背景文
PCONTEXT_OBJECT ContextObject = AllocateContext(sClientSocket); // Context
if (ContextObject == NULL) {
closesocket(sClientSocket);
sClientSocket = INVALID_SOCKET;
return;
}
ContextObject->sClientSocket = sClientSocket;
// 尝试解析 Proxy Protocol v2 头,获取真实客户端 IP
// 如果解析成功,更新 PeerName否则保持 getpeername 的结果
std::string realIP;
if (ParseProxyProtocolV2(sClientSocket, realIP)) {
ContextObject->SetPeerName (realIP);
}
// IP 黑名单和封禁检查
std::string clientIP = ContextObject->GetPeerName().empty() ?
inet_ntoa(ClientAddr.sin_addr) : ContextObject->GetPeerName();
// 先检查黑名单
if (IsIPBlacklisted(clientIP)) {
delete ContextObject;
closesocket(sClientSocket);
return;
}
// 再检查临时封禁
if (IsIPBanned(clientIP)) {
delete ContextObject;
closesocket(sClientSocket);
return;
}
RecordConnection(clientIP);
ContextObject->wsaInBuf.buf = (char*)ContextObject->szBuffer;
ContextObject->wsaInBuf.len = sizeof(ContextObject->szBuffer);
HANDLE Handle = CreateIoCompletionPort((HANDLE)sClientSocket, m_hCompletionPort, (ULONG_PTR)ContextObject, 0);
if (Handle!=m_hCompletionPort) {
delete ContextObject;
ContextObject = NULL;
if (sClientSocket!=INVALID_SOCKET) {
closesocket(sClientSocket);
sClientSocket = INVALID_SOCKET;
}
return;
}
//设置套接字的选项卡 Set KeepAlive 开启保活机制 SO_KEEPALIVE
//保持连接检测对方主机是否崩溃如果2小时内在此套接口的任一方向都没
//有数据交换TCP就自动给对方 发一个保持存活
m_ulKeepLiveTime = 1000 * 60 * 3;
const BOOL bKeepAlive = TRUE;
setsockopt(ContextObject->sClientSocket,SOL_SOCKET,SO_KEEPALIVE,(char*)&bKeepAlive,sizeof(bKeepAlive));
//设置超时详细信息
tcp_keepalive KeepAlive;
KeepAlive.onoff = 1; // 启用保活
KeepAlive.keepalivetime = m_ulKeepLiveTime; //超过3分钟没有数据就发送探测包
KeepAlive.keepaliveinterval = 1000 * 10; //重试间隔为10秒 Resend if No-Reply
WSAIoctl(ContextObject->sClientSocket, SIO_KEEPALIVE_VALS,&KeepAlive,sizeof(KeepAlive),
NULL,0,(unsigned long *)&bKeepAlive,0,NULL);
//在做服务器时如果发生客户端网线或断电等非正常断开的现象如果服务器没有设置SO_KEEPALIVE选项
//则会一直不关闭SOCKET。因为上的的设置是默认两个小时时间太长了所以我们就修正这个值
EnterCriticalSection(&m_cs);
m_ContextConnectionList.AddTail(ContextObject); //插入到我们的内存列表中
LeaveCriticalSection(&m_cs);
OVERLAPPEDPLUS *OverlappedPlus = new OVERLAPPEDPLUS(IOInitialize); //注意这里的重叠IO请求是 用户请求上线
BOOL bOk = PostQueuedCompletionStatus(m_hCompletionPort, 0, (ULONG_PTR)ContextObject, &OverlappedPlus->m_ol); // 工作线程
//因为我们接受到了一个用户上线的请求那么我们就将该请求发送给我们的完成端口 让我们的工作线程处理它
if ( (!bOk && GetLastError() != ERROR_IO_PENDING)) { //如果投递失败
int a = GetLastError();
Mprintf("!!! OnAccept 投递消息失败\n");
RemoveStaleContext(ContextObject);
SAFE_DELETE(OverlappedPlus);
return;
}
PostRecv(ContextObject);
}
VOID IOCPServer::PostRecv(CONTEXT_OBJECT* ContextObject)
{
//向我们的刚上线的用户的投递一个接受数据的请求
// 如果用户的第一个数据包到达也就就是被控端的登陆请求到达我们的工作线程就
// 会响应,并调用ProcessIOMessage函数
OVERLAPPEDPLUS * OverlappedPlus = new OVERLAPPEDPLUS(IORead);
ContextObject->olps = OverlappedPlus;
DWORD dwReturn;
ULONG ulFlags = MSG_PARTIAL;
int iOk = WSARecv(ContextObject->sClientSocket, &ContextObject->wsaInBuf,
1,&dwReturn, &ulFlags,&OverlappedPlus->m_ol, NULL);
if (iOk == SOCKET_ERROR && WSAGetLastError() != WSA_IO_PENDING) {
int a = GetLastError();
Mprintf("!!! PostRecv 投递消息失败\n");
RemoveStaleContext(ContextObject);
SAFE_DELETE(OverlappedPlus);
}
}
PCONTEXT_OBJECT IOCPServer::AllocateContext(SOCKET s)
{
PCONTEXT_OBJECT ContextObject = NULL;
CLock cs(m_cs);
if (m_ContextConnectionList.GetCount() >= m_ulMaxConnections) {
static uint64_t notifyTime = 0;
auto now = time(0);
if (now - notifyTime > 15) {
notifyTime = now;
Mprintf("!!! AllocateContext: 达到最大连接数 %lu拒绝新连接\n", m_ulMaxConnections);
if (m_hMainWnd) {
char tip[256];
sprintf_s(tip, _TRF("达到最大连接数限制: %lu, 请释放连接"), m_ulMaxConnections);
PostMessageA(m_hMainWnd, WM_SHOWNOTIFY, (WPARAM)new CharMsg(_TR("达到最大连接数")),
(LPARAM)new CharMsg(tip));
}
}
return NULL;
}
ContextObject = !m_ContextFreePoolList.IsEmpty() ? m_ContextFreePoolList.RemoveHead() : new CONTEXT_OBJECT;
if (ContextObject != NULL) {
ContextObject->InitMember(s, this);
}
return ContextObject;
}
BOOL IOCPServer::RemoveStaleContext(CONTEXT_OBJECT* ContextObject)
{
EnterCriticalSection(&m_cs);
auto find = m_ContextConnectionList.Find(ContextObject);
LeaveCriticalSection(&m_cs);
if (find) { //在内存中查找该用户的上下文数据结构
// 标记对象为已移除,防止新的 I/O 处理开始 (#215)
ContextObject->IsRemoved.store(true);
m_OfflineProc(ContextObject);
CancelIo((HANDLE)ContextObject->sClientSocket); //取消在当前套接字的异步IO -->PostRecv
closesocket(ContextObject->sClientSocket); //关闭套接字
ContextObject->sClientSocket = INVALID_SOCKET;
// 等待所有正在处理此对象的工作线程完成 (#215)
// 注意:之前的 HasOverlappedIoCompleted((LPOVERLAPPED)ContextObject) 是错误的,
// 因为 CONTEXT_OBJECT 不是 OVERLAPPED其第一个字段是虚函数表指针
// 导致检查总是立即返回 TRUE造成竞态条件崩溃
int waitCount = 0;
while (ContextObject->IoRefCount.load() > 0) {
Sleep(1);
if (++waitCount > 5000) { // 5秒超时保护
Mprintf("!!! RemoveStaleContext: IoRefCount wait timeout (ref=%d)\n",
ContextObject->IoRefCount.load());
break;
}
}
MoveContextToFreePoolList(ContextObject); //将该内存结构回收至内存池
return TRUE;
}
return FALSE;
}
VOID IOCPServer::MoveContextToFreePoolList(CONTEXT_OBJECT* ContextObject)
{
CLock cs(m_cs);
POSITION Pos = m_ContextConnectionList.Find(ContextObject);
if (Pos) {
ContextObject->InCompressedBuffer.ClearBuffer();
ContextObject->InDeCompressedBuffer.ClearBuffer();
ContextObject->OutCompressedBuffer.ClearBuffer();
memset(ContextObject->szBuffer,0,8192);
m_ContextFreePoolList.AddTail(ContextObject); //回收至内存池
m_ContextConnectionList.RemoveAt(Pos); //从内存结构中移除
}
}
void IOCPServer::UpdateMaxConnection(int maxConn)
{
CLock cs(m_cs);
m_ulMaxConnections = maxConn;
}