Skip to content

Commit 4fc5b30

Browse files
authored
breaking: fix name collision in FromRow, return Error::ColumnDecode for TryFrom errors (#3356)
* chore: create regression test for #3344 * fix(derives): use a parameter name that's less likely to collide * breaking(derives): emit `Error::ColumnDecode` when a `TryFrom` conversion fails in `FromRow` Breaking because `#[sqlx(default)]` on an individual field or the struct itself would have previously suppressed the error. This doesn't seem like good behavior as it could result in some potentially very difficult bugs. Instead of using `TryFrom` for these fields, just implement `From` and apply the default explicitly. * fix: run `cargo fmt` * fix: use correct field in `ColumnDecode`
1 parent b37b34b commit 4fc5b30

File tree

4 files changed

+181
-22
lines changed

4 files changed

+181
-22
lines changed

sqlx-macros-core/src/derives/row.rs

Lines changed: 60 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,30 @@ fn expand_derive_from_row_struct(
104104
.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
105105
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));
106106

107-
parse_quote!(row.try_get(#id_s))
107+
parse_quote!(__row.try_get(#id_s))
108108
}
109109
// Flatten
110110
(true, None, false) => {
111111
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
112-
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row))
112+
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(__row))
113113
}
114114
// Flatten + Try from
115115
(true, Some(try_from), false) => {
116116
predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>));
117-
parse_quote!(<#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(row).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
117+
parse_quote!(
118+
<#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(__row)
119+
.and_then(|v| {
120+
<#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v)
121+
.map_err(|e| {
122+
// Triggers a lint warning if `TryFrom::Err = Infallible`
123+
#[allow(unreachable_code)]
124+
::sqlx::Error::ColumnDecode {
125+
index: #id_s.to_string(),
126+
source: sqlx::__spec_error!(e),
127+
}
128+
})
129+
})
130+
)
118131
}
119132
// Flatten + Json
120133
(true, _, true) => {
@@ -126,7 +139,20 @@ fn expand_derive_from_row_struct(
126139
.push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>));
127140
predicates.push(parse_quote!(#try_from: ::sqlx::types::Type<R::Database>));
128141

129-
parse_quote!(row.try_get(#id_s).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))))
142+
parse_quote!(
143+
__row.try_get(#id_s)
144+
.and_then(|v| {
145+
<#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v)
146+
.map_err(|e| {
147+
// Triggers a lint warning if `TryFrom::Err = Infallible`
148+
#[allow(unreachable_code)]
149+
::sqlx::Error::ColumnDecode {
150+
index: #id_s.to_string(),
151+
source: sqlx::__spec_error!(e),
152+
}
153+
})
154+
})
155+
)
130156
}
131157
// Try from + Json
132158
(false, Some(try_from), true) => {
@@ -135,10 +161,18 @@ fn expand_derive_from_row_struct(
135161
predicates.push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::types::Type<R::Database>));
136162

137163
parse_quote!(
138-
row.try_get::<::sqlx::types::Json<_>, _>(#id_s).and_then(|v|
139-
<#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v.0)
140-
.map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string()))
141-
)
164+
__row.try_get::<::sqlx::types::Json<_>, _>(#id_s)
165+
.and_then(|v| {
166+
<#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v.0)
167+
.map_err(|e| {
168+
// Triggers a lint warning if `TryFrom::Err = Infallible`
169+
#[allow(unreachable_code)]
170+
::sqlx::Error::ColumnDecode {
171+
index: #id_s.to_string(),
172+
source: sqlx::__spec_error!(e),
173+
}
174+
})
175+
})
142176
)
143177
},
144178
// Json
@@ -147,24 +181,28 @@ fn expand_derive_from_row_struct(
147181
.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::decode::Decode<#lifetime, R::Database>));
148182
predicates.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::types::Type<R::Database>));
149183

150-
parse_quote!(row.try_get::<::sqlx::types::Json<_>, _>(#id_s).map(|x| x.0))
184+
parse_quote!(__row.try_get::<::sqlx::types::Json<_>, _>(#id_s).map(|x| x.0))
151185
},
152186
};
153187

154188
if attributes.default {
155-
Some(parse_quote!(let #id: #ty = #expr.or_else(|e| match e {
156-
::sqlx::Error::ColumnNotFound(_) => {
157-
::std::result::Result::Ok(Default::default())
158-
},
159-
e => ::std::result::Result::Err(e)
160-
})?;))
189+
Some(parse_quote!(
190+
let #id: #ty = #expr.or_else(|e| match e {
191+
::sqlx::Error::ColumnNotFound(_) => {
192+
::std::result::Result::Ok(Default::default())
193+
},
194+
e => ::std::result::Result::Err(e)
195+
})?;
196+
))
161197
} else if container_attributes.default {
162-
Some(parse_quote!(let #id: #ty = #expr.or_else(|e| match e {
163-
::sqlx::Error::ColumnNotFound(_) => {
164-
::std::result::Result::Ok(__default.#id)
165-
},
166-
e => ::std::result::Result::Err(e)
167-
})?;))
198+
Some(parse_quote!(
199+
let #id: #ty = #expr.or_else(|e| match e {
200+
::sqlx::Error::ColumnNotFound(_) => {
201+
::std::result::Result::Ok(__default.#id)
202+
},
203+
e => ::std::result::Result::Err(e)
204+
})?;
205+
))
168206
} else {
169207
Some(parse_quote!(
170208
let #id: #ty = #expr?;
@@ -180,7 +218,7 @@ fn expand_derive_from_row_struct(
180218
Ok(quote!(
181219
#[automatically_derived]
182220
impl #impl_generics ::sqlx::FromRow<#lifetime, R> for #ident #ty_generics #where_clause {
183-
fn from_row(row: &#lifetime R) -> ::sqlx::Result<Self> {
221+
fn from_row(__row: &#lifetime R) -> ::sqlx::Result<Self> {
184222
#default_instance
185223

186224
#(#reads)*

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ mod macros;
8686
#[doc(hidden)]
8787
pub mod ty_match;
8888

89+
#[cfg(feature = "macros")]
90+
pub mod spec_error;
91+
8992
#[doc(hidden)]
9093
pub use sqlx_core::rt as __rt;
9194

src/spec_error.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
use std::any::Any;
2+
use std::error::Error;
3+
use std::fmt::{Debug, Display};
4+
5+
// Autoderef specialization similar to `clap::value_parser!()`.
6+
pub struct SpecErrorWrapper<E>(pub E);
7+
8+
pub trait SpecError<E>: Sized {
9+
fn __sqlx_spec_error(
10+
&self,
11+
) -> fn(SpecErrorWrapper<E>) -> Box<dyn Error + Send + Sync + 'static>;
12+
}
13+
14+
impl<E> SpecError<E> for &&&&SpecErrorWrapper<E>
15+
where
16+
E: Error + Send + Sync + 'static,
17+
{
18+
fn __sqlx_spec_error(
19+
&self,
20+
) -> fn(SpecErrorWrapper<E>) -> Box<dyn Error + Send + Sync + 'static> {
21+
|e| Box::new(e.0)
22+
}
23+
}
24+
25+
impl<E> SpecError<E> for &&&SpecErrorWrapper<E>
26+
where
27+
E: Display,
28+
{
29+
fn __sqlx_spec_error(
30+
&self,
31+
) -> fn(SpecErrorWrapper<E>) -> Box<dyn Error + Send + Sync + 'static> {
32+
|e| e.0.to_string().into()
33+
}
34+
}
35+
36+
impl<E> SpecError<E> for &&SpecErrorWrapper<E>
37+
where
38+
E: Debug,
39+
{
40+
fn __sqlx_spec_error(
41+
&self,
42+
) -> fn(SpecErrorWrapper<E>) -> Box<dyn Error + Send + Sync + 'static> {
43+
|e| format!("{:?}", e.0).into()
44+
}
45+
}
46+
47+
impl<E> SpecError<E> for &SpecErrorWrapper<E>
48+
where
49+
E: Any,
50+
{
51+
fn __sqlx_spec_error(
52+
&self,
53+
) -> fn(SpecErrorWrapper<E>) -> Box<dyn Error + Send + Sync + 'static> {
54+
|_e| format!("unprintable error: {}", std::any::type_name::<E>()).into()
55+
}
56+
}
57+
58+
impl<E> SpecError<E> for SpecErrorWrapper<E> {
59+
fn __sqlx_spec_error(
60+
&self,
61+
) -> fn(SpecErrorWrapper<E>) -> Box<dyn Error + Send + Sync + 'static> {
62+
|_e| "unprintable error: (unprintable type)".into()
63+
}
64+
}
65+
66+
#[doc(hidden)]
67+
#[macro_export]
68+
macro_rules! __spec_error {
69+
($e:expr) => {{
70+
use $crate::spec_error::{SpecError, SpecErrorWrapper};
71+
72+
let wrapper = SpecErrorWrapper($e);
73+
let wrap_err = wrapper.__sqlx_spec_error();
74+
wrap_err(wrapper)
75+
}};
76+
}
77+
78+
#[test]
79+
fn test_spec_error() {
80+
#[derive(Debug)]
81+
struct DebugError;
82+
83+
struct AnyError;
84+
85+
let _e: Box<dyn Error + Send + Sync + 'static> =
86+
__spec_error!(std::io::Error::from(std::io::ErrorKind::Unsupported));
87+
88+
let _e: Box<dyn Error + Send + Sync + 'static> = __spec_error!("displayable error");
89+
90+
let _e: Box<dyn Error + Send + Sync + 'static> = __spec_error!(DebugError);
91+
92+
let _e: Box<dyn Error + Send + Sync + 'static> = __spec_error!(AnyError);
93+
94+
let _e: Box<dyn Error + Send + Sync + 'static> = __spec_error!(&1i32);
95+
}

tests/postgres/derives.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,3 +769,26 @@ async fn test_enum_with_schema() -> anyhow::Result<()> {
769769

770770
Ok(())
771771
}
772+
773+
#[cfg(feature = "macros")]
774+
#[sqlx_macros::test]
775+
async fn test_from_row_hygiene() -> anyhow::Result<()> {
776+
// A field named `row` previously would shadow the `row` parameter of `FromRow::from_row()`:
777+
// https://github.com/launchbadge/sqlx/issues/3344
778+
#[derive(Debug, sqlx::FromRow)]
779+
pub struct Foo {
780+
pub row: i32,
781+
pub bar: i32,
782+
}
783+
784+
let mut conn = new::<Postgres>().await?;
785+
786+
let foo: Foo = sqlx::query_as("SELECT 1234 as row, 5678 as bar")
787+
.fetch_one(&mut conn)
788+
.await?;
789+
790+
assert_eq!(foo.row, 1234);
791+
assert_eq!(foo.bar, 5678);
792+
793+
Ok(())
794+
}

0 commit comments

Comments
 (0)