Skip to content

Commit

Permalink
DeepSpeed: hardcode torch.arange dtype on float usage to avoid in…
Browse files Browse the repository at this point in the history
…correct initialization (#28760)
  • Loading branch information
gante authored Jan 31, 2024
1 parent f7076cd commit beb2a09
Show file tree
Hide file tree
Showing 50 changed files with 192 additions and 118 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/clvp/modeling_clvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class ClvpRotaryPositionalEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
dim = max(config.projection_dim // (config.num_attention_heads * 2), 32)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))

self.register_buffer("inv_freq", inv_freq)
self.cached_sequence_length = None
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@

# Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def forward(self, pixel_values, pixel_mask):
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale

dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)

pos_x = x_embed[:, :, :, None] / dim_t
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/ctrl/modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def angle_defn(pos, i, d_model_size):
def positional_encoding(position, d_model_size, dtype):
# create the sinusoidal pattern for the positional encoding
angle_rads = angle_defn(
torch.arange(position, dtype=dtype).unsqueeze(1),
torch.arange(d_model_size, dtype=dtype).unsqueeze(0),
torch.arange(position, dtype=torch.int64).to(dtype).unsqueeze(1),
torch.arange(d_model_size, dtype=torch.int64).to(dtype).unsqueeze(0),
d_model_size,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def forward(self, pixel_values, pixel_mask):
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale

dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)

pos_x = x_embed[:, :, :, None] / dim_t
Expand Down Expand Up @@ -617,7 +617,7 @@ def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):

def _reset_parameters(self):
nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
Expand Down Expand Up @@ -1557,7 +1557,7 @@ def get_proposal_pos_embed(self, proposals):
temperature = 10000
scale = 2 * math.pi

dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float()
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
# batch_size, num_queries, 4
proposals = proposals.sigmoid() * scale
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build here to make `torch.jit.trace` work.
Expand All @@ -81,7 +81,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor

freqs = torch.outer(t, self.inv_freq)
Expand All @@ -135,10 +135,10 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,9 @@ def forward(
hids = []
attentions = [] if output_attentions else None
if self.attn_type == 0: # default
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype)
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=torch.int64).type_as(
dtype=word_emb.dtype
)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def forward(self, pixel_values, pixel_mask):
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale

dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)

pos_x = x_embed[:, :, :, None] / dim_t
Expand Down Expand Up @@ -526,7 +526,7 @@ def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int)

def _reset_parameters(self):
nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
Expand Down Expand Up @@ -1447,7 +1447,7 @@ def get_proposal_pos_embed(self, proposals):
temperature = 10000
scale = 2 * math.pi

dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float()
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
# batch_size, num_queries, 4
proposals = proposals.sigmoid() * scale
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def forward(self, pixel_values, pixel_mask):
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale

dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)

pos_x = x_embed[:, :, :, None] / dim_t
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
inv_freq = inv_freq
self.register_buffer("inv_freq", inv_freq)

Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build here to make `torch.jit.trace` work.
Expand All @@ -148,7 +148,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor

freqs = torch.outer(t, self.inv_freq)
Expand All @@ -202,10 +202,10 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -820,9 +820,9 @@ def extend_pos_enc(self, x):
# are to the left (i>j) and negative relative positions otherwise (i<j).
pos_enc_positive = torch.zeros(x.size(1), self.embed_dim)
pos_enc_negative = torch.zeros(x.size(1), self.embed_dim)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
position = torch.arange(0, x.size(1), dtype=torch.int64).float().unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embed_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embed_dim)
torch.arange(0, self.embed_dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / self.embed_dim)
)
pos_enc_positive[:, 0::2] = torch.sin(position * div_term)
pos_enc_positive[:, 1::2] = torch.cos(position * div_term)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,8 +1346,8 @@ def get_embedding(num_embeddings, embedding_dim, padding_idx):
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/funnel/modeling_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def get_position_embeds(
if self.config.attention_type == "factorized":
# Notations from the paper, appending A.2.2, final formula.
# We need to create and return the matrices phi, psi, pi and omega.
pos_seq = torch.arange(0, seq_len, 1.0, dtype=dtype, device=device)
freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=dtype, device=device)
pos_seq = torch.arange(0, seq_len, 1.0, dtype=torch.int64, device=device).to(dtype)
freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=torch.int64, device=device).to(dtype)
inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
sinusoid = pos_seq[:, None] * inv_freq[None]
sin_embed = torch.sin(sinusoid)
Expand All @@ -252,17 +252,17 @@ def get_position_embeds(
else:
# Notations from the paper, appending A.2.1, final formula.
# We need to create and return all the possible vectors R for all blocks and shifts.
freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=dtype, device=device)
freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=torch.int64, device=device).to(dtype)
inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
# Maximum relative positions for the first input
rel_pos_id = torch.arange(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype, device=device)
rel_pos_id = torch.arange(-seq_len * 2, seq_len * 2, 1.0, dtype=torch.int64, device=device).to(dtype)
zero_offset = seq_len * 2
sinusoid = rel_pos_id[:, None] * inv_freq[None]
sin_embed = self.sin_dropout(torch.sin(sinusoid))
cos_embed = self.cos_dropout(torch.cos(sinusoid))
pos_embed = torch.cat([sin_embed, cos_embed], dim=-1)

pos = torch.arange(0, seq_len, dtype=dtype, device=device)
pos = torch.arange(0, seq_len, dtype=torch.int64, device=device).to(dtype)
pooled_pos = pos
position_embeds_list = []
for block_index in range(0, self.config.num_blocks):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/fuyu/image_processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,8 @@ def preprocess_with_tokenizer_info(
# Indices of image patches.
patches_mask = subseq_image_input_ids == image_placeholder_id
num_patches = torch.count_nonzero(patches_mask)
indices = torch.arange(
num_patches, dtype=subseq_image_input_ids.dtype, device=subseq_image_input_ids.device
indices = torch.arange(num_patches, dtype=torch.int64, device=subseq_image_input_ids.device).type_as(
subseq_image_input_ids
)

# Place those indices in the image input ids token stream, with -1 representing non-index tokens.
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build here to make `torch.jit.trace` work.
Expand All @@ -544,7 +544,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
Expand Down Expand Up @@ -573,7 +573,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor

freqs = torch.outer(t, self.inv_freq)
Expand All @@ -598,10 +598,10 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build here to make `torch.jit.trace` work.
Expand All @@ -252,7 +252,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
Expand Down
Loading

0 comments on commit beb2a09

Please sign in to comment.