Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicom24 authored Aug 21, 2023
0 parents commit ee1e2e5
Show file tree
Hide file tree
Showing 9 changed files with 1,450 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# TFDiff

Empty file added complex/__init__.py
Empty file.
Binary file added complex/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added complex/__pycache__/complex_module.cpython-37.pyc
Binary file not shown.
204 changes: 204 additions & 0 deletions complex/complex_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
@author: spopoff
"""
import torch
from torch.nn.functional import relu, max_pool2d, max_pool3d, avg_pool2d, avg_pool1d, dropout, dropout2d, dropout3d, interpolate, sigmoid, tanh, leaky_relu

def complex_matmul(A, B):
'''
Performs the matrix product between two complex matricess
'''

outp_real = torch.matmul(A.real, B.real) - torch.matmul(A.imag, B.imag)
outp_imag = torch.matmul(A.real, B.imag) + torch.matmul(A.imag, B.real)

return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64)

def complex_avg_pool1d(input, *args, **kwargs):
absolute_value_real = avg_pool1d(input.real, *args, **kwargs)
absolute_value_imag = avg_pool1d(input.imag, *args, **kwargs)

return absolute_value_real.type(torch.complex64)+1j*absolute_value_imag.type(torch.complex64)

def complex_avg_pool2d(input, *args, **kwargs):
'''
Perform complex average pooling.
'''
absolute_value_real = avg_pool2d(input.real, *args, **kwargs)
absolute_value_imag = avg_pool2d(input.imag, *args, **kwargs)

return absolute_value_real.type(torch.complex64)+1j*absolute_value_imag.type(torch.complex64)

def complex_normalize(input):
'''
Perform complex normalization
'''
real_value, imag_value = input.real, input.imag
real_norm = (real_value - real_value.mean()) / real_value.std()
imag_norm = (imag_value - imag_value.mean()) / imag_value.std()

return real_norm.type(torch.complex64) + 1j*imag_norm.type(torch.complex64)

def complex_leaky_relu(input, negative_slope):
return leaky_relu(input.real, negative_slope).type(torch.complex64)+1j*leaky_relu(input.imag, negative_slope).type(torch.complex64)

def complex_relu(input):
return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64)

def complex_sigmoid(input):
return sigmoid(input.real).type(torch.complex64)+1j*sigmoid(input.imag).type(torch.complex64)

def complex_tanh(input):
return tanh(input.real).type(torch.complex64)+1j*tanh(input.imag).type(torch.complex64)

def complex_opposite(input):
return -(input.real).type(torch.complex64)+1j*(-(input.imag).type(torch.complex64))

def complex_stack(input, dim):
input_real = [x.real for x in input]
input_imag = [x.imag for x in input]
return torch.stack(input_real, dim).type(torch.complex64)+1j*torch.stack(input_imag, dim).type(torch.complex64)

def _retrieve_elements_from_indices(tensor, indices):
flattened_tensor = tensor.flatten(start_dim=-2)
output = flattened_tensor.gather(dim=-1, index=indices.flatten(start_dim=-2)).view_as(indices)
return output

def _retrieve_elements_from_indices3d(tensor, indices):
flattened_tensor = tensor.flatten(start_dim=-3)
output = flattened_tensor.gather(dim=-1, index=indices.flatten(start_dim=-3)).view_as(indices)
return output

def complex_upsample(input, size=None, scale_factor=None, mode='nearest',
align_corners=None, recompute_scale_factor=None):
'''
Performs upsampling by separately interpolating the real and imaginary part and recombining
'''
outp_real = interpolate(input.real, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor)
outp_imag = interpolate(input.imag, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor)

return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64)

def complex_upsample2(input, size=None, scale_factor=None, mode='nearest',
align_corners=None, recompute_scale_factor=None):
'''
Performs upsampling by separately interpolating the amplitude and phase part and recombining
'''
outp_abs = interpolate(input.abs(), size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor)
angle = torch.atan2(input.imag,input.real)
outp_angle = interpolate(angle, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor)

return outp_abs \
* (torch.cos(angle).type(torch.complex64)+1j*torch.sin(angle).type(torch.complex64))


def complex_max_pool2d(input,kernel_size, stride=None, padding=0,
dilation=1, ceil_mode=False, return_indices=False):
'''
Perform complex max pooling by selecting on the absolute value on the complex values.
'''
absolute_value, indices = max_pool2d(
input.abs(),
kernel_size = kernel_size,
stride = stride,
padding = padding,
dilation = dilation,
ceil_mode = ceil_mode,
return_indices = True
)
# performs the selection on the absolute values
absolute_value = absolute_value.type(torch.complex64)
# retrieve the corresponding phase value using the indices
# unfortunately, the derivative for 'angle' is not implemented
angle = torch.atan2(input.imag,input.real)
# get only the phase values selected by max pool
angle = _retrieve_elements_from_indices(angle, indices)
return absolute_value \
* (torch.cos(angle).type(torch.complex64)+1j*torch.sin(angle).type(torch.complex64))

def complex_max_pool3d(input,kernel_size, stride=None, padding=0,
dilation=1, ceil_mode=False, return_indices=False):
'''
Perform complex max pooling by selecting on the absolute value on the complex values.
'''
absolute_value, indices = max_pool3d(
input.abs(),
kernel_size = kernel_size,
stride = stride,
padding = padding,
dilation = dilation,
ceil_mode = ceil_mode,
return_indices = True
)
# performs the selection on the absolute values
absolute_value = absolute_value.type(torch.complex64)
# retrieve the corresponding phase value using the indices
# unfortunately, the derivative for 'angle' is not implemented
angle = torch.atan2(input.imag,input.real)
# get only the phase values selected by max pool
angle = _retrieve_elements_from_indices3d(angle, indices)
return absolute_value \
* (torch.cos(angle).type(torch.complex64)+1j*torch.sin(angle).type(torch.complex64))


def complex_adaptive_avg_pool3d(input,kernel_size, stride=None, padding=0,
dilation=1, ceil_mode=False, return_indices=False):
'''
Perform complex max pooling by selecting on the absolute value on the complex values.
'''
absolute_value, indices = max_pool3d(
input.abs(),
kernel_size = kernel_size,
stride = stride,
padding = padding,
dilation = dilation,
ceil_mode = ceil_mode,
return_indices = True
)
# performs the selection on the absolute values
absolute_value = absolute_value.type(torch.complex64)
# retrieve the corresponding phase value using the indices
# unfortunately, the derivative for 'angle' is not implemented
angle = torch.atan2(input.imag,input.real)
# get only the phase values selected by max pool
angle = _retrieve_elements_from_indices(angle, indices)
return absolute_value \
* (torch.cos(angle).type(torch.complex64)+1j*torch.sin(angle).type(torch.complex64))


def complex_dropout(input, p=0.5, training=True):
# need to have the same dropout mask for real and imaginary part,
# this not a clean solution!
device = input.device
mask = torch.ones(*input.shape, dtype = torch.float32, device = device)
mask = dropout(mask, p, training)*1/(1-p)
mask.type(input.dtype)
return mask*input


def complex_dropout2d(input, p=0.5, training=True):
# need to have the same dropout mask for real and imaginary part,
# this not a clean solution!
device = input.device
mask = torch.ones(*input.shape, dtype = torch.float32, device = device)
mask = dropout2d(mask, p, training)*1/(1-p)
mask.type(input.dtype)
return mask*input

def complex_dropout3d(input, p=0.5, training=True):
# need to have the same dropout mask for real and imaginary part,
# this not a clean solution!
device = input.device
mask = torch.ones(*input.shape, dtype = torch.float32, device = device)
mask = dropout3d(mask, p, training)*1/(1-p)
mask.type(input.dtype)
return mask*input


Loading

0 comments on commit ee1e2e5

Please sign in to comment.