registration.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. #include "registration.h"
  2. #include "core/utility.hpp"
  3. #include "render_v3/vtk_viewer.h"
  4. #include <vtkActor.h>
  5. #include <vtkCellLocator.h>
  6. #include <vtkIterativeClosestPointTransform.h>
  7. #include <vtkLandmarkTransform.h>
  8. #include <vtkMatrix4x4.h>
  9. #include <vtkNamedColors.h>
  10. #include <vtkPolyData.h>
  11. #include <vtkProperty.h>
  12. #include <vtkSmartPointer.h>
  13. #include <queue>
  14. #include <vector>
  15. using namespace vtk_viewer_helper;
  16. struct registration::impl {
  17. static constexpr auto MIN_REG_POINTS = 3;
  18. static constexpr auto DIS_LIMIT_1 = 0.5;
  19. static constexpr auto DIS_LIMIT_2 = 1.0;
  20. static constexpr auto NAN_VALUE = std::numeric_limits<float>::quiet_NaN();
  21. enum show_mode {
  22. CONFIG,
  23. COLLECTING,
  24. FINISHED
  25. };
  26. struct target_store_type {
  27. std::string name;
  28. vtkSmartPointer<vtkPolyData> model;
  29. vtkSmartPointer<vtkActor> model_actor;
  30. vtkSmartPointer<vtkCellLocator> model_locator;
  31. std::unique_ptr<smart_point_sets> all_points;
  32. std::unique_ptr<smart_point_sets> pending_points;
  33. std::unique_ptr<smart_point_sets> current_point;
  34. std::unique_ptr<smart_point_sets> finished_points;
  35. std::unique_ptr<smart_point_sets> level_points[3];
  36. std::string target_var_name;
  37. std::string point_var_name;
  38. std::string collect_obj_name;
  39. std::string probe_var_name;
  40. bool is_finished = false;
  41. float max_error = 0;
  42. };
  43. int cur_target_id = -1;
  44. std::vector<target_store_type> targets;
  45. target_store_type *cur_target = nullptr;
  46. std::unique_ptr<vtk_viewer> viewer;
  47. bool is_picking = false;
  48. bool is_collecting = false;
  49. std::vector<Eigen::Vector3d> collected_points;
  50. std::queue<std::function<void()>> eq;
  51. sophiar::local_connection *conn = nullptr;
  52. vtkSmartPointer<vtkActor> probe_actor;
  53. float probe_test_error = NAN_VALUE;
  54. // colors
  55. vtkColor4d target_color;
  56. vtkColor4d idle_color;
  57. vtkColor4d pending_color;
  58. vtkColor4d current_color;
  59. vtkColor4d finished_color;
  60. vtkColor4d level_color[3]; // - 0.5mm - 1.0mm -
  61. impl() {
  62. viewer = std::make_unique<vtk_viewer>();
  63. // preload colors
  64. vtkNew<vtkNamedColors> colors;
  65. target_color = colors->GetColor4d("silver");
  66. idle_color = colors->GetColor4d("white");
  67. pending_color = colors->GetColor4d("red");
  68. current_color = colors->GetColor4d("yellow");
  69. finished_color = colors->GetColor4d("lime");
  70. level_color[0] = colors->GetColor4d("lime");
  71. level_color[1] = colors->GetColor4d("yellow");
  72. level_color[2] = colors->GetColor4d("red");
  73. }
  74. void switch_viewer_mode(show_mode mode) {
  75. assert(cur_target != nullptr);
  76. viewer->clear_actor();
  77. viewer->add_actor(cur_target->model_actor);
  78. viewer->add_actor(probe_actor);
  79. switch (mode) {
  80. case CONFIG: {
  81. viewer->add_actor(cur_target->all_points->get_actor());
  82. break;
  83. }
  84. case COLLECTING: {
  85. viewer->add_actor(cur_target->pending_points->get_actor());
  86. viewer->add_actor(cur_target->current_point->get_actor());
  87. viewer->add_actor(cur_target->finished_points->get_actor());
  88. break;
  89. }
  90. case FINISHED: {
  91. for (auto &points: cur_target->level_points) {
  92. viewer->add_actor(points->get_actor());
  93. }
  94. break;
  95. }
  96. }
  97. }
  98. void add_target(const registration_target &conf) {
  99. auto &target = targets.emplace_back();
  100. target.name = conf.name;
  101. target.model = load_any(conf.model_path);
  102. target.model_actor = create_actor(target.model);
  103. target.model_locator = vtkSmartPointer<vtkCellLocator>::New();
  104. target.model_locator->SetDataSet(target.model);
  105. target.model_locator->BuildLocator();
  106. // copy information
  107. target.target_var_name = conf.target_var_name;
  108. target.collect_obj_name = conf.collect_obj_name;
  109. target.point_var_name = conf.collect_var_name;
  110. target.probe_var_name = conf.probe_var_name;
  111. // create point sets
  112. target.all_points = std::make_unique<smart_point_sets>();
  113. target.pending_points = std::make_unique<smart_point_sets>();
  114. target.current_point = std::make_unique<smart_point_sets>();
  115. target.finished_points = std::make_unique<smart_point_sets>();
  116. for (auto &points: target.level_points) {
  117. points = std::make_unique<smart_point_sets>();
  118. }
  119. // set colors
  120. target.model_actor->GetProperty()->SetColor(target_color.GetData());
  121. target.all_points->get_actor()->GetProperty()->SetColor(idle_color.GetData());
  122. target.pending_points->get_actor()->GetProperty()->SetColor(pending_color.GetData());
  123. target.current_point->get_actor()->GetProperty()->SetColor(current_color.GetData());
  124. target.finished_points->get_actor()->GetProperty()->SetColor(finished_color.GetData());
  125. for (auto k = 0; k < 3; ++k) {
  126. target.level_points[k]->get_actor()->GetProperty()->SetColor(level_color[k].GetData());
  127. }
  128. }
  129. void change_target() {
  130. assert(cur_target_id != -1);
  131. cur_target = &targets[cur_target_id];
  132. if (cur_target->is_finished) {
  133. switch_viewer_mode(FINISHED);
  134. } else {
  135. switch_viewer_mode(CONFIG);
  136. }
  137. viewer->reset_camera();
  138. }
  139. bool progress_reg_point() {
  140. if (!cur_target->current_point->empty()) {
  141. cur_target->finished_points->add_point(
  142. cur_target->current_point->pop_front());
  143. }
  144. if (cur_target->pending_points->empty()) return false;
  145. cur_target->current_point->add_point(
  146. cur_target->pending_points->pop_front());
  147. return true;
  148. }
  149. void start() {
  150. assert(cur_target != nullptr);
  151. assert(!is_collecting);
  152. is_collecting = true;
  153. cur_target->is_finished = false;
  154. conn->mark_variable_disposal(cur_target->point_var_name);
  155. CALL_CHECK(conn->start_object(cur_target->collect_obj_name));
  156. collected_points.clear();
  157. // copy points
  158. cur_target->all_points->for_each([this](void *, const Eigen::Vector3d &point) {
  159. cur_target->pending_points->add_point(point);
  160. });
  161. progress_reg_point();
  162. // switch actor
  163. switch_viewer_mode(COLLECTING);
  164. }
  165. void stop() {
  166. assert(is_collecting);
  167. is_collecting = false;
  168. CALL_CHECK(conn->stop_object(cur_target->collect_obj_name));
  169. // clear point set
  170. cur_target->pending_points->clear();
  171. cur_target->current_point->clear();
  172. cur_target->finished_points->clear();
  173. // switch actor
  174. if (cur_target->is_finished) {
  175. switch_viewer_mode(FINISHED);
  176. } else {
  177. switch_viewer_mode(CONFIG);
  178. }
  179. }
  180. auto calc_closest_point(const Eigen::Vector3d &point) {
  181. Eigen::Vector3d close_point;
  182. vtkIdType cell_id;
  183. int sub_id;
  184. double dis2;
  185. cur_target->model_locator->FindClosestPoint(
  186. point.data(), close_point.data(), cell_id, sub_id, dis2);
  187. return std::make_tuple(close_point, std::sqrt(dis2));
  188. }
  189. // return if it needs to continue
  190. bool calc_result() {
  191. // prepare landmark
  192. auto num_points = collected_points.size();
  193. if (num_points < MIN_REG_POINTS) return true;
  194. auto source_points = Eigen::Matrix3Xd{3, num_points};
  195. auto target_points = Eigen::Matrix3Xd{3, num_points};
  196. for (auto k = 0; k < num_points; ++k) {
  197. source_points.col(k) = (*cur_target->all_points)[k];
  198. }
  199. for (auto k = 0; k < num_points; ++k) {
  200. target_points.col(k) = collected_points[k];
  201. }
  202. // calculate landmark
  203. auto result = (Eigen::Isometry3d) Eigen::umeyama(source_points, target_points, false);
  204. // prepare icp
  205. vtkNew<vtkPoints> icp_points;
  206. auto landmark_inv = result.inverse();
  207. for (auto k = 0; k < num_points; ++k) {
  208. Eigen::Vector3d point = landmark_inv * collected_points[k];
  209. icp_points->InsertNextPoint(point.data());
  210. }
  211. // calculate icp
  212. vtkNew<vtkIterativeClosestPointTransform> icp;
  213. vtkNew<vtkPolyData> tmp_poly;
  214. tmp_poly->SetPoints(icp_points);
  215. icp->GetLandmarkTransform()->SetModeToRigidBody();
  216. icp->SetSource(tmp_poly);
  217. icp->SetTarget(cur_target->model);
  218. icp->Modified();
  219. icp->Update();
  220. // refine result
  221. Eigen::Isometry3d trans_delta;
  222. for (auto i = 0; i < 4; ++i)
  223. for (auto j = 0; j < 4; ++j) {
  224. trans_delta(i, j) = icp->GetMatrix()->GetElement(i, j);
  225. }
  226. result = result * trans_delta.inverse();
  227. // commit result
  228. conn->update_transform_variable(cur_target->target_var_name, result);
  229. // calculate error only when all points are collected
  230. if (num_points != cur_target->all_points->size()) return true;
  231. // calculate error
  232. cur_target->max_error = 0;
  233. for (auto &points: cur_target->level_points) {
  234. points->clear();
  235. }
  236. auto result_inv = result.inverse();
  237. for (auto k = 0; k < num_points; ++k) {
  238. Eigen::Vector3d point = result_inv * collected_points[k];
  239. // find the closest point
  240. double dis;
  241. std::tie(std::ignore, dis) = calc_closest_point(point);
  242. // update results
  243. if (dis > cur_target->max_error) {
  244. cur_target->max_error = dis;
  245. }
  246. if (dis < DIS_LIMIT_1) {
  247. cur_target->level_points[0]->add_point(point);
  248. } else if (dis < DIS_LIMIT_2) {
  249. cur_target->level_points[1]->add_point(point);
  250. } else {
  251. cur_target->level_points[2]->add_point(point);
  252. }
  253. }
  254. cur_target->is_finished = true;
  255. return false;
  256. }
  257. void try_picking() {
  258. assert(is_picking);
  259. auto val = viewer->get_picked_point();
  260. if (!val.has_value()) return;
  261. cur_target->all_points->add_point(val.value());
  262. }
  263. void try_collect() {
  264. assert(is_collecting);
  265. auto val = conn->query_scalarxyz_variable(cur_target->point_var_name);
  266. if (!val.has_value()) return;
  267. collected_points.emplace_back(val.value());
  268. SPDLOG_INFO("Collected point ({}, {}, {}).", val->x(), val->y(), val->z());
  269. calc_result();
  270. if (!progress_reg_point()) {
  271. stop();
  272. }
  273. }
  274. void show() {
  275. if (ImGui::Begin("Registration Control")) {
  276. if (is_collecting || is_picking) {
  277. ImGui::BeginDisabled();
  278. }
  279. const char *target_name = (cur_target == nullptr) ? nullptr : cur_target->name.c_str();
  280. if (ImGui::BeginCombo("Target", target_name)) {
  281. for (auto k = 0; k < targets.size(); ++k) {
  282. bool is_selected = (k == cur_target_id);
  283. if (ImGui::Selectable(targets[k].name.c_str(), is_selected)) {
  284. cur_target_id = k;
  285. eq.emplace([this] { change_target(); });
  286. }
  287. if (is_selected) {
  288. ImGui::SetItemDefaultFocus();
  289. }
  290. }
  291. ImGui::EndCombo();
  292. }
  293. if (is_collecting || is_picking) {
  294. ImGui::EndDisabled();
  295. }
  296. if (cur_target != nullptr) {
  297. auto point_set = cur_target->all_points.get();
  298. if (ImGui::CollapsingHeader("Actions")) {
  299. if (is_collecting) {
  300. ImGui::BeginDisabled();
  301. }
  302. if (ImGui::Checkbox("Config Points", &is_picking)) {
  303. if (is_picking) {
  304. eq.emplace([this] {
  305. switch_viewer_mode(CONFIG);
  306. viewer->start_picking();
  307. });
  308. } else {
  309. eq.emplace([this] { viewer->stop_picking(); });
  310. }
  311. }
  312. if (is_collecting) {
  313. ImGui::EndDisabled();
  314. }
  315. if (!is_picking && point_set->size() >= MIN_REG_POINTS) {
  316. ImGui::SameLine();
  317. if (!is_collecting) {
  318. if (ImGui::Button("Start")) {
  319. eq.emplace([this] { start(); });
  320. }
  321. } else {
  322. if (ImGui::Button("Stop")) {
  323. eq.emplace([this] { stop(); });
  324. }
  325. }
  326. }
  327. }
  328. if (ImGui::CollapsingHeader("Infos")) {
  329. void *token_delete = nullptr;
  330. if (ImGui::TreeNode("Fiducial Points")) {
  331. ImGui::PushItemWidth(200);
  332. point_set->for_each([&](void *token, const Eigen::Vector3d &point) {
  333. Eigen::Vector3f point_f = point.cast<float>();
  334. ImGui::PushID(token);
  335. ImGui::Bullet();
  336. ImGui::BeginDisabled();
  337. ImGui::InputFloat3("", point_f.data());
  338. ImGui::EndDisabled();
  339. if (is_picking) {
  340. ImGui::SameLine();
  341. if (ImGui::SmallButton("Delete")) {
  342. token_delete = token;
  343. }
  344. }
  345. ImGui::PopID();
  346. });
  347. ImGui::PopItemWidth();
  348. ImGui::TreePop();
  349. }
  350. if (token_delete != nullptr) {
  351. eq.emplace([=] { point_set->remove_point(token_delete); });
  352. }
  353. ImGui::PushItemWidth(100);
  354. ImGui::Bullet();
  355. ImGui::InputFloat("Probe Test Error (mm)", &probe_test_error);
  356. if (cur_target->is_finished) {
  357. ImGui::Bullet();
  358. ImGui::InputFloat("Max Fiducial Error (mm)", &cur_target->max_error);
  359. }
  360. ImGui::PopItemWidth();
  361. }
  362. }
  363. }
  364. ImGui::End();
  365. if (cur_target == nullptr) return;
  366. if (ImGui::Begin("Registration View", nullptr, vtk_viewer::no_scroll_flag)) {
  367. viewer->show();
  368. }
  369. ImGui::End();
  370. }
  371. void process() {
  372. while (!eq.empty()) {
  373. eq.front()();
  374. eq.pop();
  375. }
  376. if (is_picking) {
  377. try_picking();
  378. } else if (is_collecting) {
  379. try_collect();
  380. }
  381. // update probe transform
  382. if (cur_target != nullptr) {
  383. auto trans = conn->query_transform_variable(cur_target->probe_var_name);
  384. update_actor_pose(probe_actor, trans);
  385. if (trans.has_value()) {
  386. Eigen::Vector3d point = trans.value().translation();
  387. std::tie(std::ignore, probe_test_error) = calc_closest_point(point);
  388. } else {
  389. probe_test_error = NAN_VALUE;
  390. }
  391. }
  392. }
  393. };
  394. registration::~registration() = default;
  395. registration *registration::create(const registration_config &conf) {
  396. auto pimpl = new impl{};
  397. pimpl->conn = conf.conn;
  398. pimpl->probe_actor = create_actor(conf.probe_model_path);
  399. auto ret = new registration{};
  400. ret->pimpl.reset(pimpl);
  401. return ret;
  402. }
  403. void registration::add_target(const registration_target &item) {
  404. pimpl->add_target(item);
  405. }
  406. void registration::show() {
  407. pimpl->show();
  408. }
  409. void registration::process() {
  410. pimpl->process();
  411. }