diff --git a/pallets/subtensor/src/macros/dispatches.rs b/pallets/subtensor/src/macros/dispatches.rs index 4a70af6ca3..fcc086eb58 100644 --- a/pallets/subtensor/src/macros/dispatches.rs +++ b/pallets/subtensor/src/macros/dispatches.rs @@ -2189,7 +2189,13 @@ mod dispatches { ); // Add or remove liquidity - let result = T::SwapInterface::modify_position(netuid, &coldkey, &hotkey, position_id, liquidity_delta)?; + let result = T::SwapInterface::modify_position( + netuid, + &coldkey, + &hotkey, + position_id, + liquidity_delta, + )?; if liquidity_delta > 0 { // Remove TAO and Alpha balances or fail transaction if they can't be removed exactly @@ -2197,9 +2203,15 @@ mod dispatches { ensure!(tao_provided == result.tao, Error::::InsufficientBalance); let alpha_provided = Self::decrease_stake_for_hotkey_and_coldkey_on_subnet( - &hotkey, &coldkey, netuid, result.alpha, + &hotkey, + &coldkey, + netuid, + result.alpha, + ); + ensure!( + alpha_provided == result.alpha, + Error::::InsufficientBalance ); - ensure!(alpha_provided == result.alpha, Error::::InsufficientBalance); // Emit an event Self::deposit_event(Event::LiquidityAdded { diff --git a/pallets/subtensor/src/tests/children.rs b/pallets/subtensor/src/tests/children.rs index 813a0c2064..d29fd91184 100644 --- a/pallets/subtensor/src/tests/children.rs +++ b/pallets/subtensor/src/tests/children.rs @@ -2416,14 +2416,14 @@ fn test_revoke_child_no_min_stake_check() { add_network(netuid, 13, 0); register_ok_neuron(netuid, parent, coldkey, 0); - let reserve = 1_000_000_000_000_000; - mock::setup_reserves(netuid, reserve, reserve); - mock::setup_reserves(root, reserve, reserve); + let reserve = 1_000_000_000_000_000; + mock::setup_reserves(netuid, reserve, reserve); + mock::setup_reserves(root, reserve, reserve); // Set minimum stake for setting children StakeThreshold::::put(1_000_000_000_000); - let (_, fee) = mock::swap_tao_to_alpha(root, StakeThreshold::::get()); + let (_, fee) = mock::swap_tao_to_alpha(root, StakeThreshold::::get()); SubtensorModule::increase_stake_for_hotkey_and_coldkey_on_subnet( &parent, &coldkey, @@ -2494,12 +2494,12 @@ fn test_do_set_child_registration_disabled() { add_network(netuid, 13, 0); register_ok_neuron(netuid, parent, coldkey, 0); - let reserve = 1_000_000_000_000_000; - mock::setup_reserves(netuid, reserve, reserve); + let reserve = 1_000_000_000_000_000; + mock::setup_reserves(netuid, reserve, reserve); // Set minimum stake for setting children StakeThreshold::::put(1_000_000_000_000); - let (_, fee) = mock::swap_tao_to_alpha(netuid, StakeThreshold::::get()); + let (_, fee) = mock::swap_tao_to_alpha(netuid, StakeThreshold::::get()); SubtensorModule::increase_stake_for_hotkey_and_coldkey_on_subnet( &parent, &coldkey, diff --git a/pallets/subtensor/src/tests/coinbase.rs b/pallets/subtensor/src/tests/coinbase.rs index f8c405edbd..2636d0f3d8 100644 --- a/pallets/subtensor/src/tests/coinbase.rs +++ b/pallets/subtensor/src/tests/coinbase.rs @@ -1980,8 +1980,8 @@ fn test_run_coinbase_not_started() { // Set weight-set limit to 0. SubtensorModule::set_weights_set_rate_limit(netuid, 0); - let reserve = init_stake * 1000; - mock::setup_reserves(netuid, reserve, reserve); + let reserve = init_stake * 1000; + mock::setup_reserves(netuid, reserve, reserve); register_ok_neuron(netuid, hotkey, coldkey, 0); register_ok_neuron(netuid, miner_hk, miner_ck, 0); diff --git a/pallets/subtensor/src/tests/mock.rs b/pallets/subtensor/src/tests/mock.rs index 9bce02b12b..cc6753b552 100644 --- a/pallets/subtensor/src/tests/mock.rs +++ b/pallets/subtensor/src/tests/mock.rs @@ -899,4 +899,3 @@ pub(crate) fn swap_alpha_to_tao(netuid: u16, alpha: u64) -> (u64, u64) { (result.amount_paid_out, result.fee_paid) } - diff --git a/pallets/subtensor/src/tests/senate.rs b/pallets/subtensor/src/tests/senate.rs index 1efefc6130..64eaef266c 100644 --- a/pallets/subtensor/src/tests/senate.rs +++ b/pallets/subtensor/src/tests/senate.rs @@ -191,19 +191,19 @@ fn test_senate_vote_works() { stake )); - let approx_expected = stake - fee; + let approx_expected = stake - fee; assert_abs_diff_eq!( SubtensorModule::get_stake_for_hotkey_and_coldkey_on_subnet( &hotkey_account_id, &staker_coldkey, netuid ), - approx_expected, + approx_expected, epsilon = approx_expected / 1000 ); assert_abs_diff_eq!( SubtensorModule::get_stake_for_hotkey_on_subnet(&hotkey_account_id, netuid), - approx_expected, + approx_expected, epsilon = approx_expected / 1000 ); diff --git a/pallets/subtensor/src/tests/swap_hotkey.rs b/pallets/subtensor/src/tests/swap_hotkey.rs index 059e489a04..d93ca0c6c3 100644 --- a/pallets/subtensor/src/tests/swap_hotkey.rs +++ b/pallets/subtensor/src/tests/swap_hotkey.rs @@ -79,7 +79,7 @@ fn test_swap_total_hotkey_stake() { // Add stake let (expected_alpha, _) = mock::swap_tao_to_alpha(netuid, amount); - assert!(expected_alpha > 0); + assert!(expected_alpha > 0); assert_ok!(SubtensorModule::add_stake( RuntimeOrigin::signed(coldkey), old_hotkey, @@ -88,7 +88,10 @@ fn test_swap_total_hotkey_stake() { )); // Check if stake has increased - assert_eq!(TotalHotkeyAlpha::::get(old_hotkey, netuid), expected_alpha); + assert_eq!( + TotalHotkeyAlpha::::get(old_hotkey, netuid), + expected_alpha + ); assert_abs_diff_eq!( SubtensorModule::get_total_stake_for_hotkey(&new_hotkey), 0, @@ -109,7 +112,10 @@ fn test_swap_total_hotkey_stake() { 0, epsilon = 1, ); - assert_eq!(TotalHotkeyAlpha::::get(new_hotkey, netuid), expected_alpha); + assert_eq!( + TotalHotkeyAlpha::::get(new_hotkey, netuid), + expected_alpha + ); }); } diff --git a/pallets/swap/src/pallet/impls.rs b/pallets/swap/src/pallet/impls.rs index feb8a4bcd8..7ff4bcce76 100644 --- a/pallets/swap/src/pallet/impls.rs +++ b/pallets/swap/src/pallet/impls.rs @@ -7,7 +7,7 @@ use sp_arithmetic::helpers_128bit; use sp_runtime::traits::AccountIdConversion; use substrate_fixed::types::{U64F64, U96F32}; use subtensor_swap_interface::{ - LiquidityDataProvider, UpdateLiquidityResult, SwapHandler, SwapResult, + LiquidityDataProvider, SwapHandler, SwapResult, UpdateLiquidityResult, }; use super::pallet::*; @@ -90,14 +90,14 @@ impl SwapStep { lq = one.safe_div(TickIndex::min_sqrt_price()); } lq - }, + } OrderType::Buy => { let mut lq = TickIndex::max_sqrt_price().min(sqrt_price_limit.into()); if lq < current_price { lq = TickIndex::max_sqrt_price(); } lq - }, + } }; Self { @@ -228,10 +228,8 @@ impl SwapStep { delta_fixed.saturating_mul(u16_max.safe_div(u16_max.saturating_sub(fee_rate))); // Hold the fees - let fee = Pallet::::calculate_fee_amount( - self.netuid, - total_cost.saturating_to_num::(), - ); + let fee = + Pallet::::calculate_fee_amount(self.netuid, total_cost.saturating_to_num::()); Pallet::::add_fees(self.netuid, self.order_type, fee); let delta_out = Pallet::::convert_deltas(self.netuid, self.order_type, self.delta_in); @@ -336,14 +334,11 @@ impl Pallet { let epsilon = U64F64::saturating_from_num(0.000001); let current_sqrt_price = price.checked_sqrt(epsilon).unwrap_or(U64F64::from_num(0)); - AlphaSqrtPrice::::set( - netuid, - current_sqrt_price, - ); + AlphaSqrtPrice::::set(netuid, current_sqrt_price); // Set current tick let current_tick = TickIndex::from_sqrt_price_bounded(current_sqrt_price); - CurrentTick::::set(netuid, current_tick); + CurrentTick::::set(netuid, current_tick); // Set initial (protocol owned) liquidity and positions // Protocol liquidity makes one position from TickIndex::MIN to TickIndex::MAX @@ -1174,8 +1169,14 @@ impl SwapHandler for Pallet { position_id: u128, liquidity_delta: i64, ) -> Result { - Self::modify_position(netuid.into(), coldkey_account_id, hotkey_account_id, position_id.into(), liquidity_delta) - .map_err(Into::into) + Self::modify_position( + netuid.into(), + coldkey_account_id, + hotkey_account_id, + position_id.into(), + liquidity_delta, + ) + .map_err(Into::into) } fn approx_fee_amount(netuid: u16, amount: u64) -> u64 { @@ -1298,7 +1299,7 @@ mod tests { assert_eq!(sqrt_price, expected_sqrt_price); // Verify that current tick is set - let current_tick = CurrentTick::::get(netuid); + let current_tick = CurrentTick::::get(netuid); let expected_current_tick = TickIndex::from_sqrt_price_bounded(expected_sqrt_price); assert_eq!(current_tick, expected_current_tick); @@ -1729,15 +1730,14 @@ mod tests { // Modify liquidity (also causes claiming of fees) let liquidity_before = CurrentLiquidity::::get(netuid); - let modify_result = - Pallet::::modify_position( - netuid, - &OK_COLDKEY_ACCOUNT_ID, - &OK_HOTKEY_ACCOUNT_ID, - position_id, - -1_i64 * ((liquidity / 10) as i64), - ) - .unwrap(); + let modify_result = Pallet::::modify_position( + netuid, + &OK_COLDKEY_ACCOUNT_ID, + &OK_HOTKEY_ACCOUNT_ID, + position_id, + -1_i64 * ((liquidity / 10) as i64), + ) + .unwrap(); assert_abs_diff_eq!(modify_result.alpha, alpha / 10, epsilon = alpha / 1000); assert!(modify_result.fee_tao > 0); assert_eq!(modify_result.fee_alpha, 0); @@ -1753,22 +1753,20 @@ mod tests { // Position liquidity is reduced let position = - Positions::::get(&(netuid, OK_COLDKEY_ACCOUNT_ID, position_id)) - .unwrap(); + Positions::::get(&(netuid, OK_COLDKEY_ACCOUNT_ID, position_id)).unwrap(); assert_eq!(position.liquidity, liquidity * 9 / 10); assert_eq!(position.tick_low, tick_low); assert_eq!(position.tick_high, tick_high); // Modify liquidity again (ensure fees aren't double-collected) - let modify_result = - Pallet::::modify_position( - netuid, - &OK_COLDKEY_ACCOUNT_ID, - &OK_HOTKEY_ACCOUNT_ID, - position_id, - -1_i64 * ((liquidity / 100) as i64), - ) - .unwrap(); + let modify_result = Pallet::::modify_position( + netuid, + &OK_COLDKEY_ACCOUNT_ID, + &OK_HOTKEY_ACCOUNT_ID, + position_id, + -1_i64 * ((liquidity / 100) as i64), + ) + .unwrap(); assert_abs_diff_eq!(modify_result.alpha, alpha / 100, epsilon = alpha / 1000); assert_eq!(modify_result.fee_tao, 0); @@ -1913,9 +1911,10 @@ mod tests { } // Assert that current tick is updated - let current_tick = CurrentTick::::get(netuid); - let expected_current_tick = TickIndex::from_sqrt_price_bounded(sqrt_current_price_after); - assert_eq!(current_tick, expected_current_tick); + let current_tick = CurrentTick::::get(netuid); + let expected_current_tick = + TickIndex::from_sqrt_price_bounded(sqrt_current_price_after); + assert_eq!(current_tick, expected_current_tick); }, ); }); @@ -2390,6 +2389,7 @@ mod tests { let current_price = SqrtPrice::from_num(0.50000051219212275465); let tick = TickIndex::try_from_sqrt_price(current_price).unwrap(); + let round_trip_price = TickIndex::try_to_sqrt_price(&tick).unwrap(); assert!(round_trip_price <= current_price); @@ -2397,6 +2397,4 @@ mod tests { assert!(tick == roundtrip_tick); }); } - - } diff --git a/pallets/swap/src/pallet/mod.rs b/pallets/swap/src/pallet/mod.rs index 46ad895e9d..3a44cb4601 100644 --- a/pallets/swap/src/pallet/mod.rs +++ b/pallets/swap/src/pallet/mod.rs @@ -170,8 +170,8 @@ mod pallet { /// Provided liquidity parameter is invalid (likely too small) InvalidLiquidityValue, - /// Reserves too low for operation. - ReservesTooLow, + /// Reserves too low for operation. + ReservesTooLow, } #[pallet::call] diff --git a/pallets/swap/src/position.rs b/pallets/swap/src/position.rs index 4039e9f191..8f1ef69d4d 100644 --- a/pallets/swap/src/position.rs +++ b/pallets/swap/src/position.rs @@ -107,7 +107,10 @@ impl Position { fee_tao = liquidity_frac.saturating_mul(fee_tao); fee_alpha = liquidity_frac.saturating_mul(fee_alpha); - (fee_tao.saturating_to_num::(), fee_alpha.saturating_to_num::()) + ( + fee_tao.saturating_to_num::(), + fee_alpha.saturating_to_num::(), + ) } /// Get fees in a position's range diff --git a/pallets/swap/src/tick.rs b/pallets/swap/src/tick.rs index 7c4d9b22cb..2526914711 100644 --- a/pallets/swap/src/tick.rs +++ b/pallets/swap/src/tick.rs @@ -5,7 +5,7 @@ use core::fmt; use core::hash::Hash; use core::ops::{Add, AddAssign, BitOr, Deref, Neg, Shl, Shr, Sub, SubAssign}; -use alloy_primitives::{I256, U256}; +use alloy_primitives::{I256, U256, uint}; use codec::{Decode, Encode, MaxEncodedLen}; use frame_support::pallet_prelude::*; use safe_math::*; @@ -443,17 +443,25 @@ impl TickIndex { /// tick index matches the price by the following inequality: /// sqrt_lower_price <= sqrt_price < sqrt_higher_price pub fn try_from_sqrt_price(sqrt_price: SqrtPrice) -> Result { - let tick = get_tick_at_sqrt_ratio(u64f64_to_u256_q64_96(sqrt_price))?; + // price in the native Q64.96 integer format + let price_x96 = u64f64_to_u256_q64_96(sqrt_price); - // Correct for rounding error during conversions between different fixed-point formats - if tick == 0 { - Ok(Self(tick)) + // first‑pass estimate from the log calculation + let mut tick = get_tick_at_sqrt_ratio(price_x96)?; + + // post‑verification, *both* directions + let price_at_tick = get_sqrt_ratio_at_tick(tick)?; + if price_at_tick > price_x96 { + tick -= 1; // estimate was too high } else { - match (tick + 1).into_tick_index() { - Ok(incremented) => Ok(incremented), - Err(e) => Err(e), + // it may still be one too low + let price_at_tick_plus = get_sqrt_ratio_at_tick(tick + 1)?; + if price_at_tick_plus <= price_x96 { + tick += 1; // step up when required } } + + tick.into_tick_index() } } @@ -905,12 +913,13 @@ fn get_sqrt_ratio_at_tick(tick: i32) -> Result { ratio = U256::MAX / ratio; } - Ok((ratio >> 32) - + if (ratio.wrapping_rem(U256_1 << 32)).is_zero() { - U256::ZERO - } else { - U256_1 - }) + let shifted = ratio >> 32; + let ceil = if ratio & U256::from((1u128 << 32) - 1) != U256::ZERO { + shifted + U256_1 + } else { + shifted + }; + Ok(ceil) } fn get_tick_at_sqrt_ratio(sqrt_price_x_96: U256) -> Result { @@ -1025,7 +1034,7 @@ fn get_tick_at_sqrt_ratio(sqrt_price_x_96: U256) -> Result { /// * `value` - The U256 value in Q64.96 format /// /// # Returns -/// * `Result` - Converted value or error if too large +/// * `Result` - Converted value or error if too large fn u256_to_u64f64(value: U256, source_fractional_bits: u32) -> Result { if value > U256::from(u128::MAX) { return Err(TickMathError::ConversionError); @@ -1049,6 +1058,11 @@ fn u256_to_u64f64(value: U256, source_fractional_bits: u32) -> Result U256 { + u64f64_to_u256(value, 96) +} + /// Convert U64F64 to U256 /// /// # Arguments @@ -1075,12 +1089,32 @@ fn u64f64_to_u256(value: U64F64, target_fractional_bits: u32) -> U256 { /// Convert U256 in Q64.96 format (Uniswap's sqrt price format) to U64F64 fn u256_q64_96_to_u64f64(value: U256) -> Result { - u256_to_u64f64(value, 96) + q_to_u64f64(value, 96) } -/// Convert U64F64 to U256 in Q64.96 format (Uniswap's sqrt price format) -fn u64f64_to_u256_q64_96(value: U64F64) -> U256 { - u64f64_to_u256(value, 96) +fn q_to_u64f64(x: U256, frac_bits: u32) -> Result { + let diff = frac_bits.checked_sub(64).unwrap_or(0); + + // 1. shift right diff bits + let shifted = if diff != 0 { x >> diff } else { x }; + + // 2. **round up** if we threw away any 1‑bits + let mask = if diff != 0 { + (U256_1 << diff) - U256_1 + } else { + U256::ZERO + }; + let rounded = if diff != 0 && (x & mask) != U256::ZERO { + shifted + U256_1 + } else { + shifted + }; + + // 3. check that it fits in 128 bits and transmute + if (rounded >> 128) != U256::ZERO { + return Err(TickMathError::Overflow); + } + Ok(U64F64::from_bits(rounded.to::())) } #[derive(Debug, PartialEq, Eq)] @@ -1110,6 +1144,7 @@ impl Error for TickMathError {} mod tests { use std::{ops::Sub, str::FromStr}; + use approx::assert_abs_diff_eq; use safe_math::FixedExt; use super::*; @@ -1361,18 +1396,31 @@ mod tests { assert_eq!(tick_index, TickIndex::new_unchecked(0)); // Test with sqrt price equal to tick_spacing_tao (should be tick index 2) - let tick_index = TickIndex::try_from_sqrt_price(tick_spacing).unwrap(); - assert_eq!(tick_index, TickIndex::new_unchecked(2)); + let epsilon = SqrtPrice::from_num(0.0000000000000001); + assert!( + TickIndex::new_unchecked(2) + .to_sqrt_price_bounded() + .abs_diff(tick_spacing) + < epsilon + ); // Test with sqrt price equal to tick_spacing_tao^2 (should be tick index 4) let sqrt_price = tick_spacing * tick_spacing; - let tick_index = TickIndex::try_from_sqrt_price(sqrt_price).unwrap(); - assert_eq!(tick_index, TickIndex::new_unchecked(4)); + assert!( + TickIndex::new_unchecked(4) + .to_sqrt_price_bounded() + .abs_diff(sqrt_price) + < epsilon + ); // Test with sqrt price equal to tick_spacing_tao^5 (should be tick index 10) let sqrt_price = tick_spacing.checked_pow(5).unwrap(); - let tick_index = TickIndex::try_from_sqrt_price(sqrt_price).unwrap(); - assert_eq!(tick_index, TickIndex::new_unchecked(10)); + assert!( + TickIndex::new_unchecked(10) + .to_sqrt_price_bounded() + .abs_diff(sqrt_price) + < epsilon + ); } #[test] @@ -1392,9 +1440,9 @@ mod tests { 1000, TickIndex::MAX.get(), ] - .iter() + .into_iter() { - let tick_index = TickIndex(*i32_value); + let tick_index = TickIndex::new_unchecked(i32_value); let sqrt_price = tick_index.try_to_sqrt_price().unwrap(); let round_trip_tick_index = TickIndex::try_from_sqrt_price(sqrt_price).unwrap(); assert_eq!(round_trip_tick_index, tick_index);