Skip to content

Commit

Permalink
fix ci
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubospica committed Jan 29, 2023
1 parent 07dd34f commit b983eaf
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
4 changes: 3 additions & 1 deletion python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,9 @@ def match_cast(self, value: Expr, struct_info: StructInfo, name_hint: str = "")
ret : tvm.relax.Var
A newly created variable that get bounds to be the casted result.
"""
return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info, name_hint) # type: ignore
return _ffi_api.BlockBuilderEmitMatchCast( # type: ignore
self, value, struct_info, name_hint
)

def emit_output(self, output: Union[Expr, Tuple, List[Expr]], name_hint: str = "") -> Var:
"""Emit output for the current dataflow block or function.
Expand Down
19 changes: 10 additions & 9 deletions python/tvm/relax/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Provide abstraction for defining optimizers and a set of common optimizers."""

from decimal import Decimal
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np # type: ignore

Expand Down Expand Up @@ -221,7 +221,7 @@ def set_vm_config(
return self

def __call__(
self, params_ADT: tvm.runtime.container.ADT, grads_ADT: tvm.runtime.container.ADT
self, params_adt: tvm.runtime.container.ADT, grads_adt: tvm.runtime.container.ADT
) -> tvm.runtime.container.ADT:
"""Optimization process. This function takes an ADT tuple of the input parameters and an ADT
tuple of the gradients of the input parameters, and returns an ADT tuple of parameters
Expand All @@ -234,10 +234,10 @@ def __call__(
Parameters
----------
params_ADT : tvm.runtime.container.ADT
params_adt : tvm.runtime.container.ADT
An ADT tuple of the input parameters. A TVM runtime object.
grads_ADT : tvm.runtime.container.ADT
grads_adt : tvm.runtime.container.ADT
An ADT tuple of the gradients of the input parameters. A TVM runtime object.
"""
if self._vm_module is None:
Expand All @@ -246,10 +246,11 @@ def __call__(
"The vm configs of the optimizer is not set. Please call set_vm_config first"
)
mod = tvm.IRModule({self.name: self.get_function()})
lowered_mod = LegalizeOps()(mod)
ex = rx.vm.build(lowered_mod, self._target)
self._vm_module = rx.VirtualMachine(ex, self._device)
new_params, self.state = self._vm_module[self.name](params_ADT, grads_ADT, self.state)
# pylint: disable=not-callable
lowered_mod = LegalizeOps()(mod) # type: ignore
executable = rx.vm.build(lowered_mod, self._target)
self._vm_module = rx.VirtualMachine(executable, self._device)
new_params, self.state = self._vm_module[self.name](params_adt, grads_adt, self.state)
return new_params


Expand Down Expand Up @@ -531,7 +532,7 @@ def Adam(param_tuple, grad_tuple, state_tuple):
num_steps_new = num_steps + 1
param_tuple_new = []
state_tuple_new = [None] * len(state_tuple) # type: ignore
state_tuple_new = [None] * len(state_tuple)
state_tuple_new[0] = num_steps_new
state_tuple_new[1] = state_tuple[1] * betas[0]
state_tuple_new[2] = state_tuple[2] * betas[1]
Expand Down

0 comments on commit b983eaf

Please sign in to comment.