diff --git a/src/deepforest/dataset.py b/src/deepforest/dataset.py index 1cc595ce..f0421465 100644 --- a/src/deepforest/dataset.py +++ b/src/deepforest/dataset.py @@ -66,7 +66,11 @@ def __init__(self, Returns: If train, path, image, targets else image """ - self.annotations = pd.read_csv(csv_file) + # Check if csv_file is a DataFrame or a file path + if isinstance(csv_file, pd.DataFrame): + self.annotations = csv_file + else: + self.annotations = pd.read_csv(csv_file) self.root_dir = root_dir if transforms is None: self.transform = get_transform(augment=train) diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 320beca0..82ca4f9e 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -420,6 +420,7 @@ def predict_image(self, result = utilities.read_file(result, root_dir=root_dir) return result + def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1): """Create a dataset and predict entire annotation file Csv file format @@ -431,7 +432,7 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1 Deprecation warning: The return_plot argument is deprecated and will be removed in 2.0. Use visualize.plot_results on the result instead. Args: - csv_file: path to csv file + csv_file (str or pd.DataFrame): Path to a CSV file or a DataFrame with annotations. root_dir: directory of images. If none, uses "image_dir" in config (deprecated) savedir: directory to save images with bounding boxes (deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255) @@ -441,11 +442,17 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1 df: pandas dataframe with bounding boxes, label and scores for each image in the csv file """ - df = utilities.read_file(csv_file) - ds = dataset.TreeDataset(csv_file=csv_file, + # Use DataFrame directly if provided, otherwise treat as file path + if isinstance(csv_file, pd.DataFrame): + df = csv_file + else: + df = utilities.read_file(csv_file) + + ds = dataset.TreeDataset(csv_file=df, root_dir=root_dir, transforms=None, train=False) + dataloader = self.predict_dataloader(ds) results = predict._dataloader_wrapper_(model=self, diff --git a/tests/test_main.py b/tests/test_main.py index a8c3543e..f4019335 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -649,8 +649,7 @@ def test_predict_tile_with_crop_model(m, config): "xmin", "ymin", "xmax", "ymax", "label", "score", "cropmodel_label", "geometry", "cropmodel_score", "image_path" } - - + def test_predict_tile_with_crop_model_empty(): """If the model return is empty, the crop model should return an empty dataframe""" raster_path = get_data("SOAP_061.png") @@ -673,4 +672,4 @@ def test_predict_tile_with_crop_model_empty(): crop_model=crop_model) # Assert the result - assert result is None + assert result is None \ No newline at end of file