Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linalg eye: allow generalized return type and kind #902

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions doc/specs/stdlib_linalg.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,33 +239,34 @@ Pure function.

### Description

Construct the identity matrix.
Constructs the identity matrix.

### Syntax

`I = ` [[stdlib_linalg(module):eye(function)]] `(dim1 [, dim2])`
`I = ` [[stdlib_linalg(module):eye(function)]] `(dim1 [, dim2] [, mold])`

### Arguments

`dim1`: Shall be a scalar of default type `integer`.
This is an `intent(in)` argument.

`dim2`: Shall be a scalar of default type `integer`.
This is an `intent(in)` and `optional` argument.
- `dim1`: A scalar of type `integer`. This is an `intent(in)` argument and specifies the number of rows.
- `dim2`: A scalar of type `integer`. This is an optional `intent(in)` argument specifying the number of columns. If not provided, the matrix is square (`dim1 = dim2`).
- `mold`: A scalar of any supported `integer`, `real`, or `complex` type. This is an optional `intent(in)` argument. If provided, the returned identity matrix will have the same type and kind as `mold`. If not provided, the matrix will be of type `integer(int8)` by default.
perazz marked this conversation as resolved.
Show resolved Hide resolved

### Return value

Return the identity matrix, i.e. a matrix with ones on the main diagonal and zeros elsewhere. The return value is of type `integer(int8)`.
The use of `int8` was suggested to save storage.
Returns the identity matrix, with ones on the main diagonal and zeros elsewhere.

- By default, the return value is of type `integer(int8)`, which is recommended for storage efficiency.
- If the `mold` argument is provided, the return value will match the type and kind of `mold`, allowing for arbitrary `integer`, `real`, or `complex` return types.

#### Warning

Since the result of `eye` is of `integer(int8)` type, one should be careful about using it in arithmetic expressions. For example:
When using the default `integer(int8)` type, be cautious when performing arithmetic operations, as integer division may occur. For example:

```fortran
!> Be careful
perazz marked this conversation as resolved.
Show resolved Hide resolved
A = eye(2,2)/2 !! A == 0.0
!> Recommend
A = eye(2,2)/2.0 !! A == diag([0.5, 0.5])
!> Caution: default type is `integer`
A = eye(2,2)/2 !! A == 0.0 due to integer division
!> Recommend using a non-integer type for division
A = eye(2,2, mold=1.0)/2 !! A == diag([0.5, 0.5])
```

### Example
Expand Down
10 changes: 5 additions & 5 deletions example/linalg/example_eye1.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ program example_eye1
real :: a(3, 3)
real :: b(2, 3) !! Matrix is non-square.
complex :: c(2, 2)
I = eye(2) !! [1,0; 0,1]
A = eye(3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0]
A = eye(3, 3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0]
B = eye(2, 3) !! [1.0,0.0,0.0; 0.0,1.0,0.0]
C = eye(2, 2) !! [(1.0,0.0),(0.0,0.0); (0.0,0.0),(1.0,0.0)]
I = eye(2) !! [1,0; 0,1]
A = eye(3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0]
A = eye(3, 3) !! [1.0,0.0,0.0; 0.0,1.0,0.0; 0.0,0.0,1.0]
B = eye(2, 3) !! [1.0,0.0,0.0; 0.0,1.0,0.0]
C = eye(2, 2) !! [(1.0,0.0),(0.0,0.0); (0.0,0.0),(1.0,0.0)]
C = (1.0, 1.0)*eye(2, 2) !! [(1.0,1.0),(0.0,0.0); (0.0,0.0),(1.0,1.0)]
end program example_eye1
23 changes: 18 additions & 5 deletions src/stdlib_linalg.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ module stdlib_linalg
#:endfor
end interface

! Identity matrix
interface eye
!! version: experimental
!!
!! Constructs the identity matrix
!! ([Specification](../page/specs/stdlib_linalg.html#eye-construct-the-identity-matrix))
#:for k1, t1 in RCI_KINDS_TYPES
module procedure eye_${t1[0]}$${k1}$
#:endfor
end interface eye

! Outer product (of two vectors)
interface outer_product
Expand Down Expand Up @@ -1327,24 +1337,27 @@ contains
!>
!> Constructs the identity matrix.
!> ([Specification](../page/specs/stdlib_linalg.html#eye-construct-the-identity-matrix))
pure function eye(dim1, dim2) result(result)
#:for k1, t1 in RCI_KINDS_TYPES
pure function eye_${t1[0]}$${k1}$(dim1, dim2, mold) result(result)

integer, intent(in) :: dim1
integer, intent(in), optional :: dim2
integer(int8), allocatable :: result(:, :)
${t1}$, intent(in) #{if 'int8' in t1}#, optional #{endif}#:: mold
${t1}$, allocatable :: result(:, :)

integer :: dim2_
integer :: i

dim2_ = optval(dim2, dim1)
allocate(result(dim1, dim2_))

result = 0_int8
result = 0
do i = 1, min(dim1, dim2_)
result(i, i) = 1_int8
result(i, i) = 1
end do

end function eye
end function eye_${t1[0]}$${k1}$
#:endfor

#:for k1, t1 in RCI_KINDS_TYPES
function trace_${t1[0]}$${k1}$(A) result(res)
Expand Down
Loading