You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

520 lines
13 KiB

4 years ago
4 years ago
  1. /*
  2. Copyright Jeroen Vreeken (jeroen@vreeken.net), 2017
  3. This program is free software: you can redistribute it and/or modify
  4. it under the terms of the GNU General Public License as published by
  5. the Free Software Foundation, either version 3 of the License, or
  6. (at your option) any later version.
  7. This program is distributed in the hope that it will be useful,
  8. but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. GNU General Public License for more details.
  11. You should have received a copy of the GNU General Public License
  12. along with this program. If not, see <http://www.gnu.org/licenses/>.
  13. */
  14. #include <dml/dml_host.h>
  15. #include <dml/dml_client.h>
  16. #include <dml/dml_connection.h>
  17. #include <dml/dml_crypto.h>
  18. #include <dml/dml_packet.h>
  19. #include <dml/dml_poll.h>
  20. #include <string.h>
  21. #include <stdio.h>
  22. struct dml_host {
  23. struct dml_client *client;
  24. struct dml_connection *connection;
  25. char **mime_filter;
  26. int mime_filter_nr;
  27. void (*connection_closed_cb)(struct dml_host *host, void *arg);
  28. void *connection_closed_cb_arg;
  29. void (*stream_added_cb)(struct dml_host *host, struct dml_stream *ds, void *arg);
  30. void *stream_added_cb_arg;
  31. void (*stream_removed_cb)(struct dml_host *host, struct dml_stream *ds, void *arg);
  32. void *stream_removed_cb_arg;
  33. void (*stream_header_cb)(struct dml_host *host, struct dml_stream *ds, void *header, size_t header_size, void *arg);
  34. void *stream_header_cb_arg;
  35. void (*stream_data_cb)(struct dml_host *host, struct dml_stream *ds, uint64_t timestamp, void *data, size_t data_size, void *arg);
  36. void *stream_data_cb_arg;
  37. void (*stream_req_reverse_connect_cb)(struct dml_host *host, struct dml_stream *ds, struct dml_stream *ds_rev, int status, void *arg);
  38. void *stream_req_reverse_connect_cb_arg;
  39. void (*stream_req_reverse_disconnect_cb)(struct dml_host *host, struct dml_stream *ds, struct dml_stream *ds_rev, int status, void *arg);
  40. void *stream_req_reverse_disconnect_cb_arg;
  41. };
  42. bool dml_host_mime_filter(struct dml_host *host, struct dml_stream *ds)
  43. {
  44. char *dmime = dml_stream_mime_get(ds);
  45. int i;
  46. for (i = 0; i < host->mime_filter_nr; i++) {
  47. if (!strcmp(host->mime_filter[i], dmime)) {
  48. return true;
  49. }
  50. }
  51. return false;
  52. }
  53. static void rx_packet(struct dml_connection *dc, void *arg,
  54. uint16_t id, uint16_t len, uint8_t *data)
  55. {
  56. struct dml_host *host = arg;
  57. switch(id) {
  58. case DML_PACKET_ROUTE: {
  59. uint8_t hops;
  60. uint8_t rid[DML_ID_SIZE];
  61. struct dml_stream *ds;
  62. if (dml_packet_parse_route(data, len, rid, &hops))
  63. break;
  64. if (hops == 255) {
  65. ds = dml_stream_by_id(rid);
  66. if (ds) {
  67. if (dml_stream_mine_get(ds))
  68. break;
  69. if (host->stream_removed_cb)
  70. host->stream_removed_cb(host, ds, host->stream_removed_cb_arg);
  71. dml_stream_remove(ds);
  72. }
  73. } else {
  74. ds = dml_stream_by_id_alloc(rid);
  75. if (!ds)
  76. break;
  77. char *mime = dml_stream_mime_get(ds);
  78. if (!mime)
  79. dml_packet_send_req_description(dc, rid);
  80. else if (dml_host_mime_filter(host, ds)) {
  81. struct dml_crypto_key *ck = dml_stream_crypto_get(ds);
  82. if (!ck)
  83. dml_packet_send_req_certificate(dc, rid);
  84. }
  85. }
  86. break;
  87. }
  88. case DML_PACKET_REQ_DESCRIPTION: {
  89. uint8_t rid[DML_ID_SIZE];
  90. if (dml_packet_parse_req_description(data, len, rid))
  91. break;
  92. struct dml_stream *ds;
  93. if ((ds = dml_stream_by_id(rid))) {
  94. dml_packet_send_description(dc, rid,
  95. DML_PACKET_DESCRIPTION_VERSION_0,
  96. dml_stream_bps_get(ds),
  97. dml_stream_mime_get(ds),
  98. dml_stream_name_get(ds),
  99. dml_stream_alias_get(ds),
  100. dml_stream_description_get(ds));
  101. }
  102. break;
  103. }
  104. case DML_PACKET_DESCRIPTION: {
  105. bool new_stream = false;
  106. struct dml_stream *ds;
  107. if (!(ds = dml_stream_update_description(data, len, &new_stream)))
  108. break;
  109. uint8_t *rid = dml_stream_id_get(ds);
  110. if (dml_host_mime_filter(host, ds)) {
  111. struct dml_crypto_key *ck = dml_stream_crypto_get(ds);
  112. if (!ck)
  113. dml_packet_send_req_certificate(dc, rid);
  114. if (new_stream && host->stream_added_cb)
  115. host->stream_added_cb(host, ds, host->stream_added_cb_arg);
  116. }
  117. break;
  118. }
  119. case DML_PACKET_REQ_CERTIFICATE: {
  120. void *cert;
  121. size_t cert_size;
  122. uint8_t rid[DML_ID_SIZE];
  123. if (dml_packet_parse_req_certificate(data, len, rid))
  124. break;
  125. if (dml_crypto_cert_get(&cert, &cert_size))
  126. break;
  127. dml_packet_send_certificate(dc, rid, cert, cert_size);
  128. break;
  129. }
  130. case DML_PACKET_CERTIFICATE: {
  131. uint8_t cid[DML_ID_SIZE];
  132. void *cert;
  133. size_t size;
  134. struct dml_stream *ds;
  135. if (dml_packet_parse_certificate(data, len, cid, &cert, &size))
  136. break;
  137. if ((ds = dml_stream_by_id(cid))) {
  138. if (dml_host_mime_filter(host, ds)) {
  139. if (dml_crypto_cert_add_verify(cert, size, cid)) {
  140. printf("Not accepting certificate for %s\n",
  141. dml_stream_name_get(ds));
  142. }
  143. }
  144. }
  145. free(cert);
  146. break;
  147. }
  148. case DML_PACKET_REQ_HEADER: {
  149. uint8_t rid[DML_ID_SIZE];
  150. if (dml_packet_parse_req_header(data, len, rid))
  151. break;
  152. struct dml_stream *ds;
  153. if ((ds = dml_stream_by_id(rid))) {
  154. uint8_t header_sig[DML_SIG_SIZE];
  155. uint8_t *header;
  156. size_t header_size;
  157. struct dml_crypto_key *dk = dml_stream_crypto_get(ds);
  158. dml_stream_header_get(ds, &header, &header_size);
  159. dml_crypto_sign(header_sig, header, header_size, dk);
  160. dml_packet_send_header(dc, rid, header_sig, header, header_size);
  161. printf("Header requested\n");
  162. }
  163. break;
  164. }
  165. case DML_PACKET_HEADER: {
  166. uint8_t hid[DML_ID_SIZE];
  167. uint8_t sig[DML_SIG_SIZE];
  168. void *header;
  169. size_t header_size;
  170. struct dml_stream *ds;
  171. struct dml_crypto_key *dk;
  172. if (dml_packet_parse_header(data, len, hid, sig, &header, &header_size))
  173. break;
  174. if ((ds = dml_stream_by_id(hid))) {
  175. if ((dk = dml_stream_crypto_get(ds))) {
  176. bool verified = dml_crypto_verify(header, header_size, sig, dk);
  177. if (verified) {
  178. if (host->stream_header_cb)
  179. host->stream_header_cb(host, ds, header, header_size, host->stream_header_cb_arg);
  180. } else {
  181. fprintf(stderr, "Failed to verify header signature (%zd bytes)\n", header_size);
  182. }
  183. }
  184. }
  185. free(header);
  186. break;
  187. }
  188. case DML_PACKET_CONNECT: {
  189. uint16_t connect_packet_id;
  190. uint8_t connect_id[DML_ID_SIZE];
  191. dml_packet_parse_connect(data, len, connect_id, &connect_packet_id);
  192. printf("Received connect, packet_id: %d\n", connect_packet_id);
  193. struct dml_stream *ds;
  194. if ((ds = dml_stream_by_id(connect_id))) {
  195. if (!dml_stream_mine_get(ds))
  196. break;
  197. dml_stream_data_id_set(ds, connect_packet_id);
  198. }
  199. break;
  200. }
  201. case DML_PACKET_REQ_DISC: {
  202. uint8_t rid[DML_ID_SIZE];
  203. if (dml_packet_parse_req_disc(data, len, rid))
  204. break;
  205. struct dml_stream *ds;
  206. if ((ds = dml_stream_by_id(rid))) {
  207. if (!dml_stream_mine_get(ds))
  208. break;
  209. dml_stream_data_id_set(ds, 0);
  210. dml_packet_send_disc(dc, rid, DML_PACKET_DISC_REQUESTED);
  211. }
  212. break;
  213. }
  214. case DML_PACKET_REQ_REVERSE: {
  215. uint8_t id_me[DML_ID_SIZE];
  216. uint8_t id_rev[DML_ID_SIZE];
  217. uint8_t action;
  218. uint16_t status;
  219. if (dml_packet_parse_req_reverse(data, len, id_me, id_rev, &action, &status))
  220. break;
  221. printf("Received reverse request: %d status: %d\n", action, status);
  222. struct dml_stream *ds_rev = dml_stream_by_id(id_rev);
  223. struct dml_stream *ds = dml_stream_by_id(id_me);
  224. if (!ds_rev || !ds)
  225. break;
  226. if (action & DML_PACKET_REQ_REVERSE_CONNECT) {
  227. if (host->stream_req_reverse_connect_cb)
  228. host->stream_req_reverse_connect_cb(host, ds, ds_rev, status, host->stream_req_reverse_connect_cb_arg);
  229. } else if (action & DML_PACKET_REQ_REVERSE_DISC) {
  230. if (host->stream_req_reverse_disconnect_cb)
  231. host->stream_req_reverse_disconnect_cb(host, ds, ds_rev, status, host->stream_req_reverse_disconnect_cb_arg);
  232. }
  233. break;
  234. }
  235. default: {
  236. if (id < DML_PACKET_DATA)
  237. break;
  238. if (len < DML_SIG_SIZE + sizeof(uint64_t))
  239. break;
  240. uint64_t timestamp;
  241. size_t payload_len;
  242. void *payload_data;
  243. struct dml_crypto_key *dk;
  244. struct dml_stream *ds;
  245. ds = dml_stream_by_data_id(id);
  246. if (!ds) {
  247. fprintf(stderr, "Could not find dml stream\n");
  248. break;
  249. }
  250. dk = dml_stream_crypto_get(ds);
  251. if (!dk) {
  252. fprintf(stderr, "Could not find key for stream %s id %d\n", dml_stream_name_get(ds), id);
  253. break;
  254. }
  255. if (dml_packet_parse_data(data, len,
  256. &payload_data, &payload_len, &timestamp, dk)) {
  257. fprintf(stderr, "Decoding failed\n");
  258. } else {
  259. if (timestamp <= dml_stream_timestamp_get(ds)) {
  260. fprintf(stderr, "Timestamp mismatch %"PRIx64" <= %"PRIx64"\n",
  261. timestamp, dml_stream_timestamp_get(ds));
  262. } else {
  263. dml_stream_timestamp_set(ds, timestamp);
  264. if (host->stream_data_cb)
  265. host->stream_data_cb(host, ds, timestamp, payload_data, payload_len, host->stream_data_cb_arg);
  266. }
  267. }
  268. break;
  269. }
  270. }
  271. }
  272. static uint16_t alloc_data_id(void)
  273. {
  274. uint16_t id;
  275. for (id = DML_PACKET_DATA; id >= DML_PACKET_DATA; id++)
  276. if (!dml_stream_by_data_id(id))
  277. return id;
  278. return 0;
  279. }
  280. int dml_host_connect(struct dml_host *host, struct dml_stream *ds)
  281. {
  282. if (!host->connection)
  283. return -1;
  284. uint16_t data_id = dml_stream_data_id_get(ds);
  285. if (!data_id) {
  286. data_id = alloc_data_id();
  287. if (!data_id)
  288. return -1;
  289. }
  290. printf("Connect to %s (data id %d)\n", dml_stream_name_get(ds), data_id);
  291. dml_stream_data_id_set(ds, data_id);
  292. dml_packet_send_connect(host->connection, dml_stream_id_get(ds), data_id);
  293. return 0;
  294. }
  295. static int client_reconnect(void *arg)
  296. {
  297. struct dml_host *host = arg;
  298. if (dml_client_connect(host->client)) {
  299. printf("Reconnect to DML server failed\n");
  300. dml_poll_timeout(host, &(struct timespec){ 2, 0 });
  301. }
  302. return 0;
  303. }
  304. static int client_connection_close(struct dml_connection *dc, void *arg)
  305. {
  306. struct dml_host *host = arg;
  307. host->connection = NULL;
  308. struct dml_stream *ds = NULL;
  309. while ((ds = dml_stream_iterate(ds))) {
  310. if (!dml_stream_mine_get(ds))
  311. continue;
  312. dml_stream_data_id_set(ds, 0);
  313. }
  314. if (host->connection_closed_cb)
  315. host->connection_closed_cb(host, host->connection_closed_cb_arg);
  316. dml_poll_add(host, NULL, NULL, client_reconnect);
  317. dml_poll_timeout(host, &(struct timespec){ 1, 0 });
  318. if (dc) {
  319. return dml_connection_destroy(dc);
  320. } else
  321. return 0;
  322. }
  323. static void client_connect(struct dml_client *client, void *arg)
  324. {
  325. struct dml_host *host = arg;
  326. struct dml_connection *dc;
  327. int fd;
  328. printf("Connected to DML server\n");
  329. fd = dml_client_fd_get(client);
  330. dc = dml_connection_create(fd, host, rx_packet, client_connection_close);
  331. host->connection = dc;
  332. dml_packet_send_hello(dc,
  333. DML_PACKET_HELLO_LEAF | DML_PACKET_HELLO_UPDATES,
  334. "dml_host " DML_VERSION);
  335. struct dml_stream *ds = NULL;
  336. while ((ds = dml_stream_iterate(ds))) {
  337. if (!dml_stream_mine_get(ds))
  338. continue;
  339. dml_packet_send_route(dc, dml_stream_id_get(ds), 0);
  340. }
  341. }
  342. struct dml_connection *dml_host_connection_get(struct dml_host *host)
  343. {
  344. return host->connection;
  345. }
  346. int dml_host_mime_filter_set(struct dml_host *host, int nr, char **mimetypes)
  347. {
  348. host->mime_filter = mimetypes;
  349. host->mime_filter_nr = nr;
  350. return 0;
  351. }
  352. int dml_host_stream_added_cb_set(struct dml_host *host,
  353. void(*cb)(struct dml_host *host, struct dml_stream *ds, void *arg), void *arg)
  354. {
  355. host->stream_added_cb = cb;
  356. host->stream_added_cb_arg = arg;
  357. return 0;
  358. }
  359. int dml_host_stream_removed_cb_set(struct dml_host *host,
  360. void(*cb)(struct dml_host *host, struct dml_stream *ds, void *arg), void *arg)
  361. {
  362. host->stream_removed_cb = cb;
  363. host->stream_removed_cb_arg = arg;
  364. return 0;
  365. }
  366. int dml_host_stream_header_cb_set(struct dml_host *host,
  367. void (*cb)(struct dml_host *host, struct dml_stream *ds, void *header, size_t header_size, void *arg), void *arg)
  368. {
  369. host->stream_header_cb = cb;
  370. host->stream_header_cb_arg = arg;
  371. return 0;
  372. }
  373. int dml_host_stream_data_cb_set(struct dml_host *host,
  374. void (*cb)(struct dml_host *host, struct dml_stream *ds, uint64_t timestamp, void *data, size_t data_size, void *arg), void *arg)
  375. {
  376. host->stream_data_cb = cb;
  377. host->stream_data_cb_arg = arg;
  378. return 0;
  379. }
  380. int dml_host_stream_req_reverse_connect_cb_set(struct dml_host *host,
  381. void (*cb)(struct dml_host *host, struct dml_stream *ds, struct dml_stream *ds_rev, int status, void *arg), void *arg)
  382. {
  383. host->stream_req_reverse_connect_cb = cb;
  384. host->stream_req_reverse_connect_cb_arg = arg;
  385. return 0;
  386. }
  387. int dml_host_stream_req_reverse_disconnect_cb_set(struct dml_host *host,
  388. void (*cb)(struct dml_host *host, struct dml_stream *ds, struct dml_stream *ds_rev, int status, void *arg), void *arg)
  389. {
  390. host->stream_req_reverse_disconnect_cb = cb;
  391. host->stream_req_reverse_disconnect_cb_arg = arg;
  392. return 0;
  393. }
  394. int dml_host_connection_closed_cb_set(struct dml_host *host,
  395. void(*cb)(struct dml_host *host, void *arg), void *arg)
  396. {
  397. host->connection_closed_cb = cb;
  398. host->connection_closed_cb_arg = arg;
  399. return 0;
  400. }
  401. struct dml_host *dml_host_create(char *server)
  402. {
  403. struct dml_host *host = calloc(1, sizeof(struct dml_host));
  404. if (!host)
  405. goto err_alloc;
  406. host->client = dml_client_create(server, 0, client_connect, host);
  407. if (dml_client_connect(host->client)) {
  408. printf("Could not connect to server\n");
  409. dml_poll_add(host, NULL, NULL, client_reconnect);
  410. dml_poll_timeout(host, &(struct timespec){ 2, 0 });
  411. }
  412. err_alloc:
  413. return host;
  414. }