|
|
@@ -0,0 +1,274 @@
|
|
|
+#include "zmq_client.h"
|
|
|
+#include "network/binary_utility.hpp"
|
|
|
+
|
|
|
+// from sophiar
|
|
|
+#include "utility/coro_signal2.hpp"
|
|
|
+#include "utility/coro_worker.hpp"
|
|
|
+#include "utility/coro_worker_helper_func.hpp"
|
|
|
+
|
|
|
+#include <boost/asio/experimental/concurrent_channel.hpp>
|
|
|
+#include <boost/endian.hpp>
|
|
|
+#include <boost/process.hpp>
|
|
|
+
|
|
|
+#include <azmq/socket.hpp>
|
|
|
+
|
|
|
+#include <thread>
|
|
|
+#include <utility>
|
|
|
+
|
|
|
+namespace bp = boost::process;
|
|
|
+namespace ba = boost::asio;
|
|
|
+using namespace sophiar;
|
|
|
+using boost::system::error_code;
|
|
|
+
|
|
|
+namespace {
|
|
|
+ const char *cv_type_to_dtype(const int type) {
|
|
|
+ static constexpr bool le =
|
|
|
+ boost::endian::order::native == boost::endian::order::little;
|
|
|
+ switch (CV_MAT_DEPTH(type)) {
|
|
|
+ // @formatter:off
|
|
|
+ case CV_8U: { return "|u1"; }
|
|
|
+ case CV_16U: { return le ? "<u2" : ">u2"; }
|
|
|
+ case CV_32F: { return le ? "<f4" : ">f4"; }
|
|
|
+ // @formatter:on
|
|
|
+ default: {
|
|
|
+ RET_ERROR_E;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ int dtype_to_cv_type(const std::string &dtype, int c) {
|
|
|
+ // @formatter:off
|
|
|
+ if (dtype == "|b1") { return CV_8UC(c); }
|
|
|
+ if (dtype == "|u1") { return CV_8UC(c); }
|
|
|
+ if (dtype == "<u2") { return CV_16UC(c); }
|
|
|
+ if (dtype == "<f4") { return CV_32FC(c); }
|
|
|
+ // @formatter:on
|
|
|
+ RET_ERROR_E;
|
|
|
+ }
|
|
|
+
|
|
|
+ std::filesystem::path get_executable_folder() {
|
|
|
+ constexpr auto path_buf_size = 512;
|
|
|
+ char path_buf[path_buf_size];
|
|
|
+ const auto len = readlink("/proc/self/exe", path_buf, path_buf_size - 1);
|
|
|
+ assert(len != -1 && len < path_buf_size - 1);
|
|
|
+ path_buf[len] = '\0';
|
|
|
+ const auto path = std::filesystem::path(path_buf);
|
|
|
+ return path.parent_path();
|
|
|
+ }
|
|
|
+
|
|
|
+ constexpr auto max_reply_size = 32 * 1024 * 1024; // 32MB
|
|
|
+ constexpr auto max_channel_buffer = 32;
|
|
|
+
|
|
|
+ struct channel_is_full_error final : std::exception {
|
|
|
+ [[nodiscard]] const char *what() const noexcept override {
|
|
|
+ return "Request buffer channel is full.";
|
|
|
+ }
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
+struct zmq_client::impl {
|
|
|
+ create_config conf;
|
|
|
+
|
|
|
+ ba::io_context aux_ctx;
|
|
|
+ std::optional<std::thread> aux_thread;
|
|
|
+ std::optional<azmq::req_socket> socket;
|
|
|
+
|
|
|
+ using request_channel_type =
|
|
|
+ boost::asio::experimental::concurrent_channel<void(error_code, request_type)>;
|
|
|
+ std::optional<request_channel_type> request_chan;
|
|
|
+ coro_worker::pointer request_worker;
|
|
|
+ std::optional<coro_signal2> socket_signal;
|
|
|
+ std::optional<signal_watcher> socket_watcher;
|
|
|
+
|
|
|
+ std::atomic<size_t> queued_requests = 0;
|
|
|
+
|
|
|
+ static data_type encode_request(const request_type &req) {
|
|
|
+ using namespace nlohmann;
|
|
|
+
|
|
|
+ auto img_list = json::array();
|
|
|
+ assert(!req.img_list.empty());
|
|
|
+ for (auto &img: req.img_list) {
|
|
|
+ auto img_info = json::object();
|
|
|
+ img_info["dtype"] = cv_type_to_dtype(img.cv_type());
|
|
|
+ img_info["shape"] = json::array(
|
|
|
+ {img.height(), img.width(), CV_MAT_CN(img.cv_type())});
|
|
|
+ img_info["transfer"] = "direct";
|
|
|
+ img_list.push_back(img_info);
|
|
|
+ }
|
|
|
+
|
|
|
+ auto head = nlohmann::json();
|
|
|
+ head["model"] = req.model_name;
|
|
|
+ head["images"] = img_list;
|
|
|
+ head["extra"] = req.extra;
|
|
|
+ const auto head_str = head.dump();
|
|
|
+ const size_t head_size = head_str.length();
|
|
|
+
|
|
|
+ size_t total_size = sizeof(size_t) + head_size;
|
|
|
+ for (auto &img: req.img_list) {
|
|
|
+ total_size += sizeof(size_t) + img.byte_size();
|
|
|
+ }
|
|
|
+ auto ret = data_type(total_size);
|
|
|
+
|
|
|
+ auto writer = network_writer(ret);
|
|
|
+ writer << head_size << head_str;
|
|
|
+ for (auto &img: req.img_list) {
|
|
|
+ size_t img_size = img.byte_size();
|
|
|
+ writer << img_size;
|
|
|
+ auto img_dense = create_dense(img);
|
|
|
+ auto helper = read_access_helper(img_dense.host());
|
|
|
+ writer.write_data(helper.ptr(), img_dense.byte_size());
|
|
|
+ }
|
|
|
+
|
|
|
+ assert(writer.empty());
|
|
|
+ return ret;
|
|
|
+ }
|
|
|
+
|
|
|
+ static reply_type decode_reply(const data_type &data) {
|
|
|
+ using namespace nlohmann;
|
|
|
+
|
|
|
+ auto reader = network_reader(data);
|
|
|
+ const auto head_size = reader.read_value<size_t>();
|
|
|
+ auto head_str = reader.read_std_string(head_size);
|
|
|
+ auto head = json::parse(head_str);
|
|
|
+
|
|
|
+ auto ret = zmq_client::reply_type();
|
|
|
+ ret.extra = head["extra"];
|
|
|
+
|
|
|
+ for (auto &img_info: head["images"]) {
|
|
|
+ assert(img_info["transfer"] == "direct");
|
|
|
+
|
|
|
+ auto img_shape = img_info["shape"];
|
|
|
+ assert(img_shape.is_array());
|
|
|
+ if (img_shape.size() == 1) {
|
|
|
+ // no object found
|
|
|
+ assert(img_shape[0].get<int>() == 0);
|
|
|
+ return {};
|
|
|
+ }
|
|
|
+ assert(img_shape.size() == 2
|
|
|
+ || img_shape.size() == 3);
|
|
|
+ const auto img_height = img_shape[0].get<int>();
|
|
|
+ const auto img_width = img_shape[1].get<int>();
|
|
|
+ int img_channels = 1;
|
|
|
+ if (img_shape.size() == 3) {
|
|
|
+ img_channels = img_shape[2].get<int>();
|
|
|
+ }
|
|
|
+
|
|
|
+ const auto img_bytes = reader.read_value<size_t>();
|
|
|
+ const auto img_dtype = img_info["dtype"].get<std::string>();
|
|
|
+ auto img = sp_image::create(dtype_to_cv_type(img_dtype, img_channels),
|
|
|
+ cv::Size(img_width, img_height), reader.current_ptr());
|
|
|
+ assert(img.byte_size() == img_bytes);
|
|
|
+ reader.manual_offset(img_bytes);
|
|
|
+ ret.img_list.push_back(img);
|
|
|
+ }
|
|
|
+
|
|
|
+ return ret;
|
|
|
+ }
|
|
|
+
|
|
|
+ ba::awaitable<bool> handle_request() {
|
|
|
+ const auto req = co_await request_chan->async_receive(ba::use_awaitable);
|
|
|
+ const auto req_data = encode_request(req);
|
|
|
+ const auto req_buf = ba::buffer(req_data.start_ptr(), req_data.size);
|
|
|
+ --queued_requests;
|
|
|
+
|
|
|
+ socket->async_send(req_buf, [&](const error_code &ec, const size_t size) {
|
|
|
+ if (ec.failed()) {
|
|
|
+ SPDLOG_ERROR("Zmq client send failed: {}", ec.message());
|
|
|
+ assert(false);
|
|
|
+ }
|
|
|
+ assert(size == req_data.size);
|
|
|
+ socket_signal->try_notify_all();
|
|
|
+ });
|
|
|
+ co_await socket_watcher->coro_wait();
|
|
|
+
|
|
|
+ auto rep_data = data_type(max_reply_size);
|
|
|
+ const auto rep_buf = ba::buffer(rep_data.start_ptr(), rep_data.size);
|
|
|
+ socket->async_receive(rep_buf, [=, this](const error_code &ec, const size_t size) mutable {
|
|
|
+ if (ec.failed()) {
|
|
|
+ SPDLOG_ERROR("Zmq client receive failed: {}", ec.message());
|
|
|
+ assert(false);
|
|
|
+ }
|
|
|
+ rep_data.shrink(size);
|
|
|
+ socket_signal->try_notify_all();
|
|
|
+ });
|
|
|
+ co_await socket_watcher->coro_wait();
|
|
|
+
|
|
|
+ const auto rep = decode_reply(rep_data);
|
|
|
+ req.reply_cb(rep);
|
|
|
+ co_return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ void start_server() {
|
|
|
+ // create python process
|
|
|
+ bp::environment aux_env =
|
|
|
+ boost::this_process::environment();
|
|
|
+ aux_env["PYTHONPATH"] += get_executable_folder();
|
|
|
+ auto aux_proc = bp::child(
|
|
|
+ conf.python_interpreter, conf.server_script_path,
|
|
|
+ aux_env, bp::start_dir(conf.server_working_dir));
|
|
|
+
|
|
|
+ socket.emplace(aux_ctx, true);
|
|
|
+ socket->connect(conf.serv_addr);
|
|
|
+
|
|
|
+ socket_signal.emplace(&aux_ctx);
|
|
|
+ socket_watcher.emplace(socket_signal->new_watcher());
|
|
|
+ request_worker = make_infinite_coro_worker(
|
|
|
+ [this] { return handle_request(); },
|
|
|
+ coro_worker::empty_func, &aux_ctx);
|
|
|
+ request_worker->run();
|
|
|
+
|
|
|
+ if (true) {
|
|
|
+ auto blocker = boost::asio::make_work_guard(aux_ctx);
|
|
|
+ aux_ctx.run();
|
|
|
+ }
|
|
|
+
|
|
|
+ // wait coro worker to exit
|
|
|
+ request_worker->cancel();
|
|
|
+ aux_ctx.restart();
|
|
|
+ aux_ctx.run();
|
|
|
+
|
|
|
+ // cleanup
|
|
|
+ aux_proc.terminate();
|
|
|
+ }
|
|
|
+
|
|
|
+ void on_request(const request_type &req) {
|
|
|
+ if (!request_chan->try_send(error_code(), req)) {
|
|
|
+ throw channel_is_full_error();
|
|
|
+ }
|
|
|
+ ++queued_requests;
|
|
|
+ }
|
|
|
+
|
|
|
+ ba::awaitable<void> on_request_async(const request_type &req) {
|
|
|
+ auto closer = sg::make_scope_guard([this] { ++queued_requests; });
|
|
|
+ return request_chan->async_send(error_code(), req, ba::use_awaitable);
|
|
|
+ }
|
|
|
+
|
|
|
+ explicit impl(create_config _conf)
|
|
|
+ : conf(std::move(_conf)) {
|
|
|
+ request_chan.emplace(aux_ctx, max_channel_buffer);
|
|
|
+ aux_thread.emplace([this] { start_server(); });
|
|
|
+ }
|
|
|
+
|
|
|
+ ~impl() {
|
|
|
+ aux_ctx.stop();
|
|
|
+ aux_thread->join();
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+zmq_client::zmq_client(const create_config &conf)
|
|
|
+ : pimpl(std::make_unique<impl>(conf)) {
|
|
|
+}
|
|
|
+
|
|
|
+zmq_client::~zmq_client() = default;
|
|
|
+
|
|
|
+void zmq_client::request(const request_type &req) const {
|
|
|
+ pimpl->on_request(req);
|
|
|
+}
|
|
|
+
|
|
|
+ba::awaitable<void> zmq_client::coro_request(const request_type &req) const {
|
|
|
+ return pimpl->on_request_async(req);
|
|
|
+}
|
|
|
+
|
|
|
+size_t zmq_client::queued_requests() const {
|
|
|
+ return pimpl->queued_requests;
|
|
|
+}
|