Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/tree functions #177

Merged
merged 15 commits into from
Jul 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions discretize/TensorMesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,18 @@ def _repr_html_(self):
fmt = "<table>\n"
fmt += " <tr>\n"
fmt += " <td style='font-weight: bold; font-size: 1.2em; text-align"
fmt += ": center;' colspan='3'>{}</td\n>".format(type(self).__name__)
fmt += ": center;' colspan='3'>{}</td>\n".format(type(self).__name__)
fmt += " <td style='font-size: 1.2em; text-align: center;'"
fmt += "colspan='4'>{:,} cells</td>\n".format(self.nC)
fmt += " </tr>\n"

fmt += " <tr>\n"
fmt += " <th></th\n>"
fmt += " <th></th\n>"
fmt += " <th colspan='2'"+style+">MESH EXTENT</th\n>"
fmt += " <th colspan='2'"+style+">CELL WIDTH</th\n>"
fmt += " <th"+style+">FACTOR</th\n>"
fmt += " </tr\n>"
fmt += " <th></th>\n"
fmt += " <th></th>\n"
fmt += " <th colspan='2'"+style+">MESH EXTENT</th>\n"
fmt += " <th colspan='2'"+style+">CELL WIDTH</th>\n"
fmt += " <th"+style+">FACTOR</th>\n"
fmt += " </tr>\n"

fmt += " <tr>\n"
fmt += " <th"+style+">dir</th>\n"
Expand Down
253 changes: 205 additions & 48 deletions discretize/TreeMesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from .InnerProducts import InnerProducts
from .MeshIO import TreeMeshIO
from . import utils
from .tree_ext import _TreeMesh
from .tree_ext import _TreeMesh, TreeCell
import numpy as np
from scipy.spatial import Delaunay
import scipy.sparse as sp
Expand All @@ -113,55 +113,132 @@ def is_pow2(num): return ((num & (num - 1)) == 0) and num != 0
# Now can initialize cpp tree parent
_TreeMesh.__init__(self, self.h, self.x0)

def __str__(self):
outStr = ' ---- {0!s}TreeMesh ---- '.format(
('Oc' if self.dim == 3 else 'Quad')
)

def printH(hx, outStr=''):
i = -1
while True:
i = i + 1
if i > hx.size:
break
elif i == hx.size:
break
h = hx[i]
n = 1
for j in range(i+1, hx.size):
if hx[j] == h:
n = n + 1
i = i + 1
else:
break
if n == 1:
outStr += ' {0:.2f}, '.format(h)
else:
outStr += ' {0:d}*{1:.2f}, '.format(n, h)
return outStr[:-1]

if self.dim == 2:
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
outStr += printH(self.hx, outStr='\n hx:')
outStr += printH(self.hy, outStr='\n hy:')
elif self.dim == 3:
outStr += '\n x0: {0:.2f}'.format(self.x0[0])
outStr += '\n y0: {0:.2f}'.format(self.x0[1])
outStr += '\n z0: {0:.2f}'.format(self.x0[2])
outStr += printH(self.hx, outStr='\n hx:')
outStr += printH(self.hy, outStr='\n hy:')
outStr += printH(self.hz, outStr='\n hz:')
outStr += '\n nC: {0:d}'.format(self.nC)
outStr += '\n Fill: {0:2.2f}%'.format((self.fill*100))
return outStr
def __repr__(self):
"""Plain text representation."""
mesh_name = '{0!s}TreeMesh'.format(('Oc' if self.dim==3 else 'Quad'))

top = "\n"+mesh_name+": {0:2.2f}% filled\n\n".format(self.fill*100)

# Number of cells per level
level_count = self._count_cells_per_index()
non_zero_levels = np.nonzero(level_count)[0]
cell_display = ["Level : Number of cells"]
cell_display.append("-----------------------")
for level in non_zero_levels:
cell_display.append("{:^5} : {:^15}".format(level, level_count[level]))
cell_display.append("-----------------------")
cell_display.append("Total : {:^15}".format(self.nC))

extent_display = [" Mesh Extent "]
extent_display.append(" min , max ")
extent_display.append(" ---------------------------")
dim_label = {0:'x',1:'y',2:'z'}
for dim in range(self.dim):
n_vector = getattr(self, 'vectorN'+dim_label[dim])
extent_display.append("{}: {:^13},{:^13}".format(dim_label[dim], n_vector[0], n_vector[-1]))

for i, line in enumerate(extent_display):
if i==len(cell_display):
cell_display.append(" "*(len(cell_display[0])-3-len(line)))
cell_display[i] += 3*" " + line

h_display = [' Cell Widths ']
h_display.append(" min , max ")
h_display.append("-"*(len(h_display[0])))
h_gridded = self.h_gridded
mins = np.min(h_gridded,axis=0)
maxs = np.max(h_gridded,axis=0)
for dim in range(self.dim):
h_display.append("{:^10}, {:^10}".format(mins[dim], maxs[dim]))

for i, line in enumerate(h_display):
if i==len(cell_display):
cell_display.append(" "*len(cell_display[0]))
cell_display[i] += 3*" " + line

return top+"\n".join(cell_display)

def _repr_html_(self):
"""html representation"""
mesh_name = '{0!s}TreeMesh'.format(('Oc' if self.dim==3 else 'Quad'))
level_count = self._count_cells_per_index()
non_zero_levels = np.nonzero(level_count)[0]
dim_label = {0:'x',1:'y',2:'z'}
h_gridded = self.h_gridded
mins = np.min(h_gridded,axis=0)
maxs = np.max(h_gridded,axis=0)

style = " style='padding: 5px 20px 5px 20px;'"
#Cell level table:
cel_tbl = "<table>\n"
cel_tbl += "<tr>\n"
cel_tbl += "<th"+style+">Level</th>\n"
cel_tbl += "<th"+style+">Number of cells</th>\n"
cel_tbl += "</tr>\n"
for level in non_zero_levels:
cel_tbl += "<tr>\n"
cel_tbl += "<td"+style+">{}</td>\n".format(level)
cel_tbl += "<td"+style+">{}</td>\n".format(level_count[level])
cel_tbl += "</tr>\n"
cel_tbl += "<tr>\n"
cel_tbl += "<td style='font-weight: bold; padding: 5px 20px 5px 20px;'> Total </td>\n"
cel_tbl += "<td"+style+"> {} </td>\n".format(self.nC)
cel_tbl += "</tr>\n"
cel_tbl += "</table>\n"

det_tbl = "<table>\n"
det_tbl += "<tr>\n"
det_tbl += "<th></th>\n"
det_tbl += "<th"+style+" colspan='2'>Mesh extent</th>\n"
det_tbl += "<th"+style+" colspan='2'>Cell widths</th>\n"
det_tbl += "</tr>\n"

det_tbl += "<tr>\n"
det_tbl += "<th></th>\n"
det_tbl += "<th"+style+">min</th>\n"
det_tbl += "<th"+style+">max</th>\n"
det_tbl += "<th"+style+">min</th>\n"
det_tbl += "<th"+style+">max</th>\n"
det_tbl += "</tr>\n"
for dim in range(self.dim):
n_vector = getattr(self, 'vectorN'+dim_label[dim])
det_tbl += "<tr>\n"
det_tbl += "<td"+style+">{}</td>\n".format(dim_label[dim])
det_tbl += "<td"+style+">{}</td>\n".format(n_vector[0])
det_tbl += "<td"+style+">{}</td>\n".format(n_vector[-1])
det_tbl += "<td"+style+">{}</td>\n".format(mins[dim])
det_tbl += "<td"+style+">{}</td>\n".format(maxs[dim])
det_tbl += "</tr>\n"
det_tbl += "</table>\n"

full_tbl = "<table>\n"
full_tbl += "<tr>\n"
full_tbl += "<td style='font-weight: bold; font-size: 1.2em; text-align: center;'>{}</td>\n".format(mesh_name)
full_tbl += "<td style='font-size: 1.2em; text-align: center;' colspan='2'>{0:2.2f}% filled</td>\n".format(100*self.fill)
full_tbl += "</tr>\n"
full_tbl += "<tr>\n"

full_tbl += "<td>\n"
full_tbl += cel_tbl
full_tbl += "</td>\n"

full_tbl += "<td>\n"
full_tbl += det_tbl
full_tbl += "</td>\n"

full_tbl += "</tr>\n"
full_tbl += "</table>\n"

return full_tbl

@property
def vntF(self):
"""Total number of hanging and non-hanging faces in a [nx,ny,nz] form"""
return [self.ntFx, self.ntFy] + ([] if self.dim == 2 else [self.ntFz])

@property
def vntE(self):
"""Total number of hanging and non-hanging edges in a [nx,ny,nz] form"""
return [self.ntEx, self.ntEy] + ([] if self.dim == 2 else [self.ntEz])

@property
Expand Down Expand Up @@ -242,7 +319,7 @@ def cellGradx(self):
@property
def cellGrady(self):
"""
Cell centered Gradient operator in y-direction (Gradx)
Cell centered Gradient operator in y-direction (Grady)
Grad = sp.vstack((Gradx, Grady, Gradz))
"""
if getattr(self, '_cellGrady', None) is None:
Expand Down Expand Up @@ -271,6 +348,8 @@ def cellGradz(self):
Cell centered Gradient operator in z-direction (Gradz)
Grad = sp.vstack((Gradx, Grady, Gradz))
"""
if self.dim == 2:
raise TypeError("z derivative not defined in 2D")
if getattr(self, '_cellGradz', None) is None:

nFx = self.nFx
Expand Down Expand Up @@ -310,21 +389,98 @@ def faceDivz(self):
return self._faceDivz

def point2index(self, locs):
"""Finds cells that contain the given points.
Returns an array of index values of the cells that contain the given
points

Parameters
----------
locs: array_like of shape (N, dim)
points to search for the location of

Returns
-------
numpy.array of integers of length(N)
Cell indices that contain the points
"""
locs = utils.asArray_N_x_Dim(locs, self.dim)

inds = np.empty(locs.shape[0], dtype=np.int64)
for ind, loc in enumerate(locs):
inds[ind] = self._get_containing_cell_index(loc)
inds = self._get_containing_cell_indexes(locs)
return inds

def cell_levels_by_index(self, indices):
"""Fast function to return a list of levels for the given cell indices

Parameters
----------
index: array_like of length (N)
Cell indexes to query

Returns
-------
numpy.array of length (N)
Levels for the cells.
"""

return self._cell_levels_by_indexes(indices)


def getInterpolationMat(self, locs, locType, zerosOutside=False):
""" Produces interpolation matrix

Parameters
----------
loc : numpy.ndarray
Location of points to interpolate to

locType: str
What to interpolate

locType can be::

'Ex' -> x-component of field defined on edges
'Ey' -> y-component of field defined on edges
'Ez' -> z-component of field defined on edges
'Fx' -> x-component of field defined on faces
'Fy' -> y-component of field defined on faces
'Fz' -> z-component of field defined on faces
'N' -> scalar field defined on nodes
'CC' -> scalar field defined on cell centers

Returns
-------
scipy.sparse.csr_matrix
M, the interpolation matrix

"""
locs = utils.asArray_N_x_Dim(locs, self.dim)
if locType not in ['N', 'CC', "Ex", "Ey", "Ez", "Fx", "Fy", "Fz"]:
raise Exception('locType must be one of N, CC, Ex, Ey, Ez, Fx, Fy, or Fz')

if self.dim == 2 and locType in ['Ez', 'Fz']:
raise Exception('Unable to interpolate from Z edges/face in 2D')

locs = np.require(np.atleast_2d(locs), dtype=np.float64, requirements='C')

if locType == 'N':
Av = self._getNodeIntMat(locs, zerosOutside)
elif locType in ['Ex', 'Ey', 'Ez']:
Av = self._getEdgeIntMat(locs, zerosOutside, locType[1])
elif locType in ['Fx', 'Fy', 'Fz']:
Av = self._getFaceIntMat(locs, zerosOutside, locType[1])
elif locType in ['CC']:
Av = self._getCellIntMat(locs, zerosOutside)
return Av

@property
def permuteCC(self):
"""Permutation matrix re-ordering of cells sorted by x, then y, then z"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the doc-strings :)

# TODO: cache these?
P = np.lexsort(self.gridCC.T) # sort by x, then y, then z
return sp.identity(self.nC).tocsr()[P]

@property
def permuteF(self):
"""Permutation matrix re-ordering of faces sorted by x, then y, then z"""
# TODO: cache these?
Px = np.lexsort(self.gridFx.T)
Py = np.lexsort(self.gridFy.T)+self.nFx
Expand All @@ -337,6 +493,7 @@ def permuteF(self):

@property
def permuteE(self):
"""Permutation matrix re-ordering of edges sorted by x, then y, then z"""
# TODO: cache these?
Px = np.lexsort(self.gridEx.T)
Py = np.lexsort(self.gridEy.T) + self.nEx
Expand Down
8 changes: 5 additions & 3 deletions discretize/View.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ def plotGrid(
else:
if not isinstance(ax, matplotlib.axes.Axes):
raise AssertionError("ax must be an matplotlib.axes.Axes")
if lines:
color = kwargs.get('color', 'C0')
linewidth = kwargs.get('linewidth', 1.)

if self.dim == 1:
if nodes:
Expand Down Expand Up @@ -660,8 +663,7 @@ def plotGrid(
marker="^", linestyle=""
)

color = kwargs.get('color', 'C0')
linewidth = kwargs.get('linewidth', 1.)

# Plot the grid lines
if lines:
NN = self.r(self.gridN, 'N', 'N', 'M')
Expand Down Expand Up @@ -728,7 +730,7 @@ def plotGrid(
X = np.r_[X1, X2, X3]
Y = np.r_[Y1, Y2, Y3]
Z = np.r_[Z1, Z2, Z3]
ax.plot(X, Y, color="C0", linestyle="-", zs=Z)
ax.plot(X, Y, color=color, linestyle="-", lw=linewidth, zs=Z)
ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_zlabel('x3')
Expand Down
2 changes: 2 additions & 0 deletions discretize/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ Cell::Cell(Node *pts[8], int_t ndim, int_t maxlevel, function func){
int_t n_points = 1<<n_dim;
for(int_t i = 0; i < n_points; ++i)
points[i] = pts[i];
index = -1;
level = 0;
max_level = maxlevel;
parent = NULL;
Expand Down Expand Up @@ -212,6 +213,7 @@ Cell::Cell(Node *pts[8], Cell *parent){
int_t n_points = 1<<n_dim;
for(int_t i = 0; i < n_points; ++i)
points[i] = pts[i];
index = -1;
level = parent->level + 1;
max_level = parent->max_level;
test_func = parent->test_func;
Expand Down
3 changes: 2 additions & 1 deletion discretize/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ class Cell{
Edge *edges[12];
Face *faces[6];

int_t location_ind[3], index, key, level, max_level;
int_t location_ind[3], key, level, max_level;
long long int index; // non root parents will have a -1 value
double location[3];
double volume;
function test_func;
Expand Down
Loading