File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change 12
12
from __future__ import annotations
13
13
14
14
from collections .abc import Sequence
15
+ from typing import Optional
15
16
16
17
import numpy as np
17
18
import torch
@@ -53,7 +54,7 @@ def __init__(
53
54
pos_embed_type : str = "learnable" ,
54
55
dropout_rate : float = 0.0 ,
55
56
spatial_dims : int = 3 ,
56
- pos_embed_kwargs : dict = {} ,
57
+ pos_embed_kwargs : Optional [ dict ] = None ,
57
58
) -> None :
58
59
"""
59
60
Args:
@@ -108,6 +109,8 @@ def __init__(
108
109
self .position_embeddings = nn .Parameter (torch .zeros (1 , self .n_patches , hidden_size ))
109
110
self .dropout = nn .Dropout (dropout_rate )
110
111
112
+ pos_embed_kwargs = {} if pos_embed_kwargs is None else pos_embed_kwargs
113
+
111
114
if self .pos_embed_type == "none" :
112
115
pass
113
116
elif self .pos_embed_type == "learnable" :
You can’t perform that action at this time.
0 commit comments