From f1970c8542fb0b9133569108be156601dd7fce7d Mon Sep 17 00:00:00 2001 From: ian Date: Thu, 27 Jun 2024 10:13:42 +0800 Subject: [PATCH] feat: allow sha256 for HTLC Allow choosing the hash algorithm when adding the HTLC output. Related PR: https://github.com/nervosnetwork/cfn-scripts/pull/6 --- src/ckb/channel.rs | 29 +++- src/ckb/gen/cfn.rs | 49 ++++-- src/ckb/gen/invoice.rs | 302 +++++++++++++++++++++++++++++++++++- src/ckb/hash_algorithm.rs | 79 ++++++++++ src/ckb/mod.rs | 2 + src/ckb/schema/cfn.mol | 1 + src/ckb/schema/invoice.mol | 8 + src/ckb/types.rs | 6 + src/invoice/invoice_impl.rs | 13 ++ src/rpc/channel.rs | 3 + 10 files changed, 473 insertions(+), 19 deletions(-) create mode 100644 src/ckb/hash_algorithm.rs diff --git a/src/ckb/channel.rs b/src/ckb/channel.rs index 1caf52fc4..233d41924 100644 --- a/src/ckb/channel.rs +++ b/src/ckb/channel.rs @@ -41,6 +41,7 @@ use crate::{ use super::{ config::{DEFAULT_CHANNEL_MINIMAL_CKB_AMOUNT, MIN_UDT_OCCUPIED_CAPACITY}, + hash_algorithm::HashAlgorithm, key::blake2b_hash_with_salt, network::CFNMessageWithPeerId, serde_utils::EntityHex, @@ -103,6 +104,7 @@ pub struct AddTlcCommand { pub preimage: Option, pub payment_hash: Option, pub expiry: LockTime, + pub hash_algorithm: HashAlgorithm, } #[derive(Debug)] @@ -563,6 +565,7 @@ impl ChannelActor { amount: tlc.amount, payment_hash: tlc.payment_hash, expiry: tlc.lock_time, + hash_algorithm: tlc.hash_algorithm, }), }; debug!("Sending AddTlc message: {:?}", &msg); @@ -2130,8 +2133,11 @@ impl ChannelActorState { reason, removed_at, current ); if let RemoveTlcReason::RemoveTlcFulfill(fulfill) = reason { - let filled_payment_hash: Hash256 = - blake2b_256(fulfill.payment_preimage).into(); + let filled_payment_hash: Hash256 = current + .tlc + .hash_algorithm + .hash(fulfill.payment_preimage) + .into(); if current.tlc.payment_hash != filled_payment_hash { return Err(ProcessingChannelError::InvalidParameter(format!( "Preimage {:?} is hashed to {}, which does not match payment hash {:?}", @@ -2436,7 +2442,7 @@ impl ChannelActorState { tlcs.iter() .map(|(tlc, local, remote)| { [ - (if tlc.tlc.is_offered() { [0] } else { [1] }).to_vec(), + vec![tlc.tlc.get_htlc_type()], tlc.tlc.amount.to_le_bytes().to_vec(), tlc.tlc.get_hash().to_vec(), local.serialize().to_vec(), @@ -2520,7 +2526,7 @@ impl ChannelActorState { let preimage = command.preimage.unwrap_or(get_random_preimage()); let payment_hash = command .payment_hash - .unwrap_or(blake2b_256(&preimage).into()); + .unwrap_or_else(|| command.hash_algorithm.hash(&preimage).into()); TLC { id: TLCId::Offered(id), @@ -2528,6 +2534,7 @@ impl ChannelActorState { payment_hash, lock_time: command.expiry, payment_preimage: Some(preimage), + hash_algorithm: command.hash_algorithm, } } @@ -2551,6 +2558,7 @@ impl ChannelActorState { payment_hash: message.payment_hash, lock_time: message.expiry, payment_preimage: None, + hash_algorithm: message.hash_algorithm, }) } } @@ -3171,6 +3179,7 @@ impl ChannelActorState { amount: info.tlc.amount, payment_hash: info.tlc.payment_hash, expiry: info.tlc.lock_time, + hash_algorithm: info.tlc.hash_algorithm, }), }), )) @@ -4170,6 +4179,8 @@ pub struct TLC { pub payment_hash: Hash256, /// The preimage of the hash to be sent to the counterparty. pub payment_preimage: Option, + /// Which hash algorithm is applied on the preimage + pub hash_algorithm: HashAlgorithm, } impl TLC { @@ -4186,6 +4197,16 @@ impl TLC { self.id.flip_mut() } + /// Get the value for the field `htlc_type` in commitment lock witness. + /// - Lowest 1 bit: 0 if the tlc is offered by the remote party, 1 otherwise. + /// - High 7 bits: + /// - 0: ckb hash + /// - 1: sha256 + pub fn get_htlc_type(&self) -> u8 { + let offered_flag = if self.is_offered() { 0u8 } else { 1u8 }; + ((self.hash_algorithm as u8) << 1) + offered_flag + } + fn get_hash(&self) -> ShortHash { self.payment_hash.as_ref()[..20].try_into().unwrap() } diff --git a/src/ckb/gen/cfn.rs b/src/ckb/gen/cfn.rs index 1579b5e5c..1d132457c 100644 --- a/src/ckb/gen/cfn.rs +++ b/src/ckb/gen/cfn.rs @@ -7042,6 +7042,7 @@ impl ::core::fmt::Display for AddTlc { write!(f, ", {}: {}", "amount", self.amount())?; write!(f, ", {}: {}", "payment_hash", self.payment_hash())?; write!(f, ", {}: {}", "expiry", self.expiry())?; + write!(f, ", {}: {}", "hash_algorithm", self.hash_algorithm())?; let extra_count = self.count_extra_fields(); if extra_count != 0 { write!(f, ", .. ({} fields)", extra_count)?; @@ -7056,14 +7057,14 @@ impl ::core::default::Default for AddTlc { } } impl AddTlc { - const DEFAULT_VALUE: [u8; 120] = [ - 120, 0, 0, 0, 24, 0, 0, 0, 56, 0, 0, 0, 64, 0, 0, 0, 80, 0, 0, 0, 112, 0, 0, 0, 0, 0, 0, 0, + const DEFAULT_VALUE: [u8; 125] = [ + 125, 0, 0, 0, 28, 0, 0, 0, 60, 0, 0, 0, 68, 0, 0, 0, 84, 0, 0, 0, 116, 0, 0, 0, 124, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, ]; - pub const FIELD_COUNT: usize = 5; + pub const FIELD_COUNT: usize = 6; pub fn total_size(&self) -> usize { molecule::unpack_number(self.as_slice()) as usize } @@ -7107,11 +7108,17 @@ impl AddTlc { pub fn expiry(&self) -> Uint64 { let slice = self.as_slice(); let start = molecule::unpack_number(&slice[20..]) as usize; + let end = molecule::unpack_number(&slice[24..]) as usize; + Uint64::new_unchecked(self.0.slice(start..end)) + } + pub fn hash_algorithm(&self) -> Byte { + let slice = self.as_slice(); + let start = molecule::unpack_number(&slice[24..]) as usize; if self.has_extra_fields() { - let end = molecule::unpack_number(&slice[24..]) as usize; - Uint64::new_unchecked(self.0.slice(start..end)) + let end = molecule::unpack_number(&slice[28..]) as usize; + Byte::new_unchecked(self.0.slice(start..end)) } else { - Uint64::new_unchecked(self.0.slice(start..)) + Byte::new_unchecked(self.0.slice(start..)) } } pub fn as_reader<'r>(&'r self) -> AddTlcReader<'r> { @@ -7146,6 +7153,7 @@ impl molecule::prelude::Entity for AddTlc { .amount(self.amount()) .payment_hash(self.payment_hash()) .expiry(self.expiry()) + .hash_algorithm(self.hash_algorithm()) } } #[derive(Clone, Copy)] @@ -7172,6 +7180,7 @@ impl<'r> ::core::fmt::Display for AddTlcReader<'r> { write!(f, ", {}: {}", "amount", self.amount())?; write!(f, ", {}: {}", "payment_hash", self.payment_hash())?; write!(f, ", {}: {}", "expiry", self.expiry())?; + write!(f, ", {}: {}", "hash_algorithm", self.hash_algorithm())?; let extra_count = self.count_extra_fields(); if extra_count != 0 { write!(f, ", .. ({} fields)", extra_count)?; @@ -7180,7 +7189,7 @@ impl<'r> ::core::fmt::Display for AddTlcReader<'r> { } } impl<'r> AddTlcReader<'r> { - pub const FIELD_COUNT: usize = 5; + pub const FIELD_COUNT: usize = 6; pub fn total_size(&self) -> usize { molecule::unpack_number(self.as_slice()) as usize } @@ -7224,11 +7233,17 @@ impl<'r> AddTlcReader<'r> { pub fn expiry(&self) -> Uint64Reader<'r> { let slice = self.as_slice(); let start = molecule::unpack_number(&slice[20..]) as usize; + let end = molecule::unpack_number(&slice[24..]) as usize; + Uint64Reader::new_unchecked(&self.as_slice()[start..end]) + } + pub fn hash_algorithm(&self) -> ByteReader<'r> { + let slice = self.as_slice(); + let start = molecule::unpack_number(&slice[24..]) as usize; if self.has_extra_fields() { - let end = molecule::unpack_number(&slice[24..]) as usize; - Uint64Reader::new_unchecked(&self.as_slice()[start..end]) + let end = molecule::unpack_number(&slice[28..]) as usize; + ByteReader::new_unchecked(&self.as_slice()[start..end]) } else { - Uint64Reader::new_unchecked(&self.as_slice()[start..]) + ByteReader::new_unchecked(&self.as_slice()[start..]) } } } @@ -7283,6 +7298,7 @@ impl<'r> molecule::prelude::Reader<'r> for AddTlcReader<'r> { Uint128Reader::verify(&slice[offsets[2]..offsets[3]], compatible)?; Byte32Reader::verify(&slice[offsets[3]..offsets[4]], compatible)?; Uint64Reader::verify(&slice[offsets[4]..offsets[5]], compatible)?; + ByteReader::verify(&slice[offsets[5]..offsets[6]], compatible)?; Ok(()) } } @@ -7293,9 +7309,10 @@ pub struct AddTlcBuilder { pub(crate) amount: Uint128, pub(crate) payment_hash: Byte32, pub(crate) expiry: Uint64, + pub(crate) hash_algorithm: Byte, } impl AddTlcBuilder { - pub const FIELD_COUNT: usize = 5; + pub const FIELD_COUNT: usize = 6; pub fn channel_id(mut self, v: Byte32) -> Self { self.channel_id = v; self @@ -7316,6 +7333,10 @@ impl AddTlcBuilder { self.expiry = v; self } + pub fn hash_algorithm(mut self, v: Byte) -> Self { + self.hash_algorithm = v; + self + } } impl molecule::prelude::Builder for AddTlcBuilder { type Entity = AddTlc; @@ -7327,6 +7348,7 @@ impl molecule::prelude::Builder for AddTlcBuilder { + self.amount.as_slice().len() + self.payment_hash.as_slice().len() + self.expiry.as_slice().len() + + self.hash_algorithm.as_slice().len() } fn write(&self, writer: &mut W) -> molecule::io::Result<()> { let mut total_size = molecule::NUMBER_SIZE * (Self::FIELD_COUNT + 1); @@ -7341,6 +7363,8 @@ impl molecule::prelude::Builder for AddTlcBuilder { total_size += self.payment_hash.as_slice().len(); offsets.push(total_size); total_size += self.expiry.as_slice().len(); + offsets.push(total_size); + total_size += self.hash_algorithm.as_slice().len(); writer.write_all(&molecule::pack_number(total_size as molecule::Number))?; for offset in offsets.into_iter() { writer.write_all(&molecule::pack_number(offset as molecule::Number))?; @@ -7350,6 +7374,7 @@ impl molecule::prelude::Builder for AddTlcBuilder { writer.write_all(self.amount.as_slice())?; writer.write_all(self.payment_hash.as_slice())?; writer.write_all(self.expiry.as_slice())?; + writer.write_all(self.hash_algorithm.as_slice())?; Ok(()) } fn build(&self) -> Self::Entity { diff --git a/src/ckb/gen/invoice.rs b/src/ckb/gen/invoice.rs index 89f2dcfaa..1ad201a63 100644 --- a/src/ckb/gen/invoice.rs +++ b/src/ckb/gen/invoice.rs @@ -6279,6 +6279,266 @@ impl molecule::prelude::Builder for PaymentPreimageBuilder { } } #[derive(Clone)] +pub struct HashAlgorithm(molecule::bytes::Bytes); +impl ::core::fmt::LowerHex for HashAlgorithm { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + use molecule::hex_string; + if f.alternate() { + write!(f, "0x")?; + } + write!(f, "{}", hex_string(self.as_slice())) + } +} +impl ::core::fmt::Debug for HashAlgorithm { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + write!(f, "{}({:#x})", Self::NAME, self) + } +} +impl ::core::fmt::Display for HashAlgorithm { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + write!(f, "{} {{ ", Self::NAME)?; + write!(f, "{}: {}", "attr_id", self.attr_id())?; + write!(f, ", {}: {}", "value", self.value())?; + let extra_count = self.count_extra_fields(); + if extra_count != 0 { + write!(f, ", .. ({} fields)", extra_count)?; + } + write!(f, " }}") + } +} +impl ::core::default::Default for HashAlgorithm { + fn default() -> Self { + let v = molecule::bytes::Bytes::from_static(&Self::DEFAULT_VALUE); + HashAlgorithm::new_unchecked(v) + } +} +impl HashAlgorithm { + const DEFAULT_VALUE: [u8; 14] = [14, 0, 0, 0, 12, 0, 0, 0, 13, 0, 0, 0, 0, 0]; + pub const FIELD_COUNT: usize = 2; + pub fn total_size(&self) -> usize { + molecule::unpack_number(self.as_slice()) as usize + } + pub fn field_count(&self) -> usize { + if self.total_size() == molecule::NUMBER_SIZE { + 0 + } else { + (molecule::unpack_number(&self.as_slice()[molecule::NUMBER_SIZE..]) as usize / 4) - 1 + } + } + pub fn count_extra_fields(&self) -> usize { + self.field_count() - Self::FIELD_COUNT + } + pub fn has_extra_fields(&self) -> bool { + Self::FIELD_COUNT != self.field_count() + } + pub fn attr_id(&self) -> Byte { + let slice = self.as_slice(); + let start = molecule::unpack_number(&slice[4..]) as usize; + let end = molecule::unpack_number(&slice[8..]) as usize; + Byte::new_unchecked(self.0.slice(start..end)) + } + pub fn value(&self) -> Byte { + let slice = self.as_slice(); + let start = molecule::unpack_number(&slice[8..]) as usize; + if self.has_extra_fields() { + let end = molecule::unpack_number(&slice[12..]) as usize; + Byte::new_unchecked(self.0.slice(start..end)) + } else { + Byte::new_unchecked(self.0.slice(start..)) + } + } + pub fn as_reader<'r>(&'r self) -> HashAlgorithmReader<'r> { + HashAlgorithmReader::new_unchecked(self.as_slice()) + } +} +impl molecule::prelude::Entity for HashAlgorithm { + type Builder = HashAlgorithmBuilder; + const NAME: &'static str = "HashAlgorithm"; + fn new_unchecked(data: molecule::bytes::Bytes) -> Self { + HashAlgorithm(data) + } + fn as_bytes(&self) -> molecule::bytes::Bytes { + self.0.clone() + } + fn as_slice(&self) -> &[u8] { + &self.0[..] + } + fn from_slice(slice: &[u8]) -> molecule::error::VerificationResult { + HashAlgorithmReader::from_slice(slice).map(|reader| reader.to_entity()) + } + fn from_compatible_slice(slice: &[u8]) -> molecule::error::VerificationResult { + HashAlgorithmReader::from_compatible_slice(slice).map(|reader| reader.to_entity()) + } + fn new_builder() -> Self::Builder { + ::core::default::Default::default() + } + fn as_builder(self) -> Self::Builder { + Self::new_builder() + .attr_id(self.attr_id()) + .value(self.value()) + } +} +#[derive(Clone, Copy)] +pub struct HashAlgorithmReader<'r>(&'r [u8]); +impl<'r> ::core::fmt::LowerHex for HashAlgorithmReader<'r> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + use molecule::hex_string; + if f.alternate() { + write!(f, "0x")?; + } + write!(f, "{}", hex_string(self.as_slice())) + } +} +impl<'r> ::core::fmt::Debug for HashAlgorithmReader<'r> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + write!(f, "{}({:#x})", Self::NAME, self) + } +} +impl<'r> ::core::fmt::Display for HashAlgorithmReader<'r> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + write!(f, "{} {{ ", Self::NAME)?; + write!(f, "{}: {}", "attr_id", self.attr_id())?; + write!(f, ", {}: {}", "value", self.value())?; + let extra_count = self.count_extra_fields(); + if extra_count != 0 { + write!(f, ", .. ({} fields)", extra_count)?; + } + write!(f, " }}") + } +} +impl<'r> HashAlgorithmReader<'r> { + pub const FIELD_COUNT: usize = 2; + pub fn total_size(&self) -> usize { + molecule::unpack_number(self.as_slice()) as usize + } + pub fn field_count(&self) -> usize { + if self.total_size() == molecule::NUMBER_SIZE { + 0 + } else { + (molecule::unpack_number(&self.as_slice()[molecule::NUMBER_SIZE..]) as usize / 4) - 1 + } + } + pub fn count_extra_fields(&self) -> usize { + self.field_count() - Self::FIELD_COUNT + } + pub fn has_extra_fields(&self) -> bool { + Self::FIELD_COUNT != self.field_count() + } + pub fn attr_id(&self) -> ByteReader<'r> { + let slice = self.as_slice(); + let start = molecule::unpack_number(&slice[4..]) as usize; + let end = molecule::unpack_number(&slice[8..]) as usize; + ByteReader::new_unchecked(&self.as_slice()[start..end]) + } + pub fn value(&self) -> ByteReader<'r> { + let slice = self.as_slice(); + let start = molecule::unpack_number(&slice[8..]) as usize; + if self.has_extra_fields() { + let end = molecule::unpack_number(&slice[12..]) as usize; + ByteReader::new_unchecked(&self.as_slice()[start..end]) + } else { + ByteReader::new_unchecked(&self.as_slice()[start..]) + } + } +} +impl<'r> molecule::prelude::Reader<'r> for HashAlgorithmReader<'r> { + type Entity = HashAlgorithm; + const NAME: &'static str = "HashAlgorithmReader"; + fn to_entity(&self) -> Self::Entity { + Self::Entity::new_unchecked(self.as_slice().to_owned().into()) + } + fn new_unchecked(slice: &'r [u8]) -> Self { + HashAlgorithmReader(slice) + } + fn as_slice(&self) -> &'r [u8] { + self.0 + } + fn verify(slice: &[u8], compatible: bool) -> molecule::error::VerificationResult<()> { + use molecule::verification_error as ve; + let slice_len = slice.len(); + if slice_len < molecule::NUMBER_SIZE { + return ve!(Self, HeaderIsBroken, molecule::NUMBER_SIZE, slice_len); + } + let total_size = molecule::unpack_number(slice) as usize; + if slice_len != total_size { + return ve!(Self, TotalSizeNotMatch, total_size, slice_len); + } + if slice_len < molecule::NUMBER_SIZE * 2 { + return ve!(Self, HeaderIsBroken, molecule::NUMBER_SIZE * 2, slice_len); + } + let offset_first = molecule::unpack_number(&slice[molecule::NUMBER_SIZE..]) as usize; + if offset_first % molecule::NUMBER_SIZE != 0 || offset_first < molecule::NUMBER_SIZE * 2 { + return ve!(Self, OffsetsNotMatch); + } + if slice_len < offset_first { + return ve!(Self, HeaderIsBroken, offset_first, slice_len); + } + let field_count = offset_first / molecule::NUMBER_SIZE - 1; + if field_count < Self::FIELD_COUNT { + return ve!(Self, FieldCountNotMatch, Self::FIELD_COUNT, field_count); + } else if !compatible && field_count > Self::FIELD_COUNT { + return ve!(Self, FieldCountNotMatch, Self::FIELD_COUNT, field_count); + }; + let mut offsets: Vec = slice[molecule::NUMBER_SIZE..offset_first] + .chunks_exact(molecule::NUMBER_SIZE) + .map(|x| molecule::unpack_number(x) as usize) + .collect(); + offsets.push(total_size); + if offsets.windows(2).any(|i| i[0] > i[1]) { + return ve!(Self, OffsetsNotMatch); + } + ByteReader::verify(&slice[offsets[0]..offsets[1]], compatible)?; + ByteReader::verify(&slice[offsets[1]..offsets[2]], compatible)?; + Ok(()) + } +} +#[derive(Clone, Debug, Default)] +pub struct HashAlgorithmBuilder { + pub(crate) attr_id: Byte, + pub(crate) value: Byte, +} +impl HashAlgorithmBuilder { + pub const FIELD_COUNT: usize = 2; + pub fn attr_id(mut self, v: Byte) -> Self { + self.attr_id = v; + self + } + pub fn value(mut self, v: Byte) -> Self { + self.value = v; + self + } +} +impl molecule::prelude::Builder for HashAlgorithmBuilder { + type Entity = HashAlgorithm; + const NAME: &'static str = "HashAlgorithmBuilder"; + fn expected_length(&self) -> usize { + molecule::NUMBER_SIZE * (Self::FIELD_COUNT + 1) + + self.attr_id.as_slice().len() + + self.value.as_slice().len() + } + fn write(&self, writer: &mut W) -> molecule::io::Result<()> { + let mut total_size = molecule::NUMBER_SIZE * (Self::FIELD_COUNT + 1); + let mut offsets = Vec::with_capacity(Self::FIELD_COUNT); + offsets.push(total_size); + total_size += self.attr_id.as_slice().len(); + offsets.push(total_size); + total_size += self.value.as_slice().len(); + writer.write_all(&molecule::pack_number(total_size as molecule::Number))?; + for offset in offsets.into_iter() { + writer.write_all(&molecule::pack_number(offset as molecule::Number))?; + } + writer.write_all(self.attr_id.as_slice())?; + writer.write_all(self.value.as_slice())?; + Ok(()) + } + fn build(&self) -> Self::Entity { + let mut inner = Vec::with_capacity(self.expected_length()); + self.write(&mut inner) + .unwrap_or_else(|_| panic!("{} build should be ok", Self::NAME)); + HashAlgorithm::new_unchecked(inner.into()) + } +} +#[derive(Clone)] pub struct InvoiceAttr(molecule::bytes::Bytes); impl ::core::fmt::LowerHex for InvoiceAttr { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { @@ -6312,7 +6572,7 @@ impl InvoiceAttr { 0, 0, 0, 0, 41, 0, 0, 0, 12, 0, 0, 0, 13, 0, 0, 0, 0, 28, 0, 0, 0, 12, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]; - pub const ITEMS_COUNT: usize = 9; + pub const ITEMS_COUNT: usize = 10; pub fn item_id(&self) -> molecule::Number { molecule::unpack_number(self.as_slice()) } @@ -6328,6 +6588,7 @@ impl InvoiceAttr { 6 => UdtScript::new_unchecked(inner).into(), 7 => PayeePublicKey::new_unchecked(inner).into(), 8 => PaymentPreimage::new_unchecked(inner).into(), + 9 => HashAlgorithm::new_unchecked(inner).into(), _ => panic!("{}: invalid data", Self::NAME), } } @@ -6384,7 +6645,7 @@ impl<'r> ::core::fmt::Display for InvoiceAttrReader<'r> { } } impl<'r> InvoiceAttrReader<'r> { - pub const ITEMS_COUNT: usize = 9; + pub const ITEMS_COUNT: usize = 10; pub fn item_id(&self) -> molecule::Number { molecule::unpack_number(self.as_slice()) } @@ -6400,6 +6661,7 @@ impl<'r> InvoiceAttrReader<'r> { 6 => UdtScriptReader::new_unchecked(inner).into(), 7 => PayeePublicKeyReader::new_unchecked(inner).into(), 8 => PaymentPreimageReader::new_unchecked(inner).into(), + 9 => HashAlgorithmReader::new_unchecked(inner).into(), _ => panic!("{}: invalid data", Self::NAME), } } @@ -6434,6 +6696,7 @@ impl<'r> molecule::prelude::Reader<'r> for InvoiceAttrReader<'r> { 6 => UdtScriptReader::verify(inner_slice, compatible), 7 => PayeePublicKeyReader::verify(inner_slice, compatible), 8 => PaymentPreimageReader::verify(inner_slice, compatible), + 9 => HashAlgorithmReader::verify(inner_slice, compatible), _ => ve!(Self, UnknownItem, Self::ITEMS_COUNT, item_id), }?; Ok(()) @@ -6442,7 +6705,7 @@ impl<'r> molecule::prelude::Reader<'r> for InvoiceAttrReader<'r> { #[derive(Clone, Debug, Default)] pub struct InvoiceAttrBuilder(pub(crate) InvoiceAttrUnion); impl InvoiceAttrBuilder { - pub const ITEMS_COUNT: usize = 9; + pub const ITEMS_COUNT: usize = 10; pub fn set(mut self, v: I) -> Self where I: ::core::convert::Into, @@ -6479,6 +6742,7 @@ pub enum InvoiceAttrUnion { UdtScript(UdtScript), PayeePublicKey(PayeePublicKey), PaymentPreimage(PaymentPreimage), + HashAlgorithm(HashAlgorithm), } #[derive(Debug, Clone, Copy)] pub enum InvoiceAttrUnionReader<'r> { @@ -6491,6 +6755,7 @@ pub enum InvoiceAttrUnionReader<'r> { UdtScript(UdtScriptReader<'r>), PayeePublicKey(PayeePublicKeyReader<'r>), PaymentPreimage(PaymentPreimageReader<'r>), + HashAlgorithm(HashAlgorithmReader<'r>), } impl ::core::default::Default for InvoiceAttrUnion { fn default() -> Self { @@ -6533,6 +6798,9 @@ impl ::core::fmt::Display for InvoiceAttrUnion { InvoiceAttrUnion::PaymentPreimage(ref item) => { write!(f, "{}::{}({})", Self::NAME, PaymentPreimage::NAME, item) } + InvoiceAttrUnion::HashAlgorithm(ref item) => { + write!(f, "{}::{}({})", Self::NAME, HashAlgorithm::NAME, item) + } } } } @@ -6572,6 +6840,9 @@ impl<'r> ::core::fmt::Display for InvoiceAttrUnionReader<'r> { InvoiceAttrUnionReader::PaymentPreimage(ref item) => { write!(f, "{}::{}({})", Self::NAME, PaymentPreimage::NAME, item) } + InvoiceAttrUnionReader::HashAlgorithm(ref item) => { + write!(f, "{}::{}({})", Self::NAME, HashAlgorithm::NAME, item) + } } } } @@ -6587,6 +6858,7 @@ impl InvoiceAttrUnion { InvoiceAttrUnion::UdtScript(ref item) => write!(f, "{}", item), InvoiceAttrUnion::PayeePublicKey(ref item) => write!(f, "{}", item), InvoiceAttrUnion::PaymentPreimage(ref item) => write!(f, "{}", item), + InvoiceAttrUnion::HashAlgorithm(ref item) => write!(f, "{}", item), } } } @@ -6602,6 +6874,7 @@ impl<'r> InvoiceAttrUnionReader<'r> { InvoiceAttrUnionReader::UdtScript(ref item) => write!(f, "{}", item), InvoiceAttrUnionReader::PayeePublicKey(ref item) => write!(f, "{}", item), InvoiceAttrUnionReader::PaymentPreimage(ref item) => write!(f, "{}", item), + InvoiceAttrUnionReader::HashAlgorithm(ref item) => write!(f, "{}", item), } } } @@ -6650,6 +6923,11 @@ impl ::core::convert::From for InvoiceAttrUnion { InvoiceAttrUnion::PaymentPreimage(item) } } +impl ::core::convert::From for InvoiceAttrUnion { + fn from(item: HashAlgorithm) -> Self { + InvoiceAttrUnion::HashAlgorithm(item) + } +} impl<'r> ::core::convert::From> for InvoiceAttrUnionReader<'r> { fn from(item: ExpiryTimeReader<'r>) -> Self { InvoiceAttrUnionReader::ExpiryTime(item) @@ -6697,6 +6975,11 @@ impl<'r> ::core::convert::From> for InvoiceAttrUnionRe InvoiceAttrUnionReader::PaymentPreimage(item) } } +impl<'r> ::core::convert::From> for InvoiceAttrUnionReader<'r> { + fn from(item: HashAlgorithmReader<'r>) -> Self { + InvoiceAttrUnionReader::HashAlgorithm(item) + } +} impl InvoiceAttrUnion { pub const NAME: &'static str = "InvoiceAttrUnion"; pub fn as_bytes(&self) -> molecule::bytes::Bytes { @@ -6710,6 +6993,7 @@ impl InvoiceAttrUnion { InvoiceAttrUnion::UdtScript(item) => item.as_bytes(), InvoiceAttrUnion::PayeePublicKey(item) => item.as_bytes(), InvoiceAttrUnion::PaymentPreimage(item) => item.as_bytes(), + InvoiceAttrUnion::HashAlgorithm(item) => item.as_bytes(), } } pub fn as_slice(&self) -> &[u8] { @@ -6723,6 +7007,7 @@ impl InvoiceAttrUnion { InvoiceAttrUnion::UdtScript(item) => item.as_slice(), InvoiceAttrUnion::PayeePublicKey(item) => item.as_slice(), InvoiceAttrUnion::PaymentPreimage(item) => item.as_slice(), + InvoiceAttrUnion::HashAlgorithm(item) => item.as_slice(), } } pub fn item_id(&self) -> molecule::Number { @@ -6736,6 +7021,7 @@ impl InvoiceAttrUnion { InvoiceAttrUnion::UdtScript(_) => 6, InvoiceAttrUnion::PayeePublicKey(_) => 7, InvoiceAttrUnion::PaymentPreimage(_) => 8, + InvoiceAttrUnion::HashAlgorithm(_) => 9, } } pub fn item_name(&self) -> &str { @@ -6749,6 +7035,7 @@ impl InvoiceAttrUnion { InvoiceAttrUnion::UdtScript(_) => "UdtScript", InvoiceAttrUnion::PayeePublicKey(_) => "PayeePublicKey", InvoiceAttrUnion::PaymentPreimage(_) => "PaymentPreimage", + InvoiceAttrUnion::HashAlgorithm(_) => "HashAlgorithm", } } pub fn as_reader<'r>(&'r self) -> InvoiceAttrUnionReader<'r> { @@ -6762,6 +7049,7 @@ impl InvoiceAttrUnion { InvoiceAttrUnion::UdtScript(item) => item.as_reader().into(), InvoiceAttrUnion::PayeePublicKey(item) => item.as_reader().into(), InvoiceAttrUnion::PaymentPreimage(item) => item.as_reader().into(), + InvoiceAttrUnion::HashAlgorithm(item) => item.as_reader().into(), } } } @@ -6778,6 +7066,7 @@ impl<'r> InvoiceAttrUnionReader<'r> { InvoiceAttrUnionReader::UdtScript(item) => item.as_slice(), InvoiceAttrUnionReader::PayeePublicKey(item) => item.as_slice(), InvoiceAttrUnionReader::PaymentPreimage(item) => item.as_slice(), + InvoiceAttrUnionReader::HashAlgorithm(item) => item.as_slice(), } } pub fn item_id(&self) -> molecule::Number { @@ -6791,6 +7080,7 @@ impl<'r> InvoiceAttrUnionReader<'r> { InvoiceAttrUnionReader::UdtScript(_) => 6, InvoiceAttrUnionReader::PayeePublicKey(_) => 7, InvoiceAttrUnionReader::PaymentPreimage(_) => 8, + InvoiceAttrUnionReader::HashAlgorithm(_) => 9, } } pub fn item_name(&self) -> &str { @@ -6804,6 +7094,7 @@ impl<'r> InvoiceAttrUnionReader<'r> { InvoiceAttrUnionReader::UdtScript(_) => "UdtScript", InvoiceAttrUnionReader::PayeePublicKey(_) => "PayeePublicKey", InvoiceAttrUnionReader::PaymentPreimage(_) => "PaymentPreimage", + InvoiceAttrUnionReader::HashAlgorithm(_) => "HashAlgorithm", } } } @@ -6852,6 +7143,11 @@ impl From for InvoiceAttr { Self::new_builder().set(value).build() } } +impl From for InvoiceAttr { + fn from(value: HashAlgorithm) -> Self { + Self::new_builder().set(value).build() + } +} #[derive(Clone)] pub struct InvoiceAttrsVec(molecule::bytes::Bytes); impl ::core::fmt::LowerHex for InvoiceAttrsVec { diff --git a/src/ckb/hash_algorithm.rs b/src/ckb/hash_algorithm.rs new file mode 100644 index 000000000..7ed07a30e --- /dev/null +++ b/src/ckb/hash_algorithm.rs @@ -0,0 +1,79 @@ +use bitcoin::hashes::{sha256::Hash as Sha256, Hash as _}; +use ckb_hash::blake2b_256; +use ckb_types::packed; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[repr(u8)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum HashAlgorithm { + CKBHash = 0, + Sha256 = 1, +} + +impl HashAlgorithm { + pub fn hash>(&self, s: T) -> [u8; 32] { + match self { + HashAlgorithm::CKBHash => blake2b_256(s), + HashAlgorithm::Sha256 => sha256(s), + } + } +} + +impl Default for HashAlgorithm { + fn default() -> Self { + HashAlgorithm::CKBHash + } +} + +/// The error type wrap various ser/de errors. +#[derive(Error, Debug)] +#[error("Unknown Hash Algorithm: {0}")] +pub struct UnknownHashAlgorithmError(pub u8); + +impl TryFrom for HashAlgorithm { + type Error = UnknownHashAlgorithmError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(HashAlgorithm::CKBHash), + 1 => Ok(HashAlgorithm::Sha256), + _ => Err(UnknownHashAlgorithmError(value)), + } + } +} + +impl TryFrom for HashAlgorithm { + type Error = UnknownHashAlgorithmError; + + fn try_from(value: packed::Byte) -> Result { + let value: u8 = value.into(); + value.try_into() + } +} + +pub fn sha256>(s: T) -> [u8; 32] { + Sha256::hash(s.as_ref()).to_byte_array() +} + +#[cfg(test)] +mod tests { + #[test] + fn test_hash_algorithm_serialization_sha256() { + let algorithm = super::HashAlgorithm::Sha256; + let serialized = serde_json::to_string(&algorithm).unwrap(); + assert_eq!(serialized, r#""sha256""#); + let deserialized: super::HashAlgorithm = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized, algorithm); + } + + #[test] + fn test_hash_algorithm_serialization_ckb_hash() { + let algorithm = super::HashAlgorithm::CKBHash; + let serialized = serde_json::to_string(&algorithm).unwrap(); + assert_eq!(serialized, r#""ckb_hash""#); + let deserialized: super::HashAlgorithm = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized, algorithm); + } +} diff --git a/src/ckb/mod.rs b/src/ckb/mod.rs index 2c77fc69d..3418b52da 100644 --- a/src/ckb/mod.rs +++ b/src/ckb/mod.rs @@ -16,6 +16,8 @@ pub mod channel; pub mod types; +pub mod hash_algorithm; + pub mod serde_utils; #[cfg(test)] diff --git a/src/ckb/schema/cfn.mol b/src/ckb/schema/cfn.mol index ce7f9e1f5..814d990ac 100644 --- a/src/ckb/schema/cfn.mol +++ b/src/ckb/schema/cfn.mol @@ -101,6 +101,7 @@ table AddTlc { amount: Uint128, payment_hash: Byte32, expiry: Uint64, + hash_algorithm: byte, } table RevokeAndAck { diff --git a/src/ckb/schema/invoice.mol b/src/ckb/schema/invoice.mol index c8e83f09b..392522532 100644 --- a/src/ckb/schema/invoice.mol +++ b/src/ckb/schema/invoice.mol @@ -60,6 +60,13 @@ table PaymentPreimage { value: Preimage, } +// 0 - ckb hash (Default) +// 1 - sha256 +table HashAlgorithm { + attr_id: byte, + value: byte, +} + union InvoiceAttr { ExpiryTime, Description, @@ -70,6 +77,7 @@ union InvoiceAttr { UdtScript, PayeePublicKey, PaymentPreimage, + HashAlgorithm, } vector InvoiceAttrsVec ; diff --git a/src/ckb/types.rs b/src/ckb/types.rs index 9804ea59a..b7ce6c3c0 100644 --- a/src/ckb/types.rs +++ b/src/ckb/types.rs @@ -1,6 +1,7 @@ use std::str::FromStr; use super::gen::cfn::{self as molecule_cfn, PubNonce as Byte66}; +use super::hash_algorithm::{HashAlgorithm, UnknownHashAlgorithmError}; use super::serde_utils::SliceHex; use anyhow::anyhow; use ckb_sdk::{Since, SinceType}; @@ -927,6 +928,7 @@ pub struct AddTlc { pub amount: u128, pub payment_hash: Hash256, pub expiry: LockTime, + pub hash_algorithm: HashAlgorithm, } impl From for molecule_cfn::AddTlc { @@ -951,6 +953,10 @@ impl TryFrom for AddTlc { amount: add_tlc.amount().unpack(), payment_hash: add_tlc.payment_hash().into(), expiry: add_tlc.expiry().try_into()?, + hash_algorithm: add_tlc + .hash_algorithm() + .try_into() + .map_err(|err: UnknownHashAlgorithmError| Error::AnyHow(err.into()))?, }) } } diff --git a/src/invoice/invoice_impl.rs b/src/invoice/invoice_impl.rs index 936c6a7b7..f79a5fe01 100644 --- a/src/invoice/invoice_impl.rs +++ b/src/invoice/invoice_impl.rs @@ -1,6 +1,7 @@ use super::errors::VerificationError; use super::utils::*; use crate::ckb::gen::invoice::{self as gen_invoice, *}; +use crate::ckb::hash_algorithm::HashAlgorithm; use crate::ckb::serde_utils::EntityHex; use crate::ckb::serde_utils::U128Hex; use crate::ckb::types::Hash256; @@ -88,6 +89,7 @@ pub enum Attribute { UdtScript(CkbScript), PayeePublicKey(PublicKey), PaymentPreimage(Hash256), + HashAlgorithm(HashAlgorithm), Feature(u64), } @@ -463,6 +465,11 @@ impl From for InvoiceAttr { ) .build(), ), + Attribute::HashAlgorithm(hash_algorithm) => InvoiceAttrUnion::HashAlgorithm( + gen_invoice::HashAlgorithm::new_builder() + .value(Byte::new(hash_algorithm as u8)) + .build(), + ), }; InvoiceAttr::new_builder().set(a).build() } @@ -504,6 +511,12 @@ impl From for Attribute { preimage.copy_from_slice(&value); Attribute::PaymentPreimage(preimage.into()) } + InvoiceAttrUnion::HashAlgorithm(x) => { + let value = x.value(); + // Consider unknown algorithm as the default one. + let hash_algorithm = value.try_into().unwrap_or_default(); + Attribute::HashAlgorithm(hash_algorithm) + } } } } diff --git a/src/rpc/channel.rs b/src/rpc/channel.rs index b6d80932d..e35fc8fb7 100644 --- a/src/rpc/channel.rs +++ b/src/rpc/channel.rs @@ -3,6 +3,7 @@ use crate::ckb::{ AddTlcCommand, ChannelActorStateStore, ChannelCommand, ChannelCommandWithId, ChannelState, RemoveTlcCommand, ShutdownCommand, }, + hash_algorithm::HashAlgorithm, network::{AcceptChannelCommand, OpenChannelCommand}, serde_utils::{U128Hex, U32Hex, U64Hex}, types::{Hash256, LockTime, RemoveTlcFail, RemoveTlcFulfill}, @@ -93,6 +94,7 @@ pub struct AddTlcParams { pub amount: u128, pub payment_hash: Hash256, pub expiry: LockTime, + pub hash_algorithm: Option, } #[serde_as] @@ -278,6 +280,7 @@ where preimage: None, payment_hash: Some(params.payment_hash), expiry: params.expiry, + hash_algorithm: params.hash_algorithm.unwrap_or_default(), }, rpc_reply, ),