Skip to content

Commit 620305a

Browse files
committed
impl FromSql for Vec<f32>, Vec<f64>, Vec<i8> and Vec<u8> to get VECTOR type
1 parent 9fc72c0 commit 620305a

File tree

5 files changed

+183
-4
lines changed

5 files changed

+183
-4
lines changed

ChangeLog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ New features:
1010
* Add [`VecFmt`] enum type
1111
* Add [`VecRef`] enum type to set rust values to Oracle VECTOR data type
1212
* Add [`VectorFormat`] trait type
13+
* impl `FromSql` for `Vec<f32>`, `Vec<f64>`, `Vec<i8>` and `Vec<u8>` to get values from Oracle VECTOR data type
1314

1415
Incompatible changes:
1516

src/sql_type/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ impl_from_and_to_sql!(f32, to_f32, set_f32, OracleType::Number(0, 0));
243243
impl_from_and_to_sql!(bool, to_bool, set_bool, OracleType::Boolean);
244244
impl_from_sql!(String, to_string);
245245
impl_from_sql!(Vec<u8>, to_bytes);
246+
impl_from_sql!(Vec<f32>, to_f32_vec);
247+
impl_from_sql!(Vec<f64>, to_f64_vec);
248+
impl_from_sql!(Vec<i8>, to_i8_vec);
246249
impl_from_and_to_sql!(
247250
Timestamp,
248251
to_timestamp,

src/sql_type/oracle_type.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ pub(crate) struct VarParam {
182182
pub native_type: NativeType,
183183
pub size: u32,
184184
pub size_is_byte: i32,
185+
pub vector_format: VecFmt,
185186
}
186187

187188
impl VarParam {
@@ -191,6 +192,7 @@ impl VarParam {
191192
native_type,
192193
size: 0,
193194
size_is_byte: 0,
195+
vector_format: VecFmt::Flexible,
194196
}
195197
}
196198

@@ -203,6 +205,11 @@ impl VarParam {
203205
self.size_is_byte = 1;
204206
self
205207
}
208+
209+
fn vector_format(mut self, format: VecFmt) -> VarParam {
210+
self.vector_format = format;
211+
self
212+
}
206213
}
207214

208215
/// Oracle data type
@@ -483,8 +490,8 @@ impl OracleType {
483490
)),
484491
OracleType::LongRaw => Ok(VarParam::new(DPI_ORACLE_TYPE_LONG_RAW, NativeType::Raw)),
485492
OracleType::Xml => Ok(VarParam::new(DPI_ORACLE_TYPE_XMLTYPE, NativeType::Char)),
486-
OracleType::Vector(_, _) => {
487-
Ok(VarParam::new(DPI_ORACLE_TYPE_VECTOR, NativeType::Vector))
493+
OracleType::Vector(_, format) => {
494+
Ok(VarParam::new(DPI_ORACLE_TYPE_VECTOR, NativeType::Vector).vector_format(format))
488495
}
489496
OracleType::Int64 => Ok(VarParam::new(DPI_ORACLE_TYPE_NATIVE_INT, NativeType::Int64)),
490497
OracleType::UInt64 => Ok(VarParam::new(

src/sql_type/vector.rs

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use crate::Result;
2525
use odpic_sys::*;
2626
use std::fmt;
2727
use std::os::raw::c_void;
28+
use std::slice;
2829

2930
/// Vector dimension element format
3031
///
@@ -100,6 +101,33 @@ pub enum VecRef<'a> {
100101
}
101102

102103
impl VecRef<'_> {
104+
// The 'static lifetime in the returned value is incorrect.
105+
// Its actual lifetime is that of data referred by info.
106+
pub(crate) unsafe fn from_dpi(info: dpiVectorInfo) -> Result<VecRef<'static>> {
107+
match info.format as u32 {
108+
DPI_VECTOR_FORMAT_FLOAT32 => Ok(VecRef::Float32(slice::from_raw_parts(
109+
info.dimensions.asFloat,
110+
info.numDimensions as usize,
111+
))),
112+
DPI_VECTOR_FORMAT_FLOAT64 => Ok(VecRef::Float64(slice::from_raw_parts(
113+
info.dimensions.asDouble,
114+
info.numDimensions as usize,
115+
))),
116+
DPI_VECTOR_FORMAT_INT8 => Ok(VecRef::Int8(slice::from_raw_parts(
117+
info.dimensions.asInt8,
118+
info.numDimensions as usize,
119+
))),
120+
DPI_VECTOR_FORMAT_BINARY => Ok(VecRef::Binary(slice::from_raw_parts(
121+
info.dimensions.asPtr as *const u8,
122+
(info.numDimensions / 8) as usize,
123+
))),
124+
_ => Err(Error::internal_error(format!(
125+
"unknown vector format {}",
126+
info.format
127+
))),
128+
}
129+
}
130+
103131
pub(crate) fn to_dpi(&self) -> Result<dpiVectorInfo> {
104132
match self {
105133
VecRef::Float32(slice) => Ok(dpiVectorInfo {
@@ -184,7 +212,7 @@ impl VecRef<'_> {
184212
T::vec_ref_to_slice(self)
185213
}
186214

187-
fn oracle_type(&self) -> OracleType {
215+
pub(crate) fn oracle_type(&self) -> OracleType {
188216
match self {
189217
VecRef::Float32(slice) => OracleType::Vector(slice.len() as u32, VecFmt::Float32),
190218
VecRef::Float64(slice) => OracleType::Vector(slice.len() as u32, VecFmt::Float64),
@@ -415,4 +443,49 @@ mod tests {
415443
assert_eq!(index, expected_data.len());
416444
Ok(())
417445
}
446+
447+
#[test]
448+
fn vec_from_sql() -> Result<()> {
449+
let conn = test_util::connect()?;
450+
451+
if !test_util::check_version(&conn, &test_util::VER23, &test_util::VER23)? {
452+
return Ok(());
453+
}
454+
let binary_vec = test_util::check_version(&conn, &test_util::VER23_5, &test_util::VER23_5)?;
455+
conn.execute("delete from test_vector_type", &[])?;
456+
let mut expected_data = vec![];
457+
conn.execute("insert into test_vector_type(id, vec) values(1, TO_VECTOR('[1.0, 2.25, 3.5]', 3, FLOAT32))", &[])?;
458+
expected_data.push((1, "FLOAT32", VecRef::Float32(&[1.0, 2.25, 3.5])));
459+
conn.execute("insert into test_vector_type(id, vec) values(2, TO_VECTOR('[4.0, 5.25, 6.5]', 3, FLOAT64))", &[])?;
460+
expected_data.push((2, "FLOAT64", VecRef::Float64(&[4.0, 5.25, 6.5])));
461+
conn.execute(
462+
"insert into test_vector_type(id, vec) values(3, TO_VECTOR('[7, 8, 9]', 3, INT8))",
463+
&[],
464+
)?;
465+
expected_data.push((3, "INT8", VecRef::Int8(&[7, 8, 9])));
466+
if binary_vec {
467+
conn.execute("insert into test_vector_type(id, vec) values(4, TO_VECTOR('[10, 11, 12]', 24, BINARY))", &[])?;
468+
expected_data.push((4, "BINARY", VecRef::Binary(&[10, 11, 12])));
469+
}
470+
let mut index = 0;
471+
for row_result in conn.query(
472+
"select id, vector_dimension_format(vec), vec from test_vector_type order by id",
473+
&[],
474+
)? {
475+
let row = row_result?;
476+
assert!(index < expected_data.len());
477+
let data = &expected_data[index];
478+
assert_eq!(row.get::<_, i32>(0)?, data.0);
479+
assert_eq!(row.get::<_, String>(1)?, data.1);
480+
match data.2 {
481+
VecRef::Float32(slice) => assert_eq!(row.get::<_, Vec<f32>>(2)?, slice),
482+
VecRef::Float64(slice) => assert_eq!(row.get::<_, Vec<f64>>(2)?, slice),
483+
VecRef::Int8(slice) => assert_eq!(row.get::<_, Vec<i8>>(2)?, slice),
484+
VecRef::Binary(slice) => assert_eq!(row.get::<_, Vec<u8>>(2)?, slice),
485+
}
486+
index += 1;
487+
}
488+
assert_eq!(index, expected_data.len());
489+
Ok(())
490+
}
418491
}

src/sql_value.rs

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
use crate::chkerr;
1717
use crate::connection::Conn;
18+
use crate::sql_type::vector::VecFmt;
1819
use crate::sql_type::vector::VecRef;
1920
use crate::sql_type::Bfile;
2021
use crate::sql_type::Blob;
@@ -52,6 +53,7 @@ use odpic_sys::*;
5253
use std::borrow::Cow;
5354
use std::convert::TryInto;
5455
use std::fmt;
56+
use std::mem::MaybeUninit;
5557
use std::os::raw::c_char;
5658
use std::ptr;
5759
use std::rc::Rc;
@@ -154,6 +156,7 @@ pub struct SqlValue<'a> {
154156
keep_dpiobj: DpiObject,
155157
pub(crate) lob_bind_type: LobBindType,
156158
pub(crate) query_params: QueryParams,
159+
vector_format: VecFmt,
157160
}
158161

159162
impl SqlValue<'_> {
@@ -174,6 +177,7 @@ impl SqlValue<'_> {
174177
keep_dpiobj: DpiObject::null(),
175178
lob_bind_type,
176179
query_params,
180+
vector_format: VecFmt::Flexible,
177181
}
178182
}
179183

@@ -228,6 +232,7 @@ impl SqlValue<'_> {
228232
keep_dpiobj: DpiObject::null(),
229233
lob_bind_type: LobBindType::Locator,
230234
query_params: QueryParams::new(),
235+
vector_format: param.vector_format,
231236
})
232237
}
233238

@@ -728,6 +733,20 @@ impl SqlValue<'_> {
728733
unsafe { Ok(dpiData_getStmt(self.data()?)) }
729734
}
730735

736+
/// # Safety
737+
///
738+
/// The actual lifetime of VecRef isn't 'static.
739+
/// It is same with that of DpiVar.
740+
unsafe fn get_vec_ref_unchecked(&self) -> Result<VecRef<'static>> {
741+
self.check_not_null()?;
742+
let mut info = MaybeUninit::uninit();
743+
chkerr!(
744+
self.ctxt(),
745+
dpiVector_getValue(self.data()?.value.asVector, info.as_mut_ptr())
746+
);
747+
VecRef::from_dpi(info.assume_init())
748+
}
749+
731750
//
732751
// set_TYPE_unchecked methods
733752
//
@@ -923,6 +942,7 @@ impl SqlValue<'_> {
923942
keep_dpiobj: DpiObject::null(),
924943
lob_bind_type: self.lob_bind_type,
925944
query_params: self.query_params.clone(),
945+
vector_format: self.vector_format,
926946
})
927947
} else {
928948
Err(Error::internal_error("dpVar handle isn't initialized"))
@@ -1064,7 +1084,12 @@ impl SqlValue<'_> {
10641084
}),
10651085
NativeType::Rowid => self.get_rowid_as_string_unchecked(),
10661086
NativeType::Stmt => self.invalid_conversion_to_rust_type("string"),
1067-
NativeType::Vector => todo!(),
1087+
NativeType::Vector => Ok(match unsafe { self.get_vec_ref_unchecked()? } {
1088+
VecRef::Float32(slice) => format!("{:?}", slice),
1089+
VecRef::Float64(slice) => format!("{:?}", slice),
1090+
VecRef::Int8(slice) => format!("{:?}", slice),
1091+
VecRef::Binary(slice) => format!("{:?}", slice),
1092+
}),
10681093
}
10691094
}
10701095

@@ -1075,10 +1100,79 @@ impl SqlValue<'_> {
10751100
NativeType::Blob => self.get_blob_unchecked(),
10761101
NativeType::Char => Ok(parse_str_into_raw(&self.get_cow_str_unchecked()?)?),
10771102
NativeType::Clob => Ok(parse_str_into_raw(&self.get_clob_as_string_unchecked()?)?),
1103+
NativeType::Vector
1104+
if self.vector_format == VecFmt::Binary
1105+
|| self.vector_format == VecFmt::Flexible =>
1106+
unsafe {
1107+
let vec_ref = self.get_vec_ref_unchecked()?;
1108+
match vec_ref {
1109+
VecRef::Binary(slice) => Ok(slice.to_vec()),
1110+
_ => Err(Error::invalid_type_conversion(
1111+
vec_ref.oracle_type().to_string(),
1112+
"Vec<u8>",
1113+
)),
1114+
}
1115+
},
10781116
_ => self.invalid_conversion_to_rust_type("raw"),
10791117
}
10801118
}
10811119

1120+
pub(crate) fn to_f32_vec(&self) -> Result<Vec<f32>> {
1121+
match self.native_type {
1122+
NativeType::Vector
1123+
if self.vector_format == VecFmt::Float32
1124+
|| self.vector_format == VecFmt::Flexible =>
1125+
unsafe {
1126+
let vec_ref = self.get_vec_ref_unchecked()?;
1127+
match vec_ref {
1128+
VecRef::Float32(slice) => Ok(slice.to_vec()),
1129+
_ => Err(Error::invalid_type_conversion(
1130+
vec_ref.oracle_type().to_string(),
1131+
"Vec<f32>",
1132+
)),
1133+
}
1134+
},
1135+
_ => self.invalid_conversion_to_rust_type("Vec<f32>"),
1136+
}
1137+
}
1138+
1139+
pub(crate) fn to_f64_vec(&self) -> Result<Vec<f64>> {
1140+
match self.native_type {
1141+
NativeType::Vector
1142+
if self.vector_format == VecFmt::Float64
1143+
|| self.vector_format == VecFmt::Flexible =>
1144+
unsafe {
1145+
let vec_ref = self.get_vec_ref_unchecked()?;
1146+
match vec_ref {
1147+
VecRef::Float64(slice) => Ok(slice.to_vec()),
1148+
_ => Err(Error::invalid_type_conversion(
1149+
vec_ref.oracle_type().to_string(),
1150+
"Vec<f64>",
1151+
)),
1152+
}
1153+
},
1154+
_ => self.invalid_conversion_to_rust_type("Vec<f64>"),
1155+
}
1156+
}
1157+
1158+
pub(crate) fn to_i8_vec(&self) -> Result<Vec<i8>> {
1159+
match self.native_type {
1160+
NativeType::Vector
1161+
if self.vector_format == VecFmt::Int8 || self.vector_format == VecFmt::Flexible =>
1162+
unsafe {
1163+
let vec_ref = self.get_vec_ref_unchecked()?;
1164+
match vec_ref {
1165+
VecRef::Int8(slice) => Ok(slice.to_vec()),
1166+
_ => Err(Error::invalid_type_conversion(
1167+
vec_ref.oracle_type().to_string(),
1168+
"Vec<i8>",
1169+
)),
1170+
}
1171+
},
1172+
_ => self.invalid_conversion_to_rust_type("Vec<i8>"),
1173+
}
1174+
}
1175+
10821176
/// Gets the SQL value as Timestamp. The Oracle type must be
10831177
/// `DATE`, `TIMESTAMP`, or `TIMESTAMP WITH TIME ZONE`.
10841178
pub(crate) fn to_timestamp(&self) -> Result<Timestamp> {
@@ -1401,6 +1495,7 @@ impl SqlValue<'_> {
14011495
keep_dpiobj: DpiObject::null(),
14021496
lob_bind_type: self.lob_bind_type,
14031497
query_params: self.query_params.clone(),
1498+
vector_format: self.vector_format,
14041499
})
14051500
}
14061501
}

0 commit comments

Comments
 (0)