Fix: keep Linux/macOS client alive across server restarts; gate all commands on auth-verified state to neutralize unauthorized servers
348 lines
10 KiB
C++
348 lines
10 KiB
C++
// IOCPClient.h: interface for the IOCPClient class.
|
||
//
|
||
//////////////////////////////////////////////////////////////////////
|
||
|
||
#pragma once
|
||
|
||
#ifdef _WIN32
|
||
#include "stdafx.h"
|
||
#include <WinSock2.h>
|
||
#include <MSTcpIP.h>
|
||
#pragma comment(lib,"ws2_32.lib")
|
||
#endif
|
||
|
||
#include "Buffer.h"
|
||
#include "zstd/zstd.h"
|
||
#include "domain_pool.h"
|
||
#include "common/mask.h"
|
||
#include "common/header.h"
|
||
#define NO_AES
|
||
#include "common/encrypt.h"
|
||
#ifdef _WIN32
|
||
#include "SafeThread.h"
|
||
#else
|
||
#ifndef SAFE_DELETE
|
||
#define SAFE_DELETE(p) if(NULL !=(p)){ delete (p);(p) = NULL;}
|
||
#endif
|
||
#ifndef SAFE_DELETE_ARRAY
|
||
#define SAFE_DELETE_ARRAY(p) if(NULL !=(p)){ delete[] (p);(p) = NULL;}
|
||
#endif
|
||
#include <sys/socket.h>
|
||
#include <netinet/in.h>
|
||
#endif
|
||
#include "IOCPBase.h"
|
||
#include <mutex>
|
||
#include <condition_variable>
|
||
#include <chrono>
|
||
|
||
#define MAX_RECV_BUFFER 1024*32
|
||
#define MAX_SEND_BUFFER 1024*128 // 增大分块大小以提高发送效率
|
||
|
||
enum { S_STOP = 0, S_RUN, S_END };
|
||
|
||
typedef int (*DataProcessCB)(void* userData, PBYTE szBuffer, ULONG ulLength);
|
||
|
||
typedef int (*OnDisconnectCB)(void* userData);
|
||
|
||
class ProtocolEncoder
|
||
{
|
||
public:
|
||
virtual ~ProtocolEncoder() {}
|
||
virtual HeaderFlag GetHead() const
|
||
{
|
||
return "Shine";
|
||
}
|
||
virtual int GetHeadLen() const
|
||
{
|
||
return 13;
|
||
}
|
||
virtual int GetFlagLen() const
|
||
{
|
||
return 5;
|
||
}
|
||
virtual void Encode(unsigned char* data, int len, unsigned char* param = 0) {}
|
||
virtual void Decode(unsigned char* data, int len, unsigned char* param = 0) {}
|
||
virtual EncFun GetHeaderEncoder() const
|
||
{
|
||
return nullptr;
|
||
}
|
||
};
|
||
|
||
class HellEncoder : public ProtocolEncoder
|
||
{
|
||
private:
|
||
EncFun m_HeaderEnc;
|
||
Encoder *m_BodyEnc;
|
||
public:
|
||
HellEncoder(EncFun head, Encoder *body)
|
||
{
|
||
m_HeaderEnc = head;
|
||
m_BodyEnc = body;
|
||
}
|
||
~HellEncoder()
|
||
{
|
||
SAFE_DELETE(m_BodyEnc);
|
||
}
|
||
virtual HeaderFlag GetHead() const override
|
||
{
|
||
return ::GetHead(m_HeaderEnc);
|
||
}
|
||
virtual int GetHeadLen() const override
|
||
{
|
||
return 16;
|
||
}
|
||
virtual int GetFlagLen() const override
|
||
{
|
||
return 8;
|
||
}
|
||
virtual void Encode(unsigned char* data, int len, unsigned char* param = 0) override
|
||
{
|
||
return m_BodyEnc->Encode(data, len, param);
|
||
}
|
||
virtual void Decode(unsigned char* data, int len, unsigned char* param = 0) override
|
||
{
|
||
return m_BodyEnc->Decode(data, len, param);
|
||
}
|
||
virtual EncFun GetHeaderEncoder() const override
|
||
{
|
||
return m_HeaderEnc;
|
||
}
|
||
};
|
||
|
||
class IOCPManager
|
||
{
|
||
public:
|
||
virtual ~IOCPManager() {}
|
||
virtual BOOL IsAlive() const
|
||
{
|
||
return TRUE;
|
||
}
|
||
virtual BOOL IsReady() const
|
||
{
|
||
return TRUE;
|
||
}
|
||
virtual VOID OnReceive(PBYTE szBuffer, ULONG ulLength) { }
|
||
|
||
// Tip: 在派生类实现该函数以便支持断线重连
|
||
virtual BOOL OnReconnect()
|
||
{
|
||
return FALSE;
|
||
}
|
||
|
||
static int DataProcess(void* user, PBYTE szBuffer, ULONG ulLength)
|
||
{
|
||
IOCPManager* m_Manager = (IOCPManager*)user;
|
||
if (nullptr == m_Manager) {
|
||
Mprintf("IOCPManager DataProcess on NULL ptr: %d\n", unsigned(szBuffer[0]));
|
||
return FALSE;
|
||
}
|
||
// 等待子类准备就绪才能处理数据, 1秒足够了
|
||
int i = 0;
|
||
for (; i < 1000 && !m_Manager->IsReady(); ++i)
|
||
Sleep(1);
|
||
if (!m_Manager->IsReady()) {
|
||
Mprintf("IOCPManager DataProcess is NOT ready: %d\n", unsigned(szBuffer[0]));
|
||
return FALSE;
|
||
}
|
||
if (i) {
|
||
Mprintf("IOCPManager DataProcess wait for %dms: %d\n", i, unsigned(szBuffer[0]));
|
||
}
|
||
m_Manager->OnReceive(szBuffer, ulLength);
|
||
return TRUE;
|
||
}
|
||
static int ReconnectProcess(void* user)
|
||
{
|
||
IOCPManager* m_Manager = (IOCPManager*)user;
|
||
if (nullptr == m_Manager) {
|
||
return FALSE;
|
||
}
|
||
return m_Manager->OnReconnect();
|
||
}
|
||
};
|
||
|
||
typedef BOOL(*TrailCheck)(void);
|
||
|
||
class IOCPClient : public IOCPBase
|
||
{
|
||
public:
|
||
IOCPClient(const State& bExit, bool exit_while_disconnect = false, int mask=0, CONNECT_ADDRESS *conn=0,
|
||
const std::string&pubIP="", void*main=0);
|
||
virtual ~IOCPClient();
|
||
|
||
int SendLoginInfo(const LOGIN_INFOR& logInfo)
|
||
{
|
||
LOGIN_INFOR tmp = logInfo;
|
||
int iRet = Send2Server((char*)&tmp, sizeof(LOGIN_INFOR));
|
||
|
||
return iRet;
|
||
}
|
||
virtual BOOL ConnectServer(const char* szServerIP, unsigned short uPort);
|
||
|
||
std::string GetClientIP() const
|
||
{
|
||
return m_sLocPublicIP;
|
||
}
|
||
|
||
std::map<std::string, std::string> GetClientIPHeader() const
|
||
{
|
||
return m_sLocPublicIP.empty() ? std::map<std::string, std::string> {} :
|
||
std::map<std::string, std::string> { {"X-Forwarded-For", m_sLocPublicIP} };
|
||
}
|
||
|
||
BOOL Send2Server(const char* szBuffer, ULONG ulOriginalLength, PkgMask* mask = NULL)
|
||
{
|
||
return OnServerSending(szBuffer, ulOriginalLength, mask);
|
||
}
|
||
|
||
void SetServerAddress(const char* szServerIP, unsigned short uPort)
|
||
{
|
||
m_Domain = szServerIP ? szServerIP : "127.0.0.1";
|
||
m_nHostPort = uPort;
|
||
}
|
||
|
||
std::string ServerIP() const
|
||
{
|
||
return m_sCurIP;
|
||
}
|
||
|
||
int ServerPort() const
|
||
{
|
||
return m_nHostPort;
|
||
}
|
||
|
||
BOOL IsRunning() const
|
||
{
|
||
return m_bIsRunning;
|
||
}
|
||
VOID StopRunning()
|
||
{
|
||
m_ReconnectFunc = NULL;
|
||
m_bIsRunning = FALSE;
|
||
}
|
||
VOID setManagerCallBack(void* Manager, DataProcessCB dataProcess, OnDisconnectCB reconnect);
|
||
VOID RunEventLoop(TrailCheck checker);
|
||
VOID RunEventLoop(const BOOL &bCondition);
|
||
bool IsConnected() const
|
||
{
|
||
return m_bConnected == TRUE;
|
||
}
|
||
BOOL Reconnect(void* manager)
|
||
{
|
||
Disconnect();
|
||
if (manager) m_Manager = manager;
|
||
return ConnectServer(NULL, 0);
|
||
}
|
||
const State& GetState() const
|
||
{
|
||
return g_bExit;
|
||
}
|
||
void SetMultiThreadCompress(int threadNum=0);
|
||
std::string GetClientID() const
|
||
{
|
||
return m_conn ? std::to_string(m_conn->clientID) : "";
|
||
}
|
||
std::string GetPublicIP() const
|
||
{
|
||
return m_sLocPublicIP;
|
||
}
|
||
CONNECT_ADDRESS* GetConnectionAddress() const
|
||
{
|
||
return m_conn;
|
||
}
|
||
IOCPManager* GetManager() const
|
||
{
|
||
return (IOCPManager*)m_Manager;
|
||
}
|
||
void* GetMain() const
|
||
{
|
||
return m_main;
|
||
}
|
||
void SetVerifyInfo(const std::string& msg, const std::string& hmac) {
|
||
m_LoginMsg = msg;
|
||
m_LoginSignature = hmac;
|
||
}
|
||
|
||
// 子连接身份校验:发 TOKEN_CONN_AUTH 包,等服务端 ConnAuthAck 响应。
|
||
// 返回 true 表示通过,false 表示超时/失败/网络错误。
|
||
// 主连接不调用此方法。新客户端必须调用并校验成功后才能继续后续命令。
|
||
// 已实现的协议扩展(如 KeyBoard 子连接的 cap word)保留不变,与本机制并行工作。
|
||
bool PerformConnAuth(uint64_t clientID, int timeoutMs);
|
||
|
||
// 让 ConnectServer 在每次成功后自动调一次 PerformConnAuth(opt-in)。
|
||
// 子连接构造后调用此方法启用。
|
||
// - clientID == 0:每次 auth 时从 m_conn->clientID 现取(Windows 客户端走此路径)。
|
||
// 这样即便 IOCPClient 创建时主连接还没拿到 ID,真正连上时也能用到最新值。
|
||
// - clientID != 0:显式指定(Linux/macOS 客户端 IOCPClient 不带 m_conn 时用此参数)。
|
||
void EnableSubConnAuth(bool enabled = true, uint64_t clientID = 0) {
|
||
m_subConnAuthEnabled = enabled;
|
||
m_subConnAuthClientID = clientID;
|
||
}
|
||
|
||
// 内部:在收到的数据帧分发到 manager 之前,尝试识别并消费 TOKEN_CONN_AUTH ack。
|
||
// 仅在我们正在等待 auth 响应时(m_authPending=true)才消费;否则透传给 manager。
|
||
bool TryHandleAuthResponse(PBYTE buf, ULONG len);
|
||
|
||
// 主动断开当前连接,关闭 socket。提到 public 让外层(如 Linux/macOS main 的心跳
|
||
// 循环检测到服务端身份校验超时)能在重连前显式关闭旧 fd,避免泄漏。
|
||
virtual VOID Disconnect(); // 函数支持 TCP/UDP
|
||
|
||
protected:
|
||
virtual int ReceiveData(char* buffer, int bufSize, int flags)
|
||
{
|
||
// TCP版本调用 recv
|
||
return recv(m_sClientSocket, buffer, bufSize - 1, 0);
|
||
}
|
||
virtual bool ProcessRecvData(CBuffer* m_CompressedBuffer, char* szBuffer, int len, int flag);
|
||
virtual int SendTo(const char* buf, int len, int flags)
|
||
{
|
||
return ::send(m_sClientSocket, buf, len, flags);
|
||
}
|
||
BOOL OnServerSending(const char* szBuffer, ULONG ulOriginalLength, PkgMask* mask);
|
||
static DWORD WINAPI WorkThreadProc(LPVOID lParam);
|
||
VOID OnServerReceiving(CBuffer *m_CompressedBuffer, char* szBuffer, ULONG ulReceivedLength);
|
||
BOOL SendWithSplit(const char* src, ULONG srcSize, ULONG ulSplitLength, int cmd, PkgMask* mask);
|
||
|
||
protected:
|
||
sockaddr_in m_ServerAddr;
|
||
SOCKET m_sClientSocket;
|
||
BOOL m_bWorkThread;
|
||
HANDLE m_hWorkThread;
|
||
BOOL m_bIsRunning;
|
||
BOOL m_bConnected;
|
||
|
||
std::mutex m_Locker;
|
||
|
||
// 子连接身份校验同步状态。仅在 PerformConnAuth 调用期间生效。
|
||
std::mutex m_authMtx;
|
||
std::condition_variable m_authCv;
|
||
int m_authStatus = -1; // -1 = 未启动;其它 = ConnAuthStatus
|
||
bool m_authPending = false; // true 时 TryHandleAuthResponse 才消费 ack
|
||
|
||
// ConnectServer 成功后自动 auth 的 opt-in 标志。子连接构造后调 EnableSubConnAuth() 设为 true。
|
||
bool m_subConnAuthEnabled = false;
|
||
uint64_t m_subConnAuthClientID = 0; // 0 表示从 m_conn->clientID 现取
|
||
#if USING_CTX
|
||
ZSTD_CCtx* m_Cctx; // 压缩上下文
|
||
ZSTD_DCtx* m_Dctx; // 解压上下文
|
||
#endif
|
||
|
||
const State& g_bExit; // 全局状态量
|
||
void* m_Manager; // 用户数据
|
||
DataProcessCB m_DataProcess; // 处理用户数据
|
||
OnDisconnectCB m_ReconnectFunc; // 断线重连逻辑
|
||
ProtocolEncoder* m_Encoder; // 加密
|
||
DomainPool m_Domain;
|
||
std::string m_sCurIP;
|
||
int m_nHostPort;
|
||
bool m_exit_while_disconnect;
|
||
PkgMask* m_masker;
|
||
BOOL m_EncoderType;
|
||
std::string m_sLocPublicIP;
|
||
CONNECT_ADDRESS *m_conn = NULL;
|
||
|
||
void *m_main = NULL;
|
||
public:
|
||
std::string m_LoginMsg; // 登录消息摘要
|
||
std::string m_LoginSignature; // 登录消息签名
|
||
};
|