diff --git a/.github/workflows/bazel.yml b/.github/workflows/bazel.yml index 6483e64..ca5d28e 100644 --- a/.github/workflows/bazel.yml +++ b/.github/workflows/bazel.yml @@ -42,6 +42,20 @@ jobs: bazelrc: | build --cxxopt='/std:c++17' build --cxxopt='/utf-8' + build --cxxopt='/wd4514' + build --cxxopt='/wd4625' + build --cxxopt='/wd4582' + build --cxxopt='/wd4365' + build --cxxopt='/wd5045' + build --cxxopt='/wd4820' + build --cxxopt='/wd5031' + build --cxxopt='/wd4668' + build --cxxopt='/wd5027' + build --cxxopt='/wd5204' + build --cxxopt='/wd5206' + build --cxxopt='/wd4626' + build --cxxopt='/wd4623' + build --cxxopt='/wd4710' #build --@boost.mysql//:ssl=boringssl build --@boost.asio//:ssl=boringssl diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index 46769bc..14be457 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -1,5 +1,5 @@ { - "lockFileVersion": 18, + "lockFileVersion": 24, "registryFileHashes": { "https://bcr.bazel.build/bazel_registry.json": "8a28e4aff06ee60aed2a8c281907fb8bcbf3b753c91fb5a5c57da3215d5b3497", "https://bcr.bazel.build/modules/abseil-cpp/20210324.2/MODULE.bazel": "7cd0312e064fde87c8d1cd79ba06c876bd23630c83466e9500321be55c96ace2", @@ -427,7 +427,7 @@ "moduleExtensions": { "@@rules_foreign_cc+//foreign_cc:extensions.bzl%tools": { "general": { - "bzlTransitiveDigest": "bzYvbsHj2ct8D8fQBqNNAJqAjmx6oxXp0japlQvDLjo=", + "bzlTransitiveDigest": "214a15Hi6YO0SxdwD2rGG5hBYv7/aQ5blgNKDcASQaM=", "usagesDigest": "Eyh4mAOi6L+Nn/lY/wQBJclQrmBnWdQM+B4lZeq6azA=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, @@ -811,7 +811,7 @@ }, "@@rules_kotlin+//src/main/starlark/core/repositories:bzlmod_setup.bzl%rules_kotlin_extensions": { "general": { - "bzlTransitiveDigest": "OlvsB0HsvxbR8ZN+J9Vf00X/+WVz/Y/5Xrq2LgcVfdo=", + "bzlTransitiveDigest": "rL/34P1aFDq2GqVC2zCFgQ8nTuOC6ziogocpvG50Qz8=", "usagesDigest": "QI2z8ZUR+mqtbwsf2fLqYdJAkPOHdOV+tF2yVAUgRzw=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, @@ -873,5 +873,6 @@ ] } } - } + }, + "facts": {} } diff --git a/framework/BUILD.bazel b/framework/BUILD.bazel index 991bc1e..6e513dc 100644 --- a/framework/BUILD.bazel +++ b/framework/BUILD.bazel @@ -9,17 +9,21 @@ cc_library( "session/*.cpp", "websocket/*.cpp", "context/*.cpp", + "client/*.cpp", ]), hdrs = glob([ "*.hpp", "context/*.hpp", "controller/*.hpp", "exception/*.hpp", + "cron/*.hpp", + "dto/*.hpp", "interceptor/*.hpp", "di/*.hpp", "router/*.hpp", "session/*.hpp", "websocket/*.hpp", + "client/*.hpp", ]), copts = [ "-std=c++17", diff --git a/framework/client/http_client.cpp b/framework/client/http_client.cpp new file mode 100644 index 0000000..2206470 --- /dev/null +++ b/framework/client/http_client.cpp @@ -0,0 +1,372 @@ +#include "http_client.hpp" +#include +#include +#include "io_context_pool.hpp" + +namespace khttpd::framework::client +{ + std::string replace_all(std::string str, const std::string& from, const std::string& to) + { + if (from.empty()) return str; + size_t start_pos = 0; + while ((start_pos = str.find(from, start_pos)) != std::string::npos) + { + str.replace(start_pos, from.length(), to); + start_pos += to.length(); + } + return str; + } + + // ========================================== + // Abstract Session to handle common logic + // ========================================== + class Session : public std::enable_shared_from_this + { + protected: + HttpClient::ResponseCallback callback_; + http::request req_; + http::response res_; + beast::flat_buffer buffer_; + std::chrono::seconds timeout_; + + public: + Session(HttpClient::ResponseCallback callback, std::chrono::seconds timeout) + : callback_(std::move(callback)), timeout_(timeout) + { + } + + virtual ~Session() = default; + virtual void run(const std::string& host, const std::string& port, http::request req) = 0; + + protected: + void on_fail(beast::error_code ec, const char* what) + { + // Log if needed: std::cerr << what << ": " << ec.message() << "\n"; + if (callback_) callback_(ec, {}); + } + }; + + // ========================================== + // Plain HTTP Session + // ========================================== + class HttpSession : public Session + { + beast::tcp_stream stream_; + tcp::resolver resolver_; + + // Helper: Downcast shared_from_this to avoid template deduction errors + std::shared_ptr get_shared() + { + return std::static_pointer_cast(shared_from_this()); + } + + public: + HttpSession(net::io_context& ioc, HttpClient::ResponseCallback cb, std::chrono::seconds timeout) + : Session(std::move(cb), timeout), stream_(ioc), resolver_(ioc) + { + } + + void run(const std::string& host, const std::string& port, http::request req) override + { + req_ = std::move(req); + stream_.expires_after(timeout_); + resolver_.async_resolve(host, port, + beast::bind_front_handler(&HttpSession::on_resolve, get_shared())); + } + + void on_resolve(beast::error_code ec, tcp::resolver::results_type results) + { + if (ec) return on_fail(ec, "resolve"); + stream_.expires_after(timeout_); + stream_.async_connect(results, + beast::bind_front_handler(&HttpSession::on_connect, get_shared())); + } + + void on_connect(beast::error_code ec, tcp::resolver::results_type::endpoint_type) + { + if (ec) return on_fail(ec, "connect"); + stream_.expires_after(timeout_); + http::async_write(stream_, req_, + beast::bind_front_handler(&HttpSession::on_write, get_shared())); + } + + void on_write(beast::error_code ec, std::size_t bytes_transferred) + { + boost::ignore_unused(bytes_transferred); + if (ec) return on_fail(ec, "write"); + + http::async_read(stream_, buffer_, res_, + beast::bind_front_handler(&HttpSession::on_read, get_shared())); + } + + void on_read(beast::error_code ec, std::size_t bytes_transferred) + { + boost::ignore_unused(bytes_transferred); + if (ec) return on_fail(ec, "read"); + + stream_.socket().shutdown(tcp::socket::shutdown_both, ec); + if (callback_) callback_(ec, std::move(res_)); + } + }; + + // ========================================== + // HTTPS Session + // ========================================== + class HttpsSession : public Session + { + beast::ssl_stream stream_; + tcp::resolver resolver_; + + std::shared_ptr get_shared() + { + return std::static_pointer_cast(shared_from_this()); + } + + public: + HttpsSession(net::io_context& ioc, ssl::context& ctx, HttpClient::ResponseCallback cb, std::chrono::seconds timeout) + : Session(std::move(cb), timeout), stream_(ioc, ctx), resolver_(ioc) + { + } + + void run(const std::string& host, const std::string& port, http::request req) override + { + req_ = std::move(req); + if (!SSL_set_tlsext_host_name(stream_.native_handle(), host.c_str())) + { + beast::error_code ec{static_cast(::ERR_get_error()), net::error::get_ssl_category()}; + return on_fail(ec, "ssl_setup"); + } + + stream_.next_layer().expires_after(timeout_); + resolver_.async_resolve(host, port, + beast::bind_front_handler(&HttpsSession::on_resolve, get_shared())); + } + + void on_resolve(beast::error_code ec, tcp::resolver::results_type results) + { + if (ec) return on_fail(ec, "resolve"); + stream_.next_layer().expires_after(timeout_); + beast::get_lowest_layer(stream_).async_connect(results, + beast::bind_front_handler( + &HttpsSession::on_connect, get_shared())); + } + + void on_connect(beast::error_code ec, tcp::resolver::results_type::endpoint_type) + { + if (ec) return on_fail(ec, "connect"); + stream_.next_layer().expires_after(timeout_); + stream_.async_handshake(ssl::stream_base::client, + beast::bind_front_handler(&HttpsSession::on_handshake, get_shared())); + } + + void on_handshake(beast::error_code ec) + { + if (ec) return on_fail(ec, "handshake"); + stream_.next_layer().expires_after(timeout_); + http::async_write(stream_, req_, + beast::bind_front_handler(&HttpsSession::on_write, get_shared())); + } + + void on_write(beast::error_code ec, std::size_t bytes_transferred) + { + boost::ignore_unused(bytes_transferred); + if (ec) return on_fail(ec, "write"); + http::async_read(stream_, buffer_, res_, + beast::bind_front_handler(&HttpsSession::on_read, get_shared())); + } + + void on_read(beast::error_code ec, std::size_t bytes_transferred) + { + boost::ignore_unused(bytes_transferred); + if (ec) return on_fail(ec, "read"); + + stream_.async_shutdown(beast::bind_front_handler(&HttpsSession::on_shutdown, get_shared())); + } + + void on_shutdown(beast::error_code ec) + { + if (ec == net::error::eof || ec == ssl::error::stream_truncated) + ec = {}; + if (callback_) callback_(ec, std::move(res_)); + } + }; + + // 1. 傻瓜式:全局 IO + 默认 SSL + HttpClient::HttpClient() + : ioc_(IoContextPool::instance().get_io_context()) // 从单例获取 + { + // 同样的默认 SSL 初始化逻辑 + own_ssl_ctx_ = std::make_shared(ssl::context::tls_client); + own_ssl_ctx_->set_default_verify_paths(); + own_ssl_ctx_->set_verify_mode(ssl::verify_none); + ssl_ctx_ptr_ = own_ssl_ctx_.get(); + } + + // 2. 全局 IO + 自定义 SSL + HttpClient::HttpClient(ssl::context& ssl_ctx) + : ioc_(IoContextPool::instance().get_io_context()) + , ssl_ctx_ptr_(&ssl_ctx) + { + } + + // 3. 自定义 IO + 默认 SSL (原逻辑) + HttpClient::HttpClient(net::io_context& ioc) + : ioc_(ioc) + { + own_ssl_ctx_ = std::make_shared(ssl::context::tls_client); + own_ssl_ctx_->set_default_verify_paths(); + own_ssl_ctx_->set_verify_mode(ssl::verify_none); + ssl_ctx_ptr_ = own_ssl_ctx_.get(); + } + + // 4. 全自定义 + HttpClient::HttpClient(net::io_context& ioc, ssl::context& ssl_ctx) + : ioc_(ioc) + , ssl_ctx_ptr_(&ssl_ctx) + { + } + + void HttpClient::set_base_url(const std::string& url) + { + auto result = boost::urls::parse_uri(url); + if (result.has_value()) + { + base_url_ = result.value(); + } + else + { + // Fallback for missing scheme + if (url.find("http") != 0) + { + auto res2 = boost::urls::parse_uri("http://" + url); + if (res2.has_value()) base_url_ = res2.value(); + } + } + } + + void HttpClient::set_default_header(const std::string& key, const std::string& value) + { + default_headers_[key] = value; + } + + void HttpClient::set_bearer_token(const std::string& token) + { + set_default_header("Authorization", "Bearer " + token); + } + + void HttpClient::set_timeout(std::chrono::seconds seconds) + { + timeout_ = seconds; + } + + HttpClient::UrlParts HttpClient::parse_target(const std::string& path_in, + const std::map& query) + { + boost::urls::url u; + + if (base_url_.has_value()) + { + u = base_url_.value(); + if (!path_in.empty()) + { + if (path_in.front() != '/') u.set_path(u.path() + "/" + path_in); + else u.set_path(path_in); + } + } + + auto parse_res = boost::urls::parse_uri(path_in); + if (parse_res.has_value()) + { + u = parse_res.value(); + } + + for (const auto& [k, v] : query) + { + u.params().append({k, v}); + } + + UrlParts parts; + parts.scheme = u.scheme(); + parts.host = u.host(); + parts.port = u.port(); + parts.target = u.encoded_target(); + + if (parts.scheme.empty()) parts.scheme = "http"; + if (parts.target.empty()) parts.target = "/"; + if (parts.port.empty()) parts.port = (parts.scheme == "https") ? "443" : "80"; + + return parts; + } + + void HttpClient::request(http::verb method, + std::string path, + const std::map& query_params, + const std::string& body, + const std::map& headers, + ResponseCallback callback) + { + try + { + auto parts = parse_target(path, query_params); + + http::request req{method, parts.target, 11}; + req.set(http::field::host, parts.host); + req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING); + + for (const auto& h : default_headers_) req.set(h.first, h.second); + for (const auto& h : headers) req.set(h.first, h.second); + + if (!body.empty()) + { + req.body() = body; + req.prepare_payload(); + } + + std::shared_ptr session; + if (parts.scheme == "https") + { + if (!ssl_ctx_ptr_) + { + if (callback) callback(beast::error_code(beast::errc::operation_not_supported, beast::system_category()), {}); + return; + } + session = std::make_shared(ioc_, *ssl_ctx_ptr_, std::move(callback), timeout_); + } + else + { + session = std::make_shared(ioc_, std::move(callback), timeout_); + } + session->run(parts.host, parts.port, std::move(req)); + } + catch (const std::exception& e) + { + if (callback) callback(beast::error_code(beast::errc::invalid_argument, beast::system_category()), {}); + } + } + + http::response HttpClient::request_sync( + http::verb method, + std::string path, + const std::map& query_params, + const std::string& body, + const std::map& headers) + { + std::promise>> p; + auto f = p.get_future(); + + this->request(method, path, query_params, body, headers, + [&p](beast::error_code ec, http::response res) + { + p.set_value({ec, std::move(res)}); + }); + + f.wait(); + auto result = f.get(); + + if (result.first) + { + throw boost::system::system_error(result.first); + } + return result.second; + } +} diff --git a/framework/client/http_client.hpp b/framework/client/http_client.hpp new file mode 100644 index 0000000..88fd585 --- /dev/null +++ b/framework/client/http_client.hpp @@ -0,0 +1,134 @@ +#ifndef KHTTPD_FRAMEWORK_CLIENT_HTTP_CLIENT_HPP +#define KHTTPD_FRAMEWORK_CLIENT_HTTP_CLIENT_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace khttpd::framework::client +{ + namespace beast = boost::beast; + namespace http = beast::http; + namespace net = boost::asio; + namespace ssl = boost::asio::ssl; + using tcp = boost::asio::ip::tcp; + + // Helper: String conversion + template + std::string to_string(const T& val) + { + if constexpr (std::is_convertible_v) + { + return std::string(val); + } + else if constexpr (std::is_arithmetic_v) + { + return std::to_string(val); + } + else + { + return std::to_string(val); + } + } + + inline std::string to_string(const std::string& val) { return val; } + + // Helper: Body serialization + template + std::string serialize_body(const T& value) + { + if constexpr (std::is_convertible_v) + { + return std::string(value); + } + else + { + // Assume it's serializable to JSON + return boost::json::serialize(boost::json::value_from(value)); + } + } + + // Helper: Replace function + std::string replace_all(std::string str, const std::string& from, const std::string& to); + + class HttpClient : public std::enable_shared_from_this + { + public: + using ResponseCallback = std::function)>; + + // 1. 【新增】傻瓜式构造函数:使用全局 IO 池,内部默认 SSL + HttpClient(); + + // 2. 【新增】使用全局 IO 池,但指定自定义 SSL + explicit HttpClient(ssl::context& ssl_ctx); + + // 3. 【保留】专家模式:指定外部 IO Context + explicit HttpClient(net::io_context& ioc); + HttpClient(net::io_context& ioc, ssl::context& ssl_ctx); + + virtual ~HttpClient() = default; + + // Configuration + void set_base_url(const std::string& url); + void set_default_header(const std::string& key, const std::string& value); + void set_bearer_token(const std::string& token); + void set_timeout(std::chrono::seconds seconds); + + // Core Request Method (Used by Macros) + void request(http::verb method, + std::string path, // relative path or full url + const std::map& query_params, + const std::string& body, + const std::map& headers, + ResponseCallback callback); + + // Sync Request Method + http::response request_sync( + http::verb method, + std::string path, + const std::map& query_params, + const std::string& body, + const std::map& headers); + + private: + struct UrlParts + { + std::string scheme; + std::string host; + std::string port; + std::string target; + }; + + UrlParts parse_target(const std::string& path, const std::map& query); + + net::io_context& ioc_; + + // SSL Context Management + std::shared_ptr own_ssl_ctx_; // Holds ownership if created internally + ssl::context* ssl_ctx_ptr_; // Points to the active context + + std::optional base_url_; + std::map default_headers_; + std::chrono::seconds timeout_{30}; + }; +} + +// Include macros at the end +#include "macros.hpp" + +#endif // KHTTPD_FRAMEWORK_CLIENT_HTTP_CLIENT_HPP diff --git a/framework/client/macros.hpp b/framework/client/macros.hpp new file mode 100644 index 0000000..78b7438 --- /dev/null +++ b/framework/client/macros.hpp @@ -0,0 +1,146 @@ +#ifndef KHTTPD_FRAMEWORK_CLIENT_MACROS_HPP +#define KHTTPD_FRAMEWORK_CLIENT_MACROS_HPP + +#include +#include +#include +#include +#include + +// ========================================================================= +// Compiler Warning Suppression +// ========================================================================= +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wvariadic-macro-arguments-omitted" +#pragma clang diagnostic ignored "-Wpedantic" +#elif defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpedantic" +#endif + +// ========================================================================= +// MSVC Compatibility Helper (关键修复 1) +// ========================================================================= +// 用于强制 MSVC 展开 __VA_ARGS__ +#define EXPAND(x) x + +// ========================================================================= +// Argument Tags +// ========================================================================= +#define QUERY(Type, Name, Key) (QUERY_TAG, Type, Name, Key) +#define PATH(Type, Name) (PATH_TAG, Type, Name) +#define BODY(Type, Name) (BODY_TAG, Type, Name) +#define HEADER(Type, Name, Key) (HEADER_TAG, Type, Name, Key) + +// ========================================================================= +// Dispatching Logic (关键修复 2:简化解包逻辑) +// ========================================================================= + +// 之前的 POP_TAG 方式在 MSVC 上容易出错。 +// 我们改为直接展开 Tuple: +// SIG_DISPATCH((TAG, Type, Name)) -> SIG_DISPATCH_I(TAG, Type, Name) -> SIG_TAG(Type, Name) + +#define SIG_DISPATCH(Tuple) EXPAND(SIG_DISPATCH_I Tuple) +#define SIG_DISPATCH_I(Tag, ...) EXPAND(SIG_##Tag(__VA_ARGS__)) + +#define PROC_DISPATCH(Tuple) EXPAND(PROC_DISPATCH_I Tuple) +#define PROC_DISPATCH_I(Tag, ...) EXPAND(PROC_##Tag(__VA_ARGS__)) + +// ========================================================================= +// Implementation Logic +// ========================================================================= +// Signature Generation +#define SIG_QUERY_TAG(Type, Name, Key) Type Name +#define SIG_PATH_TAG(Type, Name) Type Name +#define SIG_BODY_TAG(Type, Name) Type Name +#define SIG_HEADER_TAG(Type, Name, Key) Type Name + +// Process Logic +#define PROC_QUERY_TAG(Type, Name, Key) query_params.emplace(Key, khttpd::framework::client::to_string(Name)) +#define PROC_PATH_TAG(Type, Name) path_str = khttpd::framework::client::replace_all(path_str, ":" #Name, khttpd::framework::client::to_string(Name)) +#define PROC_BODY_TAG(Type, Name) body_str = khttpd::framework::client::serialize_body(Name) +#define PROC_HEADER_TAG(Type, Name, Key) header_map.emplace(Key, khttpd::framework::client::to_string(Name)) + +// ========================================================================= +// API Function Body Generators +// ========================================================================= + +#define API_FUNC_BODY(METHOD, PATH_TEMPLATE, ...) \ + std::string path_str = PATH_TEMPLATE; \ + std::map query_params; \ + std::map header_map; \ + std::string body_str; \ + __VA_ARGS__ \ + this->request(METHOD, path_str, query_params, body_str, header_map, std::move(callback)); + +#define API_FUNC_BODY_SYNC(METHOD, PATH_TEMPLATE, ...) \ + std::string path_str = PATH_TEMPLATE; \ + std::map query_params; \ + std::map header_map; \ + std::string body_str; \ + __VA_ARGS__ \ + return this->request_sync(METHOD, path_str, query_params, body_str, header_map); + +// ========================================================================= +// N-Argument Macro Implementations +// ========================================================================= + +#define API_CALL_0(METHOD, PT, NAME) \ + void NAME(khttpd::framework::client::HttpClient::ResponseCallback callback) { \ + API_FUNC_BODY(METHOD, PT, ) \ + } \ + boost::beast::http::response NAME##_sync() { \ + API_FUNC_BODY_SYNC(METHOD, PT, ) \ + } + +#define API_CALL_1(METHOD, PT, NAME, A) \ + void NAME(SIG_DISPATCH(A), khttpd::framework::client::HttpClient::ResponseCallback callback) { \ + API_FUNC_BODY(METHOD, PT, PROC_DISPATCH(A);) \ + } \ + auto NAME##_sync(SIG_DISPATCH(A)) { \ + API_FUNC_BODY_SYNC(METHOD, PT, PROC_DISPATCH(A);) \ + } + +#define API_CALL_2(METHOD, PT, NAME, A, B) \ + void NAME(SIG_DISPATCH(A), SIG_DISPATCH(B), khttpd::framework::client::HttpClient::ResponseCallback callback) { \ + API_FUNC_BODY(METHOD, PT, PROC_DISPATCH(A); PROC_DISPATCH(B);) \ + } \ + auto NAME##_sync(SIG_DISPATCH(A), SIG_DISPATCH(B)) { \ + API_FUNC_BODY_SYNC(METHOD, PT, PROC_DISPATCH(A); PROC_DISPATCH(B);) \ + } + +#define API_CALL_3(METHOD, PT, NAME, A, B, C) \ + void NAME(SIG_DISPATCH(A), SIG_DISPATCH(B), SIG_DISPATCH(C), khttpd::framework::client::HttpClient::ResponseCallback callback) { \ + API_FUNC_BODY(METHOD, PT, PROC_DISPATCH(A); PROC_DISPATCH(B); PROC_DISPATCH(C);) \ + } \ + auto NAME##_sync(SIG_DISPATCH(A), SIG_DISPATCH(B), SIG_DISPATCH(C)) { \ + API_FUNC_BODY_SYNC(METHOD, PT, PROC_DISPATCH(A); PROC_DISPATCH(B); PROC_DISPATCH(C);) \ + } + +#define API_CALL_4(METHOD, PT, NAME, A, B, C, D) \ + void NAME(SIG_DISPATCH(A), SIG_DISPATCH(B), SIG_DISPATCH(C), SIG_DISPATCH(D), khttpd::framework::client::HttpClient::ResponseCallback callback) { \ + API_FUNC_BODY(METHOD, PT, PROC_DISPATCH(A); PROC_DISPATCH(B); PROC_DISPATCH(C); PROC_DISPATCH(D);) \ + } \ + auto NAME##_sync(SIG_DISPATCH(A), SIG_DISPATCH(B), SIG_DISPATCH(C), SIG_DISPATCH(D)) { \ + API_FUNC_BODY_SYNC(METHOD, PT, PROC_DISPATCH(A); PROC_DISPATCH(B); PROC_DISPATCH(C); PROC_DISPATCH(D);) \ + } + +// ========================================================================= +// Dispatcher Logic (关键修复 3:修正宏选择计数) +// ========================================================================= + +#define GET_MACRO(_1, _2, _3, _4, _5, _6, _7, NAME, ...) NAME + +// 在这里使用 EXPAND 包裹整个 GET_MACRO 调用。 +// 这解决了 "not enough arguments" 警告,并确保正确选择 API_CALL_x 宏。 +#define API_CALL(...) EXPAND(GET_MACRO(__VA_ARGS__, API_CALL_4, API_CALL_3, API_CALL_2, API_CALL_1, API_CALL_0, DUMMY)(__VA_ARGS__)) + +#if defined(__clang__) +#pragma clang diagnostic pop +#elif defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +#endif // KHTTPD_FRAMEWORK_CLIENT_MACROS_HPP diff --git a/framework/client/websocket_client.cpp b/framework/client/websocket_client.cpp new file mode 100644 index 0000000..4b3ce7e --- /dev/null +++ b/framework/client/websocket_client.cpp @@ -0,0 +1,431 @@ +#include "websocket_client.hpp" +#include +#include + +#include "io_context_pool.hpp" + +namespace khttpd::framework::client +{ + // ========================================== + // Internal Session Abstraction + // ========================================== + struct WebsocketSessionImpl : public std::enable_shared_from_this + { + WebsocketClient* owner_; + std::string host_; + beast::flat_buffer buffer_; + std::deque write_queue_; // 写队列 + bool is_writing_ = false; + + explicit WebsocketSessionImpl(WebsocketClient* owner) : owner_(owner) + { + } + + virtual ~WebsocketSessionImpl() = default; + + virtual void run(const std::string& host, const std::string& port, const std::string& target, + const std::map& headers, WebsocketClient::ConnectCallback cb) = 0; + virtual void close() = 0; + + // 核心发送逻辑:入队 + void queue_write(std::string message) + { + net::post(get_executor(), beast::bind_front_handler( + &WebsocketSessionImpl::on_queue_write, shared_from_this(), std::move(message))); + } + + protected: + virtual net::any_io_executor get_executor() = 0; + virtual void do_write_from_queue() = 0; + + void on_queue_write(std::string message) + { + write_queue_.push_back(std::move(message)); + if (!is_writing_) + { + is_writing_ = true; + do_write_from_queue(); + } + } + + // 通用的读循环处理 + void process_read_result(beast::error_code ec, std::size_t bytes) + { + boost::ignore_unused(bytes); + if (ec) + { + // 修改:增加 operation_aborted 到关闭判定条件中 + // 当 async_read 被取消(例如正在关闭时),也应视为连接断开 + if (ec == websocket::error::closed || + ec == net::error::eof || + ec == ssl::error::stream_truncated || + ec == boost::asio::error::connection_reset || + ec == boost::asio::error::operation_aborted) + { + if (owner_->on_close_) owner_->on_close_(); + } + else + { + if (owner_->on_error_) owner_->on_error_(ec); + } + return; + } + + if (owner_->on_message_) + { + owner_->on_message_(beast::buffers_to_string(buffer_.data())); + } + buffer_.consume(buffer_.size()); + } + + // 通用的写完成处理 + void process_write_result(beast::error_code ec) + { + if (ec) + { + is_writing_ = false; // Stop writing on error + if (owner_->on_error_) owner_->on_error_(ec); + return; + } + + write_queue_.pop_front(); + + if (!write_queue_.empty()) + { + do_write_from_queue(); + } + else + { + is_writing_ = false; + } + } + }; + + // ========================================== + // Plain TCP Session (ws://) + // ========================================== + class PlainWebsocketSession : public WebsocketSessionImpl + { + websocket::stream ws_; + tcp::resolver resolver_; + WebsocketClient::ConnectCallback connect_cb_; + + public: + PlainWebsocketSession(net::io_context& ioc, WebsocketClient* owner) + : WebsocketSessionImpl(owner), ws_(net::make_strand(ioc)), resolver_(ioc) + { + } + + net::any_io_executor get_executor() override { return ws_.get_executor(); } + + void run(const std::string& host, const std::string& port, const std::string& target, + const std::map& headers, WebsocketClient::ConnectCallback cb) override + { + host_ = host; + connect_cb_ = std::move(cb); + + resolver_.async_resolve(host, port, beast::bind_front_handler(&PlainWebsocketSession::on_resolve, + std::static_pointer_cast( + shared_from_this()), target, headers)); + } + + void close() override + { + net::post(ws_.get_executor(), [self = std::static_pointer_cast(shared_from_this())]() + { + if (self->ws_.is_open()) + { + self->ws_.async_close(websocket::close_code::normal, [self](beast::error_code) + { + /* ignore close error */ + }); + } + }); + } + + protected: + void do_write_from_queue() override + { + ws_.async_write(net::buffer(write_queue_.front()), + beast::bind_front_handler(&PlainWebsocketSession::on_write, + std::static_pointer_cast(shared_from_this()))); + } + + private: + void on_resolve(std::string target, std::map headers, beast::error_code ec, + tcp::resolver::results_type results) + { + if (ec) return fail(ec); + beast::get_lowest_layer(ws_).async_connect(results, beast::bind_front_handler( + &PlainWebsocketSession::on_connect, + std::static_pointer_cast(shared_from_this()), + target, headers)); + } + + void on_connect(std::string target, std::map headers, beast::error_code ec, + tcp::resolver::results_type::endpoint_type) + { + if (ec) return fail(ec); + + ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); + + // Set Headers + ws_.set_option(websocket::stream_base::decorator([headers](websocket::request_type& req) + { + req.set(beast::http::field::user_agent, BOOST_BEAST_VERSION_STRING); + for (const auto& h : headers) req.set(h.first, h.second); + })); + + ws_.async_handshake(host_, target, + beast::bind_front_handler(&PlainWebsocketSession::on_handshake, + std::static_pointer_cast( + shared_from_this()))); + } + + void on_handshake(beast::error_code ec) + { + if (ec) return fail(ec); + if (connect_cb_) connect_cb_(ec); + do_read(); + } + + void do_read() + { + ws_.async_read(buffer_, beast::bind_front_handler(&PlainWebsocketSession::on_read, + std::static_pointer_cast( + shared_from_this()))); + } + + void on_read(beast::error_code ec, std::size_t bytes) + { + process_read_result(ec, bytes); + if (!ec) do_read(); + } + + void on_write(beast::error_code ec, std::size_t) + { + process_write_result(ec); + } + + void fail(beast::error_code ec) + { + if (connect_cb_) connect_cb_(ec); + } + }; + + // ========================================== + // SSL Session (wss://) + // ========================================== + class SslWebsocketSession : public WebsocketSessionImpl + { + websocket::stream> ws_; + tcp::resolver resolver_; + WebsocketClient::ConnectCallback connect_cb_; + + public: + SslWebsocketSession(net::io_context& ioc, ssl::context& ctx, WebsocketClient* owner) + : WebsocketSessionImpl(owner), ws_(net::make_strand(ioc), ctx), resolver_(ioc) + { + } + + net::any_io_executor get_executor() override { return ws_.get_executor(); } + + void run(const std::string& host, const std::string& port, const std::string& target, + const std::map& headers, WebsocketClient::ConnectCallback cb) override + { + host_ = host; + connect_cb_ = std::move(cb); + + if (!SSL_set_tlsext_host_name(ws_.next_layer().native_handle(), host.c_str())) + { + return fail(beast::error_code(static_cast(::ERR_get_error()), net::error::get_ssl_category())); + } + + resolver_.async_resolve(host, port, beast::bind_front_handler(&SslWebsocketSession::on_resolve, + std::static_pointer_cast( + shared_from_this()), target, headers)); + } + + void close() override + { + net::post(ws_.get_executor(), [self = std::static_pointer_cast(shared_from_this())]() + { + if (self->ws_.is_open()) + { + self->ws_.async_close(websocket::close_code::normal, [self](beast::error_code) + { + }); + } + }); + } + + protected: + void do_write_from_queue() override + { + ws_.async_write(net::buffer(write_queue_.front()), + beast::bind_front_handler(&SslWebsocketSession::on_write, + std::static_pointer_cast(shared_from_this()))); + } + + private: + void on_resolve(std::string target, std::map headers, beast::error_code ec, + tcp::resolver::results_type results) + { + if (ec) return fail(ec); + beast::get_lowest_layer(ws_).async_connect(results, beast::bind_front_handler( + &SslWebsocketSession::on_connect, + std::static_pointer_cast(shared_from_this()), + target, headers)); + } + + void on_connect(std::string target, std::map headers, beast::error_code ec, + tcp::resolver::results_type::endpoint_type) + { + if (ec) return fail(ec); + ws_.next_layer().async_handshake(ssl::stream_base::client, + beast::bind_front_handler(&SslWebsocketSession::on_ssl_handshake, + std::static_pointer_cast( + shared_from_this()), target, headers)); + } + + void on_ssl_handshake(std::string target, std::map headers, beast::error_code ec) + { + if (ec) return fail(ec); + + ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); + ws_.set_option(websocket::stream_base::decorator([headers](websocket::request_type& req) + { + req.set(beast::http::field::user_agent, BOOST_BEAST_VERSION_STRING); + for (const auto& h : headers) req.set(h.first, h.second); + })); + + ws_.async_handshake(host_, target, + beast::bind_front_handler(&SslWebsocketSession::on_handshake, + std::static_pointer_cast(shared_from_this()))); + } + + void on_handshake(beast::error_code ec) + { + if (ec) return fail(ec); + if (connect_cb_) connect_cb_(ec); + do_read(); + } + + void do_read() + { + ws_.async_read(buffer_, beast::bind_front_handler(&SslWebsocketSession::on_read, + std::static_pointer_cast( + shared_from_this()))); + } + + void on_read(beast::error_code ec, std::size_t bytes) + { + process_read_result(ec, bytes); + if (!ec) do_read(); + } + + void on_write(beast::error_code ec, std::size_t) + { + process_write_result(ec); + } + + void fail(beast::error_code ec) + { + if (connect_cb_) connect_cb_(ec); + } + }; + + // ========================================== + // WebsocketClient Implementation + // ========================================== + + WebsocketClient::WebsocketClient() + : ioc_(IoContextPool::instance().get_io_context()) + { + own_ssl_ctx_ = std::make_shared(ssl::context::tls_client); + own_ssl_ctx_->set_default_verify_paths(); + own_ssl_ctx_->set_verify_mode(ssl::verify_none); + ssl_ctx_ptr_ = own_ssl_ctx_.get(); + } + + WebsocketClient::WebsocketClient(net::io_context& ioc) : ioc_(ioc) + { + // Default SSL Context + own_ssl_ctx_ = std::make_shared(ssl::context::tls_client); + own_ssl_ctx_->set_default_verify_paths(); + own_ssl_ctx_->set_verify_mode(ssl::verify_none); + ssl_ctx_ptr_ = own_ssl_ctx_.get(); + } + + WebsocketClient::WebsocketClient(net::io_context& ioc, ssl::context& ssl_ctx) + : ioc_(ioc), ssl_ctx_ptr_(&ssl_ctx) + { + } + + WebsocketClient::~WebsocketClient() + { + close(); + } + + void WebsocketClient::set_header(const std::string& key, const std::string& value) + { + headers_[key] = value; + } + + void WebsocketClient::connect(const std::string& url, ConnectCallback callback) + { + auto url_result = boost::urls::parse_uri(url); + if (!url_result.has_value()) + { + if (callback) callback(beast::error_code(beast::http::error::bad_target)); + return; + } + auto u = url_result.value(); + std::string host = u.host(); + std::string scheme = u.scheme(); + std::string port = u.port(); + std::string target = u.encoded_path().data(); + if (target.empty()) target = "/"; + + if (port.empty()) port = (scheme == "wss") ? "443" : "80"; + + if (scheme == "wss") + { + if (!ssl_ctx_ptr_) + { + if (callback) callback(beast::error_code(beast::errc::operation_not_supported, beast::system_category())); + return; + } + auto s = std::make_shared(ioc_, *ssl_ctx_ptr_, this); + session_ = s; + s->run(host, port, target, headers_, std::move(callback)); + } + else + { + auto s = std::make_shared(ioc_, this); + session_ = s; + s->run(host, port, target, headers_, std::move(callback)); + } + } + + void WebsocketClient::send(const std::string& message) + { + if (session_) + { + session_->queue_write(message); + } + } + + void WebsocketClient::close() + { + if (session_) + { + session_->close(); + // session_ = nullptr; // keep alive for handlers to finish + } + } + + void WebsocketClient::set_on_message(MessageHandler handler) { on_message_ = std::move(handler); } + void WebsocketClient::set_on_error(ErrorHandler handler) { on_error_ = std::move(handler); } + void WebsocketClient::set_on_close(CloseHandler handler) { on_close_ = std::move(handler); } +} diff --git a/framework/client/websocket_client.hpp b/framework/client/websocket_client.hpp new file mode 100644 index 0000000..15e275a --- /dev/null +++ b/framework/client/websocket_client.hpp @@ -0,0 +1,79 @@ +#ifndef KHTTPD_FRAMEWORK_CLIENT_WEBSOCKET_CLIENT_HPP +#define KHTTPD_FRAMEWORK_CLIENT_WEBSOCKET_CLIENT_HPP + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace khttpd::framework::client +{ + namespace beast = boost::beast; + namespace websocket = beast::websocket; + namespace net = boost::asio; + namespace ssl = boost::asio::ssl; + using tcp = boost::asio::ip::tcp; + + // 前置声明内部会话接口 + struct WebsocketSessionImpl; + + class WebsocketClient : public std::enable_shared_from_this + { + public: + using ConnectCallback = std::function; + using MessageHandler = std::function; + using ErrorHandler = std::function; + using CloseHandler = std::function; + + WebsocketClient(); + // 构造函数:支持默认 SSL 或 外部 SSL Context + explicit WebsocketClient(net::io_context& ioc); + WebsocketClient(net::io_context& ioc, ssl::context& ssl_ctx); + ~WebsocketClient(); + + // 连接 URL (支持 ws:// 和 wss://) + void connect(const std::string& url, ConnectCallback callback); + + // 发送消息 (线程安全,支持并发调用) + void send(const std::string& message); + + // 关闭连接 + void close(); + + // 配置 + void set_header(const std::string& key, const std::string& value); + void set_on_message(MessageHandler handler); + void set_on_error(ErrorHandler handler); + void set_on_close(CloseHandler handler); + + private: + friend WebsocketSessionImpl; + net::io_context& ioc_; + + // SSL Context Management + std::shared_ptr own_ssl_ctx_; + ssl::context* ssl_ctx_ptr_; + + // Callbacks + MessageHandler on_message_; + ErrorHandler on_error_; + CloseHandler on_close_; + + // Headers to send during handshake + std::map headers_; + + // 多态的内部会话 (持有实际的 websocket stream) + std::shared_ptr session_; + }; +} + +#endif // KHTTPD_FRAMEWORK_CLIENT_WEBSOCKET_CLIENT_HPP diff --git a/framework/cron/CronJob.hpp b/framework/cron/CronJob.hpp new file mode 100644 index 0000000..ba21ea2 --- /dev/null +++ b/framework/cron/CronJob.hpp @@ -0,0 +1,130 @@ +#ifndef KHTTPD_FRAMEWORK_CRON_JOB_HPP +#define KHTTPD_FRAMEWORK_CRON_JOB_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include "croncpp.hpp" +#include "io_context_pool.hpp" + +namespace khttpd::framework +{ + class CronJob : public std::enable_shared_from_this + { + public: + explicit CronJob(const std::string& expression) + : timer_(IoContextPool::instance().get_io_context()) + , expression_(expression) + , is_running_(false) + { + try + { + cron_expr_ = cron::make_cron(expression); + } + catch (const std::exception& e) + { + std::cerr << "[CronJob] Invalid expression '" << expression << "': " << e.what() << std::endl; + throw; + } + } + + virtual ~CronJob() + { + } + + /** + * @brief 启动任务 + * @param delay_ms 延迟启动时间(毫秒),默认为 0(立即计算下一次执行时间) + */ + void start(std::chrono::milliseconds delay_ms = std::chrono::milliseconds(0)) + { + bool expected = false; + if (is_running_.compare_exchange_strong(expected, true)) + { + if (delay_ms.count() > 0) + { + // 延迟启动逻辑 + timer_.expires_after(delay_ms); + auto self = shared_from_this(); + timer_.async_wait([this, self](const boost::system::error_code& ec) + { + if (!ec && is_running_) + { + schedule_next(); // 延迟结束后,开始正常的 cron 调度 + } + }); + } + else + { + // 立即启动 + schedule_next(); + } + } + } + + void stop() + { + is_running_ = false; + timer_.cancel(); + } + + // 判断当前是否在运行状态 + bool is_running() const { return is_running_; } + + protected: + virtual void run() = 0; + + private: + void schedule_next() + { + if (!is_running_) return; + + auto now_time_t = std::time(nullptr); + std::time_t next_time_t = cron::cron_next(cron_expr_, now_time_t); + auto next_time_point = std::chrono::system_clock::from_time_t(next_time_t); + + timer_.expires_at(next_time_point); + + auto self = shared_from_this(); + + timer_.async_wait([this, self](const boost::system::error_code& ec) + { + if (ec == boost::asio::error::operation_aborted) return; + if (!is_running_) return; + + if (ec) + { + std::cerr << "[CronJob] Timer error: " << ec.message() << std::endl; + return; + } + + try + { + this->run(); + } + catch (const std::exception& e) + { + std::cerr << "[CronJob] Task exception: " << e.what() << std::endl; + } + + if (is_running_) + { + schedule_next(); + } + }); + } + + private: + boost::asio::system_timer timer_; + std::string expression_; + cron::cronexpr cron_expr_; + std::atomic is_running_; + }; +} + +#endif diff --git a/framework/cron/CronScheduler.hpp b/framework/cron/CronScheduler.hpp new file mode 100644 index 0000000..69f420c --- /dev/null +++ b/framework/cron/CronScheduler.hpp @@ -0,0 +1,75 @@ +#ifndef KHTTPD_FRAMEWORK_CRON_SCHEDULER_HPP +#define KHTTPD_FRAMEWORK_CRON_SCHEDULER_HPP + +#include "CronJob.hpp" +#include +#include +#include +#include + +namespace khttpd::framework +{ + class CronScheduler + { + private: + // 内部通用实现类:LambdaCronJob + // 专门用于执行 std::function 的任务 + class LambdaCronJob : public CronJob + { + public: + LambdaCronJob(const std::string& expr, std::function func) + : CronJob(expr), func_(std::move(func)) + { + } + + protected: + void run() override + { + if (func_) func_(); + } + + private: + std::function func_; + }; + + public: + // 单例获取 + static CronScheduler& instance() + { + static CronScheduler instance; + return instance; + } + + // 禁止拷贝 + CronScheduler(const CronScheduler&) = delete; + CronScheduler& operator=(const CronScheduler&) = delete; + + /** + * @brief 调度一个 Cron 任务 + * + * @param expression Cron 表达式 (如 "* * * * * *") + * @param task 回调函数 + * @param delay_ms 首次启动延迟 (默认 0) + * @return std::shared_ptr 返回任务句柄。 + * 注意:即使忽略返回值,任务也会自动运行(因为 ASIO 内部持有了 shared_ptr)。 + * 保留返回值是为了让你有机会调用 .stop()。 + */ + std::shared_ptr schedule( + const std::string& expression, + std::function task, + std::chrono::milliseconds delay_ms = std::chrono::milliseconds(0)) + { + auto job = std::make_shared(expression, std::move(task)); + job->start(delay_ms); + + // 这里的 job 即使出了作用域,也会因为 CronJob 内部 async_wait 捕获了 shared_from_this 而存活。 + // 返回它只是为了让调用者有控制权。 + return job; + } + + private: + CronScheduler() = default; + }; +} + +#endif // KHTTPD_FRAMEWORK_CRON_SCHEDULER_HPP diff --git a/framework/cron/croncpp.hpp b/framework/cron/croncpp.hpp new file mode 100644 index 0000000..4a9e075 --- /dev/null +++ b/framework/cron/croncpp.hpp @@ -0,0 +1,937 @@ +// +// Created by Caesar on 2025/12/14. +// + +#ifndef KHTTPD_FRAMEWORK_CRONCPP_HPP +#define KHTTPD_FRAMEWORK_CRONCPP_HPP +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if __cplusplus > 201402L +#include +#define CRONCPP_IS_CPP17 +#endif + +namespace cron +{ +#ifdef CRONCPP_IS_CPP17 +#define CRONCPP_STRING_VIEW std::string_view +#define CRONCPP_STRING_VIEW_NPOS std::string_view::npos +#define CRONCPP_CONSTEXPTR constexpr +#else +#define CRONCPP_STRING_VIEW std::string const & +#define CRONCPP_STRING_VIEW_NPOS std::string::npos +#define CRONCPP_CONSTEXPTR +#endif + + using cron_int = uint8_t; + + constexpr std::time_t INVALID_TIME = static_cast(-1); + + constexpr size_t INVALID_INDEX = static_cast(-1); + + class cronexpr; + + namespace detail + { + enum class cron_field + { + second, + minute, + hour_of_day, + day_of_week, + day_of_month, + month, + year + }; + + template + static bool find_next(cronexpr const& cex, + std::tm& date, + size_t const dot); + } + + struct bad_cronexpr : public std::runtime_error + { + public: + explicit bad_cronexpr(CRONCPP_STRING_VIEW message) : + std::runtime_error(message.data()) + { + } + }; + + + struct cron_standard_traits + { + static const cron_int CRON_MIN_SECONDS = 0; + static const cron_int CRON_MAX_SECONDS = 59; + + static const cron_int CRON_MIN_MINUTES = 0; + static const cron_int CRON_MAX_MINUTES = 59; + + static const cron_int CRON_MIN_HOURS = 0; + static const cron_int CRON_MAX_HOURS = 23; + + static const cron_int CRON_MIN_DAYS_OF_WEEK = 0; + static const cron_int CRON_MAX_DAYS_OF_WEEK = 6; + + static const cron_int CRON_MIN_DAYS_OF_MONTH = 1; + static const cron_int CRON_MAX_DAYS_OF_MONTH = 31; + + static const cron_int CRON_MIN_MONTHS = 1; + static const cron_int CRON_MAX_MONTHS = 12; + + static const cron_int CRON_MAX_YEARS_DIFF = 4; + +#ifdef CRONCPP_IS_CPP17 + static const inline std::vector DAYS = {"SUN", "MON", "TUE", "WED", "THU", "FRI", "SAT"}; + static const inline std::vector MONTHS = { + "NIL", "JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC" + }; +#else + static std::vector& DAYS() + { + static std::vector days = {"SUN", "MON", "TUE", "WED", "THU", "FRI", "SAT"}; + return days; + } + + static std::vector& MONTHS() + { + static std::vector months = { + "NIL", "JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC" + }; + return months; + } +#endif + }; + + struct cron_oracle_traits + { + static const cron_int CRON_MIN_SECONDS = 0; + static const cron_int CRON_MAX_SECONDS = 59; + + static const cron_int CRON_MIN_MINUTES = 0; + static const cron_int CRON_MAX_MINUTES = 59; + + static const cron_int CRON_MIN_HOURS = 0; + static const cron_int CRON_MAX_HOURS = 23; + + static const cron_int CRON_MIN_DAYS_OF_WEEK = 1; + static const cron_int CRON_MAX_DAYS_OF_WEEK = 7; + + static const cron_int CRON_MIN_DAYS_OF_MONTH = 1; + static const cron_int CRON_MAX_DAYS_OF_MONTH = 31; + + static const cron_int CRON_MIN_MONTHS = 0; + static const cron_int CRON_MAX_MONTHS = 11; + + static const cron_int CRON_MAX_YEARS_DIFF = 4; + +#ifdef CRONCPP_IS_CPP17 + static const inline std::vector DAYS = {"NIL", "SUN", "MON", "TUE", "WED", "THU", "FRI", "SAT"}; + static const inline std::vector MONTHS = { + "JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC" + }; +#else + + static std::vector& DAYS() + { + static std::vector days = {"NIL", "SUN", "MON", "TUE", "WED", "THU", "FRI", "SAT"}; + return days; + } + + static std::vector& MONTHS() + { + static std::vector months = { + "JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC" + }; + return months; + } +#endif + }; + + struct cron_quartz_traits + { + static const cron_int CRON_MIN_SECONDS = 0; + static const cron_int CRON_MAX_SECONDS = 59; + + static const cron_int CRON_MIN_MINUTES = 0; + static const cron_int CRON_MAX_MINUTES = 59; + + static const cron_int CRON_MIN_HOURS = 0; + static const cron_int CRON_MAX_HOURS = 23; + + static const cron_int CRON_MIN_DAYS_OF_WEEK = 1; + static const cron_int CRON_MAX_DAYS_OF_WEEK = 7; + + static const cron_int CRON_MIN_DAYS_OF_MONTH = 1; + static const cron_int CRON_MAX_DAYS_OF_MONTH = 31; + + static const cron_int CRON_MIN_MONTHS = 1; + static const cron_int CRON_MAX_MONTHS = 12; + + static const cron_int CRON_MAX_YEARS_DIFF = 4; + +#ifdef CRONCPP_IS_CPP17 + static const inline std::vector DAYS = {"NIL", "SUN", "MON", "TUE", "WED", "THU", "FRI", "SAT"}; + static const inline std::vector MONTHS = { + "NIL", "JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC" + }; +#else + static std::vector& DAYS() + { + static std::vector days = {"NIL", "SUN", "MON", "TUE", "WED", "THU", "FRI", "SAT"}; + return days; + } + + static std::vector& MONTHS() + { + static std::vector months = { + "NIL", "JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC" + }; + return months; + } +#endif + }; + + class cronexpr; + + template + static cronexpr make_cron(CRONCPP_STRING_VIEW expr); + + class cronexpr + { + std::bitset<60> seconds; + std::bitset<60> minutes; + std::bitset<24> hours; + std::bitset<7> days_of_week; + std::bitset<31> days_of_month; + std::bitset<12> months; + std::string expr; + + friend bool operator==(cronexpr const& e1, cronexpr const& e2); + friend bool operator!=(cronexpr const& e1, cronexpr const& e2); + + template + friend bool detail::find_next(cronexpr const& cex, + std::tm& date, + size_t const dot); + + friend std::string to_cronstr(cronexpr const& cex); + friend std::string to_string(cronexpr const& cex); + + template + friend cronexpr make_cron(CRONCPP_STRING_VIEW expr); + }; + + inline bool operator==(cronexpr const& e1, cronexpr const& e2) + { + return + e1.seconds == e2.seconds && + e1.minutes == e2.minutes && + e1.hours == e2.hours && + e1.days_of_week == e2.days_of_week && + e1.days_of_month == e2.days_of_month && + e1.months == e2.months; + } + + inline bool operator!=(cronexpr const& e1, cronexpr const& e2) + { + return !(e1 == e2); + } + + inline std::string to_string(cronexpr const& cex) + { + return + cex.seconds.to_string() + " " + + cex.minutes.to_string() + " " + + cex.hours.to_string() + " " + + cex.days_of_month.to_string() + " " + + cex.months.to_string() + " " + + cex.days_of_week.to_string(); + } + + inline std::string to_cronstr(cronexpr const& cex) + { + return cex.expr; + } + + namespace utils + { + inline std::time_t tm_to_time(std::tm& date) + { + return std::mktime(&date); + } + + inline std::tm* time_to_tm(std::time_t const* date, std::tm* const out) + { +#ifdef _WIN32 + errno_t err = localtime_s(out, date); + return 0 == err ? out : nullptr; +#else + return localtime_r(date, out); +#endif + } + + inline std::tm to_tm(CRONCPP_STRING_VIEW time) + { + std::tm result; +#if __cplusplus > 201103L + std::istringstream str(time.data()); + str.imbue(std::locale(setlocale(LC_ALL, nullptr))); + + str >> std::get_time(&result, "%Y-%m-%d %H:%M:%S"); + if (str.fail()) throw std::runtime_error("Parsing date failed!"); +#else + int year = 1900; + int month = 1; + int day = 1; + int hour = 0; + int minute = 0; + int second = 0; + sscanf(time.data(), "%d-%d-%d %d:%d:%d", &year, &month, &day, &hour, &minute, &second); + result.tm_year = year - 1900; + result.tm_mon = month - 1; + result.tm_mday = day; + result.tm_hour = hour; + result.tm_min = minute; + result.tm_sec = second; +#endif + result.tm_isdst = -1; // DST info not available + + return result; + } + + inline std::string to_string(std::tm const& tm) + { +#if __cplusplus > 201103L + std::ostringstream str; + str.imbue(std::locale(setlocale(LC_ALL, nullptr))); + str << std::put_time(&tm, "%Y-%m-%d %H:%M:%S"); + if (str.fail()) throw std::runtime_error("Writing date failed!"); + + return str.str(); +#else + char buff[70] = {0}; + strftime(buff, sizeof(buff), "%Y-%m-%d %H:%M:%S", &tm); + return std::string(buff); +#endif + } + + inline std::string to_upper(std::string text) + { + std::transform(std::begin(text), std::end(text), + std::begin(text), [](char const c) { return static_cast(std::toupper(c)); }); + + return text; + } + + static std::vector split(CRONCPP_STRING_VIEW text, char const delimiter) + { + std::vector tokens; + std::string token; + std::istringstream tokenStream(text.data()); + while (std::getline(tokenStream, token, delimiter)) + { + tokens.push_back(token); + } + return tokens; + } + + CRONCPP_CONSTEXPTR inline bool contains(CRONCPP_STRING_VIEW text, char const ch) noexcept + { + return CRONCPP_STRING_VIEW_NPOS != text.find_first_of(ch); + } + } + + namespace detail + { + inline cron_int to_cron_int(CRONCPP_STRING_VIEW text) + { + try + { + return static_cast(std::stoul(text.data())); + } + catch (std::exception const& ex) + { + throw bad_cronexpr(ex.what()); + } + } + + static std::string replace_ordinals( + std::string text, + std::vector const& replacement) + { + for (size_t i = 0; i < replacement.size(); ++i) + { + auto pos = text.find(replacement[i]); + if (std::string::npos != pos) + text.replace(pos, 3, std::to_string(i)); + } + + return text; + } + + static std::pair make_range( + CRONCPP_STRING_VIEW field, + cron_int const minval, + cron_int const maxval) + { + cron_int first = 0; + cron_int last = 0; + if (field.size() == 1 && field[0] == '*') + { + first = minval; + last = maxval; + } + else if (!utils::contains(field, '-')) + { + first = to_cron_int(field); + last = first; + } + else + { + auto parts = utils::split(field, '-'); + if (parts.size() != 2) + throw bad_cronexpr("Specified range requires two fields"); + + first = to_cron_int(parts[0]); + last = to_cron_int(parts[1]); + } + + if (first > maxval || last > maxval) + { + throw bad_cronexpr("Specified range exceeds maximum"); + } + if (first < minval || last < minval) + { + throw bad_cronexpr("Specified range is less than minimum"); + } + if (first > last) + { + throw bad_cronexpr("Specified range start exceeds range end"); + } + + return {first, last}; + } + + template + static void set_cron_field( + CRONCPP_STRING_VIEW value, + std::bitset& target, + cron_int const minval, + cron_int const maxval) + { + if (value.length() > 0 && value[value.length() - 1] == ',') + throw bad_cronexpr("Value cannot end with comma"); + + auto fields = utils::split(value, ','); + if (fields.empty()) + throw bad_cronexpr("Expression parsing error"); + + for (auto const& field : fields) + { + if (!utils::contains(field, '/')) + { +#ifdef CRONCPP_IS_CPP17 + auto [first, last] = detail::make_range(field, minval, maxval); +#else + auto range = detail::make_range(field, minval, maxval); + auto first = range.first; + auto last = range.second; +#endif + for (cron_int i = first - minval; i <= last - minval; ++i) + { + target.set(i); + } + } + else + { + auto parts = utils::split(field, '/'); + if (parts.size() != 2) + throw bad_cronexpr("Incrementer must have two fields"); + +#ifdef CRONCPP_IS_CPP17 + auto [first, last] = detail::make_range(parts[0], minval, maxval); +#else + auto range = detail::make_range(parts[0], minval, maxval); + auto first = range.first; + auto last = range.second; +#endif + + if (!utils::contains(parts[0], '-')) + { + last = maxval; + } + + auto delta = detail::to_cron_int(parts[1]); + if (delta <= 0) + throw bad_cronexpr("Incrementer must be a positive value"); + + for (cron_int i = first - minval; i <= last - minval; i += delta) + { + target.set(i); + } + } + } + } + + template + static void set_cron_days_of_week( + std::string value, + std::bitset<7>& target) + { + auto days = utils::to_upper(value); + auto days_replaced = detail::replace_ordinals( + days, +#ifdef CRONCPP_IS_CPP17 + Traits::DAYS +#else + Traits::DAYS() +#endif + ); + + if (days_replaced.size() == 1 && days_replaced[0] == '?') + days_replaced[0] = '*'; + + set_cron_field( + days_replaced, + target, + Traits::CRON_MIN_DAYS_OF_WEEK, + Traits::CRON_MAX_DAYS_OF_WEEK); + } + + template + static void set_cron_days_of_month( + std::string value, + std::bitset<31>& target) + { + if (value.size() == 1 && value[0] == '?') + value[0] = '*'; + + set_cron_field( + value, + target, + Traits::CRON_MIN_DAYS_OF_MONTH, + Traits::CRON_MAX_DAYS_OF_MONTH); + } + + template + static void set_cron_month( + std::string value, + std::bitset<12>& target) + { + auto month = utils::to_upper(value); + auto month_replaced = replace_ordinals( + month, +#ifdef CRONCPP_IS_CPP17 + Traits::MONTHS +#else + Traits::MONTHS() +#endif + ); + + set_cron_field( + month_replaced, + target, + Traits::CRON_MIN_MONTHS, + Traits::CRON_MAX_MONTHS); + } + + template + inline size_t next_set_bit( + std::bitset const& target, + size_t /*minimum*/, + size_t /*maximum*/, + size_t offset) + { + for (auto i = offset; i < N; ++i) + { + if (target.test(i)) return i; + } + + return INVALID_INDEX; + } + + inline void add_to_field( + std::tm& date, + cron_field const field, + int const val) + { + switch (field) + { + case cron_field::second: + date.tm_sec += val; + break; + case cron_field::minute: + date.tm_min += val; + break; + case cron_field::hour_of_day: + date.tm_hour += val; + break; + case cron_field::day_of_week: + case cron_field::day_of_month: + date.tm_mday += val; + date.tm_isdst = -1; + break; + case cron_field::month: + date.tm_mon += val; + date.tm_isdst = -1; + break; + case cron_field::year: + date.tm_year += val; + break; + } + + if (INVALID_TIME == utils::tm_to_time(date)) + throw bad_cronexpr("Invalid time expression"); + } + + inline void set_field( + std::tm& date, + cron_field const field, + int const val) + { + switch (field) + { + case cron_field::second: + date.tm_sec = val; + break; + case cron_field::minute: + date.tm_min = val; + break; + case cron_field::hour_of_day: + date.tm_hour = val; + break; + case cron_field::day_of_week: + date.tm_wday = val; + break; + case cron_field::day_of_month: + date.tm_mday = val; + date.tm_isdst = -1; + break; + case cron_field::month: + date.tm_mon = val; + date.tm_isdst = -1; + break; + case cron_field::year: + date.tm_year = val; + break; + } + + if (INVALID_TIME == utils::tm_to_time(date)) + throw bad_cronexpr("Invalid time expression"); + } + + inline void reset_field( + std::tm& date, + cron_field const field) + { + switch (field) + { + case cron_field::second: + date.tm_sec = 0; + break; + case cron_field::minute: + date.tm_min = 0; + break; + case cron_field::hour_of_day: + date.tm_hour = 0; + break; + case cron_field::day_of_week: + date.tm_wday = 0; + break; + case cron_field::day_of_month: + date.tm_mday = 1; + date.tm_isdst = -1; + break; + case cron_field::month: + date.tm_mon = 0; + date.tm_isdst = -1; + break; + case cron_field::year: + date.tm_year = 0; + break; + } + + if (INVALID_TIME == utils::tm_to_time(date)) + throw bad_cronexpr("Invalid time expression"); + } + + inline void reset_all_fields( + std::tm& date, + std::bitset<7> const& marked_fields) + { + for (size_t i = 0; i < marked_fields.size(); ++i) + { + if (marked_fields.test(i)) + reset_field(date, static_cast(i)); + } + } + + inline void mark_field( + std::bitset<7>& orders, + cron_field const field) + { + if (!orders.test(static_cast(field))) + orders.set(static_cast(field)); + } + + template + static size_t find_next( + std::bitset const& target, + std::tm& date, + unsigned int const minimum, + unsigned int const maximum, + unsigned int const value, + cron_field const field, + cron_field const next_field, + std::bitset<7> const& marked_fields) + { + auto next_value = next_set_bit(target, minimum, maximum, value); + if (INVALID_INDEX == next_value) + { + add_to_field(date, next_field, 1); + reset_field(date, field); + next_value = next_set_bit(target, minimum, maximum, 0); + } + + if (INVALID_INDEX == next_value || next_value != value) + { + set_field(date, field, static_cast(next_value)); + reset_all_fields(date, marked_fields); + } + + return next_value; + } + + template + static size_t find_next_day( + std::tm& date, + std::bitset<31> const& days_of_month, + size_t day_of_month, + std::bitset<7> const& days_of_week, + size_t day_of_week, + std::bitset<7> const& marked_fields) + { + unsigned int count = 0; + unsigned int maximum = 366; + while ( + (!days_of_month.test(day_of_month - Traits::CRON_MIN_DAYS_OF_MONTH) || + !days_of_week.test(day_of_week - Traits::CRON_MIN_DAYS_OF_WEEK)) + && count++ < maximum) + { + add_to_field(date, cron_field::day_of_month, 1); + + day_of_month = date.tm_mday; + day_of_week = date.tm_wday; + + reset_all_fields(date, marked_fields); + } + + return day_of_month; + } + + template + static bool find_next(cronexpr const& cex, + std::tm& date, + size_t const dot) + { + bool res = true; + + std::bitset<7> marked_fields{0}; + std::bitset<7> empty_list{0}; + + unsigned int second = date.tm_sec; + auto updated_second = find_next( + cex.seconds, + date, + Traits::CRON_MIN_SECONDS, + Traits::CRON_MAX_SECONDS, + second, + cron_field::second, + cron_field::minute, + empty_list); + + if (second == updated_second) + { + mark_field(marked_fields, cron_field::second); + } + + unsigned int minute = date.tm_min; + auto update_minute = find_next( + cex.minutes, + date, + Traits::CRON_MIN_MINUTES, + Traits::CRON_MAX_MINUTES, + minute, + cron_field::minute, + cron_field::hour_of_day, + marked_fields); + if (minute == update_minute) + { + mark_field(marked_fields, cron_field::minute); + } + else + { + res = find_next(cex, date, dot); + if (!res) return res; + } + + unsigned int hour = date.tm_hour; + auto updated_hour = find_next( + cex.hours, + date, + Traits::CRON_MIN_HOURS, + Traits::CRON_MAX_HOURS, + hour, + cron_field::hour_of_day, + cron_field::day_of_week, + marked_fields); + if (hour == updated_hour) + { + mark_field(marked_fields, cron_field::hour_of_day); + } + else + { + res = find_next(cex, date, dot); + if (!res) return res; + } + + unsigned int day_of_week = date.tm_wday; + unsigned int day_of_month = date.tm_mday; + auto updated_day_of_month = find_next_day( + date, + cex.days_of_month, + day_of_month, + cex.days_of_week, + day_of_week, + marked_fields); + if (day_of_month == updated_day_of_month) + { + mark_field(marked_fields, cron_field::day_of_month); + } + else + { + res = find_next(cex, date, dot); + if (!res) return res; + } + + unsigned int month = date.tm_mon; + auto updated_month = find_next( + cex.months, + date, + Traits::CRON_MIN_MONTHS, + Traits::CRON_MAX_MONTHS, + month, + cron_field::month, + cron_field::year, + marked_fields); + if (month != updated_month) + { + if (date.tm_year - dot > Traits::CRON_MAX_YEARS_DIFF) + return false; + + res = find_next(cex, date, dot); + if (!res) return res; + } + + return res; + } + } + + template + static cronexpr make_cron(CRONCPP_STRING_VIEW expr) + { + cronexpr cex; + + if (expr.empty()) + throw bad_cronexpr("Invalid empty cron expression"); + + auto fields = utils::split(expr, ' '); + fields.erase( + std::remove_if(std::begin(fields), std::end(fields), + [](CRONCPP_STRING_VIEW s) { return s.empty(); }), + std::end(fields)); + if (fields.size() != 6) + throw bad_cronexpr("cron expression must have six fields"); + + detail::set_cron_field(fields[0], cex.seconds, Traits::CRON_MIN_SECONDS, Traits::CRON_MAX_SECONDS); + detail::set_cron_field(fields[1], cex.minutes, Traits::CRON_MIN_MINUTES, Traits::CRON_MAX_MINUTES); + detail::set_cron_field(fields[2], cex.hours, Traits::CRON_MIN_HOURS, Traits::CRON_MAX_HOURS); + + detail::set_cron_days_of_week(fields[5], cex.days_of_week); + + detail::set_cron_days_of_month(fields[3], cex.days_of_month); + + detail::set_cron_month(fields[4], cex.months); + + cex.expr = expr; + + return cex; + } + + template + static std::tm cron_next(cronexpr const& cex, std::tm date) + { + time_t original = utils::tm_to_time(date); + if (INVALID_TIME == original) return {}; + + if (!detail::find_next(cex, date, date.tm_year)) + return {}; + + time_t calculated = utils::tm_to_time(date); + if (INVALID_TIME == calculated) return {}; + + if (calculated == original) + { + add_to_field(date, detail::cron_field::second, 1); + if (!detail::find_next(cex, date, date.tm_year)) + return {}; + } + + return date; + } + + template + static std::time_t cron_next(cronexpr const& cex, std::time_t const& date) + { + std::tm val; + std::tm* dt = utils::time_to_tm(&date, &val); + if (dt == nullptr) return INVALID_TIME; + + time_t original = utils::tm_to_time(*dt); + if (INVALID_TIME == original) return INVALID_TIME; + + if (!detail::find_next(cex, *dt, dt->tm_year)) + return INVALID_TIME; + + time_t calculated = utils::tm_to_time(*dt); + if (INVALID_TIME == calculated) return calculated; + + if (calculated == original) + { + add_to_field(*dt, detail::cron_field::second, 1); + if (!detail::find_next(cex, *dt, dt->tm_year)) + return INVALID_TIME; + } + + return utils::tm_to_time(*dt); + } + + template + static std::chrono::system_clock::time_point cron_next(cronexpr const& cex, + std::chrono::system_clock::time_point const& time_point) + { + return std::chrono::system_clock::from_time_t( + cron_next(cex, std::chrono::system_clock::to_time_t(time_point))); + } +} +#endif //KHTTPD_FRAMEWORK_CRONCPP_HPP diff --git a/framework/dto/TagInvoke.hpp b/framework/dto/TagInvoke.hpp new file mode 100644 index 0000000..201a32e --- /dev/null +++ b/framework/dto/TagInvoke.hpp @@ -0,0 +1,70 @@ +// +// Created by Caesar on 2025/12/7. +// + +#ifndef BOOST_TAGINVOKE_HPP +#define BOOST_TAGINVOKE_HPP + +#include +#include +#include +#include + +namespace boost::json +{ + namespace desc = boost::describe; + namespace mp11 = boost::mp11; + + // ================================================================= + // 1. 通用序列化 (Struct -> JSON) + // C++17 改进: 使用 std::enable_if_t 简化语法 + // ================================================================= + template + auto tag_invoke(value_from_tag, value& jv, T const& t) + -> std::enable_if_t::value> + { + auto& obj = jv.emplace_object(); + + using Md = desc::describe_members; + + mp11::mp_for_each([&](auto D) + { + // 使用 emplace 稍微高效一点,直接构造 value + obj.emplace(D.name, value_from(t.*D.pointer)); + }); + } + + // ================================================================= + // 2. 通用反序列化 (JSON -> Struct) + // C++17 改进: 使用 std::enable_if_t 和 std::remove_reference_t + // ================================================================= + template + auto tag_invoke(value_to_tag, value const& jv) + -> std::enable_if_t::value, T> + { + // 只有 T 是 DefaultConstructible 时才能这样写 + T t{}; + + // 如果 JSON 不是 object,as_object() 会抛出异常,这是符合预期的行为 + auto const& obj = jv.as_object(); + + using Md = desc::describe_members; + + mp11::mp_for_each([&](auto D) + { + // 查找 key 是否存在 + if (auto it = obj.find(D.name); it != obj.end()) + { + // C++17: 使用 remove_reference_t 简化类型获取 + using MemberT = std::remove_reference_t; + + // 将 json value 转换为具体的成员类型并赋值 + t.*D.pointer = value_to(it->value()); + } + }); + + return t; + } +} // namespace boost::json + +#endif //BOOST_TAGINVOKE_HPP diff --git a/framework/io_context_pool.hpp b/framework/io_context_pool.hpp new file mode 100644 index 0000000..3533e47 --- /dev/null +++ b/framework/io_context_pool.hpp @@ -0,0 +1,90 @@ +#ifndef KHTTPD_FRAMEWORK_CLIENT_IO_CONTEXT_POOL_HPP +#define KHTTPD_FRAMEWORK_CLIENT_IO_CONTEXT_POOL_HPP + +#include +#include +#include +#include +#include +#include + +namespace khttpd::framework +{ + class IoContextPool + { + public: + // 获取单例实例 + static IoContextPool& instance(unsigned int num_threads = 0) + { + static IoContextPool instance{num_threads}; + return instance; + } + + // 获取共享的 io_context + boost::asio::io_context& get_io_context() + { + return ioc_; + } + + // 获取当前运行的线程数量 + size_t get_thread_count() const + { + return threads_.size(); + } + + ~IoContextPool() + { + stop(); + } + + void stop() + { + // 确保只停止一次,防止析构和显式调用 stop 冲突 + std::call_once(stop_flag_, [this]() + { + work_guard_.reset(); // 允许 run() 退出 + ioc_.stop(); // 显式发出停止信号 + + // 等待所有线程结束 + for (auto& t : threads_) + { + if (t.joinable()) + { + t.join(); + } + } + threads_.clear(); + }); + } + + private: + explicit IoContextPool(unsigned int count = std::thread::hardware_concurrency()) + : work_guard_(boost::asio::make_work_guard(ioc_)) + { + // 如果检测失败(返回0)或者核心数少于1,保底使用 1 个线程 + // 如果为了提高并发吞吐量,也可以设为 count * 2 + if (count <= 0) count = 1; + + threads_.reserve(count * 2); + + // 2. 启动线程池 + for (unsigned int i = 0; i < count; ++i) + { + threads_.emplace_back([this]() + { + // 每个线程都运行同一个 io_context + // ASIO 会自动调度 handler 到空闲线程 + ioc_.run(); + }); + } + } + + + boost::asio::io_context ioc_; + boost::asio::executor_work_guard work_guard_; + std::vector threads_; + std::once_flag stop_flag_; + }; +} + +#endif // KHTTPD_FRAMEWORK_CLIENT_IO_CONTEXT_POOL_HPP diff --git a/framework/server.cpp b/framework/server.cpp index 7bb1a7d..f457b9d 100644 --- a/framework/server.cpp +++ b/framework/server.cpp @@ -5,14 +5,14 @@ #include #include +#include "io_context_pool.hpp" + namespace khttpd::framework { Server::Server(const tcp::endpoint& endpoint, std::string web_root, int num_threads) - : ioc_(std::in_place, num_threads), - num_threads_(num_threads), - signals_(*ioc_, SIGINT, SIGTERM), + : signals_(IoContextPool::instance(num_threads).get_io_context(), SIGINT, SIGTERM), web_root_(std::move(web_root)), - acceptor_(net::make_strand(*ioc_)) + acceptor_(net::make_strand(IoContextPool::instance().get_io_context())) { boost::beast::error_code ec; @@ -91,21 +91,8 @@ namespace khttpd::framework do_accept(); - threads_.reserve(num_threads_ - 1); - for (int i = 0; i < num_threads_ - 1; ++i) - { - threads_.emplace_back([&ioc = *ioc_] - { - ioc.run(); - }); - } - - (*ioc_).run(); + IoContextPool::instance().get_io_context().run(); - for (auto& t : threads_) - { - t.join(); - } fmt::print("Server workers stopped.\n"); } @@ -118,17 +105,14 @@ namespace khttpd::framework fmt::print(stderr, "Server acceptor close error: {}\n", ec.message()); } - if (ioc_.has_value()) - { - (*ioc_).stop(); - } + IoContextPool::instance().stop(); fmt::print("Server stopped.\n"); } void Server::do_accept() { acceptor_.async_accept( - net::make_strand(*ioc_), + net::make_strand(IoContextPool::instance().get_io_context()), beast::bind_front_handler(&Server::on_accept, shared_from_this())); } diff --git a/framework/server.hpp b/framework/server.hpp index 82ccf7a..d61f7f0 100644 --- a/framework/server.hpp +++ b/framework/server.hpp @@ -38,8 +38,8 @@ namespace khttpd::framework void stop(); private: - std::optional ioc_; - int num_threads_; + // std::optional ioc_; + // int num_threads_; std::vector threads_; net::signal_set signals_; const std::string web_root_; diff --git a/framework/tests/BUILD.bazel b/framework/tests/BUILD.bazel index 1e29a3f..5b84d0d 100644 --- a/framework/tests/BUILD.bazel +++ b/framework/tests/BUILD.bazel @@ -57,3 +57,33 @@ cc_test( "@googletest//:gtest_main", ], ) + +cc_test( + name = "client_test", + srcs = ["client_test.cpp"], + copts = [ + "-std=c++17", + "-Wall", + "-pedantic", + ], + deps = [ + "//framework", + "@googletest//:gtest", + "@googletest//:gtest_main", + ], +) + +cc_test( + name = "cronjob_test", + srcs = ["cronjob_test.cpp"], + copts = [ + "-std=c++17", + "-Wall", + "-pedantic", + ], + deps = [ + "//framework", + "@googletest//:gtest", + "@googletest//:gtest_main", + ], +) diff --git a/framework/tests/client_test.cpp b/framework/tests/client_test.cpp new file mode 100644 index 0000000..3ca0b97 --- /dev/null +++ b/framework/tests/client_test.cpp @@ -0,0 +1,417 @@ +#include "framework/client/http_client.hpp" +#include "framework/client/websocket_client.hpp" +#include +#include +#include +#include +#include + +#include "io_context_pool.hpp" + +using namespace khttpd::framework::client; +namespace http = boost::beast::http; + +// ========================================== +// 1. 定义 PostmanEchoClient 类 +// ========================================== +class PostmanEchoClient : public HttpClient +{ +public: + // 构造函数:注入 ioc,并设置默认 Base URL + PostmanEchoClient() + { + set_base_url("https://postman-echo.com"); + // 设置一个较长的超时时间,防止 CI 环境网络慢 + set_timeout(std::chrono::seconds(10)); + } + + // ------------------------------------------------------------------ + // API 定义 + // ------------------------------------------------------------------ + + // 1. GET 请求,带查询参数 + // Endpoint: /get?foo=bar + API_CALL(http::verb::get, "/get", echo_get, + QUERY(std::string, foo_val, "foo"), + QUERY(int, id_val, "id")) + + // 2. POST 请求,带 JSON Body + // Endpoint: /post + API_CALL(http::verb::post, "/post", echo_post, + BODY(boost::json::object, json_body)) + + // 3. GET 请求,测试 Header 传递 + // Endpoint: /headers + // 我们定义一个名为 request_id 的参数,它会被映射为 HTTP Header "X-Request-Id" + API_CALL(http::verb::get, "/headers", echo_headers, + HEADER(std::string, request_id, "X-My-Request-Id"), + HEADER(std::string, user_token, "X-User-Token")) + + // 4. PUT 请求,带路径参数 + // Endpoint: /put (Postman echo 实际上忽略路径后的东西,但我们可以测试 URL 拼接) + API_CALL(http::verb::put, "/put", echo_put_dummy) +}; + +// ========================================== +// 2. 测试用例 +// ========================================== + +class ClientTest : public ::testing::Test +{ +protected: + boost::asio::io_context ioc; + std::shared_ptr client; + + // 辅助:用于在主线程等待异步结果 + void run_until_complete() + { + ioc.run(); + ioc.restart(); // 重置以便下次使用 + } + + void SetUp() override + { + client = std::make_shared(); + } +}; + + +// 辅助宏:等待异步结果 +// 如果 5 秒没结果,这就认为超时失败 +#define WAIT_FOR_ASYNC(future) \ + ASSERT_EQ(future.wait_for(std::chrono::seconds(5)), std::future_status::ready) << "Async operation timed out"; + +TEST_F(ClientTest, GetWithQueryParams) +{ + // 创建一个 promise 用于通知主线程任务完成 + std::promise promise; + auto future = promise.get_future(); + + client->echo_get("hello", 123, [&](auto ec, auto res) + { + // 这里的代码在后台线程运行 + if (!ec) + { + EXPECT_EQ(res.result(), http::status::ok); + std::string body = res.body(); + EXPECT_TRUE(body.find("\"foo\":\"hello\"") != std::string::npos); + EXPECT_TRUE(body.find("\"id\":\"123\"") != std::string::npos); + } + else + { + ADD_FAILURE() << "Network error: " << ec.message(); + } + + // 通知主线程:我做完了 + promise.set_value(); + }); + + // 主线程在此阻塞等待,直到 callback 执行完毕 + WAIT_FOR_ASYNC(future); +} + +TEST_F(ClientTest, PostJsonBody) +{ + std::promise promise; + auto future = promise.get_future(); + + boost::json::object jv; + jv["message"] = "test_payload"; + jv["count"] = 99; + + client->echo_post(jv, [&](auto ec, auto res) + { + if (!ec) + { + EXPECT_EQ(res.result(), http::status::ok); + std::string body = res.body(); + EXPECT_TRUE(body.find("test_payload") != std::string::npos); + } + else + { + ADD_FAILURE() << "Network error: " << ec.message(); + } + promise.set_value(); + }); + + WAIT_FOR_ASYNC(future); +} + +TEST_F(ClientTest, CustomHeaders) +{ + std::promise promise; + auto future = promise.get_future(); + + std::string rid = "req-unique-id-001"; + std::string token = "secret-token-abc"; + + client->echo_headers(rid, token, [&](auto ec, auto res) + { + if (!ec) + { + EXPECT_EQ(res.result(), http::status::ok); + std::string body = res.body(); + + bool has_rid = body.find("x-my-request-id") != std::string::npos || + body.find("X-My-Request-Id") != std::string::npos; + bool has_val = body.find(rid) != std::string::npos; + + EXPECT_TRUE(has_rid) << "Missing Header Key"; + EXPECT_TRUE(has_val) << "Missing Header Value"; + } + else + { + ADD_FAILURE() << "Network error: " << ec.message(); + } + promise.set_value(); + }); + + WAIT_FOR_ASYNC(future); +} + +TEST_F(ClientTest, GlobalDefaultHeader) +{ + std::promise promise; + auto future = promise.get_future(); + + client->set_default_header("X-App-Version", "v1.0.0-beta"); + + client->echo_headers("id-1", "token-1", [&](auto ec, auto res) + { + if (!ec) + { + std::string body = res.body(); + EXPECT_TRUE(body.find("v1.0.0-beta") != std::string::npos); + } + promise.set_value(); + }); + + WAIT_FOR_ASYNC(future); +} + +// 同步调用测试 (现在非常安全,不会死锁) +TEST_F(ClientTest, SyncCallSafe) +{ + try + { + // 主线程调用,后台线程执行,future wait 自动处理 + auto res = client->echo_get_sync("sync_world", 999); + + EXPECT_EQ(res.result(), http::status::ok); + std::string body = res.body(); + EXPECT_TRUE(body.find("sync_world") != std::string::npos); + } + catch (const std::exception& e) + { + ADD_FAILURE() << "Sync request exception: " << e.what(); + } +} + +TEST_F(ClientTest, SyncCall) +{ + // 重要:同步调用会阻塞当前线程等待 future, + // 所以 io_context 必须在另一个线程跑,否则死锁。 + auto work = boost::asio::make_work_guard(ioc); + std::thread ioc_thread([&] + { + ioc.run(); + }); + + try + { + // 使用同步生成的 API + auto res = client->echo_get_sync("sync_world", 999); + + EXPECT_EQ(res.result(), http::status::ok); + std::string body = res.body(); + EXPECT_TRUE(body.find("sync_world") != std::string::npos); + } + catch (const std::exception& e) + { + ADD_FAILURE() << "Sync request exception: " << e.what(); + } + + // 清理 + work.reset(); + ioc.stop(); + if (ioc_thread.joinable()) ioc_thread.join(); +} + +TEST(EasyModeTest, SyncRequestWithoutManualContext) +{ + // 不需要手动创建 ioc, work_guard, thread + auto client = std::make_shared(); // 使用默认构造 + + try + { + // 直接调用同步接口 + auto res = client->echo_get_sync("easy_mode", 1); + EXPECT_EQ(res.result(), http::status::ok); + EXPECT_TRUE(res.body().find("easy_mode") != std::string::npos); + } + catch (const std::exception& e) + { + ADD_FAILURE() << "Exception: " << e.what(); + } +} + +TEST(EasyModeTest, AsyncRequest) +{ + auto client = std::make_shared(); + + std::promise done; + auto future = done.get_future(); + + client->echo_get("async_easy", 2, [&](auto ec, auto res) + { + EXPECT_FALSE(ec); + done.set_value(); + }); + + // 等待异步结果 + // 因为 ioc 在后台线程跑,这里我们需要 wait + future.wait(); +} + + +// ========================================== +// WebSocket 测试 +// ========================================== + +class WebsocketTest : public ::testing::Test +{ +protected: + boost::asio::io_context ioc; + std::shared_ptr ws_client; + + void SetUp() override + { + ws_client = std::make_shared(ioc); + } + + void TearDown() override + { + if (ws_client) ws_client->close(); + } +}; + +TEST_F(WebsocketTest, WssEchoAndWriteQueue) +{ + std::string url = "wss://echo.websocket.org"; + + const int message_count = 5; + int received_count = 0; + bool closed_gracefully = false; + + // 增加一个 flag 标记是否发生严重错误 + bool has_error = false; + + // 创建定时器,但先不 async_wait,后面逻辑控制 + boost::asio::steady_timer timer(ioc, std::chrono::seconds(15)); + + ws_client->set_on_message([&](const std::string& msg) + { + // 过滤欢迎消息 + if (msg.find("Request served by") != std::string::npos) return; + + received_count++; + // std::cout << "Msg: " << msg << std::endl; + + if (received_count >= message_count) + { + ws_client->close(); + } + }); + + ws_client->set_on_close([&]() + { + closed_gracefully = true; + // 关键:连接关闭后,取消定时器,ioc.run() 就会立即返回 + timer.cancel(); + }); + + ws_client->set_on_error([&](boost::beast::error_code ec) + { + // 忽略操作取消(通常是 close() 导致的 pending read 取消) + if (ec == boost::asio::error::operation_aborted) return; + + std::cerr << "WS Error: " << ec.message() << std::endl; + has_error = true; + timer.cancel(); // 发生错误也停止测试 + }); + + ws_client->connect(url, [&](boost::beast::error_code ec) + { + if (ec) + { + ADD_FAILURE() << "WS Connect Failed: " << ec.message(); + timer.cancel(); + return; + } + + for (int i = 0; i < message_count; ++i) + { + ws_client->send("Msg-" + std::to_string(i)); + } + }); + + // 启动超时计时 + timer.async_wait([&](boost::system::error_code ec) + { + if (ec == boost::asio::error::operation_aborted) + { + // 定时器被取消,说明测试正常结束或提前出错 + return; + } + // 定时器真的触发了 -> 超时 + ws_client->close(); + ADD_FAILURE() << "Test Timed Out! Received: " << received_count << "/" << message_count; + }); + + ioc.run(); + + EXPECT_FALSE(has_error) << "Should not encounter network errors"; + EXPECT_EQ(received_count, message_count); + EXPECT_TRUE(closed_gracefully) << "on_close should be triggered"; +} + +TEST_F(WebsocketTest, ConnectFailure) +{ + // 测试连接不可达端口 + bool failed = false; + ws_client->connect("ws://localhost:59999", [&](boost::beast::error_code ec) + { + if (ec) + { + failed = true; + } + }); + + ioc.run(); + EXPECT_TRUE(failed); +} + +TEST_F(ClientTest, ThreadPoolVerify) +{ + std::cout << "Pool Size: " << khttpd::framework::IoContextPool::instance().get_thread_count() << std::endl; + + std::promise p1, p2; + auto f1 = p1.get_future(); + auto f2 = p2.get_future(); + + // 发起两个请求 + client->echo_get("A", 1, [&](auto, auto) + { + std::cout << "Req 1 processed on thread: " << std::this_thread::get_id() << std::endl; + p1.set_value(); + }); + + client->echo_get("B", 2, [&](auto, auto) + { + std::cout << "Req 2 processed on thread: " << std::this_thread::get_id() << std::endl; + p2.set_value(); + }); + + WAIT_FOR_ASYNC(f1); + WAIT_FOR_ASYNC(f2); +} diff --git a/framework/tests/cronjob_test.cpp b/framework/tests/cronjob_test.cpp new file mode 100644 index 0000000..120db10 --- /dev/null +++ b/framework/tests/cronjob_test.cpp @@ -0,0 +1,331 @@ +#include +#include +#include +#include +#include +#include + +// 包含你之前的头文件 +#include "cron/CronJob.hpp" +#include "io_context_pool.hpp" + +#include +#include +#include +#include +#include +#include +#include + +// 引入你的头文件 +#include "cron/CronScheduler.hpp" +#include "io_context_pool.hpp" + +using namespace std::chrono_literals; +using namespace khttpd::framework; + +// --- 测试用的辅助类 --- +class TestableCronJob : public CronJob +{ +public: + TestableCronJob(const std::string& expr) + : CronJob(expr), run_count_(0) + { + } + + // 实现 run 方法 + void run() override + { + // 1. 增加计数 + run_count_++; + + // 2. 通知测试线程 + { + std::lock_guard lock(mutex_); + // 只需要通知,具体逻辑由测试线程判断 + } + cv_.notify_one(); + } + + // 辅助方法:等待任务执行 n 次 + // 返回 true 表示在超时前完成了任务,false 表示超时 + bool wait_for_runs(int expected_count, std::chrono::milliseconds timeout) + { + std::unique_lock lock(mutex_); + return cv_.wait_for(lock, timeout, [this, expected_count]() + { + return run_count_ >= expected_count; + }); + } + + int get_run_count() const + { + return run_count_; + } + +private: + std::atomic run_count_; + std::mutex mutex_; + std::condition_variable cv_; +}; + +// --- 测试套件 --- + +class CronJobTest : public ::testing::Test +{ +protected: + static void SetUpTestSuite() + { + // 确保 IoContextPool 至少有一个线程在运行 + // 注意:单例模式下,这个池会在所有测试间共享 + IoContextPool::instance(1); + } + + static void TearDownTestSuite() + { + // 测试结束后停止池(可选,视具体需求而定) + // IoContextPool::instance().stop(); + } +}; + +// 测试 1: 验证无效的 Cron 表达式会抛出异常 +TEST_F(CronJobTest, ThrowsOnInvalidExpression) +{ + // 这是一个错误的表达式 (只有 5 个字段,或者是乱码) + std::string invalid_expr = "invalid cron string"; + + EXPECT_THROW({ + auto job = std::make_shared(invalid_expr); + }, std::runtime_error); // 这里的异常类型取决于 croncpp 具体抛出什么,通常是 std::runtime_error 或 croncpp::cron_exception +} + +// 测试 2: 验证任务是否能被调度和执行 +TEST_F(CronJobTest, RunsScheduleCorrectly) +{ + // 设置为每秒执行一次 ("* * * * * *") + // 注意:croncpp 能够处理秒级 + auto job = std::make_shared("* * * * * *"); + + job->start(); + + // 等待任务至少执行 1 次 + // 给它 2.5 秒的时间(理论上应该在第 1 秒或第 2 秒触发) + bool executed = job->wait_for_runs(1, std::chrono::milliseconds(2500)); + + EXPECT_TRUE(executed) << "Job did not run within timeout"; + EXPECT_GE(job->get_run_count(), 1); + + job->stop(); +} + +// 测试 3: 验证 Stop 后不再执行 +TEST_F(CronJobTest, StopPreventsFurtherExecution) +{ + auto job = std::make_shared("* * * * * *"); + job->start(); + + // 等待第 1 次 + ASSERT_TRUE(job->wait_for_runs(1, std::chrono::seconds(2))); + + // 停止 + job->stop(); + + // 获取当前快照 + int count_after_stop = job->get_run_count(); + + // 再等一会儿,看会不会偷偷跑 + std::this_thread::sleep_for(std::chrono::seconds(2)); + + // 现在这里应该能通过了 + // 即使在 stop 瞬间正好有一次执行完成,检查3也会阻止下一次调度 + EXPECT_EQ(job->get_run_count(), count_after_stop); +} + +// 测试 4: 多个任务并发 +TEST_F(CronJobTest, MultipleJobs) +{ + auto job1 = std::make_shared("* * * * * *"); + auto job2 = std::make_shared("* * * * * *"); + + job1->start(); + job2->start(); + + // 等待两个任务都至少运行一次 + EXPECT_TRUE(job1->wait_for_runs(1, std::chrono::seconds(2))); + EXPECT_TRUE(job2->wait_for_runs(1, std::chrono::seconds(2))); + + job1->stop(); + job2->stop(); +} + + +// --- 辅助类:用于线程安全地计数和等待 --- +class AsyncCounter +{ +public: + void tick() + { + run_count_++; + cv_.notify_all(); + } + + int get_count() const + { + return run_count_; + } + + // 等待至少达到 expected_count 次执行 + // 返回 true 表示成功,false 表示超时 + bool wait_for_at_least(int expected_count, std::chrono::milliseconds timeout) + { + std::unique_lock lock(mtx_); + return cv_.wait_for(lock, timeout, [this, expected_count]() + { + return run_count_ >= expected_count; + }); + } + + // 等待指定的时间,确认在此期间计数器是否变化(用于验证 Stop 和 Delay) + // 如果计数器在 timeout 内没有增加,返回 true + bool ensure_no_execution_for(std::chrono::milliseconds duration) + { + int initial = run_count_; + std::unique_lock lock(mtx_); + // wait_for 返回 false 表示超时(即条件一直不满足),这意味着没有达到 initial + 1 + // 所以如果 wait_for 返回 false,说明没有执行,是我们想要的结果 + bool triggered = cv_.wait_for(lock, duration, [this, initial]() + { + return run_count_ > initial; + }); + return !triggered; + } + +private: + std::atomic run_count_{0}; + std::mutex mtx_; + std::condition_variable cv_; +}; + +// --- 测试套件 --- +class CronSchedulerTest : public ::testing::Test +{ +protected: + static void SetUpTestSuite() + { + // 初始化线程池,使用 2 个线程以支持并发测试 + IoContextPool::instance(2); + } + + static void TearDownTestSuite() + { + IoContextPool::instance().stop(); + } + + void SetUp() override + { + // 每个测试开始前重置计数器等(如果需要) + } +}; + +// 测试 1: 验证通过 Scheduler 调度的基础 Lambda 任务能正常运行 +TEST_F(CronSchedulerTest, ScheduleBasic) +{ + auto counter = std::make_shared(); + + // 每秒执行一次 + // 注意:持有返回的 job 指针,否则测试函数结束时如果 pool 还在跑,任务也会跑 + auto job = CronScheduler::instance().schedule("* * * * * *", [counter]() + { + counter->tick(); + }); + + // 等待至少执行 1 次,超时时间 2.5 秒 + ASSERT_TRUE(counter->wait_for_at_least(1, 2500ms)) << "Job failed to run within timeout"; + + // 停止任务 + job->stop(); +} + +// 测试 2: 验证手动停止 (Stop) 功能,并测试之前的竞态条件修复 +TEST_F(CronSchedulerTest, ScheduleStop) +{ + auto counter = std::make_shared(); + + // 极高频任务(每秒) + auto job = CronScheduler::instance().schedule("* * * * * *", [counter]() + { + counter->tick(); + }); + + // 1. 确保它跑起来了 + ASSERT_TRUE(counter->wait_for_at_least(1, 2000ms)); + + // 2. 停止任务 + job->stop(); + + // 3. 记录停止后的次数 + int count_after_stop = counter->get_count(); + + // 4. 等待一段时间,确保它真的停了 + // 之前修复了 atomic 标志位,这里应该非常稳定 + std::this_thread::sleep_for(2000ms); + + EXPECT_EQ(counter->get_count(), count_after_stop) + << "Job continued running after stop() was called"; +} + +// 测试 3: 验证延迟启动 (Delayed Start) +TEST_F(CronSchedulerTest, ScheduleDelay) +{ + auto counter = std::make_shared(); + + // 定义延迟时间:2秒 + auto delay_time = 2000ms; + + // 调度:每秒执行一次,但先延迟 2 秒 + auto job = CronScheduler::instance().schedule( + "* * * * * *", + [counter]() { counter->tick(); }, + delay_time + ); + + // 阶段 A: 验证在延迟期间(比如前 1 秒内),任务没有运行 + // Cron 是每秒一次,如果没有延迟,1秒内肯定会跑。 + bool no_run_early = counter->ensure_no_execution_for(1000ms); + EXPECT_TRUE(no_run_early) << "Job ran during the delay period!"; + + // 阶段 B: 验证延迟结束后,任务开始运行 + // 现在已经过了 1s,再等 2.5s (总共 3.5s),应该能覆盖 2s 延迟 + 1s 触发 + ASSERT_TRUE(counter->wait_for_at_least(1, 2500ms)) + << "Job failed to start after delay"; + + job->stop(); +} + +// 测试 4: 多个任务并发 +TEST_F(CronSchedulerTest, MultipleTasks) +{ + auto counter1 = std::make_shared(); + auto counter2 = std::make_shared(); + + auto job1 = CronScheduler::instance().schedule("* * * * * *", [counter1]() { counter1->tick(); }); + // job2 延迟 1 秒开始 + auto job2 = CronScheduler::instance().schedule("* * * * * *", [counter2]() { counter2->tick(); }, 1000ms); + + // 验证 job1 跑了 + EXPECT_TRUE(counter1->wait_for_at_least(1, 2000ms)); + + // 验证 job2 也跑了(需要多等一会儿因为有延迟) + EXPECT_TRUE(counter2->wait_for_at_least(1, 3000ms)); + + job1->stop(); + job2->stop(); +} + +// 测试 5: 验证错误的表达式不会导致 Crash,而是抛出异常 +TEST_F(CronSchedulerTest, InvalidExpression) +{ + EXPECT_THROW({ + CronScheduler::instance().schedule("invalid cron", [](){}); + }, std::exception); +}