Files
SimpleRemoter/client/IOCPClient.cpp

872 lines
31 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// IOCPClient.cpp: implementation of the IOCPClient class.
//
//////////////////////////////////////////////////////////////////////
#ifdef _WIN32
#include "stdafx.h"
#include <WS2tcpip.h>
#else
#include <netdb.h>
#include <sys/socket.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <netinet/in.h> // For struct sockaddr_in
#include <unistd.h> // For close()
#include <cstring> // For memset()
inline int WSAGetLastError()
{
return -1;
}
#define USING_COMPRESS 1
// 注意Linux 不启用 USING_CTX因为 libzstd.a (1.5.6) 与 zstd.h (1.5.7) 版本不匹配
// 可能导致 ZSTD_CCtx 结构体 ABI 不兼容,引发堆损坏
// 使用无状态 ZSTD_compress/ZSTD_decompress 更安全
#endif
#include "IOCPClient.h"
#include <assert.h>
#include <string>
#if USING_ZLIB
#include "zlib/zlib.h"
#define Z_FAILED(p) (Z_OK != (p))
#define Z_SUCCESS(p) (!Z_FAILED(p))
#else
#include "common/zstd_wrapper.h"
#ifdef _WIN64
#pragma comment(lib, "zstd/zstd_x64.lib")
#else
#pragma comment(lib, "zstd/zstd.lib")
#endif
#define Z_FAILED(p) ZSTD_isError(p)
#define Z_SUCCESS(p) (!Z_FAILED(p))
#define ZSTD_CLEVEL ZSTD_CLEVEL_DEFAULT
#if USING_CTX
#define compress(dest, destLen, source, sourceLen) zstd_compress_auto(m_Cctx, dest, *(destLen), source, sourceLen, 1024*1024)
#define uncompress(dest, destLen, source, sourceLen) ZSTD_decompressDCtx(m_Dctx, dest, *(destLen), source, sourceLen)
#else
#define compress(dest, destLen, source, sourceLen) ZSTD_compress(dest, *(destLen), source, sourceLen, ZSTD_CLEVEL_DEFAULT)
#define uncompress(dest, destLen, source, sourceLen) ZSTD_decompress(dest, *(destLen), source, sourceLen)
#endif
#endif
//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////
#ifndef _WIN32
BOOL SetKeepAliveOptions(int socket, int nKeepAliveSec = 180)
{
// 启用 TCP 保活选项
int enable = 1;
if (setsockopt(socket, SOL_SOCKET, SO_KEEPALIVE, &enable, sizeof(enable)) < 0) {
Mprintf("Failed to enable TCP keep-alive\n");
return FALSE;
}
#ifdef __APPLE__
// macOS: 只有 TCP_KEEPALIVE (等同于 TCP_KEEPIDLE)
setsockopt(socket, IPPROTO_TCP, TCP_KEEPALIVE, &nKeepAliveSec, sizeof(nKeepAliveSec));
#else
// 设置 TCP_KEEPIDLE (3分钟空闲后开始发送 keep-alive 包)
if (setsockopt(socket, IPPROTO_TCP, TCP_KEEPIDLE, &nKeepAliveSec, sizeof(nKeepAliveSec)) < 0) {
Mprintf("Failed to set TCP_KEEPIDLE\n");
return FALSE;
}
// 设置 TCP_KEEPINTVL (5秒的重试间隔)
int keepAliveInterval = 5; // 5秒
if (setsockopt(socket, IPPROTO_TCP, TCP_KEEPINTVL, &keepAliveInterval, sizeof(keepAliveInterval)) < 0) {
Mprintf("Failed to set TCP_KEEPINTVL\n");
return FALSE;
}
// 设置 TCP_KEEPCNT (最多5次探测包后认为连接断开)
int keepAliveProbes = 5;
if (setsockopt(socket, IPPROTO_TCP, TCP_KEEPCNT, &keepAliveProbes, sizeof(keepAliveProbes)) < 0) {
Mprintf("Failed to set TCP_KEEPCNT\n");
return FALSE;
}
#endif
// TCP_USER_TIMEOUT (RFC 5482): 未被对端 ACK 的已发数据超过此时间,内核直接把
// socket 标记为 ETIMEDOUT下一次 send/recv 立即报错。
//
// 为什么 SO_KEEPALIVE 不够keep-alive 只在连接完全 idle 时才探测,应用层每
// 30s 一次心跳让 TCP 永远进不了 idle 态。VM 挂起恢复 / 笔记本合盖唤醒 / NAT
// 表项老化等场景下,对端早已关闭连接但本端 send() 仍把字节塞进 SNDBUF 立即
// 返回成功——出现 ESTABLISHED + Send-Q 堆积的"半死连接",应用层完全无感,
// 默认要等 tcp_retries2 跑完(~15分钟)才报错。
//
// 选 30s>= 默认心跳间隔(5-30s)< 服务端 CheckHeartbeat 超时(>=60s)。
// Linux 2.6.37+ 支持macOS / 老内核 无此宏,自动跳过——那条路径上靠应用层
// ACK 看门狗(linux/main.cpp 心跳循环)兜底。
#ifdef TCP_USER_TIMEOUT
unsigned int userTimeoutMs = 30000;
if (setsockopt(socket, IPPROTO_TCP, TCP_USER_TIMEOUT,
&userTimeoutMs, sizeof(userTimeoutMs)) < 0) {
Mprintf("Failed to set TCP_USER_TIMEOUT\n");
// 非致命keep-alive 已设上,应用层还有 ACK 看门狗兜底,继续即可
}
#endif
Mprintf("TCP keep-alive settings applied successfully\n");
return TRUE;
}
#endif
VOID IOCPClient::setManagerCallBack(void* Manager, DataProcessCB dataProcess, OnDisconnectCB reconnect)
{
m_Manager = Manager;
m_DataProcess = dataProcess;
m_ReconnectFunc = m_exit_while_disconnect ? reconnect : NULL;
}
// 子连接身份校验:发 TOKEN_CONN_AUTH 包后阻塞等服务端响应。
// signMessage 由私有库提供(与 KernelManager.cpp 验证主控签名同款),
// 空 publicKey/privateKey 走内置 HMAC。
extern std::string signMessage(const std::string& privateKey, BYTE* msg, int len);
bool IOCPClient::PerformConnAuth(uint64_t clientID, int timeoutMs)
{
ConnAuthPacket pkt = {};
pkt.token = TOKEN_CONN_AUTH;
pkt.clientID = clientID;
pkt.timestamp = (uint64_t)time(NULL);
// 16 字节 nonce用 rand() + 时间扰动,强度够用(重放保护主要靠时间戳)
for (int i = 0; i < 16; ++i) {
pkt.nonce[i] = (uint8_t)((rand() ^ (clock() >> i)) & 0xFF);
}
BYTE sigInput[8 + 8 + 16];
memcpy(sigInput, &pkt.clientID, 8);
memcpy(sigInput + 8, &pkt.timestamp, 8);
memcpy(sigInput + 16, pkt.nonce, 16);
auto sig = signMessage("", sigInput, sizeof(sigInput));
size_t sigLen = sig.size() < 64 ? sig.size() : 64;
memcpy(pkt.signature, sig.data(), sigLen);
// 设置等待状态
{
std::lock_guard<std::mutex> lk(m_authMtx);
m_authStatus = -1;
m_authPending = true;
}
// 发包;用 HttpMask 包装与其它子连接首包风格一致
HttpMask mask(DEFAULT_HOST, GetClientIPHeader());
int sent = Send2Server((char*)&pkt, sizeof(pkt), &mask);
if (sent <= 0) {
std::lock_guard<std::mutex> lk(m_authMtx);
m_authPending = false;
Mprintf("[ConnAuth] 发送失败\n");
return false;
}
// 等响应或超时
std::unique_lock<std::mutex> lk(m_authMtx);
bool got = m_authCv.wait_for(lk, std::chrono::milliseconds(timeoutMs),
[this]{ return !m_authPending; });
int status = m_authStatus;
m_authPending = false;
if (!got) {
Mprintf("[ConnAuth] 等待响应超时 (%d ms),判定失败\n", timeoutMs);
return false;
}
bool ok = (status == CONN_AUTH_OK);
Mprintf("[ConnAuth] %s (status=%d)\n", ok ? "通过" : "失败", status);
return ok;
}
bool IOCPClient::TryHandleAuthResponse(PBYTE buf, ULONG len)
{
if (!buf || len < sizeof(ConnAuthAck)) return false;
if (buf[0] != TOKEN_CONN_AUTH) return false;
{
std::lock_guard<std::mutex> lk(m_authMtx);
if (!m_authPending) return false; // 没在等 → 不消费,让 manager 处理(理论不会发生)
const ConnAuthAck* ack = (const ConnAuthAck*)buf;
m_authStatus = ack->status;
m_authPending = false;
}
m_authCv.notify_all();
return true;
}
IOCPClient::IOCPClient(const State&bExit, bool exit_while_disconnect, int mask, CONNECT_ADDRESS* conn,
const std::string& pubIP, void* main) : g_bExit(bExit)
{
// 首次构造时打印 ZSTD 版本信息,帮助诊断版本兼容性问题
static bool versionLogged = false;
if (!versionLogged) {
versionLogged = true;
unsigned ver = ZSTD_versionNumber();
#if USING_CTX
Mprintf("[IOCPClient] ZSTD version: %u.%u.%u, USING_CTX=1\n",
ver / 10000, (ver / 100) % 100, ver % 100);
#else
Mprintf("[IOCPClient] ZSTD version: %u.%u.%u, USING_CTX=0\n",
ver / 10000, (ver / 100) % 100, ver % 100);
#endif
}
m_main = main;
m_conn = conn; // 保存 CONNECT_ADDRESS 指针。子连接 auth 在每次连接时通过
// m_conn->clientID 现取主连接 ID同一指针主连接登录后填好的最新值
int encoder = conn ? conn->GetHeaderEncType() : 0;
m_sLocPublicIP = pubIP;
m_ServerAddr = {};
m_nHostPort = 0;
m_Manager = NULL;
m_masker = mask ? new HttpMask(DEFAULT_HOST) : new PkgMask();
auto enc = GetHeaderEncoder(HeaderEncType(time(nullptr) % HeaderEncNum));
m_EncoderType = encoder;
m_Encoder = encoder ? new HellEncoder(enc, new XOREncoder16()) : new ProtocolEncoder();
#ifdef _WIN32
WSADATA wsaData;
WSAStartup(MAKEWORD(2, 2), &wsaData);
#endif
m_sClientSocket = INVALID_SOCKET;
m_hWorkThread = NULL;
m_bWorkThread = S_STOP;
m_bIsRunning = TRUE;
m_bConnected = FALSE;
m_exit_while_disconnect = exit_while_disconnect;
m_ReconnectFunc = NULL;
#if USING_CTX
m_Cctx = ZSTD_createCCtx();
m_Dctx = ZSTD_createDCtx();
auto n = ZSTD_CCtx_setParameter(m_Cctx, ZSTD_c_nbWorkers, 0);
if (Z_FAILED(n)) {
ZSTD_CCtx_setParameter(m_Cctx, ZSTD_c_nbWorkers, 0);
}
ZSTD_CCtx_setParameter(m_Cctx, ZSTD_c_compressionLevel, ZSTD_CLEVEL);
ZSTD_CCtx_setParameter(m_Cctx, ZSTD_c_hashLog, 15);
ZSTD_CCtx_setParameter(m_Cctx, ZSTD_c_chainLog, 16);
ZSTD_CCtx_setParameter(m_Cctx, ZSTD_c_searchLog, 1);
ZSTD_CCtx_setParameter(m_Cctx, ZSTD_c_windowLog, 19);
#endif
}
void IOCPClient::SetMultiThreadCompress(int threadNum)
{
#if USING_CTX
BOOL failed = TRUE;
if (threadNum > 1) {
failed = Z_FAILED(ZSTD_CCtx_setParameter(m_Cctx, ZSTD_c_nbWorkers, threadNum));
}
if (failed) {
ZSTD_CCtx_setParameter(m_Cctx, ZSTD_c_nbWorkers, 0);
}
#endif
}
IOCPClient::~IOCPClient()
{
m_bIsRunning = FALSE;
Disconnect();
if (m_hWorkThread!=NULL) {
SAFE_CLOSE_HANDLE(m_hWorkThread);
m_hWorkThread = NULL;
}
#ifdef _WIN32
WSACleanup();
#endif
while (S_RUN == m_bWorkThread)
Sleep(10);
m_bWorkThread = S_END;
#if USING_CTX
ZSTD_freeCCtx(m_Cctx);
ZSTD_freeDCtx(m_Dctx);
#endif
m_masker->Destroy();
SAFE_DELETE(m_Encoder);
}
// 从域名获取IP地址
std::string GetIPAddress(const char *hostName)
{
#ifdef _WIN32
struct sockaddr_in sa = { 0 };
if (inet_pton(AF_INET, hostName, &(sa.sin_addr)) == 1) {
return hostName;
}
struct hostent *host = gethostbyname(hostName);
#ifdef _DEBUG
if (host == NULL) return "";
Mprintf("此域名的IP类型为: %s.\n", host->h_addrtype == AF_INET ? "IPV4" : "IPV6");
for (int i = 0; host->h_addr_list[i]; ++i)
Mprintf("获取的第%d个IP: %s\n", i+1, inet_ntoa(*(struct in_addr*)host->h_addr_list[i]));
#endif
if (host == NULL || host->h_addr_list == NULL)
return "";
return host->h_addr_list[0] ? inet_ntoa(*(struct in_addr*)host->h_addr_list[0]) : "";
#else
struct addrinfo hints, * res;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET; // IPv4
hints.ai_socktype = SOCK_STREAM; // TCP socket
int status = getaddrinfo(hostName, nullptr, &hints, &res);
if (status != 0) {
Mprintf("getaddrinfo failed: %s\n", gai_strerror(status));
return "";
}
struct sockaddr_in* addr = reinterpret_cast<struct sockaddr_in*>(res->ai_addr);
char ip[INET_ADDRSTRLEN];
inet_ntop(AF_INET, &(addr->sin_addr), ip, sizeof(ip));
Mprintf("IP Address: %s \n", ip);
freeaddrinfo(res); // 不要忘记释放地址信息
return ip;
#endif
}
#ifdef _WIN32
BOOL ConnectWithTimeout(SOCKET sock, SOCKADDR *addr, int timeout_sec=5)
{
// 临时设为非阻塞
u_long mode = 1;
ioctlsocket(sock, FIONBIO, &mode);
// 发起连接(非阻塞)
int ret = connect(sock, addr, sizeof(*addr));
if (ret == SOCKET_ERROR) {
int err = WSAGetLastError();
if (err != WSAEWOULDBLOCK && err != WSAEINPROGRESS) {
return FALSE;
}
}
// 等待可写(代表连接完成或失败)
fd_set writefds;
FD_ZERO(&writefds);
FD_SET(sock, &writefds);
timeval tv;
tv.tv_sec = timeout_sec;
tv.tv_usec = 0;
ret = select(0, NULL, &writefds, NULL, &tv);
if (ret <= 0 || !FD_ISSET(sock, &writefds)) {
return FALSE; // 超时或出错
}
// 检查连接是否真正成功
int error = 0;
int len = sizeof(error);
getsockopt(sock, SOL_SOCKET, SO_ERROR, (char*)&error, &len);
if (error != 0) {
return FALSE;
}
// 改回阻塞模式
mode = 0;
ioctlsocket(sock, FIONBIO, &mode);
return TRUE;
}
#endif
BOOL IOCPClient::ConnectServer(const char* szServerIP, unsigned short uPort)
{
if (szServerIP != NULL && uPort != 0) {
SetServerAddress(szServerIP, uPort);
}
m_sCurIP = m_Domain.SelectIP();
m_masker->SetServer(m_sCurIP.c_str());
unsigned short port = m_nHostPort;
m_sClientSocket = socket(AF_INET,SOCK_STREAM, IPPROTO_TCP); //传输层
if (m_sClientSocket == SOCKET_ERROR) {
return FALSE;
}
#ifdef _WIN32
m_ServerAddr.sin_family = AF_INET;
m_ServerAddr.sin_port = htons(port);
m_ServerAddr.sin_addr.S_un.S_addr = inet_addr(m_sCurIP.c_str());
if (!ConnectWithTimeout(m_sClientSocket,(SOCKADDR *)&m_ServerAddr)) {
if (m_sClientSocket!=INVALID_SOCKET) {
closesocket(m_sClientSocket);
m_sClientSocket = INVALID_SOCKET;
}
return FALSE;
}
#else
m_ServerAddr.sin_family = AF_INET;
m_ServerAddr.sin_port = htons(port);
// 若szServerIP非数字开头则认为是域名需进行IP转换
// 使用 inet_pton 替代 inet_addr (inet_pton 可以支持 IPv4 和 IPv6)
if (inet_pton(AF_INET, m_sCurIP.c_str(), &m_ServerAddr.sin_addr) <= 0) {
Mprintf("Invalid address or address not supported\n");
return false;
}
// 创建套接字
if (m_sClientSocket == -1) {
Mprintf("Failed to create socket\n");
return false;
}
// 连接到服务器
if (connect(m_sClientSocket, (struct sockaddr*)&m_ServerAddr, sizeof(m_ServerAddr)) == -1) {
Mprintf("Connection failed\n");
close(m_sClientSocket);
m_sClientSocket = -1; // 标记套接字无效
return false;
}
#endif
const int chOpt = 1; // True
// 启用 TCP_NODELAY 禁用 Nagle 算法,减少小包延迟
int nodelay = 1;
setsockopt(m_sClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&nodelay, sizeof(nodelay));
// 增大发送缓冲区到 256KB
int sendBufSize = 256 * 1024;
setsockopt(m_sClientSocket, SOL_SOCKET, SO_SNDBUF, (char*)&sendBufSize, sizeof(sendBufSize));
// Set KeepAlive 开启保活机制, 防止服务端产生死连接
if (setsockopt(m_sClientSocket, SOL_SOCKET, SO_KEEPALIVE,
(char *)&chOpt, sizeof(chOpt)) == 0) {
#ifdef _WIN32
// 设置超时详细信息
tcp_keepalive klive;
klive.onoff = 1; // 启用保活
klive.keepalivetime = 1000 * 60 * 3; // 3分钟超时 Keep Alive
klive.keepaliveinterval = 1000 * 5; // 重试间隔为5秒 Resend if No-Reply
WSAIoctl(m_sClientSocket, SIO_KEEPALIVE_VALS,&klive,sizeof(tcp_keepalive),
NULL, 0,(unsigned long *)&chOpt,0,NULL);
#else
// 设置保活选项
SetKeepAliveOptions(m_sClientSocket);
#endif
}
m_bConnected = TRUE;
Mprintf("连接服务端成功: %s:%d.\n", m_sCurIP.c_str(), (int)port);
if (m_hWorkThread == NULL) {
#ifdef _WIN32
m_bIsRunning = TRUE;
m_hWorkThread = (HANDLE)__CreateThread(NULL, 0, WorkThreadProc,(LPVOID)this, 0, NULL);
m_bWorkThread = m_hWorkThread ? S_RUN : S_STOP;
m_bIsRunning = m_hWorkThread ? TRUE : FALSE;
#else
pthread_t id = 0;
int ret = pthread_create(&id, nullptr, (void* (*)(void*))IOCPClient::WorkThreadProc, this);
if (ret == 0) {
m_bWorkThread = S_RUN;
m_bIsRunning = TRUE;
}
#endif
}
// 子连接身份校验opt-in 通过 EnableSubConnAuth 开启):
// - WorkThread 已经启动,能接收 ack 包并通过 TryHandleAuthResponse 唤醒等待。
// - clientID 优先用 EnableSubConnAuth 显式传入的值Linux/macOS 客户端走此路径),
// 未显式传入时从 m_conn 现取Windows 客户端走此路径)。
// - 校验失败Disconnect 并返回 FALSE让上层走重连或放弃逻辑。
if (m_subConnAuthEnabled) {
uint64_t cid = m_subConnAuthClientID;
if (cid == 0 && m_conn) cid = m_conn->clientID;
if (cid == 0) {
Mprintf("[ConnAuth] 跳过校验clientID 尚未就绪(主连接还没拿到 ID\n");
// 没拿到 ID 就别盲发,等下一次 Reconnect 时再试。视为本次连接失败。
Disconnect();
return FALSE;
}
if (!PerformConnAuth(cid, CONN_AUTH_CLIENT_WAIT_MS)) {
Mprintf("[ConnAuth] 校验失败,断开连接\n");
Disconnect();
return FALSE;
}
}
return TRUE;
}
DWORD WINAPI IOCPClient::WorkThreadProc(LPVOID lParam)
{
IOCPClient* This = (IOCPClient*)lParam;
char* szBuffer = new char[MAX_RECV_BUFFER];
fd_set fd;
struct timeval tm;
CBuffer m_CompressedBuffer;
while (This->IsRunning()) { // 没有退出,就一直陷在这个循环中
if(!This->IsConnected()) {
Sleep(50);
continue;
}
FD_ZERO(&fd);
FD_SET(This->m_sClientSocket, &fd);
// Linux select() 会修改 timeval必须每次重置
tm.tv_sec = 2;
tm.tv_usec = 0;
#ifdef _WIN32
int iRet = select(NULL, &fd, NULL, NULL, &tm);
#else
int iRet = select(This->m_sClientSocket + 1, &fd, NULL, NULL, &tm);
#endif
if (iRet <= 0) {
if (iRet == 0) Sleep(50);
else {
Mprintf("[select] return %d, GetLastError= %d. \n", iRet, WSAGetLastError());
This->Disconnect(); //接收错误处理
m_CompressedBuffer.ClearBuffer();
if(This->m_exit_while_disconnect)
break;
}
} else if (iRet > 0) {
if (!This->ProcessRecvData(&m_CompressedBuffer, szBuffer, MAX_RECV_BUFFER - 1, 0)) {
break;
}
}
}
SAFE_CLOSE_HANDLE(This->m_hWorkThread);
This->m_hWorkThread = NULL;
This->m_bWorkThread = S_STOP;
This->m_bIsRunning = FALSE;
delete[] szBuffer;
return 0xDEAD;
}
bool IOCPClient::ProcessRecvData(CBuffer *m_CompressedBuffer, char *szBuffer, int len, int flag)
{
int iReceivedLength = ReceiveData(szBuffer, len, flag);
if (iReceivedLength <= 0) {
int a = WSAGetLastError();
Mprintf("[recv] return %d, GetLastError= %d. \n", iReceivedLength, a);
Disconnect(); //接收错误处理
m_CompressedBuffer->ClearBuffer();
if (m_ReconnectFunc && !m_ReconnectFunc(m_Manager))
return false;
} else {
szBuffer[iReceivedLength] = 0;
//正确接收就调用OnRead处理,转到OnRead
OnServerReceiving(m_CompressedBuffer, szBuffer, iReceivedLength);
}
return true;
}
// 带异常处理的数据处理逻辑:
// 如果 f 执行时 没有触发系统异常(如访问冲突),返回 0
// 如果 f 执行过程中 抛出了异常(比如空指针访问),将被 __except 捕获,返回异常码(如 0xC0000005 表示访问违规)
int DataProcessWithSEH(DataProcessCB f, void* manager, LPBYTE data, ULONG len)
{
#ifdef _WIN32
__try {
if (f) f(manager, data, len);
return 0;
} __except (EXCEPTION_EXECUTE_HANDLER) {
return GetExceptionCode();
}
#else
// 非 Windows 平台暂不支持 SEH 异常处理,直接调用
if (f) f(manager, data, len);
return 0;
#endif
}
VOID IOCPClient::OnServerReceiving(CBuffer* m_CompressedBuffer, char* szBuffer, ULONG ulLength)
{
try {
assert (ulLength > 0);
//以下接到数据进行解压缩
m_CompressedBuffer->WriteBuffer((LPBYTE)szBuffer, ulLength);
int FLAG_LENGTH = m_Encoder->GetFlagLen();
int HDR_LENGTH = m_Encoder->GetHeadLen();
//检测数据是否大于数据头大小 如果不是那就不是正确的数据
while (m_CompressedBuffer->GetBufferLength() > HDR_LENGTH) {
// UnMask
char* src = (char*)m_CompressedBuffer->GetBuffer();
ULONG srcSize = m_CompressedBuffer->GetBufferLength();
PkgMaskType maskType = MaskTypeUnknown;
ULONG ret = TryUnMask(src, srcSize, maskType);
// ULONG ret = m_masker->UnMask(src, srcSize);
m_CompressedBuffer->Skip(ret);
if (m_CompressedBuffer->GetBufferLength() <= HDR_LENGTH)
break;
char szPacketFlag[32] = {0};
src = (char*)m_CompressedBuffer->GetBuffer();
CopyMemory(szPacketFlag, src, FLAG_LENGTH);
//判断数据头
HeaderEncType encType = HeaderEncUnknown;
FlagType flagType = CheckHead(szPacketFlag, encType);
if (flagType == FLAG_UNKNOWN) {
// 打印诊断信息
ULONG bufLen = m_CompressedBuffer->GetBufferLength();
Mprintf("[ERROR] Unknown header! bufLen=%lu, first 16 bytes: ", bufLen);
for (int i = 0; i < 16 && i < (int)bufLen; ++i) {
Mprintf("%02X ", (unsigned char)src[i]);
}
Mprintf("\n");
m_CompressedBuffer->ClearBuffer();
break;
}
ULONG ulPackTotalLength = 0;
CopyMemory(&ulPackTotalLength, m_CompressedBuffer->GetBuffer(FLAG_LENGTH), sizeof(ULONG));
// 包长度合理性检查:防止错误的长度值导致内存问题
// 单个包不应超过 50MB且至少要大于头部长度支持大型DLL执行代码传输
const ULONG MAX_PACKET_SIZE = 50 * 1024 * 1024;
if (ulPackTotalLength <= (ULONG)HDR_LENGTH || ulPackTotalLength > MAX_PACKET_SIZE) {
Mprintf("[ERROR] Invalid packet length: %lu (HDR=%d)\n", ulPackTotalLength, HDR_LENGTH);
m_CompressedBuffer->ClearBuffer();
break;
}
//--- 数据的大小正确判断
ULONG len = m_CompressedBuffer->GetBufferLength();
if (ulPackTotalLength && len >= ulPackTotalLength) {
ULONG ulOriginalLength = 0;
m_CompressedBuffer->ReadBuffer((PBYTE)szPacketFlag, FLAG_LENGTH);//读取各种头部 shine
m_CompressedBuffer->ReadBuffer((PBYTE) &ulPackTotalLength, sizeof(ULONG));
m_CompressedBuffer->ReadBuffer((PBYTE) &ulOriginalLength, sizeof(ULONG));
// 解压后长度合理性检查
if (ulOriginalLength == 0 || ulOriginalLength > MAX_PACKET_SIZE) {
Mprintf("[ERROR] Invalid original length: %lu. Skipping packet.\n", ulOriginalLength);
ULONG skipLen = ulPackTotalLength - HDR_LENGTH;
if (skipLen > 0 && skipLen < len) {
m_CompressedBuffer->Skip(skipLen);
}
continue;
}
ULONG ulCompressedLength = ulPackTotalLength - HDR_LENGTH;
const int bufSize = 512;
BYTE buf1[bufSize], buf2[bufSize];
PBYTE CompressedBuffer = ulCompressedLength > bufSize ? new BYTE[ulCompressedLength] : buf1;
PBYTE DeCompressedBuffer = ulOriginalLength > bufSize ? new BYTE[ulOriginalLength] : buf2;
m_CompressedBuffer->ReadBuffer(CompressedBuffer, ulCompressedLength);
m_Encoder->Decode(CompressedBuffer, ulCompressedLength, (LPBYTE)szPacketFlag);
size_t iRet = uncompress(DeCompressedBuffer, &ulOriginalLength, CompressedBuffer, ulCompressedLength);
if (Z_SUCCESS(iRet)) { //如果解压成功
// 优先看是不是 TOKEN_CONN_AUTH 响应;只有当 PerformConnAuth 正在等待时才消费。
// 不在等待状态时返回 false包透传给 managermanager 一般也不识别此 token
// 走 default 路径忽略,无副作用)。
if (!TryHandleAuthResponse(DeCompressedBuffer, ulOriginalLength)) {
//解压好的数据和长度传递给对象Manager进行处理 注意这里是用了多态
//由于m_pManager中的子类不一样造成调用的OnReceive函数不一样
int ret = DataProcessWithSEH(m_DataProcess, m_Manager, DeCompressedBuffer, ulOriginalLength);
if (ret) {
Mprintf("[ERROR] DataProcessWithSEH return exception code: [0x%08X]\n", ret);
}
}
} else {
Mprintf("[ERROR] uncompress fail: dstLen %lu, srcLen %lu\n", ulOriginalLength, ulCompressedLength);
// ReadBuffer 已消费当前包,不需要清空缓冲区
}
if (CompressedBuffer != buf1)delete [] CompressedBuffer;
if (DeCompressedBuffer != buf2)delete [] DeCompressedBuffer;
} else {
break; // received data is incomplete
}
}
} catch(...) {
m_CompressedBuffer->ClearBuffer();
Mprintf("[ERROR] OnServerReceiving catch an error \n");
}
}
// 向server发送数据压缩操作比较耗时。
// 关闭压缩开关时SendWithSplit比较耗时。
BOOL IOCPClient::OnServerSending(const char* szBuffer, ULONG ulOriginalLength, PkgMask* mask) //Hello
{
AUTO_TICK(100, std::to_string(ulOriginalLength));
assert (ulOriginalLength > 0);
// 整个发送过程需要加锁,防止多线程(视频+音频)数据交错
std::lock_guard<std::mutex> lock(m_Locker);
{
int cmd = BYTE(szBuffer[0]);
//乘以1.001是以最坏的也就是数据压缩后占用的内存空间和原先一样 +12
//防止缓冲区溢出// HelloWorld 10 22
//数据压缩 压缩算法 微软提供
//nSize = 436
//destLen = 448
#if USING_ZLIB
unsigned long ulCompressedLength = (double)ulOriginalLength * 1.001 + 12;
#else
unsigned long ulCompressedLength = ZSTD_compressBound(ulOriginalLength);
#endif
BYTE buf[1024];
LPBYTE CompressedBuffer = ulCompressedLength>1024 ? new BYTE[ulCompressedLength] : buf;
int iRet = compress(CompressedBuffer, &ulCompressedLength, (PBYTE)szBuffer, ulOriginalLength);
if (Z_FAILED(iRet)) {
Mprintf("[ERROR] compress failed: srcLen %d, dstLen %d \n", ulOriginalLength, ulCompressedLength);
if (CompressedBuffer != buf) delete [] CompressedBuffer;
return FALSE;
}
#if !USING_ZLIB
ulCompressedLength = iRet;
#endif
ULONG ulPackTotalLength = ulCompressedLength + m_Encoder->GetHeadLen();
CBuffer m_WriteBuffer;
HeaderFlag H = m_Encoder->GetHead();
m_Encoder->Encode(CompressedBuffer, ulCompressedLength, (LPBYTE)H.data());
m_WriteBuffer.WriteBuffer((PBYTE)H.data(), m_Encoder->GetFlagLen());
m_WriteBuffer.WriteBuffer((PBYTE) &ulPackTotalLength,sizeof(ULONG));
m_WriteBuffer.WriteBuffer((PBYTE)&ulOriginalLength, sizeof(ULONG));
m_WriteBuffer.WriteBuffer(CompressedBuffer,ulCompressedLength);
if (CompressedBuffer != buf) delete [] CompressedBuffer;
STOP_TICK;
// 分块发送
return SendWithSplit((char*)m_WriteBuffer.GetBuffer(), m_WriteBuffer.GetBufferLength(), MAX_SEND_BUFFER, cmd, mask);
}
}
// 5 2 // 2 2 1
BOOL IOCPClient::SendWithSplit(const char* src, ULONG srcSize, ULONG ulSplitLength, int cmd, PkgMask* mask)
{
AUTO_TICK(50, std::to_string(cmd));
if (src == nullptr || srcSize == 0 || ulSplitLength == 0)
return FALSE;
// Mask
char* szBuffer = nullptr;
ULONG ulLength = 0;
(mask && srcSize <= ulSplitLength) ? mask->SetServer(m_sCurIP)->Mask(szBuffer, ulLength, (char*)src, srcSize, cmd) :
m_masker->Mask(szBuffer, ulLength, (char*)src, srcSize, cmd);
if(szBuffer != src && srcSize > ulSplitLength) {
Mprintf("SendWithSplit: %d bytes large packet may causes issues.\n", srcSize);
}
bool isFail = false;
int iReturn = 0; //真正发送了多少
const char* Travel = szBuffer;
int i = 0;
int ulSended = 0;
const int ulSendRetry = 15;
// 大包优化:当数据量超过阈值时,尝试一次性发送更大的块
// SO_SNDBUF 已设为 256KB可以尝试一次发送更多数据
const ULONG LARGE_PACKET_THRESHOLD = 256 * 1024; // 256KB
ULONG actualSplitLength = ulSplitLength;
if (ulLength >= LARGE_PACKET_THRESHOLD) {
// 大包使用更大的分块,减少系统调用次数
actualSplitLength = 256 * 1024; // 一次发送256KB
}
// 依次发送
for (i = ulLength; i >= (int)actualSplitLength; i -= actualSplitLength) {
int remaining = actualSplitLength;
while (remaining > 0) {
int j = 0;
for (; j < ulSendRetry; ++j) {
iReturn = SendTo(Travel, remaining, 0);
if (iReturn > 0) {
break;
}
}
if (j == ulSendRetry) {
isFail = true;
break;
}
ulSended += iReturn;
Travel += iReturn;
remaining -= iReturn;
}
if (isFail) break;
}
// 发送最后的部分
if (!isFail && i>0) { //1024
int remaining = i;
while (remaining > 0) {
int j = 0;
for (; j < ulSendRetry; j++) {
iReturn = SendTo((char*)Travel, remaining, 0);
if (iReturn > 0) {
break;
}
}
if (j == ulSendRetry) {
isFail = true;
break;
}
ulSended += iReturn;
Travel += iReturn;
remaining -= iReturn;
}
}
if (szBuffer != src)
SAFE_DELETE_ARRAY(szBuffer);
if (isFail) {
return FALSE;
}
return (ulSended == ulLength) ? TRUE : FALSE;
}
VOID IOCPClient::Disconnect()
{
if (m_sClientSocket == INVALID_SOCKET)
return;
Mprintf("Disconnect with [%s:%d].\n", m_sCurIP.c_str(), m_nHostPort);
CancelIo((HANDLE)m_sClientSocket);
closesocket(m_sClientSocket);
m_sClientSocket = INVALID_SOCKET;
m_bConnected = FALSE;
}
VOID IOCPClient::RunEventLoop(const BOOL &bCondition)
{
Mprintf("======> RunEventLoop begin\n");
while ((m_bIsRunning && bCondition) || bCondition == FOREVER_RUN)
Sleep(200);
setManagerCallBack(NULL, NULL, NULL);
Mprintf("======> RunEventLoop end\n");
}
BOOL is_valid()
{
return TRUE;
}
VOID IOCPClient::RunEventLoop(TrailCheck checker)
{
Mprintf("======> RunEventLoop begin\n");
checker = checker ? checker : is_valid;
#ifdef _DEBUG
checker = is_valid;
#endif
while (m_bIsRunning && checker())
Sleep(200);
setManagerCallBack(NULL, NULL, NULL);
Mprintf("======> RunEventLoop end\n");
}