This commit is contained in:
2025-05-17 23:52:20 +03:00
parent 02f45aeac6
commit de11b5a4c3
84 changed files with 2425 additions and 687 deletions

2
.gitignore vendored
View File

@@ -1,7 +1,7 @@
/target /target
/.idea /.idea
/config.toml /config.toml
/db /data
/logs /logs
.env .env

438
Cargo.lock generated
View File

@@ -194,12 +194,52 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "attohttpc"
version = "0.28.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07a9b245ba0739fc90935094c29adbaee3f977218b5fb95e822e261cda7f56a3"
dependencies = [
"http 1.3.1",
"log",
"native-tls",
"serde",
"serde_json",
"url",
]
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.4.0" version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "aws-creds"
version = "0.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f84143206b9c72b3c5cb65415de60c7539c79cd1559290fddec657939131be0"
dependencies = [
"attohttpc",
"home",
"log",
"quick-xml",
"rust-ini",
"serde",
"thiserror 1.0.69",
"time",
"url",
]
[[package]]
name = "aws-region"
version = "0.25.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9aed3f9c7eac9be28662fdb3b0f4d1951e812f7c64fed4f0327ba702f459b3b"
dependencies = [
"thiserror 1.0.69",
]
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.8.4" version = "0.8.4"
@@ -212,10 +252,10 @@ dependencies = [
"bytes", "bytes",
"form_urlencoded", "form_urlencoded",
"futures-util", "futures-util",
"http", "http 1.3.1",
"http-body", "http-body 1.0.1",
"http-body-util", "http-body-util",
"hyper", "hyper 1.6.0",
"hyper-util", "hyper-util",
"itoa", "itoa",
"matchit", "matchit",
@@ -247,8 +287,8 @@ checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-core", "futures-core",
"http", "http 1.3.1",
"http-body", "http-body 1.0.1",
"http-body-util", "http-body-util",
"mime", "mime",
"pin-project-lite", "pin-project-lite",
@@ -270,8 +310,8 @@ dependencies = [
"bytes", "bytes",
"futures-util", "futures-util",
"headers", "headers",
"http", "http 1.3.1",
"http-body", "http-body 1.0.1",
"http-body-util", "http-body-util",
"mime", "mime",
"pin-project-lite", "pin-project-lite",
@@ -384,9 +424,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "2.9.0" version = "2.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967"
dependencies = [ dependencies = [
"serde", "serde",
] ]
@@ -513,9 +553,9 @@ dependencies = [
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.22" version = "1.2.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32db95edf998450acc7881c932f94cd9b05c87b4b2599e8bab064753da4acfd1" checksum = "5f4ac86a9e5bc1e2b3449ab9d7d3a6a405e3d1bb28d7b9be8614f55846ae3766"
dependencies = [ dependencies = [
"shlex", "shlex",
] ]
@@ -641,6 +681,16 @@ dependencies = [
"unicode-segmentation", "unicode-segmentation",
] ]
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "core-foundation-sys" name = "core-foundation-sys"
version = "0.8.7" version = "0.8.7"
@@ -855,6 +905,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e"
dependencies = [ dependencies = [
"powerfmt", "powerfmt",
"serde",
] ]
[[package]] [[package]]
@@ -922,6 +973,7 @@ dependencies = [
"mime", "mime",
"rand 0.9.1", "rand 0.9.1",
"regex", "regex",
"rust-s3",
"scc", "scc",
"serde", "serde",
"serde_json", "serde_json",
@@ -1025,9 +1077,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]] [[package]]
name = "errno" name = "errno"
version = "0.3.11" version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.59.0", "windows-sys 0.59.0",
@@ -1100,6 +1152,21 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]] [[package]]
name = "form_urlencoded" name = "form_urlencoded"
version = "1.2.1" version = "1.2.1"
@@ -1322,7 +1389,7 @@ dependencies = [
"base64 0.21.7", "base64 0.21.7",
"bytes", "bytes",
"headers-core", "headers-core",
"http", "http 1.3.1",
"httpdate", "httpdate",
"mime", "mime",
"sha1", "sha1",
@@ -1334,7 +1401,7 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4"
dependencies = [ dependencies = [
"http", "http 1.3.1",
] ]
[[package]] [[package]]
@@ -1376,6 +1443,17 @@ dependencies = [
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
[[package]]
name = "http"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]] [[package]]
name = "http" name = "http"
version = "1.3.1" version = "1.3.1"
@@ -1387,6 +1465,17 @@ dependencies = [
"itoa", "itoa",
] ]
[[package]]
name = "http-body"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
dependencies = [
"bytes",
"http 0.2.12",
"pin-project-lite",
]
[[package]] [[package]]
name = "http-body" name = "http-body"
version = "1.0.1" version = "1.0.1"
@@ -1394,7 +1483,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [ dependencies = [
"bytes", "bytes",
"http", "http 1.3.1",
] ]
[[package]] [[package]]
@@ -1405,8 +1494,8 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-core", "futures-core",
"http", "http 1.3.1",
"http-body", "http-body 1.0.1",
"pin-project-lite", "pin-project-lite",
] ]
@@ -1431,6 +1520,29 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "hyper"
version = "0.14.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7"
dependencies = [
"bytes",
"futures-channel",
"futures-core",
"futures-util",
"http 0.2.12",
"http-body 0.4.6",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"socket2",
"tokio",
"tower-service",
"tracing",
"want",
]
[[package]] [[package]]
name = "hyper" name = "hyper"
version = "1.6.0" version = "1.6.0"
@@ -1440,8 +1552,8 @@ dependencies = [
"bytes", "bytes",
"futures-channel", "futures-channel",
"futures-util", "futures-util",
"http", "http 1.3.1",
"http-body", "http-body 1.0.1",
"httparse", "httparse",
"httpdate", "httpdate",
"itoa", "itoa",
@@ -1450,6 +1562,19 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "hyper-tls"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
dependencies = [
"bytes",
"hyper 0.14.32",
"native-tls",
"tokio",
"tokio-native-tls",
]
[[package]] [[package]]
name = "hyper-util" name = "hyper-util"
version = "0.1.11" version = "0.1.11"
@@ -1458,9 +1583,9 @@ checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-util", "futures-util",
"http", "http 1.3.1",
"http-body", "http-body 1.0.1",
"hyper", "hyper 1.6.0",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
"tower-service", "tower-service",
@@ -1765,6 +1890,17 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "maybe-async"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cf92c10c7e361d6b99666ec1c6f9805b0bea2c3bd8c78dc6fe98ac5bd78db11"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.101",
]
[[package]] [[package]]
name = "md-5" name = "md-5"
version = "0.10.6" version = "0.10.6"
@@ -1775,6 +1911,12 @@ dependencies = [
"digest 0.10.7", "digest 0.10.7",
] ]
[[package]]
name = "md5"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.7.4" version = "2.7.4"
@@ -1796,6 +1938,15 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "minidom"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f45614075738ce1b77a1768912a60c0227525971b03e09122a05b8a34a2a6278"
dependencies = [
"rxml",
]
[[package]] [[package]]
name = "minimal-lexical" name = "minimal-lexical"
version = "0.2.1" version = "0.2.1"
@@ -1831,7 +1982,7 @@ dependencies = [
"bytes", "bytes",
"encoding_rs", "encoding_rs",
"futures-util", "futures-util",
"http", "http 1.3.1",
"httparse", "httparse",
"memchr", "memchr",
"mime", "mime",
@@ -1839,6 +1990,23 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "native-tls"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]] [[package]]
name = "nix" name = "nix"
version = "0.26.4" version = "0.26.4"
@@ -1965,6 +2133,50 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "openssl"
version = "0.10.72"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da"
dependencies = [
"bitflags 2.9.1",
"cfg-if",
"foreign-types",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.101",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
version = "0.9.108"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e145e1651e858e820e4860f7b9c5e169bc1d8ce1c86043be79fa7b7634821847"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "ordered-multimap" name = "ordered-multimap"
version = "0.7.3" version = "0.7.3"
@@ -2271,6 +2483,16 @@ dependencies = [
"syn 1.0.109", "syn 1.0.109",
] ]
[[package]]
name = "quick-xml"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d3a6e5838b60e0e8fa7a43f22ade549a37d61f8bdbe636d0d7816191de969c2"
dependencies = [
"memchr",
"serde",
]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.40" version = "1.0.40"
@@ -2371,7 +2593,7 @@ version = "0.5.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af" checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af"
dependencies = [ dependencies = [
"bitflags 2.9.0", "bitflags 2.9.1",
] ]
[[package]] [[package]]
@@ -2487,7 +2709,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94" checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94"
dependencies = [ dependencies = [
"base64 0.21.7", "base64 0.21.7",
"bitflags 2.9.0", "bitflags 2.9.1",
"serde", "serde",
"serde_derive", "serde_derive",
] ]
@@ -2549,6 +2771,43 @@ dependencies = [
"trim-in-place", "trim-in-place",
] ]
[[package]]
name = "rust-s3"
version = "0.35.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3df3f353b1f4209dcf437d777cda90279c397ab15a0cd6fd06bd32c88591533"
dependencies = [
"async-trait",
"aws-creds",
"aws-region",
"base64 0.22.1",
"bytes",
"cfg-if",
"futures",
"hex",
"hmac",
"http 0.2.12",
"hyper 0.14.32",
"hyper-tls",
"log",
"maybe-async",
"md5",
"minidom",
"native-tls",
"percent-encoding",
"quick-xml",
"serde",
"serde_derive",
"serde_json",
"sha2",
"thiserror 1.0.69",
"time",
"tokio",
"tokio-native-tls",
"tokio-stream",
"url",
]
[[package]] [[package]]
name = "rust_decimal" name = "rust_decimal"
version = "1.37.1" version = "1.37.1"
@@ -2595,7 +2854,7 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266"
dependencies = [ dependencies = [
"bitflags 2.9.0", "bitflags 2.9.1",
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
@@ -2642,6 +2901,23 @@ version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2"
[[package]]
name = "rxml"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a98f186c7a2f3abbffb802984b7f1dfd65dac8be1aafdaabbca4137f53f0dff7"
dependencies = [
"bytes",
"rxml_validation",
"smartstring",
]
[[package]]
name = "rxml_validation"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22a197350ece202f19a166d1ad6d9d6de145e1d2a8ef47db299abe164dbd7530"
[[package]] [[package]]
name = "ryu" name = "ryu"
version = "1.0.20" version = "1.0.20"
@@ -2657,6 +2933,15 @@ dependencies = [
"sdd", "sdd",
] ]
[[package]]
name = "schannel"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d"
dependencies = [
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "scopeguard" name = "scopeguard"
version = "1.2.0" version = "1.2.0"
@@ -2701,6 +2986,29 @@ dependencies = [
"zeroize", "zeroize",
] ]
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.9.1",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "semver" name = "semver"
version = "1.0.26" version = "1.0.26"
@@ -2862,6 +3170,17 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "smartstring"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29"
dependencies = [
"autocfg",
"static_assertions",
"version_check",
]
[[package]] [[package]]
name = "smol_str" name = "smol_str"
version = "0.2.2" version = "0.2.2"
@@ -2996,7 +3315,7 @@ checksum = "0afdd3aa7a629683c2d750c2df343025545087081ab5942593a5288855b1b7a7"
dependencies = [ dependencies = [
"atoi", "atoi",
"base64 0.22.1", "base64 0.22.1",
"bitflags 2.9.0", "bitflags 2.9.1",
"byteorder", "byteorder",
"bytes", "bytes",
"chrono", "chrono",
@@ -3040,7 +3359,7 @@ checksum = "a0bedbe1bbb5e2615ef347a5e9d8cd7680fb63e77d9dafc0f29be15e53f1ebe6"
dependencies = [ dependencies = [
"atoi", "atoi",
"base64 0.22.1", "base64 0.22.1",
"bitflags 2.9.0", "bitflags 2.9.1",
"byteorder", "byteorder",
"chrono", "chrono",
"crc", "crc",
@@ -3103,6 +3422,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]] [[package]]
name = "stringprep" name = "stringprep"
version = "0.1.5" version = "0.1.5"
@@ -3201,9 +3526,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.19.1" version = "3.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1"
dependencies = [ dependencies = [
"fastrand", "fastrand",
"getrandom 0.3.3", "getrandom 0.3.3",
@@ -3356,6 +3681,16 @@ dependencies = [
"syn 2.0.101", "syn 2.0.101",
] ]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]] [[package]]
name = "tokio-stream" name = "tokio-stream"
version = "0.1.17" version = "0.1.17"
@@ -3444,14 +3779,14 @@ dependencies = [
[[package]] [[package]]
name = "tower-http" name = "tower-http"
version = "0.6.2" version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697" checksum = "0fdb0c213ca27a9f57ab69ddb290fd80d970922355b83ae380b395d3986b8a2e"
dependencies = [ dependencies = [
"bitflags 2.9.0", "bitflags 2.9.1",
"bytes", "bytes",
"http", "http 1.3.1",
"http-body", "http-body 1.0.1",
"pin-project-lite", "pin-project-lite",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
@@ -3552,6 +3887,12 @@ version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343e926fc669bc8cde4fa3129ab681c63671bae288b1f1081ceee6d9d37904fc" checksum = "343e926fc669bc8cde4fa3129ab681c63671bae288b1f1081ceee6d9d37904fc"
[[package]]
name = "try-lock"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]] [[package]]
name = "tungstenite" name = "tungstenite"
version = "0.26.2" version = "0.26.2"
@@ -3560,7 +3901,7 @@ checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
dependencies = [ dependencies = [
"bytes", "bytes",
"data-encoding", "data-encoding",
"http", "http 1.3.1",
"httparse", "httparse",
"log", "log",
"rand 0.9.1", "rand 0.9.1",
@@ -3755,6 +4096,15 @@ dependencies = [
"atomic-waker", "atomic-waker",
] ]
[[package]]
name = "want"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e"
dependencies = [
"try-lock",
]
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.11.0+wasi-snapshot-preview1" version = "0.11.0+wasi-snapshot-preview1"
@@ -4077,9 +4427,9 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]] [[package]]
name = "windows-core" name = "windows-core"
version = "0.61.0" version = "0.61.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980" checksum = "46ec44dc15085cea82cf9c78f85a9114c463a369786585ad2882d1ff0b0acf40"
dependencies = [ dependencies = [
"windows-implement", "windows-implement",
"windows-interface", "windows-interface",
@@ -4118,18 +4468,18 @@ checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38"
[[package]] [[package]]
name = "windows-result" name = "windows-result"
version = "0.3.2" version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252" checksum = "4b895b5356fc36103d0f64dd1e94dfa7ac5633f1c9dd6e80fe9ec4adef69e09d"
dependencies = [ dependencies = [
"windows-link", "windows-link",
] ]
[[package]] [[package]]
name = "windows-strings" name = "windows-strings"
version = "0.4.0" version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2ba9642430ee452d5a7aa78d72907ebe8cfda358e8cb7918a2050581322f97" checksum = "2a7ab927b2637c19b3dbe0965e75d8f2d30bdd697a1516191cad2ec4df8fb28a"
dependencies = [ dependencies = [
"windows-link", "windows-link",
] ]
@@ -4297,7 +4647,7 @@ version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [ dependencies = [
"bitflags 2.9.0", "bitflags 2.9.1",
] ]
[[package]] [[package]]

View File

@@ -34,3 +34,4 @@ rand = "0.9"
sha2 = "0.10" sha2 = "0.10"
base64 = "0.22" base64 = "0.22"
scc = "2.3" scc = "2.3"
rust-s3 = "0.35.1"

View File

@@ -6,5 +6,17 @@ services:
env_file: env_file:
- .env - .env
volumes: volumes:
- ${PWD}/db/:/var/lib/postgresql/data/ - ${PWD}/data/db/:/var/lib/postgresql/data/
user: "1000:1000"
object_storage:
image: 'quay.io/minio/minio:latest'
ports:
- "9000:9000"
- "9001:9001"
env_file:
- .env
volumes:
- ${PWD}/data/minio/:/data
command: server /data --console-address ":9001"
user: "1000:1000" user: "1000:1000"

View File

@@ -0,0 +1,8 @@
CREATE TABLE IF NOT EXISTS "file"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"filename" VARCHAR NOT NULL,
"content_type" VARCHAR NOT NULL,
"size" INT8 NOT NULL
);

View File

@@ -1 +0,0 @@
DROP EXTENSION pg_uuidv7;

View File

@@ -1,4 +0,0 @@
DROP TRIGGER trg_user_relation_update ON "user_relation";
DROP FUNCTION fn_on_user_relation_update();
DROP TABLE "user_relation";
DROP TABLE "user";

View File

@@ -1,7 +1,7 @@
CREATE TABLE IF NOT EXISTS "user" CREATE TABLE IF NOT EXISTS "user"
( (
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"avatar_url" VARCHAR, "avatar_id" UUID REFERENCES "file" ("id") ON DELETE SET NULL,
"username" VARCHAR NOT NULL UNIQUE, "username" VARCHAR NOT NULL UNIQUE,
"display_name" VARCHAR, "display_name" VARCHAR,
"email" VARCHAR NOT NULL, "email" VARCHAR NOT NULL,
@@ -25,6 +25,38 @@ CREATE TABLE IF NOT EXISTS "user_relation"
INSERT INTO "user" ("username", "display_name", "email", "password_hash", "bot", "system") INSERT INTO "user" ("username", "display_name", "email", "password_hash", "bot", "system")
VALUES ('system', 'System', 'system@lionarius.ru', '', TRUE, TRUE); VALUES ('system', 'System', 'system@lionarius.ru', '', TRUE, TRUE);
CREATE OR REPLACE FUNCTION check_avatar_is_image()
RETURNS TRIGGER AS
$$
DECLARE
file_content_type VARCHAR;
BEGIN
-- Skip check if icon_id is null
IF NEW.avatar_id IS NULL THEN
RETURN NEW;
END IF;
-- Retrieve content_type from file table
SELECT content_type
INTO file_content_type
FROM file
WHERE id = NEW.avatar_id;
-- Raise exception if content_type does not start with 'image/'
IF file_content_type IS NULL OR file_content_type NOT LIKE 'image/%' THEN
RAISE EXCEPTION 'avatar_id must reference a file with content_type starting with image/';
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_check_icon_is_image
BEFORE INSERT OR UPDATE
ON "user"
FOR EACH ROW
EXECUTE FUNCTION check_avatar_is_image();
CREATE OR REPLACE FUNCTION fn_on_user_relation_update() CREATE OR REPLACE FUNCTION fn_on_user_relation_update()
RETURNS TRIGGER RETURNS TRIGGER
LANGUAGE plpgsql LANGUAGE plpgsql

View File

@@ -1 +0,0 @@
DROP TABLE "server";

View File

@@ -3,7 +3,7 @@ CREATE TABLE IF NOT EXISTS "server"
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"owner_id" UUID NOT NULL REFERENCES "user" ("id"), "owner_id" UUID NOT NULL REFERENCES "user" ("id"),
"name" VARCHAR NOT NULL, "name" VARCHAR NOT NULL,
"icon_url" VARCHAR "icon_id" UUID REFERENCES "file" ("id") ON DELETE SET NULL
); );
CREATE TABLE IF NOT EXISTS "server_role" CREATE TABLE IF NOT EXISTS "server_role"
@@ -42,6 +42,38 @@ CREATE TABLE IF NOT EXISTS "server_invite"
"expires_at" TIMESTAMPTZ "expires_at" TIMESTAMPTZ
); );
CREATE OR REPLACE FUNCTION check_icon_is_image()
RETURNS TRIGGER AS
$$
DECLARE
file_content_type VARCHAR;
BEGIN
-- Skip check if icon_id is null
IF NEW.icon_id IS NULL THEN
RETURN NEW;
END IF;
-- Retrieve content_type from file table
SELECT content_type
INTO file_content_type
FROM file
WHERE id = NEW.icon_id;
-- Raise exception if content_type does not start with 'image/'
IF file_content_type IS NULL OR file_content_type NOT LIKE 'image/%' THEN
RAISE EXCEPTION 'icon_id must reference a file with content_type starting with image/';
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_check_icon_is_image
BEFORE INSERT OR UPDATE
ON "server"
FOR EACH ROW
EXECUTE FUNCTION check_icon_is_image();
CREATE OR REPLACE FUNCTION check_server_user_role_server_id() CREATE OR REPLACE FUNCTION check_server_user_role_server_id()
RETURNS TRIGGER AS RETURNS TRIGGER AS
$$ $$

View File

@@ -1,2 +0,0 @@
DROP TABLE "message";
DROP TABLE "channel";

View File

@@ -20,11 +20,19 @@ CREATE TABLE IF NOT EXISTS "channel_recipient"
CREATE TABLE IF NOT EXISTS "message" CREATE TABLE IF NOT EXISTS "message"
( (
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"author_id" UUID NOT NULL REFERENCES "user" ("id"), "author_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE,
"channel_id" UUID NOT NULL REFERENCES "channel" ("id"), "channel_id" UUID NOT NULL REFERENCES "channel" ("id") ON DELETE CASCADE,
"content" TEXT NOT NULL "content" TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS "message_attachment"
(
"message_id" UUID NOT NULL REFERENCES "message" ON DELETE CASCADE,
"file_id" UUID NOT NULL REFERENCES "file" ON DELETE CASCADE,
"order" INT2 NOT NULL,
PRIMARY KEY ("message_id", "file_id")
);
ALTER TABLE "channel" ALTER TABLE "channel"
ADD COLUMN "last_message_id" UUID REFERENCES "message" ("id") ON DELETE SET NULL; ADD COLUMN "last_message_id" UUID REFERENCES "message" ("id") ON DELETE SET NULL;

View File

@@ -1,2 +0,0 @@
DROP TABLE "message_attachment";
DROP TABLE "file";

View File

@@ -1,16 +0,0 @@
CREATE TABLE IF NOT EXISTS "file"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"filename" VARCHAR NOT NULL,
"content_type" VARCHAR NOT NULL,
"url" VARCHAR NOT NULL,
"size" INT8 NOT NULL
);
CREATE TABLE IF NOT EXISTS "message_attachment"
(
"message_id" UUID NOT NULL REFERENCES "message" ON DELETE CASCADE,
"attachment_id" UUID NOT NULL REFERENCES "file" ON DELETE CASCADE,
"order" INT2 NOT NULL,
PRIMARY KEY ("message_id", "attachment_id")
);

View File

@@ -0,0 +1,35 @@
CREATE OR REPLACE FUNCTION get_users_that_can_see_user(target_user_id UUID)
RETURNS TABLE (user_id UUID) AS $$
BEGIN
RETURN QUERY
-- Users directly related to the target user
SELECT ur.user_id
FROM user_relation ur
WHERE ur.other_id = target_user_id
UNION
-- Users where target user is related to them
SELECT ur.other_id AS user_id
FROM user_relation ur
WHERE ur.user_id = target_user_id
UNION
-- Users who share a server with the target user
SELECT sm.user_id
FROM server_member sm
JOIN server_member sm2 ON sm.server_id = sm2.server_id
WHERE sm2.user_id = target_user_id
AND sm.user_id != target_user_id
UNION
-- Users who share a channel with the target user (DM or group)
SELECT cr.user_id
FROM channel_recipient cr
JOIN channel_recipient cr2 ON cr.channel_id = cr2.channel_id
WHERE cr2.user_id = target_user_id
AND cr.user_id != target_user_id;
END;
$$ LANGUAGE plpgsql;

View File

@@ -22,10 +22,12 @@ pub struct Config {
pub security: SecurityConfig, pub security: SecurityConfig,
pub gateway: GatewayConfig, pub gateway: GatewayConfig,
pub database: DatabaseConfig, pub database: DatabaseConfig,
pub object_store: ObjectStoreConfig,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct ServerConfig { pub struct ServerConfig {
pub hostname: url::Url,
pub host: std::net::Ipv4Addr, pub host: std::net::Ipv4Addr,
pub port: u16, pub port: u16,
} }
@@ -59,6 +61,16 @@ pub enum DatabaseConfig {
}, },
} }
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ObjectStoreConfig {
pub endpoint: url::Url,
pub region: String,
pub bucket: String,
pub access_key: String,
pub secret_key: String,
}
impl DatabaseConfig { impl DatabaseConfig {
pub fn url(&self) -> Option<url::Url> { pub fn url(&self) -> Option<url::Url> {
match self { match self {

View File

@@ -21,9 +21,15 @@ pub enum Error {
ServerDoesNotExists, ServerDoesNotExists,
MemberAlreadyExists,
ChannelDoesNotExists, ChannelDoesNotExists,
InviteDoesNotExists,
MessageDoesNotExists, MessageDoesNotExists,
FileDoesNotExists,
} }
impl Database { impl Database {
@@ -81,6 +87,25 @@ impl Database {
Ok(user) Ok(user)
} }
pub async fn update_user_by_id(
&self,
user_id: entity::user::Id,
display_name: Option<&str>,
avatar_id: Option<entity::file::Id>,
) -> Result<entity::user::User> {
let user = sqlx::query_as!(
entity::user::User,
r#"UPDATE "user" SET "display_name" = COALESCE($2, "display_name"), "avatar_id" = COALESCE($3, "avatar_id") WHERE "id" = $1 RETURNING "user".*"#,
user_id,
display_name,
avatar_id
)
.fetch_one(&self.pool)
.await?;
Ok(user)
}
pub async fn select_users_by_ids( pub async fn select_users_by_ids(
&self, &self,
user_ids: &[entity::user::Id], user_ids: &[entity::user::Id],
@@ -146,17 +171,54 @@ impl Database {
Ok(servers) Ok(servers)
} }
pub async fn select_server_members(
&self,
server_id: entity::server::Id,
) -> Result<Vec<entity::user::User>> {
let users = sqlx::query_as!(
entity::user::User,
r#"SELECT * FROM "user" WHERE "id" IN (
SELECT "user_id" FROM "server_member" WHERE "server_id" = $1
)"#,
server_id
)
.fetch_all(&self.pool)
.await?;
Ok(users)
}
pub async fn select_channel_members(
&self,
channel_id: entity::channel::Id,
) -> Result<Vec<entity::user::User>> {
let users = sqlx::query_as!(
entity::user::User,
r#"SELECT * FROM "user" WHERE "id" IN (
SELECT "user_id" FROM "channel_recipient" WHERE "channel_id" = $1
UNION SELECT "user_id" FROM "server_member" WHERE "server_id" IN (
SELECT "server_id" FROM "channel" WHERE "id" = $1
)
)"#,
channel_id
)
.fetch_all(&self.pool)
.await?;
Ok(users)
}
pub async fn select_user_channels( pub async fn select_user_channels(
&self, &self,
user_id: entity::user::Id, user_id: entity::user::Id,
) -> Result<Vec<entity::channel::Channel>> { ) -> Result<Vec<entity::channel::Channel>> {
// for some reason using macro overflows tokio stack let channels = sqlx::query_as!(
let channels = sqlx::query_as( entity::channel::Channel,
r#"SELECT * FROM "channel" WHERE "id" IN ( r#"SELECT * FROM "channel" WHERE "id" IN (
SELECT "channel_id" FROM "channel_recipient" WHERE "user_id" = $1 SELECT "channel_id" FROM "channel_recipient" WHERE "user_id" = $1
)"#, )"#,
user_id
) )
.bind(user_id)
.fetch_all(&self.pool) .fetch_all(&self.pool)
.await?; .await?;
@@ -166,14 +228,14 @@ impl Database {
pub async fn insert_server( pub async fn insert_server(
&self, &self,
name: &str, name: &str,
icon_url: Option<&str>, icon_id: Option<entity::file::Id>,
owner_id: entity::user::Id, owner_id: entity::user::Id,
) -> Result<entity::server::Server> { ) -> Result<entity::server::Server> {
let server = sqlx::query_as!( let server = sqlx::query_as!(
entity::server::Server, entity::server::Server,
r#"INSERT INTO "server"("name", "icon_url", "owner_id") VALUES ($1, $2, $3) RETURNING "server".*"#, r#"INSERT INTO "server"("name", "icon_id", "owner_id") VALUES ($1, $2, $3) RETURNING "server".*"#,
name, name,
icon_url, icon_id,
owner_id owner_id
) )
.fetch_one(&self.pool) .fetch_one(&self.pool)
@@ -214,14 +276,20 @@ impl Database {
server_id: entity::server::Id, server_id: entity::server::Id,
user_id: entity::user::Id, user_id: entity::user::Id,
) -> Result<entity::server::member::ServerMember> { ) -> Result<entity::server::member::ServerMember> {
let member = sqlx::query_as!( let member = match sqlx::query_as!(
entity::server::member::ServerMember, entity::server::member::ServerMember,
r#"INSERT INTO "server_member"("server_id", "user_id") VALUES ($1, $2) RETURNING "server_member".*"#, r#"INSERT INTO "server_member"("server_id", "user_id") VALUES ($1, $2) RETURNING "server_member".*"#,
server_id, server_id,
user_id user_id
) )
.fetch_one(&self.pool) .fetch_one(&self.pool)
.await?; .await {
Ok(member) => member,
Err(sqlx::Error::Database(e)) if e.code() == Some("23505".into()) => {
return Err(Error::MemberAlreadyExists);
}
Err(e) => return Err(e.into()),
};
Ok(member) Ok(member)
} }
@@ -242,35 +310,15 @@ impl Database {
Ok(()) Ok(())
} }
pub async fn insert_server_channel(
&self,
server_id: entity::server::Id,
name: &str,
channel_type: entity::channel::ChannelType,
position: u16,
parent: Option<entity::channel::Id>,
) -> Result<entity::channel::Channel> {
let channel = sqlx::query_as!(
entity::channel::Channel,
r#"INSERT INTO "channel"("name", "type", "position", "server_id", "parent") VALUES ($1, $2, $3, $4, $5) RETURNING "channel".*"#,
name,
channel_type as i16,
position as i16,
server_id,
parent
)
.fetch_one(&self.pool)
.await?;
Ok(channel)
}
pub async fn select_channel_by_id( pub async fn select_channel_by_id(
&self, &self,
channel_id: entity::channel::Id, channel_id: entity::channel::Id,
) -> Result<entity::channel::Channel> { ) -> Result<entity::channel::Channel> {
let channel = sqlx::query_as(r#"SELECT * FROM "channel" WHERE "id" = $1"#) let channel = sqlx::query_as!(
.bind(channel_id) entity::channel::Channel,
r#"SELECT * FROM "channel" WHERE "id" = $1"#,
channel_id
)
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
.await? .await?
.ok_or(Error::ChannelDoesNotExists)?; .ok_or(Error::ChannelDoesNotExists)?;
@@ -314,6 +362,217 @@ impl Database {
Ok(channels) Ok(channels)
} }
pub async fn delete_server_by_id(
&self,
server_id: entity::server::Id,
) -> Result<entity::server::Server> {
let server = sqlx::query_as!(
entity::server::Server,
r#"DELETE FROM "server" WHERE "id" = $1 RETURNING "server".*"#,
server_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::ServerDoesNotExists)?;
Ok(server)
}
pub async fn delete_channel_by_id(
&self,
channel_id: entity::channel::Id,
) -> Result<entity::channel::Channel> {
let channel = sqlx::query_as!(
entity::channel::Channel,
r#"DELETE FROM "channel" WHERE "id" = $1 RETURNING "channel".*"#,
channel_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::ChannelDoesNotExists)?;
Ok(channel)
}
pub async fn insert_server_channel(
&self,
name: &str,
position: u16,
r#type: entity::channel::ChannelType,
server_id: entity::server::Id,
parent: Option<entity::channel::Id>,
) -> Result<entity::channel::Channel> {
let channel = sqlx::query_as!(
entity::channel::Channel,
r#"INSERT INTO "channel"("name", "type", "position", "server_id", "parent") VALUES ($1, $2, $3, $4, $5) RETURNING "channel".*"#,
name,
r#type as i16,
position as i16,
server_id,
parent
)
.fetch_one(&self.pool)
.await?;
Ok(channel)
}
pub async fn insert_server_invite(
&self,
code: &str,
server_id: entity::server::Id,
inviter_id: Option<entity::user::Id>,
expires_at: Option<chrono::DateTime<chrono::Utc>>,
) -> Result<entity::server::invite::ServerInvite> {
let invite = sqlx::query_as!(
entity::server::invite::ServerInvite,
r#"INSERT INTO "server_invite"("code", "server_id", "inviter_id", "expires_at") VALUES ($1, $2, $3, $4) RETURNING "server_invite".*"#,
code,
server_id,
inviter_id,
expires_at
)
.fetch_one(&self.pool)
.await?;
Ok(invite)
}
pub async fn select_server_invite_by_code(
&self,
code: &str,
) -> Result<entity::server::invite::ServerInvite> {
let invite = sqlx::query_as!(
entity::server::invite::ServerInvite,
r#"SELECT * FROM "server_invite" WHERE "code" = $1"#,
code
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::InviteDoesNotExists)?;
Ok(invite)
}
pub async fn delete_server_invite_by_code(
&self,
code: &str,
) -> Result<Option<entity::server::invite::ServerInvite>> {
let invite = sqlx::query_as!(
entity::server::invite::ServerInvite,
r#"DELETE FROM "server_invite" WHERE "code" = $1 RETURNING "server_invite".*"#,
code
)
.fetch_optional(&self.pool)
.await?;
Ok(invite)
}
pub async fn select_channel_messages_paginated(
&self,
channel_id: entity::channel::Id,
before: Option<entity::message::Id>,
limit: i64,
) -> Result<Vec<entity::message::Message>> {
let messages = sqlx::query_as!(
entity::message::Message,
r#"SELECT * FROM "message" WHERE "channel_id" = $1 AND ($2::uuid IS NULL OR "id" < $2::uuid) ORDER BY "id" DESC LIMIT $3"#,
channel_id,
before,
limit
)
.fetch_all(&self.pool)
.await?;
Ok(messages)
}
pub async fn insert_channel_message(
&self,
user_id: entity::user::Id,
channel_id: entity::channel::Id,
content: &str,
) -> Result<entity::message::Message> {
let message = sqlx::query_as!(
entity::message::Message,
r#"INSERT INTO "message"("channel_id", "author_id", "content") VALUES ($1, $2, $3) RETURNING "message".*"#,
channel_id,
user_id,
content
)
.fetch_one(&self.pool)
.await?;
Ok(message)
}
pub async fn select_file_by_id(&self, file_id: entity::file::Id) -> Result<entity::file::File> {
let file = sqlx::query_as!(
entity::file::File,
r#"SELECT * FROM "file" WHERE "id" = $1"#,
file_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::FileDoesNotExists)?;
Ok(file)
}
pub async fn delete_file_by_id(&self, file_id: entity::file::Id) -> Result<entity::file::File> {
let file = sqlx::query_as!(
entity::file::File,
r#"DELETE FROM "file" WHERE "id" = $1 RETURNING "file".*"#,
file_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::FileDoesNotExists)?;
Ok(file)
}
pub async fn insert_file(
&self,
filename: &str,
content_type: &str,
size: usize,
) -> Result<entity::file::File> {
let file = sqlx::query_as!(
entity::file::File,
r#"INSERT INTO "file"("filename", "content_type", "size") VALUES ($1, $2, $3) RETURNING "file".*"#,
filename,
content_type,
size as i64
)
.fetch_one(&self.pool)
.await?;
Ok(file)
}
pub async fn select_related_user_ids(
&self,
user_id: entity::user::Id,
) -> Result<Vec<entity::user::Id>> {
#[derive(sqlx::FromRow)]
struct UserId {
user_id: entity::user::Id,
}
let user_ids =
sqlx::query_as::<_, UserId>(r#"SELECT * FROM get_users_that_can_see_user($1)"#)
.bind(user_id)
.fetch_all(&self.pool)
.await?
.into_iter()
.map(|row| row.user_id)
.collect();
Ok(user_ids)
}
pub async fn procedure_create_dm_channel( pub async fn procedure_create_dm_channel(
&self, &self,
user1_id: entity::user::Id, user1_id: entity::user::Id,

View File

@@ -1,4 +1,4 @@
use serde::Serialize; use serde::{Deserialize, Serialize};
use crate::entity::{message, server, user}; use crate::entity::{message, server, user};
@@ -21,7 +21,7 @@ pub struct Channel {
pub last_message_id: Option<message::Id>, pub last_message_id: Option<message::Id>,
} }
#[derive(Debug, Clone, sqlx::Type, Serialize)] #[derive(Debug, Clone, sqlx::Type, Serialize, Deserialize)]
#[non_exhaustive] #[non_exhaustive]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
#[repr(i16)] #[repr(i16)]
@@ -36,6 +36,13 @@ pub enum ChannelType {
impl From<i16> for ChannelType { impl From<i16> for ChannelType {
fn from(value: i16) -> Self { fn from(value: i16) -> Self {
value.try_into().unwrap_or(ChannelType::ServerText) match value {
1 => ChannelType::ServerText,
2 => ChannelType::ServerVoice,
3 => ChannelType::ServerCategory,
4 => ChannelType::DirectMessage,
5 => ChannelType::Group,
_ => ChannelType::ServerText,
}
} }
} }

View File

@@ -3,10 +3,9 @@ use serde::Serialize;
pub type Id = uuid::Uuid; pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)] #[derive(Debug, Clone, sqlx::FromRow, Serialize)]
pub struct Attachment { pub struct File {
pub id: Id, pub id: Id,
pub filename: String, pub filename: String,
pub content_type: String, pub content_type: String,
pub url: String, pub size: i64,
pub size: u64,
} }

View File

@@ -1,15 +1,11 @@
use serde::Serialize;
use crate::entity::{channel, user}; use crate::entity::{channel, user};
pub type Id = uuid::Uuid; pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)] #[derive(Debug, Clone, sqlx::FromRow)]
#[serde(rename_all = "camelCase")]
pub struct Message { pub struct Message {
pub id: Id, pub id: Id,
pub author_id: user::Id, pub author_id: user::Id,
pub channel_id: channel::Id, pub channel_id: channel::Id,
pub content: String, pub content: String,
pub timestamp: chrono::DateTime<chrono::Utc>,
} }

View File

@@ -1,4 +1,4 @@
pub mod attachment; pub mod file;
pub mod channel; pub mod channel;
pub mod message; pub mod message;
pub mod server; pub mod server;

View File

@@ -1,18 +1,15 @@
mod invite; pub mod invite;
pub mod member; pub mod member;
pub mod role; pub mod role;
use serde::Serialize; use crate::entity::{file, user};
use crate::entity::user;
pub type Id = uuid::Uuid; pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)] #[derive(Debug, Clone, sqlx::FromRow)]
#[serde(rename_all = "camelCase")]
pub struct Server { pub struct Server {
pub id: Id, pub id: Id,
pub owner_id: user::Id, pub owner_id: user::Id,
pub name: String, pub name: String,
pub icon_url: Option<String>, pub icon_id: Option<file::Id>,
} }

View File

@@ -1,6 +1,7 @@
use std::sync::LazyLock; use std::sync::LazyLock;
use regex::Regex; use regex::Regex;
use crate::entity::file;
pub static USERNAME_REGEX: LazyLock<Regex> = pub static USERNAME_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_.]+$").unwrap()); LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_.]+$").unwrap());
@@ -10,7 +11,7 @@ pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow)] #[derive(Debug, Clone, sqlx::FromRow)]
pub struct User { pub struct User {
pub id: Id, pub id: Id,
pub avatar_url: Option<String>, pub avatar_id: Option<file::Id>,
pub username: String, pub username: String,
pub display_name: Option<String>, pub display_name: Option<String>,
pub email: String, pub email: String,

View File

@@ -15,7 +15,7 @@ pub struct Claims<T> {
pub iat: i64, pub iat: i64,
} }
pub fn generate_jwt<T: Serialize>(data: T) -> Result<String> { pub fn generate_jwt<T: Serialize>(data: T, secret: &[u8]) -> Result<String> {
let claims = Claims { let claims = Claims {
data, data,
iat: Utc::now().timestamp_millis(), iat: Utc::now().timestamp_millis(),
@@ -24,14 +24,14 @@ pub fn generate_jwt<T: Serialize>(data: T) -> Result<String> {
let token = jsonwebtoken::encode( let token = jsonwebtoken::encode(
&jsonwebtoken::Header::default(), &jsonwebtoken::Header::default(),
&claims, &claims,
&jsonwebtoken::EncodingKey::from_secret(config::config().security.auth_secret.as_ref()), &jsonwebtoken::EncodingKey::from_secret(secret),
) )
.map_err(|_| Error::CouldNotEncodeToken)?; .map_err(|_| Error::CouldNotEncodeToken)?;
Ok(token) Ok(token)
} }
pub fn verify_jwt<T: DeserializeOwned>(token: &str) -> Result<T> { pub fn verify_jwt<T: DeserializeOwned>(token: &str, secret: &[u8]) -> Result<T> {
tracing::debug!("verifying token: {}", token); tracing::debug!("verifying token: {}", token);
let mut validation = jsonwebtoken::Validation::default(); let mut validation = jsonwebtoken::Validation::default();
@@ -39,9 +39,12 @@ pub fn verify_jwt<T: DeserializeOwned>(token: &str) -> Result<T> {
let token_data = jsonwebtoken::decode::<Claims<T>>( let token_data = jsonwebtoken::decode::<Claims<T>>(
token, token,
&jsonwebtoken::DecodingKey::from_secret(config::config().security.auth_secret.as_ref()), &jsonwebtoken::DecodingKey::from_secret(secret),
&validation, &validation,
) )
.inspect_err(|err| {
tracing::error!("Failed to decode JWT: {:?}", err);
})
.map_err(|_| Error::CouldNotVerifyToken)?; .map_err(|_| Error::CouldNotVerifyToken)?;
Ok(token_data.claims.data) Ok(token_data.claims.data)

View File

@@ -10,6 +10,7 @@ mod database;
mod entity; mod entity;
mod jwt; mod jwt;
mod log; mod log;
mod object_store;
mod state; mod state;
mod util; mod util;
mod web; mod web;
@@ -20,8 +21,10 @@ async fn main() -> anyhow::Result<()> {
let _guard = log::init_logging()?; let _guard = log::init_logging()?;
let database = Database::connect(&config::config().database).await?; let database = Database::connect(&config::config().database).await?;
let object_store = object_store::ObjectStore::connect(&config::config().object_store).await?;
let state = AppState { let state = AppState {
database, database,
object_store,
hasher: Arc::new(Argon2::default()), hasher: Arc::new(Argon2::default()),
gateway_state: Default::default(), gateway_state: Default::default(),
voice_rooms: Default::default(), voice_rooms: Default::default(),

51
src/object_store.rs Normal file
View File

@@ -0,0 +1,51 @@
use crate::config::ObjectStoreConfig;
#[derive(Clone, derive_more::AsRef, derive_more::Deref)]
pub struct ObjectStore {
inner: Box<s3::Bucket>,
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From, derive_more::Error, derive_more::Display)]
pub enum Error {
#[from]
Credentials(s3::creds::error::CredentialsError),
#[from]
S3(s3::error::S3Error),
}
impl ObjectStore {
pub async fn connect(config: &ObjectStoreConfig) -> Result<Self> {
let region = s3::region::Region::Custom {
region: config.region.clone(),
endpoint: config.endpoint.origin().ascii_serialization(),
};
let credentials = s3::creds::Credentials::new(
Some(&config.access_key),
Some(&config.secret_key),
None,
None,
None,
)?;
let mut bucket =
s3::bucket::Bucket::new(&config.bucket, region.clone(), credentials.clone())?
.with_path_style();
if !bucket.exists().await? {
bucket = s3::bucket::Bucket::create_with_path_style(
&config.bucket,
region,
credentials,
s3::BucketConfiguration::default(),
)
.await?
.bucket;
}
Ok(Self { inner: bucket })
}
}

View File

@@ -6,17 +6,19 @@ use tokio::sync::{RwLock, mpsc};
use uuid::Uuid; use uuid::Uuid;
use crate::database::Database; use crate::database::Database;
use crate::object_store::ObjectStore;
use crate::web::ws::gateway::{GatewayWsState, SessionKey, event}; use crate::web::ws::gateway::{GatewayWsState, SessionKey, event};
use crate::webrtc::OfferSignal; use crate::webrtc::WebRtcSignal;
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
pub database: Database, pub database: Database,
pub object_store: ObjectStore,
pub hasher: Arc<Argon2<'static>>, pub hasher: Arc<Argon2<'static>>,
pub gateway_state: Arc<GatewayState>, pub gateway_state: Arc<GatewayState>,
pub voice_rooms: Arc<RwLock<HashMap<Uuid, mpsc::UnboundedSender<OfferSignal>>>>, pub voice_rooms: Arc<RwLock<HashMap<Uuid, mpsc::UnboundedSender<WebRtcSignal>>>>,
} }
#[derive(Debug, Default)] #[derive(Debug, Default)]
@@ -60,4 +62,16 @@ impl AppState {
self.gateway_state.connected.remove_async(&user_id).await; self.gateway_state.connected.remove_async(&user_id).await;
} }
} }
pub async fn register_voice_room(
&self,
room_id: Uuid,
sender: mpsc::UnboundedSender<WebRtcSignal>,
) {
self.voice_rooms.write().await.insert(room_id, sender);
}
pub async fn unregister_voice_room(&self, room_id: Uuid) {
self.voice_rooms.write().await.remove(&room_id);
}
} }

View File

@@ -2,6 +2,21 @@ use axum::extract::multipart::Field;
use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError}; use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::entity;
pub fn file_id_to_url(file_id: &entity::file::Id) -> Option<String> {
Some(
crate::config::config()
.server
.hostname
.join("files/")
.ok()?
.join(&file_id.to_string())
.ok()?
.to_string(),
)
}
#[derive(Debug, derive_more::Deref)] #[derive(Debug, derive_more::Deref)]
pub struct SerdeFieldData<T>(pub FieldData<T>); pub struct SerdeFieldData<T>(pub FieldData<T>);
@@ -60,3 +75,25 @@ where
let seconds = u64::deserialize(deserializer)?; let seconds = u64::deserialize(deserializer)?;
Ok(std::time::Duration::from_secs(seconds)) Ok(std::time::Duration::from_secs(seconds))
} }
pub fn serialize_duration_seconds_option<S>(
duration: &Option<std::time::Duration>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match duration {
Some(duration) => serialize_duration_seconds(duration, serializer),
None => serializer.serialize_none(),
}
}
pub fn deserialize_duration_seconds_option<'de, D>(
deserializer: D,
) -> Result<Option<std::time::Duration>, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(deserialize_duration_seconds(deserializer).ok())
}

35
src/web/entity/message.rs Normal file
View File

@@ -0,0 +1,35 @@
use serde::Serialize;
use crate::entity::message::Id;
use crate::entity::{channel, user};
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Message {
pub id: Id,
pub author_id: user::Id,
pub channel_id: channel::Id,
pub content: String,
pub created_at: chrono::DateTime<chrono::Utc>,
}
impl From<crate::entity::message::Message> for Message {
fn from(message: crate::entity::message::Message) -> Self {
Self {
id: message.id,
author_id: message.author_id,
channel_id: message.channel_id,
content: message.content,
created_at: message
.id
.get_timestamp()
.as_ref()
.map(uuid::Timestamp::to_unix)
.map(|(secs, nsecs)| {
chrono::DateTime::<chrono::Utc>::from_timestamp(secs as i64, nsecs)
})
.flatten()
.unwrap_or_default(),
}
}
}

3
src/web/entity/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod message;
pub mod user;
pub mod server;

25
src/web/entity/server.rs Normal file
View File

@@ -0,0 +1,25 @@
use serde::Serialize;
use crate::entity::server::Id;
use crate::entity::user;
use crate::util;
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Server {
pub id: Id,
pub owner_id: user::Id,
pub name: String,
pub icon_url: Option<String>,
}
impl From<crate::entity::server::Server> for Server {
fn from(server: crate::entity::server::Server) -> Self {
Self {
id: server.id,
owner_id: server.owner_id,
name: server.name,
icon_url: server.icon_id.as_ref().map(util::file_id_to_url).flatten(),
}
}
}

54
src/web/entity/user.rs Normal file
View File

@@ -0,0 +1,54 @@
use crate::entity::user;
use crate::util;
#[derive(serde::Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct FullUser {
pub id: user::Id,
pub avatar_url: Option<String>,
pub username: String,
pub display_name: Option<String>,
pub email: String,
pub bot: bool,
pub system: bool,
pub settings: serde_json::Value,
}
#[derive(Debug, Clone, serde::Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PartialUser {
pub id: user::Id,
pub avatar_url: Option<String>,
pub username: String,
pub display_name: Option<String>,
pub bot: bool,
pub system: bool,
}
impl From<user::User> for FullUser {
fn from(user: user::User) -> Self {
Self {
id: user.id,
avatar_url: user.avatar_id.as_ref().map(util::file_id_to_url).flatten(),
username: user.username,
display_name: user.display_name,
email: user.email,
bot: user.bot,
system: user.system,
settings: user.settings,
}
}
}
impl From<user::User> for PartialUser {
fn from(user: user::User) -> Self {
Self {
id: user.id,
avatar_url: user.avatar_id.as_ref().map(util::file_id_to_url).flatten(),
username: user.username,
display_name: user.display_name,
bot: user.bot,
system: user.system,
}
}
}

View File

@@ -4,7 +4,7 @@ use axum::http::StatusCode;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use crate::web::context; use crate::web::context;
use crate::{database, jwt}; use crate::{database, jwt, object_store};
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@@ -22,6 +22,9 @@ pub enum Error {
#[from] #[from]
Database(database::Error), Database(database::Error),
#[from]
ObjectStore(object_store::Error),
#[from] #[from]
Json(serde_json::error::Error), Json(serde_json::error::Error),
@@ -55,6 +58,8 @@ pub enum ClientError {
ValidationFailed(validator::ValidationErrors), ValidationFailed(validator::ValidationErrors),
InternalServerError, InternalServerError,
Unknown,
} }
#[derive(derive_more::Debug, Clone, serde::Serialize)] #[derive(derive_more::Debug, Clone, serde::Serialize)]

View File

@@ -52,7 +52,11 @@ async fn get_context(state: &AppState, request: &mut Request) -> context::UserCo
} }
pub async fn get_context_from_token(state: &AppState, token: &str) -> context::UserContextResult { pub async fn get_context_from_token(state: &AppState, token: &str) -> context::UserContextResult {
let context = jwt::verify_jwt::<UserContext>(token).map_err(|_| context::Error::BadToken)?; let context = jwt::verify_jwt::<UserContext>(
token,
crate::config::config().security.auth_secret.as_ref(),
)
.map_err(|_| context::Error::BadToken)?;
let _ = state let _ = state
.database .database

View File

@@ -1,4 +1,5 @@
mod context; mod context;
mod entity;
mod error; mod error;
mod middleware; mod middleware;
mod route; mod route;
@@ -38,6 +39,8 @@ fn router(state: state::AppState) -> axum::Router {
// websocket // websocket
.route("/gateway/ws", get(ws::gateway::ws_handler)) .route("/gateway/ws", get(ws::gateway::ws_handler))
.route("/voice/ws", get(ws::voice::ws_handler)) .route("/voice/ws", get(ws::voice::ws_handler))
// file
.route("/files/{file_id}", get(file::get))
// api // api
.nest( .nest(
"/api/v1", "/api/v1",
@@ -66,13 +69,41 @@ fn protected_router() -> axum::Router<state::AppState> {
Router::new() Router::new()
// user // user
.route("/users/@me", get(user::me)) .route("/users/@me", get(user::me))
.route("/users/@me", patch(user::patch))
.route("/users/@me/channels", get(user::channel::list)) .route("/users/@me/channels", get(user::channel::list))
.route("/users/{id}", get(user::get_by_id)) .route("/users/{id}", get(user::get_by_id))
// channel
.route(
"/channels/{channel_id}/messages",
get(channel::message::page),
)
.route(
"/channels/{channel_id}/messages",
post(channel::message::create),
)
// server // server
.route("/servers", get(server::list)) .route("/servers", get(server::list))
.route("/servers", post(server::create)) .route("/servers", post(server::create))
.route("/servers/{server_id}", get(server::get)) .route("/servers/{server_id}", get(server::get))
.route("/servers/{server_id}", delete(server::delete))
.route("/servers/{server_id}/channels", get(server::channel::list)) .route("/servers/{server_id}/channels", get(server::channel::list))
.route(
"/servers/{server_id}/channels",
post(server::channel::create),
)
.route(
"/servers/{server_id}/channels/{channel_id}",
get(server::channel::get),
)
.route(
"/servers/{server_id}/channels/{channel_id}",
delete(server::channel::delete),
)
// invite
.route("/servers/{server_id}/invites", post(server::invite::create))
.route("/invites/{code}", get(server::invite::get))
// file
.route("/files", post(file::upload))
// middleware // middleware
.route_layer(axum::middleware::from_fn(middleware::require_context)) .route_layer(axum::middleware::from_fn(middleware::require_context))
} }

View File

@@ -6,8 +6,8 @@ use serde::{Deserialize, Serialize};
use crate::state::AppState; use crate::state::AppState;
use crate::web::context::UserContext; use crate::web::context::UserContext;
use crate::web::route::user::FullUser;
use crate::{jwt, web}; use crate::{jwt, web};
use crate::web::entity::user::FullUser;
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@@ -39,7 +39,7 @@ pub async fn login(
.verify_password(payload.password.as_bytes(), &password_hash) .verify_password(payload.password.as_bytes(), &password_hash)
.map_err(|_| web::error::ClientError::WrongPassword)?; .map_err(|_| web::error::ClientError::WrongPassword)?;
let token = jwt::generate_jwt(UserContext { user_id: user.id })?; let token = jwt::generate_jwt(UserContext { user_id: user.id }, crate::config::config().security.auth_secret.as_ref())?;
let response = LoginResponse { let response = LoginResponse {
user: user.into(), user: user.into(),

View File

@@ -9,8 +9,8 @@ use validator::Validate;
use crate::state::AppState; use crate::state::AppState;
use crate::web; use crate::web;
use crate::web::entity::user::FullUser;
use crate::web::error::ClientError; use crate::web::error::ClientError;
use crate::web::route::user::FullUser;
#[derive(Validate, Deserialize)] #[derive(Validate, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]

View File

@@ -0,0 +1,49 @@
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use validator::Validate;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::entity::message::Message;
use crate::web::ws;
use crate::{entity, web};
#[derive(Debug, serde::Deserialize, Validate)]
#[serde(rename_all = "camelCase")]
pub struct CreatePayload {
#[validate(length(min = 1, max = 2000))]
pub content: String,
}
pub async fn create(
State(state): State<AppState>,
context: UserContext,
Path(channel_id): Path<entity::channel::Id>,
Json(payload): Json<CreatePayload>,
) -> web::Result<impl IntoResponse> {
// TODO: check permissions
match payload.validate() {
Ok(_) => {},
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
};
let message = state
.database
.insert_channel_message(context.user_id, channel_id, &payload.content)
.await?;
let message = Message::from(message);
// TODO: check permissions
ws::gateway::util::send_message_channel(
state,
message.channel_id,
ws::gateway::event::Event::AddMessage {
channel_id,
message: message.clone(),
},
);
Ok(Json(message))
}

View File

@@ -0,0 +1,5 @@
mod page;
mod create;
pub use page::page;
pub use create::create;

View File

@@ -0,0 +1,47 @@
use axum::Json;
use axum::extract::{Path, Query, State};
use axum::response::IntoResponse;
use serde::Deserialize;
use validator::Validate;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::entity::message::Message;
use crate::{entity, web};
#[derive(Debug, Deserialize, Validate)]
#[serde(rename_all = "camelCase")]
pub struct PageParams {
#[serde(default = "limit_default")]
#[validate(range(min = 1, max = 100))]
pub limit: u32,
#[serde(default)]
pub before: Option<entity::message::Id>,
}
fn limit_default() -> u32 {
50
}
pub async fn page(
State(state): State<AppState>,
context: UserContext,
Path(channel_id): Path<entity::channel::Id>,
Query(params): Query<PageParams>,
) -> web::Result<impl IntoResponse> {
// TODO: check permissions
match params.validate() {
Ok(_) => {},
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
};
let messages = state
.database
.select_channel_messages_paginated(channel_id, params.before, params.limit as i64)
.await?
.into_iter()
.map(Message::from)
.collect::<Vec<_>>();
Ok(Json(messages))
}

View File

@@ -0,0 +1 @@
pub mod message;

43
src/web/route/file/get.rs Normal file
View File

@@ -0,0 +1,43 @@
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::{entity, object_store, web};
pub async fn get(
State(state): State<AppState>,
Path(file_id): Path<entity::file::Id>,
) -> web::Result<impl IntoResponse> {
let file = match state.database.select_file_by_id(file_id).await {
Ok(file) => file,
Err(e) => {
return Ok(axum::http::StatusCode::NOT_FOUND.into_response());
},
};
let data = match state
.object_store
.get_object_stream(&file.id.to_string())
.await
{
Ok(data) => data,
Err(s3::error::S3Error::HttpFailWithBody(403 | 404, _)) => {
let _ = state.database.delete_file_by_id(file.id).await?;
return Ok(axum::http::StatusCode::NOT_FOUND.into_response());
},
Err(e) => {
return Err(object_store::Error::from(e).into());
},
};
let headers = axum::response::AppendHeaders([
(axum::http::header::CONTENT_TYPE, file.content_type.clone()),
(
axum::http::header::CONTENT_DISPOSITION,
format!("filename=\"{}\"", file.filename),
),
]);
Ok((headers, axum::body::Body::from_stream(data.bytes)).into_response())
}

View File

@@ -0,0 +1,5 @@
mod get;
mod upload;
pub use get::get;
pub use upload::upload;

View File

@@ -0,0 +1,54 @@
use axum::Json;
use axum::body::Bytes;
use axum::extract::State;
use axum::response::IntoResponse;
use axum_typed_multipart::{TryFromMultipart, TypedMultipart};
use validator::Validate;
use crate::state::AppState;
use crate::util::SerdeFieldData;
use crate::{object_store, web};
#[derive(Debug, Validate, TryFromMultipart)]
#[try_from_multipart(rename_all = "camelCase")]
pub struct UploadPayload {
#[form_data(limit = "50MB")]
#[validate(length(min = 1, max = 16))]
files: Vec<SerdeFieldData<Bytes>>,
}
pub async fn upload(
State(state): State<AppState>,
TypedMultipart(payload): TypedMultipart<UploadPayload>,
) -> web::Result<impl IntoResponse> {
match payload.validate() {
Ok(_) => {},
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
};
let mut file_ids = Vec::new();
for file in payload.files {
let db_file = state
.database
.insert_file(
file.metadata
.file_name
.as_deref()
.unwrap_or_else(|| "unknown"),
file.metadata.content_type.as_deref().unwrap_or_default(),
file.contents.len(),
)
.await?;
state
.object_store
.put_object(&db_file.id.to_string(), &file.contents)
.await
.map_err(object_store::Error::from)?;
file_ids.push(db_file.id);
}
Ok(Json(file_ids))
}

View File

@@ -1,4 +1,5 @@
pub mod auth; pub mod auth;
pub mod channel;
pub mod file;
pub mod server; pub mod server;
pub mod user; pub mod user;
pub mod voice;

View File

@@ -0,0 +1,58 @@
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use axum_extra::extract::WithRejection;
use serde::Deserialize;
use validator::Validate;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::error::ClientError;
use crate::web::ws;
use crate::{entity, web};
#[derive(Debug, Validate, Deserialize)]
pub struct CreatePayload {
#[validate(length(min = 1, max = 32))]
name: String,
#[validate(custom(function = "validate_server_channel_type"))]
r#type: entity::channel::ChannelType,
}
fn validate_server_channel_type(
r#type: &entity::channel::ChannelType,
) -> Result<(), validator::ValidationError> {
match r#type {
entity::channel::ChannelType::ServerText => Ok(()),
entity::channel::ChannelType::ServerVoice => Ok(()),
entity::channel::ChannelType::ServerCategory => Ok(()),
_ => Err(validator::ValidationError::new("invalid_channel_type")),
}
}
pub async fn create(
State(state): State<AppState>,
context: UserContext,
Path(server_id): Path<entity::server::Id>,
WithRejection(Json(payload), _): WithRejection<Json<CreatePayload>, web::Error>,
) -> web::Result<impl IntoResponse> {
payload.validate().map_err(ClientError::ValidationFailed)?;
// TODO: check permissions
let server = state.database.select_server_by_id(server_id).await?;
let channel = state
.database
.insert_server_channel(&payload.name, 0, payload.r#type, server_id, None)
.await?;
ws::gateway::util::send_message_server(
state,
server_id,
ws::gateway::event::Event::AddServerChannel {
channel: channel.clone(),
},
);
Ok(Json(channel))
}

View File

@@ -0,0 +1,39 @@
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::ws;
use crate::{entity, web};
pub async fn delete(
State(state): State<AppState>,
context: UserContext,
Path((server_id, channel_id)): Path<(entity::server::Id, entity::channel::Id)>,
) -> web::Result<impl IntoResponse> {
// TODO: check permissions
let channel = state.database.select_channel_by_id(channel_id).await?;
if let Some(channel_server_id) = channel.server_id {
if channel_server_id != server_id {
return Err(web::error::ClientError::NotAllowed.into());
}
} else {
return Err(web::error::ClientError::NotAllowed.into());
}
let channel = state.database.delete_channel_by_id(channel_id).await?;
ws::gateway::util::send_message_server(
state,
server_id,
ws::gateway::event::Event::RemoveServerChannel {
server_id: server_id.clone(),
channel_id: channel.id.clone(),
},
);
Ok(Json(channel))
}

View File

@@ -0,0 +1,27 @@
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::{entity, web};
pub async fn get(
State(state): State<AppState>,
context: UserContext,
Path((server_id, channel_id)): Path<(entity::server::Id, entity::channel::Id)>,
) -> web::Result<impl IntoResponse> {
// TODO: check permissions
let channel = state.database.select_channel_by_id(channel_id).await?;
if let Some(channel_server_id) = channel.server_id {
if channel_server_id != server_id {
return Err(web::error::ClientError::NotAllowed.into());
}
} else {
return Err(web::error::ClientError::NotAllowed.into());
}
Ok(Json(channel))
}

View File

@@ -9,9 +9,9 @@ use crate::{entity, web};
pub async fn list( pub async fn list(
State(state): State<AppState>, State(state): State<AppState>,
context: UserContext, context: UserContext,
Path(id): Path<entity::server::Id>, Path(server_id): Path<entity::server::Id>,
) -> web::Result<impl IntoResponse> { ) -> web::Result<impl IntoResponse> {
let channels = state.database.select_server_channels(id).await?; let channels = state.database.select_server_channels(server_id).await?;
Ok(Json(channels)) Ok(Json(channels))
} }

View File

@@ -1,3 +1,9 @@
mod create;
mod delete;
mod get;
mod list; mod list;
pub use create::create;
pub use delete::delete;
pub use get::get;
pub use list::list; pub use list::list;

View File

@@ -1,50 +1,37 @@
use axum::Json; use axum::Json;
use axum::body::Bytes;
use axum::extract::State; use axum::extract::State;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum_typed_multipart::{TryFromMultipart, TypedMultipart}; use axum_extra::extract::WithRejection;
use validator::{Validate, ValidationError}; use axum_typed_multipart::TryFromMultipart;
use serde::Deserialize;
use validator::Validate;
use crate::state::AppState; use crate::state::AppState;
use crate::util::SerdeFieldData;
use crate::web;
use crate::web::context::UserContext; use crate::web::context::UserContext;
use crate::web::error::ClientError; use crate::web::error::ClientError;
use crate::web::ws; use crate::web::ws;
use crate::{entity, web};
use crate::web::entity::server::Server;
#[derive(Debug, Validate, TryFromMultipart)] #[derive(Debug, Validate, Deserialize)]
#[try_from_multipart(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct CreatePayload { pub struct CreatePayload {
#[validate(length(min = 1, max = 32))] #[validate(length(min = 1, max = 32))]
name: String, name: String,
#[validate(custom(function = "validate_icon_content_type"))] icon_id: Option<entity::file::Id>,
#[form_data(limit = "10MB")]
icon: Option<SerdeFieldData<Bytes>>,
}
fn validate_icon_content_type(icon: &SerdeFieldData<Bytes>) -> Result<(), ValidationError> {
if let Some(content_type) = icon.metadata.content_type.as_deref() {
if !content_type.starts_with("image/") {
return Err(ValidationError::new("invalid_icon_content_type"));
}
} else {
return Err(ValidationError::new("missing_icon_content_type"));
}
Ok(())
} }
pub async fn create( pub async fn create(
State(state): State<AppState>, State(state): State<AppState>,
context: UserContext, context: UserContext,
TypedMultipart(payload): TypedMultipart<CreatePayload>, WithRejection(Json(payload), _): WithRejection<Json<CreatePayload>, web::Error>,
) -> web::Result<impl IntoResponse> { ) -> web::Result<impl IntoResponse> {
payload.validate().map_err(ClientError::ValidationFailed)?; payload.validate().map_err(ClientError::ValidationFailed)?;
let server = state let server = state
.database .database
.insert_server(&payload.name, None, context.user_id) .insert_server(&payload.name, payload.icon_id, context.user_id)
.await?; .await?;
let role = state let role = state
@@ -70,14 +57,15 @@ pub async fn create(
.insert_server_member_role(member.id, role.id) .insert_server_member_role(member.id, role.id)
.await?; .await?;
let server = Server::from(server);
ws::gateway::util::send_message( ws::gateway::util::send_message(
&state, state,
context.user_id, context.user_id,
ws::gateway::event::Event::AddServer { ws::gateway::event::Event::AddServer {
server: server.clone(), server: server.clone(),
}, },
) );
.await;
Ok(Json(server)) Ok(Json(server))
} }

View File

@@ -0,0 +1,60 @@
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::ws;
use crate::webrtc::WebRtcSignal;
use crate::{entity, web};
use crate::web::entity::server::Server;
pub async fn delete(
State(state): State<AppState>,
context: UserContext,
Path(server_id): Path<entity::server::Id>,
) -> web::Result<impl IntoResponse> {
let server = state.database.select_server_by_id(server_id).await?;
if server.owner_id != context.user_id {
return Err(web::error::ClientError::NotAllowed.into());
}
let members = state
.database
.select_server_members(server_id)
.await?
.iter()
.map(|u| u.id)
.collect::<Vec<_>>();
let channels = state
.database
.select_server_channels(server_id)
.await?
.iter()
.map(|c| c.id)
.collect::<Vec<_>>();
let state_clone = state.clone();
tokio::spawn(async move {
let voice_rooms = state_clone.voice_rooms.read().await;
for channel_id in channels {
if let Some(voice_room) = voice_rooms.get(&channel_id) {
let _ = voice_room.send(WebRtcSignal::Close);
}
}
});
let server = state.database.delete_server_by_id(server_id).await?;
ws::gateway::util::send_message_many(
state.clone(),
&members,
ws::gateway::event::Event::RemoveServer {
server_id: server.id,
},
);
Ok(Json(Server::from(server)))
}

View File

@@ -4,12 +4,15 @@ use axum::response::IntoResponse;
use crate::state::AppState; use crate::state::AppState;
use crate::{entity, web}; use crate::{entity, web};
use crate::web::entity::server::Server;
pub async fn get( pub async fn get(
State(state): State<AppState>, State(state): State<AppState>,
Path(id): Path<entity::server::Id>, Path(server_id): Path<entity::server::Id>,
) -> web::Result<impl IntoResponse> { ) -> web::Result<impl IntoResponse> {
let server = state.database.select_server_by_id(id).await?; // TODO: check permissions
Ok(Json(server)) let server = state.database.select_server_by_id(server_id).await?;
Ok(Json(Server::from(server)))
} }

View File

@@ -0,0 +1,45 @@
use axum::Json;
use axum::extract::{Path, Query, State};
use axum::response::IntoResponse;
use base64::Engine;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::{entity, web};
#[derive(serde::Deserialize, Debug)]
pub struct CreateParams {
#[serde(deserialize_with = "crate::util::deserialize_duration_seconds_option")]
#[serde(default)]
pub expires_in: Option<std::time::Duration>,
}
pub async fn create(
State(state): State<AppState>,
context: UserContext,
Path(server_id): Path<entity::server::Id>,
Query(params): Query<CreateParams>,
) -> web::Result<impl IntoResponse> {
// TODO: check permissions
let code = {
use rand::Rng;
let mut rng = rand::rng();
let mut code = [0u8; 32];
rng.fill(&mut code);
base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(&code)
};
let expires_at = params.expires_in.map(|d| {
let now = chrono::Utc::now();
now + chrono::Duration::from_std(d).expect("valid duration")
});
let invite = state
.database
.insert_server_invite(&code, server_id, Some(context.user_id), expires_at)
.await?;
Ok(Json(invite))
}

View File

@@ -0,0 +1,46 @@
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::ws;
use crate::{database, web};
use crate::web::entity::server::Server;
pub async fn get(
State(state): State<AppState>,
context: UserContext,
Path(code): Path<String>,
) -> web::Result<impl IntoResponse> {
let invite = state.database.select_server_invite_by_code(&code).await?;
let server = state.database.select_server_by_id(invite.server_id).await?;
let member = match state
.database
.insert_server_member(invite.server_id, context.user_id)
.await
{
Ok(member) => member,
Err(database::Error::MemberAlreadyExists) => return Ok(Json(Server::from(server))),
Err(e) => return Err(e.into()),
};
state
.database
.insert_server_member_role(member.id, invite.server_id)
.await?;
let user = state.database.select_user_by_id(context.user_id).await?;
ws::gateway::util::send_message_server(
state,
invite.server_id,
ws::gateway::event::Event::AddServerMember {
server_id: server.id,
member: user.into(),
},
);
Ok(Json(Server::from(server)))
}

View File

@@ -0,0 +1,5 @@
mod create;
mod get;
pub use create::create;
pub use get::get;

View File

@@ -5,12 +5,19 @@ use axum::response::IntoResponse;
use crate::state::AppState; use crate::state::AppState;
use crate::web; use crate::web;
use crate::web::context::UserContext; use crate::web::context::UserContext;
use crate::web::entity::server::Server;
pub async fn list( pub async fn list(
State(state): State<AppState>, State(state): State<AppState>,
context: UserContext, context: UserContext,
) -> web::Result<impl IntoResponse> { ) -> web::Result<impl IntoResponse> {
let servers = state.database.select_user_servers(context.user_id).await?; let servers = state
.database
.select_user_servers(context.user_id)
.await?
.into_iter()
.map(Server::from)
.collect::<Vec<_>>();
Ok(Json(servers)) Ok(Json(servers))
} }

View File

@@ -1,8 +1,11 @@
pub mod channel; pub mod channel;
mod create; mod create;
mod delete;
mod get; mod get;
pub mod invite;
mod list; mod list;
pub use create::create; pub use create::create;
pub use delete::delete;
pub use get::get; pub use get::get;
pub use list::list; pub use list::list;

View File

@@ -7,7 +7,7 @@ use crate::entity::channel;
use crate::state::AppState; use crate::state::AppState;
use crate::web; use crate::web;
use crate::web::context::UserContext; use crate::web::context::UserContext;
use crate::web::route::user::PartialUser; use crate::web::entity::user::PartialUser;
#[derive(Debug, sqlx::FromRow, Serialize)] #[derive(Debug, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]

View File

@@ -4,13 +4,13 @@ use axum::response::IntoResponse;
use crate::state::AppState; use crate::state::AppState;
use crate::web; use crate::web;
use crate::web::route::user::PartialUser; use crate::web::entity::user::PartialUser;
pub async fn get_by_id( pub async fn get_by_id(
Path(id): Path<uuid::Uuid>,
State(state): State<AppState>, State(state): State<AppState>,
Path(user_id): Path<uuid::Uuid>,
) -> web::Result<impl IntoResponse> { ) -> web::Result<impl IntoResponse> {
let user = state.database.select_user_by_id(id).await?; let user = state.database.select_user_by_id(user_id).await?;
Ok(Json(PartialUser::from(user))) Ok(Json(PartialUser::from(user)))
} }

View File

@@ -5,11 +5,11 @@ use axum::response::IntoResponse;
use crate::state::AppState; use crate::state::AppState;
use crate::web; use crate::web;
use crate::web::context::UserContext; use crate::web::context::UserContext;
use crate::web::route::user::FullUser; use crate::web::entity::user::FullUser;
pub async fn me( pub async fn me(
context: UserContext,
State(state): State<AppState>, State(state): State<AppState>,
context: UserContext,
) -> web::Result<impl IntoResponse> { ) -> web::Result<impl IntoResponse> {
let user = state.database.select_user_by_id(context.user_id).await?; let user = state.database.select_user_by_id(context.user_id).await?;

View File

@@ -1,60 +1,9 @@
pub mod channel; pub mod channel;
mod get; mod get;
mod me; mod me;
mod patch;
pub use get::get_by_id; pub use get::get_by_id;
pub use me::me; pub use me::me;
pub use patch::patch;
use crate::entity::user;
#[derive(serde::Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct FullUser {
pub id: user::Id,
pub avatar_url: Option<String>,
pub username: String,
pub display_name: Option<String>,
pub email: String,
pub bot: bool,
pub system: bool,
pub settings: serde_json::Value,
}
#[derive(serde::Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct PartialUser {
pub id: user::Id,
pub avatar_url: Option<String>,
pub username: String,
pub display_name: Option<String>,
pub bot: bool,
pub system: bool,
}
impl From<user::User> for FullUser {
fn from(user: user::User) -> Self {
Self {
id: user.id,
avatar_url: user.avatar_url,
username: user.username,
display_name: user.display_name,
email: user.email,
bot: user.bot,
system: user.system,
settings: user.settings,
}
}
}
impl From<user::User> for PartialUser {
fn from(user: user::User) -> Self {
Self {
id: user.id,
avatar_url: user.avatar_url,
username: user.username,
display_name: user.display_name,
bot: user.bot,
system: user.system,
}
}
}

View File

@@ -0,0 +1,61 @@
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use axum_extra::extract::WithRejection;
use serde::Deserialize;
use validator::Validate;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::entity::user::{FullUser, PartialUser};
use crate::web::ws;
use crate::{entity, web};
#[derive(Debug, Validate, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreatePayload {
#[validate(length(min = 1, max = 32))]
#[serde(default)]
display_name: Option<String>,
#[serde(default)]
avatar_id: Option<entity::file::Id>,
}
pub async fn patch(
State(state): State<AppState>,
context: UserContext,
WithRejection(Json(payload), _): WithRejection<Json<CreatePayload>, web::Error>,
) -> web::Result<impl IntoResponse> {
match payload.validate() {
Ok(_) => {},
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
};
let user = state
.database
.update_user_by_id(
context.user_id,
payload.display_name.as_deref(),
payload.avatar_id,
)
.await?;
ws::gateway::util::send_message(
state.clone(),
context.user_id,
ws::gateway::event::Event::AddUser {
user: PartialUser::from(user.clone()),
},
);
ws::gateway::util::send_message_related(
state.clone(),
context.user_id,
ws::gateway::event::Event::AddUser {
user: PartialUser::from(user.clone()),
},
);
Ok(Json(FullUser::from(user)))
}

View File

@@ -1,91 +0,0 @@
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use axum_extra::extract::WithRejection;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::{entity, web};
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Payload {
sdp: RTCSessionDescription,
}
#[derive(Debug, serde::Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Response {
sdp: RTCSessionDescription,
}
pub async fn connect(
State(state): State<AppState>,
context: UserContext,
Path(channel_id): Path<entity::channel::Id>,
WithRejection(Json(payload), _): WithRejection<Json<Payload>, web::Error>,
) -> web::Result<impl IntoResponse> {
tracing::debug!("connect to voice channel: {:?}", channel_id);
let channel = state.database.select_channel_by_id(channel_id).await?;
let channel_id = channel.id;
let room_sender = {
state
.voice_rooms
.read()
.await
.get(&channel_id)
.map(|room| room.clone())
};
let room_sender = match room_sender {
Some(room) => room,
None => {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let rooms = state.voice_rooms.clone();
tokio::spawn(async move {
crate::webrtc::webrtc_task(channel_id, rx)
.await
.unwrap_or_else(|err| {
tracing::error!("webrtc task error: {:?}", err);
});
{
let mut rooms = rooms.write().await;
rooms.remove(&channel_id);
}
});
{
let mut rooms = state.voice_rooms.write().await;
rooms.insert(channel_id, tx.clone());
}
tx
},
};
let offer = crate::webrtc::Offer {
peer_id: context.user_id,
sdp_offer: payload.sdp,
};
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let _ = room_sender.send(crate::webrtc::OfferSignal {
offer,
response: response_tx,
});
let answer = response_rx
.await
.map_err(|_| web::error::ClientError::InternalServerError)?;
let response = Response {
sdp: answer.sdp_answer,
};
Ok(Json(response))
}

View File

@@ -1,3 +0,0 @@
mod connect;
pub use connect::connect;

View File

@@ -1,7 +1,10 @@
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T, E> = std::result::Result<T, Error<E>>;
#[derive(Debug, derive_more::From, derive_more::Display)] #[derive(Debug, derive_more::From, derive_more::Display)]
pub enum Error { pub enum Error<T: CustomError> {
#[from]
Custom(T),
#[from] #[from]
Json(serde_json::Error), Json(serde_json::Error),
@@ -11,4 +14,8 @@ pub enum Error {
WrongMessageType, WrongMessageType,
WebSocketClosed, WebSocketClosed,
UnknownError,
} }
pub trait CustomError {}

View File

@@ -1,101 +1,75 @@
use std::ops::ControlFlow; use axum::extract::ws::Message as AxumMessage;
use axum::extract::ws::{Message as AxumMessage, WebSocket};
use base64::Engine as _; use base64::Engine as _;
use futures::stream::SplitStream; use futures::{Stream, StreamExt};
use futures::{Sink, SinkExt, StreamExt};
use serde::Serialize;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use tokio::time::Instant; use tokio::sync::mpsc;
use super::error::{self, Error as WsError}; use super::error::Error as WsError;
use super::event::Event as WsEvent; use super::event::Event as WsEvent;
use super::protocol::{WsClientMessage, WsServerMessage}; use super::protocol::{WsClientMessage, WsServerMessage};
use super::state::{WsContext, WsState, WsUserContext}; use super::state::{WsContext, WsState, WsUserContext};
use crate::jwt; use crate::jwt;
use crate::state::AppState; use crate::state::AppState;
use crate::web::ws::gateway::SessionKey; use crate::web::ws::gateway::SessionKey;
use crate::web::ws::util::{SendWsMessage, deserialize_ws_message, serialize_ws_message}; use crate::web::ws::general::WebSocketHandler;
use crate::web::ws::{util, voice}; use crate::web::ws::util::{SendWsMessage, deserialize_ws_message};
use crate::web::ws::voice;
use crate::webrtc::WebRtcSignal;
/// Main handler for an individual WebSocket connection's lifecycle. impl WebSocketHandler for WsContext {
/// Spawned by Axum upon successful WebSocket upgrade. type ServerMessage = WsServerMessage;
#[tracing::instrument(skip_all, name = "ws_connection_handler")] type ClientMessage = WsClientMessage;
pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) { type Error = WsError;
let (ws_sink, ws_stream) = websocket.split();
let (internal_send_tx, internal_send_rx) = tokio::sync::mpsc::unbounded_channel(); async fn handle_stream<S>(
&mut self,
stream: S,
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
app_state: &AppState,
) -> crate::web::ws::error::Result<(), Self::Error>
where
S: Stream<Item = Result<AxumMessage, axum::Error>> + Unpin,
{
process_websocket_messages(self, stream, sender, app_state).await?;
let writer_task = util::spawn_writer_task(ws_sink, internal_send_rx); Ok(())
}
let mut context = WsContext { async fn cleanup(&mut self, app_state: &AppState) {
connection_state: WsState::Initialize, if let Some(user_ctx_data) = &self.user_context {
user_context: None, app_state
event_channel: None, .unregister_gateway_connected_user(
}; user_ctx_data.user_id,
&user_ctx_data.session_key,
let processing_result = process_websocket_messages(
&mut context,
ws_stream,
&internal_send_tx,
&app_state,
) )
.await; .await;
// --- Cleanup ---
if let Some(user_ctx_data) = &context.user_context {
app_state
.unregister_gateway_connected_user(user_ctx_data.user_id, &user_ctx_data.session_key)
.await;
tracing::info!(user_id = ?user_ctx_data.user_id, session_key = %user_ctx_data.session_key, "Unregistered WebSocket user.");
} }
// Drop our sender for the event channel; receiver in `process_websocket_messages` will see this. drop(self.event_channel.take());
drop(context.event_channel.take()); }
// If processing loop exited with an error (not a graceful close like WebSocketClosed or HeartbeatTimeout), async fn handle_result_error(
// try to send a final error message to the client. &mut self,
if let Err(err_to_report) = &processing_result { error: Self::Error,
if !matches!( sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
err_to_report,
WsError::WebSocketClosed
) { ) {
tracing::warn!( let error_ws_message = WsServerMessage::Error { code: error };
"WebSocket processing error, attempting to notify client: {:?}",
err_to_report
);
let client_err_code = err_to_report.as_client_error();
let error_ws_message = WsServerMessage::Error {
code: client_err_code,
};
// Use new_no_response for best-effort send during shutdown.
// Ignore result as internal_send_tx might already be closed if writer_task ended.
let _ = internal_send_tx.send(SendWsMessage::new_no_response(error_ws_message));
}
}
// Signal writer task to stop by dropping the MPSC sender. let _ = sender.send(SendWsMessage::new_no_response(error_ws_message));
drop(internal_send_tx);
// Wait for the writer task to complete its shutdown.
if let Err(e) = writer_task.await {
tracing::error!(
"WebSocket writer task panicked or encountered an error: {:?}",
e
);
} }
tracing::debug!(result = ?processing_result, "WebSocket connection handler finished.");
} }
/// Main loop for processing incoming WebSocket messages and outgoing application events.
/// Manages state transitions (Initialize -> Connected) and heartbeating.
#[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id) #[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)
))] ))]
async fn process_websocket_messages( async fn process_websocket_messages<S>(
context: &mut WsContext, context: &mut WsContext,
mut ws_stream: SplitStream<WebSocket>, mut ws_stream: S,
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage<WsServerMessage>>, sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
app_state: &AppState, app_state: &AppState,
) -> error::Result<()> { ) -> crate::web::ws::error::Result<(), WsError>
where
S: Stream<Item = Result<AxumMessage, axum::Error>> + Unpin,
{
loop { loop {
match context.connection_state { match context.connection_state {
WsState::Initialize => { WsState::Initialize => {
@@ -104,24 +78,17 @@ async fn process_websocket_messages(
maybe_message = ws_stream.next() => { maybe_message = ws_stream.next() => {
match maybe_message { match maybe_message {
Some(Ok(message)) => { Some(Ok(message)) => {
match handle_initial_message(context, message, sender, app_state).await { handle_initial_message(context, message, sender, app_state).await?;
Ok(ControlFlow::Continue(())) => {}, context.connection_state = WsState::Connected;
Ok(ControlFlow::Break(new_state)) => { // Authenticated
context.connection_state = new_state;
tracing::info!(user_id = ?context.user_context.as_ref().unwrap().user_id, "User authenticated, WebSocket connected."); tracing::info!(user_id = ?context.user_context.as_ref().unwrap().user_id, "User authenticated, WebSocket connected.");
},
Err(e) => { // Auth failed critically or other error
return Err(e);
}
}
} }
Some(Err(axum_ws_err)) => { Some(Err(axum_ws_err)) => {
tracing::debug!("WebSocket stream error during Initialize: {:?}", axum_ws_err); tracing::debug!("WebSocket stream error during Initialize: {:?}", axum_ws_err);
return Err(WsError::WebSocketClosed); return Err(crate::web::ws::error::Error::WebSocketClosed);
} }
None => { // Stream closed by client None => { // Stream closed by client
tracing::debug!("WebSocket stream ended by client during Initialize state."); tracing::debug!("WebSocket stream ended by client during Initialize state.");
return Err(WsError::WebSocketClosed); return Err(crate::web::ws::error::Error::WebSocketClosed);
} }
} }
} }
@@ -139,30 +106,26 @@ async fn process_websocket_messages(
tokio::select! { tokio::select! {
biased; biased;
// Listen for application events to send to the client
maybe_app_event = event_rx.recv() => { maybe_app_event = event_rx.recv() => {
if let Some(app_event_data) = maybe_app_event { if let Some(app_event_data) = maybe_app_event {
SendWsMessage::send_with_response(sender, WsServerMessage::Event { event: app_event_data }).await?; SendWsMessage::send_with_response(sender, WsServerMessage::Event { event: app_event_data }).await?;
// Sending an app event doesn't reset the client's ping requirement.
} else { } else {
// Event channel closed (e.g., AppState unregistered, or system shutdown signal)
tracing::info!(user_id = ?user_ctx.user_id, "Event channel closed, closing WebSocket."); tracing::info!(user_id = ?user_ctx.user_id, "Event channel closed, closing WebSocket.");
return Ok(()); // Graceful shutdown signaled by closed event channel return Ok(());
} }
} }
// Listen for messages from the client (e.g., Ping)
maybe_ws_message = ws_stream.next() => { maybe_ws_message = ws_stream.next() => {
match maybe_ws_message { match maybe_ws_message {
Some(Ok(message)) => { Some(Ok(message)) => {
handle_connected_message(context, message, sender).await?; handle_connected_message(context, message, sender, &app_state).await?;
} }
Some(Err(axum_ws_err)) => { Some(Err(axum_ws_err)) => {
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream error during Connected: {:?}", axum_ws_err); tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream error during Connected: {:?}", axum_ws_err);
return Err(WsError::WebSocketClosed); return Err(crate::web::ws::error::Error::WebSocketClosed);
} }
None => { // Stream closed by client None => { // Stream closed by client
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream ended by client during Connected state."); tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream ended by client during Connected state.");
return Err(WsError::WebSocketClosed); return Err(crate::web::ws::error::Error::WebSocketClosed);
} }
} }
} }
@@ -172,16 +135,13 @@ async fn process_websocket_messages(
} }
} }
/// Handles messages received when the connection is in the `Initialize` state.
/// Expects `Authenticate` to transition to `Connected`, or `Ping` to stay in `Initialize`.
#[tracing::instrument(skip_all, fields(state = ?context.connection_state))] #[tracing::instrument(skip_all, fields(state = ?context.connection_state))]
async fn handle_initial_message( async fn handle_initial_message(
context: &mut WsContext, context: &mut WsContext,
message: AxumMessage, message: AxumMessage,
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage<WsServerMessage>>, // Changed to reference sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>, // Changed to reference
app_state: &AppState, app_state: &AppState,
) -> error::Result<ControlFlow<WsState, ()>> { ) -> crate::web::ws::error::Result<(), WsError> {
// Break(NewState) or Continue(())
match deserialize_ws_message(message)? { match deserialize_ws_message(message)? {
WsClientMessage::Authenticate { token } => { WsClientMessage::Authenticate { token } => {
match crate::web::middleware::get_context_from_token(&app_state, &token).await { match crate::web::middleware::get_context_from_token(&app_state, &token).await {
@@ -212,7 +172,7 @@ async fn handle_initial_message(
.register_gateway_connected_user( .register_gateway_connected_user(
user_id, user_id,
current_session_key.clone(), current_session_key.clone(),
event_tx, // This is ws::state::EventSender -> mpsc::UnboundedSender<ws::message::Event> event_tx,
) )
.await; .await;
@@ -224,41 +184,35 @@ async fn handle_initial_message(
}, },
) )
.await?; .await?;
// Deadline is reset by the caller upon ControlFlow::Break Ok(())
Ok(ControlFlow::Break(WsState::Connected))
}, },
Err(_auth_err) => { Err(_auth_err) => {
tracing::warn!(token = %token, "Authentication failed for token."); tracing::warn!(token = %token, "Authentication failed for token.");
// Send AuthenticateDenied, then the connection will be closed by HeartbeatTimeout or by returning error.
// We send response to ensure client gets the denial before we might drop connection.
let _ = SendWsMessage::send_with_response( let _ = SendWsMessage::send_with_response(
sender, sender,
WsServerMessage::AuthenticateDenied, WsServerMessage::AuthenticateDenied,
) )
.await; .await;
Err(WsError::AuthenticationFailed) // This will terminate process_websocket_messages Err(WsError::AuthenticationFailed.into())
}, },
} }
}, },
// Per original code, only Authenticate and Ping are expected in Initialize.
// If WsClientMessage has other variants, this might need adjustment.
#[allow(unreachable_patterns)] #[allow(unreachable_patterns)]
_ => { _ => {
tracing::warn!("Unexpected message type received during Initialize state."); tracing::warn!("Unexpected message type received during Initialize state.");
Err(WsError::UnexpectedMessageType) Err(crate::web::ws::error::Error::WrongMessageType)
}, },
} }
} }
/// Handles messages received when the connection is in the `Connected` state.
/// Primarily expects `Ping` messages to keep the connection alive.
#[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id) #[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)
))] ))]
async fn handle_connected_message( async fn handle_connected_message(
context: &mut WsContext, context: &mut WsContext,
message: AxumMessage, message: AxumMessage,
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage<WsServerMessage>>, sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>, // Changed to reference
) -> error::Result<()> { app_state: &AppState,
) -> crate::web::ws::error::Result<(), WsError> {
match deserialize_ws_message(message)? { match deserialize_ws_message(message)? {
WsClientMessage::VoiceStateUpdate { WsClientMessage::VoiceStateUpdate {
server_id, server_id,
@@ -274,11 +228,15 @@ async fn handle_connected_message(
.user_id, .user_id,
server_id, server_id,
channel_id, channel_id,
iat: (chrono::Utc::now() + crate::config::config().gateway.voice_token_lifetime) exp: (chrono::Utc::now() + crate::config::config().gateway.voice_token_lifetime)
.timestamp(), .timestamp(),
}; };
let token = jwt::generate_jwt(claims).map_err(|_| WsError::TokenGenerationFailed)?; let token = jwt::generate_jwt(
claims,
crate::config::config().security.voice_secret.as_ref(),
)
.map_err(|_| WsError::TokenGenerationFailed)?;
SendWsMessage::send_with_response( SendWsMessage::send_with_response(
sender, sender,
@@ -294,9 +252,47 @@ async fn handle_connected_message(
Ok(()) Ok(())
}, },
WsClientMessage::RequestVoiceStates { server_id } => {
let channels = app_state
.database
.select_server_channels(server_id)
.await
.map_err(|_| crate::web::ws::error::Error::UnknownError)?;
for channel in channels {
let (tx, rx) = tokio::sync::oneshot::channel();
let webrtc_sender =
{ app_state.voice_rooms.read().await.get(&channel.id).cloned() };
if let Some(voice_room) = webrtc_sender {
let _ = voice_room.send(WebRtcSignal::RequestPeers { response: tx });
let peers = match rx.await {
Ok(peers) => peers,
Err(_) => {
continue;
},
};
for peer in peers {
let _ =
sender.send(SendWsMessage::new_no_response(WsServerMessage::Event {
event: WsEvent::VoiceChannelConnected {
server_id,
channel_id: channel.id,
user_id: peer,
},
}));
}
}
}
Ok(())
},
other_message => { other_message => {
tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state."); tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state.");
Err(WsError::UnexpectedMessageType) Err(crate::web::ws::error::Error::WrongMessageType)
}, },
} }
} }

View File

@@ -1,56 +1,12 @@
use crate::web::ws::error::CustomError;
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From, derive_more::Display)] #[derive(Debug, derive_more::From, derive_more::Display, serde::Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Error { pub enum Error {
#[from]
Axum(axum::Error),
#[from]
Json(serde_json::Error),
#[from]
AcknowledgementError(tokio::sync::oneshot::error::RecvError),
UnexpectedMessageType,
WrongMessageType,
WebSocketClosed,
AuthenticationFailed, AuthenticationFailed,
TokenGenerationFailed, TokenGenerationFailed,
} }
#[derive(Debug, Clone, serde::Serialize)] impl CustomError for Error {}
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ClientError {
DeserializationError,
NotAuthenticated,
AlreadyAuthenticated,
Unknown,
}
impl Error {
pub fn as_client_error(&self) -> ClientError {
match self {
Error::Json(_) => ClientError::DeserializationError,
Error::UnexpectedMessageType => ClientError::Unknown,
Error::WrongMessageType => ClientError::Unknown,
Error::WebSocketClosed => ClientError::Unknown,
_ => ClientError::Unknown,
}
}
}
impl From<crate::web::ws::error::Error> for Error {
fn from(err: crate::web::ws::error::Error) -> Self {
match err {
crate::web::ws::error::Error::Json(e) => Error::Json(e),
crate::web::ws::error::Error::AcknowledgementError(e) => Error::AcknowledgementError(e),
crate::web::ws::error::Error::WrongMessageType => Error::WrongMessageType,
crate::web::ws::error::Error::WebSocketClosed => Error::WebSocketClosed,
}
}
}

View File

@@ -1,23 +1,77 @@
use crate::entity; use crate::{entity, web};
#[derive(Debug, Clone, serde::Serialize)] #[derive(Debug, Clone, serde::Serialize)]
#[serde(tag = "type", content = "data")] #[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Event { pub enum Event {
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
AddServer { server: entity::server::Server }, AddServer { server: web::entity::server::Server },
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
RemoveServer { server_id: entity::server::Id }, RemoveServer { server_id: entity::server::Id },
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
AddDmChannel { channel: entity::channel::Channel }, AddDmChannel { channel: entity::channel::Channel },
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
RemoveDmChannel { channel_id: entity::channel::Id }, RemoveDmChannel { channel_id: entity::channel::Id },
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
AddServerChannel { channel: entity::channel::Channel }, AddServerChannel { channel: entity::channel::Channel },
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
RemoveServerChannel { channel_id: entity::channel::Id }, RemoveServerChannel {
server_id: entity::server::Id,
channel_id: entity::channel::Id,
},
#[serde(rename_all = "camelCase")]
AddUser {
user: web::entity::user::PartialUser,
},
#[serde(rename_all = "camelCase")]
RemoveUser {
user_id: entity::user::Id,
},
#[serde(rename_all = "camelCase")]
AddServerMember {
server_id: entity::server::Id,
member: web::entity::user::PartialUser,
},
#[serde(rename_all = "camelCase")]
RemoveServerMember {
server_id: entity::server::Id,
member_id: entity::user::Id,
},
#[serde(rename_all = "camelCase")]
AddMessage {
channel_id: entity::channel::Id,
message: web::entity::message::Message,
},
#[serde(rename_all = "camelCase")]
RemoveMessage {
channel_id: entity::channel::Id,
message_id: entity::message::Id,
},
#[serde(rename_all = "camelCase")]
VoiceChannelConnected {
server_id: entity::server::Id,
channel_id: entity::channel::Id,
user_id: entity::user::Id,
},
#[serde(rename_all = "camelCase")]
VoiceChannelDisconnected {
server_id: entity::server::Id,
channel_id: entity::channel::Id,
user_id: entity::user::Id,
},
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
VoiceServerUpdate { VoiceServerUpdate {

View File

@@ -3,8 +3,8 @@ use axum::response::IntoResponse;
use dashmap::DashMap; use dashmap::DashMap;
use crate::state::AppState; use crate::state::AppState;
use crate::web::ws::gateway::connection::handle_socket_connection; use crate::web::ws::gateway::state::{EventSender, WsContext};
use crate::web::ws::gateway::state::EventSender; use crate::web::ws::general;
mod connection; mod connection;
mod error; mod error;
@@ -37,5 +37,7 @@ pub async fn ws_handler(
State(app_state): State<AppState>, State(app_state): State<AppState>,
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
) -> crate::web::error::Result<impl IntoResponse> { ) -> crate::web::error::Result<impl IntoResponse> {
Ok(ws.on_upgrade(|socket| handle_socket_connection(socket, app_state))) Ok(ws.on_upgrade(|socket| {
general::handle_websocket_connection(socket, app_state, WsContext::default())
}))
} }

View File

@@ -1,12 +1,7 @@
use std::time::Duration; use super::{SessionKey, error, event};
use serde::{Deserialize, Serialize};
use super::error::ClientError;
use super::{SessionKey, event as ws_local_message};
use crate::entity; use crate::entity;
#[derive(Debug, Serialize)] #[derive(Debug, serde::Serialize)]
#[serde(tag = "type", content = "data")] #[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsServerMessage { pub enum WsServerMessage {
@@ -20,27 +15,28 @@ pub enum WsServerMessage {
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
Event { Event {
event: ws_local_message::Event, event: event::Event,
}, },
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
Error { Error {
code: ClientError, code: error::Error,
}, },
} }
#[derive(Debug, Deserialize)] #[derive(Debug, serde::Deserialize)]
#[serde(tag = "type", content = "data")] #[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsClientMessage { pub enum WsClientMessage {
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
Authenticate { Authenticate { token: String },
token: String,
},
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
VoiceStateUpdate { VoiceStateUpdate {
server_id: entity::server::Id, server_id: entity::server::Id,
channel_id: entity::channel::Id, channel_id: entity::channel::Id,
}, },
#[serde(rename_all = "camelCase")]
RequestVoiceStates { server_id: entity::server::Id },
} }

View File

@@ -1,37 +1,35 @@
use std::time::Duration;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use super::{event, SessionKey}; use super::{SessionKey, event};
use crate::entity; // For entity::user::Id // For ws::message::Event used in EventSender/Receiver use crate::entity;
/// Represents the current state of a single WebSocket connection.
#[derive(Debug, Eq, PartialEq, Clone, Copy)] #[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub enum WsState { pub enum WsState {
Initialize, // Connection established, awaiting authentication Initialize,
Connected, // Authenticated and operational Connected,
} }
/// Contextual information for an authenticated WebSocket user session.
#[derive(Debug)] #[derive(Debug)]
pub struct WsUserContext { pub struct WsUserContext {
pub user_id: entity::user::Id, pub user_id: entity::user::Id,
pub session_key: SessionKey, // Unique key for this specific WebSocket session instance pub session_key: SessionKey, // Unique key for this specific WebSocket session instance
} }
/// Sender part of an MPSC channel used to send `ws::message::Event`s to a connected client.
pub type EventSender = mpsc::UnboundedSender<event::Event>; pub type EventSender = mpsc::UnboundedSender<event::Event>;
/// Receiver part of an MPSC channel used by a connection task to receive `ws::message::Event`s.
pub type EventReceiver = mpsc::UnboundedReceiver<event::Event>; pub type EventReceiver = mpsc::UnboundedReceiver<event::Event>;
/// Holds the full context for a single WebSocket connection's lifecycle.
/// This struct is managed per-connection.
pub struct WsContext { pub struct WsContext {
pub connection_state: WsState, pub connection_state: WsState,
pub user_context: Option<WsUserContext>, pub user_context: Option<WsUserContext>,
/// Channel for receiving application-specific events to be sent to this client.
/// The `EventSender` (tx) part is given to `AppState` for broadcasting.
/// The `EventReceiver` (rx) part is polled by the connection task.
pub event_channel: Option<(EventSender, EventReceiver)>, pub event_channel: Option<(EventSender, EventReceiver)>,
} }
impl Default for WsContext {
fn default() -> Self {
Self {
connection_state: WsState::Initialize,
user_context: None,
event_channel: None,
}
}
}

View File

@@ -2,7 +2,8 @@ use crate::entity;
use crate::state::AppState; use crate::state::AppState;
use crate::web::ws::gateway::event; use crate::web::ws::gateway::event;
pub async fn send_message(state: &AppState, user_id: entity::user::Id, message: event::Event) { pub fn send_message(state: AppState, user_id: entity::user::Id, message: event::Event) {
tokio::spawn(async move {
let connected_users = state.gateway_state.connected.get_async(&user_id).await; let connected_users = state.gateway_state.connected.get_async(&user_id).await;
if let Some(session) = connected_users { if let Some(session) = connected_users {
for instance in session.instances.iter() { for instance in session.instances.iter() {
@@ -11,4 +12,57 @@ pub async fn send_message(state: &AppState, user_id: entity::user::Id, message:
} }
} }
} }
});
}
pub fn send_message_many(state: AppState, user_ids: &[entity::user::Id], message: event::Event) {
for id in user_ids.iter() {
send_message(state.clone(), *id, message.clone());
}
}
pub fn send_message_server(state: AppState, server_id: entity::server::Id, message: event::Event) {
tokio::spawn(async move {
let users = state
.database
.select_server_members(server_id)
.await
.unwrap_or_else(|_| vec![])
.iter()
.map(|u| u.id)
.collect::<Vec<_>>();
send_message_many(state, &users, message);
});
}
pub fn send_message_channel(
state: AppState,
channel_id: entity::channel::Id,
message: event::Event,
) {
tokio::spawn(async move {
let users = state
.database
.select_channel_members(channel_id)
.await
.unwrap_or_else(|_| vec![])
.iter()
.map(|u| u.id)
.collect::<Vec<_>>();
send_message_many(state, &users, message);
});
}
pub fn send_message_related(state: AppState, user_id: entity::user::Id, message: event::Event) {
tokio::spawn(async move {
let users = state
.database
.select_related_user_ids(user_id)
.await
.unwrap_or_else(|_| vec![]);
send_message_many(state, &users, message);
});
} }

74
src/web/ws/general.rs Normal file
View File

@@ -0,0 +1,74 @@
use std::fmt::Debug;
use axum::extract::ws::WebSocket;
use futures::{Stream, StreamExt};
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::sync::mpsc;
use crate::state::AppState;
use crate::web::ws::error::CustomError;
use crate::web::ws::util;
use crate::web::ws::util::SendWsMessage;
pub trait WebSocketHandler {
type ServerMessage: Serialize + Send;
type ClientMessage: DeserializeOwned;
type Error: CustomError + Send + Debug;
async fn handle_stream<S>(
&mut self,
stream: S,
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
app_state: &AppState,
) -> crate::web::ws::error::Result<(), Self::Error>
where
S: Stream<Item = Result<axum::extract::ws::Message, axum::Error>> + Unpin;
async fn cleanup(&mut self, app_state: &AppState);
async fn handle_result_error(
&mut self,
error: Self::Error,
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
);
}
#[tracing::instrument(skip_all)]
pub async fn handle_websocket_connection(
websocket: WebSocket,
app_state: AppState,
mut handler: impl WebSocketHandler + 'static,
) {
let (ws_sink, ws_stream) = websocket.split();
let (internal_send_tx, internal_send_rx) = mpsc::unbounded_channel();
let writer_task = util::spawn_writer_task(ws_sink, internal_send_rx);
let processing_result = handler
.handle_stream(ws_stream, &internal_send_tx, &app_state)
.await;
handler.cleanup(&app_state).await;
match processing_result {
Ok(_) => {},
Err(crate::web::ws::error::Error::Custom(err_to_report)) => {
handler
.handle_result_error(err_to_report, &internal_send_tx)
.await;
},
Err(e) => {
tracing::info!("WebSocket connection closed: {:?}", e);
},
}
drop(internal_send_tx);
if let Err(e) = writer_task.await {
tracing::error!(
"WebSocket writer task panicked or encountered an error: {:?}",
e
);
}
}

View File

@@ -2,3 +2,4 @@ mod error;
pub mod gateway; pub mod gateway;
mod util; mod util;
pub mod voice; pub mod voice;
mod general;

View File

@@ -4,13 +4,16 @@ use serde::Serialize;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
pub fn spawn_writer_task<S, T>( use crate::web::ws::error::CustomError;
pub fn spawn_writer_task<S, T, E>(
mut ws_sink: S, mut ws_sink: S,
mut writer_rx: mpsc::UnboundedReceiver<SendWsMessage<T>>, mut writer_rx: mpsc::UnboundedReceiver<SendWsMessage<T, E>>,
) -> tokio::task::JoinHandle<()> ) -> tokio::task::JoinHandle<()>
where where
S: Sink<axum::extract::ws::Message> + Unpin + Send + 'static, S: Sink<axum::extract::ws::Message> + Unpin + Send + 'static,
T: Serialize + Send + 'static, T: Serialize + Send + 'static,
E: CustomError + Send + 'static,
{ {
tokio::spawn(async move { tokio::spawn(async move {
while let Some(SendWsMessage { while let Some(SendWsMessage {
@@ -39,9 +42,9 @@ where
} }
/// Deserializes an Axum WebSocket message into a `WsClientMessage`. /// Deserializes an Axum WebSocket message into a `WsClientMessage`.
pub fn deserialize_ws_message<T: DeserializeOwned>( pub fn deserialize_ws_message<T: DeserializeOwned, E: CustomError>(
message: AxumMessage, message: AxumMessage,
) -> super::error::Result<T> { ) -> super::error::Result<T, E> {
match message { match message {
AxumMessage::Text(text) => serde_json::from_str(&text).map_err(super::error::Error::from), AxumMessage::Text(text) => serde_json::from_str(&text).map_err(super::error::Error::from),
AxumMessage::Close(_) => Err(super::error::Error::WebSocketClosed), AxumMessage::Close(_) => Err(super::error::Error::WebSocketClosed),
@@ -50,7 +53,9 @@ pub fn deserialize_ws_message<T: DeserializeOwned>(
} }
/// Serializes a `WsServerMessage` into an Axum WebSocket message. /// Serializes a `WsServerMessage` into an Axum WebSocket message.
pub fn serialize_ws_message<T: Serialize>(message: T) -> super::error::Result<AxumMessage> { pub fn serialize_ws_message<T: Serialize, E: CustomError>(
message: T,
) -> super::error::Result<AxumMessage, E> {
serde_json::to_string(&message) serde_json::to_string(&message)
.map(Into::into) .map(Into::into)
.map(AxumMessage::Text) .map(AxumMessage::Text)
@@ -59,17 +64,17 @@ pub fn serialize_ws_message<T: Serialize>(message: T) -> super::error::Result<Ax
/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task. /// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task.
/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer. /// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer.
pub struct SendWsMessage<T> { pub struct SendWsMessage<T, E: CustomError> {
pub message: T, pub message: T,
pub response_ch: Option<oneshot::Sender<super::error::Result<()>>>, pub response_ch: Option<oneshot::Sender<super::error::Result<(), E>>>,
} }
impl<T> SendWsMessage<T> { impl<T, E: CustomError> SendWsMessage<T, E> {
/// Sends a message over the MPSC channel and awaits a response via a oneshot channel. /// Sends a message over the MPSC channel and awaits a response via a oneshot channel.
pub async fn send_with_response( pub async fn send_with_response(
tx: &mpsc::UnboundedSender<Self>, // Changed to reference tx: &mpsc::UnboundedSender<Self>, // Changed to reference
message: T, message: T,
) -> super::error::Result<()> { ) -> super::error::Result<(), E> {
let (response_tx, response_rx) = oneshot::channel(); let (response_tx, response_rx) = oneshot::channel();
let send_message = SendWsMessage { let send_message = SendWsMessage {
message, message,
@@ -87,7 +92,7 @@ impl<T> SendWsMessage<T> {
/// Creates a new message for fire-and-forget sending (no response/acknowledgement expected). /// Creates a new message for fire-and-forget sending (no response/acknowledgement expected).
pub fn new_no_response(message: T) -> Self { pub fn new_no_response(message: T) -> Self {
SendWsMessage { Self {
message, message,
response_ch: None, response_ch: None,
} }

View File

@@ -5,5 +5,5 @@ pub struct VoiceClaims {
pub user_id: entity::user::Id, pub user_id: entity::user::Id,
pub server_id: entity::server::Id, pub server_id: entity::server::Id,
pub channel_id: entity::channel::Id, pub channel_id: entity::channel::Id,
pub iat: i64, pub exp: i64,
} }

View File

@@ -1,6 +1,211 @@
use axum::extract::ws::WebSocket; use axum::extract::ws::Message as AxumMessage;
use futures::{Stream, StreamExt};
use tokio::sync::{mpsc, oneshot};
use super::error::{self, Error as WsError};
use super::protocol::{WsClientMessage, WsServerMessage};
use crate::jwt;
use crate::state::AppState; use crate::state::AppState;
use crate::web::ws;
use crate::web::ws::general::WebSocketHandler;
use crate::web::ws::util::{SendWsMessage, deserialize_ws_message};
use crate::web::ws::voice::claims::VoiceClaims;
use crate::web::ws::voice::protocol::WsServerMessage::SdpAnswer;
use crate::web::ws::voice::state::{WsContext, WsState};
use crate::webrtc::{Offer, OfferSignal, WebRtcSignal};
#[tracing::instrument(skip_all, name = "ws_connection_handler")] impl WebSocketHandler for WsContext {
pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) {} type ServerMessage = WsServerMessage;
type ClientMessage = WsClientMessage;
type Error = WsError;
async fn handle_stream<S>(
&mut self,
stream: S,
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
app_state: &AppState,
) -> ws::error::Result<(), Self::Error>
where
S: Stream<Item = Result<AxumMessage, axum::Error>> + Unpin,
{
process_websocket_messages(self, stream, sender, app_state).await?;
Ok(())
}
async fn cleanup(&mut self, app_state: &AppState) {
tracing::debug!("Cleaning up WebSocket connection.");
match &self.connection_state {
WsState::Connected {
signal_channel,
server_id,
channel_id,
user_id,
} => {
ws::gateway::util::send_message_server(
app_state.clone(),
*server_id,
ws::gateway::event::Event::VoiceChannelDisconnected {
server_id: *server_id,
channel_id: *channel_id,
user_id: *user_id,
},
);
let _ = signal_channel.send(WebRtcSignal::Disconnect(user_id.clone()));
},
WsState::Initialize => {},
}
}
async fn handle_result_error(
&mut self,
error: Self::Error,
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
) {
tracing::error!("WebSocket error: {:?}", error);
}
}
#[tracing::instrument(skip_all)]
async fn process_websocket_messages<S>(
context: &mut WsContext,
mut ws_stream: S,
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
app_state: &AppState,
) -> ws::error::Result<(), WsError>
where
S: Stream<Item = Result<AxumMessage, axum::Error>> + Unpin,
{
loop {
match &context.connection_state {
WsState::Initialize => {
while let Some(Ok(message)) = ws_stream.next().await {
handle_initial_message(context, message, sender, &app_state).await?;
break;
}
},
WsState::Connected { signal_channel, .. } => {
let signal_channel = signal_channel.clone();
loop {
tokio::select! {
biased;
_ = signal_channel.closed() => {
tracing::debug!("Signal channel closed.");
break;
}
Some(Ok(message)) = ws_stream.next() => {
handle_connected_message(context, message, sender, &app_state).await?;
}
else => {
break;
}
}
}
return Err(ws::error::Error::WebSocketClosed);
},
}
}
}
#[tracing::instrument(skip_all)]
async fn handle_initial_message(
context: &mut WsContext,
message: AxumMessage,
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, error::Error>>,
app_state: &AppState,
) -> ws::error::Result<(), error::Error> {
match deserialize_ws_message(message)? {
WsClientMessage::Authenticate { token } => match jwt::verify_jwt::<VoiceClaims>(
&token,
crate::config::config().security.voice_secret.as_ref(),
) {
Ok(claims) => {
SendWsMessage::send_with_response(sender, WsServerMessage::AuthenticateAccepted)
.await?;
let signal_channel =
ws::voice::state::get_signaling_channel(app_state, claims.channel_id).await;
context.connection_state = WsState::Connected {
signal_channel,
server_id: claims.server_id,
channel_id: claims.channel_id,
user_id: claims.user_id,
};
ws::gateway::util::send_message_server(
app_state.clone(),
claims.server_id,
ws::gateway::event::Event::VoiceChannelConnected {
server_id: claims.server_id,
channel_id: claims.channel_id,
user_id: claims.user_id,
},
);
Ok(())
},
Err(auth_err) => {
tracing::warn!("Authentication failed: {:?}", auth_err);
let _ =
SendWsMessage::send_with_response(sender, WsServerMessage::AuthenticateDenied)
.await;
Err(error::Error::AuthenticationFailed.into())
},
},
#[allow(unreachable_patterns)]
_ => {
tracing::warn!("Unexpected message type received during Initialize state.");
Err(ws::error::Error::WrongMessageType)
},
}
}
#[tracing::instrument(skip_all)]
async fn handle_connected_message(
context: &mut WsContext,
message: AxumMessage,
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, error::Error>>,
app_state: &AppState,
) -> ws::error::Result<(), error::Error> {
match deserialize_ws_message(message)? {
WsClientMessage::SdpOffer { sdp } => {
let (signal_channel, user_id) = match &context.connection_state {
WsState::Connected {
signal_channel,
user_id,
..
} => (signal_channel.clone(), *user_id),
_ => return Err(ws::error::Error::WrongMessageType),
};
let (tx, rx) = oneshot::channel();
let _ = signal_channel.send(WebRtcSignal::Offer(OfferSignal {
offer: Offer {
peer_id: user_id,
sdp_offer: sdp,
},
response: tx,
}));
let answer_signal = rx.await?;
sender
.send(SendWsMessage::new_no_response(SdpAnswer {
sdp: answer_signal.sdp_answer,
}))
.map_err(|_| ws::error::Error::WebSocketClosed)?;
Ok(())
},
other_message => {
tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state.");
Err(ws::error::Error::WrongMessageType)
},
}
}

View File

@@ -1,48 +1,10 @@
use crate::web::ws::error::CustomError;
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From, derive_more::Display)] #[derive(Debug, derive_more::From, derive_more::Display)]
pub enum Error { pub enum Error {
#[from]
Axum(axum::Error),
#[from]
Json(serde_json::Error),
#[from]
AcknowledgementError(tokio::sync::oneshot::error::RecvError),
UnexpectedMessageType,
WrongMessageType,
WebSocketClosed,
HeartbeatTimeout,
AuthenticationFailed, AuthenticationFailed,
TokenGenerationFailed,
} }
#[derive(Debug, Clone, serde::Serialize)] impl CustomError for Error {}
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ClientError {
DeserializationError,
NotAuthenticated,
AlreadyAuthenticated,
HeartbeatTimeout,
Unknown,
}
impl Error {
pub fn into_client_error(&self) -> ClientError {
match self {
Error::HeartbeatTimeout => ClientError::HeartbeatTimeout,
Error::Json(_) => ClientError::DeserializationError,
Error::UnexpectedMessageType => ClientError::Unknown,
Error::WrongMessageType => ClientError::Unknown,
Error::WebSocketClosed => ClientError::Unknown,
_ => ClientError::Unknown,
}
}
}

View File

@@ -2,16 +2,20 @@ pub mod claims;
mod connection; mod connection;
mod error; mod error;
mod protocol; mod protocol;
mod state;
use axum::extract::{State, WebSocketUpgrade}; use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use crate::state::AppState; use crate::state::AppState;
use crate::web::ws::voice::connection::handle_socket_connection; use crate::web::ws::general;
use crate::web::ws::voice::state::WsContext;
pub async fn ws_handler( pub async fn ws_handler(
State(app_state): State<AppState>, State(app_state): State<AppState>,
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
) -> crate::web::error::Result<impl IntoResponse> { ) -> crate::web::error::Result<impl IntoResponse> {
Ok(ws.on_upgrade(|socket| handle_socket_connection(socket, app_state))) Ok(ws.on_upgrade(|socket| {
general::handle_websocket_connection(socket, app_state, WsContext::default())
}))
} }

View File

@@ -1,21 +1,9 @@
use std::time::Duration;
use axum::extract::ws::Message as AxumMessage;
use serde::{Deserialize, Serialize};
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use super::error::{self, ClientError, Error as WsError}; #[derive(Debug, serde::Serialize)]
use crate::{entity, util as crate_root_util}; // For crate::util::serialize_duration_seconds
#[derive(Debug, Serialize)]
#[serde(tag = "type", content = "data")] #[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsServerMessage { pub enum WsServerMessage {
HeartbeatInterval {
#[serde(serialize_with = "crate_root_util::serialize_duration_seconds")]
interval: Duration,
},
AuthenticateDenied, AuthenticateDenied,
AuthenticateAccepted, AuthenticateAccepted,
@@ -24,82 +12,15 @@ pub enum WsServerMessage {
SdpAnswer { SdpAnswer {
sdp: RTCSessionDescription, sdp: RTCSessionDescription,
}, },
#[serde(rename_all = "camelCase")]
Error {
code: ClientError,
},
Pong,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, serde::Deserialize)]
#[serde(tag = "type", content = "data")] #[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsClientMessage { pub enum WsClientMessage {
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
Authenticate { Authenticate { token: String },
token: String,
},
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
SdpOffer { SdpOffer { sdp: RTCSessionDescription },
sdp: RTCSessionDescription,
},
Ping,
}
/// Deserializes an Axum WebSocket message into a `WsClientMessage`.
pub fn deserialize_ws_message(message: AxumMessage) -> error::Result<WsClientMessage> {
match message {
AxumMessage::Text(text) => serde_json::from_str(&text).map_err(WsError::from),
AxumMessage::Close(_) => Err(WsError::WebSocketClosed),
_ => Err(WsError::WrongMessageType), // e.g. Binary, Ping, Pong from axum::Message
}
}
/// Serializes a `WsServerMessage` into an Axum WebSocket message.
pub fn serialize_ws_message(message: WsServerMessage) -> error::Result<AxumMessage> {
serde_json::to_string(&message)
.map(Into::into)
.map(AxumMessage::Text)
.map_err(WsError::from)
}
/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task.
/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer.
pub struct SendWsMessage {
pub message: WsServerMessage,
pub response_ch: Option<tokio::sync::oneshot::Sender<error::Result<()>>>,
}
impl SendWsMessage {
/// Sends a message over the MPSC channel and awaits a response via a oneshot channel.
pub async fn send_with_response(
tx: &tokio::sync::mpsc::UnboundedSender<Self>, // Changed to reference
message: WsServerMessage,
) -> error::Result<()> {
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let send_message = SendWsMessage {
message,
response_ch: Some(response_tx),
};
if tx.send(send_message).is_err() {
Err(WsError::WebSocketClosed) // MPSC channel closed, writer task likely dead
} else {
// Wait for the writer task to acknowledge the send attempt.
// This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error.
response_rx.await? // Propagates RecvError into WsError::AcknowledgementError
}
}
/// Creates a new message for fire-and-forget sending (no response/acknowledgement expected).
pub fn new_no_response(message: WsServerMessage) -> Self {
SendWsMessage {
message,
response_ch: None,
}
}
} }

68
src/web/ws/voice/state.rs Normal file
View File

@@ -0,0 +1,68 @@
use tokio::sync::mpsc;
use crate::entity;
use crate::state::AppState;
use crate::webrtc::{OfferSignal, WebRtcSignal};
#[derive(Debug, Clone)]
pub enum WsState {
Initialize,
Connected {
signal_channel: mpsc::UnboundedSender<WebRtcSignal>,
server_id: entity::server::Id,
channel_id: entity::channel::Id,
user_id: entity::user::Id,
},
}
pub struct WsContext {
pub connection_state: WsState,
}
impl Default for WsContext {
fn default() -> Self {
Self {
connection_state: WsState::Initialize,
}
}
}
pub async fn get_signaling_channel(
app_state: &AppState,
channel_id: entity::channel::Id,
) -> mpsc::UnboundedSender<WebRtcSignal> {
let room_sender = {
app_state
.voice_rooms
.read()
.await
.get(&channel_id)
.map(|room| room.clone())
};
match room_sender {
Some(room) => room,
None => {
let (tx, rx) = mpsc::unbounded_channel();
let app_state_ = app_state.clone();
tokio::spawn(async move {
crate::webrtc::webrtc_task(channel_id, rx)
.await
.unwrap_or_else(|err| {
tracing::error!("webrtc task error: {:?}", err);
});
{
app_state_.unregister_voice_room(channel_id).await;
}
});
{
app_state.register_voice_room(channel_id, tx.clone()).await;
}
tx
},
}
}

View File

@@ -42,6 +42,16 @@ pub struct AnswerSignal {
pub sdp_answer: RTCSessionDescription, pub sdp_answer: RTCSessionDescription,
} }
#[derive(Debug)]
pub enum WebRtcSignal {
Offer(OfferSignal),
Disconnect(PeerId),
RequestPeers {
response: tokio::sync::oneshot::Sender<Vec<PeerId>>,
},
Close,
}
#[derive(Debug)] #[derive(Debug)]
pub struct OfferSignal { pub struct OfferSignal {
pub offer: Offer, pub offer: Offer,
@@ -51,12 +61,14 @@ pub struct OfferSignal {
#[tracing::instrument(skip(signal))] #[tracing::instrument(skip(signal))]
pub async fn webrtc_task( pub async fn webrtc_task(
room_id: RoomId, room_id: RoomId,
signal: tokio::sync::mpsc::UnboundedReceiver<OfferSignal>, signal: tokio::sync::mpsc::UnboundedReceiver<WebRtcSignal>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
tracing::info!("Starting WebRTC task"); tracing::info!("Starting WebRTC task");
let (close_signal, mut close_receiver) = tokio::sync::mpsc::unbounded_channel(); let (close_signal, mut close_receiver) = tokio::sync::mpsc::unbounded_channel();
let mut skip_timeout = false;
let state = Arc::new(RoomState { let state = Arc::new(RoomState {
room_id, room_id,
peers: DashMap::new(), peers: DashMap::new(),
@@ -78,20 +90,46 @@ pub async fn webrtc_task(
loop { loop {
tokio::select! { tokio::select! {
Some(signal) = signal.recv() => { biased;
let room_state = state.clone(); _ = tokio::time::sleep(std::time::Duration::from_secs(10)), if !skip_timeout => {
let api = api.clone(); tracing::debug!("initial timeout reached");
break;
tokio::spawn(async move {
if let Err(e) = handle_peer(api, room_state, signal).await {
tracing::error!("error handling peer: {}", e);
}
}.instrument(tracing::Span::current()));
} }
_ = close_receiver.recv() => { _ = close_receiver.recv() => {
tracing::debug!("WebRTC task stopped"); tracing::debug!("WebRTC task stopped");
break; break;
} }
Some(signal) = signal.recv() => {
skip_timeout = true;
match signal {
WebRtcSignal::Offer(offer_signal) => {
let room_state = state.clone();
let api = api.clone();
tokio::spawn(async move {
if let Err(e) = handle_peer(api, room_state, offer_signal).await {
tracing::error!("error handling peer: {}", e);
}
}.instrument(tracing::Span::current()));
}
WebRtcSignal::RequestPeers { response } => {
let peers = state
.peers
.iter()
.map(|pair| pair.key().clone())
.collect::<Vec<_>>();
let _ = response.send(peers);
}
WebRtcSignal::Disconnect(peer_id) => {
tracing::debug!("received disconnect signal for peer {}", peer_id);
cleanup_peer(state.clone(), peer_id).await;
}
WebRtcSignal::Close => {
break;
}
}
}
} }
} }
@@ -105,6 +143,7 @@ async fn handle_peer(
offer_signal: OfferSignal, offer_signal: OfferSignal,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
tracing::debug!("handling peer"); tracing::debug!("handling peer");
let config = RTCConfiguration { let config = RTCConfiguration {
..Default::default() ..Default::default()
}; };