-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
-Sun- 1. added optimizers 2. test7 : ill fix cora dataset streaming on GTransformer Later 3. cr_boosters is pkg contains optimizers written in extension to pyTorch
- Loading branch information
1 parent
3503f88
commit 3c47962
Showing
29 changed files
with
11,708 additions
and
60 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# """ | ||
# :mod:`torch.optim` is a package implementing various optimization algorithms. | ||
|
||
# Most commonly used methods are already supported, and the interface is general | ||
# enough, so that more sophisticated ones can also be easily integrated in the | ||
# future. | ||
# """ | ||
|
||
# from torch.optim import lr_scheduler, swa_utils | ||
# from torch.optim.adadelta import Adadelta | ||
# from torch.optim.adagrad import Adagrad | ||
# from torch.optim.adam import Adam | ||
# from torch.optim.adamax import Adamax | ||
# from torch.optim.adamw import AdamW | ||
# from torch.optim.asgd import ASGD | ||
# from torch.optim.lbfgs import LBFGS | ||
# from torch.optim.nadam import NAdam | ||
# from torch.optim.optimizer import Optimizer | ||
# from torch.optim.radam import RAdam | ||
# from torch.optim.rmsprop import RMSprop | ||
# from torch.optim.rprop import Rprop | ||
# from torch.optim.sgd import SGD | ||
# from torch.optim.sparse_adam import SparseAdam | ||
|
||
# del Adadelta # type: ignore[name-defined] # noqa: F821 | ||
# del Adagrad # type: ignore[name-defined] # noqa: F821 | ||
# del Adam # type: ignore[name-defined] # noqa: F821 | ||
# del AdamW # type: ignore[name-defined] # noqa: F821 | ||
# del SparseAdam # type: ignore[name-defined] # noqa: F821 | ||
# del Adamax # type: ignore[name-defined] # noqa: F821 | ||
# del ASGD # type: ignore[name-defined] # noqa: F821 | ||
# del SGD # type: ignore[name-defined] # noqa: F821 | ||
# del RAdam # type: ignore[name-defined] # noqa: F821 | ||
# del Rprop # type: ignore[name-defined] # noqa: F821 | ||
# del RMSprop # type: ignore[name-defined] # noqa: F821 | ||
# del Optimizer # type: ignore[name-defined] # noqa: F821 | ||
# del NAdam # type: ignore[name-defined] # noqa: F821 | ||
# del LBFGS # type: ignore[name-defined] # noqa: F821 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# # mypy: allow-untyped-defs | ||
# r"""Functional interface.""" | ||
# import math | ||
# from typing import List | ||
|
||
# from torch import Tensor | ||
|
||
# from .adadelta import adadelta # type: ignore[attr-defined] # noqa: F401 | ||
# from .adagrad import _make_sparse, adagrad # type: ignore[attr-defined] # noqa: F401 | ||
# from .adam import adam # type: ignore[attr-defined] # noqa: F401 | ||
# from .adamax import adamax # type: ignore[attr-defined] # noqa: F401 | ||
# # from .adamw import adamw # type: ignore[attr-defined] # noqa: F401 | ||
# from .asgd import asgd # type: ignore[attr-defined] # noqa: F401 | ||
# from .nadam import nadam # type: ignore[attr-defined] # noqa: F401 | ||
# from .radam import radam # type: ignore[attr-defined] # noqa: F401 | ||
# from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401 | ||
# from .rprop import rprop # type: ignore[attr-defined] # noqa: F401 | ||
# from .sgd import sgd # type: ignore[attr-defined] # noqa: F401 | ||
|
||
|
||
# # TODO: use foreach API in optim._functional to do all the computation | ||
|
||
|
||
# def sparse_adam( | ||
# params: List[Tensor], | ||
# grads: List[Tensor], | ||
# exp_avgs: List[Tensor], | ||
# exp_avg_sqs: List[Tensor], | ||
# state_steps: List[int], | ||
# *, | ||
# eps: float, | ||
# beta1: float, | ||
# beta2: float, | ||
# lr: float, | ||
# maximize: bool, | ||
# ): | ||
# r"""Functional API that performs Sparse Adam algorithm computation. | ||
|
||
# See :class:`~torch.optim.SparseAdam` for details. | ||
# """ | ||
# for i, param in enumerate(params): | ||
# grad = grads[i] | ||
# grad = grad if not maximize else -grad | ||
# grad = grad.coalesce() # the update is non-linear so indices must be unique | ||
# grad_indices = grad._indices() | ||
# grad_values = grad._values() | ||
# if grad_values.numel() == 0: | ||
# # Skip update for empty grad | ||
# continue | ||
# size = grad.size() | ||
|
||
# exp_avg = exp_avgs[i] | ||
# exp_avg_sq = exp_avg_sqs[i] | ||
# step = state_steps[i] | ||
|
||
# def make_sparse(values): | ||
# constructor = grad.new | ||
# if grad_indices.dim() == 0 or values.dim() == 0: | ||
# return constructor().resize_as_(grad) | ||
# return constructor(grad_indices, values, size) | ||
|
||
# # Decay the first and second moment running average coefficient | ||
# # old <- b * old + (1 - b) * new | ||
# # <==> old += (1 - b) * (new - old) | ||
# old_exp_avg_values = exp_avg.sparse_mask(grad)._values() | ||
# exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) | ||
# exp_avg.add_(make_sparse(exp_avg_update_values)) | ||
# old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() | ||
# exp_avg_sq_update_values = ( | ||
# grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) | ||
# ) | ||
# exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) | ||
|
||
# # Dense addition again is intended, avoiding another sparse_mask | ||
# numer = exp_avg_update_values.add_(old_exp_avg_values) | ||
# exp_avg_sq_update_values.add_(old_exp_avg_sq_values) | ||
# denom = exp_avg_sq_update_values.sqrt_().add_(eps) | ||
# del exp_avg_update_values, exp_avg_sq_update_values | ||
|
||
# bias_correction1 = 1 - beta1**step | ||
# bias_correction2 = 1 - beta2**step | ||
# step_size = lr * math.sqrt(bias_correction2) / bias_correction1 | ||
|
||
# param.add_(make_sparse(-step_size * numer.div_(denom))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
""" | ||
:mod:`torch.optim._multi_tensor` is a package implementing various optimization algorithms. | ||
Most commonly used methods are already supported, and the interface is general | ||
enough, so that more sophisticated ones can be also easily integrated in the | ||
future. | ||
""" | ||
from functools import partialmethod | ||
|
||
from torch import optim | ||
|
||
|
||
def partialclass(cls, *args, **kwargs): # noqa: D103 | ||
class NewCls(cls): | ||
__init__ = partialmethod(cls.__init__, *args, **kwargs) | ||
|
||
return NewCls | ||
|
||
|
||
Adam = partialclass(optim.Adam, foreach=True) | ||
AdamW = partialclass(optim.AdamW, foreach=True) | ||
NAdam = partialclass(optim.NAdam, foreach=True) | ||
SGD = partialclass(optim.SGD, foreach=True) | ||
RAdam = partialclass(optim.RAdam, foreach=True) | ||
RMSprop = partialclass(optim.RMSprop, foreach=True) | ||
Rprop = partialclass(optim.Rprop, foreach=True) | ||
ASGD = partialclass(optim.ASGD, foreach=True) | ||
Adamax = partialclass(optim.Adamax, foreach=True) | ||
Adadelta = partialclass(optim.Adadelta, foreach=True) | ||
Adagrad = partialclass(optim.Adagrad, foreach=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from functools import partial | ||
|
||
from torch import optim | ||
|
||
Adam = partial(optim.Adam, foreach=True) | ||
AdamW = partial(optim.AdamW, foreach=True) | ||
NAdam = partial(optim.NAdam, foreach=True) | ||
SGD = partial(optim.SGD, foreach=True) | ||
RAdam = partial(optim.RAdam, foreach=True) | ||
RMSprop = partial(optim.RMSprop, foreach=True) | ||
Rprop = partial(optim.Rprop, foreach=True) | ||
ASGD = partial(optim.ASGD, foreach=True) | ||
Adamax = partial(optim.Adamax, foreach=True) | ||
Adadelta = partial(optim.Adadelta, foreach=True) | ||
Adagrad = partial(optim.Adagrad, foreach=True) |
Oops, something went wrong.