Init: Migrate SimpleRemoter (Since v1.3.1) to Gitea

This commit is contained in:
yuanyuanxiang
2026-04-19 19:55:01 +02:00
commit 5a325a202b
744 changed files with 235562 additions and 0 deletions

View File

@@ -0,0 +1,378 @@
// GeoLocationTest.cpp - IP地理位置API单元测试
// 测试 location.h 中的 GetGeoLocation 功能
#include <gtest/gtest.h>
#ifdef _WIN32
#include <winsock2.h>
#include <windows.h>
#include <wininet.h>
#include <ws2tcpip.h>
#pragma comment(lib, "wininet.lib")
#pragma comment(lib, "ws2_32.lib")
#include <string>
#include <vector>
// ============================================
// 简化的测试版本 - 不依赖jsoncpp
// ============================================
// UTF-8 转 ANSI
inline std::string Utf8ToAnsi(const std::string& utf8)
{
if (utf8.empty()) return "";
int wideLen = MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, NULL, 0);
if (wideLen <= 0) return utf8;
std::wstring wideStr(wideLen, 0);
MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, &wideStr[0], wideLen);
int ansiLen = WideCharToMultiByte(CP_ACP, 0, wideStr.c_str(), -1, NULL, 0, NULL, NULL);
if (ansiLen <= 0) return utf8;
std::string ansiStr(ansiLen, 0);
WideCharToMultiByte(CP_ACP, 0, wideStr.c_str(), -1, &ansiStr[0], ansiLen, NULL, NULL);
if (!ansiStr.empty() && ansiStr.back() == '\0') ansiStr.pop_back();
return ansiStr;
}
// 简单JSON值提取 (仅用于测试不依赖jsoncpp)
std::string ExtractJsonString(const std::string& json, const std::string& key)
{
std::string searchKey = "\"" + key + "\"";
size_t keyPos = json.find(searchKey);
if (keyPos == std::string::npos) return "";
size_t colonPos = json.find(':', keyPos);
if (colonPos == std::string::npos) return "";
size_t valueStart = json.find_first_not_of(" \t\n\r", colonPos + 1);
if (valueStart == std::string::npos) return "";
if (json[valueStart] == '"') {
size_t valueEnd = json.find('"', valueStart + 1);
if (valueEnd == std::string::npos) return "";
return json.substr(valueStart + 1, valueEnd - valueStart - 1);
}
// 数字或布尔值
size_t valueEnd = json.find_first_of(",}\n\r", valueStart);
if (valueEnd == std::string::npos) valueEnd = json.length();
std::string value = json.substr(valueStart, valueEnd - valueStart);
// 去除尾部空格
while (!value.empty() && (value.back() == ' ' || value.back() == '\t')) {
value.pop_back();
}
return value;
}
// API配置结构
struct GeoApiConfig {
const char* name;
const char* urlFmt;
const char* cityField;
const char* countryField;
const char* checkField;
const char* checkValue;
bool useHttps;
};
// 测试用API配置
static const GeoApiConfig testApis[] = {
{"ip-api.com", "http://ip-api.com/json/%s?fields=status,country,city", "city", "country", "status", "success", false},
{"ipinfo.io", "http://ipinfo.io/%s/json", "city", "country", "", "", false},
{"ipapi.co", "https://ipapi.co/%s/json/", "city", "country_name", "error", "", true},
};
// 测试单个API
struct ApiTestResult {
bool success;
int httpStatus;
std::string city;
std::string country;
std::string error;
double latencyMs;
};
ApiTestResult TestSingleApi(const GeoApiConfig& api, const std::string& ip)
{
ApiTestResult result = {false, 0, "", "", "", 0};
DWORD startTime = GetTickCount();
HINTERNET hInternet = InternetOpenA("GeoLocationTest", INTERNET_OPEN_TYPE_DIRECT, NULL, NULL, 0);
if (!hInternet) {
result.error = "InternetOpen failed";
return result;
}
DWORD timeout = 10000;
InternetSetOptionA(hInternet, INTERNET_OPTION_CONNECT_TIMEOUT, &timeout, sizeof(timeout));
InternetSetOptionA(hInternet, INTERNET_OPTION_SEND_TIMEOUT, &timeout, sizeof(timeout));
InternetSetOptionA(hInternet, INTERNET_OPTION_RECEIVE_TIMEOUT, &timeout, sizeof(timeout));
char urlBuf[256];
sprintf_s(urlBuf, api.urlFmt, ip.c_str());
DWORD flags = INTERNET_FLAG_RELOAD;
if (api.useHttps) flags |= INTERNET_FLAG_SECURE;
HINTERNET hConnect = InternetOpenUrlA(hInternet, urlBuf, NULL, 0, flags, 0);
if (!hConnect) {
result.error = "InternetOpenUrl failed: " + std::to_string(GetLastError());
InternetCloseHandle(hInternet);
return result;
}
// 获取HTTP状态码
DWORD statusCode = 0;
DWORD statusSize = sizeof(statusCode);
if (HttpQueryInfoA(hConnect, HTTP_QUERY_STATUS_CODE | HTTP_QUERY_FLAG_NUMBER, &statusCode, &statusSize, NULL)) {
result.httpStatus = statusCode;
}
// 读取响应
std::string readBuffer;
char buffer[4096];
DWORD bytesRead;
while (InternetReadFile(hConnect, buffer, sizeof(buffer), &bytesRead) && bytesRead > 0) {
readBuffer.append(buffer, bytesRead);
}
InternetCloseHandle(hConnect);
InternetCloseHandle(hInternet);
result.latencyMs = (double)(GetTickCount() - startTime);
// 检查HTTP错误
if (result.httpStatus >= 400) {
result.error = "HTTP " + std::to_string(result.httpStatus);
if (result.httpStatus == 429) {
result.error += " (Rate Limited)";
}
return result;
}
// 检查响应体错误
if (readBuffer.find("Rate limit") != std::string::npos ||
readBuffer.find("rate limit") != std::string::npos) {
result.error = "Rate limited (body)";
return result;
}
// 解析JSON
if (api.checkField && api.checkField[0]) {
std::string checkVal = ExtractJsonString(readBuffer, api.checkField);
if (api.checkValue && api.checkValue[0]) {
if (checkVal != api.checkValue) {
result.error = "Check failed: " + std::string(api.checkField) + "=" + checkVal;
return result;
}
} else {
if (checkVal == "true") {
result.error = "Error flag set";
return result;
}
}
}
result.city = Utf8ToAnsi(ExtractJsonString(readBuffer, api.cityField));
result.country = Utf8ToAnsi(ExtractJsonString(readBuffer, api.countryField));
result.success = !result.city.empty() || !result.country.empty();
if (!result.success) {
result.error = "No city/country in response";
}
return result;
}
// ============================================
// 测试用例
// ============================================
class GeoLocationTest : public ::testing::Test {
protected:
void SetUp() override {
// 初始化Winsock
WSADATA wsaData;
WSAStartup(MAKEWORD(2, 2), &wsaData);
}
void TearDown() override {
WSACleanup();
}
};
// 测试公网IP (Google DNS)
TEST_F(GeoLocationTest, TestPublicIP_GoogleDNS) {
std::string testIP = "8.8.8.8";
std::cout << "\n=== Testing IP: " << testIP << " ===\n";
int successCount = 0;
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] ";
if (result.success) {
std::cout << "OK - " << result.city << ", " << result.country;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
successCount++;
} else {
std::cout << "FAIL - " << result.error;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
}
}
// 至少一个API应该成功
EXPECT_GE(successCount, 1) << "At least one API should succeed";
}
// 测试另一个公网IP (Cloudflare DNS)
TEST_F(GeoLocationTest, TestPublicIP_CloudflareDNS) {
std::string testIP = "1.1.1.1";
std::cout << "\n=== Testing IP: " << testIP << " ===\n";
int successCount = 0;
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] ";
if (result.success) {
std::cout << "OK - " << result.city << ", " << result.country;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
successCount++;
} else {
std::cout << "FAIL - " << result.error;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
}
}
EXPECT_GE(successCount, 1) << "At least one API should succeed";
}
// 测试中国IP
TEST_F(GeoLocationTest, TestPublicIP_ChinaIP) {
std::string testIP = "114.114.114.114"; // 114DNS
std::cout << "\n=== Testing IP: " << testIP << " (China) ===\n";
int successCount = 0;
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] ";
if (result.success) {
std::cout << "OK - " << result.city << ", " << result.country;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
successCount++;
} else {
std::cout << "FAIL - " << result.error;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
}
}
EXPECT_GE(successCount, 1) << "At least one API should succeed";
}
// 测试ip-api.com单独
TEST_F(GeoLocationTest, TestIpApiCom) {
std::string testIP = "8.8.8.8";
auto result = TestSingleApi(testApis[0], testIP);
std::cout << "\n[ip-api.com] HTTP: " << result.httpStatus
<< ", Latency: " << result.latencyMs << "ms\n";
if (result.success) {
std::cout << "City: " << result.city << ", Country: " << result.country << "\n";
EXPECT_FALSE(result.country.empty());
} else {
std::cout << "Error: " << result.error << "\n";
// 如果是429就跳过不算失败
if (result.httpStatus == 429) {
GTEST_SKIP() << "Rate limited, skipping";
}
}
}
// 测试ipinfo.io单独
TEST_F(GeoLocationTest, TestIpInfoIo) {
std::string testIP = "8.8.8.8";
auto result = TestSingleApi(testApis[1], testIP);
std::cout << "\n[ipinfo.io] HTTP: " << result.httpStatus
<< ", Latency: " << result.latencyMs << "ms\n";
if (result.success) {
std::cout << "City: " << result.city << ", Country: " << result.country << "\n";
EXPECT_FALSE(result.country.empty());
} else {
std::cout << "Error: " << result.error << "\n";
if (result.httpStatus == 429) {
GTEST_SKIP() << "Rate limited, skipping";
}
}
}
// 测试ipapi.co单独
TEST_F(GeoLocationTest, TestIpApiCo) {
std::string testIP = "8.8.8.8";
auto result = TestSingleApi(testApis[2], testIP);
std::cout << "\n[ipapi.co] HTTP: " << result.httpStatus
<< ", Latency: " << result.latencyMs << "ms\n";
if (result.success) {
std::cout << "City: " << result.city << ", Country: " << result.country << "\n";
EXPECT_FALSE(result.country.empty());
} else {
std::cout << "Error: " << result.error << "\n";
if (result.httpStatus == 429) {
GTEST_SKIP() << "Rate limited, skipping";
}
}
}
// 测试HTTP状态码检测
TEST_F(GeoLocationTest, TestHttpStatusCodeDetection) {
// 使用一个会返回错误的请求
std::string testIP = "invalid-ip";
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] Invalid IP test: HTTP "
<< result.httpStatus << ", Error: " << result.error << "\n";
// 应该失败
EXPECT_FALSE(result.success);
}
}
// 测试所有API的响应时间
TEST_F(GeoLocationTest, TestApiLatency) {
std::string testIP = "8.8.8.8";
std::cout << "\n=== API Latency Test ===\n";
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] " << result.latencyMs << "ms";
if (result.success) {
std::cout << " (OK)";
} else {
std::cout << " (FAIL: " << result.error << ")";
}
std::cout << "\n";
// 响应时间应该在合理范围内 (< 15秒)
EXPECT_LT(result.latencyMs, 15000);
}
}
#else
// 非Windows平台 - 跳过测试
TEST(GeoLocationTest, SkipNonWindows) {
GTEST_SKIP() << "GeoLocation tests only run on Windows";
}
#endif

View File

@@ -0,0 +1,642 @@
/**
* @file HeaderTest.cpp
* @brief 协议头验证和加密测试
*
* 测试覆盖:
* - 协议头格式和常量
* - 加密/解密函数正确性
* - 多版本头部识别
* - 头部生成和验证往返
*/
#include <gtest/gtest.h>
#include <cstring>
#include <cstdint>
#include <vector>
#include <array>
// ============================================
// 协议常量定义(测试专用副本)
// ============================================
#define MSG_HEADER "HELL"
const int FLAG_COMPLEN = 4;
const int FLAG_LENGTH = 8;
const int HDR_LENGTH = FLAG_LENGTH + 2 * sizeof(unsigned int); // 16
const int MIN_COMLEN = 12;
enum HeaderEncType {
HeaderEncUnknown = -1,
HeaderEncNone,
HeaderEncV0,
HeaderEncV1,
HeaderEncV2,
HeaderEncV3,
HeaderEncV4,
HeaderEncV5,
HeaderEncV6,
HeaderEncNum,
};
enum FlagType {
FLAG_WINOS = -1,
FLAG_UNKNOWN = 0,
FLAG_SHINE = 1,
FLAG_FUCK = 2,
FLAG_HELLO = 3,
FLAG_HELL = 4,
};
// ============================================
// 加密/解密函数(测试专用副本)
// ============================================
inline void default_encrypt(unsigned char* data, size_t length, unsigned char key)
{
data[FLAG_LENGTH - 2] = data[FLAG_LENGTH - 1] = 0;
}
inline void default_decrypt(unsigned char* data, size_t length, unsigned char key)
{
}
inline void encrypt(unsigned char* data, size_t length, unsigned char key)
{
if (key == 0) return;
for (size_t i = 0; i < length; ++i) {
unsigned char k = static_cast<unsigned char>(key ^ (i * 31));
int value = static_cast<int>(data[i]);
switch (i % 4) {
case 0:
value += k;
break;
case 1:
value = value ^ k;
break;
case 2:
value -= k;
break;
case 3:
value = ~(value ^ k);
break;
}
data[i] = static_cast<unsigned char>(value & 0xFF);
}
}
inline void decrypt(unsigned char* data, size_t length, unsigned char key)
{
if (key == 0) return;
for (size_t i = 0; i < length; ++i) {
unsigned char k = static_cast<unsigned char>(key ^ (i * 31));
int value = static_cast<int>(data[i]);
switch (i % 4) {
case 0:
value -= k;
break;
case 1:
value = value ^ k;
break;
case 2:
value += k;
break;
case 3:
value = ~(value) ^ k;
break;
}
data[i] = static_cast<unsigned char>(value & 0xFF);
}
}
// ============================================
// HeaderFlag 结构体
// ============================================
typedef struct HeaderFlag {
char Data[FLAG_LENGTH + 1];
HeaderFlag(const char header[FLAG_LENGTH + 1])
{
memcpy(Data, header, sizeof(Data));
}
char& operator[](int i)
{
return Data[i];
}
const char operator[](int i) const
{
return Data[i];
}
const char* data() const
{
return Data;
}
} HeaderFlag;
typedef void (*EncFun)(unsigned char* data, size_t length, unsigned char key);
typedef void (*DecFun)(unsigned char* data, size_t length, unsigned char key);
// ============================================
// 头部生成函数
// ============================================
inline HeaderFlag GetHead(EncFun enc, unsigned char fixedKey = 0)
{
char header[FLAG_LENGTH + 1] = { 'H','E','L','L', 0 };
HeaderFlag H(header);
unsigned char key = (fixedKey != 0) ? fixedKey : (time(0) % 256);
H[FLAG_LENGTH - 2] = key;
H[FLAG_LENGTH - 1] = ~key;
enc((unsigned char*)H.data(), FLAG_COMPLEN, H[FLAG_LENGTH - 2]);
return H;
}
// ============================================
// 头部验证函数
// ============================================
inline int compare(const char *flag, const char *magic, int len, DecFun dec, unsigned char key)
{
unsigned char buf[32] = {};
memcpy(buf, flag, MIN_COMLEN);
dec(buf, len, key);
if (memcmp(buf, magic, len) == 0) {
memcpy((void*)flag, buf, MIN_COMLEN);
return 0;
}
return -1;
}
inline FlagType CheckHead(const char* flag, DecFun dec)
{
FlagType type = FLAG_UNKNOWN;
if (compare(flag, MSG_HEADER, FLAG_COMPLEN, dec, flag[6]) == 0) {
type = FLAG_HELL;
} else if (compare(flag, "Shine", 5, dec, 0) == 0) {
type = FLAG_SHINE;
} else if (compare(flag, "<<FUCK>>", 8, dec, flag[9]) == 0) {
type = FLAG_FUCK;
} else if (compare(flag, "Hello?", 6, dec, flag[6]) == 0) {
type = FLAG_HELLO;
} else {
type = FLAG_UNKNOWN;
}
return type;
}
inline FlagType CheckHeadMulti(char* flag, HeaderEncType& funcHit)
{
static const DecFun methods[] = { default_decrypt, decrypt };
static const int methodNum = sizeof(methods) / sizeof(DecFun);
char buffer[MIN_COMLEN + 4] = {};
for (int i = 0; i < methodNum; ++i) {
memcpy(buffer, flag, MIN_COMLEN);
FlagType type = CheckHead(buffer, methods[i]);
if (type != FLAG_UNKNOWN) {
memcpy(flag, buffer, MIN_COMLEN);
funcHit = HeaderEncType(i);
return type;
}
}
funcHit = HeaderEncUnknown;
return FLAG_UNKNOWN;
}
// ============================================
// 协议常量测试
// ============================================
class HeaderConstantsTest : public ::testing::Test {};
TEST_F(HeaderConstantsTest, FlagLength) {
EXPECT_EQ(FLAG_LENGTH, 8);
}
TEST_F(HeaderConstantsTest, FlagCompLen) {
EXPECT_EQ(FLAG_COMPLEN, 4);
}
TEST_F(HeaderConstantsTest, HdrLength) {
// FLAG_LENGTH(8) + 2 * sizeof(uint32_t)(8) = 16
EXPECT_EQ(HDR_LENGTH, 16);
}
TEST_F(HeaderConstantsTest, MinComLen) {
EXPECT_EQ(MIN_COMLEN, 12);
}
TEST_F(HeaderConstantsTest, MsgHeader) {
EXPECT_STREQ(MSG_HEADER, "HELL");
EXPECT_EQ(strlen(MSG_HEADER), 4u);
}
// ============================================
// 加密/解密测试
// ============================================
class EncryptionTest : public ::testing::Test {};
TEST_F(EncryptionTest, ZeroKeyNoOp) {
unsigned char data[] = {0x01, 0x02, 0x03, 0x04, 0x05};
unsigned char original[] = {0x01, 0x02, 0x03, 0x04, 0x05};
encrypt(data, 5, 0);
EXPECT_EQ(memcmp(data, original, 5), 0);
decrypt(data, 5, 0);
EXPECT_EQ(memcmp(data, original, 5), 0);
}
TEST_F(EncryptionTest, EncryptDecryptRoundTrip) {
unsigned char original[] = {0x48, 0x45, 0x4C, 0x4C}; // "HELL"
unsigned char data[4];
memcpy(data, original, 4);
unsigned char key = 0x42;
encrypt(data, 4, key);
// 加密后应该不同
EXPECT_NE(memcmp(data, original, 4), 0);
decrypt(data, 4, key);
// 解密后应该恢复
EXPECT_EQ(memcmp(data, original, 4), 0);
}
TEST_F(EncryptionTest, DifferentKeysProduceDifferentResults) {
unsigned char data1[] = {0x48, 0x45, 0x4C, 0x4C};
unsigned char data2[] = {0x48, 0x45, 0x4C, 0x4C};
encrypt(data1, 4, 0x10);
encrypt(data2, 4, 0x20);
EXPECT_NE(memcmp(data1, data2, 4), 0);
}
TEST_F(EncryptionTest, PositionDependentEncryption) {
// 相同值在不同位置加密结果不同
unsigned char data[] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
unsigned char original[] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
encrypt(data, 8, 0x55);
// 验证加密后值不全相同(位置相关加密)
bool allSame = true;
for (int i = 1; i < 8; ++i) {
if (data[i] != data[0]) {
allSame = false;
break;
}
}
EXPECT_FALSE(allSame);
// 验证解密恢复
decrypt(data, 8, 0x55);
EXPECT_EQ(memcmp(data, original, 8), 0);
}
TEST_F(EncryptionTest, AllKeyValues) {
// 测试所有可能的密钥值
unsigned char original[] = {0x12, 0x34, 0x56, 0x78};
for (int key = 1; key < 256; ++key) {
unsigned char data[4];
memcpy(data, original, 4);
encrypt(data, 4, static_cast<unsigned char>(key));
decrypt(data, 4, static_cast<unsigned char>(key));
EXPECT_EQ(memcmp(data, original, 4), 0) << "Failed for key: " << key;
}
}
TEST_F(EncryptionTest, LargeDataRoundTrip) {
std::vector<unsigned char> original(1000);
for (size_t i = 0; i < original.size(); ++i) {
original[i] = static_cast<unsigned char>(i & 0xFF);
}
std::vector<unsigned char> data = original;
unsigned char key = 0x7F;
encrypt(data.data(), data.size(), key);
decrypt(data.data(), data.size(), key);
EXPECT_EQ(data, original);
}
// ============================================
// HeaderFlag 测试
// ============================================
class HeaderFlagTest : public ::testing::Test {};
TEST_F(HeaderFlagTest, Construction) {
char header[FLAG_LENGTH + 1] = { 'H','E','L','L', 0, 0, 0, 0, 0 };
HeaderFlag hf(header);
EXPECT_EQ(hf[0], 'H');
EXPECT_EQ(hf[1], 'E');
EXPECT_EQ(hf[2], 'L');
EXPECT_EQ(hf[3], 'L');
}
TEST_F(HeaderFlagTest, DataAccess) {
char header[FLAG_LENGTH + 1] = { 'T','E','S','T', 0x12, 0x34, 0x56, 0x78, 0 };
HeaderFlag hf(header);
EXPECT_EQ(memcmp(hf.data(), "TEST", 4), 0);
EXPECT_EQ(static_cast<unsigned char>(hf[4]), 0x12);
EXPECT_EQ(static_cast<unsigned char>(hf[5]), 0x34);
}
TEST_F(HeaderFlagTest, Modification) {
char header[FLAG_LENGTH + 1] = { 0 };
HeaderFlag hf(header);
hf[0] = 'A';
hf[1] = 'B';
EXPECT_EQ(hf[0], 'A');
EXPECT_EQ(hf[1], 'B');
}
// ============================================
// GetHead 测试
// ============================================
class GetHeadTest : public ::testing::Test {};
TEST_F(GetHeadTest, GeneratesValidHeader) {
HeaderFlag hf = GetHead(default_encrypt, 0x42);
// 检查基础格式
EXPECT_EQ(hf[0], 'H');
EXPECT_EQ(hf[1], 'E');
EXPECT_EQ(hf[2], 'L');
EXPECT_EQ(hf[3], 'L');
}
TEST_F(GetHeadTest, KeyAndInverseKey) {
unsigned char key = 0x42;
HeaderFlag hf = GetHead(default_encrypt, key);
// 使用 default_encrypt 时key 位置被清零
// 但我们需要验证生成逻辑
char rawHeader[FLAG_LENGTH + 1] = { 'H','E','L','L', 0 };
HeaderFlag expected(rawHeader);
expected[FLAG_LENGTH - 2] = key;
expected[FLAG_LENGTH - 1] = ~key;
default_encrypt((unsigned char*)expected.data(), FLAG_COMPLEN, expected[FLAG_LENGTH - 2]);
EXPECT_EQ(memcmp(hf.data(), expected.data(), FLAG_LENGTH), 0);
}
TEST_F(GetHeadTest, EncryptedHeader) {
unsigned char key = 0x55;
HeaderFlag hf = GetHead(encrypt, key);
// 头部应该被加密,不再是明文 "HELL"
EXPECT_NE(memcmp(hf.data(), "HELL", 4), 0);
// 密钥应该在正确位置
unsigned char storedKey = static_cast<unsigned char>(hf[FLAG_LENGTH - 2]);
unsigned char inverseKey = static_cast<unsigned char>(hf[FLAG_LENGTH - 1]);
EXPECT_EQ(storedKey, key);
EXPECT_EQ(inverseKey, static_cast<unsigned char>(~key));
}
// ============================================
// CheckHead 测试
// ============================================
class CheckHeadTest : public ::testing::Test {};
TEST_F(CheckHeadTest, IdentifyHellFlag) {
char header[MIN_COMLEN + 4] = { 'H','E','L','L', 0, 0, 0x42, static_cast<char>(~0x42), 0 };
FlagType type = CheckHead(header, default_decrypt);
EXPECT_EQ(type, FLAG_HELL);
}
TEST_F(CheckHeadTest, IdentifyShineFlag) {
char header[MIN_COMLEN + 4] = { 'S','h','i','n','e', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
FlagType type = CheckHead(header, default_decrypt);
EXPECT_EQ(type, FLAG_SHINE);
}
TEST_F(CheckHeadTest, UnknownFlag) {
char header[MIN_COMLEN + 4] = { 'X','Y','Z','W', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
FlagType type = CheckHead(header, default_decrypt);
EXPECT_EQ(type, FLAG_UNKNOWN);
}
TEST_F(CheckHeadTest, EncryptedHellFlag) {
// 生成加密的头部
unsigned char key = 0x33;
char header[FLAG_LENGTH + 1] = { 'H','E','L','L', 0, 0, 0, 0, 0 };
header[FLAG_LENGTH - 2] = key;
header[FLAG_LENGTH - 1] = ~key;
encrypt((unsigned char*)header, FLAG_COMPLEN, key);
char buffer[MIN_COMLEN + 4] = {};
memcpy(buffer, header, FLAG_LENGTH);
FlagType type = CheckHead(buffer, decrypt);
EXPECT_EQ(type, FLAG_HELL);
}
// ============================================
// CheckHeadMulti 测试
// ============================================
class CheckHeadMultiTest : public ::testing::Test {};
TEST_F(CheckHeadMultiTest, PlainTextHeader) {
char header[MIN_COMLEN + 4] = { 'H','E','L','L', 0, 0, 0x42, static_cast<char>(~0x42), 0 };
HeaderEncType encType;
FlagType type = CheckHeadMulti(header, encType);
EXPECT_EQ(type, FLAG_HELL);
EXPECT_EQ(encType, HeaderEncNone);
}
TEST_F(CheckHeadMultiTest, EncryptedHeader) {
unsigned char key = 0x77;
char header[MIN_COMLEN + 4] = { 'H','E','L','L', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
header[FLAG_LENGTH - 2] = key;
header[FLAG_LENGTH - 1] = ~key;
encrypt((unsigned char*)header, FLAG_COMPLEN, key);
HeaderEncType encType;
FlagType type = CheckHeadMulti(header, encType);
EXPECT_EQ(type, FLAG_HELL);
EXPECT_EQ(encType, HeaderEncV0); // encrypt 对应 V0
}
TEST_F(CheckHeadMultiTest, UnrecognizedHeader) {
char header[MIN_COMLEN + 4] = { 0xFF, 0xFE, 0xFD, 0xFC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
HeaderEncType encType;
FlagType type = CheckHeadMulti(header, encType);
EXPECT_EQ(type, FLAG_UNKNOWN);
EXPECT_EQ(encType, HeaderEncUnknown);
}
// ============================================
// 头部生成和验证往返测试
// ============================================
class HeaderRoundTripTest : public ::testing::Test {};
TEST_F(HeaderRoundTripTest, DefaultEncryptRoundTrip) {
HeaderFlag hf = GetHead(default_encrypt, 0x42);
char buffer[MIN_COMLEN + 4] = {};
memcpy(buffer, hf.data(), FLAG_LENGTH);
HeaderEncType encType;
FlagType type = CheckHeadMulti(buffer, encType);
EXPECT_EQ(type, FLAG_HELL);
}
TEST_F(HeaderRoundTripTest, EncryptRoundTrip) {
HeaderFlag hf = GetHead(encrypt, 0x88);
char buffer[MIN_COMLEN + 4] = {};
memcpy(buffer, hf.data(), FLAG_LENGTH);
HeaderEncType encType;
FlagType type = CheckHeadMulti(buffer, encType);
EXPECT_EQ(type, FLAG_HELL);
}
TEST_F(HeaderRoundTripTest, AllKeyValuesRoundTrip) {
for (int key = 1; key < 256; ++key) {
HeaderFlag hf = GetHead(encrypt, static_cast<unsigned char>(key));
char buffer[MIN_COMLEN + 4] = {};
memcpy(buffer, hf.data(), FLAG_LENGTH);
HeaderEncType encType;
FlagType type = CheckHeadMulti(buffer, encType);
EXPECT_EQ(type, FLAG_HELL) << "Failed for key: " << key;
}
}
// ============================================
// 数据包长度字段测试
// ============================================
class PacketLengthTest : public ::testing::Test {};
#pragma pack(push, 1)
struct PacketHeader {
char flag[FLAG_LENGTH];
uint32_t packedLength; // 压缩后长度
uint32_t originalLength; // 原始长度
};
#pragma pack(pop)
TEST_F(PacketLengthTest, HeaderSize) {
EXPECT_EQ(sizeof(PacketHeader), HDR_LENGTH);
}
TEST_F(PacketLengthTest, BuildAndParsePacket) {
PacketHeader pkt = {};
memcpy(pkt.flag, "HELL", 4);
pkt.flag[6] = 0x42;
pkt.flag[7] = ~0x42;
pkt.packedLength = 100;
pkt.originalLength = 200;
// 解析
uint32_t packedLen, origLen;
memcpy(&packedLen, reinterpret_cast<char*>(&pkt) + FLAG_LENGTH, sizeof(uint32_t));
memcpy(&origLen, reinterpret_cast<char*>(&pkt) + FLAG_LENGTH + sizeof(uint32_t), sizeof(uint32_t));
EXPECT_EQ(packedLen, 100u);
EXPECT_EQ(origLen, 200u);
}
TEST_F(PacketLengthTest, TotalPacketLength) {
// 总包长度 = HDR_LENGTH + 压缩数据长度
uint32_t dataLen = 1000;
uint32_t totalLen = HDR_LENGTH + dataLen;
EXPECT_EQ(totalLen, 1016u);
}
// ============================================
// 边界条件测试
// ============================================
class HeaderBoundaryTest : public ::testing::Test {};
TEST_F(HeaderBoundaryTest, MinimalPacket) {
// 最小合法包:只有头部
std::vector<uint8_t> packet(HDR_LENGTH);
memcpy(packet.data(), "HELL", 4);
packet[6] = 0x42;
packet[7] = ~0x42;
// packedLength = HDR_LENGTH
uint32_t packedLen = HDR_LENGTH;
memcpy(packet.data() + FLAG_LENGTH, &packedLen, sizeof(uint32_t));
// originalLength = 0
uint32_t origLen = 0;
memcpy(packet.data() + FLAG_LENGTH + sizeof(uint32_t), &origLen, sizeof(uint32_t));
EXPECT_EQ(packet.size(), HDR_LENGTH);
}
TEST_F(HeaderBoundaryTest, MaxPacketLength) {
// 验证大包长度字段
uint32_t maxDataLen = 100 * 1024 * 1024; // 100 MB
uint32_t totalLen = HDR_LENGTH + maxDataLen;
PacketHeader pkt = {};
pkt.packedLength = totalLen;
pkt.originalLength = maxDataLen * 2; // 压缩前更大
EXPECT_EQ(pkt.packedLength, totalLen);
EXPECT_EQ(pkt.originalLength, maxDataLen * 2);
}
TEST_F(HeaderBoundaryTest, TruncatedHeader) {
// 不完整的头部不应该被识别
char truncated[4] = { 'H', 'E', 'L', 'L' };
// 不足以进行完整验证
// 这里只是验证常量定义正确
EXPECT_LT(sizeof(truncated), static_cast<size_t>(MIN_COMLEN));
}
// ============================================
// FlagType 枚举测试
// ============================================
class FlagTypeTest : public ::testing::Test {};
TEST_F(FlagTypeTest, EnumValues) {
EXPECT_EQ(FLAG_WINOS, -1);
EXPECT_EQ(FLAG_UNKNOWN, 0);
EXPECT_EQ(FLAG_SHINE, 1);
EXPECT_EQ(FLAG_FUCK, 2);
EXPECT_EQ(FLAG_HELLO, 3);
EXPECT_EQ(FLAG_HELL, 4);
}
TEST_F(FlagTypeTest, HeaderEncTypeValues) {
EXPECT_EQ(HeaderEncUnknown, -1);
EXPECT_EQ(HeaderEncNone, 0);
EXPECT_EQ(HeaderEncV0, 1);
EXPECT_EQ(HeaderEncNum, 8);
}

View File

@@ -0,0 +1,534 @@
/**
* @file HttpMaskTest.cpp
* @brief HTTP 协议伪装测试
*
* 测试覆盖:
* - HTTP 请求格式生成
* - HTTP 头部解析和移除
* - Mask/UnMask 往返测试
* - 边界条件处理
*/
#include <gtest/gtest.h>
#include <cstring>
#include <string>
#include <vector>
#include <map>
// ============================================
// 类型定义
// ============================================
#ifdef _WIN32
typedef unsigned long ULONG;
#else
typedef uint32_t ULONG;
#endif
// ============================================
// 协议伪装类型
// ============================================
enum PkgMaskType {
MaskTypeUnknown = -1,
MaskTypeNone,
MaskTypeHTTP,
MaskTypeNum,
};
// ============================================
// HTTP 解除伪装函数
// ============================================
inline ULONG UnMaskHttp(const char* src, ULONG srcSize)
{
const char* header_end_mark = "\r\n\r\n";
const ULONG mark_len = 4;
for (ULONG i = 0; i + mark_len <= srcSize; ++i) {
if (memcmp(src + i, header_end_mark, mark_len) == 0) {
return i + mark_len;
}
}
return 0;
}
inline ULONG TryUnMask(const char* src, ULONG srcSize, PkgMaskType& maskHit)
{
if (srcSize >= 5 && memcmp(src, "POST ", 5) == 0) {
maskHit = MaskTypeHTTP;
return UnMaskHttp(src, srcSize);
}
if (srcSize >= 4 && memcmp(src, "GET ", 4) == 0) {
maskHit = MaskTypeHTTP;
return UnMaskHttp(src, srcSize);
}
maskHit = MaskTypeNone;
return 0;
}
// ============================================
// HTTP Mask 类(简化版)
// ============================================
class HttpMask {
public:
explicit HttpMask(const std::string& host = "example.com",
const std::map<std::string, std::string>& headers = {})
: host_(host)
{
for (const auto& kv : headers) {
customHeaders_ += kv.first + ": " + kv.second + "\r\n";
}
}
void Mask(char*& dst, ULONG& dstSize, const char* src, ULONG srcSize, int cmd = -1)
{
std::string path = "/api/v1/" + std::to_string(cmd == -1 ? 0 : cmd);
std::string httpHeader =
"POST " + path + " HTTP/1.1\r\n"
"Host: " + host_ + "\r\n"
"User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64)\r\n"
"Content-Type: application/octet-stream\r\n"
"Content-Length: " + std::to_string(srcSize) + "\r\n" + customHeaders_ +
"Connection: keep-alive\r\n"
"\r\n";
dstSize = static_cast<ULONG>(httpHeader.size()) + srcSize;
dst = new char[dstSize];
memcpy(dst, httpHeader.data(), httpHeader.size());
if (srcSize > 0) {
memcpy(dst + httpHeader.size(), src, srcSize);
}
}
ULONG UnMask(const char* src, ULONG srcSize)
{
return UnMaskHttp(src, srcSize);
}
void SetHost(const std::string& host)
{
host_ = host;
}
private:
std::string host_;
std::string customHeaders_;
};
// ============================================
// UnMaskHttp 测试
// ============================================
class UnMaskHttpTest : public ::testing::Test {};
TEST_F(UnMaskHttpTest, ValidHttpRequest) {
std::string httpRequest =
"POST /api HTTP/1.1\r\n"
"Host: example.com\r\n"
"Content-Length: 4\r\n"
"\r\n"
"DATA";
ULONG offset = UnMaskHttp(httpRequest.data(), static_cast<ULONG>(httpRequest.size()));
// 应该返回 body 起始位置
EXPECT_GT(offset, 0u);
EXPECT_STREQ(httpRequest.data() + offset, "DATA");
}
TEST_F(UnMaskHttpTest, NoHeaderEnd) {
std::string incomplete = "POST /api HTTP/1.1\r\nHost: example.com\r\n";
ULONG offset = UnMaskHttp(incomplete.data(), static_cast<ULONG>(incomplete.size()));
EXPECT_EQ(offset, 0u);
}
TEST_F(UnMaskHttpTest, EmptyBody) {
std::string httpRequest =
"POST /api HTTP/1.1\r\n"
"Content-Length: 0\r\n"
"\r\n";
ULONG offset = UnMaskHttp(httpRequest.data(), static_cast<ULONG>(httpRequest.size()));
EXPECT_EQ(offset, static_cast<ULONG>(httpRequest.size()));
}
TEST_F(UnMaskHttpTest, MultipleHeaderEndMarkers) {
std::string httpRequest =
"POST /api HTTP/1.1\r\n"
"Content-Length: 8\r\n"
"\r\n"
"\r\n\r\nXX"; // body 中也有 \r\n\r\n
ULONG offset = UnMaskHttp(httpRequest.data(), static_cast<ULONG>(httpRequest.size()));
// 应该返回第一个 \r\n\r\n 之后
std::string body(httpRequest.data() + offset);
EXPECT_EQ(body, "\r\n\r\nXX");
}
TEST_F(UnMaskHttpTest, MinimalInput) {
// 小于 4 字节
std::string tiny = "AB";
ULONG offset = UnMaskHttp(tiny.data(), static_cast<ULONG>(tiny.size()));
EXPECT_EQ(offset, 0u);
}
TEST_F(UnMaskHttpTest, ExactlyHeaderEnd) {
std::string justEnd = "\r\n\r\n";
ULONG offset = UnMaskHttp(justEnd.data(), static_cast<ULONG>(justEnd.size()));
EXPECT_EQ(offset, 4u);
}
// ============================================
// TryUnMask 测试
// ============================================
class TryUnMaskTest : public ::testing::Test {};
TEST_F(TryUnMaskTest, DetectPOST) {
std::string httpRequest =
"POST /api HTTP/1.1\r\n"
"Host: test.com\r\n"
"\r\n"
"body";
PkgMaskType maskType;
ULONG offset = TryUnMask(httpRequest.data(), static_cast<ULONG>(httpRequest.size()), maskType);
EXPECT_EQ(maskType, MaskTypeHTTP);
EXPECT_GT(offset, 0u);
}
TEST_F(TryUnMaskTest, DetectGET) {
std::string httpRequest =
"GET /resource HTTP/1.1\r\n"
"Host: test.com\r\n"
"\r\n";
PkgMaskType maskType;
ULONG offset = TryUnMask(httpRequest.data(), static_cast<ULONG>(httpRequest.size()), maskType);
EXPECT_EQ(maskType, MaskTypeHTTP);
EXPECT_GT(offset, 0u);
}
TEST_F(TryUnMaskTest, NonHttpData) {
std::string binaryData = "HELL\x00\x00\x42\xBD";
PkgMaskType maskType;
ULONG offset = TryUnMask(binaryData.data(), static_cast<ULONG>(binaryData.size()), maskType);
EXPECT_EQ(maskType, MaskTypeNone);
EXPECT_EQ(offset, 0u);
}
TEST_F(TryUnMaskTest, ShortInput) {
std::string tooShort = "POS";
PkgMaskType maskType;
ULONG offset = TryUnMask(tooShort.data(), static_cast<ULONG>(tooShort.size()), maskType);
EXPECT_EQ(maskType, MaskTypeNone);
EXPECT_EQ(offset, 0u);
}
// ============================================
// HttpMask 类测试
// ============================================
class HttpMaskClassTest : public ::testing::Test {};
TEST_F(HttpMaskClassTest, MaskBasic) {
HttpMask mask("api.example.com");
const char* data = "Hello";
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, data, 5);
ASSERT_NE(masked, nullptr);
EXPECT_GT(maskedSize, 5u);
// 验证是 HTTP 格式
EXPECT_EQ(memcmp(masked, "POST ", 5), 0);
// 验证 Host 头
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Host: api.example.com"), std::string::npos);
// 验证 Content-Length
EXPECT_NE(maskedStr.find("Content-Length: 5"), std::string::npos);
// 验证 body
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(memcmp(masked + offset, "Hello", 5), 0);
delete[] masked;
}
TEST_F(HttpMaskClassTest, MaskEmptyData) {
HttpMask mask;
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "", 0);
ASSERT_NE(masked, nullptr);
// 验证 Content-Length: 0
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Content-Length: 0"), std::string::npos);
delete[] masked;
}
TEST_F(HttpMaskClassTest, MaskWithCommand) {
HttpMask mask;
const char* data = "X";
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, data, 1, 42);
// 验证路径包含命令号
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("/42"), std::string::npos);
delete[] masked;
}
TEST_F(HttpMaskClassTest, MaskLargeData) {
HttpMask mask;
std::vector<char> largeData(64 * 1024, 'X'); // 64 KB
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, largeData.data(), static_cast<ULONG>(largeData.size()));
ASSERT_NE(masked, nullptr);
// 验证 Content-Length
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Content-Length: 65536"), std::string::npos);
// 验证 body 完整
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(maskedSize - offset, 64u * 1024u);
delete[] masked;
}
// ============================================
// Mask/UnMask 往返测试
// ============================================
class MaskRoundTripTest : public ::testing::Test {};
TEST_F(MaskRoundTripTest, SimpleRoundTrip) {
HttpMask mask;
std::vector<char> original = {'H', 'E', 'L', 'L', 'O'};
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
// UnMask
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_GT(offset, 0u);
// 验证数据完整
EXPECT_EQ(maskedSize - offset, original.size());
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
delete[] masked;
}
TEST_F(MaskRoundTripTest, BinaryDataRoundTrip) {
HttpMask mask;
// 包含所有字节值
std::vector<char> original(256);
for (int i = 0; i < 256; ++i) {
original[i] = static_cast<char>(i);
}
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
delete[] masked;
}
TEST_F(MaskRoundTripTest, NullBytesRoundTrip) {
HttpMask mask;
std::vector<char> original = {'\0', '\0', 'A', '\0', 'B'};
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(maskedSize - offset, original.size());
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
delete[] masked;
}
TEST_F(MaskRoundTripTest, HttpLikeDataRoundTrip) {
HttpMask mask;
// 数据本身看起来像 HTTP
std::string httpLike = "POST /fake HTTP/1.1\r\n\r\nfake body";
std::vector<char> original(httpLike.begin(), httpLike.end());
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
delete[] masked;
}
// ============================================
// 自定义头部测试
// ============================================
class CustomHeadersTest : public ::testing::Test {};
TEST_F(CustomHeadersTest, AddCustomHeaders) {
std::map<std::string, std::string> headers;
headers["X-Custom-Header"] = "custom-value";
headers["X-Request-ID"] = "12345";
HttpMask mask("test.com", headers);
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "data", 4);
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("X-Custom-Header: custom-value"), std::string::npos);
EXPECT_NE(maskedStr.find("X-Request-ID: 12345"), std::string::npos);
delete[] masked;
}
// ============================================
// 边界条件测试
// ============================================
class HttpMaskBoundaryTest : public ::testing::Test {};
TEST_F(HttpMaskBoundaryTest, VeryLongHost) {
std::string longHost(1000, 'x');
longHost += ".com";
HttpMask mask(longHost);
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "test", 4);
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find(longHost), std::string::npos);
delete[] masked;
}
TEST_F(HttpMaskBoundaryTest, SpecialCharactersInHost) {
HttpMask mask("api-v2.test-server.example.com");
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "x", 1);
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Host: api-v2.test-server.example.com"), std::string::npos);
delete[] masked;
}
TEST_F(HttpMaskBoundaryTest, MaxULONGContentLength) {
// 测试大的 Content-Length 值
HttpMask mask;
// 不实际分配这么大的内存,只是构造请求头
std::string largeContentLengthStr = "Content-Length: 4294967295";
// 验证格式正确
EXPECT_EQ(largeContentLengthStr.find("4294967295"), 16u);
}
// ============================================
// HTTP 格式验证测试
// ============================================
class HttpFormatTest : public ::testing::Test {};
TEST_F(HttpFormatTest, ValidHttpRequestFormat) {
HttpMask mask("test.com");
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "body", 4);
std::string maskedStr(masked, maskedSize);
// 验证请求行
EXPECT_EQ(maskedStr.substr(0, 5), "POST ");
// 验证 HTTP 版本
EXPECT_NE(maskedStr.find("HTTP/1.1\r\n"), std::string::npos);
// 验证必需的头部
EXPECT_NE(maskedStr.find("Host:"), std::string::npos);
EXPECT_NE(maskedStr.find("Content-Length:"), std::string::npos);
EXPECT_NE(maskedStr.find("Content-Type:"), std::string::npos);
// 验证头部和 body 之间的分隔符
EXPECT_NE(maskedStr.find("\r\n\r\n"), std::string::npos);
delete[] masked;
}
TEST_F(HttpFormatTest, ContentLengthMatchesBody) {
HttpMask mask;
std::vector<char> body(123, 'X');
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, body.data(), static_cast<ULONG>(body.size()));
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Content-Length: 123"), std::string::npos);
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(maskedSize - offset, 123u);
delete[] masked;
}

View File

@@ -0,0 +1,660 @@
/**
* @file PacketFragmentTest.cpp
* @brief 粘包/分包处理测试
*
* 测试覆盖:
* - 完整包接收和解析
* - 分包(不完整包)处理
* - 粘包(多个包粘在一起)处理
* - 混合场景测试
* - 边界条件处理
*/
#include <gtest/gtest.h>
#include <cstring>
#include <cstdint>
#include <vector>
#include <queue>
#include <functional>
// ============================================
// 协议常量
// ============================================
const int FLAG_LENGTH = 8;
const int HDR_LENGTH = 16; // FLAG(8) + PackedLen(4) + OrigLen(4)
// ============================================
// 简化版 CBuffer测试专用
// ============================================
class TestBuffer {
public:
TestBuffer() {}
void WriteBuffer(const uint8_t* data, size_t len) {
m_data.insert(m_data.end(), data, data + len);
}
size_t ReadBuffer(uint8_t* dst, size_t len) {
size_t toRead = std::min(len, m_data.size());
if (toRead > 0) {
memcpy(dst, m_data.data(), toRead);
m_data.erase(m_data.begin(), m_data.begin() + toRead);
}
return toRead;
}
void Skip(size_t len) {
size_t toSkip = std::min(len, m_data.size());
m_data.erase(m_data.begin(), m_data.begin() + toSkip);
}
size_t GetBufferLength() const {
return m_data.size();
}
const uint8_t* GetBuffer(size_t pos = 0) const {
if (pos >= m_data.size()) return nullptr;
return m_data.data() + pos;
}
bool CopyBuffer(uint8_t* dst, size_t pos, size_t len) const {
if (pos + len > m_data.size()) return false;
memcpy(dst, m_data.data() + pos, len);
return true;
}
void Clear() {
m_data.clear();
}
private:
std::vector<uint8_t> m_data;
};
// ============================================
// 数据包构建辅助函数
// ============================================
#pragma pack(push, 1)
struct PacketHeader {
char flag[FLAG_LENGTH];
uint32_t packedLength; // 包含头部的总长度
uint32_t originalLength; // 原始数据长度
};
#pragma pack(pop)
std::vector<uint8_t> BuildPacket(const std::vector<uint8_t>& payload, uint8_t key = 0x42) {
std::vector<uint8_t> packet(HDR_LENGTH + payload.size());
PacketHeader* hdr = reinterpret_cast<PacketHeader*>(packet.data());
memcpy(hdr->flag, "HELL", 4);
hdr->flag[6] = key;
hdr->flag[7] = ~key;
hdr->packedLength = static_cast<uint32_t>(HDR_LENGTH + payload.size());
hdr->originalLength = static_cast<uint32_t>(payload.size());
if (!payload.empty()) {
memcpy(packet.data() + HDR_LENGTH, payload.data(), payload.size());
}
return packet;
}
std::vector<uint8_t> BuildPacketWithData(size_t dataSize, uint8_t fillByte = 0xAA) {
std::vector<uint8_t> payload(dataSize, fillByte);
return BuildPacket(payload);
}
// ============================================
// 粘包/分包处理器(模拟 OnServerReceiving 逻辑)
// ============================================
class PacketProcessor {
public:
using PacketCallback = std::function<void(const std::vector<uint8_t>&)>;
PacketProcessor(PacketCallback callback) : m_callback(callback) {}
// 接收数据(模拟网络接收)
void OnReceive(const uint8_t* data, size_t len) {
m_buffer.WriteBuffer(data, len);
ProcessBuffer();
}
// 获取待处理的字节数
size_t GetPendingBytes() const {
return m_buffer.GetBufferLength();
}
// 获取已处理的包数量
size_t GetProcessedCount() const {
return m_processedCount;
}
private:
void ProcessBuffer() {
while (m_buffer.GetBufferLength() >= HDR_LENGTH) {
// 验证头部
const uint8_t* buf = m_buffer.GetBuffer();
if (memcmp(buf, "HELL", 4) != 0) {
// 无效头部,跳过一个字节重试
m_buffer.Skip(1);
continue;
}
// 读取包长度
uint32_t packedLength;
m_buffer.CopyBuffer(reinterpret_cast<uint8_t*>(&packedLength),
FLAG_LENGTH, sizeof(uint32_t));
// 检查长度有效性
if (packedLength < HDR_LENGTH || packedLength > 100 * 1024 * 1024) {
// 无效长度,跳过头部
m_buffer.Skip(FLAG_LENGTH);
continue;
}
// 检查包是否完整
if (m_buffer.GetBufferLength() < packedLength) {
// 不完整,等待更多数据
break;
}
// 读取完整包
std::vector<uint8_t> packet(packedLength);
m_buffer.ReadBuffer(packet.data(), packedLength);
// 提取 payload
std::vector<uint8_t> payload(packet.begin() + HDR_LENGTH, packet.end());
m_callback(payload);
m_processedCount++;
}
}
TestBuffer m_buffer;
PacketCallback m_callback;
size_t m_processedCount = 0;
};
// ============================================
// 完整包接收测试
// ============================================
class CompletePacketTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
void SetUp() override {
receivedPackets.clear();
}
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(CompletePacketTest, SinglePacket) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0x01, 0x02, 0x03, 0x04};
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
EXPECT_EQ(processor.GetPendingBytes(), 0u);
}
TEST_F(CompletePacketTest, EmptyPayload) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload;
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_TRUE(receivedPackets[0].empty());
}
TEST_F(CompletePacketTest, LargePayload) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(64 * 1024); // 64 KB
for (size_t i = 0; i < payload.size(); ++i) {
payload[i] = static_cast<uint8_t>(i & 0xFF);
}
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
// ============================================
// 分包(不完整包)测试
// ============================================
class FragmentedPacketTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(FragmentedPacketTest, TwoFragments) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0xAA, 0xBB, 0xCC, 0xDD};
auto packet = BuildPacket(payload);
// 分两次发送
size_t half = packet.size() / 2;
processor.OnReceive(packet.data(), half);
EXPECT_EQ(receivedPackets.size(), 0u);
EXPECT_EQ(processor.GetPendingBytes(), half);
processor.OnReceive(packet.data() + half, packet.size() - half);
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
TEST_F(FragmentedPacketTest, ManyFragments) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(100);
for (size_t i = 0; i < payload.size(); ++i) {
payload[i] = static_cast<uint8_t>(i);
}
auto packet = BuildPacket(payload);
// 每次发送 10 字节
for (size_t i = 0; i < packet.size(); i += 10) {
size_t len = std::min(size_t(10), packet.size() - i);
processor.OnReceive(packet.data() + i, len);
}
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
TEST_F(FragmentedPacketTest, OnlyHeader) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0x01, 0x02};
auto packet = BuildPacket(payload);
// 只发送头部
processor.OnReceive(packet.data(), HDR_LENGTH);
EXPECT_EQ(receivedPackets.size(), 0u);
EXPECT_EQ(processor.GetPendingBytes(), HDR_LENGTH);
// 发送剩余数据
processor.OnReceive(packet.data() + HDR_LENGTH, packet.size() - HDR_LENGTH);
ASSERT_EQ(receivedPackets.size(), 1u);
}
TEST_F(FragmentedPacketTest, PartialHeader) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0xFF};
auto packet = BuildPacket(payload);
// 发送不完整的头部
processor.OnReceive(packet.data(), 4); // 只有 "HELL"
EXPECT_EQ(receivedPackets.size(), 0u);
// 发送剩余部分
processor.OnReceive(packet.data() + 4, packet.size() - 4);
ASSERT_EQ(receivedPackets.size(), 1u);
}
// ============================================
// 粘包测试
// ============================================
class StickyPacketTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(StickyPacketTest, TwoPacketsStuckTogether) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload1 = {0x11, 0x22};
std::vector<uint8_t> payload2 = {0x33, 0x44, 0x55};
auto packet1 = BuildPacket(payload1);
auto packet2 = BuildPacket(payload2);
// 合并两个包
std::vector<uint8_t> combined;
combined.insert(combined.end(), packet1.begin(), packet1.end());
combined.insert(combined.end(), packet2.begin(), packet2.end());
processor.OnReceive(combined.data(), combined.size());
ASSERT_EQ(receivedPackets.size(), 2u);
EXPECT_EQ(receivedPackets[0], payload1);
EXPECT_EQ(receivedPackets[1], payload2);
}
TEST_F(StickyPacketTest, ThreePacketsStuckTogether) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload1 = {0x01};
std::vector<uint8_t> payload2 = {0x02, 0x03};
std::vector<uint8_t> payload3 = {0x04, 0x05, 0x06};
auto packet1 = BuildPacket(payload1);
auto packet2 = BuildPacket(payload2);
auto packet3 = BuildPacket(payload3);
// 合并三个包
std::vector<uint8_t> combined;
combined.insert(combined.end(), packet1.begin(), packet1.end());
combined.insert(combined.end(), packet2.begin(), packet2.end());
combined.insert(combined.end(), packet3.begin(), packet3.end());
processor.OnReceive(combined.data(), combined.size());
ASSERT_EQ(receivedPackets.size(), 3u);
EXPECT_EQ(receivedPackets[0], payload1);
EXPECT_EQ(receivedPackets[1], payload2);
EXPECT_EQ(receivedPackets[2], payload3);
}
TEST_F(StickyPacketTest, ManyPacketsStuckTogether) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> combined;
const int numPackets = 100;
for (int i = 0; i < numPackets; ++i) {
std::vector<uint8_t> payload(i % 10 + 1, static_cast<uint8_t>(i));
auto packet = BuildPacket(payload);
combined.insert(combined.end(), packet.begin(), packet.end());
}
processor.OnReceive(combined.data(), combined.size());
EXPECT_EQ(receivedPackets.size(), numPackets);
}
// ============================================
// 混合场景测试
// ============================================
class MixedScenarioTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(MixedScenarioTest, OneAndHalfPackets) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload1 = {0xAA, 0xBB};
std::vector<uint8_t> payload2 = {0xCC, 0xDD, 0xEE, 0xFF};
auto packet1 = BuildPacket(payload1);
auto packet2 = BuildPacket(payload2);
// 发送完整包1 + 半个包2
std::vector<uint8_t> firstSend;
firstSend.insert(firstSend.end(), packet1.begin(), packet1.end());
firstSend.insert(firstSend.end(), packet2.begin(), packet2.begin() + packet2.size() / 2);
processor.OnReceive(firstSend.data(), firstSend.size());
EXPECT_EQ(receivedPackets.size(), 1u); // 只处理了包1
// 发送剩余的半个包2
processor.OnReceive(packet2.data() + packet2.size() / 2, packet2.size() - packet2.size() / 2);
ASSERT_EQ(receivedPackets.size(), 2u);
EXPECT_EQ(receivedPackets[0], payload1);
EXPECT_EQ(receivedPackets[1], payload2);
}
TEST_F(MixedScenarioTest, HalfPacketThenOneAndHalf) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload1 = {0x11};
std::vector<uint8_t> payload2 = {0x22, 0x33};
auto packet1 = BuildPacket(payload1);
auto packet2 = BuildPacket(payload2);
// 发送半个包1
processor.OnReceive(packet1.data(), packet1.size() / 2);
EXPECT_EQ(receivedPackets.size(), 0u);
// 发送剩余包1 + 完整包2
std::vector<uint8_t> secondSend;
secondSend.insert(secondSend.end(), packet1.begin() + packet1.size() / 2, packet1.end());
secondSend.insert(secondSend.end(), packet2.begin(), packet2.end());
processor.OnReceive(secondSend.data(), secondSend.size());
ASSERT_EQ(receivedPackets.size(), 2u);
}
TEST_F(MixedScenarioTest, RandomChunkSizes) {
PacketProcessor processor(GetCallback());
// 准备多个包
std::vector<std::vector<uint8_t>> payloads;
std::vector<uint8_t> allData;
for (int i = 0; i < 10; ++i) {
std::vector<uint8_t> payload(i * 5 + 10, static_cast<uint8_t>(i));
payloads.push_back(payload);
auto packet = BuildPacket(payload);
allData.insert(allData.end(), packet.begin(), packet.end());
}
// 使用"随机"大小的块发送
size_t chunkSizes[] = {1, 7, 15, 16, 17, 31, 32, 33, 64, 128};
size_t pos = 0;
size_t chunkIdx = 0;
while (pos < allData.size()) {
size_t chunkSize = chunkSizes[chunkIdx % (sizeof(chunkSizes) / sizeof(chunkSizes[0]))];
size_t len = std::min(chunkSize, allData.size() - pos);
processor.OnReceive(allData.data() + pos, len);
pos += len;
chunkIdx++;
}
ASSERT_EQ(receivedPackets.size(), payloads.size());
for (size_t i = 0; i < payloads.size(); ++i) {
EXPECT_EQ(receivedPackets[i], payloads[i]) << "Mismatch at packet " << i;
}
}
// ============================================
// 边界条件测试
// ============================================
class PacketBoundaryTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(PacketBoundaryTest, ExactlyHdrLength) {
PacketProcessor processor(GetCallback());
// 只有头部,无 payload
std::vector<uint8_t> packet(HDR_LENGTH);
PacketHeader* hdr = reinterpret_cast<PacketHeader*>(packet.data());
memcpy(hdr->flag, "HELL", 4);
hdr->flag[6] = 0x42;
hdr->flag[7] = ~0x42;
hdr->packedLength = HDR_LENGTH;
hdr->originalLength = 0;
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_TRUE(receivedPackets[0].empty());
}
TEST_F(PacketBoundaryTest, SingleBytePayload) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0xFF};
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0].size(), 1u);
EXPECT_EQ(receivedPackets[0][0], 0xFF);
}
TEST_F(PacketBoundaryTest, ByteByByteReceive) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0x12, 0x34, 0x56};
auto packet = BuildPacket(payload);
// 每次接收 1 字节
for (size_t i = 0; i < packet.size(); ++i) {
processor.OnReceive(packet.data() + i, 1);
}
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
// ============================================
// 数据完整性测试
// ============================================
class DataIntegrityTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(DataIntegrityTest, BinaryData) {
PacketProcessor processor(GetCallback());
// 包含所有字节值的 payload
std::vector<uint8_t> payload(256);
for (int i = 0; i < 256; ++i) {
payload[i] = static_cast<uint8_t>(i);
}
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
TEST_F(DataIntegrityTest, AllZeros) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(100, 0x00);
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
for (uint8_t b : receivedPackets[0]) {
EXPECT_EQ(b, 0x00);
}
}
TEST_F(DataIntegrityTest, AllOnes) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(100, 0xFF);
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
for (uint8_t b : receivedPackets[0]) {
EXPECT_EQ(b, 0xFF);
}
}
// ============================================
// 性能相关测试
// ============================================
class PacketPerformanceTest : public ::testing::Test {
protected:
size_t packetCount = 0;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
packetCount++;
};
}
};
TEST_F(PacketPerformanceTest, ManySmallPackets) {
PacketProcessor processor(GetCallback());
const int numPackets = 10000;
std::vector<uint8_t> allData;
for (int i = 0; i < numPackets; ++i) {
std::vector<uint8_t> payload = {static_cast<uint8_t>(i & 0xFF)};
auto packet = BuildPacket(payload);
allData.insert(allData.end(), packet.begin(), packet.end());
}
processor.OnReceive(allData.data(), allData.size());
EXPECT_EQ(packetCount, numPackets);
}
TEST_F(PacketPerformanceTest, LargePacketInSmallChunks) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(100 * 1024); // 100 KB
for (size_t i = 0; i < payload.size(); ++i) {
payload[i] = static_cast<uint8_t>(i & 0xFF);
}
auto packet = BuildPacket(payload);
// 每次发送 1 KB
const size_t chunkSize = 1024;
for (size_t i = 0; i < packet.size(); i += chunkSize) {
size_t len = std::min(chunkSize, packet.size() - i);
processor.OnReceive(packet.data() + i, len);
}
EXPECT_EQ(packetCount, 1u);
}