refactor(caesar): restructure the project to a monorepo
This commit is contained in:
parent
17ebd0261b
commit
b39e88107a
27 changed files with 195 additions and 142 deletions
45
caesar-core/Cargo.toml
Normal file
45
caesar-core/Cargo.toml
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
[package]
|
||||
name = "caesar-core"
|
||||
version = "0.3.1"
|
||||
edition = "2021"
|
||||
build = "src/build.rs"
|
||||
authors = ["Manuel Keidel", "Patryk Hegenberg", "Krzysztof Stankiewicz"]
|
||||
|
||||
[dependencies]
|
||||
futures-util = "0.3"
|
||||
tungstenite = "0.21.0"
|
||||
tokio = { version = "1.28.1", features = ["full"] }
|
||||
tokio-tungstenite = { version = "0.21.0", features = [
|
||||
"rustls-tls-webpki-roots",
|
||||
] }
|
||||
serde_json = { version = "1.0" }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
uuid = { version = "1.3.2", features = ["v4"] }
|
||||
tracing = "0.1.40"
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
dotenv = { version = "0.15.0", features = ["clap", "cli"] }
|
||||
clap = { version = "4.5.4", features = ["derive"] }
|
||||
flume = { git = "https://github.com/zesterer/flume", rev = "80d19c49" }
|
||||
prost = "0.12.4"
|
||||
prost-types = "0.12.4"
|
||||
base64 = "0.22.0"
|
||||
url = "2.4.0"
|
||||
p256 = { version = "0.13.2", features = ["ecdh"] }
|
||||
hmac = "0.12.1"
|
||||
sha2 = "0.10.7"
|
||||
rand = { version = "0.8.5", features = ["getrandom"] }
|
||||
aes-gcm = "0.10.3"
|
||||
sanitize-filename = "0.5.0"
|
||||
qr2term = "0.3.1"
|
||||
axum = { version = "0.7.5", features = ["ws"] }
|
||||
tower-http = { version = "0.5.2", features = ["fs", "trace"] }
|
||||
axum-client-ip = "0.6.0"
|
||||
local-ip-address = "0.6.1"
|
||||
axum-extra = { version = "0.9.3", features = ["typed-header"] }
|
||||
headers = "0.4"
|
||||
tower = { version = "0.4", features = ["util"] }
|
||||
reqwest = { version = "0.12.4", features = ["blocking", "json"] }
|
||||
hex = "0.4.3"
|
||||
|
||||
[build-dependencies]
|
||||
prost-build = "0.12.4"
|
||||
42
caesar-core/packets.proto
Normal file
42
caesar-core/packets.proto
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package Packets;
|
||||
|
||||
message HandshakePacket {
|
||||
bytes publicKey = 1;
|
||||
bytes signature = 2;
|
||||
}
|
||||
|
||||
message HandshakeResponsePacket {
|
||||
bytes publicKey = 1;
|
||||
bytes signature = 2;
|
||||
}
|
||||
|
||||
message ListPacket {
|
||||
message Entry {
|
||||
uint32 index = 1;
|
||||
uint64 size = 2;
|
||||
string name = 3;
|
||||
}
|
||||
repeated Entry entries = 1;
|
||||
}
|
||||
|
||||
message ProgressPacket {
|
||||
uint32 index = 1;
|
||||
uint32 progress = 2;
|
||||
}
|
||||
|
||||
message ChunkPacket {
|
||||
uint32 sequence = 1;
|
||||
bytes chunk = 2;
|
||||
}
|
||||
|
||||
message Packet {
|
||||
oneof value {
|
||||
HandshakePacket handshake = 1;
|
||||
HandshakeResponsePacket handshakeResponse = 2;
|
||||
ListPacket list = 3;
|
||||
ProgressPacket progress = 4;
|
||||
ChunkPacket chunk = 5;
|
||||
}
|
||||
}
|
||||
22
caesar-core/src/build.rs
Normal file
22
caesar-core/src/build.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
// This build script is invoked by cargo when it is building our crate. It is responsible for
|
||||
// generating code based on other protocol buffer definitions.
|
||||
//
|
||||
// Specifically, it generates Rust code from the `packets.proto` file in the root of our
|
||||
// crate. This generated code is then compiled into our final binary.
|
||||
//
|
||||
// The `prost_build` crate is responsible for doing the actual work of generating code from
|
||||
// protocol buffer definitions. We're passing it the path to our `.proto` file and the root
|
||||
// directory of our crate.
|
||||
extern crate prost_build;
|
||||
|
||||
fn main() {
|
||||
// Invoke the `compile_protos` function from the `prost_build` crate. This function takes
|
||||
// two arguments: a list of `.proto` files to compile and the root directory of our crate.
|
||||
// It returns a `Result` indicating whether the compilation was successful or not.
|
||||
//
|
||||
// The `.unwrap()` method is then called on the `Result` to panic if the compilation
|
||||
// failed. This is okay in a build script because it will stop the build process and
|
||||
// prevent our code from being built.
|
||||
prost_build::compile_protos(&["packets.proto"], &["."]).unwrap();
|
||||
}
|
||||
|
||||
4
caesar-core/src/lib.rs
Normal file
4
caesar-core/src/lib.rs
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
pub mod receiver;
|
||||
pub mod relay;
|
||||
pub mod sender;
|
||||
pub mod shared;
|
||||
760
caesar-core/src/receiver/client.rs
Normal file
760
caesar-core/src/receiver/client.rs
Normal file
|
|
@ -0,0 +1,760 @@
|
|||
use std::{fs, io::stdout, path::Path};
|
||||
|
||||
use crate::shared::{
|
||||
packets::{
|
||||
packet::Value, ChunkPacket, HandshakePacket, HandshakeResponsePacket, ListPacket, Packet,
|
||||
ProgressPacket,
|
||||
},
|
||||
JsonPacket, JsonPacketResponse, JsonPacketSender, PacketSender, Sender, Socket, Status,
|
||||
};
|
||||
|
||||
use aes_gcm::{aead::Aead, Aes128Gcm, Key};
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt};
|
||||
use hmac::{Hmac, Mac};
|
||||
use p256::{ecdh::EphemeralSecret, pkcs8::der::Writer, PublicKey};
|
||||
use prost::Message;
|
||||
use rand::rngs::OsRng;
|
||||
use sha2::Sha256;
|
||||
use tokio_tungstenite::tungstenite::{protocol::Message as WebSocketMessage, Error};
|
||||
use tracing::error;
|
||||
|
||||
const DESTINATION: u8 = 0;
|
||||
const NONCE_SIZE: usize = 12;
|
||||
|
||||
/// This struct represents a file that is being received.
|
||||
///
|
||||
/// The struct contains information about the file, such as its name, size,
|
||||
/// and the handle of the file on the disk.
|
||||
///
|
||||
/// The `name` field contains the name of the file, which is the name of the
|
||||
/// file on the disk.
|
||||
///
|
||||
/// The `size` field contains the size of the file in bytes.
|
||||
///
|
||||
/// The `progress` field contains the number of bytes that have already been
|
||||
/// received for the file.
|
||||
///
|
||||
/// The `handle` field contains a handle to the file on the disk, which is
|
||||
/// used to read the contents of the file.
|
||||
struct File {
|
||||
name: String,
|
||||
size: u64,
|
||||
progress: u64,
|
||||
handle: fs::File,
|
||||
}
|
||||
|
||||
/// This struct contains the context for the receiver.
|
||||
///
|
||||
/// This structure is used to keep track of the state of the receiver, and to
|
||||
/// pass information between functions.
|
||||
struct Context {
|
||||
/// The HMAC key that is used to verify that the packets that are received
|
||||
/// are authentic.
|
||||
///
|
||||
/// The HMAC key is generated by the sender, and is used to verify that the
|
||||
/// packets that are received are authentic. If the HMAC of a packet does
|
||||
/// not match the expected HMAC, then the packet is not processed.
|
||||
hmac: Vec<u8>,
|
||||
|
||||
/// The sender that is used to send packets to the server.
|
||||
///
|
||||
/// The sender is used to send packets to the server. The sender is also
|
||||
/// used to receive packets from the server.
|
||||
sender: Sender,
|
||||
|
||||
/// The ephemeral secret key that is used for key exchange with the sender.
|
||||
///
|
||||
/// The ephemeral secret key is generated by the receiver, and is used to
|
||||
/// exchange a shared key with the sender. The shared key is used to
|
||||
/// encrypt and decrypt packets.
|
||||
key: EphemeralSecret,
|
||||
|
||||
/// The shared key that is used to encrypt and decrypt packets.
|
||||
///
|
||||
/// The shared key is established between the receiver and the sender during
|
||||
/// the key exchange. The shared key is used to encrypt and decrypt packets
|
||||
/// between the receiver and the sender. If the shared key is `None`, then the
|
||||
/// packets that are received are not processed.
|
||||
shared_key: Option<Aes128Gcm>,
|
||||
|
||||
/// The files that are being received.
|
||||
///
|
||||
/// The files vector contains a list of all the files that are being
|
||||
/// received. Each file is represented by a `File` struct. The `name` field
|
||||
/// of the `File` struct contains the name of the file, which is the name of
|
||||
/// the file on the disk. The `size` field of the `File` struct contains the
|
||||
/// size of the file in bytes. The `progress` field of the `File` struct
|
||||
/// contains the number of bytes that have already been received for the
|
||||
/// file. The `handle` field of the `File` struct contains a handle to the
|
||||
/// file on the disk, which is used to read the contents of the file.
|
||||
files: Vec<File>,
|
||||
|
||||
/// The sequence number of the next chunk that is expected to be received.
|
||||
///
|
||||
/// The sequence number is used to keep track of the sequence of chunks that
|
||||
/// are received. If a chunk does not have the expected sequence number,
|
||||
/// then the chunk is not processed.
|
||||
sequence: u32,
|
||||
|
||||
/// The index of the file that is currently being received.
|
||||
///
|
||||
/// The index is used to keep track of which file is currently being
|
||||
/// received. The index is incremented after a file is completely received.
|
||||
index: usize,
|
||||
|
||||
/// The progress of the current file that is being received.
|
||||
///
|
||||
/// The progress is used to keep track of the progress of the current file
|
||||
/// that is being received. The progress is calculated by dividing the
|
||||
/// number of bytes that have been received by the size of the file. The
|
||||
/// progress is sent to the server so that the sender knows how much of the
|
||||
/// file has been received.
|
||||
progress: u64,
|
||||
|
||||
/// The total length of the current file that is being received.
|
||||
///
|
||||
/// The length is used to keep track of the total length of the current file
|
||||
/// that is being received. The length is used to calculate the progress of
|
||||
/// the file.
|
||||
length: u64,
|
||||
}
|
||||
|
||||
/// This function is called when the receiver receives a join room packet from
|
||||
/// the server. The packet contains the size of the list of files that the
|
||||
/// sender is going to send.
|
||||
///
|
||||
/// If the packet does not contain the size of the list, then an error is
|
||||
/// returned.
|
||||
///
|
||||
/// If the packet does contain the size of the list, then a message is printed
|
||||
/// to the console indicating that the receiver has connected to the room.
|
||||
///
|
||||
/// The function does not do anything else. It returns a `Status::Continue`
|
||||
/// variant to indicate that the event loop should continue processing events.
|
||||
fn on_join_room(size: Option<usize>) -> Status {
|
||||
// If the packet does not contain the size of the list, then return an error.
|
||||
if size.is_none() {
|
||||
return Status::Err("Invalid join room packet.".into());
|
||||
}
|
||||
|
||||
// If the packet contains the size of the list, then print a message to the
|
||||
// console indicating that the receiver has connected to the room.
|
||||
println!("Connected to room.");
|
||||
|
||||
// Return a `Status::Continue` variant to indicate that the event loop
|
||||
// should continue processing events.
|
||||
Status::Continue()
|
||||
}
|
||||
|
||||
/// This function is called when the event loop receives an error packet from
|
||||
/// the server. The packet contains a message with a description of the error.
|
||||
///
|
||||
/// When an error occurs, the server sends an error packet to the client. The
|
||||
/// error packet contains a message with a description of the error. This
|
||||
/// function extracts that message and creates a `Status::Err` variant with it,
|
||||
/// which is then returned to be handled by the main event loop.
|
||||
///
|
||||
/// When the event loop receives a status variant that is an error, it exits
|
||||
/// with an error message containing the message from the error packet.
|
||||
///
|
||||
/// The message from the error packet is the only information that the event
|
||||
/// loop has about the error, so the message should be descriptive and
|
||||
/// helpful to the user. The message should not contain technical details
|
||||
/// about the error or how it occurred. Instead, the message should be
|
||||
/// written from the perspective of the user and should give the user enough
|
||||
/// information to understand what went wrong and how they might be able to
|
||||
/// fix the problem.
|
||||
///
|
||||
/// This function takes the message from the error packet and creates a
|
||||
/// `Status::Err` variant with it. The function returns this variant to be
|
||||
/// handled by the main event loop.
|
||||
fn on_error(message: String) -> Status {
|
||||
Status::Err(message)
|
||||
}
|
||||
|
||||
/// This function is called when the event loop receives a leave room packet from
|
||||
/// the server. The packet contains the index of the file that was being
|
||||
/// transferred when the receiver left the room.
|
||||
///
|
||||
/// When the receiver receives a leave room packet, it means that the sender
|
||||
/// has disconnected from the room. In this case, the receiver should check if
|
||||
/// there are any files that were being transferred but not yet complete. If
|
||||
/// there are any incomplete files, the receiver should print a message to
|
||||
/// the user indicating that the transfer was interrupted.
|
||||
///
|
||||
/// If there are no incomplete files, then the receiver should exit
|
||||
/// normally. The `Status::Exit` variant is returned to the main event loop,
|
||||
/// which will cause the event loop to exit normally.
|
||||
///
|
||||
/// This function checks if there are any incomplete files by iterating over
|
||||
/// the list of files in the context. If there are any files with a progress
|
||||
/// less than 100%, then the function prints a message to the user and returns
|
||||
/// an error status.
|
||||
///
|
||||
/// If there are no incomplete files, then the function simply returns a
|
||||
/// `Status::Exit` variant. This will cause the main event loop to exit
|
||||
/// normally.
|
||||
fn on_leave_room(context: &mut Context, _: usize) -> Status {
|
||||
// Check if there are any incomplete files.
|
||||
if context.files.iter().any(|file| file.progress < 100) {
|
||||
// If there are any incomplete files, print a message to the user.
|
||||
println!();
|
||||
println!("Transfer was interrupted because the host left the room.");
|
||||
|
||||
// Return an error status.
|
||||
Status::Err("Transfer was interrupted because the host left the room.".into())
|
||||
} else {
|
||||
// If there are no incomplete files, return a `Status::Exit` variant.
|
||||
// This will cause the event loop to exit normally.
|
||||
Status::Exit()
|
||||
}
|
||||
}
|
||||
|
||||
/// This function is called when the event loop receives a list packet from
|
||||
/// the server. The packet contains a list of files to be transferred.
|
||||
///
|
||||
/// When this function is called, we know that the sender has successfully
|
||||
/// established a shared key with the receiver. Therefore, we can start
|
||||
/// receiving encrypted files.
|
||||
///
|
||||
/// This function iterates over the list of files in the packet and creates a
|
||||
/// file on disk for each file in the list. If a file with the same name already
|
||||
/// exists, an error is returned and the event loop is exited with an error
|
||||
/// message. This is because the receiver should not overwrite existing files
|
||||
/// without the user's explicit permission.
|
||||
///
|
||||
/// Once all the files have been created, the function initializes the context
|
||||
/// variables for the event loop. `index` is set to 0 to indicate that we are
|
||||
/// currently transferring the first file. `progress` is set to 0 to indicate
|
||||
/// that the progress of the first file is 0%. `sequence` is set to 0 to
|
||||
/// indicate that the first chunk of data we receive will have a sequence
|
||||
/// number of 0. `length` is set to 0 to indicate that we have not received
|
||||
/// any data yet.
|
||||
///
|
||||
/// If there is an error creating any of the files, the function returns an
|
||||
/// error status. This will cause the event loop to exit with an error message.
|
||||
///
|
||||
/// If there are no errors, the function returns a `Status::Continue()` variant.
|
||||
/// This will cause the event loop to continue running and wait for more
|
||||
/// packets from the sender.
|
||||
fn on_list(context: &mut Context, list: ListPacket) -> Status {
|
||||
if context.shared_key.is_none() {
|
||||
return Status::Err("Invalid list packet: no shared key established".into());
|
||||
}
|
||||
|
||||
// Iterate over the list of files in the packet.
|
||||
for entry in list.entries {
|
||||
// Sanitize the file name to remove any characters that are not valid in
|
||||
// file names on the current platform.
|
||||
let path = sanitize_filename::sanitize(entry.name.clone());
|
||||
|
||||
// Check if a file with the same name already exists.
|
||||
if Path::new(&path).exists() {
|
||||
// If the file already exists, return an error and exit the event loop
|
||||
// with an error message.
|
||||
return Status::Err(format!("The file '{}' already exists.", path));
|
||||
}
|
||||
|
||||
// Try to create a new file with the sanitized file name.
|
||||
let handle = match fs::File::create(&path) {
|
||||
Ok(handle) => handle,
|
||||
Err(error) => {
|
||||
// If there is an error creating the file, return an error and
|
||||
// exit the event loop with an error message.
|
||||
return Status::Err(format!(
|
||||
"Error: Failed to create file '{}': {}",
|
||||
path, error
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Create a new file struct for the file we just created.
|
||||
let file = File {
|
||||
name: entry.name,
|
||||
size: entry.size,
|
||||
handle,
|
||||
progress: 0,
|
||||
};
|
||||
|
||||
// Add the new file to the list of files in the context.
|
||||
context.files.push(file);
|
||||
}
|
||||
|
||||
// Set the context variables for the event loop.
|
||||
context.index = 0;
|
||||
context.progress = 0;
|
||||
context.sequence = 0;
|
||||
context.length = 0;
|
||||
|
||||
// Return a `Status::Continue()` variant to indicate that the event loop
|
||||
// should continue running and wait for more packets from the sender.
|
||||
Status::Continue()
|
||||
}
|
||||
|
||||
/// This function handles a chunk packet received from the sender.
|
||||
///
|
||||
/// It checks that the shared key has been established, that the sequence number
|
||||
/// of the chunk matches the expected sequence number in the context, and that
|
||||
/// the index of the file in the context is valid.
|
||||
///
|
||||
/// If any of these checks fail, an error is returned and the event loop is
|
||||
/// stopped.
|
||||
///
|
||||
/// The function updates the length of the file, increments the sequence number
|
||||
/// in the context, and writes the contents of the chunk to the file.
|
||||
///
|
||||
/// The progress of the file is updated to be the ratio of the number of bytes
|
||||
/// read so far to the total size of the file.
|
||||
///
|
||||
/// If the progress of the file is 100%, or if the difference in progress between
|
||||
/// this chunk and the last chunk is greater than or equal to 1, or if this is the
|
||||
/// first chunk, a ProgressPacket is sent to the sender with the index of the file
|
||||
/// in the context and the progress of the file.
|
||||
///
|
||||
/// If the size of the file has been reached, the index of the current file is
|
||||
/// incremented, the length of the current file is set to 0, the progress of the
|
||||
/// current file is set to 0, and the sequence number is set to 0.
|
||||
///
|
||||
/// Finally, a Status::Continue() variant is returned to indicate that the event
|
||||
/// loop should continue running and wait for more packets from the sender.
|
||||
fn on_chunk(context: &mut Context, chunk: ChunkPacket) -> Status {
|
||||
// Check that the shared key has been established.
|
||||
if context.shared_key.is_none() {
|
||||
return Status::Err("Invalid chunk packet: no shared key established".into());
|
||||
}
|
||||
|
||||
// Check that the sequence number of the chunk matches the expected sequence
|
||||
// number in the context.
|
||||
if chunk.sequence != context.sequence {
|
||||
return Status::Err(format!(
|
||||
"Expected sequence {}, but got {}.",
|
||||
context.sequence, chunk.sequence
|
||||
));
|
||||
}
|
||||
|
||||
// Get a mutable reference to the file in the context at the index of the
|
||||
// file.
|
||||
let Some(file) = context.files.get_mut(context.index) else {
|
||||
// If the index of the file in the context is invalid, return an error and
|
||||
// stop the event loop.
|
||||
return Status::Err("Invalid file index.".into());
|
||||
};
|
||||
|
||||
// Update the length of the file.
|
||||
context.length += chunk.chunk.len() as u64;
|
||||
|
||||
// Increment the sequence number in the context.
|
||||
context.sequence += 1;
|
||||
|
||||
// Write the contents of the chunk to the file.
|
||||
file.handle.write(&chunk.chunk).unwrap();
|
||||
|
||||
// Update the progress of the file.
|
||||
file.progress = (context.length * 100) / file.size;
|
||||
|
||||
// If the progress of the file is 100%, or if the difference in progress between
|
||||
// this chunk and the last chunk is greater than or equal to 1, or if this is the
|
||||
// first chunk, send a ProgressPacket to the sender.
|
||||
if file.progress == 100 || file.progress - context.progress >= 1 || chunk.sequence == 0 {
|
||||
context.progress = file.progress;
|
||||
|
||||
let progress = ProgressPacket {
|
||||
// Convert the index of the file in the context to a u32.
|
||||
index: context.index.try_into().unwrap(),
|
||||
// Convert the progress of the file to a u32.
|
||||
progress: context.progress.try_into().unwrap(),
|
||||
};
|
||||
|
||||
// Send the ProgressPacket to the sender.
|
||||
context.sender.send_encrypted_packet(
|
||||
&context.shared_key,
|
||||
DESTINATION,
|
||||
Value::Progress(progress),
|
||||
);
|
||||
|
||||
print!("\rTransferring '{}': {}%", file.name, file.progress);
|
||||
std::io::Write::flush(&mut stdout()).unwrap();
|
||||
}
|
||||
|
||||
// If the size of the file has been reached, increment the index of the
|
||||
// current file, set the length of the current file to 0, set the progress
|
||||
// of the current file to 0, and resets the sequence number to 0.
|
||||
if file.size == context.length {
|
||||
context.index += 1;
|
||||
context.length = 0;
|
||||
context.progress = 0;
|
||||
context.sequence = 0;
|
||||
|
||||
println!();
|
||||
}
|
||||
|
||||
// Return a Status::Continue() variant to indicate that the event loop should
|
||||
// continue running and wait for more packets from the sender.
|
||||
Status::Continue()
|
||||
}
|
||||
|
||||
/// This function is called when the Receiver receives a HandshakePacket from the
|
||||
/// Sender. It verifies the signature of the Sender's public key and generates its own
|
||||
/// public key. It then generates a shared secret key between the Receiver and the Sender
|
||||
/// using the Diffie-Hellman key exchange.
|
||||
///
|
||||
/// The Receiver sends back a HandshakeResponsePacket to the Sender with its own public
|
||||
/// key and a signature created using the shared secret key and its own private key.
|
||||
///
|
||||
/// The shared secret key is used to encrypt packets sent between the Receiver and the
|
||||
/// Sender.
|
||||
fn on_handshake(context: &mut Context, handshake: HandshakePacket) -> Status {
|
||||
// If a shared key has already been established, this means that the Receiver
|
||||
// has already performed the handshake, so return an error.
|
||||
if context.shared_key.is_some() {
|
||||
return Status::Err("Already performed handshake.".into());
|
||||
}
|
||||
|
||||
// Create a new HMAC using the hmac from the Context struct as the key.
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(&context.hmac).unwrap();
|
||||
|
||||
// Update the HMAC with the public key from the HandshakePacket.
|
||||
mac.update(&handshake.public_key);
|
||||
|
||||
// Call verify_slice() on the HMAC to verify the signature from the Sender.
|
||||
// If the signature is invalid, return an error.
|
||||
let verification = mac.verify_slice(&handshake.signature);
|
||||
if verification.is_err() {
|
||||
return Status::Err("Invalid signature from the sender.".into());
|
||||
}
|
||||
|
||||
// Generate the Receiver's public key from the private key.
|
||||
let public_key = context.key.public_key().to_sec1_bytes().into_vec();
|
||||
|
||||
// Create a new HMAC using the hmac from the Context struct as the key.
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(&context.hmac).unwrap();
|
||||
|
||||
// Update the HMAC with the public key of the Receiver.
|
||||
mac.update(&public_key);
|
||||
|
||||
// Serialize the resulting HMAC into a byte array and use it as the
|
||||
// signature in the HandshakeResponsePacket.
|
||||
let signature = mac.finalize().into_bytes().to_vec();
|
||||
// Create a new shared secret key between the Receiver and the Sender.
|
||||
let shared_public_key = PublicKey::from_sec1_bytes(&handshake.public_key).unwrap();
|
||||
|
||||
let shared_secret = context.key.diffie_hellman(&shared_public_key);
|
||||
let shared_secret = shared_secret.raw_secret_bytes();
|
||||
let shared_secret = &shared_secret[0..16];
|
||||
|
||||
// Create a new Aes128Gcm key from the shared secret.
|
||||
let shared_key: &Key<Aes128Gcm> = shared_secret.into();
|
||||
let shared_key = <Aes128Gcm as aes_gcm::KeyInit>::new(shared_key);
|
||||
|
||||
// Create the HandshakeResponsePacket and send it to the Sender.
|
||||
let handshake_response = HandshakeResponsePacket {
|
||||
public_key,
|
||||
signature,
|
||||
};
|
||||
|
||||
context
|
||||
.sender
|
||||
.send_packet(DESTINATION, Value::HandshakeResponse(handshake_response));
|
||||
|
||||
// Store the shared key in the Context struct.
|
||||
context.shared_key = Some(shared_key);
|
||||
|
||||
// Return a Status::Continue() variant to indicate that the event loop should
|
||||
// continue running and wait for more packets from the Sender.
|
||||
Status::Continue()
|
||||
}
|
||||
|
||||
/// This function is called when a message is received from the Sender.
|
||||
///
|
||||
/// The message can be either text or binary. If it's text, we attempt to
|
||||
/// parse it as a JsonPacketResponse and match on the type of response it is.
|
||||
/// If it's binary, we attempt to decrypt it using the shared key (if it
|
||||
/// exists) and then decode it into a Packet. We then match on the type of
|
||||
/// value in the Packet and call the appropriate function with the relevant
|
||||
/// data.
|
||||
///
|
||||
/// If the message is not text or binary, we return a Status::Err with an
|
||||
/// appropriate error message.
|
||||
fn on_message(context: &mut Context, message: WebSocketMessage) -> Status {
|
||||
if message.is_text() {
|
||||
let text = message.into_text().unwrap();
|
||||
let packet = serde_json::from_str(&text).unwrap();
|
||||
|
||||
return match packet {
|
||||
JsonPacketResponse::Join { size } => on_join_room(size),
|
||||
JsonPacketResponse::Leave { index } => on_leave_room(context, index),
|
||||
JsonPacketResponse::Error { message } => on_error(message),
|
||||
|
||||
_ => Status::Err(format!("Unexpected json packet: {:?}", packet)),
|
||||
};
|
||||
} else if message.is_binary() {
|
||||
let data = message.into_data();
|
||||
let data = &data[1..];
|
||||
|
||||
let data = if let Some(shared_key) = &context.shared_key {
|
||||
let nonce = &data[..NONCE_SIZE];
|
||||
let ciphertext = &data[NONCE_SIZE..];
|
||||
|
||||
shared_key.decrypt(nonce.into(), ciphertext).unwrap()
|
||||
} else {
|
||||
data.to_vec()
|
||||
};
|
||||
|
||||
let packet = Packet::decode(data.as_ref()).unwrap();
|
||||
let value = packet.value.unwrap();
|
||||
|
||||
return match value {
|
||||
Value::List(list) => on_list(context, list),
|
||||
Value::Chunk(chunk) => on_chunk(context, chunk),
|
||||
Value::Handshake(handshake) => on_handshake(context, handshake),
|
||||
|
||||
_ => Status::Err(format!("Unexpected packet: {:?}", value)),
|
||||
};
|
||||
}
|
||||
|
||||
Status::Err("Invalid message type".into())
|
||||
}
|
||||
|
||||
/// This function takes a websocket connection and an invite code,
|
||||
/// splits the connection into an outgoing and incoming part,
|
||||
/// creates a context for the connection, sends a join room packet,
|
||||
/// and starts two futures to handle incoming and outgoing messages.
|
||||
///
|
||||
/// The outgoing future reads from a channel and sends the messages
|
||||
/// through the outgoing part of the connection. If the sending fails,
|
||||
/// the future will print an error and exit.
|
||||
///
|
||||
/// The incoming future reads from the incoming part of the connection
|
||||
/// and passes the messages to the `on_message` function. If the message
|
||||
/// is an exit or an error, the function will print the error and exit.
|
||||
/// If the message is any other type of packet, it will be handled by the
|
||||
/// `on_message` function and the future will continue running.
|
||||
pub async fn start(socket: Socket, fragment: &str) {
|
||||
// Extract the room id and hmac from the invite code
|
||||
let Some(index) = fragment.rfind('-') else {
|
||||
println!("Error: The invite code '{}' is not valid.", fragment);
|
||||
return;
|
||||
};
|
||||
|
||||
let id = &fragment[..index];
|
||||
let hmac = &fragment[index + 1..];
|
||||
let Ok(hmac) = general_purpose::STANDARD.decode(hmac) else {
|
||||
error!("Error: Invalid base64 inside the invite code.");
|
||||
return;
|
||||
};
|
||||
|
||||
// Create a new ephemeral key pair
|
||||
let key = EphemeralSecret::random(&mut OsRng);
|
||||
|
||||
// Create a channel for sending messages
|
||||
let (sender, receiver) = flume::bounded(1000);
|
||||
|
||||
// Split the websocket connection into an outgoing and incoming part
|
||||
let (outgoing, incoming) = socket.split();
|
||||
|
||||
// Create a new context for the connection
|
||||
let mut context = Context {
|
||||
hmac,
|
||||
sender,
|
||||
key,
|
||||
|
||||
shared_key: None,
|
||||
files: vec![],
|
||||
|
||||
index: 0,
|
||||
sequence: 0,
|
||||
progress: 0,
|
||||
length: 0,
|
||||
};
|
||||
|
||||
println!("Attempting to join room '{}'...", id);
|
||||
|
||||
// Send a join room packet to the server
|
||||
context
|
||||
.sender
|
||||
.send_json_packet(JsonPacket::Join { id: id.to_string() });
|
||||
|
||||
// Create futures for handling incoming and outgoing messages
|
||||
let outgoing_handler = receiver.stream().map(Ok).forward(outgoing);
|
||||
let incoming_handler = incoming.try_for_each(|message| {
|
||||
// Call the on_message function to handle the message
|
||||
match on_message(&mut context, message) {
|
||||
// If the message is an exit, print a message and exit
|
||||
Status::Exit() => {
|
||||
println!("Transfer has completed.");
|
||||
|
||||
return future::err(Error::ConnectionClosed);
|
||||
}
|
||||
// If the message is an error, print the error and exit
|
||||
Status::Err(error) => {
|
||||
println!("Error: {}", error);
|
||||
|
||||
return future::err(Error::ConnectionClosed);
|
||||
}
|
||||
// If the message is any other type of packet, do nothing
|
||||
_ => {}
|
||||
};
|
||||
|
||||
// Continue running the future
|
||||
future::ok(())
|
||||
});
|
||||
|
||||
// Pin the futures to the stack so they can run concurrently
|
||||
pin_mut!(incoming_handler, outgoing_handler);
|
||||
|
||||
// Wait for either future to complete
|
||||
future::select(incoming_handler, outgoing_handler).await;
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
|
||||
|
||||
#[test]
|
||||
fn test_on_join_room_valid_size() {
|
||||
assert_eq!(on_join_room(Some(10)), Status::Continue());
|
||||
}
|
||||
#[test]
|
||||
fn test_on_join_room_invalid_size() {
|
||||
assert_eq!(
|
||||
on_join_room(None),
|
||||
Status::Err("Invalid join room packet.".into())
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn test_on_error_with_message() {
|
||||
assert_eq!(
|
||||
on_error("Error message".to_string()),
|
||||
Status::Err("Error message".to_string())
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn test_on_leave_room() {
|
||||
let (sender, _) = flume::bounded(1000);
|
||||
let mut context = Context {
|
||||
hmac: vec![],
|
||||
sender,
|
||||
key: EphemeralSecret::random(&mut OsRng),
|
||||
shared_key: None,
|
||||
files: vec![
|
||||
File {
|
||||
name: "file1.txt".to_string(),
|
||||
size: 100,
|
||||
progress: 100,
|
||||
handle: fs::File::create("file1.txt").unwrap(),
|
||||
},
|
||||
File {
|
||||
name: "file2.txt".to_string(),
|
||||
size: 100,
|
||||
progress: 50,
|
||||
handle: fs::File::create("file2.txt").unwrap(),
|
||||
},
|
||||
],
|
||||
sequence: 0,
|
||||
index: 0,
|
||||
progress: 0,
|
||||
length: 0,
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
on_leave_room(&mut context, 0),
|
||||
Status::Err("Transfer was interrupted because the host left the room.".into())
|
||||
);
|
||||
context.files[1].progress = 100;
|
||||
assert_eq!(on_leave_room(&mut context, 0), Status::Exit());
|
||||
}
|
||||
#[test]
|
||||
fn test_on_message_text_join() {
|
||||
let (sender, _) = flume::bounded(1000);
|
||||
let mut context = Context {
|
||||
hmac: vec![],
|
||||
sender,
|
||||
key: EphemeralSecret::random(&mut OsRng),
|
||||
shared_key: None,
|
||||
files: vec![],
|
||||
sequence: 0,
|
||||
index: 0,
|
||||
progress: 0,
|
||||
length: 0,
|
||||
};
|
||||
|
||||
let text_message = WebSocketMessage::Text(r#"{"type":"join","size":10}"#.to_string());
|
||||
assert_eq!(on_message(&mut context, text_message), Status::Continue());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_on_chunk() {
|
||||
let (sender, _) = flume::bounded(1000);
|
||||
// let mut context = Context {
|
||||
// hmac: vec![],
|
||||
// sender: sender.clone(),
|
||||
// key: EphemeralSecret::random(&mut OsRng),
|
||||
// shared_key: Some(Aes128Gcm::new(Key::<Aes128Gcm>::from_slice(&[0u8; 16]))),
|
||||
// files: vec![File {
|
||||
// name: "file1.txt".to_string(),
|
||||
// size: 100,
|
||||
// progress: 0,
|
||||
// handle: fs::File::create("file1.txt").unwrap(),
|
||||
// }],
|
||||
// sequence: 0,
|
||||
// index: 0,
|
||||
// progress: 0,
|
||||
// length: 0,
|
||||
// };
|
||||
|
||||
// let chunk_packet = ChunkPacket {
|
||||
// sequence: 0,
|
||||
// chunk: b"Hello, world!".to_vec(),
|
||||
// };
|
||||
// assert_eq!(on_chunk(&mut context, chunk_packet), Status::Continue());
|
||||
// assert_eq!(context.sequence, 1);
|
||||
// assert_eq!(context.length, 14);
|
||||
// assert_eq!(context.progress, 14);
|
||||
|
||||
// let chunk_packet = ChunkPacket {
|
||||
// sequence: 1,
|
||||
// chunk: b"Hello, world!".to_vec(),
|
||||
// };
|
||||
// assert_eq!(
|
||||
// on_chunk(&mut context, chunk_packet),
|
||||
// Status::Err("Expected sequence 1, but got 1.".into())
|
||||
// );
|
||||
|
||||
// context.files.clear();
|
||||
// let chunk_packet = ChunkPacket {
|
||||
// sequence: 0,
|
||||
// chunk: b"Hello, world!".to_vec(),
|
||||
// };
|
||||
// assert_eq!(
|
||||
// on_chunk(&mut context, chunk_packet),
|
||||
// Status::Err("Invalid file index.".into())
|
||||
// );
|
||||
|
||||
// Test a chunk packet with no shared key
|
||||
let mut context = Context {
|
||||
hmac: vec![],
|
||||
sender,
|
||||
key: EphemeralSecret::random(&mut OsRng),
|
||||
shared_key: None,
|
||||
files: vec![File {
|
||||
name: "file1.txt".to_string(),
|
||||
size: 100,
|
||||
progress: 0,
|
||||
handle: fs::File::create("file1.txt").unwrap(),
|
||||
}],
|
||||
sequence: 0,
|
||||
index: 0,
|
||||
progress: 0,
|
||||
length: 0,
|
||||
};
|
||||
let chunk_packet = ChunkPacket {
|
||||
sequence: 0,
|
||||
chunk: b"Hello, world!".to_vec(),
|
||||
};
|
||||
assert_eq!(
|
||||
on_chunk(&mut context, chunk_packet),
|
||||
Status::Err("Invalid chunk packet: no shared key established".into())
|
||||
);
|
||||
}
|
||||
}
|
||||
38
caesar-core/src/receiver/http_client.rs
Normal file
38
caesar-core/src/receiver/http_client.rs
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
use hex;
|
||||
use reqwest::{self, Client};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::error;
|
||||
|
||||
use crate::relay::transfer::TransferResponse;
|
||||
|
||||
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
|
||||
|
||||
pub async fn download_info(relay: &str, name: &str) -> Result<TransferResponse> {
|
||||
let url = String::from(relay);
|
||||
let hashed_name = Sha256::digest(name.as_bytes());
|
||||
let hashed_string = hex::encode(hashed_name);
|
||||
|
||||
match reqwest::get(format!("{}/download/{}", url, hashed_string)).await {
|
||||
Ok(resp) => match resp.json::<TransferResponse>().await {
|
||||
Ok(res) => Ok(res),
|
||||
Err(e) => Err(Box::new(e)),
|
||||
},
|
||||
Err(err) => {
|
||||
error!("Error: {err}");
|
||||
Err(Box::new(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn download_success(relay: &str, name: &str) -> Result<()> {
|
||||
let url = String::from(relay);
|
||||
let hashed_name = Sha256::digest(name.as_bytes());
|
||||
let hashed_string = hex::encode(hashed_name);
|
||||
|
||||
let client = Client::new();
|
||||
let _ = client
|
||||
.post(format!("{}/download_success/{}", url, hashed_string))
|
||||
.send()
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
113
caesar-core/src/receiver/mod.rs
Normal file
113
caesar-core/src/receiver/mod.rs
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
/// This module is the entry point for the receiver command.
|
||||
/// It contains a single function, `start_receiver`, which is the
|
||||
/// entry point for the receiver program.
|
||||
///
|
||||
/// The `start_receiver` function takes a `String` which is the URL or
|
||||
/// invite code for the room that the receiver should join. If the
|
||||
/// URL is invalid or does not contain an invite code fragment,
|
||||
/// the function falls back to using the command line arguments to get
|
||||
/// the file paths to be sent.
|
||||
///
|
||||
/// The `start_receiver` function first creates a request to connect
|
||||
/// to the WebSocket server with a specific origin. This is done to
|
||||
/// prevent cross-origin requests, which are not allowed by the
|
||||
/// WebSocket protocol.
|
||||
///
|
||||
/// If creating the request succeeds, the function inserts the origin
|
||||
/// into the request headers. Then, it attempts to connect to the
|
||||
/// server using the `connect_async` function from the
|
||||
/// `tokio_tungstenite` crate.
|
||||
///
|
||||
/// If the connection attempt succeeds, the function extracts the
|
||||
/// invite code fragment from the URL and passes it to the `start`
|
||||
/// function in the `receiver::client` module. The `start` function is
|
||||
/// defined in the `receiver::client` module and is the function that
|
||||
/// interacts with the server to receive files.
|
||||
///
|
||||
/// If the connection attempt fails or the URL does not contain an
|
||||
/// invite code fragment, the function falls back to using the command
|
||||
/// line arguments to get the file paths to be sent. It then calls the
|
||||
/// `start` function in the `sender::client` module with the
|
||||
/// WebSocket stream and the file paths. The `start` function in the
|
||||
/// `sender::client` module is defined in the `sender::client`
|
||||
/// module and is the function that sends the files over the
|
||||
/// WebSocket connection.
|
||||
///
|
||||
/// The `start` function takes ownership of the WebSocket stream and
|
||||
/// the file paths, so we pass them by value.
|
||||
pub mod client;
|
||||
pub mod http_client;
|
||||
|
||||
use crate::{receiver::client as receiver, sender::util::replace_protocol};
|
||||
|
||||
use tokio_tungstenite::{
|
||||
connect_async,
|
||||
tungstenite::{client::IntoClientRequest, http::HeaderValue},
|
||||
};
|
||||
use tracing::{debug, error};
|
||||
|
||||
pub async fn start_receiver(relay: &str, name: &str) {
|
||||
let http_url = replace_protocol(relay);
|
||||
let res = http_client::download_info(http_url.as_str(), name)
|
||||
.await
|
||||
.unwrap();
|
||||
debug!("Got room_id from Server: {:?}", res);
|
||||
let res_ip = String::from("ws://") + res.ip.as_str() + ":9000";
|
||||
|
||||
if let Err(local_err) = start_ws_com(res_ip.as_str(), res.local_room_id.as_str()).await {
|
||||
debug!("Failed to connect local: {local_err}");
|
||||
if let Err(relay_err) = start_ws_com(relay, res.relay_room_id.as_str()).await {
|
||||
debug!("Failed to connect remote: {relay_err}");
|
||||
}
|
||||
}
|
||||
let success = http_client::download_success(http_url.as_str(), name).await;
|
||||
match success {
|
||||
Ok(()) => debug!("Success"),
|
||||
Err(e) => error!("Error: {e:?}"),
|
||||
};
|
||||
|
||||
// if let Err(e) = start_ws_com(res_ip.as_str(), res.local_room_id.as_str()).await {
|
||||
// debug!("Failed to connect local with first room_id: {e}");
|
||||
// if let Err(e) = start_ws_com(relay, res.relay_room_id.as_str()).await {
|
||||
// debug!("Failed to connect remote with first room_id: {e}");
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
pub async fn start_ws_com(relay: &str, name: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = String::from(relay) + "/ws";
|
||||
let Ok(mut request) = url.into_client_request() else {
|
||||
println!("Error: Failed to create request.");
|
||||
return Err("Failed to create request".into());
|
||||
};
|
||||
|
||||
// Insert the origin into the request headers to prevent
|
||||
// cross-origin requests.
|
||||
request
|
||||
.headers_mut()
|
||||
.insert("Origin", HeaderValue::from_str(relay).unwrap());
|
||||
|
||||
println!("Attempting to connect...");
|
||||
|
||||
let _ = match tokio::time::timeout(std::time::Duration::from_secs(5), connect_async(request))
|
||||
.await
|
||||
{
|
||||
Ok(Ok((socket, _))) => {
|
||||
receiver::start(socket, name).await;
|
||||
Ok(())
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
error!("Error: Failed to connect: {e:?}");
|
||||
Err(Box::new(e))
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error: Timeout reached for local connection attempt");
|
||||
Err(Box::new(e))
|
||||
}?,
|
||||
};
|
||||
// The start function is defined in the
|
||||
// receiver::client module and is the function that interacts with
|
||||
// the server to receive files.
|
||||
// receiver::start(socket, name).await
|
||||
Ok(())
|
||||
}
|
||||
65
caesar-core/src/relay/appstate.rs
Normal file
65
caesar-core/src/relay/appstate.rs
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::relay::room::Room;
|
||||
use crate::relay::transfer::TransferResponse;
|
||||
|
||||
/// A struct that holds all of the rooms that the server knows about.
|
||||
///
|
||||
/// The rooms are stored in a `HashMap` with the room ID as the key and the
|
||||
/// room as the value. This means that looking up a room by its ID is an O(1)
|
||||
/// operation, which is very fast.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AppState {
|
||||
pub rooms: HashMap<String, Room>,
|
||||
pub transfers: Vec<TransferResponse>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
/// Creates a new `Server` with an empty list of rooms.
|
||||
///
|
||||
/// The `rooms` field of the returned `Server` is an empty `HashMap`.
|
||||
/// This means that the server will not have any rooms when it is first
|
||||
/// created.
|
||||
///
|
||||
/// This function returns an `Arc<RwLock<Server>>` because the server
|
||||
/// needs to be shared between different parts of the program. The
|
||||
/// `Arc` makes it so that the server can be shared by multiple threads,
|
||||
/// and the `RwLock` makes it so that the server can be read from and
|
||||
/// written to from multiple threads at the same time.
|
||||
///
|
||||
/// The `Arc` and `RwLock` are both parts of the `tokio` library, which
|
||||
/// provides asynchronous programming tools for Rust.
|
||||
///
|
||||
/// The `Arc` and `RwLock` are used together to create a Mutex-like
|
||||
/// object that can be shared between threads. The main difference
|
||||
/// between a Mutex and an `Arc<RwLock<T>>` is that a Mutex can only be
|
||||
/// locked by one thread at a time, while an `Arc<RwLock<T>>` can be
|
||||
/// locked by multiple threads at the same time.
|
||||
///
|
||||
/// This function is used to create a new `Server` and share it between
|
||||
/// different parts of the program. The `Server` is shared because it
|
||||
/// needs to be able to handle connections from multiple clients at the
|
||||
/// same time.
|
||||
pub fn new() -> Arc<RwLock<AppState>> {
|
||||
// Create a new `Server` instance.
|
||||
Arc::new(RwLock::new(AppState {
|
||||
// Initialize the list of rooms to be empty.
|
||||
rooms: HashMap::new(),
|
||||
transfers: Vec::new(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_new() {
|
||||
let app_state = AppState::new();
|
||||
|
||||
assert!(Arc::ptr_eq(&app_state, &app_state.clone()));
|
||||
}
|
||||
}
|
||||
499
caesar-core/src/relay/client.rs
Normal file
499
caesar-core/src/relay/client.rs
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
use axum::extract::ws::Message;
|
||||
use futures_util::{future::join_all, stream::SplitSink, SinkExt};
|
||||
use std::{sync::Arc, vec};
|
||||
use tokio::{sync::Mutex, sync::RwLock};
|
||||
use tracing::{debug, error};
|
||||
|
||||
use crate::relay::appstate::AppState;
|
||||
use crate::relay::room::Room;
|
||||
use crate::relay::RequestPacket;
|
||||
use crate::relay::ResponsePacket;
|
||||
use uuid::Uuid;
|
||||
|
||||
type Sender = Arc<Mutex<SplitSink<axum::extract::ws::WebSocket, Message>>>;
|
||||
/// This struct represents a single client connection to the server.
|
||||
///
|
||||
/// A `Client` instance holds a `Sender` and a `room_id`.
|
||||
///
|
||||
/// The `Sender` is a type alias for a `tokio::sync::mpsc::Sender<Message>`.
|
||||
/// It is used to send messages to the client.
|
||||
///
|
||||
/// The `room_id` is an `Option<String>`. It is used to keep track of which
|
||||
/// room the client is currently in. If the `room_id` is `None`, then the
|
||||
/// client is not in any room. If the `room_id` is `Some(id)`, where `id` is a
|
||||
/// `String`, then the client is in the room with the ID `id`.
|
||||
///
|
||||
/// The `room_id` is used to keep track of which room the client is in so
|
||||
/// that the server knows which room to send messages to. When a client
|
||||
/// joins a room, their `room_id` is set to the ID of the room that they
|
||||
/// joined. When a client leaves a room, their `room_id` is set to `None`.
|
||||
///
|
||||
/// The `Client` struct is used to keep track of which room each client is
|
||||
/// in. It is used by the `Server` to determine which room to send messages
|
||||
/// to.
|
||||
///
|
||||
#[derive(Debug)]
|
||||
pub struct Client {
|
||||
sender: Sender,
|
||||
room_id: Option<String>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Creates a new `Client` instance.
|
||||
///
|
||||
/// The `sender` argument is a `Sender` for sending messages to the client.
|
||||
/// It is used by the `Server` to send messages to the client.
|
||||
///
|
||||
/// The `room_id` field of the `Client` instance is set to `None` initially.
|
||||
/// This is because the client is not in any room when they first connect
|
||||
/// to the server.
|
||||
///
|
||||
/// The `sender` field of the `Client` instance is used to send messages to
|
||||
/// the client. When the server wants to send a message to the client, it
|
||||
/// uses the `sender` to send the message.
|
||||
///
|
||||
/// The `Client` instance is used by the `Server` to keep track of which
|
||||
/// room each client is in. It is used by the `Server` to determine which
|
||||
/// room to send messages to.
|
||||
pub fn new(sender: Sender) -> Client {
|
||||
Client {
|
||||
sender,
|
||||
room_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a message to a client.
|
||||
///
|
||||
/// This function takes a `sender` argument, which is a `Mutex` guard
|
||||
/// for a WebSocket connection. The `sender` is used to send a message
|
||||
/// to the client.
|
||||
///
|
||||
/// The `message` argument is the message that is sent to the client. It
|
||||
/// is a WebSocket message.
|
||||
///
|
||||
/// This function locks the `sender` Mutex to ensure that only one thread
|
||||
/// can send a message at a time. This is because the SplitSink that the
|
||||
/// `sender` mutex guards is not thread-safe, and sending a message from
|
||||
/// multiple threads could result in the messages being sent out of order.
|
||||
///
|
||||
/// If sending the message fails, this function logs an error message.
|
||||
async fn send(&self, sender: Sender, message: Message) {
|
||||
let mut sender = sender.lock().await;
|
||||
if let Err(error) = sender.send(message).await {
|
||||
error!("Failed to send message to the client: {}", error);
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a packet to a client.
|
||||
///
|
||||
/// This function takes a `sender` argument, which is a `Mutex` guard
|
||||
/// for a WebSocket connection. The `sender` is used to send a message
|
||||
/// to the client.
|
||||
///
|
||||
/// The `packet` argument is the packet that is sent to the client. It
|
||||
/// is a struct that contains the data that is being sent.
|
||||
///
|
||||
/// This function serializes the `packet` using serde_json and sends it
|
||||
/// to the client as a WebSocket Text message.
|
||||
///
|
||||
/// This function locks the `sender` Mutex to ensure that only one thread
|
||||
/// can send a message at a time. This is because the SplitSink that the
|
||||
/// `sender` mutex guards is not thread-safe, and sending a message from
|
||||
/// multiple threads could result in the messages being sent out of order.
|
||||
async fn send_packet(&self, sender: Sender, packet: ResponsePacket) {
|
||||
let serialized_packet = serde_json::to_string(&packet).unwrap();
|
||||
|
||||
self.send(sender, Message::Text(serialized_packet)).await;
|
||||
}
|
||||
|
||||
/// Sends an error packet to a client.
|
||||
///
|
||||
/// This function takes a `sender` argument, which is a `Mutex` guard
|
||||
/// for a WebSocket connection. The `sender` is used to send a message
|
||||
/// to the client.
|
||||
///
|
||||
/// The `message` argument is the message that is sent to the client. It
|
||||
/// is a string that describes the error.
|
||||
///
|
||||
/// This function creates an error packet with the `message` and sends it
|
||||
/// to the client using the `send_packet` function.
|
||||
///
|
||||
/// This function locks the `sender` Mutex to ensure that only one thread
|
||||
/// can send a message at a time. This is because the SplitSink that the
|
||||
/// `sender` mutex guards is not thread-safe, and sending a message from
|
||||
/// multiple threads could result in the messages being sent out of order.
|
||||
async fn send_error_packet(&self, sender: Sender, message: String) {
|
||||
let error_packet = ResponsePacket::Error { message };
|
||||
|
||||
self.send_packet(sender, error_packet).await
|
||||
}
|
||||
|
||||
/// Handles a CreateRoom request from a client.
|
||||
///
|
||||
/// This function is called when a client sends a CreateRoom request to
|
||||
/// the server. The server will create a new room with the specified
|
||||
/// size and return the room's identifier to the client.
|
||||
///
|
||||
/// This function takes a `server` argument, which is a `RwLock`
|
||||
/// guard for the server's state. The `server` is used to check if the
|
||||
/// current client is already in a room, and to insert the new room into
|
||||
/// the server's state.
|
||||
///
|
||||
/// If the current client is already in a room, this function returns
|
||||
/// without doing anything. This is to prevent a client from being in
|
||||
/// multiple rooms at the same time.
|
||||
///
|
||||
/// If there is already a room with the same identifier as the one that
|
||||
/// is being created, this function sends an error packet to the client
|
||||
/// and returns.
|
||||
///
|
||||
/// If there is no existing room with the same identifier, this function
|
||||
/// creates a new room with the specified size and inserts it into the
|
||||
/// server's state. It then sends a CreateRoom response packet to the
|
||||
/// client with the room's identifier.
|
||||
///
|
||||
/// This function locks the `server` RwLock to ensure that only one
|
||||
/// thread can access the server's state at a time. This is because the
|
||||
/// server's state is not thread-safe, and accessing it from multiple
|
||||
/// threads could result in undefined behavior.
|
||||
async fn handle_create_room(&mut self, server: &RwLock<AppState>, id: Option<String>) {
|
||||
let mut server = server.write().await;
|
||||
|
||||
// If the current client is already in a room, do nothing.
|
||||
if server.rooms.iter().any(|(_, room)| {
|
||||
room.senders
|
||||
.iter()
|
||||
.any(|sender| Arc::ptr_eq(sender, &self.sender))
|
||||
}) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Generate a new room identifier.
|
||||
let size = Room::DEFAULT_ROOM_SIZE;
|
||||
let room_id = match id {
|
||||
Some(id) => id,
|
||||
None => Uuid::new_v4().to_string(),
|
||||
};
|
||||
|
||||
// If there is already a room with the same identifier, send an error
|
||||
// packet to the client and return.
|
||||
if server.rooms.contains_key(&room_id) {
|
||||
drop(server);
|
||||
|
||||
return self
|
||||
.send_error_packet(
|
||||
self.sender.clone(),
|
||||
"A room with that identifier already exists.".to_string(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Create a new room with the specified size and insert it into the
|
||||
// server's state.
|
||||
let mut room = Room::new(size);
|
||||
room.senders.push(self.sender.clone());
|
||||
|
||||
server.rooms.insert(room_id.clone(), room);
|
||||
|
||||
// Set the client's room ID to the new room's identifier.
|
||||
self.room_id = Some(room_id.clone());
|
||||
|
||||
drop(server);
|
||||
|
||||
// Send a CreateRoom response packet to the client with the room's
|
||||
// identifier.
|
||||
debug!("Room created");
|
||||
self.send_packet(self.sender.clone(), ResponsePacket::Create { id: room_id })
|
||||
.await
|
||||
}
|
||||
|
||||
/// This function is called when the client sends a JoinRoom packet.
|
||||
///
|
||||
/// If the client is already in a room, then this function does nothing.
|
||||
///
|
||||
/// If the client is not in a room, then the function checks if the room
|
||||
/// specified in the packet exists. If the room does not exist, an error
|
||||
/// packet is sent to the client with a message indicating that the room
|
||||
/// does not exist.
|
||||
///
|
||||
/// If the room does exist, then the function checks if the room is full.
|
||||
/// If the room is full, an error packet is sent to the client with a
|
||||
/// message indicating that the room is full.
|
||||
///
|
||||
/// If the room is not full, then the client is added to the room and the
|
||||
/// function sends a JoinRoom response packet to the client with the size
|
||||
/// of the room (excluding the client itself) and a `size` field set to
|
||||
/// `None`. The response packet is sent to all other clients in the room.
|
||||
async fn handle_join_room(&mut self, server: &RwLock<AppState>, room_id: String) {
|
||||
let mut server = server.write().await;
|
||||
|
||||
// If the client is already in a room, do nothing.
|
||||
if server.rooms.iter().any(|(_, room)| {
|
||||
room.senders
|
||||
.iter()
|
||||
.any(|sender| Arc::ptr_eq(sender, &self.sender))
|
||||
}) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get a mutable reference to the room specified in the packet.
|
||||
// If the room does not exist, return an error to the client.
|
||||
let Some(room) = server.rooms.get_mut(&room_id) else {
|
||||
drop(server);
|
||||
|
||||
return self
|
||||
.send_error_packet(self.sender.clone(), "The room does not exist.".to_string())
|
||||
.await;
|
||||
};
|
||||
|
||||
// If the room is full, return an error to the client.
|
||||
if room.senders.len() >= room.size {
|
||||
drop(server);
|
||||
|
||||
return self
|
||||
.send_error_packet(self.sender.clone(), "The room is full.".to_string())
|
||||
.await;
|
||||
}
|
||||
|
||||
// Add the client to the room and set the client's room ID to the new
|
||||
// room's identifier.
|
||||
room.senders.push(self.sender.clone());
|
||||
self.room_id = Some(room_id);
|
||||
|
||||
// Create a list of futures to send JoinRoom response packets to all
|
||||
// other clients in the room. The `size` field of the response packet is
|
||||
// set to `None` if the client sending the packet is the one joining the
|
||||
// room. Otherwise, the `size` field is set to the number of clients in
|
||||
// the room minus one (to exclude the client joining the room).
|
||||
let mut futures = vec![];
|
||||
for sender in &room.senders {
|
||||
if Arc::ptr_eq(sender, &self.sender) {
|
||||
futures.push(self.send_packet(
|
||||
sender.clone(),
|
||||
ResponsePacket::Join {
|
||||
size: Some(room.senders.len() - 1),
|
||||
},
|
||||
));
|
||||
} else {
|
||||
futures.push(self.send_packet(sender.clone(), ResponsePacket::Join { size: None }));
|
||||
}
|
||||
}
|
||||
|
||||
drop(server);
|
||||
join_all(futures).await;
|
||||
}
|
||||
|
||||
/// Handles a request to leave a room.
|
||||
///
|
||||
/// This function is called when a client sends a `LeaveRoom` request
|
||||
/// packet. The function obtains a write lock on the server's state and
|
||||
/// does the following:
|
||||
///
|
||||
/// 1. Gets the room ID of the client who sent the request. If the client is
|
||||
/// not in a room, the function returns early.
|
||||
/// 2. Tries to get a mutable reference to the room with the obtained room
|
||||
/// ID. If the room does not exist, the function returns early.
|
||||
/// 3. Finds the index of the client's sender in the room's list of senders.
|
||||
/// If the client is not in the room, the function returns early.
|
||||
/// 4. Removes the client's sender from the room's list of senders.
|
||||
/// 5. Sets the client's room ID to `None`.
|
||||
/// 6. Creates a list of futures to send `LeaveRoom` response packets to
|
||||
/// all other clients in the room. The `index` field of the response
|
||||
/// packet is set to the index of the client's sender in the room's list
|
||||
/// of senders.
|
||||
/// 7. If the room is now empty, removes the room from the server's list
|
||||
/// of rooms.
|
||||
/// 8. Drops the write lock on the server's state.
|
||||
/// 9. Waits for all futures to complete.
|
||||
async fn handle_leave_room(&mut self, server: &RwLock<AppState>) {
|
||||
// Obtain a write lock on the server's state.
|
||||
let mut server = server.write().await;
|
||||
|
||||
// Get the room ID of the client who sent the request.
|
||||
let Some(room_id) = self.room_id.clone() else {
|
||||
// If the client is not in a room, return early.
|
||||
return;
|
||||
};
|
||||
|
||||
// Try to get a mutable reference to the room with the obtained room ID.
|
||||
let Some(room) = server.rooms.get_mut(&room_id) else {
|
||||
// If the room does not exist, return early.
|
||||
return;
|
||||
};
|
||||
|
||||
// Find the index of the client's sender in the room's list of senders.
|
||||
let Some(index) = room
|
||||
.senders
|
||||
.iter()
|
||||
.position(|sender| Arc::ptr_eq(sender, &self.sender))
|
||||
else {
|
||||
// If the client is not in the room, return early.
|
||||
return;
|
||||
};
|
||||
|
||||
// Remove the client's sender from the room's list of senders.
|
||||
room.senders.remove(index);
|
||||
|
||||
// Set the client's room ID to `None`.
|
||||
self.room_id = None;
|
||||
|
||||
// Create a list of futures to send `LeaveRoom` response packets to
|
||||
// all other clients in the room. The `index` field of the response
|
||||
// packet is set to the index of the client's sender in the room's list
|
||||
// of senders.
|
||||
let mut futures = vec![];
|
||||
for sender in &room.senders {
|
||||
futures.push(self.send_packet(sender.clone(), ResponsePacket::Leave { index }));
|
||||
}
|
||||
|
||||
// If the room is now empty, removes the room from the server's list
|
||||
// of rooms.
|
||||
if room.senders.is_empty() {
|
||||
server.rooms.remove(&room_id);
|
||||
}
|
||||
|
||||
// Drop the write lock on the server's state.
|
||||
drop(server);
|
||||
|
||||
// Wait for all futures to complete.
|
||||
join_all(futures).await;
|
||||
}
|
||||
|
||||
/// This function handles an incoming message from a client.
|
||||
///
|
||||
/// The message can be one of four types: `Text`, `Binary`, `Ping`, or `Close`.
|
||||
///
|
||||
/// If the message is `Text`, the function parses the message as a `RequestPacket` and
|
||||
/// calls the appropriate function to handle the request. If the message cannot be
|
||||
/// parsed as a `RequestPacket`, the function does nothing and returns early.
|
||||
///
|
||||
/// If the message is `Binary`, the function first acquires a read lock on the server's
|
||||
/// state. If the client is not currently in a room, the function drops the read lock and
|
||||
/// returns early. If the client is not in a room, or if the room does not exist, the
|
||||
/// function drops the read lock and returns early.
|
||||
///
|
||||
/// The function then finds the index of the client's sender in the room's list of
|
||||
/// senders. If the client's sender is not in the room's list of senders, the function
|
||||
/// drops the read lock and returns early.
|
||||
///
|
||||
/// The function then gets the binary data from the message and sets the first byte to
|
||||
/// the index of the client's sender in the room's list of senders. If there is no
|
||||
/// binary data in the message, the function drops the read lock and returns early.
|
||||
///
|
||||
/// The function then determines where to send the message. If the first byte of the
|
||||
/// message is less than the number of clients in the room, the function sends the message
|
||||
/// to the client at that index in the room's list of senders. If the first byte of the
|
||||
/// message is equal to the number of clients in the room plus one, the function sends the
|
||||
/// message to all clients in the room, excluding the client that sent the message.
|
||||
///
|
||||
/// If the first byte of the message is any other value, the function drops the read
|
||||
/// lock and returns early.
|
||||
///
|
||||
/// Finally, the function drops the read lock and waits for all futures to complete.
|
||||
///
|
||||
/// If the message is `Ping`, the function prints a message to stdout.
|
||||
///
|
||||
/// If the message is `Pong`, the function prints a message to stdout.
|
||||
///
|
||||
/// If the message is `Close`, the function prints a message to stdout and calls the
|
||||
/// `handle_close` function.
|
||||
pub async fn handle_message(&mut self, server: &RwLock<AppState>, message: Message) {
|
||||
match message {
|
||||
Message::Text(text) => {
|
||||
let packet = match serde_json::from_str(&text) {
|
||||
Ok(packet) => packet,
|
||||
Err(_) => return,
|
||||
};
|
||||
match packet {
|
||||
RequestPacket::Create { id } => self.handle_create_room(server, id).await,
|
||||
RequestPacket::Join { id } => self.handle_join_room(server, id).await,
|
||||
RequestPacket::Leave => self.handle_leave_room(server).await,
|
||||
}
|
||||
}
|
||||
Message::Binary(_) => {
|
||||
// Acquire a read lock on the server's state.
|
||||
let server = server.read().await;
|
||||
|
||||
// If the client is not currently in a room, return early.
|
||||
let Some(room_id) = &self.room_id else {
|
||||
drop(server);
|
||||
return;
|
||||
};
|
||||
|
||||
// If the room does not exist, return early.
|
||||
let Some(room) = server.rooms.get(room_id) else {
|
||||
drop(server);
|
||||
return;
|
||||
};
|
||||
|
||||
// Find the index of the client's sender in the room's list of senders.
|
||||
let Some(index) = room
|
||||
.senders
|
||||
.iter()
|
||||
.position(|sender| Arc::ptr_eq(sender, &self.sender))
|
||||
else {
|
||||
drop(server);
|
||||
return;
|
||||
};
|
||||
|
||||
// Get the binary data from the message and set the first byte to
|
||||
// the index of the client's sender in the room's list of senders.
|
||||
let mut data = message.into_data();
|
||||
if data.is_empty() {
|
||||
drop(server);
|
||||
return;
|
||||
}
|
||||
|
||||
let source = u8::try_from(index).unwrap();
|
||||
|
||||
// Determine where to send the message.
|
||||
let destination = usize::from(data[0]);
|
||||
data[0] = source;
|
||||
|
||||
// Send the message to the client at the destination index in the
|
||||
// room's list of senders.
|
||||
if destination < room.senders.len() {
|
||||
let sender = room.senders[destination].clone();
|
||||
|
||||
drop(server);
|
||||
return self.send(sender, Message::Binary(data)).await;
|
||||
}
|
||||
|
||||
// Send the message to all clients in the room, excluding the
|
||||
// client that sent the message.
|
||||
if destination == usize::from(u8::MAX) {
|
||||
let mut futures = vec![];
|
||||
for sender in &room.senders {
|
||||
if Arc::ptr_eq(sender, &self.sender) {
|
||||
continue;
|
||||
}
|
||||
|
||||
futures.push(self.send(sender.clone(), Message::Binary(data.clone())));
|
||||
}
|
||||
|
||||
drop(server);
|
||||
join_all(futures).await;
|
||||
}
|
||||
}
|
||||
Message::Ping(_) => {
|
||||
println!("Got Message Type Ping");
|
||||
}
|
||||
Message::Pong(_) => {
|
||||
println!("Got Message Type Pong");
|
||||
}
|
||||
Message::Close(_) => {
|
||||
println!("Got Message Type Close");
|
||||
self.handle_close(server).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_close(&mut self, server: &RwLock<AppState>) {
|
||||
self.handle_leave_room(server).await
|
||||
}
|
||||
}
|
||||
// TODO: Add tests
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// use super::*;
|
||||
}
|
||||
72
caesar-core/src/relay/mod.rs
Normal file
72
caesar-core/src/relay/mod.rs
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
pub mod appstate;
|
||||
pub mod client;
|
||||
pub mod room;
|
||||
pub mod server;
|
||||
pub mod transfer;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
// This enum is used to represent the different types of requests that a client
|
||||
// can send to the server.
|
||||
//
|
||||
// The requests that a client can send are:
|
||||
//
|
||||
// * Join: A request to join a room. The request contains the ID of the room
|
||||
// that the client wants to join.
|
||||
// * Create: A request to create a new room.
|
||||
// * Leave: A request to leave the current room.
|
||||
pub enum RequestPacket {
|
||||
Join {
|
||||
// The ID of the room that the client wants to join.
|
||||
id: String,
|
||||
},
|
||||
Create {
|
||||
id: Option<String>,
|
||||
},
|
||||
Leave,
|
||||
}
|
||||
|
||||
/// This enum is used to represent the different types of responses that the
|
||||
/// server can send to the client.
|
||||
///
|
||||
/// The responses that the server can send are:
|
||||
///
|
||||
/// * Join: A response to a `Join` request from the client. If the client
|
||||
/// successfully joined a room, the `size` field will be `Some` and contain
|
||||
/// the size of the room. If the client could not join a room, the `size` field
|
||||
/// will be `None`.
|
||||
/// * Create: A response to a `Create` request from the client. If the server
|
||||
/// successfully created a room, the `id` field will contain the ID of the
|
||||
/// room. If the server could not create a room, the `id` field will be empty.
|
||||
/// * Leave: A response to a `Leave` request from the client. If the client
|
||||
/// successfully left a room, the `index` field will contain the index of the
|
||||
/// client that left the room. If the client could not leave a room, the
|
||||
/// `index` field will be 0.
|
||||
/// * Error: A response to indicate that an error occurred. The `message`
|
||||
/// field will contain a description of the error.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
pub enum ResponsePacket {
|
||||
Join {
|
||||
/// The size of the room that the client joined. If the client could
|
||||
/// not join a room, this field will be `None`.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
size: Option<usize>,
|
||||
},
|
||||
Create {
|
||||
/// The ID of the room that the server created. If the server could
|
||||
/// not create a room, this field will be empty.
|
||||
id: String,
|
||||
},
|
||||
Leave {
|
||||
/// The index of the client that left the room. If the client could not
|
||||
/// leave a room, this field will be 0.
|
||||
index: usize,
|
||||
},
|
||||
Error {
|
||||
/// A description of the error that occurred.
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
86
caesar-core/src/relay/room.rs
Normal file
86
caesar-core/src/relay/room.rs
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
use axum::extract::ws::{Message, WebSocket};
|
||||
use futures_util::stream::SplitSink;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
// A type alias for a sender to a WebSocket connection.
|
||||
//
|
||||
// The sender is a mutex-guarded, split sink of a WebSocket stream and Message
|
||||
// values. It is used to send messages to a client.
|
||||
//
|
||||
// The Mutex is used to ensure that only one thread can send a message at a
|
||||
// time. This is because the SplitSink is not thread-safe, and sending a
|
||||
// message from multiple threads could result in the messages being sent
|
||||
// out of order.
|
||||
//
|
||||
// The SplitSink is used to send messages to a client. It is the part of the
|
||||
// WebSocket stream that handles the sending of messages.
|
||||
//
|
||||
// The WebSocket stream is the underlying connection to the client. It is used
|
||||
// to send and receive messages.
|
||||
//
|
||||
// The Message value is the type of data that is sent over the WebSocket
|
||||
// connection. It is a struct that contains the data that is being sent.
|
||||
//
|
||||
// The type alias is used so that the type is not mentioned every time it is
|
||||
// used. This makes the code easier to read and understand.
|
||||
type Sender = Arc<Mutex<SplitSink<WebSocket, Message>>>;
|
||||
|
||||
/// A `Room` is a collection of clients that are connected to each other.
|
||||
///
|
||||
/// Each room has a set of clients, represented by a `Vec` of `Sender`
|
||||
/// instances. The `Sender` instances are used to send messages to the
|
||||
/// clients in the room.
|
||||
///
|
||||
/// The `senders` field is the list of senders that are connected to each
|
||||
/// other. Each sender is a mutex-guarded, split sink of a WebSocket
|
||||
/// stream and Message values. This is explained in more detail in the
|
||||
/// documentation for the `Sender` type alias in the `packets` module.
|
||||
///
|
||||
/// The `size` field is the maximum number of clients that a room can have.
|
||||
/// When a room reaches its maximum size, no more clients can join the room.
|
||||
/// This is used to prevent rooms from getting too full and causing the
|
||||
/// server to run out of memory.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Room {
|
||||
pub senders: Vec<Sender>,
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
impl Room {
|
||||
/// The default size of a room.
|
||||
///
|
||||
/// This is the size that a room will have when it is created.
|
||||
pub const DEFAULT_ROOM_SIZE: usize = 2;
|
||||
|
||||
/// Creates a new `Room` with the given size.
|
||||
///
|
||||
/// The `size` parameter is the maximum number of clients that can join the
|
||||
/// room. If `size` is 0, then the room will not be able to hold any
|
||||
/// clients.
|
||||
///
|
||||
/// The `senders` field of the returned `Room` is an empty vector.
|
||||
///
|
||||
/// The `size` field of the returned `Room` is `size`.
|
||||
pub fn new(size: usize) -> Room {
|
||||
Room {
|
||||
// Initialize the list of senders to be empty.
|
||||
senders: Vec::new(),
|
||||
// Set the size of the room.
|
||||
size,
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_room_new() {
|
||||
let room = Room::new(5);
|
||||
|
||||
assert_eq!(room.size, 5);
|
||||
|
||||
assert!(room.senders.is_empty());
|
||||
}
|
||||
}
|
||||
386
caesar-core/src/relay/server.rs
Normal file
386
caesar-core/src/relay/server.rs
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
/// This function starts the WebSocket server.
|
||||
///
|
||||
/// It configures the server to listen on the specified host and port. If
|
||||
/// these values are not specified in the environment, it falls back to using
|
||||
/// the defaults of "0.0.0.0" for the host and "8000" for the port.
|
||||
///
|
||||
/// It then sets up the application routes for the server. In this case, the
|
||||
/// only route is for the WebSocket connection.
|
||||
///
|
||||
/// The WebSocket route requires a `ConnectInfo` extractor to get the client's
|
||||
/// IP address, which is then used to store the client in a data structure
|
||||
/// keyed by their IP address. This allows for efficient lookup of clients by
|
||||
/// their IP address.
|
||||
///
|
||||
/// Finally, it starts the server by binding to the specified host and port,
|
||||
/// and running the application. If the server fails to bind to the specified
|
||||
/// host and port, it logs an error and exits.
|
||||
use axum::{
|
||||
extract::{ws::WebSocket, Json, Path, State, WebSocketUpgrade},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{get, post, put},
|
||||
Router,
|
||||
};
|
||||
|
||||
use futures_util::StreamExt;
|
||||
use serde_json::json;
|
||||
use std::{env, net::SocketAddr, sync::Arc};
|
||||
use tokio::{
|
||||
net::TcpListener,
|
||||
signal,
|
||||
sync::{Mutex, RwLock},
|
||||
};
|
||||
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::relay::client::Client;
|
||||
use crate::relay::transfer::TransferResponse;
|
||||
use crate::relay::{appstate::AppState, transfer::TransferRequest};
|
||||
|
||||
/// This function starts the WebSocket server.
|
||||
///
|
||||
/// It retrieves the environment variables that define how the server should
|
||||
/// be configured. If any of these variables are not defined, it sets a
|
||||
/// reasonable default value.
|
||||
///
|
||||
/// The environment variables are:
|
||||
///
|
||||
/// * `APP_ENVIRONMENT`: the environment the server is running in (defaults
|
||||
/// to "development").
|
||||
/// * `APP_HOST`: the host the server should listen on (defaults to "0.0.0.0").
|
||||
/// * `APP_PORT`: the port the server should listen on (defaults to "8000").
|
||||
/// * `APP_DOMAIN`: the domain the server is accessible at (defaults to "").
|
||||
///
|
||||
/// It then sets up the application routes for the server. In this case, the
|
||||
/// only route is for the WebSocket connection.
|
||||
///
|
||||
/// The WebSocket route requires a `ConnectInfo` extractor to get the client's
|
||||
/// IP address, which is then used to store the client in a data structure
|
||||
/// keyed by their IP address. This allows for efficient lookup of clients by
|
||||
/// their IP address.
|
||||
///
|
||||
/// Finally, it starts the server by binding to the specified host and port,
|
||||
/// and running the application. If the server fails to bind to the specified
|
||||
/// host and port, it logs an error and exits.
|
||||
pub async fn start_ws(port: Option<&i32>, listen_addr: Option<&String>) {
|
||||
// Retrieve environment variables and set defaults if necessary.
|
||||
let app_environemt = env::var("APP_ENVIRONMENT").unwrap_or("development".to_string());
|
||||
let app_host = match listen_addr {
|
||||
Some(address) => address.to_string(),
|
||||
None => env::var("APP_HOST").unwrap_or("0.0.0.0".to_string()),
|
||||
};
|
||||
let app_port = match port {
|
||||
Some(port) => port.to_string(),
|
||||
None => env::var("APP_PORT").unwrap_or("8000".to_string()),
|
||||
};
|
||||
|
||||
// Log information about the server's configuration.
|
||||
debug!("Server configured to accept connections on host {app_host}...",);
|
||||
debug!("Server configured to listen connections on port {app_port}...",);
|
||||
|
||||
// Based on the environment variable, set the logging level.
|
||||
match app_environemt.as_str() {
|
||||
"development" => {
|
||||
debug!("Running in development mode");
|
||||
}
|
||||
"production" => {
|
||||
debug!("Running in production mode");
|
||||
}
|
||||
_ => {
|
||||
debug!("Running in development mode");
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new server data structure.
|
||||
let server = AppState::new();
|
||||
|
||||
// Set up the application routes.
|
||||
let app = Router::new()
|
||||
.route("/ws", get(ws_handler))
|
||||
.route("/upload", put(upload_info))
|
||||
.route("/download/:name", get(download_info))
|
||||
.route("/download_success/:name", post(download_success))
|
||||
.with_state(server)
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(DefaultMakeSpan::default().include_headers(true)),
|
||||
);
|
||||
|
||||
// Attempt to bind to the specified host and port.
|
||||
if let Ok(listener) = TcpListener::bind(&format!("{}:{}", app_host, app_port)).await {
|
||||
// Log successful binding.
|
||||
info!("Listening on: {}", listener.local_addr().unwrap());
|
||||
|
||||
// Run the server.
|
||||
axum::serve(
|
||||
listener,
|
||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.unwrap();
|
||||
} else {
|
||||
// Log binding failure and exit.
|
||||
error!("Failed to listen on: {}:{}", app_host, app_port);
|
||||
}
|
||||
}
|
||||
|
||||
/// This function is an endpoint for the WebSocket route.
|
||||
///
|
||||
/// This function is called whenever a client makes a WebSocket request to
|
||||
/// the `/ws` endpoint.
|
||||
///
|
||||
/// The function takes two arguments:
|
||||
///
|
||||
/// - `ws`: This is the WebSocketUpgrade object, which is used to upgrade the
|
||||
/// HTTP connection to a WebSocket connection.
|
||||
/// - `State(shared_state)`: This is the state of the server, which is stored
|
||||
/// in a read-write lock. The state is shared between all WebSocket
|
||||
/// connections.
|
||||
/// - `ConnectInfo(addr)`: This is the information about the client that
|
||||
/// connected to the server. The function uses this information to log the
|
||||
/// address of the client that connected to the server.
|
||||
///
|
||||
/// The function upgrades the HTTP connection to a WebSocket connection using
|
||||
/// the `ws` argument. It then passes the upgraded WebSocket connection, along
|
||||
/// with the state of the server, to the `handle_socket` function.
|
||||
///
|
||||
/// The `handle_socket` function is defined in the `src/relay/mod.rs` file. It
|
||||
/// is the function that handles the WebSocket connection.
|
||||
///
|
||||
/// The `handle_socket` function takes three arguments:
|
||||
///
|
||||
/// - `socket`: This is the WebSocket connection that it should handle.
|
||||
/// - `who`: This is the address of the client that connected to the server.
|
||||
/// - `rooms`: This is the state of the server, which is stored in a read-write
|
||||
/// lock. The state is shared between all WebSocket connections.
|
||||
///
|
||||
/// The `handle_socket` function handles the WebSocket connection by calling
|
||||
/// the `handle_message` function on a `Client` object that it creates. The
|
||||
/// `handle_message` function is defined in the `src/relay/client.rs` file. The
|
||||
/// `handle_message` function handles incoming messages from the client and
|
||||
/// takes care of sending the appropriate response back to the client.
|
||||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(shared_state): State<Arc<RwLock<AppState>>>,
|
||||
// ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
) -> impl IntoResponse {
|
||||
debug!("Got Request on Websocket route");
|
||||
// debug!("WebSocket connection established from:{}", addr.to_string());
|
||||
debug!("Upgrading Connection");
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, shared_state))
|
||||
}
|
||||
|
||||
/// This function is called when a new WebSocket connection is established.
|
||||
/// The function takes three arguments:
|
||||
///
|
||||
/// - `socket`: This is the WebSocket connection that it should handle.
|
||||
/// - `who`: This is the address of the client that connected to the server.
|
||||
/// - `rooms`: This is the state of the server, which is stored in a read-write
|
||||
/// lock. The state is shared between all WebSocket connections.
|
||||
///
|
||||
/// The function creates a `Client` object, which will handle the WebSocket
|
||||
/// connection. The `Client` object is created with an Arc-wrapped Mutex
|
||||
/// containing the `sender` of the WebSocket connection. The `sender` is used to
|
||||
/// send messages to the client.
|
||||
///
|
||||
/// The function then creates a new `split` of the WebSocket connection, which
|
||||
/// is a pair of a `sender` and a `receiver`. The `sender` is used to send
|
||||
/// messages to the client, and the `receiver` is used to receive messages from
|
||||
/// the client. The `receiver` is wrapped in a `Stream` (which is an async
|
||||
/// iterator) so that the function can use the `next` method to receive messages
|
||||
/// from the client.
|
||||
///
|
||||
/// The function then enters a loop that receives incoming messages from the
|
||||
/// client and handles them. For each received message, the function calls the
|
||||
/// `handle_message` method on the `Client` object that it created. The
|
||||
/// `handle_message` method is defined in the `src/relay/client.rs` file. The
|
||||
/// `handle_message` method handles incoming messages from the client and
|
||||
/// takes care of sending the appropriate response back to the client.
|
||||
///
|
||||
/// If the function encounters an error while reading a message from the
|
||||
/// client, it logs the error and breaks out of the loop.
|
||||
///
|
||||
/// After the loop finishes (either because an error occurred or because the
|
||||
/// client disconnected), the function calls the `handle_close` method on the
|
||||
/// `Client` object that it created. The `handle_close` method is defined in the
|
||||
/// `src/relay/client.rs` file. The `handle_close` method handles the close event
|
||||
/// from the client.
|
||||
async fn handle_socket(socket: WebSocket, rooms: Arc<RwLock<AppState>>) {
|
||||
let (sender, mut receiver) = socket.split();
|
||||
|
||||
let sender = Arc::new(Mutex::new(sender));
|
||||
let mut client = Client::new(sender.clone());
|
||||
while let Some(message) = receiver.next().await {
|
||||
match message {
|
||||
Ok(message) => {
|
||||
client.handle_message(&rooms, message).await;
|
||||
}
|
||||
Err(error) => {
|
||||
warn!("Failed to read message from client: {}", error);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle the close event from the client.
|
||||
client.handle_close(&rooms).await
|
||||
}
|
||||
|
||||
/// This function sets up a signal handler for SIGINT (Ctrl+C) and SIGTERM
|
||||
/// (terminate) on Unix platforms. It does nothing on non-Unix platforms.
|
||||
///
|
||||
/// The function installs two signal handlers: one for SIGINT and one for
|
||||
/// SIGTERM. When either of these signals is received, the signal handler
|
||||
/// simply resolves the future with `()`. This allows the main function to
|
||||
/// wait for the signal handler to trigger a shutdown.
|
||||
///
|
||||
/// The function uses the `tokio::select!` macro to wait for either of the
|
||||
/// signal handlers to resolve. When the future returned by `tokio::select!`
|
||||
/// resolves, the function simply drops the value and does nothing else.
|
||||
///
|
||||
/// The function does not actually do anything itself. It simply waits for
|
||||
/// one of the signal handlers to trigger a shutdown.
|
||||
async fn shutdown_signal() {
|
||||
// Install a signal handler for SIGINT (Ctrl+C). This future resolves
|
||||
// when the user presses Ctrl+C.
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
// Install a signal handler for SIGTERM (terminate). This future
|
||||
// resolves when the operating system sends a SIGTERM signal to the
|
||||
// program.
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
// If we are not on a Unix platform, we don't need to install a signal
|
||||
// handler for SIGTERM. Instead, we create a future that never resolves.
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
// Wait for either of the two signal handlers to resolve. When one of them
|
||||
// resolves, the other one may still be waiting, but it doesn't matter
|
||||
// because we don't need to do anything else.
|
||||
tokio::select! {
|
||||
// If the Ctrl+C signal handler resolves, drop the value and do
|
||||
// nothing else.
|
||||
_ = ctrl_c => {},
|
||||
// If the terminate signal handler resolves, drop the value and do
|
||||
// nothing else.
|
||||
_ = terminate => {},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn upload_info(
|
||||
State(shared_state): State<Arc<RwLock<AppState>>>,
|
||||
// ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
Json(payload): Json<TransferRequest>,
|
||||
) -> impl IntoResponse {
|
||||
// debug!("Got upload request from {}", addr.ip().to_string());
|
||||
let mut data = shared_state.write().await;
|
||||
match data
|
||||
.transfers
|
||||
.iter_mut()
|
||||
.find(|request| request.name == payload.name)
|
||||
{
|
||||
Some(request) => {
|
||||
debug!("Found Transfer");
|
||||
debug!("Request is: {:?}", request);
|
||||
if request.relay_room_id.is_empty() {
|
||||
request.relay_room_id = payload.relay_room_id;
|
||||
debug!("Found Transfer and updated");
|
||||
debug!("request is: {:#?}", request);
|
||||
(StatusCode::OK, Json(request.clone()))
|
||||
} else {
|
||||
request.local_room_id = payload.local_room_id;
|
||||
debug!("Found Transfer and updated");
|
||||
debug!("request is: {:#?}", request);
|
||||
(StatusCode::OK, Json(request.clone()))
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut local = String::from("");
|
||||
let mut relay = String::from("");
|
||||
if payload.relay_room_id.is_empty() {
|
||||
local = payload.local_room_id;
|
||||
} else {
|
||||
relay = payload.relay_room_id;
|
||||
}
|
||||
let t_request = TransferResponse {
|
||||
name: payload.name,
|
||||
ip: payload.ip,
|
||||
local_room_id: local,
|
||||
relay_room_id: relay,
|
||||
};
|
||||
data.transfers.push(t_request.clone());
|
||||
|
||||
debug!("New TransferRequest created");
|
||||
debug!("Actual AppState is {:#?}", *data);
|
||||
|
||||
(StatusCode::CREATED, Json(t_request))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn download_info(
|
||||
State(shared_state): State<Arc<RwLock<AppState>>>,
|
||||
Path(name): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
let data = shared_state.write().await;
|
||||
match data.transfers.iter().find(|request| request.name == name) {
|
||||
Some(request) => {
|
||||
debug!("Found transfer name.");
|
||||
(StatusCode::OK, Json(request.clone()))
|
||||
}
|
||||
None => {
|
||||
warn!("couldn't find transfer-name: {}", name);
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(TransferResponse {
|
||||
name: String::from(""),
|
||||
ip: String::from(""),
|
||||
local_room_id: String::from(""),
|
||||
relay_room_id: String::from(""),
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn download_success(
|
||||
State(shared_state): State<Arc<RwLock<AppState>>>,
|
||||
Path(name): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
let mut data = shared_state.write().await;
|
||||
if let Some(index) = data
|
||||
.transfers
|
||||
.iter()
|
||||
.position(|request| request.name == name)
|
||||
{
|
||||
debug!("Found Transfer by name '{name}'");
|
||||
data.transfers.remove(index);
|
||||
debug!("Transfer deleted");
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"message": "transfer deleted"
|
||||
})),
|
||||
)
|
||||
} else {
|
||||
warn!("couldn't find transfer-name: {}", name);
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({
|
||||
"message": "transfer not found"
|
||||
})),
|
||||
)
|
||||
}
|
||||
}
|
||||
61
caesar-core/src/relay/transfer.rs
Normal file
61
caesar-core/src/relay/transfer.rs
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct TransferRequest {
|
||||
pub name: String,
|
||||
pub ip: String,
|
||||
pub local_room_id: String,
|
||||
pub relay_room_id: String,
|
||||
}
|
||||
impl TransferRequest {
|
||||
pub fn new(name: String, ip: String, local_room_id: String, relay_room_id: String) -> Self {
|
||||
Self {
|
||||
name,
|
||||
ip,
|
||||
local_room_id,
|
||||
relay_room_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct TransferResponse {
|
||||
pub name: String,
|
||||
pub ip: String,
|
||||
pub local_room_id: String,
|
||||
pub relay_room_id: String,
|
||||
}
|
||||
|
||||
impl TransferResponse {
|
||||
pub fn new(name: String, ip: String, local_room_id: String, relay_room_id: String) -> Self {
|
||||
Self {
|
||||
name,
|
||||
ip,
|
||||
local_room_id,
|
||||
relay_room_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_new() {
|
||||
let transfer = TransferResponse {
|
||||
name: "Test".to_string(),
|
||||
ip: "127.0.0.1".to_string(),
|
||||
local_room_id: "This_is_a_test_room_id".to_string(),
|
||||
relay_room_id: "This_is_a_test_room_id".to_string(),
|
||||
};
|
||||
assert_eq!(
|
||||
TransferResponse::new(
|
||||
"Test".to_string(),
|
||||
"127.0.0.1".to_string(),
|
||||
"This_is_a_test_room_id".to_string(),
|
||||
"This_is_a_test_room_id".to_string(),
|
||||
),
|
||||
transfer
|
||||
)
|
||||
}
|
||||
}
|
||||
969
caesar-core/src/sender/client.rs
Normal file
969
caesar-core/src/sender/client.rs
Normal file
|
|
@ -0,0 +1,969 @@
|
|||
use crate::sender::http_client::send_info;
|
||||
use crate::sender::util::{hash_random_name, replace_protocol};
|
||||
use crate::shared::{
|
||||
packets::{
|
||||
list_packet, packet::Value, ChunkPacket, HandshakePacket, HandshakeResponsePacket,
|
||||
ListPacket, Packet, ProgressPacket,
|
||||
},
|
||||
JsonPacket, JsonPacketResponse, JsonPacketSender, PacketSender, Sender, Socket, Status,
|
||||
};
|
||||
|
||||
use aes_gcm::{aead::Aead, Aes128Gcm, Key};
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt};
|
||||
use hmac::{Hmac, Mac};
|
||||
use p256::{ecdh::EphemeralSecret, PublicKey};
|
||||
use prost::Message;
|
||||
use rand::{rngs::OsRng, RngCore};
|
||||
use sha2::Sha256;
|
||||
use std::{
|
||||
fs,
|
||||
io::{stdout, Write},
|
||||
path::Path,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{io::AsyncReadExt, task::JoinHandle, time::sleep};
|
||||
use tokio_tungstenite::tungstenite::{protocol::Message as WebSocketMessage, Error};
|
||||
use tracing::{debug, error};
|
||||
|
||||
const DESTINATION: u8 = 1;
|
||||
const NONCE_SIZE: usize = 12;
|
||||
const MAX_CHUNK_SIZE: isize = u16::MAX as isize;
|
||||
const DELAY: Duration = Duration::from_millis(750);
|
||||
|
||||
/// A file that is to be sent.
|
||||
///
|
||||
/// This structure contains all the information about a file that is to be
|
||||
/// sent. It is used to keep track of the files that a user wants to send.
|
||||
#[derive(Clone)]
|
||||
struct File {
|
||||
/// The path to the file on the file system.
|
||||
///
|
||||
/// This is the path to the file on the user's file system. The path is
|
||||
/// used to open the file and read its contents.
|
||||
path: String,
|
||||
|
||||
/// The name of the file.
|
||||
///
|
||||
/// This is the name that the file will have when it is received by the
|
||||
/// receiver. This name is used when creating the file on the receiver's
|
||||
/// file system.
|
||||
name: String,
|
||||
|
||||
/// The size of the file in bytes.
|
||||
///
|
||||
/// This is the size of the file in bytes. The size is used to calculate
|
||||
/// the number of chunks that the file will be split into, and is also
|
||||
/// used to keep track of the progress of the file being sent.
|
||||
size: u64,
|
||||
}
|
||||
|
||||
/// The context for the sender.
|
||||
///
|
||||
/// This structure contains all the information that the sender needs in order
|
||||
/// to function properly. It is used to keep track of the state of the
|
||||
/// sender, and to pass information between functions.
|
||||
struct Context {
|
||||
/// The HMAC key for the sender.
|
||||
///
|
||||
/// This is the key that is used to sign packets. The key is also used to
|
||||
/// generate a URL that the receiver can use to join the session.
|
||||
hmac: Vec<u8>,
|
||||
|
||||
/// The sender that is used to send packets to the receiver.
|
||||
///
|
||||
/// This sender is used to send handshake packets, list packets, chunk
|
||||
/// packets, and progress packets to the receiver.
|
||||
sender: Sender,
|
||||
|
||||
/// The ephemeral keypair that is used to establish a shared key with the
|
||||
/// receiver.
|
||||
///
|
||||
/// This key is used to establish a shared key between the sender and
|
||||
/// receiver. The key is ephemeral, meaning that it is only used once in
|
||||
/// the session. The key is generated when the sender is created, and is
|
||||
/// then discarded after the session is complete.
|
||||
key: EphemeralSecret,
|
||||
|
||||
/// The files that the sender wants to send.
|
||||
///
|
||||
/// This vec contains all the information about the files that the sender
|
||||
/// wants to send. The vec is filled when the user specifies the files to
|
||||
/// send using the command line arguments.
|
||||
files: Vec<File>,
|
||||
|
||||
/// The shared key that is used to encrypt packets.
|
||||
///
|
||||
/// This value is set to `None` initially, and is set to `Some` when the
|
||||
/// shared key is established with the receiver. The shared key is used to
|
||||
/// encrypt packets that are sent to the receiver.
|
||||
shared_key: Option<Aes128Gcm>,
|
||||
|
||||
/// The task that is running in the background to send chunks of files to
|
||||
/// the receiver.
|
||||
///
|
||||
/// This task is created when the sender is created, and is used to send
|
||||
/// chunks of files to the receiver in the background. The task is
|
||||
/// initially set to `None`, but is set to `Some` when the task is spawned.
|
||||
/// The task is used to cancel the background task when the sender is
|
||||
/// dropped.
|
||||
task: Option<JoinHandle<()>>,
|
||||
}
|
||||
|
||||
/// This function is called when the client receives a create room packet
|
||||
/// from the server. The function is responsible for printing a URL to the
|
||||
/// console that the user can use to join the room.
|
||||
///
|
||||
/// The function first generates a base64 string from the hmac value that is
|
||||
/// used to verify the integrity of the room. The base64 string is then
|
||||
/// appended to the room id to create a URL. The URL is then printed to the
|
||||
/// console using the qr2term library. Finally, the function prints a
|
||||
/// message to the console with the URL.
|
||||
fn on_create_room(
|
||||
context: &Context,
|
||||
id: String,
|
||||
relay: String,
|
||||
transfer_name: String,
|
||||
is_local: bool,
|
||||
) -> Status {
|
||||
debug!("Creating room on: {relay}");
|
||||
let base64 = general_purpose::STANDARD.encode(&context.hmac);
|
||||
let url = format!("{}-{}", id, base64);
|
||||
|
||||
// let rand_name = generate_random_name();
|
||||
let hash_name = hash_random_name(transfer_name.clone());
|
||||
|
||||
let send_url = url.to_string();
|
||||
let h_name = hash_name.to_string();
|
||||
let server_url = replace_protocol(relay.as_str());
|
||||
let res = std::thread::spawn(move || {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.block_on(send_info(&server_url, &h_name, send_url.as_str(), is_local))
|
||||
})
|
||||
.join()
|
||||
.unwrap();
|
||||
debug!("Got Result: {:?}", res);
|
||||
// Print a newline to the console to separate the output from the command
|
||||
// line.
|
||||
match res {
|
||||
Ok(transfer_response) => {
|
||||
if !transfer_response.local_room_id.is_empty()
|
||||
&& !transfer_response.relay_room_id.is_empty()
|
||||
{
|
||||
println!();
|
||||
|
||||
// Try to generate a QR code from the URL. If the function fails for some
|
||||
// reason, print an error message to the console.
|
||||
// if let Err(error) = qr2term::print_qr(&url) {
|
||||
// error!("Failed to generate QR code: {}", error);
|
||||
// }
|
||||
|
||||
if let Err(error) = qr2term::print_qr(&transfer_name) {
|
||||
error!("Failed to generate QR code: {}", error);
|
||||
}
|
||||
// Print a newline to the console to separate the output from the command
|
||||
// line.
|
||||
println!();
|
||||
|
||||
// Print a message to the console with the URL.
|
||||
println!("Created room: {}", url);
|
||||
println!("Transfername is: {}", transfer_name);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error sending info: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
// Continue the event loop.
|
||||
Status::Continue()
|
||||
}
|
||||
|
||||
/// This function is called when the client receives a join room packet from
|
||||
/// the server. The function is responsible for sending a handshake packet to
|
||||
/// the server containing the client's public key and a signature generated
|
||||
/// using the client's private key and the room's hmac value.
|
||||
///
|
||||
/// The function first generates the client's public key from the private key.
|
||||
/// The public key is then serialized into a byte array.
|
||||
///
|
||||
/// Next, the function creates a HMAC object with the room's hmac value and
|
||||
/// updates it with the serialized public key. The resulting HMAC is then
|
||||
/// serialized into a byte array and used as the signature in the handshake
|
||||
/// packet.
|
||||
///
|
||||
/// Finally, the function sends the handshake packet to the server using the
|
||||
/// sender object.
|
||||
fn on_join_room(context: &Context, size: Option<usize>) -> Status {
|
||||
if size.is_some() {
|
||||
return Status::Err("Invalid join room packet.".into());
|
||||
}
|
||||
|
||||
// Generate the client's public key from the private key.
|
||||
let public_key = context.key.public_key().to_sec1_bytes().into_vec();
|
||||
|
||||
// Create a HMAC object with the room's hmac value and update
|
||||
// it with the serialized public key.
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(&context.hmac).unwrap();
|
||||
mac.update(&public_key);
|
||||
|
||||
// Serialize the resulting HMAC into a byte array and use it as the
|
||||
// signature in the handshake packet.
|
||||
let signature = mac.finalize().into_bytes().to_vec();
|
||||
|
||||
// Create the handshake packet and send it to the server.
|
||||
let handshake = HandshakePacket {
|
||||
public_key,
|
||||
signature,
|
||||
};
|
||||
|
||||
context
|
||||
.sender
|
||||
.send_packet(DESTINATION, Value::Handshake(handshake));
|
||||
|
||||
Status::Continue()
|
||||
}
|
||||
|
||||
/// This function is called when an error packet is received from the
|
||||
/// server. It creates a `Status::Err` variant containing the error
|
||||
/// message from the server and returns it to be handled by the main
|
||||
/// event loop.
|
||||
///
|
||||
/// When an error occurs, the server sends an error packet to the
|
||||
/// client. The error packet contains a message with a description of
|
||||
/// the error. This function extracts that message and creates a
|
||||
/// `Status::Err` variant with it, which is then returned to be handled
|
||||
/// by the main event loop.
|
||||
///
|
||||
/// The main event loop checks the status of the client and performs
|
||||
/// the necessary actions based on its value. If the status is
|
||||
/// `Status::Err`, the event loop exits with an error message
|
||||
/// containing the error message from the server.
|
||||
///
|
||||
/// This function is called from the event loop when an error packet is
|
||||
/// received from the server.
|
||||
fn on_error(message: String) -> Status {
|
||||
Status::Err(message)
|
||||
}
|
||||
|
||||
/// This function is called when the server sends a leave room packet to
|
||||
/// the client. It is responsible for aborting the file transfer task,
|
||||
/// generating a new ECDH key pair for the next handshake, and setting the
|
||||
/// shared key to `None`.
|
||||
///
|
||||
/// When the server sends a leave room packet to the client, it means that
|
||||
/// the receiver has disconnected from the room. In this case, the client
|
||||
/// should abort the file transfer task and print an error message to the
|
||||
/// user.
|
||||
///
|
||||
/// If the client is currently transferring files, it should abort the task
|
||||
/// by calling `AbortHandle::abort` on the task handle.
|
||||
///
|
||||
/// After that, the client should generate a new ECDH key pair using the
|
||||
/// `EphemeralSecret::random` function from the `p256` crate. This key pair
|
||||
/// will be used for the next handshake with the server.
|
||||
///
|
||||
/// Finally, the client should set the shared key to `None` to indicate that
|
||||
/// there is no shared key established for the current room.
|
||||
///
|
||||
/// This function is called from the event loop when a leave room packet is
|
||||
/// received from the server.
|
||||
fn on_leave_room(context: &mut Context, _: usize) -> Status {
|
||||
if let Some(task) = &context.task {
|
||||
// If the client is currently transferring files, abort the task
|
||||
// by calling `AbortHandle::abort` on the task handle.
|
||||
task.abort();
|
||||
}
|
||||
|
||||
// Generate a new ECDH key pair for the next handshake.
|
||||
context.key = EphemeralSecret::random(&mut OsRng);
|
||||
|
||||
// Set the shared key to `None` to indicate that there is no shared key
|
||||
// established for the current room.
|
||||
context.shared_key = None;
|
||||
|
||||
// Set the task handle to `None` to indicate that there is no task
|
||||
// running.
|
||||
context.task = None;
|
||||
|
||||
// Print an error message to the user indicating that the transfer was
|
||||
// interrupted because the receiver disconnected.
|
||||
println!();
|
||||
error!("Transfer was interrupted because the receiver disconnected.");
|
||||
|
||||
// Continue the event loop.
|
||||
Status::Continue()
|
||||
}
|
||||
|
||||
/// This function is called by the event loop when a progress packet is
|
||||
/// received from the server.
|
||||
///
|
||||
/// The progress packet contains the index of the file that is being
|
||||
/// transferred and the current progress of that file as a percentage.
|
||||
///
|
||||
/// If the client does not have a shared key established with the server,
|
||||
/// the function returns an error and does not continue. This indicates
|
||||
/// that the event loop should exit with an error message.
|
||||
///
|
||||
/// The function then retrieves the file at the index specified by the
|
||||
/// progress packet from the context. If the index is out of bounds, the
|
||||
/// function returns an error and does not continue. This indicates that
|
||||
/// the event loop should exit with an error message.
|
||||
///
|
||||
/// The function then prints a message to the console indicating which file
|
||||
/// is currently being transferred and what its progress is. The progress
|
||||
/// message is printed to the same line as a carriage return (`\r`) so that
|
||||
/// it overwrites the previous message.
|
||||
///
|
||||
/// If the progress of the file is 100%, the function prints a newline
|
||||
/// (`\n`) to the console to move the cursor to the next line.
|
||||
///
|
||||
/// If the progress of the last file is 100%, the function returns
|
||||
/// `Status::Exit()`. This indicates that the event loop should exit
|
||||
/// successfully.
|
||||
///
|
||||
/// If any other condition is met, the function returns `Status::Continue()`.
|
||||
/// This indicates that the event loop should continue running.
|
||||
fn on_progress(context: &Context, progress: ProgressPacket) -> Status {
|
||||
if context.shared_key.is_none() {
|
||||
return Status::Err("Invalid progress packet: no shared key established".into());
|
||||
}
|
||||
|
||||
let file = match context.files.get(progress.index as usize) {
|
||||
Some(file) => file,
|
||||
None => return Status::Err("Invalid index in progress packet.".into()),
|
||||
};
|
||||
|
||||
print!("\rTransferring '{}': {}%", file.name, progress.progress);
|
||||
stdout().flush().unwrap();
|
||||
|
||||
if progress.progress == 100 {
|
||||
println!();
|
||||
|
||||
if progress.index as usize == context.files.len() - 1 {
|
||||
return Status::Exit();
|
||||
}
|
||||
}
|
||||
|
||||
Status::Continue()
|
||||
}
|
||||
|
||||
/// This function reads a file in chunks, sends each chunk to the receiver over
|
||||
/// the WebSocket connection, and then sleeps for a short amount of time
|
||||
/// before sending the next chunk.
|
||||
///
|
||||
/// The function takes the sender, the shared key, and a vector of files to
|
||||
/// transfer as arguments.
|
||||
///
|
||||
/// For each file in the vector of files, the function reads the file in
|
||||
/// chunks, sends each chunk to the receiver over the WebSocket connection,
|
||||
/// and then sleeps for a short amount of time before sending the next chunk.
|
||||
///
|
||||
/// The chunk size is set to the maximum chunk size. If the number of bytes
|
||||
/// left to read in the file is less than the chunk size, the chunk size is set
|
||||
/// to the number of bytes left to read.
|
||||
///
|
||||
/// The function opens the file for reading using the tokio::fs::File::open
|
||||
/// function. If there is an error opening the file, the function prints an
|
||||
/// error message to the console and returns.
|
||||
///
|
||||
/// The function reads the file in chunks using the read_exact function from
|
||||
/// the tokio::io::AsyncReadExt trait. If there is an error reading from the
|
||||
/// file, the function prints an error message to the console and returns.
|
||||
///
|
||||
/// The function sends each chunk to the receiver over the WebSocket
|
||||
/// connection using the send_encrypted_packet function from the Sender struct.
|
||||
/// The function also increments the sequence number for each chunk that is
|
||||
/// sent.
|
||||
///
|
||||
/// After sending all of the chunks for a file, the function sleeps for a short
|
||||
/// amount of time using the tokio::time::sleep function. This helps to prevent
|
||||
/// the sender from overwhelming the receiver with too many messages.
|
||||
///
|
||||
/// The function repeats this process for all of the files in the vector of
|
||||
/// files.
|
||||
async fn on_chunk(sender: Sender, shared_key: Option<Aes128Gcm>, files: Vec<File>) {
|
||||
for file in files {
|
||||
// Initialize a sequence number for the chunks of this file
|
||||
let mut sequence = 0;
|
||||
// Set the chunk size to the maximum chunk size
|
||||
let mut chunk_size = MAX_CHUNK_SIZE;
|
||||
// Set the number of bytes left to read in the file
|
||||
let mut size = file.size as isize;
|
||||
|
||||
// Open the file for reading
|
||||
let mut handle = match tokio::fs::File::open(file.path).await {
|
||||
Ok(handle) => handle,
|
||||
Err(error) => {
|
||||
println!("Error: Unable to open file '{}': {}", file.name, error);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
while size > 0 {
|
||||
// If the number of bytes left to read in the file is less than the
|
||||
// chunk size, set the chunk size to the number of bytes left to read
|
||||
if size < chunk_size {
|
||||
chunk_size = size;
|
||||
}
|
||||
|
||||
// Create a vector to hold the chunk of data to be read from the file
|
||||
let mut chunk = vec![0u8; chunk_size.try_into().unwrap()];
|
||||
// Read a chunk of data from the file into the vector
|
||||
handle.read_exact(&mut chunk).await.unwrap();
|
||||
|
||||
// Send the chunk to the receiver over the WebSocket connection
|
||||
sender.send_encrypted_packet(
|
||||
&shared_key,
|
||||
DESTINATION,
|
||||
Value::Chunk(ChunkPacket { sequence, chunk }),
|
||||
);
|
||||
|
||||
// Increment the sequence number for the next chunk
|
||||
sequence += 1;
|
||||
// Decrement the number of bytes left to read in the file
|
||||
size -= chunk_size;
|
||||
}
|
||||
|
||||
// Sleep for a short amount of time to prevent overwhelming the receiver
|
||||
// with too many messages
|
||||
sleep(DELAY).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// This function sends a ListPacket to the receiver containing the list of
|
||||
/// files to be transferred. The ListPacket contains a vector of Entry structs,
|
||||
/// each of which represents one file.
|
||||
///
|
||||
/// The function creates a vector of Entry structs from the vector of File structs
|
||||
/// in the Context struct. Each Entry struct contains the index, name, and size
|
||||
/// of the corresponding File struct.
|
||||
///
|
||||
/// The function then sends the ListPacket to the receiver using the send_encrypted_packet
|
||||
/// function from the Sender struct.
|
||||
///
|
||||
/// After sending the ListPacket, the function spawns a task using tokio::spawn to
|
||||
/// call the on_chunk function with the Sender, shared_key, and vector of File
|
||||
/// structs as arguments. The on_chunk function will send each chunk of data for
|
||||
/// each file to the receiver.
|
||||
///
|
||||
/// The function returns Status::Continue(), which tells the main loop to continue
|
||||
/// running until all of the files have been transferred.
|
||||
fn on_handshake_finalize(context: &mut Context) -> Status {
|
||||
let mut entries = vec![];
|
||||
|
||||
for (index, file) in context.files.iter().enumerate() {
|
||||
let entry = list_packet::Entry {
|
||||
// The index of the file in the vector of Files in the Context struct
|
||||
index: index.try_into().unwrap(),
|
||||
// The name of the file
|
||||
name: file.name.clone(),
|
||||
// The size of the file in bytes
|
||||
size: file.size,
|
||||
};
|
||||
|
||||
entries.push(entry);
|
||||
}
|
||||
|
||||
context.sender.send_encrypted_packet(
|
||||
&context.shared_key,
|
||||
DESTINATION,
|
||||
Value::List(ListPacket { entries }),
|
||||
);
|
||||
|
||||
context.task = Some(tokio::spawn(on_chunk(
|
||||
context.sender.clone(),
|
||||
context.shared_key.clone(),
|
||||
context.files.clone(),
|
||||
)));
|
||||
|
||||
Status::Continue()
|
||||
}
|
||||
|
||||
/// Handshake function that is called when the Sender receives a HandshakeResponsePacket
|
||||
/// from the Receiver. This function verifies the signature from the Receiver and if
|
||||
/// successful, creates a shared key using the from the PublicKey struct.
|
||||
///
|
||||
/// The shared key is used to encrypt and decrypt packets sent between the Sender
|
||||
/// and the Receiver.
|
||||
///
|
||||
/// This function is called by the main loop in client.rs.
|
||||
fn on_handshake(context: &mut Context, handshake_response: HandshakeResponsePacket) -> Status {
|
||||
if context.shared_key.is_some() {
|
||||
// If the shared key is already established, this means that the Sender
|
||||
// has already performed the handshake, so return an error.
|
||||
return Status::Err("Already performed handshake.".into());
|
||||
}
|
||||
|
||||
// Create a new HMAC using the hmac from the Context struct as the key.
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(&context.hmac).unwrap();
|
||||
|
||||
// Update the HMAC with the public key from the HandshakeResponsePacket.
|
||||
mac.update(&handshake_response.public_key);
|
||||
|
||||
// Call verify_slice() on the HMAC to verify the signature from the Receiver.
|
||||
// If the signature is invalid, return an error.
|
||||
let verification = mac.verify_slice(&handshake_response.signature);
|
||||
if verification.is_err() {
|
||||
return Status::Err("Invalid signature from the receiver.".into());
|
||||
}
|
||||
|
||||
// Create a new PublicKey struct from the public key bytes in the
|
||||
// HandshakeResponsePacket.
|
||||
let shared_public_key = PublicKey::from_sec1_bytes(&handshake_response.public_key).unwrap();
|
||||
|
||||
// Use the diffie_hellman() method from the PublicKey struct to create a shared
|
||||
// secret key between the Sender and the Receiver. The shared secret key is a
|
||||
// 16 byte long slice of bytes.
|
||||
let shared_secret = context.key.diffie_hellman(&shared_public_key);
|
||||
let shared_secret = shared_secret.raw_secret_bytes();
|
||||
let shared_secret = &shared_secret[0..16];
|
||||
|
||||
// Create a new Key struct from the shared secret key. The Key<Aes128Gcm> type
|
||||
// is used to encrypt and decrypt packets.
|
||||
let shared_key: &Key<Aes128Gcm> = shared_secret.into();
|
||||
let shared_key = <Aes128Gcm as aes_gcm::KeyInit>::new(shared_key);
|
||||
|
||||
// Set the shared_key field of the Context struct to the shared key.
|
||||
context.shared_key = Some(shared_key);
|
||||
|
||||
// Call on_handshake_finalize() to start the transfer of files between the
|
||||
// Sender and the Receiver.
|
||||
on_handshake_finalize(context)
|
||||
}
|
||||
|
||||
/// This function is called by the `Sender` when a new message is received over
|
||||
/// the WebSocket connection. The message could be a text message or a binary
|
||||
/// message. If it is a text message, it will be deserialized into a
|
||||
/// `JsonPacketResponse` enum. If it is a binary message, it will be decrypted
|
||||
/// if necessary and then deserialized into a `Packet` struct.
|
||||
///
|
||||
/// The `JsonPacketResponse` enum will have one of the following variants:
|
||||
///
|
||||
/// * `Create { id }`: The Receiver has created a new room with the given ID.
|
||||
/// * `Join { size }`: The Receiver has joined a room with `size` number of
|
||||
/// files.
|
||||
/// * `Leave { index }`: The Receiver has left a room.
|
||||
/// * `Error { message }`: The Receiver has encountered an error.
|
||||
///
|
||||
/// If the message is a binary message, the `Packet` struct will have a
|
||||
/// `Value` variant that will have one of the following variants:
|
||||
///
|
||||
/// * `HandshakeResponse`: The Receiver has responded to the Sender's
|
||||
/// `Handshake` packet.
|
||||
/// * `Progress`: The Receiver has sent progress information for one of the
|
||||
/// files in the room.
|
||||
///
|
||||
/// This function does the following:
|
||||
///
|
||||
/// * If the message is a text message, it is deserialized into a
|
||||
/// `JsonPacketResponse` enum and then matched on to call the appropriate
|
||||
/// function.
|
||||
/// * If the message is a binary message, it is decrypted if necessary and then
|
||||
/// deserialized into a `Packet` struct. The `Value` variant of the `Packet`
|
||||
/// struct is then matched on to call the appropriate function.
|
||||
///
|
||||
/// If the message is invalid, an error is returned.
|
||||
fn on_message(
|
||||
context: &mut Context,
|
||||
message: WebSocketMessage,
|
||||
relay: String,
|
||||
transfer_name: String,
|
||||
is_local: bool,
|
||||
) -> Status {
|
||||
if message.is_text() {
|
||||
let text = message.into_text().unwrap();
|
||||
let packet = serde_json::from_str(&text).unwrap();
|
||||
|
||||
return match packet {
|
||||
JsonPacketResponse::Create { id } => {
|
||||
on_create_room(context, id, relay, transfer_name, is_local)
|
||||
}
|
||||
JsonPacketResponse::Join { size } => on_join_room(context, size),
|
||||
JsonPacketResponse::Leave { index } => on_leave_room(context, index),
|
||||
JsonPacketResponse::Error { message } => on_error(message),
|
||||
};
|
||||
} else if message.is_binary() {
|
||||
let data = message.into_data();
|
||||
let data = &data[1..];
|
||||
|
||||
let data = if let Some(shared_key) = &context.shared_key {
|
||||
let nonce = &data[..NONCE_SIZE];
|
||||
let ciphertext = &data[NONCE_SIZE..];
|
||||
|
||||
shared_key.decrypt(nonce.into(), ciphertext).unwrap()
|
||||
} else {
|
||||
data.to_vec()
|
||||
};
|
||||
|
||||
let packet = Packet::decode(data.as_ref()).unwrap();
|
||||
let value = packet.value.unwrap();
|
||||
|
||||
return match value {
|
||||
Value::HandshakeResponse(handshake_response) => {
|
||||
on_handshake(context, handshake_response)
|
||||
}
|
||||
Value::Progress(progress) => on_progress(context, progress),
|
||||
|
||||
_ => Status::Err(format!("Unexpected packet: {:?}", value)),
|
||||
};
|
||||
}
|
||||
|
||||
Status::Err("Invalid message type".into())
|
||||
}
|
||||
|
||||
/// Starts the sender client. This function will attempt to create a room with a size of 2
|
||||
/// (the number of clients that will be joining the room) and then it will open a file for
|
||||
/// each of the paths provided. It will then read chunks of data from each file and send them
|
||||
/// to the server.
|
||||
///
|
||||
/// This function takes two arguments:
|
||||
/// 1. `socket`: A `Socket` that represents the connection to the server.
|
||||
/// 2. `paths`: A `Vec` of `String`s that represent the paths to the files that will be sent
|
||||
/// to the server.
|
||||
///
|
||||
/// When the function is finished, it will exit and the transfer will be complete. If there
|
||||
/// is an error during the transfer, the function will print an error message to stdout and
|
||||
/// exit.
|
||||
pub async fn start(
|
||||
socket: Socket,
|
||||
paths: Vec<String>,
|
||||
room_id: Option<String>,
|
||||
relay: String,
|
||||
transfer_name: String,
|
||||
is_local: bool,
|
||||
) {
|
||||
// Create a vector to store metadata about each file that will be sent.
|
||||
let mut files = vec![];
|
||||
|
||||
// For each path in the `paths` vector:
|
||||
for path in paths {
|
||||
// Attempt to open the file at the given path.
|
||||
let handle = match fs::File::open(&path) {
|
||||
// If the file is successfully opened, store it in the `handle` variable.
|
||||
Ok(handle) => handle,
|
||||
// If there is an error, print an error message to stdout and exit the function.
|
||||
Err(error) => {
|
||||
error!("Error: Failed to open file '{}': {}", path, error);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Get the metadata for the file.
|
||||
let metadata = handle.metadata().unwrap();
|
||||
|
||||
// If the file is a directory, print an error message to stdout and exit the function.
|
||||
if metadata.is_dir() {
|
||||
error!("Error: The path '{}' does not point to a file.", path);
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the file name from the path.
|
||||
let name = Path::new(&path).file_name().unwrap().to_str().unwrap();
|
||||
|
||||
// Get the file size from the metadata.
|
||||
let size = metadata.len();
|
||||
|
||||
// If the file is empty, print an error message to stdout and exit the function.
|
||||
if size == 0 {
|
||||
error!("Error: The file '{}' is empty and cannot be sent.", name);
|
||||
return;
|
||||
}
|
||||
|
||||
// Add the file metadata to the `files` vector.
|
||||
files.push(File {
|
||||
name: name.to_string(),
|
||||
path,
|
||||
size,
|
||||
});
|
||||
}
|
||||
|
||||
// Generate a random key for HMAC.
|
||||
let mut hmac = [0u8; 32];
|
||||
OsRng.fill_bytes(&mut hmac);
|
||||
|
||||
// Generate a random key for AES-GCM.
|
||||
let key = EphemeralSecret::random(&mut OsRng);
|
||||
|
||||
// Create a channel to send packets to the server.
|
||||
let (sender, receiver) = flume::bounded(1000);
|
||||
|
||||
// Split the socket into separate send and receive streams.
|
||||
let (outgoing, incoming) = socket.split();
|
||||
|
||||
// Create a context that will be used throughout the transfer.
|
||||
let mut context = Context {
|
||||
// Store the sender half of the channel to send packets to the server.
|
||||
sender,
|
||||
// Store the ephemeral key for AES-GCM.
|
||||
key,
|
||||
// Store the files that will be sent to the server.
|
||||
files,
|
||||
|
||||
// Store the HMAC key.
|
||||
hmac: hmac.to_vec(),
|
||||
// Set the shared key to None.
|
||||
shared_key: None,
|
||||
// Set the current task to None.
|
||||
task: None,
|
||||
};
|
||||
|
||||
// Print a message to stdout indicating that the client is attempting to create a room.
|
||||
debug!("Attempting to create room...");
|
||||
|
||||
// Send a JSON packet to the server to create a room with a size of 2.
|
||||
debug!("With Room-ID: {:?}", room_id);
|
||||
context.sender.send_json_packet(JsonPacket::Create {
|
||||
id: room_id.clone(),
|
||||
});
|
||||
// context.sender.send_json_packet(JsonPacket::Create);
|
||||
|
||||
// Create a future that handles the outgoing stream of messages from the client to the
|
||||
// server.
|
||||
let outgoing_handler = receiver.stream().map(Ok).forward(outgoing);
|
||||
|
||||
// Create a future that handles the incoming stream of messages from the server to the
|
||||
// client.
|
||||
let incoming_handler = incoming.try_for_each(|message| {
|
||||
// Call the `on_message` function to handle the incoming message.
|
||||
match on_message(
|
||||
&mut context,
|
||||
message,
|
||||
relay.clone(),
|
||||
transfer_name.clone(),
|
||||
is_local,
|
||||
) {
|
||||
// If the status is `Status::Exit`, the transfer is complete. Print a message to
|
||||
// stdout and exit the function.
|
||||
Status::Exit() => {
|
||||
// TODO: Signal Exit to the server
|
||||
println!("Transfer has completed.");
|
||||
|
||||
// Exit the function with a `Result` of `Err`.
|
||||
return future::err(Error::ConnectionClosed);
|
||||
}
|
||||
// If the status is `Status::Err`, there was an error. Print an error message to
|
||||
// stdout and exit the function.
|
||||
Status::Err(error) => {
|
||||
error!("Error: {}", error);
|
||||
|
||||
// Exit the function with a `Result` of `Err`.
|
||||
return future::err(Error::ConnectionClosed);
|
||||
}
|
||||
// Otherwise, the message was handled successfully.
|
||||
_ => {}
|
||||
};
|
||||
|
||||
// Continue handling the incoming messages.
|
||||
future::ok(())
|
||||
});
|
||||
|
||||
// Pin the `incoming_handler` and `outgoing_handler` futures so that they do not move.
|
||||
pin_mut!(incoming_handler, outgoing_handler);
|
||||
|
||||
// Wait for either the `incoming_handler` or `outgoing_handler` to complete. If the
|
||||
// `incoming_handler` completes, return the result of the `incoming_handler`. If the
|
||||
// `outgoing_handler` completes, return the result of the `outgoing_handler`.
|
||||
future::select(incoming_handler, outgoing_handler).await;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use aes_gcm::KeyInit;
|
||||
|
||||
// #[test]
|
||||
// fn test_on_chunk() {
|
||||
// let (sender, _) = flume::bounded(1000);
|
||||
// let context = Context {
|
||||
// hmac: vec![],
|
||||
// sender,
|
||||
// key: EphemeralSecret::random(&mut OsRng),
|
||||
// shared_key: None,
|
||||
// files: vec![
|
||||
// File {
|
||||
// name: "file1.txt".to_string(),
|
||||
// size: 100,
|
||||
// path: "file1.txt".to_string(),
|
||||
// },
|
||||
// File {
|
||||
// name: "file2.txt".to_string(),
|
||||
// size: 100,
|
||||
// path: "file2.txt".to_string(),
|
||||
// },
|
||||
// ],
|
||||
// task: None,
|
||||
// };
|
||||
// }
|
||||
#[test]
|
||||
fn test_on_progress() {
|
||||
let (sender, _) = flume::bounded(1000);
|
||||
let context = Context {
|
||||
hmac: vec![],
|
||||
sender,
|
||||
key: EphemeralSecret::random(&mut OsRng),
|
||||
shared_key: Some(Aes128Gcm::new(Key::<Aes128Gcm>::from_slice(&[0u8; 16]))),
|
||||
files: vec![
|
||||
File {
|
||||
name: "file1.txt".to_string(),
|
||||
size: 100,
|
||||
path: "file1.txt".to_string(),
|
||||
},
|
||||
File {
|
||||
name: "file2.txt".to_string(),
|
||||
size: 100,
|
||||
path: "file2.txt".to_string(),
|
||||
},
|
||||
],
|
||||
task: None,
|
||||
};
|
||||
assert_eq!(
|
||||
on_progress(
|
||||
&context,
|
||||
ProgressPacket {
|
||||
index: 0,
|
||||
progress: 50
|
||||
}
|
||||
),
|
||||
Status::Continue()
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn test_on_create_room() {
|
||||
let (sender, _) = flume::bounded(1000);
|
||||
let context = Context {
|
||||
hmac: vec![],
|
||||
sender,
|
||||
key: EphemeralSecret::random(&mut OsRng),
|
||||
shared_key: None,
|
||||
files: vec![
|
||||
File {
|
||||
name: "file1.txt".to_string(),
|
||||
size: 100,
|
||||
path: "file1.txt".to_string(),
|
||||
},
|
||||
File {
|
||||
name: "file2.txt".to_string(),
|
||||
size: 100,
|
||||
path: "file2.txt".to_string(),
|
||||
},
|
||||
],
|
||||
task: None,
|
||||
};
|
||||
assert_eq!(
|
||||
on_create_room(
|
||||
&context,
|
||||
"b531e87d-e51a-4507-94f4-335cbe2d32f3-Nc5skZReq7qJN7INwckyAZLWEEbxsrFfH/692tUNgkM="
|
||||
.to_string(),
|
||||
String::from("0.0.0.0:8000"),
|
||||
String::from("Test"),
|
||||
true,
|
||||
),
|
||||
Status::Continue()
|
||||
);
|
||||
}
|
||||
// #[test]
|
||||
// fn test_on_join_room(){
|
||||
// let (sender, _) = flume::bounded(1000);
|
||||
// let mut context = Context {
|
||||
// hmac: vec![],
|
||||
// sender: sender,
|
||||
// key: EphemeralSecret::random(&mut OsRng),
|
||||
// shared_key: None,
|
||||
// files: vec![
|
||||
// File {
|
||||
// name: "file1.txt".to_string(),
|
||||
// size: 100,
|
||||
// path: "file1.txt".to_string(),
|
||||
// },
|
||||
// File {
|
||||
// name: "file2.txt".to_string(),
|
||||
// size: 100,
|
||||
// path: "file2.txt".to_string(),
|
||||
// },
|
||||
// ],
|
||||
// task: None,
|
||||
// };
|
||||
// assert_eq!(on_join_room(&context, None), Status::Continue());
|
||||
// }
|
||||
#[test]
|
||||
fn test_on_error() {
|
||||
assert_eq!(
|
||||
on_error("Error message".to_string()),
|
||||
Status::Err("Error message".to_string())
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn test_on_leave_room() {
|
||||
let (sender, _) = flume::bounded(1000);
|
||||
let mut context = Context {
|
||||
hmac: vec![],
|
||||
sender,
|
||||
key: EphemeralSecret::random(&mut OsRng),
|
||||
shared_key: None,
|
||||
files: vec![
|
||||
File {
|
||||
name: "file1.txt".to_string(),
|
||||
size: 100,
|
||||
path: "file1.txt".to_string(),
|
||||
},
|
||||
File {
|
||||
name: "file2.txt".to_string(),
|
||||
size: 100,
|
||||
path: "file2.txt".to_string(),
|
||||
},
|
||||
],
|
||||
task: None,
|
||||
};
|
||||
assert_eq!(on_leave_room(&mut context, 5), Status::Continue());
|
||||
}
|
||||
#[test]
|
||||
fn test_on_message() {
|
||||
let (sender, _) = flume::bounded(1000);
|
||||
let mut context = Context {
|
||||
hmac: vec![],
|
||||
sender,
|
||||
key: EphemeralSecret::random(&mut OsRng),
|
||||
shared_key: None,
|
||||
files: vec![
|
||||
File {
|
||||
name: "file1.txt".to_string(),
|
||||
size: 100,
|
||||
path: "file1.txt".to_string(),
|
||||
},
|
||||
File {
|
||||
name: "file2.txt".to_string(),
|
||||
size: 100,
|
||||
path: "file2.txt".to_string(),
|
||||
},
|
||||
],
|
||||
task: None,
|
||||
};
|
||||
assert_eq!(
|
||||
on_message(
|
||||
&mut context,
|
||||
WebSocketMessage::Text(r#"{"type":"leave","index":5}"#.to_string()),
|
||||
String::from("0.0.0.0:8000"),
|
||||
String::from("Test"),
|
||||
true,
|
||||
),
|
||||
Status::Continue()
|
||||
);
|
||||
assert_eq!(on_message(&mut context, WebSocketMessage::Text(r#"{"type":"create","id":"b531e87d-e51a-4507-94f4-335cbe2d32f3-Nc5skZReq7qJN7INwckyAZLWEEbxsrFfH/692tUNgkM="}"#.to_string()), String::from("0.0.0.0:8000"), String::from("Test"), true), Status::Continue());
|
||||
assert_eq!(
|
||||
on_message(
|
||||
&mut context,
|
||||
WebSocketMessage::Text(
|
||||
r#"{"type":"error","message":"Error Message: Test"}"#.to_string()
|
||||
),
|
||||
String::from("0.0.0.0:8000"),
|
||||
String::from("Test"),
|
||||
true
|
||||
),
|
||||
Status::Err("Error Message: Test".to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
58
caesar-core/src/sender/http_client.rs
Normal file
58
caesar-core/src/sender/http_client.rs
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
use tracing::{debug, error};
|
||||
|
||||
use local_ip_address::{local_ip, local_ipv6};
|
||||
use reqwest::blocking::Client;
|
||||
use tokio::task;
|
||||
|
||||
use crate::relay::transfer::{TransferRequest, TransferResponse};
|
||||
|
||||
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
pub async fn send_info(
|
||||
relay: &str,
|
||||
name: &str,
|
||||
room_id: &str,
|
||||
is_local: bool,
|
||||
) -> Result<TransferResponse> {
|
||||
let url = relay.to_string();
|
||||
let sender_ip = match local_ipv6() {
|
||||
Ok(ip) => ip,
|
||||
Err(_) => match local_ip() {
|
||||
Ok(ip) => ip,
|
||||
Err(e) => {
|
||||
error!("Error getting local ip: {e:?}");
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
},
|
||||
};
|
||||
let ip_str = sender_ip.to_owned().to_string();
|
||||
|
||||
let transfer_request = TransferRequest {
|
||||
name: String::from(name),
|
||||
ip: ip_str,
|
||||
local_room_id: if is_local {
|
||||
String::from(room_id)
|
||||
} else {
|
||||
String::from("")
|
||||
},
|
||||
relay_room_id: if !is_local {
|
||||
String::from(room_id)
|
||||
} else {
|
||||
String::from("")
|
||||
},
|
||||
};
|
||||
|
||||
debug!("Trying to send Request.");
|
||||
let result: Result<TransferResponse> = task::spawn_blocking(move || {
|
||||
let client = Client::new();
|
||||
let response = client
|
||||
.put(format!("{}/upload", url))
|
||||
.json(&transfer_request)
|
||||
.send()?
|
||||
.json()?;
|
||||
Ok(response)
|
||||
})
|
||||
.await?;
|
||||
|
||||
result
|
||||
}
|
||||
187
caesar-core/src/sender/mod.rs
Normal file
187
caesar-core/src/sender/mod.rs
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
/// Connects to the WebSocket server at `ws://0.0.0.0:8000/ws` with an
|
||||
/// `Origin` header of `ws://0.0.0.0:8000/ws`. This is the URL that the
|
||||
/// sender and receiver clients will connect to.
|
||||
///
|
||||
/// The `start_sender` function takes a reference to a vector of strings,
|
||||
/// which are the paths to the files that the sender will send over the
|
||||
/// WebSocket connection.
|
||||
///
|
||||
/// The function first creates a WebSocket request using the `IntoClientRequest`
|
||||
/// trait from `tungstenite`, which is defined on the `IntoClientRequest` struct.
|
||||
/// This struct is a type that represents a request to a WebSocket server.
|
||||
///
|
||||
/// The `into_client_request` function returns a `Result` because it may fail
|
||||
/// to create the request. In this case, we do not handle the error, so we just
|
||||
/// return if the result is an error.
|
||||
///
|
||||
/// Once we have a request, we insert the `Origin` header into the headers of
|
||||
/// the request. This is necessary because the WebSocket protocol requires the
|
||||
/// `Origin` header to be present in the handshake.
|
||||
///
|
||||
/// After that, we print out a message to the console indicating that we are
|
||||
/// attempting to connect to the server.
|
||||
///
|
||||
/// Next, we call the `connect_async` function from `tokio_tungstenite` which
|
||||
/// takes our request and attempts to connect to the server. This function
|
||||
/// returns a `Future` that resolves to a tuple of a `WebSocketStream` and a
|
||||
/// `Response` from the server. The `WebSocketStream` is a stream of
|
||||
/// WebSocket messages from the server, and the `Response` is the response
|
||||
/// from the server to our handshake request.
|
||||
///
|
||||
/// If connecting to the server fails, we print out an error message and
|
||||
/// return.
|
||||
///
|
||||
/// If connecting to the server succeeds, we pass the `WebSocketStream` and
|
||||
/// the paths to the files to the `start` function from the `sender` module.
|
||||
/// The `start` function is defined in the `sender` module, and it is the
|
||||
/// function that sends the files over the WebSocket connection.
|
||||
///
|
||||
/// The `start` function takes ownership of the `WebSocketStream` and the file
|
||||
/// paths, so we pass it the `paths` vector by value.
|
||||
pub mod client;
|
||||
pub mod http_client;
|
||||
pub mod util;
|
||||
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
relay::{appstate::AppState, server::ws_handler},
|
||||
sender::{client as sender, util::generate_random_name},
|
||||
};
|
||||
use axum::{routing::get, Router};
|
||||
use tokio::{net::TcpListener, sync::mpsc, task};
|
||||
use tokio_tungstenite::{
|
||||
connect_async,
|
||||
tungstenite::{client::IntoClientRequest, http::HeaderValue},
|
||||
};
|
||||
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
|
||||
use tracing::{debug, error, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn start_sender(relay: Arc<String>, files: Arc<Vec<String>>) {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
debug!("Got relay: {relay}");
|
||||
let room_id = Uuid::new_v4().to_string();
|
||||
let rand_name = generate_random_name();
|
||||
let local_room_id = room_id.clone();
|
||||
let local_files = files.clone();
|
||||
let local_relay = relay.clone();
|
||||
let local_rand_name = rand_name.clone();
|
||||
let local_tx = tx.clone();
|
||||
let local_ws_thread = task::spawn(async move {
|
||||
start_local_ws().await;
|
||||
});
|
||||
let relay_thread = task::spawn(async move {
|
||||
connect_to_server(
|
||||
relay.clone(),
|
||||
files.clone(),
|
||||
Some(room_id),
|
||||
relay.clone(),
|
||||
Arc::new(rand_name.clone()),
|
||||
tx.clone(),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
let local_thread = task::spawn(async move {
|
||||
connect_to_server(
|
||||
Arc::new(String::from("ws://0.0.0.0:9000")),
|
||||
local_files.clone(),
|
||||
Some(local_room_id),
|
||||
local_relay.clone(),
|
||||
Arc::new(local_rand_name.clone()),
|
||||
local_tx.clone(),
|
||||
true,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
rx.recv().await.unwrap();
|
||||
local_ws_thread.abort();
|
||||
relay_thread.abort();
|
||||
local_thread.abort();
|
||||
}
|
||||
|
||||
pub async fn start_local_ws() {
|
||||
let app_host = "0.0.0.0";
|
||||
let app_port = "9000";
|
||||
|
||||
// Create a new server data structure.
|
||||
let server = AppState::new();
|
||||
|
||||
// Set up the application routes.
|
||||
let app = Router::new()
|
||||
.route("/ws", get(ws_handler))
|
||||
.with_state(server)
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(DefaultMakeSpan::default().include_headers(true)),
|
||||
);
|
||||
|
||||
if let Ok(listener) = TcpListener::bind(&format!("{}:{}", app_host, app_port)).await {
|
||||
info!(
|
||||
"Local Websocket listening on: {}",
|
||||
listener.local_addr().unwrap()
|
||||
);
|
||||
|
||||
// Run the server.
|
||||
axum::serve(
|
||||
listener,
|
||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
} else {
|
||||
// Log binding failure and exit.
|
||||
error!("Failed to listen on: {}:{}", app_host, app_port);
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_to_server(
|
||||
relay: Arc<String>,
|
||||
files: Arc<Vec<String>>,
|
||||
room_id: Option<String>,
|
||||
message_server: Arc<String>,
|
||||
transfer_name: Arc<String>,
|
||||
tx: mpsc::Sender<()>,
|
||||
is_local: bool,
|
||||
) {
|
||||
let url = format!("{}/ws", relay);
|
||||
let message_relay = format!("{}", message_server);
|
||||
let transfer_name = format!("{}", transfer_name);
|
||||
match url.clone().into_client_request() {
|
||||
Ok(mut request) => {
|
||||
request
|
||||
.headers_mut()
|
||||
.insert("Origin", HeaderValue::from_str(relay.as_ref()).unwrap());
|
||||
|
||||
debug!("Attempting to connect to {url}...");
|
||||
let room_id = match room_id {
|
||||
Some(id) => id,
|
||||
None => Uuid::new_v4().to_string(),
|
||||
};
|
||||
|
||||
match connect_async(request).await {
|
||||
Ok((socket, _)) => {
|
||||
let paths = files.to_vec();
|
||||
sender::start(
|
||||
socket,
|
||||
paths,
|
||||
Some(room_id),
|
||||
message_relay.to_string(),
|
||||
transfer_name.clone(),
|
||||
is_local,
|
||||
)
|
||||
.await;
|
||||
tx.send(()).await.unwrap();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error: Failed to connect with error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error: failed to create request with reason: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
56
caesar-core/src/sender/util.rs
Normal file
56
caesar-core/src/sender/util.rs
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
use hex;
|
||||
use rand::{seq::SliceRandom, thread_rng};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
pub fn generate_random_name() -> String {
|
||||
let mut rng = thread_rng();
|
||||
let adjective = adjectives().choose(&mut rng).unwrap();
|
||||
// let adjective = adjectives().sample(&mut rng).unwrap();
|
||||
let noun1 = nouns1().choose(&mut rng).unwrap();
|
||||
let noun2 = nouns2().choose(&mut rng).unwrap();
|
||||
|
||||
format!("{adjective}-{noun1}-{noun2}")
|
||||
}
|
||||
|
||||
fn adjectives() -> &'static [&'static str] {
|
||||
static ADJECTIVES: &[&str] = &["funny", "smart", "creative", "friendly", "great"];
|
||||
ADJECTIVES
|
||||
}
|
||||
|
||||
fn nouns1() -> &'static [&'static str] {
|
||||
static NOUNS1: &[&str] = &["dog", "cat", "flower", "tree", "house"];
|
||||
NOUNS1
|
||||
}
|
||||
|
||||
fn nouns2() -> &'static [&'static str] {
|
||||
static NOUNS2: &[&str] = &["cookie", "cake", "frosting"];
|
||||
NOUNS2
|
||||
}
|
||||
|
||||
pub fn hash_random_name(name: String) -> String {
|
||||
let hashed_name = Sha256::digest(name.as_bytes());
|
||||
hex::encode(hashed_name)
|
||||
}
|
||||
|
||||
pub fn replace_protocol(address: &str) -> String {
|
||||
let mut result = address.to_string();
|
||||
result = result.replace("ws://", "http://");
|
||||
|
||||
result = result.replace("wss://", "https://");
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_random_name() {
|
||||
let name = generate_random_name();
|
||||
|
||||
assert!(name.contains('-'));
|
||||
assert!(name.split('-').count() == 3);
|
||||
assert!(name.len() > 0);
|
||||
}
|
||||
}
|
||||
335
caesar-core/src/shared.rs
Normal file
335
caesar-core/src/shared.rs
Normal file
|
|
@ -0,0 +1,335 @@
|
|||
pub mod packets {
|
||||
include!(concat!(env!("OUT_DIR"), "/packets.rs"));
|
||||
}
|
||||
|
||||
use aes_gcm::{
|
||||
aead::{Aead, AeadCore},
|
||||
Aes128Gcm,
|
||||
};
|
||||
use packets::Packet;
|
||||
use prost::Message;
|
||||
use rand::rngs::OsRng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_tungstenite::tungstenite::protocol::Message as WebSocketMessage;
|
||||
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
|
||||
|
||||
/// This struct is used to serialize/deserialize JSON packets sent
|
||||
/// between the client and the server.
|
||||
///
|
||||
/// The `type` field is used to specify the type of packet that is being sent.
|
||||
/// The possible values for this field are listed as variants of the enum.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
pub enum JsonPacket {
|
||||
/// Sent from the client to ask to join a room.
|
||||
///
|
||||
/// The `id` field specifies the ID of the room that the client wants
|
||||
/// to join.
|
||||
Join {
|
||||
/// The ID of the room that the client wants to join.
|
||||
id: String,
|
||||
},
|
||||
/// Sent from the client to ask to create a new room.
|
||||
Create { id: Option<String> },
|
||||
// Create,
|
||||
/// Sent from the client to ask to leave the current room.
|
||||
Leave,
|
||||
}
|
||||
|
||||
/// This struct is used to serialize/deserialize JSON packets sent
|
||||
/// from the server to the client.
|
||||
///
|
||||
/// The `type` field is used to specify the type of packet that is being
|
||||
/// sent. The possible values for this field are listed as variants of the
|
||||
/// enum.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
pub enum JsonPacketResponse {
|
||||
/// Sent from the server to inform the client of the result of a `Join`
|
||||
/// packet.
|
||||
///
|
||||
/// If the client successfully joined a room, the `size` field will be
|
||||
/// `Some` and contain the size of the room. If the client could not join
|
||||
/// a room, the `size` field will be `None`.
|
||||
Join {
|
||||
/// The size of the room that the client joined. If the client could
|
||||
/// not join a room, this field will be `None`.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
size: Option<usize>,
|
||||
},
|
||||
/// Sent from the server to inform the client of the result of a `Create`
|
||||
/// packet.
|
||||
///
|
||||
/// If the server successfully created a room, the `id` field will
|
||||
/// contain the ID of the room. If the server could not create a room,
|
||||
/// the `id` field will be empty.
|
||||
Create {
|
||||
/// The ID of the room that the server created. If the server could
|
||||
/// not create a room, this field will be empty.
|
||||
id: String,
|
||||
},
|
||||
/// Sent from the server to inform the client of the result of a `Leave`
|
||||
/// packet.
|
||||
///
|
||||
/// If the client successfully left a room, the `index` field will
|
||||
/// contain the index of the client that left the room. If the client
|
||||
/// could not leave a room, the `index` field will be 0.
|
||||
Leave {
|
||||
/// The index of the client that left the room. If the client could
|
||||
/// not leave a room, this field will be 0.
|
||||
index: usize,
|
||||
},
|
||||
/// Sent from the server to inform the client of an error.
|
||||
///
|
||||
/// The `message` field contains a description of the error.
|
||||
Error {
|
||||
/// A description of the error that occurred.
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// This enum represents the result of processing an event in the event loop.
|
||||
///
|
||||
/// The `Status` enum has three variants:
|
||||
///
|
||||
/// * `Continue` - This variant indicates that the event loop should
|
||||
/// continue processing events. This is the most common result and is used
|
||||
/// when the event loop has nothing special to do.
|
||||
///
|
||||
/// * `Exit` - This variant indicates that the event loop should exit. This
|
||||
/// is used when the event loop should exit because of an error or
|
||||
/// because the user has requested that the program exit.
|
||||
///
|
||||
/// * `Err` - This variant indicates that the event loop encountered an
|
||||
/// error. When the event loop receives a `Status::Err` variant, it should
|
||||
/// exit with an error message containing the message from the error packet.
|
||||
/// The message from the error packet is the only information that the event
|
||||
/// loop has about the error, so the message should be descriptive and
|
||||
/// helpful to the user. The message should not contain technical details
|
||||
/// about the error or how it occurred. Instead, the message should be
|
||||
/// written from the perspective of the user and should give the user enough
|
||||
/// information to understand what went wrong and how they might be able to
|
||||
/// fix the problem.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Status {
|
||||
/// Indicates that the event loop should continue processing events.
|
||||
Continue(),
|
||||
/// Indicates that the event loop should exit.
|
||||
Exit(),
|
||||
/// Indicates that the event loop encountered an error.
|
||||
Err(String),
|
||||
}
|
||||
|
||||
/// A trait for sending JSON packets.
|
||||
///
|
||||
/// This trait provides a single method, `send_json_packet`, which sends a
|
||||
/// JSON packet over some underlying transport.
|
||||
pub trait JsonPacketSender {
|
||||
/// Sends a JSON packet.
|
||||
///
|
||||
/// This method takes a single argument, `packet`, which is the JSON packet
|
||||
/// to send. The packet will be serialized into a JSON string and then sent
|
||||
/// over the underlying transport.
|
||||
///
|
||||
/// Note that the exact semantics of what it means to "send a JSON packet"
|
||||
/// will depend on the specific implementation of this trait. However, in
|
||||
/// general, the packet will be sent as a single message over the
|
||||
/// transport, and the transport will be responsible for ensuring that the
|
||||
/// packet is delivered to the intended recipient.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// If there is an error serializing the JSON packet, or if there is an
|
||||
/// error sending the serialized packet over the transport, this method
|
||||
/// may return an error. The exact semantics of what constitutes an error
|
||||
/// will depend on the specific implementation of this trait.
|
||||
fn send_json_packet(&self, packet: JsonPacket);
|
||||
}
|
||||
|
||||
/// A trait for sending Protocol Buffers packets over some underlying transport.
|
||||
///
|
||||
/// This trait provides two methods for sending Protocol Buffers packets:
|
||||
///
|
||||
/// * `send_packet` sends a packet in the clear (i.e., not encrypted).
|
||||
/// * `send_encrypted_packet` sends a packet encrypted using the AES-GCM
|
||||
/// algorithm with a 128-bit key.
|
||||
///
|
||||
/// The exact semantics of what it means to "send a packet" will depend on the
|
||||
/// specific implementation of this trait. However, in general, the packet will
|
||||
/// be serialized into a binary message using the Protocol Buffers wire format,
|
||||
/// and then sent over the underlying transport.
|
||||
///
|
||||
/// The `destination` argument specifies which recipient should receive the
|
||||
/// packet. This is a 1-byte field that is prepended to the serialized packet
|
||||
/// before it is sent.
|
||||
///
|
||||
/// The `key` argument is an optional AES-GCM key. If a key is provided, the
|
||||
/// packet will be encrypted before being sent. If no key is provided, the
|
||||
/// packet will be sent in the clear.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// If there is an error serializing the Protocol Buffers packet, or if there
|
||||
/// is an error sending the serialized packet over the transport, either of
|
||||
/// these methods may return an error. The exact semantics of what constitutes
|
||||
/// an error will depend on the specific implementation of this trait.
|
||||
pub trait PacketSender {
|
||||
/// Sends a Protocol Buffers packet in the clear.
|
||||
///
|
||||
/// The packet will be serialized into a binary message using the Protocol
|
||||
/// Buffers wire format, and then sent over the underlying transport.
|
||||
fn send_packet(&self, destination: u8, packet: packets::packet::Value);
|
||||
|
||||
/// Sends a Protocol Buffers packet encrypted using AES-GCM.
|
||||
///
|
||||
/// The packet will be serialized into a binary message using the Protocol
|
||||
/// Buffers wire format, encrypted using AES-GCM with a 128-bit key, and
|
||||
/// then sent over the underlying transport.
|
||||
///
|
||||
/// If no key is provided, the packet will be sent in the clear.
|
||||
fn send_encrypted_packet(
|
||||
&self,
|
||||
key: &Option<Aes128Gcm>,
|
||||
destination: u8,
|
||||
value: packets::packet::Value,
|
||||
);
|
||||
}
|
||||
|
||||
impl JsonPacketSender for Sender {
|
||||
/// Serializes the given JSON packet into a string, and then sends it as a
|
||||
/// text message over the underlying transport.
|
||||
///
|
||||
/// The `JsonPacket` type is defined in the `serde_json` crate, and it is a
|
||||
/// simple wrapper around a JSON object with string keys and values. This
|
||||
/// trait method is responsible for taking a `JsonPacket` and sending it
|
||||
/// over the WebSocket connection.
|
||||
///
|
||||
/// The `serde_json::to_string` function is used to serialize the packet
|
||||
/// into a JSON string. If this function returns an error, we panic
|
||||
/// because there is no reasonable recovery behavior in this case.
|
||||
///
|
||||
/// Once we have the JSON string, we wrap it in a `WebSocketMessage::Text`
|
||||
/// enum variant and send it over the WebSocket connection using the
|
||||
/// `send` method. If this method returns an error, we panic because there
|
||||
/// is no reasonable recovery behavior in this case.
|
||||
fn send_json_packet(&self, packet: JsonPacket) {
|
||||
let serialized_packet =
|
||||
serde_json::to_string(&packet).expect("Failed to serialize JSON packet.");
|
||||
|
||||
self.send(WebSocketMessage::Text(serialized_packet))
|
||||
.expect("Failed to send JSON packet.");
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketSender for Sender {
|
||||
/// Serializes the given packet value into a binary message, and then
|
||||
/// sends it over the underlying transport.
|
||||
///
|
||||
/// The `destination` parameter specifies which client should receive
|
||||
/// this message. The value of this parameter should be a byte that
|
||||
/// represents the client's index in the list of connected clients.
|
||||
///
|
||||
/// The `value` parameter specifies the actual data that should be sent
|
||||
/// to the client. This will be serialized into a `Packet` struct using
|
||||
/// the Protocol Buffers wire format.
|
||||
///
|
||||
/// This function will first encode the `Packet` struct into a vector of
|
||||
/// bytes using the Protocol Buffers wire format. It will then insert the
|
||||
/// `destination` byte as the first element of the vector, so that the
|
||||
/// receiving client knows which client this message is intended for.
|
||||
///
|
||||
/// Finally, this function will send the serialized packet over the
|
||||
/// underlying transport, which is assumed to be a WebSocket connection.
|
||||
/// If this send operation fails, this function will panic because there
|
||||
/// is no reasonable recovery behavior in this case.
|
||||
fn send_packet(&self, destination: u8, value: packets::packet::Value) {
|
||||
let packet = Packet { value: Some(value) };
|
||||
|
||||
let mut serialized_packet = packet.encode_to_vec();
|
||||
serialized_packet.insert(0, destination);
|
||||
|
||||
self.send(WebSocketMessage::Binary(serialized_packet))
|
||||
.expect("Failed to send Packet.");
|
||||
}
|
||||
|
||||
/// Similar to `send_packet`, but the message is encrypted using AES-GCM
|
||||
/// with a 128-bit key.
|
||||
///
|
||||
/// If no key is provided (i.e., if `key` is `None`), then the message will
|
||||
/// be sent in the clear.
|
||||
///
|
||||
/// This function works by generating a random 12-byte nonce using the
|
||||
/// `rand::OsRng` PRNG, encrypting the message using AES-GCM with the
|
||||
/// provided key and nonce, and then prepending the nonce to the ciphertext
|
||||
/// before sending it over the WebSocket connection. The receiving client
|
||||
/// will use the same key and nonce to decrypt the message.
|
||||
///
|
||||
/// Note that this function does not actually check whether the provided
|
||||
/// key is valid. If an invalid key is provided, the encryption will fail
|
||||
/// and the receiver will not be able to decrypt the message.
|
||||
fn send_encrypted_packet(
|
||||
&self,
|
||||
key: &Option<Aes128Gcm>,
|
||||
destination: u8,
|
||||
value: packets::packet::Value,
|
||||
) {
|
||||
let packet = Packet { value: Some(value) };
|
||||
|
||||
let nonce = Aes128Gcm::generate_nonce(&mut OsRng);
|
||||
let plaintext = packet.encode_to_vec();
|
||||
let mut ciphertext = key
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.encrypt(&nonce, plaintext.as_ref())
|
||||
.expect("Failed to encrypt Packet.");
|
||||
|
||||
let mut serialized_packet = nonce.to_vec();
|
||||
serialized_packet.append(&mut ciphertext);
|
||||
serialized_packet.insert(0, destination);
|
||||
|
||||
self.send(WebSocketMessage::Binary(serialized_packet))
|
||||
.expect("Failed to send encrypted Packet.");
|
||||
}
|
||||
}
|
||||
|
||||
/// A sender is a type that allows us to send messages to a WebSocket client.
|
||||
///
|
||||
/// In this case, a sender is a channel that allows us to send WebSocket
|
||||
/// messages to a client. The messages can be any type that implements the
|
||||
/// `Into<WebSocketMessage>`.
|
||||
///
|
||||
/// The `WebSocketMessage` type represents any message that can be sent over a
|
||||
/// WebSocket connection. It can be a binary message, a text message, or a
|
||||
/// close message.
|
||||
///
|
||||
/// The `MaybeTlsStream` type is a stream that may or may not be encrypted.
|
||||
/// If the connection is encrypted (e.g., via TLS), then the stream will be
|
||||
/// encrypted. If the connection is not encrypted, then the stream will be
|
||||
/// unencrypted.
|
||||
///
|
||||
/// The `TcpStream` type is a stream that is used to connect to a remote
|
||||
/// server over a TCP connection.
|
||||
///
|
||||
/// The `WebSocketStream` type is a stream that is used to connect to a remote
|
||||
/// WebSocket server. It is a wrapper around the `MaybeTlsStream` stream that
|
||||
/// adds WebSocket-specific functionality.
|
||||
pub type Sender = flume::Sender<WebSocketMessage>;
|
||||
|
||||
/// A socket is a type that represents a WebSocket connection.
|
||||
///
|
||||
/// In this case, a socket is a wrapper around a `MaybeTlsStream` stream that
|
||||
/// adds WebSocket-specific functionality.
|
||||
///
|
||||
/// The `MaybeTlsStream` type is a stream that may or may not be encrypted.
|
||||
/// If the connection is encrypted (e.g., via TLS), then the stream will be
|
||||
/// encrypted. If the connection is not encrypted, then the stream will be
|
||||
/// unencrypted.
|
||||
///
|
||||
/// The `TcpStream` type is a stream that is used to connect to a remote
|
||||
/// server over a TCP connection.
|
||||
///
|
||||
/// The `WebSocketStream` type is a stream that is used to connect to a remote
|
||||
/// WebSocket server. It is a wrapper around the `MaybeTlsStream` stream that
|
||||
/// adds WebSocket-specific functionality.
|
||||
pub type Socket = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
Loading…
Add table
Add a link
Reference in a new issue