|
|
@@ -0,0 +1,208 @@
|
|
|
+#include "zmq_client_impl.h"
|
|
|
+
|
|
|
+#include <boost/asio/post.hpp>
|
|
|
+#include <boost/endian.hpp>
|
|
|
+
|
|
|
+using boost::asio::buffer;
|
|
|
+using boost::asio::post;
|
|
|
+using boost::system::error_code;
|
|
|
+
|
|
|
+namespace zmq_client_impl {
|
|
|
+
|
|
|
+// zmq::context_t zmq_ctx;
|
|
|
+
|
|
|
+ const char *cv_type_to_dtype(int type) {
|
|
|
+ static constexpr bool le =
|
|
|
+ boost::endian::order::native == boost::endian::order::little;
|
|
|
+ switch (CV_MAT_DEPTH(type)) {
|
|
|
+ // @formatter:off
|
|
|
+ case CV_8U: { return "|u1"; }
|
|
|
+ case CV_16U: { return le ? "<u2" : ">u2"; }
|
|
|
+ case CV_32F: { return le ? "<f4" : ">f4"; }
|
|
|
+ // @formatter:on
|
|
|
+ default: {
|
|
|
+ RET_ERROR_E;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ int dtype_to_cv_type(const std::string &dtype, int c) {
|
|
|
+ // @formatter:off
|
|
|
+ if (dtype == "|b1") { return CV_8UC(c); }
|
|
|
+ if (dtype == "|u1") { return CV_8UC(c); }
|
|
|
+ if (dtype == "<u2") { return CV_16UC(c); }
|
|
|
+ if (dtype == "<f4") { return CV_32FC(c); }
|
|
|
+ // @formatter:on
|
|
|
+ RET_ERROR_E;
|
|
|
+ }
|
|
|
+
|
|
|
+ data_type encode_request(const zmq_client::request_type &req,
|
|
|
+ smart_cuda_stream *stream) {
|
|
|
+ using namespace nlohmann;
|
|
|
+
|
|
|
+ // request memory transfer from cuda to host in advance
|
|
|
+ for (auto &img: req.img_list) {
|
|
|
+ img->memory(MEM_HOST, stream);
|
|
|
+ }
|
|
|
+
|
|
|
+ auto img_list = json::array();
|
|
|
+ assert(!req.img_list.empty());
|
|
|
+ for (auto k = 0; k < req.img_list.size(); ++k) {
|
|
|
+ auto &img = req.img_list[k];
|
|
|
+ auto img_info = json::object();
|
|
|
+ img_info["dtype"] = cv_type_to_dtype(img->cv_type());
|
|
|
+ img_info["shape"] = json::array(
|
|
|
+ {img->height(), img->width(), CV_MAT_CN(img->cv_type())});
|
|
|
+ img_info["transfer"] = "direct";
|
|
|
+// img_info["id"] = k;
|
|
|
+ img_list.push_back(img_info);
|
|
|
+ }
|
|
|
+
|
|
|
+ auto head = nlohmann::json();
|
|
|
+ head["model"] = req.model_name;
|
|
|
+ head["images"] = img_list;
|
|
|
+ head["extra"] = req.extra;
|
|
|
+ auto head_str = head.dump();
|
|
|
+ size_t head_size = head_str.length();
|
|
|
+
|
|
|
+ size_t total_size = sizeof(size_t) + head_size;
|
|
|
+ for (auto &img: req.img_list) {
|
|
|
+ total_size += sizeof(size_t) + img->size_in_bytes();
|
|
|
+ }
|
|
|
+ auto ret = data_type(total_size);
|
|
|
+
|
|
|
+ auto writer = network_writer(ret);
|
|
|
+ writer << head_size << head_str;
|
|
|
+ for (auto &img: req.img_list) {
|
|
|
+ size_t img_size = img->size_in_bytes();
|
|
|
+ writer << img_size;
|
|
|
+ auto img_data = img->memory(MEM_HOST, stream);
|
|
|
+ assert(img_data.is_continuous());
|
|
|
+ assert(img_data.width * img_data.height == img->size_in_bytes());
|
|
|
+ CUDA_API_CHECK(cudaStreamSynchronize(stream->cuda));
|
|
|
+ writer.write_data((uint8_t *) img_data.start_ptr(),
|
|
|
+ img->size_in_bytes());
|
|
|
+ }
|
|
|
+
|
|
|
+ assert(writer.empty());
|
|
|
+ return ret;
|
|
|
+ }
|
|
|
+
|
|
|
+ zmq_client::reply_type decode_reply(const data_type &data) {
|
|
|
+ using namespace nlohmann;
|
|
|
+
|
|
|
+ auto reader = network_reader(data);
|
|
|
+ auto head_size = reader.read_value<size_t>();
|
|
|
+ auto head_str = reader.read_std_string(head_size);
|
|
|
+ auto head = json::parse(head_str);
|
|
|
+
|
|
|
+ auto ret = zmq_client::reply_type();
|
|
|
+ ret.extra = head["extra"];
|
|
|
+
|
|
|
+ for (auto k = 0; k < head["images"].size(); ++k) {
|
|
|
+ auto &img_info = head["images"][k];
|
|
|
+// assert(img_info["id"] == k);
|
|
|
+ assert(img_info["transfer"] == "direct");
|
|
|
+
|
|
|
+ auto img_shape = img_info["shape"];
|
|
|
+ assert(img_shape.is_array());
|
|
|
+ assert(img_shape.size() == 2
|
|
|
+ || img_shape.size() == 3);
|
|
|
+ auto img_height = img_shape[0].get<int>();
|
|
|
+ auto img_width = img_shape[1].get<int>();
|
|
|
+ int img_channels = 1;
|
|
|
+ if (img_shape.size() == 3) {
|
|
|
+ img_channels = img_shape[2].get<int>();
|
|
|
+ }
|
|
|
+
|
|
|
+ auto img_dtype = img_info["dtype"].get<std::string>();
|
|
|
+ auto img_type = dtype_to_cv_type(img_dtype, img_channels);
|
|
|
+ auto img = create_image(cv::Size(img_width, img_height), img_type);
|
|
|
+
|
|
|
+ auto img_mem = img->memory(MEM_HOST);
|
|
|
+ auto img_bytes = reader.read_value<size_t>();
|
|
|
+ assert(img_bytes == img->size_in_bytes());
|
|
|
+ assert(img_mem.is_continuous());
|
|
|
+ reader.read_data((uint8_t *) img_mem.start_ptr(), img_bytes);
|
|
|
+ img->host_modified();
|
|
|
+ ret.img_list.push_back(img);
|
|
|
+ }
|
|
|
+
|
|
|
+ return ret;
|
|
|
+ }
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+zmq_client::impl::impl(const create_config &conf) {
|
|
|
+ aux_ctx = std::make_unique<io_context>();
|
|
|
+ aux_thread = std::make_unique<std::thread>([=, this]() {
|
|
|
+ aux_thread_work(conf);
|
|
|
+ });
|
|
|
+}
|
|
|
+
|
|
|
+zmq_client::impl::~impl() {
|
|
|
+ aux_ctx->stop();
|
|
|
+ aux_thread->join();
|
|
|
+}
|
|
|
+
|
|
|
+void zmq_client::impl::aux_thread_work(const create_config &conf) {
|
|
|
+ CUDA_API_CHECK(cuCtxSetCurrent(*conf.cuda_ctx));
|
|
|
+ reply_cb = conf.reply_cb;
|
|
|
+ socket = std::make_unique<azmq::req_socket>(*aux_ctx, true);
|
|
|
+ socket->connect(conf.serv_addr);
|
|
|
+ auto blocker = boost::asio::make_work_guard(*aux_ctx);
|
|
|
+ aux_ctx->run();
|
|
|
+
|
|
|
+ // cleanup
|
|
|
+ socket = nullptr;
|
|
|
+}
|
|
|
+
|
|
|
+void zmq_client::impl::aux_on_request(const request_type &req) {
|
|
|
+ assert(!socket_busy.test());
|
|
|
+ socket_busy.test_and_set();
|
|
|
+
|
|
|
+ // send request
|
|
|
+ auto req_data = encode_request(req, &aux_stream);
|
|
|
+ auto req_buf = buffer(req_data.start_ptr(), req_data.size);
|
|
|
+ socket->async_send(req_buf, [=, this](error_code ec, size_t size) {
|
|
|
+ assert(!ec);
|
|
|
+ assert(size == req_data.size);
|
|
|
+ try_recv_rep();
|
|
|
+ });
|
|
|
+}
|
|
|
+
|
|
|
+void zmq_client::impl::try_recv_rep() {
|
|
|
+ static constexpr auto max_rep_size = 32 * 1024 * 1024; // 32MB
|
|
|
+ auto rep_data = data_type(max_rep_size);
|
|
|
+ auto rep_buf = buffer(rep_data.start_ptr(), rep_data.size);
|
|
|
+ socket->async_receive(rep_buf, [=, this](error_code ec, size_t size) mutable {
|
|
|
+ assert(!ec);
|
|
|
+ rep_data.shrink(size);
|
|
|
+ on_reply(rep_data);
|
|
|
+ });
|
|
|
+}
|
|
|
+
|
|
|
+void zmq_client::impl::on_reply(const data_type &data) {
|
|
|
+ auto rep = decode_reply(data);
|
|
|
+ reply_cb(rep);
|
|
|
+ socket_busy.clear();
|
|
|
+}
|
|
|
+
|
|
|
+void zmq_client::impl::on_request(const request_type &req) {
|
|
|
+ if (socket_busy.test()) return;
|
|
|
+ post(*aux_ctx, [req = req, this]() { aux_on_request(req); });
|
|
|
+}
|
|
|
+
|
|
|
+zmq_client::zmq_client(const create_config &conf)
|
|
|
+ : pimpl(std::make_unique<impl>(conf)) {
|
|
|
+}
|
|
|
+
|
|
|
+zmq_client::~zmq_client() = default;
|
|
|
+
|
|
|
+void zmq_client::request(const request_type &req) {
|
|
|
+ pimpl->on_request(req);
|
|
|
+}
|
|
|
+
|
|
|
+bool zmq_client::is_ideal() const {
|
|
|
+ return !pimpl->socket_busy.test();
|
|
|
+}
|