diff --git a/tests/python/contrib/test_clml/menangerie.py b/tests/python/contrib/test_clml/menangerie.py index d146b84c46b62..fd1b30bf51aab 100644 --- a/tests/python/contrib/test_clml/menangerie.py +++ b/tests/python/contrib/test_clml/menangerie.py @@ -30,2289 +30,6 @@ def make_consts(dtype, shapes): return [make_const(dtype, shape) for shape in shapes] -def mnist_consts(dtype): - return make_consts( - dtype, - [ - (8, 1, 5, 5), # 0 - (8, 1, 1), # 1 - (16, 8, 5, 5), # 2 - (16, 1, 1), # 3 - (10, 256), # 4 - (1, 10), # 5 - ], - ) - - -def mnist(): - metatable = {"relay.Constant": mnist_consts("float32")} - mod = tvm.parser.parse( - """ - #[version = "0.0.5"] - def @main(%x: Tensor[(1, 1, 28, 28), float32]) -> Tensor[(1, 10), float32] { - %0 = nn.pad(%x, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); - %1 = nn.conv2d(%0, meta[relay.Constant][0], padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]); - %2 = add(%1, meta[relay.Constant][1]); - %3 = nn.relu(%2); - %4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]); - %5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]); - %6 = nn.conv2d(%5, meta[relay.Constant][2], padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]); - %7 = add(%6, meta[relay.Constant][3]); - %8 = nn.relu(%7); - %9 = nn.max_pool2d(%8, pool_size=[3, 3], strides=[3, 3], padding=[0, 0, 0, 0]); - %10 = reshape(%9, newshape=[1, 256]); - %11 = nn.dense(%10, meta[relay.Constant][4], units=None, out_dtype="float32"); - add(%11, meta[relay.Constant][5]) - } - """, - "from_string", - None, - metatable, - ) - return { - "name": "mnist", - "input_shapes": {"x": [1, 1, 28, 28]}, - "input_dtypes": {"x": "float32"}, - "mod": mod, - "params": None, - "main_dtype": "float32", - } - - -def gpt2_consts(dtype): - return make_consts( - dtype, - [ - (50257, 768), # 0 - (1, 32, 768), # 1 - (768,), # 2 - (768,), # 3 - (2304, 768), # 4 - (2304,), # 5 - (1, 1, 32, 32), # 6 - (1, 1, 32, 32), # 7 - (768, 768), # 8 - (768,), # 9 - (768,), # 10 - (768,), # 11 - (3072, 768), # 12 - (3072,), # 13 - (768, 3072), # 14 - (768,), # 15 - (768,), # 16 - (768,), # 17 - (2304, 768), # 18 - (2304,), # 19 - (1, 1, 32, 32), # 20 - (1, 1, 32, 32), # 21 - (768, 768), # 22 - (768,), # 23 - (768,), # 24 - (768,), # 25 - (3072, 768), # 26 - (3072,), # 27 - (768, 3072), # 28 - (768,), # 29 - (768,), # 30 - (768,), # 31 - (2304, 768), # 32 - (2304,), # 33 - (1, 1, 32, 32), # 34 - (1, 1, 32, 32), # 35 - (768, 768), # 36 - (768,), # 37 - (768,), # 38 - (768,), # 39 - (3072, 768), # 40 - (3072,), # 41 - (768, 3072), # 42 - (768,), # 43 - (768,), # 44 - (768,), # 45 - (2304, 768), # 46 - (2304,), # 47 - (1, 1, 32, 32), # 48 - (1, 1, 32, 32), # 49 - (768, 768), # 50 - (768,), # 51 - (768,), # 52 - (768,), # 53 - (3072, 768), # 54 - (3072,), # 55 - (768, 3072), # 56 - (768,), # 57 - (768,), # 58 - (768,), # 59 - (2304, 768), # 60 - (2304,), # 61 - (1, 1, 32, 32), # 62 - (1, 1, 32, 32), # 63 - (768, 768), # 64 - (768,), # 65 - (768,), # 66 - (768,), # 67 - (3072, 768), # 68 - (3072,), # 69 - (768, 3072), # 70 - (768,), # 71 - (768,), # 72 - (768,), # 73 - (2304, 768), # 74 - (2304,), # 75 - (1, 1, 32, 32), # 76 - (1, 1, 32, 32), # 77 - (768, 768), # 78 - (768,), # 79 - (768,), # 80 - (768,), # 81 - (3072, 768), # 82 - (3072,), # 83 - (768, 3072), # 84 - (768,), # 85 - (768,), # 86 - (768,), # 87 - (2304, 768), # 88 - (2304,), # 89 - (1, 1, 32, 32), # 90 - (1, 1, 32, 32), # 91 - (768, 768), # 92 - (768,), # 93 - (768,), # 94 - (768,), # 95 - (3072, 768), # 96 - (3072,), # 97 - (768, 3072), # 98 - (768,), # 99 - (768,), # 100 - (768,), # 101 - (2304, 768), # 102 - (2304,), # 103 - (1, 1, 32, 32), # 104 - (1, 1, 32, 32), # 105 - (768, 768), # 106 - (768,), # 107 - (768,), # 108 - (768,), # 109 - (3072, 768), # 110 - (3072,), # 111 - (768, 3072), # 112 - (768,), # 113 - (768,), # 114 - (768,), # 115 - (2304, 768), # 116 - (2304,), # 117 - (1, 1, 32, 32), # 118 - (1, 1, 32, 32), # 119 - (768, 768), # 120 - (768,), # 121 - (768,), # 122 - (768,), # 123 - (3072, 768), # 124 - (3072,), # 125 - (768, 3072), # 126 - (768,), # 127 - (768,), # 128 - (768,), # 129 - (2304, 768), # 130 - (2304,), # 131 - (1, 1, 32, 32), # 132 - (1, 1, 32, 32), # 133 - (768, 768), # 134 - (768,), # 135 - (768,), # 136 - (768,), # 137 - (3072, 768), # 138 - (3072,), # 139 - (768, 3072), # 140 - (768,), # 141 - (768,), # 142 - (768,), # 143 - (2304, 768), # 144 - (2304,), # 145 - (1, 1, 32, 32), # 146 - (1, 1, 32, 32), # 147 - (768, 768), # 148 - (768,), # 149 - (768,), # 150 - (768,), # 151 - (3072, 768), # 152 - (3072,), # 153 - (768, 3072), # 154 - (768,), # 155 - (768,), # 156 - (768,), # 157 - (2304, 768), # 158 - (2304,), # 159 - (1, 1, 32, 32), # 160 - (1, 1, 32, 32), # 161 - (768, 768), # 162 - (768,), # 163 - (768,), # 164 - (768,), # 165 - (3072, 768), # 166 - (3072,), # 167 - (768, 3072), # 168 - (768,), # 169 - (768,), # 170 - (768,), # 171 - ], - ) - - -def gpt2(): - metatable = {"relay.Constant": gpt2_consts("float32")} - mod = tvm.parser.parse( - """ - #[version = "0.0.5"] - def @main(%x: Tensor[(1, 50, 32), int64]) -> (Tensor[(1, 50, 32, 768), float32], - Tensor[(2, 50, 12, 32, 64), float32], - Tensor[(2, 50, 12, 32, 64), float32], - Tensor[(2, 50, 12, 32, 64), float32], - Tensor[(2, 50, 12, 32, 64), float32], - Tensor[(2, 50, 12, 32, 64), float32], - Tensor[(2, 50, 12, 32, 64), float32], - Tensor[(2, 50, 12, 32, 64), float32], - Tensor[(2, 50, 12, 32, 64), float32], - Tensor[(2, 50, 12, 32, 64), float32], - Tensor[(2, 50, 12, 32, 64), float32], - Tensor[(2, 50, 12, 32, 64), float32], - Tensor[(2, 50, 12, 32, 64), float32]) { - %0 = reshape(%x, newshape=[-1, 32]); - %1 = less(%0, 0i64); - %2 = add(%0, 50257i64); - %3 = where(%1, %2, %0); - %4 = take(meta[relay.Constant][0], %3, axis=0); - %5 = add(%4, meta[relay.Constant][1]); - %6 = mean(%5, axis=[-1], keepdims=True); - %7 = subtract(%5, %6); - %8 = power(%7, 2f); - %9 = mean(%8, axis=[-1], keepdims=True); - %10 = add(%9, 1e-05f); - %11 = sqrt(%10); - %12 = divide(%7, %11); - %13 = multiply(%12, meta[relay.Constant][2]); - %14 = add(%13, meta[relay.Constant][3]); - %15 = reshape(%14, newshape=[-1, 768]); - %16 = nn.dense(%15, meta[relay.Constant][4], units=2304); - %17 = add(%16, meta[relay.Constant][5]); - %18 = reshape(%17, newshape=[50, 32, 2304]); - %19 = split(%18, indices_or_sections=[768, 1536], axis=2); - %20 = %19.0; - %21 = reshape(%20, newshape=[50, 32, 12, 64]); - %22 = transpose(%21, axes=[0, 2, 1, 3]); - %23 = %19.1; - %24 = reshape(%23, newshape=[50, 32, 12, 64]); - %25 = transpose(%24, axes=[0, 2, 3, 1]); - %26 = reshape(%25, newshape=[-1, 64, 32]); - %27 = reshape(%22, newshape=[-1, 32, 64]); - %28 = transpose(%26, axes=[0, 2, 1]); - %29 = nn.batch_matmul(%27, %28, out_dtype="float32", transpose_b=True); - %30 = reshape(%29, newshape=[50, 12, 32, 32]); - %31 = divide(%30, 8f); - %32 = multiply(%31, meta[relay.Constant][6]); - %33 = subtract(%32, meta[relay.Constant][7]); - %34 = nn.softmax(%33, axis=3); - %35 = %19.2; - %36 = reshape(%35, newshape=[50, 32, 12, 64]); - %37 = transpose(%36, axes=[0, 2, 1, 3]); - %38 = reshape(%37, newshape=[-1, 32, 64]); - %39 = reshape(%34, newshape=[-1, 32, 32]); - %40 = transpose(%38, axes=[0, 2, 1]); - %41 = nn.batch_matmul(%39, %40, out_dtype="float32", transpose_b=True); - %42 = reshape(%41, newshape=[50, 12, 32, 64]); - %43 = transpose(%42, axes=[0, 2, 1, 3]); - %44 = reshape(%43, newshape=[50, 32, 768]); - %45 = reshape(%44, newshape=[-1, 768]); - %46 = nn.dense(%45, meta[relay.Constant][8], units=768); - %47 = add(%46, meta[relay.Constant][9]); - %48 = reshape(%47, newshape=[50, 32, 768]); - %49 = add(%5, %48); - %50 = mean(%49, axis=[-1], keepdims=True); - %51 = subtract(%49, %50); - %52 = power(%51, 2f); - %53 = mean(%52, axis=[-1], keepdims=True); - %54 = add(%53, 1e-05f); - %55 = sqrt(%54); - %56 = divide(%51, %55); - %57 = multiply(%56, meta[relay.Constant][10]); - %58 = add(%57, meta[relay.Constant][11]); - %59 = reshape(%58, newshape=[-1, 768]); - %60 = nn.dense(%59, meta[relay.Constant][12], units=3072); - %61 = add(%60, meta[relay.Constant][13]); - %62 = reshape(%61, newshape=[50, 32, 3072]); - %63 = power(%62, 3f); - %64 = multiply(%63, 0.044715f); - %65 = add(%62, %64); - %66 = multiply(%65, 0.797885f); - %67 = tanh(%66); - %68 = multiply(%62, 0.5f); - %69 = add(%67, 1f); - %70 = multiply(%68, %69); - %71 = reshape(%70, newshape=[-1, 3072]); - %72 = nn.dense(%71, meta[relay.Constant][14], units=768); - %73 = add(%72, meta[relay.Constant][15]); - %74 = reshape(%73, newshape=[50, 32, 768]); - %75 = add(%49, %74); - %76 = mean(%75, axis=[-1], keepdims=True); - %77 = subtract(%75, %76); - %78 = power(%77, 2f); - %79 = mean(%78, axis=[-1], keepdims=True); - %80 = add(%79, 1e-05f); - %81 = sqrt(%80); - %82 = divide(%77, %81); - %83 = multiply(%82, meta[relay.Constant][16]); - %84 = add(%83, meta[relay.Constant][17]); - %85 = reshape(%84, newshape=[-1, 768]); - %86 = nn.dense(%85, meta[relay.Constant][18], units=2304); - %87 = add(%86, meta[relay.Constant][19]); - %88 = reshape(%87, newshape=[50, 32, 2304]); - %89 = split(%88, indices_or_sections=[768, 1536], axis=2); - %90 = %89.0; - %91 = reshape(%90, newshape=[50, 32, 12, 64]); - %92 = transpose(%91, axes=[0, 2, 1, 3]); - %93 = %89.1; - %94 = reshape(%93, newshape=[50, 32, 12, 64]); - %95 = transpose(%94, axes=[0, 2, 3, 1]); - %96 = reshape(%95, newshape=[-1, 64, 32]); - %97 = reshape(%92, newshape=[-1, 32, 64]); - %98 = transpose(%96, axes=[0, 2, 1]); - %99 = nn.batch_matmul(%97, %98, out_dtype="float32", transpose_b=True); - %100 = reshape(%99, newshape=[50, 12, 32, 32]); - %101 = divide(%100, 8f); - %102 = multiply(%101, meta[relay.Constant][20]); - %103 = subtract(%102, meta[relay.Constant][21]); - %104 = nn.softmax(%103, axis=3); - %105 = %89.2; - %106 = reshape(%105, newshape=[50, 32, 12, 64]); - %107 = transpose(%106, axes=[0, 2, 1, 3]); - %108 = reshape(%107, newshape=[-1, 32, 64]); - %109 = reshape(%104, newshape=[-1, 32, 32]); - %110 = transpose(%108, axes=[0, 2, 1]); - %111 = nn.batch_matmul(%109, %110, out_dtype="float32", transpose_b=True); - %112 = reshape(%111, newshape=[50, 12, 32, 64]); - %113 = transpose(%112, axes=[0, 2, 1, 3]); - %114 = reshape(%113, newshape=[50, 32, 768]); - %115 = reshape(%114, newshape=[-1, 768]); - %116 = nn.dense(%115, meta[relay.Constant][22], units=768); - %117 = add(%116, meta[relay.Constant][23]); - %118 = reshape(%117, newshape=[50, 32, 768]); - %119 = add(%75, %118); - %120 = mean(%119, axis=[-1], keepdims=True); - %121 = subtract(%119, %120); - %122 = power(%121, 2f); - %123 = mean(%122, axis=[-1], keepdims=True); - %124 = add(%123, 1e-05f); - %125 = sqrt(%124); - %126 = divide(%121, %125); - %127 = multiply(%126, meta[relay.Constant][24]); - %128 = add(%127, meta[relay.Constant][25]); - %129 = reshape(%128, newshape=[-1, 768]); - %130 = nn.dense(%129, meta[relay.Constant][26], units=3072); - %131 = add(%130, meta[relay.Constant][27]); - %132 = reshape(%131, newshape=[50, 32, 3072]); - %133 = power(%132, 3f); - %134 = multiply(%133, 0.044715f); - %135 = add(%132, %134); - %136 = multiply(%135, 0.797885f); - %137 = tanh(%136); - %138 = multiply(%132, 0.5f); - %139 = add(%137, 1f); - %140 = multiply(%138, %139); - %141 = reshape(%140, newshape=[-1, 3072]); - %142 = nn.dense(%141, meta[relay.Constant][28], units=768); - %143 = add(%142, meta[relay.Constant][29]); - %144 = reshape(%143, newshape=[50, 32, 768]); - %145 = add(%119, %144); - %146 = mean(%145, axis=[-1], keepdims=True); - %147 = subtract(%145, %146); - %148 = power(%147, 2f); - %149 = mean(%148, axis=[-1], keepdims=True); - %150 = add(%149, 1e-05f); - %151 = sqrt(%150); - %152 = divide(%147, %151); - %153 = multiply(%152, meta[relay.Constant][30]); - %154 = add(%153, meta[relay.Constant][31]); - %155 = reshape(%154, newshape=[-1, 768]); - %156 = nn.dense(%155, meta[relay.Constant][32], units=2304); - %157 = add(%156, meta[relay.Constant][33]); - %158 = reshape(%157, newshape=[50, 32, 2304]); - %159 = split(%158, indices_or_sections=[768, 1536], axis=2); - %160 = %159.0; - %161 = reshape(%160, newshape=[50, 32, 12, 64]); - %162 = transpose(%161, axes=[0, 2, 1, 3]); - %163 = %159.1; - %164 = reshape(%163, newshape=[50, 32, 12, 64]); - %165 = transpose(%164, axes=[0, 2, 3, 1]); - %166 = reshape(%165, newshape=[-1, 64, 32]); - %167 = reshape(%162, newshape=[-1, 32, 64]); - %168 = transpose(%166, axes=[0, 2, 1]); - %169 = nn.batch_matmul(%167, %168, out_dtype="float32", transpose_b=True); - %170 = reshape(%169, newshape=[50, 12, 32, 32]); - %171 = divide(%170, 8f); - %172 = multiply(%171, meta[relay.Constant][34]); - %173 = subtract(%172, meta[relay.Constant][35]); - %174 = nn.softmax(%173, axis=3); - %175 = %159.2; - %176 = reshape(%175, newshape=[50, 32, 12, 64]); - %177 = transpose(%176, axes=[0, 2, 1, 3]); - %178 = reshape(%177, newshape=[-1, 32, 64]); - %179 = reshape(%174, newshape=[-1, 32, 32]); - %180 = transpose(%178, axes=[0, 2, 1]); - %181 = nn.batch_matmul(%179, %180, out_dtype="float32", transpose_b=True); - %182 = reshape(%181, newshape=[50, 12, 32, 64]); - %183 = transpose(%182, axes=[0, 2, 1, 3]); - %184 = reshape(%183, newshape=[50, 32, 768]); - %185 = reshape(%184, newshape=[-1, 768]); - %186 = nn.dense(%185, meta[relay.Constant][36], units=768); - %187 = add(%186, meta[relay.Constant][37]); - %188 = reshape(%187, newshape=[50, 32, 768]); - %189 = add(%145, %188); - %190 = mean(%189, axis=[-1], keepdims=True); - %191 = subtract(%189, %190); - %192 = power(%191, 2f); - %193 = mean(%192, axis=[-1], keepdims=True); - %194 = add(%193, 1e-05f); - %195 = sqrt(%194); - %196 = divide(%191, %195); - %197 = multiply(%196, meta[relay.Constant][38]); - %198 = add(%197, meta[relay.Constant][39]); - %199 = reshape(%198, newshape=[-1, 768]); - %200 = nn.dense(%199, meta[relay.Constant][40], units=3072); - %201 = add(%200, meta[relay.Constant][41]); - %202 = reshape(%201, newshape=[50, 32, 3072]); - %203 = power(%202, 3f); - %204 = multiply(%203, 0.044715f); - %205 = add(%202, %204); - %206 = multiply(%205, 0.797885f); - %207 = tanh(%206); - %208 = multiply(%202, 0.5f); - %209 = add(%207, 1f); - %210 = multiply(%208, %209); - %211 = reshape(%210, newshape=[-1, 3072]); - %212 = nn.dense(%211, meta[relay.Constant][42], units=768); - %213 = add(%212, meta[relay.Constant][43]); - %214 = reshape(%213, newshape=[50, 32, 768]); - %215 = add(%189, %214); - %216 = mean(%215, axis=[-1], keepdims=True); - %217 = subtract(%215, %216); - %218 = power(%217, 2f); - %219 = mean(%218, axis=[-1], keepdims=True); - %220 = add(%219, 1e-05f); - %221 = sqrt(%220); - %222 = divide(%217, %221); - %223 = multiply(%222, meta[relay.Constant][44]); - %224 = add(%223, meta[relay.Constant][45]); - %225 = reshape(%224, newshape=[-1, 768]); - %226 = nn.dense(%225, meta[relay.Constant][46], units=2304); - %227 = add(%226, meta[relay.Constant][47]); - %228 = reshape(%227, newshape=[50, 32, 2304]); - %229 = split(%228, indices_or_sections=[768, 1536], axis=2); - %230 = %229.0; - %231 = reshape(%230, newshape=[50, 32, 12, 64]); - %232 = transpose(%231, axes=[0, 2, 1, 3]); - %233 = %229.1; - %234 = reshape(%233, newshape=[50, 32, 12, 64]); - %235 = transpose(%234, axes=[0, 2, 3, 1]); - %236 = reshape(%235, newshape=[-1, 64, 32]); - %237 = reshape(%232, newshape=[-1, 32, 64]); - %238 = transpose(%236, axes=[0, 2, 1]); - %239 = nn.batch_matmul(%237, %238, out_dtype="float32", transpose_b=True); - %240 = reshape(%239, newshape=[50, 12, 32, 32]); - %241 = divide(%240, 8f); - %242 = multiply(%241, meta[relay.Constant][48]); - %243 = subtract(%242, meta[relay.Constant][49]); - %244 = nn.softmax(%243, axis=3); - %245 = %229.2; - %246 = reshape(%245, newshape=[50, 32, 12, 64]); - %247 = transpose(%246, axes=[0, 2, 1, 3]); - %248 = reshape(%247, newshape=[-1, 32, 64]); - %249 = reshape(%244, newshape=[-1, 32, 32]); - %250 = transpose(%248, axes=[0, 2, 1]); - %251 = nn.batch_matmul(%249, %250, out_dtype="float32", transpose_b=True); - %252 = reshape(%251, newshape=[50, 12, 32, 64]); - %253 = transpose(%252, axes=[0, 2, 1, 3]); - %254 = reshape(%253, newshape=[50, 32, 768]); - %255 = reshape(%254, newshape=[-1, 768]); - %256 = nn.dense(%255, meta[relay.Constant][50], units=768); - %257 = add(%256, meta[relay.Constant][51]); - %258 = reshape(%257, newshape=[50, 32, 768]); - %259 = add(%215, %258); - %260 = mean(%259, axis=[-1], keepdims=True); - %261 = subtract(%259, %260); - %262 = power(%261, 2f); - %263 = mean(%262, axis=[-1], keepdims=True); - %264 = add(%263, 1e-05f); - %265 = sqrt(%264); - %266 = divide(%261, %265); - %267 = multiply(%266, meta[relay.Constant][52]); - %268 = add(%267, meta[relay.Constant][53]); - %269 = reshape(%268, newshape=[-1, 768]); - %270 = nn.dense(%269, meta[relay.Constant][54], units=3072); - %271 = add(%270, meta[relay.Constant][55]); - %272 = reshape(%271, newshape=[50, 32, 3072]); - %273 = power(%272, 3f); - %274 = multiply(%273, 0.044715f); - %275 = add(%272, %274); - %276 = multiply(%275, 0.797885f); - %277 = tanh(%276); - %278 = multiply(%272, 0.5f); - %279 = add(%277, 1f); - %280 = multiply(%278, %279); - %281 = reshape(%280, newshape=[-1, 3072]); - %282 = nn.dense(%281, meta[relay.Constant][56], units=768); - %283 = add(%282, meta[relay.Constant][57]); - %284 = reshape(%283, newshape=[50, 32, 768]); - %285 = add(%259, %284); - %286 = mean(%285, axis=[-1], keepdims=True); - %287 = subtract(%285, %286); - %288 = power(%287, 2f); - %289 = mean(%288, axis=[-1], keepdims=True); - %290 = add(%289, 1e-05f); - %291 = sqrt(%290); - %292 = divide(%287, %291); - %293 = multiply(%292, meta[relay.Constant][58]); - %294 = add(%293, meta[relay.Constant][59]); - %295 = reshape(%294, newshape=[-1, 768]); - %296 = nn.dense(%295, meta[relay.Constant][60], units=2304); - %297 = add(%296, meta[relay.Constant][61]); - %298 = reshape(%297, newshape=[50, 32, 2304]); - %299 = split(%298, indices_or_sections=[768, 1536], axis=2); - %300 = %299.0; - %301 = reshape(%300, newshape=[50, 32, 12, 64]); - %302 = transpose(%301, axes=[0, 2, 1, 3]); - %303 = %299.1; - %304 = reshape(%303, newshape=[50, 32, 12, 64]); - %305 = transpose(%304, axes=[0, 2, 3, 1]); - %306 = reshape(%305, newshape=[-1, 64, 32]); - %307 = reshape(%302, newshape=[-1, 32, 64]); - %308 = transpose(%306, axes=[0, 2, 1]); - %309 = nn.batch_matmul(%307, %308, out_dtype="float32", transpose_b=True); - %310 = reshape(%309, newshape=[50, 12, 32, 32]); - %311 = divide(%310, 8f); - %312 = multiply(%311, meta[relay.Constant][62]); - %313 = subtract(%312, meta[relay.Constant][63]); - %314 = nn.softmax(%313, axis=3); - %315 = %299.2; - %316 = reshape(%315, newshape=[50, 32, 12, 64]); - %317 = transpose(%316, axes=[0, 2, 1, 3]); - %318 = reshape(%317, newshape=[-1, 32, 64]); - %319 = reshape(%314, newshape=[-1, 32, 32]); - %320 = transpose(%318, axes=[0, 2, 1]); - %321 = nn.batch_matmul(%319, %320, out_dtype="float32", transpose_b=True); - %322 = reshape(%321, newshape=[50, 12, 32, 64]); - %323 = transpose(%322, axes=[0, 2, 1, 3]); - %324 = reshape(%323, newshape=[50, 32, 768]); - %325 = reshape(%324, newshape=[-1, 768]); - %326 = nn.dense(%325, meta[relay.Constant][64], units=768); - %327 = add(%326, meta[relay.Constant][65]); - %328 = reshape(%327, newshape=[50, 32, 768]); - %329 = add(%285, %328); - %330 = mean(%329, axis=[-1], keepdims=True); - %331 = subtract(%329, %330); - %332 = power(%331, 2f); - %333 = mean(%332, axis=[-1], keepdims=True); - %334 = add(%333, 1e-05f); - %335 = sqrt(%334); - %336 = divide(%331, %335); - %337 = multiply(%336, meta[relay.Constant][66]); - %338 = add(%337, meta[relay.Constant][67]); - %339 = reshape(%338, newshape=[-1, 768]); - %340 = nn.dense(%339, meta[relay.Constant][68], units=3072); - %341 = add(%340, meta[relay.Constant][69]); - %342 = reshape(%341, newshape=[50, 32, 3072]); - %343 = power(%342, 3f); - %344 = multiply(%343, 0.044715f); - %345 = add(%342, %344); - %346 = multiply(%345, 0.797885f); - %347 = tanh(%346); - %348 = multiply(%342, 0.5f); - %349 = add(%347, 1f); - %350 = multiply(%348, %349); - %351 = reshape(%350, newshape=[-1, 3072]); - %352 = nn.dense(%351, meta[relay.Constant][70], units=768); - %353 = add(%352, meta[relay.Constant][71]); - %354 = reshape(%353, newshape=[50, 32, 768]); - %355 = add(%329, %354); - %356 = mean(%355, axis=[-1], keepdims=True); - %357 = subtract(%355, %356); - %358 = power(%357, 2f); - %359 = mean(%358, axis=[-1], keepdims=True); - %360 = add(%359, 1e-05f); - %361 = sqrt(%360); - %362 = divide(%357, %361); - %363 = multiply(%362, meta[relay.Constant][72]); - %364 = add(%363, meta[relay.Constant][73]); - %365 = reshape(%364, newshape=[-1, 768]); - %366 = nn.dense(%365, meta[relay.Constant][74], units=2304); - %367 = add(%366, meta[relay.Constant][75]); - %368 = reshape(%367, newshape=[50, 32, 2304]); - %369 = split(%368, indices_or_sections=[768, 1536], axis=2); - %370 = %369.0; - %371 = reshape(%370, newshape=[50, 32, 12, 64]); - %372 = transpose(%371, axes=[0, 2, 1, 3]); - %373 = %369.1; - %374 = reshape(%373, newshape=[50, 32, 12, 64]); - %375 = transpose(%374, axes=[0, 2, 3, 1]); - %376 = reshape(%375, newshape=[-1, 64, 32]); - %377 = reshape(%372, newshape=[-1, 32, 64]); - %378 = transpose(%376, axes=[0, 2, 1]); - %379 = nn.batch_matmul(%377, %378, out_dtype="float32", transpose_b=True); - %380 = reshape(%379, newshape=[50, 12, 32, 32]); - %381 = divide(%380, 8f); - %382 = multiply(%381, meta[relay.Constant][76]); - %383 = subtract(%382, meta[relay.Constant][77]); - %384 = nn.softmax(%383, axis=3); - %385 = %369.2; - %386 = reshape(%385, newshape=[50, 32, 12, 64]); - %387 = transpose(%386, axes=[0, 2, 1, 3]); - %388 = reshape(%387, newshape=[-1, 32, 64]); - %389 = reshape(%384, newshape=[-1, 32, 32]); - %390 = transpose(%388, axes=[0, 2, 1]); - %391 = nn.batch_matmul(%389, %390, out_dtype="float32", transpose_b=True); - %392 = reshape(%391, newshape=[50, 12, 32, 64]); - %393 = transpose(%392, axes=[0, 2, 1, 3]); - %394 = reshape(%393, newshape=[50, 32, 768]); - %395 = reshape(%394, newshape=[-1, 768]); - %396 = nn.dense(%395, meta[relay.Constant][78], units=768); - %397 = add(%396, meta[relay.Constant][79]); - %398 = reshape(%397, newshape=[50, 32, 768]); - %399 = add(%355, %398); - %400 = mean(%399, axis=[-1], keepdims=True); - %401 = subtract(%399, %400); - %402 = power(%401, 2f); - %403 = mean(%402, axis=[-1], keepdims=True); - %404 = add(%403, 1e-05f); - %405 = sqrt(%404); - %406 = divide(%401, %405); - %407 = multiply(%406, meta[relay.Constant][80]); - %408 = add(%407, meta[relay.Constant][81]); - %409 = reshape(%408, newshape=[-1, 768]); - %410 = nn.dense(%409, meta[relay.Constant][82], units=3072); - %411 = add(%410, meta[relay.Constant][83]); - %412 = reshape(%411, newshape=[50, 32, 3072]); - %413 = power(%412, 3f); - %414 = multiply(%413, 0.044715f); - %415 = add(%412, %414); - %416 = multiply(%415, 0.797885f); - %417 = tanh(%416); - %418 = multiply(%412, 0.5f); - %419 = add(%417, 1f); - %420 = multiply(%418, %419); - %421 = reshape(%420, newshape=[-1, 3072]); - %422 = nn.dense(%421, meta[relay.Constant][84], units=768); - %423 = add(%422, meta[relay.Constant][85]); - %424 = reshape(%423, newshape=[50, 32, 768]); - %425 = add(%399, %424); - %426 = mean(%425, axis=[-1], keepdims=True); - %427 = subtract(%425, %426); - %428 = power(%427, 2f); - %429 = mean(%428, axis=[-1], keepdims=True); - %430 = add(%429, 1e-05f); - %431 = sqrt(%430); - %432 = divide(%427, %431); - %433 = multiply(%432, meta[relay.Constant][86]); - %434 = add(%433, meta[relay.Constant][87]); - %435 = reshape(%434, newshape=[-1, 768]); - %436 = nn.dense(%435, meta[relay.Constant][88], units=2304); - %437 = add(%436, meta[relay.Constant][89]); - %438 = reshape(%437, newshape=[50, 32, 2304]); - %439 = split(%438, indices_or_sections=[768, 1536], axis=2); - %440 = %439.0; - %441 = reshape(%440, newshape=[50, 32, 12, 64]); - %442 = transpose(%441, axes=[0, 2, 1, 3]); - %443 = %439.1; - %444 = reshape(%443, newshape=[50, 32, 12, 64]); - %445 = transpose(%444, axes=[0, 2, 3, 1]); - %446 = reshape(%445, newshape=[-1, 64, 32]); - %447 = reshape(%442, newshape=[-1, 32, 64]); - %448 = transpose(%446, axes=[0, 2, 1]); - %449 = nn.batch_matmul(%447, %448, out_dtype="float32", transpose_b=True); - %450 = reshape(%449, newshape=[50, 12, 32, 32]); - %451 = divide(%450, 8f); - %452 = multiply(%451, meta[relay.Constant][90]); - %453 = subtract(%452, meta[relay.Constant][91]); - %454 = nn.softmax(%453, axis=3); - %455 = %439.2; - %456 = reshape(%455, newshape=[50, 32, 12, 64]); - %457 = transpose(%456, axes=[0, 2, 1, 3]); - %458 = reshape(%457, newshape=[-1, 32, 64]); - %459 = reshape(%454, newshape=[-1, 32, 32]); - %460 = transpose(%458, axes=[0, 2, 1]); - %461 = nn.batch_matmul(%459, %460, out_dtype="float32", transpose_b=True); - %462 = reshape(%461, newshape=[50, 12, 32, 64]); - %463 = transpose(%462, axes=[0, 2, 1, 3]); - %464 = reshape(%463, newshape=[50, 32, 768]); - %465 = reshape(%464, newshape=[-1, 768]); - %466 = nn.dense(%465, meta[relay.Constant][92], units=768); - %467 = add(%466, meta[relay.Constant][93]); - %468 = reshape(%467, newshape=[50, 32, 768]); - %469 = add(%425, %468); - %470 = mean(%469, axis=[-1], keepdims=True); - %471 = subtract(%469, %470); - %472 = power(%471, 2f); - %473 = mean(%472, axis=[-1], keepdims=True); - %474 = add(%473, 1e-05f); - %475 = sqrt(%474); - %476 = divide(%471, %475); - %477 = multiply(%476, meta[relay.Constant][94]); - %478 = add(%477, meta[relay.Constant][95]); - %479 = reshape(%478, newshape=[-1, 768]); - %480 = nn.dense(%479, meta[relay.Constant][96], units=3072); - %481 = add(%480, meta[relay.Constant][97]); - %482 = reshape(%481, newshape=[50, 32, 3072]); - %483 = power(%482, 3f); - %484 = multiply(%483, 0.044715f); - %485 = add(%482, %484); - %486 = multiply(%485, 0.797885f); - %487 = tanh(%486); - %488 = multiply(%482, 0.5f); - %489 = add(%487, 1f); - %490 = multiply(%488, %489); - %491 = reshape(%490, newshape=[-1, 3072]); - %492 = nn.dense(%491, meta[relay.Constant][98], units=768); - %493 = add(%492, meta[relay.Constant][99]); - %494 = reshape(%493, newshape=[50, 32, 768]); - %495 = add(%469, %494); - %496 = mean(%495, axis=[-1], keepdims=True); - %497 = subtract(%495, %496); - %498 = power(%497, 2f); - %499 = mean(%498, axis=[-1], keepdims=True); - %500 = add(%499, 1e-05f); - %501 = sqrt(%500); - %502 = divide(%497, %501); - %503 = multiply(%502, meta[relay.Constant][100]); - %504 = add(%503, meta[relay.Constant][101]); - %505 = reshape(%504, newshape=[-1, 768]); - %506 = nn.dense(%505, meta[relay.Constant][102], units=2304); - %507 = add(%506, meta[relay.Constant][103]); - %508 = reshape(%507, newshape=[50, 32, 2304]); - %509 = split(%508, indices_or_sections=[768, 1536], axis=2); - %510 = %509.0; - %511 = reshape(%510, newshape=[50, 32, 12, 64]); - %512 = transpose(%511, axes=[0, 2, 1, 3]); - %513 = %509.1; - %514 = reshape(%513, newshape=[50, 32, 12, 64]); - %515 = transpose(%514, axes=[0, 2, 3, 1]); - %516 = reshape(%515, newshape=[-1, 64, 32]); - %517 = reshape(%512, newshape=[-1, 32, 64]); - %518 = transpose(%516, axes=[0, 2, 1]); - %519 = nn.batch_matmul(%517, %518, out_dtype="float32", transpose_b=True); - %520 = reshape(%519, newshape=[50, 12, 32, 32]); - %521 = divide(%520, 8f); - %522 = multiply(%521, meta[relay.Constant][104]); - %523 = subtract(%522, meta[relay.Constant][105]); - %524 = nn.softmax(%523, axis=3); - %525 = %509.2; - %526 = reshape(%525, newshape=[50, 32, 12, 64]); - %527 = transpose(%526, axes=[0, 2, 1, 3]); - %528 = reshape(%527, newshape=[-1, 32, 64]); - %529 = reshape(%524, newshape=[-1, 32, 32]); - %530 = transpose(%528, axes=[0, 2, 1]); - %531 = nn.batch_matmul(%529, %530, out_dtype="float32", transpose_b=True); - %532 = reshape(%531, newshape=[50, 12, 32, 64]); - %533 = transpose(%532, axes=[0, 2, 1, 3]); - %534 = reshape(%533, newshape=[50, 32, 768]); - %535 = reshape(%534, newshape=[-1, 768]); - %536 = nn.dense(%535, meta[relay.Constant][106], units=768); - %537 = add(%536, meta[relay.Constant][107]); - %538 = reshape(%537, newshape=[50, 32, 768]); - %539 = add(%495, %538); - %540 = mean(%539, axis=[-1], keepdims=True); - %541 = subtract(%539, %540); - %542 = power(%541, 2f); - %543 = mean(%542, axis=[-1], keepdims=True); - %544 = add(%543, 1e-05f); - %545 = sqrt(%544); - %546 = divide(%541, %545); - %547 = multiply(%546, meta[relay.Constant][108]); - %548 = add(%547, meta[relay.Constant][109]); - %549 = reshape(%548, newshape=[-1, 768]); - %550 = nn.dense(%549, meta[relay.Constant][110], units=3072); - %551 = add(%550, meta[relay.Constant][111]); - %552 = reshape(%551, newshape=[50, 32, 3072]); - %553 = power(%552, 3f); - %554 = multiply(%553, 0.044715f); - %555 = add(%552, %554); - %556 = multiply(%555, 0.797885f); - %557 = tanh(%556); - %558 = multiply(%552, 0.5f); - %559 = add(%557, 1f); - %560 = multiply(%558, %559); - %561 = reshape(%560, newshape=[-1, 3072]); - %562 = nn.dense(%561, meta[relay.Constant][112], units=768); - %563 = add(%562, meta[relay.Constant][113]); - %564 = reshape(%563, newshape=[50, 32, 768]); - %565 = add(%539, %564); - %566 = mean(%565, axis=[-1], keepdims=True); - %567 = subtract(%565, %566); - %568 = power(%567, 2f); - %569 = mean(%568, axis=[-1], keepdims=True); - %570 = add(%569, 1e-05f); - %571 = sqrt(%570); - %572 = divide(%567, %571); - %573 = multiply(%572, meta[relay.Constant][114]); - %574 = add(%573, meta[relay.Constant][115]); - %575 = reshape(%574, newshape=[-1, 768]); - %576 = nn.dense(%575, meta[relay.Constant][116], units=2304); - %577 = add(%576, meta[relay.Constant][117]); - %578 = reshape(%577, newshape=[50, 32, 2304]); - %579 = split(%578, indices_or_sections=[768, 1536], axis=2); - %580 = %579.0; - %581 = reshape(%580, newshape=[50, 32, 12, 64]); - %582 = transpose(%581, axes=[0, 2, 1, 3]); - %583 = %579.1; - %584 = reshape(%583, newshape=[50, 32, 12, 64]); - %585 = transpose(%584, axes=[0, 2, 3, 1]); - %586 = reshape(%585, newshape=[-1, 64, 32]); - %587 = reshape(%582, newshape=[-1, 32, 64]); - %588 = transpose(%586, axes=[0, 2, 1]); - %589 = nn.batch_matmul(%587, %588, out_dtype="float32", transpose_b=True); - %590 = reshape(%589, newshape=[50, 12, 32, 32]); - %591 = divide(%590, 8f); - %592 = multiply(%591, meta[relay.Constant][118]); - %593 = subtract(%592, meta[relay.Constant][119]); - %594 = nn.softmax(%593, axis=3); - %595 = %579.2; - %596 = reshape(%595, newshape=[50, 32, 12, 64]); - %597 = transpose(%596, axes=[0, 2, 1, 3]); - %598 = reshape(%597, newshape=[-1, 32, 64]); - %599 = reshape(%594, newshape=[-1, 32, 32]); - %600 = transpose(%598, axes=[0, 2, 1]); - %601 = nn.batch_matmul(%599, %600, out_dtype="float32", transpose_b=True); - %602 = reshape(%601, newshape=[50, 12, 32, 64]); - %603 = transpose(%602, axes=[0, 2, 1, 3]); - %604 = reshape(%603, newshape=[50, 32, 768]); - %605 = reshape(%604, newshape=[-1, 768]); - %606 = nn.dense(%605, meta[relay.Constant][120], units=768); - %607 = add(%606, meta[relay.Constant][121]); - %608 = reshape(%607, newshape=[50, 32, 768]); - %609 = add(%565, %608); - %610 = mean(%609, axis=[-1], keepdims=True); - %611 = subtract(%609, %610); - %612 = power(%611, 2f); - %613 = mean(%612, axis=[-1], keepdims=True); - %614 = add(%613, 1e-05f); - %615 = sqrt(%614); - %616 = divide(%611, %615); - %617 = multiply(%616, meta[relay.Constant][122]); - %618 = add(%617, meta[relay.Constant][123]); - %619 = reshape(%618, newshape=[-1, 768]); - %620 = nn.dense(%619, meta[relay.Constant][124], units=3072); - %621 = add(%620, meta[relay.Constant][125]); - %622 = reshape(%621, newshape=[50, 32, 3072]); - %623 = power(%622, 3f); - %624 = multiply(%623, 0.044715f); - %625 = add(%622, %624); - %626 = multiply(%625, 0.797885f); - %627 = tanh(%626); - %628 = multiply(%622, 0.5f); - %629 = add(%627, 1f); - %630 = multiply(%628, %629); - %631 = reshape(%630, newshape=[-1, 3072]); - %632 = nn.dense(%631, meta[relay.Constant][126], units=768); - %633 = add(%632, meta[relay.Constant][127]); - %634 = reshape(%633, newshape=[50, 32, 768]); - %635 = add(%609, %634); - %636 = mean(%635, axis=[-1], keepdims=True); - %637 = subtract(%635, %636); - %638 = power(%637, 2f); - %639 = mean(%638, axis=[-1], keepdims=True); - %640 = add(%639, 1e-05f); - %641 = sqrt(%640); - %642 = divide(%637, %641); - %643 = multiply(%642, meta[relay.Constant][128]); - %644 = add(%643, meta[relay.Constant][129]); - %645 = reshape(%644, newshape=[-1, 768]); - %646 = nn.dense(%645, meta[relay.Constant][130], units=2304); - %647 = add(%646, meta[relay.Constant][131]); - %648 = reshape(%647, newshape=[50, 32, 2304]); - %649 = split(%648, indices_or_sections=[768, 1536], axis=2); - %650 = %649.0; - %651 = reshape(%650, newshape=[50, 32, 12, 64]); - %652 = transpose(%651, axes=[0, 2, 1, 3]); - %653 = %649.1; - %654 = reshape(%653, newshape=[50, 32, 12, 64]); - %655 = transpose(%654, axes=[0, 2, 3, 1]); - %656 = reshape(%655, newshape=[-1, 64, 32]); - %657 = reshape(%652, newshape=[-1, 32, 64]); - %658 = transpose(%656, axes=[0, 2, 1]); - %659 = nn.batch_matmul(%657, %658, out_dtype="float32", transpose_b=True); - %660 = reshape(%659, newshape=[50, 12, 32, 32]); - %661 = divide(%660, 8f); - %662 = multiply(%661, meta[relay.Constant][132]); - %663 = subtract(%662, meta[relay.Constant][133]); - %664 = nn.softmax(%663, axis=3); - %665 = %649.2; - %666 = reshape(%665, newshape=[50, 32, 12, 64]); - %667 = transpose(%666, axes=[0, 2, 1, 3]); - %668 = reshape(%667, newshape=[-1, 32, 64]); - %669 = reshape(%664, newshape=[-1, 32, 32]); - %670 = transpose(%668, axes=[0, 2, 1]); - %671 = nn.batch_matmul(%669, %670, out_dtype="float32", transpose_b=True); - %672 = reshape(%671, newshape=[50, 12, 32, 64]); - %673 = transpose(%672, axes=[0, 2, 1, 3]); - %674 = reshape(%673, newshape=[50, 32, 768]); - %675 = reshape(%674, newshape=[-1, 768]); - %676 = nn.dense(%675, meta[relay.Constant][134], units=768); - %677 = add(%676, meta[relay.Constant][135]); - %678 = reshape(%677, newshape=[50, 32, 768]); - %679 = add(%635, %678); - %680 = mean(%679, axis=[-1], keepdims=True); - %681 = subtract(%679, %680); - %682 = power(%681, 2f); - %683 = mean(%682, axis=[-1], keepdims=True); - %684 = add(%683, 1e-05f); - %685 = sqrt(%684); - %686 = divide(%681, %685); - %687 = multiply(%686, meta[relay.Constant][136]); - %688 = add(%687, meta[relay.Constant][137]); - %689 = reshape(%688, newshape=[-1, 768]); - %690 = nn.dense(%689, meta[relay.Constant][138], units=3072); - %691 = add(%690, meta[relay.Constant][139]); - %692 = reshape(%691, newshape=[50, 32, 3072]); - %693 = power(%692, 3f); - %694 = multiply(%693, 0.044715f); - %695 = add(%692, %694); - %696 = multiply(%695, 0.797885f); - %697 = tanh(%696); - %698 = multiply(%692, 0.5f); - %699 = add(%697, 1f); - %700 = multiply(%698, %699); - %701 = reshape(%700, newshape=[-1, 3072]); - %702 = nn.dense(%701, meta[relay.Constant][140], units=768); - %703 = add(%702, meta[relay.Constant][141]); - %704 = reshape(%703, newshape=[50, 32, 768]); - %705 = add(%679, %704); - %706 = mean(%705, axis=[-1], keepdims=True); - %707 = subtract(%705, %706); - %708 = power(%707, 2f); - %709 = mean(%708, axis=[-1], keepdims=True); - %710 = add(%709, 1e-05f); - %711 = sqrt(%710); - %712 = divide(%707, %711); - %713 = multiply(%712, meta[relay.Constant][142]); - %714 = add(%713, meta[relay.Constant][143]); - %715 = reshape(%714, newshape=[-1, 768]); - %716 = nn.dense(%715, meta[relay.Constant][144], units=2304); - %717 = add(%716, meta[relay.Constant][145]); - %718 = reshape(%717, newshape=[50, 32, 2304]); - %719 = split(%718, indices_or_sections=[768, 1536], axis=2); - %720 = %719.0; - %721 = reshape(%720, newshape=[50, 32, 12, 64]); - %722 = transpose(%721, axes=[0, 2, 1, 3]); - %723 = %719.1; - %724 = reshape(%723, newshape=[50, 32, 12, 64]); - %725 = transpose(%724, axes=[0, 2, 3, 1]); - %726 = reshape(%725, newshape=[-1, 64, 32]); - %727 = reshape(%722, newshape=[-1, 32, 64]); - %728 = transpose(%726, axes=[0, 2, 1]); - %729 = nn.batch_matmul(%727, %728, out_dtype="float32", transpose_b=True); - %730 = reshape(%729, newshape=[50, 12, 32, 32]); - %731 = divide(%730, 8f); - %732 = multiply(%731, meta[relay.Constant][146]); - %733 = subtract(%732, meta[relay.Constant][147]); - %734 = nn.softmax(%733, axis=3); - %735 = %719.2; - %736 = reshape(%735, newshape=[50, 32, 12, 64]); - %737 = transpose(%736, axes=[0, 2, 1, 3]); - %738 = reshape(%737, newshape=[-1, 32, 64]); - %739 = reshape(%734, newshape=[-1, 32, 32]); - %740 = transpose(%738, axes=[0, 2, 1]); - %741 = nn.batch_matmul(%739, %740, out_dtype="float32", transpose_b=True); - %742 = reshape(%741, newshape=[50, 12, 32, 64]); - %743 = transpose(%742, axes=[0, 2, 1, 3]); - %744 = reshape(%743, newshape=[50, 32, 768]); - %745 = reshape(%744, newshape=[-1, 768]); - %746 = nn.dense(%745, meta[relay.Constant][148], units=768); - %747 = add(%746, meta[relay.Constant][149]); - %748 = reshape(%747, newshape=[50, 32, 768]); - %749 = add(%705, %748); - %750 = mean(%749, axis=[-1], keepdims=True); - %751 = subtract(%749, %750); - %752 = power(%751, 2f); - %753 = mean(%752, axis=[-1], keepdims=True); - %754 = add(%753, 1e-05f); - %755 = sqrt(%754); - %756 = divide(%751, %755); - %757 = multiply(%756, meta[relay.Constant][150]); - %758 = add(%757, meta[relay.Constant][151]); - %759 = reshape(%758, newshape=[-1, 768]); - %760 = nn.dense(%759, meta[relay.Constant][152], units=3072); - %761 = add(%760, meta[relay.Constant][153]); - %762 = reshape(%761, newshape=[50, 32, 3072]); - %763 = power(%762, 3f); - %764 = multiply(%763, 0.044715f); - %765 = add(%762, %764); - %766 = multiply(%765, 0.797885f); - %767 = tanh(%766); - %768 = multiply(%762, 0.5f); - %769 = add(%767, 1f); - %770 = multiply(%768, %769); - %771 = reshape(%770, newshape=[-1, 3072]); - %772 = nn.dense(%771, meta[relay.Constant][154], units=768); - %773 = add(%772, meta[relay.Constant][155]); - %774 = reshape(%773, newshape=[50, 32, 768]); - %775 = add(%749, %774); - %776 = mean(%775, axis=[-1], keepdims=True); - %777 = subtract(%775, %776); - %778 = power(%777, 2f); - %779 = mean(%778, axis=[-1], keepdims=True); - %780 = add(%779, 1e-05f); - %781 = sqrt(%780); - %782 = divide(%777, %781); - %783 = multiply(%782, meta[relay.Constant][156]); - %784 = add(%783, meta[relay.Constant][157]); - %785 = reshape(%784, newshape=[-1, 768]); - %786 = nn.dense(%785, meta[relay.Constant][158], units=2304); - %787 = add(%786, meta[relay.Constant][159]); - %788 = reshape(%787, newshape=[50, 32, 2304]); - %789 = split(%788, indices_or_sections=[768, 1536], axis=2); - %790 = %789.0; - %791 = reshape(%790, newshape=[50, 32, 12, 64]); - %792 = transpose(%791, axes=[0, 2, 1, 3]); - %793 = %789.1; - %794 = reshape(%793, newshape=[50, 32, 12, 64]); - %795 = transpose(%794, axes=[0, 2, 3, 1]); - %796 = reshape(%795, newshape=[-1, 64, 32]); - %797 = reshape(%792, newshape=[-1, 32, 64]); - %798 = transpose(%796, axes=[0, 2, 1]); - %799 = nn.batch_matmul(%797, %798, out_dtype="float32", transpose_b=True); - %800 = reshape(%799, newshape=[50, 12, 32, 32]); - %801 = divide(%800, 8f); - %802 = multiply(%801, meta[relay.Constant][160]); - %803 = subtract(%802, meta[relay.Constant][161]); - %804 = nn.softmax(%803, axis=3); - %805 = %789.2; - %806 = reshape(%805, newshape=[50, 32, 12, 64]); - %807 = transpose(%806, axes=[0, 2, 1, 3]); - %808 = reshape(%807, newshape=[-1, 32, 64]); - %809 = reshape(%804, newshape=[-1, 32, 32]); - %810 = transpose(%808, axes=[0, 2, 1]); - %811 = nn.batch_matmul(%809, %810, out_dtype="float32", transpose_b=True); - %812 = reshape(%811, newshape=[50, 12, 32, 64]); - %813 = transpose(%812, axes=[0, 2, 1, 3]); - %814 = reshape(%813, newshape=[50, 32, 768]); - %815 = reshape(%814, newshape=[-1, 768]); - %816 = nn.dense(%815, meta[relay.Constant][162], units=768); - %817 = add(%816, meta[relay.Constant][163]); - %818 = reshape(%817, newshape=[50, 32, 768]); - %819 = add(%775, %818); - %820 = mean(%819, axis=[-1], keepdims=True); - %821 = subtract(%819, %820); - %822 = power(%821, 2f); - %823 = mean(%822, axis=[-1], keepdims=True); - %824 = add(%823, 1e-05f); - %825 = sqrt(%824); - %826 = divide(%821, %825); - %827 = multiply(%826, meta[relay.Constant][164]); - %828 = add(%827, meta[relay.Constant][165]); - %829 = reshape(%828, newshape=[-1, 768]); - %830 = nn.dense(%829, meta[relay.Constant][166], units=3072); - %831 = add(%830, meta[relay.Constant][167]); - %832 = reshape(%831, newshape=[50, 32, 3072]); - %833 = power(%832, 3f); - %834 = multiply(%833, 0.044715f); - %835 = add(%832, %834); - %836 = multiply(%835, 0.797885f); - %837 = tanh(%836); - %838 = multiply(%832, 0.5f); - %839 = add(%837, 1f); - %840 = multiply(%838, %839); - %841 = reshape(%840, newshape=[-1, 3072]); - %842 = nn.dense(%841, meta[relay.Constant][168], units=768); - %843 = add(%842, meta[relay.Constant][169]); - %844 = reshape(%843, newshape=[50, 32, 768]); - %845 = add(%819, %844); - %846 = mean(%845, axis=[-1], keepdims=True); - %847 = subtract(%845, %846); - %848 = power(%847, 2f); - %849 = mean(%848, axis=[-1], keepdims=True); - %850 = add(%849, 1e-05f); - %851 = sqrt(%850); - %852 = divide(%847, %851); - %853 = multiply(%852, meta[relay.Constant][170]); - %854 = add(%853, meta[relay.Constant][171]); - %855 = transpose(%24, axes=[0, 2, 1, 3]); - %856 = expand_dims(%855, axis=0); - %857 = expand_dims(%37, axis=0); - %858 = (%856, %857); - %859 = transpose(%94, axes=[0, 2, 1, 3]); - %860 = expand_dims(%859, axis=0); - %861 = expand_dims(%107, axis=0); - %862 = (%860, %861); - %863 = transpose(%164, axes=[0, 2, 1, 3]); - %864 = expand_dims(%863, axis=0); - %865 = expand_dims(%177, axis=0); - %866 = (%864, %865); - %867 = transpose(%234, axes=[0, 2, 1, 3]); - %868 = expand_dims(%867, axis=0); - %869 = expand_dims(%247, axis=0); - %870 = (%868, %869); - %871 = transpose(%304, axes=[0, 2, 1, 3]); - %872 = expand_dims(%871, axis=0); - %873 = expand_dims(%317, axis=0); - %874 = (%872, %873); - %875 = transpose(%374, axes=[0, 2, 1, 3]); - %876 = expand_dims(%875, axis=0); - %877 = expand_dims(%387, axis=0); - %878 = (%876, %877); - %879 = transpose(%444, axes=[0, 2, 1, 3]); - %880 = expand_dims(%879, axis=0); - %881 = expand_dims(%457, axis=0); - %882 = (%880, %881); - %883 = transpose(%514, axes=[0, 2, 1, 3]); - %884 = expand_dims(%883, axis=0); - %885 = expand_dims(%527, axis=0); - %886 = (%884, %885); - %887 = transpose(%584, axes=[0, 2, 1, 3]); - %888 = expand_dims(%887, axis=0); - %889 = expand_dims(%597, axis=0); - %890 = (%888, %889); - %891 = transpose(%654, axes=[0, 2, 1, 3]); - %892 = expand_dims(%891, axis=0); - %893 = expand_dims(%667, axis=0); - %894 = (%892, %893); - %895 = transpose(%724, axes=[0, 2, 1, 3]); - %896 = expand_dims(%895, axis=0); - %897 = expand_dims(%737, axis=0); - %898 = (%896, %897); - %899 = transpose(%794, axes=[0, 2, 1, 3]); - %900 = expand_dims(%899, axis=0); - %901 = expand_dims(%807, axis=0); - %902 = (%900, %901); - %903 = reshape(%854, newshape=[1, 50, 32, 768]); - %904 = concatenate(%858); - %905 = concatenate(%862); - %906 = concatenate(%866); - %907 = concatenate(%870); - %908 = concatenate(%874); - %909 = concatenate(%878); - %910 = concatenate(%882); - %911 = concatenate(%886); - %912 = concatenate(%890); - %913 = concatenate(%894); - %914 = concatenate(%898); - %915 = concatenate(%902); - (%903, %904, %905, %906, %907, %908, %909, %910, %911, %912, %913, %914, %915) - } - """, - "from_string", - None, - metatable, - ) - return { - "name": "gpt2", - "input_shapes": {"x": [1, 50, 32]}, - "input_dtypes": {"x": "int64"}, - "mod": mod, - "params": None, - "main_dtype": "float32", - } - - -def gpt2_16(): - metatable = {"relay.Constant": gpt2_consts("float16")} - mod = tvm.parser.parse( - """ - #[version = "0.0.5"] - def @main(%x: Tensor[(1, 50, 32), int64]) -> (Tensor[(1, 50, 32, 768), float16], - Tensor[(2, 50, 12, 32, 64), float16], - Tensor[(2, 50, 12, 32, 64), float16], - Tensor[(2, 50, 12, 32, 64), float16], - Tensor[(2, 50, 12, 32, 64), float16], - Tensor[(2, 50, 12, 32, 64), float16], - Tensor[(2, 50, 12, 32, 64), float16], - Tensor[(2, 50, 12, 32, 64), float16], - Tensor[(2, 50, 12, 32, 64), float16], - Tensor[(2, 50, 12, 32, 64), float16], - Tensor[(2, 50, 12, 32, 64), float16], - Tensor[(2, 50, 12, 32, 64), float16], - Tensor[(2, 50, 12, 32, 64), float16]) { - %0 = reshape(%x, newshape=[-1, 32]); - %1 = less(%0, 0i64); - %2 = add(%0, 50257i64); - %3 = where(%1, %2, %0); - %4 = take(meta[relay.Constant][0], %3, axis=0); - %5 = add(%4, meta[relay.Constant][1]); - %6 = mean(%5, axis=[-1], keepdims=True); - %7 = subtract(%5, %6); - %8 = power(%7, 2f16); - %9 = mean(%8, axis=[-1], keepdims=True); - %10 = add(%9, 1e-05f16); - %11 = sqrt(%10); - %12 = divide(%7, %11); - %13 = multiply(%12, meta[relay.Constant][2]); - %14 = add(%13, meta[relay.Constant][3]); - %15 = reshape(%14, newshape=[-1, 768]); - %16 = nn.dense(%15, meta[relay.Constant][4], units=2304); - %17 = add(%16, meta[relay.Constant][5]); - %18 = reshape(%17, newshape=[50, 32, 2304]); - %19 = split(%18, indices_or_sections=[768, 1536], axis=2); - %20 = %19.0; - %21 = reshape(%20, newshape=[50, 32, 12, 64]); - %22 = transpose(%21, axes=[0, 2, 1, 3]); - %23 = %19.1; - %24 = reshape(%23, newshape=[50, 32, 12, 64]); - %25 = transpose(%24, axes=[0, 2, 3, 1]); - %26 = reshape(%25, newshape=[-1, 64, 32]); - %27 = reshape(%22, newshape=[-1, 32, 64]); - %28 = transpose(%26, axes=[0, 2, 1]); - %29 = nn.batch_matmul(%27, %28, out_dtype="float16", transpose_b=True); - %30 = reshape(%29, newshape=[50, 12, 32, 32]); - %31 = divide(%30, 8f16); - %32 = multiply(%31, meta[relay.Constant][6]); - %33 = subtract(%32, meta[relay.Constant][7]); - %34 = nn.softmax(%33, axis=3); - %35 = %19.2; - %36 = reshape(%35, newshape=[50, 32, 12, 64]); - %37 = transpose(%36, axes=[0, 2, 1, 3]); - %38 = reshape(%37, newshape=[-1, 32, 64]); - %39 = reshape(%34, newshape=[-1, 32, 32]); - %40 = transpose(%38, axes=[0, 2, 1]); - %41 = nn.batch_matmul(%39, %40, out_dtype="float16", transpose_b=True); - %42 = reshape(%41, newshape=[50, 12, 32, 64]); - %43 = transpose(%42, axes=[0, 2, 1, 3]); - %44 = reshape(%43, newshape=[50, 32, 768]); - %45 = reshape(%44, newshape=[-1, 768]); - %46 = nn.dense(%45, meta[relay.Constant][8], units=768); - %47 = add(%46, meta[relay.Constant][9]); - %48 = reshape(%47, newshape=[50, 32, 768]); - %49 = add(%5, %48); - %50 = mean(%49, axis=[-1], keepdims=True); - %51 = subtract(%49, %50); - %52 = power(%51, 2f16); - %53 = mean(%52, axis=[-1], keepdims=True); - %54 = add(%53, 1e-05f16); - %55 = sqrt(%54); - %56 = divide(%51, %55); - %57 = multiply(%56, meta[relay.Constant][10]); - %58 = add(%57, meta[relay.Constant][11]); - %59 = reshape(%58, newshape=[-1, 768]); - %60 = nn.dense(%59, meta[relay.Constant][12], units=3072); - %61 = add(%60, meta[relay.Constant][13]); - %62 = reshape(%61, newshape=[50, 32, 3072]); - %63 = power(%62, 3f16); - %64 = multiply(%63, 0.044715f16); - %65 = add(%62, %64); - %66 = multiply(%65, 0.797885f16); - %67 = tanh(%66); - %68 = multiply(%62, 0.5f16); - %69 = add(%67, 1f16); - %70 = multiply(%68, %69); - %71 = reshape(%70, newshape=[-1, 3072]); - %72 = nn.dense(%71, meta[relay.Constant][14], units=768); - %73 = add(%72, meta[relay.Constant][15]); - %74 = reshape(%73, newshape=[50, 32, 768]); - %75 = add(%49, %74); - %76 = mean(%75, axis=[-1], keepdims=True); - %77 = subtract(%75, %76); - %78 = power(%77, 2f16); - %79 = mean(%78, axis=[-1], keepdims=True); - %80 = add(%79, 1e-05f16); - %81 = sqrt(%80); - %82 = divide(%77, %81); - %83 = multiply(%82, meta[relay.Constant][16]); - %84 = add(%83, meta[relay.Constant][17]); - %85 = reshape(%84, newshape=[-1, 768]); - %86 = nn.dense(%85, meta[relay.Constant][18], units=2304); - %87 = add(%86, meta[relay.Constant][19]); - %88 = reshape(%87, newshape=[50, 32, 2304]); - %89 = split(%88, indices_or_sections=[768, 1536], axis=2); - %90 = %89.0; - %91 = reshape(%90, newshape=[50, 32, 12, 64]); - %92 = transpose(%91, axes=[0, 2, 1, 3]); - %93 = %89.1; - %94 = reshape(%93, newshape=[50, 32, 12, 64]); - %95 = transpose(%94, axes=[0, 2, 3, 1]); - %96 = reshape(%95, newshape=[-1, 64, 32]); - %97 = reshape(%92, newshape=[-1, 32, 64]); - %98 = transpose(%96, axes=[0, 2, 1]); - %99 = nn.batch_matmul(%97, %98, out_dtype="float16", transpose_b=True); - %100 = reshape(%99, newshape=[50, 12, 32, 32]); - %101 = divide(%100, 8f16); - %102 = multiply(%101, meta[relay.Constant][20]); - %103 = subtract(%102, meta[relay.Constant][21]); - %104 = nn.softmax(%103, axis=3); - %105 = %89.2; - %106 = reshape(%105, newshape=[50, 32, 12, 64]); - %107 = transpose(%106, axes=[0, 2, 1, 3]); - %108 = reshape(%107, newshape=[-1, 32, 64]); - %109 = reshape(%104, newshape=[-1, 32, 32]); - %110 = transpose(%108, axes=[0, 2, 1]); - %111 = nn.batch_matmul(%109, %110, out_dtype="float16", transpose_b=True); - %112 = reshape(%111, newshape=[50, 12, 32, 64]); - %113 = transpose(%112, axes=[0, 2, 1, 3]); - %114 = reshape(%113, newshape=[50, 32, 768]); - %115 = reshape(%114, newshape=[-1, 768]); - %116 = nn.dense(%115, meta[relay.Constant][22], units=768); - %117 = add(%116, meta[relay.Constant][23]); - %118 = reshape(%117, newshape=[50, 32, 768]); - %119 = add(%75, %118); - %120 = mean(%119, axis=[-1], keepdims=True); - %121 = subtract(%119, %120); - %122 = power(%121, 2f16); - %123 = mean(%122, axis=[-1], keepdims=True); - %124 = add(%123, 1e-05f16); - %125 = sqrt(%124); - %126 = divide(%121, %125); - %127 = multiply(%126, meta[relay.Constant][24]); - %128 = add(%127, meta[relay.Constant][25]); - %129 = reshape(%128, newshape=[-1, 768]); - %130 = nn.dense(%129, meta[relay.Constant][26], units=3072); - %131 = add(%130, meta[relay.Constant][27]); - %132 = reshape(%131, newshape=[50, 32, 3072]); - %133 = power(%132, 3f16); - %134 = multiply(%133, 0.044715f16); - %135 = add(%132, %134); - %136 = multiply(%135, 0.797885f16); - %137 = tanh(%136); - %138 = multiply(%132, 0.5f16); - %139 = add(%137, 1f16); - %140 = multiply(%138, %139); - %141 = reshape(%140, newshape=[-1, 3072]); - %142 = nn.dense(%141, meta[relay.Constant][28], units=768); - %143 = add(%142, meta[relay.Constant][29]); - %144 = reshape(%143, newshape=[50, 32, 768]); - %145 = add(%119, %144); - %146 = mean(%145, axis=[-1], keepdims=True); - %147 = subtract(%145, %146); - %148 = power(%147, 2f16); - %149 = mean(%148, axis=[-1], keepdims=True); - %150 = add(%149, 1e-05f16); - %151 = sqrt(%150); - %152 = divide(%147, %151); - %153 = multiply(%152, meta[relay.Constant][30]); - %154 = add(%153, meta[relay.Constant][31]); - %155 = reshape(%154, newshape=[-1, 768]); - %156 = nn.dense(%155, meta[relay.Constant][32], units=2304); - %157 = add(%156, meta[relay.Constant][33]); - %158 = reshape(%157, newshape=[50, 32, 2304]); - %159 = split(%158, indices_or_sections=[768, 1536], axis=2); - %160 = %159.0; - %161 = reshape(%160, newshape=[50, 32, 12, 64]); - %162 = transpose(%161, axes=[0, 2, 1, 3]); - %163 = %159.1; - %164 = reshape(%163, newshape=[50, 32, 12, 64]); - %165 = transpose(%164, axes=[0, 2, 3, 1]); - %166 = reshape(%165, newshape=[-1, 64, 32]); - %167 = reshape(%162, newshape=[-1, 32, 64]); - %168 = transpose(%166, axes=[0, 2, 1]); - %169 = nn.batch_matmul(%167, %168, out_dtype="float16", transpose_b=True); - %170 = reshape(%169, newshape=[50, 12, 32, 32]); - %171 = divide(%170, 8f16); - %172 = multiply(%171, meta[relay.Constant][34]); - %173 = subtract(%172, meta[relay.Constant][35]); - %174 = nn.softmax(%173, axis=3); - %175 = %159.2; - %176 = reshape(%175, newshape=[50, 32, 12, 64]); - %177 = transpose(%176, axes=[0, 2, 1, 3]); - %178 = reshape(%177, newshape=[-1, 32, 64]); - %179 = reshape(%174, newshape=[-1, 32, 32]); - %180 = transpose(%178, axes=[0, 2, 1]); - %181 = nn.batch_matmul(%179, %180, out_dtype="float16", transpose_b=True); - %182 = reshape(%181, newshape=[50, 12, 32, 64]); - %183 = transpose(%182, axes=[0, 2, 1, 3]); - %184 = reshape(%183, newshape=[50, 32, 768]); - %185 = reshape(%184, newshape=[-1, 768]); - %186 = nn.dense(%185, meta[relay.Constant][36], units=768); - %187 = add(%186, meta[relay.Constant][37]); - %188 = reshape(%187, newshape=[50, 32, 768]); - %189 = add(%145, %188); - %190 = mean(%189, axis=[-1], keepdims=True); - %191 = subtract(%189, %190); - %192 = power(%191, 2f16); - %193 = mean(%192, axis=[-1], keepdims=True); - %194 = add(%193, 1e-05f16); - %195 = sqrt(%194); - %196 = divide(%191, %195); - %197 = multiply(%196, meta[relay.Constant][38]); - %198 = add(%197, meta[relay.Constant][39]); - %199 = reshape(%198, newshape=[-1, 768]); - %200 = nn.dense(%199, meta[relay.Constant][40], units=3072); - %201 = add(%200, meta[relay.Constant][41]); - %202 = reshape(%201, newshape=[50, 32, 3072]); - %203 = power(%202, 3f16); - %204 = multiply(%203, 0.044715f16); - %205 = add(%202, %204); - %206 = multiply(%205, 0.797885f16); - %207 = tanh(%206); - %208 = multiply(%202, 0.5f16); - %209 = add(%207, 1f16); - %210 = multiply(%208, %209); - %211 = reshape(%210, newshape=[-1, 3072]); - %212 = nn.dense(%211, meta[relay.Constant][42], units=768); - %213 = add(%212, meta[relay.Constant][43]); - %214 = reshape(%213, newshape=[50, 32, 768]); - %215 = add(%189, %214); - %216 = mean(%215, axis=[-1], keepdims=True); - %217 = subtract(%215, %216); - %218 = power(%217, 2f16); - %219 = mean(%218, axis=[-1], keepdims=True); - %220 = add(%219, 1e-05f16); - %221 = sqrt(%220); - %222 = divide(%217, %221); - %223 = multiply(%222, meta[relay.Constant][44]); - %224 = add(%223, meta[relay.Constant][45]); - %225 = reshape(%224, newshape=[-1, 768]); - %226 = nn.dense(%225, meta[relay.Constant][46], units=2304); - %227 = add(%226, meta[relay.Constant][47]); - %228 = reshape(%227, newshape=[50, 32, 2304]); - %229 = split(%228, indices_or_sections=[768, 1536], axis=2); - %230 = %229.0; - %231 = reshape(%230, newshape=[50, 32, 12, 64]); - %232 = transpose(%231, axes=[0, 2, 1, 3]); - %233 = %229.1; - %234 = reshape(%233, newshape=[50, 32, 12, 64]); - %235 = transpose(%234, axes=[0, 2, 3, 1]); - %236 = reshape(%235, newshape=[-1, 64, 32]); - %237 = reshape(%232, newshape=[-1, 32, 64]); - %238 = transpose(%236, axes=[0, 2, 1]); - %239 = nn.batch_matmul(%237, %238, out_dtype="float16", transpose_b=True); - %240 = reshape(%239, newshape=[50, 12, 32, 32]); - %241 = divide(%240, 8f16); - %242 = multiply(%241, meta[relay.Constant][48]); - %243 = subtract(%242, meta[relay.Constant][49]); - %244 = nn.softmax(%243, axis=3); - %245 = %229.2; - %246 = reshape(%245, newshape=[50, 32, 12, 64]); - %247 = transpose(%246, axes=[0, 2, 1, 3]); - %248 = reshape(%247, newshape=[-1, 32, 64]); - %249 = reshape(%244, newshape=[-1, 32, 32]); - %250 = transpose(%248, axes=[0, 2, 1]); - %251 = nn.batch_matmul(%249, %250, out_dtype="float16", transpose_b=True); - %252 = reshape(%251, newshape=[50, 12, 32, 64]); - %253 = transpose(%252, axes=[0, 2, 1, 3]); - %254 = reshape(%253, newshape=[50, 32, 768]); - %255 = reshape(%254, newshape=[-1, 768]); - %256 = nn.dense(%255, meta[relay.Constant][50], units=768); - %257 = add(%256, meta[relay.Constant][51]); - %258 = reshape(%257, newshape=[50, 32, 768]); - %259 = add(%215, %258); - %260 = mean(%259, axis=[-1], keepdims=True); - %261 = subtract(%259, %260); - %262 = power(%261, 2f16); - %263 = mean(%262, axis=[-1], keepdims=True); - %264 = add(%263, 1e-05f16); - %265 = sqrt(%264); - %266 = divide(%261, %265); - %267 = multiply(%266, meta[relay.Constant][52]); - %268 = add(%267, meta[relay.Constant][53]); - %269 = reshape(%268, newshape=[-1, 768]); - %270 = nn.dense(%269, meta[relay.Constant][54], units=3072); - %271 = add(%270, meta[relay.Constant][55]); - %272 = reshape(%271, newshape=[50, 32, 3072]); - %273 = power(%272, 3f16); - %274 = multiply(%273, 0.044715f16); - %275 = add(%272, %274); - %276 = multiply(%275, 0.797885f16); - %277 = tanh(%276); - %278 = multiply(%272, 0.5f16); - %279 = add(%277, 1f16); - %280 = multiply(%278, %279); - %281 = reshape(%280, newshape=[-1, 3072]); - %282 = nn.dense(%281, meta[relay.Constant][56], units=768); - %283 = add(%282, meta[relay.Constant][57]); - %284 = reshape(%283, newshape=[50, 32, 768]); - %285 = add(%259, %284); - %286 = mean(%285, axis=[-1], keepdims=True); - %287 = subtract(%285, %286); - %288 = power(%287, 2f16); - %289 = mean(%288, axis=[-1], keepdims=True); - %290 = add(%289, 1e-05f16); - %291 = sqrt(%290); - %292 = divide(%287, %291); - %293 = multiply(%292, meta[relay.Constant][58]); - %294 = add(%293, meta[relay.Constant][59]); - %295 = reshape(%294, newshape=[-1, 768]); - %296 = nn.dense(%295, meta[relay.Constant][60], units=2304); - %297 = add(%296, meta[relay.Constant][61]); - %298 = reshape(%297, newshape=[50, 32, 2304]); - %299 = split(%298, indices_or_sections=[768, 1536], axis=2); - %300 = %299.0; - %301 = reshape(%300, newshape=[50, 32, 12, 64]); - %302 = transpose(%301, axes=[0, 2, 1, 3]); - %303 = %299.1; - %304 = reshape(%303, newshape=[50, 32, 12, 64]); - %305 = transpose(%304, axes=[0, 2, 3, 1]); - %306 = reshape(%305, newshape=[-1, 64, 32]); - %307 = reshape(%302, newshape=[-1, 32, 64]); - %308 = transpose(%306, axes=[0, 2, 1]); - %309 = nn.batch_matmul(%307, %308, out_dtype="float16", transpose_b=True); - %310 = reshape(%309, newshape=[50, 12, 32, 32]); - %311 = divide(%310, 8f16); - %312 = multiply(%311, meta[relay.Constant][62]); - %313 = subtract(%312, meta[relay.Constant][63]); - %314 = nn.softmax(%313, axis=3); - %315 = %299.2; - %316 = reshape(%315, newshape=[50, 32, 12, 64]); - %317 = transpose(%316, axes=[0, 2, 1, 3]); - %318 = reshape(%317, newshape=[-1, 32, 64]); - %319 = reshape(%314, newshape=[-1, 32, 32]); - %320 = transpose(%318, axes=[0, 2, 1]); - %321 = nn.batch_matmul(%319, %320, out_dtype="float16", transpose_b=True); - %322 = reshape(%321, newshape=[50, 12, 32, 64]); - %323 = transpose(%322, axes=[0, 2, 1, 3]); - %324 = reshape(%323, newshape=[50, 32, 768]); - %325 = reshape(%324, newshape=[-1, 768]); - %326 = nn.dense(%325, meta[relay.Constant][64], units=768); - %327 = add(%326, meta[relay.Constant][65]); - %328 = reshape(%327, newshape=[50, 32, 768]); - %329 = add(%285, %328); - %330 = mean(%329, axis=[-1], keepdims=True); - %331 = subtract(%329, %330); - %332 = power(%331, 2f16); - %333 = mean(%332, axis=[-1], keepdims=True); - %334 = add(%333, 1e-05f16); - %335 = sqrt(%334); - %336 = divide(%331, %335); - %337 = multiply(%336, meta[relay.Constant][66]); - %338 = add(%337, meta[relay.Constant][67]); - %339 = reshape(%338, newshape=[-1, 768]); - %340 = nn.dense(%339, meta[relay.Constant][68], units=3072); - %341 = add(%340, meta[relay.Constant][69]); - %342 = reshape(%341, newshape=[50, 32, 3072]); - %343 = power(%342, 3f16); - %344 = multiply(%343, 0.044715f16); - %345 = add(%342, %344); - %346 = multiply(%345, 0.797885f16); - %347 = tanh(%346); - %348 = multiply(%342, 0.5f16); - %349 = add(%347, 1f16); - %350 = multiply(%348, %349); - %351 = reshape(%350, newshape=[-1, 3072]); - %352 = nn.dense(%351, meta[relay.Constant][70], units=768); - %353 = add(%352, meta[relay.Constant][71]); - %354 = reshape(%353, newshape=[50, 32, 768]); - %355 = add(%329, %354); - %356 = mean(%355, axis=[-1], keepdims=True); - %357 = subtract(%355, %356); - %358 = power(%357, 2f16); - %359 = mean(%358, axis=[-1], keepdims=True); - %360 = add(%359, 1e-05f16); - %361 = sqrt(%360); - %362 = divide(%357, %361); - %363 = multiply(%362, meta[relay.Constant][72]); - %364 = add(%363, meta[relay.Constant][73]); - %365 = reshape(%364, newshape=[-1, 768]); - %366 = nn.dense(%365, meta[relay.Constant][74], units=2304); - %367 = add(%366, meta[relay.Constant][75]); - %368 = reshape(%367, newshape=[50, 32, 2304]); - %369 = split(%368, indices_or_sections=[768, 1536], axis=2); - %370 = %369.0; - %371 = reshape(%370, newshape=[50, 32, 12, 64]); - %372 = transpose(%371, axes=[0, 2, 1, 3]); - %373 = %369.1; - %374 = reshape(%373, newshape=[50, 32, 12, 64]); - %375 = transpose(%374, axes=[0, 2, 3, 1]); - %376 = reshape(%375, newshape=[-1, 64, 32]); - %377 = reshape(%372, newshape=[-1, 32, 64]); - %378 = transpose(%376, axes=[0, 2, 1]); - %379 = nn.batch_matmul(%377, %378, out_dtype="float16", transpose_b=True); - %380 = reshape(%379, newshape=[50, 12, 32, 32]); - %381 = divide(%380, 8f16); - %382 = multiply(%381, meta[relay.Constant][76]); - %383 = subtract(%382, meta[relay.Constant][77]); - %384 = nn.softmax(%383, axis=3); - %385 = %369.2; - %386 = reshape(%385, newshape=[50, 32, 12, 64]); - %387 = transpose(%386, axes=[0, 2, 1, 3]); - %388 = reshape(%387, newshape=[-1, 32, 64]); - %389 = reshape(%384, newshape=[-1, 32, 32]); - %390 = transpose(%388, axes=[0, 2, 1]); - %391 = nn.batch_matmul(%389, %390, out_dtype="float16", transpose_b=True); - %392 = reshape(%391, newshape=[50, 12, 32, 64]); - %393 = transpose(%392, axes=[0, 2, 1, 3]); - %394 = reshape(%393, newshape=[50, 32, 768]); - %395 = reshape(%394, newshape=[-1, 768]); - %396 = nn.dense(%395, meta[relay.Constant][78], units=768); - %397 = add(%396, meta[relay.Constant][79]); - %398 = reshape(%397, newshape=[50, 32, 768]); - %399 = add(%355, %398); - %400 = mean(%399, axis=[-1], keepdims=True); - %401 = subtract(%399, %400); - %402 = power(%401, 2f16); - %403 = mean(%402, axis=[-1], keepdims=True); - %404 = add(%403, 1e-05f16); - %405 = sqrt(%404); - %406 = divide(%401, %405); - %407 = multiply(%406, meta[relay.Constant][80]); - %408 = add(%407, meta[relay.Constant][81]); - %409 = reshape(%408, newshape=[-1, 768]); - %410 = nn.dense(%409, meta[relay.Constant][82], units=3072); - %411 = add(%410, meta[relay.Constant][83]); - %412 = reshape(%411, newshape=[50, 32, 3072]); - %413 = power(%412, 3f16); - %414 = multiply(%413, 0.044715f16); - %415 = add(%412, %414); - %416 = multiply(%415, 0.797885f16); - %417 = tanh(%416); - %418 = multiply(%412, 0.5f16); - %419 = add(%417, 1f16); - %420 = multiply(%418, %419); - %421 = reshape(%420, newshape=[-1, 3072]); - %422 = nn.dense(%421, meta[relay.Constant][84], units=768); - %423 = add(%422, meta[relay.Constant][85]); - %424 = reshape(%423, newshape=[50, 32, 768]); - %425 = add(%399, %424); - %426 = mean(%425, axis=[-1], keepdims=True); - %427 = subtract(%425, %426); - %428 = power(%427, 2f16); - %429 = mean(%428, axis=[-1], keepdims=True); - %430 = add(%429, 1e-05f16); - %431 = sqrt(%430); - %432 = divide(%427, %431); - %433 = multiply(%432, meta[relay.Constant][86]); - %434 = add(%433, meta[relay.Constant][87]); - %435 = reshape(%434, newshape=[-1, 768]); - %436 = nn.dense(%435, meta[relay.Constant][88], units=2304); - %437 = add(%436, meta[relay.Constant][89]); - %438 = reshape(%437, newshape=[50, 32, 2304]); - %439 = split(%438, indices_or_sections=[768, 1536], axis=2); - %440 = %439.0; - %441 = reshape(%440, newshape=[50, 32, 12, 64]); - %442 = transpose(%441, axes=[0, 2, 1, 3]); - %443 = %439.1; - %444 = reshape(%443, newshape=[50, 32, 12, 64]); - %445 = transpose(%444, axes=[0, 2, 3, 1]); - %446 = reshape(%445, newshape=[-1, 64, 32]); - %447 = reshape(%442, newshape=[-1, 32, 64]); - %448 = transpose(%446, axes=[0, 2, 1]); - %449 = nn.batch_matmul(%447, %448, out_dtype="float16", transpose_b=True); - %450 = reshape(%449, newshape=[50, 12, 32, 32]); - %451 = divide(%450, 8f16); - %452 = multiply(%451, meta[relay.Constant][90]); - %453 = subtract(%452, meta[relay.Constant][91]); - %454 = nn.softmax(%453, axis=3); - %455 = %439.2; - %456 = reshape(%455, newshape=[50, 32, 12, 64]); - %457 = transpose(%456, axes=[0, 2, 1, 3]); - %458 = reshape(%457, newshape=[-1, 32, 64]); - %459 = reshape(%454, newshape=[-1, 32, 32]); - %460 = transpose(%458, axes=[0, 2, 1]); - %461 = nn.batch_matmul(%459, %460, out_dtype="float16", transpose_b=True); - %462 = reshape(%461, newshape=[50, 12, 32, 64]); - %463 = transpose(%462, axes=[0, 2, 1, 3]); - %464 = reshape(%463, newshape=[50, 32, 768]); - %465 = reshape(%464, newshape=[-1, 768]); - %466 = nn.dense(%465, meta[relay.Constant][92], units=768); - %467 = add(%466, meta[relay.Constant][93]); - %468 = reshape(%467, newshape=[50, 32, 768]); - %469 = add(%425, %468); - %470 = mean(%469, axis=[-1], keepdims=True); - %471 = subtract(%469, %470); - %472 = power(%471, 2f16); - %473 = mean(%472, axis=[-1], keepdims=True); - %474 = add(%473, 1e-05f16); - %475 = sqrt(%474); - %476 = divide(%471, %475); - %477 = multiply(%476, meta[relay.Constant][94]); - %478 = add(%477, meta[relay.Constant][95]); - %479 = reshape(%478, newshape=[-1, 768]); - %480 = nn.dense(%479, meta[relay.Constant][96], units=3072); - %481 = add(%480, meta[relay.Constant][97]); - %482 = reshape(%481, newshape=[50, 32, 3072]); - %483 = power(%482, 3f16); - %484 = multiply(%483, 0.044715f16); - %485 = add(%482, %484); - %486 = multiply(%485, 0.797885f16); - %487 = tanh(%486); - %488 = multiply(%482, 0.5f16); - %489 = add(%487, 1f16); - %490 = multiply(%488, %489); - %491 = reshape(%490, newshape=[-1, 3072]); - %492 = nn.dense(%491, meta[relay.Constant][98], units=768); - %493 = add(%492, meta[relay.Constant][99]); - %494 = reshape(%493, newshape=[50, 32, 768]); - %495 = add(%469, %494); - %496 = mean(%495, axis=[-1], keepdims=True); - %497 = subtract(%495, %496); - %498 = power(%497, 2f16); - %499 = mean(%498, axis=[-1], keepdims=True); - %500 = add(%499, 1e-05f16); - %501 = sqrt(%500); - %502 = divide(%497, %501); - %503 = multiply(%502, meta[relay.Constant][100]); - %504 = add(%503, meta[relay.Constant][101]); - %505 = reshape(%504, newshape=[-1, 768]); - %506 = nn.dense(%505, meta[relay.Constant][102], units=2304); - %507 = add(%506, meta[relay.Constant][103]); - %508 = reshape(%507, newshape=[50, 32, 2304]); - %509 = split(%508, indices_or_sections=[768, 1536], axis=2); - %510 = %509.0; - %511 = reshape(%510, newshape=[50, 32, 12, 64]); - %512 = transpose(%511, axes=[0, 2, 1, 3]); - %513 = %509.1; - %514 = reshape(%513, newshape=[50, 32, 12, 64]); - %515 = transpose(%514, axes=[0, 2, 3, 1]); - %516 = reshape(%515, newshape=[-1, 64, 32]); - %517 = reshape(%512, newshape=[-1, 32, 64]); - %518 = transpose(%516, axes=[0, 2, 1]); - %519 = nn.batch_matmul(%517, %518, out_dtype="float16", transpose_b=True); - %520 = reshape(%519, newshape=[50, 12, 32, 32]); - %521 = divide(%520, 8f16); - %522 = multiply(%521, meta[relay.Constant][104]); - %523 = subtract(%522, meta[relay.Constant][105]); - %524 = nn.softmax(%523, axis=3); - %525 = %509.2; - %526 = reshape(%525, newshape=[50, 32, 12, 64]); - %527 = transpose(%526, axes=[0, 2, 1, 3]); - %528 = reshape(%527, newshape=[-1, 32, 64]); - %529 = reshape(%524, newshape=[-1, 32, 32]); - %530 = transpose(%528, axes=[0, 2, 1]); - %531 = nn.batch_matmul(%529, %530, out_dtype="float16", transpose_b=True); - %532 = reshape(%531, newshape=[50, 12, 32, 64]); - %533 = transpose(%532, axes=[0, 2, 1, 3]); - %534 = reshape(%533, newshape=[50, 32, 768]); - %535 = reshape(%534, newshape=[-1, 768]); - %536 = nn.dense(%535, meta[relay.Constant][106], units=768); - %537 = add(%536, meta[relay.Constant][107]); - %538 = reshape(%537, newshape=[50, 32, 768]); - %539 = add(%495, %538); - %540 = mean(%539, axis=[-1], keepdims=True); - %541 = subtract(%539, %540); - %542 = power(%541, 2f16); - %543 = mean(%542, axis=[-1], keepdims=True); - %544 = add(%543, 1e-05f16); - %545 = sqrt(%544); - %546 = divide(%541, %545); - %547 = multiply(%546, meta[relay.Constant][108]); - %548 = add(%547, meta[relay.Constant][109]); - %549 = reshape(%548, newshape=[-1, 768]); - %550 = nn.dense(%549, meta[relay.Constant][110], units=3072); - %551 = add(%550, meta[relay.Constant][111]); - %552 = reshape(%551, newshape=[50, 32, 3072]); - %553 = power(%552, 3f16); - %554 = multiply(%553, 0.044715f16); - %555 = add(%552, %554); - %556 = multiply(%555, 0.797885f16); - %557 = tanh(%556); - %558 = multiply(%552, 0.5f16); - %559 = add(%557, 1f16); - %560 = multiply(%558, %559); - %561 = reshape(%560, newshape=[-1, 3072]); - %562 = nn.dense(%561, meta[relay.Constant][112], units=768); - %563 = add(%562, meta[relay.Constant][113]); - %564 = reshape(%563, newshape=[50, 32, 768]); - %565 = add(%539, %564); - %566 = mean(%565, axis=[-1], keepdims=True); - %567 = subtract(%565, %566); - %568 = power(%567, 2f16); - %569 = mean(%568, axis=[-1], keepdims=True); - %570 = add(%569, 1e-05f16); - %571 = sqrt(%570); - %572 = divide(%567, %571); - %573 = multiply(%572, meta[relay.Constant][114]); - %574 = add(%573, meta[relay.Constant][115]); - %575 = reshape(%574, newshape=[-1, 768]); - %576 = nn.dense(%575, meta[relay.Constant][116], units=2304); - %577 = add(%576, meta[relay.Constant][117]); - %578 = reshape(%577, newshape=[50, 32, 2304]); - %579 = split(%578, indices_or_sections=[768, 1536], axis=2); - %580 = %579.0; - %581 = reshape(%580, newshape=[50, 32, 12, 64]); - %582 = transpose(%581, axes=[0, 2, 1, 3]); - %583 = %579.1; - %584 = reshape(%583, newshape=[50, 32, 12, 64]); - %585 = transpose(%584, axes=[0, 2, 3, 1]); - %586 = reshape(%585, newshape=[-1, 64, 32]); - %587 = reshape(%582, newshape=[-1, 32, 64]); - %588 = transpose(%586, axes=[0, 2, 1]); - %589 = nn.batch_matmul(%587, %588, out_dtype="float16", transpose_b=True); - %590 = reshape(%589, newshape=[50, 12, 32, 32]); - %591 = divide(%590, 8f16); - %592 = multiply(%591, meta[relay.Constant][118]); - %593 = subtract(%592, meta[relay.Constant][119]); - %594 = nn.softmax(%593, axis=3); - %595 = %579.2; - %596 = reshape(%595, newshape=[50, 32, 12, 64]); - %597 = transpose(%596, axes=[0, 2, 1, 3]); - %598 = reshape(%597, newshape=[-1, 32, 64]); - %599 = reshape(%594, newshape=[-1, 32, 32]); - %600 = transpose(%598, axes=[0, 2, 1]); - %601 = nn.batch_matmul(%599, %600, out_dtype="float16", transpose_b=True); - %602 = reshape(%601, newshape=[50, 12, 32, 64]); - %603 = transpose(%602, axes=[0, 2, 1, 3]); - %604 = reshape(%603, newshape=[50, 32, 768]); - %605 = reshape(%604, newshape=[-1, 768]); - %606 = nn.dense(%605, meta[relay.Constant][120], units=768); - %607 = add(%606, meta[relay.Constant][121]); - %608 = reshape(%607, newshape=[50, 32, 768]); - %609 = add(%565, %608); - %610 = mean(%609, axis=[-1], keepdims=True); - %611 = subtract(%609, %610); - %612 = power(%611, 2f16); - %613 = mean(%612, axis=[-1], keepdims=True); - %614 = add(%613, 1e-05f16); - %615 = sqrt(%614); - %616 = divide(%611, %615); - %617 = multiply(%616, meta[relay.Constant][122]); - %618 = add(%617, meta[relay.Constant][123]); - %619 = reshape(%618, newshape=[-1, 768]); - %620 = nn.dense(%619, meta[relay.Constant][124], units=3072); - %621 = add(%620, meta[relay.Constant][125]); - %622 = reshape(%621, newshape=[50, 32, 3072]); - %623 = power(%622, 3f16); - %624 = multiply(%623, 0.044715f16); - %625 = add(%622, %624); - %626 = multiply(%625, 0.797885f16); - %627 = tanh(%626); - %628 = multiply(%622, 0.5f16); - %629 = add(%627, 1f16); - %630 = multiply(%628, %629); - %631 = reshape(%630, newshape=[-1, 3072]); - %632 = nn.dense(%631, meta[relay.Constant][126], units=768); - %633 = add(%632, meta[relay.Constant][127]); - %634 = reshape(%633, newshape=[50, 32, 768]); - %635 = add(%609, %634); - %636 = mean(%635, axis=[-1], keepdims=True); - %637 = subtract(%635, %636); - %638 = power(%637, 2f16); - %639 = mean(%638, axis=[-1], keepdims=True); - %640 = add(%639, 1e-05f16); - %641 = sqrt(%640); - %642 = divide(%637, %641); - %643 = multiply(%642, meta[relay.Constant][128]); - %644 = add(%643, meta[relay.Constant][129]); - %645 = reshape(%644, newshape=[-1, 768]); - %646 = nn.dense(%645, meta[relay.Constant][130], units=2304); - %647 = add(%646, meta[relay.Constant][131]); - %648 = reshape(%647, newshape=[50, 32, 2304]); - %649 = split(%648, indices_or_sections=[768, 1536], axis=2); - %650 = %649.0; - %651 = reshape(%650, newshape=[50, 32, 12, 64]); - %652 = transpose(%651, axes=[0, 2, 1, 3]); - %653 = %649.1; - %654 = reshape(%653, newshape=[50, 32, 12, 64]); - %655 = transpose(%654, axes=[0, 2, 3, 1]); - %656 = reshape(%655, newshape=[-1, 64, 32]); - %657 = reshape(%652, newshape=[-1, 32, 64]); - %658 = transpose(%656, axes=[0, 2, 1]); - %659 = nn.batch_matmul(%657, %658, out_dtype="float16", transpose_b=True); - %660 = reshape(%659, newshape=[50, 12, 32, 32]); - %661 = divide(%660, 8f16); - %662 = multiply(%661, meta[relay.Constant][132]); - %663 = subtract(%662, meta[relay.Constant][133]); - %664 = nn.softmax(%663, axis=3); - %665 = %649.2; - %666 = reshape(%665, newshape=[50, 32, 12, 64]); - %667 = transpose(%666, axes=[0, 2, 1, 3]); - %668 = reshape(%667, newshape=[-1, 32, 64]); - %669 = reshape(%664, newshape=[-1, 32, 32]); - %670 = transpose(%668, axes=[0, 2, 1]); - %671 = nn.batch_matmul(%669, %670, out_dtype="float16", transpose_b=True); - %672 = reshape(%671, newshape=[50, 12, 32, 64]); - %673 = transpose(%672, axes=[0, 2, 1, 3]); - %674 = reshape(%673, newshape=[50, 32, 768]); - %675 = reshape(%674, newshape=[-1, 768]); - %676 = nn.dense(%675, meta[relay.Constant][134], units=768); - %677 = add(%676, meta[relay.Constant][135]); - %678 = reshape(%677, newshape=[50, 32, 768]); - %679 = add(%635, %678); - %680 = mean(%679, axis=[-1], keepdims=True); - %681 = subtract(%679, %680); - %682 = power(%681, 2f16); - %683 = mean(%682, axis=[-1], keepdims=True); - %684 = add(%683, 1e-05f16); - %685 = sqrt(%684); - %686 = divide(%681, %685); - %687 = multiply(%686, meta[relay.Constant][136]); - %688 = add(%687, meta[relay.Constant][137]); - %689 = reshape(%688, newshape=[-1, 768]); - %690 = nn.dense(%689, meta[relay.Constant][138], units=3072); - %691 = add(%690, meta[relay.Constant][139]); - %692 = reshape(%691, newshape=[50, 32, 3072]); - %693 = power(%692, 3f16); - %694 = multiply(%693, 0.044715f16); - %695 = add(%692, %694); - %696 = multiply(%695, 0.797885f16); - %697 = tanh(%696); - %698 = multiply(%692, 0.5f16); - %699 = add(%697, 1f16); - %700 = multiply(%698, %699); - %701 = reshape(%700, newshape=[-1, 3072]); - %702 = nn.dense(%701, meta[relay.Constant][140], units=768); - %703 = add(%702, meta[relay.Constant][141]); - %704 = reshape(%703, newshape=[50, 32, 768]); - %705 = add(%679, %704); - %706 = mean(%705, axis=[-1], keepdims=True); - %707 = subtract(%705, %706); - %708 = power(%707, 2f16); - %709 = mean(%708, axis=[-1], keepdims=True); - %710 = add(%709, 1e-05f16); - %711 = sqrt(%710); - %712 = divide(%707, %711); - %713 = multiply(%712, meta[relay.Constant][142]); - %714 = add(%713, meta[relay.Constant][143]); - %715 = reshape(%714, newshape=[-1, 768]); - %716 = nn.dense(%715, meta[relay.Constant][144], units=2304); - %717 = add(%716, meta[relay.Constant][145]); - %718 = reshape(%717, newshape=[50, 32, 2304]); - %719 = split(%718, indices_or_sections=[768, 1536], axis=2); - %720 = %719.0; - %721 = reshape(%720, newshape=[50, 32, 12, 64]); - %722 = transpose(%721, axes=[0, 2, 1, 3]); - %723 = %719.1; - %724 = reshape(%723, newshape=[50, 32, 12, 64]); - %725 = transpose(%724, axes=[0, 2, 3, 1]); - %726 = reshape(%725, newshape=[-1, 64, 32]); - %727 = reshape(%722, newshape=[-1, 32, 64]); - %728 = transpose(%726, axes=[0, 2, 1]); - %729 = nn.batch_matmul(%727, %728, out_dtype="float16", transpose_b=True); - %730 = reshape(%729, newshape=[50, 12, 32, 32]); - %731 = divide(%730, 8f16); - %732 = multiply(%731, meta[relay.Constant][146]); - %733 = subtract(%732, meta[relay.Constant][147]); - %734 = nn.softmax(%733, axis=3); - %735 = %719.2; - %736 = reshape(%735, newshape=[50, 32, 12, 64]); - %737 = transpose(%736, axes=[0, 2, 1, 3]); - %738 = reshape(%737, newshape=[-1, 32, 64]); - %739 = reshape(%734, newshape=[-1, 32, 32]); - %740 = transpose(%738, axes=[0, 2, 1]); - %741 = nn.batch_matmul(%739, %740, out_dtype="float16", transpose_b=True); - %742 = reshape(%741, newshape=[50, 12, 32, 64]); - %743 = transpose(%742, axes=[0, 2, 1, 3]); - %744 = reshape(%743, newshape=[50, 32, 768]); - %745 = reshape(%744, newshape=[-1, 768]); - %746 = nn.dense(%745, meta[relay.Constant][148], units=768); - %747 = add(%746, meta[relay.Constant][149]); - %748 = reshape(%747, newshape=[50, 32, 768]); - %749 = add(%705, %748); - %750 = mean(%749, axis=[-1], keepdims=True); - %751 = subtract(%749, %750); - %752 = power(%751, 2f16); - %753 = mean(%752, axis=[-1], keepdims=True); - %754 = add(%753, 1e-05f16); - %755 = sqrt(%754); - %756 = divide(%751, %755); - %757 = multiply(%756, meta[relay.Constant][150]); - %758 = add(%757, meta[relay.Constant][151]); - %759 = reshape(%758, newshape=[-1, 768]); - %760 = nn.dense(%759, meta[relay.Constant][152], units=3072); - %761 = add(%760, meta[relay.Constant][153]); - %762 = reshape(%761, newshape=[50, 32, 3072]); - %763 = power(%762, 3f16); - %764 = multiply(%763, 0.044715f16); - %765 = add(%762, %764); - %766 = multiply(%765, 0.797885f16); - %767 = tanh(%766); - %768 = multiply(%762, 0.5f16); - %769 = add(%767, 1f16); - %770 = multiply(%768, %769); - %771 = reshape(%770, newshape=[-1, 3072]); - %772 = nn.dense(%771, meta[relay.Constant][154], units=768); - %773 = add(%772, meta[relay.Constant][155]); - %774 = reshape(%773, newshape=[50, 32, 768]); - %775 = add(%749, %774); - %776 = mean(%775, axis=[-1], keepdims=True); - %777 = subtract(%775, %776); - %778 = power(%777, 2f16); - %779 = mean(%778, axis=[-1], keepdims=True); - %780 = add(%779, 1e-05f16); - %781 = sqrt(%780); - %782 = divide(%777, %781); - %783 = multiply(%782, meta[relay.Constant][156]); - %784 = add(%783, meta[relay.Constant][157]); - %785 = reshape(%784, newshape=[-1, 768]); - %786 = nn.dense(%785, meta[relay.Constant][158], units=2304); - %787 = add(%786, meta[relay.Constant][159]); - %788 = reshape(%787, newshape=[50, 32, 2304]); - %789 = split(%788, indices_or_sections=[768, 1536], axis=2); - %790 = %789.0; - %791 = reshape(%790, newshape=[50, 32, 12, 64]); - %792 = transpose(%791, axes=[0, 2, 1, 3]); - %793 = %789.1; - %794 = reshape(%793, newshape=[50, 32, 12, 64]); - %795 = transpose(%794, axes=[0, 2, 3, 1]); - %796 = reshape(%795, newshape=[-1, 64, 32]); - %797 = reshape(%792, newshape=[-1, 32, 64]); - %798 = transpose(%796, axes=[0, 2, 1]); - %799 = nn.batch_matmul(%797, %798, out_dtype="float16", transpose_b=True); - %800 = reshape(%799, newshape=[50, 12, 32, 32]); - %801 = divide(%800, 8f16); - %802 = multiply(%801, meta[relay.Constant][160]); - %803 = subtract(%802, meta[relay.Constant][161]); - %804 = nn.softmax(%803, axis=3); - %805 = %789.2; - %806 = reshape(%805, newshape=[50, 32, 12, 64]); - %807 = transpose(%806, axes=[0, 2, 1, 3]); - %808 = reshape(%807, newshape=[-1, 32, 64]); - %809 = reshape(%804, newshape=[-1, 32, 32]); - %810 = transpose(%808, axes=[0, 2, 1]); - %811 = nn.batch_matmul(%809, %810, out_dtype="float16", transpose_b=True); - %812 = reshape(%811, newshape=[50, 12, 32, 64]); - %813 = transpose(%812, axes=[0, 2, 1, 3]); - %814 = reshape(%813, newshape=[50, 32, 768]); - %815 = reshape(%814, newshape=[-1, 768]); - %816 = nn.dense(%815, meta[relay.Constant][162], units=768); - %817 = add(%816, meta[relay.Constant][163]); - %818 = reshape(%817, newshape=[50, 32, 768]); - %819 = add(%775, %818); - %820 = mean(%819, axis=[-1], keepdims=True); - %821 = subtract(%819, %820); - %822 = power(%821, 2f16); - %823 = mean(%822, axis=[-1], keepdims=True); - %824 = add(%823, 1e-05f16); - %825 = sqrt(%824); - %826 = divide(%821, %825); - %827 = multiply(%826, meta[relay.Constant][164]); - %828 = add(%827, meta[relay.Constant][165]); - %829 = reshape(%828, newshape=[-1, 768]); - %830 = nn.dense(%829, meta[relay.Constant][166], units=3072); - %831 = add(%830, meta[relay.Constant][167]); - %832 = reshape(%831, newshape=[50, 32, 3072]); - %833 = power(%832, 3f16); - %834 = multiply(%833, 0.044715f16); - %835 = add(%832, %834); - %836 = multiply(%835, 0.797885f16); - %837 = tanh(%836); - %838 = multiply(%832, 0.5f16); - %839 = add(%837, 1f16); - %840 = multiply(%838, %839); - %841 = reshape(%840, newshape=[-1, 3072]); - %842 = nn.dense(%841, meta[relay.Constant][168], units=768); - %843 = add(%842, meta[relay.Constant][169]); - %844 = reshape(%843, newshape=[50, 32, 768]); - %845 = add(%819, %844); - %846 = mean(%845, axis=[-1], keepdims=True); - %847 = subtract(%845, %846); - %848 = power(%847, 2f16); - %849 = mean(%848, axis=[-1], keepdims=True); - %850 = add(%849, 1e-05f16); - %851 = sqrt(%850); - %852 = divide(%847, %851); - %853 = multiply(%852, meta[relay.Constant][170]); - %854 = add(%853, meta[relay.Constant][171]); - %855 = transpose(%24, axes=[0, 2, 1, 3]); - %856 = expand_dims(%855, axis=0); - %857 = expand_dims(%37, axis=0); - %858 = (%856, %857); - %859 = transpose(%94, axes=[0, 2, 1, 3]); - %860 = expand_dims(%859, axis=0); - %861 = expand_dims(%107, axis=0); - %862 = (%860, %861); - %863 = transpose(%164, axes=[0, 2, 1, 3]); - %864 = expand_dims(%863, axis=0); - %865 = expand_dims(%177, axis=0); - %866 = (%864, %865); - %867 = transpose(%234, axes=[0, 2, 1, 3]); - %868 = expand_dims(%867, axis=0); - %869 = expand_dims(%247, axis=0); - %870 = (%868, %869); - %871 = transpose(%304, axes=[0, 2, 1, 3]); - %872 = expand_dims(%871, axis=0); - %873 = expand_dims(%317, axis=0); - %874 = (%872, %873); - %875 = transpose(%374, axes=[0, 2, 1, 3]); - %876 = expand_dims(%875, axis=0); - %877 = expand_dims(%387, axis=0); - %878 = (%876, %877); - %879 = transpose(%444, axes=[0, 2, 1, 3]); - %880 = expand_dims(%879, axis=0); - %881 = expand_dims(%457, axis=0); - %882 = (%880, %881); - %883 = transpose(%514, axes=[0, 2, 1, 3]); - %884 = expand_dims(%883, axis=0); - %885 = expand_dims(%527, axis=0); - %886 = (%884, %885); - %887 = transpose(%584, axes=[0, 2, 1, 3]); - %888 = expand_dims(%887, axis=0); - %889 = expand_dims(%597, axis=0); - %890 = (%888, %889); - %891 = transpose(%654, axes=[0, 2, 1, 3]); - %892 = expand_dims(%891, axis=0); - %893 = expand_dims(%667, axis=0); - %894 = (%892, %893); - %895 = transpose(%724, axes=[0, 2, 1, 3]); - %896 = expand_dims(%895, axis=0); - %897 = expand_dims(%737, axis=0); - %898 = (%896, %897); - %899 = transpose(%794, axes=[0, 2, 1, 3]); - %900 = expand_dims(%899, axis=0); - %901 = expand_dims(%807, axis=0); - %902 = (%900, %901); - %903 = reshape(%854, newshape=[1, 50, 32, 768]); - %904 = concatenate(%858); - %905 = concatenate(%862); - %906 = concatenate(%866); - %907 = concatenate(%870); - %908 = concatenate(%874); - %909 = concatenate(%878); - %910 = concatenate(%882); - %911 = concatenate(%886); - %912 = concatenate(%890); - %913 = concatenate(%894); - %914 = concatenate(%898); - %915 = concatenate(%902); - (%903, %904, %905, %906, %907, %908, %909, %910, %911, %912, %913, %914, %915) - } - """, - "from_string", - None, - metatable, - ) - return { - "name": "gpt2_16", - "input_shapes": {"x": [1, 50, 32]}, - "input_dtypes": {"x": "int64"}, - "mod": mod, - "params": None, - "main_dtype": "float16", - } - - -def gpt2_extract_consts(dtype): - return make_consts( - dtype, - [ - (768, 768), # 0 - (768,), # 1 - (768,), # 2 - (768,), # 3 - (3072, 768), # 4 - (3072,), # 5 - (1, 32, 768), # 6 - ], - ) - - -def gpt2_extract(): - metatable = {"relay.Constant": gpt2_extract_consts("float32")} - mod = tvm.parser.parse( - """ - #[version = "0.0.5"] - def @main(%x: Tensor[(1600, 768), float32]) -> Tensor[(50, 32, 3072), float32] { - %46 = nn.dense(%x, meta[relay.Constant][0], units=768); - %47 = add(%46, meta[relay.Constant][1]); - %48 = reshape(%47, newshape=[50, 32, 768]); - %49 = add(meta[relay.Constant][6], %48); - %50 = mean(%49, axis=[-1], keepdims=True); - %51 = subtract(%49, %50); - %52 = power(%51, 2f); - %53 = mean(%52, axis=[-1], keepdims=True); - %54 = add(%53, 1e-05f); - %55 = sqrt(%54); - %56 = divide(%51, %55); - %57 = multiply(%56, meta[relay.Constant][2]); - %58 = add(%57, meta[relay.Constant][3]); - %59 = reshape(%58, newshape=[-1, 768]); - %60 = nn.dense(%59, meta[relay.Constant][4], units=3072); - %61 = add(%60, meta[relay.Constant][5]); - %62 = reshape(%61, newshape=[50, 32, 3072]); - %63 = power(%62, 3f); - %64 = multiply(%63, 0.044715f); - %65 = add(%62, %64); - %66 = multiply(%65, 0.797885f); - %67 = tanh(%66); - %68 = multiply(%62, 0.5f); - %69 = add(%67, 1f); - %70 = multiply(%68, %69); - %70 - } - """, - "from_string", - None, - metatable, - ) - return { - "input_shapes": {"x": [1600, 768]}, - "input_dtypes": {"x": "float32"}, - "mod": mod, - "params": None, - "main_dtype": "float32", - } - - -def gpt2_extract_16(): - metatable = {"relay.Constant": gpt2_extract_consts("float16")} - mod = tvm.parser.parse( - """ - #[version = "0.0.5"] - def @main(%x: Tensor[(1600, 768), float16]) -> Tensor[(50, 32, 3072), float16] { - %46 = nn.dense(%x, meta[relay.Constant][0], units=768); - %47 = add(%46, meta[relay.Constant][1]); - %48 = reshape(%47, newshape=[50, 32, 768]); - %49 = add(meta[relay.Constant][6], %48); - %50 = mean(%49, axis=[-1], keepdims=True); - %51 = subtract(%49, %50); - %52 = power(%51, 2f16); - %53 = mean(%52, axis=[-1], keepdims=True); - %54 = add(%53, 1e-05f16); - %55 = sqrt(%54); - %56 = divide(%51, %55); - %57 = multiply(%56, meta[relay.Constant][2]); - %58 = add(%57, meta[relay.Constant][3]); - %59 = reshape(%58, newshape=[-1, 768]); - %60 = nn.dense(%59, meta[relay.Constant][4], units=3072); - %61 = add(%60, meta[relay.Constant][5]); - %62 = reshape(%61, newshape=[50, 32, 3072]); - %63 = power(%62, 3f16); - %64 = multiply(%63, 0.044715f16); - %65 = add(%62, %64); - %66 = multiply(%65, 0.797885f16); - %67 = tanh(%66); - %68 = multiply(%62, 0.5f16); - %69 = add(%67, 1f16); - %70 = multiply(%68, %69); - %70 - } - """, - "from_string", - None, - metatable, - ) - return { - "name": "gpt2_extract_16", - "input_shapes": {"x": [1600, 768]}, - "input_dtypes": {"x": "float16"}, - "mod": mod, - "params": None, - "main_dtype": "float16", - } - - -def gpt2_16_for_cutlass_extract_consts(dtype): - return make_consts( - "float16", - [ - (2304, 768), # 0 - (2304,), # 1 - (600, 32, 64), # 2 - (600, 32, 32), # 3 - ], - ) - - -def gpt2_16_for_cutlass_extract(): - metatable = {"relay.Constant": gpt2_16_for_cutlass_extract_consts("float16")} - mod = tvm.parser.parse( - """ - #[version = "0.0.5"] - def @main(%x0: Tensor[(1600, 768), float16], - %x3: Tensor[(600, 32, 64), float16]) - -> (Tensor[(1600, 2304), float16], Tensor[(1200, 32, 32), float16]) { - %0 = nn.dense(%x0, meta[relay.Constant][0], units=2304); - %1 = add(%0, meta[relay.Constant][1]); - %2 = nn.batch_matmul(%x3, meta[relay.Constant][2], out_dtype="float16", transpose_b=True); - %3 = (%2, meta[relay.Constant][3]); - %4 = concatenate(%3); - (%1, %4) - } - """, - "from_string", - None, - metatable, - ) - return { - "name": "gpt2_16_for_cutlass_extract", - "input_shapes": {"x0": (1600, 768), "x3": (600, 32, 64)}, - "input_dtypes": {"x0": "float16", "x3": "float16"}, - "mod": mod, - "params": None, - "main_dtype": "float16", - } - - def resnet50_consts(dtype): return make_consts( dtype, @@ -3808,372 +1525,3 @@ def @main(%data: Tensor[(1, 3, 224, 224), float16]) -> Tensor[(1, 1000), float16 "params": None, "main_dtype": "float16", } - - -def batch_norm_extract(): - consts = make_consts( - "float32", - [ - (32,), # 0 - (32,), # 1 - (32,), # 2 - (32,), # 3 - ], - ) - metatable = {"relay.Constant": consts} - mod = tvm.parser.parse( - """ - #[version = "0.0.5"] - def @main(%FunctionVar_0: Tensor[(1, 32, 112, 112), float32]) -> Tensor[(1, 32, 112, 112), float32] { - %3 = nn.batch_norm(%FunctionVar_0, meta[relay.Constant][0], meta[relay.Constant][1], meta[relay.Constant][2], meta[relay.Constant][3]); - %3.0 - } - """, - "from_string", - None, - metatable, - ) - return { - "name": "batch_norm_extract", - "input_shapes": {"FunctionVar_0": [1, 32, 112, 112]}, - "input_dtypes": {"FunctionVar_0": "float32"}, - "mod": mod, - "params": None, - "main_dtype": "float32", - } - - -def resnext50_32x4d_consts(dtype): - return make_consts( - dtype, - [ - (128, 64, 1, 1), # 0 - (128, 4, 3, 3), # 1 - (256, 128, 1, 1), # 2 - (256, 64, 1, 1), # 3 - (128, 256, 1, 1), # 4 - (128, 4, 3, 3), # 5 - (256, 128, 1, 1), # 6 - (128, 256, 1, 1), # 7 - (128, 4, 3, 3), # 8 - (256, 128, 1, 1), # 9 - (256, 256, 1, 1), # 10 - (256, 8, 3, 3), # 11 - (512, 256, 1, 1), # 12 - (512, 256, 1, 1), # 13 - (256, 512, 1, 1), # 14 - (256, 8, 3, 3), # 15 - (512, 256, 1, 1), # 16 - (256, 512, 1, 1), # 17 - (256, 8, 3, 3), # 18 - (512, 256, 1, 1), # 19 - (256, 512, 1, 1), # 20 - (256, 8, 3, 3), # 21 - (512, 256, 1, 1), # 22 - (512, 512, 1, 1), # 23 - (512, 16, 3, 3), # 24 - (1024, 512, 1, 1), # 25 - (1024, 512, 1, 1), # 26 - (512, 1024, 1, 1), # 27 - (512, 16, 3, 3), # 28 - (1024, 512, 1, 1), # 29 - (512, 1024, 1, 1), # 30 - (512, 16, 3, 3), # 31 - (1024, 512, 1, 1), # 32 - (512, 1024, 1, 1), # 33 - (512, 16, 3, 3), # 34 - (1024, 512, 1, 1), # 35 - (512, 1024, 1, 1), # 36 - (512, 16, 3, 3), # 37 - (1024, 512, 1, 1), # 38 - (512, 1024, 1, 1), # 39 - (512, 16, 3, 3), # 40 - (1024, 512, 1, 1), # 41 - (1024, 1024, 1, 1), # 42 - (1024, 32, 3, 3), # 43 - (2048, 1024, 1, 1), # 44 - (2048, 1024, 1, 1), # 45 - (1024, 2048, 1, 1), # 46 - (1024, 32, 3, 3), # 47 - (2048, 1024, 1, 1), # 48 - (1024, 2048, 1, 1), # 49 - (1024, 32, 3, 3), # 50 - (2048, 1024, 1, 1), # 51 - ], - ) - - -def resnext50_32x4d(): - metatable = {"relay.Constant": resnext50_32x4d_consts("float32")} - mod = tvm.parser.parse( - """ - #[version = "0.0.5"] - def @main(%x: Tensor[(1, 64, 56, 56), float32]) { - %0 = nn.conv2d(%x, meta[relay.Constant][0], padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1]); - %1 = nn.relu(%0); - %2 = nn.conv2d(%1, meta[relay.Constant][1], padding=[1, 1, 1, 1], groups=32, channels=128, kernel_size=[3, 3]); - %3 = nn.relu(%2); - %4 = nn.conv2d(%3, meta[relay.Constant][2], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %5 = nn.conv2d(%x, meta[relay.Constant][3], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %6 = add(%4, %5); - %7 = nn.relu(%6); - %8 = nn.conv2d(%7, meta[relay.Constant][4], padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1]); - %9 = nn.relu(%8); - %10 = nn.conv2d(%9, meta[relay.Constant][5], padding=[1, 1, 1, 1], groups=32, channels=128, kernel_size=[3, 3]); - %11 = nn.relu(%10); - %12 = nn.conv2d(%11, meta[relay.Constant][6], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %13 = add(%12, %7); - %14 = nn.relu(%13); - %15 = nn.conv2d(%14, meta[relay.Constant][7], padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1]); - %16 = nn.relu(%15); - %17 = nn.conv2d(%16, meta[relay.Constant][8], padding=[1, 1, 1, 1], groups=32, channels=128, kernel_size=[3, 3]); - %18 = nn.relu(%17); - %19 = nn.conv2d(%18, meta[relay.Constant][9], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %20 = add(%19, %14); - %21 = nn.relu(%20); - %22 = nn.conv2d(%21, meta[relay.Constant][10], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %23 = nn.relu(%22); - %24 = nn.conv2d(%23, meta[relay.Constant][11], strides=[2, 2], padding=[1, 1, 1, 1], groups=32, channels=256, kernel_size=[3, 3]); - %25 = nn.relu(%24); - %26 = nn.conv2d(%25, meta[relay.Constant][12], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %27 = nn.conv2d(%21, meta[relay.Constant][13], strides=[2, 2], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %28 = add(%26, %27); - %29 = nn.relu(%28); - %30 = nn.conv2d(%29, meta[relay.Constant][14], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %31 = nn.relu(%30); - %32 = nn.conv2d(%31, meta[relay.Constant][15], padding=[1, 1, 1, 1], groups=32, channels=256, kernel_size=[3, 3]); - %33 = nn.relu(%32); - %34 = nn.conv2d(%33, meta[relay.Constant][16], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %35 = add(%34, %29); - %36 = nn.relu(%35); - %37 = nn.conv2d(%36, meta[relay.Constant][17], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %38 = nn.relu(%37); - %39 = nn.conv2d(%38, meta[relay.Constant][18], padding=[1, 1, 1, 1], groups=32, channels=256, kernel_size=[3, 3]); - %40 = nn.relu(%39); - %41 = nn.conv2d(%40, meta[relay.Constant][19], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %42 = add(%41, %36); - %43 = nn.relu(%42); - %44 = nn.conv2d(%43, meta[relay.Constant][20], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %45 = nn.relu(%44); - %46 = nn.conv2d(%45, meta[relay.Constant][21], padding=[1, 1, 1, 1], groups=32, channels=256, kernel_size=[3, 3]); - %47 = nn.relu(%46); - %48 = nn.conv2d(%47, meta[relay.Constant][22], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %49 = add(%48, %43); - %50 = nn.relu(%49); - %51 = nn.conv2d(%50, meta[relay.Constant][23], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %52 = nn.relu(%51); - %53 = nn.conv2d(%52, meta[relay.Constant][24], strides=[2, 2], padding=[1, 1, 1, 1], groups=32, channels=512, kernel_size=[3, 3]); - %54 = nn.relu(%53); - %55 = nn.conv2d(%54, meta[relay.Constant][25], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %56 = nn.conv2d(%50, meta[relay.Constant][26], strides=[2, 2], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %57 = add(%55, %56); - %58 = nn.relu(%57); - %59 = nn.conv2d(%58, meta[relay.Constant][27], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %60 = nn.relu(%59); - %61 = nn.conv2d(%60, meta[relay.Constant][28], padding=[1, 1, 1, 1], groups=32, channels=512, kernel_size=[3, 3]); - %62 = nn.relu(%61); - %63 = nn.conv2d(%62, meta[relay.Constant][29], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %64 = add(%63, %58); - %65 = nn.relu(%64); - %66 = nn.conv2d(%65, meta[relay.Constant][30], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %67 = nn.relu(%66); - %68 = nn.conv2d(%67, meta[relay.Constant][31], padding=[1, 1, 1, 1], groups=32, channels=512, kernel_size=[3, 3]); - %69 = nn.relu(%68); - %70 = nn.conv2d(%69, meta[relay.Constant][32], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %71 = add(%70, %65); - %72 = nn.relu(%71); - %73 = nn.conv2d(%72, meta[relay.Constant][33], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %74 = nn.relu(%73); - %75 = nn.conv2d(%74, meta[relay.Constant][34], padding=[1, 1, 1, 1], groups=32, channels=512, kernel_size=[3, 3]); - %76 = nn.relu(%75); - %77 = nn.conv2d(%76, meta[relay.Constant][35], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %78 = add(%77, %72); - %79 = nn.relu(%78); - %80 = nn.conv2d(%79, meta[relay.Constant][36], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %81 = nn.relu(%80); - %82 = nn.conv2d(%81, meta[relay.Constant][37], padding=[1, 1, 1, 1], groups=32, channels=512, kernel_size=[3, 3]); - %83 = nn.relu(%82); - %84 = nn.conv2d(%83, meta[relay.Constant][38], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %85 = add(%84, %79); - %86 = nn.relu(%85); - %87 = nn.conv2d(%86, meta[relay.Constant][39], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %88 = nn.relu(%87); - %89 = nn.conv2d(%88, meta[relay.Constant][40], padding=[1, 1, 1, 1], groups=32, channels=512, kernel_size=[3, 3]); - %90 = nn.relu(%89); - %91 = nn.conv2d(%90, meta[relay.Constant][41], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %92 = add(%91, %86); - %93 = nn.relu(%92); - %94 = nn.conv2d(%93, meta[relay.Constant][42], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %95 = nn.relu(%94); - %96 = nn.conv2d(%95, meta[relay.Constant][43], strides=[2, 2], padding=[1, 1, 1, 1], groups=32, channels=1024, kernel_size=[3, 3]); - %97 = nn.relu(%96); - %98 = nn.conv2d(%97, meta[relay.Constant][44], padding=[0, 0, 0, 0], channels=2048, kernel_size=[1, 1]); - %99 = nn.conv2d(%93, meta[relay.Constant][45], strides=[2, 2], padding=[0, 0, 0, 0], channels=2048, kernel_size=[1, 1]); - %100 = add(%98, %99); - %101 = nn.relu(%100); - %102 = nn.conv2d(%101, meta[relay.Constant][46], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %103 = nn.relu(%102); - %104 = nn.conv2d(%103, meta[relay.Constant][47], padding=[1, 1, 1, 1], groups=32, channels=1024, kernel_size=[3, 3]); - %105 = nn.relu(%104); - %106 = nn.conv2d(%105, meta[relay.Constant][48], padding=[0, 0, 0, 0], channels=2048, kernel_size=[1, 1]); - %107 = add(%106, %101); - %108 = nn.relu(%107); - %109 = nn.conv2d(%108, meta[relay.Constant][49], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %110 = nn.relu(%109); - %111 = nn.conv2d(%110, meta[relay.Constant][50], padding=[1, 1, 1, 1], groups=32, channels=1024, kernel_size=[3, 3]); - %112 = nn.relu(%111); - %113 = nn.conv2d(%112, meta[relay.Constant][51], padding=[0, 0, 0, 0], channels=2048, kernel_size=[1, 1]); - %114 = add(%113, %108); - nn.relu(%114) - } - """, - "from_string", - None, - metatable, - ) - return { - "name": "resnext50_32x4d", - "input_shapes": {"x": [1, 64, 56, 56]}, - "input_dtypes": {"x": "float32"}, - "mod": mod, - "params": None, - "main_dtype": "float32", - } - - -def resnext50_32x4d_16(): - metatable = {"relay.Constant": resnext50_32x4d_consts("float16")} - mod = tvm.parser.parse( - """ - #[version = "0.0.5"] - def @main(%x: Tensor[(1, 64, 56, 56), float16]) { - %0 = nn.conv2d(%x, meta[relay.Constant][0], padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1]); - %1 = nn.relu(%0); - %2 = nn.conv2d(%1, meta[relay.Constant][1], padding=[1, 1, 1, 1], groups=32, channels=128, kernel_size=[3, 3]); - %3 = nn.relu(%2); - %4 = nn.conv2d(%3, meta[relay.Constant][2], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %5 = nn.conv2d(%x, meta[relay.Constant][3], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %6 = add(%4, %5); - %7 = nn.relu(%6); - %8 = nn.conv2d(%7, meta[relay.Constant][4], padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1]); - %9 = nn.relu(%8); - %10 = nn.conv2d(%9, meta[relay.Constant][5], padding=[1, 1, 1, 1], groups=32, channels=128, kernel_size=[3, 3]); - %11 = nn.relu(%10); - %12 = nn.conv2d(%11, meta[relay.Constant][6], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %13 = add(%12, %7); - %14 = nn.relu(%13); - %15 = nn.conv2d(%14, meta[relay.Constant][7], padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1]); - %16 = nn.relu(%15); - %17 = nn.conv2d(%16, meta[relay.Constant][8], padding=[1, 1, 1, 1], groups=32, channels=128, kernel_size=[3, 3]); - %18 = nn.relu(%17); - %19 = nn.conv2d(%18, meta[relay.Constant][9], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %20 = add(%19, %14); - %21 = nn.relu(%20); - %22 = nn.conv2d(%21, meta[relay.Constant][10], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %23 = nn.relu(%22); - %24 = nn.conv2d(%23, meta[relay.Constant][11], strides=[2, 2], padding=[1, 1, 1, 1], groups=32, channels=256, kernel_size=[3, 3]); - %25 = nn.relu(%24); - %26 = nn.conv2d(%25, meta[relay.Constant][12], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %27 = nn.conv2d(%21, meta[relay.Constant][13], strides=[2, 2], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %28 = add(%26, %27); - %29 = nn.relu(%28); - %30 = nn.conv2d(%29, meta[relay.Constant][14], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %31 = nn.relu(%30); - %32 = nn.conv2d(%31, meta[relay.Constant][15], padding=[1, 1, 1, 1], groups=32, channels=256, kernel_size=[3, 3]); - %33 = nn.relu(%32); - %34 = nn.conv2d(%33, meta[relay.Constant][16], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %35 = add(%34, %29); - %36 = nn.relu(%35); - %37 = nn.conv2d(%36, meta[relay.Constant][17], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %38 = nn.relu(%37); - %39 = nn.conv2d(%38, meta[relay.Constant][18], padding=[1, 1, 1, 1], groups=32, channels=256, kernel_size=[3, 3]); - %40 = nn.relu(%39); - %41 = nn.conv2d(%40, meta[relay.Constant][19], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %42 = add(%41, %36); - %43 = nn.relu(%42); - %44 = nn.conv2d(%43, meta[relay.Constant][20], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); - %45 = nn.relu(%44); - %46 = nn.conv2d(%45, meta[relay.Constant][21], padding=[1, 1, 1, 1], groups=32, channels=256, kernel_size=[3, 3]); - %47 = nn.relu(%46); - %48 = nn.conv2d(%47, meta[relay.Constant][22], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %49 = add(%48, %43); - %50 = nn.relu(%49); - %51 = nn.conv2d(%50, meta[relay.Constant][23], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %52 = nn.relu(%51); - %53 = nn.conv2d(%52, meta[relay.Constant][24], strides=[2, 2], padding=[1, 1, 1, 1], groups=32, channels=512, kernel_size=[3, 3]); - %54 = nn.relu(%53); - %55 = nn.conv2d(%54, meta[relay.Constant][25], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %56 = nn.conv2d(%50, meta[relay.Constant][26], strides=[2, 2], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %57 = add(%55, %56); - %58 = nn.relu(%57); - %59 = nn.conv2d(%58, meta[relay.Constant][27], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %60 = nn.relu(%59); - %61 = nn.conv2d(%60, meta[relay.Constant][28], padding=[1, 1, 1, 1], groups=32, channels=512, kernel_size=[3, 3]); - %62 = nn.relu(%61); - %63 = nn.conv2d(%62, meta[relay.Constant][29], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %64 = add(%63, %58); - %65 = nn.relu(%64); - %66 = nn.conv2d(%65, meta[relay.Constant][30], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %67 = nn.relu(%66); - %68 = nn.conv2d(%67, meta[relay.Constant][31], padding=[1, 1, 1, 1], groups=32, channels=512, kernel_size=[3, 3]); - %69 = nn.relu(%68); - %70 = nn.conv2d(%69, meta[relay.Constant][32], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %71 = add(%70, %65); - %72 = nn.relu(%71); - %73 = nn.conv2d(%72, meta[relay.Constant][33], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %74 = nn.relu(%73); - %75 = nn.conv2d(%74, meta[relay.Constant][34], padding=[1, 1, 1, 1], groups=32, channels=512, kernel_size=[3, 3]); - %76 = nn.relu(%75); - %77 = nn.conv2d(%76, meta[relay.Constant][35], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %78 = add(%77, %72); - %79 = nn.relu(%78); - %80 = nn.conv2d(%79, meta[relay.Constant][36], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %81 = nn.relu(%80); - %82 = nn.conv2d(%81, meta[relay.Constant][37], padding=[1, 1, 1, 1], groups=32, channels=512, kernel_size=[3, 3]); - %83 = nn.relu(%82); - %84 = nn.conv2d(%83, meta[relay.Constant][38], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %85 = add(%84, %79); - %86 = nn.relu(%85); - %87 = nn.conv2d(%86, meta[relay.Constant][39], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1]); - %88 = nn.relu(%87); - %89 = nn.conv2d(%88, meta[relay.Constant][40], padding=[1, 1, 1, 1], groups=32, channels=512, kernel_size=[3, 3]); - %90 = nn.relu(%89); - %91 = nn.conv2d(%90, meta[relay.Constant][41], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %92 = add(%91, %86); - %93 = nn.relu(%92); - %94 = nn.conv2d(%93, meta[relay.Constant][42], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %95 = nn.relu(%94); - %96 = nn.conv2d(%95, meta[relay.Constant][43], strides=[2, 2], padding=[1, 1, 1, 1], groups=32, channels=1024, kernel_size=[3, 3]); - %97 = nn.relu(%96); - %98 = nn.conv2d(%97, meta[relay.Constant][44], padding=[0, 0, 0, 0], channels=2048, kernel_size=[1, 1]); - %99 = nn.conv2d(%93, meta[relay.Constant][45], strides=[2, 2], padding=[0, 0, 0, 0], channels=2048, kernel_size=[1, 1]); - %100 = add(%98, %99); - %101 = nn.relu(%100); - %102 = nn.conv2d(%101, meta[relay.Constant][46], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %103 = nn.relu(%102); - %104 = nn.conv2d(%103, meta[relay.Constant][47], padding=[1, 1, 1, 1], groups=32, channels=1024, kernel_size=[3, 3]); - %105 = nn.relu(%104); - %106 = nn.conv2d(%105, meta[relay.Constant][48], padding=[0, 0, 0, 0], channels=2048, kernel_size=[1, 1]); - %107 = add(%106, %101); - %108 = nn.relu(%107); - %109 = nn.conv2d(%108, meta[relay.Constant][49], padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]); - %110 = nn.relu(%109); - %111 = nn.conv2d(%110, meta[relay.Constant][50], padding=[1, 1, 1, 1], groups=32, channels=1024, kernel_size=[3, 3]); - %112 = nn.relu(%111); - %113 = nn.conv2d(%112, meta[relay.Constant][51], padding=[0, 0, 0, 0], channels=2048, kernel_size=[1, 1]); - %114 = add(%113, %108); - nn.relu(%114) - } - """, - "from_string", - None, - metatable, - ) - return { - "name": "resnext50_32x4d_16", - "input_shapes": {"x": [1, 64, 56, 56]}, - "input_dtypes": {"x": "float16"}, - "mod": mod, - "params": None, - "main_dtype": "float16", - } diff --git a/tests/python/contrib/test_clml/test_adreno_collage_targets.py b/tests/python/contrib/test_clml/test_adreno_collage_targets.py index bcd3e30c75c92..b777b2ca1196c 100644 --- a/tests/python/contrib/test_clml/test_adreno_collage_targets.py +++ b/tests/python/contrib/test_clml/test_adreno_collage_targets.py @@ -17,10 +17,6 @@ """Compares Collage with various other baselines.""" -# CAUTION: Requires some changes in python/tvm/autotvm/task/dispatcher.py -# so that AutoTVM tuning records can be cached between runs and between -# models. See https://github.com/mbs-octoml/mbs-tvm/tree/mbs-collage-hacks. - import tvm import logging import tempfile @@ -42,15 +38,10 @@ ########### Configuration ########### ### -### Rename to match your hardware, eg ..._vt100... +### TVM opencl autotvm log file name ### TUNING_LOG = "" -### -### If true, runs final model under nvprof -### -PROFILE = True - ### ### If true, run all models ###