瀏覽代碼

Fixed ZeroMQ client and FastSAM module.

jcsyshc 1 年之前
父節點
當前提交
b55aa0c7af

+ 0 - 1
CMakeLists.txt

@@ -5,7 +5,6 @@ set(CMAKE_CXX_STANDARD 23)
 
 add_executable(${PROJECT_NAME} src/main.cpp
         src/ai/impl/fast_sam.cpp
-        src/ai/impl/zmq_client.cpp
         src/image_process/impl/camera_calibrator.cpp
         src/image_process/impl/image_process_ui.cpp
         src/image_process/impl/process_funcs.cpp

+ 5 - 5
src/ai/fast_sam.h

@@ -8,16 +8,16 @@
 
 #include <memory>
 
-class fast_sam {
+class [[deprecated]] fast_sam {
 public:
 
     using io_context = boost::asio::io_context;
 
     struct create_config {
-        obj_name_type img_in = invalid_obj_name; // image_u8c3
-        obj_name_type mask_out = invalid_obj_name; // image_u8c1
-        CUcontext *cuda_ctx = nullptr;
-        io_context *ctx = nullptr;
+        obj_name_type img_in = invalid_obj_name; // sp_image
+        obj_name_type mask_out = invalid_obj_name; // sp_image
+        // CUcontext *cuda_ctx = nullptr;
+        // io_context *ctx = nullptr;
     };
 
     explicit fast_sam(create_config conf);

+ 15 - 18
src/ai/impl/fast_sam.cpp

@@ -8,40 +8,39 @@
 using boost::asio::post;
 namespace fs = std::filesystem;
 
-fast_sam::impl::impl(create_config _conf) {
-    conf = _conf;
-    ctx = conf.ctx;
-}
+extern boost::asio::io_context *main_ctx;
+
+fast_sam::impl::impl(create_config _conf)
+    : conf(_conf) { (void) 0; }
 
 void fast_sam::impl::image_callback(obj_name_type name) {
     assert(client != nullptr);
-    if (!client->is_ideal()) return;
+    if (client->queued_requests() != 0) return;
 
     assert(name == conf.img_in);
-    auto img = OBJ_QUERY(image_u8c3, name);
-    auto img_v2 = create_image(img);
+    const auto img = OBJ_QUERY(sp_image, name);
     auto req = zmq_client::request_type();
-    req.img_list.push_back(img_v2);
+    req.img_list.push_back(img);
 
     using namespace nlohmann;
     auto extra = json();
     extra["prompt_type"] = "point";
     extra["points"] = json::array({json::array(
-            {img_v2->width() / 2, img_v2->height() / 2})}); // center point
+            {img.width() / 2, img.height() / 2})}); // center point
     extra["point_label"] = json::array({1});
     extra["imgsz"] = 1024;
 
+    req.reply_cb = [this](auto rep) { reply_callback(rep); };
     req.extra = extra;
     client->request(req);
 }
 
 void fast_sam::impl::reply_callback(const zmq_client::reply_type &rep) {
     assert(!rep.img_list.empty());
-    auto img_v2 = rep.img_list.front();
-    img_v2->set_meta(META_COLOR_FMT, COLOR_BW); // TODO: deprecated, remove this
-    img_v2->set_meta(META_IMAGE_DISPLAY_FMT, DISP_MASK);
-//    auto img = img_v2->v1<uchar1>();
-    OBJ_SAVE(conf.mask_out, img_v2);
+    auto img = rep.img_list.front();
+    // img.insert_meta(META_COLOR_FMT, COLOR_BW); // TODO: deprecated, remove this
+    // img.insert_meta(META_IMAGE_DISPLAY_FMT, DISP_MASK);
+    OBJ_SAVE(conf.mask_out, img);
 }
 
 void fast_sam::impl::start() {
@@ -54,8 +53,6 @@ void fast_sam::impl::start() {
             .server_script_path = script_path,
             .serv_addr = "ipc:///tmp/fast_sam_v1",
 //            .serv_addr = "tcp://127.0.0.1:7899",
-            .reply_cb = [this](auto rep) { reply_callback(rep); },
-            .cuda_ctx  = conf.cuda_ctx,
     };
     assert(client == nullptr);
     client = std::make_unique<zmq_client>(client_conf);
@@ -72,11 +69,11 @@ void fast_sam::impl::stop() {
 void fast_sam::impl::show() {
     if (client == nullptr) {
         if (ImGui::Button("Start")) {
-            post(*ctx, [this] { start(); });
+            post(*main_ctx, [this] { start(); });
         }
     } else {
         if (ImGui::Button("Stop")) {
-            post(*ctx, [this] { stop(); });
+            post(*main_ctx, [this] { stop(); });
         }
     }
 }

+ 3 - 4
src/ai/impl/fast_sam_impl.h

@@ -2,18 +2,17 @@
 #define DEPTHGUIDE_FAST_SAM_IMPL_H
 
 #include "ai/fast_sam.h"
-#include "ai/zmq_client.h"
-#include "core/image_utility.hpp"
+#include "module_v5/zmq_client.h"
 
 struct fast_sam::impl {
 
     create_config conf;
-    io_context *ctx = nullptr;
+    // io_context *ctx = nullptr;
 
     obj_conn_type img_conn;
     std::unique_ptr<zmq_client> client;
 
-    explicit impl(create_config conf);
+    explicit impl(create_config _conf);
 
     void image_callback(obj_name_type name);
 

+ 0 - 229
src/ai/impl/zmq_client.cpp

@@ -1,229 +0,0 @@
-#include "zmq_client_impl.h"
-
-#include <boost/asio/post.hpp>
-#include <boost/endian.hpp>
-
-using boost::asio::buffer;
-using boost::asio::post;
-using boost::system::error_code;
-
-namespace zmq_client_impl {
-
-    const char *cv_type_to_dtype(int type) {
-        static constexpr bool le =
-                boost::endian::order::native == boost::endian::order::little;
-        switch (CV_MAT_DEPTH(type)) {
-            // @formatter:off
-            case CV_8U:  { return "|u1"; }
-            case CV_16U: { return le ? "<u2" : ">u2"; }
-            case CV_32F: { return le ? "<f4" : ">f4"; }
-            // @formatter:on
-            default: {
-                RET_ERROR_E;
-            }
-        }
-    }
-
-    int dtype_to_cv_type(const std::string &dtype, int c) {
-        // @formatter:off
-        if (dtype == "|b1") { return CV_8UC(c); }
-        if (dtype == "|u1") { return CV_8UC(c); }
-        if (dtype == "<u2") { return CV_16UC(c); }
-        if (dtype == "<f4") { return CV_32FC(c); }
-        // @formatter:on
-        RET_ERROR_E;
-    }
-
-    data_type encode_request(const zmq_client::request_type &req,
-                             smart_cuda_stream *stream) {
-        using namespace nlohmann;
-
-        // request memory transfer from cuda to host in advance
-        for (auto &img: req.img_list) {
-            img->memory(MEM_HOST, stream);
-        }
-
-        auto img_list = json::array();
-        assert(!req.img_list.empty());
-        for (auto k = 0; k < req.img_list.size(); ++k) {
-            auto &img = req.img_list[k];
-            auto img_info = json::object();
-            img_info["dtype"] = cv_type_to_dtype(img->cv_type());
-            img_info["shape"] = json::array(
-                    {img->height(), img->width(), CV_MAT_CN(img->cv_type())});
-            img_info["transfer"] = "direct";
-//            img_info["id"] = k;
-            img_list.push_back(img_info);
-        }
-
-        auto head = nlohmann::json();
-        head["model"] = req.model_name;
-        head["images"] = img_list;
-        head["extra"] = req.extra;
-        auto head_str = head.dump();
-        size_t head_size = head_str.length();
-
-        size_t total_size = sizeof(size_t) + head_size;
-        for (auto &img: req.img_list) {
-            total_size += sizeof(size_t) + img->size_in_bytes();
-        }
-        auto ret = data_type(total_size);
-
-        auto writer = network_writer(ret);
-        writer << head_size << head_str;
-        for (auto &img: req.img_list) {
-            size_t img_size = img->size_in_bytes();
-            writer << img_size;
-            auto img_data = img->memory(MEM_HOST, stream);
-            assert(img_data.is_continuous());
-            assert(img_data.width * img_data.height == img->size_in_bytes());
-            CUDA_API_CHECK(cudaStreamSynchronize(stream->cuda));
-            writer.write_data((uint8_t *) img_data.start_ptr(),
-                              img->size_in_bytes());
-        }
-
-        assert(writer.empty());
-        return ret;
-    }
-
-    zmq_client::reply_type decode_reply(const data_type &data) {
-        using namespace nlohmann;
-
-        auto reader = network_reader(data);
-        auto head_size = reader.read_value<size_t>();
-        auto head_str = reader.read_std_string(head_size);
-        auto head = json::parse(head_str);
-
-        auto ret = zmq_client::reply_type();
-        ret.extra = head["extra"];
-
-        for (auto k = 0; k < head["images"].size(); ++k) {
-            auto &img_info = head["images"][k];
-//            assert(img_info["id"] == k);
-            assert(img_info["transfer"] == "direct");
-
-            auto img_shape = img_info["shape"];
-            assert(img_shape.is_array());
-            if (img_shape.size() == 1) { // no object found
-                assert(img_shape[0].get<int>() == 0);
-                return {};
-            }
-            assert(img_shape.size() == 2
-                   || img_shape.size() == 3);
-            auto img_height = img_shape[0].get<int>();
-            auto img_width = img_shape[1].get<int>();
-            int img_channels = 1;
-            if (img_shape.size() == 3) {
-                img_channels = img_shape[2].get<int>();
-            }
-
-            auto img_dtype = img_info["dtype"].get<std::string>();
-            auto img_type = dtype_to_cv_type(img_dtype, img_channels);
-            auto img = create_image(cv::Size(img_width, img_height), img_type);
-
-            auto img_mem = img->memory(MEM_HOST);
-            auto img_bytes = reader.read_value<size_t>();
-            assert(img_bytes == img->size_in_bytes());
-            assert(img_mem.is_continuous());
-            reader.read_data((uint8_t *) img_mem.start_ptr(), img_bytes);
-            img->host_modified();
-            ret.img_list.push_back(img);
-        }
-
-        return ret;
-    }
-
-}
-
-zmq_client::impl::impl(const create_config &conf) {
-    // create python process
-    static constexpr auto zmq_server_path =
-            "/home/tpx/project/DepthGuide/src/ai/impl/python";
-    bp::environment aux_env =
-            boost::this_process::environment();
-    aux_env["PYTHONPATH"] += zmq_server_path;
-    aux_proc = std::make_unique<bp::child>(
-            conf.python_interpreter, conf.server_script_path,
-            aux_env, bp::start_dir(conf.server_working_dir));
-
-    // create auxiliary thread
-    aux_ctx = std::make_unique<io_context>();
-    aux_thread = std::make_unique<std::thread>([=, this]() {
-        aux_thread_work(conf);
-    });
-}
-
-zmq_client::impl::~impl() {
-    aux_ctx->stop();
-    aux_thread->join();
-    aux_proc->terminate();
-}
-
-void zmq_client::impl::aux_thread_work(const create_config &conf) {
-    CUDA_API_CHECK(cuCtxSetCurrent(*conf.cuda_ctx));
-    reply_cb = conf.reply_cb;
-    socket = std::make_unique<azmq::req_socket>(*aux_ctx, true);
-    socket->connect(conf.serv_addr);
-    auto blocker = boost::asio::make_work_guard(*aux_ctx);
-    aux_ctx->run();
-
-    // cleanup
-    socket = nullptr;
-}
-
-void zmq_client::impl::aux_on_request(const request_type &req) {
-//    assert(!socket_busy.test()); // TODO: fix socket_busy
-    if (socket_busy.test()) return;
-    socket_busy.test_and_set();
-
-    // send request
-    auto req_data = encode_request(req, &aux_stream);
-    auto req_buf = buffer(req_data.start_ptr(), req_data.size);
-    socket->async_send(req_buf, [=, this](error_code ec, size_t size) {
-        assert(!ec);
-        assert(size == req_data.size);
-        try_recv_rep();
-    });
-}
-
-void zmq_client::impl::try_recv_rep() {
-    static constexpr auto max_rep_size = 32 * 1024 * 1024; // 32MB
-    auto rep_data = data_type(max_rep_size);
-    auto rep_buf = buffer(rep_data.start_ptr(), rep_data.size);
-    socket->async_receive(rep_buf, [=, this](error_code ec, size_t size) mutable {
-        assert(!ec);
-        rep_data.shrink(size);
-        on_reply(rep_data);
-    });
-}
-
-void zmq_client::impl::on_reply(const data_type &data) {
-    auto rep = decode_reply(data);
-    if (!rep.img_list.empty()) {
-        reply_cb(rep);
-    }
-    socket_busy.clear();
-}
-
-bool zmq_client::impl::is_running() const {
-    return aux_proc->running();
-}
-
-void zmq_client::impl::on_request(const request_type &req) {
-    if (socket_busy.test() || !is_running()) return;
-    post(*aux_ctx, [req = req, this]() { aux_on_request(req); });
-}
-
-zmq_client::zmq_client(const create_config &conf)
-        : pimpl(std::make_unique<impl>(conf)) {
-}
-
-zmq_client::~zmq_client() = default;
-
-void zmq_client::request(const request_type &req) {
-    pimpl->on_request(req);
-}
-
-bool zmq_client::is_ideal() const {
-    return !pimpl->socket_busy.test();
-}

+ 0 - 77
src/ai/impl/zmq_client_impl.h

@@ -1,77 +0,0 @@
-#ifndef DEPTHGUIDE_ZMQ_CLIENT_IMPL_H
-#define DEPTHGUIDE_ZMQ_CLIENT_IMPL_H
-
-#include "ai/zmq_client.h"
-#include "network/binary_utility.hpp"
-
-#include <boost/process.hpp>
-
-#include <azmq/socket.hpp>
-
-#include <atomic>
-#include <thread>
-
-namespace zmq_client_impl {
-
-    enum image_transfer_mode {
-        DIRECT = 0,
-        HOST_SHARED = 1,
-        CUDA_SHARED = 2
-    };
-
-    const char *cv_type_to_dtype(int type);
-
-    int dtype_to_cv_type(const std::string &dtype, int channels);
-
-    /* Head Size [4 bytes]
-     * Head Json [n bytes]
-     * Img#1 Size [4 bytes]
-     * Img#1 Data [n bytes]
-     * Img#n ... */
-    data_type encode_request(const zmq_client::request_type &req,
-                             smart_cuda_stream *stream);
-
-    zmq_client::reply_type decode_reply(const data_type &data);
-
-}
-
-using namespace zmq_client_impl;
-namespace bp = boost::process;
-
-struct zmq_client::impl {
-
-    std::unique_ptr<bp::child> aux_proc;
-    std::unique_ptr<std::thread> aux_thread;
-//    image_transfer_mode trans_mode = DIRECT;
-
-    // for aux thread
-    using io_context = boost::asio::io_context;
-    std::unique_ptr<io_context> aux_ctx;
-
-    std::unique_ptr<azmq::req_socket> socket;
-    std::atomic_flag socket_busy = false;
-
-    smart_cuda_stream aux_stream;
-    reply_cb_type reply_cb;
-
-    explicit impl(const create_config &conf);
-
-    ~impl();
-
-    bool is_running() const;
-
-    void on_request(const request_type &req);
-
-    // for auxiliary thread
-
-    void aux_thread_work(const create_config &conf);
-
-    void aux_on_request(const request_type &req);
-
-    void try_recv_rep();
-
-    void on_reply(const data_type &data);
-
-};
-
-#endif //DEPTHGUIDE_ZMQ_CLIENT_IMPL_H

+ 9 - 0
src/core/math_helper.hpp

@@ -10,6 +10,15 @@
 
 #include <Eigen/Geometry>
 
+template<size_t Align = 1>
+size_t alignment_round(size_t size, const size_t align = Align) {
+    assert(std::popcount(align) == 1);
+    if (size & (align - 1)) {
+        size = (size + align) & ~(align - 1);
+    }
+    return size;
+}
+
 // r in radius
 inline glm::mat4 to_transform_mat(glm::vec3 t, glm::vec3 r) {
     static constexpr auto unit_x = glm::vec3(1.0f, 0.0f, 0.0f);

+ 1 - 1
src/core_v2/memory_manager.cpp

@@ -1,6 +1,6 @@
 #include "memory_manager.h"
 #include "memory_utility.h"
-#include "utility.hpp"
+#include "core/math_helper.hpp"
 
 #include <map>
 #include <ranges>

+ 7 - 12
src/core_v2/utility.hpp

@@ -1,24 +1,19 @@
 #ifndef UTILITY_HPP
 #define UTILITY_HPP
 
-#include <cassert>
-#include <cstdint>
-
 #include <BS_thread_pool.hpp>
 
-template<size_t Align = 1>
-size_t alignment_round(size_t size, const size_t align = Align) {
-    assert(std::popcount(align) == 1);
-    if (size & (align - 1)) {
-        size = (size + align) & ~(align - 1);
-    }
-    return size;
-}
-
 extern BS::thread_pool *g_thread_pool;
 
 #define TP_DETACH(func) g_thread_pool->detach_task(func)
 #define TP_SUBMIT(func) g_thread_pool->submit_task(func)
 #define TP_SYNC g_thread_pool->wait()
 
+#include <boost/asio/io_context.hpp>
+#include <boost/asio/post.hpp>
+
+extern boost::asio::io_context *main_ctx;
+
+#define MAIN_DETACH(func) boost::asio::post(*main_ctx, func)
+
 #endif //UTILITY_HPP

+ 3 - 1
src/image_process_v5/CMakeLists.txt

@@ -2,4 +2,6 @@ target_sources(${PROJECT_NAME} PRIVATE
         image_viewer.cpp
         image_process.cpp
         osg_helper.cpp
-        sp_image.cpp)
+        sp_image.cpp
+        process_python/fast_sam.cpp
+)

+ 74 - 16
src/image_process_v5/image_process.cpp

@@ -1,5 +1,8 @@
 #include "image_process.h"
 
+#include <glm/ext/matrix_transform.hpp>
+#include <glm/gtx/matrix_transform_2d.hpp>
+
 #include <opencv2/cudaarithm.hpp>
 #include <opencv2/cudaimgproc.hpp>
 #include <opencv2/cudawarping.hpp>
@@ -50,25 +53,44 @@ sp_image nv12_chrome_view(const sp_image &img) {
     return img_chrome.cast_view(CV_8UC2);
 }
 
+namespace {
+    struct image_opencv_cuda_helper {
+        const sp_image *read;
+        sp_image *write;
+        using proxy_type = auto_memory_info::cuda_proxy;
+        cuda_stream_guard stream_guard;
+        pair_access_helper<proxy_type, proxy_type> access_helper;
+
+        image_opencv_cuda_helper(const sp_image &src, sp_image &dst)
+            : read(&src), write(&dst),
+              stream_guard((cudaStream_t) get_cv_stream().cudaPtr()),
+              access_helper(read->cuda(), write->cuda()) { (void) 0; }
+
+        [[nodiscard]] cv::cuda::GpuMat input() const {
+            return read->cv_gpu_mat(access_helper.read_ptr());
+        }
+
+        [[nodiscard]] cv::cuda::GpuMat output() const {
+            return write->cv_gpu_mat(access_helper.write_ptr());
+        }
+    };
+}
+
 sp_image image_debayer(const sp_image &img) {
     assert(img.cv_type() == CV_8UC1);
     auto ret = sp_image::create<uchar3>(img.cv_size());
-    auto stream_guard = cuda_stream_guard((cudaStream_t) get_cv_stream().cudaPtr());
-    const auto pair_helper = pair_access_helper(img.cuda(), ret.cuda());
-    const auto in_mat = img.cv_gpu_mat(pair_helper.read_ptr());
-    auto out_mat = ret.cv_gpu_mat(pair_helper.write_ptr());
-    cv::cuda::cvtColor(in_mat, out_mat, cv::COLOR_BayerRG2BGR, 3, get_cv_stream());
+    const auto helper = image_opencv_cuda_helper(img, ret);
+    cv::cuda::cvtColor(helper.input(), helper.output(),
+                       cv::COLOR_BayerRG2BGR, 3, get_cv_stream());
     ret.merge_meta(img);
     return ret;
 }
 
 void image_resize(const sp_image &src, sp_image &dst) {
     assert(src.cv_type() == dst.cv_type());
-    auto stream_guard = cuda_stream_guard((cudaStream_t) get_cv_stream().cudaPtr());
-    const auto pair_helper = pair_access_helper(src.cuda(), dst.cuda());
-    const auto in_mat = src.cv_gpu_mat(pair_helper.read_ptr());
-    auto out_mat = dst.cv_gpu_mat(pair_helper.write_ptr());
-    cv::cuda::resize(in_mat, out_mat, dst.cv_size(), 0, 0, cv::INTER_LINEAR, get_cv_stream());
+    const auto helper = image_opencv_cuda_helper(src, dst);
+    cv::cuda::resize(helper.input(), helper.output(),
+                     dst.cv_size(), 0, 0, cv::INTER_LINEAR, get_cv_stream());
     dst.merge_meta(src);
 }
 
@@ -78,18 +100,54 @@ sp_image image_resize(const sp_image &img, const cv::Size size) {
     return ret;
 }
 
-// TODO: create a helper class to simplify this type of operation
 sp_image image_flip_y(const sp_image &img) {
     auto ret = sp_image::create(img.cv_type(), img.cv_size());
-    auto stream_guard = cuda_stream_guard((cudaStream_t) get_cv_stream().cudaPtr());
-    const auto pair_helper = pair_access_helper(img.cuda(), ret.cuda());
-    const auto in_mat = img.cv_gpu_mat(pair_helper.read_ptr());
-    auto out_mat = ret.cv_gpu_mat(pair_helper.write_ptr());
-    cv::cuda::flip(in_mat, out_mat, 1, get_cv_stream()); // flip vertically
+    const auto helper = image_opencv_cuda_helper(img, ret);
+    cv::cuda::flip(helper.input(), helper.output(), 1, get_cv_stream()); // flip vertically
+    ret.merge_meta(img);
+    return ret;
+}
+
+sp_image image_warp_affine(const sp_image &img, const glm::mat3 &matrix) {
+    auto cv_matrix = cv::Mat(2, 3, CV_32FC1);
+    for (auto i = 0; i < 3; ++i)
+        for (auto j = 0; j < 2; ++j) {
+            cv_matrix.at<float>(j, i) = matrix[i][j];
+        }
+    auto ret = sp_image::create_like(img);
+    const auto helper = image_opencv_cuda_helper(img, ret);
+    cv::cuda::warpAffine(helper.input(), helper.output(),
+                         cv_matrix, img.cv_size(), cv::INTER_LINEAR,
+                         cv::BORDER_CONSTANT, {},
+                         get_cv_stream());
     ret.merge_meta(img);
     return ret;
 }
 
+namespace {
+    float pixel_center(const float size) {
+        return 0.5f * size - 0.5f;
+    }
+}
+
+sp_image image_rotate(const sp_image &img, const float angle,
+                      std::optional<glm::vec2> center) {
+    if (!center) {
+        center = glm::vec2(pixel_center(img.width()),
+                           pixel_center(img.height()));
+    }
+    auto matrix = glm::identity<glm::mat3>();
+    matrix = glm::translate(matrix, -*center);
+    matrix = glm::rotate(matrix, angle);
+    matrix = glm::translate(matrix, *center);
+    return image_warp_affine(img, matrix);
+}
+
+sp_image image_translate(const sp_image &img, const glm::vec2 offset) {
+    const auto matrix = glm::translate(glm::identity<glm::mat3>(), offset);
+    return image_warp_affine(img, matrix);
+}
+
 #include "image_process/cuda_impl/pixel_convert.cuh"
 
 namespace {

+ 6 - 0
src/image_process_v5/image_process.h

@@ -3,6 +3,8 @@
 
 #include "sp_image.h"
 
+#include <glm/glm.hpp>
+
 size_t normal_height_to_nv12(size_t height);
 size_t nv12_height_to_normal(size_t height);
 cv::Size normal_size_to_nv12(cv::Size size);
@@ -21,6 +23,10 @@ sp_image image_rgb_to_nv12(const sp_image &img);
 sp_image image_nv12_to_rgb(const sp_image &img);
 sp_image image_yuyv_to_rgb(const sp_image &img);
 
+sp_image image_warp_affine(const sp_image &img, const glm::mat3 &matrix);
+sp_image image_rotate(const sp_image &img, float angle, std::optional<glm::vec2> center);
+sp_image image_translate(const sp_image &img, glm::vec2 offset);
+
 void image_save_jpg(const sp_image &img, const std::string &filename); // filename without extension
 void image_save_png(const sp_image &img, const std::string &filename);
 

+ 49 - 0
src/image_process_v5/process_python.h

@@ -0,0 +1,49 @@
+#ifndef PROCESS_PYTHON_H
+#define PROCESS_PYTHON_H
+
+#include "sp_image.h"
+
+#include <glm/glm.hpp>
+
+#include <functional>
+#include <string>
+
+using image_callback_type = std::function<void(sp_image)>;
+
+namespace process_python {
+    struct server_config {
+        std::string interpreter_path;
+        std::string script_dir;
+    };
+}
+
+struct fast_sam_point_prompt {
+    struct point_type {
+        glm::uvec2 pos;
+        bool is_positive;
+    };
+
+    using points_type = std::vector<point_type>;
+    points_type points;
+};
+
+struct fast_sam_box_prompt {
+};
+
+struct fast_sam_text_prompt {
+};
+
+struct fast_sam_options {
+    using prompt_type = std::variant<
+        fast_sam_point_prompt,
+        fast_sam_box_prompt,
+        fast_sam_text_prompt>;
+    prompt_type prompt;
+    size_t size_level = 1024;
+};
+
+// call before any usage
+void set_fast_sam_config(const process_python::server_config &conf);
+void image_fast_sam(const sp_image& img, const fast_sam_options &opts, const image_callback_type& cb);
+
+#endif //PROCESS_PYTHON_H

+ 79 - 0
src/image_process_v5/process_python/fast_sam.cpp

@@ -0,0 +1,79 @@
+#include "../process_python.h"
+#include "core/math_helper.hpp"
+#include "impl/main_impl.h"
+#include "module_v5/zmq_client.h"
+#include "third_party/static_block.hpp"
+
+#include <nlohmann/json.hpp>
+
+#include <filesystem>
+
+namespace fs = std::filesystem;
+using namespace nlohmann;
+
+namespace {
+    constexpr auto default_interpreter = "/home/tpx/ext/anaconda3/envs/FastSAM/bin/python";
+    constexpr auto default_script_path = "/home/tpx/ext/code/FastSAM";
+    constexpr auto server_script_name = "fastsam_server.py";
+    constexpr auto server_address = "ipc:///tmp/fast_sam_v1";
+
+    zmq_client::create_config server_conf;
+    static_block {
+        set_fast_sam_config({default_interpreter, default_script_path,});
+    }
+
+    std::optional<zmq_client> server;
+
+    void create_server() {
+        if (server) [[likely]] return;
+        server.emplace(server_conf);
+        register_cleanup_func([] { server.reset(); });
+    }
+
+    template<Vec2Type Vec>
+    auto to_json(Vec vec) {
+        return json::array({vec.x, vec.y});
+    }
+
+    json encode_prompt(const fast_sam_point_prompt &prompt) {
+        auto points_j = json::array();
+        std::ranges::transform(
+            prompt.points, std::back_inserter(points_j),
+            [](auto p) { return to_json(p.pos); });
+        auto label_j = json::array();
+        std::ranges::transform(
+            prompt.points, std::back_inserter(label_j),
+            [](auto p) { return (int) p.is_positive; });
+
+        auto ret = json();
+        ret["prompt_type"] = "point";
+        ret["points"] = points_j;
+        ret["point_label"] = label_j;
+        return ret;
+    }
+
+    json encode_prompt(fast_sam_box_prompt) { assert(false); return {}; }
+    json encode_prompt(fast_sam_text_prompt) { assert(false); return {}; }
+}
+
+void set_fast_sam_config(const process_python::server_config &conf) {
+    server_conf.python_interpreter = conf.interpreter_path;
+    server_conf.server_working_dir = conf.script_dir;
+    server_conf.server_script_path = fs::path(conf.script_dir) / server_script_name;
+    server_conf.serv_addr = server_address;
+}
+
+void image_fast_sam(const sp_image &img,
+                    const fast_sam_options &opts,
+                    const image_callback_type &cb) {
+    auto req = zmq_client::request_type();
+    req.img_list.push_back(img);
+    req.extra = std::visit([](auto &&p) { return encode_prompt(p); }, opts.prompt);
+    req.extra["imgsz"] = opts.size_level;
+    req.reply_cb = [cb = cb](const zmq_client::reply_type &rep) {
+        assert(!rep.img_list.empty());
+        cb(*rep.img_list.begin());
+    };
+    create_server();
+    server->request(req);
+}

+ 23 - 0
src/image_process_v5/sp_image.cpp

@@ -1,6 +1,9 @@
 #include "sp_image.h"
+#include "core/math_helper.hpp"
 #include "third_party/static_block.hpp"
 
+#include <opencv2/imgcodecs.hpp>
+
 #include <unordered_map>
 
 namespace {
@@ -107,6 +110,10 @@ sp_image sp_image::create_impl(const cv::Size size, const void *ptr,
     return ret;
 }
 
+sp_image sp_image::create_impl(const cv::Size size, const void *ptr, int cv_type) {
+    return create_impl(size, ptr, cv_map.query(cv_type).type);
+}
+
 sp_image sp_image::cast_view_impl(const std::type_index type) const {
     auto ret = *this;
     const auto type_size = type_map.query(type).size;
@@ -127,6 +134,16 @@ sp_image sp_image::create(const cv::Mat &mat) {
                        cv_map.query(mat.type()).type);
 }
 
+sp_image sp_image::create_like(const sp_image &img, int cv_type) {
+    if (cv_type == 0) {
+        cv_type = img.cv_type();
+    }
+    return create(cv_type, img.cv_size());
+}
+
+sp_image sp_image::from_file(const std::string &path) {
+    return create(cv::imread(path));
+}
 
 using image_ndarray_proxy = ndarray_proxy<image_rank>;
 using image_index_pack = index_pack<image_rank>;
@@ -162,6 +179,12 @@ void copy_sp_image(const sp_image &src, sp_image &dst, const cudaMemcpyKind kind
     copy_ndarray(src, dst, kind);
 }
 
+sp_image create_dense(const sp_image &src) {
+    sp_image ret = src;
+    *ret.array_base() = create_dense(*src.array_base());
+    return ret;
+}
+
 image_mem_info to_mem_v1(const sp_image &img, void *ptr,
                          const memory_location loc) {
     auto ret = image_mem_info();

+ 12 - 0
src/image_process_v5/sp_image.h

@@ -15,6 +15,7 @@ struct sp_image : ndarray_proxy<image_rank>,
     //@formatter:off
     using base_type = ndarray_proxy;
     base_type *array_base() { return this; }
+    const base_type *array_base() const { return this; }
     [[nodiscard]] cv::Size cv_size() const;
     [[nodiscard]] int cv_type() const;
     [[nodiscard]] cv::Mat cv_mat(void *ptr) const;
@@ -36,8 +37,16 @@ struct sp_image : ndarray_proxy<image_rank>,
         return create_impl(size, ptr, typeid(T));
     }
 
+    static sp_image create(const int cv_type, const cv::Size size, const void *ptr) {
+        return create_impl(size, ptr, cv_type);
+    }
+
     static sp_image create(const cv::Mat &mat);
 
+    static sp_image create_like(const sp_image &img, int cv_type = 0);
+
+    static sp_image from_file(const std::string &path);
+
     template<typename T>
     [[nodiscard]] sp_image cast_view() const {
         return cast_view_impl(typeid(T));
@@ -50,6 +59,7 @@ protected:
     static sp_image create_impl(cv::Size size, size_t align, std::type_index type);
     static sp_image create_impl(cv::Size size, size_t align, int cv_type);
     static sp_image create_impl(cv::Size size, const void *ptr, std::type_index type);
+    static sp_image create_impl(cv::Size size, const void *ptr, int cv_type);
     [[nodiscard]] sp_image cast_view_impl(std::type_index type) const;
     //@formatter:on
 };
@@ -57,6 +67,8 @@ protected:
 void copy_sp_image(const sp_image &src, sp_image &dst,
                    cudaMemcpyKind kind = cudaMemcpyDefault);
 
+sp_image create_dense(const sp_image &src);
+
 template<typename T>
 using image_ndarray = ndarray<T, image_rank>;
 

+ 25 - 13
src/impl/apps/debug/app_debug.cpp

@@ -1,21 +1,33 @@
 #include "app_debug.h"
 #include "core/math_helper.hpp"
-#include "module_v5/algorithms/algorithms.h"
+#include "image_process_v5/process_python.h"
+#include "image_process_v5/image_process.h"
 
 #include <GLFW/glfw3.h>
 
 app_debug::app_debug(const create_config &conf) {
-    auto input = fitting_circle_3d::input_type();
-    auto points = std::vector<glm::vec3>();
-    points.emplace_back(100, 1, 2);
-    points.emplace_back(100, 2, 0);
-    points.emplace_back(100, 1, -1);
-    points.emplace_back(100, 0, 0);
-    input.points = to_eigen(points);
+    if (true) {
+        const auto img = sp_image::from_file("/home/tpx/ext/code/FastSAM/images/dogs.jpg");
+        auto prompt = fast_sam_point_prompt();
+        prompt.points.emplace_back(glm::uvec2(520, 360), true);
+        prompt.points.emplace_back(glm::uvec2(620, 300), false);
+        auto opts = fast_sam_options();
+        opts.prompt = prompt;
+        image_fast_sam(img, opts, [](const sp_image &mask) {
+            image_save_png(mask, "dogs");
+            // MAIN_DETACH([] { glfwSetWindowShouldClose(glfwGetCurrentContext(), true); });
+        });
+    }
 
-    auto result = fitting_circle_3d()(input);
-    SPDLOG_DEBUG("{}, {}, {}", result.center[0], result.center[1], result.radius);
-
-
-    glfwSetWindowShouldClose(glfwGetCurrentContext(), true);
+    if (true) {
+        const auto img = sp_image::from_file("/home/tpx/ext/code/FastSAM/images/cat.jpg");
+        auto prompt = fast_sam_point_prompt();
+        prompt.points.emplace_back(glm::uvec2(540, 960), true);
+        auto opts = fast_sam_options();
+        opts.prompt = prompt;
+        image_fast_sam(img, opts, [](const sp_image &mask) {
+            image_save_png(mask, "cat");
+            MAIN_DETACH([] { glfwSetWindowShouldClose(glfwGetCurrentContext(), true); });
+        });
+    }
 }

+ 1 - 1
src/impl/apps/depth_guide/depth_guide.cpp

@@ -51,7 +51,7 @@ app_depth_guide::app_depth_guide(const create_config &_conf) {
 
     auto ai_conf = fast_sam::create_config{
             .img_in = img_color, .mask_out = img_obj_mask,
-            .cuda_ctx = conf.cuda_ctx, .ctx = conf.asio_ctx,
+            // .cuda_ctx = conf.cuda_ctx, .ctx = conf.asio_ctx,
     };
     ai_seg = std::make_unique<fast_sam>(ai_conf);
 

+ 5 - 2
src/module_v5/CMakeLists.txt

@@ -1,6 +1,9 @@
 target_sources(${PROJECT_NAME} PRIVATE
         versatile_saver.cpp
         transform_provider.cpp
-        oblique_calibrator.cpp)
+        oblique_calibrator.cpp
+        zmq_client.cpp)
 
-add_subdirectory(algorithms)
+add_subdirectory(algorithms)
+
+file(COPY python/zmq_server.py DESTINATION ${CMAKE_BINARY_DIR})

+ 0 - 0
src/module_v5/algorithms/algorithms.h → src/module_v5/algorithms.h


+ 1 - 1
src/module_v5/algorithms/fitting_circle_3d.cpp

@@ -1,4 +1,4 @@
-#include "algorithms.h"
+#include "../algorithms.h"
 #include "core/math_helper.hpp"
 
 using namespace Eigen;

+ 1 - 1
src/module_v5/oblique_calibrator.cpp

@@ -4,7 +4,7 @@
 #include "core_v2/utility.hpp"
 #include "device_v5/ndi_stray_point_tracker.h"
 #include "module_v5/transform_provider.h"
-#include "module_v5/algorithms/algorithms.h"
+#include "module_v5/algorithms.h"
 #include "third_party/scope_guard.hpp"
 
 // from sophiar

+ 0 - 0
src/ai/impl/python/zmq_server.py → src/module_v5/python/zmq_server.py


+ 274 - 0
src/module_v5/zmq_client.cpp

@@ -0,0 +1,274 @@
+#include "zmq_client.h"
+#include "network/binary_utility.hpp"
+
+// from sophiar
+#include "utility/coro_signal2.hpp"
+#include "utility/coro_worker.hpp"
+#include "utility/coro_worker_helper_func.hpp"
+
+#include <boost/asio/experimental/concurrent_channel.hpp>
+#include <boost/endian.hpp>
+#include <boost/process.hpp>
+
+#include <azmq/socket.hpp>
+
+#include <thread>
+#include <utility>
+
+namespace bp = boost::process;
+namespace ba = boost::asio;
+using namespace sophiar;
+using boost::system::error_code;
+
+namespace {
+    const char *cv_type_to_dtype(const int type) {
+        static constexpr bool le =
+                boost::endian::order::native == boost::endian::order::little;
+        switch (CV_MAT_DEPTH(type)) {
+            // @formatter:off
+            case CV_8U:  { return "|u1"; }
+            case CV_16U: { return le ? "<u2" : ">u2"; }
+            case CV_32F: { return le ? "<f4" : ">f4"; }
+            // @formatter:on
+            default: {
+                RET_ERROR_E;
+            }
+        }
+    }
+
+    int dtype_to_cv_type(const std::string &dtype, int c) {
+        // @formatter:off
+        if (dtype == "|b1") { return CV_8UC(c); }
+        if (dtype == "|u1") { return CV_8UC(c); }
+        if (dtype == "<u2") { return CV_16UC(c); }
+        if (dtype == "<f4") { return CV_32FC(c); }
+        // @formatter:on
+        RET_ERROR_E;
+    }
+
+    std::filesystem::path get_executable_folder() {
+        constexpr auto path_buf_size = 512;
+        char path_buf[path_buf_size];
+        const auto len = readlink("/proc/self/exe", path_buf, path_buf_size - 1);
+        assert(len != -1 && len < path_buf_size - 1);
+        path_buf[len] = '\0';
+        const auto path = std::filesystem::path(path_buf);
+        return path.parent_path();
+    }
+
+    constexpr auto max_reply_size = 32 * 1024 * 1024; // 32MB
+    constexpr auto max_channel_buffer = 32;
+
+    struct channel_is_full_error final : std::exception {
+        [[nodiscard]] const char *what() const noexcept override {
+            return "Request buffer channel is full.";
+        }
+    };
+}
+
+struct zmq_client::impl {
+    create_config conf;
+
+    ba::io_context aux_ctx;
+    std::optional<std::thread> aux_thread;
+    std::optional<azmq::req_socket> socket;
+
+    using request_channel_type =
+        boost::asio::experimental::concurrent_channel<void(error_code, request_type)>;
+    std::optional<request_channel_type> request_chan;
+    coro_worker::pointer request_worker;
+    std::optional<coro_signal2> socket_signal;
+    std::optional<signal_watcher> socket_watcher;
+
+    std::atomic<size_t> queued_requests = 0;
+
+    static data_type encode_request(const request_type &req) {
+        using namespace nlohmann;
+
+        auto img_list = json::array();
+        assert(!req.img_list.empty());
+        for (auto &img: req.img_list) {
+            auto img_info = json::object();
+            img_info["dtype"] = cv_type_to_dtype(img.cv_type());
+            img_info["shape"] = json::array(
+                {img.height(), img.width(), CV_MAT_CN(img.cv_type())});
+            img_info["transfer"] = "direct";
+            img_list.push_back(img_info);
+        }
+
+        auto head = nlohmann::json();
+        head["model"] = req.model_name;
+        head["images"] = img_list;
+        head["extra"] = req.extra;
+        const auto head_str = head.dump();
+        const size_t head_size = head_str.length();
+
+        size_t total_size = sizeof(size_t) + head_size;
+        for (auto &img: req.img_list) {
+            total_size += sizeof(size_t) + img.byte_size();
+        }
+        auto ret = data_type(total_size);
+
+        auto writer = network_writer(ret);
+        writer << head_size << head_str;
+        for (auto &img: req.img_list) {
+            size_t img_size = img.byte_size();
+            writer << img_size;
+            auto img_dense = create_dense(img);
+            auto helper = read_access_helper(img_dense.host());
+            writer.write_data(helper.ptr(), img_dense.byte_size());
+        }
+
+        assert(writer.empty());
+        return ret;
+    }
+
+    static reply_type decode_reply(const data_type &data) {
+        using namespace nlohmann;
+
+        auto reader = network_reader(data);
+        const auto head_size = reader.read_value<size_t>();
+        auto head_str = reader.read_std_string(head_size);
+        auto head = json::parse(head_str);
+
+        auto ret = zmq_client::reply_type();
+        ret.extra = head["extra"];
+
+        for (auto &img_info: head["images"]) {
+            assert(img_info["transfer"] == "direct");
+
+            auto img_shape = img_info["shape"];
+            assert(img_shape.is_array());
+            if (img_shape.size() == 1) {
+                // no object found
+                assert(img_shape[0].get<int>() == 0);
+                return {};
+            }
+            assert(img_shape.size() == 2
+                || img_shape.size() == 3);
+            const auto img_height = img_shape[0].get<int>();
+            const auto img_width = img_shape[1].get<int>();
+            int img_channels = 1;
+            if (img_shape.size() == 3) {
+                img_channels = img_shape[2].get<int>();
+            }
+
+            const auto img_bytes = reader.read_value<size_t>();
+            const auto img_dtype = img_info["dtype"].get<std::string>();
+            auto img = sp_image::create(dtype_to_cv_type(img_dtype, img_channels),
+                                        cv::Size(img_width, img_height), reader.current_ptr());
+            assert(img.byte_size() == img_bytes);
+            reader.manual_offset(img_bytes);
+            ret.img_list.push_back(img);
+        }
+
+        return ret;
+    }
+
+    ba::awaitable<bool> handle_request() {
+        const auto req = co_await request_chan->async_receive(ba::use_awaitable);
+        const auto req_data = encode_request(req);
+        const auto req_buf = ba::buffer(req_data.start_ptr(), req_data.size);
+        --queued_requests;
+
+        socket->async_send(req_buf, [&](const error_code &ec, const size_t size) {
+            if (ec.failed()) {
+                SPDLOG_ERROR("Zmq client send failed: {}", ec.message());
+                assert(false);
+            }
+            assert(size == req_data.size);
+            socket_signal->try_notify_all();
+        });
+        co_await socket_watcher->coro_wait();
+
+        auto rep_data = data_type(max_reply_size);
+        const auto rep_buf = ba::buffer(rep_data.start_ptr(), rep_data.size);
+        socket->async_receive(rep_buf, [=, this](const error_code &ec, const size_t size) mutable {
+            if (ec.failed()) {
+                SPDLOG_ERROR("Zmq client receive failed: {}", ec.message());
+                assert(false);
+            }
+            rep_data.shrink(size);
+            socket_signal->try_notify_all();
+        });
+        co_await socket_watcher->coro_wait();
+
+        const auto rep = decode_reply(rep_data);
+        req.reply_cb(rep);
+        co_return true;
+    }
+
+    void start_server() {
+        // create python process
+        bp::environment aux_env =
+                boost::this_process::environment();
+        aux_env["PYTHONPATH"] += get_executable_folder();
+        auto aux_proc = bp::child(
+            conf.python_interpreter, conf.server_script_path,
+            aux_env, bp::start_dir(conf.server_working_dir));
+
+        socket.emplace(aux_ctx, true);
+        socket->connect(conf.serv_addr);
+
+        socket_signal.emplace(&aux_ctx);
+        socket_watcher.emplace(socket_signal->new_watcher());
+        request_worker = make_infinite_coro_worker(
+            [this] { return handle_request(); },
+            coro_worker::empty_func, &aux_ctx);
+        request_worker->run();
+
+        if (true) {
+            auto blocker = boost::asio::make_work_guard(aux_ctx);
+            aux_ctx.run();
+        }
+
+        // wait coro worker to exit
+        request_worker->cancel();
+        aux_ctx.restart();
+        aux_ctx.run();
+
+        // cleanup
+        aux_proc.terminate();
+    }
+
+    void on_request(const request_type &req) {
+        if (!request_chan->try_send(error_code(), req)) {
+            throw channel_is_full_error();
+        }
+        ++queued_requests;
+    }
+
+    ba::awaitable<void> on_request_async(const request_type &req) {
+        auto closer = sg::make_scope_guard([this] { ++queued_requests; });
+        return request_chan->async_send(error_code(), req, ba::use_awaitable);
+    }
+
+    explicit impl(create_config _conf)
+        : conf(std::move(_conf)) {
+        request_chan.emplace(aux_ctx, max_channel_buffer);
+        aux_thread.emplace([this] { start_server(); });
+    }
+
+    ~impl() {
+        aux_ctx.stop();
+        aux_thread->join();
+    }
+};
+
+zmq_client::zmq_client(const create_config &conf)
+        : pimpl(std::make_unique<impl>(conf)) {
+}
+
+zmq_client::~zmq_client() = default;
+
+void zmq_client::request(const request_type &req) const {
+    pimpl->on_request(req);
+}
+
+ba::awaitable<void> zmq_client::coro_request(const request_type &req) const {
+    return pimpl->on_request_async(req);
+}
+
+size_t zmq_client::queued_requests() const {
+    return pimpl->queued_requests;
+}

+ 17 - 19
src/ai/zmq_client.h → src/module_v5/zmq_client.h

@@ -1,9 +1,9 @@
-#ifndef DEPTHGUIDE_ZMQ_CLIENT_H
-#define DEPTHGUIDE_ZMQ_CLIENT_H
+#ifndef ZMQ_CLIENT_H
+#define ZMQ_CLIENT_H
 
-#include "core/cuda_helper.hpp"
-#include "core/image_utility_v2.h"
+#include "image_process_v5/sp_image.h"
 
+#include <boost/asio/awaitable.hpp>
 #include <boost/container/static_vector.hpp>
 
 #include <nlohmann/json.hpp>
@@ -12,16 +12,9 @@
 
 class zmq_client {
 public:
-
     static constexpr auto max_img_cnt = 2;
     using img_list_type =
-            boost::container::static_vector<image_ptr, max_img_cnt>;
-
-    struct request_type {
-        img_list_type img_list;
-        std::string model_name;
-        nlohmann::json extra;
-    };
+        boost::container::static_vector<sp_image, max_img_cnt>;
 
     struct reply_type {
         img_list_type img_list;
@@ -31,29 +24,34 @@ public:
     // reply will be call in another thread
     using reply_cb_type = std::function<void(reply_type)>;
 
+    struct request_type {
+        img_list_type img_list;
+        std::string model_name;
+        nlohmann::json extra;
+        reply_cb_type reply_cb;
+    };
+
     struct create_config {
         std::string python_interpreter;
         std::string server_working_dir;
         std::string server_script_path;
         std::string serv_addr; // like "ipc://fast_sam_v1"
-        reply_cb_type reply_cb;
-        CUcontext *cuda_ctx = nullptr;
     };
 
     explicit zmq_client(const create_config &conf);
 
     ~zmq_client();
 
-    // requests will be dropped if the previous request has not been finished.
-    void request(const request_type &req);
+    void request(const request_type &req) const;
 
-    bool is_ideal() const;
+    [[nodiscard]] boost::asio::awaitable<void>
+    coro_request(const request_type &req) const;
 
-    bool is_running() const;
+    [[nodiscard]] size_t queued_requests() const;
 
 private:
     struct impl;
     std::unique_ptr<impl> pimpl;
 };
 
-#endif //DEPTHGUIDE_ZMQ_CLIENT_H
+#endif //ZMQ_CLIENT_H

+ 4 - 0
src/network/binary_utility.hpp

@@ -185,6 +185,10 @@ public:
         assert(cur_ptr <= end_ptr());
     }
 
+    auto current_ptr() const {
+        return cur_ptr;
+    }
+
     auto current_offset() const {
         return cur_ptr - start_ptr();
     }