diff --git a/swiftide/src/loaders/file_loader.rs b/swiftide/src/loaders/file_loader.rs index 7d355e84..086f140b 100644 --- a/swiftide/src/loaders/file_loader.rs +++ b/swiftide/src/loaders/file_loader.rs @@ -1,11 +1,12 @@ use crate::{ingestion::IngestionNode, ingestion::IngestionStream, Loader}; -use std::path::PathBuf; +use anyhow::Context as _; +use std::path::{Path, PathBuf}; /// The `FileLoader` struct is responsible for loading files from a specified directory, /// filtering them based on their extensions, and creating a stream of these files for further processing. pub struct FileLoader { pub(crate) path: PathBuf, - pub(crate) extensions: Vec, + pub(crate) extensions: Option>, } impl FileLoader { @@ -19,7 +20,7 @@ impl FileLoader { pub fn new(path: impl Into) -> Self { Self { path: path.into(), - extensions: vec![], + extensions: None, } } @@ -30,9 +31,14 @@ impl FileLoader { /// /// # Returns /// The `FileLoader` instance with the added extensions. - pub fn with_extensions(mut self, extensions: &[&str]) -> Self { - self.extensions - .extend(extensions.iter().map(ToString::to_string)); + pub fn with_extensions(mut self, extensions: &[impl AsRef]) -> Self { + self.extensions = Some( + self.extensions + .unwrap_or_default() + .into_iter() + .chain(extensions.iter().map(|ext| ext.as_ref().to_string())) + .collect(), + ); self } @@ -47,15 +53,7 @@ impl FileLoader { ignore::Walk::new(&self.path) .filter_map(|entry| entry.ok()) .filter(|entry| entry.file_type().map(|ft| ft.is_file()).unwrap_or(false)) - .filter(move |entry| { - let extensions = self.extensions.clone(); - - entry - .path() - .extension() - .map(|ext| extensions.contains(&ext.to_string_lossy().to_string())) - .unwrap_or(false) - }) + .filter(move |entry| self.file_has_extension(entry.path())) .map(|entry| entry.into_path()) .map(|entry| { tracing::debug!("Reading file: {:?}", entry); @@ -68,6 +66,21 @@ impl FileLoader { }) .collect() } + + // Helper function to check if a file has the specified extension. + // If no extensions are specified, this function will return true. + // If the file has no extension, this function will return false. + fn file_has_extension(&self, path: &Path) -> bool { + self.extensions + .as_ref() + .map(|exts| { + let Some(ext) = path.extension() else { + return false; + }; + exts.iter().any(|e| e == ext.to_string_lossy().as_ref()) + }) + .unwrap_or(true) + } } impl Loader for FileLoader { @@ -79,30 +92,22 @@ impl Loader for FileLoader { /// # Errors /// This method will return an error if it fails to read a file's content. fn into_stream(self) -> IngestionStream { - let file_paths = ignore::Walk::new(self.path) + let files = ignore::Walk::new(&self.path) .filter_map(|entry| entry.ok()) .filter(|entry| entry.file_type().map(|ft| ft.is_file()).unwrap_or(false)) - .filter(move |entry| { - let extensions = self.extensions.clone(); - - entry - .path() - .extension() - .map(|ext| extensions.contains(&ext.to_string_lossy().to_string())) - .unwrap_or(false) - }) - .map(|entry| entry.into_path()) + .filter(move |entry| self.file_has_extension(entry.path())) .map(|entry| { - let content = std::fs::read_to_string(&entry)?; tracing::debug!("Reading file: {:?}", entry); + let content = + std::fs::read_to_string(entry.path()).context("Failed to read file")?; Ok(IngestionNode { - path: entry, + path: entry.path().into(), chunk: content, ..Default::default() }) }); - IngestionStream::iter(file_paths) + IngestionStream::iter(files) } } @@ -113,6 +118,6 @@ mod test { #[test] fn test_with_extensions() { let loader = FileLoader::new("/tmp").with_extensions(&["rs"]); - assert_eq!(loader.extensions, vec!["rs".to_string()]); + assert_eq!(loader.extensions, Some(vec!["rs".to_string()])); } }