diff --git a/README-zh.md b/README-zh.md index b4c2c16..63dcd50 100644 --- a/README-zh.md +++ b/README-zh.md @@ -17,17 +17,18 @@ sgo sgo -d target -p 3001 ``` - ### Command help ```sh Usage: sgo [OPTIONS] Options: - -d, --dir Sets the directory to serve files from [default: ./static] - -p, --port Sets the port number to listen on [default: 3030] - -h, --help Print help - -V, --version Print version + -d, --dir Sets the directory to serve files from [default: ./static] + -p, --port Sets the port number to listen on [default: 3030] + -L, --no-request-logging Do not log any request information to the console + -C, --cors Enable CORS, sets `Access-Control-Allow-Origin` to `*` + -h, --help Print help + -V, --version Print version ```
diff --git a/README.md b/README.md index 4ccde69..74fc224 100644 --- a/README.md +++ b/README.md @@ -25,10 +25,12 @@ sgo -d target -p 3001 Usage: sgo [OPTIONS] Options: - -d, --dir Sets the directory to serve files from [default: ./static] - -p, --port Sets the port number to listen on [default: 3030] - -h, --help Print help - -V, --version Print version + -d, --dir Sets the directory to serve files from [default: ./static] + -p, --port Sets the port number to listen on [default: 3030] + -L, --no-request-logging Do not log any request information to the console + -C, --cors Enable CORS, sets `Access-Control-Allow-Origin` to `*` + -h, --help Print help + -V, --version Print version ```
diff --git a/src/cli.rs b/src/cli.rs index 1c3df77..c593f0a 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -22,10 +22,17 @@ pub fn get_matches() -> ArgMatches { ) .arg( Arg::new("no-request-logging") + .short('L') .long("no-request-logging") - .value_name("LOGGING") .help("Do not log any request information to the console") .action(clap::ArgAction::SetTrue), // Define as a flag that sets the value to false ) + .arg( + Arg::new("cors") + .short('C') + .long("cors") + .help("Enable CORS, sets `Access-Control-Allow-Origin` to `*`") + .action(clap::ArgAction::SetTrue), + ) .get_matches() } \ No newline at end of file diff --git a/src/file_server.rs b/src/file_server.rs index 2d2f36f..6d65ca8 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -1,79 +1,86 @@ use std::convert::Infallible; use std::path::Path; use tokio::fs; -use warp::reply::{html, with_header, Response}; use warp::Reply; -use warp::path::Tail; use std::sync::Arc; use mime_guess::mime; -// MARK: - 处理请求 -pub async fn serve_files(path: Tail, css_content: Arc, base_dir: Arc) -> Result { - let path_str = path.as_str(); +pub async fn serve_files( + path: warp::path::Tail, + css_content: Arc, + base_dir: Arc, + enable_cors: bool, +) -> Result { + let path_str = path.as_str(); - let full_path = Path::new(&**base_dir).join(path_str); - let full_path_clone = full_path.clone(); // 克隆 PathBuf + let full_path = Path::new(&**base_dir).join(path_str); + let full_path_clone = full_path.clone(); // 克隆 PathBuf - if full_path.is_dir() { - match fs::read_dir(full_path).await { - Ok(mut entries) => { - let relative_path: String = Path::new(path_str).to_str().unwrap_or(&base_dir).to_string(); + let response = if full_path.is_dir() { + match fs::read_dir(full_path).await { + Ok(mut entries) => { + let relative_path: String = Path::new(path_str).to_str().unwrap_or(&base_dir).to_string(); + + // 检查 relative_path 是否为空 + let relative_path = if relative_path.is_empty() { + base_dir.to_string() + } else { + relative_path + }; - // 检查 relative_path 是否为空 - let relative_path = if relative_path.is_empty() { - base_dir.to_string() - } else { - relative_path - }; + let mut list = String::new(); + list.push_str(""); + list.push_str(&format!("", css_content)); + list.push_str(&format!("

Index of {}

    ", relative_path)); - let mut list = String::new(); - // 引入 CSS 文件 - list.push_str(""); - list.push_str(&format!("", css_content)); - list.push_str(&format!("

    Index of {}

      ", relative_path)); - - // 添加返回上一级目录的链接(如果不是根目录) - if !path_str.is_empty() { - let parent_path = Path::new(path_str).parent().unwrap_or(Path::new(&**base_dir)).to_str().unwrap(); - list.push_str(&format!("
    • ../
    • ", parent_path)); - } + // 添加返回上一级目录的链接(如果不是根目录) + if !path_str.is_empty() { + let parent_path = Path::new(path_str).parent().unwrap_or(Path::new(&**base_dir)).to_str().unwrap(); + list.push_str(&format!("
    • ../
    • ", parent_path)); + } - while let Some(entry) = entries.next_entry().await.unwrap() { - let file_name: String = entry.file_name().into_string().unwrap(); - let entry_path: std::path::PathBuf = entry.path(); - let relative_path: String = Path::new(path_str).join(&file_name).to_str().unwrap().to_string(); - - if entry_path.is_dir() { - list.push_str(&format!("
    • {}/
    • ", relative_path, file_name)); - } else { - list.push_str(&format!("
    • {}
    • ", relative_path, file_name)); - } - } - list.push_str("
    "); - Ok(html(list).into_response()) - }, - Err(_) => { - let error_message: String = "Directory not found".to_string(); - Ok(html(error_message).into_response()) - }, - } - } else { - match fs::read(full_path).await { - Ok(content) => { - let mime_type: mime::Mime = mime_guess::from_path(&full_path_clone).first_or_octet_stream(); - if mime_type == mime::TEXT_HTML || mime_type == mime::TEXT_PLAIN || mime_type == mime::TEXT_CSS || mime_type == mime::TEXT_JAVASCRIPT { - // 对于文本文件直接展示内容 - let content_str = String::from_utf8_lossy(&content).to_string(); - Ok(with_header(content_str, "Content-Type", mime_type.to_string()).into_response()) - } else { - // 对于其他文件,提供下载 - Ok(with_header(content, "Content-Type", mime_type.to_string()).into_response()) - } - }, - Err(_) => { - let error_message = "File not found".to_string(); - Ok(html(error_message).into_response()) - }, - } - } + while let Some(entry) = entries.next_entry().await.unwrap() { + let file_name: String = entry.file_name().into_string().unwrap(); + let entry_path: std::path::PathBuf = entry.path(); + let relative_path: String = Path::new(path_str).join(&file_name).to_str().unwrap().to_string(); + + if entry_path.is_dir() { + list.push_str(&format!("
  • {}/
  • ", relative_path, file_name)); + } else { + list.push_str(&format!("
  • {}
  • ", relative_path, file_name)); + } + } + list.push_str("
"); + warp::reply::html(list).into_response() + } + Err(_) => { + let error_message: String = "Directory not found".to_string(); + warp::reply::html(error_message).into_response() + } + } + } else { + match fs::read(full_path).await { + Ok(content) => { + let mime_type: mime::Mime = mime_guess::from_path(&full_path_clone).first_or_octet_stream(); + if mime_type == mime::TEXT_HTML || mime_type == mime::TEXT_PLAIN || mime_type == mime::TEXT_CSS || mime_type == mime::TEXT_JAVASCRIPT { + let content_str = String::from_utf8_lossy(&content).to_string(); + warp::reply::with_header(content_str, "Content-Type", mime_type.to_string()).into_response() + } else { + warp::reply::with_header(content, "Content-Type", mime_type.to_string()).into_response() + } + } + Err(_) => { + let error_message = "File not found".to_string(); + warp::reply::html(error_message).into_response() + } + } + }; + + let response = if enable_cors { + warp::reply::with_header(response, "Access-Control-Allow-Origin", "*").into_response() + } else { + response + }; + + Ok(response) } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index f74bd24..566fb02 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,7 @@ async fn main() { let matches = cli::get_matches(); // 读取命令行参数 + let enable_cors = matches.get_flag("cors"); let no_request_logging = matches.get_flag("no-request-logging"); let base_dir = Arc::new(matches.get_one::("dir").unwrap().to_string()); let port: u16 = matches @@ -43,10 +44,19 @@ async fn main() { if path.as_str().is_empty() { "/".green() } else { path.as_str().green() } ); } - file_server::serve_files(path, css_content_arc.clone(), base_dir.clone()) + file_server::serve_files(path, css_content_arc.clone(), base_dir.clone(), enable_cors.clone()) } }); + // // Create CORS filter + // let cors_filter = warp::reply::with_header(ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*")); + // // Apply CORS filter conditionally + // let route = if enable_cors { + // route.with(cors_filter) + // } else { + // route + // }; + // 打印服务器启动信息 println!( "Starting server at http://{}:{}",