@@ -25,13 +25,19 @@ Data must be in the `str` format as detailed in the example below:
2525from pytree.data import prepare_input_from_constituency_tree
2626
2727parse_tree_example = ' (TOP (S (NP (_ I)) (VP (_ saw) (NP (_ Sarah)) (PP (_ with) (NP (_ a) (_ telescope)))) (_ .)))'
28- input_test, head_idx_test = prepare_input_from_constituency_tree(parse_tree_example)
28+ input_test, head_idx_test, head_idx_r_test, head_idx_l_test = prepare_input_from_constituency_tree(parse_tree_example)
2929
3030print (input_test)
3131# ['[CLS]', 'I', 'saw', 'Sarah', 'with', 'a', 'telescope', '.', '[S]', '[S]', '[VP]', '[VP]', '[PP]', '[NP]']
3232
3333print (head_idx_test)
34- # [0, 8, 10, 10, 11, 12, 12, 7, 0, 7, 8, 9, 9, 11]
34+ # [0, 9, 11, 11, 12, 13, 13, 8, 0, 8, 9, 10, 10, 12]
35+
36+ print (head_idx_r_test)
37+ # [0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0]
38+
39+ print (head_idx_l_test)
40+ # [0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1]
3541```
3642
3743### Prepare dependency tree data
@@ -68,17 +74,19 @@ from pytree.data.glove_tokenizer import GloveTokenizer
6874glove_tokenizer = GloveTokenizer(glove_file_path = ' ./glove.6B.300d.txt' , vocab_size = 10000 )
6975input_test = glove_tokenizer.convert_tokens_to_ids(input_test)
7076print (input_test)
71- # [1, 1, 824, 1, 19, 9, 1, 4, 1, 1, 1, 1, 1, 1 ]
77+ # [1, 1, 824, 1, 19, 9, 1, 4]
7278```
7379
7480Then prepare the data:
7581
7682``` python
77- tree_ids_test, tree_ids_test_r, tree_ids_test_l = build_tree_ids_n_ary(head_idx_test)
83+ from pytree.data.utils import build_tree_ids_n_ary
84+
85+ tree_ids_test, tree_ids_test_r, tree_ids_test_l = build_tree_ids_n_ary(head_idx_test, head_idx_r_test, head_idx_l_test)
7886inputs = {' input_ids' : torch.tensor(input_test).unsqueeze(0 ),
79- ' packed_tree ' : torch.tensor(tree_ids_test).unsqueeze(0 ),
80- ' packed_tree_r ' : torch.tensor(tree_ids_test_r).unsqueeze(0 ),
81- ' packed_tree_l ' : torch.tensor(tree_ids_test_l).unsqueeze(0 )}
87+ ' tree_ids ' : torch.tensor(tree_ids_test).unsqueeze(0 ),
88+ ' tree_ids_r ' : torch.tensor(tree_ids_test_r).unsqueeze(0 ),
89+ ' tree_ids_l ' : torch.tensor(tree_ids_test_l).unsqueeze(0 )}
8290```
8391
8492And apply the model:
@@ -89,17 +97,19 @@ from pytree.models import NaryConfig, NaryTree
8997config = NaryConfig()
9098tree_encoder = NaryTree(config)
9199
92- tree_encoder(inputs)
100+ (h, c), h_root = tree_encoder(inputs)
101+ print (h)
93102# tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
94- # [ 0.0012, 0.0015, -0.0026 , ..., -0.0001 , 0.0002 , -0.0043 ],
95- # [ 0.0022, 0.0024, -0.0035 , ..., -0.0002 , 0.0003 , -0.0058 ],
103+ # [ 0.0113, -0.0066, 0.0089 , ..., 0.0064 , 0.0076 , -0.0048 ],
104+ # [ 0.0110, -0.0073, 0.0110 , ..., 0.0070 , 0.0046 , -0.0049 ],
96105# ...,
97- # [ 0.0028, 0.0023, -0.0035 , ..., -0.0002 , 0.0003 , -0.0057 ],
98- # [ 0.0020, 0.0016, -0.0023 , ..., -0.0001 , 0.0002 , -0.0036 ],
99- # [ 0.0019, 0.0015, -0.0024 , ..., -0.0001 , 0.0002 , -0.0039 ]]],
100- # grad_fn=<MaskedScatterBackward >)
106+ # [ 0.0254, -0.0138, 0.0224 , ..., 0.0131 , 0.0148 , -0.0143 ],
107+ # [ 0.0346, -0.0172, 0.0281 , ..., 0.0140 , 0.0198 , -0.0267 ],
108+ # [ 0.0247, -0.0126, 0.0201 , ..., 0.0116 , 0.0162 , -0.0184 ]]],
109+ # grad_fn=<SWhereBackward >)
101110
102- print (tree_encoder(inputs) .shape)
103- # tree_encoder(inputs).shape
111+ print (h_root .shape)
112+ # torch.Size([150])
104113```
105114
115+ We also provide a full demonstration with the SICK dataset and batched processing in the [ examples folder] ( https://github.com/AntoineSimoulin/pytree/tree/main/examples ) .
0 commit comments