Skip to content

Commit

Permalink
basic SAMC implementation functional
Browse files Browse the repository at this point in the history
  • Loading branch information
tkchafin committed Apr 10, 2024
1 parent c0fc71e commit 602f82f
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 46 deletions.
165 changes: 165 additions & 0 deletions src/resistnet/CFPT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import sys
import numpy as np
from scipy.sparse import diags, eye, csr_matrix
from scipy.sparse.linalg import spsolve, lgmres, cg

# 1. Scale dest row off-diagonals by absorption
# 2. fill diagonals
# 3. Extract Qj and qj
# 4. Negate Qj and +1 diagonals

def CFPT(Q, R, edge_site_indices):
N = len(edge_site_indices)
cfpt_matrix = np.zeros((N, N))

for i, dest in enumerate(edge_site_indices):
Q_temp = Q.copy().tolil() # LIL format for easier row manipulation
absorption_factor = R[dest]

# Get indices and data for the dest row
for j in range(Q.shape[1]):
if j != dest: # Avoid altering the diagonal element at this step
Q_temp[dest, j] *= (1 - absorption_factor)
Q_temp = Q_temp.tocsr()
# Ensure diagonals make row sums to 1
Q_temp = _set_diags(Q_temp)

for j, orig in enumerate(edge_site_indices):
if orig != dest:
# Extract Qj and qj for current destination
mask = np.ones(Q.shape[0], dtype=bool)
mask[dest] = False
Qj = Q_temp[mask, :][:, mask]
qj = Q_temp[mask, dest]

# Apply numeric transformations to Qj for solving
Qj.data *= -1
Qj.setdiag(Qj.diagonal() + 1)

# Solve for passage times using the adjusted Qj
try:
solution, info = lgmres(Qj, qj.toarray().flatten())
if info == 0:
adjusted_index = np.where(mask)[0].tolist().index(orig)
cfpt_matrix[j, i] = solution[adjusted_index]
else:
print(f"Convergence issue with dest={dest}, orig={orig}, info={info}")
cfpt_matrix[j, i] = np.nan
except Exception as e:
print(f"Solver failed for dest={dest}, orig={orig}: {e}")
cfpt_matrix[j, i] = np.nan
return cfpt_matrix

# def CFPT(Q, R, edge_site_indices):
# """
# Calculate conditional first passage times for a Markov chain represented by matrix Q
# with absorption probabilities R for each state specified in edge_site_indices.

# Parameters:
# - Q (csr_matrix): Transition probability matrix without absorption (diagonals adjusted).
# - R (numpy.ndarray): Absorption probabilities for each state.
# - edge_site_indices (list): Indices of states for which to calculate CFPT.

# Returns:
# - cfpt_matrix (numpy.ndarray): Matrix of conditional first passage times.
# """

# # create Q_scaled
# # if more efficient can stack these
# abs_vec = np.ravel(R)
# # NOTE: I don't think we need the below part currently
# # for each row (1 - abs_vec[i]) * row_values / np.sum(row_values)

# # set diagonals
# Q = _set_diags(Q)

# # validate Q
# row_sums = Q.sum(axis=1).A1
# if not np.allclose(row_sums, 1):
# raise ValueError("Row sums in Q should sum to 1.")

# # in-place modifications for downstream steps
# _opt_prep(Q)

# # compute cfpt
# Np = len(edge_site_indices)
# cfpt_matrix = np.zeros((Np, Np))

# for i, dest in enumerate(edge_site_indices):
# # Mask to exclude 'dest'
# mask = np.ones(Q.shape[0], dtype=bool)
# mask[dest] = False

# # Create Qj by excluding the 'dest' state
# Qj = Q[mask, :][:, mask]

# # Create qj as the 'dest' column, excluding itself
# qj = Q[mask, dest]

# for j, orig in enumerate(edge_site_indices):
# if orig != dest:
# try:
# passage_times = spsolve(Qj, qj.toarray().flatten())
# print(passage_times)
# cfpt_matrix[j, i] = passage_times[orig]
# except Exception as e:
# print(f"Solver failed for dest={dest}, orig={orig}: {e}")
# cfpt_matrix[j, i] = np.nan
# print(cfpt_matrix)
# sys.exit()
# return cfpt_matrix

def _set_diags(sm, offset=None):
"""
Adjusts the diagonal elements of a sparse matrix to ensure that each
row sums to 1, optionally incorporating an offset subtraction from each
diagonal element.
This function modifies the input sparse matrix in-place.
Parameters:
- sm (csr_matrix): The sparse matrix whose diagonals are to be adjusted.
- offset (numpy.ndarray, optional): An array of values to subtract from
each diagonal. If `None`, no offset is subtracted.
Returns:
- None: The matrix `sm` is modified in-place.
"""
row_sums = sm.sum(axis=1).A1
d = 1 - row_sums
if offset is not None:
d -= offset
dm = diags(d, 0, shape=sm.shape, format='csr')
sm += dm
_validate_row_sums(sm)
return sm


def _validate_row_sums(sm):
"""
Validates that each row of a sparse matrix sums to 1.
Parameters:
- sm (csr_matrix): The sparse matrix to validate.
Raises:
- ValueError: If any row sum does not equal 1.
"""
row_sums_after = sm.sum(axis=1).A1
if not np.allclose(row_sums_after, 1):
raise ValueError("Row sums do not sum to 1.")


def _opt_prep(sm):
"""
Prepares a scipy.sparse matrix for optimization by negating non-zero
elements and incrementing diagonal elements by 1, in-place.
Parameters:
- sm (scipy.sparse matrix): The sparse matrix to be modified in-place.
Returns:
- None
"""
sm.data = -sm.data
sm.setdiag(sm.diagonal() + 1)
3 changes: 3 additions & 0 deletions src/resistnet/model_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,9 @@ def start_workers(self, threads):
worker_args['adj'] = self.resistance_network._adj
worker_args['origin'] = self.resistance_network._origin
worker_args['R'] = self.resistance_network._R
worker_args[
'edge_site_indices'
] = self.resistance_network._edge_site_indices

worker_process = Process(
target=self.worker_task,
Expand Down
15 changes: 14 additions & 1 deletion src/resistnet/resist_dist.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import itertools
import pandas as pd
import sys
import numpy as np

import resistnet.MLPE as mlpe_rga

import resistnet.CFPT as cfpt_samc

def parsePairwise(points, inc_matrix, multi, gendist):
"""
Expand Down Expand Up @@ -59,3 +60,15 @@ def effectiveResistanceMatrix(points, inc_matrix, edge_resistance):
np.fill_diagonal(r, 0.0)

return r

def conditionalFirstPassTime(Q, R, sites_i, gendist):

# get cfpt matrix
cfpt = cfpt_samc.CFPT(Q, R, sites_i)
cfpt = np.array(cfpt)

# fit MLPE
res = mlpe_rga.MLPE_R(gendist, cfpt, scale=True)
return cfpt, res


77 changes: 32 additions & 45 deletions src/resistnet/samc_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import resistnet.utils as utils
import resistnet.transform as trans
import resistnet.resist_dist as rd
from resistnet.resistance_network import ResistanceNetwork

class ResistanceNetworkSAMC(ResistanceNetwork):
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(self, network, shapefile, sizes, coords, variables, inmat,
self._Kd = None
self._edge_absorption = None
self._R = None
self._edge_site_indices = None

# Initialization methods
self.initialize_network()
Expand Down Expand Up @@ -319,6 +321,8 @@ def calculate_absorption(self):
edge_id: idx for idx, edge_id in enumerate(self._edge_order)}

sites = list(self._points_snapped.keys())
print(sites)
tips = [None] * len(sites)
# Iterate over edges in the graph
for u, v, edge_data in self._K.edges(data=True):
# Get the edge ID from edge_data using the identifier column name
Expand All @@ -335,10 +339,12 @@ def calculate_absorption(self):
absorption_value = 1 / (2 * pop_size)
idx = edge_to_idx[edge_id]
self._edge_absorption[idx] = absorption_value
tips[sites.index(u)] = idx
else:
print(f"Population size missing for {u}")
self._R = np.array(self._edge_absorption)
self._R = self._R.reshape(-1, 1)
self._edge_site_indices = tips

def read_sizes(self):
"""
Expand Down Expand Up @@ -418,6 +424,7 @@ def evaluate(self, individual):
evaluation.
"""
first = True
multi = None

# Compute any transformations
for i, variable in enumerate(self._predictors.columns):
Expand Down Expand Up @@ -455,52 +462,31 @@ def evaluate(self, individual):
var_m.data *= individual[1::5][i]
multi.data += var_m.data

# minmax scale 0-1
multi.data = utils.minmax(multi.data)

# inverse to get transition rates
# avoid a division by zero error by setting zero to smallest non-zero
non_zero_min = np.min(multi.data[np.nonzero(multi.data)])
multi.data[multi.data == 0] = non_zero_min
multi.data = utils.minmax(1 / multi.data)

# convert to dense array
Q = multi.toarray()

# fill diagonals (row sums should be 1)
np.fill_diagonal(Q, 0)
row_sums = Q.sum(axis=1)
np.fill_diagonal(Q, 1 - row_sums)

# append R and compute P matrix
bottom_row = np.zeros((1, Q.shape[1]))
bottom_row = np.append(bottom_row, 1).reshape(1, -1)
P = np.hstack((Q, self._R))
P = np.vstack((P, bottom_row))

# If no layers are selected, return a zero fitness
if first:
if multi is None:
return float('-inf'), None
else:
pass
# Compute P matrix

# fit SAMC model and compute cfpt matrix

# compute likelihoods with MLPE

# multi = trans.rescaleCols(multi, 1, 10)
# r, res = rd.parsePairwise(
# self._points_snapped, self._inc, multi, self._gendist
# )
# # fitness = res[self.fitmetric][0]
# fitness = res[self.fitmetric].iloc[0]
# res = list(res.iloc[0])

# Return fitness value and results
sys.exit()
return 0.0, [0.0,0.0,0.0,0.0,0.0]
#return (fitness, res)
# complete Q matrix
# minmax scale 0-1
multi.data = utils.minmax(multi.data)

# inverse to get transition rates
# avoid divide-by-zero by setting zero to smallest non-zero element
non_zero_min = np.min(multi.data[np.nonzero(multi.data)])
multi.data[multi.data == 0] = non_zero_min
multi.data = utils.minmax(1 / multi.data)

# compute cfpt matrix
cfpt, res = rd.conditionalFirstPassTime(
multi, self._R, self._edge_site_indices, self._gendist)

# plot cfpt matrix pairwise regression against genetic distance
# matrix held as self._gendist
# these can be assumed to have the same order
# fitness = res[self.fitmetric][0]
fitness = res[self.fitmetric].iloc[0]
res = list(res.iloc[0])
return (fitness, res)

def _compute_transition(self, var, directional=False):
data = []
Expand Down Expand Up @@ -628,7 +614,7 @@ def __init__(self, network, pop_agg, reachid_col, length_col,
variables, agg_opts, fitmetric, posWeight, fixWeight,
allShapes, fixShape, min_weight, max_shape, inc, point_coords,
points_names, points_snapped, points_labels, predictors,
edge_order, gendist, adj, origin, R):
edge_order, gendist, adj, origin, R, edge_site_indices):

self.network = network
self.pop_agg = pop_agg
Expand Down Expand Up @@ -659,4 +645,5 @@ def __init__(self, network, pop_agg, reachid_col, length_col,
self._edge_order = edge_order
self._gendist = gendist
self._adj = adj
self._R = R
self._R = R
self._edge_site_indices = edge_site_indices

0 comments on commit 602f82f

Please sign in to comment.