frame_sender.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. #include "config.h"
  2. #include "frame_sender.h"
  3. #include "third_party/scope_guard.hpp"
  4. extern "C" {
  5. #include "third_party/rs.h"
  6. }
  7. #include <boost/asio/awaitable.hpp>
  8. #include <boost/asio/buffer.hpp>
  9. #include <boost/asio/deadline_timer.hpp>
  10. #include <boost/asio/detached.hpp>
  11. #include <boost/asio/experimental/awaitable_operators.hpp>
  12. #include <boost/asio/experimental/concurrent_channel.hpp>
  13. #include <boost/asio/io_context.hpp>
  14. #include <boost/asio/ip/udp.hpp>
  15. #include <boost/asio/redirect_error.hpp>
  16. #include <boost/asio/use_awaitable.hpp>
  17. #include <boost/endian.hpp>
  18. #include <boost/date_time/posix_time/posix_time.hpp>
  19. #include <boost/smart_ptr.hpp>
  20. #include <xxhash.h>
  21. #include <spdlog/spdlog.h>
  22. #include <cstdint>
  23. #include <deque>
  24. #include <numeric>
  25. #include <random>
  26. #include <thread>
  27. #include <tuple>
  28. #include <vector>
  29. using namespace boost::asio::experimental::awaitable_operators;
  30. using namespace boost::asio::ip;
  31. using namespace boost::posix_time;
  32. using boost::asio::awaitable;
  33. using boost::asio::buffer;
  34. using boost::asio::deadline_timer;
  35. using boost::asio::detached;
  36. using boost::asio::experimental::concurrent_channel;
  37. using boost::asio::io_context;
  38. using boost::asio::redirect_error;
  39. using boost::asio::use_awaitable;
  40. using boost::system::error_code;
  41. #define EXCEPTION_CHECK(api_call) \
  42. try { \
  43. api_call; \
  44. } catch (std::exception &e) { \
  45. SPDLOG_ERROR("Procedure {} failed at {}:{} with exception {}.", \
  46. #api_call, __FILE__, __LINE__, e.what()); \
  47. return false; \
  48. } void(0)
  49. #define CORO_CHECK(api_call) { \
  50. bool ok = co_await (api_call); \
  51. if (!ok) { \
  52. SPDLOG_ERROR("Coroutine {} failed at {}:{}.", \
  53. #api_call, __FILE__, __LINE__); \
  54. co_return false; \
  55. } \
  56. } void(0)
  57. struct frame_sender::impl {
  58. static constexpr auto buffer_size = 64 * 1024; // 64KB
  59. static constexpr auto rtt_probe_count = 30;
  60. static constexpr auto max_loss_rate = 0.2; // 20% packet loss rate
  61. static constexpr auto frag_header_size = 35;
  62. static constexpr auto channel_buffer_size = 16;
  63. struct frag_header {
  64. uint64_t frag_checksum;
  65. uint8_t frame_type;
  66. uint64_t frame_salt;
  67. uint32_t frame_id;
  68. uint32_t frame_length;
  69. uint32_t block_size;
  70. uint16_t block_count;
  71. uint16_t frame_decode_count;
  72. uint16_t block_id;
  73. };
  74. struct sent_frame_info {
  75. uint64_t salt;
  76. uint32_t id;
  77. ptime time;
  78. };
  79. uint16_t local_port = 5277;
  80. udp::endpoint remote_endpoint;
  81. uint64_t conn_rtt_us = 50; // connection round trip time (RTT)
  82. uint16_t conn_mtu = 1200;
  83. double parity_rate = 0.2;
  84. boost::scoped_ptr<io_context> context;
  85. boost::scoped_ptr<udp::socket> socket;
  86. using chan_type = concurrent_channel<void(error_code, frame_info)>;
  87. boost::scoped_ptr<chan_type> chan;
  88. char *in_data = nullptr, *out_data = nullptr;
  89. enum status_type {
  90. IDLE,
  91. CONNECTING,
  92. CONNECTED
  93. } status = IDLE;
  94. uint32_t frame_count = 0;
  95. std::atomic_flag *idr_flag = nullptr;
  96. int frame_rate = default_camera_fps;
  97. time_duration frame_timeout, conn_timeout;
  98. ptime last_confirm_time;
  99. boost::scoped_ptr<deadline_timer> keepalive_timer;
  100. std::deque<sent_frame_info> sent_list; // pending confirm list
  101. std::unique_ptr<std::thread> work_thread;
  102. impl() {
  103. in_data = (char *) malloc(buffer_size);
  104. out_data = (char *) malloc(buffer_size);
  105. context.reset(new io_context{});
  106. chan.reset(new chan_type{*context, channel_buffer_size});
  107. keepalive_timer.reset(new deadline_timer{*context});
  108. auto error_handler = [](std::exception_ptr ep) {
  109. if (!ep) {
  110. SPDLOG_ERROR("Infinite loop exited with no error.");
  111. return;
  112. }
  113. try {
  114. std::rethrow_exception(ep);
  115. } catch (std::exception &e) {
  116. SPDLOG_ERROR("Infinite loop exited with error: {}", e.what());
  117. }
  118. };
  119. co_spawn(*context, main_loop(), error_handler);
  120. co_spawn(*context, keepalive_loop(), error_handler);
  121. }
  122. ~impl() {
  123. stop();
  124. free(in_data);
  125. free(out_data);
  126. }
  127. static uint64_t generate_salt() {
  128. static std::random_device device;
  129. static std::default_random_engine engine{device()};
  130. static std::uniform_int_distribution<uint64_t> dist;
  131. return dist(engine);
  132. }
  133. template<typename T>
  134. static char *write_binary_number(char *ptr, T val) {
  135. static constexpr auto need_swap =
  136. (boost::endian::order::native != boost::endian::order::big);
  137. auto real_ptr = (T *) ptr;
  138. if constexpr (need_swap) {
  139. *real_ptr = boost::endian::endian_reverse(val);
  140. } else {
  141. *real_ptr = val;
  142. }
  143. return ptr + sizeof(T);
  144. }
  145. template<typename T>
  146. static char *read_binary_number(char *ptr, T *val) {
  147. static constexpr auto need_swap =
  148. (boost::endian::order::native != boost::endian::order::big);
  149. *val = *(T *) ptr;
  150. if constexpr (need_swap) {
  151. boost::endian::endian_reverse_inplace(*val);
  152. }
  153. return ptr + sizeof(T);
  154. }
  155. static char *write_frag_header(char *ptr, const frag_header *header) {
  156. #define WRITE(member) ptr = write_binary_number(ptr, header->member)
  157. WRITE(frag_checksum);
  158. WRITE(frame_type);
  159. WRITE(frame_salt);
  160. WRITE(frame_id);
  161. WRITE(frame_length);
  162. WRITE(block_size);
  163. WRITE(block_count);
  164. WRITE(frame_decode_count);
  165. WRITE(block_id);
  166. #undef WRITE
  167. return ptr;
  168. }
  169. // calculate and fill hash value for out buffer
  170. bool calc_out_hash(char *end_ptr) {
  171. assert(end_ptr - out_data > sizeof(uint64_t));
  172. static auto hash_state = XXH64_createState();
  173. auto out_ptr = out_data + sizeof(uint64_t);
  174. CALL_CHECK(XXH64_reset(hash_state, 0) != XXH_ERROR);
  175. CALL_CHECK(XXH64_update(hash_state, out_ptr, end_ptr - out_ptr) != XXH_ERROR);
  176. write_binary_number(out_data, XXH64_digest(hash_state));
  177. return true;
  178. }
  179. bool check_rtt_reply(uint64_t salt, uint16_t out_len, uint16_t in_len) {
  180. static constexpr auto desired_length =
  181. sizeof(uint8_t) + sizeof(uint64_t) + sizeof(uint16_t);
  182. if (in_len != desired_length) return false;
  183. // check frag type
  184. if (in_data[0] != 'R') return false;
  185. // check frame salt
  186. uint64_t in_salt;
  187. auto in_ptr = read_binary_number(in_data + sizeof(uint8_t), &in_salt);
  188. if (in_salt != salt) return false;
  189. // check returned length
  190. uint16_t in_frag_len;
  191. read_binary_number(in_ptr, &in_frag_len);
  192. if (in_frag_len != out_len) return false;
  193. return true;
  194. }
  195. template<typename T>
  196. static T power2(T x) { return x * x; }
  197. static uint64_t calc_upper_rtt(const std::vector<uint64_t> &v) {
  198. auto sum = std::accumulate(v.begin(), v.end(), 0.0);
  199. auto mean = sum / (double) v.size();
  200. auto sum2 = std::accumulate(v.begin(), v.end(), 0.0,
  201. [=](double a, uint64_t b) { return a + power2((double) b - mean); });
  202. auto std_var = std::sqrt(sum2 / (double) v.size());
  203. return (uint64_t) (mean + 5 * std_var);
  204. }
  205. awaitable<bool> probe_rtt() {
  206. static const auto max_rtt = seconds(1);
  207. auto timer = deadline_timer{*context};
  208. std::vector<uint64_t> rtt_result;
  209. auto in_buf = buffer(in_data, buffer_size);
  210. for (int k = 0; k < rtt_probe_count; ++k) {
  211. auto salt = generate_salt();
  212. // write probe frag data
  213. auto out_ptr = out_data;
  214. out_ptr = write_binary_number(out_ptr, (uint64_t) 0); // checksum placeholder
  215. out_ptr = write_binary_number(out_ptr, 'T');
  216. out_ptr = write_binary_number(out_ptr, salt);
  217. // fill frag with random data
  218. auto limit_ptr = out_data + conn_mtu;
  219. auto content_len = 0;
  220. while (out_ptr + sizeof(uint64_t) < limit_ptr) {
  221. out_ptr = write_binary_number(out_ptr, generate_salt());
  222. content_len += sizeof(uint64_t);
  223. }
  224. calc_out_hash(out_ptr);
  225. auto out_buf = buffer(out_data, out_ptr - out_data);
  226. socket->send_to(out_buf, remote_endpoint);
  227. // wait for reply or timeout
  228. auto start_time = microsec_clock::local_time();
  229. timer.expires_from_now(max_rtt);
  230. for (;;) {
  231. udp::endpoint sender_endpoint;
  232. auto ret = co_await (socket->async_receive_from(in_buf, sender_endpoint, use_awaitable) ||
  233. timer.async_wait(use_awaitable));
  234. if (ret.index() == 0) { // received reply
  235. if (sender_endpoint != remote_endpoint) continue;
  236. if (check_rtt_reply(salt, content_len, std::get<0>(ret))) {
  237. auto end_time = microsec_clock::local_time();
  238. auto rtt_us = (end_time - start_time).total_microseconds();
  239. rtt_result.push_back(rtt_us);
  240. SPDLOG_TRACE("RTT probe {}: {}us.", k, rtt_us);
  241. break;
  242. }
  243. } else { // timeout
  244. assert(ret.index() == 1);
  245. SPDLOG_TRACE("RTT probe {}: failed.", k);
  246. break;
  247. }
  248. }
  249. }
  250. if (rtt_result.size() <= (int) (rtt_probe_count * max_loss_rate)) {
  251. SPDLOG_WARN("Packet loss rate too high, cannot probe RTT.");
  252. co_return false;
  253. }
  254. conn_rtt_us = calc_upper_rtt(rtt_result);
  255. SPDLOG_INFO("Connection MaxRTT: {}us.", conn_rtt_us);
  256. co_return true;
  257. }
  258. awaitable<bool> setup_connection() {
  259. // socket->connect(remote_endpoint);
  260. CORO_CHECK(probe_rtt());
  261. // TODO: detect mtu
  262. // TODO: detect packet loss rate
  263. // reset timer
  264. frame_timeout = milliseconds(conn_rtt_us / 1000 + 3 * 1000 / frame_rate);
  265. conn_timeout = seconds(1); // TODO
  266. last_confirm_time = microsec_clock::local_time();
  267. // keepalive_timer->expires_at(boost::posix_time::pos_infin);
  268. sent_list.clear();
  269. idr_flag->test_and_set();
  270. co_return true;
  271. }
  272. void handle_frame_confirm(size_t msg_len) {
  273. static constexpr auto desired_length =
  274. sizeof(uint8_t) + sizeof(uint64_t);
  275. if (msg_len != desired_length) return;
  276. // read salt
  277. uint64_t frame_salt;
  278. read_binary_number(in_data + 1, &frame_salt);
  279. static uint64_t last_frame_salt;
  280. if (frame_salt == last_frame_salt) return; // already confirmed
  281. // erase confirmed frame
  282. auto iter = sent_list.begin();
  283. while (iter != sent_list.end() && iter->salt != frame_salt) ++iter;
  284. if (iter == sent_list.end()) return;
  285. SPDLOG_TRACE("Frame {} confirmed.", iter->id);
  286. sent_list.erase(sent_list.begin(), ++iter);
  287. // reset timer
  288. if (sent_list.empty()) {
  289. keepalive_timer->expires_at(pos_infin);
  290. } else {
  291. keepalive_timer->expires_at(sent_list.begin()->time + frame_timeout);
  292. }
  293. last_confirm_time = microsec_clock::local_time();
  294. }
  295. void handle_upd_message(size_t msg_len, const udp::endpoint &sender) {
  296. assert(status != CONNECTING);
  297. if (status == IDLE) {
  298. if (msg_len == 1 && in_data[0] == 'R') { // reset connection
  299. if (status == CONNECTING) {
  300. SPDLOG_WARN("Only one connection is supported.");
  301. return;
  302. }
  303. status = CONNECTING;
  304. remote_endpoint = sender;
  305. SPDLOG_INFO("Reset connection with {}:{}.", sender.address().to_string(), sender.port());
  306. co_spawn(*context, setup_connection(), [this](std::exception_ptr e, bool ok) {
  307. assert(!e);
  308. SPDLOG_INFO("Reset connection {}.", ok ? "succeeded" : "failed");
  309. if (ok) {
  310. status = CONNECTED;
  311. } else {
  312. status = IDLE;
  313. remote_endpoint = {};
  314. }
  315. });
  316. }
  317. } else if (status == CONNECTED) {
  318. if (sender != remote_endpoint) return;
  319. if (msg_len == 1 && in_data[0] == 'E') { // client exit
  320. keepalive_timer->expires_at(pos_infin);
  321. status = IDLE;
  322. SPDLOG_INFO("Client left.");
  323. } else if (in_data[0] == 'C') { // confirmation
  324. handle_frame_confirm(msg_len);
  325. }
  326. }
  327. }
  328. void handle_frame(const frame_info &info) {
  329. ++frame_count;
  330. auto frame_deleter = sg::make_scope_guard([&]() {
  331. free(info.data);
  332. });
  333. if (status != CONNECTED) {
  334. SPDLOG_TRACE("Frame {} received, but connection is not ready.");
  335. return;
  336. }
  337. // prepare buffer for frame
  338. auto block_size = (conn_mtu - frag_header_size) & 0xffffff00; // TODO: support for larger frame
  339. auto data_blocks = (info.length + block_size - 1) / block_size;
  340. auto parity_blocks = std::max(1, (int) (data_blocks * parity_rate));
  341. auto total_blocks = data_blocks + parity_blocks;
  342. auto block_data = (uint8_t *) malloc(total_blocks * block_size);
  343. auto block_ptr = (uint8_t **) malloc(total_blocks * sizeof(void *));
  344. for (int i = 0; i < total_blocks; ++i) {
  345. block_ptr[i] = block_data + block_size * i;
  346. }
  347. auto rs = reed_solomon_new(data_blocks, parity_blocks);
  348. assert(rs != nullptr);
  349. auto closer = sg::make_scope_guard([&]() {
  350. free(block_data);
  351. free(block_ptr);
  352. reed_solomon_release(rs);
  353. });
  354. // calc reed-solomon
  355. memcpy(block_data, info.data, info.length);
  356. auto ret = reed_solomon_encode2(rs, block_ptr, total_blocks, block_size);
  357. assert(ret == 0);
  358. // send encoded frames
  359. frag_header header;
  360. header.frame_type = info.is_idr_frame ? 'I' : 'P';
  361. header.frame_salt = generate_salt();
  362. header.frame_id = frame_count;
  363. header.frame_length = info.length;
  364. header.block_size = block_size;
  365. header.block_count = total_blocks;
  366. header.frame_decode_count = data_blocks;
  367. for (int i = 0; i < total_blocks; ++i) {
  368. header.block_id = i;
  369. auto out_ptr = write_frag_header(out_data, &header);
  370. assert(out_ptr - out_data == frag_header_size);
  371. memcpy(out_ptr, block_ptr[i], block_size);
  372. out_ptr += block_size;
  373. calc_out_hash(out_ptr);
  374. auto out_buf = buffer(out_data, out_ptr - out_data);
  375. socket->send_to(out_buf, remote_endpoint);
  376. }
  377. SPDLOG_TRACE("Frame {} is sent with {}+{} blocks.",
  378. header.frame_id, header.block_count, header.block_count - header.frame_decode_count);
  379. // config frame queue and timeout
  380. if (keepalive_timer->expires_at() == pos_infin) {
  381. keepalive_timer->expires_from_now(frame_timeout);
  382. SPDLOG_TRACE("Timer reset to {}.", to_simple_string(keepalive_timer->expires_at()));
  383. }
  384. sent_list.push_back({header.frame_salt, header.frame_id,
  385. microsec_clock::local_time()});
  386. }
  387. awaitable<void> main_loop() {
  388. auto in_buf = buffer(in_data, buffer_size);
  389. for (;;) {
  390. if (status == CONNECTING) {
  391. auto ret = co_await chan->async_receive(use_awaitable);
  392. handle_frame(ret);
  393. } else { // IDLE or CONNECTED
  394. udp::endpoint sender_endpoint;
  395. auto ret = co_await (socket->async_receive_from(in_buf, sender_endpoint, use_awaitable) ||
  396. chan->async_receive(use_awaitable));
  397. if (ret.index() == 0) { // udp message
  398. handle_upd_message(std::get<0>(ret), sender_endpoint);
  399. } else { // new frame
  400. assert(ret.index() == 1);
  401. handle_frame(std::get<1>(ret));
  402. }
  403. }
  404. }
  405. }
  406. awaitable<void> keepalive_loop() {
  407. for (;;) {
  408. error_code ec;
  409. co_await keepalive_timer->async_wait(redirect_error(use_awaitable, ec));
  410. if (ec == boost::asio::error::operation_aborted) {
  411. SPDLOG_TRACE("Timer aborted.");
  412. continue;
  413. }
  414. SPDLOG_WARN("Connection timeout.");
  415. keepalive_timer->expires_at(pos_infin);
  416. auto now = microsec_clock::local_time();
  417. if (now - last_confirm_time > conn_timeout) {
  418. status = IDLE;
  419. SPDLOG_WARN("Connection closed.");
  420. } else {
  421. idr_flag->test_and_set();
  422. }
  423. sent_list.clear();
  424. }
  425. }
  426. void start() {
  427. // clean channel
  428. if (chan != nullptr) {
  429. while (chan->try_receive([](error_code e, frame_info &&info) {
  430. free(info.data);
  431. }));
  432. }
  433. keepalive_timer->expires_at(pos_infin);
  434. auto local_endpoint = udp::endpoint{udp::v4(), local_port};
  435. socket.reset(new udp::socket{*context, local_endpoint});
  436. socket->set_option(udp::socket::send_buffer_size{10 * 1024 * 1024}); // 10MB send buffer
  437. assert(socket->is_open());
  438. if (context->stopped()) {
  439. context->restart();
  440. }
  441. // request idr frame
  442. idr_flag->test_and_set();
  443. assert(work_thread == nullptr);
  444. work_thread = std::make_unique<std::thread>([this]() {
  445. try {
  446. context->run();
  447. } catch (std::exception &e) {
  448. SPDLOG_ERROR("Frame sender error: {}", e.what());
  449. }
  450. });
  451. }
  452. void stop() {
  453. if (work_thread == nullptr) return;
  454. context->stop();
  455. work_thread->join();
  456. work_thread = nullptr;
  457. socket->close();
  458. }
  459. };
  460. frame_sender::frame_sender()
  461. : pimpl(std::make_unique<impl>()) {
  462. fec_init();
  463. }
  464. frame_sender::~frame_sender() = default;
  465. bool frame_sender::start(uint16_t local_port, std::atomic_flag *idr_flag, int fps) {
  466. pimpl->local_port = local_port;
  467. pimpl->idr_flag = idr_flag;
  468. pimpl->frame_rate = fps;
  469. EXCEPTION_CHECK(pimpl->start());
  470. return true;
  471. }
  472. void frame_sender::stop() {
  473. pimpl->stop();
  474. }
  475. bool frame_sender::send_frame(const frame_sender::frame_info &info) {
  476. CALL_CHECK(pimpl->chan->try_send(error_code{}, info));
  477. return true;
  478. }
  479. bool frame_sender::is_running() {
  480. return pimpl->work_thread != nullptr;
  481. }