Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve organization of infer_rnn_rates #326

Open
abuzarmahmood opened this issue Jan 24, 2025 · 1 comment
Open

Improve organization of infer_rnn_rates #326

abuzarmahmood opened this issue Jan 24, 2025 · 1 comment

Comments

@abuzarmahmood
Copy link
Collaborator

Remove redundancies in processing
And convert processes to functions

@abuzarmahmood
Copy link
Collaborator Author

The issue involves improving the organization of the infer_rnn_rates script by removing redundancies and converting processes into functions. The relevant file identified for modification is:

  • File: /home/abuzarmahmood/projects/blech_clust/utils/infer_rnn_rates.py
  • Description: This module uses an Auto-regressive Recurrent Neural Network (RNN) to infer firing rates from electrophysiological data. It processes data for each taste separately, trains an RNN model, and saves the predicted firing rates and latent factors.

The suggested changes include refactoring the script by modularizing various sections into functions to enhance readability, maintainability, and reduce redundancy. The proposed functions are:

  1. Argument Parsing:

    def parse_arguments():
        parser = argparse.ArgumentParser(description='Infer firing rates using RNN')
        parser.add_argument('data_dir', help='Path to data directory')
        # Add other arguments here...
        return parser.parse_args()
  2. Configuration Loading:

    def load_config(args, blech_clust_path):
        if args.override_config:
            print('Overriding config file\nUsing provided arguments\n')
            return {
                'train_steps': args.train_steps,
                'hidden_size': args.hidden_size,
                'bin_size': args.bin_size,
                'train_test_split': args.train_test_split,
                'use_pca': not args.no_pca,
                'time_lims': args.time_lims
            }
        else:
            config_path = os.path.join(blech_clust_path, 'params', 'blechrnn_params.json')
            if not os.path.exists(config_path):
                raise FileNotFoundError(f'BlechRNN Config file not found @ {config_path}')
            with open(config_path, 'r') as f:
                return json.load(f)
  3. Data Processing:

    def process_data(data, args, time_lims, bin_size):
        if args.separate_regions:
            data.get_region_units()
            region_names = data.region_names
            spike_arrays = [data.return_region_spikes(region) for region in region_names]
            keep_inds = [i for i, x in enumerate(spike_arrays) if x is not None]
            region_names = [region_names[i] for i in keep_inds]
            spike_arrays = [spike_arrays[i] for i in keep_inds]
        else:
            region_names = ['all']
            spike_arrays = [np.stack(data.spikes)]
        return region_names, spike_arrays
  4. Model Training:

    def train_rnn_model(inputs, labels, train_test_split, hidden_size, loss_name, train_steps, device):
        input_size = inputs.shape[-1]
        output_size = input_size - 2
        train_inds = np.random.choice(np.arange(inputs.shape[1]), int(train_test_split * inputs.shape[1]), replace=False)
        test_inds = np.setdiff1d(np.arange(inputs.shape[1]), train_inds)
        train_inputs, train_labels = inputs[:, train_inds], labels[:, train_inds]
        test_inputs, test_labels = inputs[:, test_inds], labels[:, test_inds]
        train_inputs, train_labels = train_inputs.to(device), train_labels.to(device)
        test_inputs, test_labels = test_inputs.to(device), test_labels.to(device)
        net = autoencoderRNN(input_size=input_size, hidden_size=hidden_size, output_size=output_size, rnn_layers=2, dropout=0.2)
        net.to(device)
        net, loss, cross_val_loss = train_model(net, train_inputs, train_labels, output_size=output_size, lr=0.001, train_steps=train_steps, loss=loss_name, test_inputs=test_inputs, test_labels=test_labels)
        return net, loss, cross_val_loss
  5. Plotting and Saving:

    def plot_results(pred_firing, binned_spikes, plots_dir, iden_str):
        vz.firing_overview(pred_firing.swapaxes(0, 1))
        fig = plt.gcf()
        plt.suptitle('RNN Predicted Firing Rates')
        fig.savefig(os.path.join(plots_dir, f'firing_pred_{iden_str}.png'))
        plt.close(fig)
        # Add more plotting logic here...
  6. HDF5 Writing:

    def write_to_hdf5(data, pred_firing_list, latent_out_list, pred_x_list):
        hdf5_path = data.hdf5_path
        with tables.open_file(hdf5_path, 'r+') as hf5:
            if '/rnn_output' not in hf5:
                hf5.create_group('/', 'rnn_output', 'RNN Output')
            rnn_output = hf5.get_node('/rnn_output')
            if '/rnn_output/regions' not in hf5:
                hf5.create_group('/rnn_output', 'regions', 'Region-specific RNN Output')
            rnn_output = hf5.get_node('/rnn_output/regions')
            for idx, (name, _) in processing_items:
                group_name = f'region_{name}_taste_{idx}'
                group_desc = f'Region {name} Taste {idx}'
                taste_grp = hf5.create_group(rnn_output, group_name, group_desc)
                hf5.create_array(taste_grp, 'pred_firing', pred_firing_list[idx])
                hf5.create_array(taste_grp, 'latent_out', latent_out_list[idx])
                hf5.create_array(taste_grp, 'pred_x', pred_x_list[idx])

These changes aim to improve the script's structure by encapsulating distinct functionalities into separate functions, thereby enhancing the code's clarity and ease of maintenance.


This response was automatically generated by blech_bot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant