diff --git a/protocol/payload_update_test.cc b/protocol/payload_update_test.cc index 3631252..4de9b47 100644 --- a/protocol/payload_update_test.cc +++ b/protocol/payload_update_test.cc @@ -16,9 +16,13 @@ #include #include +#include #include +#include +#include #include +#include #include "command_version.h" #include "payload_info.h" @@ -38,15 +42,35 @@ constexpr int64_t kAlign = 1 << 16; constexpr int64_t kDummy = 0; MATCHER_P2(IsEraseRequest, offset, len, "") { - const uint8_t* data = static_cast(arg); + const struct hoth_host_request* req = + static_cast(arg); + if (req->struct_version != HOTH_HOST_REQUEST_VERSION || + req->command != kCmd) { + return false; + } const struct payload_update_packet* p = reinterpret_cast( - data + sizeof(struct hoth_host_request)); + static_cast(arg) + sizeof(struct hoth_host_request)); return p->type == PAYLOAD_UPDATE_ERASE && p->offset == static_cast(offset) && p->len == static_cast(len); } +MATCHER_P2(IsReadRequest, offset, len, "") { + const struct hoth_host_request* req = + static_cast(arg); + if (req->struct_version != HOTH_HOST_REQUEST_VERSION || + req->command != kCmd) { + return false; + } + const struct payload_update_packet* p = + reinterpret_cast( + static_cast(arg) + sizeof(struct hoth_host_request)); + return p->type == PAYLOAD_UPDATE_READ && + p->offset == static_cast(offset) && + p->len == static_cast(len); +} + TEST_F(LibHothTest, payload_update_bad_image_test) { EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _)) .WillRepeatedly(Return(LIBHOTH_OK)); @@ -496,3 +520,31 @@ TEST_F(LibHothTest, payload_update_erase_cmd_range_overflow_test) { EXPECT_EQ(libhoth_payload_update_erase(&hoth_dev_, kOffset, kSize), PAYLOAD_UPDATE_INVALID_ARGS); } + +TEST_F(LibHothTest, payload_update_read_chunk_test) { + uint8_t expected_data[] = {0x11, 0x22, 0x33, 0x44, 0x55}; + size_t offset = 0x100; + size_t len = sizeof(expected_data); + + EXPECT_CALL(mock_, send(_, IsReadRequest(offset, len), _)) + .WillOnce(Return(LIBHOTH_OK)); + EXPECT_CALL(mock_, receive) + .WillOnce(DoAll(CopyResp(expected_data, len), Return(LIBHOTH_OK))); + + std::string temp_path = testing::TempDir() + "libhoth_payload_read_XXXXXX"; + int fd = mkstemp(&temp_path[0]); + ASSERT_GE(fd, 0); + unlink(temp_path.c_str()); + struct FdGuard { + int fd; + ~FdGuard() { close(fd); } + } guard = {fd}; + + EXPECT_EQ(libhoth_payload_update_read_chunk(&hoth_dev_, fd, len, offset), + PAYLOAD_UPDATE_OK); + + ASSERT_EQ(lseek(fd, 0, SEEK_SET), 0); + uint8_t actual_data[sizeof(expected_data)] = {0}; + ASSERT_EQ(read(fd, actual_data, len), static_cast(len)); + EXPECT_THAT(actual_data, ::testing::ElementsAreArray(expected_data)); +}