Skip to content

Commit

Permalink
Merge branch 'main' into line-search
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaron Meyer committed Nov 18, 2024
2 parents 67f3581 + 51a5f2a commit bea8b43
Show file tree
Hide file tree
Showing 18 changed files with 320 additions and 390 deletions.
116 changes: 0 additions & 116 deletions timpute/SVD_impute.py

This file was deleted.

File renamed without changes.
9 changes: 9 additions & 0 deletions timpute/figures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ..method_ALS import perform_ALS
from ..method_CLS import perform_CLS
from ..method_DO import perform_DO

METHODS = (perform_DO, perform_ALS, perform_CLS)
METHODNAMES = ["DO","ALS","CLS"]
SAVENAMES = ["zohar", "alter", "hms", "coh_response"]
DATANAMES = ['Covid serology', 'HIV serology', 'DyeDrop profiling', 'BC cytokine']
DROPS = (0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5)
15 changes: 15 additions & 0 deletions timpute/common.py → timpute/figures/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,18 @@ def subplotLabel(axs):
""" Place subplot labels on figure. """
for ii, ax in enumerate(axs):
ax.text(-0.2, 1.2, ascii_lowercase[ii], transform=ax.transAxes, fontsize=16, fontweight="bold", va="top")


def set_boxplot_color(bp, color):
plt.setp(bp['boxes'], color=color)
plt.setp(bp['whiskers'], color=color)
plt.setp(bp['caps'], color=color)
plt.setp(bp['medians'], color=color)


def rgbs(color = 0, transparency = None):
color_rgbs = [sns.color_palette("bright")[0],sns.color_palette("bright")[8],sns.color_palette("bright")[3]]
if transparency is not None:
return tuple(list(color_rgbs[color]) + [transparency])
else:
return color_rgbs[color]
8 changes: 4 additions & 4 deletions timpute/figures/figure2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
from .figure_helper import *
from ..plot import *
from ..common import *
import math
import numpy as np
from .figure_helper import loadImputation
from .common import getSetup, subplotLabel, rgbs
from figures import METHODS, METHODNAMES, SAVENAMES, DATANAMES

# poetry run python -m timpute.figures.figure2

Expand Down
9 changes: 5 additions & 4 deletions timpute/figures/figure3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
from .runImputation import *
from ..plot import *
from ..common import *
import math
import numpy as np
from matplotlib.lines import Line2D
from .figure_helper import loadImputation
from .common import getSetup, subplotLabel, rgbs, set_boxplot_color
from figures import METHODS, METHODNAMES, SAVENAMES, DATANAMES, DROPS

# poetry run python -m timpute.figures.figure3

Expand Down
7 changes: 3 additions & 4 deletions timpute/figures/figure4.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import numpy as np
from .runImputation import *
from ..plot import *
from ..common import *
import math
from .figure_helper import loadImputation
from .common import getSetup, subplotLabel, rgbs
from figures import METHODS, METHODNAMES, SAVENAMES, DATANAMES

# poetry run python -m timpute.figures.figure4

Expand Down
6 changes: 3 additions & 3 deletions timpute/figures/figure5.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
from .figure_data import bestComps
from .figure_helper import *
from ..plot import *
from ..common import *
from .figure_helper import loadImputation
from .common import getSetup, subplotLabel, rgbs
from figures import METHODS, METHODNAMES, SAVENAMES, DATANAMES
# from matplotlib.legend_handler import HandlerErrorbar

# poetry run python -m timpute.figures.figure5
Expand Down
9 changes: 3 additions & 6 deletions timpute/figures/figure_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import numpy as np
import os

from .figure_helper import *
import numpy as np
from .figure_helper import runImputation, loadImputation
from ..generateTensor import generateTensor
from ..plot import *
from ..common import *
import pickle
from figures import METHODS, METHODNAMES, SAVENAMES, DROPS

# poetry run python -m timpute.figures.figure_data

Expand Down
10 changes: 0 additions & 10 deletions timpute/figures/figure_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
from ..decomposition import Decomposition
from ..tracker import Tracker

from ..method_ALS import perform_ALS
from ..method_CLS import perform_CLS
from ..method_DO import perform_DO
METHODS = (perform_DO, perform_ALS, perform_CLS)
METHODNAMES = ["DO","ALS","CLS"]
SAVENAMES = ["zohar", "alter", "hms", "coh_response"]
DATANAMES = ['Covid serology', 'HIV serology', 'DyeDrop profiling', 'BC cytokine']
LINESTYLES = ('dashdot', (0,(1,1)), 'solid', (3,(3,1,1,1,1,1)), 'dotted', (0,(5,1)))
DROPS = (0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5)


def runImputation(data:np.ndarray,
max_rr:int,
Expand Down
8 changes: 6 additions & 2 deletions timpute/figures/memUsage.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# import psutil
import pickle
import os
import numpy as np
import resource
import argparse
from .figure_helper import *

from figures import METHODS, METHODNAMES, SAVENAMES
from ..decomposition import Decomposition
from ..generateTensor import generateTensor
import pickle

# poetry run python -m timpute.figures.dataUsage

Expand Down
7 changes: 4 additions & 3 deletions timpute/figures/supplements.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pickle
from ..plot import *
from ..common import *
from figures import METHODNAMES, SAVENAMES, DATANAMES, DROPS
import numpy as np
import pandas as pd
from .figure_helper import *
from .common import getSetup, rgbs
from ..decomposition import Decomposition

# poetry run python -m timpute.figures.supplements

Expand Down
2 changes: 1 addition & 1 deletion timpute/generateTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tensordata.atyeo import data as atyeo
from tensordata.zohar import data as zohar
from tensordata.alter import data as alter
from .import_hmsData import hms_tensor
from .data.import_hmsData import hms_tensor

def generateTensor(type=None, r=6, shape=(10,10,10), scale=2, distribution='gamma', par=2, missingness=0.0, noise_scale=50):
"""
Expand Down
103 changes: 102 additions & 1 deletion timpute/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,108 @@
import tensorly as tl
from tensorly.tenalg import svd_interface
from tensorly.random import random_cp
from .SVD_impute import IterativeSVD

class IterativeSVD(object):
def __init__(
self,
rank,
convergence_threshold=1e-7,
max_iters=500,
random_state=None,
min_value=None,
max_value=None,
verbose=False):
self.min_value=min_value
self.max_value=max_value
self.rank = rank
self.max_iters = max_iters
self.convergence_threshold = convergence_threshold
self.verbose = verbose
self.random_state = random_state

def clip(self, X):
"""
Clip values to fall within any global or column-wise min/max constraints
"""
X = np.asarray(X)
if self.min_value is not None:
X[X < self.min_value] = self.min_value
if self.max_value is not None:
X[X > self.max_value] = self.max_value
return X

def prepare_input_data(self, X):
"""
Check to make sure that the input matrix and its mask of missing
values are valid. Returns X and missing mask.
"""
if X.dtype != "f" and X.dtype != "d":
X = X.astype(float)

assert X.ndim == 2
missing_mask = np.isnan(X)
assert not missing_mask.all()
return X, missing_mask

def fit_transform(self, X, y=None):
"""
Fit the imputer and then transform input `X`
Note: all imputations should have a `fit_transform` method,
but only some (like IterativeImputer in sklearn) also support inductive
mode using `fit` or `fit_transform` on `X_train` and then `transform`
on new `X_test`.
"""
X_original, missing_mask = self.prepare_input_data(X)
observed_mask = ~missing_mask
X_filled = X_original.copy()
X_filled[missing_mask] = 0.0
assert isinstance(X_filled, np.ndarray)
X_result = self.solve(X_filled, missing_mask)
assert isinstance(X_result, np.ndarray)
X_result = self.clip(np.asarray(X_result))
X_result[observed_mask] = X_original[observed_mask]
return X_result

def _converged(self, X_old, X_new, missing_mask):
F32PREC = np.finfo(np.float32).eps
# check for convergence
old_missing_values = X_old[missing_mask]
new_missing_values = X_new[missing_mask]
difference = old_missing_values - new_missing_values
ssd = np.sum(difference ** 2)
old_norm_squared = (old_missing_values ** 2).sum()
# edge cases
if old_norm_squared == 0 or \
(old_norm_squared < F32PREC and ssd > F32PREC):
return False
else:
return (ssd / old_norm_squared) < self.convergence_threshold

def solve(self, X, missing_mask):
observed_mask = ~missing_mask
X_filled = X
for i in range(self.max_iters):
curr_rank = self.rank
self.U, S, V = svd_interface(matrix=X_filled, n_eigenvecs=curr_rank, random_state=self.random_state)
X_reconstructed = self.U @ np.diag(S) @ V
X_reconstructed = self.clip(X_reconstructed)

# Masked mae
mae = np.mean(np.abs(X[observed_mask] - X_reconstructed[observed_mask]))

if self.verbose:
print(
"[IterativeSVD] Iter %d: observed MAE=%0.6f" % (
i + 1, mae))
converged = self._converged(
X_old=X_filled,
X_new=X_reconstructed,
missing_mask=missing_mask)
X_filled[missing_mask] = X_reconstructed[missing_mask]
if converged:
break
return X_filled


def initialize_fac(tensor: np.ndarray, rank: int, method='svd'):
"""Initialize factors used in `parafac`.
Expand Down
Loading

0 comments on commit bea8b43

Please sign in to comment.