Skip to content

Commit

Permalink
lintering
Browse files Browse the repository at this point in the history
* linter

* update import

* delete the commented code

* synchronize the compressor at first
  • Loading branch information
huangrh99 committed Oct 28, 2020
1 parent b4d48e3 commit c774e46
Showing 1 changed file with 24 additions and 100 deletions.
124 changes: 24 additions & 100 deletions autodist/kernel/synchronization/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
# limitations under the License.

"""Gradient Compressors for All-Reduce."""
import copy
from abc import ABC, abstractmethod
from tensorflow.python.framework import dtypes
from tensorflow.python.framework.ops import Tensor
from tensorflow.python.ops import collective_ops, math_ops

#from tensorflow.python.ops import array_ops, collective_ops, linalg_ops, math_ops, random_ops
#from autodist.kernel.synchronization.collective_key import get_collective_keys
from tensorflow.python.ops import collective_ops, math_ops, random_ops, array_ops, linalg_ops
from autodist.kernel.synchronization.collective_key import get_collective_keys
#from autodist.utils import logging


Expand Down Expand Up @@ -207,11 +206,13 @@ class HorovodCompressorEF(CompressorEF, HorovodCompressor): # This works becaus

class PowerSGDCompressor(CompressorEF):
"""An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727)."""

def __init__(self, var_op_name, rank=1):
self.rank = rank
self.og_shape, self.ndims, self.compressor = None, None, None # compressor is the Q in paper
self.og_shape, self.ndims = None, None
self.compressor, self.compressor_conf = None, None # compressor is the Q in paper
self.var_op_name = var_op_name
super.__init__(var_op_name)
super().__init__(var_op_name)

def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig):
"""
Expand All @@ -229,13 +230,18 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig):
self.ndims = len(self.og_shape)

# rank <= 1
if self.ndims <= 1 or (self.ndims==2 and any([d == 1 for d in self.og_shape])):
if self.ndims <= 1 or (self.ndims == 2 and any([d == 1 for d in self.og_shape])):
return self._all_reduce(tensor, conf)

# compressor init
if self.compressor is None:
self.compressor = random_ops.random_normal([array_ops.shape_v2(tensor)[1], self.rank])

# synchronize compressor init statue
self.compressor_conf = copy.copy(conf)
self.conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor')
self.compressor = self._all_reduce(self.compressor, self.compressor_conf)

if self.error is not None:
tensor += self.error

Expand All @@ -246,13 +252,10 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig):

orthonormal_reduced_tensor = self._modified_gram_schmidt(reduced_tensor)

self.compressor = math_ops.matmul(tensor, orthonormal_reduced_tensor, transpose_a=True) # mxn * nxr => mxr
self.compressor = math_ops.matmul(tensor, orthonormal_reduced_tensor, transpose_a=True) # mxn * nxr => mxr

# all reduce mean compressor
instance_key = conf.instance_key
conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor')
self.compressor = self._all_reduce(self.compressor, conf)
conf.instance_key = instance_key
self.compressor = self._all_reduce(self.compressor, self.compressor_conf)

return self._decompress(orthonormal_reduced_tensor)

Expand All @@ -278,105 +281,26 @@ def _decompress(self, compressed_tensor: Tensor):
Returns:
Tensor, Context
"""
return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) # nxr * rxm = nxm
return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) # nxr * rxm => nxm

@staticmethod
def _modified_gram_schmidt(matrix):
'''
apply modified Gram-Schmidt procedure to orthogonalize a matrix in columns
"""
Apply modified Gram-Schmidt procedure to orthogonalize a matrix in columns.
Args:
matrix (Tensor): the Tensor to orthogonalize.
Returns:
matrix (Tensor)
'''
n, m = matrix.shape
matrix (Tensor)
"""
_, m = matrix.shape

for i in range(m):
v = matrix[:, i:i+1]
v = matrix[:, i:(i + 1)]
v /= linalg_ops.norm_v2(v, axis=0)

rest = matrix[:,i+1:]
rest = matrix[:, (i + 1):]
rest -= math_ops.reduce_sum_v1(v * rest, axis=0, keepdims=True) * v
matrix = array_ops.concat([matrix[:,:i], v, rest],axis=1)
matrix = array_ops.concat([matrix[:, :i], v, rest], axis=1)
return matrix


# class PowerSGDCompressor(CompressorEF):
# """An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727)."""

# def __init__(self, var_op_name, rank=1):
# self.rank = rank
# self.og_shape, self.ndims, self.new_shape, self.compressor = None, None, None, None
# super().__init__(var_op_name)

# def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig):
# """
# Compress, reduce, and decompress a given tensor.

# Args:
# tensor (Tensor): the Tensor to reduce.
# conf (CollectiveOpsConfig): the config for Collective Ops.

# Returns:
# Reduced Tensor
# """
# if self.og_shape is None:
# self.og_shape = tensor.shape
# self.ndims = len(self.og_shape)

# # Check if rank 1 tensor (this shouldn't be called with sparse tensors)
# # Just reduce it if it is, no need to compress
# if self._is_1d:
# return self._all_reduce(tensor, conf)

# logging.info(f"Compressing tensor {tensor.name} (var {self.var_op_name}) with shape {tensor.shape}")
# if self.ndims > 2:
# tensor = array_ops.reshape(tensor, [self.og_shape[0], -1])

# if self.compressor is None:
# self.new_shape = array_ops.shape_v2(tensor)
# self.compressor = random_ops.random_normal([self.new_shape[1], self.rank])

# if self.error is not None:
# tensor += self.error

# compressed_tensor = self._compress(tensor)
# self.error = tensor - self._decompress(compressed_tensor)

# # all reduce mean p
# reduced = self._all_reduce(compressed_tensor, conf)
# reduced = self._orthogonalize(reduced)

# # update compressor
# self.compressor = math_ops.matmul(tensor, reduced, transpose_a=True)
# conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + "/compressor")
# self.compressor = self._all_reduce(self.compressor, conf)
# return array_ops.reshape(self._decompress(reduced), self.og_shape) \
# if self.ndims > 2 else self._decompress(reduced)

# def _compress(self, tensor: Tensor):
# return math_ops.matmul(tensor, self.compressor)

# def _decompress(self, compressed_tensor: Tensor):
# return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True)

# @property
# def _is_1d(self):
# return self.ndims <= 1 or (
# self.ndims == 2 and any(d == 1 for d in self.og_shape)
# )

# @staticmethod
# def _orthogonalize(matrix):
# _, m = matrix.shape
# for i in range(m):
# v = matrix[:, i]
# v /= linalg_ops.norm_v2(v)
# v = array_ops.expand_dims_v2(v, 1)

# begin, rest = matrix[:, :i], matrix[:, (i + 1):]
# rest -= math_ops.matmul(v, rest, transpose_a=True) * v
# matrix = array_ops.concat([begin, v, rest], 1)
# return matrix

0 comments on commit c774e46

Please sign in to comment.