Skip to content

Commit

Permalink
Add supports for slice of struct ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
haoxins committed Nov 4, 2024
1 parent ada350f commit 0985fcf
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 10 deletions.
22 changes: 14 additions & 8 deletions result_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,16 @@ func (res ResultSet) Scan(v interface{}) error {
func (res ResultSet) scanRow(row *nebula.Row, colNames []string, rowType reflect.Type) (reflect.Value, error) {
rowVals := row.GetValues()

val := reflect.New(rowType).Elem()
var result reflect.Value
if rowType.Kind() == reflect.Ptr {
result = reflect.New(rowType.Elem())
} else {
result = reflect.New(rowType).Elem()
}
structVal := reflect.Indirect(result)

for fIdx := 0; fIdx < rowType.NumField(); fIdx++ {
f := rowType.Field(fIdx)
for fIdx := 0; fIdx < structVal.Type().NumField(); fIdx++ {
f := structVal.Type().Field(fIdx)
tag := f.Tag.Get("nebula")

if tag == "" {
Expand All @@ -358,19 +364,19 @@ func (res ResultSet) scanRow(row *nebula.Row, colNames []string, rowType reflect

if f.Type.Kind() == reflect.Slice {
list := rowVal.GetLVal()
err := scanListCol(list.Values, val.Field(fIdx), f.Type)
err := scanListCol(list.Values, structVal.Field(fIdx), f.Type)
if err != nil {
return val, err
return result, err
}
} else {
err := scanPrimitiveCol(rowVal, val.Field(fIdx), f.Type.Kind())
err := scanPrimitiveCol(rowVal, structVal.Field(fIdx), f.Type.Kind())
if err != nil {
return val, err
return result, err
}
}
}

return val, nil
return result, nil
}

func scanListCol(vals []*nebula.Value, listVal reflect.Value, sliceType reflect.Type) error {
Expand Down
97 changes: 95 additions & 2 deletions result_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,54 @@ func TestScan(t *testing.T) {
assert.Equal(t, true, testStructList[1].Col3)
}

func TestScanPtr(t *testing.T) {
resp := &graph.ExecutionResponse{
ErrorCode: nebula.ErrorCode_SUCCEEDED,
LatencyInUs: 1000,
Data: getDateset2(),
SpaceName: []byte("test_space"),
ErrorMsg: []byte("test"),
PlanDesc: graph.NewPlanDescription(),
Comment: []byte("test_comment")}
resultSet, err := genResultSet(resp, testTimezone)
if err != nil {
t.Error(err)
}

type testStruct struct {
Col0 int64 `nebula:"col0_int64"`
Col1 float64 `nebula:"col1_float64"`
Col2 string `nebula:"col2_string"`
Col3 bool `nebula:"col3_bool"`
}

var testStructList []*testStruct
err = resultSet.Scan(&testStructList)
if err != nil {
t.Error(err)
}
assert.Equal(t, 1, len(testStructList))
assert.Equal(t, int64(1), testStructList[0].Col0)
assert.Equal(t, float64(2.0), testStructList[0].Col1)
assert.Equal(t, "string", testStructList[0].Col2)
assert.Equal(t, true, testStructList[0].Col3)

// Scan again should work
err = resultSet.Scan(&testStructList)
if err != nil {
t.Error(err)
}
assert.Equal(t, 2, len(testStructList))
assert.Equal(t, int64(1), testStructList[0].Col0)
assert.Equal(t, float64(2.0), testStructList[0].Col1)
assert.Equal(t, "string", testStructList[0].Col2)
assert.Equal(t, true, testStructList[0].Col3)
assert.Equal(t, int64(1), testStructList[1].Col0)
assert.Equal(t, float64(2.0), testStructList[1].Col1)
assert.Equal(t, "string", testStructList[1].Col2)
assert.Equal(t, true, testStructList[1].Col3)
}

func TestScanWithNestStruct(t *testing.T) {
resp := &graph.ExecutionResponse{
ErrorCode: nebula.ErrorCode_SUCCEEDED,
Expand Down Expand Up @@ -916,8 +964,6 @@ func TestScanWithNestStructPtr(t *testing.T) {
Edges []*Friend `nebula:"relationships"`
}

// TODO: actually, the `results` should be []*Result,
// we still need to support this case
var results []Result
err = resultSet.Scan(&results)
if err != nil {
Expand All @@ -939,6 +985,53 @@ func TestScanWithNestStructPtr(t *testing.T) {
assert.Equal(t, 2, len(results))
}

func TestScanWithStructPtr(t *testing.T) {
resp := &graph.ExecutionResponse{
ErrorCode: nebula.ErrorCode_SUCCEEDED,
LatencyInUs: 1000,
Data: getNestDateset(),
SpaceName: []byte("test_space"),
ErrorMsg: []byte("test"),
PlanDesc: graph.NewPlanDescription(),
Comment: []byte("test_comment")}
resultSet, err := genResultSet(resp, testTimezone)
if err != nil {
t.Error(err)
}

type Person struct {
Name string `nebula:"name"`
City string `nebula:"city"`
}
type Friend struct {
CreatedAt string `nebula:"created_at"`
}
type Result struct {
Nodes []*Person `nebula:"nodes"`
Edges []*Friend `nebula:"relationships"`
}

var results []*Result
err = resultSet.Scan(&results)
if err != nil {
t.Error(err)
}
assert.Equal(t, 1, len(results))
assert.Equal(t, "Tom", results[0].Nodes[0].Name)
assert.Equal(t, "Shanghai", results[0].Nodes[0].City)
assert.Equal(t, "Bob", results[0].Nodes[1].Name)
assert.Equal(t, "Hangzhou", results[0].Nodes[1].City)
assert.Equal(t, "2024-07-07", results[0].Edges[0].CreatedAt)
assert.Equal(t, "2024-07-07", results[0].Edges[1].CreatedAt)

// Scan again should work
err = resultSet.Scan(&results)
if err != nil {
t.Error(err)
}
assert.Equal(t, 2, len(results))
}

func TestIntVid(t *testing.T) {
vertex := getVertexInt(101, 3, 5)
node, err := genNode(vertex, testTimezone)
Expand Down

0 comments on commit 0985fcf

Please sign in to comment.