Init: Migrate SimpleRemoter (Since v1.3.1) to Gitea
This commit is contained in:
378
test/unit/network/GeoLocationTest.cpp
Normal file
378
test/unit/network/GeoLocationTest.cpp
Normal file
@@ -0,0 +1,378 @@
|
||||
// GeoLocationTest.cpp - IP地理位置API单元测试
|
||||
// 测试 location.h 中的 GetGeoLocation 功能
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <winsock2.h>
|
||||
#include <windows.h>
|
||||
#include <wininet.h>
|
||||
#include <ws2tcpip.h>
|
||||
#pragma comment(lib, "wininet.lib")
|
||||
#pragma comment(lib, "ws2_32.lib")
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// ============================================
|
||||
// 简化的测试版本 - 不依赖jsoncpp
|
||||
// ============================================
|
||||
|
||||
// UTF-8 转 ANSI
|
||||
inline std::string Utf8ToAnsi(const std::string& utf8)
|
||||
{
|
||||
if (utf8.empty()) return "";
|
||||
int wideLen = MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, NULL, 0);
|
||||
if (wideLen <= 0) return utf8;
|
||||
std::wstring wideStr(wideLen, 0);
|
||||
MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, &wideStr[0], wideLen);
|
||||
int ansiLen = WideCharToMultiByte(CP_ACP, 0, wideStr.c_str(), -1, NULL, 0, NULL, NULL);
|
||||
if (ansiLen <= 0) return utf8;
|
||||
std::string ansiStr(ansiLen, 0);
|
||||
WideCharToMultiByte(CP_ACP, 0, wideStr.c_str(), -1, &ansiStr[0], ansiLen, NULL, NULL);
|
||||
if (!ansiStr.empty() && ansiStr.back() == '\0') ansiStr.pop_back();
|
||||
return ansiStr;
|
||||
}
|
||||
|
||||
// 简单JSON值提取 (仅用于测试,不依赖jsoncpp)
|
||||
std::string ExtractJsonString(const std::string& json, const std::string& key)
|
||||
{
|
||||
std::string searchKey = "\"" + key + "\"";
|
||||
size_t keyPos = json.find(searchKey);
|
||||
if (keyPos == std::string::npos) return "";
|
||||
|
||||
size_t colonPos = json.find(':', keyPos);
|
||||
if (colonPos == std::string::npos) return "";
|
||||
|
||||
size_t valueStart = json.find_first_not_of(" \t\n\r", colonPos + 1);
|
||||
if (valueStart == std::string::npos) return "";
|
||||
|
||||
if (json[valueStart] == '"') {
|
||||
size_t valueEnd = json.find('"', valueStart + 1);
|
||||
if (valueEnd == std::string::npos) return "";
|
||||
return json.substr(valueStart + 1, valueEnd - valueStart - 1);
|
||||
}
|
||||
|
||||
// 数字或布尔值
|
||||
size_t valueEnd = json.find_first_of(",}\n\r", valueStart);
|
||||
if (valueEnd == std::string::npos) valueEnd = json.length();
|
||||
std::string value = json.substr(valueStart, valueEnd - valueStart);
|
||||
// 去除尾部空格
|
||||
while (!value.empty() && (value.back() == ' ' || value.back() == '\t')) {
|
||||
value.pop_back();
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
// API配置结构
|
||||
struct GeoApiConfig {
|
||||
const char* name;
|
||||
const char* urlFmt;
|
||||
const char* cityField;
|
||||
const char* countryField;
|
||||
const char* checkField;
|
||||
const char* checkValue;
|
||||
bool useHttps;
|
||||
};
|
||||
|
||||
// 测试用API配置
|
||||
static const GeoApiConfig testApis[] = {
|
||||
{"ip-api.com", "http://ip-api.com/json/%s?fields=status,country,city", "city", "country", "status", "success", false},
|
||||
{"ipinfo.io", "http://ipinfo.io/%s/json", "city", "country", "", "", false},
|
||||
{"ipapi.co", "https://ipapi.co/%s/json/", "city", "country_name", "error", "", true},
|
||||
};
|
||||
|
||||
// 测试单个API
|
||||
struct ApiTestResult {
|
||||
bool success;
|
||||
int httpStatus;
|
||||
std::string city;
|
||||
std::string country;
|
||||
std::string error;
|
||||
double latencyMs;
|
||||
};
|
||||
|
||||
ApiTestResult TestSingleApi(const GeoApiConfig& api, const std::string& ip)
|
||||
{
|
||||
ApiTestResult result = {false, 0, "", "", "", 0};
|
||||
|
||||
DWORD startTime = GetTickCount();
|
||||
|
||||
HINTERNET hInternet = InternetOpenA("GeoLocationTest", INTERNET_OPEN_TYPE_DIRECT, NULL, NULL, 0);
|
||||
if (!hInternet) {
|
||||
result.error = "InternetOpen failed";
|
||||
return result;
|
||||
}
|
||||
|
||||
DWORD timeout = 10000;
|
||||
InternetSetOptionA(hInternet, INTERNET_OPTION_CONNECT_TIMEOUT, &timeout, sizeof(timeout));
|
||||
InternetSetOptionA(hInternet, INTERNET_OPTION_SEND_TIMEOUT, &timeout, sizeof(timeout));
|
||||
InternetSetOptionA(hInternet, INTERNET_OPTION_RECEIVE_TIMEOUT, &timeout, sizeof(timeout));
|
||||
|
||||
char urlBuf[256];
|
||||
sprintf_s(urlBuf, api.urlFmt, ip.c_str());
|
||||
|
||||
DWORD flags = INTERNET_FLAG_RELOAD;
|
||||
if (api.useHttps) flags |= INTERNET_FLAG_SECURE;
|
||||
|
||||
HINTERNET hConnect = InternetOpenUrlA(hInternet, urlBuf, NULL, 0, flags, 0);
|
||||
if (!hConnect) {
|
||||
result.error = "InternetOpenUrl failed: " + std::to_string(GetLastError());
|
||||
InternetCloseHandle(hInternet);
|
||||
return result;
|
||||
}
|
||||
|
||||
// 获取HTTP状态码
|
||||
DWORD statusCode = 0;
|
||||
DWORD statusSize = sizeof(statusCode);
|
||||
if (HttpQueryInfoA(hConnect, HTTP_QUERY_STATUS_CODE | HTTP_QUERY_FLAG_NUMBER, &statusCode, &statusSize, NULL)) {
|
||||
result.httpStatus = statusCode;
|
||||
}
|
||||
|
||||
// 读取响应
|
||||
std::string readBuffer;
|
||||
char buffer[4096];
|
||||
DWORD bytesRead;
|
||||
while (InternetReadFile(hConnect, buffer, sizeof(buffer), &bytesRead) && bytesRead > 0) {
|
||||
readBuffer.append(buffer, bytesRead);
|
||||
}
|
||||
|
||||
InternetCloseHandle(hConnect);
|
||||
InternetCloseHandle(hInternet);
|
||||
|
||||
result.latencyMs = (double)(GetTickCount() - startTime);
|
||||
|
||||
// 检查HTTP错误
|
||||
if (result.httpStatus >= 400) {
|
||||
result.error = "HTTP " + std::to_string(result.httpStatus);
|
||||
if (result.httpStatus == 429) {
|
||||
result.error += " (Rate Limited)";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// 检查响应体错误
|
||||
if (readBuffer.find("Rate limit") != std::string::npos ||
|
||||
readBuffer.find("rate limit") != std::string::npos) {
|
||||
result.error = "Rate limited (body)";
|
||||
return result;
|
||||
}
|
||||
|
||||
// 解析JSON
|
||||
if (api.checkField && api.checkField[0]) {
|
||||
std::string checkVal = ExtractJsonString(readBuffer, api.checkField);
|
||||
if (api.checkValue && api.checkValue[0]) {
|
||||
if (checkVal != api.checkValue) {
|
||||
result.error = "Check failed: " + std::string(api.checkField) + "=" + checkVal;
|
||||
return result;
|
||||
}
|
||||
} else {
|
||||
if (checkVal == "true") {
|
||||
result.error = "Error flag set";
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result.city = Utf8ToAnsi(ExtractJsonString(readBuffer, api.cityField));
|
||||
result.country = Utf8ToAnsi(ExtractJsonString(readBuffer, api.countryField));
|
||||
result.success = !result.city.empty() || !result.country.empty();
|
||||
|
||||
if (!result.success) {
|
||||
result.error = "No city/country in response";
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 测试用例
|
||||
// ============================================
|
||||
|
||||
class GeoLocationTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
// 初始化Winsock
|
||||
WSADATA wsaData;
|
||||
WSAStartup(MAKEWORD(2, 2), &wsaData);
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
WSACleanup();
|
||||
}
|
||||
};
|
||||
|
||||
// 测试公网IP (Google DNS)
|
||||
TEST_F(GeoLocationTest, TestPublicIP_GoogleDNS) {
|
||||
std::string testIP = "8.8.8.8";
|
||||
|
||||
std::cout << "\n=== Testing IP: " << testIP << " ===\n";
|
||||
|
||||
int successCount = 0;
|
||||
for (const auto& api : testApis) {
|
||||
auto result = TestSingleApi(api, testIP);
|
||||
|
||||
std::cout << "[" << api.name << "] ";
|
||||
if (result.success) {
|
||||
std::cout << "OK - " << result.city << ", " << result.country;
|
||||
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
|
||||
successCount++;
|
||||
} else {
|
||||
std::cout << "FAIL - " << result.error;
|
||||
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
|
||||
}
|
||||
}
|
||||
|
||||
// 至少一个API应该成功
|
||||
EXPECT_GE(successCount, 1) << "At least one API should succeed";
|
||||
}
|
||||
|
||||
// 测试另一个公网IP (Cloudflare DNS)
|
||||
TEST_F(GeoLocationTest, TestPublicIP_CloudflareDNS) {
|
||||
std::string testIP = "1.1.1.1";
|
||||
|
||||
std::cout << "\n=== Testing IP: " << testIP << " ===\n";
|
||||
|
||||
int successCount = 0;
|
||||
for (const auto& api : testApis) {
|
||||
auto result = TestSingleApi(api, testIP);
|
||||
|
||||
std::cout << "[" << api.name << "] ";
|
||||
if (result.success) {
|
||||
std::cout << "OK - " << result.city << ", " << result.country;
|
||||
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
|
||||
successCount++;
|
||||
} else {
|
||||
std::cout << "FAIL - " << result.error;
|
||||
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_GE(successCount, 1) << "At least one API should succeed";
|
||||
}
|
||||
|
||||
// 测试中国IP
|
||||
TEST_F(GeoLocationTest, TestPublicIP_ChinaIP) {
|
||||
std::string testIP = "114.114.114.114"; // 114DNS
|
||||
|
||||
std::cout << "\n=== Testing IP: " << testIP << " (China) ===\n";
|
||||
|
||||
int successCount = 0;
|
||||
for (const auto& api : testApis) {
|
||||
auto result = TestSingleApi(api, testIP);
|
||||
|
||||
std::cout << "[" << api.name << "] ";
|
||||
if (result.success) {
|
||||
std::cout << "OK - " << result.city << ", " << result.country;
|
||||
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
|
||||
successCount++;
|
||||
} else {
|
||||
std::cout << "FAIL - " << result.error;
|
||||
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_GE(successCount, 1) << "At least one API should succeed";
|
||||
}
|
||||
|
||||
// 测试ip-api.com单独
|
||||
TEST_F(GeoLocationTest, TestIpApiCom) {
|
||||
std::string testIP = "8.8.8.8";
|
||||
auto result = TestSingleApi(testApis[0], testIP);
|
||||
|
||||
std::cout << "\n[ip-api.com] HTTP: " << result.httpStatus
|
||||
<< ", Latency: " << result.latencyMs << "ms\n";
|
||||
|
||||
if (result.success) {
|
||||
std::cout << "City: " << result.city << ", Country: " << result.country << "\n";
|
||||
EXPECT_FALSE(result.country.empty());
|
||||
} else {
|
||||
std::cout << "Error: " << result.error << "\n";
|
||||
// 如果是429就跳过,不算失败
|
||||
if (result.httpStatus == 429) {
|
||||
GTEST_SKIP() << "Rate limited, skipping";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 测试ipinfo.io单独
|
||||
TEST_F(GeoLocationTest, TestIpInfoIo) {
|
||||
std::string testIP = "8.8.8.8";
|
||||
auto result = TestSingleApi(testApis[1], testIP);
|
||||
|
||||
std::cout << "\n[ipinfo.io] HTTP: " << result.httpStatus
|
||||
<< ", Latency: " << result.latencyMs << "ms\n";
|
||||
|
||||
if (result.success) {
|
||||
std::cout << "City: " << result.city << ", Country: " << result.country << "\n";
|
||||
EXPECT_FALSE(result.country.empty());
|
||||
} else {
|
||||
std::cout << "Error: " << result.error << "\n";
|
||||
if (result.httpStatus == 429) {
|
||||
GTEST_SKIP() << "Rate limited, skipping";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 测试ipapi.co单独
|
||||
TEST_F(GeoLocationTest, TestIpApiCo) {
|
||||
std::string testIP = "8.8.8.8";
|
||||
auto result = TestSingleApi(testApis[2], testIP);
|
||||
|
||||
std::cout << "\n[ipapi.co] HTTP: " << result.httpStatus
|
||||
<< ", Latency: " << result.latencyMs << "ms\n";
|
||||
|
||||
if (result.success) {
|
||||
std::cout << "City: " << result.city << ", Country: " << result.country << "\n";
|
||||
EXPECT_FALSE(result.country.empty());
|
||||
} else {
|
||||
std::cout << "Error: " << result.error << "\n";
|
||||
if (result.httpStatus == 429) {
|
||||
GTEST_SKIP() << "Rate limited, skipping";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 测试HTTP状态码检测
|
||||
TEST_F(GeoLocationTest, TestHttpStatusCodeDetection) {
|
||||
// 使用一个会返回错误的请求
|
||||
std::string testIP = "invalid-ip";
|
||||
|
||||
for (const auto& api : testApis) {
|
||||
auto result = TestSingleApi(api, testIP);
|
||||
std::cout << "[" << api.name << "] Invalid IP test: HTTP "
|
||||
<< result.httpStatus << ", Error: " << result.error << "\n";
|
||||
|
||||
// 应该失败
|
||||
EXPECT_FALSE(result.success);
|
||||
}
|
||||
}
|
||||
|
||||
// 测试所有API的响应时间
|
||||
TEST_F(GeoLocationTest, TestApiLatency) {
|
||||
std::string testIP = "8.8.8.8";
|
||||
|
||||
std::cout << "\n=== API Latency Test ===\n";
|
||||
|
||||
for (const auto& api : testApis) {
|
||||
auto result = TestSingleApi(api, testIP);
|
||||
std::cout << "[" << api.name << "] " << result.latencyMs << "ms";
|
||||
if (result.success) {
|
||||
std::cout << " (OK)";
|
||||
} else {
|
||||
std::cout << " (FAIL: " << result.error << ")";
|
||||
}
|
||||
std::cout << "\n";
|
||||
|
||||
// 响应时间应该在合理范围内 (< 15秒)
|
||||
EXPECT_LT(result.latencyMs, 15000);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
// 非Windows平台 - 跳过测试
|
||||
TEST(GeoLocationTest, SkipNonWindows) {
|
||||
GTEST_SKIP() << "GeoLocation tests only run on Windows";
|
||||
}
|
||||
#endif
|
||||
|
||||
642
test/unit/network/HeaderTest.cpp
Normal file
642
test/unit/network/HeaderTest.cpp
Normal file
@@ -0,0 +1,642 @@
|
||||
/**
|
||||
* @file HeaderTest.cpp
|
||||
* @brief 协议头验证和加密测试
|
||||
*
|
||||
* 测试覆盖:
|
||||
* - 协议头格式和常量
|
||||
* - 加密/解密函数正确性
|
||||
* - 多版本头部识别
|
||||
* - 头部生成和验证往返
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <cstring>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
|
||||
// ============================================
|
||||
// 协议常量定义(测试专用副本)
|
||||
// ============================================
|
||||
|
||||
#define MSG_HEADER "HELL"
|
||||
const int FLAG_COMPLEN = 4;
|
||||
const int FLAG_LENGTH = 8;
|
||||
const int HDR_LENGTH = FLAG_LENGTH + 2 * sizeof(unsigned int); // 16
|
||||
const int MIN_COMLEN = 12;
|
||||
|
||||
enum HeaderEncType {
|
||||
HeaderEncUnknown = -1,
|
||||
HeaderEncNone,
|
||||
HeaderEncV0,
|
||||
HeaderEncV1,
|
||||
HeaderEncV2,
|
||||
HeaderEncV3,
|
||||
HeaderEncV4,
|
||||
HeaderEncV5,
|
||||
HeaderEncV6,
|
||||
HeaderEncNum,
|
||||
};
|
||||
|
||||
enum FlagType {
|
||||
FLAG_WINOS = -1,
|
||||
FLAG_UNKNOWN = 0,
|
||||
FLAG_SHINE = 1,
|
||||
FLAG_FUCK = 2,
|
||||
FLAG_HELLO = 3,
|
||||
FLAG_HELL = 4,
|
||||
};
|
||||
|
||||
// ============================================
|
||||
// 加密/解密函数(测试专用副本)
|
||||
// ============================================
|
||||
|
||||
inline void default_encrypt(unsigned char* data, size_t length, unsigned char key)
|
||||
{
|
||||
data[FLAG_LENGTH - 2] = data[FLAG_LENGTH - 1] = 0;
|
||||
}
|
||||
|
||||
inline void default_decrypt(unsigned char* data, size_t length, unsigned char key)
|
||||
{
|
||||
}
|
||||
|
||||
inline void encrypt(unsigned char* data, size_t length, unsigned char key)
|
||||
{
|
||||
if (key == 0) return;
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
unsigned char k = static_cast<unsigned char>(key ^ (i * 31));
|
||||
int value = static_cast<int>(data[i]);
|
||||
switch (i % 4) {
|
||||
case 0:
|
||||
value += k;
|
||||
break;
|
||||
case 1:
|
||||
value = value ^ k;
|
||||
break;
|
||||
case 2:
|
||||
value -= k;
|
||||
break;
|
||||
case 3:
|
||||
value = ~(value ^ k);
|
||||
break;
|
||||
}
|
||||
data[i] = static_cast<unsigned char>(value & 0xFF);
|
||||
}
|
||||
}
|
||||
|
||||
inline void decrypt(unsigned char* data, size_t length, unsigned char key)
|
||||
{
|
||||
if (key == 0) return;
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
unsigned char k = static_cast<unsigned char>(key ^ (i * 31));
|
||||
int value = static_cast<int>(data[i]);
|
||||
switch (i % 4) {
|
||||
case 0:
|
||||
value -= k;
|
||||
break;
|
||||
case 1:
|
||||
value = value ^ k;
|
||||
break;
|
||||
case 2:
|
||||
value += k;
|
||||
break;
|
||||
case 3:
|
||||
value = ~(value) ^ k;
|
||||
break;
|
||||
}
|
||||
data[i] = static_cast<unsigned char>(value & 0xFF);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// HeaderFlag 结构体
|
||||
// ============================================
|
||||
|
||||
typedef struct HeaderFlag {
|
||||
char Data[FLAG_LENGTH + 1];
|
||||
HeaderFlag(const char header[FLAG_LENGTH + 1])
|
||||
{
|
||||
memcpy(Data, header, sizeof(Data));
|
||||
}
|
||||
char& operator[](int i)
|
||||
{
|
||||
return Data[i];
|
||||
}
|
||||
const char operator[](int i) const
|
||||
{
|
||||
return Data[i];
|
||||
}
|
||||
const char* data() const
|
||||
{
|
||||
return Data;
|
||||
}
|
||||
} HeaderFlag;
|
||||
|
||||
typedef void (*EncFun)(unsigned char* data, size_t length, unsigned char key);
|
||||
typedef void (*DecFun)(unsigned char* data, size_t length, unsigned char key);
|
||||
|
||||
// ============================================
|
||||
// 头部生成函数
|
||||
// ============================================
|
||||
|
||||
inline HeaderFlag GetHead(EncFun enc, unsigned char fixedKey = 0)
|
||||
{
|
||||
char header[FLAG_LENGTH + 1] = { 'H','E','L','L', 0 };
|
||||
HeaderFlag H(header);
|
||||
unsigned char key = (fixedKey != 0) ? fixedKey : (time(0) % 256);
|
||||
H[FLAG_LENGTH - 2] = key;
|
||||
H[FLAG_LENGTH - 1] = ~key;
|
||||
enc((unsigned char*)H.data(), FLAG_COMPLEN, H[FLAG_LENGTH - 2]);
|
||||
return H;
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 头部验证函数
|
||||
// ============================================
|
||||
|
||||
inline int compare(const char *flag, const char *magic, int len, DecFun dec, unsigned char key)
|
||||
{
|
||||
unsigned char buf[32] = {};
|
||||
memcpy(buf, flag, MIN_COMLEN);
|
||||
dec(buf, len, key);
|
||||
if (memcmp(buf, magic, len) == 0) {
|
||||
memcpy((void*)flag, buf, MIN_COMLEN);
|
||||
return 0;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
inline FlagType CheckHead(const char* flag, DecFun dec)
|
||||
{
|
||||
FlagType type = FLAG_UNKNOWN;
|
||||
if (compare(flag, MSG_HEADER, FLAG_COMPLEN, dec, flag[6]) == 0) {
|
||||
type = FLAG_HELL;
|
||||
} else if (compare(flag, "Shine", 5, dec, 0) == 0) {
|
||||
type = FLAG_SHINE;
|
||||
} else if (compare(flag, "<<FUCK>>", 8, dec, flag[9]) == 0) {
|
||||
type = FLAG_FUCK;
|
||||
} else if (compare(flag, "Hello?", 6, dec, flag[6]) == 0) {
|
||||
type = FLAG_HELLO;
|
||||
} else {
|
||||
type = FLAG_UNKNOWN;
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
inline FlagType CheckHeadMulti(char* flag, HeaderEncType& funcHit)
|
||||
{
|
||||
static const DecFun methods[] = { default_decrypt, decrypt };
|
||||
static const int methodNum = sizeof(methods) / sizeof(DecFun);
|
||||
char buffer[MIN_COMLEN + 4] = {};
|
||||
for (int i = 0; i < methodNum; ++i) {
|
||||
memcpy(buffer, flag, MIN_COMLEN);
|
||||
FlagType type = CheckHead(buffer, methods[i]);
|
||||
if (type != FLAG_UNKNOWN) {
|
||||
memcpy(flag, buffer, MIN_COMLEN);
|
||||
funcHit = HeaderEncType(i);
|
||||
return type;
|
||||
}
|
||||
}
|
||||
funcHit = HeaderEncUnknown;
|
||||
return FLAG_UNKNOWN;
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 协议常量测试
|
||||
// ============================================
|
||||
|
||||
class HeaderConstantsTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(HeaderConstantsTest, FlagLength) {
|
||||
EXPECT_EQ(FLAG_LENGTH, 8);
|
||||
}
|
||||
|
||||
TEST_F(HeaderConstantsTest, FlagCompLen) {
|
||||
EXPECT_EQ(FLAG_COMPLEN, 4);
|
||||
}
|
||||
|
||||
TEST_F(HeaderConstantsTest, HdrLength) {
|
||||
// FLAG_LENGTH(8) + 2 * sizeof(uint32_t)(8) = 16
|
||||
EXPECT_EQ(HDR_LENGTH, 16);
|
||||
}
|
||||
|
||||
TEST_F(HeaderConstantsTest, MinComLen) {
|
||||
EXPECT_EQ(MIN_COMLEN, 12);
|
||||
}
|
||||
|
||||
TEST_F(HeaderConstantsTest, MsgHeader) {
|
||||
EXPECT_STREQ(MSG_HEADER, "HELL");
|
||||
EXPECT_EQ(strlen(MSG_HEADER), 4u);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 加密/解密测试
|
||||
// ============================================
|
||||
|
||||
class EncryptionTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(EncryptionTest, ZeroKeyNoOp) {
|
||||
unsigned char data[] = {0x01, 0x02, 0x03, 0x04, 0x05};
|
||||
unsigned char original[] = {0x01, 0x02, 0x03, 0x04, 0x05};
|
||||
|
||||
encrypt(data, 5, 0);
|
||||
EXPECT_EQ(memcmp(data, original, 5), 0);
|
||||
|
||||
decrypt(data, 5, 0);
|
||||
EXPECT_EQ(memcmp(data, original, 5), 0);
|
||||
}
|
||||
|
||||
TEST_F(EncryptionTest, EncryptDecryptRoundTrip) {
|
||||
unsigned char original[] = {0x48, 0x45, 0x4C, 0x4C}; // "HELL"
|
||||
unsigned char data[4];
|
||||
memcpy(data, original, 4);
|
||||
|
||||
unsigned char key = 0x42;
|
||||
|
||||
encrypt(data, 4, key);
|
||||
// 加密后应该不同
|
||||
EXPECT_NE(memcmp(data, original, 4), 0);
|
||||
|
||||
decrypt(data, 4, key);
|
||||
// 解密后应该恢复
|
||||
EXPECT_EQ(memcmp(data, original, 4), 0);
|
||||
}
|
||||
|
||||
TEST_F(EncryptionTest, DifferentKeysProduceDifferentResults) {
|
||||
unsigned char data1[] = {0x48, 0x45, 0x4C, 0x4C};
|
||||
unsigned char data2[] = {0x48, 0x45, 0x4C, 0x4C};
|
||||
|
||||
encrypt(data1, 4, 0x10);
|
||||
encrypt(data2, 4, 0x20);
|
||||
|
||||
EXPECT_NE(memcmp(data1, data2, 4), 0);
|
||||
}
|
||||
|
||||
TEST_F(EncryptionTest, PositionDependentEncryption) {
|
||||
// 相同值在不同位置加密结果不同
|
||||
unsigned char data[] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
|
||||
unsigned char original[] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
|
||||
|
||||
encrypt(data, 8, 0x55);
|
||||
|
||||
// 验证加密后值不全相同(位置相关加密)
|
||||
bool allSame = true;
|
||||
for (int i = 1; i < 8; ++i) {
|
||||
if (data[i] != data[0]) {
|
||||
allSame = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_FALSE(allSame);
|
||||
|
||||
// 验证解密恢复
|
||||
decrypt(data, 8, 0x55);
|
||||
EXPECT_EQ(memcmp(data, original, 8), 0);
|
||||
}
|
||||
|
||||
TEST_F(EncryptionTest, AllKeyValues) {
|
||||
// 测试所有可能的密钥值
|
||||
unsigned char original[] = {0x12, 0x34, 0x56, 0x78};
|
||||
|
||||
for (int key = 1; key < 256; ++key) {
|
||||
unsigned char data[4];
|
||||
memcpy(data, original, 4);
|
||||
|
||||
encrypt(data, 4, static_cast<unsigned char>(key));
|
||||
decrypt(data, 4, static_cast<unsigned char>(key));
|
||||
|
||||
EXPECT_EQ(memcmp(data, original, 4), 0) << "Failed for key: " << key;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(EncryptionTest, LargeDataRoundTrip) {
|
||||
std::vector<unsigned char> original(1000);
|
||||
for (size_t i = 0; i < original.size(); ++i) {
|
||||
original[i] = static_cast<unsigned char>(i & 0xFF);
|
||||
}
|
||||
|
||||
std::vector<unsigned char> data = original;
|
||||
unsigned char key = 0x7F;
|
||||
|
||||
encrypt(data.data(), data.size(), key);
|
||||
decrypt(data.data(), data.size(), key);
|
||||
|
||||
EXPECT_EQ(data, original);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// HeaderFlag 测试
|
||||
// ============================================
|
||||
|
||||
class HeaderFlagTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(HeaderFlagTest, Construction) {
|
||||
char header[FLAG_LENGTH + 1] = { 'H','E','L','L', 0, 0, 0, 0, 0 };
|
||||
HeaderFlag hf(header);
|
||||
|
||||
EXPECT_EQ(hf[0], 'H');
|
||||
EXPECT_EQ(hf[1], 'E');
|
||||
EXPECT_EQ(hf[2], 'L');
|
||||
EXPECT_EQ(hf[3], 'L');
|
||||
}
|
||||
|
||||
TEST_F(HeaderFlagTest, DataAccess) {
|
||||
char header[FLAG_LENGTH + 1] = { 'T','E','S','T', 0x12, 0x34, 0x56, 0x78, 0 };
|
||||
HeaderFlag hf(header);
|
||||
|
||||
EXPECT_EQ(memcmp(hf.data(), "TEST", 4), 0);
|
||||
EXPECT_EQ(static_cast<unsigned char>(hf[4]), 0x12);
|
||||
EXPECT_EQ(static_cast<unsigned char>(hf[5]), 0x34);
|
||||
}
|
||||
|
||||
TEST_F(HeaderFlagTest, Modification) {
|
||||
char header[FLAG_LENGTH + 1] = { 0 };
|
||||
HeaderFlag hf(header);
|
||||
|
||||
hf[0] = 'A';
|
||||
hf[1] = 'B';
|
||||
|
||||
EXPECT_EQ(hf[0], 'A');
|
||||
EXPECT_EQ(hf[1], 'B');
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// GetHead 测试
|
||||
// ============================================
|
||||
|
||||
class GetHeadTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(GetHeadTest, GeneratesValidHeader) {
|
||||
HeaderFlag hf = GetHead(default_encrypt, 0x42);
|
||||
|
||||
// 检查基础格式
|
||||
EXPECT_EQ(hf[0], 'H');
|
||||
EXPECT_EQ(hf[1], 'E');
|
||||
EXPECT_EQ(hf[2], 'L');
|
||||
EXPECT_EQ(hf[3], 'L');
|
||||
}
|
||||
|
||||
TEST_F(GetHeadTest, KeyAndInverseKey) {
|
||||
unsigned char key = 0x42;
|
||||
HeaderFlag hf = GetHead(default_encrypt, key);
|
||||
|
||||
// 使用 default_encrypt 时,key 位置被清零
|
||||
// 但我们需要验证生成逻辑
|
||||
char rawHeader[FLAG_LENGTH + 1] = { 'H','E','L','L', 0 };
|
||||
HeaderFlag expected(rawHeader);
|
||||
expected[FLAG_LENGTH - 2] = key;
|
||||
expected[FLAG_LENGTH - 1] = ~key;
|
||||
default_encrypt((unsigned char*)expected.data(), FLAG_COMPLEN, expected[FLAG_LENGTH - 2]);
|
||||
|
||||
EXPECT_EQ(memcmp(hf.data(), expected.data(), FLAG_LENGTH), 0);
|
||||
}
|
||||
|
||||
TEST_F(GetHeadTest, EncryptedHeader) {
|
||||
unsigned char key = 0x55;
|
||||
HeaderFlag hf = GetHead(encrypt, key);
|
||||
|
||||
// 头部应该被加密,不再是明文 "HELL"
|
||||
EXPECT_NE(memcmp(hf.data(), "HELL", 4), 0);
|
||||
|
||||
// 密钥应该在正确位置
|
||||
unsigned char storedKey = static_cast<unsigned char>(hf[FLAG_LENGTH - 2]);
|
||||
unsigned char inverseKey = static_cast<unsigned char>(hf[FLAG_LENGTH - 1]);
|
||||
EXPECT_EQ(storedKey, key);
|
||||
EXPECT_EQ(inverseKey, static_cast<unsigned char>(~key));
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// CheckHead 测试
|
||||
// ============================================
|
||||
|
||||
class CheckHeadTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(CheckHeadTest, IdentifyHellFlag) {
|
||||
char header[MIN_COMLEN + 4] = { 'H','E','L','L', 0, 0, 0x42, static_cast<char>(~0x42), 0 };
|
||||
|
||||
FlagType type = CheckHead(header, default_decrypt);
|
||||
EXPECT_EQ(type, FLAG_HELL);
|
||||
}
|
||||
|
||||
TEST_F(CheckHeadTest, IdentifyShineFlag) {
|
||||
char header[MIN_COMLEN + 4] = { 'S','h','i','n','e', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
|
||||
FlagType type = CheckHead(header, default_decrypt);
|
||||
EXPECT_EQ(type, FLAG_SHINE);
|
||||
}
|
||||
|
||||
TEST_F(CheckHeadTest, UnknownFlag) {
|
||||
char header[MIN_COMLEN + 4] = { 'X','Y','Z','W', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
|
||||
FlagType type = CheckHead(header, default_decrypt);
|
||||
EXPECT_EQ(type, FLAG_UNKNOWN);
|
||||
}
|
||||
|
||||
TEST_F(CheckHeadTest, EncryptedHellFlag) {
|
||||
// 生成加密的头部
|
||||
unsigned char key = 0x33;
|
||||
char header[FLAG_LENGTH + 1] = { 'H','E','L','L', 0, 0, 0, 0, 0 };
|
||||
header[FLAG_LENGTH - 2] = key;
|
||||
header[FLAG_LENGTH - 1] = ~key;
|
||||
encrypt((unsigned char*)header, FLAG_COMPLEN, key);
|
||||
|
||||
char buffer[MIN_COMLEN + 4] = {};
|
||||
memcpy(buffer, header, FLAG_LENGTH);
|
||||
|
||||
FlagType type = CheckHead(buffer, decrypt);
|
||||
EXPECT_EQ(type, FLAG_HELL);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// CheckHeadMulti 测试
|
||||
// ============================================
|
||||
|
||||
class CheckHeadMultiTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(CheckHeadMultiTest, PlainTextHeader) {
|
||||
char header[MIN_COMLEN + 4] = { 'H','E','L','L', 0, 0, 0x42, static_cast<char>(~0x42), 0 };
|
||||
|
||||
HeaderEncType encType;
|
||||
FlagType type = CheckHeadMulti(header, encType);
|
||||
|
||||
EXPECT_EQ(type, FLAG_HELL);
|
||||
EXPECT_EQ(encType, HeaderEncNone);
|
||||
}
|
||||
|
||||
TEST_F(CheckHeadMultiTest, EncryptedHeader) {
|
||||
unsigned char key = 0x77;
|
||||
char header[MIN_COMLEN + 4] = { 'H','E','L','L', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
header[FLAG_LENGTH - 2] = key;
|
||||
header[FLAG_LENGTH - 1] = ~key;
|
||||
encrypt((unsigned char*)header, FLAG_COMPLEN, key);
|
||||
|
||||
HeaderEncType encType;
|
||||
FlagType type = CheckHeadMulti(header, encType);
|
||||
|
||||
EXPECT_EQ(type, FLAG_HELL);
|
||||
EXPECT_EQ(encType, HeaderEncV0); // encrypt 对应 V0
|
||||
}
|
||||
|
||||
TEST_F(CheckHeadMultiTest, UnrecognizedHeader) {
|
||||
char header[MIN_COMLEN + 4] = { 0xFF, 0xFE, 0xFD, 0xFC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
|
||||
|
||||
HeaderEncType encType;
|
||||
FlagType type = CheckHeadMulti(header, encType);
|
||||
|
||||
EXPECT_EQ(type, FLAG_UNKNOWN);
|
||||
EXPECT_EQ(encType, HeaderEncUnknown);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 头部生成和验证往返测试
|
||||
// ============================================
|
||||
|
||||
class HeaderRoundTripTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(HeaderRoundTripTest, DefaultEncryptRoundTrip) {
|
||||
HeaderFlag hf = GetHead(default_encrypt, 0x42);
|
||||
|
||||
char buffer[MIN_COMLEN + 4] = {};
|
||||
memcpy(buffer, hf.data(), FLAG_LENGTH);
|
||||
|
||||
HeaderEncType encType;
|
||||
FlagType type = CheckHeadMulti(buffer, encType);
|
||||
|
||||
EXPECT_EQ(type, FLAG_HELL);
|
||||
}
|
||||
|
||||
TEST_F(HeaderRoundTripTest, EncryptRoundTrip) {
|
||||
HeaderFlag hf = GetHead(encrypt, 0x88);
|
||||
|
||||
char buffer[MIN_COMLEN + 4] = {};
|
||||
memcpy(buffer, hf.data(), FLAG_LENGTH);
|
||||
|
||||
HeaderEncType encType;
|
||||
FlagType type = CheckHeadMulti(buffer, encType);
|
||||
|
||||
EXPECT_EQ(type, FLAG_HELL);
|
||||
}
|
||||
|
||||
TEST_F(HeaderRoundTripTest, AllKeyValuesRoundTrip) {
|
||||
for (int key = 1; key < 256; ++key) {
|
||||
HeaderFlag hf = GetHead(encrypt, static_cast<unsigned char>(key));
|
||||
|
||||
char buffer[MIN_COMLEN + 4] = {};
|
||||
memcpy(buffer, hf.data(), FLAG_LENGTH);
|
||||
|
||||
HeaderEncType encType;
|
||||
FlagType type = CheckHeadMulti(buffer, encType);
|
||||
|
||||
EXPECT_EQ(type, FLAG_HELL) << "Failed for key: " << key;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 数据包长度字段测试
|
||||
// ============================================
|
||||
|
||||
class PacketLengthTest : public ::testing::Test {};
|
||||
|
||||
#pragma pack(push, 1)
|
||||
struct PacketHeader {
|
||||
char flag[FLAG_LENGTH];
|
||||
uint32_t packedLength; // 压缩后长度
|
||||
uint32_t originalLength; // 原始长度
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
TEST_F(PacketLengthTest, HeaderSize) {
|
||||
EXPECT_EQ(sizeof(PacketHeader), HDR_LENGTH);
|
||||
}
|
||||
|
||||
TEST_F(PacketLengthTest, BuildAndParsePacket) {
|
||||
PacketHeader pkt = {};
|
||||
memcpy(pkt.flag, "HELL", 4);
|
||||
pkt.flag[6] = 0x42;
|
||||
pkt.flag[7] = ~0x42;
|
||||
pkt.packedLength = 100;
|
||||
pkt.originalLength = 200;
|
||||
|
||||
// 解析
|
||||
uint32_t packedLen, origLen;
|
||||
memcpy(&packedLen, reinterpret_cast<char*>(&pkt) + FLAG_LENGTH, sizeof(uint32_t));
|
||||
memcpy(&origLen, reinterpret_cast<char*>(&pkt) + FLAG_LENGTH + sizeof(uint32_t), sizeof(uint32_t));
|
||||
|
||||
EXPECT_EQ(packedLen, 100u);
|
||||
EXPECT_EQ(origLen, 200u);
|
||||
}
|
||||
|
||||
TEST_F(PacketLengthTest, TotalPacketLength) {
|
||||
// 总包长度 = HDR_LENGTH + 压缩数据长度
|
||||
uint32_t dataLen = 1000;
|
||||
uint32_t totalLen = HDR_LENGTH + dataLen;
|
||||
|
||||
EXPECT_EQ(totalLen, 1016u);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 边界条件测试
|
||||
// ============================================
|
||||
|
||||
class HeaderBoundaryTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(HeaderBoundaryTest, MinimalPacket) {
|
||||
// 最小合法包:只有头部
|
||||
std::vector<uint8_t> packet(HDR_LENGTH);
|
||||
memcpy(packet.data(), "HELL", 4);
|
||||
packet[6] = 0x42;
|
||||
packet[7] = ~0x42;
|
||||
// packedLength = HDR_LENGTH
|
||||
uint32_t packedLen = HDR_LENGTH;
|
||||
memcpy(packet.data() + FLAG_LENGTH, &packedLen, sizeof(uint32_t));
|
||||
// originalLength = 0
|
||||
uint32_t origLen = 0;
|
||||
memcpy(packet.data() + FLAG_LENGTH + sizeof(uint32_t), &origLen, sizeof(uint32_t));
|
||||
|
||||
EXPECT_EQ(packet.size(), HDR_LENGTH);
|
||||
}
|
||||
|
||||
TEST_F(HeaderBoundaryTest, MaxPacketLength) {
|
||||
// 验证大包长度字段
|
||||
uint32_t maxDataLen = 100 * 1024 * 1024; // 100 MB
|
||||
uint32_t totalLen = HDR_LENGTH + maxDataLen;
|
||||
|
||||
PacketHeader pkt = {};
|
||||
pkt.packedLength = totalLen;
|
||||
pkt.originalLength = maxDataLen * 2; // 压缩前更大
|
||||
|
||||
EXPECT_EQ(pkt.packedLength, totalLen);
|
||||
EXPECT_EQ(pkt.originalLength, maxDataLen * 2);
|
||||
}
|
||||
|
||||
TEST_F(HeaderBoundaryTest, TruncatedHeader) {
|
||||
// 不完整的头部不应该被识别
|
||||
char truncated[4] = { 'H', 'E', 'L', 'L' };
|
||||
|
||||
// 不足以进行完整验证
|
||||
// 这里只是验证常量定义正确
|
||||
EXPECT_LT(sizeof(truncated), static_cast<size_t>(MIN_COMLEN));
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// FlagType 枚举测试
|
||||
// ============================================
|
||||
|
||||
class FlagTypeTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(FlagTypeTest, EnumValues) {
|
||||
EXPECT_EQ(FLAG_WINOS, -1);
|
||||
EXPECT_EQ(FLAG_UNKNOWN, 0);
|
||||
EXPECT_EQ(FLAG_SHINE, 1);
|
||||
EXPECT_EQ(FLAG_FUCK, 2);
|
||||
EXPECT_EQ(FLAG_HELLO, 3);
|
||||
EXPECT_EQ(FLAG_HELL, 4);
|
||||
}
|
||||
|
||||
TEST_F(FlagTypeTest, HeaderEncTypeValues) {
|
||||
EXPECT_EQ(HeaderEncUnknown, -1);
|
||||
EXPECT_EQ(HeaderEncNone, 0);
|
||||
EXPECT_EQ(HeaderEncV0, 1);
|
||||
EXPECT_EQ(HeaderEncNum, 8);
|
||||
}
|
||||
|
||||
534
test/unit/network/HttpMaskTest.cpp
Normal file
534
test/unit/network/HttpMaskTest.cpp
Normal file
@@ -0,0 +1,534 @@
|
||||
/**
|
||||
* @file HttpMaskTest.cpp
|
||||
* @brief HTTP 协议伪装测试
|
||||
*
|
||||
* 测试覆盖:
|
||||
* - HTTP 请求格式生成
|
||||
* - HTTP 头部解析和移除
|
||||
* - Mask/UnMask 往返测试
|
||||
* - 边界条件处理
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
// ============================================
|
||||
// 类型定义
|
||||
// ============================================
|
||||
|
||||
#ifdef _WIN32
|
||||
typedef unsigned long ULONG;
|
||||
#else
|
||||
typedef uint32_t ULONG;
|
||||
#endif
|
||||
|
||||
// ============================================
|
||||
// 协议伪装类型
|
||||
// ============================================
|
||||
|
||||
enum PkgMaskType {
|
||||
MaskTypeUnknown = -1,
|
||||
MaskTypeNone,
|
||||
MaskTypeHTTP,
|
||||
MaskTypeNum,
|
||||
};
|
||||
|
||||
// ============================================
|
||||
// HTTP 解除伪装函数
|
||||
// ============================================
|
||||
|
||||
inline ULONG UnMaskHttp(const char* src, ULONG srcSize)
|
||||
{
|
||||
const char* header_end_mark = "\r\n\r\n";
|
||||
const ULONG mark_len = 4;
|
||||
|
||||
for (ULONG i = 0; i + mark_len <= srcSize; ++i) {
|
||||
if (memcmp(src + i, header_end_mark, mark_len) == 0) {
|
||||
return i + mark_len;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
inline ULONG TryUnMask(const char* src, ULONG srcSize, PkgMaskType& maskHit)
|
||||
{
|
||||
if (srcSize >= 5 && memcmp(src, "POST ", 5) == 0) {
|
||||
maskHit = MaskTypeHTTP;
|
||||
return UnMaskHttp(src, srcSize);
|
||||
}
|
||||
if (srcSize >= 4 && memcmp(src, "GET ", 4) == 0) {
|
||||
maskHit = MaskTypeHTTP;
|
||||
return UnMaskHttp(src, srcSize);
|
||||
}
|
||||
maskHit = MaskTypeNone;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// HTTP Mask 类(简化版)
|
||||
// ============================================
|
||||
|
||||
class HttpMask {
|
||||
public:
|
||||
explicit HttpMask(const std::string& host = "example.com",
|
||||
const std::map<std::string, std::string>& headers = {})
|
||||
: host_(host)
|
||||
{
|
||||
for (const auto& kv : headers) {
|
||||
customHeaders_ += kv.first + ": " + kv.second + "\r\n";
|
||||
}
|
||||
}
|
||||
|
||||
void Mask(char*& dst, ULONG& dstSize, const char* src, ULONG srcSize, int cmd = -1)
|
||||
{
|
||||
std::string path = "/api/v1/" + std::to_string(cmd == -1 ? 0 : cmd);
|
||||
|
||||
std::string httpHeader =
|
||||
"POST " + path + " HTTP/1.1\r\n"
|
||||
"Host: " + host_ + "\r\n"
|
||||
"User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64)\r\n"
|
||||
"Content-Type: application/octet-stream\r\n"
|
||||
"Content-Length: " + std::to_string(srcSize) + "\r\n" + customHeaders_ +
|
||||
"Connection: keep-alive\r\n"
|
||||
"\r\n";
|
||||
|
||||
dstSize = static_cast<ULONG>(httpHeader.size()) + srcSize;
|
||||
dst = new char[dstSize];
|
||||
|
||||
memcpy(dst, httpHeader.data(), httpHeader.size());
|
||||
if (srcSize > 0) {
|
||||
memcpy(dst + httpHeader.size(), src, srcSize);
|
||||
}
|
||||
}
|
||||
|
||||
ULONG UnMask(const char* src, ULONG srcSize)
|
||||
{
|
||||
return UnMaskHttp(src, srcSize);
|
||||
}
|
||||
|
||||
void SetHost(const std::string& host)
|
||||
{
|
||||
host_ = host;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string host_;
|
||||
std::string customHeaders_;
|
||||
};
|
||||
|
||||
// ============================================
|
||||
// UnMaskHttp 测试
|
||||
// ============================================
|
||||
|
||||
class UnMaskHttpTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(UnMaskHttpTest, ValidHttpRequest) {
|
||||
std::string httpRequest =
|
||||
"POST /api HTTP/1.1\r\n"
|
||||
"Host: example.com\r\n"
|
||||
"Content-Length: 4\r\n"
|
||||
"\r\n"
|
||||
"DATA";
|
||||
|
||||
ULONG offset = UnMaskHttp(httpRequest.data(), static_cast<ULONG>(httpRequest.size()));
|
||||
|
||||
// 应该返回 body 起始位置
|
||||
EXPECT_GT(offset, 0u);
|
||||
EXPECT_STREQ(httpRequest.data() + offset, "DATA");
|
||||
}
|
||||
|
||||
TEST_F(UnMaskHttpTest, NoHeaderEnd) {
|
||||
std::string incomplete = "POST /api HTTP/1.1\r\nHost: example.com\r\n";
|
||||
|
||||
ULONG offset = UnMaskHttp(incomplete.data(), static_cast<ULONG>(incomplete.size()));
|
||||
EXPECT_EQ(offset, 0u);
|
||||
}
|
||||
|
||||
TEST_F(UnMaskHttpTest, EmptyBody) {
|
||||
std::string httpRequest =
|
||||
"POST /api HTTP/1.1\r\n"
|
||||
"Content-Length: 0\r\n"
|
||||
"\r\n";
|
||||
|
||||
ULONG offset = UnMaskHttp(httpRequest.data(), static_cast<ULONG>(httpRequest.size()));
|
||||
EXPECT_EQ(offset, static_cast<ULONG>(httpRequest.size()));
|
||||
}
|
||||
|
||||
TEST_F(UnMaskHttpTest, MultipleHeaderEndMarkers) {
|
||||
std::string httpRequest =
|
||||
"POST /api HTTP/1.1\r\n"
|
||||
"Content-Length: 8\r\n"
|
||||
"\r\n"
|
||||
"\r\n\r\nXX"; // body 中也有 \r\n\r\n
|
||||
|
||||
ULONG offset = UnMaskHttp(httpRequest.data(), static_cast<ULONG>(httpRequest.size()));
|
||||
|
||||
// 应该返回第一个 \r\n\r\n 之后
|
||||
std::string body(httpRequest.data() + offset);
|
||||
EXPECT_EQ(body, "\r\n\r\nXX");
|
||||
}
|
||||
|
||||
TEST_F(UnMaskHttpTest, MinimalInput) {
|
||||
// 小于 4 字节
|
||||
std::string tiny = "AB";
|
||||
ULONG offset = UnMaskHttp(tiny.data(), static_cast<ULONG>(tiny.size()));
|
||||
EXPECT_EQ(offset, 0u);
|
||||
}
|
||||
|
||||
TEST_F(UnMaskHttpTest, ExactlyHeaderEnd) {
|
||||
std::string justEnd = "\r\n\r\n";
|
||||
ULONG offset = UnMaskHttp(justEnd.data(), static_cast<ULONG>(justEnd.size()));
|
||||
EXPECT_EQ(offset, 4u);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// TryUnMask 测试
|
||||
// ============================================
|
||||
|
||||
class TryUnMaskTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(TryUnMaskTest, DetectPOST) {
|
||||
std::string httpRequest =
|
||||
"POST /api HTTP/1.1\r\n"
|
||||
"Host: test.com\r\n"
|
||||
"\r\n"
|
||||
"body";
|
||||
|
||||
PkgMaskType maskType;
|
||||
ULONG offset = TryUnMask(httpRequest.data(), static_cast<ULONG>(httpRequest.size()), maskType);
|
||||
|
||||
EXPECT_EQ(maskType, MaskTypeHTTP);
|
||||
EXPECT_GT(offset, 0u);
|
||||
}
|
||||
|
||||
TEST_F(TryUnMaskTest, DetectGET) {
|
||||
std::string httpRequest =
|
||||
"GET /resource HTTP/1.1\r\n"
|
||||
"Host: test.com\r\n"
|
||||
"\r\n";
|
||||
|
||||
PkgMaskType maskType;
|
||||
ULONG offset = TryUnMask(httpRequest.data(), static_cast<ULONG>(httpRequest.size()), maskType);
|
||||
|
||||
EXPECT_EQ(maskType, MaskTypeHTTP);
|
||||
EXPECT_GT(offset, 0u);
|
||||
}
|
||||
|
||||
TEST_F(TryUnMaskTest, NonHttpData) {
|
||||
std::string binaryData = "HELL\x00\x00\x42\xBD";
|
||||
|
||||
PkgMaskType maskType;
|
||||
ULONG offset = TryUnMask(binaryData.data(), static_cast<ULONG>(binaryData.size()), maskType);
|
||||
|
||||
EXPECT_EQ(maskType, MaskTypeNone);
|
||||
EXPECT_EQ(offset, 0u);
|
||||
}
|
||||
|
||||
TEST_F(TryUnMaskTest, ShortInput) {
|
||||
std::string tooShort = "POS";
|
||||
|
||||
PkgMaskType maskType;
|
||||
ULONG offset = TryUnMask(tooShort.data(), static_cast<ULONG>(tooShort.size()), maskType);
|
||||
|
||||
EXPECT_EQ(maskType, MaskTypeNone);
|
||||
EXPECT_EQ(offset, 0u);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// HttpMask 类测试
|
||||
// ============================================
|
||||
|
||||
class HttpMaskClassTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(HttpMaskClassTest, MaskBasic) {
|
||||
HttpMask mask("api.example.com");
|
||||
|
||||
const char* data = "Hello";
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, data, 5);
|
||||
|
||||
ASSERT_NE(masked, nullptr);
|
||||
EXPECT_GT(maskedSize, 5u);
|
||||
|
||||
// 验证是 HTTP 格式
|
||||
EXPECT_EQ(memcmp(masked, "POST ", 5), 0);
|
||||
|
||||
// 验证 Host 头
|
||||
std::string maskedStr(masked, maskedSize);
|
||||
EXPECT_NE(maskedStr.find("Host: api.example.com"), std::string::npos);
|
||||
|
||||
// 验证 Content-Length
|
||||
EXPECT_NE(maskedStr.find("Content-Length: 5"), std::string::npos);
|
||||
|
||||
// 验证 body
|
||||
ULONG offset = mask.UnMask(masked, maskedSize);
|
||||
EXPECT_EQ(memcmp(masked + offset, "Hello", 5), 0);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
TEST_F(HttpMaskClassTest, MaskEmptyData) {
|
||||
HttpMask mask;
|
||||
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, "", 0);
|
||||
|
||||
ASSERT_NE(masked, nullptr);
|
||||
|
||||
// 验证 Content-Length: 0
|
||||
std::string maskedStr(masked, maskedSize);
|
||||
EXPECT_NE(maskedStr.find("Content-Length: 0"), std::string::npos);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
TEST_F(HttpMaskClassTest, MaskWithCommand) {
|
||||
HttpMask mask;
|
||||
|
||||
const char* data = "X";
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, data, 1, 42);
|
||||
|
||||
// 验证路径包含命令号
|
||||
std::string maskedStr(masked, maskedSize);
|
||||
EXPECT_NE(maskedStr.find("/42"), std::string::npos);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
TEST_F(HttpMaskClassTest, MaskLargeData) {
|
||||
HttpMask mask;
|
||||
|
||||
std::vector<char> largeData(64 * 1024, 'X'); // 64 KB
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, largeData.data(), static_cast<ULONG>(largeData.size()));
|
||||
|
||||
ASSERT_NE(masked, nullptr);
|
||||
|
||||
// 验证 Content-Length
|
||||
std::string maskedStr(masked, maskedSize);
|
||||
EXPECT_NE(maskedStr.find("Content-Length: 65536"), std::string::npos);
|
||||
|
||||
// 验证 body 完整
|
||||
ULONG offset = mask.UnMask(masked, maskedSize);
|
||||
EXPECT_EQ(maskedSize - offset, 64u * 1024u);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// Mask/UnMask 往返测试
|
||||
// ============================================
|
||||
|
||||
class MaskRoundTripTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(MaskRoundTripTest, SimpleRoundTrip) {
|
||||
HttpMask mask;
|
||||
|
||||
std::vector<char> original = {'H', 'E', 'L', 'L', 'O'};
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
|
||||
|
||||
// UnMask
|
||||
ULONG offset = mask.UnMask(masked, maskedSize);
|
||||
EXPECT_GT(offset, 0u);
|
||||
|
||||
// 验证数据完整
|
||||
EXPECT_EQ(maskedSize - offset, original.size());
|
||||
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
TEST_F(MaskRoundTripTest, BinaryDataRoundTrip) {
|
||||
HttpMask mask;
|
||||
|
||||
// 包含所有字节值
|
||||
std::vector<char> original(256);
|
||||
for (int i = 0; i < 256; ++i) {
|
||||
original[i] = static_cast<char>(i);
|
||||
}
|
||||
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
|
||||
|
||||
ULONG offset = mask.UnMask(masked, maskedSize);
|
||||
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
TEST_F(MaskRoundTripTest, NullBytesRoundTrip) {
|
||||
HttpMask mask;
|
||||
|
||||
std::vector<char> original = {'\0', '\0', 'A', '\0', 'B'};
|
||||
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
|
||||
|
||||
ULONG offset = mask.UnMask(masked, maskedSize);
|
||||
EXPECT_EQ(maskedSize - offset, original.size());
|
||||
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
TEST_F(MaskRoundTripTest, HttpLikeDataRoundTrip) {
|
||||
HttpMask mask;
|
||||
|
||||
// 数据本身看起来像 HTTP
|
||||
std::string httpLike = "POST /fake HTTP/1.1\r\n\r\nfake body";
|
||||
std::vector<char> original(httpLike.begin(), httpLike.end());
|
||||
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
|
||||
|
||||
ULONG offset = mask.UnMask(masked, maskedSize);
|
||||
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 自定义头部测试
|
||||
// ============================================
|
||||
|
||||
class CustomHeadersTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(CustomHeadersTest, AddCustomHeaders) {
|
||||
std::map<std::string, std::string> headers;
|
||||
headers["X-Custom-Header"] = "custom-value";
|
||||
headers["X-Request-ID"] = "12345";
|
||||
|
||||
HttpMask mask("test.com", headers);
|
||||
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, "data", 4);
|
||||
|
||||
std::string maskedStr(masked, maskedSize);
|
||||
EXPECT_NE(maskedStr.find("X-Custom-Header: custom-value"), std::string::npos);
|
||||
EXPECT_NE(maskedStr.find("X-Request-ID: 12345"), std::string::npos);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 边界条件测试
|
||||
// ============================================
|
||||
|
||||
class HttpMaskBoundaryTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(HttpMaskBoundaryTest, VeryLongHost) {
|
||||
std::string longHost(1000, 'x');
|
||||
longHost += ".com";
|
||||
|
||||
HttpMask mask(longHost);
|
||||
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, "test", 4);
|
||||
|
||||
std::string maskedStr(masked, maskedSize);
|
||||
EXPECT_NE(maskedStr.find(longHost), std::string::npos);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
TEST_F(HttpMaskBoundaryTest, SpecialCharactersInHost) {
|
||||
HttpMask mask("api-v2.test-server.example.com");
|
||||
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, "x", 1);
|
||||
|
||||
std::string maskedStr(masked, maskedSize);
|
||||
EXPECT_NE(maskedStr.find("Host: api-v2.test-server.example.com"), std::string::npos);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
TEST_F(HttpMaskBoundaryTest, MaxULONGContentLength) {
|
||||
// 测试大的 Content-Length 值
|
||||
HttpMask mask;
|
||||
|
||||
// 不实际分配这么大的内存,只是构造请求头
|
||||
std::string largeContentLengthStr = "Content-Length: 4294967295";
|
||||
|
||||
// 验证格式正确
|
||||
EXPECT_EQ(largeContentLengthStr.find("4294967295"), 16u);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// HTTP 格式验证测试
|
||||
// ============================================
|
||||
|
||||
class HttpFormatTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(HttpFormatTest, ValidHttpRequestFormat) {
|
||||
HttpMask mask("test.com");
|
||||
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, "body", 4);
|
||||
|
||||
std::string maskedStr(masked, maskedSize);
|
||||
|
||||
// 验证请求行
|
||||
EXPECT_EQ(maskedStr.substr(0, 5), "POST ");
|
||||
|
||||
// 验证 HTTP 版本
|
||||
EXPECT_NE(maskedStr.find("HTTP/1.1\r\n"), std::string::npos);
|
||||
|
||||
// 验证必需的头部
|
||||
EXPECT_NE(maskedStr.find("Host:"), std::string::npos);
|
||||
EXPECT_NE(maskedStr.find("Content-Length:"), std::string::npos);
|
||||
EXPECT_NE(maskedStr.find("Content-Type:"), std::string::npos);
|
||||
|
||||
// 验证头部和 body 之间的分隔符
|
||||
EXPECT_NE(maskedStr.find("\r\n\r\n"), std::string::npos);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
TEST_F(HttpFormatTest, ContentLengthMatchesBody) {
|
||||
HttpMask mask;
|
||||
|
||||
std::vector<char> body(123, 'X');
|
||||
char* masked = nullptr;
|
||||
ULONG maskedSize = 0;
|
||||
|
||||
mask.Mask(masked, maskedSize, body.data(), static_cast<ULONG>(body.size()));
|
||||
|
||||
std::string maskedStr(masked, maskedSize);
|
||||
EXPECT_NE(maskedStr.find("Content-Length: 123"), std::string::npos);
|
||||
|
||||
ULONG offset = mask.UnMask(masked, maskedSize);
|
||||
EXPECT_EQ(maskedSize - offset, 123u);
|
||||
|
||||
delete[] masked;
|
||||
}
|
||||
|
||||
660
test/unit/network/PacketFragmentTest.cpp
Normal file
660
test/unit/network/PacketFragmentTest.cpp
Normal file
@@ -0,0 +1,660 @@
|
||||
/**
|
||||
* @file PacketFragmentTest.cpp
|
||||
* @brief 粘包/分包处理测试
|
||||
*
|
||||
* 测试覆盖:
|
||||
* - 完整包接收和解析
|
||||
* - 分包(不完整包)处理
|
||||
* - 粘包(多个包粘在一起)处理
|
||||
* - 混合场景测试
|
||||
* - 边界条件处理
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <cstring>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <functional>
|
||||
|
||||
// ============================================
|
||||
// 协议常量
|
||||
// ============================================
|
||||
|
||||
const int FLAG_LENGTH = 8;
|
||||
const int HDR_LENGTH = 16; // FLAG(8) + PackedLen(4) + OrigLen(4)
|
||||
|
||||
// ============================================
|
||||
// 简化版 CBuffer(测试专用)
|
||||
// ============================================
|
||||
|
||||
class TestBuffer {
|
||||
public:
|
||||
TestBuffer() {}
|
||||
|
||||
void WriteBuffer(const uint8_t* data, size_t len) {
|
||||
m_data.insert(m_data.end(), data, data + len);
|
||||
}
|
||||
|
||||
size_t ReadBuffer(uint8_t* dst, size_t len) {
|
||||
size_t toRead = std::min(len, m_data.size());
|
||||
if (toRead > 0) {
|
||||
memcpy(dst, m_data.data(), toRead);
|
||||
m_data.erase(m_data.begin(), m_data.begin() + toRead);
|
||||
}
|
||||
return toRead;
|
||||
}
|
||||
|
||||
void Skip(size_t len) {
|
||||
size_t toSkip = std::min(len, m_data.size());
|
||||
m_data.erase(m_data.begin(), m_data.begin() + toSkip);
|
||||
}
|
||||
|
||||
size_t GetBufferLength() const {
|
||||
return m_data.size();
|
||||
}
|
||||
|
||||
const uint8_t* GetBuffer(size_t pos = 0) const {
|
||||
if (pos >= m_data.size()) return nullptr;
|
||||
return m_data.data() + pos;
|
||||
}
|
||||
|
||||
bool CopyBuffer(uint8_t* dst, size_t pos, size_t len) const {
|
||||
if (pos + len > m_data.size()) return false;
|
||||
memcpy(dst, m_data.data() + pos, len);
|
||||
return true;
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
m_data.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<uint8_t> m_data;
|
||||
};
|
||||
|
||||
// ============================================
|
||||
// 数据包构建辅助函数
|
||||
// ============================================
|
||||
|
||||
#pragma pack(push, 1)
|
||||
struct PacketHeader {
|
||||
char flag[FLAG_LENGTH];
|
||||
uint32_t packedLength; // 包含头部的总长度
|
||||
uint32_t originalLength; // 原始数据长度
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
std::vector<uint8_t> BuildPacket(const std::vector<uint8_t>& payload, uint8_t key = 0x42) {
|
||||
std::vector<uint8_t> packet(HDR_LENGTH + payload.size());
|
||||
|
||||
PacketHeader* hdr = reinterpret_cast<PacketHeader*>(packet.data());
|
||||
memcpy(hdr->flag, "HELL", 4);
|
||||
hdr->flag[6] = key;
|
||||
hdr->flag[7] = ~key;
|
||||
hdr->packedLength = static_cast<uint32_t>(HDR_LENGTH + payload.size());
|
||||
hdr->originalLength = static_cast<uint32_t>(payload.size());
|
||||
|
||||
if (!payload.empty()) {
|
||||
memcpy(packet.data() + HDR_LENGTH, payload.data(), payload.size());
|
||||
}
|
||||
|
||||
return packet;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> BuildPacketWithData(size_t dataSize, uint8_t fillByte = 0xAA) {
|
||||
std::vector<uint8_t> payload(dataSize, fillByte);
|
||||
return BuildPacket(payload);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 粘包/分包处理器(模拟 OnServerReceiving 逻辑)
|
||||
// ============================================
|
||||
|
||||
class PacketProcessor {
|
||||
public:
|
||||
using PacketCallback = std::function<void(const std::vector<uint8_t>&)>;
|
||||
|
||||
PacketProcessor(PacketCallback callback) : m_callback(callback) {}
|
||||
|
||||
// 接收数据(模拟网络接收)
|
||||
void OnReceive(const uint8_t* data, size_t len) {
|
||||
m_buffer.WriteBuffer(data, len);
|
||||
ProcessBuffer();
|
||||
}
|
||||
|
||||
// 获取待处理的字节数
|
||||
size_t GetPendingBytes() const {
|
||||
return m_buffer.GetBufferLength();
|
||||
}
|
||||
|
||||
// 获取已处理的包数量
|
||||
size_t GetProcessedCount() const {
|
||||
return m_processedCount;
|
||||
}
|
||||
|
||||
private:
|
||||
void ProcessBuffer() {
|
||||
while (m_buffer.GetBufferLength() >= HDR_LENGTH) {
|
||||
// 验证头部
|
||||
const uint8_t* buf = m_buffer.GetBuffer();
|
||||
if (memcmp(buf, "HELL", 4) != 0) {
|
||||
// 无效头部,跳过一个字节重试
|
||||
m_buffer.Skip(1);
|
||||
continue;
|
||||
}
|
||||
|
||||
// 读取包长度
|
||||
uint32_t packedLength;
|
||||
m_buffer.CopyBuffer(reinterpret_cast<uint8_t*>(&packedLength),
|
||||
FLAG_LENGTH, sizeof(uint32_t));
|
||||
|
||||
// 检查长度有效性
|
||||
if (packedLength < HDR_LENGTH || packedLength > 100 * 1024 * 1024) {
|
||||
// 无效长度,跳过头部
|
||||
m_buffer.Skip(FLAG_LENGTH);
|
||||
continue;
|
||||
}
|
||||
|
||||
// 检查包是否完整
|
||||
if (m_buffer.GetBufferLength() < packedLength) {
|
||||
// 不完整,等待更多数据
|
||||
break;
|
||||
}
|
||||
|
||||
// 读取完整包
|
||||
std::vector<uint8_t> packet(packedLength);
|
||||
m_buffer.ReadBuffer(packet.data(), packedLength);
|
||||
|
||||
// 提取 payload
|
||||
std::vector<uint8_t> payload(packet.begin() + HDR_LENGTH, packet.end());
|
||||
m_callback(payload);
|
||||
m_processedCount++;
|
||||
}
|
||||
}
|
||||
|
||||
TestBuffer m_buffer;
|
||||
PacketCallback m_callback;
|
||||
size_t m_processedCount = 0;
|
||||
};
|
||||
|
||||
// ============================================
|
||||
// 完整包接收测试
|
||||
// ============================================
|
||||
|
||||
class CompletePacketTest : public ::testing::Test {
|
||||
protected:
|
||||
std::vector<std::vector<uint8_t>> receivedPackets;
|
||||
|
||||
void SetUp() override {
|
||||
receivedPackets.clear();
|
||||
}
|
||||
|
||||
PacketProcessor::PacketCallback GetCallback() {
|
||||
return [this](const std::vector<uint8_t>& payload) {
|
||||
receivedPackets.push_back(payload);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(CompletePacketTest, SinglePacket) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload = {0x01, 0x02, 0x03, 0x04};
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
processor.OnReceive(packet.data(), packet.size());
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
EXPECT_EQ(receivedPackets[0], payload);
|
||||
EXPECT_EQ(processor.GetPendingBytes(), 0u);
|
||||
}
|
||||
|
||||
TEST_F(CompletePacketTest, EmptyPayload) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload;
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
processor.OnReceive(packet.data(), packet.size());
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
EXPECT_TRUE(receivedPackets[0].empty());
|
||||
}
|
||||
|
||||
TEST_F(CompletePacketTest, LargePayload) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload(64 * 1024); // 64 KB
|
||||
for (size_t i = 0; i < payload.size(); ++i) {
|
||||
payload[i] = static_cast<uint8_t>(i & 0xFF);
|
||||
}
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
processor.OnReceive(packet.data(), packet.size());
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
EXPECT_EQ(receivedPackets[0], payload);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 分包(不完整包)测试
|
||||
// ============================================
|
||||
|
||||
class FragmentedPacketTest : public ::testing::Test {
|
||||
protected:
|
||||
std::vector<std::vector<uint8_t>> receivedPackets;
|
||||
|
||||
PacketProcessor::PacketCallback GetCallback() {
|
||||
return [this](const std::vector<uint8_t>& payload) {
|
||||
receivedPackets.push_back(payload);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(FragmentedPacketTest, TwoFragments) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload = {0xAA, 0xBB, 0xCC, 0xDD};
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
// 分两次发送
|
||||
size_t half = packet.size() / 2;
|
||||
processor.OnReceive(packet.data(), half);
|
||||
EXPECT_EQ(receivedPackets.size(), 0u);
|
||||
EXPECT_EQ(processor.GetPendingBytes(), half);
|
||||
|
||||
processor.OnReceive(packet.data() + half, packet.size() - half);
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
EXPECT_EQ(receivedPackets[0], payload);
|
||||
}
|
||||
|
||||
TEST_F(FragmentedPacketTest, ManyFragments) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload(100);
|
||||
for (size_t i = 0; i < payload.size(); ++i) {
|
||||
payload[i] = static_cast<uint8_t>(i);
|
||||
}
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
// 每次发送 10 字节
|
||||
for (size_t i = 0; i < packet.size(); i += 10) {
|
||||
size_t len = std::min(size_t(10), packet.size() - i);
|
||||
processor.OnReceive(packet.data() + i, len);
|
||||
}
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
EXPECT_EQ(receivedPackets[0], payload);
|
||||
}
|
||||
|
||||
TEST_F(FragmentedPacketTest, OnlyHeader) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload = {0x01, 0x02};
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
// 只发送头部
|
||||
processor.OnReceive(packet.data(), HDR_LENGTH);
|
||||
EXPECT_EQ(receivedPackets.size(), 0u);
|
||||
EXPECT_EQ(processor.GetPendingBytes(), HDR_LENGTH);
|
||||
|
||||
// 发送剩余数据
|
||||
processor.OnReceive(packet.data() + HDR_LENGTH, packet.size() - HDR_LENGTH);
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
}
|
||||
|
||||
TEST_F(FragmentedPacketTest, PartialHeader) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload = {0xFF};
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
// 发送不完整的头部
|
||||
processor.OnReceive(packet.data(), 4); // 只有 "HELL"
|
||||
EXPECT_EQ(receivedPackets.size(), 0u);
|
||||
|
||||
// 发送剩余部分
|
||||
processor.OnReceive(packet.data() + 4, packet.size() - 4);
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 粘包测试
|
||||
// ============================================
|
||||
|
||||
class StickyPacketTest : public ::testing::Test {
|
||||
protected:
|
||||
std::vector<std::vector<uint8_t>> receivedPackets;
|
||||
|
||||
PacketProcessor::PacketCallback GetCallback() {
|
||||
return [this](const std::vector<uint8_t>& payload) {
|
||||
receivedPackets.push_back(payload);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(StickyPacketTest, TwoPacketsStuckTogether) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload1 = {0x11, 0x22};
|
||||
std::vector<uint8_t> payload2 = {0x33, 0x44, 0x55};
|
||||
auto packet1 = BuildPacket(payload1);
|
||||
auto packet2 = BuildPacket(payload2);
|
||||
|
||||
// 合并两个包
|
||||
std::vector<uint8_t> combined;
|
||||
combined.insert(combined.end(), packet1.begin(), packet1.end());
|
||||
combined.insert(combined.end(), packet2.begin(), packet2.end());
|
||||
|
||||
processor.OnReceive(combined.data(), combined.size());
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), 2u);
|
||||
EXPECT_EQ(receivedPackets[0], payload1);
|
||||
EXPECT_EQ(receivedPackets[1], payload2);
|
||||
}
|
||||
|
||||
TEST_F(StickyPacketTest, ThreePacketsStuckTogether) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload1 = {0x01};
|
||||
std::vector<uint8_t> payload2 = {0x02, 0x03};
|
||||
std::vector<uint8_t> payload3 = {0x04, 0x05, 0x06};
|
||||
auto packet1 = BuildPacket(payload1);
|
||||
auto packet2 = BuildPacket(payload2);
|
||||
auto packet3 = BuildPacket(payload3);
|
||||
|
||||
// 合并三个包
|
||||
std::vector<uint8_t> combined;
|
||||
combined.insert(combined.end(), packet1.begin(), packet1.end());
|
||||
combined.insert(combined.end(), packet2.begin(), packet2.end());
|
||||
combined.insert(combined.end(), packet3.begin(), packet3.end());
|
||||
|
||||
processor.OnReceive(combined.data(), combined.size());
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), 3u);
|
||||
EXPECT_EQ(receivedPackets[0], payload1);
|
||||
EXPECT_EQ(receivedPackets[1], payload2);
|
||||
EXPECT_EQ(receivedPackets[2], payload3);
|
||||
}
|
||||
|
||||
TEST_F(StickyPacketTest, ManyPacketsStuckTogether) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> combined;
|
||||
const int numPackets = 100;
|
||||
|
||||
for (int i = 0; i < numPackets; ++i) {
|
||||
std::vector<uint8_t> payload(i % 10 + 1, static_cast<uint8_t>(i));
|
||||
auto packet = BuildPacket(payload);
|
||||
combined.insert(combined.end(), packet.begin(), packet.end());
|
||||
}
|
||||
|
||||
processor.OnReceive(combined.data(), combined.size());
|
||||
|
||||
EXPECT_EQ(receivedPackets.size(), numPackets);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 混合场景测试
|
||||
// ============================================
|
||||
|
||||
class MixedScenarioTest : public ::testing::Test {
|
||||
protected:
|
||||
std::vector<std::vector<uint8_t>> receivedPackets;
|
||||
|
||||
PacketProcessor::PacketCallback GetCallback() {
|
||||
return [this](const std::vector<uint8_t>& payload) {
|
||||
receivedPackets.push_back(payload);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MixedScenarioTest, OneAndHalfPackets) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload1 = {0xAA, 0xBB};
|
||||
std::vector<uint8_t> payload2 = {0xCC, 0xDD, 0xEE, 0xFF};
|
||||
auto packet1 = BuildPacket(payload1);
|
||||
auto packet2 = BuildPacket(payload2);
|
||||
|
||||
// 发送完整包1 + 半个包2
|
||||
std::vector<uint8_t> firstSend;
|
||||
firstSend.insert(firstSend.end(), packet1.begin(), packet1.end());
|
||||
firstSend.insert(firstSend.end(), packet2.begin(), packet2.begin() + packet2.size() / 2);
|
||||
|
||||
processor.OnReceive(firstSend.data(), firstSend.size());
|
||||
EXPECT_EQ(receivedPackets.size(), 1u); // 只处理了包1
|
||||
|
||||
// 发送剩余的半个包2
|
||||
processor.OnReceive(packet2.data() + packet2.size() / 2, packet2.size() - packet2.size() / 2);
|
||||
ASSERT_EQ(receivedPackets.size(), 2u);
|
||||
EXPECT_EQ(receivedPackets[0], payload1);
|
||||
EXPECT_EQ(receivedPackets[1], payload2);
|
||||
}
|
||||
|
||||
TEST_F(MixedScenarioTest, HalfPacketThenOneAndHalf) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload1 = {0x11};
|
||||
std::vector<uint8_t> payload2 = {0x22, 0x33};
|
||||
auto packet1 = BuildPacket(payload1);
|
||||
auto packet2 = BuildPacket(payload2);
|
||||
|
||||
// 发送半个包1
|
||||
processor.OnReceive(packet1.data(), packet1.size() / 2);
|
||||
EXPECT_EQ(receivedPackets.size(), 0u);
|
||||
|
||||
// 发送剩余包1 + 完整包2
|
||||
std::vector<uint8_t> secondSend;
|
||||
secondSend.insert(secondSend.end(), packet1.begin() + packet1.size() / 2, packet1.end());
|
||||
secondSend.insert(secondSend.end(), packet2.begin(), packet2.end());
|
||||
|
||||
processor.OnReceive(secondSend.data(), secondSend.size());
|
||||
ASSERT_EQ(receivedPackets.size(), 2u);
|
||||
}
|
||||
|
||||
TEST_F(MixedScenarioTest, RandomChunkSizes) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
// 准备多个包
|
||||
std::vector<std::vector<uint8_t>> payloads;
|
||||
std::vector<uint8_t> allData;
|
||||
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
std::vector<uint8_t> payload(i * 5 + 10, static_cast<uint8_t>(i));
|
||||
payloads.push_back(payload);
|
||||
auto packet = BuildPacket(payload);
|
||||
allData.insert(allData.end(), packet.begin(), packet.end());
|
||||
}
|
||||
|
||||
// 使用"随机"大小的块发送
|
||||
size_t chunkSizes[] = {1, 7, 15, 16, 17, 31, 32, 33, 64, 128};
|
||||
size_t pos = 0;
|
||||
size_t chunkIdx = 0;
|
||||
|
||||
while (pos < allData.size()) {
|
||||
size_t chunkSize = chunkSizes[chunkIdx % (sizeof(chunkSizes) / sizeof(chunkSizes[0]))];
|
||||
size_t len = std::min(chunkSize, allData.size() - pos);
|
||||
processor.OnReceive(allData.data() + pos, len);
|
||||
pos += len;
|
||||
chunkIdx++;
|
||||
}
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), payloads.size());
|
||||
for (size_t i = 0; i < payloads.size(); ++i) {
|
||||
EXPECT_EQ(receivedPackets[i], payloads[i]) << "Mismatch at packet " << i;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 边界条件测试
|
||||
// ============================================
|
||||
|
||||
class PacketBoundaryTest : public ::testing::Test {
|
||||
protected:
|
||||
std::vector<std::vector<uint8_t>> receivedPackets;
|
||||
|
||||
PacketProcessor::PacketCallback GetCallback() {
|
||||
return [this](const std::vector<uint8_t>& payload) {
|
||||
receivedPackets.push_back(payload);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(PacketBoundaryTest, ExactlyHdrLength) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
// 只有头部,无 payload
|
||||
std::vector<uint8_t> packet(HDR_LENGTH);
|
||||
PacketHeader* hdr = reinterpret_cast<PacketHeader*>(packet.data());
|
||||
memcpy(hdr->flag, "HELL", 4);
|
||||
hdr->flag[6] = 0x42;
|
||||
hdr->flag[7] = ~0x42;
|
||||
hdr->packedLength = HDR_LENGTH;
|
||||
hdr->originalLength = 0;
|
||||
|
||||
processor.OnReceive(packet.data(), packet.size());
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
EXPECT_TRUE(receivedPackets[0].empty());
|
||||
}
|
||||
|
||||
TEST_F(PacketBoundaryTest, SingleBytePayload) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload = {0xFF};
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
processor.OnReceive(packet.data(), packet.size());
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
EXPECT_EQ(receivedPackets[0].size(), 1u);
|
||||
EXPECT_EQ(receivedPackets[0][0], 0xFF);
|
||||
}
|
||||
|
||||
TEST_F(PacketBoundaryTest, ByteByByteReceive) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload = {0x12, 0x34, 0x56};
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
// 每次接收 1 字节
|
||||
for (size_t i = 0; i < packet.size(); ++i) {
|
||||
processor.OnReceive(packet.data() + i, 1);
|
||||
}
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
EXPECT_EQ(receivedPackets[0], payload);
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 数据完整性测试
|
||||
// ============================================
|
||||
|
||||
class DataIntegrityTest : public ::testing::Test {
|
||||
protected:
|
||||
std::vector<std::vector<uint8_t>> receivedPackets;
|
||||
|
||||
PacketProcessor::PacketCallback GetCallback() {
|
||||
return [this](const std::vector<uint8_t>& payload) {
|
||||
receivedPackets.push_back(payload);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(DataIntegrityTest, BinaryData) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
// 包含所有字节值的 payload
|
||||
std::vector<uint8_t> payload(256);
|
||||
for (int i = 0; i < 256; ++i) {
|
||||
payload[i] = static_cast<uint8_t>(i);
|
||||
}
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
processor.OnReceive(packet.data(), packet.size());
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
EXPECT_EQ(receivedPackets[0], payload);
|
||||
}
|
||||
|
||||
TEST_F(DataIntegrityTest, AllZeros) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload(100, 0x00);
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
processor.OnReceive(packet.data(), packet.size());
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
for (uint8_t b : receivedPackets[0]) {
|
||||
EXPECT_EQ(b, 0x00);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DataIntegrityTest, AllOnes) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload(100, 0xFF);
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
processor.OnReceive(packet.data(), packet.size());
|
||||
|
||||
ASSERT_EQ(receivedPackets.size(), 1u);
|
||||
for (uint8_t b : receivedPackets[0]) {
|
||||
EXPECT_EQ(b, 0xFF);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 性能相关测试
|
||||
// ============================================
|
||||
|
||||
class PacketPerformanceTest : public ::testing::Test {
|
||||
protected:
|
||||
size_t packetCount = 0;
|
||||
|
||||
PacketProcessor::PacketCallback GetCallback() {
|
||||
return [this](const std::vector<uint8_t>& payload) {
|
||||
packetCount++;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(PacketPerformanceTest, ManySmallPackets) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
const int numPackets = 10000;
|
||||
std::vector<uint8_t> allData;
|
||||
|
||||
for (int i = 0; i < numPackets; ++i) {
|
||||
std::vector<uint8_t> payload = {static_cast<uint8_t>(i & 0xFF)};
|
||||
auto packet = BuildPacket(payload);
|
||||
allData.insert(allData.end(), packet.begin(), packet.end());
|
||||
}
|
||||
|
||||
processor.OnReceive(allData.data(), allData.size());
|
||||
|
||||
EXPECT_EQ(packetCount, numPackets);
|
||||
}
|
||||
|
||||
TEST_F(PacketPerformanceTest, LargePacketInSmallChunks) {
|
||||
PacketProcessor processor(GetCallback());
|
||||
|
||||
std::vector<uint8_t> payload(100 * 1024); // 100 KB
|
||||
for (size_t i = 0; i < payload.size(); ++i) {
|
||||
payload[i] = static_cast<uint8_t>(i & 0xFF);
|
||||
}
|
||||
auto packet = BuildPacket(payload);
|
||||
|
||||
// 每次发送 1 KB
|
||||
const size_t chunkSize = 1024;
|
||||
for (size_t i = 0; i < packet.size(); i += chunkSize) {
|
||||
size_t len = std::min(chunkSize, packet.size() - i);
|
||||
processor.OnReceive(packet.data() + i, len);
|
||||
}
|
||||
|
||||
EXPECT_EQ(packetCount, 1u);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user