Skip to content

kabsh umeyama algorithm for vectors aligment #1051

@jalvesz

Description

@jalvesz

Motivation

I would like to propose to add the the Kabsh-Umeyama algorithm (https://en.wikipedia.org/wiki/Kabsch_algorithm) in a (yet to be) stdlib_spatial module for vector alignment purposes.

A first API implementation combining several approaches including the use of weights:

module test
    use stdlib_kinds, only: dp
    use stdlib_intrinsics, only: stdlib_sum
    use stdlib_linalg, only: svd, det, eye, trace

contains

subroutine kabsh_umeyama( R, t, c, lrms, d, N, P, Q, W )
    integer, parameter :: wp = dp
    integer, intent(in) :: d, N
    real(wp), intent(out):: R(d,d) !> Rotation matrix
    real(wp), intent(out):: t(d)     !> Translation vector
    real(wp), intent(out):: c         !> Scaling factor
    real(wp), intent(out):: lrms    !> least root mean square
    real(wp), intent(in) :: P(d,N)  !> Reference point cloud
    real(wp), intent(in) :: Q(d,N) !> Target point cloud
    real(wp), intent(in), optional :: W(N) !> Weights
    
    ! -- Internal Variables
    integer :: i, j, point
    real(wp) :: centroid_P(d), centroid_Q(d)
    real(wp) :: covariance(d,d), S(d), U(d,d), Vt(d,d), B(d,d)
    real(wp) :: vec(d), variance_p, vp, vq, isum_w
    !---------------------------------------
    ! Compute centroids
    isum_w = 1._wp / N
    if(present(W)) isum_w = 1._wp / stdlib_sum(W)
    
    centroid_P = 0._wp; centroid_Q = 0._wp
    if(present(W))then
        do concurrent( point = 1:N )
            centroid_P(1:d) = centroid_P(1:d) + P(1:d,point)*W(point)
            centroid_Q(1:d) = centroid_Q(1:d) + Q(1:d,point)*W(point)
        end do
    else
        do concurrent( point = 1:N )
            centroid_P(1:d) = centroid_P(1:d) + P(1:d,point)
            centroid_Q(1:d) = centroid_Q(1:d) + Q(1:d,point)
        end do
    end if
    centroid_P = centroid_P * isum_w
    centroid_Q = centroid_Q * isum_w

    ! Compute the covariance matrix
    covariance = 0._wp; variance_p = 0._wp
    if(present(W))then
        do concurrent( point = 1:N, j = 1:d, i = 1:d )
            vp = P(i,point)-centroid_P(i)
            vq = Q(j,point)-centroid_Q(j)
            covariance(i,j) = covariance(i,j) + vp*vq * W(point)
            variance_p = variance_p + vp**2
        end do
    else
        do concurrent( point = 1:N, j = 1:d, i = 1:d )
            vp = P(i,point)-centroid_P(i)
            vq = Q(j,point)-centroid_Q(j)
            covariance(i,j) = covariance(i,j) + vp*vq
            variance_p = variance_p + vp**2
        end do
    end if
    covariance = covariance * isum_w
    variance_p = variance_p * isum_w / d

    ! SVD
    call svd(covariance, S, U, Vt )
    
    ! Optimal rotation
    R = matmul( U , Vt )
    
    ! Validate right-handed coordinate system
    B = eye(d,d)
    B(d,d) = sign( 1._wp, det(R))

    R = matmul( U , matmul( B , Vt ) )
    
    c = variance_p / (sum(S(1:d-1))+B(d,d)*S(d))
    
    ! Optimal translation (aligning Q to P)
    t = centroid_P - c * matmul(R , centroid_Q )
    
    ! RMSD
    lrms = 0._wp
    do concurrent( point = 1:N )
        vec(1:d) = t(1:d) + c * matmul(R , Q(1:d,point) )
        vec(1:d) = vec(1:d) - P(1:d,point)
        lrms  = lrms  + dot_product(vec,vec)
    end do
    lrms  = sqrt( lrms * isum_w )

end subroutine

end module

program main
use test

integer, parameter :: wp = 8
integer, parameter :: d = 2
real(wp), allocatable :: P(:,:)
real(wp), allocatable :: Q(:,:)
real(wp) :: R(d,d), t(d), vec(d), c, lrms 
integer :: i, N
N = 7
P = reshape( &
    [23._wp, 178._wp,&
     66._wp, 173._wp,&
     88._wp, 187._wp,&
    119._wp, 202._wp,&
    122._wp, 229._wp,&
    170._wp, 232._wp,&
    179._wp, 199._wp] , [d,N] )

Q = reshape( & 
    [232._wp, 38._wp,&
     208._wp, 32._wp,&
     181._wp, 31._wp,&
     155._wp, 45._wp,&
     142._wp, 33._wp,&
     121._wp, 59._wp,&
     139._wp, 69._wp] , [d,N] )

call kabsh_umeyama(R, t, c, lrms, d, N, P, Q)
print *, 'RMSD        :', lrms 
print *, 'Translation :', t
print *, 'Scaling     :', c
print *, 'Rotation    :'
do i = 1, d
    print *, R(i,:)
end do

end program

gives:

 LRMS        :   16.242818369005477     
 Translation :   271.33459510449791        396.07800316838268     
 Scaling     :   1.4616613091002035     
 Rotation    :
 -0.81034281019830057       0.58595608193782001     
 -0.58595608193782023      -0.81034281019830057     

Prior Art

Additional Information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    ideaProposition of an idea and opening an issue to discuss ittopic: algorithmssearching and sorting, merging, ...topic: mathematicslinear algebra, sparse matrices, special functions, FFT, random numbers, statistics, ...

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions