Skip to content

Commit

Permalink
add rtdp generic experiment; reveal bug in sm.SelfishMining.honest()
Browse files Browse the repository at this point in the history
  • Loading branch information
pkel committed May 8, 2024
1 parent b1b55ef commit 74b9f62
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 67 deletions.
148 changes: 81 additions & 67 deletions mdp/measure-rtdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import argparse
import joblib
import pandas
import random
import pickle
from time import time
import traceback
from tqdm import tqdm
Expand All @@ -31,88 +31,113 @@
]

rows = [
dict(row=1, protocol="bitcoin", model="fc16", truncated=True, algo="aft20", ref=1),
dict(row=2, protocol="bitcoin", model="aft20", truncated=True, algo="aft20", ref=1),
dict(row=3, protocol="bitcoin", model="fc16", truncated=True, algo="rtdp", ref=1),
dict(row=4, protocol="bitcoin", model="aft20", truncated=True, algo="rtdp", ref=1),
dict(row=5, protocol="bitcoin", model="fc16", truncated=False, algo="rtdp", ref=1),
dict(row=6, protocol="bitcoin", model="aft20", truncated=False, algo="rtdp", ref=5),
dict(row=1, protocol="bitcoin", model="fc16", trunc=40, algo="aft20", ref=1),
dict(row=2, protocol="bitcoin", model="aft20", trunc=40, algo="aft20", ref=1),
dict(row=3, protocol="bitcoin", model="fc16", trunc=40, algo="rtdp", ref=1),
dict(row=4, protocol="bitcoin", model="aft20", trunc=40, algo="rtdp", ref=1),
dict(row=5, protocol="bitcoin", model="fc16", trunc=0, algo="rtdp", ref=1),
dict(row=6, protocol="bitcoin", model="aft20", trunc=0, algo="rtdp", ref=1),
dict(row=7, protocol="bitcoin", model="generic", trunc=8, algo="aft20", ref=1),
dict(row=8, protocol="bitcoin", model="generic", trunc=8, algo="rtdp", ref=1),
dict(row=9, protocol="bitcoin", model="generic", trunc=0, algo="rtdp", ref=5),
]

horizon = 100

# Algorithms
# TODO, I think it's interesting to report the number of states explored / visited!
# TODO it might be instructive to track/report the size of the policy-induced markov chain
# TODO it might appropriate to derive steady states and report value/progress on that


def algo_aft20(model):
eps = 0.01 # termination epsilon

def algo_aft20(implicit_mdp, *args, horizon, vi_delta, **kwargs):
# Compile Full MDP
mdp = Compiler(model).mdp()
mdp = Compiler(implicit_mdp).mdp()

# Derive PTO MDP
mdp = aft20barzur.ptmdp(mdp, horizon=horizon)

# Solve PTO MDP
vi = mdp.value_iteration(stop_delta=eps, eps=None, discount=1)
vi = mdp.value_iteration(stop_delta=vi_delta, eps=None, discount=1)

value = 0.0
progress = 0.0
for state, prob in mdp.start.items():
value += vi["vi_value"][state] * prob
progress += vi["vi_progress"][state] * prob

return dict(value=value, progress=progress, rpp=value / progress)
return dict(
value=value,
progress=progress,
n_states=mdp.n_states,
)


def algo_rtdp(model):
steps = 1_000_000
eps = 0.3 # exploration epsilon
def algo_rtdp(implicit_mdp, *args, horizon, rtdp_steps, rtdp_eps, **kwargs):
agent = RTDP(implicit_mdp, eps=rtdp_eps, eps_honest=0, horizon=horizon)

agent = RTDP(model, eps=eps, eps_honest=0, horizon=horizon)
for i in range(steps):
for i in range(rtdp_steps):
agent.step()

value, progress = agent.start_value_and_progress()
return dict(value=value, progress=progress, rpp=value / progress)
return dict(
value=value,
progress=progress,
n_states=len(agent.states),
)


# How do we instantiate the models and run the algo?


def instanciate_model(*args, model, protocol, truncated, alpha, gamma, **kwargs):
def implicit_mdp(*args, model, protocol, trunc, alpha, gamma, **kwargs):
if model in ["fc16", "aft20"]:
assert protocol == "bitcoin", "fc16 and aft20 model are bitcoin-only"

common = dict(alpha=alpha, gamma=gamma)

if model == "fc16" and truncated:
return fc16sapirshtein.BitcoinSM(**common, maximum_fork_length=40)
if trunc <= 0:
trunc = 100_000
# TODO disable truncation completely

if model == "fc16":
return fc16sapirshtein.BitcoinSM(**common, maximum_fork_length=trunc)

if model == "fc16" and not truncated:
return fc16sapirshtein.BitcoinSM(
**common, maximum_fork_length=10000
) # TODO, disable truncation completely
if model == "aft20":
return aft20barzur.BitcoinSM(**common, maximum_fork_length=trunc)

if model == "aft20" and truncated:
return aft20barzur.BitcoinSM(**common, maximum_fork_length=40)
if model == "generic":
common["merge_isomorphic"] = False
common["maximum_size"] = trunc

if model == "aft20" and not truncated:
return aft20barzur.BitcoinSM(
**common, maximum_fork_length=10000
) # TODO, disable truncation completely
if protocol == "bitcoin":
return SelfishMining(Bitcoin(), **common)

raise ValueError(f"unknown protocol: {protocol}")

raise ValueError(f"unknown model: {model}")


def measure_unsafe(*args, algo, **kwargs):
model = instanciate_model(**kwargs)
# Command line arguments

argp = argparse.ArgumentParser()
argp.add_argument("-j", "--n_jobs", type=int, default=1, metavar="INT")
argp.add_argument("-H", "--horizon", type=int, default=30, metavar="INT")
argp.add_argument("--rtdp_eps", type=float, default=0.25, metavar="FLOAT")
argp.add_argument("--rtdp_steps", type=int, default=100_000, metavar="INT")
argp.add_argument("--vi_delta", type=float, default=0.01, metavar="FLOAT")
args = argp.parse_args()

# Single measurement


def measure_unsafe(*_args, algo, **kwargs):
mdp = implicit_mdp(**kwargs)
kwargs["horizon"] = args.horizon
if algo == "aft20":
return algo_aft20(model)
hp = dict(vi_delta=args.vi_delta)
return algo_aft20(mdp, **hp, **kwargs) | dict(hyperparams=hp)
if algo == "rtdp":
return algo_rtdp(model)
hp = dict(rtdp_eps=args.rtdp_eps, rtdp_steps=args.rtdp_steps)
return algo_rtdp(mdp, **hp, **kwargs) | dict(hyperparams=hp)

raise ValueError(f"unknown algo: {algo}")

Expand All @@ -124,19 +149,12 @@ def measure(*args, **kwargs):
return dict(error=str(e), traceback=traceback.format_exc())


# Command line arguments
argp = argparse.ArgumentParser()
argp.add_argument("-j", "--n_jobs", type=int, default=1, metavar="INT")
args = argp.parse_args()


# Multicore measurement loop


def job(*args, **kwargs):
start_time = time()
result = measure(**kwargs)
return kwargs | result | dict(time=time() - start_time)
return kwargs | measure(**kwargs) | dict(time=time() - start_time)


def job_gen():
Expand All @@ -146,7 +164,6 @@ def job_gen():


jobs = list(job_gen())
jobs = random.sample(jobs, len(jobs))

res_gen = joblib.Parallel(n_jobs=args.n_jobs, return_as="generator")(jobs)

Expand All @@ -162,6 +179,23 @@ def job_gen():

df = pandas.DataFrame(rows)

# Print

print(df)

# Save to disk

fname = "measure-rtdp.pkl"
print()
print(f"storing results in {fname}")

results = dict(
data=df,
)

with open(fname, "wb") as pkl:
pickle.dump(results, pkl)

# Error handling

if "error" in df.columns:
Expand All @@ -171,23 +205,3 @@ def job_gen():
print()

raise Exception("errors during measurements")


def tabulate(df, key):
return (
df.pivot(
columns=["attacker", "alpha", "gamma"],
index=["row", "protocol", "model", "truncated", "algo", "ref"],
values=key,
)
.reset_index()
.set_index(["row"])
)


print()
print("rpp")
print(tabulate(df, "rpp"))
print()
print("time")
print(tabulate(df, "time"))
3 changes: 3 additions & 0 deletions mdp/sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,9 @@ def honest(self, s: State) -> Action:
return Consider(0)
return Continue()

# TODO Continue might be unavailable, if Communicate is the only remaining
# action during truncation; get rid of communicate action

def apply(self, a: Action, s: State) -> list[Transition]:
if isinstance(a, Release):
return self.apply_release(a.i, s)
Expand Down
32 changes: 32 additions & 0 deletions mdp/tab-rtdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pickle

fname = "measure-rtdp.pkl"
print()
print(f"load results in {fname}")

with open(fname, "rb") as pkl:
results = pickle.load(pkl)

df = results["data"]

df = df.assign(rpp=lambda x: x.value / x.progress)


def tabulate(df, key):
return (
df.pivot(
columns=["attacker", "alpha", "gamma"],
index=["row", "protocol", "model", "truncated", "algo", "ref"],
values=key,
)
.reset_index()
.set_index(["row"])
)


print()
print("rpp")
print(tabulate(df, "rpp"))
print()
print("time")
print(tabulate(df, "time"))

0 comments on commit 74b9f62

Please sign in to comment.