Skip to content

Commit

Permalink
looking better
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Dec 20, 2023
1 parent 189241e commit 4bd4673
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ufmt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install ufmt
pip install ufmt==2.1.0
- name: Analyzing the code with ufmt
run: |
ufmt check .
4 changes: 1 addition & 3 deletions transformer_nuggets/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ This should take around 3 minutes to run and prepare the training data.

#### Train Model
To edit the training configs take a look at `transformer_nuggets/llama/train.py`. The `entrypoint` function constructs the hyperparam configs as well as the
training configs. By default this will train a 7b model and and save the checkpoints to `transformer_nuggets/llama/data/out/`. It will also save the loss
training configs. By default this will train a 7b model and and save the checkpoints to `transformer_nuggets/llama/data/out/`. It will also save the loss
logs to `transformer_nuggets/llama/data/logs`.


Expand All @@ -34,7 +34,5 @@ python transformer_nuggets/llama/train.py \
```




### Notes
To get the Llama2 tokenizer go to https://huggingface.co/meta-llama/Llama-2-7b and go through steps to obtain access. This will get you pretrained weights as well as the tokenizer.
13 changes: 9 additions & 4 deletions transformer_nuggets/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, config: ModelArgs) -> None:
self.max_batch_size = -1
self.max_seq_length = -1

def setup_caches(self, max_batch_size, max_seq_length):
def setup_caches(self, max_batch_size, max_seq_length, device: torch.device):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
Expand All @@ -117,7 +117,10 @@ def setup_caches(self, max_batch_size, max_seq_length):
self.max_batch_size = max_batch_size

self.freqs_cis = precompute_freqs_cis(
self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base
self.config.block_size,
head_dim,
device,
self.config.rope_base,
)

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
Expand Down Expand Up @@ -236,13 +239,15 @@ def forward(self, x: Tensor) -> Tensor:
return output * self.weight


def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
def precompute_freqs_cis(
seq_len: int, n_elem: int, device: torch.device, base: int = 10000
) -> Tensor:
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype=torch.bfloat16)
return cache.to(dtype=torch.bfloat16, device=device)


def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
Expand Down
24 changes: 10 additions & 14 deletions transformer_nuggets/llama/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Hyperparameters:
learning_rate: float = 6e-4
batch_size: int = 128
micro_batch_size: int = 1
max_seq_length: int = 4096
max_seq_length: int = 2048
gradient_accumulation_iters: int = batch_size // micro_batch_size
assert gradient_accumulation_iters > 0
max_iters: int = 600000 # train dataset size
Expand All @@ -60,7 +60,7 @@ class Hyperparameters:
# Float8 Specific Config
# We want to skip the first embedding layer since scaled_mm needs to multiple of 16
float8_skip_list: List[str] = field(default_factory=lambda: ["lm_head"])
fp8_linear_type: Optional[str] = None
fp8_linear_type: Optional[LinearType] = None


@dataclass
Expand All @@ -81,21 +81,20 @@ class TrainingConfig:
base_path = Path("transformer_nuggets/llama/data")
out_dir: Path = base_path / "out"
data_dir: Path = base_path
log_dir: Path = out_dir / "logs"
log_dir: Path = base_path / "logs"

device: torch.device = torch.device("cuda:0")
# If true we will profile iters 100-102 of the model training
profile: bool = False


def write_loss_to_file(config: TrainingConfig, step: int, loss: float):
def write_loss_to_file(loss_file: Path, step: int, loss: float):
"""Writes the loss to a csv file for later plotting
Args:
loss_file: The file to write the loss to
step: The current step
loss: The loss to write
"""
loss_file = config.log_dir / "loss.csv"
if not loss_file.exists():
with open(loss_file, "w") as f:
writer = csv.writer(f)
Expand Down Expand Up @@ -156,7 +155,9 @@ def main(
with training_config.device:
model = Transformer(model_args).to(torch.bfloat16)
model.init_parameters()
model.setup_caches(hyper_params.micro_batch_size, hyper_params.max_seq_length)
model.setup_caches(
hyper_params.micro_batch_size, hyper_params.max_seq_length, training_config.device
)

logging.info("Setting up the dataloaders")
train_data, val_data = load_datasets(hyper_params, training_config)
Expand All @@ -166,8 +167,6 @@ def main(
val_dataloader = DataLoader(val_data, batch_size=hyper_params.micro_batch_size, num_workers=2)

fp8_linear_type = hyper_params.fp8_linear_type
if fp8_linear_type is not None:
fp8_linear_type = LinearType[fp8_linear_type.upper()]
if fp8_linear_type is not None:
fp8_module = LINEAR_TYPE_MAP[fp8_linear_type]
swap_linear_with_float8_linear(model, fp8_module)
Expand Down Expand Up @@ -245,7 +244,6 @@ def train(
if linear_requires_sync(fp8_linear_type):
sync_func(model)

t0 = time.perf_counter()
input_ids, targets = next(train_iter)
input_ids = input_ids.pin_memory().to(training_config.device)
targets = targets.pin_memory().to(training_config.device)
Expand All @@ -264,8 +262,6 @@ def train(
optimizer.step()
optimizer.zero_grad()
step_count += 1

dt = time.perf_counter() - t0
total_lengths += input_ids.size(1)

if not is_accumulating and step_count % training_config.eval_interval == 0:
Expand All @@ -283,9 +279,7 @@ def train(
if iter_num % training_config.log_interval == 0:
# loss.item causes a sync so we update the progress bar sporadically
write_loss_to_file(train_loss_file, step_count, loss.item())
progress_bar.set_postfix_str(
f"Iter {iter_num}: Loss {loss.item():.4f}, Time: {dt*1000:.2f}ms"
)
progress_bar.set_postfix_str(f"Iter {iter_num}: Loss {loss.item():.4f}")
progress_bar.update(1)

if training_config.profile and iter_num < 103:
Expand Down Expand Up @@ -383,6 +377,8 @@ def entrypoint(
overfit: bool = False,
profile: bool = False,
):
if fp8_linear_type is not None:
fp8_linear_type = LinearType[fp8_linear_type.upper()]
hyper_params = Hyperparameters(fp8_linear_type=fp8_linear_type)
training_config = TrainingConfig(compile=compile, overfit=overfit, profile=profile)
main(hyper_params, training_config)
Expand Down

0 comments on commit 4bd4673

Please sign in to comment.