Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
tzhao-viant committed Aug 29, 2024
2 parents 54eefc2 + 311e08a commit 409ce60
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 18 deletions.
56 changes: 42 additions & 14 deletions internal/inference/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (s State) Parameters() state.Parameters {
}

func (s State) Compact(modulePath string) (State, error) {
if err := s.EnsureReflectTypes(modulePath); err != nil {
if err := s.EnsureReflectTypes(modulePath, ""); err != nil {
return nil, err
}
var result = State{}
Expand Down Expand Up @@ -406,8 +406,7 @@ func (s State) ReflectType(pkgPath string, lookupType xreflect.LookupType) (refl
return baseType, nil
}

func (s State) EnsureReflectTypes(modulePath string) error {
typeRegistry := xreflect.NewTypes(xreflect.WithPackagePath(modulePath), xreflect.WithRegistry(extension.Config.Types))
func (s State) EnsureStructQLTypes() error {
for _, param := range s {
if param.Schema == nil {
continue
Expand All @@ -416,17 +415,43 @@ func (s State) EnsureReflectTypes(modulePath string) error {
continue
}
if param.In.Kind == state.KindParam {
sourceParam := s.Lookup(param.In.Name)
if sourceParam == nil {
return fmt.Errorf("failed to lookup param parameter: %v", param.In.Name)
if err := s.enureStructQLType(param); err != nil {
return err
}
if param.SQL != "" {
query, err := structql.NewQuery(param.SQL, sourceParam.Schema.Type(), nil)
if err != nil {
return fmt.Errorf("failed to queryql param %v from %s(%s) due to: %w", param.Name, param.In.Name, sourceParam.Schema.Type().String(), err)
}
param.Schema = state.NewSchema(query.StructType())
param.Schema.DataType = param.Name
continue
}
}
return nil
}

func (s State) enureStructQLType(param *Parameter) error {
sourceParam := s.Lookup(param.In.Name)
if sourceParam == nil {
return fmt.Errorf("failed to lookup param parameter: %v", param.In.Name)
}
if param.SQL != "" {
query, err := structql.NewQuery(param.SQL, sourceParam.Schema.Type(), nil)
if err != nil {
return fmt.Errorf("failed to queryql %v param %v from %s(%s) due to: %w", param.SQL, param.Name, param.In.Name, sourceParam.Schema.Type().String(), err)
}
param.Schema = state.NewSchema(query.StructType())
param.Schema.DataType = param.Name
}
return nil
}

func (s State) EnsureReflectTypes(modulePath string, pkg string) error {
typeRegistry := xreflect.NewTypes(xreflect.WithPackagePath(modulePath), xreflect.WithRegistry(extension.Config.Types))
for _, param := range s {
if param.Schema == nil {
continue
}
if param.Schema.Type() != nil {
continue
}
if param.In.Kind == state.KindParam {
if err := s.enureStructQLType(param); err != nil {
return err
}
continue
}
Expand All @@ -441,7 +466,10 @@ func (s State) EnsureReflectTypes(modulePath string) error {
}
rType, err := types.LookupType(typeRegistry.Lookup, dataType, xreflect.WithPackage(param.Schema.Package))
if err != nil {
return err
rType, err = types.LookupType(typeRegistry.Lookup, dataType, xreflect.WithPackage(pkg))
if err != nil {
return err
}
}
param.Schema.SetType(rType)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/translator/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (s *Service) updateOutputParameters(resource *Resource, rootViewlet *Viewle
}
}

if err = resource.OutputState.EnsureReflectTypes(resource.rule.ModuleLocation); err != nil {
if err = resource.OutputState.EnsureReflectTypes(resource.rule.ModuleLocation, resource.rule.Package()); err != nil {
return err
}

Expand Down
5 changes: 4 additions & 1 deletion internal/translator/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,10 @@ func (r *Resource) extractRuleSetting(dSQL *string) error {
func (r *Resource) expandSQL(viewlet *Viewlet) (*sqlx.SQL, error) {
types := viewlet.Resource.Resource.TypeRegistry()
resourceState := viewlet.Resource.State
_ = resourceState.EnsureReflectTypes(r.rule.GoModuleLocation())
err := resourceState.EnsureStructQLTypes()
if err != nil {
return nil, err
}
sqlState := viewlet.Resource.State.StateForSQL(viewlet.SQL, r.Rule.Root == viewlet.Name)
metaViewSQL := sqlState.MetaViewSQL()
compacted, err := sqlState.Compact(r.rule.ModuleLocation)
Expand Down
9 changes: 9 additions & 0 deletions service/executor/sequencer/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sequencer
import (
"context"
"database/sql"
"fmt"
"github.com/viant/sqlx/io/insert"
"github.com/viant/sqlx/metadata/info/dialect"
"strings"
Expand All @@ -14,6 +15,14 @@ type Service struct {
}

func (s *Service) Next(table string, any interface{}, selector string) error {
err := s.next(table, any, selector)
if err != nil {
return fmt.Errorf("failed to allocate %v sequence due to: %w", table, err)
}
return nil
}

func (s *Service) next(table string, any interface{}, selector string) error {
parts := strings.Split(selector, "/")
aWalker, err := NewWalker(any, parts)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions service/reader/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ func (s *Service) readAll(ctx context.Context, session *Session, collector *view
relationGroup.Wait()
ptr, xslice := collector.Slice()
for i := 0; i < xslice.Len(ptr); i++ {

if actual, ok := xslice.ValuePointerAt(ptr, i).(OnRelationer); ok {
actual.OnRelation(ctx)
continue
Expand Down
3 changes: 1 addition & 2 deletions view/extension/codec/structql.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ func (s *StructQLCodec) evaluateQuery() (*structql.Query, error) {
}
query, err := structql.NewQuery(s.query, s.ownerType, nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to evaludate structql codec: %w", err)
}

s._query = query
return query, nil
}
Expand Down
1 change: 1 addition & 0 deletions view/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,7 @@ func (v *View) SetParameter(name string, selectors *State, value interface{}) er
}

func (v *View) BuildParametrizedSQL(aState state.Parameters, types *xreflect.Types, SQL string, bindingArgs []interface{}, options ...expand2.StateOption) (*sqlx.SQL, error) {

reflectType, err := aState.ReflectType(pkgPath, types.Lookup, state.WithSetMarker())
if err != nil {
return nil, fmt.Errorf("failed to create aState %v type: %w", v.Name, err)
Expand Down

0 comments on commit 409ce60

Please sign in to comment.