diff --git a/frontend/src/utils/constants.js b/frontend/src/utils/constants.js index 3d8941f74..d6ec692ce 100644 --- a/frontend/src/utils/constants.js +++ b/frontend/src/utils/constants.js @@ -1,5 +1,5 @@ export const CONTRACT_ADDRESS = '0x041a78e741e5af2fec34b695679bc6891742439f7afb8484ecd7766661ad02bf'; -export const CLASS_HASH = '0x031ae5768dced8a123907259d946edec39afa08073277abb3e55be8daa8fe49d'; +export const CLASS_HASH = '0x035ae0fe6ca00fcc8020a6c64503f38bfaf3481ae9a6c8b7daec2f899df735fa'; export const UNIQUE = '0x0'; export const EKUBO_ADDRESS = '0x00000005dd3d2f4429af886cd1a3b08289dbcea99a294197e9eb43b0e0325b4b'; export const ZKLEND_ADDRESS = '0x04c0a5193d58f74fbace4b74dcf65481e734ed1714121bdc571da345540efa05'; diff --git a/spotnet_tracker/tasks.py b/spotnet_tracker/tasks.py index 3ae354344..f965cf6d9 100644 --- a/spotnet_tracker/tasks.py +++ b/spotnet_tracker/tasks.py @@ -13,7 +13,6 @@ import time from web_app.contract_tools.mixins.alert import AlertMixin -from web_app.tasks.claim_airdrops import AirdropClaimer from .celery_config import app @@ -33,19 +32,3 @@ def check_users_health_ratio() -> None: except Exception as e: logger.error(f"Error in check_users_health_ratio task: {e}") - -@app.task(name="claim_airdrop_task") -def claim_airdrop_task() -> None: - """ - Background task to claim user airdrops. - - :return: None - """ - try: - logger.info("Running claim_airdrop_task.") - logger.info("Task started at: ",time.strftime("%a, %d %b %Y %H:%M:%S")) - airdrop_claimer = AirdropClaimer() - asyncio.run(airdrop_claimer.claim_airdrops()) - logger.info("Task started at: ", time.strftime("%a, %d %b %Y %H:%M:%S")) - except Exception as e: - logger.error(f"Error in claiming airdrop task: {e}") diff --git a/web_app/alembic/versions/b6eaae01419c_remove_airdrop.py b/web_app/alembic/versions/b6eaae01419c_remove_airdrop.py new file mode 100644 index 000000000..3b2de55e0 --- /dev/null +++ b/web_app/alembic/versions/b6eaae01419c_remove_airdrop.py @@ -0,0 +1,41 @@ +"""remove_airdrop + +Revision ID: b6eaae01419c +Revises: cda4342b007d +Create Date: 2024-12-10 19:37:04.898405 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'b6eaae01419c' +down_revision = 'cda4342b007d' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('ix_airdrop_is_claimed', table_name='airdrop') + op.drop_index('ix_airdrop_user_id', table_name='airdrop') + op.drop_table('airdrop') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('airdrop', + sa.Column('id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('user_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('created_at', postgresql.TIMESTAMP(), autoincrement=False, nullable=False), + sa.Column('amount', sa.NUMERIC(), autoincrement=False, nullable=True), + sa.Column('is_claimed', sa.BOOLEAN(), autoincrement=False, nullable=True), + sa.Column('claimed_at', postgresql.TIMESTAMP(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='airdrop_user_id_fkey'), + sa.PrimaryKeyConstraint('id', name='airdrop_pkey') + ) + op.create_index('ix_airdrop_user_id', 'airdrop', ['user_id'], unique=False) + op.create_index('ix_airdrop_is_claimed', 'airdrop', ['is_claimed'], unique=False) + # ### end Alembic commands ### diff --git a/web_app/api/dashboard.py b/web_app/api/dashboard.py index 526a63e97..3d5c31d04 100644 --- a/web_app/api/dashboard.py +++ b/web_app/api/dashboard.py @@ -8,7 +8,7 @@ from web_app.api.serializers.dashboard import DashboardResponse from web_app.contract_tools.mixins import DashboardMixin, HealthRatioMixin from web_app.db.crud import PositionDBConnector -from decimal import Decimal +from decimal import Decimal, DivisionByZero router = APIRouter() position_db_connector = PositionDBConnector() @@ -59,12 +59,14 @@ async def get_dashboard(wallet_id: str) -> DashboardResponse: if opened_positions else collections.defaultdict(lambda: None) ) + if not first_opened_position: + return default_dashboard_response try: # Fetch zkLend position for the wallet ID health_ratio, tvl = await HealthRatioMixin.get_health_ratio_and_tvl( contract_address ) - except IndexError: + except (IndexError, DivisionByZero) as e: return default_dashboard_response position_multiplier = first_opened_position["multiplier"] diff --git a/web_app/api/position.py b/web_app/api/position.py index a5579d8e5..1c752d011 100644 --- a/web_app/api/position.py +++ b/web_app/api/position.py @@ -14,8 +14,7 @@ TokenMultipliers, ) from web_app.api.serializers.position import TokenMultiplierResponse -from web_app.contract_tools.mixins.deposit import DepositMixin -from web_app.contract_tools.mixins.dashboard import DashboardMixin +from web_app.contract_tools.mixins import DepositMixin, DashboardMixin, PositionMixin from web_app.db.crud import PositionDBConnector router = APIRouter() # Initialize the router @@ -106,19 +105,17 @@ async def get_repay_data( :return: Dict containing the repay transaction data :raises: HTTPException :return: Dict containing status code and detail """ - # TODO rework it too many requests to DB if not wallet_id: raise HTTPException(status_code=404, detail="Wallet not found") - contract_address = position_db_connector.get_contract_address_by_wallet_id( - wallet_id - ) - position_id = position_db_connector.get_position_id_by_wallet_id(wallet_id) - position = position_db_connector.get_position_by_id(position_id) - if not position: - raise HTTPException(status_code=404, detail="Position not found") + contract_address, position_id, token_symbol = position_db_connector.get_repay_data(wallet_id) + is_opened_position = await PositionMixin.is_opened_position(contract_address) + if not is_opened_position: + raise HTTPException(status_code=400, detail="Position was closed") + if not position_id: + raise HTTPException(status_code=404, detail="Position not found or closed") - repay_data = await DepositMixin.get_repay_data(position.token_symbol) + repay_data = await DepositMixin.get_repay_data(token_symbol) repay_data["contract_address"] = contract_address repay_data["position_id"] = str(position_id) return repay_data diff --git a/web_app/api/serializers/airdrop.py b/web_app/api/serializers/airdrop.py deleted file mode 100644 index b90dc090e..000000000 --- a/web_app/api/serializers/airdrop.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Serializers for airdrop data. -""" - -from typing import List -from pydantic import BaseModel - - -class AirdropItem(BaseModel): - """Model for individual airdrop items.""" - - amount: str - proof: List[str] # This needs to be List[str], not str - is_claimed: bool - recipient: str - - -class AirdropResponseModel(BaseModel): - """Model for the complete airdrop response.""" - - airdrops: List[AirdropItem] diff --git a/web_app/api/user.py b/web_app/api/user.py index 54c3a3003..5f85d24b9 100644 --- a/web_app/api/user.py +++ b/web_app/api/user.py @@ -15,7 +15,7 @@ UpdateUserContractResponse, UserHistoryResponse, ) -from web_app.contract_tools.mixins.dashboard import DashboardMixin +from web_app.contract_tools.mixins import PositionMixin, DashboardMixin from web_app.db.crud import ( PositionDBConnector, TelegramUserDBConnector, @@ -45,7 +45,9 @@ async def has_user_opened_position(wallet_id: str) -> dict: """ try: has_position = position_db.has_opened_position(wallet_id) - return {"has_opened_position": has_position} + contract_address = user_db.get_contract_address_by_wallet_id(wallet_id) + is_position_opened = await PositionMixin.is_opened_position(contract_address) + return {"has_opened_position": has_position or is_position_opened} except ValueError as e: raise HTTPException( status_code=404, detail=f"Invalid wallet ID format: {str(e)}" diff --git a/web_app/contract_tools/airdrop.py b/web_app/contract_tools/airdrop.py deleted file mode 100644 index e9d7f21e0..000000000 --- a/web_app/contract_tools/airdrop.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -This module defines the contract tools for the airdrop data. -""" - -from typing import List -from web_app.api.serializers.airdrop import AirdropItem, AirdropResponseModel -from web_app.contract_tools.api_request import APIRequest -from web_app.contract_tools.constants import TokenParams - - -class ZkLendAirdrop: - """ - A class to fetch and validate airdrop data - for a specified contract. - """ - - REWARD_API_ENDPOINT = "https://app.zklend.com/api/reward/all/" - - def __init__(self): - """ - Initializes the ZkLendAirdrop class with an APIRequest instance. - """ - self.api = APIRequest(base_url=self.REWARD_API_ENDPOINT) - - async def get_contract_airdrop(self, contract_id: str) -> AirdropResponseModel: - """ - Fetches all available airdrops - for a specific contract asynchronously. - Args: - contract_id (str): The ID of the contract - for which to fetch airdrop data. - Returns: - AirdropResponseModel: A validated list of airdrop items - for the specified contract. - Raises: - ValueError: If contract_id is None - """ - if contract_id is None: - raise ValueError("Contract ID cannot be None") - - underlying_contract_id = TokenParams.add_underlying_address(contract_id) - response = await self.api.fetch(underlying_contract_id) - return self._validate_response(response) - - @staticmethod - def _validate_response(data: List[dict]) -> AirdropResponseModel: - """ - Validates and formats the response data, keeping only necessary fields. - Args: - data (List[dict]): Raw response data from the API. - Returns: - AirdropResponseModel: Structured and validated airdrop data. - """ - validated_items = [] - for item in data: - validated_item = AirdropItem( - amount=item["amount"], - proof=item[ - "proof" - ], # This is correct now as AirdropItem expects List[str] - is_claimed=item["is_claimed"], - recipient=item["recipient"], - ) - validated_items.append(validated_item) - return AirdropResponseModel(airdrops=validated_items) - - -if __name__ == "__main__": - airdrop_fetcher = ZkLendAirdrop() - result = airdrop_fetcher.get_contract_airdrop( - "0x698b63df00be56ba39447c9b9ca576ffd0edba0526d98b3e8e4a902ffcf12f0" - ) - print(result) diff --git a/web_app/contract_tools/blockchain_call.py b/web_app/contract_tools/blockchain_call.py index c4993fa61..db7480161 100644 --- a/web_app/contract_tools/blockchain_call.py +++ b/web_app/contract_tools/blockchain_call.py @@ -325,3 +325,16 @@ async def claim_airdrop(self, contract_address: str, proofs: list[str]) -> None: selector="claim", calldata=proofs, ) + + async def is_opened_position(self, contract_address: str) -> bool: + """ + Checks if a position is opened on the Starknet blockchain. + + :param contract_address: The contract address. + :return: A boolean indicating if the position is opened. + """ + return await self._func_call( + addr=self._convert_address(contract_address), + selector="is_position_open", + calldata=[], + ) diff --git a/web_app/contract_tools/mixins/__init__.py b/web_app/contract_tools/mixins/__init__.py index 65277273d..5fdbce757 100644 --- a/web_app/contract_tools/mixins/__init__.py +++ b/web_app/contract_tools/mixins/__init__.py @@ -2,7 +2,15 @@ Import all mixins here to make them available to the rest of the application. """ + from .dashboard import DashboardMixin from .health_ratio import HealthRatioMixin from .deposit import DepositMixin from .alert import AlertMixin +from .position import PositionMixin +from web_app.contract_tools.blockchain_call import ( + StarknetClient, +) + + +CLIENT = StarknetClient() diff --git a/web_app/contract_tools/mixins/alert.py b/web_app/contract_tools/mixins/alert.py index d529b2426..c71572a1c 100644 --- a/web_app/contract_tools/mixins/alert.py +++ b/web_app/contract_tools/mixins/alert.py @@ -8,6 +8,7 @@ from web_app.contract_tools.mixins import HealthRatioMixin from web_app.db.crud import UserDBConnector + logger = logging.getLogger(__name__) ALERT_THRESHOLD = 3.2 # FIXME return to 1.1 after testing diff --git a/web_app/contract_tools/mixins/dashboard.py b/web_app/contract_tools/mixins/dashboard.py index 87b8b5928..a19e2c4d6 100644 --- a/web_app/contract_tools/mixins/dashboard.py +++ b/web_app/contract_tools/mixins/dashboard.py @@ -6,15 +6,13 @@ from typing import Dict from decimal import Decimal -from web_app.contract_tools.blockchain_call import StarknetClient + from web_app.contract_tools.constants import TokenParams, MULTIPLIER_POWER from web_app.contract_tools.api_request import APIRequest -from web_app.api.serializers.dashboard import DashboardResponse logger = logging.getLogger(__name__) -CLIENT = StarknetClient() # example of ARGENT_X_POSITION_URL # "https://cloud.argent-api.com/v1/tokens/defi/decomposition/{wallet_id}?chain=starknet" ARGENT_X_POSITION_URL = "https://cloud.argent-api.com/v1/tokens/defi/" @@ -67,6 +65,8 @@ async def get_wallet_balances(cls, holder_address: str) -> Dict[str, str]: :param holder_address: holder address :return: Returns the wallet balances for the given holder address. """ + from . import CLIENT + wallet_balances = {} for token in TokenParams.tokens(): @@ -84,18 +84,6 @@ async def get_wallet_balances(cls, holder_address: str) -> Dict[str, str]: return wallet_balances - @classmethod - async def get_zklend_position( - cls, contract_address: str, position: "Position" - ) -> DashboardResponse: - """ - Get the zkLend position for the given wallet ID. - :param contract_address: contract address - :param position: Position db model - :return: zkLend position validated by Pydantic models - """ - pass - @classmethod def _get_products(cls, dapps: list) -> list[dict]: """ @@ -133,8 +121,8 @@ async def get_current_position_sum(cls, position: dict) -> Decimal: """ current_prices = await cls.get_current_prices() price = current_prices.get(position.get("token_symbol"), Decimal(0)) - amount = Decimal(position.get("amount", 0)) - multiplier = Decimal(position.get("multiplier", 0)) + amount = Decimal(position.get("amount", 0) or 0) + multiplier = Decimal(position.get("multiplier", 0) or 0) return cls._calculate_sum(price, amount, multiplier) @classmethod diff --git a/web_app/contract_tools/mixins/deposit.py b/web_app/contract_tools/mixins/deposit.py index 4061c289e..546ee9b97 100644 --- a/web_app/contract_tools/mixins/deposit.py +++ b/web_app/contract_tools/mixins/deposit.py @@ -3,10 +3,9 @@ """ from decimal import Decimal -from web_app.contract_tools.blockchain_call import StarknetClient from web_app.contract_tools.constants import TokenParams -CLIENT = StarknetClient() + # alternative ARGENT_X_POSITION_URL # "https://cloud.argent-api.com/v1/tokens/defi/decomposition/{wallet_id}?chain=starknet" ARGENT_X_POSITION_URL = "https://cloud.argent-api.com/v1/tokens/defi/" @@ -35,6 +34,8 @@ async def get_transaction_data( :param borrowing_token: Borrowing token :return: approve_data and loop_liquidity_data """ + from . import CLIENT + deposit_token_address = TokenParams.get_token_address(deposit_token) decimal = TokenParams.get_token_decimals(deposit_token_address) amount = int(Decimal(amount) * 10**decimal) @@ -52,8 +53,14 @@ async def get_repay_data(cls, supply_token: str) -> dict: :param supply_token: Deposit token :return: dict with repay data """ + from . import CLIENT + deposit_token_address = TokenParams.get_token_address(supply_token) - debt_token_address = TokenParams.get_token_address("USDC") if supply_token != "USDC" else TokenParams.get_token_address("ETH") + debt_token_address = ( + TokenParams.get_token_address("USDC") + if supply_token != "USDC" + else TokenParams.get_token_address("ETH") + ) repay_data = { "supply_token": deposit_token_address, "debt_token": debt_token_address, diff --git a/web_app/contract_tools/mixins/health_ratio.py b/web_app/contract_tools/mixins/health_ratio.py index 1c8fefaf0..32fba27cb 100644 --- a/web_app/contract_tools/mixins/health_ratio.py +++ b/web_app/contract_tools/mixins/health_ratio.py @@ -7,13 +7,8 @@ from pragma_sdk.common.types.types import AggregationMode from pragma_sdk.onchain.client import PragmaOnChainClient - -from web_app.contract_tools.blockchain_call import ( - StarknetClient, -) from web_app.contract_tools.constants import TokenParams -CLIENT = StarknetClient() PRAGMA = PragmaOnChainClient( network="mainnet", ) @@ -53,6 +48,8 @@ async def _get_z_balances( :return: A dictionary of token balances with token symbols as keys and balances as Decimal values. """ + from . import CLIENT + tasks = [ CLIENT.get_balance( z_data[1], @@ -81,6 +78,8 @@ async def _get_deposited_tokens( :return: A dictionary of deposited tokens with token symbols as keys and amounts as Decimal values. """ + from . import CLIENT + reserves = await CLIENT.get_z_addresses() deposits = await cls._get_z_balances( reserves, deposit_contract_address @@ -123,6 +122,8 @@ async def _get_borrowed_token( """ :return: Tuple with borrowed token and current debt on ZkLend """ + from . import CLIENT + tasks = [ CLIENT.get_zklend_debt(deposit_contract_address, token.address) for token in TokenParams.tokens() @@ -174,10 +175,6 @@ async def get_health_ratio_and_tvl( * prices[borrowed_token] / 10 ** int(TokenParams.get_token_decimals(borrowed_address)) ) - # return { - # "health_factor": f"{round(deposit_usdc / Decimal(debt_usdc), 2)}" if debt_usdc != 0 else "0", # pylint: disable=line-too-long - # "ltv": f"{round((debt_usdc / TokenParams.get_borrow_factor(borrowed_token)) / deposit_usdc, 2)}" # pylint: disable=line-too-long - # } health_factor = ( f"{round(deposit_usdc / Decimal(debt_usdc), 2)}" if debt_usdc != 0 diff --git a/web_app/contract_tools/mixins/position.py b/web_app/contract_tools/mixins/position.py new file mode 100644 index 000000000..c62fcb449 --- /dev/null +++ b/web_app/contract_tools/mixins/position.py @@ -0,0 +1,24 @@ +""" +Mixins for position related methods +""" + + +class PositionMixin: + """ + Mixin for position related methods + """ + + @classmethod + async def is_opened_position(cls, contract_address: str) -> bool: + """ + Check if the position is opened. + :param contract_address: Contract address + :return: True if the position is opened, False otherwise + """ + from . import CLIENT + + response = await CLIENT.is_opened_position(contract_address) + try: + return bool(response[0]) + except IndexError: + return False diff --git a/web_app/db/crud/__init__.py b/web_app/db/crud/__init__.py index 90ae2a919..c05dbf571 100644 --- a/web_app/db/crud/__init__.py +++ b/web_app/db/crud/__init__.py @@ -2,7 +2,6 @@ This module contains the CRUD operations for the database. """ -from .airdrop import * from .base import * from .deposit import * from .position import * diff --git a/web_app/db/crud/airdrop.py b/web_app/db/crud/airdrop.py deleted file mode 100644 index bd177052b..000000000 --- a/web_app/db/crud/airdrop.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -This module contains the database configuration for airdrops. -""" - -import logging -import uuid -from datetime import datetime -from decimal import Decimal -from typing import List, TypeVar - -from sqlalchemy.exc import SQLAlchemyError - -from web_app.db.models import AirDrop, Base -from .base import DBConnector - -logger = logging.getLogger(__name__) -ModelType = TypeVar("ModelType", bound=Base) - - -class AirDropDBConnector(DBConnector): - """ - Provides database connection and operations management for the AirDrop model. - """ - - def save_claim_data(self, airdrop_id: uuid.UUID, amount: Decimal) -> None: - """ - Updates the AirDrop instance with claim data. - :param airdrop_id: uuid.UUID - :param amount: Decimal - """ - airdrop = self.get_object(AirDrop, airdrop_id) - if airdrop: - airdrop.amount = amount - airdrop.is_claimed = True - airdrop.claimed_at = datetime.now() - self.write_to_db(airdrop) - else: - logger.error(f"AirDrop with ID {airdrop_id} not found") - - def get_all_unclaimed(self) -> List[AirDrop]: - """ - Returns all unclaimed AirDrop instances (where is_claimed is False). - - :return: List of unclaimed AirDrop instances - """ - with self.Session() as db: - try: - unclaimed_instances = ( - db.query(AirDrop).filter_by(is_claimed=False).all() - ) - return unclaimed_instances - except SQLAlchemyError as e: - logger.error( - f"Failed to retrieve unclaimed AirDrop instances: {str(e)}" - ) - return [] - - def delete_all_users_airdrop(self, user_id: uuid.UUID) -> None: - """ - Delete all airdrops for a user. - :param user_id: User ID - """ - with self.Session() as db: - try: - airdrops = db.query(AirDrop).filter_by(user_id=user_id).all() - for airdrop in airdrops: - db.delete(airdrop) - db.commit() - except SQLAlchemyError as e: - logger.error(f"Error deleting airdrops for user {user_id}: {str(e)}") diff --git a/web_app/db/crud/base.py b/web_app/db/crud/base.py index f356f9da5..16d1446dd 100644 --- a/web_app/db/crud/base.py +++ b/web_app/db/crud/base.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import scoped_session, sessionmaker from web_app.db.database import SQLALCHEMY_DATABASE_URL -from web_app.db.models import AirDrop, Base +from web_app.db.models import Base logger = logging.getLogger(__name__) ModelType = TypeVar("ModelType", bound=Base) @@ -130,13 +130,3 @@ def delete_object(self, object: Base) -> None: finally: db.close() - - def create_empty_claim(self, user_id: uuid.UUID) -> AirDrop: - """ - Creates a new empty AirDrop instance for the given user_id. - :param user_id: uuid.UUID - :return: AirDrop - """ - airdrop = AirDrop(user_id=user_id) - self.write_to_db(airdrop) - return airdrop diff --git a/web_app/db/crud/position.py b/web_app/db/crud/position.py index ace783baf..a467a2798 100644 --- a/web_app/db/crud/position.py +++ b/web_app/db/crud/position.py @@ -207,7 +207,7 @@ def close_position(self, position_id: uuid) -> Position | None: def open_position(self, position_id: uuid.UUID, current_prices: dict) -> str | None: """ - Opens a position by updating its status and creating an AirDrop claim. + Opens a position by updating its status. :param position_id: uuid.UUID :param current_prices: dict :return: str | None @@ -216,7 +216,6 @@ def open_position(self, position_id: uuid.UUID, current_prices: dict) -> str | N if position: position.status = Status.OPENED.value self.write_to_db(position) - self.create_empty_claim(position.user_id) self.save_current_price(position, current_prices) return position.status else: @@ -322,6 +321,29 @@ def get_all_liquidated_positions(self) -> list[dict]: logger.error(f"Error retrieving liquidated positions: {str(e)}") return [] + def get_repay_data(self, wallet_id: str) -> tuple: + """ + Retrieves the repay data for a user. + :param wallet_id: + :return: + """ + with self.Session() as db: + result = ( + db.query( + User.contract_address, Position.id, Position.token_symbol + ) + .join(Position, Position.user_id == User.id) + .filter(User.wallet_id == wallet_id) + .first() + ) + + if not result: + return None, None, None + + return result + + + def get_position_by_id(self, position_id: int) -> Position | None: """ Retrieves a position by its ID. diff --git a/web_app/db/models.py b/web_app/db/models.py index 30070cef5..e6debfdca 100644 --- a/web_app/db/models.py +++ b/web_app/db/models.py @@ -1,6 +1,6 @@ """ This module contains SQLAlchemy models for the application, including -User, Position, AirDrop, and TelegramUser. Each model represents a +User, Position, and TelegramUser. Each model represents a table in the database and defines the structure and relationships between the data entities. """ @@ -84,23 +84,6 @@ class Position(Base): datetime_liquidation = Column(DateTime, nullable=True) -class AirDrop(Base): - """ - SQLAlchemy model for the airdrop table. - """ - - __tablename__ = "airdrop" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4) - user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id"), index=True, nullable=False - ) - created_at = Column(DateTime, nullable=False, default=func.now()) - amount = Column(DECIMAL, nullable=True) - is_claimed = Column(Boolean, default=False, index=True) - claimed_at = Column(DateTime, nullable=True) - - class TelegramUser(Base): """ SQLAlchemy model for the telegram_user table. diff --git a/web_app/db/seed_data.py b/web_app/db/seed_data.py index cdf99057a..e9cf0a52a 100644 --- a/web_app/db/seed_data.py +++ b/web_app/db/seed_data.py @@ -5,7 +5,7 @@ import logging from decimal import Decimal from faker import Faker -from web_app.db.models import Status, User, Position, AirDrop, TelegramUser, Vault +from web_app.db.models import Status, User, Position, TelegramUser, Vault from web_app.db.database import SessionLocal from web_app.contract_tools.constants import TokenParams @@ -74,30 +74,6 @@ def create_positions(session: SessionLocal, users: list[User]) -> None: logger.info("No positions created.") -def create_airdrops(session: SessionLocal, users: list[User]) -> None: - """ - Create and save fake airdrop records for each user. - Args: - session (Session): SQLAlchemy session object. - users (list): List of User objects to associate with airdrops. - """ - airdrops = [] - for user in users: - for _ in range(2): - airdrop = AirDrop( - user_id=user.id, - amount=Decimal( - fake.pydecimal(left_digits=5, right_digits=2, positive=True) - ), - is_claimed=fake.boolean(), - claimed_at=fake.date_time_this_decade() if fake.boolean() else None, - ) - airdrops.append(airdrop) - if airdrops: - session.bulk_save_objects(airdrops) - session.commit() - - def create_telegram_users(session: SessionLocal, users: list[User]) -> None: """ Create and save fake Telegram user records to the database. @@ -157,7 +133,6 @@ def create_vaults(session: SessionLocal, users: list[User]) -> None: # Populate the database users = create_users(session) create_positions(session, users) - create_airdrops(session, users) create_telegram_users(session, users) create_vaults(session, users) diff --git a/web_app/tasks/claim_airdrops.py b/web_app/tasks/claim_airdrops.py deleted file mode 100644 index d67c0c403..000000000 --- a/web_app/tasks/claim_airdrops.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -Module for claiming unclaimed airdrops from the AirDropDBConnector -and updating the database when a claim is successful. -""" - -import asyncio -import logging -from typing import List - -from requests.exceptions import ConnectionError, Timeout -from sqlalchemy.exc import SQLAlchemyError -from web_app.contract_tools.airdrop import ZkLendAirdrop -from web_app.contract_tools.blockchain_call import StarknetClient -from web_app.db.crud import AirDropDBConnector - -logger = logging.getLogger(__name__) - - -class AirdropClaimer: - """ - Handles the process of claiming unclaimed airdrops and updating the database. - """ - - def __init__(self): - """ - Initializes the AirdropClaimer with database and Starknet client instances. - """ - self.db_connector = AirDropDBConnector() - self.starknet_client = StarknetClient() - self.zk_lend_airdrop = ZkLendAirdrop() - - async def claim_airdrops(self) -> None: - """ - Retrieves unclaimed airdrops, attempts to claim them on the Starknet blockchain, - and updates the database if the claim is successful. - """ - unclaimed_airdrops = self.db_connector.get_all_unclaimed() - for airdrop in unclaimed_airdrops: - try: - user_contract_address = airdrop.user.contract_address - proofs = self.zk_lend_airdrop.get_contract_airdrop( - user_contract_address - ) - - claim_successful = await self._claim_airdrop( - user_contract_address, proofs - ) - - if claim_successful: - self.db_connector.save_claim_data(airdrop.id, airdrop.amount) - logger.info("Airdrop %s claimed succesfully.", airdrop.id) - except ValueError as ve: - logger.error("Invalid data for airdrop %s: %s", airdrop.id, ve) - except SQLAlchemyError as db_err: - logger.error( - "Database error while updating claim data for airdrop %s: %s", - airdrop.id, - db_err, - ) - except ConnectionError as ce: - logger.error( - "Network connection error during claim for airdrop %s: %s", - airdrop.id, - ce, - ) - except Timeout as te: - logger.error("Timeout during claim for airdrop %s: %s", airdrop.id, te) - except Exception as e: - logger.error("Unexpected error claiming airdrop %s: %s", airdrop.id, e) - - async def _claim_airdrop(self, contract_address: str, proofs: List[str]) -> bool: - """ - Claims a single airdrop by making a contract call on the Starknet blockchain. - """ - try: - await self.starknet_client.claim_airdrop(contract_address, proofs) - return True - except ConnectionError as ce: - logger.error( - "Network connection failed for address %s: %s", contract_address, ce - ) - return False - except Timeout as te: - logger.error( - "Timeout during claim for address %s: %s", contract_address, te - ) - return False - except ValueError as ve: - logger.error( - "Invalid data format for calldata during claim for address %s: %s", - contract_address, - ve, - ) - return False - except Exception as e: - logger.error( - "Unexpected error claiming address %s: %s", contract_address, e - ) - return False - - -if __name__ == "__main__": - airdrop_claimer = AirdropClaimer() - asyncio.run(airdrop_claimer.claim_airdrops()) diff --git a/web_app/test_integration/test_close_position.py b/web_app/test_integration/test_close_position.py index 30decfbd8..c6eb13aae 100644 --- a/web_app/test_integration/test_close_position.py +++ b/web_app/test_integration/test_close_position.py @@ -8,11 +8,10 @@ import pytest from web_app.contract_tools.mixins.dashboard import DashboardMixin -from web_app.db.crud import PositionDBConnector, UserDBConnector, AirDropDBConnector +from web_app.db.crud import PositionDBConnector, UserDBConnector from web_app.db.models import Status user_db = UserDBConnector() -airdrop = AirDropDBConnector() position_db = PositionDBConnector() @@ -104,7 +103,6 @@ def test_close_position(self, form_data: Dict[str, Any]) -> None: # Clean up - delete the position and user user = position_db.get_user_by_wallet_id(wallet_id) - airdrop.delete_all_users_airdrop(user.id) position_db.delete_position(position) if not position_db.get_positions_by_wallet_id(wallet_id): position_db.delete_user_by_wallet_id(wallet_id) diff --git a/web_app/test_integration/test_create_position.py b/web_app/test_integration/test_create_position.py index b06547c4d..00f81bcb1 100644 --- a/web_app/test_integration/test_create_position.py +++ b/web_app/test_integration/test_create_position.py @@ -17,12 +17,11 @@ import pytest from typing import Dict, Any from datetime import datetime -from web_app.db.crud import PositionDBConnector, AirDropDBConnector +from web_app.db.crud import PositionDBConnector from web_app.contract_tools.mixins.dashboard import DashboardMixin from web_app.db.models import Status position_db = PositionDBConnector() -airdrop = AirDropDBConnector() class TestPositionCreation: @@ -106,6 +105,5 @@ def test_create_position(self, form_data: Dict[str, Any]) -> None: print(f"Position {position.id} successfully opened.") user = position_db.get_user_by_wallet_id(wallet_id) - airdrop.delete_all_users_airdrop(user.id) position_db.delete_all_user_positions(user.id) position_db.delete_user_by_wallet_id(wallet_id) diff --git a/web_app/tests/db/test_dbconnector.py b/web_app/tests/db/test_dbconnector.py index e5ab4ea17..9b4e7b9c8 100644 --- a/web_app/tests/db/test_dbconnector.py +++ b/web_app/tests/db/test_dbconnector.py @@ -12,7 +12,7 @@ PositionDBConnector, UserDBConnector, ) -from web_app.db.models import AirDrop, Base, Position, Status, User +from web_app.db.models import Base, Position, Status, User @pytest.fixture(scope="function") @@ -121,26 +121,4 @@ def db_connector(): connector.delete_object_by_id(User, test_user.id) -def test_create_empty_claim_positive(db_connector): - """ - Test that create_empty_claim successfully creates an AirDrop - for an existing user. - """ - connector, test_user = db_connector - airdrop = connector.create_empty_claim(test_user.id) - assert airdrop is not None - assert airdrop.user_id == test_user.id - assert not airdrop.is_claimed - assert airdrop.amount is None - assert airdrop.claimed_at is None - connector.delete_object_by_id(AirDrop, airdrop.id) - -def test_create_empty_claim_non_existent_user(db_connector): - """ - Test that create_empty_claim raises an error when called with a non-existent user ID. - """ - connector, _ = db_connector - fake_user_id = uuid.uuid4() - with pytest.raises(SQLAlchemyError): - connector.create_empty_claim(fake_user_id) diff --git a/web_app/tests/db/test_user_dbconnector.py b/web_app/tests/db/test_user_dbconnector.py index 4d62a0004..06dcf8216 100644 --- a/web_app/tests/db/test_user_dbconnector.py +++ b/web_app/tests/db/test_user_dbconnector.py @@ -7,8 +7,8 @@ import pytest from sqlalchemy.exc import SQLAlchemyError -from web_app.db.crud import AirDropDBConnector, UserDBConnector -from web_app.db.models import AirDrop, User +from web_app.db.crud import UserDBConnector +from web_app.db.models import User @pytest.fixture @@ -90,50 +90,3 @@ def test_get_unique_users_count(mock_user_db_connector): result = mock_user_db_connector.get_unique_users_count() assert result == 5 - - -def test_delete_all_users_airdrop_success(user_db): - """ - Test successful deletion of all airdrops for a user. - """ - user_id = "123e4567-e89b-12d3-a456-426614174000" - mock_session = MagicMock() - mock_airdrops = [ - AirDrop(id=1, user_id=user_id), - AirDrop(id=2, user_id=user_id), - ] - - air_drop_connector = AirDropDBConnector() - with patch.object(air_drop_connector, "Session") as mock_session_factory: - mock_session_factory.return_value.__enter__.return_value = mock_session - mock_session.query.return_value.filter_by.return_value.all.return_value = ( - mock_airdrops - ) - - air_drop_connector.delete_all_users_airdrop(user_id) - - mock_session.query.assert_called_once_with(AirDrop) - mock_session.query.return_value.filter_by.assert_called_once_with( - user_id=user_id - ) - assert mock_session.delete.call_count == len(mock_airdrops) - mock_session.commit.assert_called_once() - - -def test_delete_all_users_airdrop_failure(user_db): - """ - Test failure while deleting airdrops for a user. - """ - user_id = "123e4567-e89b-12d3-a456-426614174000" - mock_session = MagicMock() - mock_session.query.side_effect = SQLAlchemyError("Database error") - - air_drop_connector = AirDropDBConnector() - with patch.object( - air_drop_connector, "Session", return_value=mock_session - ) as mock_session_factory: - mock_session_factory.return_value.__enter__.return_value = mock_session - - air_drop_connector.delete_all_users_airdrop(user_id) - mock_session.query.assert_called_once_with(AirDrop) - # mock_session.rollback.assert_called_once() diff --git a/web_app/tests/test_airdrop.py b/web_app/tests/test_airdrop.py deleted file mode 100644 index da4d76b16..000000000 --- a/web_app/tests/test_airdrop.py +++ /dev/null @@ -1,172 +0,0 @@ -""" -Tests for the AirDropDBConnector class, covering key database operations for airdrops. - -Fixtures: -- db_connector: Provides an AirDropDBConnector instance with test user and airdrop data. - -Test Cases: -- test_create_empty_claim_positive: Verifies airdrop creation for an existing user. -- test_create_empty_claim_non_existent_user: Checks error handling for invalid user IDs. -- test_save_claim_data_positive: Ensures claim data updates correctly. -- test_save_claim_data_non_existent_airdrop: Confirms logging for invalid airdrop IDs. -- test_get_all_unclaimed_positive: Retrieves unclaimed airdrops. -- test_get_all_unclaimed_after_claiming: Excludes claimed airdrops from unclaimed results. -""" - -import uuid -from datetime import datetime -from decimal import Decimal - -import pytest -from sqlalchemy.exc import SQLAlchemyError - -from web_app.db.crud import AirDropDBConnector -from web_app.db.models import AirDrop, User - - -@pytest.fixture -def db_connector(): - """ - Sets up an AirDropDBConnector with a test user and airdrop record, then cleans - up after the test. - This fixture: - - Initializes an AirDropDBConnector instance. - - Creates and saves a test user and associated airdrop record. - - Yields the connector, user, and airdrop instances for test use. - - Cleans up the database by removing the test user and airdrop after the test. - - Yields: - tuple: (AirDropDBConnector, User, AirDrop) - """ - connector = AirDropDBConnector() - test_user = User(wallet_id="test_wallet_id") - connector.write_to_db(test_user) - airdrop = AirDrop(user_id=test_user.id) - connector.write_to_db(airdrop) - yield connector, test_user, airdrop - connector.delete_object_by_id(AirDrop, airdrop.id) - connector.delete_object_by_id(User, test_user.id) - - -def test_create_empty_claim_positive(db_connector): - """ - Tests that create_empty_claim successfully creates a new airdrop for an - existing user. - - Steps: - - Calls create_empty_claim with a valid user ID. - - Asserts the airdrop is created with the correct user_id and - is initially unclaimed. - - Args: - db_connector (fixture): Provides the AirDropDBConnector, test user, - and test airdrop. - """ - connector, test_user, _ = db_connector - new_airdrop = connector.create_empty_claim(test_user.id) - assert new_airdrop is not None - assert new_airdrop.user_id == test_user.id - assert not new_airdrop.is_claimed - connector.delete_object_by_id(AirDrop, new_airdrop.id) - - -def test_create_empty_claim_non_existent_user(db_connector): - """ - Tests that create_empty_claim raises an error when called with - a non-existent user ID. - - Steps: - - Generates a fake user ID that does not exist in the database. - - Verifies that calling create_empty_claim with this ID raises - an SQLAlchemyError. - - Args: - db_connector (fixture): Provides the AirDropDBConnector - and test setup. - """ - connector, _, _ = db_connector - fake_user_id = uuid.uuid4() - with pytest.raises(SQLAlchemyError): - connector.create_empty_claim(fake_user_id) - - -def test_save_claim_data_positive(db_connector): - """ - Tests that save_claim_data correctly updates an existing airdrop - with claim details. - - Steps: - - Calls save_claim_data with a valid airdrop ID and amount. - - Asserts the airdrop's amount, is_claimed status, and claimed_at - timestamp are updated correctly. - - Args: - db_connector (fixture): Provides the AirDropDBConnector, test user, - and test airdrop. - """ - connector, _, airdrop = db_connector - amount = Decimal("100.50") - connector.save_claim_data(airdrop.id, amount) - updated_airdrop = connector.get_object(AirDrop, airdrop.id) - assert updated_airdrop.amount == amount - assert updated_airdrop.is_claimed - assert updated_airdrop.claimed_at is not None - - -def test_save_claim_data_non_existent_airdrop(db_connector, caplog): - """ - Tests that save_claim_data logs an error when called with a non-existent - airdrop ID. - - Steps: - - Generates a fake airdrop ID that is not in the database. - - Calls save_claim_data with this ID and checks that the appropriate - error message is logged. - - Args: - db_connector (fixture): Provides the AirDropDBConnector and - test setup. - caplog (fixture): Captures log output for verification. - """ - connector, _, _ = db_connector - fake_airdrop_id = uuid.uuid4() - connector.save_claim_data(fake_airdrop_id, Decimal("50.00")) - assert f"AirDrop with ID {fake_airdrop_id} not found" in caplog.text - - -def test_get_all_unclaimed_positive(db_connector): - """ - Tests that get_all_unclaimed retrieves unclaimed airdrops correctly. - - Steps: - - Calls get_all_unclaimed to fetch unclaimed airdrops. - - Asserts that the test airdrop (unclaimed) is present in the retrieved - list by matching IDs. - - Args: - db_connector (fixture): Provides the AirDropDBConnector, test user, - and test airdrop. - """ - connector, _, airdrop = db_connector - unclaimed_airdrops = connector.get_all_unclaimed() - assert any(airdrop.id == unclaimed.id for unclaimed in unclaimed_airdrops) - - -def test_get_all_unclaimed_after_claiming(db_connector): - """ - Tests that get_all_unclaimed excludes airdrops that have been claimed. - - Steps: - - Marks the test airdrop as claimed using save_claim_data. - - Calls get_all_unclaimed to fetch unclaimed airdrops. - - Asserts that the claimed airdrop is not included in the - returned list. - - Args: - db_connector (fixture): Provides the AirDropDBConnector, - test user, and test airdrop. - """ - connector, _, airdrop = db_connector - connector.save_claim_data(airdrop.id, Decimal("50.00")) - unclaimed_airdrops = connector.get_all_unclaimed() - assert airdrop not in unclaimed_airdrops diff --git a/web_app/tests/test_claim_airdrops.py b/web_app/tests/test_claim_airdrops.py deleted file mode 100644 index d8463a408..000000000 --- a/web_app/tests/test_claim_airdrops.py +++ /dev/null @@ -1,220 +0,0 @@ -""" -Tests for the AirdropClaimer class, covering comprehensive airdrop claim operations. - -Fixtures: -- airdrop_claimer: Fixture creating a mock AirdropClaimer instance for consistent testing -- mock_airdrop: Fixture generating a standard mock airdrop object for reusable test scenarios - -Test Cases: -- test_claim_airdrops_successful: Validates successful airdrop claim workflow -- test_claim_airdrops_no_unclaimed: Checks behavior when no unclaimed airdrops exist -- test_claim_airdrops_partial_failure: Tests mixed success and failure scenarios -- test_claim_airdrops_database_error: Verifies database error handling -- test_claim_airdrop_timeout_error: Ensures proper handling of request timeout errors -- test_claim_airdrop_invalid_proof: Checks processing of invalid proof data -- test_claim_airdrop_unexpected_error: Validates unexpected error management -""" - -import logging -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from requests.exceptions import ConnectionError, Timeout -from sqlalchemy.exc import SQLAlchemyError - -from web_app.tasks.claim_airdrops import AirdropClaimer - - -@pytest.fixture -def airdrop_claimer(): - """ - Fixture to create a mock AirdropClaimer instance for each test. - - Yields: - claimer - """ - claimer = AirdropClaimer() - claimer.db_connector = MagicMock() - claimer.starknet_client = AsyncMock() - claimer.zk_lend_airdrop = MagicMock() - yield claimer - - -@pytest.fixture -def mock_airdrop(): - """ - Create a standard mock airdrop for reusable test setup. - - Yields: - mock_airdrop - """ - mock_airdrop = MagicMock() - mock_airdrop.user.contract_address = "0x123" - mock_airdrop.id = 1 - mock_airdrop.amount = 100 - yield mock_airdrop - - -@pytest.mark.asyncio -async def test_claim_airdrops_successful(airdrop_claimer, mock_airdrop): - """ - Test the claim_airdrops method for successful claims. - """ - # Arrange - airdrop_claimer.db_connector.get_all_unclaimed.return_value = [mock_airdrop] - airdrop_claimer.zk_lend_airdrop.get_contract_airdrop.return_value = [ - "proof1", - "proof2", - ] - airdrop_claimer.starknet_client.claim_airdrop.return_value = True - - # Act - await airdrop_claimer.claim_airdrops() - - # Assertions - airdrop_claimer.zk_lend_airdrop.get_contract_airdrop.assert_called_with("0x123") - airdrop_claimer.starknet_client.claim_airdrop.assert_awaited_with( - "0x123", ["proof1", "proof2"] - ) - airdrop_claimer.db_connector.save_claim_data.assert_called_with(1, 100) - - -@pytest.mark.asyncio -async def test_claim_airdrops_no_unclaimed(airdrop_claimer): - """ - Test claim_airdrops when no unclaimed airdrops exist. - """ - # Arrange - airdrop_claimer.db_connector.get_all_unclaimed.return_value = [] - - # Act - await airdrop_claimer.claim_airdrops() - - # Assertions - airdrop_claimer.zk_lend_airdrop.get_contract_airdrop.assert_not_called() - airdrop_claimer.starknet_client.claim_airdrop.assert_not_called() - airdrop_claimer.db_connector.save_claim_data.assert_not_called() - - -@pytest.mark.asyncio -async def test_claim_airdrops_partial_failure(airdrop_claimer): - """ - Test claim_airdrops with multiple airdrops, some failing and some succeeding. - """ - # Arrange - mock_airdrop1 = MagicMock( - user=MagicMock(contract_address="0x123"), id=1, amount=100 - ) - mock_airdrop2 = MagicMock( - user=MagicMock(contract_address="0x456"), id=2, amount=200 - ) - - airdrop_claimer.db_connector.get_all_unclaimed.return_value = [ - mock_airdrop1, - mock_airdrop2, - ] - - # Mock different behaviors for different airdrops - airdrop_claimer.zk_lend_airdrop.get_contract_airdrop.side_effect = [ - ["proof1"], - ["proof2"], - ] - airdrop_claimer.starknet_client.claim_airdrop.side_effect = [ - True, - ValueError("Claim failed"), - ] - - # Act - await airdrop_claimer.claim_airdrops() - - # Assertions - # Verify first airdrop was claimed and saved - airdrop_claimer.db_connector.save_claim_data.assert_any_call(1, 100) - # Verify second airdrop was not saved due to claim failure - assert airdrop_claimer.db_connector.save_claim_data.call_count == 1 - - -@pytest.mark.asyncio -async def test_claim_airdrops_database_error(airdrop_claimer, mock_airdrop, caplog): - """ - Test handling of database errors during airdrop claiming. - """ - # Arrange - airdrop_claimer.db_connector.get_all_unclaimed.return_value = [mock_airdrop] - airdrop_claimer.zk_lend_airdrop.get_contract_airdrop.return_value = ["proof1"] - airdrop_claimer.starknet_client.claim_airdrop.return_value = True - - # Simulate database save error - airdrop_claimer.db_connector.save_claim_data.side_effect = SQLAlchemyError( - "Database error" - ) - - # Act - with caplog.at_level(logging.ERROR): - await airdrop_claimer.claim_airdrops() - - # Assertions - assert "Database error while updating claim data" in caplog.text - airdrop_claimer.starknet_client.claim_airdrop.assert_called_once() - airdrop_claimer.db_connector.save_claim_data.assert_called_once() - - -@pytest.mark.asyncio -async def test_claim_airdrop_timeout_error(airdrop_claimer): - """ - Test _claim_airdrop method handling of timeout errors. - """ - # Arrange - airdrop_claimer.starknet_client.claim_airdrop.side_effect = Timeout( - "Request timed out" - ) - - # Act - result = await airdrop_claimer._claim_airdrop("0x123", ["proof1"]) - - # Assertions - assert result is False - airdrop_claimer.starknet_client.claim_airdrop.assert_awaited_with( - "0x123", ["proof1"] - ) - - -@pytest.mark.asyncio -async def test_claim_airdrop_invalid_proof(airdrop_claimer): - """ - Test _claim_airdrop method with invalid proof data. - """ - # Arrange - airdrop_claimer.starknet_client.claim_airdrop.side_effect = ValueError( - "Invalid proof" - ) - - # Act - result = await airdrop_claimer._claim_airdrop("0x123", ["invalid_proof"]) - - # Assertions - assert result is False - airdrop_claimer.starknet_client.claim_airdrop.assert_awaited_with( - "0x123", ["invalid_proof"] - ) - - -@pytest.mark.asyncio -async def test_claim_airdrop_unexpected_error(airdrop_claimer, caplog): - """ - Test _claim_airdrop method handling of unexpected errors. - """ - # Arrange - unexpected_error = Exception("Completely unexpected error") - airdrop_claimer.starknet_client.claim_airdrop.side_effect = unexpected_error - - # Act - with caplog.at_level(logging.ERROR): - result = await airdrop_claimer._claim_airdrop("0x123", ["proof1"]) - - # Assertions - assert result is False - assert "Unexpected error claiming address" in caplog.text - airdrop_claimer.starknet_client.claim_airdrop.assert_awaited_with( - "0x123", ["proof1"] - ) diff --git a/web_app/tests/test_zklend_airdrop.py b/web_app/tests/test_zklend_airdrop.py deleted file mode 100644 index f53cb8e24..000000000 --- a/web_app/tests/test_zklend_airdrop.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Test module for ZkLendAirdrop class""" - -from unittest.mock import AsyncMock, Mock, patch - -import pytest - -from web_app.api.serializers.airdrop import AirdropResponseModel -from web_app.contract_tools.airdrop import ZkLendAirdrop - - -@pytest.fixture -def mock_api_response() -> list: - """Fixture providing mock API response data.""" - return [ - { - "amount": "1000000000000000000", - "proof": ["0xabcd", "0x1234"], - "is_claimed": False, - "recipient": "0x742d35Cc6634C0532925a3b844Bc454e4438f44e", - } - ] - - -@pytest.fixture -def zklend_airdrop(): - """Fixture providing a ZkLendAirdrop instance with mocked API.""" - instance = ZkLendAirdrop() - instance.api = Mock() - instance.api.fetch = AsyncMock() - return instance - - -class TestZkLendAirdrop: - """Test suite for ZkLendAirdrop class.""" - - def test_init(self, zklend_airdrop): - """Test ZkLendAirdrop initialization.""" - assert isinstance(zklend_airdrop, ZkLendAirdrop) - assert ( - zklend_airdrop.REWARD_API_ENDPOINT - == "https://app.zklend.com/api/reward/all/" - ) - assert hasattr(zklend_airdrop, "api") - - @pytest.mark.asyncio - async def test_get_contract_airdrop_success( - self, zklend_airdrop, mock_api_response - ): - """Test successful retrieval of airdrop data.""" - - contract_id = "0x123456" - zklend_airdrop.api.fetch.return_value = mock_api_response - - result = await zklend_airdrop.get_contract_airdrop(contract_id) - - # Assert - assert isinstance(result, AirdropResponseModel) - assert len(result.airdrops) == 1 - airdrop = result.airdrops[0] - assert airdrop.amount == "1000000000000000000" - assert airdrop.proof == ["0xabcd", "0x1234"] - assert airdrop.is_claimed is False - assert airdrop.recipient == "0x742d35Cc6634C0532925a3b844Bc454e4438f44e" - - @pytest.mark.asyncio - async def test_get_contract_airdrop_empty_response(self, zklend_airdrop): - """Test handling of empty API response.""" - - contract_id = "0x123456" - zklend_airdrop.api.fetch.return_value = [] - - result = await zklend_airdrop.get_contract_airdrop(contract_id) - - # Assert - assert isinstance(result, AirdropResponseModel) - assert len(result.airdrops) == 0 - - @pytest.mark.asyncio - async def test_get_contract_airdrop_with_invalid_contract_id(self, zklend_airdrop): - """Test handling of invalid contract IDs.""" - - invalid_ids = ["", "0x"] - zklend_airdrop.api.fetch.return_value = [] - - for invalid_id in invalid_ids: - result = await zklend_airdrop.get_contract_airdrop(invalid_id) - assert isinstance(result, AirdropResponseModel) - assert len(result.airdrops) == 0 - - @pytest.mark.asyncio - async def test_get_contract_airdrop_none_contract_id(self, zklend_airdrop): - """Test handling of None contract ID.""" - with pytest.raises(ValueError): - await zklend_airdrop.get_contract_airdrop(None) - - def test_validate_response(self, zklend_airdrop, mock_api_response): - """Test response validation.""" - - result = zklend_airdrop._validate_response(mock_api_response) - - assert isinstance(result, AirdropResponseModel) - assert len(result.airdrops) == 1 - airdrop = result.airdrops[0] - assert isinstance(airdrop.proof, list) - assert airdrop.proof == ["0xabcd", "0x1234"] - - def test_validate_response_empty(self, zklend_airdrop): - """Test validation of empty response.""" - - result = zklend_airdrop._validate_response([]) - - assert isinstance(result, AirdropResponseModel) - assert len(result.airdrops) == 0 - - def test_validate_response_missing_fields(self, zklend_airdrop): - """Test validation with missing required fields.""" - - invalid_data = [{"amount": "1000"}] - - with pytest.raises(KeyError): - zklend_airdrop._validate_response(invalid_data) - - @pytest.mark.asyncio - async def test_underlying_contract_id_formatting(self, zklend_airdrop): - """Test correct formatting of underlying contract ID.""" - - contract_id = "0x123456" - expected_underlying_id = "0x0123456" - zklend_airdrop.api.fetch.return_value = [] - - await zklend_airdrop.get_contract_airdrop(contract_id) - - zklend_airdrop.api.fetch.assert_called_once_with(expected_underlying_id)