diff --git a/.gitignore b/.gitignore index 30131ff..2384521 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ /target /.idea /config.toml -/db +/data /logs .env \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index e70671b..9d03924 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -194,12 +194,52 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "attohttpc" +version = "0.28.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07a9b245ba0739fc90935094c29adbaee3f977218b5fb95e822e261cda7f56a3" +dependencies = [ + "http 1.3.1", + "log", + "native-tls", + "serde", + "serde_json", + "url", +] + [[package]] name = "autocfg" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-creds" +version = "0.37.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f84143206b9c72b3c5cb65415de60c7539c79cd1559290fddec657939131be0" +dependencies = [ + "attohttpc", + "home", + "log", + "quick-xml", + "rust-ini", + "serde", + "thiserror 1.0.69", + "time", + "url", +] + +[[package]] +name = "aws-region" +version = "0.25.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9aed3f9c7eac9be28662fdb3b0f4d1951e812f7c64fed4f0327ba702f459b3b" +dependencies = [ + "thiserror 1.0.69", +] + [[package]] name = "axum" version = "0.8.4" @@ -212,10 +252,10 @@ dependencies = [ "bytes", "form_urlencoded", "futures-util", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "http-body-util", - "hyper", + "hyper 1.6.0", "hyper-util", "itoa", "matchit", @@ -247,8 +287,8 @@ checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -270,8 +310,8 @@ dependencies = [ "bytes", "futures-util", "headers", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -384,9 +424,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.9.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" dependencies = [ "serde", ] @@ -513,9 +553,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.22" +version = "1.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32db95edf998450acc7881c932f94cd9b05c87b4b2599e8bab064753da4acfd1" +checksum = "5f4ac86a9e5bc1e2b3449ab9d7d3a6a405e3d1bb28d7b9be8614f55846ae3766" dependencies = [ "shlex", ] @@ -641,6 +681,16 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -855,6 +905,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" dependencies = [ "powerfmt", + "serde", ] [[package]] @@ -922,6 +973,7 @@ dependencies = [ "mime", "rand 0.9.1", "regex", + "rust-s3", "scc", "serde", "serde_json", @@ -1025,9 +1077,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" +checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" dependencies = [ "libc", "windows-sys 0.59.0", @@ -1100,6 +1152,21 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1322,7 +1389,7 @@ dependencies = [ "base64 0.21.7", "bytes", "headers-core", - "http", + "http 1.3.1", "httpdate", "mime", "sha1", @@ -1334,7 +1401,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" dependencies = [ - "http", + "http 1.3.1", ] [[package]] @@ -1376,6 +1443,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http" version = "1.3.1" @@ -1387,6 +1465,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + [[package]] name = "http-body" version = "1.0.1" @@ -1394,7 +1483,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http", + "http 1.3.1", ] [[package]] @@ -1405,8 +1494,8 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "pin-project-lite", ] @@ -1431,6 +1520,29 @@ dependencies = [ "typenum", ] +[[package]] +name = "hyper" +version = "0.14.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", + "want", +] + [[package]] name = "hyper" version = "1.6.0" @@ -1440,8 +1552,8 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -1450,6 +1562,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "hyper-tls" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" +dependencies = [ + "bytes", + "hyper 0.14.32", + "native-tls", + "tokio", + "tokio-native-tls", +] + [[package]] name = "hyper-util" version = "0.1.11" @@ -1458,9 +1583,9 @@ checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2" dependencies = [ "bytes", "futures-util", - "http", - "http-body", - "hyper", + "http 1.3.1", + "http-body 1.0.1", + "hyper 1.6.0", "pin-project-lite", "tokio", "tower-service", @@ -1765,6 +1890,17 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" +[[package]] +name = "maybe-async" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cf92c10c7e361d6b99666ec1c6f9805b0bea2c3bd8c78dc6fe98ac5bd78db11" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "md-5" version = "0.10.6" @@ -1775,6 +1911,12 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.7.4" @@ -1796,6 +1938,15 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minidom" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f45614075738ce1b77a1768912a60c0227525971b03e09122a05b8a34a2a6278" +dependencies = [ + "rxml", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1831,7 +1982,7 @@ dependencies = [ "bytes", "encoding_rs", "futures-util", - "http", + "http 1.3.1", "httparse", "memchr", "mime", @@ -1839,6 +1990,23 @@ dependencies = [ "version_check", ] +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nix" version = "0.26.4" @@ -1965,6 +2133,50 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" +[[package]] +name = "openssl" +version = "0.10.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e145e1651e858e820e4860f7b9c5e169bc1d8ce1c86043be79fa7b7634821847" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "ordered-multimap" version = "0.7.3" @@ -2271,6 +2483,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "quick-xml" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d3a6e5838b60e0e8fa7a43f22ade549a37d61f8bdbe636d0d7816191de969c2" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "quote" version = "1.0.40" @@ -2371,7 +2593,7 @@ version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af" dependencies = [ - "bitflags 2.9.0", + "bitflags 2.9.1", ] [[package]] @@ -2487,7 +2709,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94" dependencies = [ "base64 0.21.7", - "bitflags 2.9.0", + "bitflags 2.9.1", "serde", "serde_derive", ] @@ -2549,6 +2771,43 @@ dependencies = [ "trim-in-place", ] +[[package]] +name = "rust-s3" +version = "0.35.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3df3f353b1f4209dcf437d777cda90279c397ab15a0cd6fd06bd32c88591533" +dependencies = [ + "async-trait", + "aws-creds", + "aws-region", + "base64 0.22.1", + "bytes", + "cfg-if", + "futures", + "hex", + "hmac", + "http 0.2.12", + "hyper 0.14.32", + "hyper-tls", + "log", + "maybe-async", + "md5", + "minidom", + "native-tls", + "percent-encoding", + "quick-xml", + "serde", + "serde_derive", + "serde_json", + "sha2", + "thiserror 1.0.69", + "time", + "tokio", + "tokio-native-tls", + "tokio-stream", + "url", +] + [[package]] name = "rust_decimal" version = "1.37.1" @@ -2595,7 +2854,7 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" dependencies = [ - "bitflags 2.9.0", + "bitflags 2.9.1", "errno", "libc", "linux-raw-sys", @@ -2642,6 +2901,23 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" +[[package]] +name = "rxml" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a98f186c7a2f3abbffb802984b7f1dfd65dac8be1aafdaabbca4137f53f0dff7" +dependencies = [ + "bytes", + "rxml_validation", + "smartstring", +] + +[[package]] +name = "rxml_validation" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22a197350ece202f19a166d1ad6d9d6de145e1d2a8ef47db299abe164dbd7530" + [[package]] name = "ryu" version = "1.0.20" @@ -2657,6 +2933,15 @@ dependencies = [ "sdd", ] +[[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2701,6 +2986,29 @@ dependencies = [ "zeroize", ] +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.9.1", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.26" @@ -2862,6 +3170,17 @@ dependencies = [ "serde", ] +[[package]] +name = "smartstring" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29" +dependencies = [ + "autocfg", + "static_assertions", + "version_check", +] + [[package]] name = "smol_str" version = "0.2.2" @@ -2996,7 +3315,7 @@ checksum = "0afdd3aa7a629683c2d750c2df343025545087081ab5942593a5288855b1b7a7" dependencies = [ "atoi", "base64 0.22.1", - "bitflags 2.9.0", + "bitflags 2.9.1", "byteorder", "bytes", "chrono", @@ -3040,7 +3359,7 @@ checksum = "a0bedbe1bbb5e2615ef347a5e9d8cd7680fb63e77d9dafc0f29be15e53f1ebe6" dependencies = [ "atoi", "base64 0.22.1", - "bitflags 2.9.0", + "bitflags 2.9.1", "byteorder", "chrono", "crc", @@ -3103,6 +3422,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "stringprep" version = "0.1.5" @@ -3201,9 +3526,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.19.1" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ "fastrand", "getrandom 0.3.3", @@ -3356,6 +3681,16 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.17" @@ -3444,14 +3779,14 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.2" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697" +checksum = "0fdb0c213ca27a9f57ab69ddb290fd80d970922355b83ae380b395d3986b8a2e" dependencies = [ - "bitflags 2.9.0", + "bitflags 2.9.1", "bytes", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "pin-project-lite", "tower-layer", "tower-service", @@ -3552,6 +3887,12 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "343e926fc669bc8cde4fa3129ab681c63671bae288b1f1081ceee6d9d37904fc" +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + [[package]] name = "tungstenite" version = "0.26.2" @@ -3560,7 +3901,7 @@ checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" dependencies = [ "bytes", "data-encoding", - "http", + "http 1.3.1", "httparse", "log", "rand 0.9.1", @@ -3755,6 +4096,15 @@ dependencies = [ "atomic-waker", ] +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -4077,9 +4427,9 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows-core" -version = "0.61.0" +version = "0.61.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980" +checksum = "46ec44dc15085cea82cf9c78f85a9114c463a369786585ad2882d1ff0b0acf40" dependencies = [ "windows-implement", "windows-interface", @@ -4118,18 +4468,18 @@ checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" [[package]] name = "windows-result" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252" +checksum = "4b895b5356fc36103d0f64dd1e94dfa7ac5633f1c9dd6e80fe9ec4adef69e09d" dependencies = [ "windows-link", ] [[package]] name = "windows-strings" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2ba9642430ee452d5a7aa78d72907ebe8cfda358e8cb7918a2050581322f97" +checksum = "2a7ab927b2637c19b3dbe0965e75d8f2d30bdd697a1516191cad2ec4df8fb28a" dependencies = [ "windows-link", ] @@ -4297,7 +4647,7 @@ version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" dependencies = [ - "bitflags 2.9.0", + "bitflags 2.9.1", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index c6f8d42..0d09b95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,3 +34,4 @@ rand = "0.9" sha2 = "0.10" base64 = "0.22" scc = "2.3" +rust-s3 = "0.35.1" diff --git a/docker-compose.yml b/docker-compose.yml index 32fac71..44952df 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,5 +6,17 @@ services: env_file: - .env volumes: - - ${PWD}/db/:/var/lib/postgresql/data/ + - ${PWD}/data/db/:/var/lib/postgresql/data/ + user: "1000:1000" + + object_storage: + image: 'quay.io/minio/minio:latest' + ports: + - "9000:9000" + - "9001:9001" + env_file: + - .env + volumes: + - ${PWD}/data/minio/:/data + command: server /data --console-address ":9001" user: "1000:1000" \ No newline at end of file diff --git a/migrations/20250510183101_uuidv7.up.sql b/migrations/20250510180101_uuidv7.sql similarity index 100% rename from migrations/20250510183101_uuidv7.up.sql rename to migrations/20250510180101_uuidv7.sql diff --git a/migrations/20250510182916_file.sql b/migrations/20250510182916_file.sql new file mode 100644 index 0000000..e2ef3b7 --- /dev/null +++ b/migrations/20250510182916_file.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS "file" +( + "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), + "filename" VARCHAR NOT NULL, + "content_type" VARCHAR NOT NULL, + "size" INT8 NOT NULL +); + diff --git a/migrations/20250510183101_uuidv7.down.sql b/migrations/20250510183101_uuidv7.down.sql deleted file mode 100644 index 6960f25..0000000 --- a/migrations/20250510183101_uuidv7.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP EXTENSION pg_uuidv7; \ No newline at end of file diff --git a/migrations/20250510183102_user.down.sql b/migrations/20250510183102_user.down.sql deleted file mode 100644 index 4b100bb..0000000 --- a/migrations/20250510183102_user.down.sql +++ /dev/null @@ -1,4 +0,0 @@ -DROP TRIGGER trg_user_relation_update ON "user_relation"; -DROP FUNCTION fn_on_user_relation_update(); -DROP TABLE "user_relation"; -DROP TABLE "user"; \ No newline at end of file diff --git a/migrations/20250510183102_user.up.sql b/migrations/20250510183102_user.sql similarity index 62% rename from migrations/20250510183102_user.up.sql rename to migrations/20250510183102_user.sql index 8e5e907..0c39a64 100644 --- a/migrations/20250510183102_user.up.sql +++ b/migrations/20250510183102_user.sql @@ -1,7 +1,7 @@ CREATE TABLE IF NOT EXISTS "user" ( "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), - "avatar_url" VARCHAR, + "avatar_id" UUID REFERENCES "file" ("id") ON DELETE SET NULL, "username" VARCHAR NOT NULL UNIQUE, "display_name" VARCHAR, "email" VARCHAR NOT NULL, @@ -25,6 +25,38 @@ CREATE TABLE IF NOT EXISTS "user_relation" INSERT INTO "user" ("username", "display_name", "email", "password_hash", "bot", "system") VALUES ('system', 'System', 'system@lionarius.ru', '', TRUE, TRUE); +CREATE OR REPLACE FUNCTION check_avatar_is_image() + RETURNS TRIGGER AS +$$ +DECLARE + file_content_type VARCHAR; +BEGIN + -- Skip check if icon_id is null + IF NEW.avatar_id IS NULL THEN + RETURN NEW; + END IF; + + -- Retrieve content_type from file table + SELECT content_type + INTO file_content_type + FROM file + WHERE id = NEW.avatar_id; + + -- Raise exception if content_type does not start with 'image/' + IF file_content_type IS NULL OR file_content_type NOT LIKE 'image/%' THEN + RAISE EXCEPTION 'avatar_id must reference a file with content_type starting with image/'; + END IF; + + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trigger_check_icon_is_image + BEFORE INSERT OR UPDATE + ON "user" + FOR EACH ROW +EXECUTE FUNCTION check_avatar_is_image(); + CREATE OR REPLACE FUNCTION fn_on_user_relation_update() RETURNS TRIGGER LANGUAGE plpgsql diff --git a/migrations/20250510183125_server.down.sql b/migrations/20250510183125_server.down.sql deleted file mode 100644 index 3271cb2..0000000 --- a/migrations/20250510183125_server.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE "server"; \ No newline at end of file diff --git a/migrations/20250510183125_server.up.sql b/migrations/20250510183125_server.sql similarity index 73% rename from migrations/20250510183125_server.up.sql rename to migrations/20250510183125_server.sql index 8d56ff0..df17f24 100644 --- a/migrations/20250510183125_server.up.sql +++ b/migrations/20250510183125_server.sql @@ -3,7 +3,7 @@ CREATE TABLE IF NOT EXISTS "server" "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), "owner_id" UUID NOT NULL REFERENCES "user" ("id"), "name" VARCHAR NOT NULL, - "icon_url" VARCHAR + "icon_id" UUID REFERENCES "file" ("id") ON DELETE SET NULL ); CREATE TABLE IF NOT EXISTS "server_role" @@ -42,6 +42,38 @@ CREATE TABLE IF NOT EXISTS "server_invite" "expires_at" TIMESTAMPTZ ); +CREATE OR REPLACE FUNCTION check_icon_is_image() + RETURNS TRIGGER AS +$$ +DECLARE + file_content_type VARCHAR; +BEGIN + -- Skip check if icon_id is null + IF NEW.icon_id IS NULL THEN + RETURN NEW; + END IF; + + -- Retrieve content_type from file table + SELECT content_type + INTO file_content_type + FROM file + WHERE id = NEW.icon_id; + + -- Raise exception if content_type does not start with 'image/' + IF file_content_type IS NULL OR file_content_type NOT LIKE 'image/%' THEN + RAISE EXCEPTION 'icon_id must reference a file with content_type starting with image/'; + END IF; + + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trigger_check_icon_is_image + BEFORE INSERT OR UPDATE + ON "server" + FOR EACH ROW +EXECUTE FUNCTION check_icon_is_image(); + CREATE OR REPLACE FUNCTION check_server_user_role_server_id() RETURNS TRIGGER AS $$ diff --git a/migrations/20250510184011_channel_message.down.sql b/migrations/20250510184011_channel_message.down.sql deleted file mode 100644 index e03cd29..0000000 --- a/migrations/20250510184011_channel_message.down.sql +++ /dev/null @@ -1,2 +0,0 @@ -DROP TABLE "message"; -DROP TABLE "channel"; \ No newline at end of file diff --git a/migrations/20250510184011_channel_message.up.sql b/migrations/20250510184011_channel_message.sql similarity index 88% rename from migrations/20250510184011_channel_message.up.sql rename to migrations/20250510184011_channel_message.sql index 489ce45..93a0cbf 100644 --- a/migrations/20250510184011_channel_message.up.sql +++ b/migrations/20250510184011_channel_message.sql @@ -20,11 +20,19 @@ CREATE TABLE IF NOT EXISTS "channel_recipient" CREATE TABLE IF NOT EXISTS "message" ( "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), - "author_id" UUID NOT NULL REFERENCES "user" ("id"), - "channel_id" UUID NOT NULL REFERENCES "channel" ("id"), + "author_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE, + "channel_id" UUID NOT NULL REFERENCES "channel" ("id") ON DELETE CASCADE, "content" TEXT NOT NULL ); +CREATE TABLE IF NOT EXISTS "message_attachment" +( + "message_id" UUID NOT NULL REFERENCES "message" ON DELETE CASCADE, + "file_id" UUID NOT NULL REFERENCES "file" ON DELETE CASCADE, + "order" INT2 NOT NULL, + PRIMARY KEY ("message_id", "file_id") +); + ALTER TABLE "channel" ADD COLUMN "last_message_id" UUID REFERENCES "message" ("id") ON DELETE SET NULL; diff --git a/migrations/20250510184916_file.down.sql b/migrations/20250510184916_file.down.sql deleted file mode 100644 index 7731a0b..0000000 --- a/migrations/20250510184916_file.down.sql +++ /dev/null @@ -1,2 +0,0 @@ -DROP TABLE "message_attachment"; -DROP TABLE "file"; \ No newline at end of file diff --git a/migrations/20250510184916_file.up.sql b/migrations/20250510184916_file.up.sql deleted file mode 100644 index 61da927..0000000 --- a/migrations/20250510184916_file.up.sql +++ /dev/null @@ -1,16 +0,0 @@ -CREATE TABLE IF NOT EXISTS "file" -( - "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), - "filename" VARCHAR NOT NULL, - "content_type" VARCHAR NOT NULL, - "url" VARCHAR NOT NULL, - "size" INT8 NOT NULL -); - -CREATE TABLE IF NOT EXISTS "message_attachment" -( - "message_id" UUID NOT NULL REFERENCES "message" ON DELETE CASCADE, - "attachment_id" UUID NOT NULL REFERENCES "file" ON DELETE CASCADE, - "order" INT2 NOT NULL, - PRIMARY KEY ("message_id", "attachment_id") -); \ No newline at end of file diff --git a/migrations/20250517190855_util.sql b/migrations/20250517190855_util.sql new file mode 100644 index 0000000..f3354fe --- /dev/null +++ b/migrations/20250517190855_util.sql @@ -0,0 +1,35 @@ +CREATE OR REPLACE FUNCTION get_users_that_can_see_user(target_user_id UUID) + RETURNS TABLE (user_id UUID) AS $$ +BEGIN + RETURN QUERY + -- Users directly related to the target user + SELECT ur.user_id + FROM user_relation ur + WHERE ur.other_id = target_user_id + + UNION + + -- Users where target user is related to them + SELECT ur.other_id AS user_id + FROM user_relation ur + WHERE ur.user_id = target_user_id + + UNION + + -- Users who share a server with the target user + SELECT sm.user_id + FROM server_member sm + JOIN server_member sm2 ON sm.server_id = sm2.server_id + WHERE sm2.user_id = target_user_id + AND sm.user_id != target_user_id + + UNION + + -- Users who share a channel with the target user (DM or group) + SELECT cr.user_id + FROM channel_recipient cr + JOIN channel_recipient cr2 ON cr.channel_id = cr2.channel_id + WHERE cr2.user_id = target_user_id + AND cr.user_id != target_user_id; +END; +$$ LANGUAGE plpgsql; \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 26e0240..96604bf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -22,10 +22,12 @@ pub struct Config { pub security: SecurityConfig, pub gateway: GatewayConfig, pub database: DatabaseConfig, + pub object_store: ObjectStoreConfig, } #[derive(Deserialize)] pub struct ServerConfig { + pub hostname: url::Url, pub host: std::net::Ipv4Addr, pub port: u16, } @@ -59,6 +61,16 @@ pub enum DatabaseConfig { }, } +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct ObjectStoreConfig { + pub endpoint: url::Url, + pub region: String, + pub bucket: String, + pub access_key: String, + pub secret_key: String, +} + impl DatabaseConfig { pub fn url(&self) -> Option { match self { diff --git a/src/database.rs b/src/database.rs index c73db5d..be0af84 100644 --- a/src/database.rs +++ b/src/database.rs @@ -21,9 +21,15 @@ pub enum Error { ServerDoesNotExists, + MemberAlreadyExists, + ChannelDoesNotExists, + InviteDoesNotExists, + MessageDoesNotExists, + + FileDoesNotExists, } impl Database { @@ -81,6 +87,25 @@ impl Database { Ok(user) } + pub async fn update_user_by_id( + &self, + user_id: entity::user::Id, + display_name: Option<&str>, + avatar_id: Option, + ) -> Result { + let user = sqlx::query_as!( + entity::user::User, + r#"UPDATE "user" SET "display_name" = COALESCE($2, "display_name"), "avatar_id" = COALESCE($3, "avatar_id") WHERE "id" = $1 RETURNING "user".*"#, + user_id, + display_name, + avatar_id + ) + .fetch_one(&self.pool) + .await?; + + Ok(user) + } + pub async fn select_users_by_ids( &self, user_ids: &[entity::user::Id], @@ -146,17 +171,54 @@ impl Database { Ok(servers) } + pub async fn select_server_members( + &self, + server_id: entity::server::Id, + ) -> Result> { + let users = sqlx::query_as!( + entity::user::User, + r#"SELECT * FROM "user" WHERE "id" IN ( + SELECT "user_id" FROM "server_member" WHERE "server_id" = $1 + )"#, + server_id + ) + .fetch_all(&self.pool) + .await?; + + Ok(users) + } + + pub async fn select_channel_members( + &self, + channel_id: entity::channel::Id, + ) -> Result> { + let users = sqlx::query_as!( + entity::user::User, + r#"SELECT * FROM "user" WHERE "id" IN ( + SELECT "user_id" FROM "channel_recipient" WHERE "channel_id" = $1 + UNION SELECT "user_id" FROM "server_member" WHERE "server_id" IN ( + SELECT "server_id" FROM "channel" WHERE "id" = $1 + ) + )"#, + channel_id + ) + .fetch_all(&self.pool) + .await?; + + Ok(users) + } + pub async fn select_user_channels( &self, user_id: entity::user::Id, ) -> Result> { - // for some reason using macro overflows tokio stack - let channels = sqlx::query_as( + let channels = sqlx::query_as!( + entity::channel::Channel, r#"SELECT * FROM "channel" WHERE "id" IN ( SELECT "channel_id" FROM "channel_recipient" WHERE "user_id" = $1 )"#, + user_id ) - .bind(user_id) .fetch_all(&self.pool) .await?; @@ -166,14 +228,14 @@ impl Database { pub async fn insert_server( &self, name: &str, - icon_url: Option<&str>, + icon_id: Option, owner_id: entity::user::Id, ) -> Result { let server = sqlx::query_as!( entity::server::Server, - r#"INSERT INTO "server"("name", "icon_url", "owner_id") VALUES ($1, $2, $3) RETURNING "server".*"#, + r#"INSERT INTO "server"("name", "icon_id", "owner_id") VALUES ($1, $2, $3) RETURNING "server".*"#, name, - icon_url, + icon_id, owner_id ) .fetch_one(&self.pool) @@ -214,14 +276,20 @@ impl Database { server_id: entity::server::Id, user_id: entity::user::Id, ) -> Result { - let member = sqlx::query_as!( + let member = match sqlx::query_as!( entity::server::member::ServerMember, r#"INSERT INTO "server_member"("server_id", "user_id") VALUES ($1, $2) RETURNING "server_member".*"#, server_id, user_id ) .fetch_one(&self.pool) - .await?; + .await { + Ok(member) => member, + Err(sqlx::Error::Database(e)) if e.code() == Some("23505".into()) => { + return Err(Error::MemberAlreadyExists); + } + Err(e) => return Err(e.into()), + }; Ok(member) } @@ -242,38 +310,18 @@ impl Database { Ok(()) } - pub async fn insert_server_channel( - &self, - server_id: entity::server::Id, - name: &str, - channel_type: entity::channel::ChannelType, - position: u16, - parent: Option, - ) -> Result { - let channel = sqlx::query_as!( - entity::channel::Channel, - r#"INSERT INTO "channel"("name", "type", "position", "server_id", "parent") VALUES ($1, $2, $3, $4, $5) RETURNING "channel".*"#, - name, - channel_type as i16, - position as i16, - server_id, - parent - ) - .fetch_one(&self.pool) - .await?; - - Ok(channel) - } - pub async fn select_channel_by_id( &self, channel_id: entity::channel::Id, ) -> Result { - let channel = sqlx::query_as(r#"SELECT * FROM "channel" WHERE "id" = $1"#) - .bind(channel_id) - .fetch_optional(&self.pool) - .await? - .ok_or(Error::ChannelDoesNotExists)?; + let channel = sqlx::query_as!( + entity::channel::Channel, + r#"SELECT * FROM "channel" WHERE "id" = $1"#, + channel_id + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::ChannelDoesNotExists)?; Ok(channel) } @@ -314,6 +362,217 @@ impl Database { Ok(channels) } + pub async fn delete_server_by_id( + &self, + server_id: entity::server::Id, + ) -> Result { + let server = sqlx::query_as!( + entity::server::Server, + r#"DELETE FROM "server" WHERE "id" = $1 RETURNING "server".*"#, + server_id + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::ServerDoesNotExists)?; + + Ok(server) + } + + pub async fn delete_channel_by_id( + &self, + channel_id: entity::channel::Id, + ) -> Result { + let channel = sqlx::query_as!( + entity::channel::Channel, + r#"DELETE FROM "channel" WHERE "id" = $1 RETURNING "channel".*"#, + channel_id + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::ChannelDoesNotExists)?; + + Ok(channel) + } + + pub async fn insert_server_channel( + &self, + name: &str, + position: u16, + r#type: entity::channel::ChannelType, + server_id: entity::server::Id, + parent: Option, + ) -> Result { + let channel = sqlx::query_as!( + entity::channel::Channel, + r#"INSERT INTO "channel"("name", "type", "position", "server_id", "parent") VALUES ($1, $2, $3, $4, $5) RETURNING "channel".*"#, + name, + r#type as i16, + position as i16, + server_id, + parent + ) + .fetch_one(&self.pool) + .await?; + + Ok(channel) + } + + pub async fn insert_server_invite( + &self, + code: &str, + server_id: entity::server::Id, + inviter_id: Option, + expires_at: Option>, + ) -> Result { + let invite = sqlx::query_as!( + entity::server::invite::ServerInvite, + r#"INSERT INTO "server_invite"("code", "server_id", "inviter_id", "expires_at") VALUES ($1, $2, $3, $4) RETURNING "server_invite".*"#, + code, + server_id, + inviter_id, + expires_at + ) + .fetch_one(&self.pool) + .await?; + + Ok(invite) + } + + pub async fn select_server_invite_by_code( + &self, + code: &str, + ) -> Result { + let invite = sqlx::query_as!( + entity::server::invite::ServerInvite, + r#"SELECT * FROM "server_invite" WHERE "code" = $1"#, + code + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::InviteDoesNotExists)?; + + Ok(invite) + } + + pub async fn delete_server_invite_by_code( + &self, + code: &str, + ) -> Result> { + let invite = sqlx::query_as!( + entity::server::invite::ServerInvite, + r#"DELETE FROM "server_invite" WHERE "code" = $1 RETURNING "server_invite".*"#, + code + ) + .fetch_optional(&self.pool) + .await?; + + Ok(invite) + } + + pub async fn select_channel_messages_paginated( + &self, + channel_id: entity::channel::Id, + before: Option, + limit: i64, + ) -> Result> { + let messages = sqlx::query_as!( + entity::message::Message, + r#"SELECT * FROM "message" WHERE "channel_id" = $1 AND ($2::uuid IS NULL OR "id" < $2::uuid) ORDER BY "id" DESC LIMIT $3"#, + channel_id, + before, + limit + ) + .fetch_all(&self.pool) + .await?; + + Ok(messages) + } + + pub async fn insert_channel_message( + &self, + user_id: entity::user::Id, + channel_id: entity::channel::Id, + content: &str, + ) -> Result { + let message = sqlx::query_as!( + entity::message::Message, + r#"INSERT INTO "message"("channel_id", "author_id", "content") VALUES ($1, $2, $3) RETURNING "message".*"#, + channel_id, + user_id, + content + ) + .fetch_one(&self.pool) + .await?; + + Ok(message) + } + + pub async fn select_file_by_id(&self, file_id: entity::file::Id) -> Result { + let file = sqlx::query_as!( + entity::file::File, + r#"SELECT * FROM "file" WHERE "id" = $1"#, + file_id + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::FileDoesNotExists)?; + + Ok(file) + } + + pub async fn delete_file_by_id(&self, file_id: entity::file::Id) -> Result { + let file = sqlx::query_as!( + entity::file::File, + r#"DELETE FROM "file" WHERE "id" = $1 RETURNING "file".*"#, + file_id + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::FileDoesNotExists)?; + + Ok(file) + } + + pub async fn insert_file( + &self, + filename: &str, + content_type: &str, + size: usize, + ) -> Result { + let file = sqlx::query_as!( + entity::file::File, + r#"INSERT INTO "file"("filename", "content_type", "size") VALUES ($1, $2, $3) RETURNING "file".*"#, + filename, + content_type, + size as i64 + ) + .fetch_one(&self.pool) + .await?; + + Ok(file) + } + + pub async fn select_related_user_ids( + &self, + user_id: entity::user::Id, + ) -> Result> { + #[derive(sqlx::FromRow)] + struct UserId { + user_id: entity::user::Id, + } + + let user_ids = + sqlx::query_as::<_, UserId>(r#"SELECT * FROM get_users_that_can_see_user($1)"#) + .bind(user_id) + .fetch_all(&self.pool) + .await? + .into_iter() + .map(|row| row.user_id) + .collect(); + + Ok(user_ids) + } + pub async fn procedure_create_dm_channel( &self, user1_id: entity::user::Id, diff --git a/src/entity/channel.rs b/src/entity/channel.rs index 0becc30..94150c8 100644 --- a/src/entity/channel.rs +++ b/src/entity/channel.rs @@ -1,4 +1,4 @@ -use serde::Serialize; +use serde::{Deserialize, Serialize}; use crate::entity::{message, server, user}; @@ -21,7 +21,7 @@ pub struct Channel { pub last_message_id: Option, } -#[derive(Debug, Clone, sqlx::Type, Serialize)] +#[derive(Debug, Clone, sqlx::Type, Serialize, Deserialize)] #[non_exhaustive] #[serde(rename_all = "snake_case")] #[repr(i16)] @@ -36,6 +36,13 @@ pub enum ChannelType { impl From for ChannelType { fn from(value: i16) -> Self { - value.try_into().unwrap_or(ChannelType::ServerText) + match value { + 1 => ChannelType::ServerText, + 2 => ChannelType::ServerVoice, + 3 => ChannelType::ServerCategory, + 4 => ChannelType::DirectMessage, + 5 => ChannelType::Group, + _ => ChannelType::ServerText, + } } } diff --git a/src/entity/attachment.rs b/src/entity/file.rs similarity index 73% rename from src/entity/attachment.rs rename to src/entity/file.rs index 7e83736..f9780ab 100644 --- a/src/entity/attachment.rs +++ b/src/entity/file.rs @@ -3,10 +3,9 @@ use serde::Serialize; pub type Id = uuid::Uuid; #[derive(Debug, Clone, sqlx::FromRow, Serialize)] -pub struct Attachment { +pub struct File { pub id: Id, pub filename: String, pub content_type: String, - pub url: String, - pub size: u64, + pub size: i64, } diff --git a/src/entity/message.rs b/src/entity/message.rs index 6597ffc..ccb8cf8 100644 --- a/src/entity/message.rs +++ b/src/entity/message.rs @@ -1,15 +1,11 @@ -use serde::Serialize; - use crate::entity::{channel, user}; pub type Id = uuid::Uuid; -#[derive(Debug, Clone, sqlx::FromRow, Serialize)] -#[serde(rename_all = "camelCase")] +#[derive(Debug, Clone, sqlx::FromRow)] pub struct Message { pub id: Id, pub author_id: user::Id, pub channel_id: channel::Id, pub content: String, - pub timestamp: chrono::DateTime, } diff --git a/src/entity/mod.rs b/src/entity/mod.rs index 0767d0e..cfe1ef6 100644 --- a/src/entity/mod.rs +++ b/src/entity/mod.rs @@ -1,4 +1,4 @@ -pub mod attachment; +pub mod file; pub mod channel; pub mod message; pub mod server; diff --git a/src/entity/server.rs b/src/entity/server.rs index 4ed1f37..6a35104 100644 --- a/src/entity/server.rs +++ b/src/entity/server.rs @@ -1,18 +1,15 @@ -mod invite; +pub mod invite; pub mod member; pub mod role; -use serde::Serialize; - -use crate::entity::user; +use crate::entity::{file, user}; pub type Id = uuid::Uuid; -#[derive(Debug, Clone, sqlx::FromRow, Serialize)] -#[serde(rename_all = "camelCase")] +#[derive(Debug, Clone, sqlx::FromRow)] pub struct Server { pub id: Id, pub owner_id: user::Id, pub name: String, - pub icon_url: Option, + pub icon_id: Option, } diff --git a/src/entity/user.rs b/src/entity/user.rs index 15a0260..abf6dea 100644 --- a/src/entity/user.rs +++ b/src/entity/user.rs @@ -1,6 +1,7 @@ use std::sync::LazyLock; use regex::Regex; +use crate::entity::file; pub static USERNAME_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_.]+$").unwrap()); @@ -10,7 +11,7 @@ pub type Id = uuid::Uuid; #[derive(Debug, Clone, sqlx::FromRow)] pub struct User { pub id: Id, - pub avatar_url: Option, + pub avatar_id: Option, pub username: String, pub display_name: Option, pub email: String, diff --git a/src/jwt.rs b/src/jwt.rs index d569f6b..1be8ad6 100644 --- a/src/jwt.rs +++ b/src/jwt.rs @@ -15,7 +15,7 @@ pub struct Claims { pub iat: i64, } -pub fn generate_jwt(data: T) -> Result { +pub fn generate_jwt(data: T, secret: &[u8]) -> Result { let claims = Claims { data, iat: Utc::now().timestamp_millis(), @@ -24,14 +24,14 @@ pub fn generate_jwt(data: T) -> Result { let token = jsonwebtoken::encode( &jsonwebtoken::Header::default(), &claims, - &jsonwebtoken::EncodingKey::from_secret(config::config().security.auth_secret.as_ref()), + &jsonwebtoken::EncodingKey::from_secret(secret), ) .map_err(|_| Error::CouldNotEncodeToken)?; Ok(token) } -pub fn verify_jwt(token: &str) -> Result { +pub fn verify_jwt(token: &str, secret: &[u8]) -> Result { tracing::debug!("verifying token: {}", token); let mut validation = jsonwebtoken::Validation::default(); @@ -39,9 +39,12 @@ pub fn verify_jwt(token: &str) -> Result { let token_data = jsonwebtoken::decode::>( token, - &jsonwebtoken::DecodingKey::from_secret(config::config().security.auth_secret.as_ref()), + &jsonwebtoken::DecodingKey::from_secret(secret), &validation, ) + .inspect_err(|err| { + tracing::error!("Failed to decode JWT: {:?}", err); + }) .map_err(|_| Error::CouldNotVerifyToken)?; Ok(token_data.claims.data) diff --git a/src/main.rs b/src/main.rs index 03d472c..9118ca5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,7 @@ mod database; mod entity; mod jwt; mod log; +mod object_store; mod state; mod util; mod web; @@ -20,8 +21,10 @@ async fn main() -> anyhow::Result<()> { let _guard = log::init_logging()?; let database = Database::connect(&config::config().database).await?; + let object_store = object_store::ObjectStore::connect(&config::config().object_store).await?; let state = AppState { database, + object_store, hasher: Arc::new(Argon2::default()), gateway_state: Default::default(), voice_rooms: Default::default(), diff --git a/src/object_store.rs b/src/object_store.rs new file mode 100644 index 0000000..2976f9a --- /dev/null +++ b/src/object_store.rs @@ -0,0 +1,51 @@ +use crate::config::ObjectStoreConfig; + +#[derive(Clone, derive_more::AsRef, derive_more::Deref)] +pub struct ObjectStore { + inner: Box, +} + +pub type Result = std::result::Result; + +#[derive(Debug, derive_more::From, derive_more::Error, derive_more::Display)] +pub enum Error { + #[from] + Credentials(s3::creds::error::CredentialsError), + + #[from] + S3(s3::error::S3Error), +} + +impl ObjectStore { + pub async fn connect(config: &ObjectStoreConfig) -> Result { + let region = s3::region::Region::Custom { + region: config.region.clone(), + endpoint: config.endpoint.origin().ascii_serialization(), + }; + + let credentials = s3::creds::Credentials::new( + Some(&config.access_key), + Some(&config.secret_key), + None, + None, + None, + )?; + + let mut bucket = + s3::bucket::Bucket::new(&config.bucket, region.clone(), credentials.clone())? + .with_path_style(); + + if !bucket.exists().await? { + bucket = s3::bucket::Bucket::create_with_path_style( + &config.bucket, + region, + credentials, + s3::BucketConfiguration::default(), + ) + .await? + .bucket; + } + + Ok(Self { inner: bucket }) + } +} diff --git a/src/state.rs b/src/state.rs index 9074860..49499ac 100644 --- a/src/state.rs +++ b/src/state.rs @@ -6,17 +6,19 @@ use tokio::sync::{RwLock, mpsc}; use uuid::Uuid; use crate::database::Database; +use crate::object_store::ObjectStore; use crate::web::ws::gateway::{GatewayWsState, SessionKey, event}; -use crate::webrtc::OfferSignal; +use crate::webrtc::WebRtcSignal; #[derive(Clone)] pub struct AppState { pub database: Database, + pub object_store: ObjectStore, pub hasher: Arc>, pub gateway_state: Arc, - pub voice_rooms: Arc>>>, + pub voice_rooms: Arc>>>, } #[derive(Debug, Default)] @@ -60,4 +62,16 @@ impl AppState { self.gateway_state.connected.remove_async(&user_id).await; } } + + pub async fn register_voice_room( + &self, + room_id: Uuid, + sender: mpsc::UnboundedSender, + ) { + self.voice_rooms.write().await.insert(room_id, sender); + } + + pub async fn unregister_voice_room(&self, room_id: Uuid) { + self.voice_rooms.write().await.remove(&room_id); + } } diff --git a/src/util.rs b/src/util.rs index 3feeade..9587904 100644 --- a/src/util.rs +++ b/src/util.rs @@ -2,6 +2,21 @@ use axum::extract::multipart::Field; use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError}; use serde::{Deserialize, Serialize}; +use crate::entity; + +pub fn file_id_to_url(file_id: &entity::file::Id) -> Option { + Some( + crate::config::config() + .server + .hostname + .join("files/") + .ok()? + .join(&file_id.to_string()) + .ok()? + .to_string(), + ) +} + #[derive(Debug, derive_more::Deref)] pub struct SerdeFieldData(pub FieldData); @@ -60,3 +75,25 @@ where let seconds = u64::deserialize(deserializer)?; Ok(std::time::Duration::from_secs(seconds)) } + +pub fn serialize_duration_seconds_option( + duration: &Option, + serializer: S, +) -> Result +where + S: serde::Serializer, +{ + match duration { + Some(duration) => serialize_duration_seconds(duration, serializer), + None => serializer.serialize_none(), + } +} + +pub fn deserialize_duration_seconds_option<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + Ok(deserialize_duration_seconds(deserializer).ok()) +} diff --git a/src/web/entity/message.rs b/src/web/entity/message.rs new file mode 100644 index 0000000..0f932cc --- /dev/null +++ b/src/web/entity/message.rs @@ -0,0 +1,35 @@ +use serde::Serialize; + +use crate::entity::message::Id; +use crate::entity::{channel, user}; + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Message { + pub id: Id, + pub author_id: user::Id, + pub channel_id: channel::Id, + pub content: String, + pub created_at: chrono::DateTime, +} + +impl From for Message { + fn from(message: crate::entity::message::Message) -> Self { + Self { + id: message.id, + author_id: message.author_id, + channel_id: message.channel_id, + content: message.content, + created_at: message + .id + .get_timestamp() + .as_ref() + .map(uuid::Timestamp::to_unix) + .map(|(secs, nsecs)| { + chrono::DateTime::::from_timestamp(secs as i64, nsecs) + }) + .flatten() + .unwrap_or_default(), + } + } +} diff --git a/src/web/entity/mod.rs b/src/web/entity/mod.rs new file mode 100644 index 0000000..6747680 --- /dev/null +++ b/src/web/entity/mod.rs @@ -0,0 +1,3 @@ +pub mod message; +pub mod user; +pub mod server; \ No newline at end of file diff --git a/src/web/entity/server.rs b/src/web/entity/server.rs new file mode 100644 index 0000000..76616e3 --- /dev/null +++ b/src/web/entity/server.rs @@ -0,0 +1,25 @@ +use serde::Serialize; + +use crate::entity::server::Id; +use crate::entity::user; +use crate::util; + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Server { + pub id: Id, + pub owner_id: user::Id, + pub name: String, + pub icon_url: Option, +} + +impl From for Server { + fn from(server: crate::entity::server::Server) -> Self { + Self { + id: server.id, + owner_id: server.owner_id, + name: server.name, + icon_url: server.icon_id.as_ref().map(util::file_id_to_url).flatten(), + } + } +} diff --git a/src/web/entity/user.rs b/src/web/entity/user.rs new file mode 100644 index 0000000..14b4f14 --- /dev/null +++ b/src/web/entity/user.rs @@ -0,0 +1,54 @@ +use crate::entity::user; +use crate::util; + +#[derive(serde::Serialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct FullUser { + pub id: user::Id, + pub avatar_url: Option, + pub username: String, + pub display_name: Option, + pub email: String, + pub bot: bool, + pub system: bool, + pub settings: serde_json::Value, +} + +#[derive(Debug, Clone, serde::Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PartialUser { + pub id: user::Id, + pub avatar_url: Option, + pub username: String, + pub display_name: Option, + pub bot: bool, + pub system: bool, +} + +impl From for FullUser { + fn from(user: user::User) -> Self { + Self { + id: user.id, + avatar_url: user.avatar_id.as_ref().map(util::file_id_to_url).flatten(), + username: user.username, + display_name: user.display_name, + email: user.email, + bot: user.bot, + system: user.system, + settings: user.settings, + } + } +} + +impl From for PartialUser { + fn from(user: user::User) -> Self { + Self { + id: user.id, + avatar_url: user.avatar_id.as_ref().map(util::file_id_to_url).flatten(), + username: user.username, + display_name: user.display_name, + bot: user.bot, + system: user.system, + } + } +} diff --git a/src/web/error.rs b/src/web/error.rs index 50ff203..1476f7d 100644 --- a/src/web/error.rs +++ b/src/web/error.rs @@ -4,7 +4,7 @@ use axum::http::StatusCode; use axum::response::IntoResponse; use crate::web::context; -use crate::{database, jwt}; +use crate::{database, jwt, object_store}; pub type Result = std::result::Result; @@ -22,6 +22,9 @@ pub enum Error { #[from] Database(database::Error), + #[from] + ObjectStore(object_store::Error), + #[from] Json(serde_json::error::Error), @@ -55,6 +58,8 @@ pub enum ClientError { ValidationFailed(validator::ValidationErrors), InternalServerError, + + Unknown, } #[derive(derive_more::Debug, Clone, serde::Serialize)] diff --git a/src/web/middleware/auth.rs b/src/web/middleware/auth.rs index e8604c0..7396226 100644 --- a/src/web/middleware/auth.rs +++ b/src/web/middleware/auth.rs @@ -52,7 +52,11 @@ async fn get_context(state: &AppState, request: &mut Request) -> context::UserCo } pub async fn get_context_from_token(state: &AppState, token: &str) -> context::UserContextResult { - let context = jwt::verify_jwt::(token).map_err(|_| context::Error::BadToken)?; + let context = jwt::verify_jwt::( + token, + crate::config::config().security.auth_secret.as_ref(), + ) + .map_err(|_| context::Error::BadToken)?; let _ = state .database diff --git a/src/web/mod.rs b/src/web/mod.rs index 0044361..452f33d 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -1,4 +1,5 @@ mod context; +mod entity; mod error; mod middleware; mod route; @@ -38,6 +39,8 @@ fn router(state: state::AppState) -> axum::Router { // websocket .route("/gateway/ws", get(ws::gateway::ws_handler)) .route("/voice/ws", get(ws::voice::ws_handler)) + // file + .route("/files/{file_id}", get(file::get)) // api .nest( "/api/v1", @@ -66,13 +69,41 @@ fn protected_router() -> axum::Router { Router::new() // user .route("/users/@me", get(user::me)) + .route("/users/@me", patch(user::patch)) .route("/users/@me/channels", get(user::channel::list)) .route("/users/{id}", get(user::get_by_id)) + // channel + .route( + "/channels/{channel_id}/messages", + get(channel::message::page), + ) + .route( + "/channels/{channel_id}/messages", + post(channel::message::create), + ) // server .route("/servers", get(server::list)) .route("/servers", post(server::create)) .route("/servers/{server_id}", get(server::get)) + .route("/servers/{server_id}", delete(server::delete)) .route("/servers/{server_id}/channels", get(server::channel::list)) + .route( + "/servers/{server_id}/channels", + post(server::channel::create), + ) + .route( + "/servers/{server_id}/channels/{channel_id}", + get(server::channel::get), + ) + .route( + "/servers/{server_id}/channels/{channel_id}", + delete(server::channel::delete), + ) + // invite + .route("/servers/{server_id}/invites", post(server::invite::create)) + .route("/invites/{code}", get(server::invite::get)) + // file + .route("/files", post(file::upload)) // middleware .route_layer(axum::middleware::from_fn(middleware::require_context)) } diff --git a/src/web/route/auth/login.rs b/src/web/route/auth/login.rs index 88ef85b..8baafb2 100644 --- a/src/web/route/auth/login.rs +++ b/src/web/route/auth/login.rs @@ -6,8 +6,8 @@ use serde::{Deserialize, Serialize}; use crate::state::AppState; use crate::web::context::UserContext; -use crate::web::route::user::FullUser; use crate::{jwt, web}; +use crate::web::entity::user::FullUser; #[derive(Deserialize)] #[serde(rename_all = "camelCase")] @@ -39,7 +39,7 @@ pub async fn login( .verify_password(payload.password.as_bytes(), &password_hash) .map_err(|_| web::error::ClientError::WrongPassword)?; - let token = jwt::generate_jwt(UserContext { user_id: user.id })?; + let token = jwt::generate_jwt(UserContext { user_id: user.id }, crate::config::config().security.auth_secret.as_ref())?; let response = LoginResponse { user: user.into(), diff --git a/src/web/route/auth/register.rs b/src/web/route/auth/register.rs index 5759d94..214601c 100644 --- a/src/web/route/auth/register.rs +++ b/src/web/route/auth/register.rs @@ -9,8 +9,8 @@ use validator::Validate; use crate::state::AppState; use crate::web; +use crate::web::entity::user::FullUser; use crate::web::error::ClientError; -use crate::web::route::user::FullUser; #[derive(Validate, Deserialize)] #[serde(rename_all = "camelCase")] diff --git a/src/web/route/channel/message/create.rs b/src/web/route/channel/message/create.rs new file mode 100644 index 0000000..aa75250 --- /dev/null +++ b/src/web/route/channel/message/create.rs @@ -0,0 +1,49 @@ +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; +use validator::Validate; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::entity::message::Message; +use crate::web::ws; +use crate::{entity, web}; + +#[derive(Debug, serde::Deserialize, Validate)] +#[serde(rename_all = "camelCase")] +pub struct CreatePayload { + #[validate(length(min = 1, max = 2000))] + pub content: String, +} + +pub async fn create( + State(state): State, + context: UserContext, + Path(channel_id): Path, + Json(payload): Json, +) -> web::Result { + // TODO: check permissions + match payload.validate() { + Ok(_) => {}, + Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), + }; + + let message = state + .database + .insert_channel_message(context.user_id, channel_id, &payload.content) + .await?; + + let message = Message::from(message); + + // TODO: check permissions + ws::gateway::util::send_message_channel( + state, + message.channel_id, + ws::gateway::event::Event::AddMessage { + channel_id, + message: message.clone(), + }, + ); + + Ok(Json(message)) +} diff --git a/src/web/route/channel/message/mod.rs b/src/web/route/channel/message/mod.rs new file mode 100644 index 0000000..29a11aa --- /dev/null +++ b/src/web/route/channel/message/mod.rs @@ -0,0 +1,5 @@ +mod page; +mod create; + +pub use page::page; +pub use create::create; \ No newline at end of file diff --git a/src/web/route/channel/message/page.rs b/src/web/route/channel/message/page.rs new file mode 100644 index 0000000..a96d0b8 --- /dev/null +++ b/src/web/route/channel/message/page.rs @@ -0,0 +1,47 @@ +use axum::Json; +use axum::extract::{Path, Query, State}; +use axum::response::IntoResponse; +use serde::Deserialize; +use validator::Validate; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::entity::message::Message; +use crate::{entity, web}; + +#[derive(Debug, Deserialize, Validate)] +#[serde(rename_all = "camelCase")] +pub struct PageParams { + #[serde(default = "limit_default")] + #[validate(range(min = 1, max = 100))] + pub limit: u32, + #[serde(default)] + pub before: Option, +} + +fn limit_default() -> u32 { + 50 +} + +pub async fn page( + State(state): State, + context: UserContext, + Path(channel_id): Path, + Query(params): Query, +) -> web::Result { + // TODO: check permissions + match params.validate() { + Ok(_) => {}, + Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), + }; + + let messages = state + .database + .select_channel_messages_paginated(channel_id, params.before, params.limit as i64) + .await? + .into_iter() + .map(Message::from) + .collect::>(); + + Ok(Json(messages)) +} diff --git a/src/web/route/channel/mod.rs b/src/web/route/channel/mod.rs new file mode 100644 index 0000000..e216a50 --- /dev/null +++ b/src/web/route/channel/mod.rs @@ -0,0 +1 @@ +pub mod message; diff --git a/src/web/route/file/get.rs b/src/web/route/file/get.rs new file mode 100644 index 0000000..4e329cc --- /dev/null +++ b/src/web/route/file/get.rs @@ -0,0 +1,43 @@ +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::{entity, object_store, web}; + +pub async fn get( + State(state): State, + Path(file_id): Path, +) -> web::Result { + let file = match state.database.select_file_by_id(file_id).await { + Ok(file) => file, + Err(e) => { + return Ok(axum::http::StatusCode::NOT_FOUND.into_response()); + }, + }; + + let data = match state + .object_store + .get_object_stream(&file.id.to_string()) + .await + { + Ok(data) => data, + Err(s3::error::S3Error::HttpFailWithBody(403 | 404, _)) => { + let _ = state.database.delete_file_by_id(file.id).await?; + + return Ok(axum::http::StatusCode::NOT_FOUND.into_response()); + }, + Err(e) => { + return Err(object_store::Error::from(e).into()); + }, + }; + + let headers = axum::response::AppendHeaders([ + (axum::http::header::CONTENT_TYPE, file.content_type.clone()), + ( + axum::http::header::CONTENT_DISPOSITION, + format!("filename=\"{}\"", file.filename), + ), + ]); + + Ok((headers, axum::body::Body::from_stream(data.bytes)).into_response()) +} diff --git a/src/web/route/file/mod.rs b/src/web/route/file/mod.rs new file mode 100644 index 0000000..a0c29e7 --- /dev/null +++ b/src/web/route/file/mod.rs @@ -0,0 +1,5 @@ +mod get; +mod upload; + +pub use get::get; +pub use upload::upload; diff --git a/src/web/route/file/upload.rs b/src/web/route/file/upload.rs new file mode 100644 index 0000000..094a5ad --- /dev/null +++ b/src/web/route/file/upload.rs @@ -0,0 +1,54 @@ +use axum::Json; +use axum::body::Bytes; +use axum::extract::State; +use axum::response::IntoResponse; +use axum_typed_multipart::{TryFromMultipart, TypedMultipart}; +use validator::Validate; + +use crate::state::AppState; +use crate::util::SerdeFieldData; +use crate::{object_store, web}; + +#[derive(Debug, Validate, TryFromMultipart)] +#[try_from_multipart(rename_all = "camelCase")] +pub struct UploadPayload { + #[form_data(limit = "50MB")] + #[validate(length(min = 1, max = 16))] + files: Vec>, +} + +pub async fn upload( + State(state): State, + TypedMultipart(payload): TypedMultipart, +) -> web::Result { + match payload.validate() { + Ok(_) => {}, + Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), + }; + + let mut file_ids = Vec::new(); + + for file in payload.files { + let db_file = state + .database + .insert_file( + file.metadata + .file_name + .as_deref() + .unwrap_or_else(|| "unknown"), + file.metadata.content_type.as_deref().unwrap_or_default(), + file.contents.len(), + ) + .await?; + + state + .object_store + .put_object(&db_file.id.to_string(), &file.contents) + .await + .map_err(object_store::Error::from)?; + + file_ids.push(db_file.id); + } + + Ok(Json(file_ids)) +} diff --git a/src/web/route/mod.rs b/src/web/route/mod.rs index 5805195..55cc745 100644 --- a/src/web/route/mod.rs +++ b/src/web/route/mod.rs @@ -1,4 +1,5 @@ pub mod auth; +pub mod channel; +pub mod file; pub mod server; pub mod user; -pub mod voice; diff --git a/src/web/route/server/channel/create.rs b/src/web/route/server/channel/create.rs new file mode 100644 index 0000000..6a7b6af --- /dev/null +++ b/src/web/route/server/channel/create.rs @@ -0,0 +1,58 @@ +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; +use axum_extra::extract::WithRejection; +use serde::Deserialize; +use validator::Validate; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::error::ClientError; +use crate::web::ws; +use crate::{entity, web}; + +#[derive(Debug, Validate, Deserialize)] +pub struct CreatePayload { + #[validate(length(min = 1, max = 32))] + name: String, + + #[validate(custom(function = "validate_server_channel_type"))] + r#type: entity::channel::ChannelType, +} + +fn validate_server_channel_type( + r#type: &entity::channel::ChannelType, +) -> Result<(), validator::ValidationError> { + match r#type { + entity::channel::ChannelType::ServerText => Ok(()), + entity::channel::ChannelType::ServerVoice => Ok(()), + entity::channel::ChannelType::ServerCategory => Ok(()), + _ => Err(validator::ValidationError::new("invalid_channel_type")), + } +} + +pub async fn create( + State(state): State, + context: UserContext, + Path(server_id): Path, + WithRejection(Json(payload), _): WithRejection, web::Error>, +) -> web::Result { + payload.validate().map_err(ClientError::ValidationFailed)?; + + // TODO: check permissions + let server = state.database.select_server_by_id(server_id).await?; + let channel = state + .database + .insert_server_channel(&payload.name, 0, payload.r#type, server_id, None) + .await?; + + ws::gateway::util::send_message_server( + state, + server_id, + ws::gateway::event::Event::AddServerChannel { + channel: channel.clone(), + }, + ); + + Ok(Json(channel)) +} diff --git a/src/web/route/server/channel/delete.rs b/src/web/route/server/channel/delete.rs new file mode 100644 index 0000000..a8832a1 --- /dev/null +++ b/src/web/route/server/channel/delete.rs @@ -0,0 +1,39 @@ +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::ws; +use crate::{entity, web}; + +pub async fn delete( + State(state): State, + context: UserContext, + Path((server_id, channel_id)): Path<(entity::server::Id, entity::channel::Id)>, +) -> web::Result { + // TODO: check permissions + + let channel = state.database.select_channel_by_id(channel_id).await?; + + if let Some(channel_server_id) = channel.server_id { + if channel_server_id != server_id { + return Err(web::error::ClientError::NotAllowed.into()); + } + } else { + return Err(web::error::ClientError::NotAllowed.into()); + } + + let channel = state.database.delete_channel_by_id(channel_id).await?; + + ws::gateway::util::send_message_server( + state, + server_id, + ws::gateway::event::Event::RemoveServerChannel { + server_id: server_id.clone(), + channel_id: channel.id.clone(), + }, + ); + + Ok(Json(channel)) +} diff --git a/src/web/route/server/channel/get.rs b/src/web/route/server/channel/get.rs new file mode 100644 index 0000000..ba6638c --- /dev/null +++ b/src/web/route/server/channel/get.rs @@ -0,0 +1,27 @@ +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::{entity, web}; + +pub async fn get( + State(state): State, + context: UserContext, + Path((server_id, channel_id)): Path<(entity::server::Id, entity::channel::Id)>, +) -> web::Result { + // TODO: check permissions + + let channel = state.database.select_channel_by_id(channel_id).await?; + + if let Some(channel_server_id) = channel.server_id { + if channel_server_id != server_id { + return Err(web::error::ClientError::NotAllowed.into()); + } + } else { + return Err(web::error::ClientError::NotAllowed.into()); + } + + Ok(Json(channel)) +} diff --git a/src/web/route/server/channel/list.rs b/src/web/route/server/channel/list.rs index bf1330e..b4f4286 100644 --- a/src/web/route/server/channel/list.rs +++ b/src/web/route/server/channel/list.rs @@ -9,9 +9,9 @@ use crate::{entity, web}; pub async fn list( State(state): State, context: UserContext, - Path(id): Path, + Path(server_id): Path, ) -> web::Result { - let channels = state.database.select_server_channels(id).await?; + let channels = state.database.select_server_channels(server_id).await?; Ok(Json(channels)) } diff --git a/src/web/route/server/channel/mod.rs b/src/web/route/server/channel/mod.rs index 19f2172..327982c 100644 --- a/src/web/route/server/channel/mod.rs +++ b/src/web/route/server/channel/mod.rs @@ -1,3 +1,9 @@ +mod create; +mod delete; +mod get; mod list; +pub use create::create; +pub use delete::delete; +pub use get::get; pub use list::list; diff --git a/src/web/route/server/create.rs b/src/web/route/server/create.rs index 7e2f869..71116d3 100644 --- a/src/web/route/server/create.rs +++ b/src/web/route/server/create.rs @@ -1,50 +1,37 @@ use axum::Json; -use axum::body::Bytes; use axum::extract::State; use axum::response::IntoResponse; -use axum_typed_multipart::{TryFromMultipart, TypedMultipart}; -use validator::{Validate, ValidationError}; +use axum_extra::extract::WithRejection; +use axum_typed_multipart::TryFromMultipart; +use serde::Deserialize; +use validator::Validate; use crate::state::AppState; -use crate::util::SerdeFieldData; -use crate::web; use crate::web::context::UserContext; use crate::web::error::ClientError; use crate::web::ws; +use crate::{entity, web}; +use crate::web::entity::server::Server; -#[derive(Debug, Validate, TryFromMultipart)] -#[try_from_multipart(rename_all = "camelCase")] +#[derive(Debug, Validate, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct CreatePayload { #[validate(length(min = 1, max = 32))] name: String, - #[validate(custom(function = "validate_icon_content_type"))] - #[form_data(limit = "10MB")] - icon: Option>, -} - -fn validate_icon_content_type(icon: &SerdeFieldData) -> Result<(), ValidationError> { - if let Some(content_type) = icon.metadata.content_type.as_deref() { - if !content_type.starts_with("image/") { - return Err(ValidationError::new("invalid_icon_content_type")); - } - } else { - return Err(ValidationError::new("missing_icon_content_type")); - } - - Ok(()) + icon_id: Option, } pub async fn create( State(state): State, context: UserContext, - TypedMultipart(payload): TypedMultipart, + WithRejection(Json(payload), _): WithRejection, web::Error>, ) -> web::Result { payload.validate().map_err(ClientError::ValidationFailed)?; let server = state .database - .insert_server(&payload.name, None, context.user_id) + .insert_server(&payload.name, payload.icon_id, context.user_id) .await?; let role = state @@ -69,15 +56,16 @@ pub async fn create( .database .insert_server_member_role(member.id, role.id) .await?; + + let server = Server::from(server); ws::gateway::util::send_message( - &state, + state, context.user_id, ws::gateway::event::Event::AddServer { server: server.clone(), }, - ) - .await; + ); Ok(Json(server)) } diff --git a/src/web/route/server/delete.rs b/src/web/route/server/delete.rs new file mode 100644 index 0000000..1e85005 --- /dev/null +++ b/src/web/route/server/delete.rs @@ -0,0 +1,60 @@ +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::ws; +use crate::webrtc::WebRtcSignal; +use crate::{entity, web}; +use crate::web::entity::server::Server; + +pub async fn delete( + State(state): State, + context: UserContext, + Path(server_id): Path, +) -> web::Result { + let server = state.database.select_server_by_id(server_id).await?; + + if server.owner_id != context.user_id { + return Err(web::error::ClientError::NotAllowed.into()); + } + + let members = state + .database + .select_server_members(server_id) + .await? + .iter() + .map(|u| u.id) + .collect::>(); + + let channels = state + .database + .select_server_channels(server_id) + .await? + .iter() + .map(|c| c.id) + .collect::>(); + + let state_clone = state.clone(); + tokio::spawn(async move { + let voice_rooms = state_clone.voice_rooms.read().await; + for channel_id in channels { + if let Some(voice_room) = voice_rooms.get(&channel_id) { + let _ = voice_room.send(WebRtcSignal::Close); + } + } + }); + + let server = state.database.delete_server_by_id(server_id).await?; + + ws::gateway::util::send_message_many( + state.clone(), + &members, + ws::gateway::event::Event::RemoveServer { + server_id: server.id, + }, + ); + + Ok(Json(Server::from(server))) +} diff --git a/src/web/route/server/get.rs b/src/web/route/server/get.rs index 483b199..8b3ac74 100644 --- a/src/web/route/server/get.rs +++ b/src/web/route/server/get.rs @@ -4,12 +4,15 @@ use axum::response::IntoResponse; use crate::state::AppState; use crate::{entity, web}; +use crate::web::entity::server::Server; pub async fn get( State(state): State, - Path(id): Path, + Path(server_id): Path, ) -> web::Result { - let server = state.database.select_server_by_id(id).await?; + // TODO: check permissions + + let server = state.database.select_server_by_id(server_id).await?; - Ok(Json(server)) + Ok(Json(Server::from(server))) } diff --git a/src/web/route/server/invite/create.rs b/src/web/route/server/invite/create.rs new file mode 100644 index 0000000..acd6e2c --- /dev/null +++ b/src/web/route/server/invite/create.rs @@ -0,0 +1,45 @@ +use axum::Json; +use axum::extract::{Path, Query, State}; +use axum::response::IntoResponse; +use base64::Engine; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::{entity, web}; + +#[derive(serde::Deserialize, Debug)] +pub struct CreateParams { + #[serde(deserialize_with = "crate::util::deserialize_duration_seconds_option")] + #[serde(default)] + pub expires_in: Option, +} + +pub async fn create( + State(state): State, + context: UserContext, + Path(server_id): Path, + Query(params): Query, +) -> web::Result { + // TODO: check permissions + + let code = { + use rand::Rng; + + let mut rng = rand::rng(); + let mut code = [0u8; 32]; + rng.fill(&mut code); + base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(&code) + }; + + let expires_at = params.expires_in.map(|d| { + let now = chrono::Utc::now(); + now + chrono::Duration::from_std(d).expect("valid duration") + }); + + let invite = state + .database + .insert_server_invite(&code, server_id, Some(context.user_id), expires_at) + .await?; + + Ok(Json(invite)) +} diff --git a/src/web/route/server/invite/get.rs b/src/web/route/server/invite/get.rs new file mode 100644 index 0000000..0155dc9 --- /dev/null +++ b/src/web/route/server/invite/get.rs @@ -0,0 +1,46 @@ +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::ws; +use crate::{database, web}; +use crate::web::entity::server::Server; + +pub async fn get( + State(state): State, + context: UserContext, + Path(code): Path, +) -> web::Result { + let invite = state.database.select_server_invite_by_code(&code).await?; + let server = state.database.select_server_by_id(invite.server_id).await?; + + let member = match state + .database + .insert_server_member(invite.server_id, context.user_id) + .await + { + Ok(member) => member, + Err(database::Error::MemberAlreadyExists) => return Ok(Json(Server::from(server))), + Err(e) => return Err(e.into()), + }; + + state + .database + .insert_server_member_role(member.id, invite.server_id) + .await?; + + let user = state.database.select_user_by_id(context.user_id).await?; + + ws::gateway::util::send_message_server( + state, + invite.server_id, + ws::gateway::event::Event::AddServerMember { + server_id: server.id, + member: user.into(), + }, + ); + + Ok(Json(Server::from(server))) +} diff --git a/src/web/route/server/invite/mod.rs b/src/web/route/server/invite/mod.rs new file mode 100644 index 0000000..6ff327c --- /dev/null +++ b/src/web/route/server/invite/mod.rs @@ -0,0 +1,5 @@ +mod create; +mod get; + +pub use create::create; +pub use get::get; diff --git a/src/web/route/server/list.rs b/src/web/route/server/list.rs index 8d482c1..cc41e81 100644 --- a/src/web/route/server/list.rs +++ b/src/web/route/server/list.rs @@ -5,12 +5,19 @@ use axum::response::IntoResponse; use crate::state::AppState; use crate::web; use crate::web::context::UserContext; +use crate::web::entity::server::Server; pub async fn list( State(state): State, context: UserContext, ) -> web::Result { - let servers = state.database.select_user_servers(context.user_id).await?; + let servers = state + .database + .select_user_servers(context.user_id) + .await? + .into_iter() + .map(Server::from) + .collect::>(); Ok(Json(servers)) } diff --git a/src/web/route/server/mod.rs b/src/web/route/server/mod.rs index 4f7e49a..ae43af5 100644 --- a/src/web/route/server/mod.rs +++ b/src/web/route/server/mod.rs @@ -1,8 +1,11 @@ pub mod channel; mod create; +mod delete; mod get; +pub mod invite; mod list; pub use create::create; +pub use delete::delete; pub use get::get; pub use list::list; diff --git a/src/web/route/user/channel/list.rs b/src/web/route/user/channel/list.rs index 8e146af..79e9450 100644 --- a/src/web/route/user/channel/list.rs +++ b/src/web/route/user/channel/list.rs @@ -7,7 +7,7 @@ use crate::entity::channel; use crate::state::AppState; use crate::web; use crate::web::context::UserContext; -use crate::web::route::user::PartialUser; +use crate::web::entity::user::PartialUser; #[derive(Debug, sqlx::FromRow, Serialize)] #[serde(rename_all = "camelCase")] diff --git a/src/web/route/user/get.rs b/src/web/route/user/get.rs index d017ea7..139540c 100644 --- a/src/web/route/user/get.rs +++ b/src/web/route/user/get.rs @@ -4,13 +4,13 @@ use axum::response::IntoResponse; use crate::state::AppState; use crate::web; -use crate::web::route::user::PartialUser; +use crate::web::entity::user::PartialUser; pub async fn get_by_id( - Path(id): Path, State(state): State, + Path(user_id): Path, ) -> web::Result { - let user = state.database.select_user_by_id(id).await?; + let user = state.database.select_user_by_id(user_id).await?; Ok(Json(PartialUser::from(user))) } diff --git a/src/web/route/user/me.rs b/src/web/route/user/me.rs index d9e81b5..8745382 100644 --- a/src/web/route/user/me.rs +++ b/src/web/route/user/me.rs @@ -5,11 +5,11 @@ use axum::response::IntoResponse; use crate::state::AppState; use crate::web; use crate::web::context::UserContext; -use crate::web::route::user::FullUser; +use crate::web::entity::user::FullUser; pub async fn me( - context: UserContext, State(state): State, + context: UserContext, ) -> web::Result { let user = state.database.select_user_by_id(context.user_id).await?; diff --git a/src/web/route/user/mod.rs b/src/web/route/user/mod.rs index 08dd7d6..640a054 100644 --- a/src/web/route/user/mod.rs +++ b/src/web/route/user/mod.rs @@ -1,60 +1,9 @@ pub mod channel; mod get; mod me; +mod patch; pub use get::get_by_id; pub use me::me; +pub use patch::patch; -use crate::entity::user; - -#[derive(serde::Serialize, Debug)] -#[serde(rename_all = "camelCase")] -pub struct FullUser { - pub id: user::Id, - pub avatar_url: Option, - pub username: String, - pub display_name: Option, - pub email: String, - pub bot: bool, - pub system: bool, - pub settings: serde_json::Value, -} - -#[derive(serde::Serialize, Debug)] -#[serde(rename_all = "camelCase")] -pub struct PartialUser { - pub id: user::Id, - pub avatar_url: Option, - pub username: String, - pub display_name: Option, - pub bot: bool, - pub system: bool, -} - -impl From for FullUser { - fn from(user: user::User) -> Self { - Self { - id: user.id, - avatar_url: user.avatar_url, - username: user.username, - display_name: user.display_name, - email: user.email, - bot: user.bot, - system: user.system, - settings: user.settings, - } - } -} - -impl From for PartialUser { - fn from(user: user::User) -> Self { - Self { - id: user.id, - avatar_url: user.avatar_url, - username: user.username, - display_name: user.display_name, - bot: user.bot, - system: user.system, - } - } -} diff --git a/src/web/route/user/patch.rs b/src/web/route/user/patch.rs new file mode 100644 index 0000000..c23e187 --- /dev/null +++ b/src/web/route/user/patch.rs @@ -0,0 +1,61 @@ +use axum::Json; +use axum::extract::State; +use axum::response::IntoResponse; +use axum_extra::extract::WithRejection; +use serde::Deserialize; +use validator::Validate; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::entity::user::{FullUser, PartialUser}; +use crate::web::ws; +use crate::{entity, web}; + +#[derive(Debug, Validate, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CreatePayload { + #[validate(length(min = 1, max = 32))] + #[serde(default)] + display_name: Option, + + #[serde(default)] + avatar_id: Option, +} + +pub async fn patch( + State(state): State, + context: UserContext, + WithRejection(Json(payload), _): WithRejection, web::Error>, +) -> web::Result { + match payload.validate() { + Ok(_) => {}, + Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), + }; + + let user = state + .database + .update_user_by_id( + context.user_id, + payload.display_name.as_deref(), + payload.avatar_id, + ) + .await?; + + ws::gateway::util::send_message( + state.clone(), + context.user_id, + ws::gateway::event::Event::AddUser { + user: PartialUser::from(user.clone()), + }, + ); + + ws::gateway::util::send_message_related( + state.clone(), + context.user_id, + ws::gateway::event::Event::AddUser { + user: PartialUser::from(user.clone()), + }, + ); + + Ok(Json(FullUser::from(user))) +} diff --git a/src/web/route/voice/connect.rs b/src/web/route/voice/connect.rs deleted file mode 100644 index 04b9b2d..0000000 --- a/src/web/route/voice/connect.rs +++ /dev/null @@ -1,91 +0,0 @@ -use axum::Json; -use axum::extract::{Path, State}; -use axum::response::IntoResponse; -use axum_extra::extract::WithRejection; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::{entity, web}; - -#[derive(Debug, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Payload { - sdp: RTCSessionDescription, -} - -#[derive(Debug, serde::Serialize)] -#[serde(rename_all = "camelCase")] -pub struct Response { - sdp: RTCSessionDescription, -} - -pub async fn connect( - State(state): State, - context: UserContext, - Path(channel_id): Path, - WithRejection(Json(payload), _): WithRejection, web::Error>, -) -> web::Result { - tracing::debug!("connect to voice channel: {:?}", channel_id); - - let channel = state.database.select_channel_by_id(channel_id).await?; - let channel_id = channel.id; - - let room_sender = { - state - .voice_rooms - .read() - .await - .get(&channel_id) - .map(|room| room.clone()) - }; - - let room_sender = match room_sender { - Some(room) => room, - None => { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - - let rooms = state.voice_rooms.clone(); - tokio::spawn(async move { - crate::webrtc::webrtc_task(channel_id, rx) - .await - .unwrap_or_else(|err| { - tracing::error!("webrtc task error: {:?}", err); - }); - - { - let mut rooms = rooms.write().await; - rooms.remove(&channel_id); - } - }); - - { - let mut rooms = state.voice_rooms.write().await; - rooms.insert(channel_id, tx.clone()); - } - - tx - }, - }; - - let offer = crate::webrtc::Offer { - peer_id: context.user_id, - sdp_offer: payload.sdp, - }; - - let (response_tx, response_rx) = tokio::sync::oneshot::channel(); - let _ = room_sender.send(crate::webrtc::OfferSignal { - offer, - response: response_tx, - }); - - let answer = response_rx - .await - .map_err(|_| web::error::ClientError::InternalServerError)?; - - let response = Response { - sdp: answer.sdp_answer, - }; - - Ok(Json(response)) -} diff --git a/src/web/route/voice/mod.rs b/src/web/route/voice/mod.rs deleted file mode 100644 index fc9caa8..0000000 --- a/src/web/route/voice/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod connect; - -pub use connect::connect; diff --git a/src/web/ws/error.rs b/src/web/ws/error.rs index 55eb7d5..ee70b70 100644 --- a/src/web/ws/error.rs +++ b/src/web/ws/error.rs @@ -1,14 +1,21 @@ -pub type Result = std::result::Result; +pub type Result = std::result::Result>; #[derive(Debug, derive_more::From, derive_more::Display)] -pub enum Error { +pub enum Error { + #[from] + Custom(T), + #[from] Json(serde_json::Error), #[from] AcknowledgementError(tokio::sync::oneshot::error::RecvError), - + WrongMessageType, WebSocketClosed, + + UnknownError, } + +pub trait CustomError {} diff --git a/src/web/ws/gateway/connection.rs b/src/web/ws/gateway/connection.rs index bb678b2..14fb04b 100644 --- a/src/web/ws/gateway/connection.rs +++ b/src/web/ws/gateway/connection.rs @@ -1,101 +1,75 @@ -use std::ops::ControlFlow; - -use axum::extract::ws::{Message as AxumMessage, WebSocket}; +use axum::extract::ws::Message as AxumMessage; use base64::Engine as _; -use futures::stream::SplitStream; -use futures::{Sink, SinkExt, StreamExt}; -use serde::Serialize; +use futures::{Stream, StreamExt}; use sha2::{Digest, Sha256}; -use tokio::time::Instant; +use tokio::sync::mpsc; -use super::error::{self, Error as WsError}; +use super::error::Error as WsError; use super::event::Event as WsEvent; use super::protocol::{WsClientMessage, WsServerMessage}; use super::state::{WsContext, WsState, WsUserContext}; use crate::jwt; use crate::state::AppState; use crate::web::ws::gateway::SessionKey; -use crate::web::ws::util::{SendWsMessage, deserialize_ws_message, serialize_ws_message}; -use crate::web::ws::{util, voice}; +use crate::web::ws::general::WebSocketHandler; +use crate::web::ws::util::{SendWsMessage, deserialize_ws_message}; +use crate::web::ws::voice; +use crate::webrtc::WebRtcSignal; -/// Main handler for an individual WebSocket connection's lifecycle. -/// Spawned by Axum upon successful WebSocket upgrade. -#[tracing::instrument(skip_all, name = "ws_connection_handler")] -pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) { - let (ws_sink, ws_stream) = websocket.split(); +impl WebSocketHandler for WsContext { + type ServerMessage = WsServerMessage; + type ClientMessage = WsClientMessage; + type Error = WsError; - let (internal_send_tx, internal_send_rx) = tokio::sync::mpsc::unbounded_channel(); + async fn handle_stream( + &mut self, + stream: S, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, + ) -> crate::web::ws::error::Result<(), Self::Error> + where + S: Stream> + Unpin, + { + process_websocket_messages(self, stream, sender, app_state).await?; - let writer_task = util::spawn_writer_task(ws_sink, internal_send_rx); - - let mut context = WsContext { - connection_state: WsState::Initialize, - user_context: None, - event_channel: None, - }; - - let processing_result = process_websocket_messages( - &mut context, - ws_stream, - &internal_send_tx, - &app_state, - ) - .await; - - // --- Cleanup --- - if let Some(user_ctx_data) = &context.user_context { - app_state - .unregister_gateway_connected_user(user_ctx_data.user_id, &user_ctx_data.session_key) - .await; - tracing::info!(user_id = ?user_ctx_data.user_id, session_key = %user_ctx_data.session_key, "Unregistered WebSocket user."); + Ok(()) } - // Drop our sender for the event channel; receiver in `process_websocket_messages` will see this. - drop(context.event_channel.take()); - - // If processing loop exited with an error (not a graceful close like WebSocketClosed or HeartbeatTimeout), - // try to send a final error message to the client. - if let Err(err_to_report) = &processing_result { - if !matches!( - err_to_report, - WsError::WebSocketClosed - ) { - tracing::warn!( - "WebSocket processing error, attempting to notify client: {:?}", - err_to_report - ); - let client_err_code = err_to_report.as_client_error(); - let error_ws_message = WsServerMessage::Error { - code: client_err_code, - }; - // Use new_no_response for best-effort send during shutdown. - // Ignore result as internal_send_tx might already be closed if writer_task ended. - let _ = internal_send_tx.send(SendWsMessage::new_no_response(error_ws_message)); + async fn cleanup(&mut self, app_state: &AppState) { + if let Some(user_ctx_data) = &self.user_context { + app_state + .unregister_gateway_connected_user( + user_ctx_data.user_id, + &user_ctx_data.session_key, + ) + .await; } + + drop(self.event_channel.take()); } - // Signal writer task to stop by dropping the MPSC sender. - drop(internal_send_tx); - // Wait for the writer task to complete its shutdown. - if let Err(e) = writer_task.await { - tracing::error!( - "WebSocket writer task panicked or encountered an error: {:?}", - e - ); + async fn handle_result_error( + &mut self, + error: Self::Error, + sender: &mpsc::UnboundedSender>, + ) { + let error_ws_message = WsServerMessage::Error { code: error }; + + let _ = sender.send(SendWsMessage::new_no_response(error_ws_message)); } - tracing::debug!(result = ?processing_result, "WebSocket connection handler finished."); } -/// Main loop for processing incoming WebSocket messages and outgoing application events. -/// Manages state transitions (Initialize -> Connected) and heartbeating. #[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id) ))] -async fn process_websocket_messages( +async fn process_websocket_messages( context: &mut WsContext, - mut ws_stream: SplitStream, - sender: &tokio::sync::mpsc::UnboundedSender>, + mut ws_stream: S, + sender: &mpsc::UnboundedSender>, app_state: &AppState, -) -> error::Result<()> { +) -> crate::web::ws::error::Result<(), WsError> +where + S: Stream> + Unpin, +{ loop { match context.connection_state { WsState::Initialize => { @@ -104,24 +78,17 @@ async fn process_websocket_messages( maybe_message = ws_stream.next() => { match maybe_message { Some(Ok(message)) => { - match handle_initial_message(context, message, sender, app_state).await { - Ok(ControlFlow::Continue(())) => {}, - Ok(ControlFlow::Break(new_state)) => { // Authenticated - context.connection_state = new_state; - tracing::info!(user_id = ?context.user_context.as_ref().unwrap().user_id, "User authenticated, WebSocket connected."); - }, - Err(e) => { // Auth failed critically or other error - return Err(e); - } - } + handle_initial_message(context, message, sender, app_state).await?; + context.connection_state = WsState::Connected; + tracing::info!(user_id = ?context.user_context.as_ref().unwrap().user_id, "User authenticated, WebSocket connected."); } Some(Err(axum_ws_err)) => { tracing::debug!("WebSocket stream error during Initialize: {:?}", axum_ws_err); - return Err(WsError::WebSocketClosed); + return Err(crate::web::ws::error::Error::WebSocketClosed); } None => { // Stream closed by client tracing::debug!("WebSocket stream ended by client during Initialize state."); - return Err(WsError::WebSocketClosed); + return Err(crate::web::ws::error::Error::WebSocketClosed); } } } @@ -139,30 +106,26 @@ async fn process_websocket_messages( tokio::select! { biased; - // Listen for application events to send to the client maybe_app_event = event_rx.recv() => { if let Some(app_event_data) = maybe_app_event { SendWsMessage::send_with_response(sender, WsServerMessage::Event { event: app_event_data }).await?; - // Sending an app event doesn't reset the client's ping requirement. } else { - // Event channel closed (e.g., AppState unregistered, or system shutdown signal) tracing::info!(user_id = ?user_ctx.user_id, "Event channel closed, closing WebSocket."); - return Ok(()); // Graceful shutdown signaled by closed event channel + return Ok(()); } } - // Listen for messages from the client (e.g., Ping) maybe_ws_message = ws_stream.next() => { match maybe_ws_message { Some(Ok(message)) => { - handle_connected_message(context, message, sender).await?; + handle_connected_message(context, message, sender, &app_state).await?; } Some(Err(axum_ws_err)) => { tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream error during Connected: {:?}", axum_ws_err); - return Err(WsError::WebSocketClosed); + return Err(crate::web::ws::error::Error::WebSocketClosed); } None => { // Stream closed by client tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream ended by client during Connected state."); - return Err(WsError::WebSocketClosed); + return Err(crate::web::ws::error::Error::WebSocketClosed); } } } @@ -172,16 +135,13 @@ async fn process_websocket_messages( } } -/// Handles messages received when the connection is in the `Initialize` state. -/// Expects `Authenticate` to transition to `Connected`, or `Ping` to stay in `Initialize`. #[tracing::instrument(skip_all, fields(state = ?context.connection_state))] async fn handle_initial_message( context: &mut WsContext, message: AxumMessage, - sender: &tokio::sync::mpsc::UnboundedSender>, // Changed to reference + sender: &mpsc::UnboundedSender>, // Changed to reference app_state: &AppState, -) -> error::Result> { - // Break(NewState) or Continue(()) +) -> crate::web::ws::error::Result<(), WsError> { match deserialize_ws_message(message)? { WsClientMessage::Authenticate { token } => { match crate::web::middleware::get_context_from_token(&app_state, &token).await { @@ -212,7 +172,7 @@ async fn handle_initial_message( .register_gateway_connected_user( user_id, current_session_key.clone(), - event_tx, // This is ws::state::EventSender -> mpsc::UnboundedSender + event_tx, ) .await; @@ -224,41 +184,35 @@ async fn handle_initial_message( }, ) .await?; - // Deadline is reset by the caller upon ControlFlow::Break - Ok(ControlFlow::Break(WsState::Connected)) + Ok(()) }, Err(_auth_err) => { tracing::warn!(token = %token, "Authentication failed for token."); - // Send AuthenticateDenied, then the connection will be closed by HeartbeatTimeout or by returning error. - // We send response to ensure client gets the denial before we might drop connection. let _ = SendWsMessage::send_with_response( sender, WsServerMessage::AuthenticateDenied, ) .await; - Err(WsError::AuthenticationFailed) // This will terminate process_websocket_messages + Err(WsError::AuthenticationFailed.into()) }, } }, - // Per original code, only Authenticate and Ping are expected in Initialize. - // If WsClientMessage has other variants, this might need adjustment. #[allow(unreachable_patterns)] _ => { tracing::warn!("Unexpected message type received during Initialize state."); - Err(WsError::UnexpectedMessageType) + Err(crate::web::ws::error::Error::WrongMessageType) }, } } -/// Handles messages received when the connection is in the `Connected` state. -/// Primarily expects `Ping` messages to keep the connection alive. #[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id) ))] async fn handle_connected_message( context: &mut WsContext, message: AxumMessage, - sender: &tokio::sync::mpsc::UnboundedSender>, -) -> error::Result<()> { + sender: &mpsc::UnboundedSender>, // Changed to reference + app_state: &AppState, +) -> crate::web::ws::error::Result<(), WsError> { match deserialize_ws_message(message)? { WsClientMessage::VoiceStateUpdate { server_id, @@ -274,11 +228,15 @@ async fn handle_connected_message( .user_id, server_id, channel_id, - iat: (chrono::Utc::now() + crate::config::config().gateway.voice_token_lifetime) + exp: (chrono::Utc::now() + crate::config::config().gateway.voice_token_lifetime) .timestamp(), }; - let token = jwt::generate_jwt(claims).map_err(|_| WsError::TokenGenerationFailed)?; + let token = jwt::generate_jwt( + claims, + crate::config::config().security.voice_secret.as_ref(), + ) + .map_err(|_| WsError::TokenGenerationFailed)?; SendWsMessage::send_with_response( sender, @@ -294,9 +252,47 @@ async fn handle_connected_message( Ok(()) }, + WsClientMessage::RequestVoiceStates { server_id } => { + let channels = app_state + .database + .select_server_channels(server_id) + .await + .map_err(|_| crate::web::ws::error::Error::UnknownError)?; + + for channel in channels { + let (tx, rx) = tokio::sync::oneshot::channel(); + + let webrtc_sender = + { app_state.voice_rooms.read().await.get(&channel.id).cloned() }; + + if let Some(voice_room) = webrtc_sender { + let _ = voice_room.send(WebRtcSignal::RequestPeers { response: tx }); + + let peers = match rx.await { + Ok(peers) => peers, + Err(_) => { + continue; + }, + }; + + for peer in peers { + let _ = + sender.send(SendWsMessage::new_no_response(WsServerMessage::Event { + event: WsEvent::VoiceChannelConnected { + server_id, + channel_id: channel.id, + user_id: peer, + }, + })); + } + } + } + + Ok(()) + }, other_message => { tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state."); - Err(WsError::UnexpectedMessageType) + Err(crate::web::ws::error::Error::WrongMessageType) }, } } diff --git a/src/web/ws/gateway/error.rs b/src/web/ws/gateway/error.rs index 1d1eb94..b4822ae 100644 --- a/src/web/ws/gateway/error.rs +++ b/src/web/ws/gateway/error.rs @@ -1,56 +1,12 @@ +use crate::web::ws::error::CustomError; + pub type Result = std::result::Result; -#[derive(Debug, derive_more::From, derive_more::Display)] +#[derive(Debug, derive_more::From, derive_more::Display, serde::Serialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum Error { - #[from] - Axum(axum::Error), - - #[from] - Json(serde_json::Error), - - #[from] - AcknowledgementError(tokio::sync::oneshot::error::RecvError), - - UnexpectedMessageType, - - WrongMessageType, - - WebSocketClosed, - AuthenticationFailed, - TokenGenerationFailed, } -#[derive(Debug, Clone, serde::Serialize)] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum ClientError { - DeserializationError, - NotAuthenticated, - AlreadyAuthenticated, - - Unknown, -} - -impl Error { - pub fn as_client_error(&self) -> ClientError { - match self { - Error::Json(_) => ClientError::DeserializationError, - Error::UnexpectedMessageType => ClientError::Unknown, - Error::WrongMessageType => ClientError::Unknown, - Error::WebSocketClosed => ClientError::Unknown, - _ => ClientError::Unknown, - } - } -} - -impl From for Error { - fn from(err: crate::web::ws::error::Error) -> Self { - match err { - crate::web::ws::error::Error::Json(e) => Error::Json(e), - crate::web::ws::error::Error::AcknowledgementError(e) => Error::AcknowledgementError(e), - crate::web::ws::error::Error::WrongMessageType => Error::WrongMessageType, - crate::web::ws::error::Error::WebSocketClosed => Error::WebSocketClosed, - } - } -} +impl CustomError for Error {} diff --git a/src/web/ws/gateway/event.rs b/src/web/ws/gateway/event.rs index 624b20c..6734621 100644 --- a/src/web/ws/gateway/event.rs +++ b/src/web/ws/gateway/event.rs @@ -1,23 +1,77 @@ -use crate::entity; +use crate::{entity, web}; #[derive(Debug, Clone, serde::Serialize)] #[serde(tag = "type", content = "data")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum Event { #[serde(rename_all = "camelCase")] - AddServer { server: entity::server::Server }, + AddServer { server: web::entity::server::Server }, + #[serde(rename_all = "camelCase")] RemoveServer { server_id: entity::server::Id }, #[serde(rename_all = "camelCase")] AddDmChannel { channel: entity::channel::Channel }, + #[serde(rename_all = "camelCase")] RemoveDmChannel { channel_id: entity::channel::Id }, #[serde(rename_all = "camelCase")] AddServerChannel { channel: entity::channel::Channel }, + #[serde(rename_all = "camelCase")] - RemoveServerChannel { channel_id: entity::channel::Id }, + RemoveServerChannel { + server_id: entity::server::Id, + channel_id: entity::channel::Id, + }, + + #[serde(rename_all = "camelCase")] + AddUser { + user: web::entity::user::PartialUser, + }, + + #[serde(rename_all = "camelCase")] + RemoveUser { + user_id: entity::user::Id, + }, + + #[serde(rename_all = "camelCase")] + AddServerMember { + server_id: entity::server::Id, + member: web::entity::user::PartialUser, + }, + + #[serde(rename_all = "camelCase")] + RemoveServerMember { + server_id: entity::server::Id, + member_id: entity::user::Id, + }, + + #[serde(rename_all = "camelCase")] + AddMessage { + channel_id: entity::channel::Id, + message: web::entity::message::Message, + }, + + #[serde(rename_all = "camelCase")] + RemoveMessage { + channel_id: entity::channel::Id, + message_id: entity::message::Id, + }, + + #[serde(rename_all = "camelCase")] + VoiceChannelConnected { + server_id: entity::server::Id, + channel_id: entity::channel::Id, + user_id: entity::user::Id, + }, + + #[serde(rename_all = "camelCase")] + VoiceChannelDisconnected { + server_id: entity::server::Id, + channel_id: entity::channel::Id, + user_id: entity::user::Id, + }, #[serde(rename_all = "camelCase")] VoiceServerUpdate { diff --git a/src/web/ws/gateway/mod.rs b/src/web/ws/gateway/mod.rs index 49137f6..8203657 100644 --- a/src/web/ws/gateway/mod.rs +++ b/src/web/ws/gateway/mod.rs @@ -3,8 +3,8 @@ use axum::response::IntoResponse; use dashmap::DashMap; use crate::state::AppState; -use crate::web::ws::gateway::connection::handle_socket_connection; -use crate::web::ws::gateway::state::EventSender; +use crate::web::ws::gateway::state::{EventSender, WsContext}; +use crate::web::ws::general; mod connection; mod error; @@ -37,5 +37,7 @@ pub async fn ws_handler( State(app_state): State, ws: WebSocketUpgrade, ) -> crate::web::error::Result { - Ok(ws.on_upgrade(|socket| handle_socket_connection(socket, app_state))) + Ok(ws.on_upgrade(|socket| { + general::handle_websocket_connection(socket, app_state, WsContext::default()) + })) } diff --git a/src/web/ws/gateway/protocol.rs b/src/web/ws/gateway/protocol.rs index 39ffdb5..819101b 100644 --- a/src/web/ws/gateway/protocol.rs +++ b/src/web/ws/gateway/protocol.rs @@ -1,12 +1,7 @@ -use std::time::Duration; - -use serde::{Deserialize, Serialize}; - -use super::error::ClientError; -use super::{SessionKey, event as ws_local_message}; +use super::{SessionKey, error, event}; use crate::entity; -#[derive(Debug, Serialize)] +#[derive(Debug, serde::Serialize)] #[serde(tag = "type", content = "data")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum WsServerMessage { @@ -20,27 +15,28 @@ pub enum WsServerMessage { #[serde(rename_all = "camelCase")] Event { - event: ws_local_message::Event, + event: event::Event, }, #[serde(rename_all = "camelCase")] Error { - code: ClientError, + code: error::Error, }, } -#[derive(Debug, Deserialize)] +#[derive(Debug, serde::Deserialize)] #[serde(tag = "type", content = "data")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum WsClientMessage { #[serde(rename_all = "camelCase")] - Authenticate { - token: String, - }, + Authenticate { token: String }, #[serde(rename_all = "camelCase")] VoiceStateUpdate { server_id: entity::server::Id, channel_id: entity::channel::Id, }, + + #[serde(rename_all = "camelCase")] + RequestVoiceStates { server_id: entity::server::Id }, } diff --git a/src/web/ws/gateway/state.rs b/src/web/ws/gateway/state.rs index 273f427..6e79e2f 100644 --- a/src/web/ws/gateway/state.rs +++ b/src/web/ws/gateway/state.rs @@ -1,37 +1,35 @@ -use std::time::Duration; - use tokio::sync::mpsc; -use super::{event, SessionKey}; -use crate::entity; // For entity::user::Id // For ws::message::Event used in EventSender/Receiver +use super::{SessionKey, event}; +use crate::entity; -/// Represents the current state of a single WebSocket connection. #[derive(Debug, Eq, PartialEq, Clone, Copy)] pub enum WsState { - Initialize, // Connection established, awaiting authentication - Connected, // Authenticated and operational + Initialize, + Connected, } -/// Contextual information for an authenticated WebSocket user session. #[derive(Debug)] pub struct WsUserContext { pub user_id: entity::user::Id, pub session_key: SessionKey, // Unique key for this specific WebSocket session instance } -/// Sender part of an MPSC channel used to send `ws::message::Event`s to a connected client. pub type EventSender = mpsc::UnboundedSender; -/// Receiver part of an MPSC channel used by a connection task to receive `ws::message::Event`s. pub type EventReceiver = mpsc::UnboundedReceiver; -/// Holds the full context for a single WebSocket connection's lifecycle. -/// This struct is managed per-connection. pub struct WsContext { pub connection_state: WsState, pub user_context: Option, - /// Channel for receiving application-specific events to be sent to this client. - /// The `EventSender` (tx) part is given to `AppState` for broadcasting. - /// The `EventReceiver` (rx) part is polled by the connection task. pub event_channel: Option<(EventSender, EventReceiver)>, } +impl Default for WsContext { + fn default() -> Self { + Self { + connection_state: WsState::Initialize, + user_context: None, + event_channel: None, + } + } +} diff --git a/src/web/ws/gateway/util.rs b/src/web/ws/gateway/util.rs index d8d798a..8657382 100644 --- a/src/web/ws/gateway/util.rs +++ b/src/web/ws/gateway/util.rs @@ -2,13 +2,67 @@ use crate::entity; use crate::state::AppState; use crate::web::ws::gateway::event; -pub async fn send_message(state: &AppState, user_id: entity::user::Id, message: event::Event) { - let connected_users = state.gateway_state.connected.get_async(&user_id).await; - if let Some(session) = connected_users { - for instance in session.instances.iter() { - if let Err(e) = instance.send(message.clone()) { - tracing::error!("failed to send message: {}", e); +pub fn send_message(state: AppState, user_id: entity::user::Id, message: event::Event) { + tokio::spawn(async move { + let connected_users = state.gateway_state.connected.get_async(&user_id).await; + if let Some(session) = connected_users { + for instance in session.instances.iter() { + if let Err(e) = instance.send(message.clone()) { + tracing::error!("failed to send message: {}", e); + } } } + }); +} + +pub fn send_message_many(state: AppState, user_ids: &[entity::user::Id], message: event::Event) { + for id in user_ids.iter() { + send_message(state.clone(), *id, message.clone()); } } + +pub fn send_message_server(state: AppState, server_id: entity::server::Id, message: event::Event) { + tokio::spawn(async move { + let users = state + .database + .select_server_members(server_id) + .await + .unwrap_or_else(|_| vec![]) + .iter() + .map(|u| u.id) + .collect::>(); + + send_message_many(state, &users, message); + }); +} + +pub fn send_message_channel( + state: AppState, + channel_id: entity::channel::Id, + message: event::Event, +) { + tokio::spawn(async move { + let users = state + .database + .select_channel_members(channel_id) + .await + .unwrap_or_else(|_| vec![]) + .iter() + .map(|u| u.id) + .collect::>(); + + send_message_many(state, &users, message); + }); +} + +pub fn send_message_related(state: AppState, user_id: entity::user::Id, message: event::Event) { + tokio::spawn(async move { + let users = state + .database + .select_related_user_ids(user_id) + .await + .unwrap_or_else(|_| vec![]); + + send_message_many(state, &users, message); + }); +} diff --git a/src/web/ws/general.rs b/src/web/ws/general.rs new file mode 100644 index 0000000..fd3103a --- /dev/null +++ b/src/web/ws/general.rs @@ -0,0 +1,74 @@ +use std::fmt::Debug; + +use axum::extract::ws::WebSocket; +use futures::{Stream, StreamExt}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use tokio::sync::mpsc; + +use crate::state::AppState; +use crate::web::ws::error::CustomError; +use crate::web::ws::util; +use crate::web::ws::util::SendWsMessage; + +pub trait WebSocketHandler { + type ServerMessage: Serialize + Send; + type ClientMessage: DeserializeOwned; + type Error: CustomError + Send + Debug; + + async fn handle_stream( + &mut self, + stream: S, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, + ) -> crate::web::ws::error::Result<(), Self::Error> + where + S: Stream> + Unpin; + + async fn cleanup(&mut self, app_state: &AppState); + + async fn handle_result_error( + &mut self, + error: Self::Error, + sender: &mpsc::UnboundedSender>, + ); +} + +#[tracing::instrument(skip_all)] +pub async fn handle_websocket_connection( + websocket: WebSocket, + app_state: AppState, + mut handler: impl WebSocketHandler + 'static, +) { + let (ws_sink, ws_stream) = websocket.split(); + + let (internal_send_tx, internal_send_rx) = mpsc::unbounded_channel(); + + let writer_task = util::spawn_writer_task(ws_sink, internal_send_rx); + + let processing_result = handler + .handle_stream(ws_stream, &internal_send_tx, &app_state) + .await; + + handler.cleanup(&app_state).await; + + match processing_result { + Ok(_) => {}, + Err(crate::web::ws::error::Error::Custom(err_to_report)) => { + handler + .handle_result_error(err_to_report, &internal_send_tx) + .await; + }, + Err(e) => { + tracing::info!("WebSocket connection closed: {:?}", e); + }, + } + + drop(internal_send_tx); + if let Err(e) = writer_task.await { + tracing::error!( + "WebSocket writer task panicked or encountered an error: {:?}", + e + ); + } +} diff --git a/src/web/ws/mod.rs b/src/web/ws/mod.rs index 658f016..039ed06 100644 --- a/src/web/ws/mod.rs +++ b/src/web/ws/mod.rs @@ -2,3 +2,4 @@ mod error; pub mod gateway; mod util; pub mod voice; +mod general; diff --git a/src/web/ws/util.rs b/src/web/ws/util.rs index 8817c77..74ba350 100644 --- a/src/web/ws/util.rs +++ b/src/web/ws/util.rs @@ -4,13 +4,16 @@ use serde::Serialize; use serde::de::DeserializeOwned; use tokio::sync::{mpsc, oneshot}; -pub fn spawn_writer_task( +use crate::web::ws::error::CustomError; + +pub fn spawn_writer_task( mut ws_sink: S, - mut writer_rx: mpsc::UnboundedReceiver>, + mut writer_rx: mpsc::UnboundedReceiver>, ) -> tokio::task::JoinHandle<()> where S: Sink + Unpin + Send + 'static, T: Serialize + Send + 'static, + E: CustomError + Send + 'static, { tokio::spawn(async move { while let Some(SendWsMessage { @@ -39,9 +42,9 @@ where } /// Deserializes an Axum WebSocket message into a `WsClientMessage`. -pub fn deserialize_ws_message( +pub fn deserialize_ws_message( message: AxumMessage, -) -> super::error::Result { +) -> super::error::Result { match message { AxumMessage::Text(text) => serde_json::from_str(&text).map_err(super::error::Error::from), AxumMessage::Close(_) => Err(super::error::Error::WebSocketClosed), @@ -50,7 +53,9 @@ pub fn deserialize_ws_message( } /// Serializes a `WsServerMessage` into an Axum WebSocket message. -pub fn serialize_ws_message(message: T) -> super::error::Result { +pub fn serialize_ws_message( + message: T, +) -> super::error::Result { serde_json::to_string(&message) .map(Into::into) .map(AxumMessage::Text) @@ -59,17 +64,17 @@ pub fn serialize_ws_message(message: T) -> super::error::Result { +pub struct SendWsMessage { pub message: T, - pub response_ch: Option>>, + pub response_ch: Option>>, } -impl SendWsMessage { +impl SendWsMessage { /// Sends a message over the MPSC channel and awaits a response via a oneshot channel. pub async fn send_with_response( tx: &mpsc::UnboundedSender, // Changed to reference message: T, - ) -> super::error::Result<()> { + ) -> super::error::Result<(), E> { let (response_tx, response_rx) = oneshot::channel(); let send_message = SendWsMessage { message, @@ -87,7 +92,7 @@ impl SendWsMessage { /// Creates a new message for fire-and-forget sending (no response/acknowledgement expected). pub fn new_no_response(message: T) -> Self { - SendWsMessage { + Self { message, response_ch: None, } diff --git a/src/web/ws/voice/claims.rs b/src/web/ws/voice/claims.rs index e1d3ea7..7c4c40b 100644 --- a/src/web/ws/voice/claims.rs +++ b/src/web/ws/voice/claims.rs @@ -5,5 +5,5 @@ pub struct VoiceClaims { pub user_id: entity::user::Id, pub server_id: entity::server::Id, pub channel_id: entity::channel::Id, - pub iat: i64, + pub exp: i64, } diff --git a/src/web/ws/voice/connection.rs b/src/web/ws/voice/connection.rs index c1a9b07..4fb2a82 100644 --- a/src/web/ws/voice/connection.rs +++ b/src/web/ws/voice/connection.rs @@ -1,6 +1,211 @@ -use axum::extract::ws::WebSocket; +use axum::extract::ws::Message as AxumMessage; +use futures::{Stream, StreamExt}; +use tokio::sync::{mpsc, oneshot}; +use super::error::{self, Error as WsError}; +use super::protocol::{WsClientMessage, WsServerMessage}; +use crate::jwt; use crate::state::AppState; +use crate::web::ws; +use crate::web::ws::general::WebSocketHandler; +use crate::web::ws::util::{SendWsMessage, deserialize_ws_message}; +use crate::web::ws::voice::claims::VoiceClaims; +use crate::web::ws::voice::protocol::WsServerMessage::SdpAnswer; +use crate::web::ws::voice::state::{WsContext, WsState}; +use crate::webrtc::{Offer, OfferSignal, WebRtcSignal}; -#[tracing::instrument(skip_all, name = "ws_connection_handler")] -pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) {} +impl WebSocketHandler for WsContext { + type ServerMessage = WsServerMessage; + type ClientMessage = WsClientMessage; + type Error = WsError; + + async fn handle_stream( + &mut self, + stream: S, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, + ) -> ws::error::Result<(), Self::Error> + where + S: Stream> + Unpin, + { + process_websocket_messages(self, stream, sender, app_state).await?; + + Ok(()) + } + + async fn cleanup(&mut self, app_state: &AppState) { + tracing::debug!("Cleaning up WebSocket connection."); + + match &self.connection_state { + WsState::Connected { + signal_channel, + server_id, + channel_id, + user_id, + } => { + ws::gateway::util::send_message_server( + app_state.clone(), + *server_id, + ws::gateway::event::Event::VoiceChannelDisconnected { + server_id: *server_id, + channel_id: *channel_id, + user_id: *user_id, + }, + ); + + let _ = signal_channel.send(WebRtcSignal::Disconnect(user_id.clone())); + }, + WsState::Initialize => {}, + } + } + + async fn handle_result_error( + &mut self, + error: Self::Error, + sender: &mpsc::UnboundedSender>, + ) { + tracing::error!("WebSocket error: {:?}", error); + } +} + +#[tracing::instrument(skip_all)] +async fn process_websocket_messages( + context: &mut WsContext, + mut ws_stream: S, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, +) -> ws::error::Result<(), WsError> +where + S: Stream> + Unpin, +{ + loop { + match &context.connection_state { + WsState::Initialize => { + while let Some(Ok(message)) = ws_stream.next().await { + handle_initial_message(context, message, sender, &app_state).await?; + break; + } + }, + WsState::Connected { signal_channel, .. } => { + let signal_channel = signal_channel.clone(); + loop { + tokio::select! { + biased; + _ = signal_channel.closed() => { + tracing::debug!("Signal channel closed."); + break; + } + Some(Ok(message)) = ws_stream.next() => { + handle_connected_message(context, message, sender, &app_state).await?; + } + else => { + break; + } + } + } + + return Err(ws::error::Error::WebSocketClosed); + }, + } + } +} + +#[tracing::instrument(skip_all)] +async fn handle_initial_message( + context: &mut WsContext, + message: AxumMessage, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, +) -> ws::error::Result<(), error::Error> { + match deserialize_ws_message(message)? { + WsClientMessage::Authenticate { token } => match jwt::verify_jwt::( + &token, + crate::config::config().security.voice_secret.as_ref(), + ) { + Ok(claims) => { + SendWsMessage::send_with_response(sender, WsServerMessage::AuthenticateAccepted) + .await?; + + let signal_channel = + ws::voice::state::get_signaling_channel(app_state, claims.channel_id).await; + + context.connection_state = WsState::Connected { + signal_channel, + server_id: claims.server_id, + channel_id: claims.channel_id, + user_id: claims.user_id, + }; + + ws::gateway::util::send_message_server( + app_state.clone(), + claims.server_id, + ws::gateway::event::Event::VoiceChannelConnected { + server_id: claims.server_id, + channel_id: claims.channel_id, + user_id: claims.user_id, + }, + ); + + Ok(()) + }, + Err(auth_err) => { + tracing::warn!("Authentication failed: {:?}", auth_err); + + let _ = + SendWsMessage::send_with_response(sender, WsServerMessage::AuthenticateDenied) + .await; + Err(error::Error::AuthenticationFailed.into()) + }, + }, + #[allow(unreachable_patterns)] + _ => { + tracing::warn!("Unexpected message type received during Initialize state."); + Err(ws::error::Error::WrongMessageType) + }, + } +} + +#[tracing::instrument(skip_all)] +async fn handle_connected_message( + context: &mut WsContext, + message: AxumMessage, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, +) -> ws::error::Result<(), error::Error> { + match deserialize_ws_message(message)? { + WsClientMessage::SdpOffer { sdp } => { + let (signal_channel, user_id) = match &context.connection_state { + WsState::Connected { + signal_channel, + user_id, + .. + } => (signal_channel.clone(), *user_id), + _ => return Err(ws::error::Error::WrongMessageType), + }; + + let (tx, rx) = oneshot::channel(); + + let _ = signal_channel.send(WebRtcSignal::Offer(OfferSignal { + offer: Offer { + peer_id: user_id, + sdp_offer: sdp, + }, + response: tx, + })); + + let answer_signal = rx.await?; + + sender + .send(SendWsMessage::new_no_response(SdpAnswer { + sdp: answer_signal.sdp_answer, + })) + .map_err(|_| ws::error::Error::WebSocketClosed)?; + + Ok(()) + }, + other_message => { + tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state."); + Err(ws::error::Error::WrongMessageType) + }, + } +} diff --git a/src/web/ws/voice/error.rs b/src/web/ws/voice/error.rs index 5766e1d..3f20a07 100644 --- a/src/web/ws/voice/error.rs +++ b/src/web/ws/voice/error.rs @@ -1,48 +1,10 @@ +use crate::web::ws::error::CustomError; + pub type Result = std::result::Result; #[derive(Debug, derive_more::From, derive_more::Display)] pub enum Error { - #[from] - Axum(axum::Error), - - #[from] - Json(serde_json::Error), - - #[from] - AcknowledgementError(tokio::sync::oneshot::error::RecvError), - - UnexpectedMessageType, - - WrongMessageType, - - WebSocketClosed, - - HeartbeatTimeout, AuthenticationFailed, - - TokenGenerationFailed, } -#[derive(Debug, Clone, serde::Serialize)] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum ClientError { - DeserializationError, - NotAuthenticated, - AlreadyAuthenticated, - HeartbeatTimeout, - - Unknown, -} - -impl Error { - pub fn into_client_error(&self) -> ClientError { - match self { - Error::HeartbeatTimeout => ClientError::HeartbeatTimeout, - Error::Json(_) => ClientError::DeserializationError, - Error::UnexpectedMessageType => ClientError::Unknown, - Error::WrongMessageType => ClientError::Unknown, - Error::WebSocketClosed => ClientError::Unknown, - _ => ClientError::Unknown, - } - } -} +impl CustomError for Error {} diff --git a/src/web/ws/voice/mod.rs b/src/web/ws/voice/mod.rs index bddb0d6..4d43718 100644 --- a/src/web/ws/voice/mod.rs +++ b/src/web/ws/voice/mod.rs @@ -2,16 +2,20 @@ pub mod claims; mod connection; mod error; mod protocol; +mod state; use axum::extract::{State, WebSocketUpgrade}; use axum::response::IntoResponse; use crate::state::AppState; -use crate::web::ws::voice::connection::handle_socket_connection; +use crate::web::ws::general; +use crate::web::ws::voice::state::WsContext; pub async fn ws_handler( State(app_state): State, ws: WebSocketUpgrade, ) -> crate::web::error::Result { - Ok(ws.on_upgrade(|socket| handle_socket_connection(socket, app_state))) + Ok(ws.on_upgrade(|socket| { + general::handle_websocket_connection(socket, app_state, WsContext::default()) + })) } diff --git a/src/web/ws/voice/protocol.rs b/src/web/ws/voice/protocol.rs index 118bdb6..b658203 100644 --- a/src/web/ws/voice/protocol.rs +++ b/src/web/ws/voice/protocol.rs @@ -1,21 +1,9 @@ -use std::time::Duration; - -use axum::extract::ws::Message as AxumMessage; -use serde::{Deserialize, Serialize}; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use super::error::{self, ClientError, Error as WsError}; -use crate::{entity, util as crate_root_util}; // For crate::util::serialize_duration_seconds - -#[derive(Debug, Serialize)] +#[derive(Debug, serde::Serialize)] #[serde(tag = "type", content = "data")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum WsServerMessage { - HeartbeatInterval { - #[serde(serialize_with = "crate_root_util::serialize_duration_seconds")] - interval: Duration, - }, - AuthenticateDenied, AuthenticateAccepted, @@ -24,82 +12,15 @@ pub enum WsServerMessage { SdpAnswer { sdp: RTCSessionDescription, }, - - #[serde(rename_all = "camelCase")] - Error { - code: ClientError, - }, - - Pong, } -#[derive(Debug, Deserialize)] +#[derive(Debug, serde::Deserialize)] #[serde(tag = "type", content = "data")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum WsClientMessage { #[serde(rename_all = "camelCase")] - Authenticate { - token: String, - }, + Authenticate { token: String }, #[serde(rename_all = "camelCase")] - SdpOffer { - sdp: RTCSessionDescription, - }, - - Ping, -} - -/// Deserializes an Axum WebSocket message into a `WsClientMessage`. -pub fn deserialize_ws_message(message: AxumMessage) -> error::Result { - match message { - AxumMessage::Text(text) => serde_json::from_str(&text).map_err(WsError::from), - AxumMessage::Close(_) => Err(WsError::WebSocketClosed), - _ => Err(WsError::WrongMessageType), // e.g. Binary, Ping, Pong from axum::Message - } -} - -/// Serializes a `WsServerMessage` into an Axum WebSocket message. -pub fn serialize_ws_message(message: WsServerMessage) -> error::Result { - serde_json::to_string(&message) - .map(Into::into) - .map(AxumMessage::Text) - .map_err(WsError::from) -} - -/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task. -/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer. -pub struct SendWsMessage { - pub message: WsServerMessage, - pub response_ch: Option>>, -} - -impl SendWsMessage { - /// Sends a message over the MPSC channel and awaits a response via a oneshot channel. - pub async fn send_with_response( - tx: &tokio::sync::mpsc::UnboundedSender, // Changed to reference - message: WsServerMessage, - ) -> error::Result<()> { - let (response_tx, response_rx) = tokio::sync::oneshot::channel(); - let send_message = SendWsMessage { - message, - response_ch: Some(response_tx), - }; - - if tx.send(send_message).is_err() { - Err(WsError::WebSocketClosed) // MPSC channel closed, writer task likely dead - } else { - // Wait for the writer task to acknowledge the send attempt. - // This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error. - response_rx.await? // Propagates RecvError into WsError::AcknowledgementError - } - } - - /// Creates a new message for fire-and-forget sending (no response/acknowledgement expected). - pub fn new_no_response(message: WsServerMessage) -> Self { - SendWsMessage { - message, - response_ch: None, - } - } + SdpOffer { sdp: RTCSessionDescription }, } diff --git a/src/web/ws/voice/state.rs b/src/web/ws/voice/state.rs new file mode 100644 index 0000000..5485edf --- /dev/null +++ b/src/web/ws/voice/state.rs @@ -0,0 +1,68 @@ +use tokio::sync::mpsc; + +use crate::entity; +use crate::state::AppState; +use crate::webrtc::{OfferSignal, WebRtcSignal}; + +#[derive(Debug, Clone)] +pub enum WsState { + Initialize, + Connected { + signal_channel: mpsc::UnboundedSender, + server_id: entity::server::Id, + channel_id: entity::channel::Id, + user_id: entity::user::Id, + }, +} + +pub struct WsContext { + pub connection_state: WsState, +} + +impl Default for WsContext { + fn default() -> Self { + Self { + connection_state: WsState::Initialize, + } + } +} + +pub async fn get_signaling_channel( + app_state: &AppState, + channel_id: entity::channel::Id, +) -> mpsc::UnboundedSender { + let room_sender = { + app_state + .voice_rooms + .read() + .await + .get(&channel_id) + .map(|room| room.clone()) + }; + + match room_sender { + Some(room) => room, + None => { + let (tx, rx) = mpsc::unbounded_channel(); + + let app_state_ = app_state.clone(); + tokio::spawn(async move { + crate::webrtc::webrtc_task(channel_id, rx) + .await + .unwrap_or_else(|err| { + tracing::error!("webrtc task error: {:?}", err); + }); + + { + app_state_.unregister_voice_room(channel_id).await; + } + }); + + { + app_state.register_voice_room(channel_id, tx.clone()).await; + } + + tx + }, + } +} diff --git a/src/webrtc/mod.rs b/src/webrtc/mod.rs index 50fbcfb..9d92671 100644 --- a/src/webrtc/mod.rs +++ b/src/webrtc/mod.rs @@ -42,6 +42,16 @@ pub struct AnswerSignal { pub sdp_answer: RTCSessionDescription, } +#[derive(Debug)] +pub enum WebRtcSignal { + Offer(OfferSignal), + Disconnect(PeerId), + RequestPeers { + response: tokio::sync::oneshot::Sender>, + }, + Close, +} + #[derive(Debug)] pub struct OfferSignal { pub offer: Offer, @@ -51,12 +61,14 @@ pub struct OfferSignal { #[tracing::instrument(skip(signal))] pub async fn webrtc_task( room_id: RoomId, - signal: tokio::sync::mpsc::UnboundedReceiver, + signal: tokio::sync::mpsc::UnboundedReceiver, ) -> anyhow::Result<()> { tracing::info!("Starting WebRTC task"); let (close_signal, mut close_receiver) = tokio::sync::mpsc::unbounded_channel(); + let mut skip_timeout = false; + let state = Arc::new(RoomState { room_id, peers: DashMap::new(), @@ -78,20 +90,46 @@ pub async fn webrtc_task( loop { tokio::select! { - Some(signal) = signal.recv() => { - let room_state = state.clone(); - let api = api.clone(); - - tokio::spawn(async move { - if let Err(e) = handle_peer(api, room_state, signal).await { - tracing::error!("error handling peer: {}", e); - } - }.instrument(tracing::Span::current())); + biased; + _ = tokio::time::sleep(std::time::Duration::from_secs(10)), if !skip_timeout => { + tracing::debug!("initial timeout reached"); + break; } _ = close_receiver.recv() => { tracing::debug!("WebRTC task stopped"); break; } + Some(signal) = signal.recv() => { + skip_timeout = true; + match signal { + WebRtcSignal::Offer(offer_signal) => { + let room_state = state.clone(); + let api = api.clone(); + + tokio::spawn(async move { + if let Err(e) = handle_peer(api, room_state, offer_signal).await { + tracing::error!("error handling peer: {}", e); + } + }.instrument(tracing::Span::current())); + } + WebRtcSignal::RequestPeers { response } => { + let peers = state + .peers + .iter() + .map(|pair| pair.key().clone()) + .collect::>(); + + let _ = response.send(peers); + } + WebRtcSignal::Disconnect(peer_id) => { + tracing::debug!("received disconnect signal for peer {}", peer_id); + cleanup_peer(state.clone(), peer_id).await; + } + WebRtcSignal::Close => { + break; + } + } + } } } @@ -105,6 +143,7 @@ async fn handle_peer( offer_signal: OfferSignal, ) -> anyhow::Result<()> { tracing::debug!("handling peer"); + let config = RTCConfiguration { ..Default::default() };