@@ -395,6 +395,12 @@ def line_search(
395
395
hager_zhang_initial_found = _Hager_Zhang_Initial_State .NOT_FOUND
396
396
hager_zhang_descent_grad = wolfe_descent_new_grad
397
397
hager_zhang_state = _Hager_Zhang_State .NONE
398
+ hager_zhang_eps = (
399
+ jnp .linalg .norm (ravel_pytree (gradient )[0 ])
400
+ * varipeps_config .line_search_hager_zhang_eps_grad_norm_factor
401
+ if varipeps_config .line_search_hager_zhang_eps_use_grad_norm
402
+ else varipeps_config .line_search_hager_zhang_eps
403
+ )
398
404
399
405
new_value = current_value
400
406
@@ -608,7 +614,7 @@ def line_search(
608
614
609
615
if descent_new_grad >= hz_wolfe_2_right :
610
616
if hz_wolfe_1_left >= hz_wolfe_1_right and new_value <= (
611
- current_value + varipeps_config . line_search_hager_zhang_eps
617
+ current_value + hager_zhang_eps
612
618
):
613
619
break
614
620
@@ -617,7 +623,7 @@ def line_search(
617
623
) * hager_zhang_descent_grad
618
624
619
625
if hz_approx_wolfe_left >= hager_zhang_descent_grad and new_value <= (
620
- current_value + varipeps_config . line_search_hager_zhang_eps
626
+ current_value + hager_zhang_eps
621
627
):
622
628
break
623
629
@@ -635,9 +641,7 @@ def line_search(
635
641
hager_zhang_upper_bound_grad = new_gradient
636
642
hager_zhang_upper_bound_des_grad = descent_new_grad
637
643
hager_zhang_initial_found = _Hager_Zhang_Initial_State .FOUND
638
- elif new_value <= (
639
- current_value + varipeps_config .line_search_hager_zhang_eps
640
- ):
644
+ elif new_value <= (current_value + hager_zhang_eps ):
641
645
hager_zhang_lower_bound = alpha
642
646
hager_zhang_lower_bound_value = new_value
643
647
hager_zhang_lower_bound_grad = new_gradient
@@ -700,9 +704,7 @@ def line_search(
700
704
hager_zhang_upper_bound_grad = new_gradient
701
705
hager_zhang_upper_bound_des_grad = descent_new_grad
702
706
hager_zhang_initial_found = _Hager_Zhang_Initial_State .FOUND
703
- elif descent_new_grad < 0 and new_value > (
704
- current_value + varipeps_config .line_search_hager_zhang_eps
705
- ):
707
+ elif descent_new_grad < 0 and new_value > (current_value + hager_zhang_eps ):
706
708
alpha = varipeps_config .line_search_hager_zhang_theta * alpha
707
709
hager_zhang_initial_found = (
708
710
_Hager_Zhang_Initial_State .SCALAR_LOWER_VALUE_GREATER
@@ -725,9 +727,7 @@ def line_search(
725
727
count += 1
726
728
continue
727
729
else :
728
- if new_value <= (
729
- current_value + varipeps_config .line_search_hager_zhang_eps
730
- ):
730
+ if new_value <= (current_value + hager_zhang_eps ):
731
731
hager_zhang_lower_bound = alpha
732
732
hager_zhang_lower_bound_value = new_value
733
733
hager_zhang_lower_bound_grad = new_gradient
@@ -892,9 +892,7 @@ def line_search(
892
892
hager_zhang_upper_bound_grad = new_gradient
893
893
hager_zhang_upper_bound_des_grad = descent_new_grad
894
894
hager_zhang_state = _Hager_Zhang_State .NONE
895
- elif new_value <= (
896
- current_value + varipeps_config .line_search_hager_zhang_eps
897
- ):
895
+ elif new_value <= (current_value + hager_zhang_eps ):
898
896
hager_zhang_lower_bound = alpha
899
897
hager_zhang_lower_bound_value = new_value
900
898
hager_zhang_lower_bound_grad = new_gradient
@@ -938,9 +936,7 @@ def line_search(
938
936
hager_zhang_upper_bound_grad = new_gradient
939
937
hager_zhang_upper_bound_des_grad = descent_new_grad
940
938
hager_zhang_state = _Hager_Zhang_State .NONE
941
- elif new_value <= (
942
- current_value + varipeps_config .line_search_hager_zhang_eps
943
- ):
939
+ elif new_value <= (current_value + hager_zhang_eps ):
944
940
hager_zhang_lower_bound = alpha
945
941
hager_zhang_lower_bound_value = new_value
946
942
hager_zhang_lower_bound_grad = new_gradient
0 commit comments