Skip to content

Commit

Permalink
rename locfile to filelist, make the latitude/longitude/date values o…
Browse files Browse the repository at this point in the history
…ptional, and use it to filter available recordings if specified
  • Loading branch information
jhuus committed Aug 2, 2024
1 parent 1d2fc2b commit dcf1b93
Showing 1 changed file with 42 additions and 26 deletions.
68 changes: 42 additions & 26 deletions analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, class_name, score, start_time, end_time):
self.end_time = end_time

class Analyzer:
def __init__(self, input_path, output_path, start_time, end_time, date_str, latitude, longitude, region, locfile, debug_mode, merge, overlap, device, thread_num=1):
def __init__(self, input_path, output_path, start_time, end_time, date_str, latitude, longitude, region, filelist, debug_mode, merge, overlap, device, thread_num=1):
self.input_path = input_path.strip()
self.output_path = output_path.strip()
self.start_seconds = self._get_seconds_from_time_string(start_time)
Expand All @@ -55,12 +55,13 @@ def __init__(self, input_path, output_path, start_time, end_time, date_str, lati
self.latitude = latitude
self.longitude = longitude
self.region = region
self.locfile = locfile
self.filelist = filelist
self.debug_mode = debug_mode
self.overlap = overlap
self.thread_num = thread_num
self.device = device
self.frequencies = {}
self.issued_skip_files_warning = False

if cfg.infer.min_score == 0:
self.merge_labels = False # merging all labels >= min_score makes no sense in this case
Expand Down Expand Up @@ -96,6 +97,9 @@ def _get_file_list(input_path):
# return week number in the range [1, 48] as used by eBird barcharts, i.e. 4 weeks per month
@staticmethod
def _get_week_num_from_date_str(date_str):
if not isinstance(date_str, str):
return None # e.g. if filelist is used to filter recordings and no date is specified

date_str = date_str.replace('-', '') # for case with yyyy-mm-dd dates in CSV file
if not date_str.isnumeric():
return None
Expand All @@ -112,7 +116,7 @@ def _get_week_num_from_date_str(date_str):
# a region is an alternative to lat/lon, and may specify an eBird county (e.g. CA-AB-FN)
# or province (e.g. CA-AB)
def _process_location_and_date(self):
if self.locfile is None and self.region is None and (self.latitude is None or self.longitude is None):
if self.filelist is None and self.region is None and (self.latitude is None or self.longitude is None):
self.check_frequency = False
self.week_num = None
return
Expand All @@ -125,17 +129,17 @@ def _process_location_and_date(self):

# if a location file is specified, use that
self.location_date_dict = None
if self.locfile is not None:
if os.path.exists(self.locfile):
dataframe = pd.read_csv(self.locfile)
if self.filelist is not None:
if os.path.exists(self.filelist):
dataframe = pd.read_csv(self.filelist)
expected_column_names = ['filename', 'latitude', 'longitude', 'recording_date']
if len(dataframe.columns) != len(expected_column_names):
logging.error(f"Error: file {self.locfile} has {len(dataframe.columns)} columns but {len(expected_column_names)} were expected.")
logging.error(f"Error: file {self.filelist} has {len(dataframe.columns)} columns but {len(expected_column_names)} were expected.")
quit()

for i, column_name in enumerate(dataframe.columns):
if column_name != expected_column_names[i]:
logging.error(f"Error: file {self.locfile}, column {i} is {column_name} but {expected_column_names[i]} was expected.")
logging.error(f"Error: file {self.filelist}, column {i} is {column_name} but {expected_column_names[i]} was expected.")
quit()

self.location_date_dict = {}
Expand All @@ -145,7 +149,7 @@ def _process_location_and_date(self):

return
else:
logging.error(f"Error: file {self.locfile} not found.")
logging.error(f"Error: file {self.filelist} not found.")
quit()

if self.date_str == 'file':
Expand Down Expand Up @@ -319,27 +323,36 @@ def _get_specs(self, start_seconds, end_seconds):
return spec_array

def _analyze_file(self, file_path):
logging.info(f"Thread {self.thread_num}: Analyzing {file_path}")

check_frequency = self.check_frequency
if check_frequency:
if self.location_date_dict is not None:
filename = Path(file_path).name
if filename in self.location_date_dict:
latitude, longitude, self.week_num = self.location_date_dict[filename]
county = None
for c in self.counties:
if latitude >= c.min_y and latitude <= c.max_y and longitude >= c.min_x and longitude <= c.max_x:
county = c
break

if county is None:
if self.week_num is None:
check_frequency = False
logging.warning(f"Warning: no matching county found for latitude={latitude} and longitude={longitude}")
else:
self._update_class_frequency_stats([county])
county = None
for c in self.counties:
if latitude >= c.min_y and latitude <= c.max_y and longitude >= c.min_x and longitude <= c.max_x:
county = c
break

if county is None:
check_frequency = False
logging.warning(f"Warning: no matching county found for latitude={latitude} and longitude={longitude}")
else:
self._update_class_frequency_stats([county])
else:
logging.warning(f"Warning: file {filename} not found in {self.locfile}")
# when a filelist is specified, only the recordings in that file are processed;
# so you can specify a filelist with no locations or dates if you want to restrict the recording
# list but not invoke location/date processing; you still need the standard CSV format
# with the expected number of columns, but latitude/longitude/date can be empty
if not self.issued_skip_files_warning:
logging.info(f"Thread {self.thread_num}: skipping some recordings that were not included in {self.filelist} (e.g. {filename})")
self.issued_skip_files_warning = True

return
elif self.get_date_from_file_name:
result = re.split(cfg.infer.file_date_regex, os.path.basename(file_path))
if len(result) > cfg.infer.file_date_regex_group:
Expand All @@ -349,6 +362,8 @@ def _analyze_file(self, file_path):
logging.error(f'Error: invalid date string: {self.date_str} extracted from {file_path}')
check_frequency = False # ignore species frequencies for this file

logging.info(f"Thread {self.thread_num}: Analyzing {file_path}")

# clear info from previous recording, and mark classes where frequency of eBird reports is too low
for class_info in self.class_infos:
class_info.reset()
Expand Down Expand Up @@ -481,7 +496,7 @@ def run(self, file_list):
if __name__ == '__main__':
# command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument('-b', '--band', type=int, default=1 * cfg.infer.use_banding_codes, help="If 1, use banding codes labels. If 0, use common names. Default = {1 * cfg.infer.use_banding_codes}.")
parser.add_argument('-b', '--band', type=int, default=1 * cfg.infer.use_banding_codes, help=f"If 1, use banding codes labels. If 0, use common names. Default = {1 * cfg.infer.use_banding_codes}.")
parser.add_argument('-d', '--debug', default=False, action='store_true', help='Flag for debug mode (analyze one spectrogram only, and output several top candidates).')
parser.add_argument('-e', '--end', type=str, default='', help="Optional end time in hh:mm:ss format, where hh and mm are optional.")
parser.add_argument('-i', '--input', type=str, default='', help="Input path (single audio file or directory). No default.")
Expand All @@ -490,10 +505,10 @@ def run(self, file_list):
parser.add_argument('-m', '--merge', type=int, default=1, help=f'Specify 0 to not merge adjacent labels of same species. Default = 1, i.e. merge.')
parser.add_argument('-p', '--min_score', type=float, default=cfg.infer.min_score, help=f"Generate label if score >= this. Default = {cfg.infer.min_score}.")
parser.add_argument('-s', '--start', type=str, default='', help="Optional start time in hh:mm:ss format, where hh and mm are optional.")
parser.add_argument('--date', type=str, default=None, help=f'Date in yyyymmdd, mmdd, or file. Specifying file extracts the date from the file name, using the reg ex defined in config.py.')
parser.add_argument('--date', type=str, default=None, help=f'Date in yyyymmdd, mmdd, or file. Specifying file extracts the date from the file name, using the file_date_regex in base_config.py.')
parser.add_argument('--lat', type=float, default=None, help=f'Latitude. Use with longitude to identify an eBird county and ignore corresponding rarities.')
parser.add_argument('--lon', type=float, default=None, help=f'Longitude. Use with latitude to identify an eBird county and ignore corresponding rarities.')
parser.add_argument('--locfile', type=str, default=None, help=f'Path to optional CSV file containing file names, latitudes, longitudes and recording dates.')
parser.add_argument('--filelist', type=str, default=None, help=f'Path to optional CSV file containing input file names, latitudes, longitudes and recording dates.')
parser.add_argument('--threads', type=int, default=cfg.infer.num_threads, help=f'Number of threads. Default = {cfg.infer.num_threads}')
parser.add_argument('-r', '--region', type=str, default=None, help=f'eBird region code, e.g. "CA-AB" for Alberta. Use as an alternative to latitude/longitude.')
args = parser.parse_args()
Expand All @@ -514,13 +529,14 @@ def run(self, file_list):
device = 'cuda'
logging.info(f"Using GPU")
else:
# TODO: use openvino to improve performance when no GPU is available
device = 'cpu'
logging.info(f"Using CPU")

file_list = Analyzer._get_file_list(args.input)
if num_threads == 1:
# keep it simple in case multithreading code has undesirable side-effects (e.g. disabling echo to terminal)
analyzer = Analyzer(args.input, args.output, args.start, args.end, args.date, args.lat, args.lon, args.region, args.locfile, args.debug, args.merge, args.overlap, device)
analyzer = Analyzer(args.input, args.output, args.start, args.end, args.date, args.lat, args.lon, args.region, args.filelist, args.debug, args.merge, args.overlap, device)
analyzer.run(file_list)
else:
# split input files into one group per thread
Expand All @@ -532,7 +548,7 @@ def run(self, file_list):
processes = []
for i in range(num_threads):
if len(file_lists[i]) > 0:
analyzer = Analyzer(args.input, args.output, args.start, args.end, args.date, args.lat, args.lon, args.region, args.locfile, args.debug, args.merge, args.overlap, device, i + 1)
analyzer = Analyzer(args.input, args.output, args.start, args.end, args.date, args.lat, args.lon, args.region, args.filelist, args.debug, args.merge, args.overlap, device, i + 1)
if os.name == "posix":
process = mp.Process(target=analyzer.run, args=(file_lists[i], ))
else:
Expand Down

0 comments on commit dcf1b93

Please sign in to comment.