Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
henry-yeh committed Jul 21, 2024
1 parent 08a6e67 commit 38ddcb9
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions eval_atsp/test_glop.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@

##### GLOP parameters #####
N_REVISER = 50 # We only test on Reviser-50; using more revisers requires code modifications
N_REVISIONS = 3
N_SAMPLES = 100 # for sampling decoding during revision
N_REVISIONS = 3 # number of revision iterations
N_SAMPLES = {
150: 2000,
250: 1000,
1000: 500
} # for sampling decoding during revision



Expand All @@ -69,7 +73,7 @@
'int_max': 1000*1000,
'scaler': 1000*1000
},
'pomo_size': N_SAMPLES
'pomo_size': 500,
}

model_params = {
Expand All @@ -85,7 +89,7 @@
'ms_layer1_init': (1/2)**(1/2),
'ms_layer2_init': (1/16)**(1/2),
'eval_type': 'softmax', # note here, can be greedy
'one_hot_seed_cnt': 20, # must be >= node_cnt
'one_hot_seed_cnt': N_REVISER, # must be >= node_cnt
}

tester_params = {
Expand Down Expand Up @@ -121,7 +125,7 @@
def revision(tour, inst, tester):
sub_tours = tour.reshape(-1, N_REVISER) # shape: (batch, revision_len)
sub_insts = [inst[sub_tour][:, sub_tour] for sub_tour in sub_tours]
original_scores = torch.stack([inst[sub_tour[:-1], torch.roll(sub_tour, shifts=-1)[:-1]].sum() for sub_tour in sub_tours]) # note that original_scores are positive values
original_scores = torch.tensor([cal_len_shpp(sub_tour, inst) for sub_tour in sub_tours]) # note that original_scores are positive values
# Scale the sub_insts to make the largest value 1
scale_coef = [sub_inst.max() for sub_inst in sub_insts]
sub_insts = torch.stack(sub_insts)
Expand All @@ -136,16 +140,15 @@ def revision(tour, inst, tester):
# TODO: unmcomment to validate the subtours
for i in range(len(sub_insts)):
validate_subtour(solutions[i], sub_insts[i], revised_scores[i])

# Gather the subtours according to the solutions
revised_tours = sub_tours.gather(1, solutions)

# Compare the original scores and the revised scores
improved_scores = original_scores - revised_scores
# subtours should be aranged in the same order as the original tours, if the improved_scores <= 0
kept_subtour_idx = improved_scores <= 0
revised_tours[kept_subtour_idx] = sub_tours[kept_subtour_idx]

solutions[improved_scores <= 0] = torch.arange(sub_tours.shape[1])
# Gather the subtours according to the solutions
revised_tours = sub_tours.gather(1, solutions)
# Flatten the revised_tours
revised_tours = revised_tours.reshape(-1) # shape: (batch * revision_len) i.e. (node_cnt,)
return revised_tours

def validate_subtour(subtour, dist, cost):
Expand All @@ -156,7 +159,11 @@ def validate_subtour(subtour, dist, cost):
for i in range(1, len(subtour) - 1):
assert i in subtour

def calc_len(tour, dist):
def validate_tour(tour):
for i in range(1, len(tour) - 1):
assert i in tour

def cal_len(tour, dist):
cost = dist[tour, torch.roll(tour, -1, -1)].sum()
return cost.item()

Expand Down Expand Up @@ -187,16 +194,16 @@ def main(n):
for inst in dataset:
tour, cost = random_insertion_non_euclidean(inst, order)
original_costs.append(cost)

tour = torch.tensor(tour.astype(np.int64))

for revision_iter in range(N_REVISIONS):
tour = revision(tour, inst, tester)
# Shift the tour to the right by N_SHIFTS
tour = torch.roll(tour, shifts=N_SHIFTS, dims=-1)
# cost = calc_len(tour, inst)
# print(f"cost after revision {revision_iter}: {cost}")
cost = calc_len(tour, inst)

# TODO: unmcomment to validate the solution
# validate_tour(tour)
cost = cal_len(tour, inst)
revised_costs.append(cost)

total_duration = time.time() - start
Expand All @@ -207,4 +214,8 @@ def main(n):


if __name__ == "__main__":
main(int(sys.argv[1]))
N = int(sys.argv[1])
env_params['pomo_size'] = N_SAMPLES.get(N, 500)

main(N)

0 comments on commit 38ddcb9

Please sign in to comment.