Init: Migrate SimpleRemoter (Since v1.3.1) to Gitea

This commit is contained in:
yuanyuanxiang
2026-04-19 19:55:01 +02:00
commit 5a325a202b
744 changed files with 235562 additions and 0 deletions

View File

@@ -0,0 +1,875 @@
/**
* @file DateVerifyTest.cpp
* @brief DateVerify 时间验证类测试
*
* 测试覆盖:
* - isTimeTampered() 时间篡改检测
* - getTimeOffset() 时间偏差获取
* - isTrial() 试用期检测
* - isTrail() 兼容性测试
* - 边界条件和缓存机制
*/
#include <gtest/gtest.h>
#include <cstring>
#include <ctime>
#include <map>
#include <chrono>
#include <cmath>
#include <string>
#include <functional>
// ============================================
// 可测试版本的 DateVerify 类
// 允许注入模拟的网络时间
// ============================================
class TestableDateVerify
{
private:
bool m_hasVerified = false;
bool m_lastTimeTampered = true;
time_t m_lastVerifyLocalTime = 0;
time_t m_lastNetworkTime = 0;
static const int VERIFY_INTERVAL = 6 * 3600; // 6小时
// 模拟的网络时间0表示网络不可用
time_t m_mockNetworkTime = 0;
bool m_useMockTime = false;
// 模拟的本地时间
time_t m_mockLocalTime = 0;
bool m_useMockLocalTime = false;
// 模拟的编译日期
std::string m_mockCompileDate;
time_t getNetworkTimeInChina()
{
if (m_useMockTime) {
return m_mockNetworkTime;
}
// 实际测试中不调用网络
return 0;
}
time_t getLocalTime()
{
if (m_useMockLocalTime) {
return m_mockLocalTime;
}
return time(nullptr);
}
int monthAbbrevToNumber(const std::string& month)
{
static const std::map<std::string, int> months = {
{"Jan", 1}, {"Feb", 2}, {"Mar", 3}, {"Apr", 4},
{"May", 5}, {"Jun", 6}, {"Jul", 7}, {"Aug", 8},
{"Sep", 9}, {"Oct", 10}, {"Nov", 11}, {"Dec", 12}
};
auto it = months.find(month);
return (it != months.end()) ? it->second : 0;
}
tm parseCompileDate(const char* compileDate)
{
tm tmCompile = { 0 };
std::string monthStr(compileDate, 3);
std::string dayStr(compileDate + 4, 2);
std::string yearStr(compileDate + 7, 4);
tmCompile.tm_year = std::stoi(yearStr) - 1900;
tmCompile.tm_mon = monthAbbrevToNumber(monthStr) - 1;
tmCompile.tm_mday = std::stoi(dayStr);
return tmCompile;
}
int daysBetweenDates(const tm& date1, const tm& date2)
{
auto timeToTimePoint = [](const tm& tmTime) {
std::time_t tt = mktime(const_cast<tm*>(&tmTime));
return std::chrono::system_clock::from_time_t(tt);
};
auto tp1 = timeToTimePoint(date1);
auto tp2 = timeToTimePoint(date2);
auto duration = tp1 > tp2 ? tp1 - tp2 : tp2 - tp1;
return static_cast<int>(std::chrono::duration_cast<std::chrono::hours>(duration).count() / 24);
}
tm getCurrentDate()
{
std::time_t now = getLocalTime();
tm tmNow = *std::localtime(&now);
tmNow.tm_hour = 0;
tmNow.tm_min = 0;
tmNow.tm_sec = 0;
return tmNow;
}
public:
// 测试辅助方法
void setMockNetworkTime(time_t t) { m_mockNetworkTime = t; m_useMockTime = true; }
void setNetworkUnavailable() { m_mockNetworkTime = 0; m_useMockTime = true; }
void setMockLocalTime(time_t t) { m_mockLocalTime = t; m_useMockLocalTime = true; }
void setMockCompileDate(const std::string& date) { m_mockCompileDate = date; }
void resetCache() { m_hasVerified = false; m_lastTimeTampered = true; }
bool hasVerified() const { return m_hasVerified; }
// 检测本地时间是否被篡改
bool isTimeTampered(int toleranceDays = 1)
{
time_t currentLocalTime = getLocalTime();
// 检查是否可以使用缓存
if (m_hasVerified) {
time_t localElapsed = currentLocalTime - m_lastVerifyLocalTime;
// 本地时间在合理范围内前进,使用缓存推算
if (localElapsed >= 0 && localElapsed < VERIFY_INTERVAL) {
time_t estimatedNetworkTime = m_lastNetworkTime + localElapsed;
double diffDays = difftime(estimatedNetworkTime, currentLocalTime) / 86400.0;
if (fabs(diffDays) <= toleranceDays) {
return false; // 未篡改
}
}
}
// 执行网络验证
time_t networkTime = getNetworkTimeInChina();
if (networkTime == 0) {
// 网络不可用:如果之前验证通过且本地时间没异常,暂时信任
if (m_hasVerified && !m_lastTimeTampered) {
time_t localElapsed = currentLocalTime - m_lastVerifyLocalTime;
// 允许5分钟回拨和24小时内的前进
if (localElapsed >= -300 && localElapsed < 24 * 3600) {
return false;
}
}
return true; // 无法验证,视为篡改
}
// 更新缓存
m_hasVerified = true;
m_lastVerifyLocalTime = currentLocalTime;
m_lastNetworkTime = networkTime;
double diffDays = difftime(networkTime, currentLocalTime) / 86400.0;
m_lastTimeTampered = fabs(diffDays) > toleranceDays;
return m_lastTimeTampered;
}
// 获取网络时间与本地时间的偏差(秒)
int getTimeOffset()
{
time_t networkTime = getNetworkTimeInChina();
if (networkTime == 0) return 0;
return static_cast<int>(difftime(networkTime, getLocalTime()));
}
bool isTrial(int trialDays = 7)
{
if (isTimeTampered())
return false;
const char* compileDate = m_mockCompileDate.empty() ? __DATE__ : m_mockCompileDate.c_str();
tm tmCompile = parseCompileDate(compileDate), tmCurrent = getCurrentDate();
int daysDiff = daysBetweenDates(tmCompile, tmCurrent);
return daysDiff <= trialDays;
}
// 兼容性函数
bool isTrail(int trailDays = 7) { return isTrial(trailDays); }
};
// ============================================
// 辅助函数
// ============================================
// 创建指定日期的 time_t
time_t makeTime(int year, int month, int day, int hour = 12, int minute = 0, int second = 0)
{
tm t = {};
t.tm_year = year - 1900;
t.tm_mon = month - 1;
t.tm_mday = day;
t.tm_hour = hour;
t.tm_min = minute;
t.tm_sec = second;
return mktime(&t);
}
// 格式化日期为 __DATE__ 格式 (例如 "Mar 13 2026")
std::string formatCompileDate(int year, int month, int day)
{
static const char* months[] = {
"Jan", "Feb", "Mar", "Apr", "May", "Jun",
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec"
};
char buf[16];
snprintf(buf, sizeof(buf), "%s %2d %d", months[month - 1], day, year);
return std::string(buf);
}
// ============================================
// isTimeTampered 测试
// ============================================
class IsTimeTamperedTest : public ::testing::Test {
protected:
TestableDateVerify verifier;
};
TEST_F(IsTimeTamperedTest, NetworkUnavailable_NoCache_ReturnsTampered)
{
// 网络不可用,无缓存时应返回 true视为篡改
verifier.setNetworkUnavailable();
EXPECT_TRUE(verifier.isTimeTampered());
}
TEST_F(IsTimeTamperedTest, NetworkAvailable_TimeMatch_ReturnsNotTampered)
{
// 网络时间与本地时间匹配
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, NetworkAvailable_TimeWithinTolerance_ReturnsNotTampered)
{
// 网络时间与本地时间差异在容忍范围内12小时差异1天容忍
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 12 * 3600); // 本地时间落后12小时
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, NetworkAvailable_TimeExceedsTolerance_ReturnsTampered)
{
// 网络时间与本地时间差异超过容忍范围2天差异1天容忍
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 2 * 86400); // 本地时间落后2天
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, ToleranceDays_Zero_StrictMode)
{
// 0天容忍度任何偏差都应报告为篡改
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 86400); // 1天偏差
EXPECT_TRUE(verifier.isTimeTampered(0));
}
TEST_F(IsTimeTamperedTest, ToleranceDays_Large_LenientMode)
{
// 大容忍度7天较大偏差仍应通过
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 5 * 86400); // 5天偏差
EXPECT_FALSE(verifier.isTimeTampered(7));
}
TEST_F(IsTimeTamperedTest, LocalTimeAhead_ExceedsTolerance_ReturnsTampered)
{
// 本地时间超前网络时间
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now + 3 * 86400); // 本地时间超前3天
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, CacheHit_WithinInterval_UsesCache)
{
// 首次验证成功
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
EXPECT_TRUE(verifier.hasVerified());
// 模拟时间前进1小时在6小时缓存间隔内
verifier.setMockLocalTime(now + 3600);
verifier.setNetworkUnavailable(); // 网络不可用
// 应该使用缓存,不报告篡改
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, CacheMiss_ExceedsInterval_RequiresNetwork)
{
// 首次验证成功
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 模拟时间前进25小时超过24小时宽容期
// 注实现中网络不可用时有24小时宽容期超过后才报告篡改
verifier.setMockLocalTime(now + 25 * 3600);
verifier.setNetworkUnavailable();
// 超过24小时宽容期网络不可用应报告篡改
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, NetworkUnavailable_WithGoodCache_AllowsSmallRewind)
{
// 首次验证成功
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 模拟时间回拨2分钟在允许的5分钟范围内
verifier.setMockLocalTime(now - 120);
verifier.setNetworkUnavailable();
// 应该暂时信任
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, NetworkUnavailable_WithGoodCache_RejectsLargeRewind)
{
// 首次验证成功
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 模拟时间回拨10分钟超过允许的5分钟
verifier.setMockLocalTime(now - 600);
verifier.setNetworkUnavailable();
// 应该报告篡改
EXPECT_TRUE(verifier.isTimeTampered(1));
}
// ============================================
// getTimeOffset 测试
// ============================================
class GetTimeOffsetTest : public ::testing::Test {
protected:
TestableDateVerify verifier;
};
TEST_F(GetTimeOffsetTest, NetworkUnavailable_ReturnsZero)
{
verifier.setNetworkUnavailable();
EXPECT_EQ(verifier.getTimeOffset(), 0);
}
TEST_F(GetTimeOffsetTest, LocalTimeBehind_ReturnsPositive)
{
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 3600); // 本地落后1小时
int offset = verifier.getTimeOffset();
EXPECT_GT(offset, 0);
EXPECT_NEAR(offset, 3600, 5); // 允许5秒误差
}
TEST_F(GetTimeOffsetTest, LocalTimeAhead_ReturnsNegative)
{
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now + 3600); // 本地超前1小时
int offset = verifier.getTimeOffset();
EXPECT_LT(offset, 0);
EXPECT_NEAR(offset, -3600, 5);
}
TEST_F(GetTimeOffsetTest, TimeMatch_ReturnsNearZero)
{
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
int offset = verifier.getTimeOffset();
EXPECT_NEAR(offset, 0, 2); // 允许2秒误差
}
// ============================================
// isTrial / isTrail 测试
// ============================================
class IsTrialTest : public ::testing::Test {
protected:
TestableDateVerify verifier;
void SetUp() override {
// 默认设置:网络正常,时间同步
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
}
};
TEST_F(IsTrialTest, WithinTrialPeriod_ReturnsTrue)
{
// 编译日期3天前
time_t now = time(nullptr);
tm* tmNow = localtime(&now);
// 创建3天前的日期
time_t threeDaysAgo = now - 3 * 86400;
tm* tmCompile = localtime(&threeDaysAgo);
std::string compileDate = formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_TRUE(verifier.isTrial(7)); // 7天试用期内
}
TEST_F(IsTrialTest, ExactTrialPeriod_ReturnsTrue)
{
// 编译日期正好7天前
time_t now = time(nullptr);
time_t sevenDaysAgo = now - 7 * 86400;
tm* tmCompile = localtime(&sevenDaysAgo);
std::string compileDate = formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_TRUE(verifier.isTrial(7)); // 边界正好7天
}
TEST_F(IsTrialTest, ExceedsTrialPeriod_ReturnsFalse)
{
// 编译日期8天前
time_t now = time(nullptr);
time_t eightDaysAgo = now - 8 * 86400;
tm* tmCompile = localtime(&eightDaysAgo);
std::string compileDate = formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_FALSE(verifier.isTrial(7)); // 超过7天试用期
}
TEST_F(IsTrialTest, CompileToday_ReturnsTrue)
{
// 编译日期:今天
time_t now = time(nullptr);
tm* tmNow = localtime(&now);
std::string compileDate = formatCompileDate(
tmNow->tm_year + 1900,
tmNow->tm_mon + 1,
tmNow->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_TRUE(verifier.isTrial(7));
EXPECT_TRUE(verifier.isTrial(1));
EXPECT_TRUE(verifier.isTrial(0));
}
TEST_F(IsTrialTest, TimeTampered_ReturnsFalse)
{
// 时间被篡改时,无论试用期如何都应返回 false
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 30 * 86400); // 本地时间落后30天
// 即使编译日期是今天,时间篡改也会导致返回 false
tm* tmNow = localtime(&now);
std::string compileDate = formatCompileDate(
tmNow->tm_year + 1900,
tmNow->tm_mon + 1,
tmNow->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_FALSE(verifier.isTrial(7));
}
TEST_F(IsTrialTest, CustomTrialDays_Zero)
{
// 0天试用期只有编译当天有效
time_t now = time(nullptr);
time_t yesterday = now - 86400;
tm* tmYesterday = localtime(&yesterday);
std::string compileDate = formatCompileDate(
tmYesterday->tm_year + 1900,
tmYesterday->tm_mon + 1,
tmYesterday->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_FALSE(verifier.isTrial(0)); // 昨天编译0天试用期已过
}
TEST_F(IsTrialTest, IsTrail_CompatibilityAlias)
{
// isTrail 应该与 isTrial 行为一致
time_t now = time(nullptr);
tm* tmNow = localtime(&now);
std::string compileDate = formatCompileDate(
tmNow->tm_year + 1900,
tmNow->tm_mon + 1,
tmNow->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_EQ(verifier.isTrail(7), verifier.isTrial(7));
EXPECT_EQ(verifier.isTrail(0), verifier.isTrial(0));
}
// ============================================
// 边界条件测试
// ============================================
class BoundaryTest : public ::testing::Test {
protected:
TestableDateVerify verifier;
};
TEST_F(BoundaryTest, ToleranceDays_ExactBoundary)
{
// 测试容忍度边界:偏差正好等于容忍天数
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 86400); // 正好1天偏差
// 1天容忍度1天偏差应该通过<= 比较)
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(BoundaryTest, ToleranceDays_JustOverBoundary)
{
// 测试容忍度边界:偏差刚刚超过容忍天数
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 86400 - 3600); // 1天+1小时偏差
// 1天容忍度偏差超过1天应该失败
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(BoundaryTest, CacheInterval_ExactBoundary)
{
// 测试网络不可用时的24小时宽容期边界
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 正好24小时后刚好到达宽容期边界
verifier.setMockLocalTime(now + 24 * 3600);
verifier.setNetworkUnavailable();
// 刚好到达24小时边界网络不可用应报告篡改
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(BoundaryTest, CacheInterval_JustUnderBoundary)
{
// 测试缓存间隔边界略少于6小时
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 5小时59分钟后仍在缓存有效期内
verifier.setMockLocalTime(now + 6 * 3600 - 60);
verifier.setNetworkUnavailable();
// 仍在缓存有效期内
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(BoundaryTest, NetworkTimeRollback_AllowedMargin)
{
// 测试允许的时间回拨范围正好5分钟
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 回拨正好5分钟
verifier.setMockLocalTime(now - 300);
verifier.setNetworkUnavailable();
// 边界情况:-300 >= -300应该通过
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(BoundaryTest, NetworkTimeRollback_ExceedsMargin)
{
// 测试超过允许的时间回拨范围
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 回拨超过5分钟
verifier.setMockLocalTime(now - 301);
verifier.setNetworkUnavailable();
// 超过允许范围
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(BoundaryTest, TrialDays_LargeValue)
{
// 测试大试用期值
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
// 100天前编译
time_t hundredDaysAgo = now - 100 * 86400;
tm* tmCompile = localtime(&hundredDaysAgo);
std::string compileDate = formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_FALSE(verifier.isTrial(99)); // 99天试用期已过
EXPECT_TRUE(verifier.isTrial(100)); // 正好100天
EXPECT_TRUE(verifier.isTrial(365)); // 365天试用期内
}
// ============================================
// 日期解析测试
// ============================================
class DateParsingTest : public ::testing::Test {
protected:
TestableDateVerify verifier;
void SetUp() override {
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
}
};
TEST_F(DateParsingTest, AllMonths)
{
// 测试所有月份的解析
const char* months[] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun",
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec"};
time_t now = time(nullptr);
tm* tmNow = localtime(&now);
for (int i = 0; i < 12; ++i) {
char dateStr[16];
snprintf(dateStr, sizeof(dateStr), "%s %2d %d", months[i], 15, tmNow->tm_year + 1900);
verifier.setMockCompileDate(dateStr);
// 只要不是时间篡改,应该能正常解析
// 结果取决于当前日期与编译日期的差异
bool result = verifier.isTrial(365); // 使用较长试用期确保通过
EXPECT_TRUE(result) << "Failed for month: " << months[i];
}
}
TEST_F(DateParsingTest, SingleDigitDay)
{
// 测试单数字日期(如 "Mar 1 2026"
time_t now = time(nullptr);
tm* tmNow = localtime(&now);
char dateStr[16];
snprintf(dateStr, sizeof(dateStr), "Mar 1 %d", tmNow->tm_year + 1900);
verifier.setMockCompileDate(dateStr);
EXPECT_TRUE(verifier.isTrial(365));
}
TEST_F(DateParsingTest, DoubleDigitDay)
{
// 测试双数字日期(如 "Mar 15 2026"
time_t now = time(nullptr);
tm* tmNow = localtime(&now);
char dateStr[16];
snprintf(dateStr, sizeof(dateStr), "Mar 15 %d", tmNow->tm_year + 1900);
verifier.setMockCompileDate(dateStr);
EXPECT_TRUE(verifier.isTrial(365));
}
// ============================================
// 授权场景模拟测试
// ============================================
class AuthorizationScenarioTest : public ::testing::Test {
protected:
TestableDateVerify verifier;
};
TEST_F(AuthorizationScenarioTest, ValidLicense_TimeNotTampered)
{
// 场景:有效授权,时间未被篡改
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
// 授权检查应该通过
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(AuthorizationScenarioTest, ExpiredLicense_UserRollsBackTime)
{
// 场景:授权过期,用户将时间回拨企图绕过
time_t now = time(nullptr);
time_t thirtyDaysAgo = now - 30 * 86400;
verifier.setMockNetworkTime(now); // 网络时间是真实时间
verifier.setMockLocalTime(thirtyDaysAgo); // 用户将本地时间回拨30天
// 应该检测到时间篡改
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(AuthorizationScenarioTest, ExpiredLicense_UserAdvancesTime)
{
// 场景:用户将时间提前(不太常见,但也应检测)
time_t now = time(nullptr);
time_t thirtyDaysAhead = now + 30 * 86400;
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(thirtyDaysAhead);
// 应该检测到时间篡改
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(AuthorizationScenarioTest, OfflineUser_RecentValidation)
{
// 场景:用户刚刚验证通过后断网
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1)); // 首次验证通过
// 用户断网,但时间正常前进
verifier.setMockLocalTime(now + 3600); // 1小时后
verifier.setNetworkUnavailable();
// 应该允许(使用缓存)
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(AuthorizationScenarioTest, OfflineUser_LongOffline)
{
// 场景:用户长时间离线后恢复
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 用户离线超过24小时
verifier.setMockLocalTime(now + 25 * 3600);
verifier.setNetworkUnavailable();
// 缓存已过期,网络不可用,应该拒绝
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(AuthorizationScenarioTest, TrialUser_WithinPeriod)
{
// 场景:试用用户在试用期内
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
time_t threeDaysAgo = now - 3 * 86400;
tm* tmCompile = localtime(&threeDaysAgo);
verifier.setMockCompileDate(formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
));
EXPECT_TRUE(verifier.isTrial(7));
}
TEST_F(AuthorizationScenarioTest, TrialUser_Expired)
{
// 场景:试用用户试用期已过
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
time_t tenDaysAgo = now - 10 * 86400;
tm* tmCompile = localtime(&tenDaysAgo);
verifier.setMockCompileDate(formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
));
EXPECT_FALSE(verifier.isTrial(7));
}
TEST_F(AuthorizationScenarioTest, TrialUser_RollsBackTime)
{
// 场景:试用用户回拨时间企图延长试用期
time_t now = time(nullptr);
time_t tenDaysAgo = now - 10 * 86400;
verifier.setMockNetworkTime(now); // 真实网络时间
verifier.setMockLocalTime(tenDaysAgo); // 用户回拨到10天前
// 编译日期设为15天前
time_t fifteenDaysAgo = now - 15 * 86400;
tm* tmCompile = localtime(&fifteenDaysAgo);
verifier.setMockCompileDate(formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
));
// 时间篡改应被检测isTrial 应返回 false
EXPECT_FALSE(verifier.isTrial(7));
}

View File

@@ -0,0 +1,556 @@
/**
* @file BufferTest.cpp
* @brief 客户端 CBuffer 类单元测试
*
* 测试覆盖:
* - 基本读写操作
* - 边界条件(空缓冲区、零长度、超长请求)
* - 内存管理(扩展、收缩)
* - 下溢防护Skip、ReadBuffer 边界)
*/
#include <gtest/gtest.h>
#include <cstring>
#include <vector>
// 模拟 Windows 类型定义(用于跨平台测试)
#ifndef _WIN32
typedef unsigned char BYTE;
typedef BYTE* PBYTE;
typedef BYTE* LPBYTE;
typedef unsigned long ULONG;
typedef int BOOL;
#define TRUE 1
#define FALSE 0
#define MEM_COMMIT 0x1000
#define MEM_RELEASE 0x8000
#define PAGE_READWRITE 0x04
// 跨平台内存分配模拟
inline void* MVirtualAlloc(void*, size_t size, int, int) {
return malloc(size);
}
inline void MVirtualFree(void* ptr, size_t, int) {
free(ptr);
}
inline void CopyMemory(void* dst, const void* src, size_t len) {
memcpy(dst, src, len);
}
inline void MoveMemory(void* dst, const void* src, size_t len) {
memmove(dst, src, len);
}
#else
#include <Windows.h>
// Windows 下的内存分配封装
inline void* MVirtualAlloc(void* addr, size_t size, int type, int protect) {
return VirtualAlloc(addr, size, type, protect);
}
inline void MVirtualFree(void* ptr, size_t size, int type) {
VirtualFree(ptr, size, type);
}
#endif
// 内联包含 Buffer 实现(测试专用)
// 这样可以避免复杂的链接问题
namespace ClientBuffer {
#define U_PAGE_ALIGNMENT 3
#define F_PAGE_ALIGNMENT 3.0
class CBuffer
{
public:
CBuffer() : m_ulMaxLength(0), m_Base(NULL), m_Ptr(NULL) {}
~CBuffer() {
if (m_Base) {
MVirtualFree(m_Base, 0, MEM_RELEASE);
m_Base = NULL;
}
m_Base = m_Ptr = NULL;
m_ulMaxLength = 0;
}
ULONG ReadBuffer(PBYTE Buffer, ULONG ulLength) {
ULONG dataLen = (ULONG)(m_Ptr - m_Base);
if (ulLength > dataLen) {
ulLength = dataLen;
}
if (ulLength) {
CopyMemory(Buffer, m_Base, ulLength);
ULONG remaining = dataLen - ulLength;
if (remaining > 0) {
MoveMemory(m_Base, m_Base + ulLength, remaining);
}
m_Ptr = m_Base + remaining;
}
DeAllocateBuffer((ULONG)(m_Ptr - m_Base));
return ulLength;
}
VOID DeAllocateBuffer(ULONG ulLength) {
int len = (int)(m_Ptr - m_Base);
if (ulLength < (ULONG)len)
return;
ULONG ulNewMaxLength = (ULONG)(ceil(ulLength / F_PAGE_ALIGNMENT) * U_PAGE_ALIGNMENT);
if (m_ulMaxLength <= ulNewMaxLength) {
return;
}
PBYTE NewBase = (PBYTE)MVirtualAlloc(NULL, ulNewMaxLength, MEM_COMMIT, PAGE_READWRITE);
if (NewBase == NULL)
return;
CopyMemory(NewBase, m_Base, len);
MVirtualFree(m_Base, 0, MEM_RELEASE);
m_Base = NewBase;
m_Ptr = m_Base + len;
m_ulMaxLength = ulNewMaxLength;
}
BOOL WriteBuffer(PBYTE Buffer, ULONG ulLength) {
if (ReAllocateBuffer(ulLength + (ULONG)(m_Ptr - m_Base)) == FALSE) {
return FALSE;
}
CopyMemory(m_Ptr, Buffer, ulLength);
m_Ptr += ulLength;
return TRUE;
}
BOOL ReAllocateBuffer(ULONG ulLength) {
if (ulLength < m_ulMaxLength)
return TRUE;
ULONG ulNewMaxLength = (ULONG)(ceil(ulLength / F_PAGE_ALIGNMENT) * U_PAGE_ALIGNMENT);
PBYTE NewBase = (PBYTE)MVirtualAlloc(NULL, ulNewMaxLength, MEM_COMMIT, PAGE_READWRITE);
if (NewBase == NULL) {
return FALSE;
}
ULONG len = (ULONG)(m_Ptr - m_Base);
CopyMemory(NewBase, m_Base, len);
if (m_Base) {
MVirtualFree(m_Base, 0, MEM_RELEASE);
}
m_Base = NewBase;
m_Ptr = m_Base + len;
m_ulMaxLength = ulNewMaxLength;
return TRUE;
}
VOID ClearBuffer() {
m_Ptr = m_Base;
DeAllocateBuffer(1024);
}
ULONG GetBufferLength() const {
return (ULONG)(m_Ptr - m_Base);
}
void Skip(ULONG ulPos) {
if (ulPos == 0)
return;
ULONG dataLen = (ULONG)(m_Ptr - m_Base);
if (ulPos > dataLen) {
ulPos = dataLen;
}
if (ulPos > 0) {
ULONG remaining = dataLen - ulPos;
if (remaining > 0) {
MoveMemory(m_Base, m_Base + ulPos, remaining);
}
m_Ptr = m_Base + remaining;
}
}
PBYTE GetBuffer(ULONG ulPos = 0) const {
if (m_Base == NULL || ulPos >= (ULONG)(m_Ptr - m_Base)) {
return NULL;
}
return m_Base + ulPos;
}
protected:
PBYTE m_Base;
PBYTE m_Ptr;
ULONG m_ulMaxLength;
};
} // namespace ClientBuffer
using ClientBuffer::CBuffer;
// ============================================
// 测试夹具
// ============================================
class ClientBufferTest : public ::testing::Test {
protected:
CBuffer buffer;
void SetUp() override {
// 每个测试前重置
}
void TearDown() override {
// 每个测试后清理
}
// 辅助方法:写入测试数据
void WriteTestData(const std::vector<BYTE>& data) {
buffer.WriteBuffer(const_cast<BYTE*>(data.data()), (ULONG)data.size());
}
// 辅助方法:写入指定长度的填充数据
void WriteFillData(ULONG length, BYTE fillValue = 0x42) {
std::vector<BYTE> data(length, fillValue);
buffer.WriteBuffer(data.data(), length);
}
};
// ============================================
// 构造/析构测试
// ============================================
TEST_F(ClientBufferTest, Constructor_InitializesEmpty) {
CBuffer newBuffer;
EXPECT_EQ(newBuffer.GetBufferLength(), 0u);
EXPECT_EQ(newBuffer.GetBuffer(), nullptr);
}
// ============================================
// WriteBuffer 测试
// ============================================
TEST_F(ClientBufferTest, WriteBuffer_ValidData_ReturnsTrue) {
BYTE data[] = {1, 2, 3, 4, 5};
EXPECT_TRUE(buffer.WriteBuffer(data, 5));
EXPECT_EQ(buffer.GetBufferLength(), 5u);
}
TEST_F(ClientBufferTest, WriteBuffer_ZeroLength_ToEmptyBuffer) {
// 空缓冲区写入 0 字节:由于 VirtualAlloc(0) 返回 NULL会失败
BYTE data[] = {1};
// 实际行为:返回 FALSE无法分配 0 字节)
EXPECT_FALSE(buffer.WriteBuffer(data, 0));
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, WriteBuffer_ZeroLength_ToNonEmptyBuffer) {
// 非空缓冲区写入 0 字节应该成功
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
EXPECT_TRUE(buffer.WriteBuffer(data, 0));
EXPECT_EQ(buffer.GetBufferLength(), 3u);
}
TEST_F(ClientBufferTest, WriteBuffer_MultipleWrites_AccumulatesData) {
BYTE data1[] = {1, 2, 3};
BYTE data2[] = {4, 5};
buffer.WriteBuffer(data1, 3);
buffer.WriteBuffer(data2, 2);
EXPECT_EQ(buffer.GetBufferLength(), 5u);
// 验证数据完整性
BYTE result[5];
buffer.ReadBuffer(result, 5);
EXPECT_EQ(result[0], 1);
EXPECT_EQ(result[1], 2);
EXPECT_EQ(result[2], 3);
EXPECT_EQ(result[3], 4);
EXPECT_EQ(result[4], 5);
}
TEST_F(ClientBufferTest, WriteBuffer_LargeData_HandlesCorrectly) {
const ULONG largeSize = 10000;
std::vector<BYTE> data(largeSize);
for (ULONG i = 0; i < largeSize; i++) {
data[i] = (BYTE)(i & 0xFF);
}
EXPECT_TRUE(buffer.WriteBuffer(data.data(), largeSize));
EXPECT_EQ(buffer.GetBufferLength(), largeSize);
}
// ============================================
// ReadBuffer 测试
// ============================================
TEST_F(ClientBufferTest, ReadBuffer_EmptyBuffer_ReturnsZero) {
BYTE result[10];
EXPECT_EQ(buffer.ReadBuffer(result, 10), 0u);
}
TEST_F(ClientBufferTest, ReadBuffer_ExactLength_ReturnsAll) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE result[5];
ULONG bytesRead = buffer.ReadBuffer(result, 5);
EXPECT_EQ(bytesRead, 5u);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
for (int i = 0; i < 5; i++) {
EXPECT_EQ(result[i], data[i]);
}
}
TEST_F(ClientBufferTest, ReadBuffer_PartialRead_LeavesRemainder) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE result[3];
ULONG bytesRead = buffer.ReadBuffer(result, 3);
EXPECT_EQ(bytesRead, 3u);
EXPECT_EQ(buffer.GetBufferLength(), 2u);
// 验证剩余数据
EXPECT_EQ(*buffer.GetBuffer(0), 4);
EXPECT_EQ(*buffer.GetBuffer(1), 5);
}
TEST_F(ClientBufferTest, ReadBuffer_RequestExceedsAvailable_ReturnsAvailableOnly) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
BYTE result[10];
ULONG bytesRead = buffer.ReadBuffer(result, 10);
EXPECT_EQ(bytesRead, 3u);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, ReadBuffer_ZeroLength_ReturnsZero) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
BYTE result[1];
ULONG bytesRead = buffer.ReadBuffer(result, 0);
EXPECT_EQ(bytesRead, 0u);
EXPECT_EQ(buffer.GetBufferLength(), 3u);
}
// ============================================
// Skip 测试
// ============================================
TEST_F(ClientBufferTest, Skip_ZeroPosition_NoChange) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
buffer.Skip(0);
EXPECT_EQ(buffer.GetBufferLength(), 5u);
}
TEST_F(ClientBufferTest, Skip_PartialSkip_RemovesPrefix) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
buffer.Skip(2);
EXPECT_EQ(buffer.GetBufferLength(), 3u);
EXPECT_EQ(*buffer.GetBuffer(0), 3);
EXPECT_EQ(*buffer.GetBuffer(1), 4);
EXPECT_EQ(*buffer.GetBuffer(2), 5);
}
TEST_F(ClientBufferTest, Skip_ExactLength_ClearsBuffer) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
buffer.Skip(3);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, Skip_ExceedsLength_ClampsToAvailable) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
// 这是修复后的行为:不会下溢,而是限制到可用长度
buffer.Skip(100);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, Skip_EmptyBuffer_NoEffect) {
buffer.Skip(10);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
// ============================================
// GetBuffer 测试
// ============================================
TEST_F(ClientBufferTest, GetBuffer_EmptyBuffer_ReturnsNull) {
EXPECT_EQ(buffer.GetBuffer(), nullptr);
}
TEST_F(ClientBufferTest, GetBuffer_ValidPosition_ReturnsCorrectPointer) {
BYTE data[] = {10, 20, 30, 40, 50};
buffer.WriteBuffer(data, 5);
EXPECT_EQ(*buffer.GetBuffer(0), 10);
EXPECT_EQ(*buffer.GetBuffer(2), 30);
EXPECT_EQ(*buffer.GetBuffer(4), 50);
}
TEST_F(ClientBufferTest, GetBuffer_PositionAtLength_ReturnsNull) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
EXPECT_EQ(buffer.GetBuffer(3), nullptr);
}
TEST_F(ClientBufferTest, GetBuffer_PositionExceedsLength_ReturnsNull) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
EXPECT_EQ(buffer.GetBuffer(100), nullptr);
}
// ============================================
// ClearBuffer 测试
// ============================================
TEST_F(ClientBufferTest, ClearBuffer_AfterWrite_ResetsLength) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
buffer.ClearBuffer();
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, ClearBuffer_EmptyBuffer_NoEffect) {
buffer.ClearBuffer();
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
// ============================================
// 数据完整性测试
// ============================================
TEST_F(ClientBufferTest, DataIntegrity_WriteReadCycle_PreservesData) {
// 写入各种字节值
std::vector<BYTE> data(256);
for (int i = 0; i < 256; i++) {
data[i] = (BYTE)i;
}
buffer.WriteBuffer(data.data(), 256);
std::vector<BYTE> result(256);
ULONG bytesRead = buffer.ReadBuffer(result.data(), 256);
EXPECT_EQ(bytesRead, 256u);
for (int i = 0; i < 256; i++) {
EXPECT_EQ(result[i], data[i]) << "Mismatch at index " << i;
}
}
TEST_F(ClientBufferTest, DataIntegrity_MultipleReadWriteCycles) {
for (int cycle = 0; cycle < 10; cycle++) {
BYTE data[100];
for (int i = 0; i < 100; i++) {
data[i] = (BYTE)(cycle * 10 + i);
}
buffer.WriteBuffer(data, 100);
BYTE result[100];
ULONG bytesRead = buffer.ReadBuffer(result, 100);
EXPECT_EQ(bytesRead, 100u);
for (int i = 0; i < 100; i++) {
EXPECT_EQ(result[i], data[i]);
}
}
}
// ============================================
// 边界条件和下溢防护测试
// ============================================
TEST_F(ClientBufferTest, UnderflowProtection_SkipMoreThanLength_NoUnderflow) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
// 不应崩溃或产生意外行为
buffer.Skip(ULONG_MAX);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, UnderflowProtection_ReadMoreThanLength_NoUnderflow) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
BYTE result[1000];
ULONG bytesRead = buffer.ReadBuffer(result, ULONG_MAX);
// 应该只读取可用的 3 字节
EXPECT_EQ(bytesRead, 3u);
}
// ============================================
// 内存管理测试
// ============================================
TEST_F(ClientBufferTest, MemoryManagement_RepeatedAllocations_NoLeak) {
// 反复分配和释放,验证无内存泄漏
for (int i = 0; i < 100; i++) {
WriteFillData(1000);
buffer.ClearBuffer();
}
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, MemoryManagement_GrowingBuffer_HandlesReallocation) {
// 逐步增长缓冲区
for (ULONG size = 1; size <= 10000; size *= 2) {
WriteFillData(size);
}
EXPECT_GT(buffer.GetBufferLength(), 0u);
}
// ============================================
// 参数化测试
// ============================================
class BufferReadParameterizedTest
: public ::testing::TestWithParam<std::tuple<size_t, size_t, size_t>> {
protected:
CBuffer buffer;
};
TEST_P(BufferReadParameterizedTest, ReadBuffer_VariousLengths) {
auto [writeLen, readLen, expectedRead] = GetParam();
std::vector<BYTE> data(writeLen, 0x42);
if (writeLen > 0) {
buffer.WriteBuffer(data.data(), (ULONG)writeLen);
}
std::vector<BYTE> result(readLen > 0 ? readLen : 1);
ULONG actual = buffer.ReadBuffer(result.data(), (ULONG)readLen);
EXPECT_EQ(actual, expectedRead);
}
INSTANTIATE_TEST_SUITE_P(
ReadLengths,
BufferReadParameterizedTest,
::testing::Values(
std::make_tuple(10, 5, 5), // 正常读取
std::make_tuple(5, 10, 5), // 请求超过可用
std::make_tuple(0, 5, 0), // 空缓冲区
std::make_tuple(100, 0, 0), // 零长度读取
std::make_tuple(1000, 500, 500) // 大数据部分读取
)
);

View File

@@ -0,0 +1,345 @@
/**
* @file RegistryConfigTest.cpp
* @brief iniFile 和 binFile 注册表配置类单元测试
*
* 测试覆盖:
* - 基本读写操作(字符串、整数、浮点数)
* - 二进制数据读写
* - 键句柄缓存机制
* - 并发读写
* - 默认值处理
* - 边界条件
*
* 注意:此测试仅在 Windows 平台运行,使用真实注册表
*/
#include <gtest/gtest.h>
#ifdef _WIN32
// Windows 头文件必须最先包含
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#include <thread>
#include <atomic>
#include <chrono>
#include <vector>
#include <string>
#include <map>
// 定义 INIFILE_STANDALONE 跳过 commands.h 中的 MFC 依赖
#define INIFILE_STANDALONE 1
// 提供 iniFile.h 需要的最小依赖(来自 commands.h
#ifndef GET_FILEPATH
#define GET_FILEPATH(path, file) do { \
char* p = strrchr(path, '\\'); \
if (p) strcpy_s(p + 1, _MAX_PATH - (p - path + 1), file); \
} while(0)
#endif
inline std::vector<std::string> StringToVector(const std::string& s, char ch) {
std::vector<std::string> result;
std::string::size_type start = 0;
std::string::size_type pos = s.find(ch, start);
while (pos != std::string::npos) {
result.push_back(s.substr(start, pos - start));
start = pos + 1;
pos = s.find(ch, start);
}
result.push_back(s.substr(start));
if (result.empty()) result.push_back("");
return result;
}
// 包含实际的 iniFile.h 头文件(需要修改 iniFile.h 支持 INIFILE_STANDALONE
#include "common/iniFile.h"
// 测试用的注册表路径(与生产环境隔离)
#define TEST_INI_PATH "Software\\YAMA_TEST_INI"
#define TEST_BIN_PATH "Software\\YAMA_TEST_BIN"
// ============================================
// iniFile 测试夹具
// ============================================
class IniFileTest : public ::testing::Test {
protected:
std::unique_ptr<iniFile> cfg;
void SetUp() override {
cfg = std::make_unique<iniFile>(TEST_INI_PATH);
}
void TearDown() override {
cfg.reset();
// 清理测试注册表项
RegDeleteTreeA(HKEY_CURRENT_USER, TEST_INI_PATH);
}
};
// ============================================
// binFile 测试夹具
// ============================================
class BinFileTest : public ::testing::Test {
protected:
std::unique_ptr<binFile> cfg;
void SetUp() override {
cfg = std::make_unique<binFile>(TEST_BIN_PATH);
}
void TearDown() override {
cfg.reset();
// 清理测试注册表项
RegDeleteTreeA(HKEY_CURRENT_USER, TEST_BIN_PATH);
}
};
// ============================================
// iniFile 基本读写测试
// ============================================
TEST_F(IniFileTest, SetStr_GetStr_ReturnsCorrectValue) {
EXPECT_TRUE(cfg->SetStr("TestSection", "TestKey", "HelloWorld"));
EXPECT_EQ(cfg->GetStr("TestSection", "TestKey"), "HelloWorld");
}
TEST_F(IniFileTest, SetStr_EmptyString_ReturnsEmpty) {
EXPECT_TRUE(cfg->SetStr("TestSection", "EmptyKey", ""));
EXPECT_EQ(cfg->GetStr("TestSection", "EmptyKey"), "");
}
TEST_F(IniFileTest, GetStr_NonExistent_ReturnsDefault) {
EXPECT_EQ(cfg->GetStr("NonExistent", "Key", "DefaultValue"), "DefaultValue");
}
TEST_F(IniFileTest, SetInt_GetInt_ReturnsCorrectValue) {
EXPECT_TRUE(cfg->SetInt("TestSection", "IntKey", 12345));
EXPECT_EQ(cfg->GetInt("TestSection", "IntKey"), 12345);
}
TEST_F(IniFileTest, SetInt_NegativeValue_ReturnsCorrect) {
EXPECT_TRUE(cfg->SetInt("TestSection", "NegKey", -9999));
EXPECT_EQ(cfg->GetInt("TestSection", "NegKey"), -9999);
}
TEST_F(IniFileTest, GetInt_NonExistent_ReturnsDefault) {
EXPECT_EQ(cfg->GetInt("NonExistent", "Key", 42), 42);
}
TEST_F(IniFileTest, GetInt_InvalidString_ReturnsDefault) {
cfg->SetStr("TestSection", "InvalidInt", "NotANumber");
EXPECT_EQ(cfg->GetInt("TestSection", "InvalidInt", 99), 99);
}
TEST_F(IniFileTest, SetStr_ChineseCharacters_ReturnsCorrect) {
EXPECT_TRUE(cfg->SetStr("TestSection", "ChineseKey", "中文测试"));
EXPECT_EQ(cfg->GetStr("TestSection", "ChineseKey"), "中文测试");
}
TEST_F(IniFileTest, SetStr_LongString_ReturnsCorrect) {
std::string longStr(400, 'A'); // 400 字符长字符串
EXPECT_TRUE(cfg->SetStr("TestSection", "LongKey", longStr));
EXPECT_EQ(cfg->GetStr("TestSection", "LongKey"), longStr);
}
TEST_F(IniFileTest, SetStr_SpecialCharacters_ReturnsCorrect) {
std::string special = "Path\\With\\Backslashes=Value;With|Special<Chars>";
EXPECT_TRUE(cfg->SetStr("TestSection", "SpecialKey", special));
EXPECT_EQ(cfg->GetStr("TestSection", "SpecialKey"), special);
}
// ============================================
// iniFile 多 Section 测试
// ============================================
TEST_F(IniFileTest, MultipleSections_IsolatedValues) {
cfg->SetStr("Section1", "Key", "Value1");
cfg->SetStr("Section2", "Key", "Value2");
cfg->SetStr("Section3", "Key", "Value3");
EXPECT_EQ(cfg->GetStr("Section1", "Key"), "Value1");
EXPECT_EQ(cfg->GetStr("Section2", "Key"), "Value2");
EXPECT_EQ(cfg->GetStr("Section3", "Key"), "Value3");
}
TEST_F(IniFileTest, MultipleKeys_SameSection) {
cfg->SetStr("Settings", "Key1", "Value1");
cfg->SetStr("Settings", "Key2", "Value2");
cfg->SetInt("Settings", "Key3", 100);
EXPECT_EQ(cfg->GetStr("Settings", "Key1"), "Value1");
EXPECT_EQ(cfg->GetStr("Settings", "Key2"), "Value2");
EXPECT_EQ(cfg->GetInt("Settings", "Key3"), 100);
}
// ============================================
// iniFile 覆盖写入测试
// ============================================
TEST_F(IniFileTest, OverwriteValue_ReturnsNewValue) {
cfg->SetStr("TestSection", "Key", "OldValue");
cfg->SetStr("TestSection", "Key", "NewValue");
EXPECT_EQ(cfg->GetStr("TestSection", "Key"), "NewValue");
}
TEST_F(IniFileTest, OverwriteInt_ReturnsNewValue) {
cfg->SetInt("TestSection", "IntKey", 100);
cfg->SetInt("TestSection", "IntKey", 200);
EXPECT_EQ(cfg->GetInt("TestSection", "IntKey"), 200);
}
// ============================================
// binFile 基本读写测试
// ============================================
TEST_F(BinFileTest, SetStr_GetStr_ReturnsCorrectValue) {
EXPECT_TRUE(cfg->SetStr("TestSection", "TestKey", "BinaryHello"));
EXPECT_EQ(cfg->GetStr("TestSection", "TestKey"), "BinaryHello");
}
TEST_F(BinFileTest, SetInt_GetInt_ReturnsCorrectValue) {
EXPECT_TRUE(cfg->SetInt("TestSection", "IntKey", 0x12345678));
EXPECT_EQ(cfg->GetInt("TestSection", "IntKey"), 0x12345678);
}
TEST_F(BinFileTest, SetInt_NegativeValue_ReturnsCorrect) {
EXPECT_TRUE(cfg->SetInt("TestSection", "NegKey", -12345));
EXPECT_EQ(cfg->GetInt("TestSection", "NegKey"), -12345);
}
TEST_F(BinFileTest, GetInt_NonExistent_ReturnsDefault) {
EXPECT_EQ(cfg->GetInt("NonExistent", "Key", 999), 999);
}
TEST_F(BinFileTest, GetStr_NonExistent_ReturnsDefault) {
EXPECT_EQ(cfg->GetStr("NonExistent", "Key", "Default"), "Default");
}
// ============================================
// 并发测试(验证注册表操作的稳定性)
// ============================================
TEST_F(IniFileTest, Concurrent_ReadWrite_NoCorruption) {
std::atomic<bool> running{true};
std::atomic<int> writeCount{0};
std::atomic<int> readCount{0};
// 写线程
std::thread writer([&]() {
int i = 0;
while (running) {
std::string key = "Key" + std::to_string(i % 10);
std::string value = "Value" + std::to_string(i);
if (cfg->SetStr("ConcurrentTest", key, value)) {
writeCount++;
}
i++;
std::this_thread::yield();
}
});
// 读线程
std::thread reader([&]() {
while (running) {
for (int i = 0; i < 10; i++) {
std::string key = "Key" + std::to_string(i);
cfg->GetStr("ConcurrentTest", key, "");
readCount++;
}
std::this_thread::yield();
}
});
// 运行 100ms
std::this_thread::sleep_for(std::chrono::milliseconds(100));
running = false;
writer.join();
reader.join();
EXPECT_GT(writeCount.load(), 0);
EXPECT_GT(readCount.load(), 0);
}
TEST_F(IniFileTest, Concurrent_MultipleWriters_NoCrash) {
std::atomic<bool> running{true};
std::vector<std::thread> writers;
for (int t = 0; t < 4; t++) {
writers.emplace_back([&, t]() {
int i = 0;
while (running) {
std::string section = "Section" + std::to_string(t);
std::string key = "Key" + std::to_string(i % 5);
cfg->SetInt(section, key, i);
cfg->GetInt(section, key);
i++;
std::this_thread::yield();
}
});
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
running = false;
for (auto& t : writers) {
t.join();
}
// 无崩溃即为成功
SUCCEED();
}
// ============================================
// 边界条件测试
// ============================================
TEST_F(IniFileTest, BoundaryCondition_MaxIntValue) {
cfg->SetInt("BoundaryTest", "MaxInt", INT_MAX);
EXPECT_EQ(cfg->GetInt("BoundaryTest", "MaxInt"), INT_MAX);
}
TEST_F(IniFileTest, BoundaryCondition_MinIntValue) {
cfg->SetInt("BoundaryTest", "MinInt", INT_MIN);
EXPECT_EQ(cfg->GetInt("BoundaryTest", "MinInt"), INT_MIN);
}
TEST_F(IniFileTest, BoundaryCondition_ZeroValue) {
cfg->SetInt("BoundaryTest", "Zero", 0);
EXPECT_EQ(cfg->GetInt("BoundaryTest", "Zero", -1), 0);
}
TEST_F(BinFileTest, BoundaryCondition_ZeroInt) {
cfg->SetInt("BoundaryTest", "Zero", 0);
EXPECT_EQ(cfg->GetInt("BoundaryTest", "Zero", -1), 0);
}
TEST_F(BinFileTest, BoundaryCondition_NegativeMax) {
cfg->SetInt("BoundaryTest", "NegMax", INT_MIN);
EXPECT_EQ(cfg->GetInt("BoundaryTest", "NegMax"), INT_MIN);
}
// ============================================
// 性能测试
// ============================================
TEST_F(IniFileTest, Performance_CachedAccess_FastEnough) {
// 预热缓存
cfg->SetStr("PerfTest", "Key", "InitialValue");
auto start = std::chrono::high_resolution_clock::now();
const int iterations = 1000;
for (int i = 0; i < iterations; i++) {
cfg->SetStr("PerfTest", "Key", "Value" + std::to_string(i));
cfg->GetStr("PerfTest", "Key");
}
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
// 1000 次读写应在 2 秒内完成
EXPECT_LT(duration.count(), 2000) << "Performance: " << duration.count() << "ms for 1000 iterations";
}
#else
// 非 Windows 平台跳过测试
TEST(RegistryConfigTest, SkippedOnNonWindows) {
GTEST_SKIP() << "Registry tests only run on Windows";
}
#endif

View File

@@ -0,0 +1,629 @@
/**
* @file ChunkManagerTest.cpp
* @brief 分块管理和区间跟踪测试
*
* 测试覆盖:
* - FileRangeV2 区间操作
* - 接收状态跟踪
* - 区间合并算法
* - 断点续传区间计算
*/
#include <gtest/gtest.h>
#include <vector>
#include <algorithm>
#include <cstdint>
// ============================================
// 区间结构(测试专用)
// ============================================
struct Range {
uint64_t offset;
uint64_t length;
Range(uint64_t o = 0, uint64_t l = 0) : offset(o), length(l) {}
uint64_t end() const { return offset + length; }
bool operator<(const Range& other) const {
return offset < other.offset;
}
bool operator==(const Range& other) const {
return offset == other.offset && length == other.length;
}
};
// ============================================
// 区间管理类(断点续传核心逻辑)
// ============================================
class RangeManager {
public:
RangeManager(uint64_t fileSize = 0) : m_fileSize(fileSize), m_receivedBytes(0) {}
// 添加已接收区间
void AddRange(uint64_t offset, uint64_t length) {
if (length == 0) return;
Range newRange(offset, length);
m_ranges.push_back(newRange);
MergeRanges();
UpdateReceivedBytes();
}
// 获取已接收区间列表
const std::vector<Range>& GetRanges() const {
return m_ranges;
}
// 获取缺失区间列表
std::vector<Range> GetMissingRanges() const {
std::vector<Range> missing;
if (m_fileSize == 0) return missing;
uint64_t currentPos = 0;
for (const auto& r : m_ranges) {
if (r.offset > currentPos) {
missing.emplace_back(currentPos, r.offset - currentPos);
}
currentPos = r.end();
}
if (currentPos < m_fileSize) {
missing.emplace_back(currentPos, m_fileSize - currentPos);
}
return missing;
}
// 获取已接收字节数
uint64_t GetReceivedBytes() const {
return m_receivedBytes;
}
// 是否完整接收
bool IsComplete() const {
return m_receivedBytes >= m_fileSize && m_fileSize > 0;
}
// 清空
void Clear() {
m_ranges.clear();
m_receivedBytes = 0;
}
// 设置文件大小
void SetFileSize(uint64_t size) {
m_fileSize = size;
}
uint64_t GetFileSize() const {
return m_fileSize;
}
private:
void MergeRanges() {
if (m_ranges.size() <= 1) return;
std::sort(m_ranges.begin(), m_ranges.end());
std::vector<Range> merged;
merged.push_back(m_ranges[0]);
for (size_t i = 1; i < m_ranges.size(); ++i) {
Range& last = merged.back();
const Range& current = m_ranges[i];
// 检查是否可以合并(相邻或重叠)
if (current.offset <= last.end()) {
// 扩展现有区间
uint64_t newEnd = std::max(last.end(), current.end());
last.length = newEnd - last.offset;
} else {
// 新的独立区间
merged.push_back(current);
}
}
m_ranges = std::move(merged);
}
void UpdateReceivedBytes() {
m_receivedBytes = 0;
for (const auto& r : m_ranges) {
m_receivedBytes += r.length;
}
}
std::vector<Range> m_ranges;
uint64_t m_fileSize;
uint64_t m_receivedBytes;
};
// ============================================
// 接收状态类(模拟 FileRecvStateV2
// ============================================
class FileRecvState {
public:
FileRecvState() : m_fileSize(0), m_fileIndex(0), m_transferID(0) {}
void Initialize(uint64_t transferID, uint32_t fileIndex, uint64_t fileSize, const std::string& filePath) {
m_transferID = transferID;
m_fileIndex = fileIndex;
m_fileSize = fileSize;
m_filePath = filePath;
m_rangeManager.SetFileSize(fileSize);
}
void AddChunk(uint64_t offset, uint64_t length) {
m_rangeManager.AddRange(offset, length);
}
bool IsComplete() const {
return m_rangeManager.IsComplete();
}
uint64_t GetReceivedBytes() const {
return m_rangeManager.GetReceivedBytes();
}
double GetProgress() const {
if (m_fileSize == 0) return 0.0;
return static_cast<double>(GetReceivedBytes()) / m_fileSize * 100.0;
}
std::vector<Range> GetMissingRanges() const {
return m_rangeManager.GetMissingRanges();
}
const std::vector<Range>& GetReceivedRanges() const {
return m_rangeManager.GetRanges();
}
uint64_t GetTransferID() const { return m_transferID; }
uint32_t GetFileIndex() const { return m_fileIndex; }
uint64_t GetFileSize() const { return m_fileSize; }
const std::string& GetFilePath() const { return m_filePath; }
private:
uint64_t m_transferID;
uint32_t m_fileIndex;
uint64_t m_fileSize;
std::string m_filePath;
RangeManager m_rangeManager;
};
// ============================================
// Range 基础测试
// ============================================
class RangeTest : public ::testing::Test {};
TEST_F(RangeTest, DefaultConstruction) {
Range r;
EXPECT_EQ(r.offset, 0u);
EXPECT_EQ(r.length, 0u);
EXPECT_EQ(r.end(), 0u);
}
TEST_F(RangeTest, Construction) {
Range r(100, 200);
EXPECT_EQ(r.offset, 100u);
EXPECT_EQ(r.length, 200u);
EXPECT_EQ(r.end(), 300u);
}
TEST_F(RangeTest, Comparison) {
Range r1(0, 100);
Range r2(100, 100);
Range r3(0, 100);
EXPECT_TRUE(r1 < r2);
EXPECT_FALSE(r2 < r1);
EXPECT_TRUE(r1 == r3);
EXPECT_FALSE(r1 == r2);
}
// ============================================
// RangeManager 测试
// ============================================
class RangeManagerTest : public ::testing::Test {};
TEST_F(RangeManagerTest, InitialState) {
RangeManager rm(1000);
EXPECT_EQ(rm.GetReceivedBytes(), 0u);
EXPECT_EQ(rm.GetFileSize(), 1000u);
EXPECT_FALSE(rm.IsComplete());
EXPECT_TRUE(rm.GetRanges().empty());
}
TEST_F(RangeManagerTest, AddSingleRange) {
RangeManager rm(1000);
rm.AddRange(0, 500);
EXPECT_EQ(rm.GetReceivedBytes(), 500u);
ASSERT_EQ(rm.GetRanges().size(), 1u);
EXPECT_EQ(rm.GetRanges()[0].offset, 0u);
EXPECT_EQ(rm.GetRanges()[0].length, 500u);
}
TEST_F(RangeManagerTest, AddZeroLengthRange) {
RangeManager rm(1000);
rm.AddRange(0, 0);
EXPECT_EQ(rm.GetReceivedBytes(), 0u);
EXPECT_TRUE(rm.GetRanges().empty());
}
TEST_F(RangeManagerTest, AddNonOverlappingRanges) {
RangeManager rm(1000);
rm.AddRange(0, 100);
rm.AddRange(200, 100);
rm.AddRange(400, 100);
EXPECT_EQ(rm.GetReceivedBytes(), 300u);
ASSERT_EQ(rm.GetRanges().size(), 3u);
}
TEST_F(RangeManagerTest, MergeAdjacentRanges) {
RangeManager rm(1000);
rm.AddRange(0, 100);
rm.AddRange(100, 100); // 紧邻
EXPECT_EQ(rm.GetReceivedBytes(), 200u);
ASSERT_EQ(rm.GetRanges().size(), 1u);
EXPECT_EQ(rm.GetRanges()[0].offset, 0u);
EXPECT_EQ(rm.GetRanges()[0].length, 200u);
}
TEST_F(RangeManagerTest, MergeOverlappingRanges) {
RangeManager rm(1000);
rm.AddRange(0, 150);
rm.AddRange(100, 150); // 重叠
EXPECT_EQ(rm.GetReceivedBytes(), 250u);
ASSERT_EQ(rm.GetRanges().size(), 1u);
EXPECT_EQ(rm.GetRanges()[0].offset, 0u);
EXPECT_EQ(rm.GetRanges()[0].length, 250u);
}
TEST_F(RangeManagerTest, MergeContainedRange) {
RangeManager rm(1000);
rm.AddRange(0, 500);
rm.AddRange(100, 100); // 完全被包含
EXPECT_EQ(rm.GetReceivedBytes(), 500u);
ASSERT_EQ(rm.GetRanges().size(), 1u);
EXPECT_EQ(rm.GetRanges()[0].length, 500u);
}
TEST_F(RangeManagerTest, MergeOutOfOrder) {
RangeManager rm(1000);
rm.AddRange(500, 100);
rm.AddRange(100, 100);
rm.AddRange(0, 100);
EXPECT_EQ(rm.GetReceivedBytes(), 300u);
// 0-100 和 100-200 被合并成 0-200加上 500-600共 2 个区间
ASSERT_EQ(rm.GetRanges().size(), 2u);
// 验证排序和合并
EXPECT_EQ(rm.GetRanges()[0].offset, 0u);
EXPECT_EQ(rm.GetRanges()[0].length, 200u); // 合并后
EXPECT_EQ(rm.GetRanges()[1].offset, 500u);
EXPECT_EQ(rm.GetRanges()[1].length, 100u);
}
TEST_F(RangeManagerTest, MergeMultipleOverlapping) {
RangeManager rm(1000);
rm.AddRange(0, 100);
rm.AddRange(200, 100);
rm.AddRange(50, 200); // 跨越两个区间
EXPECT_EQ(rm.GetReceivedBytes(), 300u);
ASSERT_EQ(rm.GetRanges().size(), 1u);
EXPECT_EQ(rm.GetRanges()[0].offset, 0u);
EXPECT_EQ(rm.GetRanges()[0].length, 300u);
}
TEST_F(RangeManagerTest, GetMissingRanges_Empty) {
RangeManager rm(1000);
auto missing = rm.GetMissingRanges();
ASSERT_EQ(missing.size(), 1u);
EXPECT_EQ(missing[0].offset, 0u);
EXPECT_EQ(missing[0].length, 1000u);
}
TEST_F(RangeManagerTest, GetMissingRanges_Partial) {
RangeManager rm(1000);
rm.AddRange(0, 100);
rm.AddRange(500, 100);
auto missing = rm.GetMissingRanges();
ASSERT_EQ(missing.size(), 2u);
EXPECT_EQ(missing[0].offset, 100u);
EXPECT_EQ(missing[0].length, 400u);
EXPECT_EQ(missing[1].offset, 600u);
EXPECT_EQ(missing[1].length, 400u);
}
TEST_F(RangeManagerTest, GetMissingRanges_Complete) {
RangeManager rm(1000);
rm.AddRange(0, 1000);
auto missing = rm.GetMissingRanges();
EXPECT_TRUE(missing.empty());
}
TEST_F(RangeManagerTest, IsComplete_Exact) {
RangeManager rm(1000);
rm.AddRange(0, 1000);
EXPECT_TRUE(rm.IsComplete());
}
TEST_F(RangeManagerTest, IsComplete_OverReceived) {
RangeManager rm(1000);
rm.AddRange(0, 1500); // 超过文件大小
EXPECT_TRUE(rm.IsComplete());
}
TEST_F(RangeManagerTest, IsComplete_Partial) {
RangeManager rm(1000);
rm.AddRange(0, 999);
EXPECT_FALSE(rm.IsComplete());
}
TEST_F(RangeManagerTest, Clear) {
RangeManager rm(1000);
rm.AddRange(0, 500);
rm.Clear();
EXPECT_EQ(rm.GetReceivedBytes(), 0u);
EXPECT_TRUE(rm.GetRanges().empty());
}
// ============================================
// FileRecvState 测试
// ============================================
class FileRecvStateTest : public ::testing::Test {};
TEST_F(FileRecvStateTest, Initialize) {
FileRecvState state;
state.Initialize(12345, 0, 1024, "C:\\test\\file.txt");
EXPECT_EQ(state.GetTransferID(), 12345u);
EXPECT_EQ(state.GetFileIndex(), 0u);
EXPECT_EQ(state.GetFileSize(), 1024u);
EXPECT_EQ(state.GetFilePath(), "C:\\test\\file.txt");
EXPECT_EQ(state.GetReceivedBytes(), 0u);
EXPECT_FALSE(state.IsComplete());
}
TEST_F(FileRecvStateTest, AddChunks) {
FileRecvState state;
state.Initialize(1, 0, 1000, "file.txt");
state.AddChunk(0, 100);
EXPECT_EQ(state.GetReceivedBytes(), 100u);
EXPECT_NEAR(state.GetProgress(), 10.0, 0.01);
state.AddChunk(100, 400);
EXPECT_EQ(state.GetReceivedBytes(), 500u);
EXPECT_NEAR(state.GetProgress(), 50.0, 0.01);
state.AddChunk(500, 500);
EXPECT_EQ(state.GetReceivedBytes(), 1000u);
EXPECT_TRUE(state.IsComplete());
EXPECT_NEAR(state.GetProgress(), 100.0, 0.01);
}
TEST_F(FileRecvStateTest, OutOfOrderChunks) {
FileRecvState state;
state.Initialize(1, 0, 1000, "file.txt");
state.AddChunk(500, 200);
state.AddChunk(0, 200);
state.AddChunk(800, 200);
EXPECT_EQ(state.GetReceivedBytes(), 600u);
auto missing = state.GetMissingRanges();
ASSERT_EQ(missing.size(), 2u);
EXPECT_EQ(missing[0].offset, 200u);
EXPECT_EQ(missing[0].length, 300u);
EXPECT_EQ(missing[1].offset, 700u);
EXPECT_EQ(missing[1].length, 100u);
}
TEST_F(FileRecvStateTest, DuplicateChunks) {
FileRecvState state;
state.Initialize(1, 0, 1000, "file.txt");
state.AddChunk(0, 500);
state.AddChunk(0, 500); // 重复
state.AddChunk(250, 250); // 重叠
EXPECT_EQ(state.GetReceivedBytes(), 500u);
}
// ============================================
// 断点续传场景测试
// ============================================
class ResumeScenarioTest : public ::testing::Test {};
TEST_F(ResumeScenarioTest, SimulateInterruptedTransfer) {
FileRecvState state;
state.Initialize(12345, 0, 10000, "large_file.bin");
// 模拟接收了一些数据后中断
state.AddChunk(0, 2000);
state.AddChunk(2000, 2000);
state.AddChunk(5000, 1000);
EXPECT_EQ(state.GetReceivedBytes(), 5000u);
EXPECT_NEAR(state.GetProgress(), 50.0, 0.01);
// 获取需要续传的区间
auto missing = state.GetMissingRanges();
ASSERT_EQ(missing.size(), 2u);
// 验证缺失区间
EXPECT_EQ(missing[0].offset, 4000u);
EXPECT_EQ(missing[0].length, 1000u);
EXPECT_EQ(missing[1].offset, 6000u);
EXPECT_EQ(missing[1].length, 4000u);
}
TEST_F(ResumeScenarioTest, ResumeAndComplete) {
FileRecvState state;
state.Initialize(12345, 0, 10000, "large_file.bin");
// 初始接收
state.AddChunk(0, 3000);
state.AddChunk(7000, 3000);
EXPECT_FALSE(state.IsComplete());
// 续传缺失部分
auto missing = state.GetMissingRanges();
for (const auto& r : missing) {
state.AddChunk(r.offset, r.length);
}
EXPECT_TRUE(state.IsComplete());
EXPECT_EQ(state.GetReceivedBytes(), 10000u);
}
TEST_F(ResumeScenarioTest, SmallChunksReassembly) {
FileRecvState state;
state.Initialize(1, 0, 1000, "file.txt");
// 模拟接收很多小块
for (uint64_t i = 0; i < 1000; i += 10) {
state.AddChunk(i, 10);
}
EXPECT_TRUE(state.IsComplete());
// 验证区间已合并
const auto& ranges = state.GetReceivedRanges();
EXPECT_EQ(ranges.size(), 1u);
EXPECT_EQ(ranges[0].offset, 0u);
EXPECT_EQ(ranges[0].length, 1000u);
}
// ============================================
// 大文件场景测试
// ============================================
class LargeFileScenarioTest : public ::testing::Test {};
TEST_F(LargeFileScenarioTest, FileGreaterThan4GB) {
FileRecvState state;
uint64_t fileSize = 5ULL * 1024 * 1024 * 1024; // 5 GB
state.Initialize(1, 0, fileSize, "huge.bin");
uint64_t chunkSize = 64 * 1024; // 64 KB chunks
// 添加几个大区间
state.AddChunk(0, 1ULL * 1024 * 1024 * 1024); // 1 GB
state.AddChunk(3ULL * 1024 * 1024 * 1024, 1ULL * 1024 * 1024 * 1024); // 1 GB at 3GB
EXPECT_EQ(state.GetReceivedBytes(), 2ULL * 1024 * 1024 * 1024);
auto missing = state.GetMissingRanges();
ASSERT_EQ(missing.size(), 2u);
// 1GB-3GB 缺失
EXPECT_EQ(missing[0].offset, 1ULL * 1024 * 1024 * 1024);
EXPECT_EQ(missing[0].length, 2ULL * 1024 * 1024 * 1024);
// 4GB-5GB 缺失
EXPECT_EQ(missing[1].offset, 4ULL * 1024 * 1024 * 1024);
EXPECT_EQ(missing[1].length, 1ULL * 1024 * 1024 * 1024);
}
// ============================================
// 边界条件测试
// ============================================
class ChunkBoundaryTest : public ::testing::Test {};
TEST_F(ChunkBoundaryTest, ZeroFileSize) {
FileRecvState state;
state.Initialize(1, 0, 0, "empty.txt");
EXPECT_FALSE(state.IsComplete()); // 0大小文件不算完成
EXPECT_TRUE(state.GetMissingRanges().empty());
}
TEST_F(ChunkBoundaryTest, SingleByteFile) {
FileRecvState state;
state.Initialize(1, 0, 1, "tiny.txt");
state.AddChunk(0, 1);
EXPECT_TRUE(state.IsComplete());
}
TEST_F(ChunkBoundaryTest, MaxValues) {
RangeManager rm(UINT64_MAX);
// 添加接近最大值的区间
rm.AddRange(UINT64_MAX - 1000, 500);
EXPECT_EQ(rm.GetReceivedBytes(), 500u);
}
TEST_F(ChunkBoundaryTest, OverlappingAtBoundary) {
RangeManager rm(1000);
rm.AddRange(0, 500);
rm.AddRange(499, 2); // 重叠1字节
EXPECT_EQ(rm.GetReceivedBytes(), 501u);
ASSERT_EQ(rm.GetRanges().size(), 1u);
}
// ============================================
// 性能相关测试
// ============================================
class ChunkPerformanceTest : public ::testing::Test {};
TEST_F(ChunkPerformanceTest, ManySmallRanges) {
RangeManager rm(1000000);
// 添加大量不连续的小区间
for (uint64_t i = 0; i < 1000000; i += 20) {
rm.AddRange(i, 10);
}
// 验证区间数量合理
EXPECT_LE(rm.GetRanges().size(), 50000u);
EXPECT_EQ(rm.GetReceivedBytes(), 500000u);
}
TEST_F(ChunkPerformanceTest, ManyContiguousRanges) {
RangeManager rm(1000000);
// 添加大量连续的小区间(应该全部合并)
for (uint64_t i = 0; i < 1000; ++i) {
rm.AddRange(i * 1000, 1000);
}
// 应该合并成单个区间
ASSERT_EQ(rm.GetRanges().size(), 1u);
EXPECT_EQ(rm.GetReceivedBytes(), 1000000u);
EXPECT_TRUE(rm.IsComplete());
}

View File

@@ -0,0 +1,700 @@
/**
* @file FileTransferV2Test.cpp
* @brief V2 文件传输逻辑测试
*
* 测试覆盖:
* - TransferOptionsV2 结构体
* - 传输ID生成
* - 包头构建与解析
* - 变长数据包处理
*/
#include <gtest/gtest.h>
#include <cstring>
#include <vector>
#include <map>
#include <random>
#include <chrono>
// ============================================
// 从 file_upload.h 复制的结构体定义(测试专用)
// ============================================
#pragma pack(push, 1)
enum FileFlagsV2 : uint16_t {
FFV2_NONE = 0x0000,
FFV2_LAST_CHUNK = 0x0001,
FFV2_RESUME_REQ = 0x0002,
FFV2_RESUME_RESP = 0x0004,
FFV2_CANCEL = 0x0008,
FFV2_DIRECTORY = 0x0010,
FFV2_COMPRESSED = 0x0020,
FFV2_ERROR = 0x0040,
};
struct FileChunkPacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint32_t totalFiles;
uint64_t fileSize;
uint64_t offset;
uint64_t dataLength;
uint64_t nameLength;
uint16_t flags;
uint16_t checksum;
uint8_t reserved[8];
};
struct FileRangeV2 {
uint64_t offset;
uint64_t length;
};
struct FileResumePacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint64_t fileSize;
uint64_t receivedBytes;
uint16_t flags;
uint16_t rangeCount;
};
struct FileCompletePacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint64_t fileSize;
uint8_t sha256[32];
};
struct FileQueryResumeV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileCount;
};
struct FileQueryResumeEntryV2 {
uint64_t fileSize;
uint16_t nameLength;
};
struct FileResumeResponseV2 {
uint8_t cmd;
uint64_t srcClientID;
uint64_t dstClientID;
uint16_t flags;
uint32_t fileCount;
};
struct FileResumeResponseEntryV2 {
uint32_t fileIndex;
uint64_t receivedBytes;
};
#pragma pack(pop)
// V2 传输选项
struct TransferOptionsV2 {
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
bool enableResume;
std::map<uint32_t, uint64_t> startOffsets;
TransferOptionsV2() : transferID(0), srcClientID(0), dstClientID(0), enableResume(true) {}
};
// ============================================
// 辅助函数(测试专用实现)
// ============================================
// 生成传输ID简化实现
uint64_t GenerateTransferID() {
static std::mt19937_64 rng(
static_cast<uint64_t>(std::chrono::high_resolution_clock::now().time_since_epoch().count())
);
return rng();
}
// 构建文件块包
std::vector<uint8_t> BuildFileChunkPacketV2(
uint64_t transferID,
uint64_t srcClientID,
uint64_t dstClientID,
uint32_t fileIndex,
uint32_t totalFiles,
uint64_t fileSize,
uint64_t offset,
const std::string& filename,
const std::vector<uint8_t>& data,
uint16_t flags = FFV2_NONE)
{
size_t totalSize = sizeof(FileChunkPacketV2) + filename.size() + data.size();
std::vector<uint8_t> buffer(totalSize);
FileChunkPacketV2* pkt = reinterpret_cast<FileChunkPacketV2*>(buffer.data());
pkt->cmd = 85; // COMMAND_SEND_FILE_V2
pkt->transferID = transferID;
pkt->srcClientID = srcClientID;
pkt->dstClientID = dstClientID;
pkt->fileIndex = fileIndex;
pkt->totalFiles = totalFiles;
pkt->fileSize = fileSize;
pkt->offset = offset;
pkt->dataLength = data.size();
pkt->nameLength = filename.size();
pkt->flags = flags;
pkt->checksum = 0;
memset(pkt->reserved, 0, sizeof(pkt->reserved));
// 追加文件名
memcpy(buffer.data() + sizeof(FileChunkPacketV2), filename.data(), filename.size());
// 追加数据
if (!data.empty()) {
memcpy(buffer.data() + sizeof(FileChunkPacketV2) + filename.size(), data.data(), data.size());
}
return buffer;
}
// 解析文件块包
bool ParseFileChunkPacketV2(
const uint8_t* buffer, size_t len,
FileChunkPacketV2& header,
std::string& filename,
std::vector<uint8_t>& data)
{
if (len < sizeof(FileChunkPacketV2)) {
return false;
}
memcpy(&header, buffer, sizeof(FileChunkPacketV2));
size_t expectedLen = sizeof(FileChunkPacketV2) + header.nameLength + header.dataLength;
if (len < expectedLen) {
return false;
}
const char* namePtr = reinterpret_cast<const char*>(buffer + sizeof(FileChunkPacketV2));
filename.assign(namePtr, header.nameLength);
if (header.dataLength > 0) {
const uint8_t* dataPtr = buffer + sizeof(FileChunkPacketV2) + header.nameLength;
data.assign(dataPtr, dataPtr + header.dataLength);
} else {
data.clear();
}
return true;
}
// 构建续传查询包
std::vector<uint8_t> BuildResumeQuery(
uint64_t transferID,
uint64_t srcClientID,
uint64_t dstClientID,
const std::vector<std::pair<std::string, uint64_t>>& files)
{
// 计算总大小
size_t totalSize = sizeof(FileQueryResumeV2);
for (const auto& file : files) {
totalSize += sizeof(FileQueryResumeEntryV2) + file.first.size();
}
std::vector<uint8_t> buffer(totalSize);
FileQueryResumeV2* pkt = reinterpret_cast<FileQueryResumeV2*>(buffer.data());
pkt->cmd = 88; // COMMAND_FILE_QUERY_RESUME
pkt->transferID = transferID;
pkt->srcClientID = srcClientID;
pkt->dstClientID = dstClientID;
pkt->fileCount = static_cast<uint32_t>(files.size());
uint8_t* ptr = buffer.data() + sizeof(FileQueryResumeV2);
for (const auto& file : files) {
FileQueryResumeEntryV2* entry = reinterpret_cast<FileQueryResumeEntryV2*>(ptr);
entry->fileSize = file.second;
entry->nameLength = static_cast<uint16_t>(file.first.size());
ptr += sizeof(FileQueryResumeEntryV2);
memcpy(ptr, file.first.data(), file.first.size());
ptr += file.first.size();
}
return buffer;
}
// 解析续传查询包
bool ParseResumeQuery(
const uint8_t* buffer, size_t len,
uint64_t& transferID,
uint64_t& srcClientID,
uint64_t& dstClientID,
std::vector<std::pair<std::string, uint64_t>>& files)
{
if (len < sizeof(FileQueryResumeV2)) {
return false;
}
const FileQueryResumeV2* pkt = reinterpret_cast<const FileQueryResumeV2*>(buffer);
transferID = pkt->transferID;
srcClientID = pkt->srcClientID;
dstClientID = pkt->dstClientID;
files.clear();
const uint8_t* ptr = buffer + sizeof(FileQueryResumeV2);
const uint8_t* end = buffer + len;
for (uint32_t i = 0; i < pkt->fileCount; ++i) {
if (ptr + sizeof(FileQueryResumeEntryV2) > end) {
return false;
}
const FileQueryResumeEntryV2* entry = reinterpret_cast<const FileQueryResumeEntryV2*>(ptr);
ptr += sizeof(FileQueryResumeEntryV2);
if (ptr + entry->nameLength > end) {
return false;
}
std::string name(reinterpret_cast<const char*>(ptr), entry->nameLength);
ptr += entry->nameLength;
files.emplace_back(name, entry->fileSize);
}
return true;
}
// ============================================
// TransferOptionsV2 测试
// ============================================
class TransferOptionsV2Test : public ::testing::Test {};
TEST_F(TransferOptionsV2Test, DefaultValues) {
TransferOptionsV2 options;
EXPECT_EQ(options.transferID, 0u);
EXPECT_EQ(options.srcClientID, 0u);
EXPECT_EQ(options.dstClientID, 0u);
EXPECT_TRUE(options.enableResume);
EXPECT_TRUE(options.startOffsets.empty());
}
TEST_F(TransferOptionsV2Test, SetStartOffsets) {
TransferOptionsV2 options;
options.startOffsets[0] = 1024;
options.startOffsets[1] = 2048;
options.startOffsets[2] = 0;
EXPECT_EQ(options.startOffsets.size(), 3u);
EXPECT_EQ(options.startOffsets[0], 1024u);
EXPECT_EQ(options.startOffsets[1], 2048u);
EXPECT_EQ(options.startOffsets[2], 0u);
}
TEST_F(TransferOptionsV2Test, C2CConfiguration) {
TransferOptionsV2 options;
options.transferID = 12345;
options.srcClientID = 100;
options.dstClientID = 200;
EXPECT_EQ(options.transferID, 12345u);
EXPECT_EQ(options.srcClientID, 100u);
EXPECT_EQ(options.dstClientID, 200u);
}
// ============================================
// TransferID 生成测试
// ============================================
class TransferIDTest : public ::testing::Test {};
TEST_F(TransferIDTest, GeneratesNonZero) {
// 多次生成应该都不为 0极低概率
for (int i = 0; i < 100; ++i) {
uint64_t id = GenerateTransferID();
// 不做强制断言,因为理论上可能为 0
// 但实际概率极低
if (id == 0) {
// 如果真的生成了 0再生成一次
id = GenerateTransferID();
}
}
SUCCEED();
}
TEST_F(TransferIDTest, GeneratesUnique) {
std::vector<uint64_t> ids;
for (int i = 0; i < 1000; ++i) {
ids.push_back(GenerateTransferID());
}
// 检查没有重复
std::sort(ids.begin(), ids.end());
auto it = std::unique(ids.begin(), ids.end());
EXPECT_EQ(it, ids.end()) << "Found duplicate transfer IDs";
}
// ============================================
// FileChunkPacketV2 构建测试
// ============================================
class FileChunkPacketV2BuildTest : public ::testing::Test {};
TEST_F(FileChunkPacketV2BuildTest, BuildSimplePacket) {
uint64_t transferID = 12345;
std::string filename = "test.txt";
std::vector<uint8_t> data = {0x01, 0x02, 0x03, 0x04};
auto buffer = BuildFileChunkPacketV2(
transferID, 0, 0, 0, 1, 100, 0, filename, data);
EXPECT_EQ(buffer.size(), sizeof(FileChunkPacketV2) + filename.size() + data.size());
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->cmd, 85);
EXPECT_EQ(pkt->transferID, transferID);
EXPECT_EQ(pkt->nameLength, filename.size());
EXPECT_EQ(pkt->dataLength, data.size());
}
TEST_F(FileChunkPacketV2BuildTest, BuildWithFlags) {
auto buffer = BuildFileChunkPacketV2(
1, 0, 0, 0, 1, 100, 50, "file.bin", {}, FFV2_LAST_CHUNK);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->flags, FFV2_LAST_CHUNK);
EXPECT_EQ(pkt->offset, 50u);
}
TEST_F(FileChunkPacketV2BuildTest, BuildC2CPacket) {
auto buffer = BuildFileChunkPacketV2(
999, 100, 200, 0, 3, 1024, 0, "shared/doc.pdf", {});
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->srcClientID, 100u);
EXPECT_EQ(pkt->dstClientID, 200u);
EXPECT_EQ(pkt->totalFiles, 3u);
}
TEST_F(FileChunkPacketV2BuildTest, BuildEmptyDataPacket) {
std::vector<uint8_t> emptyData;
auto buffer = BuildFileChunkPacketV2(
1, 0, 0, 0, 1, 0, 0, "empty.txt", emptyData);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->dataLength, 0u);
EXPECT_EQ(pkt->fileSize, 0u);
}
TEST_F(FileChunkPacketV2BuildTest, BuildDirectoryPacket) {
auto buffer = BuildFileChunkPacketV2(
1, 0, 0, 0, 1, 0, 0, "subdir/", {}, FFV2_DIRECTORY);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->flags & FFV2_DIRECTORY, FFV2_DIRECTORY);
EXPECT_EQ(pkt->dataLength, 0u);
}
// ============================================
// FileChunkPacketV2 解析测试
// ============================================
class FileChunkPacketV2ParseTest : public ::testing::Test {};
TEST_F(FileChunkPacketV2ParseTest, ParseValidPacket) {
std::string originalName = "test/file.txt";
std::vector<uint8_t> originalData = {0xAA, 0xBB, 0xCC, 0xDD, 0xEE};
auto buffer = BuildFileChunkPacketV2(
12345, 100, 200, 3, 10, 1024 * 1024, 512, originalName, originalData, FFV2_COMPRESSED);
FileChunkPacketV2 header;
std::string parsedName;
std::vector<uint8_t> parsedData;
bool result = ParseFileChunkPacketV2(buffer.data(), buffer.size(), header, parsedName, parsedData);
EXPECT_TRUE(result);
EXPECT_EQ(header.transferID, 12345u);
EXPECT_EQ(header.srcClientID, 100u);
EXPECT_EQ(header.dstClientID, 200u);
EXPECT_EQ(header.fileIndex, 3u);
EXPECT_EQ(header.totalFiles, 10u);
EXPECT_EQ(header.fileSize, 1024u * 1024u);
EXPECT_EQ(header.offset, 512u);
EXPECT_EQ(header.flags, FFV2_COMPRESSED);
EXPECT_EQ(parsedName, originalName);
EXPECT_EQ(parsedData, originalData);
}
TEST_F(FileChunkPacketV2ParseTest, ParseTruncatedHeader) {
std::vector<uint8_t> truncated(sizeof(FileChunkPacketV2) - 10);
FileChunkPacketV2 header;
std::string name;
std::vector<uint8_t> data;
bool result = ParseFileChunkPacketV2(truncated.data(), truncated.size(), header, name, data);
EXPECT_FALSE(result);
}
TEST_F(FileChunkPacketV2ParseTest, ParseTruncatedData) {
auto buffer = BuildFileChunkPacketV2(
1, 0, 0, 0, 1, 100, 0, "file.txt", {0x01, 0x02, 0x03});
// 截断数据部分
std::vector<uint8_t> truncated(buffer.begin(), buffer.end() - 2);
FileChunkPacketV2 header;
std::string name;
std::vector<uint8_t> data;
bool result = ParseFileChunkPacketV2(truncated.data(), truncated.size(), header, name, data);
EXPECT_FALSE(result);
}
TEST_F(FileChunkPacketV2ParseTest, RoundTrip) {
// 构建各种类型的包并验证往返
struct TestCase {
uint64_t transferID;
uint32_t fileIndex;
uint64_t fileSize;
uint64_t offset;
std::string filename;
std::vector<uint8_t> data;
uint16_t flags;
};
std::vector<TestCase> cases = {
{1, 0, 100, 0, "a.txt", {0x01}, FFV2_NONE},
{UINT64_MAX, 999, UINT64_MAX, UINT64_MAX - 100, "path/to/file.bin", {}, FFV2_LAST_CHUNK},
{12345, 5, 1024, 512, "中文文件.txt", {0xAA, 0xBB, 0xCC}, FFV2_COMPRESSED | FFV2_LAST_CHUNK},
};
for (const auto& tc : cases) {
auto buffer = BuildFileChunkPacketV2(
tc.transferID, 0, 0, tc.fileIndex, 10, tc.fileSize, tc.offset, tc.filename, tc.data, tc.flags);
FileChunkPacketV2 header;
std::string name;
std::vector<uint8_t> data;
ASSERT_TRUE(ParseFileChunkPacketV2(buffer.data(), buffer.size(), header, name, data));
EXPECT_EQ(header.transferID, tc.transferID);
EXPECT_EQ(header.fileIndex, tc.fileIndex);
EXPECT_EQ(header.fileSize, tc.fileSize);
EXPECT_EQ(header.offset, tc.offset);
EXPECT_EQ(header.flags, tc.flags);
EXPECT_EQ(name, tc.filename);
EXPECT_EQ(data, tc.data);
}
}
// ============================================
// 续传查询包测试
// ============================================
class ResumeQueryTest : public ::testing::Test {};
TEST_F(ResumeQueryTest, BuildEmpty) {
std::vector<std::pair<std::string, uint64_t>> files;
auto buffer = BuildResumeQuery(12345, 100, 200, files);
EXPECT_EQ(buffer.size(), sizeof(FileQueryResumeV2));
const FileQueryResumeV2* pkt = reinterpret_cast<const FileQueryResumeV2*>(buffer.data());
EXPECT_EQ(pkt->cmd, 88);
EXPECT_EQ(pkt->fileCount, 0u);
}
TEST_F(ResumeQueryTest, BuildSingleFile) {
std::vector<std::pair<std::string, uint64_t>> files = {
{"file.txt", 1024}
};
auto buffer = BuildResumeQuery(12345, 100, 200, files);
const FileQueryResumeV2* pkt = reinterpret_cast<const FileQueryResumeV2*>(buffer.data());
EXPECT_EQ(pkt->fileCount, 1u);
}
TEST_F(ResumeQueryTest, BuildMultipleFiles) {
std::vector<std::pair<std::string, uint64_t>> files = {
{"file1.txt", 1024},
{"dir/file2.bin", 2048},
{"another/path/file3.dat", 4096}
};
auto buffer = BuildResumeQuery(12345, 100, 200, files);
const FileQueryResumeV2* pkt = reinterpret_cast<const FileQueryResumeV2*>(buffer.data());
EXPECT_EQ(pkt->fileCount, 3u);
}
TEST_F(ResumeQueryTest, ParseRoundTrip) {
std::vector<std::pair<std::string, uint64_t>> original = {
{"file1.txt", 1024},
{"subdir/file2.bin", UINT64_MAX},
{"中文/文件.txt", 0}
};
auto buffer = BuildResumeQuery(99999, 111, 222, original);
uint64_t transferID, srcClientID, dstClientID;
std::vector<std::pair<std::string, uint64_t>> parsed;
bool result = ParseResumeQuery(buffer.data(), buffer.size(), transferID, srcClientID, dstClientID, parsed);
EXPECT_TRUE(result);
EXPECT_EQ(transferID, 99999u);
EXPECT_EQ(srcClientID, 111u);
EXPECT_EQ(dstClientID, 222u);
ASSERT_EQ(parsed.size(), original.size());
for (size_t i = 0; i < original.size(); ++i) {
EXPECT_EQ(parsed[i].first, original[i].first);
EXPECT_EQ(parsed[i].second, original[i].second);
}
}
TEST_F(ResumeQueryTest, ParseTruncated) {
std::vector<std::pair<std::string, uint64_t>> files = {
{"file.txt", 1024}
};
auto buffer = BuildResumeQuery(1, 0, 0, files);
// 截断
std::vector<uint8_t> truncated(buffer.begin(), buffer.begin() + sizeof(FileQueryResumeV2) + 5);
uint64_t transferID, srcClientID, dstClientID;
std::vector<std::pair<std::string, uint64_t>> parsed;
bool result = ParseResumeQuery(truncated.data(), truncated.size(), transferID, srcClientID, dstClientID, parsed);
EXPECT_FALSE(result);
}
// ============================================
// 大文件支持测试
// ============================================
class LargeFileTest : public ::testing::Test {};
TEST_F(LargeFileTest, FileSize_GreaterThan4GB) {
uint64_t largeSize = 5ULL * 1024 * 1024 * 1024; // 5 GB
auto buffer = BuildFileChunkPacketV2(
1, 0, 0, 0, 1, largeSize, 0, "large.bin", {});
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->fileSize, largeSize);
}
TEST_F(LargeFileTest, Offset_GreaterThan4GB) {
uint64_t largeOffset = 10ULL * 1024 * 1024 * 1024; // 10 GB
uint64_t fileSize = 20ULL * 1024 * 1024 * 1024; // 20 GB
auto buffer = BuildFileChunkPacketV2(
1, 0, 0, 0, 1, fileSize, largeOffset, "huge.bin", {});
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->fileSize, fileSize);
EXPECT_EQ(pkt->offset, largeOffset);
}
TEST_F(LargeFileTest, MaxValues) {
auto buffer = BuildFileChunkPacketV2(
UINT64_MAX, UINT64_MAX, UINT64_MAX,
UINT32_MAX, UINT32_MAX, UINT64_MAX, UINT64_MAX,
"max.bin", {}, UINT16_MAX);
FileChunkPacketV2 header;
std::string name;
std::vector<uint8_t> data;
ASSERT_TRUE(ParseFileChunkPacketV2(buffer.data(), buffer.size(), header, name, data));
EXPECT_EQ(header.transferID, UINT64_MAX);
EXPECT_EQ(header.srcClientID, UINT64_MAX);
EXPECT_EQ(header.dstClientID, UINT64_MAX);
EXPECT_EQ(header.fileIndex, UINT32_MAX);
EXPECT_EQ(header.totalFiles, UINT32_MAX);
EXPECT_EQ(header.fileSize, UINT64_MAX);
EXPECT_EQ(header.offset, UINT64_MAX);
}
// ============================================
// 多文件传输测试
// ============================================
class MultiFileTransferTest : public ::testing::Test {};
TEST_F(MultiFileTransferTest, SequentialFileIndices) {
std::vector<std::pair<std::string, uint64_t>> files = {
{"file1.txt", 100},
{"file2.txt", 200},
{"file3.txt", 300}
};
for (uint32_t i = 0; i < files.size(); ++i) {
auto buffer = BuildFileChunkPacketV2(
12345, 0, 0, i, static_cast<uint32_t>(files.size()),
files[i].second, 0, files[i].first, {}, FFV2_LAST_CHUNK);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->fileIndex, i);
EXPECT_EQ(pkt->totalFiles, files.size());
}
}
TEST_F(MultiFileTransferTest, ConsistentTransferID) {
uint64_t transferID = GenerateTransferID();
for (int i = 0; i < 5; ++i) {
auto buffer = BuildFileChunkPacketV2(
transferID, 0, 0, i, 5, 1000 * (i + 1), 0, "file" + std::to_string(i), {});
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->transferID, transferID);
}
}
// ============================================
// 错误处理测试
// ============================================
class ErrorHandlingTest : public ::testing::Test {};
TEST_F(ErrorHandlingTest, CancelFlag) {
auto buffer = BuildFileChunkPacketV2(
12345, 0, 0, 0, 1, 0, 0, "", {}, FFV2_CANCEL);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->flags & FFV2_CANCEL, FFV2_CANCEL);
}
TEST_F(ErrorHandlingTest, ErrorFlag) {
auto buffer = BuildFileChunkPacketV2(
12345, 0, 0, 0, 1, 0, 0, "", {}, FFV2_ERROR);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->flags & FFV2_ERROR, FFV2_ERROR);
}
TEST_F(ErrorHandlingTest, CombinedErrorFlags) {
uint16_t flags = FFV2_ERROR | FFV2_CANCEL | FFV2_LAST_CHUNK;
auto buffer = BuildFileChunkPacketV2(
12345, 0, 0, 0, 1, 0, 0, "", {}, flags);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_TRUE(pkt->flags & FFV2_ERROR);
EXPECT_TRUE(pkt->flags & FFV2_CANCEL);
EXPECT_TRUE(pkt->flags & FFV2_LAST_CHUNK);
}

View File

@@ -0,0 +1,778 @@
/**
* @file ResumeStateTest.cpp
* @brief 断点续传状态管理测试
*
* 测试覆盖:
* - 续传状态序列化/反序列化
* - 续传请求/响应包构建
* - 状态文件格式
* - 多文件续传管理
*/
#include <gtest/gtest.h>
#include <cstring>
#include <vector>
#include <map>
#include <string>
#include <cstdint>
// ============================================
// 协议结构(测试专用)
// ============================================
#pragma pack(push, 1)
struct FileRangeV2 {
uint64_t offset;
uint64_t length;
};
struct FileResumePacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint64_t fileSize;
uint64_t receivedBytes;
uint16_t flags;
uint16_t rangeCount;
};
enum FileFlagsV2 : uint16_t {
FFV2_NONE = 0x0000,
FFV2_RESUME_REQ = 0x0002,
FFV2_RESUME_RESP = 0x0004,
};
struct FileResumeResponseV2 {
uint8_t cmd;
uint64_t srcClientID;
uint64_t dstClientID;
uint16_t flags;
uint32_t fileCount;
};
struct FileResumeResponseEntryV2 {
uint32_t fileIndex;
uint64_t receivedBytes;
};
#pragma pack(pop)
// ============================================
// 续传状态管理类(测试专用实现)
// ============================================
struct FileResumeEntry {
uint32_t fileIndex;
uint64_t fileSize;
uint64_t receivedBytes;
std::string relativePath;
std::vector<std::pair<uint64_t, uint64_t>> receivedRanges;
};
class ResumeStateManager {
public:
ResumeStateManager() : m_transferID(0), m_srcClientID(0), m_dstClientID(0) {}
void Initialize(uint64_t transferID, uint64_t srcClientID, uint64_t dstClientID,
const std::string& targetDir) {
m_transferID = transferID;
m_srcClientID = srcClientID;
m_dstClientID = dstClientID;
m_targetDir = targetDir;
m_entries.clear();
}
void AddFile(uint32_t fileIndex, uint64_t fileSize, const std::string& path) {
FileResumeEntry entry;
entry.fileIndex = fileIndex;
entry.fileSize = fileSize;
entry.receivedBytes = 0;
entry.relativePath = path;
m_entries.push_back(entry);
}
void UpdateProgress(uint32_t fileIndex, uint64_t offset, uint64_t length) {
for (auto& entry : m_entries) {
if (entry.fileIndex == fileIndex) {
entry.receivedRanges.emplace_back(offset, length);
entry.receivedBytes += length;
break;
}
}
}
bool GetFileState(uint32_t fileIndex, FileResumeEntry& outEntry) const {
for (const auto& entry : m_entries) {
if (entry.fileIndex == fileIndex) {
outEntry = entry;
return true;
}
}
return false;
}
// 序列化为字节流
std::vector<uint8_t> Serialize() const {
std::vector<uint8_t> buffer;
// Header
auto appendU64 = [&buffer](uint64_t val) {
for (int i = 0; i < 8; ++i) {
buffer.push_back(static_cast<uint8_t>(val >> (i * 8)));
}
};
auto appendU32 = [&buffer](uint32_t val) {
for (int i = 0; i < 4; ++i) {
buffer.push_back(static_cast<uint8_t>(val >> (i * 8)));
}
};
auto appendU16 = [&buffer](uint16_t val) {
buffer.push_back(static_cast<uint8_t>(val & 0xFF));
buffer.push_back(static_cast<uint8_t>(val >> 8));
};
auto appendString = [&buffer, &appendU16](const std::string& str) {
appendU16(static_cast<uint16_t>(str.size()));
buffer.insert(buffer.end(), str.begin(), str.end());
};
// Magic
buffer.push_back('R');
buffer.push_back('S');
buffer.push_back('T');
buffer.push_back('V'); // Resume State V2
appendU64(m_transferID);
appendU64(m_srcClientID);
appendU64(m_dstClientID);
appendString(m_targetDir);
appendU32(static_cast<uint32_t>(m_entries.size()));
for (const auto& entry : m_entries) {
appendU32(entry.fileIndex);
appendU64(entry.fileSize);
appendU64(entry.receivedBytes);
appendString(entry.relativePath);
appendU16(static_cast<uint16_t>(entry.receivedRanges.size()));
for (const auto& range : entry.receivedRanges) {
appendU64(range.first);
appendU64(range.second);
}
}
return buffer;
}
// 从字节流反序列化
bool Deserialize(const std::vector<uint8_t>& buffer) {
if (buffer.size() < 8) return false;
size_t pos = 0;
auto readU64 = [&buffer, &pos]() -> uint64_t {
if (pos + 8 > buffer.size()) return 0;
uint64_t val = 0;
for (int i = 0; i < 8; ++i) {
val |= static_cast<uint64_t>(buffer[pos++]) << (i * 8);
}
return val;
};
auto readU32 = [&buffer, &pos]() -> uint32_t {
if (pos + 4 > buffer.size()) return 0;
uint32_t val = 0;
for (int i = 0; i < 4; ++i) {
val |= static_cast<uint32_t>(buffer[pos++]) << (i * 8);
}
return val;
};
auto readU16 = [&buffer, &pos]() -> uint16_t {
if (pos + 2 > buffer.size()) return 0;
uint16_t val = buffer[pos] | (buffer[pos + 1] << 8);
pos += 2;
return val;
};
auto readString = [&buffer, &pos, &readU16]() -> std::string {
uint16_t len = readU16();
if (pos + len > buffer.size()) return "";
std::string str(buffer.begin() + pos, buffer.begin() + pos + len);
pos += len;
return str;
};
// Check magic
if (buffer[0] != 'R' || buffer[1] != 'S' || buffer[2] != 'T' || buffer[3] != 'V') {
return false;
}
pos = 4;
m_transferID = readU64();
m_srcClientID = readU64();
m_dstClientID = readU64();
m_targetDir = readString();
uint32_t entryCount = readU32();
m_entries.clear();
for (uint32_t i = 0; i < entryCount; ++i) {
FileResumeEntry entry;
entry.fileIndex = readU32();
entry.fileSize = readU64();
entry.receivedBytes = readU64();
entry.relativePath = readString();
uint16_t rangeCount = readU16();
for (uint16_t j = 0; j < rangeCount; ++j) {
uint64_t offset = readU64();
uint64_t length = readU64();
entry.receivedRanges.emplace_back(offset, length);
}
m_entries.push_back(entry);
}
return true;
}
uint64_t GetTransferID() const { return m_transferID; }
uint64_t GetSrcClientID() const { return m_srcClientID; }
uint64_t GetDstClientID() const { return m_dstClientID; }
const std::string& GetTargetDir() const { return m_targetDir; }
size_t GetFileCount() const { return m_entries.size(); }
// 获取所有文件的接收偏移映射
std::map<uint32_t, uint64_t> GetReceivedOffsets() const {
std::map<uint32_t, uint64_t> offsets;
for (const auto& entry : m_entries) {
offsets[entry.fileIndex] = entry.receivedBytes;
}
return offsets;
}
private:
uint64_t m_transferID;
uint64_t m_srcClientID;
uint64_t m_dstClientID;
std::string m_targetDir;
std::vector<FileResumeEntry> m_entries;
};
// ============================================
// 续传包构建/解析辅助函数
// ============================================
std::vector<uint8_t> BuildResumeRequest(
uint64_t transferID,
uint64_t srcClientID,
uint64_t dstClientID,
uint32_t fileIndex,
uint64_t fileSize,
uint64_t receivedBytes,
const std::vector<std::pair<uint64_t, uint64_t>>& ranges)
{
size_t size = sizeof(FileResumePacketV2) + ranges.size() * sizeof(FileRangeV2);
std::vector<uint8_t> buffer(size);
FileResumePacketV2* pkt = reinterpret_cast<FileResumePacketV2*>(buffer.data());
pkt->cmd = 86; // COMMAND_FILE_RESUME
pkt->transferID = transferID;
pkt->srcClientID = srcClientID;
pkt->dstClientID = dstClientID;
pkt->fileIndex = fileIndex;
pkt->fileSize = fileSize;
pkt->receivedBytes = receivedBytes;
pkt->flags = FFV2_RESUME_REQ;
pkt->rangeCount = static_cast<uint16_t>(ranges.size());
FileRangeV2* rangePtr = reinterpret_cast<FileRangeV2*>(buffer.data() + sizeof(FileResumePacketV2));
for (size_t i = 0; i < ranges.size(); ++i) {
rangePtr[i].offset = ranges[i].first;
rangePtr[i].length = ranges[i].second;
}
return buffer;
}
bool ParseResumeRequest(
const uint8_t* buffer, size_t len,
FileResumePacketV2& header,
std::vector<std::pair<uint64_t, uint64_t>>& ranges)
{
if (len < sizeof(FileResumePacketV2)) {
return false;
}
memcpy(&header, buffer, sizeof(FileResumePacketV2));
size_t expectedLen = sizeof(FileResumePacketV2) + header.rangeCount * sizeof(FileRangeV2);
if (len < expectedLen) {
return false;
}
ranges.clear();
const FileRangeV2* rangePtr = reinterpret_cast<const FileRangeV2*>(buffer + sizeof(FileResumePacketV2));
for (uint16_t i = 0; i < header.rangeCount; ++i) {
ranges.emplace_back(rangePtr[i].offset, rangePtr[i].length);
}
return true;
}
std::vector<uint8_t> BuildResumeResponse(
uint64_t srcClientID,
uint64_t dstClientID,
const std::map<uint32_t, uint64_t>& offsets)
{
size_t size = sizeof(FileResumeResponseV2) + offsets.size() * sizeof(FileResumeResponseEntryV2);
std::vector<uint8_t> buffer(size);
FileResumeResponseV2* pkt = reinterpret_cast<FileResumeResponseV2*>(buffer.data());
pkt->cmd = 86; // COMMAND_FILE_RESUME
pkt->srcClientID = srcClientID;
pkt->dstClientID = dstClientID;
pkt->flags = FFV2_RESUME_RESP;
pkt->fileCount = static_cast<uint32_t>(offsets.size());
FileResumeResponseEntryV2* entryPtr = reinterpret_cast<FileResumeResponseEntryV2*>(
buffer.data() + sizeof(FileResumeResponseV2));
size_t i = 0;
for (const auto& kv : offsets) {
entryPtr[i].fileIndex = kv.first;
entryPtr[i].receivedBytes = kv.second;
++i;
}
return buffer;
}
bool ParseResumeResponse(
const uint8_t* buffer, size_t len,
std::map<uint32_t, uint64_t>& offsets)
{
if (len < sizeof(FileResumeResponseV2)) {
return false;
}
const FileResumeResponseV2* pkt = reinterpret_cast<const FileResumeResponseV2*>(buffer);
if ((pkt->flags & FFV2_RESUME_RESP) == 0) {
return false;
}
size_t expectedLen = sizeof(FileResumeResponseV2) + pkt->fileCount * sizeof(FileResumeResponseEntryV2);
if (len < expectedLen) {
return false;
}
offsets.clear();
const FileResumeResponseEntryV2* entryPtr = reinterpret_cast<const FileResumeResponseEntryV2*>(
buffer + sizeof(FileResumeResponseV2));
for (uint32_t i = 0; i < pkt->fileCount; ++i) {
offsets[entryPtr[i].fileIndex] = entryPtr[i].receivedBytes;
}
return true;
}
// ============================================
// ResumeStateManager 测试
// ============================================
class ResumeStateManagerTest : public ::testing::Test {};
TEST_F(ResumeStateManagerTest, Initialize) {
ResumeStateManager mgr;
mgr.Initialize(12345, 100, 200, "C:\\Downloads\\");
EXPECT_EQ(mgr.GetTransferID(), 12345u);
EXPECT_EQ(mgr.GetSrcClientID(), 100u);
EXPECT_EQ(mgr.GetDstClientID(), 200u);
EXPECT_EQ(mgr.GetTargetDir(), "C:\\Downloads\\");
EXPECT_EQ(mgr.GetFileCount(), 0u);
}
TEST_F(ResumeStateManagerTest, AddFiles) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "");
mgr.AddFile(0, 1000, "file1.txt");
mgr.AddFile(1, 2000, "subdir/file2.bin");
mgr.AddFile(2, 3000, "another/path/file3.dat");
EXPECT_EQ(mgr.GetFileCount(), 3u);
FileResumeEntry entry;
ASSERT_TRUE(mgr.GetFileState(1, entry));
EXPECT_EQ(entry.fileSize, 2000u);
EXPECT_EQ(entry.relativePath, "subdir/file2.bin");
}
TEST_F(ResumeStateManagerTest, UpdateProgress) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "");
mgr.AddFile(0, 10000, "file.bin");
mgr.UpdateProgress(0, 0, 2000);
mgr.UpdateProgress(0, 2000, 3000);
FileResumeEntry entry;
ASSERT_TRUE(mgr.GetFileState(0, entry));
EXPECT_EQ(entry.receivedBytes, 5000u);
EXPECT_EQ(entry.receivedRanges.size(), 2u);
}
TEST_F(ResumeStateManagerTest, GetReceivedOffsets) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "");
mgr.AddFile(0, 1000, "a.txt");
mgr.AddFile(1, 2000, "b.txt");
mgr.AddFile(2, 3000, "c.txt");
mgr.UpdateProgress(0, 0, 500);
mgr.UpdateProgress(1, 0, 1500);
mgr.UpdateProgress(2, 0, 2500);
auto offsets = mgr.GetReceivedOffsets();
EXPECT_EQ(offsets.size(), 3u);
EXPECT_EQ(offsets[0], 500u);
EXPECT_EQ(offsets[1], 1500u);
EXPECT_EQ(offsets[2], 2500u);
}
// ============================================
// 序列化/反序列化测试
// ============================================
class ResumeSerializationTest : public ::testing::Test {};
TEST_F(ResumeSerializationTest, EmptyState) {
ResumeStateManager mgr1, mgr2;
mgr1.Initialize(12345, 100, 200, "C:\\Target\\");
auto buffer = mgr1.Serialize();
ASSERT_TRUE(mgr2.Deserialize(buffer));
EXPECT_EQ(mgr2.GetTransferID(), 12345u);
EXPECT_EQ(mgr2.GetSrcClientID(), 100u);
EXPECT_EQ(mgr2.GetDstClientID(), 200u);
EXPECT_EQ(mgr2.GetTargetDir(), "C:\\Target\\");
EXPECT_EQ(mgr2.GetFileCount(), 0u);
}
TEST_F(ResumeSerializationTest, SingleFile) {
ResumeStateManager mgr1, mgr2;
mgr1.Initialize(1, 0, 0, "/tmp/download/");
mgr1.AddFile(0, 10000, "test.bin");
mgr1.UpdateProgress(0, 0, 5000);
auto buffer = mgr1.Serialize();
ASSERT_TRUE(mgr2.Deserialize(buffer));
EXPECT_EQ(mgr2.GetFileCount(), 1u);
FileResumeEntry entry;
ASSERT_TRUE(mgr2.GetFileState(0, entry));
EXPECT_EQ(entry.fileSize, 10000u);
EXPECT_EQ(entry.receivedBytes, 5000u);
EXPECT_EQ(entry.relativePath, "test.bin");
}
TEST_F(ResumeSerializationTest, MultipleFiles) {
ResumeStateManager mgr1, mgr2;
mgr1.Initialize(99999, 111, 222, "D:\\Backup\\");
mgr1.AddFile(0, 1000, "file1.txt");
mgr1.AddFile(1, 2000, "dir/file2.bin");
mgr1.AddFile(2, 3000, "path/to/file3.dat");
mgr1.UpdateProgress(0, 0, 1000); // 完成
mgr1.UpdateProgress(1, 0, 500);
mgr1.UpdateProgress(1, 500, 500);
mgr1.UpdateProgress(2, 0, 1000);
mgr1.UpdateProgress(2, 2000, 500);
auto buffer = mgr1.Serialize();
ASSERT_TRUE(mgr2.Deserialize(buffer));
EXPECT_EQ(mgr2.GetFileCount(), 3u);
FileResumeEntry entry;
ASSERT_TRUE(mgr2.GetFileState(0, entry));
EXPECT_EQ(entry.receivedBytes, 1000u);
ASSERT_TRUE(mgr2.GetFileState(1, entry));
EXPECT_EQ(entry.receivedBytes, 1000u);
EXPECT_EQ(entry.receivedRanges.size(), 2u);
ASSERT_TRUE(mgr2.GetFileState(2, entry));
EXPECT_EQ(entry.receivedBytes, 1500u);
EXPECT_EQ(entry.receivedRanges.size(), 2u);
}
TEST_F(ResumeSerializationTest, InvalidMagic) {
std::vector<uint8_t> invalidBuffer = {'X', 'X', 'X', 'X', 0, 0, 0, 0};
ResumeStateManager mgr;
EXPECT_FALSE(mgr.Deserialize(invalidBuffer));
}
TEST_F(ResumeSerializationTest, TruncatedBuffer) {
ResumeStateManager mgr1;
mgr1.Initialize(1, 0, 0, "test");
mgr1.AddFile(0, 1000, "file.txt");
auto buffer = mgr1.Serialize();
// 截断
std::vector<uint8_t> truncated(buffer.begin(), buffer.begin() + 10);
ResumeStateManager mgr2;
// 可能成功也可能失败,取决于截断位置
// 主要验证不会崩溃
mgr2.Deserialize(truncated);
SUCCEED();
}
// ============================================
// 续传请求/响应包测试
// ============================================
class ResumePacketTest : public ::testing::Test {};
TEST_F(ResumePacketTest, BuildResumeRequest_NoRanges) {
std::vector<std::pair<uint64_t, uint64_t>> ranges;
auto buffer = BuildResumeRequest(12345, 100, 200, 5, 10000, 0, ranges);
EXPECT_EQ(buffer.size(), sizeof(FileResumePacketV2));
const FileResumePacketV2* pkt = reinterpret_cast<const FileResumePacketV2*>(buffer.data());
EXPECT_EQ(pkt->cmd, 86);
EXPECT_EQ(pkt->flags, FFV2_RESUME_REQ);
EXPECT_EQ(pkt->rangeCount, 0);
}
TEST_F(ResumePacketTest, BuildResumeRequest_WithRanges) {
std::vector<std::pair<uint64_t, uint64_t>> ranges = {
{0, 1000},
{2000, 500},
{5000, 2000}
};
auto buffer = BuildResumeRequest(1, 0, 0, 0, 10000, 3500, ranges);
FileResumePacketV2 header;
std::vector<std::pair<uint64_t, uint64_t>> parsedRanges;
ASSERT_TRUE(ParseResumeRequest(buffer.data(), buffer.size(), header, parsedRanges));
EXPECT_EQ(header.fileSize, 10000u);
EXPECT_EQ(header.receivedBytes, 3500u);
ASSERT_EQ(parsedRanges.size(), 3u);
EXPECT_EQ(parsedRanges[0].first, 0u);
EXPECT_EQ(parsedRanges[0].second, 1000u);
EXPECT_EQ(parsedRanges[2].first, 5000u);
}
TEST_F(ResumePacketTest, BuildResumeResponse) {
std::map<uint32_t, uint64_t> offsets = {
{0, 1000},
{1, 0},
{2, 5000}
};
auto buffer = BuildResumeResponse(100, 200, offsets);
std::map<uint32_t, uint64_t> parsedOffsets;
ASSERT_TRUE(ParseResumeResponse(buffer.data(), buffer.size(), parsedOffsets));
EXPECT_EQ(parsedOffsets.size(), 3u);
EXPECT_EQ(parsedOffsets[0], 1000u);
EXPECT_EQ(parsedOffsets[1], 0u);
EXPECT_EQ(parsedOffsets[2], 5000u);
}
TEST_F(ResumePacketTest, ParseTruncatedRequest) {
auto buffer = BuildResumeRequest(1, 0, 0, 0, 1000, 0, {{0, 500}});
// 截断
FileResumePacketV2 header;
std::vector<std::pair<uint64_t, uint64_t>> ranges;
EXPECT_FALSE(ParseResumeRequest(buffer.data(), sizeof(FileResumePacketV2) - 5, header, ranges));
}
TEST_F(ResumePacketTest, ParseTruncatedResponse) {
std::map<uint32_t, uint64_t> offsets = {{0, 1000}};
auto buffer = BuildResumeResponse(0, 0, offsets);
// 截断
std::map<uint32_t, uint64_t> parsedOffsets;
EXPECT_FALSE(ParseResumeResponse(buffer.data(), sizeof(FileResumeResponseV2) - 5, parsedOffsets));
}
// ============================================
// 续传场景测试
// ============================================
class ResumeScenarioTest : public ::testing::Test {};
TEST_F(ResumeScenarioTest, SimulateInterruptAndResume) {
// 第一次传输,接收了部分数据
ResumeStateManager session1;
session1.Initialize(12345, 100, 0, "C:\\Downloads\\");
session1.AddFile(0, 10000, "large_file.bin");
session1.AddFile(1, 5000, "small_file.txt");
session1.UpdateProgress(0, 0, 3000); // 30%
session1.UpdateProgress(1, 0, 5000); // 100%
// 保存状态
auto savedState = session1.Serialize();
// 模拟程序重启,恢复状态
ResumeStateManager session2;
ASSERT_TRUE(session2.Deserialize(savedState));
// 获取续传偏移
auto offsets = session2.GetReceivedOffsets();
EXPECT_EQ(offsets[0], 3000u); // 从 3000 继续
EXPECT_EQ(offsets[1], 5000u); // 已完成
// 继续传输
session2.UpdateProgress(0, 3000, 7000);
FileResumeEntry entry;
ASSERT_TRUE(session2.GetFileState(0, entry));
EXPECT_EQ(entry.receivedBytes, 10000u);
}
TEST_F(ResumeScenarioTest, C2CResumeNegotiation) {
// 源客户端查询续传状态
std::vector<std::pair<std::string, uint64_t>> files = {
{"file1.txt", 1000},
{"file2.txt", 2000},
{"file3.txt", 3000}
};
// 目标客户端已有部分数据
std::map<uint32_t, uint64_t> receivedOffsets = {
{0, 500}, // file1: 50%
{1, 2000}, // file2: 100%
{2, 0} // file3: 0%
};
// 构建响应
auto response = BuildResumeResponse(100, 200, receivedOffsets);
// 源客户端解析响应
std::map<uint32_t, uint64_t> parsedOffsets;
ASSERT_TRUE(ParseResumeResponse(response.data(), response.size(), parsedOffsets));
// 根据偏移决定发送策略
for (size_t i = 0; i < files.size(); ++i) {
uint32_t fileIndex = static_cast<uint32_t>(i);
uint64_t startOffset = parsedOffsets[fileIndex];
if (startOffset >= files[i].second) {
// 已完成,跳过
EXPECT_EQ(i, 1u) << "Only file2 should be complete";
} else if (startOffset > 0) {
// 部分完成,从偏移继续
EXPECT_EQ(i, 0u) << "Only file1 should be partial";
EXPECT_EQ(startOffset, 500u);
} else {
// 未开始,从头发送
EXPECT_EQ(i, 2u) << "Only file3 should start from beginning";
}
}
}
TEST_F(ResumeScenarioTest, LargeFileResume) {
ResumeStateManager mgr;
uint64_t fileSize = 10ULL * 1024 * 1024 * 1024; // 10 GB
mgr.Initialize(1, 0, 0, "/data/");
mgr.AddFile(0, fileSize, "huge.bin");
// 模拟接收了 5 GB
mgr.UpdateProgress(0, 0, 5ULL * 1024 * 1024 * 1024);
auto offsets = mgr.GetReceivedOffsets();
EXPECT_EQ(offsets[0], 5ULL * 1024 * 1024 * 1024);
// 序列化/反序列化
auto buffer = mgr.Serialize();
ResumeStateManager mgr2;
ASSERT_TRUE(mgr2.Deserialize(buffer));
FileResumeEntry entry;
ASSERT_TRUE(mgr2.GetFileState(0, entry));
EXPECT_EQ(entry.fileSize, 10ULL * 1024 * 1024 * 1024);
EXPECT_EQ(entry.receivedBytes, 5ULL * 1024 * 1024 * 1024);
}
// ============================================
// 边界条件测试
// ============================================
class ResumeBoundaryTest : public ::testing::Test {};
TEST_F(ResumeBoundaryTest, EmptyFileName) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "");
mgr.AddFile(0, 100, "");
auto buffer = mgr.Serialize();
ResumeStateManager mgr2;
ASSERT_TRUE(mgr2.Deserialize(buffer));
FileResumeEntry entry;
ASSERT_TRUE(mgr2.GetFileState(0, entry));
EXPECT_EQ(entry.relativePath, "");
}
TEST_F(ResumeBoundaryTest, SpecialCharactersInPath) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "C:\\My Files\\测试目录\\");
mgr.AddFile(0, 100, "文件 (1).txt");
auto buffer = mgr.Serialize();
ResumeStateManager mgr2;
ASSERT_TRUE(mgr2.Deserialize(buffer));
EXPECT_EQ(mgr2.GetTargetDir(), "C:\\My Files\\测试目录\\");
FileResumeEntry entry;
ASSERT_TRUE(mgr2.GetFileState(0, entry));
EXPECT_EQ(entry.relativePath, "文件 (1).txt");
}
TEST_F(ResumeBoundaryTest, ManyRanges) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "");
mgr.AddFile(0, 100000, "fragmented.bin");
// 添加很多小区间
for (uint64_t i = 0; i < 100000; i += 20) {
mgr.UpdateProgress(0, i, 10);
}
auto buffer = mgr.Serialize();
ResumeStateManager mgr2;
ASSERT_TRUE(mgr2.Deserialize(buffer));
FileResumeEntry entry;
ASSERT_TRUE(mgr2.GetFileState(0, entry));
EXPECT_EQ(entry.receivedRanges.size(), 5000u);
}
TEST_F(ResumeBoundaryTest, MaxFileCount) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "");
// 添加大量文件
for (uint32_t i = 0; i < 1000; ++i) {
mgr.AddFile(i, 100 * (i + 1), "file" + std::to_string(i) + ".txt");
}
auto buffer = mgr.Serialize();
ResumeStateManager mgr2;
ASSERT_TRUE(mgr2.Deserialize(buffer));
EXPECT_EQ(mgr2.GetFileCount(), 1000u);
}

View File

@@ -0,0 +1,647 @@
/**
* @file SHA256VerifyTest.cpp
* @brief SHA-256 文件校验测试
*
* 测试覆盖:
* - SHA-256 哈希计算
* - FileCompletePacketV2 构建与解析
* - 文件完整性校验逻辑
* - 校验失败处理
*/
#include <gtest/gtest.h>
#include <cstring>
#include <vector>
#include <array>
#include <string>
#include <iomanip>
#include <sstream>
// ============================================
// 简化版 SHA-256 实现(测试专用)
// 生产环境使用 OpenSSL 或 Windows CNG
// ============================================
class SHA256 {
public:
SHA256() { Reset(); }
void Reset() {
m_state[0] = 0x6a09e667;
m_state[1] = 0xbb67ae85;
m_state[2] = 0x3c6ef372;
m_state[3] = 0xa54ff53a;
m_state[4] = 0x510e527f;
m_state[5] = 0x9b05688c;
m_state[6] = 0x1f83d9ab;
m_state[7] = 0x5be0cd19;
m_bitLen = 0;
m_bufferLen = 0;
}
void Update(const uint8_t* data, size_t len) {
for (size_t i = 0; i < len; ++i) {
m_buffer[m_bufferLen++] = data[i];
if (m_bufferLen == 64) {
Transform();
m_bitLen += 512;
m_bufferLen = 0;
}
}
}
void Update(const std::vector<uint8_t>& data) {
Update(data.data(), data.size());
}
std::array<uint8_t, 32> Finalize() {
size_t i = m_bufferLen;
// Pad
if (m_bufferLen < 56) {
m_buffer[i++] = 0x80;
while (i < 56) m_buffer[i++] = 0x00;
} else {
m_buffer[i++] = 0x80;
while (i < 64) m_buffer[i++] = 0x00;
Transform();
memset(m_buffer, 0, 56);
}
// Append length
m_bitLen += m_bufferLen * 8;
m_buffer[63] = static_cast<uint8_t>(m_bitLen);
m_buffer[62] = static_cast<uint8_t>(m_bitLen >> 8);
m_buffer[61] = static_cast<uint8_t>(m_bitLen >> 16);
m_buffer[60] = static_cast<uint8_t>(m_bitLen >> 24);
m_buffer[59] = static_cast<uint8_t>(m_bitLen >> 32);
m_buffer[58] = static_cast<uint8_t>(m_bitLen >> 40);
m_buffer[57] = static_cast<uint8_t>(m_bitLen >> 48);
m_buffer[56] = static_cast<uint8_t>(m_bitLen >> 56);
Transform();
// Output
std::array<uint8_t, 32> hash;
for (int j = 0; j < 8; ++j) {
hash[j * 4 + 0] = (m_state[j] >> 24) & 0xff;
hash[j * 4 + 1] = (m_state[j] >> 16) & 0xff;
hash[j * 4 + 2] = (m_state[j] >> 8) & 0xff;
hash[j * 4 + 3] = m_state[j] & 0xff;
}
return hash;
}
static std::string HashToHex(const std::array<uint8_t, 32>& hash) {
std::ostringstream ss;
for (uint8_t b : hash) {
ss << std::hex << std::setfill('0') << std::setw(2) << static_cast<int>(b);
}
return ss.str();
}
private:
static uint32_t RotR(uint32_t x, uint32_t n) { return (x >> n) | (x << (32 - n)); }
static uint32_t Ch(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ (~x & z); }
static uint32_t Maj(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ (x & z) ^ (y & z); }
static uint32_t Sig0(uint32_t x) { return RotR(x, 2) ^ RotR(x, 13) ^ RotR(x, 22); }
static uint32_t Sig1(uint32_t x) { return RotR(x, 6) ^ RotR(x, 11) ^ RotR(x, 25); }
static uint32_t sig0(uint32_t x) { return RotR(x, 7) ^ RotR(x, 18) ^ (x >> 3); }
static uint32_t sig1(uint32_t x) { return RotR(x, 17) ^ RotR(x, 19) ^ (x >> 10); }
void Transform() {
static const uint32_t K[64] = {
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
};
uint32_t W[64];
for (int i = 0; i < 16; ++i) {
W[i] = (m_buffer[i * 4] << 24) | (m_buffer[i * 4 + 1] << 16) |
(m_buffer[i * 4 + 2] << 8) | m_buffer[i * 4 + 3];
}
for (int i = 16; i < 64; ++i) {
W[i] = sig1(W[i - 2]) + W[i - 7] + sig0(W[i - 15]) + W[i - 16];
}
uint32_t a = m_state[0], b = m_state[1], c = m_state[2], d = m_state[3];
uint32_t e = m_state[4], f = m_state[5], g = m_state[6], h = m_state[7];
for (int i = 0; i < 64; ++i) {
uint32_t t1 = h + Sig1(e) + Ch(e, f, g) + K[i] + W[i];
uint32_t t2 = Sig0(a) + Maj(a, b, c);
h = g; g = f; f = e; e = d + t1;
d = c; c = b; b = a; a = t1 + t2;
}
m_state[0] += a; m_state[1] += b; m_state[2] += c; m_state[3] += d;
m_state[4] += e; m_state[5] += f; m_state[6] += g; m_state[7] += h;
}
uint32_t m_state[8];
uint8_t m_buffer[64];
uint64_t m_bitLen;
size_t m_bufferLen;
};
// ============================================
// 协议结构(测试专用)
// ============================================
#pragma pack(push, 1)
struct FileCompletePacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint64_t fileSize;
uint8_t sha256[32];
};
enum FileErrorV2 : uint8_t {
FEV2_OK = 0,
FEV2_TARGET_OFFLINE = 1,
FEV2_VERSION_MISMATCH = 2,
FEV2_FILE_NOT_FOUND = 3,
FEV2_ACCESS_DENIED = 4,
FEV2_DISK_FULL = 5,
FEV2_TRANSFER_CANCEL = 6,
FEV2_CHECKSUM_ERROR = 7,
FEV2_HASH_MISMATCH = 8,
};
#pragma pack(pop)
// ============================================
// 辅助函数
// ============================================
std::vector<uint8_t> BuildFileCompletePacket(
uint64_t transferID,
uint64_t srcClientID,
uint64_t dstClientID,
uint32_t fileIndex,
uint64_t fileSize,
const std::array<uint8_t, 32>& sha256)
{
std::vector<uint8_t> buffer(sizeof(FileCompletePacketV2));
FileCompletePacketV2* pkt = reinterpret_cast<FileCompletePacketV2*>(buffer.data());
pkt->cmd = 91; // COMMAND_FILE_COMPLETE_V2
pkt->transferID = transferID;
pkt->srcClientID = srcClientID;
pkt->dstClientID = dstClientID;
pkt->fileIndex = fileIndex;
pkt->fileSize = fileSize;
memcpy(pkt->sha256, sha256.data(), 32);
return buffer;
}
bool ParseFileCompletePacket(
const uint8_t* buffer, size_t len,
FileCompletePacketV2& pkt)
{
if (len < sizeof(FileCompletePacketV2)) {
return false;
}
memcpy(&pkt, buffer, sizeof(FileCompletePacketV2));
return true;
}
// 校验文件完整性
FileErrorV2 VerifyFileIntegrity(
const std::array<uint8_t, 32>& expectedHash,
const std::vector<uint8_t>& fileData)
{
SHA256 sha;
sha.Update(fileData);
auto actualHash = sha.Finalize();
if (actualHash == expectedHash) {
return FEV2_OK;
}
return FEV2_HASH_MISMATCH;
}
// ============================================
// SHA-256 基础测试
// ============================================
class SHA256Test : public ::testing::Test {};
TEST_F(SHA256Test, EmptyString) {
SHA256 sha;
auto hash = sha.Finalize();
// SHA-256("") = e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
std::string hex = SHA256::HashToHex(hash);
EXPECT_EQ(hex, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
}
TEST_F(SHA256Test, HelloWorld) {
SHA256 sha;
const char* msg = "Hello, World!";
sha.Update(reinterpret_cast<const uint8_t*>(msg), strlen(msg));
auto hash = sha.Finalize();
// SHA-256("Hello, World!") = dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f
std::string hex = SHA256::HashToHex(hash);
EXPECT_EQ(hex, "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f");
}
TEST_F(SHA256Test, ABC) {
SHA256 sha;
const char* msg = "abc";
sha.Update(reinterpret_cast<const uint8_t*>(msg), strlen(msg));
auto hash = sha.Finalize();
// SHA-256("abc") = ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
std::string hex = SHA256::HashToHex(hash);
EXPECT_EQ(hex, "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad");
}
TEST_F(SHA256Test, IncrementalUpdate) {
SHA256 sha1, sha2;
const char* part1 = "Hello, ";
const char* part2 = "World!";
const char* full = "Hello, World!";
sha1.Update(reinterpret_cast<const uint8_t*>(part1), strlen(part1));
sha1.Update(reinterpret_cast<const uint8_t*>(part2), strlen(part2));
sha2.Update(reinterpret_cast<const uint8_t*>(full), strlen(full));
auto hash1 = sha1.Finalize();
auto hash2 = sha2.Finalize();
EXPECT_EQ(hash1, hash2);
}
TEST_F(SHA256Test, LargeData) {
SHA256 sha;
// 1 MB of zeros
std::vector<uint8_t> data(1024 * 1024, 0);
sha.Update(data);
auto hash = sha.Finalize();
// 验证计算成功(不崩溃)
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(SHA256Test, BinaryData) {
SHA256 sha;
std::vector<uint8_t> data = {0x00, 0x01, 0x02, 0x03, 0xff, 0xfe, 0xfd, 0xfc};
sha.Update(data);
auto hash = sha.Finalize();
// 验证计算成功
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(SHA256Test, Reset) {
SHA256 sha;
sha.Update(reinterpret_cast<const uint8_t*>("test"), 4);
sha.Reset();
auto hash = sha.Finalize();
// 重置后应该等于空字符串的哈希
std::string hex = SHA256::HashToHex(hash);
EXPECT_EQ(hex, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
}
// ============================================
// FileCompletePacketV2 测试
// ============================================
class FileCompletePacketTest : public ::testing::Test {};
TEST_F(FileCompletePacketTest, BuildAndParse) {
std::array<uint8_t, 32> sha256 = {};
sha256[0] = 0xAA;
sha256[31] = 0xBB;
auto buffer = BuildFileCompletePacket(12345, 100, 200, 5, 1024 * 1024, sha256);
FileCompletePacketV2 pkt;
ASSERT_TRUE(ParseFileCompletePacket(buffer.data(), buffer.size(), pkt));
EXPECT_EQ(pkt.cmd, 91);
EXPECT_EQ(pkt.transferID, 12345u);
EXPECT_EQ(pkt.srcClientID, 100u);
EXPECT_EQ(pkt.dstClientID, 200u);
EXPECT_EQ(pkt.fileIndex, 5u);
EXPECT_EQ(pkt.fileSize, 1024u * 1024u);
EXPECT_EQ(pkt.sha256[0], 0xAA);
EXPECT_EQ(pkt.sha256[31], 0xBB);
}
TEST_F(FileCompletePacketTest, TruncatedPacket) {
std::array<uint8_t, 32> sha256 = {};
auto buffer = BuildFileCompletePacket(1, 0, 0, 0, 100, sha256);
// 截断
FileCompletePacketV2 pkt;
EXPECT_FALSE(ParseFileCompletePacket(buffer.data(), sizeof(FileCompletePacketV2) - 10, pkt));
}
TEST_F(FileCompletePacketTest, LargeFileSize) {
std::array<uint8_t, 32> sha256 = {};
uint64_t largeSize = 100ULL * 1024 * 1024 * 1024; // 100 GB
auto buffer = BuildFileCompletePacket(1, 0, 0, 0, largeSize, sha256);
FileCompletePacketV2 pkt;
ASSERT_TRUE(ParseFileCompletePacket(buffer.data(), buffer.size(), pkt));
EXPECT_EQ(pkt.fileSize, largeSize);
}
// ============================================
// 文件完整性校验测试
// ============================================
class FileIntegrityTest : public ::testing::Test {};
TEST_F(FileIntegrityTest, CorrectHash) {
std::vector<uint8_t> fileData = {'H', 'e', 'l', 'l', 'o'};
SHA256 sha;
sha.Update(fileData);
auto expectedHash = sha.Finalize();
auto result = VerifyFileIntegrity(expectedHash, fileData);
EXPECT_EQ(result, FEV2_OK);
}
TEST_F(FileIntegrityTest, IncorrectHash) {
std::vector<uint8_t> fileData = {'H', 'e', 'l', 'l', 'o'};
std::array<uint8_t, 32> wrongHash = {}; // 全零的错误哈希
auto result = VerifyFileIntegrity(wrongHash, fileData);
EXPECT_EQ(result, FEV2_HASH_MISMATCH);
}
TEST_F(FileIntegrityTest, CorruptedData) {
std::vector<uint8_t> originalData = {'H', 'e', 'l', 'l', 'o'};
SHA256 sha;
sha.Update(originalData);
auto originalHash = sha.Finalize();
// 修改一个字节
std::vector<uint8_t> corruptedData = originalData;
corruptedData[2] = 'X';
auto result = VerifyFileIntegrity(originalHash, corruptedData);
EXPECT_EQ(result, FEV2_HASH_MISMATCH);
}
TEST_F(FileIntegrityTest, EmptyFile) {
std::vector<uint8_t> emptyData;
SHA256 sha;
auto emptyHash = sha.Finalize();
auto result = VerifyFileIntegrity(emptyHash, emptyData);
EXPECT_EQ(result, FEV2_OK);
}
TEST_F(FileIntegrityTest, SingleBitDifference) {
std::vector<uint8_t> data1 = {0x00};
std::vector<uint8_t> data2 = {0x01}; // 单比特差异
SHA256 sha1, sha2;
sha1.Update(data1);
sha2.Update(data2);
auto hash1 = sha1.Finalize();
auto hash2 = sha2.Finalize();
// 哈希应该完全不同
EXPECT_NE(hash1, hash2);
// 验证错误检测
auto result = VerifyFileIntegrity(hash1, data2);
EXPECT_EQ(result, FEV2_HASH_MISMATCH);
}
// ============================================
// 流式哈希计算测试
// ============================================
class StreamingHashTest : public ::testing::Test {};
TEST_F(StreamingHashTest, ChunkedUpdate) {
std::vector<uint8_t> fullData(10000);
for (size_t i = 0; i < fullData.size(); ++i) {
fullData[i] = static_cast<uint8_t>(i & 0xFF);
}
// 一次性计算
SHA256 sha1;
sha1.Update(fullData);
auto hash1 = sha1.Finalize();
// 分块计算
SHA256 sha2;
size_t chunkSize = 64; // SHA-256 块大小
for (size_t i = 0; i < fullData.size(); i += chunkSize) {
size_t len = std::min(chunkSize, fullData.size() - i);
sha2.Update(fullData.data() + i, len);
}
auto hash2 = sha2.Finalize();
EXPECT_EQ(hash1, hash2);
}
TEST_F(StreamingHashTest, VariableChunkSizes) {
std::vector<uint8_t> data(1000);
for (size_t i = 0; i < data.size(); ++i) {
data[i] = static_cast<uint8_t>(i);
}
SHA256 sha1;
sha1.Update(data);
auto expected = sha1.Finalize();
// 不同大小的块
std::vector<size_t> chunkSizes = {1, 7, 63, 64, 65, 128, 1000};
for (size_t chunkSize : chunkSizes) {
SHA256 sha2;
for (size_t i = 0; i < data.size(); i += chunkSize) {
size_t len = std::min(chunkSize, data.size() - i);
sha2.Update(data.data() + i, len);
}
auto actual = sha2.Finalize();
EXPECT_EQ(expected, actual) << "Failed for chunk size: " << chunkSize;
}
}
// ============================================
// 文件传输完整性验证场景
// ============================================
class TransferVerificationTest : public ::testing::Test {};
TEST_F(TransferVerificationTest, SimulateSuccessfulTransfer) {
// 模拟文件数据
std::vector<uint8_t> fileData(1024 * 64); // 64 KB
for (size_t i = 0; i < fileData.size(); ++i) {
fileData[i] = static_cast<uint8_t>(i * 17); // 伪随机数据
}
// 发送方计算哈希
SHA256 senderSha;
senderSha.Update(fileData);
auto senderHash = senderSha.Finalize();
// 构建校验包
auto packet = BuildFileCompletePacket(12345, 100, 0, 0, fileData.size(), senderHash);
// 接收方解析并验证
FileCompletePacketV2 pkt;
ASSERT_TRUE(ParseFileCompletePacket(packet.data(), packet.size(), pkt));
EXPECT_EQ(pkt.fileSize, fileData.size());
// 接收方计算哈希并比较
std::array<uint8_t, 32> expectedHash;
memcpy(expectedHash.data(), pkt.sha256, 32);
auto result = VerifyFileIntegrity(expectedHash, fileData);
EXPECT_EQ(result, FEV2_OK);
}
TEST_F(TransferVerificationTest, SimulateCorruptedTransfer) {
// 原始文件
std::vector<uint8_t> originalData(1024);
for (size_t i = 0; i < originalData.size(); ++i) {
originalData[i] = static_cast<uint8_t>(i);
}
// 发送方计算哈希
SHA256 senderSha;
senderSha.Update(originalData);
auto senderHash = senderSha.Finalize();
// 模拟传输中数据损坏
std::vector<uint8_t> receivedData = originalData;
receivedData[512] ^= 0xFF; // 翻转一个字节
// 校验失败
auto result = VerifyFileIntegrity(senderHash, receivedData);
EXPECT_EQ(result, FEV2_HASH_MISMATCH);
}
TEST_F(TransferVerificationTest, MultipleFilesInTransfer) {
struct FileInfo {
std::vector<uint8_t> data;
std::array<uint8_t, 32> hash;
};
std::vector<FileInfo> files(5);
// 生成测试文件
for (int i = 0; i < 5; ++i) {
files[i].data.resize(1000 * (i + 1));
for (size_t j = 0; j < files[i].data.size(); ++j) {
files[i].data[j] = static_cast<uint8_t>(j + i * 7);
}
SHA256 sha;
sha.Update(files[i].data);
files[i].hash = sha.Finalize();
}
// 验证每个文件
for (int i = 0; i < 5; ++i) {
auto result = VerifyFileIntegrity(files[i].hash, files[i].data);
EXPECT_EQ(result, FEV2_OK) << "File " << i << " verification failed";
}
// 交叉验证应该失败
auto crossResult = VerifyFileIntegrity(files[0].hash, files[1].data);
EXPECT_EQ(crossResult, FEV2_HASH_MISMATCH);
}
// ============================================
// 边界条件测试
// ============================================
class HashBoundaryTest : public ::testing::Test {};
TEST_F(HashBoundaryTest, ExactBlockSize) {
// SHA-256 块大小是 64 字节
std::vector<uint8_t> data(64, 'A');
SHA256 sha;
sha.Update(data);
auto hash = sha.Finalize();
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(HashBoundaryTest, BlockSizePlusOne) {
std::vector<uint8_t> data(65, 'B');
SHA256 sha;
sha.Update(data);
auto hash = sha.Finalize();
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(HashBoundaryTest, BlockSizeMinusOne) {
std::vector<uint8_t> data(63, 'C');
SHA256 sha;
sha.Update(data);
auto hash = sha.Finalize();
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(HashBoundaryTest, MultipleBlocks) {
std::vector<uint8_t> data(64 * 10, 'D');
SHA256 sha;
sha.Update(data);
auto hash = sha.Finalize();
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(HashBoundaryTest, PaddingBoundary55) {
// 55 字节需要特殊处理padding + length 刚好 64
std::vector<uint8_t> data(55, 'E');
SHA256 sha;
sha.Update(data);
auto hash = sha.Finalize();
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(HashBoundaryTest, PaddingBoundary56) {
// 56 字节需要额外的块
std::vector<uint8_t> data(56, 'F');
SHA256 sha;
sha.Update(data);
auto hash = sha.Finalize();
EXPECT_EQ(hash.size(), 32u);
}

View File

@@ -0,0 +1,378 @@
// GeoLocationTest.cpp - IP地理位置API单元测试
// 测试 location.h 中的 GetGeoLocation 功能
#include <gtest/gtest.h>
#ifdef _WIN32
#include <winsock2.h>
#include <windows.h>
#include <wininet.h>
#include <ws2tcpip.h>
#pragma comment(lib, "wininet.lib")
#pragma comment(lib, "ws2_32.lib")
#include <string>
#include <vector>
// ============================================
// 简化的测试版本 - 不依赖jsoncpp
// ============================================
// UTF-8 转 ANSI
inline std::string Utf8ToAnsi(const std::string& utf8)
{
if (utf8.empty()) return "";
int wideLen = MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, NULL, 0);
if (wideLen <= 0) return utf8;
std::wstring wideStr(wideLen, 0);
MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, &wideStr[0], wideLen);
int ansiLen = WideCharToMultiByte(CP_ACP, 0, wideStr.c_str(), -1, NULL, 0, NULL, NULL);
if (ansiLen <= 0) return utf8;
std::string ansiStr(ansiLen, 0);
WideCharToMultiByte(CP_ACP, 0, wideStr.c_str(), -1, &ansiStr[0], ansiLen, NULL, NULL);
if (!ansiStr.empty() && ansiStr.back() == '\0') ansiStr.pop_back();
return ansiStr;
}
// 简单JSON值提取 (仅用于测试不依赖jsoncpp)
std::string ExtractJsonString(const std::string& json, const std::string& key)
{
std::string searchKey = "\"" + key + "\"";
size_t keyPos = json.find(searchKey);
if (keyPos == std::string::npos) return "";
size_t colonPos = json.find(':', keyPos);
if (colonPos == std::string::npos) return "";
size_t valueStart = json.find_first_not_of(" \t\n\r", colonPos + 1);
if (valueStart == std::string::npos) return "";
if (json[valueStart] == '"') {
size_t valueEnd = json.find('"', valueStart + 1);
if (valueEnd == std::string::npos) return "";
return json.substr(valueStart + 1, valueEnd - valueStart - 1);
}
// 数字或布尔值
size_t valueEnd = json.find_first_of(",}\n\r", valueStart);
if (valueEnd == std::string::npos) valueEnd = json.length();
std::string value = json.substr(valueStart, valueEnd - valueStart);
// 去除尾部空格
while (!value.empty() && (value.back() == ' ' || value.back() == '\t')) {
value.pop_back();
}
return value;
}
// API配置结构
struct GeoApiConfig {
const char* name;
const char* urlFmt;
const char* cityField;
const char* countryField;
const char* checkField;
const char* checkValue;
bool useHttps;
};
// 测试用API配置
static const GeoApiConfig testApis[] = {
{"ip-api.com", "http://ip-api.com/json/%s?fields=status,country,city", "city", "country", "status", "success", false},
{"ipinfo.io", "http://ipinfo.io/%s/json", "city", "country", "", "", false},
{"ipapi.co", "https://ipapi.co/%s/json/", "city", "country_name", "error", "", true},
};
// 测试单个API
struct ApiTestResult {
bool success;
int httpStatus;
std::string city;
std::string country;
std::string error;
double latencyMs;
};
ApiTestResult TestSingleApi(const GeoApiConfig& api, const std::string& ip)
{
ApiTestResult result = {false, 0, "", "", "", 0};
DWORD startTime = GetTickCount();
HINTERNET hInternet = InternetOpenA("GeoLocationTest", INTERNET_OPEN_TYPE_DIRECT, NULL, NULL, 0);
if (!hInternet) {
result.error = "InternetOpen failed";
return result;
}
DWORD timeout = 10000;
InternetSetOptionA(hInternet, INTERNET_OPTION_CONNECT_TIMEOUT, &timeout, sizeof(timeout));
InternetSetOptionA(hInternet, INTERNET_OPTION_SEND_TIMEOUT, &timeout, sizeof(timeout));
InternetSetOptionA(hInternet, INTERNET_OPTION_RECEIVE_TIMEOUT, &timeout, sizeof(timeout));
char urlBuf[256];
sprintf_s(urlBuf, api.urlFmt, ip.c_str());
DWORD flags = INTERNET_FLAG_RELOAD;
if (api.useHttps) flags |= INTERNET_FLAG_SECURE;
HINTERNET hConnect = InternetOpenUrlA(hInternet, urlBuf, NULL, 0, flags, 0);
if (!hConnect) {
result.error = "InternetOpenUrl failed: " + std::to_string(GetLastError());
InternetCloseHandle(hInternet);
return result;
}
// 获取HTTP状态码
DWORD statusCode = 0;
DWORD statusSize = sizeof(statusCode);
if (HttpQueryInfoA(hConnect, HTTP_QUERY_STATUS_CODE | HTTP_QUERY_FLAG_NUMBER, &statusCode, &statusSize, NULL)) {
result.httpStatus = statusCode;
}
// 读取响应
std::string readBuffer;
char buffer[4096];
DWORD bytesRead;
while (InternetReadFile(hConnect, buffer, sizeof(buffer), &bytesRead) && bytesRead > 0) {
readBuffer.append(buffer, bytesRead);
}
InternetCloseHandle(hConnect);
InternetCloseHandle(hInternet);
result.latencyMs = (double)(GetTickCount() - startTime);
// 检查HTTP错误
if (result.httpStatus >= 400) {
result.error = "HTTP " + std::to_string(result.httpStatus);
if (result.httpStatus == 429) {
result.error += " (Rate Limited)";
}
return result;
}
// 检查响应体错误
if (readBuffer.find("Rate limit") != std::string::npos ||
readBuffer.find("rate limit") != std::string::npos) {
result.error = "Rate limited (body)";
return result;
}
// 解析JSON
if (api.checkField && api.checkField[0]) {
std::string checkVal = ExtractJsonString(readBuffer, api.checkField);
if (api.checkValue && api.checkValue[0]) {
if (checkVal != api.checkValue) {
result.error = "Check failed: " + std::string(api.checkField) + "=" + checkVal;
return result;
}
} else {
if (checkVal == "true") {
result.error = "Error flag set";
return result;
}
}
}
result.city = Utf8ToAnsi(ExtractJsonString(readBuffer, api.cityField));
result.country = Utf8ToAnsi(ExtractJsonString(readBuffer, api.countryField));
result.success = !result.city.empty() || !result.country.empty();
if (!result.success) {
result.error = "No city/country in response";
}
return result;
}
// ============================================
// 测试用例
// ============================================
class GeoLocationTest : public ::testing::Test {
protected:
void SetUp() override {
// 初始化Winsock
WSADATA wsaData;
WSAStartup(MAKEWORD(2, 2), &wsaData);
}
void TearDown() override {
WSACleanup();
}
};
// 测试公网IP (Google DNS)
TEST_F(GeoLocationTest, TestPublicIP_GoogleDNS) {
std::string testIP = "8.8.8.8";
std::cout << "\n=== Testing IP: " << testIP << " ===\n";
int successCount = 0;
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] ";
if (result.success) {
std::cout << "OK - " << result.city << ", " << result.country;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
successCount++;
} else {
std::cout << "FAIL - " << result.error;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
}
}
// 至少一个API应该成功
EXPECT_GE(successCount, 1) << "At least one API should succeed";
}
// 测试另一个公网IP (Cloudflare DNS)
TEST_F(GeoLocationTest, TestPublicIP_CloudflareDNS) {
std::string testIP = "1.1.1.1";
std::cout << "\n=== Testing IP: " << testIP << " ===\n";
int successCount = 0;
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] ";
if (result.success) {
std::cout << "OK - " << result.city << ", " << result.country;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
successCount++;
} else {
std::cout << "FAIL - " << result.error;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
}
}
EXPECT_GE(successCount, 1) << "At least one API should succeed";
}
// 测试中国IP
TEST_F(GeoLocationTest, TestPublicIP_ChinaIP) {
std::string testIP = "114.114.114.114"; // 114DNS
std::cout << "\n=== Testing IP: " << testIP << " (China) ===\n";
int successCount = 0;
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] ";
if (result.success) {
std::cout << "OK - " << result.city << ", " << result.country;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
successCount++;
} else {
std::cout << "FAIL - " << result.error;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
}
}
EXPECT_GE(successCount, 1) << "At least one API should succeed";
}
// 测试ip-api.com单独
TEST_F(GeoLocationTest, TestIpApiCom) {
std::string testIP = "8.8.8.8";
auto result = TestSingleApi(testApis[0], testIP);
std::cout << "\n[ip-api.com] HTTP: " << result.httpStatus
<< ", Latency: " << result.latencyMs << "ms\n";
if (result.success) {
std::cout << "City: " << result.city << ", Country: " << result.country << "\n";
EXPECT_FALSE(result.country.empty());
} else {
std::cout << "Error: " << result.error << "\n";
// 如果是429就跳过不算失败
if (result.httpStatus == 429) {
GTEST_SKIP() << "Rate limited, skipping";
}
}
}
// 测试ipinfo.io单独
TEST_F(GeoLocationTest, TestIpInfoIo) {
std::string testIP = "8.8.8.8";
auto result = TestSingleApi(testApis[1], testIP);
std::cout << "\n[ipinfo.io] HTTP: " << result.httpStatus
<< ", Latency: " << result.latencyMs << "ms\n";
if (result.success) {
std::cout << "City: " << result.city << ", Country: " << result.country << "\n";
EXPECT_FALSE(result.country.empty());
} else {
std::cout << "Error: " << result.error << "\n";
if (result.httpStatus == 429) {
GTEST_SKIP() << "Rate limited, skipping";
}
}
}
// 测试ipapi.co单独
TEST_F(GeoLocationTest, TestIpApiCo) {
std::string testIP = "8.8.8.8";
auto result = TestSingleApi(testApis[2], testIP);
std::cout << "\n[ipapi.co] HTTP: " << result.httpStatus
<< ", Latency: " << result.latencyMs << "ms\n";
if (result.success) {
std::cout << "City: " << result.city << ", Country: " << result.country << "\n";
EXPECT_FALSE(result.country.empty());
} else {
std::cout << "Error: " << result.error << "\n";
if (result.httpStatus == 429) {
GTEST_SKIP() << "Rate limited, skipping";
}
}
}
// 测试HTTP状态码检测
TEST_F(GeoLocationTest, TestHttpStatusCodeDetection) {
// 使用一个会返回错误的请求
std::string testIP = "invalid-ip";
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] Invalid IP test: HTTP "
<< result.httpStatus << ", Error: " << result.error << "\n";
// 应该失败
EXPECT_FALSE(result.success);
}
}
// 测试所有API的响应时间
TEST_F(GeoLocationTest, TestApiLatency) {
std::string testIP = "8.8.8.8";
std::cout << "\n=== API Latency Test ===\n";
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] " << result.latencyMs << "ms";
if (result.success) {
std::cout << " (OK)";
} else {
std::cout << " (FAIL: " << result.error << ")";
}
std::cout << "\n";
// 响应时间应该在合理范围内 (< 15秒)
EXPECT_LT(result.latencyMs, 15000);
}
}
#else
// 非Windows平台 - 跳过测试
TEST(GeoLocationTest, SkipNonWindows) {
GTEST_SKIP() << "GeoLocation tests only run on Windows";
}
#endif

View File

@@ -0,0 +1,642 @@
/**
* @file HeaderTest.cpp
* @brief 协议头验证和加密测试
*
* 测试覆盖:
* - 协议头格式和常量
* - 加密/解密函数正确性
* - 多版本头部识别
* - 头部生成和验证往返
*/
#include <gtest/gtest.h>
#include <cstring>
#include <cstdint>
#include <vector>
#include <array>
// ============================================
// 协议常量定义(测试专用副本)
// ============================================
#define MSG_HEADER "HELL"
const int FLAG_COMPLEN = 4;
const int FLAG_LENGTH = 8;
const int HDR_LENGTH = FLAG_LENGTH + 2 * sizeof(unsigned int); // 16
const int MIN_COMLEN = 12;
enum HeaderEncType {
HeaderEncUnknown = -1,
HeaderEncNone,
HeaderEncV0,
HeaderEncV1,
HeaderEncV2,
HeaderEncV3,
HeaderEncV4,
HeaderEncV5,
HeaderEncV6,
HeaderEncNum,
};
enum FlagType {
FLAG_WINOS = -1,
FLAG_UNKNOWN = 0,
FLAG_SHINE = 1,
FLAG_FUCK = 2,
FLAG_HELLO = 3,
FLAG_HELL = 4,
};
// ============================================
// 加密/解密函数(测试专用副本)
// ============================================
inline void default_encrypt(unsigned char* data, size_t length, unsigned char key)
{
data[FLAG_LENGTH - 2] = data[FLAG_LENGTH - 1] = 0;
}
inline void default_decrypt(unsigned char* data, size_t length, unsigned char key)
{
}
inline void encrypt(unsigned char* data, size_t length, unsigned char key)
{
if (key == 0) return;
for (size_t i = 0; i < length; ++i) {
unsigned char k = static_cast<unsigned char>(key ^ (i * 31));
int value = static_cast<int>(data[i]);
switch (i % 4) {
case 0:
value += k;
break;
case 1:
value = value ^ k;
break;
case 2:
value -= k;
break;
case 3:
value = ~(value ^ k);
break;
}
data[i] = static_cast<unsigned char>(value & 0xFF);
}
}
inline void decrypt(unsigned char* data, size_t length, unsigned char key)
{
if (key == 0) return;
for (size_t i = 0; i < length; ++i) {
unsigned char k = static_cast<unsigned char>(key ^ (i * 31));
int value = static_cast<int>(data[i]);
switch (i % 4) {
case 0:
value -= k;
break;
case 1:
value = value ^ k;
break;
case 2:
value += k;
break;
case 3:
value = ~(value) ^ k;
break;
}
data[i] = static_cast<unsigned char>(value & 0xFF);
}
}
// ============================================
// HeaderFlag 结构体
// ============================================
typedef struct HeaderFlag {
char Data[FLAG_LENGTH + 1];
HeaderFlag(const char header[FLAG_LENGTH + 1])
{
memcpy(Data, header, sizeof(Data));
}
char& operator[](int i)
{
return Data[i];
}
const char operator[](int i) const
{
return Data[i];
}
const char* data() const
{
return Data;
}
} HeaderFlag;
typedef void (*EncFun)(unsigned char* data, size_t length, unsigned char key);
typedef void (*DecFun)(unsigned char* data, size_t length, unsigned char key);
// ============================================
// 头部生成函数
// ============================================
inline HeaderFlag GetHead(EncFun enc, unsigned char fixedKey = 0)
{
char header[FLAG_LENGTH + 1] = { 'H','E','L','L', 0 };
HeaderFlag H(header);
unsigned char key = (fixedKey != 0) ? fixedKey : (time(0) % 256);
H[FLAG_LENGTH - 2] = key;
H[FLAG_LENGTH - 1] = ~key;
enc((unsigned char*)H.data(), FLAG_COMPLEN, H[FLAG_LENGTH - 2]);
return H;
}
// ============================================
// 头部验证函数
// ============================================
inline int compare(const char *flag, const char *magic, int len, DecFun dec, unsigned char key)
{
unsigned char buf[32] = {};
memcpy(buf, flag, MIN_COMLEN);
dec(buf, len, key);
if (memcmp(buf, magic, len) == 0) {
memcpy((void*)flag, buf, MIN_COMLEN);
return 0;
}
return -1;
}
inline FlagType CheckHead(const char* flag, DecFun dec)
{
FlagType type = FLAG_UNKNOWN;
if (compare(flag, MSG_HEADER, FLAG_COMPLEN, dec, flag[6]) == 0) {
type = FLAG_HELL;
} else if (compare(flag, "Shine", 5, dec, 0) == 0) {
type = FLAG_SHINE;
} else if (compare(flag, "<<FUCK>>", 8, dec, flag[9]) == 0) {
type = FLAG_FUCK;
} else if (compare(flag, "Hello?", 6, dec, flag[6]) == 0) {
type = FLAG_HELLO;
} else {
type = FLAG_UNKNOWN;
}
return type;
}
inline FlagType CheckHeadMulti(char* flag, HeaderEncType& funcHit)
{
static const DecFun methods[] = { default_decrypt, decrypt };
static const int methodNum = sizeof(methods) / sizeof(DecFun);
char buffer[MIN_COMLEN + 4] = {};
for (int i = 0; i < methodNum; ++i) {
memcpy(buffer, flag, MIN_COMLEN);
FlagType type = CheckHead(buffer, methods[i]);
if (type != FLAG_UNKNOWN) {
memcpy(flag, buffer, MIN_COMLEN);
funcHit = HeaderEncType(i);
return type;
}
}
funcHit = HeaderEncUnknown;
return FLAG_UNKNOWN;
}
// ============================================
// 协议常量测试
// ============================================
class HeaderConstantsTest : public ::testing::Test {};
TEST_F(HeaderConstantsTest, FlagLength) {
EXPECT_EQ(FLAG_LENGTH, 8);
}
TEST_F(HeaderConstantsTest, FlagCompLen) {
EXPECT_EQ(FLAG_COMPLEN, 4);
}
TEST_F(HeaderConstantsTest, HdrLength) {
// FLAG_LENGTH(8) + 2 * sizeof(uint32_t)(8) = 16
EXPECT_EQ(HDR_LENGTH, 16);
}
TEST_F(HeaderConstantsTest, MinComLen) {
EXPECT_EQ(MIN_COMLEN, 12);
}
TEST_F(HeaderConstantsTest, MsgHeader) {
EXPECT_STREQ(MSG_HEADER, "HELL");
EXPECT_EQ(strlen(MSG_HEADER), 4u);
}
// ============================================
// 加密/解密测试
// ============================================
class EncryptionTest : public ::testing::Test {};
TEST_F(EncryptionTest, ZeroKeyNoOp) {
unsigned char data[] = {0x01, 0x02, 0x03, 0x04, 0x05};
unsigned char original[] = {0x01, 0x02, 0x03, 0x04, 0x05};
encrypt(data, 5, 0);
EXPECT_EQ(memcmp(data, original, 5), 0);
decrypt(data, 5, 0);
EXPECT_EQ(memcmp(data, original, 5), 0);
}
TEST_F(EncryptionTest, EncryptDecryptRoundTrip) {
unsigned char original[] = {0x48, 0x45, 0x4C, 0x4C}; // "HELL"
unsigned char data[4];
memcpy(data, original, 4);
unsigned char key = 0x42;
encrypt(data, 4, key);
// 加密后应该不同
EXPECT_NE(memcmp(data, original, 4), 0);
decrypt(data, 4, key);
// 解密后应该恢复
EXPECT_EQ(memcmp(data, original, 4), 0);
}
TEST_F(EncryptionTest, DifferentKeysProduceDifferentResults) {
unsigned char data1[] = {0x48, 0x45, 0x4C, 0x4C};
unsigned char data2[] = {0x48, 0x45, 0x4C, 0x4C};
encrypt(data1, 4, 0x10);
encrypt(data2, 4, 0x20);
EXPECT_NE(memcmp(data1, data2, 4), 0);
}
TEST_F(EncryptionTest, PositionDependentEncryption) {
// 相同值在不同位置加密结果不同
unsigned char data[] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
unsigned char original[] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
encrypt(data, 8, 0x55);
// 验证加密后值不全相同(位置相关加密)
bool allSame = true;
for (int i = 1; i < 8; ++i) {
if (data[i] != data[0]) {
allSame = false;
break;
}
}
EXPECT_FALSE(allSame);
// 验证解密恢复
decrypt(data, 8, 0x55);
EXPECT_EQ(memcmp(data, original, 8), 0);
}
TEST_F(EncryptionTest, AllKeyValues) {
// 测试所有可能的密钥值
unsigned char original[] = {0x12, 0x34, 0x56, 0x78};
for (int key = 1; key < 256; ++key) {
unsigned char data[4];
memcpy(data, original, 4);
encrypt(data, 4, static_cast<unsigned char>(key));
decrypt(data, 4, static_cast<unsigned char>(key));
EXPECT_EQ(memcmp(data, original, 4), 0) << "Failed for key: " << key;
}
}
TEST_F(EncryptionTest, LargeDataRoundTrip) {
std::vector<unsigned char> original(1000);
for (size_t i = 0; i < original.size(); ++i) {
original[i] = static_cast<unsigned char>(i & 0xFF);
}
std::vector<unsigned char> data = original;
unsigned char key = 0x7F;
encrypt(data.data(), data.size(), key);
decrypt(data.data(), data.size(), key);
EXPECT_EQ(data, original);
}
// ============================================
// HeaderFlag 测试
// ============================================
class HeaderFlagTest : public ::testing::Test {};
TEST_F(HeaderFlagTest, Construction) {
char header[FLAG_LENGTH + 1] = { 'H','E','L','L', 0, 0, 0, 0, 0 };
HeaderFlag hf(header);
EXPECT_EQ(hf[0], 'H');
EXPECT_EQ(hf[1], 'E');
EXPECT_EQ(hf[2], 'L');
EXPECT_EQ(hf[3], 'L');
}
TEST_F(HeaderFlagTest, DataAccess) {
char header[FLAG_LENGTH + 1] = { 'T','E','S','T', 0x12, 0x34, 0x56, 0x78, 0 };
HeaderFlag hf(header);
EXPECT_EQ(memcmp(hf.data(), "TEST", 4), 0);
EXPECT_EQ(static_cast<unsigned char>(hf[4]), 0x12);
EXPECT_EQ(static_cast<unsigned char>(hf[5]), 0x34);
}
TEST_F(HeaderFlagTest, Modification) {
char header[FLAG_LENGTH + 1] = { 0 };
HeaderFlag hf(header);
hf[0] = 'A';
hf[1] = 'B';
EXPECT_EQ(hf[0], 'A');
EXPECT_EQ(hf[1], 'B');
}
// ============================================
// GetHead 测试
// ============================================
class GetHeadTest : public ::testing::Test {};
TEST_F(GetHeadTest, GeneratesValidHeader) {
HeaderFlag hf = GetHead(default_encrypt, 0x42);
// 检查基础格式
EXPECT_EQ(hf[0], 'H');
EXPECT_EQ(hf[1], 'E');
EXPECT_EQ(hf[2], 'L');
EXPECT_EQ(hf[3], 'L');
}
TEST_F(GetHeadTest, KeyAndInverseKey) {
unsigned char key = 0x42;
HeaderFlag hf = GetHead(default_encrypt, key);
// 使用 default_encrypt 时key 位置被清零
// 但我们需要验证生成逻辑
char rawHeader[FLAG_LENGTH + 1] = { 'H','E','L','L', 0 };
HeaderFlag expected(rawHeader);
expected[FLAG_LENGTH - 2] = key;
expected[FLAG_LENGTH - 1] = ~key;
default_encrypt((unsigned char*)expected.data(), FLAG_COMPLEN, expected[FLAG_LENGTH - 2]);
EXPECT_EQ(memcmp(hf.data(), expected.data(), FLAG_LENGTH), 0);
}
TEST_F(GetHeadTest, EncryptedHeader) {
unsigned char key = 0x55;
HeaderFlag hf = GetHead(encrypt, key);
// 头部应该被加密,不再是明文 "HELL"
EXPECT_NE(memcmp(hf.data(), "HELL", 4), 0);
// 密钥应该在正确位置
unsigned char storedKey = static_cast<unsigned char>(hf[FLAG_LENGTH - 2]);
unsigned char inverseKey = static_cast<unsigned char>(hf[FLAG_LENGTH - 1]);
EXPECT_EQ(storedKey, key);
EXPECT_EQ(inverseKey, static_cast<unsigned char>(~key));
}
// ============================================
// CheckHead 测试
// ============================================
class CheckHeadTest : public ::testing::Test {};
TEST_F(CheckHeadTest, IdentifyHellFlag) {
char header[MIN_COMLEN + 4] = { 'H','E','L','L', 0, 0, 0x42, static_cast<char>(~0x42), 0 };
FlagType type = CheckHead(header, default_decrypt);
EXPECT_EQ(type, FLAG_HELL);
}
TEST_F(CheckHeadTest, IdentifyShineFlag) {
char header[MIN_COMLEN + 4] = { 'S','h','i','n','e', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
FlagType type = CheckHead(header, default_decrypt);
EXPECT_EQ(type, FLAG_SHINE);
}
TEST_F(CheckHeadTest, UnknownFlag) {
char header[MIN_COMLEN + 4] = { 'X','Y','Z','W', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
FlagType type = CheckHead(header, default_decrypt);
EXPECT_EQ(type, FLAG_UNKNOWN);
}
TEST_F(CheckHeadTest, EncryptedHellFlag) {
// 生成加密的头部
unsigned char key = 0x33;
char header[FLAG_LENGTH + 1] = { 'H','E','L','L', 0, 0, 0, 0, 0 };
header[FLAG_LENGTH - 2] = key;
header[FLAG_LENGTH - 1] = ~key;
encrypt((unsigned char*)header, FLAG_COMPLEN, key);
char buffer[MIN_COMLEN + 4] = {};
memcpy(buffer, header, FLAG_LENGTH);
FlagType type = CheckHead(buffer, decrypt);
EXPECT_EQ(type, FLAG_HELL);
}
// ============================================
// CheckHeadMulti 测试
// ============================================
class CheckHeadMultiTest : public ::testing::Test {};
TEST_F(CheckHeadMultiTest, PlainTextHeader) {
char header[MIN_COMLEN + 4] = { 'H','E','L','L', 0, 0, 0x42, static_cast<char>(~0x42), 0 };
HeaderEncType encType;
FlagType type = CheckHeadMulti(header, encType);
EXPECT_EQ(type, FLAG_HELL);
EXPECT_EQ(encType, HeaderEncNone);
}
TEST_F(CheckHeadMultiTest, EncryptedHeader) {
unsigned char key = 0x77;
char header[MIN_COMLEN + 4] = { 'H','E','L','L', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
header[FLAG_LENGTH - 2] = key;
header[FLAG_LENGTH - 1] = ~key;
encrypt((unsigned char*)header, FLAG_COMPLEN, key);
HeaderEncType encType;
FlagType type = CheckHeadMulti(header, encType);
EXPECT_EQ(type, FLAG_HELL);
EXPECT_EQ(encType, HeaderEncV0); // encrypt 对应 V0
}
TEST_F(CheckHeadMultiTest, UnrecognizedHeader) {
char header[MIN_COMLEN + 4] = { 0xFF, 0xFE, 0xFD, 0xFC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
HeaderEncType encType;
FlagType type = CheckHeadMulti(header, encType);
EXPECT_EQ(type, FLAG_UNKNOWN);
EXPECT_EQ(encType, HeaderEncUnknown);
}
// ============================================
// 头部生成和验证往返测试
// ============================================
class HeaderRoundTripTest : public ::testing::Test {};
TEST_F(HeaderRoundTripTest, DefaultEncryptRoundTrip) {
HeaderFlag hf = GetHead(default_encrypt, 0x42);
char buffer[MIN_COMLEN + 4] = {};
memcpy(buffer, hf.data(), FLAG_LENGTH);
HeaderEncType encType;
FlagType type = CheckHeadMulti(buffer, encType);
EXPECT_EQ(type, FLAG_HELL);
}
TEST_F(HeaderRoundTripTest, EncryptRoundTrip) {
HeaderFlag hf = GetHead(encrypt, 0x88);
char buffer[MIN_COMLEN + 4] = {};
memcpy(buffer, hf.data(), FLAG_LENGTH);
HeaderEncType encType;
FlagType type = CheckHeadMulti(buffer, encType);
EXPECT_EQ(type, FLAG_HELL);
}
TEST_F(HeaderRoundTripTest, AllKeyValuesRoundTrip) {
for (int key = 1; key < 256; ++key) {
HeaderFlag hf = GetHead(encrypt, static_cast<unsigned char>(key));
char buffer[MIN_COMLEN + 4] = {};
memcpy(buffer, hf.data(), FLAG_LENGTH);
HeaderEncType encType;
FlagType type = CheckHeadMulti(buffer, encType);
EXPECT_EQ(type, FLAG_HELL) << "Failed for key: " << key;
}
}
// ============================================
// 数据包长度字段测试
// ============================================
class PacketLengthTest : public ::testing::Test {};
#pragma pack(push, 1)
struct PacketHeader {
char flag[FLAG_LENGTH];
uint32_t packedLength; // 压缩后长度
uint32_t originalLength; // 原始长度
};
#pragma pack(pop)
TEST_F(PacketLengthTest, HeaderSize) {
EXPECT_EQ(sizeof(PacketHeader), HDR_LENGTH);
}
TEST_F(PacketLengthTest, BuildAndParsePacket) {
PacketHeader pkt = {};
memcpy(pkt.flag, "HELL", 4);
pkt.flag[6] = 0x42;
pkt.flag[7] = ~0x42;
pkt.packedLength = 100;
pkt.originalLength = 200;
// 解析
uint32_t packedLen, origLen;
memcpy(&packedLen, reinterpret_cast<char*>(&pkt) + FLAG_LENGTH, sizeof(uint32_t));
memcpy(&origLen, reinterpret_cast<char*>(&pkt) + FLAG_LENGTH + sizeof(uint32_t), sizeof(uint32_t));
EXPECT_EQ(packedLen, 100u);
EXPECT_EQ(origLen, 200u);
}
TEST_F(PacketLengthTest, TotalPacketLength) {
// 总包长度 = HDR_LENGTH + 压缩数据长度
uint32_t dataLen = 1000;
uint32_t totalLen = HDR_LENGTH + dataLen;
EXPECT_EQ(totalLen, 1016u);
}
// ============================================
// 边界条件测试
// ============================================
class HeaderBoundaryTest : public ::testing::Test {};
TEST_F(HeaderBoundaryTest, MinimalPacket) {
// 最小合法包:只有头部
std::vector<uint8_t> packet(HDR_LENGTH);
memcpy(packet.data(), "HELL", 4);
packet[6] = 0x42;
packet[7] = ~0x42;
// packedLength = HDR_LENGTH
uint32_t packedLen = HDR_LENGTH;
memcpy(packet.data() + FLAG_LENGTH, &packedLen, sizeof(uint32_t));
// originalLength = 0
uint32_t origLen = 0;
memcpy(packet.data() + FLAG_LENGTH + sizeof(uint32_t), &origLen, sizeof(uint32_t));
EXPECT_EQ(packet.size(), HDR_LENGTH);
}
TEST_F(HeaderBoundaryTest, MaxPacketLength) {
// 验证大包长度字段
uint32_t maxDataLen = 100 * 1024 * 1024; // 100 MB
uint32_t totalLen = HDR_LENGTH + maxDataLen;
PacketHeader pkt = {};
pkt.packedLength = totalLen;
pkt.originalLength = maxDataLen * 2; // 压缩前更大
EXPECT_EQ(pkt.packedLength, totalLen);
EXPECT_EQ(pkt.originalLength, maxDataLen * 2);
}
TEST_F(HeaderBoundaryTest, TruncatedHeader) {
// 不完整的头部不应该被识别
char truncated[4] = { 'H', 'E', 'L', 'L' };
// 不足以进行完整验证
// 这里只是验证常量定义正确
EXPECT_LT(sizeof(truncated), static_cast<size_t>(MIN_COMLEN));
}
// ============================================
// FlagType 枚举测试
// ============================================
class FlagTypeTest : public ::testing::Test {};
TEST_F(FlagTypeTest, EnumValues) {
EXPECT_EQ(FLAG_WINOS, -1);
EXPECT_EQ(FLAG_UNKNOWN, 0);
EXPECT_EQ(FLAG_SHINE, 1);
EXPECT_EQ(FLAG_FUCK, 2);
EXPECT_EQ(FLAG_HELLO, 3);
EXPECT_EQ(FLAG_HELL, 4);
}
TEST_F(FlagTypeTest, HeaderEncTypeValues) {
EXPECT_EQ(HeaderEncUnknown, -1);
EXPECT_EQ(HeaderEncNone, 0);
EXPECT_EQ(HeaderEncV0, 1);
EXPECT_EQ(HeaderEncNum, 8);
}

View File

@@ -0,0 +1,534 @@
/**
* @file HttpMaskTest.cpp
* @brief HTTP 协议伪装测试
*
* 测试覆盖:
* - HTTP 请求格式生成
* - HTTP 头部解析和移除
* - Mask/UnMask 往返测试
* - 边界条件处理
*/
#include <gtest/gtest.h>
#include <cstring>
#include <string>
#include <vector>
#include <map>
// ============================================
// 类型定义
// ============================================
#ifdef _WIN32
typedef unsigned long ULONG;
#else
typedef uint32_t ULONG;
#endif
// ============================================
// 协议伪装类型
// ============================================
enum PkgMaskType {
MaskTypeUnknown = -1,
MaskTypeNone,
MaskTypeHTTP,
MaskTypeNum,
};
// ============================================
// HTTP 解除伪装函数
// ============================================
inline ULONG UnMaskHttp(const char* src, ULONG srcSize)
{
const char* header_end_mark = "\r\n\r\n";
const ULONG mark_len = 4;
for (ULONG i = 0; i + mark_len <= srcSize; ++i) {
if (memcmp(src + i, header_end_mark, mark_len) == 0) {
return i + mark_len;
}
}
return 0;
}
inline ULONG TryUnMask(const char* src, ULONG srcSize, PkgMaskType& maskHit)
{
if (srcSize >= 5 && memcmp(src, "POST ", 5) == 0) {
maskHit = MaskTypeHTTP;
return UnMaskHttp(src, srcSize);
}
if (srcSize >= 4 && memcmp(src, "GET ", 4) == 0) {
maskHit = MaskTypeHTTP;
return UnMaskHttp(src, srcSize);
}
maskHit = MaskTypeNone;
return 0;
}
// ============================================
// HTTP Mask 类(简化版)
// ============================================
class HttpMask {
public:
explicit HttpMask(const std::string& host = "example.com",
const std::map<std::string, std::string>& headers = {})
: host_(host)
{
for (const auto& kv : headers) {
customHeaders_ += kv.first + ": " + kv.second + "\r\n";
}
}
void Mask(char*& dst, ULONG& dstSize, const char* src, ULONG srcSize, int cmd = -1)
{
std::string path = "/api/v1/" + std::to_string(cmd == -1 ? 0 : cmd);
std::string httpHeader =
"POST " + path + " HTTP/1.1\r\n"
"Host: " + host_ + "\r\n"
"User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64)\r\n"
"Content-Type: application/octet-stream\r\n"
"Content-Length: " + std::to_string(srcSize) + "\r\n" + customHeaders_ +
"Connection: keep-alive\r\n"
"\r\n";
dstSize = static_cast<ULONG>(httpHeader.size()) + srcSize;
dst = new char[dstSize];
memcpy(dst, httpHeader.data(), httpHeader.size());
if (srcSize > 0) {
memcpy(dst + httpHeader.size(), src, srcSize);
}
}
ULONG UnMask(const char* src, ULONG srcSize)
{
return UnMaskHttp(src, srcSize);
}
void SetHost(const std::string& host)
{
host_ = host;
}
private:
std::string host_;
std::string customHeaders_;
};
// ============================================
// UnMaskHttp 测试
// ============================================
class UnMaskHttpTest : public ::testing::Test {};
TEST_F(UnMaskHttpTest, ValidHttpRequest) {
std::string httpRequest =
"POST /api HTTP/1.1\r\n"
"Host: example.com\r\n"
"Content-Length: 4\r\n"
"\r\n"
"DATA";
ULONG offset = UnMaskHttp(httpRequest.data(), static_cast<ULONG>(httpRequest.size()));
// 应该返回 body 起始位置
EXPECT_GT(offset, 0u);
EXPECT_STREQ(httpRequest.data() + offset, "DATA");
}
TEST_F(UnMaskHttpTest, NoHeaderEnd) {
std::string incomplete = "POST /api HTTP/1.1\r\nHost: example.com\r\n";
ULONG offset = UnMaskHttp(incomplete.data(), static_cast<ULONG>(incomplete.size()));
EXPECT_EQ(offset, 0u);
}
TEST_F(UnMaskHttpTest, EmptyBody) {
std::string httpRequest =
"POST /api HTTP/1.1\r\n"
"Content-Length: 0\r\n"
"\r\n";
ULONG offset = UnMaskHttp(httpRequest.data(), static_cast<ULONG>(httpRequest.size()));
EXPECT_EQ(offset, static_cast<ULONG>(httpRequest.size()));
}
TEST_F(UnMaskHttpTest, MultipleHeaderEndMarkers) {
std::string httpRequest =
"POST /api HTTP/1.1\r\n"
"Content-Length: 8\r\n"
"\r\n"
"\r\n\r\nXX"; // body 中也有 \r\n\r\n
ULONG offset = UnMaskHttp(httpRequest.data(), static_cast<ULONG>(httpRequest.size()));
// 应该返回第一个 \r\n\r\n 之后
std::string body(httpRequest.data() + offset);
EXPECT_EQ(body, "\r\n\r\nXX");
}
TEST_F(UnMaskHttpTest, MinimalInput) {
// 小于 4 字节
std::string tiny = "AB";
ULONG offset = UnMaskHttp(tiny.data(), static_cast<ULONG>(tiny.size()));
EXPECT_EQ(offset, 0u);
}
TEST_F(UnMaskHttpTest, ExactlyHeaderEnd) {
std::string justEnd = "\r\n\r\n";
ULONG offset = UnMaskHttp(justEnd.data(), static_cast<ULONG>(justEnd.size()));
EXPECT_EQ(offset, 4u);
}
// ============================================
// TryUnMask 测试
// ============================================
class TryUnMaskTest : public ::testing::Test {};
TEST_F(TryUnMaskTest, DetectPOST) {
std::string httpRequest =
"POST /api HTTP/1.1\r\n"
"Host: test.com\r\n"
"\r\n"
"body";
PkgMaskType maskType;
ULONG offset = TryUnMask(httpRequest.data(), static_cast<ULONG>(httpRequest.size()), maskType);
EXPECT_EQ(maskType, MaskTypeHTTP);
EXPECT_GT(offset, 0u);
}
TEST_F(TryUnMaskTest, DetectGET) {
std::string httpRequest =
"GET /resource HTTP/1.1\r\n"
"Host: test.com\r\n"
"\r\n";
PkgMaskType maskType;
ULONG offset = TryUnMask(httpRequest.data(), static_cast<ULONG>(httpRequest.size()), maskType);
EXPECT_EQ(maskType, MaskTypeHTTP);
EXPECT_GT(offset, 0u);
}
TEST_F(TryUnMaskTest, NonHttpData) {
std::string binaryData = "HELL\x00\x00\x42\xBD";
PkgMaskType maskType;
ULONG offset = TryUnMask(binaryData.data(), static_cast<ULONG>(binaryData.size()), maskType);
EXPECT_EQ(maskType, MaskTypeNone);
EXPECT_EQ(offset, 0u);
}
TEST_F(TryUnMaskTest, ShortInput) {
std::string tooShort = "POS";
PkgMaskType maskType;
ULONG offset = TryUnMask(tooShort.data(), static_cast<ULONG>(tooShort.size()), maskType);
EXPECT_EQ(maskType, MaskTypeNone);
EXPECT_EQ(offset, 0u);
}
// ============================================
// HttpMask 类测试
// ============================================
class HttpMaskClassTest : public ::testing::Test {};
TEST_F(HttpMaskClassTest, MaskBasic) {
HttpMask mask("api.example.com");
const char* data = "Hello";
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, data, 5);
ASSERT_NE(masked, nullptr);
EXPECT_GT(maskedSize, 5u);
// 验证是 HTTP 格式
EXPECT_EQ(memcmp(masked, "POST ", 5), 0);
// 验证 Host 头
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Host: api.example.com"), std::string::npos);
// 验证 Content-Length
EXPECT_NE(maskedStr.find("Content-Length: 5"), std::string::npos);
// 验证 body
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(memcmp(masked + offset, "Hello", 5), 0);
delete[] masked;
}
TEST_F(HttpMaskClassTest, MaskEmptyData) {
HttpMask mask;
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "", 0);
ASSERT_NE(masked, nullptr);
// 验证 Content-Length: 0
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Content-Length: 0"), std::string::npos);
delete[] masked;
}
TEST_F(HttpMaskClassTest, MaskWithCommand) {
HttpMask mask;
const char* data = "X";
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, data, 1, 42);
// 验证路径包含命令号
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("/42"), std::string::npos);
delete[] masked;
}
TEST_F(HttpMaskClassTest, MaskLargeData) {
HttpMask mask;
std::vector<char> largeData(64 * 1024, 'X'); // 64 KB
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, largeData.data(), static_cast<ULONG>(largeData.size()));
ASSERT_NE(masked, nullptr);
// 验证 Content-Length
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Content-Length: 65536"), std::string::npos);
// 验证 body 完整
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(maskedSize - offset, 64u * 1024u);
delete[] masked;
}
// ============================================
// Mask/UnMask 往返测试
// ============================================
class MaskRoundTripTest : public ::testing::Test {};
TEST_F(MaskRoundTripTest, SimpleRoundTrip) {
HttpMask mask;
std::vector<char> original = {'H', 'E', 'L', 'L', 'O'};
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
// UnMask
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_GT(offset, 0u);
// 验证数据完整
EXPECT_EQ(maskedSize - offset, original.size());
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
delete[] masked;
}
TEST_F(MaskRoundTripTest, BinaryDataRoundTrip) {
HttpMask mask;
// 包含所有字节值
std::vector<char> original(256);
for (int i = 0; i < 256; ++i) {
original[i] = static_cast<char>(i);
}
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
delete[] masked;
}
TEST_F(MaskRoundTripTest, NullBytesRoundTrip) {
HttpMask mask;
std::vector<char> original = {'\0', '\0', 'A', '\0', 'B'};
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(maskedSize - offset, original.size());
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
delete[] masked;
}
TEST_F(MaskRoundTripTest, HttpLikeDataRoundTrip) {
HttpMask mask;
// 数据本身看起来像 HTTP
std::string httpLike = "POST /fake HTTP/1.1\r\n\r\nfake body";
std::vector<char> original(httpLike.begin(), httpLike.end());
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
delete[] masked;
}
// ============================================
// 自定义头部测试
// ============================================
class CustomHeadersTest : public ::testing::Test {};
TEST_F(CustomHeadersTest, AddCustomHeaders) {
std::map<std::string, std::string> headers;
headers["X-Custom-Header"] = "custom-value";
headers["X-Request-ID"] = "12345";
HttpMask mask("test.com", headers);
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "data", 4);
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("X-Custom-Header: custom-value"), std::string::npos);
EXPECT_NE(maskedStr.find("X-Request-ID: 12345"), std::string::npos);
delete[] masked;
}
// ============================================
// 边界条件测试
// ============================================
class HttpMaskBoundaryTest : public ::testing::Test {};
TEST_F(HttpMaskBoundaryTest, VeryLongHost) {
std::string longHost(1000, 'x');
longHost += ".com";
HttpMask mask(longHost);
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "test", 4);
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find(longHost), std::string::npos);
delete[] masked;
}
TEST_F(HttpMaskBoundaryTest, SpecialCharactersInHost) {
HttpMask mask("api-v2.test-server.example.com");
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "x", 1);
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Host: api-v2.test-server.example.com"), std::string::npos);
delete[] masked;
}
TEST_F(HttpMaskBoundaryTest, MaxULONGContentLength) {
// 测试大的 Content-Length 值
HttpMask mask;
// 不实际分配这么大的内存,只是构造请求头
std::string largeContentLengthStr = "Content-Length: 4294967295";
// 验证格式正确
EXPECT_EQ(largeContentLengthStr.find("4294967295"), 16u);
}
// ============================================
// HTTP 格式验证测试
// ============================================
class HttpFormatTest : public ::testing::Test {};
TEST_F(HttpFormatTest, ValidHttpRequestFormat) {
HttpMask mask("test.com");
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "body", 4);
std::string maskedStr(masked, maskedSize);
// 验证请求行
EXPECT_EQ(maskedStr.substr(0, 5), "POST ");
// 验证 HTTP 版本
EXPECT_NE(maskedStr.find("HTTP/1.1\r\n"), std::string::npos);
// 验证必需的头部
EXPECT_NE(maskedStr.find("Host:"), std::string::npos);
EXPECT_NE(maskedStr.find("Content-Length:"), std::string::npos);
EXPECT_NE(maskedStr.find("Content-Type:"), std::string::npos);
// 验证头部和 body 之间的分隔符
EXPECT_NE(maskedStr.find("\r\n\r\n"), std::string::npos);
delete[] masked;
}
TEST_F(HttpFormatTest, ContentLengthMatchesBody) {
HttpMask mask;
std::vector<char> body(123, 'X');
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, body.data(), static_cast<ULONG>(body.size()));
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Content-Length: 123"), std::string::npos);
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(maskedSize - offset, 123u);
delete[] masked;
}

View File

@@ -0,0 +1,660 @@
/**
* @file PacketFragmentTest.cpp
* @brief 粘包/分包处理测试
*
* 测试覆盖:
* - 完整包接收和解析
* - 分包(不完整包)处理
* - 粘包(多个包粘在一起)处理
* - 混合场景测试
* - 边界条件处理
*/
#include <gtest/gtest.h>
#include <cstring>
#include <cstdint>
#include <vector>
#include <queue>
#include <functional>
// ============================================
// 协议常量
// ============================================
const int FLAG_LENGTH = 8;
const int HDR_LENGTH = 16; // FLAG(8) + PackedLen(4) + OrigLen(4)
// ============================================
// 简化版 CBuffer测试专用
// ============================================
class TestBuffer {
public:
TestBuffer() {}
void WriteBuffer(const uint8_t* data, size_t len) {
m_data.insert(m_data.end(), data, data + len);
}
size_t ReadBuffer(uint8_t* dst, size_t len) {
size_t toRead = std::min(len, m_data.size());
if (toRead > 0) {
memcpy(dst, m_data.data(), toRead);
m_data.erase(m_data.begin(), m_data.begin() + toRead);
}
return toRead;
}
void Skip(size_t len) {
size_t toSkip = std::min(len, m_data.size());
m_data.erase(m_data.begin(), m_data.begin() + toSkip);
}
size_t GetBufferLength() const {
return m_data.size();
}
const uint8_t* GetBuffer(size_t pos = 0) const {
if (pos >= m_data.size()) return nullptr;
return m_data.data() + pos;
}
bool CopyBuffer(uint8_t* dst, size_t pos, size_t len) const {
if (pos + len > m_data.size()) return false;
memcpy(dst, m_data.data() + pos, len);
return true;
}
void Clear() {
m_data.clear();
}
private:
std::vector<uint8_t> m_data;
};
// ============================================
// 数据包构建辅助函数
// ============================================
#pragma pack(push, 1)
struct PacketHeader {
char flag[FLAG_LENGTH];
uint32_t packedLength; // 包含头部的总长度
uint32_t originalLength; // 原始数据长度
};
#pragma pack(pop)
std::vector<uint8_t> BuildPacket(const std::vector<uint8_t>& payload, uint8_t key = 0x42) {
std::vector<uint8_t> packet(HDR_LENGTH + payload.size());
PacketHeader* hdr = reinterpret_cast<PacketHeader*>(packet.data());
memcpy(hdr->flag, "HELL", 4);
hdr->flag[6] = key;
hdr->flag[7] = ~key;
hdr->packedLength = static_cast<uint32_t>(HDR_LENGTH + payload.size());
hdr->originalLength = static_cast<uint32_t>(payload.size());
if (!payload.empty()) {
memcpy(packet.data() + HDR_LENGTH, payload.data(), payload.size());
}
return packet;
}
std::vector<uint8_t> BuildPacketWithData(size_t dataSize, uint8_t fillByte = 0xAA) {
std::vector<uint8_t> payload(dataSize, fillByte);
return BuildPacket(payload);
}
// ============================================
// 粘包/分包处理器(模拟 OnServerReceiving 逻辑)
// ============================================
class PacketProcessor {
public:
using PacketCallback = std::function<void(const std::vector<uint8_t>&)>;
PacketProcessor(PacketCallback callback) : m_callback(callback) {}
// 接收数据(模拟网络接收)
void OnReceive(const uint8_t* data, size_t len) {
m_buffer.WriteBuffer(data, len);
ProcessBuffer();
}
// 获取待处理的字节数
size_t GetPendingBytes() const {
return m_buffer.GetBufferLength();
}
// 获取已处理的包数量
size_t GetProcessedCount() const {
return m_processedCount;
}
private:
void ProcessBuffer() {
while (m_buffer.GetBufferLength() >= HDR_LENGTH) {
// 验证头部
const uint8_t* buf = m_buffer.GetBuffer();
if (memcmp(buf, "HELL", 4) != 0) {
// 无效头部,跳过一个字节重试
m_buffer.Skip(1);
continue;
}
// 读取包长度
uint32_t packedLength;
m_buffer.CopyBuffer(reinterpret_cast<uint8_t*>(&packedLength),
FLAG_LENGTH, sizeof(uint32_t));
// 检查长度有效性
if (packedLength < HDR_LENGTH || packedLength > 100 * 1024 * 1024) {
// 无效长度,跳过头部
m_buffer.Skip(FLAG_LENGTH);
continue;
}
// 检查包是否完整
if (m_buffer.GetBufferLength() < packedLength) {
// 不完整,等待更多数据
break;
}
// 读取完整包
std::vector<uint8_t> packet(packedLength);
m_buffer.ReadBuffer(packet.data(), packedLength);
// 提取 payload
std::vector<uint8_t> payload(packet.begin() + HDR_LENGTH, packet.end());
m_callback(payload);
m_processedCount++;
}
}
TestBuffer m_buffer;
PacketCallback m_callback;
size_t m_processedCount = 0;
};
// ============================================
// 完整包接收测试
// ============================================
class CompletePacketTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
void SetUp() override {
receivedPackets.clear();
}
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(CompletePacketTest, SinglePacket) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0x01, 0x02, 0x03, 0x04};
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
EXPECT_EQ(processor.GetPendingBytes(), 0u);
}
TEST_F(CompletePacketTest, EmptyPayload) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload;
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_TRUE(receivedPackets[0].empty());
}
TEST_F(CompletePacketTest, LargePayload) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(64 * 1024); // 64 KB
for (size_t i = 0; i < payload.size(); ++i) {
payload[i] = static_cast<uint8_t>(i & 0xFF);
}
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
// ============================================
// 分包(不完整包)测试
// ============================================
class FragmentedPacketTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(FragmentedPacketTest, TwoFragments) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0xAA, 0xBB, 0xCC, 0xDD};
auto packet = BuildPacket(payload);
// 分两次发送
size_t half = packet.size() / 2;
processor.OnReceive(packet.data(), half);
EXPECT_EQ(receivedPackets.size(), 0u);
EXPECT_EQ(processor.GetPendingBytes(), half);
processor.OnReceive(packet.data() + half, packet.size() - half);
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
TEST_F(FragmentedPacketTest, ManyFragments) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(100);
for (size_t i = 0; i < payload.size(); ++i) {
payload[i] = static_cast<uint8_t>(i);
}
auto packet = BuildPacket(payload);
// 每次发送 10 字节
for (size_t i = 0; i < packet.size(); i += 10) {
size_t len = std::min(size_t(10), packet.size() - i);
processor.OnReceive(packet.data() + i, len);
}
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
TEST_F(FragmentedPacketTest, OnlyHeader) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0x01, 0x02};
auto packet = BuildPacket(payload);
// 只发送头部
processor.OnReceive(packet.data(), HDR_LENGTH);
EXPECT_EQ(receivedPackets.size(), 0u);
EXPECT_EQ(processor.GetPendingBytes(), HDR_LENGTH);
// 发送剩余数据
processor.OnReceive(packet.data() + HDR_LENGTH, packet.size() - HDR_LENGTH);
ASSERT_EQ(receivedPackets.size(), 1u);
}
TEST_F(FragmentedPacketTest, PartialHeader) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0xFF};
auto packet = BuildPacket(payload);
// 发送不完整的头部
processor.OnReceive(packet.data(), 4); // 只有 "HELL"
EXPECT_EQ(receivedPackets.size(), 0u);
// 发送剩余部分
processor.OnReceive(packet.data() + 4, packet.size() - 4);
ASSERT_EQ(receivedPackets.size(), 1u);
}
// ============================================
// 粘包测试
// ============================================
class StickyPacketTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(StickyPacketTest, TwoPacketsStuckTogether) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload1 = {0x11, 0x22};
std::vector<uint8_t> payload2 = {0x33, 0x44, 0x55};
auto packet1 = BuildPacket(payload1);
auto packet2 = BuildPacket(payload2);
// 合并两个包
std::vector<uint8_t> combined;
combined.insert(combined.end(), packet1.begin(), packet1.end());
combined.insert(combined.end(), packet2.begin(), packet2.end());
processor.OnReceive(combined.data(), combined.size());
ASSERT_EQ(receivedPackets.size(), 2u);
EXPECT_EQ(receivedPackets[0], payload1);
EXPECT_EQ(receivedPackets[1], payload2);
}
TEST_F(StickyPacketTest, ThreePacketsStuckTogether) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload1 = {0x01};
std::vector<uint8_t> payload2 = {0x02, 0x03};
std::vector<uint8_t> payload3 = {0x04, 0x05, 0x06};
auto packet1 = BuildPacket(payload1);
auto packet2 = BuildPacket(payload2);
auto packet3 = BuildPacket(payload3);
// 合并三个包
std::vector<uint8_t> combined;
combined.insert(combined.end(), packet1.begin(), packet1.end());
combined.insert(combined.end(), packet2.begin(), packet2.end());
combined.insert(combined.end(), packet3.begin(), packet3.end());
processor.OnReceive(combined.data(), combined.size());
ASSERT_EQ(receivedPackets.size(), 3u);
EXPECT_EQ(receivedPackets[0], payload1);
EXPECT_EQ(receivedPackets[1], payload2);
EXPECT_EQ(receivedPackets[2], payload3);
}
TEST_F(StickyPacketTest, ManyPacketsStuckTogether) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> combined;
const int numPackets = 100;
for (int i = 0; i < numPackets; ++i) {
std::vector<uint8_t> payload(i % 10 + 1, static_cast<uint8_t>(i));
auto packet = BuildPacket(payload);
combined.insert(combined.end(), packet.begin(), packet.end());
}
processor.OnReceive(combined.data(), combined.size());
EXPECT_EQ(receivedPackets.size(), numPackets);
}
// ============================================
// 混合场景测试
// ============================================
class MixedScenarioTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(MixedScenarioTest, OneAndHalfPackets) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload1 = {0xAA, 0xBB};
std::vector<uint8_t> payload2 = {0xCC, 0xDD, 0xEE, 0xFF};
auto packet1 = BuildPacket(payload1);
auto packet2 = BuildPacket(payload2);
// 发送完整包1 + 半个包2
std::vector<uint8_t> firstSend;
firstSend.insert(firstSend.end(), packet1.begin(), packet1.end());
firstSend.insert(firstSend.end(), packet2.begin(), packet2.begin() + packet2.size() / 2);
processor.OnReceive(firstSend.data(), firstSend.size());
EXPECT_EQ(receivedPackets.size(), 1u); // 只处理了包1
// 发送剩余的半个包2
processor.OnReceive(packet2.data() + packet2.size() / 2, packet2.size() - packet2.size() / 2);
ASSERT_EQ(receivedPackets.size(), 2u);
EXPECT_EQ(receivedPackets[0], payload1);
EXPECT_EQ(receivedPackets[1], payload2);
}
TEST_F(MixedScenarioTest, HalfPacketThenOneAndHalf) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload1 = {0x11};
std::vector<uint8_t> payload2 = {0x22, 0x33};
auto packet1 = BuildPacket(payload1);
auto packet2 = BuildPacket(payload2);
// 发送半个包1
processor.OnReceive(packet1.data(), packet1.size() / 2);
EXPECT_EQ(receivedPackets.size(), 0u);
// 发送剩余包1 + 完整包2
std::vector<uint8_t> secondSend;
secondSend.insert(secondSend.end(), packet1.begin() + packet1.size() / 2, packet1.end());
secondSend.insert(secondSend.end(), packet2.begin(), packet2.end());
processor.OnReceive(secondSend.data(), secondSend.size());
ASSERT_EQ(receivedPackets.size(), 2u);
}
TEST_F(MixedScenarioTest, RandomChunkSizes) {
PacketProcessor processor(GetCallback());
// 准备多个包
std::vector<std::vector<uint8_t>> payloads;
std::vector<uint8_t> allData;
for (int i = 0; i < 10; ++i) {
std::vector<uint8_t> payload(i * 5 + 10, static_cast<uint8_t>(i));
payloads.push_back(payload);
auto packet = BuildPacket(payload);
allData.insert(allData.end(), packet.begin(), packet.end());
}
// 使用"随机"大小的块发送
size_t chunkSizes[] = {1, 7, 15, 16, 17, 31, 32, 33, 64, 128};
size_t pos = 0;
size_t chunkIdx = 0;
while (pos < allData.size()) {
size_t chunkSize = chunkSizes[chunkIdx % (sizeof(chunkSizes) / sizeof(chunkSizes[0]))];
size_t len = std::min(chunkSize, allData.size() - pos);
processor.OnReceive(allData.data() + pos, len);
pos += len;
chunkIdx++;
}
ASSERT_EQ(receivedPackets.size(), payloads.size());
for (size_t i = 0; i < payloads.size(); ++i) {
EXPECT_EQ(receivedPackets[i], payloads[i]) << "Mismatch at packet " << i;
}
}
// ============================================
// 边界条件测试
// ============================================
class PacketBoundaryTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(PacketBoundaryTest, ExactlyHdrLength) {
PacketProcessor processor(GetCallback());
// 只有头部,无 payload
std::vector<uint8_t> packet(HDR_LENGTH);
PacketHeader* hdr = reinterpret_cast<PacketHeader*>(packet.data());
memcpy(hdr->flag, "HELL", 4);
hdr->flag[6] = 0x42;
hdr->flag[7] = ~0x42;
hdr->packedLength = HDR_LENGTH;
hdr->originalLength = 0;
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_TRUE(receivedPackets[0].empty());
}
TEST_F(PacketBoundaryTest, SingleBytePayload) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0xFF};
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0].size(), 1u);
EXPECT_EQ(receivedPackets[0][0], 0xFF);
}
TEST_F(PacketBoundaryTest, ByteByByteReceive) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0x12, 0x34, 0x56};
auto packet = BuildPacket(payload);
// 每次接收 1 字节
for (size_t i = 0; i < packet.size(); ++i) {
processor.OnReceive(packet.data() + i, 1);
}
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
// ============================================
// 数据完整性测试
// ============================================
class DataIntegrityTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(DataIntegrityTest, BinaryData) {
PacketProcessor processor(GetCallback());
// 包含所有字节值的 payload
std::vector<uint8_t> payload(256);
for (int i = 0; i < 256; ++i) {
payload[i] = static_cast<uint8_t>(i);
}
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
TEST_F(DataIntegrityTest, AllZeros) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(100, 0x00);
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
for (uint8_t b : receivedPackets[0]) {
EXPECT_EQ(b, 0x00);
}
}
TEST_F(DataIntegrityTest, AllOnes) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(100, 0xFF);
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
for (uint8_t b : receivedPackets[0]) {
EXPECT_EQ(b, 0xFF);
}
}
// ============================================
// 性能相关测试
// ============================================
class PacketPerformanceTest : public ::testing::Test {
protected:
size_t packetCount = 0;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
packetCount++;
};
}
};
TEST_F(PacketPerformanceTest, ManySmallPackets) {
PacketProcessor processor(GetCallback());
const int numPackets = 10000;
std::vector<uint8_t> allData;
for (int i = 0; i < numPackets; ++i) {
std::vector<uint8_t> payload = {static_cast<uint8_t>(i & 0xFF)};
auto packet = BuildPacket(payload);
allData.insert(allData.end(), packet.begin(), packet.end());
}
processor.OnReceive(allData.data(), allData.size());
EXPECT_EQ(packetCount, numPackets);
}
TEST_F(PacketPerformanceTest, LargePacketInSmallChunks) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(100 * 1024); // 100 KB
for (size_t i = 0; i < payload.size(); ++i) {
payload[i] = static_cast<uint8_t>(i & 0xFF);
}
auto packet = BuildPacket(payload);
// 每次发送 1 KB
const size_t chunkSize = 1024;
for (size_t i = 0; i < packet.size(); i += chunkSize) {
size_t len = std::min(chunkSize, packet.size() - i);
processor.OnReceive(packet.data() + i, len);
}
EXPECT_EQ(packetCount, 1u);
}

View File

@@ -0,0 +1,520 @@
/**
* @file PacketTest.cpp
* @brief 协议数据包结构测试
*
* 测试覆盖:
* - 数据包结构大小验证
* - 字段偏移和对齐
* - 序列化/反序列化
* - 边界条件
*/
#include <gtest/gtest.h>
#include <cstring>
#include <vector>
#include <cstdint>
// ============================================
// 协议数据结构定义(测试专用副本)
// ============================================
#pragma pack(push, 1)
// V1 文件传输包
struct FileChunkPacket {
uint8_t cmd;
uint32_t fileIndex;
uint32_t totalNum;
uint64_t fileSize;
uint64_t offset;
uint64_t dataLength;
uint64_t nameLength;
}; // 应该是 41 bytes
// V2 标志位
enum FileFlagsV2 : uint16_t {
FFV2_NONE = 0x0000,
FFV2_LAST_CHUNK = 0x0001,
FFV2_RESUME_REQ = 0x0002,
FFV2_RESUME_RESP = 0x0004,
FFV2_CANCEL = 0x0008,
FFV2_DIRECTORY = 0x0010,
FFV2_COMPRESSED = 0x0020,
FFV2_ERROR = 0x0040,
};
// V2 文件传输包
struct FileChunkPacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint32_t totalFiles;
uint64_t fileSize;
uint64_t offset;
uint64_t dataLength;
uint64_t nameLength;
uint16_t flags;
uint16_t checksum;
uint8_t reserved[8];
}; // 应该是 77 bytes
// V2 断点续传区间
struct FileRangeV2 {
uint64_t offset;
uint64_t length;
}; // 16 bytes
// V2 断点续传控制包
struct FileResumePacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint64_t fileSize;
uint64_t receivedBytes;
uint16_t flags;
uint16_t rangeCount;
}; // 49 bytes
// C2C 准备包
struct C2CPreparePacket {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
}; // 17 bytes
// C2C 准备响应包
struct C2CPrepareRespPacket {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint16_t pathLength;
}; // 19 bytes
// V2 文件完成校验包
struct FileCompletePacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint64_t fileSize;
uint8_t sha256[32];
}; // 69 bytes
#pragma pack(pop)
// 命令常量
enum Commands {
COMMAND_SEND_FILE = 68,
COMMAND_SEND_FILE_V2 = 85,
COMMAND_FILE_RESUME = 86,
COMMAND_CLIPBOARD_V2 = 87,
COMMAND_FILE_QUERY_RESUME = 88,
COMMAND_C2C_PREPARE = 89,
COMMAND_C2C_TEXT = 90,
COMMAND_FILE_COMPLETE_V2 = 91,
COMMAND_C2C_PREPARE_RESP = 92,
};
// ============================================
// 结构大小测试
// ============================================
TEST(PacketSizeTest, FileChunkPacket_Size) {
EXPECT_EQ(sizeof(FileChunkPacket), 41u);
}
TEST(PacketSizeTest, FileChunkPacketV2_Size) {
EXPECT_EQ(sizeof(FileChunkPacketV2), 77u);
}
TEST(PacketSizeTest, FileRangeV2_Size) {
EXPECT_EQ(sizeof(FileRangeV2), 16u);
}
TEST(PacketSizeTest, FileResumePacketV2_Size) {
EXPECT_EQ(sizeof(FileResumePacketV2), 49u);
}
TEST(PacketSizeTest, C2CPreparePacket_Size) {
EXPECT_EQ(sizeof(C2CPreparePacket), 17u);
}
TEST(PacketSizeTest, C2CPrepareRespPacket_Size) {
EXPECT_EQ(sizeof(C2CPrepareRespPacket), 19u);
}
TEST(PacketSizeTest, FileCompletePacketV2_Size) {
EXPECT_EQ(sizeof(FileCompletePacketV2), 69u);
}
// ============================================
// 字段偏移测试
// ============================================
TEST(PacketOffsetTest, FileChunkPacketV2_FieldOffsets) {
EXPECT_EQ(offsetof(FileChunkPacketV2, cmd), 0u);
EXPECT_EQ(offsetof(FileChunkPacketV2, transferID), 1u);
EXPECT_EQ(offsetof(FileChunkPacketV2, srcClientID), 9u);
EXPECT_EQ(offsetof(FileChunkPacketV2, dstClientID), 17u);
EXPECT_EQ(offsetof(FileChunkPacketV2, fileIndex), 25u);
EXPECT_EQ(offsetof(FileChunkPacketV2, totalFiles), 29u);
EXPECT_EQ(offsetof(FileChunkPacketV2, fileSize), 33u);
EXPECT_EQ(offsetof(FileChunkPacketV2, offset), 41u);
EXPECT_EQ(offsetof(FileChunkPacketV2, dataLength), 49u);
EXPECT_EQ(offsetof(FileChunkPacketV2, nameLength), 57u);
EXPECT_EQ(offsetof(FileChunkPacketV2, flags), 65u);
EXPECT_EQ(offsetof(FileChunkPacketV2, checksum), 67u);
EXPECT_EQ(offsetof(FileChunkPacketV2, reserved), 69u);
}
TEST(PacketOffsetTest, FileResumePacketV2_FieldOffsets) {
EXPECT_EQ(offsetof(FileResumePacketV2, cmd), 0u);
EXPECT_EQ(offsetof(FileResumePacketV2, transferID), 1u);
EXPECT_EQ(offsetof(FileResumePacketV2, srcClientID), 9u);
EXPECT_EQ(offsetof(FileResumePacketV2, dstClientID), 17u);
EXPECT_EQ(offsetof(FileResumePacketV2, fileIndex), 25u);
EXPECT_EQ(offsetof(FileResumePacketV2, fileSize), 29u);
EXPECT_EQ(offsetof(FileResumePacketV2, receivedBytes), 37u);
EXPECT_EQ(offsetof(FileResumePacketV2, flags), 45u);
EXPECT_EQ(offsetof(FileResumePacketV2, rangeCount), 47u);
}
// ============================================
// 序列化/反序列化测试
// ============================================
class PacketSerializationTest : public ::testing::Test {
protected:
std::vector<uint8_t> buffer;
template<typename T>
void SerializePacket(const T& pkt) {
buffer.resize(sizeof(T));
memcpy(buffer.data(), &pkt, sizeof(T));
}
template<typename T>
T DeserializePacket() {
T pkt;
memcpy(&pkt, buffer.data(), sizeof(T));
return pkt;
}
};
TEST_F(PacketSerializationTest, FileChunkPacketV2_RoundTrip) {
FileChunkPacketV2 original = {};
original.cmd = COMMAND_SEND_FILE_V2;
original.transferID = 0x123456789ABCDEF0ULL;
original.srcClientID = 0x1111111111111111ULL;
original.dstClientID = 0x2222222222222222ULL;
original.fileIndex = 42;
original.totalFiles = 100;
original.fileSize = 1024 * 1024 * 1024; // 1GB
original.offset = 512 * 1024;
original.dataLength = 4096;
original.nameLength = 256;
original.flags = FFV2_LAST_CHUNK | FFV2_COMPRESSED;
original.checksum = 0xABCD;
memset(original.reserved, 0x55, sizeof(original.reserved));
SerializePacket(original);
auto restored = DeserializePacket<FileChunkPacketV2>();
EXPECT_EQ(restored.cmd, original.cmd);
EXPECT_EQ(restored.transferID, original.transferID);
EXPECT_EQ(restored.srcClientID, original.srcClientID);
EXPECT_EQ(restored.dstClientID, original.dstClientID);
EXPECT_EQ(restored.fileIndex, original.fileIndex);
EXPECT_EQ(restored.totalFiles, original.totalFiles);
EXPECT_EQ(restored.fileSize, original.fileSize);
EXPECT_EQ(restored.offset, original.offset);
EXPECT_EQ(restored.dataLength, original.dataLength);
EXPECT_EQ(restored.nameLength, original.nameLength);
EXPECT_EQ(restored.flags, original.flags);
EXPECT_EQ(restored.checksum, original.checksum);
EXPECT_EQ(memcmp(restored.reserved, original.reserved, 8), 0);
}
TEST_F(PacketSerializationTest, FileResumePacketV2_RoundTrip) {
FileResumePacketV2 original = {};
original.cmd = COMMAND_FILE_RESUME;
original.transferID = 0xDEADBEEFCAFEBABEULL;
original.srcClientID = 12345;
original.dstClientID = 67890;
original.fileIndex = 5;
original.fileSize = 0xFFFFFFFFFFFFFFFFULL; // 最大值
original.receivedBytes = 0x8000000000000000ULL; // 大值
original.flags = FFV2_RESUME_RESP;
original.rangeCount = 10;
SerializePacket(original);
auto restored = DeserializePacket<FileResumePacketV2>();
EXPECT_EQ(restored.cmd, original.cmd);
EXPECT_EQ(restored.transferID, original.transferID);
EXPECT_EQ(restored.srcClientID, original.srcClientID);
EXPECT_EQ(restored.dstClientID, original.dstClientID);
EXPECT_EQ(restored.fileIndex, original.fileIndex);
EXPECT_EQ(restored.fileSize, original.fileSize);
EXPECT_EQ(restored.receivedBytes, original.receivedBytes);
EXPECT_EQ(restored.flags, original.flags);
EXPECT_EQ(restored.rangeCount, original.rangeCount);
}
TEST_F(PacketSerializationTest, FileCompletePacketV2_RoundTrip) {
FileCompletePacketV2 original = {};
original.cmd = COMMAND_FILE_COMPLETE_V2;
original.transferID = 99999;
original.srcClientID = 1;
original.dstClientID = 2;
original.fileIndex = 0;
original.fileSize = 12345678;
// 填充 SHA-256
for (int i = 0; i < 32; i++) {
original.sha256[i] = (uint8_t)(i * 8);
}
SerializePacket(original);
auto restored = DeserializePacket<FileCompletePacketV2>();
EXPECT_EQ(restored.cmd, original.cmd);
EXPECT_EQ(restored.transferID, original.transferID);
EXPECT_EQ(restored.fileSize, original.fileSize);
EXPECT_EQ(memcmp(restored.sha256, original.sha256, 32), 0);
}
// ============================================
// 标志位测试
// ============================================
TEST(FlagsTest, FileFlagsV2_SingleFlags) {
EXPECT_EQ(FFV2_NONE, 0x0000);
EXPECT_EQ(FFV2_LAST_CHUNK, 0x0001);
EXPECT_EQ(FFV2_RESUME_REQ, 0x0002);
EXPECT_EQ(FFV2_RESUME_RESP, 0x0004);
EXPECT_EQ(FFV2_CANCEL, 0x0008);
EXPECT_EQ(FFV2_DIRECTORY, 0x0010);
EXPECT_EQ(FFV2_COMPRESSED, 0x0020);
EXPECT_EQ(FFV2_ERROR, 0x0040);
}
TEST(FlagsTest, FileFlagsV2_Combinations) {
uint16_t flags = FFV2_LAST_CHUNK | FFV2_COMPRESSED;
EXPECT_TRUE(flags & FFV2_LAST_CHUNK);
EXPECT_TRUE(flags & FFV2_COMPRESSED);
EXPECT_FALSE(flags & FFV2_DIRECTORY);
EXPECT_FALSE(flags & FFV2_ERROR);
}
TEST(FlagsTest, FileFlagsV2_NoBitOverlap) {
// 验证各标志位不重叠
uint16_t allFlags[] = {
FFV2_LAST_CHUNK, FFV2_RESUME_REQ, FFV2_RESUME_RESP,
FFV2_CANCEL, FFV2_DIRECTORY, FFV2_COMPRESSED, FFV2_ERROR
};
for (size_t i = 0; i < sizeof(allFlags)/sizeof(allFlags[0]); i++) {
for (size_t j = i + 1; j < sizeof(allFlags)/sizeof(allFlags[0]); j++) {
EXPECT_EQ(allFlags[i] & allFlags[j], 0)
<< "Flags overlap: " << allFlags[i] << " & " << allFlags[j];
}
}
}
// ============================================
// 命令常量测试
// ============================================
TEST(CommandTest, CommandValues_AreUnique) {
std::vector<int> commands = {
COMMAND_SEND_FILE,
COMMAND_SEND_FILE_V2,
COMMAND_FILE_RESUME,
COMMAND_CLIPBOARD_V2,
COMMAND_FILE_QUERY_RESUME,
COMMAND_C2C_PREPARE,
COMMAND_C2C_TEXT,
COMMAND_FILE_COMPLETE_V2,
COMMAND_C2C_PREPARE_RESP
};
// 验证无重复
for (size_t i = 0; i < commands.size(); i++) {
for (size_t j = i + 1; j < commands.size(); j++) {
EXPECT_NE(commands[i], commands[j])
<< "Duplicate command values at index " << i << " and " << j;
}
}
}
TEST(CommandTest, CommandValues_FitInByte) {
EXPECT_LE(COMMAND_SEND_FILE, 255);
EXPECT_LE(COMMAND_SEND_FILE_V2, 255);
EXPECT_LE(COMMAND_FILE_RESUME, 255);
EXPECT_LE(COMMAND_FILE_COMPLETE_V2, 255);
}
// ============================================
// 边界条件测试
// ============================================
TEST(BoundaryTest, MaxFileSize) {
FileChunkPacketV2 pkt = {};
pkt.fileSize = UINT64_MAX;
EXPECT_EQ(pkt.fileSize, 0xFFFFFFFFFFFFFFFFULL);
}
TEST(BoundaryTest, MaxOffset) {
FileChunkPacketV2 pkt = {};
pkt.offset = UINT64_MAX;
EXPECT_EQ(pkt.offset, 0xFFFFFFFFFFFFFFFFULL);
}
TEST(BoundaryTest, ZeroValues) {
FileChunkPacketV2 pkt = {};
memset(&pkt, 0, sizeof(pkt));
EXPECT_EQ(pkt.cmd, 0);
EXPECT_EQ(pkt.transferID, 0u);
EXPECT_EQ(pkt.fileSize, 0u);
EXPECT_EQ(pkt.flags, FFV2_NONE);
}
// ============================================
// 带变长数据的包测试
// ============================================
TEST(VariableLengthTest, FileChunkPacketV2_WithFilename) {
const char* filename = "test/folder/file.txt";
size_t nameLen = strlen(filename);
size_t totalSize = sizeof(FileChunkPacketV2) + nameLen;
std::vector<uint8_t> buffer(totalSize);
auto* pkt = reinterpret_cast<FileChunkPacketV2*>(buffer.data());
pkt->cmd = COMMAND_SEND_FILE_V2;
pkt->nameLength = static_cast<uint64_t>(nameLen);
memcpy(buffer.data() + sizeof(FileChunkPacketV2), filename, nameLen);
// 验证可以正确读取
EXPECT_EQ(pkt->nameLength, nameLen);
std::string extractedName(
reinterpret_cast<char*>(buffer.data() + sizeof(FileChunkPacketV2)),
pkt->nameLength
);
EXPECT_EQ(extractedName, filename);
}
TEST(VariableLengthTest, FileChunkPacketV2_WithFilenameAndData) {
const char* filename = "data.bin";
size_t nameLen = strlen(filename);
std::vector<uint8_t> fileData = {0x01, 0x02, 0x03, 0x04, 0x05};
size_t dataLen = fileData.size();
size_t totalSize = sizeof(FileChunkPacketV2) + nameLen + dataLen;
std::vector<uint8_t> buffer(totalSize);
auto* pkt = reinterpret_cast<FileChunkPacketV2*>(buffer.data());
pkt->cmd = COMMAND_SEND_FILE_V2;
pkt->nameLength = nameLen;
pkt->dataLength = dataLen;
// 写入文件名
memcpy(buffer.data() + sizeof(FileChunkPacketV2), filename, nameLen);
// 写入数据
memcpy(buffer.data() + sizeof(FileChunkPacketV2) + nameLen, fileData.data(), dataLen);
// 验证
EXPECT_EQ(buffer.size(), totalSize);
// 提取文件名
std::string extractedName(
reinterpret_cast<char*>(buffer.data() + sizeof(FileChunkPacketV2)),
pkt->nameLength
);
EXPECT_EQ(extractedName, filename);
// 提取数据
std::vector<uint8_t> extractedData(
buffer.begin() + sizeof(FileChunkPacketV2) + nameLen,
buffer.end()
);
EXPECT_EQ(extractedData, fileData);
}
// ============================================
// 断点续传区间测试
// ============================================
TEST(ResumeRangeTest, SingleRange) {
FileRangeV2 range = {0, 1024};
EXPECT_EQ(range.offset, 0u);
EXPECT_EQ(range.length, 1024u);
}
TEST(ResumeRangeTest, MultipleRanges) {
std::vector<FileRangeV2> ranges = {
{0, 1024},
{2048, 512},
{4096, 2048}
};
// 计算总接收字节数
uint64_t totalReceived = 0;
for (const auto& r : ranges) {
totalReceived += r.length;
}
EXPECT_EQ(totalReceived, 1024u + 512u + 2048u);
}
TEST(ResumeRangeTest, PacketWithRanges) {
uint16_t rangeCount = 3;
size_t totalSize = sizeof(FileResumePacketV2) + rangeCount * sizeof(FileRangeV2);
std::vector<uint8_t> buffer(totalSize);
auto* pkt = reinterpret_cast<FileResumePacketV2*>(buffer.data());
pkt->cmd = COMMAND_FILE_RESUME;
pkt->rangeCount = rangeCount;
auto* ranges = reinterpret_cast<FileRangeV2*>(buffer.data() + sizeof(FileResumePacketV2));
ranges[0] = {0, 1024};
ranges[1] = {2048, 512};
ranges[2] = {4096, 2048};
// 验证
EXPECT_EQ(pkt->rangeCount, 3);
EXPECT_EQ(ranges[0].offset, 0u);
EXPECT_EQ(ranges[1].offset, 2048u);
EXPECT_EQ(ranges[2].length, 2048u);
}
// ============================================
// 字节序测试(假设小端)
// ============================================
TEST(EndiannessTest, LittleEndian_Uint64) {
uint64_t value = 0x0102030405060708ULL;
uint8_t bytes[8];
memcpy(bytes, &value, 8);
// 小端:低字节在前
EXPECT_EQ(bytes[0], 0x08);
EXPECT_EQ(bytes[1], 0x07);
EXPECT_EQ(bytes[7], 0x01);
}
TEST(EndiannessTest, FileChunkPacketV2_TransferID) {
FileChunkPacketV2 pkt = {};
pkt.transferID = 0x0102030405060708ULL;
uint8_t* raw = reinterpret_cast<uint8_t*>(&pkt);
// transferID 从 offset 1 开始
EXPECT_EQ(raw[1], 0x08); // 低字节
EXPECT_EQ(raw[8], 0x01); // 高字节
}

View File

@@ -0,0 +1,360 @@
/**
* @file PathUtilsTest.cpp
* @brief 路径处理工具函数测试
*
* 测试覆盖:
* - GetCommonRoot: 计算多文件的公共根目录
* - GetRelativePath: 获取相对路径
* - 边界条件和特殊字符处理
*/
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include <algorithm>
#include <cctype>
// ============================================
// 路径工具函数实现(测试专用副本)
// ============================================
// 计算最长公共前缀作为根目录
std::string GetCommonRoot(const std::vector<std::string>& files)
{
if (files.empty())
return "";
std::string root = files[0];
for (size_t i = 1; i < files.size(); ++i)
{
const std::string& path = files[i];
size_t len = (std::min)(root.size(), path.size());
size_t j = 0;
for (; j < len; ++j)
{
if (std::tolower(static_cast<unsigned char>(root[j])) !=
std::tolower(static_cast<unsigned char>(path[j])))
{
break;
}
}
root = root.substr(0, j);
size_t pos = root.find_last_of('\\');
if (pos != std::string::npos)
root = root.substr(0, pos + 1);
}
return root;
}
// 获取相对路径
std::string GetRelativePath(const std::string& root, const std::string& fullPath)
{
if (fullPath.compare(0, root.size(), root) == 0)
{
std::string rel = fullPath.substr(root.size());
if (rel.empty()) // root 就是完整文件路径
{
size_t pos = fullPath.find_last_of('\\');
if (pos != std::string::npos)
rel = fullPath.substr(pos + 1); // 文件名
else
rel = fullPath;
}
return rel;
}
return fullPath;
}
// ============================================
// GetCommonRoot 测试
// ============================================
class GetCommonRootTest : public ::testing::Test {};
TEST_F(GetCommonRootTest, EmptyList_ReturnsEmpty) {
std::vector<std::string> files;
EXPECT_EQ(GetCommonRoot(files), "");
}
TEST_F(GetCommonRootTest, SingleFile_ReturnsParentDir) {
std::vector<std::string> files = {
"C:\\Users\\Test\\file.txt"
};
// 单个文件时,返回整个路径(因为没有其他文件比较)
EXPECT_EQ(GetCommonRoot(files), "C:\\Users\\Test\\file.txt");
}
TEST_F(GetCommonRootTest, TwoFilesInSameDir_ReturnsDir) {
std::vector<std::string> files = {
"C:\\Users\\Test\\file1.txt",
"C:\\Users\\Test\\file2.txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\Users\\Test\\");
}
TEST_F(GetCommonRootTest, FilesInNestedDirs_ReturnsCommonAncestor) {
std::vector<std::string> files = {
"C:\\Users\\Test\\Documents\\a.txt",
"C:\\Users\\Test\\Downloads\\b.txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\Users\\Test\\");
}
TEST_F(GetCommonRootTest, FilesInDeeplyNestedDirs) {
std::vector<std::string> files = {
"C:\\a\\b\\c\\d\\e\\file1.txt",
"C:\\a\\b\\c\\x\\y\\file2.txt",
"C:\\a\\b\\c\\z\\file3.txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\a\\b\\c\\");
}
TEST_F(GetCommonRootTest, CaseInsensitive) {
std::vector<std::string> files = {
"C:\\Users\\TEST\\file1.txt",
"C:\\users\\test\\file2.txt"
};
// 应该忽略大小写进行比较
std::string root = GetCommonRoot(files);
// 结果应该是原始路径的前缀
EXPECT_TRUE(root.size() >= strlen("C:\\Users\\TEST\\") ||
root.size() >= strlen("C:\\users\\test\\"));
}
TEST_F(GetCommonRootTest, DifferentDrives_ReturnsEmpty) {
std::vector<std::string> files = {
"C:\\Users\\file1.txt",
"D:\\Data\\file2.txt"
};
// 不同驱动器,公共根可能为空或只有共同前缀
std::string root = GetCommonRoot(files);
EXPECT_TRUE(root.empty() || root.find('\\') == std::string::npos);
}
TEST_F(GetCommonRootTest, SameDirectory_MultipleFiles) {
std::vector<std::string> files = {
"C:\\Temp\\a.txt",
"C:\\Temp\\b.txt",
"C:\\Temp\\c.txt",
"C:\\Temp\\d.txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\Temp\\");
}
TEST_F(GetCommonRootTest, DirectoryAndFiles) {
std::vector<std::string> files = {
"C:\\Temp\\folder",
"C:\\Temp\\folder\\file.txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\Temp\\");
}
TEST_F(GetCommonRootTest, ChineseCharacters) {
std::vector<std::string> files = {
"C:\\Users\\测试\\文档\\file1.txt",
"C:\\Users\\测试\\下载\\file2.txt"
};
std::string root = GetCommonRoot(files);
// 应该能正确处理中文路径
EXPECT_TRUE(root.find("测试") != std::string::npos ||
root == "C:\\Users\\");
}
TEST_F(GetCommonRootTest, SpacesInPath) {
std::vector<std::string> files = {
"C:\\Program Files\\App\\file1.txt",
"C:\\Program Files\\App\\file2.txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\Program Files\\App\\");
}
// ============================================
// GetRelativePath 测试
// ============================================
class GetRelativePathTest : public ::testing::Test {};
TEST_F(GetRelativePathTest, SimpleRelativePath) {
std::string root = "C:\\Users\\Test\\";
std::string fullPath = "C:\\Users\\Test\\Documents\\file.txt";
EXPECT_EQ(GetRelativePath(root, fullPath), "Documents\\file.txt");
}
TEST_F(GetRelativePathTest, FileInRoot) {
std::string root = "C:\\Users\\Test\\";
std::string fullPath = "C:\\Users\\Test\\file.txt";
EXPECT_EQ(GetRelativePath(root, fullPath), "file.txt");
}
TEST_F(GetRelativePathTest, RootEqualsFullPath) {
std::string root = "C:\\Users\\Test\\file.txt";
std::string fullPath = "C:\\Users\\Test\\file.txt";
// 当 root 就是完整路径时,应该返回文件名
EXPECT_EQ(GetRelativePath(root, fullPath), "file.txt");
}
TEST_F(GetRelativePathTest, NoCommonPrefix_ReturnsFullPath) {
std::string root = "C:\\Users\\Test\\";
std::string fullPath = "D:\\Other\\file.txt";
// 没有公共前缀时返回完整路径
EXPECT_EQ(GetRelativePath(root, fullPath), fullPath);
}
TEST_F(GetRelativePathTest, NestedPath) {
std::string root = "C:\\a\\";
std::string fullPath = "C:\\a\\b\\c\\d\\file.txt";
EXPECT_EQ(GetRelativePath(root, fullPath), "b\\c\\d\\file.txt");
}
TEST_F(GetRelativePathTest, EmptyRoot) {
std::string root = "";
std::string fullPath = "C:\\Users\\file.txt";
EXPECT_EQ(GetRelativePath(root, fullPath), fullPath);
}
TEST_F(GetRelativePathTest, RootWithoutTrailingSlash) {
std::string root = "C:\\Users\\Test";
std::string fullPath = "C:\\Users\\Test\\file.txt";
// 没有尾部斜杠也应该工作
std::string result = GetRelativePath(root, fullPath);
EXPECT_TRUE(result == "\\file.txt" || result == fullPath);
}
TEST_F(GetRelativePathTest, DirectoryPath) {
std::string root = "C:\\Users\\";
std::string fullPath = "C:\\Users\\Test\\Documents\\";
EXPECT_EQ(GetRelativePath(root, fullPath), "Test\\Documents\\");
}
TEST_F(GetRelativePathTest, FileNameOnly) {
std::string root = "";
std::string fullPath = "file.txt";
EXPECT_EQ(GetRelativePath(root, fullPath), "file.txt");
}
// ============================================
// 组合测试
// ============================================
class PathUtilsCombinedTest : public ::testing::Test {};
TEST_F(PathUtilsCombinedTest, GetCommonRootThenRelative) {
std::vector<std::string> files = {
"C:\\Project\\src\\main.cpp",
"C:\\Project\\src\\utils.cpp",
"C:\\Project\\include\\header.h"
};
std::string root = GetCommonRoot(files);
EXPECT_EQ(root, "C:\\Project\\");
// 获取各文件的相对路径
EXPECT_EQ(GetRelativePath(root, files[0]), "src\\main.cpp");
EXPECT_EQ(GetRelativePath(root, files[1]), "src\\utils.cpp");
EXPECT_EQ(GetRelativePath(root, files[2]), "include\\header.h");
}
TEST_F(PathUtilsCombinedTest, SimulateFileTransfer) {
// 模拟文件传输场景
std::vector<std::string> selectedFiles = {
"D:\\Downloads\\Project\\doc\\readme.md",
"D:\\Downloads\\Project\\src\\app.py",
"D:\\Downloads\\Project\\src\\lib\\util.py"
};
std::string rootDir = GetCommonRoot(selectedFiles);
EXPECT_EQ(rootDir, "D:\\Downloads\\Project\\");
std::string targetDir = "C:\\Backup\\";
for (const auto& file : selectedFiles) {
std::string relPath = GetRelativePath(rootDir, file);
std::string destPath = targetDir + relPath;
// 验证目标路径正确
EXPECT_TRUE(destPath.find("C:\\Backup\\") == 0);
}
}
// ============================================
// 边界条件测试
// ============================================
TEST(PathBoundaryTest, VeryLongPath) {
// 创建一个很长的路径
std::string basePath = "C:\\";
for (int i = 0; i < 50; i++) {
basePath += "folder" + std::to_string(i) + "\\";
}
basePath += "file.txt";
std::vector<std::string> files = {basePath};
std::string root = GetCommonRoot(files);
EXPECT_EQ(root, basePath);
}
TEST(PathBoundaryTest, SpecialCharacters) {
std::vector<std::string> files = {
"C:\\Users\\Test (1)\\file[1].txt",
"C:\\Users\\Test (1)\\file[2].txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\Users\\Test (1)\\");
}
TEST(PathBoundaryTest, UNCPath) {
std::vector<std::string> files = {
"\\\\Server\\Share\\folder\\file1.txt",
"\\\\Server\\Share\\folder\\file2.txt"
};
std::string root = GetCommonRoot(files);
EXPECT_TRUE(root.find("\\\\Server\\Share\\folder\\") != std::string::npos ||
root.find("\\\\Server\\Share\\") != std::string::npos);
}
TEST(PathBoundaryTest, ForwardSlashes) {
// 测试正斜杠(虽然 Windows 主要用反斜杠)
std::vector<std::string> files = {
"C:/Users/Test/file1.txt",
"C:/Users/Test/file2.txt"
};
// 函数使用反斜杠,正斜杠可能不被正确处理
std::string root = GetCommonRoot(files);
// 验证不会崩溃
EXPECT_FALSE(root.empty());
}
// ============================================
// 参数化测试
// ============================================
class GetRelativePathParameterizedTest
: public ::testing::TestWithParam<std::tuple<std::string, std::string, std::string>> {};
TEST_P(GetRelativePathParameterizedTest, VariousPaths) {
auto [root, fullPath, expected] = GetParam();
EXPECT_EQ(GetRelativePath(root, fullPath), expected);
}
INSTANTIATE_TEST_SUITE_P(
PathCases,
GetRelativePathParameterizedTest,
::testing::Values(
std::make_tuple("C:\\a\\", "C:\\a\\b.txt", "b.txt"),
std::make_tuple("C:\\a\\", "C:\\a\\b\\c.txt", "b\\c.txt"),
std::make_tuple("C:\\a\\b\\", "C:\\a\\b\\c\\d\\e.txt", "c\\d\\e.txt"),
std::make_tuple("", "file.txt", "file.txt"),
std::make_tuple("C:\\", "C:\\file.txt", "file.txt")
)
);

View File

@@ -0,0 +1,725 @@
// DiffAlgorithmTest.cpp - Phase 4: 差分算法单元测试
// 测试帧间差异计算和编码逻辑
#include <gtest/gtest.h>
#include <vector>
#include <cstdint>
#include <cstring>
#include <algorithm>
#include <random>
// ============================================
// 算法常量定义 (来自 CursorInfo.h)
// ============================================
#define ALGORITHM_GRAY 0 // 灰度算法
#define ALGORITHM_DIFF 1 // 差分算法(默认)
#define ALGORITHM_H264 2 // H264 视频编码
#define ALGORITHM_RGB565 3 // RGB565 压缩
// ============================================
// 差分输出结构
// ============================================
struct DiffRegion {
uint32_t offset; // 字节偏移
uint32_t length; // 长度(含义取决于算法)
std::vector<uint8_t> data; // 差异数据
};
// ============================================
// 测试用差分算法实现
// 模拟 ScreenCapture.h 中的 CompareBitmapDXGI
// ============================================
class DiffAlgorithm {
public:
// 比较两帧,返回差异区域列表
// 输出格式: [offset:4][length:4][data:N]...
static std::vector<DiffRegion> CompareBitmap(
const uint8_t* srcData, // 新帧
const uint8_t* dstData, // 旧帧
size_t dataLength, // 数据长度字节必须是4的倍数
int algorithm // 压缩算法
) {
std::vector<DiffRegion> regions;
if (dataLength == 0 || dataLength % 4 != 0) {
return regions;
}
const uint32_t* src32 = reinterpret_cast<const uint32_t*>(srcData);
const uint32_t* dst32 = reinterpret_cast<const uint32_t*>(dstData);
size_t pixelCount = dataLength / 4;
size_t i = 0;
while (i < pixelCount) {
// 找到差异起始点
while (i < pixelCount && src32[i] == dst32[i]) {
i++;
}
if (i >= pixelCount) break;
// 记录起始位置
size_t startPos = i;
// 找到差异结束点
while (i < pixelCount && src32[i] != dst32[i]) {
i++;
}
// 创建差异区域
DiffRegion region;
region.offset = static_cast<uint32_t>(startPos * 4); // 字节偏移
size_t diffPixels = i - startPos;
const uint8_t* pixelStart = srcData + startPos * 4;
switch (algorithm) {
case ALGORITHM_GRAY: {
// 灰度: 1字节/像素
region.length = static_cast<uint32_t>(diffPixels); // 像素数
region.data.resize(diffPixels);
for (size_t p = 0; p < diffPixels; p++) {
const uint8_t* pixel = pixelStart + p * 4;
// BGRA格式: B=0, G=1, R=2, A=3
// 灰度公式: Y = 0.299*R + 0.587*G + 0.114*B
int gray = (306 * pixel[2] + 601 * pixel[1] + 117 * pixel[0]) >> 10;
region.data[p] = static_cast<uint8_t>(std::min(255, std::max(0, gray)));
}
break;
}
case ALGORITHM_RGB565: {
// RGB565: 2字节/像素
region.length = static_cast<uint32_t>(diffPixels); // 像素数
region.data.resize(diffPixels * 2);
uint16_t* out = reinterpret_cast<uint16_t*>(region.data.data());
for (size_t p = 0; p < diffPixels; p++) {
const uint8_t* pixel = pixelStart + p * 4;
// BGRA -> RGB565
out[p] = ((pixel[2] >> 3) << 11) | // R: 5位
((pixel[1] >> 2) << 5) | // G: 6位
(pixel[0] >> 3); // B: 5位
}
break;
}
case ALGORITHM_DIFF:
case ALGORITHM_H264:
default: {
// DIFF/H264: 4字节/像素原始BGRA
region.length = static_cast<uint32_t>(diffPixels * 4); // 字节数
region.data.resize(diffPixels * 4);
memcpy(region.data.data(), pixelStart, diffPixels * 4);
break;
}
}
regions.push_back(region);
}
return regions;
}
// 序列化差异区域到缓冲区
static size_t SerializeDiffRegions(
const std::vector<DiffRegion>& regions,
uint8_t* buffer,
size_t bufferSize
) {
size_t offset = 0;
for (const auto& region : regions) {
size_t needed = 8 + region.data.size(); // offset(4) + length(4) + data
if (offset + needed > bufferSize) break;
memcpy(buffer + offset, &region.offset, 4);
offset += 4;
memcpy(buffer + offset, &region.length, 4);
offset += 4;
memcpy(buffer + offset, region.data.data(), region.data.size());
offset += region.data.size();
}
return offset;
}
// 应用差异到目标帧
static void ApplyDiff(
uint8_t* dstData,
size_t dstLength,
const std::vector<DiffRegion>& regions,
int algorithm
) {
for (const auto& region : regions) {
if (region.offset >= dstLength) continue;
uint8_t* dst = dstData + region.offset;
switch (algorithm) {
case ALGORITHM_GRAY: {
// 灰度 -> BGRA
for (uint32_t p = 0; p < region.length && region.offset + p * 4 < dstLength; p++) {
uint8_t gray = region.data[p];
dst[p * 4 + 0] = gray; // B
dst[p * 4 + 1] = gray; // G
dst[p * 4 + 2] = gray; // R
dst[p * 4 + 3] = 0xFF; // A
}
break;
}
case ALGORITHM_RGB565: {
// RGB565 -> BGRA
const uint16_t* src = reinterpret_cast<const uint16_t*>(region.data.data());
for (uint32_t p = 0; p < region.length && region.offset + p * 4 < dstLength; p++) {
uint16_t c = src[p];
uint8_t r5 = (c >> 11) & 0x1F;
uint8_t g6 = (c >> 5) & 0x3F;
uint8_t b5 = c & 0x1F;
dst[p * 4 + 0] = (b5 << 3) | (b5 >> 2); // B
dst[p * 4 + 1] = (g6 << 2) | (g6 >> 4); // G
dst[p * 4 + 2] = (r5 << 3) | (r5 >> 2); // R
dst[p * 4 + 3] = 0xFF; // A
}
break;
}
case ALGORITHM_DIFF:
case ALGORITHM_H264:
default: {
// 原始BGRA
size_t copyLen = std::min(static_cast<size_t>(region.length),
dstLength - region.offset);
memcpy(dst, region.data.data(), copyLen);
break;
}
}
}
}
};
// ============================================
// 测试夹具
// ============================================
class DiffAlgorithmTest : public ::testing::Test {
protected:
// 创建纯色帧 (BGRA格式)
static std::vector<uint8_t> CreateSolidFrame(int width, int height,
uint8_t b, uint8_t g,
uint8_t r, uint8_t a = 0xFF) {
std::vector<uint8_t> frame(width * height * 4);
for (int i = 0; i < width * height; i++) {
frame[i * 4 + 0] = b;
frame[i * 4 + 1] = g;
frame[i * 4 + 2] = r;
frame[i * 4 + 3] = a;
}
return frame;
}
// 创建渐变帧
static std::vector<uint8_t> CreateGradientFrame(int width, int height) {
std::vector<uint8_t> frame(width * height * 4);
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = (y * width + x) * 4;
frame[idx + 0] = static_cast<uint8_t>(x * 255 / width); // B
frame[idx + 1] = static_cast<uint8_t>(y * 255 / height); // G
frame[idx + 2] = static_cast<uint8_t>((x + y) * 128 / (width + height)); // R
frame[idx + 3] = 0xFF;
}
}
return frame;
}
// 创建带随机区域变化的帧
static std::vector<uint8_t> CreateFrameWithChanges(
const std::vector<uint8_t>& baseFrame,
int width, int height,
int changeX, int changeY,
int changeW, int changeH,
uint8_t newB, uint8_t newG, uint8_t newR
) {
std::vector<uint8_t> frame = baseFrame;
for (int y = changeY; y < changeY + changeH && y < height; y++) {
for (int x = changeX; x < changeX + changeW && x < width; x++) {
int idx = (y * width + x) * 4;
frame[idx + 0] = newB;
frame[idx + 1] = newG;
frame[idx + 2] = newR;
frame[idx + 3] = 0xFF;
}
}
return frame;
}
};
// ============================================
// 基础功能测试
// ============================================
TEST_F(DiffAlgorithmTest, IdenticalFrames_NoDifference) {
auto frame = CreateSolidFrame(100, 100, 128, 128, 128);
auto regions = DiffAlgorithm::CompareBitmap(
frame.data(), frame.data(), frame.size(), ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 0u);
}
TEST_F(DiffAlgorithmTest, CompletelyDifferent_SingleRegion) {
auto frame1 = CreateSolidFrame(10, 10, 0, 0, 0);
auto frame2 = CreateSolidFrame(10, 10, 255, 255, 255);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].offset, 0u);
EXPECT_EQ(regions[0].length, 100u * 4); // 100像素 * 4字节
}
TEST_F(DiffAlgorithmTest, PartialChange_SingleRegion) {
const int WIDTH = 100, HEIGHT = 100;
auto frame1 = CreateSolidFrame(WIDTH, HEIGHT, 0, 0, 0);
auto frame2 = CreateFrameWithChanges(frame1, WIDTH, HEIGHT,
10, 10, 20, 20, 255, 255, 255);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_DIFF);
// 应该检测到变化区域
EXPECT_GT(regions.size(), 0u);
// 验证总变化像素数
size_t totalChangedPixels = 0;
for (const auto& r : regions) {
totalChangedPixels += r.length / 4; // DIFF算法length是字节数
}
EXPECT_EQ(totalChangedPixels, 20u * 20u); // 20x20区域
}
TEST_F(DiffAlgorithmTest, MultipleRegions_NonContiguous) {
const int WIDTH = 100, HEIGHT = 10;
auto frame1 = CreateSolidFrame(WIDTH, HEIGHT, 128, 128, 128);
auto frame2 = frame1;
// 创建两个不相邻的变化区域
// 区域1: 像素 5-14
for (int i = 5; i < 15; i++) {
frame2[i * 4 + 0] = 0;
frame2[i * 4 + 1] = 0;
frame2[i * 4 + 2] = 255;
}
// 区域2: 像素 50-59 (与区域1不相邻)
for (int i = 50; i < 60; i++) {
frame2[i * 4 + 0] = 255;
frame2[i * 4 + 1] = 0;
frame2[i * 4 + 2] = 0;
}
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 2u);
}
// ============================================
// 算法特定测试
// ============================================
TEST_F(DiffAlgorithmTest, GrayAlgorithm_CorrectOutput) {
auto frame1 = CreateSolidFrame(10, 10, 0, 0, 0); // 黑色
auto frame2 = CreateSolidFrame(10, 10, 255, 255, 255); // 白色
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_GRAY);
ASSERT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].length, 100u); // 100像素
EXPECT_EQ(regions[0].data.size(), 100u); // 1字节/像素
// 白色应该转换为灰度255
EXPECT_EQ(regions[0].data[0], 255);
}
TEST_F(DiffAlgorithmTest, GrayAlgorithm_GrayConversionFormula) {
// 测试灰度转换公式: Y = 0.299*R + 0.587*G + 0.114*B
std::vector<uint8_t> frame1(4, 0); // 1像素黑色
std::vector<uint8_t> frame2(4);
// 测试纯红色 (R=255, G=0, B=0)
frame2[0] = 0; // B
frame2[1] = 0; // G
frame2[2] = 255; // R
frame2[3] = 255; // A
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), 4, ALGORITHM_GRAY);
ASSERT_EQ(regions.size(), 1u);
// 期望: (306 * 255 + 601 * 0 + 117 * 0) >> 10 ≈ 76
uint8_t expectedGray = (306 * 255) >> 10;
EXPECT_NEAR(regions[0].data[0], expectedGray, 1);
}
TEST_F(DiffAlgorithmTest, RGB565Algorithm_CorrectOutput) {
auto frame1 = CreateSolidFrame(10, 10, 0, 0, 0);
auto frame2 = CreateSolidFrame(10, 10, 255, 255, 255);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_RGB565);
ASSERT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].length, 100u); // 100像素
EXPECT_EQ(regions[0].data.size(), 200u); // 2字节/像素
// 白色 RGB565 = 0xFFFF
uint16_t* rgb565 = reinterpret_cast<uint16_t*>(regions[0].data.data());
EXPECT_EQ(rgb565[0], 0xFFFF);
}
TEST_F(DiffAlgorithmTest, RGB565Algorithm_ColorConversion) {
std::vector<uint8_t> frame1(4, 0); // 1像素黑色
std::vector<uint8_t> frame2(4);
// 纯红色 (R=255, G=0, B=0) -> RGB565 = 0xF800
frame2[0] = 0; // B
frame2[1] = 0; // G
frame2[2] = 255; // R
frame2[3] = 255; // A
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), 4, ALGORITHM_RGB565);
ASSERT_EQ(regions.size(), 1u);
uint16_t* rgb565 = reinterpret_cast<uint16_t*>(regions[0].data.data());
EXPECT_EQ(rgb565[0], 0xF800); // 纯红色
}
TEST_F(DiffAlgorithmTest, DiffAlgorithm_PreservesOriginalData) {
std::vector<uint8_t> frame1(8, 0); // 2像素黑色
std::vector<uint8_t> frame2 = {
0x12, 0x34, 0x56, 0x78, // 像素1
0xAB, 0xCD, 0xEF, 0xFF // 像素2
};
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), 8, ALGORITHM_DIFF);
ASSERT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].length, 8u); // 8字节
EXPECT_EQ(regions[0].data, frame2); // 原始数据完整保留
}
// ============================================
// 边界条件测试
// ============================================
TEST_F(DiffAlgorithmTest, EmptyInput_NoRegions) {
auto regions = DiffAlgorithm::CompareBitmap(nullptr, nullptr, 0, ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 0u);
}
TEST_F(DiffAlgorithmTest, SinglePixel_Difference) {
std::vector<uint8_t> frame1 = {0, 0, 0, 255};
std::vector<uint8_t> frame2 = {255, 255, 255, 255};
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), 4, ALGORITHM_DIFF);
ASSERT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].offset, 0u);
EXPECT_EQ(regions[0].length, 4u);
}
TEST_F(DiffAlgorithmTest, SinglePixel_NoDifference) {
std::vector<uint8_t> frame = {100, 150, 200, 255};
auto regions = DiffAlgorithm::CompareBitmap(
frame.data(), frame.data(), 4, ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 0u);
}
TEST_F(DiffAlgorithmTest, NonAlignedLength_Rejected) {
std::vector<uint8_t> data(7); // 不是4的倍数
auto regions = DiffAlgorithm::CompareBitmap(
data.data(), data.data(), data.size(), ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 0u);
}
TEST_F(DiffAlgorithmTest, FirstPixelOnly_Changed) {
std::vector<uint8_t> frame1(40, 128); // 10像素
std::vector<uint8_t> frame2 = frame1;
frame2[0] = 0; // 只改变第一个像素的B通道
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), 40, ALGORITHM_DIFF);
ASSERT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].offset, 0u);
EXPECT_EQ(regions[0].length, 4u); // 只有1个像素
}
TEST_F(DiffAlgorithmTest, LastPixelOnly_Changed) {
std::vector<uint8_t> frame1(40, 128); // 10像素
std::vector<uint8_t> frame2 = frame1;
frame2[36] = 0; // 只改变最后一个像素的B通道
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), 40, ALGORITHM_DIFF);
ASSERT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].offset, 36u); // 第9个像素的偏移
EXPECT_EQ(regions[0].length, 4u);
}
// ============================================
// 序列化测试
// ============================================
TEST_F(DiffAlgorithmTest, Serialize_SingleRegion) {
std::vector<DiffRegion> regions;
DiffRegion r;
r.offset = 100;
r.length = 16;
r.data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
regions.push_back(r);
std::vector<uint8_t> buffer(1024);
size_t written = DiffAlgorithm::SerializeDiffRegions(
regions, buffer.data(), buffer.size());
EXPECT_EQ(written, 8u + 16u); // offset(4) + length(4) + data(16)
// 验证偏移
uint32_t readOffset;
memcpy(&readOffset, buffer.data(), 4);
EXPECT_EQ(readOffset, 100u);
// 验证长度
uint32_t readLength;
memcpy(&readLength, buffer.data() + 4, 4);
EXPECT_EQ(readLength, 16u);
}
TEST_F(DiffAlgorithmTest, Serialize_MultipleRegions) {
std::vector<DiffRegion> regions;
DiffRegion r1;
r1.offset = 0;
r1.length = 4;
r1.data = {1, 2, 3, 4};
regions.push_back(r1);
DiffRegion r2;
r2.offset = 100;
r2.length = 8;
r2.data = {5, 6, 7, 8, 9, 10, 11, 12};
regions.push_back(r2);
std::vector<uint8_t> buffer(1024);
size_t written = DiffAlgorithm::SerializeDiffRegions(
regions, buffer.data(), buffer.size());
EXPECT_EQ(written, (8u + 4u) + (8u + 8u)); // 两个区域
}
TEST_F(DiffAlgorithmTest, Serialize_BufferTooSmall) {
std::vector<DiffRegion> regions;
DiffRegion r;
r.offset = 0;
r.length = 100;
r.data.resize(100, 0xFF);
regions.push_back(r);
std::vector<uint8_t> buffer(10); // 太小
size_t written = DiffAlgorithm::SerializeDiffRegions(
regions, buffer.data(), buffer.size());
EXPECT_EQ(written, 0u); // 无法写入
}
// ============================================
// 应用差异测试
// ============================================
TEST_F(DiffAlgorithmTest, ApplyDiff_DIFF_Reconstructs) {
auto frame1 = CreateGradientFrame(50, 50);
auto frame2 = CreateFrameWithChanges(frame1, 50, 50, 10, 10, 20, 20, 255, 0, 0);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_DIFF);
// 应用差异到frame1的副本
std::vector<uint8_t> reconstructed = frame1;
DiffAlgorithm::ApplyDiff(reconstructed.data(), reconstructed.size(),
regions, ALGORITHM_DIFF);
// 应该完全重建frame2
EXPECT_EQ(reconstructed, frame2);
}
TEST_F(DiffAlgorithmTest, ApplyDiff_GRAY_ApproximateReconstruction) {
auto frame1 = CreateSolidFrame(10, 10, 0, 0, 0);
auto frame2 = CreateSolidFrame(10, 10, 200, 200, 200); // 浅灰色
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_GRAY);
std::vector<uint8_t> reconstructed = frame1;
DiffAlgorithm::ApplyDiff(reconstructed.data(), reconstructed.size(),
regions, ALGORITHM_GRAY);
// 灰度重建R=G=B
for (size_t i = 0; i < reconstructed.size(); i += 4) {
EXPECT_EQ(reconstructed[i], reconstructed[i + 1]); // B == G
EXPECT_EQ(reconstructed[i + 1], reconstructed[i + 2]); // G == R
}
}
TEST_F(DiffAlgorithmTest, ApplyDiff_RGB565_ApproximateReconstruction) {
auto frame1 = CreateSolidFrame(10, 10, 0, 0, 0);
// RGB565有量化误差使用能精确表示的颜色
auto frame2 = CreateSolidFrame(10, 10, 248, 252, 248); // 接近白色
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_RGB565);
std::vector<uint8_t> reconstructed = frame1;
DiffAlgorithm::ApplyDiff(reconstructed.data(), reconstructed.size(),
regions, ALGORITHM_RGB565);
// 验证重建颜色接近原始(允许量化误差)
for (size_t i = 0; i < reconstructed.size(); i += 4) {
EXPECT_NEAR(reconstructed[i], frame2[i], 8); // B
EXPECT_NEAR(reconstructed[i + 1], frame2[i + 1], 4); // G (6位精度)
EXPECT_NEAR(reconstructed[i + 2], frame2[i + 2], 8); // R
}
}
// ============================================
// 性能相关测试(不计时,只验证正确性)
// ============================================
TEST_F(DiffAlgorithmTest, LargeFrame_1080p_Correctness) {
const int WIDTH = 1920, HEIGHT = 1080;
auto frame1 = CreateSolidFrame(WIDTH, HEIGHT, 128, 128, 128);
auto frame2 = CreateFrameWithChanges(frame1, WIDTH, HEIGHT,
100, 100, 200, 200, 255, 0, 0);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_DIFF);
// 应该有变化区域
EXPECT_GT(regions.size(), 0u);
// 重建验证
std::vector<uint8_t> reconstructed = frame1;
DiffAlgorithm::ApplyDiff(reconstructed.data(), reconstructed.size(),
regions, ALGORITHM_DIFF);
EXPECT_EQ(reconstructed, frame2);
}
TEST_F(DiffAlgorithmTest, RandomChanges_AllAlgorithms) {
const int WIDTH = 100, HEIGHT = 100;
std::mt19937 rng(42);
std::uniform_int_distribution<int> dist(0, 255);
auto frame1 = CreateGradientFrame(WIDTH, HEIGHT);
auto frame2 = frame1;
// 随机修改10%的像素
for (int i = 0; i < WIDTH * HEIGHT / 10; i++) {
int idx = (rng() % (WIDTH * HEIGHT)) * 4;
frame2[idx + 0] = dist(rng);
frame2[idx + 1] = dist(rng);
frame2[idx + 2] = dist(rng);
}
// 测试所有算法都能产生输出
for (int algo : {ALGORITHM_GRAY, ALGORITHM_DIFF, ALGORITHM_RGB565}) {
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), algo);
EXPECT_GT(regions.size(), 0u) << "Algorithm " << algo << " failed";
}
}
// ============================================
// 压缩效率测试
// ============================================
TEST_F(DiffAlgorithmTest, CompressionRatio_GRAY) {
auto frame1 = CreateSolidFrame(100, 100, 0, 0, 0);
auto frame2 = CreateSolidFrame(100, 100, 255, 255, 255);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_GRAY);
size_t originalSize = 100 * 100 * 4; // 40000 字节
size_t compressedSize = 8 + regions[0].data.size(); // offset + length + data
// GRAY应该是 100*100*1 = 10000 字节数据
EXPECT_EQ(regions[0].data.size(), 10000u);
// 压缩比约 4:1
EXPECT_LT(compressedSize, originalSize / 3);
}
TEST_F(DiffAlgorithmTest, CompressionRatio_RGB565) {
auto frame1 = CreateSolidFrame(100, 100, 0, 0, 0);
auto frame2 = CreateSolidFrame(100, 100, 255, 255, 255);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_RGB565);
// RGB565应该是 100*100*2 = 20000 字节数据
EXPECT_EQ(regions[0].data.size(), 20000u);
}
TEST_F(DiffAlgorithmTest, NoChange_ZeroOverhead) {
auto frame = CreateGradientFrame(100, 100);
auto regions = DiffAlgorithm::CompareBitmap(
frame.data(), frame.data(), frame.size(), ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 0u);
std::vector<uint8_t> buffer(1024);
size_t written = DiffAlgorithm::SerializeDiffRegions(
regions, buffer.data(), buffer.size());
EXPECT_EQ(written, 0u); // 无变化时零开销
}
// ============================================
// 参数化测试:不同分辨率
// ============================================
class DiffAlgorithmResolutionTest : public ::testing::TestWithParam<std::tuple<int, int>> {};
TEST_P(DiffAlgorithmResolutionTest, Resolution_Correctness) {
auto [width, height] = GetParam();
auto frame1 = std::vector<uint8_t>(width * height * 4, 0);
auto frame2 = std::vector<uint8_t>(width * height * 4, 255);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].offset, 0u);
EXPECT_EQ(regions[0].length, static_cast<uint32_t>(width * height * 4));
}
INSTANTIATE_TEST_SUITE_P(
Resolutions,
DiffAlgorithmResolutionTest,
::testing::Values(
std::make_tuple(1, 1), // 最小
std::make_tuple(10, 10), // 小
std::make_tuple(100, 100), // 中
std::make_tuple(640, 480), // VGA
std::make_tuple(1280, 720), // 720p
std::make_tuple(1920, 1080) // 1080p
)
);

View File

@@ -0,0 +1,691 @@
// QualityAdaptiveTest.cpp - Phase 4: 质量自适应单元测试
// 测试 RTT 到质量等级的映射和防抖策略
#include <gtest/gtest.h>
#include <vector>
#include <cstdint>
#include <climits>
// ============================================
// 质量等级枚举 (来自 commands.h)
// ============================================
enum QualityLevel {
QUALITY_DISABLED = -2, // 关闭质量控制
QUALITY_ADAPTIVE = -1, // 自适应模式
QUALITY_ULTRA = 0, // 极佳(局域网)
QUALITY_HIGH = 1, // 优秀
QUALITY_GOOD = 2, // 良好
QUALITY_MEDIUM = 3, // 一般
QUALITY_LOW = 4, // 较差
QUALITY_MINIMAL = 5, // 最低
QUALITY_COUNT = 6,
};
// ============================================
// 算法枚举
// ============================================
#define ALGORITHM_GRAY 0
#define ALGORITHM_DIFF 1
#define ALGORITHM_H264 2
#define ALGORITHM_RGB565 3
// ============================================
// 质量配置结构体 (来自 commands.h)
// ============================================
struct QualityProfile {
int maxFPS; // 最大帧率
int maxWidth; // 最大宽度 (0=不限)
int algorithm; // 压缩算法
int bitRate; // kbps (仅H264使用)
};
// 默认质量配置表
static const QualityProfile g_QualityProfiles[QUALITY_COUNT] = {
{25, 0, ALGORITHM_DIFF, 0 }, // Ultra: 25FPS, 原始, DIFF
{20, 0, ALGORITHM_RGB565, 0 }, // High: 20FPS, 原始, RGB565
{20, 1920, ALGORITHM_H264, 3000}, // Good: 20FPS, 1080P, H264
{15, 1600, ALGORITHM_H264, 2000}, // Medium: 15FPS, 900P, H264
{12, 1280, ALGORITHM_H264, 1200}, // Low: 12FPS, 720P, H264
{8, 1024, ALGORITHM_H264, 800 }, // Minimal: 8FPS, 540P, H264
};
// ============================================
// RTT 阈值表 (来自 commands.h)
// ============================================
// 行0: 直连模式, 行1: FRP代理模式
static const int g_RttThresholds[2][QUALITY_COUNT] = {
/* DIRECT */ { 30, 80, 150, 250, 400, INT_MAX },
/* PROXY */ { 60, 160, 300, 500, 800, INT_MAX },
};
// ============================================
// RTT到质量等级映射函数 (来自 commands.h)
// ============================================
inline int GetTargetQualityLevel(int rtt, int usingFRP)
{
int row = usingFRP ? 1 : 0;
for (int level = 0; level < QUALITY_COUNT; level++) {
if (rtt < g_RttThresholds[row][level]) {
return level;
}
}
return QUALITY_MINIMAL;
}
// ============================================
// 防抖策略模拟器
// ============================================
class QualityDebouncer {
public:
// 防抖参数
static const int DOWNGRADE_STABLE_COUNT = 2; // 降级所需稳定次数
static const int UPGRADE_STABLE_COUNT = 5; // 升级所需稳定次数
static const int DEFAULT_COOLDOWN_MS = 3000; // 默认冷却时间
static const int RES_CHANGE_DOWNGRADE_COOLDOWN_MS = 15000; // 分辨率降级冷却
static const int RES_CHANGE_UPGRADE_COOLDOWN_MS = 30000; // 分辨率升级冷却
static const int STARTUP_DELAY_MS = 60000; // 启动延迟
QualityDebouncer()
: m_currentLevel(QUALITY_HIGH)
, m_stableCount(0)
, m_lastChangeTime(0)
, m_startTime(0)
, m_enabled(true)
{}
void Reset() {
m_currentLevel = QUALITY_HIGH;
m_stableCount = 0;
m_lastChangeTime = 0;
m_startTime = 0;
}
void SetStartTime(uint64_t time) {
m_startTime = time;
}
void SetEnabled(bool enabled) {
m_enabled = enabled;
}
// 评估并返回新的质量等级
// 返回 -1 表示不改变
int Evaluate(int targetLevel, uint64_t currentTime, bool resolutionChange = false) {
if (!m_enabled) return -1;
// 启动延迟
if (currentTime - m_startTime < STARTUP_DELAY_MS) {
return -1;
}
// 冷却时间检查
uint64_t cooldown = DEFAULT_COOLDOWN_MS;
if (resolutionChange) {
cooldown = (targetLevel > m_currentLevel)
? RES_CHANGE_DOWNGRADE_COOLDOWN_MS
: RES_CHANGE_UPGRADE_COOLDOWN_MS;
}
if (currentTime - m_lastChangeTime < cooldown) {
return -1;
}
// 降级: 快速响应
if (targetLevel > m_currentLevel) {
m_stableCount++;
if (m_stableCount >= DOWNGRADE_STABLE_COUNT) {
int newLevel = targetLevel;
m_currentLevel = newLevel;
m_stableCount = 0;
m_lastChangeTime = currentTime;
return newLevel;
}
}
// 升级: 谨慎处理,每次只升一级
else if (targetLevel < m_currentLevel) {
m_stableCount++;
if (m_stableCount >= UPGRADE_STABLE_COUNT) {
int newLevel = m_currentLevel - 1; // 只升一级
m_currentLevel = newLevel;
m_stableCount = 0;
m_lastChangeTime = currentTime;
return newLevel;
}
}
// 目标等级等于当前等级
else {
m_stableCount = 0;
}
return -1;
}
int GetCurrentLevel() const { return m_currentLevel; }
int GetStableCount() const { return m_stableCount; }
private:
int m_currentLevel;
int m_stableCount;
uint64_t m_lastChangeTime;
uint64_t m_startTime;
bool m_enabled;
};
// ============================================
// 测试夹具
// ============================================
class QualityAdaptiveTest : public ::testing::Test {
protected:
void SetUp() override {
debouncer.Reset();
}
QualityDebouncer debouncer;
};
// ============================================
// RTT 映射测试 - 直连模式
// ============================================
TEST_F(QualityAdaptiveTest, RTT_Direct_Ultra) {
// RTT < 30ms -> Ultra
EXPECT_EQ(GetTargetQualityLevel(0, 0), QUALITY_ULTRA);
EXPECT_EQ(GetTargetQualityLevel(10, 0), QUALITY_ULTRA);
EXPECT_EQ(GetTargetQualityLevel(29, 0), QUALITY_ULTRA);
}
TEST_F(QualityAdaptiveTest, RTT_Direct_High) {
// 30ms <= RTT < 80ms -> High
EXPECT_EQ(GetTargetQualityLevel(30, 0), QUALITY_HIGH);
EXPECT_EQ(GetTargetQualityLevel(50, 0), QUALITY_HIGH);
EXPECT_EQ(GetTargetQualityLevel(79, 0), QUALITY_HIGH);
}
TEST_F(QualityAdaptiveTest, RTT_Direct_Good) {
// 80ms <= RTT < 150ms -> Good
EXPECT_EQ(GetTargetQualityLevel(80, 0), QUALITY_GOOD);
EXPECT_EQ(GetTargetQualityLevel(100, 0), QUALITY_GOOD);
EXPECT_EQ(GetTargetQualityLevel(149, 0), QUALITY_GOOD);
}
TEST_F(QualityAdaptiveTest, RTT_Direct_Medium) {
// 150ms <= RTT < 250ms -> Medium
EXPECT_EQ(GetTargetQualityLevel(150, 0), QUALITY_MEDIUM);
EXPECT_EQ(GetTargetQualityLevel(200, 0), QUALITY_MEDIUM);
EXPECT_EQ(GetTargetQualityLevel(249, 0), QUALITY_MEDIUM);
}
TEST_F(QualityAdaptiveTest, RTT_Direct_Low) {
// 250ms <= RTT < 400ms -> Low
EXPECT_EQ(GetTargetQualityLevel(250, 0), QUALITY_LOW);
EXPECT_EQ(GetTargetQualityLevel(300, 0), QUALITY_LOW);
EXPECT_EQ(GetTargetQualityLevel(399, 0), QUALITY_LOW);
}
TEST_F(QualityAdaptiveTest, RTT_Direct_Minimal) {
// RTT >= 400ms -> Minimal
EXPECT_EQ(GetTargetQualityLevel(400, 0), QUALITY_MINIMAL);
EXPECT_EQ(GetTargetQualityLevel(500, 0), QUALITY_MINIMAL);
EXPECT_EQ(GetTargetQualityLevel(1000, 0), QUALITY_MINIMAL);
}
// ============================================
// RTT 映射测试 - FRP代理模式
// ============================================
TEST_F(QualityAdaptiveTest, RTT_FRP_Ultra) {
// RTT < 60ms -> Ultra (FRP模式阈值更宽松)
EXPECT_EQ(GetTargetQualityLevel(0, 1), QUALITY_ULTRA);
EXPECT_EQ(GetTargetQualityLevel(30, 1), QUALITY_ULTRA);
EXPECT_EQ(GetTargetQualityLevel(59, 1), QUALITY_ULTRA);
}
TEST_F(QualityAdaptiveTest, RTT_FRP_High) {
// 60ms <= RTT < 160ms -> High
EXPECT_EQ(GetTargetQualityLevel(60, 1), QUALITY_HIGH);
EXPECT_EQ(GetTargetQualityLevel(100, 1), QUALITY_HIGH);
EXPECT_EQ(GetTargetQualityLevel(159, 1), QUALITY_HIGH);
}
TEST_F(QualityAdaptiveTest, RTT_FRP_Good) {
// 160ms <= RTT < 300ms -> Good
EXPECT_EQ(GetTargetQualityLevel(160, 1), QUALITY_GOOD);
EXPECT_EQ(GetTargetQualityLevel(200, 1), QUALITY_GOOD);
EXPECT_EQ(GetTargetQualityLevel(299, 1), QUALITY_GOOD);
}
TEST_F(QualityAdaptiveTest, RTT_FRP_Medium) {
// 300ms <= RTT < 500ms -> Medium
EXPECT_EQ(GetTargetQualityLevel(300, 1), QUALITY_MEDIUM);
EXPECT_EQ(GetTargetQualityLevel(400, 1), QUALITY_MEDIUM);
EXPECT_EQ(GetTargetQualityLevel(499, 1), QUALITY_MEDIUM);
}
TEST_F(QualityAdaptiveTest, RTT_FRP_Low) {
// 500ms <= RTT < 800ms -> Low
EXPECT_EQ(GetTargetQualityLevel(500, 1), QUALITY_LOW);
EXPECT_EQ(GetTargetQualityLevel(600, 1), QUALITY_LOW);
EXPECT_EQ(GetTargetQualityLevel(799, 1), QUALITY_LOW);
}
TEST_F(QualityAdaptiveTest, RTT_FRP_Minimal) {
// RTT >= 800ms -> Minimal
EXPECT_EQ(GetTargetQualityLevel(800, 1), QUALITY_MINIMAL);
EXPECT_EQ(GetTargetQualityLevel(1000, 1), QUALITY_MINIMAL);
EXPECT_EQ(GetTargetQualityLevel(2000, 1), QUALITY_MINIMAL);
}
// ============================================
// 质量配置表测试
// ============================================
TEST_F(QualityAdaptiveTest, Profile_Ultra) {
const auto& p = g_QualityProfiles[QUALITY_ULTRA];
EXPECT_EQ(p.maxFPS, 25);
EXPECT_EQ(p.maxWidth, 0); // 无限制
EXPECT_EQ(p.algorithm, ALGORITHM_DIFF);
EXPECT_EQ(p.bitRate, 0);
}
TEST_F(QualityAdaptiveTest, Profile_High) {
const auto& p = g_QualityProfiles[QUALITY_HIGH];
EXPECT_EQ(p.maxFPS, 20);
EXPECT_EQ(p.maxWidth, 0); // 无限制
EXPECT_EQ(p.algorithm, ALGORITHM_RGB565);
EXPECT_EQ(p.bitRate, 0);
}
TEST_F(QualityAdaptiveTest, Profile_Good) {
const auto& p = g_QualityProfiles[QUALITY_GOOD];
EXPECT_EQ(p.maxFPS, 20);
EXPECT_EQ(p.maxWidth, 1920); // 1080p
EXPECT_EQ(p.algorithm, ALGORITHM_H264);
EXPECT_EQ(p.bitRate, 3000);
}
TEST_F(QualityAdaptiveTest, Profile_Medium) {
const auto& p = g_QualityProfiles[QUALITY_MEDIUM];
EXPECT_EQ(p.maxFPS, 15);
EXPECT_EQ(p.maxWidth, 1600); // 900p
EXPECT_EQ(p.algorithm, ALGORITHM_H264);
EXPECT_EQ(p.bitRate, 2000);
}
TEST_F(QualityAdaptiveTest, Profile_Low) {
const auto& p = g_QualityProfiles[QUALITY_LOW];
EXPECT_EQ(p.maxFPS, 12);
EXPECT_EQ(p.maxWidth, 1280); // 720p
EXPECT_EQ(p.algorithm, ALGORITHM_H264);
EXPECT_EQ(p.bitRate, 1200);
}
TEST_F(QualityAdaptiveTest, Profile_Minimal) {
const auto& p = g_QualityProfiles[QUALITY_MINIMAL];
EXPECT_EQ(p.maxFPS, 8);
EXPECT_EQ(p.maxWidth, 1024); // 540p
EXPECT_EQ(p.algorithm, ALGORITHM_H264);
EXPECT_EQ(p.bitRate, 800);
}
// ============================================
// 防抖策略测试 - 启动延迟
// ============================================
TEST_F(QualityAdaptiveTest, Debounce_StartupDelay) {
debouncer.SetStartTime(0);
// 启动后60秒内不应改变
EXPECT_EQ(debouncer.Evaluate(QUALITY_MINIMAL, 1000), -1);
EXPECT_EQ(debouncer.Evaluate(QUALITY_MINIMAL, 30000), -1);
EXPECT_EQ(debouncer.Evaluate(QUALITY_MINIMAL, 59999), -1);
}
TEST_F(QualityAdaptiveTest, Debounce_AfterStartupDelay) {
debouncer.SetStartTime(0);
// 启动60秒后应该可以改变
// 需要连续2次降级请求
debouncer.Evaluate(QUALITY_MINIMAL, 60000); // 第1次
int result = debouncer.Evaluate(QUALITY_MINIMAL, 60001); // 第2次
EXPECT_EQ(result, QUALITY_MINIMAL);
}
// ============================================
// 防抖策略测试 - 降级
// ============================================
TEST_F(QualityAdaptiveTest, Debounce_Downgrade_RequiresTwice) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 第1次降级请求 - 不应立即执行
int result = debouncer.Evaluate(QUALITY_MINIMAL, time);
EXPECT_EQ(result, -1);
EXPECT_EQ(debouncer.GetStableCount(), 1);
// 第2次降级请求 - 应该执行
result = debouncer.Evaluate(QUALITY_MINIMAL, time + 100);
EXPECT_EQ(result, QUALITY_MINIMAL);
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_MINIMAL);
}
TEST_F(QualityAdaptiveTest, Debounce_Downgrade_ResetOnStable) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 第1次降级请求
debouncer.Evaluate(QUALITY_LOW, time);
EXPECT_EQ(debouncer.GetStableCount(), 1);
// 目标等级恢复 - 计数应重置
debouncer.Evaluate(QUALITY_HIGH, time + 100); // 当前等级
EXPECT_EQ(debouncer.GetStableCount(), 0);
}
// ============================================
// 防抖策略测试 - 升级
// ============================================
TEST_F(QualityAdaptiveTest, Debounce_Upgrade_RequiresFiveTimes) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 先降级到 MINIMAL
debouncer.Evaluate(QUALITY_MINIMAL, time);
debouncer.Evaluate(QUALITY_MINIMAL, time + 100);
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_MINIMAL);
// 尝试升级到 ULTRA (需要5次)
time += 5000; // 冷却后
for (int i = 0; i < 4; i++) {
int result = debouncer.Evaluate(QUALITY_ULTRA, time + i * 100);
EXPECT_EQ(result, -1); // 前4次不应执行
EXPECT_EQ(debouncer.GetStableCount(), i + 1);
}
// 第5次应该执行但只升一级
int result = debouncer.Evaluate(QUALITY_ULTRA, time + 500);
EXPECT_EQ(result, QUALITY_LOW); // MINIMAL -> LOW (只升一级)
}
TEST_F(QualityAdaptiveTest, Debounce_Upgrade_OneStepAtATime) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 降级到 MINIMAL
debouncer.Evaluate(QUALITY_MINIMAL, time);
debouncer.Evaluate(QUALITY_MINIMAL, time + 100);
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_MINIMAL);
// 多次升级请求,验证每次只升一级
// 从 MINIMAL(5) 升到 HIGH(1) 需要4次升级每次需要5个稳定请求
time += 5000;
int upgradeCount = 0;
while (debouncer.GetCurrentLevel() > QUALITY_HIGH && upgradeCount < 10) {
for (int i = 0; i < 5; i++) {
debouncer.Evaluate(QUALITY_ULTRA, time); // 请求最高等级
time += 100;
}
time += 5000; // 冷却
upgradeCount++;
}
// 最终应该回到 HIGH (或更高)
EXPECT_LE(debouncer.GetCurrentLevel(), QUALITY_HIGH);
}
// ============================================
// 防抖策略测试 - 冷却时间
// ============================================
TEST_F(QualityAdaptiveTest, Debounce_DefaultCooldown) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 执行一次降级
debouncer.Evaluate(QUALITY_LOW, time);
debouncer.Evaluate(QUALITY_LOW, time + 100);
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_LOW);
// 冷却期内不应再次改变
int result = debouncer.Evaluate(QUALITY_MINIMAL, time + 200);
EXPECT_EQ(result, -1);
result = debouncer.Evaluate(QUALITY_MINIMAL, time + 2999);
EXPECT_EQ(result, -1);
}
TEST_F(QualityAdaptiveTest, Debounce_AfterCooldown) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 执行一次降级
debouncer.Evaluate(QUALITY_LOW, time);
debouncer.Evaluate(QUALITY_LOW, time + 100);
// 冷却后应该可以再次改变
time += 3100;
debouncer.Evaluate(QUALITY_MINIMAL, time);
int result = debouncer.Evaluate(QUALITY_MINIMAL, time + 100);
EXPECT_EQ(result, QUALITY_MINIMAL);
}
// ============================================
// 防抖策略测试 - 分辨率变化冷却
// ============================================
TEST_F(QualityAdaptiveTest, Debounce_ResolutionChange_DowngradeCooldown) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 执行一次带分辨率变化的降级
debouncer.Evaluate(QUALITY_LOW, time, true); // resolutionChange=true
debouncer.Evaluate(QUALITY_LOW, time + 100, true);
// 分辨率降级冷却15秒
int result = debouncer.Evaluate(QUALITY_MINIMAL, time + 14000, true);
EXPECT_EQ(result, -1);
// 15秒后可以
time += 16000;
debouncer.Evaluate(QUALITY_MINIMAL, time, true);
result = debouncer.Evaluate(QUALITY_MINIMAL, time + 100, true);
EXPECT_EQ(result, QUALITY_MINIMAL);
}
TEST_F(QualityAdaptiveTest, Debounce_ResolutionChange_UpgradeCooldown) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 先降级到 MINIMAL (不带分辨率变化,使用默认冷却)
debouncer.Evaluate(QUALITY_MINIMAL, time);
debouncer.Evaluate(QUALITY_MINIMAL, time + 100);
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_MINIMAL);
// 等待足够长时间(>30秒进行带分辨率变化的升级
time += 35000; // 超过30秒冷却
for (int i = 0; i < 5; i++) {
debouncer.Evaluate(QUALITY_ULTRA, time + i * 100, true); // resolutionChange=true
}
// 应该成功升级一级 (MINIMAL -> LOW)
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_LOW);
// 30秒内不应再次升级带分辨率变化的升级需要30秒冷却
time += 1000;
for (int i = 0; i < 5; i++) {
debouncer.Evaluate(QUALITY_ULTRA, time + i * 100, true);
}
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_LOW); // 未改变还在30秒冷却期内
}
// ============================================
// 禁用自适应测试
// ============================================
TEST_F(QualityAdaptiveTest, Debounce_Disabled) {
debouncer.SetStartTime(0);
debouncer.SetEnabled(false);
uint64_t time = 60000;
int result = debouncer.Evaluate(QUALITY_MINIMAL, time);
EXPECT_EQ(result, -1);
result = debouncer.Evaluate(QUALITY_MINIMAL, time + 100);
EXPECT_EQ(result, -1);
// 质量等级应保持不变
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_HIGH);
}
// ============================================
// 边界值测试
// ============================================
TEST_F(QualityAdaptiveTest, RTT_Boundary_Zero) {
EXPECT_EQ(GetTargetQualityLevel(0, 0), QUALITY_ULTRA);
EXPECT_EQ(GetTargetQualityLevel(0, 1), QUALITY_ULTRA);
}
TEST_F(QualityAdaptiveTest, RTT_Boundary_Negative) {
// 负值RTT应该仍返回最高质量
EXPECT_EQ(GetTargetQualityLevel(-1, 0), QUALITY_ULTRA);
EXPECT_EQ(GetTargetQualityLevel(-100, 0), QUALITY_ULTRA);
}
TEST_F(QualityAdaptiveTest, RTT_Boundary_VeryHigh) {
// 非常高的RTT
EXPECT_EQ(GetTargetQualityLevel(10000, 0), QUALITY_MINIMAL);
EXPECT_EQ(GetTargetQualityLevel(100000, 0), QUALITY_MINIMAL);
EXPECT_EQ(GetTargetQualityLevel(INT_MAX - 1, 0), QUALITY_MINIMAL);
}
// ============================================
// 常量验证测试
// ============================================
TEST(QualityConstantsTest, QualityLevelEnum) {
EXPECT_EQ(QUALITY_DISABLED, -2);
EXPECT_EQ(QUALITY_ADAPTIVE, -1);
EXPECT_EQ(QUALITY_ULTRA, 0);
EXPECT_EQ(QUALITY_HIGH, 1);
EXPECT_EQ(QUALITY_GOOD, 2);
EXPECT_EQ(QUALITY_MEDIUM, 3);
EXPECT_EQ(QUALITY_LOW, 4);
EXPECT_EQ(QUALITY_MINIMAL, 5);
EXPECT_EQ(QUALITY_COUNT, 6);
}
TEST(QualityConstantsTest, ProfileCount) {
// 应该有 QUALITY_COUNT 个配置
EXPECT_EQ(sizeof(g_QualityProfiles) / sizeof(g_QualityProfiles[0]),
static_cast<size_t>(QUALITY_COUNT));
}
TEST(QualityConstantsTest, ThresholdCount) {
// 每行应该有 QUALITY_COUNT 个阈值
EXPECT_EQ(sizeof(g_RttThresholds[0]) / sizeof(g_RttThresholds[0][0]),
static_cast<size_t>(QUALITY_COUNT));
}
TEST(QualityConstantsTest, ThresholdIncreasing) {
// 阈值应该递增
for (int row = 0; row < 2; row++) {
for (int i = 0; i < QUALITY_COUNT - 1; i++) {
EXPECT_LT(g_RttThresholds[row][i], g_RttThresholds[row][i + 1])
<< "Row " << row << ", index " << i;
}
}
}
TEST(QualityConstantsTest, FRPThresholdsHigher) {
// FRP模式阈值应该比直连模式高
for (int i = 0; i < QUALITY_COUNT - 1; i++) {
EXPECT_GT(g_RttThresholds[1][i], g_RttThresholds[0][i])
<< "Index " << i;
}
}
// ============================================
// 参数化测试RTT值遍历
// ============================================
class RTTMappingTest : public ::testing::TestWithParam<std::tuple<int, int, int>> {};
TEST_P(RTTMappingTest, RTT_Mapping) {
auto [rtt, usingFRP, expectedLevel] = GetParam();
EXPECT_EQ(GetTargetQualityLevel(rtt, usingFRP), expectedLevel);
}
INSTANTIATE_TEST_SUITE_P(
DirectMode,
RTTMappingTest,
::testing::Values(
std::make_tuple(0, 0, QUALITY_ULTRA),
std::make_tuple(29, 0, QUALITY_ULTRA),
std::make_tuple(30, 0, QUALITY_HIGH),
std::make_tuple(79, 0, QUALITY_HIGH),
std::make_tuple(80, 0, QUALITY_GOOD),
std::make_tuple(149, 0, QUALITY_GOOD),
std::make_tuple(150, 0, QUALITY_MEDIUM),
std::make_tuple(249, 0, QUALITY_MEDIUM),
std::make_tuple(250, 0, QUALITY_LOW),
std::make_tuple(399, 0, QUALITY_LOW),
std::make_tuple(400, 0, QUALITY_MINIMAL),
std::make_tuple(1000, 0, QUALITY_MINIMAL)
)
);
INSTANTIATE_TEST_SUITE_P(
FRPMode,
RTTMappingTest,
::testing::Values(
std::make_tuple(0, 1, QUALITY_ULTRA),
std::make_tuple(59, 1, QUALITY_ULTRA),
std::make_tuple(60, 1, QUALITY_HIGH),
std::make_tuple(159, 1, QUALITY_HIGH),
std::make_tuple(160, 1, QUALITY_GOOD),
std::make_tuple(299, 1, QUALITY_GOOD),
std::make_tuple(300, 1, QUALITY_MEDIUM),
std::make_tuple(499, 1, QUALITY_MEDIUM),
std::make_tuple(500, 1, QUALITY_LOW),
std::make_tuple(799, 1, QUALITY_LOW),
std::make_tuple(800, 1, QUALITY_MINIMAL),
std::make_tuple(2000, 1, QUALITY_MINIMAL)
)
);
// ============================================
// 质量配置合理性测试
// ============================================
TEST(QualityProfileTest, FPSDecreasing) {
// FPS 应该随质量降低而减少
for (int i = 0; i < QUALITY_COUNT - 1; i++) {
EXPECT_GE(g_QualityProfiles[i].maxFPS, g_QualityProfiles[i + 1].maxFPS)
<< "Level " << i;
}
}
TEST(QualityProfileTest, MaxWidthDecreasing) {
// maxWidth 应该随质量降低而减少除了0表示不限
int prevWidth = INT_MAX;
for (int i = 0; i < QUALITY_COUNT; i++) {
int width = g_QualityProfiles[i].maxWidth;
if (width > 0) {
EXPECT_LE(width, prevWidth) << "Level " << i;
prevWidth = width;
}
}
}
TEST(QualityProfileTest, BitRateDecreasing) {
// H264 bitRate 应该随质量降低而减少
int prevBitRate = INT_MAX;
for (int i = 0; i < QUALITY_COUNT; i++) {
if (g_QualityProfiles[i].algorithm == ALGORITHM_H264) {
int bitRate = g_QualityProfiles[i].bitRate;
EXPECT_LE(bitRate, prevBitRate) << "Level " << i;
prevBitRate = bitRate;
}
}
}

View File

@@ -0,0 +1,597 @@
// RGB565Test.cpp - Phase 4: RGB565压缩单元测试
// 测试 BGRA <-> RGB565 颜色空间转换
#include <gtest/gtest.h>
#include <vector>
#include <cstdint>
#include <cstring>
#include <random>
#include <cmath>
// ============================================
// RGB565 颜色空间转换实现
// 来源: client/ScreenCapture.h, server/ScreenSpyDlg.cpp
// ============================================
// RGB565 格式说明:
// 16位: RRRRRGGG GGGBBBBB
// R: 5位 (0-31)
// G: 6位 (0-63)
// B: 5位 (0-31)
class RGB565Converter {
public:
// ============================================
// BGRA -> RGB565 转换 (标量版本)
// ============================================
static void ConvertBGRAtoRGB565_Scalar(
const uint8_t* src,
uint16_t* dst,
size_t pixelCount
) {
for (size_t i = 0; i < pixelCount; i++) {
// BGRA 格式: B=0, G=1, R=2, A=3
uint8_t b = src[i * 4 + 0];
uint8_t g = src[i * 4 + 1];
uint8_t r = src[i * 4 + 2];
// A 通道被忽略
// RGB565: RRRRRGGG GGGBBBBB
dst[i] = static_cast<uint16_t>(
((r >> 3) << 11) | // R: 高5位 -> 位11-15
((g >> 2) << 5) | // G: 高6位 -> 位5-10
(b >> 3) // B: 高5位 -> 位0-4
);
}
}
// ============================================
// RGB565 -> BGRA 转换
// 位复制填充低位以提高精度
// ============================================
static void ConvertRGB565ToBGRA(
const uint16_t* src,
uint8_t* dst,
size_t pixelCount
) {
for (size_t i = 0; i < pixelCount; i++) {
uint16_t c = src[i];
// 提取各通道
uint8_t r5 = (c >> 11) & 0x1F; // 5位红色
uint8_t g6 = (c >> 5) & 0x3F; // 6位绿色
uint8_t b5 = c & 0x1F; // 5位蓝色
// 扩展到8位使用位复制填充低位
// 例如: 5位 11111 -> 8位 11111111 (通过 (x << 3) | (x >> 2))
dst[i * 4 + 0] = (b5 << 3) | (b5 >> 2); // B: 5->8位
dst[i * 4 + 1] = (g6 << 2) | (g6 >> 4); // G: 6->8位
dst[i * 4 + 2] = (r5 << 3) | (r5 >> 2); // R: 5->8位
dst[i * 4 + 3] = 0xFF; // A: 不透明
}
}
// ============================================
// 单像素转换辅助函数
// ============================================
static uint16_t BGRAToRGB565(uint8_t b, uint8_t g, uint8_t r) {
return static_cast<uint16_t>(
((r >> 3) << 11) |
((g >> 2) << 5) |
(b >> 3)
);
}
static void RGB565ToBGRA(uint16_t c, uint8_t& b, uint8_t& g, uint8_t& r, uint8_t& a) {
uint8_t r5 = (c >> 11) & 0x1F;
uint8_t g6 = (c >> 5) & 0x3F;
uint8_t b5 = c & 0x1F;
b = (b5 << 3) | (b5 >> 2);
g = (g6 << 2) | (g6 >> 4);
r = (r5 << 3) | (r5 >> 2);
a = 0xFF;
}
// ============================================
// 计算量化误差
// ============================================
static int CalculateError(uint8_t original, uint8_t converted) {
return std::abs(static_cast<int>(original) - static_cast<int>(converted));
}
// 计算最大理论误差
// 5位通道: 量化+位填充可能产生最大误差 7
// 6位通道: 量化+位填充可能产生最大误差 3
static int MaxError5Bit() { return 7; }
static int MaxError6Bit() { return 3; }
};
// ============================================
// 测试夹具
// ============================================
class RGB565Test : public ::testing::Test {
protected:
// 创建BGRA像素数组
static std::vector<uint8_t> CreateBGRAPixels(size_t count, uint8_t b, uint8_t g, uint8_t r, uint8_t a = 0xFF) {
std::vector<uint8_t> pixels(count * 4);
for (size_t i = 0; i < count; i++) {
pixels[i * 4 + 0] = b;
pixels[i * 4 + 1] = g;
pixels[i * 4 + 2] = r;
pixels[i * 4 + 3] = a;
}
return pixels;
}
};
// ============================================
// 基础转换测试
// ============================================
TEST_F(RGB565Test, SinglePixel_Black) {
uint16_t result = RGB565Converter::BGRAToRGB565(0, 0, 0);
EXPECT_EQ(result, 0x0000);
}
TEST_F(RGB565Test, SinglePixel_White) {
uint16_t result = RGB565Converter::BGRAToRGB565(255, 255, 255);
// R=31<<11=0xF800, G=63<<5=0x07E0, B=31=0x001F
// 合计: 0xFFFF
EXPECT_EQ(result, 0xFFFF);
}
TEST_F(RGB565Test, SinglePixel_Red) {
uint16_t result = RGB565Converter::BGRAToRGB565(0, 0, 255);
// R=31<<11=0xF800, G=0, B=0
EXPECT_EQ(result, 0xF800);
}
TEST_F(RGB565Test, SinglePixel_Green) {
uint16_t result = RGB565Converter::BGRAToRGB565(0, 255, 0);
// R=0, G=63<<5=0x07E0, B=0
EXPECT_EQ(result, 0x07E0);
}
TEST_F(RGB565Test, SinglePixel_Blue) {
uint16_t result = RGB565Converter::BGRAToRGB565(255, 0, 0);
// R=0, G=0, B=31=0x001F
EXPECT_EQ(result, 0x001F);
}
// ============================================
// 反向转换测试
// ============================================
TEST_F(RGB565Test, Reverse_Black) {
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(0x0000, b, g, r, a);
EXPECT_EQ(b, 0);
EXPECT_EQ(g, 0);
EXPECT_EQ(r, 0);
EXPECT_EQ(a, 255);
}
TEST_F(RGB565Test, Reverse_White) {
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(0xFFFF, b, g, r, a);
EXPECT_EQ(b, 255);
EXPECT_EQ(g, 255);
EXPECT_EQ(r, 255);
EXPECT_EQ(a, 255);
}
TEST_F(RGB565Test, Reverse_Red) {
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(0xF800, b, g, r, a);
EXPECT_EQ(b, 0);
EXPECT_EQ(g, 0);
EXPECT_EQ(r, 255);
EXPECT_EQ(a, 255);
}
TEST_F(RGB565Test, Reverse_Green) {
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(0x07E0, b, g, r, a);
EXPECT_EQ(b, 0);
EXPECT_EQ(g, 255);
EXPECT_EQ(r, 0);
EXPECT_EQ(a, 255);
}
TEST_F(RGB565Test, Reverse_Blue) {
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(0x001F, b, g, r, a);
EXPECT_EQ(b, 255);
EXPECT_EQ(g, 0);
EXPECT_EQ(r, 0);
EXPECT_EQ(a, 255);
}
// ============================================
// 往返转换测试
// ============================================
TEST_F(RGB565Test, RoundTrip_ExactColors) {
// 测试能精确表示的颜色 (只有 0 和 255 能完美往返)
// 位填充公式: (x << 3) | (x >> 2) 只有 x=0 -> 0, x=31 -> 255 是精确的
struct TestCase {
uint8_t b, g, r;
} testCases[] = {
{0, 0, 0}, // 黑色 - 精确
{255, 255, 255}, // 白色 - 精确
{0, 0, 255}, // 红色 - 精确
{0, 255, 0}, // 绿色 - 精确
{255, 0, 0}, // 蓝色 - 精确
{255, 255, 0}, // 青色 - 精确
{0, 255, 255}, // 黄色 - 精确
{255, 0, 255}, // 品红 - 精确
};
for (const auto& tc : testCases) {
uint16_t rgb565 = RGB565Converter::BGRAToRGB565(tc.b, tc.g, tc.r);
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(rgb565, b, g, r, a);
// 能精确表示的颜色应该完美还原
EXPECT_EQ(b, tc.b) << "B channel mismatch for (" << (int)tc.r << "," << (int)tc.g << "," << (int)tc.b << ")";
EXPECT_EQ(g, tc.g) << "G channel mismatch for (" << (int)tc.r << "," << (int)tc.g << "," << (int)tc.b << ")";
EXPECT_EQ(r, tc.r) << "R channel mismatch for (" << (int)tc.r << "," << (int)tc.g << "," << (int)tc.b << ")";
}
}
TEST_F(RGB565Test, RoundTrip_QuantizationError) {
// 测试所有颜色的量化误差在允许范围内
std::mt19937 rng(12345);
std::uniform_int_distribution<int> dist(0, 255);
for (int i = 0; i < 1000; i++) {
uint8_t origB = dist(rng);
uint8_t origG = dist(rng);
uint8_t origR = dist(rng);
uint16_t rgb565 = RGB565Converter::BGRAToRGB565(origB, origG, origR);
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(rgb565, b, g, r, a);
// 验证误差在允许范围内
EXPECT_LE(RGB565Converter::CalculateError(origB, b), RGB565Converter::MaxError5Bit());
EXPECT_LE(RGB565Converter::CalculateError(origG, g), RGB565Converter::MaxError6Bit());
EXPECT_LE(RGB565Converter::CalculateError(origR, r), RGB565Converter::MaxError5Bit());
}
}
// ============================================
// 批量转换测试
// ============================================
TEST_F(RGB565Test, BatchConvert_SinglePixel) {
auto bgra = CreateBGRAPixels(1, 128, 64, 192);
std::vector<uint16_t> rgb565(1);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra.data(), rgb565.data(), 1);
// 验证转换结果
uint16_t expected = RGB565Converter::BGRAToRGB565(128, 64, 192);
EXPECT_EQ(rgb565[0], expected);
}
TEST_F(RGB565Test, BatchConvert_MultiplePixels) {
const size_t COUNT = 100;
auto bgra = CreateBGRAPixels(COUNT, 100, 150, 200);
std::vector<uint16_t> rgb565(COUNT);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra.data(), rgb565.data(), COUNT);
uint16_t expected = RGB565Converter::BGRAToRGB565(100, 150, 200);
for (size_t i = 0; i < COUNT; i++) {
EXPECT_EQ(rgb565[i], expected);
}
}
TEST_F(RGB565Test, BatchReverse_MultiplePixels) {
const size_t COUNT = 100;
std::vector<uint16_t> rgb565(COUNT, 0xF800); // 红色
std::vector<uint8_t> bgra(COUNT * 4);
RGB565Converter::ConvertRGB565ToBGRA(rgb565.data(), bgra.data(), COUNT);
for (size_t i = 0; i < COUNT; i++) {
EXPECT_EQ(bgra[i * 4 + 0], 0); // B
EXPECT_EQ(bgra[i * 4 + 1], 0); // G
EXPECT_EQ(bgra[i * 4 + 2], 255); // R
EXPECT_EQ(bgra[i * 4 + 3], 255); // A
}
}
TEST_F(RGB565Test, BatchRoundTrip) {
const size_t COUNT = 1000;
std::mt19937 rng(42);
std::uniform_int_distribution<int> dist(0, 255);
// 创建随机BGRA数据
std::vector<uint8_t> original(COUNT * 4);
for (size_t i = 0; i < COUNT * 4; i++) {
original[i] = (i % 4 == 3) ? 255 : dist(rng);
}
// 转换到 RGB565
std::vector<uint16_t> rgb565(COUNT);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(original.data(), rgb565.data(), COUNT);
// 转换回 BGRA
std::vector<uint8_t> reconstructed(COUNT * 4);
RGB565Converter::ConvertRGB565ToBGRA(rgb565.data(), reconstructed.data(), COUNT);
// 验证所有像素误差在允许范围内
for (size_t i = 0; i < COUNT; i++) {
EXPECT_LE(RGB565Converter::CalculateError(original[i * 4 + 0], reconstructed[i * 4 + 0]),
RGB565Converter::MaxError5Bit()) << "Pixel " << i << " B";
EXPECT_LE(RGB565Converter::CalculateError(original[i * 4 + 1], reconstructed[i * 4 + 1]),
RGB565Converter::MaxError6Bit()) << "Pixel " << i << " G";
EXPECT_LE(RGB565Converter::CalculateError(original[i * 4 + 2], reconstructed[i * 4 + 2]),
RGB565Converter::MaxError5Bit()) << "Pixel " << i << " R";
EXPECT_EQ(reconstructed[i * 4 + 3], 255) << "Pixel " << i << " A";
}
}
// ============================================
// 边界值测试
// ============================================
TEST_F(RGB565Test, Boundary_AllZeros) {
uint16_t result = RGB565Converter::BGRAToRGB565(0, 0, 0);
EXPECT_EQ(result, 0);
}
TEST_F(RGB565Test, Boundary_AllOnes) {
uint16_t result = RGB565Converter::BGRAToRGB565(255, 255, 255);
EXPECT_EQ(result, 0xFFFF);
}
TEST_F(RGB565Test, Boundary_ChannelMax) {
// 单通道最大值
EXPECT_EQ(RGB565Converter::BGRAToRGB565(255, 0, 0), 0x001F); // B
EXPECT_EQ(RGB565Converter::BGRAToRGB565(0, 255, 0), 0x07E0); // G
EXPECT_EQ(RGB565Converter::BGRAToRGB565(0, 0, 255), 0xF800); // R
}
TEST_F(RGB565Test, Boundary_ChannelMin) {
// 单通道最小值(其他为最大)
EXPECT_EQ(RGB565Converter::BGRAToRGB565(0, 255, 255), 0xFFE0); // 无B
EXPECT_EQ(RGB565Converter::BGRAToRGB565(255, 0, 255), 0xF81F); // 无G
EXPECT_EQ(RGB565Converter::BGRAToRGB565(255, 255, 0), 0x07FF); // 无R
}
// ============================================
// 位填充测试
// ============================================
TEST_F(RGB565Test, BitFilling_5BitExpansion) {
// 测试5位扩展到8位的位填充策略
// 5位值 11111 (31) -> 8位值 11111111 (255)
// 公式: (x << 3) | (x >> 2)
// 最大值: 31 -> 255
uint8_t expanded = (31 << 3) | (31 >> 2);
EXPECT_EQ(expanded, 255);
// 最小值: 0 -> 0
expanded = (0 << 3) | (0 >> 2);
EXPECT_EQ(expanded, 0);
// 中间值: 16 -> 132 (10000 -> 10000100)
expanded = (16 << 3) | (16 >> 2);
EXPECT_EQ(expanded, 132);
}
TEST_F(RGB565Test, BitFilling_6BitExpansion) {
// 测试6位扩展到8位的位填充策略
// 6位值 111111 (63) -> 8位值 11111111 (255)
// 公式: (x << 2) | (x >> 4)
// 最大值: 63 -> 255
uint8_t expanded = (63 << 2) | (63 >> 4);
EXPECT_EQ(expanded, 255);
// 最小值: 0 -> 0
expanded = (0 << 2) | (0 >> 4);
EXPECT_EQ(expanded, 0);
// 中间值: 32 -> 130 (100000 -> 10000010)
expanded = (32 << 2) | (32 >> 4);
EXPECT_EQ(expanded, 130);
}
// ============================================
// Alpha通道测试
// ============================================
TEST_F(RGB565Test, Alpha_Ignored) {
// 不同alpha值应该产生相同的RGB565
auto bgra1 = CreateBGRAPixels(1, 100, 100, 100, 255);
auto bgra2 = CreateBGRAPixels(1, 100, 100, 100, 128);
auto bgra3 = CreateBGRAPixels(1, 100, 100, 100, 0);
std::vector<uint16_t> rgb565_1(1), rgb565_2(1), rgb565_3(1);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra1.data(), rgb565_1.data(), 1);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra2.data(), rgb565_2.data(), 1);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra3.data(), rgb565_3.data(), 1);
EXPECT_EQ(rgb565_1[0], rgb565_2[0]);
EXPECT_EQ(rgb565_2[0], rgb565_3[0]);
}
TEST_F(RGB565Test, Alpha_RestoredToOpaque) {
// 反向转换时Alpha应该恢复为255
std::vector<uint16_t> rgb565 = {0x1234, 0x5678, 0x9ABC};
std::vector<uint8_t> bgra(3 * 4);
RGB565Converter::ConvertRGB565ToBGRA(rgb565.data(), bgra.data(), 3);
for (size_t i = 0; i < 3; i++) {
EXPECT_EQ(bgra[i * 4 + 3], 255);
}
}
// ============================================
// 压缩率测试
// ============================================
TEST_F(RGB565Test, CompressionRatio) {
// BGRA: 4字节/像素
// RGB565: 2字节/像素
// 压缩率: 50%
const size_t PIXEL_COUNT = 1920 * 1080;
size_t bgraSize = PIXEL_COUNT * 4;
size_t rgb565Size = PIXEL_COUNT * 2;
EXPECT_EQ(rgb565Size, bgraSize / 2);
EXPECT_DOUBLE_EQ(static_cast<double>(rgb565Size) / bgraSize, 0.5);
}
// ============================================
// 特殊颜色测试
// ============================================
TEST_F(RGB565Test, CommonColors) {
struct ColorTest {
const char* name;
uint8_t r, g, b;
uint16_t expectedRGB565;
} colors[] = {
{"Black", 0, 0, 0, 0x0000},
{"White", 255, 255, 255, 0xFFFF},
{"Red", 255, 0, 0, 0xF800},
{"Green", 0, 255, 0, 0x07E0},
{"Blue", 0, 0, 255, 0x001F},
{"Yellow", 255, 255, 0, 0xFFE0},
{"Cyan", 0, 255, 255, 0x07FF},
{"Magenta", 255, 0, 255, 0xF81F},
};
for (const auto& c : colors) {
uint16_t result = RGB565Converter::BGRAToRGB565(c.b, c.g, c.r);
EXPECT_EQ(result, c.expectedRGB565) << "Color: " << c.name;
}
}
TEST_F(RGB565Test, GrayScales) {
// 测试灰度值转换
for (int gray = 0; gray <= 255; gray += 17) {
uint8_t g = static_cast<uint8_t>(gray);
uint16_t rgb565 = RGB565Converter::BGRAToRGB565(g, g, g);
uint8_t b, gr, r, a;
RGB565Converter::RGB565ToBGRA(rgb565, b, gr, r, a);
// 灰度值在量化误差范围内应该保持一致
EXPECT_NEAR(b, g, 8);
EXPECT_NEAR(gr, g, 4);
EXPECT_NEAR(r, g, 8);
}
}
// ============================================
// 参数化测试
// ============================================
class RGB565ChannelTest : public ::testing::TestWithParam<int> {};
TEST_P(RGB565ChannelTest, ChannelValueRange) {
int value = GetParam();
uint8_t v = static_cast<uint8_t>(value);
// 测试R通道
{
uint16_t rgb565 = RGB565Converter::BGRAToRGB565(0, 0, v);
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(rgb565, b, g, r, a);
EXPECT_LE(RGB565Converter::CalculateError(v, r), RGB565Converter::MaxError5Bit());
}
// 测试G通道
{
uint16_t rgb565 = RGB565Converter::BGRAToRGB565(0, v, 0);
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(rgb565, b, g, r, a);
EXPECT_LE(RGB565Converter::CalculateError(v, g), RGB565Converter::MaxError6Bit());
}
// 测试B通道
{
uint16_t rgb565 = RGB565Converter::BGRAToRGB565(v, 0, 0);
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(rgb565, b, g, r, a);
EXPECT_LE(RGB565Converter::CalculateError(v, b), RGB565Converter::MaxError5Bit());
}
}
INSTANTIATE_TEST_SUITE_P(
AllValues,
RGB565ChannelTest,
::testing::Range(0, 256, 1) // 测试0-255所有值
);
// ============================================
// 大数据量测试
// ============================================
TEST_F(RGB565Test, LargeFrame_1080p) {
const size_t WIDTH = 1920, HEIGHT = 1080;
const size_t COUNT = WIDTH * HEIGHT;
// 创建渐变图像
std::vector<uint8_t> bgra(COUNT * 4);
for (size_t y = 0; y < HEIGHT; y++) {
for (size_t x = 0; x < WIDTH; x++) {
size_t idx = (y * WIDTH + x) * 4;
bgra[idx + 0] = static_cast<uint8_t>(x * 255 / WIDTH);
bgra[idx + 1] = static_cast<uint8_t>(y * 255 / HEIGHT);
bgra[idx + 2] = static_cast<uint8_t>((x + y) * 127 / (WIDTH + HEIGHT));
bgra[idx + 3] = 255;
}
}
// 转换
std::vector<uint16_t> rgb565(COUNT);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra.data(), rgb565.data(), COUNT);
// 验证大小
EXPECT_EQ(rgb565.size() * sizeof(uint16_t), COUNT * 2);
// 抽样验证
for (size_t i = 0; i < COUNT; i += COUNT / 100) {
uint16_t expected = RGB565Converter::BGRAToRGB565(bgra[i * 4 + 0], bgra[i * 4 + 1], bgra[i * 4 + 2]);
EXPECT_EQ(rgb565[i], expected) << "Mismatch at pixel " << i;
}
}
// ============================================
// 端序测试
// ============================================
TEST_F(RGB565Test, Endianness_LittleEndian) {
// RGB565 应该以小端存储
uint16_t rgb565 = 0x1234;
uint8_t* bytes = reinterpret_cast<uint8_t*>(&rgb565);
// 在小端系统上: bytes[0] = 0x34, bytes[1] = 0x12
EXPECT_EQ(bytes[0], 0x34);
EXPECT_EQ(bytes[1], 0x12);
}
// ============================================
// 错误处理测试
// ============================================
TEST_F(RGB565Test, ZeroPixelCount) {
std::vector<uint8_t> bgra(4);
std::vector<uint16_t> rgb565(1);
// 零像素不应崩溃
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra.data(), rgb565.data(), 0);
RGB565Converter::ConvertRGB565ToBGRA(rgb565.data(), bgra.data(), 0);
}

View File

@@ -0,0 +1,686 @@
// ScrollDetectorTest.cpp - Phase 4: 滚动检测单元测试
// 测试屏幕滚动检测和优化传输
#include <gtest/gtest.h>
#include <vector>
#include <cstdint>
#include <cstring>
#include <random>
#include <algorithm>
// ============================================
// 滚动检测常量定义 (来自 ScrollDetector.h)
// ============================================
#define MIN_SCROLL_LINES 16 // 最小滚动行数
#define MAX_SCROLL_RATIO 4 // 最大滚动 = 高度 / 4
#define MATCH_THRESHOLD 85 // 行匹配百分比阈值 (85%)
// 滚动方向常量
#define SCROLL_DIR_UP 0 // 向上滚动(内容向下移动)
#define SCROLL_DIR_DOWN 1 // 向下滚动(内容向上移动)
// ============================================
// CRC32 哈希计算
// ============================================
class CRC32 {
public:
static uint32_t Calculate(const uint8_t* data, size_t length) {
static uint32_t table[256] = {0};
static bool tableInit = false;
if (!tableInit) {
for (uint32_t i = 0; i < 256; i++) {
uint32_t c = i;
for (int j = 0; j < 8; j++) {
c = (c & 1) ? (0xEDB88320 ^ (c >> 1)) : (c >> 1);
}
table[i] = c;
}
tableInit = true;
}
uint32_t crc = 0xFFFFFFFF;
for (size_t i = 0; i < length; i++) {
crc = table[(crc ^ data[i]) & 0xFF] ^ (crc >> 8);
}
return crc ^ 0xFFFFFFFF;
}
};
// ============================================
// 滚动检测器实现 (模拟 ScrollDetector.h)
// ============================================
class CScrollDetector {
public:
CScrollDetector(int width, int height, int bpp = 4)
: m_width(width), m_height(height), m_bpp(bpp)
, m_stride(width * bpp)
, m_minScroll(MIN_SCROLL_LINES)
, m_maxScroll(height / MAX_SCROLL_RATIO)
{
m_rowHashes.resize(height);
}
// 检测垂直滚动
// 返回: >0 向下滚动, <0 向上滚动, 0 无滚动
int DetectVerticalScroll(const uint8_t* prevFrame, const uint8_t* currFrame) {
if (!prevFrame || !currFrame) return 0;
// 计算当前帧的行哈希
std::vector<uint32_t> currHashes(m_height);
for (int y = 0; y < m_height; y++) {
currHashes[y] = CRC32::Calculate(currFrame + y * m_stride, m_stride);
}
// 计算前一帧的行哈希
std::vector<uint32_t> prevHashes(m_height);
for (int y = 0; y < m_height; y++) {
prevHashes[y] = CRC32::Calculate(prevFrame + y * m_stride, m_stride);
}
int bestScroll = 0;
int bestMatchCount = 0;
// 尝试各种滚动量
for (int scroll = m_minScroll; scroll <= m_maxScroll; scroll++) {
// 向下滚动 (正值)
int matchDown = CountMatchingRows(prevHashes, currHashes, scroll);
if (matchDown > bestMatchCount) {
bestMatchCount = matchDown;
bestScroll = scroll;
}
// 向上滚动 (负值)
int matchUp = CountMatchingRows(prevHashes, currHashes, -scroll);
if (matchUp > bestMatchCount) {
bestMatchCount = matchUp;
bestScroll = -scroll;
}
}
// 检查是否达到匹配阈值
int scrollAbs = std::abs(bestScroll);
int totalRows = m_height - scrollAbs;
int threshold = totalRows * MATCH_THRESHOLD / 100;
if (bestMatchCount >= threshold) {
return bestScroll;
}
return 0;
}
// 获取边缘区域信息
void GetEdgeRegion(int scrollAmount, int* outOffset, int* outPixelCount) const {
if (scrollAmount > 0) {
// 向下滚动: 新内容在底部 (BMP底上格式: 低地址)
*outOffset = 0;
*outPixelCount = scrollAmount * m_width;
} else if (scrollAmount < 0) {
// 向上滚动: 新内容在顶部 (BMP底上格式: 高地址)
*outOffset = (m_height + scrollAmount) * m_stride;
*outPixelCount = (-scrollAmount) * m_width;
} else {
*outOffset = 0;
*outPixelCount = 0;
}
}
int GetMinScroll() const { return m_minScroll; }
int GetMaxScroll() const { return m_maxScroll; }
int GetWidth() const { return m_width; }
int GetHeight() const { return m_height; }
private:
int CountMatchingRows(const std::vector<uint32_t>& prevHashes,
const std::vector<uint32_t>& currHashes,
int scroll) const {
int matchCount = 0;
if (scroll > 0) {
// 向下滚动: prev[y] 对应 curr[y + scroll]
for (int y = 0; y < m_height - scroll; y++) {
if (prevHashes[y] == currHashes[y + scroll]) {
matchCount++;
}
}
} else if (scroll < 0) {
// 向上滚动: prev[y] 对应 curr[y + scroll] (scroll为负)
int absScroll = -scroll;
for (int y = absScroll; y < m_height; y++) {
if (prevHashes[y] == currHashes[y + scroll]) {
matchCount++;
}
}
}
return matchCount;
}
int m_width;
int m_height;
int m_bpp;
int m_stride;
int m_minScroll;
int m_maxScroll;
std::vector<uint32_t> m_rowHashes;
};
// ============================================
// 测试夹具
// ============================================
class ScrollDetectorTest : public ::testing::Test {
public:
// 创建纯色帧
static std::vector<uint8_t> CreateSolidFrame(int width, int height,
uint8_t b, uint8_t g,
uint8_t r, uint8_t a = 0xFF) {
std::vector<uint8_t> frame(width * height * 4);
for (int i = 0; i < width * height; i++) {
frame[i * 4 + 0] = b;
frame[i * 4 + 1] = g;
frame[i * 4 + 2] = r;
frame[i * 4 + 3] = a;
}
return frame;
}
// 创建带条纹的帧 (每行不同颜色,便于检测滚动)
static std::vector<uint8_t> CreateStripedFrame(int width, int height) {
std::vector<uint8_t> frame(width * height * 4);
for (int y = 0; y < height; y++) {
uint8_t color = static_cast<uint8_t>(y % 256);
for (int x = 0; x < width; x++) {
int idx = (y * width + x) * 4;
frame[idx + 0] = color;
frame[idx + 1] = color;
frame[idx + 2] = color;
frame[idx + 3] = 0xFF;
}
}
return frame;
}
// 模拟向下滚动 (内容向上移动)
static std::vector<uint8_t> SimulateScrollDown(const std::vector<uint8_t>& frame,
int width, int height,
int scrollAmount) {
std::vector<uint8_t> result(frame.size());
int stride = width * 4;
// 复制滚动后的内容
for (int y = scrollAmount; y < height; y++) {
memcpy(result.data() + (y - scrollAmount) * stride,
frame.data() + y * stride, stride);
}
// 底部新内容用不同颜色填充
for (int y = height - scrollAmount; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = (y * width + x) * 4;
result[idx + 0] = 0xFF; // 新内容用白色
result[idx + 1] = 0xFF;
result[idx + 2] = 0xFF;
result[idx + 3] = 0xFF;
}
}
return result;
}
// 模拟向上滚动 (内容向下移动)
static std::vector<uint8_t> SimulateScrollUp(const std::vector<uint8_t>& frame,
int width, int height,
int scrollAmount) {
std::vector<uint8_t> result(frame.size());
int stride = width * 4;
// 复制滚动后的内容
for (int y = 0; y < height - scrollAmount; y++) {
memcpy(result.data() + (y + scrollAmount) * stride,
frame.data() + y * stride, stride);
}
// 顶部新内容用不同颜色填充
for (int y = 0; y < scrollAmount; y++) {
for (int x = 0; x < width; x++) {
int idx = (y * width + x) * 4;
result[idx + 0] = 0x00; // 新内容用黑色
result[idx + 1] = 0x00;
result[idx + 2] = 0x00;
result[idx + 3] = 0xFF;
}
}
return result;
}
};
// ============================================
// 基础功能测试
// ============================================
TEST_F(ScrollDetectorTest, Constructor_ValidParameters) {
CScrollDetector detector(100, 200);
EXPECT_EQ(detector.GetWidth(), 100);
EXPECT_EQ(detector.GetHeight(), 200);
EXPECT_EQ(detector.GetMinScroll(), MIN_SCROLL_LINES);
EXPECT_EQ(detector.GetMaxScroll(), 200 / MAX_SCROLL_RATIO);
}
TEST_F(ScrollDetectorTest, IdenticalFrames_NoScroll) {
const int WIDTH = 100, HEIGHT = 100;
auto frame = CreateStripedFrame(WIDTH, HEIGHT);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame.data(), frame.data());
EXPECT_EQ(scroll, 0);
}
TEST_F(ScrollDetectorTest, CompletelyDifferent_NoScroll) {
const int WIDTH = 100, HEIGHT = 100;
auto frame1 = CreateSolidFrame(WIDTH, HEIGHT, 0, 0, 0);
auto frame2 = CreateSolidFrame(WIDTH, HEIGHT, 255, 255, 255);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_EQ(scroll, 0);
}
// ============================================
// 向下滚动检测测试
// ============================================
TEST_F(ScrollDetectorTest, ScrollDown_MinimalScroll) {
const int WIDTH = 100, HEIGHT = 100;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, MIN_SCROLL_LINES);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
// 检测到滚动(绝对值匹配,方向取决于实现)
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), MIN_SCROLL_LINES);
}
TEST_F(ScrollDetectorTest, ScrollDown_MediumScroll) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), SCROLL);
}
TEST_F(ScrollDetectorTest, ScrollDown_MaxScroll) {
const int WIDTH = 100, HEIGHT = 100;
const int MAX_SCROLL = HEIGHT / MAX_SCROLL_RATIO;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, MAX_SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(scroll, 0);
EXPECT_LE(std::abs(scroll), MAX_SCROLL);
}
// ============================================
// 向上滚动检测测试
// ============================================
TEST_F(ScrollDetectorTest, ScrollUp_MinimalScroll) {
const int WIDTH = 100, HEIGHT = 100;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollUp(frame1, WIDTH, HEIGHT, MIN_SCROLL_LINES);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
// 检测到滚动(方向与 ScrollDown 相反)
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), MIN_SCROLL_LINES);
}
TEST_F(ScrollDetectorTest, ScrollUp_MediumScroll) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollUp(frame1, WIDTH, HEIGHT, SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), SCROLL);
}
// ============================================
// 边缘区域测试
// ============================================
TEST_F(ScrollDetectorTest, EdgeRegion_ScrollDown) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
CScrollDetector detector(WIDTH, HEIGHT);
int offset, pixelCount;
detector.GetEdgeRegion(SCROLL, &offset, &pixelCount);
// 向下滚动: 新内容在底部
EXPECT_EQ(offset, 0);
EXPECT_EQ(pixelCount, SCROLL * WIDTH);
}
TEST_F(ScrollDetectorTest, EdgeRegion_ScrollUp) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
CScrollDetector detector(WIDTH, HEIGHT);
int offset, pixelCount;
detector.GetEdgeRegion(-SCROLL, &offset, &pixelCount);
// 向上滚动: 新内容在顶部
EXPECT_EQ(offset, (HEIGHT - SCROLL) * WIDTH * 4);
EXPECT_EQ(pixelCount, SCROLL * WIDTH);
}
TEST_F(ScrollDetectorTest, EdgeRegion_NoScroll) {
const int WIDTH = 100, HEIGHT = 100;
CScrollDetector detector(WIDTH, HEIGHT);
int offset, pixelCount;
detector.GetEdgeRegion(0, &offset, &pixelCount);
EXPECT_EQ(offset, 0);
EXPECT_EQ(pixelCount, 0);
}
// ============================================
// 边界条件测试
// ============================================
TEST_F(ScrollDetectorTest, NullInput_NoScroll) {
const int WIDTH = 100, HEIGHT = 100;
auto frame = CreateStripedFrame(WIDTH, HEIGHT);
CScrollDetector detector(WIDTH, HEIGHT);
EXPECT_EQ(detector.DetectVerticalScroll(nullptr, frame.data()), 0);
EXPECT_EQ(detector.DetectVerticalScroll(frame.data(), nullptr), 0);
EXPECT_EQ(detector.DetectVerticalScroll(nullptr, nullptr), 0);
}
TEST_F(ScrollDetectorTest, SmallScroll_BelowMinimum_NoDetection) {
const int WIDTH = 100, HEIGHT = 100;
const int SMALL_SCROLL = MIN_SCROLL_LINES - 1; // 低于最小滚动量
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, SMALL_SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
// 小于最小滚动量不应被检测
EXPECT_EQ(scroll, 0);
}
TEST_F(ScrollDetectorTest, LargeScroll_AboveMaximum_NotDetected) {
const int WIDTH = 100, HEIGHT = 100;
const int MAX_SCROLL = HEIGHT / MAX_SCROLL_RATIO;
const int LARGE_SCROLL = MAX_SCROLL + 10; // 超过最大滚动量
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, LARGE_SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
// 超过最大滚动量可能不被检测或返回最大值
EXPECT_LE(std::abs(scroll), MAX_SCROLL);
}
// ============================================
// CRC32 哈希测试
// ============================================
TEST_F(ScrollDetectorTest, CRC32_EmptyData) {
uint32_t hash = CRC32::Calculate(nullptr, 0);
// CRC32 的空数据哈希值
EXPECT_EQ(hash, 0); // 实现相关
}
TEST_F(ScrollDetectorTest, CRC32_KnownVector) {
// "123456789" 的 CRC32 应该是 0xCBF43926
const char* testData = "123456789";
uint32_t hash = CRC32::Calculate(reinterpret_cast<const uint8_t*>(testData), 9);
EXPECT_EQ(hash, 0xCBF43926);
}
TEST_F(ScrollDetectorTest, CRC32_SameData_SameHash) {
std::vector<uint8_t> data1(100, 0x42);
std::vector<uint8_t> data2(100, 0x42);
uint32_t hash1 = CRC32::Calculate(data1.data(), data1.size());
uint32_t hash2 = CRC32::Calculate(data2.data(), data2.size());
EXPECT_EQ(hash1, hash2);
}
TEST_F(ScrollDetectorTest, CRC32_DifferentData_DifferentHash) {
std::vector<uint8_t> data1(100, 0x42);
std::vector<uint8_t> data2(100, 0x43);
uint32_t hash1 = CRC32::Calculate(data1.data(), data1.size());
uint32_t hash2 = CRC32::Calculate(data2.data(), data2.size());
EXPECT_NE(hash1, hash2);
}
// ============================================
// 性能相关测试(验证正确性)
// ============================================
TEST_F(ScrollDetectorTest, LargeFrame_720p) {
const int WIDTH = 1280, HEIGHT = 720;
const int SCROLL = 50;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), SCROLL);
}
TEST_F(ScrollDetectorTest, LargeFrame_1080p) {
const int WIDTH = 1920, HEIGHT = 1080;
const int SCROLL = 100;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), SCROLL);
}
// ============================================
// 带宽节省计算测试
// ============================================
TEST_F(ScrollDetectorTest, BandwidthSaving_ScrollDetected) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
// 完整帧大小
size_t fullFrameSize = WIDTH * HEIGHT * 4;
// 边缘区域大小
size_t edgeSize = SCROLL * WIDTH * 4;
// 带宽节省
double saving = 1.0 - static_cast<double>(edgeSize) / fullFrameSize;
// 20行滚动应该节省约80%带宽
EXPECT_GT(saving, 0.7);
}
TEST_F(ScrollDetectorTest, BandwidthSaving_NoScroll) {
const int WIDTH = 100, HEIGHT = 100;
CScrollDetector detector(WIDTH, HEIGHT);
int offset, pixelCount;
detector.GetEdgeRegion(0, &offset, &pixelCount);
// 无滚动时没有边缘区域
EXPECT_EQ(pixelCount, 0);
}
// ============================================
// 参数化测试:不同滚动量
// ============================================
class ScrollAmountTest : public ::testing::TestWithParam<int> {};
TEST_P(ScrollAmountTest, DetectScrollAmount) {
const int WIDTH = 100, HEIGHT = 100;
int scrollAmount = GetParam();
if (scrollAmount < MIN_SCROLL_LINES || scrollAmount > HEIGHT / MAX_SCROLL_RATIO) {
GTEST_SKIP() << "Scroll amount out of valid range";
}
auto frame1 = ScrollDetectorTest::CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = ScrollDetectorTest::SimulateScrollDown(frame1, WIDTH, HEIGHT, scrollAmount);
CScrollDetector detector(WIDTH, HEIGHT);
int detected = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(detected, 0);
EXPECT_EQ(std::abs(detected), scrollAmount);
}
INSTANTIATE_TEST_SUITE_P(
ScrollAmounts,
ScrollAmountTest,
::testing::Values(16, 17, 18, 19, 20, 21, 22, 23, 24, 25)
);
// ============================================
// 匹配阈值测试
// ============================================
TEST_F(ScrollDetectorTest, MatchThreshold_HighNoise_NoDetection) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, SCROLL);
// 添加大量噪声,使匹配率低于阈值
std::mt19937 rng(42);
std::uniform_int_distribution<int> dist(0, 255);
for (size_t i = 0; i < frame2.size(); i++) {
if (rng() % 2 == 0) { // 50% 的像素被随机化
frame2[i] = dist(rng);
}
}
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
// 高噪声情况下不应检测到滚动
EXPECT_EQ(scroll, 0);
}
TEST_F(ScrollDetectorTest, MatchThreshold_LowNoise_DetectionOK) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, SCROLL);
// 只添加少量噪声10%
std::mt19937 rng(42);
std::uniform_int_distribution<int> dist(0, 255);
for (size_t i = 0; i < frame2.size(); i++) {
if (rng() % 10 == 0) { // 10% 的像素被随机化
frame2[i] = dist(rng);
}
}
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
// 低噪声情况下仍应检测到滚动(取决于具体行噪声分布)
// 这个测试可能不稳定,因为噪声是随机的
// EXPECT_GT(scroll, 0) 或 EXPECT_EQ(scroll, 0) 取决于实际噪声分布
}
// ============================================
// 常量验证测试
// ============================================
TEST(ScrollConstantsTest, MinScrollLines) {
EXPECT_EQ(MIN_SCROLL_LINES, 16);
}
TEST(ScrollConstantsTest, MaxScrollRatio) {
EXPECT_EQ(MAX_SCROLL_RATIO, 4);
}
TEST(ScrollConstantsTest, MatchThreshold) {
EXPECT_EQ(MATCH_THRESHOLD, 85);
}
TEST(ScrollConstantsTest, ScrollDirections) {
EXPECT_EQ(SCROLL_DIR_UP, 0);
EXPECT_EQ(SCROLL_DIR_DOWN, 1);
}
// ============================================
// 分辨率参数化测试
// ============================================
class ScrollResolutionTest : public ::testing::TestWithParam<std::tuple<int, int>> {};
TEST_P(ScrollResolutionTest, DetectScrollAtResolution) {
auto [width, height] = GetParam();
int scrollAmount = std::max(MIN_SCROLL_LINES, height / 10);
if (scrollAmount > height / MAX_SCROLL_RATIO) {
scrollAmount = height / MAX_SCROLL_RATIO;
}
auto frame1 = ScrollDetectorTest::CreateStripedFrame(width, height);
auto frame2 = ScrollDetectorTest::SimulateScrollDown(frame1, width, height, scrollAmount);
CScrollDetector detector(width, height);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), scrollAmount);
}
INSTANTIATE_TEST_SUITE_P(
Resolutions,
ScrollResolutionTest,
::testing::Values(
std::make_tuple(640, 480), // VGA
std::make_tuple(800, 600), // SVGA
std::make_tuple(1024, 768), // XGA
std::make_tuple(1280, 720), // 720p
std::make_tuple(1920, 1080) // 1080p
)
);

View File

@@ -0,0 +1,912 @@
/**
* @file BufferTest.cpp
* @brief 服务端 CBuffer 类单元测试
*
* 测试覆盖:
* - 基本读写操作
* - 延迟读取偏移机制 (m_ulReadOffset)
* - 压缩/紧凑策略 (CompactBuffer)
* - 边界条件和下溢防护
* - 线程安全(并发读写)
* - 零拷贝写入接口
*/
#include <gtest/gtest.h>
#include <thread>
#include <atomic>
#include <chrono>
#include <vector>
#include <string>
// Windows 头文件
#ifdef _WIN32
#include <Windows.h>
#else
// Linux 模拟 Windows API
#include <cstring>
#include <cstdlib>
#include <pthread.h>
typedef unsigned char BYTE;
typedef BYTE* PBYTE;
typedef BYTE* LPBYTE;
typedef unsigned long ULONG;
typedef void VOID;
typedef int BOOL;
typedef void* PVOID;
#define TRUE 1
#define FALSE 0
#define MEM_COMMIT 0x1000
#define MEM_RELEASE 0x8000
#define PAGE_READWRITE 0x04
struct CRITICAL_SECTION {
pthread_mutex_t mutex;
};
inline void InitializeCriticalSection(CRITICAL_SECTION* cs) {
pthread_mutex_init(&cs->mutex, NULL);
}
inline void DeleteCriticalSection(CRITICAL_SECTION* cs) {
pthread_mutex_destroy(&cs->mutex);
}
inline void EnterCriticalSection(CRITICAL_SECTION* cs) {
pthread_mutex_lock(&cs->mutex);
}
inline void LeaveCriticalSection(CRITICAL_SECTION* cs) {
pthread_mutex_unlock(&cs->mutex);
}
inline void* VirtualAlloc(void*, size_t size, int, int) {
return malloc(size);
}
inline void VirtualFree(void* ptr, size_t, int) {
free(ptr);
}
inline void CopyMemory(void* dst, const void* src, size_t len) {
memcpy(dst, src, len);
}
inline void MoveMemory(void* dst, const void* src, size_t len) {
memmove(dst, src, len);
}
#endif
#include <cmath>
// 服务端 Buffer 实现(测试专用内联版本)
namespace ServerBuffer {
#define U_PAGE_ALIGNMENT 4096
#define F_PAGE_ALIGNMENT 4096.0
#define COMPACT_THRESHOLD 0.5
// 简化的 Buffer 类(用于 GetMyBuffer 返回)
class Buffer {
private:
PBYTE buf;
ULONG len;
public:
Buffer() : buf(NULL), len(0) {}
Buffer(const BYTE* b, ULONG n) : len(n) {
if (n > 0 && b) {
buf = new BYTE[n];
memcpy(buf, b, n);
} else {
buf = NULL;
}
}
~Buffer() {
if (buf) {
delete[] buf;
buf = NULL;
}
}
Buffer(const Buffer& o) : len(o.len) {
if (o.buf && o.len > 0) {
buf = new BYTE[o.len];
memcpy(buf, o.buf, o.len);
} else {
buf = NULL;
}
}
ULONG length() const { return len; }
LPBYTE GetBuffer(int idx = 0) const {
return (idx >= (int)len) ? NULL : buf + idx;
}
};
class CBuffer {
public:
CBuffer() : m_ulMaxLength(0), m_ulReadOffset(0), m_Base(NULL), m_Ptr(NULL) {
InitializeCriticalSection(&m_cs);
}
~CBuffer() {
if (m_Base) {
VirtualFree(m_Base, 0, MEM_RELEASE);
m_Base = NULL;
}
DeleteCriticalSection(&m_cs);
m_Base = m_Ptr = NULL;
m_ulMaxLength = 0;
m_ulReadOffset = 0;
}
ULONG RemoveCompletedBuffer(ULONG ulLength) {
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (ulLength > effectiveDataLen) {
ulLength = effectiveDataLen;
}
if (ulLength) {
m_ulReadOffset += ulLength;
if (m_ulReadOffset > (ULONG)(m_ulMaxLength * COMPACT_THRESHOLD)) {
CompactBuffer();
}
}
LeaveCriticalSection(&m_cs);
return ulLength;
}
VOID CompactBuffer() {
if (m_ulReadOffset > 0 && m_Base) {
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG remainingData = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (remainingData > 0) {
MoveMemory(m_Base, m_Base + m_ulReadOffset, remainingData);
}
m_Ptr = m_Base + remainingData;
m_ulReadOffset = 0;
DeAllocateBuffer(remainingData);
}
}
ULONG ReadBuffer(PBYTE Buffer, ULONG ulLength) {
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (ulLength > effectiveDataLen) {
ulLength = effectiveDataLen;
}
if (ulLength) {
CopyMemory(Buffer, m_Base + m_ulReadOffset, ulLength);
m_ulReadOffset += ulLength;
if (m_ulReadOffset > (ULONG)(m_ulMaxLength * COMPACT_THRESHOLD)) {
CompactBuffer();
}
}
LeaveCriticalSection(&m_cs);
return ulLength;
}
ULONG DeAllocateBuffer(ULONG ulLength) {
if (ulLength < (ULONG)(m_Ptr - m_Base))
return 0;
ULONG ulNewMaxLength = (ULONG)(ceil(ulLength / F_PAGE_ALIGNMENT) * U_PAGE_ALIGNMENT);
if (m_ulMaxLength <= ulNewMaxLength) {
return 0;
}
PBYTE NewBase = (PBYTE)VirtualAlloc(NULL, ulNewMaxLength, MEM_COMMIT, PAGE_READWRITE);
ULONG ulv1 = (ULONG)(m_Ptr - m_Base);
CopyMemory(NewBase, m_Base, ulv1);
VirtualFree(m_Base, 0, MEM_RELEASE);
m_Base = NewBase;
m_Ptr = m_Base + ulv1;
m_ulMaxLength = ulNewMaxLength;
return m_ulMaxLength;
}
BOOL WriteBuffer(PBYTE Buffer, ULONG ulLength) {
EnterCriticalSection(&m_cs);
if (ReAllocateBuffer(ulLength + (ULONG)(m_Ptr - m_Base)) == (ULONG)-1) {
LeaveCriticalSection(&m_cs);
return FALSE;
}
CopyMemory(m_Ptr, Buffer, ulLength);
m_Ptr += ulLength;
LeaveCriticalSection(&m_cs);
return TRUE;
}
ULONG ReAllocateBuffer(ULONG ulLength) {
if (ulLength < m_ulMaxLength)
return 0;
ULONG ulNewMaxLength = (ULONG)(ceil(ulLength / F_PAGE_ALIGNMENT) * U_PAGE_ALIGNMENT);
PBYTE NewBase = (PBYTE)VirtualAlloc(NULL, ulNewMaxLength, MEM_COMMIT, PAGE_READWRITE);
if (NewBase == NULL) {
return (ULONG)-1;
}
ULONG ulv1 = (ULONG)(m_Ptr - m_Base);
CopyMemory(NewBase, m_Base, ulv1);
if (m_Base) {
VirtualFree(m_Base, 0, MEM_RELEASE);
}
m_Base = NewBase;
m_Ptr = m_Base + ulv1;
m_ulMaxLength = ulNewMaxLength;
return m_ulMaxLength;
}
VOID ClearBuffer() {
EnterCriticalSection(&m_cs);
m_Ptr = m_Base;
m_ulReadOffset = 0;
DeAllocateBuffer(1024);
LeaveCriticalSection(&m_cs);
}
ULONG GetBufferLength() {
EnterCriticalSection(&m_cs);
if (m_Base == NULL) {
LeaveCriticalSection(&m_cs);
return 0;
}
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG len = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
LeaveCriticalSection(&m_cs);
return len;
}
std::string Skip(ULONG ulPos) {
if (ulPos == 0)
return "";
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (ulPos > effectiveDataLen) {
ulPos = effectiveDataLen;
}
std::string ret((char*)(m_Base + m_ulReadOffset), (char*)(m_Base + m_ulReadOffset + ulPos));
m_ulReadOffset += ulPos;
if (m_ulReadOffset > (ULONG)(m_ulMaxLength * COMPACT_THRESHOLD)) {
CompactBuffer();
}
LeaveCriticalSection(&m_cs);
return ret;
}
LPBYTE GetBuffer(ULONG ulPos = 0) {
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (m_Base == NULL || ulPos >= effectiveDataLen) {
LeaveCriticalSection(&m_cs);
return NULL;
}
LPBYTE result = m_Base + m_ulReadOffset + ulPos;
LeaveCriticalSection(&m_cs);
return result;
}
Buffer GetMyBuffer(ULONG ulPos = 0) {
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (m_Base == NULL || ulPos >= effectiveDataLen) {
LeaveCriticalSection(&m_cs);
return Buffer();
}
Buffer result(m_Base + m_ulReadOffset + ulPos, effectiveDataLen - ulPos);
LeaveCriticalSection(&m_cs);
return result;
}
BYTE GetBYTE(ULONG ulPos) {
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (m_Base == NULL || ulPos >= effectiveDataLen) {
LeaveCriticalSection(&m_cs);
return 0;
}
BYTE p = *(m_Base + m_ulReadOffset + ulPos);
LeaveCriticalSection(&m_cs);
return p;
}
BOOL CopyBuffer(PVOID pDst, ULONG nLen, ULONG ulPos) {
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (m_Base == NULL || pDst == NULL || ulPos >= effectiveDataLen || (effectiveDataLen - ulPos) < nLen) {
LeaveCriticalSection(&m_cs);
return FALSE;
}
memcpy(pDst, m_Base + m_ulReadOffset + ulPos, nLen);
LeaveCriticalSection(&m_cs);
return TRUE;
}
LPBYTE GetWriteBuffer(ULONG requiredSize, ULONG& availableSize) {
EnterCriticalSection(&m_cs);
if (m_ulReadOffset > 0) {
CompactBuffer();
}
ULONG currentDataLen = (ULONG)(m_Ptr - m_Base);
if (ReAllocateBuffer(currentDataLen + requiredSize) == (ULONG)-1) {
LeaveCriticalSection(&m_cs);
availableSize = 0;
return NULL;
}
availableSize = m_ulMaxLength - currentDataLen;
LPBYTE result = m_Ptr;
LeaveCriticalSection(&m_cs);
return result;
}
VOID CommitWrite(ULONG writtenSize) {
EnterCriticalSection(&m_cs);
m_Ptr += writtenSize;
LeaveCriticalSection(&m_cs);
}
// 测试辅助:获取内部状态
ULONG GetReadOffset() const { return m_ulReadOffset; }
ULONG GetMaxLength() const { return m_ulMaxLength; }
protected:
PBYTE m_Base;
PBYTE m_Ptr;
ULONG m_ulMaxLength;
ULONG m_ulReadOffset;
CRITICAL_SECTION m_cs;
};
} // namespace ServerBuffer
using ServerBuffer::CBuffer;
using ServerBuffer::Buffer;
// ============================================
// 测试夹具
// ============================================
class ServerBufferTest : public ::testing::Test {
protected:
CBuffer buffer;
void SetUp() override {}
void TearDown() override {}
void WriteFillData(ULONG length, BYTE fillValue = 0x42) {
std::vector<BYTE> data(length, fillValue);
buffer.WriteBuffer(data.data(), length);
}
};
// ============================================
// 构造/析构测试
// ============================================
TEST_F(ServerBufferTest, Constructor_InitializesEmpty) {
CBuffer newBuffer;
EXPECT_EQ(newBuffer.GetBufferLength(), 0u);
EXPECT_EQ(newBuffer.GetBuffer(), nullptr);
}
// ============================================
// WriteBuffer 测试
// ============================================
TEST_F(ServerBufferTest, WriteBuffer_ValidData_ReturnsTrue) {
BYTE data[] = {1, 2, 3, 4, 5};
EXPECT_TRUE(buffer.WriteBuffer(data, 5));
EXPECT_EQ(buffer.GetBufferLength(), 5u);
}
TEST_F(ServerBufferTest, WriteBuffer_MultipleWrites_AccumulatesData) {
BYTE data1[] = {1, 2, 3};
BYTE data2[] = {4, 5};
buffer.WriteBuffer(data1, 3);
buffer.WriteBuffer(data2, 2);
EXPECT_EQ(buffer.GetBufferLength(), 5u);
}
TEST_F(ServerBufferTest, WriteBuffer_LargeData_HandlesCorrectly) {
const ULONG largeSize = 100000;
std::vector<BYTE> data(largeSize, 0xAB);
EXPECT_TRUE(buffer.WriteBuffer(data.data(), largeSize));
EXPECT_EQ(buffer.GetBufferLength(), largeSize);
}
// ============================================
// ReadBuffer 测试
// ============================================
TEST_F(ServerBufferTest, ReadBuffer_EmptyBuffer_ReturnsZero) {
BYTE result[10];
EXPECT_EQ(buffer.ReadBuffer(result, 10), 0u);
}
TEST_F(ServerBufferTest, ReadBuffer_ExactLength_ReturnsAll) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE result[5];
ULONG bytesRead = buffer.ReadBuffer(result, 5);
EXPECT_EQ(bytesRead, 5u);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ServerBufferTest, ReadBuffer_PartialRead_UsesReadOffset) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE result[2];
buffer.ReadBuffer(result, 2);
// 使用延迟偏移,不立即移动数据
EXPECT_EQ(buffer.GetBufferLength(), 3u);
EXPECT_GT(buffer.GetReadOffset(), 0u);
}
TEST_F(ServerBufferTest, ReadBuffer_RequestExceedsAvailable_ReturnsAvailableOnly) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
BYTE result[10];
ULONG bytesRead = buffer.ReadBuffer(result, 10);
EXPECT_EQ(bytesRead, 3u);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
// ============================================
// RemoveCompletedBuffer 测试
// ============================================
TEST_F(ServerBufferTest, RemoveCompletedBuffer_PartialRemove_UpdatesOffset) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
ULONG removed = buffer.RemoveCompletedBuffer(2);
EXPECT_EQ(removed, 2u);
EXPECT_EQ(buffer.GetBufferLength(), 3u);
}
TEST_F(ServerBufferTest, RemoveCompletedBuffer_ExceedsLength_ClampsToAvailable) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
ULONG removed = buffer.RemoveCompletedBuffer(100);
EXPECT_EQ(removed, 3u);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
// ============================================
// Skip 测试
// ============================================
TEST_F(ServerBufferTest, Skip_ReturnsSkippedData) {
BYTE data[] = {'H', 'e', 'l', 'l', 'o'};
buffer.WriteBuffer(data, 5);
std::string skipped = buffer.Skip(3);
EXPECT_EQ(skipped, "Hel");
EXPECT_EQ(buffer.GetBufferLength(), 2u);
}
TEST_F(ServerBufferTest, Skip_ExceedsLength_ClampsToAvailable) {
BYTE data[] = {'A', 'B', 'C'};
buffer.WriteBuffer(data, 3);
std::string skipped = buffer.Skip(100);
EXPECT_EQ(skipped, "ABC");
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ServerBufferTest, Skip_ZeroLength_ReturnsEmpty) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
std::string skipped = buffer.Skip(0);
EXPECT_EQ(skipped, "");
EXPECT_EQ(buffer.GetBufferLength(), 3u);
}
// ============================================
// GetBuffer 测试
// ============================================
TEST_F(ServerBufferTest, GetBuffer_RespectsReadOffset) {
BYTE data[] = {10, 20, 30, 40, 50};
buffer.WriteBuffer(data, 5);
// 读取前两个字节,更新偏移
BYTE temp[2];
buffer.ReadBuffer(temp, 2);
// GetBuffer(0) 应该返回第三个字节30
EXPECT_EQ(*buffer.GetBuffer(0), 30);
EXPECT_EQ(*buffer.GetBuffer(1), 40);
EXPECT_EQ(*buffer.GetBuffer(2), 50);
}
TEST_F(ServerBufferTest, GetBuffer_PositionExceedsEffectiveLength_ReturnsNull) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE temp[3];
buffer.ReadBuffer(temp, 3); // 有效长度变为 2
EXPECT_NE(buffer.GetBuffer(0), nullptr);
EXPECT_NE(buffer.GetBuffer(1), nullptr);
EXPECT_EQ(buffer.GetBuffer(2), nullptr); // 超出有效范围
}
// ============================================
// GetMyBuffer 测试
// ============================================
TEST_F(ServerBufferTest, GetMyBuffer_ReturnsCorrectBuffer) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
Buffer buf = buffer.GetMyBuffer(0);
EXPECT_EQ(buf.length(), 5u);
}
TEST_F(ServerBufferTest, GetMyBuffer_RespectsReadOffset) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE temp[2];
buffer.ReadBuffer(temp, 2);
Buffer buf = buffer.GetMyBuffer(0);
EXPECT_EQ(buf.length(), 3u);
EXPECT_EQ(*buf.GetBuffer(0), 3);
}
// ============================================
// GetBYTE 测试
// ============================================
TEST_F(ServerBufferTest, GetBYTE_ReturnsCorrectByte) {
BYTE data[] = {10, 20, 30, 40, 50};
buffer.WriteBuffer(data, 5);
EXPECT_EQ(buffer.GetBYTE(0), 10);
EXPECT_EQ(buffer.GetBYTE(2), 30);
EXPECT_EQ(buffer.GetBYTE(4), 50);
}
TEST_F(ServerBufferTest, GetBYTE_RespectsReadOffset) {
BYTE data[] = {10, 20, 30, 40, 50};
buffer.WriteBuffer(data, 5);
BYTE temp[2];
buffer.ReadBuffer(temp, 2);
EXPECT_EQ(buffer.GetBYTE(0), 30);
EXPECT_EQ(buffer.GetBYTE(1), 40);
}
TEST_F(ServerBufferTest, GetBYTE_OutOfRange_ReturnsZero) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
EXPECT_EQ(buffer.GetBYTE(100), 0);
}
// ============================================
// CopyBuffer 测试
// ============================================
TEST_F(ServerBufferTest, CopyBuffer_ValidRange_ReturnsTrue) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE dest[3];
EXPECT_TRUE(buffer.CopyBuffer(dest, 3, 1));
EXPECT_EQ(dest[0], 2);
EXPECT_EQ(dest[1], 3);
EXPECT_EQ(dest[2], 4);
}
TEST_F(ServerBufferTest, CopyBuffer_RespectsReadOffset) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE temp[2];
buffer.ReadBuffer(temp, 2);
BYTE dest[2];
EXPECT_TRUE(buffer.CopyBuffer(dest, 2, 0));
EXPECT_EQ(dest[0], 3);
EXPECT_EQ(dest[1], 4);
}
TEST_F(ServerBufferTest, CopyBuffer_ExceedsRange_ReturnsFalse) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
BYTE dest[10];
EXPECT_FALSE(buffer.CopyBuffer(dest, 10, 0));
}
// ============================================
// 零拷贝写入测试
// ============================================
TEST_F(ServerBufferTest, GetWriteBuffer_ReturnsValidPointer) {
ULONG availableSize = 0;
LPBYTE writePtr = buffer.GetWriteBuffer(100, availableSize);
EXPECT_NE(writePtr, nullptr);
EXPECT_GE(availableSize, 100u);
}
TEST_F(ServerBufferTest, CommitWrite_UpdatesLength) {
ULONG availableSize = 0;
LPBYTE writePtr = buffer.GetWriteBuffer(100, availableSize);
// 直接写入
for (int i = 0; i < 50; i++) {
writePtr[i] = (BYTE)i;
}
buffer.CommitWrite(50);
EXPECT_EQ(buffer.GetBufferLength(), 50u);
EXPECT_EQ(buffer.GetBYTE(0), 0);
EXPECT_EQ(buffer.GetBYTE(49), 49);
}
// ============================================
// ClearBuffer 测试
// ============================================
TEST_F(ServerBufferTest, ClearBuffer_ResetsEverything) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE temp[2];
buffer.ReadBuffer(temp, 2); // 创建读取偏移
buffer.ClearBuffer();
EXPECT_EQ(buffer.GetBufferLength(), 0u);
EXPECT_EQ(buffer.GetReadOffset(), 0u);
}
// ============================================
// 下溢防护测试
// ============================================
TEST_F(ServerBufferTest, UnderflowProtection_ReadMoreThanLength_NoUnderflow) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
BYTE result[1000];
ULONG bytesRead = buffer.ReadBuffer(result, ULONG_MAX - 1);
EXPECT_EQ(bytesRead, 3u);
}
TEST_F(ServerBufferTest, UnderflowProtection_SkipMoreThanLength_NoUnderflow) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
std::string skipped = buffer.Skip(ULONG_MAX - 1);
EXPECT_EQ(skipped.length(), 3u);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ServerBufferTest, UnderflowProtection_RemoveMoreThanLength_NoUnderflow) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
ULONG removed = buffer.RemoveCompletedBuffer(ULONG_MAX - 1);
EXPECT_EQ(removed, 3u);
}
TEST_F(ServerBufferTest, UnderflowProtection_GetByteOutOfRange_ReturnsZero) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
EXPECT_EQ(buffer.GetBYTE(ULONG_MAX - 1), 0);
}
// ============================================
// 压缩策略测试
// ============================================
TEST_F(ServerBufferTest, Compaction_TriggersAtThreshold) {
// 写入足够数据
// m_ulMaxLength 会被页对齐到 ceil(10000/4096)*4096 = 12288
// 压缩阈值 = 12288 * 0.5 = 6144
WriteFillData(10000);
ULONG initialOffset = buffer.GetReadOffset();
EXPECT_EQ(initialOffset, 0u);
// 读取超过阈值的数据(需要 > 6144
std::vector<BYTE> temp(7000);
buffer.ReadBuffer(temp.data(), 7000);
// 压缩后偏移应该重置
EXPECT_EQ(buffer.GetReadOffset(), 0u);
EXPECT_EQ(buffer.GetBufferLength(), 3000u);
}
// ============================================
// 线程安全测试
// ============================================
TEST_F(ServerBufferTest, ThreadSafety_ConcurrentReadWrite_NoDataCorruption) {
std::atomic<bool> running{true};
std::atomic<int> writeCount{0};
std::atomic<int> readCount{0};
// 写线程
std::thread writer([&]() {
BYTE data[100];
for (int i = 0; i < 100; i++) {
data[i] = (BYTE)i;
}
while (running) {
if (buffer.WriteBuffer(data, 100)) {
writeCount++;
}
std::this_thread::yield();
}
});
// 读线程
std::thread reader([&]() {
BYTE result[50];
while (running) {
if (buffer.ReadBuffer(result, 50) > 0) {
readCount++;
}
std::this_thread::yield();
}
});
// 运行一段时间
std::this_thread::sleep_for(std::chrono::milliseconds(100));
running = false;
writer.join();
reader.join();
// 验证无崩溃,且有数据交换
EXPECT_GT(writeCount.load(), 0);
EXPECT_GT(readCount.load(), 0);
}
TEST_F(ServerBufferTest, ThreadSafety_MultipleReaders_NoDeadlock) {
// 预填充数据
WriteFillData(10000);
std::atomic<bool> running{true};
std::vector<std::thread> readers;
for (int i = 0; i < 4; i++) {
readers.emplace_back([&]() {
while (running) {
buffer.GetBufferLength();
buffer.GetBYTE(0);
buffer.GetBuffer(0);
std::this_thread::yield();
}
});
}
std::this_thread::sleep_for(std::chrono::milliseconds(50));
running = false;
for (auto& t : readers) {
t.join();
}
// 无死锁即为成功
SUCCEED();
}
// ============================================
// 数据完整性测试
// ============================================
TEST_F(ServerBufferTest, DataIntegrity_WriteReadCycle_PreservesData) {
std::vector<BYTE> data(256);
for (int i = 0; i < 256; i++) {
data[i] = (BYTE)i;
}
buffer.WriteBuffer(data.data(), 256);
std::vector<BYTE> result(256);
ULONG bytesRead = buffer.ReadBuffer(result.data(), 256);
EXPECT_EQ(bytesRead, 256u);
for (int i = 0; i < 256; i++) {
EXPECT_EQ(result[i], data[i]) << "Mismatch at index " << i;
}
}
TEST_F(ServerBufferTest, DataIntegrity_PartialReads_PreservesSequence) {
// 写入 1-100
std::vector<BYTE> data(100);
for (int i = 0; i < 100; i++) {
data[i] = (BYTE)(i + 1);
}
buffer.WriteBuffer(data.data(), 100);
// 分多次读取
BYTE result[100];
ULONG totalRead = 0;
totalRead += buffer.ReadBuffer(result, 30);
totalRead += buffer.ReadBuffer(result + 30, 30);
totalRead += buffer.ReadBuffer(result + 60, 40);
EXPECT_EQ(totalRead, 100u);
for (int i = 0; i < 100; i++) {
EXPECT_EQ(result[i], (BYTE)(i + 1));
}
}
// ============================================
// 参数化测试
// ============================================
class ServerBufferParameterizedTest
: public ::testing::TestWithParam<std::tuple<size_t, size_t, size_t>> {
protected:
CBuffer buffer;
};
TEST_P(ServerBufferParameterizedTest, ReadBuffer_VariousLengths) {
auto [writeLen, readLen, expectedRead] = GetParam();
std::vector<BYTE> data(writeLen, 0x42);
if (writeLen > 0) {
buffer.WriteBuffer(data.data(), (ULONG)writeLen);
}
std::vector<BYTE> result(readLen > 0 ? readLen : 1);
ULONG actual = buffer.ReadBuffer(result.data(), (ULONG)readLen);
EXPECT_EQ(actual, expectedRead);
}
INSTANTIATE_TEST_SUITE_P(
ReadLengths,
ServerBufferParameterizedTest,
::testing::Values(
std::make_tuple(10, 5, 5),
std::make_tuple(5, 10, 5),
std::make_tuple(0, 5, 0),
std::make_tuple(100, 0, 0),
std::make_tuple(10000, 5000, 5000)
)
);