diff --git a/net/curl/inc/ROOT/RCurlConnection.hxx b/net/curl/inc/ROOT/RCurlConnection.hxx index 4a8857050fa8b..f2d7ee2fdabe2 100644 --- a/net/curl/inc/ROOT/RCurlConnection.hxx +++ b/net/curl/inc/ROOT/RCurlConnection.hxx @@ -116,6 +116,8 @@ public: /// a valid batching of requests into multiple multi-range requests takes place automatically. /// The fNBytesRecv member of the ranges is only well-defined on success. RStatus SendRangesReq(std::size_t N, RUserRange *ranges); + /// Uploads data to the URL using an HTTP PUT request. + RStatus SendPutReq(const unsigned char *data, std::size_t length); const std::string &GetEscapedUrl() const { return fEscapedUrl; } diff --git a/net/curl/src/RCurlConnection.cxx b/net/curl/src/RCurlConnection.cxx index 4a2315319b07b..51cc18dd4c588 100644 --- a/net/curl/src/RCurlConnection.cxx +++ b/net/curl/src/RCurlConnection.cxx @@ -552,6 +552,22 @@ void ReverseDisplacements(std::vector &displacements, ROOT::Interna } } +struct RPutReadState { + const unsigned char *fData; + std::size_t fLength; + std::size_t fOffset = 0; +}; + +std::size_t CallbackPutRead(char *buffer, std::size_t size, std::size_t nmemb, void *userdata) +{ + auto *state = static_cast(userdata); + std::size_t remaining = state->fLength - state->fOffset; + std::size_t toCopy = std::min(size * nmemb, remaining); + memcpy(buffer, state->fData + state->fOffset, toCopy); + state->fOffset += toCopy; + return toCopy; +} + std::string GetCurlErrorString(CURLcode code) { return std::string(curl_easy_strerror(code)) + " (" + std::to_string(code) + ")"; @@ -849,6 +865,42 @@ ROOT::Internal::RCurlConnection::SendRangesReq(std::size_t N, RUserRange *ranges return status; } +ROOT::Internal::RCurlConnection::RStatus +ROOT::Internal::RCurlConnection::SendPutReq(const unsigned char *data, std::size_t length) +{ + auto rc = curl_easy_setopt(fHandle, CURLOPT_UPLOAD, 1L); + R__ASSERT(rc == CURLE_OK); + rc = curl_easy_setopt(fHandle, CURLOPT_INFILESIZE_LARGE, static_cast(length)); + R__ASSERT(rc == CURLE_OK); + + // Reset sticky options that may have been set by previous HEAD or GET calls on this handle + rc = curl_easy_setopt(fHandle, CURLOPT_NOBODY, 0L); + R__ASSERT(rc == CURLE_OK); + rc = curl_easy_setopt(fHandle, CURLOPT_HTTPGET, 0L); + R__ASSERT(rc == CURLE_OK); + rc = curl_easy_setopt(fHandle, CURLOPT_RANGE, NULL); + R__ASSERT(rc == CURLE_OK); + + RPutReadState readState{data, length, 0}; + rc = curl_easy_setopt(fHandle, CURLOPT_READFUNCTION, CallbackPutRead); + R__ASSERT(rc == CURLE_OK); + rc = curl_easy_setopt(fHandle, CURLOPT_READDATA, &readState); + R__ASSERT(rc == CURLE_OK); + + RStatus status; + Perform(status); + + // Reset upload options so that subsequent GET/HEAD calls on this handle are not affected + rc = curl_easy_setopt(fHandle, CURLOPT_UPLOAD, 0L); + R__ASSERT(rc == CURLE_OK); + rc = curl_easy_setopt(fHandle, CURLOPT_READFUNCTION, NULL); + R__ASSERT(rc == CURLE_OK); + rc = curl_easy_setopt(fHandle, CURLOPT_READDATA, NULL); + R__ASSERT(rc == CURLE_OK); + + return status; +} + void ROOT::Internal::RCurlConnection::SetCredentials(const RS3Credentials &credentials) { ClearCredentials(); diff --git a/net/curl/test/curl_connection.cxx b/net/curl/test/curl_connection.cxx index 80fd922ae7afa..8fcd84ec27ece 100644 --- a/net/curl/test/curl_connection.cxx +++ b/net/curl/test/curl_connection.cxx @@ -9,6 +9,50 @@ #include #include +static void TaskRecvPut(TServerSocket *serverSocket, std::string *requestHeaders, std::string *requestBody) +{ + requestHeaders->clear(); + requestBody->clear(); + auto sock = serverSocket->Accept(); + + const char *eof = "\r\n\r\n"; + const std::size_t eofLen = strlen(eof); + std::size_t nextInEof = 0; + char c; + while (sock->RecvRaw(&c, 1)) { + requestHeaders->push_back(c); + if (c == eof[nextInEof]) { + if (++nextInEof == eofLen) + break; + } else { + nextInEof = 0; + } + } + + // Parse Content-Length from headers + std::size_t contentLength = 0; + std::string clHeader = "Content-Length: "; + auto pos = requestHeaders->find(clHeader); + if (pos == std::string::npos) { + clHeader = "content-length: "; + pos = requestHeaders->find(clHeader); + } + if (pos != std::string::npos) { + auto valStart = pos + clHeader.size(); + auto valEnd = requestHeaders->find("\r\n", valStart); + contentLength = std::stoul(requestHeaders->substr(valStart, valEnd - valStart)); + } + + if (contentLength > 0) { + requestBody->resize(contentLength); + sock->RecvRaw(&(*requestBody)[0], contentLength); + } + + const char *response = "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"; + sock->SendRaw(response, strlen(response)); + sock->Close(); +} + static void TaskRecv(TServerSocket *serverSocket, std::string *request) { request->clear(); @@ -63,3 +107,26 @@ TEST(RCurlConnection, Cred) threadRecv.join(); EXPECT_EQ(std::string::npos, request.find("\r\nAuthorization: ")); } + +TEST(RCurlConnection, Put) +{ + TServerSocket sock(0, false, TServerSocket::kDefaultBacklog, -1, ESocketBindOption::kInaddrLoopback); + const std::string url = + std::string("http://") + sock.GetLocalInetAddress().GetHostAddress() + ":" + std::to_string(sock.GetLocalPort()); + + const unsigned char payload[] = "Hello, S3!"; + const std::size_t payloadLen = sizeof(payload) - 1; // exclude null terminator + + std::string headers; + std::string body; + std::thread threadRecv(TaskRecvPut, &sock, &headers, &body); + + ROOT::Internal::RCurlConnection conn(url); + auto status = conn.SendPutReq(payload, payloadLen); + + threadRecv.join(); + EXPECT_TRUE(static_cast(status)); + EXPECT_EQ(0u, headers.find("PUT ")); + EXPECT_NE(std::string::npos, headers.find("Content-Length: " + std::to_string(payloadLen))); + EXPECT_EQ(std::string(reinterpret_cast(payload), payloadLen), body); +}