Init: Migrate SimpleRemoter (Since v1.3.1) to Gitea

This commit is contained in:
yuanyuanxiang
2026-04-19 19:55:01 +02:00
commit 5a325a202b
744 changed files with 235562 additions and 0 deletions

395
test/CMakeLists.txt Normal file
View File

@@ -0,0 +1,395 @@
# SimpleRemoter 测试构建配置
cmake_minimum_required(VERSION 3.14)
project(SimpleRemoterTests)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# MSVC 编码设置
if(MSVC)
add_compile_options(/utf-8)
endif()
# 选项:是否构建测试
option(BUILD_TESTING "Build tests" ON)
if(NOT BUILD_TESTING)
return()
endif()
# 使用 FetchContent 获取 Google Test
include(FetchContent)
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG v1.14.0
)
# Windows 下避免覆盖 /MT /MD 设置
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)
enable_testing()
# 包含项目头文件目录
include_directories(${CMAKE_SOURCE_DIR}/..)
include_directories(${CMAKE_SOURCE_DIR}/../client)
include_directories(${CMAKE_SOURCE_DIR}/../server/2015Remote)
include_directories(${CMAKE_SOURCE_DIR}/../common)
# ============================================
# 客户端 Buffer 测试
# ============================================
add_executable(client_buffer_test
unit/client/BufferTest.cpp
)
# 测试文件内联包含了 Buffer 实现的测试版本
# 无需链接原始源文件,支持跨平台测试
target_link_libraries(client_buffer_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(client_buffer_test PRIVATE _WIN32)
endif()
# ============================================
# 服务端 Buffer 测试
# ============================================
add_executable(server_buffer_test
unit/server/BufferTest.cpp
)
# 测试文件内联包含了 Buffer 实现的测试版本
# 支持跨平台测试
target_link_libraries(server_buffer_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(server_buffer_test PRIVATE _WIN32)
endif()
# Linux 需要 pthread
if(UNIX)
find_package(Threads REQUIRED)
target_link_libraries(client_buffer_test Threads::Threads)
target_link_libraries(server_buffer_test Threads::Threads)
endif()
# ============================================
# 协议测试 (Phase 1)
# ============================================
add_executable(protocol_test
unit/protocol/PacketTest.cpp
unit/protocol/PathUtilsTest.cpp
)
target_link_libraries(protocol_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(protocol_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(protocol_test Threads::Threads)
endif()
# ============================================
# 文件传输测试 (Phase 2)
# ============================================
add_executable(file_transfer_test
unit/file/FileTransferV2Test.cpp
)
target_link_libraries(file_transfer_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(file_transfer_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(file_transfer_test Threads::Threads)
endif()
# ============================================
# 分块管理测试 (Phase 2)
# ============================================
add_executable(chunk_manager_test
unit/file/ChunkManagerTest.cpp
)
target_link_libraries(chunk_manager_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(chunk_manager_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(chunk_manager_test Threads::Threads)
endif()
# ============================================
# SHA-256 校验测试 (Phase 2)
# ============================================
add_executable(sha256_verify_test
unit/file/SHA256VerifyTest.cpp
)
target_link_libraries(sha256_verify_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(sha256_verify_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(sha256_verify_test Threads::Threads)
endif()
# ============================================
# 断点续传状态测试 (Phase 2)
# ============================================
add_executable(resume_state_test
unit/file/ResumeStateTest.cpp
)
target_link_libraries(resume_state_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(resume_state_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(resume_state_test Threads::Threads)
endif()
# ============================================
# 协议头验证测试 (Phase 3)
# ============================================
add_executable(header_test
unit/network/HeaderTest.cpp
)
target_link_libraries(header_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(header_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(header_test Threads::Threads)
endif()
# ============================================
# 粘包/分包测试 (Phase 3)
# ============================================
add_executable(packet_fragment_test
unit/network/PacketFragmentTest.cpp
)
target_link_libraries(packet_fragment_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(packet_fragment_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(packet_fragment_test Threads::Threads)
endif()
# ============================================
# HTTP 伪装测试 (Phase 3)
# ============================================
add_executable(http_mask_test
unit/network/HttpMaskTest.cpp
)
target_link_libraries(http_mask_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(http_mask_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(http_mask_test Threads::Threads)
endif()
# ============================================
# 差分算法测试 (Phase 4)
# ============================================
add_executable(diff_algorithm_test
unit/screen/DiffAlgorithmTest.cpp
)
target_link_libraries(diff_algorithm_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(diff_algorithm_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(diff_algorithm_test Threads::Threads)
endif()
# ============================================
# RGB565 压缩测试 (Phase 4)
# ============================================
add_executable(rgb565_test
unit/screen/RGB565Test.cpp
)
target_link_libraries(rgb565_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(rgb565_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(rgb565_test Threads::Threads)
endif()
# ============================================
# 滚动检测测试 (Phase 4)
# ============================================
add_executable(scroll_detector_test
unit/screen/ScrollDetectorTest.cpp
)
target_link_libraries(scroll_detector_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(scroll_detector_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(scroll_detector_test Threads::Threads)
endif()
# ============================================
# 质量自适应测试 (Phase 4)
# ============================================
add_executable(quality_adaptive_test
unit/screen/QualityAdaptiveTest.cpp
)
target_link_libraries(quality_adaptive_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(quality_adaptive_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(quality_adaptive_test Threads::Threads)
endif()
# ============================================
# 时间验证测试 (授权模块)
# ============================================
add_executable(date_verify_test
unit/auth/DateVerifyTest.cpp
)
target_link_libraries(date_verify_test
GTest::gtest_main
)
if(WIN32)
target_compile_definitions(date_verify_test PRIVATE _WIN32)
endif()
if(UNIX)
target_link_libraries(date_verify_test Threads::Threads)
endif()
# ============================================
# IP地理位置测试 (网络API)
# ============================================
if(WIN32)
add_executable(geolocation_test
unit/network/GeoLocationTest.cpp
)
target_link_libraries(geolocation_test
GTest::gtest_main
wininet
ws2_32
)
target_compile_definitions(geolocation_test PRIVATE _WIN32)
endif()
# ============================================
# 注册表配置测试 (iniFile/binFile)
# ============================================
if(WIN32)
add_executable(registry_config_test
unit/config/RegistryConfigTest.cpp
)
target_link_libraries(registry_config_test
GTest::gtest_main
advapi32
)
target_compile_definitions(registry_config_test PRIVATE _WIN32)
endif()
# 注册测试
include(GoogleTest)
if(WIN32)
gtest_discover_tests(geolocation_test)
gtest_discover_tests(registry_config_test)
endif()
gtest_discover_tests(client_buffer_test)
gtest_discover_tests(server_buffer_test)
gtest_discover_tests(protocol_test)
gtest_discover_tests(file_transfer_test)
gtest_discover_tests(chunk_manager_test)
gtest_discover_tests(sha256_verify_test)
gtest_discover_tests(resume_state_test)
gtest_discover_tests(header_test)
gtest_discover_tests(packet_fragment_test)
gtest_discover_tests(http_mask_test)
gtest_discover_tests(diff_algorithm_test)
gtest_discover_tests(rgb565_test)
gtest_discover_tests(scroll_detector_test)
gtest_discover_tests(quality_adaptive_test)
gtest_discover_tests(date_verify_test)
# 自定义目标:运行所有测试
add_custom_target(run_tests
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure
DEPENDS client_buffer_test server_buffer_test protocol_test
file_transfer_test chunk_manager_test sha256_verify_test resume_state_test
header_test packet_fragment_test http_mask_test
diff_algorithm_test rgb565_test scroll_detector_test quality_adaptive_test
date_verify_test
)

559
test/IniParser_test.cpp Normal file
View File

@@ -0,0 +1,559 @@
// IniParser_test.cpp - CIniParser 单元测试
// 编译: cl /EHsc /W4 IniParser_test.cpp /Fe:IniParser_test.exe
// 运行: IniParser_test.exe
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include "../common/IniParser.h"
static int g_total = 0;
static int g_passed = 0;
static int g_failed = 0;
#define TEST_ASSERT(expr, msg) do { \
g_total++; \
if (expr) { g_passed++; } \
else { g_failed++; printf(" FAIL: %s\n %s:%d\n", msg, __FILE__, __LINE__); } \
} while(0)
#define TEST_STR_EQ(actual, expected, msg) do { \
g_total++; \
if (std::string(actual) == std::string(expected)) { g_passed++; } \
else { g_failed++; printf(" FAIL: %s\n expected: \"%s\"\n actual: \"%s\"\n %s:%d\n", \
msg, expected, actual, __FILE__, __LINE__); } \
} while(0)
// 辅助:写入临时文件
static std::string WriteTempFile(const char* name, const char* content)
{
std::string path = std::string("_test_") + name + ".ini";
FILE* f = nullptr;
#ifdef _MSC_VER
fopen_s(&f, path.c_str(), "w");
#else
f = fopen(path.c_str(), "w");
#endif
if (f) {
fputs(content, f);
fclose(f);
}
return path;
}
static void CleanupFile(const std::string& path)
{
remove(path.c_str());
}
// ============================================
// Test 1: 基本 key=value 解析
// ============================================
void Test_BasicKeyValue()
{
printf("[Test 1] Basic key=value parsing\n");
std::string path = WriteTempFile("basic",
"[Strings]\n"
"hello=world\n"
"foo=bar\n"
);
CIniParser ini;
TEST_ASSERT(ini.LoadFile(path.c_str()), "LoadFile should succeed");
TEST_STR_EQ(ini.GetValue("Strings", "hello"), "world", "hello -> world");
TEST_STR_EQ(ini.GetValue("Strings", "foo"), "bar", "foo -> bar");
TEST_ASSERT(ini.GetSectionSize("Strings") == 2, "Section size should be 2");
CleanupFile(path);
}
// ============================================
// Test 2: key 尾部空格保留(核心特性)
// ============================================
void Test_KeyTrailingSpace()
{
printf("[Test 2] Key trailing space preserved\n");
// 模拟: "请输入主机备注: =Enter host note:"
// key 是 "请输入主机备注: "(冒号+空格),不能被 trim
std::string path = WriteTempFile("trailing_space",
"[Strings]\n"
"key_no_space=value1\n"
"key_with_space =value2\n"
"key_with_2spaces =value3\n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
TEST_STR_EQ(ini.GetValue("Strings", "key_no_space"), "value1",
"key without trailing space");
TEST_STR_EQ(ini.GetValue("Strings", "key_with_space "), "value2",
"key with 1 trailing space (must preserve)");
TEST_STR_EQ(ini.GetValue("Strings", "key_with_2spaces "), "value3",
"key with 2 trailing spaces (must preserve)");
// 不带空格的查找应该找不到
TEST_STR_EQ(ini.GetValue("Strings", "key_with_space", "NOT_FOUND"), "NOT_FOUND",
"key without trailing space should NOT match");
CleanupFile(path);
}
// ============================================
// Test 3: value 中含特殊字符
// ============================================
void Test_SpecialCharsInValue()
{
printf("[Test 3] Special characters in value\n");
std::string path = WriteTempFile("special_chars",
"[Strings]\n"
"menu=Menu(&F)\n"
"addr=<IP:PORT>\n"
"fmt=%s connected %d times\n"
"paren=(auto-restore on expiry)\n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
TEST_STR_EQ(ini.GetValue("Strings", "menu"), "Menu(&F)", "value with (&F)");
TEST_STR_EQ(ini.GetValue("Strings", "addr"), "<IP:PORT>", "value with <IP:PORT>");
TEST_STR_EQ(ini.GetValue("Strings", "fmt"), "%s connected %d times", "value with %s %d");
TEST_STR_EQ(ini.GetValue("Strings", "paren"), "(auto-restore on expiry)", "value with parens");
CleanupFile(path);
}
// ============================================
// Test 4: 注释行跳过
// ============================================
void Test_Comments()
{
printf("[Test 4] Comment lines skipped\n");
std::string path = WriteTempFile("comments",
"; This is a comment\n"
"# This is also a comment\n"
"[Strings]\n"
"; ============================================\n"
"# Section header comment\n"
"key1=value1\n"
"; key2=should_not_exist\n"
"key3=value3\n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
TEST_STR_EQ(ini.GetValue("Strings", "key1"), "value1", "key1 exists");
TEST_STR_EQ(ini.GetValue("Strings", "key3"), "value3", "key3 exists");
TEST_STR_EQ(ini.GetValue("Strings", "key2", "NOT_FOUND"), "NOT_FOUND",
"commented key2 should not exist");
TEST_ASSERT(ini.GetSectionSize("Strings") == 2, "Only 2 keys (comments excluded)");
CleanupFile(path);
}
// ============================================
// Test 5: 空行跳过
// ============================================
void Test_EmptyLines()
{
printf("[Test 5] Empty lines skipped\n");
std::string path = WriteTempFile("empty_lines",
"\n"
"\n"
"[Strings]\n"
"\n"
"key1=value1\n"
"\n"
"\n"
"key2=value2\n"
"\n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
TEST_ASSERT(ini.GetSectionSize("Strings") == 2, "2 keys despite empty lines");
TEST_STR_EQ(ini.GetValue("Strings", "key1"), "value1", "key1");
TEST_STR_EQ(ini.GetValue("Strings", "key2"), "value2", "key2");
CleanupFile(path);
}
// ============================================
// Test 6: section 切换
// ============================================
void Test_MultipleSections()
{
printf("[Test 6] Multiple sections\n");
std::string path = WriteTempFile("sections",
"[Strings]\n"
"key1=value1\n"
"key2=value2\n"
"[Other]\n"
"key1=other_value1\n"
"key3=other_value3\n"
"[Strings2]\n"
"keyA=valueA\n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
TEST_STR_EQ(ini.GetValue("Strings", "key1"), "value1", "Strings.key1");
TEST_STR_EQ(ini.GetValue("Strings", "key2"), "value2", "Strings.key2");
TEST_STR_EQ(ini.GetValue("Other", "key1"), "other_value1", "Other.key1");
TEST_STR_EQ(ini.GetValue("Other", "key3"), "other_value3", "Other.key3");
TEST_STR_EQ(ini.GetValue("Strings2", "keyA"), "valueA", "Strings2.keyA");
// Strings section should not contain Other section's keys
TEST_STR_EQ(ini.GetValue("Strings", "key3", "NOT_FOUND"), "NOT_FOUND",
"Strings should not have Other's key3");
TEST_ASSERT(ini.GetSectionSize("Strings") == 2, "Strings has 2 keys");
TEST_ASSERT(ini.GetSectionSize("Other") == 2, "Other has 2 keys");
TEST_ASSERT(ini.GetSectionSize("Strings2") == 1, "Strings2 has 1 key");
CleanupFile(path);
}
// ============================================
// Test 7: 大文件(超过 32KB
// ============================================
void Test_LargeFile()
{
printf("[Test 7] Large file (>32KB)\n");
std::string path = std::string("_test_large.ini");
FILE* f = nullptr;
#ifdef _MSC_VER
fopen_s(&f, path.c_str(), "w");
#else
f = fopen(path.c_str(), "w");
#endif
if (!f) {
printf(" SKIP: Cannot create temp file\n");
return;
}
fputs("[Strings]\n", f);
// 写入大量条目使文件超过 32KB
const int entryCount = 2000;
for (int i = 0; i < entryCount; i++) {
fprintf(f, "key_%04d=value_for_entry_number_%04d_padding_text_here\n", i, i);
}
// 在文件末尾写一个特殊条目
fputs("last_key=last_value\n", f);
fclose(f);
CIniParser ini;
TEST_ASSERT(ini.LoadFile(path.c_str()), "LoadFile should succeed for large file");
// 验证首尾和中间的条目
TEST_STR_EQ(ini.GetValue("Strings", "key_0000"),
"value_for_entry_number_0000_padding_text_here",
"First entry");
TEST_STR_EQ(ini.GetValue("Strings", "key_0999"),
"value_for_entry_number_0999_padding_text_here",
"Middle entry");
TEST_STR_EQ(ini.GetValue("Strings", "key_1999"),
"value_for_entry_number_1999_padding_text_here",
"Last numbered entry");
TEST_STR_EQ(ini.GetValue("Strings", "last_key"), "last_value",
"Entry at very end of large file");
size_t size = ini.GetSectionSize("Strings");
TEST_ASSERT(size == entryCount + 1,
"Section size should be entryCount + 1 (last_key)");
printf(" File has %d entries, all readable\n", (int)size);
CleanupFile(path);
}
// ============================================
// Test 8: 文件不存在
// ============================================
void Test_FileNotExist()
{
printf("[Test 8] File not exist\n");
CIniParser ini;
TEST_ASSERT(!ini.LoadFile("_nonexistent_file_12345.ini"), "LoadFile should return false");
TEST_ASSERT(!ini.LoadFile(nullptr), "LoadFile(nullptr) should return false");
TEST_ASSERT(!ini.LoadFile(""), "LoadFile('') should return false");
TEST_ASSERT(ini.GetSection("Strings") == nullptr, "No sections after failed load");
}
// ============================================
// Test 9: 空文件
// ============================================
void Test_EmptyFile()
{
printf("[Test 9] Empty file\n");
std::string path = WriteTempFile("empty", "");
CIniParser ini;
TEST_ASSERT(ini.LoadFile(path.c_str()), "LoadFile should succeed for empty file");
TEST_ASSERT(ini.GetSection("Strings") == nullptr, "No Strings section in empty file");
TEST_ASSERT(ini.GetSectionSize("Strings") == 0, "Section size is 0");
CleanupFile(path);
}
// ============================================
// Test 10: value 中含 '='(只按第一个 '=' 分割)
// ============================================
void Test_EqualsInValue()
{
printf("[Test 10] Equals sign in value\n");
std::string path = WriteTempFile("equals",
"[Strings]\n"
"formula=a=b+c\n"
"equation=x=1=2=3\n"
"normal=hello\n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
TEST_STR_EQ(ini.GetValue("Strings", "formula"), "a=b+c",
"value with one '=' should keep it");
TEST_STR_EQ(ini.GetValue("Strings", "equation"), "x=1=2=3",
"value with multiple '=' should keep all");
TEST_STR_EQ(ini.GetValue("Strings", "normal"), "hello",
"normal value unaffected");
CleanupFile(path);
}
// ============================================
// Test 11: key 中含 \r\n 转义序列
// ============================================
void Test_EscapeCRLF_InKey()
{
printf("[Test 11] Escape \\r\\n in key\n");
// INI 文件中写字面量 \r\n解析器应转为真正的 0x0D 0x0A
// 模拟代码中: _TR("\n编译日期: ") 和 _TR("操作失败\r\n请重试")
std::string path = WriteTempFile("escape_key",
"[Strings]\n"
"\\n compile date: =\\n Build Date: \n"
"fail\\r\\nretry=Fail\\r\\nRetry\n"
"line1\\nline2\\nline3=L1\\nL2\\nL3\n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
// key "\n compile date: " (真正的换行 + 文本)
TEST_STR_EQ(ini.GetValue("Strings", "\n compile date: "), "\n Build Date: ",
"key with \\n at start");
// key "fail\r\nretry" (真正的 CR+LF)
TEST_STR_EQ(ini.GetValue("Strings", "fail\r\nretry"), "Fail\r\nRetry",
"key with \\r\\n in middle");
// key 含多个 \n
TEST_STR_EQ(ini.GetValue("Strings", "line1\nline2\nline3"), "L1\nL2\nL3",
"key with multiple \\n");
CleanupFile(path);
}
// ============================================
// Test 12: value 中含 \r\n 转义序列
// ============================================
void Test_EscapeCRLF_InValue()
{
printf("[Test 12] Escape \\r\\n in value\n");
std::string path = WriteTempFile("escape_value",
"[Strings]\n"
"msg=hello\\r\\nworld\n"
"multiline=line1\\nline2\\nline3\n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
TEST_STR_EQ(ini.GetValue("Strings", "msg"), "hello\r\nworld",
"value with \\r\\n");
TEST_STR_EQ(ini.GetValue("Strings", "multiline"), "line1\nline2\nline3",
"value with multiple \\n");
CleanupFile(path);
}
// ============================================
// Test 13: \\ 和 \" 转义
// ============================================
void Test_EscapeBackslashAndQuote()
{
printf("[Test 13] Escape \\\\ and \\\" sequences\n");
std::string path = WriteTempFile("escape_bsq",
"[Strings]\n"
"path=C:\\\\Users\\\\test\n"
"quoted=say \\\"hello\\\"\n"
"mixed=\\\"line1\\n line2\\\"\n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
TEST_STR_EQ(ini.GetValue("Strings", "path"), "C:\\Users\\test",
"double backslash -> single backslash");
TEST_STR_EQ(ini.GetValue("Strings", "quoted"), "say \"hello\"",
"escaped quotes");
TEST_STR_EQ(ini.GetValue("Strings", "mixed"), "\"line1\n line2\"",
"mixed \\\" and \\n");
CleanupFile(path);
}
// ============================================
// Test 14: \t 转义
// ============================================
void Test_EscapeTab()
{
printf("[Test 14] Escape \\t sequence\n");
std::string path = WriteTempFile("escape_tab",
"[Strings]\n"
"col=name\\tvalue\n"
"header=ID\\tName\\tStatus\n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
TEST_STR_EQ(ini.GetValue("Strings", "col"), "name\tvalue",
"\\t -> tab");
TEST_STR_EQ(ini.GetValue("Strings", "header"), "ID\tName\tStatus",
"multiple \\t");
CleanupFile(path);
}
// ============================================
// Test 15: 未知转义保留原样
// ============================================
void Test_UnknownEscapePassthrough()
{
printf("[Test 15] Unknown escape passthrough\n");
std::string path = WriteTempFile("escape_unknown",
"[Strings]\n"
"unknown=hello\\xworld\n"
"trailing_bs=end\\\n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
// \x 不是已知转义,应保留反斜杠
TEST_STR_EQ(ini.GetValue("Strings", "unknown"), "hello\\xworld",
"unknown \\x keeps backslash");
// 行尾的孤立反斜杠fgets 去掉换行后,最后一个字符是 \
TEST_STR_EQ(ini.GetValue("Strings", "trailing_bs"), "end\\",
"trailing backslash preserved");
CleanupFile(path);
}
// ============================================
// Test 16: key 中转义与尾部空格组合
// ============================================
void Test_EscapeWithTrailingSpace()
{
printf("[Test 16] Escape + trailing space in key\n");
// 模拟: _TR("\n编译日期: ") — key 以 \n 开头,以冒号+空格结尾
std::string path = WriteTempFile("escape_trail",
"[Strings]\n"
"\\n date: =\\n Date: \n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
// key 是 "\n date: "(真正换行 + 文本 + 尾部空格)
TEST_STR_EQ(ini.GetValue("Strings", "\n date: "), "\n Date: ",
"escape \\n + trailing space in key");
// 不带尾部空格应找不到
TEST_STR_EQ(ini.GetValue("Strings", "\n date:", "NOT_FOUND"), "NOT_FOUND",
"without trailing space should not match");
CleanupFile(path);
}
// ============================================
// Test 17: key 以 '[' 开头(不是 section 头)
// ============================================
void Test_BracketKey()
{
printf("[Test 17] Key starting with '[' (not a section header)\n");
// 模拟: _TR("[使用FRP]") 和 _TR("[未使用FRP]")
std::string path = WriteTempFile("bracket_key",
"[Strings]\n"
"normal=value1\n"
"[tag1]=[Tag One]\n"
"[tag2]=[Tag Two]\n"
"after=value2\n"
);
CIniParser ini;
ini.LoadFile(path.c_str());
// [tag1]=[Tag One] 应该是 key=value不是 section 头
TEST_STR_EQ(ini.GetValue("Strings", "[tag1]"), "[Tag One]",
"[tag1] parsed as key, not section");
TEST_STR_EQ(ini.GetValue("Strings", "[tag2]"), "[Tag Two]",
"[tag2] parsed as key, not section");
// 前后的普通 key 应仍在 Strings section
TEST_STR_EQ(ini.GetValue("Strings", "normal"), "value1",
"normal key before bracket keys");
TEST_STR_EQ(ini.GetValue("Strings", "after"), "value2",
"normal key after bracket keys still in Strings");
TEST_ASSERT(ini.GetSectionSize("Strings") == 4, "Strings has 4 keys");
// 不应该有 tag1 或 tag2 section
TEST_ASSERT(ini.GetSection("tag1") == nullptr, "no tag1 section");
TEST_ASSERT(ini.GetSection("tag2") == nullptr, "no tag2 section");
CleanupFile(path);
}
// ============================================
// main
// ============================================
int main()
{
printf("=== CIniParser Tests ===\n\n");
Test_BasicKeyValue();
Test_KeyTrailingSpace();
Test_SpecialCharsInValue();
Test_Comments();
Test_EmptyLines();
Test_MultipleSections();
Test_LargeFile();
Test_FileNotExist();
Test_EmptyFile();
Test_EqualsInValue();
Test_EscapeCRLF_InKey();
Test_EscapeCRLF_InValue();
Test_EscapeBackslashAndQuote();
Test_EscapeTab();
Test_UnknownEscapePassthrough();
Test_EscapeWithTrailingSpace();
Test_BracketKey();
printf("\n=== Results: %d/%d passed", g_passed, g_total);
if (g_failed > 0)
printf(", %d FAILED", g_failed);
printf(" ===\n");
return g_failed > 0 ? 1 : 0;
}

374
test/TestCompareBitmap.cpp Normal file
View File

@@ -0,0 +1,374 @@
// Image Diff Algorithm Benchmark
// Compile: cl /O2 /EHsc TestCompareBitmap.cpp
// Or: g++ -O2 -msse2 -o TestCompareBitmap.exe TestCompareBitmap.cpp
#include <windows.h>
#include <stdio.h>
#include <stdlib.h>
#include <emmintrin.h> // SSE2
#include <chrono>
typedef unsigned char BYTE;
typedef BYTE* LPBYTE;
typedef unsigned long ULONG;
typedef ULONG* LPDWORD;
#define ALGORITHM_DIFF 0
#define ALGORITHM_GRAY 1
//============================== Gray Conversion ==============================
inline void ConvertToGray_Original(LPBYTE dst, LPBYTE src, ULONG count)
{
for (ULONG i = 0; i < count; i += 4, src += 4, dst++) {
*dst = (306 * src[2] + 601 * src[0] + 117 * src[1]) >> 10;
}
}
inline void ConvertToGray_SSE2(LPBYTE dst, LPBYTE src, ULONG count)
{
ULONG pixels = count / 4;
ULONG i = 0;
ULONG aligned = pixels & ~3;
for (; i < aligned; i += 4, src += 16, dst += 4) {
dst[0] = (306 * src[2] + 601 * src[0] + 117 * src[1]) >> 10;
dst[1] = (306 * src[6] + 601 * src[4] + 117 * src[5]) >> 10;
dst[2] = (306 * src[10] + 601 * src[8] + 117 * src[9]) >> 10;
dst[3] = (306 * src[14] + 601 * src[12] + 117 * src[13]) >> 10;
}
for (; i < pixels; i++, src += 4, dst++) {
*dst = (306 * src[2] + 601 * src[0] + 117 * src[1]) >> 10;
}
}
void ToGray_Original(LPBYTE dst, LPBYTE src, int biSizeImage)
{
for (ULONG i = 0; i < (ULONG)biSizeImage; i += 4, dst += 4, src += 4) {
dst[0] = dst[1] = dst[2] = (306 * src[2] + 601 * src[0] + 117 * src[1]) >> 10;
}
}
void ToGray_SSE2(LPBYTE dst, LPBYTE src, int biSizeImage)
{
ULONG pixels = biSizeImage / 4;
for (ULONG i = 0; i < pixels; i++, src += 4, dst += 4) {
BYTE g = (306 * src[2] + 601 * src[0] + 117 * src[1]) >> 10;
dst[0] = dst[1] = dst[2] = g;
dst[3] = 0xFF;
}
}
//============================== Original Version ==============================
ULONG CompareBitmap_Original(LPBYTE CompareSourData, LPBYTE CompareDestData, LPBYTE szBuffer,
DWORD ulCompareLength, BYTE algo, int startPostion = 0)
{
LPDWORD p1 = (LPDWORD)CompareDestData, p2 = (LPDWORD)CompareSourData;
LPBYTE p = szBuffer;
ULONG channel = algo == ALGORITHM_GRAY ? 1 : 4;
ULONG ratio = algo == ALGORITHM_GRAY ? 4 : 1;
for (ULONG i = 0; i < ulCompareLength; i += 4, ++p1, ++p2) {
if (*p1 == *p2)
continue;
ULONG index = i;
LPDWORD pos1 = p1++, pos2 = p2++;
for (i += 4; i < ulCompareLength && *p1 != *p2; i += 4, ++p1, ++p2);
ULONG ulCount = i - index;
memcpy(pos1, pos2, ulCount);
*(LPDWORD)(p) = index + startPostion;
*(LPDWORD)(p + sizeof(ULONG)) = ulCount / ratio;
p += 2 * sizeof(ULONG);
if (channel != 1) {
memcpy(p, pos2, ulCount);
p += ulCount;
} else {
for (LPBYTE end = p + ulCount / ratio; p < end; ++p, ++pos2) {
LPBYTE src = (LPBYTE)pos2;
*p = (306 * src[2] + 601 * src[0] + 117 * src[1]) >> 10;
}
}
}
return (ULONG)(p - szBuffer);
}
//============================== SSE2 Version ==============================
ULONG CompareBitmap_SSE2(LPBYTE CompareSourData, LPBYTE CompareDestData, LPBYTE szBuffer,
DWORD ulCompareLength, BYTE algo, int startPostion = 0)
{
LPBYTE p = szBuffer;
ULONG channel = algo == ALGORITHM_GRAY ? 1 : 4;
ULONG ratio = algo == ALGORITHM_GRAY ? 4 : 1;
const ULONG SSE_BLOCK = 16;
const ULONG alignedLength = ulCompareLength & ~(SSE_BLOCK - 1);
__m128i* v1 = (__m128i*)CompareDestData;
__m128i* v2 = (__m128i*)CompareSourData;
ULONG i = 0;
while (i < alignedLength) {
__m128i cmp = _mm_cmpeq_epi32(*v1, *v2);
int mask = _mm_movemask_epi8(cmp);
if (mask == 0xFFFF) {
i += SSE_BLOCK;
++v1;
++v2;
continue;
}
ULONG index = i;
LPBYTE pos1 = (LPBYTE)v1;
LPBYTE pos2 = (LPBYTE)v2;
do {
i += SSE_BLOCK;
++v1;
++v2;
if (i >= alignedLength) break;
cmp = _mm_cmpeq_epi32(*v1, *v2);
mask = _mm_movemask_epi8(cmp);
} while (mask != 0xFFFF);
ULONG ulCount = i - index;
memcpy(pos1, pos2, ulCount);
*(LPDWORD)(p) = index + startPostion;
*(LPDWORD)(p + sizeof(ULONG)) = ulCount / ratio;
p += 2 * sizeof(ULONG);
if (channel != 1) {
memcpy(p, pos2, ulCount);
p += ulCount;
} else {
ConvertToGray_SSE2(p, pos2, ulCount);
p += ulCount / ratio;
}
}
// Handle remaining bytes
if (i < ulCompareLength) {
LPDWORD p1 = (LPDWORD)((LPBYTE)CompareDestData + i);
LPDWORD p2 = (LPDWORD)((LPBYTE)CompareSourData + i);
for (; i < ulCompareLength; i += 4, ++p1, ++p2) {
if (*p1 == *p2)
continue;
ULONG index = i;
LPDWORD pos1 = p1++;
LPDWORD pos2 = p2++;
for (i += 4; i < ulCompareLength && *p1 != *p2; i += 4, ++p1, ++p2);
ULONG ulCount = i - index;
memcpy(pos1, pos2, ulCount);
*(LPDWORD)(p) = index + startPostion;
*(LPDWORD)(p + sizeof(ULONG)) = ulCount / ratio;
p += 2 * sizeof(ULONG);
if (channel != 1) {
memcpy(p, pos2, ulCount);
p += ulCount;
} else {
LPDWORD srcPtr = pos2;
for (LPBYTE end = p + ulCount / ratio; p < end; ++p, ++srcPtr) {
LPBYTE src = (LPBYTE)srcPtr;
*p = (306 * src[2] + 601 * src[0] + 117 * src[1]) >> 10;
}
}
}
}
return (ULONG)(p - szBuffer);
}
//============================== Benchmark ==============================
void RunBenchmark(int width, int height, float diffRatio, int iterations, BYTE algo = ALGORITHM_DIFF)
{
ULONG dataSize = width * height * 4;
LPBYTE srcBuffer = (LPBYTE)_aligned_malloc(dataSize, 16);
LPBYTE dstBuffer = (LPBYTE)_aligned_malloc(dataSize, 16);
LPBYTE outBuffer1 = (LPBYTE)_aligned_malloc(dataSize * 2, 16);
LPBYTE outBuffer2 = (LPBYTE)_aligned_malloc(dataSize * 2, 16);
if (!srcBuffer || !dstBuffer || !outBuffer1 || !outBuffer2) {
printf("Memory allocation failed!\n");
return;
}
srand(12345);
for (ULONG i = 0; i < dataSize; i++) {
srcBuffer[i] = rand() % 256;
dstBuffer[i] = srcBuffer[i];
}
int diffPixels = (int)(width * height * diffRatio);
for (int i = 0; i < diffPixels; i++) {
int pos = (rand() % (width * height)) * 4;
srcBuffer[pos] = rand() % 256;
srcBuffer[pos + 1] = rand() % 256;
srcBuffer[pos + 2] = rand() % 256;
}
printf("\n========== Test Parameters ==========\n");
printf("Resolution: %d x %d\n", width, height);
printf("Data size: %.2f MB\n", dataSize / 1024.0 / 1024.0);
printf("Diff ratio: %.1f%%\n", diffRatio * 100);
printf("Algorithm: %s\n", algo == ALGORITHM_GRAY ? "Gray" : "Color");
printf("Iterations: %d\n", iterations);
printf("======================================\n\n");
// Test original version
LPBYTE testDst1 = (LPBYTE)_aligned_malloc(dataSize, 16);
memcpy(testDst1, dstBuffer, dataSize);
auto start1 = std::chrono::high_resolution_clock::now();
ULONG result1 = 0;
for (int i = 0; i < iterations; i++) {
memcpy(testDst1, dstBuffer, dataSize);
result1 = CompareBitmap_Original(srcBuffer, testDst1, outBuffer1, dataSize, algo);
}
auto end1 = std::chrono::high_resolution_clock::now();
double time1 = std::chrono::duration<double, std::milli>(end1 - start1).count();
// Test SSE2 version
LPBYTE testDst2 = (LPBYTE)_aligned_malloc(dataSize, 16);
memcpy(testDst2, dstBuffer, dataSize);
auto start2 = std::chrono::high_resolution_clock::now();
ULONG result2 = 0;
for (int i = 0; i < iterations; i++) {
memcpy(testDst2, dstBuffer, dataSize);
result2 = CompareBitmap_SSE2(srcBuffer, testDst2, outBuffer2, dataSize, algo);
}
auto end2 = std::chrono::high_resolution_clock::now();
double time2 = std::chrono::duration<double, std::milli>(end2 - start2).count();
printf("Original:\n");
printf(" Total: %.2f ms\n", time1);
printf(" Per frame: %.3f ms\n", time1 / iterations);
printf(" Output size: %lu bytes\n\n", result1);
printf("SSE2:\n");
printf(" Total: %.2f ms\n", time2);
printf(" Per frame: %.3f ms\n", time2 / iterations);
printf(" Output size: %lu bytes\n\n", result2);
printf("========== Performance ==========\n");
printf("Speedup: %.2fx\n", time1 / time2);
printf("Time saved: %.1f%%\n", (1.0 - time2 / time1) * 100);
if (result1 == result2 && memcmp(outBuffer1, outBuffer2, result1) == 0) {
printf("Verify: PASS\n");
} else {
printf("Verify: DIFF (size: %lu vs %lu)\n", result1, result2);
}
printf("=================================\n");
_aligned_free(srcBuffer);
_aligned_free(dstBuffer);
_aligned_free(outBuffer1);
_aligned_free(outBuffer2);
_aligned_free(testDst1);
_aligned_free(testDst2);
}
//============================== Gray Convert Benchmark ==============================
void RunGrayConvertBenchmark(int width, int height, int iterations)
{
ULONG dataSize = width * height * 4;
ULONG graySize = width * height;
LPBYTE srcBuffer = (LPBYTE)_aligned_malloc(dataSize, 16);
LPBYTE dstBuffer1 = (LPBYTE)_aligned_malloc(graySize, 16);
LPBYTE dstBuffer2 = (LPBYTE)_aligned_malloc(graySize, 16);
if (!srcBuffer || !dstBuffer1 || !dstBuffer2) {
printf("Memory allocation failed!\n");
return;
}
srand(12345);
for (ULONG i = 0; i < dataSize; i++) {
srcBuffer[i] = rand() % 256;
}
printf("\n========== BGRA->Gray Test ==========\n");
printf("Resolution: %d x %d\n", width, height);
printf("Input: %.2f MB, Output: %.2f MB\n", dataSize / 1024.0 / 1024.0, graySize / 1024.0 / 1024.0);
printf("Iterations: %d\n", iterations);
printf("=====================================\n\n");
// Test original version
auto start1 = std::chrono::high_resolution_clock::now();
for (int i = 0; i < iterations; i++) {
ConvertToGray_Original(dstBuffer1, srcBuffer, dataSize);
}
auto end1 = std::chrono::high_resolution_clock::now();
double time1 = std::chrono::duration<double, std::milli>(end1 - start1).count();
// Test SSE2 version
auto start2 = std::chrono::high_resolution_clock::now();
for (int i = 0; i < iterations; i++) {
ConvertToGray_SSE2(dstBuffer2, srcBuffer, dataSize);
}
auto end2 = std::chrono::high_resolution_clock::now();
double time2 = std::chrono::duration<double, std::milli>(end2 - start2).count();
printf("Original (per-pixel):\n");
printf(" Total: %.2f ms, Per frame: %.3f ms\n", time1, time1 / iterations);
printf("\nSSE2 (4-pixel batch):\n");
printf(" Total: %.2f ms, Per frame: %.3f ms\n", time2, time2 / iterations);
printf("\n========== Performance ==========\n");
printf("Speedup: %.2fx\n", time1 / time2);
printf("Time saved: %.1f%%\n", (1.0 - time2 / time1) * 100);
bool match = memcmp(dstBuffer1, dstBuffer2, graySize) == 0;
printf("Verify: %s\n", match ? "PASS" : "FAIL");
printf("=================================\n");
_aligned_free(srcBuffer);
_aligned_free(dstBuffer1);
_aligned_free(dstBuffer2);
}
int main()
{
printf("===== Image Diff Algorithm Benchmark =====\n");
printf("\n\n########## Color Mode ##########\n");
printf("\n[1080p 10%% diff - Color]");
RunBenchmark(1920, 1080, 0.10f, 100, ALGORITHM_DIFF);
printf("\n[1080p 30%% diff - Color]");
RunBenchmark(1920, 1080, 0.30f, 100, ALGORITHM_DIFF);
printf("\n\n########## Gray Mode ##########\n");
printf("\n[1080p 10%% diff - Gray]");
RunBenchmark(1920, 1080, 0.10f, 100, ALGORITHM_GRAY);
printf("\n[1080p 30%% diff - Gray]");
RunBenchmark(1920, 1080, 0.30f, 100, ALGORITHM_GRAY);
printf("\n\n########## BGRA->Gray Conversion ##########\n");
printf("\n[1080p BGRA->Gray]");
RunGrayConvertBenchmark(1920, 1080, 100);
printf("\n[4K BGRA->Gray]");
RunGrayConvertBenchmark(3840, 2160, 50);
printf("\nDone!\n");
return 0;
}

237
test/test.bat Normal file
View File

@@ -0,0 +1,237 @@
@echo off
chcp 65001 >nul 2>&1
setlocal enabledelayedexpansion
:: SimpleRemoter Test Management Script
:: Usage: test.bat [build|run|clean|rebuild|help] [options]
::
:: Test Phases:
:: Phase 1: Buffer + Protocol
:: Phase 2: File Transfer
:: Phase 3: Network
:: Phase 4: Screen/Image
:: Phase 5: Auth + Config
set "SCRIPT_DIR=%~dp0"
set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%"
set "BUILD_DIR=%SCRIPT_DIR%\build"
set "CONFIG=Release"
set "ACTION=%~1"
set "OPTION=%~2"
if "%ACTION%"=="" goto help
if "%ACTION%"=="build" goto build
if "%ACTION%"=="run" goto run
if "%ACTION%"=="clean" goto clean
if "%ACTION%"=="rebuild" goto rebuild
if "%ACTION%"=="help" goto help
echo Error: Unknown action "%ACTION%"
echo.
goto help
:build
echo ========================================
echo Building Tests (17 executables)
echo ========================================
if not exist "%BUILD_DIR%\CMakeCache.txt" (
echo [1/2] Configuring CMake...
cmake -B "%BUILD_DIR%" -S "%SCRIPT_DIR%"
if errorlevel 1 (
echo Error: CMake configuration failed
exit /b 1
)
) else (
echo [1/2] CMake already configured, skipping...
)
echo [2/2] Compiling tests...
cmake --build "%BUILD_DIR%" --config %CONFIG%
if errorlevel 1 (
echo Error: Build failed
exit /b 1
)
echo.
echo Build successful! (17 test executables)
echo.
echo Phase 1 - Buffer/Protocol:
echo - client_buffer_test.exe (33 tests)
echo - server_buffer_test.exe (40 tests)
echo - protocol_test.exe (58 tests)
echo.
echo Phase 2 - File Transfer:
echo - file_transfer_test.exe (37 tests)
echo - chunk_manager_test.exe (36 tests)
echo - sha256_verify_test.exe (28 tests)
echo - resume_state_test.exe (26 tests)
echo.
echo Phase 3 - Network:
echo - header_test.exe (29 tests)
echo - packet_fragment_test.exe (24 tests)
echo - http_mask_test.exe (27 tests)
echo - geolocation_test.exe (12 tests)
echo.
echo Phase 4 - Screen/Image:
echo - diff_algorithm_test.exe (32 tests)
echo - rgb565_test.exe (286 tests)
echo - scroll_detector_test.exe (43 tests)
echo - quality_adaptive_test.exe (64 tests)
echo.
echo Phase 5 - Auth/Config:
echo - date_verify_test.exe (18 tests)
echo - registry_config_test.exe (27 tests)
exit /b 0
:run
echo ========================================
echo Running Tests
echo ========================================
if not exist "%BUILD_DIR%\%CONFIG%\client_buffer_test.exe" (
echo Tests not built, building first...
call :build
if errorlevel 1 exit /b 1
echo.
)
if "%OPTION%"=="client" goto run_client
if "%OPTION%"=="server" goto run_server
if "%OPTION%"=="protocol" goto run_protocol
if "%OPTION%"=="file" goto run_file
if "%OPTION%"=="network" goto run_network
if "%OPTION%"=="screen" goto run_screen
if "%OPTION%"=="auth" goto run_auth
if "%OPTION%"=="verbose" goto run_verbose
goto run_all
:run_client
echo Running client Buffer tests [33]...
"%BUILD_DIR%\%CONFIG%\client_buffer_test.exe" --gtest_color=yes
goto check_result
:run_server
echo Running server Buffer tests [40]...
"%BUILD_DIR%\%CONFIG%\server_buffer_test.exe" --gtest_color=yes
goto check_result
:run_protocol
echo Running protocol tests [58]...
"%BUILD_DIR%\%CONFIG%\protocol_test.exe" --gtest_color=yes
goto check_result
:run_file
echo Running file transfer tests [127]...
"%BUILD_DIR%\%CONFIG%\file_transfer_test.exe" --gtest_color=yes
"%BUILD_DIR%\%CONFIG%\chunk_manager_test.exe" --gtest_color=yes
"%BUILD_DIR%\%CONFIG%\sha256_verify_test.exe" --gtest_color=yes
"%BUILD_DIR%\%CONFIG%\resume_state_test.exe" --gtest_color=yes
goto check_result
:run_network
echo Running network tests [92]...
"%BUILD_DIR%\%CONFIG%\header_test.exe" --gtest_color=yes
"%BUILD_DIR%\%CONFIG%\packet_fragment_test.exe" --gtest_color=yes
"%BUILD_DIR%\%CONFIG%\http_mask_test.exe" --gtest_color=yes
"%BUILD_DIR%\%CONFIG%\geolocation_test.exe" --gtest_color=yes
goto check_result
:run_screen
echo Running screen/image tests [425]...
"%BUILD_DIR%\%CONFIG%\diff_algorithm_test.exe" --gtest_color=yes
"%BUILD_DIR%\%CONFIG%\rgb565_test.exe" --gtest_color=yes
"%BUILD_DIR%\%CONFIG%\scroll_detector_test.exe" --gtest_color=yes
"%BUILD_DIR%\%CONFIG%\quality_adaptive_test.exe" --gtest_color=yes
goto check_result
:run_auth
echo Running auth/config tests [45]...
"%BUILD_DIR%\%CONFIG%\date_verify_test.exe" --gtest_color=yes
"%BUILD_DIR%\%CONFIG%\registry_config_test.exe" --gtest_color=yes
goto check_result
:run_verbose
echo Running all tests [verbose]...
ctest --test-dir "%BUILD_DIR%" -C %CONFIG% -V
goto check_result
:run_all
echo Running all tests...
ctest --test-dir "%BUILD_DIR%" -C %CONFIG% --output-on-failure
:check_result
if errorlevel 1 (
echo.
echo Tests FAILED!
exit /b 1
)
echo.
echo All tests PASSED!
goto end
:clean
echo ========================================
echo Cleaning Build
echo ========================================
if exist "%BUILD_DIR%" (
echo Removing build directory...
rmdir /s /q "%BUILD_DIR%"
echo Clean complete!
) else (
echo Build directory does not exist
)
exit /b 0
:rebuild
echo ========================================
echo Rebuilding
echo ========================================
call :clean
echo.
call :build
goto end
:help
echo.
echo SimpleRemoter Test Management Script
echo ========================================
echo.
echo Usage: test.bat ^<command^> [options]
echo.
echo Commands:
echo build Build all 17 test executables
echo run Run all tests
echo run client Run client Buffer tests
echo run server Run server Buffer tests
echo run protocol Run protocol tests
echo run file Run file transfer tests
echo run network Run network tests (incl. geolocation)
echo run screen Run screen/image tests
echo run auth Run auth/config tests
echo run verbose Run all tests with verbose output
echo clean Clean build directory
echo rebuild Clean and rebuild
echo help Show this help
echo.
echo Test Phases:
echo Phase 1: Buffer + Protocol
echo Phase 2: File Transfer
echo Phase 3: Network
echo Phase 4: Screen/Image
echo Phase 5: Auth + Config
echo.
echo Examples:
echo test.bat build # Build all tests
echo test.bat run # Run all 820 tests
echo test.bat run screen # Run Phase 4 screen tests
echo test.bat run auth # Run Phase 5 auth/config tests
echo test.bat rebuild # Clean and rebuild
echo.
goto end
:end
endlocal

View File

@@ -0,0 +1,875 @@
/**
* @file DateVerifyTest.cpp
* @brief DateVerify 时间验证类测试
*
* 测试覆盖:
* - isTimeTampered() 时间篡改检测
* - getTimeOffset() 时间偏差获取
* - isTrial() 试用期检测
* - isTrail() 兼容性测试
* - 边界条件和缓存机制
*/
#include <gtest/gtest.h>
#include <cstring>
#include <ctime>
#include <map>
#include <chrono>
#include <cmath>
#include <string>
#include <functional>
// ============================================
// 可测试版本的 DateVerify 类
// 允许注入模拟的网络时间
// ============================================
class TestableDateVerify
{
private:
bool m_hasVerified = false;
bool m_lastTimeTampered = true;
time_t m_lastVerifyLocalTime = 0;
time_t m_lastNetworkTime = 0;
static const int VERIFY_INTERVAL = 6 * 3600; // 6小时
// 模拟的网络时间0表示网络不可用
time_t m_mockNetworkTime = 0;
bool m_useMockTime = false;
// 模拟的本地时间
time_t m_mockLocalTime = 0;
bool m_useMockLocalTime = false;
// 模拟的编译日期
std::string m_mockCompileDate;
time_t getNetworkTimeInChina()
{
if (m_useMockTime) {
return m_mockNetworkTime;
}
// 实际测试中不调用网络
return 0;
}
time_t getLocalTime()
{
if (m_useMockLocalTime) {
return m_mockLocalTime;
}
return time(nullptr);
}
int monthAbbrevToNumber(const std::string& month)
{
static const std::map<std::string, int> months = {
{"Jan", 1}, {"Feb", 2}, {"Mar", 3}, {"Apr", 4},
{"May", 5}, {"Jun", 6}, {"Jul", 7}, {"Aug", 8},
{"Sep", 9}, {"Oct", 10}, {"Nov", 11}, {"Dec", 12}
};
auto it = months.find(month);
return (it != months.end()) ? it->second : 0;
}
tm parseCompileDate(const char* compileDate)
{
tm tmCompile = { 0 };
std::string monthStr(compileDate, 3);
std::string dayStr(compileDate + 4, 2);
std::string yearStr(compileDate + 7, 4);
tmCompile.tm_year = std::stoi(yearStr) - 1900;
tmCompile.tm_mon = monthAbbrevToNumber(monthStr) - 1;
tmCompile.tm_mday = std::stoi(dayStr);
return tmCompile;
}
int daysBetweenDates(const tm& date1, const tm& date2)
{
auto timeToTimePoint = [](const tm& tmTime) {
std::time_t tt = mktime(const_cast<tm*>(&tmTime));
return std::chrono::system_clock::from_time_t(tt);
};
auto tp1 = timeToTimePoint(date1);
auto tp2 = timeToTimePoint(date2);
auto duration = tp1 > tp2 ? tp1 - tp2 : tp2 - tp1;
return static_cast<int>(std::chrono::duration_cast<std::chrono::hours>(duration).count() / 24);
}
tm getCurrentDate()
{
std::time_t now = getLocalTime();
tm tmNow = *std::localtime(&now);
tmNow.tm_hour = 0;
tmNow.tm_min = 0;
tmNow.tm_sec = 0;
return tmNow;
}
public:
// 测试辅助方法
void setMockNetworkTime(time_t t) { m_mockNetworkTime = t; m_useMockTime = true; }
void setNetworkUnavailable() { m_mockNetworkTime = 0; m_useMockTime = true; }
void setMockLocalTime(time_t t) { m_mockLocalTime = t; m_useMockLocalTime = true; }
void setMockCompileDate(const std::string& date) { m_mockCompileDate = date; }
void resetCache() { m_hasVerified = false; m_lastTimeTampered = true; }
bool hasVerified() const { return m_hasVerified; }
// 检测本地时间是否被篡改
bool isTimeTampered(int toleranceDays = 1)
{
time_t currentLocalTime = getLocalTime();
// 检查是否可以使用缓存
if (m_hasVerified) {
time_t localElapsed = currentLocalTime - m_lastVerifyLocalTime;
// 本地时间在合理范围内前进,使用缓存推算
if (localElapsed >= 0 && localElapsed < VERIFY_INTERVAL) {
time_t estimatedNetworkTime = m_lastNetworkTime + localElapsed;
double diffDays = difftime(estimatedNetworkTime, currentLocalTime) / 86400.0;
if (fabs(diffDays) <= toleranceDays) {
return false; // 未篡改
}
}
}
// 执行网络验证
time_t networkTime = getNetworkTimeInChina();
if (networkTime == 0) {
// 网络不可用:如果之前验证通过且本地时间没异常,暂时信任
if (m_hasVerified && !m_lastTimeTampered) {
time_t localElapsed = currentLocalTime - m_lastVerifyLocalTime;
// 允许5分钟回拨和24小时内的前进
if (localElapsed >= -300 && localElapsed < 24 * 3600) {
return false;
}
}
return true; // 无法验证,视为篡改
}
// 更新缓存
m_hasVerified = true;
m_lastVerifyLocalTime = currentLocalTime;
m_lastNetworkTime = networkTime;
double diffDays = difftime(networkTime, currentLocalTime) / 86400.0;
m_lastTimeTampered = fabs(diffDays) > toleranceDays;
return m_lastTimeTampered;
}
// 获取网络时间与本地时间的偏差(秒)
int getTimeOffset()
{
time_t networkTime = getNetworkTimeInChina();
if (networkTime == 0) return 0;
return static_cast<int>(difftime(networkTime, getLocalTime()));
}
bool isTrial(int trialDays = 7)
{
if (isTimeTampered())
return false;
const char* compileDate = m_mockCompileDate.empty() ? __DATE__ : m_mockCompileDate.c_str();
tm tmCompile = parseCompileDate(compileDate), tmCurrent = getCurrentDate();
int daysDiff = daysBetweenDates(tmCompile, tmCurrent);
return daysDiff <= trialDays;
}
// 兼容性函数
bool isTrail(int trailDays = 7) { return isTrial(trailDays); }
};
// ============================================
// 辅助函数
// ============================================
// 创建指定日期的 time_t
time_t makeTime(int year, int month, int day, int hour = 12, int minute = 0, int second = 0)
{
tm t = {};
t.tm_year = year - 1900;
t.tm_mon = month - 1;
t.tm_mday = day;
t.tm_hour = hour;
t.tm_min = minute;
t.tm_sec = second;
return mktime(&t);
}
// 格式化日期为 __DATE__ 格式 (例如 "Mar 13 2026")
std::string formatCompileDate(int year, int month, int day)
{
static const char* months[] = {
"Jan", "Feb", "Mar", "Apr", "May", "Jun",
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec"
};
char buf[16];
snprintf(buf, sizeof(buf), "%s %2d %d", months[month - 1], day, year);
return std::string(buf);
}
// ============================================
// isTimeTampered 测试
// ============================================
class IsTimeTamperedTest : public ::testing::Test {
protected:
TestableDateVerify verifier;
};
TEST_F(IsTimeTamperedTest, NetworkUnavailable_NoCache_ReturnsTampered)
{
// 网络不可用,无缓存时应返回 true视为篡改
verifier.setNetworkUnavailable();
EXPECT_TRUE(verifier.isTimeTampered());
}
TEST_F(IsTimeTamperedTest, NetworkAvailable_TimeMatch_ReturnsNotTampered)
{
// 网络时间与本地时间匹配
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, NetworkAvailable_TimeWithinTolerance_ReturnsNotTampered)
{
// 网络时间与本地时间差异在容忍范围内12小时差异1天容忍
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 12 * 3600); // 本地时间落后12小时
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, NetworkAvailable_TimeExceedsTolerance_ReturnsTampered)
{
// 网络时间与本地时间差异超过容忍范围2天差异1天容忍
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 2 * 86400); // 本地时间落后2天
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, ToleranceDays_Zero_StrictMode)
{
// 0天容忍度任何偏差都应报告为篡改
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 86400); // 1天偏差
EXPECT_TRUE(verifier.isTimeTampered(0));
}
TEST_F(IsTimeTamperedTest, ToleranceDays_Large_LenientMode)
{
// 大容忍度7天较大偏差仍应通过
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 5 * 86400); // 5天偏差
EXPECT_FALSE(verifier.isTimeTampered(7));
}
TEST_F(IsTimeTamperedTest, LocalTimeAhead_ExceedsTolerance_ReturnsTampered)
{
// 本地时间超前网络时间
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now + 3 * 86400); // 本地时间超前3天
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, CacheHit_WithinInterval_UsesCache)
{
// 首次验证成功
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
EXPECT_TRUE(verifier.hasVerified());
// 模拟时间前进1小时在6小时缓存间隔内
verifier.setMockLocalTime(now + 3600);
verifier.setNetworkUnavailable(); // 网络不可用
// 应该使用缓存,不报告篡改
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, CacheMiss_ExceedsInterval_RequiresNetwork)
{
// 首次验证成功
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 模拟时间前进25小时超过24小时宽容期
// 注实现中网络不可用时有24小时宽容期超过后才报告篡改
verifier.setMockLocalTime(now + 25 * 3600);
verifier.setNetworkUnavailable();
// 超过24小时宽容期网络不可用应报告篡改
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, NetworkUnavailable_WithGoodCache_AllowsSmallRewind)
{
// 首次验证成功
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 模拟时间回拨2分钟在允许的5分钟范围内
verifier.setMockLocalTime(now - 120);
verifier.setNetworkUnavailable();
// 应该暂时信任
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(IsTimeTamperedTest, NetworkUnavailable_WithGoodCache_RejectsLargeRewind)
{
// 首次验证成功
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 模拟时间回拨10分钟超过允许的5分钟
verifier.setMockLocalTime(now - 600);
verifier.setNetworkUnavailable();
// 应该报告篡改
EXPECT_TRUE(verifier.isTimeTampered(1));
}
// ============================================
// getTimeOffset 测试
// ============================================
class GetTimeOffsetTest : public ::testing::Test {
protected:
TestableDateVerify verifier;
};
TEST_F(GetTimeOffsetTest, NetworkUnavailable_ReturnsZero)
{
verifier.setNetworkUnavailable();
EXPECT_EQ(verifier.getTimeOffset(), 0);
}
TEST_F(GetTimeOffsetTest, LocalTimeBehind_ReturnsPositive)
{
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 3600); // 本地落后1小时
int offset = verifier.getTimeOffset();
EXPECT_GT(offset, 0);
EXPECT_NEAR(offset, 3600, 5); // 允许5秒误差
}
TEST_F(GetTimeOffsetTest, LocalTimeAhead_ReturnsNegative)
{
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now + 3600); // 本地超前1小时
int offset = verifier.getTimeOffset();
EXPECT_LT(offset, 0);
EXPECT_NEAR(offset, -3600, 5);
}
TEST_F(GetTimeOffsetTest, TimeMatch_ReturnsNearZero)
{
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
int offset = verifier.getTimeOffset();
EXPECT_NEAR(offset, 0, 2); // 允许2秒误差
}
// ============================================
// isTrial / isTrail 测试
// ============================================
class IsTrialTest : public ::testing::Test {
protected:
TestableDateVerify verifier;
void SetUp() override {
// 默认设置:网络正常,时间同步
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
}
};
TEST_F(IsTrialTest, WithinTrialPeriod_ReturnsTrue)
{
// 编译日期3天前
time_t now = time(nullptr);
tm* tmNow = localtime(&now);
// 创建3天前的日期
time_t threeDaysAgo = now - 3 * 86400;
tm* tmCompile = localtime(&threeDaysAgo);
std::string compileDate = formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_TRUE(verifier.isTrial(7)); // 7天试用期内
}
TEST_F(IsTrialTest, ExactTrialPeriod_ReturnsTrue)
{
// 编译日期正好7天前
time_t now = time(nullptr);
time_t sevenDaysAgo = now - 7 * 86400;
tm* tmCompile = localtime(&sevenDaysAgo);
std::string compileDate = formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_TRUE(verifier.isTrial(7)); // 边界正好7天
}
TEST_F(IsTrialTest, ExceedsTrialPeriod_ReturnsFalse)
{
// 编译日期8天前
time_t now = time(nullptr);
time_t eightDaysAgo = now - 8 * 86400;
tm* tmCompile = localtime(&eightDaysAgo);
std::string compileDate = formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_FALSE(verifier.isTrial(7)); // 超过7天试用期
}
TEST_F(IsTrialTest, CompileToday_ReturnsTrue)
{
// 编译日期:今天
time_t now = time(nullptr);
tm* tmNow = localtime(&now);
std::string compileDate = formatCompileDate(
tmNow->tm_year + 1900,
tmNow->tm_mon + 1,
tmNow->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_TRUE(verifier.isTrial(7));
EXPECT_TRUE(verifier.isTrial(1));
EXPECT_TRUE(verifier.isTrial(0));
}
TEST_F(IsTrialTest, TimeTampered_ReturnsFalse)
{
// 时间被篡改时,无论试用期如何都应返回 false
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 30 * 86400); // 本地时间落后30天
// 即使编译日期是今天,时间篡改也会导致返回 false
tm* tmNow = localtime(&now);
std::string compileDate = formatCompileDate(
tmNow->tm_year + 1900,
tmNow->tm_mon + 1,
tmNow->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_FALSE(verifier.isTrial(7));
}
TEST_F(IsTrialTest, CustomTrialDays_Zero)
{
// 0天试用期只有编译当天有效
time_t now = time(nullptr);
time_t yesterday = now - 86400;
tm* tmYesterday = localtime(&yesterday);
std::string compileDate = formatCompileDate(
tmYesterday->tm_year + 1900,
tmYesterday->tm_mon + 1,
tmYesterday->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_FALSE(verifier.isTrial(0)); // 昨天编译0天试用期已过
}
TEST_F(IsTrialTest, IsTrail_CompatibilityAlias)
{
// isTrail 应该与 isTrial 行为一致
time_t now = time(nullptr);
tm* tmNow = localtime(&now);
std::string compileDate = formatCompileDate(
tmNow->tm_year + 1900,
tmNow->tm_mon + 1,
tmNow->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_EQ(verifier.isTrail(7), verifier.isTrial(7));
EXPECT_EQ(verifier.isTrail(0), verifier.isTrial(0));
}
// ============================================
// 边界条件测试
// ============================================
class BoundaryTest : public ::testing::Test {
protected:
TestableDateVerify verifier;
};
TEST_F(BoundaryTest, ToleranceDays_ExactBoundary)
{
// 测试容忍度边界:偏差正好等于容忍天数
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 86400); // 正好1天偏差
// 1天容忍度1天偏差应该通过<= 比较)
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(BoundaryTest, ToleranceDays_JustOverBoundary)
{
// 测试容忍度边界:偏差刚刚超过容忍天数
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now - 86400 - 3600); // 1天+1小时偏差
// 1天容忍度偏差超过1天应该失败
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(BoundaryTest, CacheInterval_ExactBoundary)
{
// 测试网络不可用时的24小时宽容期边界
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 正好24小时后刚好到达宽容期边界
verifier.setMockLocalTime(now + 24 * 3600);
verifier.setNetworkUnavailable();
// 刚好到达24小时边界网络不可用应报告篡改
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(BoundaryTest, CacheInterval_JustUnderBoundary)
{
// 测试缓存间隔边界略少于6小时
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 5小时59分钟后仍在缓存有效期内
verifier.setMockLocalTime(now + 6 * 3600 - 60);
verifier.setNetworkUnavailable();
// 仍在缓存有效期内
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(BoundaryTest, NetworkTimeRollback_AllowedMargin)
{
// 测试允许的时间回拨范围正好5分钟
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 回拨正好5分钟
verifier.setMockLocalTime(now - 300);
verifier.setNetworkUnavailable();
// 边界情况:-300 >= -300应该通过
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(BoundaryTest, NetworkTimeRollback_ExceedsMargin)
{
// 测试超过允许的时间回拨范围
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 回拨超过5分钟
verifier.setMockLocalTime(now - 301);
verifier.setNetworkUnavailable();
// 超过允许范围
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(BoundaryTest, TrialDays_LargeValue)
{
// 测试大试用期值
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
// 100天前编译
time_t hundredDaysAgo = now - 100 * 86400;
tm* tmCompile = localtime(&hundredDaysAgo);
std::string compileDate = formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
);
verifier.setMockCompileDate(compileDate);
EXPECT_FALSE(verifier.isTrial(99)); // 99天试用期已过
EXPECT_TRUE(verifier.isTrial(100)); // 正好100天
EXPECT_TRUE(verifier.isTrial(365)); // 365天试用期内
}
// ============================================
// 日期解析测试
// ============================================
class DateParsingTest : public ::testing::Test {
protected:
TestableDateVerify verifier;
void SetUp() override {
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
}
};
TEST_F(DateParsingTest, AllMonths)
{
// 测试所有月份的解析
const char* months[] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun",
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec"};
time_t now = time(nullptr);
tm* tmNow = localtime(&now);
for (int i = 0; i < 12; ++i) {
char dateStr[16];
snprintf(dateStr, sizeof(dateStr), "%s %2d %d", months[i], 15, tmNow->tm_year + 1900);
verifier.setMockCompileDate(dateStr);
// 只要不是时间篡改,应该能正常解析
// 结果取决于当前日期与编译日期的差异
bool result = verifier.isTrial(365); // 使用较长试用期确保通过
EXPECT_TRUE(result) << "Failed for month: " << months[i];
}
}
TEST_F(DateParsingTest, SingleDigitDay)
{
// 测试单数字日期(如 "Mar 1 2026"
time_t now = time(nullptr);
tm* tmNow = localtime(&now);
char dateStr[16];
snprintf(dateStr, sizeof(dateStr), "Mar 1 %d", tmNow->tm_year + 1900);
verifier.setMockCompileDate(dateStr);
EXPECT_TRUE(verifier.isTrial(365));
}
TEST_F(DateParsingTest, DoubleDigitDay)
{
// 测试双数字日期(如 "Mar 15 2026"
time_t now = time(nullptr);
tm* tmNow = localtime(&now);
char dateStr[16];
snprintf(dateStr, sizeof(dateStr), "Mar 15 %d", tmNow->tm_year + 1900);
verifier.setMockCompileDate(dateStr);
EXPECT_TRUE(verifier.isTrial(365));
}
// ============================================
// 授权场景模拟测试
// ============================================
class AuthorizationScenarioTest : public ::testing::Test {
protected:
TestableDateVerify verifier;
};
TEST_F(AuthorizationScenarioTest, ValidLicense_TimeNotTampered)
{
// 场景:有效授权,时间未被篡改
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
// 授权检查应该通过
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(AuthorizationScenarioTest, ExpiredLicense_UserRollsBackTime)
{
// 场景:授权过期,用户将时间回拨企图绕过
time_t now = time(nullptr);
time_t thirtyDaysAgo = now - 30 * 86400;
verifier.setMockNetworkTime(now); // 网络时间是真实时间
verifier.setMockLocalTime(thirtyDaysAgo); // 用户将本地时间回拨30天
// 应该检测到时间篡改
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(AuthorizationScenarioTest, ExpiredLicense_UserAdvancesTime)
{
// 场景:用户将时间提前(不太常见,但也应检测)
time_t now = time(nullptr);
time_t thirtyDaysAhead = now + 30 * 86400;
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(thirtyDaysAhead);
// 应该检测到时间篡改
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(AuthorizationScenarioTest, OfflineUser_RecentValidation)
{
// 场景:用户刚刚验证通过后断网
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1)); // 首次验证通过
// 用户断网,但时间正常前进
verifier.setMockLocalTime(now + 3600); // 1小时后
verifier.setNetworkUnavailable();
// 应该允许(使用缓存)
EXPECT_FALSE(verifier.isTimeTampered(1));
}
TEST_F(AuthorizationScenarioTest, OfflineUser_LongOffline)
{
// 场景:用户长时间离线后恢复
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
EXPECT_FALSE(verifier.isTimeTampered(1));
// 用户离线超过24小时
verifier.setMockLocalTime(now + 25 * 3600);
verifier.setNetworkUnavailable();
// 缓存已过期,网络不可用,应该拒绝
EXPECT_TRUE(verifier.isTimeTampered(1));
}
TEST_F(AuthorizationScenarioTest, TrialUser_WithinPeriod)
{
// 场景:试用用户在试用期内
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
time_t threeDaysAgo = now - 3 * 86400;
tm* tmCompile = localtime(&threeDaysAgo);
verifier.setMockCompileDate(formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
));
EXPECT_TRUE(verifier.isTrial(7));
}
TEST_F(AuthorizationScenarioTest, TrialUser_Expired)
{
// 场景:试用用户试用期已过
time_t now = time(nullptr);
verifier.setMockNetworkTime(now);
verifier.setMockLocalTime(now);
time_t tenDaysAgo = now - 10 * 86400;
tm* tmCompile = localtime(&tenDaysAgo);
verifier.setMockCompileDate(formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
));
EXPECT_FALSE(verifier.isTrial(7));
}
TEST_F(AuthorizationScenarioTest, TrialUser_RollsBackTime)
{
// 场景:试用用户回拨时间企图延长试用期
time_t now = time(nullptr);
time_t tenDaysAgo = now - 10 * 86400;
verifier.setMockNetworkTime(now); // 真实网络时间
verifier.setMockLocalTime(tenDaysAgo); // 用户回拨到10天前
// 编译日期设为15天前
time_t fifteenDaysAgo = now - 15 * 86400;
tm* tmCompile = localtime(&fifteenDaysAgo);
verifier.setMockCompileDate(formatCompileDate(
tmCompile->tm_year + 1900,
tmCompile->tm_mon + 1,
tmCompile->tm_mday
));
// 时间篡改应被检测isTrial 应返回 false
EXPECT_FALSE(verifier.isTrial(7));
}

View File

@@ -0,0 +1,556 @@
/**
* @file BufferTest.cpp
* @brief 客户端 CBuffer 类单元测试
*
* 测试覆盖:
* - 基本读写操作
* - 边界条件(空缓冲区、零长度、超长请求)
* - 内存管理(扩展、收缩)
* - 下溢防护Skip、ReadBuffer 边界)
*/
#include <gtest/gtest.h>
#include <cstring>
#include <vector>
// 模拟 Windows 类型定义(用于跨平台测试)
#ifndef _WIN32
typedef unsigned char BYTE;
typedef BYTE* PBYTE;
typedef BYTE* LPBYTE;
typedef unsigned long ULONG;
typedef int BOOL;
#define TRUE 1
#define FALSE 0
#define MEM_COMMIT 0x1000
#define MEM_RELEASE 0x8000
#define PAGE_READWRITE 0x04
// 跨平台内存分配模拟
inline void* MVirtualAlloc(void*, size_t size, int, int) {
return malloc(size);
}
inline void MVirtualFree(void* ptr, size_t, int) {
free(ptr);
}
inline void CopyMemory(void* dst, const void* src, size_t len) {
memcpy(dst, src, len);
}
inline void MoveMemory(void* dst, const void* src, size_t len) {
memmove(dst, src, len);
}
#else
#include <Windows.h>
// Windows 下的内存分配封装
inline void* MVirtualAlloc(void* addr, size_t size, int type, int protect) {
return VirtualAlloc(addr, size, type, protect);
}
inline void MVirtualFree(void* ptr, size_t size, int type) {
VirtualFree(ptr, size, type);
}
#endif
// 内联包含 Buffer 实现(测试专用)
// 这样可以避免复杂的链接问题
namespace ClientBuffer {
#define U_PAGE_ALIGNMENT 3
#define F_PAGE_ALIGNMENT 3.0
class CBuffer
{
public:
CBuffer() : m_ulMaxLength(0), m_Base(NULL), m_Ptr(NULL) {}
~CBuffer() {
if (m_Base) {
MVirtualFree(m_Base, 0, MEM_RELEASE);
m_Base = NULL;
}
m_Base = m_Ptr = NULL;
m_ulMaxLength = 0;
}
ULONG ReadBuffer(PBYTE Buffer, ULONG ulLength) {
ULONG dataLen = (ULONG)(m_Ptr - m_Base);
if (ulLength > dataLen) {
ulLength = dataLen;
}
if (ulLength) {
CopyMemory(Buffer, m_Base, ulLength);
ULONG remaining = dataLen - ulLength;
if (remaining > 0) {
MoveMemory(m_Base, m_Base + ulLength, remaining);
}
m_Ptr = m_Base + remaining;
}
DeAllocateBuffer((ULONG)(m_Ptr - m_Base));
return ulLength;
}
VOID DeAllocateBuffer(ULONG ulLength) {
int len = (int)(m_Ptr - m_Base);
if (ulLength < (ULONG)len)
return;
ULONG ulNewMaxLength = (ULONG)(ceil(ulLength / F_PAGE_ALIGNMENT) * U_PAGE_ALIGNMENT);
if (m_ulMaxLength <= ulNewMaxLength) {
return;
}
PBYTE NewBase = (PBYTE)MVirtualAlloc(NULL, ulNewMaxLength, MEM_COMMIT, PAGE_READWRITE);
if (NewBase == NULL)
return;
CopyMemory(NewBase, m_Base, len);
MVirtualFree(m_Base, 0, MEM_RELEASE);
m_Base = NewBase;
m_Ptr = m_Base + len;
m_ulMaxLength = ulNewMaxLength;
}
BOOL WriteBuffer(PBYTE Buffer, ULONG ulLength) {
if (ReAllocateBuffer(ulLength + (ULONG)(m_Ptr - m_Base)) == FALSE) {
return FALSE;
}
CopyMemory(m_Ptr, Buffer, ulLength);
m_Ptr += ulLength;
return TRUE;
}
BOOL ReAllocateBuffer(ULONG ulLength) {
if (ulLength < m_ulMaxLength)
return TRUE;
ULONG ulNewMaxLength = (ULONG)(ceil(ulLength / F_PAGE_ALIGNMENT) * U_PAGE_ALIGNMENT);
PBYTE NewBase = (PBYTE)MVirtualAlloc(NULL, ulNewMaxLength, MEM_COMMIT, PAGE_READWRITE);
if (NewBase == NULL) {
return FALSE;
}
ULONG len = (ULONG)(m_Ptr - m_Base);
CopyMemory(NewBase, m_Base, len);
if (m_Base) {
MVirtualFree(m_Base, 0, MEM_RELEASE);
}
m_Base = NewBase;
m_Ptr = m_Base + len;
m_ulMaxLength = ulNewMaxLength;
return TRUE;
}
VOID ClearBuffer() {
m_Ptr = m_Base;
DeAllocateBuffer(1024);
}
ULONG GetBufferLength() const {
return (ULONG)(m_Ptr - m_Base);
}
void Skip(ULONG ulPos) {
if (ulPos == 0)
return;
ULONG dataLen = (ULONG)(m_Ptr - m_Base);
if (ulPos > dataLen) {
ulPos = dataLen;
}
if (ulPos > 0) {
ULONG remaining = dataLen - ulPos;
if (remaining > 0) {
MoveMemory(m_Base, m_Base + ulPos, remaining);
}
m_Ptr = m_Base + remaining;
}
}
PBYTE GetBuffer(ULONG ulPos = 0) const {
if (m_Base == NULL || ulPos >= (ULONG)(m_Ptr - m_Base)) {
return NULL;
}
return m_Base + ulPos;
}
protected:
PBYTE m_Base;
PBYTE m_Ptr;
ULONG m_ulMaxLength;
};
} // namespace ClientBuffer
using ClientBuffer::CBuffer;
// ============================================
// 测试夹具
// ============================================
class ClientBufferTest : public ::testing::Test {
protected:
CBuffer buffer;
void SetUp() override {
// 每个测试前重置
}
void TearDown() override {
// 每个测试后清理
}
// 辅助方法:写入测试数据
void WriteTestData(const std::vector<BYTE>& data) {
buffer.WriteBuffer(const_cast<BYTE*>(data.data()), (ULONG)data.size());
}
// 辅助方法:写入指定长度的填充数据
void WriteFillData(ULONG length, BYTE fillValue = 0x42) {
std::vector<BYTE> data(length, fillValue);
buffer.WriteBuffer(data.data(), length);
}
};
// ============================================
// 构造/析构测试
// ============================================
TEST_F(ClientBufferTest, Constructor_InitializesEmpty) {
CBuffer newBuffer;
EXPECT_EQ(newBuffer.GetBufferLength(), 0u);
EXPECT_EQ(newBuffer.GetBuffer(), nullptr);
}
// ============================================
// WriteBuffer 测试
// ============================================
TEST_F(ClientBufferTest, WriteBuffer_ValidData_ReturnsTrue) {
BYTE data[] = {1, 2, 3, 4, 5};
EXPECT_TRUE(buffer.WriteBuffer(data, 5));
EXPECT_EQ(buffer.GetBufferLength(), 5u);
}
TEST_F(ClientBufferTest, WriteBuffer_ZeroLength_ToEmptyBuffer) {
// 空缓冲区写入 0 字节:由于 VirtualAlloc(0) 返回 NULL会失败
BYTE data[] = {1};
// 实际行为:返回 FALSE无法分配 0 字节)
EXPECT_FALSE(buffer.WriteBuffer(data, 0));
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, WriteBuffer_ZeroLength_ToNonEmptyBuffer) {
// 非空缓冲区写入 0 字节应该成功
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
EXPECT_TRUE(buffer.WriteBuffer(data, 0));
EXPECT_EQ(buffer.GetBufferLength(), 3u);
}
TEST_F(ClientBufferTest, WriteBuffer_MultipleWrites_AccumulatesData) {
BYTE data1[] = {1, 2, 3};
BYTE data2[] = {4, 5};
buffer.WriteBuffer(data1, 3);
buffer.WriteBuffer(data2, 2);
EXPECT_EQ(buffer.GetBufferLength(), 5u);
// 验证数据完整性
BYTE result[5];
buffer.ReadBuffer(result, 5);
EXPECT_EQ(result[0], 1);
EXPECT_EQ(result[1], 2);
EXPECT_EQ(result[2], 3);
EXPECT_EQ(result[3], 4);
EXPECT_EQ(result[4], 5);
}
TEST_F(ClientBufferTest, WriteBuffer_LargeData_HandlesCorrectly) {
const ULONG largeSize = 10000;
std::vector<BYTE> data(largeSize);
for (ULONG i = 0; i < largeSize; i++) {
data[i] = (BYTE)(i & 0xFF);
}
EXPECT_TRUE(buffer.WriteBuffer(data.data(), largeSize));
EXPECT_EQ(buffer.GetBufferLength(), largeSize);
}
// ============================================
// ReadBuffer 测试
// ============================================
TEST_F(ClientBufferTest, ReadBuffer_EmptyBuffer_ReturnsZero) {
BYTE result[10];
EXPECT_EQ(buffer.ReadBuffer(result, 10), 0u);
}
TEST_F(ClientBufferTest, ReadBuffer_ExactLength_ReturnsAll) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE result[5];
ULONG bytesRead = buffer.ReadBuffer(result, 5);
EXPECT_EQ(bytesRead, 5u);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
for (int i = 0; i < 5; i++) {
EXPECT_EQ(result[i], data[i]);
}
}
TEST_F(ClientBufferTest, ReadBuffer_PartialRead_LeavesRemainder) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE result[3];
ULONG bytesRead = buffer.ReadBuffer(result, 3);
EXPECT_EQ(bytesRead, 3u);
EXPECT_EQ(buffer.GetBufferLength(), 2u);
// 验证剩余数据
EXPECT_EQ(*buffer.GetBuffer(0), 4);
EXPECT_EQ(*buffer.GetBuffer(1), 5);
}
TEST_F(ClientBufferTest, ReadBuffer_RequestExceedsAvailable_ReturnsAvailableOnly) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
BYTE result[10];
ULONG bytesRead = buffer.ReadBuffer(result, 10);
EXPECT_EQ(bytesRead, 3u);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, ReadBuffer_ZeroLength_ReturnsZero) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
BYTE result[1];
ULONG bytesRead = buffer.ReadBuffer(result, 0);
EXPECT_EQ(bytesRead, 0u);
EXPECT_EQ(buffer.GetBufferLength(), 3u);
}
// ============================================
// Skip 测试
// ============================================
TEST_F(ClientBufferTest, Skip_ZeroPosition_NoChange) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
buffer.Skip(0);
EXPECT_EQ(buffer.GetBufferLength(), 5u);
}
TEST_F(ClientBufferTest, Skip_PartialSkip_RemovesPrefix) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
buffer.Skip(2);
EXPECT_EQ(buffer.GetBufferLength(), 3u);
EXPECT_EQ(*buffer.GetBuffer(0), 3);
EXPECT_EQ(*buffer.GetBuffer(1), 4);
EXPECT_EQ(*buffer.GetBuffer(2), 5);
}
TEST_F(ClientBufferTest, Skip_ExactLength_ClearsBuffer) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
buffer.Skip(3);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, Skip_ExceedsLength_ClampsToAvailable) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
// 这是修复后的行为:不会下溢,而是限制到可用长度
buffer.Skip(100);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, Skip_EmptyBuffer_NoEffect) {
buffer.Skip(10);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
// ============================================
// GetBuffer 测试
// ============================================
TEST_F(ClientBufferTest, GetBuffer_EmptyBuffer_ReturnsNull) {
EXPECT_EQ(buffer.GetBuffer(), nullptr);
}
TEST_F(ClientBufferTest, GetBuffer_ValidPosition_ReturnsCorrectPointer) {
BYTE data[] = {10, 20, 30, 40, 50};
buffer.WriteBuffer(data, 5);
EXPECT_EQ(*buffer.GetBuffer(0), 10);
EXPECT_EQ(*buffer.GetBuffer(2), 30);
EXPECT_EQ(*buffer.GetBuffer(4), 50);
}
TEST_F(ClientBufferTest, GetBuffer_PositionAtLength_ReturnsNull) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
EXPECT_EQ(buffer.GetBuffer(3), nullptr);
}
TEST_F(ClientBufferTest, GetBuffer_PositionExceedsLength_ReturnsNull) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
EXPECT_EQ(buffer.GetBuffer(100), nullptr);
}
// ============================================
// ClearBuffer 测试
// ============================================
TEST_F(ClientBufferTest, ClearBuffer_AfterWrite_ResetsLength) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
buffer.ClearBuffer();
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, ClearBuffer_EmptyBuffer_NoEffect) {
buffer.ClearBuffer();
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
// ============================================
// 数据完整性测试
// ============================================
TEST_F(ClientBufferTest, DataIntegrity_WriteReadCycle_PreservesData) {
// 写入各种字节值
std::vector<BYTE> data(256);
for (int i = 0; i < 256; i++) {
data[i] = (BYTE)i;
}
buffer.WriteBuffer(data.data(), 256);
std::vector<BYTE> result(256);
ULONG bytesRead = buffer.ReadBuffer(result.data(), 256);
EXPECT_EQ(bytesRead, 256u);
for (int i = 0; i < 256; i++) {
EXPECT_EQ(result[i], data[i]) << "Mismatch at index " << i;
}
}
TEST_F(ClientBufferTest, DataIntegrity_MultipleReadWriteCycles) {
for (int cycle = 0; cycle < 10; cycle++) {
BYTE data[100];
for (int i = 0; i < 100; i++) {
data[i] = (BYTE)(cycle * 10 + i);
}
buffer.WriteBuffer(data, 100);
BYTE result[100];
ULONG bytesRead = buffer.ReadBuffer(result, 100);
EXPECT_EQ(bytesRead, 100u);
for (int i = 0; i < 100; i++) {
EXPECT_EQ(result[i], data[i]);
}
}
}
// ============================================
// 边界条件和下溢防护测试
// ============================================
TEST_F(ClientBufferTest, UnderflowProtection_SkipMoreThanLength_NoUnderflow) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
// 不应崩溃或产生意外行为
buffer.Skip(ULONG_MAX);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, UnderflowProtection_ReadMoreThanLength_NoUnderflow) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
BYTE result[1000];
ULONG bytesRead = buffer.ReadBuffer(result, ULONG_MAX);
// 应该只读取可用的 3 字节
EXPECT_EQ(bytesRead, 3u);
}
// ============================================
// 内存管理测试
// ============================================
TEST_F(ClientBufferTest, MemoryManagement_RepeatedAllocations_NoLeak) {
// 反复分配和释放,验证无内存泄漏
for (int i = 0; i < 100; i++) {
WriteFillData(1000);
buffer.ClearBuffer();
}
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ClientBufferTest, MemoryManagement_GrowingBuffer_HandlesReallocation) {
// 逐步增长缓冲区
for (ULONG size = 1; size <= 10000; size *= 2) {
WriteFillData(size);
}
EXPECT_GT(buffer.GetBufferLength(), 0u);
}
// ============================================
// 参数化测试
// ============================================
class BufferReadParameterizedTest
: public ::testing::TestWithParam<std::tuple<size_t, size_t, size_t>> {
protected:
CBuffer buffer;
};
TEST_P(BufferReadParameterizedTest, ReadBuffer_VariousLengths) {
auto [writeLen, readLen, expectedRead] = GetParam();
std::vector<BYTE> data(writeLen, 0x42);
if (writeLen > 0) {
buffer.WriteBuffer(data.data(), (ULONG)writeLen);
}
std::vector<BYTE> result(readLen > 0 ? readLen : 1);
ULONG actual = buffer.ReadBuffer(result.data(), (ULONG)readLen);
EXPECT_EQ(actual, expectedRead);
}
INSTANTIATE_TEST_SUITE_P(
ReadLengths,
BufferReadParameterizedTest,
::testing::Values(
std::make_tuple(10, 5, 5), // 正常读取
std::make_tuple(5, 10, 5), // 请求超过可用
std::make_tuple(0, 5, 0), // 空缓冲区
std::make_tuple(100, 0, 0), // 零长度读取
std::make_tuple(1000, 500, 500) // 大数据部分读取
)
);

View File

@@ -0,0 +1,345 @@
/**
* @file RegistryConfigTest.cpp
* @brief iniFile 和 binFile 注册表配置类单元测试
*
* 测试覆盖:
* - 基本读写操作(字符串、整数、浮点数)
* - 二进制数据读写
* - 键句柄缓存机制
* - 并发读写
* - 默认值处理
* - 边界条件
*
* 注意:此测试仅在 Windows 平台运行,使用真实注册表
*/
#include <gtest/gtest.h>
#ifdef _WIN32
// Windows 头文件必须最先包含
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#include <thread>
#include <atomic>
#include <chrono>
#include <vector>
#include <string>
#include <map>
// 定义 INIFILE_STANDALONE 跳过 commands.h 中的 MFC 依赖
#define INIFILE_STANDALONE 1
// 提供 iniFile.h 需要的最小依赖(来自 commands.h
#ifndef GET_FILEPATH
#define GET_FILEPATH(path, file) do { \
char* p = strrchr(path, '\\'); \
if (p) strcpy_s(p + 1, _MAX_PATH - (p - path + 1), file); \
} while(0)
#endif
inline std::vector<std::string> StringToVector(const std::string& s, char ch) {
std::vector<std::string> result;
std::string::size_type start = 0;
std::string::size_type pos = s.find(ch, start);
while (pos != std::string::npos) {
result.push_back(s.substr(start, pos - start));
start = pos + 1;
pos = s.find(ch, start);
}
result.push_back(s.substr(start));
if (result.empty()) result.push_back("");
return result;
}
// 包含实际的 iniFile.h 头文件(需要修改 iniFile.h 支持 INIFILE_STANDALONE
#include "common/iniFile.h"
// 测试用的注册表路径(与生产环境隔离)
#define TEST_INI_PATH "Software\\YAMA_TEST_INI"
#define TEST_BIN_PATH "Software\\YAMA_TEST_BIN"
// ============================================
// iniFile 测试夹具
// ============================================
class IniFileTest : public ::testing::Test {
protected:
std::unique_ptr<iniFile> cfg;
void SetUp() override {
cfg = std::make_unique<iniFile>(TEST_INI_PATH);
}
void TearDown() override {
cfg.reset();
// 清理测试注册表项
RegDeleteTreeA(HKEY_CURRENT_USER, TEST_INI_PATH);
}
};
// ============================================
// binFile 测试夹具
// ============================================
class BinFileTest : public ::testing::Test {
protected:
std::unique_ptr<binFile> cfg;
void SetUp() override {
cfg = std::make_unique<binFile>(TEST_BIN_PATH);
}
void TearDown() override {
cfg.reset();
// 清理测试注册表项
RegDeleteTreeA(HKEY_CURRENT_USER, TEST_BIN_PATH);
}
};
// ============================================
// iniFile 基本读写测试
// ============================================
TEST_F(IniFileTest, SetStr_GetStr_ReturnsCorrectValue) {
EXPECT_TRUE(cfg->SetStr("TestSection", "TestKey", "HelloWorld"));
EXPECT_EQ(cfg->GetStr("TestSection", "TestKey"), "HelloWorld");
}
TEST_F(IniFileTest, SetStr_EmptyString_ReturnsEmpty) {
EXPECT_TRUE(cfg->SetStr("TestSection", "EmptyKey", ""));
EXPECT_EQ(cfg->GetStr("TestSection", "EmptyKey"), "");
}
TEST_F(IniFileTest, GetStr_NonExistent_ReturnsDefault) {
EXPECT_EQ(cfg->GetStr("NonExistent", "Key", "DefaultValue"), "DefaultValue");
}
TEST_F(IniFileTest, SetInt_GetInt_ReturnsCorrectValue) {
EXPECT_TRUE(cfg->SetInt("TestSection", "IntKey", 12345));
EXPECT_EQ(cfg->GetInt("TestSection", "IntKey"), 12345);
}
TEST_F(IniFileTest, SetInt_NegativeValue_ReturnsCorrect) {
EXPECT_TRUE(cfg->SetInt("TestSection", "NegKey", -9999));
EXPECT_EQ(cfg->GetInt("TestSection", "NegKey"), -9999);
}
TEST_F(IniFileTest, GetInt_NonExistent_ReturnsDefault) {
EXPECT_EQ(cfg->GetInt("NonExistent", "Key", 42), 42);
}
TEST_F(IniFileTest, GetInt_InvalidString_ReturnsDefault) {
cfg->SetStr("TestSection", "InvalidInt", "NotANumber");
EXPECT_EQ(cfg->GetInt("TestSection", "InvalidInt", 99), 99);
}
TEST_F(IniFileTest, SetStr_ChineseCharacters_ReturnsCorrect) {
EXPECT_TRUE(cfg->SetStr("TestSection", "ChineseKey", "中文测试"));
EXPECT_EQ(cfg->GetStr("TestSection", "ChineseKey"), "中文测试");
}
TEST_F(IniFileTest, SetStr_LongString_ReturnsCorrect) {
std::string longStr(400, 'A'); // 400 字符长字符串
EXPECT_TRUE(cfg->SetStr("TestSection", "LongKey", longStr));
EXPECT_EQ(cfg->GetStr("TestSection", "LongKey"), longStr);
}
TEST_F(IniFileTest, SetStr_SpecialCharacters_ReturnsCorrect) {
std::string special = "Path\\With\\Backslashes=Value;With|Special<Chars>";
EXPECT_TRUE(cfg->SetStr("TestSection", "SpecialKey", special));
EXPECT_EQ(cfg->GetStr("TestSection", "SpecialKey"), special);
}
// ============================================
// iniFile 多 Section 测试
// ============================================
TEST_F(IniFileTest, MultipleSections_IsolatedValues) {
cfg->SetStr("Section1", "Key", "Value1");
cfg->SetStr("Section2", "Key", "Value2");
cfg->SetStr("Section3", "Key", "Value3");
EXPECT_EQ(cfg->GetStr("Section1", "Key"), "Value1");
EXPECT_EQ(cfg->GetStr("Section2", "Key"), "Value2");
EXPECT_EQ(cfg->GetStr("Section3", "Key"), "Value3");
}
TEST_F(IniFileTest, MultipleKeys_SameSection) {
cfg->SetStr("Settings", "Key1", "Value1");
cfg->SetStr("Settings", "Key2", "Value2");
cfg->SetInt("Settings", "Key3", 100);
EXPECT_EQ(cfg->GetStr("Settings", "Key1"), "Value1");
EXPECT_EQ(cfg->GetStr("Settings", "Key2"), "Value2");
EXPECT_EQ(cfg->GetInt("Settings", "Key3"), 100);
}
// ============================================
// iniFile 覆盖写入测试
// ============================================
TEST_F(IniFileTest, OverwriteValue_ReturnsNewValue) {
cfg->SetStr("TestSection", "Key", "OldValue");
cfg->SetStr("TestSection", "Key", "NewValue");
EXPECT_EQ(cfg->GetStr("TestSection", "Key"), "NewValue");
}
TEST_F(IniFileTest, OverwriteInt_ReturnsNewValue) {
cfg->SetInt("TestSection", "IntKey", 100);
cfg->SetInt("TestSection", "IntKey", 200);
EXPECT_EQ(cfg->GetInt("TestSection", "IntKey"), 200);
}
// ============================================
// binFile 基本读写测试
// ============================================
TEST_F(BinFileTest, SetStr_GetStr_ReturnsCorrectValue) {
EXPECT_TRUE(cfg->SetStr("TestSection", "TestKey", "BinaryHello"));
EXPECT_EQ(cfg->GetStr("TestSection", "TestKey"), "BinaryHello");
}
TEST_F(BinFileTest, SetInt_GetInt_ReturnsCorrectValue) {
EXPECT_TRUE(cfg->SetInt("TestSection", "IntKey", 0x12345678));
EXPECT_EQ(cfg->GetInt("TestSection", "IntKey"), 0x12345678);
}
TEST_F(BinFileTest, SetInt_NegativeValue_ReturnsCorrect) {
EXPECT_TRUE(cfg->SetInt("TestSection", "NegKey", -12345));
EXPECT_EQ(cfg->GetInt("TestSection", "NegKey"), -12345);
}
TEST_F(BinFileTest, GetInt_NonExistent_ReturnsDefault) {
EXPECT_EQ(cfg->GetInt("NonExistent", "Key", 999), 999);
}
TEST_F(BinFileTest, GetStr_NonExistent_ReturnsDefault) {
EXPECT_EQ(cfg->GetStr("NonExistent", "Key", "Default"), "Default");
}
// ============================================
// 并发测试(验证注册表操作的稳定性)
// ============================================
TEST_F(IniFileTest, Concurrent_ReadWrite_NoCorruption) {
std::atomic<bool> running{true};
std::atomic<int> writeCount{0};
std::atomic<int> readCount{0};
// 写线程
std::thread writer([&]() {
int i = 0;
while (running) {
std::string key = "Key" + std::to_string(i % 10);
std::string value = "Value" + std::to_string(i);
if (cfg->SetStr("ConcurrentTest", key, value)) {
writeCount++;
}
i++;
std::this_thread::yield();
}
});
// 读线程
std::thread reader([&]() {
while (running) {
for (int i = 0; i < 10; i++) {
std::string key = "Key" + std::to_string(i);
cfg->GetStr("ConcurrentTest", key, "");
readCount++;
}
std::this_thread::yield();
}
});
// 运行 100ms
std::this_thread::sleep_for(std::chrono::milliseconds(100));
running = false;
writer.join();
reader.join();
EXPECT_GT(writeCount.load(), 0);
EXPECT_GT(readCount.load(), 0);
}
TEST_F(IniFileTest, Concurrent_MultipleWriters_NoCrash) {
std::atomic<bool> running{true};
std::vector<std::thread> writers;
for (int t = 0; t < 4; t++) {
writers.emplace_back([&, t]() {
int i = 0;
while (running) {
std::string section = "Section" + std::to_string(t);
std::string key = "Key" + std::to_string(i % 5);
cfg->SetInt(section, key, i);
cfg->GetInt(section, key);
i++;
std::this_thread::yield();
}
});
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
running = false;
for (auto& t : writers) {
t.join();
}
// 无崩溃即为成功
SUCCEED();
}
// ============================================
// 边界条件测试
// ============================================
TEST_F(IniFileTest, BoundaryCondition_MaxIntValue) {
cfg->SetInt("BoundaryTest", "MaxInt", INT_MAX);
EXPECT_EQ(cfg->GetInt("BoundaryTest", "MaxInt"), INT_MAX);
}
TEST_F(IniFileTest, BoundaryCondition_MinIntValue) {
cfg->SetInt("BoundaryTest", "MinInt", INT_MIN);
EXPECT_EQ(cfg->GetInt("BoundaryTest", "MinInt"), INT_MIN);
}
TEST_F(IniFileTest, BoundaryCondition_ZeroValue) {
cfg->SetInt("BoundaryTest", "Zero", 0);
EXPECT_EQ(cfg->GetInt("BoundaryTest", "Zero", -1), 0);
}
TEST_F(BinFileTest, BoundaryCondition_ZeroInt) {
cfg->SetInt("BoundaryTest", "Zero", 0);
EXPECT_EQ(cfg->GetInt("BoundaryTest", "Zero", -1), 0);
}
TEST_F(BinFileTest, BoundaryCondition_NegativeMax) {
cfg->SetInt("BoundaryTest", "NegMax", INT_MIN);
EXPECT_EQ(cfg->GetInt("BoundaryTest", "NegMax"), INT_MIN);
}
// ============================================
// 性能测试
// ============================================
TEST_F(IniFileTest, Performance_CachedAccess_FastEnough) {
// 预热缓存
cfg->SetStr("PerfTest", "Key", "InitialValue");
auto start = std::chrono::high_resolution_clock::now();
const int iterations = 1000;
for (int i = 0; i < iterations; i++) {
cfg->SetStr("PerfTest", "Key", "Value" + std::to_string(i));
cfg->GetStr("PerfTest", "Key");
}
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
// 1000 次读写应在 2 秒内完成
EXPECT_LT(duration.count(), 2000) << "Performance: " << duration.count() << "ms for 1000 iterations";
}
#else
// 非 Windows 平台跳过测试
TEST(RegistryConfigTest, SkippedOnNonWindows) {
GTEST_SKIP() << "Registry tests only run on Windows";
}
#endif

View File

@@ -0,0 +1,629 @@
/**
* @file ChunkManagerTest.cpp
* @brief 分块管理和区间跟踪测试
*
* 测试覆盖:
* - FileRangeV2 区间操作
* - 接收状态跟踪
* - 区间合并算法
* - 断点续传区间计算
*/
#include <gtest/gtest.h>
#include <vector>
#include <algorithm>
#include <cstdint>
// ============================================
// 区间结构(测试专用)
// ============================================
struct Range {
uint64_t offset;
uint64_t length;
Range(uint64_t o = 0, uint64_t l = 0) : offset(o), length(l) {}
uint64_t end() const { return offset + length; }
bool operator<(const Range& other) const {
return offset < other.offset;
}
bool operator==(const Range& other) const {
return offset == other.offset && length == other.length;
}
};
// ============================================
// 区间管理类(断点续传核心逻辑)
// ============================================
class RangeManager {
public:
RangeManager(uint64_t fileSize = 0) : m_fileSize(fileSize), m_receivedBytes(0) {}
// 添加已接收区间
void AddRange(uint64_t offset, uint64_t length) {
if (length == 0) return;
Range newRange(offset, length);
m_ranges.push_back(newRange);
MergeRanges();
UpdateReceivedBytes();
}
// 获取已接收区间列表
const std::vector<Range>& GetRanges() const {
return m_ranges;
}
// 获取缺失区间列表
std::vector<Range> GetMissingRanges() const {
std::vector<Range> missing;
if (m_fileSize == 0) return missing;
uint64_t currentPos = 0;
for (const auto& r : m_ranges) {
if (r.offset > currentPos) {
missing.emplace_back(currentPos, r.offset - currentPos);
}
currentPos = r.end();
}
if (currentPos < m_fileSize) {
missing.emplace_back(currentPos, m_fileSize - currentPos);
}
return missing;
}
// 获取已接收字节数
uint64_t GetReceivedBytes() const {
return m_receivedBytes;
}
// 是否完整接收
bool IsComplete() const {
return m_receivedBytes >= m_fileSize && m_fileSize > 0;
}
// 清空
void Clear() {
m_ranges.clear();
m_receivedBytes = 0;
}
// 设置文件大小
void SetFileSize(uint64_t size) {
m_fileSize = size;
}
uint64_t GetFileSize() const {
return m_fileSize;
}
private:
void MergeRanges() {
if (m_ranges.size() <= 1) return;
std::sort(m_ranges.begin(), m_ranges.end());
std::vector<Range> merged;
merged.push_back(m_ranges[0]);
for (size_t i = 1; i < m_ranges.size(); ++i) {
Range& last = merged.back();
const Range& current = m_ranges[i];
// 检查是否可以合并(相邻或重叠)
if (current.offset <= last.end()) {
// 扩展现有区间
uint64_t newEnd = std::max(last.end(), current.end());
last.length = newEnd - last.offset;
} else {
// 新的独立区间
merged.push_back(current);
}
}
m_ranges = std::move(merged);
}
void UpdateReceivedBytes() {
m_receivedBytes = 0;
for (const auto& r : m_ranges) {
m_receivedBytes += r.length;
}
}
std::vector<Range> m_ranges;
uint64_t m_fileSize;
uint64_t m_receivedBytes;
};
// ============================================
// 接收状态类(模拟 FileRecvStateV2
// ============================================
class FileRecvState {
public:
FileRecvState() : m_fileSize(0), m_fileIndex(0), m_transferID(0) {}
void Initialize(uint64_t transferID, uint32_t fileIndex, uint64_t fileSize, const std::string& filePath) {
m_transferID = transferID;
m_fileIndex = fileIndex;
m_fileSize = fileSize;
m_filePath = filePath;
m_rangeManager.SetFileSize(fileSize);
}
void AddChunk(uint64_t offset, uint64_t length) {
m_rangeManager.AddRange(offset, length);
}
bool IsComplete() const {
return m_rangeManager.IsComplete();
}
uint64_t GetReceivedBytes() const {
return m_rangeManager.GetReceivedBytes();
}
double GetProgress() const {
if (m_fileSize == 0) return 0.0;
return static_cast<double>(GetReceivedBytes()) / m_fileSize * 100.0;
}
std::vector<Range> GetMissingRanges() const {
return m_rangeManager.GetMissingRanges();
}
const std::vector<Range>& GetReceivedRanges() const {
return m_rangeManager.GetRanges();
}
uint64_t GetTransferID() const { return m_transferID; }
uint32_t GetFileIndex() const { return m_fileIndex; }
uint64_t GetFileSize() const { return m_fileSize; }
const std::string& GetFilePath() const { return m_filePath; }
private:
uint64_t m_transferID;
uint32_t m_fileIndex;
uint64_t m_fileSize;
std::string m_filePath;
RangeManager m_rangeManager;
};
// ============================================
// Range 基础测试
// ============================================
class RangeTest : public ::testing::Test {};
TEST_F(RangeTest, DefaultConstruction) {
Range r;
EXPECT_EQ(r.offset, 0u);
EXPECT_EQ(r.length, 0u);
EXPECT_EQ(r.end(), 0u);
}
TEST_F(RangeTest, Construction) {
Range r(100, 200);
EXPECT_EQ(r.offset, 100u);
EXPECT_EQ(r.length, 200u);
EXPECT_EQ(r.end(), 300u);
}
TEST_F(RangeTest, Comparison) {
Range r1(0, 100);
Range r2(100, 100);
Range r3(0, 100);
EXPECT_TRUE(r1 < r2);
EXPECT_FALSE(r2 < r1);
EXPECT_TRUE(r1 == r3);
EXPECT_FALSE(r1 == r2);
}
// ============================================
// RangeManager 测试
// ============================================
class RangeManagerTest : public ::testing::Test {};
TEST_F(RangeManagerTest, InitialState) {
RangeManager rm(1000);
EXPECT_EQ(rm.GetReceivedBytes(), 0u);
EXPECT_EQ(rm.GetFileSize(), 1000u);
EXPECT_FALSE(rm.IsComplete());
EXPECT_TRUE(rm.GetRanges().empty());
}
TEST_F(RangeManagerTest, AddSingleRange) {
RangeManager rm(1000);
rm.AddRange(0, 500);
EXPECT_EQ(rm.GetReceivedBytes(), 500u);
ASSERT_EQ(rm.GetRanges().size(), 1u);
EXPECT_EQ(rm.GetRanges()[0].offset, 0u);
EXPECT_EQ(rm.GetRanges()[0].length, 500u);
}
TEST_F(RangeManagerTest, AddZeroLengthRange) {
RangeManager rm(1000);
rm.AddRange(0, 0);
EXPECT_EQ(rm.GetReceivedBytes(), 0u);
EXPECT_TRUE(rm.GetRanges().empty());
}
TEST_F(RangeManagerTest, AddNonOverlappingRanges) {
RangeManager rm(1000);
rm.AddRange(0, 100);
rm.AddRange(200, 100);
rm.AddRange(400, 100);
EXPECT_EQ(rm.GetReceivedBytes(), 300u);
ASSERT_EQ(rm.GetRanges().size(), 3u);
}
TEST_F(RangeManagerTest, MergeAdjacentRanges) {
RangeManager rm(1000);
rm.AddRange(0, 100);
rm.AddRange(100, 100); // 紧邻
EXPECT_EQ(rm.GetReceivedBytes(), 200u);
ASSERT_EQ(rm.GetRanges().size(), 1u);
EXPECT_EQ(rm.GetRanges()[0].offset, 0u);
EXPECT_EQ(rm.GetRanges()[0].length, 200u);
}
TEST_F(RangeManagerTest, MergeOverlappingRanges) {
RangeManager rm(1000);
rm.AddRange(0, 150);
rm.AddRange(100, 150); // 重叠
EXPECT_EQ(rm.GetReceivedBytes(), 250u);
ASSERT_EQ(rm.GetRanges().size(), 1u);
EXPECT_EQ(rm.GetRanges()[0].offset, 0u);
EXPECT_EQ(rm.GetRanges()[0].length, 250u);
}
TEST_F(RangeManagerTest, MergeContainedRange) {
RangeManager rm(1000);
rm.AddRange(0, 500);
rm.AddRange(100, 100); // 完全被包含
EXPECT_EQ(rm.GetReceivedBytes(), 500u);
ASSERT_EQ(rm.GetRanges().size(), 1u);
EXPECT_EQ(rm.GetRanges()[0].length, 500u);
}
TEST_F(RangeManagerTest, MergeOutOfOrder) {
RangeManager rm(1000);
rm.AddRange(500, 100);
rm.AddRange(100, 100);
rm.AddRange(0, 100);
EXPECT_EQ(rm.GetReceivedBytes(), 300u);
// 0-100 和 100-200 被合并成 0-200加上 500-600共 2 个区间
ASSERT_EQ(rm.GetRanges().size(), 2u);
// 验证排序和合并
EXPECT_EQ(rm.GetRanges()[0].offset, 0u);
EXPECT_EQ(rm.GetRanges()[0].length, 200u); // 合并后
EXPECT_EQ(rm.GetRanges()[1].offset, 500u);
EXPECT_EQ(rm.GetRanges()[1].length, 100u);
}
TEST_F(RangeManagerTest, MergeMultipleOverlapping) {
RangeManager rm(1000);
rm.AddRange(0, 100);
rm.AddRange(200, 100);
rm.AddRange(50, 200); // 跨越两个区间
EXPECT_EQ(rm.GetReceivedBytes(), 300u);
ASSERT_EQ(rm.GetRanges().size(), 1u);
EXPECT_EQ(rm.GetRanges()[0].offset, 0u);
EXPECT_EQ(rm.GetRanges()[0].length, 300u);
}
TEST_F(RangeManagerTest, GetMissingRanges_Empty) {
RangeManager rm(1000);
auto missing = rm.GetMissingRanges();
ASSERT_EQ(missing.size(), 1u);
EXPECT_EQ(missing[0].offset, 0u);
EXPECT_EQ(missing[0].length, 1000u);
}
TEST_F(RangeManagerTest, GetMissingRanges_Partial) {
RangeManager rm(1000);
rm.AddRange(0, 100);
rm.AddRange(500, 100);
auto missing = rm.GetMissingRanges();
ASSERT_EQ(missing.size(), 2u);
EXPECT_EQ(missing[0].offset, 100u);
EXPECT_EQ(missing[0].length, 400u);
EXPECT_EQ(missing[1].offset, 600u);
EXPECT_EQ(missing[1].length, 400u);
}
TEST_F(RangeManagerTest, GetMissingRanges_Complete) {
RangeManager rm(1000);
rm.AddRange(0, 1000);
auto missing = rm.GetMissingRanges();
EXPECT_TRUE(missing.empty());
}
TEST_F(RangeManagerTest, IsComplete_Exact) {
RangeManager rm(1000);
rm.AddRange(0, 1000);
EXPECT_TRUE(rm.IsComplete());
}
TEST_F(RangeManagerTest, IsComplete_OverReceived) {
RangeManager rm(1000);
rm.AddRange(0, 1500); // 超过文件大小
EXPECT_TRUE(rm.IsComplete());
}
TEST_F(RangeManagerTest, IsComplete_Partial) {
RangeManager rm(1000);
rm.AddRange(0, 999);
EXPECT_FALSE(rm.IsComplete());
}
TEST_F(RangeManagerTest, Clear) {
RangeManager rm(1000);
rm.AddRange(0, 500);
rm.Clear();
EXPECT_EQ(rm.GetReceivedBytes(), 0u);
EXPECT_TRUE(rm.GetRanges().empty());
}
// ============================================
// FileRecvState 测试
// ============================================
class FileRecvStateTest : public ::testing::Test {};
TEST_F(FileRecvStateTest, Initialize) {
FileRecvState state;
state.Initialize(12345, 0, 1024, "C:\\test\\file.txt");
EXPECT_EQ(state.GetTransferID(), 12345u);
EXPECT_EQ(state.GetFileIndex(), 0u);
EXPECT_EQ(state.GetFileSize(), 1024u);
EXPECT_EQ(state.GetFilePath(), "C:\\test\\file.txt");
EXPECT_EQ(state.GetReceivedBytes(), 0u);
EXPECT_FALSE(state.IsComplete());
}
TEST_F(FileRecvStateTest, AddChunks) {
FileRecvState state;
state.Initialize(1, 0, 1000, "file.txt");
state.AddChunk(0, 100);
EXPECT_EQ(state.GetReceivedBytes(), 100u);
EXPECT_NEAR(state.GetProgress(), 10.0, 0.01);
state.AddChunk(100, 400);
EXPECT_EQ(state.GetReceivedBytes(), 500u);
EXPECT_NEAR(state.GetProgress(), 50.0, 0.01);
state.AddChunk(500, 500);
EXPECT_EQ(state.GetReceivedBytes(), 1000u);
EXPECT_TRUE(state.IsComplete());
EXPECT_NEAR(state.GetProgress(), 100.0, 0.01);
}
TEST_F(FileRecvStateTest, OutOfOrderChunks) {
FileRecvState state;
state.Initialize(1, 0, 1000, "file.txt");
state.AddChunk(500, 200);
state.AddChunk(0, 200);
state.AddChunk(800, 200);
EXPECT_EQ(state.GetReceivedBytes(), 600u);
auto missing = state.GetMissingRanges();
ASSERT_EQ(missing.size(), 2u);
EXPECT_EQ(missing[0].offset, 200u);
EXPECT_EQ(missing[0].length, 300u);
EXPECT_EQ(missing[1].offset, 700u);
EXPECT_EQ(missing[1].length, 100u);
}
TEST_F(FileRecvStateTest, DuplicateChunks) {
FileRecvState state;
state.Initialize(1, 0, 1000, "file.txt");
state.AddChunk(0, 500);
state.AddChunk(0, 500); // 重复
state.AddChunk(250, 250); // 重叠
EXPECT_EQ(state.GetReceivedBytes(), 500u);
}
// ============================================
// 断点续传场景测试
// ============================================
class ResumeScenarioTest : public ::testing::Test {};
TEST_F(ResumeScenarioTest, SimulateInterruptedTransfer) {
FileRecvState state;
state.Initialize(12345, 0, 10000, "large_file.bin");
// 模拟接收了一些数据后中断
state.AddChunk(0, 2000);
state.AddChunk(2000, 2000);
state.AddChunk(5000, 1000);
EXPECT_EQ(state.GetReceivedBytes(), 5000u);
EXPECT_NEAR(state.GetProgress(), 50.0, 0.01);
// 获取需要续传的区间
auto missing = state.GetMissingRanges();
ASSERT_EQ(missing.size(), 2u);
// 验证缺失区间
EXPECT_EQ(missing[0].offset, 4000u);
EXPECT_EQ(missing[0].length, 1000u);
EXPECT_EQ(missing[1].offset, 6000u);
EXPECT_EQ(missing[1].length, 4000u);
}
TEST_F(ResumeScenarioTest, ResumeAndComplete) {
FileRecvState state;
state.Initialize(12345, 0, 10000, "large_file.bin");
// 初始接收
state.AddChunk(0, 3000);
state.AddChunk(7000, 3000);
EXPECT_FALSE(state.IsComplete());
// 续传缺失部分
auto missing = state.GetMissingRanges();
for (const auto& r : missing) {
state.AddChunk(r.offset, r.length);
}
EXPECT_TRUE(state.IsComplete());
EXPECT_EQ(state.GetReceivedBytes(), 10000u);
}
TEST_F(ResumeScenarioTest, SmallChunksReassembly) {
FileRecvState state;
state.Initialize(1, 0, 1000, "file.txt");
// 模拟接收很多小块
for (uint64_t i = 0; i < 1000; i += 10) {
state.AddChunk(i, 10);
}
EXPECT_TRUE(state.IsComplete());
// 验证区间已合并
const auto& ranges = state.GetReceivedRanges();
EXPECT_EQ(ranges.size(), 1u);
EXPECT_EQ(ranges[0].offset, 0u);
EXPECT_EQ(ranges[0].length, 1000u);
}
// ============================================
// 大文件场景测试
// ============================================
class LargeFileScenarioTest : public ::testing::Test {};
TEST_F(LargeFileScenarioTest, FileGreaterThan4GB) {
FileRecvState state;
uint64_t fileSize = 5ULL * 1024 * 1024 * 1024; // 5 GB
state.Initialize(1, 0, fileSize, "huge.bin");
uint64_t chunkSize = 64 * 1024; // 64 KB chunks
// 添加几个大区间
state.AddChunk(0, 1ULL * 1024 * 1024 * 1024); // 1 GB
state.AddChunk(3ULL * 1024 * 1024 * 1024, 1ULL * 1024 * 1024 * 1024); // 1 GB at 3GB
EXPECT_EQ(state.GetReceivedBytes(), 2ULL * 1024 * 1024 * 1024);
auto missing = state.GetMissingRanges();
ASSERT_EQ(missing.size(), 2u);
// 1GB-3GB 缺失
EXPECT_EQ(missing[0].offset, 1ULL * 1024 * 1024 * 1024);
EXPECT_EQ(missing[0].length, 2ULL * 1024 * 1024 * 1024);
// 4GB-5GB 缺失
EXPECT_EQ(missing[1].offset, 4ULL * 1024 * 1024 * 1024);
EXPECT_EQ(missing[1].length, 1ULL * 1024 * 1024 * 1024);
}
// ============================================
// 边界条件测试
// ============================================
class ChunkBoundaryTest : public ::testing::Test {};
TEST_F(ChunkBoundaryTest, ZeroFileSize) {
FileRecvState state;
state.Initialize(1, 0, 0, "empty.txt");
EXPECT_FALSE(state.IsComplete()); // 0大小文件不算完成
EXPECT_TRUE(state.GetMissingRanges().empty());
}
TEST_F(ChunkBoundaryTest, SingleByteFile) {
FileRecvState state;
state.Initialize(1, 0, 1, "tiny.txt");
state.AddChunk(0, 1);
EXPECT_TRUE(state.IsComplete());
}
TEST_F(ChunkBoundaryTest, MaxValues) {
RangeManager rm(UINT64_MAX);
// 添加接近最大值的区间
rm.AddRange(UINT64_MAX - 1000, 500);
EXPECT_EQ(rm.GetReceivedBytes(), 500u);
}
TEST_F(ChunkBoundaryTest, OverlappingAtBoundary) {
RangeManager rm(1000);
rm.AddRange(0, 500);
rm.AddRange(499, 2); // 重叠1字节
EXPECT_EQ(rm.GetReceivedBytes(), 501u);
ASSERT_EQ(rm.GetRanges().size(), 1u);
}
// ============================================
// 性能相关测试
// ============================================
class ChunkPerformanceTest : public ::testing::Test {};
TEST_F(ChunkPerformanceTest, ManySmallRanges) {
RangeManager rm(1000000);
// 添加大量不连续的小区间
for (uint64_t i = 0; i < 1000000; i += 20) {
rm.AddRange(i, 10);
}
// 验证区间数量合理
EXPECT_LE(rm.GetRanges().size(), 50000u);
EXPECT_EQ(rm.GetReceivedBytes(), 500000u);
}
TEST_F(ChunkPerformanceTest, ManyContiguousRanges) {
RangeManager rm(1000000);
// 添加大量连续的小区间(应该全部合并)
for (uint64_t i = 0; i < 1000; ++i) {
rm.AddRange(i * 1000, 1000);
}
// 应该合并成单个区间
ASSERT_EQ(rm.GetRanges().size(), 1u);
EXPECT_EQ(rm.GetReceivedBytes(), 1000000u);
EXPECT_TRUE(rm.IsComplete());
}

View File

@@ -0,0 +1,700 @@
/**
* @file FileTransferV2Test.cpp
* @brief V2 文件传输逻辑测试
*
* 测试覆盖:
* - TransferOptionsV2 结构体
* - 传输ID生成
* - 包头构建与解析
* - 变长数据包处理
*/
#include <gtest/gtest.h>
#include <cstring>
#include <vector>
#include <map>
#include <random>
#include <chrono>
// ============================================
// 从 file_upload.h 复制的结构体定义(测试专用)
// ============================================
#pragma pack(push, 1)
enum FileFlagsV2 : uint16_t {
FFV2_NONE = 0x0000,
FFV2_LAST_CHUNK = 0x0001,
FFV2_RESUME_REQ = 0x0002,
FFV2_RESUME_RESP = 0x0004,
FFV2_CANCEL = 0x0008,
FFV2_DIRECTORY = 0x0010,
FFV2_COMPRESSED = 0x0020,
FFV2_ERROR = 0x0040,
};
struct FileChunkPacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint32_t totalFiles;
uint64_t fileSize;
uint64_t offset;
uint64_t dataLength;
uint64_t nameLength;
uint16_t flags;
uint16_t checksum;
uint8_t reserved[8];
};
struct FileRangeV2 {
uint64_t offset;
uint64_t length;
};
struct FileResumePacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint64_t fileSize;
uint64_t receivedBytes;
uint16_t flags;
uint16_t rangeCount;
};
struct FileCompletePacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint64_t fileSize;
uint8_t sha256[32];
};
struct FileQueryResumeV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileCount;
};
struct FileQueryResumeEntryV2 {
uint64_t fileSize;
uint16_t nameLength;
};
struct FileResumeResponseV2 {
uint8_t cmd;
uint64_t srcClientID;
uint64_t dstClientID;
uint16_t flags;
uint32_t fileCount;
};
struct FileResumeResponseEntryV2 {
uint32_t fileIndex;
uint64_t receivedBytes;
};
#pragma pack(pop)
// V2 传输选项
struct TransferOptionsV2 {
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
bool enableResume;
std::map<uint32_t, uint64_t> startOffsets;
TransferOptionsV2() : transferID(0), srcClientID(0), dstClientID(0), enableResume(true) {}
};
// ============================================
// 辅助函数(测试专用实现)
// ============================================
// 生成传输ID简化实现
uint64_t GenerateTransferID() {
static std::mt19937_64 rng(
static_cast<uint64_t>(std::chrono::high_resolution_clock::now().time_since_epoch().count())
);
return rng();
}
// 构建文件块包
std::vector<uint8_t> BuildFileChunkPacketV2(
uint64_t transferID,
uint64_t srcClientID,
uint64_t dstClientID,
uint32_t fileIndex,
uint32_t totalFiles,
uint64_t fileSize,
uint64_t offset,
const std::string& filename,
const std::vector<uint8_t>& data,
uint16_t flags = FFV2_NONE)
{
size_t totalSize = sizeof(FileChunkPacketV2) + filename.size() + data.size();
std::vector<uint8_t> buffer(totalSize);
FileChunkPacketV2* pkt = reinterpret_cast<FileChunkPacketV2*>(buffer.data());
pkt->cmd = 85; // COMMAND_SEND_FILE_V2
pkt->transferID = transferID;
pkt->srcClientID = srcClientID;
pkt->dstClientID = dstClientID;
pkt->fileIndex = fileIndex;
pkt->totalFiles = totalFiles;
pkt->fileSize = fileSize;
pkt->offset = offset;
pkt->dataLength = data.size();
pkt->nameLength = filename.size();
pkt->flags = flags;
pkt->checksum = 0;
memset(pkt->reserved, 0, sizeof(pkt->reserved));
// 追加文件名
memcpy(buffer.data() + sizeof(FileChunkPacketV2), filename.data(), filename.size());
// 追加数据
if (!data.empty()) {
memcpy(buffer.data() + sizeof(FileChunkPacketV2) + filename.size(), data.data(), data.size());
}
return buffer;
}
// 解析文件块包
bool ParseFileChunkPacketV2(
const uint8_t* buffer, size_t len,
FileChunkPacketV2& header,
std::string& filename,
std::vector<uint8_t>& data)
{
if (len < sizeof(FileChunkPacketV2)) {
return false;
}
memcpy(&header, buffer, sizeof(FileChunkPacketV2));
size_t expectedLen = sizeof(FileChunkPacketV2) + header.nameLength + header.dataLength;
if (len < expectedLen) {
return false;
}
const char* namePtr = reinterpret_cast<const char*>(buffer + sizeof(FileChunkPacketV2));
filename.assign(namePtr, header.nameLength);
if (header.dataLength > 0) {
const uint8_t* dataPtr = buffer + sizeof(FileChunkPacketV2) + header.nameLength;
data.assign(dataPtr, dataPtr + header.dataLength);
} else {
data.clear();
}
return true;
}
// 构建续传查询包
std::vector<uint8_t> BuildResumeQuery(
uint64_t transferID,
uint64_t srcClientID,
uint64_t dstClientID,
const std::vector<std::pair<std::string, uint64_t>>& files)
{
// 计算总大小
size_t totalSize = sizeof(FileQueryResumeV2);
for (const auto& file : files) {
totalSize += sizeof(FileQueryResumeEntryV2) + file.first.size();
}
std::vector<uint8_t> buffer(totalSize);
FileQueryResumeV2* pkt = reinterpret_cast<FileQueryResumeV2*>(buffer.data());
pkt->cmd = 88; // COMMAND_FILE_QUERY_RESUME
pkt->transferID = transferID;
pkt->srcClientID = srcClientID;
pkt->dstClientID = dstClientID;
pkt->fileCount = static_cast<uint32_t>(files.size());
uint8_t* ptr = buffer.data() + sizeof(FileQueryResumeV2);
for (const auto& file : files) {
FileQueryResumeEntryV2* entry = reinterpret_cast<FileQueryResumeEntryV2*>(ptr);
entry->fileSize = file.second;
entry->nameLength = static_cast<uint16_t>(file.first.size());
ptr += sizeof(FileQueryResumeEntryV2);
memcpy(ptr, file.first.data(), file.first.size());
ptr += file.first.size();
}
return buffer;
}
// 解析续传查询包
bool ParseResumeQuery(
const uint8_t* buffer, size_t len,
uint64_t& transferID,
uint64_t& srcClientID,
uint64_t& dstClientID,
std::vector<std::pair<std::string, uint64_t>>& files)
{
if (len < sizeof(FileQueryResumeV2)) {
return false;
}
const FileQueryResumeV2* pkt = reinterpret_cast<const FileQueryResumeV2*>(buffer);
transferID = pkt->transferID;
srcClientID = pkt->srcClientID;
dstClientID = pkt->dstClientID;
files.clear();
const uint8_t* ptr = buffer + sizeof(FileQueryResumeV2);
const uint8_t* end = buffer + len;
for (uint32_t i = 0; i < pkt->fileCount; ++i) {
if (ptr + sizeof(FileQueryResumeEntryV2) > end) {
return false;
}
const FileQueryResumeEntryV2* entry = reinterpret_cast<const FileQueryResumeEntryV2*>(ptr);
ptr += sizeof(FileQueryResumeEntryV2);
if (ptr + entry->nameLength > end) {
return false;
}
std::string name(reinterpret_cast<const char*>(ptr), entry->nameLength);
ptr += entry->nameLength;
files.emplace_back(name, entry->fileSize);
}
return true;
}
// ============================================
// TransferOptionsV2 测试
// ============================================
class TransferOptionsV2Test : public ::testing::Test {};
TEST_F(TransferOptionsV2Test, DefaultValues) {
TransferOptionsV2 options;
EXPECT_EQ(options.transferID, 0u);
EXPECT_EQ(options.srcClientID, 0u);
EXPECT_EQ(options.dstClientID, 0u);
EXPECT_TRUE(options.enableResume);
EXPECT_TRUE(options.startOffsets.empty());
}
TEST_F(TransferOptionsV2Test, SetStartOffsets) {
TransferOptionsV2 options;
options.startOffsets[0] = 1024;
options.startOffsets[1] = 2048;
options.startOffsets[2] = 0;
EXPECT_EQ(options.startOffsets.size(), 3u);
EXPECT_EQ(options.startOffsets[0], 1024u);
EXPECT_EQ(options.startOffsets[1], 2048u);
EXPECT_EQ(options.startOffsets[2], 0u);
}
TEST_F(TransferOptionsV2Test, C2CConfiguration) {
TransferOptionsV2 options;
options.transferID = 12345;
options.srcClientID = 100;
options.dstClientID = 200;
EXPECT_EQ(options.transferID, 12345u);
EXPECT_EQ(options.srcClientID, 100u);
EXPECT_EQ(options.dstClientID, 200u);
}
// ============================================
// TransferID 生成测试
// ============================================
class TransferIDTest : public ::testing::Test {};
TEST_F(TransferIDTest, GeneratesNonZero) {
// 多次生成应该都不为 0极低概率
for (int i = 0; i < 100; ++i) {
uint64_t id = GenerateTransferID();
// 不做强制断言,因为理论上可能为 0
// 但实际概率极低
if (id == 0) {
// 如果真的生成了 0再生成一次
id = GenerateTransferID();
}
}
SUCCEED();
}
TEST_F(TransferIDTest, GeneratesUnique) {
std::vector<uint64_t> ids;
for (int i = 0; i < 1000; ++i) {
ids.push_back(GenerateTransferID());
}
// 检查没有重复
std::sort(ids.begin(), ids.end());
auto it = std::unique(ids.begin(), ids.end());
EXPECT_EQ(it, ids.end()) << "Found duplicate transfer IDs";
}
// ============================================
// FileChunkPacketV2 构建测试
// ============================================
class FileChunkPacketV2BuildTest : public ::testing::Test {};
TEST_F(FileChunkPacketV2BuildTest, BuildSimplePacket) {
uint64_t transferID = 12345;
std::string filename = "test.txt";
std::vector<uint8_t> data = {0x01, 0x02, 0x03, 0x04};
auto buffer = BuildFileChunkPacketV2(
transferID, 0, 0, 0, 1, 100, 0, filename, data);
EXPECT_EQ(buffer.size(), sizeof(FileChunkPacketV2) + filename.size() + data.size());
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->cmd, 85);
EXPECT_EQ(pkt->transferID, transferID);
EXPECT_EQ(pkt->nameLength, filename.size());
EXPECT_EQ(pkt->dataLength, data.size());
}
TEST_F(FileChunkPacketV2BuildTest, BuildWithFlags) {
auto buffer = BuildFileChunkPacketV2(
1, 0, 0, 0, 1, 100, 50, "file.bin", {}, FFV2_LAST_CHUNK);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->flags, FFV2_LAST_CHUNK);
EXPECT_EQ(pkt->offset, 50u);
}
TEST_F(FileChunkPacketV2BuildTest, BuildC2CPacket) {
auto buffer = BuildFileChunkPacketV2(
999, 100, 200, 0, 3, 1024, 0, "shared/doc.pdf", {});
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->srcClientID, 100u);
EXPECT_EQ(pkt->dstClientID, 200u);
EXPECT_EQ(pkt->totalFiles, 3u);
}
TEST_F(FileChunkPacketV2BuildTest, BuildEmptyDataPacket) {
std::vector<uint8_t> emptyData;
auto buffer = BuildFileChunkPacketV2(
1, 0, 0, 0, 1, 0, 0, "empty.txt", emptyData);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->dataLength, 0u);
EXPECT_EQ(pkt->fileSize, 0u);
}
TEST_F(FileChunkPacketV2BuildTest, BuildDirectoryPacket) {
auto buffer = BuildFileChunkPacketV2(
1, 0, 0, 0, 1, 0, 0, "subdir/", {}, FFV2_DIRECTORY);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->flags & FFV2_DIRECTORY, FFV2_DIRECTORY);
EXPECT_EQ(pkt->dataLength, 0u);
}
// ============================================
// FileChunkPacketV2 解析测试
// ============================================
class FileChunkPacketV2ParseTest : public ::testing::Test {};
TEST_F(FileChunkPacketV2ParseTest, ParseValidPacket) {
std::string originalName = "test/file.txt";
std::vector<uint8_t> originalData = {0xAA, 0xBB, 0xCC, 0xDD, 0xEE};
auto buffer = BuildFileChunkPacketV2(
12345, 100, 200, 3, 10, 1024 * 1024, 512, originalName, originalData, FFV2_COMPRESSED);
FileChunkPacketV2 header;
std::string parsedName;
std::vector<uint8_t> parsedData;
bool result = ParseFileChunkPacketV2(buffer.data(), buffer.size(), header, parsedName, parsedData);
EXPECT_TRUE(result);
EXPECT_EQ(header.transferID, 12345u);
EXPECT_EQ(header.srcClientID, 100u);
EXPECT_EQ(header.dstClientID, 200u);
EXPECT_EQ(header.fileIndex, 3u);
EXPECT_EQ(header.totalFiles, 10u);
EXPECT_EQ(header.fileSize, 1024u * 1024u);
EXPECT_EQ(header.offset, 512u);
EXPECT_EQ(header.flags, FFV2_COMPRESSED);
EXPECT_EQ(parsedName, originalName);
EXPECT_EQ(parsedData, originalData);
}
TEST_F(FileChunkPacketV2ParseTest, ParseTruncatedHeader) {
std::vector<uint8_t> truncated(sizeof(FileChunkPacketV2) - 10);
FileChunkPacketV2 header;
std::string name;
std::vector<uint8_t> data;
bool result = ParseFileChunkPacketV2(truncated.data(), truncated.size(), header, name, data);
EXPECT_FALSE(result);
}
TEST_F(FileChunkPacketV2ParseTest, ParseTruncatedData) {
auto buffer = BuildFileChunkPacketV2(
1, 0, 0, 0, 1, 100, 0, "file.txt", {0x01, 0x02, 0x03});
// 截断数据部分
std::vector<uint8_t> truncated(buffer.begin(), buffer.end() - 2);
FileChunkPacketV2 header;
std::string name;
std::vector<uint8_t> data;
bool result = ParseFileChunkPacketV2(truncated.data(), truncated.size(), header, name, data);
EXPECT_FALSE(result);
}
TEST_F(FileChunkPacketV2ParseTest, RoundTrip) {
// 构建各种类型的包并验证往返
struct TestCase {
uint64_t transferID;
uint32_t fileIndex;
uint64_t fileSize;
uint64_t offset;
std::string filename;
std::vector<uint8_t> data;
uint16_t flags;
};
std::vector<TestCase> cases = {
{1, 0, 100, 0, "a.txt", {0x01}, FFV2_NONE},
{UINT64_MAX, 999, UINT64_MAX, UINT64_MAX - 100, "path/to/file.bin", {}, FFV2_LAST_CHUNK},
{12345, 5, 1024, 512, "中文文件.txt", {0xAA, 0xBB, 0xCC}, FFV2_COMPRESSED | FFV2_LAST_CHUNK},
};
for (const auto& tc : cases) {
auto buffer = BuildFileChunkPacketV2(
tc.transferID, 0, 0, tc.fileIndex, 10, tc.fileSize, tc.offset, tc.filename, tc.data, tc.flags);
FileChunkPacketV2 header;
std::string name;
std::vector<uint8_t> data;
ASSERT_TRUE(ParseFileChunkPacketV2(buffer.data(), buffer.size(), header, name, data));
EXPECT_EQ(header.transferID, tc.transferID);
EXPECT_EQ(header.fileIndex, tc.fileIndex);
EXPECT_EQ(header.fileSize, tc.fileSize);
EXPECT_EQ(header.offset, tc.offset);
EXPECT_EQ(header.flags, tc.flags);
EXPECT_EQ(name, tc.filename);
EXPECT_EQ(data, tc.data);
}
}
// ============================================
// 续传查询包测试
// ============================================
class ResumeQueryTest : public ::testing::Test {};
TEST_F(ResumeQueryTest, BuildEmpty) {
std::vector<std::pair<std::string, uint64_t>> files;
auto buffer = BuildResumeQuery(12345, 100, 200, files);
EXPECT_EQ(buffer.size(), sizeof(FileQueryResumeV2));
const FileQueryResumeV2* pkt = reinterpret_cast<const FileQueryResumeV2*>(buffer.data());
EXPECT_EQ(pkt->cmd, 88);
EXPECT_EQ(pkt->fileCount, 0u);
}
TEST_F(ResumeQueryTest, BuildSingleFile) {
std::vector<std::pair<std::string, uint64_t>> files = {
{"file.txt", 1024}
};
auto buffer = BuildResumeQuery(12345, 100, 200, files);
const FileQueryResumeV2* pkt = reinterpret_cast<const FileQueryResumeV2*>(buffer.data());
EXPECT_EQ(pkt->fileCount, 1u);
}
TEST_F(ResumeQueryTest, BuildMultipleFiles) {
std::vector<std::pair<std::string, uint64_t>> files = {
{"file1.txt", 1024},
{"dir/file2.bin", 2048},
{"another/path/file3.dat", 4096}
};
auto buffer = BuildResumeQuery(12345, 100, 200, files);
const FileQueryResumeV2* pkt = reinterpret_cast<const FileQueryResumeV2*>(buffer.data());
EXPECT_EQ(pkt->fileCount, 3u);
}
TEST_F(ResumeQueryTest, ParseRoundTrip) {
std::vector<std::pair<std::string, uint64_t>> original = {
{"file1.txt", 1024},
{"subdir/file2.bin", UINT64_MAX},
{"中文/文件.txt", 0}
};
auto buffer = BuildResumeQuery(99999, 111, 222, original);
uint64_t transferID, srcClientID, dstClientID;
std::vector<std::pair<std::string, uint64_t>> parsed;
bool result = ParseResumeQuery(buffer.data(), buffer.size(), transferID, srcClientID, dstClientID, parsed);
EXPECT_TRUE(result);
EXPECT_EQ(transferID, 99999u);
EXPECT_EQ(srcClientID, 111u);
EXPECT_EQ(dstClientID, 222u);
ASSERT_EQ(parsed.size(), original.size());
for (size_t i = 0; i < original.size(); ++i) {
EXPECT_EQ(parsed[i].first, original[i].first);
EXPECT_EQ(parsed[i].second, original[i].second);
}
}
TEST_F(ResumeQueryTest, ParseTruncated) {
std::vector<std::pair<std::string, uint64_t>> files = {
{"file.txt", 1024}
};
auto buffer = BuildResumeQuery(1, 0, 0, files);
// 截断
std::vector<uint8_t> truncated(buffer.begin(), buffer.begin() + sizeof(FileQueryResumeV2) + 5);
uint64_t transferID, srcClientID, dstClientID;
std::vector<std::pair<std::string, uint64_t>> parsed;
bool result = ParseResumeQuery(truncated.data(), truncated.size(), transferID, srcClientID, dstClientID, parsed);
EXPECT_FALSE(result);
}
// ============================================
// 大文件支持测试
// ============================================
class LargeFileTest : public ::testing::Test {};
TEST_F(LargeFileTest, FileSize_GreaterThan4GB) {
uint64_t largeSize = 5ULL * 1024 * 1024 * 1024; // 5 GB
auto buffer = BuildFileChunkPacketV2(
1, 0, 0, 0, 1, largeSize, 0, "large.bin", {});
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->fileSize, largeSize);
}
TEST_F(LargeFileTest, Offset_GreaterThan4GB) {
uint64_t largeOffset = 10ULL * 1024 * 1024 * 1024; // 10 GB
uint64_t fileSize = 20ULL * 1024 * 1024 * 1024; // 20 GB
auto buffer = BuildFileChunkPacketV2(
1, 0, 0, 0, 1, fileSize, largeOffset, "huge.bin", {});
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->fileSize, fileSize);
EXPECT_EQ(pkt->offset, largeOffset);
}
TEST_F(LargeFileTest, MaxValues) {
auto buffer = BuildFileChunkPacketV2(
UINT64_MAX, UINT64_MAX, UINT64_MAX,
UINT32_MAX, UINT32_MAX, UINT64_MAX, UINT64_MAX,
"max.bin", {}, UINT16_MAX);
FileChunkPacketV2 header;
std::string name;
std::vector<uint8_t> data;
ASSERT_TRUE(ParseFileChunkPacketV2(buffer.data(), buffer.size(), header, name, data));
EXPECT_EQ(header.transferID, UINT64_MAX);
EXPECT_EQ(header.srcClientID, UINT64_MAX);
EXPECT_EQ(header.dstClientID, UINT64_MAX);
EXPECT_EQ(header.fileIndex, UINT32_MAX);
EXPECT_EQ(header.totalFiles, UINT32_MAX);
EXPECT_EQ(header.fileSize, UINT64_MAX);
EXPECT_EQ(header.offset, UINT64_MAX);
}
// ============================================
// 多文件传输测试
// ============================================
class MultiFileTransferTest : public ::testing::Test {};
TEST_F(MultiFileTransferTest, SequentialFileIndices) {
std::vector<std::pair<std::string, uint64_t>> files = {
{"file1.txt", 100},
{"file2.txt", 200},
{"file3.txt", 300}
};
for (uint32_t i = 0; i < files.size(); ++i) {
auto buffer = BuildFileChunkPacketV2(
12345, 0, 0, i, static_cast<uint32_t>(files.size()),
files[i].second, 0, files[i].first, {}, FFV2_LAST_CHUNK);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->fileIndex, i);
EXPECT_EQ(pkt->totalFiles, files.size());
}
}
TEST_F(MultiFileTransferTest, ConsistentTransferID) {
uint64_t transferID = GenerateTransferID();
for (int i = 0; i < 5; ++i) {
auto buffer = BuildFileChunkPacketV2(
transferID, 0, 0, i, 5, 1000 * (i + 1), 0, "file" + std::to_string(i), {});
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->transferID, transferID);
}
}
// ============================================
// 错误处理测试
// ============================================
class ErrorHandlingTest : public ::testing::Test {};
TEST_F(ErrorHandlingTest, CancelFlag) {
auto buffer = BuildFileChunkPacketV2(
12345, 0, 0, 0, 1, 0, 0, "", {}, FFV2_CANCEL);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->flags & FFV2_CANCEL, FFV2_CANCEL);
}
TEST_F(ErrorHandlingTest, ErrorFlag) {
auto buffer = BuildFileChunkPacketV2(
12345, 0, 0, 0, 1, 0, 0, "", {}, FFV2_ERROR);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_EQ(pkt->flags & FFV2_ERROR, FFV2_ERROR);
}
TEST_F(ErrorHandlingTest, CombinedErrorFlags) {
uint16_t flags = FFV2_ERROR | FFV2_CANCEL | FFV2_LAST_CHUNK;
auto buffer = BuildFileChunkPacketV2(
12345, 0, 0, 0, 1, 0, 0, "", {}, flags);
const FileChunkPacketV2* pkt = reinterpret_cast<const FileChunkPacketV2*>(buffer.data());
EXPECT_TRUE(pkt->flags & FFV2_ERROR);
EXPECT_TRUE(pkt->flags & FFV2_CANCEL);
EXPECT_TRUE(pkt->flags & FFV2_LAST_CHUNK);
}

View File

@@ -0,0 +1,778 @@
/**
* @file ResumeStateTest.cpp
* @brief 断点续传状态管理测试
*
* 测试覆盖:
* - 续传状态序列化/反序列化
* - 续传请求/响应包构建
* - 状态文件格式
* - 多文件续传管理
*/
#include <gtest/gtest.h>
#include <cstring>
#include <vector>
#include <map>
#include <string>
#include <cstdint>
// ============================================
// 协议结构(测试专用)
// ============================================
#pragma pack(push, 1)
struct FileRangeV2 {
uint64_t offset;
uint64_t length;
};
struct FileResumePacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint64_t fileSize;
uint64_t receivedBytes;
uint16_t flags;
uint16_t rangeCount;
};
enum FileFlagsV2 : uint16_t {
FFV2_NONE = 0x0000,
FFV2_RESUME_REQ = 0x0002,
FFV2_RESUME_RESP = 0x0004,
};
struct FileResumeResponseV2 {
uint8_t cmd;
uint64_t srcClientID;
uint64_t dstClientID;
uint16_t flags;
uint32_t fileCount;
};
struct FileResumeResponseEntryV2 {
uint32_t fileIndex;
uint64_t receivedBytes;
};
#pragma pack(pop)
// ============================================
// 续传状态管理类(测试专用实现)
// ============================================
struct FileResumeEntry {
uint32_t fileIndex;
uint64_t fileSize;
uint64_t receivedBytes;
std::string relativePath;
std::vector<std::pair<uint64_t, uint64_t>> receivedRanges;
};
class ResumeStateManager {
public:
ResumeStateManager() : m_transferID(0), m_srcClientID(0), m_dstClientID(0) {}
void Initialize(uint64_t transferID, uint64_t srcClientID, uint64_t dstClientID,
const std::string& targetDir) {
m_transferID = transferID;
m_srcClientID = srcClientID;
m_dstClientID = dstClientID;
m_targetDir = targetDir;
m_entries.clear();
}
void AddFile(uint32_t fileIndex, uint64_t fileSize, const std::string& path) {
FileResumeEntry entry;
entry.fileIndex = fileIndex;
entry.fileSize = fileSize;
entry.receivedBytes = 0;
entry.relativePath = path;
m_entries.push_back(entry);
}
void UpdateProgress(uint32_t fileIndex, uint64_t offset, uint64_t length) {
for (auto& entry : m_entries) {
if (entry.fileIndex == fileIndex) {
entry.receivedRanges.emplace_back(offset, length);
entry.receivedBytes += length;
break;
}
}
}
bool GetFileState(uint32_t fileIndex, FileResumeEntry& outEntry) const {
for (const auto& entry : m_entries) {
if (entry.fileIndex == fileIndex) {
outEntry = entry;
return true;
}
}
return false;
}
// 序列化为字节流
std::vector<uint8_t> Serialize() const {
std::vector<uint8_t> buffer;
// Header
auto appendU64 = [&buffer](uint64_t val) {
for (int i = 0; i < 8; ++i) {
buffer.push_back(static_cast<uint8_t>(val >> (i * 8)));
}
};
auto appendU32 = [&buffer](uint32_t val) {
for (int i = 0; i < 4; ++i) {
buffer.push_back(static_cast<uint8_t>(val >> (i * 8)));
}
};
auto appendU16 = [&buffer](uint16_t val) {
buffer.push_back(static_cast<uint8_t>(val & 0xFF));
buffer.push_back(static_cast<uint8_t>(val >> 8));
};
auto appendString = [&buffer, &appendU16](const std::string& str) {
appendU16(static_cast<uint16_t>(str.size()));
buffer.insert(buffer.end(), str.begin(), str.end());
};
// Magic
buffer.push_back('R');
buffer.push_back('S');
buffer.push_back('T');
buffer.push_back('V'); // Resume State V2
appendU64(m_transferID);
appendU64(m_srcClientID);
appendU64(m_dstClientID);
appendString(m_targetDir);
appendU32(static_cast<uint32_t>(m_entries.size()));
for (const auto& entry : m_entries) {
appendU32(entry.fileIndex);
appendU64(entry.fileSize);
appendU64(entry.receivedBytes);
appendString(entry.relativePath);
appendU16(static_cast<uint16_t>(entry.receivedRanges.size()));
for (const auto& range : entry.receivedRanges) {
appendU64(range.first);
appendU64(range.second);
}
}
return buffer;
}
// 从字节流反序列化
bool Deserialize(const std::vector<uint8_t>& buffer) {
if (buffer.size() < 8) return false;
size_t pos = 0;
auto readU64 = [&buffer, &pos]() -> uint64_t {
if (pos + 8 > buffer.size()) return 0;
uint64_t val = 0;
for (int i = 0; i < 8; ++i) {
val |= static_cast<uint64_t>(buffer[pos++]) << (i * 8);
}
return val;
};
auto readU32 = [&buffer, &pos]() -> uint32_t {
if (pos + 4 > buffer.size()) return 0;
uint32_t val = 0;
for (int i = 0; i < 4; ++i) {
val |= static_cast<uint32_t>(buffer[pos++]) << (i * 8);
}
return val;
};
auto readU16 = [&buffer, &pos]() -> uint16_t {
if (pos + 2 > buffer.size()) return 0;
uint16_t val = buffer[pos] | (buffer[pos + 1] << 8);
pos += 2;
return val;
};
auto readString = [&buffer, &pos, &readU16]() -> std::string {
uint16_t len = readU16();
if (pos + len > buffer.size()) return "";
std::string str(buffer.begin() + pos, buffer.begin() + pos + len);
pos += len;
return str;
};
// Check magic
if (buffer[0] != 'R' || buffer[1] != 'S' || buffer[2] != 'T' || buffer[3] != 'V') {
return false;
}
pos = 4;
m_transferID = readU64();
m_srcClientID = readU64();
m_dstClientID = readU64();
m_targetDir = readString();
uint32_t entryCount = readU32();
m_entries.clear();
for (uint32_t i = 0; i < entryCount; ++i) {
FileResumeEntry entry;
entry.fileIndex = readU32();
entry.fileSize = readU64();
entry.receivedBytes = readU64();
entry.relativePath = readString();
uint16_t rangeCount = readU16();
for (uint16_t j = 0; j < rangeCount; ++j) {
uint64_t offset = readU64();
uint64_t length = readU64();
entry.receivedRanges.emplace_back(offset, length);
}
m_entries.push_back(entry);
}
return true;
}
uint64_t GetTransferID() const { return m_transferID; }
uint64_t GetSrcClientID() const { return m_srcClientID; }
uint64_t GetDstClientID() const { return m_dstClientID; }
const std::string& GetTargetDir() const { return m_targetDir; }
size_t GetFileCount() const { return m_entries.size(); }
// 获取所有文件的接收偏移映射
std::map<uint32_t, uint64_t> GetReceivedOffsets() const {
std::map<uint32_t, uint64_t> offsets;
for (const auto& entry : m_entries) {
offsets[entry.fileIndex] = entry.receivedBytes;
}
return offsets;
}
private:
uint64_t m_transferID;
uint64_t m_srcClientID;
uint64_t m_dstClientID;
std::string m_targetDir;
std::vector<FileResumeEntry> m_entries;
};
// ============================================
// 续传包构建/解析辅助函数
// ============================================
std::vector<uint8_t> BuildResumeRequest(
uint64_t transferID,
uint64_t srcClientID,
uint64_t dstClientID,
uint32_t fileIndex,
uint64_t fileSize,
uint64_t receivedBytes,
const std::vector<std::pair<uint64_t, uint64_t>>& ranges)
{
size_t size = sizeof(FileResumePacketV2) + ranges.size() * sizeof(FileRangeV2);
std::vector<uint8_t> buffer(size);
FileResumePacketV2* pkt = reinterpret_cast<FileResumePacketV2*>(buffer.data());
pkt->cmd = 86; // COMMAND_FILE_RESUME
pkt->transferID = transferID;
pkt->srcClientID = srcClientID;
pkt->dstClientID = dstClientID;
pkt->fileIndex = fileIndex;
pkt->fileSize = fileSize;
pkt->receivedBytes = receivedBytes;
pkt->flags = FFV2_RESUME_REQ;
pkt->rangeCount = static_cast<uint16_t>(ranges.size());
FileRangeV2* rangePtr = reinterpret_cast<FileRangeV2*>(buffer.data() + sizeof(FileResumePacketV2));
for (size_t i = 0; i < ranges.size(); ++i) {
rangePtr[i].offset = ranges[i].first;
rangePtr[i].length = ranges[i].second;
}
return buffer;
}
bool ParseResumeRequest(
const uint8_t* buffer, size_t len,
FileResumePacketV2& header,
std::vector<std::pair<uint64_t, uint64_t>>& ranges)
{
if (len < sizeof(FileResumePacketV2)) {
return false;
}
memcpy(&header, buffer, sizeof(FileResumePacketV2));
size_t expectedLen = sizeof(FileResumePacketV2) + header.rangeCount * sizeof(FileRangeV2);
if (len < expectedLen) {
return false;
}
ranges.clear();
const FileRangeV2* rangePtr = reinterpret_cast<const FileRangeV2*>(buffer + sizeof(FileResumePacketV2));
for (uint16_t i = 0; i < header.rangeCount; ++i) {
ranges.emplace_back(rangePtr[i].offset, rangePtr[i].length);
}
return true;
}
std::vector<uint8_t> BuildResumeResponse(
uint64_t srcClientID,
uint64_t dstClientID,
const std::map<uint32_t, uint64_t>& offsets)
{
size_t size = sizeof(FileResumeResponseV2) + offsets.size() * sizeof(FileResumeResponseEntryV2);
std::vector<uint8_t> buffer(size);
FileResumeResponseV2* pkt = reinterpret_cast<FileResumeResponseV2*>(buffer.data());
pkt->cmd = 86; // COMMAND_FILE_RESUME
pkt->srcClientID = srcClientID;
pkt->dstClientID = dstClientID;
pkt->flags = FFV2_RESUME_RESP;
pkt->fileCount = static_cast<uint32_t>(offsets.size());
FileResumeResponseEntryV2* entryPtr = reinterpret_cast<FileResumeResponseEntryV2*>(
buffer.data() + sizeof(FileResumeResponseV2));
size_t i = 0;
for (const auto& kv : offsets) {
entryPtr[i].fileIndex = kv.first;
entryPtr[i].receivedBytes = kv.second;
++i;
}
return buffer;
}
bool ParseResumeResponse(
const uint8_t* buffer, size_t len,
std::map<uint32_t, uint64_t>& offsets)
{
if (len < sizeof(FileResumeResponseV2)) {
return false;
}
const FileResumeResponseV2* pkt = reinterpret_cast<const FileResumeResponseV2*>(buffer);
if ((pkt->flags & FFV2_RESUME_RESP) == 0) {
return false;
}
size_t expectedLen = sizeof(FileResumeResponseV2) + pkt->fileCount * sizeof(FileResumeResponseEntryV2);
if (len < expectedLen) {
return false;
}
offsets.clear();
const FileResumeResponseEntryV2* entryPtr = reinterpret_cast<const FileResumeResponseEntryV2*>(
buffer + sizeof(FileResumeResponseV2));
for (uint32_t i = 0; i < pkt->fileCount; ++i) {
offsets[entryPtr[i].fileIndex] = entryPtr[i].receivedBytes;
}
return true;
}
// ============================================
// ResumeStateManager 测试
// ============================================
class ResumeStateManagerTest : public ::testing::Test {};
TEST_F(ResumeStateManagerTest, Initialize) {
ResumeStateManager mgr;
mgr.Initialize(12345, 100, 200, "C:\\Downloads\\");
EXPECT_EQ(mgr.GetTransferID(), 12345u);
EXPECT_EQ(mgr.GetSrcClientID(), 100u);
EXPECT_EQ(mgr.GetDstClientID(), 200u);
EXPECT_EQ(mgr.GetTargetDir(), "C:\\Downloads\\");
EXPECT_EQ(mgr.GetFileCount(), 0u);
}
TEST_F(ResumeStateManagerTest, AddFiles) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "");
mgr.AddFile(0, 1000, "file1.txt");
mgr.AddFile(1, 2000, "subdir/file2.bin");
mgr.AddFile(2, 3000, "another/path/file3.dat");
EXPECT_EQ(mgr.GetFileCount(), 3u);
FileResumeEntry entry;
ASSERT_TRUE(mgr.GetFileState(1, entry));
EXPECT_EQ(entry.fileSize, 2000u);
EXPECT_EQ(entry.relativePath, "subdir/file2.bin");
}
TEST_F(ResumeStateManagerTest, UpdateProgress) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "");
mgr.AddFile(0, 10000, "file.bin");
mgr.UpdateProgress(0, 0, 2000);
mgr.UpdateProgress(0, 2000, 3000);
FileResumeEntry entry;
ASSERT_TRUE(mgr.GetFileState(0, entry));
EXPECT_EQ(entry.receivedBytes, 5000u);
EXPECT_EQ(entry.receivedRanges.size(), 2u);
}
TEST_F(ResumeStateManagerTest, GetReceivedOffsets) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "");
mgr.AddFile(0, 1000, "a.txt");
mgr.AddFile(1, 2000, "b.txt");
mgr.AddFile(2, 3000, "c.txt");
mgr.UpdateProgress(0, 0, 500);
mgr.UpdateProgress(1, 0, 1500);
mgr.UpdateProgress(2, 0, 2500);
auto offsets = mgr.GetReceivedOffsets();
EXPECT_EQ(offsets.size(), 3u);
EXPECT_EQ(offsets[0], 500u);
EXPECT_EQ(offsets[1], 1500u);
EXPECT_EQ(offsets[2], 2500u);
}
// ============================================
// 序列化/反序列化测试
// ============================================
class ResumeSerializationTest : public ::testing::Test {};
TEST_F(ResumeSerializationTest, EmptyState) {
ResumeStateManager mgr1, mgr2;
mgr1.Initialize(12345, 100, 200, "C:\\Target\\");
auto buffer = mgr1.Serialize();
ASSERT_TRUE(mgr2.Deserialize(buffer));
EXPECT_EQ(mgr2.GetTransferID(), 12345u);
EXPECT_EQ(mgr2.GetSrcClientID(), 100u);
EXPECT_EQ(mgr2.GetDstClientID(), 200u);
EXPECT_EQ(mgr2.GetTargetDir(), "C:\\Target\\");
EXPECT_EQ(mgr2.GetFileCount(), 0u);
}
TEST_F(ResumeSerializationTest, SingleFile) {
ResumeStateManager mgr1, mgr2;
mgr1.Initialize(1, 0, 0, "/tmp/download/");
mgr1.AddFile(0, 10000, "test.bin");
mgr1.UpdateProgress(0, 0, 5000);
auto buffer = mgr1.Serialize();
ASSERT_TRUE(mgr2.Deserialize(buffer));
EXPECT_EQ(mgr2.GetFileCount(), 1u);
FileResumeEntry entry;
ASSERT_TRUE(mgr2.GetFileState(0, entry));
EXPECT_EQ(entry.fileSize, 10000u);
EXPECT_EQ(entry.receivedBytes, 5000u);
EXPECT_EQ(entry.relativePath, "test.bin");
}
TEST_F(ResumeSerializationTest, MultipleFiles) {
ResumeStateManager mgr1, mgr2;
mgr1.Initialize(99999, 111, 222, "D:\\Backup\\");
mgr1.AddFile(0, 1000, "file1.txt");
mgr1.AddFile(1, 2000, "dir/file2.bin");
mgr1.AddFile(2, 3000, "path/to/file3.dat");
mgr1.UpdateProgress(0, 0, 1000); // 完成
mgr1.UpdateProgress(1, 0, 500);
mgr1.UpdateProgress(1, 500, 500);
mgr1.UpdateProgress(2, 0, 1000);
mgr1.UpdateProgress(2, 2000, 500);
auto buffer = mgr1.Serialize();
ASSERT_TRUE(mgr2.Deserialize(buffer));
EXPECT_EQ(mgr2.GetFileCount(), 3u);
FileResumeEntry entry;
ASSERT_TRUE(mgr2.GetFileState(0, entry));
EXPECT_EQ(entry.receivedBytes, 1000u);
ASSERT_TRUE(mgr2.GetFileState(1, entry));
EXPECT_EQ(entry.receivedBytes, 1000u);
EXPECT_EQ(entry.receivedRanges.size(), 2u);
ASSERT_TRUE(mgr2.GetFileState(2, entry));
EXPECT_EQ(entry.receivedBytes, 1500u);
EXPECT_EQ(entry.receivedRanges.size(), 2u);
}
TEST_F(ResumeSerializationTest, InvalidMagic) {
std::vector<uint8_t> invalidBuffer = {'X', 'X', 'X', 'X', 0, 0, 0, 0};
ResumeStateManager mgr;
EXPECT_FALSE(mgr.Deserialize(invalidBuffer));
}
TEST_F(ResumeSerializationTest, TruncatedBuffer) {
ResumeStateManager mgr1;
mgr1.Initialize(1, 0, 0, "test");
mgr1.AddFile(0, 1000, "file.txt");
auto buffer = mgr1.Serialize();
// 截断
std::vector<uint8_t> truncated(buffer.begin(), buffer.begin() + 10);
ResumeStateManager mgr2;
// 可能成功也可能失败,取决于截断位置
// 主要验证不会崩溃
mgr2.Deserialize(truncated);
SUCCEED();
}
// ============================================
// 续传请求/响应包测试
// ============================================
class ResumePacketTest : public ::testing::Test {};
TEST_F(ResumePacketTest, BuildResumeRequest_NoRanges) {
std::vector<std::pair<uint64_t, uint64_t>> ranges;
auto buffer = BuildResumeRequest(12345, 100, 200, 5, 10000, 0, ranges);
EXPECT_EQ(buffer.size(), sizeof(FileResumePacketV2));
const FileResumePacketV2* pkt = reinterpret_cast<const FileResumePacketV2*>(buffer.data());
EXPECT_EQ(pkt->cmd, 86);
EXPECT_EQ(pkt->flags, FFV2_RESUME_REQ);
EXPECT_EQ(pkt->rangeCount, 0);
}
TEST_F(ResumePacketTest, BuildResumeRequest_WithRanges) {
std::vector<std::pair<uint64_t, uint64_t>> ranges = {
{0, 1000},
{2000, 500},
{5000, 2000}
};
auto buffer = BuildResumeRequest(1, 0, 0, 0, 10000, 3500, ranges);
FileResumePacketV2 header;
std::vector<std::pair<uint64_t, uint64_t>> parsedRanges;
ASSERT_TRUE(ParseResumeRequest(buffer.data(), buffer.size(), header, parsedRanges));
EXPECT_EQ(header.fileSize, 10000u);
EXPECT_EQ(header.receivedBytes, 3500u);
ASSERT_EQ(parsedRanges.size(), 3u);
EXPECT_EQ(parsedRanges[0].first, 0u);
EXPECT_EQ(parsedRanges[0].second, 1000u);
EXPECT_EQ(parsedRanges[2].first, 5000u);
}
TEST_F(ResumePacketTest, BuildResumeResponse) {
std::map<uint32_t, uint64_t> offsets = {
{0, 1000},
{1, 0},
{2, 5000}
};
auto buffer = BuildResumeResponse(100, 200, offsets);
std::map<uint32_t, uint64_t> parsedOffsets;
ASSERT_TRUE(ParseResumeResponse(buffer.data(), buffer.size(), parsedOffsets));
EXPECT_EQ(parsedOffsets.size(), 3u);
EXPECT_EQ(parsedOffsets[0], 1000u);
EXPECT_EQ(parsedOffsets[1], 0u);
EXPECT_EQ(parsedOffsets[2], 5000u);
}
TEST_F(ResumePacketTest, ParseTruncatedRequest) {
auto buffer = BuildResumeRequest(1, 0, 0, 0, 1000, 0, {{0, 500}});
// 截断
FileResumePacketV2 header;
std::vector<std::pair<uint64_t, uint64_t>> ranges;
EXPECT_FALSE(ParseResumeRequest(buffer.data(), sizeof(FileResumePacketV2) - 5, header, ranges));
}
TEST_F(ResumePacketTest, ParseTruncatedResponse) {
std::map<uint32_t, uint64_t> offsets = {{0, 1000}};
auto buffer = BuildResumeResponse(0, 0, offsets);
// 截断
std::map<uint32_t, uint64_t> parsedOffsets;
EXPECT_FALSE(ParseResumeResponse(buffer.data(), sizeof(FileResumeResponseV2) - 5, parsedOffsets));
}
// ============================================
// 续传场景测试
// ============================================
class ResumeScenarioTest : public ::testing::Test {};
TEST_F(ResumeScenarioTest, SimulateInterruptAndResume) {
// 第一次传输,接收了部分数据
ResumeStateManager session1;
session1.Initialize(12345, 100, 0, "C:\\Downloads\\");
session1.AddFile(0, 10000, "large_file.bin");
session1.AddFile(1, 5000, "small_file.txt");
session1.UpdateProgress(0, 0, 3000); // 30%
session1.UpdateProgress(1, 0, 5000); // 100%
// 保存状态
auto savedState = session1.Serialize();
// 模拟程序重启,恢复状态
ResumeStateManager session2;
ASSERT_TRUE(session2.Deserialize(savedState));
// 获取续传偏移
auto offsets = session2.GetReceivedOffsets();
EXPECT_EQ(offsets[0], 3000u); // 从 3000 继续
EXPECT_EQ(offsets[1], 5000u); // 已完成
// 继续传输
session2.UpdateProgress(0, 3000, 7000);
FileResumeEntry entry;
ASSERT_TRUE(session2.GetFileState(0, entry));
EXPECT_EQ(entry.receivedBytes, 10000u);
}
TEST_F(ResumeScenarioTest, C2CResumeNegotiation) {
// 源客户端查询续传状态
std::vector<std::pair<std::string, uint64_t>> files = {
{"file1.txt", 1000},
{"file2.txt", 2000},
{"file3.txt", 3000}
};
// 目标客户端已有部分数据
std::map<uint32_t, uint64_t> receivedOffsets = {
{0, 500}, // file1: 50%
{1, 2000}, // file2: 100%
{2, 0} // file3: 0%
};
// 构建响应
auto response = BuildResumeResponse(100, 200, receivedOffsets);
// 源客户端解析响应
std::map<uint32_t, uint64_t> parsedOffsets;
ASSERT_TRUE(ParseResumeResponse(response.data(), response.size(), parsedOffsets));
// 根据偏移决定发送策略
for (size_t i = 0; i < files.size(); ++i) {
uint32_t fileIndex = static_cast<uint32_t>(i);
uint64_t startOffset = parsedOffsets[fileIndex];
if (startOffset >= files[i].second) {
// 已完成,跳过
EXPECT_EQ(i, 1u) << "Only file2 should be complete";
} else if (startOffset > 0) {
// 部分完成,从偏移继续
EXPECT_EQ(i, 0u) << "Only file1 should be partial";
EXPECT_EQ(startOffset, 500u);
} else {
// 未开始,从头发送
EXPECT_EQ(i, 2u) << "Only file3 should start from beginning";
}
}
}
TEST_F(ResumeScenarioTest, LargeFileResume) {
ResumeStateManager mgr;
uint64_t fileSize = 10ULL * 1024 * 1024 * 1024; // 10 GB
mgr.Initialize(1, 0, 0, "/data/");
mgr.AddFile(0, fileSize, "huge.bin");
// 模拟接收了 5 GB
mgr.UpdateProgress(0, 0, 5ULL * 1024 * 1024 * 1024);
auto offsets = mgr.GetReceivedOffsets();
EXPECT_EQ(offsets[0], 5ULL * 1024 * 1024 * 1024);
// 序列化/反序列化
auto buffer = mgr.Serialize();
ResumeStateManager mgr2;
ASSERT_TRUE(mgr2.Deserialize(buffer));
FileResumeEntry entry;
ASSERT_TRUE(mgr2.GetFileState(0, entry));
EXPECT_EQ(entry.fileSize, 10ULL * 1024 * 1024 * 1024);
EXPECT_EQ(entry.receivedBytes, 5ULL * 1024 * 1024 * 1024);
}
// ============================================
// 边界条件测试
// ============================================
class ResumeBoundaryTest : public ::testing::Test {};
TEST_F(ResumeBoundaryTest, EmptyFileName) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "");
mgr.AddFile(0, 100, "");
auto buffer = mgr.Serialize();
ResumeStateManager mgr2;
ASSERT_TRUE(mgr2.Deserialize(buffer));
FileResumeEntry entry;
ASSERT_TRUE(mgr2.GetFileState(0, entry));
EXPECT_EQ(entry.relativePath, "");
}
TEST_F(ResumeBoundaryTest, SpecialCharactersInPath) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "C:\\My Files\\测试目录\\");
mgr.AddFile(0, 100, "文件 (1).txt");
auto buffer = mgr.Serialize();
ResumeStateManager mgr2;
ASSERT_TRUE(mgr2.Deserialize(buffer));
EXPECT_EQ(mgr2.GetTargetDir(), "C:\\My Files\\测试目录\\");
FileResumeEntry entry;
ASSERT_TRUE(mgr2.GetFileState(0, entry));
EXPECT_EQ(entry.relativePath, "文件 (1).txt");
}
TEST_F(ResumeBoundaryTest, ManyRanges) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "");
mgr.AddFile(0, 100000, "fragmented.bin");
// 添加很多小区间
for (uint64_t i = 0; i < 100000; i += 20) {
mgr.UpdateProgress(0, i, 10);
}
auto buffer = mgr.Serialize();
ResumeStateManager mgr2;
ASSERT_TRUE(mgr2.Deserialize(buffer));
FileResumeEntry entry;
ASSERT_TRUE(mgr2.GetFileState(0, entry));
EXPECT_EQ(entry.receivedRanges.size(), 5000u);
}
TEST_F(ResumeBoundaryTest, MaxFileCount) {
ResumeStateManager mgr;
mgr.Initialize(1, 0, 0, "");
// 添加大量文件
for (uint32_t i = 0; i < 1000; ++i) {
mgr.AddFile(i, 100 * (i + 1), "file" + std::to_string(i) + ".txt");
}
auto buffer = mgr.Serialize();
ResumeStateManager mgr2;
ASSERT_TRUE(mgr2.Deserialize(buffer));
EXPECT_EQ(mgr2.GetFileCount(), 1000u);
}

View File

@@ -0,0 +1,647 @@
/**
* @file SHA256VerifyTest.cpp
* @brief SHA-256 文件校验测试
*
* 测试覆盖:
* - SHA-256 哈希计算
* - FileCompletePacketV2 构建与解析
* - 文件完整性校验逻辑
* - 校验失败处理
*/
#include <gtest/gtest.h>
#include <cstring>
#include <vector>
#include <array>
#include <string>
#include <iomanip>
#include <sstream>
// ============================================
// 简化版 SHA-256 实现(测试专用)
// 生产环境使用 OpenSSL 或 Windows CNG
// ============================================
class SHA256 {
public:
SHA256() { Reset(); }
void Reset() {
m_state[0] = 0x6a09e667;
m_state[1] = 0xbb67ae85;
m_state[2] = 0x3c6ef372;
m_state[3] = 0xa54ff53a;
m_state[4] = 0x510e527f;
m_state[5] = 0x9b05688c;
m_state[6] = 0x1f83d9ab;
m_state[7] = 0x5be0cd19;
m_bitLen = 0;
m_bufferLen = 0;
}
void Update(const uint8_t* data, size_t len) {
for (size_t i = 0; i < len; ++i) {
m_buffer[m_bufferLen++] = data[i];
if (m_bufferLen == 64) {
Transform();
m_bitLen += 512;
m_bufferLen = 0;
}
}
}
void Update(const std::vector<uint8_t>& data) {
Update(data.data(), data.size());
}
std::array<uint8_t, 32> Finalize() {
size_t i = m_bufferLen;
// Pad
if (m_bufferLen < 56) {
m_buffer[i++] = 0x80;
while (i < 56) m_buffer[i++] = 0x00;
} else {
m_buffer[i++] = 0x80;
while (i < 64) m_buffer[i++] = 0x00;
Transform();
memset(m_buffer, 0, 56);
}
// Append length
m_bitLen += m_bufferLen * 8;
m_buffer[63] = static_cast<uint8_t>(m_bitLen);
m_buffer[62] = static_cast<uint8_t>(m_bitLen >> 8);
m_buffer[61] = static_cast<uint8_t>(m_bitLen >> 16);
m_buffer[60] = static_cast<uint8_t>(m_bitLen >> 24);
m_buffer[59] = static_cast<uint8_t>(m_bitLen >> 32);
m_buffer[58] = static_cast<uint8_t>(m_bitLen >> 40);
m_buffer[57] = static_cast<uint8_t>(m_bitLen >> 48);
m_buffer[56] = static_cast<uint8_t>(m_bitLen >> 56);
Transform();
// Output
std::array<uint8_t, 32> hash;
for (int j = 0; j < 8; ++j) {
hash[j * 4 + 0] = (m_state[j] >> 24) & 0xff;
hash[j * 4 + 1] = (m_state[j] >> 16) & 0xff;
hash[j * 4 + 2] = (m_state[j] >> 8) & 0xff;
hash[j * 4 + 3] = m_state[j] & 0xff;
}
return hash;
}
static std::string HashToHex(const std::array<uint8_t, 32>& hash) {
std::ostringstream ss;
for (uint8_t b : hash) {
ss << std::hex << std::setfill('0') << std::setw(2) << static_cast<int>(b);
}
return ss.str();
}
private:
static uint32_t RotR(uint32_t x, uint32_t n) { return (x >> n) | (x << (32 - n)); }
static uint32_t Ch(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ (~x & z); }
static uint32_t Maj(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ (x & z) ^ (y & z); }
static uint32_t Sig0(uint32_t x) { return RotR(x, 2) ^ RotR(x, 13) ^ RotR(x, 22); }
static uint32_t Sig1(uint32_t x) { return RotR(x, 6) ^ RotR(x, 11) ^ RotR(x, 25); }
static uint32_t sig0(uint32_t x) { return RotR(x, 7) ^ RotR(x, 18) ^ (x >> 3); }
static uint32_t sig1(uint32_t x) { return RotR(x, 17) ^ RotR(x, 19) ^ (x >> 10); }
void Transform() {
static const uint32_t K[64] = {
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
};
uint32_t W[64];
for (int i = 0; i < 16; ++i) {
W[i] = (m_buffer[i * 4] << 24) | (m_buffer[i * 4 + 1] << 16) |
(m_buffer[i * 4 + 2] << 8) | m_buffer[i * 4 + 3];
}
for (int i = 16; i < 64; ++i) {
W[i] = sig1(W[i - 2]) + W[i - 7] + sig0(W[i - 15]) + W[i - 16];
}
uint32_t a = m_state[0], b = m_state[1], c = m_state[2], d = m_state[3];
uint32_t e = m_state[4], f = m_state[5], g = m_state[6], h = m_state[7];
for (int i = 0; i < 64; ++i) {
uint32_t t1 = h + Sig1(e) + Ch(e, f, g) + K[i] + W[i];
uint32_t t2 = Sig0(a) + Maj(a, b, c);
h = g; g = f; f = e; e = d + t1;
d = c; c = b; b = a; a = t1 + t2;
}
m_state[0] += a; m_state[1] += b; m_state[2] += c; m_state[3] += d;
m_state[4] += e; m_state[5] += f; m_state[6] += g; m_state[7] += h;
}
uint32_t m_state[8];
uint8_t m_buffer[64];
uint64_t m_bitLen;
size_t m_bufferLen;
};
// ============================================
// 协议结构(测试专用)
// ============================================
#pragma pack(push, 1)
struct FileCompletePacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint64_t fileSize;
uint8_t sha256[32];
};
enum FileErrorV2 : uint8_t {
FEV2_OK = 0,
FEV2_TARGET_OFFLINE = 1,
FEV2_VERSION_MISMATCH = 2,
FEV2_FILE_NOT_FOUND = 3,
FEV2_ACCESS_DENIED = 4,
FEV2_DISK_FULL = 5,
FEV2_TRANSFER_CANCEL = 6,
FEV2_CHECKSUM_ERROR = 7,
FEV2_HASH_MISMATCH = 8,
};
#pragma pack(pop)
// ============================================
// 辅助函数
// ============================================
std::vector<uint8_t> BuildFileCompletePacket(
uint64_t transferID,
uint64_t srcClientID,
uint64_t dstClientID,
uint32_t fileIndex,
uint64_t fileSize,
const std::array<uint8_t, 32>& sha256)
{
std::vector<uint8_t> buffer(sizeof(FileCompletePacketV2));
FileCompletePacketV2* pkt = reinterpret_cast<FileCompletePacketV2*>(buffer.data());
pkt->cmd = 91; // COMMAND_FILE_COMPLETE_V2
pkt->transferID = transferID;
pkt->srcClientID = srcClientID;
pkt->dstClientID = dstClientID;
pkt->fileIndex = fileIndex;
pkt->fileSize = fileSize;
memcpy(pkt->sha256, sha256.data(), 32);
return buffer;
}
bool ParseFileCompletePacket(
const uint8_t* buffer, size_t len,
FileCompletePacketV2& pkt)
{
if (len < sizeof(FileCompletePacketV2)) {
return false;
}
memcpy(&pkt, buffer, sizeof(FileCompletePacketV2));
return true;
}
// 校验文件完整性
FileErrorV2 VerifyFileIntegrity(
const std::array<uint8_t, 32>& expectedHash,
const std::vector<uint8_t>& fileData)
{
SHA256 sha;
sha.Update(fileData);
auto actualHash = sha.Finalize();
if (actualHash == expectedHash) {
return FEV2_OK;
}
return FEV2_HASH_MISMATCH;
}
// ============================================
// SHA-256 基础测试
// ============================================
class SHA256Test : public ::testing::Test {};
TEST_F(SHA256Test, EmptyString) {
SHA256 sha;
auto hash = sha.Finalize();
// SHA-256("") = e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
std::string hex = SHA256::HashToHex(hash);
EXPECT_EQ(hex, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
}
TEST_F(SHA256Test, HelloWorld) {
SHA256 sha;
const char* msg = "Hello, World!";
sha.Update(reinterpret_cast<const uint8_t*>(msg), strlen(msg));
auto hash = sha.Finalize();
// SHA-256("Hello, World!") = dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f
std::string hex = SHA256::HashToHex(hash);
EXPECT_EQ(hex, "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f");
}
TEST_F(SHA256Test, ABC) {
SHA256 sha;
const char* msg = "abc";
sha.Update(reinterpret_cast<const uint8_t*>(msg), strlen(msg));
auto hash = sha.Finalize();
// SHA-256("abc") = ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
std::string hex = SHA256::HashToHex(hash);
EXPECT_EQ(hex, "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad");
}
TEST_F(SHA256Test, IncrementalUpdate) {
SHA256 sha1, sha2;
const char* part1 = "Hello, ";
const char* part2 = "World!";
const char* full = "Hello, World!";
sha1.Update(reinterpret_cast<const uint8_t*>(part1), strlen(part1));
sha1.Update(reinterpret_cast<const uint8_t*>(part2), strlen(part2));
sha2.Update(reinterpret_cast<const uint8_t*>(full), strlen(full));
auto hash1 = sha1.Finalize();
auto hash2 = sha2.Finalize();
EXPECT_EQ(hash1, hash2);
}
TEST_F(SHA256Test, LargeData) {
SHA256 sha;
// 1 MB of zeros
std::vector<uint8_t> data(1024 * 1024, 0);
sha.Update(data);
auto hash = sha.Finalize();
// 验证计算成功(不崩溃)
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(SHA256Test, BinaryData) {
SHA256 sha;
std::vector<uint8_t> data = {0x00, 0x01, 0x02, 0x03, 0xff, 0xfe, 0xfd, 0xfc};
sha.Update(data);
auto hash = sha.Finalize();
// 验证计算成功
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(SHA256Test, Reset) {
SHA256 sha;
sha.Update(reinterpret_cast<const uint8_t*>("test"), 4);
sha.Reset();
auto hash = sha.Finalize();
// 重置后应该等于空字符串的哈希
std::string hex = SHA256::HashToHex(hash);
EXPECT_EQ(hex, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855");
}
// ============================================
// FileCompletePacketV2 测试
// ============================================
class FileCompletePacketTest : public ::testing::Test {};
TEST_F(FileCompletePacketTest, BuildAndParse) {
std::array<uint8_t, 32> sha256 = {};
sha256[0] = 0xAA;
sha256[31] = 0xBB;
auto buffer = BuildFileCompletePacket(12345, 100, 200, 5, 1024 * 1024, sha256);
FileCompletePacketV2 pkt;
ASSERT_TRUE(ParseFileCompletePacket(buffer.data(), buffer.size(), pkt));
EXPECT_EQ(pkt.cmd, 91);
EXPECT_EQ(pkt.transferID, 12345u);
EXPECT_EQ(pkt.srcClientID, 100u);
EXPECT_EQ(pkt.dstClientID, 200u);
EXPECT_EQ(pkt.fileIndex, 5u);
EXPECT_EQ(pkt.fileSize, 1024u * 1024u);
EXPECT_EQ(pkt.sha256[0], 0xAA);
EXPECT_EQ(pkt.sha256[31], 0xBB);
}
TEST_F(FileCompletePacketTest, TruncatedPacket) {
std::array<uint8_t, 32> sha256 = {};
auto buffer = BuildFileCompletePacket(1, 0, 0, 0, 100, sha256);
// 截断
FileCompletePacketV2 pkt;
EXPECT_FALSE(ParseFileCompletePacket(buffer.data(), sizeof(FileCompletePacketV2) - 10, pkt));
}
TEST_F(FileCompletePacketTest, LargeFileSize) {
std::array<uint8_t, 32> sha256 = {};
uint64_t largeSize = 100ULL * 1024 * 1024 * 1024; // 100 GB
auto buffer = BuildFileCompletePacket(1, 0, 0, 0, largeSize, sha256);
FileCompletePacketV2 pkt;
ASSERT_TRUE(ParseFileCompletePacket(buffer.data(), buffer.size(), pkt));
EXPECT_EQ(pkt.fileSize, largeSize);
}
// ============================================
// 文件完整性校验测试
// ============================================
class FileIntegrityTest : public ::testing::Test {};
TEST_F(FileIntegrityTest, CorrectHash) {
std::vector<uint8_t> fileData = {'H', 'e', 'l', 'l', 'o'};
SHA256 sha;
sha.Update(fileData);
auto expectedHash = sha.Finalize();
auto result = VerifyFileIntegrity(expectedHash, fileData);
EXPECT_EQ(result, FEV2_OK);
}
TEST_F(FileIntegrityTest, IncorrectHash) {
std::vector<uint8_t> fileData = {'H', 'e', 'l', 'l', 'o'};
std::array<uint8_t, 32> wrongHash = {}; // 全零的错误哈希
auto result = VerifyFileIntegrity(wrongHash, fileData);
EXPECT_EQ(result, FEV2_HASH_MISMATCH);
}
TEST_F(FileIntegrityTest, CorruptedData) {
std::vector<uint8_t> originalData = {'H', 'e', 'l', 'l', 'o'};
SHA256 sha;
sha.Update(originalData);
auto originalHash = sha.Finalize();
// 修改一个字节
std::vector<uint8_t> corruptedData = originalData;
corruptedData[2] = 'X';
auto result = VerifyFileIntegrity(originalHash, corruptedData);
EXPECT_EQ(result, FEV2_HASH_MISMATCH);
}
TEST_F(FileIntegrityTest, EmptyFile) {
std::vector<uint8_t> emptyData;
SHA256 sha;
auto emptyHash = sha.Finalize();
auto result = VerifyFileIntegrity(emptyHash, emptyData);
EXPECT_EQ(result, FEV2_OK);
}
TEST_F(FileIntegrityTest, SingleBitDifference) {
std::vector<uint8_t> data1 = {0x00};
std::vector<uint8_t> data2 = {0x01}; // 单比特差异
SHA256 sha1, sha2;
sha1.Update(data1);
sha2.Update(data2);
auto hash1 = sha1.Finalize();
auto hash2 = sha2.Finalize();
// 哈希应该完全不同
EXPECT_NE(hash1, hash2);
// 验证错误检测
auto result = VerifyFileIntegrity(hash1, data2);
EXPECT_EQ(result, FEV2_HASH_MISMATCH);
}
// ============================================
// 流式哈希计算测试
// ============================================
class StreamingHashTest : public ::testing::Test {};
TEST_F(StreamingHashTest, ChunkedUpdate) {
std::vector<uint8_t> fullData(10000);
for (size_t i = 0; i < fullData.size(); ++i) {
fullData[i] = static_cast<uint8_t>(i & 0xFF);
}
// 一次性计算
SHA256 sha1;
sha1.Update(fullData);
auto hash1 = sha1.Finalize();
// 分块计算
SHA256 sha2;
size_t chunkSize = 64; // SHA-256 块大小
for (size_t i = 0; i < fullData.size(); i += chunkSize) {
size_t len = std::min(chunkSize, fullData.size() - i);
sha2.Update(fullData.data() + i, len);
}
auto hash2 = sha2.Finalize();
EXPECT_EQ(hash1, hash2);
}
TEST_F(StreamingHashTest, VariableChunkSizes) {
std::vector<uint8_t> data(1000);
for (size_t i = 0; i < data.size(); ++i) {
data[i] = static_cast<uint8_t>(i);
}
SHA256 sha1;
sha1.Update(data);
auto expected = sha1.Finalize();
// 不同大小的块
std::vector<size_t> chunkSizes = {1, 7, 63, 64, 65, 128, 1000};
for (size_t chunkSize : chunkSizes) {
SHA256 sha2;
for (size_t i = 0; i < data.size(); i += chunkSize) {
size_t len = std::min(chunkSize, data.size() - i);
sha2.Update(data.data() + i, len);
}
auto actual = sha2.Finalize();
EXPECT_EQ(expected, actual) << "Failed for chunk size: " << chunkSize;
}
}
// ============================================
// 文件传输完整性验证场景
// ============================================
class TransferVerificationTest : public ::testing::Test {};
TEST_F(TransferVerificationTest, SimulateSuccessfulTransfer) {
// 模拟文件数据
std::vector<uint8_t> fileData(1024 * 64); // 64 KB
for (size_t i = 0; i < fileData.size(); ++i) {
fileData[i] = static_cast<uint8_t>(i * 17); // 伪随机数据
}
// 发送方计算哈希
SHA256 senderSha;
senderSha.Update(fileData);
auto senderHash = senderSha.Finalize();
// 构建校验包
auto packet = BuildFileCompletePacket(12345, 100, 0, 0, fileData.size(), senderHash);
// 接收方解析并验证
FileCompletePacketV2 pkt;
ASSERT_TRUE(ParseFileCompletePacket(packet.data(), packet.size(), pkt));
EXPECT_EQ(pkt.fileSize, fileData.size());
// 接收方计算哈希并比较
std::array<uint8_t, 32> expectedHash;
memcpy(expectedHash.data(), pkt.sha256, 32);
auto result = VerifyFileIntegrity(expectedHash, fileData);
EXPECT_EQ(result, FEV2_OK);
}
TEST_F(TransferVerificationTest, SimulateCorruptedTransfer) {
// 原始文件
std::vector<uint8_t> originalData(1024);
for (size_t i = 0; i < originalData.size(); ++i) {
originalData[i] = static_cast<uint8_t>(i);
}
// 发送方计算哈希
SHA256 senderSha;
senderSha.Update(originalData);
auto senderHash = senderSha.Finalize();
// 模拟传输中数据损坏
std::vector<uint8_t> receivedData = originalData;
receivedData[512] ^= 0xFF; // 翻转一个字节
// 校验失败
auto result = VerifyFileIntegrity(senderHash, receivedData);
EXPECT_EQ(result, FEV2_HASH_MISMATCH);
}
TEST_F(TransferVerificationTest, MultipleFilesInTransfer) {
struct FileInfo {
std::vector<uint8_t> data;
std::array<uint8_t, 32> hash;
};
std::vector<FileInfo> files(5);
// 生成测试文件
for (int i = 0; i < 5; ++i) {
files[i].data.resize(1000 * (i + 1));
for (size_t j = 0; j < files[i].data.size(); ++j) {
files[i].data[j] = static_cast<uint8_t>(j + i * 7);
}
SHA256 sha;
sha.Update(files[i].data);
files[i].hash = sha.Finalize();
}
// 验证每个文件
for (int i = 0; i < 5; ++i) {
auto result = VerifyFileIntegrity(files[i].hash, files[i].data);
EXPECT_EQ(result, FEV2_OK) << "File " << i << " verification failed";
}
// 交叉验证应该失败
auto crossResult = VerifyFileIntegrity(files[0].hash, files[1].data);
EXPECT_EQ(crossResult, FEV2_HASH_MISMATCH);
}
// ============================================
// 边界条件测试
// ============================================
class HashBoundaryTest : public ::testing::Test {};
TEST_F(HashBoundaryTest, ExactBlockSize) {
// SHA-256 块大小是 64 字节
std::vector<uint8_t> data(64, 'A');
SHA256 sha;
sha.Update(data);
auto hash = sha.Finalize();
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(HashBoundaryTest, BlockSizePlusOne) {
std::vector<uint8_t> data(65, 'B');
SHA256 sha;
sha.Update(data);
auto hash = sha.Finalize();
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(HashBoundaryTest, BlockSizeMinusOne) {
std::vector<uint8_t> data(63, 'C');
SHA256 sha;
sha.Update(data);
auto hash = sha.Finalize();
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(HashBoundaryTest, MultipleBlocks) {
std::vector<uint8_t> data(64 * 10, 'D');
SHA256 sha;
sha.Update(data);
auto hash = sha.Finalize();
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(HashBoundaryTest, PaddingBoundary55) {
// 55 字节需要特殊处理padding + length 刚好 64
std::vector<uint8_t> data(55, 'E');
SHA256 sha;
sha.Update(data);
auto hash = sha.Finalize();
EXPECT_EQ(hash.size(), 32u);
}
TEST_F(HashBoundaryTest, PaddingBoundary56) {
// 56 字节需要额外的块
std::vector<uint8_t> data(56, 'F');
SHA256 sha;
sha.Update(data);
auto hash = sha.Finalize();
EXPECT_EQ(hash.size(), 32u);
}

View File

@@ -0,0 +1,378 @@
// GeoLocationTest.cpp - IP地理位置API单元测试
// 测试 location.h 中的 GetGeoLocation 功能
#include <gtest/gtest.h>
#ifdef _WIN32
#include <winsock2.h>
#include <windows.h>
#include <wininet.h>
#include <ws2tcpip.h>
#pragma comment(lib, "wininet.lib")
#pragma comment(lib, "ws2_32.lib")
#include <string>
#include <vector>
// ============================================
// 简化的测试版本 - 不依赖jsoncpp
// ============================================
// UTF-8 转 ANSI
inline std::string Utf8ToAnsi(const std::string& utf8)
{
if (utf8.empty()) return "";
int wideLen = MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, NULL, 0);
if (wideLen <= 0) return utf8;
std::wstring wideStr(wideLen, 0);
MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, &wideStr[0], wideLen);
int ansiLen = WideCharToMultiByte(CP_ACP, 0, wideStr.c_str(), -1, NULL, 0, NULL, NULL);
if (ansiLen <= 0) return utf8;
std::string ansiStr(ansiLen, 0);
WideCharToMultiByte(CP_ACP, 0, wideStr.c_str(), -1, &ansiStr[0], ansiLen, NULL, NULL);
if (!ansiStr.empty() && ansiStr.back() == '\0') ansiStr.pop_back();
return ansiStr;
}
// 简单JSON值提取 (仅用于测试不依赖jsoncpp)
std::string ExtractJsonString(const std::string& json, const std::string& key)
{
std::string searchKey = "\"" + key + "\"";
size_t keyPos = json.find(searchKey);
if (keyPos == std::string::npos) return "";
size_t colonPos = json.find(':', keyPos);
if (colonPos == std::string::npos) return "";
size_t valueStart = json.find_first_not_of(" \t\n\r", colonPos + 1);
if (valueStart == std::string::npos) return "";
if (json[valueStart] == '"') {
size_t valueEnd = json.find('"', valueStart + 1);
if (valueEnd == std::string::npos) return "";
return json.substr(valueStart + 1, valueEnd - valueStart - 1);
}
// 数字或布尔值
size_t valueEnd = json.find_first_of(",}\n\r", valueStart);
if (valueEnd == std::string::npos) valueEnd = json.length();
std::string value = json.substr(valueStart, valueEnd - valueStart);
// 去除尾部空格
while (!value.empty() && (value.back() == ' ' || value.back() == '\t')) {
value.pop_back();
}
return value;
}
// API配置结构
struct GeoApiConfig {
const char* name;
const char* urlFmt;
const char* cityField;
const char* countryField;
const char* checkField;
const char* checkValue;
bool useHttps;
};
// 测试用API配置
static const GeoApiConfig testApis[] = {
{"ip-api.com", "http://ip-api.com/json/%s?fields=status,country,city", "city", "country", "status", "success", false},
{"ipinfo.io", "http://ipinfo.io/%s/json", "city", "country", "", "", false},
{"ipapi.co", "https://ipapi.co/%s/json/", "city", "country_name", "error", "", true},
};
// 测试单个API
struct ApiTestResult {
bool success;
int httpStatus;
std::string city;
std::string country;
std::string error;
double latencyMs;
};
ApiTestResult TestSingleApi(const GeoApiConfig& api, const std::string& ip)
{
ApiTestResult result = {false, 0, "", "", "", 0};
DWORD startTime = GetTickCount();
HINTERNET hInternet = InternetOpenA("GeoLocationTest", INTERNET_OPEN_TYPE_DIRECT, NULL, NULL, 0);
if (!hInternet) {
result.error = "InternetOpen failed";
return result;
}
DWORD timeout = 10000;
InternetSetOptionA(hInternet, INTERNET_OPTION_CONNECT_TIMEOUT, &timeout, sizeof(timeout));
InternetSetOptionA(hInternet, INTERNET_OPTION_SEND_TIMEOUT, &timeout, sizeof(timeout));
InternetSetOptionA(hInternet, INTERNET_OPTION_RECEIVE_TIMEOUT, &timeout, sizeof(timeout));
char urlBuf[256];
sprintf_s(urlBuf, api.urlFmt, ip.c_str());
DWORD flags = INTERNET_FLAG_RELOAD;
if (api.useHttps) flags |= INTERNET_FLAG_SECURE;
HINTERNET hConnect = InternetOpenUrlA(hInternet, urlBuf, NULL, 0, flags, 0);
if (!hConnect) {
result.error = "InternetOpenUrl failed: " + std::to_string(GetLastError());
InternetCloseHandle(hInternet);
return result;
}
// 获取HTTP状态码
DWORD statusCode = 0;
DWORD statusSize = sizeof(statusCode);
if (HttpQueryInfoA(hConnect, HTTP_QUERY_STATUS_CODE | HTTP_QUERY_FLAG_NUMBER, &statusCode, &statusSize, NULL)) {
result.httpStatus = statusCode;
}
// 读取响应
std::string readBuffer;
char buffer[4096];
DWORD bytesRead;
while (InternetReadFile(hConnect, buffer, sizeof(buffer), &bytesRead) && bytesRead > 0) {
readBuffer.append(buffer, bytesRead);
}
InternetCloseHandle(hConnect);
InternetCloseHandle(hInternet);
result.latencyMs = (double)(GetTickCount() - startTime);
// 检查HTTP错误
if (result.httpStatus >= 400) {
result.error = "HTTP " + std::to_string(result.httpStatus);
if (result.httpStatus == 429) {
result.error += " (Rate Limited)";
}
return result;
}
// 检查响应体错误
if (readBuffer.find("Rate limit") != std::string::npos ||
readBuffer.find("rate limit") != std::string::npos) {
result.error = "Rate limited (body)";
return result;
}
// 解析JSON
if (api.checkField && api.checkField[0]) {
std::string checkVal = ExtractJsonString(readBuffer, api.checkField);
if (api.checkValue && api.checkValue[0]) {
if (checkVal != api.checkValue) {
result.error = "Check failed: " + std::string(api.checkField) + "=" + checkVal;
return result;
}
} else {
if (checkVal == "true") {
result.error = "Error flag set";
return result;
}
}
}
result.city = Utf8ToAnsi(ExtractJsonString(readBuffer, api.cityField));
result.country = Utf8ToAnsi(ExtractJsonString(readBuffer, api.countryField));
result.success = !result.city.empty() || !result.country.empty();
if (!result.success) {
result.error = "No city/country in response";
}
return result;
}
// ============================================
// 测试用例
// ============================================
class GeoLocationTest : public ::testing::Test {
protected:
void SetUp() override {
// 初始化Winsock
WSADATA wsaData;
WSAStartup(MAKEWORD(2, 2), &wsaData);
}
void TearDown() override {
WSACleanup();
}
};
// 测试公网IP (Google DNS)
TEST_F(GeoLocationTest, TestPublicIP_GoogleDNS) {
std::string testIP = "8.8.8.8";
std::cout << "\n=== Testing IP: " << testIP << " ===\n";
int successCount = 0;
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] ";
if (result.success) {
std::cout << "OK - " << result.city << ", " << result.country;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
successCount++;
} else {
std::cout << "FAIL - " << result.error;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
}
}
// 至少一个API应该成功
EXPECT_GE(successCount, 1) << "At least one API should succeed";
}
// 测试另一个公网IP (Cloudflare DNS)
TEST_F(GeoLocationTest, TestPublicIP_CloudflareDNS) {
std::string testIP = "1.1.1.1";
std::cout << "\n=== Testing IP: " << testIP << " ===\n";
int successCount = 0;
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] ";
if (result.success) {
std::cout << "OK - " << result.city << ", " << result.country;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
successCount++;
} else {
std::cout << "FAIL - " << result.error;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
}
}
EXPECT_GE(successCount, 1) << "At least one API should succeed";
}
// 测试中国IP
TEST_F(GeoLocationTest, TestPublicIP_ChinaIP) {
std::string testIP = "114.114.114.114"; // 114DNS
std::cout << "\n=== Testing IP: " << testIP << " (China) ===\n";
int successCount = 0;
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] ";
if (result.success) {
std::cout << "OK - " << result.city << ", " << result.country;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
successCount++;
} else {
std::cout << "FAIL - " << result.error;
std::cout << " (HTTP " << result.httpStatus << ", " << result.latencyMs << "ms)\n";
}
}
EXPECT_GE(successCount, 1) << "At least one API should succeed";
}
// 测试ip-api.com单独
TEST_F(GeoLocationTest, TestIpApiCom) {
std::string testIP = "8.8.8.8";
auto result = TestSingleApi(testApis[0], testIP);
std::cout << "\n[ip-api.com] HTTP: " << result.httpStatus
<< ", Latency: " << result.latencyMs << "ms\n";
if (result.success) {
std::cout << "City: " << result.city << ", Country: " << result.country << "\n";
EXPECT_FALSE(result.country.empty());
} else {
std::cout << "Error: " << result.error << "\n";
// 如果是429就跳过不算失败
if (result.httpStatus == 429) {
GTEST_SKIP() << "Rate limited, skipping";
}
}
}
// 测试ipinfo.io单独
TEST_F(GeoLocationTest, TestIpInfoIo) {
std::string testIP = "8.8.8.8";
auto result = TestSingleApi(testApis[1], testIP);
std::cout << "\n[ipinfo.io] HTTP: " << result.httpStatus
<< ", Latency: " << result.latencyMs << "ms\n";
if (result.success) {
std::cout << "City: " << result.city << ", Country: " << result.country << "\n";
EXPECT_FALSE(result.country.empty());
} else {
std::cout << "Error: " << result.error << "\n";
if (result.httpStatus == 429) {
GTEST_SKIP() << "Rate limited, skipping";
}
}
}
// 测试ipapi.co单独
TEST_F(GeoLocationTest, TestIpApiCo) {
std::string testIP = "8.8.8.8";
auto result = TestSingleApi(testApis[2], testIP);
std::cout << "\n[ipapi.co] HTTP: " << result.httpStatus
<< ", Latency: " << result.latencyMs << "ms\n";
if (result.success) {
std::cout << "City: " << result.city << ", Country: " << result.country << "\n";
EXPECT_FALSE(result.country.empty());
} else {
std::cout << "Error: " << result.error << "\n";
if (result.httpStatus == 429) {
GTEST_SKIP() << "Rate limited, skipping";
}
}
}
// 测试HTTP状态码检测
TEST_F(GeoLocationTest, TestHttpStatusCodeDetection) {
// 使用一个会返回错误的请求
std::string testIP = "invalid-ip";
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] Invalid IP test: HTTP "
<< result.httpStatus << ", Error: " << result.error << "\n";
// 应该失败
EXPECT_FALSE(result.success);
}
}
// 测试所有API的响应时间
TEST_F(GeoLocationTest, TestApiLatency) {
std::string testIP = "8.8.8.8";
std::cout << "\n=== API Latency Test ===\n";
for (const auto& api : testApis) {
auto result = TestSingleApi(api, testIP);
std::cout << "[" << api.name << "] " << result.latencyMs << "ms";
if (result.success) {
std::cout << " (OK)";
} else {
std::cout << " (FAIL: " << result.error << ")";
}
std::cout << "\n";
// 响应时间应该在合理范围内 (< 15秒)
EXPECT_LT(result.latencyMs, 15000);
}
}
#else
// 非Windows平台 - 跳过测试
TEST(GeoLocationTest, SkipNonWindows) {
GTEST_SKIP() << "GeoLocation tests only run on Windows";
}
#endif

View File

@@ -0,0 +1,642 @@
/**
* @file HeaderTest.cpp
* @brief 协议头验证和加密测试
*
* 测试覆盖:
* - 协议头格式和常量
* - 加密/解密函数正确性
* - 多版本头部识别
* - 头部生成和验证往返
*/
#include <gtest/gtest.h>
#include <cstring>
#include <cstdint>
#include <vector>
#include <array>
// ============================================
// 协议常量定义(测试专用副本)
// ============================================
#define MSG_HEADER "HELL"
const int FLAG_COMPLEN = 4;
const int FLAG_LENGTH = 8;
const int HDR_LENGTH = FLAG_LENGTH + 2 * sizeof(unsigned int); // 16
const int MIN_COMLEN = 12;
enum HeaderEncType {
HeaderEncUnknown = -1,
HeaderEncNone,
HeaderEncV0,
HeaderEncV1,
HeaderEncV2,
HeaderEncV3,
HeaderEncV4,
HeaderEncV5,
HeaderEncV6,
HeaderEncNum,
};
enum FlagType {
FLAG_WINOS = -1,
FLAG_UNKNOWN = 0,
FLAG_SHINE = 1,
FLAG_FUCK = 2,
FLAG_HELLO = 3,
FLAG_HELL = 4,
};
// ============================================
// 加密/解密函数(测试专用副本)
// ============================================
inline void default_encrypt(unsigned char* data, size_t length, unsigned char key)
{
data[FLAG_LENGTH - 2] = data[FLAG_LENGTH - 1] = 0;
}
inline void default_decrypt(unsigned char* data, size_t length, unsigned char key)
{
}
inline void encrypt(unsigned char* data, size_t length, unsigned char key)
{
if (key == 0) return;
for (size_t i = 0; i < length; ++i) {
unsigned char k = static_cast<unsigned char>(key ^ (i * 31));
int value = static_cast<int>(data[i]);
switch (i % 4) {
case 0:
value += k;
break;
case 1:
value = value ^ k;
break;
case 2:
value -= k;
break;
case 3:
value = ~(value ^ k);
break;
}
data[i] = static_cast<unsigned char>(value & 0xFF);
}
}
inline void decrypt(unsigned char* data, size_t length, unsigned char key)
{
if (key == 0) return;
for (size_t i = 0; i < length; ++i) {
unsigned char k = static_cast<unsigned char>(key ^ (i * 31));
int value = static_cast<int>(data[i]);
switch (i % 4) {
case 0:
value -= k;
break;
case 1:
value = value ^ k;
break;
case 2:
value += k;
break;
case 3:
value = ~(value) ^ k;
break;
}
data[i] = static_cast<unsigned char>(value & 0xFF);
}
}
// ============================================
// HeaderFlag 结构体
// ============================================
typedef struct HeaderFlag {
char Data[FLAG_LENGTH + 1];
HeaderFlag(const char header[FLAG_LENGTH + 1])
{
memcpy(Data, header, sizeof(Data));
}
char& operator[](int i)
{
return Data[i];
}
const char operator[](int i) const
{
return Data[i];
}
const char* data() const
{
return Data;
}
} HeaderFlag;
typedef void (*EncFun)(unsigned char* data, size_t length, unsigned char key);
typedef void (*DecFun)(unsigned char* data, size_t length, unsigned char key);
// ============================================
// 头部生成函数
// ============================================
inline HeaderFlag GetHead(EncFun enc, unsigned char fixedKey = 0)
{
char header[FLAG_LENGTH + 1] = { 'H','E','L','L', 0 };
HeaderFlag H(header);
unsigned char key = (fixedKey != 0) ? fixedKey : (time(0) % 256);
H[FLAG_LENGTH - 2] = key;
H[FLAG_LENGTH - 1] = ~key;
enc((unsigned char*)H.data(), FLAG_COMPLEN, H[FLAG_LENGTH - 2]);
return H;
}
// ============================================
// 头部验证函数
// ============================================
inline int compare(const char *flag, const char *magic, int len, DecFun dec, unsigned char key)
{
unsigned char buf[32] = {};
memcpy(buf, flag, MIN_COMLEN);
dec(buf, len, key);
if (memcmp(buf, magic, len) == 0) {
memcpy((void*)flag, buf, MIN_COMLEN);
return 0;
}
return -1;
}
inline FlagType CheckHead(const char* flag, DecFun dec)
{
FlagType type = FLAG_UNKNOWN;
if (compare(flag, MSG_HEADER, FLAG_COMPLEN, dec, flag[6]) == 0) {
type = FLAG_HELL;
} else if (compare(flag, "Shine", 5, dec, 0) == 0) {
type = FLAG_SHINE;
} else if (compare(flag, "<<FUCK>>", 8, dec, flag[9]) == 0) {
type = FLAG_FUCK;
} else if (compare(flag, "Hello?", 6, dec, flag[6]) == 0) {
type = FLAG_HELLO;
} else {
type = FLAG_UNKNOWN;
}
return type;
}
inline FlagType CheckHeadMulti(char* flag, HeaderEncType& funcHit)
{
static const DecFun methods[] = { default_decrypt, decrypt };
static const int methodNum = sizeof(methods) / sizeof(DecFun);
char buffer[MIN_COMLEN + 4] = {};
for (int i = 0; i < methodNum; ++i) {
memcpy(buffer, flag, MIN_COMLEN);
FlagType type = CheckHead(buffer, methods[i]);
if (type != FLAG_UNKNOWN) {
memcpy(flag, buffer, MIN_COMLEN);
funcHit = HeaderEncType(i);
return type;
}
}
funcHit = HeaderEncUnknown;
return FLAG_UNKNOWN;
}
// ============================================
// 协议常量测试
// ============================================
class HeaderConstantsTest : public ::testing::Test {};
TEST_F(HeaderConstantsTest, FlagLength) {
EXPECT_EQ(FLAG_LENGTH, 8);
}
TEST_F(HeaderConstantsTest, FlagCompLen) {
EXPECT_EQ(FLAG_COMPLEN, 4);
}
TEST_F(HeaderConstantsTest, HdrLength) {
// FLAG_LENGTH(8) + 2 * sizeof(uint32_t)(8) = 16
EXPECT_EQ(HDR_LENGTH, 16);
}
TEST_F(HeaderConstantsTest, MinComLen) {
EXPECT_EQ(MIN_COMLEN, 12);
}
TEST_F(HeaderConstantsTest, MsgHeader) {
EXPECT_STREQ(MSG_HEADER, "HELL");
EXPECT_EQ(strlen(MSG_HEADER), 4u);
}
// ============================================
// 加密/解密测试
// ============================================
class EncryptionTest : public ::testing::Test {};
TEST_F(EncryptionTest, ZeroKeyNoOp) {
unsigned char data[] = {0x01, 0x02, 0x03, 0x04, 0x05};
unsigned char original[] = {0x01, 0x02, 0x03, 0x04, 0x05};
encrypt(data, 5, 0);
EXPECT_EQ(memcmp(data, original, 5), 0);
decrypt(data, 5, 0);
EXPECT_EQ(memcmp(data, original, 5), 0);
}
TEST_F(EncryptionTest, EncryptDecryptRoundTrip) {
unsigned char original[] = {0x48, 0x45, 0x4C, 0x4C}; // "HELL"
unsigned char data[4];
memcpy(data, original, 4);
unsigned char key = 0x42;
encrypt(data, 4, key);
// 加密后应该不同
EXPECT_NE(memcmp(data, original, 4), 0);
decrypt(data, 4, key);
// 解密后应该恢复
EXPECT_EQ(memcmp(data, original, 4), 0);
}
TEST_F(EncryptionTest, DifferentKeysProduceDifferentResults) {
unsigned char data1[] = {0x48, 0x45, 0x4C, 0x4C};
unsigned char data2[] = {0x48, 0x45, 0x4C, 0x4C};
encrypt(data1, 4, 0x10);
encrypt(data2, 4, 0x20);
EXPECT_NE(memcmp(data1, data2, 4), 0);
}
TEST_F(EncryptionTest, PositionDependentEncryption) {
// 相同值在不同位置加密结果不同
unsigned char data[] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
unsigned char original[] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
encrypt(data, 8, 0x55);
// 验证加密后值不全相同(位置相关加密)
bool allSame = true;
for (int i = 1; i < 8; ++i) {
if (data[i] != data[0]) {
allSame = false;
break;
}
}
EXPECT_FALSE(allSame);
// 验证解密恢复
decrypt(data, 8, 0x55);
EXPECT_EQ(memcmp(data, original, 8), 0);
}
TEST_F(EncryptionTest, AllKeyValues) {
// 测试所有可能的密钥值
unsigned char original[] = {0x12, 0x34, 0x56, 0x78};
for (int key = 1; key < 256; ++key) {
unsigned char data[4];
memcpy(data, original, 4);
encrypt(data, 4, static_cast<unsigned char>(key));
decrypt(data, 4, static_cast<unsigned char>(key));
EXPECT_EQ(memcmp(data, original, 4), 0) << "Failed for key: " << key;
}
}
TEST_F(EncryptionTest, LargeDataRoundTrip) {
std::vector<unsigned char> original(1000);
for (size_t i = 0; i < original.size(); ++i) {
original[i] = static_cast<unsigned char>(i & 0xFF);
}
std::vector<unsigned char> data = original;
unsigned char key = 0x7F;
encrypt(data.data(), data.size(), key);
decrypt(data.data(), data.size(), key);
EXPECT_EQ(data, original);
}
// ============================================
// HeaderFlag 测试
// ============================================
class HeaderFlagTest : public ::testing::Test {};
TEST_F(HeaderFlagTest, Construction) {
char header[FLAG_LENGTH + 1] = { 'H','E','L','L', 0, 0, 0, 0, 0 };
HeaderFlag hf(header);
EXPECT_EQ(hf[0], 'H');
EXPECT_EQ(hf[1], 'E');
EXPECT_EQ(hf[2], 'L');
EXPECT_EQ(hf[3], 'L');
}
TEST_F(HeaderFlagTest, DataAccess) {
char header[FLAG_LENGTH + 1] = { 'T','E','S','T', 0x12, 0x34, 0x56, 0x78, 0 };
HeaderFlag hf(header);
EXPECT_EQ(memcmp(hf.data(), "TEST", 4), 0);
EXPECT_EQ(static_cast<unsigned char>(hf[4]), 0x12);
EXPECT_EQ(static_cast<unsigned char>(hf[5]), 0x34);
}
TEST_F(HeaderFlagTest, Modification) {
char header[FLAG_LENGTH + 1] = { 0 };
HeaderFlag hf(header);
hf[0] = 'A';
hf[1] = 'B';
EXPECT_EQ(hf[0], 'A');
EXPECT_EQ(hf[1], 'B');
}
// ============================================
// GetHead 测试
// ============================================
class GetHeadTest : public ::testing::Test {};
TEST_F(GetHeadTest, GeneratesValidHeader) {
HeaderFlag hf = GetHead(default_encrypt, 0x42);
// 检查基础格式
EXPECT_EQ(hf[0], 'H');
EXPECT_EQ(hf[1], 'E');
EXPECT_EQ(hf[2], 'L');
EXPECT_EQ(hf[3], 'L');
}
TEST_F(GetHeadTest, KeyAndInverseKey) {
unsigned char key = 0x42;
HeaderFlag hf = GetHead(default_encrypt, key);
// 使用 default_encrypt 时key 位置被清零
// 但我们需要验证生成逻辑
char rawHeader[FLAG_LENGTH + 1] = { 'H','E','L','L', 0 };
HeaderFlag expected(rawHeader);
expected[FLAG_LENGTH - 2] = key;
expected[FLAG_LENGTH - 1] = ~key;
default_encrypt((unsigned char*)expected.data(), FLAG_COMPLEN, expected[FLAG_LENGTH - 2]);
EXPECT_EQ(memcmp(hf.data(), expected.data(), FLAG_LENGTH), 0);
}
TEST_F(GetHeadTest, EncryptedHeader) {
unsigned char key = 0x55;
HeaderFlag hf = GetHead(encrypt, key);
// 头部应该被加密,不再是明文 "HELL"
EXPECT_NE(memcmp(hf.data(), "HELL", 4), 0);
// 密钥应该在正确位置
unsigned char storedKey = static_cast<unsigned char>(hf[FLAG_LENGTH - 2]);
unsigned char inverseKey = static_cast<unsigned char>(hf[FLAG_LENGTH - 1]);
EXPECT_EQ(storedKey, key);
EXPECT_EQ(inverseKey, static_cast<unsigned char>(~key));
}
// ============================================
// CheckHead 测试
// ============================================
class CheckHeadTest : public ::testing::Test {};
TEST_F(CheckHeadTest, IdentifyHellFlag) {
char header[MIN_COMLEN + 4] = { 'H','E','L','L', 0, 0, 0x42, static_cast<char>(~0x42), 0 };
FlagType type = CheckHead(header, default_decrypt);
EXPECT_EQ(type, FLAG_HELL);
}
TEST_F(CheckHeadTest, IdentifyShineFlag) {
char header[MIN_COMLEN + 4] = { 'S','h','i','n','e', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
FlagType type = CheckHead(header, default_decrypt);
EXPECT_EQ(type, FLAG_SHINE);
}
TEST_F(CheckHeadTest, UnknownFlag) {
char header[MIN_COMLEN + 4] = { 'X','Y','Z','W', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
FlagType type = CheckHead(header, default_decrypt);
EXPECT_EQ(type, FLAG_UNKNOWN);
}
TEST_F(CheckHeadTest, EncryptedHellFlag) {
// 生成加密的头部
unsigned char key = 0x33;
char header[FLAG_LENGTH + 1] = { 'H','E','L','L', 0, 0, 0, 0, 0 };
header[FLAG_LENGTH - 2] = key;
header[FLAG_LENGTH - 1] = ~key;
encrypt((unsigned char*)header, FLAG_COMPLEN, key);
char buffer[MIN_COMLEN + 4] = {};
memcpy(buffer, header, FLAG_LENGTH);
FlagType type = CheckHead(buffer, decrypt);
EXPECT_EQ(type, FLAG_HELL);
}
// ============================================
// CheckHeadMulti 测试
// ============================================
class CheckHeadMultiTest : public ::testing::Test {};
TEST_F(CheckHeadMultiTest, PlainTextHeader) {
char header[MIN_COMLEN + 4] = { 'H','E','L','L', 0, 0, 0x42, static_cast<char>(~0x42), 0 };
HeaderEncType encType;
FlagType type = CheckHeadMulti(header, encType);
EXPECT_EQ(type, FLAG_HELL);
EXPECT_EQ(encType, HeaderEncNone);
}
TEST_F(CheckHeadMultiTest, EncryptedHeader) {
unsigned char key = 0x77;
char header[MIN_COMLEN + 4] = { 'H','E','L','L', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
header[FLAG_LENGTH - 2] = key;
header[FLAG_LENGTH - 1] = ~key;
encrypt((unsigned char*)header, FLAG_COMPLEN, key);
HeaderEncType encType;
FlagType type = CheckHeadMulti(header, encType);
EXPECT_EQ(type, FLAG_HELL);
EXPECT_EQ(encType, HeaderEncV0); // encrypt 对应 V0
}
TEST_F(CheckHeadMultiTest, UnrecognizedHeader) {
char header[MIN_COMLEN + 4] = { 0xFF, 0xFE, 0xFD, 0xFC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
HeaderEncType encType;
FlagType type = CheckHeadMulti(header, encType);
EXPECT_EQ(type, FLAG_UNKNOWN);
EXPECT_EQ(encType, HeaderEncUnknown);
}
// ============================================
// 头部生成和验证往返测试
// ============================================
class HeaderRoundTripTest : public ::testing::Test {};
TEST_F(HeaderRoundTripTest, DefaultEncryptRoundTrip) {
HeaderFlag hf = GetHead(default_encrypt, 0x42);
char buffer[MIN_COMLEN + 4] = {};
memcpy(buffer, hf.data(), FLAG_LENGTH);
HeaderEncType encType;
FlagType type = CheckHeadMulti(buffer, encType);
EXPECT_EQ(type, FLAG_HELL);
}
TEST_F(HeaderRoundTripTest, EncryptRoundTrip) {
HeaderFlag hf = GetHead(encrypt, 0x88);
char buffer[MIN_COMLEN + 4] = {};
memcpy(buffer, hf.data(), FLAG_LENGTH);
HeaderEncType encType;
FlagType type = CheckHeadMulti(buffer, encType);
EXPECT_EQ(type, FLAG_HELL);
}
TEST_F(HeaderRoundTripTest, AllKeyValuesRoundTrip) {
for (int key = 1; key < 256; ++key) {
HeaderFlag hf = GetHead(encrypt, static_cast<unsigned char>(key));
char buffer[MIN_COMLEN + 4] = {};
memcpy(buffer, hf.data(), FLAG_LENGTH);
HeaderEncType encType;
FlagType type = CheckHeadMulti(buffer, encType);
EXPECT_EQ(type, FLAG_HELL) << "Failed for key: " << key;
}
}
// ============================================
// 数据包长度字段测试
// ============================================
class PacketLengthTest : public ::testing::Test {};
#pragma pack(push, 1)
struct PacketHeader {
char flag[FLAG_LENGTH];
uint32_t packedLength; // 压缩后长度
uint32_t originalLength; // 原始长度
};
#pragma pack(pop)
TEST_F(PacketLengthTest, HeaderSize) {
EXPECT_EQ(sizeof(PacketHeader), HDR_LENGTH);
}
TEST_F(PacketLengthTest, BuildAndParsePacket) {
PacketHeader pkt = {};
memcpy(pkt.flag, "HELL", 4);
pkt.flag[6] = 0x42;
pkt.flag[7] = ~0x42;
pkt.packedLength = 100;
pkt.originalLength = 200;
// 解析
uint32_t packedLen, origLen;
memcpy(&packedLen, reinterpret_cast<char*>(&pkt) + FLAG_LENGTH, sizeof(uint32_t));
memcpy(&origLen, reinterpret_cast<char*>(&pkt) + FLAG_LENGTH + sizeof(uint32_t), sizeof(uint32_t));
EXPECT_EQ(packedLen, 100u);
EXPECT_EQ(origLen, 200u);
}
TEST_F(PacketLengthTest, TotalPacketLength) {
// 总包长度 = HDR_LENGTH + 压缩数据长度
uint32_t dataLen = 1000;
uint32_t totalLen = HDR_LENGTH + dataLen;
EXPECT_EQ(totalLen, 1016u);
}
// ============================================
// 边界条件测试
// ============================================
class HeaderBoundaryTest : public ::testing::Test {};
TEST_F(HeaderBoundaryTest, MinimalPacket) {
// 最小合法包:只有头部
std::vector<uint8_t> packet(HDR_LENGTH);
memcpy(packet.data(), "HELL", 4);
packet[6] = 0x42;
packet[7] = ~0x42;
// packedLength = HDR_LENGTH
uint32_t packedLen = HDR_LENGTH;
memcpy(packet.data() + FLAG_LENGTH, &packedLen, sizeof(uint32_t));
// originalLength = 0
uint32_t origLen = 0;
memcpy(packet.data() + FLAG_LENGTH + sizeof(uint32_t), &origLen, sizeof(uint32_t));
EXPECT_EQ(packet.size(), HDR_LENGTH);
}
TEST_F(HeaderBoundaryTest, MaxPacketLength) {
// 验证大包长度字段
uint32_t maxDataLen = 100 * 1024 * 1024; // 100 MB
uint32_t totalLen = HDR_LENGTH + maxDataLen;
PacketHeader pkt = {};
pkt.packedLength = totalLen;
pkt.originalLength = maxDataLen * 2; // 压缩前更大
EXPECT_EQ(pkt.packedLength, totalLen);
EXPECT_EQ(pkt.originalLength, maxDataLen * 2);
}
TEST_F(HeaderBoundaryTest, TruncatedHeader) {
// 不完整的头部不应该被识别
char truncated[4] = { 'H', 'E', 'L', 'L' };
// 不足以进行完整验证
// 这里只是验证常量定义正确
EXPECT_LT(sizeof(truncated), static_cast<size_t>(MIN_COMLEN));
}
// ============================================
// FlagType 枚举测试
// ============================================
class FlagTypeTest : public ::testing::Test {};
TEST_F(FlagTypeTest, EnumValues) {
EXPECT_EQ(FLAG_WINOS, -1);
EXPECT_EQ(FLAG_UNKNOWN, 0);
EXPECT_EQ(FLAG_SHINE, 1);
EXPECT_EQ(FLAG_FUCK, 2);
EXPECT_EQ(FLAG_HELLO, 3);
EXPECT_EQ(FLAG_HELL, 4);
}
TEST_F(FlagTypeTest, HeaderEncTypeValues) {
EXPECT_EQ(HeaderEncUnknown, -1);
EXPECT_EQ(HeaderEncNone, 0);
EXPECT_EQ(HeaderEncV0, 1);
EXPECT_EQ(HeaderEncNum, 8);
}

View File

@@ -0,0 +1,534 @@
/**
* @file HttpMaskTest.cpp
* @brief HTTP 协议伪装测试
*
* 测试覆盖:
* - HTTP 请求格式生成
* - HTTP 头部解析和移除
* - Mask/UnMask 往返测试
* - 边界条件处理
*/
#include <gtest/gtest.h>
#include <cstring>
#include <string>
#include <vector>
#include <map>
// ============================================
// 类型定义
// ============================================
#ifdef _WIN32
typedef unsigned long ULONG;
#else
typedef uint32_t ULONG;
#endif
// ============================================
// 协议伪装类型
// ============================================
enum PkgMaskType {
MaskTypeUnknown = -1,
MaskTypeNone,
MaskTypeHTTP,
MaskTypeNum,
};
// ============================================
// HTTP 解除伪装函数
// ============================================
inline ULONG UnMaskHttp(const char* src, ULONG srcSize)
{
const char* header_end_mark = "\r\n\r\n";
const ULONG mark_len = 4;
for (ULONG i = 0; i + mark_len <= srcSize; ++i) {
if (memcmp(src + i, header_end_mark, mark_len) == 0) {
return i + mark_len;
}
}
return 0;
}
inline ULONG TryUnMask(const char* src, ULONG srcSize, PkgMaskType& maskHit)
{
if (srcSize >= 5 && memcmp(src, "POST ", 5) == 0) {
maskHit = MaskTypeHTTP;
return UnMaskHttp(src, srcSize);
}
if (srcSize >= 4 && memcmp(src, "GET ", 4) == 0) {
maskHit = MaskTypeHTTP;
return UnMaskHttp(src, srcSize);
}
maskHit = MaskTypeNone;
return 0;
}
// ============================================
// HTTP Mask 类(简化版)
// ============================================
class HttpMask {
public:
explicit HttpMask(const std::string& host = "example.com",
const std::map<std::string, std::string>& headers = {})
: host_(host)
{
for (const auto& kv : headers) {
customHeaders_ += kv.first + ": " + kv.second + "\r\n";
}
}
void Mask(char*& dst, ULONG& dstSize, const char* src, ULONG srcSize, int cmd = -1)
{
std::string path = "/api/v1/" + std::to_string(cmd == -1 ? 0 : cmd);
std::string httpHeader =
"POST " + path + " HTTP/1.1\r\n"
"Host: " + host_ + "\r\n"
"User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64)\r\n"
"Content-Type: application/octet-stream\r\n"
"Content-Length: " + std::to_string(srcSize) + "\r\n" + customHeaders_ +
"Connection: keep-alive\r\n"
"\r\n";
dstSize = static_cast<ULONG>(httpHeader.size()) + srcSize;
dst = new char[dstSize];
memcpy(dst, httpHeader.data(), httpHeader.size());
if (srcSize > 0) {
memcpy(dst + httpHeader.size(), src, srcSize);
}
}
ULONG UnMask(const char* src, ULONG srcSize)
{
return UnMaskHttp(src, srcSize);
}
void SetHost(const std::string& host)
{
host_ = host;
}
private:
std::string host_;
std::string customHeaders_;
};
// ============================================
// UnMaskHttp 测试
// ============================================
class UnMaskHttpTest : public ::testing::Test {};
TEST_F(UnMaskHttpTest, ValidHttpRequest) {
std::string httpRequest =
"POST /api HTTP/1.1\r\n"
"Host: example.com\r\n"
"Content-Length: 4\r\n"
"\r\n"
"DATA";
ULONG offset = UnMaskHttp(httpRequest.data(), static_cast<ULONG>(httpRequest.size()));
// 应该返回 body 起始位置
EXPECT_GT(offset, 0u);
EXPECT_STREQ(httpRequest.data() + offset, "DATA");
}
TEST_F(UnMaskHttpTest, NoHeaderEnd) {
std::string incomplete = "POST /api HTTP/1.1\r\nHost: example.com\r\n";
ULONG offset = UnMaskHttp(incomplete.data(), static_cast<ULONG>(incomplete.size()));
EXPECT_EQ(offset, 0u);
}
TEST_F(UnMaskHttpTest, EmptyBody) {
std::string httpRequest =
"POST /api HTTP/1.1\r\n"
"Content-Length: 0\r\n"
"\r\n";
ULONG offset = UnMaskHttp(httpRequest.data(), static_cast<ULONG>(httpRequest.size()));
EXPECT_EQ(offset, static_cast<ULONG>(httpRequest.size()));
}
TEST_F(UnMaskHttpTest, MultipleHeaderEndMarkers) {
std::string httpRequest =
"POST /api HTTP/1.1\r\n"
"Content-Length: 8\r\n"
"\r\n"
"\r\n\r\nXX"; // body 中也有 \r\n\r\n
ULONG offset = UnMaskHttp(httpRequest.data(), static_cast<ULONG>(httpRequest.size()));
// 应该返回第一个 \r\n\r\n 之后
std::string body(httpRequest.data() + offset);
EXPECT_EQ(body, "\r\n\r\nXX");
}
TEST_F(UnMaskHttpTest, MinimalInput) {
// 小于 4 字节
std::string tiny = "AB";
ULONG offset = UnMaskHttp(tiny.data(), static_cast<ULONG>(tiny.size()));
EXPECT_EQ(offset, 0u);
}
TEST_F(UnMaskHttpTest, ExactlyHeaderEnd) {
std::string justEnd = "\r\n\r\n";
ULONG offset = UnMaskHttp(justEnd.data(), static_cast<ULONG>(justEnd.size()));
EXPECT_EQ(offset, 4u);
}
// ============================================
// TryUnMask 测试
// ============================================
class TryUnMaskTest : public ::testing::Test {};
TEST_F(TryUnMaskTest, DetectPOST) {
std::string httpRequest =
"POST /api HTTP/1.1\r\n"
"Host: test.com\r\n"
"\r\n"
"body";
PkgMaskType maskType;
ULONG offset = TryUnMask(httpRequest.data(), static_cast<ULONG>(httpRequest.size()), maskType);
EXPECT_EQ(maskType, MaskTypeHTTP);
EXPECT_GT(offset, 0u);
}
TEST_F(TryUnMaskTest, DetectGET) {
std::string httpRequest =
"GET /resource HTTP/1.1\r\n"
"Host: test.com\r\n"
"\r\n";
PkgMaskType maskType;
ULONG offset = TryUnMask(httpRequest.data(), static_cast<ULONG>(httpRequest.size()), maskType);
EXPECT_EQ(maskType, MaskTypeHTTP);
EXPECT_GT(offset, 0u);
}
TEST_F(TryUnMaskTest, NonHttpData) {
std::string binaryData = "HELL\x00\x00\x42\xBD";
PkgMaskType maskType;
ULONG offset = TryUnMask(binaryData.data(), static_cast<ULONG>(binaryData.size()), maskType);
EXPECT_EQ(maskType, MaskTypeNone);
EXPECT_EQ(offset, 0u);
}
TEST_F(TryUnMaskTest, ShortInput) {
std::string tooShort = "POS";
PkgMaskType maskType;
ULONG offset = TryUnMask(tooShort.data(), static_cast<ULONG>(tooShort.size()), maskType);
EXPECT_EQ(maskType, MaskTypeNone);
EXPECT_EQ(offset, 0u);
}
// ============================================
// HttpMask 类测试
// ============================================
class HttpMaskClassTest : public ::testing::Test {};
TEST_F(HttpMaskClassTest, MaskBasic) {
HttpMask mask("api.example.com");
const char* data = "Hello";
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, data, 5);
ASSERT_NE(masked, nullptr);
EXPECT_GT(maskedSize, 5u);
// 验证是 HTTP 格式
EXPECT_EQ(memcmp(masked, "POST ", 5), 0);
// 验证 Host 头
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Host: api.example.com"), std::string::npos);
// 验证 Content-Length
EXPECT_NE(maskedStr.find("Content-Length: 5"), std::string::npos);
// 验证 body
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(memcmp(masked + offset, "Hello", 5), 0);
delete[] masked;
}
TEST_F(HttpMaskClassTest, MaskEmptyData) {
HttpMask mask;
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "", 0);
ASSERT_NE(masked, nullptr);
// 验证 Content-Length: 0
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Content-Length: 0"), std::string::npos);
delete[] masked;
}
TEST_F(HttpMaskClassTest, MaskWithCommand) {
HttpMask mask;
const char* data = "X";
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, data, 1, 42);
// 验证路径包含命令号
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("/42"), std::string::npos);
delete[] masked;
}
TEST_F(HttpMaskClassTest, MaskLargeData) {
HttpMask mask;
std::vector<char> largeData(64 * 1024, 'X'); // 64 KB
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, largeData.data(), static_cast<ULONG>(largeData.size()));
ASSERT_NE(masked, nullptr);
// 验证 Content-Length
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Content-Length: 65536"), std::string::npos);
// 验证 body 完整
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(maskedSize - offset, 64u * 1024u);
delete[] masked;
}
// ============================================
// Mask/UnMask 往返测试
// ============================================
class MaskRoundTripTest : public ::testing::Test {};
TEST_F(MaskRoundTripTest, SimpleRoundTrip) {
HttpMask mask;
std::vector<char> original = {'H', 'E', 'L', 'L', 'O'};
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
// UnMask
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_GT(offset, 0u);
// 验证数据完整
EXPECT_EQ(maskedSize - offset, original.size());
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
delete[] masked;
}
TEST_F(MaskRoundTripTest, BinaryDataRoundTrip) {
HttpMask mask;
// 包含所有字节值
std::vector<char> original(256);
for (int i = 0; i < 256; ++i) {
original[i] = static_cast<char>(i);
}
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
delete[] masked;
}
TEST_F(MaskRoundTripTest, NullBytesRoundTrip) {
HttpMask mask;
std::vector<char> original = {'\0', '\0', 'A', '\0', 'B'};
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(maskedSize - offset, original.size());
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
delete[] masked;
}
TEST_F(MaskRoundTripTest, HttpLikeDataRoundTrip) {
HttpMask mask;
// 数据本身看起来像 HTTP
std::string httpLike = "POST /fake HTTP/1.1\r\n\r\nfake body";
std::vector<char> original(httpLike.begin(), httpLike.end());
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, original.data(), static_cast<ULONG>(original.size()));
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(memcmp(masked + offset, original.data(), original.size()), 0);
delete[] masked;
}
// ============================================
// 自定义头部测试
// ============================================
class CustomHeadersTest : public ::testing::Test {};
TEST_F(CustomHeadersTest, AddCustomHeaders) {
std::map<std::string, std::string> headers;
headers["X-Custom-Header"] = "custom-value";
headers["X-Request-ID"] = "12345";
HttpMask mask("test.com", headers);
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "data", 4);
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("X-Custom-Header: custom-value"), std::string::npos);
EXPECT_NE(maskedStr.find("X-Request-ID: 12345"), std::string::npos);
delete[] masked;
}
// ============================================
// 边界条件测试
// ============================================
class HttpMaskBoundaryTest : public ::testing::Test {};
TEST_F(HttpMaskBoundaryTest, VeryLongHost) {
std::string longHost(1000, 'x');
longHost += ".com";
HttpMask mask(longHost);
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "test", 4);
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find(longHost), std::string::npos);
delete[] masked;
}
TEST_F(HttpMaskBoundaryTest, SpecialCharactersInHost) {
HttpMask mask("api-v2.test-server.example.com");
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "x", 1);
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Host: api-v2.test-server.example.com"), std::string::npos);
delete[] masked;
}
TEST_F(HttpMaskBoundaryTest, MaxULONGContentLength) {
// 测试大的 Content-Length 值
HttpMask mask;
// 不实际分配这么大的内存,只是构造请求头
std::string largeContentLengthStr = "Content-Length: 4294967295";
// 验证格式正确
EXPECT_EQ(largeContentLengthStr.find("4294967295"), 16u);
}
// ============================================
// HTTP 格式验证测试
// ============================================
class HttpFormatTest : public ::testing::Test {};
TEST_F(HttpFormatTest, ValidHttpRequestFormat) {
HttpMask mask("test.com");
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, "body", 4);
std::string maskedStr(masked, maskedSize);
// 验证请求行
EXPECT_EQ(maskedStr.substr(0, 5), "POST ");
// 验证 HTTP 版本
EXPECT_NE(maskedStr.find("HTTP/1.1\r\n"), std::string::npos);
// 验证必需的头部
EXPECT_NE(maskedStr.find("Host:"), std::string::npos);
EXPECT_NE(maskedStr.find("Content-Length:"), std::string::npos);
EXPECT_NE(maskedStr.find("Content-Type:"), std::string::npos);
// 验证头部和 body 之间的分隔符
EXPECT_NE(maskedStr.find("\r\n\r\n"), std::string::npos);
delete[] masked;
}
TEST_F(HttpFormatTest, ContentLengthMatchesBody) {
HttpMask mask;
std::vector<char> body(123, 'X');
char* masked = nullptr;
ULONG maskedSize = 0;
mask.Mask(masked, maskedSize, body.data(), static_cast<ULONG>(body.size()));
std::string maskedStr(masked, maskedSize);
EXPECT_NE(maskedStr.find("Content-Length: 123"), std::string::npos);
ULONG offset = mask.UnMask(masked, maskedSize);
EXPECT_EQ(maskedSize - offset, 123u);
delete[] masked;
}

View File

@@ -0,0 +1,660 @@
/**
* @file PacketFragmentTest.cpp
* @brief 粘包/分包处理测试
*
* 测试覆盖:
* - 完整包接收和解析
* - 分包(不完整包)处理
* - 粘包(多个包粘在一起)处理
* - 混合场景测试
* - 边界条件处理
*/
#include <gtest/gtest.h>
#include <cstring>
#include <cstdint>
#include <vector>
#include <queue>
#include <functional>
// ============================================
// 协议常量
// ============================================
const int FLAG_LENGTH = 8;
const int HDR_LENGTH = 16; // FLAG(8) + PackedLen(4) + OrigLen(4)
// ============================================
// 简化版 CBuffer测试专用
// ============================================
class TestBuffer {
public:
TestBuffer() {}
void WriteBuffer(const uint8_t* data, size_t len) {
m_data.insert(m_data.end(), data, data + len);
}
size_t ReadBuffer(uint8_t* dst, size_t len) {
size_t toRead = std::min(len, m_data.size());
if (toRead > 0) {
memcpy(dst, m_data.data(), toRead);
m_data.erase(m_data.begin(), m_data.begin() + toRead);
}
return toRead;
}
void Skip(size_t len) {
size_t toSkip = std::min(len, m_data.size());
m_data.erase(m_data.begin(), m_data.begin() + toSkip);
}
size_t GetBufferLength() const {
return m_data.size();
}
const uint8_t* GetBuffer(size_t pos = 0) const {
if (pos >= m_data.size()) return nullptr;
return m_data.data() + pos;
}
bool CopyBuffer(uint8_t* dst, size_t pos, size_t len) const {
if (pos + len > m_data.size()) return false;
memcpy(dst, m_data.data() + pos, len);
return true;
}
void Clear() {
m_data.clear();
}
private:
std::vector<uint8_t> m_data;
};
// ============================================
// 数据包构建辅助函数
// ============================================
#pragma pack(push, 1)
struct PacketHeader {
char flag[FLAG_LENGTH];
uint32_t packedLength; // 包含头部的总长度
uint32_t originalLength; // 原始数据长度
};
#pragma pack(pop)
std::vector<uint8_t> BuildPacket(const std::vector<uint8_t>& payload, uint8_t key = 0x42) {
std::vector<uint8_t> packet(HDR_LENGTH + payload.size());
PacketHeader* hdr = reinterpret_cast<PacketHeader*>(packet.data());
memcpy(hdr->flag, "HELL", 4);
hdr->flag[6] = key;
hdr->flag[7] = ~key;
hdr->packedLength = static_cast<uint32_t>(HDR_LENGTH + payload.size());
hdr->originalLength = static_cast<uint32_t>(payload.size());
if (!payload.empty()) {
memcpy(packet.data() + HDR_LENGTH, payload.data(), payload.size());
}
return packet;
}
std::vector<uint8_t> BuildPacketWithData(size_t dataSize, uint8_t fillByte = 0xAA) {
std::vector<uint8_t> payload(dataSize, fillByte);
return BuildPacket(payload);
}
// ============================================
// 粘包/分包处理器(模拟 OnServerReceiving 逻辑)
// ============================================
class PacketProcessor {
public:
using PacketCallback = std::function<void(const std::vector<uint8_t>&)>;
PacketProcessor(PacketCallback callback) : m_callback(callback) {}
// 接收数据(模拟网络接收)
void OnReceive(const uint8_t* data, size_t len) {
m_buffer.WriteBuffer(data, len);
ProcessBuffer();
}
// 获取待处理的字节数
size_t GetPendingBytes() const {
return m_buffer.GetBufferLength();
}
// 获取已处理的包数量
size_t GetProcessedCount() const {
return m_processedCount;
}
private:
void ProcessBuffer() {
while (m_buffer.GetBufferLength() >= HDR_LENGTH) {
// 验证头部
const uint8_t* buf = m_buffer.GetBuffer();
if (memcmp(buf, "HELL", 4) != 0) {
// 无效头部,跳过一个字节重试
m_buffer.Skip(1);
continue;
}
// 读取包长度
uint32_t packedLength;
m_buffer.CopyBuffer(reinterpret_cast<uint8_t*>(&packedLength),
FLAG_LENGTH, sizeof(uint32_t));
// 检查长度有效性
if (packedLength < HDR_LENGTH || packedLength > 100 * 1024 * 1024) {
// 无效长度,跳过头部
m_buffer.Skip(FLAG_LENGTH);
continue;
}
// 检查包是否完整
if (m_buffer.GetBufferLength() < packedLength) {
// 不完整,等待更多数据
break;
}
// 读取完整包
std::vector<uint8_t> packet(packedLength);
m_buffer.ReadBuffer(packet.data(), packedLength);
// 提取 payload
std::vector<uint8_t> payload(packet.begin() + HDR_LENGTH, packet.end());
m_callback(payload);
m_processedCount++;
}
}
TestBuffer m_buffer;
PacketCallback m_callback;
size_t m_processedCount = 0;
};
// ============================================
// 完整包接收测试
// ============================================
class CompletePacketTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
void SetUp() override {
receivedPackets.clear();
}
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(CompletePacketTest, SinglePacket) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0x01, 0x02, 0x03, 0x04};
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
EXPECT_EQ(processor.GetPendingBytes(), 0u);
}
TEST_F(CompletePacketTest, EmptyPayload) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload;
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_TRUE(receivedPackets[0].empty());
}
TEST_F(CompletePacketTest, LargePayload) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(64 * 1024); // 64 KB
for (size_t i = 0; i < payload.size(); ++i) {
payload[i] = static_cast<uint8_t>(i & 0xFF);
}
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
// ============================================
// 分包(不完整包)测试
// ============================================
class FragmentedPacketTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(FragmentedPacketTest, TwoFragments) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0xAA, 0xBB, 0xCC, 0xDD};
auto packet = BuildPacket(payload);
// 分两次发送
size_t half = packet.size() / 2;
processor.OnReceive(packet.data(), half);
EXPECT_EQ(receivedPackets.size(), 0u);
EXPECT_EQ(processor.GetPendingBytes(), half);
processor.OnReceive(packet.data() + half, packet.size() - half);
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
TEST_F(FragmentedPacketTest, ManyFragments) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(100);
for (size_t i = 0; i < payload.size(); ++i) {
payload[i] = static_cast<uint8_t>(i);
}
auto packet = BuildPacket(payload);
// 每次发送 10 字节
for (size_t i = 0; i < packet.size(); i += 10) {
size_t len = std::min(size_t(10), packet.size() - i);
processor.OnReceive(packet.data() + i, len);
}
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
TEST_F(FragmentedPacketTest, OnlyHeader) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0x01, 0x02};
auto packet = BuildPacket(payload);
// 只发送头部
processor.OnReceive(packet.data(), HDR_LENGTH);
EXPECT_EQ(receivedPackets.size(), 0u);
EXPECT_EQ(processor.GetPendingBytes(), HDR_LENGTH);
// 发送剩余数据
processor.OnReceive(packet.data() + HDR_LENGTH, packet.size() - HDR_LENGTH);
ASSERT_EQ(receivedPackets.size(), 1u);
}
TEST_F(FragmentedPacketTest, PartialHeader) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0xFF};
auto packet = BuildPacket(payload);
// 发送不完整的头部
processor.OnReceive(packet.data(), 4); // 只有 "HELL"
EXPECT_EQ(receivedPackets.size(), 0u);
// 发送剩余部分
processor.OnReceive(packet.data() + 4, packet.size() - 4);
ASSERT_EQ(receivedPackets.size(), 1u);
}
// ============================================
// 粘包测试
// ============================================
class StickyPacketTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(StickyPacketTest, TwoPacketsStuckTogether) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload1 = {0x11, 0x22};
std::vector<uint8_t> payload2 = {0x33, 0x44, 0x55};
auto packet1 = BuildPacket(payload1);
auto packet2 = BuildPacket(payload2);
// 合并两个包
std::vector<uint8_t> combined;
combined.insert(combined.end(), packet1.begin(), packet1.end());
combined.insert(combined.end(), packet2.begin(), packet2.end());
processor.OnReceive(combined.data(), combined.size());
ASSERT_EQ(receivedPackets.size(), 2u);
EXPECT_EQ(receivedPackets[0], payload1);
EXPECT_EQ(receivedPackets[1], payload2);
}
TEST_F(StickyPacketTest, ThreePacketsStuckTogether) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload1 = {0x01};
std::vector<uint8_t> payload2 = {0x02, 0x03};
std::vector<uint8_t> payload3 = {0x04, 0x05, 0x06};
auto packet1 = BuildPacket(payload1);
auto packet2 = BuildPacket(payload2);
auto packet3 = BuildPacket(payload3);
// 合并三个包
std::vector<uint8_t> combined;
combined.insert(combined.end(), packet1.begin(), packet1.end());
combined.insert(combined.end(), packet2.begin(), packet2.end());
combined.insert(combined.end(), packet3.begin(), packet3.end());
processor.OnReceive(combined.data(), combined.size());
ASSERT_EQ(receivedPackets.size(), 3u);
EXPECT_EQ(receivedPackets[0], payload1);
EXPECT_EQ(receivedPackets[1], payload2);
EXPECT_EQ(receivedPackets[2], payload3);
}
TEST_F(StickyPacketTest, ManyPacketsStuckTogether) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> combined;
const int numPackets = 100;
for (int i = 0; i < numPackets; ++i) {
std::vector<uint8_t> payload(i % 10 + 1, static_cast<uint8_t>(i));
auto packet = BuildPacket(payload);
combined.insert(combined.end(), packet.begin(), packet.end());
}
processor.OnReceive(combined.data(), combined.size());
EXPECT_EQ(receivedPackets.size(), numPackets);
}
// ============================================
// 混合场景测试
// ============================================
class MixedScenarioTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(MixedScenarioTest, OneAndHalfPackets) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload1 = {0xAA, 0xBB};
std::vector<uint8_t> payload2 = {0xCC, 0xDD, 0xEE, 0xFF};
auto packet1 = BuildPacket(payload1);
auto packet2 = BuildPacket(payload2);
// 发送完整包1 + 半个包2
std::vector<uint8_t> firstSend;
firstSend.insert(firstSend.end(), packet1.begin(), packet1.end());
firstSend.insert(firstSend.end(), packet2.begin(), packet2.begin() + packet2.size() / 2);
processor.OnReceive(firstSend.data(), firstSend.size());
EXPECT_EQ(receivedPackets.size(), 1u); // 只处理了包1
// 发送剩余的半个包2
processor.OnReceive(packet2.data() + packet2.size() / 2, packet2.size() - packet2.size() / 2);
ASSERT_EQ(receivedPackets.size(), 2u);
EXPECT_EQ(receivedPackets[0], payload1);
EXPECT_EQ(receivedPackets[1], payload2);
}
TEST_F(MixedScenarioTest, HalfPacketThenOneAndHalf) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload1 = {0x11};
std::vector<uint8_t> payload2 = {0x22, 0x33};
auto packet1 = BuildPacket(payload1);
auto packet2 = BuildPacket(payload2);
// 发送半个包1
processor.OnReceive(packet1.data(), packet1.size() / 2);
EXPECT_EQ(receivedPackets.size(), 0u);
// 发送剩余包1 + 完整包2
std::vector<uint8_t> secondSend;
secondSend.insert(secondSend.end(), packet1.begin() + packet1.size() / 2, packet1.end());
secondSend.insert(secondSend.end(), packet2.begin(), packet2.end());
processor.OnReceive(secondSend.data(), secondSend.size());
ASSERT_EQ(receivedPackets.size(), 2u);
}
TEST_F(MixedScenarioTest, RandomChunkSizes) {
PacketProcessor processor(GetCallback());
// 准备多个包
std::vector<std::vector<uint8_t>> payloads;
std::vector<uint8_t> allData;
for (int i = 0; i < 10; ++i) {
std::vector<uint8_t> payload(i * 5 + 10, static_cast<uint8_t>(i));
payloads.push_back(payload);
auto packet = BuildPacket(payload);
allData.insert(allData.end(), packet.begin(), packet.end());
}
// 使用"随机"大小的块发送
size_t chunkSizes[] = {1, 7, 15, 16, 17, 31, 32, 33, 64, 128};
size_t pos = 0;
size_t chunkIdx = 0;
while (pos < allData.size()) {
size_t chunkSize = chunkSizes[chunkIdx % (sizeof(chunkSizes) / sizeof(chunkSizes[0]))];
size_t len = std::min(chunkSize, allData.size() - pos);
processor.OnReceive(allData.data() + pos, len);
pos += len;
chunkIdx++;
}
ASSERT_EQ(receivedPackets.size(), payloads.size());
for (size_t i = 0; i < payloads.size(); ++i) {
EXPECT_EQ(receivedPackets[i], payloads[i]) << "Mismatch at packet " << i;
}
}
// ============================================
// 边界条件测试
// ============================================
class PacketBoundaryTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(PacketBoundaryTest, ExactlyHdrLength) {
PacketProcessor processor(GetCallback());
// 只有头部,无 payload
std::vector<uint8_t> packet(HDR_LENGTH);
PacketHeader* hdr = reinterpret_cast<PacketHeader*>(packet.data());
memcpy(hdr->flag, "HELL", 4);
hdr->flag[6] = 0x42;
hdr->flag[7] = ~0x42;
hdr->packedLength = HDR_LENGTH;
hdr->originalLength = 0;
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_TRUE(receivedPackets[0].empty());
}
TEST_F(PacketBoundaryTest, SingleBytePayload) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0xFF};
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0].size(), 1u);
EXPECT_EQ(receivedPackets[0][0], 0xFF);
}
TEST_F(PacketBoundaryTest, ByteByByteReceive) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload = {0x12, 0x34, 0x56};
auto packet = BuildPacket(payload);
// 每次接收 1 字节
for (size_t i = 0; i < packet.size(); ++i) {
processor.OnReceive(packet.data() + i, 1);
}
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
// ============================================
// 数据完整性测试
// ============================================
class DataIntegrityTest : public ::testing::Test {
protected:
std::vector<std::vector<uint8_t>> receivedPackets;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
receivedPackets.push_back(payload);
};
}
};
TEST_F(DataIntegrityTest, BinaryData) {
PacketProcessor processor(GetCallback());
// 包含所有字节值的 payload
std::vector<uint8_t> payload(256);
for (int i = 0; i < 256; ++i) {
payload[i] = static_cast<uint8_t>(i);
}
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
EXPECT_EQ(receivedPackets[0], payload);
}
TEST_F(DataIntegrityTest, AllZeros) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(100, 0x00);
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
for (uint8_t b : receivedPackets[0]) {
EXPECT_EQ(b, 0x00);
}
}
TEST_F(DataIntegrityTest, AllOnes) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(100, 0xFF);
auto packet = BuildPacket(payload);
processor.OnReceive(packet.data(), packet.size());
ASSERT_EQ(receivedPackets.size(), 1u);
for (uint8_t b : receivedPackets[0]) {
EXPECT_EQ(b, 0xFF);
}
}
// ============================================
// 性能相关测试
// ============================================
class PacketPerformanceTest : public ::testing::Test {
protected:
size_t packetCount = 0;
PacketProcessor::PacketCallback GetCallback() {
return [this](const std::vector<uint8_t>& payload) {
packetCount++;
};
}
};
TEST_F(PacketPerformanceTest, ManySmallPackets) {
PacketProcessor processor(GetCallback());
const int numPackets = 10000;
std::vector<uint8_t> allData;
for (int i = 0; i < numPackets; ++i) {
std::vector<uint8_t> payload = {static_cast<uint8_t>(i & 0xFF)};
auto packet = BuildPacket(payload);
allData.insert(allData.end(), packet.begin(), packet.end());
}
processor.OnReceive(allData.data(), allData.size());
EXPECT_EQ(packetCount, numPackets);
}
TEST_F(PacketPerformanceTest, LargePacketInSmallChunks) {
PacketProcessor processor(GetCallback());
std::vector<uint8_t> payload(100 * 1024); // 100 KB
for (size_t i = 0; i < payload.size(); ++i) {
payload[i] = static_cast<uint8_t>(i & 0xFF);
}
auto packet = BuildPacket(payload);
// 每次发送 1 KB
const size_t chunkSize = 1024;
for (size_t i = 0; i < packet.size(); i += chunkSize) {
size_t len = std::min(chunkSize, packet.size() - i);
processor.OnReceive(packet.data() + i, len);
}
EXPECT_EQ(packetCount, 1u);
}

View File

@@ -0,0 +1,520 @@
/**
* @file PacketTest.cpp
* @brief 协议数据包结构测试
*
* 测试覆盖:
* - 数据包结构大小验证
* - 字段偏移和对齐
* - 序列化/反序列化
* - 边界条件
*/
#include <gtest/gtest.h>
#include <cstring>
#include <vector>
#include <cstdint>
// ============================================
// 协议数据结构定义(测试专用副本)
// ============================================
#pragma pack(push, 1)
// V1 文件传输包
struct FileChunkPacket {
uint8_t cmd;
uint32_t fileIndex;
uint32_t totalNum;
uint64_t fileSize;
uint64_t offset;
uint64_t dataLength;
uint64_t nameLength;
}; // 应该是 41 bytes
// V2 标志位
enum FileFlagsV2 : uint16_t {
FFV2_NONE = 0x0000,
FFV2_LAST_CHUNK = 0x0001,
FFV2_RESUME_REQ = 0x0002,
FFV2_RESUME_RESP = 0x0004,
FFV2_CANCEL = 0x0008,
FFV2_DIRECTORY = 0x0010,
FFV2_COMPRESSED = 0x0020,
FFV2_ERROR = 0x0040,
};
// V2 文件传输包
struct FileChunkPacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint32_t totalFiles;
uint64_t fileSize;
uint64_t offset;
uint64_t dataLength;
uint64_t nameLength;
uint16_t flags;
uint16_t checksum;
uint8_t reserved[8];
}; // 应该是 77 bytes
// V2 断点续传区间
struct FileRangeV2 {
uint64_t offset;
uint64_t length;
}; // 16 bytes
// V2 断点续传控制包
struct FileResumePacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint64_t fileSize;
uint64_t receivedBytes;
uint16_t flags;
uint16_t rangeCount;
}; // 49 bytes
// C2C 准备包
struct C2CPreparePacket {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
}; // 17 bytes
// C2C 准备响应包
struct C2CPrepareRespPacket {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint16_t pathLength;
}; // 19 bytes
// V2 文件完成校验包
struct FileCompletePacketV2 {
uint8_t cmd;
uint64_t transferID;
uint64_t srcClientID;
uint64_t dstClientID;
uint32_t fileIndex;
uint64_t fileSize;
uint8_t sha256[32];
}; // 69 bytes
#pragma pack(pop)
// 命令常量
enum Commands {
COMMAND_SEND_FILE = 68,
COMMAND_SEND_FILE_V2 = 85,
COMMAND_FILE_RESUME = 86,
COMMAND_CLIPBOARD_V2 = 87,
COMMAND_FILE_QUERY_RESUME = 88,
COMMAND_C2C_PREPARE = 89,
COMMAND_C2C_TEXT = 90,
COMMAND_FILE_COMPLETE_V2 = 91,
COMMAND_C2C_PREPARE_RESP = 92,
};
// ============================================
// 结构大小测试
// ============================================
TEST(PacketSizeTest, FileChunkPacket_Size) {
EXPECT_EQ(sizeof(FileChunkPacket), 41u);
}
TEST(PacketSizeTest, FileChunkPacketV2_Size) {
EXPECT_EQ(sizeof(FileChunkPacketV2), 77u);
}
TEST(PacketSizeTest, FileRangeV2_Size) {
EXPECT_EQ(sizeof(FileRangeV2), 16u);
}
TEST(PacketSizeTest, FileResumePacketV2_Size) {
EXPECT_EQ(sizeof(FileResumePacketV2), 49u);
}
TEST(PacketSizeTest, C2CPreparePacket_Size) {
EXPECT_EQ(sizeof(C2CPreparePacket), 17u);
}
TEST(PacketSizeTest, C2CPrepareRespPacket_Size) {
EXPECT_EQ(sizeof(C2CPrepareRespPacket), 19u);
}
TEST(PacketSizeTest, FileCompletePacketV2_Size) {
EXPECT_EQ(sizeof(FileCompletePacketV2), 69u);
}
// ============================================
// 字段偏移测试
// ============================================
TEST(PacketOffsetTest, FileChunkPacketV2_FieldOffsets) {
EXPECT_EQ(offsetof(FileChunkPacketV2, cmd), 0u);
EXPECT_EQ(offsetof(FileChunkPacketV2, transferID), 1u);
EXPECT_EQ(offsetof(FileChunkPacketV2, srcClientID), 9u);
EXPECT_EQ(offsetof(FileChunkPacketV2, dstClientID), 17u);
EXPECT_EQ(offsetof(FileChunkPacketV2, fileIndex), 25u);
EXPECT_EQ(offsetof(FileChunkPacketV2, totalFiles), 29u);
EXPECT_EQ(offsetof(FileChunkPacketV2, fileSize), 33u);
EXPECT_EQ(offsetof(FileChunkPacketV2, offset), 41u);
EXPECT_EQ(offsetof(FileChunkPacketV2, dataLength), 49u);
EXPECT_EQ(offsetof(FileChunkPacketV2, nameLength), 57u);
EXPECT_EQ(offsetof(FileChunkPacketV2, flags), 65u);
EXPECT_EQ(offsetof(FileChunkPacketV2, checksum), 67u);
EXPECT_EQ(offsetof(FileChunkPacketV2, reserved), 69u);
}
TEST(PacketOffsetTest, FileResumePacketV2_FieldOffsets) {
EXPECT_EQ(offsetof(FileResumePacketV2, cmd), 0u);
EXPECT_EQ(offsetof(FileResumePacketV2, transferID), 1u);
EXPECT_EQ(offsetof(FileResumePacketV2, srcClientID), 9u);
EXPECT_EQ(offsetof(FileResumePacketV2, dstClientID), 17u);
EXPECT_EQ(offsetof(FileResumePacketV2, fileIndex), 25u);
EXPECT_EQ(offsetof(FileResumePacketV2, fileSize), 29u);
EXPECT_EQ(offsetof(FileResumePacketV2, receivedBytes), 37u);
EXPECT_EQ(offsetof(FileResumePacketV2, flags), 45u);
EXPECT_EQ(offsetof(FileResumePacketV2, rangeCount), 47u);
}
// ============================================
// 序列化/反序列化测试
// ============================================
class PacketSerializationTest : public ::testing::Test {
protected:
std::vector<uint8_t> buffer;
template<typename T>
void SerializePacket(const T& pkt) {
buffer.resize(sizeof(T));
memcpy(buffer.data(), &pkt, sizeof(T));
}
template<typename T>
T DeserializePacket() {
T pkt;
memcpy(&pkt, buffer.data(), sizeof(T));
return pkt;
}
};
TEST_F(PacketSerializationTest, FileChunkPacketV2_RoundTrip) {
FileChunkPacketV2 original = {};
original.cmd = COMMAND_SEND_FILE_V2;
original.transferID = 0x123456789ABCDEF0ULL;
original.srcClientID = 0x1111111111111111ULL;
original.dstClientID = 0x2222222222222222ULL;
original.fileIndex = 42;
original.totalFiles = 100;
original.fileSize = 1024 * 1024 * 1024; // 1GB
original.offset = 512 * 1024;
original.dataLength = 4096;
original.nameLength = 256;
original.flags = FFV2_LAST_CHUNK | FFV2_COMPRESSED;
original.checksum = 0xABCD;
memset(original.reserved, 0x55, sizeof(original.reserved));
SerializePacket(original);
auto restored = DeserializePacket<FileChunkPacketV2>();
EXPECT_EQ(restored.cmd, original.cmd);
EXPECT_EQ(restored.transferID, original.transferID);
EXPECT_EQ(restored.srcClientID, original.srcClientID);
EXPECT_EQ(restored.dstClientID, original.dstClientID);
EXPECT_EQ(restored.fileIndex, original.fileIndex);
EXPECT_EQ(restored.totalFiles, original.totalFiles);
EXPECT_EQ(restored.fileSize, original.fileSize);
EXPECT_EQ(restored.offset, original.offset);
EXPECT_EQ(restored.dataLength, original.dataLength);
EXPECT_EQ(restored.nameLength, original.nameLength);
EXPECT_EQ(restored.flags, original.flags);
EXPECT_EQ(restored.checksum, original.checksum);
EXPECT_EQ(memcmp(restored.reserved, original.reserved, 8), 0);
}
TEST_F(PacketSerializationTest, FileResumePacketV2_RoundTrip) {
FileResumePacketV2 original = {};
original.cmd = COMMAND_FILE_RESUME;
original.transferID = 0xDEADBEEFCAFEBABEULL;
original.srcClientID = 12345;
original.dstClientID = 67890;
original.fileIndex = 5;
original.fileSize = 0xFFFFFFFFFFFFFFFFULL; // 最大值
original.receivedBytes = 0x8000000000000000ULL; // 大值
original.flags = FFV2_RESUME_RESP;
original.rangeCount = 10;
SerializePacket(original);
auto restored = DeserializePacket<FileResumePacketV2>();
EXPECT_EQ(restored.cmd, original.cmd);
EXPECT_EQ(restored.transferID, original.transferID);
EXPECT_EQ(restored.srcClientID, original.srcClientID);
EXPECT_EQ(restored.dstClientID, original.dstClientID);
EXPECT_EQ(restored.fileIndex, original.fileIndex);
EXPECT_EQ(restored.fileSize, original.fileSize);
EXPECT_EQ(restored.receivedBytes, original.receivedBytes);
EXPECT_EQ(restored.flags, original.flags);
EXPECT_EQ(restored.rangeCount, original.rangeCount);
}
TEST_F(PacketSerializationTest, FileCompletePacketV2_RoundTrip) {
FileCompletePacketV2 original = {};
original.cmd = COMMAND_FILE_COMPLETE_V2;
original.transferID = 99999;
original.srcClientID = 1;
original.dstClientID = 2;
original.fileIndex = 0;
original.fileSize = 12345678;
// 填充 SHA-256
for (int i = 0; i < 32; i++) {
original.sha256[i] = (uint8_t)(i * 8);
}
SerializePacket(original);
auto restored = DeserializePacket<FileCompletePacketV2>();
EXPECT_EQ(restored.cmd, original.cmd);
EXPECT_EQ(restored.transferID, original.transferID);
EXPECT_EQ(restored.fileSize, original.fileSize);
EXPECT_EQ(memcmp(restored.sha256, original.sha256, 32), 0);
}
// ============================================
// 标志位测试
// ============================================
TEST(FlagsTest, FileFlagsV2_SingleFlags) {
EXPECT_EQ(FFV2_NONE, 0x0000);
EXPECT_EQ(FFV2_LAST_CHUNK, 0x0001);
EXPECT_EQ(FFV2_RESUME_REQ, 0x0002);
EXPECT_EQ(FFV2_RESUME_RESP, 0x0004);
EXPECT_EQ(FFV2_CANCEL, 0x0008);
EXPECT_EQ(FFV2_DIRECTORY, 0x0010);
EXPECT_EQ(FFV2_COMPRESSED, 0x0020);
EXPECT_EQ(FFV2_ERROR, 0x0040);
}
TEST(FlagsTest, FileFlagsV2_Combinations) {
uint16_t flags = FFV2_LAST_CHUNK | FFV2_COMPRESSED;
EXPECT_TRUE(flags & FFV2_LAST_CHUNK);
EXPECT_TRUE(flags & FFV2_COMPRESSED);
EXPECT_FALSE(flags & FFV2_DIRECTORY);
EXPECT_FALSE(flags & FFV2_ERROR);
}
TEST(FlagsTest, FileFlagsV2_NoBitOverlap) {
// 验证各标志位不重叠
uint16_t allFlags[] = {
FFV2_LAST_CHUNK, FFV2_RESUME_REQ, FFV2_RESUME_RESP,
FFV2_CANCEL, FFV2_DIRECTORY, FFV2_COMPRESSED, FFV2_ERROR
};
for (size_t i = 0; i < sizeof(allFlags)/sizeof(allFlags[0]); i++) {
for (size_t j = i + 1; j < sizeof(allFlags)/sizeof(allFlags[0]); j++) {
EXPECT_EQ(allFlags[i] & allFlags[j], 0)
<< "Flags overlap: " << allFlags[i] << " & " << allFlags[j];
}
}
}
// ============================================
// 命令常量测试
// ============================================
TEST(CommandTest, CommandValues_AreUnique) {
std::vector<int> commands = {
COMMAND_SEND_FILE,
COMMAND_SEND_FILE_V2,
COMMAND_FILE_RESUME,
COMMAND_CLIPBOARD_V2,
COMMAND_FILE_QUERY_RESUME,
COMMAND_C2C_PREPARE,
COMMAND_C2C_TEXT,
COMMAND_FILE_COMPLETE_V2,
COMMAND_C2C_PREPARE_RESP
};
// 验证无重复
for (size_t i = 0; i < commands.size(); i++) {
for (size_t j = i + 1; j < commands.size(); j++) {
EXPECT_NE(commands[i], commands[j])
<< "Duplicate command values at index " << i << " and " << j;
}
}
}
TEST(CommandTest, CommandValues_FitInByte) {
EXPECT_LE(COMMAND_SEND_FILE, 255);
EXPECT_LE(COMMAND_SEND_FILE_V2, 255);
EXPECT_LE(COMMAND_FILE_RESUME, 255);
EXPECT_LE(COMMAND_FILE_COMPLETE_V2, 255);
}
// ============================================
// 边界条件测试
// ============================================
TEST(BoundaryTest, MaxFileSize) {
FileChunkPacketV2 pkt = {};
pkt.fileSize = UINT64_MAX;
EXPECT_EQ(pkt.fileSize, 0xFFFFFFFFFFFFFFFFULL);
}
TEST(BoundaryTest, MaxOffset) {
FileChunkPacketV2 pkt = {};
pkt.offset = UINT64_MAX;
EXPECT_EQ(pkt.offset, 0xFFFFFFFFFFFFFFFFULL);
}
TEST(BoundaryTest, ZeroValues) {
FileChunkPacketV2 pkt = {};
memset(&pkt, 0, sizeof(pkt));
EXPECT_EQ(pkt.cmd, 0);
EXPECT_EQ(pkt.transferID, 0u);
EXPECT_EQ(pkt.fileSize, 0u);
EXPECT_EQ(pkt.flags, FFV2_NONE);
}
// ============================================
// 带变长数据的包测试
// ============================================
TEST(VariableLengthTest, FileChunkPacketV2_WithFilename) {
const char* filename = "test/folder/file.txt";
size_t nameLen = strlen(filename);
size_t totalSize = sizeof(FileChunkPacketV2) + nameLen;
std::vector<uint8_t> buffer(totalSize);
auto* pkt = reinterpret_cast<FileChunkPacketV2*>(buffer.data());
pkt->cmd = COMMAND_SEND_FILE_V2;
pkt->nameLength = static_cast<uint64_t>(nameLen);
memcpy(buffer.data() + sizeof(FileChunkPacketV2), filename, nameLen);
// 验证可以正确读取
EXPECT_EQ(pkt->nameLength, nameLen);
std::string extractedName(
reinterpret_cast<char*>(buffer.data() + sizeof(FileChunkPacketV2)),
pkt->nameLength
);
EXPECT_EQ(extractedName, filename);
}
TEST(VariableLengthTest, FileChunkPacketV2_WithFilenameAndData) {
const char* filename = "data.bin";
size_t nameLen = strlen(filename);
std::vector<uint8_t> fileData = {0x01, 0x02, 0x03, 0x04, 0x05};
size_t dataLen = fileData.size();
size_t totalSize = sizeof(FileChunkPacketV2) + nameLen + dataLen;
std::vector<uint8_t> buffer(totalSize);
auto* pkt = reinterpret_cast<FileChunkPacketV2*>(buffer.data());
pkt->cmd = COMMAND_SEND_FILE_V2;
pkt->nameLength = nameLen;
pkt->dataLength = dataLen;
// 写入文件名
memcpy(buffer.data() + sizeof(FileChunkPacketV2), filename, nameLen);
// 写入数据
memcpy(buffer.data() + sizeof(FileChunkPacketV2) + nameLen, fileData.data(), dataLen);
// 验证
EXPECT_EQ(buffer.size(), totalSize);
// 提取文件名
std::string extractedName(
reinterpret_cast<char*>(buffer.data() + sizeof(FileChunkPacketV2)),
pkt->nameLength
);
EXPECT_EQ(extractedName, filename);
// 提取数据
std::vector<uint8_t> extractedData(
buffer.begin() + sizeof(FileChunkPacketV2) + nameLen,
buffer.end()
);
EXPECT_EQ(extractedData, fileData);
}
// ============================================
// 断点续传区间测试
// ============================================
TEST(ResumeRangeTest, SingleRange) {
FileRangeV2 range = {0, 1024};
EXPECT_EQ(range.offset, 0u);
EXPECT_EQ(range.length, 1024u);
}
TEST(ResumeRangeTest, MultipleRanges) {
std::vector<FileRangeV2> ranges = {
{0, 1024},
{2048, 512},
{4096, 2048}
};
// 计算总接收字节数
uint64_t totalReceived = 0;
for (const auto& r : ranges) {
totalReceived += r.length;
}
EXPECT_EQ(totalReceived, 1024u + 512u + 2048u);
}
TEST(ResumeRangeTest, PacketWithRanges) {
uint16_t rangeCount = 3;
size_t totalSize = sizeof(FileResumePacketV2) + rangeCount * sizeof(FileRangeV2);
std::vector<uint8_t> buffer(totalSize);
auto* pkt = reinterpret_cast<FileResumePacketV2*>(buffer.data());
pkt->cmd = COMMAND_FILE_RESUME;
pkt->rangeCount = rangeCount;
auto* ranges = reinterpret_cast<FileRangeV2*>(buffer.data() + sizeof(FileResumePacketV2));
ranges[0] = {0, 1024};
ranges[1] = {2048, 512};
ranges[2] = {4096, 2048};
// 验证
EXPECT_EQ(pkt->rangeCount, 3);
EXPECT_EQ(ranges[0].offset, 0u);
EXPECT_EQ(ranges[1].offset, 2048u);
EXPECT_EQ(ranges[2].length, 2048u);
}
// ============================================
// 字节序测试(假设小端)
// ============================================
TEST(EndiannessTest, LittleEndian_Uint64) {
uint64_t value = 0x0102030405060708ULL;
uint8_t bytes[8];
memcpy(bytes, &value, 8);
// 小端:低字节在前
EXPECT_EQ(bytes[0], 0x08);
EXPECT_EQ(bytes[1], 0x07);
EXPECT_EQ(bytes[7], 0x01);
}
TEST(EndiannessTest, FileChunkPacketV2_TransferID) {
FileChunkPacketV2 pkt = {};
pkt.transferID = 0x0102030405060708ULL;
uint8_t* raw = reinterpret_cast<uint8_t*>(&pkt);
// transferID 从 offset 1 开始
EXPECT_EQ(raw[1], 0x08); // 低字节
EXPECT_EQ(raw[8], 0x01); // 高字节
}

View File

@@ -0,0 +1,360 @@
/**
* @file PathUtilsTest.cpp
* @brief 路径处理工具函数测试
*
* 测试覆盖:
* - GetCommonRoot: 计算多文件的公共根目录
* - GetRelativePath: 获取相对路径
* - 边界条件和特殊字符处理
*/
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include <algorithm>
#include <cctype>
// ============================================
// 路径工具函数实现(测试专用副本)
// ============================================
// 计算最长公共前缀作为根目录
std::string GetCommonRoot(const std::vector<std::string>& files)
{
if (files.empty())
return "";
std::string root = files[0];
for (size_t i = 1; i < files.size(); ++i)
{
const std::string& path = files[i];
size_t len = (std::min)(root.size(), path.size());
size_t j = 0;
for (; j < len; ++j)
{
if (std::tolower(static_cast<unsigned char>(root[j])) !=
std::tolower(static_cast<unsigned char>(path[j])))
{
break;
}
}
root = root.substr(0, j);
size_t pos = root.find_last_of('\\');
if (pos != std::string::npos)
root = root.substr(0, pos + 1);
}
return root;
}
// 获取相对路径
std::string GetRelativePath(const std::string& root, const std::string& fullPath)
{
if (fullPath.compare(0, root.size(), root) == 0)
{
std::string rel = fullPath.substr(root.size());
if (rel.empty()) // root 就是完整文件路径
{
size_t pos = fullPath.find_last_of('\\');
if (pos != std::string::npos)
rel = fullPath.substr(pos + 1); // 文件名
else
rel = fullPath;
}
return rel;
}
return fullPath;
}
// ============================================
// GetCommonRoot 测试
// ============================================
class GetCommonRootTest : public ::testing::Test {};
TEST_F(GetCommonRootTest, EmptyList_ReturnsEmpty) {
std::vector<std::string> files;
EXPECT_EQ(GetCommonRoot(files), "");
}
TEST_F(GetCommonRootTest, SingleFile_ReturnsParentDir) {
std::vector<std::string> files = {
"C:\\Users\\Test\\file.txt"
};
// 单个文件时,返回整个路径(因为没有其他文件比较)
EXPECT_EQ(GetCommonRoot(files), "C:\\Users\\Test\\file.txt");
}
TEST_F(GetCommonRootTest, TwoFilesInSameDir_ReturnsDir) {
std::vector<std::string> files = {
"C:\\Users\\Test\\file1.txt",
"C:\\Users\\Test\\file2.txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\Users\\Test\\");
}
TEST_F(GetCommonRootTest, FilesInNestedDirs_ReturnsCommonAncestor) {
std::vector<std::string> files = {
"C:\\Users\\Test\\Documents\\a.txt",
"C:\\Users\\Test\\Downloads\\b.txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\Users\\Test\\");
}
TEST_F(GetCommonRootTest, FilesInDeeplyNestedDirs) {
std::vector<std::string> files = {
"C:\\a\\b\\c\\d\\e\\file1.txt",
"C:\\a\\b\\c\\x\\y\\file2.txt",
"C:\\a\\b\\c\\z\\file3.txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\a\\b\\c\\");
}
TEST_F(GetCommonRootTest, CaseInsensitive) {
std::vector<std::string> files = {
"C:\\Users\\TEST\\file1.txt",
"C:\\users\\test\\file2.txt"
};
// 应该忽略大小写进行比较
std::string root = GetCommonRoot(files);
// 结果应该是原始路径的前缀
EXPECT_TRUE(root.size() >= strlen("C:\\Users\\TEST\\") ||
root.size() >= strlen("C:\\users\\test\\"));
}
TEST_F(GetCommonRootTest, DifferentDrives_ReturnsEmpty) {
std::vector<std::string> files = {
"C:\\Users\\file1.txt",
"D:\\Data\\file2.txt"
};
// 不同驱动器,公共根可能为空或只有共同前缀
std::string root = GetCommonRoot(files);
EXPECT_TRUE(root.empty() || root.find('\\') == std::string::npos);
}
TEST_F(GetCommonRootTest, SameDirectory_MultipleFiles) {
std::vector<std::string> files = {
"C:\\Temp\\a.txt",
"C:\\Temp\\b.txt",
"C:\\Temp\\c.txt",
"C:\\Temp\\d.txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\Temp\\");
}
TEST_F(GetCommonRootTest, DirectoryAndFiles) {
std::vector<std::string> files = {
"C:\\Temp\\folder",
"C:\\Temp\\folder\\file.txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\Temp\\");
}
TEST_F(GetCommonRootTest, ChineseCharacters) {
std::vector<std::string> files = {
"C:\\Users\\测试\\文档\\file1.txt",
"C:\\Users\\测试\\下载\\file2.txt"
};
std::string root = GetCommonRoot(files);
// 应该能正确处理中文路径
EXPECT_TRUE(root.find("测试") != std::string::npos ||
root == "C:\\Users\\");
}
TEST_F(GetCommonRootTest, SpacesInPath) {
std::vector<std::string> files = {
"C:\\Program Files\\App\\file1.txt",
"C:\\Program Files\\App\\file2.txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\Program Files\\App\\");
}
// ============================================
// GetRelativePath 测试
// ============================================
class GetRelativePathTest : public ::testing::Test {};
TEST_F(GetRelativePathTest, SimpleRelativePath) {
std::string root = "C:\\Users\\Test\\";
std::string fullPath = "C:\\Users\\Test\\Documents\\file.txt";
EXPECT_EQ(GetRelativePath(root, fullPath), "Documents\\file.txt");
}
TEST_F(GetRelativePathTest, FileInRoot) {
std::string root = "C:\\Users\\Test\\";
std::string fullPath = "C:\\Users\\Test\\file.txt";
EXPECT_EQ(GetRelativePath(root, fullPath), "file.txt");
}
TEST_F(GetRelativePathTest, RootEqualsFullPath) {
std::string root = "C:\\Users\\Test\\file.txt";
std::string fullPath = "C:\\Users\\Test\\file.txt";
// 当 root 就是完整路径时,应该返回文件名
EXPECT_EQ(GetRelativePath(root, fullPath), "file.txt");
}
TEST_F(GetRelativePathTest, NoCommonPrefix_ReturnsFullPath) {
std::string root = "C:\\Users\\Test\\";
std::string fullPath = "D:\\Other\\file.txt";
// 没有公共前缀时返回完整路径
EXPECT_EQ(GetRelativePath(root, fullPath), fullPath);
}
TEST_F(GetRelativePathTest, NestedPath) {
std::string root = "C:\\a\\";
std::string fullPath = "C:\\a\\b\\c\\d\\file.txt";
EXPECT_EQ(GetRelativePath(root, fullPath), "b\\c\\d\\file.txt");
}
TEST_F(GetRelativePathTest, EmptyRoot) {
std::string root = "";
std::string fullPath = "C:\\Users\\file.txt";
EXPECT_EQ(GetRelativePath(root, fullPath), fullPath);
}
TEST_F(GetRelativePathTest, RootWithoutTrailingSlash) {
std::string root = "C:\\Users\\Test";
std::string fullPath = "C:\\Users\\Test\\file.txt";
// 没有尾部斜杠也应该工作
std::string result = GetRelativePath(root, fullPath);
EXPECT_TRUE(result == "\\file.txt" || result == fullPath);
}
TEST_F(GetRelativePathTest, DirectoryPath) {
std::string root = "C:\\Users\\";
std::string fullPath = "C:\\Users\\Test\\Documents\\";
EXPECT_EQ(GetRelativePath(root, fullPath), "Test\\Documents\\");
}
TEST_F(GetRelativePathTest, FileNameOnly) {
std::string root = "";
std::string fullPath = "file.txt";
EXPECT_EQ(GetRelativePath(root, fullPath), "file.txt");
}
// ============================================
// 组合测试
// ============================================
class PathUtilsCombinedTest : public ::testing::Test {};
TEST_F(PathUtilsCombinedTest, GetCommonRootThenRelative) {
std::vector<std::string> files = {
"C:\\Project\\src\\main.cpp",
"C:\\Project\\src\\utils.cpp",
"C:\\Project\\include\\header.h"
};
std::string root = GetCommonRoot(files);
EXPECT_EQ(root, "C:\\Project\\");
// 获取各文件的相对路径
EXPECT_EQ(GetRelativePath(root, files[0]), "src\\main.cpp");
EXPECT_EQ(GetRelativePath(root, files[1]), "src\\utils.cpp");
EXPECT_EQ(GetRelativePath(root, files[2]), "include\\header.h");
}
TEST_F(PathUtilsCombinedTest, SimulateFileTransfer) {
// 模拟文件传输场景
std::vector<std::string> selectedFiles = {
"D:\\Downloads\\Project\\doc\\readme.md",
"D:\\Downloads\\Project\\src\\app.py",
"D:\\Downloads\\Project\\src\\lib\\util.py"
};
std::string rootDir = GetCommonRoot(selectedFiles);
EXPECT_EQ(rootDir, "D:\\Downloads\\Project\\");
std::string targetDir = "C:\\Backup\\";
for (const auto& file : selectedFiles) {
std::string relPath = GetRelativePath(rootDir, file);
std::string destPath = targetDir + relPath;
// 验证目标路径正确
EXPECT_TRUE(destPath.find("C:\\Backup\\") == 0);
}
}
// ============================================
// 边界条件测试
// ============================================
TEST(PathBoundaryTest, VeryLongPath) {
// 创建一个很长的路径
std::string basePath = "C:\\";
for (int i = 0; i < 50; i++) {
basePath += "folder" + std::to_string(i) + "\\";
}
basePath += "file.txt";
std::vector<std::string> files = {basePath};
std::string root = GetCommonRoot(files);
EXPECT_EQ(root, basePath);
}
TEST(PathBoundaryTest, SpecialCharacters) {
std::vector<std::string> files = {
"C:\\Users\\Test (1)\\file[1].txt",
"C:\\Users\\Test (1)\\file[2].txt"
};
EXPECT_EQ(GetCommonRoot(files), "C:\\Users\\Test (1)\\");
}
TEST(PathBoundaryTest, UNCPath) {
std::vector<std::string> files = {
"\\\\Server\\Share\\folder\\file1.txt",
"\\\\Server\\Share\\folder\\file2.txt"
};
std::string root = GetCommonRoot(files);
EXPECT_TRUE(root.find("\\\\Server\\Share\\folder\\") != std::string::npos ||
root.find("\\\\Server\\Share\\") != std::string::npos);
}
TEST(PathBoundaryTest, ForwardSlashes) {
// 测试正斜杠(虽然 Windows 主要用反斜杠)
std::vector<std::string> files = {
"C:/Users/Test/file1.txt",
"C:/Users/Test/file2.txt"
};
// 函数使用反斜杠,正斜杠可能不被正确处理
std::string root = GetCommonRoot(files);
// 验证不会崩溃
EXPECT_FALSE(root.empty());
}
// ============================================
// 参数化测试
// ============================================
class GetRelativePathParameterizedTest
: public ::testing::TestWithParam<std::tuple<std::string, std::string, std::string>> {};
TEST_P(GetRelativePathParameterizedTest, VariousPaths) {
auto [root, fullPath, expected] = GetParam();
EXPECT_EQ(GetRelativePath(root, fullPath), expected);
}
INSTANTIATE_TEST_SUITE_P(
PathCases,
GetRelativePathParameterizedTest,
::testing::Values(
std::make_tuple("C:\\a\\", "C:\\a\\b.txt", "b.txt"),
std::make_tuple("C:\\a\\", "C:\\a\\b\\c.txt", "b\\c.txt"),
std::make_tuple("C:\\a\\b\\", "C:\\a\\b\\c\\d\\e.txt", "c\\d\\e.txt"),
std::make_tuple("", "file.txt", "file.txt"),
std::make_tuple("C:\\", "C:\\file.txt", "file.txt")
)
);

View File

@@ -0,0 +1,725 @@
// DiffAlgorithmTest.cpp - Phase 4: 差分算法单元测试
// 测试帧间差异计算和编码逻辑
#include <gtest/gtest.h>
#include <vector>
#include <cstdint>
#include <cstring>
#include <algorithm>
#include <random>
// ============================================
// 算法常量定义 (来自 CursorInfo.h)
// ============================================
#define ALGORITHM_GRAY 0 // 灰度算法
#define ALGORITHM_DIFF 1 // 差分算法(默认)
#define ALGORITHM_H264 2 // H264 视频编码
#define ALGORITHM_RGB565 3 // RGB565 压缩
// ============================================
// 差分输出结构
// ============================================
struct DiffRegion {
uint32_t offset; // 字节偏移
uint32_t length; // 长度(含义取决于算法)
std::vector<uint8_t> data; // 差异数据
};
// ============================================
// 测试用差分算法实现
// 模拟 ScreenCapture.h 中的 CompareBitmapDXGI
// ============================================
class DiffAlgorithm {
public:
// 比较两帧,返回差异区域列表
// 输出格式: [offset:4][length:4][data:N]...
static std::vector<DiffRegion> CompareBitmap(
const uint8_t* srcData, // 新帧
const uint8_t* dstData, // 旧帧
size_t dataLength, // 数据长度字节必须是4的倍数
int algorithm // 压缩算法
) {
std::vector<DiffRegion> regions;
if (dataLength == 0 || dataLength % 4 != 0) {
return regions;
}
const uint32_t* src32 = reinterpret_cast<const uint32_t*>(srcData);
const uint32_t* dst32 = reinterpret_cast<const uint32_t*>(dstData);
size_t pixelCount = dataLength / 4;
size_t i = 0;
while (i < pixelCount) {
// 找到差异起始点
while (i < pixelCount && src32[i] == dst32[i]) {
i++;
}
if (i >= pixelCount) break;
// 记录起始位置
size_t startPos = i;
// 找到差异结束点
while (i < pixelCount && src32[i] != dst32[i]) {
i++;
}
// 创建差异区域
DiffRegion region;
region.offset = static_cast<uint32_t>(startPos * 4); // 字节偏移
size_t diffPixels = i - startPos;
const uint8_t* pixelStart = srcData + startPos * 4;
switch (algorithm) {
case ALGORITHM_GRAY: {
// 灰度: 1字节/像素
region.length = static_cast<uint32_t>(diffPixels); // 像素数
region.data.resize(diffPixels);
for (size_t p = 0; p < diffPixels; p++) {
const uint8_t* pixel = pixelStart + p * 4;
// BGRA格式: B=0, G=1, R=2, A=3
// 灰度公式: Y = 0.299*R + 0.587*G + 0.114*B
int gray = (306 * pixel[2] + 601 * pixel[1] + 117 * pixel[0]) >> 10;
region.data[p] = static_cast<uint8_t>(std::min(255, std::max(0, gray)));
}
break;
}
case ALGORITHM_RGB565: {
// RGB565: 2字节/像素
region.length = static_cast<uint32_t>(diffPixels); // 像素数
region.data.resize(diffPixels * 2);
uint16_t* out = reinterpret_cast<uint16_t*>(region.data.data());
for (size_t p = 0; p < diffPixels; p++) {
const uint8_t* pixel = pixelStart + p * 4;
// BGRA -> RGB565
out[p] = ((pixel[2] >> 3) << 11) | // R: 5位
((pixel[1] >> 2) << 5) | // G: 6位
(pixel[0] >> 3); // B: 5位
}
break;
}
case ALGORITHM_DIFF:
case ALGORITHM_H264:
default: {
// DIFF/H264: 4字节/像素原始BGRA
region.length = static_cast<uint32_t>(diffPixels * 4); // 字节数
region.data.resize(diffPixels * 4);
memcpy(region.data.data(), pixelStart, diffPixels * 4);
break;
}
}
regions.push_back(region);
}
return regions;
}
// 序列化差异区域到缓冲区
static size_t SerializeDiffRegions(
const std::vector<DiffRegion>& regions,
uint8_t* buffer,
size_t bufferSize
) {
size_t offset = 0;
for (const auto& region : regions) {
size_t needed = 8 + region.data.size(); // offset(4) + length(4) + data
if (offset + needed > bufferSize) break;
memcpy(buffer + offset, &region.offset, 4);
offset += 4;
memcpy(buffer + offset, &region.length, 4);
offset += 4;
memcpy(buffer + offset, region.data.data(), region.data.size());
offset += region.data.size();
}
return offset;
}
// 应用差异到目标帧
static void ApplyDiff(
uint8_t* dstData,
size_t dstLength,
const std::vector<DiffRegion>& regions,
int algorithm
) {
for (const auto& region : regions) {
if (region.offset >= dstLength) continue;
uint8_t* dst = dstData + region.offset;
switch (algorithm) {
case ALGORITHM_GRAY: {
// 灰度 -> BGRA
for (uint32_t p = 0; p < region.length && region.offset + p * 4 < dstLength; p++) {
uint8_t gray = region.data[p];
dst[p * 4 + 0] = gray; // B
dst[p * 4 + 1] = gray; // G
dst[p * 4 + 2] = gray; // R
dst[p * 4 + 3] = 0xFF; // A
}
break;
}
case ALGORITHM_RGB565: {
// RGB565 -> BGRA
const uint16_t* src = reinterpret_cast<const uint16_t*>(region.data.data());
for (uint32_t p = 0; p < region.length && region.offset + p * 4 < dstLength; p++) {
uint16_t c = src[p];
uint8_t r5 = (c >> 11) & 0x1F;
uint8_t g6 = (c >> 5) & 0x3F;
uint8_t b5 = c & 0x1F;
dst[p * 4 + 0] = (b5 << 3) | (b5 >> 2); // B
dst[p * 4 + 1] = (g6 << 2) | (g6 >> 4); // G
dst[p * 4 + 2] = (r5 << 3) | (r5 >> 2); // R
dst[p * 4 + 3] = 0xFF; // A
}
break;
}
case ALGORITHM_DIFF:
case ALGORITHM_H264:
default: {
// 原始BGRA
size_t copyLen = std::min(static_cast<size_t>(region.length),
dstLength - region.offset);
memcpy(dst, region.data.data(), copyLen);
break;
}
}
}
}
};
// ============================================
// 测试夹具
// ============================================
class DiffAlgorithmTest : public ::testing::Test {
protected:
// 创建纯色帧 (BGRA格式)
static std::vector<uint8_t> CreateSolidFrame(int width, int height,
uint8_t b, uint8_t g,
uint8_t r, uint8_t a = 0xFF) {
std::vector<uint8_t> frame(width * height * 4);
for (int i = 0; i < width * height; i++) {
frame[i * 4 + 0] = b;
frame[i * 4 + 1] = g;
frame[i * 4 + 2] = r;
frame[i * 4 + 3] = a;
}
return frame;
}
// 创建渐变帧
static std::vector<uint8_t> CreateGradientFrame(int width, int height) {
std::vector<uint8_t> frame(width * height * 4);
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = (y * width + x) * 4;
frame[idx + 0] = static_cast<uint8_t>(x * 255 / width); // B
frame[idx + 1] = static_cast<uint8_t>(y * 255 / height); // G
frame[idx + 2] = static_cast<uint8_t>((x + y) * 128 / (width + height)); // R
frame[idx + 3] = 0xFF;
}
}
return frame;
}
// 创建带随机区域变化的帧
static std::vector<uint8_t> CreateFrameWithChanges(
const std::vector<uint8_t>& baseFrame,
int width, int height,
int changeX, int changeY,
int changeW, int changeH,
uint8_t newB, uint8_t newG, uint8_t newR
) {
std::vector<uint8_t> frame = baseFrame;
for (int y = changeY; y < changeY + changeH && y < height; y++) {
for (int x = changeX; x < changeX + changeW && x < width; x++) {
int idx = (y * width + x) * 4;
frame[idx + 0] = newB;
frame[idx + 1] = newG;
frame[idx + 2] = newR;
frame[idx + 3] = 0xFF;
}
}
return frame;
}
};
// ============================================
// 基础功能测试
// ============================================
TEST_F(DiffAlgorithmTest, IdenticalFrames_NoDifference) {
auto frame = CreateSolidFrame(100, 100, 128, 128, 128);
auto regions = DiffAlgorithm::CompareBitmap(
frame.data(), frame.data(), frame.size(), ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 0u);
}
TEST_F(DiffAlgorithmTest, CompletelyDifferent_SingleRegion) {
auto frame1 = CreateSolidFrame(10, 10, 0, 0, 0);
auto frame2 = CreateSolidFrame(10, 10, 255, 255, 255);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].offset, 0u);
EXPECT_EQ(regions[0].length, 100u * 4); // 100像素 * 4字节
}
TEST_F(DiffAlgorithmTest, PartialChange_SingleRegion) {
const int WIDTH = 100, HEIGHT = 100;
auto frame1 = CreateSolidFrame(WIDTH, HEIGHT, 0, 0, 0);
auto frame2 = CreateFrameWithChanges(frame1, WIDTH, HEIGHT,
10, 10, 20, 20, 255, 255, 255);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_DIFF);
// 应该检测到变化区域
EXPECT_GT(regions.size(), 0u);
// 验证总变化像素数
size_t totalChangedPixels = 0;
for (const auto& r : regions) {
totalChangedPixels += r.length / 4; // DIFF算法length是字节数
}
EXPECT_EQ(totalChangedPixels, 20u * 20u); // 20x20区域
}
TEST_F(DiffAlgorithmTest, MultipleRegions_NonContiguous) {
const int WIDTH = 100, HEIGHT = 10;
auto frame1 = CreateSolidFrame(WIDTH, HEIGHT, 128, 128, 128);
auto frame2 = frame1;
// 创建两个不相邻的变化区域
// 区域1: 像素 5-14
for (int i = 5; i < 15; i++) {
frame2[i * 4 + 0] = 0;
frame2[i * 4 + 1] = 0;
frame2[i * 4 + 2] = 255;
}
// 区域2: 像素 50-59 (与区域1不相邻)
for (int i = 50; i < 60; i++) {
frame2[i * 4 + 0] = 255;
frame2[i * 4 + 1] = 0;
frame2[i * 4 + 2] = 0;
}
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 2u);
}
// ============================================
// 算法特定测试
// ============================================
TEST_F(DiffAlgorithmTest, GrayAlgorithm_CorrectOutput) {
auto frame1 = CreateSolidFrame(10, 10, 0, 0, 0); // 黑色
auto frame2 = CreateSolidFrame(10, 10, 255, 255, 255); // 白色
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_GRAY);
ASSERT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].length, 100u); // 100像素
EXPECT_EQ(regions[0].data.size(), 100u); // 1字节/像素
// 白色应该转换为灰度255
EXPECT_EQ(regions[0].data[0], 255);
}
TEST_F(DiffAlgorithmTest, GrayAlgorithm_GrayConversionFormula) {
// 测试灰度转换公式: Y = 0.299*R + 0.587*G + 0.114*B
std::vector<uint8_t> frame1(4, 0); // 1像素黑色
std::vector<uint8_t> frame2(4);
// 测试纯红色 (R=255, G=0, B=0)
frame2[0] = 0; // B
frame2[1] = 0; // G
frame2[2] = 255; // R
frame2[3] = 255; // A
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), 4, ALGORITHM_GRAY);
ASSERT_EQ(regions.size(), 1u);
// 期望: (306 * 255 + 601 * 0 + 117 * 0) >> 10 ≈ 76
uint8_t expectedGray = (306 * 255) >> 10;
EXPECT_NEAR(regions[0].data[0], expectedGray, 1);
}
TEST_F(DiffAlgorithmTest, RGB565Algorithm_CorrectOutput) {
auto frame1 = CreateSolidFrame(10, 10, 0, 0, 0);
auto frame2 = CreateSolidFrame(10, 10, 255, 255, 255);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_RGB565);
ASSERT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].length, 100u); // 100像素
EXPECT_EQ(regions[0].data.size(), 200u); // 2字节/像素
// 白色 RGB565 = 0xFFFF
uint16_t* rgb565 = reinterpret_cast<uint16_t*>(regions[0].data.data());
EXPECT_EQ(rgb565[0], 0xFFFF);
}
TEST_F(DiffAlgorithmTest, RGB565Algorithm_ColorConversion) {
std::vector<uint8_t> frame1(4, 0); // 1像素黑色
std::vector<uint8_t> frame2(4);
// 纯红色 (R=255, G=0, B=0) -> RGB565 = 0xF800
frame2[0] = 0; // B
frame2[1] = 0; // G
frame2[2] = 255; // R
frame2[3] = 255; // A
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), 4, ALGORITHM_RGB565);
ASSERT_EQ(regions.size(), 1u);
uint16_t* rgb565 = reinterpret_cast<uint16_t*>(regions[0].data.data());
EXPECT_EQ(rgb565[0], 0xF800); // 纯红色
}
TEST_F(DiffAlgorithmTest, DiffAlgorithm_PreservesOriginalData) {
std::vector<uint8_t> frame1(8, 0); // 2像素黑色
std::vector<uint8_t> frame2 = {
0x12, 0x34, 0x56, 0x78, // 像素1
0xAB, 0xCD, 0xEF, 0xFF // 像素2
};
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), 8, ALGORITHM_DIFF);
ASSERT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].length, 8u); // 8字节
EXPECT_EQ(regions[0].data, frame2); // 原始数据完整保留
}
// ============================================
// 边界条件测试
// ============================================
TEST_F(DiffAlgorithmTest, EmptyInput_NoRegions) {
auto regions = DiffAlgorithm::CompareBitmap(nullptr, nullptr, 0, ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 0u);
}
TEST_F(DiffAlgorithmTest, SinglePixel_Difference) {
std::vector<uint8_t> frame1 = {0, 0, 0, 255};
std::vector<uint8_t> frame2 = {255, 255, 255, 255};
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), 4, ALGORITHM_DIFF);
ASSERT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].offset, 0u);
EXPECT_EQ(regions[0].length, 4u);
}
TEST_F(DiffAlgorithmTest, SinglePixel_NoDifference) {
std::vector<uint8_t> frame = {100, 150, 200, 255};
auto regions = DiffAlgorithm::CompareBitmap(
frame.data(), frame.data(), 4, ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 0u);
}
TEST_F(DiffAlgorithmTest, NonAlignedLength_Rejected) {
std::vector<uint8_t> data(7); // 不是4的倍数
auto regions = DiffAlgorithm::CompareBitmap(
data.data(), data.data(), data.size(), ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 0u);
}
TEST_F(DiffAlgorithmTest, FirstPixelOnly_Changed) {
std::vector<uint8_t> frame1(40, 128); // 10像素
std::vector<uint8_t> frame2 = frame1;
frame2[0] = 0; // 只改变第一个像素的B通道
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), 40, ALGORITHM_DIFF);
ASSERT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].offset, 0u);
EXPECT_EQ(regions[0].length, 4u); // 只有1个像素
}
TEST_F(DiffAlgorithmTest, LastPixelOnly_Changed) {
std::vector<uint8_t> frame1(40, 128); // 10像素
std::vector<uint8_t> frame2 = frame1;
frame2[36] = 0; // 只改变最后一个像素的B通道
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), 40, ALGORITHM_DIFF);
ASSERT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].offset, 36u); // 第9个像素的偏移
EXPECT_EQ(regions[0].length, 4u);
}
// ============================================
// 序列化测试
// ============================================
TEST_F(DiffAlgorithmTest, Serialize_SingleRegion) {
std::vector<DiffRegion> regions;
DiffRegion r;
r.offset = 100;
r.length = 16;
r.data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
regions.push_back(r);
std::vector<uint8_t> buffer(1024);
size_t written = DiffAlgorithm::SerializeDiffRegions(
regions, buffer.data(), buffer.size());
EXPECT_EQ(written, 8u + 16u); // offset(4) + length(4) + data(16)
// 验证偏移
uint32_t readOffset;
memcpy(&readOffset, buffer.data(), 4);
EXPECT_EQ(readOffset, 100u);
// 验证长度
uint32_t readLength;
memcpy(&readLength, buffer.data() + 4, 4);
EXPECT_EQ(readLength, 16u);
}
TEST_F(DiffAlgorithmTest, Serialize_MultipleRegions) {
std::vector<DiffRegion> regions;
DiffRegion r1;
r1.offset = 0;
r1.length = 4;
r1.data = {1, 2, 3, 4};
regions.push_back(r1);
DiffRegion r2;
r2.offset = 100;
r2.length = 8;
r2.data = {5, 6, 7, 8, 9, 10, 11, 12};
regions.push_back(r2);
std::vector<uint8_t> buffer(1024);
size_t written = DiffAlgorithm::SerializeDiffRegions(
regions, buffer.data(), buffer.size());
EXPECT_EQ(written, (8u + 4u) + (8u + 8u)); // 两个区域
}
TEST_F(DiffAlgorithmTest, Serialize_BufferTooSmall) {
std::vector<DiffRegion> regions;
DiffRegion r;
r.offset = 0;
r.length = 100;
r.data.resize(100, 0xFF);
regions.push_back(r);
std::vector<uint8_t> buffer(10); // 太小
size_t written = DiffAlgorithm::SerializeDiffRegions(
regions, buffer.data(), buffer.size());
EXPECT_EQ(written, 0u); // 无法写入
}
// ============================================
// 应用差异测试
// ============================================
TEST_F(DiffAlgorithmTest, ApplyDiff_DIFF_Reconstructs) {
auto frame1 = CreateGradientFrame(50, 50);
auto frame2 = CreateFrameWithChanges(frame1, 50, 50, 10, 10, 20, 20, 255, 0, 0);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_DIFF);
// 应用差异到frame1的副本
std::vector<uint8_t> reconstructed = frame1;
DiffAlgorithm::ApplyDiff(reconstructed.data(), reconstructed.size(),
regions, ALGORITHM_DIFF);
// 应该完全重建frame2
EXPECT_EQ(reconstructed, frame2);
}
TEST_F(DiffAlgorithmTest, ApplyDiff_GRAY_ApproximateReconstruction) {
auto frame1 = CreateSolidFrame(10, 10, 0, 0, 0);
auto frame2 = CreateSolidFrame(10, 10, 200, 200, 200); // 浅灰色
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_GRAY);
std::vector<uint8_t> reconstructed = frame1;
DiffAlgorithm::ApplyDiff(reconstructed.data(), reconstructed.size(),
regions, ALGORITHM_GRAY);
// 灰度重建R=G=B
for (size_t i = 0; i < reconstructed.size(); i += 4) {
EXPECT_EQ(reconstructed[i], reconstructed[i + 1]); // B == G
EXPECT_EQ(reconstructed[i + 1], reconstructed[i + 2]); // G == R
}
}
TEST_F(DiffAlgorithmTest, ApplyDiff_RGB565_ApproximateReconstruction) {
auto frame1 = CreateSolidFrame(10, 10, 0, 0, 0);
// RGB565有量化误差使用能精确表示的颜色
auto frame2 = CreateSolidFrame(10, 10, 248, 252, 248); // 接近白色
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_RGB565);
std::vector<uint8_t> reconstructed = frame1;
DiffAlgorithm::ApplyDiff(reconstructed.data(), reconstructed.size(),
regions, ALGORITHM_RGB565);
// 验证重建颜色接近原始(允许量化误差)
for (size_t i = 0; i < reconstructed.size(); i += 4) {
EXPECT_NEAR(reconstructed[i], frame2[i], 8); // B
EXPECT_NEAR(reconstructed[i + 1], frame2[i + 1], 4); // G (6位精度)
EXPECT_NEAR(reconstructed[i + 2], frame2[i + 2], 8); // R
}
}
// ============================================
// 性能相关测试(不计时,只验证正确性)
// ============================================
TEST_F(DiffAlgorithmTest, LargeFrame_1080p_Correctness) {
const int WIDTH = 1920, HEIGHT = 1080;
auto frame1 = CreateSolidFrame(WIDTH, HEIGHT, 128, 128, 128);
auto frame2 = CreateFrameWithChanges(frame1, WIDTH, HEIGHT,
100, 100, 200, 200, 255, 0, 0);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_DIFF);
// 应该有变化区域
EXPECT_GT(regions.size(), 0u);
// 重建验证
std::vector<uint8_t> reconstructed = frame1;
DiffAlgorithm::ApplyDiff(reconstructed.data(), reconstructed.size(),
regions, ALGORITHM_DIFF);
EXPECT_EQ(reconstructed, frame2);
}
TEST_F(DiffAlgorithmTest, RandomChanges_AllAlgorithms) {
const int WIDTH = 100, HEIGHT = 100;
std::mt19937 rng(42);
std::uniform_int_distribution<int> dist(0, 255);
auto frame1 = CreateGradientFrame(WIDTH, HEIGHT);
auto frame2 = frame1;
// 随机修改10%的像素
for (int i = 0; i < WIDTH * HEIGHT / 10; i++) {
int idx = (rng() % (WIDTH * HEIGHT)) * 4;
frame2[idx + 0] = dist(rng);
frame2[idx + 1] = dist(rng);
frame2[idx + 2] = dist(rng);
}
// 测试所有算法都能产生输出
for (int algo : {ALGORITHM_GRAY, ALGORITHM_DIFF, ALGORITHM_RGB565}) {
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), algo);
EXPECT_GT(regions.size(), 0u) << "Algorithm " << algo << " failed";
}
}
// ============================================
// 压缩效率测试
// ============================================
TEST_F(DiffAlgorithmTest, CompressionRatio_GRAY) {
auto frame1 = CreateSolidFrame(100, 100, 0, 0, 0);
auto frame2 = CreateSolidFrame(100, 100, 255, 255, 255);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_GRAY);
size_t originalSize = 100 * 100 * 4; // 40000 字节
size_t compressedSize = 8 + regions[0].data.size(); // offset + length + data
// GRAY应该是 100*100*1 = 10000 字节数据
EXPECT_EQ(regions[0].data.size(), 10000u);
// 压缩比约 4:1
EXPECT_LT(compressedSize, originalSize / 3);
}
TEST_F(DiffAlgorithmTest, CompressionRatio_RGB565) {
auto frame1 = CreateSolidFrame(100, 100, 0, 0, 0);
auto frame2 = CreateSolidFrame(100, 100, 255, 255, 255);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_RGB565);
// RGB565应该是 100*100*2 = 20000 字节数据
EXPECT_EQ(regions[0].data.size(), 20000u);
}
TEST_F(DiffAlgorithmTest, NoChange_ZeroOverhead) {
auto frame = CreateGradientFrame(100, 100);
auto regions = DiffAlgorithm::CompareBitmap(
frame.data(), frame.data(), frame.size(), ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 0u);
std::vector<uint8_t> buffer(1024);
size_t written = DiffAlgorithm::SerializeDiffRegions(
regions, buffer.data(), buffer.size());
EXPECT_EQ(written, 0u); // 无变化时零开销
}
// ============================================
// 参数化测试:不同分辨率
// ============================================
class DiffAlgorithmResolutionTest : public ::testing::TestWithParam<std::tuple<int, int>> {};
TEST_P(DiffAlgorithmResolutionTest, Resolution_Correctness) {
auto [width, height] = GetParam();
auto frame1 = std::vector<uint8_t>(width * height * 4, 0);
auto frame2 = std::vector<uint8_t>(width * height * 4, 255);
auto regions = DiffAlgorithm::CompareBitmap(
frame2.data(), frame1.data(), frame1.size(), ALGORITHM_DIFF);
EXPECT_EQ(regions.size(), 1u);
EXPECT_EQ(regions[0].offset, 0u);
EXPECT_EQ(regions[0].length, static_cast<uint32_t>(width * height * 4));
}
INSTANTIATE_TEST_SUITE_P(
Resolutions,
DiffAlgorithmResolutionTest,
::testing::Values(
std::make_tuple(1, 1), // 最小
std::make_tuple(10, 10), // 小
std::make_tuple(100, 100), // 中
std::make_tuple(640, 480), // VGA
std::make_tuple(1280, 720), // 720p
std::make_tuple(1920, 1080) // 1080p
)
);

View File

@@ -0,0 +1,691 @@
// QualityAdaptiveTest.cpp - Phase 4: 质量自适应单元测试
// 测试 RTT 到质量等级的映射和防抖策略
#include <gtest/gtest.h>
#include <vector>
#include <cstdint>
#include <climits>
// ============================================
// 质量等级枚举 (来自 commands.h)
// ============================================
enum QualityLevel {
QUALITY_DISABLED = -2, // 关闭质量控制
QUALITY_ADAPTIVE = -1, // 自适应模式
QUALITY_ULTRA = 0, // 极佳(局域网)
QUALITY_HIGH = 1, // 优秀
QUALITY_GOOD = 2, // 良好
QUALITY_MEDIUM = 3, // 一般
QUALITY_LOW = 4, // 较差
QUALITY_MINIMAL = 5, // 最低
QUALITY_COUNT = 6,
};
// ============================================
// 算法枚举
// ============================================
#define ALGORITHM_GRAY 0
#define ALGORITHM_DIFF 1
#define ALGORITHM_H264 2
#define ALGORITHM_RGB565 3
// ============================================
// 质量配置结构体 (来自 commands.h)
// ============================================
struct QualityProfile {
int maxFPS; // 最大帧率
int maxWidth; // 最大宽度 (0=不限)
int algorithm; // 压缩算法
int bitRate; // kbps (仅H264使用)
};
// 默认质量配置表
static const QualityProfile g_QualityProfiles[QUALITY_COUNT] = {
{25, 0, ALGORITHM_DIFF, 0 }, // Ultra: 25FPS, 原始, DIFF
{20, 0, ALGORITHM_RGB565, 0 }, // High: 20FPS, 原始, RGB565
{20, 1920, ALGORITHM_H264, 3000}, // Good: 20FPS, 1080P, H264
{15, 1600, ALGORITHM_H264, 2000}, // Medium: 15FPS, 900P, H264
{12, 1280, ALGORITHM_H264, 1200}, // Low: 12FPS, 720P, H264
{8, 1024, ALGORITHM_H264, 800 }, // Minimal: 8FPS, 540P, H264
};
// ============================================
// RTT 阈值表 (来自 commands.h)
// ============================================
// 行0: 直连模式, 行1: FRP代理模式
static const int g_RttThresholds[2][QUALITY_COUNT] = {
/* DIRECT */ { 30, 80, 150, 250, 400, INT_MAX },
/* PROXY */ { 60, 160, 300, 500, 800, INT_MAX },
};
// ============================================
// RTT到质量等级映射函数 (来自 commands.h)
// ============================================
inline int GetTargetQualityLevel(int rtt, int usingFRP)
{
int row = usingFRP ? 1 : 0;
for (int level = 0; level < QUALITY_COUNT; level++) {
if (rtt < g_RttThresholds[row][level]) {
return level;
}
}
return QUALITY_MINIMAL;
}
// ============================================
// 防抖策略模拟器
// ============================================
class QualityDebouncer {
public:
// 防抖参数
static const int DOWNGRADE_STABLE_COUNT = 2; // 降级所需稳定次数
static const int UPGRADE_STABLE_COUNT = 5; // 升级所需稳定次数
static const int DEFAULT_COOLDOWN_MS = 3000; // 默认冷却时间
static const int RES_CHANGE_DOWNGRADE_COOLDOWN_MS = 15000; // 分辨率降级冷却
static const int RES_CHANGE_UPGRADE_COOLDOWN_MS = 30000; // 分辨率升级冷却
static const int STARTUP_DELAY_MS = 60000; // 启动延迟
QualityDebouncer()
: m_currentLevel(QUALITY_HIGH)
, m_stableCount(0)
, m_lastChangeTime(0)
, m_startTime(0)
, m_enabled(true)
{}
void Reset() {
m_currentLevel = QUALITY_HIGH;
m_stableCount = 0;
m_lastChangeTime = 0;
m_startTime = 0;
}
void SetStartTime(uint64_t time) {
m_startTime = time;
}
void SetEnabled(bool enabled) {
m_enabled = enabled;
}
// 评估并返回新的质量等级
// 返回 -1 表示不改变
int Evaluate(int targetLevel, uint64_t currentTime, bool resolutionChange = false) {
if (!m_enabled) return -1;
// 启动延迟
if (currentTime - m_startTime < STARTUP_DELAY_MS) {
return -1;
}
// 冷却时间检查
uint64_t cooldown = DEFAULT_COOLDOWN_MS;
if (resolutionChange) {
cooldown = (targetLevel > m_currentLevel)
? RES_CHANGE_DOWNGRADE_COOLDOWN_MS
: RES_CHANGE_UPGRADE_COOLDOWN_MS;
}
if (currentTime - m_lastChangeTime < cooldown) {
return -1;
}
// 降级: 快速响应
if (targetLevel > m_currentLevel) {
m_stableCount++;
if (m_stableCount >= DOWNGRADE_STABLE_COUNT) {
int newLevel = targetLevel;
m_currentLevel = newLevel;
m_stableCount = 0;
m_lastChangeTime = currentTime;
return newLevel;
}
}
// 升级: 谨慎处理,每次只升一级
else if (targetLevel < m_currentLevel) {
m_stableCount++;
if (m_stableCount >= UPGRADE_STABLE_COUNT) {
int newLevel = m_currentLevel - 1; // 只升一级
m_currentLevel = newLevel;
m_stableCount = 0;
m_lastChangeTime = currentTime;
return newLevel;
}
}
// 目标等级等于当前等级
else {
m_stableCount = 0;
}
return -1;
}
int GetCurrentLevel() const { return m_currentLevel; }
int GetStableCount() const { return m_stableCount; }
private:
int m_currentLevel;
int m_stableCount;
uint64_t m_lastChangeTime;
uint64_t m_startTime;
bool m_enabled;
};
// ============================================
// 测试夹具
// ============================================
class QualityAdaptiveTest : public ::testing::Test {
protected:
void SetUp() override {
debouncer.Reset();
}
QualityDebouncer debouncer;
};
// ============================================
// RTT 映射测试 - 直连模式
// ============================================
TEST_F(QualityAdaptiveTest, RTT_Direct_Ultra) {
// RTT < 30ms -> Ultra
EXPECT_EQ(GetTargetQualityLevel(0, 0), QUALITY_ULTRA);
EXPECT_EQ(GetTargetQualityLevel(10, 0), QUALITY_ULTRA);
EXPECT_EQ(GetTargetQualityLevel(29, 0), QUALITY_ULTRA);
}
TEST_F(QualityAdaptiveTest, RTT_Direct_High) {
// 30ms <= RTT < 80ms -> High
EXPECT_EQ(GetTargetQualityLevel(30, 0), QUALITY_HIGH);
EXPECT_EQ(GetTargetQualityLevel(50, 0), QUALITY_HIGH);
EXPECT_EQ(GetTargetQualityLevel(79, 0), QUALITY_HIGH);
}
TEST_F(QualityAdaptiveTest, RTT_Direct_Good) {
// 80ms <= RTT < 150ms -> Good
EXPECT_EQ(GetTargetQualityLevel(80, 0), QUALITY_GOOD);
EXPECT_EQ(GetTargetQualityLevel(100, 0), QUALITY_GOOD);
EXPECT_EQ(GetTargetQualityLevel(149, 0), QUALITY_GOOD);
}
TEST_F(QualityAdaptiveTest, RTT_Direct_Medium) {
// 150ms <= RTT < 250ms -> Medium
EXPECT_EQ(GetTargetQualityLevel(150, 0), QUALITY_MEDIUM);
EXPECT_EQ(GetTargetQualityLevel(200, 0), QUALITY_MEDIUM);
EXPECT_EQ(GetTargetQualityLevel(249, 0), QUALITY_MEDIUM);
}
TEST_F(QualityAdaptiveTest, RTT_Direct_Low) {
// 250ms <= RTT < 400ms -> Low
EXPECT_EQ(GetTargetQualityLevel(250, 0), QUALITY_LOW);
EXPECT_EQ(GetTargetQualityLevel(300, 0), QUALITY_LOW);
EXPECT_EQ(GetTargetQualityLevel(399, 0), QUALITY_LOW);
}
TEST_F(QualityAdaptiveTest, RTT_Direct_Minimal) {
// RTT >= 400ms -> Minimal
EXPECT_EQ(GetTargetQualityLevel(400, 0), QUALITY_MINIMAL);
EXPECT_EQ(GetTargetQualityLevel(500, 0), QUALITY_MINIMAL);
EXPECT_EQ(GetTargetQualityLevel(1000, 0), QUALITY_MINIMAL);
}
// ============================================
// RTT 映射测试 - FRP代理模式
// ============================================
TEST_F(QualityAdaptiveTest, RTT_FRP_Ultra) {
// RTT < 60ms -> Ultra (FRP模式阈值更宽松)
EXPECT_EQ(GetTargetQualityLevel(0, 1), QUALITY_ULTRA);
EXPECT_EQ(GetTargetQualityLevel(30, 1), QUALITY_ULTRA);
EXPECT_EQ(GetTargetQualityLevel(59, 1), QUALITY_ULTRA);
}
TEST_F(QualityAdaptiveTest, RTT_FRP_High) {
// 60ms <= RTT < 160ms -> High
EXPECT_EQ(GetTargetQualityLevel(60, 1), QUALITY_HIGH);
EXPECT_EQ(GetTargetQualityLevel(100, 1), QUALITY_HIGH);
EXPECT_EQ(GetTargetQualityLevel(159, 1), QUALITY_HIGH);
}
TEST_F(QualityAdaptiveTest, RTT_FRP_Good) {
// 160ms <= RTT < 300ms -> Good
EXPECT_EQ(GetTargetQualityLevel(160, 1), QUALITY_GOOD);
EXPECT_EQ(GetTargetQualityLevel(200, 1), QUALITY_GOOD);
EXPECT_EQ(GetTargetQualityLevel(299, 1), QUALITY_GOOD);
}
TEST_F(QualityAdaptiveTest, RTT_FRP_Medium) {
// 300ms <= RTT < 500ms -> Medium
EXPECT_EQ(GetTargetQualityLevel(300, 1), QUALITY_MEDIUM);
EXPECT_EQ(GetTargetQualityLevel(400, 1), QUALITY_MEDIUM);
EXPECT_EQ(GetTargetQualityLevel(499, 1), QUALITY_MEDIUM);
}
TEST_F(QualityAdaptiveTest, RTT_FRP_Low) {
// 500ms <= RTT < 800ms -> Low
EXPECT_EQ(GetTargetQualityLevel(500, 1), QUALITY_LOW);
EXPECT_EQ(GetTargetQualityLevel(600, 1), QUALITY_LOW);
EXPECT_EQ(GetTargetQualityLevel(799, 1), QUALITY_LOW);
}
TEST_F(QualityAdaptiveTest, RTT_FRP_Minimal) {
// RTT >= 800ms -> Minimal
EXPECT_EQ(GetTargetQualityLevel(800, 1), QUALITY_MINIMAL);
EXPECT_EQ(GetTargetQualityLevel(1000, 1), QUALITY_MINIMAL);
EXPECT_EQ(GetTargetQualityLevel(2000, 1), QUALITY_MINIMAL);
}
// ============================================
// 质量配置表测试
// ============================================
TEST_F(QualityAdaptiveTest, Profile_Ultra) {
const auto& p = g_QualityProfiles[QUALITY_ULTRA];
EXPECT_EQ(p.maxFPS, 25);
EXPECT_EQ(p.maxWidth, 0); // 无限制
EXPECT_EQ(p.algorithm, ALGORITHM_DIFF);
EXPECT_EQ(p.bitRate, 0);
}
TEST_F(QualityAdaptiveTest, Profile_High) {
const auto& p = g_QualityProfiles[QUALITY_HIGH];
EXPECT_EQ(p.maxFPS, 20);
EXPECT_EQ(p.maxWidth, 0); // 无限制
EXPECT_EQ(p.algorithm, ALGORITHM_RGB565);
EXPECT_EQ(p.bitRate, 0);
}
TEST_F(QualityAdaptiveTest, Profile_Good) {
const auto& p = g_QualityProfiles[QUALITY_GOOD];
EXPECT_EQ(p.maxFPS, 20);
EXPECT_EQ(p.maxWidth, 1920); // 1080p
EXPECT_EQ(p.algorithm, ALGORITHM_H264);
EXPECT_EQ(p.bitRate, 3000);
}
TEST_F(QualityAdaptiveTest, Profile_Medium) {
const auto& p = g_QualityProfiles[QUALITY_MEDIUM];
EXPECT_EQ(p.maxFPS, 15);
EXPECT_EQ(p.maxWidth, 1600); // 900p
EXPECT_EQ(p.algorithm, ALGORITHM_H264);
EXPECT_EQ(p.bitRate, 2000);
}
TEST_F(QualityAdaptiveTest, Profile_Low) {
const auto& p = g_QualityProfiles[QUALITY_LOW];
EXPECT_EQ(p.maxFPS, 12);
EXPECT_EQ(p.maxWidth, 1280); // 720p
EXPECT_EQ(p.algorithm, ALGORITHM_H264);
EXPECT_EQ(p.bitRate, 1200);
}
TEST_F(QualityAdaptiveTest, Profile_Minimal) {
const auto& p = g_QualityProfiles[QUALITY_MINIMAL];
EXPECT_EQ(p.maxFPS, 8);
EXPECT_EQ(p.maxWidth, 1024); // 540p
EXPECT_EQ(p.algorithm, ALGORITHM_H264);
EXPECT_EQ(p.bitRate, 800);
}
// ============================================
// 防抖策略测试 - 启动延迟
// ============================================
TEST_F(QualityAdaptiveTest, Debounce_StartupDelay) {
debouncer.SetStartTime(0);
// 启动后60秒内不应改变
EXPECT_EQ(debouncer.Evaluate(QUALITY_MINIMAL, 1000), -1);
EXPECT_EQ(debouncer.Evaluate(QUALITY_MINIMAL, 30000), -1);
EXPECT_EQ(debouncer.Evaluate(QUALITY_MINIMAL, 59999), -1);
}
TEST_F(QualityAdaptiveTest, Debounce_AfterStartupDelay) {
debouncer.SetStartTime(0);
// 启动60秒后应该可以改变
// 需要连续2次降级请求
debouncer.Evaluate(QUALITY_MINIMAL, 60000); // 第1次
int result = debouncer.Evaluate(QUALITY_MINIMAL, 60001); // 第2次
EXPECT_EQ(result, QUALITY_MINIMAL);
}
// ============================================
// 防抖策略测试 - 降级
// ============================================
TEST_F(QualityAdaptiveTest, Debounce_Downgrade_RequiresTwice) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 第1次降级请求 - 不应立即执行
int result = debouncer.Evaluate(QUALITY_MINIMAL, time);
EXPECT_EQ(result, -1);
EXPECT_EQ(debouncer.GetStableCount(), 1);
// 第2次降级请求 - 应该执行
result = debouncer.Evaluate(QUALITY_MINIMAL, time + 100);
EXPECT_EQ(result, QUALITY_MINIMAL);
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_MINIMAL);
}
TEST_F(QualityAdaptiveTest, Debounce_Downgrade_ResetOnStable) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 第1次降级请求
debouncer.Evaluate(QUALITY_LOW, time);
EXPECT_EQ(debouncer.GetStableCount(), 1);
// 目标等级恢复 - 计数应重置
debouncer.Evaluate(QUALITY_HIGH, time + 100); // 当前等级
EXPECT_EQ(debouncer.GetStableCount(), 0);
}
// ============================================
// 防抖策略测试 - 升级
// ============================================
TEST_F(QualityAdaptiveTest, Debounce_Upgrade_RequiresFiveTimes) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 先降级到 MINIMAL
debouncer.Evaluate(QUALITY_MINIMAL, time);
debouncer.Evaluate(QUALITY_MINIMAL, time + 100);
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_MINIMAL);
// 尝试升级到 ULTRA (需要5次)
time += 5000; // 冷却后
for (int i = 0; i < 4; i++) {
int result = debouncer.Evaluate(QUALITY_ULTRA, time + i * 100);
EXPECT_EQ(result, -1); // 前4次不应执行
EXPECT_EQ(debouncer.GetStableCount(), i + 1);
}
// 第5次应该执行但只升一级
int result = debouncer.Evaluate(QUALITY_ULTRA, time + 500);
EXPECT_EQ(result, QUALITY_LOW); // MINIMAL -> LOW (只升一级)
}
TEST_F(QualityAdaptiveTest, Debounce_Upgrade_OneStepAtATime) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 降级到 MINIMAL
debouncer.Evaluate(QUALITY_MINIMAL, time);
debouncer.Evaluate(QUALITY_MINIMAL, time + 100);
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_MINIMAL);
// 多次升级请求,验证每次只升一级
// 从 MINIMAL(5) 升到 HIGH(1) 需要4次升级每次需要5个稳定请求
time += 5000;
int upgradeCount = 0;
while (debouncer.GetCurrentLevel() > QUALITY_HIGH && upgradeCount < 10) {
for (int i = 0; i < 5; i++) {
debouncer.Evaluate(QUALITY_ULTRA, time); // 请求最高等级
time += 100;
}
time += 5000; // 冷却
upgradeCount++;
}
// 最终应该回到 HIGH (或更高)
EXPECT_LE(debouncer.GetCurrentLevel(), QUALITY_HIGH);
}
// ============================================
// 防抖策略测试 - 冷却时间
// ============================================
TEST_F(QualityAdaptiveTest, Debounce_DefaultCooldown) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 执行一次降级
debouncer.Evaluate(QUALITY_LOW, time);
debouncer.Evaluate(QUALITY_LOW, time + 100);
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_LOW);
// 冷却期内不应再次改变
int result = debouncer.Evaluate(QUALITY_MINIMAL, time + 200);
EXPECT_EQ(result, -1);
result = debouncer.Evaluate(QUALITY_MINIMAL, time + 2999);
EXPECT_EQ(result, -1);
}
TEST_F(QualityAdaptiveTest, Debounce_AfterCooldown) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 执行一次降级
debouncer.Evaluate(QUALITY_LOW, time);
debouncer.Evaluate(QUALITY_LOW, time + 100);
// 冷却后应该可以再次改变
time += 3100;
debouncer.Evaluate(QUALITY_MINIMAL, time);
int result = debouncer.Evaluate(QUALITY_MINIMAL, time + 100);
EXPECT_EQ(result, QUALITY_MINIMAL);
}
// ============================================
// 防抖策略测试 - 分辨率变化冷却
// ============================================
TEST_F(QualityAdaptiveTest, Debounce_ResolutionChange_DowngradeCooldown) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 执行一次带分辨率变化的降级
debouncer.Evaluate(QUALITY_LOW, time, true); // resolutionChange=true
debouncer.Evaluate(QUALITY_LOW, time + 100, true);
// 分辨率降级冷却15秒
int result = debouncer.Evaluate(QUALITY_MINIMAL, time + 14000, true);
EXPECT_EQ(result, -1);
// 15秒后可以
time += 16000;
debouncer.Evaluate(QUALITY_MINIMAL, time, true);
result = debouncer.Evaluate(QUALITY_MINIMAL, time + 100, true);
EXPECT_EQ(result, QUALITY_MINIMAL);
}
TEST_F(QualityAdaptiveTest, Debounce_ResolutionChange_UpgradeCooldown) {
debouncer.SetStartTime(0);
uint64_t time = 60000;
// 先降级到 MINIMAL (不带分辨率变化,使用默认冷却)
debouncer.Evaluate(QUALITY_MINIMAL, time);
debouncer.Evaluate(QUALITY_MINIMAL, time + 100);
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_MINIMAL);
// 等待足够长时间(>30秒进行带分辨率变化的升级
time += 35000; // 超过30秒冷却
for (int i = 0; i < 5; i++) {
debouncer.Evaluate(QUALITY_ULTRA, time + i * 100, true); // resolutionChange=true
}
// 应该成功升级一级 (MINIMAL -> LOW)
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_LOW);
// 30秒内不应再次升级带分辨率变化的升级需要30秒冷却
time += 1000;
for (int i = 0; i < 5; i++) {
debouncer.Evaluate(QUALITY_ULTRA, time + i * 100, true);
}
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_LOW); // 未改变还在30秒冷却期内
}
// ============================================
// 禁用自适应测试
// ============================================
TEST_F(QualityAdaptiveTest, Debounce_Disabled) {
debouncer.SetStartTime(0);
debouncer.SetEnabled(false);
uint64_t time = 60000;
int result = debouncer.Evaluate(QUALITY_MINIMAL, time);
EXPECT_EQ(result, -1);
result = debouncer.Evaluate(QUALITY_MINIMAL, time + 100);
EXPECT_EQ(result, -1);
// 质量等级应保持不变
EXPECT_EQ(debouncer.GetCurrentLevel(), QUALITY_HIGH);
}
// ============================================
// 边界值测试
// ============================================
TEST_F(QualityAdaptiveTest, RTT_Boundary_Zero) {
EXPECT_EQ(GetTargetQualityLevel(0, 0), QUALITY_ULTRA);
EXPECT_EQ(GetTargetQualityLevel(0, 1), QUALITY_ULTRA);
}
TEST_F(QualityAdaptiveTest, RTT_Boundary_Negative) {
// 负值RTT应该仍返回最高质量
EXPECT_EQ(GetTargetQualityLevel(-1, 0), QUALITY_ULTRA);
EXPECT_EQ(GetTargetQualityLevel(-100, 0), QUALITY_ULTRA);
}
TEST_F(QualityAdaptiveTest, RTT_Boundary_VeryHigh) {
// 非常高的RTT
EXPECT_EQ(GetTargetQualityLevel(10000, 0), QUALITY_MINIMAL);
EXPECT_EQ(GetTargetQualityLevel(100000, 0), QUALITY_MINIMAL);
EXPECT_EQ(GetTargetQualityLevel(INT_MAX - 1, 0), QUALITY_MINIMAL);
}
// ============================================
// 常量验证测试
// ============================================
TEST(QualityConstantsTest, QualityLevelEnum) {
EXPECT_EQ(QUALITY_DISABLED, -2);
EXPECT_EQ(QUALITY_ADAPTIVE, -1);
EXPECT_EQ(QUALITY_ULTRA, 0);
EXPECT_EQ(QUALITY_HIGH, 1);
EXPECT_EQ(QUALITY_GOOD, 2);
EXPECT_EQ(QUALITY_MEDIUM, 3);
EXPECT_EQ(QUALITY_LOW, 4);
EXPECT_EQ(QUALITY_MINIMAL, 5);
EXPECT_EQ(QUALITY_COUNT, 6);
}
TEST(QualityConstantsTest, ProfileCount) {
// 应该有 QUALITY_COUNT 个配置
EXPECT_EQ(sizeof(g_QualityProfiles) / sizeof(g_QualityProfiles[0]),
static_cast<size_t>(QUALITY_COUNT));
}
TEST(QualityConstantsTest, ThresholdCount) {
// 每行应该有 QUALITY_COUNT 个阈值
EXPECT_EQ(sizeof(g_RttThresholds[0]) / sizeof(g_RttThresholds[0][0]),
static_cast<size_t>(QUALITY_COUNT));
}
TEST(QualityConstantsTest, ThresholdIncreasing) {
// 阈值应该递增
for (int row = 0; row < 2; row++) {
for (int i = 0; i < QUALITY_COUNT - 1; i++) {
EXPECT_LT(g_RttThresholds[row][i], g_RttThresholds[row][i + 1])
<< "Row " << row << ", index " << i;
}
}
}
TEST(QualityConstantsTest, FRPThresholdsHigher) {
// FRP模式阈值应该比直连模式高
for (int i = 0; i < QUALITY_COUNT - 1; i++) {
EXPECT_GT(g_RttThresholds[1][i], g_RttThresholds[0][i])
<< "Index " << i;
}
}
// ============================================
// 参数化测试RTT值遍历
// ============================================
class RTTMappingTest : public ::testing::TestWithParam<std::tuple<int, int, int>> {};
TEST_P(RTTMappingTest, RTT_Mapping) {
auto [rtt, usingFRP, expectedLevel] = GetParam();
EXPECT_EQ(GetTargetQualityLevel(rtt, usingFRP), expectedLevel);
}
INSTANTIATE_TEST_SUITE_P(
DirectMode,
RTTMappingTest,
::testing::Values(
std::make_tuple(0, 0, QUALITY_ULTRA),
std::make_tuple(29, 0, QUALITY_ULTRA),
std::make_tuple(30, 0, QUALITY_HIGH),
std::make_tuple(79, 0, QUALITY_HIGH),
std::make_tuple(80, 0, QUALITY_GOOD),
std::make_tuple(149, 0, QUALITY_GOOD),
std::make_tuple(150, 0, QUALITY_MEDIUM),
std::make_tuple(249, 0, QUALITY_MEDIUM),
std::make_tuple(250, 0, QUALITY_LOW),
std::make_tuple(399, 0, QUALITY_LOW),
std::make_tuple(400, 0, QUALITY_MINIMAL),
std::make_tuple(1000, 0, QUALITY_MINIMAL)
)
);
INSTANTIATE_TEST_SUITE_P(
FRPMode,
RTTMappingTest,
::testing::Values(
std::make_tuple(0, 1, QUALITY_ULTRA),
std::make_tuple(59, 1, QUALITY_ULTRA),
std::make_tuple(60, 1, QUALITY_HIGH),
std::make_tuple(159, 1, QUALITY_HIGH),
std::make_tuple(160, 1, QUALITY_GOOD),
std::make_tuple(299, 1, QUALITY_GOOD),
std::make_tuple(300, 1, QUALITY_MEDIUM),
std::make_tuple(499, 1, QUALITY_MEDIUM),
std::make_tuple(500, 1, QUALITY_LOW),
std::make_tuple(799, 1, QUALITY_LOW),
std::make_tuple(800, 1, QUALITY_MINIMAL),
std::make_tuple(2000, 1, QUALITY_MINIMAL)
)
);
// ============================================
// 质量配置合理性测试
// ============================================
TEST(QualityProfileTest, FPSDecreasing) {
// FPS 应该随质量降低而减少
for (int i = 0; i < QUALITY_COUNT - 1; i++) {
EXPECT_GE(g_QualityProfiles[i].maxFPS, g_QualityProfiles[i + 1].maxFPS)
<< "Level " << i;
}
}
TEST(QualityProfileTest, MaxWidthDecreasing) {
// maxWidth 应该随质量降低而减少除了0表示不限
int prevWidth = INT_MAX;
for (int i = 0; i < QUALITY_COUNT; i++) {
int width = g_QualityProfiles[i].maxWidth;
if (width > 0) {
EXPECT_LE(width, prevWidth) << "Level " << i;
prevWidth = width;
}
}
}
TEST(QualityProfileTest, BitRateDecreasing) {
// H264 bitRate 应该随质量降低而减少
int prevBitRate = INT_MAX;
for (int i = 0; i < QUALITY_COUNT; i++) {
if (g_QualityProfiles[i].algorithm == ALGORITHM_H264) {
int bitRate = g_QualityProfiles[i].bitRate;
EXPECT_LE(bitRate, prevBitRate) << "Level " << i;
prevBitRate = bitRate;
}
}
}

View File

@@ -0,0 +1,597 @@
// RGB565Test.cpp - Phase 4: RGB565压缩单元测试
// 测试 BGRA <-> RGB565 颜色空间转换
#include <gtest/gtest.h>
#include <vector>
#include <cstdint>
#include <cstring>
#include <random>
#include <cmath>
// ============================================
// RGB565 颜色空间转换实现
// 来源: client/ScreenCapture.h, server/ScreenSpyDlg.cpp
// ============================================
// RGB565 格式说明:
// 16位: RRRRRGGG GGGBBBBB
// R: 5位 (0-31)
// G: 6位 (0-63)
// B: 5位 (0-31)
class RGB565Converter {
public:
// ============================================
// BGRA -> RGB565 转换 (标量版本)
// ============================================
static void ConvertBGRAtoRGB565_Scalar(
const uint8_t* src,
uint16_t* dst,
size_t pixelCount
) {
for (size_t i = 0; i < pixelCount; i++) {
// BGRA 格式: B=0, G=1, R=2, A=3
uint8_t b = src[i * 4 + 0];
uint8_t g = src[i * 4 + 1];
uint8_t r = src[i * 4 + 2];
// A 通道被忽略
// RGB565: RRRRRGGG GGGBBBBB
dst[i] = static_cast<uint16_t>(
((r >> 3) << 11) | // R: 高5位 -> 位11-15
((g >> 2) << 5) | // G: 高6位 -> 位5-10
(b >> 3) // B: 高5位 -> 位0-4
);
}
}
// ============================================
// RGB565 -> BGRA 转换
// 位复制填充低位以提高精度
// ============================================
static void ConvertRGB565ToBGRA(
const uint16_t* src,
uint8_t* dst,
size_t pixelCount
) {
for (size_t i = 0; i < pixelCount; i++) {
uint16_t c = src[i];
// 提取各通道
uint8_t r5 = (c >> 11) & 0x1F; // 5位红色
uint8_t g6 = (c >> 5) & 0x3F; // 6位绿色
uint8_t b5 = c & 0x1F; // 5位蓝色
// 扩展到8位使用位复制填充低位
// 例如: 5位 11111 -> 8位 11111111 (通过 (x << 3) | (x >> 2))
dst[i * 4 + 0] = (b5 << 3) | (b5 >> 2); // B: 5->8位
dst[i * 4 + 1] = (g6 << 2) | (g6 >> 4); // G: 6->8位
dst[i * 4 + 2] = (r5 << 3) | (r5 >> 2); // R: 5->8位
dst[i * 4 + 3] = 0xFF; // A: 不透明
}
}
// ============================================
// 单像素转换辅助函数
// ============================================
static uint16_t BGRAToRGB565(uint8_t b, uint8_t g, uint8_t r) {
return static_cast<uint16_t>(
((r >> 3) << 11) |
((g >> 2) << 5) |
(b >> 3)
);
}
static void RGB565ToBGRA(uint16_t c, uint8_t& b, uint8_t& g, uint8_t& r, uint8_t& a) {
uint8_t r5 = (c >> 11) & 0x1F;
uint8_t g6 = (c >> 5) & 0x3F;
uint8_t b5 = c & 0x1F;
b = (b5 << 3) | (b5 >> 2);
g = (g6 << 2) | (g6 >> 4);
r = (r5 << 3) | (r5 >> 2);
a = 0xFF;
}
// ============================================
// 计算量化误差
// ============================================
static int CalculateError(uint8_t original, uint8_t converted) {
return std::abs(static_cast<int>(original) - static_cast<int>(converted));
}
// 计算最大理论误差
// 5位通道: 量化+位填充可能产生最大误差 7
// 6位通道: 量化+位填充可能产生最大误差 3
static int MaxError5Bit() { return 7; }
static int MaxError6Bit() { return 3; }
};
// ============================================
// 测试夹具
// ============================================
class RGB565Test : public ::testing::Test {
protected:
// 创建BGRA像素数组
static std::vector<uint8_t> CreateBGRAPixels(size_t count, uint8_t b, uint8_t g, uint8_t r, uint8_t a = 0xFF) {
std::vector<uint8_t> pixels(count * 4);
for (size_t i = 0; i < count; i++) {
pixels[i * 4 + 0] = b;
pixels[i * 4 + 1] = g;
pixels[i * 4 + 2] = r;
pixels[i * 4 + 3] = a;
}
return pixels;
}
};
// ============================================
// 基础转换测试
// ============================================
TEST_F(RGB565Test, SinglePixel_Black) {
uint16_t result = RGB565Converter::BGRAToRGB565(0, 0, 0);
EXPECT_EQ(result, 0x0000);
}
TEST_F(RGB565Test, SinglePixel_White) {
uint16_t result = RGB565Converter::BGRAToRGB565(255, 255, 255);
// R=31<<11=0xF800, G=63<<5=0x07E0, B=31=0x001F
// 合计: 0xFFFF
EXPECT_EQ(result, 0xFFFF);
}
TEST_F(RGB565Test, SinglePixel_Red) {
uint16_t result = RGB565Converter::BGRAToRGB565(0, 0, 255);
// R=31<<11=0xF800, G=0, B=0
EXPECT_EQ(result, 0xF800);
}
TEST_F(RGB565Test, SinglePixel_Green) {
uint16_t result = RGB565Converter::BGRAToRGB565(0, 255, 0);
// R=0, G=63<<5=0x07E0, B=0
EXPECT_EQ(result, 0x07E0);
}
TEST_F(RGB565Test, SinglePixel_Blue) {
uint16_t result = RGB565Converter::BGRAToRGB565(255, 0, 0);
// R=0, G=0, B=31=0x001F
EXPECT_EQ(result, 0x001F);
}
// ============================================
// 反向转换测试
// ============================================
TEST_F(RGB565Test, Reverse_Black) {
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(0x0000, b, g, r, a);
EXPECT_EQ(b, 0);
EXPECT_EQ(g, 0);
EXPECT_EQ(r, 0);
EXPECT_EQ(a, 255);
}
TEST_F(RGB565Test, Reverse_White) {
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(0xFFFF, b, g, r, a);
EXPECT_EQ(b, 255);
EXPECT_EQ(g, 255);
EXPECT_EQ(r, 255);
EXPECT_EQ(a, 255);
}
TEST_F(RGB565Test, Reverse_Red) {
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(0xF800, b, g, r, a);
EXPECT_EQ(b, 0);
EXPECT_EQ(g, 0);
EXPECT_EQ(r, 255);
EXPECT_EQ(a, 255);
}
TEST_F(RGB565Test, Reverse_Green) {
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(0x07E0, b, g, r, a);
EXPECT_EQ(b, 0);
EXPECT_EQ(g, 255);
EXPECT_EQ(r, 0);
EXPECT_EQ(a, 255);
}
TEST_F(RGB565Test, Reverse_Blue) {
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(0x001F, b, g, r, a);
EXPECT_EQ(b, 255);
EXPECT_EQ(g, 0);
EXPECT_EQ(r, 0);
EXPECT_EQ(a, 255);
}
// ============================================
// 往返转换测试
// ============================================
TEST_F(RGB565Test, RoundTrip_ExactColors) {
// 测试能精确表示的颜色 (只有 0 和 255 能完美往返)
// 位填充公式: (x << 3) | (x >> 2) 只有 x=0 -> 0, x=31 -> 255 是精确的
struct TestCase {
uint8_t b, g, r;
} testCases[] = {
{0, 0, 0}, // 黑色 - 精确
{255, 255, 255}, // 白色 - 精确
{0, 0, 255}, // 红色 - 精确
{0, 255, 0}, // 绿色 - 精确
{255, 0, 0}, // 蓝色 - 精确
{255, 255, 0}, // 青色 - 精确
{0, 255, 255}, // 黄色 - 精确
{255, 0, 255}, // 品红 - 精确
};
for (const auto& tc : testCases) {
uint16_t rgb565 = RGB565Converter::BGRAToRGB565(tc.b, tc.g, tc.r);
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(rgb565, b, g, r, a);
// 能精确表示的颜色应该完美还原
EXPECT_EQ(b, tc.b) << "B channel mismatch for (" << (int)tc.r << "," << (int)tc.g << "," << (int)tc.b << ")";
EXPECT_EQ(g, tc.g) << "G channel mismatch for (" << (int)tc.r << "," << (int)tc.g << "," << (int)tc.b << ")";
EXPECT_EQ(r, tc.r) << "R channel mismatch for (" << (int)tc.r << "," << (int)tc.g << "," << (int)tc.b << ")";
}
}
TEST_F(RGB565Test, RoundTrip_QuantizationError) {
// 测试所有颜色的量化误差在允许范围内
std::mt19937 rng(12345);
std::uniform_int_distribution<int> dist(0, 255);
for (int i = 0; i < 1000; i++) {
uint8_t origB = dist(rng);
uint8_t origG = dist(rng);
uint8_t origR = dist(rng);
uint16_t rgb565 = RGB565Converter::BGRAToRGB565(origB, origG, origR);
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(rgb565, b, g, r, a);
// 验证误差在允许范围内
EXPECT_LE(RGB565Converter::CalculateError(origB, b), RGB565Converter::MaxError5Bit());
EXPECT_LE(RGB565Converter::CalculateError(origG, g), RGB565Converter::MaxError6Bit());
EXPECT_LE(RGB565Converter::CalculateError(origR, r), RGB565Converter::MaxError5Bit());
}
}
// ============================================
// 批量转换测试
// ============================================
TEST_F(RGB565Test, BatchConvert_SinglePixel) {
auto bgra = CreateBGRAPixels(1, 128, 64, 192);
std::vector<uint16_t> rgb565(1);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra.data(), rgb565.data(), 1);
// 验证转换结果
uint16_t expected = RGB565Converter::BGRAToRGB565(128, 64, 192);
EXPECT_EQ(rgb565[0], expected);
}
TEST_F(RGB565Test, BatchConvert_MultiplePixels) {
const size_t COUNT = 100;
auto bgra = CreateBGRAPixels(COUNT, 100, 150, 200);
std::vector<uint16_t> rgb565(COUNT);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra.data(), rgb565.data(), COUNT);
uint16_t expected = RGB565Converter::BGRAToRGB565(100, 150, 200);
for (size_t i = 0; i < COUNT; i++) {
EXPECT_EQ(rgb565[i], expected);
}
}
TEST_F(RGB565Test, BatchReverse_MultiplePixels) {
const size_t COUNT = 100;
std::vector<uint16_t> rgb565(COUNT, 0xF800); // 红色
std::vector<uint8_t> bgra(COUNT * 4);
RGB565Converter::ConvertRGB565ToBGRA(rgb565.data(), bgra.data(), COUNT);
for (size_t i = 0; i < COUNT; i++) {
EXPECT_EQ(bgra[i * 4 + 0], 0); // B
EXPECT_EQ(bgra[i * 4 + 1], 0); // G
EXPECT_EQ(bgra[i * 4 + 2], 255); // R
EXPECT_EQ(bgra[i * 4 + 3], 255); // A
}
}
TEST_F(RGB565Test, BatchRoundTrip) {
const size_t COUNT = 1000;
std::mt19937 rng(42);
std::uniform_int_distribution<int> dist(0, 255);
// 创建随机BGRA数据
std::vector<uint8_t> original(COUNT * 4);
for (size_t i = 0; i < COUNT * 4; i++) {
original[i] = (i % 4 == 3) ? 255 : dist(rng);
}
// 转换到 RGB565
std::vector<uint16_t> rgb565(COUNT);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(original.data(), rgb565.data(), COUNT);
// 转换回 BGRA
std::vector<uint8_t> reconstructed(COUNT * 4);
RGB565Converter::ConvertRGB565ToBGRA(rgb565.data(), reconstructed.data(), COUNT);
// 验证所有像素误差在允许范围内
for (size_t i = 0; i < COUNT; i++) {
EXPECT_LE(RGB565Converter::CalculateError(original[i * 4 + 0], reconstructed[i * 4 + 0]),
RGB565Converter::MaxError5Bit()) << "Pixel " << i << " B";
EXPECT_LE(RGB565Converter::CalculateError(original[i * 4 + 1], reconstructed[i * 4 + 1]),
RGB565Converter::MaxError6Bit()) << "Pixel " << i << " G";
EXPECT_LE(RGB565Converter::CalculateError(original[i * 4 + 2], reconstructed[i * 4 + 2]),
RGB565Converter::MaxError5Bit()) << "Pixel " << i << " R";
EXPECT_EQ(reconstructed[i * 4 + 3], 255) << "Pixel " << i << " A";
}
}
// ============================================
// 边界值测试
// ============================================
TEST_F(RGB565Test, Boundary_AllZeros) {
uint16_t result = RGB565Converter::BGRAToRGB565(0, 0, 0);
EXPECT_EQ(result, 0);
}
TEST_F(RGB565Test, Boundary_AllOnes) {
uint16_t result = RGB565Converter::BGRAToRGB565(255, 255, 255);
EXPECT_EQ(result, 0xFFFF);
}
TEST_F(RGB565Test, Boundary_ChannelMax) {
// 单通道最大值
EXPECT_EQ(RGB565Converter::BGRAToRGB565(255, 0, 0), 0x001F); // B
EXPECT_EQ(RGB565Converter::BGRAToRGB565(0, 255, 0), 0x07E0); // G
EXPECT_EQ(RGB565Converter::BGRAToRGB565(0, 0, 255), 0xF800); // R
}
TEST_F(RGB565Test, Boundary_ChannelMin) {
// 单通道最小值(其他为最大)
EXPECT_EQ(RGB565Converter::BGRAToRGB565(0, 255, 255), 0xFFE0); // 无B
EXPECT_EQ(RGB565Converter::BGRAToRGB565(255, 0, 255), 0xF81F); // 无G
EXPECT_EQ(RGB565Converter::BGRAToRGB565(255, 255, 0), 0x07FF); // 无R
}
// ============================================
// 位填充测试
// ============================================
TEST_F(RGB565Test, BitFilling_5BitExpansion) {
// 测试5位扩展到8位的位填充策略
// 5位值 11111 (31) -> 8位值 11111111 (255)
// 公式: (x << 3) | (x >> 2)
// 最大值: 31 -> 255
uint8_t expanded = (31 << 3) | (31 >> 2);
EXPECT_EQ(expanded, 255);
// 最小值: 0 -> 0
expanded = (0 << 3) | (0 >> 2);
EXPECT_EQ(expanded, 0);
// 中间值: 16 -> 132 (10000 -> 10000100)
expanded = (16 << 3) | (16 >> 2);
EXPECT_EQ(expanded, 132);
}
TEST_F(RGB565Test, BitFilling_6BitExpansion) {
// 测试6位扩展到8位的位填充策略
// 6位值 111111 (63) -> 8位值 11111111 (255)
// 公式: (x << 2) | (x >> 4)
// 最大值: 63 -> 255
uint8_t expanded = (63 << 2) | (63 >> 4);
EXPECT_EQ(expanded, 255);
// 最小值: 0 -> 0
expanded = (0 << 2) | (0 >> 4);
EXPECT_EQ(expanded, 0);
// 中间值: 32 -> 130 (100000 -> 10000010)
expanded = (32 << 2) | (32 >> 4);
EXPECT_EQ(expanded, 130);
}
// ============================================
// Alpha通道测试
// ============================================
TEST_F(RGB565Test, Alpha_Ignored) {
// 不同alpha值应该产生相同的RGB565
auto bgra1 = CreateBGRAPixels(1, 100, 100, 100, 255);
auto bgra2 = CreateBGRAPixels(1, 100, 100, 100, 128);
auto bgra3 = CreateBGRAPixels(1, 100, 100, 100, 0);
std::vector<uint16_t> rgb565_1(1), rgb565_2(1), rgb565_3(1);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra1.data(), rgb565_1.data(), 1);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra2.data(), rgb565_2.data(), 1);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra3.data(), rgb565_3.data(), 1);
EXPECT_EQ(rgb565_1[0], rgb565_2[0]);
EXPECT_EQ(rgb565_2[0], rgb565_3[0]);
}
TEST_F(RGB565Test, Alpha_RestoredToOpaque) {
// 反向转换时Alpha应该恢复为255
std::vector<uint16_t> rgb565 = {0x1234, 0x5678, 0x9ABC};
std::vector<uint8_t> bgra(3 * 4);
RGB565Converter::ConvertRGB565ToBGRA(rgb565.data(), bgra.data(), 3);
for (size_t i = 0; i < 3; i++) {
EXPECT_EQ(bgra[i * 4 + 3], 255);
}
}
// ============================================
// 压缩率测试
// ============================================
TEST_F(RGB565Test, CompressionRatio) {
// BGRA: 4字节/像素
// RGB565: 2字节/像素
// 压缩率: 50%
const size_t PIXEL_COUNT = 1920 * 1080;
size_t bgraSize = PIXEL_COUNT * 4;
size_t rgb565Size = PIXEL_COUNT * 2;
EXPECT_EQ(rgb565Size, bgraSize / 2);
EXPECT_DOUBLE_EQ(static_cast<double>(rgb565Size) / bgraSize, 0.5);
}
// ============================================
// 特殊颜色测试
// ============================================
TEST_F(RGB565Test, CommonColors) {
struct ColorTest {
const char* name;
uint8_t r, g, b;
uint16_t expectedRGB565;
} colors[] = {
{"Black", 0, 0, 0, 0x0000},
{"White", 255, 255, 255, 0xFFFF},
{"Red", 255, 0, 0, 0xF800},
{"Green", 0, 255, 0, 0x07E0},
{"Blue", 0, 0, 255, 0x001F},
{"Yellow", 255, 255, 0, 0xFFE0},
{"Cyan", 0, 255, 255, 0x07FF},
{"Magenta", 255, 0, 255, 0xF81F},
};
for (const auto& c : colors) {
uint16_t result = RGB565Converter::BGRAToRGB565(c.b, c.g, c.r);
EXPECT_EQ(result, c.expectedRGB565) << "Color: " << c.name;
}
}
TEST_F(RGB565Test, GrayScales) {
// 测试灰度值转换
for (int gray = 0; gray <= 255; gray += 17) {
uint8_t g = static_cast<uint8_t>(gray);
uint16_t rgb565 = RGB565Converter::BGRAToRGB565(g, g, g);
uint8_t b, gr, r, a;
RGB565Converter::RGB565ToBGRA(rgb565, b, gr, r, a);
// 灰度值在量化误差范围内应该保持一致
EXPECT_NEAR(b, g, 8);
EXPECT_NEAR(gr, g, 4);
EXPECT_NEAR(r, g, 8);
}
}
// ============================================
// 参数化测试
// ============================================
class RGB565ChannelTest : public ::testing::TestWithParam<int> {};
TEST_P(RGB565ChannelTest, ChannelValueRange) {
int value = GetParam();
uint8_t v = static_cast<uint8_t>(value);
// 测试R通道
{
uint16_t rgb565 = RGB565Converter::BGRAToRGB565(0, 0, v);
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(rgb565, b, g, r, a);
EXPECT_LE(RGB565Converter::CalculateError(v, r), RGB565Converter::MaxError5Bit());
}
// 测试G通道
{
uint16_t rgb565 = RGB565Converter::BGRAToRGB565(0, v, 0);
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(rgb565, b, g, r, a);
EXPECT_LE(RGB565Converter::CalculateError(v, g), RGB565Converter::MaxError6Bit());
}
// 测试B通道
{
uint16_t rgb565 = RGB565Converter::BGRAToRGB565(v, 0, 0);
uint8_t b, g, r, a;
RGB565Converter::RGB565ToBGRA(rgb565, b, g, r, a);
EXPECT_LE(RGB565Converter::CalculateError(v, b), RGB565Converter::MaxError5Bit());
}
}
INSTANTIATE_TEST_SUITE_P(
AllValues,
RGB565ChannelTest,
::testing::Range(0, 256, 1) // 测试0-255所有值
);
// ============================================
// 大数据量测试
// ============================================
TEST_F(RGB565Test, LargeFrame_1080p) {
const size_t WIDTH = 1920, HEIGHT = 1080;
const size_t COUNT = WIDTH * HEIGHT;
// 创建渐变图像
std::vector<uint8_t> bgra(COUNT * 4);
for (size_t y = 0; y < HEIGHT; y++) {
for (size_t x = 0; x < WIDTH; x++) {
size_t idx = (y * WIDTH + x) * 4;
bgra[idx + 0] = static_cast<uint8_t>(x * 255 / WIDTH);
bgra[idx + 1] = static_cast<uint8_t>(y * 255 / HEIGHT);
bgra[idx + 2] = static_cast<uint8_t>((x + y) * 127 / (WIDTH + HEIGHT));
bgra[idx + 3] = 255;
}
}
// 转换
std::vector<uint16_t> rgb565(COUNT);
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra.data(), rgb565.data(), COUNT);
// 验证大小
EXPECT_EQ(rgb565.size() * sizeof(uint16_t), COUNT * 2);
// 抽样验证
for (size_t i = 0; i < COUNT; i += COUNT / 100) {
uint16_t expected = RGB565Converter::BGRAToRGB565(bgra[i * 4 + 0], bgra[i * 4 + 1], bgra[i * 4 + 2]);
EXPECT_EQ(rgb565[i], expected) << "Mismatch at pixel " << i;
}
}
// ============================================
// 端序测试
// ============================================
TEST_F(RGB565Test, Endianness_LittleEndian) {
// RGB565 应该以小端存储
uint16_t rgb565 = 0x1234;
uint8_t* bytes = reinterpret_cast<uint8_t*>(&rgb565);
// 在小端系统上: bytes[0] = 0x34, bytes[1] = 0x12
EXPECT_EQ(bytes[0], 0x34);
EXPECT_EQ(bytes[1], 0x12);
}
// ============================================
// 错误处理测试
// ============================================
TEST_F(RGB565Test, ZeroPixelCount) {
std::vector<uint8_t> bgra(4);
std::vector<uint16_t> rgb565(1);
// 零像素不应崩溃
RGB565Converter::ConvertBGRAtoRGB565_Scalar(bgra.data(), rgb565.data(), 0);
RGB565Converter::ConvertRGB565ToBGRA(rgb565.data(), bgra.data(), 0);
}

View File

@@ -0,0 +1,686 @@
// ScrollDetectorTest.cpp - Phase 4: 滚动检测单元测试
// 测试屏幕滚动检测和优化传输
#include <gtest/gtest.h>
#include <vector>
#include <cstdint>
#include <cstring>
#include <random>
#include <algorithm>
// ============================================
// 滚动检测常量定义 (来自 ScrollDetector.h)
// ============================================
#define MIN_SCROLL_LINES 16 // 最小滚动行数
#define MAX_SCROLL_RATIO 4 // 最大滚动 = 高度 / 4
#define MATCH_THRESHOLD 85 // 行匹配百分比阈值 (85%)
// 滚动方向常量
#define SCROLL_DIR_UP 0 // 向上滚动(内容向下移动)
#define SCROLL_DIR_DOWN 1 // 向下滚动(内容向上移动)
// ============================================
// CRC32 哈希计算
// ============================================
class CRC32 {
public:
static uint32_t Calculate(const uint8_t* data, size_t length) {
static uint32_t table[256] = {0};
static bool tableInit = false;
if (!tableInit) {
for (uint32_t i = 0; i < 256; i++) {
uint32_t c = i;
for (int j = 0; j < 8; j++) {
c = (c & 1) ? (0xEDB88320 ^ (c >> 1)) : (c >> 1);
}
table[i] = c;
}
tableInit = true;
}
uint32_t crc = 0xFFFFFFFF;
for (size_t i = 0; i < length; i++) {
crc = table[(crc ^ data[i]) & 0xFF] ^ (crc >> 8);
}
return crc ^ 0xFFFFFFFF;
}
};
// ============================================
// 滚动检测器实现 (模拟 ScrollDetector.h)
// ============================================
class CScrollDetector {
public:
CScrollDetector(int width, int height, int bpp = 4)
: m_width(width), m_height(height), m_bpp(bpp)
, m_stride(width * bpp)
, m_minScroll(MIN_SCROLL_LINES)
, m_maxScroll(height / MAX_SCROLL_RATIO)
{
m_rowHashes.resize(height);
}
// 检测垂直滚动
// 返回: >0 向下滚动, <0 向上滚动, 0 无滚动
int DetectVerticalScroll(const uint8_t* prevFrame, const uint8_t* currFrame) {
if (!prevFrame || !currFrame) return 0;
// 计算当前帧的行哈希
std::vector<uint32_t> currHashes(m_height);
for (int y = 0; y < m_height; y++) {
currHashes[y] = CRC32::Calculate(currFrame + y * m_stride, m_stride);
}
// 计算前一帧的行哈希
std::vector<uint32_t> prevHashes(m_height);
for (int y = 0; y < m_height; y++) {
prevHashes[y] = CRC32::Calculate(prevFrame + y * m_stride, m_stride);
}
int bestScroll = 0;
int bestMatchCount = 0;
// 尝试各种滚动量
for (int scroll = m_minScroll; scroll <= m_maxScroll; scroll++) {
// 向下滚动 (正值)
int matchDown = CountMatchingRows(prevHashes, currHashes, scroll);
if (matchDown > bestMatchCount) {
bestMatchCount = matchDown;
bestScroll = scroll;
}
// 向上滚动 (负值)
int matchUp = CountMatchingRows(prevHashes, currHashes, -scroll);
if (matchUp > bestMatchCount) {
bestMatchCount = matchUp;
bestScroll = -scroll;
}
}
// 检查是否达到匹配阈值
int scrollAbs = std::abs(bestScroll);
int totalRows = m_height - scrollAbs;
int threshold = totalRows * MATCH_THRESHOLD / 100;
if (bestMatchCount >= threshold) {
return bestScroll;
}
return 0;
}
// 获取边缘区域信息
void GetEdgeRegion(int scrollAmount, int* outOffset, int* outPixelCount) const {
if (scrollAmount > 0) {
// 向下滚动: 新内容在底部 (BMP底上格式: 低地址)
*outOffset = 0;
*outPixelCount = scrollAmount * m_width;
} else if (scrollAmount < 0) {
// 向上滚动: 新内容在顶部 (BMP底上格式: 高地址)
*outOffset = (m_height + scrollAmount) * m_stride;
*outPixelCount = (-scrollAmount) * m_width;
} else {
*outOffset = 0;
*outPixelCount = 0;
}
}
int GetMinScroll() const { return m_minScroll; }
int GetMaxScroll() const { return m_maxScroll; }
int GetWidth() const { return m_width; }
int GetHeight() const { return m_height; }
private:
int CountMatchingRows(const std::vector<uint32_t>& prevHashes,
const std::vector<uint32_t>& currHashes,
int scroll) const {
int matchCount = 0;
if (scroll > 0) {
// 向下滚动: prev[y] 对应 curr[y + scroll]
for (int y = 0; y < m_height - scroll; y++) {
if (prevHashes[y] == currHashes[y + scroll]) {
matchCount++;
}
}
} else if (scroll < 0) {
// 向上滚动: prev[y] 对应 curr[y + scroll] (scroll为负)
int absScroll = -scroll;
for (int y = absScroll; y < m_height; y++) {
if (prevHashes[y] == currHashes[y + scroll]) {
matchCount++;
}
}
}
return matchCount;
}
int m_width;
int m_height;
int m_bpp;
int m_stride;
int m_minScroll;
int m_maxScroll;
std::vector<uint32_t> m_rowHashes;
};
// ============================================
// 测试夹具
// ============================================
class ScrollDetectorTest : public ::testing::Test {
public:
// 创建纯色帧
static std::vector<uint8_t> CreateSolidFrame(int width, int height,
uint8_t b, uint8_t g,
uint8_t r, uint8_t a = 0xFF) {
std::vector<uint8_t> frame(width * height * 4);
for (int i = 0; i < width * height; i++) {
frame[i * 4 + 0] = b;
frame[i * 4 + 1] = g;
frame[i * 4 + 2] = r;
frame[i * 4 + 3] = a;
}
return frame;
}
// 创建带条纹的帧 (每行不同颜色,便于检测滚动)
static std::vector<uint8_t> CreateStripedFrame(int width, int height) {
std::vector<uint8_t> frame(width * height * 4);
for (int y = 0; y < height; y++) {
uint8_t color = static_cast<uint8_t>(y % 256);
for (int x = 0; x < width; x++) {
int idx = (y * width + x) * 4;
frame[idx + 0] = color;
frame[idx + 1] = color;
frame[idx + 2] = color;
frame[idx + 3] = 0xFF;
}
}
return frame;
}
// 模拟向下滚动 (内容向上移动)
static std::vector<uint8_t> SimulateScrollDown(const std::vector<uint8_t>& frame,
int width, int height,
int scrollAmount) {
std::vector<uint8_t> result(frame.size());
int stride = width * 4;
// 复制滚动后的内容
for (int y = scrollAmount; y < height; y++) {
memcpy(result.data() + (y - scrollAmount) * stride,
frame.data() + y * stride, stride);
}
// 底部新内容用不同颜色填充
for (int y = height - scrollAmount; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = (y * width + x) * 4;
result[idx + 0] = 0xFF; // 新内容用白色
result[idx + 1] = 0xFF;
result[idx + 2] = 0xFF;
result[idx + 3] = 0xFF;
}
}
return result;
}
// 模拟向上滚动 (内容向下移动)
static std::vector<uint8_t> SimulateScrollUp(const std::vector<uint8_t>& frame,
int width, int height,
int scrollAmount) {
std::vector<uint8_t> result(frame.size());
int stride = width * 4;
// 复制滚动后的内容
for (int y = 0; y < height - scrollAmount; y++) {
memcpy(result.data() + (y + scrollAmount) * stride,
frame.data() + y * stride, stride);
}
// 顶部新内容用不同颜色填充
for (int y = 0; y < scrollAmount; y++) {
for (int x = 0; x < width; x++) {
int idx = (y * width + x) * 4;
result[idx + 0] = 0x00; // 新内容用黑色
result[idx + 1] = 0x00;
result[idx + 2] = 0x00;
result[idx + 3] = 0xFF;
}
}
return result;
}
};
// ============================================
// 基础功能测试
// ============================================
TEST_F(ScrollDetectorTest, Constructor_ValidParameters) {
CScrollDetector detector(100, 200);
EXPECT_EQ(detector.GetWidth(), 100);
EXPECT_EQ(detector.GetHeight(), 200);
EXPECT_EQ(detector.GetMinScroll(), MIN_SCROLL_LINES);
EXPECT_EQ(detector.GetMaxScroll(), 200 / MAX_SCROLL_RATIO);
}
TEST_F(ScrollDetectorTest, IdenticalFrames_NoScroll) {
const int WIDTH = 100, HEIGHT = 100;
auto frame = CreateStripedFrame(WIDTH, HEIGHT);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame.data(), frame.data());
EXPECT_EQ(scroll, 0);
}
TEST_F(ScrollDetectorTest, CompletelyDifferent_NoScroll) {
const int WIDTH = 100, HEIGHT = 100;
auto frame1 = CreateSolidFrame(WIDTH, HEIGHT, 0, 0, 0);
auto frame2 = CreateSolidFrame(WIDTH, HEIGHT, 255, 255, 255);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_EQ(scroll, 0);
}
// ============================================
// 向下滚动检测测试
// ============================================
TEST_F(ScrollDetectorTest, ScrollDown_MinimalScroll) {
const int WIDTH = 100, HEIGHT = 100;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, MIN_SCROLL_LINES);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
// 检测到滚动(绝对值匹配,方向取决于实现)
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), MIN_SCROLL_LINES);
}
TEST_F(ScrollDetectorTest, ScrollDown_MediumScroll) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), SCROLL);
}
TEST_F(ScrollDetectorTest, ScrollDown_MaxScroll) {
const int WIDTH = 100, HEIGHT = 100;
const int MAX_SCROLL = HEIGHT / MAX_SCROLL_RATIO;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, MAX_SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(scroll, 0);
EXPECT_LE(std::abs(scroll), MAX_SCROLL);
}
// ============================================
// 向上滚动检测测试
// ============================================
TEST_F(ScrollDetectorTest, ScrollUp_MinimalScroll) {
const int WIDTH = 100, HEIGHT = 100;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollUp(frame1, WIDTH, HEIGHT, MIN_SCROLL_LINES);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
// 检测到滚动(方向与 ScrollDown 相反)
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), MIN_SCROLL_LINES);
}
TEST_F(ScrollDetectorTest, ScrollUp_MediumScroll) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollUp(frame1, WIDTH, HEIGHT, SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), SCROLL);
}
// ============================================
// 边缘区域测试
// ============================================
TEST_F(ScrollDetectorTest, EdgeRegion_ScrollDown) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
CScrollDetector detector(WIDTH, HEIGHT);
int offset, pixelCount;
detector.GetEdgeRegion(SCROLL, &offset, &pixelCount);
// 向下滚动: 新内容在底部
EXPECT_EQ(offset, 0);
EXPECT_EQ(pixelCount, SCROLL * WIDTH);
}
TEST_F(ScrollDetectorTest, EdgeRegion_ScrollUp) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
CScrollDetector detector(WIDTH, HEIGHT);
int offset, pixelCount;
detector.GetEdgeRegion(-SCROLL, &offset, &pixelCount);
// 向上滚动: 新内容在顶部
EXPECT_EQ(offset, (HEIGHT - SCROLL) * WIDTH * 4);
EXPECT_EQ(pixelCount, SCROLL * WIDTH);
}
TEST_F(ScrollDetectorTest, EdgeRegion_NoScroll) {
const int WIDTH = 100, HEIGHT = 100;
CScrollDetector detector(WIDTH, HEIGHT);
int offset, pixelCount;
detector.GetEdgeRegion(0, &offset, &pixelCount);
EXPECT_EQ(offset, 0);
EXPECT_EQ(pixelCount, 0);
}
// ============================================
// 边界条件测试
// ============================================
TEST_F(ScrollDetectorTest, NullInput_NoScroll) {
const int WIDTH = 100, HEIGHT = 100;
auto frame = CreateStripedFrame(WIDTH, HEIGHT);
CScrollDetector detector(WIDTH, HEIGHT);
EXPECT_EQ(detector.DetectVerticalScroll(nullptr, frame.data()), 0);
EXPECT_EQ(detector.DetectVerticalScroll(frame.data(), nullptr), 0);
EXPECT_EQ(detector.DetectVerticalScroll(nullptr, nullptr), 0);
}
TEST_F(ScrollDetectorTest, SmallScroll_BelowMinimum_NoDetection) {
const int WIDTH = 100, HEIGHT = 100;
const int SMALL_SCROLL = MIN_SCROLL_LINES - 1; // 低于最小滚动量
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, SMALL_SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
// 小于最小滚动量不应被检测
EXPECT_EQ(scroll, 0);
}
TEST_F(ScrollDetectorTest, LargeScroll_AboveMaximum_NotDetected) {
const int WIDTH = 100, HEIGHT = 100;
const int MAX_SCROLL = HEIGHT / MAX_SCROLL_RATIO;
const int LARGE_SCROLL = MAX_SCROLL + 10; // 超过最大滚动量
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, LARGE_SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
// 超过最大滚动量可能不被检测或返回最大值
EXPECT_LE(std::abs(scroll), MAX_SCROLL);
}
// ============================================
// CRC32 哈希测试
// ============================================
TEST_F(ScrollDetectorTest, CRC32_EmptyData) {
uint32_t hash = CRC32::Calculate(nullptr, 0);
// CRC32 的空数据哈希值
EXPECT_EQ(hash, 0); // 实现相关
}
TEST_F(ScrollDetectorTest, CRC32_KnownVector) {
// "123456789" 的 CRC32 应该是 0xCBF43926
const char* testData = "123456789";
uint32_t hash = CRC32::Calculate(reinterpret_cast<const uint8_t*>(testData), 9);
EXPECT_EQ(hash, 0xCBF43926);
}
TEST_F(ScrollDetectorTest, CRC32_SameData_SameHash) {
std::vector<uint8_t> data1(100, 0x42);
std::vector<uint8_t> data2(100, 0x42);
uint32_t hash1 = CRC32::Calculate(data1.data(), data1.size());
uint32_t hash2 = CRC32::Calculate(data2.data(), data2.size());
EXPECT_EQ(hash1, hash2);
}
TEST_F(ScrollDetectorTest, CRC32_DifferentData_DifferentHash) {
std::vector<uint8_t> data1(100, 0x42);
std::vector<uint8_t> data2(100, 0x43);
uint32_t hash1 = CRC32::Calculate(data1.data(), data1.size());
uint32_t hash2 = CRC32::Calculate(data2.data(), data2.size());
EXPECT_NE(hash1, hash2);
}
// ============================================
// 性能相关测试(验证正确性)
// ============================================
TEST_F(ScrollDetectorTest, LargeFrame_720p) {
const int WIDTH = 1280, HEIGHT = 720;
const int SCROLL = 50;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), SCROLL);
}
TEST_F(ScrollDetectorTest, LargeFrame_1080p) {
const int WIDTH = 1920, HEIGHT = 1080;
const int SCROLL = 100;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, SCROLL);
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), SCROLL);
}
// ============================================
// 带宽节省计算测试
// ============================================
TEST_F(ScrollDetectorTest, BandwidthSaving_ScrollDetected) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
// 完整帧大小
size_t fullFrameSize = WIDTH * HEIGHT * 4;
// 边缘区域大小
size_t edgeSize = SCROLL * WIDTH * 4;
// 带宽节省
double saving = 1.0 - static_cast<double>(edgeSize) / fullFrameSize;
// 20行滚动应该节省约80%带宽
EXPECT_GT(saving, 0.7);
}
TEST_F(ScrollDetectorTest, BandwidthSaving_NoScroll) {
const int WIDTH = 100, HEIGHT = 100;
CScrollDetector detector(WIDTH, HEIGHT);
int offset, pixelCount;
detector.GetEdgeRegion(0, &offset, &pixelCount);
// 无滚动时没有边缘区域
EXPECT_EQ(pixelCount, 0);
}
// ============================================
// 参数化测试:不同滚动量
// ============================================
class ScrollAmountTest : public ::testing::TestWithParam<int> {};
TEST_P(ScrollAmountTest, DetectScrollAmount) {
const int WIDTH = 100, HEIGHT = 100;
int scrollAmount = GetParam();
if (scrollAmount < MIN_SCROLL_LINES || scrollAmount > HEIGHT / MAX_SCROLL_RATIO) {
GTEST_SKIP() << "Scroll amount out of valid range";
}
auto frame1 = ScrollDetectorTest::CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = ScrollDetectorTest::SimulateScrollDown(frame1, WIDTH, HEIGHT, scrollAmount);
CScrollDetector detector(WIDTH, HEIGHT);
int detected = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(detected, 0);
EXPECT_EQ(std::abs(detected), scrollAmount);
}
INSTANTIATE_TEST_SUITE_P(
ScrollAmounts,
ScrollAmountTest,
::testing::Values(16, 17, 18, 19, 20, 21, 22, 23, 24, 25)
);
// ============================================
// 匹配阈值测试
// ============================================
TEST_F(ScrollDetectorTest, MatchThreshold_HighNoise_NoDetection) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, SCROLL);
// 添加大量噪声,使匹配率低于阈值
std::mt19937 rng(42);
std::uniform_int_distribution<int> dist(0, 255);
for (size_t i = 0; i < frame2.size(); i++) {
if (rng() % 2 == 0) { // 50% 的像素被随机化
frame2[i] = dist(rng);
}
}
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
// 高噪声情况下不应检测到滚动
EXPECT_EQ(scroll, 0);
}
TEST_F(ScrollDetectorTest, MatchThreshold_LowNoise_DetectionOK) {
const int WIDTH = 100, HEIGHT = 100;
const int SCROLL = 20;
auto frame1 = CreateStripedFrame(WIDTH, HEIGHT);
auto frame2 = SimulateScrollDown(frame1, WIDTH, HEIGHT, SCROLL);
// 只添加少量噪声10%
std::mt19937 rng(42);
std::uniform_int_distribution<int> dist(0, 255);
for (size_t i = 0; i < frame2.size(); i++) {
if (rng() % 10 == 0) { // 10% 的像素被随机化
frame2[i] = dist(rng);
}
}
CScrollDetector detector(WIDTH, HEIGHT);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
// 低噪声情况下仍应检测到滚动(取决于具体行噪声分布)
// 这个测试可能不稳定,因为噪声是随机的
// EXPECT_GT(scroll, 0) 或 EXPECT_EQ(scroll, 0) 取决于实际噪声分布
}
// ============================================
// 常量验证测试
// ============================================
TEST(ScrollConstantsTest, MinScrollLines) {
EXPECT_EQ(MIN_SCROLL_LINES, 16);
}
TEST(ScrollConstantsTest, MaxScrollRatio) {
EXPECT_EQ(MAX_SCROLL_RATIO, 4);
}
TEST(ScrollConstantsTest, MatchThreshold) {
EXPECT_EQ(MATCH_THRESHOLD, 85);
}
TEST(ScrollConstantsTest, ScrollDirections) {
EXPECT_EQ(SCROLL_DIR_UP, 0);
EXPECT_EQ(SCROLL_DIR_DOWN, 1);
}
// ============================================
// 分辨率参数化测试
// ============================================
class ScrollResolutionTest : public ::testing::TestWithParam<std::tuple<int, int>> {};
TEST_P(ScrollResolutionTest, DetectScrollAtResolution) {
auto [width, height] = GetParam();
int scrollAmount = std::max(MIN_SCROLL_LINES, height / 10);
if (scrollAmount > height / MAX_SCROLL_RATIO) {
scrollAmount = height / MAX_SCROLL_RATIO;
}
auto frame1 = ScrollDetectorTest::CreateStripedFrame(width, height);
auto frame2 = ScrollDetectorTest::SimulateScrollDown(frame1, width, height, scrollAmount);
CScrollDetector detector(width, height);
int scroll = detector.DetectVerticalScroll(frame1.data(), frame2.data());
EXPECT_NE(scroll, 0);
EXPECT_EQ(std::abs(scroll), scrollAmount);
}
INSTANTIATE_TEST_SUITE_P(
Resolutions,
ScrollResolutionTest,
::testing::Values(
std::make_tuple(640, 480), // VGA
std::make_tuple(800, 600), // SVGA
std::make_tuple(1024, 768), // XGA
std::make_tuple(1280, 720), // 720p
std::make_tuple(1920, 1080) // 1080p
)
);

View File

@@ -0,0 +1,912 @@
/**
* @file BufferTest.cpp
* @brief 服务端 CBuffer 类单元测试
*
* 测试覆盖:
* - 基本读写操作
* - 延迟读取偏移机制 (m_ulReadOffset)
* - 压缩/紧凑策略 (CompactBuffer)
* - 边界条件和下溢防护
* - 线程安全(并发读写)
* - 零拷贝写入接口
*/
#include <gtest/gtest.h>
#include <thread>
#include <atomic>
#include <chrono>
#include <vector>
#include <string>
// Windows 头文件
#ifdef _WIN32
#include <Windows.h>
#else
// Linux 模拟 Windows API
#include <cstring>
#include <cstdlib>
#include <pthread.h>
typedef unsigned char BYTE;
typedef BYTE* PBYTE;
typedef BYTE* LPBYTE;
typedef unsigned long ULONG;
typedef void VOID;
typedef int BOOL;
typedef void* PVOID;
#define TRUE 1
#define FALSE 0
#define MEM_COMMIT 0x1000
#define MEM_RELEASE 0x8000
#define PAGE_READWRITE 0x04
struct CRITICAL_SECTION {
pthread_mutex_t mutex;
};
inline void InitializeCriticalSection(CRITICAL_SECTION* cs) {
pthread_mutex_init(&cs->mutex, NULL);
}
inline void DeleteCriticalSection(CRITICAL_SECTION* cs) {
pthread_mutex_destroy(&cs->mutex);
}
inline void EnterCriticalSection(CRITICAL_SECTION* cs) {
pthread_mutex_lock(&cs->mutex);
}
inline void LeaveCriticalSection(CRITICAL_SECTION* cs) {
pthread_mutex_unlock(&cs->mutex);
}
inline void* VirtualAlloc(void*, size_t size, int, int) {
return malloc(size);
}
inline void VirtualFree(void* ptr, size_t, int) {
free(ptr);
}
inline void CopyMemory(void* dst, const void* src, size_t len) {
memcpy(dst, src, len);
}
inline void MoveMemory(void* dst, const void* src, size_t len) {
memmove(dst, src, len);
}
#endif
#include <cmath>
// 服务端 Buffer 实现(测试专用内联版本)
namespace ServerBuffer {
#define U_PAGE_ALIGNMENT 4096
#define F_PAGE_ALIGNMENT 4096.0
#define COMPACT_THRESHOLD 0.5
// 简化的 Buffer 类(用于 GetMyBuffer 返回)
class Buffer {
private:
PBYTE buf;
ULONG len;
public:
Buffer() : buf(NULL), len(0) {}
Buffer(const BYTE* b, ULONG n) : len(n) {
if (n > 0 && b) {
buf = new BYTE[n];
memcpy(buf, b, n);
} else {
buf = NULL;
}
}
~Buffer() {
if (buf) {
delete[] buf;
buf = NULL;
}
}
Buffer(const Buffer& o) : len(o.len) {
if (o.buf && o.len > 0) {
buf = new BYTE[o.len];
memcpy(buf, o.buf, o.len);
} else {
buf = NULL;
}
}
ULONG length() const { return len; }
LPBYTE GetBuffer(int idx = 0) const {
return (idx >= (int)len) ? NULL : buf + idx;
}
};
class CBuffer {
public:
CBuffer() : m_ulMaxLength(0), m_ulReadOffset(0), m_Base(NULL), m_Ptr(NULL) {
InitializeCriticalSection(&m_cs);
}
~CBuffer() {
if (m_Base) {
VirtualFree(m_Base, 0, MEM_RELEASE);
m_Base = NULL;
}
DeleteCriticalSection(&m_cs);
m_Base = m_Ptr = NULL;
m_ulMaxLength = 0;
m_ulReadOffset = 0;
}
ULONG RemoveCompletedBuffer(ULONG ulLength) {
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (ulLength > effectiveDataLen) {
ulLength = effectiveDataLen;
}
if (ulLength) {
m_ulReadOffset += ulLength;
if (m_ulReadOffset > (ULONG)(m_ulMaxLength * COMPACT_THRESHOLD)) {
CompactBuffer();
}
}
LeaveCriticalSection(&m_cs);
return ulLength;
}
VOID CompactBuffer() {
if (m_ulReadOffset > 0 && m_Base) {
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG remainingData = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (remainingData > 0) {
MoveMemory(m_Base, m_Base + m_ulReadOffset, remainingData);
}
m_Ptr = m_Base + remainingData;
m_ulReadOffset = 0;
DeAllocateBuffer(remainingData);
}
}
ULONG ReadBuffer(PBYTE Buffer, ULONG ulLength) {
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (ulLength > effectiveDataLen) {
ulLength = effectiveDataLen;
}
if (ulLength) {
CopyMemory(Buffer, m_Base + m_ulReadOffset, ulLength);
m_ulReadOffset += ulLength;
if (m_ulReadOffset > (ULONG)(m_ulMaxLength * COMPACT_THRESHOLD)) {
CompactBuffer();
}
}
LeaveCriticalSection(&m_cs);
return ulLength;
}
ULONG DeAllocateBuffer(ULONG ulLength) {
if (ulLength < (ULONG)(m_Ptr - m_Base))
return 0;
ULONG ulNewMaxLength = (ULONG)(ceil(ulLength / F_PAGE_ALIGNMENT) * U_PAGE_ALIGNMENT);
if (m_ulMaxLength <= ulNewMaxLength) {
return 0;
}
PBYTE NewBase = (PBYTE)VirtualAlloc(NULL, ulNewMaxLength, MEM_COMMIT, PAGE_READWRITE);
ULONG ulv1 = (ULONG)(m_Ptr - m_Base);
CopyMemory(NewBase, m_Base, ulv1);
VirtualFree(m_Base, 0, MEM_RELEASE);
m_Base = NewBase;
m_Ptr = m_Base + ulv1;
m_ulMaxLength = ulNewMaxLength;
return m_ulMaxLength;
}
BOOL WriteBuffer(PBYTE Buffer, ULONG ulLength) {
EnterCriticalSection(&m_cs);
if (ReAllocateBuffer(ulLength + (ULONG)(m_Ptr - m_Base)) == (ULONG)-1) {
LeaveCriticalSection(&m_cs);
return FALSE;
}
CopyMemory(m_Ptr, Buffer, ulLength);
m_Ptr += ulLength;
LeaveCriticalSection(&m_cs);
return TRUE;
}
ULONG ReAllocateBuffer(ULONG ulLength) {
if (ulLength < m_ulMaxLength)
return 0;
ULONG ulNewMaxLength = (ULONG)(ceil(ulLength / F_PAGE_ALIGNMENT) * U_PAGE_ALIGNMENT);
PBYTE NewBase = (PBYTE)VirtualAlloc(NULL, ulNewMaxLength, MEM_COMMIT, PAGE_READWRITE);
if (NewBase == NULL) {
return (ULONG)-1;
}
ULONG ulv1 = (ULONG)(m_Ptr - m_Base);
CopyMemory(NewBase, m_Base, ulv1);
if (m_Base) {
VirtualFree(m_Base, 0, MEM_RELEASE);
}
m_Base = NewBase;
m_Ptr = m_Base + ulv1;
m_ulMaxLength = ulNewMaxLength;
return m_ulMaxLength;
}
VOID ClearBuffer() {
EnterCriticalSection(&m_cs);
m_Ptr = m_Base;
m_ulReadOffset = 0;
DeAllocateBuffer(1024);
LeaveCriticalSection(&m_cs);
}
ULONG GetBufferLength() {
EnterCriticalSection(&m_cs);
if (m_Base == NULL) {
LeaveCriticalSection(&m_cs);
return 0;
}
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG len = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
LeaveCriticalSection(&m_cs);
return len;
}
std::string Skip(ULONG ulPos) {
if (ulPos == 0)
return "";
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (ulPos > effectiveDataLen) {
ulPos = effectiveDataLen;
}
std::string ret((char*)(m_Base + m_ulReadOffset), (char*)(m_Base + m_ulReadOffset + ulPos));
m_ulReadOffset += ulPos;
if (m_ulReadOffset > (ULONG)(m_ulMaxLength * COMPACT_THRESHOLD)) {
CompactBuffer();
}
LeaveCriticalSection(&m_cs);
return ret;
}
LPBYTE GetBuffer(ULONG ulPos = 0) {
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (m_Base == NULL || ulPos >= effectiveDataLen) {
LeaveCriticalSection(&m_cs);
return NULL;
}
LPBYTE result = m_Base + m_ulReadOffset + ulPos;
LeaveCriticalSection(&m_cs);
return result;
}
Buffer GetMyBuffer(ULONG ulPos = 0) {
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (m_Base == NULL || ulPos >= effectiveDataLen) {
LeaveCriticalSection(&m_cs);
return Buffer();
}
Buffer result(m_Base + m_ulReadOffset + ulPos, effectiveDataLen - ulPos);
LeaveCriticalSection(&m_cs);
return result;
}
BYTE GetBYTE(ULONG ulPos) {
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (m_Base == NULL || ulPos >= effectiveDataLen) {
LeaveCriticalSection(&m_cs);
return 0;
}
BYTE p = *(m_Base + m_ulReadOffset + ulPos);
LeaveCriticalSection(&m_cs);
return p;
}
BOOL CopyBuffer(PVOID pDst, ULONG nLen, ULONG ulPos) {
EnterCriticalSection(&m_cs);
ULONG totalDataLen = (ULONG)(m_Ptr - m_Base);
ULONG effectiveDataLen = (totalDataLen > m_ulReadOffset) ? (totalDataLen - m_ulReadOffset) : 0;
if (m_Base == NULL || pDst == NULL || ulPos >= effectiveDataLen || (effectiveDataLen - ulPos) < nLen) {
LeaveCriticalSection(&m_cs);
return FALSE;
}
memcpy(pDst, m_Base + m_ulReadOffset + ulPos, nLen);
LeaveCriticalSection(&m_cs);
return TRUE;
}
LPBYTE GetWriteBuffer(ULONG requiredSize, ULONG& availableSize) {
EnterCriticalSection(&m_cs);
if (m_ulReadOffset > 0) {
CompactBuffer();
}
ULONG currentDataLen = (ULONG)(m_Ptr - m_Base);
if (ReAllocateBuffer(currentDataLen + requiredSize) == (ULONG)-1) {
LeaveCriticalSection(&m_cs);
availableSize = 0;
return NULL;
}
availableSize = m_ulMaxLength - currentDataLen;
LPBYTE result = m_Ptr;
LeaveCriticalSection(&m_cs);
return result;
}
VOID CommitWrite(ULONG writtenSize) {
EnterCriticalSection(&m_cs);
m_Ptr += writtenSize;
LeaveCriticalSection(&m_cs);
}
// 测试辅助:获取内部状态
ULONG GetReadOffset() const { return m_ulReadOffset; }
ULONG GetMaxLength() const { return m_ulMaxLength; }
protected:
PBYTE m_Base;
PBYTE m_Ptr;
ULONG m_ulMaxLength;
ULONG m_ulReadOffset;
CRITICAL_SECTION m_cs;
};
} // namespace ServerBuffer
using ServerBuffer::CBuffer;
using ServerBuffer::Buffer;
// ============================================
// 测试夹具
// ============================================
class ServerBufferTest : public ::testing::Test {
protected:
CBuffer buffer;
void SetUp() override {}
void TearDown() override {}
void WriteFillData(ULONG length, BYTE fillValue = 0x42) {
std::vector<BYTE> data(length, fillValue);
buffer.WriteBuffer(data.data(), length);
}
};
// ============================================
// 构造/析构测试
// ============================================
TEST_F(ServerBufferTest, Constructor_InitializesEmpty) {
CBuffer newBuffer;
EXPECT_EQ(newBuffer.GetBufferLength(), 0u);
EXPECT_EQ(newBuffer.GetBuffer(), nullptr);
}
// ============================================
// WriteBuffer 测试
// ============================================
TEST_F(ServerBufferTest, WriteBuffer_ValidData_ReturnsTrue) {
BYTE data[] = {1, 2, 3, 4, 5};
EXPECT_TRUE(buffer.WriteBuffer(data, 5));
EXPECT_EQ(buffer.GetBufferLength(), 5u);
}
TEST_F(ServerBufferTest, WriteBuffer_MultipleWrites_AccumulatesData) {
BYTE data1[] = {1, 2, 3};
BYTE data2[] = {4, 5};
buffer.WriteBuffer(data1, 3);
buffer.WriteBuffer(data2, 2);
EXPECT_EQ(buffer.GetBufferLength(), 5u);
}
TEST_F(ServerBufferTest, WriteBuffer_LargeData_HandlesCorrectly) {
const ULONG largeSize = 100000;
std::vector<BYTE> data(largeSize, 0xAB);
EXPECT_TRUE(buffer.WriteBuffer(data.data(), largeSize));
EXPECT_EQ(buffer.GetBufferLength(), largeSize);
}
// ============================================
// ReadBuffer 测试
// ============================================
TEST_F(ServerBufferTest, ReadBuffer_EmptyBuffer_ReturnsZero) {
BYTE result[10];
EXPECT_EQ(buffer.ReadBuffer(result, 10), 0u);
}
TEST_F(ServerBufferTest, ReadBuffer_ExactLength_ReturnsAll) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE result[5];
ULONG bytesRead = buffer.ReadBuffer(result, 5);
EXPECT_EQ(bytesRead, 5u);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ServerBufferTest, ReadBuffer_PartialRead_UsesReadOffset) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE result[2];
buffer.ReadBuffer(result, 2);
// 使用延迟偏移,不立即移动数据
EXPECT_EQ(buffer.GetBufferLength(), 3u);
EXPECT_GT(buffer.GetReadOffset(), 0u);
}
TEST_F(ServerBufferTest, ReadBuffer_RequestExceedsAvailable_ReturnsAvailableOnly) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
BYTE result[10];
ULONG bytesRead = buffer.ReadBuffer(result, 10);
EXPECT_EQ(bytesRead, 3u);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
// ============================================
// RemoveCompletedBuffer 测试
// ============================================
TEST_F(ServerBufferTest, RemoveCompletedBuffer_PartialRemove_UpdatesOffset) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
ULONG removed = buffer.RemoveCompletedBuffer(2);
EXPECT_EQ(removed, 2u);
EXPECT_EQ(buffer.GetBufferLength(), 3u);
}
TEST_F(ServerBufferTest, RemoveCompletedBuffer_ExceedsLength_ClampsToAvailable) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
ULONG removed = buffer.RemoveCompletedBuffer(100);
EXPECT_EQ(removed, 3u);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
// ============================================
// Skip 测试
// ============================================
TEST_F(ServerBufferTest, Skip_ReturnsSkippedData) {
BYTE data[] = {'H', 'e', 'l', 'l', 'o'};
buffer.WriteBuffer(data, 5);
std::string skipped = buffer.Skip(3);
EXPECT_EQ(skipped, "Hel");
EXPECT_EQ(buffer.GetBufferLength(), 2u);
}
TEST_F(ServerBufferTest, Skip_ExceedsLength_ClampsToAvailable) {
BYTE data[] = {'A', 'B', 'C'};
buffer.WriteBuffer(data, 3);
std::string skipped = buffer.Skip(100);
EXPECT_EQ(skipped, "ABC");
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ServerBufferTest, Skip_ZeroLength_ReturnsEmpty) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
std::string skipped = buffer.Skip(0);
EXPECT_EQ(skipped, "");
EXPECT_EQ(buffer.GetBufferLength(), 3u);
}
// ============================================
// GetBuffer 测试
// ============================================
TEST_F(ServerBufferTest, GetBuffer_RespectsReadOffset) {
BYTE data[] = {10, 20, 30, 40, 50};
buffer.WriteBuffer(data, 5);
// 读取前两个字节,更新偏移
BYTE temp[2];
buffer.ReadBuffer(temp, 2);
// GetBuffer(0) 应该返回第三个字节30
EXPECT_EQ(*buffer.GetBuffer(0), 30);
EXPECT_EQ(*buffer.GetBuffer(1), 40);
EXPECT_EQ(*buffer.GetBuffer(2), 50);
}
TEST_F(ServerBufferTest, GetBuffer_PositionExceedsEffectiveLength_ReturnsNull) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE temp[3];
buffer.ReadBuffer(temp, 3); // 有效长度变为 2
EXPECT_NE(buffer.GetBuffer(0), nullptr);
EXPECT_NE(buffer.GetBuffer(1), nullptr);
EXPECT_EQ(buffer.GetBuffer(2), nullptr); // 超出有效范围
}
// ============================================
// GetMyBuffer 测试
// ============================================
TEST_F(ServerBufferTest, GetMyBuffer_ReturnsCorrectBuffer) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
Buffer buf = buffer.GetMyBuffer(0);
EXPECT_EQ(buf.length(), 5u);
}
TEST_F(ServerBufferTest, GetMyBuffer_RespectsReadOffset) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE temp[2];
buffer.ReadBuffer(temp, 2);
Buffer buf = buffer.GetMyBuffer(0);
EXPECT_EQ(buf.length(), 3u);
EXPECT_EQ(*buf.GetBuffer(0), 3);
}
// ============================================
// GetBYTE 测试
// ============================================
TEST_F(ServerBufferTest, GetBYTE_ReturnsCorrectByte) {
BYTE data[] = {10, 20, 30, 40, 50};
buffer.WriteBuffer(data, 5);
EXPECT_EQ(buffer.GetBYTE(0), 10);
EXPECT_EQ(buffer.GetBYTE(2), 30);
EXPECT_EQ(buffer.GetBYTE(4), 50);
}
TEST_F(ServerBufferTest, GetBYTE_RespectsReadOffset) {
BYTE data[] = {10, 20, 30, 40, 50};
buffer.WriteBuffer(data, 5);
BYTE temp[2];
buffer.ReadBuffer(temp, 2);
EXPECT_EQ(buffer.GetBYTE(0), 30);
EXPECT_EQ(buffer.GetBYTE(1), 40);
}
TEST_F(ServerBufferTest, GetBYTE_OutOfRange_ReturnsZero) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
EXPECT_EQ(buffer.GetBYTE(100), 0);
}
// ============================================
// CopyBuffer 测试
// ============================================
TEST_F(ServerBufferTest, CopyBuffer_ValidRange_ReturnsTrue) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE dest[3];
EXPECT_TRUE(buffer.CopyBuffer(dest, 3, 1));
EXPECT_EQ(dest[0], 2);
EXPECT_EQ(dest[1], 3);
EXPECT_EQ(dest[2], 4);
}
TEST_F(ServerBufferTest, CopyBuffer_RespectsReadOffset) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE temp[2];
buffer.ReadBuffer(temp, 2);
BYTE dest[2];
EXPECT_TRUE(buffer.CopyBuffer(dest, 2, 0));
EXPECT_EQ(dest[0], 3);
EXPECT_EQ(dest[1], 4);
}
TEST_F(ServerBufferTest, CopyBuffer_ExceedsRange_ReturnsFalse) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
BYTE dest[10];
EXPECT_FALSE(buffer.CopyBuffer(dest, 10, 0));
}
// ============================================
// 零拷贝写入测试
// ============================================
TEST_F(ServerBufferTest, GetWriteBuffer_ReturnsValidPointer) {
ULONG availableSize = 0;
LPBYTE writePtr = buffer.GetWriteBuffer(100, availableSize);
EXPECT_NE(writePtr, nullptr);
EXPECT_GE(availableSize, 100u);
}
TEST_F(ServerBufferTest, CommitWrite_UpdatesLength) {
ULONG availableSize = 0;
LPBYTE writePtr = buffer.GetWriteBuffer(100, availableSize);
// 直接写入
for (int i = 0; i < 50; i++) {
writePtr[i] = (BYTE)i;
}
buffer.CommitWrite(50);
EXPECT_EQ(buffer.GetBufferLength(), 50u);
EXPECT_EQ(buffer.GetBYTE(0), 0);
EXPECT_EQ(buffer.GetBYTE(49), 49);
}
// ============================================
// ClearBuffer 测试
// ============================================
TEST_F(ServerBufferTest, ClearBuffer_ResetsEverything) {
BYTE data[] = {1, 2, 3, 4, 5};
buffer.WriteBuffer(data, 5);
BYTE temp[2];
buffer.ReadBuffer(temp, 2); // 创建读取偏移
buffer.ClearBuffer();
EXPECT_EQ(buffer.GetBufferLength(), 0u);
EXPECT_EQ(buffer.GetReadOffset(), 0u);
}
// ============================================
// 下溢防护测试
// ============================================
TEST_F(ServerBufferTest, UnderflowProtection_ReadMoreThanLength_NoUnderflow) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
BYTE result[1000];
ULONG bytesRead = buffer.ReadBuffer(result, ULONG_MAX - 1);
EXPECT_EQ(bytesRead, 3u);
}
TEST_F(ServerBufferTest, UnderflowProtection_SkipMoreThanLength_NoUnderflow) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
std::string skipped = buffer.Skip(ULONG_MAX - 1);
EXPECT_EQ(skipped.length(), 3u);
EXPECT_EQ(buffer.GetBufferLength(), 0u);
}
TEST_F(ServerBufferTest, UnderflowProtection_RemoveMoreThanLength_NoUnderflow) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
ULONG removed = buffer.RemoveCompletedBuffer(ULONG_MAX - 1);
EXPECT_EQ(removed, 3u);
}
TEST_F(ServerBufferTest, UnderflowProtection_GetByteOutOfRange_ReturnsZero) {
BYTE data[] = {1, 2, 3};
buffer.WriteBuffer(data, 3);
EXPECT_EQ(buffer.GetBYTE(ULONG_MAX - 1), 0);
}
// ============================================
// 压缩策略测试
// ============================================
TEST_F(ServerBufferTest, Compaction_TriggersAtThreshold) {
// 写入足够数据
// m_ulMaxLength 会被页对齐到 ceil(10000/4096)*4096 = 12288
// 压缩阈值 = 12288 * 0.5 = 6144
WriteFillData(10000);
ULONG initialOffset = buffer.GetReadOffset();
EXPECT_EQ(initialOffset, 0u);
// 读取超过阈值的数据(需要 > 6144
std::vector<BYTE> temp(7000);
buffer.ReadBuffer(temp.data(), 7000);
// 压缩后偏移应该重置
EXPECT_EQ(buffer.GetReadOffset(), 0u);
EXPECT_EQ(buffer.GetBufferLength(), 3000u);
}
// ============================================
// 线程安全测试
// ============================================
TEST_F(ServerBufferTest, ThreadSafety_ConcurrentReadWrite_NoDataCorruption) {
std::atomic<bool> running{true};
std::atomic<int> writeCount{0};
std::atomic<int> readCount{0};
// 写线程
std::thread writer([&]() {
BYTE data[100];
for (int i = 0; i < 100; i++) {
data[i] = (BYTE)i;
}
while (running) {
if (buffer.WriteBuffer(data, 100)) {
writeCount++;
}
std::this_thread::yield();
}
});
// 读线程
std::thread reader([&]() {
BYTE result[50];
while (running) {
if (buffer.ReadBuffer(result, 50) > 0) {
readCount++;
}
std::this_thread::yield();
}
});
// 运行一段时间
std::this_thread::sleep_for(std::chrono::milliseconds(100));
running = false;
writer.join();
reader.join();
// 验证无崩溃,且有数据交换
EXPECT_GT(writeCount.load(), 0);
EXPECT_GT(readCount.load(), 0);
}
TEST_F(ServerBufferTest, ThreadSafety_MultipleReaders_NoDeadlock) {
// 预填充数据
WriteFillData(10000);
std::atomic<bool> running{true};
std::vector<std::thread> readers;
for (int i = 0; i < 4; i++) {
readers.emplace_back([&]() {
while (running) {
buffer.GetBufferLength();
buffer.GetBYTE(0);
buffer.GetBuffer(0);
std::this_thread::yield();
}
});
}
std::this_thread::sleep_for(std::chrono::milliseconds(50));
running = false;
for (auto& t : readers) {
t.join();
}
// 无死锁即为成功
SUCCEED();
}
// ============================================
// 数据完整性测试
// ============================================
TEST_F(ServerBufferTest, DataIntegrity_WriteReadCycle_PreservesData) {
std::vector<BYTE> data(256);
for (int i = 0; i < 256; i++) {
data[i] = (BYTE)i;
}
buffer.WriteBuffer(data.data(), 256);
std::vector<BYTE> result(256);
ULONG bytesRead = buffer.ReadBuffer(result.data(), 256);
EXPECT_EQ(bytesRead, 256u);
for (int i = 0; i < 256; i++) {
EXPECT_EQ(result[i], data[i]) << "Mismatch at index " << i;
}
}
TEST_F(ServerBufferTest, DataIntegrity_PartialReads_PreservesSequence) {
// 写入 1-100
std::vector<BYTE> data(100);
for (int i = 0; i < 100; i++) {
data[i] = (BYTE)(i + 1);
}
buffer.WriteBuffer(data.data(), 100);
// 分多次读取
BYTE result[100];
ULONG totalRead = 0;
totalRead += buffer.ReadBuffer(result, 30);
totalRead += buffer.ReadBuffer(result + 30, 30);
totalRead += buffer.ReadBuffer(result + 60, 40);
EXPECT_EQ(totalRead, 100u);
for (int i = 0; i < 100; i++) {
EXPECT_EQ(result[i], (BYTE)(i + 1));
}
}
// ============================================
// 参数化测试
// ============================================
class ServerBufferParameterizedTest
: public ::testing::TestWithParam<std::tuple<size_t, size_t, size_t>> {
protected:
CBuffer buffer;
};
TEST_P(ServerBufferParameterizedTest, ReadBuffer_VariousLengths) {
auto [writeLen, readLen, expectedRead] = GetParam();
std::vector<BYTE> data(writeLen, 0x42);
if (writeLen > 0) {
buffer.WriteBuffer(data.data(), (ULONG)writeLen);
}
std::vector<BYTE> result(readLen > 0 ? readLen : 1);
ULONG actual = buffer.ReadBuffer(result.data(), (ULONG)readLen);
EXPECT_EQ(actual, expectedRead);
}
INSTANTIATE_TEST_SUITE_P(
ReadLengths,
ServerBufferParameterizedTest,
::testing::Values(
std::make_tuple(10, 5, 5),
std::make_tuple(5, 10, 5),
std::make_tuple(0, 5, 0),
std::make_tuple(100, 0, 0),
std::make_tuple(10000, 5000, 5000)
)
);