diff --git a/src/astx/__init__.py b/src/astx/__init__.py index d95161c..77572c2 100644 --- a/src/astx/__init__.py +++ b/src/astx/__init__.py @@ -39,6 +39,9 @@ ) from astx.datatypes import ( Boolean, + Complex, + Complex32, + Complex64, DataTypeOps, Float16, Float32, @@ -51,6 +54,9 @@ Integer, Literal, LiteralBoolean, + LiteralComplex, + LiteralComplex32, + LiteralComplex64, LiteralFloat16, LiteralFloat32, LiteralFloat64, @@ -203,6 +209,12 @@ def get_version() -> str: "variables", "VisibilityKind", "While", + "Complex", + "Complex32", + "Complex64", + "LiteralComplex", + "LiteralComplex32", + "LiteralComplex64", ] diff --git a/src/astx/datatypes.py b/src/astx/datatypes.py index db57c1e..8c3e559 100644 --- a/src/astx/datatypes.py +++ b/src/astx/datatypes.py @@ -552,3 +552,103 @@ def __init__( self.value = value self.type_ = Float64 self.loc = loc + + +@public +class Complex(Number): + """Base class for complex numbers.""" + + def __init__(self, real: float, imag: float) -> None: + """Initialize a complex number with real and imaginary parts.""" + self.real = real + self.imag = imag + + def __str__(self) -> str: + """Return a string representation of the complex number.""" + return f"{self.real} + {self.imag}j" + + +@public +class Complex32(Complex): + """Complex32 data type class.""" + + nbytes: int = 8 + + def __init__(self, real: float, imag: float) -> None: + """Initialize a 32-bit complex number.""" + super().__init__(real, imag) + + +@public +class Complex64(Complex): + """Complex64 data type class.""" + + nbytes: int = 16 + + def __init__(self, real: float, imag: float) -> None: + """Initialize a 64-bit complex number.""" + super().__init__(real, imag) + + +@public +class LiteralComplex(Literal): + """Base class for literal complex numbers.""" + + value: Complex + + def __init__( + self, value: Complex, loc: SourceLocation = NO_SOURCE_LOCATION + ) -> None: + """Initialize LiteralComplex with a complex number.""" + super().__init__(loc) + if isinstance(value, Complex): + self.value = value + self.type_ = ( + Complex64 if isinstance(value, Complex64) else Complex32 + ) + else: + raise TypeError("Value must be an instance of Complex.") + self.loc = loc + + def __str__(self) -> str: + """Return a string that represents the object.""" + return f"LiteralComplex({self.value.real} + {self.value.imag}j)" + + def get_struct(self, simplified: bool = False) -> ReprStruct: + """Return the AST representation for the complex literal.""" + key = f"{self.__class__.__name__}: {self.value}" + value: ReprStruct = { + "real": self.value.real, + "imag": self.value.imag, + } + return self._prepare_struct(key, value, simplified) + + +@public +class LiteralComplex32(LiteralComplex): + """LiteralComplex32 data type class.""" + + def __init__( + self, + real: float, + imag: float, + loc: SourceLocation = NO_SOURCE_LOCATION, + ) -> None: + """Initialize LiteralComplex32.""" + super().__init__(Complex32(real, imag), loc) + self.type_ = Complex32 + + +@public +class LiteralComplex64(LiteralComplex): + """LiteralComplex64 data type class.""" + + def __init__( + self, + real: float, + imag: float, + loc: SourceLocation = NO_SOURCE_LOCATION, + ) -> None: + """Initialize LiteralComplex64.""" + super().__init__(Complex64(real, imag), loc) + self.type_ = Complex64 diff --git a/src/astx/transpilers/python.py b/src/astx/transpilers/python.py index 7789283..4726af2 100644 --- a/src/astx/transpilers/python.py +++ b/src/astx/transpilers/python.py @@ -195,3 +195,57 @@ def visit(self, node: astx.VariableAssignment) -> str: target = node.name value = self.visit(node.value) return f"{target} = {value}" + + @dispatch # type: ignore[no-redef] + def visit(self, node: Type[astx.Float16]) -> str: + """Handle Float nodes.""" + return "float" + + @dispatch # type: ignore[no-redef] + def visit(self, node: Type[astx.Float32]) -> str: + """Handle Float nodes.""" + return "float" + + @dispatch # type: ignore[no-redef] + def visit(self, node: Type[astx.Float64]) -> str: + """Handle Float nodes.""" + return "float" + + @dispatch # type: ignore[no-redef] + def visit(self, node: astx.LiteralFloat16) -> str: + """Handle LiteralFloat nodes.""" + return str(node.value) + + @dispatch # type: ignore[no-redef] + def visit(self, node: astx.LiteralFloat32) -> str: + """Handle LiteralFloat nodes.""" + return str(node.value) + + @dispatch # type: ignore[no-redef] + def visit(self, node: astx.LiteralFloat64) -> str: + """Handle LiteralFloat nodes.""" + return str(node.value) + + @dispatch # type: ignore[no-redef] + def visit(self, node: Type[astx.Complex32]) -> str: + """Handle Complex32 nodes.""" + return "Complex" + + @dispatch # type: ignore[no-redef] + def visit(self, node: Type[astx.Complex64]) -> str: + """Handle Complex64 nodes.""" + return "Complex" + + @dispatch # type: ignore[no-redef] + def visit(self, node: astx.LiteralComplex32) -> str: + """Handle LiteralComplex32 nodes.""" + real = node.value.real + imag = node.value.imag + return f"complex({real}, {imag})" + + @dispatch # type: ignore[no-redef] + def visit(self, node: astx.LiteralComplex64) -> str: + """Handle LiteralComplex64 nodes.""" + real = node.value.real + imag = node.value.imag + return f"complex({real}, {imag})" diff --git a/src/astx/viz.py b/src/astx/viz.py index 7f4f878..c5e8ae4 100644 --- a/src/astx/viz.py +++ b/src/astx/viz.py @@ -22,7 +22,7 @@ _AsciiGraphProxy, ) from graphviz import Digraph -from IPython.display import Image, display # type: ignore[attr-defined] +from IPython.display import Image, display # type: ignore from msgpack import dumps, loads from astx.types import DictDataTypesStruct, ReprStruct diff --git a/tests/test_datatype_complex.py b/tests/test_datatype_complex.py new file mode 100644 index 0000000..2dab50d --- /dev/null +++ b/tests/test_datatype_complex.py @@ -0,0 +1,82 @@ +"""Tests for complex number data types.""" + +from __future__ import annotations + +from typing import Callable, Type + +import astx +import pytest + +from astx.operators import BinaryOp, UnaryOp +from astx.variables import Variable + +VAR_A = Variable("a") + +COMPLEX_LITERAL_CLASSES = [ + astx.LiteralComplex32, + astx.LiteralComplex64, +] + + +def test_variable() -> None: + """Test variable complex.""" + var_a = Variable("a") + var_b = Variable("b") + + BinaryOp(op_code="+", lhs=var_a, rhs=var_b) + + +@pytest.mark.parametrize("literal_class", COMPLEX_LITERAL_CLASSES) +def test_literal(literal_class: Type[astx.Literal]) -> None: + """Test complex literals.""" + lit_a = literal_class(1.5, 2.5) + lit_b = literal_class(3.0, -4.0) + BinaryOp(op_code="+", lhs=lit_a, rhs=lit_b) + + +@pytest.mark.parametrize( + "fn_bin_op,op_code", + [ + (lambda literal_class: VAR_A + literal_class(1.5, 2.5), "+"), + (lambda literal_class: VAR_A - literal_class(1.5, 2.5), "-"), + (lambda literal_class: VAR_A / literal_class(1.5, 2.5), "/"), + (lambda literal_class: VAR_A * literal_class(1.5, 2.5), "*"), + (lambda literal_class: VAR_A == literal_class(1.5, 2.5), "=="), + (lambda literal_class: VAR_A != literal_class(1.5, 2.5), "!="), + ], +) +@pytest.mark.parametrize("literal_class", COMPLEX_LITERAL_CLASSES) +def test_bin_ops( + literal_class: Type[astx.Literal], + fn_bin_op: Callable[[Type[astx.Literal]], BinaryOp], + op_code: str, +) -> None: + """Test binary operations on complex numbers.""" + bin_op = fn_bin_op(literal_class) + assert bin_op.op_code == op_code + assert str(bin_op) != "" + assert repr(bin_op) != "" + assert bin_op.get_struct() != {} + assert bin_op.get_struct(simplified=True) != {} + + +@pytest.mark.parametrize( + "fn_unary_op,op_code", + [ + (lambda literal_class: +literal_class(1.5, 2.5), "+"), + (lambda literal_class: -literal_class(1.5, 2.5), "-"), + ], +) +@pytest.mark.parametrize("literal_class", COMPLEX_LITERAL_CLASSES) +def test_unary_ops( + literal_class: Type[astx.Literal], + fn_unary_op: Callable[[Type[astx.Literal]], UnaryOp], + op_code: str, +) -> None: + """Test unary operations on complex numbers.""" + unary_op = fn_unary_op(literal_class) + assert unary_op.op_code == op_code + assert str(unary_op) != "" + assert repr(unary_op) != "" + assert unary_op.get_struct() != {} + assert unary_op.get_struct(simplified=True) != {} diff --git a/tests/transpilers/test_python.py b/tests/transpilers/test_python.py index eae2d93..0ffe532 100644 --- a/tests/transpilers/test_python.py +++ b/tests/transpilers/test_python.py @@ -240,3 +240,97 @@ def test_transpiler_function() -> None: ) assert generated_code == expected_code, "generated_code != expected_code" + + +def test_literal_int32() -> None: + """Test astx.LiteralInt32.""" + # Create a LiteralInt32 node + literal_int32_node = astx.LiteralInt32(value=42) + + # Initialize the generator + generator = astx2py.ASTxPythonTranspiler() + + # Generate Python code + generated_code = generator.visit(literal_int32_node) + expected_code = "42" + + assert generated_code == expected_code, "generated_code != expected_code" + + +def test_literal_float16() -> None: + """Test astx.LiteralFloat16.""" + # Create a LiteralFloat16 node + literal_float16_node = astx.LiteralFloat16(value=3.14) + + # Initialize the generator + generator = astx2py.ASTxPythonTranspiler() + + # Generate Python code + generated_code = generator.visit(literal_float16_node) + expected_code = "3.14" + + assert generated_code == expected_code, "generated_code != expected_code" + + +def test_literal_float32() -> None: + """Test astx.LiteralFloat32.""" + # Create a LiteralFloat32 node + literal_float32_node = astx.LiteralFloat32(value=2.718) + + # Initialize the generator + generator = astx2py.ASTxPythonTranspiler() + + # Generate Python code + generated_code = generator.visit(literal_float32_node) + expected_code = "2.718" + + assert generated_code == expected_code, "generated_code != expected_code" + + +def test_literal_float64() -> None: + """Test astx.LiteralFloat64.""" + # Create a LiteralFloat64 node + literal_float64_node = astx.LiteralFloat64(value=1.414) + + # Initialize the generator + generator = astx2py.ASTxPythonTranspiler() + + # Generate Python code + generated_code = generator.visit(literal_float64_node) + expected_code = "1.414" + + assert generated_code == expected_code, "generated_code != expected_code" + + +def test_literal_complex32() -> None: + """Test astx.LiteralComplex32.""" + # Create a LiteralComplex32 node + literal_complex32_node = astx.LiteralComplex32(real=1, imag=2.8) + + # Initialize the generator + generator = astx2py.ASTxPythonTranspiler() + + # Generate Python code + generated_code = generator.visit(literal_complex32_node) + expected_code = "complex(1, 2.8)" + + assert ( + generated_code == expected_code + ), f"Expected '{expected_code}', but got '{generated_code}'" + + +def test_literal_complex64() -> None: + """Test astx.LiteralComplex64.""" + # Create a LiteralComplex64 node + literal_complex64_node = astx.LiteralComplex64(real=3.5, imag=4) + + # Initialize the generator + generator = astx2py.ASTxPythonTranspiler() + + # Generate Python code + generated_code = generator.visit(literal_complex64_node) + expected_code = "complex(3.5, 4)" + + assert ( + generated_code == expected_code + ), f"Expected '{expected_code}', but got '{generated_code}'"