Skip to content

Commit

Permalink
修改: doc/tutorial/quant/ptq-eager.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
liuxinwei committed Jan 18, 2024
1 parent 934a5df commit 0291c47
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
4 changes: 2 additions & 2 deletions doc/paper/quant/ptq.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@
"\\widehat{\\mathbf{x}} = s \\mathbf{x}_{\\operatorname{int}}\n",
"$$ (q8.1)\n",
"$$\n",
"\\mathbf{x}_{\\operatorname{int}} = \\operatorname{clamp}(\\lfloor \\cfrac{\\mathbf{x}}{s} \\rceil + z; 0, 2^b -1), \\text{针对无符号整型}\n",
"\\mathbf{x}_{\\operatorname{int}} = \\operatorname{clamp}(\\lfloor \\cfrac{\\mathbf{x}}{s} \\rceil; 0, 2^b -1), \\text{针对无符号整型}\n",
"$$ (q8.2)\n",
"$$\n",
"\\mathbf{x}_{\\operatorname{int}} = \\operatorname{clamp}(\\lfloor \\cfrac{\\mathbf{x}}{s} \\rceil + z; -2^{b-1}, 2^{b-1} -1), \\text{针对有符号整型}\n",
"\\mathbf{x}_{\\operatorname{int}} = \\operatorname{clamp}(\\lfloor \\cfrac{\\mathbf{x}}{s} \\rceil; -2^{b-1}, 2^{b-1} -1), \\text{针对有符号整型}\n",
"$$ (q8.3)"
]
},
Expand Down
5 changes: 0 additions & 5 deletions doc/paper/quant/qat.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
"\\cfrac{\\partial{\\lfloor y \\rceil}}{\\partial{y}} = 1\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
Expand Down
20 changes: 13 additions & 7 deletions doc/tutorial/quant/ptq-eager.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
" self.dequant = DeQuantStub() # 将张量从量化转换为浮点\n",
"\n",
" def forward(self, x: Tensor) -> Tensor:\n",
" if self.is_print:\n",
" print('原始类型:', x.dtype)\n",
" # 手动指定张量将在量化模型中从浮点模块转换为量化模块的位置\n",
" x = self.quant(x)\n",
" if self.is_print:\n",
Expand Down Expand Up @@ -171,6 +173,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"原始类型: torch.float32\n",
"量化前的类型: torch.float32\n",
"量化中的类型: torch.float32\n",
"量化后的类型: torch.float32\n"
Expand Down Expand Up @@ -264,6 +267,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"原始类型: torch.float32\n",
"量化前的类型: torch.float32\n",
"量化中的类型: torch.float32\n",
"量化后的类型: torch.float32\n",
Expand Down Expand Up @@ -406,6 +410,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"原始类型: torch.float32\n",
"量化前的类型: torch.float32\n",
"量化中的类型: torch.float32\n",
"量化后的类型: torch.float32\n"
Expand All @@ -418,15 +423,15 @@
" (conv): ConvReLU2d(\n",
" (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
" (1): ReLU()\n",
" (activation_post_process): HistogramObserver(min_val=0.0, max_val=1.3095310926437378)\n",
" (activation_post_process): HistogramObserver(min_val=0.0, max_val=1.430925965309143)\n",
" )\n",
" (relu): Identity()\n",
" (conv2): Conv2d(\n",
" 3, 16, kernel_size=(1, 1), stride=(1, 1)\n",
" (activation_post_process): HistogramObserver(min_val=-0.9652463793754578, max_val=0.972123384475708)\n",
" (activation_post_process): HistogramObserver(min_val=-0.8590439558029175, max_val=0.9416270852088928)\n",
" )\n",
" (quant): QuantStub(\n",
" (activation_post_process): HistogramObserver(min_val=-2.4491002559661865, max_val=1.9835344552993774)\n",
" (activation_post_process): HistogramObserver(min_val=-1.7396681308746338, max_val=2.5327765941619873)\n",
" )\n",
" (dequant): DeQuantStub()\n",
")"
Expand Down Expand Up @@ -471,10 +476,10 @@
"data": {
"text/plain": [
"QM(\n",
" (conv): QuantizedConvReLU2d(1, 3, kernel_size=(3, 3), stride=(1, 1), scale=0.005132908467203379, zero_point=0)\n",
" (conv): QuantizedConvReLU2d(1, 3, kernel_size=(3, 3), stride=(1, 1), scale=0.005608734209090471, zero_point=0)\n",
" (relu): Identity()\n",
" (conv2): QuantizedConv2d(3, 16, kernel_size=(1, 1), stride=(1, 1), scale=0.007593819405883551, zero_point=127)\n",
" (quant): Quantize(scale=tensor([0.0174]), zero_point=tensor([141]), dtype=torch.quint8)\n",
" (conv2): QuantizedConv2d(3, 16, kernel_size=(1, 1), stride=(1, 1), scale=0.007058007176965475, zero_point=122)\n",
" (quant): Quantize(scale=tensor([0.0167]), zero_point=tensor([104]), dtype=torch.quint8)\n",
" (dequant): DeQuantize()\n",
")"
]
Expand Down Expand Up @@ -559,6 +564,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"原始类型: torch.float32\n",
"量化前的类型: torch.quint8\n",
"量化中的类型: torch.quint8\n",
"量化后的类型: torch.float32\n"
Expand All @@ -582,7 +588,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down

0 comments on commit 0291c47

Please sign in to comment.