@@ -125,6 +125,18 @@ def program_transform(program):
125125 _create_op_desc_ ('tanh_p' , {'X' : [in_names [0 ]]},
126126 {'Y' : [out_names [0 ]]}, {}))
127127
128+ elif op_desc .type () == 'assign' :
129+ tmp_1 = name_gen .get_var (new_block , block .var (in_names [0 ]))
130+ to_insert .append (
131+ _create_op_desc_ ('fill_constant_p' , {}, {'Y' : [tmp_1 ]}, {
132+ 'shape' : block .var (in_names [0 ]).shape ,
133+ 'value' : 0.0
134+ }))
135+ to_insert .append (
136+ _create_op_desc_ ('add_p' , {'X' : [in_names [0 ]],
137+ 'Y' : [tmp_1 ]},
138+ {'Z' : [out_names [0 ]]}, {}))
139+
128140 elif op_desc .type () == 'reshape2' :
129141 to_insert .append (
130142 _create_op_desc_ ('reshape_p' , {'X' : [in_names [0 ]]}, {
@@ -138,14 +150,33 @@ def program_transform(program):
138150 }, {'axis' : op_desc .attr ('axis' )}))
139151
140152 elif op_desc .type () == 'slice' :
141- to_insert .append (
142- _create_op_desc_ ('slice_select_p' , {'X' : [in_names [0 ]]},
143- {'Y' : [out_names [0 ]]}, {
144- 'axis' : op_desc .attr ('axes' ),
145- 'starts' : op_desc .attr ('starts' ),
146- 'ends' : op_desc .attr ('ends' ),
147- 'strides' : op_desc .attr ('decrease_axis' )
148- }))
153+ if op_desc .attr ('decrease_axis' ) is None :
154+ to_insert .append (
155+ _create_op_desc_ ('slice_select_p' , {'X' : [in_names [0 ]]},
156+ {'Y' : [out_names [0 ]]}, {
157+ 'axis' : op_desc .attr ('axes' ),
158+ 'starts' : op_desc .attr ('starts' ),
159+ 'ends' : op_desc .attr ('ends' ),
160+ 'strides' : 1 ,
161+ }))
162+ else :
163+ tmp_shape = list (block .var (in_names [0 ]).shape )
164+ for axis in op_desc .attr ('decrease_axis' ):
165+ tmp_shape [axis ] = 1
166+ tmp_0 = name_gen .get_var (
167+ new_block , block .var (in_names [0 ]), shape = tuple (tmp_shape ))
168+ to_insert .append (
169+ _create_op_desc_ ('slice_select_p' , {'X' : [in_names [0 ]]},
170+ {'Y' : [tmp_0 ]}, {
171+ 'axis' : op_desc .attr ('axes' ),
172+ 'starts' : op_desc .attr ('starts' ),
173+ 'ends' : op_desc .attr ('ends' ),
174+ 'strides' : 1 ,
175+ }))
176+ to_insert .append (
177+ _create_op_desc_ ('reshape_p' , {'X' : [tmp_0 ]}, {
178+ 'Y' : [out_names [0 ]]
179+ }, {'shape' : block .var (out_names [0 ]).shape }))
149180
150181 elif op_desc .type () == 'slice_grad' :
151182 tmp_1 = name_gen .get_var (new_block , block .var (in_names [0 ]))
@@ -154,15 +185,36 @@ def program_transform(program):
154185 'shape' : block .var (in_names [0 ]).shape ,
155186 'value' : 0.0
156187 }))
157- to_insert .append (
158- _create_op_desc_ ('slice_assign_p' ,
159- {'X' : [tmp_1 ],
160- 'Y' : [in_names [1 ]]}, {'Z' : [out_names [0 ]]}, {
161- 'axis' : op_desc .attr ('axes' ),
162- 'starts' : op_desc .attr ('starts' ),
163- 'ends' : op_desc .attr ('ends' ),
164- 'strides' : op_desc .attr ('decrease_axis' )
165- }))
188+ if op_desc .attr ('decrease_axis' ) is None :
189+ to_insert .append (
190+ _create_op_desc_ ('slice_assign_p' , {
191+ 'X' : [tmp_1 ],
192+ 'Y' : [in_names [1 ]]
193+ }, {'Z' : [out_names [0 ]]}, {
194+ 'axis' : op_desc .attr ('axes' ),
195+ 'starts' : op_desc .attr ('starts' ),
196+ 'ends' : op_desc .attr ('ends' ),
197+ 'strides' : 1
198+ }))
199+ else :
200+ tmp_shape = list (block .var (in_names [1 ]).shape )
201+ for axis in op_desc .attr ('decrease_axis' ):
202+ assert axis == 1
203+ tmp_shape .append (1 )
204+ tmp_2 = name_gen .get_var (
205+ new_block , block .var (in_names [1 ]), shape = tuple (tmp_shape ))
206+ to_insert .append (
207+ _create_op_desc_ ('reshape_p' , {'X' : [in_names [1 ]]},
208+ {'Y' : [tmp_2 ]}, {'shape' : tmp_shape }))
209+ to_insert .append (
210+ _create_op_desc_ ('slice_assign_p' ,
211+ {'X' : [tmp_1 ],
212+ 'Y' : [tmp_2 ]}, {'Z' : [out_names [0 ]]}, {
213+ 'axis' : op_desc .attr ('axes' ),
214+ 'starts' : op_desc .attr ('starts' ),
215+ 'ends' : op_desc .attr ('ends' ),
216+ 'strides' : 1
217+ }))
166218
167219 elif op_desc .type () == 'concat_grad' :
168220 to_insert .append (
@@ -759,6 +811,7 @@ def program_transform(program):
759811 'add_p' , {'X' : [in_names [2 ]],
760812 'Y' : [tmp_1 ]}, {'Z' : [out_names [1 ]]}, {}))
761813 else :
814+ # print(op_desc.type())
762815 assert op_desc .type () in {'adam' , 'shape' , 'fill_constant' }
763816 to_insert .append (op_desc )
764817
0 commit comments