Skip to content

Commit

Permalink
Merge pull request #167 from cjb873/sindy
Browse files Browse the repository at this point in the history
SINDy
  • Loading branch information
drgona authored Jul 2, 2024
2 parents 0ce896e + 21c214e commit 70b1598
Show file tree
Hide file tree
Showing 6 changed files with 1,764 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ Part 5: Using Cvxpylayers for differentiable projection onto the polytopic feasi
+ <a target="_blank" href="https://colab.research.google.com/github/pnnl/neuromancer/blob/master/examples/ODEs/Part_8_nonauto_DeepKoopman.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> Part 8: control-oriented Deep Koopman operator

+ <a target="_blank" href="https://colab.research.google.com/github/pnnl/neuromancer/blob/master/examples/ODEs/Part_9_SINDy.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> Part 9: Sparse Identification of Nonlinear Dynamics (SINDy)


### Physics-Informed Neural Networks (PINNs) for Partial Differential Equations (PDEs)
+ <a target="_blank" href="https://colab.research.google.com/github/pnnl/neuromancer/blob/master/examples/PDEs/Part_1_PINN_DiffusionEquation.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> Part 1: Diffusion Equation
Expand Down
1,521 changes: 1,521 additions & 0 deletions examples/ODEs/Part_9_SINDy.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions examples/ODEs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ showcasing system identification capabilities for ordinary differential equation
+ <a target="_blank" href="https://colab.research.google.com/github/pnnl/neuromancer/blob/master/examples/ODEs/Part_8_nonauto_DeepKoopman.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> Part 8: control-oriented Deep Koopman operator

+ <a target="_blank" href="https://colab.research.google.com/github/pnnl/neuromancer/blob/master/examples/ODEs/Part_9_SINDy.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> Part 9: Sparse Identification of Nonlinear Dynamics (SINDy)


## System Identification

Expand Down
Binary file added examples/ODEs/figs/sindy.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
180 changes: 180 additions & 0 deletions src/neuromancer/dynamics/library.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import numpy as np
import itertools
import torch

class FunctionLibrary:

def __init__(
self,
functions,
n_features,
function_names=None
):
"""
:param functions: (list) a one-dimensional list of callable functions.
All callables must be able to operate on the
columns of a two-dimensional torch.Tensor of
floating point numbers.
:param n_features: (int) number of columns of data matrix
:param function_names: (list) name for each function, must either be None
or a list of the same length as the functions
"""
if function_names is not None:
assert len(functions) == len(function_names), "Must have one name for each function"

assert isinstance(functions, list), "Functions must be passed as list"

self.library = functions

self.shape = (len(self.library), n_features)
self.function_names = function_names

def evaluate(self, x):
"""
:param x: (torch.Tensor) an array of values to evaluate the functions at
:return: (torch.Tensor, shape=[[# of rows of X, number of functions) the
functions evaluated at every single time step of the data
"""
output = torch.zeros(size=(x.shape[0], self.shape[0]))

for i in range(self.shape[0]):
output[:, i] = self.library[i](x)

return output

def __str__(self):
"""
:return: (str) all of the names provided to the library, or generic function names
"""
out_str = ""
if self.function_names is not None:
for name in self.function_names:
out_str += f"{name}, "

else:
for i in range(self.shape[0]):
out_str += f"f{i}, "
return out_str[:-2]


class PolynomialLibrary(FunctionLibrary):
def __init__(
self,
n_features,
max_degree=2
):
"""
:param n_features: (int) the number of features (columns) in the input dataset
:param max_degree: (int) the maximum total degree of any polynomial
(i.e. max_degree=3 -> x^2y)
"""
self.max_degree = max_degree
lib, function_names = self.__create_library(n_features)
super().__init__(lib, n_features, function_names)

def __create_library(self, n_features):
"""
:param n_features: (int) the number of features (columns) in the input dataset
:return: (list) a row vector of all of the polynomial functions
:return: (list) a list of all of the names of the functions
"""
funs_list = [(lambda y: lambda X: X[:, y])(i) for i in range(n_features)]
vars_list = [f"x{i}" for i in range(n_features)]
all_combos = [(lambda X: torch.ones(size=(X.shape[0],)),)]
all_names = []

for i in range(1, self.max_degree+1):
combos = list(itertools.combinations_with_replacement(funs_list, i))
names = list(itertools.combinations_with_replacement(vars_list, i))

for j in range(len(combos)):
all_combos.append(combos[j])

for j in range(len(names)):
all_names.append(names[j])

library = all_combos
names = self.__convert(all_names)
return library, names

def __convert(self, names):
"""
:param names: (list) a list of tuples of all of the names of terms that were combined
:return: (list) the names converted into a more intrepreted form (i.e. x*x*x -> x^3)
"""
return_names = ["1"]
for func in names:
name = ""
for term in func:
if term not in name:
name += f"{term}*"
else:
if f"{term}^" in name:
degree = int(name[name.find(f"{term}^") + 3])
name = name.replace(f"{term}^{degree}", f"{term}^{degree+1}")
else:
name = name.replace(term, f"{term}^2")
return_names.append(name[:-1])
return return_names

def evaluate(self, X):
"""
:param X: (torch.Tensor) the two-dimensional dataset to put through the library
:return: (torch.Tensor, [# of rows of X, number of functions]) the library evaluated at X
"""
output = torch.ones((X.shape[0], self.shape[0]))
for i in range(self.shape[0]):
for func in self.library[i]:
output[:, i] *= func(X)

return output


class FourierLibrary(FunctionLibrary):
def __init__(
self,
n_features,
max_freq=2,
include_sin=True,
include_cos=True
):
"""
:param n_features: (int) the number of features of the dataset (columns of data matrix)
:param max_freq: (int) the max multiplier for the frequency
:param include_sin: (bool) include the sine function
:param include_cos: (bool) include the cosine function
"""
self.max_freq = max_freq
self.include_sin = include_sin
self.include_cos = include_cos
lib, function_names = self.__create_library(n_features)
super().__init__(lib, n_features, function_names)

def __create_library(self, n_features):
"""
:param n_features: (int) the number of input features of the dataset (columns of data matrix)
:return: (list) an array of functions of sines and cosines
:return: (list) a list of the names of each function (i.e. sin(2*x0))
"""
sines = []
sin_names = []
cosines = []
cos_names = []
if self.include_sin:
sines = [(lambda z: [(lambda y: lambda X: torch.sin(y * X[:,z]))(i) \
for i in range(1,self.max_freq+1)]) (j) for j in \
range(n_features)]
sin_names = [[f"sin({j}*x{i})" for j in \
range(1, self.max_freq+1)] for i in range(n_features)]

if self.include_cos:
cosines = [(lambda z: [(lambda y: lambda X: torch.cos(y * X[:,z]))(i)\
for i in range(1,self.max_freq+1)]) (j) for j in \
range(n_features)]
cos_names = [[f"cos({j}*x{i})" for j in \
range(1, self.max_freq+1)] for i in range(n_features)]

library = sum(sines, []) + sum(cosines, [])
function_names = sum(sin_names, []) + sum(cos_names, [])

return library, function_names
57 changes: 57 additions & 0 deletions src/neuromancer/dynamics/ode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from neuromancer.dynamics.library import FunctionLibrary


class SSM(nn.Module):
Expand Down Expand Up @@ -156,6 +157,62 @@ def coupling_physics(self, x, *args):
return dx


class SINDy(ODESystem):
"""
Sparse Identification of Nonlinear Dynamics
Reference: https://www.pnas.org/doi/10.1073/pnas.1517384113
"""
def __init__(
self,
library,
threshold=1e-2
):
"""
:param library: (FunctionLibrary) the library of candidate functions
:param threshold: (float) all functions with coefficients lower than this are omitted
"""
assert isinstance(library, FunctionLibrary), "Must be valid library"

super().__init__(library.shape[1], library.shape[1])

self.library = library
self.threshold = threshold
init_coef = torch.rand(self.library.shape)
self.coef = torch.nn.Parameter(init_coef, requires_grad=True)
self.float()

def ode_equations(self, x):
"""
:param x: (torch.tensor) time series data
"""
assert x.ndim == 2, "Input must not be empty"
assert x.shape[1] == self.library.shape[1], "Must have same number of states as insize"
lib_eval = self.library.evaluate(x)
output = torch.matmul(lib_eval, self.coef)
return output


def __str__(self):
"""
return: (str) a list of the linear combinations of candidate functions for each state variable
"""
f_names = self.library.__str__()
f_names = f_names.split(", ")
return_str = ""

for i in range(self.library.shape[1]):
return_str += f"dx{i}/dt = "
for j in range(len(f_names)):
coef = self.coef[j, i]
if torch.abs(coef) > self.threshold:
func = f_names[j]
return_str += f"{coef:.3f}*{func} + "
return_str = return_str[:-2]
return_str += "\n"

return return_str


class TwoTankParam(ODESystem):
def __init__(self, insize=4, outsize=2):
"""
Expand Down

0 comments on commit 70b1598

Please sign in to comment.