Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update README.md and train.py #252

Merged
merged 3 commits into from
Sep 16, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mt_with_external_memory/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class ExternalMemory(object):
name,
mem_slot_size,
boot_layer,
initial_weight,
readonly=False,
enable_interpolation=True):
""" Initialization.
Expand All @@ -154,6 +155,8 @@ class ExternalMemory(object):
sequence layer has sequence length indicating the number
of memory slots, and size as memory slot size.
:type boot_layer: LayerOutput
:param initial_weight: Initializer for addressing weights.
:type initial_weight: LayerOutput
:param readonly: If true, the memory is read-only, and write function cannot
be called. Default is false.
:type readonly: bool
Expand Down Expand Up @@ -205,7 +208,7 @@ class ExternalMemory(object):

- `_content_addressing`: 通过基于内容的寻址,计算得到读写操作的寻址强度。
- `_interpolation`: 通过插值寻址(当前寻址强度和上一时间步寻址强度的线性加权),更新当前寻址强度。
- `_get_addressing_weight`: 调用上述两个寻址操作,获得对存储导员的读写操作的最终寻址强度
- `_get_addressing_weight`: 调用上述两个寻址操作,获得对存储单元的读写操作的最终寻址强度


对外接口包含:
Expand All @@ -214,6 +217,7 @@ class ExternalMemory(object):
- 输入参数 `name`: 外部记忆单元名,不同实例的相同命名将共享同一外部记忆单元。
- 输入参数 `mem_slot_size`: 单个记忆槽(向量)的维度。
- 输入参数 `boot_layer`: 用于内存槽初始化的层。需为序列类型,序列长度表明记忆槽的数量。
- 输入参数 `initial_weight`: 用于初始化寻址强度。
- 输入参数 `readonly`: 是否打开只读模式(例如打开只读模式,该实例可用于注意力机制)。打开只读模式,`write` 方法不可被调用。
- 输入参数 `enable_interpolation`: 是否允许插值寻址(例如当用于注意力机制时,需要关闭插值寻址)。
- `write`: 写操作。
Expand All @@ -230,7 +234,6 @@ class ExternalMemory(object):
self.external_memory = paddle.layer.memory(
name=self.name,
size=self.mem_slot_size,
is_seq=True,
boot_layer=boot_layer)
```
- `ExternalMemory`类的寻址逻辑通过 `_content_addressing` 和 `_interpolation` 两个私有方法实现。读和写操作通过 `read` 和 `write` 两个函数实现,包括上述的寻址操作。并且读和写的寻址独立进行,不同于 \[[2](#参考文献)\] 中的二者共享同一个寻址强度,目的是为了使得该类更通用。
Expand Down Expand Up @@ -349,6 +352,7 @@ def memory_enhanced_seq2seq(encoder_input, decoder_input, decoder_target,
name="unbounded_memory",
mem_slot_size=size * 2,
boot_layer=unbounded_memory_init,
initial_weight=unbounded_memory_weight_init,
readonly=True,
enable_interpolation=False)
```
Expand All @@ -359,6 +363,7 @@ def memory_enhanced_seq2seq(encoder_input, decoder_input, decoder_target,
name="bounded_memory",
mem_slot_size=size,
boot_layer=bounded_memory_init,
initial_weight=bounded_memory_weight_init,
readonly=False,
enable_interpolation=True)
```
Expand Down
2 changes: 1 addition & 1 deletion mt_with_external_memory/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def event_handler(event):
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_batch_reader, feeding=feeding)
print "Pass: %d, TestCost: %f, %s" % (event.pass_id, event.cost,
print "Pass: %d, TestCost: %f, %s" % (event.pass_id, result.cost,
result.metrics)
with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id,
'w') as f:
Expand Down