https://arxiv.org/abs/2410.20672
regular transformer:
recursive version with L layers and B blocks:
the rrt:
for each weight matrix at each layer:
-
$W^\prime$ is learned shared weights -
$BA$ is position-specific LoRA (initialized via SVD)
-
compute residuals between original and tied for each position:
$R^l = W^l - W^\prime_{((l-1) \bmod L/B + 1)}$ -
get initial LoRA weights via truncated SVD:
$U_r^l, \Sigma_r^l, V_r^l = \text{TruncatedSVD}(R^l; r)$ $B^l = U_r^l \Sigma_r^l$ $A^l = (V_r^l)^T$
-
during training:
- forward:
$h = W^\prime x + B^lA^lx$ - backward: update BOTH
$W^\prime$ AND$B^l,A^l$ matrices -
$W^\prime$ learns optimal shared representation -
$B^l,A^l$ learn position-specific adjustments
- forward:
so the final learned mapping approximates:
-
$r = 0$ : pure recursive ($W^l \approx W^\prime$ ) - small
$r$ : mostly shared with slight adjustments - full
$r$ : can recover original ($W^l = W^\prime + BA$ )