Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complex number support to expm1 #452

Merged
merged 8 commits into from
Nov 17, 2022
Merged

Add complex number support to expm1 #452

merged 8 commits into from
Nov 17, 2022

Conversation

kgryte
Copy link
Contributor

@kgryte kgryte commented Jun 13, 2022

This PR

  • adds complex number support to expm1 by documenting special cases. The exponential function is an entire function in the complex plane. Thus, the function does not have branch cuts.
  • updates the input and output array data types to be any floating-point data type, not just real-valued floating-point data types.
  • derives special cases from C99 exp and tested against NumPy (script found below).
import numpy as np
import math

def is_equal_float(x, y):
    """Test whether two floating-point numbers are equal with special consideration for zeros and NaNs.

    Parameters
    ----------
    x : float
        First input number.
    y : float
        Second input number.

    Returns
    -------
    bool
        Boolean indicating whether two floating-point numbers are equal.

    Examples
    --------
    >>> is_equal_float(0.0, -0.0)
    False
    >>> is_equal_float(-0.0, -0.0)
    True
    """
    # Handle +-0:
    if x == 0.0 and y == 0.0:
        return math.copysign(1.0, x) == math.copysign(1.0, y)

    # Handle NaNs:
    if x != x:
        return y != y

    # Everything else, including infinities:
    return x == y


def is_equal(x, y):
    """Test whether two complex numbers are equal with special consideration for zeros and NaNs.

    Parameters
    ----------
    x : complex
        First input number.
    y : complex
        Second input number.

    Returns
    -------
    bool
        Boolean indicating whether two complex numbers are equal.

    Examples
    --------
    >>> import numpy as np
    >>> is_equal(complex(np.nan, np.nan), complex(np.nan, np.nan))
    True
    """
    return is_equal_float(x.real, y.real) and is_equal_float(x.imag, y.imag)


# Strided array consisting of input values and expected values:
values = [
    complex(0.0, 0.0),        # 0
    complex(0.0, 0.0),        # 0
    
    complex(-0.0, 0.0),       # 1
    complex(-0.0, 0.0),       # 1
    
    complex(1.0, np.inf),     # 2
    complex(np.nan, np.nan),  # 2
    
    complex(1.0, np.nan),     # 3
    complex(np.nan, np.nan),  # 3
    
    complex(np.inf, 0.0),     # 4
    complex(np.inf, 0.0),     # 4, seems to be a bug in NumPy, as it returns (inf+nanj), vs np.exp(complex(np.inf, 0.0))-1.0 == (inf+0j)
    
    complex(-np.inf, 1.0),    # 5
    complex(-1.0, 0.0),       # 5
    
    complex(np.inf, 1.0),     # 6
    complex(np.inf, np.inf),  # 6
    
    complex(-np.inf, np.inf), # 7
    complex(-1.0, 0.0),       # 7, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(-np.inf, np.inf))-1.0 == (-1+0j)
    
    complex(np.inf, np.inf),  # 8
    complex(np.inf, np.nan),  # 8, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(np.inf, np.inf))-1.0 == (inf+nanj)
    
    complex(-np.inf, np.nan), # 9
    complex(-1.0, 0.0),       # 9, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(-np.inf, np.nan))-1.0 == (-1+0j)
    
    complex(np.inf, np.nan),  # 10
    complex(np.inf, np.nan),  # 10, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(np.inf, np.nan))-1.0 == (inf+nanj)
    
    complex(np.nan, 0.0),     # 11
    complex(np.nan, 0.0),     # 11, seems to be a bug in NumPy, as it returns (nan+nanj), vs np.exp(complex(np.nan, 0.0))-1.0 == (nan+0j)
    
    complex(np.nan, 1.0),     # 12
    complex(np.nan, np.nan),  # 12
    
    complex(np.nan, np.nan),  # 13
    complex(np.nan, np.nan)   # 13
]

for i in range(len(values)//2):
    j = i * 2
    v = values[j]
    e = values[j+1]
    actual = np.expm1(v)
    print('Value: {value}'.format(value=str(v)))
    print('Actual: {actual}'.format(actual=str(actual)))
    print('Naive: {naive}'.format(naive=str(np.exp(v)-1.0)))
    print('Expected: {expected}'.format(expected=str(e)))
    print('Equal: {is_equal}'.format(is_equal=str(is_equal(actual, e))))
    print('\n')
Value: 0j
Actual: 0j
Naive: 0j
Expected: 0j
Equal: True


Value: (-0+0j)
Actual: (-0+0j)
Naive: 0j
Expected: (-0+0j)
Equal: True


/path/to/cexpm1.py:113: RuntimeWarning: invalid value encountered in expm1
  actual = np.expm1(v)
Value: (1+infj)
Actual: (nan+nanj)
/path/to/cexpm1.py:116: RuntimeWarning: invalid value encountered in exp
  print('Naive: {naive}'.format(naive=str(np.exp(v)-1.0)))
Naive: (nan+nanj)
Expected: (nan+nanj)
Equal: True


Value: (1+nanj)
Actual: (nan+nanj)
Naive: (nan+nanj)
Expected: (nan+nanj)
Equal: True


Value: (inf+0j)
Actual: (inf+nanj)
Naive: (inf+0j)
Expected: (inf+0j)
Equal: False


Value: (-inf+1j)
Actual: (-1+0j)
Naive: (-1+0j)
Expected: (-1+0j)
Equal: True


Value: (inf+1j)
Actual: (inf+infj)
Naive: (inf+infj)
Expected: (inf+infj)
Equal: True


Value: (-inf+infj)
Actual: (nan+nanj)
Naive: (-1+0j)
Expected: (-1+0j)
Equal: False


Value: (inf+infj)
Actual: (nan+nanj)
Naive: (inf+nanj)
Expected: (inf+nanj)
Equal: False


Value: (-inf+nanj)
Actual: (nan+nanj)
Naive: (-1+0j)
Expected: (-1+0j)
Equal: False


Value: (inf+nanj)
Actual: (nan+nanj)
Naive: (inf+nanj)
Expected: (inf+nanj)
Equal: False


Value: (nan+0j)
Actual: (nan+nanj)
Naive: (nan+0j)
Expected: (nan+0j)
Equal: False


Value: (nan+1j)
Actual: (nan+nanj)
Naive: (nan+nanj)
Expected: (nan+nanj)
Equal: True


Value: (nan+nanj)
Actual: (nan+nanj)
Naive: (nan+nanj)
Expected: (nan+nanj)
Equal: True

Notes

  • NumPy currently fails for 4 6 complex number special cases. It's behavior is inconsistent with np.exp(z)-1, as documented in the script above.

@kgryte kgryte added API extension Adds new functions or objects to the API. topic: Complex Data Types Complex number data types. labels Jun 13, 2022
@kgryte kgryte added this to the v2022 milestone Jun 13, 2022
@kgryte kgryte added API change Changes to existing functions or objects in the API. and removed API extension Adds new functions or objects to the API. labels Jun 20, 2022
@kgryte
Copy link
Contributor Author

kgryte commented Nov 17, 2022

As no objections have been raised to the changes introduced in this PR and the changes follow established conventions (C99 and equivalent exp(x)-1 behavior), will merge. Revisions to special cases can be made in follow-up PRs.

@kgryte kgryte merged commit ea8f6a9 into main Nov 17, 2022
@kgryte kgryte deleted the cmplx-expm1 branch November 17, 2022 10:30
@honno
Copy link
Member

honno commented Dec 9, 2022

Seems torch doesn't support complex inputs for expm1

>>> torch.expm1(torch.as_tensor(1.+1.j))
RuntimeError: "expm1_vml_cpu" not implemented for 'ComplexFloat'

exp however does take complex inputs. Couldn't see any relevant issues to this on PyTorch's tracker.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API change Changes to existing functions or objects in the API. topic: Complex Data Types Complex number data types.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants