From 645870e3608cbdbea9b1e570d30787539593e693 Mon Sep 17 00:00:00 2001
From: Aniket Dixit <dixitaniket199@gmail.com>
Date: Fri, 28 Jul 2023 10:21:51 +0530
Subject: [PATCH] contract test fix

---
 cosmwasm/contracts/price-feed/src/contract.rs | 55 +++++++++++--------
 cw-relayer/relayer/client/client.go           |  7 ++-
 cw-relayer/relayer/relayer.go                 |  2 +-
 3 files changed, 40 insertions(+), 24 deletions(-)

diff --git a/cosmwasm/contracts/price-feed/src/contract.rs b/cosmwasm/contracts/price-feed/src/contract.rs
index f0497f6..b7603c1 100644
--- a/cosmwasm/contracts/price-feed/src/contract.rs
+++ b/cosmwasm/contracts/price-feed/src/contract.rs
@@ -417,7 +417,7 @@ fn query_deviation_ref(deps: Deps, symbol: &str) -> StdResult<RefDeviationData>
     }
 }
 
-fn query_deviation_ref_bulk(deps: Deps, symbols: &[String]) -> StdResult<Vec<RefData>> {
+fn query_deviation_ref_bulk(deps: Deps, symbols: &[String]) -> StdResult<Vec<RefDeviationData>> {
     symbols
         .iter()
         .map(|symbol| query_deviation_ref(deps, symbol))
@@ -1030,14 +1030,19 @@ mod tests {
                 .into_iter()
                 .map(|s| s.to_string())
                 .collect::<Vec<String>>();
-            let rates = [1000, 2000, 3000]
+            let deviations = [1000, 2000, 3000]
                 .iter()
                 .map(|r| Uint64::new(*r))
                 .collect::<Vec<Uint64>>();
 
+            let symbol_rates: Vec<(String, Vec<Uint64>)> = symbols
+                .iter()
+                .zip(std::iter::repeat(deviations.clone()))
+                .map(|(s, r)| (s.to_owned(), r))
+                .collect();
+
             let msg = RelayHistoricalDeviation {
-                symbol_rates: zip(symbols.clone(), rates.clone())
-                    .collect::<Vec<(String, Uint64)>>(),
+                symbol_rates: symbol_rates.clone(),
                 resolve_time: Uint64::from(10u64),
                 request_id: Uint64::one(),
             };
@@ -1047,13 +1052,9 @@ mod tests {
             let reference_datas =
                 query_deviation_ref_bulk(deps.as_ref(), &symbols.clone()).unwrap();
 
-            let retrieved_rates = reference_datas
-                .clone()
-                .iter()
-                .map(|r| r.rate)
-                .collect::<Vec<Uint64>>();
-
-            assert_eq!(retrieved_rates, rates);
+            for (expected, actual) in symbol_rates.iter().zip(reference_datas.iter()) {
+                assert_eq!(expected.1, actual.rates)
+            }
         }
 
         #[test]
@@ -1075,9 +1076,16 @@ mod tests {
                 .map(|r| Uint64::new(*r))
                 .collect::<Vec<Uint64>>();
 
+
+            let symbol_rates: Vec<(String, Vec<Uint64>)> = symbols
+                .iter()
+                .zip(std::iter::repeat(deviations.clone()))
+                .map(|(s, r)| (s.to_owned(), r))
+                .collect();
+
+
             let msg = ForceRelayHistoricalDeviation {
-                symbol_rates: zip(symbols.clone(), deviations.clone())
-                    .collect::<Vec<(String, Uint64)>>(),
+                symbol_rates:symbol_rates.clone(),
                 resolve_time: Uint64::from(100u64),
                 request_id: Uint64::from(2u64),
             };
@@ -1091,9 +1099,16 @@ mod tests {
                 .map(|r| Uint64::new(*r))
                 .collect::<Vec<Uint64>>();
 
+
+            let forced_deviation_symbol_rates: Vec<(String, Vec<Uint64>)> = symbols
+                .iter()
+                .zip(std::iter::repeat(forced_deviations.clone()))
+                .map(|(s, r)| (s.to_owned(), r))
+                .collect();
+
+
             let msg = ForceRelayHistoricalDeviation {
-                symbol_rates: zip(symbols.clone(), forced_deviations.clone())
-                    .collect::<Vec<(String, Uint64)>>(),
+                symbol_rates: forced_deviation_symbol_rates.clone(),
                 resolve_time: Uint64::from(10u64),
                 request_id: Uint64::zero(),
             };
@@ -1103,13 +1118,9 @@ mod tests {
             let reference_datas =
                 query_deviation_ref_bulk(deps.as_ref(), &symbols.clone()).unwrap();
 
-            let retrieved_rates = reference_datas
-                .clone()
-                .iter()
-                .map(|r| r.rate)
-                .collect::<Vec<Uint64>>();
-
-            assert_eq!(retrieved_rates, forced_deviations);
+            for (expected, actual) in forced_deviation_symbol_rates.iter().zip(reference_datas.iter()) {
+                assert_eq!(expected.1, actual.rates)
+            }
         }
 
         #[test]
diff --git a/cw-relayer/relayer/client/client.go b/cw-relayer/relayer/client/client.go
index 4e7e433..7fca69f 100644
--- a/cw-relayer/relayer/client/client.go
+++ b/cw-relayer/relayer/client/client.go
@@ -158,9 +158,10 @@ func (r *passReader) Read(p []byte) (n int, err error) {
 
 // BroadcastTx attempts to broadcast a signed transaction. If it fails, a few re-attempts
 // will be made until the transaction succeeds or ultimately times out or fails.
-func (oc RelayerClient) BroadcastTx(clientCtx client.Context, nextBlockHeight, timeoutHeight int64, msgs ...sdk.Msg) error {
+func (oc RelayerClient) BroadcastTx(clientCtx client.Context, timeoutDuration time.Duration, nextBlockHeight, timeoutHeight int64, msgs ...sdk.Msg) error {
 	maxBlockHeight := nextBlockHeight + timeoutHeight
 	lastCheckHeight := nextBlockHeight - 1
+	start := time.Now()
 
 	factory, err := oc.CreateTxFactory()
 	if err != nil {
@@ -175,6 +176,10 @@ func (oc RelayerClient) BroadcastTx(clientCtx client.Context, nextBlockHeight, t
 		}
 
 		if latestBlockHeight <= lastCheckHeight {
+			if time.Since(start).Seconds() >= timeoutDuration.Seconds() {
+				return fmt.Errorf("timeout duration exceeded")
+			}
+
 			continue
 		}
 
diff --git a/cw-relayer/relayer/relayer.go b/cw-relayer/relayer/relayer.go
index c718920..1f025a4 100644
--- a/cw-relayer/relayer/relayer.go
+++ b/cw-relayer/relayer/relayer.go
@@ -407,7 +407,7 @@ func (r *Relayer) tick(ctx context.Context) error {
 
 	logs.Msg("broadcasting execute to contract")
 
-	if err := r.relayerClient.BroadcastTx(*clientCtx, nextBlockHeight, r.timeoutHeight, executeMsgs...); err != nil {
+	if err := r.relayerClient.BroadcastTx(*clientCtx, r.resolveDuration, nextBlockHeight, r.timeoutHeight, executeMsgs...); err != nil {
 		r.missedCounter += 1
 		return err
 	}