Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using external activation functions #981

Closed
Maya7991 opened this issue Jul 4, 2024 · 4 comments
Closed

Using external activation functions #981

Maya7991 opened this issue Jul 4, 2024 · 4 comments

Comments

@Maya7991
Copy link

Maya7991 commented Jul 4, 2024

I have a Spiking convolutional neural network. It uses the Leaky(Leaky Integrate and Fire) neuron from SNNTorch library as activation function. Is it possible to use activation functions like from SNNTorch along with Brevitas. Given below is an example architecture.

import snntorch as snn
from snntorch.functional import quant

class SpikingCNN(nn.Module):
    def __init__(self):
        super(SpikingCNN, self).__init__()

        qlif = quant.state_quant(num_bits=8, uniform=False, thr_centered=True)

        self.conv1 = qnn.QuantConv2d(in_channels=1, out_channels=8, kernel_size=5, padding=0, bias=False, weight_bit_width=8)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 =qnn.QuantConv2d(in_channels=8, out_channels=16, kernel_size=3, padding=0, bias=False, weight_bit_width=8)
        self.conv3 = qnn.QuantConv2d(in_channels=16, out_channels=16, kernel_size=1, padding=0, bias=False, weight_bit_width=8)
        self.fc1 = qnn.QuantLinear(3*3*16, 256, bias=True, weight_bit_width=8)
        self.fc2 = qnn.QuantLinear(256, 10, bias=True, weight_bit_width=8)
        self.lif= snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, state_quant=qlif)

    def forward(self, x):
        self.activation.reset_mem()

        x = self.pool(F.relu(self.conv1(x)))  # First conv + ReLU + max pool
        for step in range(self.num_steps):
            x =self.lif(self.conv2(x))  # Second conv + LIF
        x = torch.sum(x, dim=0)
        x = F.relu(self.conv3(x))  # Third conv + ReLU
        x = x.view(-1, 32 * 7 * 7) 
        x = F.relu(self.fc1(x))     
        x = self.fc2(x)            
        return x

spikingcnn_model= SpikingCNN()
spikingcnn_loss=trainNet(spikingcnn_model,0.0003)   # not including training loop here

quantized_weights = {}
    for name, module in spikingcnn_model.named_modules():
        if isinstance(module, qnn.QuantConv2d) and 'conv2' in name:
            quantized_weights[f'{name}.weight'] = module.quant_weight().int().detach().cpu().numpy()

Is it possible to use Brevitas along with such custom activation functions?

The purpose of quantizing my model is to extract the INT8 weights and use it for simulation of VHDL design I have written. I have recorded the INT8 weights of the quantized spiking convolutional layer(conv2). However, I observe that there is some difference in the observed values after the activation function and expected values. I would like to know if Brevitas support using custom activation functions. If yes, does it need any additional configurations?

@Giuseppe5
Copy link
Collaborator

Hi,

Thanks for opening this issue.
Brevitas's layers are made to work as drop-in replacement of the corresponding PyTorch ones, and based on your example, I believe there should be no issue in combining it with third party libraries even though we have no experience with SNNTorch in particular.

You are mentioning observed vs expected values. Where are the expected values coming from?

@Maya7991
Copy link
Author

Hi @Giuseppe5 ,

I apologize for the delay. I had to look into the basics of quantizing a model in order to be able to explain my doubts here.

Use case: Train a spiking CNN with Leaky or LIF activation function in PyTorch & SNNTorch and use the INT8 weights of this trained model in my VHDL design.

  1. Performed QAT of a vanilla Spiking CNN model using Brevitas and extracted INT8 weights from the trained model.
  2. Manually calculated conv operation for a single channel and activation function output.
  3. Compare it with the output of model.

From your previous reply, I understand that there is no problem in using SNNTorch along with Brevitas. However, I have been trying to calculate some channel output manually and compare it with the output of Quantized model. This is where I am seeing a difference in observed vs expected values.

  • Extracted the INT8 weights of a conv layer and for a sample input, calculated the output value of the conv operation.
  • The leaky activation function checks if the output of conv operation is above the threshold value, and produces spike(1) else no spike(0).
  • I calculate this for a channel by hand using INT8 weights and the spikes(0 or 1) does not match with the Quantized model.

I have a few assumptions on why this is happening. The quant and dequant stubs between each layer in a fake quantized model would not allow such a comparison. As this is a fake quantized model, I have to generate a True INT8 model to be able to compare it with the manual calculations I am doing.

If the above question is the problem, Can I generate a true INT8 model in which I can run a inference pass which uses only INT8 values and no FP32 values?
I did not post my thoughts here for so long because I was not able to decide how much of this come under the scope of Brevitas.

note: the input to Spiking conv model comprises of 0 and 1(spike or not spike) which makes manual calculation easy. MAC reduces to just accumulation operation.

Thank you!

@Giuseppe5
Copy link
Collaborator

I still have a few questions about the set-up.
If you could share a reproducible script where you show how to compute the real vs expected results, it could be easier for us to help.

@Giuseppe5
Copy link
Collaborator

If this is still an issue, please feel free to re-open and we'd be more than happy to help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants