98 update configuration management to use confy for file based configuration #100

Merged
PatrykHegenberg merged 5 commits from 98-update-configuration-management-to-use-confy-for-file-based-configuration into main 2024-05-09 16:29:03 +02:00
16 changed files with 338 additions and 1704 deletions

210
Cargo.lock generated
View file

@ -413,7 +413,9 @@ dependencies = [
"axum-client-ip", "axum-client-ip",
"caesar-core", "caesar-core",
"clap 4.5.4", "clap 4.5.4",
"dotenv", "confy",
"dotenvy",
"lazy_static",
"serde", "serde",
"serde_json", "serde_json",
"tokio", "tokio",
@ -460,6 +462,20 @@ dependencies = [
"uuid", "uuid",
] ]
[[package]]
name = "caesar-desktop"
version = "0.0.1"
dependencies = [
"caesar-core",
]
[[package]]
name = "caesar-mobile"
version = "0.0.1"
dependencies = [
"caesar-core",
]
[[package]] [[package]]
name = "caesar-transfer-iu" name = "caesar-transfer-iu"
version = "0.3.1" version = "0.3.1"
@ -471,6 +487,13 @@ dependencies = [
"shuttle-runtime", "shuttle-runtime",
] ]
[[package]]
name = "caesar-tui"
version = "0.0.1"
dependencies = [
"caesar-core",
]
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.96" version = "1.0.96"
@ -522,11 +545,26 @@ dependencies = [
"atty", "atty",
"bitflags 1.3.2", "bitflags 1.3.2",
"strsim 0.8.0", "strsim 0.8.0",
"textwrap", "textwrap 0.11.0",
"unicode-width", "unicode-width",
"vec_map", "vec_map",
] ]
[[package]]
name = "clap"
version = "3.2.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123"
dependencies = [
"atty",
"bitflags 1.3.2",
"clap_lex 0.2.4",
"indexmap 1.9.3",
"strsim 0.10.0",
"termcolor",
"textwrap 0.16.1",
]
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.4" version = "4.5.4"
@ -545,7 +583,7 @@ checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
"clap_lex", "clap_lex 0.7.0",
"strsim 0.11.1", "strsim 0.11.1",
] ]
@ -561,6 +599,15 @@ dependencies = [
"syn 2.0.60", "syn 2.0.60",
] ]
[[package]]
name = "clap_lex"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5"
dependencies = [
"os_str_bytes",
]
[[package]] [[package]]
name = "clap_lex" name = "clap_lex"
version = "0.7.0" version = "0.7.0"
@ -595,6 +642,18 @@ dependencies = [
"unicode-width", "unicode-width",
] ]
[[package]]
name = "confy"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45b1f4c00870f07dc34adcac82bb6a72cc5aabca8536ba1797e01df51d2ce9a0"
dependencies = [
"directories",
"serde",
"thiserror",
"toml",
]
[[package]] [[package]]
name = "const-oid" name = "const-oid"
version = "0.9.6" version = "0.9.6"
@ -759,6 +818,27 @@ dependencies = [
"subtle", "subtle",
] ]
[[package]]
name = "directories"
version = "5.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a49173b84e034382284f27f1af4dcbbd231ffa358c0fe316541a7337f376a35"
dependencies = [
"dirs-sys",
]
[[package]]
name = "dirs-sys"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c"
dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.48.0",
]
[[package]] [[package]]
name = "dotenv" name = "dotenv"
version = "0.15.0" version = "0.15.0"
@ -768,6 +848,15 @@ dependencies = [
"clap 2.34.0", "clap 2.34.0",
] ]
[[package]]
name = "dotenvy"
version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
dependencies = [
"clap 3.2.25",
]
[[package]] [[package]]
name = "ecdsa" name = "ecdsa"
version = "0.16.9" version = "0.16.9"
@ -1439,6 +1528,16 @@ version = "0.2.154"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346" checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346"
[[package]]
name = "libredox"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d"
dependencies = [
"bitflags 2.5.0",
"libc",
]
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.4.13" version = "0.4.13"
@ -1737,6 +1836,12 @@ dependencies = [
"thiserror", "thiserror",
] ]
[[package]]
name = "option-ext"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]] [[package]]
name = "ordered-float" name = "ordered-float"
version = "4.2.0" version = "4.2.0"
@ -1746,6 +1851,12 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "os_str_bytes"
version = "6.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1"
[[package]] [[package]]
name = "overload" name = "overload"
version = "0.1.1" version = "0.1.1"
@ -2050,6 +2161,17 @@ dependencies = [
"bitflags 2.5.0", "bitflags 2.5.0",
] ]
[[package]]
name = "redox_users"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891"
dependencies = [
"getrandom",
"libredox",
"thiserror",
]
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.10.4" version = "1.10.4"
@ -2346,6 +2468,15 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "serde_spanned"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde_urlencoded" name = "serde_urlencoded"
version = "0.7.1" version = "0.7.1"
@ -2583,6 +2714,12 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]] [[package]]
name = "strsim" name = "strsim"
version = "0.11.1" version = "0.11.1"
@ -2703,6 +2840,15 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "termcolor"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
dependencies = [
"winapi-util",
]
[[package]] [[package]]
name = "textwrap" name = "textwrap"
version = "0.11.0" version = "0.11.0"
@ -2712,6 +2858,12 @@ dependencies = [
"unicode-width", "unicode-width",
] ]
[[package]]
name = "textwrap"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9"
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.59" version = "1.0.59"
@ -2858,6 +3010,40 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "toml"
version = "0.8.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9dd1545e8208b4a5af1aa9bbd0b4cf7e9ea08fabc5d0a5c67fcaafa17433aa3"
dependencies = [
"serde",
"serde_spanned",
"toml_datetime",
"toml_edit",
]
[[package]]
name = "toml_datetime"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1"
dependencies = [
"serde",
]
[[package]]
name = "toml_edit"
version = "0.22.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3328d4f68a705b2a4498da1d580585d39a6510f98318a2cec3018a7ec61ddef"
dependencies = [
"indexmap 2.2.6",
"serde",
"serde_spanned",
"toml_datetime",
"winnow",
]
[[package]] [[package]]
name = "tonic" name = "tonic"
version = "0.10.2" version = "0.10.2"
@ -3310,6 +3496,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b"
dependencies = [
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "winapi-x86_64-pc-windows-gnu" name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0" version = "0.4.0"
@ -3464,6 +3659,15 @@ version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
[[package]]
name = "winnow"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3c52e9c97a68071b23e836c9380edae937f17b9c4667bd021973efc689f618d"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "winreg" name = "winreg"
version = "0.52.0" version = "0.52.0"

View file

@ -17,7 +17,9 @@ serde_json = { version = "1.0" }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
tracing = "0.1.40" tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
dotenv = { version = "0.15.0", features = ["clap", "cli"] }
clap = { version = "4.5.4", features = ["derive"] } clap = { version = "4.5.4", features = ["derive"] }
axum = { version = "0.7.5", features = ["ws"] } axum = { version = "0.7.5", features = ["ws"] }
axum-client-ip = "0.6.0" axum-client-ip = "0.6.0"
confy = "0.6.1"
dotenvy = { version = "0.15.7", features = ["clap", "cli"] }
lazy_static = "1.4.0"

View file

@ -5,40 +5,12 @@ use clap::{Parser, Subcommand};
use std::{env, sync::Arc}; use std::{env, sync::Arc};
use tracing::debug; use tracing::debug;
/// This struct defines the CLI arguments and subcommands for the caesar command line application. use crate::config::GLOBAL_CONFIG;
///
/// The #[derive(Parser, Debug)] macro generates code that:
/// - parses the command line arguments using the clap library
/// - provides a Debug implementation for the struct
///
/// The #[command(version, about, long_about = None)] macro generates code that:
/// - defines the version and about strings for the application
/// - specifies that there is no long about help text
///
/// The #[command(subcommand)] macro generates code that:
/// - defines a subcommand for the caesar command line application.
/// Subcommands are used to break up a large number of options into
/// smaller, more manageable groups.
///
/// The #[command] macro is used to annotate the `command` field of the struct.
/// The `command` field is an Option<Commands> type, which means that the
/// subcommand is optional.
/// If the subcommand is not provided, the program will exit with a status code
/// of 0 and without printing any output.
///
/// The Commands enum defines the possible subcommands for the caesar command
/// line application.
/// See the Commands enum definition for more information about the available
/// subcommands.
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(version = env!("CARGO_PKG_VERSION"), about = "Send and receive files securely")] #[command(version = env!("CARGO_PKG_VERSION"), about = "Send and receive files securely")]
#[command(long_about = None)] #[command(long_about = None)]
pub struct Args { pub struct Args {
/// The subcommand for the caesar command line application.
/// Subcommands are used to break up a large number of options into smaller,
/// more manageable groups.
/// If no subcommand is provided, the program will exit with a status code
/// of 0 and without printing any output.
#[command(subcommand)] #[command(subcommand)]
pub command: Option<Commands>, pub command: Option<Commands>,
} }
@ -80,116 +52,47 @@ pub enum Commands {
} }
impl Default for Args { impl Default for Args {
// This function is called by the Default trait when no value is
// provided for a field of type Args. It returns an instance of
// Args that has been created by calling the new() function.
//
// The Default trait is used by various parts of the program to
// provide a sensible default value for a field when no value is
// provided. For example, the clap crate uses the Default trait when
// parsing command line arguments to provide a default value for
// a field.
//
// The new() function is a constructor function for Args that
// creates an instance of Args with default field values.
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
} }
impl Args { impl Args {
/// Creates a new instance of Args by parsing command line arguments
///
/// This function is a constructor for Args. It uses the clap crate to parse
/// command line arguments and creates an instance of Args with the values
/// provided by the user.
///
/// The clap crate is a command line argument parser that is well tested and
/// widely used. It provides a simple way to define command line
/// arguments and generate helpful documentation for the user.
///
/// The `parse()` function is used to parse the command line arguments and
/// return an instance of Args.
pub fn new() -> Self { pub fn new() -> Self {
Self::parse() Self::parse()
} }
/// Runs the command specified by the user
///
/// This function is called after the command line arguments have been
/// parsed. It matches on the `command` field of the Args struct to determine
/// what command the user wants to run.
///
/// The match statement checks the value of `command` and calls the
/// appropriate function to run the command. The functions that are called
/// are located in other modules of the program.
///
/// The `run()` function is called by the `main()` function of the program.
/// The program's entry point is the `main()` function, which parses the
/// command line arguments and then calls `run()` on the resulting Args
/// instance.
///
/// The `run()` function returns a Result. The error type is `Box<dyn
/// std::error::Error + Send + Sync>`. This means that the error type is a
/// trait object that represents an error that can be sent across threads
/// and sent over a network connection. The `Send` and `Sync` traits are part
/// of the standard library and are used to indicate that the error type can
/// be sent across threads and sent over a network connection.
///
/// The `run()` function does not return anything if the command is `None`.
/// This is because `command` is an `Option<Commands>`. If the user does
/// not specify a command, then `command` is `None`. In this case, there is
/// nothing to run, so `run()` returns early with no error.
pub async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { pub async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let cfg = &GLOBAL_CONFIG;
debug!("args: {:#?}", self); debug!("args: {:#?}", self);
match &self.command { match &self.command {
// If the user wants to send files, call `start_sender()` in the
// `sender` module with the list of files that the user wants to
// send.
Some(Commands::Send { relay, files }) => { Some(Commands::Send { relay, files }) => {
let relay_string: String = relay let relay_string: String = relay.as_deref().unwrap_or(&cfg.app_origin).to_string();
.as_deref()
.unwrap_or(
&env::var("APP_ORIGIN")
.unwrap_or("wss://caesar-transfer-iu.shuttleapp.rs/ws".to_string()),
)
.to_string();
let relay_arc = Arc::new(relay_string); let relay_arc = Arc::new(relay_string);
let files_arc = Arc::new(files.to_vec()); let files_arc = Arc::new(files.to_vec());
sender::start_sender(relay_arc, files_arc).await; sender::start_sender(relay_arc, files_arc).await;
} }
// If the user wants to receive files, call `start_receiver()` in the
// `receiver` module with the name of the transfer that the user
// wants to download.
Some(Commands::Receive { Some(Commands::Receive {
relay, relay,
overwrite: _, overwrite: _,
name, name,
}) => { }) => {
println!("Receive for {name:?}"); println!("Receive for {name:?}");
receiver::start_receiver( receiver::start_receiver(relay.as_deref().unwrap_or(&cfg.app_origin), name).await;
relay.as_deref().unwrap_or(
env::var("APP_ORIGIN")
.unwrap_or("ws://0.0.0.0:8000/ws".to_string())
.as_str(),
),
name,
)
.await;
} }
// If the user wants to start a relay server, call `start_ws()` in the
// `relay` module with the port and listen address that the user
// specified.
Some(Commands::Serve { Some(Commands::Serve {
port, port,
listen_address, listen_address,
}) => { }) => {
println!("Serve with address '{listen_address:?}' and '{port:?}'"); println!("Serve with address '{listen_address:?}' and '{port:?}'");
relay::server::start_ws(port.as_ref(), listen_address.as_ref()).await; let address: String = listen_address
.as_deref()
.unwrap_or(&cfg.app_host)
.to_string();
let port_value = port.unwrap_or(cfg.app_port.parse::<i32>().unwrap_or(0));
let port: i32 = port_value;
relay::server::start_ws(&port, &address).await;
} }
// If the user does not specify a command, return early with no error.
// This is because `command` is an `Option<Commands>`. If the user does
// not specify a command, then `command` is `None`.
None => {} None => {}
} }
Ok(()) Ok(())

33
caesar-cli/src/config.rs Normal file
View file

@ -0,0 +1,33 @@
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct CaesarConfig {
pub app_environment: String,
pub app_host: String,
pub app_port: String,
pub app_origin: String,
pub app_relay: String,
pub rust_log: String,
}
impl Default for CaesarConfig {
fn default() -> Self {
CaesarConfig {
app_environment: "production".to_string(),
app_host: "0.0.0.0".to_string(),
app_port: "8000".to_string(),
app_origin: "wss://caesar-transfer-iu.shuttleapp.rs".to_string(),
app_relay: "0.0.0.0:8000".to_string(),
rust_log: "info".to_string(),
}
}
}
lazy_static! {
pub static ref GLOBAL_CONFIG: CaesarConfig = {
let cfg: CaesarConfig =
confy::load("caesar", "caesar").expect("could not find config file");
cfg
};
}

View file

@ -1,33 +1,20 @@
use crate::cli::args::Args; use crate::cli::args::Args;
use dotenv::dotenv; use dotenvy::dotenv;
use tracing::error; use tracing::error;
use tracing_subscriber::filter::EnvFilter; use tracing_subscriber::filter::EnvFilter;
mod cli; mod cli;
mod config;
#[tokio::main] #[tokio::main]
// This is the entrypoint of caesar.
// The #[tokio::main] attribute is required for any async code, and it
// sets up the tokio runtime.
// The async fn main() is the entrypoint of the application, and it's where
// we kick off our program.
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Load environment variables from a .env file if one is present.
dotenv().ok(); dotenv().ok();
// Set up our logging subscriber.
// TheEnvFilter::from_default_env reads the env variable RUST_LOG
// and sets up the logging accordingly.
// The default is INFO level logging.
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env()) .with_env_filter(EnvFilter::from_default_env())
.init(); .init();
// Parse the command line arguments.
let args = Args::new(); let args = Args::new();
// Run the commands based on the parsed arguments.
// If there is an error, print it to the console with the error! macro.
if let Err(e) = args.run().await { if let Err(e) = args.run().await {
error!("{e}"); error!("{e}");
} }
// Return an Ok result, which just means that our program exited successfully.
Ok(()) Ok(())
} }

View file

@ -22,21 +22,6 @@ use tracing::error;
const DESTINATION: u8 = 0; const DESTINATION: u8 = 0;
const NONCE_SIZE: usize = 12; const NONCE_SIZE: usize = 12;
/// This struct represents a file that is being received.
///
/// The struct contains information about the file, such as its name, size,
/// and the handle of the file on the disk.
///
/// The `name` field contains the name of the file, which is the name of the
/// file on the disk.
///
/// The `size` field contains the size of the file in bytes.
///
/// The `progress` field contains the number of bytes that have already been
/// received for the file.
///
/// The `handle` field contains a handle to the file on the disk, which is
/// used to read the contents of the file.
struct File { struct File {
name: String, name: String,
size: u64, size: u64,
@ -44,224 +29,58 @@ struct File {
handle: fs::File, handle: fs::File,
} }
/// This struct contains the context for the receiver.
///
/// This structure is used to keep track of the state of the receiver, and to
/// pass information between functions.
struct Context { struct Context {
/// The HMAC key that is used to verify that the packets that are received
/// are authentic.
///
/// The HMAC key is generated by the sender, and is used to verify that the
/// packets that are received are authentic. If the HMAC of a packet does
/// not match the expected HMAC, then the packet is not processed.
hmac: Vec<u8>, hmac: Vec<u8>,
/// The sender that is used to send packets to the server.
///
/// The sender is used to send packets to the server. The sender is also
/// used to receive packets from the server.
sender: Sender, sender: Sender,
/// The ephemeral secret key that is used for key exchange with the sender.
///
/// The ephemeral secret key is generated by the receiver, and is used to
/// exchange a shared key with the sender. The shared key is used to
/// encrypt and decrypt packets.
key: EphemeralSecret, key: EphemeralSecret,
/// The shared key that is used to encrypt and decrypt packets.
///
/// The shared key is established between the receiver and the sender during
/// the key exchange. The shared key is used to encrypt and decrypt packets
/// between the receiver and the sender. If the shared key is `None`, then the
/// packets that are received are not processed.
shared_key: Option<Aes128Gcm>, shared_key: Option<Aes128Gcm>,
/// The files that are being received.
///
/// The files vector contains a list of all the files that are being
/// received. Each file is represented by a `File` struct. The `name` field
/// of the `File` struct contains the name of the file, which is the name of
/// the file on the disk. The `size` field of the `File` struct contains the
/// size of the file in bytes. The `progress` field of the `File` struct
/// contains the number of bytes that have already been received for the
/// file. The `handle` field of the `File` struct contains a handle to the
/// file on the disk, which is used to read the contents of the file.
files: Vec<File>, files: Vec<File>,
/// The sequence number of the next chunk that is expected to be received.
///
/// The sequence number is used to keep track of the sequence of chunks that
/// are received. If a chunk does not have the expected sequence number,
/// then the chunk is not processed.
sequence: u32, sequence: u32,
/// The index of the file that is currently being received.
///
/// The index is used to keep track of which file is currently being
/// received. The index is incremented after a file is completely received.
index: usize, index: usize,
/// The progress of the current file that is being received.
///
/// The progress is used to keep track of the progress of the current file
/// that is being received. The progress is calculated by dividing the
/// number of bytes that have been received by the size of the file. The
/// progress is sent to the server so that the sender knows how much of the
/// file has been received.
progress: u64, progress: u64,
/// The total length of the current file that is being received.
///
/// The length is used to keep track of the total length of the current file
/// that is being received. The length is used to calculate the progress of
/// the file.
length: u64, length: u64,
} }
/// This function is called when the receiver receives a join room packet from
/// the server. The packet contains the size of the list of files that the
/// sender is going to send.
///
/// If the packet does not contain the size of the list, then an error is
/// returned.
///
/// If the packet does contain the size of the list, then a message is printed
/// to the console indicating that the receiver has connected to the room.
///
/// The function does not do anything else. It returns a `Status::Continue`
/// variant to indicate that the event loop should continue processing events.
fn on_join_room(size: Option<usize>) -> Status { fn on_join_room(size: Option<usize>) -> Status {
// If the packet does not contain the size of the list, then return an error.
if size.is_none() { if size.is_none() {
return Status::Err("Invalid join room packet.".into()); return Status::Err("Invalid join room packet.".into());
} }
// If the packet contains the size of the list, then print a message to the
// console indicating that the receiver has connected to the room.
println!("Connected to room."); println!("Connected to room.");
// Return a `Status::Continue` variant to indicate that the event loop
// should continue processing events.
Status::Continue() Status::Continue()
} }
/// This function is called when the event loop receives an error packet from
/// the server. The packet contains a message with a description of the error.
///
/// When an error occurs, the server sends an error packet to the client. The
/// error packet contains a message with a description of the error. This
/// function extracts that message and creates a `Status::Err` variant with it,
/// which is then returned to be handled by the main event loop.
///
/// When the event loop receives a status variant that is an error, it exits
/// with an error message containing the message from the error packet.
///
/// The message from the error packet is the only information that the event
/// loop has about the error, so the message should be descriptive and
/// helpful to the user. The message should not contain technical details
/// about the error or how it occurred. Instead, the message should be
/// written from the perspective of the user and should give the user enough
/// information to understand what went wrong and how they might be able to
/// fix the problem.
///
/// This function takes the message from the error packet and creates a
/// `Status::Err` variant with it. The function returns this variant to be
/// handled by the main event loop.
fn on_error(message: String) -> Status { fn on_error(message: String) -> Status {
Status::Err(message) Status::Err(message)
} }
/// This function is called when the event loop receives a leave room packet from
/// the server. The packet contains the index of the file that was being
/// transferred when the receiver left the room.
///
/// When the receiver receives a leave room packet, it means that the sender
/// has disconnected from the room. In this case, the receiver should check if
/// there are any files that were being transferred but not yet complete. If
/// there are any incomplete files, the receiver should print a message to
/// the user indicating that the transfer was interrupted.
///
/// If there are no incomplete files, then the receiver should exit
/// normally. The `Status::Exit` variant is returned to the main event loop,
/// which will cause the event loop to exit normally.
///
/// This function checks if there are any incomplete files by iterating over
/// the list of files in the context. If there are any files with a progress
/// less than 100%, then the function prints a message to the user and returns
/// an error status.
///
/// If there are no incomplete files, then the function simply returns a
/// `Status::Exit` variant. This will cause the main event loop to exit
/// normally.
fn on_leave_room(context: &mut Context, _: usize) -> Status { fn on_leave_room(context: &mut Context, _: usize) -> Status {
// Check if there are any incomplete files.
if context.files.iter().any(|file| file.progress < 100) { if context.files.iter().any(|file| file.progress < 100) {
// If there are any incomplete files, print a message to the user.
println!(); println!();
println!("Transfer was interrupted because the host left the room."); println!("Transfer was interrupted because the host left the room.");
// Return an error status.
Status::Err("Transfer was interrupted because the host left the room.".into()) Status::Err("Transfer was interrupted because the host left the room.".into())
} else { } else {
// If there are no incomplete files, return a `Status::Exit` variant.
// This will cause the event loop to exit normally.
Status::Exit() Status::Exit()
} }
} }
/// This function is called when the event loop receives a list packet from
/// the server. The packet contains a list of files to be transferred.
///
/// When this function is called, we know that the sender has successfully
/// established a shared key with the receiver. Therefore, we can start
/// receiving encrypted files.
///
/// This function iterates over the list of files in the packet and creates a
/// file on disk for each file in the list. If a file with the same name already
/// exists, an error is returned and the event loop is exited with an error
/// message. This is because the receiver should not overwrite existing files
/// without the user's explicit permission.
///
/// Once all the files have been created, the function initializes the context
/// variables for the event loop. `index` is set to 0 to indicate that we are
/// currently transferring the first file. `progress` is set to 0 to indicate
/// that the progress of the first file is 0%. `sequence` is set to 0 to
/// indicate that the first chunk of data we receive will have a sequence
/// number of 0. `length` is set to 0 to indicate that we have not received
/// any data yet.
///
/// If there is an error creating any of the files, the function returns an
/// error status. This will cause the event loop to exit with an error message.
///
/// If there are no errors, the function returns a `Status::Continue()` variant.
/// This will cause the event loop to continue running and wait for more
/// packets from the sender.
fn on_list(context: &mut Context, list: ListPacket) -> Status { fn on_list(context: &mut Context, list: ListPacket) -> Status {
if context.shared_key.is_none() { if context.shared_key.is_none() {
return Status::Err("Invalid list packet: no shared key established".into()); return Status::Err("Invalid list packet: no shared key established".into());
} }
// Iterate over the list of files in the packet.
for entry in list.entries { for entry in list.entries {
// Sanitize the file name to remove any characters that are not valid in
// file names on the current platform.
let path = sanitize_filename::sanitize(entry.name.clone()); let path = sanitize_filename::sanitize(entry.name.clone());
// Check if a file with the same name already exists.
if Path::new(&path).exists() { if Path::new(&path).exists() {
// If the file already exists, return an error and exit the event loop
// with an error message.
return Status::Err(format!("The file '{}' already exists.", path)); return Status::Err(format!("The file '{}' already exists.", path));
} }
// Try to create a new file with the sanitized file name.
let handle = match fs::File::create(&path) { let handle = match fs::File::create(&path) {
Ok(handle) => handle, Ok(handle) => handle,
Err(error) => { Err(error) => {
// If there is an error creating the file, return an error and
// exit the event loop with an error message.
return Status::Err(format!( return Status::Err(format!(
"Error: Failed to create file '{}': {}", "Error: Failed to create file '{}': {}",
path, error path, error
@ -269,7 +88,6 @@ fn on_list(context: &mut Context, list: ListPacket) -> Status {
} }
}; };
// Create a new file struct for the file we just created.
let file = File { let file = File {
name: entry.name, name: entry.name,
size: entry.size, size: entry.size,
@ -277,55 +95,22 @@ fn on_list(context: &mut Context, list: ListPacket) -> Status {
progress: 0, progress: 0,
}; };
// Add the new file to the list of files in the context.
context.files.push(file); context.files.push(file);
} }
// Set the context variables for the event loop.
context.index = 0; context.index = 0;
context.progress = 0; context.progress = 0;
context.sequence = 0; context.sequence = 0;
context.length = 0; context.length = 0;
// Return a `Status::Continue()` variant to indicate that the event loop
// should continue running and wait for more packets from the sender.
Status::Continue() Status::Continue()
} }
/// This function handles a chunk packet received from the sender.
///
/// It checks that the shared key has been established, that the sequence number
/// of the chunk matches the expected sequence number in the context, and that
/// the index of the file in the context is valid.
///
/// If any of these checks fail, an error is returned and the event loop is
/// stopped.
///
/// The function updates the length of the file, increments the sequence number
/// in the context, and writes the contents of the chunk to the file.
///
/// The progress of the file is updated to be the ratio of the number of bytes
/// read so far to the total size of the file.
///
/// If the progress of the file is 100%, or if the difference in progress between
/// this chunk and the last chunk is greater than or equal to 1, or if this is the
/// first chunk, a ProgressPacket is sent to the sender with the index of the file
/// in the context and the progress of the file.
///
/// If the size of the file has been reached, the index of the current file is
/// incremented, the length of the current file is set to 0, the progress of the
/// current file is set to 0, and the sequence number is set to 0.
///
/// Finally, a Status::Continue() variant is returned to indicate that the event
/// loop should continue running and wait for more packets from the sender.
fn on_chunk(context: &mut Context, chunk: ChunkPacket) -> Status { fn on_chunk(context: &mut Context, chunk: ChunkPacket) -> Status {
// Check that the shared key has been established.
if context.shared_key.is_none() { if context.shared_key.is_none() {
return Status::Err("Invalid chunk packet: no shared key established".into()); return Status::Err("Invalid chunk packet: no shared key established".into());
} }
// Check that the sequence number of the chunk matches the expected sequence
// number in the context.
if chunk.sequence != context.sequence { if chunk.sequence != context.sequence {
return Status::Err(format!( return Status::Err(format!(
"Expected sequence {}, but got {}.", "Expected sequence {}, but got {}.",
@ -333,40 +118,26 @@ fn on_chunk(context: &mut Context, chunk: ChunkPacket) -> Status {
)); ));
} }
// Get a mutable reference to the file in the context at the index of the
// file.
let Some(file) = context.files.get_mut(context.index) else { let Some(file) = context.files.get_mut(context.index) else {
// If the index of the file in the context is invalid, return an error and
// stop the event loop.
return Status::Err("Invalid file index.".into()); return Status::Err("Invalid file index.".into());
}; };
// Update the length of the file.
context.length += chunk.chunk.len() as u64; context.length += chunk.chunk.len() as u64;
// Increment the sequence number in the context.
context.sequence += 1; context.sequence += 1;
// Write the contents of the chunk to the file.
file.handle.write(&chunk.chunk).unwrap(); file.handle.write(&chunk.chunk).unwrap();
// Update the progress of the file.
file.progress = (context.length * 100) / file.size; file.progress = (context.length * 100) / file.size;
// If the progress of the file is 100%, or if the difference in progress between
// this chunk and the last chunk is greater than or equal to 1, or if this is the
// first chunk, send a ProgressPacket to the sender.
if file.progress == 100 || file.progress - context.progress >= 1 || chunk.sequence == 0 { if file.progress == 100 || file.progress - context.progress >= 1 || chunk.sequence == 0 {
context.progress = file.progress; context.progress = file.progress;
let progress = ProgressPacket { let progress = ProgressPacket {
// Convert the index of the file in the context to a u32.
index: context.index.try_into().unwrap(), index: context.index.try_into().unwrap(),
// Convert the progress of the file to a u32.
progress: context.progress.try_into().unwrap(), progress: context.progress.try_into().unwrap(),
}; };
// Send the ProgressPacket to the sender.
context.sender.send_encrypted_packet( context.sender.send_encrypted_packet(
&context.shared_key, &context.shared_key,
DESTINATION, DESTINATION,
@ -377,9 +148,6 @@ fn on_chunk(context: &mut Context, chunk: ChunkPacket) -> Status {
std::io::Write::flush(&mut stdout()).unwrap(); std::io::Write::flush(&mut stdout()).unwrap();
} }
// If the size of the file has been reached, increment the index of the
// current file, set the length of the current file to 0, set the progress
// of the current file to 0, and resets the sequence number to 0.
if file.size == context.length { if file.size == context.length {
context.index += 1; context.index += 1;
context.length = 0; context.length = 0;
@ -389,65 +157,39 @@ fn on_chunk(context: &mut Context, chunk: ChunkPacket) -> Status {
println!(); println!();
} }
// Return a Status::Continue() variant to indicate that the event loop should
// continue running and wait for more packets from the sender.
Status::Continue() Status::Continue()
} }
/// This function is called when the Receiver receives a HandshakePacket from the
/// Sender. It verifies the signature of the Sender's public key and generates its own
/// public key. It then generates a shared secret key between the Receiver and the Sender
/// using the Diffie-Hellman key exchange.
///
/// The Receiver sends back a HandshakeResponsePacket to the Sender with its own public
/// key and a signature created using the shared secret key and its own private key.
///
/// The shared secret key is used to encrypt packets sent between the Receiver and the
/// Sender.
fn on_handshake(context: &mut Context, handshake: HandshakePacket) -> Status { fn on_handshake(context: &mut Context, handshake: HandshakePacket) -> Status {
// If a shared key has already been established, this means that the Receiver
// has already performed the handshake, so return an error.
if context.shared_key.is_some() { if context.shared_key.is_some() {
return Status::Err("Already performed handshake.".into()); return Status::Err("Already performed handshake.".into());
} }
// Create a new HMAC using the hmac from the Context struct as the key.
let mut mac = Hmac::<Sha256>::new_from_slice(&context.hmac).unwrap(); let mut mac = Hmac::<Sha256>::new_from_slice(&context.hmac).unwrap();
// Update the HMAC with the public key from the HandshakePacket.
mac.update(&handshake.public_key); mac.update(&handshake.public_key);
// Call verify_slice() on the HMAC to verify the signature from the Sender.
// If the signature is invalid, return an error.
let verification = mac.verify_slice(&handshake.signature); let verification = mac.verify_slice(&handshake.signature);
if verification.is_err() { if verification.is_err() {
return Status::Err("Invalid signature from the sender.".into()); return Status::Err("Invalid signature from the sender.".into());
} }
// Generate the Receiver's public key from the private key.
let public_key = context.key.public_key().to_sec1_bytes().into_vec(); let public_key = context.key.public_key().to_sec1_bytes().into_vec();
// Create a new HMAC using the hmac from the Context struct as the key.
let mut mac = Hmac::<Sha256>::new_from_slice(&context.hmac).unwrap(); let mut mac = Hmac::<Sha256>::new_from_slice(&context.hmac).unwrap();
// Update the HMAC with the public key of the Receiver.
mac.update(&public_key); mac.update(&public_key);
// Serialize the resulting HMAC into a byte array and use it as the
// signature in the HandshakeResponsePacket.
let signature = mac.finalize().into_bytes().to_vec(); let signature = mac.finalize().into_bytes().to_vec();
// Create a new shared secret key between the Receiver and the Sender.
let shared_public_key = PublicKey::from_sec1_bytes(&handshake.public_key).unwrap(); let shared_public_key = PublicKey::from_sec1_bytes(&handshake.public_key).unwrap();
let shared_secret = context.key.diffie_hellman(&shared_public_key); let shared_secret = context.key.diffie_hellman(&shared_public_key);
let shared_secret = shared_secret.raw_secret_bytes(); let shared_secret = shared_secret.raw_secret_bytes();
let shared_secret = &shared_secret[0..16]; let shared_secret = &shared_secret[0..16];
// Create a new Aes128Gcm key from the shared secret.
let shared_key: &Key<Aes128Gcm> = shared_secret.into(); let shared_key: &Key<Aes128Gcm> = shared_secret.into();
let shared_key = <Aes128Gcm as aes_gcm::KeyInit>::new(shared_key); let shared_key = <Aes128Gcm as aes_gcm::KeyInit>::new(shared_key);
// Create the HandshakeResponsePacket and send it to the Sender.
let handshake_response = HandshakeResponsePacket { let handshake_response = HandshakeResponsePacket {
public_key, public_key,
signature, signature,
@ -457,81 +199,55 @@ fn on_handshake(context: &mut Context, handshake: HandshakePacket) -> Status {
.sender .sender
.send_packet(DESTINATION, Value::HandshakeResponse(handshake_response)); .send_packet(DESTINATION, Value::HandshakeResponse(handshake_response));
// Store the shared key in the Context struct.
context.shared_key = Some(shared_key); context.shared_key = Some(shared_key);
// Return a Status::Continue() variant to indicate that the event loop should
// continue running and wait for more packets from the Sender.
Status::Continue() Status::Continue()
} }
/// This function is called when a message is received from the Sender.
///
/// The message can be either text or binary. If it's text, we attempt to
/// parse it as a JsonPacketResponse and match on the type of response it is.
/// If it's binary, we attempt to decrypt it using the shared key (if it
/// exists) and then decode it into a Packet. We then match on the type of
/// value in the Packet and call the appropriate function with the relevant
/// data.
///
/// If the message is not text or binary, we return a Status::Err with an
/// appropriate error message.
fn on_message(context: &mut Context, message: WebSocketMessage) -> Status { fn on_message(context: &mut Context, message: WebSocketMessage) -> Status {
if message.is_text() { match message.clone() {
let text = message.into_text().unwrap(); WebSocketMessage::Text(text) => {
let packet = serde_json::from_str(&text).unwrap(); let packet = match serde_json::from_str(&text) {
Ok(packet) => packet,
Err(_) => {
return Status::Continue();
}
};
return match packet {
JsonPacketResponse::Join { size } => on_join_room(size),
JsonPacketResponse::Leave { index } => on_leave_room(context, index),
JsonPacketResponse::Error { message } => on_error(message),
_ => Status::Err(format!("Unexpected json packet: {:?}", packet)),
};
}
WebSocketMessage::Binary(data) => {
let data = &data[1..];
return match packet { let data = if let Some(shared_key) = &context.shared_key {
JsonPacketResponse::Join { size } => on_join_room(size), let nonce = &data[..NONCE_SIZE];
JsonPacketResponse::Leave { index } => on_leave_room(context, index), let ciphertext = &data[NONCE_SIZE..];
JsonPacketResponse::Error { message } => on_error(message),
_ => Status::Err(format!("Unexpected json packet: {:?}", packet)), shared_key.decrypt(nonce.into(), ciphertext).unwrap()
}; } else {
} else if message.is_binary() { data.to_vec()
let data = message.into_data(); };
let data = &data[1..];
let data = if let Some(shared_key) = &context.shared_key { let packet = Packet::decode(data.as_ref()).unwrap();
let nonce = &data[..NONCE_SIZE]; let value = packet.value.unwrap();
let ciphertext = &data[NONCE_SIZE..]; return match value {
Value::List(list) => on_list(context, list),
shared_key.decrypt(nonce.into(), ciphertext).unwrap() Value::Chunk(chunk) => on_chunk(context, chunk),
} else { Value::Handshake(handshake) => on_handshake(context, handshake),
data.to_vec() _ => Status::Err(format!("Unexpected packet: {:?}", value)),
}; };
}
let packet = Packet::decode(data.as_ref()).unwrap(); _ => (),
let value = packet.value.unwrap();
return match value {
Value::List(list) => on_list(context, list),
Value::Chunk(chunk) => on_chunk(context, chunk),
Value::Handshake(handshake) => on_handshake(context, handshake),
_ => Status::Err(format!("Unexpected packet: {:?}", value)),
};
} }
Status::Err("Invalid message type".into()) Status::Err("Invalid message type".into())
} }
/// This function takes a websocket connection and an invite code,
/// splits the connection into an outgoing and incoming part,
/// creates a context for the connection, sends a join room packet,
/// and starts two futures to handle incoming and outgoing messages.
///
/// The outgoing future reads from a channel and sends the messages
/// through the outgoing part of the connection. If the sending fails,
/// the future will print an error and exit.
///
/// The incoming future reads from the incoming part of the connection
/// and passes the messages to the `on_message` function. If the message
/// is an exit or an error, the function will print the error and exit.
/// If the message is any other type of packet, it will be handled by the
/// `on_message` function and the future will continue running.
pub async fn start(socket: Socket, fragment: &str) { pub async fn start(socket: Socket, fragment: &str) {
// Extract the room id and hmac from the invite code
let Some(index) = fragment.rfind('-') else { let Some(index) = fragment.rfind('-') else {
println!("Error: The invite code '{}' is not valid.", fragment); println!("Error: The invite code '{}' is not valid.", fragment);
return; return;
@ -544,16 +260,12 @@ pub async fn start(socket: Socket, fragment: &str) {
return; return;
}; };
// Create a new ephemeral key pair
let key = EphemeralSecret::random(&mut OsRng); let key = EphemeralSecret::random(&mut OsRng);
// Create a channel for sending messages
let (sender, receiver) = flume::bounded(1000); let (sender, receiver) = flume::bounded(1000);
// Split the websocket connection into an outgoing and incoming part
let (outgoing, incoming) = socket.split(); let (outgoing, incoming) = socket.split();
// Create a new context for the connection
let mut context = Context { let mut context = Context {
hmac, hmac,
sender, sender,
@ -570,40 +282,32 @@ pub async fn start(socket: Socket, fragment: &str) {
println!("Attempting to join room '{}'...", id); println!("Attempting to join room '{}'...", id);
// Send a join room packet to the server
context context
.sender .sender
.send_json_packet(JsonPacket::Join { id: id.to_string() }); .send_json_packet(JsonPacket::Join { id: id.to_string() });
// Create futures for handling incoming and outgoing messages
let outgoing_handler = receiver.stream().map(Ok).forward(outgoing); let outgoing_handler = receiver.stream().map(Ok).forward(outgoing);
let incoming_handler = incoming.try_for_each(|message| { let incoming_handler = incoming.try_for_each(|message| {
// Call the on_message function to handle the message
match on_message(&mut context, message) { match on_message(&mut context, message) {
// If the message is an exit, print a message and exit
Status::Exit() => { Status::Exit() => {
context.sender.send_json_packet(JsonPacket::Leave);
println!("Transfer has completed."); println!("Transfer has completed.");
return future::err(Error::ConnectionClosed); return future::err(Error::ConnectionClosed);
} }
// If the message is an error, print the error and exit
Status::Err(error) => { Status::Err(error) => {
println!("Error: {}", error); println!("Error: {}", error);
return future::err(Error::ConnectionClosed); return future::err(Error::ConnectionClosed);
} }
// If the message is any other type of packet, do nothing
_ => {} _ => {}
}; };
// Continue running the future
future::ok(()) future::ok(())
}); });
// Pin the futures to the stack so they can run concurrently
pin_mut!(incoming_handler, outgoing_handler); pin_mut!(incoming_handler, outgoing_handler);
// Wait for either future to complete
future::select(incoming_handler, outgoing_handler).await; future::select(incoming_handler, outgoing_handler).await;
} }
#[cfg(test)] #[cfg(test)]

View file

@ -1,40 +1,3 @@
/// This module is the entry point for the receiver command.
/// It contains a single function, `start_receiver`, which is the
/// entry point for the receiver program.
///
/// The `start_receiver` function takes a `String` which is the URL or
/// invite code for the room that the receiver should join. If the
/// URL is invalid or does not contain an invite code fragment,
/// the function falls back to using the command line arguments to get
/// the file paths to be sent.
///
/// The `start_receiver` function first creates a request to connect
/// to the WebSocket server with a specific origin. This is done to
/// prevent cross-origin requests, which are not allowed by the
/// WebSocket protocol.
///
/// If creating the request succeeds, the function inserts the origin
/// into the request headers. Then, it attempts to connect to the
/// server using the `connect_async` function from the
/// `tokio_tungstenite` crate.
///
/// If the connection attempt succeeds, the function extracts the
/// invite code fragment from the URL and passes it to the `start`
/// function in the `receiver::client` module. The `start` function is
/// defined in the `receiver::client` module and is the function that
/// interacts with the server to receive files.
///
/// If the connection attempt fails or the URL does not contain an
/// invite code fragment, the function falls back to using the command
/// line arguments to get the file paths to be sent. It then calls the
/// `start` function in the `sender::client` module with the
/// WebSocket stream and the file paths. The `start` function in the
/// `sender::client` module is defined in the `sender::client`
/// module and is the function that sends the files over the
/// WebSocket connection.
///
/// The `start` function takes ownership of the WebSocket stream and
/// the file paths, so we pass them by value.
pub mod client; pub mod client;
pub mod http_client; pub mod http_client;
@ -65,13 +28,6 @@ pub async fn start_receiver(relay: &str, name: &str) {
Ok(()) => debug!("Success"), Ok(()) => debug!("Success"),
Err(e) => error!("Error: {e:?}"), Err(e) => error!("Error: {e:?}"),
}; };
// if let Err(e) = start_ws_com(res_ip.as_str(), res.local_room_id.as_str()).await {
// debug!("Failed to connect local with first room_id: {e}");
// if let Err(e) = start_ws_com(relay, res.relay_room_id.as_str()).await {
// debug!("Failed to connect remote with first room_id: {e}");
// }
// }
} }
pub async fn start_ws_com(relay: &str, name: &str) -> Result<(), Box<dyn std::error::Error>> { pub async fn start_ws_com(relay: &str, name: &str) -> Result<(), Box<dyn std::error::Error>> {
@ -81,8 +37,6 @@ pub async fn start_ws_com(relay: &str, name: &str) -> Result<(), Box<dyn std::er
return Err("Failed to create request".into()); return Err("Failed to create request".into());
}; };
// Insert the origin into the request headers to prevent
// cross-origin requests.
request request
.headers_mut() .headers_mut()
.insert("Origin", HeaderValue::from_str(relay).unwrap()); .insert("Origin", HeaderValue::from_str(relay).unwrap());
@ -105,9 +59,5 @@ pub async fn start_ws_com(relay: &str, name: &str) -> Result<(), Box<dyn std::er
Err(Box::new(e)) Err(Box::new(e))
}?, }?,
}; };
// The start function is defined in the
// receiver::client module and is the function that interacts with
// the server to receive files.
// receiver::start(socket, name).await
Ok(()) Ok(())
} }

View file

@ -4,11 +4,6 @@ use tokio::sync::RwLock;
use crate::relay::room::Room; use crate::relay::room::Room;
use crate::relay::transfer::TransferResponse; use crate::relay::transfer::TransferResponse;
/// A struct that holds all of the rooms that the server knows about.
///
/// The rooms are stored in a `HashMap` with the room ID as the key and the
/// room as the value. This means that looking up a room by its ID is an O(1)
/// operation, which is very fast.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct AppState { pub struct AppState {
pub rooms: HashMap<String, Room>, pub rooms: HashMap<String, Room>,
@ -16,35 +11,8 @@ pub struct AppState {
} }
impl AppState { impl AppState {
/// Creates a new `Server` with an empty list of rooms.
///
/// The `rooms` field of the returned `Server` is an empty `HashMap`.
/// This means that the server will not have any rooms when it is first
/// created.
///
/// This function returns an `Arc<RwLock<Server>>` because the server
/// needs to be shared between different parts of the program. The
/// `Arc` makes it so that the server can be shared by multiple threads,
/// and the `RwLock` makes it so that the server can be read from and
/// written to from multiple threads at the same time.
///
/// The `Arc` and `RwLock` are both parts of the `tokio` library, which
/// provides asynchronous programming tools for Rust.
///
/// The `Arc` and `RwLock` are used together to create a Mutex-like
/// object that can be shared between threads. The main difference
/// between a Mutex and an `Arc<RwLock<T>>` is that a Mutex can only be
/// locked by one thread at a time, while an `Arc<RwLock<T>>` can be
/// locked by multiple threads at the same time.
///
/// This function is used to create a new `Server` and share it between
/// different parts of the program. The `Server` is shared because it
/// needs to be able to handle connections from multiple clients at the
/// same time.
pub fn new() -> Arc<RwLock<AppState>> { pub fn new() -> Arc<RwLock<AppState>> {
// Create a new `Server` instance.
Arc::new(RwLock::new(AppState { Arc::new(RwLock::new(AppState {
// Initialize the list of rooms to be empty.
rooms: HashMap::new(), rooms: HashMap::new(),
transfers: Vec::new(), transfers: Vec::new(),
})) }))

View file

@ -11,27 +11,6 @@ use crate::relay::ResponsePacket;
use uuid::Uuid; use uuid::Uuid;
type Sender = Arc<Mutex<SplitSink<axum::extract::ws::WebSocket, Message>>>; type Sender = Arc<Mutex<SplitSink<axum::extract::ws::WebSocket, Message>>>;
/// This struct represents a single client connection to the server.
///
/// A `Client` instance holds a `Sender` and a `room_id`.
///
/// The `Sender` is a type alias for a `tokio::sync::mpsc::Sender<Message>`.
/// It is used to send messages to the client.
///
/// The `room_id` is an `Option<String>`. It is used to keep track of which
/// room the client is currently in. If the `room_id` is `None`, then the
/// client is not in any room. If the `room_id` is `Some(id)`, where `id` is a
/// `String`, then the client is in the room with the ID `id`.
///
/// The `room_id` is used to keep track of which room the client is in so
/// that the server knows which room to send messages to. When a client
/// joins a room, their `room_id` is set to the ID of the room that they
/// joined. When a client leaves a room, their `room_id` is set to `None`.
///
/// The `Client` struct is used to keep track of which room each client is
/// in. It is used by the `Server` to determine which room to send messages
/// to.
///
#[derive(Debug)] #[derive(Debug)]
pub struct Client { pub struct Client {
sender: Sender, sender: Sender,
@ -39,22 +18,6 @@ pub struct Client {
} }
impl Client { impl Client {
/// Creates a new `Client` instance.
///
/// The `sender` argument is a `Sender` for sending messages to the client.
/// It is used by the `Server` to send messages to the client.
///
/// The `room_id` field of the `Client` instance is set to `None` initially.
/// This is because the client is not in any room when they first connect
/// to the server.
///
/// The `sender` field of the `Client` instance is used to send messages to
/// the client. When the server wants to send a message to the client, it
/// uses the `sender` to send the message.
///
/// The `Client` instance is used by the `Server` to keep track of which
/// room each client is in. It is used by the `Server` to determine which
/// room to send messages to.
pub fn new(sender: Sender) -> Client { pub fn new(sender: Sender) -> Client {
Client { Client {
sender, sender,
@ -62,21 +25,6 @@ impl Client {
} }
} }
/// Sends a message to a client.
///
/// This function takes a `sender` argument, which is a `Mutex` guard
/// for a WebSocket connection. The `sender` is used to send a message
/// to the client.
///
/// The `message` argument is the message that is sent to the client. It
/// is a WebSocket message.
///
/// This function locks the `sender` Mutex to ensure that only one thread
/// can send a message at a time. This is because the SplitSink that the
/// `sender` mutex guards is not thread-safe, and sending a message from
/// multiple threads could result in the messages being sent out of order.
///
/// If sending the message fails, this function logs an error message.
async fn send(&self, sender: Sender, message: Message) { async fn send(&self, sender: Sender, message: Message) {
let mut sender = sender.lock().await; let mut sender = sender.lock().await;
if let Err(error) = sender.send(message).await { if let Err(error) = sender.send(message).await {
@ -84,82 +32,21 @@ impl Client {
} }
} }
/// Sends a packet to a client.
///
/// This function takes a `sender` argument, which is a `Mutex` guard
/// for a WebSocket connection. The `sender` is used to send a message
/// to the client.
///
/// The `packet` argument is the packet that is sent to the client. It
/// is a struct that contains the data that is being sent.
///
/// This function serializes the `packet` using serde_json and sends it
/// to the client as a WebSocket Text message.
///
/// This function locks the `sender` Mutex to ensure that only one thread
/// can send a message at a time. This is because the SplitSink that the
/// `sender` mutex guards is not thread-safe, and sending a message from
/// multiple threads could result in the messages being sent out of order.
async fn send_packet(&self, sender: Sender, packet: ResponsePacket) { async fn send_packet(&self, sender: Sender, packet: ResponsePacket) {
let serialized_packet = serde_json::to_string(&packet).unwrap(); let serialized_packet = serde_json::to_string(&packet).unwrap();
self.send(sender, Message::Text(serialized_packet)).await; self.send(sender, Message::Text(serialized_packet)).await;
} }
/// Sends an error packet to a client.
///
/// This function takes a `sender` argument, which is a `Mutex` guard
/// for a WebSocket connection. The `sender` is used to send a message
/// to the client.
///
/// The `message` argument is the message that is sent to the client. It
/// is a string that describes the error.
///
/// This function creates an error packet with the `message` and sends it
/// to the client using the `send_packet` function.
///
/// This function locks the `sender` Mutex to ensure that only one thread
/// can send a message at a time. This is because the SplitSink that the
/// `sender` mutex guards is not thread-safe, and sending a message from
/// multiple threads could result in the messages being sent out of order.
async fn send_error_packet(&self, sender: Sender, message: String) { async fn send_error_packet(&self, sender: Sender, message: String) {
let error_packet = ResponsePacket::Error { message }; let error_packet = ResponsePacket::Error { message };
self.send_packet(sender, error_packet).await self.send_packet(sender, error_packet).await
} }
/// Handles a CreateRoom request from a client.
///
/// This function is called when a client sends a CreateRoom request to
/// the server. The server will create a new room with the specified
/// size and return the room's identifier to the client.
///
/// This function takes a `server` argument, which is a `RwLock`
/// guard for the server's state. The `server` is used to check if the
/// current client is already in a room, and to insert the new room into
/// the server's state.
///
/// If the current client is already in a room, this function returns
/// without doing anything. This is to prevent a client from being in
/// multiple rooms at the same time.
///
/// If there is already a room with the same identifier as the one that
/// is being created, this function sends an error packet to the client
/// and returns.
///
/// If there is no existing room with the same identifier, this function
/// creates a new room with the specified size and inserts it into the
/// server's state. It then sends a CreateRoom response packet to the
/// client with the room's identifier.
///
/// This function locks the `server` RwLock to ensure that only one
/// thread can access the server's state at a time. This is because the
/// server's state is not thread-safe, and accessing it from multiple
/// threads could result in undefined behavior.
async fn handle_create_room(&mut self, server: &RwLock<AppState>, id: Option<String>) { async fn handle_create_room(&mut self, server: &RwLock<AppState>, id: Option<String>) {
let mut server = server.write().await; let mut server = server.write().await;
// If the current client is already in a room, do nothing.
if server.rooms.iter().any(|(_, room)| { if server.rooms.iter().any(|(_, room)| {
room.senders room.senders
.iter() .iter()
@ -168,15 +55,12 @@ impl Client {
return; return;
} }
// Generate a new room identifier.
let size = Room::DEFAULT_ROOM_SIZE; let size = Room::DEFAULT_ROOM_SIZE;
let room_id = match id { let room_id = match id {
Some(id) => id, Some(id) => id,
None => Uuid::new_v4().to_string(), None => Uuid::new_v4().to_string(),
}; };
// If there is already a room with the same identifier, send an error
// packet to the client and return.
if server.rooms.contains_key(&room_id) { if server.rooms.contains_key(&room_id) {
drop(server); drop(server);
@ -188,46 +72,23 @@ impl Client {
.await; .await;
} }
// Create a new room with the specified size and insert it into the
// server's state.
let mut room = Room::new(size); let mut room = Room::new(size);
room.senders.push(self.sender.clone()); room.senders.push(self.sender.clone());
server.rooms.insert(room_id.clone(), room); server.rooms.insert(room_id.clone(), room);
// Set the client's room ID to the new room's identifier.
self.room_id = Some(room_id.clone()); self.room_id = Some(room_id.clone());
drop(server); drop(server);
// Send a CreateRoom response packet to the client with the room's
// identifier.
debug!("Room created"); debug!("Room created");
self.send_packet(self.sender.clone(), ResponsePacket::Create { id: room_id }) self.send_packet(self.sender.clone(), ResponsePacket::Create { id: room_id })
.await .await
} }
/// This function is called when the client sends a JoinRoom packet.
///
/// If the client is already in a room, then this function does nothing.
///
/// If the client is not in a room, then the function checks if the room
/// specified in the packet exists. If the room does not exist, an error
/// packet is sent to the client with a message indicating that the room
/// does not exist.
///
/// If the room does exist, then the function checks if the room is full.
/// If the room is full, an error packet is sent to the client with a
/// message indicating that the room is full.
///
/// If the room is not full, then the client is added to the room and the
/// function sends a JoinRoom response packet to the client with the size
/// of the room (excluding the client itself) and a `size` field set to
/// `None`. The response packet is sent to all other clients in the room.
async fn handle_join_room(&mut self, server: &RwLock<AppState>, room_id: String) { async fn handle_join_room(&mut self, server: &RwLock<AppState>, room_id: String) {
let mut server = server.write().await; let mut server = server.write().await;
// If the client is already in a room, do nothing.
if server.rooms.iter().any(|(_, room)| { if server.rooms.iter().any(|(_, room)| {
room.senders room.senders
.iter() .iter()
@ -236,8 +97,6 @@ impl Client {
return; return;
} }
// Get a mutable reference to the room specified in the packet.
// If the room does not exist, return an error to the client.
let Some(room) = server.rooms.get_mut(&room_id) else { let Some(room) = server.rooms.get_mut(&room_id) else {
drop(server); drop(server);
@ -246,7 +105,6 @@ impl Client {
.await; .await;
}; };
// If the room is full, return an error to the client.
if room.senders.len() >= room.size { if room.senders.len() >= room.size {
drop(server); drop(server);
@ -255,16 +113,9 @@ impl Client {
.await; .await;
} }
// Add the client to the room and set the client's room ID to the new
// room's identifier.
room.senders.push(self.sender.clone()); room.senders.push(self.sender.clone());
self.room_id = Some(room_id); self.room_id = Some(room_id);
// Create a list of futures to send JoinRoom response packets to all
// other clients in the room. The `size` field of the response packet is
// set to `None` if the client sending the packet is the one joining the
// room. Otherwise, the `size` field is set to the number of clients in
// the room minus one (to exclude the client joining the room).
let mut futures = vec![]; let mut futures = vec![];
for sender in &room.senders { for sender in &room.senders {
if Arc::ptr_eq(sender, &self.sender) { if Arc::ptr_eq(sender, &self.sender) {
@ -283,120 +134,43 @@ impl Client {
join_all(futures).await; join_all(futures).await;
} }
/// Handles a request to leave a room.
///
/// This function is called when a client sends a `LeaveRoom` request
/// packet. The function obtains a write lock on the server's state and
/// does the following:
///
/// 1. Gets the room ID of the client who sent the request. If the client is
/// not in a room, the function returns early.
/// 2. Tries to get a mutable reference to the room with the obtained room
/// ID. If the room does not exist, the function returns early.
/// 3. Finds the index of the client's sender in the room's list of senders.
/// If the client is not in the room, the function returns early.
/// 4. Removes the client's sender from the room's list of senders.
/// 5. Sets the client's room ID to `None`.
/// 6. Creates a list of futures to send `LeaveRoom` response packets to
/// all other clients in the room. The `index` field of the response
/// packet is set to the index of the client's sender in the room's list
/// of senders.
/// 7. If the room is now empty, removes the room from the server's list
/// of rooms.
/// 8. Drops the write lock on the server's state.
/// 9. Waits for all futures to complete.
async fn handle_leave_room(&mut self, server: &RwLock<AppState>) { async fn handle_leave_room(&mut self, server: &RwLock<AppState>) {
// Obtain a write lock on the server's state.
let mut server = server.write().await; let mut server = server.write().await;
// Get the room ID of the client who sent the request.
let Some(room_id) = self.room_id.clone() else { let Some(room_id) = self.room_id.clone() else {
// If the client is not in a room, return early.
return; return;
}; };
// Try to get a mutable reference to the room with the obtained room ID.
let Some(room) = server.rooms.get_mut(&room_id) else { let Some(room) = server.rooms.get_mut(&room_id) else {
// If the room does not exist, return early.
return; return;
}; };
// Find the index of the client's sender in the room's list of senders.
let Some(index) = room let Some(index) = room
.senders .senders
.iter() .iter()
.position(|sender| Arc::ptr_eq(sender, &self.sender)) .position(|sender| Arc::ptr_eq(sender, &self.sender))
else { else {
// If the client is not in the room, return early.
return; return;
}; };
// Remove the client's sender from the room's list of senders.
room.senders.remove(index); room.senders.remove(index);
// Set the client's room ID to `None`.
self.room_id = None; self.room_id = None;
// Create a list of futures to send `LeaveRoom` response packets to
// all other clients in the room. The `index` field of the response
// packet is set to the index of the client's sender in the room's list
// of senders.
let mut futures = vec![]; let mut futures = vec![];
for sender in &room.senders { for sender in &room.senders {
futures.push(self.send_packet(sender.clone(), ResponsePacket::Leave { index })); futures.push(self.send_packet(sender.clone(), ResponsePacket::Leave { index }));
} }
// If the room is now empty, removes the room from the server's list
// of rooms.
if room.senders.is_empty() { if room.senders.is_empty() {
server.rooms.remove(&room_id); server.rooms.remove(&room_id);
} }
// Drop the write lock on the server's state.
drop(server); drop(server);
// Wait for all futures to complete.
join_all(futures).await; join_all(futures).await;
} }
/// This function handles an incoming message from a client.
///
/// The message can be one of four types: `Text`, `Binary`, `Ping`, or `Close`.
///
/// If the message is `Text`, the function parses the message as a `RequestPacket` and
/// calls the appropriate function to handle the request. If the message cannot be
/// parsed as a `RequestPacket`, the function does nothing and returns early.
///
/// If the message is `Binary`, the function first acquires a read lock on the server's
/// state. If the client is not currently in a room, the function drops the read lock and
/// returns early. If the client is not in a room, or if the room does not exist, the
/// function drops the read lock and returns early.
///
/// The function then finds the index of the client's sender in the room's list of
/// senders. If the client's sender is not in the room's list of senders, the function
/// drops the read lock and returns early.
///
/// The function then gets the binary data from the message and sets the first byte to
/// the index of the client's sender in the room's list of senders. If there is no
/// binary data in the message, the function drops the read lock and returns early.
///
/// The function then determines where to send the message. If the first byte of the
/// message is less than the number of clients in the room, the function sends the message
/// to the client at that index in the room's list of senders. If the first byte of the
/// message is equal to the number of clients in the room plus one, the function sends the
/// message to all clients in the room, excluding the client that sent the message.
///
/// If the first byte of the message is any other value, the function drops the read
/// lock and returns early.
///
/// Finally, the function drops the read lock and waits for all futures to complete.
///
/// If the message is `Ping`, the function prints a message to stdout.
///
/// If the message is `Pong`, the function prints a message to stdout.
///
/// If the message is `Close`, the function prints a message to stdout and calls the
/// `handle_close` function.
pub async fn handle_message(&mut self, server: &RwLock<AppState>, message: Message) { pub async fn handle_message(&mut self, server: &RwLock<AppState>, message: Message) {
match message { match message {
Message::Text(text) => { Message::Text(text) => {
@ -411,22 +185,18 @@ impl Client {
} }
} }
Message::Binary(_) => { Message::Binary(_) => {
// Acquire a read lock on the server's state.
let server = server.read().await; let server = server.read().await;
// If the client is not currently in a room, return early.
let Some(room_id) = &self.room_id else { let Some(room_id) = &self.room_id else {
drop(server); drop(server);
return; return;
}; };
// If the room does not exist, return early.
let Some(room) = server.rooms.get(room_id) else { let Some(room) = server.rooms.get(room_id) else {
drop(server); drop(server);
return; return;
}; };
// Find the index of the client's sender in the room's list of senders.
let Some(index) = room let Some(index) = room
.senders .senders
.iter() .iter()
@ -436,8 +206,6 @@ impl Client {
return; return;
}; };
// Get the binary data from the message and set the first byte to
// the index of the client's sender in the room's list of senders.
let mut data = message.into_data(); let mut data = message.into_data();
if data.is_empty() { if data.is_empty() {
drop(server); drop(server);
@ -446,12 +214,9 @@ impl Client {
let source = u8::try_from(index).unwrap(); let source = u8::try_from(index).unwrap();
// Determine where to send the message.
let destination = usize::from(data[0]); let destination = usize::from(data[0]);
data[0] = source; data[0] = source;
// Send the message to the client at the destination index in the
// room's list of senders.
if destination < room.senders.len() { if destination < room.senders.len() {
let sender = room.senders[destination].clone(); let sender = room.senders[destination].clone();
@ -459,8 +224,6 @@ impl Client {
return self.send(sender, Message::Binary(data)).await; return self.send(sender, Message::Binary(data)).await;
} }
// Send the message to all clients in the room, excluding the
// client that sent the message.
if destination == usize::from(u8::MAX) { if destination == usize::from(u8::MAX) {
let mut futures = vec![]; let mut futures = vec![];
for sender in &room.senders { for sender in &room.senders {

View file

@ -8,15 +8,6 @@ use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")] #[serde(tag = "type", rename_all = "camelCase")]
// This enum is used to represent the different types of requests that a client
// can send to the server.
//
// The requests that a client can send are:
//
// * Join: A request to join a room. The request contains the ID of the room
// that the client wants to join.
// * Create: A request to create a new room.
// * Leave: A request to leave the current room.
pub enum RequestPacket { pub enum RequestPacket {
Join { Join {
// The ID of the room that the client wants to join. // The ID of the room that the client wants to join.
@ -28,45 +19,20 @@ pub enum RequestPacket {
Leave, Leave,
} }
/// This enum is used to represent the different types of responses that the
/// server can send to the client.
///
/// The responses that the server can send are:
///
/// * Join: A response to a `Join` request from the client. If the client
/// successfully joined a room, the `size` field will be `Some` and contain
/// the size of the room. If the client could not join a room, the `size` field
/// will be `None`.
/// * Create: A response to a `Create` request from the client. If the server
/// successfully created a room, the `id` field will contain the ID of the
/// room. If the server could not create a room, the `id` field will be empty.
/// * Leave: A response to a `Leave` request from the client. If the client
/// successfully left a room, the `index` field will contain the index of the
/// client that left the room. If the client could not leave a room, the
/// `index` field will be 0.
/// * Error: A response to indicate that an error occurred. The `message`
/// field will contain a description of the error.
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")] #[serde(tag = "type", rename_all = "camelCase")]
pub enum ResponsePacket { pub enum ResponsePacket {
Join { Join {
/// The size of the room that the client joined. If the client could
/// not join a room, this field will be `None`.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
size: Option<usize>, size: Option<usize>,
}, },
Create { Create {
/// The ID of the room that the server created. If the server could
/// not create a room, this field will be empty.
id: String, id: String,
}, },
Leave { Leave {
/// The index of the client that left the room. If the client could not
/// leave a room, this field will be 0.
index: usize, index: usize,
}, },
Error { Error {
/// A description of the error that occurred.
message: String, message: String,
}, },
} }

View file

@ -3,44 +3,8 @@ use futures_util::stream::SplitSink;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
// A type alias for a sender to a WebSocket connection.
//
// The sender is a mutex-guarded, split sink of a WebSocket stream and Message
// values. It is used to send messages to a client.
//
// The Mutex is used to ensure that only one thread can send a message at a
// time. This is because the SplitSink is not thread-safe, and sending a
// message from multiple threads could result in the messages being sent
// out of order.
//
// The SplitSink is used to send messages to a client. It is the part of the
// WebSocket stream that handles the sending of messages.
//
// The WebSocket stream is the underlying connection to the client. It is used
// to send and receive messages.
//
// The Message value is the type of data that is sent over the WebSocket
// connection. It is a struct that contains the data that is being sent.
//
// The type alias is used so that the type is not mentioned every time it is
// used. This makes the code easier to read and understand.
type Sender = Arc<Mutex<SplitSink<WebSocket, Message>>>; type Sender = Arc<Mutex<SplitSink<WebSocket, Message>>>;
/// A `Room` is a collection of clients that are connected to each other.
///
/// Each room has a set of clients, represented by a `Vec` of `Sender`
/// instances. The `Sender` instances are used to send messages to the
/// clients in the room.
///
/// The `senders` field is the list of senders that are connected to each
/// other. Each sender is a mutex-guarded, split sink of a WebSocket
/// stream and Message values. This is explained in more detail in the
/// documentation for the `Sender` type alias in the `packets` module.
///
/// The `size` field is the maximum number of clients that a room can have.
/// When a room reaches its maximum size, no more clients can join the room.
/// This is used to prevent rooms from getting too full and causing the
/// server to run out of memory.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Room { pub struct Room {
pub senders: Vec<Sender>, pub senders: Vec<Sender>,
@ -48,25 +12,11 @@ pub struct Room {
} }
impl Room { impl Room {
/// The default size of a room.
///
/// This is the size that a room will have when it is created.
pub const DEFAULT_ROOM_SIZE: usize = 2; pub const DEFAULT_ROOM_SIZE: usize = 2;
/// Creates a new `Room` with the given size.
///
/// The `size` parameter is the maximum number of clients that can join the
/// room. If `size` is 0, then the room will not be able to hold any
/// clients.
///
/// The `senders` field of the returned `Room` is an empty vector.
///
/// The `size` field of the returned `Room` is `size`.
pub fn new(size: usize) -> Room { pub fn new(size: usize) -> Room {
Room { Room {
// Initialize the list of senders to be empty.
senders: Vec::new(), senders: Vec::new(),
// Set the size of the room.
size, size,
} }
} }

View file

@ -1,20 +1,3 @@
/// This function starts the WebSocket server.
///
/// It configures the server to listen on the specified host and port. If
/// these values are not specified in the environment, it falls back to using
/// the defaults of "0.0.0.0" for the host and "8000" for the port.
///
/// It then sets up the application routes for the server. In this case, the
/// only route is for the WebSocket connection.
///
/// The WebSocket route requires a `ConnectInfo` extractor to get the client's
/// IP address, which is then used to store the client in a data structure
/// keyed by their IP address. This allows for efficient lookup of clients by
/// their IP address.
///
/// Finally, it starts the server by binding to the specified host and port,
/// and running the application. If the server fails to bind to the specified
/// host and port, it logs an error and exits.
use axum::{ use axum::{
extract::{ws::WebSocket, Json, Path, State, WebSocketUpgrade}, extract::{ws::WebSocket, Json, Path, State, WebSocketUpgrade},
http::StatusCode, http::StatusCode,
@ -25,7 +8,7 @@ use axum::{
use futures_util::StreamExt; use futures_util::StreamExt;
use serde_json::json; use serde_json::json;
use std::{env, net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use tokio::{ use tokio::{
net::TcpListener, net::TcpListener,
signal, signal,
@ -38,64 +21,15 @@ use crate::relay::client::Client;
use crate::relay::transfer::TransferResponse; use crate::relay::transfer::TransferResponse;
use crate::relay::{appstate::AppState, transfer::TransferRequest}; use crate::relay::{appstate::AppState, transfer::TransferRequest};
/// This function starts the WebSocket server. pub async fn start_ws(port: &i32, listen_addr: &String) {
/// let app_host = listen_addr;
/// It retrieves the environment variables that define how the server should let app_port = port;
/// be configured. If any of these variables are not defined, it sets a
/// reasonable default value.
///
/// The environment variables are:
///
/// * `APP_ENVIRONMENT`: the environment the server is running in (defaults
/// to "development").
/// * `APP_HOST`: the host the server should listen on (defaults to "0.0.0.0").
/// * `APP_PORT`: the port the server should listen on (defaults to "8000").
/// * `APP_DOMAIN`: the domain the server is accessible at (defaults to "").
///
/// It then sets up the application routes for the server. In this case, the
/// only route is for the WebSocket connection.
///
/// The WebSocket route requires a `ConnectInfo` extractor to get the client's
/// IP address, which is then used to store the client in a data structure
/// keyed by their IP address. This allows for efficient lookup of clients by
/// their IP address.
///
/// Finally, it starts the server by binding to the specified host and port,
/// and running the application. If the server fails to bind to the specified
/// host and port, it logs an error and exits.
pub async fn start_ws(port: Option<&i32>, listen_addr: Option<&String>) {
// Retrieve environment variables and set defaults if necessary.
let app_environemt = env::var("APP_ENVIRONMENT").unwrap_or("development".to_string());
let app_host = match listen_addr {
Some(address) => address.to_string(),
None => env::var("APP_HOST").unwrap_or("0.0.0.0".to_string()),
};
let app_port = match port {
Some(port) => port.to_string(),
None => env::var("APP_PORT").unwrap_or("8000".to_string()),
};
// Log information about the server's configuration.
debug!("Server configured to accept connections on host {app_host}...",); debug!("Server configured to accept connections on host {app_host}...",);
debug!("Server configured to listen connections on port {app_port}...",); debug!("Server configured to listen connections on port {app_port}...",);
// Based on the environment variable, set the logging level.
match app_environemt.as_str() {
"development" => {
debug!("Running in development mode");
}
"production" => {
debug!("Running in production mode");
}
_ => {
debug!("Running in development mode");
}
}
// Create a new server data structure.
let server = AppState::new(); let server = AppState::new();
// Set up the application routes.
let app = Router::new() let app = Router::new()
.route("/ws", get(ws_handler)) .route("/ws", get(ws_handler))
.route("/upload", put(upload_info)) .route("/upload", put(upload_info))
@ -107,12 +41,9 @@ pub async fn start_ws(port: Option<&i32>, listen_addr: Option<&String>) {
.make_span_with(DefaultMakeSpan::default().include_headers(true)), .make_span_with(DefaultMakeSpan::default().include_headers(true)),
); );
// Attempt to bind to the specified host and port.
if let Ok(listener) = TcpListener::bind(&format!("{}:{}", app_host, app_port)).await { if let Ok(listener) = TcpListener::bind(&format!("{}:{}", app_host, app_port)).await {
// Log successful binding.
info!("Listening on: {}", listener.local_addr().unwrap()); info!("Listening on: {}", listener.local_addr().unwrap());
// Run the server.
axum::serve( axum::serve(
listener, listener,
app.into_make_service_with_connect_info::<SocketAddr>(), app.into_make_service_with_connect_info::<SocketAddr>(),
@ -121,92 +52,19 @@ pub async fn start_ws(port: Option<&i32>, listen_addr: Option<&String>) {
.await .await
.unwrap(); .unwrap();
} else { } else {
// Log binding failure and exit.
error!("Failed to listen on: {}:{}", app_host, app_port); error!("Failed to listen on: {}:{}", app_host, app_port);
} }
} }
/// This function is an endpoint for the WebSocket route.
///
/// This function is called whenever a client makes a WebSocket request to
/// the `/ws` endpoint.
///
/// The function takes two arguments:
///
/// - `ws`: This is the WebSocketUpgrade object, which is used to upgrade the
/// HTTP connection to a WebSocket connection.
/// - `State(shared_state)`: This is the state of the server, which is stored
/// in a read-write lock. The state is shared between all WebSocket
/// connections.
/// - `ConnectInfo(addr)`: This is the information about the client that
/// connected to the server. The function uses this information to log the
/// address of the client that connected to the server.
///
/// The function upgrades the HTTP connection to a WebSocket connection using
/// the `ws` argument. It then passes the upgraded WebSocket connection, along
/// with the state of the server, to the `handle_socket` function.
///
/// The `handle_socket` function is defined in the `src/relay/mod.rs` file. It
/// is the function that handles the WebSocket connection.
///
/// The `handle_socket` function takes three arguments:
///
/// - `socket`: This is the WebSocket connection that it should handle.
/// - `who`: This is the address of the client that connected to the server.
/// - `rooms`: This is the state of the server, which is stored in a read-write
/// lock. The state is shared between all WebSocket connections.
///
/// The `handle_socket` function handles the WebSocket connection by calling
/// the `handle_message` function on a `Client` object that it creates. The
/// `handle_message` function is defined in the `src/relay/client.rs` file. The
/// `handle_message` function handles incoming messages from the client and
/// takes care of sending the appropriate response back to the client.
pub async fn ws_handler( pub async fn ws_handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
State(shared_state): State<Arc<RwLock<AppState>>>, State(shared_state): State<Arc<RwLock<AppState>>>,
// ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse { ) -> impl IntoResponse {
debug!("Got Request on Websocket route"); debug!("Got Request on Websocket route");
// debug!("WebSocket connection established from:{}", addr.to_string());
debug!("Upgrading Connection"); debug!("Upgrading Connection");
ws.on_upgrade(move |socket| handle_socket(socket, shared_state)) ws.on_upgrade(move |socket| handle_socket(socket, shared_state))
} }
/// This function is called when a new WebSocket connection is established.
/// The function takes three arguments:
///
/// - `socket`: This is the WebSocket connection that it should handle.
/// - `who`: This is the address of the client that connected to the server.
/// - `rooms`: This is the state of the server, which is stored in a read-write
/// lock. The state is shared between all WebSocket connections.
///
/// The function creates a `Client` object, which will handle the WebSocket
/// connection. The `Client` object is created with an Arc-wrapped Mutex
/// containing the `sender` of the WebSocket connection. The `sender` is used to
/// send messages to the client.
///
/// The function then creates a new `split` of the WebSocket connection, which
/// is a pair of a `sender` and a `receiver`. The `sender` is used to send
/// messages to the client, and the `receiver` is used to receive messages from
/// the client. The `receiver` is wrapped in a `Stream` (which is an async
/// iterator) so that the function can use the `next` method to receive messages
/// from the client.
///
/// The function then enters a loop that receives incoming messages from the
/// client and handles them. For each received message, the function calls the
/// `handle_message` method on the `Client` object that it created. The
/// `handle_message` method is defined in the `src/relay/client.rs` file. The
/// `handle_message` method handles incoming messages from the client and
/// takes care of sending the appropriate response back to the client.
///
/// If the function encounters an error while reading a message from the
/// client, it logs the error and breaks out of the loop.
///
/// After the loop finishes (either because an error occurred or because the
/// client disconnected), the function calls the `handle_close` method on the
/// `Client` object that it created. The `handle_close` method is defined in the
/// `src/relay/client.rs` file. The `handle_close` method handles the close event
/// from the client.
async fn handle_socket(socket: WebSocket, rooms: Arc<RwLock<AppState>>) { async fn handle_socket(socket: WebSocket, rooms: Arc<RwLock<AppState>>) {
let (sender, mut receiver) = socket.split(); let (sender, mut receiver) = socket.split();
@ -227,32 +85,13 @@ async fn handle_socket(socket: WebSocket, rooms: Arc<RwLock<AppState>>) {
client.handle_close(&rooms).await client.handle_close(&rooms).await
} }
/// This function sets up a signal handler for SIGINT (Ctrl+C) and SIGTERM
/// (terminate) on Unix platforms. It does nothing on non-Unix platforms.
///
/// The function installs two signal handlers: one for SIGINT and one for
/// SIGTERM. When either of these signals is received, the signal handler
/// simply resolves the future with `()`. This allows the main function to
/// wait for the signal handler to trigger a shutdown.
///
/// The function uses the `tokio::select!` macro to wait for either of the
/// signal handlers to resolve. When the future returned by `tokio::select!`
/// resolves, the function simply drops the value and does nothing else.
///
/// The function does not actually do anything itself. It simply waits for
/// one of the signal handlers to trigger a shutdown.
async fn shutdown_signal() { async fn shutdown_signal() {
// Install a signal handler for SIGINT (Ctrl+C). This future resolves
// when the user presses Ctrl+C.
let ctrl_c = async { let ctrl_c = async {
signal::ctrl_c() signal::ctrl_c()
.await .await
.expect("failed to install Ctrl+C handler"); .expect("failed to install Ctrl+C handler");
}; };
// Install a signal handler for SIGTERM (terminate). This future
// resolves when the operating system sends a SIGTERM signal to the
// program.
#[cfg(unix)] #[cfg(unix)]
let terminate = async { let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate()) signal::unix::signal(signal::unix::SignalKind::terminate())
@ -261,30 +100,19 @@ async fn shutdown_signal() {
.await; .await;
}; };
// If we are not on a Unix platform, we don't need to install a signal
// handler for SIGTERM. Instead, we create a future that never resolves.
#[cfg(not(unix))] #[cfg(not(unix))]
let terminate = std::future::pending::<()>(); let terminate = std::future::pending::<()>();
// Wait for either of the two signal handlers to resolve. When one of them
// resolves, the other one may still be waiting, but it doesn't matter
// because we don't need to do anything else.
tokio::select! { tokio::select! {
// If the Ctrl+C signal handler resolves, drop the value and do
// nothing else.
_ = ctrl_c => {}, _ = ctrl_c => {},
// If the terminate signal handler resolves, drop the value and do
// nothing else.
_ = terminate => {}, _ = terminate => {},
} }
} }
pub async fn upload_info( pub async fn upload_info(
State(shared_state): State<Arc<RwLock<AppState>>>, State(shared_state): State<Arc<RwLock<AppState>>>,
// ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(payload): Json<TransferRequest>, Json(payload): Json<TransferRequest>,
) -> impl IntoResponse { ) -> impl IntoResponse {
// debug!("Got upload request from {}", addr.ip().to_string());
let mut data = shared_state.write().await; let mut data = shared_state.write().await;
match data match data
.transfers .transfers

View file

@ -31,94 +31,22 @@ const NONCE_SIZE: usize = 12;
const MAX_CHUNK_SIZE: isize = u16::MAX as isize; const MAX_CHUNK_SIZE: isize = u16::MAX as isize;
const DELAY: Duration = Duration::from_millis(750); const DELAY: Duration = Duration::from_millis(750);
/// A file that is to be sent.
///
/// This structure contains all the information about a file that is to be
/// sent. It is used to keep track of the files that a user wants to send.
#[derive(Clone)] #[derive(Clone)]
struct File { struct File {
/// The path to the file on the file system.
///
/// This is the path to the file on the user's file system. The path is
/// used to open the file and read its contents.
path: String, path: String,
/// The name of the file.
///
/// This is the name that the file will have when it is received by the
/// receiver. This name is used when creating the file on the receiver's
/// file system.
name: String, name: String,
/// The size of the file in bytes.
///
/// This is the size of the file in bytes. The size is used to calculate
/// the number of chunks that the file will be split into, and is also
/// used to keep track of the progress of the file being sent.
size: u64, size: u64,
} }
/// The context for the sender.
///
/// This structure contains all the information that the sender needs in order
/// to function properly. It is used to keep track of the state of the
/// sender, and to pass information between functions.
struct Context { struct Context {
/// The HMAC key for the sender.
///
/// This is the key that is used to sign packets. The key is also used to
/// generate a URL that the receiver can use to join the session.
hmac: Vec<u8>, hmac: Vec<u8>,
/// The sender that is used to send packets to the receiver.
///
/// This sender is used to send handshake packets, list packets, chunk
/// packets, and progress packets to the receiver.
sender: Sender, sender: Sender,
/// The ephemeral keypair that is used to establish a shared key with the
/// receiver.
///
/// This key is used to establish a shared key between the sender and
/// receiver. The key is ephemeral, meaning that it is only used once in
/// the session. The key is generated when the sender is created, and is
/// then discarded after the session is complete.
key: EphemeralSecret, key: EphemeralSecret,
/// The files that the sender wants to send.
///
/// This vec contains all the information about the files that the sender
/// wants to send. The vec is filled when the user specifies the files to
/// send using the command line arguments.
files: Vec<File>, files: Vec<File>,
/// The shared key that is used to encrypt packets.
///
/// This value is set to `None` initially, and is set to `Some` when the
/// shared key is established with the receiver. The shared key is used to
/// encrypt packets that are sent to the receiver.
shared_key: Option<Aes128Gcm>, shared_key: Option<Aes128Gcm>,
/// The task that is running in the background to send chunks of files to
/// the receiver.
///
/// This task is created when the sender is created, and is used to send
/// chunks of files to the receiver in the background. The task is
/// initially set to `None`, but is set to `Some` when the task is spawned.
/// The task is used to cancel the background task when the sender is
/// dropped.
task: Option<JoinHandle<()>>, task: Option<JoinHandle<()>>,
} }
/// This function is called when the client receives a create room packet
/// from the server. The function is responsible for printing a URL to the
/// console that the user can use to join the room.
///
/// The function first generates a base64 string from the hmac value that is
/// used to verify the integrity of the room. The base64 string is then
/// appended to the room id to create a URL. The URL is then printed to the
/// console using the qr2term library. Finally, the function prints a
/// message to the console with the URL.
fn on_create_room( fn on_create_room(
context: &Context, context: &Context,
id: String, id: String,
@ -130,7 +58,6 @@ fn on_create_room(
let base64 = general_purpose::STANDARD.encode(&context.hmac); let base64 = general_purpose::STANDARD.encode(&context.hmac);
let url = format!("{}-{}", id, base64); let url = format!("{}-{}", id, base64);
// let rand_name = generate_random_name();
let hash_name = hash_random_name(transfer_name.clone()); let hash_name = hash_random_name(transfer_name.clone());
let send_url = url.to_string(); let send_url = url.to_string();
@ -146,8 +73,6 @@ fn on_create_room(
.join() .join()
.unwrap(); .unwrap();
debug!("Got Result: {:?}", res); debug!("Got Result: {:?}", res);
// Print a newline to the console to separate the output from the command
// line.
match res { match res {
Ok(transfer_response) => { Ok(transfer_response) => {
if !transfer_response.local_room_id.is_empty() if !transfer_response.local_room_id.is_empty()
@ -155,20 +80,11 @@ fn on_create_room(
{ {
println!(); println!();
// Try to generate a QR code from the URL. If the function fails for some
// reason, print an error message to the console.
// if let Err(error) = qr2term::print_qr(&url) {
// error!("Failed to generate QR code: {}", error);
// }
if let Err(error) = qr2term::print_qr(&transfer_name) { if let Err(error) = qr2term::print_qr(&transfer_name) {
error!("Failed to generate QR code: {}", error); error!("Failed to generate QR code: {}", error);
} }
// Print a newline to the console to separate the output from the command
// line.
println!(); println!();
// Print a message to the console with the URL.
println!("Created room: {}", url); println!("Created room: {}", url);
println!("Transfername is: {}", transfer_name); println!("Transfername is: {}", transfer_name);
} }
@ -178,43 +94,21 @@ fn on_create_room(
} }
} }
// Continue the event loop.
Status::Continue() Status::Continue()
} }
/// This function is called when the client receives a join room packet from
/// the server. The function is responsible for sending a handshake packet to
/// the server containing the client's public key and a signature generated
/// using the client's private key and the room's hmac value.
///
/// The function first generates the client's public key from the private key.
/// The public key is then serialized into a byte array.
///
/// Next, the function creates a HMAC object with the room's hmac value and
/// updates it with the serialized public key. The resulting HMAC is then
/// serialized into a byte array and used as the signature in the handshake
/// packet.
///
/// Finally, the function sends the handshake packet to the server using the
/// sender object.
fn on_join_room(context: &Context, size: Option<usize>) -> Status { fn on_join_room(context: &Context, size: Option<usize>) -> Status {
if size.is_some() { if size.is_some() {
return Status::Err("Invalid join room packet.".into()); return Status::Err("Invalid join room packet.".into());
} }
// Generate the client's public key from the private key.
let public_key = context.key.public_key().to_sec1_bytes().into_vec(); let public_key = context.key.public_key().to_sec1_bytes().into_vec();
// Create a HMAC object with the room's hmac value and update
// it with the serialized public key.
let mut mac = Hmac::<Sha256>::new_from_slice(&context.hmac).unwrap(); let mut mac = Hmac::<Sha256>::new_from_slice(&context.hmac).unwrap();
mac.update(&public_key); mac.update(&public_key);
// Serialize the resulting HMAC into a byte array and use it as the
// signature in the handshake packet.
let signature = mac.finalize().into_bytes().to_vec(); let signature = mac.finalize().into_bytes().to_vec();
// Create the handshake packet and send it to the server.
let handshake = HandshakePacket { let handshake = HandshakePacket {
public_key, public_key,
signature, signature,
@ -227,106 +121,27 @@ fn on_join_room(context: &Context, size: Option<usize>) -> Status {
Status::Continue() Status::Continue()
} }
/// This function is called when an error packet is received from the
/// server. It creates a `Status::Err` variant containing the error
/// message from the server and returns it to be handled by the main
/// event loop.
///
/// When an error occurs, the server sends an error packet to the
/// client. The error packet contains a message with a description of
/// the error. This function extracts that message and creates a
/// `Status::Err` variant with it, which is then returned to be handled
/// by the main event loop.
///
/// The main event loop checks the status of the client and performs
/// the necessary actions based on its value. If the status is
/// `Status::Err`, the event loop exits with an error message
/// containing the error message from the server.
///
/// This function is called from the event loop when an error packet is
/// received from the server.
fn on_error(message: String) -> Status { fn on_error(message: String) -> Status {
Status::Err(message) Status::Err(message)
} }
/// This function is called when the server sends a leave room packet to
/// the client. It is responsible for aborting the file transfer task,
/// generating a new ECDH key pair for the next handshake, and setting the
/// shared key to `None`.
///
/// When the server sends a leave room packet to the client, it means that
/// the receiver has disconnected from the room. In this case, the client
/// should abort the file transfer task and print an error message to the
/// user.
///
/// If the client is currently transferring files, it should abort the task
/// by calling `AbortHandle::abort` on the task handle.
///
/// After that, the client should generate a new ECDH key pair using the
/// `EphemeralSecret::random` function from the `p256` crate. This key pair
/// will be used for the next handshake with the server.
///
/// Finally, the client should set the shared key to `None` to indicate that
/// there is no shared key established for the current room.
///
/// This function is called from the event loop when a leave room packet is
/// received from the server.
fn on_leave_room(context: &mut Context, _: usize) -> Status { fn on_leave_room(context: &mut Context, _: usize) -> Status {
if let Some(task) = &context.task { if let Some(task) = &context.task {
// If the client is currently transferring files, abort the task
// by calling `AbortHandle::abort` on the task handle.
task.abort(); task.abort();
} }
// Generate a new ECDH key pair for the next handshake.
context.key = EphemeralSecret::random(&mut OsRng); context.key = EphemeralSecret::random(&mut OsRng);
// Set the shared key to `None` to indicate that there is no shared key
// established for the current room.
context.shared_key = None; context.shared_key = None;
// Set the task handle to `None` to indicate that there is no task
// running.
context.task = None; context.task = None;
// Print an error message to the user indicating that the transfer was
// interrupted because the receiver disconnected.
println!(); println!();
error!("Transfer was interrupted because the receiver disconnected."); error!("Transfer was interrupted because the receiver disconnected.");
// Continue the event loop.
Status::Continue() Status::Continue()
} }
/// This function is called by the event loop when a progress packet is
/// received from the server.
///
/// The progress packet contains the index of the file that is being
/// transferred and the current progress of that file as a percentage.
///
/// If the client does not have a shared key established with the server,
/// the function returns an error and does not continue. This indicates
/// that the event loop should exit with an error message.
///
/// The function then retrieves the file at the index specified by the
/// progress packet from the context. If the index is out of bounds, the
/// function returns an error and does not continue. This indicates that
/// the event loop should exit with an error message.
///
/// The function then prints a message to the console indicating which file
/// is currently being transferred and what its progress is. The progress
/// message is printed to the same line as a carriage return (`\r`) so that
/// it overwrites the previous message.
///
/// If the progress of the file is 100%, the function prints a newline
/// (`\n`) to the console to move the cursor to the next line.
///
/// If the progress of the last file is 100%, the function returns
/// `Status::Exit()`. This indicates that the event loop should exit
/// successfully.
///
/// If any other condition is met, the function returns `Status::Continue()`.
/// This indicates that the event loop should continue running.
fn on_progress(context: &Context, progress: ProgressPacket) -> Status { fn on_progress(context: &Context, progress: ProgressPacket) -> Status {
if context.shared_key.is_none() { if context.shared_key.is_none() {
return Status::Err("Invalid progress packet: no shared key established".into()); return Status::Err("Invalid progress packet: no shared key established".into());
@ -351,50 +166,12 @@ fn on_progress(context: &Context, progress: ProgressPacket) -> Status {
Status::Continue() Status::Continue()
} }
/// This function reads a file in chunks, sends each chunk to the receiver over
/// the WebSocket connection, and then sleeps for a short amount of time
/// before sending the next chunk.
///
/// The function takes the sender, the shared key, and a vector of files to
/// transfer as arguments.
///
/// For each file in the vector of files, the function reads the file in
/// chunks, sends each chunk to the receiver over the WebSocket connection,
/// and then sleeps for a short amount of time before sending the next chunk.
///
/// The chunk size is set to the maximum chunk size. If the number of bytes
/// left to read in the file is less than the chunk size, the chunk size is set
/// to the number of bytes left to read.
///
/// The function opens the file for reading using the tokio::fs::File::open
/// function. If there is an error opening the file, the function prints an
/// error message to the console and returns.
///
/// The function reads the file in chunks using the read_exact function from
/// the tokio::io::AsyncReadExt trait. If there is an error reading from the
/// file, the function prints an error message to the console and returns.
///
/// The function sends each chunk to the receiver over the WebSocket
/// connection using the send_encrypted_packet function from the Sender struct.
/// The function also increments the sequence number for each chunk that is
/// sent.
///
/// After sending all of the chunks for a file, the function sleeps for a short
/// amount of time using the tokio::time::sleep function. This helps to prevent
/// the sender from overwhelming the receiver with too many messages.
///
/// The function repeats this process for all of the files in the vector of
/// files.
async fn on_chunk(sender: Sender, shared_key: Option<Aes128Gcm>, files: Vec<File>) { async fn on_chunk(sender: Sender, shared_key: Option<Aes128Gcm>, files: Vec<File>) {
for file in files { for file in files {
// Initialize a sequence number for the chunks of this file
let mut sequence = 0; let mut sequence = 0;
// Set the chunk size to the maximum chunk size
let mut chunk_size = MAX_CHUNK_SIZE; let mut chunk_size = MAX_CHUNK_SIZE;
// Set the number of bytes left to read in the file
let mut size = file.size as isize; let mut size = file.size as isize;
// Open the file for reading
let mut handle = match tokio::fs::File::open(file.path).await { let mut handle = match tokio::fs::File::open(file.path).await {
Ok(handle) => handle, Ok(handle) => handle,
Err(error) => { Err(error) => {
@ -404,64 +181,34 @@ async fn on_chunk(sender: Sender, shared_key: Option<Aes128Gcm>, files: Vec<File
}; };
while size > 0 { while size > 0 {
// If the number of bytes left to read in the file is less than the
// chunk size, set the chunk size to the number of bytes left to read
if size < chunk_size { if size < chunk_size {
chunk_size = size; chunk_size = size;
} }
// Create a vector to hold the chunk of data to be read from the file
let mut chunk = vec![0u8; chunk_size.try_into().unwrap()]; let mut chunk = vec![0u8; chunk_size.try_into().unwrap()];
// Read a chunk of data from the file into the vector
handle.read_exact(&mut chunk).await.unwrap(); handle.read_exact(&mut chunk).await.unwrap();
// Send the chunk to the receiver over the WebSocket connection
sender.send_encrypted_packet( sender.send_encrypted_packet(
&shared_key, &shared_key,
DESTINATION, DESTINATION,
Value::Chunk(ChunkPacket { sequence, chunk }), Value::Chunk(ChunkPacket { sequence, chunk }),
); );
// Increment the sequence number for the next chunk
sequence += 1; sequence += 1;
// Decrement the number of bytes left to read in the file
size -= chunk_size; size -= chunk_size;
} }
// Sleep for a short amount of time to prevent overwhelming the receiver
// with too many messages
sleep(DELAY).await; sleep(DELAY).await;
} }
} }
/// This function sends a ListPacket to the receiver containing the list of
/// files to be transferred. The ListPacket contains a vector of Entry structs,
/// each of which represents one file.
///
/// The function creates a vector of Entry structs from the vector of File structs
/// in the Context struct. Each Entry struct contains the index, name, and size
/// of the corresponding File struct.
///
/// The function then sends the ListPacket to the receiver using the send_encrypted_packet
/// function from the Sender struct.
///
/// After sending the ListPacket, the function spawns a task using tokio::spawn to
/// call the on_chunk function with the Sender, shared_key, and vector of File
/// structs as arguments. The on_chunk function will send each chunk of data for
/// each file to the receiver.
///
/// The function returns Status::Continue(), which tells the main loop to continue
/// running until all of the files have been transferred.
fn on_handshake_finalize(context: &mut Context) -> Status { fn on_handshake_finalize(context: &mut Context) -> Status {
let mut entries = vec![]; let mut entries = vec![];
for (index, file) in context.files.iter().enumerate() { for (index, file) in context.files.iter().enumerate() {
let entry = list_packet::Entry { let entry = list_packet::Entry {
// The index of the file in the vector of Files in the Context struct
index: index.try_into().unwrap(), index: index.try_into().unwrap(),
// The name of the file
name: file.name.clone(), name: file.name.clone(),
// The size of the file in bytes
size: file.size, size: file.size,
}; };
@ -483,90 +230,34 @@ fn on_handshake_finalize(context: &mut Context) -> Status {
Status::Continue() Status::Continue()
} }
/// Handshake function that is called when the Sender receives a HandshakeResponsePacket
/// from the Receiver. This function verifies the signature from the Receiver and if
/// successful, creates a shared key using the from the PublicKey struct.
///
/// The shared key is used to encrypt and decrypt packets sent between the Sender
/// and the Receiver.
///
/// This function is called by the main loop in client.rs.
fn on_handshake(context: &mut Context, handshake_response: HandshakeResponsePacket) -> Status { fn on_handshake(context: &mut Context, handshake_response: HandshakeResponsePacket) -> Status {
if context.shared_key.is_some() { if context.shared_key.is_some() {
// If the shared key is already established, this means that the Sender
// has already performed the handshake, so return an error.
return Status::Err("Already performed handshake.".into()); return Status::Err("Already performed handshake.".into());
} }
// Create a new HMAC using the hmac from the Context struct as the key.
let mut mac = Hmac::<Sha256>::new_from_slice(&context.hmac).unwrap(); let mut mac = Hmac::<Sha256>::new_from_slice(&context.hmac).unwrap();
// Update the HMAC with the public key from the HandshakeResponsePacket.
mac.update(&handshake_response.public_key); mac.update(&handshake_response.public_key);
// Call verify_slice() on the HMAC to verify the signature from the Receiver.
// If the signature is invalid, return an error.
let verification = mac.verify_slice(&handshake_response.signature); let verification = mac.verify_slice(&handshake_response.signature);
if verification.is_err() { if verification.is_err() {
return Status::Err("Invalid signature from the receiver.".into()); return Status::Err("Invalid signature from the receiver.".into());
} }
// Create a new PublicKey struct from the public key bytes in the
// HandshakeResponsePacket.
let shared_public_key = PublicKey::from_sec1_bytes(&handshake_response.public_key).unwrap(); let shared_public_key = PublicKey::from_sec1_bytes(&handshake_response.public_key).unwrap();
// Use the diffie_hellman() method from the PublicKey struct to create a shared
// secret key between the Sender and the Receiver. The shared secret key is a
// 16 byte long slice of bytes.
let shared_secret = context.key.diffie_hellman(&shared_public_key); let shared_secret = context.key.diffie_hellman(&shared_public_key);
let shared_secret = shared_secret.raw_secret_bytes(); let shared_secret = shared_secret.raw_secret_bytes();
let shared_secret = &shared_secret[0..16]; let shared_secret = &shared_secret[0..16];
// Create a new Key struct from the shared secret key. The Key<Aes128Gcm> type
// is used to encrypt and decrypt packets.
let shared_key: &Key<Aes128Gcm> = shared_secret.into(); let shared_key: &Key<Aes128Gcm> = shared_secret.into();
let shared_key = <Aes128Gcm as aes_gcm::KeyInit>::new(shared_key); let shared_key = <Aes128Gcm as aes_gcm::KeyInit>::new(shared_key);
// Set the shared_key field of the Context struct to the shared key.
context.shared_key = Some(shared_key); context.shared_key = Some(shared_key);
// Call on_handshake_finalize() to start the transfer of files between the
// Sender and the Receiver.
on_handshake_finalize(context) on_handshake_finalize(context)
} }
/// This function is called by the `Sender` when a new message is received over
/// the WebSocket connection. The message could be a text message or a binary
/// message. If it is a text message, it will be deserialized into a
/// `JsonPacketResponse` enum. If it is a binary message, it will be decrypted
/// if necessary and then deserialized into a `Packet` struct.
///
/// The `JsonPacketResponse` enum will have one of the following variants:
///
/// * `Create { id }`: The Receiver has created a new room with the given ID.
/// * `Join { size }`: The Receiver has joined a room with `size` number of
/// files.
/// * `Leave { index }`: The Receiver has left a room.
/// * `Error { message }`: The Receiver has encountered an error.
///
/// If the message is a binary message, the `Packet` struct will have a
/// `Value` variant that will have one of the following variants:
///
/// * `HandshakeResponse`: The Receiver has responded to the Sender's
/// `Handshake` packet.
/// * `Progress`: The Receiver has sent progress information for one of the
/// files in the room.
///
/// This function does the following:
///
/// * If the message is a text message, it is deserialized into a
/// `JsonPacketResponse` enum and then matched on to call the appropriate
/// function.
/// * If the message is a binary message, it is decrypted if necessary and then
/// deserialized into a `Packet` struct. The `Value` variant of the `Packet`
/// struct is then matched on to call the appropriate function.
///
/// If the message is invalid, an error is returned.
fn on_message( fn on_message(
context: &mut Context, context: &mut Context,
message: WebSocketMessage, message: WebSocketMessage,
@ -574,60 +265,52 @@ fn on_message(
transfer_name: String, transfer_name: String,
is_local: bool, is_local: bool,
) -> Status { ) -> Status {
if message.is_text() { match message.clone() {
let text = message.into_text().unwrap(); WebSocketMessage::Text(text) => {
let packet = serde_json::from_str(&text).unwrap(); let packet = match serde_json::from_str(&text) {
Ok(packet) => packet,
Err(_) => {
return Status::Continue();
}
};
return match packet {
JsonPacketResponse::Create { id } => {
on_create_room(context, id, relay, transfer_name, is_local)
}
JsonPacketResponse::Join { size } => on_join_room(context, size),
JsonPacketResponse::Leave { index } => on_leave_room(context, index),
JsonPacketResponse::Error { message } => on_error(message),
};
}
WebSocketMessage::Binary(data) => {
let data = data[1..].to_vec();
return match packet { let data = if let Some(shared_key) = &context.shared_key {
JsonPacketResponse::Create { id } => { let nonce = &data[..NONCE_SIZE];
on_create_room(context, id, relay, transfer_name, is_local) let ciphertext = &data[NONCE_SIZE..];
shared_key.decrypt(nonce.into(), ciphertext).unwrap()
} else {
data
};
let packet = Packet::decode(data.as_ref()).unwrap();
let value = packet.value.unwrap();
return match value {
Value::HandshakeResponse(handshake_response) => {
on_handshake(context, handshake_response)
}
Value::Progress(progress) => on_progress(context, progress),
_ => Status::Err(format!("Unexpected packet: {:?}", value)),
} }
JsonPacketResponse::Join { size } => on_join_room(context, size), }
JsonPacketResponse::Leave { index } => on_leave_room(context, index), _ => (),
JsonPacketResponse::Error { message } => on_error(message),
};
} else if message.is_binary() {
let data = message.into_data();
let data = &data[1..];
let data = if let Some(shared_key) = &context.shared_key {
let nonce = &data[..NONCE_SIZE];
let ciphertext = &data[NONCE_SIZE..];
shared_key.decrypt(nonce.into(), ciphertext).unwrap()
} else {
data.to_vec()
};
let packet = Packet::decode(data.as_ref()).unwrap();
let value = packet.value.unwrap();
return match value {
Value::HandshakeResponse(handshake_response) => {
on_handshake(context, handshake_response)
}
Value::Progress(progress) => on_progress(context, progress),
_ => Status::Err(format!("Unexpected packet: {:?}", value)),
};
} }
Status::Err("Invalid message type".into()) Status::Err("Invalid message type".into())
} }
/// Starts the sender client. This function will attempt to create a room with a size of 2
/// (the number of clients that will be joining the room) and then it will open a file for
/// each of the paths provided. It will then read chunks of data from each file and send them
/// to the server.
///
/// This function takes two arguments:
/// 1. `socket`: A `Socket` that represents the connection to the server.
/// 2. `paths`: A `Vec` of `String`s that represent the paths to the files that will be sent
/// to the server.
///
/// When the function is finished, it will exit and the transfer will be complete. If there
/// is an error during the transfer, the function will print an error message to stdout and
/// exit.
pub async fn start( pub async fn start(
socket: Socket, socket: Socket,
paths: Vec<String>, paths: Vec<String>,
@ -636,44 +319,33 @@ pub async fn start(
transfer_name: String, transfer_name: String,
is_local: bool, is_local: bool,
) { ) {
// Create a vector to store metadata about each file that will be sent.
let mut files = vec![]; let mut files = vec![];
// For each path in the `paths` vector:
for path in paths { for path in paths {
// Attempt to open the file at the given path.
let handle = match fs::File::open(&path) { let handle = match fs::File::open(&path) {
// If the file is successfully opened, store it in the `handle` variable.
Ok(handle) => handle, Ok(handle) => handle,
// If there is an error, print an error message to stdout and exit the function.
Err(error) => { Err(error) => {
error!("Error: Failed to open file '{}': {}", path, error); error!("Error: Failed to open file '{}': {}", path, error);
return; return;
} }
}; };
// Get the metadata for the file.
let metadata = handle.metadata().unwrap(); let metadata = handle.metadata().unwrap();
// If the file is a directory, print an error message to stdout and exit the function.
if metadata.is_dir() { if metadata.is_dir() {
error!("Error: The path '{}' does not point to a file.", path); error!("Error: The path '{}' does not point to a file.", path);
return; return;
} }
// Get the file name from the path.
let name = Path::new(&path).file_name().unwrap().to_str().unwrap(); let name = Path::new(&path).file_name().unwrap().to_str().unwrap();
// Get the file size from the metadata.
let size = metadata.len(); let size = metadata.len();
// If the file is empty, print an error message to stdout and exit the function.
if size == 0 { if size == 0 {
error!("Error: The file '{}' is empty and cannot be sent.", name); error!("Error: The file '{}' is empty and cannot be sent.", name);
return; return;
} }
// Add the file metadata to the `files` vector.
files.push(File { files.push(File {
name: name.to_string(), name: name.to_string(),
path, path,
@ -681,54 +353,35 @@ pub async fn start(
}); });
} }
// Generate a random key for HMAC.
let mut hmac = [0u8; 32]; let mut hmac = [0u8; 32];
OsRng.fill_bytes(&mut hmac); OsRng.fill_bytes(&mut hmac);
// Generate a random key for AES-GCM.
let key = EphemeralSecret::random(&mut OsRng); let key = EphemeralSecret::random(&mut OsRng);
// Create a channel to send packets to the server.
let (sender, receiver) = flume::bounded(1000); let (sender, receiver) = flume::bounded(1000);
// Split the socket into separate send and receive streams.
let (outgoing, incoming) = socket.split(); let (outgoing, incoming) = socket.split();
// Create a context that will be used throughout the transfer.
let mut context = Context { let mut context = Context {
// Store the sender half of the channel to send packets to the server.
sender, sender,
// Store the ephemeral key for AES-GCM.
key, key,
// Store the files that will be sent to the server.
files, files,
// Store the HMAC key.
hmac: hmac.to_vec(), hmac: hmac.to_vec(),
// Set the shared key to None.
shared_key: None, shared_key: None,
// Set the current task to None.
task: None, task: None,
}; };
// Print a message to stdout indicating that the client is attempting to create a room.
debug!("Attempting to create room..."); debug!("Attempting to create room...");
// Send a JSON packet to the server to create a room with a size of 2.
debug!("With Room-ID: {:?}", room_id); debug!("With Room-ID: {:?}", room_id);
context.sender.send_json_packet(JsonPacket::Create { context.sender.send_json_packet(JsonPacket::Create {
id: room_id.clone(), id: room_id.clone(),
}); });
// context.sender.send_json_packet(JsonPacket::Create);
// Create a future that handles the outgoing stream of messages from the client to the
// server.
let outgoing_handler = receiver.stream().map(Ok).forward(outgoing); let outgoing_handler = receiver.stream().map(Ok).forward(outgoing);
// Create a future that handles the incoming stream of messages from the server to the
// client.
let incoming_handler = incoming.try_for_each(|message| { let incoming_handler = incoming.try_for_each(|message| {
// Call the `on_message` function to handle the incoming message.
match on_message( match on_message(
&mut context, &mut context,
message, message,
@ -736,37 +389,26 @@ pub async fn start(
transfer_name.clone(), transfer_name.clone(),
is_local, is_local,
) { ) {
// If the status is `Status::Exit`, the transfer is complete. Print a message to
// stdout and exit the function.
Status::Exit() => { Status::Exit() => {
// TODO: Signal Exit to the server // TODO: Signal Exit to the server
context.sender.send_json_packet(JsonPacket::Leave);
println!("Transfer has completed."); println!("Transfer has completed.");
// Exit the function with a `Result` of `Err`.
return future::err(Error::ConnectionClosed); return future::err(Error::ConnectionClosed);
} }
// If the status is `Status::Err`, there was an error. Print an error message to
// stdout and exit the function.
Status::Err(error) => { Status::Err(error) => {
error!("Error: {}", error); error!("Error: {}", error);
// Exit the function with a `Result` of `Err`.
return future::err(Error::ConnectionClosed); return future::err(Error::ConnectionClosed);
} }
// Otherwise, the message was handled successfully.
_ => {} _ => {}
}; };
// Continue handling the incoming messages.
future::ok(()) future::ok(())
}); });
// Pin the `incoming_handler` and `outgoing_handler` futures so that they do not move.
pin_mut!(incoming_handler, outgoing_handler); pin_mut!(incoming_handler, outgoing_handler);
// Wait for either the `incoming_handler` or `outgoing_handler` to complete. If the
// `incoming_handler` completes, return the result of the `incoming_handler`. If the
// `outgoing_handler` completes, return the result of the `outgoing_handler`.
future::select(incoming_handler, outgoing_handler).await; future::select(incoming_handler, outgoing_handler).await;
} }

View file

@ -1,43 +1,3 @@
/// Connects to the WebSocket server at `ws://0.0.0.0:8000/ws` with an
/// `Origin` header of `ws://0.0.0.0:8000/ws`. This is the URL that the
/// sender and receiver clients will connect to.
///
/// The `start_sender` function takes a reference to a vector of strings,
/// which are the paths to the files that the sender will send over the
/// WebSocket connection.
///
/// The function first creates a WebSocket request using the `IntoClientRequest`
/// trait from `tungstenite`, which is defined on the `IntoClientRequest` struct.
/// This struct is a type that represents a request to a WebSocket server.
///
/// The `into_client_request` function returns a `Result` because it may fail
/// to create the request. In this case, we do not handle the error, so we just
/// return if the result is an error.
///
/// Once we have a request, we insert the `Origin` header into the headers of
/// the request. This is necessary because the WebSocket protocol requires the
/// `Origin` header to be present in the handshake.
///
/// After that, we print out a message to the console indicating that we are
/// attempting to connect to the server.
///
/// Next, we call the `connect_async` function from `tokio_tungstenite` which
/// takes our request and attempts to connect to the server. This function
/// returns a `Future` that resolves to a tuple of a `WebSocketStream` and a
/// `Response` from the server. The `WebSocketStream` is a stream of
/// WebSocket messages from the server, and the `Response` is the response
/// from the server to our handshake request.
///
/// If connecting to the server fails, we print out an error message and
/// return.
///
/// If connecting to the server succeeds, we pass the `WebSocketStream` and
/// the paths to the files to the `start` function from the `sender` module.
/// The `start` function is defined in the `sender` module, and it is the
/// function that sends the files over the WebSocket connection.
///
/// The `start` function takes ownership of the `WebSocketStream` and the file
/// paths, so we pass it the `paths` vector by value.
pub mod client; pub mod client;
pub mod http_client; pub mod http_client;
pub mod util; pub mod util;
@ -106,10 +66,8 @@ pub async fn start_local_ws() {
let app_host = "0.0.0.0"; let app_host = "0.0.0.0";
let app_port = "9000"; let app_port = "9000";
// Create a new server data structure.
let server = AppState::new(); let server = AppState::new();
// Set up the application routes.
let app = Router::new() let app = Router::new()
.route("/ws", get(ws_handler)) .route("/ws", get(ws_handler))
.with_state(server) .with_state(server)
@ -124,7 +82,6 @@ pub async fn start_local_ws() {
listener.local_addr().unwrap() listener.local_addr().unwrap()
); );
// Run the server.
axum::serve( axum::serve(
listener, listener,
app.into_make_service_with_connect_info::<SocketAddr>(), app.into_make_service_with_connect_info::<SocketAddr>(),
@ -132,7 +89,6 @@ pub async fn start_local_ws() {
.await .await
.unwrap(); .unwrap();
} else { } else {
// Log binding failure and exit.
error!("Failed to listen on: {}:{}", app_host, app_port); error!("Failed to listen on: {}:{}", app_host, app_port);
} }
} }

View file

@ -51,6 +51,6 @@ mod tests {
assert!(name.contains('-')); assert!(name.contains('-'));
assert!(name.split('-').count() == 3); assert!(name.split('-').count() == 3);
assert!(name.len() > 0); assert!(name.is_empty());
} }
} }

View file

@ -14,180 +14,46 @@ use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::protocol::Message as WebSocketMessage; use tokio_tungstenite::tungstenite::protocol::Message as WebSocketMessage;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
/// This struct is used to serialize/deserialize JSON packets sent
/// between the client and the server.
///
/// The `type` field is used to specify the type of packet that is being sent.
/// The possible values for this field are listed as variants of the enum.
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")] #[serde(tag = "type", rename_all = "camelCase")]
pub enum JsonPacket { pub enum JsonPacket {
/// Sent from the client to ask to join a room. Join { id: String },
///
/// The `id` field specifies the ID of the room that the client wants
/// to join.
Join {
/// The ID of the room that the client wants to join.
id: String,
},
/// Sent from the client to ask to create a new room.
Create { id: Option<String> }, Create { id: Option<String> },
// Create,
/// Sent from the client to ask to leave the current room.
Leave, Leave,
} }
/// This struct is used to serialize/deserialize JSON packets sent
/// from the server to the client.
///
/// The `type` field is used to specify the type of packet that is being
/// sent. The possible values for this field are listed as variants of the
/// enum.
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")] #[serde(tag = "type", rename_all = "camelCase")]
pub enum JsonPacketResponse { pub enum JsonPacketResponse {
/// Sent from the server to inform the client of the result of a `Join`
/// packet.
///
/// If the client successfully joined a room, the `size` field will be
/// `Some` and contain the size of the room. If the client could not join
/// a room, the `size` field will be `None`.
Join { Join {
/// The size of the room that the client joined. If the client could
/// not join a room, this field will be `None`.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
size: Option<usize>, size: Option<usize>,
}, },
/// Sent from the server to inform the client of the result of a `Create`
/// packet.
///
/// If the server successfully created a room, the `id` field will
/// contain the ID of the room. If the server could not create a room,
/// the `id` field will be empty.
Create { Create {
/// The ID of the room that the server created. If the server could
/// not create a room, this field will be empty.
id: String, id: String,
}, },
/// Sent from the server to inform the client of the result of a `Leave`
/// packet.
///
/// If the client successfully left a room, the `index` field will
/// contain the index of the client that left the room. If the client
/// could not leave a room, the `index` field will be 0.
Leave { Leave {
/// The index of the client that left the room. If the client could
/// not leave a room, this field will be 0.
index: usize, index: usize,
}, },
/// Sent from the server to inform the client of an error.
///
/// The `message` field contains a description of the error.
Error { Error {
/// A description of the error that occurred.
message: String, message: String,
}, },
} }
/// This enum represents the result of processing an event in the event loop.
///
/// The `Status` enum has three variants:
///
/// * `Continue` - This variant indicates that the event loop should
/// continue processing events. This is the most common result and is used
/// when the event loop has nothing special to do.
///
/// * `Exit` - This variant indicates that the event loop should exit. This
/// is used when the event loop should exit because of an error or
/// because the user has requested that the program exit.
///
/// * `Err` - This variant indicates that the event loop encountered an
/// error. When the event loop receives a `Status::Err` variant, it should
/// exit with an error message containing the message from the error packet.
/// The message from the error packet is the only information that the event
/// loop has about the error, so the message should be descriptive and
/// helpful to the user. The message should not contain technical details
/// about the error or how it occurred. Instead, the message should be
/// written from the perspective of the user and should give the user enough
/// information to understand what went wrong and how they might be able to
/// fix the problem.
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum Status { pub enum Status {
/// Indicates that the event loop should continue processing events.
Continue(), Continue(),
/// Indicates that the event loop should exit.
Exit(), Exit(),
/// Indicates that the event loop encountered an error.
Err(String), Err(String),
} }
/// A trait for sending JSON packets.
///
/// This trait provides a single method, `send_json_packet`, which sends a
/// JSON packet over some underlying transport.
pub trait JsonPacketSender { pub trait JsonPacketSender {
/// Sends a JSON packet.
///
/// This method takes a single argument, `packet`, which is the JSON packet
/// to send. The packet will be serialized into a JSON string and then sent
/// over the underlying transport.
///
/// Note that the exact semantics of what it means to "send a JSON packet"
/// will depend on the specific implementation of this trait. However, in
/// general, the packet will be sent as a single message over the
/// transport, and the transport will be responsible for ensuring that the
/// packet is delivered to the intended recipient.
///
/// # Errors
///
/// If there is an error serializing the JSON packet, or if there is an
/// error sending the serialized packet over the transport, this method
/// may return an error. The exact semantics of what constitutes an error
/// will depend on the specific implementation of this trait.
fn send_json_packet(&self, packet: JsonPacket); fn send_json_packet(&self, packet: JsonPacket);
} }
/// A trait for sending Protocol Buffers packets over some underlying transport.
///
/// This trait provides two methods for sending Protocol Buffers packets:
///
/// * `send_packet` sends a packet in the clear (i.e., not encrypted).
/// * `send_encrypted_packet` sends a packet encrypted using the AES-GCM
/// algorithm with a 128-bit key.
///
/// The exact semantics of what it means to "send a packet" will depend on the
/// specific implementation of this trait. However, in general, the packet will
/// be serialized into a binary message using the Protocol Buffers wire format,
/// and then sent over the underlying transport.
///
/// The `destination` argument specifies which recipient should receive the
/// packet. This is a 1-byte field that is prepended to the serialized packet
/// before it is sent.
///
/// The `key` argument is an optional AES-GCM key. If a key is provided, the
/// packet will be encrypted before being sent. If no key is provided, the
/// packet will be sent in the clear.
///
/// # Errors
///
/// If there is an error serializing the Protocol Buffers packet, or if there
/// is an error sending the serialized packet over the transport, either of
/// these methods may return an error. The exact semantics of what constitutes
/// an error will depend on the specific implementation of this trait.
pub trait PacketSender { pub trait PacketSender {
/// Sends a Protocol Buffers packet in the clear.
///
/// The packet will be serialized into a binary message using the Protocol
/// Buffers wire format, and then sent over the underlying transport.
fn send_packet(&self, destination: u8, packet: packets::packet::Value); fn send_packet(&self, destination: u8, packet: packets::packet::Value);
/// Sends a Protocol Buffers packet encrypted using AES-GCM.
///
/// The packet will be serialized into a binary message using the Protocol
/// Buffers wire format, encrypted using AES-GCM with a 128-bit key, and
/// then sent over the underlying transport.
///
/// If no key is provided, the packet will be sent in the clear.
fn send_encrypted_packet( fn send_encrypted_packet(
&self, &self,
key: &Option<Aes128Gcm>, key: &Option<Aes128Gcm>,
@ -197,22 +63,6 @@ pub trait PacketSender {
} }
impl JsonPacketSender for Sender { impl JsonPacketSender for Sender {
/// Serializes the given JSON packet into a string, and then sends it as a
/// text message over the underlying transport.
///
/// The `JsonPacket` type is defined in the `serde_json` crate, and it is a
/// simple wrapper around a JSON object with string keys and values. This
/// trait method is responsible for taking a `JsonPacket` and sending it
/// over the WebSocket connection.
///
/// The `serde_json::to_string` function is used to serialize the packet
/// into a JSON string. If this function returns an error, we panic
/// because there is no reasonable recovery behavior in this case.
///
/// Once we have the JSON string, we wrap it in a `WebSocketMessage::Text`
/// enum variant and send it over the WebSocket connection using the
/// `send` method. If this method returns an error, we panic because there
/// is no reasonable recovery behavior in this case.
fn send_json_packet(&self, packet: JsonPacket) { fn send_json_packet(&self, packet: JsonPacket) {
let serialized_packet = let serialized_packet =
serde_json::to_string(&packet).expect("Failed to serialize JSON packet."); serde_json::to_string(&packet).expect("Failed to serialize JSON packet.");
@ -223,26 +73,6 @@ impl JsonPacketSender for Sender {
} }
impl PacketSender for Sender { impl PacketSender for Sender {
/// Serializes the given packet value into a binary message, and then
/// sends it over the underlying transport.
///
/// The `destination` parameter specifies which client should receive
/// this message. The value of this parameter should be a byte that
/// represents the client's index in the list of connected clients.
///
/// The `value` parameter specifies the actual data that should be sent
/// to the client. This will be serialized into a `Packet` struct using
/// the Protocol Buffers wire format.
///
/// This function will first encode the `Packet` struct into a vector of
/// bytes using the Protocol Buffers wire format. It will then insert the
/// `destination` byte as the first element of the vector, so that the
/// receiving client knows which client this message is intended for.
///
/// Finally, this function will send the serialized packet over the
/// underlying transport, which is assumed to be a WebSocket connection.
/// If this send operation fails, this function will panic because there
/// is no reasonable recovery behavior in this case.
fn send_packet(&self, destination: u8, value: packets::packet::Value) { fn send_packet(&self, destination: u8, value: packets::packet::Value) {
let packet = Packet { value: Some(value) }; let packet = Packet { value: Some(value) };
@ -253,21 +83,6 @@ impl PacketSender for Sender {
.expect("Failed to send Packet."); .expect("Failed to send Packet.");
} }
/// Similar to `send_packet`, but the message is encrypted using AES-GCM
/// with a 128-bit key.
///
/// If no key is provided (i.e., if `key` is `None`), then the message will
/// be sent in the clear.
///
/// This function works by generating a random 12-byte nonce using the
/// `rand::OsRng` PRNG, encrypting the message using AES-GCM with the
/// provided key and nonce, and then prepending the nonce to the ciphertext
/// before sending it over the WebSocket connection. The receiving client
/// will use the same key and nonce to decrypt the message.
///
/// Note that this function does not actually check whether the provided
/// key is valid. If an invalid key is provided, the encryption will fail
/// and the receiver will not be able to decrypt the message.
fn send_encrypted_packet( fn send_encrypted_packet(
&self, &self,
key: &Option<Aes128Gcm>, key: &Option<Aes128Gcm>,
@ -293,43 +108,6 @@ impl PacketSender for Sender {
} }
} }
/// A sender is a type that allows us to send messages to a WebSocket client.
///
/// In this case, a sender is a channel that allows us to send WebSocket
/// messages to a client. The messages can be any type that implements the
/// `Into<WebSocketMessage>`.
///
/// The `WebSocketMessage` type represents any message that can be sent over a
/// WebSocket connection. It can be a binary message, a text message, or a
/// close message.
///
/// The `MaybeTlsStream` type is a stream that may or may not be encrypted.
/// If the connection is encrypted (e.g., via TLS), then the stream will be
/// encrypted. If the connection is not encrypted, then the stream will be
/// unencrypted.
///
/// The `TcpStream` type is a stream that is used to connect to a remote
/// server over a TCP connection.
///
/// The `WebSocketStream` type is a stream that is used to connect to a remote
/// WebSocket server. It is a wrapper around the `MaybeTlsStream` stream that
/// adds WebSocket-specific functionality.
pub type Sender = flume::Sender<WebSocketMessage>; pub type Sender = flume::Sender<WebSocketMessage>;
/// A socket is a type that represents a WebSocket connection.
///
/// In this case, a socket is a wrapper around a `MaybeTlsStream` stream that
/// adds WebSocket-specific functionality.
///
/// The `MaybeTlsStream` type is a stream that may or may not be encrypted.
/// If the connection is encrypted (e.g., via TLS), then the stream will be
/// encrypted. If the connection is not encrypted, then the stream will be
/// unencrypted.
///
/// The `TcpStream` type is a stream that is used to connect to a remote
/// server over a TCP connection.
///
/// The `WebSocketStream` type is a stream that is used to connect to a remote
/// WebSocket server. It is a wrapper around the `MaybeTlsStream` stream that
/// adds WebSocket-specific functionality.
pub type Socket = WebSocketStream<MaybeTlsStream<TcpStream>>; pub type Socket = WebSocketStream<MaybeTlsStream<TcpStream>>;