Skip to content

Commit

Permalink
Merge pull request #1363 from ChaosTHL/develop
Browse files Browse the repository at this point in the history
update
  • Loading branch information
weihuayi authored Nov 24, 2024
2 parents 417f563 + 371c8cd commit f6220de
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 121 deletions.
2 changes: 1 addition & 1 deletion app/lafem-ims/lafemims/data_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from near_field_data_generator import NearFieldDataFEMGenerator2d
from .near_field_data_generator import NearFieldDataFEMGenerator2d
195 changes: 124 additions & 71 deletions app/lafem-ims/lafemims/data_generator/near_field_data_generator.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,38 @@
import os

import os
import matplotlib.pyplot as plt
from numpy.typing import NDArray
from typing import Sequence, Callable
import matplotlib.pyplot as plt

from fealpy.backend import backend_manager as bm
from fealpy.mesh import TriangleMesh, QuadrangleMesh, UniformMesh2d
from fealpy.functionspace import LagrangeFESpace
from fealpy.fem import ScalarDiffusionIntegrator, ScalarMassIntegrator, ScalarSourceIntegrator, ScalarConvectionIntegrator, DirichletBC
from fealpy.fem import BilinearForm, LinearForm
from fealpy.pde.pml_2d import PMLPDEModel2d
from fealpy.functionspace import LagrangeFESpace
from fealpy.fem import (
ScalarDiffusionIntegrator,
ScalarMassIntegrator,
ScalarSourceIntegrator,
ScalarConvectionIntegrator,
BilinearForm,
LinearForm,
DirichletBC
)
from fealpy.solver import spsolve


class NearFieldDataFEMGenerator2d:
def __init__(self,
domain:Sequence[float],
mesh:str,
nx:int,
ny:int,
p:int,
q:int,
u_inc:str,
levelset:Callable[[NDArray], NDArray],
d:Sequence[float],
k:Sequence[float],
reciever_points:NDArray):
domain: Sequence[float],
mesh: str,
nx: int,
ny: int,
p: int,
q: int,
u_inc: str,
levelset: Callable[[NDArray], NDArray],
d: Sequence[float],
k: Sequence[float],
reciever_points: NDArray):

self.domain = domain
self.nx = nx
Expand All @@ -35,53 +42,65 @@ def __init__(self,
self.u_inc = u_inc
self.levelset = levelset

# 验证并创建网格
if mesh not in ['InterfaceMesh', 'QuadrangleMesh', 'UniformMesh']:
raise ValueError("Invalid value for 'mesh'. Choose from 'InterfaceMesh', 'QuadrangleMesh' or 'UniformMesh'.")

if mesh == 'InterfaceMesh':
self.mesh = TriangleMesh.interfacemesh_generator(box=self.domain, nx=self.nx, ny=self.ny, phi=self.levelset)
self.meshtype = 'InterfaceMesh'
elif mesh == 'QuadrangleMesh':
self.mesh = QuadrangleMesh.from_box(box=self.domain, nx=self.nx, ny=self.ny)
self.meshtype = 'QuadrangleMesh'
else:
if mesh == 'InterfaceMesh':
self.mesh = TriangleMesh.interfacemesh_generator(box=self.domain, nx=self.nx, ny=self.ny, phi=self.levelset)
self.meshtype = 'InterfaceMesh'
elif mesh == 'QuadrangleMesh':
self.mesh = QuadrangleMesh.from_box(box=self.domain, nx=self.nx, ny=self.ny)
self.meshtype = 'QuadrangleMesh'
else:
EXTC_1 = self.nx
EXTC_2 = self.ny
HC_1 = 1/EXTC_1 * (self.domain[1] - self.domain[0])
HC_2 = 1/EXTC_2 * (self.domain[3] - self.domain[2])
self.mesh = UniformMesh2d((0, EXTC_1, 0, EXTC_2), (HC_1, HC_2), origin=(self.domain[0], self.domain[2]))
self.meshtype = 'UniformMesh'

self.mesh.ftype = bm.complex128
self.d = d
EXTC_1 = self.nx
EXTC_2 = self.ny
HC_1 = 1 / EXTC_1 * (self.domain[1] - self.domain[0])
HC_2 = 1 / EXTC_2 * (self.domain[3] - self.domain[2])
self.mesh = UniformMesh2d((0, EXTC_1, 0, EXTC_2), (HC_1, HC_2), origin=(self.domain[0], self.domain[2]))
self.meshtype = 'UniformMesh'

self.d = d
self.k = k
self.reciever_points = reciever_points
qf = self.mesh.quadrature_formula(self.q, 'cell')
self.bc, _= qf.get_quadrature_points_and_weights()

def get_nearfield_data(self, k:float, d:Sequence[float]):

k_index = (self.k).index(k)
d_index = (self.d).index(d)
pde = PMLPDEModel2d(levelset=self.levelset,
domain=self.domain,
u_inc=self.u_inc,
A=1,
k=self.k[k_index],
d=self.d[d_index],
refractive_index=[1, 1+1/self.k[k_index]**2],
absortion_constant=1.79,
lx=1.0,
ly=1.0
)
self.bc, _ = qf.get_quadrature_points_and_weights()

def get_nearfield_data(self, k: float, d: Sequence[float]):
"""
获取近场数据。
参数:
- k: 波数
- d: 波矢量方向
返回:
- uh: 近场数据
"""
k_index = self.k.index(k)
d_index = self.d.index(d)
pde = PMLPDEModel2d(
levelset=self.levelset,
domain=self.domain,
u_inc=self.u_inc,
A=1,
k=self.k[k_index],
d=self.d[d_index],
refractive_index=[1, 1 + 1 / self.k[k_index]**2],
absortion_constant=1.79,
lx=1.0,
ly=1.0
)

space = LagrangeFESpace(self.mesh, p=self.p)

# 定义积分器
D = ScalarDiffusionIntegrator(pde.diffusion_coefficient, q=self.q)
C = ScalarConvectionIntegrator(pde.convection_coefficient, q=self.q)
M = ScalarMassIntegrator(pde.reaction_coefficient, q=self.q)
f = ScalarSourceIntegrator(pde.source, q=self.q)

# 组装双线性形式和线性形式
b = BilinearForm(space)
b.add_integrator([D, C, M])

Expand All @@ -90,14 +109,27 @@ def get_nearfield_data(self, k:float, d:Sequence[float]):

A = b.assembly()
F = l.assembly()

bc = DirichletBC(space, pde.dirichlet)
uh = space.function(dtype=bm.complex128)
A, F = bc.apply(A, F)
uh[:] = spsolve(A, F, solver='scipy')
return uh

def points_location_and_bc(self, p, domain:Sequence[float], nx:int, ny:int):

def points_location_and_bc(self, p: NDArray, domain: Sequence[float], nx: int, ny: int):
"""
计算接收点的位置和重心坐标。
参数:
- p: 接收点坐标
- domain: 计算域
- nx: x方向的网格数
- ny: y方向的网格数
返回:
- location: 接收点所在单元的索引
- bc: 重心坐标
"""
x = p[..., 0]
y = p[..., 1]
cell_length_x = (domain[1] - domain[0]) / nx
Expand All @@ -113,17 +145,26 @@ def points_location_and_bc(self, p, domain:Sequence[float], nx:int, ny:int):
bc = (bc_x, bc_y)
return location, bc

def data_for_dsm(self, k:float, d:Sequence[float]):
def data_for_dsm(self, k: float, d: Sequence[float]):
"""
获取用于DSM的数据。
参数:
- k: 波数
- d: 波矢量方向
返回:
- data: DSM数据
"""
reciever_points = self.reciever_points
data_length = reciever_points.shape[0]
data = bm.zeros((data_length, ), dtype=bm.complex128)
data = bm.zeros((data_length,), dtype=bm.complex128)
uh = self.get_nearfield_data(k=k, d=d)
if self.meshtype =='InterfaceMesh':

if self.meshtype == 'InterfaceMesh':
b = self.mesh.point_to_bc(reciever_points)
location = self.mesh.location(reciever_points)
for i in range (data_length):
for i in range(data_length):
data[i] = uh(b[i])[location[i]]
elif self.meshtype == 'QuadrangleMesh':
for i in range(data_length):
Expand All @@ -140,39 +181,51 @@ def data_for_dsm(self, k:float, d:Sequence[float]):
u = uh(b).reshape(-1)
data[i] = u[location]
return data

def save(self, save_path:str, scatterer_index:int):

def save(self, save_path: str, scatterer_index: int):
"""
保存数据。
参数:
- save_path: 保存路径
- scatterer_index: 散射体索引
"""
k_values = self.k
d_values = self.d
data_dict = {}
for i in range (len(k_values)):
for j in range (len(d_values)):
for i in range(len(k_values)):
for j in range(len(d_values)):
k_name = f'k={k_values[i]}'
d_name = d_values[j]
name = f"{k_name}, d={d_name}"
data_dict[name] = self.data_for_dsm(k=k_values[i], d=d_values[j])
filename = os.path.join(save_path, f"data_for_dsm_{scatterer_index}.bmz")
bm.savez(filename, **data_dict)

def visualization_of_nearfield_data(self, k:float, d:Sequence[float]):
def visualization_of_nearfield_data(self, k: float, d: Sequence[float]):
"""
可视化近场数据。
参数:
- k: 波数
- d: 波矢量方向
"""
uh = self.get_nearfield_data(k=k, d=d)
value = uh(self.bc)
if self.meshtype == 'UniformMesh':
self.mesh.ftype = bm.float64
self.mesh.add_plot(plt, cellcolor=value[..., 0].real, linewidths=0)
self.mesh.add_plot(plt, cellcolor=value[..., 0].imag, linewidths=0)

#TODO
fig = plt.figure()
axes = fig.add_subplot(1, 3, 1)
self.mesh.add_plot(axes)
if self.meshtype == 'UniformMesh':
uh = uh.view(bm.ndarray)
axes = fig.add_subplot(1, 3, 2, projection='3d')
self.mesh.show_function(axes, bm.real(uh))
axes = fig.add_subplot(1, 3, 3, projection='3d')
self.mesh.show_function(axes, bm.imag(uh))
# TODO: 添加更多可视化选项
# fig = plt.figure()
# axes = fig.add_subplot(1, 3, 1)
# self.mesh.add_plot(axes)
# if self.meshtype == 'UniformMesh':
# uh = uh.view(bm.ndarray)
# axes = fig.add_subplot(1, 3, 2, projection='3d')
# self.mesh.show_function(axes, bm.real(uh))
# axes = fig.add_subplot(1, 3, 3, projection='3d')
# self.mesh.show_function(axes, bm.imag(uh))
plt.show()

46 changes: 0 additions & 46 deletions app/lafem-ims/lafemims/test/test_near_field_data_generator.py

This file was deleted.

19 changes: 19 additions & 0 deletions app/lafem-ims/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

import os
import pathlib
from setuptools import setup, find_packages

from lafemims import __version__


setup(
name="lafemims",
version=__version__,
description="LaFEM-IMS: Learn-automated FEM Inverse Medium Scattering",
url="",
author="Chaos",
author_email="",
license="GNU",
packages=find_packages(),
python_requires=">=3.8",
)
Loading

0 comments on commit f6220de

Please sign in to comment.