1
- Skipping Module Parameter Initialization
2
- ========================================
1
+ λͺ¨λ λ§€κ°λ³μ μ΄κΈ°ν 건λλ°κΈ°
2
+ ===========================
3
3
4
- Introduction
5
- ------------
4
+ μκ°
5
+ ----
6
6
7
- When a module is created, its learnable parameters are initialized according
8
- to a default initialization scheme associated with the module type. For example, the `weight `
9
- parameter for a :class: `torch.nn.Linear ` module is initialized from a
10
- `uniform(-1/sqrt(in_features), 1/sqrt(in_features)) ` distribution. If some other initialization
11
- scheme is desired, this has traditionally required re-initializing the parameters
12
- after module instantiation:
7
+ λͺ¨λμ΄ μμ±λ λ, λͺ¨λ μ νκ³Ό κ΄λ ¨λ κΈ°λ³Έ μ΄κΈ°ν λ°©λ²μ λ°λΌ νμ΅ κ°λ₯ν λ§€κ°λ³μκ° μ΄κΈ°νλ©λλ€.
8
+ μλ₯Ό λ€μ΄, :class: `torch.nn.Linear ` λͺ¨λμ `weight ` λ§€κ°λ³μλ
9
+ `uniform(-1/sqrt(in_features), 1/sqrt(in_features)) ` λΆν¬λ‘ μ΄κΈ°νλ©λλ€.
10
+ κΈ°μ‘΄μλ λ€λ₯Έ μ΄κΈ°ν λ°©λ²μ΄ νμν κ²½μ° λͺ¨λ μΈμ€ν΄μ€ν ν λ§€κ°λ³μλ₯Ό μ¬μ΄κΈ°νν΄μΌ νμ΅λλ€.
13
11
14
12
::
15
13
16
14
from torch import nn
17
15
18
- # Initializes weight from the default distribution : uniform(-1/sqrt(10), 1/sqrt(10)).
16
+ # κΈ°λ³Έ λΆν¬λ‘ κ°μ€μΉλ₯Ό μ΄κΈ°νν©λλ€ : uniform(-1/sqrt(10), 1/sqrt(10)).
19
17
m = nn.Linear(10, 5)
20
18
21
- # Re-initialize weight from a different distribution .
19
+ # λ€λ₯Έ λΆν¬λ‘ κ°μ€μΉλ₯Ό μ¬μ΄κΈ°νν©λλ€ .
22
20
nn.init.orthogonal_(m.weight)
23
21
24
- In this case, the initialization done during construction is wasted computation, and it may be non-trivial if
25
- the `weight ` parameter is large.
22
+ μ΄ κ²½μ° κ΅¬μ± μ€ μνλλ μ΄κΈ°νλ κ³μ° λλΉμ΄λ©°, `weight ` λ§€κ°λ³μκ° ν¬λ©΄ μ¬μν λ¬Έμ κ° μλ μ μμ΅λλ€.
26
23
27
- Skipping Initialization
28
- -----------------------
29
-
30
- It is now possible to skip parameter initialization during module construction, avoiding
31
- wasted computation. This is easily accomplished using the :func: `torch.nn.utils.skip_init ` function:
24
+ μ΄κΈ°ν 건λλ°κΈ°
25
+ --------------
32
26
27
+ λͺ¨λ κ΅¬μ± μ€ λ§€κ°λ³μ μ΄κΈ°νλ₯Ό 건λλ°κ² λμ΄ λλΉλλ κ³μ°μ νΌν μ μμ΅λλ€.
28
+ :func: `torch.nn.utils.skip_init ` ν¨μλ₯Ό μ¬μ©νλ©΄ μ½κ² 건λλ°κΈ°κ° κ°λ₯ν©λλ€.
33
29
::
34
30
35
31
from torch import nn
36
32
from torch.nn.utils import skip_init
37
33
38
34
m = skip_init(nn.Linear, 10, 5)
39
35
40
- # Example: Do custom, non-default parameter initialization .
36
+ # μμ : κΈ°λ³Έ μ΄μΈμ λ§€κ°λ³μ μ΄κΈ°νλ₯Ό μμ νμ¬ μ€νν©λλ€ .
41
37
nn.init.orthogonal_(m.weight)
42
38
43
- This can be applied to any module that satisfies the conditions described in the
44
- :ref: `Updating ` section below. Note that all modules provided by
45
- `torch.nn ` satisfy these conditions and thus support skipping init.
39
+ μλ :ref: `Updating ` μΉμ
μ μ€λͺ
λ 쑰건μ μΆ©μ‘±νλ λͺ¨λμ μ μ©ν μ μμ΅λλ€.
40
+ `torch.nn ` μ μλ λͺ¨λ λͺ¨λμ 쑰건μ μΆ©μ‘±νκΈ° λλ¬Έμ μ΄κΈ°ν 건λλ°κΈ°λ₯Ό μ§μνκ³ μμ΅λλ€.
46
41
47
42
.. _Updating :
48
43
49
- Updating Modules to Support Skipping Initialization
50
- ---------------------------------------------------
44
+ μ΄κΈ°ν 건λλ°κΈ°λ₯Ό μν λͺ¨λ μ
λ°μ΄νΈ
45
+ ---------------------------------
51
46
52
- Due to the way :func: `torch.nn.utils.skip_init ` is implemented (see :ref: `Details `), there are
53
- two requirements that a module must meet to be compatible with the function.
54
- You can opt in to the parameter initialization skipping functionality for your custom module
55
- simply by adhering to these requirements:
47
+ :func: `torch.nn.utils.skip_init ` μ ꡬν(μ°Έκ³ :ref: `Details `) λ°©λ²μ λ°λΌ,
48
+ λͺ¨λμ΄ ν¨μμ νΈνλκΈ° μν λ κ°μ§ μꡬμ¬νμ΄ μμ΅λλ€.
49
+ λ€μμ μꡬμ¬νμ μ΄ννλ©΄ 컀μ€ν
λͺ¨λμ λ§€κ°λ³μ μ΄κΈ°ν 건λλ°κΈ° κΈ°λ₯μ μ νν μ μμ΅λλ€.
56
50
57
- 1. The module must accept a ` device ` kwarg in its constructor that is passed to any parameters
58
- or buffers created during construction.
51
+ 1. λͺ¨λμ μμ±ν λ λ§€κ°λ³μμ λ²νΌλ‘ μ λ¬λλ λͺ¨λμ μμ±μ λ΄ ` device ` ν€μλ μΈμ(keyword argument)λ₯Ό
52
+ μ¬μ©ν΄μΌ ν©λλ€.
59
53
60
- 2. The module must not perform any computation on parameters or buffers in its constructor except
61
- initialization (i.e. functions from `torch.nn.init `).
54
+ 2. λͺ¨λμ μ΄κΈ°νλ₯Ό μ μΈνκ³ λͺ¨λμ μμ±μ λ΄ λ§€κ°λ³μ λλ λ²νΌ κ³μ°μ μννμ§ μμμΌ ν©λλ€
55
+ (μ¦, `torch.nn.init`μ ν¨μ ).
62
56
63
- The following example demonstrates a module updated to support the ` device `
64
- kwarg by passing it along to any created parameters, buffers, or submodules:
57
+ λ€μμ ` device ` ν€μλ μΈμκ° μμ±λ νλΌλ―Έν°, λ²νΌ, μλΈλͺ¨λλ‘ λ°λΌ μ λ¬λκΈ° μν
58
+ λͺ¨λ μ
λ°μ΄νΈλ₯Ό 보μ¬μ£Όλ μμμ
λλ€.
65
59
66
60
::
67
61
@@ -72,56 +66,55 @@ kwarg by passing it along to any created parameters, buffers, or submodules:
72
66
def __init__(self, foo, bar, device=None):
73
67
super().__init__()
74
68
75
- # ==== Case 1: Module creates parameters directly . ====
76
- # Pass device along to any created parameters .
69
+ # ==== μ¬λ‘ 1: λͺ¨λ λ§€κ°λ³μλ₯Ό μ§μ μμ±ν©λλ€ . ====
70
+ # μμ±ν λ§€κ°λ³μμ μ₯μΉ( device)λ₯Ό μ λ¬ν©λλ€ .
77
71
self.param1 = nn.Parameter(torch.empty((foo, bar), device=device))
78
72
self.register_parameter('param2', nn.Parameter(torch.empty(bar, device=device)))
79
73
80
- # To ensure support for the meta device, avoid using ops except those in
81
- # torch.nn.init on parameters in your module's constructor .
74
+ # meta μ₯μΉ μ§μμ νμ€ν νκΈ° μν΄ λͺ¨λμ μμ±μ λ΄ λ§€κ°λ³μμ
75
+ # torch.nn.initμ ops μΈμλ μ¬μ©νμ§ λ§μμμ€ .
82
76
with torch.no_grad():
83
77
nn.init.kaiming_uniform_(self.param1)
84
78
nn.init.uniform_(self.param2)
85
79
86
80
87
- # ==== Case 2: Module creates submodules . ====
88
- # Pass device along recursively. All submodules will need to support
89
- # them as well; this is the case for all torch.nn provided modules .
81
+ # ==== μ¬λ‘ 2: λͺ¨λμ μλΈ λͺ¨λμ μμ±ν©λλ€ . ====
82
+ # λͺ¨λ μλΈ λͺ¨λμ΄ ν΄λΉ μ¬νμ μ§μν΄μΌ νκΈ° λλ¬Έμ μ₯μΉλ₯Ό μ¬κ·μ μΌλ‘ μ λ¬ν©λλ€.
83
+ # μ΄λ torch.nnμ΄ μ 곡νλ λͺ¨λλ€μ κ²½μ°μ ν΄λΉν©λλ€ .
90
84
self.fc = nn.Linear(bar, 5, device=device)
91
85
92
- # This also works with containers .
86
+ # 컨ν
μ΄λμλ λμΌνκ² μ μ©ν©λλ€ .
93
87
self.linears = nn.Sequential(
94
88
nn.Linear(5, 5, device=device),
95
89
nn.Linear(5, 1, device=device)
96
90
)
97
91
98
92
99
- # ==== Case 3: Module creates buffers . ====
100
- # Pass device along during buffer tensor creation .
93
+ # ==== μ¬λ‘ 3: λͺ¨λμ λ²νΌλ₯Ό μμ±ν©λλ€ . ====
94
+ # λ²νΌ tensor μμ±νλ λμ μ₯μΉλ₯Ό μ λ¬ν©λλ€ .
101
95
self.register_buffer('some_buffer', torch.ones(7, device=device))
102
96
103
97
...
104
98
105
99
.. _Details :
106
100
107
- Implementation Details
108
- ----------------------
101
+ ꡬν μΈλΆ μ¬ν
102
+ -------------
109
103
110
- Behind the scenes, the :func: `torch.nn.utils.skip_init ` function is implemented in terms of a two-step pattern:
104
+ λ΄λΆμ μΌλ‘ :func: `torch.nn.utils.skip_init ` ν¨μλ 2λ¨κ³ ν¨ν΄μΌλ‘ ꡬνλ©λλ€.
111
105
112
106
::
113
107
114
- # 1. Initialize module on the meta device; all torch.nn.init ops have
115
- # no-op behavior on the meta device .
108
+ # 1. meta μ₯μΉμμ λͺ¨λμ μ΄κΈ°νν©λλ€; λͺ¨λ torch.nn.init opsλ
109
+ # meta μ₯μΉμμ no-op λμμ ν©λλ€ .
116
110
m = nn.Linear(10, 5, device='meta')
117
111
118
- # 2. Materialize an uninitialized (empty) form of the module on the CPU device .
119
- # The result of this is a module instance with uninitialized parameters .
112
+ # 2. μ΄κΈ°νλμ§ μμ(λΉ) ννμ λͺ¨λμ CPU μ₯μΉμ ꡬνν©λλ€ .
113
+ # κ²°κ³Όλ μ΄κΈ°νλμ§ μμ λ§€κ° λ³μλ₯Ό κ°μ§ λͺ¨λ μΈμ€ν΄μ€μ
λλ€ .
120
114
m.to_empty(device='cpu')
121
115
122
- It works by instantiating the module onto a "meta" device, which has tensor shape information
123
- but does not allocate any storage. The `torch.nn.init ` ops are specially implemented for this meta device
124
- so that they have no-op behavior. This results in the parameter intialization logic being essentially skipped .
116
+ λͺ¨λμ "meta" μ₯μΉλ‘ μΈμ€ν΄μ€ννμ¬ λμν©λλ€. tensor shape μ 보λ₯Ό κ°μ§κ³ μμ§λ§ μ μ₯ 곡κ°μ ν λΉνμ§ μμ΅λλ€.
117
+ `torch.nn.init ` opsλ meta μ₯μΉλ₯Ό μν΄ νΉλ³ν ꡬνλμ΄ μκ³ no-op λμμ ν©λλ€.
118
+ μ΄μ λ°λΌ λ§€κ°λ³μ μ΄κΈ°ν λ‘μ§μμ λ³Έμ§μ μΌλ‘ 건λλ°κ² λ©λλ€ .
125
119
126
- Note that this pattern only works for modules that properly support a `device ` kwarg during construction, as
127
- described in :ref: `Updating `.
120
+ :ref: `Updating ` μ μ€λͺ
λ λλ‘ μ΄ ν¨ν΄μ λͺ¨λ κ΅¬μ± μ€ `device ` ν€μλ μΈμλ₯Ό μ μ ν μ§μνλ λͺ¨λμμλ§ μλν©λλ€.
0 commit comments