Skip to content

Commit

Permalink
little cleaner, overfit works
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Dec 20, 2023
1 parent 0f4089b commit 14cddf3
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ufmt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install ufmt==2.1.0
pip install -e .[dev]
- name: Analyzing the code with ufmt
run: |
ufmt check .
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ dependencies = [

[project.optional-dependencies]
dev = [
"black",
"usort",
"libcst",
"black==23.3.0",
"usort==1.0.6",
"ufmt==2.1.0",
"libcst==1.0.1",
"bumpver",
"pip-tools",
"pytest"
Expand Down
17 changes: 11 additions & 6 deletions transformer_nuggets/llama/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __post_init__(self):
self.gradient_accumulation_iters = self.batch_size // self.micro_batch_size
self.lr_decay_iters = self.max_iters
assert self.gradient_accumulation_iters > 0
if self.fp8_linear_type is not None:
self.fp8_linear_type = LinearType[self.fp8_linear_type.upper()]


@dataclass
Expand Down Expand Up @@ -213,7 +215,6 @@ def train(
) -> None:
"""Lets go!"""
step_count = 0
total_lengths = 0
progress_bar = tqdm(total=hyper_params.max_iters)

model.train()
Expand Down Expand Up @@ -268,7 +269,6 @@ def train(
optimizer.step()
optimizer.zero_grad()
step_count += 1
total_lengths += input_ids.size(1)

if not is_accumulating and step_count % training_config.eval_interval == 0:
t0 = time.time()
Expand All @@ -285,7 +285,7 @@ def train(
if iter_num % training_config.log_interval == 0:
# loss.item causes a sync so we update the progress bar sporadically
write_loss_to_file(train_loss_file, step_count, loss.item())
progress_bar.set_postfix_str(f"Iter {iter_num}: Loss {loss.item():.4f}")
progress_bar.set_postfix_str(f"Loss {loss.item():.4f}")
progress_bar.update(1)

if training_config.profile and iter_num < 103:
Expand Down Expand Up @@ -378,13 +378,18 @@ def get_lr(it, hyper_params: Hyperparameters):


def entrypoint(
fp8_linear_type: LinearType = None,
fp8_linear_type: Optional[LinearType] = None,
compile: bool = False,
overfit: bool = False,
profile: bool = False,
):
if fp8_linear_type is not None:
fp8_linear_type = LinearType[fp8_linear_type.upper()]
assert (
isinstance(fp8_linear_type, str) or fp8_linear_type is None
), "fp8_linear_type must be str"
assert isinstance(compile, bool), "compile must be bool"
assert isinstance(overfit, bool), "overfit must be bool"
assert isinstance(profile, bool), "profile must be bool"

if overfit:
batch_size = 1
else:
Expand Down

0 comments on commit 14cddf3

Please sign in to comment.