Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unit tests for uniform distribution and triangular distribution #58

Merged
merged 5 commits into from
Jan 15, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions tests/crvs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,70 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# limitations under the License.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not required


import sys

import jax
import pytest
import jax.numpy as jnp
from jax.scipy.stats import uniform as jax_uniform

sys.path.append("../jaxampler")

from jaxampler._src.rvs.uniform import Uniform

eps = 1e-3


def test_LogNormal():
pass

class TestUniform:

def test_Pareto():
pass
def test_shape(self):
assert jnp.allclose(Uniform(low=0, high=10, name="uniform_0_to_10").pdf_x(5), jax_uniform.pdf(5, 0, 10))

# when low is negative
assert jnp.allclose(Uniform(low=-10, high=10, name="uniform_n10_to_10").pdf_x(5), jax_uniform.pdf(5, -10, 10))

def test_Rayleigh():
pass
# when both low and high are negative
assert jnp.allclose(Uniform(low=-10, high=-1, name="uniform_n10_to_n1").pdf_x(5), jax_uniform.pdf(5, -10, -1))

# when low is equal to high
with pytest.raises(AssertionError):
Uniform(low=10, high=10, name="uniform_10_to_10")

def test_Triangular():
pass
# when high is greater than low
with pytest.raises(AssertionError):
Uniform(low=10, high=0, name="uniform_10_to_0")

def test_cdf_x(self):
uniform_cdf = Uniform(low=0, high=10, name="cdf_0_to_10")
assert uniform_cdf.cdf_x(5) <= 1
assert uniform_cdf.cdf_x(5) >= 0
assert uniform_cdf.cdf_x(15) == 1
assert uniform_cdf.cdf_x(-1) == -jnp.inf
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong test case

assert uniform_cdf.cdf_x(-1) == 0.0


def test_TruncPowerLaw():
pass
# when low is negative
uniform_cdf = Uniform(low=-10, high=10, name="cdf_n10_to_10")
assert uniform_cdf.cdf_x(0) <= 1
assert uniform_cdf.cdf_x(0) >= 0
assert uniform_cdf.cdf_x(15) == 1
assert uniform_cdf.cdf_x(-11) == -jnp.inf
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong test case

assert uniform_cdf.cdf_x(-11) == 0.0


# when low and high are negative
uniform_cdf = Uniform(low=-10, high=-1, name="cdf_n10_to_n1")
assert uniform_cdf.cdf_x(-5) <= 1
assert uniform_cdf.cdf_x(-5) >= 0
assert uniform_cdf.cdf_x(1) == 1
assert uniform_cdf.cdf_x(-20) == -jnp.inf
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong test case


def test_Uniform():
pass
def test_rvs(self):
uniforn_rvs = Uniform(low=0, high=10, name="tets_rvs")
shape = (3, 4)

# with key
key = jax.random.PRNGKey(123)
result = uniforn_rvs.rvs(shape, key)
assert result.shape, shape + uniforn_rvs._shape

def test_Weibull():
pass
# without key
result = uniforn_rvs.rvs(shape)
assert result.shape, shape + uniforn_rvs._shape
Loading