Skip to content

Commit 65c473c

Browse files
committed
Add feature to use gradient norm for Hager-Zhang eps value
1 parent 4de3fd5 commit 65c473c

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

varipeps/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,16 @@ class VariPEPS_Config:
202202
Constant used in Hager-Zhang line search method.
203203
line_search_hager_zhang_rho (:obj:`float`):
204204
Constant used in Hager-Zhang line search method.
205+
line_search_hager_zhang_eps_use_grad_norm (:obj:`bool`):
206+
Use norm of gradient multiplied by
207+
:obj:`VariPEPS_Config.line_search_hager_zhang_eps_grad_norm_factor` to
208+
calculate eps value in Hager-Zhang line search. If disabled, the fixed
209+
value from config parameter
210+
:obj:`VariPEPS_Config.line_search_hager_zhang_eps` is used.
211+
line_search_hager_zhang_eps_grad_norm_factor (:obj:`float`):
212+
Factor used for gradient based eps calculation. See parameter
213+
:obj:`VariPEPS_Config.line_search_hager_zhang_eps_use_grad_norm`
214+
for details.
205215
basinhopping_niter (:obj:`int`):
206216
Value for parameter `niter` of :obj:`scipy.optimize.basinhopping`.
207217
See this function for details.
@@ -286,6 +296,8 @@ class VariPEPS_Config:
286296
line_search_hager_zhang_theta: float = 0.5
287297
line_search_hager_zhang_gamma: float = 0.66
288298
line_search_hager_zhang_rho: float = 5
299+
line_search_hager_zhang_eps_use_grad_norm: bool = True
300+
line_search_hager_zhang_eps_grad_norm_factor: float = 1e-2
289301

290302
# Basinhopping
291303
basinhopping_niter: int = 20

varipeps/optimization/line_search.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,12 @@ def line_search(
395395
hager_zhang_initial_found = _Hager_Zhang_Initial_State.NOT_FOUND
396396
hager_zhang_descent_grad = wolfe_descent_new_grad
397397
hager_zhang_state = _Hager_Zhang_State.NONE
398+
hager_zhang_eps = (
399+
jnp.linalg.norm(ravel_pytree(gradient)[0])
400+
* varipeps_config.line_search_hager_zhang_eps_grad_norm_factor
401+
if varipeps_config.line_search_hager_zhang_eps_use_grad_norm
402+
else varipeps_config.line_search_hager_zhang_eps
403+
)
398404

399405
new_value = current_value
400406

@@ -608,7 +614,7 @@ def line_search(
608614

609615
if descent_new_grad >= hz_wolfe_2_right:
610616
if hz_wolfe_1_left >= hz_wolfe_1_right and new_value <= (
611-
current_value + varipeps_config.line_search_hager_zhang_eps
617+
current_value + hager_zhang_eps
612618
):
613619
break
614620

@@ -617,7 +623,7 @@ def line_search(
617623
) * hager_zhang_descent_grad
618624

619625
if hz_approx_wolfe_left >= hager_zhang_descent_grad and new_value <= (
620-
current_value + varipeps_config.line_search_hager_zhang_eps
626+
current_value + hager_zhang_eps
621627
):
622628
break
623629

@@ -635,9 +641,7 @@ def line_search(
635641
hager_zhang_upper_bound_grad = new_gradient
636642
hager_zhang_upper_bound_des_grad = descent_new_grad
637643
hager_zhang_initial_found = _Hager_Zhang_Initial_State.FOUND
638-
elif new_value <= (
639-
current_value + varipeps_config.line_search_hager_zhang_eps
640-
):
644+
elif new_value <= (current_value + hager_zhang_eps):
641645
hager_zhang_lower_bound = alpha
642646
hager_zhang_lower_bound_value = new_value
643647
hager_zhang_lower_bound_grad = new_gradient
@@ -700,9 +704,7 @@ def line_search(
700704
hager_zhang_upper_bound_grad = new_gradient
701705
hager_zhang_upper_bound_des_grad = descent_new_grad
702706
hager_zhang_initial_found = _Hager_Zhang_Initial_State.FOUND
703-
elif descent_new_grad < 0 and new_value > (
704-
current_value + varipeps_config.line_search_hager_zhang_eps
705-
):
707+
elif descent_new_grad < 0 and new_value > (current_value + hager_zhang_eps):
706708
alpha = varipeps_config.line_search_hager_zhang_theta * alpha
707709
hager_zhang_initial_found = (
708710
_Hager_Zhang_Initial_State.SCALAR_LOWER_VALUE_GREATER
@@ -725,9 +727,7 @@ def line_search(
725727
count += 1
726728
continue
727729
else:
728-
if new_value <= (
729-
current_value + varipeps_config.line_search_hager_zhang_eps
730-
):
730+
if new_value <= (current_value + hager_zhang_eps):
731731
hager_zhang_lower_bound = alpha
732732
hager_zhang_lower_bound_value = new_value
733733
hager_zhang_lower_bound_grad = new_gradient
@@ -892,9 +892,7 @@ def line_search(
892892
hager_zhang_upper_bound_grad = new_gradient
893893
hager_zhang_upper_bound_des_grad = descent_new_grad
894894
hager_zhang_state = _Hager_Zhang_State.NONE
895-
elif new_value <= (
896-
current_value + varipeps_config.line_search_hager_zhang_eps
897-
):
895+
elif new_value <= (current_value + hager_zhang_eps):
898896
hager_zhang_lower_bound = alpha
899897
hager_zhang_lower_bound_value = new_value
900898
hager_zhang_lower_bound_grad = new_gradient
@@ -938,9 +936,7 @@ def line_search(
938936
hager_zhang_upper_bound_grad = new_gradient
939937
hager_zhang_upper_bound_des_grad = descent_new_grad
940938
hager_zhang_state = _Hager_Zhang_State.NONE
941-
elif new_value <= (
942-
current_value + varipeps_config.line_search_hager_zhang_eps
943-
):
939+
elif new_value <= (current_value + hager_zhang_eps):
944940
hager_zhang_lower_bound = alpha
945941
hager_zhang_lower_bound_value = new_value
946942
hager_zhang_lower_bound_grad = new_gradient

0 commit comments

Comments
 (0)