Skip to content

Commit

Permalink
Merge pull request #315 from cool-RR:2020-02-21-base
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 301865541
  • Loading branch information
copybara-github committed Mar 19, 2020
2 parents eb28db1 + b49268d commit 889afdf
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions trax/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,10 @@ def init(self, input_signature, rng=None):
return (weights, state)
else:
return (EMPTY_WEIGHTS, state)
except Exception:
except Exception as e:
name, trace = self.__class__.__name__, _short_traceback(skip=3)
raise LayerError(name, 'init', self._caller,
input_signature, trace)
input_signature, trace) from e

def init_from_file(self, file_name, weights_only=False):
"""Initializes this layer and its sublayers from a file.
Expand Down Expand Up @@ -445,10 +445,10 @@ def pure_fn(self, x, weights, state, rng):
self._state = s
return outputs, s

except Exception:
except Exception as e:
name, trace = self.__class__.__name__, _short_traceback()
raise LayerError(name, 'pure_fn',
self._caller, signature(x), trace)
self._caller, signature(x), trace) from e

def output_signature(self, input_signature):
"""Returns output signature this layer would give for `input_signature`."""
Expand Down Expand Up @@ -479,10 +479,10 @@ def call_on_input(x, weights, state, rng):
s = math.abstract_eval(call_on_input)(
input_signature, weight_signature, self.state, rng)
return s
except Exception:
except Exception as e:
name, trace = self.__class__.__name__, _short_traceback(skip=3)
raise LayerError(name, '_forward_abstract', self._caller, input_signature,
trace)
trace) from e

# pylint: disable=protected-access
def _set_rng_recursive(self, rng):
Expand Down Expand Up @@ -645,8 +645,8 @@ def Fn(f, n_in=None, n_out=None): # pylint: disable=invalid-name
dummy_args = [np.array([[0.0]]) for _ in range(n_in)]
res = f(*dummy_args)
n_out = len(res) if isinstance(res, (list, tuple)) else 1
except:
raise ValueError('n_out is not set and could not be determined')
except Exception as e:
raise ValueError('n_out is not set and could not be determined') from e

# Create the layer.
@layer(n_in=n_in, n_out=n_out)
Expand Down

0 comments on commit 889afdf

Please sign in to comment.