Skip to content

Commit

Permalink
Fixed parity experiment.
Browse files Browse the repository at this point in the history
  • Loading branch information
e2crawfo committed Oct 26, 2017
1 parent 0aded96 commit d8396f1
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions dps/envs/grid_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,13 @@ def __init__(self, **kwargs):
else:
raise Exception("NotImplemented")

mnist_x, mnist_y, _ = load_emnist(
mnist_x, mnist_y, classmap = load_emnist(
cfg.data_dir, mnist_classes, balance=True,
downsample_factor=self.downsample_factor)
mnist_x = mnist_x.reshape(-1, self.image_width, self.image_width)
mnist_y = np.squeeze(mnist_y, 1)
inverted_classmap = {v: k for k, v in classmap.items()}
mnist_y = np.array([inverted_classmap[y] for y in mnist_y])

digit_reps = DataContainer(mnist_x, mnist_y)
blank_element = np.zeros((self.image_width, self.image_width))
Expand Down Expand Up @@ -308,10 +310,6 @@ def make_dataset(
padded_env[:env.shape[0], :env.shape[1]] = env
env = padded_env

if j % 10000 == 0:
print(image_to_string(env))
print("\n")

new_X.append(env)
y = func(ys)

Expand All @@ -323,6 +321,11 @@ def make_dataset(
_y[int(y)] = 1.0
y = _y

if j % 10000 == 0:
print(y)
print(image_to_string(env))
print("\n")

new_Y.append(y)

new_X = np.array(new_X).astype('f')
Expand Down

0 comments on commit d8396f1

Please sign in to comment.