diff --git a/CMakeLists.txt b/CMakeLists.txt index c0735e5b1..bb01373c5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ endif() # # general -#option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE}) +option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE}) option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE}) option(SD_CUDA "sd: cuda backend" OFF) option(SD_HIPBLAS "sd: rocm backend" OFF) @@ -145,6 +145,11 @@ target_include_directories(${SD_LIB} PUBLIC . thirdparty) target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17) +if (SD_BUILD_TESTS) + enable_testing() + add_subdirectory(tests) +endif() + if (SD_BUILD_EXAMPLES) add_subdirectory(examples) endif() diff --git a/format-code.sh b/format-code.sh old mode 100644 new mode 100755 diff --git a/model.cpp b/model.cpp index 992a02dbc..22395eaef 100644 --- a/model.cpp +++ b/model.cpp @@ -40,38 +40,6 @@ #define ST_HEADER_SIZE_LEN 8 -uint64_t read_u64(uint8_t* buffer) { - // little endian - uint64_t value = 0; - value |= static_cast(buffer[7]) << 56; - value |= static_cast(buffer[6]) << 48; - value |= static_cast(buffer[5]) << 40; - value |= static_cast(buffer[4]) << 32; - value |= static_cast(buffer[3]) << 24; - value |= static_cast(buffer[2]) << 16; - value |= static_cast(buffer[1]) << 8; - value |= static_cast(buffer[0]); - return value; -} - -int32_t read_int(uint8_t* buffer) { - // little endian - int value = 0; - value |= buffer[3] << 24; - value |= buffer[2] << 16; - value |= buffer[1] << 8; - value |= buffer[0]; - return value; -} - -uint16_t read_short(uint8_t* buffer) { - // little endian - uint16_t value = 0; - value |= buffer[1] << 8; - value |= buffer[0]; - return value; -} - /*================================================= Preprocess ==================================================*/ std::string self_attn_names[] = { diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index b68ba4fb8..c313368d9 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -346,9 +346,9 @@ class StableDiffusionGGML { offload_params_to_cpu, model_loader.tensor_storages_types); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - sd_ctx_params->diffusion_flash_attn, - model_loader.tensor_storages_types); + offload_params_to_cpu, + sd_ctx_params->diffusion_flash_attn, + model_loader.tensor_storages_types); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : model_loader.tensor_storages_types) { @@ -391,11 +391,11 @@ class StableDiffusionGGML { 1, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - model_loader.tensor_storages_types, - "model.diffusion_model", - version, - sd_ctx_params->diffusion_flash_attn); + offload_params_to_cpu, + model_loader.tensor_storages_types, + "model.diffusion_model", + version, + sd_ctx_params->diffusion_flash_attn); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, @@ -1416,12 +1416,12 @@ class StableDiffusionGGML { -0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f, 0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f}; latents_std_vec = { - 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, - 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, - 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, - 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, - 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, - 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; + 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, + 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, + 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, + 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, + 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, + 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; } for (int i = 0; i < latent->ne[3]; i++) { float mean = latents_mean_vec[i]; @@ -1456,12 +1456,12 @@ class StableDiffusionGGML { -0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f, 0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f}; latents_std_vec = { - 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, - 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, - 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, - 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, - 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, - 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; + 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, + 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, + 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, + 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, + 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, + 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; } for (int i = 0; i < latent->ne[3]; i++) { float mean = latents_mean_vec[i]; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 000000000..ed2839fa2 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET test-binary-reader) + +add_executable(${TARGET} test_binary_reader.cpp) +target_link_libraries(${TARGET} PRIVATE stable-diffusion) +target_include_directories(${TARGET} PRIVATE ..) +target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17) + +add_test(NAME ${TARGET} COMMAND ${TARGET}) diff --git a/tests/test_binary_reader.cpp b/tests/test_binary_reader.cpp new file mode 100644 index 000000000..8e8dd6838 --- /dev/null +++ b/tests/test_binary_reader.cpp @@ -0,0 +1,96 @@ +#include +#include +#include +#include "util.h" + +#define ASSERT(cond) \ + if (!(cond)) { \ + std::cerr << "Assertion failed at " << __FILE__ << ":" << __LINE__ << ": " << #cond << std::endl; \ + std::exit(1); \ + } + +void test_read_u64() { + std::cout << "Testing read_u64..." << std::endl; + + // Case 1: 0 + uint8_t buf1[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + ASSERT(read_u64(buf1) == 0); + + // Case 2: 1 + uint8_t buf2[8] = {1, 0, 0, 0, 0, 0, 0, 0}; + ASSERT(read_u64(buf2) == 1); + + // Case 3: 0x0102030405060708 + uint8_t buf3[8] = {0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01}; + ASSERT(read_u64(buf3) == 0x0102030405060708ULL); + + // Case 4: Max value + uint8_t buf4[8] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; + ASSERT(read_u64(buf4) == 0xFFFFFFFFFFFFFFFFULL); + + // Case 5: Pattern with high bits + uint8_t buf5[8] = {0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11}; + ASSERT(read_u64(buf5) == 0x1100FFEEDDCCBBAAULL); + + std::cout << "read_u64 passed!" << std::endl; +} + +void test_read_int() { + std::cout << "Testing read_int..." << std::endl; + + // Case 1: 0 + uint8_t buf1[4] = {0, 0, 0, 0}; + ASSERT(read_int(buf1) == 0); + + // Case 2: 1 + uint8_t buf2[4] = {1, 0, 0, 0}; + ASSERT(read_int(buf2) == 1); + + // Case 3: 0x01020304 + uint8_t buf3[4] = {0x04, 0x03, 0x02, 0x01}; + ASSERT(read_int(buf3) == 0x01020304); + + // Case 4: Negative pattern (if treated as signed) + uint8_t buf4[4] = {0xFF, 0xFF, 0xFF, 0xFF}; + ASSERT(read_int(buf4) == -1); + + // Case 5: 0x7FFFFFFF (max positive int32) + uint8_t buf5[4] = {0xFF, 0xFF, 0xFF, 0x7F}; + ASSERT(read_int(buf5) == 2147483647); + + // Case 6: 0x80000000 (min negative int32) + uint8_t buf6[4] = {0x00, 0x00, 0x00, 0x80}; + ASSERT(read_int(buf6) == (int32_t)0x80000000); + + std::cout << "read_int passed!" << std::endl; +} + +void test_read_short() { + std::cout << "Testing read_short..." << std::endl; + + // Case 1: 0 + uint8_t buf1[2] = {0, 0}; + ASSERT(read_short(buf1) == 0); + + // Case 2: 1 + uint8_t buf2[2] = {1, 0}; + ASSERT(read_short(buf2) == 1); + + // Case 3: 0x0102 + uint8_t buf3[2] = {0x02, 0x01}; + ASSERT(read_short(buf3) == 0x0102); + + // Case 4: Max value + uint8_t buf4[2] = {0xFF, 0xFF}; + ASSERT(read_short(buf4) == 0xFFFF); + + std::cout << "read_short passed!" << std::endl; +} + +int main() { + test_read_u64(); + test_read_int(); + test_read_short(); + std::cout << "All tests passed!" << std::endl; + return 0; +} diff --git a/util.cpp b/util.cpp index 5af6b1ec1..69fcd8966 100644 --- a/util.cpp +++ b/util.cpp @@ -605,4 +605,35 @@ std::vector> parse_prompt_attention(const std::str } return res; -} \ No newline at end of file +} +uint64_t read_u64(uint8_t* buffer) { + // little endian + uint64_t value = 0; + value |= static_cast(buffer[7]) << 56; + value |= static_cast(buffer[6]) << 48; + value |= static_cast(buffer[5]) << 40; + value |= static_cast(buffer[4]) << 32; + value |= static_cast(buffer[3]) << 24; + value |= static_cast(buffer[2]) << 16; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return value; +} + +int32_t read_int(uint8_t* buffer) { + // little endian + int32_t value = 0; + value |= static_cast(buffer[3]) << 24; + value |= static_cast(buffer[2]) << 16; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return value; +} + +uint16_t read_short(uint8_t* buffer) { + // little endian + uint16_t value = 0; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return value; +} diff --git a/util.h b/util.h index 1e8db6e3b..07d2f95c7 100644 --- a/util.h +++ b/util.h @@ -48,6 +48,10 @@ std::string path_join(const std::string& p1, const std::string& p2); std::vector split_string(const std::string& str, char delimiter); void pretty_progress(int step, int steps, float time); +uint64_t read_u64(uint8_t* buffer); +int32_t read_int(uint8_t* buffer); +uint16_t read_short(uint8_t* buffer); + void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...); std::string trim(const std::string& s); diff --git a/vae.hpp b/vae.hpp index dd982ab7a..a6a6021b8 100644 --- a/vae.hpp +++ b/vae.hpp @@ -529,7 +529,7 @@ struct VAE : public GGMLRunner { struct ggml_tensor** output, struct ggml_context* output_ctx) = 0; virtual void get_param_tensors(std::map& tensors, const std::string prefix) = 0; - virtual void enable_conv2d_direct(){}; + virtual void enable_conv2d_direct() {}; }; struct AutoEncoderKL : public VAE {