diff --git a/Tensile/Source/lib/source/msgpack/MessagePack.cpp b/Tensile/Source/lib/source/msgpack/MessagePack.cpp index de97929c..dbc397e0 100644 --- a/tensilelite/src/msgpack/MessagePack.cpp +++ b/tensilelite/src/msgpack/MessagePack.cpp @@ -28,6 +28,8 @@ #include +#include + #include namespace Tensile @@ -86,6 +88,34 @@ namespace Tensile return nullptr; } + // Check if the file is zstd compressed + char magic[4]; + in.read(magic, 4); + bool isCompressed = (in.gcount() == 4 && magic[0] == '\x28' && magic[1] == '\xB5' && magic[2] == '\x2F' && magic[3] == '\xFD'); + // Reset file pointer to the beginning + in.seekg(0, std::ios::beg); + + if (isCompressed) { + // Decompress zstd file + std::vector compressedData((std::istreambuf_iterator(in)), std::istreambuf_iterator()); + + size_t decompressedSize = ZSTD_getFrameContentSize(compressedData.data(), compressedData.size()); + if (decompressedSize == ZSTD_CONTENTSIZE_ERROR || decompressedSize == ZSTD_CONTENTSIZE_UNKNOWN) { + if(Debug::Instance().printDataInit()) + std::cout << "Error: Unable to determine decompressed size for " << filename << std::endl; + return false; + } + + std::vector decompressedData(decompressedSize); + size_t dSize = ZSTD_decompress(decompressedData.data(), decompressedSize, compressedData.data(), compressedData.size()); + if (ZSTD_isError(dSize)) { + if(Debug::Instance().printDataInit()) + std::cout << "Error: ZSTD decompression failed for " << filename << std::endl; + return false; + } + + msgpack::unpack(result, decompressedData.data(), dSize); + } else { msgpack::unpacker unp; bool finished_parsing; constexpr size_t buffer_size = 1 << 19; @@ -109,6 +139,7 @@ namespace Tensile return nullptr; } + } } catch(std::runtime_error const& exc) {