Skip to content

Commit ede3cd9

Browse files
committed
WIP: EigWorkImpl for f64
1 parent 0a78593 commit ede3cd9

File tree

1 file changed

+153
-30
lines changed

1 file changed

+153
-30
lines changed

lax/src/eig.rs

Lines changed: 153 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,18 @@ pub struct EigWork<T: Scalar> {
4040
pub jobvl: JobEv,
4141

4242
/// Eigenvalues used in complex routines
43-
pub eigs: Option<Vec<MaybeUninit<T>>>,
43+
pub eigs: Vec<MaybeUninit<T::Complex>>,
4444
/// Real part of eigenvalues used in real routines
45-
pub eigs_re: Option<Vec<MaybeUninit<T>>>,
45+
pub eigs_re: Option<Vec<MaybeUninit<T::Real>>>,
4646
/// Imaginary part of eigenvalues used in real routines
47-
pub eigs_im: Option<Vec<MaybeUninit<T>>>,
47+
pub eigs_im: Option<Vec<MaybeUninit<T::Real>>>,
4848

4949
/// Left eigenvectors
50-
pub vl: Option<Vec<MaybeUninit<T>>>,
50+
pub vc_l: Option<Vec<MaybeUninit<T::Complex>>>,
51+
pub vr_l: Option<Vec<MaybeUninit<T::Real>>>,
5152
/// Right eigenvectors
52-
pub vr: Option<Vec<MaybeUninit<T>>>,
53+
pub vc_r: Option<Vec<MaybeUninit<T::Complex>>>,
54+
pub vr_r: Option<Vec<MaybeUninit<T::Real>>>,
5355

5456
/// Working memory
5557
pub work: Vec<MaybeUninit<T>>,
@@ -97,8 +99,8 @@ impl EigWorkImpl for EigWork<c64> {
9799
let mut eigs: Vec<MaybeUninit<c64>> = vec_uninit(n as usize);
98100
let mut rwork: Vec<MaybeUninit<f64>> = vec_uninit(2 * n as usize);
99101

100-
let mut vl: Option<Vec<MaybeUninit<c64>>> = jobvl.then(|| vec_uninit((n * n) as usize));
101-
let mut vr: Option<Vec<MaybeUninit<c64>>> = jobvr.then(|| vec_uninit((n * n) as usize));
102+
let mut vc_l: Option<Vec<MaybeUninit<c64>>> = jobvl.then(|| vec_uninit((n * n) as usize));
103+
let mut vc_r: Option<Vec<MaybeUninit<c64>>> = jobvr.then(|| vec_uninit((n * n) as usize));
102104

103105
// calc work size
104106
let mut info = 0;
@@ -111,9 +113,9 @@ impl EigWorkImpl for EigWork<c64> {
111113
std::ptr::null_mut(),
112114
&n,
113115
AsPtr::as_mut_ptr(&mut eigs),
114-
AsPtr::as_mut_ptr(vl.as_deref_mut().unwrap_or(&mut [])),
116+
AsPtr::as_mut_ptr(vc_l.as_deref_mut().unwrap_or(&mut [])),
115117
&n,
116-
AsPtr::as_mut_ptr(vr.as_deref_mut().unwrap_or(&mut [])),
118+
AsPtr::as_mut_ptr(vc_r.as_deref_mut().unwrap_or(&mut [])),
117119
&n,
118120
AsPtr::as_mut_ptr(&mut work_size),
119121
&(-1),
@@ -129,12 +131,14 @@ impl EigWorkImpl for EigWork<c64> {
129131
n,
130132
jobvl,
131133
jobvr,
132-
eigs: Some(eigs),
134+
eigs,
133135
eigs_re: None,
134136
eigs_im: None,
135137
rwork: Some(rwork),
136-
vl,
137-
vr,
138+
vc_l,
139+
vc_r,
140+
vr_l: None,
141+
vr_r: None,
138142
work,
139143
})
140144
}
@@ -149,10 +153,10 @@ impl EigWorkImpl for EigWork<c64> {
149153
&self.n,
150154
AsPtr::as_mut_ptr(a),
151155
&self.n,
152-
AsPtr::as_mut_ptr(self.eigs.as_mut().unwrap()),
153-
AsPtr::as_mut_ptr(self.vl.as_deref_mut().unwrap_or(&mut [])),
156+
AsPtr::as_mut_ptr(&mut self.eigs),
157+
AsPtr::as_mut_ptr(self.vc_l.as_deref_mut().unwrap_or(&mut [])),
154158
&self.n,
155-
AsPtr::as_mut_ptr(self.vr.as_deref_mut().unwrap_or(&mut [])),
159+
AsPtr::as_mut_ptr(self.vc_r.as_deref_mut().unwrap_or(&mut [])),
156160
&self.n,
157161
AsPtr::as_mut_ptr(&mut self.work),
158162
&lwork,
@@ -162,14 +166,10 @@ impl EigWorkImpl for EigWork<c64> {
162166
};
163167
info.as_lapack_result()?;
164168

165-
let eigs = self
166-
.eigs
167-
.as_ref()
168-
.map(|v| unsafe { v.slice_assume_init_ref() })
169-
.unwrap();
169+
let eigs = unsafe { self.eigs.slice_assume_init_ref() };
170170

171171
// Hermite conjugate
172-
if let Some(vl) = self.vl.as_mut() {
172+
if let Some(vl) = self.vc_l.as_mut() {
173173
for value in vl {
174174
let value = unsafe { value.assume_init_mut() };
175175
value.im = -value.im;
@@ -178,11 +178,11 @@ impl EigWorkImpl for EigWork<c64> {
178178
Ok(EigRef {
179179
eigs,
180180
vl: self
181-
.vl
181+
.vc_l
182182
.as_ref()
183183
.map(|v| unsafe { v.slice_assume_init_ref() }),
184184
vr: self
185-
.vr
185+
.vc_r
186186
.as_ref()
187187
.map(|v| unsafe { v.slice_assume_init_ref() }),
188188
})
@@ -198,10 +198,10 @@ impl EigWorkImpl for EigWork<c64> {
198198
&self.n,
199199
AsPtr::as_mut_ptr(a),
200200
&self.n,
201-
AsPtr::as_mut_ptr(self.eigs.as_mut().unwrap()),
202-
AsPtr::as_mut_ptr(self.vl.as_deref_mut().unwrap_or(&mut [])),
201+
AsPtr::as_mut_ptr(&mut self.eigs),
202+
AsPtr::as_mut_ptr(self.vc_l.as_deref_mut().unwrap_or(&mut [])),
203203
&self.n,
204-
AsPtr::as_mut_ptr(self.vr.as_deref_mut().unwrap_or(&mut [])),
204+
AsPtr::as_mut_ptr(self.vc_r.as_deref_mut().unwrap_or(&mut [])),
205205
&self.n,
206206
AsPtr::as_mut_ptr(&mut self.work),
207207
&lwork,
@@ -210,21 +210,134 @@ impl EigWorkImpl for EigWork<c64> {
210210
)
211211
};
212212
info.as_lapack_result()?;
213-
let eigs = self.eigs.map(|v| unsafe { v.assume_init() }).unwrap();
213+
let eigs = unsafe { self.eigs.assume_init() };
214214

215215
// Hermite conjugate
216-
if let Some(vl) = self.vl.as_mut() {
216+
if let Some(vl) = self.vc_l.as_mut() {
217217
for value in vl {
218218
let value = unsafe { value.assume_init_mut() };
219219
value.im = -value.im;
220220
}
221221
}
222222
Ok(Eig {
223223
eigs,
224-
vl: self.vl.map(|v| unsafe { v.assume_init() }),
225-
vr: self.vr.map(|v| unsafe { v.assume_init() }),
224+
vl: self.vc_l.map(|v| unsafe { v.assume_init() }),
225+
vr: self.vc_r.map(|v| unsafe { v.assume_init() }),
226+
})
227+
}
228+
}
229+
230+
impl EigWorkImpl for EigWork<f64> {
231+
type Elem = f64;
232+
233+
fn new(calc_v: bool, l: MatrixLayout) -> Result<Self> {
234+
let (n, _) = l.size();
235+
let (jobvl, jobvr) = if calc_v {
236+
match l {
237+
MatrixLayout::C { .. } => (JobEv::All, JobEv::None),
238+
MatrixLayout::F { .. } => (JobEv::None, JobEv::All),
239+
}
240+
} else {
241+
(JobEv::None, JobEv::None)
242+
};
243+
let mut eigs_re: Vec<MaybeUninit<f64>> = vec_uninit(n as usize);
244+
let mut eigs_im: Vec<MaybeUninit<f64>> = vec_uninit(n as usize);
245+
246+
let mut vr_l: Option<Vec<MaybeUninit<f64>>> = jobvl.then(|| vec_uninit((n * n) as usize));
247+
let mut vr_r: Option<Vec<MaybeUninit<f64>>> = jobvr.then(|| vec_uninit((n * n) as usize));
248+
249+
// calc work size
250+
let mut info = 0;
251+
let mut work_size: [f64; 1] = [0.0];
252+
unsafe {
253+
lapack_sys::dgeev_(
254+
jobvl.as_ptr(),
255+
jobvr.as_ptr(),
256+
&n,
257+
std::ptr::null_mut(),
258+
&n,
259+
AsPtr::as_mut_ptr(&mut eigs_re),
260+
AsPtr::as_mut_ptr(&mut eigs_im),
261+
AsPtr::as_mut_ptr(vr_l.as_deref_mut().unwrap_or(&mut [])),
262+
&n,
263+
AsPtr::as_mut_ptr(vr_r.as_deref_mut().unwrap_or(&mut [])),
264+
&n,
265+
AsPtr::as_mut_ptr(&mut work_size),
266+
&(-1),
267+
&mut info,
268+
)
269+
};
270+
info.as_lapack_result()?;
271+
272+
// actual ev
273+
let lwork = work_size[0].to_usize().unwrap();
274+
let work: Vec<MaybeUninit<f64>> = vec_uninit(lwork);
275+
276+
Ok(Self {
277+
n,
278+
jobvr,
279+
jobvl,
280+
eigs: vec_uninit(n as usize),
281+
eigs_re: Some(eigs_re),
282+
eigs_im: Some(eigs_im),
283+
rwork: None,
284+
vr_l,
285+
vr_r,
286+
vc_l: None,
287+
vc_r: None,
288+
work,
226289
})
227290
}
291+
292+
fn calc<'work>(&'work mut self, _a: &mut [f64]) -> Result<EigRef<'work, f64>> {
293+
todo!()
294+
}
295+
296+
fn eval(mut self, a: &mut [f64]) -> Result<Eig<f64>> {
297+
let lwork = self.work.len().to_i32().unwrap();
298+
let mut info = 0;
299+
unsafe {
300+
lapack_sys::dgeev_(
301+
self.jobvl.as_ptr(),
302+
self.jobvr.as_ptr(),
303+
&self.n,
304+
AsPtr::as_mut_ptr(a),
305+
&self.n,
306+
AsPtr::as_mut_ptr(self.eigs_re.as_mut().unwrap()),
307+
AsPtr::as_mut_ptr(self.eigs_im.as_mut().unwrap()),
308+
AsPtr::as_mut_ptr(self.vr_l.as_deref_mut().unwrap_or(&mut [])),
309+
&self.n,
310+
AsPtr::as_mut_ptr(self.vr_r.as_deref_mut().unwrap_or(&mut [])),
311+
&self.n,
312+
AsPtr::as_mut_ptr(&mut self.work),
313+
&lwork,
314+
&mut info,
315+
)
316+
};
317+
info.as_lapack_result()?;
318+
319+
let eigs_re = unsafe { self.eigs_re.unwrap().assume_init() };
320+
let eigs_im = unsafe { self.eigs_im.unwrap().assume_init() };
321+
322+
let n = self.n as usize;
323+
let vl = self.vr_l.map(|v| {
324+
let v = unsafe { v.assume_init() };
325+
let mut vc = vec_uninit(n * n);
326+
reconstruct_eigenvectors(false, &eigs_im, &v, &mut vc);
327+
unsafe { vc.assume_init() }
328+
});
329+
let vr = self.vr_r.map(|v| {
330+
let v = unsafe { v.assume_init() };
331+
let mut vc = vec_uninit(n * n);
332+
reconstruct_eigenvectors(true, &eigs_im, &v, &mut vc);
333+
unsafe { vc.assume_init() }
334+
});
335+
336+
reconstruct_eigs(&eigs_re, &eigs_im, &mut self.eigs);
337+
let eigs = unsafe { self.eigs.assume_init() };
338+
339+
Ok(Eig { eigs, vl, vr })
340+
}
228341
}
229342

230343
macro_rules! impl_eig_complex {
@@ -497,3 +610,13 @@ fn reconstruct_eigenvectors<T: Scalar>(
497610
}
498611
}
499612
}
613+
614+
/// Create complex eigenvalues from real and imaginary parts.
615+
fn reconstruct_eigs<T: Scalar>(re: &[T], im: &[T], eigs: &mut [MaybeUninit<T::Complex>]) {
616+
let n = eigs.len();
617+
assert_eq!(re.len(), n);
618+
assert_eq!(im.len(), n);
619+
for i in 0..n {
620+
eigs[i].write(T::complex(re[i], im[i]));
621+
}
622+
}

0 commit comments

Comments
 (0)