Skip to content

Commit

Permalink
commit unknown
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitrysarov committed Oct 3, 2019
1 parent 49a3d22 commit 72031a5
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 89 deletions.
8 changes: 5 additions & 3 deletions LRP/lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ class LRP():
:param rule:str: name of used rule
'''

def __init__(self, model, rule):
def __init__(self, model, rule, input_lowest=-1, input_highest=1):
self.model = copy.deepcopy(model)
self.model = self.model.eval()
self.model = utils.redefine_nn(self.model, rule=rule) #redefine each layer(module) of model, to set custom autograd func
self.model = utils.redefine_nn(self.model, rule=rule,
input_lowest=input_lowest,
input_highest=input_highest) #redefine each layer(module) of model, to set custom autograd func
self.output = None

def forward(self, input_):
Expand All @@ -35,7 +37,7 @@ def relprop(self, input_, R=None):
C = self.local_input.grad.clone().detach()
assert C is not None, 'obtained relevance is None'
self.local_input.grad = None
R = C*input_.clone().detach()
R = C#*input_.clone().detach()
return R

__call__ = relprop
Expand Down
19 changes: 11 additions & 8 deletions LRP/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def z_rule(func, input, R, func_args, keep_bias=False):
Z.backward(S)
assert input.grad is not None
C = input.grad
R = input * C
return R
Ri = input * C
print(Ri.sum())
return Ri

@staticmethod
def z_plus(func, input, R, func_args, keep_bias=False):
Expand All @@ -70,15 +71,16 @@ def z_plus(func, input, R, func_args, keep_bias=False):
if func_args.get('bias', None) is not None:
if not keep_bias:
func_args['bias'] = None
func_args['weight'].clamp_(0, float('inf'))
if func_args.get('weight', None) is not None:
func_args['weight'].clamp_(0, float('inf'))
with torch.enable_grad():
Z = func(input, **func_args)
S = R /(Z + (Z==0).float()*np.finfo(np.float32).eps)
Z.backward(S)
assert input.grad is not None
C = input.grad
R = input * C
return R
Ri = input * C
return Ri

@staticmethod
def z_box(func, input, R, func_args, lowest, highest, keep_bias=False):
Expand Down Expand Up @@ -111,14 +113,15 @@ def z_box(func, input, R, func_args, lowest, highest, keep_bias=False):
if input.grad is not None: input.grad.zero_()
input.retain_grad()
with torch.enable_grad():
Z = func(input, **ifunc_args) - func(L, **pfunc_args) - func(H, **nfunc_args) + np.finfo(np.float32).eps
Z = func(input, **ifunc_args) - func(L, **pfunc_args) - func(H, **nfunc_args)
S = R / (Z + (Z==0).float()*np.finfo(np.float32).eps)
Z.backward(S)
assert input.grad is not None
assert L.grad is not None
assert H.grad is not None
R = input * input.grad - L * L.grad - H * H.grad
return R
import ipdb; ipdb.set_trace() # BREAKPOINT
Ri = input * input.grad - L * L.grad - H * H.grad
return Ri
#
#
#def z_epsilon_rule(module, input_, R, keep_bias=True):
Expand Down
2 changes: 1 addition & 1 deletion LRP/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def copy_module(module):
module._backward_hooks.popitem() # remove hooks from module copy
return module

def redefine_nn(model, rule='z_rule', input_lowest=-1, input_highest=1):
def redefine_nn(model, rule, input_lowest, input_highest):
'''
go over model layers and overload chosen instance methods (e.g. forward()).
New methods come from classes of layers module
Expand Down
Empty file added models/__init__.py
Empty file.
Loading

0 comments on commit 72031a5

Please sign in to comment.