Skip to content

Commit 38da7fa

Browse files
authored
Merge pull request sfackler#875 from halfmatthalfcat/ltree-support
Add ltree, lquery and ltxtquery support
2 parents a638ada + 6fae655 commit 38da7fa

File tree

8 files changed

+272
-17
lines changed

8 files changed

+272
-17
lines changed

docker/sql_setup.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,5 @@ psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" <<-EOSQL
9696
CREATE ROLE ssl_user LOGIN;
9797
CREATE EXTENSION hstore;
9898
CREATE EXTENSION citext;
99+
CREATE EXTENSION ltree;
99100
EOSQL

postgres-protocol/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "postgres-protocol"
3-
version = "0.6.3"
3+
version = "0.6.4"
44
authors = ["Steven Fackler <sfackler@gmail.com>"]
55
edition = "2018"
66
description = "Low level Postgres protocol APIs"

postgres-protocol/src/types/mod.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,3 +1059,60 @@ impl Inet {
10591059
self.netmask
10601060
}
10611061
}
1062+
1063+
/// Serializes a Postgres ltree string
1064+
#[inline]
1065+
pub fn ltree_to_sql(v: &str, buf: &mut BytesMut) {
1066+
// A version number is prepended to an ltree string per spec
1067+
buf.put_u8(1);
1068+
// Append the rest of the query
1069+
buf.put_slice(v.as_bytes());
1070+
}
1071+
1072+
/// Deserialize a Postgres ltree string
1073+
#[inline]
1074+
pub fn ltree_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
1075+
match buf {
1076+
// Remove the version number from the front of the ltree per spec
1077+
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
1078+
_ => Err("ltree version 1 only supported".into()),
1079+
}
1080+
}
1081+
1082+
/// Serializes a Postgres lquery string
1083+
#[inline]
1084+
pub fn lquery_to_sql(v: &str, buf: &mut BytesMut) {
1085+
// A version number is prepended to an lquery string per spec
1086+
buf.put_u8(1);
1087+
// Append the rest of the query
1088+
buf.put_slice(v.as_bytes());
1089+
}
1090+
1091+
/// Deserialize a Postgres lquery string
1092+
#[inline]
1093+
pub fn lquery_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
1094+
match buf {
1095+
// Remove the version number from the front of the lquery per spec
1096+
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
1097+
_ => Err("lquery version 1 only supported".into()),
1098+
}
1099+
}
1100+
1101+
/// Serializes a Postgres ltxtquery string
1102+
#[inline]
1103+
pub fn ltxtquery_to_sql(v: &str, buf: &mut BytesMut) {
1104+
// A version number is prepended to an ltxtquery string per spec
1105+
buf.put_u8(1);
1106+
// Append the rest of the query
1107+
buf.put_slice(v.as_bytes());
1108+
}
1109+
1110+
/// Deserialize a Postgres ltxtquery string
1111+
#[inline]
1112+
pub fn ltxtquery_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
1113+
match buf {
1114+
// Remove the version number from the front of the ltxtquery per spec
1115+
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
1116+
_ => Err("ltxtquery version 1 only supported".into()),
1117+
}
1118+
}

postgres-protocol/src/types/test.rs

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use bytes::BytesMut;
1+
use bytes::{Buf, BytesMut};
22
use fallible_iterator::FallibleIterator;
33
use std::collections::HashMap;
44

@@ -156,3 +156,87 @@ fn non_null_array() {
156156
assert_eq!(array.dimensions().collect::<Vec<_>>().unwrap(), dimensions);
157157
assert_eq!(array.values().collect::<Vec<_>>().unwrap(), values);
158158
}
159+
160+
#[test]
161+
fn ltree_sql() {
162+
let mut query = vec![1u8];
163+
query.extend_from_slice("A.B.C".as_bytes());
164+
165+
let mut buf = BytesMut::new();
166+
167+
ltree_to_sql("A.B.C", &mut buf);
168+
169+
assert_eq!(query.as_slice(), buf.chunk());
170+
}
171+
172+
#[test]
173+
fn ltree_str() {
174+
let mut query = vec![1u8];
175+
query.extend_from_slice("A.B.C".as_bytes());
176+
177+
assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_)))
178+
}
179+
180+
#[test]
181+
fn ltree_wrong_version() {
182+
let mut query = vec![2u8];
183+
query.extend_from_slice("A.B.C".as_bytes());
184+
185+
assert!(matches!(ltree_from_sql(query.as_slice()), Err(_)))
186+
}
187+
188+
#[test]
189+
fn lquery_sql() {
190+
let mut query = vec![1u8];
191+
query.extend_from_slice("A.B.C".as_bytes());
192+
193+
let mut buf = BytesMut::new();
194+
195+
lquery_to_sql("A.B.C", &mut buf);
196+
197+
assert_eq!(query.as_slice(), buf.chunk());
198+
}
199+
200+
#[test]
201+
fn lquery_str() {
202+
let mut query = vec![1u8];
203+
query.extend_from_slice("A.B.C".as_bytes());
204+
205+
assert!(matches!(lquery_from_sql(query.as_slice()), Ok(_)))
206+
}
207+
208+
#[test]
209+
fn lquery_wrong_version() {
210+
let mut query = vec![2u8];
211+
query.extend_from_slice("A.B.C".as_bytes());
212+
213+
assert!(matches!(lquery_from_sql(query.as_slice()), Err(_)))
214+
}
215+
216+
#[test]
217+
fn ltxtquery_sql() {
218+
let mut query = vec![1u8];
219+
query.extend_from_slice("a & b*".as_bytes());
220+
221+
let mut buf = BytesMut::new();
222+
223+
ltree_to_sql("a & b*", &mut buf);
224+
225+
assert_eq!(query.as_slice(), buf.chunk());
226+
}
227+
228+
#[test]
229+
fn ltxtquery_str() {
230+
let mut query = vec![1u8];
231+
query.extend_from_slice("a & b*".as_bytes());
232+
233+
assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_)))
234+
}
235+
236+
#[test]
237+
fn ltxtquery_wrong_version() {
238+
let mut query = vec![2u8];
239+
query.extend_from_slice("a & b*".as_bytes());
240+
241+
assert!(matches!(ltree_from_sql(query.as_slice()), Err(_)))
242+
}

postgres-types/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "postgres-types"
3-
version = "0.2.2"
3+
version = "0.2.3"
44
authors = ["Steven Fackler <sfackler@gmail.com>"]
55
edition = "2018"
66
license = "MIT/Apache-2.0"
@@ -28,7 +28,7 @@ with-time-0_3 = ["time-03"]
2828
[dependencies]
2929
bytes = "1.0"
3030
fallible-iterator = "0.2"
31-
postgres-protocol = { version = "0.6.1", path = "../postgres-protocol" }
31+
postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" }
3232
postgres-derive = { version = "0.4.0", optional = true, path = "../postgres-derive" }
3333

3434
array-init = { version = "2", optional = true }

postgres-types/src/lib.rs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ impl WrongType {
407407
/// | `f32` | REAL |
408408
/// | `f64` | DOUBLE PRECISION |
409409
/// | `&str`/`String` | VARCHAR, CHAR(n), TEXT, CITEXT, NAME, UNKNOWN |
410+
/// | | LTREE, LQUERY, LTXTQUERY |
410411
/// | `&[u8]`/`Vec<u8>` | BYTEA |
411412
/// | `HashMap<String, Option<String>>` | HSTORE |
412413
/// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE |
@@ -594,8 +595,8 @@ impl<'a> FromSql<'a> for &'a [u8] {
594595
}
595596

596597
impl<'a> FromSql<'a> for String {
597-
fn from_sql(_: &Type, raw: &'a [u8]) -> Result<String, Box<dyn Error + Sync + Send>> {
598-
types::text_from_sql(raw).map(ToString::to_string)
598+
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<String, Box<dyn Error + Sync + Send>> {
599+
<&str as FromSql>::from_sql(ty, raw).map(ToString::to_string)
599600
}
600601

601602
fn accepts(ty: &Type) -> bool {
@@ -604,8 +605,8 @@ impl<'a> FromSql<'a> for String {
604605
}
605606

606607
impl<'a> FromSql<'a> for Box<str> {
607-
fn from_sql(_: &Type, raw: &'a [u8]) -> Result<Box<str>, Box<dyn Error + Sync + Send>> {
608-
types::text_from_sql(raw)
608+
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Box<str>, Box<dyn Error + Sync + Send>> {
609+
<&str as FromSql>::from_sql(ty, raw)
609610
.map(ToString::to_string)
610611
.map(String::into_boxed_str)
611612
}
@@ -616,14 +617,26 @@ impl<'a> FromSql<'a> for Box<str> {
616617
}
617618

618619
impl<'a> FromSql<'a> for &'a str {
619-
fn from_sql(_: &Type, raw: &'a [u8]) -> Result<&'a str, Box<dyn Error + Sync + Send>> {
620-
types::text_from_sql(raw)
620+
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<&'a str, Box<dyn Error + Sync + Send>> {
621+
match *ty {
622+
ref ty if ty.name() == "ltree" => types::ltree_from_sql(raw),
623+
ref ty if ty.name() == "lquery" => types::lquery_from_sql(raw),
624+
ref ty if ty.name() == "ltxtquery" => types::ltxtquery_from_sql(raw),
625+
_ => types::text_from_sql(raw),
626+
}
621627
}
622628

623629
fn accepts(ty: &Type) -> bool {
624630
match *ty {
625631
Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true,
626-
ref ty if ty.name() == "citext" => true,
632+
ref ty
633+
if (ty.name() == "citext"
634+
|| ty.name() == "ltree"
635+
|| ty.name() == "lquery"
636+
|| ty.name() == "ltxtquery") =>
637+
{
638+
true
639+
}
627640
_ => false,
628641
}
629642
}
@@ -727,6 +740,7 @@ pub enum IsNull {
727740
/// | `f32` | REAL |
728741
/// | `f64` | DOUBLE PRECISION |
729742
/// | `&str`/`String` | VARCHAR, CHAR(n), TEXT, CITEXT, NAME |
743+
/// | | LTREE, LQUERY, LTXTQUERY |
730744
/// | `&[u8]`/`Vec<u8>` | BYTEA |
731745
/// | `HashMap<String, Option<String>>` | HSTORE |
732746
/// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE |
@@ -924,15 +938,27 @@ impl ToSql for Vec<u8> {
924938
}
925939

926940
impl<'a> ToSql for &'a str {
927-
fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
928-
types::text_to_sql(*self, w);
941+
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
942+
match *ty {
943+
ref ty if ty.name() == "ltree" => types::ltree_to_sql(*self, w),
944+
ref ty if ty.name() == "lquery" => types::lquery_to_sql(*self, w),
945+
ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(*self, w),
946+
_ => types::text_to_sql(*self, w),
947+
}
929948
Ok(IsNull::No)
930949
}
931950

932951
fn accepts(ty: &Type) -> bool {
933952
match *ty {
934953
Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true,
935-
ref ty if ty.name() == "citext" => true,
954+
ref ty
955+
if (ty.name() == "citext"
956+
|| ty.name() == "ltree"
957+
|| ty.name() == "lquery"
958+
|| ty.name() == "ltxtquery") =>
959+
{
960+
true
961+
}
936962
_ => false,
937963
}
938964
}

tokio-postgres/Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "tokio-postgres"
3-
version = "0.7.5"
3+
version = "0.7.6"
44
authors = ["Steven Fackler <sfackler@gmail.com>"]
55
edition = "2018"
66
license = "MIT/Apache-2.0"
@@ -50,8 +50,8 @@ parking_lot = "0.12"
5050
percent-encoding = "2.0"
5151
pin-project-lite = "0.2"
5252
phf = "0.10"
53-
postgres-protocol = { version = "0.6.1", path = "../postgres-protocol" }
54-
postgres-types = { version = "0.2.2", path = "../postgres-types" }
53+
postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" }
54+
postgres-types = { version = "0.2.3", path = "../postgres-types" }
5555
socket2 = "0.4"
5656
tokio = { version = "1.0", features = ["io-util"] }
5757
tokio-util = { version = "0.7", features = ["codec"] }

tokio-postgres/tests/test/types/mod.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,3 +648,90 @@ async fn inet() {
648648
)
649649
.await;
650650
}
651+
652+
#[tokio::test]
653+
async fn ltree() {
654+
test_type(
655+
"ltree",
656+
&[(Some("b.c.d".to_owned()), "'b.c.d'"), (None, "NULL")],
657+
)
658+
.await;
659+
}
660+
661+
#[tokio::test]
662+
async fn ltree_any() {
663+
test_type(
664+
"ltree[]",
665+
&[
666+
(Some(vec![]), "ARRAY[]"),
667+
(Some(vec!["a.b.c".to_string()]), "ARRAY['a.b.c']"),
668+
(
669+
Some(vec!["a.b.c".to_string(), "e.f.g".to_string()]),
670+
"ARRAY['a.b.c','e.f.g']",
671+
),
672+
(None, "NULL"),
673+
],
674+
)
675+
.await;
676+
}
677+
678+
#[tokio::test]
679+
async fn lquery() {
680+
test_type(
681+
"lquery",
682+
&[
683+
(Some("b.c.d".to_owned()), "'b.c.d'"),
684+
(Some("b.c.*".to_owned()), "'b.c.*'"),
685+
(Some("b.*{1,2}.d|e".to_owned()), "'b.*{1,2}.d|e'"),
686+
(None, "NULL"),
687+
],
688+
)
689+
.await;
690+
}
691+
692+
#[tokio::test]
693+
async fn lquery_any() {
694+
test_type(
695+
"lquery[]",
696+
&[
697+
(Some(vec![]), "ARRAY[]"),
698+
(Some(vec!["b.c.*".to_string()]), "ARRAY['b.c.*']"),
699+
(
700+
Some(vec!["b.c.*".to_string(), "b.*{1,2}.d|e".to_string()]),
701+
"ARRAY['b.c.*','b.*{1,2}.d|e']",
702+
),
703+
(None, "NULL"),
704+
],
705+
)
706+
.await;
707+
}
708+
709+
#[tokio::test]
710+
async fn ltxtquery() {
711+
test_type(
712+
"ltxtquery",
713+
&[
714+
(Some("b & c & d".to_owned()), "'b & c & d'"),
715+
(Some("b@* & !c".to_owned()), "'b@* & !c'"),
716+
(None, "NULL"),
717+
],
718+
)
719+
.await;
720+
}
721+
722+
#[tokio::test]
723+
async fn ltxtquery_any() {
724+
test_type(
725+
"ltxtquery[]",
726+
&[
727+
(Some(vec![]), "ARRAY[]"),
728+
(Some(vec!["b & c & d".to_string()]), "ARRAY['b & c & d']"),
729+
(
730+
Some(vec!["b & c & d".to_string(), "b@* & !c".to_string()]),
731+
"ARRAY['b & c & d','b@* & !c']",
732+
),
733+
(None, "NULL"),
734+
],
735+
)
736+
.await;
737+
}

0 commit comments

Comments
 (0)