diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 02047f1..2201709 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -24,6 +24,12 @@ jobs: - name: Build run: go build -v ./... + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: v1.55.2 + args: --verbose --timeout 10m --fix=false --config=.golangci.yml ./pkg/... + - name: Generate coverage report for coveralls.io run: | go test -coverprofile=/var/tmp/capillaries.p.tmp -cover $(find ./ -name '*_test.go' -printf "%h\n" | sort -u) @@ -75,7 +81,7 @@ jobs: - name: pkg/sc test coverage threshold check env: - TESTCOVERAGE_THRESHOLD: 90.8 + TESTCOVERAGE_THRESHOLD: 91.5 run: | go test -v ./pkg/sc/... -coverprofile coverage.out -covermode count totalCoverage=`go tool cover -func=coverage.out | grep total | grep -Eo '[0-9]+\.[0-9]+'` @@ -88,7 +94,7 @@ jobs: - name: pkg/custom/py_calc test coverage threshold check env: - TESTCOVERAGE_THRESHOLD: 81.3 + TESTCOVERAGE_THRESHOLD: 83.3 run: | go test -v ./pkg/custom/py_calc/... -coverprofile coverage.out -covermode count totalCoverage=`go tool cover -func=coverage.out | grep total | grep -Eo '[0-9]+\.[0-9]+'` diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..4a5c509 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,119 @@ +# https://golangci-lint.run/usage/configuration/#config-file +linters: + disable-all: true + enable: +# - goerr113 + - errcheck +# - goimports + # - paralleltest # missing the call to method parallel, but testify does not seem to work well with parallel test: https://github.com/stretchr/testify/issues/187 + - revive # revive supersedes golint, which is now archived + - staticcheck + - vet +# - forbidigo +run: +# skip-dirs: +# - ^api +# - ^proto +# - ^.git +linters-settings: +# govet: +# fieldalignment: 0 +# forbidigo: +# forbid: +# - p: ^time\.After$ +# msg: "time.After may leak resources. Use time.NewTimer instead." + revive: + severity: error + confidence: 0.8 + enable-all-rules: true + rules: + # Disabled rules + - name: confusing-results + disabled: true + - name: add-constant + disabled: true + - name: argument-limit + disabled: true +# - name: bare-return +# disabled: true +# - name: banned-characters +# disabled: true +# - name: bool-literal-in-expr +# disabled: true +# - name: confusing-naming +# disabled: true + - name: empty-lines + disabled: true +# - name: error-naming +# disabled: true +# - name: errorf +# disabled: true + - name: exported + disabled: true +# - name: file-header +# disabled: true + - name: function-length + disabled: true +# - name: imports-blacklist +# disabled: true +# - name: increment-decrement +# disabled: true + - name: line-length-limit + disabled: true + - name: max-public-structs + disabled: true +# - name: nested-structs +# disabled: true +# - name: package-comments +# disabled: true +# - name: string-format +# disabled: true +# - name: unexported-naming +# disabled: true +# - name: unexported-return +# disabled: true +# - name: unused-parameter +# disabled: true + - name: unused-receiver + disabled: true +# - name: use-any +# disabled: true + - name: var-naming + disabled: true +# - name: empty-block +# disabled: true + - name: flag-parameter + disabled: true + + # Rule tuning + - name: cognitive-complexity + arguments: + - 400 # TODO: do something + - name: cyclomatic + arguments: + - 100 + - name: function-result-limit + arguments: + - 4 + - name: unhandled-error + arguments: + - "fmt.*" + - "bytes.Buffer.*" + - "strings.Builder.*" + - "os.File.Close" + - "io.Closer.Close" + - "zap.Logger.Sync*" +# issues: +# # Exclude cyclomatic and cognitive complexity rules for functional tests in the `tests` root directory. +# exclude-rules: +# - path: ^tests\/.+\.go +# text: "(cyclomatic|cognitive)" +# linters: +# - revive +# - path: _test\.go|^common/persistence\/tests\/.+\.go # Ignore things like err = errors.New("test error") in tests +# linters: +# - goerr113 +# - path: ^tools\/.+\.go +# linters: +# - goerr113 +# - revive diff --git a/.vscode/settings.json b/.vscode/settings.json index ca6fe06..4a5133c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,4 @@ { - "git.ignoreLimitWarning": true + "git.ignoreLimitWarning": true, + "editor.insertSpaces": false } \ No newline at end of file diff --git a/README.md b/README.md index 3113dc6..61fe0d5 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# logo Capillaries
Coverage Status
+# logo Capillaries
[![coveralls](https://coveralls.io/repos/github/capillariesio/capillaries/badge.svg?branch=main)](https://coveralls.io/github/capillariesio/capillaries?branch=main) [![goreport](https://goreportcard.com/badge/github.com/capillariesio/capillaries)](https://goreportcard.com/report/github.com/capillariesio/capillaries)
Capillaries is a data processing framework that: diff --git a/pkg/api/db.go b/pkg/api/db.go index 47c1575..e84b660 100644 --- a/pkg/api/db.go +++ b/pkg/api/db.go @@ -1,117 +1,114 @@ -package api - -import ( - "fmt" - "reflect" - "regexp" - "strings" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/proc" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/wfdb" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "github.com/gocql/gocql" -) - -const ProhibitedKeyspaceNameRegex = "^system" -const AllowedKeyspaceNameRegex = "[a-zA-Z0-9_]+" - -func IsSystemKeyspaceName(keyspace string) bool { - re := regexp.MustCompile(ProhibitedKeyspaceNameRegex) - invalidNamePieceFound := re.FindString(keyspace) - if len(invalidNamePieceFound) > 0 { - return true - } - return false -} - -func checkKeyspaceName(keyspace string) error { - re := regexp.MustCompile(ProhibitedKeyspaceNameRegex) - invalidNamePieceFound := re.FindString(keyspace) - if len(invalidNamePieceFound) > 0 { - return fmt.Errorf("invalid keyspace name [%s]: prohibited regex is [%s]", keyspace, ProhibitedKeyspaceNameRegex) - } - re = regexp.MustCompile(AllowedKeyspaceNameRegex) - if !re.MatchString(keyspace) { - return fmt.Errorf("invalid keyspace name [%s]: allowed regex is [%s]", keyspace, AllowedKeyspaceNameRegex) - } - return nil -} - -// A helper used by Toolbelt get_table_cql cmd, no logging needed -func GetTablesCql(script *sc.ScriptDef, keyspace string, runId int16, startNodeNames []string) string { - sb := strings.Builder{} - sb.WriteString("-- Workflow\n") - sb.WriteString(fmt.Sprintf("%s\n", wfmodel.GetCreateTableCql(reflect.TypeOf(wfmodel.BatchHistoryEvent{}), keyspace, wfmodel.TableNameBatchHistory))) - sb.WriteString(fmt.Sprintf("%s\n", wfmodel.GetCreateTableCql(reflect.TypeOf(wfmodel.NodeHistoryEvent{}), keyspace, wfmodel.TableNameNodeHistory))) - sb.WriteString(fmt.Sprintf("%s\n", wfmodel.GetCreateTableCql(reflect.TypeOf(wfmodel.RunHistoryEvent{}), keyspace, wfmodel.TableNameRunHistory))) - sb.WriteString(fmt.Sprintf("%s\n", wfmodel.GetCreateTableCql(reflect.TypeOf(wfmodel.RunProperties{}), keyspace, wfmodel.TableNameRunAffectedNodes))) - sb.WriteString(fmt.Sprintf("%s\n", wfmodel.GetCreateTableCql(reflect.TypeOf(wfmodel.RunCounter{}), keyspace, wfmodel.TableNameRunCounter))) - qb := cql.QueryBuilder{} - sb.WriteString(fmt.Sprintf("%s\n", qb.Keyspace(keyspace).Write("ks", keyspace).Write("last_run", 0).InsertUnpreparedQuery(wfmodel.TableNameRunCounter, cql.IgnoreIfExists))) - - for _, nodeName := range script.GetAffectedNodes(startNodeNames) { - node, ok := script.ScriptNodes[nodeName] - if !ok || !node.HasTableCreator() { - continue - } - sb.WriteString(fmt.Sprintf("-- %s\n", nodeName)) - sb.WriteString(fmt.Sprintf("%s\n", proc.CreateDataTableCql(keyspace, runId, &node.TableCreator))) - for idxName, idxDef := range node.TableCreator.Indexes { - sb.WriteString(fmt.Sprintf("%s\n", proc.CreateIdxTableCql(keyspace, runId, idxName, idxDef))) - } - } - return sb.String() -} - -// Used by Toolbelt and Webapi -func DropKeyspace(logger *l.Logger, cqlSession *gocql.Session, keyspace string) error { - logger.PushF("api.DropKeyspace") - defer logger.PopF() - - if err := checkKeyspaceName(keyspace); err != nil { - return err - } - - qb := cql.QueryBuilder{} - q := qb. - Keyspace(keyspace). - DropKeyspace() - if err := cqlSession.Query(q).Exec(); err != nil { - return db.WrapDbErrorWithQuery("cannot drop keyspace", q, err) - } - return nil -} - -// wfdb wrapper for webapi use -func HarvestRunLifespans(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runIds []int16) (wfmodel.RunLifespanMap, error) { - logger.PushF("api.HarvestRunLifespans") - defer logger.PopF() - - return wfdb.HarvestRunLifespans(logger, cqlSession, keyspace, runIds) -} - -// wfdb wrapper for webapi use -func GetRunProperties(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runId int16) ([]*wfmodel.RunProperties, error) { - logger.PushF("api.GetRunProperties") - defer logger.PopF() - return wfdb.GetRunProperties(logger, cqlSession, keyspace, runId) -} - -// wfdb wrapper for webapi use -func GetNodeHistoryForRun(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runId int16) ([]*wfmodel.NodeHistoryEvent, error) { - logger.PushF("api.GetNodeHistoryForRun") - defer logger.PopF() - - return wfdb.GetNodeHistoryForRun(logger, cqlSession, keyspace, runId) -} - -// wfdb wrapper for webapi use -func GetRunNodeBatchHistory(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runId int16, nodeName string) ([]*wfmodel.BatchHistoryEvent, error) { - logger.PushF("api.GetRunNodeBatchHistory") - defer logger.PopF() - return wfdb.GetRunNodeBatchHistory(logger, cqlSession, keyspace, runId, nodeName) -} +package api + +import ( + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/proc" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/wfdb" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "github.com/gocql/gocql" +) + +const ProhibitedKeyspaceNameRegex = "^system" +const AllowedKeyspaceNameRegex = "[a-zA-Z0-9_]+" + +func IsSystemKeyspaceName(keyspace string) bool { + re := regexp.MustCompile(ProhibitedKeyspaceNameRegex) + invalidNamePieceFound := re.FindString(keyspace) + return len(invalidNamePieceFound) > 0 +} + +func checkKeyspaceName(keyspace string) error { + re := regexp.MustCompile(ProhibitedKeyspaceNameRegex) + invalidNamePieceFound := re.FindString(keyspace) + if len(invalidNamePieceFound) > 0 { + return fmt.Errorf("invalid keyspace name [%s]: prohibited regex is [%s]", keyspace, ProhibitedKeyspaceNameRegex) + } + re = regexp.MustCompile(AllowedKeyspaceNameRegex) + if !re.MatchString(keyspace) { + return fmt.Errorf("invalid keyspace name [%s]: allowed regex is [%s]", keyspace, AllowedKeyspaceNameRegex) + } + return nil +} + +// A helper used by Toolbelt get_table_cql cmd, no logging needed +func GetTablesCql(script *sc.ScriptDef, keyspace string, runId int16, startNodeNames []string) string { + sb := strings.Builder{} + sb.WriteString("-- Workflow\n") + sb.WriteString(fmt.Sprintf("%s\n", wfmodel.GetCreateTableCql(reflect.TypeOf(wfmodel.BatchHistoryEvent{}), keyspace, wfmodel.TableNameBatchHistory))) + sb.WriteString(fmt.Sprintf("%s\n", wfmodel.GetCreateTableCql(reflect.TypeOf(wfmodel.NodeHistoryEvent{}), keyspace, wfmodel.TableNameNodeHistory))) + sb.WriteString(fmt.Sprintf("%s\n", wfmodel.GetCreateTableCql(reflect.TypeOf(wfmodel.RunHistoryEvent{}), keyspace, wfmodel.TableNameRunHistory))) + sb.WriteString(fmt.Sprintf("%s\n", wfmodel.GetCreateTableCql(reflect.TypeOf(wfmodel.RunProperties{}), keyspace, wfmodel.TableNameRunAffectedNodes))) + sb.WriteString(fmt.Sprintf("%s\n", wfmodel.GetCreateTableCql(reflect.TypeOf(wfmodel.RunCounter{}), keyspace, wfmodel.TableNameRunCounter))) + qb := cql.QueryBuilder{} + sb.WriteString(fmt.Sprintf("%s\n", qb.Keyspace(keyspace).Write("ks", keyspace).Write("last_run", 0).InsertUnpreparedQuery(wfmodel.TableNameRunCounter, cql.IgnoreIfExists))) + + for _, nodeName := range script.GetAffectedNodes(startNodeNames) { + node, ok := script.ScriptNodes[nodeName] + if !ok || !node.HasTableCreator() { + continue + } + sb.WriteString(fmt.Sprintf("-- %s\n", nodeName)) + sb.WriteString(fmt.Sprintf("%s\n", proc.CreateDataTableCql(keyspace, runId, &node.TableCreator))) + for idxName, idxDef := range node.TableCreator.Indexes { + sb.WriteString(fmt.Sprintf("%s\n", proc.CreateIdxTableCql(keyspace, runId, idxName, idxDef))) + } + } + return sb.String() +} + +// Used by Toolbelt and Webapi +func DropKeyspace(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string) error { + logger.PushF("api.DropKeyspace") + defer logger.PopF() + + if err := checkKeyspaceName(keyspace); err != nil { + return err + } + + qb := cql.QueryBuilder{} + q := qb. + Keyspace(keyspace). + DropKeyspace() + if err := cqlSession.Query(q).Exec(); err != nil { + return db.WrapDbErrorWithQuery("cannot drop keyspace", q, err) + } + return nil +} + +// wfdb wrapper for webapi use +func HarvestRunLifespans(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runIds []int16) (wfmodel.RunLifespanMap, error) { + logger.PushF("api.HarvestRunLifespans") + defer logger.PopF() + + return wfdb.HarvestRunLifespans(logger, cqlSession, keyspace, runIds) +} + +// wfdb wrapper for webapi use +func GetRunProperties(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runId int16) ([]*wfmodel.RunProperties, error) { + logger.PushF("api.GetRunProperties") + defer logger.PopF() + return wfdb.GetRunProperties(logger, cqlSession, keyspace, runId) +} + +// wfdb wrapper for webapi use +func GetNodeHistoryForRun(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runId int16) ([]*wfmodel.NodeHistoryEvent, error) { + logger.PushF("api.GetNodeHistoryForRun") + defer logger.PopF() + + return wfdb.GetNodeHistoryForRun(logger, cqlSession, keyspace, runId) +} + +// wfdb wrapper for webapi use +func GetRunNodeBatchHistory(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runId int16, nodeName string) ([]*wfmodel.BatchHistoryEvent, error) { + logger.PushF("api.GetRunNodeBatchHistory") + defer logger.PopF() + return wfdb.GetRunNodeBatchHistory(logger, cqlSession, keyspace, runId, nodeName) +} diff --git a/pkg/api/reporting.go b/pkg/api/reporting.go index 607eb3d..7d75a3d 100644 --- a/pkg/api/reporting.go +++ b/pkg/api/reporting.go @@ -1,92 +1,92 @@ -package api - -import ( - "fmt" - "sort" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "github.com/gocql/gocql" -) - -func GetRunHistory(logger *l.Logger, cqlSession *gocql.Session, keyspace string) ([]*wfmodel.RunHistoryEvent, error) { - logger.PushF("api.GetRunHistory") - defer logger.PopF() - - qb := cql.QueryBuilder{} - q := qb. - Keyspace(keyspace). - Select(wfmodel.TableNameRunHistory, wfmodel.RunHistoryEventAllFields()) - rows, err := cqlSession.Query(q).Iter().SliceMap() - if err != nil { - return nil, db.WrapDbErrorWithQuery("cannot get run history", q, err) - } - - result := make([]*wfmodel.RunHistoryEvent, len(rows)) - for rowIdx, r := range rows { - result[rowIdx], err = wfmodel.NewRunHistoryEventFromMap(r, wfmodel.RunHistoryEventAllFields()) - if err != nil { - return nil, fmt.Errorf("cannot deserialize run history row: %s, %s", err.Error(), q) - } - } - sort.Slice(result, func(i, j int) bool { return result[i].Ts.Before(result[j].Ts) }) - - return result, nil -} - -func GetRunsNodeHistory(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runIds []int16) ([]*wfmodel.NodeHistoryEvent, error) { - logger.PushF("api.GetNodeHistory") - defer logger.PopF() - - qb := cql.QueryBuilder{} - qb.Keyspace(keyspace) - if len(runIds) > 0 { - qb.CondInInt16("run_id", runIds) - } - q := qb.Select(wfmodel.TableNameNodeHistory, wfmodel.NodeHistoryEventAllFields()) - rows, err := cqlSession.Query(q).Iter().SliceMap() - if err != nil { - return nil, db.WrapDbErrorWithQuery("cannot get node history", q, err) - } - - result := make([]*wfmodel.NodeHistoryEvent, len(rows)) - for rowIdx, r := range rows { - result[rowIdx], err = wfmodel.NewNodeHistoryEventFromMap(r, wfmodel.NodeHistoryEventAllFields()) - if err != nil { - return nil, fmt.Errorf("cannot deserialize node history row: %s, %s", err.Error(), q) - } - } - sort.Slice(result, func(i, j int) bool { return result[i].Ts.Before(result[j].Ts) }) - return result, nil -} - -func GetBatchHistory(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runIds []int16, scriptNodes []string) ([]*wfmodel.BatchHistoryEvent, error) { - logger.PushF("api.GetBatchHistory") - defer logger.PopF() - - qb := cql.QueryBuilder{} - qb.Keyspace(keyspace) - if len(runIds) > 0 { - qb.CondInInt16("run_id", runIds) - } - if len(scriptNodes) > 0 { - qb.CondInString("script_node", scriptNodes) - } - q := qb.Select(wfmodel.TableNameBatchHistory, wfmodel.BatchHistoryEventAllFields()) - rows, err := cqlSession.Query(q).Iter().SliceMap() - if err != nil { - return nil, db.WrapDbErrorWithQuery("cannot get batch history", q, err) - } - - result := make([]*wfmodel.BatchHistoryEvent, len(rows)) - for rowIdx, r := range rows { - result[rowIdx], err = wfmodel.NewBatchHistoryEventFromMap(r, wfmodel.BatchHistoryEventAllFields()) - if err != nil { - return nil, fmt.Errorf("cannot deserialize batch history row: %s, %s", err.Error(), q) - } - } - sort.Slice(result, func(i, j int) bool { return result[i].Ts.Before(result[j].Ts) }) - return result, nil -} +package api + +import ( + "fmt" + "sort" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "github.com/gocql/gocql" +) + +func GetRunHistory(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string) ([]*wfmodel.RunHistoryEvent, error) { + logger.PushF("api.GetRunHistory") + defer logger.PopF() + + qb := cql.QueryBuilder{} + q := qb. + Keyspace(keyspace). + Select(wfmodel.TableNameRunHistory, wfmodel.RunHistoryEventAllFields()) + rows, err := cqlSession.Query(q).Iter().SliceMap() + if err != nil { + return nil, db.WrapDbErrorWithQuery("cannot get run history", q, err) + } + + result := make([]*wfmodel.RunHistoryEvent, len(rows)) + for rowIdx, r := range rows { + result[rowIdx], err = wfmodel.NewRunHistoryEventFromMap(r, wfmodel.RunHistoryEventAllFields()) + if err != nil { + return nil, fmt.Errorf("cannot deserialize run history row: %s, %s", err.Error(), q) + } + } + sort.Slice(result, func(i, j int) bool { return result[i].Ts.Before(result[j].Ts) }) + + return result, nil +} + +func GetRunsNodeHistory(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runIds []int16) ([]*wfmodel.NodeHistoryEvent, error) { + logger.PushF("api.GetNodeHistory") + defer logger.PopF() + + qb := cql.QueryBuilder{} + qb.Keyspace(keyspace) + if len(runIds) > 0 { + qb.CondInInt16("run_id", runIds) + } + q := qb.Select(wfmodel.TableNameNodeHistory, wfmodel.NodeHistoryEventAllFields()) + rows, err := cqlSession.Query(q).Iter().SliceMap() + if err != nil { + return nil, db.WrapDbErrorWithQuery("cannot get node history", q, err) + } + + result := make([]*wfmodel.NodeHistoryEvent, len(rows)) + for rowIdx, r := range rows { + result[rowIdx], err = wfmodel.NewNodeHistoryEventFromMap(r, wfmodel.NodeHistoryEventAllFields()) + if err != nil { + return nil, fmt.Errorf("cannot deserialize node history row: %s, %s", err.Error(), q) + } + } + sort.Slice(result, func(i, j int) bool { return result[i].Ts.Before(result[j].Ts) }) + return result, nil +} + +func GetBatchHistory(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runIds []int16, scriptNodes []string) ([]*wfmodel.BatchHistoryEvent, error) { + logger.PushF("api.GetBatchHistory") + defer logger.PopF() + + qb := cql.QueryBuilder{} + qb.Keyspace(keyspace) + if len(runIds) > 0 { + qb.CondInInt16("run_id", runIds) + } + if len(scriptNodes) > 0 { + qb.CondInString("script_node", scriptNodes) + } + q := qb.Select(wfmodel.TableNameBatchHistory, wfmodel.BatchHistoryEventAllFields()) + rows, err := cqlSession.Query(q).Iter().SliceMap() + if err != nil { + return nil, db.WrapDbErrorWithQuery("cannot get batch history", q, err) + } + + result := make([]*wfmodel.BatchHistoryEvent, len(rows)) + for rowIdx, r := range rows { + result[rowIdx], err = wfmodel.NewBatchHistoryEventFromMap(r, wfmodel.BatchHistoryEventAllFields()) + if err != nil { + return nil, fmt.Errorf("cannot deserialize batch history row: %s, %s", err.Error(), q) + } + } + sort.Slice(result, func(i, j int) bool { return result[i].Ts.Before(result[j].Ts) }) + return result, nil +} diff --git a/pkg/api/run.go b/pkg/api/run.go index d43530b..6a0ef57 100644 --- a/pkg/api/run.go +++ b/pkg/api/run.go @@ -1,249 +1,254 @@ -package api - -import ( - "fmt" - "strings" - "time" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/env" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/proc" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/wf" - "github.com/capillariesio/capillaries/pkg/wfdb" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "github.com/gocql/gocql" - amqp "github.com/rabbitmq/amqp091-go" -) - -func StopRun(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runId int16, comment string) error { - logger.PushF("api.StopRun") - defer logger.PopF() - - if err := checkKeyspaceName(keyspace); err != nil { - return err - } - - return wfdb.SetRunStatus(logger, cqlSession, keyspace, runId, wfmodel.RunStop, comment, cql.IgnoreIfExists) -} - -func StartRun(envConfig *env.EnvConfig, logger *l.Logger, amqpChannel *amqp.Channel, scriptFilePath string, paramsFilePath string, cqlSession *gocql.Session, keyspace string, startNodes []string, desc string) (int16, error) { - logger.PushF("api.StartRun") - defer logger.PopF() - - if err := checkKeyspaceName(keyspace); err != nil { - return 0, err - } - - script, err, _ := sc.NewScriptFromFiles(envConfig.CaPath, envConfig.PrivateKeys, scriptFilePath, paramsFilePath, envConfig.CustomProcessorDefFactoryInstance, envConfig.CustomProcessorsSettings) - if err != nil { - return 0, err - } - - // Verify that all start nodes actually present - missingNodesSb := strings.Builder{} - for _, nodeName := range startNodes { - if _, ok := script.ScriptNodes[nodeName]; !ok { - if missingNodesSb.Len() > 0 { - missingNodesSb.WriteString(",") - } - missingNodesSb.WriteString(nodeName) - } - } - if missingNodesSb.Len() > 0 { - return 0, fmt.Errorf("node(s) %s missing from %s, check node name spelling", missingNodesSb.String(), scriptFilePath) - } - - // Get new run_id - runId, err := wfdb.GetNextRunCounter(logger, cqlSession, keyspace) - if err != nil { - return 0, err - } - logger.Info("incremented run_id to %d", runId) - - // Write affected nodes - affectedNodes := script.GetAffectedNodes(startNodes) - if err := wfdb.WriteRunProperties(logger, cqlSession, keyspace, runId, startNodes, affectedNodes, scriptFilePath, paramsFilePath, desc); err != nil { - return 0, err - } - - logger.Info("creating data and idx tables for run %d...", runId) - - // Create all run-specific tables, do not create them in daemon on the fly to avoid INCOMPATIBLE_SCHEMA error - // (apparently, thrown if we try to insert immediately after creating a table) - tablesCreated := 0 - for _, nodeName := range affectedNodes { - node, ok := script.ScriptNodes[nodeName] - if !ok || !node.HasTableCreator() { - continue - } - q := proc.CreateDataTableCql(keyspace, runId, &node.TableCreator) - if err := cqlSession.Query(q).Exec(); err != nil { - return 0, db.WrapDbErrorWithQuery("cannot create data table", q, err) - } - tablesCreated++ - for idxName, idxDef := range node.TableCreator.Indexes { - q = proc.CreateIdxTableCql(keyspace, runId, idxName, idxDef) - if err := cqlSession.Query(q).Exec(); err != nil { - return 0, db.WrapDbErrorWithQuery("cannot create idx table", q, err) - } - tablesCreated++ - } - } - - logger.Info("created %d tables, creating messages to send for run %d...", tablesCreated, runId) - - allMsgs := make([]*wfmodel.Message, 0) - allHandlerExeTypes := make([]string, 0) - for _, affectedNodeName := range affectedNodes { - affectedNode, ok := script.ScriptNodes[affectedNodeName] - if !ok { - return 0, fmt.Errorf("cannot find node to start with: %s in the script %s", affectedNodeName, scriptFilePath) - } - intervals, err := affectedNode.GetTokenIntervalsByNumberOfBatches() - if err != nil { - return 0, err - } - msgs := make([]*wfmodel.Message, len(intervals)) - handlerExeTypes := make([]string, len(intervals)) - for msgIdx := 0; msgIdx < len(intervals); msgIdx++ { - msgs[msgIdx] = &wfmodel.Message{ - Ts: time.Now().UnixMilli(), - MessageType: wfmodel.MessageTypeDataBatch, - Payload: wfmodel.MessagePayloadDataBatch{ - ScriptURI: scriptFilePath, - ScriptParamsURI: paramsFilePath, - DataKeyspace: keyspace, - RunId: runId, - TargetNodeName: affectedNodeName, - FirstToken: intervals[msgIdx][0], - LastToken: intervals[msgIdx][1], - BatchIdx: int16(msgIdx), - BatchesTotal: int16(len(intervals))}} - handlerExeTypes[msgIdx] = affectedNode.HandlerExeType - } - allMsgs = append(allMsgs, msgs...) - allHandlerExeTypes = append(allHandlerExeTypes, handlerExeTypes...) - } - - // Write status 'start', fail if a record for run_id is already there (too many operators) - if err := wfdb.SetRunStatus(logger, cqlSession, keyspace, runId, wfmodel.RunStart, "api.StartRun", cql.ThrowIfExists); err != nil { - return 0, err - } - - logger.Info("sending %d messages for run %d...", len(allMsgs), runId) - - // Send one msg after another - // TODO: there easily may be hundreds of messages, can we send them in a single shot? - for msgIdx := 0; msgIdx < len(allMsgs); msgIdx++ { - msgOutBytes, errMsgOut := allMsgs[msgIdx].Serialize() - if errMsgOut != nil { - return 0, fmt.Errorf("cannot serialize outgoing message %v. %v", allMsgs[msgIdx].ToString(), errMsgOut) - } - - errSend := amqpChannel.Publish( - envConfig.Amqp.Exchange, // exchange - allHandlerExeTypes[msgIdx], // routing key / hander exe type - false, // mandatory - false, // immediate - amqp.Publishing{ContentType: "text/plain", Body: msgOutBytes}) - if errSend != nil { - // Reconnect required - return 0, fmt.Errorf("failed to send next message: %v\n", errSend) - } - } - return runId, nil -} - -func RunNode(envConfig *env.EnvConfig, logger *l.Logger, nodeName string, runId int16, scriptFilePath string, paramsFilePath string, cqlSession *gocql.Session, keyspace string) (int16, error) { - logger.PushF("api.RunNode") - defer logger.PopF() - - script, err, _ := sc.NewScriptFromFiles(envConfig.CaPath, envConfig.PrivateKeys, scriptFilePath, paramsFilePath, envConfig.CustomProcessorDefFactoryInstance, envConfig.CustomProcessorsSettings) - if err != nil { - return 0, err - } - // Get new run_id if needed - if runId == 0 { - runId, err = wfdb.GetNextRunCounter(logger, cqlSession, keyspace) - if err != nil { - return 0, err - } - logger.Info("incremented run_id to %d", runId) - } - - // Calculate intervals for this node - node, ok := script.ScriptNodes[nodeName] - if !ok { - return 0, fmt.Errorf("cannot find node to start with: %s in the script %s", nodeName, scriptFilePath) - } - - intervals, err := node.GetTokenIntervalsByNumberOfBatches() - if err != nil { - return 0, err - } - - // Write affected nodes - affectedNodes := script.GetAffectedNodes([]string{nodeName}) - if err := wfdb.WriteRunProperties(logger, cqlSession, keyspace, runId, []string{nodeName}, affectedNodes, scriptFilePath, paramsFilePath, "started by Toolbelt direct RunNode"); err != nil { - return 0, err - } - - // Write status 'start', fail if a record for run_id is already there (too many operators) - if err := wfdb.SetRunStatus(logger, cqlSession, keyspace, runId, wfmodel.RunStart, fmt.Sprintf("Toolbelt RunNode(%s)", nodeName), cql.ThrowIfExists); err != nil { - return 0, err - } - - logger.Info("creating data and idx tables for run %d...", runId) - - // Create all run-specific tables, do not create them in daemon on the fly to avoid INCOMPATIBLE_SCHEMA error - // (apparently, thrown if we try to insert immediately after creating a table) - tablesCreated := 0 - for _, nodeName := range affectedNodes { - node, ok := script.ScriptNodes[nodeName] - if !ok || !node.HasTableCreator() { - continue - } - q := proc.CreateDataTableCql(keyspace, runId, &node.TableCreator) - if err := cqlSession.Query(q).Exec(); err != nil { - return 0, db.WrapDbErrorWithQuery("cannot create data table", q, err) - } - tablesCreated++ - for idxName, idxDef := range node.TableCreator.Indexes { - q = proc.CreateIdxTableCql(keyspace, runId, idxName, idxDef) - if err := cqlSession.Query(q).Exec(); err != nil { - return 0, db.WrapDbErrorWithQuery("cannot create idx table", q, err) - } - tablesCreated++ - } - } - - logger.Info("created %d tables, creating messages to send for run %d...", tablesCreated, runId) - - for i := 0; i < len(intervals); i++ { - batchStartTs := time.Now() - logger.Info("BatchStarted: [%d,%d]...", intervals[i][0], intervals[i][1]) - dataBatchInfo := wfmodel.MessagePayloadDataBatch{ - ScriptURI: scriptFilePath, - ScriptParamsURI: paramsFilePath, - DataKeyspace: keyspace, - RunId: runId, - TargetNodeName: nodeName, - FirstToken: intervals[i][0], - LastToken: intervals[i][1], - BatchIdx: int16(i), - BatchesTotal: int16(len(intervals))} - - if daemonCmd := wf.ProcessDataBatchMsg(envConfig, logger, batchStartTs.UnixMilli(), &dataBatchInfo); daemonCmd != wf.DaemonCmdAckSuccess { - return 0, fmt.Errorf("processor returned daemon cmd %d, assuming failure, check the logs", daemonCmd) - } - logger.Info("BatchComplete: [%d,%d], %.3fs", intervals[i][0], intervals[i][1], time.Since(batchStartTs).Seconds()) - } - if err := wfdb.SetRunStatus(logger, cqlSession, keyspace, runId, wfmodel.RunComplete, fmt.Sprintf("Toolbelt RunNode(%s), run successful", nodeName), cql.IgnoreIfExists); err != nil { - return 0, err - } - - return runId, nil -} +package api + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/env" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/proc" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/wf" + "github.com/capillariesio/capillaries/pkg/wfdb" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "github.com/gocql/gocql" + amqp "github.com/rabbitmq/amqp091-go" +) + +func StopRun(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runId int16, comment string) error { + logger.PushF("api.StopRun") + defer logger.PopF() + + if err := checkKeyspaceName(keyspace); err != nil { + return err + } + + return wfdb.SetRunStatus(logger, cqlSession, keyspace, runId, wfmodel.RunStop, comment, cql.IgnoreIfExists) +} + +func StartRun(envConfig *env.EnvConfig, logger *l.CapiLogger, amqpChannel *amqp.Channel, scriptFilePath string, paramsFilePath string, cqlSession *gocql.Session, keyspace string, startNodes []string, desc string) (int16, error) { + logger.PushF("api.StartRun") + defer logger.PopF() + + if err := checkKeyspaceName(keyspace); err != nil { + return 0, err + } + + script, _, err := sc.NewScriptFromFiles(envConfig.CaPath, envConfig.PrivateKeys, scriptFilePath, paramsFilePath, envConfig.CustomProcessorDefFactoryInstance, envConfig.CustomProcessorsSettings) + if err != nil { + return 0, err + } + + // Verify that all start nodes actually present + missingNodesSb := strings.Builder{} + for _, nodeName := range startNodes { + if _, ok := script.ScriptNodes[nodeName]; !ok { + if missingNodesSb.Len() > 0 { + missingNodesSb.WriteString(",") + } + missingNodesSb.WriteString(nodeName) + } + } + if missingNodesSb.Len() > 0 { + return 0, fmt.Errorf("node(s) %s missing from %s, check node name spelling", missingNodesSb.String(), scriptFilePath) + } + + // Get new run_id + runId, err := wfdb.GetNextRunCounter(logger, cqlSession, keyspace) + if err != nil { + return 0, err + } + logger.Info("incremented run_id to %d", runId) + + // Write affected nodes + affectedNodes := script.GetAffectedNodes(startNodes) + if err := wfdb.WriteRunProperties(cqlSession, keyspace, runId, startNodes, affectedNodes, scriptFilePath, paramsFilePath, desc); err != nil { + return 0, err + } + + logger.Info("creating data and idx tables for run %d...", runId) + + // Create all run-specific tables, do not create them in daemon on the fly to avoid INCOMPATIBLE_SCHEMA error + // (apparently, thrown if we try to insert immediately after creating a table) + tablesCreated := 0 + for _, nodeName := range affectedNodes { + node, ok := script.ScriptNodes[nodeName] + if !ok || !node.HasTableCreator() { + continue + } + q := proc.CreateDataTableCql(keyspace, runId, &node.TableCreator) + if err := cqlSession.Query(q).Exec(); err != nil { + return 0, db.WrapDbErrorWithQuery("cannot create data table", q, err) + } + tablesCreated++ + for idxName, idxDef := range node.TableCreator.Indexes { + q = proc.CreateIdxTableCql(keyspace, runId, idxName, idxDef) + if err := cqlSession.Query(q).Exec(); err != nil { + return 0, db.WrapDbErrorWithQuery("cannot create idx table", q, err) + } + tablesCreated++ + } + } + + logger.Info("created %d tables, creating messages to send for run %d...", tablesCreated, runId) + + allMsgs := make([]*wfmodel.Message, 0) + allHandlerExeTypes := make([]string, 0) + for _, affectedNodeName := range affectedNodes { + affectedNode, ok := script.ScriptNodes[affectedNodeName] + if !ok { + return 0, fmt.Errorf("cannot find node to start with: %s in the script %s", affectedNodeName, scriptFilePath) + } + intervals, err := affectedNode.GetTokenIntervalsByNumberOfBatches() + if err != nil { + return 0, err + } + msgs := make([]*wfmodel.Message, len(intervals)) + handlerExeTypes := make([]string, len(intervals)) + for msgIdx := 0; msgIdx < len(intervals); msgIdx++ { + msgs[msgIdx] = &wfmodel.Message{ + Ts: time.Now().UnixMilli(), + MessageType: wfmodel.MessageTypeDataBatch, + Payload: wfmodel.MessagePayloadDataBatch{ + ScriptURI: scriptFilePath, + ScriptParamsURI: paramsFilePath, + DataKeyspace: keyspace, + RunId: runId, + TargetNodeName: affectedNodeName, + FirstToken: intervals[msgIdx][0], + LastToken: intervals[msgIdx][1], + BatchIdx: int16(msgIdx), + BatchesTotal: int16(len(intervals))}} + handlerExeTypes[msgIdx] = affectedNode.HandlerExeType + } + allMsgs = append(allMsgs, msgs...) + allHandlerExeTypes = append(allHandlerExeTypes, handlerExeTypes...) + } + + // Write status 'start', fail if a record for run_id is already there (too many operators) + if err := wfdb.SetRunStatus(logger, cqlSession, keyspace, runId, wfmodel.RunStart, "api.StartRun", cql.ThrowIfExists); err != nil { + return 0, err + } + + logger.Info("sending %d messages for run %d...", len(allMsgs), runId) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Send one msg after another + // TODO: there easily may be hundreds of messages, can we send them in a single shot? + for msgIdx := 0; msgIdx < len(allMsgs); msgIdx++ { + msgOutBytes, errMsgOut := allMsgs[msgIdx].Serialize() + if errMsgOut != nil { + return 0, fmt.Errorf("cannot serialize outgoing message %v. %v", allMsgs[msgIdx].ToString(), errMsgOut) + } + + errSend := amqpChannel.PublishWithContext( + ctx, + envConfig.Amqp.Exchange, // exchange + allHandlerExeTypes[msgIdx], // routing key / hander exe type + false, // mandatory + false, // immediate + amqp.Publishing{ContentType: "text/plain", Body: msgOutBytes}) + if errSend != nil { + // Reconnect required + return 0, fmt.Errorf("failed to send next message: %s", errSend.Error()) + } + } + return runId, nil +} + +func RunNode(envConfig *env.EnvConfig, logger *l.CapiLogger, nodeName string, runId int16, scriptFilePath string, paramsFilePath string, cqlSession *gocql.Session, keyspace string) (int16, error) { + logger.PushF("api.RunNode") + defer logger.PopF() + + script, _, err := sc.NewScriptFromFiles(envConfig.CaPath, envConfig.PrivateKeys, scriptFilePath, paramsFilePath, envConfig.CustomProcessorDefFactoryInstance, envConfig.CustomProcessorsSettings) + if err != nil { + return 0, err + } + // Get new run_id if needed + if runId == 0 { + runId, err = wfdb.GetNextRunCounter(logger, cqlSession, keyspace) + if err != nil { + return 0, err + } + logger.Info("incremented run_id to %d", runId) + } + + // Calculate intervals for this node + node, ok := script.ScriptNodes[nodeName] + if !ok { + return 0, fmt.Errorf("cannot find node to start with: %s in the script %s", nodeName, scriptFilePath) + } + + intervals, err := node.GetTokenIntervalsByNumberOfBatches() + if err != nil { + return 0, err + } + + // Write affected nodes + affectedNodes := script.GetAffectedNodes([]string{nodeName}) + if err := wfdb.WriteRunProperties(cqlSession, keyspace, runId, []string{nodeName}, affectedNodes, scriptFilePath, paramsFilePath, "started by Toolbelt direct RunNode"); err != nil { + return 0, err + } + + // Write status 'start', fail if a record for run_id is already there (too many operators) + if err := wfdb.SetRunStatus(logger, cqlSession, keyspace, runId, wfmodel.RunStart, fmt.Sprintf("Toolbelt RunNode(%s)", nodeName), cql.ThrowIfExists); err != nil { + return 0, err + } + + logger.Info("creating data and idx tables for run %d...", runId) + + // Create all run-specific tables, do not create them in daemon on the fly to avoid INCOMPATIBLE_SCHEMA error + // (apparently, thrown if we try to insert immediately after creating a table) + tablesCreated := 0 + for _, nodeName := range affectedNodes { + node, ok := script.ScriptNodes[nodeName] + if !ok || !node.HasTableCreator() { + continue + } + q := proc.CreateDataTableCql(keyspace, runId, &node.TableCreator) + if err := cqlSession.Query(q).Exec(); err != nil { + return 0, db.WrapDbErrorWithQuery("cannot create data table", q, err) + } + tablesCreated++ + for idxName, idxDef := range node.TableCreator.Indexes { + q = proc.CreateIdxTableCql(keyspace, runId, idxName, idxDef) + if err := cqlSession.Query(q).Exec(); err != nil { + return 0, db.WrapDbErrorWithQuery("cannot create idx table", q, err) + } + tablesCreated++ + } + } + + logger.Info("created %d tables, creating messages to send for run %d...", tablesCreated, runId) + + for i := 0; i < len(intervals); i++ { + batchStartTs := time.Now() + logger.Info("BatchStarted: [%d,%d]...", intervals[i][0], intervals[i][1]) + dataBatchInfo := wfmodel.MessagePayloadDataBatch{ + ScriptURI: scriptFilePath, + ScriptParamsURI: paramsFilePath, + DataKeyspace: keyspace, + RunId: runId, + TargetNodeName: nodeName, + FirstToken: intervals[i][0], + LastToken: intervals[i][1], + BatchIdx: int16(i), + BatchesTotal: int16(len(intervals))} + + if daemonCmd := wf.ProcessDataBatchMsg(envConfig, logger, batchStartTs.UnixMilli(), &dataBatchInfo); daemonCmd != wf.DaemonCmdAckSuccess { + return 0, fmt.Errorf("processor returned daemon cmd %d, assuming failure, check the logs", daemonCmd) + } + logger.Info("BatchComplete: [%d,%d], %.3fs", intervals[i][0], intervals[i][1], time.Since(batchStartTs).Seconds()) + } + if err := wfdb.SetRunStatus(logger, cqlSession, keyspace, runId, wfmodel.RunComplete, fmt.Sprintf("Toolbelt RunNode(%s), run successful", nodeName), cql.IgnoreIfExists); err != nil { + return 0, err + } + + return runId, nil +} diff --git a/pkg/cql/cql_query_builder.go b/pkg/cql/cql_query_builder.go index 1d0fd57..0a6785b 100644 --- a/pkg/cql/cql_query_builder.go +++ b/pkg/cql/cql_query_builder.go @@ -1,481 +1,480 @@ -package cql - -import ( - "fmt" - "math" - "strings" - "time" - - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/shopspring/decimal" - "gopkg.in/inf.v0" -) - -type IfNotExistsType int - -const ( - IgnoreIfExists IfNotExistsType = 1 - ThrowIfExists IfNotExistsType = 0 -) - -type QuotePolicyType int - -const ( - LeaveQuoteAsIs QuotePolicyType = iota - ForceUnquote -) - -/* -Data/idx table name for each run needs run id as a suffix -*/ -func RunIdSuffix(runId int16) string { - if runId > 0 { - return fmt.Sprintf("_%05d", runId) - } else { - return "" - } -} - -/* -Helper used in query builder -*/ -func valueToString(value interface{}, quotePolicy QuotePolicyType) string { - switch v := value.(type) { - case string: - if quotePolicy == ForceUnquote { - return strings.ReplaceAll(v, "'", "''") - } else { - return fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''")) - } - case time.Time: - if quotePolicy == ForceUnquote { - return v.Format(sc.CassandraDatetimeFormat) - } else { - return v.Format(fmt.Sprintf("'%s'", sc.CassandraDatetimeFormat)) - } - default: - return fmt.Sprintf("%v", v) - } -} - -func valueToCqlParam(value interface{}) interface{} { - switch v := value.(type) { - case decimal.Decimal: - f, _ := v.Float64() - scaled := int64(math.Round(f * 100)) - return inf.NewDec(scaled, 2) - default: - return v - } -} - -type queryBuilderColumnDefs struct { - Columns [256]string - Types [256]string - Len int -} - -func (cd *queryBuilderColumnDefs) add(column string, fieldType sc.TableFieldType) { - cd.Columns[cd.Len] = column - switch fieldType { - case sc.FieldTypeInt: - cd.Types[cd.Len] = "BIGINT" - case sc.FieldTypeDecimal2: - cd.Types[cd.Len] = "DECIMAL" - case sc.FieldTypeFloat: - cd.Types[cd.Len] = "DOUBLE" - case sc.FieldTypeString: - cd.Types[cd.Len] = "TEXT" - case sc.FieldTypeBool: - cd.Types[cd.Len] = "BOOLEAN" - case sc.FieldTypeDateTime: - cd.Types[cd.Len] = "TIMESTAMP" // Cassandra stores milliseconds since epoch - default: - cd.Types[cd.Len] = fmt.Sprintf("UKNOWN_TYPE_%s", fieldType) - } - cd.Len++ -} - -type queryBuilderPreparedColumnData struct { - Columns [256]string - Values [256]interface{} - ColumnIdxMap map[string]int - ValueIdxMap map[string]int -} - -func (cd *queryBuilderPreparedColumnData) addColumnName(column string) error { - if _, ok := cd.ColumnIdxMap[column]; ok { - return fmt.Errorf("cannot add same column %s to a prepared query twice: %v", column, cd.Columns) - } - curColCount := len(cd.ColumnIdxMap) - cd.Columns[curColCount] = column - cd.ColumnIdxMap[column] = curColCount - return nil -} -func (cd *queryBuilderPreparedColumnData) addColumnValue(column string, value interface{}) error { - colIdx, ok := cd.ColumnIdxMap[column] - if !ok { - return fmt.Errorf("cannot set value for non-prepared column %s, available columns are %v", column, cd.Columns) - } - cd.Values[colIdx] = valueToCqlParam(value) - cd.ValueIdxMap[column] = colIdx - return nil -} - -type queryBuilderColumnData struct { - Columns [256]string - Values [256]string - Len int -} - -func (cd *queryBuilderColumnData) add(column string, value interface{}, quotePolicy QuotePolicyType) { - cd.Values[cd.Len] = valueToString(value, quotePolicy) - cd.Columns[cd.Len] = column - cd.Len++ -} - -type queryBuilderConditions struct { - Items [256]string - Len int -} - -func (cc *queryBuilderConditions) addInInt(column string, values []int64) { - inValues := make([]string, len(values)) - for i, v := range values { - inValues[i] = fmt.Sprintf("%d", v) - } - cc.Items[cc.Len] = fmt.Sprintf("%s IN ( %s )", column, strings.Join(inValues, ", ")) - cc.Len++ -} - -func (cc *queryBuilderConditions) addInInt16(column string, values []int16) { - inValues := make([]string, len(values)) - for i, v := range values { - inValues[i] = fmt.Sprintf("%d", v) - } - cc.Items[cc.Len] = fmt.Sprintf("%s IN ( %s )", column, strings.Join(inValues, ", ")) - cc.Len++ -} - -func (cc *queryBuilderConditions) addInString(column string, values []string) { - cc.Items[cc.Len] = fmt.Sprintf("%s IN ( '%s' )", column, strings.Join(values, "', '")) - cc.Len++ -} - -func (cc *queryBuilderConditions) addSimple(column string, op string, value interface{}) { - cc.Items[cc.Len] = fmt.Sprintf("%s %s %s", column, op, valueToString(value, LeaveQuoteAsIs)) - cc.Len++ -} -func (cc *queryBuilderConditions) addSimpleForceUnquote(column string, op string, value interface{}) { - cc.Items[cc.Len] = fmt.Sprintf("%s %s %s", column, op, valueToString(value, ForceUnquote)) - cc.Len++ -} - -/* -QueryBuilder - very simple cql query builder that does not require db connection -*/ -type QueryBuilder struct { - ColumnDefs queryBuilderColumnDefs - PartitionKeyColumns []string - ClusteringKeyColumns []string - ColumnData queryBuilderColumnData - PreparedColumnData queryBuilderPreparedColumnData - Conditions queryBuilderConditions - IfConditions queryBuilderConditions - SelectLimit int - FormattedKeyspace string - OrderByColumns []string -} - -func NewQB() *QueryBuilder { - var qb QueryBuilder - qb.PreparedColumnData.ColumnIdxMap = map[string]int{} - qb.PreparedColumnData.ValueIdxMap = map[string]int{} - return &qb -} - -func (qb *QueryBuilder) ColumnDef(column string, fieldType sc.TableFieldType) *QueryBuilder { - qb.ColumnDefs.add(column, fieldType) - return qb -} - -/* - */ -func (qb *QueryBuilder) PartitionKey(column ...string) *QueryBuilder { - qb.PartitionKeyColumns = column - return qb -} -func (qb *QueryBuilder) ClusteringKey(column ...string) *QueryBuilder { - qb.ClusteringKeyColumns = column - return qb -} - -/* -Keyspace - specify keyspace (optional) -*/ -func (qb *QueryBuilder) Keyspace(keyspace string) *QueryBuilder { - if trimmedKeyspace := strings.TrimSpace(keyspace); len(trimmedKeyspace) > 0 { - qb.FormattedKeyspace = fmt.Sprintf("%s.", trimmedKeyspace) - } else { - qb.FormattedKeyspace = "" - } - return qb -} - -func (qb *QueryBuilder) Limit(limit int) *QueryBuilder { - qb.SelectLimit = limit - return qb -} - -/* -Write - add a column for INSERT or UPDATE -*/ -func (qb *QueryBuilder) Write(column string, value interface{}) *QueryBuilder { - qb.ColumnData.add(column, value, LeaveQuoteAsIs) - return qb -} - -func (qb *QueryBuilder) WritePreparedColumn(column string) error { - return qb.PreparedColumnData.addColumnName(column) -} - -func (qb *QueryBuilder) WritePreparedValue(column string, value interface{}) error { - return qb.PreparedColumnData.addColumnValue(column, value) -} - -/* -WriteForceUnquote - add a column for INSERT or UPDATE -*/ -func (qb *QueryBuilder) WriteForceUnquote(column string, value interface{}) *QueryBuilder { - qb.ColumnData.add(column, value, ForceUnquote) - return qb -} - -/* -Cond - add condition for SELECT, UPDATE or DELETE -*/ -func (qb *QueryBuilder) Cond(column string, op string, value interface{}) *QueryBuilder { - qb.Conditions.addSimple(column, op, value) - return qb -} - -func (qb *QueryBuilder) CondPrepared(column string, op string) *QueryBuilder { - qb.Conditions.addSimpleForceUnquote(column, op, "?") - return qb -} - -func (qb *QueryBuilder) CondInPrepared(column string) *QueryBuilder { - qb.Conditions.addSimpleForceUnquote(column, "IN", "?") - return qb -} - -/* -CondIn - add IN condition for SELECT, UPDATE or DELETE -*/ -func (qb *QueryBuilder) CondInInt(column string, values []int64) *QueryBuilder { - qb.Conditions.addInInt(column, values) - return qb -} - -func (qb *QueryBuilder) CondInInt16(column string, values []int16) *QueryBuilder { - qb.Conditions.addInInt16(column, values) - return qb -} - -func (qb *QueryBuilder) CondInString(column string, values []string) *QueryBuilder { - qb.Conditions.addInString(column, values) - return qb -} - -func (qb *QueryBuilder) OrderBy(columns ...string) *QueryBuilder { - qb.OrderByColumns = columns - return qb -} - -func (qb *QueryBuilder) If(column string, op string, value interface{}) *QueryBuilder { - qb.IfConditions.addSimple(column, op, value) - return qb -} - -/* -Insert - build INSERT query -*/ -const RunIdForEmptyRun = -1 - -func (qb *QueryBuilder) InsertUnpreparedQuery(tableName string, ifNotExists IfNotExistsType) string { - return qb.insertRunUnpreparedQuery(tableName, RunIdForEmptyRun, ifNotExists) -} -func (qb *QueryBuilder) insertRunUnpreparedQuery(tableName string, runId int16, ifNotExists IfNotExistsType) string { - ifNotExistsStr := "" - if ifNotExists == IgnoreIfExists { - ifNotExistsStr = "IF NOT EXISTS" - } - q := fmt.Sprintf("INSERT INTO %s%s%s ( %s ) VALUES ( %s ) %s;", - qb.FormattedKeyspace, - tableName, - RunIdSuffix(runId), - strings.Join(qb.ColumnData.Columns[:qb.ColumnData.Len], ", "), - strings.Join(qb.ColumnData.Values[:qb.ColumnData.Len], ", "), - ifNotExistsStr) - if runId == 0 { - q = "INVALID runId: " + q - } - return q -} - -func (qb *QueryBuilder) InsertRunPreparedQuery(tableName string, runId int16, ifNotExists IfNotExistsType) (string, error) { - ifNotExistsStr := "" - if ifNotExists == IgnoreIfExists { - ifNotExistsStr = "IF NOT EXISTS" - } - columnCount := len(qb.PreparedColumnData.ColumnIdxMap) - paramArray := make([]string, columnCount) - for paramIdx := 0; paramIdx < columnCount; paramIdx++ { - paramArray[paramIdx] = "?" - } - q := fmt.Sprintf("INSERT INTO %s%s%s ( %s ) VALUES ( %s ) %s;", - qb.FormattedKeyspace, - tableName, - RunIdSuffix(runId), - strings.Join(qb.PreparedColumnData.Columns[:columnCount], ", "), - strings.Join(paramArray, ", "), - ifNotExistsStr) - if runId == 0 { - return "", fmt.Errorf("invalid runId=0 in %s", q) - } - return q, nil -} - -func (qb *QueryBuilder) InsertRunParams() ([]interface{}, error) { - if len(qb.PreparedColumnData.ColumnIdxMap) != len(qb.PreparedColumnData.ValueIdxMap) { - return nil, fmt.Errorf("cannot produce insert params, length mismatch: columns %v, values %v", qb.PreparedColumnData.ColumnIdxMap, qb.PreparedColumnData.ValueIdxMap) - } - return qb.PreparedColumnData.Values[:len(qb.PreparedColumnData.ValueIdxMap)], nil -} - -/* -Select - build SELECT query -*/ -func (qb *QueryBuilder) Select(tableName string, columns []string) string { - return qb.SelectRun(tableName, RunIdForEmptyRun, columns) -} -func (qb *QueryBuilder) SelectRun(tableName string, runId int16, columns []string) string { - b := strings.Builder{} - if runId == 0 { - b.WriteString("INVALID runId: ") - } - b.WriteString(fmt.Sprintf("SELECT %s FROM %s%s%s", - strings.Join(columns, ", "), - qb.FormattedKeyspace, - tableName, - RunIdSuffix(runId))) - if qb.Conditions.Len > 0 { - b.WriteString(" WHERE ") - b.WriteString(strings.Join(qb.Conditions.Items[:qb.Conditions.Len], " AND ")) - } - if len(qb.OrderByColumns) > 0 { - b.WriteString(fmt.Sprintf(" ORDER BY %s ", strings.Join(qb.OrderByColumns, ","))) - } - if qb.SelectLimit > 0 { - b.WriteString(fmt.Sprintf(" LIMIT %d", qb.SelectLimit)) - } - b.WriteString(";") - - return b.String() -} - -/* -Delete - build DELETE query -*/ -func (qb *QueryBuilder) Delete(tableName string) string { - return qb.DeleteRun(tableName, RunIdForEmptyRun) -} -func (qb *QueryBuilder) DeleteRun(tableName string, runId int16) string { - q := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s", - qb.FormattedKeyspace, - tableName, - RunIdSuffix(runId), - strings.Join(qb.Conditions.Items[:qb.Conditions.Len], " AND ")) - if runId == 0 { - q = "DEV ERROR, INVALID runId: " + q - } - - return q -} - -/* -Update - build UPDATE query -*/ -func (qb *QueryBuilder) Update(tableName string) string { - return qb.UpdateRun(tableName, RunIdForEmptyRun) -} -func (qb *QueryBuilder) UpdateRun(tableName string, runId int16) string { - var assignments [256]string - for i := 0; i < qb.ColumnData.Len; i++ { - assignments[i] = fmt.Sprintf("%s = %s", qb.ColumnData.Columns[i], qb.ColumnData.Values[i]) - } - q := fmt.Sprintf("UPDATE %s%s%s SET %s WHERE %s", - qb.FormattedKeyspace, - tableName, - RunIdSuffix(runId), - strings.Join(assignments[:qb.ColumnData.Len], ", "), - strings.Join(qb.Conditions.Items[:qb.Conditions.Len], " AND ")) - - if qb.IfConditions.Len > 0 { - q += " IF " + strings.Join(qb.IfConditions.Items[:qb.IfConditions.Len], " AND ") - } - if runId == 0 { - q = "INVALID runId: " + q - } - return q -} - -func (qb *QueryBuilder) Create(tableName string, ifNotExists IfNotExistsType) string { - return qb.CreateRun(tableName, RunIdForEmptyRun, ifNotExists) -} -func (qb *QueryBuilder) CreateRun(tableName string, runId int16, ifNotExists IfNotExistsType) string { - var b strings.Builder - if runId == 0 { - b.WriteString("INVALID runId: ") - } - b.WriteString("CREATE TABLE ") - if ifNotExists == IgnoreIfExists { - b.WriteString("IF NOT EXISTS ") - } - b.WriteString(fmt.Sprintf("%s%s%s ( ", qb.FormattedKeyspace, tableName, RunIdSuffix(runId))) - for i := 0; i < qb.ColumnDefs.Len; i++ { - b.WriteString(qb.ColumnDefs.Columns[i]) - b.WriteString(" ") - b.WriteString(qb.ColumnDefs.Types[i]) - if i < qb.ColumnDefs.Len-1 { - b.WriteString(", ") - } - } - if len(qb.PartitionKeyColumns) > 0 { - b.WriteString(", ") - b.WriteString(fmt.Sprintf("PRIMARY KEY((%s)", strings.Join(qb.PartitionKeyColumns, ", "))) - if len(qb.ClusteringKeyColumns) > 0 { - b.WriteString(", ") - b.WriteString(strings.Join(qb.ClusteringKeyColumns, ", ")) - } - b.WriteString(")") - } - b.WriteString(");") - return b.String() -} - -// Currently not used, leave it commented out just in case -// func (qb *QueryBuilder) Drop(tableName string) string { -// return qb.DropRun(tableName, RunIdForEmptyRun) -// } -// func (qb *QueryBuilder) DropRun(tableName string, runId int16) string { -// q := fmt.Sprintf("DROP TABLE IF EXISTS %s%s%s", qb.FormattedKeyspace, tableName, RunIdSuffix(runId)) -// if runId == 0 { -// q = "INVALID runId: " + q -// } -// return q -// } - -func (qb *QueryBuilder) DropKeyspace() string { - return fmt.Sprintf("DROP KEYSPACE IF EXISTS %s", strings.ReplaceAll(qb.FormattedKeyspace, ".", "")) -} +package cql + +import ( + "fmt" + "math" + "strings" + "time" + + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/shopspring/decimal" + "gopkg.in/inf.v0" +) + +type IfNotExistsType int + +const ( + IgnoreIfExists IfNotExistsType = 1 + ThrowIfExists IfNotExistsType = 0 +) + +type QuotePolicyType int + +const ( + LeaveQuoteAsIs QuotePolicyType = iota + ForceUnquote +) + +/* +Data/idx table name for each run needs run id as a suffix +*/ +func RunIdSuffix(runId int16) string { + if runId > 0 { + return fmt.Sprintf("_%05d", runId) + } + return "" +} + +/* +Helper used in query builder +*/ +func valueToString(value any, quotePolicy QuotePolicyType) string { + switch v := value.(type) { + case string: + if quotePolicy == ForceUnquote { + return strings.ReplaceAll(v, "'", "''") + } else { + return fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''")) + } + case time.Time: + if quotePolicy == ForceUnquote { + return v.Format(sc.CassandraDatetimeFormat) + } else { + return v.Format(fmt.Sprintf("'%s'", sc.CassandraDatetimeFormat)) + } + default: + return fmt.Sprintf("%v", v) + } +} + +func valueToCqlParam(value any) any { + switch v := value.(type) { + case decimal.Decimal: + f, _ := v.Float64() + scaled := int64(math.Round(f * 100)) + return inf.NewDec(scaled, 2) + default: + return v + } +} + +type queryBuilderColumnDefs struct { + Columns [256]string + Types [256]string + Len int +} + +func (cd *queryBuilderColumnDefs) add(column string, fieldType sc.TableFieldType) { + cd.Columns[cd.Len] = column + switch fieldType { + case sc.FieldTypeInt: + cd.Types[cd.Len] = "BIGINT" + case sc.FieldTypeDecimal2: + cd.Types[cd.Len] = "DECIMAL" + case sc.FieldTypeFloat: + cd.Types[cd.Len] = "DOUBLE" + case sc.FieldTypeString: + cd.Types[cd.Len] = "TEXT" + case sc.FieldTypeBool: + cd.Types[cd.Len] = "BOOLEAN" + case sc.FieldTypeDateTime: + cd.Types[cd.Len] = "TIMESTAMP" // Cassandra stores milliseconds since epoch + default: + cd.Types[cd.Len] = fmt.Sprintf("UKNOWN_TYPE_%s", fieldType) + } + cd.Len++ +} + +type queryBuilderPreparedColumnData struct { + Columns [256]string + Values [256]any + ColumnIdxMap map[string]int + ValueIdxMap map[string]int +} + +func (cd *queryBuilderPreparedColumnData) addColumnName(column string) error { + if _, ok := cd.ColumnIdxMap[column]; ok { + return fmt.Errorf("cannot add same column %s to a prepared query twice: %v", column, cd.Columns) + } + curColCount := len(cd.ColumnIdxMap) + cd.Columns[curColCount] = column + cd.ColumnIdxMap[column] = curColCount + return nil +} +func (cd *queryBuilderPreparedColumnData) addColumnValue(column string, value any) error { + colIdx, ok := cd.ColumnIdxMap[column] + if !ok { + return fmt.Errorf("cannot set value for non-prepared column %s, available columns are %v", column, cd.Columns) + } + cd.Values[colIdx] = valueToCqlParam(value) + cd.ValueIdxMap[column] = colIdx + return nil +} + +type queryBuilderColumnData struct { + Columns [256]string + Values [256]string + Len int +} + +func (cd *queryBuilderColumnData) add(column string, value any, quotePolicy QuotePolicyType) { + cd.Values[cd.Len] = valueToString(value, quotePolicy) + cd.Columns[cd.Len] = column + cd.Len++ +} + +type queryBuilderConditions struct { + Items [256]string + Len int +} + +func (cc *queryBuilderConditions) addInInt(column string, values []int64) { + inValues := make([]string, len(values)) + for i, v := range values { + inValues[i] = fmt.Sprintf("%d", v) + } + cc.Items[cc.Len] = fmt.Sprintf("%s IN ( %s )", column, strings.Join(inValues, ", ")) + cc.Len++ +} + +func (cc *queryBuilderConditions) addInInt16(column string, values []int16) { + inValues := make([]string, len(values)) + for i, v := range values { + inValues[i] = fmt.Sprintf("%d", v) + } + cc.Items[cc.Len] = fmt.Sprintf("%s IN ( %s )", column, strings.Join(inValues, ", ")) + cc.Len++ +} + +func (cc *queryBuilderConditions) addInString(column string, values []string) { + cc.Items[cc.Len] = fmt.Sprintf("%s IN ( '%s' )", column, strings.Join(values, "', '")) + cc.Len++ +} + +func (cc *queryBuilderConditions) addSimple(column string, op string, value any) { + cc.Items[cc.Len] = fmt.Sprintf("%s %s %s", column, op, valueToString(value, LeaveQuoteAsIs)) + cc.Len++ +} +func (cc *queryBuilderConditions) addSimpleForceUnquote(column string, op string, value any) { + cc.Items[cc.Len] = fmt.Sprintf("%s %s %s", column, op, valueToString(value, ForceUnquote)) + cc.Len++ +} + +/* +QueryBuilder - very simple cql query builder that does not require db connection +*/ +type QueryBuilder struct { + ColumnDefs queryBuilderColumnDefs + PartitionKeyColumns []string + ClusteringKeyColumns []string + ColumnData queryBuilderColumnData + PreparedColumnData queryBuilderPreparedColumnData + Conditions queryBuilderConditions + IfConditions queryBuilderConditions + SelectLimit int + FormattedKeyspace string + OrderByColumns []string +} + +func NewQB() *QueryBuilder { + var qb QueryBuilder + qb.PreparedColumnData.ColumnIdxMap = map[string]int{} + qb.PreparedColumnData.ValueIdxMap = map[string]int{} + return &qb +} + +func (qb *QueryBuilder) ColumnDef(column string, fieldType sc.TableFieldType) *QueryBuilder { + qb.ColumnDefs.add(column, fieldType) + return qb +} + +/* + */ +func (qb *QueryBuilder) PartitionKey(column ...string) *QueryBuilder { + qb.PartitionKeyColumns = column + return qb +} +func (qb *QueryBuilder) ClusteringKey(column ...string) *QueryBuilder { + qb.ClusteringKeyColumns = column + return qb +} + +/* +Keyspace - specify keyspace (optional) +*/ +func (qb *QueryBuilder) Keyspace(keyspace string) *QueryBuilder { + if trimmedKeyspace := strings.TrimSpace(keyspace); len(trimmedKeyspace) > 0 { + qb.FormattedKeyspace = fmt.Sprintf("%s.", trimmedKeyspace) + } else { + qb.FormattedKeyspace = "" + } + return qb +} + +func (qb *QueryBuilder) Limit(limit int) *QueryBuilder { + qb.SelectLimit = limit + return qb +} + +/* +Write - add a column for INSERT or UPDATE +*/ +func (qb *QueryBuilder) Write(column string, value any) *QueryBuilder { + qb.ColumnData.add(column, value, LeaveQuoteAsIs) + return qb +} + +func (qb *QueryBuilder) WritePreparedColumn(column string) error { + return qb.PreparedColumnData.addColumnName(column) +} + +func (qb *QueryBuilder) WritePreparedValue(column string, value any) error { + return qb.PreparedColumnData.addColumnValue(column, value) +} + +/* +WriteForceUnquote - add a column for INSERT or UPDATE +*/ +func (qb *QueryBuilder) WriteForceUnquote(column string, value any) *QueryBuilder { + qb.ColumnData.add(column, value, ForceUnquote) + return qb +} + +/* +Cond - add condition for SELECT, UPDATE or DELETE +*/ +func (qb *QueryBuilder) Cond(column string, op string, value any) *QueryBuilder { + qb.Conditions.addSimple(column, op, value) + return qb +} + +func (qb *QueryBuilder) CondPrepared(column string, op string) *QueryBuilder { + qb.Conditions.addSimpleForceUnquote(column, op, "?") + return qb +} + +func (qb *QueryBuilder) CondInPrepared(column string) *QueryBuilder { + qb.Conditions.addSimpleForceUnquote(column, "IN", "?") + return qb +} + +/* +CondIn - add IN condition for SELECT, UPDATE or DELETE +*/ +func (qb *QueryBuilder) CondInInt(column string, values []int64) *QueryBuilder { + qb.Conditions.addInInt(column, values) + return qb +} + +func (qb *QueryBuilder) CondInInt16(column string, values []int16) *QueryBuilder { + qb.Conditions.addInInt16(column, values) + return qb +} + +func (qb *QueryBuilder) CondInString(column string, values []string) *QueryBuilder { + qb.Conditions.addInString(column, values) + return qb +} + +func (qb *QueryBuilder) OrderBy(columns ...string) *QueryBuilder { + qb.OrderByColumns = columns + return qb +} + +func (qb *QueryBuilder) If(column string, op string, value any) *QueryBuilder { + qb.IfConditions.addSimple(column, op, value) + return qb +} + +/* +Insert - build INSERT query +*/ +const RunIdForEmptyRun = -1 + +func (qb *QueryBuilder) InsertUnpreparedQuery(tableName string, ifNotExists IfNotExistsType) string { + return qb.insertRunUnpreparedQuery(tableName, RunIdForEmptyRun, ifNotExists) +} +func (qb *QueryBuilder) insertRunUnpreparedQuery(tableName string, runId int16, ifNotExists IfNotExistsType) string { + ifNotExistsStr := "" + if ifNotExists == IgnoreIfExists { + ifNotExistsStr = "IF NOT EXISTS" + } + q := fmt.Sprintf("INSERT INTO %s%s%s ( %s ) VALUES ( %s ) %s;", + qb.FormattedKeyspace, + tableName, + RunIdSuffix(runId), + strings.Join(qb.ColumnData.Columns[:qb.ColumnData.Len], ", "), + strings.Join(qb.ColumnData.Values[:qb.ColumnData.Len], ", "), + ifNotExistsStr) + if runId == 0 { + q = "INVALID runId: " + q + } + return q +} + +func (qb *QueryBuilder) InsertRunPreparedQuery(tableName string, runId int16, ifNotExists IfNotExistsType) (string, error) { + ifNotExistsStr := "" + if ifNotExists == IgnoreIfExists { + ifNotExistsStr = "IF NOT EXISTS" + } + columnCount := len(qb.PreparedColumnData.ColumnIdxMap) + paramArray := make([]string, columnCount) + for paramIdx := 0; paramIdx < columnCount; paramIdx++ { + paramArray[paramIdx] = "?" + } + q := fmt.Sprintf("INSERT INTO %s%s%s ( %s ) VALUES ( %s ) %s;", + qb.FormattedKeyspace, + tableName, + RunIdSuffix(runId), + strings.Join(qb.PreparedColumnData.Columns[:columnCount], ", "), + strings.Join(paramArray, ", "), + ifNotExistsStr) + if runId == 0 { + return "", fmt.Errorf("invalid runId=0 in %s", q) + } + return q, nil +} + +func (qb *QueryBuilder) InsertRunParams() ([]any, error) { + if len(qb.PreparedColumnData.ColumnIdxMap) != len(qb.PreparedColumnData.ValueIdxMap) { + return nil, fmt.Errorf("cannot produce insert params, length mismatch: columns %v, values %v", qb.PreparedColumnData.ColumnIdxMap, qb.PreparedColumnData.ValueIdxMap) + } + return qb.PreparedColumnData.Values[:len(qb.PreparedColumnData.ValueIdxMap)], nil +} + +/* +Select - build SELECT query +*/ +func (qb *QueryBuilder) Select(tableName string, columns []string) string { + return qb.SelectRun(tableName, RunIdForEmptyRun, columns) +} +func (qb *QueryBuilder) SelectRun(tableName string, runId int16, columns []string) string { + b := strings.Builder{} + if runId == 0 { + b.WriteString("INVALID runId: ") + } + b.WriteString(fmt.Sprintf("SELECT %s FROM %s%s%s", + strings.Join(columns, ", "), + qb.FormattedKeyspace, + tableName, + RunIdSuffix(runId))) + if qb.Conditions.Len > 0 { + b.WriteString(" WHERE ") + b.WriteString(strings.Join(qb.Conditions.Items[:qb.Conditions.Len], " AND ")) + } + if len(qb.OrderByColumns) > 0 { + b.WriteString(fmt.Sprintf(" ORDER BY %s ", strings.Join(qb.OrderByColumns, ","))) + } + if qb.SelectLimit > 0 { + b.WriteString(fmt.Sprintf(" LIMIT %d", qb.SelectLimit)) + } + b.WriteString(";") + + return b.String() +} + +/* +Delete - build DELETE query +*/ +func (qb *QueryBuilder) Delete(tableName string) string { + return qb.DeleteRun(tableName, RunIdForEmptyRun) +} +func (qb *QueryBuilder) DeleteRun(tableName string, runId int16) string { + q := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s", + qb.FormattedKeyspace, + tableName, + RunIdSuffix(runId), + strings.Join(qb.Conditions.Items[:qb.Conditions.Len], " AND ")) + if runId == 0 { + q = "DEV ERROR, INVALID runId: " + q + } + + return q +} + +/* +Update - build UPDATE query +*/ +func (qb *QueryBuilder) Update(tableName string) string { + return qb.UpdateRun(tableName, RunIdForEmptyRun) +} +func (qb *QueryBuilder) UpdateRun(tableName string, runId int16) string { + var assignments [256]string + for i := 0; i < qb.ColumnData.Len; i++ { + assignments[i] = fmt.Sprintf("%s = %s", qb.ColumnData.Columns[i], qb.ColumnData.Values[i]) + } + q := fmt.Sprintf("UPDATE %s%s%s SET %s WHERE %s", + qb.FormattedKeyspace, + tableName, + RunIdSuffix(runId), + strings.Join(assignments[:qb.ColumnData.Len], ", "), + strings.Join(qb.Conditions.Items[:qb.Conditions.Len], " AND ")) + + if qb.IfConditions.Len > 0 { + q += " IF " + strings.Join(qb.IfConditions.Items[:qb.IfConditions.Len], " AND ") + } + if runId == 0 { + q = "INVALID runId: " + q + } + return q +} + +func (qb *QueryBuilder) Create(tableName string, ifNotExists IfNotExistsType) string { + return qb.CreateRun(tableName, RunIdForEmptyRun, ifNotExists) +} +func (qb *QueryBuilder) CreateRun(tableName string, runId int16, ifNotExists IfNotExistsType) string { + var b strings.Builder + if runId == 0 { + b.WriteString("INVALID runId: ") + } + b.WriteString("CREATE TABLE ") + if ifNotExists == IgnoreIfExists { + b.WriteString("IF NOT EXISTS ") + } + b.WriteString(fmt.Sprintf("%s%s%s ( ", qb.FormattedKeyspace, tableName, RunIdSuffix(runId))) + for i := 0; i < qb.ColumnDefs.Len; i++ { + b.WriteString(qb.ColumnDefs.Columns[i]) + b.WriteString(" ") + b.WriteString(qb.ColumnDefs.Types[i]) + if i < qb.ColumnDefs.Len-1 { + b.WriteString(", ") + } + } + if len(qb.PartitionKeyColumns) > 0 { + b.WriteString(", ") + b.WriteString(fmt.Sprintf("PRIMARY KEY((%s)", strings.Join(qb.PartitionKeyColumns, ", "))) + if len(qb.ClusteringKeyColumns) > 0 { + b.WriteString(", ") + b.WriteString(strings.Join(qb.ClusteringKeyColumns, ", ")) + } + b.WriteString(")") + } + b.WriteString(");") + return b.String() +} + +// Currently not used, leave it commented out just in case +// func (qb *QueryBuilder) Drop(tableName string) string { +// return qb.DropRun(tableName, RunIdForEmptyRun) +// } +// func (qb *QueryBuilder) DropRun(tableName string, runId int16) string { +// q := fmt.Sprintf("DROP TABLE IF EXISTS %s%s%s", qb.FormattedKeyspace, tableName, RunIdSuffix(runId)) +// if runId == 0 { +// q = "INVALID runId: " + q +// } +// return q +// } + +func (qb *QueryBuilder) DropKeyspace() string { + return fmt.Sprintf("DROP KEYSPACE IF EXISTS %s", strings.ReplaceAll(qb.FormattedKeyspace, ".", "")) +} diff --git a/pkg/cql/cql_query_builder_test.go b/pkg/cql/cql_query_builder_test.go index 77c77a6..15960e4 100644 --- a/pkg/cql/cql_query_builder_test.go +++ b/pkg/cql/cql_query_builder_test.go @@ -1,110 +1,110 @@ -package cql - -import ( - "fmt" - "testing" - - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/shopspring/decimal" - "github.com/stretchr/testify/assert" - "gopkg.in/inf.v0" -) - -func TestValueToCqlParam(t *testing.T) { - // Simple - assert.Equal(t, "1.23", valueToCqlParam(decimal.NewFromFloat(1.23)).(*inf.Dec).String()) - - // big round up - assert.Equal(t, "1.24", valueToCqlParam(decimal.NewFromFloat(1.235)).(*inf.Dec).String()) - - // small round down - assert.Equal(t, "0.03", valueToCqlParam(decimal.NewFromFloat(0.0345)).(*inf.Dec).String()) -} - -func TestInsertRunParams(t *testing.T) { - qb := NewQB() - qb.WritePreparedColumn("param_name") - qb.WritePreparedValue("param_name", "param_value") - q, err := qb.Keyspace("ks1").InsertRunPreparedQuery("table1", 1, IgnoreIfExists) - assert.Nil(t, err) - assert.Equal(t, "INSERT INTO ks1.table1_00001 ( param_name ) VALUES ( ? ) IF NOT EXISTS;", q) - - params, err := qb.InsertRunParams() - assert.Equal(t, []interface{}([]interface{}{"param_value"}), params) -} - -func TestInsert(t *testing.T) { - const qTemplate string = "INSERT INTO table1%s ( col1, col2, col3 ) VALUES ( 'val1', 2, now() ) IF NOT EXISTS;" - qb := (&QueryBuilder{}). - Write("col1", "val1"). - Write("col2", 2). - WriteForceUnquote("col3", "now()") - assert.Equal(t, fmt.Sprintf(qTemplate, "_00123"), qb.insertRunUnpreparedQuery("table1", 123, IgnoreIfExists)) -} - -func TestDropKeyspace(t *testing.T) { - assert.Equal(t, "DROP KEYSPACE IF EXISTS aaa", (&QueryBuilder{}).Keyspace("aaa").DropKeyspace()) -} - -func TestSelect(t *testing.T) { - const qTemplate string = "SELECT col3, col4 FROM somekeyspace.table1%s WHERE col1 > 1 AND col2 = 2 AND col3 IN ( 'val31', 'val32' ) AND col7 IN ( 1, 2 ) ORDER BY col3 LIMIT 10;" - qb := (&QueryBuilder{}). - Keyspace("somekeyspace"). - Cond("col1", ">", 1). - Cond("col2", "=", 2). - CondInString("col3", []string{"val31", "val32"}). - CondInInt16("col7", []int16{1, 2}). - OrderBy("col3"). - Limit(10) - - assert.Equal(t, fmt.Sprintf(qTemplate, "_00123"), qb.SelectRun("table1", 123, []string{"col3", "col4"})) - assert.Equal(t, fmt.Sprintf(qTemplate, ""), qb.Select("table1", []string{"col3", "col4"})) -} - -func TestDelete(t *testing.T) { - const qTemplate string = "DELETE FROM table1%s WHERE col1 > 1 AND col2 = 2 AND col3 IN ( 'val31', 'val32' ) AND col7 IN ( 1, 2 )" - qb := (&QueryBuilder{}). - Cond("col1", ">", 1). - Cond("col2", "=", 2). - CondInString("col3", []string{"val31", "val32"}). - CondInInt("col7", []int64{1, 2}) - assert.Equal(t, fmt.Sprintf(qTemplate, "_00123"), qb.DeleteRun("table1", 123)) - assert.Equal(t, fmt.Sprintf(qTemplate, ""), qb.Delete("table1")) -} - -func TestUpdate(t *testing.T) { - const qTemplate string = "UPDATE table1%s SET col1 = 'val1', col2 = 2 WHERE col1 > 1 AND col2 = '2' IF col1 = 2" - qb := (&QueryBuilder{}). - Write("col1", "val1"). - Write("col2", 2). - Cond("col1", ">", 1). - Cond("col2", "=", "2"). - If("col1", "=", 2) - assert.Equal(t, fmt.Sprintf(qTemplate, "_00123"), qb.UpdateRun("table1", 123)) - assert.Equal(t, fmt.Sprintf(qTemplate, ""), qb.Update("table1")) -} - -func TestCreate(t *testing.T) { - const qTemplate string = "CREATE TABLE IF NOT EXISTS table1%s ( col_int BIGINT, col_bool BOOLEAN, col_string TEXT, col_datetime TIMESTAMP, col_decimal2 DECIMAL, col_float DOUBLE, PRIMARY KEY((col_int, col_decimal2), col_bool, col_float));" - qb := (&QueryBuilder{}). - ColumnDef("col_int", sc.FieldTypeInt). - ColumnDef("col_bool", sc.FieldTypeBool). - ColumnDef("col_string", sc.FieldTypeString). - ColumnDef("col_datetime", sc.FieldTypeDateTime). - ColumnDef("col_decimal2", sc.FieldTypeDecimal2). - ColumnDef("col_float", sc.FieldTypeFloat). - PartitionKey("col_int", "col_decimal2"). - ClusteringKey("col_bool", "col_float") - assert.Equal(t, fmt.Sprintf(qTemplate, "_00123"), qb.CreateRun("table1", 123, IgnoreIfExists)) - assert.Equal(t, fmt.Sprintf(qTemplate, ""), qb.Create("table1", IgnoreIfExists)) -} - -func TestInsertPrepared(t *testing.T) { - dataQb := NewQB() - err := dataQb.WritePreparedColumn("col_int") - assert.Nil(t, err) - err = dataQb.WritePreparedValue("col_int", 2) - assert.Nil(t, err) - s, _ := dataQb.InsertRunPreparedQuery("table1", 123, IgnoreIfExists) - assert.Equal(t, "INSERT INTO table1_00123 ( col_int ) VALUES ( ? ) IF NOT EXISTS;", s) -} +package cql + +import ( + "fmt" + "testing" + + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" + "gopkg.in/inf.v0" +) + +func TestValueToCqlParam(t *testing.T) { + // Simple + assert.Equal(t, "1.23", valueToCqlParam(decimal.NewFromFloat(1.23)).(*inf.Dec).String()) + + // big round up + assert.Equal(t, "1.24", valueToCqlParam(decimal.NewFromFloat(1.235)).(*inf.Dec).String()) + + // small round down + assert.Equal(t, "0.03", valueToCqlParam(decimal.NewFromFloat(0.0345)).(*inf.Dec).String()) +} + +func TestInsertRunParams(t *testing.T) { + qb := NewQB() + assert.Nil(t, qb.WritePreparedColumn("param_name")) + assert.Nil(t, qb.WritePreparedValue("param_name", "param_value")) + q, err := qb.Keyspace("ks1").InsertRunPreparedQuery("table1", 1, IgnoreIfExists) + assert.Nil(t, err) + assert.Equal(t, "INSERT INTO ks1.table1_00001 ( param_name ) VALUES ( ? ) IF NOT EXISTS;", q) + + params, _ := qb.InsertRunParams() + assert.Equal(t, []any([]any{"param_value"}), params) +} + +func TestInsert(t *testing.T) { + const qTemplate string = "INSERT INTO table1%s ( col1, col2, col3 ) VALUES ( 'val1', 2, now() ) IF NOT EXISTS;" + qb := (&QueryBuilder{}). + Write("col1", "val1"). + Write("col2", 2). + WriteForceUnquote("col3", "now()") + assert.Equal(t, fmt.Sprintf(qTemplate, "_00123"), qb.insertRunUnpreparedQuery("table1", 123, IgnoreIfExists)) +} + +func TestDropKeyspace(t *testing.T) { + assert.Equal(t, "DROP KEYSPACE IF EXISTS aaa", (&QueryBuilder{}).Keyspace("aaa").DropKeyspace()) +} + +func TestSelect(t *testing.T) { + const qTemplate string = "SELECT col3, col4 FROM somekeyspace.table1%s WHERE col1 > 1 AND col2 = 2 AND col3 IN ( 'val31', 'val32' ) AND col7 IN ( 1, 2 ) ORDER BY col3 LIMIT 10;" + qb := (&QueryBuilder{}). + Keyspace("somekeyspace"). + Cond("col1", ">", 1). + Cond("col2", "=", 2). + CondInString("col3", []string{"val31", "val32"}). + CondInInt16("col7", []int16{1, 2}). + OrderBy("col3"). + Limit(10) + + assert.Equal(t, fmt.Sprintf(qTemplate, "_00123"), qb.SelectRun("table1", 123, []string{"col3", "col4"})) + assert.Equal(t, fmt.Sprintf(qTemplate, ""), qb.Select("table1", []string{"col3", "col4"})) +} + +func TestDelete(t *testing.T) { + const qTemplate string = "DELETE FROM table1%s WHERE col1 > 1 AND col2 = 2 AND col3 IN ( 'val31', 'val32' ) AND col7 IN ( 1, 2 )" + qb := (&QueryBuilder{}). + Cond("col1", ">", 1). + Cond("col2", "=", 2). + CondInString("col3", []string{"val31", "val32"}). + CondInInt("col7", []int64{1, 2}) + assert.Equal(t, fmt.Sprintf(qTemplate, "_00123"), qb.DeleteRun("table1", 123)) + assert.Equal(t, fmt.Sprintf(qTemplate, ""), qb.Delete("table1")) +} + +func TestUpdate(t *testing.T) { + const qTemplate string = "UPDATE table1%s SET col1 = 'val1', col2 = 2 WHERE col1 > 1 AND col2 = '2' IF col1 = 2" + qb := (&QueryBuilder{}). + Write("col1", "val1"). + Write("col2", 2). + Cond("col1", ">", 1). + Cond("col2", "=", "2"). + If("col1", "=", 2) + assert.Equal(t, fmt.Sprintf(qTemplate, "_00123"), qb.UpdateRun("table1", 123)) + assert.Equal(t, fmt.Sprintf(qTemplate, ""), qb.Update("table1")) +} + +func TestCreate(t *testing.T) { + const qTemplate string = "CREATE TABLE IF NOT EXISTS table1%s ( col_int BIGINT, col_bool BOOLEAN, col_string TEXT, col_datetime TIMESTAMP, col_decimal2 DECIMAL, col_float DOUBLE, PRIMARY KEY((col_int, col_decimal2), col_bool, col_float));" + qb := (&QueryBuilder{}). + ColumnDef("col_int", sc.FieldTypeInt). + ColumnDef("col_bool", sc.FieldTypeBool). + ColumnDef("col_string", sc.FieldTypeString). + ColumnDef("col_datetime", sc.FieldTypeDateTime). + ColumnDef("col_decimal2", sc.FieldTypeDecimal2). + ColumnDef("col_float", sc.FieldTypeFloat). + PartitionKey("col_int", "col_decimal2"). + ClusteringKey("col_bool", "col_float") + assert.Equal(t, fmt.Sprintf(qTemplate, "_00123"), qb.CreateRun("table1", 123, IgnoreIfExists)) + assert.Equal(t, fmt.Sprintf(qTemplate, ""), qb.Create("table1", IgnoreIfExists)) +} + +func TestInsertPrepared(t *testing.T) { + dataQb := NewQB() + err := dataQb.WritePreparedColumn("col_int") + assert.Nil(t, err) + err = dataQb.WritePreparedValue("col_int", 2) + assert.Nil(t, err) + s, _ := dataQb.InsertRunPreparedQuery("table1", 123, IgnoreIfExists) + assert.Equal(t, "INSERT INTO table1_00123 ( col_int ) VALUES ( ? ) IF NOT EXISTS;", s) +} diff --git a/pkg/ctx/processing_context.go b/pkg/ctx/processing_context.go index 06ca4a2..e6597fc 100644 --- a/pkg/ctx/processing_context.go +++ b/pkg/ctx/processing_context.go @@ -1,41 +1,40 @@ -package ctx - -import ( - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/env" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "github.com/gocql/gocql" - "go.uber.org/zap/zapcore" -) - -type MessageProcessingContext struct { - MsgTs int64 - BatchInfo wfmodel.MessagePayloadDataBatch - CqlSession *gocql.Session - Script *sc.ScriptDef - CurrentScriptNode *sc.ScriptNodeDef - ZapDataKeyspace zapcore.Field - ZapRun zapcore.Field - ZapNode zapcore.Field - ZapBatchIdx zapcore.Field - ZapMsgAgeMillis zapcore.Field -} - -func (pCtx *MessageProcessingContext) DbConnect(envConfig *env.EnvConfig) error { - var err error - if pCtx.CqlSession, err = db.NewSession(envConfig, pCtx.BatchInfo.DataKeyspace, db.CreateKeyspaceOnConnect); err != nil { - return err - } - return nil -} - -func (pCtx *MessageProcessingContext) DbClose() { - if pCtx.CqlSession != nil { - if pCtx.CqlSession.Closed() { - // TODO: something is not clean in the code, find a way to communicate it without using logger - } else { - pCtx.CqlSession.Close() - } - } -} +package ctx + +import ( + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/env" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "github.com/gocql/gocql" + "go.uber.org/zap/zapcore" +) + +type MessageProcessingContext struct { + MsgTs int64 + BatchInfo wfmodel.MessagePayloadDataBatch + CqlSession *gocql.Session + Script *sc.ScriptDef + CurrentScriptNode *sc.ScriptNodeDef + ZapDataKeyspace zapcore.Field + ZapRun zapcore.Field + ZapNode zapcore.Field + ZapBatchIdx zapcore.Field + ZapMsgAgeMillis zapcore.Field +} + +func (pCtx *MessageProcessingContext) DbConnect(envConfig *env.EnvConfig) error { + var err error + if pCtx.CqlSession, err = db.NewSession(envConfig, pCtx.BatchInfo.DataKeyspace, db.CreateKeyspaceOnConnect); err != nil { + return err + } + return nil +} + +func (pCtx *MessageProcessingContext) DbClose() { + if pCtx.CqlSession != nil { + // TODO: if it's already closed, something is not clean in the code, find a way to communicate it without using logger + if !pCtx.CqlSession.Closed() { + pCtx.CqlSession.Close() + } + } +} diff --git a/pkg/custom/py_calc/py_calc.donotcover.go b/pkg/custom/py_calc/py_calc.donotcover.go index 865d728..0c8924b 100644 --- a/pkg/custom/py_calc/py_calc.donotcover.go +++ b/pkg/custom/py_calc/py_calc.donotcover.go @@ -14,12 +14,10 @@ import ( "github.com/capillariesio/capillaries/pkg/proc" ) -func (procDef *PyCalcProcessorDef) Run(logger *l.Logger, pCtx *ctx.MessageProcessingContext, rsIn *proc.Rowset, flushVarsArray func(varsArray []*eval.VarValuesMap, varsArrayCount int) error) error { +func (procDef *PyCalcProcessorDef) Run(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, rsIn *proc.Rowset, flushVarsArray func(varsArray []*eval.VarValuesMap, varsArrayCount int) error) error { logger.PushF("custom.PyCalcProcessorDef.Run") defer logger.PopF() - //err := procDef.executeCalculations(logger, pCtx, rsIn, rsOut, time.Duration(procDef.EnvSettings.ExecutionTimeout*int(time.Millisecond))) - timeout := time.Duration(procDef.EnvSettings.ExecutionTimeout * int(time.Millisecond)) codeBase, err := procDef.buildPythonCodebaseFromRowset(rsIn) @@ -42,7 +40,7 @@ func (procDef *PyCalcProcessorDef) Run(logger *l.Logger, pCtx *ctx.MessageProces p.Stdout = &stdout p.Stderr = &stderr - //return fmt.Errorf(codeBase.String()) + // return fmt.Errorf(codeBase.String()) // Run pythonStartTime := time.Now() @@ -50,13 +48,13 @@ func (procDef *PyCalcProcessorDef) Run(logger *l.Logger, pCtx *ctx.MessageProces pythonDur := time.Since(pythonStartTime) logger.InfoCtx(pCtx, "PythonInterpreter: %d items in %v (%.0f items/s)", rsIn.RowCount, pythonDur, float64(rsIn.RowCount)/pythonDur.Seconds()) - rawOutput := string(stdout.Bytes()) - rawErrors := string(stderr.Bytes()) + rawOutput := stdout.String() + rawErrors := stderr.String() // Really verbose, use for troubleshooting only // fmt.Println(codeBase, rawOutput) - //fmt.Println(fmt.Sprintf("err.Error():'%s', cmdCtx.Err():'%v'", err.Error(), cmdCtx.Err())) + // fmt.Println(fmt.Sprintf("err.Error():'%s', cmdCtx.Err():'%v'", err.Error(), cmdCtx.Err())) if err != nil { fullErrorInfo, err := procDef.analyseExecError(codeBase, rawOutput, rawErrors, err) @@ -64,12 +62,10 @@ func (procDef *PyCalcProcessorDef) Run(logger *l.Logger, pCtx *ctx.MessageProces logger.ErrorCtx(pCtx, fullErrorInfo) } return fmt.Errorf("Python interpreter returned an error: %s", err) - } else { - if cmdCtx.Err() == context.DeadlineExceeded { - // Timeout occurred, err.Error() is probably: 'signal: killed' - return fmt.Errorf("Python calculation timeout %d s expired;", timeout) - } else { - return procDef.analyseExecSuccess(codeBase, rawOutput, rawErrors, procDef.GetFieldRefs(), rsIn, flushVarsArray) - } } + if cmdCtx.Err() == context.DeadlineExceeded { + // Timeout occurred, err.Error() is probably: 'signal: killed' + return fmt.Errorf("Python calculation timeout %d s expired;", timeout) + } + return procDef.analyseExecSuccess(codeBase, rawOutput, rawErrors, procDef.GetFieldRefs(), rsIn, flushVarsArray) } diff --git a/pkg/custom/py_calc/py_calc.go b/pkg/custom/py_calc/py_calc.go index 920616e..d2f4040 100644 --- a/pkg/custom/py_calc/py_calc.go +++ b/pkg/custom/py_calc/py_calc.go @@ -1,610 +1,611 @@ -package py_calc - -import ( - "bufio" - "encoding/json" - "fmt" - "go/ast" - "regexp" - "strconv" - "strings" - "time" - - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/capillariesio/capillaries/pkg/proc" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/xfer" - "github.com/shopspring/decimal" -) - -const ProcessorPyCalcName string = "py_calc" - -type PyCalcEnvSettings struct { - // Windows: `python` or `C:\Users\%USERNAME%\AppData\Local\Programs\Python\Python310\python.exe` - // WSL: `python` or `/mnt/c/Users/myusername/AppData/Local/Programs/Python/Python310/python.exe` - // Linux: `python` - InterpreterPath string `json:"python_interpreter_path"` - // Usually: ["-u", "-"]. -u is essential: without it, we will not see stdout/stderr in the timeout scenario - InterpreterParams []string `json:"python_interpreter_params"` - ExecutionTimeout int `json:"execution_timeout"` // Default 5000 milliseconds -} - -type PyCalcProcessorDef struct { - PythonUrls []string `json:"python_code_urls"` - CalculatedFields map[string]*sc.WriteTableFieldDef `json:"calculated_fields"` - UsedInTargetExpressionsFields sc.FieldRefs - PythonCode string - CalculationOrder []string - EnvSettings PyCalcEnvSettings -} - -func (procDef *PyCalcProcessorDef) GetFieldRefs() *sc.FieldRefs { - fieldRefs := make(sc.FieldRefs, len(procDef.CalculatedFields)) - i := 0 - for fieldName, fieldDef := range procDef.CalculatedFields { - fieldRefs[i] = sc.FieldRef{ - TableName: sc.CustomProcessorAlias, - FieldName: fieldName, - FieldType: fieldDef.Type} - i += 1 - } - return &fieldRefs -} - -func harvestCallExp(callExp *ast.CallExpr, sigMap map[string]struct{}) { - // This expression is a python fuction call. - // Build a func_name(arg,arg,...) signature for it to check our Python code later - funIdentExp, _ := callExp.Fun.(*ast.Ident) - funcSig := fmt.Sprintf("%s(%s)", funIdentExp.Name, strings.Trim(strings.Repeat("arg,", len(callExp.Args)), ",")) - sigMap[funcSig] = struct{}{} - - for _, argExp := range callExp.Args { - switch typedExp := argExp.(type) { - case *ast.CallExpr: - harvestCallExp(typedExp, sigMap) - } - } -} - -func (procDef *PyCalcProcessorDef) Deserialize(raw json.RawMessage, customProcSettings json.RawMessage, caPath string, privateKeys map[string]string) error { - var err error - if err = json.Unmarshal(raw, procDef); err != nil { - return fmt.Errorf("cannot unmarshal py_calc processor def: %s", err.Error()) - } - - if err = json.Unmarshal(customProcSettings, &procDef.EnvSettings); err != nil { - return fmt.Errorf("cannot unmarshal py_calc processor env settings: %s", err.Error()) - } - - if len(procDef.EnvSettings.InterpreterPath) == 0 { - return fmt.Errorf("py_calc interpreter path canot be empty") - } - - if procDef.EnvSettings.ExecutionTimeout == 0 { - procDef.EnvSettings.ExecutionTimeout = 5000 - } - - errors := make([]string, 0) - usedPythonFunctionSignatures := map[string]struct{}{} - - // Calculated fields - for _, fieldDef := range procDef.CalculatedFields { - - // Use relaxed Go parser for Python - we are lucky that Go designers liked Python, so we do not have to implement a separate Python partser (for now) - if fieldDef.ParsedExpression, err = sc.ParseRawRelaxedGolangExpressionStringAndHarvestFieldRefs(fieldDef.RawExpression, &fieldDef.UsedFields, sc.FieldRefAllowUnknownIdents); err != nil { - errors = append(errors, fmt.Sprintf("cannot parse field expression [%s]: [%s]", fieldDef.RawExpression, err.Error())) - } else if !sc.IsValidFieldType(fieldDef.Type) { - errors = append(errors, fmt.Sprintf("invalid field type [%s]", fieldDef.Type)) - } - - // Each calculated field expression must be a valid Python expression and either: - // 1. some_python_func() - // 2. some reader field, like r.order_id (will be checked by checkFieldUsageInCustomProcessor) - // 3. some field calculated by this processor, like p.calculatedMargin (will be checked by checkFieldUsageInCustomProcessor) - - // Check top-level expression - switch typedExp := fieldDef.ParsedExpression.(type) { - case *ast.CallExpr: - harvestCallExp(typedExp, usedPythonFunctionSignatures) - case *ast.SelectorExpr: - // Assume it's a reader or calculated field. Do not check it here, checkFieldUsageInCustomProcessor() will do that - default: - errors = append(errors, fmt.Sprintf("invalid calculated field expression '%s', expected either 'some_function_from_your_python_code(...)' or some reader field, like 'r.order_id', or some other calculated (by this processor) field, like 'p.calculatedMargin'", fieldDef.RawExpression)) - } - } - - procDef.UsedInTargetExpressionsFields = sc.GetFieldRefsUsedInAllTargetExpressions(procDef.CalculatedFields) - - // Python files - var b strings.Builder - procDef.PythonCode = "" - for _, url := range procDef.PythonUrls { - bytes, err := xfer.GetFileBytes(url, caPath, privateKeys) - if err != nil { - errors = append(errors, err.Error()) - } - b.WriteString(string(bytes)) - b.WriteString("\n") - } - - procDef.PythonCode = b.String() - - if errCheckDefs := checkPythonFuncDefAvailability(usedPythonFunctionSignatures, procDef.PythonCode); errCheckDefs != nil { - errors = append(errors, errCheckDefs.Error()) - } - - // Build a set of "r.inputFieldX" and "p.calculatedFieldY" to perform Python code checks - srcVarsSet := map[string]struct{}{} - for _, fieldRef := range *procDef.GetFieldRefs() { - srcVarsSet[fieldRef.GetAliasHash()] = struct{}{} - } - - // Define calculation sequence - // Populate DAG - dag := map[string][]string{} - for tgtFieldName, tgtFieldDef := range procDef.CalculatedFields { - tgtFieldAlias := fmt.Sprintf("p.%s", tgtFieldName) - dag[tgtFieldAlias] = make([]string, len(tgtFieldDef.UsedFields)) - for i, usedFieldRef := range tgtFieldDef.UsedFields { - dag[tgtFieldAlias][i] = usedFieldRef.GetAliasHash() - } - } - - // Check DAG and return calc fields in the order they should be calculated - if procDef.CalculationOrder, err = kahn(dag); err != nil { - errors = append(errors, fmt.Sprintf("%s. Calc dependency map:\n%v", err, dag)) - } - - // TODO: deserialize other stuff from raw here if needed - - if len(errors) > 0 { - return fmt.Errorf(strings.Join(errors, "; ")) - } else { - return nil - } -} - -func (procDef *PyCalcProcessorDef) GetUsedInTargetExpressionsFields() *sc.FieldRefs { - return &procDef.UsedInTargetExpressionsFields -} - -// Python supports microseconds in datetime. Unfortunately, Cassandra supports only milliseconds. Millis are our lingua franca. -// So, use only three digits after decimal point -// Python 8601 requires ":" in the timezone -const PythonDatetimeFormat string = "2006-01-02T15:04:05.000-07:00" - -func valueToPythonExpr(val interface{}) string { - switch typedVal := val.(type) { - case int64: - return fmt.Sprintf("%d", typedVal) - case float64: - return fmt.Sprintf("%f", typedVal) - case string: - return fmt.Sprintf("'%s'", strings.ReplaceAll(typedVal, "'", `\'`)) // Use single commas in Python - this may go to logs - case bool: - if typedVal { - return "TRUE" - } else { - return "FALSE" - } - case decimal.Decimal: - return typedVal.String() - case time.Time: - return typedVal.Format(fmt.Sprintf("\"%s\"", PythonDatetimeFormat)) - default: - return fmt.Sprintf("cannot convert '%v(%T)' to Python expression", typedVal, typedVal) - } -} - -func pythonResultToRowsetValue(fieldRef *sc.FieldRef, fieldValue interface{}) (interface{}, error) { - switch fieldRef.FieldType { - case sc.FieldTypeString: - finalVal, ok := fieldValue.(string) - if ok { - return finalVal, nil - } else { - return nil, fmt.Errorf("string %s, unexpected type %T(%v)", fieldRef.FieldName, fieldValue, fieldValue) - } - case sc.FieldTypeBool: - finalVal, ok := fieldValue.(bool) - if ok { - return finalVal, nil - } else { - return nil, fmt.Errorf("bool %s, unexpected type %T(%v)", fieldRef.FieldName, fieldValue, fieldValue) - } - case sc.FieldTypeInt: - finalVal, ok := fieldValue.(float64) - if ok { - finalIntVal := int64(finalVal) - return finalIntVal, nil - } else { - return nil, fmt.Errorf("int %s, unexpected type %T(%v)", fieldRef.FieldName, fieldValue, fieldValue) - } - case sc.FieldTypeFloat: - finalVal, ok := fieldValue.(float64) - if ok { - return finalVal, nil - } else { - return nil, fmt.Errorf("float %s, unexpected type %T(%v)", fieldRef.FieldName, fieldValue, fieldValue) - } - case sc.FieldTypeDecimal2: - finalVal, ok := fieldValue.(float64) - if ok { - finalDecVal := decimal.NewFromFloat(finalVal).Round(2) - return finalDecVal, nil - } else { - return nil, fmt.Errorf("decimal %s, unexpected type %T(%v)", fieldRef.FieldName, fieldValue, fieldValue) - } - case sc.FieldTypeDateTime: - finalVal, ok := fieldValue.(string) - if ok { - timeVal, err := time.Parse(PythonDatetimeFormat, finalVal) - if err != nil { - return nil, fmt.Errorf("bad time result %s, unexpected format %s", fieldRef.FieldName, finalVal) - } else { - return timeVal, nil - } - } else { - return nil, fmt.Errorf("time %s, unexpected type %T(%v)", fieldRef.FieldName, fieldValue, fieldValue) - } - default: - return nil, fmt.Errorf("unexpected field type %s, %s, %T(%v)", fieldRef.FieldType, fieldRef.FieldName, fieldValue, fieldValue) - } -} - -/* -CheckDefsAvailability - makes sure all expressions mentioned in calc expressions have correspondent Python functions defined. -Pays attention to function name and number of arguments. Expressions may contain: -- calculated field references like p.orderMargin -- number/string/bool constants -- calls to other Python functions -Does NOT perform deep checks for Python function call hierarchy. -*/ -func checkPythonFuncDefAvailability(usedPythonFunctionSignatures map[string]struct{}, codeFormulaDefs string) error { - var errors strings.Builder - - // Walk trhough the whole code and collect Python function defs - availableDefSigs := map[string]struct{}{} - reDefSig := regexp.MustCompile(`(?m)^def[ ]+([a-zA-Z0-9_]+)[ ]*([ \(\)a-zA-Z0-9_,\.\"\'\t]+)[ ]*:`) - reArgCommas := regexp.MustCompile("[^)(,]+") - - // Find "def someFunc123(someParam, 'some literal'):" - defSigMatches := reDefSig.FindAllStringSubmatch(codeFormulaDefs, -1) - for _, sigMatch := range defSigMatches { - if strings.ReplaceAll(sigMatch[2], " ", "") == "()" { - // This Python function definition does not accept arguments - availableDefSigs[fmt.Sprintf("%s()", sigMatch[1])] = struct{}{} - } else { - // Strip all characters from the arg list, except commas and parenthesis, - // this will give us the canonical signature with the number of arguments presented by commas: "sum(arg,arg,arg)"" - canonicalSig := fmt.Sprintf("%s%s", sigMatch[1], reArgCommas.ReplaceAllString(sigMatch[2], "arg")) - availableDefSigs[canonicalSig] = struct{}{} - } - } - - // A curated list of Python functions allowed in script expressions. No exec() please. - pythoBuiltinFunctions := map[string]struct{}{ - "str(arg)": {}, - "int(arg)": {}, - "float(arg)": {}, - "round(arg)": {}, - "len(arg)": {}, - "bool(arg)": {}, - "abs(arg)": {}, - } - - // Walk through all Python signatures in calc expressions (we harvested them in Deserialize) - // and make sure correspondent python function defs are available - for usedSig, _ := range usedPythonFunctionSignatures { - if _, ok := availableDefSigs[usedSig]; !ok { - if _, ok := pythoBuiltinFunctions[usedSig]; !ok { - errors.WriteString(fmt.Sprintf("function def '%s' not found in Python file, and it's not in the list of allowed Python built-in functions; ", usedSig)) - } - } - } - - if errors.Len() > 0 { - var defs strings.Builder - for defSig, _ := range availableDefSigs { - defs.WriteString(fmt.Sprintf("%s; ", defSig)) - } - return fmt.Errorf("Python function defs availability check failed, the following functions are not defined: [%s]. Full list of available Python function definitions: %s", errors.String(), defs.String()) - } - - return nil -} - -func kahn(depMap map[string][]string) ([]string, error) { - inDegreeMap := make(map[string]int) - for node := range depMap { - if depMap[node] != nil { - for _, v := range depMap[node] { - inDegreeMap[v]++ - } - } - } - - var queue []string - for node := range depMap { - if _, ok := inDegreeMap[node]; !ok { - queue = append(queue, node) - } - } - - var execOrder []string - for len(queue) > 0 { - node := queue[len(queue)-1] - queue = queue[:(len(queue) - 1)] - // Prepend "node" to execOrder only if it's the key in depMap - // (this is not one of the algorithm requirements, - // it's just our caller needs only those elements) - if _, ok := depMap[node]; ok { - execOrder = append(execOrder, "") - copy(execOrder[1:], execOrder) - execOrder[0] = node - } - for _, v := range depMap[node] { - inDegreeMap[v]-- - if inDegreeMap[v] == 0 { - queue = append(queue, v) - } - } - } - - for _, inDegree := range inDegreeMap { - if inDegree > 0 { - return []string{}, fmt.Errorf("Formula expressions have cycle(s)") - } - } - - return execOrder, nil -} - -func (procDef *PyCalcProcessorDef) buildPythonCodebaseFromRowset(rsIn *proc.Rowset) (string, error) { - // Build a massive Python source: function defs, dp init, calculation, result json - // This is the hardcoded Python code structure we rely on. Do not change it. - var codeBase strings.Builder - - codeBase.WriteString(fmt.Sprintf(` -import traceback -import json -print("\n%s") # Provide function defs -%s -`, - FORMULA_MARKER_FUNCTION_DEFINITIONS, - procDef.PythonCode)) - - for rowIdx := 0; rowIdx < rsIn.RowCount; rowIdx++ { - itemCalculationCodebase, err := procDef.printItemCalculationCode(rowIdx, rsIn) - if err != nil { - return "", err - } - codeBase.WriteString(fmt.Sprintf("%s\n", itemCalculationCodebase)) - } - - return codeBase.String(), nil -} - -func (procDef *PyCalcProcessorDef) analyseExecError(codeBase string, rawOutput string, rawErrors string, err error) (string, error) { - // Linux: err.Error():'exec: "python333": executable file not found in $PATH' - // Windows: err.Error():'exec: "C:\\Program Files\\Python3badpath\\python.exe": file does not exist' - // MacOS/WSL: err.Error():'fork/exec /mnt/c/Users/myusername/AppData/Local/Programs/Python/Python310/python.exe: no such file or directory' - if strings.Contains(err.Error(), "file not found") || - strings.Contains(err.Error(), "file does not exist") || - strings.Contains(err.Error(), "no such file or directory") { - return "", fmt.Errorf("interpreter binary not found: %s", procDef.EnvSettings.InterpreterPath) - } else if strings.Contains(err.Error(), "exit status") { - //err.Error():'exit status 1', cmdCtx.Err():'%!s()' - // Python interpreter reported an error: there was a syntax error in the codebase and no results were returned - fullErrorInfo := fmt.Sprintf("interpreter returned an error (probably syntax):\n%s %v\n%s\n%s\n%s", - procDef.EnvSettings.InterpreterPath, procDef.EnvSettings.InterpreterParams, - rawOutput, - rawErrors, - getErrorLineNumberInfo(codeBase, rawErrors)) - - //fmt.Println(fullErrorInfo) - return fullErrorInfo, fmt.Errorf("interpreter returned an error (probably syntax), see log for details: %s", rawErrors) - } else { - return "", fmt.Errorf("unexpected calculation errors: %s", rawErrors) - } -} - -const pyCalcFlushBufferSize int = 1000 - -func (procDef *PyCalcProcessorDef) analyseExecSuccess(codeBase string, rawOutput string, rawErrors string, outFieldRefs *sc.FieldRefs, rsIn *proc.Rowset, flushVarsArray func(varsArray []*eval.VarValuesMap, varsArrayCount int) error) error { - // No Python interpreter errors, but there may be runtime errors and good results. - // Timeout error may be there too. - - var errors strings.Builder - varsArray := make([]*eval.VarValuesMap, pyCalcFlushBufferSize) - varsArrayCount := 0 - - sectionEndPos := 0 - for rowIdx := 0; rowIdx < rsIn.RowCount; rowIdx++ { - startMarker := fmt.Sprintf("%s:%d", FORMULA_MARKER_DATA_POINTS_INITIALIZATION, rowIdx) - endMarker := fmt.Sprintf("%s:%d", FORMULA_MARKER_END, rowIdx) - relSectionStartPos := strings.Index(rawOutput[sectionEndPos:], startMarker) - relSectionEndPos := strings.Index(rawOutput[sectionEndPos:], endMarker) - sectionStartPos := sectionEndPos + relSectionStartPos - if sectionStartPos == -1 { - return fmt.Errorf("%d: unexpected, cannot find calculation start marker %s;", rowIdx, startMarker) - } - sectionEndPos = sectionEndPos + relSectionEndPos - if sectionEndPos == -1 { - return fmt.Errorf("%d: unexpected, cannot find calculation end marker %s;", rowIdx, endMarker) - } - if sectionStartPos > sectionEndPos { - return fmt.Errorf("%d: unexpected, end marker %s(%d) is earlier than start marker %s(%d);", rowIdx, endMarker, sectionStartPos, endMarker, sectionEndPos) - } - - rawSectionOutput := rawOutput[sectionStartPos:sectionEndPos] - successMarker := fmt.Sprintf("%s:%d", FORMULA_MARKER_SUCCESS, rowIdx) - sectionSuccessPos := strings.Index(rawSectionOutput, successMarker) - if sectionSuccessPos == -1 { - // There was an error calculating fields for this item - // Assume the last line of the out is the error - errorLines := strings.Split(rawSectionOutput, "\n") - errorText := "" - for i := len(errorLines) - 1; i >= 0; i-- { - errorText = strings.Trim(errorLines[i], "\r \t") - if len(errorText) > 0 { - break - } - } - errorText = fmt.Sprintf("%d:cannot calculate data points;%s; ", rowIdx, errorText) - // errors.WriteString(errorText) - errors.WriteString(fmt.Sprintf("%s\n%s", errorText, getErrorLineNumberInfo(codeBase, rawSectionOutput))) - } else { - // SUCESS code snippet is there, try to get the result JSON - var itemResults map[string]interface{} - jsonString := rawSectionOutput[sectionSuccessPos+len(successMarker):] - err := json.Unmarshal([]byte(jsonString), &itemResults) - if err != nil { - // Bad JSON - errorText := fmt.Sprintf("%d:unexpected error, cannot deserialize results, %s, '%s'", rowIdx, err, jsonString) - errors.WriteString(errorText) - //logText.WriteString(errorText) - } else { - // Success - - // We need to include reader fieldsin the result, writermay use any of them - vars := eval.VarValuesMap{} - if err := rsIn.ExportToVars(rowIdx, &vars); err != nil { - return err - } - - vars[sc.CustomProcessorAlias] = map[string]interface{}{} - - for _, outFieldRef := range *outFieldRefs { - pythonFieldValue, ok := itemResults[outFieldRef.FieldName] - if !ok { - errors.WriteString(fmt.Sprintf("cannot find result for row %d, field %s;", rowIdx, outFieldRef.FieldName)) - } else { - valVolatile, err := pythonResultToRowsetValue(&outFieldRef, pythonFieldValue) - if err != nil { - errors.WriteString(fmt.Sprintf("cannot deserialize result for row %d: %s;", rowIdx, err.Error())) - } else { - vars[sc.CustomProcessorAlias][outFieldRef.FieldName] = valVolatile - } - } - } - - if errors.Len() == 0 { - varsArray[varsArrayCount] = &vars - varsArrayCount++ - if varsArrayCount == len(varsArray) { - if err = flushVarsArray(varsArray, varsArrayCount); err != nil { - return fmt.Errorf("error flushing vars array of size %d: %s", varsArrayCount, err.Error()) - } - varsArray = make([]*eval.VarValuesMap, pyCalcFlushBufferSize) - varsArrayCount = 0 - } - } - } - } - } - - if errors.Len() > 0 { - //fmt.Println(fmt.Sprintf("%s\nRaw output below:\n%s\nFull codebase below (may be big):\n%s", logText.String(), rawOutput, codeBase.String())) - return fmt.Errorf(errors.String()) - } else { - //fmt.Println(fmt.Sprintf("%s\nRaw output below:\n%s", logText.String(), rawOutput)) - if varsArrayCount > 0 { - if err := flushVarsArray(varsArray, varsArrayCount); err != nil { - return fmt.Errorf("error flushing leftovers vars array of size %d: %s", varsArrayCount, err.Error()) - } - } - return nil - } -} - -/* -getErrorLineNumbers - shows error lines +/- 5 if error info found in the output -*/ -func getErrorLineNumberInfo(codeBase string, rawErrors string) string { - var errorLineNumberInfo strings.Builder - - reErrLine := regexp.MustCompile(`File "", line ([\d]+)`) - groupMatches := reErrLine.FindAllStringSubmatch(rawErrors, -1) - if len(groupMatches) > 0 { - for matchIdx := 0; matchIdx < len(groupMatches); matchIdx++ { - errLineNum, errAtoi := strconv.Atoi(groupMatches[matchIdx][1]) - if errAtoi != nil { - errorLineNumberInfo.WriteString(fmt.Sprintf("Unexpected error, cannot parse error line number (%s): %s", groupMatches[matchIdx][1], errAtoi)) - } else { - errorLineNumberInfo.WriteString(fmt.Sprintf("Source code lines close to the error location (line %d):\n", errLineNum)) - scanner := bufio.NewScanner(strings.NewReader(codeBase)) - lineNum := 1 - for scanner.Scan() { - if lineNum+15 >= errLineNum && lineNum-15 <= errLineNum { - errorLineNumberInfo.WriteString(fmt.Sprintf("%06d %s\n", lineNum, scanner.Text())) - } - lineNum++ - } - } - } - } else { - errorLineNumberInfo.WriteString(fmt.Sprintf("Unexpected error, cannot find error line number in raw error output %s", rawErrors)) - } - - return errorLineNumberInfo.String() -} - -const FORMULA_MARKER_FUNCTION_DEFINITIONS = "--FMDEF" -const FORMULA_MARKER_DATA_POINTS_INITIALIZATION = "--FMINIT" -const FORMULA_MARKER_CALCULATIONS = "--FMCALC" -const FORMULA_MARKER_SUCCESS = "--FMOK" -const FORMULA_MARKER_END = "--FMEND" - -const ReaderPrefix string = "r_" -const ProcPrefix string = "p_" - -func (procDef *PyCalcProcessorDef) printItemCalculationCode(rowIdx int, rsIn *proc.Rowset) (string, error) { - // Initialize input variables in no particular order - vars := eval.VarValuesMap{} - err := rsIn.ExportToVars(rowIdx, &vars) - if err != nil { - return "", err - } - var bIn strings.Builder - for fieldName, fieldVal := range vars[sc.ReaderAlias] { - bIn.WriteString(fmt.Sprintf(" %s%s = %s\n", ReaderPrefix, fieldName, valueToPythonExpr(fieldVal))) - } - - // Calculation expression order matters (we got it from DAG analysis), so follow it - // for calc data points. Also follow it for results JSON (although this is not important) - var bCalc strings.Builder - var bRes strings.Builder - prefixRemover := strings.NewReplacer(fmt.Sprintf("%s.", sc.CustomProcessorAlias), "") - prefixReplacer := strings.NewReplacer(fmt.Sprintf("%s.", sc.ReaderAlias), ReaderPrefix, fmt.Sprintf("%s.", sc.CustomProcessorAlias), ProcPrefix) - for fieldIdx, procFieldWithAlias := range procDef.CalculationOrder { - procField := prefixRemover.Replace(procFieldWithAlias) - bCalc.WriteString(fmt.Sprintf(" %s%s = %s\n", ProcPrefix, procField, prefixReplacer.Replace(procDef.CalculatedFields[procField].RawExpression))) - bRes.WriteString(fmt.Sprintf(" \"%s\":%s%s", procField, ProcPrefix, procField)) - if fieldIdx < len(procDef.CalculationOrder)-1 { - bRes.WriteString(",") - } - } - - const codeBaseSkeleton = ` -print('') -print('%s:%d') -try: -%s - print('%s:%d') -%s - print('%s:%d') - print(json.dumps({%s})) -except: - print(traceback.format_exc()) -print('%s:%d') -` - return fmt.Sprintf(codeBaseSkeleton, - FORMULA_MARKER_DATA_POINTS_INITIALIZATION, rowIdx, - bIn.String(), - FORMULA_MARKER_CALCULATIONS, rowIdx, - bCalc.String(), - FORMULA_MARKER_SUCCESS, rowIdx, - bRes.String(), - FORMULA_MARKER_END, rowIdx), nil -} +package py_calc + +import ( + "bufio" + "encoding/json" + "fmt" + "go/ast" + "regexp" + "strconv" + "strings" + "time" + + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/capillariesio/capillaries/pkg/proc" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/xfer" + "github.com/shopspring/decimal" +) + +const ProcessorPyCalcName string = "py_calc" + +type PyCalcEnvSettings struct { + // Windows: `python` or `C:\Users\%USERNAME%\AppData\Local\Programs\Python\Python310\python.exe` + // WSL: `python` or `/mnt/c/Users/myusername/AppData/Local/Programs/Python/Python310/python.exe` + // Linux: `python` + InterpreterPath string `json:"python_interpreter_path"` + // Usually: ["-u", "-"]. -u is essential: without it, we will not see stdout/stderr in the timeout scenario + InterpreterParams []string `json:"python_interpreter_params"` + ExecutionTimeout int `json:"execution_timeout"` // Default 5000 milliseconds +} + +type PyCalcProcessorDef struct { + PythonUrls []string `json:"python_code_urls"` + CalculatedFields map[string]*sc.WriteTableFieldDef `json:"calculated_fields"` + UsedInTargetExpressionsFields sc.FieldRefs + PythonCode string + CalculationOrder []string + EnvSettings PyCalcEnvSettings +} + +func (procDef *PyCalcProcessorDef) GetFieldRefs() *sc.FieldRefs { + fieldRefs := make(sc.FieldRefs, len(procDef.CalculatedFields)) + i := 0 + for fieldName, fieldDef := range procDef.CalculatedFields { + fieldRefs[i] = sc.FieldRef{ + TableName: sc.CustomProcessorAlias, + FieldName: fieldName, + FieldType: fieldDef.Type} + i++ + } + return &fieldRefs +} + +func harvestCallExp(callExp *ast.CallExpr, sigMap map[string]struct{}) error { + // This expression is a python fuction call. + // Build a func_name(arg,arg,...) signature for it to check our Python code later + funIdentExp, ok := callExp.Fun.(*ast.Ident) + if !ok { + return fmt.Errorf("cannot cast to ident in harvestCallExp") + } + funcSig := fmt.Sprintf("%s(%s)", funIdentExp.Name, strings.Trim(strings.Repeat("arg,", len(callExp.Args)), ",")) + sigMap[funcSig] = struct{}{} + + for _, argExp := range callExp.Args { + switch typedExp := argExp.(type) { //nolint:all + case *ast.CallExpr: + if err := harvestCallExp(typedExp, sigMap); err != nil { + return err + } + } + } + + return nil +} + +func (procDef *PyCalcProcessorDef) Deserialize(raw json.RawMessage, customProcSettings json.RawMessage, caPath string, privateKeys map[string]string) error { + var err error + if err = json.Unmarshal(raw, procDef); err != nil { + return fmt.Errorf("cannot unmarshal py_calc processor def: %s", err.Error()) + } + + if err = json.Unmarshal(customProcSettings, &procDef.EnvSettings); err != nil { + return fmt.Errorf("cannot unmarshal py_calc processor env settings: %s", err.Error()) + } + + if len(procDef.EnvSettings.InterpreterPath) == 0 { + return fmt.Errorf("py_calc interpreter path canot be empty") + } + + if procDef.EnvSettings.ExecutionTimeout == 0 { + procDef.EnvSettings.ExecutionTimeout = 5000 + } + + errors := make([]string, 0) + usedPythonFunctionSignatures := map[string]struct{}{} + + // Calculated fields + for _, fieldDef := range procDef.CalculatedFields { + + // Use relaxed Go parser for Python - we are lucky that Go designers liked Python, so we do not have to implement a separate Python partser (for now) + if fieldDef.ParsedExpression, err = sc.ParseRawRelaxedGolangExpressionStringAndHarvestFieldRefs(fieldDef.RawExpression, &fieldDef.UsedFields, sc.FieldRefAllowUnknownIdents); err != nil { + errors = append(errors, fmt.Sprintf("cannot parse field expression [%s]: [%s]", fieldDef.RawExpression, err.Error())) + } else if !sc.IsValidFieldType(fieldDef.Type) { + errors = append(errors, fmt.Sprintf("invalid field type [%s]", fieldDef.Type)) + } + + // Each calculated field expression must be a valid Python expression and either: + // 1. some_python_func() + // 2. some reader field, like r.order_id (will be checked by checkFieldUsageInCustomProcessor) + // 3. some field calculated by this processor, like p.calculatedMargin (will be checked by checkFieldUsageInCustomProcessor) + + // Check top-level expression + switch typedExp := fieldDef.ParsedExpression.(type) { + case *ast.CallExpr: + if err := harvestCallExp(typedExp, usedPythonFunctionSignatures); err != nil { + errors = append(errors, fmt.Sprintf("cannot harvest Python call expressions in %s: %s", fieldDef.RawExpression, err.Error())) + } + case *ast.SelectorExpr: + // Assume it's a reader or calculated field. Do not check it here, checkFieldUsageInCustomProcessor() will do that + default: + errors = append(errors, fmt.Sprintf("invalid calculated field expression '%s', expected either 'some_function_from_your_python_code(...)' or some reader field, like 'r.order_id', or some other calculated (by this processor) field, like 'p.calculatedMargin'", fieldDef.RawExpression)) + } + } + + procDef.UsedInTargetExpressionsFields = sc.GetFieldRefsUsedInAllTargetExpressions(procDef.CalculatedFields) + + // Python files + var b strings.Builder + procDef.PythonCode = "" + for _, url := range procDef.PythonUrls { + bytes, err := xfer.GetFileBytes(url, caPath, privateKeys) + if err != nil { + errors = append(errors, err.Error()) + } + b.WriteString(string(bytes)) + b.WriteString("\n") + } + + procDef.PythonCode = b.String() + + if errCheckDefs := checkPythonFuncDefAvailability(usedPythonFunctionSignatures, procDef.PythonCode); errCheckDefs != nil { + errors = append(errors, errCheckDefs.Error()) + } + + // Build a set of "r.inputFieldX" and "p.calculatedFieldY" to perform Python code checks + srcVarsSet := map[string]struct{}{} + for _, fieldRef := range *procDef.GetFieldRefs() { + srcVarsSet[fieldRef.GetAliasHash()] = struct{}{} + } + + // Define calculation sequence + // Populate DAG + dag := map[string][]string{} + for tgtFieldName, tgtFieldDef := range procDef.CalculatedFields { + tgtFieldAlias := fmt.Sprintf("p.%s", tgtFieldName) + dag[tgtFieldAlias] = make([]string, len(tgtFieldDef.UsedFields)) + for i, usedFieldRef := range tgtFieldDef.UsedFields { + dag[tgtFieldAlias][i] = usedFieldRef.GetAliasHash() + } + } + + // Check DAG and return calc fields in the order they should be calculated + if procDef.CalculationOrder, err = kahn(dag); err != nil { + errors = append(errors, fmt.Sprintf("%s. Calc dependency map:\n%v", err, dag)) + } + + // TODO: deserialize other stuff from raw here if needed + + if len(errors) > 0 { + return fmt.Errorf(strings.Join(errors, "; ")) + } + + return nil +} + +func (procDef *PyCalcProcessorDef) GetUsedInTargetExpressionsFields() *sc.FieldRefs { + return &procDef.UsedInTargetExpressionsFields +} + +// Python supports microseconds in datetime. Unfortunately, Cassandra supports only milliseconds. Millis are our lingua franca. +// So, use only three digits after decimal point +// Python 8601 requires ":" in the timezone +const PythonDatetimeFormat string = "2006-01-02T15:04:05.000-07:00" + +func valueToPythonExpr(val any) string { + switch typedVal := val.(type) { + case int64: + return fmt.Sprintf("%d", typedVal) + case float64: + return fmt.Sprintf("%f", typedVal) + case string: + return fmt.Sprintf("'%s'", strings.ReplaceAll(typedVal, "'", `\'`)) // Use single commas in Python - this may go to logs + case bool: + if typedVal { + return "TRUE" + } else { + return "FALSE" + } + case decimal.Decimal: + return typedVal.String() + case time.Time: + return typedVal.Format(fmt.Sprintf("\"%s\"", PythonDatetimeFormat)) + default: + return fmt.Sprintf("cannot convert '%v(%T)' to Python expression", typedVal, typedVal) + } +} + +func pythonResultToRowsetValue(fieldRef *sc.FieldRef, fieldValue any) (any, error) { + switch fieldRef.FieldType { + case sc.FieldTypeString: + finalVal, ok := fieldValue.(string) + if !ok { + return nil, fmt.Errorf("string %s, unexpected type %T(%v)", fieldRef.FieldName, fieldValue, fieldValue) + } + return finalVal, nil + case sc.FieldTypeBool: + finalVal, ok := fieldValue.(bool) + if !ok { + return nil, fmt.Errorf("bool %s, unexpected type %T(%v)", fieldRef.FieldName, fieldValue, fieldValue) + } + return finalVal, nil + case sc.FieldTypeInt: + finalVal, ok := fieldValue.(float64) + if !ok { + return nil, fmt.Errorf("int %s, unexpected type %T(%v)", fieldRef.FieldName, fieldValue, fieldValue) + } + finalIntVal := int64(finalVal) + return finalIntVal, nil + case sc.FieldTypeFloat: + finalVal, ok := fieldValue.(float64) + if !ok { + return nil, fmt.Errorf("float %s, unexpected type %T(%v)", fieldRef.FieldName, fieldValue, fieldValue) + } + return finalVal, nil + case sc.FieldTypeDecimal2: + finalVal, ok := fieldValue.(float64) + if !ok { + return nil, fmt.Errorf("decimal %s, unexpected type %T(%v)", fieldRef.FieldName, fieldValue, fieldValue) + } + finalDecVal := decimal.NewFromFloat(finalVal).Round(2) + return finalDecVal, nil + case sc.FieldTypeDateTime: + finalVal, ok := fieldValue.(string) + if !ok { + return nil, fmt.Errorf("time %s, unexpected type %T(%v)", fieldRef.FieldName, fieldValue, fieldValue) + } + timeVal, err := time.Parse(PythonDatetimeFormat, finalVal) + if err != nil { + return nil, fmt.Errorf("bad time result %s, unexpected format %s", fieldRef.FieldName, finalVal) + } + return timeVal, nil + default: + return nil, fmt.Errorf("unexpected field type %s, %s, %T(%v)", fieldRef.FieldType, fieldRef.FieldName, fieldValue, fieldValue) + } +} + +/* +CheckDefsAvailability - makes sure all expressions mentioned in calc expressions have correspondent Python functions defined. +Pays attention to function name and number of arguments. Expressions may contain: +- calculated field references like p.orderMargin +- number/string/bool constants +- calls to other Python functions +Does NOT perform deep checks for Python function call hierarchy. +*/ +func checkPythonFuncDefAvailability(usedPythonFunctionSignatures map[string]struct{}, codeFormulaDefs string) error { + var errors strings.Builder + + // Walk trhough the whole code and collect Python function defs + availableDefSigs := map[string]struct{}{} + reDefSig := regexp.MustCompile(`(?m)^def[ ]+([a-zA-Z0-9_]+)[ ]*([ \(\)a-zA-Z0-9_,\.\"\'\t]+)[ ]*:`) + reArgCommas := regexp.MustCompile("[^)(,]+") + + // Find "def someFunc123(someParam, 'some literal'):" + defSigMatches := reDefSig.FindAllStringSubmatch(codeFormulaDefs, -1) + for _, sigMatch := range defSigMatches { + if strings.ReplaceAll(sigMatch[2], " ", "") == "()" { + // This Python function definition does not accept arguments + availableDefSigs[fmt.Sprintf("%s()", sigMatch[1])] = struct{}{} + } else { + // Strip all characters from the arg list, except commas and parenthesis, + // this will give us the canonical signature with the number of arguments presented by commas: "sum(arg,arg,arg)"" + canonicalSig := fmt.Sprintf("%s%s", sigMatch[1], reArgCommas.ReplaceAllString(sigMatch[2], "arg")) + availableDefSigs[canonicalSig] = struct{}{} + } + } + + // A curated list of Python functions allowed in script expressions. No exec() please. + pythoBuiltinFunctions := map[string]struct{}{ + "str(arg)": {}, + "int(arg)": {}, + "float(arg)": {}, + "round(arg)": {}, + "len(arg)": {}, + "bool(arg)": {}, + "abs(arg)": {}, + } + + // Walk through all Python signatures in calc expressions (we harvested them in Deserialize) + // and make sure correspondent python function defs are available + for usedSig := range usedPythonFunctionSignatures { + if _, ok := availableDefSigs[usedSig]; !ok { + if _, ok := pythoBuiltinFunctions[usedSig]; !ok { + errors.WriteString(fmt.Sprintf("function def '%s' not found in Python file, and it's not in the list of allowed Python built-in functions; ", usedSig)) + } + } + } + + if errors.Len() > 0 { + var defs strings.Builder + for defSig := range availableDefSigs { + defs.WriteString(fmt.Sprintf("%s; ", defSig)) + } + return fmt.Errorf("Python function defs availability check failed, the following functions are not defined: [%s]. Full list of available Python function definitions: %s", errors.String(), defs.String()) + } + + return nil +} + +func kahn(depMap map[string][]string) ([]string, error) { + inDegreeMap := make(map[string]int) + for node := range depMap { + if depMap[node] != nil { + for _, v := range depMap[node] { + inDegreeMap[v]++ + } + } + } + + var queue []string + for node := range depMap { + if _, ok := inDegreeMap[node]; !ok { + queue = append(queue, node) + } + } + + var execOrder []string + for len(queue) > 0 { + node := queue[len(queue)-1] + queue = queue[:(len(queue) - 1)] + // Prepend "node" to execOrder only if it's the key in depMap + // (this is not one of the algorithm requirements, + // it's just our caller needs only those elements) + if _, ok := depMap[node]; ok { + execOrder = append(execOrder, "") + copy(execOrder[1:], execOrder) + execOrder[0] = node + } + for _, v := range depMap[node] { + inDegreeMap[v]-- + if inDegreeMap[v] == 0 { + queue = append(queue, v) + } + } + } + + for _, inDegree := range inDegreeMap { + if inDegree > 0 { + return []string{}, fmt.Errorf("Formula expressions have cycle(s)") + } + } + + return execOrder, nil +} + +func (procDef *PyCalcProcessorDef) buildPythonCodebaseFromRowset(rsIn *proc.Rowset) (string, error) { + // Build a massive Python source: function defs, dp init, calculation, result json + // This is the hardcoded Python code structure we rely on. Do not change it. + var codeBase strings.Builder + + codeBase.WriteString(fmt.Sprintf(` +import traceback +import json +print("\n%s") # Provide function defs +%s +`, + FORMULA_MARKER_FUNCTION_DEFINITIONS, + procDef.PythonCode)) + + for rowIdx := 0; rowIdx < rsIn.RowCount; rowIdx++ { + itemCalculationCodebase, err := procDef.printItemCalculationCode(rowIdx, rsIn) + if err != nil { + return "", err + } + codeBase.WriteString(fmt.Sprintf("%s\n", itemCalculationCodebase)) + } + + return codeBase.String(), nil +} + +func (procDef *PyCalcProcessorDef) analyseExecError(codeBase string, rawOutput string, rawErrors string, err error) (string, error) { + // Linux: err.Error():'exec: "python333": executable file not found in $PATH' + // Windows: err.Error():'exec: "C:\\Program Files\\Python3badpath\\python.exe": file does not exist' + // MacOS/WSL: err.Error():'fork/exec /mnt/c/Users/myusername/AppData/Local/Programs/Python/Python310/python.exe: no such file or directory' + if strings.Contains(err.Error(), "file not found") || + strings.Contains(err.Error(), "file does not exist") || + strings.Contains(err.Error(), "no such file or directory") { + return "", fmt.Errorf("interpreter binary not found: %s", procDef.EnvSettings.InterpreterPath) + } else if strings.Contains(err.Error(), "exit status") { + // err.Error():'exit status 1', cmdCtx.Err():'%!s()' + // Python interpreter reported an error: there was a syntax error in the codebase and no results were returned + fullErrorInfo := fmt.Sprintf("interpreter returned an error (probably syntax):\n%s %v\n%s\n%s\n%s", + procDef.EnvSettings.InterpreterPath, procDef.EnvSettings.InterpreterParams, + rawOutput, + rawErrors, + getErrorLineNumberInfo(codeBase, rawErrors)) + + // fmt.Println(fullErrorInfo) + return fullErrorInfo, fmt.Errorf("interpreter returned an error (probably syntax), see log for details: %s", rawErrors) + } + return "", fmt.Errorf("unexpected calculation errors: %s", rawErrors) +} + +const pyCalcFlushBufferSize int = 1000 + +func (procDef *PyCalcProcessorDef) analyseExecSuccess(codeBase string, rawOutput string, _ string, outFieldRefs *sc.FieldRefs, rsIn *proc.Rowset, flushVarsArray func(varsArray []*eval.VarValuesMap, varsArrayCount int) error) error { + // No Python interpreter errors, but there may be runtime errors and good results. + // Timeout error may be there too. + + var errors strings.Builder + varsArray := make([]*eval.VarValuesMap, pyCalcFlushBufferSize) + varsArrayCount := 0 + + sectionEndPos := 0 + for rowIdx := 0; rowIdx < rsIn.RowCount; rowIdx++ { + startMarker := fmt.Sprintf("%s:%d", FORMULA_MARKER_DATA_POINTS_INITIALIZATION, rowIdx) + endMarker := fmt.Sprintf("%s:%d", FORMULA_MARKER_END, rowIdx) + relSectionStartPos := strings.Index(rawOutput[sectionEndPos:], startMarker) + relSectionEndPos := strings.Index(rawOutput[sectionEndPos:], endMarker) + sectionStartPos := sectionEndPos + relSectionStartPos + if sectionStartPos == -1 { + return fmt.Errorf("%d: unexpected, cannot find calculation start marker %s;", rowIdx, startMarker) + } + sectionEndPos = sectionEndPos + relSectionEndPos + if sectionEndPos == -1 { + return fmt.Errorf("%d: unexpected, cannot find calculation end marker %s;", rowIdx, endMarker) + } + if sectionStartPos > sectionEndPos { + return fmt.Errorf("%d: unexpected, end marker %s(%d) is earlier than start marker %s(%d);", rowIdx, endMarker, sectionStartPos, endMarker, sectionEndPos) + } + + rawSectionOutput := rawOutput[sectionStartPos:sectionEndPos] + successMarker := fmt.Sprintf("%s:%d", FORMULA_MARKER_SUCCESS, rowIdx) + sectionSuccessPos := strings.Index(rawSectionOutput, successMarker) + if sectionSuccessPos == -1 { + // There was an error calculating fields for this item + // Assume the last line of the out is the error + errorLines := strings.Split(rawSectionOutput, "\n") + errorText := "" + for i := len(errorLines) - 1; i >= 0; i-- { + errorText = strings.Trim(errorLines[i], "\r \t") + if len(errorText) > 0 { + break + } + } + errorText = fmt.Sprintf("%d:cannot calculate data points;%s; ", rowIdx, errorText) + // errors.WriteString(errorText) + errors.WriteString(fmt.Sprintf("%s\n%s", errorText, getErrorLineNumberInfo(codeBase, rawSectionOutput))) + } else { + // SUCCESS code snippet is there, try to get the result JSON + var itemResults map[string]any + jsonString := rawSectionOutput[sectionSuccessPos+len(successMarker):] + err := json.Unmarshal([]byte(jsonString), &itemResults) + if err != nil { + // Bad JSON + errorText := fmt.Sprintf("%d:unexpected error, cannot deserialize results, %s, '%s'", rowIdx, err, jsonString) + errors.WriteString(errorText) + // logText.WriteString(errorText) + } else { + // Success + + // We need to include reader fieldsin the result, writermay use any of them + vars := eval.VarValuesMap{} + if err := rsIn.ExportToVars(rowIdx, &vars); err != nil { + return err + } + + vars[sc.CustomProcessorAlias] = map[string]any{} + + for _, outFieldRef := range *outFieldRefs { + pythonFieldValue, ok := itemResults[outFieldRef.FieldName] + if !ok { + errors.WriteString(fmt.Sprintf("cannot find result for row %d, field %s;", rowIdx, outFieldRef.FieldName)) + } else { + valVolatile, err := pythonResultToRowsetValue(&outFieldRef, pythonFieldValue) + if err != nil { + errors.WriteString(fmt.Sprintf("cannot deserialize result for row %d: %s;", rowIdx, err.Error())) + } else { + vars[sc.CustomProcessorAlias][outFieldRef.FieldName] = valVolatile + } + } + } + + if errors.Len() == 0 { + varsArray[varsArrayCount] = &vars + varsArrayCount++ + if varsArrayCount == len(varsArray) { + if err = flushVarsArray(varsArray, varsArrayCount); err != nil { + return fmt.Errorf("error flushing vars array of size %d: %s", varsArrayCount, err.Error()) + } + varsArray = make([]*eval.VarValuesMap, pyCalcFlushBufferSize) + varsArrayCount = 0 + } + } + } + } + } + + if errors.Len() > 0 { + // fmt.Println(fmt.Sprintf("%s\nRaw output below:\n%s\nFull codebase below (may be big):\n%s", logText.String(), rawOutput, codeBase.String())) + return fmt.Errorf(errors.String()) + } + + // fmt.Println(fmt.Sprintf("%s\nRaw output below:\n%s", logText.String(), rawOutput)) + if varsArrayCount > 0 { + if err := flushVarsArray(varsArray, varsArrayCount); err != nil { + return fmt.Errorf("error flushing leftovers vars array of size %d: %s", varsArrayCount, err.Error()) + } + } + return nil +} + +/* +getErrorLineNumbers - shows error lines +/- 5 if error info found in the output +*/ +func getErrorLineNumberInfo(codeBase string, rawErrors string) string { + var errorLineNumberInfo strings.Builder + + reErrLine := regexp.MustCompile(`File "", line ([\d]+)`) + groupMatches := reErrLine.FindAllStringSubmatch(rawErrors, -1) + if len(groupMatches) > 0 { + for matchIdx := 0; matchIdx < len(groupMatches); matchIdx++ { + errLineNum, errAtoi := strconv.Atoi(groupMatches[matchIdx][1]) + if errAtoi != nil { + errorLineNumberInfo.WriteString(fmt.Sprintf("Unexpected error, cannot parse error line number (%s): %s", groupMatches[matchIdx][1], errAtoi)) + } else { + errorLineNumberInfo.WriteString(fmt.Sprintf("Source code lines close to the error location (line %d):\n", errLineNum)) + scanner := bufio.NewScanner(strings.NewReader(codeBase)) + lineNum := 1 + for scanner.Scan() { + if lineNum+15 >= errLineNum && lineNum-15 <= errLineNum { + errorLineNumberInfo.WriteString(fmt.Sprintf("%06d %s\n", lineNum, scanner.Text())) + } + lineNum++ + } + } + } + } else { + errorLineNumberInfo.WriteString(fmt.Sprintf("Unexpected error, cannot find error line number in raw error output %s", rawErrors)) + } + + return errorLineNumberInfo.String() +} + +const FORMULA_MARKER_FUNCTION_DEFINITIONS = "--FMDEF" +const FORMULA_MARKER_DATA_POINTS_INITIALIZATION = "--FMINIT" +const FORMULA_MARKER_CALCULATIONS = "--FMCALC" +const FORMULA_MARKER_SUCCESS = "--FMOK" +const FORMULA_MARKER_END = "--FMEND" + +const ReaderPrefix string = "r_" +const ProcPrefix string = "p_" + +func (procDef *PyCalcProcessorDef) printItemCalculationCode(rowIdx int, rsIn *proc.Rowset) (string, error) { + // Initialize input variables in no particular order + vars := eval.VarValuesMap{} + err := rsIn.ExportToVars(rowIdx, &vars) + if err != nil { + return "", err + } + var bIn strings.Builder + for fieldName, fieldVal := range vars[sc.ReaderAlias] { + bIn.WriteString(fmt.Sprintf(" %s%s = %s\n", ReaderPrefix, fieldName, valueToPythonExpr(fieldVal))) + } + + // Calculation expression order matters (we got it from DAG analysis), so follow it + // for calc data points. Also follow it for results JSON (although this is not important) + var bCalc strings.Builder + var bRes strings.Builder + prefixRemover := strings.NewReplacer(fmt.Sprintf("%s.", sc.CustomProcessorAlias), "") + prefixReplacer := strings.NewReplacer(fmt.Sprintf("%s.", sc.ReaderAlias), ReaderPrefix, fmt.Sprintf("%s.", sc.CustomProcessorAlias), ProcPrefix) + for fieldIdx, procFieldWithAlias := range procDef.CalculationOrder { + procField := prefixRemover.Replace(procFieldWithAlias) + bCalc.WriteString(fmt.Sprintf(" %s%s = %s\n", ProcPrefix, procField, prefixReplacer.Replace(procDef.CalculatedFields[procField].RawExpression))) + bRes.WriteString(fmt.Sprintf(" \"%s\":%s%s", procField, ProcPrefix, procField)) + if fieldIdx < len(procDef.CalculationOrder)-1 { + bRes.WriteString(",") + } + } + + const codeBaseSkeleton = ` +print('') +print('%s:%d') +try: +%s + print('%s:%d') +%s + print('%s:%d') + print(json.dumps({%s})) +except: + print(traceback.format_exc()) +print('%s:%d') +` + return fmt.Sprintf(codeBaseSkeleton, + FORMULA_MARKER_DATA_POINTS_INITIALIZATION, rowIdx, + bIn.String(), + FORMULA_MARKER_CALCULATIONS, rowIdx, + bCalc.String(), + FORMULA_MARKER_SUCCESS, rowIdx, + bRes.String(), + FORMULA_MARKER_END, rowIdx), nil +} diff --git a/pkg/custom/py_calc/py_calc_test.go b/pkg/custom/py_calc/py_calc_test.go index f7b0e6b..c50a329 100644 --- a/pkg/custom/py_calc/py_calc_test.go +++ b/pkg/custom/py_calc/py_calc_test.go @@ -1,388 +1,408 @@ -package py_calc - -import ( - "encoding/json" - "fmt" - "regexp" - "strings" - "testing" - "time" - - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/capillariesio/capillaries/pkg/proc" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/shopspring/decimal" - "github.com/stretchr/testify/assert" -) - -type PyCalcTestTestProcessorDefFactory struct { -} - -func (f *PyCalcTestTestProcessorDefFactory) Create(processorType string) (sc.CustomProcessorDef, bool) { - switch processorType { - case ProcessorPyCalcName: - return &PyCalcProcessorDef{}, true - default: - return nil, false - } -} - -const scriptJson string = ` -{ - "nodes": { - "read_table1": { - "type": "file_table", - "r": { - "urls": [ - "file1.csv" - ], - "csv":{ - "first_data_line_idx": 0 - }, - "columns": { - "col_field_int": { - "csv":{"col_idx": 0}, - "col_type": "int" - } - } - }, - "w": { - "name": "table1", - "having": "w.field_int1 > 1", - "fields": { - "field_int1": { - "expression": "r.col_field_int", - "type": "int" - }, - "field_float1": { - "expression": "float(r.col_field_int)", - "type": "float" - }, - "field_decimal1": { - "expression": "decimal2(r.col_field_int)", - "type": "decimal2" - }, - "field_string1": { - "expression": "string(r.col_field_int)", - "type": "string" - }, - "field_dt1": { - "expression": "time.Date(2000, time.January, 1, 0, 0, 0, 0, time.FixedZone(\"\", -7200))", - "type": "datetime" - } - } - } - }, - "tax_table1": { - "type": "table_custom_tfm_table", - "custom_proc_type": "py_calc", - "r": { - "table": "table1" - }, - "p": { - "python_code_urls": [ - "../../../test/data/cfg/py_calc_quicktest/py/calc_order_items_code.py" - ], - "calculated_fields": { - "taxed_field_int1": { - "expression": "increase_by_ten_percent(increase_by_ten_percent(r.field_int1))", - "type": "int" - }, - "taxed_field_float1": { - "expression": "increase_by_ten_percent(r.field_float1)", - "type": "float" - }, - "taxed_field_string1": { - "expression": "str(increase_by_ten_percent(float(r.field_string1)))", - "type": "string" - }, - "taxed_field_decimal1": { - "expression": "increase_by_ten_percent(r.field_decimal1)", - "type": "decimal2" - }, - "taxed_field_bool1": { - "expression": "bool(r.field_int1)", - "type": "bool" - }, - "taxed_field_dt1": { - "expression": "r.field_dt1", - "type": "datetime" - } - } - }, - "w": { - "name": "taxed_table1", - "having": "w.taxed_field_decimal > 10", - "fields": { - "field_int1": { - "expression": "p.taxed_field_int1", - "type": "int" - }, - "field_float1": { - "expression": "p.taxed_field_float1", - "type": "float" - }, - "field_string1": { - "expression": "p.taxed_field_string1", - "type": "string" - }, - "taxed_field_decimal": { - "expression": "decimal2(p.taxed_field_float1)", - "type": "decimal2" - }, - "taxed_field_bool": { - "expression": "p.taxed_field_bool1", - "type": "bool" - }, - "taxed_field_dt": { - "expression": "p.taxed_field_dt1", - "type": "datetime" - } - } - } - }, - "file_taxed_table1": { - "type": "table_file", - "r": { - "table": "taxed_table1" - }, - "w": { - "top": { - "order": "taxed_field_int1(asc)" - }, - "url_template": "taxed_table1.csv", - "columns": [ - { - "csv":{ - "header": "field_int1", - "format": "%d" - }, - "name": "field_int1", - "expression": "r.field_int1", - "type": "int" - }, - { - "csv":{ - "header": "field_string1", - "format": "%s" - }, - "name": "field_string1", - "expression": "r.field_string1", - "type": "string" - }, - { - "csv":{ - "header": "taxed_field_decimal", - "format": "%s" - }, - "name": "taxed_field_decimal", - "expression": "r.taxed_field_decimal", - "type": "decimal2" - } - ] - } - } - }, - "dependency_policies": { - "current_active_first_stopped_nogo":` + sc.DefaultPolicyCheckerConf + - ` - } -}` - -const envSettings string = ` -{ - "python_interpreter_path": "/some/bad/python/path", - "python_interpreter_params": ["-u", "-"] -}` - -func TestPyCalcDefCalculator(t *testing.T) { - scriptDef := &sc.ScriptDef{} - err := scriptDef.Deserialize([]byte(scriptJson), &PyCalcTestTestProcessorDefFactory{}, map[string]json.RawMessage{"py_calc": []byte(envSettings)}, "", nil) - assert.Nil(t, err) - - // Initializing rowset is tedious and error-prone. Add schema first. - rs := proc.NewRowsetFromFieldRefs(sc.FieldRefs{ - {TableName: "r", FieldName: "field_int1", FieldType: sc.FieldTypeInt}, - {TableName: "r", FieldName: "field_float1", FieldType: sc.FieldTypeFloat}, - {TableName: "r", FieldName: "field_decimal1", FieldType: sc.FieldTypeDecimal2}, - {TableName: "r", FieldName: "field_string1", FieldType: sc.FieldTypeString}, - {TableName: "r", FieldName: "field_bool1", FieldType: sc.FieldTypeBool}, - {TableName: "r", FieldName: "field_dt1", FieldType: sc.FieldTypeDateTime}, - }) - - // Allocate rows - rs.InitRows(1) - - // Initialize with pointers - i := int64(235) - (*rs.Rows[0])[0] = &i - f := float64(236) - (*rs.Rows[0])[1] = &f - d := decimal.NewFromFloat(237) - (*rs.Rows[0])[2] = &d - s := "238" - (*rs.Rows[0])[3] = &s - b := true - (*rs.Rows[0])[4] = &b - dt := time.Date(2002, 2, 2, 2, 2, 2, 0, time.FixedZone("", -7200)) - (*rs.Rows[0])[5] = &dt - - // Tell it we wrote something to [0] - rs.RowCount++ - - // PyCalcProcessorDef implements both sc.CustomProcessorDef and proc.CustomProcessorRunner. - // We only need the sc.CustomProcessorDef part here, no plans to run Python as part of the unit testing process. - pyCalcProcDef := scriptDef.ScriptNodes["tax_table1"].CustomProcessor.(sc.CustomProcessorDef).(*PyCalcProcessorDef) - - codeBase, err := pyCalcProcDef.buildPythonCodebaseFromRowset(rs) - assert.Nil(t, err) - assert.Contains(t, codeBase, "r_field_int1 = 235") - assert.Contains(t, codeBase, "r_field_float1 = 236.000000") - assert.Contains(t, codeBase, "r_field_decimal1 = 237") - assert.Contains(t, codeBase, "r_field_string1 = '238'") - assert.Contains(t, codeBase, "r_field_bool1 = TRUE") - assert.Contains(t, codeBase, "r_field_dt1 = \"2002-02-02T02:02:02.000-02:00\"") // Capillaries official PythonDatetimeFormat - assert.Contains(t, codeBase, "p_taxed_field_int1 = increase_by_ten_percent(increase_by_ten_percent(r_field_int1))") - assert.Contains(t, codeBase, "p_taxed_field_float1 = increase_by_ten_percent(r_field_float1)") - assert.Contains(t, codeBase, "p_taxed_field_decimal1 = increase_by_ten_percent(r_field_decimal1)") - assert.Contains(t, codeBase, "p_taxed_field_string1 = str(increase_by_ten_percent(float(r_field_string1)))") - assert.Contains(t, codeBase, "p_taxed_field_bool1 = bool(r_field_int1)") - assert.Contains(t, codeBase, "p_taxed_field_dt1 = r_field_dt1") - - // Interpreter executable returns an error - - _, err = pyCalcProcDef.analyseExecError(codeBase, "", "", fmt.Errorf("file not found")) - assert.Equal(t, "interpreter binary not found: /some/bad/python/path", err.Error()) - - _, err = pyCalcProcDef.analyseExecError(codeBase, "", "rawErrors", fmt.Errorf("exit status")) - assert.Equal(t, "interpreter returned an error (probably syntax), see log for details: rawErrors", err.Error()) - - _, err = pyCalcProcDef.analyseExecError(codeBase, "", "unknown raw errors", fmt.Errorf("unexpected error")) - assert.Equal(t, "unexpected calculation errors: unknown raw errors", err.Error()) - - // Interpreter ok, analyse output - - // Test flusher, doesn't write anywhere, just saves data in the local variable - var results []*eval.VarValuesMap - flushVarsArray := func(varsArray []*eval.VarValuesMap, varsArrayCount int) error { - results = varsArray - return nil - } - - // Some error was caught by Python try/catch, it's in the raw output, analyse it - - err = pyCalcProcDef.analyseExecSuccess(codeBase, "", "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) - assert.Equal(t, "0: unexpected, cannot find calculation start marker --FMINIT:0;", err.Error()) - - err = pyCalcProcDef.analyseExecSuccess(codeBase, "--FMINIT:0", "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) - assert.Equal(t, "0: unexpected, cannot find calculation end marker --FMEND:0;", err.Error()) - - err = pyCalcProcDef.analyseExecSuccess(codeBase, "--FMEND:0\n--FMINIT:0", "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) - assert.Equal(t, "0: unexpected, end marker --FMEND:0(10) is earlier than start marker --FMEND:0(0);", err.Error()) - - err = pyCalcProcDef.analyseExecSuccess(codeBase, "--FMINIT:0\n--FMEND:0", "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) - assert.Equal(t, "0:cannot calculate data points;--FMINIT:0; \nUnexpected error, cannot find error line number in raw error output --FMINIT:0\n", err.Error()) - - rawOutput := - ` ---FMINIT:0 -Traceback (most recent call last): - File "", line 1, in - s = Something() -NameError: name 'Something' is not defined ---FMEND:0 -` - err = pyCalcProcDef.analyseExecSuccess(codeBase, rawOutput, "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) - assert.Contains(t, err.Error(), "0:cannot calculate data points;NameError: name 'Something' is not defined; \nSource code lines close to the error location (line 1):\n000001 \n000002 import traceback") - - rawOutput = - ` ---FMINIT:0 -Traceback (most recent call last): - File "some_invalid_file_path", line 1, in - s = Something() -NameError: name 'Something' is not defined ---FMEND:0 -` - err = pyCalcProcDef.analyseExecSuccess(codeBase, rawOutput, "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) - assert.Contains(t, err.Error(), "0:cannot calculate data points;NameError: name 'Something' is not defined; \nUnexpected error, cannot find error line number in raw error output") - - // No error from Python try/catch, get the results from raw output - - rawOutput = - ` ---FMINIT:0 ---FMOK:0 -bla ---FMEND:0 -` - err = pyCalcProcDef.analyseExecSuccess(codeBase, rawOutput, "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) - assert.Contains(t, err.Error(), "0:unexpected error, cannot deserialize results, invalid character 'b' looking for beginning of value, '\nbla\n'") - - rawOutput = - ` ---FMINIT:0 ---FMOK:0 -{"taxed_field_float1":2.2,"taxed_field_string1":"aaa","taxed_field_decimal1":3.3,"taxed_field_bool1":true,"taxed_field_int1":1} ---FMEND:0 -` - err = pyCalcProcDef.analyseExecSuccess(codeBase, rawOutput, "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) - assert.Contains(t, err.Error(), "cannot find result for row 0, field taxed_field_dt1;") - - rawOutput = - ` ---FMINIT:0 ---FMOK:0 -{"taxed_field_float1":2.2,"taxed_field_string1":"aaa","taxed_field_decimal1":3.3,"taxed_field_bool1":true,"taxed_field_int1":1,"taxed_field_dt1":"2003-03-03T03:03:03.000-02:00"} ---FMEND:0 -` - err = pyCalcProcDef.analyseExecSuccess(codeBase, rawOutput, "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) - assert.Nil(t, err) - flushedRow := *results[0] - // r fields must be present in the result, they can be used by the writer - assert.Equal(t, i, flushedRow["r"]["field_int1"]) - assert.Equal(t, f, flushedRow["r"]["field_float1"]) - assert.Equal(t, d, flushedRow["r"]["field_decimal1"]) - assert.Equal(t, s, flushedRow["r"]["field_string1"]) - assert.Equal(t, b, flushedRow["r"]["field_bool1"]) - assert.Equal(t, dt, flushedRow["r"]["field_dt1"]) - // p field must be in the result - assert.Equal(t, int64(1), flushedRow["p"]["taxed_field_int1"]) - assert.Equal(t, 2.2, flushedRow["p"]["taxed_field_float1"]) - assert.True(t, decimal.NewFromFloat(3.3).Equal(flushedRow["p"]["taxed_field_decimal1"].(decimal.Decimal))) - assert.Equal(t, "aaa", flushedRow["p"]["taxed_field_string1"]) - assert.Equal(t, true, flushedRow["p"]["taxed_field_bool1"]) - assert.Equal(t, time.Date(2003, 3, 3, 3, 3, 3, 0, time.FixedZone("", -7200)), flushedRow["p"]["taxed_field_dt1"]) -} - -func TestPyCalcDefBadScript(t *testing.T) { - - scriptDef := &sc.ScriptDef{} - err := scriptDef.Deserialize( - []byte(strings.Replace(scriptJson, `"having": "w.taxed_field_decimal > 10"`, `"having": "p.taxed_field_int1 > 10"`, 1)), - &PyCalcTestTestProcessorDefFactory{}, map[string]json.RawMessage{"py_calc": []byte(envSettings)}, "", nil) - assert.Contains(t, err.Error(), "prohibited field p.taxed_field_int1") - - err = scriptDef.Deserialize( - []byte(strings.Replace(scriptJson, `increase_by_ten_percent(r.field_int1)`, `bad_func(r.field_int1)`, 1)), - &PyCalcTestTestProcessorDefFactory{}, map[string]json.RawMessage{"py_calc": []byte(envSettings)}, "", nil) - assert.Contains(t, err.Error(), "function def 'bad_func(arg)' not found in Python file") - - re := regexp.MustCompile(`"python_code_urls": \[[^\]]+\]`) - err = scriptDef.Deserialize( - []byte(re.ReplaceAllString(scriptJson, `"python_code_urls":[123]`)), - &PyCalcTestTestProcessorDefFactory{}, map[string]json.RawMessage{"py_calc": []byte(envSettings)}, "", nil) - assert.Contains(t, err.Error(), "cannot unmarshal py_calc processor def") - - re = regexp.MustCompile(`"python_interpreter_path": "[^"]+"`) - err = scriptDef.Deserialize( - []byte(scriptJson), - &PyCalcTestTestProcessorDefFactory{}, map[string]json.RawMessage{"py_calc": []byte(re.ReplaceAllString(envSettings, `"python_interpreter_path": 123`))}, "", nil) - assert.Contains(t, err.Error(), "cannot unmarshal py_calc processor env settings") - - err = scriptDef.Deserialize( - []byte(scriptJson), - &PyCalcTestTestProcessorDefFactory{}, map[string]json.RawMessage{"py_calc": []byte(re.ReplaceAllString(envSettings, `"python_interpreter_path": ""`))}, "", nil) - assert.Contains(t, err.Error(), "py_calc interpreter path canot be empty") - -} +package py_calc + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + "testing" + "time" + + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/capillariesio/capillaries/pkg/proc" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" +) + +type PyCalcTestTestProcessorDefFactory struct { +} + +func (f *PyCalcTestTestProcessorDefFactory) Create(processorType string) (sc.CustomProcessorDef, bool) { + switch processorType { + case ProcessorPyCalcName: + return &PyCalcProcessorDef{}, true + default: + return nil, false + } +} + +const scriptJson string = ` +{ + "nodes": { + "read_table1": { + "type": "file_table", + "r": { + "urls": [ + "file1.csv" + ], + "csv":{ + "first_data_line_idx": 0 + }, + "columns": { + "col_field_int": { + "csv":{"col_idx": 0}, + "col_type": "int" + } + } + }, + "w": { + "name": "table1", + "having": "w.field_int1 > 1", + "fields": { + "field_int1": { + "expression": "r.col_field_int", + "type": "int" + }, + "field_float1": { + "expression": "float(r.col_field_int)", + "type": "float" + }, + "field_decimal1": { + "expression": "decimal2(r.col_field_int)", + "type": "decimal2" + }, + "field_string1": { + "expression": "string(r.col_field_int)", + "type": "string" + }, + "field_dt1": { + "expression": "time.Date(2000, time.January, 1, 0, 0, 0, 0, time.FixedZone(\"\", -7200))", + "type": "datetime" + } + } + } + }, + "tax_table1": { + "type": "table_custom_tfm_table", + "custom_proc_type": "py_calc", + "r": { + "table": "table1" + }, + "p": { + "python_code_urls": [ + "../../../test/data/cfg/py_calc_quicktest/py/calc_order_items_code.py" + ], + "calculated_fields": { + "taxed_field_int1": { + "expression": "increase_by_ten_percent(increase_by_ten_percent(r.field_int1))", + "type": "int" + }, + "taxed_field_float1": { + "expression": "increase_by_ten_percent(r.field_float1)", + "type": "float" + }, + "taxed_field_string1": { + "expression": "str(increase_by_ten_percent(float(r.field_string1)))", + "type": "string" + }, + "taxed_field_decimal1": { + "expression": "increase_by_ten_percent(r.field_decimal1)", + "type": "decimal2" + }, + "taxed_field_bool1": { + "expression": "bool(r.field_int1)", + "type": "bool" + }, + "taxed_field_dt1": { + "expression": "r.field_dt1", + "type": "datetime" + } + } + }, + "w": { + "name": "taxed_table1", + "having": "w.taxed_field_decimal > 10", + "fields": { + "field_int1": { + "expression": "p.taxed_field_int1", + "type": "int" + }, + "field_float1": { + "expression": "p.taxed_field_float1", + "type": "float" + }, + "field_string1": { + "expression": "p.taxed_field_string1", + "type": "string" + }, + "taxed_field_decimal": { + "expression": "decimal2(p.taxed_field_float1)", + "type": "decimal2" + }, + "taxed_field_bool": { + "expression": "p.taxed_field_bool1", + "type": "bool" + }, + "taxed_field_dt": { + "expression": "p.taxed_field_dt1", + "type": "datetime" + } + } + } + }, + "file_taxed_table1": { + "type": "table_file", + "r": { + "table": "taxed_table1" + }, + "w": { + "top": { + "order": "field_int1(asc)" + }, + "url_template": "taxed_table1.csv", + "columns": [ + { + "csv":{ + "header": "field_int1", + "format": "%d" + }, + "name": "field_int1", + "expression": "r.field_int1", + "type": "int" + }, + { + "csv":{ + "header": "field_string1", + "format": "%s" + }, + "name": "field_string1", + "expression": "r.field_string1", + "type": "string" + }, + { + "csv":{ + "header": "taxed_field_decimal", + "format": "%s" + }, + "name": "taxed_field_decimal", + "expression": "r.taxed_field_decimal", + "type": "decimal2" + } + ] + } + } + }, + "dependency_policies": { + "current_active_first_stopped_nogo":` + sc.DefaultPolicyCheckerConf + + ` + } +}` + +const envSettings string = ` +{ + "python_interpreter_path": "/some/bad/python/path", + "python_interpreter_params": ["-u", "-"] +}` + +func TestPyCalcDefCalculator(t *testing.T) { + scriptDef := &sc.ScriptDef{} + err := scriptDef.Deserialize([]byte(scriptJson), &PyCalcTestTestProcessorDefFactory{}, map[string]json.RawMessage{"py_calc": []byte(envSettings)}, "", nil) + assert.Nil(t, err) + + // Initializing rowset is tedious and error-prone. Add schema first. + rs := proc.NewRowsetFromFieldRefs(sc.FieldRefs{ + {TableName: "r", FieldName: "field_int1", FieldType: sc.FieldTypeInt}, + {TableName: "r", FieldName: "field_float1", FieldType: sc.FieldTypeFloat}, + {TableName: "r", FieldName: "field_decimal1", FieldType: sc.FieldTypeDecimal2}, + {TableName: "r", FieldName: "field_string1", FieldType: sc.FieldTypeString}, + {TableName: "r", FieldName: "field_bool1", FieldType: sc.FieldTypeBool}, + {TableName: "r", FieldName: "field_dt1", FieldType: sc.FieldTypeDateTime}, + }) + + // Allocate rows + assert.Nil(t, rs.InitRows(1)) + + // Initialize with pointers + i := int64(235) + (*rs.Rows[0])[0] = &i + f := float64(236) + (*rs.Rows[0])[1] = &f + d := decimal.NewFromFloat(237) + (*rs.Rows[0])[2] = &d + s := "238" + (*rs.Rows[0])[3] = &s + b := true + (*rs.Rows[0])[4] = &b + dt := time.Date(2002, 2, 2, 2, 2, 2, 0, time.FixedZone("", -7200)) + (*rs.Rows[0])[5] = &dt + + // Tell it we wrote something to [0] + rs.RowCount++ + + // PyCalcProcessorDef implements both sc.CustomProcessorDef and proc.CustomProcessorRunner. + // We only need the sc.CustomProcessorDef part here, no plans to run Python as part of the unit testing process. + pyCalcProcDef, ok := scriptDef.ScriptNodes["tax_table1"].CustomProcessor.(*PyCalcProcessorDef) + assert.True(t, ok) + + codeBase, err := pyCalcProcDef.buildPythonCodebaseFromRowset(rs) + assert.Nil(t, err) + assert.Contains(t, codeBase, "r_field_int1 = 235") + assert.Contains(t, codeBase, "r_field_float1 = 236.000000") + assert.Contains(t, codeBase, "r_field_decimal1 = 237") + assert.Contains(t, codeBase, "r_field_string1 = '238'") + assert.Contains(t, codeBase, "r_field_bool1 = TRUE") + assert.Contains(t, codeBase, "r_field_dt1 = \"2002-02-02T02:02:02.000-02:00\"") // Capillaries official PythonDatetimeFormat + assert.Contains(t, codeBase, "p_taxed_field_int1 = increase_by_ten_percent(increase_by_ten_percent(r_field_int1))") + assert.Contains(t, codeBase, "p_taxed_field_float1 = increase_by_ten_percent(r_field_float1)") + assert.Contains(t, codeBase, "p_taxed_field_decimal1 = increase_by_ten_percent(r_field_decimal1)") + assert.Contains(t, codeBase, "p_taxed_field_string1 = str(increase_by_ten_percent(float(r_field_string1)))") + assert.Contains(t, codeBase, "p_taxed_field_bool1 = bool(r_field_int1)") + assert.Contains(t, codeBase, "p_taxed_field_dt1 = r_field_dt1") + + // Interpreter executable returns an error + + _, err = pyCalcProcDef.analyseExecError(codeBase, "", "", fmt.Errorf("file not found")) + assert.Equal(t, "interpreter binary not found: /some/bad/python/path", err.Error()) + + _, err = pyCalcProcDef.analyseExecError(codeBase, "", "rawErrors", fmt.Errorf("exit status")) + assert.Equal(t, "interpreter returned an error (probably syntax), see log for details: rawErrors", err.Error()) + + _, err = pyCalcProcDef.analyseExecError(codeBase, "", "unknown raw errors", fmt.Errorf("unexpected error")) + assert.Equal(t, "unexpected calculation errors: unknown raw errors", err.Error()) + + // Interpreter ok, analyse output + + // Test flusher, doesn't write anywhere, just saves data in the local variable + var results []*eval.VarValuesMap + flushVarsArray := func(varsArray []*eval.VarValuesMap, varsArrayCount int) error { + results = varsArray + return nil + } + + // Some error was caught by Python try/catch, it's in the raw output, analyse it + + err = pyCalcProcDef.analyseExecSuccess(codeBase, "", "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) + assert.Equal(t, "0: unexpected, cannot find calculation start marker --FMINIT:0;", err.Error()) + + err = pyCalcProcDef.analyseExecSuccess(codeBase, "--FMINIT:0", "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) + assert.Equal(t, "0: unexpected, cannot find calculation end marker --FMEND:0;", err.Error()) + + err = pyCalcProcDef.analyseExecSuccess(codeBase, "--FMEND:0\n--FMINIT:0", "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) + assert.Equal(t, "0: unexpected, end marker --FMEND:0(10) is earlier than start marker --FMEND:0(0);", err.Error()) + + err = pyCalcProcDef.analyseExecSuccess(codeBase, "--FMINIT:0\n--FMEND:0", "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) + assert.Equal(t, "0:cannot calculate data points;--FMINIT:0; \nUnexpected error, cannot find error line number in raw error output --FMINIT:0\n", err.Error()) + + rawOutput := + ` +--FMINIT:0 +Traceback (most recent call last): + File "", line 1, in + s = Something() +NameError: name 'Something' is not defined +--FMEND:0 +` + err = pyCalcProcDef.analyseExecSuccess(codeBase, rawOutput, "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) + assert.Contains(t, err.Error(), "0:cannot calculate data points;NameError: name 'Something' is not defined; \nSource code lines close to the error location (line 1):\n000001 \n000002 import traceback") + + rawOutput = + ` +--FMINIT:0 +Traceback (most recent call last): + File "some_invalid_file_path", line 1, in + s = Something() +NameError: name 'Something' is not defined +--FMEND:0 +` + err = pyCalcProcDef.analyseExecSuccess(codeBase, rawOutput, "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) + assert.Contains(t, err.Error(), "0:cannot calculate data points;NameError: name 'Something' is not defined; \nUnexpected error, cannot find error line number in raw error output") + + // No error from Python try/catch, get the results from raw output + + rawOutput = + ` +--FMINIT:0 +--FMOK:0 +bla +--FMEND:0 +` + err = pyCalcProcDef.analyseExecSuccess(codeBase, rawOutput, "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) + assert.Contains(t, err.Error(), "0:unexpected error, cannot deserialize results, invalid character 'b' looking for beginning of value, '\nbla\n'") + + rawOutput = + ` +--FMINIT:0 +--FMOK:0 +{"taxed_field_float1":2.2,"taxed_field_string1":"aaa","taxed_field_decimal1":3.3,"taxed_field_bool1":true,"taxed_field_int1":1} +--FMEND:0 +` + err = pyCalcProcDef.analyseExecSuccess(codeBase, rawOutput, "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) + assert.Contains(t, err.Error(), "cannot find result for row 0, field taxed_field_dt1;") + + rawOutput = + ` +--FMINIT:0 +--FMOK:0 +{"taxed_field_float1":2.2,"taxed_field_string1":"aaa","taxed_field_decimal1":3.3,"taxed_field_bool1":true,"taxed_field_int1":1,"taxed_field_dt1":"2003-03-03T03:03:03.000-02:00"} +--FMEND:0 +` + err = pyCalcProcDef.analyseExecSuccess(codeBase, rawOutput, "", pyCalcProcDef.GetFieldRefs(), rs, flushVarsArray) + assert.Nil(t, err) + flushedRow := *results[0] + // r fields must be present in the result, they can be used by the writer + assert.Equal(t, i, flushedRow["r"]["field_int1"]) + assert.Equal(t, f, flushedRow["r"]["field_float1"]) + assert.Equal(t, d, flushedRow["r"]["field_decimal1"]) + assert.Equal(t, s, flushedRow["r"]["field_string1"]) + assert.Equal(t, b, flushedRow["r"]["field_bool1"]) + assert.Equal(t, dt, flushedRow["r"]["field_dt1"]) + // p field must be in the result + assert.Equal(t, int64(1), flushedRow["p"]["taxed_field_int1"]) + assert.Equal(t, 2.2, flushedRow["p"]["taxed_field_float1"]) + assert.True(t, decimal.NewFromFloat(3.3).Equal(flushedRow["p"]["taxed_field_decimal1"].(decimal.Decimal))) + assert.Equal(t, "aaa", flushedRow["p"]["taxed_field_string1"]) + assert.Equal(t, true, flushedRow["p"]["taxed_field_bool1"]) + assert.Equal(t, time.Date(2003, 3, 3, 3, 3, 3, 0, time.FixedZone("", -7200)), flushedRow["p"]["taxed_field_dt1"]) +} + +func TestPyCalcDefBadScript(t *testing.T) { + + scriptDef := &sc.ScriptDef{} + err := scriptDef.Deserialize( + []byte(strings.Replace(scriptJson, `"having": "w.taxed_field_decimal > 10"`, `"having": "p.taxed_field_int1 > 10"`, 1)), + &PyCalcTestTestProcessorDefFactory{}, map[string]json.RawMessage{"py_calc": []byte(envSettings)}, "", nil) + assert.Contains(t, err.Error(), "prohibited field p.taxed_field_int1") + + err = scriptDef.Deserialize( + []byte(strings.Replace(scriptJson, `increase_by_ten_percent(r.field_int1)`, `bad_func(r.field_int1)`, 1)), + &PyCalcTestTestProcessorDefFactory{}, map[string]json.RawMessage{"py_calc": []byte(envSettings)}, "", nil) + assert.Contains(t, err.Error(), "function def 'bad_func(arg)' not found in Python file") + + re := regexp.MustCompile(`"python_code_urls": \[[^\]]+\]`) + err = scriptDef.Deserialize( + []byte(re.ReplaceAllString(scriptJson, `"python_code_urls":[123]`)), + &PyCalcTestTestProcessorDefFactory{}, map[string]json.RawMessage{"py_calc": []byte(envSettings)}, "", nil) + assert.Contains(t, err.Error(), "cannot unmarshal py_calc processor def") + + re = regexp.MustCompile(`"python_interpreter_path": "[^"]+"`) + err = scriptDef.Deserialize( + []byte(scriptJson), + &PyCalcTestTestProcessorDefFactory{}, map[string]json.RawMessage{"py_calc": []byte(re.ReplaceAllString(envSettings, `"python_interpreter_path": 123`))}, "", nil) + assert.Contains(t, err.Error(), "cannot unmarshal py_calc processor env settings") + + err = scriptDef.Deserialize( + []byte(scriptJson), + &PyCalcTestTestProcessorDefFactory{}, map[string]json.RawMessage{"py_calc": []byte(re.ReplaceAllString(envSettings, `"python_interpreter_path": ""`))}, "", nil) + assert.Contains(t, err.Error(), "py_calc interpreter path canot be empty") + +} + +func TestPythonResultToRowsetValueFailures(t *testing.T) { + _, err := pythonResultToRowsetValue(&sc.FieldRef{TableName: "p", FieldName: "field_int1", FieldType: sc.FieldTypeInt}, true) + assert.Contains(t, err.Error(), "int field_int1, unexpected type bool(true)") + _, err = pythonResultToRowsetValue(&sc.FieldRef{TableName: "p", FieldName: "field_float1", FieldType: sc.FieldTypeFloat}, true) + assert.Contains(t, err.Error(), "float field_float1, unexpected type bool(true)") + _, err = pythonResultToRowsetValue(&sc.FieldRef{TableName: "p", FieldName: "field_decimal1", FieldType: sc.FieldTypeDecimal2}, true) + assert.Contains(t, err.Error(), "decimal field_decimal1, unexpected type bool(true)") + _, err = pythonResultToRowsetValue(&sc.FieldRef{TableName: "p", FieldName: "field_string1", FieldType: sc.FieldTypeString}, true) + assert.Contains(t, err.Error(), "string field_string1, unexpected type bool(true)") + _, err = pythonResultToRowsetValue(&sc.FieldRef{TableName: "p", FieldName: "field_datetime1", FieldType: sc.FieldTypeDateTime}, true) + assert.Contains(t, err.Error(), "time field_datetime1, unexpected type bool(true)") + _, err = pythonResultToRowsetValue(&sc.FieldRef{TableName: "p", FieldName: "field_datetime1", FieldType: sc.FieldTypeDateTime}, "aaa") + assert.Contains(t, err.Error(), "bad time result field_datetime1, unexpected format aaa") + _, err = pythonResultToRowsetValue(&sc.FieldRef{TableName: "p", FieldName: "field_bool1", FieldType: sc.FieldTypeBool}, "aaa") + assert.Contains(t, err.Error(), "bool field_bool1, unexpected type string(aaa)") + _, err = pythonResultToRowsetValue(&sc.FieldRef{TableName: "p", FieldName: "bad_field", FieldType: sc.FieldTypeUnknown}, "") + assert.Contains(t, err.Error(), "unexpected field type unknown, bad_field, string()") +} diff --git a/pkg/custom/tag_and_denormalize/tag_and_denormalize.donotcover.go b/pkg/custom/tag_and_denormalize/tag_and_denormalize.donotcover.go index 03d9bf6..08396ab 100644 --- a/pkg/custom/tag_and_denormalize/tag_and_denormalize.donotcover.go +++ b/pkg/custom/tag_and_denormalize/tag_and_denormalize.donotcover.go @@ -1,15 +1,15 @@ -package tag_and_denormalize - -import ( - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/proc" -) - -func (procDef *TagAndDenormalizeProcessorDef) Run(logger *l.Logger, pCtx *ctx.MessageProcessingContext, rsIn *proc.Rowset, flushVarsArray func(varsArray []*eval.VarValuesMap, varsArrayCount int) error) error { - logger.PushF("custom.TagAndDenormalizeProcessorDef.Run") - defer logger.PopF() - - return procDef.tagAndDenormalize(rsIn, flushVarsArray) -} +package tag_and_denormalize + +import ( + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/proc" +) + +func (procDef *TagAndDenormalizeProcessorDef) Run(logger *l.CapiLogger, _ *ctx.MessageProcessingContext, rsIn *proc.Rowset, flushVarsArray func(varsArray []*eval.VarValuesMap, varsArrayCount int) error) error { + logger.PushF("custom.TagAndDenormalizeProcessorDef.Run") + defer logger.PopF() + + return procDef.tagAndDenormalize(rsIn, flushVarsArray) +} diff --git a/pkg/custom/tag_and_denormalize/tag_and_denormalize.go b/pkg/custom/tag_and_denormalize/tag_and_denormalize.go index a9abdd1..9e7c95c 100644 --- a/pkg/custom/tag_and_denormalize/tag_and_denormalize.go +++ b/pkg/custom/tag_and_denormalize/tag_and_denormalize.go @@ -1,141 +1,140 @@ -package tag_and_denormalize - -import ( - "encoding/json" - "fmt" - "go/ast" - "strings" - - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/capillariesio/capillaries/pkg/proc" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/xfer" -) - -const ProcessorTagAndDenormalizeName string = "tag_and_denormalize" - -type TagAndDenormalizeProcessorDef struct { - TagFieldName string `json:"tag_field_name"` - RawTagCriteria map[string]string `json:"tag_criteria"` - RawTagCriteriaUri string `json:"tag_criteria_uri"` - ParsedTagCriteria map[string]ast.Expr - UsedInCriteriaFields sc.FieldRefs -} - -func (procDef *TagAndDenormalizeProcessorDef) GetFieldRefs() *sc.FieldRefs { - return &sc.FieldRefs{ - { - TableName: sc.CustomProcessorAlias, - FieldName: procDef.TagFieldName, - FieldType: sc.FieldTypeString}} -} - -func (procDef *TagAndDenormalizeProcessorDef) GetUsedInTargetExpressionsFields() *sc.FieldRefs { - return &procDef.UsedInCriteriaFields -} - -func (procDef *TagAndDenormalizeProcessorDef) Deserialize(raw json.RawMessage, customProcSettings json.RawMessage, caPath string, privateKeys map[string]string) error { - var err error - if err = json.Unmarshal(raw, procDef); err != nil { - return fmt.Errorf("cannot unmarshal tag_and_denormalize processor def: %s", err.Error()) - } - - errors := make([]string, 0) - procDef.ParsedTagCriteria = map[string]ast.Expr{} - - if len(procDef.RawTagCriteriaUri) > 0 { - if len(procDef.RawTagCriteria) > 0 { - return fmt.Errorf("cannot unmarshal both tag_criteria and tag_criteria_url - pick one") - } - - criteriaBytes, err := xfer.GetFileBytes(procDef.RawTagCriteriaUri, caPath, privateKeys) - if err != nil { - return fmt.Errorf("cannot get criteria file [%s]: %s", procDef.RawTagCriteriaUri, err.Error()) - } - - if criteriaBytes == nil || len(criteriaBytes) == 0 { - return fmt.Errorf("criteria file [%s] is empty", procDef.RawTagCriteriaUri) - } - - if criteriaBytes != nil { - if err := json.Unmarshal(criteriaBytes, &procDef.RawTagCriteria); err != nil { - return fmt.Errorf("cannot unmarshal tag criteria file [%s]: [%s]", procDef.RawTagCriteriaUri, err.Error()) - } - } - } else if len(procDef.RawTagCriteria) == 0 { - return fmt.Errorf("cannot unmarshal with tag_criteria and tag_criteria_url missing") - } - - for tag, rawExp := range procDef.RawTagCriteria { - if procDef.ParsedTagCriteria[tag], err = sc.ParseRawGolangExpressionStringAndHarvestFieldRefs(rawExp, &procDef.UsedInCriteriaFields); err != nil { - errors = append(errors, fmt.Sprintf("cannot parse tag criteria expression [%s]: [%s]", rawExp, err.Error())) - } - } - - // Later on, checkFieldUsageInCustomProcessor() will verify all fields from procDef.UsedInCriteriaFields are valid reader fields - - if len(errors) > 0 { - return fmt.Errorf(strings.Join(errors, "; ")) - } else { - return nil - } -} - -const tagAndDenormalizeFlushBufferSize int = 1000 - -func (procDef *TagAndDenormalizeProcessorDef) tagAndDenormalize(rsIn *proc.Rowset, flushVarsArray func(varsArray []*eval.VarValuesMap, varsArrayCount int) error) error { - varsArray := make([]*eval.VarValuesMap, tagAndDenormalizeFlushBufferSize) - varsArrayCount := 0 - - for rowIdx := 0; rowIdx < rsIn.RowCount; rowIdx++ { - vars := eval.VarValuesMap{} - if err := rsIn.ExportToVars(rowIdx, &vars); err != nil { - return err - } - - for tag, tagCriteria := range procDef.ParsedTagCriteria { - eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &vars) - valVolatile, err := eCtx.Eval(tagCriteria) - if err != nil { - return fmt.Errorf("cannot evaluate expression for tag %s criteria: [%s]", tag, err.Error()) - } - valBool, ok := valVolatile.(bool) - if !ok { - return fmt.Errorf("tag %s criteria returned type %T, expected bool", tag, valVolatile) - } - - if !valBool { - // This tag criteria were not met, skip it - continue - } - - // Add new tag field to the output - - varsArray[varsArrayCount] = &eval.VarValuesMap{} - // Write tag - (*varsArray[varsArrayCount])[sc.CustomProcessorAlias] = map[string]interface{}{procDef.TagFieldName: tag} - // Write r values - (*varsArray[varsArrayCount])[sc.ReaderAlias] = map[string]interface{}{} - for fieldName, fieldVal := range vars[sc.ReaderAlias] { - (*varsArray[varsArrayCount])[sc.ReaderAlias][fieldName] = fieldVal - } - varsArrayCount++ - - if varsArrayCount == len(varsArray) { - if err = flushVarsArray(varsArray, varsArrayCount); err != nil { - return fmt.Errorf("error flushing vars array of size %d: %s", varsArrayCount, err.Error()) - } - varsArray = make([]*eval.VarValuesMap, tagAndDenormalizeFlushBufferSize) - varsArrayCount = 0 - } - } - } - - if varsArrayCount > 0 { - if err := flushVarsArray(varsArray, varsArrayCount); err != nil { - return fmt.Errorf("error flushing leftovers vars array of size %d: %s", varsArrayCount, err.Error()) - } - } - - return nil -} +package tag_and_denormalize + +import ( + "encoding/json" + "fmt" + "go/ast" + "strings" + + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/capillariesio/capillaries/pkg/proc" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/xfer" +) + +const ProcessorTagAndDenormalizeName string = "tag_and_denormalize" + +type TagAndDenormalizeProcessorDef struct { + TagFieldName string `json:"tag_field_name"` + RawTagCriteria map[string]string `json:"tag_criteria"` + RawTagCriteriaUri string `json:"tag_criteria_uri"` + ParsedTagCriteria map[string]ast.Expr + UsedInCriteriaFields sc.FieldRefs +} + +func (procDef *TagAndDenormalizeProcessorDef) GetFieldRefs() *sc.FieldRefs { + return &sc.FieldRefs{ + { + TableName: sc.CustomProcessorAlias, + FieldName: procDef.TagFieldName, + FieldType: sc.FieldTypeString}} +} + +func (procDef *TagAndDenormalizeProcessorDef) GetUsedInTargetExpressionsFields() *sc.FieldRefs { + return &procDef.UsedInCriteriaFields +} + +func (procDef *TagAndDenormalizeProcessorDef) Deserialize(raw json.RawMessage, _ json.RawMessage, caPath string, privateKeys map[string]string) error { + var err error + if err = json.Unmarshal(raw, procDef); err != nil { + return fmt.Errorf("cannot unmarshal tag_and_denormalize processor def: %s", err.Error()) + } + + errors := make([]string, 0) + procDef.ParsedTagCriteria = map[string]ast.Expr{} + + if len(procDef.RawTagCriteriaUri) > 0 { + if len(procDef.RawTagCriteria) > 0 { + return fmt.Errorf("cannot unmarshal both tag_criteria and tag_criteria_url - pick one") + } + + criteriaBytes, err := xfer.GetFileBytes(procDef.RawTagCriteriaUri, caPath, privateKeys) + if err != nil { + return fmt.Errorf("cannot get criteria file [%s]: %s", procDef.RawTagCriteriaUri, err.Error()) + } + + if len(criteriaBytes) == 0 { + return fmt.Errorf("criteria file [%s] is empty", procDef.RawTagCriteriaUri) + } + + if criteriaBytes != nil { + if err := json.Unmarshal(criteriaBytes, &procDef.RawTagCriteria); err != nil { + return fmt.Errorf("cannot unmarshal tag criteria file [%s]: [%s]", procDef.RawTagCriteriaUri, err.Error()) + } + } + } else if len(procDef.RawTagCriteria) == 0 { + return fmt.Errorf("cannot unmarshal with tag_criteria and tag_criteria_url missing") + } + + for tag, rawExp := range procDef.RawTagCriteria { + if procDef.ParsedTagCriteria[tag], err = sc.ParseRawGolangExpressionStringAndHarvestFieldRefs(rawExp, &procDef.UsedInCriteriaFields); err != nil { + errors = append(errors, fmt.Sprintf("cannot parse tag criteria expression [%s]: [%s]", rawExp, err.Error())) + } + } + + // Later on, checkFieldUsageInCustomProcessor() will verify all fields from procDef.UsedInCriteriaFields are valid reader fields + + if len(errors) > 0 { + return fmt.Errorf(strings.Join(errors, "; ")) + } + return nil +} + +const tagAndDenormalizeFlushBufferSize int = 1000 + +func (procDef *TagAndDenormalizeProcessorDef) tagAndDenormalize(rsIn *proc.Rowset, flushVarsArray func(varsArray []*eval.VarValuesMap, varsArrayCount int) error) error { + varsArray := make([]*eval.VarValuesMap, tagAndDenormalizeFlushBufferSize) + varsArrayCount := 0 + + for rowIdx := 0; rowIdx < rsIn.RowCount; rowIdx++ { + vars := eval.VarValuesMap{} + if err := rsIn.ExportToVars(rowIdx, &vars); err != nil { + return err + } + + for tag, tagCriteria := range procDef.ParsedTagCriteria { + eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &vars) + valVolatile, err := eCtx.Eval(tagCriteria) + if err != nil { + return fmt.Errorf("cannot evaluate expression for tag %s criteria: [%s]", tag, err.Error()) + } + valBool, ok := valVolatile.(bool) + if !ok { + return fmt.Errorf("tag %s criteria returned type %T, expected bool", tag, valVolatile) + } + + if !valBool { + // This tag criteria were not met, skip it + continue + } + + // Add new tag field to the output + + varsArray[varsArrayCount] = &eval.VarValuesMap{} + // Write tag + (*varsArray[varsArrayCount])[sc.CustomProcessorAlias] = map[string]any{procDef.TagFieldName: tag} + // Write r values + (*varsArray[varsArrayCount])[sc.ReaderAlias] = map[string]any{} + for fieldName, fieldVal := range vars[sc.ReaderAlias] { + (*varsArray[varsArrayCount])[sc.ReaderAlias][fieldName] = fieldVal + } + varsArrayCount++ + + if varsArrayCount == len(varsArray) { + if err = flushVarsArray(varsArray, varsArrayCount); err != nil { + return fmt.Errorf("error flushing vars array of size %d: %s", varsArrayCount, err.Error()) + } + varsArray = make([]*eval.VarValuesMap, tagAndDenormalizeFlushBufferSize) + varsArrayCount = 0 + } + } + } + + if varsArrayCount > 0 { + if err := flushVarsArray(varsArray, varsArrayCount); err != nil { + return fmt.Errorf("error flushing leftovers vars array of size %d: %s", varsArrayCount, err.Error()) + } + } + + return nil +} diff --git a/pkg/custom/tag_and_denormalize/tag_and_denormalize_test.go b/pkg/custom/tag_and_denormalize/tag_and_denormalize_test.go index 2abd591..a29a74e 100644 --- a/pkg/custom/tag_and_denormalize/tag_and_denormalize_test.go +++ b/pkg/custom/tag_and_denormalize/tag_and_denormalize_test.go @@ -1,299 +1,304 @@ -package tag_and_denormalize - -import ( - "encoding/json" - "fmt" - "regexp" - "strings" - "testing" - - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/capillariesio/capillaries/pkg/proc" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/shopspring/decimal" - "github.com/stretchr/testify/assert" -) - -type TagAndDenormalizeTestTestProcessorDefFactory struct { -} - -func (f *TagAndDenormalizeTestTestProcessorDefFactory) Create(processorType string) (sc.CustomProcessorDef, bool) { - switch processorType { - case ProcessorTagAndDenormalizeName: - return &TagAndDenormalizeProcessorDef{}, true - default: - return nil, false - } -} - -const scriptJson string = ` -{ - "nodes": { - "read_products": { - "type": "file_table", - "desc": "Load product data from CSV files to a table, one input file - one batch", - "explicit_run_only": true, - "r": { - "urls": ["{test_root_dir}/data/in/flipcart_products.tsv"], - "csv":{ - "separator": "\t", - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_product_id": { - "csv":{ - "col_idx": 0, - "col_format": "%d" - }, - "col_type": "int" - }, - "col_product_name": { - "csv":{ - "col_idx": 1 - }, - "col_type": "string" - }, - "col_product_category_tree": { - "csv":{ - "col_idx": 2 - }, - "col_type": "string" - }, - "col_retail_price": { - "csv":{ - "col_idx": 3, - "col_format": "%f" - }, - "col_type": "decimal2" - }, - "col_product_specifications": { - "csv":{ - "col_idx": 4 - }, - "col_type": "string" - } - } - }, - "w": { - "name": "products", - "fields": { - "product_id": { - "expression": "r.col_product_id", - "type": "int" - }, - "name": { - "expression": "r.col_product_name", - "type": "string" - }, - "category_tree": { - "expression": "r.col_product_category_tree", - "type": "string" - }, - "price": { - "expression": "r.col_retail_price", - "type": "decimal2" - }, - "product_spec": { - "expression": "r.col_product_specifications", - "type": "string" - } - } - } - }, - "tag_products": { - "type": "table_custom_tfm_table", - "custom_proc_type": "tag_and_denormalize", - "desc": "Tag products according to criteria and write product tag, id, price to a new table", - "r": { - "table": "products", - "expected_batches_total": 10 - }, - "p": { - "tag_field_name": "tag", - "tag_criteria": { - "boys":"re.MatchString(` + "`" + `\"k\":\"Ideal For\",\"v\":\"[\\w ,]*Boys[\\w ,]*\"` + "`" + `, r.product_spec)", - "diving":"re.MatchString(` + "`" + `\"k\":\"Water Resistance Depth\",\"v\":\"(100|200) m\"` + "`" + `, r.product_spec)", - "engagement":"re.MatchString(` + "`" + `\"k\":\"Occasion\",\"v\":\"[\\w ,]*Engagement[\\w ,]*\"` + "`" + `, r.product_spec) && re.MatchString(` + "`" + `\"k\":\"Gemstone\",\"v\":\"Diamond\"` + "`" + `, r.product_spec) && r.price > 5000" - } - }, - "w": { - "name": "tagged_products", - "having": "len(w.tag) > 0", - "fields": { - "tag": { - "expression": "p.tag", - "type": "string" - }, - "product_id": { - "expression": "r.product_id", - "type": "int" - }, - "price": { - "expression": "r.price", - "type": "decimal2" - } - }, - "indexes": { - "idx_tagged_products_tag": "non_unique(tag)" - } - } - } - }, - "dependency_policies": { - "current_active_first_stopped_nogo":` + sc.DefaultPolicyCheckerConf + - ` - } -}` - -func TestTagAndDenormalizeDeserializeFileCriteria(t *testing.T) { - scriptDef := &sc.ScriptDef{} - - re := regexp.MustCompile(`"tag_criteria": \{[^\}]+\}`) - err := scriptDef.Deserialize( - []byte(re.ReplaceAllString(scriptJson, `"tag_criteria_uri": "../../../test/data/cfg/tag_and_denormalize_quicktest/tag_criteria.json"`)), - &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) - assert.Nil(t, err) - - tndProcessor, _ := scriptDef.ScriptNodes["tag_products"].CustomProcessor.(*TagAndDenormalizeProcessorDef) - assert.Equal(t, 4, len(tndProcessor.ParsedTagCriteria)) -} - -func TestTagAndDenormalizeRunEmbeddedCriteria(t *testing.T) { - scriptDef := &sc.ScriptDef{} - - err := scriptDef.Deserialize([]byte(scriptJson), &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) - assert.Nil(t, err) - - tndProcessor, _ := scriptDef.ScriptNodes["tag_products"].CustomProcessor.(*TagAndDenormalizeProcessorDef) - assert.Equal(t, 3, len(tndProcessor.ParsedTagCriteria)) - - // Initializing rowset is tedious and error-prone. Add schema first. - rs := proc.NewRowsetFromFieldRefs(sc.FieldRefs{ - {TableName: "r", FieldName: "product_id", FieldType: sc.FieldTypeInt}, - {TableName: "r", FieldName: "name", FieldType: sc.FieldTypeString}, - {TableName: "r", FieldName: "price", FieldType: sc.FieldTypeDecimal2}, - {TableName: "r", FieldName: "product_spec", FieldType: sc.FieldTypeString}, - }) - - // Allocate rows - rs.InitRows(1) - - // Initialize with pointers - product_id := int64(1) - (*rs.Rows[0])[0] = &product_id - name := "Breitling AB011010/BB08 131S Chronomat 44 Analog Watch" - (*rs.Rows[0])[1] = &name - price := decimal.NewFromFloat(571230) - (*rs.Rows[0])[2] = &price - product_spec := `{"k":"Occasion","v":"Formal, Casual"}, {"k":"Ideal For","v":"Boys, Men"}, {"k":"Water Resistance Depth","v":"100 m"}` - (*rs.Rows[0])[3] = &product_spec - - // Tell it we wrote something to [0] - rs.RowCount++ - - // Test flusher, doesn't write anywhere, just saves data in the local variable - var results []*eval.VarValuesMap - flushVarsArray := func(varsArray []*eval.VarValuesMap, varsArrayCount int) error { - results = varsArray - return nil - } - - err = tndProcessor.tagAndDenormalize(rs, flushVarsArray) - assert.Nil(t, err) - - // Check that 2 rows were produced: thiswatch is good for boys and for diving - - flushedRow := *results[0] - // r fields must be present in the result, they can be used by the writer - assert.Equal(t, product_id, flushedRow["r"]["product_id"]) - assert.Equal(t, name, flushedRow["r"]["name"]) - assert.Equal(t, price, flushedRow["r"]["price"]) - assert.Equal(t, product_spec, flushedRow["r"]["product_spec"]) - // p field must be in the result - var nextExpectedTag string - if flushedRow["p"]["tag"].(string) == "boys" { - nextExpectedTag = "diving" - } else if flushedRow["p"]["tag"].(string) == "diving" { - nextExpectedTag = "boys" - } else { - assert.Fail(t, fmt.Sprintf("unexpected tag %s", *(flushedRow["p"]["tag"].(*string)))) - } - - flushedRow = *results[1] - // r fields must be present in the result, they can be used by the writer - assert.Equal(t, product_id, flushedRow["r"]["product_id"]) - assert.Equal(t, name, flushedRow["r"]["name"]) - assert.Equal(t, price, flushedRow["r"]["price"]) - assert.Equal(t, product_spec, flushedRow["r"]["product_spec"]) - // p field must be in the result - assert.Equal(t, nextExpectedTag, flushedRow["p"]["tag"]) - - // Bad criteria - re := regexp.MustCompile(`"tag_criteria": \{[^\}]+\}`) - - // Bad function used - err = scriptDef.Deserialize( - []byte(re.ReplaceAllString(scriptJson, `"tag_criteria": {"boys":"re.BadGoMethod(\"aaa\")"}`)), - &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) - - tndProcessor, _ = scriptDef.ScriptNodes["tag_products"].CustomProcessor.(*TagAndDenormalizeProcessorDef) - assert.Equal(t, 1, len(tndProcessor.ParsedTagCriteria)) - - err = tndProcessor.tagAndDenormalize(rs, flushVarsArray) - assert.Contains(t, err.Error(), "cannot evaluate expression for tag boys criteria") - - // Bad type - err = scriptDef.Deserialize( - []byte(re.ReplaceAllString(scriptJson, `"tag_criteria": {"boys":"math.Round(1.1)"}`)), - &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) - - tndProcessor, _ = scriptDef.ScriptNodes["tag_products"].CustomProcessor.(*TagAndDenormalizeProcessorDef) - assert.Equal(t, 1, len(tndProcessor.ParsedTagCriteria)) - - err = tndProcessor.tagAndDenormalize(rs, flushVarsArray) - assert.Contains(t, err.Error(), "tag boys criteria returned type float64, expected bool") -} - -func TestTagAndDenormalizeDeserializeFailures(t *testing.T) { - scriptDef := &sc.ScriptDef{} - - // Exercise checkFieldUsageInCustomProcessor() error code path - err := scriptDef.Deserialize( - []byte(strings.ReplaceAll(scriptJson, `r.product_spec`, `w.product_spec`)), - &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) - assert.Contains(t, err.Error(), "unknown field w.product_spec") - - // Prohibited field - err = scriptDef.Deserialize( - []byte(strings.Replace(scriptJson, `"having": "len(w.tag) > 0"`, `"having": "len(p.tag) > 0"`, 1)), - &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) - assert.Contains(t, err.Error(), "prohibited field p.tag") - - // Bad criteria - re := regexp.MustCompile(`"tag_criteria": \{[^\}]+\}`) - err = scriptDef.Deserialize( - []byte(re.ReplaceAllString(scriptJson, `"some_bogus_key": 123`)), - &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) - assert.Contains(t, err.Error(), "cannot unmarshal with tag_criteria and tag_criteria_url missing") - - err = scriptDef.Deserialize( - []byte(re.ReplaceAllString(scriptJson, `"tag_criteria":{"a":"b"},"tag_criteria_uri":"aaa"`)), - &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) - assert.Contains(t, err.Error(), "cannot unmarshal both tag_criteria and tag_criteria_url - pick one") - - err = scriptDef.Deserialize( - []byte(re.ReplaceAllString(scriptJson, `"tag_criteria_uri":"aaa"`)), - &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) - assert.Contains(t, err.Error(), "cannot get criteria file") - - err = scriptDef.Deserialize( - []byte(re.ReplaceAllString(scriptJson, `"tag_criteria": ["boys"]`)), - &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) - assert.Contains(t, err.Error(), "cannot unmarshal array into Go struct") - - err = scriptDef.Deserialize( - []byte(re.ReplaceAllString(scriptJson, `"tag_criteria": {"boys":"["}`)), - &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) - assert.Contains(t, err.Error(), "cannot parse tag criteria expression") -} +package tag_and_denormalize + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + "testing" + + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/capillariesio/capillaries/pkg/proc" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" +) + +type TagAndDenormalizeTestTestProcessorDefFactory struct { +} + +func (f *TagAndDenormalizeTestTestProcessorDefFactory) Create(processorType string) (sc.CustomProcessorDef, bool) { + switch processorType { + case ProcessorTagAndDenormalizeName: + return &TagAndDenormalizeProcessorDef{}, true + default: + return nil, false + } +} + +const scriptJson string = ` +{ + "nodes": { + "read_products": { + "type": "file_table", + "desc": "Load product data from CSV files to a table, one input file - one batch", + "explicit_run_only": true, + "r": { + "urls": ["{test_root_dir}/data/in/flipcart_products.tsv"], + "csv":{ + "separator": "\t", + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_product_id": { + "csv":{ + "col_idx": 0, + "col_format": "%d" + }, + "col_type": "int" + }, + "col_product_name": { + "csv":{ + "col_idx": 1 + }, + "col_type": "string" + }, + "col_product_category_tree": { + "csv":{ + "col_idx": 2 + }, + "col_type": "string" + }, + "col_retail_price": { + "csv":{ + "col_idx": 3, + "col_format": "%f" + }, + "col_type": "decimal2" + }, + "col_product_specifications": { + "csv":{ + "col_idx": 4 + }, + "col_type": "string" + } + } + }, + "w": { + "name": "products", + "fields": { + "product_id": { + "expression": "r.col_product_id", + "type": "int" + }, + "name": { + "expression": "r.col_product_name", + "type": "string" + }, + "category_tree": { + "expression": "r.col_product_category_tree", + "type": "string" + }, + "price": { + "expression": "r.col_retail_price", + "type": "decimal2" + }, + "product_spec": { + "expression": "r.col_product_specifications", + "type": "string" + } + } + } + }, + "tag_products": { + "type": "table_custom_tfm_table", + "custom_proc_type": "tag_and_denormalize", + "desc": "Tag products according to criteria and write product tag, id, price to a new table", + "r": { + "table": "products", + "expected_batches_total": 10 + }, + "p": { + "tag_field_name": "tag", + "tag_criteria": { + "boys":"re.MatchString(` + "`" + `\"k\":\"Ideal For\",\"v\":\"[\\w ,]*Boys[\\w ,]*\"` + "`" + `, r.product_spec)", + "diving":"re.MatchString(` + "`" + `\"k\":\"Water Resistance Depth\",\"v\":\"(100|200) m\"` + "`" + `, r.product_spec)", + "engagement":"re.MatchString(` + "`" + `\"k\":\"Occasion\",\"v\":\"[\\w ,]*Engagement[\\w ,]*\"` + "`" + `, r.product_spec) && re.MatchString(` + "`" + `\"k\":\"Gemstone\",\"v\":\"Diamond\"` + "`" + `, r.product_spec) && r.price > 5000" + } + }, + "w": { + "name": "tagged_products", + "having": "len(w.tag) > 0", + "fields": { + "tag": { + "expression": "p.tag", + "type": "string" + }, + "product_id": { + "expression": "r.product_id", + "type": "int" + }, + "price": { + "expression": "r.price", + "type": "decimal2" + } + }, + "indexes": { + "idx_tagged_products_tag": "non_unique(tag)" + } + } + } + }, + "dependency_policies": { + "current_active_first_stopped_nogo":` + sc.DefaultPolicyCheckerConf + + ` + } +}` + +func TestTagAndDenormalizeDeserializeFileCriteria(t *testing.T) { + scriptDef := &sc.ScriptDef{} + + re := regexp.MustCompile(`"tag_criteria": \{[^\}]+\}`) + err := scriptDef.Deserialize( + []byte(re.ReplaceAllString(scriptJson, `"tag_criteria_uri": "../../../test/data/cfg/tag_and_denormalize_quicktest/tag_criteria.json"`)), + &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) + assert.Nil(t, err) + + tndProcessor, ok := scriptDef.ScriptNodes["tag_products"].CustomProcessor.(*TagAndDenormalizeProcessorDef) + assert.True(t, ok) + assert.Equal(t, 4, len(tndProcessor.ParsedTagCriteria)) +} + +func TestTagAndDenormalizeRunEmbeddedCriteria(t *testing.T) { + scriptDef := &sc.ScriptDef{} + + err := scriptDef.Deserialize([]byte(scriptJson), &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) + assert.Nil(t, err) + + tndProcessor, ok := scriptDef.ScriptNodes["tag_products"].CustomProcessor.(*TagAndDenormalizeProcessorDef) + assert.True(t, ok) + assert.Equal(t, 3, len(tndProcessor.ParsedTagCriteria)) + + // Initializing rowset is tedious and error-prone. Add schema first. + rs := proc.NewRowsetFromFieldRefs(sc.FieldRefs{ + {TableName: "r", FieldName: "product_id", FieldType: sc.FieldTypeInt}, + {TableName: "r", FieldName: "name", FieldType: sc.FieldTypeString}, + {TableName: "r", FieldName: "price", FieldType: sc.FieldTypeDecimal2}, + {TableName: "r", FieldName: "product_spec", FieldType: sc.FieldTypeString}, + }) + + // Allocate rows + assert.Nil(t, rs.InitRows(1)) + + // Initialize with pointers + product_id := int64(1) + (*rs.Rows[0])[0] = &product_id + name := "Breitling AB011010/BB08 131S Chronomat 44 Analog Watch" + (*rs.Rows[0])[1] = &name + price := decimal.NewFromFloat(571230) + (*rs.Rows[0])[2] = &price + product_spec := `{"k":"Occasion","v":"Formal, Casual"}, {"k":"Ideal For","v":"Boys, Men"}, {"k":"Water Resistance Depth","v":"100 m"}` + (*rs.Rows[0])[3] = &product_spec + + // Tell it we wrote something to [0] + rs.RowCount++ + + // Test flusher, doesn't write anywhere, just saves data in the local variable + var results []*eval.VarValuesMap + flushVarsArray := func(varsArray []*eval.VarValuesMap, varsArrayCount int) error { + results = varsArray + return nil + } + + err = tndProcessor.tagAndDenormalize(rs, flushVarsArray) + assert.Nil(t, err) + + // Check that 2 rows were produced: thiswatch is good for boys and for diving + + flushedRow := *results[0] + // r fields must be present in the result, they can be used by the writer + assert.Equal(t, product_id, flushedRow["r"]["product_id"]) + assert.Equal(t, name, flushedRow["r"]["name"]) + assert.Equal(t, price, flushedRow["r"]["price"]) + assert.Equal(t, product_spec, flushedRow["r"]["product_spec"]) + // p field must be in the result + var nextExpectedTag string + flushedRowTag, ok := flushedRow["p"]["tag"].(string) + assert.True(t, ok) + if flushedRowTag == "boys" { + nextExpectedTag = "diving" + } else if flushedRowTag == "diving" { + nextExpectedTag = "boys" + } else { + assert.Fail(t, fmt.Sprintf("unexpected tag %s", flushedRowTag)) + } + + flushedRow = *results[1] + // r fields must be present in the result, they can be used by the writer + assert.Equal(t, product_id, flushedRow["r"]["product_id"]) + assert.Equal(t, name, flushedRow["r"]["name"]) + assert.Equal(t, price, flushedRow["r"]["price"]) + assert.Equal(t, product_spec, flushedRow["r"]["product_spec"]) + // p field must be in the result + assert.Equal(t, nextExpectedTag, flushedRow["p"]["tag"]) + + // Bad criteria + re := regexp.MustCompile(`"tag_criteria": \{[^\}]+\}`) + + // Bad function used + assert.Nil(t, scriptDef.Deserialize( + []byte(re.ReplaceAllString(scriptJson, `"tag_criteria": {"boys":"re.BadGoMethod(\"aaa\")"}`)), + &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil)) + + tndProcessor, ok = scriptDef.ScriptNodes["tag_products"].CustomProcessor.(*TagAndDenormalizeProcessorDef) + assert.True(t, ok) + assert.Equal(t, 1, len(tndProcessor.ParsedTagCriteria)) + + err = tndProcessor.tagAndDenormalize(rs, flushVarsArray) + assert.Contains(t, err.Error(), "cannot evaluate expression for tag boys criteria") + + // Bad type + assert.Nil(t, scriptDef.Deserialize( + []byte(re.ReplaceAllString(scriptJson, `"tag_criteria": {"boys":"math.Round(1.1)"}`)), + &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil)) + tndProcessor, ok = scriptDef.ScriptNodes["tag_products"].CustomProcessor.(*TagAndDenormalizeProcessorDef) + assert.True(t, ok) + assert.Equal(t, 1, len(tndProcessor.ParsedTagCriteria)) + + err = tndProcessor.tagAndDenormalize(rs, flushVarsArray) + assert.Contains(t, err.Error(), "tag boys criteria returned type float64, expected bool") +} + +func TestTagAndDenormalizeDeserializeFailures(t *testing.T) { + scriptDef := &sc.ScriptDef{} + + // Exercise checkFieldUsageInCustomProcessor() error code path + err := scriptDef.Deserialize( + []byte(strings.ReplaceAll(scriptJson, `r.product_spec`, `w.product_spec`)), + &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) + assert.Contains(t, err.Error(), "unknown field w.product_spec") + + // Prohibited field + err = scriptDef.Deserialize( + []byte(strings.Replace(scriptJson, `"having": "len(w.tag) > 0"`, `"having": "len(p.tag) > 0"`, 1)), + &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) + assert.Contains(t, err.Error(), "prohibited field p.tag") + + // Bad criteria + re := regexp.MustCompile(`"tag_criteria": \{[^\}]+\}`) + err = scriptDef.Deserialize( + []byte(re.ReplaceAllString(scriptJson, `"some_bogus_key": 123`)), + &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) + assert.Contains(t, err.Error(), "cannot unmarshal with tag_criteria and tag_criteria_url missing") + + err = scriptDef.Deserialize( + []byte(re.ReplaceAllString(scriptJson, `"tag_criteria":{"a":"b"},"tag_criteria_uri":"aaa"`)), + &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) + assert.Contains(t, err.Error(), "cannot unmarshal both tag_criteria and tag_criteria_url - pick one") + + err = scriptDef.Deserialize( + []byte(re.ReplaceAllString(scriptJson, `"tag_criteria_uri":"aaa"`)), + &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) + assert.Contains(t, err.Error(), "cannot get criteria file") + + err = scriptDef.Deserialize( + []byte(re.ReplaceAllString(scriptJson, `"tag_criteria": ["boys"]`)), + &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) + assert.Contains(t, err.Error(), "cannot unmarshal array into Go struct") + + err = scriptDef.Deserialize( + []byte(re.ReplaceAllString(scriptJson, `"tag_criteria": {"boys":"["}`)), + &TagAndDenormalizeTestTestProcessorDefFactory{}, map[string]json.RawMessage{"tag_and_denormalize": {}}, "", nil) + assert.Contains(t, err.Error(), "cannot parse tag criteria expression") +} diff --git a/pkg/db/cassandra.go b/pkg/db/cassandra.go index 0229b62..d553888 100644 --- a/pkg/db/cassandra.go +++ b/pkg/db/cassandra.go @@ -1,107 +1,107 @@ -package db - -import ( - "fmt" - "reflect" - "strings" - "time" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/env" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "github.com/gocql/gocql" -) - -const ErrorPrefixDb string = "dberror:" - -func WrapDbErrorWithQuery(msg string, query string, dbErr error) error { - if len(query) > 500 { - query = query[:500] - } - return fmt.Errorf("%s, query:%s, %s%s", msg, query, ErrorPrefixDb, dbErr.Error()) -} - -func IsDbConnError(err error) bool { - return strings.Contains(err.Error(), ErrorPrefixDb+gocql.ErrNoConnections.Error()) || - strings.Contains(err.Error(), ErrorPrefixDb+"EOF") - -} - -func createWfTable(cqlSession *gocql.Session, keyspace string, t reflect.Type, tableName string) error { - q := wfmodel.GetCreateTableCql(t, keyspace, tableName) - if err := cqlSession.Query(q).Exec(); err != nil { - return WrapDbErrorWithQuery("failed to create WF table", q, err) - } - return nil -} - -type CreateKeyspaceEnumType int - -const DoNotCreateKeyspaceOnConnect CreateKeyspaceEnumType = 0 -const CreateKeyspaceOnConnect CreateKeyspaceEnumType = 1 - -func NewSession(envConfig *env.EnvConfig, keyspace string, createKeyspace CreateKeyspaceEnumType) (*gocql.Session, error) { - dataCluster := gocql.NewCluster(envConfig.Cassandra.Hosts...) - dataCluster.Port = envConfig.Cassandra.Port - dataCluster.Authenticator = gocql.PasswordAuthenticator{Username: envConfig.Cassandra.Username, Password: envConfig.Cassandra.Password} - dataCluster.NumConns = envConfig.Cassandra.NumConns - dataCluster.Timeout = time.Duration(envConfig.Cassandra.Timeout * int(time.Millisecond)) - dataCluster.ConnectTimeout = time.Duration(envConfig.Cassandra.ConnectTimeout * int(time.Millisecond)) - // Token-aware policy should give better perf results when used together with prepared queries, and Capillaries chatty inserts are killing Cassandra. - // TODO: consider making it configurable - dataCluster.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(gocql.RoundRobinHostPolicy()) - // When testing, we load Cassandra cluster at 100%. There will be "Operation timed out - received only 0 responses" errors. - // It's up to admins how to handle the load, but we should not give up quickly in any case. Make it 3 attempts. - dataCluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: 3} - if envConfig.Cassandra.SslOpts != nil { - dataCluster.SslOpts = &gocql.SslOptions{ - EnableHostVerification: envConfig.Cassandra.SslOpts.EnableHostVerification, - CaPath: envConfig.Cassandra.SslOpts.CaPath, - CertPath: envConfig.Cassandra.SslOpts.CaPath, - KeyPath: envConfig.Cassandra.SslOpts.KeyPath} - } - cqlSession, err := dataCluster.CreateSession() - if err != nil { - return nil, fmt.Errorf("failed to connect to data cluster %v, keyspace [%s]: %s", envConfig.Cassandra.Hosts, keyspace, err.Error()) - } - // Create keyspace if needed - if len(keyspace) > 0 { - dataCluster.Keyspace = keyspace - - if createKeyspace == CreateKeyspaceOnConnect { - createKsQuery := fmt.Sprintf("CREATE KEYSPACE IF NOT EXISTS %s WITH REPLICATION = %s", keyspace, envConfig.Cassandra.KeyspaceReplicationConfig) - if err := cqlSession.Query(createKsQuery).Exec(); err != nil { - return nil, WrapDbErrorWithQuery("failed to create keyspace", createKsQuery, err) - } - - // Create WF tables if needed - if err = createWfTable(cqlSession, keyspace, reflect.TypeOf(wfmodel.BatchHistoryEvent{}), wfmodel.TableNameBatchHistory); err != nil { - return nil, err - } - if err = createWfTable(cqlSession, keyspace, reflect.TypeOf(wfmodel.NodeHistoryEvent{}), wfmodel.TableNameNodeHistory); err != nil { - return nil, err - } - if err = createWfTable(cqlSession, keyspace, reflect.TypeOf(wfmodel.RunHistoryEvent{}), wfmodel.TableNameRunHistory); err != nil { - return nil, err - } - if err = createWfTable(cqlSession, keyspace, reflect.TypeOf(wfmodel.RunProperties{}), wfmodel.TableNameRunAffectedNodes); err != nil { - return nil, err - } - if err = createWfTable(cqlSession, keyspace, reflect.TypeOf(wfmodel.RunCounter{}), wfmodel.TableNameRunCounter); err != nil { - return nil, err - } - - qb := cql.QueryBuilder{} - qb. - Keyspace(keyspace). - Write("ks", keyspace). - Write("last_run", 0) - q := qb.InsertUnpreparedQuery(wfmodel.TableNameRunCounter, cql.IgnoreIfExists) // If not exists. Insert only once. - err = cqlSession.Query(q).Exec() - if err != nil { - return nil, WrapDbErrorWithQuery("cannot initialize run counter", q, err) - } - } - } - return cqlSession, nil -} +package db + +import ( + "fmt" + "reflect" + "strings" + "time" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/env" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "github.com/gocql/gocql" +) + +const ErrorPrefixDb string = "dberror:" + +func WrapDbErrorWithQuery(msg string, query string, dbErr error) error { + if len(query) > 500 { + query = query[:500] + } + return fmt.Errorf("%s, query:%s, %s%s", msg, query, ErrorPrefixDb, dbErr.Error()) +} + +func IsDbConnError(err error) bool { + return strings.Contains(err.Error(), ErrorPrefixDb+gocql.ErrNoConnections.Error()) || + strings.Contains(err.Error(), ErrorPrefixDb+"EOF") + +} + +func createWfTable(cqlSession *gocql.Session, keyspace string, t reflect.Type, tableName string) error { + q := wfmodel.GetCreateTableCql(t, keyspace, tableName) + if err := cqlSession.Query(q).Exec(); err != nil { + return WrapDbErrorWithQuery("failed to create WF table", q, err) + } + return nil +} + +type CreateKeyspaceEnumType int + +const DoNotCreateKeyspaceOnConnect CreateKeyspaceEnumType = 0 +const CreateKeyspaceOnConnect CreateKeyspaceEnumType = 1 + +func NewSession(envConfig *env.EnvConfig, keyspace string, createKeyspace CreateKeyspaceEnumType) (*gocql.Session, error) { + dataCluster := gocql.NewCluster(envConfig.Cassandra.Hosts...) + dataCluster.Port = envConfig.Cassandra.Port + dataCluster.Authenticator = gocql.PasswordAuthenticator{Username: envConfig.Cassandra.Username, Password: envConfig.Cassandra.Password} + dataCluster.NumConns = envConfig.Cassandra.NumConns + dataCluster.Timeout = time.Duration(envConfig.Cassandra.Timeout * int(time.Millisecond)) + dataCluster.ConnectTimeout = time.Duration(envConfig.Cassandra.ConnectTimeout * int(time.Millisecond)) + // Token-aware policy should give better perf results when used together with prepared queries, and Capillaries chatty inserts are killing Cassandra. + // TODO: consider making it configurable + dataCluster.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(gocql.RoundRobinHostPolicy()) + // When testing, we load Cassandra cluster at 100%. There will be "Operation timed out - received only 0 responses" errors. + // It's up to admins how to handle the load, but we should not give up quickly in any case. Make it 3 attempts. + dataCluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: 3} + if envConfig.Cassandra.SslOpts != nil { + dataCluster.SslOpts = &gocql.SslOptions{ + EnableHostVerification: envConfig.Cassandra.SslOpts.EnableHostVerification, + CaPath: envConfig.Cassandra.SslOpts.CaPath, + CertPath: envConfig.Cassandra.SslOpts.CaPath, + KeyPath: envConfig.Cassandra.SslOpts.KeyPath} + } + cqlSession, err := dataCluster.CreateSession() + if err != nil { + return nil, fmt.Errorf("failed to connect to data cluster %v, keyspace [%s]: %s", envConfig.Cassandra.Hosts, keyspace, err.Error()) + } + // Create keyspace if needed + if len(keyspace) > 0 { + dataCluster.Keyspace = keyspace + + if createKeyspace == CreateKeyspaceOnConnect { + createKsQuery := fmt.Sprintf("CREATE KEYSPACE IF NOT EXISTS %s WITH REPLICATION = %s", keyspace, envConfig.Cassandra.KeyspaceReplicationConfig) + if err := cqlSession.Query(createKsQuery).Exec(); err != nil { + return nil, WrapDbErrorWithQuery("failed to create keyspace", createKsQuery, err) + } + + // Create WF tables if needed + if err = createWfTable(cqlSession, keyspace, reflect.TypeOf(wfmodel.BatchHistoryEvent{}), wfmodel.TableNameBatchHistory); err != nil { + return nil, err + } + if err = createWfTable(cqlSession, keyspace, reflect.TypeOf(wfmodel.NodeHistoryEvent{}), wfmodel.TableNameNodeHistory); err != nil { + return nil, err + } + if err = createWfTable(cqlSession, keyspace, reflect.TypeOf(wfmodel.RunHistoryEvent{}), wfmodel.TableNameRunHistory); err != nil { + return nil, err + } + if err = createWfTable(cqlSession, keyspace, reflect.TypeOf(wfmodel.RunProperties{}), wfmodel.TableNameRunAffectedNodes); err != nil { + return nil, err + } + if err = createWfTable(cqlSession, keyspace, reflect.TypeOf(wfmodel.RunCounter{}), wfmodel.TableNameRunCounter); err != nil { + return nil, err + } + + qb := cql.QueryBuilder{} + qb. + Keyspace(keyspace). + Write("ks", keyspace). + Write("last_run", 0) + q := qb.InsertUnpreparedQuery(wfmodel.TableNameRunCounter, cql.IgnoreIfExists) // If not exists. Insert only once. + err = cqlSession.Query(q).Exec() + if err != nil { + return nil, WrapDbErrorWithQuery("cannot initialize run counter", q, err) + } + } + } + return cqlSession, nil +} diff --git a/pkg/deploy/aws_exec.go b/pkg/deploy/aws_exec.go index d3f5d80..4f69621 100644 --- a/pkg/deploy/aws_exec.go +++ b/pkg/deploy/aws_exec.go @@ -7,7 +7,7 @@ import ( "github.com/itchyny/gojq" ) -func ExecLocalAndGetJsonValue(prj *Project, cmdPath string, params []string, query string) (interface{}, ExecResult) { +func ExecLocalAndGetJsonValue(prj *Project, cmdPath string, params []string, query string) (any, ExecResult) { er := ExecLocal(prj, cmdPath, params, prj.CliEnvVars, "") if er.Error != nil { return nil, er @@ -20,19 +20,14 @@ func ExecLocalAndGetJsonValue(prj *Project, cmdPath string, params []string, que } // This is a brutal way to unmarshal incoming JSON, but it should work - var jsonObj map[string]interface{} + var jsonObj map[string]any if err := json.Unmarshal([]byte(er.Stdout), &jsonObj); err != nil { er.Error = fmt.Errorf("cannot unmarshal json, error %s, json %s", err.Error(), er.Stdout) return nil, er } iter := q.Run(jsonObj) - for { - v, ok := iter.Next() - if !ok { - break - } - + if v, ok := iter.Next(); ok { return v, er } diff --git a/pkg/deploy/aws_instances.go b/pkg/deploy/aws_instances.go index 447d83f..783aa8f 100644 --- a/pkg/deploy/aws_instances.go +++ b/pkg/deploy/aws_instances.go @@ -161,7 +161,7 @@ func waitForAwsInstanceToBeCreated(prj *Project, instanceId string, timeoutSecon if status != "pending" { return lb.Complete(fmt.Errorf("%s was built, but the status is unknown: %s", instanceId, status)) } - if time.Since(startWaitTs).Seconds() > float64(prj.Timeouts.OpenstackInstanceCreation) { + if time.Since(startWaitTs).Seconds() > float64(timeoutSeconds) { return lb.Complete(fmt.Errorf("giving up after waiting for %s to be created", instanceId)) } time.Sleep(10 * time.Second) @@ -184,7 +184,7 @@ func assignAwsFloatingIp(prj *Project, instanceId string, floatingIp string, isV return lb.Complete(nil) } -func (*AwsDeployProvider) CreateInstanceAndWaitForCompletion(prjPair *ProjectPair, iNickname string, flavorId string, imageId string, availabilityZone string, isVerbose bool) (LogMsg, error) { +func (*AwsDeployProvider) CreateInstanceAndWaitForCompletion(prjPair *ProjectPair, iNickname string, flavorId string, imageId string, _ string, isVerbose bool) (LogMsg, error) { sb := strings.Builder{} logMsg, err := createAwsInstance(prjPair, iNickname, flavorId, imageId, isVerbose) diff --git a/pkg/deploy/aws_networking.go b/pkg/deploy/aws_networking.go index 88f0ea5..8c177cf 100644 --- a/pkg/deploy/aws_networking.go +++ b/pkg/deploy/aws_networking.go @@ -161,7 +161,7 @@ func waitForAwsVpcToBeCreated(prj *Project, vpcId string, timeoutSeconds int, is if status != "pending" { return lb.Complete(fmt.Errorf("vpc %s was built, but the status is %s", vpcId, status)) } - if time.Since(startWaitTs).Seconds() > float64(prj.Timeouts.OpenstackInstanceCreation) { + if time.Since(startWaitTs).Seconds() > float64(timeoutSeconds) { return lb.Complete(fmt.Errorf("giving up after waiting for vpc %s to be created", vpcId)) } time.Sleep(10 * time.Second) @@ -388,7 +388,7 @@ func createNatGatewayAndRoutePrivateSubnet(prjPair *ProjectPair, isVerbose bool) return lb.Complete(er.Error) } - if result != true { + if !result { if er.Error != nil { return lb.Complete(fmt.Errorf("route creation returned false")) } @@ -463,7 +463,7 @@ func createInternetGatewayAndRoutePublicSubnet(prjPair *ProjectPair, isVerbose b return lb.Complete(er.Error) } } else if attachedVpcId != prjPair.Live.Network.Id { - return lb.Complete(fmt.Errorf("network gateway %s seems to be attached to a wrong vpc %s already\n", prjPair.Live.Network.Router.Name, attachedVpcId)) + return lb.Complete(fmt.Errorf("network gateway %s seems to be attached to a wrong vpc %s already", prjPair.Live.Network.Router.Name, attachedVpcId)) } else { lb.Add(fmt.Sprintf("network gateway %s seems to be attached to vpc already\n", prjPair.Live.Network.Router.Name)) } @@ -509,7 +509,7 @@ func createInternetGatewayAndRoutePublicSubnet(prjPair *ProjectPair, isVerbose b return lb.Complete(er.Error) } - if result != true { + if !result { if er.Error != nil { return lb.Complete(fmt.Errorf("route creation returned false")) } diff --git a/pkg/deploy/aws_security_group.go b/pkg/deploy/aws_security_group.go index cb15e1d..e6cc19c 100644 --- a/pkg/deploy/aws_security_group.go +++ b/pkg/deploy/aws_security_group.go @@ -66,7 +66,7 @@ func createAwsSecurityGroup(prjPair *ProjectPair, sgNickname string, isVerbose b if er.Error != nil { return lb.Complete(er.Error) } - if result != true { + if !result { if er.Error != nil { return lb.Complete(fmt.Errorf("rule creation returned false: %v", rule)) } @@ -81,7 +81,7 @@ func createAwsSecurityGroup(prjPair *ProjectPair, sgNickname string, isVerbose b func (*AwsDeployProvider) CreateSecurityGroups(prjPair *ProjectPair, isVerbose bool) (LogMsg, error) { sb := strings.Builder{} - for sgNickname, _ := range prjPair.Live.SecurityGroups { + for sgNickname := range prjPair.Live.SecurityGroups { logMsg, err := createAwsSecurityGroup(prjPair, sgNickname, isVerbose) AddLogMsg(&sb, logMsg) if err != nil { @@ -127,7 +127,7 @@ func deleteAwsSecurityGroup(prjPair *ProjectPair, sgNickname string, isVerbose b func (*AwsDeployProvider) DeleteSecurityGroups(prjPair *ProjectPair, isVerbose bool) (LogMsg, error) { sb := strings.Builder{} - for sgNickname, _ := range prjPair.Live.SecurityGroups { + for sgNickname := range prjPair.Live.SecurityGroups { logMsg, err := deleteAwsSecurityGroup(prjPair, sgNickname, isVerbose) AddLogMsg(&sb, logMsg) if err != nil { diff --git a/pkg/deploy/aws_volumes.go b/pkg/deploy/aws_volumes.go index b2ee31c..ea1a103 100644 --- a/pkg/deploy/aws_volumes.go +++ b/pkg/deploy/aws_volumes.go @@ -36,7 +36,6 @@ func (*AwsDeployProvider) CreateVolume(prjPair *ProjectPair, iNickname string, v // If it was already created, save it for future use, but do not create if foundVolIdByName != "" { lb.Add(fmt.Sprintf("volume %s(%s) already there, updating project", prjPair.Live.Instances[iNickname].Volumes[volNickname].Name, foundVolIdByName)) - //fmt.Printf("Setting existing %s-%s %s\n", iNickname, volNickname, foundVolIdByName) prjPair.SetVolumeId(iNickname, volNickname, foundVolIdByName) } } else { @@ -91,9 +90,10 @@ func volNicknameToAwsSuggestedDeviceName(volumes map[string]*VolumeDef, volNickn return "invalid-device-for-vol-" + volNickname } -func awsFinalDeviceNameOld(suggestedDeviceName string) string { - return strings.ReplaceAll(suggestedDeviceName, "/dev/sd", "/dev/xvd") -} +// Not used +// func awsFinalDeviceNameOld(suggestedDeviceName string) string { +// return strings.ReplaceAll(suggestedDeviceName, "/dev/sd", "/dev/xvd") +// } func awsFinalDeviceNameNitro(suggestedDeviceName string) string { // See what lsblk shows for your case. diff --git a/pkg/deploy/deploy_provider.go b/pkg/deploy/deploy_provider.go index e7a14cf..8029bee 100644 --- a/pkg/deploy/deploy_provider.go +++ b/pkg/deploy/deploy_provider.go @@ -28,9 +28,8 @@ func DeployProviderFactory(deployProviderName string) (DeployProvider, error) { return &OpenstackDeployProvider{}, nil } else if deployProviderName == DeployProviderAws { return &AwsDeployProvider{}, nil - } else { - return nil, fmt.Errorf("unsupported deploy provider %s", deployProviderName) } + return nil, fmt.Errorf("unsupported deploy provider %s", deployProviderName) } func reportPublicIp(prj *Project) { diff --git a/pkg/deploy/exec_local.go b/pkg/deploy/exec_local.go index deddf6d..d009ddb 100644 --- a/pkg/deploy/exec_local.go +++ b/pkg/deploy/exec_local.go @@ -56,9 +56,9 @@ func CmdChainExecToString(title string, logContent string, err error, isVerbose %s ========================================= `, title, logContent) - } else { - return title } + + return title } func ExecLocal(prj *Project, cmdPath string, params []string, envVars map[string]string, dir string) ExecResult { @@ -93,17 +93,17 @@ func ExecLocal(prj *Project, cmdPath string, params []string, envVars map[string elapsed := time.Since(runStartTime).Seconds() rawInput := fmt.Sprintf("%s %s", cmdPath, strings.Join(params, " ")) - rawOutput := string(stdout.Bytes()) - rawErrors := string(stderr.Bytes()) + rawOutput := stdout.String() + rawErrors := stderr.String() if err != nil { // Cmd not found, nonzero exit status etc return ExecResult{rawInput, rawOutput, rawErrors, elapsed, err} } else if cmdCtx.Err() == context.DeadlineExceeded { // Timeout occurred, err.Error() is probably: 'signal: killed' return ExecResult{rawInput, rawOutput, rawErrors, elapsed, fmt.Errorf("cmd execution timeout exceeded")} - } else { - return ExecResult{rawInput, rawOutput, rawErrors, elapsed, nil} } + + return ExecResult{rawInput, rawOutput, rawErrors, elapsed, nil} } func BuildArtifacts(prjPair *ProjectPair, isVerbose bool) (LogMsg, error) { diff --git a/pkg/deploy/exec_ssh.go b/pkg/deploy/exec_ssh.go index e19b6a8..0bb0448 100644 --- a/pkg/deploy/exec_ssh.go +++ b/pkg/deploy/exec_ssh.go @@ -3,7 +3,7 @@ package deploy import ( "bytes" "fmt" - "io/ioutil" + "io" "net" "os" "path/filepath" @@ -55,7 +55,7 @@ func (tsc *TunneledSshClient) Close() { tsc.TunneledSshConn.Close() } if tsc.TunneledTcpConn != nil { - tsc.TunneledTcpConn.Close() + tsc.TunneledTcpConn.Close() //nolint:all } if tsc.ProxySshClient != nil { tsc.ProxySshClient.Close() @@ -65,8 +65,6 @@ func (tsc *TunneledSshClient) Close() { func NewTunneledSshClient(sshConfig *SshConfigDef, ipAddress string) (*TunneledSshClient, error) { bastionSshClientConfig, err := xfer.NewSshClientConfig( sshConfig.User, - sshConfig.ExternalIpAddress, - sshConfig.Port, sshConfig.PrivateKeyPath, sshConfig.PrivateKeyPassword) if err != nil { @@ -99,8 +97,6 @@ func NewTunneledSshClient(sshConfig *SshConfigDef, ipAddress string) (*TunneledS tunneledSshClientConfig, err := xfer.NewSshClientConfig( sshConfig.User, - ipAddress, - sshConfig.Port, sshConfig.PrivateKeyPath, sshConfig.PrivateKeyPassword) if err != nil { @@ -140,7 +136,7 @@ func ExecSsh(sshConfig *SshConfigDef, ipAddress string, cmd string) ExecResult { err = session.Run(cmd) elapsed := time.Since(runStartTime).Seconds() - er := ExecResult{cmd, string(stdout.Bytes()), string(stderr.Bytes()), elapsed, err} + er := ExecResult{cmd, stdout.String(), stderr.String(), elapsed, err} return er } @@ -238,9 +234,9 @@ func ExecScriptsOnInstance(sshConfig *SshConfigDef, ipAddress string, env map[st if err != nil { return lb.Complete(fmt.Errorf("cannot open shell script %s: %s", fullScriptPath, err.Error())) } - defer f.Close() + defer f.Close() //nolint:all - shellScriptBytes, err := ioutil.ReadAll(f) + shellScriptBytes, err := io.ReadAll(f) if err != nil { return lb.Complete(fmt.Errorf("cannot read shell script %s: %s", fullScriptPath, err.Error())) } diff --git a/pkg/deploy/openstack_instances.go b/pkg/deploy/openstack_instances.go index 988c071..5363694 100644 --- a/pkg/deploy/openstack_instances.go +++ b/pkg/deploy/openstack_instances.go @@ -14,7 +14,7 @@ func (*OpenstackDeployProvider) GetFlavorIds(prjPair *ProjectPair, flavorMap map return lb.Complete(er.Error) } - for flavorName, _ := range flavorMap { + for flavorName := range flavorMap { foundFlavorIdByName := findOpenstackColumnValue(rows, "ID", "Name", flavorName) if foundFlavorIdByName == "" { return lb.Complete(fmt.Errorf("cannot find flavor %s", flavorName)) @@ -34,7 +34,7 @@ func (*OpenstackDeployProvider) GetImageIds(prjPair *ProjectPair, imageMap map[s return lb.Complete(er.Error) } - for name, _ := range imageMap { + for name := range imageMap { foundIdByName := findOpenstackColumnValue(rows, "ID", "Name", name) if foundIdByName == "" { return lb.Complete(fmt.Errorf("cannot find image %s", name)) @@ -54,7 +54,7 @@ func (*OpenstackDeployProvider) GetKeypairs(prjPair *ProjectPair, keypairMap map return lb.Complete(er.Error) } - for keypairName, _ := range keypairMap { + for keypairName := range keypairMap { foundName := findOpenstackColumnValue(rows, "Fingerprint", "Name", keypairName) if foundName == "" { return lb.Complete(fmt.Errorf("cannot find keypair %s, you have to create it before running this command", keypairName)) diff --git a/pkg/deploy/openstack_networking.go b/pkg/deploy/openstack_networking.go index 994af97..c4b9b1b 100644 --- a/pkg/deploy/openstack_networking.go +++ b/pkg/deploy/openstack_networking.go @@ -331,7 +331,7 @@ func createOpenstackRouter(prjPair *ProjectPair, isVerbose bool) (LogMsg, error) lb.Add(fmt.Sprintf("router %s seems to be connected to internet\n", prjPair.Live.Network.Router.Name)) } else { lb.Add(fmt.Sprintf("router %s needs to be connected to internet\n", prjPair.Live.Network.Router.Name)) - rows, er = execLocalAndParseOpenstackOutput(&prjPair.Live, "openstack", []string{"router", "set", "--external-gateway", prjPair.Live.Network.Router.ExternalGatewayNetworkName, prjPair.Live.Network.Router.Name}) + _, er = execLocalAndParseOpenstackOutput(&prjPair.Live, "openstack", []string{"router", "set", "--external-gateway", prjPair.Live.Network.Router.ExternalGatewayNetworkName, prjPair.Live.Network.Router.Name}) lb.Add(er.ToString()) if er.Error != nil { return lb.Complete(er.Error) @@ -369,7 +369,7 @@ func deleteOpenstackRouter(prjPair *ProjectPair, isVerbose bool) (LogMsg, error) } // Release gateway. Some providers (genesis) will not remove port if gateway not released. - rows, er = execLocalAndParseOpenstackOutput(&prjPair.Live, "openstack", []string{"router", "unset", "--external-gateway", prjPair.Live.Network.Router.Name}) + _, er = execLocalAndParseOpenstackOutput(&prjPair.Live, "openstack", []string{"router", "unset", "--external-gateway", prjPair.Live.Network.Router.Name}) lb.Add(er.ToString()) if er.Error != nil { return lb.Complete(er.Error) diff --git a/pkg/deploy/openstack_security_group.go b/pkg/deploy/openstack_security_group.go index b32f519..e1740ac 100644 --- a/pkg/deploy/openstack_security_group.go +++ b/pkg/deploy/openstack_security_group.go @@ -7,7 +7,7 @@ import ( func (*OpenstackDeployProvider) CreateSecurityGroups(prjPair *ProjectPair, isVerbose bool) (LogMsg, error) { sb := strings.Builder{} - for sgNickname, _ := range prjPair.Live.SecurityGroups { + for sgNickname := range prjPair.Live.SecurityGroups { logMsg, err := createOpenstackSecurityGroup(prjPair, sgNickname, isVerbose) AddLogMsg(&sb, logMsg) if err != nil { @@ -146,7 +146,7 @@ func createOpenstackSecurityGroup(prjPair *ProjectPair, sgNickname string, isVer func (*OpenstackDeployProvider) DeleteSecurityGroups(prjPair *ProjectPair, isVerbose bool) (LogMsg, error) { sb := strings.Builder{} - for sgNickname, _ := range prjPair.Live.SecurityGroups { + for sgNickname := range prjPair.Live.SecurityGroups { logMsg, err := deleteOpenstackSecurityGroup(prjPair, sgNickname, isVerbose) AddLogMsg(&sb, logMsg) if err != nil { diff --git a/pkg/deploy/openstack_volumes.go b/pkg/deploy/openstack_volumes.go index 1b567b1..0d9835f 100644 --- a/pkg/deploy/openstack_volumes.go +++ b/pkg/deploy/openstack_volumes.go @@ -93,7 +93,6 @@ func (*OpenstackDeployProvider) CreateVolume(prjPair *ProjectPair, iNickname str // If it was already created, save it for future use, but do not create if foundVolIdByName != "" { lb.Add(fmt.Sprintf("volume %s(%s) already there, updating project", prjPair.Live.Instances[iNickname].Volumes[volNickname].Name, foundVolIdByName)) - //fmt.Printf("Setting existing %s-%s %s\n", iNickname, volNickname, foundVolIdByName) prjPair.SetVolumeId(iNickname, volNickname, foundVolIdByName) } } else { @@ -134,7 +133,6 @@ func (*OpenstackDeployProvider) CreateVolume(prjPair *ProjectPair, iNickname str lb.Add(fmt.Sprintf("created volume %s: %s(%s)", volNickname, prjPair.Live.Instances[iNickname].Volumes[volNickname].Name, newId)) prjPair.SetVolumeId(iNickname, volNickname, newId) - //fmt.Printf("Setting id %s-%s %s\n", iNickname, volNickname, newId) return lb.Complete(nil) } diff --git a/pkg/deploy/project.go b/pkg/deploy/project.go index 098c218..0e030ce 100644 --- a/pkg/deploy/project.go +++ b/pkg/deploy/project.go @@ -3,7 +3,6 @@ package deploy import ( "encoding/json" "fmt" - "io/ioutil" "os" "path/filepath" "strings" @@ -42,7 +41,7 @@ type PrivateSubnetDef struct { Name string `json:"name"` Id string `json:"id"` Cidr string `json:"cidr"` - AllocationPool string `json:"allocation_pool"` //start=192.168.199.2,end=192.168.199.254 + AllocationPool string `json:"allocation_pool"` // start=192.168.199.2,end=192.168.199.254 AvailabilityZone string `json:"availability_zone"` // AWS only RouteTableToNat string `json:"route_table_to_nat"` // AWS only } @@ -352,12 +351,12 @@ func (prj *Project) validate() error { } // All file groups should be referenced, otherwise useless - for fgName, _ := range prj.FileGroupsUp { + for fgName := range prj.FileGroupsUp { if _, ok := referencedUpFileGroups[fgName]; !ok { return fmt.Errorf("up file group %s not reference by any instance, consider removing it", fgName) } } - for fgName, _ := range prj.FileGroupsDown { + for fgName := range prj.FileGroupsDown { if _, ok := referencedDownFileGroups[fgName]; !ok { return fmt.Errorf("down file group %s not reference by any instance, consider removing it", fgName) } @@ -376,7 +375,7 @@ func LoadProject(prjFile string) (*ProjectPair, string, error) { return nil, "", fmt.Errorf("cannot find project file [%s]: [%s]", prjFullPath, err.Error()) } - prjBytes, err := ioutil.ReadFile(prjFullPath) + prjBytes, err := os.ReadFile(prjFullPath) if err != nil { return nil, "", fmt.Errorf("cannot read project file %s: %s", prjFullPath, err.Error()) } @@ -446,10 +445,12 @@ func (prj *Project) SaveProject(fullPrjPath string) error { } fPrj, err := os.Create(fullPrjPath) + if err != nil { + return err + } defer fPrj.Close() if _, err := fPrj.WriteString(string(prjJsonBytes)); err != nil { return err } - fPrj.Sync() - return nil + return fPrj.Sync() } diff --git a/pkg/deploy/up_down.go b/pkg/deploy/up_down.go index 12267f6..f8039fa 100644 --- a/pkg/deploy/up_down.go +++ b/pkg/deploy/up_down.go @@ -102,14 +102,14 @@ func InstanceFileGroupDownDefsToSpecs(prj *Project, ipAddress string, fgDef *Fil } defer tsc.Close() - sftp, err := sftp.NewClient(tsc.SshClient) + sftpClient, err := sftp.NewClient(tsc.SshClient) if err != nil { return nil, fmt.Errorf("cannot create sftp client: %s", err.Error()) } - defer sftp.Close() + defer sftpClient.Close() fileDownloadSpecs := make([]*FileDownloadSpec, 0) - w := sftp.Walk(fgDef.Src) + w := sftpClient.Walk(fgDef.Src) for w.Step() { if w.Err() != nil { return nil, fmt.Errorf("sftp walker error in %s, %s: %s", fgDef.Src, w.Path(), w.Err().Error()) @@ -166,11 +166,11 @@ func UploadFileSftp(prj *Project, ipAddress string, srcPath string, dstPath stri } defer tsc.Close() - sftp, err := sftp.NewClient(tsc.SshClient) + sftpClient, err := sftp.NewClient(tsc.SshClient) if err != nil { return lb.Complete(fmt.Errorf("cannot create sftp client to %s: %s", ipAddress, err.Error())) } - defer sftp.Close() + defer sftpClient.Close() pathParts := strings.Split(dstPath, string(os.PathSeparator)) curPath := string(os.PathSeparator) @@ -179,7 +179,7 @@ func UploadFileSftp(prj *Project, ipAddress string, srcPath string, dstPath stri continue } curPath = filepath.Join(curPath, pathParts[partIdx]) - fi, err := sftp.Stat(curPath) + fi, err := sftpClient.Stat(curPath) if err == nil && fi.IsDir() { // Nothing to do, we do not change existing directories continue @@ -229,7 +229,7 @@ func UploadFileSftp(prj *Project, ipAddress string, srcPath string, dstPath stri return lb.Complete(fmt.Errorf("cannot delete dst file on upload %s: %s", dstPath, err.Error())) } - dstFile, err := sftp.Create(dstPath) + dstFile, err := sftpClient.Create(dstPath) if err != nil { return lb.Complete(fmt.Errorf("cannot create on upload %s%s: %s", ipAddress, dstPath, err.Error())) } @@ -263,17 +263,17 @@ func DownloadFileSftp(prj *Project, ipAddress string, srcPath string, dstPath st } defer tsc.Close() - sftp, err := sftp.NewClient(tsc.SshClient) + sftpClient, err := sftp.NewClient(tsc.SshClient) if err != nil { return lb.Complete(fmt.Errorf("cannot create sftp client: %s", err.Error())) } - defer sftp.Close() + defer sftpClient.Close() if err := os.MkdirAll(filepath.Dir(dstPath), 0777); err != nil { return lb.Complete(fmt.Errorf("cannot create target dir for %s: %s", dstPath, err.Error())) } - srcFile, err := sftp.Open(srcPath) + srcFile, err := sftpClient.Open(srcPath) if err != nil { return lb.Complete(fmt.Errorf("cannot open for download %s: %s", srcPath, err.Error())) } diff --git a/pkg/deploy/users.go b/pkg/deploy/users.go index cad37dc..95aad50 100644 --- a/pkg/deploy/users.go +++ b/pkg/deploy/users.go @@ -2,7 +2,6 @@ package deploy import ( "fmt" - "io/ioutil" "os" "path/filepath" "strings" @@ -35,25 +34,24 @@ create_instance_user() func NewCreateInstanceUsersCommands(iDef *InstanceDef) ([]string, error) { cmds := make([]string, len(iDef.Users)) for uIdx, uDef := range iDef.Users { - if len(uDef.Name) > 0 { - keyPath := uDef.PublicKeyPath - if strings.HasPrefix(keyPath, "~/") { - homeDir, _ := os.UserHomeDir() - keyPath = filepath.Join(homeDir, keyPath[2:]) - } - keyBytes, err := ioutil.ReadFile(keyPath) - if err != nil { - return nil, fmt.Errorf("cannot read public key '%s' for user %s on %s: %s", uDef.PublicKeyPath, uDef.Name, iDef.HostName, err.Error()) - } - key := string(keyBytes) - if !strings.HasPrefix(key, "ssh-") && !strings.HasPrefix(key, "ecsda-") { - return nil, fmt.Errorf("cannot copy private key '%s' on %s: public key should start with ssh or ecdsa-", uDef.PublicKeyPath, iDef.HostName) - } - - cmds[uIdx] = fmt.Sprintf("%s\ncreate_instance_user '%s' '%s'", CreateInstanceUserFunc, uDef.Name, key) - } else { + if len(uDef.Name) == 0 { return nil, fmt.Errorf("cannot create instance %s user '%s': name cannot be null, public key should start with ssh or ecdsa-", iDef.HostName, uDef.Name) } + keyPath := uDef.PublicKeyPath + if strings.HasPrefix(keyPath, "~/") { + homeDir, _ := os.UserHomeDir() + keyPath = filepath.Join(homeDir, keyPath[2:]) + } + keyBytes, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("cannot read public key '%s' for user %s on %s: %s", uDef.PublicKeyPath, uDef.Name, iDef.HostName, err.Error()) + } + key := string(keyBytes) + if !strings.HasPrefix(key, "ssh-") && !strings.HasPrefix(key, "ecsda-") { + return nil, fmt.Errorf("cannot copy private key '%s' on %s: public key should start with ssh or ecdsa-", uDef.PublicKeyPath, iDef.HostName) + } + + cmds[uIdx] = fmt.Sprintf("%s\ncreate_instance_user '%s' '%s'", CreateInstanceUserFunc, uDef.Name, key) } return cmds, nil } @@ -72,26 +70,25 @@ copy_private_key() func NewCopyPrivateKeysCommands(iDef *InstanceDef) ([]string, error) { cmds := make([]string, len(iDef.PrivateKeys)) for uIdx, uDef := range iDef.PrivateKeys { - if len(uDef.Name) > 0 { - keyPath := uDef.PrivateKeyPath - if strings.HasPrefix(keyPath, "~/") { - homeDir, _ := os.UserHomeDir() - keyPath = filepath.Join(homeDir, keyPath[2:]) - } - keyBytes, err := ioutil.ReadFile(keyPath) - if err != nil { - return nil, fmt.Errorf("cannot read private key '%s' for user %s on %s: %s", keyPath, uDef.Name, iDef.HostName, err.Error()) - } - key := string(keyBytes) - if !strings.HasPrefix(key, "-----BEGIN OPENSSH PRIVATE KEY-----") { - return nil, fmt.Errorf("cannot copy private key '%s' on %s: private key should start with -----BEGIN OPENSSH PRIVATE KEY-----", uDef.PrivateKeyPath, iDef.HostName) - } - - // Make sure escaped \n remains escaped (this is how we store private keys in our json config files) with actual EOLs - cmds[uIdx] = fmt.Sprintf("%s\ncopy_private_key '%s' '%s'", CopyPrivateKeyFunc, uDef.Name, strings.ReplaceAll(string(key), "\n", "\\n")) - } else { + if len(uDef.Name) == 0 { return nil, fmt.Errorf("cannot copy private key '%s' on %s: name cannot be null", uDef.Name, iDef.HostName) } + keyPath := uDef.PrivateKeyPath + if strings.HasPrefix(keyPath, "~/") { + homeDir, _ := os.UserHomeDir() + keyPath = filepath.Join(homeDir, keyPath[2:]) + } + keyBytes, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("cannot read private key '%s' for user %s on %s: %s", keyPath, uDef.Name, iDef.HostName, err.Error()) + } + key := string(keyBytes) + if !strings.HasPrefix(key, "-----BEGIN OPENSSH PRIVATE KEY-----") { + return nil, fmt.Errorf("cannot copy private key '%s' on %s: private key should start with -----BEGIN OPENSSH PRIVATE KEY-----", uDef.PrivateKeyPath, iDef.HostName) + } + + // Make sure escaped \n remains escaped (this is how we store private keys in our json config files) with actual EOLs + cmds[uIdx] = fmt.Sprintf("%s\ncopy_private_key '%s' '%s'", CopyPrivateKeyFunc, uDef.Name, strings.ReplaceAll(string(key), "\n", "\\n")) } return cmds, nil } diff --git a/pkg/dpc/dependency_policy_checker.go b/pkg/dpc/dependency_policy_checker.go index 56518e0..d88c744 100644 --- a/pkg/dpc/dependency_policy_checker.go +++ b/pkg/dpc/dependency_policy_checker.go @@ -1,43 +1,43 @@ -package dpc - -import ( - "fmt" - "sort" - - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/wfmodel" -) - -func CheckDependencyPolicyAgainstNodeEventList(targetNodeDepPol *sc.DependencyPolicyDef, events wfmodel.DependencyNodeEvents) (sc.ReadyToRunNodeCmdType, int16, string, error) { - var err error - - for eventIdx := 0; eventIdx < len(events); eventIdx++ { - vars := wfmodel.NewVarsFromDepCtx(0, events[eventIdx]) - events[eventIdx].SortKey, err = sc.BuildKey(vars[wfmodel.DependencyNodeEventTableName], &targetNodeDepPol.OrderIdxDef) - if err != nil { - return sc.NodeNogo, 0, "", fmt.Errorf("unexpectedly, cannot build key to sort events: %s", err.Error()) - } - } - sort.Slice(events, func(i, j int) bool { return events[i].SortKey < events[j].SortKey }) - - for eventIdx := 0; eventIdx < len(events); eventIdx++ { - vars := wfmodel.NewVarsFromDepCtx(0, events[eventIdx]) - eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &vars) - for ruleIdx, rule := range targetNodeDepPol.Rules { - ruleMatched, err := eCtx.Eval(rule.ParsedExpression) - if err != nil { - return sc.NodeNogo, 0, "", fmt.Errorf("cannot check rule %d '%s' against event %s, eval failed: %s", ruleIdx, rule.RawExpression, events[eventIdx].ToString(), err.Error()) - } - ruleMatchedBool, ok := ruleMatched.(bool) - if !ok { - return sc.NodeNogo, 0, "", fmt.Errorf("cannot check rule %d '%s' against event %s: expected result type was bool, got %T", ruleIdx, rule.RawExpression, events[eventIdx].ToString(), ruleMatched) - } - if ruleMatchedBool { - return rule.Cmd, events[eventIdx].RunId, fmt.Sprintf("matched rule %d(%s) '%s' against event %d %s. All events %s", ruleIdx, rule.Cmd, rule.RawExpression, eventIdx, events[eventIdx].ToString(), events.ToString()), nil - } - } - } - - return sc.NodeNogo, 0, fmt.Sprintf("no rules matched against events %s", events.ToString()), nil -} +package dpc + +import ( + "fmt" + "sort" + + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/wfmodel" +) + +func CheckDependencyPolicyAgainstNodeEventList(targetNodeDepPol *sc.DependencyPolicyDef, events wfmodel.DependencyNodeEvents) (sc.ReadyToRunNodeCmdType, int16, string, error) { + var err error + + for eventIdx := 0; eventIdx < len(events); eventIdx++ { + vars := wfmodel.NewVarsFromDepCtx(events[eventIdx]) + events[eventIdx].SortKey, err = sc.BuildKey(vars[wfmodel.DependencyNodeEventTableName], &targetNodeDepPol.OrderIdxDef) + if err != nil { + return sc.NodeNogo, 0, "", fmt.Errorf("unexpectedly, cannot build key to sort events: %s", err.Error()) + } + } + sort.Slice(events, func(i, j int) bool { return events[i].SortKey < events[j].SortKey }) + + for eventIdx := 0; eventIdx < len(events); eventIdx++ { + vars := wfmodel.NewVarsFromDepCtx(events[eventIdx]) + eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &vars) + for ruleIdx, rule := range targetNodeDepPol.Rules { + ruleMatched, err := eCtx.Eval(rule.ParsedExpression) + if err != nil { + return sc.NodeNogo, 0, "", fmt.Errorf("cannot check rule %d '%s' against event %s, eval failed: %s", ruleIdx, rule.RawExpression, events[eventIdx].ToString(), err.Error()) + } + ruleMatchedBool, ok := ruleMatched.(bool) + if !ok { + return sc.NodeNogo, 0, "", fmt.Errorf("cannot check rule %d '%s' against event %s: expected result type was bool, got %T", ruleIdx, rule.RawExpression, events[eventIdx].ToString(), ruleMatched) + } + if ruleMatchedBool { + return rule.Cmd, events[eventIdx].RunId, fmt.Sprintf("matched rule %d(%s) '%s' against event %d %s. All events %s", ruleIdx, rule.Cmd, rule.RawExpression, eventIdx, events[eventIdx].ToString(), events.ToString()), nil + } + } + } + + return sc.NodeNogo, 0, fmt.Sprintf("no rules matched against events %s", events.ToString()), nil +} diff --git a/pkg/dpc/dependency_policy_checker_test.go b/pkg/dpc/dependency_policy_checker_test.go index f77252f..564f2c4 100644 --- a/pkg/dpc/dependency_policy_checker_test.go +++ b/pkg/dpc/dependency_policy_checker_test.go @@ -1,120 +1,120 @@ -package dpc - -import ( - "regexp" - "testing" - "time" - - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "github.com/stretchr/testify/assert" -) - -func TestDefaultDependencyPolicyChecker(t *testing.T) { - events := wfmodel.DependencyNodeEvents{ - { - RunId: 10, - RunIsCurrent: true, - RunStartTs: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), - RunFinalStatus: wfmodel.RunStart, - RunCompletedTs: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), - RunStoppedTs: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), - NodeIsStarted: true, - NodeStartTs: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), - NodeStatus: wfmodel.NodeBatchNone, - NodeStatusTs: time.Date(2000, 1, 1, 0, 0, 2, 0, time.UTC)}} - - polDef := sc.DependencyPolicyDef{} - if err := polDef.Deserialize([]byte(sc.DefaultPolicyCheckerConf)); err != nil { - t.Error(err) - return - } - - var cmd sc.ReadyToRunNodeCmdType - var runId int16 - var checkerLogMsg string - var err error - - events[0].RunIsCurrent = true - - events[0].NodeStatus = wfmodel.NodeBatchRunStopReceived - cmd, runId, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) - assert.Nil(t, err) - assert.Equal(t, sc.NodeNogo, cmd) - assert.Contains(t, checkerLogMsg, "no rules matched against events") - - events[0].NodeStatus = wfmodel.NodeBatchSuccess - cmd, runId, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) - assert.Nil(t, err) - assert.Equal(t, sc.NodeGo, cmd) - assert.Equal(t, int16(10), runId) - assert.Contains(t, checkerLogMsg, "matched rule 0(go)") - - events[0].NodeStatus = wfmodel.NodeBatchNone - cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) - assert.Nil(t, err) - assert.Equal(t, sc.NodeWait, cmd) - assert.Contains(t, checkerLogMsg, "matched rule 1(wait)") - - events[0].NodeStatus = wfmodel.NodeBatchStart - cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) - assert.Nil(t, err) - assert.Equal(t, sc.NodeWait, cmd) - assert.Contains(t, checkerLogMsg, "matched rule 2(wait)") - - events[0].NodeStatus = wfmodel.NodeBatchFail - cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) - assert.Nil(t, err) - assert.Equal(t, sc.NodeNogo, cmd) - assert.Contains(t, checkerLogMsg, "matched rule 3(nogo)") - - events[0].RunIsCurrent = false - - events[0].NodeStatus = wfmodel.NodeBatchRunStopReceived - cmd, runId, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) - assert.Nil(t, err) - assert.Equal(t, sc.NodeNogo, cmd) - assert.Contains(t, checkerLogMsg, "no rules matched against events") - - events[0].NodeStatus = wfmodel.NodeBatchSuccess - cmd, runId, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) - assert.Nil(t, err) - assert.Equal(t, sc.NodeGo, cmd) - assert.Equal(t, int16(10), runId) - assert.Contains(t, checkerLogMsg, "matched rule 4(go)") - - events[0].NodeStatus = wfmodel.NodeBatchNone - cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) - assert.Nil(t, err) - assert.Equal(t, sc.NodeWait, cmd) - assert.Contains(t, checkerLogMsg, "matched rule 5(wait)") - - events[0].NodeStatus = wfmodel.NodeBatchStart - cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) - assert.Nil(t, err) - assert.Equal(t, sc.NodeWait, cmd) - assert.Contains(t, checkerLogMsg, "matched rule 6(wait)") - - events[0].RunFinalStatus = wfmodel.RunComplete - - events[0].NodeStatus = wfmodel.NodeBatchSuccess - cmd, runId, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) - assert.Nil(t, err) - assert.Equal(t, sc.NodeGo, cmd) - assert.Equal(t, int16(10), runId) - assert.Contains(t, checkerLogMsg, "matched rule 7(go)") - - events[0].NodeStatus = wfmodel.NodeBatchFail - cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) - assert.Nil(t, err) - assert.Equal(t, sc.NodeNogo, cmd) - assert.Contains(t, checkerLogMsg, "matched rule 8(nogo)") - - // Failures - - re := regexp.MustCompile(`"expression": "e\.run[^"]+"`) - err = polDef.Deserialize([]byte(re.ReplaceAllString(sc.DefaultPolicyCheckerConf, `"expression": "1"`))) - assert.Nil(t, err) - _, _, _, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) - assert.Contains(t, err.Error(), "expected result type was bool, got int64") -} +package dpc + +import ( + "regexp" + "testing" + "time" + + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "github.com/stretchr/testify/assert" +) + +func TestDefaultDependencyPolicyChecker(t *testing.T) { + events := wfmodel.DependencyNodeEvents{ + { + RunId: 10, + RunIsCurrent: true, + RunStartTs: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + RunFinalStatus: wfmodel.RunStart, + RunCompletedTs: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + RunStoppedTs: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + NodeIsStarted: true, + NodeStartTs: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), + NodeStatus: wfmodel.NodeBatchNone, + NodeStatusTs: time.Date(2000, 1, 1, 0, 0, 2, 0, time.UTC)}} + + polDef := sc.DependencyPolicyDef{} + if err := polDef.Deserialize([]byte(sc.DefaultPolicyCheckerConf)); err != nil { + t.Error(err) + return + } + + var cmd sc.ReadyToRunNodeCmdType + var runId int16 + var checkerLogMsg string + var err error + + events[0].RunIsCurrent = true + + events[0].NodeStatus = wfmodel.NodeBatchRunStopReceived + cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) + assert.Nil(t, err) + assert.Equal(t, sc.NodeNogo, cmd) + assert.Contains(t, checkerLogMsg, "no rules matched against events") + + events[0].NodeStatus = wfmodel.NodeBatchSuccess + cmd, runId, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) + assert.Nil(t, err) + assert.Equal(t, sc.NodeGo, cmd) + assert.Equal(t, int16(10), runId) + assert.Contains(t, checkerLogMsg, "matched rule 0(go)") + + events[0].NodeStatus = wfmodel.NodeBatchNone + cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) + assert.Nil(t, err) + assert.Equal(t, sc.NodeWait, cmd) + assert.Contains(t, checkerLogMsg, "matched rule 1(wait)") + + events[0].NodeStatus = wfmodel.NodeBatchStart + cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) + assert.Nil(t, err) + assert.Equal(t, sc.NodeWait, cmd) + assert.Contains(t, checkerLogMsg, "matched rule 2(wait)") + + events[0].NodeStatus = wfmodel.NodeBatchFail + cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) + assert.Nil(t, err) + assert.Equal(t, sc.NodeNogo, cmd) + assert.Contains(t, checkerLogMsg, "matched rule 3(nogo)") + + events[0].RunIsCurrent = false + + events[0].NodeStatus = wfmodel.NodeBatchRunStopReceived + cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) + assert.Nil(t, err) + assert.Equal(t, sc.NodeNogo, cmd) + assert.Contains(t, checkerLogMsg, "no rules matched against events") + + events[0].NodeStatus = wfmodel.NodeBatchSuccess + cmd, runId, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) + assert.Nil(t, err) + assert.Equal(t, sc.NodeGo, cmd) + assert.Equal(t, int16(10), runId) + assert.Contains(t, checkerLogMsg, "matched rule 4(go)") + + events[0].NodeStatus = wfmodel.NodeBatchNone + cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) + assert.Nil(t, err) + assert.Equal(t, sc.NodeWait, cmd) + assert.Contains(t, checkerLogMsg, "matched rule 5(wait)") + + events[0].NodeStatus = wfmodel.NodeBatchStart + cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) + assert.Nil(t, err) + assert.Equal(t, sc.NodeWait, cmd) + assert.Contains(t, checkerLogMsg, "matched rule 6(wait)") + + events[0].RunFinalStatus = wfmodel.RunComplete + + events[0].NodeStatus = wfmodel.NodeBatchSuccess + cmd, runId, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) + assert.Nil(t, err) + assert.Equal(t, sc.NodeGo, cmd) + assert.Equal(t, int16(10), runId) + assert.Contains(t, checkerLogMsg, "matched rule 7(go)") + + events[0].NodeStatus = wfmodel.NodeBatchFail + cmd, _, checkerLogMsg, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) + assert.Nil(t, err) + assert.Equal(t, sc.NodeNogo, cmd) + assert.Contains(t, checkerLogMsg, "matched rule 8(nogo)") + + // Failures + + re := regexp.MustCompile(`"expression": "e\.run[^"]+"`) + err = polDef.Deserialize([]byte(re.ReplaceAllString(sc.DefaultPolicyCheckerConf, `"expression": "1"`))) + assert.Nil(t, err) + _, _, _, err = CheckDependencyPolicyAgainstNodeEventList(&polDef, events) + assert.Contains(t, err.Error(), "expected result type was bool, got int64") +} diff --git a/pkg/env/amqp_config.go b/pkg/env/amqp_config.go index c1c900c..a1195fd 100644 --- a/pkg/env/amqp_config.go +++ b/pkg/env/amqp_config.go @@ -1,12 +1,12 @@ -package env - -type AmqpConfig struct { - URL string `json:"url"` - Exchange string `json:"exchange"` - PrefetchCount int `json:"prefetch_count"` - PrefetchSize int `json:"prefetch_size"` - // Became obsolete after we refactored the framework so it sends all batch messages in the very beginning - // FlowMaxPerConsumer int `json:"flow_max_per_consumer"` - // FlowWaitMillisMin int `json:"flow_wait_millis_min"` - // FlowWaitMillisMax int `json:"flow_wait_millis_max"` -} +package env + +type AmqpConfig struct { + URL string `json:"url"` + Exchange string `json:"exchange"` + PrefetchCount int `json:"prefetch_count"` + PrefetchSize int `json:"prefetch_size"` + // Became obsolete after we refactored the framework so it sends all batch messages in the very beginning + // FlowMaxPerConsumer int `json:"flow_max_per_consumer"` + // FlowWaitMillisMin int `json:"flow_wait_millis_min"` + // FlowWaitMillisMax int `json:"flow_wait_millis_max"` +} diff --git a/pkg/env/cassandra_config.go b/pkg/env/cassandra_config.go index c935e83..fbc3c84 100644 --- a/pkg/env/cassandra_config.go +++ b/pkg/env/cassandra_config.go @@ -1,23 +1,23 @@ -package env - -// This was not tested outside of the EnableHostVerification=false scenario -type SslOptions struct { - CertPath string `json:"cert_path"` - KeyPath string `json:"key_path"` - CaPath string `json:"ca_path"` - EnableHostVerification bool `json:"enable_host_verification"` -} - -type CassandraConfig struct { - Hosts []string `json:"hosts"` - Port int `json:"port"` - Username string `json:"username"` - Password string `json:"password"` - WriterWorkers int `json:"writer_workers"` // 20 is conservative, 80 is very aggressive - MinInserterRate int `json:"min_inserter_rate"` // writes/sec; if the rate falls below this, we consider the db too slow and throw an error - NumConns int `json:"num_conns"` // gocql default is 2, don't make it too high - Timeout int `json:"timeout"` // in ms, set it to 5s, gocql default 600ms is way too aggressive for heavy writes by multiple workers - ConnectTimeout int `json:"connect_timeout"` // in ms, set it to 1s, gocql default 600ms may be ok, but let's stay on the safe side - KeyspaceReplicationConfig string `json:"keyspace_replication_config"` // { 'class' : 'NetworkTopologyStrategy', 'datacenter1' : 1 } - SslOpts *SslOptions `json:"ssl_opts"` -} +package env + +// This was not tested outside of the EnableHostVerification=false scenario +type SslOptions struct { + CertPath string `json:"cert_path"` + KeyPath string `json:"key_path"` + CaPath string `json:"ca_path"` + EnableHostVerification bool `json:"enable_host_verification"` +} + +type CassandraConfig struct { + Hosts []string `json:"hosts"` + Port int `json:"port"` + Username string `json:"username"` + Password string `json:"password"` + WriterWorkers int `json:"writer_workers"` // 20 is conservative, 80 is very aggressive + MinInserterRate int `json:"min_inserter_rate"` // writes/sec; if the rate falls below this, we consider the db too slow and throw an error + NumConns int `json:"num_conns"` // gocql default is 2, don't make it too high + Timeout int `json:"timeout"` // in ms, set it to 5s, gocql default 600ms is way too aggressive for heavy writes by multiple workers + ConnectTimeout int `json:"connect_timeout"` // in ms, set it to 1s, gocql default 600ms may be ok, but let's stay on the safe side + KeyspaceReplicationConfig string `json:"keyspace_replication_config"` // { 'class' : 'NetworkTopologyStrategy', 'datacenter1' : 1 } + SslOpts *SslOptions `json:"ssl_opts"` +} diff --git a/pkg/env/env_config.go b/pkg/env/env_config.go index 8b27812..69ff330 100644 --- a/pkg/env/env_config.go +++ b/pkg/env/env_config.go @@ -1,75 +1,74 @@ -package env - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "os" - "path/filepath" - - "github.com/capillariesio/capillaries/pkg/sc" - "go.uber.org/zap" -) - -type EnvConfig struct { - HandlerExecutableType string `json:"handler_executable_type"` - Cassandra CassandraConfig `json:"cassandra"` - Amqp AmqpConfig `json:"amqp"` - ZapConfig zap.Config `json:"zap_config"` - ThreadPoolSize int `json:"thread_pool_size"` - DeadLetterTtl int `json:"dead_letter_ttl"` - CaPath string `json:"ca_path"` - PrivateKeys map[string]string `json:"private_keys"` - Webapi WebapiConfig `json:"webapi,omitempty"` - CustomProcessorsSettings map[string]json.RawMessage `json:"custom_processors"` - CustomProcessorDefFactoryInstance sc.CustomProcessorDefFactory -} - -func (ec *EnvConfig) Deserialize(jsonBytes []byte) error { - err := json.Unmarshal(jsonBytes, ec) - if err != nil { - return fmt.Errorf("cannot deserialize env config: %s", err.Error()) - } - - // Defaults - - if ec.ThreadPoolSize <= 0 || ec.ThreadPoolSize > 100 { - ec.ThreadPoolSize = 5 - } - - if ec.DeadLetterTtl < 100 || ec.DeadLetterTtl > 3600000 { // [100ms,1hr] - ec.DeadLetterTtl = 1000 - } - - return nil -} - -func ReadEnvConfigFile(envConfigFile string) (*EnvConfig, error) { - exec, err := os.Executable() - if err != nil { - return nil, fmt.Errorf("cannot find current executable path: %s", err.Error()) - } - configFullPath := filepath.Join(filepath.Dir(exec), envConfigFile) - if _, err := os.Stat(configFullPath); err != nil { - cwd, err := os.Getwd() - if err != nil { - return nil, fmt.Errorf("cannot get current dir: [%s]", err.Error()) - } - configFullPath = filepath.Join(cwd, envConfigFile) - if _, err := os.Stat(configFullPath); err != nil { - return nil, fmt.Errorf("cannot find config file [%s], neither at [%s] nor at current dir [%s]: [%s]", envConfigFile, filepath.Dir(exec), filepath.Join(cwd, envConfigFile), err.Error()) - } - } - - envBytes, err := ioutil.ReadFile(configFullPath) - if err != nil { - return nil, fmt.Errorf("cannot read env config file %s: %s", configFullPath, err.Error()) - } - - var envConfig EnvConfig - if err := envConfig.Deserialize(envBytes); err != nil { - return nil, fmt.Errorf("cannot parse env config file %s: %s", configFullPath, err.Error()) - } - - return &envConfig, nil -} +package env + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/capillariesio/capillaries/pkg/sc" + "go.uber.org/zap" +) + +type EnvConfig struct { + HandlerExecutableType string `json:"handler_executable_type"` + Cassandra CassandraConfig `json:"cassandra"` + Amqp AmqpConfig `json:"amqp"` + ZapConfig zap.Config `json:"zap_config"` + ThreadPoolSize int `json:"thread_pool_size"` + DeadLetterTtl int `json:"dead_letter_ttl"` + CaPath string `json:"ca_path"` + PrivateKeys map[string]string `json:"private_keys"` + Webapi WebapiConfig `json:"webapi,omitempty"` + CustomProcessorsSettings map[string]json.RawMessage `json:"custom_processors"` + CustomProcessorDefFactoryInstance sc.CustomProcessorDefFactory +} + +func (ec *EnvConfig) Deserialize(jsonBytes []byte) error { + err := json.Unmarshal(jsonBytes, ec) + if err != nil { + return fmt.Errorf("cannot deserialize env config: %s", err.Error()) + } + + // Defaults + + if ec.ThreadPoolSize <= 0 || ec.ThreadPoolSize > 100 { + ec.ThreadPoolSize = 5 + } + + if ec.DeadLetterTtl < 100 || ec.DeadLetterTtl > 3600000 { // [100ms,1hr] + ec.DeadLetterTtl = 1000 + } + + return nil +} + +func ReadEnvConfigFile(envConfigFile string) (*EnvConfig, error) { + exec, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("cannot find current executable path: %s", err.Error()) + } + configFullPath := filepath.Join(filepath.Dir(exec), envConfigFile) + if _, err := os.Stat(configFullPath); err != nil { + cwd, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("cannot get current dir: [%s]", err.Error()) + } + configFullPath = filepath.Join(cwd, envConfigFile) + if _, err := os.Stat(configFullPath); err != nil { + return nil, fmt.Errorf("cannot find config file [%s], neither at [%s] nor at current dir [%s]: [%s]", envConfigFile, filepath.Dir(exec), filepath.Join(cwd, envConfigFile), err.Error()) + } + } + + envBytes, err := os.ReadFile(configFullPath) + if err != nil { + return nil, fmt.Errorf("cannot read env config file %s: %s", configFullPath, err.Error()) + } + + var envConfig EnvConfig + if err := envConfig.Deserialize(envBytes); err != nil { + return nil, fmt.Errorf("cannot parse env config file %s: %s", configFullPath, err.Error()) + } + + return &envConfig, nil +} diff --git a/pkg/env/webapi_config.go b/pkg/env/webapi_config.go index 9700ef7..f1f2db6 100644 --- a/pkg/env/webapi_config.go +++ b/pkg/env/webapi_config.go @@ -1,6 +1,6 @@ -package env - -type WebapiConfig struct { - Port int `json:"webapi_port"` - AccessControlAllowOrigin string `json:"access_control_allow_origin"` -} +package env + +type WebapiConfig struct { + Port int `json:"webapi_port"` + AccessControlAllowOrigin string `json:"access_control_allow_origin"` +} diff --git a/pkg/eval/agg.go b/pkg/eval/agg.go index 50b07f4..237919d 100644 --- a/pkg/eval/agg.go +++ b/pkg/eval/agg.go @@ -1,358 +1,358 @@ -package eval - -import ( - "fmt" - "go/ast" - "strings" - - "github.com/shopspring/decimal" -) - -type AggFuncType string - -const ( - AggStringAgg AggFuncType = "string_agg" - AggSum AggFuncType = "sum" - AggCount AggFuncType = "count" - AggAvg AggFuncType = "avg" - AggMin AggFuncType = "min" - AggMax AggFuncType = "max" - AggUnknown AggFuncType = "unknown" -) - -func StringToAggFunc(testString string) AggFuncType { - switch testString { - case string(AggStringAgg): - return AggStringAgg - case string(AggSum): - return AggSum - case string(AggCount): - return AggCount - case string(AggAvg): - return AggAvg - case string(AggMin): - return AggMin - case string(AggMax): - return AggMax - default: - return AggUnknown - } -} - -type SumCollector struct { - Int int64 - Float float64 - Dec decimal.Decimal -} -type AvgCollector struct { - Int int64 - Float float64 - Dec decimal.Decimal - Count int64 -} - -type MinCollector struct { - Int int64 - Float float64 - Dec decimal.Decimal - Str string - Count int64 -} - -type MaxCollector struct { - Int int64 - Float float64 - Dec decimal.Decimal - Str string - Count int64 -} - -type StringAggCollector struct { - Sb strings.Builder - Separator string -} - -type AggDataType string - -const ( - AggTypeUnknown AggDataType = "unknown" - AggTypeInt AggDataType = "int" - AggTypeFloat AggDataType = "float" - AggTypeDec AggDataType = "decimal" - AggTypeString AggDataType = "string" -) - -func (eCtx *EvalCtx) checkAgg(funcName string, callExp *ast.CallExpr, aggFunc AggFuncType) error { - if eCtx.AggEnabled != AggFuncEnabled { - return fmt.Errorf("cannot evaluate %s(), context aggregate not enabled", funcName) - } - if eCtx.AggCallExp != nil { - if eCtx.AggCallExp != callExp { - return fmt.Errorf("cannot evaluate more than one aggregate functions in the expression, extra %s() found besides %s()", funcName, eCtx.AggFunc) - } - } else { - eCtx.AggCallExp = callExp - eCtx.AggFunc = aggFunc - } - return nil -} - -func (eCtx *EvalCtx) CallAggStringAgg(callExp *ast.CallExpr, args []interface{}) (interface{}, error) { - if err := eCtx.checkAgg("string_agg", callExp, AggStringAgg); err != nil { - return nil, err - } - if err := checkArgs("string_agg", 2, len(args)); err != nil { - return nil, err - } - - switch typedArg0 := args[0].(type) { - case string: - if eCtx.StringAgg.Sb.Len() > 0 { - eCtx.StringAgg.Sb.WriteString(eCtx.StringAgg.Separator) - } - eCtx.StringAgg.Sb.WriteString(typedArg0) - return eCtx.StringAgg.Sb.String(), nil - - default: - return nil, fmt.Errorf("cannot evaluate string_agg(), unexpected argument %v of unsupported type %T", args[0], args[0]) - } -} - -func (eCtx *EvalCtx) CallAggSum(callExp *ast.CallExpr, args []interface{}) (interface{}, error) { - if err := eCtx.checkAgg("sum", callExp, AggSum); err != nil { - return nil, err - } - if err := checkArgs("sum", 1, len(args)); err != nil { - return nil, err - } - stdTypedArg, err := castNumberToStandardType(args[0]) - if err != nil { - return nil, err - } - switch typedArg0 := stdTypedArg.(type) { - case int64: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeInt - } else if eCtx.AggType != AggTypeInt { - return nil, fmt.Errorf("cannot evaluate sum(), it started with type %s, now got int value %d", eCtx.AggType, typedArg0) - } - eCtx.Sum.Int += typedArg0 - return eCtx.Sum.Int, nil - - case float64: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeFloat - } else if eCtx.AggType != AggTypeFloat { - return nil, fmt.Errorf("cannot evaluate sum(), it started with type %s, now got float value %f", eCtx.AggType, typedArg0) - } - eCtx.Sum.Float += typedArg0 - return eCtx.Sum.Float, nil - - case decimal.Decimal: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeDec - } else if eCtx.AggType != AggTypeDec { - return nil, fmt.Errorf("cannot evaluate sum(), it started with type %s, now got decimal value %s", eCtx.AggType, typedArg0.String()) - } - eCtx.Sum.Dec = eCtx.Sum.Dec.Add(typedArg0) - return eCtx.Sum.Dec, nil - - default: - return nil, fmt.Errorf("cannot evaluate sum(), unexpected argument %v of unsupported type %T", args[0], args[0]) - } -} - -func (eCtx *EvalCtx) CallAggAvg(callExp *ast.CallExpr, args []interface{}) (interface{}, error) { - if err := eCtx.checkAgg("avg", callExp, AggAvg); err != nil { - return nil, err - } - if err := checkArgs("avg", 1, len(args)); err != nil { - return nil, err - } - stdTypedArg, err := castNumberToStandardType(args[0]) - if err != nil { - return nil, err - } - switch typedArg0 := stdTypedArg.(type) { - case int64: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeInt - } else if eCtx.AggType != AggTypeInt { - return nil, fmt.Errorf("cannot evaluate avg(), it started with type %s, now got int value %d", eCtx.AggType, typedArg0) - } - eCtx.Avg.Int += typedArg0 - eCtx.Avg.Count++ - return eCtx.Avg.Int / eCtx.Avg.Count, nil - - case float64: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeFloat - } else if eCtx.AggType != AggTypeFloat { - return nil, fmt.Errorf("cannot evaluate avg(), it started with type %s, now got float value %f", eCtx.AggType, typedArg0) - } - eCtx.Avg.Float += typedArg0 - eCtx.Avg.Count++ - return eCtx.Avg.Float / float64(eCtx.Avg.Count), nil - - case decimal.Decimal: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeDec - } else if eCtx.AggType != AggTypeDec { - return nil, fmt.Errorf("cannot evaluate avg(), it started with type %s, now got decimal value %s", eCtx.AggType, typedArg0.String()) - } - eCtx.Avg.Dec = eCtx.Avg.Dec.Add(typedArg0) - eCtx.Avg.Count++ - return eCtx.Avg.Dec.Div(decimal.NewFromInt(eCtx.Avg.Count)).Round(2), nil - - default: - return nil, fmt.Errorf("cannot evaluate avg(), unexpected argument %v of unsupported type %T", args[0], args[0]) - } -} - -func (eCtx *EvalCtx) CallAggCount(callExp *ast.CallExpr, args []interface{}) (interface{}, error) { - if err := eCtx.checkAgg("count", callExp, AggCount); err != nil { - return nil, err - } - if err := checkArgs("count", 0, len(args)); err != nil { - return nil, err - } - eCtx.Count++ - return eCtx.Count, nil -} - -func (eCtx *EvalCtx) CallAggMin(callExp *ast.CallExpr, args []interface{}) (interface{}, error) { - if err := eCtx.checkAgg("min", callExp, AggMin); err != nil { - return nil, err - } - if err := checkArgs("min", 1, len(args)); err != nil { - return nil, err - } - - switch typedArg0 := args[0].(type) { - case string: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeString - } else if eCtx.AggType != AggTypeString { - return nil, fmt.Errorf("cannot evaluate min(), it started with type %s, now got string value %s", eCtx.AggType, typedArg0) - } - eCtx.Min.Count++ - if len(eCtx.Min.Str) == 0 || typedArg0 < eCtx.Min.Str { - eCtx.Min.Str = typedArg0 - } - return eCtx.Min.Str, nil - - default: - stdTypedArg0, err := castNumberToStandardType(args[0]) - if err != nil { - return nil, err - } - switch typedNumberArg0 := stdTypedArg0.(type) { - case int64: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeInt - } else if eCtx.AggType != AggTypeInt { - return nil, fmt.Errorf("cannot evaluate min(), it started with type %s, now got int value %d", eCtx.AggType, typedNumberArg0) - } - eCtx.Min.Count++ - if typedNumberArg0 < eCtx.Min.Int { - eCtx.Min.Int = typedNumberArg0 - } - return eCtx.Min.Int, nil - - case float64: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeFloat - } else if eCtx.AggType != AggTypeFloat { - return nil, fmt.Errorf("cannot evaluate min(), it started with type %s, now got float value %f", eCtx.AggType, typedNumberArg0) - } - eCtx.Min.Count++ - if typedNumberArg0 < eCtx.Min.Float { - eCtx.Min.Float = typedNumberArg0 - } - return eCtx.Min.Float, nil - - case decimal.Decimal: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeDec - } else if eCtx.AggType != AggTypeDec { - return nil, fmt.Errorf("cannot evaluate min(), it started with type %s, now got decimal value %s", eCtx.AggType, typedNumberArg0.String()) - } - eCtx.Min.Count++ - if typedNumberArg0.LessThan(eCtx.Min.Dec) { - eCtx.Min.Dec = typedNumberArg0 - } - return eCtx.Min.Dec, nil - - default: - return nil, fmt.Errorf("cannot evaluate min(), unexpected argument %v of unsupported type %T", args[0], args[0]) - } - } -} - -func (eCtx *EvalCtx) CallAggMax(callExp *ast.CallExpr, args []interface{}) (interface{}, error) { - if err := eCtx.checkAgg("max", callExp, AggMax); err != nil { - return nil, err - } - if err := checkArgs("max", 1, len(args)); err != nil { - return nil, err - } - - switch typedArg0 := args[0].(type) { - case string: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeString - } else if eCtx.AggType != AggTypeString { - return nil, fmt.Errorf("cannot evaluate max(), it started with type %s, now got string value %s", eCtx.AggType, typedArg0) - } - eCtx.Max.Count++ - if len(eCtx.Max.Str) == 0 || typedArg0 > eCtx.Max.Str { - eCtx.Max.Str = typedArg0 - } - return eCtx.Max.Str, nil - default: - stdTypedNumberArg0, err := castNumberToStandardType(args[0]) - if err != nil { - return nil, err - } - switch typedNumberArg0 := stdTypedNumberArg0.(type) { - case int64: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeInt - } else if eCtx.AggType != AggTypeInt { - return nil, fmt.Errorf("cannot evaluate max(), it started with type %s, now got int value %d", eCtx.AggType, typedNumberArg0) - } - eCtx.Max.Count++ - if typedNumberArg0 > eCtx.Max.Int { - eCtx.Max.Int = typedNumberArg0 - } - return eCtx.Max.Int, nil - - case float64: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeFloat - } else if eCtx.AggType != AggTypeFloat { - return nil, fmt.Errorf("cannot evaluate max(), it started with type %s, now got float value %f", eCtx.AggType, typedNumberArg0) - } - eCtx.Max.Count++ - if typedNumberArg0 > eCtx.Max.Float { - eCtx.Max.Float = typedNumberArg0 - } - return eCtx.Max.Float, nil - - case decimal.Decimal: - if eCtx.AggType == AggTypeUnknown { - eCtx.AggType = AggTypeDec - } else if eCtx.AggType != AggTypeDec { - return nil, fmt.Errorf("cannot evaluate max(), it started with type %s, now got decimal value %s", eCtx.AggType, typedNumberArg0.String()) - } - eCtx.Max.Count++ - if typedNumberArg0.GreaterThan(eCtx.Max.Dec) { - eCtx.Max.Dec = typedNumberArg0 - } - return eCtx.Max.Dec, nil - - default: - return nil, fmt.Errorf("cannot evaluate max(), unexpected argument %v of unsupported type %T", args[0], args[0]) - } - } -} +package eval + +import ( + "fmt" + "go/ast" + "strings" + + "github.com/shopspring/decimal" +) + +type AggFuncType string + +const ( + AggStringAgg AggFuncType = "string_agg" + AggSum AggFuncType = "sum" + AggCount AggFuncType = "count" + AggAvg AggFuncType = "avg" + AggMin AggFuncType = "min" + AggMax AggFuncType = "max" + AggUnknown AggFuncType = "unknown" +) + +func StringToAggFunc(testString string) AggFuncType { + switch testString { + case string(AggStringAgg): + return AggStringAgg + case string(AggSum): + return AggSum + case string(AggCount): + return AggCount + case string(AggAvg): + return AggAvg + case string(AggMin): + return AggMin + case string(AggMax): + return AggMax + default: + return AggUnknown + } +} + +type SumCollector struct { + Int int64 + Float float64 + Dec decimal.Decimal +} +type AvgCollector struct { + Int int64 + Float float64 + Dec decimal.Decimal + Count int64 +} + +type MinCollector struct { + Int int64 + Float float64 + Dec decimal.Decimal + Str string + Count int64 +} + +type MaxCollector struct { + Int int64 + Float float64 + Dec decimal.Decimal + Str string + Count int64 +} + +type StringAggCollector struct { + Sb strings.Builder + Separator string +} + +type AggDataType string + +const ( + AggTypeUnknown AggDataType = "unknown" + AggTypeInt AggDataType = "int" + AggTypeFloat AggDataType = "float" + AggTypeDec AggDataType = "decimal" + AggTypeString AggDataType = "string" +) + +func (eCtx *EvalCtx) checkAgg(funcName string, callExp *ast.CallExpr, aggFunc AggFuncType) error { + if eCtx.AggEnabled != AggFuncEnabled { + return fmt.Errorf("cannot evaluate %s(), context aggregate not enabled", funcName) + } + if eCtx.AggCallExp != nil { + if eCtx.AggCallExp != callExp { + return fmt.Errorf("cannot evaluate more than one aggregate functions in the expression, extra %s() found besides %s()", funcName, eCtx.AggFunc) + } + } else { + eCtx.AggCallExp = callExp + eCtx.AggFunc = aggFunc + } + return nil +} + +func (eCtx *EvalCtx) CallAggStringAgg(callExp *ast.CallExpr, args []any) (any, error) { + if err := eCtx.checkAgg("string_agg", callExp, AggStringAgg); err != nil { + return nil, err + } + if err := checkArgs("string_agg", 2, len(args)); err != nil { + return nil, err + } + + switch typedArg0 := args[0].(type) { + case string: + if eCtx.StringAgg.Sb.Len() > 0 { + eCtx.StringAgg.Sb.WriteString(eCtx.StringAgg.Separator) + } + eCtx.StringAgg.Sb.WriteString(typedArg0) + return eCtx.StringAgg.Sb.String(), nil + + default: + return nil, fmt.Errorf("cannot evaluate string_agg(), unexpected argument %v of unsupported type %T", args[0], args[0]) + } +} + +func (eCtx *EvalCtx) CallAggSum(callExp *ast.CallExpr, args []any) (any, error) { + if err := eCtx.checkAgg("sum", callExp, AggSum); err != nil { + return nil, err + } + if err := checkArgs("sum", 1, len(args)); err != nil { + return nil, err + } + stdTypedArg, err := castNumberToStandardType(args[0]) + if err != nil { + return nil, err + } + switch typedArg0 := stdTypedArg.(type) { + case int64: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeInt + } else if eCtx.AggType != AggTypeInt { + return nil, fmt.Errorf("cannot evaluate sum(), it started with type %s, now got int value %d", eCtx.AggType, typedArg0) + } + eCtx.Sum.Int += typedArg0 + return eCtx.Sum.Int, nil + + case float64: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeFloat + } else if eCtx.AggType != AggTypeFloat { + return nil, fmt.Errorf("cannot evaluate sum(), it started with type %s, now got float value %f", eCtx.AggType, typedArg0) + } + eCtx.Sum.Float += typedArg0 + return eCtx.Sum.Float, nil + + case decimal.Decimal: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeDec + } else if eCtx.AggType != AggTypeDec { + return nil, fmt.Errorf("cannot evaluate sum(), it started with type %s, now got decimal value %s", eCtx.AggType, typedArg0.String()) + } + eCtx.Sum.Dec = eCtx.Sum.Dec.Add(typedArg0) + return eCtx.Sum.Dec, nil + + default: + return nil, fmt.Errorf("cannot evaluate sum(), unexpected argument %v of unsupported type %T", args[0], args[0]) + } +} + +func (eCtx *EvalCtx) CallAggAvg(callExp *ast.CallExpr, args []any) (any, error) { + if err := eCtx.checkAgg("avg", callExp, AggAvg); err != nil { + return nil, err + } + if err := checkArgs("avg", 1, len(args)); err != nil { + return nil, err + } + stdTypedArg, err := castNumberToStandardType(args[0]) + if err != nil { + return nil, err + } + switch typedArg0 := stdTypedArg.(type) { + case int64: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeInt + } else if eCtx.AggType != AggTypeInt { + return nil, fmt.Errorf("cannot evaluate avg(), it started with type %s, now got int value %d", eCtx.AggType, typedArg0) + } + eCtx.Avg.Int += typedArg0 + eCtx.Avg.Count++ + return eCtx.Avg.Int / eCtx.Avg.Count, nil + + case float64: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeFloat + } else if eCtx.AggType != AggTypeFloat { + return nil, fmt.Errorf("cannot evaluate avg(), it started with type %s, now got float value %f", eCtx.AggType, typedArg0) + } + eCtx.Avg.Float += typedArg0 + eCtx.Avg.Count++ + return eCtx.Avg.Float / float64(eCtx.Avg.Count), nil + + case decimal.Decimal: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeDec + } else if eCtx.AggType != AggTypeDec { + return nil, fmt.Errorf("cannot evaluate avg(), it started with type %s, now got decimal value %s", eCtx.AggType, typedArg0.String()) + } + eCtx.Avg.Dec = eCtx.Avg.Dec.Add(typedArg0) + eCtx.Avg.Count++ + return eCtx.Avg.Dec.Div(decimal.NewFromInt(eCtx.Avg.Count)).Round(2), nil + + default: + return nil, fmt.Errorf("cannot evaluate avg(), unexpected argument %v of unsupported type %T", args[0], args[0]) + } +} + +func (eCtx *EvalCtx) CallAggCount(callExp *ast.CallExpr, args []any) (any, error) { + if err := eCtx.checkAgg("count", callExp, AggCount); err != nil { + return nil, err + } + if err := checkArgs("count", 0, len(args)); err != nil { + return nil, err + } + eCtx.Count++ + return eCtx.Count, nil +} + +func (eCtx *EvalCtx) CallAggMin(callExp *ast.CallExpr, args []any) (any, error) { + if err := eCtx.checkAgg("min", callExp, AggMin); err != nil { + return nil, err + } + if err := checkArgs("min", 1, len(args)); err != nil { + return nil, err + } + + switch typedArg0 := args[0].(type) { + case string: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeString + } else if eCtx.AggType != AggTypeString { + return nil, fmt.Errorf("cannot evaluate min(), it started with type %s, now got string value %s", eCtx.AggType, typedArg0) + } + eCtx.Min.Count++ + if len(eCtx.Min.Str) == 0 || typedArg0 < eCtx.Min.Str { + eCtx.Min.Str = typedArg0 + } + return eCtx.Min.Str, nil + + default: + stdTypedArg0, err := castNumberToStandardType(args[0]) + if err != nil { + return nil, err + } + switch typedNumberArg0 := stdTypedArg0.(type) { + case int64: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeInt + } else if eCtx.AggType != AggTypeInt { + return nil, fmt.Errorf("cannot evaluate min(), it started with type %s, now got int value %d", eCtx.AggType, typedNumberArg0) + } + eCtx.Min.Count++ + if typedNumberArg0 < eCtx.Min.Int { + eCtx.Min.Int = typedNumberArg0 + } + return eCtx.Min.Int, nil + + case float64: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeFloat + } else if eCtx.AggType != AggTypeFloat { + return nil, fmt.Errorf("cannot evaluate min(), it started with type %s, now got float value %f", eCtx.AggType, typedNumberArg0) + } + eCtx.Min.Count++ + if typedNumberArg0 < eCtx.Min.Float { + eCtx.Min.Float = typedNumberArg0 + } + return eCtx.Min.Float, nil + + case decimal.Decimal: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeDec + } else if eCtx.AggType != AggTypeDec { + return nil, fmt.Errorf("cannot evaluate min(), it started with type %s, now got decimal value %s", eCtx.AggType, typedNumberArg0.String()) + } + eCtx.Min.Count++ + if typedNumberArg0.LessThan(eCtx.Min.Dec) { + eCtx.Min.Dec = typedNumberArg0 + } + return eCtx.Min.Dec, nil + + default: + return nil, fmt.Errorf("cannot evaluate min(), unexpected argument %v of unsupported type %T", args[0], args[0]) + } + } +} + +func (eCtx *EvalCtx) CallAggMax(callExp *ast.CallExpr, args []any) (any, error) { + if err := eCtx.checkAgg("max", callExp, AggMax); err != nil { + return nil, err + } + if err := checkArgs("max", 1, len(args)); err != nil { + return nil, err + } + + switch typedArg0 := args[0].(type) { + case string: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeString + } else if eCtx.AggType != AggTypeString { + return nil, fmt.Errorf("cannot evaluate max(), it started with type %s, now got string value %s", eCtx.AggType, typedArg0) + } + eCtx.Max.Count++ + if len(eCtx.Max.Str) == 0 || typedArg0 > eCtx.Max.Str { + eCtx.Max.Str = typedArg0 + } + return eCtx.Max.Str, nil + default: + stdTypedNumberArg0, err := castNumberToStandardType(args[0]) + if err != nil { + return nil, err + } + switch typedNumberArg0 := stdTypedNumberArg0.(type) { + case int64: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeInt + } else if eCtx.AggType != AggTypeInt { + return nil, fmt.Errorf("cannot evaluate max(), it started with type %s, now got int value %d", eCtx.AggType, typedNumberArg0) + } + eCtx.Max.Count++ + if typedNumberArg0 > eCtx.Max.Int { + eCtx.Max.Int = typedNumberArg0 + } + return eCtx.Max.Int, nil + + case float64: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeFloat + } else if eCtx.AggType != AggTypeFloat { + return nil, fmt.Errorf("cannot evaluate max(), it started with type %s, now got float value %f", eCtx.AggType, typedNumberArg0) + } + eCtx.Max.Count++ + if typedNumberArg0 > eCtx.Max.Float { + eCtx.Max.Float = typedNumberArg0 + } + return eCtx.Max.Float, nil + + case decimal.Decimal: + if eCtx.AggType == AggTypeUnknown { + eCtx.AggType = AggTypeDec + } else if eCtx.AggType != AggTypeDec { + return nil, fmt.Errorf("cannot evaluate max(), it started with type %s, now got decimal value %s", eCtx.AggType, typedNumberArg0.String()) + } + eCtx.Max.Count++ + if typedNumberArg0.GreaterThan(eCtx.Max.Dec) { + eCtx.Max.Dec = typedNumberArg0 + } + return eCtx.Max.Dec, nil + + default: + return nil, fmt.Errorf("cannot evaluate max(), unexpected argument %v of unsupported type %T", args[0], args[0]) + } + } +} diff --git a/pkg/eval/agg_test.go b/pkg/eval/agg_test.go index be98258..3b46105 100644 --- a/pkg/eval/agg_test.go +++ b/pkg/eval/agg_test.go @@ -1,497 +1,470 @@ -package eval - -import ( - "fmt" - "go/ast" - "go/parser" - "testing" - - "github.com/shopspring/decimal" - "github.com/stretchr/testify/assert" -) - -func getTestValuesMap() VarValuesMap { - return VarValuesMap{ - "t1": { - "fieldInt": 1, - "fieldFloat": 2.1, - "fieldDec": decimal.NewFromInt(1), - "fieldStr": "a", - }, - } -} - -func TestMissingCtxVars(t *testing.T) { - varValuesMap := getTestValuesMap() - - var err error - var exp ast.Expr - var eCtx EvalCtx - - exp, _ = parser.ParseExpr("avg(t1.fieldInt)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - _, err = eCtx.Eval(exp) - assert.Contains(t, err.Error(), "no variables supplied to the context") - - delete(varValuesMap["t1"], "fieldInt") - exp, _ = parser.ParseExpr("avg(t1.fieldInt)") - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - _, err = eCtx.Eval(exp) - assert.Contains(t, err.Error(), "variable not supplied") -} - -func TestExtraAgg(t *testing.T) { - varValuesMap := getTestValuesMap() - - var err error - var exp ast.Expr - var eCtx EvalCtx - - // Extra sum - exp, _ = parser.ParseExpr("sum(min(t1.fieldFloat))") - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - _, err = eCtx.Eval(exp) - assert.Equal(t, err.Error(), "cannot evaluate more than one aggregate functions in the expression, extra sum() found besides min()") - - // Extra avg - exp, _ = parser.ParseExpr("avg(min(t1.fieldFloat))") - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - _, err = eCtx.Eval(exp) - assert.Equal(t, err.Error(), "cannot evaluate more than one aggregate functions in the expression, extra avg() found besides min()") - - // Extra min - exp, _ = parser.ParseExpr("min(min(t1.fieldFloat))") - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - _, err = eCtx.Eval(exp) - assert.Equal(t, err.Error(), "cannot evaluate more than one aggregate functions in the expression, extra min() found besides min()") - - // Extra max - exp, _ = parser.ParseExpr("max(min(t1.fieldFloat))") - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - _, err = eCtx.Eval(exp) - assert.Equal(t, err.Error(), "cannot evaluate more than one aggregate functions in the expression, extra max() found besides min()") - - // Extra count - exp, _ = parser.ParseExpr("min(t1.fieldFloat)+count())") - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - _, err = eCtx.Eval(exp) - assert.Equal(t, err.Error(), "cannot evaluate more than one aggregate functions in the expression, extra count() found besides min()") -} - -func assertFuncTypeAndArgs(t *testing.T, expression string, aggFuncEnabled AggEnabledType, expectedAggFuncType AggFuncType, expectedNumberOfArgs int) { - exp, _ := parser.ParseExpr(expression) - aggEnabledType, aggFuncType, aggFuncArgs := DetectRootAggFunc(exp) - assert.Equal(t, aggFuncEnabled, aggEnabledType, "expected AggFuncEnabled for "+expression) - assert.Equal(t, expectedAggFuncType, aggFuncType, fmt.Sprintf("expected %s for %s", expectedAggFuncType, expression)) - assert.Equal(t, expectedNumberOfArgs, len(aggFuncArgs), fmt.Sprintf("expected %d args for %s", expectedNumberOfArgs, expression)) -} - -func TestDetectRootArgFunc(t *testing.T) { - assertFuncTypeAndArgs(t, `string_agg(t1.fieldStr,",")`, AggFuncEnabled, AggStringAgg, 2) - assertFuncTypeAndArgs(t, `sum(t1.fieldFloat)`, AggFuncEnabled, AggSum, 1) - assertFuncTypeAndArgs(t, `avg(t1.fieldFloat)`, AggFuncEnabled, AggAvg, 1) - assertFuncTypeAndArgs(t, `min(t1.fieldFloat)`, AggFuncEnabled, AggMin, 1) - assertFuncTypeAndArgs(t, `max(t1.fieldFloat)`, AggFuncEnabled, AggMax, 1) - assertFuncTypeAndArgs(t, `count()`, AggFuncEnabled, AggCount, 0) - assertFuncTypeAndArgs(t, `some_func(t1.fieldFloat)`, AggFuncDisabled, AggUnknown, 0) -} - -func TestStringAgg(t *testing.T) { - varValuesMap := getTestValuesMap() - - var exp ast.Expr - var result interface{} - - varValuesMap["t1"]["fieldStr"] = "a" - - exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr,"-")`) - eCtx, _ := NewPlainEvalCtxWithVarsAndInitializedAgg(AggFuncEnabled, &varValuesMap, AggStringAgg, exp.(*ast.CallExpr).Args) - result, _ = eCtx.Eval(exp) - assert.Equal(t, "a", result) - varValuesMap["t1"]["fieldStr"] = "b" - result, _ = eCtx.Eval(exp) - assert.Equal(t, "a-b", result) - - // Empty str - exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr,",")`) - eCtx, _ = NewPlainEvalCtxWithVarsAndInitializedAgg(AggFuncEnabled, &varValuesMap, AggStringAgg, exp.(*ast.CallExpr).Args) - assert.Equal(t, "", eCtx.StringAgg.Sb.String()) - - var err error - - // Bad number of args - exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr)`) - eCtx, err = NewPlainEvalCtxWithVarsAndInitializedAgg(AggFuncEnabled, &varValuesMap, AggStringAgg, exp.(*ast.CallExpr).Args) - assert.Contains(t, err.Error(), "string_agg must have two parameters") - - // Bad separators - exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr, t2.someBadField)`) - eCtx, err = NewPlainEvalCtxWithVarsAndInitializedAgg(AggFuncEnabled, &varValuesMap, AggStringAgg, exp.(*ast.CallExpr).Args) - assert.Contains(t, err.Error(), "string_agg second parameter must be a basic literal") - - exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr, 123)`) - eCtx, err = NewPlainEvalCtxWithVarsAndInitializedAgg(AggFuncEnabled, &varValuesMap, AggStringAgg, exp.(*ast.CallExpr).Args) - assert.Contains(t, err.Error(), "string_agg second parameter must be a constant string") - - // Bad data type - exp, _ = parser.ParseExpr(`string_agg(t1.fieldFloat, ",")`) - eCtx, err = NewPlainEvalCtxWithVarsAndInitializedAgg(AggFuncEnabled, &varValuesMap, AggStringAgg, exp.(*ast.CallExpr).Args) - // TODO: can we check expression type before Eval? - _, err = eCtx.Eval(exp) - assert.Contains(t, err.Error(), "unsupported type float64") - - // Bad ctx with disabled agg func calling string_agg() - exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr,"-")`) - badCtx := NewPlainEvalCtxWithVars(AggFuncDisabled, &varValuesMap) - _, err = badCtx.Eval(exp) - assert.Contains(t, err.Error(), "context aggregate not enabled") -} - -func TestSum(t *testing.T) { - varValuesMap := getTestValuesMap() - - var exp ast.Expr - var eCtx EvalCtx - var result interface{} - var err error - - // Sum float - exp, _ = parser.ParseExpr("5 + sum(t1.fieldFloat)") - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldFloat"] = 2.1 - result, _ = eCtx.Eval(exp) - assert.Equal(t, 5+2.1, result) - result, _ = eCtx.Eval(exp) - assert.Equal(t, 5+4.2, result) - - // float -> dec - varValuesMap["t1"]["fieldFloat"] = decimal.NewFromInt(1) - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate sum(), it started with type float, now got decimal value 1", err.Error()) - - // Sum int - exp, _ = parser.ParseExpr("5 + sum(t1.fieldInt)") - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldInt"] = 1 - result, _ = eCtx.Eval(exp) - assert.Equal(t, int64(6), result) - result, _ = eCtx.Eval(exp) - assert.Equal(t, int64(7), result) - - // int -> float - varValuesMap["t1"]["fieldInt"] = float64(1) - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate sum(), it started with type int, now got float value 1.000000", err.Error()) - - // Sum dec - exp, _ = parser.ParseExpr("5 + sum(t1.fieldDec)") - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(1) - result, _ = eCtx.Eval(exp) - assert.Equal(t, decimal.New(600, -2), result) - result, _ = eCtx.Eval(exp) - assert.Equal(t, decimal.New(700, -2), result) - - // dec -> int - varValuesMap["t1"]["fieldDec"] = int64(1) - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate sum(), it started with type decimal, now got int value 1", err.Error()) - - // Sum int empty - exp, _ = parser.ParseExpr("sum(t1.fieldInt)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, int64(0), eCtx.Sum.Int) - - // Sum float empty - exp, _ = parser.ParseExpr("sum(t1.fieldFloat)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, float64(0), eCtx.Sum.Float) - - // Sum dec empty - exp, _ = parser.ParseExpr("sum(t1.fieldDec)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, defaultDecimal(), eCtx.Sum.Dec) -} - -func TestAvg(t *testing.T) { - varValuesMap := getTestValuesMap() - - var exp ast.Expr - var eCtx EvalCtx - var result interface{} - var err error - - // Avg int - exp, _ = parser.ParseExpr("avg(t1.fieldInt)") - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - - varValuesMap["t1"]["fieldInt"] = 1 - eCtx.Eval(exp) - eCtx.Eval(exp) - varValuesMap["t1"]["fieldInt"] = 2 - result, _ = eCtx.Eval(exp) - assert.Equal(t, int64(1), result) - - // int -> float - varValuesMap["t1"]["fieldInt"] = float64(1) - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate avg(), it started with type int, now got float value 1.000000", err.Error()) - - // Avg float - exp, _ = parser.ParseExpr("avg(t1.fieldFloat)") - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldFloat"] = float64(1) - eCtx.Eval(exp) - eCtx.Eval(exp) - varValuesMap["t1"]["fieldFloat"] = float64(2) - result, _ = eCtx.Eval(exp) - assert.Equal(t, float64(1.3333333333333333), result) - - // float -> dec - varValuesMap["t1"]["fieldFloat"] = decimal.NewFromInt(1) - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate avg(), it started with type float, now got decimal value 1", err.Error()) - - // Avg dec - exp, _ = parser.ParseExpr("avg(t1.fieldDec)") - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(1) - eCtx.Eval(exp) - eCtx.Eval(exp) - varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(2) - result, _ = eCtx.Eval(exp) - assert.Equal(t, decimal.NewFromFloat32(1.33), result) - - // dec -> int - varValuesMap["t1"]["fieldDec"] = int64(1) - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate avg(), it started with type decimal, now got int value 1", err.Error()) - - // Avg int empty - exp, _ = parser.ParseExpr("avg(t1.fieldInt)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, int64(0), eCtx.Avg.Int) - - // Avg float empty - exp, _ = parser.ParseExpr("avg(t1.fieldFloat)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, float64(0), eCtx.Avg.Float) - - // Avg dec empty - exp, _ = parser.ParseExpr("avg(t1.fieldDec)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, defaultDecimal(), eCtx.Avg.Dec) -} - -func TestMin(t *testing.T) { - varValuesMap := getTestValuesMap() - - var exp ast.Expr - var eCtx EvalCtx - var result interface{} - var err error - - // Min float - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldFloat"] = 1.0 - exp, _ = parser.ParseExpr("min(t1.fieldFloat)") - result, _ = eCtx.Eval(exp) - assert.Equal(t, 1.0, result) - varValuesMap["t1"]["fieldFloat"] = 2.0 - result, _ = eCtx.Eval(exp) - assert.Equal(t, 1.0, result) - - // float -> dec - varValuesMap["t1"]["fieldFloat"] = decimal.NewFromInt(1) - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate min(), it started with type float, now got decimal value 1", err.Error()) - - // float -> string - varValuesMap["t1"]["fieldFloat"] = "a" - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate min(), it started with type float, now got string value a", err.Error()) - - // Min int - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldInt"] = 1 - exp, _ = parser.ParseExpr("min(t1.fieldInt)") - result, _ = eCtx.Eval(exp) - assert.Equal(t, int64(1), result) - varValuesMap["t1"]["fieldInt"] = 2 - result, _ = eCtx.Eval(exp) - assert.Equal(t, int64(1), result) - - // int -> float - varValuesMap["t1"]["fieldInt"] = float64(1) - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate min(), it started with type int, now got float value 1.000000", err.Error()) - - // Min dec - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(1) - exp, _ = parser.ParseExpr("min(t1.fieldDec)") - result, _ = eCtx.Eval(exp) - assert.Equal(t, decimal.NewFromInt(1), result) - varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(2) - result, _ = eCtx.Eval(exp) - assert.Equal(t, decimal.NewFromInt(1), result) - - // dec -> int - varValuesMap["t1"]["fieldDec"] = int64(1) - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate min(), it started with type decimal, now got int value 1", err.Error()) - - // Min str - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldStr"] = "a" - exp, _ = parser.ParseExpr("min(t1.fieldStr)") - result, _ = eCtx.Eval(exp) - assert.Equal(t, "a", result) - varValuesMap["t1"]["fieldStr"] = "b" - result, _ = eCtx.Eval(exp) - assert.Equal(t, "a", result) - - // Empty int - exp, _ = parser.ParseExpr("min(t1.fieldInt)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, maxSupportedInt, eCtx.Min.Int) - - // Empty float - exp, _ = parser.ParseExpr("min(t1.fieldFloat)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, maxSupportedFloat, eCtx.Min.Float) - - // Empty dec - exp, _ = parser.ParseExpr("min(t1.fieldDec)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, maxSupportedDecimal(), eCtx.Min.Dec) - - // Empty str - exp, _ = parser.ParseExpr("min(t1.fieldString)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, "", eCtx.Min.Str) -} - -func TestMax(t *testing.T) { - varValuesMap := getTestValuesMap() - - var exp ast.Expr - var eCtx EvalCtx - var result interface{} - var err error - - // Max float - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldFloat"] = 10.0 - exp, _ = parser.ParseExpr("max(t1.fieldFloat)") - result, _ = eCtx.Eval(exp) - assert.Equal(t, 10.0, result) - varValuesMap["t1"]["fieldFloat"] = 2.0 - result, _ = eCtx.Eval(exp) - assert.Equal(t, 10.0, result) - - // float -> dec - varValuesMap["t1"]["fieldFloat"] = decimal.NewFromInt(1) - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate max(), it started with type float, now got decimal value 1", err.Error()) - - // float -> string - varValuesMap["t1"]["fieldFloat"] = "a" - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate max(), it started with type float, now got string value a", err.Error()) - - // Max int - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldInt"] = 1 - exp, _ = parser.ParseExpr("max(t1.fieldInt)") - result, _ = eCtx.Eval(exp) - assert.Equal(t, int64(1), result) - varValuesMap["t1"]["fieldInt"] = 2 - result, _ = eCtx.Eval(exp) - assert.Equal(t, int64(2), result) - - // int -> float - varValuesMap["t1"]["fieldInt"] = float64(1) - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate max(), it started with type int, now got float value 1.000000", err.Error()) - - // Max dec - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(1) - exp, _ = parser.ParseExpr("max(t1.fieldDec)") - result, _ = eCtx.Eval(exp) - assert.Equal(t, decimal.NewFromInt(1), result) - varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(2) - result, _ = eCtx.Eval(exp) - assert.Equal(t, decimal.NewFromInt(2), result) - - // dec -> int - varValuesMap["t1"]["fieldDec"] = int64(1) - _, err = eCtx.Eval(exp) - assert.Equal(t, "cannot evaluate max(), it started with type decimal, now got int value 1", err.Error()) - - // Max str - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - varValuesMap["t1"]["fieldStr"] = "a" - exp, _ = parser.ParseExpr("max(t1.fieldStr)") - result, _ = eCtx.Eval(exp) - assert.Equal(t, "a", result) - varValuesMap["t1"]["fieldStr"] = "b" - result, _ = eCtx.Eval(exp) - assert.Equal(t, "b", result) - - // Empty int - exp, _ = parser.ParseExpr("max(t1.fieldInt)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, minSupportedInt, eCtx.Max.Int) - - // Empty float - exp, _ = parser.ParseExpr("max(t1.fieldFloat)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, minSupportedFloat, eCtx.Max.Float) - - // Empty dec - exp, _ = parser.ParseExpr("max(t1.fieldDec)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, minSupportedDecimal(), eCtx.Max.Dec) - - // Empty str - exp, _ = parser.ParseExpr("max(t1.fieldString)") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, "", eCtx.Max.Str) -} - -func TestCount(t *testing.T) { - - varValuesMap := getTestValuesMap() - - var exp ast.Expr - var eCtx EvalCtx - var result interface{} - - eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) - exp, _ = parser.ParseExpr("count()") - result, _ = eCtx.Eval(exp) - assert.Equal(t, int64(1), result) - result, _ = eCtx.Eval(exp) - assert.Equal(t, int64(2), result) - - // Empty - exp, _ = parser.ParseExpr("count()") - eCtx = NewPlainEvalCtx(AggFuncEnabled) - assert.Equal(t, int64(0), eCtx.Count) -} - -func TestNoVars(t *testing.T) { - - var exp ast.Expr - var eCtx EvalCtx - var result interface{} - - eCtx = NewPlainEvalCtx(AggFuncEnabled) - exp, _ = parser.ParseExpr("sum(5)") - result, _ = eCtx.Eval(exp) - assert.Equal(t, int64(5), result) - result, _ = eCtx.Eval(exp) - assert.Equal(t, int64(5+5), result) -} +package eval + +import ( + "fmt" + "go/ast" + "go/parser" + "testing" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" +) + +func getTestValuesMap() VarValuesMap { + return VarValuesMap{ + "t1": { + "fieldInt": 1, + "fieldFloat": 2.1, + "fieldDec": decimal.NewFromInt(1), + "fieldStr": "a", + }, + } +} + +func TestMissingCtxVars(t *testing.T) { + varValuesMap := getTestValuesMap() + + var err error + var exp ast.Expr + var eCtx EvalCtx + + exp, _ = parser.ParseExpr("avg(t1.fieldInt)") + eCtx = NewPlainEvalCtx(AggFuncEnabled) + _, err = eCtx.Eval(exp) + assert.Contains(t, err.Error(), "no variables supplied to the context") + + delete(varValuesMap["t1"], "fieldInt") + exp, _ = parser.ParseExpr("avg(t1.fieldInt)") + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + _, err = eCtx.Eval(exp) + assert.Contains(t, err.Error(), "variable not supplied") +} + +func TestExtraAgg(t *testing.T) { + varValuesMap := getTestValuesMap() + + var err error + var exp ast.Expr + var eCtx EvalCtx + + // Extra sum + exp, _ = parser.ParseExpr("sum(min(t1.fieldFloat))") + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + _, err = eCtx.Eval(exp) + assert.Equal(t, err.Error(), "cannot evaluate more than one aggregate functions in the expression, extra sum() found besides min()") + + // Extra avg + exp, _ = parser.ParseExpr("avg(min(t1.fieldFloat))") + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + _, err = eCtx.Eval(exp) + assert.Equal(t, err.Error(), "cannot evaluate more than one aggregate functions in the expression, extra avg() found besides min()") + + // Extra min + exp, _ = parser.ParseExpr("min(min(t1.fieldFloat))") + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + _, err = eCtx.Eval(exp) + assert.Equal(t, err.Error(), "cannot evaluate more than one aggregate functions in the expression, extra min() found besides min()") + + // Extra max + exp, _ = parser.ParseExpr("max(min(t1.fieldFloat))") + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + _, err = eCtx.Eval(exp) + assert.Equal(t, err.Error(), "cannot evaluate more than one aggregate functions in the expression, extra max() found besides min()") + + // Extra count + exp, _ = parser.ParseExpr("min(t1.fieldFloat)+count())") + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + _, err = eCtx.Eval(exp) + assert.Equal(t, err.Error(), "cannot evaluate more than one aggregate functions in the expression, extra count() found besides min()") +} + +func assertFuncTypeAndArgs(t *testing.T, expression string, aggFuncEnabled AggEnabledType, expectedAggFuncType AggFuncType, expectedNumberOfArgs int) { + exp, _ := parser.ParseExpr(expression) + aggEnabledType, aggFuncType, aggFuncArgs := DetectRootAggFunc(exp) + assert.Equal(t, aggFuncEnabled, aggEnabledType, "expected AggFuncEnabled for "+expression) + assert.Equal(t, expectedAggFuncType, aggFuncType, fmt.Sprintf("expected %s for %s", expectedAggFuncType, expression)) + assert.Equal(t, expectedNumberOfArgs, len(aggFuncArgs), fmt.Sprintf("expected %d args for %s", expectedNumberOfArgs, expression)) +} + +func TestDetectRootArgFunc(t *testing.T) { + assertFuncTypeAndArgs(t, `string_agg(t1.fieldStr,",")`, AggFuncEnabled, AggStringAgg, 2) + assertFuncTypeAndArgs(t, `sum(t1.fieldFloat)`, AggFuncEnabled, AggSum, 1) + assertFuncTypeAndArgs(t, `avg(t1.fieldFloat)`, AggFuncEnabled, AggAvg, 1) + assertFuncTypeAndArgs(t, `min(t1.fieldFloat)`, AggFuncEnabled, AggMin, 1) + assertFuncTypeAndArgs(t, `max(t1.fieldFloat)`, AggFuncEnabled, AggMax, 1) + assertFuncTypeAndArgs(t, `count()`, AggFuncEnabled, AggCount, 0) + assertFuncTypeAndArgs(t, `some_func(t1.fieldFloat)`, AggFuncDisabled, AggUnknown, 0) +} + +func TestStringAgg(t *testing.T) { + varValuesMap := getTestValuesMap() + + var exp ast.Expr + var result any + + varValuesMap["t1"]["fieldStr"] = "a" + + exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr,"-")`) + eCtx, _ := NewPlainEvalCtxWithVarsAndInitializedAgg(AggFuncEnabled, &varValuesMap, AggStringAgg, exp.(*ast.CallExpr).Args) + result, _ = eCtx.Eval(exp) + assert.Equal(t, "a", result) + varValuesMap["t1"]["fieldStr"] = "b" + result, _ = eCtx.Eval(exp) + assert.Equal(t, "a-b", result) + + // Empty str + exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr,",")`) + eCtx, _ = NewPlainEvalCtxWithVarsAndInitializedAgg(AggFuncEnabled, &varValuesMap, AggStringAgg, exp.(*ast.CallExpr).Args) + assert.Equal(t, "", eCtx.StringAgg.Sb.String()) + + var err error + + // Bad number of args + exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr)`) + _, err = NewPlainEvalCtxWithVarsAndInitializedAgg(AggFuncEnabled, &varValuesMap, AggStringAgg, exp.(*ast.CallExpr).Args) + assert.Contains(t, err.Error(), "string_agg must have two parameters") + + // Bad separators + exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr, t2.someBadField)`) + _, err = NewPlainEvalCtxWithVarsAndInitializedAgg(AggFuncEnabled, &varValuesMap, AggStringAgg, exp.(*ast.CallExpr).Args) + assert.Contains(t, err.Error(), "string_agg second parameter must be a basic literal") + + exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr, 123)`) + _, err = NewPlainEvalCtxWithVarsAndInitializedAgg(AggFuncEnabled, &varValuesMap, AggStringAgg, exp.(*ast.CallExpr).Args) + assert.Contains(t, err.Error(), "string_agg second parameter must be a constant string") + + // Bad data type + exp, _ = parser.ParseExpr(`string_agg(t1.fieldFloat, ",")`) + eCtx, err = NewPlainEvalCtxWithVarsAndInitializedAgg(AggFuncEnabled, &varValuesMap, AggStringAgg, exp.(*ast.CallExpr).Args) + assert.Nil(t, err) + // TODO: can we check expression type before Eval? + _, err = eCtx.Eval(exp) + assert.Contains(t, err.Error(), "unsupported type float64") + + // Bad ctx with disabled agg func calling string_agg() + exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr,"-")`) + badCtx := NewPlainEvalCtxWithVars(AggFuncDisabled, &varValuesMap) + _, err = badCtx.Eval(exp) + assert.Contains(t, err.Error(), "context aggregate not enabled") +} + +func TestSum(t *testing.T) { + varValuesMap := getTestValuesMap() + + var exp ast.Expr + var eCtx EvalCtx + var result any + var err error + + // Sum float + exp, _ = parser.ParseExpr("5 + sum(t1.fieldFloat)") + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldFloat"] = 2.1 + result, _ = eCtx.Eval(exp) + assert.Equal(t, 5+2.1, result) + result, _ = eCtx.Eval(exp) + assert.Equal(t, 5+4.2, result) + + // float -> dec + varValuesMap["t1"]["fieldFloat"] = decimal.NewFromInt(1) + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate sum(), it started with type float, now got decimal value 1", err.Error()) + + // Sum int + exp, _ = parser.ParseExpr("5 + sum(t1.fieldInt)") + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldInt"] = 1 + result, _ = eCtx.Eval(exp) + assert.Equal(t, int64(6), result) + result, _ = eCtx.Eval(exp) + assert.Equal(t, int64(7), result) + + // int -> float + varValuesMap["t1"]["fieldInt"] = float64(1) + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate sum(), it started with type int, now got float value 1.000000", err.Error()) + + // Sum dec + exp, _ = parser.ParseExpr("5 + sum(t1.fieldDec)") + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(1) + result, _ = eCtx.Eval(exp) + assert.Equal(t, decimal.New(600, -2), result) + result, _ = eCtx.Eval(exp) + assert.Equal(t, decimal.New(700, -2), result) + + // dec -> int + varValuesMap["t1"]["fieldDec"] = int64(1) + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate sum(), it started with type decimal, now got int value 1", err.Error()) + + eCtx = NewPlainEvalCtx(AggFuncEnabled) + // Sum int empty + assert.Equal(t, int64(0), eCtx.Sum.Int) + // Sum float empty + assert.Equal(t, float64(0), eCtx.Sum.Float) + // Sum dec empty + assert.Equal(t, defaultDecimal(), eCtx.Sum.Dec) +} + +func TestAvg(t *testing.T) { + varValuesMap := getTestValuesMap() + + var exp ast.Expr + var eCtx EvalCtx + var result any + var err error + + // Avg int + exp, _ = parser.ParseExpr("avg(t1.fieldInt)") + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + + varValuesMap["t1"]["fieldInt"] = 1 + _, err = eCtx.Eval(exp) + assert.Nil(t, err) + _, err = eCtx.Eval(exp) + assert.Nil(t, err) + varValuesMap["t1"]["fieldInt"] = 2 + result, _ = eCtx.Eval(exp) + assert.Equal(t, int64(1), result) + + // int -> float + varValuesMap["t1"]["fieldInt"] = float64(1) + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate avg(), it started with type int, now got float value 1.000000", err.Error()) + + // Avg float + exp, _ = parser.ParseExpr("avg(t1.fieldFloat)") + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldFloat"] = float64(1) + _, err = eCtx.Eval(exp) + assert.Nil(t, err) + _, err = eCtx.Eval(exp) + assert.Nil(t, err) + varValuesMap["t1"]["fieldFloat"] = float64(2) + result, _ = eCtx.Eval(exp) + assert.Equal(t, float64(1.3333333333333333), result) + + // float -> dec + varValuesMap["t1"]["fieldFloat"] = decimal.NewFromInt(1) + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate avg(), it started with type float, now got decimal value 1", err.Error()) + + // Avg dec + exp, _ = parser.ParseExpr("avg(t1.fieldDec)") + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(1) + _, err = eCtx.Eval(exp) + assert.Nil(t, err) + _, err = eCtx.Eval(exp) + assert.Nil(t, err) + varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(2) + result, _ = eCtx.Eval(exp) + assert.Equal(t, decimal.NewFromFloat32(1.33), result) + + // dec -> int + varValuesMap["t1"]["fieldDec"] = int64(1) + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate avg(), it started with type decimal, now got int value 1", err.Error()) + + eCtx = NewPlainEvalCtx(AggFuncEnabled) + + // Avg int empty + assert.Equal(t, int64(0), eCtx.Avg.Int) + // Avg float empty + assert.Equal(t, float64(0), eCtx.Avg.Float) + // Avg dec empty + assert.Equal(t, defaultDecimal(), eCtx.Avg.Dec) +} + +func TestMin(t *testing.T) { + varValuesMap := getTestValuesMap() + + var exp ast.Expr + var eCtx EvalCtx + var result any + var err error + + // Min float + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldFloat"] = 1.0 + exp, _ = parser.ParseExpr("min(t1.fieldFloat)") + result, _ = eCtx.Eval(exp) + assert.Equal(t, 1.0, result) + varValuesMap["t1"]["fieldFloat"] = 2.0 + result, _ = eCtx.Eval(exp) + assert.Equal(t, 1.0, result) + + // float -> dec + varValuesMap["t1"]["fieldFloat"] = decimal.NewFromInt(1) + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate min(), it started with type float, now got decimal value 1", err.Error()) + + // float -> string + varValuesMap["t1"]["fieldFloat"] = "a" + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate min(), it started with type float, now got string value a", err.Error()) + + // Min int + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldInt"] = 1 + exp, _ = parser.ParseExpr("min(t1.fieldInt)") + result, _ = eCtx.Eval(exp) + assert.Equal(t, int64(1), result) + varValuesMap["t1"]["fieldInt"] = 2 + result, _ = eCtx.Eval(exp) + assert.Equal(t, int64(1), result) + + // int -> float + varValuesMap["t1"]["fieldInt"] = float64(1) + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate min(), it started with type int, now got float value 1.000000", err.Error()) + + // Min dec + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(1) + exp, _ = parser.ParseExpr("min(t1.fieldDec)") + result, _ = eCtx.Eval(exp) + assert.Equal(t, decimal.NewFromInt(1), result) + varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(2) + result, _ = eCtx.Eval(exp) + assert.Equal(t, decimal.NewFromInt(1), result) + + // dec -> int + varValuesMap["t1"]["fieldDec"] = int64(1) + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate min(), it started with type decimal, now got int value 1", err.Error()) + + // Min str + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldStr"] = "a" + exp, _ = parser.ParseExpr("min(t1.fieldStr)") + result, _ = eCtx.Eval(exp) + assert.Equal(t, "a", result) + varValuesMap["t1"]["fieldStr"] = "b" + result, _ = eCtx.Eval(exp) + assert.Equal(t, "a", result) + + eCtx = NewPlainEvalCtx(AggFuncEnabled) + // Empty int + assert.Equal(t, maxSupportedInt, eCtx.Min.Int) + // Empty float + assert.Equal(t, maxSupportedFloat, eCtx.Min.Float) + // Empty dec + assert.Equal(t, maxSupportedDecimal(), eCtx.Min.Dec) + // Empty str + assert.Equal(t, "", eCtx.Min.Str) +} + +func TestMax(t *testing.T) { + varValuesMap := getTestValuesMap() + + var exp ast.Expr + var eCtx EvalCtx + var result any + var err error + + // Max float + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldFloat"] = 10.0 + exp, _ = parser.ParseExpr("max(t1.fieldFloat)") + result, _ = eCtx.Eval(exp) + assert.Equal(t, 10.0, result) + varValuesMap["t1"]["fieldFloat"] = 2.0 + result, _ = eCtx.Eval(exp) + assert.Equal(t, 10.0, result) + + // float -> dec + varValuesMap["t1"]["fieldFloat"] = decimal.NewFromInt(1) + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate max(), it started with type float, now got decimal value 1", err.Error()) + + // float -> string + varValuesMap["t1"]["fieldFloat"] = "a" + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate max(), it started with type float, now got string value a", err.Error()) + + // Max int + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldInt"] = 1 + exp, _ = parser.ParseExpr("max(t1.fieldInt)") + result, _ = eCtx.Eval(exp) + assert.Equal(t, int64(1), result) + varValuesMap["t1"]["fieldInt"] = 2 + result, _ = eCtx.Eval(exp) + assert.Equal(t, int64(2), result) + + // int -> float + varValuesMap["t1"]["fieldInt"] = float64(1) + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate max(), it started with type int, now got float value 1.000000", err.Error()) + + // Max dec + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(1) + exp, _ = parser.ParseExpr("max(t1.fieldDec)") + result, _ = eCtx.Eval(exp) + assert.Equal(t, decimal.NewFromInt(1), result) + varValuesMap["t1"]["fieldDec"] = decimal.NewFromInt(2) + result, _ = eCtx.Eval(exp) + assert.Equal(t, decimal.NewFromInt(2), result) + + // dec -> int + varValuesMap["t1"]["fieldDec"] = int64(1) + _, err = eCtx.Eval(exp) + assert.Equal(t, "cannot evaluate max(), it started with type decimal, now got int value 1", err.Error()) + + // Max str + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + varValuesMap["t1"]["fieldStr"] = "a" + exp, _ = parser.ParseExpr("max(t1.fieldStr)") + result, _ = eCtx.Eval(exp) + assert.Equal(t, "a", result) + varValuesMap["t1"]["fieldStr"] = "b" + result, _ = eCtx.Eval(exp) + assert.Equal(t, "b", result) + + eCtx = NewPlainEvalCtx(AggFuncEnabled) + // Empty int + assert.Equal(t, minSupportedInt, eCtx.Max.Int) + // Empty float + assert.Equal(t, minSupportedFloat, eCtx.Max.Float) + // Empty dec + assert.Equal(t, minSupportedDecimal(), eCtx.Max.Dec) + // Empty str + assert.Equal(t, "", eCtx.Max.Str) +} + +func TestCount(t *testing.T) { + + varValuesMap := getTestValuesMap() + + var exp ast.Expr + var eCtx EvalCtx + var result any + + eCtx = NewPlainEvalCtxWithVars(AggFuncEnabled, &varValuesMap) + exp, _ = parser.ParseExpr("count()") + result, _ = eCtx.Eval(exp) + assert.Equal(t, int64(1), result) + result, _ = eCtx.Eval(exp) + assert.Equal(t, int64(2), result) + + // Empty + eCtx = NewPlainEvalCtx(AggFuncEnabled) + assert.Equal(t, int64(0), eCtx.Count) +} + +func TestNoVars(t *testing.T) { + + var exp ast.Expr + var eCtx EvalCtx + var result any + + eCtx = NewPlainEvalCtx(AggFuncEnabled) + exp, _ = parser.ParseExpr("sum(5)") + result, _ = eCtx.Eval(exp) + assert.Equal(t, int64(5), result) + result, _ = eCtx.Eval(exp) + assert.Equal(t, int64(5+5), result) +} diff --git a/pkg/eval/cast.go b/pkg/eval/cast.go index 35d25db..536d973 100644 --- a/pkg/eval/cast.go +++ b/pkg/eval/cast.go @@ -1,147 +1,147 @@ -package eval - -import ( - "fmt" - - "github.com/shopspring/decimal" -) - -func castNumberToStandardType(arg interface{}) (interface{}, error) { - switch typedArg := arg.(type) { - case int: - return int64(typedArg), nil - case int16: - return int64(typedArg), nil - case int32: - return int64(typedArg), nil - case int64: - return typedArg, nil - case float32: - return float64(typedArg), nil - case float64: - return typedArg, nil - case decimal.Decimal: - return typedArg, nil - default: - return 0.0, fmt.Errorf("cannot cast %v(%T) to standard number type, unsuported type", typedArg, typedArg) - } -} - -func castNumberPairToCommonType(argLeft interface{}, argRight interface{}) (interface{}, interface{}, error) { - stdArgLeft, err := castNumberToStandardType(argLeft) - if err != nil { - return nil, nil, fmt.Errorf("invalid left arg: %s", err.Error()) - } - stdArgRight, err := castNumberToStandardType(argRight) - if err != nil { - return nil, nil, fmt.Errorf("invalid right arg: %s", err.Error()) - } - // Check for float64 - _, floatLeft := stdArgLeft.(float64) - _, floatRight := stdArgRight.(float64) - if floatLeft || floatRight { - finalArgLeft, err := castToFloat64(stdArgLeft) - if err != nil { - return nil, nil, fmt.Errorf("unexpectedly cannot cast left arg to float64: %s", err.Error()) - } - finalArgRight, err := castToFloat64(stdArgRight) - if err != nil { - return nil, nil, fmt.Errorf("unexpectedly cannot cast right arg to float64: %s", err.Error()) - } - return finalArgLeft, finalArgRight, nil - } - - // Check for decimal2 - _, decLeft := stdArgLeft.(decimal.Decimal) - _, decRight := stdArgRight.(decimal.Decimal) - if decLeft || decRight { - finalArgLeft, err := castToDecimal2(stdArgLeft) - if err != nil { - return nil, nil, fmt.Errorf("unexpectedly cannot cast left arg to decimal2: %s", err.Error()) - } - finalArgRight, err := castToDecimal2(stdArgRight) - if err != nil { - return nil, nil, fmt.Errorf("unexpectedly cannot cast right arg to decimal2: %s", err.Error()) - } - return finalArgLeft, finalArgRight, nil - } - - // Cast both to int64 - finalArgLeft, err := castToInt64(stdArgLeft) - if err != nil { - return nil, nil, fmt.Errorf("unexpectedly cannot cast left arg to int64: %s", err.Error()) - } - finalArgRight, err := castToInt64(stdArgRight) - if err != nil { - return nil, nil, fmt.Errorf("unexpectedly cannot cast right arg to int64: %s", err.Error()) - } - return finalArgLeft, finalArgRight, nil -} - -func castToInt64(arg interface{}) (int64, error) { - switch typedArg := arg.(type) { - case int: - return int64(typedArg), nil - case int16: - return int64(typedArg), nil - case int32: - return int64(typedArg), nil - case int64: - return typedArg, nil - case float32: - return int64(typedArg), nil - case float64: - return int64(typedArg), nil - case decimal.Decimal: - if typedArg.IsInteger() { - return typedArg.BigInt().Int64(), nil - } else { - return 0.0, fmt.Errorf("cannot cast decimal '%v' to int64, exact conversion impossible", typedArg) - } - default: - return 0.0, fmt.Errorf("cannot cast %v(%T) to int64, unsuported type", typedArg, typedArg) - } -} - -func castToFloat64(arg interface{}) (float64, error) { - switch typedArg := arg.(type) { - case int: - return float64(typedArg), nil - case int16: - return float64(typedArg), nil - case int32: - return float64(typedArg), nil - case int64: - return float64(typedArg), nil - case float32: - return float64(typedArg), nil - case float64: - return typedArg, nil - case decimal.Decimal: - valFloat, _ := typedArg.Float64() - return valFloat, nil - default: - return 0.0, fmt.Errorf("cannot cast %v(%T) to float64, unsuported type", typedArg, typedArg) - } -} - -func castToDecimal2(arg interface{}) (decimal.Decimal, error) { - switch typedArg := arg.(type) { - case int: - return decimal.NewFromInt(int64(typedArg)), nil - case int16: - return decimal.NewFromInt(int64(typedArg)), nil - case int32: - return decimal.NewFromInt(int64(typedArg)), nil - case int64: - return decimal.NewFromInt(typedArg), nil - case float32: - return decimal.NewFromFloat32(typedArg), nil - case float64: - return decimal.NewFromFloat(typedArg), nil - case decimal.Decimal: - return typedArg, nil - default: - return decimal.NewFromInt(0), fmt.Errorf("cannot cast %v(%T) to decimal2, unsuported type", typedArg, typedArg) - } -} +package eval + +import ( + "fmt" + + "github.com/shopspring/decimal" +) + +func castNumberToStandardType(arg any) (any, error) { + switch typedArg := arg.(type) { + case int: + return int64(typedArg), nil + case int16: + return int64(typedArg), nil + case int32: + return int64(typedArg), nil + case int64: + return typedArg, nil + case float32: + return float64(typedArg), nil + case float64: + return typedArg, nil + case decimal.Decimal: + return typedArg, nil + default: + return 0.0, fmt.Errorf("cannot cast %v(%T) to standard number type, unsuported type", typedArg, typedArg) + } +} + +func castNumberPairToCommonType(argLeft any, argRight any) (any, any, error) { + stdArgLeft, err := castNumberToStandardType(argLeft) + if err != nil { + return nil, nil, fmt.Errorf("invalid left arg: %s", err.Error()) + } + stdArgRight, err := castNumberToStandardType(argRight) + if err != nil { + return nil, nil, fmt.Errorf("invalid right arg: %s", err.Error()) + } + // Check for float64 + _, floatLeft := stdArgLeft.(float64) + _, floatRight := stdArgRight.(float64) + if floatLeft || floatRight { + finalArgLeft, err := castToFloat64(stdArgLeft) + if err != nil { + return nil, nil, fmt.Errorf("unexpectedly cannot cast left arg to float64: %s", err.Error()) + } + finalArgRight, err := castToFloat64(stdArgRight) + if err != nil { + return nil, nil, fmt.Errorf("unexpectedly cannot cast right arg to float64: %s", err.Error()) + } + return finalArgLeft, finalArgRight, nil + } + + // Check for decimal2 + _, decLeft := stdArgLeft.(decimal.Decimal) + _, decRight := stdArgRight.(decimal.Decimal) + if decLeft || decRight { + finalArgLeft, err := castToDecimal2(stdArgLeft) + if err != nil { + return nil, nil, fmt.Errorf("unexpectedly cannot cast left arg to decimal2: %s", err.Error()) + } + finalArgRight, err := castToDecimal2(stdArgRight) + if err != nil { + return nil, nil, fmt.Errorf("unexpectedly cannot cast right arg to decimal2: %s", err.Error()) + } + return finalArgLeft, finalArgRight, nil + } + + // Cast both to int64 + finalArgLeft, err := castToInt64(stdArgLeft) + if err != nil { + return nil, nil, fmt.Errorf("unexpectedly cannot cast left arg to int64: %s", err.Error()) + } + finalArgRight, err := castToInt64(stdArgRight) + if err != nil { + return nil, nil, fmt.Errorf("unexpectedly cannot cast right arg to int64: %s", err.Error()) + } + return finalArgLeft, finalArgRight, nil +} + +func castToInt64(arg any) (int64, error) { + switch typedArg := arg.(type) { + case int: + return int64(typedArg), nil + case int16: + return int64(typedArg), nil + case int32: + return int64(typedArg), nil + case int64: + return typedArg, nil + case float32: + return int64(typedArg), nil + case float64: + return int64(typedArg), nil + case decimal.Decimal: + if typedArg.IsInteger() { + return typedArg.BigInt().Int64(), nil + } else { + return 0.0, fmt.Errorf("cannot cast decimal '%v' to int64, exact conversion impossible", typedArg) + } + default: + return 0.0, fmt.Errorf("cannot cast %v(%T) to int64, unsuported type", typedArg, typedArg) + } +} + +func castToFloat64(arg any) (float64, error) { + switch typedArg := arg.(type) { + case int: + return float64(typedArg), nil + case int16: + return float64(typedArg), nil + case int32: + return float64(typedArg), nil + case int64: + return float64(typedArg), nil + case float32: + return float64(typedArg), nil + case float64: + return typedArg, nil + case decimal.Decimal: + valFloat, _ := typedArg.Float64() + return valFloat, nil + default: + return 0.0, fmt.Errorf("cannot cast %v(%T) to float64, unsuported type", typedArg, typedArg) + } +} + +func castToDecimal2(arg any) (decimal.Decimal, error) { + switch typedArg := arg.(type) { + case int: + return decimal.NewFromInt(int64(typedArg)), nil + case int16: + return decimal.NewFromInt(int64(typedArg)), nil + case int32: + return decimal.NewFromInt(int64(typedArg)), nil + case int64: + return decimal.NewFromInt(typedArg), nil + case float32: + return decimal.NewFromFloat32(typedArg), nil + case float64: + return decimal.NewFromFloat(typedArg), nil + case decimal.Decimal: + return typedArg, nil + default: + return decimal.NewFromInt(0), fmt.Errorf("cannot cast %v(%T) to decimal2, unsuported type", typedArg, typedArg) + } +} diff --git a/pkg/eval/cast_test.go b/pkg/eval/cast_test.go index ce6df03..82fa72f 100644 --- a/pkg/eval/cast_test.go +++ b/pkg/eval/cast_test.go @@ -1,108 +1,108 @@ -package eval - -import ( - "testing" - - "github.com/shopspring/decimal" - "github.com/stretchr/testify/assert" -) - -func TestCastSingle(t *testing.T) { - var val interface{} - var err error - - val, _ = castNumberToStandardType(int(12)) - assert.Equal(t, int64(12), val) - val, _ = castNumberToStandardType(int16(12)) - assert.Equal(t, int64(12), val) - val, _ = castNumberToStandardType(int32(12)) - assert.Equal(t, int64(12), val) - val, _ = castNumberToStandardType(int64(12)) - assert.Equal(t, int64(12), val) - val, _ = castNumberToStandardType(float32(12)) - assert.Equal(t, float64(12), val) - val, _ = castNumberToStandardType(float64(12)) - assert.Equal(t, float64(12), val) - val, _ = castNumberToStandardType(decimal.NewFromInt(12)) - assert.Equal(t, decimal.NewFromInt(12), val) - _, err = castNumberToStandardType("12") - assert.Equal(t, "cannot cast 12(string) to standard number type, unsuported type", err.Error()) - - val, _ = castToInt64(int(12)) - assert.Equal(t, int64(12), val) - val, _ = castToInt64(int16(12)) - assert.Equal(t, int64(12), val) - val, _ = castToInt64(int32(12)) - assert.Equal(t, int64(12), val) - val, _ = castToInt64(int64(12)) - assert.Equal(t, int64(12), val) - val, _ = castToInt64(float32(12)) - assert.Equal(t, int64(12), val) - val, _ = castToInt64(float64(12)) - assert.Equal(t, int64(12), val) - val, _ = castToInt64(decimal.NewFromInt(12)) - assert.Equal(t, int64(12), val) - _, err = castToInt64("12") - assert.Equal(t, "cannot cast 12(string) to int64, unsuported type", err.Error()) - _, err = castToInt64(decimal.NewFromFloat(12.2)) - assert.Equal(t, "cannot cast decimal '12.2' to int64, exact conversion impossible", err.Error()) - - val, _ = castToFloat64(int(12)) - assert.Equal(t, float64(12.0), val) - val, _ = castToFloat64(int16(12)) - assert.Equal(t, float64(12.0), val) - val, _ = castToFloat64(int32(12)) - assert.Equal(t, float64(12.0), val) - val, _ = castToFloat64(float64(12.0)) - assert.Equal(t, float64(12.0), val) - val, _ = castToFloat64(float32(12)) - assert.Equal(t, float64(12.0), val) - val, _ = castToFloat64(float64(12)) - assert.Equal(t, float64(12.0), val) - val, _ = castToFloat64(decimal.NewFromInt(12)) - assert.Equal(t, float64(12.0), val) - _, err = castToFloat64("12") - assert.Equal(t, "cannot cast 12(string) to float64, unsuported type", err.Error()) - val, _ = castToFloat64(decimal.NewFromFloat(12.2)) - assert.Equal(t, float64(12.2), val) - - val, _ = castToDecimal2(int(12)) - assert.Equal(t, decimal.NewFromFloat(12.0), val) - val, _ = castToDecimal2(int16(12)) - assert.Equal(t, decimal.NewFromFloat(12.0), val) - val, _ = castToDecimal2(int32(12)) - assert.Equal(t, decimal.NewFromFloat(12.0), val) - val, _ = castToDecimal2(float64(12.1)) - assert.Equal(t, decimal.NewFromFloat(12.1), val) - val, _ = castToDecimal2(float32(12.1)) - assert.Equal(t, decimal.NewFromFloat(12.1), val) - val, _ = castToDecimal2(float64(12.1)) - assert.Equal(t, decimal.NewFromFloat(12.1), val) - val, _ = castToDecimal2(decimal.NewFromFloat(12.1)) - assert.Equal(t, decimal.NewFromFloat(12.1), val) - _, err = castToDecimal2("12") - assert.Equal(t, "cannot cast 12(string) to decimal2, unsuported type", err.Error()) -} - -func TestCastPair(t *testing.T) { - var vLeft, vRight interface{} - var err error - - vLeft, vRight, _ = castNumberPairToCommonType(float32(12), int(13)) - assert.Equal(t, float64(12), vLeft) - assert.Equal(t, float64(13), vRight) - - vLeft, vRight, _ = castNumberPairToCommonType(decimal.NewFromInt(12), int16(13)) - assert.Equal(t, decimal.NewFromInt(12), vLeft) - assert.Equal(t, decimal.NewFromInt(13), vRight) - - vLeft, vRight, _ = castNumberPairToCommonType(int32(12), int(13)) - assert.Equal(t, int64(12), vLeft) - assert.Equal(t, int64(13), vRight) - - _, _, err = castNumberPairToCommonType("12", int(13)) - assert.Equal(t, "invalid left arg: cannot cast 12(string) to standard number type, unsuported type", err.Error()) - - _, _, err = castNumberPairToCommonType(int(132), "13") - assert.Equal(t, "invalid right arg: cannot cast 13(string) to standard number type, unsuported type", err.Error()) -} +package eval + +import ( + "testing" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" +) + +func TestCastSingle(t *testing.T) { + var val any + var err error + + val, _ = castNumberToStandardType(int(12)) + assert.Equal(t, int64(12), val) + val, _ = castNumberToStandardType(int16(12)) + assert.Equal(t, int64(12), val) + val, _ = castNumberToStandardType(int32(12)) + assert.Equal(t, int64(12), val) + val, _ = castNumberToStandardType(int64(12)) + assert.Equal(t, int64(12), val) + val, _ = castNumberToStandardType(float32(12)) + assert.Equal(t, float64(12), val) + val, _ = castNumberToStandardType(float64(12)) + assert.Equal(t, float64(12), val) + val, _ = castNumberToStandardType(decimal.NewFromInt(12)) + assert.Equal(t, decimal.NewFromInt(12), val) + _, err = castNumberToStandardType("12") + assert.Equal(t, "cannot cast 12(string) to standard number type, unsuported type", err.Error()) + + val, _ = castToInt64(int(12)) + assert.Equal(t, int64(12), val) + val, _ = castToInt64(int16(12)) + assert.Equal(t, int64(12), val) + val, _ = castToInt64(int32(12)) + assert.Equal(t, int64(12), val) + val, _ = castToInt64(int64(12)) + assert.Equal(t, int64(12), val) + val, _ = castToInt64(float32(12)) + assert.Equal(t, int64(12), val) + val, _ = castToInt64(float64(12)) + assert.Equal(t, int64(12), val) + val, _ = castToInt64(decimal.NewFromInt(12)) + assert.Equal(t, int64(12), val) + _, err = castToInt64("12") + assert.Equal(t, "cannot cast 12(string) to int64, unsuported type", err.Error()) + _, err = castToInt64(decimal.NewFromFloat(12.2)) + assert.Equal(t, "cannot cast decimal '12.2' to int64, exact conversion impossible", err.Error()) + + val, _ = castToFloat64(int(12)) + assert.Equal(t, float64(12.0), val) + val, _ = castToFloat64(int16(12)) + assert.Equal(t, float64(12.0), val) + val, _ = castToFloat64(int32(12)) + assert.Equal(t, float64(12.0), val) + val, _ = castToFloat64(float64(12.0)) + assert.Equal(t, float64(12.0), val) + val, _ = castToFloat64(float32(12)) + assert.Equal(t, float64(12.0), val) + val, _ = castToFloat64(float64(12)) + assert.Equal(t, float64(12.0), val) + val, _ = castToFloat64(decimal.NewFromInt(12)) + assert.Equal(t, float64(12.0), val) + _, err = castToFloat64("12") + assert.Equal(t, "cannot cast 12(string) to float64, unsuported type", err.Error()) + val, _ = castToFloat64(decimal.NewFromFloat(12.2)) + assert.Equal(t, float64(12.2), val) + + val, _ = castToDecimal2(int(12)) + assert.Equal(t, decimal.NewFromFloat(12.0), val) + val, _ = castToDecimal2(int16(12)) + assert.Equal(t, decimal.NewFromFloat(12.0), val) + val, _ = castToDecimal2(int32(12)) + assert.Equal(t, decimal.NewFromFloat(12.0), val) + val, _ = castToDecimal2(float64(12.1)) + assert.Equal(t, decimal.NewFromFloat(12.1), val) + val, _ = castToDecimal2(float32(12.1)) + assert.Equal(t, decimal.NewFromFloat(12.1), val) + val, _ = castToDecimal2(float64(12.1)) + assert.Equal(t, decimal.NewFromFloat(12.1), val) + val, _ = castToDecimal2(decimal.NewFromFloat(12.1)) + assert.Equal(t, decimal.NewFromFloat(12.1), val) + _, err = castToDecimal2("12") + assert.Equal(t, "cannot cast 12(string) to decimal2, unsuported type", err.Error()) +} + +func TestCastPair(t *testing.T) { + var vLeft, vRight any + var err error + + vLeft, vRight, _ = castNumberPairToCommonType(float32(12), int(13)) + assert.Equal(t, float64(12), vLeft) + assert.Equal(t, float64(13), vRight) + + vLeft, vRight, _ = castNumberPairToCommonType(decimal.NewFromInt(12), int16(13)) + assert.Equal(t, decimal.NewFromInt(12), vLeft) + assert.Equal(t, decimal.NewFromInt(13), vRight) + + vLeft, vRight, _ = castNumberPairToCommonType(int32(12), int(13)) + assert.Equal(t, int64(12), vLeft) + assert.Equal(t, int64(13), vRight) + + _, _, err = castNumberPairToCommonType("12", int(13)) + assert.Equal(t, "invalid left arg: cannot cast 12(string) to standard number type, unsuported type", err.Error()) + + _, _, err = castNumberPairToCommonType(int(132), "13") + assert.Equal(t, "invalid right arg: cannot cast 13(string) to standard number type, unsuported type", err.Error()) +} diff --git a/pkg/eval/const.go b/pkg/eval/const.go index 55e4ec1..ee25160 100644 --- a/pkg/eval/const.go +++ b/pkg/eval/const.go @@ -1,18 +1,18 @@ -package eval - -import "time" - -var GolangConstants map[string]interface{} = map[string]interface{}{ - "time.January": time.January, - "time.February": time.February, - "time.March": time.March, - "time.April": time.April, - "time.May": time.May, - "time.June": time.June, - "time.July": time.July, - "time.August": time.August, - "time.September": time.September, - "time.October": time.October, - "time.November": time.November, - "time.December": time.December, - "time.UTC": time.UTC} +package eval + +import "time" + +var GolangConstants = map[string]any{ + "time.January": time.January, + "time.February": time.February, + "time.March": time.March, + "time.April": time.April, + "time.May": time.May, + "time.June": time.June, + "time.July": time.July, + "time.August": time.August, + "time.September": time.September, + "time.October": time.October, + "time.November": time.November, + "time.December": time.December, + "time.UTC": time.UTC} diff --git a/pkg/eval/convert.go b/pkg/eval/convert.go index ae60418..7f782ff 100644 --- a/pkg/eval/convert.go +++ b/pkg/eval/convert.go @@ -1,151 +1,151 @@ -package eval - -import ( - "fmt" - "strconv" - - "github.com/shopspring/decimal" -) - -func callString(args []interface{}) (interface{}, error) { - if err := checkArgs("string", 1, len(args)); err != nil { - return nil, err - } - return fmt.Sprintf("%v", args[0]), nil -} - -func callInt(args []interface{}) (interface{}, error) { - if err := checkArgs("int", 1, len(args)); err != nil { - return nil, err - } - - switch typedArg0 := args[0].(type) { - case string: - retVal, err := strconv.ParseInt(typedArg0, 10, 64) - if err != nil { - return nil, fmt.Errorf("cannot eval int(%s):%s", typedArg0, err.Error()) - } - return retVal, nil - - case bool: - if typedArg0 { - return int64(1), nil - } else { - return int64(0), nil - } - - case int: - return int64(typedArg0), nil - - case int32: - return int64(typedArg0), nil - - case int16: - return int64(typedArg0), nil - - case int64: - return typedArg0, nil - - case float32: - return int64(typedArg0), nil - - case float64: - return int64(typedArg0), nil - - case decimal.Decimal: - return (*typedArg0.BigInt()).Int64(), nil - - default: - return nil, fmt.Errorf("unsupported arg type for int(%v):%T", typedArg0, typedArg0) - } -} - -func callDecimal2(args []interface{}) (interface{}, error) { - if err := checkArgs("decimal2", 1, len(args)); err != nil { - return nil, err - } - - switch typedArg0 := args[0].(type) { - case string: - retVal, err := decimal.NewFromString(typedArg0) - if err != nil { - return nil, fmt.Errorf("cannot eval decimal2(%s):%s", typedArg0, err.Error()) - } - return retVal, nil - - case bool: - if typedArg0 { - return decimal.NewFromInt(1), nil - } else { - return decimal.NewFromInt(0), nil - } - - case int: - return decimal.NewFromInt(int64(typedArg0)), nil - - case int16: - return decimal.NewFromInt(int64(typedArg0)), nil - - case int32: - return decimal.NewFromInt(int64(typedArg0)), nil - - case int64: - return decimal.NewFromInt(typedArg0), nil - - case float32: - return decimal.NewFromFloat(float64(typedArg0)), nil - - case float64: - return decimal.NewFromFloat(typedArg0), nil - - case decimal.Decimal: - return typedArg0, nil - - default: - return nil, fmt.Errorf("unsupported arg type for decimal2(%v):%T", typedArg0, typedArg0) - } -} - -func callFloat(args []interface{}) (interface{}, error) { - if err := checkArgs("float", 1, len(args)); err != nil { - return nil, err - } - switch typedArg0 := args[0].(type) { - case string: - retVal, err := strconv.ParseFloat(typedArg0, 64) - if err != nil { - return nil, fmt.Errorf("cannot eval float(%s):%s", typedArg0, err.Error()) - } - return retVal, nil - case bool: - if typedArg0 { - return float64(1), nil - } else { - return float64(0), nil - } - - case int: - return float64(typedArg0), nil - - case int16: - return float64(typedArg0), nil - - case int32: - return float64(typedArg0), nil - - case int64: - return float64(typedArg0), nil - - case float32: - return float64(typedArg0), nil - - case float64: - return typedArg0, nil - - case decimal.Decimal: - valFloat, _ := typedArg0.Float64() - return valFloat, nil - default: - return nil, fmt.Errorf("unsupported arg type for float(%v):%T", typedArg0, typedArg0) - } -} +package eval + +import ( + "fmt" + "strconv" + + "github.com/shopspring/decimal" +) + +func callString(args []any) (any, error) { + if err := checkArgs("string", 1, len(args)); err != nil { + return nil, err + } + return fmt.Sprintf("%v", args[0]), nil +} + +func callInt(args []any) (any, error) { + if err := checkArgs("int", 1, len(args)); err != nil { + return nil, err + } + + switch typedArg0 := args[0].(type) { + case string: + retVal, err := strconv.ParseInt(typedArg0, 10, 64) + if err != nil { + return nil, fmt.Errorf("cannot eval int(%s):%s", typedArg0, err.Error()) + } + return retVal, nil + + case bool: + if typedArg0 { + return int64(1), nil + } else { + return int64(0), nil + } + + case int: + return int64(typedArg0), nil + + case int32: + return int64(typedArg0), nil + + case int16: + return int64(typedArg0), nil + + case int64: + return typedArg0, nil + + case float32: + return int64(typedArg0), nil + + case float64: + return int64(typedArg0), nil + + case decimal.Decimal: + return (*typedArg0.BigInt()).Int64(), nil + + default: + return nil, fmt.Errorf("unsupported arg type for int(%v):%T", typedArg0, typedArg0) + } +} + +func callDecimal2(args []any) (any, error) { + if err := checkArgs("decimal2", 1, len(args)); err != nil { + return nil, err + } + + switch typedArg0 := args[0].(type) { + case string: + retVal, err := decimal.NewFromString(typedArg0) + if err != nil { + return nil, fmt.Errorf("cannot eval decimal2(%s):%s", typedArg0, err.Error()) + } + return retVal, nil + + case bool: + if typedArg0 { + return decimal.NewFromInt(1), nil + } else { + return decimal.NewFromInt(0), nil + } + + case int: + return decimal.NewFromInt(int64(typedArg0)), nil + + case int16: + return decimal.NewFromInt(int64(typedArg0)), nil + + case int32: + return decimal.NewFromInt(int64(typedArg0)), nil + + case int64: + return decimal.NewFromInt(typedArg0), nil + + case float32: + return decimal.NewFromFloat(float64(typedArg0)), nil + + case float64: + return decimal.NewFromFloat(typedArg0), nil + + case decimal.Decimal: + return typedArg0, nil + + default: + return nil, fmt.Errorf("unsupported arg type for decimal2(%v):%T", typedArg0, typedArg0) + } +} + +func callFloat(args []any) (any, error) { + if err := checkArgs("float", 1, len(args)); err != nil { + return nil, err + } + switch typedArg0 := args[0].(type) { + case string: + retVal, err := strconv.ParseFloat(typedArg0, 64) + if err != nil { + return nil, fmt.Errorf("cannot eval float(%s):%s", typedArg0, err.Error()) + } + return retVal, nil + case bool: + if typedArg0 { + return float64(1), nil + } else { + return float64(0), nil + } + + case int: + return float64(typedArg0), nil + + case int16: + return float64(typedArg0), nil + + case int32: + return float64(typedArg0), nil + + case int64: + return float64(typedArg0), nil + + case float32: + return float64(typedArg0), nil + + case float64: + return typedArg0, nil + + case decimal.Decimal: + valFloat, _ := typedArg0.Float64() + return valFloat, nil + default: + return nil, fmt.Errorf("unsupported arg type for float(%v):%T", typedArg0, typedArg0) + } +} diff --git a/pkg/eval/convert_test.go b/pkg/eval/convert_test.go index b7f417a..e21f3a6 100644 --- a/pkg/eval/convert_test.go +++ b/pkg/eval/convert_test.go @@ -1,73 +1,73 @@ -package eval - -import ( - "testing" - "time" - - "github.com/shopspring/decimal" - "github.com/stretchr/testify/assert" -) - -func TestConvert(t *testing.T) { - var val interface{} - var err error - var testTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) - - val, _ = callString([]interface{}{12.2}) - assert.Equal(t, "12.2", val) - val, _ = callString([]interface{}{true}) - assert.Equal(t, "true", val) - val, _ = callString([]interface{}{false}) - assert.Equal(t, "false", val) - val, _ = callString([]interface{}{testTime}) - assert.Equal(t, "0001-01-01 01:01:01.000000001 +0000 UTC", val) - _, err = callString([]interface{}{12.2, 13.0}) - assert.Equal(t, "cannot evaluate string(), requires 1 args, 2 supplied", err.Error()) - - val, _ = callInt([]interface{}{int64(12)}) - assert.Equal(t, int64(12), val) - val, _ = callInt([]interface{}{"12"}) - assert.Equal(t, int64(12), val) - val, _ = callInt([]interface{}{true}) - assert.Equal(t, int64(1), val) - val, _ = callInt([]interface{}{false}) - assert.Equal(t, int64(0), val) - _, err = callInt([]interface{}{testTime}) - assert.Equal(t, `unsupported arg type for int(0001-01-01 01:01:01.000000001 +0000 UTC):time.Time`, err.Error()) - _, err = callInt([]interface{}{"12.2"}) - assert.Equal(t, `cannot eval int(12.2):strconv.ParseInt: parsing "12.2": invalid syntax`, err.Error()) - _, err = callInt([]interface{}{"12.0", "13.0"}) - assert.Equal(t, "cannot evaluate int(), requires 1 args, 2 supplied", err.Error()) - - val, _ = callDecimal2([]interface{}{int64(12)}) - assert.Equal(t, decimal.NewFromInt(12), val) - val, _ = callDecimal2([]interface{}{"12"}) - assert.Equal(t, decimal.NewFromInt(12), val) - val, _ = callDecimal2([]interface{}{true}) - assert.Equal(t, decimal.NewFromInt(1), val) - val, _ = callDecimal2([]interface{}{false}) - assert.Equal(t, decimal.NewFromInt(0), val) - _, err = callDecimal2([]interface{}{testTime}) - assert.Equal(t, `unsupported arg type for decimal2(0001-01-01 01:01:01.000000001 +0000 UTC):time.Time`, err.Error()) - _, err = callDecimal2([]interface{}{"somestring"}) - assert.Equal(t, "cannot eval decimal2(somestring):can't convert somestring to decimal: exponent is not numeric", err.Error()) - _, err = callDecimal2([]interface{}{"12.0", "13.0"}) - assert.Equal(t, "cannot evaluate decimal2(), requires 1 args, 2 supplied", err.Error()) - - val, _ = callFloat([]interface{}{int64(12)}) - assert.Equal(t, float64(12), val) - val, _ = callFloat([]interface{}{"12.1"}) - assert.Equal(t, float64(12.1), val) - val, _ = callFloat([]interface{}{decimal.NewFromFloat(12.2)}) - assert.Equal(t, float64(12.2), val) - val, _ = callFloat([]interface{}{true}) - assert.Equal(t, float64(1), val) - val, _ = callFloat([]interface{}{false}) - assert.Equal(t, float64(0), val) - _, err = callFloat([]interface{}{testTime}) - assert.Equal(t, `unsupported arg type for float(0001-01-01 01:01:01.000000001 +0000 UTC):time.Time`, err.Error()) - _, err = callFloat([]interface{}{"somestring"}) - assert.Equal(t, `cannot eval float(somestring):strconv.ParseFloat: parsing "somestring": invalid syntax`, err.Error()) - _, err = callFloat([]interface{}{"12.0", "13.0"}) - assert.Equal(t, "cannot evaluate float(), requires 1 args, 2 supplied", err.Error()) -} +package eval + +import ( + "testing" + "time" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" +) + +func TestConvert(t *testing.T) { + var val any + var err error + var testTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) + + val, _ = callString([]any{12.2}) + assert.Equal(t, "12.2", val) + val, _ = callString([]any{true}) + assert.Equal(t, "true", val) + val, _ = callString([]any{false}) + assert.Equal(t, "false", val) + val, _ = callString([]any{testTime}) + assert.Equal(t, "0001-01-01 01:01:01.000000001 +0000 UTC", val) + _, err = callString([]any{12.2, 13.0}) + assert.Equal(t, "cannot evaluate string(), requires 1 args, 2 supplied", err.Error()) + + val, _ = callInt([]any{int64(12)}) + assert.Equal(t, int64(12), val) + val, _ = callInt([]any{"12"}) + assert.Equal(t, int64(12), val) + val, _ = callInt([]any{true}) + assert.Equal(t, int64(1), val) + val, _ = callInt([]any{false}) + assert.Equal(t, int64(0), val) + _, err = callInt([]any{testTime}) + assert.Equal(t, `unsupported arg type for int(0001-01-01 01:01:01.000000001 +0000 UTC):time.Time`, err.Error()) + _, err = callInt([]any{"12.2"}) + assert.Equal(t, `cannot eval int(12.2):strconv.ParseInt: parsing "12.2": invalid syntax`, err.Error()) + _, err = callInt([]any{"12.0", "13.0"}) + assert.Equal(t, "cannot evaluate int(), requires 1 args, 2 supplied", err.Error()) + + val, _ = callDecimal2([]any{int64(12)}) + assert.Equal(t, decimal.NewFromInt(12), val) + val, _ = callDecimal2([]any{"12"}) + assert.Equal(t, decimal.NewFromInt(12), val) + val, _ = callDecimal2([]any{true}) + assert.Equal(t, decimal.NewFromInt(1), val) + val, _ = callDecimal2([]any{false}) + assert.Equal(t, decimal.NewFromInt(0), val) + _, err = callDecimal2([]any{testTime}) + assert.Equal(t, `unsupported arg type for decimal2(0001-01-01 01:01:01.000000001 +0000 UTC):time.Time`, err.Error()) + _, err = callDecimal2([]any{"somestring"}) + assert.Equal(t, "cannot eval decimal2(somestring):can't convert somestring to decimal: exponent is not numeric", err.Error()) + _, err = callDecimal2([]any{"12.0", "13.0"}) + assert.Equal(t, "cannot evaluate decimal2(), requires 1 args, 2 supplied", err.Error()) + + val, _ = callFloat([]any{int64(12)}) + assert.Equal(t, float64(12), val) + val, _ = callFloat([]any{"12.1"}) + assert.Equal(t, float64(12.1), val) + val, _ = callFloat([]any{decimal.NewFromFloat(12.2)}) + assert.Equal(t, float64(12.2), val) + val, _ = callFloat([]any{true}) + assert.Equal(t, float64(1), val) + val, _ = callFloat([]any{false}) + assert.Equal(t, float64(0), val) + _, err = callFloat([]any{testTime}) + assert.Equal(t, `unsupported arg type for float(0001-01-01 01:01:01.000000001 +0000 UTC):time.Time`, err.Error()) + _, err = callFloat([]any{"somestring"}) + assert.Equal(t, `cannot eval float(somestring):strconv.ParseFloat: parsing "somestring": invalid syntax`, err.Error()) + _, err = callFloat([]any{"12.0", "13.0"}) + assert.Equal(t, "cannot evaluate float(), requires 1 args, 2 supplied", err.Error()) +} diff --git a/pkg/eval/eval_ctx.go b/pkg/eval/eval_ctx.go index 3e14e09..0212922 100644 --- a/pkg/eval/eval_ctx.go +++ b/pkg/eval/eval_ctx.go @@ -1,730 +1,720 @@ -package eval - -import ( - "fmt" - "go/ast" - "go/token" - "math" - "strconv" - "strings" - "time" - - "github.com/shopspring/decimal" -) - -type AggEnabledType int - -const ( - AggFuncDisabled AggEnabledType = iota - AggFuncEnabled -) - -type EvalCtx struct { - Vars *VarValuesMap - AggFunc AggFuncType - AggType AggDataType - AggCallExp *ast.CallExpr - Count int64 - StringAgg StringAggCollector - Sum SumCollector - Avg AvgCollector - Min MinCollector - Max MaxCollector - Value interface{} - AggEnabled AggEnabledType -} - -// Not ready to make these limits/defaults public - -const ( - maxSupportedInt int64 = int64(math.MaxInt64) - minSupportedInt int64 = int64(math.MinInt64) - maxSupportedFloat float64 = math.MaxFloat64 - minSupportedFloat float64 = -math.MaxFloat32 -) - -func maxSupportedDecimal() decimal.Decimal { - return decimal.NewFromFloat32(math.MaxFloat32) -} -func minSupportedDecimal() decimal.Decimal { - return decimal.NewFromFloat32(-math.MaxFloat32 + 1) -} - -func defaultDecimal() decimal.Decimal { - // Explicit zero, otherwise its decimal NIL - return decimal.NewFromInt(0) -} - -// TODO: refactor to avoid duplicated ctx creationcode - -func NewPlainEvalCtx(aggEnabled AggEnabledType) EvalCtx { - return EvalCtx{ - AggFunc: AggUnknown, - AggType: AggTypeUnknown, - AggEnabled: aggEnabled, - StringAgg: StringAggCollector{Separator: "", Sb: strings.Builder{}}, - Sum: SumCollector{Dec: defaultDecimal()}, - Avg: AvgCollector{Dec: defaultDecimal()}, - Min: MinCollector{Int: maxSupportedInt, Float: maxSupportedFloat, Dec: maxSupportedDecimal(), Str: ""}, - Max: MaxCollector{Int: minSupportedInt, Float: minSupportedFloat, Dec: minSupportedDecimal(), Str: ""}} -} - -func NewPlainEvalCtxAndInitializedAgg(aggEnabled AggEnabledType, aggFuncType AggFuncType, aggFuncArgs []ast.Expr) (*EvalCtx, error) { - eCtx := NewPlainEvalCtx(aggEnabled) - // Special case: we need to provide eCtx.StringAgg with a separator and - // explicitly set its type to AggTypeString from the very beginning (instead of detecting it later, as we do for other agg functions) - if aggEnabled == AggFuncEnabled && aggFuncType == AggStringAgg { - var aggStringErr error - eCtx.StringAgg.Separator, aggStringErr = GetAggStringSeparator(aggFuncArgs) - if aggStringErr != nil { - return nil, aggStringErr - } - eCtx.AggType = AggTypeString - } - return &eCtx, nil -} - -func NewPlainEvalCtxWithVars(aggEnabled AggEnabledType, vars *VarValuesMap) EvalCtx { - return EvalCtx{ - AggFunc: AggUnknown, - Vars: vars, - AggType: AggTypeUnknown, - AggEnabled: aggEnabled, - StringAgg: StringAggCollector{Separator: "", Sb: strings.Builder{}}, - Sum: SumCollector{Dec: defaultDecimal()}, - Avg: AvgCollector{Dec: defaultDecimal()}, - Min: MinCollector{Int: maxSupportedInt, Float: maxSupportedFloat, Dec: maxSupportedDecimal(), Str: ""}, - Max: MaxCollector{Int: minSupportedInt, Float: minSupportedFloat, Dec: minSupportedDecimal(), Str: ""}} -} - -func NewPlainEvalCtxWithVarsAndInitializedAgg(aggEnabled AggEnabledType, vars *VarValuesMap, aggFuncType AggFuncType, aggFuncArgs []ast.Expr) (*EvalCtx, error) { - eCtx := NewPlainEvalCtxWithVars(aggEnabled, vars) - // Special case: we need to provide eCtx.StringAgg with a separator and - // explicitly set its type to AggTypeString from the very beginning (instead of detecting it later, as we do for other agg functions) - if aggEnabled == AggFuncEnabled && aggFuncType == AggStringAgg { - var aggStringErr error - eCtx.StringAgg.Separator, aggStringErr = GetAggStringSeparator(aggFuncArgs) - if aggStringErr != nil { - return nil, aggStringErr - } - eCtx.AggType = AggTypeString - } - return &eCtx, nil -} - -func checkArgs(funcName string, requiredArgCount int, actualArgCount int) error { - if actualArgCount != requiredArgCount { - return fmt.Errorf("cannot evaluate %s(), requires %d args, %d supplied", funcName, requiredArgCount, actualArgCount) - } else { - return nil - } -} - -func (eCtx *EvalCtx) EvalBinaryInt(valLeftVolatile interface{}, op token.Token, valRightVolatile interface{}) (result int64, finalErr error) { - - result = math.MaxInt - valLeft, ok := valLeftVolatile.(int64) - if !ok { - return 0, fmt.Errorf("cannot evaluate binary int64 expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) - } - - valRight, ok := valRightVolatile.(int64) - if !ok { - return 0, fmt.Errorf("cannot evaluate binary int64 expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) - } - - defer func() { - if r := recover(); r != nil { - finalErr = fmt.Errorf("%v", r) - } - }() - - switch op { - case token.ADD: - return valLeft + valRight, nil - case token.SUB: - return valLeft - valRight, nil - case token.MUL: - return valLeft * valRight, nil - case token.QUO: - return valLeft / valRight, nil - case token.REM: - return valLeft % valRight, nil - default: - return 0, fmt.Errorf("cannot perform int op %v against int %d and int %d", op, valLeft, valRight) - } -} - -func isCompareOp(op token.Token) bool { - return op == token.GTR || op == token.LSS || op == token.GEQ || op == token.LEQ || op == token.EQL || op == token.NEQ -} - -func (eCtx *EvalCtx) EvalBinaryIntToBool(valLeftVolatile interface{}, op token.Token, valRightVolatile interface{}) (bool, error) { - - valLeft, ok := valLeftVolatile.(int64) - if !ok { - return false, fmt.Errorf("cannot evaluate binary int64 expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) - } - - valRight, ok := valRightVolatile.(int64) - if !ok { - return false, fmt.Errorf("cannot evaluate binary int64 expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) - } - - if isCompareOp(op) { - if op == token.GTR && valLeft > valRight || - op == token.LSS && valLeft < valRight || - op == token.GEQ && valLeft >= valRight || - op == token.LEQ && valLeft <= valRight || - op == token.EQL && valLeft == valRight || - op == token.NEQ && valLeft != valRight { - return true, nil - } else { - return false, nil - } - } else { - return false, fmt.Errorf("cannot perform bool op %v against int %d and int %d", op, valLeft, valRight) - } -} - -func (eCtx *EvalCtx) EvalBinaryFloat64ToBool(valLeftVolatile interface{}, op token.Token, valRightVolatile interface{}) (bool, error) { - - valLeft, ok := valLeftVolatile.(float64) - if !ok { - return false, fmt.Errorf("cannot evaluate binary foat64 expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) - } - - valRight, ok := valRightVolatile.(float64) - if !ok { - return false, fmt.Errorf("cannot evaluate binary float64 expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) - } - - if isCompareOp(op) { - if op == token.GTR && valLeft > valRight || - op == token.LSS && valLeft < valRight || - op == token.GEQ && valLeft >= valRight || - op == token.LEQ && valLeft <= valRight || - op == token.EQL && valLeft == valRight || - op == token.NEQ && valLeft != valRight { - return true, nil - } else { - return false, nil - } - } else { - return false, fmt.Errorf("cannot perform bool op %v against float %f and float %f", op, valLeft, valRight) - } -} - -func (eCtx *EvalCtx) EvalBinaryDecimal2ToBool(valLeftVolatile interface{}, op token.Token, valRightVolatile interface{}) (bool, error) { - - valLeft, ok := valLeftVolatile.(decimal.Decimal) - if !ok { - return false, fmt.Errorf("cannot evaluate binary decimal2 expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) - } - - valRight, ok := valRightVolatile.(decimal.Decimal) - if !ok { - return false, fmt.Errorf("cannot evaluate binary decimal2 expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) - } - - if isCompareOp(op) { - if op == token.GTR && valLeft.Cmp(valRight) > 0 || - op == token.LSS && valLeft.Cmp(valRight) < 0 || - op == token.GEQ && valLeft.Cmp(valRight) >= 0 || - op == token.LEQ && valLeft.Cmp(valRight) <= 0 || - op == token.EQL && valLeft.Cmp(valRight) == 0 || - op == token.NEQ && valLeft.Cmp(valRight) != 0 { - return true, nil - } else { - return false, nil - } - } else { - return false, fmt.Errorf("cannot perform bool op %v against decimal2 %v and decimal2 %v", op, valLeft, valRight) - } -} - -func (eCtx *EvalCtx) EvalBinaryTimeToBool(valLeftVolatile interface{}, op token.Token, valRightVolatile interface{}) (bool, error) { - - valLeft, ok := valLeftVolatile.(time.Time) - if !ok { - return false, fmt.Errorf("cannot evaluate binary time expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) - } - - valRight, ok := valRightVolatile.(time.Time) - if !ok { - return false, fmt.Errorf("cannot evaluate binary time expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) - } - - if isCompareOp(op) { - if op == token.GTR && valLeft.After(valRight) || - op == token.LSS && valLeft.Before(valRight) || - op == token.GEQ && (valLeft.After(valRight) || valLeft == valRight) || - op == token.LEQ && (valLeft.Before(valRight) || valLeft == valRight) || - op == token.EQL && valLeft == valRight || - op == token.NEQ && valLeft != valRight { - return true, nil - } else { - return false, nil - } - } else { - return false, fmt.Errorf("cannot perform bool op %v against time %v and time %v", op, valLeft, valRight) - } -} - -func (eCtx *EvalCtx) EvalBinaryFloat64(valLeftVolatile interface{}, op token.Token, valRightVolatile interface{}) (float64, error) { - - valLeft, ok := valLeftVolatile.(float64) - if !ok { - return 0.0, fmt.Errorf("cannot evaluate binary float64 expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) - } - - valRight, ok := valRightVolatile.(float64) - if !ok { - return 0.0, fmt.Errorf("cannot evaluate binary float expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) - } - - switch op { - case token.ADD: - return valLeft + valRight, nil - case token.SUB: - return valLeft - valRight, nil - case token.MUL: - return valLeft * valRight, nil - case token.QUO: - return valLeft / valRight, nil - default: - return 0, fmt.Errorf("cannot perform float64 op %v against float64 %f and float64 %f", op, valLeft, valRight) - } -} - -func (eCtx *EvalCtx) EvalBinaryDecimal2(valLeftVolatile interface{}, op token.Token, valRightVolatile interface{}) (result decimal.Decimal, err error) { - - result = decimal.NewFromFloat(math.MaxFloat64) - err = nil - valLeft, ok := valLeftVolatile.(decimal.Decimal) - if !ok { - return decimal.NewFromInt(0), fmt.Errorf("cannot evaluate binary decimal2 expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) - } - - valRight, ok := valRightVolatile.(decimal.Decimal) - if !ok { - return decimal.NewFromInt(0), fmt.Errorf("cannot evaluate binary decimal2 expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) - } - - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("%v", r) - } - }() - - switch op { - case token.ADD: - return valLeft.Add(valRight).Round(2), nil - case token.SUB: - return valLeft.Sub(valRight).Round(2), nil - case token.MUL: - return valLeft.Mul(valRight).Round(2), nil - case token.QUO: - return valLeft.Div(valRight).Round(2), nil - default: - return decimal.NewFromInt(0), fmt.Errorf("cannot perform decimal2 op %v against decimal2 %v and float64 %v", op, valLeft, valRight) - } -} - -func (eCtx *EvalCtx) EvalBinaryBool(valLeftVolatile interface{}, op token.Token, valRightVolatile interface{}) (bool, error) { - - valLeft, ok := valLeftVolatile.(bool) - if !ok { - return false, fmt.Errorf("cannot evaluate binary bool expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) - } - - valRight, ok := valRightVolatile.(bool) - if !ok { - return false, fmt.Errorf("cannot evaluate binary bool expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) - } - - switch op { - case token.LAND: - return valLeft && valRight, nil - case token.LOR: - return valLeft || valRight, nil - default: - return false, fmt.Errorf("cannot perform bool op %v against bool %v and bool %v", op, valLeft, valRight) - } -} - -func (eCtx *EvalCtx) EvalBinaryBoolToBool(valLeftVolatile interface{}, op token.Token, valRightVolatile interface{}) (bool, error) { - - valLeft, ok := valLeftVolatile.(bool) - if !ok { - return false, fmt.Errorf("cannot evaluate binary bool expression %v with %T on the left", op, valLeftVolatile) - } - - valRight, ok := valRightVolatile.(bool) - if !ok { - return false, fmt.Errorf("cannot evaluate binary bool expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) - } - - if op == token.EQL || op == token.NEQ { - if op == token.EQL && valLeft == valRight || - op == token.NEQ && valLeft != valRight { - return true, nil - } else { - return false, nil - } - } else { - return false, fmt.Errorf("cannot evaluate binary bool expression, op %v not supported (and will never be)", op) - } -} - -func (eCtx *EvalCtx) EvalUnaryBoolNot(exp ast.Expr) (bool, error) { - valVolatile, err := eCtx.Eval(exp) - if err != nil { - return false, err - } - - val, ok := valVolatile.(bool) - if !ok { - return false, fmt.Errorf("cannot evaluate unary bool not expression with %T on the right", valVolatile) - } - - return !val, nil -} - -func (eCtx *EvalCtx) EvalUnaryMinus(exp ast.Expr) (interface{}, error) { - valVolatile, err := eCtx.Eval(exp) - if err != nil { - return false, err - } - - switch typedVal := valVolatile.(type) { - case int: - return int64(-typedVal), nil - case int16: - return int64(-typedVal), nil - case int32: - return int64(-typedVal), nil - case int64: - return -typedVal, nil - case float32: - return float64(-typedVal), nil - case float64: - return -typedVal, nil - case decimal.Decimal: - return typedVal.Neg(), nil - default: - return false, fmt.Errorf("cannot evaluate unary minus expression '-%v(%T)', unsupported type", valVolatile, valVolatile) - } -} - -func (eCtx *EvalCtx) EvalBinaryString(valLeftVolatile interface{}, op token.Token, valRightVolatile interface{}) (string, error) { - - valLeft, ok := valLeftVolatile.(string) - if !ok { - return "", fmt.Errorf("cannot evaluate binary string expression %v with %T on the left", op, valLeftVolatile) - } - - valRight, ok := valRightVolatile.(string) - if !ok { - return "", fmt.Errorf("cannot evaluate binary string expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) - } - - switch op { - case token.ADD: - return valLeft + valRight, nil - default: - return "", fmt.Errorf("cannot perform string op %v against string '%s' and string '%s', op not supported", op, valLeft, valRight) - } -} - -func (eCtx *EvalCtx) EvalBinaryStringToBool(valLeftVolatile interface{}, op token.Token, valRightVolatile interface{}) (bool, error) { - - valLeft, ok := valLeftVolatile.(string) - if !ok { - return false, fmt.Errorf("cannot evaluate binary string expression %v with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) - } - valLeft = strings.Replace(strings.Trim(valLeft, "\""), `\"`, `\`, -1) - - valRight, ok := valRightVolatile.(string) - if !ok { - return false, fmt.Errorf("cannot evaluate binary decimal2 expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) - } - valRight = strings.Replace(strings.Trim(valRight, "\""), `\"`, `"`, -1) - - if isCompareOp(op) { - if op == token.GTR && valLeft > valRight || - op == token.LSS && valLeft < valRight || - op == token.GEQ && valLeft >= valRight || - op == token.LEQ && valLeft <= valRight || - op == token.EQL && valLeft == valRight || - op == token.NEQ && valLeft != valRight { - return true, nil - } else { - return false, nil - } - } else { - return false, fmt.Errorf("cannot perform bool op %v against string %v and string %v", op, valLeft, valRight) - } -} - -func (eCtx *EvalCtx) EvalFunc(callExp *ast.CallExpr, funcName string, args []interface{}) (interface{}, error) { - var err error = nil - switch funcName { - case "math.Sqrt": - eCtx.Value, err = callMathSqrt(args) - case "math.Round": - eCtx.Value, err = callMathRound(args) - case "len": - eCtx.Value, err = callLen(args) - case "string": - eCtx.Value, err = callString(args) - case "float": - eCtx.Value, err = callFloat(args) - case "int": - eCtx.Value, err = callInt(args) - case "decimal2": - eCtx.Value, err = callDecimal2(args) - case "time.Parse": - eCtx.Value, err = callTimeParse(args) - case "time.Format": - eCtx.Value, err = callTimeFormat(args) - case "time.Date": - eCtx.Value, err = callTimeDate(args) - case "time.Now": - eCtx.Value, err = callTimeNow(args) - case "time.Unix": - eCtx.Value, err = callTimeUnix(args) - case "time.UnixMilli": - eCtx.Value, err = callTimeUnixMilli(args) - case "time.DiffMilli": - eCtx.Value, err = callTimeDiffMilli(args) - case "time.Before": - eCtx.Value, err = callTimeBefore(args) - case "time.After": - eCtx.Value, err = callTimeAfter(args) - case "time.FixedZone": - eCtx.Value, err = callTimeFixedZone(args) - case "re.MatchString": - eCtx.Value, err = callReMatchString(args) - case "strings.ReplaceAll": - eCtx.Value, err = callStringsReplaceAll(args) - case "fmt.Sprintf": - eCtx.Value, err = callFmtSprintf(args) - - // Aggregate functions, to be used only in grouped lookups - - case "string_agg": - eCtx.Value, err = eCtx.CallAggStringAgg(callExp, args) - case "sum": - eCtx.Value, err = eCtx.CallAggSum(callExp, args) - case "count": - eCtx.Value, err = eCtx.CallAggCount(callExp, args) - case "avg": - eCtx.Value, err = eCtx.CallAggAvg(callExp, args) - case "min": - eCtx.Value, err = eCtx.CallAggMin(callExp, args) - case "max": - eCtx.Value, err = eCtx.CallAggMax(callExp, args) - - default: - return nil, fmt.Errorf("cannot evaluate unsupported func '%s'", funcName) - } - return eCtx.Value, err -} - -func (eCtx *EvalCtx) Eval(exp ast.Expr) (interface{}, error) { - switch exp := exp.(type) { - case *ast.BinaryExpr: - valLeftVolatile, err := eCtx.Eval(exp.X) - if err != nil { - return nil, err - } - - valRightVolatile, err := eCtx.Eval(exp.Y) - if err != nil { - return 0, err - } - - if exp.Op == token.ADD || exp.Op == token.SUB || exp.Op == token.MUL || exp.Op == token.QUO || exp.Op == token.REM { - switch valLeftVolatile.(type) { - case string: - eCtx.Value, err = eCtx.EvalBinaryString(valLeftVolatile, exp.Op, valRightVolatile) - return eCtx.Value, err - - default: - // Assume both args are numbers (int, float, dec) - stdArgLeft, stdArgRight, err := castNumberPairToCommonType(valLeftVolatile, valRightVolatile) - if err != nil { - return nil, fmt.Errorf("cannot perform binary arithmetic op, incompatible arg types '%v(%T)' %v '%v(%T)' ", valLeftVolatile, valLeftVolatile, exp.Op, valRightVolatile, valRightVolatile) - } - switch stdArgLeft.(type) { - case int64: - eCtx.Value, err = eCtx.EvalBinaryInt(stdArgLeft, exp.Op, stdArgRight) - return eCtx.Value, err - case float64: - eCtx.Value, err = eCtx.EvalBinaryFloat64(stdArgLeft, exp.Op, stdArgRight) - return eCtx.Value, err - case decimal.Decimal: - eCtx.Value, err = eCtx.EvalBinaryDecimal2(stdArgLeft, exp.Op, stdArgRight) - return eCtx.Value, err - default: - return nil, fmt.Errorf("cannot perform binary arithmetic op, unexpected std type '%v(%T)' %v '%v(%T)' ", valLeftVolatile, valLeftVolatile, exp.Op, valRightVolatile, valRightVolatile) - } - } - } else if exp.Op == token.LOR || exp.Op == token.LAND { - switch valLeftTyped := valLeftVolatile.(type) { - case bool: - eCtx.Value, err = eCtx.EvalBinaryBool(valLeftTyped, exp.Op, valRightVolatile) - return eCtx.Value, err - default: - return nil, fmt.Errorf("cannot perform binary op %v against %T left", exp.Op, valLeftVolatile) - } - - } else if exp.Op == token.GTR || exp.Op == token.GEQ || exp.Op == token.LSS || exp.Op == token.LEQ || exp.Op == token.EQL || exp.Op == token.NEQ { - switch valLeftVolatile.(type) { - case time.Time: - eCtx.Value, err = eCtx.EvalBinaryTimeToBool(valLeftVolatile, exp.Op, valRightVolatile) - return eCtx.Value, err - case string: - eCtx.Value, err = eCtx.EvalBinaryStringToBool(valLeftVolatile, exp.Op, valRightVolatile) - return eCtx.Value, err - case bool: - eCtx.Value, err = eCtx.EvalBinaryBoolToBool(valLeftVolatile, exp.Op, valRightVolatile) - return eCtx.Value, err - default: - // Assume both args are numbers (int, float, dec) - stdArgLeft, stdArgRight, err := castNumberPairToCommonType(valLeftVolatile, valRightVolatile) - if err != nil { - return nil, fmt.Errorf("cannot perform binary comp op, incompatible arg types '%v(%T)' %v '%v(%T)' ", valLeftVolatile, valLeftVolatile, exp.Op, valRightVolatile, valRightVolatile) - } - switch stdArgLeft.(type) { - case int64: - eCtx.Value, err = eCtx.EvalBinaryIntToBool(stdArgLeft, exp.Op, stdArgRight) - return eCtx.Value, err - case float64: - eCtx.Value, err = eCtx.EvalBinaryFloat64ToBool(stdArgLeft, exp.Op, stdArgRight) - return eCtx.Value, err - case decimal.Decimal: - eCtx.Value, err = eCtx.EvalBinaryDecimal2ToBool(stdArgLeft, exp.Op, stdArgRight) - return eCtx.Value, err - default: - return nil, fmt.Errorf("cannot perform binary comp op, unexpected std type '%v(%T)' %v '%v(%T)' ", valLeftVolatile, valLeftVolatile, exp.Op, valRightVolatile, valRightVolatile) - } - } - } else { - return nil, fmt.Errorf("cannot perform binary expression unknown op %v", exp.Op) - } - case *ast.BasicLit: - switch exp.Kind { - case token.INT: - i, _ := strconv.ParseInt(exp.Value, 10, 64) - eCtx.Value = i - return i, nil - case token.FLOAT: - i, _ := strconv.ParseFloat(exp.Value, 64) - eCtx.Value = i - return i, nil - case token.IDENT: - return nil, fmt.Errorf("cannot evaluate expression %s of type token.IDENT", exp.Value) - case token.STRING: - eCtx.Value = exp.Value - if exp.Value[0] == '"' { - return strings.Trim(exp.Value, "\""), nil - } else { - return strings.Trim(exp.Value, "`"), nil - } - default: - return nil, fmt.Errorf("cannot evaluate expression %s of type %v", exp.Value, exp.Kind) - } - case *ast.UnaryExpr: - switch exp.Op { - case token.NOT: - var err error - eCtx.Value, err = eCtx.EvalUnaryBoolNot(exp.X) - return eCtx.Value, err - case token.SUB: - var err error - eCtx.Value, err = eCtx.EvalUnaryMinus(exp.X) - return eCtx.Value, err - default: - return nil, fmt.Errorf("cannot evaluate unary op %v, unkown op", exp.Op) - } - - case *ast.Ident: - if exp.Name == "true" { - eCtx.Value = true - return true, nil - } else if exp.Name == "false" { - eCtx.Value = false - return false, nil - } else { - return nil, fmt.Errorf("cannot evaluate identifier %s", exp.Name) - } - - case *ast.CallExpr: - args := make([]interface{}, len(exp.Args)) - - for i, v := range exp.Args { - arg, err := eCtx.Eval(v) - if err != nil { - return nil, err - } - args[i] = arg - } - - switch exp.Fun.(type) { - case *ast.Ident: - funcIdent, _ := exp.Fun.(*ast.Ident) - var err error - eCtx.Value, err = eCtx.EvalFunc(exp, funcIdent.Name, args) - return eCtx.Value, err - - case *ast.SelectorExpr: - expSel := exp.Fun.(*ast.SelectorExpr) - switch expSel.X.(type) { - case *ast.Ident: - expIdent, _ := expSel.X.(*ast.Ident) - var err error - eCtx.Value, err = eCtx.EvalFunc(exp, fmt.Sprintf("%s.%s", expIdent.Name, expSel.Sel.Name), args) - return eCtx.Value, err - default: - return nil, fmt.Errorf("cannot evaluate fun expression %v, unknown type of X: %T", expSel.X, expSel.X) - } - - default: - return nil, fmt.Errorf("cannot evaluate func call expression %v, unknown type of X: %T", exp.Fun, exp.Fun) - } - - case *ast.SelectorExpr: - switch exp.X.(type) { - case *ast.Ident: - objectIdent, _ := exp.X.(*ast.Ident) - golangConst, ok := GolangConstants[fmt.Sprintf("%s.%s", objectIdent.Name, exp.Sel.Name)] - if ok { - eCtx.Value = golangConst - return golangConst, nil - } - - if eCtx.Vars == nil { - return nil, fmt.Errorf("cannot evaluate expression '%s', no variables supplied to the context", objectIdent.Name) - } - - objectAttributes, ok := (*eCtx.Vars)[objectIdent.Name] - if !ok { - return nil, fmt.Errorf("cannot evaluate expression '%s', variable not supplied, check table/alias name", objectIdent.Name) - } - - val, ok := objectAttributes[exp.Sel.Name] - if !ok { - return nil, fmt.Errorf("cannot evaluate expression %s.%s, variable not supplied, check field name", objectIdent.Name, exp.Sel.Name) - } - eCtx.Value = val - return val, nil - default: - return nil, fmt.Errorf("cannot evaluate selector expression %v, unknown type of X: %T", exp.X, exp.X) - } - - default: - return nil, fmt.Errorf("cannot evaluate generic expression %v of unknown type %T", exp, exp) - } -} +package eval + +import ( + "fmt" + "go/ast" + "go/token" + "math" + "strconv" + "strings" + "time" + + "github.com/shopspring/decimal" +) + +type AggEnabledType int + +const ( + AggFuncDisabled AggEnabledType = iota + AggFuncEnabled +) + +type EvalCtx struct { + Vars *VarValuesMap + AggFunc AggFuncType + AggType AggDataType + AggCallExp *ast.CallExpr + Count int64 + StringAgg StringAggCollector + Sum SumCollector + Avg AvgCollector + Min MinCollector + Max MaxCollector + Value any + AggEnabled AggEnabledType +} + +// Not ready to make these limits/defaults public + +const ( + maxSupportedInt int64 = int64(math.MaxInt64) + minSupportedInt int64 = int64(math.MinInt64) + maxSupportedFloat float64 = math.MaxFloat64 + minSupportedFloat float64 = -math.MaxFloat32 +) + +func maxSupportedDecimal() decimal.Decimal { + return decimal.NewFromFloat32(math.MaxFloat32) +} +func minSupportedDecimal() decimal.Decimal { + return decimal.NewFromFloat32(-math.MaxFloat32 + 1) +} + +func defaultDecimal() decimal.Decimal { + // Explicit zero, otherwise its decimal NIL + return decimal.NewFromInt(0) +} + +// TODO: refactor to avoid duplicated ctx creationcode + +func NewPlainEvalCtx(aggEnabled AggEnabledType) EvalCtx { + return EvalCtx{ + AggFunc: AggUnknown, + AggType: AggTypeUnknown, + AggEnabled: aggEnabled, + StringAgg: StringAggCollector{Separator: "", Sb: strings.Builder{}}, + Sum: SumCollector{Dec: defaultDecimal()}, + Avg: AvgCollector{Dec: defaultDecimal()}, + Min: MinCollector{Int: maxSupportedInt, Float: maxSupportedFloat, Dec: maxSupportedDecimal(), Str: ""}, + Max: MaxCollector{Int: minSupportedInt, Float: minSupportedFloat, Dec: minSupportedDecimal(), Str: ""}} +} + +func NewPlainEvalCtxAndInitializedAgg(aggEnabled AggEnabledType, aggFuncType AggFuncType, aggFuncArgs []ast.Expr) (*EvalCtx, error) { + eCtx := NewPlainEvalCtx(aggEnabled) + // Special case: we need to provide eCtx.StringAgg with a separator and + // explicitly set its type to AggTypeString from the very beginning (instead of detecting it later, as we do for other agg functions) + if aggEnabled == AggFuncEnabled && aggFuncType == AggStringAgg { + var aggStringErr error + eCtx.StringAgg.Separator, aggStringErr = GetAggStringSeparator(aggFuncArgs) + if aggStringErr != nil { + return nil, aggStringErr + } + eCtx.AggType = AggTypeString + } + return &eCtx, nil +} + +func NewPlainEvalCtxWithVars(aggEnabled AggEnabledType, vars *VarValuesMap) EvalCtx { + return EvalCtx{ + AggFunc: AggUnknown, + Vars: vars, + AggType: AggTypeUnknown, + AggEnabled: aggEnabled, + StringAgg: StringAggCollector{Separator: "", Sb: strings.Builder{}}, + Sum: SumCollector{Dec: defaultDecimal()}, + Avg: AvgCollector{Dec: defaultDecimal()}, + Min: MinCollector{Int: maxSupportedInt, Float: maxSupportedFloat, Dec: maxSupportedDecimal(), Str: ""}, + Max: MaxCollector{Int: minSupportedInt, Float: minSupportedFloat, Dec: minSupportedDecimal(), Str: ""}} +} + +func NewPlainEvalCtxWithVarsAndInitializedAgg(aggEnabled AggEnabledType, vars *VarValuesMap, aggFuncType AggFuncType, aggFuncArgs []ast.Expr) (*EvalCtx, error) { + eCtx := NewPlainEvalCtxWithVars(aggEnabled, vars) + // Special case: we need to provide eCtx.StringAgg with a separator and + // explicitly set its type to AggTypeString from the very beginning (instead of detecting it later, as we do for other agg functions) + if aggEnabled == AggFuncEnabled && aggFuncType == AggStringAgg { + var aggStringErr error + eCtx.StringAgg.Separator, aggStringErr = GetAggStringSeparator(aggFuncArgs) + if aggStringErr != nil { + return nil, aggStringErr + } + eCtx.AggType = AggTypeString + } + return &eCtx, nil +} + +func checkArgs(funcName string, requiredArgCount int, actualArgCount int) error { + if actualArgCount != requiredArgCount { + return fmt.Errorf("cannot evaluate %s(), requires %d args, %d supplied", funcName, requiredArgCount, actualArgCount) + } + return nil +} + +func (eCtx *EvalCtx) EvalBinaryInt(valLeftVolatile any, op token.Token, valRightVolatile any) (result int64, finalErr error) { + + result = math.MaxInt + valLeft, ok := valLeftVolatile.(int64) + if !ok { + return 0, fmt.Errorf("cannot evaluate binary int64 expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) + } + + valRight, ok := valRightVolatile.(int64) + if !ok { + return 0, fmt.Errorf("cannot evaluate binary int64 expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) + } + + defer func() { + if r := recover(); r != nil { + finalErr = fmt.Errorf("%v", r) + } + }() + + switch op { + case token.ADD: + return valLeft + valRight, nil + case token.SUB: + return valLeft - valRight, nil + case token.MUL: + return valLeft * valRight, nil + case token.QUO: + return valLeft / valRight, nil + case token.REM: + return valLeft % valRight, nil + default: + return 0, fmt.Errorf("cannot perform int op %v against int %d and int %d", op, valLeft, valRight) + } +} + +func isCompareOp(op token.Token) bool { + return op == token.GTR || op == token.LSS || op == token.GEQ || op == token.LEQ || op == token.EQL || op == token.NEQ +} + +func (eCtx *EvalCtx) EvalBinaryIntToBool(valLeftVolatile any, op token.Token, valRightVolatile any) (bool, error) { + + valLeft, ok := valLeftVolatile.(int64) + if !ok { + return false, fmt.Errorf("cannot evaluate binary int64 expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) + } + + valRight, ok := valRightVolatile.(int64) + if !ok { + return false, fmt.Errorf("cannot evaluate binary int64 expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) + } + + if !isCompareOp(op) { + return false, fmt.Errorf("cannot perform bool op %v against int %d and int %d", op, valLeft, valRight) + } + + if op == token.GTR && valLeft > valRight || + op == token.LSS && valLeft < valRight || + op == token.GEQ && valLeft >= valRight || + op == token.LEQ && valLeft <= valRight || + op == token.EQL && valLeft == valRight || + op == token.NEQ && valLeft != valRight { + return true, nil + } + return false, nil +} + +func (eCtx *EvalCtx) EvalBinaryFloat64ToBool(valLeftVolatile any, op token.Token, valRightVolatile any) (bool, error) { + + valLeft, ok := valLeftVolatile.(float64) + if !ok { + return false, fmt.Errorf("cannot evaluate binary foat64 expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) + } + + valRight, ok := valRightVolatile.(float64) + if !ok { + return false, fmt.Errorf("cannot evaluate binary float64 expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) + } + + if !isCompareOp(op) { + return false, fmt.Errorf("cannot perform bool op %v against float %f and float %f", op, valLeft, valRight) + } + + if op == token.GTR && valLeft > valRight || + op == token.LSS && valLeft < valRight || + op == token.GEQ && valLeft >= valRight || + op == token.LEQ && valLeft <= valRight || + op == token.EQL && valLeft == valRight || + op == token.NEQ && valLeft != valRight { + return true, nil + } + return false, nil +} + +func (eCtx *EvalCtx) EvalBinaryDecimal2ToBool(valLeftVolatile any, op token.Token, valRightVolatile any) (bool, error) { + + valLeft, ok := valLeftVolatile.(decimal.Decimal) + if !ok { + return false, fmt.Errorf("cannot evaluate binary decimal2 expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) + } + + valRight, ok := valRightVolatile.(decimal.Decimal) + if !ok { + return false, fmt.Errorf("cannot evaluate binary decimal2 expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) + } + + if !isCompareOp(op) { + return false, fmt.Errorf("cannot perform bool op %v against decimal2 %v and decimal2 %v", op, valLeft, valRight) + } + if op == token.GTR && valLeft.Cmp(valRight) > 0 || + op == token.LSS && valLeft.Cmp(valRight) < 0 || + op == token.GEQ && valLeft.Cmp(valRight) >= 0 || + op == token.LEQ && valLeft.Cmp(valRight) <= 0 || + op == token.EQL && valLeft.Cmp(valRight) == 0 || + op == token.NEQ && valLeft.Cmp(valRight) != 0 { + return true, nil + } + return false, nil +} + +func (eCtx *EvalCtx) EvalBinaryTimeToBool(valLeftVolatile any, op token.Token, valRightVolatile any) (bool, error) { + + valLeft, ok := valLeftVolatile.(time.Time) + if !ok { + return false, fmt.Errorf("cannot evaluate binary time expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) + } + + valRight, ok := valRightVolatile.(time.Time) + if !ok { + return false, fmt.Errorf("cannot evaluate binary time expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) + } + + if !isCompareOp(op) { + return false, fmt.Errorf("cannot perform bool op %v against time %v and time %v", op, valLeft, valRight) + } + if op == token.GTR && valLeft.After(valRight) || + op == token.LSS && valLeft.Before(valRight) || + op == token.GEQ && (valLeft.After(valRight) || valLeft.Equal(valRight)) || + op == token.LEQ && (valLeft.Before(valRight) || valLeft.Equal(valRight)) || + op == token.EQL && valLeft.Equal(valRight) || + op == token.NEQ && !valLeft.Equal(valRight) { + return true, nil + } + return false, nil +} + +func (eCtx *EvalCtx) EvalBinaryFloat64(valLeftVolatile any, op token.Token, valRightVolatile any) (float64, error) { + + valLeft, ok := valLeftVolatile.(float64) + if !ok { + return 0.0, fmt.Errorf("cannot evaluate binary float64 expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) + } + + valRight, ok := valRightVolatile.(float64) + if !ok { + return 0.0, fmt.Errorf("cannot evaluate binary float expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) + } + + switch op { + case token.ADD: + return valLeft + valRight, nil + case token.SUB: + return valLeft - valRight, nil + case token.MUL: + return valLeft * valRight, nil + case token.QUO: + return valLeft / valRight, nil + default: + return 0, fmt.Errorf("cannot perform float64 op %v against float64 %f and float64 %f", op, valLeft, valRight) + } +} + +func (eCtx *EvalCtx) EvalBinaryDecimal2(valLeftVolatile any, op token.Token, valRightVolatile any) (result decimal.Decimal, err error) { + + result = decimal.NewFromFloat(math.MaxFloat64) + err = nil + valLeft, ok := valLeftVolatile.(decimal.Decimal) + if !ok { + return decimal.NewFromInt(0), fmt.Errorf("cannot evaluate binary decimal2 expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) + } + + valRight, ok := valRightVolatile.(decimal.Decimal) + if !ok { + return decimal.NewFromInt(0), fmt.Errorf("cannot evaluate binary decimal2 expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) + } + + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + } + }() + + switch op { + case token.ADD: + return valLeft.Add(valRight).Round(2), nil + case token.SUB: + return valLeft.Sub(valRight).Round(2), nil + case token.MUL: + return valLeft.Mul(valRight).Round(2), nil + case token.QUO: + return valLeft.Div(valRight).Round(2), nil + default: + return decimal.NewFromInt(0), fmt.Errorf("cannot perform decimal2 op %v against decimal2 %v and float64 %v", op, valLeft, valRight) + } +} + +func (eCtx *EvalCtx) EvalBinaryBool(valLeftVolatile any, op token.Token, valRightVolatile any) (bool, error) { + + valLeft, ok := valLeftVolatile.(bool) + if !ok { + return false, fmt.Errorf("cannot evaluate binary bool expression '%v' with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) + } + + valRight, ok := valRightVolatile.(bool) + if !ok { + return false, fmt.Errorf("cannot evaluate binary bool expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) + } + + switch op { + case token.LAND: + return valLeft && valRight, nil + case token.LOR: + return valLeft || valRight, nil + default: + return false, fmt.Errorf("cannot perform bool op %v against bool %v and bool %v", op, valLeft, valRight) + } +} + +func (eCtx *EvalCtx) EvalBinaryBoolToBool(valLeftVolatile any, op token.Token, valRightVolatile any) (bool, error) { + + valLeft, ok := valLeftVolatile.(bool) + if !ok { + return false, fmt.Errorf("cannot evaluate binary bool expression %v with %T on the left", op, valLeftVolatile) + } + + valRight, ok := valRightVolatile.(bool) + if !ok { + return false, fmt.Errorf("cannot evaluate binary bool expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) + } + + if !(op == token.EQL || op == token.NEQ) { + return false, fmt.Errorf("cannot evaluate binary bool expression, op %v not supported (and will never be)", op) + } + + if op == token.EQL && valLeft == valRight || + op == token.NEQ && valLeft != valRight { + return true, nil + } + return false, nil +} + +func (eCtx *EvalCtx) EvalUnaryBoolNot(exp ast.Expr) (bool, error) { + valVolatile, err := eCtx.Eval(exp) + if err != nil { + return false, err + } + + val, ok := valVolatile.(bool) + if !ok { + return false, fmt.Errorf("cannot evaluate unary bool not expression with %T on the right", valVolatile) + } + + return !val, nil +} + +func (eCtx *EvalCtx) EvalUnaryMinus(exp ast.Expr) (any, error) { + valVolatile, err := eCtx.Eval(exp) + if err != nil { + return false, err + } + + switch typedVal := valVolatile.(type) { + case int: + return int64(-typedVal), nil + case int16: + return int64(-typedVal), nil + case int32: + return int64(-typedVal), nil + case int64: + return -typedVal, nil + case float32: + return float64(-typedVal), nil + case float64: + return -typedVal, nil + case decimal.Decimal: + return typedVal.Neg(), nil + default: + return false, fmt.Errorf("cannot evaluate unary minus expression '-%v(%T)', unsupported type", valVolatile, valVolatile) + } +} + +func (eCtx *EvalCtx) EvalBinaryString(valLeftVolatile any, op token.Token, valRightVolatile any) (string, error) { + + valLeft, ok := valLeftVolatile.(string) + if !ok { + return "", fmt.Errorf("cannot evaluate binary string expression %v with %T on the left", op, valLeftVolatile) + } + + valRight, ok := valRightVolatile.(string) + if !ok { + return "", fmt.Errorf("cannot evaluate binary string expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) + } + + switch op { + case token.ADD: + return valLeft + valRight, nil + default: + return "", fmt.Errorf("cannot perform string op %v against string '%s' and string '%s', op not supported", op, valLeft, valRight) + } +} + +func (eCtx *EvalCtx) EvalBinaryStringToBool(valLeftVolatile any, op token.Token, valRightVolatile any) (bool, error) { + + valLeft, ok := valLeftVolatile.(string) + if !ok { + return false, fmt.Errorf("cannot evaluate binary string expression %v with '%v(%T)' on the left", op, valLeftVolatile, valLeftVolatile) + } + valLeft = strings.Replace(strings.Trim(valLeft, "\""), `\"`, `\`, -1) + + valRight, ok := valRightVolatile.(string) + if !ok { + return false, fmt.Errorf("cannot evaluate binary decimal2 expression '%v(%T) %v %v(%T)', invalid right arg", valLeft, valLeft, op, valRightVolatile, valRightVolatile) + } + valRight = strings.Replace(strings.Trim(valRight, "\""), `\"`, `"`, -1) + + if !isCompareOp(op) { + return false, fmt.Errorf("cannot perform bool op %v against string %v and string %v", op, valLeft, valRight) + } + if op == token.GTR && valLeft > valRight || + op == token.LSS && valLeft < valRight || + op == token.GEQ && valLeft >= valRight || + op == token.LEQ && valLeft <= valRight || + op == token.EQL && valLeft == valRight || + op == token.NEQ && valLeft != valRight { + return true, nil + } + return false, nil +} + +func (eCtx *EvalCtx) EvalFunc(callExp *ast.CallExpr, funcName string, args []any) (any, error) { + var err error + switch funcName { + case "math.Sqrt": + eCtx.Value, err = callMathSqrt(args) + case "math.Round": + eCtx.Value, err = callMathRound(args) + case "len": + eCtx.Value, err = callLen(args) + case "string": + eCtx.Value, err = callString(args) + case "float": + eCtx.Value, err = callFloat(args) + case "int": + eCtx.Value, err = callInt(args) + case "decimal2": + eCtx.Value, err = callDecimal2(args) + case "time.Parse": + eCtx.Value, err = callTimeParse(args) + case "time.Format": + eCtx.Value, err = callTimeFormat(args) + case "time.Date": + eCtx.Value, err = callTimeDate(args) + case "time.Now": + eCtx.Value, err = callTimeNow(args) + case "time.Unix": + eCtx.Value, err = callTimeUnix(args) + case "time.UnixMilli": + eCtx.Value, err = callTimeUnixMilli(args) + case "time.DiffMilli": + eCtx.Value, err = callTimeDiffMilli(args) + case "time.Before": + eCtx.Value, err = callTimeBefore(args) + case "time.After": + eCtx.Value, err = callTimeAfter(args) + case "time.FixedZone": + eCtx.Value, err = callTimeFixedZone(args) + case "re.MatchString": + eCtx.Value, err = callReMatchString(args) + case "strings.ReplaceAll": + eCtx.Value, err = callStringsReplaceAll(args) + case "fmt.Sprintf": + eCtx.Value, err = callFmtSprintf(args) + + // Aggregate functions, to be used only in grouped lookups + + case "string_agg": + eCtx.Value, err = eCtx.CallAggStringAgg(callExp, args) + case "sum": + eCtx.Value, err = eCtx.CallAggSum(callExp, args) + case "count": + eCtx.Value, err = eCtx.CallAggCount(callExp, args) + case "avg": + eCtx.Value, err = eCtx.CallAggAvg(callExp, args) + case "min": + eCtx.Value, err = eCtx.CallAggMin(callExp, args) + case "max": + eCtx.Value, err = eCtx.CallAggMax(callExp, args) + + default: + return nil, fmt.Errorf("cannot evaluate unsupported func '%s'", funcName) + } + return eCtx.Value, err +} + +func (eCtx *EvalCtx) Eval(exp ast.Expr) (any, error) { + switch exp := exp.(type) { + case *ast.BinaryExpr: + valLeftVolatile, err := eCtx.Eval(exp.X) + if err != nil { + return nil, err + } + + valRightVolatile, err := eCtx.Eval(exp.Y) + if err != nil { + return 0, err + } + + if exp.Op == token.ADD || exp.Op == token.SUB || exp.Op == token.MUL || exp.Op == token.QUO || exp.Op == token.REM { + switch valLeftVolatile.(type) { + case string: + eCtx.Value, err = eCtx.EvalBinaryString(valLeftVolatile, exp.Op, valRightVolatile) + return eCtx.Value, err + + default: + // Assume both args are numbers (int, float, dec) + stdArgLeft, stdArgRight, err := castNumberPairToCommonType(valLeftVolatile, valRightVolatile) + if err != nil { + return nil, fmt.Errorf("cannot perform binary arithmetic op, incompatible arg types '%v(%T)' %v '%v(%T)' ", valLeftVolatile, valLeftVolatile, exp.Op, valRightVolatile, valRightVolatile) + } + switch stdArgLeft.(type) { + case int64: + eCtx.Value, err = eCtx.EvalBinaryInt(stdArgLeft, exp.Op, stdArgRight) + return eCtx.Value, err + case float64: + eCtx.Value, err = eCtx.EvalBinaryFloat64(stdArgLeft, exp.Op, stdArgRight) + return eCtx.Value, err + case decimal.Decimal: + eCtx.Value, err = eCtx.EvalBinaryDecimal2(stdArgLeft, exp.Op, stdArgRight) + return eCtx.Value, err + default: + return nil, fmt.Errorf("cannot perform binary arithmetic op, unexpected std type '%v(%T)' %v '%v(%T)' ", valLeftVolatile, valLeftVolatile, exp.Op, valRightVolatile, valRightVolatile) + } + } + } else if exp.Op == token.LOR || exp.Op == token.LAND { + switch valLeftTyped := valLeftVolatile.(type) { + case bool: + eCtx.Value, err = eCtx.EvalBinaryBool(valLeftTyped, exp.Op, valRightVolatile) + return eCtx.Value, err + default: + return nil, fmt.Errorf("cannot perform binary op %v against %T left", exp.Op, valLeftVolatile) + } + + } else if exp.Op == token.GTR || exp.Op == token.GEQ || exp.Op == token.LSS || exp.Op == token.LEQ || exp.Op == token.EQL || exp.Op == token.NEQ { + switch valLeftVolatile.(type) { + case time.Time: + eCtx.Value, err = eCtx.EvalBinaryTimeToBool(valLeftVolatile, exp.Op, valRightVolatile) + return eCtx.Value, err + case string: + eCtx.Value, err = eCtx.EvalBinaryStringToBool(valLeftVolatile, exp.Op, valRightVolatile) + return eCtx.Value, err + case bool: + eCtx.Value, err = eCtx.EvalBinaryBoolToBool(valLeftVolatile, exp.Op, valRightVolatile) + return eCtx.Value, err + default: + // Assume both args are numbers (int, float, dec) + stdArgLeft, stdArgRight, err := castNumberPairToCommonType(valLeftVolatile, valRightVolatile) + if err != nil { + return nil, fmt.Errorf("cannot perform binary comp op, incompatible arg types '%v(%T)' %v '%v(%T)' ", valLeftVolatile, valLeftVolatile, exp.Op, valRightVolatile, valRightVolatile) + } + switch stdArgLeft.(type) { + case int64: + eCtx.Value, err = eCtx.EvalBinaryIntToBool(stdArgLeft, exp.Op, stdArgRight) + return eCtx.Value, err + case float64: + eCtx.Value, err = eCtx.EvalBinaryFloat64ToBool(stdArgLeft, exp.Op, stdArgRight) + return eCtx.Value, err + case decimal.Decimal: + eCtx.Value, err = eCtx.EvalBinaryDecimal2ToBool(stdArgLeft, exp.Op, stdArgRight) + return eCtx.Value, err + default: + return nil, fmt.Errorf("cannot perform binary comp op, unexpected std type '%v(%T)' %v '%v(%T)' ", valLeftVolatile, valLeftVolatile, exp.Op, valRightVolatile, valRightVolatile) + } + } + } else { + return nil, fmt.Errorf("cannot perform binary expression unknown op %v", exp.Op) + } + case *ast.BasicLit: + switch exp.Kind { + case token.INT: + i, _ := strconv.ParseInt(exp.Value, 10, 64) + eCtx.Value = i + return i, nil + case token.FLOAT: + i, _ := strconv.ParseFloat(exp.Value, 64) + eCtx.Value = i + return i, nil + case token.IDENT: + return nil, fmt.Errorf("cannot evaluate expression %s of type token.IDENT", exp.Value) + case token.STRING: + eCtx.Value = exp.Value + if exp.Value[0] == '"' { + return strings.Trim(exp.Value, "\""), nil + } else { + return strings.Trim(exp.Value, "`"), nil + } + default: + return nil, fmt.Errorf("cannot evaluate expression %s of type %v", exp.Value, exp.Kind) + } + case *ast.UnaryExpr: + switch exp.Op { + case token.NOT: + var err error + eCtx.Value, err = eCtx.EvalUnaryBoolNot(exp.X) + return eCtx.Value, err + case token.SUB: + var err error + eCtx.Value, err = eCtx.EvalUnaryMinus(exp.X) + return eCtx.Value, err + default: + return nil, fmt.Errorf("cannot evaluate unary op %v, unknown op", exp.Op) + } + + case *ast.Ident: + if exp.Name == "true" { + eCtx.Value = true + return true, nil + } else if exp.Name == "false" { + eCtx.Value = false + return false, nil + } else { + return nil, fmt.Errorf("cannot evaluate identifier %s", exp.Name) + } + + case *ast.CallExpr: + args := make([]any, len(exp.Args)) + + for i, v := range exp.Args { + arg, err := eCtx.Eval(v) + if err != nil { + return nil, err + } + args[i] = arg + } + + switch exp.Fun.(type) { + case *ast.Ident: + funcIdent, _ := exp.Fun.(*ast.Ident) //revive:disable-line + var err error + eCtx.Value, err = eCtx.EvalFunc(exp, funcIdent.Name, args) + return eCtx.Value, err + + case *ast.SelectorExpr: + expSel := exp.Fun.(*ast.SelectorExpr) //revive:disable-line + switch expSel.X.(type) { + case *ast.Ident: + expIdent, _ := expSel.X.(*ast.Ident) //revive:disable-line + var err error + eCtx.Value, err = eCtx.EvalFunc(exp, fmt.Sprintf("%s.%s", expIdent.Name, expSel.Sel.Name), args) + return eCtx.Value, err + default: + return nil, fmt.Errorf("cannot evaluate fun expression %v, unknown type of X: %T", expSel.X, expSel.X) + } + + default: + return nil, fmt.Errorf("cannot evaluate func call expression %v, unknown type of X: %T", exp.Fun, exp.Fun) + } + + case *ast.SelectorExpr: + switch exp.X.(type) { + case *ast.Ident: + objectIdent, _ := exp.X.(*ast.Ident) //revive:disable-line + golangConst, ok := GolangConstants[fmt.Sprintf("%s.%s", objectIdent.Name, exp.Sel.Name)] + if ok { + eCtx.Value = golangConst + return golangConst, nil + } + + if eCtx.Vars == nil { + return nil, fmt.Errorf("cannot evaluate expression '%s', no variables supplied to the context", objectIdent.Name) + } + + objectAttributes, ok := (*eCtx.Vars)[objectIdent.Name] + if !ok { + return nil, fmt.Errorf("cannot evaluate expression '%s', variable not supplied, check table/alias name", objectIdent.Name) + } + + val, ok := objectAttributes[exp.Sel.Name] + if !ok { + return nil, fmt.Errorf("cannot evaluate expression %s.%s, variable not supplied, check field name", objectIdent.Name, exp.Sel.Name) + } + eCtx.Value = val + return val, nil + default: + return nil, fmt.Errorf("cannot evaluate selector expression %v, unknown type of X: %T", exp.X, exp.X) + } + + default: + return nil, fmt.Errorf("cannot evaluate generic expression %v of unknown type %T", exp, exp) + } +} diff --git a/pkg/eval/eval_ctx_test.go b/pkg/eval/eval_ctx_test.go index c9f9423..db90e2c 100644 --- a/pkg/eval/eval_ctx_test.go +++ b/pkg/eval/eval_ctx_test.go @@ -1,426 +1,427 @@ -package eval - -import ( - "fmt" - "go/parser" - "go/token" - "math" - "testing" - "time" - - "github.com/shopspring/decimal" - "github.com/stretchr/testify/assert" -) - -func assertEqual(t *testing.T, expString string, expectedResult interface{}, varValuesMap VarValuesMap) { - exp, err1 := parser.ParseExpr(expString) - if err1 != nil { - t.Error(fmt.Errorf("%s: %s", expString, err1.Error())) - return - } - eCtx := NewPlainEvalCtxWithVars(AggFuncDisabled, &varValuesMap) - result, err2 := eCtx.Eval(exp) - if err2 != nil { - t.Error(fmt.Errorf("%s: %s", expString, err2.Error())) - return - } - - assert.Equal(t, expectedResult, result, fmt.Sprintf("Unmatched: %v = %v: %s ", expectedResult, result, expString)) -} - -func assertFloatNan(t *testing.T, expString string, varValuesMap VarValuesMap) { - exp, err1 := parser.ParseExpr(expString) - if err1 != nil { - t.Error(fmt.Errorf("%s: %s", expString, err1.Error())) - return - } - eCtx := NewPlainEvalCtxWithVars(AggFuncDisabled, &varValuesMap) - result, err2 := eCtx.Eval(exp) - if err2 != nil { - t.Error(fmt.Errorf("%s: %s", expString, err2.Error())) - return - } - floatResult, _ := result.(float64) - assert.True(t, math.IsNaN(floatResult)) -} - -func assertEvalError(t *testing.T, expString string, expectedErrorMsg string, varValuesMap VarValuesMap) { - exp, err1 := parser.ParseExpr(expString) - if err1 != nil { - assert.Equal(t, expectedErrorMsg, err1.Error(), fmt.Sprintf("Unmatched: %v = %v: %s ", expectedErrorMsg, err1.Error(), expString)) - return - } - eCtx := NewPlainEvalCtxWithVars(AggFuncDisabled, &varValuesMap) - _, err2 := eCtx.Eval(exp) - - assert.Equal(t, expectedErrorMsg, err2.Error(), fmt.Sprintf("Unmatched: %v = %v: %s ", expectedErrorMsg, err2.Error(), expString)) -} - -func TestBad(t *testing.T) { - // Missing identifier - assertEvalError(t, "some(", "1:6: expected ')', found 'EOF'", VarValuesMap{}) - - // Missing identifier - assertEvalError(t, "someident", "cannot evaluate identifier someident", VarValuesMap{}) - assertEvalError(t, "somefunc()", "cannot evaluate unsupported func 'somefunc'", VarValuesMap{}) - assertEvalError(t, "t2.aaa == 1", "cannot evaluate expression 't2', variable not supplied, check table/alias name", VarValuesMap{}) - - // Unsupported binary operators - assertEvalError(t, "2 ^ 1", "cannot perform binary expression unknown op ^", VarValuesMap{}) // TODO: implement ^ xor - assertEvalError(t, "2 << 1", "cannot perform binary expression unknown op <<", VarValuesMap{}) // TODO: implement >> and << - assertEvalError(t, "1 &^ 2", "cannot perform binary expression unknown op &^", VarValuesMap{}) // No plans to support this op - - // Unsupported unary operators - assertEvalError(t, "&1", "cannot evaluate unary op &, unkown op", VarValuesMap{}) - - // Unsupported selector expr - assertEvalError(t, "t1.fieldInt.w", "cannot evaluate selector expression &{t1 fieldInt}, unknown type of X: *ast.SelectorExpr", VarValuesMap{"t1": {"fieldInt": 1}}) -} - -func TestConvertEval(t *testing.T) { - varValuesMap := VarValuesMap{ - "t1": { - "fieldInt": 1, - "fieldInt16": int16(1), - "fieldInt32": int32(1), - "fieldInt64": int16(1), - "fieldFloat32": float32(1.0), - "fieldFloat64": float64(1.0), - "fieldDecimal2": decimal.NewFromInt(1), - }, - } - - // Number to number - for k, _ := range varValuesMap["t1"] { - assertEqual(t, fmt.Sprintf("decimal2(t1.%s) == 1", k), true, varValuesMap) - assertEqual(t, fmt.Sprintf("float(t1.%s) == 1.0", k), true, varValuesMap) - assertEqual(t, fmt.Sprintf("int(t1.%s) == 1", k), true, varValuesMap) - } - - // String to number - assertEqual(t, `int("1") == 1`, true, varValuesMap) - assertEqual(t, `float("1.0") == 1.0`, true, varValuesMap) - assertEqual(t, `decimal2("1.0") == 1.0`, true, varValuesMap) - - // Number to string - assertEqual(t, `string(1) == "1"`, true, varValuesMap) - assertEqual(t, `string(1.1) == "1.1"`, true, varValuesMap) - assertEqual(t, `string(decimal2(1.1)) == "1.1"`, true, varValuesMap) -} - -func TestArithmetic(t *testing.T) { - varValuesMap := VarValuesMap{ - "t1": { - "fieldInt": 1, - "fieldInt16": int16(1), - "fieldInt32": int32(1), - "fieldInt64": int16(1), - "fieldFloat32": float32(1.0), - "fieldFloat64": float64(1.0), - "fieldDecimal2": decimal.NewFromInt(1), - }, - "t2": { - "fieldInt": 2, - "fieldInt16": int16(2), - "fieldInt32": int32(2), - "fieldInt64": int16(2), - "fieldFloat32": float32(2.0), - "fieldFloat64": float64(2.0), - "fieldDecimal2": decimal.NewFromInt(2), - }, - } - for k1, _ := range varValuesMap["t1"] { - for k2, _ := range varValuesMap["t2"] { - assertEqual(t, fmt.Sprintf("t1.%s + t2.%s == 3", k1, k2), true, varValuesMap) - assertEqual(t, fmt.Sprintf("t1.%s - t2.%s == -1", k1, k2), true, varValuesMap) - assertEqual(t, fmt.Sprintf("t1.%s * t2.%s == 2", k1, k2), true, varValuesMap) - assertEqual(t, fmt.Sprintf("t2.%s / t1.%s == 2", k1, k2), true, varValuesMap) - } - } - - // Integer div - assertEqual(t, "t1.fieldInt / t2.fieldInt == 0", true, varValuesMap) - assertEqual(t, "t1.fieldInt % t2.fieldInt == 1", true, varValuesMap) - - // Float div - assertEqual(t, "t1.fieldInt / t2.fieldFloat32 == 0.5", true, varValuesMap) - assertEqual(t, "t1.fieldInt / t2.fieldDecimal2 == 0.5", true, varValuesMap) - assertEqual(t, "t1.fieldInt / t2.fieldFloat32 == 0.5", true, varValuesMap) - assertEqual(t, "t1.fieldInt / t2.fieldDecimal2 == 0.5", true, varValuesMap) - assertEqual(t, "t1.fieldDecimal2 / t2.fieldInt == 0.5", true, varValuesMap) - assertEqual(t, "t1.fieldInt / t2.fieldDecimal2 == 0.5", true, varValuesMap) - - // Div by zero - assertEvalError(t, "t1.fieldInt / 0", "runtime error: integer divide by zero", varValuesMap) - assertEqual(t, "t1.fieldFloat32 / 0", math.Inf(1), varValuesMap) - assertEvalError(t, "t1.fieldDecimal2 / 0", "decimal division by 0", varValuesMap) - - // Bad types - assertEvalError(t, "t1.fieldDecimal2 / `a`", "cannot perform binary arithmetic op, incompatible arg types '1(decimal.Decimal)' / 'a(string)' ", varValuesMap) - assertEvalError(t, "-`a`", "cannot evaluate unary minus expression '-a(string)', unsupported type", varValuesMap) - - // String - varValuesMap = VarValuesMap{ - "t1": { - "field1": "aaa", - "field2": `c"cc`, - }, - } - assertEqual(t, `t1.field1+t1.field2+"d"`, `aaac"ccd`, varValuesMap) - -} - -func TestCompare(t *testing.T) { - varValuesMap := VarValuesMap{ - "t1": { - "fieldInt": 1, - "fieldInt16": int16(1), - "fieldInt32": int32(1), - "fieldInt64": int16(1), - "fieldFloat32": float32(1.0), - "fieldFloat64": float64(1.0), - "fieldDecimal2": decimal.NewFromInt(1), - }, - "t2": { - "fieldInt": 2, - "fieldInt16": int16(2), - "fieldInt32": int32(2), - "fieldInt64": int16(2), - "fieldFloat32": float32(2.0), - "fieldFloat64": float64(2.0), - "fieldDecimal2": decimal.NewFromInt(2), - }, - } - for k1, _ := range varValuesMap["t1"] { - for k2, _ := range varValuesMap["t2"] { - assertEqual(t, fmt.Sprintf("t1.%s == t2.%s", k1, k2), false, varValuesMap) - assertEqual(t, fmt.Sprintf("t1.%s != t2.%s", k1, k2), true, varValuesMap) - assertEqual(t, fmt.Sprintf("t1.%s < t2.%s", k1, k2), true, varValuesMap) - assertEqual(t, fmt.Sprintf("t1.%s <= t2.%s", k1, k2), true, varValuesMap) - assertEqual(t, fmt.Sprintf("t2.%s > t1.%s", k1, k2), true, varValuesMap) - assertEqual(t, fmt.Sprintf("t2.%s >= t1.%s", k1, k2), true, varValuesMap) - } - } - - // Bool - assertEqual(t, "false == false", true, varValuesMap) - assertEqual(t, "false == !true", true, varValuesMap) - assertEqual(t, "false == true", false, varValuesMap) - - // String - assertEqual(t, `"aaa" != "b"`, true, varValuesMap) - assertEqual(t, `"aaa" == "b"`, false, varValuesMap) - assertEqual(t, `"aaa" < "b"`, true, varValuesMap) - assertEqual(t, `"aaa" <= "b"`, true, varValuesMap) - assertEqual(t, `"aaa" > "b"`, false, varValuesMap) - assertEqual(t, `"aaa" >= "b"`, false, varValuesMap) - assertEqual(t, `"aaa" == "aaa"`, true, varValuesMap) - assertEqual(t, `"aaa" != "aaa"`, false, varValuesMap) - - assertEvalError(t, "1 > true", "cannot perform binary comp op, incompatible arg types '1(int64)' > 'true(bool)' ", varValuesMap) -} - -func TestBool(t *testing.T) { - varValuesMap := VarValuesMap{} - assertEqual(t, `true && true`, true, varValuesMap) - assertEqual(t, `true && false`, false, varValuesMap) - assertEqual(t, `true || true`, true, varValuesMap) - assertEqual(t, `true || false`, true, varValuesMap) - assertEqual(t, `!false`, true, varValuesMap) - assertEqual(t, `!true`, false, varValuesMap) - - assertEvalError(t, `!123`, "cannot evaluate unary bool not expression with int64 on the right", varValuesMap) - assertEvalError(t, "true || 1", "cannot evaluate binary bool expression 'true(bool) || 1(int64)', invalid right arg", varValuesMap) - assertEvalError(t, "1 || true", "cannot perform binary op || against int64 left", varValuesMap) -} - -func TestUnaryMinus(t *testing.T) { - varValuesMap := VarValuesMap{ - "t1": { - "fieldInt": 1, - "fieldInt16": int16(1), - "fieldInt32": int32(1), - "fieldInt64": int16(1), - "fieldFloat32": float32(1.0), - "fieldFloat64": float64(1.0), - "fieldDecimal2": decimal.NewFromInt(1), - }, - } - for k, _ := range varValuesMap["t1"] { - assertEqual(t, fmt.Sprintf("-t1.%s == -1", k), true, varValuesMap) - } -} - -func TestTime(t *testing.T) { - varValuesMap := VarValuesMap{ - "t1": { - "fTime": time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC), - }, - "t2": { - "fTime": time.Date(1, 1, 1, 1, 1, 1, 2, time.UTC), - }, - } - assertEqual(t, `t1.fTime < t2.fTime`, true, varValuesMap) - assertEqual(t, `t1.fTime <= t2.fTime`, true, varValuesMap) - assertEqual(t, `t2.fTime >= t1.fTime`, true, varValuesMap) - assertEqual(t, `t2.fTime > t1.fTime`, true, varValuesMap) - assertEqual(t, `t1.fTime == t2.fTime`, false, varValuesMap) - assertEqual(t, `t1.fTime != t2.fTime`, true, varValuesMap) -} - -func TestNewPlainEvalCtxAndInitializedAgg(t *testing.T) { - varValuesMap := getTestValuesMap() - varValuesMap["t1"]["fieldStr"] = "a" - - exp, _ := parser.ParseExpr(`string_agg(t1.fieldStr,",")`) - aggEnabledType, aggFuncType, aggFuncArgs := DetectRootAggFunc(exp) - eCtx, err := NewPlainEvalCtxAndInitializedAgg(aggEnabledType, aggFuncType, aggFuncArgs) - assert.Equal(t, AggTypeString, eCtx.AggType) - assert.Nil(t, err) - - exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr,1)`) - aggEnabledType, aggFuncType, aggFuncArgs = DetectRootAggFunc(exp) - eCtx, err = NewPlainEvalCtxAndInitializedAgg(aggEnabledType, aggFuncType, aggFuncArgs) - assert.Equal(t, "string_agg second parameter must be a constant string", err.Error()) - - exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr, a)`) - aggEnabledType, aggFuncType, aggFuncArgs = DetectRootAggFunc(exp) - eCtx, err = NewPlainEvalCtxAndInitializedAgg(aggEnabledType, aggFuncType, aggFuncArgs) - assert.Equal(t, "string_agg second parameter must be a basic literal", err.Error()) -} - -type EvalFunc int - -const ( - BinaryIntFunc = iota - BinaryIntToBoolFunc - BinaryFloat64Func - BinaryFloat64ToBoolFunc - BinaryDecimal2Func - BinaryDecimal2ToBoolFunc - BinaryTimeToBoolFunc - BinaryBoolFunc - BinaryBoolToBoolFunc - BinaryStringFunc - BinaryStringToBoolFunc -) - -func assertBinaryEval(t *testing.T, evalFunc EvalFunc, valLeftVolatile interface{}, op token.Token, valRightVolatile interface{}, errorMessage string) { - var err error - eCtx := NewPlainEvalCtx(AggFuncDisabled) - switch evalFunc { - case BinaryIntFunc: - _, err = eCtx.EvalBinaryInt(valLeftVolatile, op, valRightVolatile) - case BinaryIntToBoolFunc: - _, err = eCtx.EvalBinaryIntToBool(valLeftVolatile, op, valRightVolatile) - case BinaryFloat64Func: - _, err = eCtx.EvalBinaryFloat64(valLeftVolatile, op, valRightVolatile) - case BinaryFloat64ToBoolFunc: - _, err = eCtx.EvalBinaryFloat64ToBool(valLeftVolatile, op, valRightVolatile) - case BinaryDecimal2Func: - _, err = eCtx.EvalBinaryDecimal2(valLeftVolatile, op, valRightVolatile) - case BinaryDecimal2ToBoolFunc: - _, err = eCtx.EvalBinaryDecimal2ToBool(valLeftVolatile, op, valRightVolatile) - case BinaryTimeToBoolFunc: - _, err = eCtx.EvalBinaryTimeToBool(valLeftVolatile, op, valRightVolatile) - case BinaryBoolFunc: - _, err = eCtx.EvalBinaryBool(valLeftVolatile, op, valRightVolatile) - case BinaryBoolToBoolFunc: - _, err = eCtx.EvalBinaryBoolToBool(valLeftVolatile, op, valRightVolatile) - case BinaryStringFunc: - _, err = eCtx.EvalBinaryString(valLeftVolatile, op, valRightVolatile) - case BinaryStringToBoolFunc: - _, err = eCtx.EvalBinaryStringToBool(valLeftVolatile, op, valRightVolatile) - default: - assert.Fail(t, "unsupported EvalFunc") - } - assert.Equal(t, errorMessage, err.Error()) -} - -func TestBadEvalBinaryInt(t *testing.T) { - goodVal := int64(1) - badVal := "a" - assertBinaryEval(t, BinaryIntFunc, badVal, token.ADD, goodVal, "cannot evaluate binary int64 expression '+' with 'a(string)' on the left") - assertBinaryEval(t, BinaryIntFunc, goodVal, token.ADD, badVal, "cannot evaluate binary int64 expression '1(int64) + a(string)', invalid right arg") - assertBinaryEval(t, BinaryIntFunc, goodVal, token.AND, goodVal, "cannot perform int op & against int 1 and int 1") -} - -func TestBadEvalBinaryIntToBool(t *testing.T) { - goodVal := int64(1) - badVal := "a" - assertBinaryEval(t, BinaryIntToBoolFunc, badVal, token.LSS, goodVal, "cannot evaluate binary int64 expression '<' with 'a(string)' on the left") - assertBinaryEval(t, BinaryIntToBoolFunc, goodVal, token.LSS, badVal, "cannot evaluate binary int64 expression '1(int64) < a(string)', invalid right arg") - assertBinaryEval(t, BinaryIntToBoolFunc, goodVal, token.ADD, int64(1), "cannot perform bool op + against int 1 and int 1") -} - -func TestBadEvalBinaryFloat64(t *testing.T) { - goodVal := float64(1) - badVal := "a" - assertBinaryEval(t, BinaryFloat64Func, badVal, token.ADD, goodVal, "cannot evaluate binary float64 expression '+' with 'a(string)' on the left") - assertBinaryEval(t, BinaryFloat64Func, goodVal, token.ADD, badVal, "cannot evaluate binary float expression '1(float64) + a(string)', invalid right arg") - assertBinaryEval(t, BinaryFloat64Func, goodVal, token.AND, goodVal, "cannot perform float64 op & against float64 1.000000 and float64 1.000000") -} - -func TestBadEvalBinaryFloat64ToBool(t *testing.T) { - goodVal := float64(1) - badVal := "a" - assertBinaryEval(t, BinaryFloat64ToBoolFunc, badVal, token.LSS, goodVal, "cannot evaluate binary foat64 expression '<' with 'a(string)' on the left") - assertBinaryEval(t, BinaryFloat64ToBoolFunc, goodVal, token.LSS, badVal, "cannot evaluate binary float64 expression '1(float64) < a(string)', invalid right arg") - assertBinaryEval(t, BinaryFloat64ToBoolFunc, goodVal, token.ADD, goodVal, "cannot perform bool op + against float 1.000000 and float 1.000000") -} - -func TestBadEvalBinaryDecimal2(t *testing.T) { - goodVal := decimal.NewFromFloat(1) - badVal := "a" - assertBinaryEval(t, BinaryDecimal2Func, badVal, token.ADD, goodVal, "cannot evaluate binary decimal2 expression '+' with 'a(string)' on the left") - assertBinaryEval(t, BinaryDecimal2Func, goodVal, token.ADD, badVal, "cannot evaluate binary decimal2 expression '1(decimal.Decimal) + a(string)', invalid right arg") - assertBinaryEval(t, BinaryDecimal2Func, goodVal, token.AND, goodVal, "cannot perform decimal2 op & against decimal2 1 and float64 1") -} - -func TestBadEvalBinaryDecimal2Bool(t *testing.T) { - goodVal := decimal.NewFromFloat(1) - badVal := "a" - assertBinaryEval(t, BinaryDecimal2ToBoolFunc, badVal, token.LSS, goodVal, "cannot evaluate binary decimal2 expression '<' with 'a(string)' on the left") - assertBinaryEval(t, BinaryDecimal2ToBoolFunc, goodVal, token.LSS, badVal, "cannot evaluate binary decimal2 expression '1(decimal.Decimal) < a(string)', invalid right arg") - assertBinaryEval(t, BinaryDecimal2ToBoolFunc, goodVal, token.ADD, goodVal, "cannot perform bool op + against decimal2 1 and decimal2 1") -} - -func TestBadEvalBinaryTimeBool(t *testing.T) { - goodVal := time.Date(2000, 1, 1, 0, 0, 0, 0, time.FixedZone("", -7200)) - badVal := "a" - assertBinaryEval(t, BinaryTimeToBoolFunc, badVal, token.LSS, goodVal, "cannot evaluate binary time expression '<' with 'a(string)' on the left") - assertBinaryEval(t, BinaryTimeToBoolFunc, goodVal, token.LSS, badVal, "cannot evaluate binary time expression '2000-01-01 00:00:00 -0200 -0200(time.Time) < a(string)', invalid right arg") - assertBinaryEval(t, BinaryTimeToBoolFunc, goodVal, token.ADD, goodVal, "cannot perform bool op + against time 2000-01-01 00:00:00 -0200 -0200 and time 2000-01-01 00:00:00 -0200 -0200") -} - -func TestBadEvalBinaryBool(t *testing.T) { - goodVal := true - badVal := "a" - assertBinaryEval(t, BinaryBoolFunc, badVal, token.LAND, goodVal, "cannot evaluate binary bool expression '&&' with 'a(string)' on the left") - assertBinaryEval(t, BinaryBoolFunc, goodVal, token.LOR, badVal, "cannot evaluate binary bool expression 'true(bool) || a(string)', invalid right arg") - assertBinaryEval(t, BinaryBoolFunc, goodVal, token.ADD, goodVal, "cannot perform bool op + against bool true and bool true") -} - -func TestBadEvalBinaryBoolToBool(t *testing.T) { - goodVal := true - badVal := "a" - assertBinaryEval(t, BinaryBoolToBoolFunc, badVal, token.EQL, goodVal, "cannot evaluate binary bool expression == with string on the left") - assertBinaryEval(t, BinaryBoolToBoolFunc, goodVal, token.NEQ, badVal, "cannot evaluate binary bool expression 'true(bool) != a(string)', invalid right arg") - assertBinaryEval(t, BinaryBoolToBoolFunc, goodVal, token.ADD, goodVal, "cannot evaluate binary bool expression, op + not supported (and will never be)") -} - -func TestBadEvalBinaryString(t *testing.T) { - goodVal := "good" - badVal := 1 - assertBinaryEval(t, BinaryStringFunc, badVal, token.ADD, goodVal, "cannot evaluate binary string expression + with int on the left") - assertBinaryEval(t, BinaryStringFunc, goodVal, token.ADD, badVal, "cannot evaluate binary string expression 'good(string) + 1(int)', invalid right arg") - assertBinaryEval(t, BinaryStringFunc, goodVal, token.AND, goodVal, "cannot perform string op & against string 'good' and string 'good', op not supported") -} - -func TestBadEvalBinaryStringToBool(t *testing.T) { - goodVal := "good" - badVal := 1 - assertBinaryEval(t, BinaryStringToBoolFunc, badVal, token.LSS, goodVal, "cannot evaluate binary string expression < with '1(int)' on the left") - assertBinaryEval(t, BinaryStringToBoolFunc, goodVal, token.GTR, badVal, "cannot evaluate binary decimal2 expression 'good(string) > 1(int)', invalid right arg") - assertBinaryEval(t, BinaryStringToBoolFunc, goodVal, token.AND, goodVal, "cannot perform bool op & against string good and string good") -} +package eval + +import ( + "fmt" + "go/parser" + "go/token" + "math" + "testing" + "time" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" +) + +func assertEqual(t *testing.T, expString string, expectedResult any, varValuesMap VarValuesMap) { + exp, err1 := parser.ParseExpr(expString) + if err1 != nil { + t.Error(fmt.Errorf("%s: %s", expString, err1.Error())) + return + } + eCtx := NewPlainEvalCtxWithVars(AggFuncDisabled, &varValuesMap) + result, err2 := eCtx.Eval(exp) + if err2 != nil { + t.Error(fmt.Errorf("%s: %s", expString, err2.Error())) + return + } + + assert.Equal(t, expectedResult, result, fmt.Sprintf("Unmatched: %v = %v: %s ", expectedResult, result, expString)) +} + +func assertFloatNan(t *testing.T, expString string, varValuesMap VarValuesMap) { + exp, err1 := parser.ParseExpr(expString) + if err1 != nil { + t.Error(fmt.Errorf("%s: %s", expString, err1.Error())) + return + } + eCtx := NewPlainEvalCtxWithVars(AggFuncDisabled, &varValuesMap) + result, err2 := eCtx.Eval(exp) + if err2 != nil { + t.Error(fmt.Errorf("%s: %s", expString, err2.Error())) + return + } + floatResult, ok := result.(float64) + assert.True(t, ok) + assert.True(t, math.IsNaN(floatResult)) +} + +func assertEvalError(t *testing.T, expString string, expectedErrorMsg string, varValuesMap VarValuesMap) { + exp, err1 := parser.ParseExpr(expString) + if err1 != nil { + assert.Equal(t, expectedErrorMsg, err1.Error(), fmt.Sprintf("Unmatched: %v = %v: %s ", expectedErrorMsg, err1.Error(), expString)) + return + } + eCtx := NewPlainEvalCtxWithVars(AggFuncDisabled, &varValuesMap) + _, err2 := eCtx.Eval(exp) + + assert.Equal(t, expectedErrorMsg, err2.Error(), fmt.Sprintf("Unmatched: %v = %v: %s ", expectedErrorMsg, err2.Error(), expString)) +} + +func TestBad(t *testing.T) { + // Missing identifier + assertEvalError(t, "some(", "1:6: expected ')', found 'EOF'", VarValuesMap{}) + + // Missing identifier + assertEvalError(t, "someident", "cannot evaluate identifier someident", VarValuesMap{}) + assertEvalError(t, "somefunc()", "cannot evaluate unsupported func 'somefunc'", VarValuesMap{}) + assertEvalError(t, "t2.aaa == 1", "cannot evaluate expression 't2', variable not supplied, check table/alias name", VarValuesMap{}) + + // Unsupported binary operators + assertEvalError(t, "2 ^ 1", "cannot perform binary expression unknown op ^", VarValuesMap{}) // TODO: implement ^ xor + assertEvalError(t, "2 << 1", "cannot perform binary expression unknown op <<", VarValuesMap{}) // TODO: implement >> and << + assertEvalError(t, "1 &^ 2", "cannot perform binary expression unknown op &^", VarValuesMap{}) // No plans to support this op + + // Unsupported unary operators + assertEvalError(t, "&1", "cannot evaluate unary op &, unknown op", VarValuesMap{}) + + // Unsupported selector expr + assertEvalError(t, "t1.fieldInt.w", "cannot evaluate selector expression &{t1 fieldInt}, unknown type of X: *ast.SelectorExpr", VarValuesMap{"t1": {"fieldInt": 1}}) +} + +func TestConvertEval(t *testing.T) { + varValuesMap := VarValuesMap{ + "t1": { + "fieldInt": 1, + "fieldInt16": int16(1), + "fieldInt32": int32(1), + "fieldInt64": int16(1), + "fieldFloat32": float32(1.0), + "fieldFloat64": float64(1.0), + "fieldDecimal2": decimal.NewFromInt(1), + }, + } + + // Number to number + for fldName := range varValuesMap["t1"] { + assertEqual(t, fmt.Sprintf("decimal2(t1.%s) == 1", fldName), true, varValuesMap) + assertEqual(t, fmt.Sprintf("float(t1.%s) == 1.0", fldName), true, varValuesMap) + assertEqual(t, fmt.Sprintf("int(t1.%s) == 1", fldName), true, varValuesMap) + } + + // String to number + assertEqual(t, `int("1") == 1`, true, varValuesMap) + assertEqual(t, `float("1.0") == 1.0`, true, varValuesMap) + assertEqual(t, `decimal2("1.0") == 1.0`, true, varValuesMap) + + // Number to string + assertEqual(t, `string(1) == "1"`, true, varValuesMap) + assertEqual(t, `string(1.1) == "1.1"`, true, varValuesMap) + assertEqual(t, `string(decimal2(1.1)) == "1.1"`, true, varValuesMap) +} + +func TestArithmetic(t *testing.T) { + varValuesMap := VarValuesMap{ + "t1": { + "fieldInt": 1, + "fieldInt16": int16(1), + "fieldInt32": int32(1), + "fieldInt64": int16(1), + "fieldFloat32": float32(1.0), + "fieldFloat64": float64(1.0), + "fieldDecimal2": decimal.NewFromInt(1), + }, + "t2": { + "fieldInt": 2, + "fieldInt16": int16(2), + "fieldInt32": int32(2), + "fieldInt64": int16(2), + "fieldFloat32": float32(2.0), + "fieldFloat64": float64(2.0), + "fieldDecimal2": decimal.NewFromInt(2), + }, + } + for k1 := range varValuesMap["t1"] { + for k2 := range varValuesMap["t2"] { + assertEqual(t, fmt.Sprintf("t1.%s + t2.%s == 3", k1, k2), true, varValuesMap) + assertEqual(t, fmt.Sprintf("t1.%s - t2.%s == -1", k1, k2), true, varValuesMap) + assertEqual(t, fmt.Sprintf("t1.%s * t2.%s == 2", k1, k2), true, varValuesMap) + assertEqual(t, fmt.Sprintf("t2.%s / t1.%s == 2", k1, k2), true, varValuesMap) + } + } + + // Integer div + assertEqual(t, "t1.fieldInt / t2.fieldInt == 0", true, varValuesMap) + assertEqual(t, "t1.fieldInt % t2.fieldInt == 1", true, varValuesMap) + + // Float div + assertEqual(t, "t1.fieldInt / t2.fieldFloat32 == 0.5", true, varValuesMap) + assertEqual(t, "t1.fieldInt / t2.fieldDecimal2 == 0.5", true, varValuesMap) + assertEqual(t, "t1.fieldInt / t2.fieldFloat32 == 0.5", true, varValuesMap) + assertEqual(t, "t1.fieldInt / t2.fieldDecimal2 == 0.5", true, varValuesMap) + assertEqual(t, "t1.fieldDecimal2 / t2.fieldInt == 0.5", true, varValuesMap) + assertEqual(t, "t1.fieldInt / t2.fieldDecimal2 == 0.5", true, varValuesMap) + + // Div by zero + assertEvalError(t, "t1.fieldInt / 0", "runtime error: integer divide by zero", varValuesMap) + assertEqual(t, "t1.fieldFloat32 / 0", math.Inf(1), varValuesMap) + assertEvalError(t, "t1.fieldDecimal2 / 0", "decimal division by 0", varValuesMap) + + // Bad types + assertEvalError(t, "t1.fieldDecimal2 / `a`", "cannot perform binary arithmetic op, incompatible arg types '1(decimal.Decimal)' / 'a(string)' ", varValuesMap) + assertEvalError(t, "-`a`", "cannot evaluate unary minus expression '-a(string)', unsupported type", varValuesMap) + + // String + varValuesMap = VarValuesMap{ + "t1": { + "field1": "aaa", + "field2": `c"cc`, + }, + } + assertEqual(t, `t1.field1+t1.field2+"d"`, `aaac"ccd`, varValuesMap) + +} + +func TestCompare(t *testing.T) { + varValuesMap := VarValuesMap{ + "t1": { + "fieldInt": 1, + "fieldInt16": int16(1), + "fieldInt32": int32(1), + "fieldInt64": int16(1), + "fieldFloat32": float32(1.0), + "fieldFloat64": float64(1.0), + "fieldDecimal2": decimal.NewFromInt(1), + }, + "t2": { + "fieldInt": 2, + "fieldInt16": int16(2), + "fieldInt32": int32(2), + "fieldInt64": int16(2), + "fieldFloat32": float32(2.0), + "fieldFloat64": float64(2.0), + "fieldDecimal2": decimal.NewFromInt(2), + }, + } + for k1 := range varValuesMap["t1"] { + for k2 := range varValuesMap["t2"] { + assertEqual(t, fmt.Sprintf("t1.%s == t2.%s", k1, k2), false, varValuesMap) + assertEqual(t, fmt.Sprintf("t1.%s != t2.%s", k1, k2), true, varValuesMap) + assertEqual(t, fmt.Sprintf("t1.%s < t2.%s", k1, k2), true, varValuesMap) + assertEqual(t, fmt.Sprintf("t1.%s <= t2.%s", k1, k2), true, varValuesMap) + assertEqual(t, fmt.Sprintf("t2.%s > t1.%s", k1, k2), true, varValuesMap) + assertEqual(t, fmt.Sprintf("t2.%s >= t1.%s", k1, k2), true, varValuesMap) + } + } + + // Bool + assertEqual(t, "false == false", true, varValuesMap) + assertEqual(t, "false == !true", true, varValuesMap) + assertEqual(t, "false == true", false, varValuesMap) + + // String + assertEqual(t, `"aaa" != "b"`, true, varValuesMap) + assertEqual(t, `"aaa" == "b"`, false, varValuesMap) + assertEqual(t, `"aaa" < "b"`, true, varValuesMap) + assertEqual(t, `"aaa" <= "b"`, true, varValuesMap) + assertEqual(t, `"aaa" > "b"`, false, varValuesMap) + assertEqual(t, `"aaa" >= "b"`, false, varValuesMap) + assertEqual(t, `"aaa" == "aaa"`, true, varValuesMap) + assertEqual(t, `"aaa" != "aaa"`, false, varValuesMap) + + assertEvalError(t, "1 > true", "cannot perform binary comp op, incompatible arg types '1(int64)' > 'true(bool)' ", varValuesMap) +} + +func TestBool(t *testing.T) { + varValuesMap := VarValuesMap{} + assertEqual(t, `true && true`, true, varValuesMap) + assertEqual(t, `true && false`, false, varValuesMap) + assertEqual(t, `true || true`, true, varValuesMap) + assertEqual(t, `true || false`, true, varValuesMap) + assertEqual(t, `!false`, true, varValuesMap) + assertEqual(t, `!true`, false, varValuesMap) + + assertEvalError(t, `!123`, "cannot evaluate unary bool not expression with int64 on the right", varValuesMap) + assertEvalError(t, "true || 1", "cannot evaluate binary bool expression 'true(bool) || 1(int64)', invalid right arg", varValuesMap) + assertEvalError(t, "1 || true", "cannot perform binary op || against int64 left", varValuesMap) +} + +func TestUnaryMinus(t *testing.T) { + varValuesMap := VarValuesMap{ + "t1": { + "fieldInt": 1, + "fieldInt16": int16(1), + "fieldInt32": int32(1), + "fieldInt64": int16(1), + "fieldFloat32": float32(1.0), + "fieldFloat64": float64(1.0), + "fieldDecimal2": decimal.NewFromInt(1), + }, + } + for k := range varValuesMap["t1"] { + assertEqual(t, fmt.Sprintf("-t1.%s == -1", k), true, varValuesMap) + } +} + +func TestTime(t *testing.T) { + varValuesMap := VarValuesMap{ + "t1": { + "fTime": time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC), + }, + "t2": { + "fTime": time.Date(1, 1, 1, 1, 1, 1, 2, time.UTC), + }, + } + assertEqual(t, `t1.fTime < t2.fTime`, true, varValuesMap) + assertEqual(t, `t1.fTime <= t2.fTime`, true, varValuesMap) + assertEqual(t, `t2.fTime >= t1.fTime`, true, varValuesMap) + assertEqual(t, `t2.fTime > t1.fTime`, true, varValuesMap) + assertEqual(t, `t1.fTime == t2.fTime`, false, varValuesMap) + assertEqual(t, `t1.fTime != t2.fTime`, true, varValuesMap) +} + +func TestNewPlainEvalCtxAndInitializedAgg(t *testing.T) { + varValuesMap := getTestValuesMap() + varValuesMap["t1"]["fieldStr"] = "a" + + exp, _ := parser.ParseExpr(`string_agg(t1.fieldStr,",")`) + aggEnabledType, aggFuncType, aggFuncArgs := DetectRootAggFunc(exp) + eCtx, err := NewPlainEvalCtxAndInitializedAgg(aggEnabledType, aggFuncType, aggFuncArgs) + assert.Equal(t, AggTypeString, eCtx.AggType) + assert.Nil(t, err) + + exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr,1)`) + aggEnabledType, aggFuncType, aggFuncArgs = DetectRootAggFunc(exp) + _, err = NewPlainEvalCtxAndInitializedAgg(aggEnabledType, aggFuncType, aggFuncArgs) + assert.Equal(t, "string_agg second parameter must be a constant string", err.Error()) + + exp, _ = parser.ParseExpr(`string_agg(t1.fieldStr, a)`) + aggEnabledType, aggFuncType, aggFuncArgs = DetectRootAggFunc(exp) + _, err = NewPlainEvalCtxAndInitializedAgg(aggEnabledType, aggFuncType, aggFuncArgs) + assert.Equal(t, "string_agg second parameter must be a basic literal", err.Error()) +} + +type EvalFunc int + +const ( + BinaryIntFunc = iota + BinaryIntToBoolFunc + BinaryFloat64Func + BinaryFloat64ToBoolFunc + BinaryDecimal2Func + BinaryDecimal2ToBoolFunc + BinaryTimeToBoolFunc + BinaryBoolFunc + BinaryBoolToBoolFunc + BinaryStringFunc + BinaryStringToBoolFunc +) + +func assertBinaryEval(t *testing.T, evalFunc EvalFunc, valLeftVolatile any, op token.Token, valRightVolatile any, errorMessage string) { + var err error + eCtx := NewPlainEvalCtx(AggFuncDisabled) + switch evalFunc { + case BinaryIntFunc: + _, err = eCtx.EvalBinaryInt(valLeftVolatile, op, valRightVolatile) + case BinaryIntToBoolFunc: + _, err = eCtx.EvalBinaryIntToBool(valLeftVolatile, op, valRightVolatile) + case BinaryFloat64Func: + _, err = eCtx.EvalBinaryFloat64(valLeftVolatile, op, valRightVolatile) + case BinaryFloat64ToBoolFunc: + _, err = eCtx.EvalBinaryFloat64ToBool(valLeftVolatile, op, valRightVolatile) + case BinaryDecimal2Func: + _, err = eCtx.EvalBinaryDecimal2(valLeftVolatile, op, valRightVolatile) + case BinaryDecimal2ToBoolFunc: + _, err = eCtx.EvalBinaryDecimal2ToBool(valLeftVolatile, op, valRightVolatile) + case BinaryTimeToBoolFunc: + _, err = eCtx.EvalBinaryTimeToBool(valLeftVolatile, op, valRightVolatile) + case BinaryBoolFunc: + _, err = eCtx.EvalBinaryBool(valLeftVolatile, op, valRightVolatile) + case BinaryBoolToBoolFunc: + _, err = eCtx.EvalBinaryBoolToBool(valLeftVolatile, op, valRightVolatile) + case BinaryStringFunc: + _, err = eCtx.EvalBinaryString(valLeftVolatile, op, valRightVolatile) + case BinaryStringToBoolFunc: + _, err = eCtx.EvalBinaryStringToBool(valLeftVolatile, op, valRightVolatile) + default: + assert.Fail(t, "unsupported EvalFunc") + } + assert.Equal(t, errorMessage, err.Error()) +} + +func TestBadEvalBinaryInt(t *testing.T) { + goodVal := int64(1) + badVal := "a" + assertBinaryEval(t, BinaryIntFunc, badVal, token.ADD, goodVal, "cannot evaluate binary int64 expression '+' with 'a(string)' on the left") + assertBinaryEval(t, BinaryIntFunc, goodVal, token.ADD, badVal, "cannot evaluate binary int64 expression '1(int64) + a(string)', invalid right arg") + assertBinaryEval(t, BinaryIntFunc, goodVal, token.AND, goodVal, "cannot perform int op & against int 1 and int 1") +} + +func TestBadEvalBinaryIntToBool(t *testing.T) { + goodVal := int64(1) + badVal := "a" + assertBinaryEval(t, BinaryIntToBoolFunc, badVal, token.LSS, goodVal, "cannot evaluate binary int64 expression '<' with 'a(string)' on the left") + assertBinaryEval(t, BinaryIntToBoolFunc, goodVal, token.LSS, badVal, "cannot evaluate binary int64 expression '1(int64) < a(string)', invalid right arg") + assertBinaryEval(t, BinaryIntToBoolFunc, goodVal, token.ADD, int64(1), "cannot perform bool op + against int 1 and int 1") +} + +func TestBadEvalBinaryFloat64(t *testing.T) { + goodVal := float64(1) + badVal := "a" + assertBinaryEval(t, BinaryFloat64Func, badVal, token.ADD, goodVal, "cannot evaluate binary float64 expression '+' with 'a(string)' on the left") + assertBinaryEval(t, BinaryFloat64Func, goodVal, token.ADD, badVal, "cannot evaluate binary float expression '1(float64) + a(string)', invalid right arg") + assertBinaryEval(t, BinaryFloat64Func, goodVal, token.AND, goodVal, "cannot perform float64 op & against float64 1.000000 and float64 1.000000") +} + +func TestBadEvalBinaryFloat64ToBool(t *testing.T) { + goodVal := float64(1) + badVal := "a" + assertBinaryEval(t, BinaryFloat64ToBoolFunc, badVal, token.LSS, goodVal, "cannot evaluate binary foat64 expression '<' with 'a(string)' on the left") + assertBinaryEval(t, BinaryFloat64ToBoolFunc, goodVal, token.LSS, badVal, "cannot evaluate binary float64 expression '1(float64) < a(string)', invalid right arg") + assertBinaryEval(t, BinaryFloat64ToBoolFunc, goodVal, token.ADD, goodVal, "cannot perform bool op + against float 1.000000 and float 1.000000") +} + +func TestBadEvalBinaryDecimal2(t *testing.T) { + goodVal := decimal.NewFromFloat(1) + badVal := "a" + assertBinaryEval(t, BinaryDecimal2Func, badVal, token.ADD, goodVal, "cannot evaluate binary decimal2 expression '+' with 'a(string)' on the left") + assertBinaryEval(t, BinaryDecimal2Func, goodVal, token.ADD, badVal, "cannot evaluate binary decimal2 expression '1(decimal.Decimal) + a(string)', invalid right arg") + assertBinaryEval(t, BinaryDecimal2Func, goodVal, token.AND, goodVal, "cannot perform decimal2 op & against decimal2 1 and float64 1") +} + +func TestBadEvalBinaryDecimal2Bool(t *testing.T) { + goodVal := decimal.NewFromFloat(1) + badVal := "a" + assertBinaryEval(t, BinaryDecimal2ToBoolFunc, badVal, token.LSS, goodVal, "cannot evaluate binary decimal2 expression '<' with 'a(string)' on the left") + assertBinaryEval(t, BinaryDecimal2ToBoolFunc, goodVal, token.LSS, badVal, "cannot evaluate binary decimal2 expression '1(decimal.Decimal) < a(string)', invalid right arg") + assertBinaryEval(t, BinaryDecimal2ToBoolFunc, goodVal, token.ADD, goodVal, "cannot perform bool op + against decimal2 1 and decimal2 1") +} + +func TestBadEvalBinaryTimeBool(t *testing.T) { + goodVal := time.Date(2000, 1, 1, 0, 0, 0, 0, time.FixedZone("", -7200)) + badVal := "a" + assertBinaryEval(t, BinaryTimeToBoolFunc, badVal, token.LSS, goodVal, "cannot evaluate binary time expression '<' with 'a(string)' on the left") + assertBinaryEval(t, BinaryTimeToBoolFunc, goodVal, token.LSS, badVal, "cannot evaluate binary time expression '2000-01-01 00:00:00 -0200 -0200(time.Time) < a(string)', invalid right arg") + assertBinaryEval(t, BinaryTimeToBoolFunc, goodVal, token.ADD, goodVal, "cannot perform bool op + against time 2000-01-01 00:00:00 -0200 -0200 and time 2000-01-01 00:00:00 -0200 -0200") +} + +func TestBadEvalBinaryBool(t *testing.T) { + goodVal := true + badVal := "a" + assertBinaryEval(t, BinaryBoolFunc, badVal, token.LAND, goodVal, "cannot evaluate binary bool expression '&&' with 'a(string)' on the left") + assertBinaryEval(t, BinaryBoolFunc, goodVal, token.LOR, badVal, "cannot evaluate binary bool expression 'true(bool) || a(string)', invalid right arg") + assertBinaryEval(t, BinaryBoolFunc, goodVal, token.ADD, goodVal, "cannot perform bool op + against bool true and bool true") +} + +func TestBadEvalBinaryBoolToBool(t *testing.T) { + goodVal := true + badVal := "a" + assertBinaryEval(t, BinaryBoolToBoolFunc, badVal, token.EQL, goodVal, "cannot evaluate binary bool expression == with string on the left") + assertBinaryEval(t, BinaryBoolToBoolFunc, goodVal, token.NEQ, badVal, "cannot evaluate binary bool expression 'true(bool) != a(string)', invalid right arg") + assertBinaryEval(t, BinaryBoolToBoolFunc, goodVal, token.ADD, goodVal, "cannot evaluate binary bool expression, op + not supported (and will never be)") +} + +func TestBadEvalBinaryString(t *testing.T) { + goodVal := "good" + badVal := 1 + assertBinaryEval(t, BinaryStringFunc, badVal, token.ADD, goodVal, "cannot evaluate binary string expression + with int on the left") + assertBinaryEval(t, BinaryStringFunc, goodVal, token.ADD, badVal, "cannot evaluate binary string expression 'good(string) + 1(int)', invalid right arg") + assertBinaryEval(t, BinaryStringFunc, goodVal, token.AND, goodVal, "cannot perform string op & against string 'good' and string 'good', op not supported") +} + +func TestBadEvalBinaryStringToBool(t *testing.T) { + goodVal := "good" + badVal := 1 + assertBinaryEval(t, BinaryStringToBoolFunc, badVal, token.LSS, goodVal, "cannot evaluate binary string expression < with '1(int)' on the left") + assertBinaryEval(t, BinaryStringToBoolFunc, goodVal, token.GTR, badVal, "cannot evaluate binary decimal2 expression 'good(string) > 1(int)', invalid right arg") + assertBinaryEval(t, BinaryStringToBoolFunc, goodVal, token.AND, goodVal, "cannot perform bool op & against string good and string good") +} diff --git a/pkg/eval/fmt.go b/pkg/eval/fmt.go index 088a738..9962aca 100644 --- a/pkg/eval/fmt.go +++ b/pkg/eval/fmt.go @@ -4,7 +4,7 @@ import ( "fmt" ) -func callFmtSprintf(args []interface{}) (interface{}, error) { +func callFmtSprintf(args []any) (any, error) { if len(args) < 2 { return nil, fmt.Errorf("cannot evaluate fmt.Sprintf(), requires at least 2 args, %d supplied", len(args)) } @@ -12,7 +12,7 @@ func callFmtSprintf(args []interface{}) (interface{}, error) { if !ok0 { return nil, fmt.Errorf("cannot convert fmt.Sprintf() arg %v to string", args[0]) } - afterStringArgs := make([]interface{}, len(args)-1) + afterStringArgs := make([]any, len(args)-1) copy(afterStringArgs, args[1:]) return fmt.Sprintf(argString0, afterStringArgs...), nil } diff --git a/pkg/eval/math.go b/pkg/eval/math.go index 6f5628c..dac2497 100644 --- a/pkg/eval/math.go +++ b/pkg/eval/math.go @@ -1,41 +1,41 @@ -package eval - -import ( - "fmt" - "math" -) - -func callLen(args []interface{}) (interface{}, error) { - if err := checkArgs("len", 1, len(args)); err != nil { - return nil, err - } - argString, ok := args[0].(string) - if !ok { - return nil, fmt.Errorf("cannot convert len() arg %v to string", args[0]) - } - return len(argString), nil -} - -func callMathSqrt(args []interface{}) (interface{}, error) { - if err := checkArgs("math.Sqrt", 1, len(args)); err != nil { - return nil, err - } - argFloat, err := castToFloat64(args[0]) - if err != nil { - return nil, fmt.Errorf("cannot evaluate math.Sqrt(), invalid args %v: [%s]", args, err.Error()) - } - - return math.Sqrt(argFloat), nil -} - -func callMathRound(args []interface{}) (interface{}, error) { - if err := checkArgs("math.Round", 1, len(args)); err != nil { - return nil, err - } - argFloat, err := castToFloat64(args[0]) - if err != nil { - return nil, fmt.Errorf("cannot evaluate math.Round(), invalid args %v: [%s]", args, err.Error()) - } - - return math.Round(argFloat), nil -} +package eval + +import ( + "fmt" + "math" +) + +func callLen(args []any) (any, error) { + if err := checkArgs("len", 1, len(args)); err != nil { + return nil, err + } + argString, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("cannot convert len() arg %v to string", args[0]) + } + return len(argString), nil +} + +func callMathSqrt(args []any) (any, error) { + if err := checkArgs("math.Sqrt", 1, len(args)); err != nil { + return nil, err + } + argFloat, err := castToFloat64(args[0]) + if err != nil { + return nil, fmt.Errorf("cannot evaluate math.Sqrt(), invalid args %v: [%s]", args, err.Error()) + } + + return math.Sqrt(argFloat), nil +} + +func callMathRound(args []any) (any, error) { + if err := checkArgs("math.Round", 1, len(args)); err != nil { + return nil, err + } + argFloat, err := castToFloat64(args[0]) + if err != nil { + return nil, fmt.Errorf("cannot evaluate math.Round(), invalid args %v: [%s]", args, err.Error()) + } + + return math.Round(argFloat), nil +} diff --git a/pkg/eval/math_test.go b/pkg/eval/math_test.go index 345c298..63565fa 100644 --- a/pkg/eval/math_test.go +++ b/pkg/eval/math_test.go @@ -1,19 +1,19 @@ -package eval - -import "testing" - -func TestMathFunctions(t *testing.T) { - varValuesMap := VarValuesMap{} - assertEqual(t, `len("aaa")`, 3, varValuesMap) - assertEvalError(t, "len(123)", "cannot convert len() arg 123 to string", varValuesMap) - assertEvalError(t, "len(123,567)", "cannot evaluate len(), requires 1 args, 2 supplied", varValuesMap) - - assertEqual(t, "math.Sqrt(5)", 2.23606797749979, varValuesMap) - assertEvalError(t, `math.Sqrt("aa")`, "cannot evaluate math.Sqrt(), invalid args [aa]: [cannot cast aa(string) to float64, unsuported type]", varValuesMap) - assertFloatNan(t, "math.Sqrt(-1)", varValuesMap) - assertEvalError(t, "math.Sqrt(123,567)", "cannot evaluate math.Sqrt(), requires 1 args, 2 supplied", varValuesMap) - - assertEqual(t, "math.Round(5.1)", 5.0, varValuesMap) - assertEvalError(t, `math.Round("aa")`, "cannot evaluate math.Round(), invalid args [aa]: [cannot cast aa(string) to float64, unsuported type]", varValuesMap) - assertEvalError(t, "math.Round(5,1)", "cannot evaluate math.Round(), requires 1 args, 2 supplied", varValuesMap) -} +package eval + +import "testing" + +func TestMathFunctions(t *testing.T) { + varValuesMap := VarValuesMap{} + assertEqual(t, `len("aaa")`, 3, varValuesMap) + assertEvalError(t, "len(123)", "cannot convert len() arg 123 to string", varValuesMap) + assertEvalError(t, "len(123,567)", "cannot evaluate len(), requires 1 args, 2 supplied", varValuesMap) + + assertEqual(t, "math.Sqrt(5)", 2.23606797749979, varValuesMap) + assertEvalError(t, `math.Sqrt("aa")`, "cannot evaluate math.Sqrt(), invalid args [aa]: [cannot cast aa(string) to float64, unsuported type]", varValuesMap) + assertFloatNan(t, "math.Sqrt(-1)", varValuesMap) + assertEvalError(t, "math.Sqrt(123,567)", "cannot evaluate math.Sqrt(), requires 1 args, 2 supplied", varValuesMap) + + assertEqual(t, "math.Round(5.1)", 5.0, varValuesMap) + assertEvalError(t, `math.Round("aa")`, "cannot evaluate math.Round(), invalid args [aa]: [cannot cast aa(string) to float64, unsuported type]", varValuesMap) + assertEvalError(t, "math.Round(5,1)", "cannot evaluate math.Round(), requires 1 args, 2 supplied", varValuesMap) +} diff --git a/pkg/eval/re.go b/pkg/eval/re.go index fce9780..0b632aa 100644 --- a/pkg/eval/re.go +++ b/pkg/eval/re.go @@ -1,18 +1,18 @@ -package eval - -import ( - "fmt" - "regexp" -) - -func callReMatchString(args []interface{}) (interface{}, error) { - if err := checkArgs("re.MatchString", 2, len(args)); err != nil { - return nil, err - } - argString0, ok0 := args[0].(string) - argString1, ok1 := args[1].(string) - if !ok0 || !ok1 { - return nil, fmt.Errorf("cannot convert re.MatchString() args %v and %v to string", args[0], args[1]) - } - return regexp.MatchString(argString0, argString1) -} +package eval + +import ( + "fmt" + "regexp" +) + +func callReMatchString(args []any) (any, error) { + if err := checkArgs("re.MatchString", 2, len(args)); err != nil { + return nil, err + } + argString0, ok0 := args[0].(string) + argString1, ok1 := args[1].(string) + if !ok0 || !ok1 { + return nil, fmt.Errorf("cannot convert re.MatchString() args %v and %v to string", args[0], args[1]) + } + return regexp.MatchString(argString0, argString1) +} diff --git a/pkg/eval/re_test.go b/pkg/eval/re_test.go index a29fcd6..dcb7515 100644 --- a/pkg/eval/re_test.go +++ b/pkg/eval/re_test.go @@ -1,34 +1,34 @@ -package eval - -import ( - "testing" -) - -func TestReFunctions(t *testing.T) { - var vars VarValuesMap - var re string - - vars = VarValuesMap{"r": map[string]interface{}{"product_spec": `{"k":"Ideal For","v":"Boys, Men, Girls, Women"}`}} - re = "re.MatchString(`\"k\":\"Ideal For\",\"v\":\"[\\w ,]*Boys[\\w ,]*\"`, r.product_spec)" - assertEqual(t, re, true, vars) - - vars = VarValuesMap{"r": map[string]interface{}{"product_spec": `{"k":"Water Resistance Depth","v":"100 m"}`}} - re = "re.MatchString(`\"k\":\"Water Resistance Depth\",\"v\":\"(100|200) m\"`, r.product_spec)" - assertEqual(t, re, true, vars) - - vars = VarValuesMap{"r": map[string]interface{}{"product_spec": `{"k":"Occasion","v":"Ethnic, Casual, Party, Formal"}`}} - re = "re.MatchString(`\"k\":\"Occasion\",\"v\":\"[\\w ,]*(Casual|Festive)[\\w ,]*\"`, r.product_spec)" - assertEqual(t, re, true, vars) - - vars = VarValuesMap{"r": map[string]interface{}{"product_spec": `{"k":"Base Material","v":"Gold"},{"k":"Gemstone","v":"Diamond"}`, "retail_price": 101}} - re = "re.MatchString(`\"k\":\"Base Material\",\"v\":\"Gold\"`, r.product_spec) && re.MatchString(`\"k\":\"Gemstone\",\"v\":\"Diamond\"`, r.product_spec) && r.retail_price > 100" - assertEqual(t, re, true, vars) - - vars = VarValuesMap{"r": map[string]interface{}{"product_spec": `{"k":"Base Material","v":"Gold"},{"k":"Gemstone","v":"Diamond"}`, "retail_price": 100}} - re = "re.MatchString(`\"k\":\"Base Material\",\"v\":\"Gold\"`, r.product_spec) && re.MatchString(`\"k\":\"Gemstone\",\"v\":\"Diamond\"`, r.product_spec) && r.retail_price > 100" - assertEqual(t, re, false, vars) - - assertEvalError(t, `re.MatchString("a")`, "cannot evaluate re.MatchString(), requires 2 args, 1 supplied", vars) - assertEvalError(t, `re.MatchString("a",1)`, "cannot convert re.MatchString() args a and 1 to string", vars) - -} +package eval + +import ( + "testing" +) + +func TestReFunctions(t *testing.T) { + var vars VarValuesMap + var re string + + vars = VarValuesMap{"r": map[string]any{"product_spec": `{"k":"Ideal For","v":"Boys, Men, Girls, Women"}`}} + re = "re.MatchString(`\"k\":\"Ideal For\",\"v\":\"[\\w ,]*Boys[\\w ,]*\"`, r.product_spec)" + assertEqual(t, re, true, vars) + + vars = VarValuesMap{"r": map[string]any{"product_spec": `{"k":"Water Resistance Depth","v":"100 m"}`}} + re = "re.MatchString(`\"k\":\"Water Resistance Depth\",\"v\":\"(100|200) m\"`, r.product_spec)" + assertEqual(t, re, true, vars) + + vars = VarValuesMap{"r": map[string]any{"product_spec": `{"k":"Occasion","v":"Ethnic, Casual, Party, Formal"}`}} + re = "re.MatchString(`\"k\":\"Occasion\",\"v\":\"[\\w ,]*(Casual|Festive)[\\w ,]*\"`, r.product_spec)" + assertEqual(t, re, true, vars) + + vars = VarValuesMap{"r": map[string]any{"product_spec": `{"k":"Base Material","v":"Gold"},{"k":"Gemstone","v":"Diamond"}`, "retail_price": 101}} + re = "re.MatchString(`\"k\":\"Base Material\",\"v\":\"Gold\"`, r.product_spec) && re.MatchString(`\"k\":\"Gemstone\",\"v\":\"Diamond\"`, r.product_spec) && r.retail_price > 100" + assertEqual(t, re, true, vars) + + vars = VarValuesMap{"r": map[string]any{"product_spec": `{"k":"Base Material","v":"Gold"},{"k":"Gemstone","v":"Diamond"}`, "retail_price": 100}} + re = "re.MatchString(`\"k\":\"Base Material\",\"v\":\"Gold\"`, r.product_spec) && re.MatchString(`\"k\":\"Gemstone\",\"v\":\"Diamond\"`, r.product_spec) && r.retail_price > 100" + assertEqual(t, re, false, vars) + + assertEvalError(t, `re.MatchString("a")`, "cannot evaluate re.MatchString(), requires 2 args, 1 supplied", vars) + assertEvalError(t, `re.MatchString("a",1)`, "cannot convert re.MatchString() args a and 1 to string", vars) + +} diff --git a/pkg/eval/strings.go b/pkg/eval/strings.go index 2f8fb5b..f0bbd5a 100644 --- a/pkg/eval/strings.go +++ b/pkg/eval/strings.go @@ -1,19 +1,19 @@ -package eval - -import ( - "fmt" - "strings" -) - -func callStringsReplaceAll(args []interface{}) (interface{}, error) { - if err := checkArgs("strings.ReplaceAll", 3, len(args)); err != nil { - return nil, err - } - argString0, ok0 := args[0].(string) - argString1, ok1 := args[1].(string) - argString2, ok2 := args[2].(string) - if !ok0 || !ok1 || !ok2 { - return nil, fmt.Errorf("cannot convert strings.ReplaceAll() args %v,%v,%v to string", args[0], args[1], args[2]) - } - return strings.ReplaceAll(argString0, argString1, argString2), nil -} +package eval + +import ( + "fmt" + "strings" +) + +func callStringsReplaceAll(args []any) (any, error) { + if err := checkArgs("strings.ReplaceAll", 3, len(args)); err != nil { + return nil, err + } + argString0, ok0 := args[0].(string) + argString1, ok1 := args[1].(string) + argString2, ok2 := args[2].(string) + if !ok0 || !ok1 || !ok2 { + return nil, fmt.Errorf("cannot convert strings.ReplaceAll() args %v,%v,%v to string", args[0], args[1], args[2]) + } + return strings.ReplaceAll(argString0, argString1, argString2), nil +} diff --git a/pkg/eval/strings_test.go b/pkg/eval/strings_test.go index c9ef862..731c0b7 100644 --- a/pkg/eval/strings_test.go +++ b/pkg/eval/strings_test.go @@ -1,13 +1,12 @@ -package eval - -import ( - "fmt" - "testing" -) - -func TestStringsFunctions(t *testing.T) { - varValuesMap := VarValuesMap{} - assertEqual(t, fmt.Sprintf(`strings.ReplaceAll("abc","a","b")`), "bbc", varValuesMap) - assertEvalError(t, `strings.ReplaceAll("a","b")`, "cannot evaluate strings.ReplaceAll(), requires 3 args, 2 supplied", varValuesMap) - assertEvalError(t, `strings.ReplaceAll("a","b",1)`, "cannot convert strings.ReplaceAll() args a,b,1 to string", varValuesMap) -} +package eval + +import ( + "testing" +) + +func TestStringsFunctions(t *testing.T) { + varValuesMap := VarValuesMap{} + assertEqual(t, `strings.ReplaceAll("abc","a","b")`, "bbc", varValuesMap) + assertEvalError(t, `strings.ReplaceAll("a","b")`, "cannot evaluate strings.ReplaceAll(), requires 3 args, 2 supplied", varValuesMap) + assertEvalError(t, `strings.ReplaceAll("a","b",1)`, "cannot convert strings.ReplaceAll() args a,b,1 to string", varValuesMap) +} diff --git a/pkg/eval/time.go b/pkg/eval/time.go index 1e54ab8..7cba2fb 100644 --- a/pkg/eval/time.go +++ b/pkg/eval/time.go @@ -1,130 +1,130 @@ -package eval - -import ( - "fmt" - "time" -) - -func callTimeParse(args []interface{}) (interface{}, error) { - if err := checkArgs("time.Parse", 2, len(args)); err != nil { - return nil, err - } - arg0, ok0 := args[0].(string) - arg1, ok1 := args[1].(string) - if !ok0 || !ok1 { - return nil, fmt.Errorf("cannot evaluate time.Parse(), invalid args %v", args) - } - return time.Parse(arg0, arg1) -} - -func callTimeFormat(args []interface{}) (interface{}, error) { - if err := checkArgs("time.Format", 2, len(args)); err != nil { - return nil, err - } - arg0, ok0 := args[0].(time.Time) - arg1, ok1 := args[1].(string) - if !ok0 || !ok1 { - return nil, fmt.Errorf("cannot evaluate time.Format(), invalid args %v", args) - } - return arg0.Format(arg1), nil -} - -func callTimeFixedZone(args []interface{}) (interface{}, error) { - if err := checkArgs("time.FixedZone", 2, len(args)); err != nil { - return nil, err - } - arg0, ok0 := args[0].(string) // Name - arg1, ok1 := args[1].(int64) // Offset in min - if !ok0 || !ok1 { - return nil, fmt.Errorf("cannot evaluate time.FixedZone(), invalid args %v", args) - } - return time.FixedZone(arg0, int(arg1)), nil -} - -func callTimeDate(args []interface{}) (interface{}, error) { - if err := checkArgs("time.Date", 8, len(args)); err != nil { - return nil, err - } - arg0, ok0 := args[0].(int64) // Year - arg1, ok1 := args[1].(time.Month) // Month - arg2, ok2 := args[2].(int64) // Day - arg3, ok3 := args[3].(int64) // Hour - arg4, ok4 := args[4].(int64) // Min - arg5, ok5 := args[5].(int64) // Sec - arg6, ok6 := args[6].(int64) // Nsec - arg7, ok7 := args[7].(*time.Location) - if !ok0 || !ok1 || !ok2 || !ok3 || !ok4 || !ok5 || !ok6 || !ok7 { - return nil, fmt.Errorf("cannot evaluate time.Date(), invalid args %v", args) - } - return time.Date(int(arg0), arg1, int(arg2), int(arg3), int(arg4), int(arg5), int(arg6), arg7), nil -} - -func callTimeDiffMilli(args []interface{}) (interface{}, error) { - if err := checkArgs("time.DiffMilli", 2, len(args)); err != nil { - return nil, err - } - arg0, ok0 := args[0].(time.Time) - arg1, ok1 := args[1].(time.Time) - if !ok0 || !ok1 { - return nil, fmt.Errorf("cannot evaluate time.DiffMilli(), invalid args %v", args) - } - - return arg0.Sub(arg1).Milliseconds(), nil -} - -func callTimeNow(args []interface{}) (interface{}, error) { - if err := checkArgs("time.Now", 0, len(args)); err != nil { - return nil, err - } - return time.Now(), nil -} - -func callTimeUnix(args []interface{}) (interface{}, error) { - if err := checkArgs("time.Unix", 1, len(args)); err != nil { - return nil, err - } - arg0, ok0 := args[0].(time.Time) - if !ok0 { - return nil, fmt.Errorf("cannot evaluate time.Unix(), invalid args %v", args) - } - - return arg0.Unix(), nil -} - -func callTimeUnixMilli(args []interface{}) (interface{}, error) { - if err := checkArgs("time.UnixMilli", 1, len(args)); err != nil { - return nil, err - } - arg0, ok0 := args[0].(time.Time) - if !ok0 { - return nil, fmt.Errorf("cannot evaluate time.UnixMilli(), invalid args %v", args) - } - - return arg0.UnixMilli(), nil -} - -func callTimeBefore(args []interface{}) (interface{}, error) { - if err := checkArgs("time.Before", 2, len(args)); err != nil { - return nil, err - } - arg0, ok0 := args[0].(time.Time) - arg1, ok1 := args[1].(time.Time) - if !ok0 || !ok1 { - return nil, fmt.Errorf("cannot evaluate time.Before(), invalid args %v", args) - } - - return arg0.Before(arg1), nil -} - -func callTimeAfter(args []interface{}) (interface{}, error) { - if err := checkArgs("time.After", 2, len(args)); err != nil { - return nil, err - } - arg0, ok0 := args[0].(time.Time) - arg1, ok1 := args[1].(time.Time) - if !ok0 || !ok1 { - return nil, fmt.Errorf("cannot evaluate time.After(), invalid args %v", args) - } - - return arg0.After(arg1), nil -} +package eval + +import ( + "fmt" + "time" +) + +func callTimeParse(args []any) (any, error) { + if err := checkArgs("time.Parse", 2, len(args)); err != nil { + return nil, err + } + arg0, ok0 := args[0].(string) + arg1, ok1 := args[1].(string) + if !ok0 || !ok1 { + return nil, fmt.Errorf("cannot evaluate time.Parse(), invalid args %v", args) + } + return time.Parse(arg0, arg1) +} + +func callTimeFormat(args []any) (any, error) { + if err := checkArgs("time.Format", 2, len(args)); err != nil { + return nil, err + } + arg0, ok0 := args[0].(time.Time) + arg1, ok1 := args[1].(string) + if !ok0 || !ok1 { + return nil, fmt.Errorf("cannot evaluate time.Format(), invalid args %v", args) + } + return arg0.Format(arg1), nil +} + +func callTimeFixedZone(args []any) (any, error) { + if err := checkArgs("time.FixedZone", 2, len(args)); err != nil { + return nil, err + } + arg0, ok0 := args[0].(string) // Name + arg1, ok1 := args[1].(int64) // Offset in min + if !ok0 || !ok1 { + return nil, fmt.Errorf("cannot evaluate time.FixedZone(), invalid args %v", args) + } + return time.FixedZone(arg0, int(arg1)), nil +} + +func callTimeDate(args []any) (any, error) { + if err := checkArgs("time.Date", 8, len(args)); err != nil { + return nil, err + } + arg0, ok0 := args[0].(int64) // Year + arg1, ok1 := args[1].(time.Month) // Month + arg2, ok2 := args[2].(int64) // Day + arg3, ok3 := args[3].(int64) // Hour + arg4, ok4 := args[4].(int64) // Min + arg5, ok5 := args[5].(int64) // Sec + arg6, ok6 := args[6].(int64) // Nsec + arg7, ok7 := args[7].(*time.Location) + if !ok0 || !ok1 || !ok2 || !ok3 || !ok4 || !ok5 || !ok6 || !ok7 { + return nil, fmt.Errorf("cannot evaluate time.Date(), invalid args %v", args) + } + return time.Date(int(arg0), arg1, int(arg2), int(arg3), int(arg4), int(arg5), int(arg6), arg7), nil +} + +func callTimeDiffMilli(args []any) (any, error) { + if err := checkArgs("time.DiffMilli", 2, len(args)); err != nil { + return nil, err + } + arg0, ok0 := args[0].(time.Time) + arg1, ok1 := args[1].(time.Time) + if !ok0 || !ok1 { + return nil, fmt.Errorf("cannot evaluate time.DiffMilli(), invalid args %v", args) + } + + return arg0.Sub(arg1).Milliseconds(), nil +} + +func callTimeNow(args []any) (any, error) { + if err := checkArgs("time.Now", 0, len(args)); err != nil { + return nil, err + } + return time.Now(), nil +} + +func callTimeUnix(args []any) (any, error) { + if err := checkArgs("time.Unix", 1, len(args)); err != nil { + return nil, err + } + arg0, ok0 := args[0].(time.Time) + if !ok0 { + return nil, fmt.Errorf("cannot evaluate time.Unix(), invalid args %v", args) + } + + return arg0.Unix(), nil +} + +func callTimeUnixMilli(args []any) (any, error) { + if err := checkArgs("time.UnixMilli", 1, len(args)); err != nil { + return nil, err + } + arg0, ok0 := args[0].(time.Time) + if !ok0 { + return nil, fmt.Errorf("cannot evaluate time.UnixMilli(), invalid args %v", args) + } + + return arg0.UnixMilli(), nil +} + +func callTimeBefore(args []any) (any, error) { + if err := checkArgs("time.Before", 2, len(args)); err != nil { + return nil, err + } + arg0, ok0 := args[0].(time.Time) + arg1, ok1 := args[1].(time.Time) + if !ok0 || !ok1 { + return nil, fmt.Errorf("cannot evaluate time.Before(), invalid args %v", args) + } + + return arg0.Before(arg1), nil +} + +func callTimeAfter(args []any) (any, error) { + if err := checkArgs("time.After", 2, len(args)); err != nil { + return nil, err + } + arg0, ok0 := args[0].(time.Time) + arg1, ok1 := args[1].(time.Time) + if !ok0 || !ok1 { + return nil, fmt.Errorf("cannot evaluate time.After(), invalid args %v", args) + } + + return arg0.After(arg1), nil +} diff --git a/pkg/eval/time_test.go b/pkg/eval/time_test.go index ce63dda..85ebd37 100644 --- a/pkg/eval/time_test.go +++ b/pkg/eval/time_test.go @@ -1,77 +1,77 @@ -package eval - -import ( - "fmt" - "go/parser" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestTimeFunctions(t *testing.T) { - testTime := time.Date(2001, 1, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)) - testTimeUtc := time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) - varValuesMap := VarValuesMap{"t": map[string]interface{}{"test_time": testTime}} - - assertEqual(t, `time.Parse("2006-01-02T15:04:05.000-0700","2001-01-01T01:01:01.100-0200")`, testTime, varValuesMap) - assertEvalError(t, `time.Parse("2006-01-02T15:04:05.000-0700","2001-01-01T01:01:01.100-0200","aaa")`, "cannot evaluate time.Parse(), requires 2 args, 3 supplied", varValuesMap) - assertEvalError(t, `time.Parse("2006-01-02T15:04:05.000-0700",123)`, "cannot evaluate time.Parse(), invalid args [2006-01-02T15:04:05.000-0700 123]", varValuesMap) - assertEvalError(t, `time.Parse("2006-01-02T15:04:05.000-0700","2001-01-01T01:01:01")`, `parsing time "2001-01-01T01:01:01" as "2006-01-02T15:04:05.000-0700": cannot parse "" as ".000"`, varValuesMap) - assertEqual(t, `time.Parse("2006-01-02","2001-01-01")`, testTimeUtc, varValuesMap) - - assertEqual(t, `time.Format(time.Date(2001, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), "2006-01-02T15:04:05.000-0700")`, testTime.Format("2006-01-02T15:04:05.000-0700"), varValuesMap) - assertEvalError(t, `time.Format(time.Date(2001, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)))`, "cannot evaluate time.Format(), requires 2 args, 1 supplied", varValuesMap) - assertEvalError(t, `time.Format("some_bad_param", "2006-01-02T15:04:05.000-0700")`, "cannot evaluate time.Format(), invalid args [some_bad_param 2006-01-02T15:04:05.000-0700]", varValuesMap) - - assertEqual(t, `time.Date(2001, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200))`, testTime, varValuesMap) - assertEvalError(t, `time.Date(2001, 354, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200))`, "cannot evaluate time.Date(), invalid args [2001 354 1 1 1 1 100000000 ]", varValuesMap) - assertEvalError(t, `time.Date(2001, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200), "extraparam")`, "cannot evaluate time.Date(), requires 8 args, 9 supplied", varValuesMap) - - assertEvalError(t, `time.FixedZone("")`, "cannot evaluate time.FixedZone(), requires 2 args, 1 supplied", varValuesMap) - assertEvalError(t, `time.FixedZone("", "some_bad_param")`, "cannot evaluate time.FixedZone(), invalid args [ some_bad_param]", varValuesMap) - - assertEqual(t, `time.DiffMilli(time.Date(2001, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, int64(0), varValuesMap) - assertEqual(t, `time.DiffMilli(time.Date(2002, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, int64(31536000000), varValuesMap) - assertEqual(t, `time.DiffMilli(time.Date(2000, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, int64(-31622400000), varValuesMap) - assertEvalError(t, `time.DiffMilli(1)`, "cannot evaluate time.DiffMilli(), requires 2 args, 1 supplied", varValuesMap) - assertEvalError(t, `time.DiffMilli("some_bad_param", t.test_time)`, "cannot evaluate time.DiffMilli(), invalid args [some_bad_param 2001-01-01 01:01:01.1 -0200 -0200]", varValuesMap) - - assertEqual(t, `time.Before(time.Date(2000, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, true, varValuesMap) - assertEqual(t, `time.Before(time.Date(2002, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, false, varValuesMap) - assertEvalError(t, `time.Before()`, "cannot evaluate time.Before(), requires 2 args, 0 supplied", varValuesMap) - assertEvalError(t, `time.Before("some_bad_param", t.test_time)`, "cannot evaluate time.Before(), invalid args [some_bad_param 2001-01-01 01:01:01.1 -0200 -0200]", varValuesMap) - - assertEqual(t, `time.After(time.Date(2002, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, true, varValuesMap) - assertEqual(t, `time.After(time.Date(2000, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, false, varValuesMap) - assertEvalError(t, `time.After()`, "cannot evaluate time.After(), requires 2 args, 0 supplied", varValuesMap) - assertEvalError(t, `time.After("some_bad_param", t.test_time)`, "cannot evaluate time.After(), invalid args [some_bad_param 2001-01-01 01:01:01.1 -0200 -0200]", varValuesMap) - - assertEqual(t, `time.Unix(t.test_time)`, testTime.Unix(), varValuesMap) - assertEvalError(t, `time.Unix()`, "cannot evaluate time.Unix(), requires 1 args, 0 supplied", varValuesMap) - assertEvalError(t, `time.Unix("some_bad_param")`, "cannot evaluate time.Unix(), invalid args [some_bad_param]", varValuesMap) - - assertEqual(t, `time.UnixMilli(t.test_time)`, testTime.UnixMilli(), varValuesMap) - assertEvalError(t, `time.UnixMilli()`, "cannot evaluate time.UnixMilli(), requires 1 args, 0 supplied", varValuesMap) - assertEvalError(t, `time.UnixMilli("some_bad_param")`, "cannot evaluate time.UnixMilli(), invalid args [some_bad_param]", varValuesMap) -} - -func TestNow(t *testing.T) { - - exp, err1 := parser.ParseExpr(`time.Now()`) - if err1 != nil { - t.Error(fmt.Errorf("cannot parse Now(): %s", err1.Error())) - return - } - eCtx := NewPlainEvalCtxWithVars(AggFuncDisabled, &VarValuesMap{}) - result, err2 := eCtx.Eval(exp) - if err2 != nil { - t.Error(fmt.Errorf("cannot eval Now(): %s", err2.Error())) - return - } - - resultTime, _ := result.(time.Time) - - assert.True(t, time.Since(resultTime).Milliseconds() < 500) - assertEvalError(t, `time.Now(1)`, "cannot evaluate time.Now(), requires 0 args, 1 supplied", VarValuesMap{}) -} +package eval + +import ( + "fmt" + "go/parser" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTimeFunctions(t *testing.T) { + testTime := time.Date(2001, 1, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)) + testTimeUtc := time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + varValuesMap := VarValuesMap{"t": map[string]any{"test_time": testTime}} + + assertEqual(t, `time.Parse("2006-01-02T15:04:05.000-0700","2001-01-01T01:01:01.100-0200")`, testTime, varValuesMap) + assertEvalError(t, `time.Parse("2006-01-02T15:04:05.000-0700","2001-01-01T01:01:01.100-0200","aaa")`, "cannot evaluate time.Parse(), requires 2 args, 3 supplied", varValuesMap) + assertEvalError(t, `time.Parse("2006-01-02T15:04:05.000-0700",123)`, "cannot evaluate time.Parse(), invalid args [2006-01-02T15:04:05.000-0700 123]", varValuesMap) + assertEvalError(t, `time.Parse("2006-01-02T15:04:05.000-0700","2001-01-01T01:01:01")`, `parsing time "2001-01-01T01:01:01" as "2006-01-02T15:04:05.000-0700": cannot parse "" as ".000"`, varValuesMap) + assertEqual(t, `time.Parse("2006-01-02","2001-01-01")`, testTimeUtc, varValuesMap) + + assertEqual(t, `time.Format(time.Date(2001, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), "2006-01-02T15:04:05.000-0700")`, testTime.Format("2006-01-02T15:04:05.000-0700"), varValuesMap) + assertEvalError(t, `time.Format(time.Date(2001, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)))`, "cannot evaluate time.Format(), requires 2 args, 1 supplied", varValuesMap) + assertEvalError(t, `time.Format("some_bad_param", "2006-01-02T15:04:05.000-0700")`, "cannot evaluate time.Format(), invalid args [some_bad_param 2006-01-02T15:04:05.000-0700]", varValuesMap) + + assertEqual(t, `time.Date(2001, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200))`, testTime, varValuesMap) + assertEvalError(t, `time.Date(2001, 354, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200))`, "cannot evaluate time.Date(), invalid args [2001 354 1 1 1 1 100000000 ]", varValuesMap) + assertEvalError(t, `time.Date(2001, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200), "extraparam")`, "cannot evaluate time.Date(), requires 8 args, 9 supplied", varValuesMap) + + assertEvalError(t, `time.FixedZone("")`, "cannot evaluate time.FixedZone(), requires 2 args, 1 supplied", varValuesMap) + assertEvalError(t, `time.FixedZone("", "some_bad_param")`, "cannot evaluate time.FixedZone(), invalid args [ some_bad_param]", varValuesMap) + + assertEqual(t, `time.DiffMilli(time.Date(2001, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, int64(0), varValuesMap) + assertEqual(t, `time.DiffMilli(time.Date(2002, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, int64(31536000000), varValuesMap) + assertEqual(t, `time.DiffMilli(time.Date(2000, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, int64(-31622400000), varValuesMap) + assertEvalError(t, `time.DiffMilli(1)`, "cannot evaluate time.DiffMilli(), requires 2 args, 1 supplied", varValuesMap) + assertEvalError(t, `time.DiffMilli("some_bad_param", t.test_time)`, "cannot evaluate time.DiffMilli(), invalid args [some_bad_param 2001-01-01 01:01:01.1 -0200 -0200]", varValuesMap) + + assertEqual(t, `time.Before(time.Date(2000, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, true, varValuesMap) + assertEqual(t, `time.Before(time.Date(2002, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, false, varValuesMap) + assertEvalError(t, `time.Before()`, "cannot evaluate time.Before(), requires 2 args, 0 supplied", varValuesMap) + assertEvalError(t, `time.Before("some_bad_param", t.test_time)`, "cannot evaluate time.Before(), invalid args [some_bad_param 2001-01-01 01:01:01.1 -0200 -0200]", varValuesMap) + + assertEqual(t, `time.After(time.Date(2002, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, true, varValuesMap) + assertEqual(t, `time.After(time.Date(2000, time.January, 1, 1, 1, 1, 100000000, time.FixedZone("", -7200)), t.test_time)`, false, varValuesMap) + assertEvalError(t, `time.After()`, "cannot evaluate time.After(), requires 2 args, 0 supplied", varValuesMap) + assertEvalError(t, `time.After("some_bad_param", t.test_time)`, "cannot evaluate time.After(), invalid args [some_bad_param 2001-01-01 01:01:01.1 -0200 -0200]", varValuesMap) + + assertEqual(t, `time.Unix(t.test_time)`, testTime.Unix(), varValuesMap) + assertEvalError(t, `time.Unix()`, "cannot evaluate time.Unix(), requires 1 args, 0 supplied", varValuesMap) + assertEvalError(t, `time.Unix("some_bad_param")`, "cannot evaluate time.Unix(), invalid args [some_bad_param]", varValuesMap) + + assertEqual(t, `time.UnixMilli(t.test_time)`, testTime.UnixMilli(), varValuesMap) + assertEvalError(t, `time.UnixMilli()`, "cannot evaluate time.UnixMilli(), requires 1 args, 0 supplied", varValuesMap) + assertEvalError(t, `time.UnixMilli("some_bad_param")`, "cannot evaluate time.UnixMilli(), invalid args [some_bad_param]", varValuesMap) +} + +func TestNow(t *testing.T) { + + exp, err1 := parser.ParseExpr(`time.Now()`) + if err1 != nil { + t.Error(fmt.Errorf("cannot parse Now(): %s", err1.Error())) + return + } + eCtx := NewPlainEvalCtxWithVars(AggFuncDisabled, &VarValuesMap{}) + result, err2 := eCtx.Eval(exp) + if err2 != nil { + t.Error(fmt.Errorf("cannot eval Now(): %s", err2.Error())) + return + } + + resultTime, ok := result.(time.Time) + assert.True(t, ok) + assert.True(t, time.Since(resultTime).Milliseconds() < 500) + assertEvalError(t, `time.Now(1)`, "cannot evaluate time.Now(), requires 0 args, 1 supplied", VarValuesMap{}) +} diff --git a/pkg/eval/util.go b/pkg/eval/util.go index f7b825b..54ca69d 100644 --- a/pkg/eval/util.go +++ b/pkg/eval/util.go @@ -1,37 +1,37 @@ -package eval - -import ( - "fmt" - "go/ast" - "go/token" - "strings" -) - -func DetectRootAggFunc(exp ast.Expr) (AggEnabledType, AggFuncType, []ast.Expr) { - if callExp, ok := exp.(*ast.CallExpr); ok { - funExp := callExp.Fun - if funIdentExp, ok := funExp.(*ast.Ident); ok { - if StringToAggFunc(funIdentExp.Name) != AggUnknown { - return AggFuncEnabled, StringToAggFunc(funIdentExp.Name), callExp.Args - } - } - } - return AggFuncDisabled, AggUnknown, nil -} - -func GetAggStringSeparator(aggFuncArgs []ast.Expr) (string, error) { - if len(aggFuncArgs) < 2 { - return "", fmt.Errorf("string_agg must have two parameters") - } - switch separatorExpTyped := aggFuncArgs[1].(type) { - case *ast.BasicLit: - switch separatorExpTyped.Kind { - case token.STRING: - return strings.Trim(separatorExpTyped.Value, "\""), nil - default: - return "", fmt.Errorf("string_agg second parameter must be a constant string") - } - default: - return "", fmt.Errorf("string_agg second parameter must be a basic literal") - } -} +package eval + +import ( + "fmt" + "go/ast" + "go/token" + "strings" +) + +func DetectRootAggFunc(exp ast.Expr) (AggEnabledType, AggFuncType, []ast.Expr) { + if callExp, ok := exp.(*ast.CallExpr); ok { + funExp := callExp.Fun + if funIdentExp, ok := funExp.(*ast.Ident); ok { + if StringToAggFunc(funIdentExp.Name) != AggUnknown { + return AggFuncEnabled, StringToAggFunc(funIdentExp.Name), callExp.Args + } + } + } + return AggFuncDisabled, AggUnknown, nil +} + +func GetAggStringSeparator(aggFuncArgs []ast.Expr) (string, error) { + if len(aggFuncArgs) < 2 { + return "", fmt.Errorf("string_agg must have two parameters") + } + switch separatorExpTyped := aggFuncArgs[1].(type) { + case *ast.BasicLit: + switch separatorExpTyped.Kind { + case token.STRING: + return strings.Trim(separatorExpTyped.Value, "\""), nil + default: + return "", fmt.Errorf("string_agg second parameter must be a constant string") + } + default: + return "", fmt.Errorf("string_agg second parameter must be a basic literal") + } +} diff --git a/pkg/eval/var_values_map.go b/pkg/eval/var_values_map.go index 02da408..08f360a 100644 --- a/pkg/eval/var_values_map.go +++ b/pkg/eval/var_values_map.go @@ -1,30 +1,30 @@ -package eval - -import ( - "fmt" - "strings" -) - -type VarValuesMap map[string]map[string]interface{} - -func (vars *VarValuesMap) Tables() string { - sb := strings.Builder{} - sb.WriteString("[") - for table, _ := range *vars { - sb.WriteString(fmt.Sprintf("%s ", table)) - } - sb.WriteString("]") - return sb.String() -} - -func (vars *VarValuesMap) Names() string { - sb := strings.Builder{} - sb.WriteString("[") - for table, fldMap := range *vars { - for fld, _ := range fldMap { - sb.WriteString(fmt.Sprintf("%s.%s ", table, fld)) - } - } - sb.WriteString("]") - return sb.String() -} +package eval + +import ( + "fmt" + "strings" +) + +type VarValuesMap map[string]map[string]any + +func (vars *VarValuesMap) Tables() string { + sb := strings.Builder{} + sb.WriteString("[") + for table := range *vars { + sb.WriteString(fmt.Sprintf("%s ", table)) + } + sb.WriteString("]") + return sb.String() +} + +func (vars *VarValuesMap) Names() string { + sb := strings.Builder{} + sb.WriteString("[") + for table, fldMap := range *vars { + for fld := range fldMap { + sb.WriteString(fmt.Sprintf("%s.%s ", table, fld)) + } + } + sb.WriteString("]") + return sb.String() +} diff --git a/pkg/eval/var_values_map_test.go b/pkg/eval/var_values_map_test.go index 507544c..edede59 100644 --- a/pkg/eval/var_values_map_test.go +++ b/pkg/eval/var_values_map_test.go @@ -7,7 +7,7 @@ import ( ) func TestVarValuesMapUtils(t *testing.T) { - varValuesMap := VarValuesMap{"some_table": map[string]interface{}{"some_field": 1}} + varValuesMap := VarValuesMap{"some_table": map[string]any{"some_field": 1}} assert.Equal(t, "[some_table ]", varValuesMap.Tables()) assert.Equal(t, "[some_table.some_field ]", varValuesMap.Names()) } diff --git a/pkg/exe/daemon/capidaemon.go b/pkg/exe/daemon/capidaemon.go index a8fa11b..6ed8595 100644 --- a/pkg/exe/daemon/capidaemon.go +++ b/pkg/exe/daemon/capidaemon.go @@ -1,91 +1,90 @@ -package main - -import ( - "log" - "os" - "os/signal" - "syscall" - "time" - - "github.com/capillariesio/capillaries/pkg/custom/py_calc" - "github.com/capillariesio/capillaries/pkg/custom/tag_and_denormalize" - "github.com/capillariesio/capillaries/pkg/env" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/wf" -) - -// https://stackoverflow.com/questions/25927660/how-to-get-the-current-function-name -// func trc() string { -// pc := make([]uintptr, 15) -// n := runtime.Callers(2, pc) -// frames := runtime.CallersFrames(pc[:n]) -// frame, _ := frames.Next() -// return fmt.Sprintf("%s:%d %s\n", frame.File, frame.Line, frame.Function) -// } - -type StandardDaemonProcessorDefFactory struct { -} - -func (f *StandardDaemonProcessorDefFactory) Create(processorType string) (sc.CustomProcessorDef, bool) { - // All processors to be supported by this 'stock' binary (daemon/toolbelt). - // If you develop your own processor(s), use your own ProcessorDefFactory that lists all processors, - // they all must implement CustomProcessorRunner interface - switch processorType { - case py_calc.ProcessorPyCalcName: - return &py_calc.PyCalcProcessorDef{}, true - case tag_and_denormalize.ProcessorTagAndDenormalizeName: - return &tag_and_denormalize.TagAndDenormalizeProcessorDef{}, true - default: - return nil, false - } -} - -func main() { - - envConfig, err := env.ReadEnvConfigFile("capidaemon.json") - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - envConfig.CustomProcessorDefFactoryInstance = &StandardDaemonProcessorDefFactory{} - - logger, err := l.NewLoggerFromEnvConfig(envConfig) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - defer logger.Close() - - logger.PushF("daemon.main") - defer logger.PopF() - - osSignalChannel := make(chan os.Signal, 1) - signal.Notify(osSignalChannel, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) - - for { - daemonCmd := wf.AmqpFullReconnectCycle(envConfig, logger, osSignalChannel) - if daemonCmd == wf.DaemonCmdQuit { - logger.Info("got quit cmd, shut down is supposed to be complete by now") - os.Exit(0) - } - logger.Info("got %d, waiting before reconnect...", daemonCmd) - - // Read from osSignalChannel with timeout - timeoutChannel := make(chan bool, 1) - go func() { - time.Sleep(10 * time.Second) - timeoutChannel <- true - }() - select { - case osSignal := <-osSignalChannel: - if osSignal == os.Interrupt || osSignal == os.Kill { - logger.Info("received os signal %v while reconnecting to mq, quitting...", osSignal) - os.Exit(0) - } - case <-timeoutChannel: - // Break from select - break - } - } -} +package main + +import ( + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/capillariesio/capillaries/pkg/custom/py_calc" + "github.com/capillariesio/capillaries/pkg/custom/tag_and_denormalize" + "github.com/capillariesio/capillaries/pkg/env" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/wf" +) + +// https://stackoverflow.com/questions/25927660/how-to-get-the-current-function-name +// func trc() string { +// pc := make([]uintptr, 15) +// n := runtime.Callers(2, pc) +// frames := runtime.CallersFrames(pc[:n]) +// frame, _ := frames.Next() +// return fmt.Sprintf("%s:%d %s\n", frame.File, frame.Line, frame.Function) +// } + +type StandardDaemonProcessorDefFactory struct { +} + +func (f *StandardDaemonProcessorDefFactory) Create(processorType string) (sc.CustomProcessorDef, bool) { + // All processors to be supported by this 'stock' binary (daemon/toolbelt). + // If you develop your own processor(s), use your own ProcessorDefFactory that lists all processors, + // they all must implement CustomProcessorRunner interface + switch processorType { + case py_calc.ProcessorPyCalcName: + return &py_calc.PyCalcProcessorDef{}, true + case tag_and_denormalize.ProcessorTagAndDenormalizeName: + return &tag_and_denormalize.TagAndDenormalizeProcessorDef{}, true + default: + return nil, false + } +} + +func main() { + + envConfig, err := env.ReadEnvConfigFile("capidaemon.json") + if err != nil { + log.Fatalf(err.Error()) + } + envConfig.CustomProcessorDefFactoryInstance = &StandardDaemonProcessorDefFactory{} + + logger, err := l.NewLoggerFromEnvConfig(envConfig) + if err != nil { + log.Fatalf(err.Error()) + } + defer logger.Close() + + logger.PushF("daemon.main") + defer logger.PopF() + + osSignalChannel := make(chan os.Signal, 1) + signal.Notify(osSignalChannel, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + for { + daemonCmd := wf.AmqpFullReconnectCycle(envConfig, logger, osSignalChannel) + if daemonCmd == wf.DaemonCmdQuit { + logger.Info("got quit cmd, shut down is supposed to be complete by now") + os.Exit(0) + } + logger.Info("got %d, waiting before reconnect...", daemonCmd) + + // Read from osSignalChannel with timeout + timeoutChannel := make(chan bool, 1) + go func() { + time.Sleep(10 * time.Second) + timeoutChannel <- true + }() + select { + case osSignal := <-osSignalChannel: + if osSignal == os.Interrupt || osSignal == os.Kill { + logger.Info("received os signal %v while reconnecting to mq, quitting...", osSignal) + os.Exit(0) + } + case <-timeoutChannel: + logger.Info("timeout while reconnecting to mq, will try to reconnect again") + // Break from select + break //nolint:all + } + } +} diff --git a/pkg/exe/deploy/capideploy.go b/pkg/exe/deploy/capideploy.go index ad1d587..41bfec2 100644 --- a/pkg/exe/deploy/capideploy.go +++ b/pkg/exe/deploy/capideploy.go @@ -46,7 +46,7 @@ func DumpLogChan(logChan chan deploy.LogMsg) { } } -func getNicknamesArg(commonArgs *flag.FlagSet, entityName string) (string, error) { +func getNicknamesArg(entityName string) (string, error) { if len(os.Args) < 3 { return "", fmt.Errorf("not enough args, expected comma-separated list of %s or '*'", entityName) } @@ -159,7 +159,6 @@ Commands: ) fmt.Printf("\nOptional parameters:\n") flagset.PrintDefaults() - os.Exit(0) } func main() { @@ -175,10 +174,11 @@ func main() { cmdStartTs := time.Now() throttle := time.Tick(time.Second) // One call per second, to avoid error 429 on openstack calls - const MaxWorkerThreads int = 10 - var logChan = make(chan deploy.LogMsg, MaxWorkerThreads*5) - var sem = make(chan int, MaxWorkerThreads) + const maxWorkerThreads int = 10 + var logChan = make(chan deploy.LogMsg, maxWorkerThreads*5) + var sem = make(chan int, maxWorkerThreads) var errChan chan error + var parseErr error errorsExpected := 1 var prjPair *deploy.ProjectPair var fullPrjPath string @@ -195,9 +195,12 @@ func main() { } if _, ok := singleThreadCommands[os.Args[1]]; ok { - commonArgs.Parse(os.Args[2:]) + parseErr = commonArgs.Parse(os.Args[2:]) } else { - commonArgs.Parse(os.Args[3:]) + parseErr = commonArgs.Parse(os.Args[3:]) + } + if parseErr != nil { + log.Fatalf(parseErr.Error()) } prjPair, fullPrjPath, prjErr = deploy.LoadProject(*argPrjFile) @@ -226,7 +229,7 @@ func main() { <-sem }() } else if os.Args[1] == CmdCreateInstances || os.Args[1] == CmdDeleteInstances { - nicknames, err := getNicknamesArg(commonArgs, "instances") + nicknames, err := getNicknamesArg("instances") if err != nil { log.Fatalf(err.Error()) } @@ -310,7 +313,7 @@ func main() { os.Args[1] == CmdConfigServices || os.Args[1] == CmdStartServices || os.Args[1] == CmdStopServices { - nicknames, err := getNicknamesArg(commonArgs, "instances") + nicknames, err := getNicknamesArg("instances") if err != nil { log.Fatalf(err.Error()) } @@ -369,7 +372,7 @@ func main() { } } else if os.Args[1] == CmdCreateVolumes || os.Args[1] == CmdAttachVolumes || os.Args[1] == CmdDeleteVolumes { - nicknames, err := getNicknamesArg(commonArgs, "instances") + nicknames, err := getNicknamesArg("instances") if err != nil { log.Fatalf(err.Error()) } @@ -423,7 +426,7 @@ func main() { } else { switch os.Args[1] { case CmdUploadFiles: - nicknames, err := getNicknamesArg(commonArgs, "file groups to upload") + nicknames, err := getNicknamesArg("file groups to upload") if err != nil { log.Fatalf(err.Error()) } @@ -469,7 +472,7 @@ func main() { } case CmdDownloadFiles: - nicknames, err := getNicknamesArg(commonArgs, "file groups to download") + nicknames, err := getNicknamesArg("file groups to download") if err != nil { log.Fatalf(err.Error()) } diff --git a/pkg/exe/toolbelt/capitoolbelt.go b/pkg/exe/toolbelt/capitoolbelt.go index 11fe1f2..d6572f2 100644 --- a/pkg/exe/toolbelt/capitoolbelt.go +++ b/pkg/exe/toolbelt/capitoolbelt.go @@ -1,570 +1,533 @@ -package main - -import ( - "flag" - "fmt" - "log" - "os" - "strconv" - "strings" - "time" - - "github.com/capillariesio/capillaries/pkg/api" - "github.com/capillariesio/capillaries/pkg/custom/py_calc" - "github.com/capillariesio/capillaries/pkg/custom/tag_and_denormalize" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/env" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/wfmodel" - amqp "github.com/rabbitmq/amqp091-go" -) - -type DotDiagramType string - -const ( - DotDiagramIndexes DotDiagramType = "indexes" - DotDiagramFields DotDiagramType = "fields" - DotDiagramRunStatus DotDiagramType = "run_status" -) - -func NodeBatchStatusToColor(status wfmodel.NodeBatchStatusType) string { - switch status { - case wfmodel.NodeBatchNone: - return "white" - case wfmodel.NodeBatchStart: - return "lightblue" - case wfmodel.NodeBatchSuccess: - return "green" - case wfmodel.NodeBatchFail: - return "red" - case wfmodel.NodeBatchRunStopReceived: - return "orangered" - default: - return "cyan" - } -} - -func GetDotDiagram(scriptDef *sc.ScriptDef, dotDiagramType DotDiagramType, nodeColorMap map[string]string) string { - var b strings.Builder - - const RecordFontSize int = 20 - const ArrowFontSize int = 18 - - urlEscaper := strings.NewReplacer(`{`, `\{`, `}`, `\}`, `|`, `\|`) - b.WriteString(fmt.Sprintf("\ndigraph %s {\nrankdir=\"TD\";\n node [fontname=\"Helvetica\"];\nedge [fontname=\"Helvetica\"];\ngraph [splines=true, pad=\"0.5\", ranksep=\"0.5\", nodesep=\"0.5\"];\n", dotDiagramType)) - for _, node := range scriptDef.ScriptNodes { - penWidth := "1" - if node.StartPolicy == sc.NodeStartManual { - penWidth = "6" - } - fillColor := "white" - var ok bool - if nodeColorMap != nil { - if fillColor, ok = nodeColorMap[node.Name]; !ok { - fillColor = "white" // This run does not affect this node, or the node was not started - } - } - - if node.HasFileReader() { - arrowLabelBuilder := strings.Builder{} - if dotDiagramType == DotDiagramType(DotDiagramFields) { - for colName := range node.FileReader.Columns { - arrowLabelBuilder.WriteString(colName) - arrowLabelBuilder.WriteString("\\l") - } - } - fileNames := make([]string, len(node.FileReader.SrcFileUrls)) - copy(fileNames, node.FileReader.SrcFileUrls) - - b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=dotted, fontsize=\"%d\", label=\"%s\"];\n", node.FileReader.SrcFileUrls[0], node.GetTargetName(), ArrowFontSize, arrowLabelBuilder.String())) - b.WriteString(fmt.Sprintf("\"%s\" [shape=folder, fontsize=\"%d\", label=\"%s\", tooltip=\"Source data file(s)\"];\n", node.FileReader.SrcFileUrls[0], RecordFontSize, strings.Join(fileNames, "\\n"))) - } - - allUsedFields := sc.FieldRefs{} - - if node.HasFileCreator() { - usedInAllTargetFileExpressions := node.FileCreator.GetFieldRefsUsedInAllTargetFileExpressions() - allUsedFields.Append(usedInAllTargetFileExpressions) - } else if node.HasTableCreator() { - usedInAllTargetTableExpressions := sc.GetFieldRefsUsedInAllTargetExpressions(node.TableCreator.Fields) - allUsedFields.Append(usedInAllTargetTableExpressions) - } - - if node.HasTableReader() { - var inSrcArrowLabel string - if dotDiagramType == DotDiagramType(DotDiagramIndexes) || dotDiagramType == DotDiagramType(DotDiagramRunStatus) { - if node.TableReader.ExpectedBatchesTotal > 1 { - inSrcArrowLabel = fmt.Sprintf("%s (%d batches)", node.TableReader.TableName, node.TableReader.ExpectedBatchesTotal) - } else { - inSrcArrowLabel = fmt.Sprintf("%s (no parallelism)", node.TableReader.TableName) - } - } else if dotDiagramType == DotDiagramType(DotDiagramFields) { - inSrcArrowLabelBuilder := strings.Builder{} - for i := 0; i < len(allUsedFields); i++ { - if allUsedFields[i].TableName == sc.ReaderAlias { - inSrcArrowLabelBuilder.WriteString(allUsedFields[i].FieldName) - inSrcArrowLabelBuilder.WriteString("\\l") - } - } - inSrcArrowLabel = inSrcArrowLabelBuilder.String() - } - if node.HasFileCreator() { - // In (reader) - //if node.TableReader.ExpectedBatchesTotal > 1 { - //b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=solid];\n", node.TableReader.TableName, node.Name)) - //b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=solid];\n", node.TableReader.TableName, node.Name)) - //} - b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=solid, fontsize=\"%d\", label=\"%s\"];\n", node.TableReader.TableName, node.Name, ArrowFontSize, inSrcArrowLabel)) - - // Node (file) - b.WriteString(fmt.Sprintf("\"%s\" [shape=record, penwidth=\"%s\", fontsize=\"%d\", fillcolor=\"%s\", style=\"filled\", label=\"{%s|creates file:\\n%s}\", tooltip=\"%s\"];\n", node.Name, penWidth, RecordFontSize, fillColor, node.Name, urlEscaper.Replace(node.FileCreator.UrlTemplate), node.Desc)) - - // Out (file) - arrowLabelBuilder := strings.Builder{} - if dotDiagramType == DotDiagramType(DotDiagramFields) { - for i := 0; i < len(allUsedFields); i++ { - arrowLabelBuilder.WriteString(allUsedFields[i].FieldName) - arrowLabelBuilder.WriteString("\\l") - } - } - - b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=dotted, fontsize=\"%d\", label=\"%s\"];\n", node.Name, node.FileCreator.UrlTemplate, ArrowFontSize, arrowLabelBuilder.String())) - b.WriteString(fmt.Sprintf("\"%s\" [shape=note, fontsize=\"%d\", label=\"%s\", tooltip=\"Target data file(s)\"];\n", node.FileCreator.UrlTemplate, RecordFontSize, node.FileCreator.UrlTemplate)) - } else { - // In (reader) - // if node.TableReader.ExpectedBatchesTotal > 1 { - // b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=solid];\n", node.TableReader.TableName, node.GetTargetName())) - // b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=solid];\n", node.TableReader.TableName, node.GetTargetName())) - // } - b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=solid, fontsize=\"%d\", label=\"%s\"];\n", node.TableReader.TableName, node.GetTargetName(), ArrowFontSize, inSrcArrowLabel)) - } - - if node.HasLookup() { - inLkpArrowLabel := fmt.Sprintf("%s (lookup)", node.Lookup.IndexName) - if dotDiagramType == DotDiagramType(DotDiagramFields) { - inLkpArrowLabelBuilder := strings.Builder{} - for i := 0; i < len(allUsedFields); i++ { - if allUsedFields[i].TableName == sc.LookupAlias { - inLkpArrowLabelBuilder.WriteString(allUsedFields[i].FieldName) - inLkpArrowLabelBuilder.WriteString("\\l") - } - } - inLkpArrowLabel = inLkpArrowLabelBuilder.String() - } - // In (lookup) - b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=dashed, fontsize=\"%d\", label=\"%s\"];\n", node.Lookup.TableCreator.Name, node.GetTargetName(), ArrowFontSize, inLkpArrowLabel)) - } - - } - - if node.HasTableCreator() { - // Node (table) - if node.HasLookup() { - b.WriteString(fmt.Sprintf("\"%s\" [shape=record, penwidth=\"%s\", fontsize=\"%d\", fillcolor=\"%s\", style=\"filled\", label=\"{%s|creates table:\\n%s|group:%t, join:%s}\", tooltip=\"%s\"];\n", node.TableCreator.Name, penWidth, RecordFontSize, fillColor, node.Name, node.TableCreator.Name, node.Lookup.IsGroup, node.Lookup.LookupJoin, node.Desc)) - } else { - b.WriteString(fmt.Sprintf("\"%s\" [shape=record, penwidth=\"%s\", fontsize=\"%d\", fillcolor=\"%s\", style=\"filled\", label=\"{%s|creates table:\\n%s}\", tooltip=\"%s\"];\n", node.TableCreator.Name, penWidth, RecordFontSize, fillColor, node.Name, node.TableCreator.Name, node.Desc)) - } - } - } - b.WriteString("}\n") - - return b.String() -} - -const LogTsFormatUnquoted = `2006-01-02T15:04:05.000-0700` - -type StandardToolbeltProcessorDefFactory struct { -} - -func (f *StandardToolbeltProcessorDefFactory) Create(processorType string) (sc.CustomProcessorDef, bool) { - // All processors to be supported by this 'stock' binary (daemon/toolbelt). - // If you develop your own processor(s), use your own ProcessorDefFactory that lists all processors, - // they all must implement CustomProcessorRunner interface - switch processorType { - case py_calc.ProcessorPyCalcName: - return &py_calc.PyCalcProcessorDef{}, true - case tag_and_denormalize.ProcessorTagAndDenormalizeName: - return &tag_and_denormalize.TagAndDenormalizeProcessorDef{}, true - default: - return nil, false - } -} - -func stringToArrayOfInt16(s string) ([]int16, error) { - var result []int16 - if len(strings.TrimSpace(s)) > 0 { - stringItems := strings.Split(s, ",") - result = make([]int16, len(stringItems)) - for itemIdx, stringItem := range stringItems { - intItem, err := strconv.ParseInt(strings.TrimSpace(stringItem), 10, 16) - if err != nil { - return nil, fmt.Errorf("invalid int16 %s:%s", stringItem, err.Error()) - } - result[itemIdx] = int16(intItem) - } - } - return result, nil -} - -func stringToArrayOfStrings(s string) ([]string, error) { - var result []string - if len(strings.TrimSpace(s)) > 0 { - stringItems := strings.Split(s, ",") - result = make([]string, len(stringItems)) - for itemIdx, stringItem := range stringItems { - result[itemIdx] = stringItem - } - } - return result, nil -} - -const ( - CmdValidateScript string = "validate_script" - CmdStartRun string = "start_run" - CmdStopRun string = "stop_run" - CmdExecNode string = "exec_node" - CmdGetRunHistory string = "get_run_history" - CmdGetNodeHistory string = "get_node_history" - CmdGetBatchHistory string = "get_batch_history" - CmdGetRunStatusDiagram string = "drop_run_status_diagram" - CmdDropKeyspace string = "drop_keyspace" - CmdGetTableCql string = "get_table_cql" -) - -func usage(flagset *flag.FlagSet) { - fmt.Printf("Capillaries toolbelt\nUsage: capitoolbelt \nCommands:\n") - fmt.Printf(" %s\n %s\n %s\n %s\n %s\n %s\n %s\n %s\n %s\n %s\n", - CmdValidateScript, - CmdStartRun, - CmdStopRun, - CmdExecNode, - CmdGetRunHistory, - CmdGetNodeHistory, - CmdGetBatchHistory, - CmdGetRunStatusDiagram, - CmdDropKeyspace, - CmdGetTableCql) - if flagset != nil { - fmt.Printf("\n%s parameters:\n", flagset.Name()) - flagset.PrintDefaults() - } - os.Exit(0) -} - -func main() { - //defer profile.Start().Stop() - - envConfig, err := env.ReadEnvConfigFile("capitoolbelt.json") - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - envConfig.CustomProcessorDefFactoryInstance = &StandardToolbeltProcessorDefFactory{} - logger, err := l.NewLoggerFromEnvConfig(envConfig) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - defer logger.Close() - - if len(os.Args) <= 1 { - usage(nil) - } - - switch os.Args[1] { - case CmdValidateScript: - validateScriptCmd := flag.NewFlagSet(CmdValidateScript, flag.ExitOnError) - scriptFilePath := validateScriptCmd.String("script_file", "", "Path to script file") - paramsFilePath := validateScriptCmd.String("params_file", "", "Path to script parameters map file") - isIdxDag := validateScriptCmd.Bool("idx_dag", false, "Print index DAG") - isFieldDag := validateScriptCmd.Bool("field_dag", false, "Print field DAG") - if err := validateScriptCmd.Parse(os.Args[2:]); err != nil || *scriptFilePath == "" || *paramsFilePath == "" || (!*isIdxDag && !*isFieldDag) { - usage(validateScriptCmd) - } - - script, err, _ := sc.NewScriptFromFiles(envConfig.CaPath, envConfig.PrivateKeys, *scriptFilePath, *paramsFilePath, envConfig.CustomProcessorDefFactoryInstance, envConfig.CustomProcessorsSettings) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - if *isIdxDag { - fmt.Println(GetDotDiagram(script, DotDiagramIndexes, nil)) - } - if *isFieldDag { - fmt.Println(GetDotDiagram(script, DotDiagramFields, nil)) - } - - case CmdStopRun: - stopRunCmd := flag.NewFlagSet(CmdStopRun, flag.ExitOnError) - keyspace := stopRunCmd.String("keyspace", "", "Keyspace (session id)") - runIdString := stopRunCmd.String("run_id", "", "Run id") - if err := stopRunCmd.Parse(os.Args[2:]); err != nil { - usage(stopRunCmd) - } - - runId, err := strconv.ParseInt(strings.TrimSpace(*runIdString), 10, 16) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - cqlSession, err := db.NewSession(envConfig, *keyspace, db.DoNotCreateKeyspaceOnConnect) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - err = api.StopRun(logger, cqlSession, *keyspace, int16(runId), "stopped by toolbelt") - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - case CmdGetRunHistory: - getRunsCmd := flag.NewFlagSet(CmdGetRunHistory, flag.ExitOnError) - keyspace := getRunsCmd.String("keyspace", "", "Keyspace (session id)") - if err := getRunsCmd.Parse(os.Args[2:]); err != nil { - usage(getRunsCmd) - } - - cqlSession, err := db.NewSession(envConfig, *keyspace, db.DoNotCreateKeyspaceOnConnect) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - runs, err := api.GetRunHistory(logger, cqlSession, *keyspace) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - fmt.Println(strings.Join(wfmodel.RunHistoryEventAllFields(), ",")) - for _, r := range runs { - fmt.Printf("%s,%d,%d,%s\n", r.Ts.Format(LogTsFormatUnquoted), r.RunId, r.Status, strings.ReplaceAll(r.Comment, ",", ";")) - } - - case CmdGetNodeHistory: - getNodeHistoryCmd := flag.NewFlagSet(CmdGetNodeHistory, flag.ExitOnError) - keyspace := getNodeHistoryCmd.String("keyspace", "", "Keyspace (session id)") - runIdsString := getNodeHistoryCmd.String("run_ids", "", "Limit results to specific run ids (optional), comma-separated list") - if err := getNodeHistoryCmd.Parse(os.Args[2:]); err != nil { - usage(getNodeHistoryCmd) - } - - cqlSession, err := db.NewSession(envConfig, *keyspace, db.DoNotCreateKeyspaceOnConnect) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - runIds, err := stringToArrayOfInt16(*runIdsString) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - nodes, err := api.GetRunsNodeHistory(logger, cqlSession, *keyspace, runIds) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - fmt.Println(strings.Join(wfmodel.NodeHistoryEventAllFields(), ",")) - for _, n := range nodes { - fmt.Printf("%s,%d,%s,%d,%s\n", n.Ts.Format(LogTsFormatUnquoted), n.RunId, n.ScriptNode, n.Status, strings.ReplaceAll(n.Comment, ",", ";")) - } - - case CmdGetBatchHistory: - getBatchHistoryCmd := flag.NewFlagSet(CmdGetBatchHistory, flag.ExitOnError) - keyspace := getBatchHistoryCmd.String("keyspace", "", "Keyspace (session id)") - runIdsString := getBatchHistoryCmd.String("run_ids", "", "Limit results to specific run ids (optional), comma-separated list") - nodeNamesString := getBatchHistoryCmd.String("nodes", "", "Limit results to specific node names (optional), comma-separated list") - if err := getBatchHistoryCmd.Parse(os.Args[2:]); err != nil { - usage(getBatchHistoryCmd) - } - - cqlSession, err := db.NewSession(envConfig, *keyspace, db.DoNotCreateKeyspaceOnConnect) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - runIds, err := stringToArrayOfInt16(*runIdsString) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - nodeNames, err := stringToArrayOfStrings(*nodeNamesString) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - runs, err := api.GetBatchHistory(logger, cqlSession, *keyspace, runIds, nodeNames) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - fmt.Println(strings.Join(wfmodel.BatchHistoryEventAllFields(), ",")) - for _, r := range runs { - fmt.Printf("%s,%d,%s,%d,%d,%d,%d,%d,%s\n", r.Ts.Format(LogTsFormatUnquoted), r.RunId, r.ScriptNode, r.BatchIdx, r.BatchesTotal, r.Status, r.FirstToken, r.LastToken, strings.ReplaceAll(r.Comment, ",", ";")) - } - - case CmdGetTableCql: - getTableCqlCmd := flag.NewFlagSet(CmdGetTableCql, flag.ExitOnError) - scriptFilePath := getTableCqlCmd.String("script_file", "", "Path to script file") - paramsFilePath := getTableCqlCmd.String("params_file", "", "Path to script parameters map file") - keyspace := getTableCqlCmd.String("keyspace", "", "Keyspace (session id)") - runId := getTableCqlCmd.Int("run_id", 0, "Run id") - startNodesString := getTableCqlCmd.String("start_nodes", "", "Comma-separated list of start node names") - if err := getTableCqlCmd.Parse(os.Args[2:]); err != nil { - usage(getTableCqlCmd) - } - - script, err, _ := sc.NewScriptFromFiles(envConfig.CaPath, envConfig.PrivateKeys, *scriptFilePath, *paramsFilePath, envConfig.CustomProcessorDefFactoryInstance, envConfig.CustomProcessorsSettings) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - startNodes := strings.Split(*startNodesString, ",") - - fmt.Print(api.GetTablesCql(script, *keyspace, int16(*runId), startNodes)) - - case CmdGetRunStatusDiagram: - getRunStatusDiagramCmd := flag.NewFlagSet(CmdGetRunStatusDiagram, flag.ExitOnError) - scriptFilePath := getRunStatusDiagramCmd.String("script_file", "", "Path to script file") - paramsFilePath := getRunStatusDiagramCmd.String("params_file", "", "Path to script parameters map file") - keyspace := getRunStatusDiagramCmd.String("keyspace", "", "Keyspace (session id)") - runIdString := getRunStatusDiagramCmd.String("run_id", "", "Run id") - if err := getRunStatusDiagramCmd.Parse(os.Args[2:]); err != nil { - usage(getRunStatusDiagramCmd) - } - - runId, err := strconv.ParseInt(strings.TrimSpace(*runIdString), 10, 16) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - script, err, _ := sc.NewScriptFromFiles(envConfig.CaPath, envConfig.PrivateKeys, *scriptFilePath, *paramsFilePath, envConfig.CustomProcessorDefFactoryInstance, envConfig.CustomProcessorsSettings) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - cqlSession, err := db.NewSession(envConfig, *keyspace, db.DoNotCreateKeyspaceOnConnect) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - nodes, err := api.GetRunsNodeHistory(logger, cqlSession, *keyspace, []int16{int16(runId)}) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - nodeColorMap := map[string]string{} - for _, node := range nodes { - nodeColorMap[node.ScriptNode] = NodeBatchStatusToColor(node.Status) - } - - fmt.Println(GetDotDiagram(script, DotDiagramRunStatus, nodeColorMap)) - - case CmdDropKeyspace: - dropKsCmd := flag.NewFlagSet(CmdDropKeyspace, flag.ExitOnError) - keyspace := dropKsCmd.String("keyspace", "", "Keyspace (session id)") - if err := dropKsCmd.Parse(os.Args[2:]); err != nil { - usage(dropKsCmd) - } - - cqlSession, err := db.NewSession(envConfig, *keyspace, db.DoNotCreateKeyspaceOnConnect) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - err = api.DropKeyspace(logger, cqlSession, *keyspace) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - case CmdStartRun: - startRunCmd := flag.NewFlagSet(CmdStartRun, flag.ExitOnError) - keyspace := startRunCmd.String("keyspace", "", "Keyspace (session id)") - scriptFilePath := startRunCmd.String("script_file", "", "Path to script file") - paramsFilePath := startRunCmd.String("params_file", "", "Path to script parameters map file") - startNodesString := startRunCmd.String("start_nodes", "", "Comma-separated list of start node names") - if err := startRunCmd.Parse(os.Args[2:]); err != nil { - usage(startRunCmd) - } - - startNodes := strings.Split(*startNodesString, ",") - - cqlSession, err := db.NewSession(envConfig, *keyspace, db.CreateKeyspaceOnConnect) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - // RabbitMQ boilerplate - amqpConnection, err := amqp.Dial(envConfig.Amqp.URL) - if err != nil { - log.Fatalf(fmt.Sprintf("cannot dial RabbitMQ at %v, will reconnect: %v\n", envConfig.Amqp.URL, err)) - os.Exit(1) - } - defer amqpConnection.Close() - - amqpChannel, err := amqpConnection.Channel() - if err != nil { - log.Fatalf(fmt.Sprintf("cannot create amqp channel, will reconnect: %v\n", err)) - os.Exit(1) - } - defer amqpChannel.Close() - - runId, err := api.StartRun(envConfig, logger, amqpChannel, *scriptFilePath, *paramsFilePath, cqlSession, *keyspace, startNodes, "started by Toolbelt") - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - fmt.Println(runId) - - case CmdExecNode: - execNodeCmd := flag.NewFlagSet(CmdExecNode, flag.ExitOnError) - keyspace := execNodeCmd.String("keyspace", "", "Keyspace (session id)") - scriptFilePath := execNodeCmd.String("script_file", "", "Path to script file") - paramsFilePath := execNodeCmd.String("params_file", "", "Path to script parameters map file") - runIdParam := execNodeCmd.Int("run_id", 0, "run id (optional, use with extra caution as it will modify existing run id results)") - nodeName := execNodeCmd.String("node_id", "", "Script node name") - if err := execNodeCmd.Parse(os.Args[2:]); err != nil { - usage(execNodeCmd) - } - - runId := int16(*runIdParam) - - startTime := time.Now() - - cqlSession, err := db.NewSession(envConfig, *keyspace, db.CreateKeyspaceOnConnect) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - runId, err = api.RunNode(envConfig, logger, *nodeName, runId, *scriptFilePath, *paramsFilePath, cqlSession, *keyspace) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - fmt.Printf("run %d, elapsed %v\n", runId, time.Since(startTime)) - - default: - fmt.Printf("invalid command: %s\n", os.Args[1]) - usage(nil) - } -} +package main + +import ( + "flag" + "fmt" + "log" + "os" + "strconv" + "strings" + "time" + + "github.com/capillariesio/capillaries/pkg/api" + "github.com/capillariesio/capillaries/pkg/custom/py_calc" + "github.com/capillariesio/capillaries/pkg/custom/tag_and_denormalize" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/env" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/wfmodel" + amqp "github.com/rabbitmq/amqp091-go" +) + +type DotDiagramType string + +const ( + DotDiagramIndexes DotDiagramType = "indexes" + DotDiagramFields DotDiagramType = "fields" + DotDiagramRunStatus DotDiagramType = "run_status" +) + +func NodeBatchStatusToColor(status wfmodel.NodeBatchStatusType) string { + switch status { + case wfmodel.NodeBatchNone: + return "white" + case wfmodel.NodeBatchStart: + return "lightblue" + case wfmodel.NodeBatchSuccess: + return "green" + case wfmodel.NodeBatchFail: + return "red" + case wfmodel.NodeBatchRunStopReceived: + return "orangered" + default: + return "cyan" + } +} + +func GetDotDiagram(scriptDef *sc.ScriptDef, dotDiagramType DotDiagramType, nodeColorMap map[string]string) string { + var b strings.Builder + + const recordFontSize int = 20 + const arrowFontSize int = 18 + + urlEscaper := strings.NewReplacer(`{`, `\{`, `}`, `\}`, `|`, `\|`) + b.WriteString(fmt.Sprintf("\ndigraph %s {\nrankdir=\"TD\";\n node [fontname=\"Helvetica\"];\nedge [fontname=\"Helvetica\"];\ngraph [splines=true, pad=\"0.5\", ranksep=\"0.5\", nodesep=\"0.5\"];\n", dotDiagramType)) + for _, node := range scriptDef.ScriptNodes { + penWidth := "1" + if node.StartPolicy == sc.NodeStartManual { + penWidth = "6" + } + fillColor := "white" + var ok bool + if nodeColorMap != nil { + if fillColor, ok = nodeColorMap[node.Name]; !ok { + fillColor = "white" // This run does not affect this node, or the node was not started + } + } + + if node.HasFileReader() { + arrowLabelBuilder := strings.Builder{} + if dotDiagramType == DotDiagramType(DotDiagramFields) { + for colName := range node.FileReader.Columns { + arrowLabelBuilder.WriteString(colName) + arrowLabelBuilder.WriteString("\\l") + } + } + fileNames := make([]string, len(node.FileReader.SrcFileUrls)) + copy(fileNames, node.FileReader.SrcFileUrls) + + b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=dotted, fontsize=\"%d\", label=\"%s\"];\n", node.FileReader.SrcFileUrls[0], node.GetTargetName(), arrowFontSize, arrowLabelBuilder.String())) + b.WriteString(fmt.Sprintf("\"%s\" [shape=folder, fontsize=\"%d\", label=\"%s\", tooltip=\"Source data file(s)\"];\n", node.FileReader.SrcFileUrls[0], recordFontSize, strings.Join(fileNames, "\\n"))) + } + + allUsedFields := sc.FieldRefs{} + + if node.HasFileCreator() { + usedInAllTargetFileExpressions := node.FileCreator.GetFieldRefsUsedInAllTargetFileExpressions() + allUsedFields.Append(usedInAllTargetFileExpressions) + } else if node.HasTableCreator() { + usedInAllTargetTableExpressions := sc.GetFieldRefsUsedInAllTargetExpressions(node.TableCreator.Fields) + allUsedFields.Append(usedInAllTargetTableExpressions) + } + + if node.HasTableReader() { + var inSrcArrowLabel string + if dotDiagramType == DotDiagramType(DotDiagramIndexes) || dotDiagramType == DotDiagramType(DotDiagramRunStatus) { + if node.TableReader.ExpectedBatchesTotal > 1 { + inSrcArrowLabel = fmt.Sprintf("%s (%d batches)", node.TableReader.TableName, node.TableReader.ExpectedBatchesTotal) + } else { + inSrcArrowLabel = fmt.Sprintf("%s (no parallelism)", node.TableReader.TableName) + } + } else if dotDiagramType == DotDiagramType(DotDiagramFields) { + inSrcArrowLabelBuilder := strings.Builder{} + for i := 0; i < len(allUsedFields); i++ { + if allUsedFields[i].TableName == sc.ReaderAlias { + inSrcArrowLabelBuilder.WriteString(allUsedFields[i].FieldName) + inSrcArrowLabelBuilder.WriteString("\\l") + } + } + inSrcArrowLabel = inSrcArrowLabelBuilder.String() + } + if node.HasFileCreator() { + b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=solid, fontsize=\"%d\", label=\"%s\"];\n", node.TableReader.TableName, node.Name, arrowFontSize, inSrcArrowLabel)) + + // Node (file) + b.WriteString(fmt.Sprintf("\"%s\" [shape=record, penwidth=\"%s\", fontsize=\"%d\", fillcolor=\"%s\", style=\"filled\", label=\"{%s|creates file:\\n%s}\", tooltip=\"%s\"];\n", node.Name, penWidth, recordFontSize, fillColor, node.Name, urlEscaper.Replace(node.FileCreator.UrlTemplate), node.Desc)) + + // Out (file) + arrowLabelBuilder := strings.Builder{} + if dotDiagramType == DotDiagramType(DotDiagramFields) { + for i := 0; i < len(allUsedFields); i++ { + arrowLabelBuilder.WriteString(allUsedFields[i].FieldName) + arrowLabelBuilder.WriteString("\\l") + } + } + + b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=dotted, fontsize=\"%d\", label=\"%s\"];\n", node.Name, node.FileCreator.UrlTemplate, arrowFontSize, arrowLabelBuilder.String())) + b.WriteString(fmt.Sprintf("\"%s\" [shape=note, fontsize=\"%d\", label=\"%s\", tooltip=\"Target data file(s)\"];\n", node.FileCreator.UrlTemplate, recordFontSize, node.FileCreator.UrlTemplate)) + } else { + b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=solid, fontsize=\"%d\", label=\"%s\"];\n", node.TableReader.TableName, node.GetTargetName(), arrowFontSize, inSrcArrowLabel)) + } + + if node.HasLookup() { + inLkpArrowLabel := fmt.Sprintf("%s (lookup)", node.Lookup.IndexName) + if dotDiagramType == DotDiagramType(DotDiagramFields) { + inLkpArrowLabelBuilder := strings.Builder{} + for i := 0; i < len(allUsedFields); i++ { + if allUsedFields[i].TableName == sc.LookupAlias { + inLkpArrowLabelBuilder.WriteString(allUsedFields[i].FieldName) + inLkpArrowLabelBuilder.WriteString("\\l") + } + } + inLkpArrowLabel = inLkpArrowLabelBuilder.String() + } + // In (lookup) + b.WriteString(fmt.Sprintf("\"%s\" -> \"%s\" [style=dashed, fontsize=\"%d\", label=\"%s\"];\n", node.Lookup.TableCreator.Name, node.GetTargetName(), arrowFontSize, inLkpArrowLabel)) + } + + } + + if node.HasTableCreator() { + // Node (table) + if node.HasLookup() { + b.WriteString(fmt.Sprintf("\"%s\" [shape=record, penwidth=\"%s\", fontsize=\"%d\", fillcolor=\"%s\", style=\"filled\", label=\"{%s|creates table:\\n%s|group:%t, join:%s}\", tooltip=\"%s\"];\n", node.TableCreator.Name, penWidth, recordFontSize, fillColor, node.Name, node.TableCreator.Name, node.Lookup.IsGroup, node.Lookup.LookupJoin, node.Desc)) + } else { + b.WriteString(fmt.Sprintf("\"%s\" [shape=record, penwidth=\"%s\", fontsize=\"%d\", fillcolor=\"%s\", style=\"filled\", label=\"{%s|creates table:\\n%s}\", tooltip=\"%s\"];\n", node.TableCreator.Name, penWidth, recordFontSize, fillColor, node.Name, node.TableCreator.Name, node.Desc)) + } + } + } + b.WriteString("}\n") + + return b.String() +} + +const LogTsFormatUnquoted = `2006-01-02T15:04:05.000-0700` + +type StandardToolbeltProcessorDefFactory struct { +} + +func (f *StandardToolbeltProcessorDefFactory) Create(processorType string) (sc.CustomProcessorDef, bool) { + // All processors to be supported by this 'stock' binary (daemon/toolbelt). + // If you develop your own processor(s), use your own ProcessorDefFactory that lists all processors, + // they all must implement CustomProcessorRunner interface + switch processorType { + case py_calc.ProcessorPyCalcName: + return &py_calc.PyCalcProcessorDef{}, true + case tag_and_denormalize.ProcessorTagAndDenormalizeName: + return &tag_and_denormalize.TagAndDenormalizeProcessorDef{}, true + default: + return nil, false + } +} + +func stringToArrayOfInt16(s string) ([]int16, error) { + var result []int16 + if len(strings.TrimSpace(s)) > 0 { + stringItems := strings.Split(s, ",") + result = make([]int16, len(stringItems)) + for itemIdx, stringItem := range stringItems { + intItem, err := strconv.ParseInt(strings.TrimSpace(stringItem), 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid int16 %s:%s", stringItem, err.Error()) + } + result[itemIdx] = int16(intItem) + } + } + return result, nil +} + +const ( + CmdValidateScript string = "validate_script" + CmdStartRun string = "start_run" + CmdStopRun string = "stop_run" + CmdExecNode string = "exec_node" + CmdGetRunHistory string = "get_run_history" + CmdGetNodeHistory string = "get_node_history" + CmdGetBatchHistory string = "get_batch_history" + CmdGetRunStatusDiagram string = "drop_run_status_diagram" + CmdDropKeyspace string = "drop_keyspace" + CmdGetTableCql string = "get_table_cql" +) + +func usage(flagset *flag.FlagSet) { + fmt.Printf("Capillaries toolbelt\nUsage: capitoolbelt \nCommands:\n") + fmt.Printf(" %s\n %s\n %s\n %s\n %s\n %s\n %s\n %s\n %s\n %s\n", + CmdValidateScript, + CmdStartRun, + CmdStopRun, + CmdExecNode, + CmdGetRunHistory, + CmdGetNodeHistory, + CmdGetBatchHistory, + CmdGetRunStatusDiagram, + CmdDropKeyspace, + CmdGetTableCql) + if flagset != nil { + fmt.Printf("\n%s parameters:\n", flagset.Name()) + flagset.PrintDefaults() + } +} + +func main() { + // defer profile.Start().Stop() + + envConfig, err := env.ReadEnvConfigFile("capitoolbelt.json") + if err != nil { + log.Fatalf(err.Error()) + } + envConfig.CustomProcessorDefFactoryInstance = &StandardToolbeltProcessorDefFactory{} + logger, err := l.NewLoggerFromEnvConfig(envConfig) + if err != nil { + log.Fatalf(err.Error()) + } + defer logger.Close() + + if len(os.Args) <= 1 { + usage(nil) + os.Exit(0) + } + + switch os.Args[1] { + case CmdValidateScript: + validateScriptCmd := flag.NewFlagSet(CmdValidateScript, flag.ExitOnError) + scriptFilePath := validateScriptCmd.String("script_file", "", "Path to script file") + paramsFilePath := validateScriptCmd.String("params_file", "", "Path to script parameters map file") + isIdxDag := validateScriptCmd.Bool("idx_dag", false, "Print index DAG") + isFieldDag := validateScriptCmd.Bool("field_dag", false, "Print field DAG") + if err := validateScriptCmd.Parse(os.Args[2:]); err != nil || *scriptFilePath == "" || *paramsFilePath == "" || (!*isIdxDag && !*isFieldDag) { + usage(validateScriptCmd) + os.Exit(0) + } + + script, _, err := sc.NewScriptFromFiles(envConfig.CaPath, envConfig.PrivateKeys, *scriptFilePath, *paramsFilePath, envConfig.CustomProcessorDefFactoryInstance, envConfig.CustomProcessorsSettings) + if err != nil { + log.Fatalf(err.Error()) + } + + if *isIdxDag { + fmt.Println(GetDotDiagram(script, DotDiagramIndexes, nil)) + } + if *isFieldDag { + fmt.Println(GetDotDiagram(script, DotDiagramFields, nil)) + } + + case CmdStopRun: + stopRunCmd := flag.NewFlagSet(CmdStopRun, flag.ExitOnError) + keyspace := stopRunCmd.String("keyspace", "", "Keyspace (session id)") + runIdString := stopRunCmd.String("run_id", "", "Run id") + if err := stopRunCmd.Parse(os.Args[2:]); err != nil { + usage(stopRunCmd) + os.Exit(0) + } + + runId, err := strconv.ParseInt(strings.TrimSpace(*runIdString), 10, 16) + if err != nil { + log.Fatalf(err.Error()) + } + + cqlSession, err := db.NewSession(envConfig, *keyspace, db.DoNotCreateKeyspaceOnConnect) + if err != nil { + log.Fatalf(err.Error()) + } + + err = api.StopRun(logger, cqlSession, *keyspace, int16(runId), "stopped by toolbelt") + if err != nil { + log.Fatalf(err.Error()) + } + + case CmdGetRunHistory: + getRunsCmd := flag.NewFlagSet(CmdGetRunHistory, flag.ExitOnError) + keyspace := getRunsCmd.String("keyspace", "", "Keyspace (session id)") + if err := getRunsCmd.Parse(os.Args[2:]); err != nil { + usage(getRunsCmd) + os.Exit(0) + } + + cqlSession, err := db.NewSession(envConfig, *keyspace, db.DoNotCreateKeyspaceOnConnect) + if err != nil { + log.Fatalf(err.Error()) + } + + runs, err := api.GetRunHistory(logger, cqlSession, *keyspace) + if err != nil { + log.Fatalf(err.Error()) + } + fmt.Println(strings.Join(wfmodel.RunHistoryEventAllFields(), ",")) + for _, r := range runs { + fmt.Printf("%s,%d,%d,%s\n", r.Ts.Format(LogTsFormatUnquoted), r.RunId, r.Status, strings.ReplaceAll(r.Comment, ",", ";")) + } + + case CmdGetNodeHistory: + getNodeHistoryCmd := flag.NewFlagSet(CmdGetNodeHistory, flag.ExitOnError) + keyspace := getNodeHistoryCmd.String("keyspace", "", "Keyspace (session id)") + runIdsString := getNodeHistoryCmd.String("run_ids", "", "Limit results to specific run ids (optional), comma-separated list") + if err := getNodeHistoryCmd.Parse(os.Args[2:]); err != nil { + usage(getNodeHistoryCmd) + os.Exit(0) + } + + cqlSession, err := db.NewSession(envConfig, *keyspace, db.DoNotCreateKeyspaceOnConnect) + if err != nil { + log.Fatalf(err.Error()) + } + + runIds, err := stringToArrayOfInt16(*runIdsString) + if err != nil { + log.Fatalf(err.Error()) + } + + nodes, err := api.GetRunsNodeHistory(logger, cqlSession, *keyspace, runIds) + if err != nil { + log.Fatalf(err.Error()) + } + fmt.Println(strings.Join(wfmodel.NodeHistoryEventAllFields(), ",")) + for _, n := range nodes { + fmt.Printf("%s,%d,%s,%d,%s\n", n.Ts.Format(LogTsFormatUnquoted), n.RunId, n.ScriptNode, n.Status, strings.ReplaceAll(n.Comment, ",", ";")) + } + + case CmdGetBatchHistory: + getBatchHistoryCmd := flag.NewFlagSet(CmdGetBatchHistory, flag.ExitOnError) + keyspace := getBatchHistoryCmd.String("keyspace", "", "Keyspace (session id)") + runIdsString := getBatchHistoryCmd.String("run_ids", "", "Limit results to specific run ids (optional), comma-separated list") + nodeNamesString := getBatchHistoryCmd.String("nodes", "", "Limit results to specific node names (optional), comma-separated list") + if err := getBatchHistoryCmd.Parse(os.Args[2:]); err != nil { + usage(getBatchHistoryCmd) + os.Exit(0) + } + + cqlSession, err := db.NewSession(envConfig, *keyspace, db.DoNotCreateKeyspaceOnConnect) + if err != nil { + log.Fatalf(err.Error()) + } + + runIds, err := stringToArrayOfInt16(*runIdsString) + if err != nil { + log.Fatalf(err.Error()) + } + + var nodeNames []string + if len(strings.TrimSpace(*nodeNamesString)) > 0 { + nodeNames = strings.Split(*nodeNamesString, ",") + } else { + nodeNames = make([]string, 0) + } + + runs, err := api.GetBatchHistory(logger, cqlSession, *keyspace, runIds, nodeNames) + if err != nil { + log.Fatalf(err.Error()) + } + + fmt.Println(strings.Join(wfmodel.BatchHistoryEventAllFields(), ",")) + for _, r := range runs { + fmt.Printf("%s,%d,%s,%d,%d,%d,%d,%d,%s\n", r.Ts.Format(LogTsFormatUnquoted), r.RunId, r.ScriptNode, r.BatchIdx, r.BatchesTotal, r.Status, r.FirstToken, r.LastToken, strings.ReplaceAll(r.Comment, ",", ";")) + } + + case CmdGetTableCql: + getTableCqlCmd := flag.NewFlagSet(CmdGetTableCql, flag.ExitOnError) + scriptFilePath := getTableCqlCmd.String("script_file", "", "Path to script file") + paramsFilePath := getTableCqlCmd.String("params_file", "", "Path to script parameters map file") + keyspace := getTableCqlCmd.String("keyspace", "", "Keyspace (session id)") + runId := getTableCqlCmd.Int("run_id", 0, "Run id") + startNodesString := getTableCqlCmd.String("start_nodes", "", "Comma-separated list of start node names") + if err := getTableCqlCmd.Parse(os.Args[2:]); err != nil { + usage(getTableCqlCmd) + os.Exit(0) + } + + script, _, err := sc.NewScriptFromFiles(envConfig.CaPath, envConfig.PrivateKeys, *scriptFilePath, *paramsFilePath, envConfig.CustomProcessorDefFactoryInstance, envConfig.CustomProcessorsSettings) + if err != nil { + log.Fatalf(err.Error()) + } + + startNodes := strings.Split(*startNodesString, ",") + + fmt.Print(api.GetTablesCql(script, *keyspace, int16(*runId), startNodes)) + + case CmdGetRunStatusDiagram: + getRunStatusDiagramCmd := flag.NewFlagSet(CmdGetRunStatusDiagram, flag.ExitOnError) + scriptFilePath := getRunStatusDiagramCmd.String("script_file", "", "Path to script file") + paramsFilePath := getRunStatusDiagramCmd.String("params_file", "", "Path to script parameters map file") + keyspace := getRunStatusDiagramCmd.String("keyspace", "", "Keyspace (session id)") + runIdString := getRunStatusDiagramCmd.String("run_id", "", "Run id") + if err := getRunStatusDiagramCmd.Parse(os.Args[2:]); err != nil { + usage(getRunStatusDiagramCmd) + os.Exit(0) + } + + runId, err := strconv.ParseInt(strings.TrimSpace(*runIdString), 10, 16) + if err != nil { + log.Fatalf(err.Error()) + } + + script, _, err := sc.NewScriptFromFiles(envConfig.CaPath, envConfig.PrivateKeys, *scriptFilePath, *paramsFilePath, envConfig.CustomProcessorDefFactoryInstance, envConfig.CustomProcessorsSettings) + if err != nil { + log.Fatalf(err.Error()) + } + + cqlSession, err := db.NewSession(envConfig, *keyspace, db.DoNotCreateKeyspaceOnConnect) + if err != nil { + log.Fatalf(err.Error()) + } + + nodes, err := api.GetRunsNodeHistory(logger, cqlSession, *keyspace, []int16{int16(runId)}) + if err != nil { + log.Fatalf(err.Error()) + } + + nodeColorMap := map[string]string{} + for _, node := range nodes { + nodeColorMap[node.ScriptNode] = NodeBatchStatusToColor(node.Status) + } + + fmt.Println(GetDotDiagram(script, DotDiagramRunStatus, nodeColorMap)) + + case CmdDropKeyspace: + dropKsCmd := flag.NewFlagSet(CmdDropKeyspace, flag.ExitOnError) + keyspace := dropKsCmd.String("keyspace", "", "Keyspace (session id)") + if err := dropKsCmd.Parse(os.Args[2:]); err != nil { + usage(dropKsCmd) + os.Exit(0) + } + + cqlSession, err := db.NewSession(envConfig, *keyspace, db.DoNotCreateKeyspaceOnConnect) + if err != nil { + log.Fatalf(err.Error()) + } + + err = api.DropKeyspace(logger, cqlSession, *keyspace) + if err != nil { + log.Fatalf(err.Error()) + } + + case CmdStartRun: + startRunCmd := flag.NewFlagSet(CmdStartRun, flag.ExitOnError) + keyspace := startRunCmd.String("keyspace", "", "Keyspace (session id)") + scriptFilePath := startRunCmd.String("script_file", "", "Path to script file") + paramsFilePath := startRunCmd.String("params_file", "", "Path to script parameters map file") + startNodesString := startRunCmd.String("start_nodes", "", "Comma-separated list of start node names") + if err := startRunCmd.Parse(os.Args[2:]); err != nil { + usage(startRunCmd) + os.Exit(0) + } + + startNodes := strings.Split(*startNodesString, ",") + + cqlSession, err := db.NewSession(envConfig, *keyspace, db.CreateKeyspaceOnConnect) + if err != nil { + log.Fatalf(err.Error()) + } + + // RabbitMQ boilerplate + amqpConnection, err := amqp.Dial(envConfig.Amqp.URL) + if err != nil { + log.Fatalf(fmt.Sprintf("cannot dial RabbitMQ at %v, will reconnect: %v\n", envConfig.Amqp.URL, err)) + } + defer amqpConnection.Close() + + amqpChannel, err := amqpConnection.Channel() + if err != nil { + log.Fatalf(fmt.Sprintf("cannot create amqp channel, will reconnect: %v\n", err)) + } + defer amqpChannel.Close() + + runId, err := api.StartRun(envConfig, logger, amqpChannel, *scriptFilePath, *paramsFilePath, cqlSession, *keyspace, startNodes, "started by Toolbelt") + if err != nil { + log.Fatalf(err.Error()) + } + + fmt.Println(runId) + + case CmdExecNode: + execNodeCmd := flag.NewFlagSet(CmdExecNode, flag.ExitOnError) + keyspace := execNodeCmd.String("keyspace", "", "Keyspace (session id)") + scriptFilePath := execNodeCmd.String("script_file", "", "Path to script file") + paramsFilePath := execNodeCmd.String("params_file", "", "Path to script parameters map file") + runIdParam := execNodeCmd.Int("run_id", 0, "run id (optional, use with extra caution as it will modify existing run id results)") + nodeName := execNodeCmd.String("node_id", "", "Script node name") + if err := execNodeCmd.Parse(os.Args[2:]); err != nil { + usage(execNodeCmd) + os.Exit(0) + } + + runId := int16(*runIdParam) + + startTime := time.Now() + + cqlSession, err := db.NewSession(envConfig, *keyspace, db.CreateKeyspaceOnConnect) + if err != nil { + log.Fatalf(err.Error()) + } + + runId, err = api.RunNode(envConfig, logger, *nodeName, runId, *scriptFilePath, *paramsFilePath, cqlSession, *keyspace) + if err != nil { + log.Fatalf(err.Error()) + } + fmt.Printf("run %d, elapsed %v\n", runId, time.Since(startTime)) + + default: + fmt.Printf("invalid command: %s\n", os.Args[1]) + usage(nil) + os.Exit(1) + } +} diff --git a/pkg/exe/webapi/capiwebapi.go b/pkg/exe/webapi/capiwebapi.go index c1f18e6..a3b2de3 100644 --- a/pkg/exe/webapi/capiwebapi.go +++ b/pkg/exe/webapi/capiwebapi.go @@ -1,587 +1,592 @@ -package main - -import ( - "context" - "encoding/json" - "fmt" - "io/ioutil" - "log" - "net/http" - "os" - "regexp" - "sort" - "strconv" - "strings" - "sync" - "time" - - "github.com/capillariesio/capillaries/pkg/api" - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/custom/py_calc" - "github.com/capillariesio/capillaries/pkg/custom/tag_and_denormalize" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/env" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "github.com/gocql/gocql" - amqp "github.com/rabbitmq/amqp091-go" -) - -type StandardWebapiProcessorDefFactory struct { -} - -func (f *StandardWebapiProcessorDefFactory) Create(processorType string) (sc.CustomProcessorDef, bool) { - // All processors to be supported by this 'stock' binary (daemon/toolbelt/webapi). - // If you develop your own processor(s), use your own ProcessorDefFactory that lists all processors, - // they all must implement CustomProcessorRunner interface - switch processorType { - case py_calc.ProcessorPyCalcName: - return &py_calc.PyCalcProcessorDef{}, true - case tag_and_denormalize.ProcessorTagAndDenormalizeName: - return &tag_and_denormalize.TagAndDenormalizeProcessorDef{}, true - default: - return nil, false - } -} - -type route struct { - method string - regex *regexp.Regexp - handler http.HandlerFunc -} - -func newRoute(method, pattern string, handler http.HandlerFunc) route { - return route{method, regexp.MustCompile("^" + pattern + "$"), handler} -} - -type ApiResponseError struct { - Msg string `json:"msg"` -} - -type ApiResponse struct { - Data interface{} `json:"data"` - Error ApiResponseError `json:"error"` -} - -func pickAccessControlAllowOrigin(wc *env.WebapiConfig, r *http.Request) string { - if wc.AccessControlAllowOrigin == "*" { - return "*" - } - allowedOrigins := strings.Split(wc.AccessControlAllowOrigin, ",") - requestedOrigins, ok := r.Header["Origin"] - if !ok || len(requestedOrigins) == 0 { - return "no-origins-requested" - } - for _, allowedOrigin := range allowedOrigins { - for _, requestedOrigin := range requestedOrigins { - if strings.ToUpper(requestedOrigin) == strings.ToUpper(allowedOrigin) { - return requestedOrigin - } - } - } - return "no-allowed-origins" -} - -func WriteApiError(l *l.Logger, wc *env.WebapiConfig, r *http.Request, w http.ResponseWriter, urlPath string, err error, httpStatus int) { - w.Header().Set("Access-Control-Allow-Origin", pickAccessControlAllowOrigin(wc, r)) - l.Error("cannot process %s: %s", urlPath, err.Error()) - respJson, err := json.Marshal(ApiResponse{Error: ApiResponseError{Msg: err.Error()}}) - if err != nil { - http.Error(w, fmt.Sprintf("unexpected: cannot serialize error response %s", err.Error()), httpStatus) - } else { - http.Error(w, string(respJson), httpStatus) - } -} - -func WriteApiSuccess(l *l.Logger, wc *env.WebapiConfig, r *http.Request, w http.ResponseWriter, data interface{}) { - w.Header().Set("Access-Control-Allow-Origin", pickAccessControlAllowOrigin(wc, r)) - respJson, err := json.Marshal(ApiResponse{Data: data}) - if err != nil { - http.Error(w, fmt.Sprintf("cannot serialize success response: %s", err.Error()), http.StatusInternalServerError) - } else { - if _, err := w.Write([]byte(respJson)); err != nil { - l.Error("cannot write success response, error %s, response %s", err.Error(), respJson) - } - } -} - -func (h *UrlHandler) ks(w http.ResponseWriter, r *http.Request) { - cqlSession, err := db.NewSession(h.Env, "", db.DoNotCreateKeyspaceOnConnect) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - defer cqlSession.Close() - - // This works only for Cassandra 4.X, not guaranteed to work for later versions - qb := cql.QueryBuilder{} - q := qb.Keyspace("system_schema").Select("keyspaces", []string{"keyspace_name"}) - rows, err := cqlSession.Query(q).Iter().SliceMap() - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - respData := make([]string, len(rows)) - ksCount := 0 - - for _, row := range rows { - ks := row["keyspace_name"].(string) - if len(ks) == 0 || api.IsSystemKeyspaceName(ks) { - continue - } - respData[ksCount] = ks - ksCount++ - } - - WriteApiSuccess(h.L, &h.Env.Webapi, r, w, respData[:ksCount]) -} - -type FullRunInfo struct { - Props *wfmodel.RunProperties `json:"props"` - History []*wfmodel.RunHistoryEvent `json:"history"` -} - -type WebapiNodeStatus struct { - RunId int16 `json:"run_id"` - Status wfmodel.NodeBatchStatusType `json:"status"` - Ts string `json:"ts"` -} - -type WebapiNodeRunMatrixRow struct { - NodeName string `json:"node_name"` - NodeDesc string `json:"node_desc"` - NodeStatuses []WebapiNodeStatus `json:"node_statuses"` -} -type WebapiNodeRunMatrix struct { - RunLifespans []*wfmodel.RunLifespan `json:"run_lifespans"` - Nodes []WebapiNodeRunMatrixRow `json:"nodes"` -} - -// Poor man's cache -var NodeDescCache map[string]string = map[string]string{} -var NodeDescCacheLock = sync.RWMutex{} - -func (h *UrlHandler) getNodeDesc(cqlSession *gocql.Session, keyspace string, runId int16, nodeName string) (string, error) { - - NodeDescCacheLock.RLock() - nodeDesc, ok := NodeDescCache[keyspace+nodeName] - NodeDescCacheLock.RUnlock() - if ok { - return nodeDesc, nil - } - - allRunsProps, err := api.GetRunProperties(h.L, cqlSession, keyspace, int16(runId)) - if err != nil { - return "", err - } - - if len(allRunsProps) != 1 { - return "", fmt.Errorf("invalid number of matching runs (%d), expected 1; this usually happens when webapi caller makes wrong assumptions about the process status", len(allRunsProps)) - } - - script, err, _ := sc.NewScriptFromFiles(h.Env.CaPath, h.Env.PrivateKeys, allRunsProps[0].ScriptUri, allRunsProps[0].ScriptParamsUri, h.Env.CustomProcessorDefFactoryInstance, h.Env.CustomProcessorsSettings) - if err != nil { - return "", err - } - - nodeDef, ok := script.ScriptNodes[nodeName] - if !ok { - return "", fmt.Errorf("cannot find node %s", nodeName) - } - - NodeDescCacheLock.RLock() - NodeDescCache[keyspace+nodeName] = nodeDef.Desc - NodeDescCacheLock.RUnlock() - - return nodeDef.Desc, nil -} - -func (h *UrlHandler) ksMatrix(w http.ResponseWriter, r *http.Request) { - keyspace := getField(r, 0) - cqlSession, err := db.NewSession(h.Env, keyspace, db.DoNotCreateKeyspaceOnConnect) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - defer cqlSession.Close() - - // Retrieve all runs that happened in this ks and find their current statuses - runLifespanMap, err := api.HarvestRunLifespans(h.L, cqlSession, keyspace, []int16{}) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - // Arrange run statuses for the matrix header - mx := WebapiNodeRunMatrix{RunLifespans: make([]*wfmodel.RunLifespan, len(runLifespanMap))} - runCount := 0 - for _, runLifespan := range runLifespanMap { - mx.RunLifespans[runCount] = runLifespan - runCount++ - } - sort.Slice(mx.RunLifespans, func(i, j int) bool { return mx.RunLifespans[i].RunId < mx.RunLifespans[j].RunId }) - - // Retireve all node events for this ks, for all runs - nodeHistory, err := api.GetRunsNodeHistory(h.L, cqlSession, keyspace, []int16{}) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - nodeStartTsMap := map[string]time.Time{} // Arrange by the ts in the last run - - // For each node/run, harvest current node status, latest wins - nodeRunStatusMap := map[string]map[int16]WebapiNodeStatus{} - for _, nodeEvent := range nodeHistory { - if _, ok := nodeRunStatusMap[nodeEvent.ScriptNode]; !ok { - nodeRunStatusMap[nodeEvent.ScriptNode] = map[int16]WebapiNodeStatus{} - } - nodeRunStatusMap[nodeEvent.ScriptNode][nodeEvent.RunId] = WebapiNodeStatus{RunId: nodeEvent.RunId, Status: nodeEvent.Status, Ts: nodeEvent.Ts.Format("2006-01-02T15:04:05.000-0700")} - - if nodeEvent.Status == wfmodel.NodeBatchStart { - if _, ok := nodeStartTsMap[nodeEvent.ScriptNode]; !ok { - nodeStartTsMap[nodeEvent.ScriptNode] = nodeEvent.Ts - } - } - } - - // Arrange status in the result mx - mx.Nodes = make([]WebapiNodeRunMatrixRow, len(nodeRunStatusMap)) - nodeCount := 0 - for nodeName, runNodeStatusMap := range nodeRunStatusMap { - nodeDesc, err := h.getNodeDesc(cqlSession, keyspace, runLifespanMap[1].RunId, nodeName) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, fmt.Errorf("cannot get node description: %s", err.Error()), http.StatusInternalServerError) - return - } - mx.Nodes[nodeCount] = WebapiNodeRunMatrixRow{NodeName: nodeName, NodeDesc: nodeDesc, NodeStatuses: make([]WebapiNodeStatus, len(mx.RunLifespans))} - for runIdx, matrixRunLifespan := range mx.RunLifespans { - if nodeStatus, ok := runNodeStatusMap[matrixRunLifespan.RunId]; ok { - mx.Nodes[nodeCount].NodeStatuses[runIdx] = nodeStatus - } - } - nodeCount++ - } - - // Sort nodes: started come first, sorted by start ts, other come after that, sorted by node name - // Ideally, they should be sorted geometrically from DAG, with start ts coming into play when DAG says nodes are equal. - // But this will require script analysis which takes too long. - sort.Slice(mx.Nodes, func(i, j int) bool { - leftTs, leftPresent := nodeStartTsMap[mx.Nodes[i].NodeName] - rightTs, rightPresent := nodeStartTsMap[mx.Nodes[j].NodeName] - if !leftPresent && rightPresent { - return false - } else if leftPresent && !rightPresent { - return true - } else if !leftPresent && !rightPresent { - // Sort by node name - return mx.Nodes[i].NodeName < mx.Nodes[j].NodeName - } else { - return leftTs.Before(rightTs) - } - }) - - WriteApiSuccess(h.L, &h.Env.Webapi, r, w, mx) -} - -func getRunPropsAndLifespans(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runId int16) (*wfmodel.RunProperties, *wfmodel.RunLifespan, error) { - // Static run properties - // TODO: consider caching - - allRunsProps, err := api.GetRunProperties(logger, cqlSession, keyspace, int16(runId)) - if err != nil { - return nil, nil, err - } - - if len(allRunsProps) != 1 { - return nil, nil, fmt.Errorf("invalid number of matching runs (%d), expected 1; this usually happens when webapi caller makes wrong assumptions about the process status", len(allRunsProps)) - } - - // Run status - - runLifeSpans, err := api.HarvestRunLifespans(logger, cqlSession, keyspace, []int16{int16(runId)}) - if err != nil { - return nil, nil, err - } - if len(runLifeSpans) != 1 { - return nil, nil, fmt.Errorf("invalid number of run life spans (%d), expected 1 ", len(runLifeSpans)) - } - - return allRunsProps[0], runLifeSpans[int16(runId)], nil -} - -type RunNodeBatchesInfo struct { - RunProps *wfmodel.RunProperties `json:"run_props"` - RunLs *wfmodel.RunLifespan `json:"run_lifespan"` - RunNodeBatchHistory []*wfmodel.BatchHistoryEvent `json:"batch_history"` -} - -func (h *UrlHandler) ksRunNodeBatchHistory(w http.ResponseWriter, r *http.Request) { - keyspace := getField(r, 0) - cqlSession, err := db.NewSession(h.Env, keyspace, db.DoNotCreateKeyspaceOnConnect) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - defer cqlSession.Close() - - runId, err := strconv.Atoi(getField(r, 1)) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - result := RunNodeBatchesInfo{} - result.RunProps, result.RunLs, err = getRunPropsAndLifespans(h.L, cqlSession, keyspace, int16(runId)) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - // Batch history - - nodeName := getField(r, 2) - result.RunNodeBatchHistory, err = api.GetRunNodeBatchHistory(h.L, cqlSession, keyspace, int16(runId), nodeName) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - WriteApiSuccess(h.L, &h.Env.Webapi, r, w, result) -} - -type RunNodesInfo struct { - RunProps *wfmodel.RunProperties `json:"run_props"` - RunLs *wfmodel.RunLifespan `json:"run_lifespan"` - RunNodeHistory []*wfmodel.NodeHistoryEvent `json:"node_history"` -} - -func (h *UrlHandler) ksRunNodeHistory(w http.ResponseWriter, r *http.Request) { - keyspace := getField(r, 0) - cqlSession, err := db.NewSession(h.Env, keyspace, db.DoNotCreateKeyspaceOnConnect) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - defer cqlSession.Close() - - runId, err := strconv.Atoi(getField(r, 1)) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - result := RunNodesInfo{} - result.RunProps, result.RunLs, err = getRunPropsAndLifespans(h.L, cqlSession, keyspace, int16(runId)) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - // Node history - - result.RunNodeHistory, err = api.GetNodeHistoryForRun(h.L, cqlSession, keyspace, int16(runId)) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - sort.Slice(result.RunNodeHistory, func(i, j int) bool { return result.RunNodeHistory[i].Ts.Before(result.RunNodeHistory[j].Ts) }) - - WriteApiSuccess(h.L, &h.Env.Webapi, r, w, result) -} - -type StartedRunInfo struct { - RunId int16 `json:"run_id"` -} - -func (h *UrlHandler) ksStartRunOptions(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Methods", "OPTIONS,POST") - WriteApiSuccess(h.L, &h.Env.Webapi, r, w, nil) -} - -func (h *UrlHandler) ksStartRun(w http.ResponseWriter, r *http.Request) { - keyspace := getField(r, 0) - cqlSession, err := db.NewSession(h.Env, keyspace, db.CreateKeyspaceOnConnect) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - amqpConnection, err := amqp.Dial(h.Env.Amqp.URL) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, fmt.Errorf("cannot dial RabbitMQ at %v, will reconnect: %v\n", h.Env.Amqp.URL, err), http.StatusInternalServerError) - return - } - defer amqpConnection.Close() - - amqpChannel, err := amqpConnection.Channel() - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, fmt.Errorf("cannot create amqp channel: %v\n", err), http.StatusInternalServerError) - return - } - defer amqpChannel.Close() - - bodyBytes, err := ioutil.ReadAll(r.Body) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - runProps := wfmodel.RunProperties{} - if err = json.Unmarshal(bodyBytes, &runProps); err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - runId, err := api.StartRun(h.Env, h.L, amqpChannel, runProps.ScriptUri, runProps.ScriptParamsUri, cqlSession, keyspace, strings.Split(runProps.StartNodes, ","), runProps.RunDescription) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - WriteApiSuccess(h.L, &h.Env.Webapi, r, w, StartedRunInfo{RunId: runId}) -} - -type StopRunInfo struct { - Comment string `json:"comment"` -} - -func (h *UrlHandler) ksStopRunOptions(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Methods", "OPTIONS,DELETE") - WriteApiSuccess(h.L, &h.Env.Webapi, r, w, nil) -} - -func (h *UrlHandler) ksStopRun(w http.ResponseWriter, r *http.Request) { - keyspace := getField(r, 0) - cqlSession, err := db.NewSession(h.Env, keyspace, db.DoNotCreateKeyspaceOnConnect) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - bodyBytes, err := ioutil.ReadAll(r.Body) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - stopRunInfo := StopRunInfo{} - if err = json.Unmarshal(bodyBytes, &stopRunInfo); err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - runId, err := strconv.Atoi(getField(r, 1)) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - if err = api.StopRun(h.L, cqlSession, keyspace, int16(runId), stopRunInfo.Comment); err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - WriteApiSuccess(h.L, &h.Env.Webapi, r, w, nil) -} - -func (h *UrlHandler) ksDropOptions(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Methods", "OPTIONS,DELETE") - WriteApiSuccess(h.L, &h.Env.Webapi, r, w, nil) -} - -func (h *UrlHandler) ksDrop(w http.ResponseWriter, r *http.Request) { - keyspace := getField(r, 0) - cqlSession, err := db.NewSession(h.Env, keyspace, db.DoNotCreateKeyspaceOnConnect) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - - err = api.DropKeyspace(h.L, cqlSession, keyspace) - if err != nil { - WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) - return - } - WriteApiSuccess(h.L, &h.Env.Webapi, r, w, nil) -} - -type UrlHandler struct { - Env *env.EnvConfig - L *l.Logger -} - -type ctxKey struct { -} - -func getField(r *http.Request, index int) string { - fields := r.Context().Value(ctxKey{}).([]string) - return fields[index] -} - -var routes []route - -func (h UrlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var allow []string - for _, route := range routes { - matches := route.regex.FindStringSubmatch(r.URL.Path) - if len(matches) > 0 { - if r.Method != route.method { - allow = append(allow, route.method) - continue - } - ctx := context.WithValue(r.Context(), ctxKey{}, matches[1:]) - - route.handler(w, r.WithContext(ctx)) - return - } - } - if len(allow) > 0 { - w.Header().Set("Allow", strings.Join(allow, ", ")) - http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed) - return - } - http.NotFound(w, r) -} - -func main() { - envConfig, err := env.ReadEnvConfigFile("capiwebapi.json") - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - - // Webapi (like toolbelt and daemon) requires custom proc def factory, otherwise it will not be able to start runs - envConfig.CustomProcessorDefFactoryInstance = &StandardWebapiProcessorDefFactory{} - logger, err := l.NewLoggerFromEnvConfig(envConfig) - if err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } - defer logger.Close() - - mux := http.NewServeMux() - - h := UrlHandler{Env: envConfig, L: logger} - - routes = []route{ - newRoute("GET", "/ks[/]*", h.ks), - newRoute("GET", "/ks/([a-zA-Z0-9_]+)[/]*", h.ksMatrix), - newRoute("GET", "/ks/([a-zA-Z0-9_]+)/run/([0-9]+)/node/([a-zA-Z0-9_]+)/batch_history[/]*", h.ksRunNodeBatchHistory), - newRoute("GET", "/ks/([a-zA-Z0-9_]+)/run/([0-9]+)/node_history[/]*", h.ksRunNodeHistory), - newRoute("POST", "/ks/([a-zA-Z0-9_]+)/run[/]*", h.ksStartRun), - newRoute("OPTIONS", "/ks/([a-zA-Z0-9_]+)/run[/]*", h.ksStartRunOptions), - newRoute("DELETE", "/ks/([a-zA-Z0-9_]+)/run/([0-9]+)[/]*", h.ksStopRun), - newRoute("OPTIONS", "/ks/([a-zA-Z0-9_]+)/run/([0-9]+)[/]*", h.ksStopRunOptions), - newRoute("DELETE", "/ks/([a-zA-Z0-9_]+)[/]*", h.ksDrop), - newRoute("OPTIONS", "/ks/([a-zA-Z0-9_]+)[/]*", h.ksDropOptions), - } - - mux.Handle("/", h) - - logger.Info("listening on %d...", h.Env.Webapi.Port) - if err := http.ListenAndServe(fmt.Sprintf(":%d", h.Env.Webapi.Port), mux); err != nil { - log.Fatalf(err.Error()) - os.Exit(1) - } -} +package main + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "regexp" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/capillariesio/capillaries/pkg/api" + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/custom/py_calc" + "github.com/capillariesio/capillaries/pkg/custom/tag_and_denormalize" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/env" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "github.com/gocql/gocql" + amqp "github.com/rabbitmq/amqp091-go" +) + +type StandardWebapiProcessorDefFactory struct { +} + +func (f *StandardWebapiProcessorDefFactory) Create(processorType string) (sc.CustomProcessorDef, bool) { + // All processors to be supported by this 'stock' binary (daemon/toolbelt/webapi). + // If you develop your own processor(s), use your own ProcessorDefFactory that lists all processors, + // they all must implement CustomProcessorRunner interface + switch processorType { + case py_calc.ProcessorPyCalcName: + return &py_calc.PyCalcProcessorDef{}, true + case tag_and_denormalize.ProcessorTagAndDenormalizeName: + return &tag_and_denormalize.TagAndDenormalizeProcessorDef{}, true + default: + return nil, false + } +} + +type route struct { + method string + regex *regexp.Regexp + handler http.HandlerFunc +} + +func newRoute(method, pattern string, handler http.HandlerFunc) route { + return route{method, regexp.MustCompile("^" + pattern + "$"), handler} +} + +type ApiResponseError struct { + Msg string `json:"msg"` +} + +type ApiResponse struct { + Data any `json:"data"` + Error ApiResponseError `json:"error"` +} + +func pickAccessControlAllowOrigin(wc *env.WebapiConfig, r *http.Request) string { + if wc.AccessControlAllowOrigin == "*" { + return "*" + } + allowedOrigins := strings.Split(wc.AccessControlAllowOrigin, ",") + requestedOrigins, ok := r.Header["Origin"] + if !ok || len(requestedOrigins) == 0 { + return "no-origins-requested" + } + for _, allowedOrigin := range allowedOrigins { + for _, requestedOrigin := range requestedOrigins { + if strings.EqualFold(requestedOrigin, allowedOrigin) { + return requestedOrigin + } + } + } + return "no-allowed-origins" +} + +func WriteApiError(logger *l.CapiLogger, wc *env.WebapiConfig, r *http.Request, w http.ResponseWriter, urlPath string, err error, httpStatus int) { + w.Header().Set("Access-Control-Allow-Origin", pickAccessControlAllowOrigin(wc, r)) + logger.Error("cannot process %s: %s", urlPath, err.Error()) + respJson, err := json.Marshal(ApiResponse{Error: ApiResponseError{Msg: err.Error()}}) + if err != nil { + http.Error(w, fmt.Sprintf("unexpected: cannot serialize error response %s", err.Error()), httpStatus) + } else { + http.Error(w, string(respJson), httpStatus) + } +} + +func WriteApiSuccess(logger *l.CapiLogger, wc *env.WebapiConfig, r *http.Request, w http.ResponseWriter, data any) { + w.Header().Set("Access-Control-Allow-Origin", pickAccessControlAllowOrigin(wc, r)) + respJson, err := json.Marshal(ApiResponse{Data: data}) + if err != nil { + http.Error(w, fmt.Sprintf("cannot serialize success response: %s", err.Error()), http.StatusInternalServerError) + } else { + if _, err := w.Write([]byte(respJson)); err != nil { + logger.Error("cannot write success response, error %s, response %s", err.Error(), respJson) + } + } +} + +func (h *UrlHandler) ks(w http.ResponseWriter, r *http.Request) { + cqlSession, err := db.NewSession(h.Env, "", db.DoNotCreateKeyspaceOnConnect) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + defer cqlSession.Close() + + // This works only for Cassandra 4.X, not guaranteed to work for later versions + qb := cql.QueryBuilder{} + q := qb.Keyspace("system_schema").Select("keyspaces", []string{"keyspace_name"}) + rows, err := cqlSession.Query(q).Iter().SliceMap() + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + respData := make([]string, len(rows)) + ksCount := 0 + + for _, row := range rows { + ksVolatile, ok := row["keyspace_name"] + if !ok { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, fmt.Errorf("cannot find keyspace_name in the response"), http.StatusInternalServerError) + return + } + + ks, ok := ksVolatile.(string) + if !ok { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, fmt.Errorf("cannot cast keyspace_name to string: %v", row["keyspace_name"]), http.StatusInternalServerError) + return + } + if len(ks) == 0 || api.IsSystemKeyspaceName(ks) { + continue + } + respData[ksCount] = ks + ksCount++ + } + + WriteApiSuccess(h.L, &h.Env.Webapi, r, w, respData[:ksCount]) +} + +type FullRunInfo struct { + Props *wfmodel.RunProperties `json:"props"` + History []*wfmodel.RunHistoryEvent `json:"history"` +} + +type WebapiNodeStatus struct { + RunId int16 `json:"run_id"` + Status wfmodel.NodeBatchStatusType `json:"status"` + Ts string `json:"ts"` +} + +type WebapiNodeRunMatrixRow struct { + NodeName string `json:"node_name"` + NodeDesc string `json:"node_desc"` + NodeStatuses []WebapiNodeStatus `json:"node_statuses"` +} +type WebapiNodeRunMatrix struct { + RunLifespans []*wfmodel.RunLifespan `json:"run_lifespans"` + Nodes []WebapiNodeRunMatrixRow `json:"nodes"` +} + +// Poor man's cache +var NodeDescCache = map[string]string{} +var NodeDescCacheLock = sync.RWMutex{} + +func (h *UrlHandler) getNodeDesc(cqlSession *gocql.Session, keyspace string, runId int16, nodeName string) (string, error) { + + NodeDescCacheLock.RLock() + nodeDesc, ok := NodeDescCache[keyspace+nodeName] + NodeDescCacheLock.RUnlock() + if ok { + return nodeDesc, nil + } + + allRunsProps, err := api.GetRunProperties(h.L, cqlSession, keyspace, int16(runId)) + if err != nil { + return "", err + } + + if len(allRunsProps) != 1 { + return "", fmt.Errorf("invalid number of matching runs (%d), expected 1; this usually happens when webapi caller makes wrong assumptions about the process status", len(allRunsProps)) + } + + script, _, err := sc.NewScriptFromFiles(h.Env.CaPath, h.Env.PrivateKeys, allRunsProps[0].ScriptUri, allRunsProps[0].ScriptParamsUri, h.Env.CustomProcessorDefFactoryInstance, h.Env.CustomProcessorsSettings) + if err != nil { + return "", err + } + + nodeDef, ok := script.ScriptNodes[nodeName] + if !ok { + return "", fmt.Errorf("cannot find node %s", nodeName) + } + + NodeDescCacheLock.RLock() + NodeDescCache[keyspace+nodeName] = nodeDef.Desc + NodeDescCacheLock.RUnlock() + + return nodeDef.Desc, nil +} + +func (h *UrlHandler) ksMatrix(w http.ResponseWriter, r *http.Request) { + keyspace := getField(r, 0) + cqlSession, err := db.NewSession(h.Env, keyspace, db.DoNotCreateKeyspaceOnConnect) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + defer cqlSession.Close() + + // Retrieve all runs that happened in this ks and find their current statuses + runLifespanMap, err := api.HarvestRunLifespans(h.L, cqlSession, keyspace, []int16{}) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + // Arrange run statuses for the matrix header + mx := WebapiNodeRunMatrix{RunLifespans: make([]*wfmodel.RunLifespan, len(runLifespanMap))} + runCount := 0 + for _, runLifespan := range runLifespanMap { + mx.RunLifespans[runCount] = runLifespan + runCount++ + } + sort.Slice(mx.RunLifespans, func(i, j int) bool { return mx.RunLifespans[i].RunId < mx.RunLifespans[j].RunId }) + + // Retrieve all node events for this ks, for all runs + nodeHistory, err := api.GetRunsNodeHistory(h.L, cqlSession, keyspace, []int16{}) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + nodeStartTsMap := map[string]time.Time{} // Arrange by the ts in the last run + + // For each node/run, harvest current node status, latest wins + nodeRunStatusMap := map[string]map[int16]WebapiNodeStatus{} + for _, nodeEvent := range nodeHistory { + if _, ok := nodeRunStatusMap[nodeEvent.ScriptNode]; !ok { + nodeRunStatusMap[nodeEvent.ScriptNode] = map[int16]WebapiNodeStatus{} + } + nodeRunStatusMap[nodeEvent.ScriptNode][nodeEvent.RunId] = WebapiNodeStatus{RunId: nodeEvent.RunId, Status: nodeEvent.Status, Ts: nodeEvent.Ts.Format("2006-01-02T15:04:05.000-0700")} + + if nodeEvent.Status == wfmodel.NodeBatchStart { + if _, ok := nodeStartTsMap[nodeEvent.ScriptNode]; !ok { + nodeStartTsMap[nodeEvent.ScriptNode] = nodeEvent.Ts + } + } + } + + // Arrange status in the result mx + mx.Nodes = make([]WebapiNodeRunMatrixRow, len(nodeRunStatusMap)) + nodeCount := 0 + for nodeName, runNodeStatusMap := range nodeRunStatusMap { + nodeDesc, err := h.getNodeDesc(cqlSession, keyspace, runLifespanMap[1].RunId, nodeName) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, fmt.Errorf("cannot get node description: %s", err.Error()), http.StatusInternalServerError) + return + } + mx.Nodes[nodeCount] = WebapiNodeRunMatrixRow{NodeName: nodeName, NodeDesc: nodeDesc, NodeStatuses: make([]WebapiNodeStatus, len(mx.RunLifespans))} + for runIdx, matrixRunLifespan := range mx.RunLifespans { + if nodeStatus, ok := runNodeStatusMap[matrixRunLifespan.RunId]; ok { + mx.Nodes[nodeCount].NodeStatuses[runIdx] = nodeStatus + } + } + nodeCount++ + } + + // Sort nodes: started come first, sorted by start ts, other come after that, sorted by node name + // Ideally, they should be sorted geometrically from DAG, with start ts coming into play when DAG says nodes are equal. + // But this will require script analysis which takes too long. + sort.Slice(mx.Nodes, func(i, j int) bool { + leftTs, leftPresent := nodeStartTsMap[mx.Nodes[i].NodeName] + rightTs, rightPresent := nodeStartTsMap[mx.Nodes[j].NodeName] + if !leftPresent && rightPresent { + return false + } else if leftPresent && !rightPresent { + return true + } else if !leftPresent && !rightPresent { + // Sort by node name + return mx.Nodes[i].NodeName < mx.Nodes[j].NodeName + } + return leftTs.Before(rightTs) + }) + + WriteApiSuccess(h.L, &h.Env.Webapi, r, w, mx) +} + +func getRunPropsAndLifespans(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runId int16) (*wfmodel.RunProperties, *wfmodel.RunLifespan, error) { + // Static run properties + // TODO: consider caching + + allRunsProps, err := api.GetRunProperties(logger, cqlSession, keyspace, int16(runId)) + if err != nil { + return nil, nil, err + } + + if len(allRunsProps) != 1 { + return nil, nil, fmt.Errorf("invalid number of matching runs (%d), expected 1; this usually happens when webapi caller makes wrong assumptions about the process status", len(allRunsProps)) + } + + // Run status + + runLifeSpans, err := api.HarvestRunLifespans(logger, cqlSession, keyspace, []int16{int16(runId)}) + if err != nil { + return nil, nil, err + } + if len(runLifeSpans) != 1 { + return nil, nil, fmt.Errorf("invalid number of run life spans (%d), expected 1 ", len(runLifeSpans)) + } + + return allRunsProps[0], runLifeSpans[int16(runId)], nil +} + +type RunNodeBatchesInfo struct { + RunProps *wfmodel.RunProperties `json:"run_props"` + RunLs *wfmodel.RunLifespan `json:"run_lifespan"` + RunNodeBatchHistory []*wfmodel.BatchHistoryEvent `json:"batch_history"` +} + +func (h *UrlHandler) ksRunNodeBatchHistory(w http.ResponseWriter, r *http.Request) { + keyspace := getField(r, 0) + cqlSession, err := db.NewSession(h.Env, keyspace, db.DoNotCreateKeyspaceOnConnect) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + defer cqlSession.Close() + + runId, err := strconv.Atoi(getField(r, 1)) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + result := RunNodeBatchesInfo{} + result.RunProps, result.RunLs, err = getRunPropsAndLifespans(h.L, cqlSession, keyspace, int16(runId)) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + // Batch history + + nodeName := getField(r, 2) + result.RunNodeBatchHistory, err = api.GetRunNodeBatchHistory(h.L, cqlSession, keyspace, int16(runId), nodeName) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + WriteApiSuccess(h.L, &h.Env.Webapi, r, w, result) +} + +type RunNodesInfo struct { + RunProps *wfmodel.RunProperties `json:"run_props"` + RunLs *wfmodel.RunLifespan `json:"run_lifespan"` + RunNodeHistory []*wfmodel.NodeHistoryEvent `json:"node_history"` +} + +func (h *UrlHandler) ksRunNodeHistory(w http.ResponseWriter, r *http.Request) { + keyspace := getField(r, 0) + cqlSession, err := db.NewSession(h.Env, keyspace, db.DoNotCreateKeyspaceOnConnect) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + defer cqlSession.Close() + + runId, err := strconv.Atoi(getField(r, 1)) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + result := RunNodesInfo{} + result.RunProps, result.RunLs, err = getRunPropsAndLifespans(h.L, cqlSession, keyspace, int16(runId)) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + // Node history + + result.RunNodeHistory, err = api.GetNodeHistoryForRun(h.L, cqlSession, keyspace, int16(runId)) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + sort.Slice(result.RunNodeHistory, func(i, j int) bool { return result.RunNodeHistory[i].Ts.Before(result.RunNodeHistory[j].Ts) }) + + WriteApiSuccess(h.L, &h.Env.Webapi, r, w, result) +} + +type StartedRunInfo struct { + RunId int16 `json:"run_id"` +} + +func (h *UrlHandler) ksStartRunOptions(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Methods", "OPTIONS,POST") + WriteApiSuccess(h.L, &h.Env.Webapi, r, w, nil) +} + +func (h *UrlHandler) ksStartRun(w http.ResponseWriter, r *http.Request) { + keyspace := getField(r, 0) + cqlSession, err := db.NewSession(h.Env, keyspace, db.CreateKeyspaceOnConnect) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + amqpConnection, err := amqp.Dial(h.Env.Amqp.URL) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, fmt.Errorf("cannot dial RabbitMQ at %v, will reconnect: %v", h.Env.Amqp.URL, err), http.StatusInternalServerError) + return + } + defer amqpConnection.Close() + + amqpChannel, err := amqpConnection.Channel() + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, fmt.Errorf("cannot create amqp channel: %v", err), http.StatusInternalServerError) + return + } + defer amqpChannel.Close() + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + runProps := wfmodel.RunProperties{} + if err = json.Unmarshal(bodyBytes, &runProps); err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + runId, err := api.StartRun(h.Env, h.L, amqpChannel, runProps.ScriptUri, runProps.ScriptParamsUri, cqlSession, keyspace, strings.Split(runProps.StartNodes, ","), runProps.RunDescription) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + WriteApiSuccess(h.L, &h.Env.Webapi, r, w, StartedRunInfo{RunId: runId}) +} + +type StopRunInfo struct { + Comment string `json:"comment"` +} + +func (h *UrlHandler) ksStopRunOptions(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Methods", "OPTIONS,DELETE") + WriteApiSuccess(h.L, &h.Env.Webapi, r, w, nil) +} + +func (h *UrlHandler) ksStopRun(w http.ResponseWriter, r *http.Request) { + keyspace := getField(r, 0) + cqlSession, err := db.NewSession(h.Env, keyspace, db.DoNotCreateKeyspaceOnConnect) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + stopRunInfo := StopRunInfo{} + if err = json.Unmarshal(bodyBytes, &stopRunInfo); err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + runId, err := strconv.Atoi(getField(r, 1)) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + if err = api.StopRun(h.L, cqlSession, keyspace, int16(runId), stopRunInfo.Comment); err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + WriteApiSuccess(h.L, &h.Env.Webapi, r, w, nil) +} + +func (h *UrlHandler) ksDropOptions(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Methods", "OPTIONS,DELETE") + WriteApiSuccess(h.L, &h.Env.Webapi, r, w, nil) +} + +func (h *UrlHandler) ksDrop(w http.ResponseWriter, r *http.Request) { + keyspace := getField(r, 0) + cqlSession, err := db.NewSession(h.Env, keyspace, db.DoNotCreateKeyspaceOnConnect) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + + err = api.DropKeyspace(h.L, cqlSession, keyspace) + if err != nil { + WriteApiError(h.L, &h.Env.Webapi, r, w, r.URL.Path, err, http.StatusInternalServerError) + return + } + WriteApiSuccess(h.L, &h.Env.Webapi, r, w, nil) +} + +type UrlHandler struct { + Env *env.EnvConfig + L *l.CapiLogger +} + +type ctxKey struct { +} + +func getField(r *http.Request, index int) string { + fields := r.Context().Value(ctxKey{}).([]string) //nolint:all + return fields[index] +} + +var routes []route + +func (h UrlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var allow []string + for _, route := range routes { + matches := route.regex.FindStringSubmatch(r.URL.Path) + if len(matches) > 0 { + if r.Method != route.method { + allow = append(allow, route.method) + continue + } + ctx := context.WithValue(r.Context(), ctxKey{}, matches[1:]) + + route.handler(w, r.WithContext(ctx)) + return + } + } + if len(allow) > 0 { + w.Header().Set("Allow", strings.Join(allow, ", ")) + http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed) + return + } + http.NotFound(w, r) +} + +func main() { + envConfig, err := env.ReadEnvConfigFile("capiwebapi.json") + if err != nil { + log.Fatalf(err.Error()) + } + + // Webapi (like toolbelt and daemon) requires custom proc def factory, otherwise it will not be able to start runs + envConfig.CustomProcessorDefFactoryInstance = &StandardWebapiProcessorDefFactory{} + logger, err := l.NewLoggerFromEnvConfig(envConfig) + if err != nil { + log.Fatalf(err.Error()) + } + defer logger.Close() + + mux := http.NewServeMux() + + h := UrlHandler{Env: envConfig, L: logger} + + routes = []route{ + newRoute("GET", "/ks[/]*", h.ks), + newRoute("GET", "/ks/([a-zA-Z0-9_]+)[/]*", h.ksMatrix), + newRoute("GET", "/ks/([a-zA-Z0-9_]+)/run/([0-9]+)/node/([a-zA-Z0-9_]+)/batch_history[/]*", h.ksRunNodeBatchHistory), + newRoute("GET", "/ks/([a-zA-Z0-9_]+)/run/([0-9]+)/node_history[/]*", h.ksRunNodeHistory), + newRoute("POST", "/ks/([a-zA-Z0-9_]+)/run[/]*", h.ksStartRun), + newRoute("OPTIONS", "/ks/([a-zA-Z0-9_]+)/run[/]*", h.ksStartRunOptions), + newRoute("DELETE", "/ks/([a-zA-Z0-9_]+)/run/([0-9]+)[/]*", h.ksStopRun), + newRoute("OPTIONS", "/ks/([a-zA-Z0-9_]+)/run/([0-9]+)[/]*", h.ksStopRunOptions), + newRoute("DELETE", "/ks/([a-zA-Z0-9_]+)[/]*", h.ksDrop), + newRoute("OPTIONS", "/ks/([a-zA-Z0-9_]+)[/]*", h.ksDropOptions), + } + + mux.Handle("/", h) + + logger.Info("listening on %d...", h.Env.Webapi.Port) + if err := http.ListenAndServe(fmt.Sprintf(":%d", h.Env.Webapi.Port), mux); err != nil { + log.Fatalf(err.Error()) + } +} diff --git a/pkg/l/logger.go b/pkg/l/logger.go index 9521167..a84db36 100644 --- a/pkg/l/logger.go +++ b/pkg/l/logger.go @@ -1,128 +1,128 @@ -package l - -import ( - "fmt" - "os" - "sync/atomic" - "time" - - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/env" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -type Logger struct { - Logger *zap.Logger - ZapMachine zapcore.Field - ZapThread zapcore.Field - SavedZapConfig zap.Config - AtomicThreadCounter *int64 - ZapFunction zapcore.Field - FunctionStack []string -} - -func (logger *Logger) PushF(functionName string) { - logger.ZapFunction = zap.String("f", functionName) - logger.FunctionStack = append(logger.FunctionStack, functionName) -} -func (logger *Logger) PopF() { - if len(logger.FunctionStack) > 0 { - logger.FunctionStack = logger.FunctionStack[:len(logger.FunctionStack)-1] - if len(logger.FunctionStack) > 0 { - logger.ZapFunction = zap.String("f", logger.FunctionStack[len(logger.FunctionStack)-1]) - } else { - logger.ZapFunction = zap.String("f", "stack_underflow") - } - } -} - -func NewLoggerFromEnvConfig(envConfig *env.EnvConfig) (*Logger, error) { - atomicTreadCounter := int64(0) - l := Logger{AtomicThreadCounter: &atomicTreadCounter} - hostName, err := os.Hostname() - if err != nil { - return nil, fmt.Errorf("cannot get hostname: %s", err.Error()) - } - l.ZapMachine = zap.String("i", fmt.Sprintf("%s/%s/%s", hostName, envConfig.HandlerExecutableType, time.Now().Format("01-02T15:04:05.000"))) - l.ZapThread = zap.Int64("t", 0) - l.ZapFunction = zap.String("f", "") - - // TODO: this solution writes everything to stdout. Potentially, there is a way to write Debug/Info/Warn to stdout and - // errors to std err: https://stackoverflow.com/questions/68472667/how-to-log-to-stdout-or-stderr-based-on-log-level-using-uber-go-zap - // Do some research to see if this can be added to our ZapConfig.Build() scenario. - l.SavedZapConfig = envConfig.ZapConfig - l.Logger, err = envConfig.ZapConfig.Build() - if err != nil { - return nil, fmt.Errorf("cannot build logger from config: %s", err.Error()) - } - return &l, nil -} - -func NewLoggerFromLogger(srcLogger *Logger) (*Logger, error) { - l := Logger{ - SavedZapConfig: srcLogger.SavedZapConfig, - AtomicThreadCounter: srcLogger.AtomicThreadCounter, - ZapMachine: srcLogger.ZapMachine, - ZapFunction: zap.String("f", ""), - ZapThread: zap.Int64("t", atomic.AddInt64(srcLogger.AtomicThreadCounter, 1))} - - var err error - l.Logger, err = srcLogger.SavedZapConfig.Build() - if err != nil { - return nil, fmt.Errorf("cannot build logger from logger: %s", err.Error()) - } - return &l, nil -} - -func (l *Logger) Close() { - l.Logger.Sync() -} - -func (l *Logger) Debug(format string, a ...interface{}) { - l.Logger.Debug(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) -} - -func (l *Logger) DebugCtx(pCtx *ctx.MessageProcessingContext, format string, a ...interface{}) { - if pCtx == nil { - l.Logger.Debug(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) - } else { - l.Logger.Debug(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction, pCtx.ZapDataKeyspace, pCtx.ZapRun, pCtx.ZapNode, pCtx.ZapBatchIdx, pCtx.ZapMsgAgeMillis) - } -} - -func (l *Logger) Info(format string, a ...interface{}) { - l.Logger.Info(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) -} - -func (l *Logger) InfoCtx(pCtx *ctx.MessageProcessingContext, format string, a ...interface{}) { - if pCtx == nil { - l.Logger.Info(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) - } else { - l.Logger.Info(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction, pCtx.ZapDataKeyspace, pCtx.ZapRun, pCtx.ZapNode, pCtx.ZapBatchIdx, pCtx.ZapMsgAgeMillis) - } -} - -func (l *Logger) Warn(format string, a ...interface{}) { - l.Logger.Warn(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) -} - -func (l *Logger) WarnCtx(pCtx *ctx.MessageProcessingContext, format string, a ...interface{}) { - if pCtx == nil { - l.Logger.Warn(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) - } else { - l.Logger.Warn(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction, pCtx.ZapDataKeyspace, pCtx.ZapRun, pCtx.ZapNode, pCtx.ZapBatchIdx, pCtx.ZapMsgAgeMillis) - } -} - -func (l *Logger) Error(format string, a ...interface{}) { - l.Logger.Error(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) -} - -func (l *Logger) ErrorCtx(pCtx *ctx.MessageProcessingContext, format string, a ...interface{}) { - if pCtx == nil { - l.Logger.Error(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) - } else { - l.Logger.Error(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction, pCtx.ZapDataKeyspace, pCtx.ZapRun, pCtx.ZapNode, pCtx.ZapBatchIdx, pCtx.ZapMsgAgeMillis) - } -} +package l + +import ( + "fmt" + "os" + "sync/atomic" + "time" + + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/env" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type CapiLogger struct { + ZapLogger *zap.Logger + ZapMachine zapcore.Field + ZapThread zapcore.Field + SavedZapConfig zap.Config + AtomicThreadCounter *int64 + ZapFunction zapcore.Field + FunctionStack []string +} + +func (l *CapiLogger) PushF(functionName string) { + l.ZapFunction = zap.String("f", functionName) + l.FunctionStack = append(l.FunctionStack, functionName) +} +func (l *CapiLogger) PopF() { + if len(l.FunctionStack) > 0 { + l.FunctionStack = l.FunctionStack[:len(l.FunctionStack)-1] + if len(l.FunctionStack) > 0 { + l.ZapFunction = zap.String("f", l.FunctionStack[len(l.FunctionStack)-1]) + } else { + l.ZapFunction = zap.String("f", "stack_underflow") + } + } +} + +func NewLoggerFromEnvConfig(envConfig *env.EnvConfig) (*CapiLogger, error) { + atomicTreadCounter := int64(0) + l := CapiLogger{AtomicThreadCounter: &atomicTreadCounter} + hostName, err := os.Hostname() + if err != nil { + return nil, fmt.Errorf("cannot get hostname: %s", err.Error()) + } + l.ZapMachine = zap.String("i", fmt.Sprintf("%s/%s/%s", hostName, envConfig.HandlerExecutableType, time.Now().Format("01-02T15:04:05.000"))) + l.ZapThread = zap.Int64("t", 0) + l.ZapFunction = zap.String("f", "") + + // TODO: this solution writes everything to stdout. Potentially, there is a way to write Debug/Info/Warn to stdout and + // errors to std err: https://stackoverflow.com/questions/68472667/how-to-log-to-stdout-or-stderr-based-on-log-level-using-uber-go-zap + // Do some research to see if this can be added to our ZapConfig.Build() scenario. + l.SavedZapConfig = envConfig.ZapConfig + l.ZapLogger, err = envConfig.ZapConfig.Build() + if err != nil { + return nil, fmt.Errorf("cannot build l from config: %s", err.Error()) + } + return &l, nil +} + +func NewLoggerFromLogger(srcLogger *CapiLogger) (*CapiLogger, error) { + l := CapiLogger{ + SavedZapConfig: srcLogger.SavedZapConfig, + AtomicThreadCounter: srcLogger.AtomicThreadCounter, + ZapMachine: srcLogger.ZapMachine, + ZapFunction: zap.String("f", ""), + ZapThread: zap.Int64("t", atomic.AddInt64(srcLogger.AtomicThreadCounter, 1))} + + var err error + l.ZapLogger, err = srcLogger.SavedZapConfig.Build() + if err != nil { + return nil, fmt.Errorf("cannot build l from l: %s", err.Error()) + } + return &l, nil +} + +func (l *CapiLogger) Close() { + l.ZapLogger.Sync() //nolint:all +} + +func (l *CapiLogger) Debug(format string, a ...any) { + l.ZapLogger.Debug(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) +} + +func (l *CapiLogger) DebugCtx(pCtx *ctx.MessageProcessingContext, format string, a ...any) { + if pCtx == nil { + l.ZapLogger.Debug(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) + } else { + l.ZapLogger.Debug(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction, pCtx.ZapDataKeyspace, pCtx.ZapRun, pCtx.ZapNode, pCtx.ZapBatchIdx, pCtx.ZapMsgAgeMillis) + } +} + +func (l *CapiLogger) Info(format string, a ...any) { + l.ZapLogger.Info(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) +} + +func (l *CapiLogger) InfoCtx(pCtx *ctx.MessageProcessingContext, format string, a ...any) { + if pCtx == nil { + l.ZapLogger.Info(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) + } else { + l.ZapLogger.Info(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction, pCtx.ZapDataKeyspace, pCtx.ZapRun, pCtx.ZapNode, pCtx.ZapBatchIdx, pCtx.ZapMsgAgeMillis) + } +} + +func (l *CapiLogger) Warn(format string, a ...any) { + l.ZapLogger.Warn(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) +} + +func (l *CapiLogger) WarnCtx(pCtx *ctx.MessageProcessingContext, format string, a ...any) { + if pCtx == nil { + l.ZapLogger.Warn(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) + } else { + l.ZapLogger.Warn(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction, pCtx.ZapDataKeyspace, pCtx.ZapRun, pCtx.ZapNode, pCtx.ZapBatchIdx, pCtx.ZapMsgAgeMillis) + } +} + +func (l *CapiLogger) Error(format string, a ...any) { + l.ZapLogger.Error(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) +} + +func (l *CapiLogger) ErrorCtx(pCtx *ctx.MessageProcessingContext, format string, a ...any) { + if pCtx == nil { + l.ZapLogger.Error(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction) + } else { + l.ZapLogger.Error(fmt.Sprintf(format, a...), l.ZapMachine, l.ZapThread, l.ZapFunction, pCtx.ZapDataKeyspace, pCtx.ZapRun, pCtx.ZapNode, pCtx.ZapBatchIdx, pCtx.ZapMsgAgeMillis) + } +} diff --git a/pkg/proc/custom_processor_runner.go b/pkg/proc/custom_processor_runner.go index fcc2f1f..11b500e 100644 --- a/pkg/proc/custom_processor_runner.go +++ b/pkg/proc/custom_processor_runner.go @@ -1,11 +1,11 @@ -package proc - -import ( - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/capillariesio/capillaries/pkg/l" -) - -type CustomProcessorRunner interface { - Run(logger *l.Logger, pCtx *ctx.MessageProcessingContext, rsIn *Rowset, flushVarsArray func(varsArray []*eval.VarValuesMap, varsArrayCount int) error) error -} +package proc + +import ( + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/capillariesio/capillaries/pkg/l" +) + +type CustomProcessorRunner interface { + Run(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, rsIn *Rowset, flushVarsArray func(varsArray []*eval.VarValuesMap, varsArrayCount int) error) error +} diff --git a/pkg/proc/data_util.go b/pkg/proc/data_util.go index 7c608d6..c2a069b 100644 --- a/pkg/proc/data_util.go +++ b/pkg/proc/data_util.go @@ -1,378 +1,377 @@ -package proc - -import ( - "fmt" - "strings" - "time" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/gocql/gocql" -) - -// func ClearNodeOutputs(logger *l.Logger, script *sc.ScriptDef, session *gocql.Session, keyspace string, nodeName string, runId int16) error { -// node, ok := script.ScriptNodes[nodeName] -// if !ok { -// return fmt.Errorf("cannot find node %s", nodeName) -// } -// if node.HasTableCreator() { -// qb := cql.QueryBuilder{} -// logger.Info("deleting data table %s.%s...", keyspace, node.TableCreator.Name) -// query := qb.Keyspace(keyspace).DropRun(node.TableCreator.Name, runId) -// if err := session.Query(query).Exec(); err != nil { -// return fmt.Errorf("cannot drop data table [%s]: [%s]", query, err.Error()) -// } - -// for idxName, _ := range node.TableCreator.Indexes { -// qb := cql.QueryBuilder{} -// logger.Info("deleting index table %s.%s...", keyspace, idxName) -// query := qb.Keyspace(keyspace).DropRun(idxName, runId) -// if err := session.Query(query).Exec(); err != nil { -// return fmt.Errorf("cannot drop idx table [%s]: [%s]", query, err.Error()) -// } -// } -// } else if node.HasFileCreator() { -// if _, err := os.Stat(node.FileCreator.Url); err == nil { -// logger.Info("deleting output file %s...", node.FileCreator.Url) -// if err := os.Remove(node.FileCreator.Url); err != nil { -// return fmt.Errorf("cannot delete file [%s]: [%s]", node.FileCreator.Url, err.Error()) -// } -// } -// } -// return nil -// } - -func selectBatchFromDataTablePaged(logger *l.Logger, - pCtx *ctx.MessageProcessingContext, - rs *Rowset, - tableName string, - lookupNodeRunId int16, - batchSize int, - pageState []byte, - rowidsToFind map[int64]struct{}) ([]byte, error) { - - logger.PushF("proc.selectBatchFromDataTablePaged") - defer logger.PopF() - - if err := rs.InitRows(batchSize); err != nil { - return nil, err - } - - rowids := make([]int64, len(rowidsToFind)) - i := 0 - for k := range rowidsToFind { - rowids[i] = k - i++ - } - - qb := cql.QueryBuilder{} - q := qb. - Keyspace(pCtx.BatchInfo.DataKeyspace). - CondInPrepared("rowid"). // This is a right-side lookup table, select by rowid - SelectRun(tableName, lookupNodeRunId, *rs.GetFieldNames()) - - var iter *gocql.Iter - selectRetryIdx := 0 - curSelectExpBackoffFactor := 1 - for { - iter = pCtx.CqlSession.Query(q, rowids).PageSize(batchSize).PageState(pageState).Iter() - - dbWarnings := iter.Warnings() - if len(dbWarnings) > 0 { - // TODO: figure out what those warnigs can be, never saw one - logger.WarnCtx(pCtx, "got warnigs while selecting %d rows from %s%s: %s", batchSize, tableName, cql.RunIdSuffix(lookupNodeRunId), strings.Join(dbWarnings, ";")) - } - - rs.RowCount = 0 - - scanner := iter.Scanner() - for scanner.Next() { - if rs.RowCount >= len(rs.Rows) { - return nil, fmt.Errorf("unexpected data row retrieved, exceeding rowset size %d", len(rs.Rows)) - } - if err := scanner.Scan(*rs.Rows[rs.RowCount]...); err != nil { - return nil, db.WrapDbErrorWithQuery("cannot scan paged data row", q, err) - } - // We assume gocql creates only UTC timestamps, so this is not needed. - // If we ever catch a ts stored in our tables with a non-UTC tz, or gocql returning a non-UTC tz - investigate it. Sanitizing is the last resort and should be avoided. - // if err := rs.SanitizeScannedDatetimesToUtc(rs.RowCount); err != nil { - // return nil, db.WrapDbErrorWithQuery("cannot sanitize datetimes", q, err) - // } - rs.RowCount++ - } - - if err := scanner.Err(); err != nil { - if strings.Contains(err.Error(), "Operation timed out") || strings.Contains(err.Error(), "Cannot achieve consistency level") && selectRetryIdx < 3 { - logger.WarnCtx(pCtx, "cannot select %d rows from %s%s on retry %d, getting timeout/consistency error (%s), will wait for %dms and retry", batchSize, tableName, cql.RunIdSuffix(lookupNodeRunId), selectRetryIdx, err.Error(), 10*curSelectExpBackoffFactor) - time.Sleep(time.Duration(10*curSelectExpBackoffFactor) * time.Millisecond) - curSelectExpBackoffFactor *= 2 - } else { - return nil, db.WrapDbErrorWithQuery(fmt.Sprintf("paged data scanner cannot select %d rows from %s%s after %d attempts; another worker may retry this batch later, but, if some unique idx records has been written already by current worker, the next worker handling this batch will throw an error on them and there is nothing we can do about it;", batchSize, tableName, cql.RunIdSuffix(lookupNodeRunId), selectRetryIdx+1), q, err) - } - } else { - break - } - selectRetryIdx++ - } - - return iter.PageState(), nil -} - -func selectBatchPagedAllRowids(logger *l.Logger, - pCtx *ctx.MessageProcessingContext, - rs *Rowset, - tableName string, - lookupNodeRunId int16, - batchSize int, - pageState []byte) ([]byte, error) { - - logger.PushF("proc.selectBatchPagedAllRowids") - defer logger.PopF() - - if err := rs.InitRows(batchSize); err != nil { - return nil, err - } - - qb := cql.QueryBuilder{} - q := qb. - Keyspace(pCtx.BatchInfo.DataKeyspace). - SelectRun(tableName, lookupNodeRunId, *rs.GetFieldNames()) - - iter := pCtx.CqlSession.Query(q).PageSize(batchSize).PageState(pageState).Iter() - - dbWarnings := iter.Warnings() - if len(dbWarnings) > 0 { - logger.WarnCtx(pCtx, strings.Join(dbWarnings, ";")) - } - - rs.RowCount = 0 - - scanner := iter.Scanner() - for scanner.Next() { - if rs.RowCount >= len(rs.Rows) { - return nil, fmt.Errorf("unexpected data row retrieved, exceeding rowset size %d", len(rs.Rows)) - } - if err := scanner.Scan(*rs.Rows[rs.RowCount]...); err != nil { - return nil, db.WrapDbErrorWithQuery("cannot scan all rows data row", q, err) - } - // We assume gocql creates only UTC timestamps, so this is not needed - // If we ever catch a ts stored in our tables with a non-UTC tz, or gocql returning a non-UTC tz - investigate it. Sanitizing is the last resort and should be avoided. - // if err := rs.SanitizeScannedDatetimesToUtc(rs.RowCount); err != nil { - // return nil, db.WrapDbErrorWithQuery("cannot sanitize datetimes", q, err) - // } - rs.RowCount++ - } - if err := scanner.Err(); err != nil { - return nil, db.WrapDbErrorWithQuery("data all rows scanner error", q, err) - } - - return iter.PageState(), nil -} - -func selectBatchFromIdxTablePaged(logger *l.Logger, - pCtx *ctx.MessageProcessingContext, - rs *Rowset, - tableName string, - lookupNodeRunId int16, - batchSize int, - pageState []byte, - keysToFind *[]string) ([]byte, error) { - - logger.PushF("proc.selectBatchFromIdxTablePaged") - defer logger.PopF() - - if err := rs.InitRows(batchSize); err != nil { - return nil, err - } - - qb := cql.QueryBuilder{} - q := qb.Keyspace(pCtx.BatchInfo.DataKeyspace). - CondInPrepared("key"). // This is an index table, select only selected keys - SelectRun(tableName, lookupNodeRunId, *rs.GetFieldNames()) - - iter := pCtx.CqlSession.Query(q, *keysToFind).PageSize(batchSize).PageState(pageState).Iter() - - dbWarnings := iter.Warnings() - if len(dbWarnings) > 0 { - logger.WarnCtx(pCtx, strings.Join(dbWarnings, ";")) - } - - rs.RowCount = 0 - - scanner := iter.Scanner() - for scanner.Next() { - if rs.RowCount >= len(rs.Rows) { - return nil, fmt.Errorf("unexpected idx row retrieved, exceeding rowset size %d", len(rs.Rows)) - } - if err := scanner.Scan(*rs.Rows[rs.RowCount]...); err != nil { - return nil, db.WrapDbErrorWithQuery("cannot scan idx row", q, err) - } - rs.RowCount++ - } - if err := scanner.Err(); err != nil { - return nil, db.WrapDbErrorWithQuery("idx scanner error", q, err) - } - - return iter.PageState(), nil -} - -func selectBatchFromTableByToken(logger *l.Logger, - pCtx *ctx.MessageProcessingContext, - rs *Rowset, - tableName string, - readerNodeRunId int16, - batchSize int, - startToken int64, - endToken int64) (int64, error) { - - logger.PushF("proc.selectBatchFromTableByToken") - defer logger.PopF() - - if err := rs.InitRows(batchSize); err != nil { - return 0, err - } - - qb := cql.QueryBuilder{} - q := qb.Keyspace(pCtx.BatchInfo.DataKeyspace). - Limit(batchSize). - CondPrepared("token(rowid)", ">="). - CondPrepared("token(rowid)", "<="). - SelectRun(tableName, readerNodeRunId, *rs.GetFieldNames()) - - // TODO: consider retries as we do in selectBatchFromDataTablePaged(); although no timeouts were detected so far here - - iter := pCtx.CqlSession.Query(q, startToken, endToken).Iter() - - dbWarnings := iter.Warnings() - if len(dbWarnings) > 0 { - logger.WarnCtx(pCtx, strings.Join(dbWarnings, ";")) - } - rs.RowCount = 0 - var lastRetrievedToken int64 - for rs.RowCount < len(rs.Rows) && iter.Scan(*rs.Rows[rs.RowCount]...) { - lastRetrievedToken = *((*rs.Rows[rs.RowCount])[rs.FieldsByFieldName["token(rowid)"]].(*int64)) - rs.RowCount++ - } - if err := iter.Close(); err != nil { - return 0, db.WrapDbErrorWithQuery("cannot close iterator", q, err) - } - - return lastRetrievedToken, nil -} - -const HarvestForDeleteRowsetSize = 1000 // Do not let users tweak it, maybe too sensitive - -func DeleteDataAndUniqueIndexesByBatchIdx(logger *l.Logger, pCtx *ctx.MessageProcessingContext) error { - logger.PushF("proc.DeleteDataAndUniqueIndexesByBatchIdx") - defer logger.PopF() - - logger.DebugCtx(pCtx, "deleting data records for %s...", pCtx.BatchInfo.FullBatchId()) - deleteStartTime := time.Now() - - if !pCtx.CurrentScriptNode.HasTableCreator() { - logger.InfoCtx(pCtx, "no table creator, nothing to delete for %s", pCtx.BatchInfo.FullBatchId()) - return nil - } - - // Select from data table by rowid, retrieve all fields that are involved i building unique indexes - uniqueIdxFieldRefs := pCtx.CurrentScriptNode.GetUniqueIndexesFieldRefs() - - rs := NewRowsetFromFieldRefs( - sc.FieldRefs{sc.RowidFieldRef(pCtx.CurrentScriptNode.TableCreator.Name)}, - *uniqueIdxFieldRefs, - sc.FieldRefs{sc.FieldRef{TableName: pCtx.CurrentScriptNode.TableCreator.Name, FieldName: "batch_idx", FieldType: sc.FieldTypeInt}}) - - var pageState []byte - var err error - for { - pageState, err = selectBatchPagedAllRowids(logger, - pCtx, - rs, - pCtx.CurrentScriptNode.TableCreator.Name, - pCtx.BatchInfo.RunId, - HarvestForDeleteRowsetSize, - pageState) - if err != nil { - return err - } - - if rs.RowCount == 0 { - break - } - - // Harvest rowids with batchIdx we are interested in, also harvest keys - - // Prepare the storage for rowids and keys - rowIdsToDelete := make([]int64, rs.RowCount) - uniqueKeysToDeleteMap := map[string][]string{} // unique_idx_name -> list_of_keys_to_delete - for idxName, idxDef := range pCtx.CurrentScriptNode.TableCreator.Indexes { - if idxDef.Uniqueness == sc.IdxUnique { - uniqueKeysToDeleteMap[idxName] = make([]string, rs.RowCount) - } - } - - rowIdsToDeleteCount := 0 - for rowIdx := 0; rowIdx < rs.RowCount; rowIdx++ { - rowId := *((*rs.Rows[rowIdx])[rs.FieldsByFieldName["rowid"]].(*int64)) - batchIdx := int16(*((*rs.Rows[rowIdx])[rs.FieldsByFieldName["batch_idx"]].(*int64))) - if batchIdx == pCtx.BatchInfo.BatchIdx { - // Add this rowid to the list - rowIdsToDelete[rowIdsToDeleteCount] = rowId - // Build the key and add it to the list - tableRecord, err := rs.GetTableRecord(rowIdx) - if err != nil { - return fmt.Errorf("while deleting previous batch attempt leftovers, cannot get table record from [%v]: %s", rs.Rows[rowIdx], err.Error()) - } - for idxName, idxDef := range pCtx.CurrentScriptNode.TableCreator.Indexes { - if _, ok := uniqueKeysToDeleteMap[idxName]; ok { - uniqueKeysToDeleteMap[idxName][rowIdsToDeleteCount], err = sc.BuildKey(tableRecord, idxDef) - if err != nil { - return fmt.Errorf("while deleting previous batch attempt leftovers, cannot build a key for index %s from [%v]: %s", idxName, tableRecord, err.Error()) - } - if len(uniqueKeysToDeleteMap[idxName][rowIdsToDeleteCount]) == 0 { - logger.ErrorCtx(pCtx, "invalid empty key calculated for %v", tableRecord) - } - } - } - rowIdsToDeleteCount++ - } - } - if rowIdsToDeleteCount > 0 { - rowIdsToDelete = rowIdsToDelete[:rowIdsToDeleteCount] - // NOTE: Assuming Delete won't interfere with paging - logger.DebugCtx(pCtx, "deleting %d data records from %s: %v", len(rowIdsToDelete), pCtx.BatchInfo.FullBatchId(), rowIdsToDelete) - qbDel := cql.QueryBuilder{} - qDel := qbDel. - Keyspace(pCtx.BatchInfo.DataKeyspace). - CondInInt("rowid", rowIdsToDelete[:rowIdsToDeleteCount]). - DeleteRun(pCtx.CurrentScriptNode.TableCreator.Name, pCtx.BatchInfo.RunId) - if err := pCtx.CqlSession.Query(qDel).Exec(); err != nil { - return db.WrapDbErrorWithQuery("cannot delete from data table", qDel, err) - } - logger.InfoCtx(pCtx, "deleted %d records from data table for %s, now will delete from %d indexes", len(rowIdsToDelete), pCtx.BatchInfo.FullBatchId(), len(uniqueKeysToDeleteMap)) - - for idxName, idxKeysToDelete := range uniqueKeysToDeleteMap { - logger.DebugCtx(pCtx, "deleting %d idx %s records from %d/%s idx %s for batch_idx %d: %v", len(rowIdsToDelete), idxName, pCtx.BatchInfo.RunId, pCtx.BatchInfo.TargetNodeName, idxName, pCtx.BatchInfo.BatchIdx, idxKeysToDelete) - qbDel := cql.QueryBuilder{} - qDel := qbDel. - Keyspace(pCtx.BatchInfo.DataKeyspace). - CondInString("key", idxKeysToDelete[:rowIdsToDeleteCount]). - DeleteRun(idxName, pCtx.BatchInfo.RunId) - if err := pCtx.CqlSession.Query(qDel).Exec(); err != nil { - return db.WrapDbErrorWithQuery("cannot delete from idx table", qDel, err) - } - logger.InfoCtx(pCtx, "deleted %d records from idx table %s for batch %d/%s/%d", len(rowIdsToDelete), idxName, pCtx.BatchInfo.RunId, pCtx.BatchInfo.TargetNodeName, pCtx.BatchInfo.BatchIdx) - } - } - if rs.RowCount < pCtx.CurrentScriptNode.TableReader.RowsetSize || len(pageState) == 0 { - break - } - } - - logger.DebugCtx(pCtx, "deleted data records for %s, elapsed %v", pCtx.BatchInfo.FullBatchId(), time.Since(deleteStartTime)) - - return nil -} +package proc + +import ( + "fmt" + "strings" + "time" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/gocql/gocql" +) + +// func ClearNodeOutputs(logger *l.Logger, script *sc.ScriptDef, session *gocql.Session, keyspace string, nodeName string, runId int16) error { +// node, ok := script.ScriptNodes[nodeName] +// if !ok { +// return fmt.Errorf("cannot find node %s", nodeName) +// } +// if node.HasTableCreator() { +// qb := cql.QueryBuilder{} +// logger.Info("deleting data table %s.%s...", keyspace, node.TableCreator.Name) +// query := qb.Keyspace(keyspace).DropRun(node.TableCreator.Name, runId) +// if err := session.Query(query).Exec(); err != nil { +// return fmt.Errorf("cannot drop data table [%s]: [%s]", query, err.Error()) +// } + +// for idxName, _ := range node.TableCreator.Indexes { +// qb := cql.QueryBuilder{} +// logger.Info("deleting index table %s.%s...", keyspace, idxName) +// query := qb.Keyspace(keyspace).DropRun(idxName, runId) +// if err := session.Query(query).Exec(); err != nil { +// return fmt.Errorf("cannot drop idx table [%s]: [%s]", query, err.Error()) +// } +// } +// } else if node.HasFileCreator() { +// if _, err := os.Stat(node.FileCreator.Url); err == nil { +// logger.Info("deleting output file %s...", node.FileCreator.Url) +// if err := os.Remove(node.FileCreator.Url); err != nil { +// return fmt.Errorf("cannot delete file [%s]: [%s]", node.FileCreator.Url, err.Error()) +// } +// } +// } +// return nil +// } + +func selectBatchFromDataTablePaged(logger *l.CapiLogger, + pCtx *ctx.MessageProcessingContext, + rs *Rowset, + tableName string, + lookupNodeRunId int16, + batchSize int, + pageState []byte, + rowidsToFind map[int64]struct{}) ([]byte, error) { + + logger.PushF("proc.selectBatchFromDataTablePaged") + defer logger.PopF() + + if err := rs.InitRows(batchSize); err != nil { + return nil, err + } + + rowids := make([]int64, len(rowidsToFind)) + i := 0 + for k := range rowidsToFind { + rowids[i] = k + i++ + } + + qb := cql.QueryBuilder{} + q := qb. + Keyspace(pCtx.BatchInfo.DataKeyspace). + CondInPrepared("rowid"). // This is a right-side lookup table, select by rowid + SelectRun(tableName, lookupNodeRunId, *rs.GetFieldNames()) + + var iter *gocql.Iter + selectRetryIdx := 0 + curSelectExpBackoffFactor := 1 + for { + iter = pCtx.CqlSession.Query(q, rowids).PageSize(batchSize).PageState(pageState).Iter() + + dbWarnings := iter.Warnings() + if len(dbWarnings) > 0 { + // TODO: figure out what those warnigs can be, never saw one + logger.WarnCtx(pCtx, "got warnigs while selecting %d rows from %s%s: %s", batchSize, tableName, cql.RunIdSuffix(lookupNodeRunId), strings.Join(dbWarnings, ";")) + } + + rs.RowCount = 0 + + scanner := iter.Scanner() + for scanner.Next() { + if rs.RowCount >= len(rs.Rows) { + return nil, fmt.Errorf("unexpected data row retrieved, exceeding rowset size %d", len(rs.Rows)) + } + if err := scanner.Scan(*rs.Rows[rs.RowCount]...); err != nil { + return nil, db.WrapDbErrorWithQuery("cannot scan paged data row", q, err) + } + // We assume gocql creates only UTC timestamps, so this is not needed. + // If we ever catch a ts stored in our tables with a non-UTC tz, or gocql returning a non-UTC tz - investigate it. Sanitizing is the last resort and should be avoided. + // if err := rs.SanitizeScannedDatetimesToUtc(rs.RowCount); err != nil { + // return nil, db.WrapDbErrorWithQuery("cannot sanitize datetimes", q, err) + // } + rs.RowCount++ + } + + err := scanner.Err() + if err == nil { + break + } + if !(strings.Contains(err.Error(), "Operation timed out") || strings.Contains(err.Error(), "Cannot achieve consistency level") && selectRetryIdx < 3) { + return nil, db.WrapDbErrorWithQuery(fmt.Sprintf("paged data scanner cannot select %d rows from %s%s after %d attempts; another worker may retry this batch later, but, if some unique idx records has been written already by current worker, the next worker handling this batch will throw an error on them and there is nothing we can do about it;", batchSize, tableName, cql.RunIdSuffix(lookupNodeRunId), selectRetryIdx+1), q, err) + } + logger.WarnCtx(pCtx, "cannot select %d rows from %s%s on retry %d, getting timeout/consistency error (%s), will wait for %dms and retry", batchSize, tableName, cql.RunIdSuffix(lookupNodeRunId), selectRetryIdx, err.Error(), 10*curSelectExpBackoffFactor) + time.Sleep(time.Duration(10*curSelectExpBackoffFactor) * time.Millisecond) + curSelectExpBackoffFactor *= 2 + selectRetryIdx++ + } + + return iter.PageState(), nil +} + +func selectBatchPagedAllRowids(logger *l.CapiLogger, + pCtx *ctx.MessageProcessingContext, + rs *Rowset, + tableName string, + lookupNodeRunId int16, + batchSize int, + pageState []byte) ([]byte, error) { + + logger.PushF("proc.selectBatchPagedAllRowids") + defer logger.PopF() + + if err := rs.InitRows(batchSize); err != nil { + return nil, err + } + + qb := cql.QueryBuilder{} + q := qb. + Keyspace(pCtx.BatchInfo.DataKeyspace). + SelectRun(tableName, lookupNodeRunId, *rs.GetFieldNames()) + + iter := pCtx.CqlSession.Query(q).PageSize(batchSize).PageState(pageState).Iter() + + dbWarnings := iter.Warnings() + if len(dbWarnings) > 0 { + logger.WarnCtx(pCtx, strings.Join(dbWarnings, ";")) + } + + rs.RowCount = 0 + + scanner := iter.Scanner() + for scanner.Next() { + if rs.RowCount >= len(rs.Rows) { + return nil, fmt.Errorf("unexpected data row retrieved, exceeding rowset size %d", len(rs.Rows)) + } + if err := scanner.Scan(*rs.Rows[rs.RowCount]...); err != nil { + return nil, db.WrapDbErrorWithQuery("cannot scan all rows data row", q, err) + } + // We assume gocql creates only UTC timestamps, so this is not needed + // If we ever catch a ts stored in our tables with a non-UTC tz, or gocql returning a non-UTC tz - investigate it. Sanitizing is the last resort and should be avoided. + // if err := rs.SanitizeScannedDatetimesToUtc(rs.RowCount); err != nil { + // return nil, db.WrapDbErrorWithQuery("cannot sanitize datetimes", q, err) + // } + rs.RowCount++ + } + if err := scanner.Err(); err != nil { + return nil, db.WrapDbErrorWithQuery("data all rows scanner error", q, err) + } + + return iter.PageState(), nil +} + +func selectBatchFromIdxTablePaged(logger *l.CapiLogger, + pCtx *ctx.MessageProcessingContext, + rs *Rowset, + tableName string, + lookupNodeRunId int16, + batchSize int, + pageState []byte, + keysToFind *[]string) ([]byte, error) { + + logger.PushF("proc.selectBatchFromIdxTablePaged") + defer logger.PopF() + + if err := rs.InitRows(batchSize); err != nil { + return nil, err + } + + qb := cql.QueryBuilder{} + q := qb.Keyspace(pCtx.BatchInfo.DataKeyspace). + CondInPrepared("key"). // This is an index table, select only selected keys + SelectRun(tableName, lookupNodeRunId, *rs.GetFieldNames()) + + iter := pCtx.CqlSession.Query(q, *keysToFind).PageSize(batchSize).PageState(pageState).Iter() + + dbWarnings := iter.Warnings() + if len(dbWarnings) > 0 { + logger.WarnCtx(pCtx, strings.Join(dbWarnings, ";")) + } + + rs.RowCount = 0 + + scanner := iter.Scanner() + for scanner.Next() { + if rs.RowCount >= len(rs.Rows) { + return nil, fmt.Errorf("unexpected idx row retrieved, exceeding rowset size %d", len(rs.Rows)) + } + if err := scanner.Scan(*rs.Rows[rs.RowCount]...); err != nil { + return nil, db.WrapDbErrorWithQuery("cannot scan idx row", q, err) + } + rs.RowCount++ + } + if err := scanner.Err(); err != nil { + return nil, db.WrapDbErrorWithQuery("idx scanner error", q, err) + } + + return iter.PageState(), nil +} + +func selectBatchFromTableByToken(logger *l.CapiLogger, + pCtx *ctx.MessageProcessingContext, + rs *Rowset, + tableName string, + readerNodeRunId int16, + batchSize int, + startToken int64, + endToken int64) (int64, error) { + + logger.PushF("proc.selectBatchFromTableByToken") + defer logger.PopF() + + if err := rs.InitRows(batchSize); err != nil { + return 0, err + } + + qb := cql.QueryBuilder{} + q := qb.Keyspace(pCtx.BatchInfo.DataKeyspace). + Limit(batchSize). + CondPrepared("token(rowid)", ">="). + CondPrepared("token(rowid)", "<="). + SelectRun(tableName, readerNodeRunId, *rs.GetFieldNames()) + + // TODO: consider retries as we do in selectBatchFromDataTablePaged(); although no timeouts were detected so far here + + iter := pCtx.CqlSession.Query(q, startToken, endToken).Iter() + + dbWarnings := iter.Warnings() + if len(dbWarnings) > 0 { + logger.WarnCtx(pCtx, strings.Join(dbWarnings, ";")) + } + rs.RowCount = 0 + var lastRetrievedToken int64 + for rs.RowCount < len(rs.Rows) && iter.Scan(*rs.Rows[rs.RowCount]...) { + lastRetrievedToken = *((*rs.Rows[rs.RowCount])[rs.FieldsByFieldName["token(rowid)"]].(*int64)) + rs.RowCount++ + } + if err := iter.Close(); err != nil { + return 0, db.WrapDbErrorWithQuery("cannot close iterator", q, err) + } + + return lastRetrievedToken, nil +} + +const HarvestForDeleteRowsetSize = 1000 // Do not let users tweak it, maybe too sensitive + +func DeleteDataAndUniqueIndexesByBatchIdx(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) error { + logger.PushF("proc.DeleteDataAndUniqueIndexesByBatchIdx") + defer logger.PopF() + + logger.DebugCtx(pCtx, "deleting data records for %s...", pCtx.BatchInfo.FullBatchId()) + deleteStartTime := time.Now() + + if !pCtx.CurrentScriptNode.HasTableCreator() { + logger.InfoCtx(pCtx, "no table creator, nothing to delete for %s", pCtx.BatchInfo.FullBatchId()) + return nil + } + + // Select from data table by rowid, retrieve all fields that are involved i building unique indexes + uniqueIdxFieldRefs := pCtx.CurrentScriptNode.GetUniqueIndexesFieldRefs() + + rs := NewRowsetFromFieldRefs( + sc.FieldRefs{sc.RowidFieldRef(pCtx.CurrentScriptNode.TableCreator.Name)}, + *uniqueIdxFieldRefs, + sc.FieldRefs{sc.FieldRef{TableName: pCtx.CurrentScriptNode.TableCreator.Name, FieldName: "batch_idx", FieldType: sc.FieldTypeInt}}) + + var pageState []byte + var err error + for { + pageState, err = selectBatchPagedAllRowids(logger, + pCtx, + rs, + pCtx.CurrentScriptNode.TableCreator.Name, + pCtx.BatchInfo.RunId, + HarvestForDeleteRowsetSize, + pageState) + if err != nil { + return err + } + + if rs.RowCount == 0 { + break + } + + // Harvest rowids with batchIdx we are interested in, also harvest keys + + // Prepare the storage for rowids and keys + rowIdsToDelete := make([]int64, rs.RowCount) + uniqueKeysToDeleteMap := map[string][]string{} // unique_idx_name -> list_of_keys_to_delete + for idxName, idxDef := range pCtx.CurrentScriptNode.TableCreator.Indexes { + if idxDef.Uniqueness == sc.IdxUnique { + uniqueKeysToDeleteMap[idxName] = make([]string, rs.RowCount) + } + } + + rowIdsToDeleteCount := 0 + for rowIdx := 0; rowIdx < rs.RowCount; rowIdx++ { + rowId := *((*rs.Rows[rowIdx])[rs.FieldsByFieldName["rowid"]].(*int64)) + batchIdx := int16(*((*rs.Rows[rowIdx])[rs.FieldsByFieldName["batch_idx"]].(*int64))) + if batchIdx == pCtx.BatchInfo.BatchIdx { + // Add this rowid to the list + rowIdsToDelete[rowIdsToDeleteCount] = rowId + // Build the key and add it to the list + tableRecord, err := rs.GetTableRecord(rowIdx) + if err != nil { + return fmt.Errorf("while deleting previous batch attempt leftovers, cannot get table record from [%v]: %s", rs.Rows[rowIdx], err.Error()) + } + for idxName, idxDef := range pCtx.CurrentScriptNode.TableCreator.Indexes { + if _, ok := uniqueKeysToDeleteMap[idxName]; ok { + uniqueKeysToDeleteMap[idxName][rowIdsToDeleteCount], err = sc.BuildKey(tableRecord, idxDef) + if err != nil { + return fmt.Errorf("while deleting previous batch attempt leftovers, cannot build a key for index %s from [%v]: %s", idxName, tableRecord, err.Error()) + } + if len(uniqueKeysToDeleteMap[idxName][rowIdsToDeleteCount]) == 0 { + logger.ErrorCtx(pCtx, "invalid empty key calculated for %v", tableRecord) + } + } + } + rowIdsToDeleteCount++ + } + } + if rowIdsToDeleteCount > 0 { + rowIdsToDelete = rowIdsToDelete[:rowIdsToDeleteCount] + // NOTE: Assuming Delete won't interfere with paging + logger.DebugCtx(pCtx, "deleting %d data records from %s: %v", len(rowIdsToDelete), pCtx.BatchInfo.FullBatchId(), rowIdsToDelete) + qbDel := cql.QueryBuilder{} + qDel := qbDel. + Keyspace(pCtx.BatchInfo.DataKeyspace). + CondInInt("rowid", rowIdsToDelete[:rowIdsToDeleteCount]). + DeleteRun(pCtx.CurrentScriptNode.TableCreator.Name, pCtx.BatchInfo.RunId) + if err := pCtx.CqlSession.Query(qDel).Exec(); err != nil { + return db.WrapDbErrorWithQuery("cannot delete from data table", qDel, err) + } + logger.InfoCtx(pCtx, "deleted %d records from data table for %s, now will delete from %d indexes", len(rowIdsToDelete), pCtx.BatchInfo.FullBatchId(), len(uniqueKeysToDeleteMap)) + + for idxName, idxKeysToDelete := range uniqueKeysToDeleteMap { + logger.DebugCtx(pCtx, "deleting %d idx %s records from %d/%s idx %s for batch_idx %d: %v", len(rowIdsToDelete), idxName, pCtx.BatchInfo.RunId, pCtx.BatchInfo.TargetNodeName, idxName, pCtx.BatchInfo.BatchIdx, idxKeysToDelete) + qbDel := cql.QueryBuilder{} + qDel := qbDel. + Keyspace(pCtx.BatchInfo.DataKeyspace). + CondInString("key", idxKeysToDelete[:rowIdsToDeleteCount]). + DeleteRun(idxName, pCtx.BatchInfo.RunId) + if err := pCtx.CqlSession.Query(qDel).Exec(); err != nil { + return db.WrapDbErrorWithQuery("cannot delete from idx table", qDel, err) + } + logger.InfoCtx(pCtx, "deleted %d records from idx table %s for batch %d/%s/%d", len(rowIdsToDelete), idxName, pCtx.BatchInfo.RunId, pCtx.BatchInfo.TargetNodeName, pCtx.BatchInfo.BatchIdx) + } + } + if rs.RowCount < pCtx.CurrentScriptNode.TableReader.RowsetSize || len(pageState) == 0 { + break + } + } + + logger.DebugCtx(pCtx, "deleted data records for %s, elapsed %v", pCtx.BatchInfo.FullBatchId(), time.Since(deleteStartTime)) + + return nil +} diff --git a/pkg/proc/file_inserter.go b/pkg/proc/file_inserter.go index f97f231..1edc4e4 100644 --- a/pkg/proc/file_inserter.go +++ b/pkg/proc/file_inserter.go @@ -1,151 +1,151 @@ -package proc - -import ( - "fmt" - "os" - "strings" - - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/xfer" -) - -type FileInserter struct { - PCtx *ctx.MessageProcessingContext - FileCreator *sc.FileCreatorDef - CurrentBatch *WriteFileBatch - BatchCapacity int - BatchesIn chan *WriteFileBatch - ErrorsOut chan error - BatchesSent int - FinalFileUrl string - TempFilePath string -} - -const DefaultFileInserterBatchCapacity int = 1000 - -type WriteFileBatch struct { - Rows [][]interface{} - RowCount int -} - -func newWriteFileBatch(batchCapacity int) *WriteFileBatch { - return &WriteFileBatch{ - Rows: make([][]interface{}, batchCapacity), - RowCount: 0, - } -} - -func newFileInserter(pCtx *ctx.MessageProcessingContext, fileCreator *sc.FileCreatorDef, runId int16, batchIdx int16) *FileInserter { - instr := FileInserter{ - PCtx: pCtx, - FileCreator: fileCreator, - BatchCapacity: DefaultFileInserterBatchCapacity, - BatchesIn: make(chan *WriteFileBatch, sc.MaxFileCreatorTopLimit/DefaultFileInserterBatchCapacity), - ErrorsOut: make(chan error, 1), - BatchesSent: 0, - FinalFileUrl: strings.ReplaceAll(strings.ReplaceAll(fileCreator.UrlTemplate, sc.ReservedParamRunId, fmt.Sprintf("%05d", runId)), sc.ReservedParamBatchIdx, fmt.Sprintf("%05d", batchIdx)), - } - - return &instr -} - -func (instr *FileInserter) checkWorkerOutputForErrors() error { - errors := make([]string, 0) - for { - select { - case err := <-instr.ErrorsOut: - instr.BatchesSent-- - if err != nil { - errors = append(errors, err.Error()) - } - default: - if len(errors) > 0 { - return fmt.Errorf(strings.Join(errors, "; ")) - } else { - return nil - } - } - } -} - -func (instr *FileInserter) waitForWorker(logger *l.Logger, pCtx *ctx.MessageProcessingContext) error { - logger.PushF("proc.waitForWorkers/FieInserter") - defer logger.PopF() - - // waitForWorker may be used for writing leftovers, handle them - if instr.CurrentBatch != nil && instr.CurrentBatch.RowCount > 0 { - instr.BatchesIn <- instr.CurrentBatch - instr.BatchesSent++ - instr.CurrentBatch = nil - } - - logger.DebugCtx(pCtx, "started reading BatchesSent=%d from instr.ErrorsOut", instr.BatchesSent) - errors := make([]string, 0) - // It's crucial that the number of errors to receive eventually should match instr.BatchesSent - errCount := 0 - for i := 0; i < instr.BatchesSent; i++ { - err := <-instr.ErrorsOut - if err != nil { - errors = append(errors, err.Error()) - errCount++ - } - logger.DebugCtx(pCtx, "got result for sent record %d out of %d from instr.ErrorsOut, %d errors so far", i, instr.BatchesSent, errCount) - } - logger.DebugCtx(pCtx, "done reading BatchesSent=%d from instr.ErrorsOut, %d errors", instr.BatchesSent, errCount) - - // Reset for the next cycle, if it ever happens - instr.BatchesSent = 0 - - // Now it's safe to close - logger.DebugCtx(pCtx, "closing BatchesIn") - close(instr.BatchesIn) - logger.DebugCtx(pCtx, "closed BatchesIn") - - if len(errors) > 0 { - return fmt.Errorf(strings.Join(errors, "; ")) - } else { - return nil - } -} - -func (instr *FileInserter) waitForWorkerAndCloseErrorsOut(logger *l.Logger, pCtx *ctx.MessageProcessingContext) error { - logger.PushF("proc.waitForWorkersAndClose/FileInserter") - defer logger.PopF() - - err := instr.waitForWorker(logger, pCtx) - logger.DebugCtx(pCtx, "closing ErrorsOut") - close(instr.ErrorsOut) - logger.DebugCtx(pCtx, "closed ErrorsOut") - return err -} - -func (instr *FileInserter) add(row []interface{}) { - if instr.CurrentBatch == nil { - instr.CurrentBatch = newWriteFileBatch(instr.BatchCapacity) - } - instr.CurrentBatch.Rows[instr.CurrentBatch.RowCount] = row - instr.CurrentBatch.RowCount++ - - if instr.CurrentBatch.RowCount == instr.BatchCapacity { - instr.BatchesIn <- instr.CurrentBatch - instr.BatchesSent++ - instr.CurrentBatch = nil - } -} - -func (instr *FileInserter) sendFileToFinal(logger *l.Logger, pCtx *ctx.MessageProcessingContext, privateKeys map[string]string) error { - logger.PushF("proc.sendFileToFinal") - defer logger.PopF() - - if instr.TempFilePath == "" { - // Nothing to do, the file is already at its destination - return nil - } - defer os.Remove(instr.TempFilePath) - - logger.InfoCtx(pCtx, "uploading %s to %s...", instr.TempFilePath, instr.FinalFileUrl) - - return xfer.UploadSftpFile(instr.TempFilePath, instr.FinalFileUrl, privateKeys) -} +package proc + +import ( + "fmt" + "os" + "strings" + + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/xfer" +) + +type FileInserter struct { + PCtx *ctx.MessageProcessingContext + FileCreator *sc.FileCreatorDef + CurrentBatch *WriteFileBatch + BatchCapacity int + BatchesIn chan *WriteFileBatch + ErrorsOut chan error + BatchesSent int + FinalFileUrl string + TempFilePath string +} + +const DefaultFileInserterBatchCapacity int = 1000 + +type WriteFileBatch struct { + Rows [][]any + RowCount int +} + +func newWriteFileBatch(batchCapacity int) *WriteFileBatch { + return &WriteFileBatch{ + Rows: make([][]any, batchCapacity), + RowCount: 0, + } +} + +func newFileInserter(pCtx *ctx.MessageProcessingContext, fileCreator *sc.FileCreatorDef, runId int16, batchIdx int16) *FileInserter { + instr := FileInserter{ + PCtx: pCtx, + FileCreator: fileCreator, + BatchCapacity: DefaultFileInserterBatchCapacity, + BatchesIn: make(chan *WriteFileBatch, sc.MaxFileCreatorTopLimit/DefaultFileInserterBatchCapacity), + ErrorsOut: make(chan error, 1), + BatchesSent: 0, + FinalFileUrl: strings.ReplaceAll(strings.ReplaceAll(fileCreator.UrlTemplate, sc.ReservedParamRunId, fmt.Sprintf("%05d", runId)), sc.ReservedParamBatchIdx, fmt.Sprintf("%05d", batchIdx)), + } + + return &instr +} + +func (instr *FileInserter) checkWorkerOutputForErrors() error { + errors := make([]string, 0) + for { + select { + case err := <-instr.ErrorsOut: + instr.BatchesSent-- + if err != nil { + errors = append(errors, err.Error()) + } + default: + if len(errors) > 0 { + return fmt.Errorf(strings.Join(errors, "; ")) + } else { + return nil + } + } + } +} + +func (instr *FileInserter) waitForWorker(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) error { + logger.PushF("proc.waitForWorkers/FieInserter") + defer logger.PopF() + + // waitForWorker may be used for writing leftovers, handle them + if instr.CurrentBatch != nil && instr.CurrentBatch.RowCount > 0 { + instr.BatchesIn <- instr.CurrentBatch + instr.BatchesSent++ + instr.CurrentBatch = nil + } + + logger.DebugCtx(pCtx, "started reading BatchesSent=%d from instr.ErrorsOut", instr.BatchesSent) + errors := make([]string, 0) + // It's crucial that the number of errors to receive eventually should match instr.BatchesSent + errCount := 0 + for i := 0; i < instr.BatchesSent; i++ { + err := <-instr.ErrorsOut + if err != nil { + errors = append(errors, err.Error()) + errCount++ + } + logger.DebugCtx(pCtx, "got result for sent record %d out of %d from instr.ErrorsOut, %d errors so far", i, instr.BatchesSent, errCount) + } + logger.DebugCtx(pCtx, "done reading BatchesSent=%d from instr.ErrorsOut, %d errors", instr.BatchesSent, errCount) + + // Reset for the next cycle, if it ever happens + instr.BatchesSent = 0 + + // Now it's safe to close + logger.DebugCtx(pCtx, "closing BatchesIn") + close(instr.BatchesIn) + logger.DebugCtx(pCtx, "closed BatchesIn") + + if len(errors) > 0 { + return fmt.Errorf(strings.Join(errors, "; ")) + } + + return nil +} + +func (instr *FileInserter) waitForWorkerAndCloseErrorsOut(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) error { + logger.PushF("proc.waitForWorkersAndClose/FileInserter") + defer logger.PopF() + + err := instr.waitForWorker(logger, pCtx) + logger.DebugCtx(pCtx, "closing ErrorsOut") + close(instr.ErrorsOut) + logger.DebugCtx(pCtx, "closed ErrorsOut") + return err +} + +func (instr *FileInserter) add(row []any) { + if instr.CurrentBatch == nil { + instr.CurrentBatch = newWriteFileBatch(instr.BatchCapacity) + } + instr.CurrentBatch.Rows[instr.CurrentBatch.RowCount] = row + instr.CurrentBatch.RowCount++ + + if instr.CurrentBatch.RowCount == instr.BatchCapacity { + instr.BatchesIn <- instr.CurrentBatch + instr.BatchesSent++ + instr.CurrentBatch = nil + } +} + +func (instr *FileInserter) sendFileToFinal(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, privateKeys map[string]string) error { + logger.PushF("proc.sendFileToFinal") + defer logger.PopF() + + if instr.TempFilePath == "" { + // Nothing to do, the file is already at its destination + return nil + } + defer os.Remove(instr.TempFilePath) + + logger.InfoCtx(pCtx, "uploading %s to %s...", instr.TempFilePath, instr.FinalFileUrl) + + return xfer.UploadSftpFile(instr.TempFilePath, instr.FinalFileUrl, privateKeys) +} diff --git a/pkg/proc/file_inserter_csv.go b/pkg/proc/file_inserter_csv.go index 07f6d20..d1771e9 100644 --- a/pkg/proc/file_inserter_csv.go +++ b/pkg/proc/file_inserter_csv.go @@ -12,7 +12,7 @@ import ( "github.com/shopspring/decimal" ) -func (instr *FileInserter) createCsvFileAndStartWorker(logger *l.Logger) error { +func (instr *FileInserter) createCsvFileAndStartWorker(logger *l.CapiLogger) error { logger.PushF("proc.createCsvFileAndStartWorker") defer logger.PopF() @@ -67,7 +67,7 @@ func (instr *FileInserter) createCsvFileAndStartWorker(logger *l.Logger) error { return nil } -func (instr *FileInserter) csvFileInserterWorker(logger *l.Logger) { +func (instr *FileInserter) csvFileInserterWorker(logger *l.CapiLogger) { logger.PushF("proc.csvFileInserterWorker") defer logger.PopF() @@ -121,13 +121,16 @@ func (instr *FileInserter) csvFileInserterWorker(logger *l.Logger) { } } - f.Sync() - if _, err := f.WriteString(b.String()); err != nil { - instr.ErrorsOut <- fmt.Errorf("cannot write string to %s(temp %s): [%s]", instr.FinalFileUrl, instr.TempFilePath, err.Error()) + if err = f.Sync(); err == nil { + if _, err = f.WriteString(b.String()); err != nil { + instr.ErrorsOut <- fmt.Errorf("cannot write string to %s(temp %s): [%s]", instr.FinalFileUrl, instr.TempFilePath, err.Error()) + } else { + dur := time.Since(batchStartTime) + logger.InfoCtx(instr.PCtx, "%d items in %.3fs (%.0f items/s)", batch.RowCount, dur.Seconds(), float64(batch.RowCount)/dur.Seconds()) + instr.ErrorsOut <- nil + } } else { - dur := time.Since(batchStartTime) - logger.InfoCtx(instr.PCtx, "%d items in %.3fs (%.0f items/s)", batch.RowCount, dur.Seconds(), float64(batch.RowCount)/dur.Seconds()) - instr.ErrorsOut <- nil + instr.ErrorsOut <- fmt.Errorf("cannot sync file %s(temp %s): [%s]", instr.FinalFileUrl, instr.TempFilePath, err.Error()) } } // next batch } diff --git a/pkg/proc/file_inserter_parquet.go b/pkg/proc/file_inserter_parquet.go index 734a6dd..2af9178 100644 --- a/pkg/proc/file_inserter_parquet.go +++ b/pkg/proc/file_inserter_parquet.go @@ -13,7 +13,7 @@ import ( "github.com/shopspring/decimal" ) -func (instr *FileInserter) createParquetFileAndStartWorker(logger *l.Logger, codec sc.ParquetCodecType) error { +func (instr *FileInserter) createParquetFileAndStartWorker(logger *l.CapiLogger, codec sc.ParquetCodecType) error { logger.PushF("proc.createParquetFileAndStartWorker") defer logger.PopF() @@ -47,7 +47,7 @@ func (instr *FileInserter) createParquetFileAndStartWorker(logger *l.Logger, cod return nil } -func (instr *FileInserter) parquetFileInserterWorker(logger *l.Logger, codec sc.ParquetCodecType) { +func (instr *FileInserter) parquetFileInserterWorker(logger *l.CapiLogger, codec sc.ParquetCodecType) { logger.PushF("proc.parquetFileInserterWorker") defer logger.PopF() @@ -87,7 +87,7 @@ func (instr *FileInserter) parquetFileInserterWorker(logger *l.Logger, codec sc. batchStartTime := time.Now() var errAddData error for rowIdx := 0; rowIdx < batch.RowCount; rowIdx++ { - d := map[string]interface{}{} + d := map[string]any{} for i := 0; i < len(instr.FileCreator.Columns); i++ { switch instr.FileCreator.Columns[i].Type { case sc.FieldTypeString: @@ -134,7 +134,7 @@ func (instr *FileInserter) parquetFileInserterWorker(logger *l.Logger, codec sc. d[instr.FileCreator.Columns[i].Parquet.ColumnName] = storage.ParquetWriterMilliTs(typedValue) default: errAddData = fmt.Errorf("cannot convert column %s value [%v] to Parquet: unsupported type", instr.FileCreator.Columns[i].Parquet.ColumnName, batch.Rows[rowIdx][i]) - break + break //nolint:all , https://github.com/dominikh/go-tools/issues/59 } } if err := w.FileWriter.AddData(d); err != nil { diff --git a/pkg/proc/file_read_csv.go b/pkg/proc/file_read_csv.go index 4f38f62..b16d8da 100644 --- a/pkg/proc/file_read_csv.go +++ b/pkg/proc/file_read_csv.go @@ -13,7 +13,7 @@ import ( "github.com/capillariesio/capillaries/pkg/sc" ) -func readCsv(envConfig *env.EnvConfig, logger *l.Logger, pCtx *ctx.MessageProcessingContext, totalStartTime time.Time, filePath string, fileReader io.Reader) (BatchStats, error) { +func readCsv(envConfig *env.EnvConfig, logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, totalStartTime time.Time, filePath string, fileReader io.Reader) (BatchStats, error) { bs := BatchStats{RowsRead: 0, RowsWritten: 0} node := pCtx.CurrentScriptNode @@ -23,11 +23,10 @@ func readCsv(envConfig *env.EnvConfig, logger *l.Logger, pCtx *ctx.MessageProces // To avoid bare \" error: https://stackoverflow.com/questions/31326659/golang-csv-error-bare-in-non-quoted-field r.LazyQuotes = true - var lineIdx int64 = 0 + var lineIdx int64 tableRecordBatchCount := 0 - instr := newTableInserter(envConfig, logger, pCtx, &node.TableCreator, DefaultInserterBatchSize) - //instr.verifyTablesExist() + instr := newTableInserter(envConfig, pCtx, &node.TableCreator, DefaultInserterBatchSize) if err := instr.startWorkers(logger, pCtx); err != nil { return bs, err } @@ -69,11 +68,13 @@ func readCsv(envConfig *env.EnvConfig, logger *l.Logger, pCtx *ctx.MessageProces // Write batch if needed if inResult { - instr.add(tableRecord) + if err = instr.add(tableRecord); err != nil { + return bs, fmt.Errorf("cannot add record to batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + } tableRecordBatchCount++ if tableRecordBatchCount == DefaultInserterBatchSize { if err := instr.waitForWorkers(logger, pCtx); err != nil { - return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + return bs, fmt.Errorf("cannot save record to batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) } reportWriteTable(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) batchStartTime = time.Now() diff --git a/pkg/proc/file_read_parquet.go b/pkg/proc/file_read_parquet.go index 494af51..c6cc061 100644 --- a/pkg/proc/file_read_parquet.go +++ b/pkg/proc/file_read_parquet.go @@ -15,14 +15,14 @@ import ( "github.com/fraugster/parquet-go/parquet" ) -func readParquetRowToValuesMap(d map[string]interface{}, +func readParquetRowToValuesMap(d map[string]any, rowIdx int, requestedParquetColumnNames []string, parquetToCapiFieldNameMap map[string]string, parquetToCapiTypeMap map[string]sc.TableFieldType, schemaElementMap map[string]*parquet.SchemaElement, colVars eval.VarValuesMap) error { - colVars[sc.ReaderAlias] = map[string]interface{}{} + colVars[sc.ReaderAlias] = map[string]any{} for _, parquetColName := range requestedParquetColumnNames { capiFieldName, ok := parquetToCapiFieldNameMap[parquetColName] if !ok { @@ -79,7 +79,7 @@ func readParquetRowToValuesMap(d map[string]interface{}, return nil } -func readParquet(envConfig *env.EnvConfig, logger *l.Logger, pCtx *ctx.MessageProcessingContext, totalStartTime time.Time, filePath string, fileReadSeeker io.ReadSeeker) (BatchStats, error) { +func readParquet(envConfig *env.EnvConfig, logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, totalStartTime time.Time, filePath string, fileReadSeeker io.ReadSeeker) (BatchStats, error) { bs := BatchStats{RowsRead: 0, RowsWritten: 0} node := pCtx.CurrentScriptNode @@ -122,11 +122,11 @@ func readParquet(envConfig *env.EnvConfig, logger *l.Logger, pCtx *ctx.MessagePr } } - var lineIdx int64 = 0 + lineIdx := int64(0) tableRecordBatchCount := 0 // Prepare inserter - instr := newTableInserter(envConfig, logger, pCtx, &node.TableCreator, DefaultInserterBatchSize) + instr := newTableInserter(envConfig, pCtx, &node.TableCreator, DefaultInserterBatchSize) if err := instr.startWorkers(logger, pCtx); err != nil { return bs, err } @@ -162,7 +162,9 @@ func readParquet(envConfig *env.EnvConfig, logger *l.Logger, pCtx *ctx.MessagePr // Write batch if needed if inResult { - instr.add(tableRecord) + if err = instr.add(tableRecord); err != nil { + return bs, fmt.Errorf("cannot add record to batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + } tableRecordBatchCount++ if tableRecordBatchCount == DefaultInserterBatchSize { if err := instr.waitForWorkers(logger, pCtx); err != nil { diff --git a/pkg/proc/proc_file_creator.go b/pkg/proc/proc_file_creator.go index f41277c..8690f88 100644 --- a/pkg/proc/proc_file_creator.go +++ b/pkg/proc/proc_file_creator.go @@ -1,201 +1,201 @@ -package proc - -import ( - "container/heap" - "fmt" - "time" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/env" - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/sc" -) - -type FileRecordHeapItem struct { - FileRecord *[]interface{} - Key string -} - -type FileRecordHeap []*FileRecordHeapItem - -func (h FileRecordHeap) Len() int { return len(h) } -func (h FileRecordHeap) Less(i, j int) bool { return h[i].Key > h[j].Key } // Reverse order: https://stackoverflow.com/questions/49065781/limit-size-of-the-priority-queue-for-gos-heap-interface-implementation -func (h FileRecordHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } -func (h *FileRecordHeap) Push(x interface{}) { - item := x.(*FileRecordHeapItem) - *h = append(*h, item) -} -func (h *FileRecordHeap) Pop() interface{} { - old := *h - n := len(old) - item := old[n-1] - old[n-1] = nil // avoid memory leak - *h = old[0 : n-1] - return item -} - -func readAndInsert(logger *l.Logger, pCtx *ctx.MessageProcessingContext, tableName string, rs *Rowset, instr *FileInserter, readerNodeRunId int16, startToken int64, endToken int64, srcBatchSize int) (BatchStats, error) { - - bs := BatchStats{RowsRead: 0, RowsWritten: 0, Src: tableName + cql.RunIdSuffix(readerNodeRunId), Dst: instr.FinalFileUrl} - - var topHeap FileRecordHeap - if instr.FileCreator.HasTop() { - topHeap := FileRecordHeap{} - heap.Init(&topHeap) - } - - curStartToken := startToken - - for { - lastRetrievedToken, err := selectBatchFromTableByToken(logger, - pCtx, - rs, - tableName, - readerNodeRunId, - srcBatchSize, - curStartToken, - endToken) - if err != nil { - return bs, err - } - curStartToken = lastRetrievedToken + 1 - - if rs.RowCount == 0 { - break - } - - for rowIdx := 0; rowIdx < rs.RowCount; rowIdx++ { - vars := eval.VarValuesMap{} - if err := rs.ExportToVars(rowIdx, &vars); err != nil { - return bs, err - } - - fileRecord, err := instr.FileCreator.CalculateFileRecordFromSrcVars(vars) - if err != nil { - return bs, fmt.Errorf("cannot populate file record from [%v]: [%s]", vars, err.Error()) - } - - inResult, err := instr.FileCreator.CheckFileRecordHavingCondition(fileRecord) - if err != nil { - return bs, fmt.Errorf("cannot check having condition [%s], file record [%v]: [%s]", instr.FileCreator.RawHaving, fileRecord, err.Error()) - } - - if !inResult { - continue - } - - if instr.FileCreator.HasTop() { - keyVars := map[string]interface{}{} - for i := 0; i < len(instr.FileCreator.Columns); i++ { - keyVars[instr.FileCreator.Columns[i].Name] = fileRecord[i] - } - key, err := sc.BuildKey(keyVars, &instr.FileCreator.Top.OrderIdxDef) - if err != nil { - return bs, fmt.Errorf("cannot build top key for [%v]: [%s]", vars, err.Error()) - } - heap.Push(&topHeap, &FileRecordHeapItem{FileRecord: &fileRecord, Key: key}) - if len(topHeap) > instr.FileCreator.Top.Limit { - heap.Pop(&topHeap) - } - } else { - instr.add(fileRecord) - bs.RowsWritten++ - } - } - - bs.RowsRead += rs.RowCount - if rs.RowCount < srcBatchSize { - break - } - - if err := instr.checkWorkerOutputForErrors(); err != nil { - return bs, fmt.Errorf("cannot save record batch from %s to %s(temp %s): [%s]", tableName, instr.FinalFileUrl, instr.TempFilePath, err.Error()) - } - - } // for each source table batch - - if instr.FileCreator.HasTop() { - properlyOrderedTopList := make([]*FileRecordHeapItem, topHeap.Len()) - for i := topHeap.Len() - 1; i >= 0; i-- { - properlyOrderedTopList[i] = heap.Pop(&topHeap).(*FileRecordHeapItem) - } - for i := 0; i < len(properlyOrderedTopList); i++ { - instr.add(*properlyOrderedTopList[i].FileRecord) - bs.RowsWritten++ - } - } - - return bs, nil - -} - -func RunCreateFile(envConfig *env.EnvConfig, - logger *l.Logger, - pCtx *ctx.MessageProcessingContext, - readerNodeRunId int16, - startToken int64, - endToken int64) (BatchStats, error) { - - logger.PushF("proc.RunCreateFile") - defer logger.PopF() - - totalStartTime := time.Now() - - if readerNodeRunId == 0 { - return BatchStats{RowsRead: 0, RowsWritten: 0}, fmt.Errorf("this node has a dependency node to read data from that was never started in this keyspace (readerNodeRunId == 0)") - } - - node := pCtx.CurrentScriptNode - - if !node.HasFileCreator() { - return BatchStats{RowsRead: 0, RowsWritten: 0}, fmt.Errorf("node does not have file creator") - } - - // Fields to read from source table - srcFieldRefs := sc.FieldRefs{} - // No src fields in having! - srcFieldRefs.AppendWithFilter(node.FileCreator.UsedInTargetExpressionsFields, sc.ReaderAlias) - - rs := NewRowsetFromFieldRefs( - sc.FieldRefs{sc.RowidFieldRef(node.TableReader.TableName)}, - sc.FieldRefs{sc.RowidTokenFieldRef()}, - srcFieldRefs) - - instr := newFileInserter(pCtx, &node.FileCreator, pCtx.BatchInfo.RunId, pCtx.BatchInfo.BatchIdx) - - if node.FileCreator.CreatorFileType == sc.CreatorFileTypeCsv { - if err := instr.createCsvFileAndStartWorker(logger); err != nil { - return BatchStats{RowsRead: 0, RowsWritten: 0}, fmt.Errorf("cannot start csv inserter worker: %s", err.Error()) - } - } else if node.FileCreator.CreatorFileType == sc.CreatorFileTypeParquet { - if err := instr.createParquetFileAndStartWorker(logger, node.FileCreator.Parquet.Codec); err != nil { - return BatchStats{RowsRead: 0, RowsWritten: 0}, fmt.Errorf("cannot start parquet inserter worker: %s", err.Error()) - } - } else { - return BatchStats{RowsRead: 0, RowsWritten: 0}, fmt.Errorf("unknown inserter file type: %d", node.FileCreator.CreatorFileType) - } - - bs, err := readAndInsert(logger, pCtx, node.TableReader.TableName, rs, instr, readerNodeRunId, startToken, endToken, node.TableReader.RowsetSize) - if err != nil { - if closeErr := instr.waitForWorkerAndCloseErrorsOut(logger, pCtx); err != nil { - logger.ErrorCtx(pCtx, "unexpected error while calling waitForWorkerAndCloseErrorsOut: %s", closeErr.Error()) - } - return bs, err - } - - // Successful so far, write leftovers - if err := instr.waitForWorkerAndCloseErrorsOut(logger, pCtx); err != nil { - return bs, fmt.Errorf("cannot save record batch from %s to %s(temp %s): [%s]", node.TableReader.TableName, instr.FinalFileUrl, instr.TempFilePath, err.Error()) - } - - bs.Elapsed = time.Since(totalStartTime) - logger.InfoCtx(pCtx, "WriteFileComplete: read %d, wrote %d items in %.3fs (%.0f items/s)", bs.RowsRead, bs.RowsWritten, bs.Elapsed.Seconds(), float64(bs.RowsWritten)/bs.Elapsed.Seconds()) - - if err := instr.sendFileToFinal(logger, pCtx, envConfig.PrivateKeys); err != nil { - return bs, err - } - - return bs, nil -} +package proc + +import ( + "container/heap" + "fmt" + "time" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/env" + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/sc" +) + +type FileRecordHeapItem struct { + FileRecord *[]any + Key string +} + +type FileRecordHeap []*FileRecordHeapItem + +func (h FileRecordHeap) Len() int { return len(h) } +func (h FileRecordHeap) Less(i, j int) bool { return h[i].Key > h[j].Key } // Reverse order: https://stackoverflow.com/questions/49065781/limit-size-of-the-priority-queue-for-gos-heap-interface-implementation +func (h FileRecordHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *FileRecordHeap) Push(x any) { + item := x.(*FileRecordHeapItem) //nolint:all + *h = append(*h, item) +} +func (h *FileRecordHeap) Pop() any { + old := *h + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + *h = old[0 : n-1] + return item +} + +func readAndInsert(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, tableName string, rs *Rowset, instr *FileInserter, readerNodeRunId int16, startToken int64, endToken int64, srcBatchSize int) (BatchStats, error) { + + bs := BatchStats{RowsRead: 0, RowsWritten: 0, Src: tableName + cql.RunIdSuffix(readerNodeRunId), Dst: instr.FinalFileUrl} + + var topHeap FileRecordHeap + if instr.FileCreator.HasTop() { + topHeap := FileRecordHeap{} + heap.Init(&topHeap) + } + + curStartToken := startToken + + for { + lastRetrievedToken, err := selectBatchFromTableByToken(logger, + pCtx, + rs, + tableName, + readerNodeRunId, + srcBatchSize, + curStartToken, + endToken) + if err != nil { + return bs, err + } + curStartToken = lastRetrievedToken + 1 + + if rs.RowCount == 0 { + break + } + + for rowIdx := 0; rowIdx < rs.RowCount; rowIdx++ { + vars := eval.VarValuesMap{} + if err := rs.ExportToVars(rowIdx, &vars); err != nil { + return bs, err + } + + fileRecord, err := instr.FileCreator.CalculateFileRecordFromSrcVars(vars) + if err != nil { + return bs, fmt.Errorf("cannot populate file record from [%v]: [%s]", vars, err.Error()) + } + + inResult, err := instr.FileCreator.CheckFileRecordHavingCondition(fileRecord) + if err != nil { + return bs, fmt.Errorf("cannot check having condition [%s], file record [%v]: [%s]", instr.FileCreator.RawHaving, fileRecord, err.Error()) + } + + if !inResult { + continue + } + + if instr.FileCreator.HasTop() { + keyVars := map[string]any{} + for i := 0; i < len(instr.FileCreator.Columns); i++ { + keyVars[instr.FileCreator.Columns[i].Name] = fileRecord[i] + } + key, err := sc.BuildKey(keyVars, &instr.FileCreator.Top.OrderIdxDef) + if err != nil { + return bs, fmt.Errorf("cannot build top key for [%v]: [%s]", vars, err.Error()) + } + heap.Push(&topHeap, &FileRecordHeapItem{FileRecord: &fileRecord, Key: key}) + if len(topHeap) > instr.FileCreator.Top.Limit { + heap.Pop(&topHeap) + } + } else { + instr.add(fileRecord) + bs.RowsWritten++ + } + } + + bs.RowsRead += rs.RowCount + if rs.RowCount < srcBatchSize { + break + } + + if err := instr.checkWorkerOutputForErrors(); err != nil { + return bs, fmt.Errorf("cannot save record batch from %s to %s(temp %s): [%s]", tableName, instr.FinalFileUrl, instr.TempFilePath, err.Error()) + } + + } // for each source table batch + + if instr.FileCreator.HasTop() { + properlyOrderedTopList := make([]*FileRecordHeapItem, topHeap.Len()) + for i := topHeap.Len() - 1; i >= 0; i-- { + properlyOrderedTopList[i] = heap.Pop(&topHeap).(*FileRecordHeapItem) //nolint:all + } + for i := 0; i < len(properlyOrderedTopList); i++ { + instr.add(*properlyOrderedTopList[i].FileRecord) + bs.RowsWritten++ + } + } + + return bs, nil + +} + +func RunCreateFile(envConfig *env.EnvConfig, + logger *l.CapiLogger, + pCtx *ctx.MessageProcessingContext, + readerNodeRunId int16, + startToken int64, + endToken int64) (BatchStats, error) { + + logger.PushF("proc.RunCreateFile") + defer logger.PopF() + + totalStartTime := time.Now() + + if readerNodeRunId == 0 { + return BatchStats{RowsRead: 0, RowsWritten: 0}, fmt.Errorf("this node has a dependency node to read data from that was never started in this keyspace (readerNodeRunId == 0)") + } + + node := pCtx.CurrentScriptNode + + if !node.HasFileCreator() { + return BatchStats{RowsRead: 0, RowsWritten: 0}, fmt.Errorf("node does not have file creator") + } + + // Fields to read from source table + srcFieldRefs := sc.FieldRefs{} + // No src fields in having! + srcFieldRefs.AppendWithFilter(node.FileCreator.UsedInTargetExpressionsFields, sc.ReaderAlias) + + rs := NewRowsetFromFieldRefs( + sc.FieldRefs{sc.RowidFieldRef(node.TableReader.TableName)}, + sc.FieldRefs{sc.RowidTokenFieldRef()}, + srcFieldRefs) + + instr := newFileInserter(pCtx, &node.FileCreator, pCtx.BatchInfo.RunId, pCtx.BatchInfo.BatchIdx) + + if node.FileCreator.CreatorFileType == sc.CreatorFileTypeCsv { + if err := instr.createCsvFileAndStartWorker(logger); err != nil { + return BatchStats{RowsRead: 0, RowsWritten: 0}, fmt.Errorf("cannot start csv inserter worker: %s", err.Error()) + } + } else if node.FileCreator.CreatorFileType == sc.CreatorFileTypeParquet { + if err := instr.createParquetFileAndStartWorker(logger, node.FileCreator.Parquet.Codec); err != nil { + return BatchStats{RowsRead: 0, RowsWritten: 0}, fmt.Errorf("cannot start parquet inserter worker: %s", err.Error()) + } + } else { + return BatchStats{RowsRead: 0, RowsWritten: 0}, fmt.Errorf("unknown inserter file type: %d", node.FileCreator.CreatorFileType) + } + + bs, err := readAndInsert(logger, pCtx, node.TableReader.TableName, rs, instr, readerNodeRunId, startToken, endToken, node.TableReader.RowsetSize) + if err != nil { + if closeErr := instr.waitForWorkerAndCloseErrorsOut(logger, pCtx); err != nil { + logger.ErrorCtx(pCtx, "unexpected error while calling waitForWorkerAndCloseErrorsOut: %s", closeErr.Error()) + } + return bs, err + } + + // Successful so far, write leftovers + if err := instr.waitForWorkerAndCloseErrorsOut(logger, pCtx); err != nil { + return bs, fmt.Errorf("cannot save record batch from %s to %s(temp %s): [%s]", node.TableReader.TableName, instr.FinalFileUrl, instr.TempFilePath, err.Error()) + } + + bs.Elapsed = time.Since(totalStartTime) + logger.InfoCtx(pCtx, "WriteFileComplete: read %d, wrote %d items in %.3fs (%.0f items/s)", bs.RowsRead, bs.RowsWritten, bs.Elapsed.Seconds(), float64(bs.RowsWritten)/bs.Elapsed.Seconds()) + + if err := instr.sendFileToFinal(logger, pCtx, envConfig.PrivateKeys); err != nil { + return bs, err + } + + return bs, nil +} diff --git a/pkg/proc/proc_table_creator.go b/pkg/proc/proc_table_creator.go index 5033cef..9abab5e 100644 --- a/pkg/proc/proc_table_creator.go +++ b/pkg/proc/proc_table_creator.go @@ -1,890 +1,891 @@ -package proc - -import ( - "bufio" - "fmt" - "io" - "net/url" - "os" - "time" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/env" - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/xfer" -) - -type TableRecord map[string]interface{} -type TableRecordPtr *map[string]interface{} -type TableRecordBatch []TableRecordPtr - -const DefaultInserterBatchSize int = 5000 - -func reportWriteTable(logger *l.Logger, pCtx *ctx.MessageProcessingContext, recordCount int, dur time.Duration, indexCount int, workerCount int) { - logger.InfoCtx(pCtx, "WriteTable: %d items in %.3fs (%.0f items/s, %d indexes, eff rate %.0f writes/s), %d workers", - recordCount, - dur.Seconds(), - float64(recordCount)/dur.Seconds(), - indexCount, - float64(recordCount*(indexCount+1))/dur.Seconds(), - workerCount) -} - -func reportWriteTableLeftovers(logger *l.Logger, pCtx *ctx.MessageProcessingContext, recordCount int, dur time.Duration, indexCount int, workerCount int) { - logger.InfoCtx(pCtx, "WriteTableLeftovers: %d items in %.3fs (%.0f items/s, %d indexes, eff rate %.0f writes/s), %d workers", - recordCount, - dur.Seconds(), - float64(recordCount)/dur.Seconds(), - indexCount, - float64(recordCount*(indexCount+1))/dur.Seconds(), - workerCount) -} - -func reportWriteTableComplete(logger *l.Logger, pCtx *ctx.MessageProcessingContext, readCount int, recordCount int, dur time.Duration, indexCount int, workerCount int) { - logger.InfoCtx(pCtx, "WriteTableComplete: read %d, wrote %d items in %.3fs (%.0f items/s, %d indexes, eff rate %.0f writes/s), %d workers", - readCount, - recordCount, - dur.Seconds(), - float64(recordCount)/dur.Seconds(), - indexCount, - float64(recordCount*(indexCount+1))/dur.Seconds(), - workerCount) -} - -func RunReadFileForBatch(envConfig *env.EnvConfig, logger *l.Logger, pCtx *ctx.MessageProcessingContext, srcFileIdx int) (BatchStats, error) { - logger.PushF("proc.RunReadFileForBatch") - defer logger.PopF() - - totalStartTime := time.Now() - bs := BatchStats{RowsRead: 0, RowsWritten: 0} - - node := pCtx.CurrentScriptNode - - if !node.HasFileReader() { - return bs, fmt.Errorf("node does not have file reader") - } - if !node.HasTableCreator() { - return bs, fmt.Errorf("node does not have table creator") - } - - if srcFileIdx < 0 || srcFileIdx >= len(node.FileReader.SrcFileUrls) { - return bs, fmt.Errorf("cannot find file to read: asked to read src file with index %d while there are only %d source files available", srcFileIdx, len(node.FileReader.SrcFileUrls)) - } - filePath := node.FileReader.SrcFileUrls[srcFileIdx] - - u, err := url.Parse(filePath) - if err != nil { - return bs, fmt.Errorf("cannot parse file uri %s: %s", filePath, err.Error()) - } - - bs.Src = filePath - bs.Dst = node.TableCreator.Name + cql.RunIdSuffix(pCtx.BatchInfo.RunId) - - var localSrcFile *os.File - var fileReader io.Reader - var fileReadSeeker io.ReadSeeker - if u.Scheme == xfer.UriSchemeFile || len(u.Scheme) == 0 { - localSrcFile, err = os.Open(filePath) - if err != nil { - return bs, err - } - defer localSrcFile.Close() - fileReader = bufio.NewReader(localSrcFile) - fileReadSeeker = localSrcFile - } else if u.Scheme == xfer.UriSchemeHttp || u.Scheme == xfer.UriSchemeHttps { - // If this is a parquet file, download it and then open so we have fileReadSeeker - if node.FileReader.ReaderFileType == sc.ReaderFileTypeParquet { - dstFile, err := os.CreateTemp("", "capi") - if err != nil { - return bs, fmt.Errorf("cannot create temp file for %s: %s", filePath, err.Error()) - } - - readCloser, err := xfer.GetHttpReadCloser(filePath, u.Scheme, envConfig.CaPath) - if err != nil { - dstFile.Close() - return bs, fmt.Errorf("cannot open http file %s: %s", filePath, err.Error()) - } - defer readCloser.Close() - - if _, err := io.Copy(dstFile, readCloser); err != nil { - dstFile.Close() - return bs, fmt.Errorf("cannot save http file %s to temp file %s: %s", filePath, dstFile.Name(), err.Error()) - } - - logger.Info("downloaded http file %s to %s", filePath, dstFile.Name()) - dstFile.Close() - defer os.Remove(dstFile.Name()) - - localSrcFile, err = os.Open(dstFile.Name()) - if err != nil { - return bs, fmt.Errorf("cannot read from file %s downloaded from %s: %s", dstFile.Name(), filePath, err.Error()) - } - defer localSrcFile.Close() - fileReadSeeker = localSrcFile - } else { - // Just read from the net - readCloser, err := xfer.GetHttpReadCloser(filePath, u.Scheme, envConfig.CaPath) - if err != nil { - return bs, err - } - fileReader = readCloser - defer readCloser.Close() - } - } else if u.Scheme == xfer.UriSchemeSftp { - // When dealing with sftp, we download the *whole* file, instead of providing a reader - dstFile, err := os.CreateTemp("", "capi") - if err != nil { - return bs, fmt.Errorf("cannot create temp file for %s: %s", filePath, err.Error()) - } - - // Download and schedule delete - if err = xfer.DownloadSftpFile(filePath, envConfig.PrivateKeys, dstFile); err != nil { - dstFile.Close() - return bs, err - } - logger.Info("downloaded sftp file %s to %s", filePath, dstFile.Name()) - dstFile.Close() - defer os.Remove(dstFile.Name()) - - // Create a reader for the temp file - localSrcFile, err = os.Open(dstFile.Name()) - if err != nil { - return bs, fmt.Errorf("cannot read from file %s downloaded from %s: %s", dstFile.Name(), filePath, err.Error()) - } - defer localSrcFile.Close() - fileReader = bufio.NewReader(localSrcFile) - fileReadSeeker = localSrcFile - } else { - return bs, fmt.Errorf("uri scheme %s not supported: %s", u.Scheme, filePath) - } - - if node.FileReader.ReaderFileType == sc.ReaderFileTypeCsv { - return readCsv(envConfig, logger, pCtx, totalStartTime, filePath, fileReader) - } else if node.FileReader.ReaderFileType == sc.ReaderFileTypeParquet { - return readParquet(envConfig, logger, pCtx, totalStartTime, filePath, fileReadSeeker) - } else { - return BatchStats{RowsRead: 0, RowsWritten: 0}, fmt.Errorf("unknown reader file type: %d", node.FileReader.ReaderFileType) - } -} - -func RunCreateTableForCustomProcessorForBatch(envConfig *env.EnvConfig, - logger *l.Logger, - pCtx *ctx.MessageProcessingContext, - readerNodeRunId int16, - startLeftToken int64, - endLeftToken int64) (BatchStats, error) { - - logger.PushF("proc.RunCreateTableForCustomProcessorForBatch") - defer logger.PopF() - - node := pCtx.CurrentScriptNode - - totalStartTime := time.Now() - bs := BatchStats{RowsRead: 0, RowsWritten: 0, Src: node.TableReader.TableName + cql.RunIdSuffix(readerNodeRunId), Dst: node.TableCreator.Name + cql.RunIdSuffix(readerNodeRunId)} - - if readerNodeRunId == 0 { - return bs, fmt.Errorf("this node has a dependency node to read data from that was never started in this keyspace (readerNodeRunId == 0)") - } - - if !node.HasTableReader() { - return bs, fmt.Errorf("node does not have table reader") - } - if !node.HasTableCreator() { - return bs, fmt.Errorf("node does not have table creator") - } - - // Fields to read from source table - srcLeftFieldRefs := sc.FieldRefs{} - srcLeftFieldRefs.AppendWithFilter(*node.CustomProcessor.GetUsedInTargetExpressionsFields(), sc.ReaderAlias) - srcLeftFieldRefs.AppendWithFilter(node.TableCreator.UsedInTargetExpressionsFields, sc.ReaderAlias) - - leftBatchSize := node.TableReader.RowsetSize - curStartLeftToken := startLeftToken - - rsIn := NewRowsetFromFieldRefs( - sc.FieldRefs{sc.RowidFieldRef(node.TableReader.TableName)}, - sc.FieldRefs{sc.RowidTokenFieldRef()}, - srcLeftFieldRefs) - - inserterBatchSize := DefaultInserterBatchSize - if inserterBatchSize < node.TableReader.RowsetSize { - inserterBatchSize = node.TableReader.RowsetSize - } - instr := newTableInserter(envConfig, logger, pCtx, &node.TableCreator, inserterBatchSize) - //instr.verifyTablesExist() - if err := instr.startWorkers(logger, pCtx); err != nil { - return bs, err - } - defer instr.waitForWorkersAndCloseErrorsOut(logger, pCtx) - - flushVarsArray := func(varsArray []*eval.VarValuesMap, varsArrayCount int) error { - logger.PushF("proc.flushRowset") - defer logger.PopF() - - flushStartTime := time.Now() - rowsWritten := 0 - - for outRowIdx := 0; outRowIdx < varsArrayCount; outRowIdx++ { - vars := varsArray[outRowIdx] - - tableRecord, err := node.TableCreator.CalculateTableRecordFromSrcVars(false, *vars) - if err != nil { - return fmt.Errorf("cannot populate table record from [%v]: [%s]", vars, err.Error()) - } - - // Check table creator having - inResult, err := node.TableCreator.CheckTableRecordHavingCondition(tableRecord) - if err != nil { - return fmt.Errorf("cannot check having condition [%s], table record [%v]: [%s]", node.TableCreator.RawHaving, tableRecord, err.Error()) - } - - // Write batch if needed - if inResult { - instr.add(tableRecord) - rowsWritten++ - bs.RowsWritten++ - } - } - - reportWriteTable(logger, pCtx, rowsWritten, time.Since(flushStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) - flushStartTime = time.Now() - rowsWritten = 0 - - return nil - } - - for { - lastRetrievedLeftToken, err := selectBatchFromTableByToken(logger, - pCtx, - rsIn, - node.TableReader.TableName, - readerNodeRunId, - leftBatchSize, - curStartLeftToken, - endLeftToken) - if err != nil { - return bs, err - } - curStartLeftToken = lastRetrievedLeftToken + 1 - - if rsIn.RowCount == 0 { - break - } - customProcBatchStartTime := time.Now() - - if err = node.CustomProcessor.(CustomProcessorRunner).Run(logger, pCtx, rsIn, flushVarsArray); err != nil { - return bs, err - } - - custProcDur := time.Since(customProcBatchStartTime) - logger.InfoCtx(pCtx, "CustomProcessor: %d items in %v (%.0f items/s)", rsIn.RowCount, custProcDur, float64(rsIn.RowCount)/custProcDur.Seconds()) - - bs.RowsRead += rsIn.RowCount - if rsIn.RowCount < leftBatchSize { - break - } - } // for each source table batch - - bs.Elapsed = time.Since(totalStartTime) - reportWriteTableComplete(logger, pCtx, bs.RowsRead, bs.RowsWritten, bs.Elapsed, len(node.TableCreator.Indexes), instr.NumWorkers) - - return bs, nil -} - -func RunCreateTableForBatch(envConfig *env.EnvConfig, - logger *l.Logger, - pCtx *ctx.MessageProcessingContext, - readerNodeRunId int16, - startLeftToken int64, - endLeftToken int64) (BatchStats, error) { - - logger.PushF("proc.RunCreateTableForBatch") - defer logger.PopF() - - node := pCtx.CurrentScriptNode - - batchStartTime := time.Now() - totalStartTime := time.Now() - bs := BatchStats{RowsRead: 0, RowsWritten: 0, Src: node.TableReader.TableName + cql.RunIdSuffix(readerNodeRunId), Dst: node.TableCreator.Name + cql.RunIdSuffix(readerNodeRunId)} - - if readerNodeRunId == 0 { - return bs, fmt.Errorf("this node has a dependency node to read data from that was never started in this keyspace (readerNodeRunId == 0)") - } - - if !node.HasTableReader() { - return bs, fmt.Errorf("node does not have table reader") - } - if !node.HasTableCreator() { - return bs, fmt.Errorf("node does not have table creator") - } - - // Fields to read from source table - srcLeftFieldRefs := sc.FieldRefs{} - srcLeftFieldRefs.AppendWithFilter(node.TableCreator.UsedInTargetExpressionsFields, sc.ReaderAlias) - - leftBatchSize := node.TableReader.RowsetSize - tableRecordBatchCount := 0 - curStartLeftToken := startLeftToken - - rsIn := NewRowsetFromFieldRefs( - sc.FieldRefs{sc.RowidFieldRef(node.TableReader.TableName)}, - sc.FieldRefs{sc.RowidTokenFieldRef()}, - srcLeftFieldRefs) - - inserterBatchSize := DefaultInserterBatchSize - if inserterBatchSize < node.TableReader.RowsetSize { - inserterBatchSize = node.TableReader.RowsetSize - } - instr := newTableInserter(envConfig, logger, pCtx, &node.TableCreator, inserterBatchSize) - //instr.verifyTablesExist() - if err := instr.startWorkers(logger, pCtx); err != nil { - return bs, err - } - defer instr.waitForWorkersAndCloseErrorsOut(logger, pCtx) - - for { - lastRetrievedLeftToken, err := selectBatchFromTableByToken(logger, - pCtx, - rsIn, - node.TableReader.TableName, - readerNodeRunId, - leftBatchSize, - curStartLeftToken, - endLeftToken) - if err != nil { - return bs, err - } - curStartLeftToken = lastRetrievedLeftToken + 1 - - if rsIn.RowCount == 0 { - break - } - - // Save rsIn - for outRowIdx := 0; outRowIdx < rsIn.RowCount; outRowIdx++ { - vars := eval.VarValuesMap{} - if err := rsIn.ExportToVars(outRowIdx, &vars); err != nil { - return bs, err - } - - tableRecord, err := node.TableCreator.CalculateTableRecordFromSrcVars(false, vars) - if err != nil { - return bs, fmt.Errorf("cannot populate table record from [%v]: [%s]", vars, err.Error()) - } - - // Check table creator having - inResult, err := node.TableCreator.CheckTableRecordHavingCondition(tableRecord) - if err != nil { - return bs, fmt.Errorf("cannot check having condition [%s], table record [%v]: [%s]", node.TableCreator.RawHaving, tableRecord, err.Error()) - } - - // Write batch if needed - if inResult { - instr.add(tableRecord) - tableRecordBatchCount++ - if tableRecordBatchCount == DefaultInserterBatchSize { - if err := instr.waitForWorkers(logger, pCtx); err != nil { - return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) - } - reportWriteTable(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) - batchStartTime = time.Now() - tableRecordBatchCount = 0 - if err := instr.startWorkers(logger, pCtx); err != nil { - return bs, err - } - } - bs.RowsWritten++ - } - } - - bs.RowsRead += rsIn.RowCount - if rsIn.RowCount < leftBatchSize { - break - } - } // for each source table batch - - // Write leftovers regardless of tableRecordBatchCount == 0 - if err := instr.waitForWorkers(logger, pCtx); err != nil { - return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) - } - reportWriteTableLeftovers(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) - - bs.Elapsed = time.Since(totalStartTime) - reportWriteTableComplete(logger, pCtx, bs.RowsRead, bs.RowsWritten, bs.Elapsed, len(node.TableCreator.Indexes), instr.NumWorkers) - - return bs, nil -} - -func RunCreateTableRelForBatch(envConfig *env.EnvConfig, - logger *l.Logger, - pCtx *ctx.MessageProcessingContext, - readerNodeRunId int16, - lookupNodeRunId int16, - startLeftToken int64, - endLeftToken int64) (BatchStats, error) { - - logger.PushF("proc.RunCreateTableRelForBatch") - defer logger.PopF() - - node := pCtx.CurrentScriptNode - - batchStartTime := time.Now() - totalStartTime := time.Now() - - bs := BatchStats{RowsRead: 0, RowsWritten: 0, Src: node.TableReader.TableName + cql.RunIdSuffix(readerNodeRunId), Dst: node.TableCreator.Name + cql.RunIdSuffix(readerNodeRunId)} - - if readerNodeRunId == 0 { - return bs, fmt.Errorf("this node has a dependency node to read data from that was never started in this keyspace (readerNodeRunId == 0)") - } - - if lookupNodeRunId == 0 { - return bs, fmt.Errorf("this node has a dependency node to lookup data at that was never started in this keyspace (lookupNodeRunId == 0)") - } - - if !node.HasTableReader() { - return bs, fmt.Errorf("node does not have table reader") - } - if !node.HasTableCreator() { - return bs, fmt.Errorf("node does not have table creator") - } - if !node.HasLookup() { - return bs, fmt.Errorf("node does not have lookup") - } - - // Fields to read from source table - srcLeftFieldRefs := sc.FieldRefs{} - srcLeftFieldRefs.AppendWithFilter(node.TableCreator.UsedInTargetExpressionsFields, sc.ReaderAlias) - srcLeftFieldRefs.Append(node.Lookup.LeftTableFields) - - srcRightFieldRefs := sc.FieldRefs{} - srcRightFieldRefs.AppendWithFilter(node.TableCreator.UsedInTargetExpressionsFields, sc.LookupAlias) - if node.Lookup.UsesFilter() { - srcRightFieldRefs.AppendWithFilter(node.Lookup.UsedInFilterFields, sc.LookupAlias) - } - - leftBatchSize := node.TableReader.RowsetSize - tableRecordBatchCount := 0 - - rsLeft := NewRowsetFromFieldRefs( - sc.FieldRefs{sc.RowidFieldRef(node.TableReader.TableName)}, - sc.FieldRefs{sc.RowidTokenFieldRef()}, - srcLeftFieldRefs) - - inserterBatchSize := DefaultInserterBatchSize - if inserterBatchSize < node.TableReader.RowsetSize { - inserterBatchSize = node.TableReader.RowsetSize - } - instr := newTableInserter(envConfig, logger, pCtx, &node.TableCreator, inserterBatchSize) - //instr.verifyTablesExist() - if err := instr.startWorkers(logger, pCtx); err != nil { - return bs, err - } - defer instr.waitForWorkersAndCloseErrorsOut(logger, pCtx) - - curStartLeftToken := startLeftToken - leftPageIdx := 0 - for { - selectLeftBatchByTokenStartTime := time.Now() - lastRetrievedLeftToken, err := selectBatchFromTableByToken(logger, - pCtx, - rsLeft, - node.TableReader.TableName, - readerNodeRunId, - leftBatchSize, - curStartLeftToken, - endLeftToken) - if err != nil { - return bs, err - } - - logger.DebugCtx(pCtx, "selectBatchFromTableByToken: leftPageIdx %d, queried tokens from %d to %d in %.3fs, retrieved %d rows", leftPageIdx, curStartLeftToken, endLeftToken, time.Since(selectLeftBatchByTokenStartTime).Seconds(), rsLeft.RowCount) - - curStartLeftToken = lastRetrievedLeftToken + 1 - - if rsLeft.RowCount == 0 { - break - } - - // Setup eval ctx for each target field if grouping is involved - // map: rowid -> field -> ctx - eCtxMap := map[int64]map[string]*eval.EvalCtx{} - if node.HasLookup() && node.Lookup.IsGroup { - for rowIdx := 0; rowIdx < rsLeft.RowCount; rowIdx++ { - rowid := *((*rsLeft.Rows[rowIdx])[rsLeft.FieldsByFieldName["rowid"]].(*int64)) - eCtxMap[rowid] = map[string]*eval.EvalCtx{} - for fieldName, fieldDef := range node.TableCreator.Fields { - aggFuncEnabled, aggFuncType, aggFuncArgs := eval.DetectRootAggFunc(fieldDef.ParsedExpression) - newCtx, newCtxErr := eval.NewPlainEvalCtxAndInitializedAgg(aggFuncEnabled, aggFuncType, aggFuncArgs) - if newCtxErr != nil { - return bs, newCtxErr - } - eCtxMap[rowid][fieldName] = newCtx - } - } - } - - // Build keys to find in the lookup index, one key may yield multiple rowids - keyToLeftRowIdxMap := map[string][]int{} - leftRowFoundRightLookup := make([]bool, rsLeft.RowCount) - for rowIdx := 0; rowIdx < rsLeft.RowCount; rowIdx++ { - leftRowFoundRightLookup[rowIdx] = false - vars := eval.VarValuesMap{} - if err := rsLeft.ExportToVars(rowIdx, &vars); err != nil { - return bs, err - } - key, err := sc.BuildKey(vars[sc.ReaderAlias], node.Lookup.TableCreator.Indexes[node.Lookup.IndexName]) - if err != nil { - return bs, err - } - - _, ok := keyToLeftRowIdxMap[key] - if !ok { - keyToLeftRowIdxMap[key] = make([]int, 0) - } - keyToLeftRowIdxMap[key] = append(keyToLeftRowIdxMap[key], rowIdx) - } - - keysToFind := make([]string, len(keyToLeftRowIdxMap)) - i := 0 - for k := range keyToLeftRowIdxMap { - keysToFind[i] = k - i++ - } - - lookupFieldRefs := sc.FieldRefs{} - lookupFieldRefs.AppendWithFilter(node.TableCreator.UsedInHavingFields, node.Lookup.TableCreator.Name) - lookupFieldRefs.AppendWithFilter(node.TableCreator.UsedInTargetExpressionsFields, node.Lookup.TableCreator.Name) - - rsIdx := NewRowsetFromFieldRefs( - sc.FieldRefs{sc.RowidFieldRef(node.Lookup.IndexName)}, - sc.FieldRefs{sc.KeyTokenFieldRef()}, - sc.FieldRefs{sc.IdxKeyFieldRef()}) - - var idxPageState []byte - rightIdxPageIdx := 0 - for { - selectIdxBatchStartTime := time.Now() - idxPageState, err = selectBatchFromIdxTablePaged(logger, - pCtx, - rsIdx, - node.Lookup.IndexName, - lookupNodeRunId, - node.Lookup.IdxReadBatchSize, - idxPageState, - &keysToFind) - if err != nil { - return bs, err - } - - if rsIdx.RowCount == 0 { - break - } - - // Build a map of right-row-id -> key - rightRowIdToKeyMap := map[int64]string{} - for rowIdx := 0; rowIdx < rsIdx.RowCount; rowIdx++ { - rightRowId := *((*rsIdx.Rows[rowIdx])[rsIdx.FieldsByFieldName["rowid"]].(*int64)) - key := *((*rsIdx.Rows[rowIdx])[rsIdx.FieldsByFieldName["key"]].(*string)) - rightRowIdToKeyMap[rightRowId] = key - } - - rowidsToFind := make(map[int64]struct{}, len(rightRowIdToKeyMap)) - for k := range rightRowIdToKeyMap { - rowidsToFind[k] = struct{}{} - } - - logger.DebugCtx(pCtx, "selectBatchFromIdxTablePaged: leftPageIdx %d, rightIdxPageIdx %d, queried %d keys in %.3fs, retrieved %d rowids", leftPageIdx, rightIdxPageIdx, len(keysToFind), time.Since(selectIdxBatchStartTime).Seconds(), len(rowidsToFind)) - - // Select from right table by rowid - rsRight := NewRowsetFromFieldRefs( - sc.FieldRefs{sc.RowidFieldRef(node.Lookup.TableCreator.Name)}, - sc.FieldRefs{sc.RowidTokenFieldRef()}, - srcRightFieldRefs) - - var rightPageState []byte - rightDataPageIdx := 0 - for { - selectBatchStartTime := time.Now() - rightPageState, err = selectBatchFromDataTablePaged(logger, - pCtx, - rsRight, - node.Lookup.TableCreator.Name, - lookupNodeRunId, - node.Lookup.RightLookupReadBatchSize, - rightPageState, - rowidsToFind) - if err != nil { - return bs, err - } - - logger.DebugCtx(pCtx, "selectBatchFromDataTablePaged: leftPageIdx %d, rightIdxPageIdx %d, rightDataPageIdx %d, queried %d rowids in %.3fs, retrieved %d rowids", leftPageIdx, rightIdxPageIdx, rightDataPageIdx, len(rowidsToFind), time.Since(selectBatchStartTime).Seconds(), rsRight.RowCount) - - if rsRight.RowCount == 0 { - break - } - - for rightRowIdx := 0; rightRowIdx < rsRight.RowCount; rightRowIdx++ { - rightRowId := *((*rsRight.Rows[rightRowIdx])[rsRight.FieldsByFieldName["rowid"]].(*int64)) - rightRowKey := rightRowIdToKeyMap[rightRowId] - - // Remove this right rowid from the set, we do not need it anymore. Reset page state. - rightPageState = nil - delete(rowidsToFind, rightRowId) - - // Check filter condition if needed - lookupFilterOk := true - if node.Lookup.UsesFilter() { - vars := eval.VarValuesMap{} - if err := rsRight.ExportToVars(rightRowIdx, &vars); err != nil { - return bs, err - } - var err error - lookupFilterOk, err = node.Lookup.CheckFilterCondition(vars) - if err != nil { - return bs, fmt.Errorf("cannot check filter condition [%s] against [%v]: [%s]", node.Lookup.RawFilter, vars, err.Error()) - } - } - - if !lookupFilterOk { - continue - } - - if node.Lookup.IsGroup { - // Find correspondent row from rsLeft, merge left and right and - // call group eval eCtxMap[leftRowid] for each output field - for _, leftRowIdx := range keyToLeftRowIdxMap[rightRowKey] { - - leftRowFoundRightLookup[leftRowIdx] = true - - leftRowid := *((*rsLeft.Rows[leftRowIdx])[rsLeft.FieldsByFieldName["rowid"]].(*int64)) - for fieldName, fieldDef := range node.TableCreator.Fields { - eCtxMap[leftRowid][fieldName].Vars = &eval.VarValuesMap{} - if err := rsLeft.ExportToVars(leftRowIdx, eCtxMap[leftRowid][fieldName].Vars); err != nil { - return bs, err - } - if err := rsRight.ExportToVarsWithAlias(rightRowIdx, eCtxMap[leftRowid][fieldName].Vars, sc.LookupAlias); err != nil { - return bs, err - } - _, err := eCtxMap[leftRowid][fieldName].Eval(fieldDef.ParsedExpression) - if err != nil { - return bs, fmt.Errorf("cannot evaluate target expression [%s]: [%s]", fieldDef.RawExpression, err.Error()) - } - } - } - } else { - // Non-group. Find correspondent row from rsLeft, merge left and right and call row-level eval - for _, leftRowIdx := range keyToLeftRowIdxMap[rightRowKey] { - - leftRowFoundRightLookup[leftRowIdx] = true - - vars := eval.VarValuesMap{} - if err := rsLeft.ExportToVars(leftRowIdx, &vars); err != nil { - return bs, err - } - if err := rsRight.ExportToVarsWithAlias(rightRowIdx, &vars, sc.LookupAlias); err != nil { - return bs, err - } - - // We are ready to write this result right away, so prepare the output tableRecord - tableRecord, err := node.TableCreator.CalculateTableRecordFromSrcVars(false, vars) - if err != nil { - return bs, fmt.Errorf("cannot populate table record from [%v]: [%s]", vars, err.Error()) - } - - // Check table creator having - inResult, err := node.TableCreator.CheckTableRecordHavingCondition(tableRecord) - if err != nil { - return bs, fmt.Errorf("cannot check having condition [%s], table record [%v]: [%s]", node.TableCreator.RawHaving, tableRecord, err.Error()) - } - - // Write batch if needed - if inResult { - instr.add(tableRecord) - tableRecordBatchCount++ - if tableRecordBatchCount == instr.BatchSize { - if err := instr.waitForWorkers(logger, pCtx); err != nil { - return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) - } - reportWriteTable(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) - batchStartTime = time.Now() - tableRecordBatchCount = 0 - if err := instr.startWorkers(logger, pCtx); err != nil { - return bs, err - } - } - bs.RowsWritten++ - } - } // non-group result row written - } // group case handled - } // for each found right row - - if rsRight.RowCount < node.Lookup.RightLookupReadBatchSize || len(rightPageState) == 0 { - break - } - rightDataPageIdx++ - } // for each data page - - if rsIdx.RowCount < node.Lookup.IdxReadBatchSize || len(idxPageState) == 0 { - break - } - rightIdxPageIdx++ - } // for each idx page - - if node.Lookup.IsGroup { - // Time to write the result of the grouped - for leftRowIdx := 0; leftRowIdx < rsLeft.RowCount; leftRowIdx++ { - tableRecord := map[string]interface{}{} - if !leftRowFoundRightLookup[leftRowIdx] { - if node.Lookup.LookupJoin == sc.LookupJoinLeft { - - // Grouped left outer join with no data on the right - - leftVars := eval.VarValuesMap{} - if err := rsLeft.ExportToVars(leftRowIdx, &leftVars); err != nil { - return bs, err - } - - for fieldName, fieldDef := range node.TableCreator.Fields { - isAggEnabled, _, _ := eval.DetectRootAggFunc(fieldDef.ParsedExpression) - if isAggEnabled == eval.AggFuncEnabled { - // Aggregate func is used in field expression - ignore the expression and produce default - tableRecord[fieldName], err = node.TableCreator.GetFieldDefaultReadyForDb(fieldName) - if err != nil { - return bs, fmt.Errorf("cannot initialize default field %s: [%s]", fieldName, err.Error()) - } - } else { - // No aggregate function used in field expression - assume it contains only left-side fields - tableRecord[fieldName], err = sc.CalculateFieldValue(fieldName, fieldDef, leftVars, false) - if err != nil { - return bs, err - } - } - } - } else { - - // Grouped inner join with no data on the right - // Do not insert this left row - - continue - } - } else { - - // Grouped inner or left outer with present data on the right - - leftRowid := *((*rsLeft.Rows[leftRowIdx])[rsLeft.FieldsByFieldName["rowid"]].(*int64)) - for fieldName, fieldDef := range node.TableCreator.Fields { - finalValue := eCtxMap[leftRowid][fieldName].Value - - if err := sc.CheckValueType(finalValue, fieldDef.Type); err != nil { - return bs, fmt.Errorf("invalid field %s type: [%s]", fieldName, err.Error()) - } - tableRecord[fieldName] = finalValue - } - } - - // Check table creator having - inResult, err := node.TableCreator.CheckTableRecordHavingCondition(tableRecord) - if err != nil { - return bs, fmt.Errorf("cannot check having condition [%s], table record [%v]: [%s]", node.TableCreator.RawHaving, tableRecord, err.Error()) - } - - // Write batch if needed - if inResult { - instr.add(tableRecord) - tableRecordBatchCount++ - if tableRecordBatchCount == instr.BatchSize { - if err := instr.waitForWorkers(logger, pCtx); err != nil { - return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) - } - reportWriteTable(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) - batchStartTime = time.Now() - tableRecordBatchCount = 0 - if err := instr.startWorkers(logger, pCtx); err != nil { - return bs, err - } - } - bs.RowsWritten++ - } - } - } else if node.Lookup.LookupJoin == sc.LookupJoinLeft { - - // Non-grouped left outer join. - // Handle those left rows that did not have right lookup counterpart - // (those who had - they have been written already) - - for leftRowIdx := 0; leftRowIdx < rsLeft.RowCount; leftRowIdx++ { - if leftRowFoundRightLookup[leftRowIdx] { - continue - } - - leftVars := eval.VarValuesMap{} - if err := rsLeft.ExportToVars(leftRowIdx, &leftVars); err != nil { - return bs, err - } - - tableRecord := map[string]interface{}{} - - for fieldName, fieldDef := range node.TableCreator.Fields { - if fieldDef.UsedFields.HasFieldsWithTableAlias(sc.LookupAlias) { - // This field expression uses fields from lkp table - produce default value - tableRecord[fieldName], err = node.TableCreator.GetFieldDefaultReadyForDb(fieldName) - if err != nil { - return bs, fmt.Errorf("cannot initialize non-grouped default field %s: [%s]", fieldName, err.Error()) - } - } else { - // This field expression does not use fields from lkp table - assume the expression contains only left-side fields - tableRecord[fieldName], err = sc.CalculateFieldValue(fieldName, fieldDef, leftVars, false) - if err != nil { - return bs, err - } - } - } - - // Check table creator having - inResult, err := node.TableCreator.CheckTableRecordHavingCondition(tableRecord) - if err != nil { - return bs, fmt.Errorf("cannot check having condition [%s], table record [%v]: [%s]", node.TableCreator.RawHaving, tableRecord, err.Error()) - } - - // Write batch if needed - if inResult { - instr.add(tableRecord) - tableRecordBatchCount++ - if tableRecordBatchCount == instr.BatchSize { - if err := instr.waitForWorkers(logger, pCtx); err != nil { - return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) - } - reportWriteTable(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) - batchStartTime = time.Now() - tableRecordBatchCount = 0 - if err := instr.startWorkers(logger, pCtx); err != nil { - return bs, err - } - } - bs.RowsWritten++ - } - } - } else { - // Non-grouped inner join, already handled above - } - - bs.RowsRead += rsLeft.RowCount - if rsLeft.RowCount < leftBatchSize { - break - } - leftPageIdx++ - } // for each source table batch - - // Write leftovers regardless of tableRecordBatchCount == 0 - if err := instr.waitForWorkers(logger, pCtx); err != nil { - return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) - } - reportWriteTableLeftovers(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) - - bs.Elapsed = time.Since(totalStartTime) - reportWriteTableComplete(logger, pCtx, bs.RowsRead, bs.RowsWritten, bs.Elapsed, len(node.TableCreator.Indexes), instr.NumWorkers) - return bs, nil -} +package proc + +import ( + "bufio" + "fmt" + "io" + "net/url" + "os" + "time" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/env" + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/xfer" +) + +type TableRecord map[string]any +type TableRecordPtr *map[string]any +type TableRecordBatch []TableRecordPtr + +const DefaultInserterBatchSize int = 5000 + +func reportWriteTable(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, recordCount int, dur time.Duration, indexCount int, workerCount int) { + logger.InfoCtx(pCtx, "WriteTable: %d items in %.3fs (%.0f items/s, %d indexes, eff rate %.0f writes/s), %d workers", + recordCount, + dur.Seconds(), + float64(recordCount)/dur.Seconds(), + indexCount, + float64(recordCount*(indexCount+1))/dur.Seconds(), + workerCount) +} + +func reportWriteTableLeftovers(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, recordCount int, dur time.Duration, indexCount int, workerCount int) { + logger.InfoCtx(pCtx, "WriteTableLeftovers: %d items in %.3fs (%.0f items/s, %d indexes, eff rate %.0f writes/s), %d workers", + recordCount, + dur.Seconds(), + float64(recordCount)/dur.Seconds(), + indexCount, + float64(recordCount*(indexCount+1))/dur.Seconds(), + workerCount) +} + +func reportWriteTableComplete(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, readCount int, recordCount int, dur time.Duration, indexCount int, workerCount int) { + logger.InfoCtx(pCtx, "WriteTableComplete: read %d, wrote %d items in %.3fs (%.0f items/s, %d indexes, eff rate %.0f writes/s), %d workers", + readCount, + recordCount, + dur.Seconds(), + float64(recordCount)/dur.Seconds(), + indexCount, + float64(recordCount*(indexCount+1))/dur.Seconds(), + workerCount) +} + +func RunReadFileForBatch(envConfig *env.EnvConfig, logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, srcFileIdx int) (BatchStats, error) { + logger.PushF("proc.RunReadFileForBatch") + defer logger.PopF() + + totalStartTime := time.Now() + bs := BatchStats{RowsRead: 0, RowsWritten: 0} + + node := pCtx.CurrentScriptNode + + if !node.HasFileReader() { + return bs, fmt.Errorf("node does not have file reader") + } + if !node.HasTableCreator() { + return bs, fmt.Errorf("node does not have table creator") + } + + if srcFileIdx < 0 || srcFileIdx >= len(node.FileReader.SrcFileUrls) { + return bs, fmt.Errorf("cannot find file to read: asked to read src file with index %d while there are only %d source files available", srcFileIdx, len(node.FileReader.SrcFileUrls)) + } + filePath := node.FileReader.SrcFileUrls[srcFileIdx] + + u, err := url.Parse(filePath) + if err != nil { + return bs, fmt.Errorf("cannot parse file uri %s: %s", filePath, err.Error()) + } + + bs.Src = filePath + bs.Dst = node.TableCreator.Name + cql.RunIdSuffix(pCtx.BatchInfo.RunId) + + var localSrcFile *os.File + var fileReader io.Reader + var fileReadSeeker io.ReadSeeker + if u.Scheme == xfer.UriSchemeFile || len(u.Scheme) == 0 { + localSrcFile, err = os.Open(filePath) + if err != nil { + return bs, err + } + defer localSrcFile.Close() + fileReader = bufio.NewReader(localSrcFile) + fileReadSeeker = localSrcFile + } else if u.Scheme == xfer.UriSchemeHttp || u.Scheme == xfer.UriSchemeHttps { + // If this is a parquet file, download it and then open so we have fileReadSeeker + if node.FileReader.ReaderFileType == sc.ReaderFileTypeParquet { + dstFile, err := os.CreateTemp("", "capi") + if err != nil { + return bs, fmt.Errorf("cannot create temp file for %s: %s", filePath, err.Error()) + } + + readCloser, err := xfer.GetHttpReadCloser(filePath, u.Scheme, envConfig.CaPath) + if err != nil { + dstFile.Close() + return bs, fmt.Errorf("cannot open http file %s: %s", filePath, err.Error()) + } + defer readCloser.Close() + + if _, err := io.Copy(dstFile, readCloser); err != nil { + dstFile.Close() + return bs, fmt.Errorf("cannot save http file %s to temp file %s: %s", filePath, dstFile.Name(), err.Error()) + } + + logger.Info("downloaded http file %s to %s", filePath, dstFile.Name()) + dstFile.Close() + defer os.Remove(dstFile.Name()) + + localSrcFile, err = os.Open(dstFile.Name()) + if err != nil { + return bs, fmt.Errorf("cannot read from file %s downloaded from %s: %s", dstFile.Name(), filePath, err.Error()) + } + defer localSrcFile.Close() + fileReadSeeker = localSrcFile + } else { + // Just read from the net + readCloser, err := xfer.GetHttpReadCloser(filePath, u.Scheme, envConfig.CaPath) + if err != nil { + return bs, err + } + fileReader = readCloser + defer readCloser.Close() + } + } else if u.Scheme == xfer.UriSchemeSftp { + // When dealing with sftp, we download the *whole* file, instead of providing a reader + dstFile, err := os.CreateTemp("", "capi") + if err != nil { + return bs, fmt.Errorf("cannot create temp file for %s: %s", filePath, err.Error()) + } + + // Download and schedule delete + if err = xfer.DownloadSftpFile(filePath, envConfig.PrivateKeys, dstFile); err != nil { + dstFile.Close() + return bs, err + } + logger.Info("downloaded sftp file %s to %s", filePath, dstFile.Name()) + dstFile.Close() + defer os.Remove(dstFile.Name()) + + // Create a reader for the temp file + localSrcFile, err = os.Open(dstFile.Name()) + if err != nil { + return bs, fmt.Errorf("cannot read from file %s downloaded from %s: %s", dstFile.Name(), filePath, err.Error()) + } + defer localSrcFile.Close() + fileReader = bufio.NewReader(localSrcFile) + fileReadSeeker = localSrcFile + } else { + return bs, fmt.Errorf("uri scheme %s not supported: %s", u.Scheme, filePath) + } + + if node.FileReader.ReaderFileType == sc.ReaderFileTypeCsv { + return readCsv(envConfig, logger, pCtx, totalStartTime, filePath, fileReader) + } else if node.FileReader.ReaderFileType == sc.ReaderFileTypeParquet { + return readParquet(envConfig, logger, pCtx, totalStartTime, filePath, fileReadSeeker) + } + + return BatchStats{RowsRead: 0, RowsWritten: 0}, fmt.Errorf("unknown reader file type: %d", node.FileReader.ReaderFileType) +} + +func RunCreateTableForCustomProcessorForBatch(envConfig *env.EnvConfig, + logger *l.CapiLogger, + pCtx *ctx.MessageProcessingContext, + readerNodeRunId int16, + startLeftToken int64, + endLeftToken int64) (BatchStats, error) { + + logger.PushF("proc.RunCreateTableForCustomProcessorForBatch") + defer logger.PopF() + + node := pCtx.CurrentScriptNode + + totalStartTime := time.Now() + bs := BatchStats{RowsRead: 0, RowsWritten: 0, Src: node.TableReader.TableName + cql.RunIdSuffix(readerNodeRunId), Dst: node.TableCreator.Name + cql.RunIdSuffix(readerNodeRunId)} + + if readerNodeRunId == 0 { + return bs, fmt.Errorf("this node has a dependency node to read data from that was never started in this keyspace (readerNodeRunId == 0)") + } + + if !node.HasTableReader() { + return bs, fmt.Errorf("node does not have table reader") + } + if !node.HasTableCreator() { + return bs, fmt.Errorf("node does not have table creator") + } + + // Fields to read from source table + srcLeftFieldRefs := sc.FieldRefs{} + srcLeftFieldRefs.AppendWithFilter(*node.CustomProcessor.GetUsedInTargetExpressionsFields(), sc.ReaderAlias) + srcLeftFieldRefs.AppendWithFilter(node.TableCreator.UsedInTargetExpressionsFields, sc.ReaderAlias) + + leftBatchSize := node.TableReader.RowsetSize + curStartLeftToken := startLeftToken + + rsIn := NewRowsetFromFieldRefs( + sc.FieldRefs{sc.RowidFieldRef(node.TableReader.TableName)}, + sc.FieldRefs{sc.RowidTokenFieldRef()}, + srcLeftFieldRefs) + + inserterBatchSize := DefaultInserterBatchSize + if inserterBatchSize < node.TableReader.RowsetSize { + inserterBatchSize = node.TableReader.RowsetSize + } + instr := newTableInserter(envConfig, pCtx, &node.TableCreator, inserterBatchSize) + if err := instr.startWorkers(logger, pCtx); err != nil { + return bs, err + } + defer instr.waitForWorkersAndCloseErrorsOut(logger, pCtx) + + flushVarsArray := func(varsArray []*eval.VarValuesMap, varsArrayCount int) error { + logger.PushF("proc.flushRowset") + defer logger.PopF() + + flushStartTime := time.Now() + rowsWritten := 0 + + for outRowIdx := 0; outRowIdx < varsArrayCount; outRowIdx++ { + vars := varsArray[outRowIdx] + + tableRecord, err := node.TableCreator.CalculateTableRecordFromSrcVars(false, *vars) + if err != nil { + return fmt.Errorf("cannot populate table record from [%v]: [%s]", vars, err.Error()) + } + + // Check table creator having + inResult, err := node.TableCreator.CheckTableRecordHavingCondition(tableRecord) + if err != nil { + return fmt.Errorf("cannot check having condition [%s], table record [%v]: [%s]", node.TableCreator.RawHaving, tableRecord, err.Error()) + } + + // Write batch if needed + if inResult { + if err = instr.add(tableRecord); err != nil { + return fmt.Errorf("cannot add record to %s: [%s]", node.TableCreator.Name, err.Error()) + } + rowsWritten++ + bs.RowsWritten++ + } + } + + reportWriteTable(logger, pCtx, rowsWritten, time.Since(flushStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) + + return nil + } + + for { + lastRetrievedLeftToken, err := selectBatchFromTableByToken(logger, + pCtx, + rsIn, + node.TableReader.TableName, + readerNodeRunId, + leftBatchSize, + curStartLeftToken, + endLeftToken) + if err != nil { + return bs, err + } + curStartLeftToken = lastRetrievedLeftToken + 1 + + if rsIn.RowCount == 0 { + break + } + customProcBatchStartTime := time.Now() + + if err = node.CustomProcessor.(CustomProcessorRunner).Run(logger, pCtx, rsIn, flushVarsArray); err != nil { + return bs, err + } + + custProcDur := time.Since(customProcBatchStartTime) + logger.InfoCtx(pCtx, "CustomProcessor: %d items in %v (%.0f items/s)", rsIn.RowCount, custProcDur, float64(rsIn.RowCount)/custProcDur.Seconds()) + + bs.RowsRead += rsIn.RowCount + if rsIn.RowCount < leftBatchSize { + break + } + } // for each source table batch + + bs.Elapsed = time.Since(totalStartTime) + reportWriteTableComplete(logger, pCtx, bs.RowsRead, bs.RowsWritten, bs.Elapsed, len(node.TableCreator.Indexes), instr.NumWorkers) + + return bs, nil +} + +func RunCreateTableForBatch(envConfig *env.EnvConfig, + logger *l.CapiLogger, + pCtx *ctx.MessageProcessingContext, + readerNodeRunId int16, + startLeftToken int64, + endLeftToken int64) (BatchStats, error) { + + logger.PushF("proc.RunCreateTableForBatch") + defer logger.PopF() + + node := pCtx.CurrentScriptNode + + batchStartTime := time.Now() + totalStartTime := time.Now() + bs := BatchStats{RowsRead: 0, RowsWritten: 0, Src: node.TableReader.TableName + cql.RunIdSuffix(readerNodeRunId), Dst: node.TableCreator.Name + cql.RunIdSuffix(readerNodeRunId)} + + if readerNodeRunId == 0 { + return bs, fmt.Errorf("this node has a dependency node to read data from that was never started in this keyspace (readerNodeRunId == 0)") + } + + if !node.HasTableReader() { + return bs, fmt.Errorf("node does not have table reader") + } + if !node.HasTableCreator() { + return bs, fmt.Errorf("node does not have table creator") + } + + // Fields to read from source table + srcLeftFieldRefs := sc.FieldRefs{} + srcLeftFieldRefs.AppendWithFilter(node.TableCreator.UsedInTargetExpressionsFields, sc.ReaderAlias) + + leftBatchSize := node.TableReader.RowsetSize + tableRecordBatchCount := 0 + curStartLeftToken := startLeftToken + + rsIn := NewRowsetFromFieldRefs( + sc.FieldRefs{sc.RowidFieldRef(node.TableReader.TableName)}, + sc.FieldRefs{sc.RowidTokenFieldRef()}, + srcLeftFieldRefs) + + inserterBatchSize := DefaultInserterBatchSize + if inserterBatchSize < node.TableReader.RowsetSize { + inserterBatchSize = node.TableReader.RowsetSize + } + instr := newTableInserter(envConfig, pCtx, &node.TableCreator, inserterBatchSize) + if err := instr.startWorkers(logger, pCtx); err != nil { + return bs, err + } + defer instr.waitForWorkersAndCloseErrorsOut(logger, pCtx) + + for { + lastRetrievedLeftToken, err := selectBatchFromTableByToken(logger, + pCtx, + rsIn, + node.TableReader.TableName, + readerNodeRunId, + leftBatchSize, + curStartLeftToken, + endLeftToken) + if err != nil { + return bs, err + } + curStartLeftToken = lastRetrievedLeftToken + 1 + + if rsIn.RowCount == 0 { + break + } + + // Save rsIn + for outRowIdx := 0; outRowIdx < rsIn.RowCount; outRowIdx++ { + vars := eval.VarValuesMap{} + if err := rsIn.ExportToVars(outRowIdx, &vars); err != nil { + return bs, err + } + + tableRecord, err := node.TableCreator.CalculateTableRecordFromSrcVars(false, vars) + if err != nil { + return bs, fmt.Errorf("cannot populate table record from [%v]: [%s]", vars, err.Error()) + } + + // Check table creator having + inResult, err := node.TableCreator.CheckTableRecordHavingCondition(tableRecord) + if err != nil { + return bs, fmt.Errorf("cannot check having condition [%s], table record [%v]: [%s]", node.TableCreator.RawHaving, tableRecord, err.Error()) + } + + // Write batch if needed + if inResult { + if err = instr.add(tableRecord); err != nil { + return bs, fmt.Errorf("cannot add record to batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + } + tableRecordBatchCount++ + if tableRecordBatchCount == DefaultInserterBatchSize { + if err := instr.waitForWorkers(logger, pCtx); err != nil { + return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + } + reportWriteTable(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) + batchStartTime = time.Now() + tableRecordBatchCount = 0 + if err := instr.startWorkers(logger, pCtx); err != nil { + return bs, err + } + } + bs.RowsWritten++ + } + } + + bs.RowsRead += rsIn.RowCount + if rsIn.RowCount < leftBatchSize { + break + } + } // for each source table batch + + // Write leftovers regardless of tableRecordBatchCount == 0 + if err := instr.waitForWorkers(logger, pCtx); err != nil { + return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + } + reportWriteTableLeftovers(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) + + bs.Elapsed = time.Since(totalStartTime) + reportWriteTableComplete(logger, pCtx, bs.RowsRead, bs.RowsWritten, bs.Elapsed, len(node.TableCreator.Indexes), instr.NumWorkers) + + return bs, nil +} + +func RunCreateTableRelForBatch(envConfig *env.EnvConfig, + logger *l.CapiLogger, + pCtx *ctx.MessageProcessingContext, + readerNodeRunId int16, + lookupNodeRunId int16, + startLeftToken int64, + endLeftToken int64) (BatchStats, error) { + + logger.PushF("proc.RunCreateTableRelForBatch") + defer logger.PopF() + + node := pCtx.CurrentScriptNode + + batchStartTime := time.Now() + totalStartTime := time.Now() + + bs := BatchStats{RowsRead: 0, RowsWritten: 0, Src: node.TableReader.TableName + cql.RunIdSuffix(readerNodeRunId), Dst: node.TableCreator.Name + cql.RunIdSuffix(readerNodeRunId)} + + if readerNodeRunId == 0 { + return bs, fmt.Errorf("this node has a dependency node to read data from that was never started in this keyspace (readerNodeRunId == 0)") + } + + if lookupNodeRunId == 0 { + return bs, fmt.Errorf("this node has a dependency node to lookup data at that was never started in this keyspace (lookupNodeRunId == 0)") + } + + if !node.HasTableReader() { + return bs, fmt.Errorf("node does not have table reader") + } + if !node.HasTableCreator() { + return bs, fmt.Errorf("node does not have table creator") + } + if !node.HasLookup() { + return bs, fmt.Errorf("node does not have lookup") + } + + // Fields to read from source table + srcLeftFieldRefs := sc.FieldRefs{} + srcLeftFieldRefs.AppendWithFilter(node.TableCreator.UsedInTargetExpressionsFields, sc.ReaderAlias) + srcLeftFieldRefs.Append(node.Lookup.LeftTableFields) + + srcRightFieldRefs := sc.FieldRefs{} + srcRightFieldRefs.AppendWithFilter(node.TableCreator.UsedInTargetExpressionsFields, sc.LookupAlias) + if node.Lookup.UsesFilter() { + srcRightFieldRefs.AppendWithFilter(node.Lookup.UsedInFilterFields, sc.LookupAlias) + } + + leftBatchSize := node.TableReader.RowsetSize + tableRecordBatchCount := 0 + + rsLeft := NewRowsetFromFieldRefs( + sc.FieldRefs{sc.RowidFieldRef(node.TableReader.TableName)}, + sc.FieldRefs{sc.RowidTokenFieldRef()}, + srcLeftFieldRefs) + + inserterBatchSize := DefaultInserterBatchSize + if inserterBatchSize < node.TableReader.RowsetSize { + inserterBatchSize = node.TableReader.RowsetSize + } + instr := newTableInserter(envConfig, pCtx, &node.TableCreator, inserterBatchSize) + if err := instr.startWorkers(logger, pCtx); err != nil { + return bs, err + } + defer instr.waitForWorkersAndCloseErrorsOut(logger, pCtx) + + curStartLeftToken := startLeftToken + leftPageIdx := 0 + for { + selectLeftBatchByTokenStartTime := time.Now() + lastRetrievedLeftToken, err := selectBatchFromTableByToken(logger, + pCtx, + rsLeft, + node.TableReader.TableName, + readerNodeRunId, + leftBatchSize, + curStartLeftToken, + endLeftToken) + if err != nil { + return bs, err + } + + logger.DebugCtx(pCtx, "selectBatchFromTableByToken: leftPageIdx %d, queried tokens from %d to %d in %.3fs, retrieved %d rows", leftPageIdx, curStartLeftToken, endLeftToken, time.Since(selectLeftBatchByTokenStartTime).Seconds(), rsLeft.RowCount) + + curStartLeftToken = lastRetrievedLeftToken + 1 + + if rsLeft.RowCount == 0 { + break + } + + // Setup eval ctx for each target field if grouping is involved + // map: rowid -> field -> ctx + eCtxMap := map[int64]map[string]*eval.EvalCtx{} + if node.HasLookup() && node.Lookup.IsGroup { + for rowIdx := 0; rowIdx < rsLeft.RowCount; rowIdx++ { + rowid := *((*rsLeft.Rows[rowIdx])[rsLeft.FieldsByFieldName["rowid"]].(*int64)) + eCtxMap[rowid] = map[string]*eval.EvalCtx{} + for fieldName, fieldDef := range node.TableCreator.Fields { + aggFuncEnabled, aggFuncType, aggFuncArgs := eval.DetectRootAggFunc(fieldDef.ParsedExpression) + newCtx, newCtxErr := eval.NewPlainEvalCtxAndInitializedAgg(aggFuncEnabled, aggFuncType, aggFuncArgs) + if newCtxErr != nil { + return bs, newCtxErr + } + eCtxMap[rowid][fieldName] = newCtx + } + } + } + + // Build keys to find in the lookup index, one key may yield multiple rowids + keyToLeftRowIdxMap := map[string][]int{} + leftRowFoundRightLookup := make([]bool, rsLeft.RowCount) + for rowIdx := 0; rowIdx < rsLeft.RowCount; rowIdx++ { + leftRowFoundRightLookup[rowIdx] = false + vars := eval.VarValuesMap{} + if err := rsLeft.ExportToVars(rowIdx, &vars); err != nil { + return bs, err + } + key, err := sc.BuildKey(vars[sc.ReaderAlias], node.Lookup.TableCreator.Indexes[node.Lookup.IndexName]) + if err != nil { + return bs, err + } + + _, ok := keyToLeftRowIdxMap[key] + if !ok { + keyToLeftRowIdxMap[key] = make([]int, 0) + } + keyToLeftRowIdxMap[key] = append(keyToLeftRowIdxMap[key], rowIdx) + } + + keysToFind := make([]string, len(keyToLeftRowIdxMap)) + i := 0 + for k := range keyToLeftRowIdxMap { + keysToFind[i] = k + i++ + } + + lookupFieldRefs := sc.FieldRefs{} + lookupFieldRefs.AppendWithFilter(node.TableCreator.UsedInHavingFields, node.Lookup.TableCreator.Name) + lookupFieldRefs.AppendWithFilter(node.TableCreator.UsedInTargetExpressionsFields, node.Lookup.TableCreator.Name) + + rsIdx := NewRowsetFromFieldRefs( + sc.FieldRefs{sc.RowidFieldRef(node.Lookup.IndexName)}, + sc.FieldRefs{sc.KeyTokenFieldRef()}, + sc.FieldRefs{sc.IdxKeyFieldRef()}) + + var idxPageState []byte + rightIdxPageIdx := 0 + for { + selectIdxBatchStartTime := time.Now() + idxPageState, err = selectBatchFromIdxTablePaged(logger, + pCtx, + rsIdx, + node.Lookup.IndexName, + lookupNodeRunId, + node.Lookup.IdxReadBatchSize, + idxPageState, + &keysToFind) + if err != nil { + return bs, err + } + + if rsIdx.RowCount == 0 { + break + } + + // Build a map of right-row-id -> key + rightRowIdToKeyMap := map[int64]string{} + for rowIdx := 0; rowIdx < rsIdx.RowCount; rowIdx++ { + rightRowId := *((*rsIdx.Rows[rowIdx])[rsIdx.FieldsByFieldName["rowid"]].(*int64)) + key := *((*rsIdx.Rows[rowIdx])[rsIdx.FieldsByFieldName["key"]].(*string)) + rightRowIdToKeyMap[rightRowId] = key + } + + rowidsToFind := make(map[int64]struct{}, len(rightRowIdToKeyMap)) + for k := range rightRowIdToKeyMap { + rowidsToFind[k] = struct{}{} + } + + logger.DebugCtx(pCtx, "selectBatchFromIdxTablePaged: leftPageIdx %d, rightIdxPageIdx %d, queried %d keys in %.3fs, retrieved %d rowids", leftPageIdx, rightIdxPageIdx, len(keysToFind), time.Since(selectIdxBatchStartTime).Seconds(), len(rowidsToFind)) + + // Select from right table by rowid + rsRight := NewRowsetFromFieldRefs( + sc.FieldRefs{sc.RowidFieldRef(node.Lookup.TableCreator.Name)}, + sc.FieldRefs{sc.RowidTokenFieldRef()}, + srcRightFieldRefs) + + var rightPageState []byte + rightDataPageIdx := 0 + for { + selectBatchStartTime := time.Now() + rightPageState, err = selectBatchFromDataTablePaged(logger, + pCtx, + rsRight, + node.Lookup.TableCreator.Name, + lookupNodeRunId, + node.Lookup.RightLookupReadBatchSize, + rightPageState, + rowidsToFind) + if err != nil { + return bs, err + } + + logger.DebugCtx(pCtx, "selectBatchFromDataTablePaged: leftPageIdx %d, rightIdxPageIdx %d, rightDataPageIdx %d, queried %d rowids in %.3fs, retrieved %d rowids", leftPageIdx, rightIdxPageIdx, rightDataPageIdx, len(rowidsToFind), time.Since(selectBatchStartTime).Seconds(), rsRight.RowCount) + + if rsRight.RowCount == 0 { + break + } + + for rightRowIdx := 0; rightRowIdx < rsRight.RowCount; rightRowIdx++ { + rightRowId := *((*rsRight.Rows[rightRowIdx])[rsRight.FieldsByFieldName["rowid"]].(*int64)) + rightRowKey := rightRowIdToKeyMap[rightRowId] + + // Remove this right rowid from the set, we do not need it anymore. Reset page state. + rightPageState = nil + delete(rowidsToFind, rightRowId) + + // Check filter condition if needed + lookupFilterOk := true + if node.Lookup.UsesFilter() { + vars := eval.VarValuesMap{} + if err := rsRight.ExportToVars(rightRowIdx, &vars); err != nil { + return bs, err + } + var err error + lookupFilterOk, err = node.Lookup.CheckFilterCondition(vars) + if err != nil { + return bs, fmt.Errorf("cannot check filter condition [%s] against [%v]: [%s]", node.Lookup.RawFilter, vars, err.Error()) + } + } + + if !lookupFilterOk { + continue + } + + if node.Lookup.IsGroup { + // Find correspondent row from rsLeft, merge left and right and + // call group eval eCtxMap[leftRowid] for each output field + for _, leftRowIdx := range keyToLeftRowIdxMap[rightRowKey] { + + leftRowFoundRightLookup[leftRowIdx] = true + + leftRowid := *((*rsLeft.Rows[leftRowIdx])[rsLeft.FieldsByFieldName["rowid"]].(*int64)) + for fieldName, fieldDef := range node.TableCreator.Fields { + eCtxMap[leftRowid][fieldName].Vars = &eval.VarValuesMap{} + if err := rsLeft.ExportToVars(leftRowIdx, eCtxMap[leftRowid][fieldName].Vars); err != nil { + return bs, err + } + if err := rsRight.ExportToVarsWithAlias(rightRowIdx, eCtxMap[leftRowid][fieldName].Vars, sc.LookupAlias); err != nil { + return bs, err + } + _, err := eCtxMap[leftRowid][fieldName].Eval(fieldDef.ParsedExpression) + if err != nil { + return bs, fmt.Errorf("cannot evaluate target expression [%s]: [%s]", fieldDef.RawExpression, err.Error()) + } + } + } + } else { + // Non-group. Find correspondent row from rsLeft, merge left and right and call row-level eval + for _, leftRowIdx := range keyToLeftRowIdxMap[rightRowKey] { + + leftRowFoundRightLookup[leftRowIdx] = true + + vars := eval.VarValuesMap{} + if err := rsLeft.ExportToVars(leftRowIdx, &vars); err != nil { + return bs, err + } + if err := rsRight.ExportToVarsWithAlias(rightRowIdx, &vars, sc.LookupAlias); err != nil { + return bs, err + } + + // We are ready to write this result right away, so prepare the output tableRecord + tableRecord, err := node.TableCreator.CalculateTableRecordFromSrcVars(false, vars) + if err != nil { + return bs, fmt.Errorf("cannot populate table record from [%v]: [%s]", vars, err.Error()) + } + + // Check table creator having + inResult, err := node.TableCreator.CheckTableRecordHavingCondition(tableRecord) + if err != nil { + return bs, fmt.Errorf("cannot check having condition [%s], table record [%v]: [%s]", node.TableCreator.RawHaving, tableRecord, err.Error()) + } + + // Write batch if needed + if inResult { + if err = instr.add(tableRecord); err != nil { + return bs, fmt.Errorf("cannot add record to batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + } + tableRecordBatchCount++ + if tableRecordBatchCount == instr.BatchSize { + if err := instr.waitForWorkers(logger, pCtx); err != nil { + return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + } + reportWriteTable(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) + batchStartTime = time.Now() + tableRecordBatchCount = 0 + if err := instr.startWorkers(logger, pCtx); err != nil { + return bs, err + } + } + bs.RowsWritten++ + } + } // non-group result row written + } // group case handled + } // for each found right row + + if rsRight.RowCount < node.Lookup.RightLookupReadBatchSize || len(rightPageState) == 0 { + break + } + rightDataPageIdx++ + } // for each data page + + if rsIdx.RowCount < node.Lookup.IdxReadBatchSize || len(idxPageState) == 0 { + break + } + rightIdxPageIdx++ + } // for each idx page + + // For grouped - group + // For non-grouped left join - add empty left-side (those who have right counterpart were alredy hendled above) + // Non-grouped inner join - already handled above + if node.Lookup.IsGroup { + // Time to write the result of the grouped + for leftRowIdx := 0; leftRowIdx < rsLeft.RowCount; leftRowIdx++ { + tableRecord := map[string]any{} + if !leftRowFoundRightLookup[leftRowIdx] { + if node.Lookup.LookupJoin == sc.LookupJoinInner { + // Grouped inner join with no data on the right + // Do not insert this left row + continue + } + // Grouped left outer join with no data on the right + leftVars := eval.VarValuesMap{} + if err := rsLeft.ExportToVars(leftRowIdx, &leftVars); err != nil { + return bs, err + } + + for fieldName, fieldDef := range node.TableCreator.Fields { + isAggEnabled, _, _ := eval.DetectRootAggFunc(fieldDef.ParsedExpression) + if isAggEnabled == eval.AggFuncEnabled { + // Aggregate func is used in field expression - ignore the expression and produce default + tableRecord[fieldName], err = node.TableCreator.GetFieldDefaultReadyForDb(fieldName) + if err != nil { + return bs, fmt.Errorf("cannot initialize default field %s: [%s]", fieldName, err.Error()) + } + } else { + // No aggregate function used in field expression - assume it contains only left-side fields + tableRecord[fieldName], err = sc.CalculateFieldValue(fieldName, fieldDef, leftVars, false) + if err != nil { + return bs, err + } + } + } + } else { + + // Grouped inner or left outer with present data on the right + + leftRowid := *((*rsLeft.Rows[leftRowIdx])[rsLeft.FieldsByFieldName["rowid"]].(*int64)) + for fieldName, fieldDef := range node.TableCreator.Fields { + finalValue := eCtxMap[leftRowid][fieldName].Value + + if err := sc.CheckValueType(finalValue, fieldDef.Type); err != nil { + return bs, fmt.Errorf("invalid field %s type: [%s]", fieldName, err.Error()) + } + tableRecord[fieldName] = finalValue + } + } + + // Check table creator having + inResult, err := node.TableCreator.CheckTableRecordHavingCondition(tableRecord) + if err != nil { + return bs, fmt.Errorf("cannot check having condition [%s], table record [%v]: [%s]", node.TableCreator.RawHaving, tableRecord, err.Error()) + } + + // Write batch if needed + if inResult { + if err = instr.add(tableRecord); err != nil { + return bs, fmt.Errorf("cannot add record to batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + } + tableRecordBatchCount++ + if tableRecordBatchCount == instr.BatchSize { + if err := instr.waitForWorkers(logger, pCtx); err != nil { + return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + } + reportWriteTable(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) + batchStartTime = time.Now() + tableRecordBatchCount = 0 + if err := instr.startWorkers(logger, pCtx); err != nil { + return bs, err + } + } + bs.RowsWritten++ + } + } + } else if node.Lookup.LookupJoin == sc.LookupJoinLeft { + + // Non-grouped left outer join. + // Handle those left rows that did not have right lookup counterpart + // (those who had - they have been written already) + + for leftRowIdx := 0; leftRowIdx < rsLeft.RowCount; leftRowIdx++ { + if leftRowFoundRightLookup[leftRowIdx] { + continue + } + + leftVars := eval.VarValuesMap{} + if err := rsLeft.ExportToVars(leftRowIdx, &leftVars); err != nil { + return bs, err + } + + tableRecord := map[string]any{} + + for fieldName, fieldDef := range node.TableCreator.Fields { + if fieldDef.UsedFields.HasFieldsWithTableAlias(sc.LookupAlias) { + // This field expression uses fields from lkp table - produce default value + tableRecord[fieldName], err = node.TableCreator.GetFieldDefaultReadyForDb(fieldName) + if err != nil { + return bs, fmt.Errorf("cannot initialize non-grouped default field %s: [%s]", fieldName, err.Error()) + } + } else { + // This field expression does not use fields from lkp table - assume the expression contains only left-side fields + tableRecord[fieldName], err = sc.CalculateFieldValue(fieldName, fieldDef, leftVars, false) + if err != nil { + return bs, err + } + } + } + + // Check table creator having + inResult, err := node.TableCreator.CheckTableRecordHavingCondition(tableRecord) + if err != nil { + return bs, fmt.Errorf("cannot check having condition [%s], table record [%v]: [%s]", node.TableCreator.RawHaving, tableRecord, err.Error()) + } + + // Write batch if needed + if inResult { + if err = instr.add(tableRecord); err != nil { + return bs, fmt.Errorf("cannot add record to batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + } + tableRecordBatchCount++ + if tableRecordBatchCount == instr.BatchSize { + if err := instr.waitForWorkers(logger, pCtx); err != nil { + return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + } + reportWriteTable(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) + batchStartTime = time.Now() + tableRecordBatchCount = 0 + if err := instr.startWorkers(logger, pCtx); err != nil { + return bs, err + } + } + bs.RowsWritten++ + } + } + } + + bs.RowsRead += rsLeft.RowCount + if rsLeft.RowCount < leftBatchSize { + break + } + leftPageIdx++ + } // for each source table batch + + // Write leftovers regardless of tableRecordBatchCount == 0 + if err := instr.waitForWorkers(logger, pCtx); err != nil { + return bs, fmt.Errorf("cannot save record batch of size %d to %s: [%s]", tableRecordBatchCount, node.TableCreator.Name, err.Error()) + } + reportWriteTableLeftovers(logger, pCtx, tableRecordBatchCount, time.Since(batchStartTime), len(node.TableCreator.Indexes), instr.NumWorkers) + + bs.Elapsed = time.Since(totalStartTime) + reportWriteTableComplete(logger, pCtx, bs.RowsRead, bs.RowsWritten, bs.Elapsed, len(node.TableCreator.Indexes), instr.NumWorkers) + return bs, nil +} diff --git a/pkg/proc/rowset.go b/pkg/proc/rowset.go index e45ee36..41a8207 100644 --- a/pkg/proc/rowset.go +++ b/pkg/proc/rowset.go @@ -1,255 +1,255 @@ -package proc - -import ( - "fmt" - "strings" - "time" - - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/shopspring/decimal" - "gopkg.in/inf.v0" -) - -type Rowset struct { - Fields []sc.FieldRef - FieldsByFullAliasName map[string]int - FieldsByFieldName map[string]int - Rows []*[]interface{} - RowCount int -} - -func NewRowsetFromFieldRefs(fieldRefsList ...sc.FieldRefs) *Rowset { - rs := Rowset{} - for i := 0; i < len(fieldRefsList); i++ { - rs.AppendFieldRefs(&fieldRefsList[i]) - } - return &rs -} - -func (rs *Rowset) ToString() string { - var b strings.Builder - for _, fr := range rs.Fields { - b.WriteString(fmt.Sprintf("%30s", fr.GetAliasHash())) - } - b.WriteString("\n") - for rowIdx := 0; rowIdx < rs.RowCount; rowIdx++ { - vals := rs.Rows[rowIdx] - for _, val := range *vals { - switch typedVal := val.(type) { - case *int64: - b.WriteString(fmt.Sprintf("%30d", *typedVal)) - case *float64: - b.WriteString(fmt.Sprintf("%30f", *typedVal)) - case *string: - b.WriteString(fmt.Sprintf("\"%30s\"", *typedVal)) - case *bool: - if *typedVal { - return " TRUE" - } else { - return " FALSE" - } - case *decimal.Decimal: - b.WriteString(fmt.Sprintf("%30s", (*typedVal).String())) - case *time.Time: - b.WriteString(fmt.Sprintf("%30s", (*typedVal).Format("\"2006-01-02T15:04:05.000-0700\""))) - default: - b.WriteString("bla") - } - } - b.WriteString("\n") - } - return b.String() -} - -func (rs *Rowset) ArrangeByRowid(rowids []int64) error { - if len(rowids) < rs.RowCount { - return fmt.Errorf("invalid rowid array length") - } - - rowidColIdx := rs.FieldsByFieldName["rowid"] - - // Build a map for quicker access - rowMap := map[int64]int{} - for rowIdx := 0; rowIdx < rs.RowCount; rowIdx++ { - rowid := *((*rs.Rows[rowIdx])[rowidColIdx]).(*int64) - rowMap[rowid] = rowIdx - } - - for i := 0; i < rs.RowCount; i++ { - // rowids[i] must be at i-th position in rs.Rows - if rowMap[rowids[i]] != i { - // Swap - tailIdx := rowMap[rowids[i]] - tailRowPtr := rs.Rows[tailIdx] - headRowid := *((*rs.Rows[i])[rowidColIdx]).(*int64) - - // Move rs.Rows[i] to the tail of rs.Rows - rs.Rows[tailIdx] = rs.Rows[i] - rowMap[headRowid] = tailIdx - - // Move tail row to the i-th position - rs.Rows[i] = tailRowPtr - rowMap[rowids[i]] = i // As it should be - } - } - - return nil -} - -func (rs *Rowset) GetFieldNames() *[]string { - fieldNames := make([]string, len(rs.Fields)) - for colIdx := 0; colIdx < len(rs.Fields); colIdx++ { - fieldNames[colIdx] = rs.Fields[colIdx].FieldName - } - return &fieldNames -} -func (rs *Rowset) AppendFieldRefs(fieldRefs *sc.FieldRefs) { - rs.AppendFieldRefsWithFilter(fieldRefs, "") -} - -func (rs *Rowset) AppendFieldRefsWithFilter(fieldRefs *sc.FieldRefs, tableFilter string) { - if rs.Fields == nil { - rs.Fields = make([]sc.FieldRef, 0) - } - if rs.FieldsByFullAliasName == nil { - rs.FieldsByFullAliasName = map[string]int{} - } - if rs.FieldsByFieldName == nil { - rs.FieldsByFieldName = map[string]int{} - } - - for i := 0; i < len(*fieldRefs); i++ { - if len(tableFilter) > 0 && (*fieldRefs)[i].TableName != tableFilter { - continue - } - key := (*fieldRefs)[i].GetAliasHash() - if _, ok := rs.FieldsByFullAliasName[key]; !ok { - rs.Fields = append(rs.Fields, (*fieldRefs)[i]) - rs.FieldsByFullAliasName[key] = len(rs.Fields) - 1 - rs.FieldsByFieldName[(*fieldRefs)[i].FieldName] = len(rs.Fields) - 1 - } - } -} - -func (rs *Rowset) InitRows(capacity int) error { - if rs.Rows == nil || len(rs.Rows) != capacity { - rs.Rows = make([](*[]interface{}), capacity) - } - for rowIdx := 0; rowIdx < capacity; rowIdx++ { - newRow := make([]interface{}, len(rs.Fields)) - rs.Rows[rowIdx] = &newRow - for colIdx := 0; colIdx < len(rs.Fields); colIdx++ { - switch rs.Fields[colIdx].FieldType { - case sc.FieldTypeInt: - v := int64(0) - (*rs.Rows[rowIdx])[colIdx] = &v - case sc.FieldTypeFloat: - v := float64(0.0) - (*rs.Rows[rowIdx])[colIdx] = &v - case sc.FieldTypeString: - v := "" - (*rs.Rows[rowIdx])[colIdx] = &v - case sc.FieldTypeDecimal2: - // Set it to Cassandra-accepted value, not decimal.Decimal: https://github.com/gocql/gocql/issues/1578 - (*rs.Rows[rowIdx])[colIdx] = inf.NewDec(0, 0) - case sc.FieldTypeBool: - v := false - (*rs.Rows[rowIdx])[colIdx] = &v - case sc.FieldTypeDateTime: - v := sc.DefaultDateTime() - (*rs.Rows[rowIdx])[colIdx] = &v - default: - return fmt.Errorf("InitRows unsupported field type %s, field %s.%s", rs.Fields[colIdx].FieldType, rs.Fields[colIdx].TableName, rs.Fields[colIdx].FieldName) - } - } - } - return nil -} -func (rs *Rowset) ExportToVars(rowIdx int, vars *eval.VarValuesMap) error { - return rs.ExportToVarsWithAlias(rowIdx, vars, "") -} - -func (rs *Rowset) GetTableRecord(rowIdx int) (map[string]interface{}, error) { - tableRecord := map[string]interface{}{} - for colIdx := 0; colIdx < len(rs.Fields); colIdx++ { - fName := rs.Fields[colIdx].FieldName - valuePtr := (*rs.Rows[rowIdx])[rs.FieldsByFieldName[fName]] - switch assertedValuePtr := valuePtr.(type) { - case *int64: - tableRecord[fName] = *assertedValuePtr - case *string: - tableRecord[fName] = *assertedValuePtr - case *time.Time: - tableRecord[fName] = *assertedValuePtr - case *bool: - tableRecord[fName] = *assertedValuePtr - case *decimal.Decimal: - tableRecord[fName] = *assertedValuePtr - case *float64: - tableRecord[fName] = *assertedValuePtr - case *inf.Dec: - decVal, err := decimal.NewFromString((*(valuePtr.(*inf.Dec))).String()) - if err != nil { - return nil, fmt.Errorf("GetTableRecord cannot convert inf.Dec [%v]to decimal.Decimal", *(valuePtr.(*inf.Dec))) - } - tableRecord[fName] = decVal - default: - return nil, fmt.Errorf("GetTableRecord unsupported field type %T", valuePtr) - } - } - return tableRecord, nil -} - -func (rs *Rowset) ExportToVarsWithAlias(rowIdx int, vars *eval.VarValuesMap, useTableAlias string) error { - for colIdx := 0; colIdx < len(rs.Fields); colIdx++ { - tName := &rs.Fields[colIdx].TableName - if len(useTableAlias) > 0 { - tName = &useTableAlias - } - fName := &rs.Fields[colIdx].FieldName - _, ok := (*vars)[*tName] - if !ok { - (*vars)[*tName] = map[string]interface{}{} - } - valuePtr := (*rs.Rows[rowIdx])[colIdx] - switch assertedValuePtr := valuePtr.(type) { - case *int64: - (*vars)[*tName][*fName] = *assertedValuePtr - case *string: - (*vars)[*tName][*fName] = *assertedValuePtr - case *time.Time: - (*vars)[*tName][*fName] = *assertedValuePtr - case *bool: - (*vars)[*tName][*fName] = *assertedValuePtr - case *decimal.Decimal: - (*vars)[*tName][*fName] = *assertedValuePtr - case *float64: - (*vars)[*tName][*fName] = *assertedValuePtr - case *inf.Dec: - decVal, err := decimal.NewFromString((*(valuePtr.(*inf.Dec))).String()) - if err != nil { - return fmt.Errorf("ExportToVars cannot convert inf.Dec [%v]to decimal.Decimal", *(valuePtr.(*inf.Dec))) - } - (*vars)[*tName][*fName] = decVal - default: - return fmt.Errorf("ExportToVars unsupported field type %T", valuePtr) - } - } - return nil -} - -// Force UTC TZ to each ts returned by gocql -// func (rs *Rowset) SanitizeScannedDatetimesToUtc(rowIdx int) error { -// for valIdx := 0; valIdx < len(rs.Fields); valIdx++ { -// if rs.Fields[valIdx].FieldType == sc.FieldTypeDateTime { -// origVolatile := (*rs.Rows[rowIdx])[valIdx] -// origDt, ok := origVolatile.(time.Time) -// if !ok { -// return fmt.Errorf("invalid type %t(%v), expected datetime", origVolatile, origVolatile) -// } -// (*rs.Rows[rowIdx])[valIdx] = origDt.In(time.UTC) -// } -// } -// return nil -// } +package proc + +import ( + "fmt" + "strings" + "time" + + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/shopspring/decimal" + "gopkg.in/inf.v0" +) + +type Rowset struct { + Fields []sc.FieldRef + FieldsByFullAliasName map[string]int + FieldsByFieldName map[string]int + Rows []*[]any + RowCount int +} + +func NewRowsetFromFieldRefs(fieldRefsList ...sc.FieldRefs) *Rowset { + rs := Rowset{} + for i := 0; i < len(fieldRefsList); i++ { + rs.AppendFieldRefs(&fieldRefsList[i]) + } + return &rs +} + +func (rs *Rowset) ToString() string { + var b strings.Builder + for _, fr := range rs.Fields { + b.WriteString(fmt.Sprintf("%30s", fr.GetAliasHash())) + } + b.WriteString("\n") + for rowIdx := 0; rowIdx < rs.RowCount; rowIdx++ { + vals := rs.Rows[rowIdx] + for _, val := range *vals { + switch typedVal := val.(type) { + case *int64: + b.WriteString(fmt.Sprintf("%30d", *typedVal)) + case *float64: + b.WriteString(fmt.Sprintf("%30f", *typedVal)) + case *string: + b.WriteString(fmt.Sprintf("\"%30s\"", *typedVal)) + case *bool: + if *typedVal { + return " TRUE" + } else { + return " FALSE" + } + case *decimal.Decimal: + b.WriteString(fmt.Sprintf("%30s", (*typedVal).String())) + case *time.Time: + b.WriteString(fmt.Sprintf("%30s", (*typedVal).Format("\"2006-01-02T15:04:05.000-0700\""))) + default: + b.WriteString("bla") + } + } + b.WriteString("\n") + } + return b.String() +} + +func (rs *Rowset) ArrangeByRowid(rowids []int64) error { + if len(rowids) < rs.RowCount { + return fmt.Errorf("invalid rowid array length") + } + + rowidColIdx := rs.FieldsByFieldName["rowid"] + + // Build a map for quicker access + rowMap := map[int64]int{} + for rowIdx := 0; rowIdx < rs.RowCount; rowIdx++ { + rowid := *((*rs.Rows[rowIdx])[rowidColIdx]).(*int64) + rowMap[rowid] = rowIdx + } + + for i := 0; i < rs.RowCount; i++ { + // rowids[i] must be at i-th position in rs.Rows + if rowMap[rowids[i]] != i { + // Swap + tailIdx := rowMap[rowids[i]] + tailRowPtr := rs.Rows[tailIdx] + headRowid := *((*rs.Rows[i])[rowidColIdx]).(*int64) + + // Move rs.Rows[i] to the tail of rs.Rows + rs.Rows[tailIdx] = rs.Rows[i] + rowMap[headRowid] = tailIdx + + // Move tail row to the i-th position + rs.Rows[i] = tailRowPtr + rowMap[rowids[i]] = i // As it should be + } + } + + return nil +} + +func (rs *Rowset) GetFieldNames() *[]string { + fieldNames := make([]string, len(rs.Fields)) + for colIdx := 0; colIdx < len(rs.Fields); colIdx++ { + fieldNames[colIdx] = rs.Fields[colIdx].FieldName + } + return &fieldNames +} +func (rs *Rowset) AppendFieldRefs(fieldRefs *sc.FieldRefs) { + rs.AppendFieldRefsWithFilter(fieldRefs, "") +} + +func (rs *Rowset) AppendFieldRefsWithFilter(fieldRefs *sc.FieldRefs, tableFilter string) { + if rs.Fields == nil { + rs.Fields = make([]sc.FieldRef, 0) + } + if rs.FieldsByFullAliasName == nil { + rs.FieldsByFullAliasName = map[string]int{} + } + if rs.FieldsByFieldName == nil { + rs.FieldsByFieldName = map[string]int{} + } + + for i := 0; i < len(*fieldRefs); i++ { + if len(tableFilter) > 0 && (*fieldRefs)[i].TableName != tableFilter { + continue + } + key := (*fieldRefs)[i].GetAliasHash() + if _, ok := rs.FieldsByFullAliasName[key]; !ok { + rs.Fields = append(rs.Fields, (*fieldRefs)[i]) + rs.FieldsByFullAliasName[key] = len(rs.Fields) - 1 + rs.FieldsByFieldName[(*fieldRefs)[i].FieldName] = len(rs.Fields) - 1 + } + } +} + +func (rs *Rowset) InitRows(capacity int) error { + if rs.Rows == nil || len(rs.Rows) != capacity { + rs.Rows = make([](*[]any), capacity) + } + for rowIdx := 0; rowIdx < capacity; rowIdx++ { + newRow := make([]any, len(rs.Fields)) + rs.Rows[rowIdx] = &newRow + for colIdx := 0; colIdx < len(rs.Fields); colIdx++ { + switch rs.Fields[colIdx].FieldType { + case sc.FieldTypeInt: + v := int64(0) + (*rs.Rows[rowIdx])[colIdx] = &v + case sc.FieldTypeFloat: + v := float64(0.0) + (*rs.Rows[rowIdx])[colIdx] = &v + case sc.FieldTypeString: + v := "" + (*rs.Rows[rowIdx])[colIdx] = &v + case sc.FieldTypeDecimal2: + // Set it to Cassandra-accepted value, not decimal.Decimal: https://github.com/gocql/gocql/issues/1578 + (*rs.Rows[rowIdx])[colIdx] = inf.NewDec(0, 0) + case sc.FieldTypeBool: + v := false + (*rs.Rows[rowIdx])[colIdx] = &v + case sc.FieldTypeDateTime: + v := sc.DefaultDateTime() + (*rs.Rows[rowIdx])[colIdx] = &v + default: + return fmt.Errorf("InitRows unsupported field type %s, field %s.%s", rs.Fields[colIdx].FieldType, rs.Fields[colIdx].TableName, rs.Fields[colIdx].FieldName) + } + } + } + return nil +} +func (rs *Rowset) ExportToVars(rowIdx int, vars *eval.VarValuesMap) error { + return rs.ExportToVarsWithAlias(rowIdx, vars, "") +} + +func (rs *Rowset) GetTableRecord(rowIdx int) (map[string]any, error) { + tableRecord := map[string]any{} + for colIdx := 0; colIdx < len(rs.Fields); colIdx++ { + fName := rs.Fields[colIdx].FieldName + valuePtr := (*rs.Rows[rowIdx])[rs.FieldsByFieldName[fName]] + switch assertedValuePtr := valuePtr.(type) { + case *int64: + tableRecord[fName] = *assertedValuePtr + case *string: + tableRecord[fName] = *assertedValuePtr + case *time.Time: + tableRecord[fName] = *assertedValuePtr + case *bool: + tableRecord[fName] = *assertedValuePtr + case *decimal.Decimal: + tableRecord[fName] = *assertedValuePtr + case *float64: + tableRecord[fName] = *assertedValuePtr + case *inf.Dec: + decVal, err := decimal.NewFromString((*(valuePtr.(*inf.Dec))).String()) + if err != nil { + return nil, fmt.Errorf("GetTableRecord cannot convert inf.Dec [%v]to decimal.Decimal", *(valuePtr.(*inf.Dec))) + } + tableRecord[fName] = decVal + default: + return nil, fmt.Errorf("GetTableRecord unsupported field type %T", valuePtr) + } + } + return tableRecord, nil +} + +func (rs *Rowset) ExportToVarsWithAlias(rowIdx int, vars *eval.VarValuesMap, useTableAlias string) error { + for colIdx := 0; colIdx < len(rs.Fields); colIdx++ { + tName := &rs.Fields[colIdx].TableName + if len(useTableAlias) > 0 { + tName = &useTableAlias + } + fName := &rs.Fields[colIdx].FieldName + _, ok := (*vars)[*tName] + if !ok { + (*vars)[*tName] = map[string]any{} + } + valuePtr := (*rs.Rows[rowIdx])[colIdx] + switch assertedValuePtr := valuePtr.(type) { + case *int64: + (*vars)[*tName][*fName] = *assertedValuePtr + case *string: + (*vars)[*tName][*fName] = *assertedValuePtr + case *time.Time: + (*vars)[*tName][*fName] = *assertedValuePtr + case *bool: + (*vars)[*tName][*fName] = *assertedValuePtr + case *decimal.Decimal: + (*vars)[*tName][*fName] = *assertedValuePtr + case *float64: + (*vars)[*tName][*fName] = *assertedValuePtr + case *inf.Dec: + decVal, err := decimal.NewFromString((*(valuePtr.(*inf.Dec))).String()) + if err != nil { + return fmt.Errorf("ExportToVars cannot convert inf.Dec [%v]to decimal.Decimal", *(valuePtr.(*inf.Dec))) + } + (*vars)[*tName][*fName] = decVal + default: + return fmt.Errorf("ExportToVars unsupported field type %T", valuePtr) + } + } + return nil +} + +// Force UTC TZ to each ts returned by gocql +// func (rs *Rowset) SanitizeScannedDatetimesToUtc(rowIdx int) error { +// for valIdx := 0; valIdx < len(rs.Fields); valIdx++ { +// if rs.Fields[valIdx].FieldType == sc.FieldTypeDateTime { +// origVolatile := (*rs.Rows[rowIdx])[valIdx] +// origDt, ok := origVolatile.(time.Time) +// if !ok { +// return fmt.Errorf("invalid type %t(%v), expected datetime", origVolatile, origVolatile) +// } +// (*rs.Rows[rowIdx])[valIdx] = origDt.In(time.UTC) +// } +// } +// return nil +// } diff --git a/pkg/proc/table_inserter.go b/pkg/proc/table_inserter.go index 5563403..e4c9cfa 100644 --- a/pkg/proc/table_inserter.go +++ b/pkg/proc/table_inserter.go @@ -1,420 +1,446 @@ -package proc - -import ( - "fmt" - "math/rand" - "strings" - "sync" - "time" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/env" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/sc" -) - -type TableInserter struct { - PCtx *ctx.MessageProcessingContext - TableCreator *sc.TableCreatorDef - BatchSize int - RecordsIn chan WriteChannelItem // Channel to pass records from the main function like RunCreateTableForBatch, usig add(), to TableInserter - ErrorsOut chan error - RowidRand *rand.Rand - RandMutex sync.Mutex - NumWorkers int - MinInserterRate int - WorkerWaitGroup sync.WaitGroup - RecordsSent int // Records sent to RecordsIn - // TODO: the only reason we have this is because we decided to end handlers - // with "defer instr.waitForWorkersAndCloseErrorsOut(logger, pCtx)" - not the cleanest way, get rid of this bool thingy. - // That defer is convenient because there are so many early returns. - RecordsInOpen bool -} - -type WriteChannelItem struct { - TableRecord *TableRecord - IndexKeyMap map[string]string -} - -var seedCounter = int64(0) - -func newSeed() int64 { - seedCounter += 3333 - return (time.Now().Unix() << 32) + time.Now().UnixMilli() + seedCounter -} - -func newTableInserter(envConfig *env.EnvConfig, logger *l.Logger, pCtx *ctx.MessageProcessingContext, tableCreator *sc.TableCreatorDef, batchSize int) *TableInserter { - - return &TableInserter{ - PCtx: pCtx, - TableCreator: tableCreator, - BatchSize: batchSize, - ErrorsOut: make(chan error, batchSize), - RowidRand: rand.New(rand.NewSource(newSeed())), - NumWorkers: envConfig.Cassandra.WriterWorkers, - MinInserterRate: envConfig.Cassandra.MinInserterRate, - RecordsInOpen: false, - //Logger: logger, - } -} - -func CreateDataTableCql(keyspace string, runId int16, tableCreator *sc.TableCreatorDef) string { - qb := cql.NewQB() - qb.ColumnDef("rowid", sc.FieldTypeInt) - qb.ColumnDef("batch_idx", sc.FieldTypeInt) - for fieldName, fieldDef := range tableCreator.Fields { - qb.ColumnDef(fieldName, fieldDef.Type) - } - return qb.PartitionKey("rowid").Keyspace(keyspace).CreateRun(tableCreator.Name, runId, cql.IgnoreIfExists) -} - -func CreateIdxTableCql(keyspace string, runId int16, idxName string, idxDef *sc.IdxDef) string { - qb := cql.NewQB() - qb.Keyspace(keyspace). - ColumnDef("key", sc.FieldTypeString). - ColumnDef("rowid", sc.FieldTypeInt) - //ColumnDef("batch_idx", sc.FieldTypeInt) - if idxDef.Uniqueness == sc.IdxUnique { - // Key must be unique, let Cassandra enforce it for us: PRIMARY KEY (key) - qb.PartitionKey("key") - } else { - // There can be multiple rowids with the same key: PRIMARY KEY (key, rowid) - qb.PartitionKey("key") - qb.ClusteringKey("rowid") - } - return qb.CreateRun(idxName, runId, cql.IgnoreIfExists) -} - -// Obsolete: now we create all run-specific tables in api.StartRun -// -// func (instr *TableInserter) verifyTablesExist() error { -// q := CreateDataTableCql(instr.PCtx.BatchInfo.DataKeyspace, instr.PCtx.BatchInfo.RunId, instr.TableCreator) -// if err := instr.PCtx.CqlSession.Query(q).Exec(); err != nil { -// return db.WrapDbErrorWithQuery("cannot create data table", q, err) -// } - -// for idxName, idxDef := range instr.TableCreator.Indexes { -// q := CreateIdxTableCql(instr.PCtx.BatchInfo.DataKeyspace, instr.PCtx.BatchInfo.RunId, idxName, idxDef) -// if err := instr.PCtx.CqlSession.Query(q).Exec(); err != nil { -// return db.WrapDbErrorWithQuery("cannot create idx table", q, err) -// } -// } -// return nil -// } - -func (instr *TableInserter) startWorkers(logger *l.Logger, pCtx *ctx.MessageProcessingContext) error { - logger.PushF("proc.startWorkers/TableInserter") - defer logger.PopF() - - instr.RecordsIn = make(chan WriteChannelItem, instr.BatchSize) - logger.DebugCtx(pCtx, "startWorkers created RecordsIn,now launching %d writers...", instr.NumWorkers) - instr.RecordsInOpen = true - - for w := 0; w < instr.NumWorkers; w++ { - newLogger, err := l.NewLoggerFromLogger(logger) - if err != nil { - return err - } - // Increase busy worker count - instr.WorkerWaitGroup.Add(1) - go instr.tableInserterWorker(newLogger, pCtx) - } - return nil -} - -func (instr *TableInserter) waitForWorkers(logger *l.Logger, pCtx *ctx.MessageProcessingContext) error { - logger.PushF("proc.waitForWorkers/TableInserter") - defer logger.PopF() - - logger.DebugCtx(pCtx, "started reading RecordsSent=%d from instr.ErrorsOut", instr.RecordsSent) - - errors := make([]string, 0) - if instr.RecordsSent > 0 { - errCount := 0 - startTime := time.Now() - // 1. It's crucial that the number of errors to receive eventually should match instr.RecordsSent - // 2. We do not need an extra select/timeout here - we are guaranteed to receive something in instr.ErrorsOut because of cassndra read timeouts (5-15s or so) - for i := 0; i < instr.RecordsSent; i++ { - err := <-instr.ErrorsOut - if err != nil { - errors = append(errors, err.Error()) - errCount++ - } - //logger.DebugCtx(pCtx, "got result for sent record %d out of %d from instr.ErrorsOut, %d errors so far", i, instr.RecordsSent, errCount) - - inserterRate := float64(i+1) / time.Now().Sub(startTime).Seconds() - // If it falls below min rate, it does not make sense to continue - if i > 5 && inserterRate < float64(instr.MinInserterRate) { - logger.DebugCtx(pCtx, "slow db insertion rate triggered, will stop reading from instr.ErrorsOut") - errors = append(errors, fmt.Sprintf("table inserter detected slow db insertion rate %.0f records/s, wrote %d records out of %d", inserterRate, i, instr.RecordsSent)) - errCount++ - break - } - } - logger.DebugCtx(pCtx, "done writing RecordsSent=%d from instr.ErrorsOut, %d errors", instr.RecordsSent, errCount) - - // Reset for the next cycle, if it ever happens - instr.RecordsSent = 0 - } else { - logger.DebugCtx(pCtx, "no need to waitfor writer results, no records were sent") - } - - // Close instr.RecordsIn, it will trigger the completion of all writer workers - if instr.RecordsInOpen { - logger.DebugCtx(pCtx, "closing RecordsIn") - close(instr.RecordsIn) - logger.DebugCtx(pCtx, "closed RecordsIn") - instr.RecordsInOpen = false - } - - // Wait for all writer threads to complete, otherwise they will keep writing to instr.ErrorsOut, which can close anytime after we exit this function - logger.DebugCtx(pCtx, "waiting for writer workers to complete...") - instr.WorkerWaitGroup.Wait() - logger.DebugCtx(pCtx, "writer workers are done") - - if len(errors) > 0 { - return fmt.Errorf(strings.Join(errors, "; ")) - } else { - return nil - } -} - -func (instr *TableInserter) waitForWorkersAndCloseErrorsOut(logger *l.Logger, pCtx *ctx.MessageProcessingContext) { - logger.PushF("proc.waitForWorkersAndClose/TableInserter") - defer logger.PopF() - - // Make sure no workers are running, so they do not hit closed ErrorsOut - instr.waitForWorkers(logger, pCtx) - // Safe to close now - logger.DebugCtx(pCtx, "closing ErrorsOut") - close(instr.ErrorsOut) - logger.DebugCtx(pCtx, "closed ErrorsOut") -} - -func (instr *TableInserter) add(tableRecord TableRecord) error { - indexKeyMap := map[string]string{} - for idxName, idxDef := range instr.TableCreator.Indexes { - var err error - indexKeyMap[idxName], err = sc.BuildKey(tableRecord, idxDef) - if err != nil { - return fmt.Errorf("cannot build key for idx %s, table record [%v]: [%s]", idxName, tableRecord, err.Error()) - } - } - - instr.RecordsSent++ - instr.RecordsIn <- WriteChannelItem{TableRecord: &tableRecord, IndexKeyMap: indexKeyMap} - - return nil -} - -func (instr *TableInserter) tableInserterWorker(logger *l.Logger, pCtx *ctx.MessageProcessingContext) { - logger.PushF("proc.tableInserterWorker") - defer logger.PopF() - - logger.DebugCtx(pCtx, "writer started reading from RecordsIn") - dataTableName := instr.TableCreator.Name + cql.RunIdSuffix(instr.PCtx.BatchInfo.RunId) - - var dataQb *cql.QueryBuilder - var preparedDataQueryErr error - var preparedDataQuery string - - handledRecordCount := 0 - for writeItem := range instr.RecordsIn { - handledRecordCount++ - maxDataRetries := 5 - curDataExpBackoffFactor := 1 - var errorToReport error - - if preparedDataQueryErr != nil { - instr.ErrorsOut <- fmt.Errorf("cannot prepare data query: %s", preparedDataQueryErr) - continue - } else if dataQb == nil { - dataQb = cql.NewQB() - dataQb.WritePreparedColumn("rowid") - dataQb.WritePreparedColumn("batch_idx") - dataQb.WritePreparedValue("batch_idx", instr.PCtx.BatchInfo.BatchIdx) - - for fieldName, _ := range *writeItem.TableRecord { - if err := dataQb.WritePreparedColumn(fieldName); err != nil { - errorToReport = fmt.Errorf("cannot prepare data query: %s", err) - break - } - } - if errorToReport != nil { - instr.ErrorsOut <- errorToReport - continue // next insert - } - - var err error - preparedDataQuery, err = dataQb.Keyspace(instr.PCtx.BatchInfo.DataKeyspace). - InsertRunPreparedQuery(instr.TableCreator.Name, instr.PCtx.BatchInfo.RunId, cql.IgnoreIfExists) // INSERT IF NOT EXISTS; if exists, returned isApplied = false - if err != nil { - instr.ErrorsOut <- fmt.Errorf("cannot prepare data query: %s", err) - continue // next insert - } - } - - instr.RandMutex.Lock() - (*writeItem.TableRecord)["rowid"] = instr.RowidRand.Int63() - instr.RandMutex.Unlock() - - for fieldName, fieldValue := range *writeItem.TableRecord { - dataQb.WritePreparedValue(fieldName, fieldValue) - } - preparedDataQueryParams, err := dataQb.InsertRunParams() - if err != nil { - instr.ErrorsOut <- fmt.Errorf("cannot provide insert params for prepared query %s: %s", preparedDataQuery, err.Error()) - continue // next insert - } - - for dataRetryCount := 0; dataRetryCount < maxDataRetries; dataRetryCount++ { - - existingDataRow := map[string]interface{}{} - isApplied, err := instr.PCtx.CqlSession.Query(preparedDataQuery, preparedDataQueryParams...).MapScanCAS(existingDataRow) - - if err == nil { - if isApplied { - // Success - break - } else { - if dataRetryCount < maxDataRetries-1 { - // Retry now with a new rowid - logger.InfoCtx(instr.PCtx, "duplicate rowid not written [%s], existing record [%v], retry count %d", preparedDataQuery, existingDataRow, dataRetryCount) - instr.RandMutex.Lock() - instr.RowidRand = rand.New(rand.NewSource(newSeed())) - (*writeItem.TableRecord)["rowid"] = instr.RowidRand.Int63() - instr.RandMutex.Unlock() - - // Set new rowid and re-build query params array (shouldn't throw errors this time) - dataQb.WritePreparedValue("rowid", (*writeItem.TableRecord)["rowid"]) - preparedDataQueryParams, _ = dataQb.InsertRunParams() - } else { - // No more retries - logger.ErrorCtx(instr.PCtx, "duplicate rowid not written [%s], existing record [%v], retry count %d reached, giving up", preparedDataQuery, existingDataRow, dataRetryCount) - errorToReport = fmt.Errorf("cannot write to data table after multiple attempts, keep getting rowid duplicates [%s]", preparedDataQuery) - break - } - } - } else { - if strings.Contains(err.Error(), "does not exist") { - // There is a chance this table is brand new and table schema was not propagated to all Cassandra nodes - if dataRetryCount < maxDataRetries-1 { - logger.WarnCtx(instr.PCtx, "will wait for table %s to be created, retry count %d, got %s", dataTableName, dataRetryCount, err.Error()) - // TODO: come up with a better waiting strategy (exp backoff, at least) - time.Sleep(5 * time.Second) - } else { - errorToReport = fmt.Errorf("cannot write to data table %s after %d attempts, apparently, table schema still not propagated to all nodes: %s", dataTableName, dataRetryCount+1, err.Error()) - break - } - } else if strings.Contains(err.Error(), "Operation timed out") { - // The cluster is overloaded, slow down - if dataRetryCount < maxDataRetries-1 { - logger.WarnCtx(instr.PCtx, "cluster overloaded (%s), will wait for %dms before writing to data table %s again, retry count %d", err.Error(), 10*curDataExpBackoffFactor, dataTableName, dataRetryCount) - time.Sleep(time.Duration(10*curDataExpBackoffFactor) * time.Millisecond) - curDataExpBackoffFactor *= 2 - } else { - errorToReport = fmt.Errorf("cannot write to data table %s after %d attempts, still getting timeouts: %s", dataTableName, dataRetryCount+1, err.Error()) - break - } - } else { - // Some serious error happened, stop trying this rowid - errorToReport = db.WrapDbErrorWithQuery("cannot write to data table", preparedDataQuery, err) - break - } - } - } // data retry loop - - if errorToReport == nil { - // Index tables - for idxName, idxDef := range instr.TableCreator.Indexes { - - maxIdxRetries := 5 - idxTableName := idxName + cql.RunIdSuffix(instr.PCtx.BatchInfo.RunId) - curIdxExpBackoffFactor := 1 - - ifNotExistsFlag := cql.ThrowIfExists - if idxDef.Uniqueness == sc.IdxUnique { - ifNotExistsFlag = cql.IgnoreIfExists - } - - idxQb := cql.NewQB() - idxQb.WritePreparedColumn("key") - idxQb.WritePreparedValue("key", writeItem.IndexKeyMap[idxName]) - idxQb.WritePreparedColumn("rowid") - idxQb.WritePreparedValue("rowid", (*writeItem.TableRecord)["rowid"]) - - preparedIdxQuery, err := idxQb.Keyspace(instr.PCtx.BatchInfo.DataKeyspace).InsertRunPreparedQuery(idxName, instr.PCtx.BatchInfo.RunId, ifNotExistsFlag) - if err != nil { - errorToReport = fmt.Errorf("cannot prepare idx query: %s", err.Error()) - break - } - preparedIdxQueryParams, err := idxQb.InsertRunParams() - if err != nil { - errorToReport = fmt.Errorf("cannot provide idx query params for %s: %s", preparedIdxQuery, err.Error()) - break - } - - for idxRetryCount := 0; idxRetryCount < maxIdxRetries; idxRetryCount++ { - existingIdxRow := map[string]interface{}{} - var isApplied = true - var err error - if idxDef.Uniqueness == sc.IdxUnique { - // Unique idx assumed, check isApplied - isApplied, err = instr.PCtx.CqlSession.Query(preparedIdxQuery, preparedIdxQueryParams...).MapScanCAS(existingIdxRow) - } else { - // No uniqueness assumed, just insert - err = instr.PCtx.CqlSession.Query(preparedIdxQuery, preparedIdxQueryParams...).Exec() - } - - if err == nil { - if !isApplied { - // If attempt number > 0, there is a chance that Cassandra managed to insert the record on the previous attempt but returned an error - if idxRetryCount > 0 && existingIdxRow["key"] == writeItem.IndexKeyMap[idxName] && existingIdxRow["rowid"] == (*writeItem.TableRecord)["rowid"] { - // Cassandra screwed up, but we know how to handle it, the record is there, just log a warning - logger.WarnCtx(instr.PCtx, "duplicate idx record found (%s) in idx %s on retry %d when writing (%d,'%s'), assuming this retry was successful, proceeding as usual", idxName, existingIdxRow, idxRetryCount, (*writeItem.TableRecord)["rowid"], writeItem.IndexKeyMap[idxName]) - } else { - // We screwed up, report everything we can - errorToReport = fmt.Errorf("cannot write duplicate index key [%s] on retry %d, existing record [%v]", preparedDataQuery, idxRetryCount, existingIdxRow) - } - } - // Success or not - we are done - break - } else { - if strings.Contains(err.Error(), "does not exist") { - // There is a chance this table is brand new and table schema was not propagated to all Cassandra nodes - if idxRetryCount < maxIdxRetries-1 { - logger.WarnCtx(instr.PCtx, "will wait for idx table %s to be created, retry count %d, got %s", idxTableName, idxRetryCount, err.Error()) - // TODO: come up with a better waiting strategy (exp backoff, at least) - time.Sleep(5 * time.Second) - } else { - errorToReport = fmt.Errorf("cannot write to idx table %s after %d attempts, apparently, table schema still not propagated to all nodes: %s", idxTableName, idxRetryCount+1, err.Error()) - break - } - } else if strings.Contains(err.Error(), "Operation timed out") { - // The cluster is overloaded, slow down - if idxRetryCount < maxIdxRetries-1 { - logger.WarnCtx(instr.PCtx, "cluster overloaded (%s), will wait for %dms before writing to idx table %s again, retry count %d", err.Error(), 10*curIdxExpBackoffFactor, idxTableName, idxRetryCount) - time.Sleep(time.Duration(10*curIdxExpBackoffFactor) * time.Millisecond) - curIdxExpBackoffFactor *= 2 - } else { - errorToReport = fmt.Errorf("cannot write to idx table %s after %d attempts, still getting timeout: %s", idxTableName, idxRetryCount+1, err.Error()) - break - } - } else { - // Some serious error happened, stop trying this idx record - errorToReport = db.WrapDbErrorWithQuery("cannot write to idx table", preparedDataQuery, err) - break - } - } - } // idx retry loop - } // idx loop - } - // logger.DebugCtx(pCtx, "writer wrote") - instr.ErrorsOut <- errorToReport - } - logger.DebugCtx(pCtx, "done reading from RecordsIn, this writer worker handled %d records from instr.RecordsIn", handledRecordCount) - // Decrease busy worker count - instr.WorkerWaitGroup.Done() -} +package proc + +import ( + "fmt" + "math/rand" + "strings" + "sync" + "time" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/env" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/sc" +) + +type TableInserter struct { + PCtx *ctx.MessageProcessingContext + TableCreator *sc.TableCreatorDef + BatchSize int + RecordsIn chan WriteChannelItem // Channel to pass records from the main function like RunCreateTableForBatch, usig add(), to TableInserter + ErrorsOut chan error + RowidRand *rand.Rand + RandMutex sync.Mutex + NumWorkers int + MinInserterRate int + WorkerWaitGroup sync.WaitGroup + RecordsSent int // Records sent to RecordsIn + // TODO: the only reason we have this is because we decided to end handlers + // with "defer instr.waitForWorkersAndCloseErrorsOut(logger, pCtx)" - not the cleanest way, get rid of this bool thingy. + // That defer is convenient because there are so many early returns. + RecordsInOpen bool +} + +type WriteChannelItem struct { + TableRecord *TableRecord + IndexKeyMap map[string]string +} + +var seedCounter = int64(0) + +func newSeed() int64 { + seedCounter += 3333 + return (time.Now().Unix() << 32) + time.Now().UnixMilli() + seedCounter +} + +func newTableInserter(envConfig *env.EnvConfig, pCtx *ctx.MessageProcessingContext, tableCreator *sc.TableCreatorDef, batchSize int) *TableInserter { + + return &TableInserter{ + PCtx: pCtx, + TableCreator: tableCreator, + BatchSize: batchSize, + ErrorsOut: make(chan error, batchSize), + RowidRand: rand.New(rand.NewSource(newSeed())), + NumWorkers: envConfig.Cassandra.WriterWorkers, + MinInserterRate: envConfig.Cassandra.MinInserterRate, + RecordsInOpen: false, + } +} + +func CreateDataTableCql(keyspace string, runId int16, tableCreator *sc.TableCreatorDef) string { + qb := cql.NewQB() + qb.ColumnDef("rowid", sc.FieldTypeInt) + qb.ColumnDef("batch_idx", sc.FieldTypeInt) + for fieldName, fieldDef := range tableCreator.Fields { + qb.ColumnDef(fieldName, fieldDef.Type) + } + return qb.PartitionKey("rowid").Keyspace(keyspace).CreateRun(tableCreator.Name, runId, cql.IgnoreIfExists) +} + +func CreateIdxTableCql(keyspace string, runId int16, idxName string, idxDef *sc.IdxDef) string { + qb := cql.NewQB() + qb.Keyspace(keyspace). + ColumnDef("key", sc.FieldTypeString). + ColumnDef("rowid", sc.FieldTypeInt) + if idxDef.Uniqueness == sc.IdxUnique { + // Key must be unique, let Cassandra enforce it for us: PRIMARY KEY (key) + qb.PartitionKey("key") + } else { + // There can be multiple rowids with the same key: PRIMARY KEY (key, rowid) + qb.PartitionKey("key") + qb.ClusteringKey("rowid") + } + return qb.CreateRun(idxName, runId, cql.IgnoreIfExists) +} + +// Obsolete: now we create all run-specific tables in api.StartRun +// +// func (instr *TableInserter) verifyTablesExist() error { +// q := CreateDataTableCql(instr.PCtx.BatchInfo.DataKeyspace, instr.PCtx.BatchInfo.RunId, instr.TableCreator) +// if err := instr.PCtx.CqlSession.Query(q).Exec(); err != nil { +// return db.WrapDbErrorWithQuery("cannot create data table", q, err) +// } + +// for idxName, idxDef := range instr.TableCreator.Indexes { +// q := CreateIdxTableCql(instr.PCtx.BatchInfo.DataKeyspace, instr.PCtx.BatchInfo.RunId, idxName, idxDef) +// if err := instr.PCtx.CqlSession.Query(q).Exec(); err != nil { +// return db.WrapDbErrorWithQuery("cannot create idx table", q, err) +// } +// } +// return nil +// } + +func (instr *TableInserter) startWorkers(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) error { + logger.PushF("proc.startWorkers/TableInserter") + defer logger.PopF() + + instr.RecordsIn = make(chan WriteChannelItem, instr.BatchSize) + logger.DebugCtx(pCtx, "startWorkers created RecordsIn,now launching %d writers...", instr.NumWorkers) + instr.RecordsInOpen = true + + for w := 0; w < instr.NumWorkers; w++ { + newLogger, err := l.NewLoggerFromLogger(logger) + if err != nil { + return err + } + // Increase busy worker count + instr.WorkerWaitGroup.Add(1) + go instr.tableInserterWorker(newLogger, pCtx) + } + return nil +} + +func (instr *TableInserter) waitForWorkers(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) error { + logger.PushF("proc.waitForWorkers/TableInserter") + defer logger.PopF() + + logger.DebugCtx(pCtx, "started reading RecordsSent=%d from instr.ErrorsOut", instr.RecordsSent) + + errors := make([]string, 0) + if instr.RecordsSent > 0 { + errCount := 0 + startTime := time.Now() + // 1. It's crucial that the number of errors to receive eventually should match instr.RecordsSent + // 2. We do not need an extra select/timeout here - we are guaranteed to receive something in instr.ErrorsOut because of cassndra read timeouts (5-15s or so) + for i := 0; i < instr.RecordsSent; i++ { + err := <-instr.ErrorsOut + if err != nil { + errors = append(errors, err.Error()) + errCount++ + } + + inserterRate := float64(i+1) / time.Since(startTime).Seconds() + // If it falls below min rate, it does not make sense to continue + if i > 5 && inserterRate < float64(instr.MinInserterRate) { + logger.DebugCtx(pCtx, "slow db insertion rate triggered, will stop reading from instr.ErrorsOut") + errors = append(errors, fmt.Sprintf("table inserter detected slow db insertion rate %.0f records/s, wrote %d records out of %d", inserterRate, i, instr.RecordsSent)) + errCount++ + break + } + } + logger.DebugCtx(pCtx, "done writing RecordsSent=%d from instr.ErrorsOut, %d errors", instr.RecordsSent, errCount) + + // Reset for the next cycle, if it ever happens + instr.RecordsSent = 0 + } else { + logger.DebugCtx(pCtx, "no need to waitfor writer results, no records were sent") + } + + // Close instr.RecordsIn, it will trigger the completion of all writer workers + if instr.RecordsInOpen { + logger.DebugCtx(pCtx, "closing RecordsIn") + close(instr.RecordsIn) + logger.DebugCtx(pCtx, "closed RecordsIn") + instr.RecordsInOpen = false + } + + // Wait for all writer threads to complete, otherwise they will keep writing to instr.ErrorsOut, which can close anytime after we exit this function + logger.DebugCtx(pCtx, "waiting for writer workers to complete...") + instr.WorkerWaitGroup.Wait() + logger.DebugCtx(pCtx, "writer workers are done") + + if len(errors) > 0 { + return fmt.Errorf(strings.Join(errors, "; ")) + } + + return nil +} + +func (instr *TableInserter) waitForWorkersAndCloseErrorsOut(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) { + logger.PushF("proc.waitForWorkersAndClose/TableInserter") + defer logger.PopF() + + // Make sure no workers are running, so they do not hit closed ErrorsOut + if err := instr.waitForWorkers(logger, pCtx); err != nil { + logger.ErrorCtx(pCtx, fmt.Sprintf("error(s) while waiting for workers to complete: %s", err.Error())) + } + + // Safe to close now + logger.DebugCtx(pCtx, "closing ErrorsOut") + close(instr.ErrorsOut) + logger.DebugCtx(pCtx, "closed ErrorsOut") +} + +func (instr *TableInserter) add(tableRecord TableRecord) error { + indexKeyMap := map[string]string{} + for idxName, idxDef := range instr.TableCreator.Indexes { + var err error + indexKeyMap[idxName], err = sc.BuildKey(tableRecord, idxDef) + if err != nil { + return fmt.Errorf("cannot build key for idx %s, table record [%v]: [%s]", idxName, tableRecord, err.Error()) + } + } + + instr.RecordsSent++ + instr.RecordsIn <- WriteChannelItem{TableRecord: &tableRecord, IndexKeyMap: indexKeyMap} + + return nil +} + +func (instr *TableInserter) tableInserterWorker(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) { + logger.PushF("proc.tableInserterWorker") + defer logger.PopF() + + logger.DebugCtx(pCtx, "writer started reading from RecordsIn") + dataTableName := instr.TableCreator.Name + cql.RunIdSuffix(instr.PCtx.BatchInfo.RunId) + + var dataQb *cql.QueryBuilder + var preparedDataQueryErr error + var preparedDataQuery string + + handledRecordCount := 0 + for writeItem := range instr.RecordsIn { + handledRecordCount++ + maxDataRetries := 5 + curDataExpBackoffFactor := 1 + var errorToReport error + + if preparedDataQueryErr != nil { + instr.ErrorsOut <- fmt.Errorf("cannot prepare data query: %s", preparedDataQueryErr) + continue + } else if dataQb == nil { + dataQb = cql.NewQB() + if err := dataQb.WritePreparedColumn("rowid"); err != nil { + instr.ErrorsOut <- errorToReport + continue // next insert + } + if err := dataQb.WritePreparedColumn("batch_idx"); err != nil { + instr.ErrorsOut <- errorToReport + continue // next insert + } + if err := dataQb.WritePreparedValue("batch_idx", instr.PCtx.BatchInfo.BatchIdx); err != nil { + instr.ErrorsOut <- errorToReport + continue // next insert + } + + for fieldName := range *writeItem.TableRecord { + if err := dataQb.WritePreparedColumn(fieldName); err != nil { + errorToReport = fmt.Errorf("cannot prepare data query: %s", err) + break + } + } + if errorToReport != nil { + instr.ErrorsOut <- errorToReport + continue // next insert + } + + var err error + preparedDataQuery, err = dataQb.Keyspace(instr.PCtx.BatchInfo.DataKeyspace). + InsertRunPreparedQuery(instr.TableCreator.Name, instr.PCtx.BatchInfo.RunId, cql.IgnoreIfExists) // INSERT IF NOT EXISTS; if exists, returned isApplied = false + if err != nil { + instr.ErrorsOut <- fmt.Errorf("cannot prepare data query: %s", err) + continue // next insert + } + } + + instr.RandMutex.Lock() + (*writeItem.TableRecord)["rowid"] = instr.RowidRand.Int63() + instr.RandMutex.Unlock() + + for fieldName, fieldValue := range *writeItem.TableRecord { + if err := dataQb.WritePreparedValue(fieldName, fieldValue); err != nil { + instr.ErrorsOut <- fmt.Errorf("cannot write prepared value for %s: %s", fieldName, err.Error()) + continue // next insert + } + } + preparedDataQueryParams, err := dataQb.InsertRunParams() + if err != nil { + instr.ErrorsOut <- fmt.Errorf("cannot provide insert params for prepared query %s: %s", preparedDataQuery, err.Error()) + continue // next insert + } + + for dataRetryCount := 0; dataRetryCount < maxDataRetries; dataRetryCount++ { + + existingDataRow := map[string]any{} + isApplied, err := instr.PCtx.CqlSession.Query(preparedDataQuery, preparedDataQueryParams...).MapScanCAS(existingDataRow) + + if err == nil { + if isApplied { + // Success + break + } + + // This rowidw as already there, retry or give up + if dataRetryCount >= maxDataRetries-1 { + // No more retries + logger.ErrorCtx(instr.PCtx, "duplicate rowid not written [%s], existing record [%v], retry count %d reached, giving up", preparedDataQuery, existingDataRow, dataRetryCount) + errorToReport = fmt.Errorf("cannot write to data table after multiple attempts, keep getting rowid duplicates [%s]", preparedDataQuery) + break + } + + // Retry now with a new rowid + logger.InfoCtx(instr.PCtx, "duplicate rowid not written [%s], existing record [%v], retry count %d", preparedDataQuery, existingDataRow, dataRetryCount) + instr.RandMutex.Lock() + instr.RowidRand = rand.New(rand.NewSource(newSeed())) + (*writeItem.TableRecord)["rowid"] = instr.RowidRand.Int63() + instr.RandMutex.Unlock() + + // Set new rowid and re-build query params array (shouldn't throw errors this time) + if err := dataQb.WritePreparedValue("rowid", (*writeItem.TableRecord)["rowid"]); err != nil { + errorToReport = fmt.Errorf("cannot prepared value to rowid: %s", err.Error()) + break + } + + // Will retry (if retry count allows) + preparedDataQueryParams, _ = dataQb.InsertRunParams() + } else { + if strings.Contains(err.Error(), "does not exist") { + // There is a chance this table is brand new and table schema was not propagated to all Cassandra nodes + if dataRetryCount >= maxDataRetries-1 { + errorToReport = fmt.Errorf("cannot write to data table %s after %d attempts, apparently, table schema still not propagated to all nodes: %s", dataTableName, dataRetryCount+1, err.Error()) + break + } + + logger.WarnCtx(instr.PCtx, "will wait for table %s to be created, retry count %d, got %s", dataTableName, dataRetryCount, err.Error()) + // TODO: come up with a better waiting strategy (exp backoff, at least) + time.Sleep(5 * time.Second) + } else if strings.Contains(err.Error(), "Operation timed out") { + // The cluster is overloaded, slow down + if dataRetryCount >= maxDataRetries-1 { + errorToReport = fmt.Errorf("cannot write to data table %s after %d attempts, still getting timeouts: %s", dataTableName, dataRetryCount+1, err.Error()) + break + } + logger.WarnCtx(instr.PCtx, "cluster overloaded (%s), will wait for %dms before writing to data table %s again, retry count %d", err.Error(), 10*curDataExpBackoffFactor, dataTableName, dataRetryCount) + time.Sleep(time.Duration(10*curDataExpBackoffFactor) * time.Millisecond) + curDataExpBackoffFactor *= 2 + } else { + // Some serious error happened, stop trying this rowid + errorToReport = db.WrapDbErrorWithQuery("cannot write to data table", preparedDataQuery, err) + break + } + } + } // data retry loop + + if errorToReport == nil { + // Index tables + for idxName, idxDef := range instr.TableCreator.Indexes { + + maxIdxRetries := 5 + idxTableName := idxName + cql.RunIdSuffix(instr.PCtx.BatchInfo.RunId) + curIdxExpBackoffFactor := 1 + + ifNotExistsFlag := cql.ThrowIfExists + if idxDef.Uniqueness == sc.IdxUnique { + ifNotExistsFlag = cql.IgnoreIfExists + } + + idxQb := cql.NewQB() + if err := idxQb.WritePreparedColumn("key"); err != nil { + errorToReport = err + break + } + if err := idxQb.WritePreparedValue("key", writeItem.IndexKeyMap[idxName]); err != nil { + errorToReport = err + break + } + if err := idxQb.WritePreparedColumn("rowid"); err != nil { + errorToReport = err + break + } + if err := idxQb.WritePreparedValue("rowid", (*writeItem.TableRecord)["rowid"]); err != nil { + errorToReport = err + break + } + + preparedIdxQuery, err := idxQb.Keyspace(instr.PCtx.BatchInfo.DataKeyspace).InsertRunPreparedQuery(idxName, instr.PCtx.BatchInfo.RunId, ifNotExistsFlag) + if err != nil { + errorToReport = fmt.Errorf("cannot prepare idx query: %s", err.Error()) + break + } + preparedIdxQueryParams, err := idxQb.InsertRunParams() + if err != nil { + errorToReport = fmt.Errorf("cannot provide idx query params for %s: %s", preparedIdxQuery, err.Error()) + break + } + + for idxRetryCount := 0; idxRetryCount < maxIdxRetries; idxRetryCount++ { + existingIdxRow := map[string]any{} + var isApplied = true + var err error + if idxDef.Uniqueness == sc.IdxUnique { + // Unique idx assumed, check isApplied + isApplied, err = instr.PCtx.CqlSession.Query(preparedIdxQuery, preparedIdxQueryParams...).MapScanCAS(existingIdxRow) + } else { + // No uniqueness assumed, just insert + err = instr.PCtx.CqlSession.Query(preparedIdxQuery, preparedIdxQueryParams...).Exec() + } + + if err == nil { + if !isApplied { + // If attempt number > 0, there is a chance that Cassandra managed to insert the record on the previous attempt but returned an error + if idxRetryCount > 0 && existingIdxRow["key"] == writeItem.IndexKeyMap[idxName] && existingIdxRow["rowid"] == (*writeItem.TableRecord)["rowid"] { + // Cassandra screwed up, but we know how to handle it, the record is there, just log a warning + logger.WarnCtx(instr.PCtx, "duplicate idx record found (%s) in idx %s on retry %d when writing (%d,'%s'), assuming this retry was successful, proceeding as usual", idxName, existingIdxRow, idxRetryCount, (*writeItem.TableRecord)["rowid"], writeItem.IndexKeyMap[idxName]) + } else { + // We screwed up, report everything we can + errorToReport = fmt.Errorf("cannot write duplicate index key [%s] on retry %d, existing record [%v]", preparedDataQuery, idxRetryCount, existingIdxRow) + } + } + // Success or not - we are done + break + } + if strings.Contains(err.Error(), "does not exist") { + // There is a chance this table is brand new and table schema was not propagated to all Cassandra nodes + if idxRetryCount >= maxIdxRetries-1 { + errorToReport = fmt.Errorf("cannot write to idx table %s after %d attempts, apparently, table schema still not propagated to all nodes: %s", idxTableName, idxRetryCount+1, err.Error()) + break + } + logger.WarnCtx(instr.PCtx, "will wait for idx table %s to be created, retry count %d, got %s", idxTableName, idxRetryCount, err.Error()) + // TODO: come up with a better waiting strategy (exp backoff, at least) + time.Sleep(5 * time.Second) + } else if strings.Contains(err.Error(), "Operation timed out") { + // The cluster is overloaded, slow down + if idxRetryCount >= maxIdxRetries-1 { + errorToReport = fmt.Errorf("cannot write to idx table %s after %d attempts, still getting timeout: %s", idxTableName, idxRetryCount+1, err.Error()) + break + } + logger.WarnCtx(instr.PCtx, "cluster overloaded (%s), will wait for %dms before writing to idx table %s again, retry count %d", err.Error(), 10*curIdxExpBackoffFactor, idxTableName, idxRetryCount) + time.Sleep(time.Duration(10*curIdxExpBackoffFactor) * time.Millisecond) + curIdxExpBackoffFactor *= 2 + } else { + // Some serious error happened, stop trying this idx record + errorToReport = db.WrapDbErrorWithQuery("cannot write to idx table", preparedDataQuery, err) + break + } + } // idx retry loop + } // idx loop + } + // logger.DebugCtx(pCtx, "writer wrote") + instr.ErrorsOut <- errorToReport + } + logger.DebugCtx(pCtx, "done reading from RecordsIn, this writer worker handled %d records from instr.RecordsIn", handledRecordCount) + // Decrease busy worker count + instr.WorkerWaitGroup.Done() +} diff --git a/pkg/sc/custom_processor_def.go b/pkg/sc/custom_processor_def.go index 9f3f557..6259a2a 100644 --- a/pkg/sc/custom_processor_def.go +++ b/pkg/sc/custom_processor_def.go @@ -1,15 +1,15 @@ -package sc - -import ( - "encoding/json" -) - -type CustomProcessorDefFactory interface { - Create(processorType string) (CustomProcessorDef, bool) -} - -type CustomProcessorDef interface { - Deserialize(raw json.RawMessage, customProcSettings json.RawMessage, caPath string, privateKeys map[string]string) error - GetFieldRefs() *FieldRefs - GetUsedInTargetExpressionsFields() *FieldRefs -} +package sc + +import ( + "encoding/json" +) + +type CustomProcessorDefFactory interface { + Create(processorType string) (CustomProcessorDef, bool) +} + +type CustomProcessorDef interface { + Deserialize(raw json.RawMessage, customProcSettings json.RawMessage, caPath string, privateKeys map[string]string) error + GetFieldRefs() *FieldRefs + GetUsedInTargetExpressionsFields() *FieldRefs +} diff --git a/pkg/sc/dependency_policy_def.go b/pkg/sc/dependency_policy_def.go index 8b149c4..ac2f1e7 100644 --- a/pkg/sc/dependency_policy_def.go +++ b/pkg/sc/dependency_policy_def.go @@ -1,137 +1,123 @@ -package sc - -import ( - "encoding/json" - "fmt" - "go/ast" - - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/capillariesio/capillaries/pkg/wfmodel" -) - -// This conf should be never referenced in prod code. It's always in the the config.json. Or in the unit tests. -const DefaultPolicyCheckerConf string = ` -{ - "is_default": true, - "event_priority_order": "run_is_current(desc),node_start_ts(desc)", - "rules": [ - {"cmd": "go", "expression": "e.run_is_current == true && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchSuccess" }, - {"cmd": "wait", "expression": "e.run_is_current == true && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchNone" }, - {"cmd": "wait", "expression": "e.run_is_current == true && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchStart" }, - {"cmd": "nogo", "expression": "e.run_is_current == true && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchFail" }, - - {"cmd": "go", "expression": "e.run_is_current == false && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchSuccess" }, - {"cmd": "wait", "expression": "e.run_is_current == false && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchNone" }, - {"cmd": "wait", "expression": "e.run_is_current == false && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchStart" }, - - {"cmd": "go", "expression": "e.run_is_current == false && e.run_final_status == wfmodel.RunComplete && e.node_status == wfmodel.NodeBatchSuccess" }, - {"cmd": "nogo", "expression": "e.run_is_current == false && e.run_final_status == wfmodel.RunComplete && e.node_status == wfmodel.NodeBatchFail" } - ] -}` - -type ReadyToRunNodeCmdType string - -const ( - NodeNone ReadyToRunNodeCmdType = "none" - NodeGo ReadyToRunNodeCmdType = "go" - NodeWait ReadyToRunNodeCmdType = "wait" - NodeNogo ReadyToRunNodeCmdType = "nogo" -) - -type DependencyRule struct { - Cmd ReadyToRunNodeCmdType `json:"cmd"` - RawExpression string `json:"expression"` - ParsedExpression ast.Expr -} - -// type EventPriorityOrderDirection string - -// const ( -// EventSortAsc EventPriorityOrderDirection = "asc" -// EventSortDesc EventPriorityOrderDirection = "desc" -// EventSortUnknown EventPriorityOrderDirection = "unknown" -// ) - -// type EventPriorityOrderField struct { -// FieldName string -// Direction EventPriorityOrderDirection -// } - -type DependencyPolicyDef struct { - EventPriorityOrderString string `json:"event_priority_order"` - IsDefault bool `json:"is_default"` - Rules []DependencyRule `json:"rules"` - OrderIdxDef IdxDef - //EventPriorityOrder []EventPriorityOrderField -} - -func NewFieldRefsFromNodeEvent() *FieldRefs { - return &FieldRefs{ - {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "run_id", FieldType: FieldTypeInt}, - {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "run_is_current", FieldType: FieldTypeBool}, - {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "run_start_ts", FieldType: FieldTypeDateTime}, - {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "run_final_status", FieldType: FieldTypeInt}, - {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "run_completed_ts", FieldType: FieldTypeDateTime}, - {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "run_stopped_ts", FieldType: FieldTypeDateTime}, - {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "node_is_started", FieldType: FieldTypeBool}, - {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "node_start_ts", FieldType: FieldTypeDateTime}, - {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "node_status", FieldType: FieldTypeInt}, - {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "node_status_ts", FieldType: FieldTypeDateTime}} -} - -func (polDef *DependencyPolicyDef) Deserialize(rawPol json.RawMessage) error { - var err error - if err = json.Unmarshal(rawPol, polDef); err != nil { - return fmt.Errorf("cannot unmarshal dependency policy: [%s]", err.Error()) - } - - if err = polDef.parseEventPriorityOrderString(); err != nil { - return err - } - - vars := wfmodel.NewVarsFromDepCtx(0, wfmodel.DependencyNodeEvent{}) - for ruleIdx := 0; ruleIdx < len(polDef.Rules); ruleIdx++ { - usedFieldRefs := FieldRefs{} - polDef.Rules[ruleIdx].ParsedExpression, err = ParseRawGolangExpressionStringAndHarvestFieldRefs(polDef.Rules[ruleIdx].RawExpression, &usedFieldRefs) - if err != nil { - return fmt.Errorf("cannot parse rule expression '%s': %s", polDef.Rules[ruleIdx].RawExpression, err.Error()) - } - - for _, fr := range usedFieldRefs { - fieldSubMap, ok := vars[fr.TableName] - if !ok { - return fmt.Errorf("cannot parse rule expression '%s': all fields must be prefixed with one of these : %s", polDef.Rules[ruleIdx].RawExpression, vars.Tables()) - } - if _, ok := fieldSubMap[fr.FieldName]; !ok { - return fmt.Errorf("cannot parse rule expression '%s': field %s.%s not found, available fields are %s", polDef.Rules[ruleIdx].RawExpression, fr.TableName, fr.FieldName, vars.Names()) - } - } - } - return nil -} - -func (polDef *DependencyPolicyDef) parseEventPriorityOrderString() error { - idxDefMap := IdxDefMap{} - rawIndexes := map[string]string{"order_by": fmt.Sprintf("non_unique(%s)", polDef.EventPriorityOrderString)} - if err := idxDefMap.parseRawIndexDefMap(rawIndexes, NewFieldRefsFromNodeEvent()); err != nil { - return fmt.Errorf("cannot parse event order string '%s': %s", polDef.EventPriorityOrderString, err.Error()) - } - polDef.OrderIdxDef = *idxDefMap["order_by"] - - return nil -} - -func (polDef *DependencyPolicyDef) evalRuleExpressionsAndCheckType() error { - vars := wfmodel.NewVarsFromDepCtx(0, wfmodel.DependencyNodeEvent{}) - eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &vars) - for ruleIdx, rule := range polDef.Rules { - result, err := eCtx.Eval(rule.ParsedExpression) - if err != nil { - return fmt.Errorf("invalid rule %d expression '%s': %s", ruleIdx, rule.RawExpression, err.Error()) - } - if err := CheckValueType(result, FieldTypeBool); err != nil { - return fmt.Errorf("invalid rule %d expression '%s' type: %s", ruleIdx, rule.RawExpression, err.Error()) - } - } - return nil -} +package sc + +import ( + "encoding/json" + "fmt" + "go/ast" + + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/capillariesio/capillaries/pkg/wfmodel" +) + +// This conf should be never referenced in prod code. It's always in the the config.json. Or in the unit tests. +const DefaultPolicyCheckerConf string = ` +{ + "is_default": true, + "event_priority_order": "run_is_current(desc),node_start_ts(desc)", + "rules": [ + {"cmd": "go", "expression": "e.run_is_current == true && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchSuccess" }, + {"cmd": "wait", "expression": "e.run_is_current == true && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchNone" }, + {"cmd": "wait", "expression": "e.run_is_current == true && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchStart" }, + {"cmd": "nogo", "expression": "e.run_is_current == true && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchFail" }, + + {"cmd": "go", "expression": "e.run_is_current == false && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchSuccess" }, + {"cmd": "wait", "expression": "e.run_is_current == false && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchNone" }, + {"cmd": "wait", "expression": "e.run_is_current == false && e.run_final_status == wfmodel.RunStart && e.node_status == wfmodel.NodeBatchStart" }, + + {"cmd": "go", "expression": "e.run_is_current == false && e.run_final_status == wfmodel.RunComplete && e.node_status == wfmodel.NodeBatchSuccess" }, + {"cmd": "nogo", "expression": "e.run_is_current == false && e.run_final_status == wfmodel.RunComplete && e.node_status == wfmodel.NodeBatchFail" } + ] +}` + +type ReadyToRunNodeCmdType string + +const ( + NodeNone ReadyToRunNodeCmdType = "none" + NodeGo ReadyToRunNodeCmdType = "go" + NodeWait ReadyToRunNodeCmdType = "wait" + NodeNogo ReadyToRunNodeCmdType = "nogo" +) + +type DependencyRule struct { + Cmd ReadyToRunNodeCmdType `json:"cmd"` + RawExpression string `json:"expression"` + ParsedExpression ast.Expr +} + +type DependencyPolicyDef struct { + EventPriorityOrderString string `json:"event_priority_order"` + IsDefault bool `json:"is_default"` + Rules []DependencyRule `json:"rules"` + OrderIdxDef IdxDef +} + +func NewFieldRefsFromNodeEvent() *FieldRefs { + return &FieldRefs{ + {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "run_id", FieldType: FieldTypeInt}, + {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "run_is_current", FieldType: FieldTypeBool}, + {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "run_start_ts", FieldType: FieldTypeDateTime}, + {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "run_final_status", FieldType: FieldTypeInt}, + {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "run_completed_ts", FieldType: FieldTypeDateTime}, + {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "run_stopped_ts", FieldType: FieldTypeDateTime}, + {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "node_is_started", FieldType: FieldTypeBool}, + {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "node_start_ts", FieldType: FieldTypeDateTime}, + {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "node_status", FieldType: FieldTypeInt}, + {TableName: wfmodel.DependencyNodeEventTableName, FieldName: "node_status_ts", FieldType: FieldTypeDateTime}} +} + +func (polDef *DependencyPolicyDef) Deserialize(rawPol json.RawMessage) error { + var err error + if err = json.Unmarshal(rawPol, polDef); err != nil { + return fmt.Errorf("cannot unmarshal dependency policy: [%s]", err.Error()) + } + + if err = polDef.parseEventPriorityOrderString(); err != nil { + return err + } + + vars := wfmodel.NewVarsFromDepCtx(wfmodel.DependencyNodeEvent{}) + for ruleIdx := 0; ruleIdx < len(polDef.Rules); ruleIdx++ { + usedFieldRefs := FieldRefs{} + polDef.Rules[ruleIdx].ParsedExpression, err = ParseRawGolangExpressionStringAndHarvestFieldRefs(polDef.Rules[ruleIdx].RawExpression, &usedFieldRefs) + if err != nil { + return fmt.Errorf("cannot parse rule expression '%s': %s", polDef.Rules[ruleIdx].RawExpression, err.Error()) + } + + for _, fr := range usedFieldRefs { + fieldSubMap, ok := vars[fr.TableName] + if !ok { + return fmt.Errorf("cannot parse rule expression '%s': all fields must be prefixed with one of these : %s", polDef.Rules[ruleIdx].RawExpression, vars.Tables()) + } + if _, ok := fieldSubMap[fr.FieldName]; !ok { + return fmt.Errorf("cannot parse rule expression '%s': field %s.%s not found, available fields are %s", polDef.Rules[ruleIdx].RawExpression, fr.TableName, fr.FieldName, vars.Names()) + } + } + } + return nil +} + +func (polDef *DependencyPolicyDef) parseEventPriorityOrderString() error { + idxDefMap := IdxDefMap{} + rawIndexes := map[string]string{"order_by": fmt.Sprintf("non_unique(%s)", polDef.EventPriorityOrderString)} + if err := idxDefMap.parseRawIndexDefMap(rawIndexes, NewFieldRefsFromNodeEvent()); err != nil { + return fmt.Errorf("cannot parse event order string '%s': %s", polDef.EventPriorityOrderString, err.Error()) + } + polDef.OrderIdxDef = *idxDefMap["order_by"] + + return nil +} + +func (polDef *DependencyPolicyDef) evalRuleExpressionsAndCheckType() error { + vars := wfmodel.NewVarsFromDepCtx(wfmodel.DependencyNodeEvent{}) + eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &vars) + for ruleIdx, rule := range polDef.Rules { + result, err := eCtx.Eval(rule.ParsedExpression) + if err != nil { + return fmt.Errorf("invalid rule %d expression '%s': %s", ruleIdx, rule.RawExpression, err.Error()) + } + if err := CheckValueType(result, FieldTypeBool); err != nil { + return fmt.Errorf("invalid rule %d expression '%s' type: %s", ruleIdx, rule.RawExpression, err.Error()) + } + } + return nil +} diff --git a/pkg/sc/dependency_policy_def_test.go b/pkg/sc/dependency_policy_def_test.go index 3ba92ec..38e2e93 100644 --- a/pkg/sc/dependency_policy_def_test.go +++ b/pkg/sc/dependency_policy_def_test.go @@ -1,47 +1,47 @@ -package sc - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestDependencyPolicyBad(t *testing.T) { - - polDef := DependencyPolicyDef{} - - conf := `{"event_priority_order": "run_i(asc)", "rules": []}` - err := polDef.Deserialize([]byte(conf)) - assertErrorPrefix(t, "cannot parse event order string 'run_i(asc)'", err.Error()) - - conf = `{"event_priority_order": "run_id(", "rules": []}` - err = polDef.Deserialize([]byte(conf)) - assertErrorPrefix(t, "cannot parse event order string 'run_id(': cannot parse order def 'non_unique(run_id()'", err.Error()) - - conf = `{"event_priority_order": "run_id(bad)", "rules": []}` - err = polDef.Deserialize([]byte(conf)) - assertErrorPrefix(t, "cannot parse event order string 'run_id(bad)'", err.Error()) - - conf = `{"event_priority_order": "run_id(asc)", "rules": [{ "cmd": "go", "expression": "e.run_is_current && e.run_final_status == bad" }]}` - err = polDef.Deserialize([]byte(conf)) - assertErrorPrefix(t, "cannot parse rule expression 'e.run_is_current && e.run_final_status == bad': plain (non-selector) identifiers", err.Error()) -} - -func TestDependencyPolicyGood(t *testing.T) { - - polDef := DependencyPolicyDef{} - - conf := `{"event_priority_order": "run_id(asc),run_is_current(desc)", "rules": [ - {"cmd": "go", "expression": "e.run_is_current && time.DiffMilli(e.run_start_ts, e.node_status_ts) > 0 && e.run_final_status == wfmodel.RunStart" }, - {"cmd": "go", "expression": "time.DiffMilli(e.run_start_ts, time.Parse(\"2006-01-02 15:04:05\",\"2000-01-01 00:00:00.000\")) > 0 && e.run_is_current == true" } - ]}` - assert.Nil(t, polDef.Deserialize([]byte(conf))) - - assert.Equal(t, "run_id", polDef.OrderIdxDef.Components[0].FieldName) - assert.Equal(t, IdxSortAsc, polDef.OrderIdxDef.Components[0].SortOrder) - assert.Equal(t, "run_is_current", polDef.OrderIdxDef.Components[1].FieldName) - assert.Equal(t, IdxSortDesc, polDef.OrderIdxDef.Components[1].SortOrder) - assert.Equal(t, NodeGo, polDef.Rules[0].Cmd) - - assert.Nil(t, polDef.evalRuleExpressionsAndCheckType()) -} +package sc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDependencyPolicyBad(t *testing.T) { + + polDef := DependencyPolicyDef{} + + conf := `{"event_priority_order": "run_i(asc)", "rules": []}` + err := polDef.Deserialize([]byte(conf)) + assertErrorPrefix(t, "cannot parse event order string 'run_i(asc)'", err.Error()) + + conf = `{"event_priority_order": "run_id(", "rules": []}` + err = polDef.Deserialize([]byte(conf)) + assertErrorPrefix(t, "cannot parse event order string 'run_id(': cannot parse order def 'non_unique(run_id()'", err.Error()) + + conf = `{"event_priority_order": "run_id(bad)", "rules": []}` + err = polDef.Deserialize([]byte(conf)) + assertErrorPrefix(t, "cannot parse event order string 'run_id(bad)'", err.Error()) + + conf = `{"event_priority_order": "run_id(asc)", "rules": [{ "cmd": "go", "expression": "e.run_is_current && e.run_final_status == bad" }]}` + err = polDef.Deserialize([]byte(conf)) + assertErrorPrefix(t, "cannot parse rule expression 'e.run_is_current && e.run_final_status == bad': plain (non-selector) identifiers", err.Error()) +} + +func TestDependencyPolicyGood(t *testing.T) { + + polDef := DependencyPolicyDef{} + + conf := `{"event_priority_order": "run_id(asc),run_is_current(desc)", "rules": [ + {"cmd": "go", "expression": "e.run_is_current && time.DiffMilli(e.run_start_ts, e.node_status_ts) > 0 && e.run_final_status == wfmodel.RunStart" }, + {"cmd": "go", "expression": "time.DiffMilli(e.run_start_ts, time.Parse(\"2006-01-02 15:04:05\",\"2000-01-01 00:00:00.000\")) > 0 && e.run_is_current == true" } + ]}` + assert.Nil(t, polDef.Deserialize([]byte(conf))) + + assert.Equal(t, "run_id", polDef.OrderIdxDef.Components[0].FieldName) + assert.Equal(t, IdxSortAsc, polDef.OrderIdxDef.Components[0].SortOrder) + assert.Equal(t, "run_is_current", polDef.OrderIdxDef.Components[1].FieldName) + assert.Equal(t, IdxSortDesc, polDef.OrderIdxDef.Components[1].SortOrder) + assert.Equal(t, NodeGo, polDef.Rules[0].Cmd) + + assert.Nil(t, polDef.evalRuleExpressionsAndCheckType()) +} diff --git a/pkg/sc/field_ref.go b/pkg/sc/field_ref.go index c144be5..03d45ed 100644 --- a/pkg/sc/field_ref.go +++ b/pkg/sc/field_ref.go @@ -1,340 +1,339 @@ -package sc - -import ( - "fmt" - "go/ast" - "go/parser" - "strings" - "time" - - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/shopspring/decimal" -) - -type FieldRef struct { - TableName string - FieldName string - FieldType TableFieldType -} - -func (fr *FieldRef) GetAliasHash() string { - return fmt.Sprintf("%s.%s", fr.TableName, fr.FieldName) -} - -type FieldRefs []FieldRef - -func (fieldRefs *FieldRefs) HasFieldsWithTableAlias(tableAlias string) bool { - for i := 0; i < len(*fieldRefs); i++ { - if tableAlias == (*fieldRefs)[i].TableName { - return true - } - } - return false -} - -func JoinFieldRefs(srcFieldRefs ...*FieldRefs) *FieldRefs { - hashes := map[string](*FieldRef){} - for i := 0; i < len(srcFieldRefs); i++ { - if srcFieldRefs[i] != nil { - for j := 0; j < len(*srcFieldRefs[i]); j++ { - hash := fmt.Sprintf("%s.%s", (*srcFieldRefs[i])[j].TableName, (*srcFieldRefs[i])[j].FieldName) - if _, ok := hashes[hash]; !ok { - hashes[hash] = &(*srcFieldRefs[i])[j] - } - } - } - } - - fieldRefs := make(FieldRefs, len(hashes)) - fieldRefCount := 0 - for _, fieldRef := range hashes { - fieldRefs[fieldRefCount] = *fieldRef - fieldRefCount++ - } - return &fieldRefs -} - -func RowidFieldRef(tableName string) FieldRef { - return FieldRef{ - TableName: tableName, - FieldName: "rowid", - FieldType: FieldTypeInt} -} - -func RowidTokenFieldRef() FieldRef { - return FieldRef{ - TableName: "db_system", - FieldName: "token(rowid)", - FieldType: FieldTypeInt} -} - -// func RunBatchRowidTokenFieldRef() FieldRef { -// return FieldRef{ -// TableName: "db_system", -// FieldName: "token(run_id,batch_idx,rowid)", -// FieldType: FieldTypeInt} -// } - -func KeyTokenFieldRef() FieldRef { - return FieldRef{ - TableName: "db_system", - FieldName: "token(key)", - FieldType: FieldTypeInt} -} - -// func RunBatchKeyTokenFieldRef() FieldRef { -// return FieldRef{ -// TableName: "db_system", -// FieldName: "token(run_id,batch_idx,key)", -// FieldType: FieldTypeInt} -// } -func IdxKeyFieldRef() FieldRef { - return FieldRef{ - TableName: "db_system", - FieldName: "key", - FieldType: FieldTypeString} -} - -func (fieldRefs *FieldRefs) contributeUnresolved(tableName string, fieldName string) { - // Check if it's already there - for i := 0; i < len(*fieldRefs); i++ { - if (*fieldRefs)[i].TableName == tableName && - (*fieldRefs)[i].FieldName == fieldName { - // Already there - return - } - } - - *fieldRefs = append(*fieldRefs, FieldRef{TableName: tableName, FieldName: fieldName, FieldType: FieldTypeUnknown}) -} - -func (fieldRefs *FieldRefs) Append(otherFieldRefs FieldRefs) { - fieldRefs.AppendWithFilter(otherFieldRefs, "") -} - -func (fieldRefs *FieldRefs) AppendWithFilter(otherFieldRefs FieldRefs, tableFilter string) { - fieldRefMap := map[string]FieldRef{} - - // Existing to map - for i := 0; i < len(*fieldRefs); i++ { - fieldRefMap[(*fieldRefs)[i].GetAliasHash()] = (*fieldRefs)[i] - } - - // New to map - for i := 0; i < len(otherFieldRefs); i++ { - if len(tableFilter) == 0 || tableFilter == (otherFieldRefs)[i].TableName { - fieldRefMap[(otherFieldRefs)[i].GetAliasHash()] = (otherFieldRefs)[i] - } - } - - *fieldRefs = make([]FieldRef, len(fieldRefMap)) - refIdx := 0 - for fieldRefHash := range fieldRefMap { - (*fieldRefs)[refIdx] = fieldRefMap[fieldRefHash] - refIdx++ - } -} - -func evalExpressionWithFieldRefsAndCheckType(exp ast.Expr, fieldRefs FieldRefs, expectedType TableFieldType) error { - if exp == nil { - // Nothing to evaluate - return nil - } - varValuesMap := eval.VarValuesMap{} - for i := 0; i < len(fieldRefs); i++ { - tName := fieldRefs[i].TableName - fName := fieldRefs[i].FieldName - fType := fieldRefs[i].FieldType - if _, ok := varValuesMap[tName]; !ok { - varValuesMap[tName] = map[string]interface{}{} - } - switch fType { - case FieldTypeInt: - varValuesMap[tName][fName] = int64(0) - case FieldTypeFloat: - varValuesMap[tName][fName] = float64(0.0) - case FieldTypeBool: - varValuesMap[tName][fName] = false - case FieldTypeString: - varValuesMap[tName][fName] = "12345.67" // There may be a float() call out there - case FieldTypeDateTime: - varValuesMap[tName][fName] = time.Now() - case FieldTypeDecimal2: - varValuesMap[tName][fName] = decimal.NewFromFloat(2.34) - default: - return fmt.Errorf("evalExpressionWithFieldRefsAndCheckType unsupported field type %s", fieldRefs[i].FieldType) - - } - } - - aggFuncEnabled, aggFuncType, aggFuncArgs := eval.DetectRootAggFunc(exp) - eCtx, err := eval.NewPlainEvalCtxWithVarsAndInitializedAgg(aggFuncEnabled, &varValuesMap, aggFuncType, aggFuncArgs) - if err != nil { - return err - } - - result, err := eCtx.Eval(exp) - if err != nil { - return err - } - - return CheckValueType(result, expectedType) -} - -func (fieldRefs *FieldRefs) FindByFieldName(fieldName string) (*FieldRef, bool) { - for i := 0; i < len(*fieldRefs); i++ { - if fieldName == (*fieldRefs)[i].FieldName { - return &(*fieldRefs)[i], true - } - } - return nil, false -} - -func checkAllowed(fieldRefsToCheck *FieldRefs, prohibitedFieldRefs *FieldRefs, allowedFieldRefs *FieldRefs) error { - if fieldRefsToCheck == nil { - return nil - } - - // Harvest allowed - allowedHashes := map[string](*FieldRef){} - if allowedFieldRefs != nil { - for i := 0; i < len(*allowedFieldRefs); i++ { - hash := fmt.Sprintf("%s.%s", (*allowedFieldRefs)[i].TableName, (*allowedFieldRefs)[i].FieldName) - if _, ok := allowedHashes[hash]; !ok { - allowedHashes[hash] = &(*allowedFieldRefs)[i] - } - } - } - - // Harvest prohibited - prohibitedHashes := map[string](*FieldRef){} - if prohibitedFieldRefs != nil { - for i := 0; i < len(*prohibitedFieldRefs); i++ { - hash := fmt.Sprintf("%s.%s", (*prohibitedFieldRefs)[i].TableName, (*prohibitedFieldRefs)[i].FieldName) - if _, ok := prohibitedHashes[hash]; !ok { - prohibitedHashes[hash] = &(*prohibitedFieldRefs)[i] - } - } - } - - errors := make([]string, 0, 2) - - for i := 0; i < len(*fieldRefsToCheck); i++ { - if len((*fieldRefsToCheck)[i].TableName) == 0 || len((*fieldRefsToCheck)[i].FieldName) == 0 { - errors = append(errors, fmt.Sprintf("dev error, empty FieldRef %s.%s", - (*fieldRefsToCheck)[i].TableName, (*fieldRefsToCheck)[i].FieldName)) - } - hash := fmt.Sprintf("%s.%s", (*fieldRefsToCheck)[i].TableName, (*fieldRefsToCheck)[i].FieldName) - if _, ok := prohibitedHashes[hash]; ok { - errors = append(errors, fmt.Sprintf("prohibited field %s.%s", (*fieldRefsToCheck)[i].TableName, (*fieldRefsToCheck)[i].FieldName)) - } else if _, ok := allowedHashes[hash]; !ok { - errors = append(errors, fmt.Sprintf("unknown field %s.%s", (*fieldRefsToCheck)[i].TableName, (*fieldRefsToCheck)[i].FieldName)) - } else { - // Update check field type, we will use it later for test eval - (*fieldRefsToCheck)[i].FieldType = allowedHashes[hash].FieldType - } - } - - if len(errors) > 0 { - return fmt.Errorf(strings.Join(errors, "; ")) - } else { - return nil - } - -} - -type FieldRefParserFlag uint32 - -func (f FieldRefParserFlag) HasFlag(flag FieldRefParserFlag) bool { return f&flag != 0 } - -// Not used for now, maybe later -// func (f *FieldRefParserFlag) AddFlag(flag FieldRefParserFlag) { *f |= flag } -// func (f *FieldRefParserFlag) ClearFlag(flag FieldRefParserFlag) { *f &= ^flag } -// func (f *FieldRefParserFlag) ToggleFlag(flag FieldRefParserFlag) { *f ^= flag } - -const ( - FieldRefStrict FieldRefParserFlag = 0 - FieldRefAllowUnknownIdents FieldRefParserFlag = 1 << iota - FieldRefAllowWhateverFeatureYouAreAddingHere -) - -func harvestFieldRefsFromParsedExpression(exp ast.Expr, usedFields *FieldRefs, parserFlags FieldRefParserFlag) error { - switch assertedExp := exp.(type) { - case *ast.BinaryExpr: - if err := harvestFieldRefsFromParsedExpression(assertedExp.X, usedFields, parserFlags); err != nil { - return err - } - return harvestFieldRefsFromParsedExpression(assertedExp.Y, usedFields, parserFlags) - - case *ast.UnaryExpr: - return harvestFieldRefsFromParsedExpression(assertedExp.X, usedFields, parserFlags) - - case *ast.CallExpr: - for _, v := range assertedExp.Args { - if err := harvestFieldRefsFromParsedExpression(v, usedFields, parserFlags); err != nil { - return err - } - } - - case *ast.SelectorExpr: - switch assertedExpIdent := assertedExp.X.(type) { - case *ast.Ident: - _, ok := eval.GolangConstants[fmt.Sprintf("%s.%s", assertedExpIdent.Name, assertedExp.Sel.Name)] - if !ok { - usedFields.contributeUnresolved(assertedExpIdent.Name, assertedExp.Sel.Name) - } - default: - return fmt.Errorf("selectors starting with non-ident are not allowed, found '%v'; aliases to use: readers - '%s', creators - '%s', custom processors - '%s', lookups - '%s'", - assertedExp.X, ReaderAlias, CreatorAlias, CustomProcessorAlias, LookupAlias) - } - - case *ast.Ident: - // Keep in mind we may use this parser for Python expressions. Allow unknown constructs for those cases. - if !parserFlags.HasFlag(FieldRefAllowUnknownIdents) { - if assertedExp.Name != "true" && assertedExp.Name != "false" { - return fmt.Errorf("plain (non-selector) identifiers are not allowed, expected field qualifiers (tableor_lkp_alias.field_name), found '%s'; for file readers, use '%s' alias; for file creators, use '%s' alias", - assertedExp.Name, ReaderAlias, CreatorAlias) - } - } - } - - return nil -} - -func ParseRawGolangExpressionStringAndHarvestFieldRefs(strExp string, usedFields *FieldRefs) (ast.Expr, error) { - if len(strings.TrimSpace(strExp)) == 0 { - return nil, nil - } - - expCondition, err := parser.ParseExpr(strExp) - if err != nil { - return nil, fmt.Errorf("strict parsing error: [%s]", err.Error()) - } - - if usedFields != nil { - if err := harvestFieldRefsFromParsedExpression(expCondition, usedFields, FieldRefStrict); err != nil { - return nil, err - } - } - - return expCondition, nil -} - -func ParseRawRelaxedGolangExpressionStringAndHarvestFieldRefs(strExp string, usedFields *FieldRefs, parserFlags FieldRefParserFlag) (ast.Expr, error) { - if len(strings.TrimSpace(strExp)) == 0 { - return nil, nil - } - - expCondition, err := parser.ParseExpr(strExp) - if err != nil { - return nil, fmt.Errorf("relaxed parsing error: [%s]", err.Error()) - } - - if usedFields != nil { - if err := harvestFieldRefsFromParsedExpression(expCondition, usedFields, parserFlags); err != nil { - return nil, err - } - } - - return expCondition, nil -} +package sc + +import ( + "fmt" + "go/ast" + "go/parser" + "strings" + "time" + + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/shopspring/decimal" +) + +type FieldRef struct { + TableName string + FieldName string + FieldType TableFieldType +} + +func (fr *FieldRef) GetAliasHash() string { + return fmt.Sprintf("%s.%s", fr.TableName, fr.FieldName) +} + +type FieldRefs []FieldRef + +func (fieldRefs *FieldRefs) HasFieldsWithTableAlias(tableAlias string) bool { + for i := 0; i < len(*fieldRefs); i++ { + if tableAlias == (*fieldRefs)[i].TableName { + return true + } + } + return false +} + +func JoinFieldRefs(srcFieldRefs ...*FieldRefs) *FieldRefs { + hashes := map[string](*FieldRef){} + for i := 0; i < len(srcFieldRefs); i++ { + if srcFieldRefs[i] != nil { + for j := 0; j < len(*srcFieldRefs[i]); j++ { + hash := fmt.Sprintf("%s.%s", (*srcFieldRefs[i])[j].TableName, (*srcFieldRefs[i])[j].FieldName) + if _, ok := hashes[hash]; !ok { + hashes[hash] = &(*srcFieldRefs[i])[j] + } + } + } + } + + fieldRefs := make(FieldRefs, len(hashes)) + fieldRefCount := 0 + for _, fieldRef := range hashes { + fieldRefs[fieldRefCount] = *fieldRef + fieldRefCount++ + } + return &fieldRefs +} + +func RowidFieldRef(tableName string) FieldRef { + return FieldRef{ + TableName: tableName, + FieldName: "rowid", + FieldType: FieldTypeInt} +} + +func RowidTokenFieldRef() FieldRef { + return FieldRef{ + TableName: "db_system", + FieldName: "token(rowid)", + FieldType: FieldTypeInt} +} + +// func RunBatchRowidTokenFieldRef() FieldRef { +// return FieldRef{ +// TableName: "db_system", +// FieldName: "token(run_id,batch_idx,rowid)", +// FieldType: FieldTypeInt} +// } + +func KeyTokenFieldRef() FieldRef { + return FieldRef{ + TableName: "db_system", + FieldName: "token(key)", + FieldType: FieldTypeInt} +} + +// func RunBatchKeyTokenFieldRef() FieldRef { +// return FieldRef{ +// TableName: "db_system", +// FieldName: "token(run_id,batch_idx,key)", +// FieldType: FieldTypeInt} +// } +func IdxKeyFieldRef() FieldRef { + return FieldRef{ + TableName: "db_system", + FieldName: "key", + FieldType: FieldTypeString} +} + +func (fieldRefs *FieldRefs) contributeUnresolved(tableName string, fieldName string) { + // Check if it's already there + for i := 0; i < len(*fieldRefs); i++ { + if (*fieldRefs)[i].TableName == tableName && + (*fieldRefs)[i].FieldName == fieldName { + // Already there + return + } + } + + *fieldRefs = append(*fieldRefs, FieldRef{TableName: tableName, FieldName: fieldName, FieldType: FieldTypeUnknown}) +} + +func (fieldRefs *FieldRefs) Append(otherFieldRefs FieldRefs) { + fieldRefs.AppendWithFilter(otherFieldRefs, "") +} + +func (fieldRefs *FieldRefs) AppendWithFilter(otherFieldRefs FieldRefs, tableFilter string) { + fieldRefMap := map[string]FieldRef{} + + // Existing to map + for i := 0; i < len(*fieldRefs); i++ { + fieldRefMap[(*fieldRefs)[i].GetAliasHash()] = (*fieldRefs)[i] + } + + // New to map + for i := 0; i < len(otherFieldRefs); i++ { + if len(tableFilter) == 0 || tableFilter == (otherFieldRefs)[i].TableName { + fieldRefMap[(otherFieldRefs)[i].GetAliasHash()] = (otherFieldRefs)[i] + } + } + + *fieldRefs = make([]FieldRef, len(fieldRefMap)) + refIdx := 0 + for fieldRefHash := range fieldRefMap { + (*fieldRefs)[refIdx] = fieldRefMap[fieldRefHash] + refIdx++ + } +} + +func evalExpressionWithFieldRefsAndCheckType(exp ast.Expr, fieldRefs FieldRefs, expectedType TableFieldType) error { + if exp == nil { + // Nothing to evaluate + return nil + } + varValuesMap := eval.VarValuesMap{} + for i := 0; i < len(fieldRefs); i++ { + tName := fieldRefs[i].TableName + fName := fieldRefs[i].FieldName + fType := fieldRefs[i].FieldType + if _, ok := varValuesMap[tName]; !ok { + varValuesMap[tName] = map[string]any{} + } + switch fType { + case FieldTypeInt: + varValuesMap[tName][fName] = int64(0) + case FieldTypeFloat: + varValuesMap[tName][fName] = float64(0.0) + case FieldTypeBool: + varValuesMap[tName][fName] = false + case FieldTypeString: + varValuesMap[tName][fName] = "12345.67" // There may be a float() call out there + case FieldTypeDateTime: + varValuesMap[tName][fName] = time.Now() + case FieldTypeDecimal2: + varValuesMap[tName][fName] = decimal.NewFromFloat(2.34) + default: + return fmt.Errorf("evalExpressionWithFieldRefsAndCheckType unsupported field type %s", fieldRefs[i].FieldType) + + } + } + + aggFuncEnabled, aggFuncType, aggFuncArgs := eval.DetectRootAggFunc(exp) + eCtx, err := eval.NewPlainEvalCtxWithVarsAndInitializedAgg(aggFuncEnabled, &varValuesMap, aggFuncType, aggFuncArgs) + if err != nil { + return err + } + + result, err := eCtx.Eval(exp) + if err != nil { + return err + } + + return CheckValueType(result, expectedType) +} + +func (fieldRefs *FieldRefs) FindByFieldName(fieldName string) (*FieldRef, bool) { + for i := 0; i < len(*fieldRefs); i++ { + if fieldName == (*fieldRefs)[i].FieldName { + return &(*fieldRefs)[i], true + } + } + return nil, false +} + +func checkAllowed(fieldRefsToCheck *FieldRefs, prohibitedFieldRefs *FieldRefs, allowedFieldRefs *FieldRefs) error { + if fieldRefsToCheck == nil { + return nil + } + + // Harvest allowed + allowedHashes := map[string](*FieldRef){} + if allowedFieldRefs != nil { + for i := 0; i < len(*allowedFieldRefs); i++ { + hash := fmt.Sprintf("%s.%s", (*allowedFieldRefs)[i].TableName, (*allowedFieldRefs)[i].FieldName) + if _, ok := allowedHashes[hash]; !ok { + allowedHashes[hash] = &(*allowedFieldRefs)[i] + } + } + } + + // Harvest prohibited + prohibitedHashes := map[string](*FieldRef){} + if prohibitedFieldRefs != nil { + for i := 0; i < len(*prohibitedFieldRefs); i++ { + hash := fmt.Sprintf("%s.%s", (*prohibitedFieldRefs)[i].TableName, (*prohibitedFieldRefs)[i].FieldName) + if _, ok := prohibitedHashes[hash]; !ok { + prohibitedHashes[hash] = &(*prohibitedFieldRefs)[i] + } + } + } + + errors := make([]string, 0, 2) + + for i := 0; i < len(*fieldRefsToCheck); i++ { + if len((*fieldRefsToCheck)[i].TableName) == 0 || len((*fieldRefsToCheck)[i].FieldName) == 0 { + errors = append(errors, fmt.Sprintf("dev error, empty FieldRef %s.%s", + (*fieldRefsToCheck)[i].TableName, (*fieldRefsToCheck)[i].FieldName)) + } + hash := fmt.Sprintf("%s.%s", (*fieldRefsToCheck)[i].TableName, (*fieldRefsToCheck)[i].FieldName) + if _, ok := prohibitedHashes[hash]; ok { + errors = append(errors, fmt.Sprintf("prohibited field %s.%s", (*fieldRefsToCheck)[i].TableName, (*fieldRefsToCheck)[i].FieldName)) + } else if _, ok := allowedHashes[hash]; !ok { + errors = append(errors, fmt.Sprintf("unknown field %s.%s", (*fieldRefsToCheck)[i].TableName, (*fieldRefsToCheck)[i].FieldName)) + } else { + // Update check field type, we will use it later for test eval + (*fieldRefsToCheck)[i].FieldType = allowedHashes[hash].FieldType + } + } + + if len(errors) > 0 { + return fmt.Errorf(strings.Join(errors, "; ")) + } + + return nil +} + +type FieldRefParserFlag uint32 + +func (f FieldRefParserFlag) HasFlag(flag FieldRefParserFlag) bool { return f&flag != 0 } + +// Not used for now, maybe later +// func (f *FieldRefParserFlag) AddFlag(flag FieldRefParserFlag) { *f |= flag } +// func (f *FieldRefParserFlag) ClearFlag(flag FieldRefParserFlag) { *f &= ^flag } +// func (f *FieldRefParserFlag) ToggleFlag(flag FieldRefParserFlag) { *f ^= flag } + +const ( + FieldRefStrict FieldRefParserFlag = 0 + FieldRefAllowUnknownIdents FieldRefParserFlag = 1 << iota + FieldRefAllowWhateverFeatureYouAreAddingHere +) + +func harvestFieldRefsFromParsedExpression(exp ast.Expr, usedFields *FieldRefs, parserFlags FieldRefParserFlag) error { + switch assertedExp := exp.(type) { + case *ast.BinaryExpr: + if err := harvestFieldRefsFromParsedExpression(assertedExp.X, usedFields, parserFlags); err != nil { + return err + } + return harvestFieldRefsFromParsedExpression(assertedExp.Y, usedFields, parserFlags) + + case *ast.UnaryExpr: + return harvestFieldRefsFromParsedExpression(assertedExp.X, usedFields, parserFlags) + + case *ast.CallExpr: + for _, v := range assertedExp.Args { + if err := harvestFieldRefsFromParsedExpression(v, usedFields, parserFlags); err != nil { + return err + } + } + + case *ast.SelectorExpr: + switch assertedExpIdent := assertedExp.X.(type) { + case *ast.Ident: + _, ok := eval.GolangConstants[fmt.Sprintf("%s.%s", assertedExpIdent.Name, assertedExp.Sel.Name)] + if !ok { + usedFields.contributeUnresolved(assertedExpIdent.Name, assertedExp.Sel.Name) + } + default: + return fmt.Errorf("selectors starting with non-ident are not allowed, found '%v'; aliases to use: readers - '%s', creators - '%s', custom processors - '%s', lookups - '%s'", + assertedExp.X, ReaderAlias, CreatorAlias, CustomProcessorAlias, LookupAlias) + } + + case *ast.Ident: + // Keep in mind we may use this parser for Python expressions. Allow unknown constructs for those cases. + if !parserFlags.HasFlag(FieldRefAllowUnknownIdents) { + if assertedExp.Name != "true" && assertedExp.Name != "false" { + return fmt.Errorf("plain (non-selector) identifiers are not allowed, expected field qualifiers (tableor_lkp_alias.field_name), found '%s'; for file readers, use '%s' alias; for file creators, use '%s' alias", + assertedExp.Name, ReaderAlias, CreatorAlias) + } + } + } + + return nil +} + +func ParseRawGolangExpressionStringAndHarvestFieldRefs(strExp string, usedFields *FieldRefs) (ast.Expr, error) { + if len(strings.TrimSpace(strExp)) == 0 { + return nil, nil + } + + expCondition, err := parser.ParseExpr(strExp) + if err != nil { + return nil, fmt.Errorf("strict parsing error: [%s]", err.Error()) + } + + if usedFields != nil { + if err := harvestFieldRefsFromParsedExpression(expCondition, usedFields, FieldRefStrict); err != nil { + return nil, err + } + } + + return expCondition, nil +} + +func ParseRawRelaxedGolangExpressionStringAndHarvestFieldRefs(strExp string, usedFields *FieldRefs, parserFlags FieldRefParserFlag) (ast.Expr, error) { + if len(strings.TrimSpace(strExp)) == 0 { + return nil, nil + } + + expCondition, err := parser.ParseExpr(strExp) + if err != nil { + return nil, fmt.Errorf("relaxed parsing error: [%s]", err.Error()) + } + + if usedFields != nil { + if err := harvestFieldRefsFromParsedExpression(expCondition, usedFields, parserFlags); err != nil { + return nil, err + } + } + + return expCondition, nil +} diff --git a/pkg/sc/field_ref_test.go b/pkg/sc/field_ref_test.go index d417cfe..6c94b8b 100644 --- a/pkg/sc/field_ref_test.go +++ b/pkg/sc/field_ref_test.go @@ -16,4 +16,7 @@ func TestAppendWithFilter(t *testing.T) { assert.True(t, targetRefs.HasFieldsWithTableAlias("t0")) assert.False(t, targetRefs.HasFieldsWithTableAlias("t1")) assert.True(t, targetRefs.HasFieldsWithTableAlias("t2")) + + targetRefs.Append(sourceRefs) + assert.Equal(t, 3, len(targetRefs)) } diff --git a/pkg/sc/file_creator_def.go b/pkg/sc/file_creator_def.go index 6d9f114..3d23dcb 100644 --- a/pkg/sc/file_creator_def.go +++ b/pkg/sc/file_creator_def.go @@ -1,216 +1,217 @@ -package sc - -import ( - "encoding/json" - "fmt" - "go/ast" - "strings" - - "github.com/capillariesio/capillaries/pkg/eval" -) - -const ( - CreatorFileTypeUnknown int = 0 - CreatorFileTypeCsv int = 1 - CreatorFileTypeParquet int = 2 -) - -type ParquetCodecType string - -const ( - ParquetCodecGzip ParquetCodecType = "gzip" - ParquetCodecSnappy ParquetCodecType = "snappy" - ParquetCodecUncompressed ParquetCodecType = "uncompressed" -) - -type WriteCsvColumnSettings struct { - Format string `json:"format"` - Header string `json:"header"` -} - -type WriteParquetColumnSettings struct { - ColumnName string `json:"column_name"` -} - -type WriteFileColumnDef struct { - RawExpression string `json:"expression"` - Name string `json:"name"` // To be used in Having - Type TableFieldType `json:"type"` // To be checked when checking expressions and to be used in Having - Csv WriteCsvColumnSettings `json:"csv,omitempty"` - Parquet WriteParquetColumnSettings `json:"parquet,omitempty"` - ParsedExpression ast.Expr - UsedFields FieldRefs -} - -type TopDef struct { - Limit int `json:"limit"` - RawOrder string `json:"order"` - OrderIdxDef IdxDef // Not an index really, we just re-use IdxDef infrastructure -} - -type CsvCreatorSettings struct { - Separator string `json:"separator"` -} - -type ParquetCreatorSettings struct { - Codec ParquetCodecType `json:"codec"` -} - -type FileCreatorDef struct { - RawHaving string `json:"having"` - Having ast.Expr - UsedInHavingFields FieldRefs - UsedInTargetExpressionsFields FieldRefs - Columns []WriteFileColumnDef `json:"columns"` - UrlTemplate string `json:"url_template"` - Top TopDef `json:"top"` - Csv CsvCreatorSettings `json:"csv,omitempty"` - Parquet ParquetCreatorSettings `json:"parquet,omitempty"` - CreatorFileType int -} - -const MaxFileCreatorTopLimit int = 500000 - -func (creatorDef *FileCreatorDef) getFieldRefs() *FieldRefs { - fieldRefs := make(FieldRefs, len(creatorDef.Columns)) - for i := 0; i < len(creatorDef.Columns); i++ { - fieldRefs[i] = FieldRef{ - TableName: CreatorAlias, - FieldName: creatorDef.Columns[i].Name, - FieldType: creatorDef.Columns[i].Type} - } - return &fieldRefs -} - -func (creatorDef *FileCreatorDef) GetFieldRefsUsedInAllTargetFileExpressions() FieldRefs { - fieldRefMap := map[string]FieldRef{} - for colIdx := 0; colIdx < len(creatorDef.Columns); colIdx++ { - targetColDef := &creatorDef.Columns[colIdx] - for i := 0; i < len((*targetColDef).UsedFields); i++ { - hash := fmt.Sprintf("%s.%s", (*targetColDef).UsedFields[i].TableName, (*targetColDef).UsedFields[i].FieldName) - if _, ok := fieldRefMap[hash]; !ok { - fieldRefMap[hash] = (*targetColDef).UsedFields[i] - } - } - } - - // Map to FieldRefs - fieldRefs := make([]FieldRef, len(fieldRefMap)) - i := 0 - for _, fieldRef := range fieldRefMap { - fieldRefs[i] = fieldRef - i++ - } - - return fieldRefs -} - -func (creatorDef *FileCreatorDef) HasTop() bool { - return len(strings.TrimSpace(creatorDef.Top.RawOrder)) > 0 -} - -func (creatorDef *FileCreatorDef) Deserialize(rawWriter json.RawMessage) error { - if err := json.Unmarshal(rawWriter, creatorDef); err != nil { - return fmt.Errorf("cannot unmarshal file creator: [%s]", err.Error()) - } - - if len(creatorDef.Columns) > 0 && creatorDef.Columns[0].Parquet.ColumnName != "" { - creatorDef.CreatorFileType = CreatorFileTypeParquet - if creatorDef.Parquet.Codec == "" { - creatorDef.Parquet.Codec = ParquetCodecGzip - } - } else if len(creatorDef.Columns) > 0 && creatorDef.Columns[0].Csv.Header != "" { - creatorDef.CreatorFileType = CreatorFileTypeCsv - if len(creatorDef.Csv.Separator) == 0 { - creatorDef.Csv.Separator = "," - } - } else { - return fmt.Errorf("cannot cannot detect file creator type: parquet should have column_name, csv should have header etc") - } - - // Having - var err error - creatorDef.Having, err = ParseRawGolangExpressionStringAndHarvestFieldRefs(creatorDef.RawHaving, &creatorDef.UsedInHavingFields) - if err != nil { - return fmt.Errorf("cannot parse file creator 'having' condition [%s]: [%s]", creatorDef.RawHaving, err.Error()) - } - - // Columns - for i := 0; i < len(creatorDef.Columns); i++ { - colDef := &creatorDef.Columns[i] - if (*colDef).ParsedExpression, err = ParseRawGolangExpressionStringAndHarvestFieldRefs((*colDef).RawExpression, &(*colDef).UsedFields); err != nil { - return fmt.Errorf("cannot parse column expression [%s]: [%s]", (*colDef).RawExpression, err.Error()) - } - if !IsValidFieldType(colDef.Type) { - return fmt.Errorf("invalid column type [%s]", colDef.Type) - } - } - - // Top - if creatorDef.HasTop() { - if creatorDef.Top.Limit <= 0 { - creatorDef.Top.Limit = MaxFileCreatorTopLimit - } else if creatorDef.Top.Limit > MaxFileCreatorTopLimit { - return fmt.Errorf("top.limit cannot exceed %d", MaxFileCreatorTopLimit) - } - idxDefMap := IdxDefMap{} - rawIndexes := map[string]string{"top": fmt.Sprintf("non_unique(%s)", creatorDef.Top.RawOrder)} - idxDefMap.parseRawIndexDefMap(rawIndexes, creatorDef.getFieldRefs()) - creatorDef.Top.OrderIdxDef = *idxDefMap["top"] - } - - creatorDef.UsedInTargetExpressionsFields = creatorDef.GetFieldRefsUsedInAllTargetFileExpressions() - return nil -} - -func (creatorDef *FileCreatorDef) CalculateFileRecordFromSrcVars(srcVars eval.VarValuesMap) ([]interface{}, error) { - errors := make([]string, 0, 2) - - fileRecord := make([]interface{}, len(creatorDef.Columns)) - - for colIdx := 0; colIdx < len(creatorDef.Columns); colIdx++ { - eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &srcVars) - valVolatile, err := eCtx.Eval(creatorDef.Columns[colIdx].ParsedExpression) - if err != nil { - errors = append(errors, fmt.Sprintf("cannot evaluate expression for column %s: [%s]", creatorDef.Columns[colIdx].Name, err.Error())) - } - if err := CheckValueType(valVolatile, creatorDef.Columns[colIdx].Type); err != nil { - errors = append(errors, fmt.Sprintf("invalid field %s type: [%s]", creatorDef.Columns[colIdx].Name, err.Error())) - } - fileRecord[colIdx] = valVolatile - } - - if len(errors) > 0 { - return nil, fmt.Errorf(strings.Join(errors, "; ")) - } else { - return fileRecord, nil - } -} - -func (creatorDef *FileCreatorDef) CheckFileRecordHavingCondition(fileRecord []interface{}) (bool, error) { - if creatorDef.Having == nil { - return true, nil - } - vars := eval.VarValuesMap{} - vars[CreatorAlias] = map[string]interface{}{} - if len(fileRecord) != len(creatorDef.Columns) { - return false, fmt.Errorf("file record length %d does not match file creator column list length %d", len(fileRecord), len(creatorDef.Columns)) - } - for colIdx := 0; colIdx < len(creatorDef.Columns); colIdx++ { - fieldName := creatorDef.Columns[colIdx].Name - fieldValue := fileRecord[colIdx] - vars[CreatorAlias][fieldName] = fieldValue - } - - eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &vars) - valVolatile, err := eCtx.Eval(creatorDef.Having) - if err != nil { - return false, fmt.Errorf("cannot evaluate 'having' expression: [%s]", err.Error()) - } - valBool, ok := valVolatile.(bool) - if !ok { - return false, fmt.Errorf("cannot get bool when evaluating having expression, got %v(%T) instead", valVolatile, valVolatile) - } - - return valBool, nil -} +package sc + +import ( + "encoding/json" + "fmt" + "go/ast" + "strings" + + "github.com/capillariesio/capillaries/pkg/eval" +) + +const ( + CreatorFileTypeUnknown int = 0 + CreatorFileTypeCsv int = 1 + CreatorFileTypeParquet int = 2 +) + +type ParquetCodecType string + +const ( + ParquetCodecGzip ParquetCodecType = "gzip" + ParquetCodecSnappy ParquetCodecType = "snappy" + ParquetCodecUncompressed ParquetCodecType = "uncompressed" +) + +type WriteCsvColumnSettings struct { + Format string `json:"format"` + Header string `json:"header"` +} + +type WriteParquetColumnSettings struct { + ColumnName string `json:"column_name"` +} + +type WriteFileColumnDef struct { + RawExpression string `json:"expression"` + Name string `json:"name"` // To be used in Having + Type TableFieldType `json:"type"` // To be checked when checking expressions and to be used in Having + Csv WriteCsvColumnSettings `json:"csv,omitempty"` + Parquet WriteParquetColumnSettings `json:"parquet,omitempty"` + ParsedExpression ast.Expr + UsedFields FieldRefs +} + +type TopDef struct { + Limit int `json:"limit"` + RawOrder string `json:"order"` + OrderIdxDef IdxDef // Not an index really, we just re-use IdxDef infrastructure +} + +type CsvCreatorSettings struct { + Separator string `json:"separator"` +} + +type ParquetCreatorSettings struct { + Codec ParquetCodecType `json:"codec"` +} + +type FileCreatorDef struct { + RawHaving string `json:"having"` + Having ast.Expr + UsedInHavingFields FieldRefs + UsedInTargetExpressionsFields FieldRefs + Columns []WriteFileColumnDef `json:"columns"` + UrlTemplate string `json:"url_template"` + Top TopDef `json:"top"` + Csv CsvCreatorSettings `json:"csv,omitempty"` + Parquet ParquetCreatorSettings `json:"parquet,omitempty"` + CreatorFileType int +} + +const MaxFileCreatorTopLimit int = 500000 + +func (creatorDef *FileCreatorDef) getFieldRefs() *FieldRefs { + fieldRefs := make(FieldRefs, len(creatorDef.Columns)) + for i := 0; i < len(creatorDef.Columns); i++ { + fieldRefs[i] = FieldRef{ + TableName: CreatorAlias, + FieldName: creatorDef.Columns[i].Name, + FieldType: creatorDef.Columns[i].Type} + } + return &fieldRefs +} + +func (creatorDef *FileCreatorDef) GetFieldRefsUsedInAllTargetFileExpressions() FieldRefs { + fieldRefMap := map[string]FieldRef{} + for colIdx := 0; colIdx < len(creatorDef.Columns); colIdx++ { + targetColDef := &creatorDef.Columns[colIdx] + for i := 0; i < len((*targetColDef).UsedFields); i++ { + hash := fmt.Sprintf("%s.%s", (*targetColDef).UsedFields[i].TableName, (*targetColDef).UsedFields[i].FieldName) + if _, ok := fieldRefMap[hash]; !ok { + fieldRefMap[hash] = (*targetColDef).UsedFields[i] + } + } + } + + // Map to FieldRefs + fieldRefs := make([]FieldRef, len(fieldRefMap)) + i := 0 + for _, fieldRef := range fieldRefMap { + fieldRefs[i] = fieldRef + i++ + } + + return fieldRefs +} + +func (creatorDef *FileCreatorDef) HasTop() bool { + return len(strings.TrimSpace(creatorDef.Top.RawOrder)) > 0 +} + +func (creatorDef *FileCreatorDef) Deserialize(rawWriter json.RawMessage) error { + if err := json.Unmarshal(rawWriter, creatorDef); err != nil { + return fmt.Errorf("cannot unmarshal file creator: [%s]", err.Error()) + } + + if len(creatorDef.Columns) > 0 && creatorDef.Columns[0].Parquet.ColumnName != "" { + creatorDef.CreatorFileType = CreatorFileTypeParquet + if creatorDef.Parquet.Codec == "" { + creatorDef.Parquet.Codec = ParquetCodecGzip + } + } else if len(creatorDef.Columns) > 0 && creatorDef.Columns[0].Csv.Header != "" { + creatorDef.CreatorFileType = CreatorFileTypeCsv + if len(creatorDef.Csv.Separator) == 0 { + creatorDef.Csv.Separator = "," + } + } else { + return fmt.Errorf("cannot cannot detect file creator type: parquet should have column_name, csv should have header etc") + } + + // Having + var err error + creatorDef.Having, err = ParseRawGolangExpressionStringAndHarvestFieldRefs(creatorDef.RawHaving, &creatorDef.UsedInHavingFields) + if err != nil { + return fmt.Errorf("cannot parse file creator 'having' condition [%s]: [%s]", creatorDef.RawHaving, err.Error()) + } + + // Columns + for i := 0; i < len(creatorDef.Columns); i++ { + colDef := &creatorDef.Columns[i] + if (*colDef).ParsedExpression, err = ParseRawGolangExpressionStringAndHarvestFieldRefs((*colDef).RawExpression, &(*colDef).UsedFields); err != nil { + return fmt.Errorf("cannot parse column expression [%s]: [%s]", (*colDef).RawExpression, err.Error()) + } + if !IsValidFieldType(colDef.Type) { + return fmt.Errorf("invalid column type [%s]", colDef.Type) + } + } + + // Top + if creatorDef.HasTop() { + if creatorDef.Top.Limit <= 0 { + creatorDef.Top.Limit = MaxFileCreatorTopLimit + } else if creatorDef.Top.Limit > MaxFileCreatorTopLimit { + return fmt.Errorf("top.limit cannot exceed %d", MaxFileCreatorTopLimit) + } + idxDefMap := IdxDefMap{} + rawIndexes := map[string]string{"top": fmt.Sprintf("non_unique(%s)", creatorDef.Top.RawOrder)} + if err := idxDefMap.parseRawIndexDefMap(rawIndexes, creatorDef.getFieldRefs()); err != nil { + return fmt.Errorf("cannot parse raw index definition(s) for top: %s", err.Error()) + } + creatorDef.Top.OrderIdxDef = *idxDefMap["top"] + } + + creatorDef.UsedInTargetExpressionsFields = creatorDef.GetFieldRefsUsedInAllTargetFileExpressions() + return nil +} + +func (creatorDef *FileCreatorDef) CalculateFileRecordFromSrcVars(srcVars eval.VarValuesMap) ([]any, error) { + errors := make([]string, 0, 2) + + fileRecord := make([]any, len(creatorDef.Columns)) + + for colIdx := 0; colIdx < len(creatorDef.Columns); colIdx++ { + eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &srcVars) + valVolatile, err := eCtx.Eval(creatorDef.Columns[colIdx].ParsedExpression) + if err != nil { + errors = append(errors, fmt.Sprintf("cannot evaluate expression for column %s: [%s]", creatorDef.Columns[colIdx].Name, err.Error())) + } + if err := CheckValueType(valVolatile, creatorDef.Columns[colIdx].Type); err != nil { + errors = append(errors, fmt.Sprintf("invalid field %s type: [%s]", creatorDef.Columns[colIdx].Name, err.Error())) + } + fileRecord[colIdx] = valVolatile + } + + if len(errors) > 0 { + return nil, fmt.Errorf(strings.Join(errors, "; ")) + } + return fileRecord, nil +} + +func (creatorDef *FileCreatorDef) CheckFileRecordHavingCondition(fileRecord []any) (bool, error) { + if len(fileRecord) != len(creatorDef.Columns) { + return false, fmt.Errorf("file record length %d does not match file creator column list length %d", len(fileRecord), len(creatorDef.Columns)) + } + if creatorDef.Having == nil { + return true, nil + } + vars := eval.VarValuesMap{} + vars[CreatorAlias] = map[string]any{} + for colIdx := 0; colIdx < len(creatorDef.Columns); colIdx++ { + fieldName := creatorDef.Columns[colIdx].Name + fieldValue := fileRecord[colIdx] + vars[CreatorAlias][fieldName] = fieldValue + } + + eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &vars) + valVolatile, err := eCtx.Eval(creatorDef.Having) + if err != nil { + return false, fmt.Errorf("cannot evaluate 'having' expression: [%s]", err.Error()) + } + valBool, ok := valVolatile.(bool) + if !ok { + return false, fmt.Errorf("cannot get bool when evaluating having expression, got %v(%T) instead", valVolatile, valVolatile) + } + + return valBool, nil +} diff --git a/pkg/sc/file_creator_def_test.go b/pkg/sc/file_creator_def_test.go index 997e6d2..74fab23 100644 --- a/pkg/sc/file_creator_def_test.go +++ b/pkg/sc/file_creator_def_test.go @@ -10,8 +10,9 @@ import ( const nodeCfgCsvJson string = ` { "top": { - "order": "taxed_field_int1(asc)" + "order": "field_string1(asc)" }, + "having": "len(w.field_string1) > 0", "url_template": "taxed_table1.csv", "columns": [ { @@ -29,9 +30,6 @@ const nodeCfgCsvJson string = ` const nodeCfgParquetJson string = ` { - "top": { - "order": "taxed_field_int1(asc)" - }, "url_template": "taxed_table1.csv", "columns": [ { @@ -62,4 +60,44 @@ func TestFileCreatorDefFailures(t *testing.T) { re := regexp.MustCompile(`"type": "[^"]+"`) assert.Contains(t, c.Deserialize([]byte(re.ReplaceAllString(nodeCfgCsvJson, `"type": "aaa"`))).Error(), "invalid column type [aaa]") + + re = regexp.MustCompile(`"order": "[^"]+"`) + assert.Contains(t, c.Deserialize([]byte(re.ReplaceAllString(nodeCfgCsvJson, `"order": "bad_field(asc)"`))).Error(), "cannot parse raw index definition(s) for top") + +} + +func TestCheckFileRecordHavingCondition(t *testing.T) { + c := FileCreatorDef{} + assert.Nil(t, c.Deserialize([]byte(nodeCfgCsvJson))) + + isPass, err := c.CheckFileRecordHavingCondition([]any{"aaa"}) + assert.Nil(t, err) + assert.True(t, isPass) + + isPass, err = c.CheckFileRecordHavingCondition([]any{""}) + assert.Nil(t, err) + assert.False(t, isPass) + + re := regexp.MustCompile(`"having": "[^"]+"`) + assert.Nil(t, c.Deserialize([]byte(re.ReplaceAllString(nodeCfgCsvJson, `"having": "w.bad_field"`)))) + _, err = c.CheckFileRecordHavingCondition([]any{"aaa"}) + assert.Contains(t, err.Error(), "cannot evaluate 'having' expression") + + re = regexp.MustCompile(`"having": "[^"]+"`) + assert.Nil(t, c.Deserialize([]byte(re.ReplaceAllString(nodeCfgCsvJson, `"having": "w.field_string1"`)))) + _, err = c.CheckFileRecordHavingCondition([]any{"aaa"}) + assert.Contains(t, err.Error(), "cannot get bool when evaluating having expression, got aaa(string) instead") + + // Remove having + c = FileCreatorDef{} + re = regexp.MustCompile(`"having": "[^"]+",`) + assert.Nil(t, c.Deserialize([]byte(re.ReplaceAllString(nodeCfgCsvJson, ``)))) + _, err = c.CheckFileRecordHavingCondition([]any{"aaa"}) + assert.Nil(t, err) + + // Missing field + c = FileCreatorDef{} + assert.Nil(t, c.Deserialize([]byte(nodeCfgCsvJson))) + _, err = c.CheckFileRecordHavingCondition([]any{}) + assert.Contains(t, err.Error(), "file record length 0 does not match file creator column list length 1") } diff --git a/pkg/sc/file_reader_def.go b/pkg/sc/file_reader_def.go index 57ac073..cf8b864 100644 --- a/pkg/sc/file_reader_def.go +++ b/pkg/sc/file_reader_def.go @@ -1,321 +1,320 @@ -package sc - -import ( - "encoding/json" - "fmt" - "strconv" - "strings" - "time" - - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/shopspring/decimal" -) - -type CsvReaderColumnSettings struct { - SrcColIdx int `json:"col_idx"` - SrcColHeader string `json:"col_hdr"` - SrcColFormat string `json:"col_format"` // Optional for all except datetime -} - -type ParquetReaderColumnSettings struct { - SrcColName string `json:"col_name"` -} - -type FileReaderColumnDef struct { - DefaultValue string `json:"col_default_value"` // Optional. If omitted, zero value is used - Type TableFieldType `json:"col_type"` - Csv CsvReaderColumnSettings `json:"csv,omitempty"` - Parquet ParquetReaderColumnSettings `json:"parquet,omitempty"` -} - -type CsvReaderSettings struct { - SrcFileHdrLineIdx int `json:"hdr_line_idx"` - SrcFileFirstDataLineIdx int `json:"first_data_line_idx"` - Separator string `json:"separator"` - ColumnIndexingMode FileColumnIndexingMode -} - -const ( - ReaderFileTypeUnknown int = 0 - ReaderFileTypeCsv int = 1 - ReaderFileTypeParquet int = 2 -) - -type FileReaderDef struct { - SrcFileUrls []string `json:"urls"` - Columns map[string]*FileReaderColumnDef `json:"columns"` // Keys are names used in table writer - Csv CsvReaderSettings `json:"csv,omitempty"` - ReaderFileType int -} - -func (frDef *FileReaderDef) getFieldRefs() *FieldRefs { - fieldRefs := make(FieldRefs, len(frDef.Columns)) - i := 0 - for fieldName, colDef := range frDef.Columns { - fieldRefs[i] = FieldRef{ - TableName: ReaderAlias, - FieldName: fieldName, - FieldType: colDef.Type} - i += 1 - } - return &fieldRefs -} - -type FileColumnIndexingMode string - -const ( - FileColumnIndexingName FileColumnIndexingMode = "name" - FileColumnIndexingIdx FileColumnIndexingMode = "idx" - FileColumnIndexingUnknown FileColumnIndexingMode = "unknown" -) - -func (frDef *FileReaderDef) getCsvColumnIndexingMode() (FileColumnIndexingMode, error) { - usesIdxCount := 0 - usesHdrNameCount := 0 - for _, colDef := range frDef.Columns { - if len(colDef.Csv.SrcColHeader) > 0 { - usesHdrNameCount++ // We have a name, ignore col idx, it's probably zero (default) - } else if colDef.Csv.SrcColIdx >= 0 { - usesIdxCount++ - } else { - if colDef.Csv.SrcColIdx < 0 { - return "", fmt.Errorf("file reader column definition cannot use negative column index: %d", colDef.Csv.SrcColIdx) - } - } - } - - if usesIdxCount > 0 && usesHdrNameCount > 0 { - return "", fmt.Errorf("file reader column definitions cannot use both indexes and names, pick one method: col_hdr or col_idx") - } - - if usesIdxCount > 0 { - return FileColumnIndexingIdx, nil - } else if usesHdrNameCount > 0 { - return FileColumnIndexingName, nil - } - - // Never land here - return "", fmt.Errorf("file reader column indexing mode dev error") - -} - -func (frDef *FileReaderDef) Deserialize(rawReader json.RawMessage) error { - errors := make([]string, 0, 2) - - // Unmarshal - - if err := json.Unmarshal(rawReader, frDef); err != nil { - errors = append(errors, err.Error()) - } - - if len(frDef.SrcFileUrls) == 0 { - errors = append(errors, "no source file urls specified, need at least one") - } - - frDef.ReaderFileType = ReaderFileTypeUnknown - for _, colDef := range frDef.Columns { - if colDef.Parquet.SrcColName != "" { - frDef.ReaderFileType = ReaderFileTypeParquet - break - } else if (colDef.Csv.SrcColHeader != "" || colDef.Csv.SrcColIdx > 0) || - len(frDef.Columns) == 1 { // Special CSV case: no headers, only one column - frDef.ReaderFileType = ReaderFileTypeCsv - - // Detect column indexing mode: by idx or by name - var err error - frDef.Csv.ColumnIndexingMode, err = frDef.getCsvColumnIndexingMode() - if err != nil { - errors = append(errors, fmt.Sprintf("cannot detect csv column indexing mode: [%s]", err.Error())) - } - - // Default CSV field Separator - if len(frDef.Csv.Separator) == 0 { - frDef.Csv.Separator = "," - } - break - } - } - - if frDef.ReaderFileType == ReaderFileTypeUnknown { - errors = append(errors, "cannot detect file reader type: parquet should have col_name, csv should have col_hdr or col_idx etc") - } - - if len(errors) > 0 { - return fmt.Errorf(strings.Join(errors, "; ")) - } else { - return nil - } -} - -func (frDef *FileReaderDef) ResolveCsvColumnIndexesFromNames(srcHdrLine []string) error { - columnsResolved := 0 - for _, colDef := range frDef.Columns { - for i := 0; i < len(srcHdrLine); i++ { - if len(colDef.Csv.SrcColHeader) > 0 && srcHdrLine[i] == colDef.Csv.SrcColHeader { - colDef.Csv.SrcColIdx = i - columnsResolved++ - } - } - } - if columnsResolved < len(frDef.Columns) { - return fmt.Errorf("cannot resove all %d source file column indexes, resolved only %d", len(frDef.Columns), columnsResolved) - } - return nil -} - -func (frDef *FileReaderDef) ReadCsvLineToValuesMap(line *[]string, colVars eval.VarValuesMap) error { - colVars[ReaderAlias] = map[string]interface{}{} - for colName, colDef := range frDef.Columns { - colData := (*line)[colDef.Csv.SrcColIdx] - switch colDef.Type { - case FieldTypeString: - if len(colDef.Csv.SrcColFormat) > 0 { - return fmt.Errorf("cannot read string column %s, data '%s': format '%s' was specified, but string fields do not accept format specifier, remove this setting", colName, colData, colDef.Csv.SrcColFormat) - } - if len(colData) == 0 { - if len(colDef.DefaultValue) > 0 { - colVars[ReaderAlias][colName] = colDef.DefaultValue - } else { - colVars[ReaderAlias][colName] = GetDefaultFieldTypeValue(FieldTypeString) - } - } else { - colVars[ReaderAlias][colName] = colData - } - - case FieldTypeBool: - if len(colDef.Csv.SrcColFormat) > 0 { - return fmt.Errorf("cannot read bool column %s, data '%s': format '%s' was specified, but bool fields do not accept format specifier, remove this setting", colName, colData, colDef.Csv.SrcColFormat) - } - - var err error - if len(strings.TrimSpace(colData)) == 0 { - if len(strings.TrimSpace(colDef.DefaultValue)) > 0 { - colVars[ReaderAlias][colName], err = strconv.ParseBool(colDef.DefaultValue) - if err != nil { - return fmt.Errorf("cannot read bool column %s, from default value string '%s', allowed values are true,false,T,F,0,1: %s", colName, colDef.DefaultValue, err.Error()) - } - } else { - colVars[ReaderAlias][colName] = GetDefaultFieldTypeValue(FieldTypeBool) - } - } else { - colVars[ReaderAlias][colName], err = strconv.ParseBool(colData) - if err != nil { - return fmt.Errorf("cannot read bool column %s, data '%s', allowed values are true,false,T,F,0,1: %s", colName, colData, err.Error()) - } - } - - case FieldTypeInt: - if len(strings.TrimSpace(colData)) == 0 { - if len(strings.TrimSpace(colDef.DefaultValue)) > 0 { - valInt, err := strconv.ParseInt(colDef.DefaultValue, 10, 64) - if err != nil { - return fmt.Errorf("cannot read int64 column %s from default value string '%s': %s", colName, colDef.DefaultValue, err.Error()) - } - colVars[ReaderAlias][colName] = valInt - } else { - colVars[ReaderAlias][colName] = GetDefaultFieldTypeValue(FieldTypeInt) - } - } else { - if len(colDef.Csv.SrcColFormat) > 0 { - var valInt int64 - _, err := fmt.Sscanf(colData, colDef.Csv.SrcColFormat, &valInt) - if err != nil { - return fmt.Errorf("cannot read int64 column %s, data '%s', format '%s': %s", colName, colData, colDef.Csv.SrcColFormat, err.Error()) - } - colVars[ReaderAlias][colName] = valInt - } else { - valInt, err := strconv.ParseInt(colData, 10, 64) - if err != nil { - return fmt.Errorf("cannot read int64 column %s, data '%s', no format: %s", colName, colData, err.Error()) - } - colVars[ReaderAlias][colName] = valInt - } - } - - case FieldTypeDateTime: - if len(strings.TrimSpace(colData)) == 0 { - if len(strings.TrimSpace(colDef.DefaultValue)) > 0 { - valTime, err := time.Parse(colDef.Csv.SrcColFormat, colDef.DefaultValue) - if err != nil { - return fmt.Errorf("cannot read time column %s from default value string '%s': %s", colName, colDef.DefaultValue, err.Error()) - } - colVars[ReaderAlias][colName] = valTime - } else { - colVars[ReaderAlias][colName] = GetDefaultFieldTypeValue(FieldTypeDateTime) - } - } else { - if len(colDef.Csv.SrcColFormat) == 0 { - return fmt.Errorf("cannot read datetime column %s, data '%s': column format is missing, consider specifying something like 2006-01-02T15:04:05.000-0700, see go datetime format documentation for details", colName, colData) - } - - valTime, err := time.Parse(colDef.Csv.SrcColFormat, colData) - if err != nil { - return fmt.Errorf("cannot read datetime column %s, data '%s', format '%s': %s", colName, colData, colDef.Csv.SrcColFormat, err.Error()) - } - colVars[ReaderAlias][colName] = valTime - } - - case FieldTypeFloat: - if len(strings.TrimSpace(colData)) == 0 { - if len(strings.TrimSpace(colDef.DefaultValue)) > 0 { - valFloat, err := strconv.ParseFloat(colDef.DefaultValue, 64) - if err != nil { - return fmt.Errorf("cannot read float64 column %s from default value string '%s': %s", colName, colDef.DefaultValue, err.Error()) - } - colVars[ReaderAlias][colName] = valFloat - } else { - colVars[ReaderAlias][colName] = GetDefaultFieldTypeValue(FieldTypeFloat) - } - } else { - if len(colDef.Csv.SrcColFormat) > 0 { - var valFloat float64 - _, err := fmt.Sscanf(colData, colDef.Csv.SrcColFormat, &valFloat) - if err != nil { - return fmt.Errorf("cannot read float64 column %s, data '%s', format '%s': %s", colName, colData, colDef.Csv.SrcColFormat, err.Error()) - } - colVars[ReaderAlias][colName] = valFloat - } else { - valFloat, err := strconv.ParseFloat(colData, 64) - if err != nil { - return fmt.Errorf("cannot read float64 column %s, data '%s', no format: %s", colName, colData, err.Error()) - } - colVars[ReaderAlias][colName] = valFloat - } - } - - case FieldTypeDecimal2: - // Round to 2 digits after decimal point right away - if len(strings.TrimSpace(colData)) == 0 { - if len(strings.TrimSpace(colDef.DefaultValue)) > 0 { - valDec, err := decimal.NewFromString(colDef.DefaultValue) - if err != nil { - return fmt.Errorf("cannot read decimal2 column %s from default value string '%s': %s", colName, colDef.DefaultValue, err.Error()) - } - colVars[ReaderAlias][colName] = valDec.Round(2) - } else { - colVars[ReaderAlias][colName] = GetDefaultFieldTypeValue(FieldTypeDecimal2) - } - } else { - var valFloat float64 - if len(colDef.Csv.SrcColFormat) > 0 { - // Decimal type does not support sscanf, so sscanf string first - _, err := fmt.Sscanf(colData, colDef.Csv.SrcColFormat, &valFloat) - if err != nil { - return fmt.Errorf("cannot read decimal2 column %s, data '%s', format '%s': %s", colName, colData, colDef.Csv.SrcColFormat, err.Error()) - } - colVars[ReaderAlias][colName] = decimal.NewFromFloat(valFloat).Round(2) - } else { - valDec, err := decimal.NewFromString(colData) - if err != nil { - return fmt.Errorf("cannot read decimal2 column %s, cannot parse data '%s': %s", colName, colData, err.Error()) - } - colVars[ReaderAlias][colName] = valDec.Round(2) - } - } - - default: - return fmt.Errorf("cannot read column %s, data '%s': unsupported column type '%s'", colName, colData, colDef.Type) - } - } - return nil -} +package sc + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/shopspring/decimal" +) + +type CsvReaderColumnSettings struct { + SrcColIdx int `json:"col_idx"` + SrcColHeader string `json:"col_hdr"` + SrcColFormat string `json:"col_format"` // Optional for all except datetime +} + +type ParquetReaderColumnSettings struct { + SrcColName string `json:"col_name"` +} + +type FileReaderColumnDef struct { + DefaultValue string `json:"col_default_value"` // Optional. If omitted, zero value is used + Type TableFieldType `json:"col_type"` + Csv CsvReaderColumnSettings `json:"csv,omitempty"` + Parquet ParquetReaderColumnSettings `json:"parquet,omitempty"` +} + +type CsvReaderSettings struct { + SrcFileHdrLineIdx int `json:"hdr_line_idx"` + SrcFileFirstDataLineIdx int `json:"first_data_line_idx"` + Separator string `json:"separator"` + ColumnIndexingMode FileColumnIndexingMode +} + +const ( + ReaderFileTypeUnknown int = 0 + ReaderFileTypeCsv int = 1 + ReaderFileTypeParquet int = 2 +) + +type FileReaderDef struct { + SrcFileUrls []string `json:"urls"` + Columns map[string]*FileReaderColumnDef `json:"columns"` // Keys are names used in table writer + Csv CsvReaderSettings `json:"csv,omitempty"` + ReaderFileType int +} + +func (frDef *FileReaderDef) getFieldRefs() *FieldRefs { + fieldRefs := make(FieldRefs, len(frDef.Columns)) + i := 0 + for fieldName, colDef := range frDef.Columns { + fieldRefs[i] = FieldRef{ + TableName: ReaderAlias, + FieldName: fieldName, + FieldType: colDef.Type} + i++ + } + return &fieldRefs +} + +type FileColumnIndexingMode string + +const ( + FileColumnIndexingName FileColumnIndexingMode = "name" + FileColumnIndexingIdx FileColumnIndexingMode = "idx" + FileColumnIndexingUnknown FileColumnIndexingMode = "unknown" +) + +func (frDef *FileReaderDef) getCsvColumnIndexingMode() (FileColumnIndexingMode, error) { + usesIdxCount := 0 + usesHdrNameCount := 0 + for _, colDef := range frDef.Columns { + if len(colDef.Csv.SrcColHeader) > 0 { + usesHdrNameCount++ // We have a name, ignore col idx, it's probably zero (default) + } else if colDef.Csv.SrcColIdx >= 0 { + usesIdxCount++ + } else { + if colDef.Csv.SrcColIdx < 0 { + return "", fmt.Errorf("file reader column definition cannot use negative column index: %d", colDef.Csv.SrcColIdx) + } + } + } + + if usesIdxCount > 0 && usesHdrNameCount > 0 { + return "", fmt.Errorf("file reader column definitions cannot use both indexes and names, pick one method: col_hdr or col_idx") + } + + if usesIdxCount > 0 { + return FileColumnIndexingIdx, nil + } else if usesHdrNameCount > 0 { + return FileColumnIndexingName, nil + } + + // Never land here + return "", fmt.Errorf("file reader column indexing mode dev error") + +} + +func (frDef *FileReaderDef) Deserialize(rawReader json.RawMessage) error { + errors := make([]string, 0, 2) + + // Unmarshal + + if err := json.Unmarshal(rawReader, frDef); err != nil { + errors = append(errors, err.Error()) + } + + if len(frDef.SrcFileUrls) == 0 { + errors = append(errors, "no source file urls specified, need at least one") + } + + frDef.ReaderFileType = ReaderFileTypeUnknown + for _, colDef := range frDef.Columns { + if colDef.Parquet.SrcColName != "" { + frDef.ReaderFileType = ReaderFileTypeParquet + break + } else if (colDef.Csv.SrcColHeader != "" || colDef.Csv.SrcColIdx > 0) || + len(frDef.Columns) == 1 { // Special CSV case: no headers, only one column + frDef.ReaderFileType = ReaderFileTypeCsv + + // Detect column indexing mode: by idx or by name + var err error + frDef.Csv.ColumnIndexingMode, err = frDef.getCsvColumnIndexingMode() + if err != nil { + errors = append(errors, fmt.Sprintf("cannot detect csv column indexing mode: [%s]", err.Error())) + } + + // Default CSV field Separator + if len(frDef.Csv.Separator) == 0 { + frDef.Csv.Separator = "," + } + break + } + } + + if frDef.ReaderFileType == ReaderFileTypeUnknown { + errors = append(errors, "cannot detect file reader type: parquet should have col_name, csv should have col_hdr or col_idx etc") + } + + if len(errors) > 0 { + return fmt.Errorf(strings.Join(errors, "; ")) + } + return nil +} + +func (frDef *FileReaderDef) ResolveCsvColumnIndexesFromNames(srcHdrLine []string) error { + columnsResolved := 0 + for _, colDef := range frDef.Columns { + for i := 0; i < len(srcHdrLine); i++ { + if len(colDef.Csv.SrcColHeader) > 0 && srcHdrLine[i] == colDef.Csv.SrcColHeader { + colDef.Csv.SrcColIdx = i + columnsResolved++ + } + } + } + if columnsResolved < len(frDef.Columns) { + return fmt.Errorf("cannot resove all %d source file column indexes, resolved only %d", len(frDef.Columns), columnsResolved) + } + return nil +} + +func (frDef *FileReaderDef) ReadCsvLineToValuesMap(line *[]string, colVars eval.VarValuesMap) error { + colVars[ReaderAlias] = map[string]any{} + for colName, colDef := range frDef.Columns { + colData := (*line)[colDef.Csv.SrcColIdx] + switch colDef.Type { + case FieldTypeString: + if len(colDef.Csv.SrcColFormat) > 0 { + return fmt.Errorf("cannot read string column %s, data '%s': format '%s' was specified, but string fields do not accept format specifier, remove this setting", colName, colData, colDef.Csv.SrcColFormat) + } + if len(colData) == 0 { + if len(colDef.DefaultValue) > 0 { + colVars[ReaderAlias][colName] = colDef.DefaultValue + } else { + colVars[ReaderAlias][colName] = GetDefaultFieldTypeValue(FieldTypeString) + } + } else { + colVars[ReaderAlias][colName] = colData + } + + case FieldTypeBool: + if len(colDef.Csv.SrcColFormat) > 0 { + return fmt.Errorf("cannot read bool column %s, data '%s': format '%s' was specified, but bool fields do not accept format specifier, remove this setting", colName, colData, colDef.Csv.SrcColFormat) + } + + var err error + if len(strings.TrimSpace(colData)) == 0 { + if len(strings.TrimSpace(colDef.DefaultValue)) > 0 { + colVars[ReaderAlias][colName], err = strconv.ParseBool(colDef.DefaultValue) + if err != nil { + return fmt.Errorf("cannot read bool column %s, from default value string '%s', allowed values are true,false,T,F,0,1: %s", colName, colDef.DefaultValue, err.Error()) + } + } else { + colVars[ReaderAlias][colName] = GetDefaultFieldTypeValue(FieldTypeBool) + } + } else { + colVars[ReaderAlias][colName], err = strconv.ParseBool(colData) + if err != nil { + return fmt.Errorf("cannot read bool column %s, data '%s', allowed values are true,false,T,F,0,1: %s", colName, colData, err.Error()) + } + } + + case FieldTypeInt: + if len(strings.TrimSpace(colData)) == 0 { + if len(strings.TrimSpace(colDef.DefaultValue)) > 0 { + valInt, err := strconv.ParseInt(colDef.DefaultValue, 10, 64) + if err != nil { + return fmt.Errorf("cannot read int64 column %s from default value string '%s': %s", colName, colDef.DefaultValue, err.Error()) + } + colVars[ReaderAlias][colName] = valInt + } else { + colVars[ReaderAlias][colName] = GetDefaultFieldTypeValue(FieldTypeInt) + } + } else { + if len(colDef.Csv.SrcColFormat) > 0 { + var valInt int64 + _, err := fmt.Sscanf(colData, colDef.Csv.SrcColFormat, &valInt) + if err != nil { + return fmt.Errorf("cannot read int64 column %s, data '%s', format '%s': %s", colName, colData, colDef.Csv.SrcColFormat, err.Error()) + } + colVars[ReaderAlias][colName] = valInt + } else { + valInt, err := strconv.ParseInt(colData, 10, 64) + if err != nil { + return fmt.Errorf("cannot read int64 column %s, data '%s', no format: %s", colName, colData, err.Error()) + } + colVars[ReaderAlias][colName] = valInt + } + } + + case FieldTypeDateTime: + if len(strings.TrimSpace(colData)) == 0 { + if len(strings.TrimSpace(colDef.DefaultValue)) > 0 { + valTime, err := time.Parse(colDef.Csv.SrcColFormat, colDef.DefaultValue) + if err != nil { + return fmt.Errorf("cannot read time column %s from default value string '%s': %s", colName, colDef.DefaultValue, err.Error()) + } + colVars[ReaderAlias][colName] = valTime + } else { + colVars[ReaderAlias][colName] = GetDefaultFieldTypeValue(FieldTypeDateTime) + } + } else { + if len(colDef.Csv.SrcColFormat) == 0 { + return fmt.Errorf("cannot read datetime column %s, data '%s': column format is missing, consider specifying something like 2006-01-02T15:04:05.000-0700, see go datetime format documentation for details", colName, colData) + } + + valTime, err := time.Parse(colDef.Csv.SrcColFormat, colData) + if err != nil { + return fmt.Errorf("cannot read datetime column %s, data '%s', format '%s': %s", colName, colData, colDef.Csv.SrcColFormat, err.Error()) + } + colVars[ReaderAlias][colName] = valTime + } + + case FieldTypeFloat: + if len(strings.TrimSpace(colData)) == 0 { + if len(strings.TrimSpace(colDef.DefaultValue)) > 0 { + valFloat, err := strconv.ParseFloat(colDef.DefaultValue, 64) + if err != nil { + return fmt.Errorf("cannot read float64 column %s from default value string '%s': %s", colName, colDef.DefaultValue, err.Error()) + } + colVars[ReaderAlias][colName] = valFloat + } else { + colVars[ReaderAlias][colName] = GetDefaultFieldTypeValue(FieldTypeFloat) + } + } else { + if len(colDef.Csv.SrcColFormat) > 0 { + var valFloat float64 + _, err := fmt.Sscanf(colData, colDef.Csv.SrcColFormat, &valFloat) + if err != nil { + return fmt.Errorf("cannot read float64 column %s, data '%s', format '%s': %s", colName, colData, colDef.Csv.SrcColFormat, err.Error()) + } + colVars[ReaderAlias][colName] = valFloat + } else { + valFloat, err := strconv.ParseFloat(colData, 64) + if err != nil { + return fmt.Errorf("cannot read float64 column %s, data '%s', no format: %s", colName, colData, err.Error()) + } + colVars[ReaderAlias][colName] = valFloat + } + } + + case FieldTypeDecimal2: + // Round to 2 digits after decimal point right away + if len(strings.TrimSpace(colData)) == 0 { + if len(strings.TrimSpace(colDef.DefaultValue)) > 0 { + valDec, err := decimal.NewFromString(colDef.DefaultValue) + if err != nil { + return fmt.Errorf("cannot read decimal2 column %s from default value string '%s': %s", colName, colDef.DefaultValue, err.Error()) + } + colVars[ReaderAlias][colName] = valDec.Round(2) + } else { + colVars[ReaderAlias][colName] = GetDefaultFieldTypeValue(FieldTypeDecimal2) + } + } else { + var valFloat float64 + if len(colDef.Csv.SrcColFormat) > 0 { + // Decimal type does not support sscanf, so sscanf string first + _, err := fmt.Sscanf(colData, colDef.Csv.SrcColFormat, &valFloat) + if err != nil { + return fmt.Errorf("cannot read decimal2 column %s, data '%s', format '%s': %s", colName, colData, colDef.Csv.SrcColFormat, err.Error()) + } + colVars[ReaderAlias][colName] = decimal.NewFromFloat(valFloat).Round(2) + } else { + valDec, err := decimal.NewFromString(colData) + if err != nil { + return fmt.Errorf("cannot read decimal2 column %s, cannot parse data '%s': %s", colName, colData, err.Error()) + } + colVars[ReaderAlias][colName] = valDec.Round(2) + } + } + + default: + return fmt.Errorf("cannot read column %s, data '%s': unsupported column type '%s'", colName, colData, colDef.Type) + } + } + return nil +} diff --git a/pkg/sc/file_reader_def_test.go b/pkg/sc/file_reader_def_test.go index 14fffeb..ed199fe 100644 --- a/pkg/sc/file_reader_def_test.go +++ b/pkg/sc/file_reader_def_test.go @@ -1,629 +1,629 @@ -package sc - -import ( - "fmt" - "strings" - "testing" - "time" - - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/shopspring/decimal" - "github.com/stretchr/testify/assert" -) - -func assertErrorPrefix(t *testing.T, expectedErrorPrefix string, actualError string) { - if !strings.HasPrefix(actualError, expectedErrorPrefix) { - t.Errorf("\nExpected error prefix:\n%s\nGot error:\n%s", expectedErrorPrefix, actualError) - } -} - -func testReader(fileReaderJson string, srcLine []string) (eval.VarValuesMap, error) { - fileReader := FileReaderDef{} - if err := fileReader.Deserialize([]byte(fileReaderJson)); err != nil { - return nil, err - } - - if fileReader.Csv.ColumnIndexingMode == FileColumnIndexingName { - srcHdrLine := []string{"order_id", "customer_id", "order_status", "order_purchase_timestamp"} - if err := fileReader.ResolveCsvColumnIndexesFromNames(srcHdrLine); err != nil { - return nil, err - } - } - - colRecord := eval.VarValuesMap{} - if err := fileReader.ReadCsvLineToValuesMap(&srcLine, colRecord); err != nil { - return nil, err - } - - return colRecord, nil -} - -func TestFieldRefs(t *testing.T) { - conf := ` - { - "urls": [""], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_order_id": { - "csv":{ - "col_idx": 0, - "col_hdr": null - }, - "col_type": "string" - }, - "col_order_status": { - "csv":{ - "col_idx": 2, - "col_hdr": null - }, - "col_type": "string" - }, - "col_order_purchase_timestamp": { - "csv":{ - "col_idx": 3, - "col_hdr": null, - "col_format": "2006-01-02 15:04:05" - }, - "col_type": "datetime" - } - } - }` - reader := FileReaderDef{} - assert.Nil(t, reader.Deserialize([]byte(conf))) - - fieldRefs := reader.getFieldRefs() - var fr *FieldRef - fr, _ = fieldRefs.FindByFieldName("col_order_id") - assert.Equal(t, ReaderAlias, fr.TableName) - assert.Equal(t, FieldTypeString, fr.FieldType) - fr, _ = fieldRefs.FindByFieldName("col_order_status") - assert.Equal(t, ReaderAlias, fr.TableName) - assert.Equal(t, FieldTypeString, fr.FieldType) - fr, _ = fieldRefs.FindByFieldName("col_order_purchase_timestamp") - assert.Equal(t, ReaderAlias, fr.TableName) - assert.Equal(t, FieldTypeDateTime, fr.FieldType) -} - -func TestColumnIndexing(t *testing.T) { - srcLine := []string{"order_id_1", "customer_id_1", "delivered", "2017-10-02 10:56:33"} - - // Good by idx - colRecord, err := testReader(` - { - "urls": [""], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_order_id": { - "csv":{ - "col_idx": 0, - "col_hdr": null - }, - "col_type": "string" - }, - "col_order_status": { - "csv":{ - "col_idx": 2, - "col_hdr": null - }, - "col_type": "string" - }, - "col_order_purchase_timestamp": { - "csv":{ - "col_idx": 3, - "col_hdr": null, - "col_format": "2006-01-02 15:04:05" - }, - "col_type": "datetime" - } - } - }`, srcLine) - assert.Nil(t, err) - - assert.Equal(t, srcLine[0], colRecord[ReaderAlias]["col_order_id"]) - assert.Equal(t, srcLine[2], colRecord[ReaderAlias]["col_order_status"]) - assert.Equal(t, time.Date(2017, 10, 2, 10, 56, 33, 0, time.UTC), colRecord[ReaderAlias]["col_order_purchase_timestamp"]) - - // Good by name - colRecord, err = testReader(` - { - "urls": [""], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_order_id": { - "csv":{ - "col_hdr": "order_id" - }, - "col_type": "string" - }, - "col_order_status": { - "csv":{ - "col_hdr": "order_status" - }, - "col_type": "string" - }, - "col_order_purchase_timestamp": { - "csv":{ - "col_hdr": "order_purchase_timestamp", - "col_format": "2006-01-02 15:04:05" - }, - "col_type": "datetime" - } - } - }`, srcLine) - assert.Nil(t, err) - - // Bad col idx - _, err = testReader(` - { - "urls": [""], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_order_id": { - "csv":{ - "col_idx": -1 - }, - "col_type": "string" - } - } - }`, srcLine) - - assertErrorPrefix(t, "cannot detect csv column indexing mode: [file reader column definition cannot use negative column index: -1]", err.Error()) - - // Bad number of source files - _, err = testReader(` - { - "urls": [], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_order_id": { - "csv":{ - "col_hdr": "order_id" - }, - "col_type": "string" - }, - "col_order_status": { - "csv":{ - "col_hdr": "order_status" - }, - "col_type": "string" - }, - "col_order_purchase_timestamp": { - "csv":{ - "col_hdr": "order_purchase_timestamp", - "col_format": "2006-01-02 15:04:05" - }, - "col_type": "datetime" - } - } - }`, srcLine) - assertErrorPrefix(t, "no source file urls specified", err.Error()) - - // Bad mixed indexing mode (second column says by idx, first and third say by name) - _, err = testReader(` - { - "urls": [""], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_order_id": { - "csv":{ - "col_idx": 1, - "col_hdr": "order_id" - }, - "col_type": "string" - }, - "col_order_status": { - "csv":{ - "col_idx": 2 - }, - "col_type": "string" - }, - "col_order_purchase_timestamp": { - "csv":{ - "col_hdr": "order_purchase_timestamp", - "col_format": "2006-01-02 15:04:05" - }, - "col_type": "datetime" - } - } - }`, srcLine) - assertErrorPrefix(t, "cannot detect csv column indexing mode", err.Error()) - - // Bad: cannot find file header some_unknown_col - _, err = testReader(` - { - "urls": [""], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_order_id": { - "csv":{ - "col_idx": 1, - "col_hdr": "order_id" - }, - "col_type": "string" - }, - "col_order_status": { - "csv":{ - "col_idx": 2, - "col_hdr": "some_unknown_col" - }, - "col_type": "string" - }, - "col_order_purchase_timestamp": { - "csv":{ - "col_hdr": "order_purchase_timestamp", - "col_format": "2006-01-02 15:04:05" - }, - "col_type": "datetime" - } - } - }`, srcLine) - assertErrorPrefix(t, "cannot resove all 3 source file column indexes, resolved only 2", err.Error()) -} - -func TestReadString(t *testing.T) { - confTemplate := ` - { - "urls": [""], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_1": { - "csv":{ - %s - "col_idx": 1 - }, - %s - "col_type": "string" - } - } - }` - - confNoFormatNoDefault := fmt.Sprintf(confTemplate, ``, ``) - confNoFormatWithDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"default_str",`) - confWithFormat := fmt.Sprintf(confTemplate, `"col_format": "some_format",`, ``) - - srcLineWithData := []string{"", "data_str", ""} - srcLineEmpty := []string{"", "", ""} - - goodTestScenarios := [][]interface{}{ - {confNoFormatNoDefault, srcLineWithData, "data_str"}, - {confNoFormatNoDefault, srcLineEmpty, ""}, - {confNoFormatWithDefault, srcLineEmpty, "default_str"}, - } - - for i := 0; i < len(goodTestScenarios); i++ { - scenario := goodTestScenarios[i] - colRecord, err := testReader(scenario[0].(string), scenario[1].([]string)) - assert.Nil(t, err) - assert.Equal(t, scenario[2], colRecord[ReaderAlias]["col_1"], fmt.Sprintf("Test %d", i)) - } - - var err error - _, err = testReader(confWithFormat, srcLineWithData) - assertErrorPrefix(t, "cannot read string column col_1, data 'data_str': format 'some_format' was specified, but string fields do not accept format specifier, remove this setting", err.Error()) -} - -func TestReadDatetime(t *testing.T) { - confTemplate := ` - { - "urls": [""], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_1": { - "csv":{ - %s - "col_idx": 1 - }, - %s - "col_type": "datetime" - } - } - }` - - confGoodFormatGoodDefault := fmt.Sprintf(confTemplate, `"col_format": "2006-01-02T15:04:05.000",`, `"col_default_value":"2001-07-07T11:22:33.700",`) - confGoodFormatNoDefault := fmt.Sprintf(confTemplate, `"col_format": "2006-01-02T15:04:05.000",`, ``) - confNoFormatNoDefault := fmt.Sprintf(confTemplate, ``, ``) - confNoFormatGoodDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"2001-07-07T11:22:33.700",`) - confGoodFormatBadDefault := fmt.Sprintf(confTemplate, `"col_format": "2006-01-02T15:04:05.000",`, `"col_default_value":"2001-07-07aaa11:22:33.700",`) - confBadFormatGoodDefault := fmt.Sprintf(confTemplate, `"col_format": "2006-01-02ccc15:04:05.000",`, `"col_default_value":"2001-07-07T11:22:33.700",`) - confBadFormatBadDefault := fmt.Sprintf(confTemplate, `"col_format": "2006-01-02ccc15:04:05.000",`, `"col_default_value":"2001-07-07aaa11:22:33.700",`) - - srcLineGood := []string{"", "2017-10-02T10:56:33.155"} - srcLineBad := []string{"", "2017-10-02bbb10:56:33.155"} - srcLineEmpty := []string{"", ""} - - goodVal := time.Date(2017, time.October, 2, 10, 56, 33, 155000000, time.UTC) - defaultVal := time.Date(2001, time.July, 7, 11, 22, 33, 700000000, time.UTC) - nullVal := GetDefaultFieldTypeValue(FieldTypeDateTime) - - goodTestScenarios := [][]interface{}{ - {confGoodFormatGoodDefault, srcLineGood, goodVal}, - {confGoodFormatGoodDefault, srcLineEmpty, defaultVal}, - {confGoodFormatNoDefault, srcLineEmpty, nullVal}, - } - - for i := 0; i < len(goodTestScenarios); i++ { - scenario := goodTestScenarios[i] - colRecord, err := testReader(scenario[0].(string), scenario[1].([]string)) - assert.Nil(t, err) - assert.Equal(t, scenario[2], colRecord[ReaderAlias]["col_1"], fmt.Sprintf("Test %d", i)) - } - - var err error - _, err = testReader(confNoFormatNoDefault, srcLineGood) - assertErrorPrefix(t, "cannot read datetime column col_1, data '2017-10-02T10:56:33.155': column format is missing, consider specifying something like 2006-01-02T15:04:05.000-0700, see go datetime format documentation for details", err.Error()) - _, err = testReader(confBadFormatGoodDefault, srcLineGood) - assertErrorPrefix(t, `cannot read datetime column col_1, data '2017-10-02T10:56:33.155', format '2006-01-02ccc15:04:05.000': parsing time "2017-10-02T10:56:33.155" as "2006-01-02ccc15:04:05.000": cannot parse "T10:56:33.155" as "ccc"`, err.Error()) - _, err = testReader(confBadFormatBadDefault, srcLineEmpty) - assertErrorPrefix(t, `cannot read time column col_1 from default value string '2001-07-07aaa11:22:33.700': parsing time "2001-07-07aaa11:22:33.700" as "2006-01-02ccc15:04:05.000": cannot parse "aaa11:22:33.700" as "ccc"`, err.Error()) - _, err = testReader(confNoFormatGoodDefault, srcLineEmpty) - assertErrorPrefix(t, `cannot read time column col_1 from default value string '2001-07-07T11:22:33.700': parsing time "2001-07-07T11:22:33.700": extra text: "2001-07-07T11:22:33.700"`, err.Error()) - _, err = testReader(confGoodFormatBadDefault, srcLineEmpty) - assertErrorPrefix(t, `cannot read time column col_1 from default value string '2001-07-07aaa11:22:33.700': parsing time "2001-07-07aaa11:22:33.700" as "2006-01-02T15:04:05.000": cannot parse "aaa11:22:33.700" as "T"`, err.Error()) - _, err = testReader(confGoodFormatGoodDefault, srcLineBad) - assertErrorPrefix(t, `cannot read datetime column col_1, data '2017-10-02bbb10:56:33.155', format '2006-01-02T15:04:05.000': parsing time "2017-10-02bbb10:56:33.155" as "2006-01-02T15:04:05.000": cannot parse "bbb10:56:33.155" as "T"`, err.Error()) -} - -func TestReadInt(t *testing.T) { - confTemplate := ` - { - "urls": [""], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_1": { - "csv":{ - %s - "col_idx": 1 - }, - %s - "col_type": "int" - } - } - }` - - confComplexFormatWithDefault := fmt.Sprintf(confTemplate, `"col_format": "value(%d)",`, `"col_default_value":"123",`) - confSimpleFormatNoDefault := fmt.Sprintf(confTemplate, `"col_format": "%d",`, ``) - confNoFormatNoDefault := fmt.Sprintf(confTemplate, ``, ``) - confNoFormatBadDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"badstring",`) - - srcLineComplexFormat := []string{"", "value(111)", ""} - srcLineSimpleFormat := []string{"", "111", ""} - srcLineEmpty := []string{"", "", ""} - - goodTestScenarios := [][]interface{}{ - {confComplexFormatWithDefault, srcLineComplexFormat, int64(111)}, - {confSimpleFormatNoDefault, srcLineSimpleFormat, int64(111)}, - {confNoFormatNoDefault, srcLineSimpleFormat, int64(111)}, - {confComplexFormatWithDefault, srcLineEmpty, int64(123)}, - {confSimpleFormatNoDefault, srcLineEmpty, int64(0)}, - } - - for i := 0; i < len(goodTestScenarios); i++ { - scenario := goodTestScenarios[i] - colRecord, err := testReader(scenario[0].(string), scenario[1].([]string)) - assert.Nil(t, err) - assert.Equal(t, scenario[2], colRecord[ReaderAlias]["col_1"], fmt.Sprintf("Test %d", i)) - } - - var err error - _, err = testReader(confSimpleFormatNoDefault, srcLineComplexFormat) - assertErrorPrefix(t, "cannot read int64 column col_1, data 'value(111)', format '%d': expected integer", err.Error()) - _, err = testReader(confComplexFormatWithDefault, srcLineSimpleFormat) - assertErrorPrefix(t, "cannot read int64 column col_1, data '111', format 'value(%d)': input does not match format", err.Error()) - _, err = testReader(confNoFormatBadDefault, srcLineEmpty) - assertErrorPrefix(t, `cannot read int64 column col_1 from default value string 'badstring': strconv.ParseInt: parsing "badstring": invalid syntax`, err.Error()) - _, err = testReader(confNoFormatBadDefault, srcLineComplexFormat) - assertErrorPrefix(t, `cannot read int64 column col_1, data 'value(111)', no format: strconv.ParseInt: parsing "value(111)": invalid syntax`, err.Error()) -} - -func TestReadFloat(t *testing.T) { - confTemplate := ` - { - "urls": [""], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_1": { - "csv":{ - %s - "col_idx": 1 - }, - %s - "col_type": "float" - } - } - }` - - confComplexFormatWithDefault := fmt.Sprintf(confTemplate, `"col_format": "value(%f)",`, `"col_default_value":"5.697",`) - confSimpleFormatNoDefault := fmt.Sprintf(confTemplate, `"col_format": "%f",`, ``) - confNoFormatNoDefault := fmt.Sprintf(confTemplate, ``, ``) - confNoFormatBadDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"badstring",`) - - srcLineComplexFormat := []string{"", "value(111.222)", ""} - srcLineSimpleFormat := []string{"", "111.222", ""} - srcLineEmpty := []string{"", "", ""} - - goodTestScenarios := [][]interface{}{ - {confComplexFormatWithDefault, srcLineComplexFormat, float64(111.222)}, - {confSimpleFormatNoDefault, srcLineSimpleFormat, float64(111.222)}, - {confNoFormatNoDefault, srcLineSimpleFormat, float64(111.222)}, - {confComplexFormatWithDefault, srcLineEmpty, float64(5.697)}, - {confSimpleFormatNoDefault, srcLineEmpty, float64(0.0)}, - } - - for i := 0; i < len(goodTestScenarios); i++ { - scenario := goodTestScenarios[i] - colRecord, err := testReader(scenario[0].(string), scenario[1].([]string)) - assert.Nil(t, err) - assert.Equal(t, scenario[2], colRecord[ReaderAlias]["col_1"], fmt.Sprintf("Test %d", i)) - } - - var err error - _, err = testReader(confSimpleFormatNoDefault, srcLineComplexFormat) - assertErrorPrefix(t, `cannot read float64 column col_1, data 'value(111.222)', format '%f': strconv.ParseFloat: parsing "": invalid syntax`, err.Error()) - _, err = testReader(confComplexFormatWithDefault, srcLineSimpleFormat) - assertErrorPrefix(t, "cannot read float64 column col_1, data '111.222', format 'value(%f)': input does not match format", err.Error()) - _, err = testReader(confNoFormatBadDefault, srcLineEmpty) - assertErrorPrefix(t, `cannot read float64 column col_1 from default value string 'badstring': strconv.ParseFloat: parsing "badstring": invalid syntax`, err.Error()) - _, err = testReader(confNoFormatBadDefault, srcLineComplexFormat) - assertErrorPrefix(t, `cannot read float64 column col_1, data 'value(111.222)', no format: strconv.ParseFloat: parsing "value(111.222)": invalid syntax`, err.Error()) -} - -func TestReadDecimal(t *testing.T) { - confTemplate := ` - { - "urls": [""], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_1": { - "csv":{ - %s - "col_idx": 1 - }, - %s - "col_type": "decimal2" - } - } - }` - - confComplexFormatWithDefault := fmt.Sprintf(confTemplate, `"col_format": "value(%f)",`, `"col_default_value":"-56.78",`) - confSimpleFormatNoDefault := fmt.Sprintf(confTemplate, `"col_format": "%f",`, ``) - confNoFormatNoDefault := fmt.Sprintf(confTemplate, ``, ``) - confNoFormatBadDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"badstring",`) - - srcLineComplexFormat := []string{"", "value(12.34)", ""} - srcLineSimpleFormat := []string{"", "12.34", ""} - srcLineEmpty := []string{"", "", ""} - - goodTestScenarios := [][]interface{}{ - {confComplexFormatWithDefault, srcLineComplexFormat, decimal.NewFromFloat32(12.34)}, - {confSimpleFormatNoDefault, srcLineSimpleFormat, decimal.NewFromFloat32(12.34)}, - {confNoFormatNoDefault, srcLineSimpleFormat, decimal.NewFromFloat32(12.34)}, - {confComplexFormatWithDefault, srcLineEmpty, decimal.NewFromFloat32(-56.78)}, - {confSimpleFormatNoDefault, srcLineEmpty, decimal.NewFromFloat32(0.0)}, - } - - for i := 0; i < len(goodTestScenarios); i++ { - scenario := goodTestScenarios[i] - colRecord, err := testReader(scenario[0].(string), scenario[1].([]string)) - assert.Nil(t, err) - assert.Equal(t, scenario[2], colRecord[ReaderAlias]["col_1"], fmt.Sprintf("Test %d", i)) - } - - var err error - _, err = testReader(confSimpleFormatNoDefault, srcLineComplexFormat) - assertErrorPrefix(t, `cannot read decimal2 column col_1, data 'value(12.34)', format '%f': strconv.ParseFloat: parsing "": invalid syntax`, err.Error()) - _, err = testReader(confComplexFormatWithDefault, srcLineSimpleFormat) - assertErrorPrefix(t, "cannot read decimal2 column col_1, data '12.34', format 'value(%f)': input does not match format", err.Error()) - _, err = testReader(confNoFormatBadDefault, srcLineEmpty) - assertErrorPrefix(t, `cannot read decimal2 column col_1 from default value string 'badstring': can't convert badstring to decimal`, err.Error()) - _, err = testReader(confNoFormatBadDefault, srcLineComplexFormat) - assertErrorPrefix(t, `cannot read decimal2 column col_1, cannot parse data 'value(12.34)': can't convert value(12.34) to decimal: exponent is not numeric`, err.Error()) -} -func TestReadBool(t *testing.T) { - - confTemplate := ` - { - "urls": [""], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_1": { - "csv":{ - %s - "col_idx": 1 - }, - %s - "col_type": "bool" - } - } - }` - - confNoFormatNoDefault := fmt.Sprintf(confTemplate, ``, ``) - confNoFormatWithDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"TRUE",`) - confNoFormatBadDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"baddefault",`) - confWithFormat := fmt.Sprintf(confTemplate, `"col_format": "some_format",`, ``) - - srcLineTrue := []string{"", "True", ""} - srcLineFalse := []string{"", "False", ""} - srcLineTrueCap := []string{"", "TRUE", ""} - srcLineFalseCap := []string{"", "FALSE", ""} - srcLineTrueSmall := []string{"", "true", ""} - srcLineFalseSmall := []string{"", "false", ""} - srcLineT := []string{"", "T", ""} - srcLineF := []string{"", "F", ""} - srcLineTSmall := []string{"", "t", ""} - srcLineFSmall := []string{"", "f", ""} - srcLine0 := []string{"", "0", ""} - srcLine1 := []string{"", "1", ""} - srcLineEmpty := []string{"", "", ""} - srcLineBad := []string{"", "bad", ""} - - goodTestScenarios := [][]interface{}{ - {confNoFormatNoDefault, srcLineTrue, true}, - {confNoFormatNoDefault, srcLineFalse, false}, - {confNoFormatNoDefault, srcLineTrueCap, true}, - {confNoFormatNoDefault, srcLineFalseCap, false}, - {confNoFormatNoDefault, srcLineTrueSmall, true}, - {confNoFormatNoDefault, srcLineFalseSmall, false}, - {confNoFormatNoDefault, srcLineT, true}, - {confNoFormatNoDefault, srcLineF, false}, - {confNoFormatNoDefault, srcLineTSmall, true}, - {confNoFormatNoDefault, srcLineFSmall, false}, - {confNoFormatNoDefault, srcLine1, true}, - {confNoFormatNoDefault, srcLine0, false}, - {confNoFormatWithDefault, srcLineEmpty, true}, - {confNoFormatNoDefault, srcLineEmpty, false}, - } - - for i := 0; i < len(goodTestScenarios); i++ { - scenario := goodTestScenarios[i] - colRecord, err := testReader(scenario[0].(string), scenario[1].([]string)) - assert.Nil(t, err) - assert.Equal(t, scenario[2], colRecord[ReaderAlias]["col_1"], fmt.Sprintf("Test %d", i)) - } - - var err error - _, err = testReader(confNoFormatNoDefault, srcLineBad) - assertErrorPrefix(t, `cannot read bool column col_1, data 'bad', allowed values are true,false,T,F,0,1: strconv.ParseBool: parsing "bad": invalid syntax`, err.Error()) - _, err = testReader(confNoFormatBadDefault, srcLineEmpty) - assertErrorPrefix(t, `cannot read bool column col_1, from default value string 'baddefault', allowed values are true,false,T,F,0,1: strconv.ParseBool: parsing "baddefault": invalid syntax`, err.Error()) - _, err = testReader(confWithFormat, srcLineTrue) - assertErrorPrefix(t, `cannot read bool column col_1, data 'True': format 'some_format' was specified, but bool fields do not accept format specifier, remove this setting`, err.Error()) -} +package sc + +import ( + "fmt" + "strings" + "testing" + "time" + + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" +) + +func assertErrorPrefix(t *testing.T, expectedErrorPrefix string, actualError string) { + if !strings.HasPrefix(actualError, expectedErrorPrefix) { + t.Errorf("\nExpected error prefix:\n%s\nGot error:\n%s", expectedErrorPrefix, actualError) + } +} + +func testReader(fileReaderJson string, srcLine []string) (eval.VarValuesMap, error) { + fileReader := FileReaderDef{} + if err := fileReader.Deserialize([]byte(fileReaderJson)); err != nil { + return nil, err + } + + if fileReader.Csv.ColumnIndexingMode == FileColumnIndexingName { + srcHdrLine := []string{"order_id", "customer_id", "order_status", "order_purchase_timestamp"} + if err := fileReader.ResolveCsvColumnIndexesFromNames(srcHdrLine); err != nil { + return nil, err + } + } + + colRecord := eval.VarValuesMap{} + if err := fileReader.ReadCsvLineToValuesMap(&srcLine, colRecord); err != nil { + return nil, err + } + + return colRecord, nil +} + +func TestFieldRefs(t *testing.T) { + conf := ` + { + "urls": [""], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_order_id": { + "csv":{ + "col_idx": 0, + "col_hdr": null + }, + "col_type": "string" + }, + "col_order_status": { + "csv":{ + "col_idx": 2, + "col_hdr": null + }, + "col_type": "string" + }, + "col_order_purchase_timestamp": { + "csv":{ + "col_idx": 3, + "col_hdr": null, + "col_format": "2006-01-02 15:04:05" + }, + "col_type": "datetime" + } + } + }` + reader := FileReaderDef{} + assert.Nil(t, reader.Deserialize([]byte(conf))) + + fieldRefs := reader.getFieldRefs() + var fr *FieldRef + fr, _ = fieldRefs.FindByFieldName("col_order_id") + assert.Equal(t, ReaderAlias, fr.TableName) + assert.Equal(t, FieldTypeString, fr.FieldType) + fr, _ = fieldRefs.FindByFieldName("col_order_status") + assert.Equal(t, ReaderAlias, fr.TableName) + assert.Equal(t, FieldTypeString, fr.FieldType) + fr, _ = fieldRefs.FindByFieldName("col_order_purchase_timestamp") + assert.Equal(t, ReaderAlias, fr.TableName) + assert.Equal(t, FieldTypeDateTime, fr.FieldType) +} + +func TestColumnIndexing(t *testing.T) { + srcLine := []string{"order_id_1", "customer_id_1", "delivered", "2017-10-02 10:56:33"} + + // Good by idx + colRecord, err := testReader(` + { + "urls": [""], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_order_id": { + "csv":{ + "col_idx": 0, + "col_hdr": null + }, + "col_type": "string" + }, + "col_order_status": { + "csv":{ + "col_idx": 2, + "col_hdr": null + }, + "col_type": "string" + }, + "col_order_purchase_timestamp": { + "csv":{ + "col_idx": 3, + "col_hdr": null, + "col_format": "2006-01-02 15:04:05" + }, + "col_type": "datetime" + } + } + }`, srcLine) + assert.Nil(t, err) + + assert.Equal(t, srcLine[0], colRecord[ReaderAlias]["col_order_id"]) + assert.Equal(t, srcLine[2], colRecord[ReaderAlias]["col_order_status"]) + assert.Equal(t, time.Date(2017, 10, 2, 10, 56, 33, 0, time.UTC), colRecord[ReaderAlias]["col_order_purchase_timestamp"]) + + // Good by name + _, err = testReader(` + { + "urls": [""], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_order_id": { + "csv":{ + "col_hdr": "order_id" + }, + "col_type": "string" + }, + "col_order_status": { + "csv":{ + "col_hdr": "order_status" + }, + "col_type": "string" + }, + "col_order_purchase_timestamp": { + "csv":{ + "col_hdr": "order_purchase_timestamp", + "col_format": "2006-01-02 15:04:05" + }, + "col_type": "datetime" + } + } + }`, srcLine) + assert.Nil(t, err) + + // Bad col idx + _, err = testReader(` + { + "urls": [""], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_order_id": { + "csv":{ + "col_idx": -1 + }, + "col_type": "string" + } + } + }`, srcLine) + + assertErrorPrefix(t, "cannot detect csv column indexing mode: [file reader column definition cannot use negative column index: -1]", err.Error()) + + // Bad number of source files + _, err = testReader(` + { + "urls": [], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_order_id": { + "csv":{ + "col_hdr": "order_id" + }, + "col_type": "string" + }, + "col_order_status": { + "csv":{ + "col_hdr": "order_status" + }, + "col_type": "string" + }, + "col_order_purchase_timestamp": { + "csv":{ + "col_hdr": "order_purchase_timestamp", + "col_format": "2006-01-02 15:04:05" + }, + "col_type": "datetime" + } + } + }`, srcLine) + assertErrorPrefix(t, "no source file urls specified", err.Error()) + + // Bad mixed indexing mode (second column says by idx, first and third say by name) + _, err = testReader(` + { + "urls": [""], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_order_id": { + "csv":{ + "col_idx": 1, + "col_hdr": "order_id" + }, + "col_type": "string" + }, + "col_order_status": { + "csv":{ + "col_idx": 2 + }, + "col_type": "string" + }, + "col_order_purchase_timestamp": { + "csv":{ + "col_hdr": "order_purchase_timestamp", + "col_format": "2006-01-02 15:04:05" + }, + "col_type": "datetime" + } + } + }`, srcLine) + assertErrorPrefix(t, "cannot detect csv column indexing mode", err.Error()) + + // Bad: cannot find file header some_unknown_col + _, err = testReader(` + { + "urls": [""], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_order_id": { + "csv":{ + "col_idx": 1, + "col_hdr": "order_id" + }, + "col_type": "string" + }, + "col_order_status": { + "csv":{ + "col_idx": 2, + "col_hdr": "some_unknown_col" + }, + "col_type": "string" + }, + "col_order_purchase_timestamp": { + "csv":{ + "col_hdr": "order_purchase_timestamp", + "col_format": "2006-01-02 15:04:05" + }, + "col_type": "datetime" + } + } + }`, srcLine) + assertErrorPrefix(t, "cannot resove all 3 source file column indexes, resolved only 2", err.Error()) +} + +func TestReadString(t *testing.T) { + confTemplate := ` + { + "urls": [""], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_1": { + "csv":{ + %s + "col_idx": 1 + }, + %s + "col_type": "string" + } + } + }` + + confNoFormatNoDefault := fmt.Sprintf(confTemplate, ``, ``) + confNoFormatWithDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"default_str",`) + confWithFormat := fmt.Sprintf(confTemplate, `"col_format": "some_format",`, ``) + + srcLineWithData := []string{"", "data_str", ""} + srcLineEmpty := []string{"", "", ""} + + goodTestScenarios := [][]any{ + {confNoFormatNoDefault, srcLineWithData, "data_str"}, + {confNoFormatNoDefault, srcLineEmpty, ""}, + {confNoFormatWithDefault, srcLineEmpty, "default_str"}, + } + + for i := 0; i < len(goodTestScenarios); i++ { + scenario := goodTestScenarios[i] + colRecord, err := testReader(scenario[0].(string), scenario[1].([]string)) + assert.Nil(t, err) + assert.Equal(t, scenario[2], colRecord[ReaderAlias]["col_1"], fmt.Sprintf("Test %d", i)) + } + + var err error + _, err = testReader(confWithFormat, srcLineWithData) + assertErrorPrefix(t, "cannot read string column col_1, data 'data_str': format 'some_format' was specified, but string fields do not accept format specifier, remove this setting", err.Error()) +} + +func TestReadDatetime(t *testing.T) { + confTemplate := ` + { + "urls": [""], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_1": { + "csv":{ + %s + "col_idx": 1 + }, + %s + "col_type": "datetime" + } + } + }` + + confGoodFormatGoodDefault := fmt.Sprintf(confTemplate, `"col_format": "2006-01-02T15:04:05.000",`, `"col_default_value":"2001-07-07T11:22:33.700",`) + confGoodFormatNoDefault := fmt.Sprintf(confTemplate, `"col_format": "2006-01-02T15:04:05.000",`, ``) + confNoFormatNoDefault := fmt.Sprintf(confTemplate, ``, ``) + confNoFormatGoodDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"2001-07-07T11:22:33.700",`) + confGoodFormatBadDefault := fmt.Sprintf(confTemplate, `"col_format": "2006-01-02T15:04:05.000",`, `"col_default_value":"2001-07-07aaa11:22:33.700",`) + confBadFormatGoodDefault := fmt.Sprintf(confTemplate, `"col_format": "2006-01-02ccc15:04:05.000",`, `"col_default_value":"2001-07-07T11:22:33.700",`) + confBadFormatBadDefault := fmt.Sprintf(confTemplate, `"col_format": "2006-01-02ccc15:04:05.000",`, `"col_default_value":"2001-07-07aaa11:22:33.700",`) + + srcLineGood := []string{"", "2017-10-02T10:56:33.155"} + srcLineBad := []string{"", "2017-10-02bbb10:56:33.155"} + srcLineEmpty := []string{"", ""} + + goodVal := time.Date(2017, time.October, 2, 10, 56, 33, 155000000, time.UTC) + defaultVal := time.Date(2001, time.July, 7, 11, 22, 33, 700000000, time.UTC) + nullVal := GetDefaultFieldTypeValue(FieldTypeDateTime) + + goodTestScenarios := [][]any{ + {confGoodFormatGoodDefault, srcLineGood, goodVal}, + {confGoodFormatGoodDefault, srcLineEmpty, defaultVal}, + {confGoodFormatNoDefault, srcLineEmpty, nullVal}, + } + + for i := 0; i < len(goodTestScenarios); i++ { + scenario := goodTestScenarios[i] + colRecord, err := testReader(scenario[0].(string), scenario[1].([]string)) + assert.Nil(t, err) + assert.Equal(t, scenario[2], colRecord[ReaderAlias]["col_1"], fmt.Sprintf("Test %d", i)) + } + + var err error + _, err = testReader(confNoFormatNoDefault, srcLineGood) + assertErrorPrefix(t, "cannot read datetime column col_1, data '2017-10-02T10:56:33.155': column format is missing, consider specifying something like 2006-01-02T15:04:05.000-0700, see go datetime format documentation for details", err.Error()) + _, err = testReader(confBadFormatGoodDefault, srcLineGood) + assertErrorPrefix(t, `cannot read datetime column col_1, data '2017-10-02T10:56:33.155', format '2006-01-02ccc15:04:05.000': parsing time "2017-10-02T10:56:33.155" as "2006-01-02ccc15:04:05.000": cannot parse "T10:56:33.155" as "ccc"`, err.Error()) + _, err = testReader(confBadFormatBadDefault, srcLineEmpty) + assertErrorPrefix(t, `cannot read time column col_1 from default value string '2001-07-07aaa11:22:33.700': parsing time "2001-07-07aaa11:22:33.700" as "2006-01-02ccc15:04:05.000": cannot parse "aaa11:22:33.700" as "ccc"`, err.Error()) + _, err = testReader(confNoFormatGoodDefault, srcLineEmpty) + assertErrorPrefix(t, `cannot read time column col_1 from default value string '2001-07-07T11:22:33.700': parsing time "2001-07-07T11:22:33.700": extra text: "2001-07-07T11:22:33.700"`, err.Error()) + _, err = testReader(confGoodFormatBadDefault, srcLineEmpty) + assertErrorPrefix(t, `cannot read time column col_1 from default value string '2001-07-07aaa11:22:33.700': parsing time "2001-07-07aaa11:22:33.700" as "2006-01-02T15:04:05.000": cannot parse "aaa11:22:33.700" as "T"`, err.Error()) + _, err = testReader(confGoodFormatGoodDefault, srcLineBad) + assertErrorPrefix(t, `cannot read datetime column col_1, data '2017-10-02bbb10:56:33.155', format '2006-01-02T15:04:05.000': parsing time "2017-10-02bbb10:56:33.155" as "2006-01-02T15:04:05.000": cannot parse "bbb10:56:33.155" as "T"`, err.Error()) +} + +func TestReadInt(t *testing.T) { + confTemplate := ` + { + "urls": [""], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_1": { + "csv":{ + %s + "col_idx": 1 + }, + %s + "col_type": "int" + } + } + }` + + confComplexFormatWithDefault := fmt.Sprintf(confTemplate, `"col_format": "value(%d)",`, `"col_default_value":"123",`) + confSimpleFormatNoDefault := fmt.Sprintf(confTemplate, `"col_format": "%d",`, ``) + confNoFormatNoDefault := fmt.Sprintf(confTemplate, ``, ``) + confNoFormatBadDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"badstring",`) + + srcLineComplexFormat := []string{"", "value(111)", ""} + srcLineSimpleFormat := []string{"", "111", ""} + srcLineEmpty := []string{"", "", ""} + + goodTestScenarios := [][]any{ + {confComplexFormatWithDefault, srcLineComplexFormat, int64(111)}, + {confSimpleFormatNoDefault, srcLineSimpleFormat, int64(111)}, + {confNoFormatNoDefault, srcLineSimpleFormat, int64(111)}, + {confComplexFormatWithDefault, srcLineEmpty, int64(123)}, + {confSimpleFormatNoDefault, srcLineEmpty, int64(0)}, + } + + for i := 0; i < len(goodTestScenarios); i++ { + scenario := goodTestScenarios[i] + colRecord, err := testReader(scenario[0].(string), scenario[1].([]string)) + assert.Nil(t, err) + assert.Equal(t, scenario[2], colRecord[ReaderAlias]["col_1"], fmt.Sprintf("Test %d", i)) + } + + var err error + _, err = testReader(confSimpleFormatNoDefault, srcLineComplexFormat) + assertErrorPrefix(t, "cannot read int64 column col_1, data 'value(111)', format '%d': expected integer", err.Error()) + _, err = testReader(confComplexFormatWithDefault, srcLineSimpleFormat) + assertErrorPrefix(t, "cannot read int64 column col_1, data '111', format 'value(%d)': input does not match format", err.Error()) + _, err = testReader(confNoFormatBadDefault, srcLineEmpty) + assertErrorPrefix(t, `cannot read int64 column col_1 from default value string 'badstring': strconv.ParseInt: parsing "badstring": invalid syntax`, err.Error()) + _, err = testReader(confNoFormatBadDefault, srcLineComplexFormat) + assertErrorPrefix(t, `cannot read int64 column col_1, data 'value(111)', no format: strconv.ParseInt: parsing "value(111)": invalid syntax`, err.Error()) +} + +func TestReadFloat(t *testing.T) { + confTemplate := ` + { + "urls": [""], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_1": { + "csv":{ + %s + "col_idx": 1 + }, + %s + "col_type": "float" + } + } + }` + + confComplexFormatWithDefault := fmt.Sprintf(confTemplate, `"col_format": "value(%f)",`, `"col_default_value":"5.697",`) + confSimpleFormatNoDefault := fmt.Sprintf(confTemplate, `"col_format": "%f",`, ``) + confNoFormatNoDefault := fmt.Sprintf(confTemplate, ``, ``) + confNoFormatBadDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"badstring",`) + + srcLineComplexFormat := []string{"", "value(111.222)", ""} + srcLineSimpleFormat := []string{"", "111.222", ""} + srcLineEmpty := []string{"", "", ""} + + goodTestScenarios := [][]any{ + {confComplexFormatWithDefault, srcLineComplexFormat, float64(111.222)}, + {confSimpleFormatNoDefault, srcLineSimpleFormat, float64(111.222)}, + {confNoFormatNoDefault, srcLineSimpleFormat, float64(111.222)}, + {confComplexFormatWithDefault, srcLineEmpty, float64(5.697)}, + {confSimpleFormatNoDefault, srcLineEmpty, float64(0.0)}, + } + + for i := 0; i < len(goodTestScenarios); i++ { + scenario := goodTestScenarios[i] + colRecord, err := testReader(scenario[0].(string), scenario[1].([]string)) + assert.Nil(t, err) + assert.Equal(t, scenario[2], colRecord[ReaderAlias]["col_1"], fmt.Sprintf("Test %d", i)) + } + + var err error + _, err = testReader(confSimpleFormatNoDefault, srcLineComplexFormat) + assertErrorPrefix(t, `cannot read float64 column col_1, data 'value(111.222)', format '%f': strconv.ParseFloat: parsing "": invalid syntax`, err.Error()) + _, err = testReader(confComplexFormatWithDefault, srcLineSimpleFormat) + assertErrorPrefix(t, "cannot read float64 column col_1, data '111.222', format 'value(%f)': input does not match format", err.Error()) + _, err = testReader(confNoFormatBadDefault, srcLineEmpty) + assertErrorPrefix(t, `cannot read float64 column col_1 from default value string 'badstring': strconv.ParseFloat: parsing "badstring": invalid syntax`, err.Error()) + _, err = testReader(confNoFormatBadDefault, srcLineComplexFormat) + assertErrorPrefix(t, `cannot read float64 column col_1, data 'value(111.222)', no format: strconv.ParseFloat: parsing "value(111.222)": invalid syntax`, err.Error()) +} + +func TestReadDecimal(t *testing.T) { + confTemplate := ` + { + "urls": [""], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_1": { + "csv":{ + %s + "col_idx": 1 + }, + %s + "col_type": "decimal2" + } + } + }` + + confComplexFormatWithDefault := fmt.Sprintf(confTemplate, `"col_format": "value(%f)",`, `"col_default_value":"-56.78",`) + confSimpleFormatNoDefault := fmt.Sprintf(confTemplate, `"col_format": "%f",`, ``) + confNoFormatNoDefault := fmt.Sprintf(confTemplate, ``, ``) + confNoFormatBadDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"badstring",`) + + srcLineComplexFormat := []string{"", "value(12.34)", ""} + srcLineSimpleFormat := []string{"", "12.34", ""} + srcLineEmpty := []string{"", "", ""} + + goodTestScenarios := [][]any{ + {confComplexFormatWithDefault, srcLineComplexFormat, decimal.NewFromFloat32(12.34)}, + {confSimpleFormatNoDefault, srcLineSimpleFormat, decimal.NewFromFloat32(12.34)}, + {confNoFormatNoDefault, srcLineSimpleFormat, decimal.NewFromFloat32(12.34)}, + {confComplexFormatWithDefault, srcLineEmpty, decimal.NewFromFloat32(-56.78)}, + {confSimpleFormatNoDefault, srcLineEmpty, decimal.NewFromFloat32(0.0)}, + } + + for i := 0; i < len(goodTestScenarios); i++ { + scenario := goodTestScenarios[i] + colRecord, err := testReader(scenario[0].(string), scenario[1].([]string)) + assert.Nil(t, err) + assert.Equal(t, scenario[2], colRecord[ReaderAlias]["col_1"], fmt.Sprintf("Test %d", i)) + } + + var err error + _, err = testReader(confSimpleFormatNoDefault, srcLineComplexFormat) + assertErrorPrefix(t, `cannot read decimal2 column col_1, data 'value(12.34)', format '%f': strconv.ParseFloat: parsing "": invalid syntax`, err.Error()) + _, err = testReader(confComplexFormatWithDefault, srcLineSimpleFormat) + assertErrorPrefix(t, "cannot read decimal2 column col_1, data '12.34', format 'value(%f)': input does not match format", err.Error()) + _, err = testReader(confNoFormatBadDefault, srcLineEmpty) + assertErrorPrefix(t, `cannot read decimal2 column col_1 from default value string 'badstring': can't convert badstring to decimal`, err.Error()) + _, err = testReader(confNoFormatBadDefault, srcLineComplexFormat) + assertErrorPrefix(t, `cannot read decimal2 column col_1, cannot parse data 'value(12.34)': can't convert value(12.34) to decimal: exponent is not numeric`, err.Error()) +} +func TestReadBool(t *testing.T) { + + confTemplate := ` + { + "urls": [""], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_1": { + "csv":{ + %s + "col_idx": 1 + }, + %s + "col_type": "bool" + } + } + }` + + confNoFormatNoDefault := fmt.Sprintf(confTemplate, ``, ``) + confNoFormatWithDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"TRUE",`) + confNoFormatBadDefault := fmt.Sprintf(confTemplate, ``, `"col_default_value":"baddefault",`) + confWithFormat := fmt.Sprintf(confTemplate, `"col_format": "some_format",`, ``) + + srcLineTrue := []string{"", "True", ""} + srcLineFalse := []string{"", "False", ""} + srcLineTrueCap := []string{"", "TRUE", ""} + srcLineFalseCap := []string{"", "FALSE", ""} + srcLineTrueSmall := []string{"", "true", ""} + srcLineFalseSmall := []string{"", "false", ""} + srcLineT := []string{"", "T", ""} + srcLineF := []string{"", "F", ""} + srcLineTSmall := []string{"", "t", ""} + srcLineFSmall := []string{"", "f", ""} + srcLine0 := []string{"", "0", ""} + srcLine1 := []string{"", "1", ""} + srcLineEmpty := []string{"", "", ""} + srcLineBad := []string{"", "bad", ""} + + goodTestScenarios := [][]any{ + {confNoFormatNoDefault, srcLineTrue, true}, + {confNoFormatNoDefault, srcLineFalse, false}, + {confNoFormatNoDefault, srcLineTrueCap, true}, + {confNoFormatNoDefault, srcLineFalseCap, false}, + {confNoFormatNoDefault, srcLineTrueSmall, true}, + {confNoFormatNoDefault, srcLineFalseSmall, false}, + {confNoFormatNoDefault, srcLineT, true}, + {confNoFormatNoDefault, srcLineF, false}, + {confNoFormatNoDefault, srcLineTSmall, true}, + {confNoFormatNoDefault, srcLineFSmall, false}, + {confNoFormatNoDefault, srcLine1, true}, + {confNoFormatNoDefault, srcLine0, false}, + {confNoFormatWithDefault, srcLineEmpty, true}, + {confNoFormatNoDefault, srcLineEmpty, false}, + } + + for i := 0; i < len(goodTestScenarios); i++ { + scenario := goodTestScenarios[i] + colRecord, err := testReader(scenario[0].(string), scenario[1].([]string)) + assert.Nil(t, err) + assert.Equal(t, scenario[2], colRecord[ReaderAlias]["col_1"], fmt.Sprintf("Test %d", i)) + } + + var err error + _, err = testReader(confNoFormatNoDefault, srcLineBad) + assertErrorPrefix(t, `cannot read bool column col_1, data 'bad', allowed values are true,false,T,F,0,1: strconv.ParseBool: parsing "bad": invalid syntax`, err.Error()) + _, err = testReader(confNoFormatBadDefault, srcLineEmpty) + assertErrorPrefix(t, `cannot read bool column col_1, from default value string 'baddefault', allowed values are true,false,T,F,0,1: strconv.ParseBool: parsing "baddefault": invalid syntax`, err.Error()) + _, err = testReader(confWithFormat, srcLineTrue) + assertErrorPrefix(t, `cannot read bool column col_1, data 'True': format 'some_format' was specified, but bool fields do not accept format specifier, remove this setting`, err.Error()) +} diff --git a/pkg/sc/index_def.go b/pkg/sc/index_def.go index dd7a076..9e7f57b 100644 --- a/pkg/sc/index_def.go +++ b/pkg/sc/index_def.go @@ -1,230 +1,230 @@ -package sc - -import ( - "fmt" - "go/ast" - "go/parser" - "go/token" - "strconv" - "strings" -) - -const ( - DefaultStringComponentLen int64 = 64 - MinStringComponentLen int64 = 16 - MaxStringComponentLen int64 = 1024 -) - -type IdxSortOrder string - -const ( - IdxSortAsc IdxSortOrder = "asc" - IdxSortDesc IdxSortOrder = "desc" - IdxSortUnknown IdxSortOrder = "unknown" -) - -type IdxCaseSensitivity string - -const ( - IdxCaseSensitive IdxCaseSensitivity = "case_sensitive" - IdxIgnoreCase IdxCaseSensitivity = "ignore_case" - IdxCaseSensitivityUnknown IdxCaseSensitivity = "case_sensitivity_unknown" -) - -type IdxComponentDef struct { - FieldName string - CaseSensitivity IdxCaseSensitivity - SortOrder IdxSortOrder - StringLen int64 // For string fields only, default 64 - FieldType TableFieldType // Populated from tgt_table def -} - -type IdxUniqueness string - -const ( - IdxUnique IdxUniqueness = "unique" - IdxNonUnique IdxUniqueness = "non_unique" - IdxUniquenessUnknown IdxUniqueness = "unknown" -) - -type IdxDef struct { - Uniqueness IdxUniqueness - Components []IdxComponentDef -} - -type IdxDefMap map[string]*IdxDef - -type IndexRef struct { - TableName string - IdxName string -} - -func (idxDef *IdxDef) getComponentFieldRefs(tableName string) FieldRefs { - fieldRefs := make([]FieldRef, len(idxDef.Components)) - for i := 0; i < len(idxDef.Components); i++ { - fieldRefs[i] = FieldRef{ - TableName: tableName, - FieldName: idxDef.Components[i].FieldName, - FieldType: idxDef.Components[i].FieldType} - } - return fieldRefs -} - -func (idxDef *IdxDef) parseComponentExpr(fldExp *ast.Expr, fieldRefs *FieldRefs) error { - // Initialize index component with defaults and append it to idx def - idxCompDef := IdxComponentDef{ - FieldName: FieldNameUnknown, - CaseSensitivity: IdxCaseSensitivityUnknown, - SortOrder: IdxSortUnknown, - StringLen: DefaultStringComponentLen, // Users can override it, see below - FieldType: FieldTypeUnknown} - - switch (*fldExp).(type) { - case *ast.CallExpr: - callExp, _ := (*fldExp).(*ast.CallExpr) - identExp, _ := callExp.Fun.(*ast.Ident) - fieldRef, ok := fieldRefs.FindByFieldName(identExp.Name) - if !ok { - return fmt.Errorf("cannot parse order component func expression, field %s unknown", identExp.Name) - } - - // Defaults - idxCompDef.FieldType = (*fieldRef).FieldType - idxCompDef.FieldName = identExp.Name - - // Parse args: asc/desc, case_sensitive/ignore_case, number-string length - for _, modifierExp := range callExp.Args { - switch modifierExpType := modifierExp.(type) { - case *ast.Ident: - modIdentExp, _ := modifierExp.(*ast.Ident) - switch modIdentExp.Name { - case string(IdxCaseSensitive): - idxCompDef.CaseSensitivity = IdxCaseSensitive - case string(IdxIgnoreCase): - idxCompDef.CaseSensitivity = IdxIgnoreCase - case string(IdxSortAsc): - idxCompDef.SortOrder = IdxSortAsc - case string(IdxSortDesc): - idxCompDef.SortOrder = IdxSortDesc - default: - return fmt.Errorf( - "unknown modifier %s for field %s, expected %s,%s,%s,%s", - modIdentExp.Name, identExp.Name, IdxIgnoreCase, IdxCaseSensitive, IdxSortAsc, IdxSortDesc) - } - case *ast.BasicLit: - switch modifierExpType.Kind { - case token.INT: - if idxCompDef.FieldType != FieldTypeString { - return fmt.Errorf("invalid expression %v in %s, component length modifier is valid only for string fields, but %s has type %s", - modifierExpType, identExp.Name, idxCompDef.FieldName, idxCompDef.FieldType) - } - idxCompDef.StringLen, _ = strconv.ParseInt(modifierExpType.Value, 10, 64) - if idxCompDef.StringLen < MinStringComponentLen { - idxCompDef.StringLen = MinStringComponentLen - } else if idxCompDef.StringLen > MaxStringComponentLen { - return fmt.Errorf("invalid expression %v in %s, component length modifier for string fields cannot exceed %d", - modifierExpType, identExp.Name, MaxStringComponentLen) - } - default: - return fmt.Errorf("invalid expression %v in %s, expected an integer for string component length", modifierExpType, identExp.Name) - } - - default: - return fmt.Errorf( - "invalid expression %v, expected a modifier for field %s: expected %s,%s,%s,%s or an integer", - modifierExpType, identExp.Name, IdxIgnoreCase, IdxCaseSensitive, IdxSortAsc, IdxSortDesc) - } - - // Check some rules - if idxCompDef.FieldType != FieldTypeString && idxCompDef.CaseSensitivity != IdxCaseSensitivityUnknown { - return fmt.Errorf( - "index component for field %s of type %s cannot have case sensitivity modifier %s, remove it from index component definition", - identExp.Name, idxCompDef.FieldType, idxCompDef.CaseSensitivity) - } - } - - case *ast.Ident: - // This is a component def without modifiers (not filed1(...), just field1), so just apply defaults - identExp, _ := (*fldExp).(*ast.Ident) - - fieldRef, ok := fieldRefs.FindByFieldName(identExp.Name) - if !ok { - return fmt.Errorf("cannot parse order component ident expression, field %s unknown", identExp.Name) - } - - // Defaults - idxCompDef.FieldType = (*fieldRef).FieldType - idxCompDef.FieldName = identExp.Name - - default: - return fmt.Errorf( - "invalid expression in index component definition, expected 'field([modifiers])' or 'field' where 'field' is one of the fields of the table created by this node") - } - - // Apply defaults if no modifiers supplied: string -> case sensitive, ordered idx -> sort asc - if idxCompDef.FieldType == FieldTypeString && idxCompDef.CaseSensitivity == IdxCaseSensitivityUnknown { - idxCompDef.CaseSensitivity = IdxCaseSensitive - } - - if idxCompDef.SortOrder == IdxSortUnknown { - idxCompDef.SortOrder = IdxSortAsc - } - - // All good - add it to the list of idx components - idxDef.Components = append(idxDef.Components, idxCompDef) - - return nil -} - -func (idxDefMap *IdxDefMap) parseRawIndexDefMap(rawIdxDefMap map[string]string, fieldRefs *FieldRefs) error { - errors := make([]string, 0, 10) - - for idxName, rawIdxDef := range rawIdxDefMap { - - expIdxDef, err := parser.ParseExpr(rawIdxDef) - if err != nil { - return fmt.Errorf("cannot parse order def '%s': %v", rawIdxDef, err) - } - switch expIdxDef.(type) { - case *ast.CallExpr: - callExp, _ := expIdxDef.(*ast.CallExpr) - identExp, _ := callExp.Fun.(*ast.Ident) - - // Init idx def, defaults here if needed - idxDef := IdxDef{Uniqueness: IdxUniquenessUnknown} - - switch identExp.Name { - case string(IdxUnique): - idxDef.Uniqueness = IdxUnique - case string(IdxNonUnique): - idxDef.Uniqueness = IdxNonUnique - default: - return fmt.Errorf( - "cannot parse index def [%s]: expected top level unique()) or non_unique() definition, found %s", - rawIdxDef, identExp.Name) - } - - // Walk through args - idx field components - for _, fldExp := range callExp.Args { - err := idxDef.parseComponentExpr(&fldExp, fieldRefs) - if err != nil { - errors = append(errors, fmt.Sprintf("index %s: [%s]", rawIdxDef, err.Error())) - } - } - - // All good - add it to the idx map of the table def - (*idxDefMap)[idxName] = &idxDef - - default: - return fmt.Errorf( - "cannot parse index def [%s]: expected top level unique()) or non_unique() definition, found unknown expression", - rawIdxDef) - } - } - - if len(errors) > 0 { - return fmt.Errorf("cannot parse order definitions: [%s]", strings.Join(errors, "; ")) - } else { - return nil - } -} +package sc + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "strconv" + "strings" +) + +const ( + DefaultStringComponentLen int64 = 64 + MinStringComponentLen int64 = 16 + MaxStringComponentLen int64 = 1024 +) + +type IdxSortOrder string + +const ( + IdxSortAsc IdxSortOrder = "asc" + IdxSortDesc IdxSortOrder = "desc" + IdxSortUnknown IdxSortOrder = "unknown" +) + +type IdxCaseSensitivity string + +const ( + IdxCaseSensitive IdxCaseSensitivity = "case_sensitive" + IdxIgnoreCase IdxCaseSensitivity = "ignore_case" + IdxCaseSensitivityUnknown IdxCaseSensitivity = "case_sensitivity_unknown" +) + +type IdxComponentDef struct { + FieldName string + CaseSensitivity IdxCaseSensitivity + SortOrder IdxSortOrder + StringLen int64 // For string fields only, default 64 + FieldType TableFieldType // Populated from tgt_table def +} + +type IdxUniqueness string + +const ( + IdxUnique IdxUniqueness = "unique" + IdxNonUnique IdxUniqueness = "non_unique" + IdxUniquenessUnknown IdxUniqueness = "unknown" +) + +type IdxDef struct { + Uniqueness IdxUniqueness + Components []IdxComponentDef +} + +type IdxDefMap map[string]*IdxDef + +type IndexRef struct { + TableName string + IdxName string +} + +func (idxDef *IdxDef) getComponentFieldRefs(tableName string) FieldRefs { + fieldRefs := make([]FieldRef, len(idxDef.Components)) + for i := 0; i < len(idxDef.Components); i++ { + fieldRefs[i] = FieldRef{ + TableName: tableName, + FieldName: idxDef.Components[i].FieldName, + FieldType: idxDef.Components[i].FieldType} + } + return fieldRefs +} + +func (idxDef *IdxDef) parseComponentExpr(fldExp *ast.Expr, fieldRefs *FieldRefs) error { + // Initialize index component with defaults and append it to idx def + idxCompDef := IdxComponentDef{ + FieldName: FieldNameUnknown, + CaseSensitivity: IdxCaseSensitivityUnknown, + SortOrder: IdxSortUnknown, + StringLen: DefaultStringComponentLen, // Users can override it, see below + FieldType: FieldTypeUnknown} + + switch (*fldExp).(type) { + case *ast.CallExpr: + callExp, _ := (*fldExp).(*ast.CallExpr) //nolint:all + identExp, _ := callExp.Fun.(*ast.Ident) //nolint:all + fieldRef, ok := fieldRefs.FindByFieldName(identExp.Name) + if !ok { + return fmt.Errorf("cannot parse order component func expression, field %s unknown", identExp.Name) + } + + // Defaults + idxCompDef.FieldType = (*fieldRef).FieldType + idxCompDef.FieldName = identExp.Name + + // Parse args: asc/desc, case_sensitive/ignore_case, number-string length + for _, modifierExp := range callExp.Args { + switch modifierExpType := modifierExp.(type) { + case *ast.Ident: + modIdentExp, _ := modifierExp.(*ast.Ident) //nolint:all + switch modIdentExp.Name { + case string(IdxCaseSensitive): + idxCompDef.CaseSensitivity = IdxCaseSensitive + case string(IdxIgnoreCase): + idxCompDef.CaseSensitivity = IdxIgnoreCase + case string(IdxSortAsc): + idxCompDef.SortOrder = IdxSortAsc + case string(IdxSortDesc): + idxCompDef.SortOrder = IdxSortDesc + default: + return fmt.Errorf( + "unknown modifier %s for field %s, expected %s,%s,%s,%s", + modIdentExp.Name, identExp.Name, IdxIgnoreCase, IdxCaseSensitive, IdxSortAsc, IdxSortDesc) + } + case *ast.BasicLit: + switch modifierExpType.Kind { + case token.INT: + if idxCompDef.FieldType != FieldTypeString { + return fmt.Errorf("invalid expression %v in %s, component length modifier is valid only for string fields, but %s has type %s", + modifierExpType, identExp.Name, idxCompDef.FieldName, idxCompDef.FieldType) + } + idxCompDef.StringLen, _ = strconv.ParseInt(modifierExpType.Value, 10, 64) + if idxCompDef.StringLen < MinStringComponentLen { + idxCompDef.StringLen = MinStringComponentLen + } else if idxCompDef.StringLen > MaxStringComponentLen { + return fmt.Errorf("invalid expression %v in %s, component length modifier for string fields cannot exceed %d", + modifierExpType, identExp.Name, MaxStringComponentLen) + } + default: + return fmt.Errorf("invalid expression %v in %s, expected an integer for string component length", modifierExpType, identExp.Name) + } + + default: + return fmt.Errorf( + "invalid expression %v, expected a modifier for field %s: expected %s,%s,%s,%s or an integer", + modifierExpType, identExp.Name, IdxIgnoreCase, IdxCaseSensitive, IdxSortAsc, IdxSortDesc) + } + + // Check some rules + if idxCompDef.FieldType != FieldTypeString && idxCompDef.CaseSensitivity != IdxCaseSensitivityUnknown { + return fmt.Errorf( + "index component for field %s of type %s cannot have case sensitivity modifier %s, remove it from index component definition", + identExp.Name, idxCompDef.FieldType, idxCompDef.CaseSensitivity) + } + } + + case *ast.Ident: + // This is a component def without modifiers (not filed1(...), just field1), so just apply defaults + identExp, _ := (*fldExp).(*ast.Ident) //nolint:all + + fieldRef, ok := fieldRefs.FindByFieldName(identExp.Name) + if !ok { + return fmt.Errorf("cannot parse order component ident expression, field %s unknown", identExp.Name) + } + + // Defaults + idxCompDef.FieldType = (*fieldRef).FieldType + idxCompDef.FieldName = identExp.Name + + default: + return fmt.Errorf( + "invalid expression in index component definition, expected 'field([modifiers])' or 'field' where 'field' is one of the fields of the table created by this node") + } + + // Apply defaults if no modifiers supplied: string -> case sensitive, ordered idx -> sort asc + if idxCompDef.FieldType == FieldTypeString && idxCompDef.CaseSensitivity == IdxCaseSensitivityUnknown { + idxCompDef.CaseSensitivity = IdxCaseSensitive + } + + if idxCompDef.SortOrder == IdxSortUnknown { + idxCompDef.SortOrder = IdxSortAsc + } + + // All good - add it to the list of idx components + idxDef.Components = append(idxDef.Components, idxCompDef) + + return nil +} + +func (idxDefMap *IdxDefMap) parseRawIndexDefMap(rawIdxDefMap map[string]string, fieldRefs *FieldRefs) error { + errors := make([]string, 0, 10) + + for idxName, rawIdxDef := range rawIdxDefMap { + + expIdxDef, err := parser.ParseExpr(rawIdxDef) + if err != nil { + return fmt.Errorf("cannot parse order def '%s': %v", rawIdxDef, err) + } + switch expIdxDef.(type) { + case *ast.CallExpr: + callExp, _ := expIdxDef.(*ast.CallExpr) //nolint:all + identExp, _ := callExp.Fun.(*ast.Ident) //nolint:all + + // Init idx def, defaults here if needed + idxDef := IdxDef{Uniqueness: IdxUniquenessUnknown} + + switch identExp.Name { + case string(IdxUnique): + idxDef.Uniqueness = IdxUnique + case string(IdxNonUnique): + idxDef.Uniqueness = IdxNonUnique + default: + return fmt.Errorf( + "cannot parse index def [%s]: expected top level unique()) or non_unique() definition, found %s", + rawIdxDef, identExp.Name) + } + + // Walk through args - idx field components + for _, fldExp := range callExp.Args { + err := idxDef.parseComponentExpr(&fldExp, fieldRefs) + if err != nil { + errors = append(errors, fmt.Sprintf("index %s: [%s]", rawIdxDef, err.Error())) + } + } + + // All good - add it to the idx map of the table def + (*idxDefMap)[idxName] = &idxDef + + default: + return fmt.Errorf( + "cannot parse index def [%s]: expected top level unique()) or non_unique() definition, found unknown expression", + rawIdxDef) + } + } + + if len(errors) > 0 { + return fmt.Errorf("cannot parse order definitions: [%s]", strings.Join(errors, "; ")) + } + + return nil +} diff --git a/pkg/sc/index_def_test.go b/pkg/sc/index_def_test.go index b2602d8..83564b2 100644 --- a/pkg/sc/index_def_test.go +++ b/pkg/sc/index_def_test.go @@ -1,131 +1,131 @@ -package sc - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func assertIdxComp(t *testing.T, fName string, fType TableFieldType, caseSens IdxCaseSensitivity, sortOrder IdxSortOrder, strLen int64, compDef *IdxComponentDef) { - assert.Equal(t, fName, compDef.FieldName) - assert.Equal(t, fType, compDef.FieldType) - assert.Equal(t, caseSens, compDef.CaseSensitivity) - assert.Equal(t, sortOrder, compDef.SortOrder) - assert.Equal(t, strLen, compDef.StringLen) -} - -func TestIndexDefParser(t *testing.T) { - fieldRefs := FieldRefs{ - FieldRef{"t1", "f_int", FieldTypeInt}, - FieldRef{"t1", "f_float", FieldTypeFloat}, - FieldRef{"t1", "f_bool", FieldTypeBool}, - FieldRef{"t1", "f_str", FieldTypeString}, - FieldRef{"t1", "f_time", FieldTypeDateTime}, - FieldRef{"t1", "f_dec", FieldTypeDecimal2}, - } - rawIdxDefMap := map[string]string{ - "idx_all_default": "non_unique(f_int(),f_float(),f_bool(),f_str(),f_time(),f_dec())", - "idx_all_desc": "unique(f_int(desc),f_float(desc),f_bool(desc),f_str(desc,ignore_case,128),f_time(desc),f_dec(desc))", - "idx_all_asc": "unique(f_int(asc),f_float(asc),f_bool(asc),f_str(asc,case_sensitive,15),f_time(asc),f_dec(asc))", - "idx_no_mods": "unique(f_int,f_float,f_bool,f_str,f_time,f_dec)", - } - idxDefMap := IdxDefMap{} - idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) - - extractedFieldRefs := idxDefMap["idx_all_default"].getComponentFieldRefs("t2") - for i := 0; i < len(extractedFieldRefs); i++ { - extractedFieldRef := &extractedFieldRefs[i] - assert.Equal(t, "t2", extractedFieldRef.TableName) - foundFieldRef, _ := fieldRefs.FindByFieldName(extractedFieldRef.FieldName) - assert.Equal(t, extractedFieldRef.FieldType, foundFieldRef.FieldType) - assert.Equal(t, "t1", foundFieldRef.TableName) - } - - assert.Equal(t, IdxNonUnique, idxDefMap["idx_all_default"].Uniqueness) - assertIdxComp(t, "f_int", FieldTypeInt, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_default"].Components[0]) - assertIdxComp(t, "f_float", FieldTypeFloat, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_default"].Components[1]) - assertIdxComp(t, "f_bool", FieldTypeBool, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_default"].Components[2]) - assertIdxComp(t, "f_str", FieldTypeString, IdxCaseSensitive, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_default"].Components[3]) - assertIdxComp(t, "f_time", FieldTypeDateTime, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_default"].Components[4]) - assertIdxComp(t, "f_dec", FieldTypeDecimal2, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_default"].Components[5]) - - assert.Equal(t, IdxUnique, idxDefMap["idx_all_desc"].Uniqueness) - assertIdxComp(t, "f_int", FieldTypeInt, IdxCaseSensitivityUnknown, IdxSortDesc, DefaultStringComponentLen, &idxDefMap["idx_all_desc"].Components[0]) - assertIdxComp(t, "f_float", FieldTypeFloat, IdxCaseSensitivityUnknown, IdxSortDesc, DefaultStringComponentLen, &idxDefMap["idx_all_desc"].Components[1]) - assertIdxComp(t, "f_bool", FieldTypeBool, IdxCaseSensitivityUnknown, IdxSortDesc, DefaultStringComponentLen, &idxDefMap["idx_all_desc"].Components[2]) - assertIdxComp(t, "f_str", FieldTypeString, IdxIgnoreCase, IdxSortDesc, 128, &idxDefMap["idx_all_desc"].Components[3]) - assertIdxComp(t, "f_time", FieldTypeDateTime, IdxCaseSensitivityUnknown, IdxSortDesc, DefaultStringComponentLen, &idxDefMap["idx_all_desc"].Components[4]) - assertIdxComp(t, "f_dec", FieldTypeDecimal2, IdxCaseSensitivityUnknown, IdxSortDesc, DefaultStringComponentLen, &idxDefMap["idx_all_desc"].Components[5]) - - assert.Equal(t, IdxUnique, idxDefMap["idx_all_asc"].Uniqueness) - assertIdxComp(t, "f_int", FieldTypeInt, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_asc"].Components[0]) - assertIdxComp(t, "f_float", FieldTypeFloat, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_asc"].Components[1]) - assertIdxComp(t, "f_bool", FieldTypeBool, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_asc"].Components[2]) - assertIdxComp(t, "f_str", FieldTypeString, IdxCaseSensitive, IdxSortAsc, MinStringComponentLen, &idxDefMap["idx_all_asc"].Components[3]) - assertIdxComp(t, "f_time", FieldTypeDateTime, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_asc"].Components[4]) - assertIdxComp(t, "f_dec", FieldTypeDecimal2, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_asc"].Components[5]) - - assert.Equal(t, IdxUnique, idxDefMap["idx_no_mods"].Uniqueness) - assertIdxComp(t, "f_int", FieldTypeInt, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_no_mods"].Components[0]) - assertIdxComp(t, "f_float", FieldTypeFloat, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_no_mods"].Components[1]) - assertIdxComp(t, "f_bool", FieldTypeBool, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_no_mods"].Components[2]) - assertIdxComp(t, "f_str", FieldTypeString, IdxCaseSensitive, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_no_mods"].Components[3]) - assertIdxComp(t, "f_time", FieldTypeDateTime, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_no_mods"].Components[4]) - assertIdxComp(t, "f_dec", FieldTypeDecimal2, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_no_mods"].Components[5]) - -} - -func TestIndexDefParserBad(t *testing.T) { - fieldRefs := FieldRefs{ - FieldRef{"t1", "f_int", FieldTypeInt}, - FieldRef{"t1", "f_float", FieldTypeFloat}, - FieldRef{"t1", "f_bool", FieldTypeBool}, - FieldRef{"t1", "f_str", FieldTypeString}, - FieldRef{"t1", "f_time", FieldTypeDateTime}, - FieldRef{"t1", "f_dec", FieldTypeDecimal2}, - } - rawIdxDefMap := map[string]string{"idx_bad_unique": "somename(f_int,f_float,f_bool,f_str,f_time,f_dec)"} - idxDefMap := IdxDefMap{} - err := idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) - assert.Equal(t, "cannot parse index def [somename(f_int,f_float,f_bool,f_str,f_time,f_dec)]: expected top level unique()) or non_unique() definition, found somename", err.Error()) - - rawIdxDefMap = map[string]string{"idx_bad_field": "unique(somefield1,somefield2(),f_bool,f_str,f_time,f_dec)"} - idxDefMap = IdxDefMap{} - err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) - assert.Equal(t, "cannot parse order definitions: [index unique(somefield1,somefield2(),f_bool,f_str,f_time,f_dec): [cannot parse order component ident expression, field somefield1 unknown]; index unique(somefield1,somefield2(),f_bool,f_str,f_time,f_dec): [cannot parse order component func expression, field somefield2 unknown]]", err.Error()) - - rawIdxDefMap = map[string]string{"idx_bad_modifier": "unique(f_int(somemodifier),f_float(case_sensitive),f_bool,f_str,f_time,f_dec)"} - idxDefMap = IdxDefMap{} - err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) - assert.Equal(t, "cannot parse order definitions: [index unique(f_int(somemodifier),f_float(case_sensitive),f_bool,f_str,f_time,f_dec): [unknown modifier somemodifier for field f_int, expected ignore_case,case_sensitive,asc,desc]; index unique(f_int(somemodifier),f_float(case_sensitive),f_bool,f_str,f_time,f_dec): [index component for field f_float of type float cannot have case sensitivity modifier case_sensitive, remove it from index component definition]]", err.Error()) - - rawIdxDefMap = map[string]string{"idx_bad_comp_string_len": "unique(f_int,f_float,f_bool(128),f_str(32000),f_time,f_dec)"} - idxDefMap = IdxDefMap{} - err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) - assert.Equal(t, "cannot parse order definitions: [index unique(f_int,f_float,f_bool(128),f_str(32000),f_time,f_dec): [invalid expression &{29 INT 128} in f_bool, component length modifier is valid only for string fields, but f_bool has type bool]; index unique(f_int,f_float,f_bool(128),f_str(32000),f_time,f_dec): [invalid expression &{40 INT 32000} in f_str, component length modifier for string fields cannot exceed 1024]]", err.Error()) - - rawIdxDefMap = map[string]string{"idx_bad_comp_string_len_float": "unique(f_int,f_float,f_bool,f_str(5.2),f_time,f_dec)"} - idxDefMap = IdxDefMap{} - err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) - assert.Equal(t, "cannot parse order definitions: [index unique(f_int,f_float,f_bool,f_str(5.2),f_time,f_dec): [invalid expression &{35 FLOAT 5.2} in f_str, expected an integer for string component length]]", err.Error()) - - rawIdxDefMap = map[string]string{"idx_bad_modifier_func": "unique(f_int,f_float,f_bool,f_str(badmofifierfunc()),f_time,f_dec)"} - idxDefMap = IdxDefMap{} - err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) - assert.Equal(t, "cannot parse order definitions: [index unique(f_int,f_float,f_bool,f_str(badmofifierfunc()),f_time,f_dec): [invalid expression &{badmofifierfunc 50 [] 0 51}, expected a modifier for field f_str: expected ignore_case,case_sensitive,asc,desc or an integer]]", err.Error()) - - rawIdxDefMap = map[string]string{"idx_bad_field_expr": "unique(123)"} - idxDefMap = IdxDefMap{} - err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) - assert.Equal(t, "cannot parse order definitions: [index unique(123): [invalid expression in index component definition, expected 'field([modifiers])' or 'field' where 'field' is one of the fields of the table created by this node]]", err.Error()) - - rawIdxDefMap = map[string]string{"idx_bad_syntax": "unique("} - idxDefMap = IdxDefMap{} - err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) - assert.Equal(t, "cannot parse order def 'unique(': 1:8: expected ')', found 'EOF'", err.Error()) - - rawIdxDefMap = map[string]string{"idx_bad_no_call": "unique"} - idxDefMap = IdxDefMap{} - err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) - assert.Equal(t, "cannot parse index def [unique]: expected top level unique()) or non_unique() definition, found unknown expression", err.Error()) -} +package sc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func assertIdxComp(t *testing.T, fName string, fType TableFieldType, caseSens IdxCaseSensitivity, sortOrder IdxSortOrder, strLen int64, compDef *IdxComponentDef) { + assert.Equal(t, fName, compDef.FieldName) + assert.Equal(t, fType, compDef.FieldType) + assert.Equal(t, caseSens, compDef.CaseSensitivity) + assert.Equal(t, sortOrder, compDef.SortOrder) + assert.Equal(t, strLen, compDef.StringLen) +} + +func TestIndexDefParser(t *testing.T) { + fieldRefs := FieldRefs{ + FieldRef{"t1", "f_int", FieldTypeInt}, + FieldRef{"t1", "f_float", FieldTypeFloat}, + FieldRef{"t1", "f_bool", FieldTypeBool}, + FieldRef{"t1", "f_str", FieldTypeString}, + FieldRef{"t1", "f_time", FieldTypeDateTime}, + FieldRef{"t1", "f_dec", FieldTypeDecimal2}, + } + rawIdxDefMap := map[string]string{ + "idx_all_default": "non_unique(f_int(),f_float(),f_bool(),f_str(),f_time(),f_dec())", + "idx_all_desc": "unique(f_int(desc),f_float(desc),f_bool(desc),f_str(desc,ignore_case,128),f_time(desc),f_dec(desc))", + "idx_all_asc": "unique(f_int(asc),f_float(asc),f_bool(asc),f_str(asc,case_sensitive,15),f_time(asc),f_dec(asc))", + "idx_no_mods": "unique(f_int,f_float,f_bool,f_str,f_time,f_dec)", + } + idxDefMap := IdxDefMap{} + assert.Nil(t, idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs)) + + extractedFieldRefs := idxDefMap["idx_all_default"].getComponentFieldRefs("t2") + for i := 0; i < len(extractedFieldRefs); i++ { + extractedFieldRef := &extractedFieldRefs[i] + assert.Equal(t, "t2", extractedFieldRef.TableName) + foundFieldRef, _ := fieldRefs.FindByFieldName(extractedFieldRef.FieldName) + assert.Equal(t, extractedFieldRef.FieldType, foundFieldRef.FieldType) + assert.Equal(t, "t1", foundFieldRef.TableName) + } + + assert.Equal(t, IdxNonUnique, idxDefMap["idx_all_default"].Uniqueness) + assertIdxComp(t, "f_int", FieldTypeInt, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_default"].Components[0]) + assertIdxComp(t, "f_float", FieldTypeFloat, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_default"].Components[1]) + assertIdxComp(t, "f_bool", FieldTypeBool, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_default"].Components[2]) + assertIdxComp(t, "f_str", FieldTypeString, IdxCaseSensitive, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_default"].Components[3]) + assertIdxComp(t, "f_time", FieldTypeDateTime, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_default"].Components[4]) + assertIdxComp(t, "f_dec", FieldTypeDecimal2, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_default"].Components[5]) + + assert.Equal(t, IdxUnique, idxDefMap["idx_all_desc"].Uniqueness) + assertIdxComp(t, "f_int", FieldTypeInt, IdxCaseSensitivityUnknown, IdxSortDesc, DefaultStringComponentLen, &idxDefMap["idx_all_desc"].Components[0]) + assertIdxComp(t, "f_float", FieldTypeFloat, IdxCaseSensitivityUnknown, IdxSortDesc, DefaultStringComponentLen, &idxDefMap["idx_all_desc"].Components[1]) + assertIdxComp(t, "f_bool", FieldTypeBool, IdxCaseSensitivityUnknown, IdxSortDesc, DefaultStringComponentLen, &idxDefMap["idx_all_desc"].Components[2]) + assertIdxComp(t, "f_str", FieldTypeString, IdxIgnoreCase, IdxSortDesc, 128, &idxDefMap["idx_all_desc"].Components[3]) + assertIdxComp(t, "f_time", FieldTypeDateTime, IdxCaseSensitivityUnknown, IdxSortDesc, DefaultStringComponentLen, &idxDefMap["idx_all_desc"].Components[4]) + assertIdxComp(t, "f_dec", FieldTypeDecimal2, IdxCaseSensitivityUnknown, IdxSortDesc, DefaultStringComponentLen, &idxDefMap["idx_all_desc"].Components[5]) + + assert.Equal(t, IdxUnique, idxDefMap["idx_all_asc"].Uniqueness) + assertIdxComp(t, "f_int", FieldTypeInt, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_asc"].Components[0]) + assertIdxComp(t, "f_float", FieldTypeFloat, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_asc"].Components[1]) + assertIdxComp(t, "f_bool", FieldTypeBool, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_asc"].Components[2]) + assertIdxComp(t, "f_str", FieldTypeString, IdxCaseSensitive, IdxSortAsc, MinStringComponentLen, &idxDefMap["idx_all_asc"].Components[3]) + assertIdxComp(t, "f_time", FieldTypeDateTime, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_asc"].Components[4]) + assertIdxComp(t, "f_dec", FieldTypeDecimal2, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_all_asc"].Components[5]) + + assert.Equal(t, IdxUnique, idxDefMap["idx_no_mods"].Uniqueness) + assertIdxComp(t, "f_int", FieldTypeInt, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_no_mods"].Components[0]) + assertIdxComp(t, "f_float", FieldTypeFloat, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_no_mods"].Components[1]) + assertIdxComp(t, "f_bool", FieldTypeBool, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_no_mods"].Components[2]) + assertIdxComp(t, "f_str", FieldTypeString, IdxCaseSensitive, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_no_mods"].Components[3]) + assertIdxComp(t, "f_time", FieldTypeDateTime, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_no_mods"].Components[4]) + assertIdxComp(t, "f_dec", FieldTypeDecimal2, IdxCaseSensitivityUnknown, IdxSortAsc, DefaultStringComponentLen, &idxDefMap["idx_no_mods"].Components[5]) + +} + +func TestIndexDefParserBad(t *testing.T) { + fieldRefs := FieldRefs{ + FieldRef{"t1", "f_int", FieldTypeInt}, + FieldRef{"t1", "f_float", FieldTypeFloat}, + FieldRef{"t1", "f_bool", FieldTypeBool}, + FieldRef{"t1", "f_str", FieldTypeString}, + FieldRef{"t1", "f_time", FieldTypeDateTime}, + FieldRef{"t1", "f_dec", FieldTypeDecimal2}, + } + rawIdxDefMap := map[string]string{"idx_bad_unique": "somename(f_int,f_float,f_bool,f_str,f_time,f_dec)"} + idxDefMap := IdxDefMap{} + err := idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) + assert.Equal(t, "cannot parse index def [somename(f_int,f_float,f_bool,f_str,f_time,f_dec)]: expected top level unique()) or non_unique() definition, found somename", err.Error()) + + rawIdxDefMap = map[string]string{"idx_bad_field": "unique(somefield1,somefield2(),f_bool,f_str,f_time,f_dec)"} + idxDefMap = IdxDefMap{} + err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) + assert.Equal(t, "cannot parse order definitions: [index unique(somefield1,somefield2(),f_bool,f_str,f_time,f_dec): [cannot parse order component ident expression, field somefield1 unknown]; index unique(somefield1,somefield2(),f_bool,f_str,f_time,f_dec): [cannot parse order component func expression, field somefield2 unknown]]", err.Error()) + + rawIdxDefMap = map[string]string{"idx_bad_modifier": "unique(f_int(somemodifier),f_float(case_sensitive),f_bool,f_str,f_time,f_dec)"} + idxDefMap = IdxDefMap{} + err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) + assert.Equal(t, "cannot parse order definitions: [index unique(f_int(somemodifier),f_float(case_sensitive),f_bool,f_str,f_time,f_dec): [unknown modifier somemodifier for field f_int, expected ignore_case,case_sensitive,asc,desc]; index unique(f_int(somemodifier),f_float(case_sensitive),f_bool,f_str,f_time,f_dec): [index component for field f_float of type float cannot have case sensitivity modifier case_sensitive, remove it from index component definition]]", err.Error()) + + rawIdxDefMap = map[string]string{"idx_bad_comp_string_len": "unique(f_int,f_float,f_bool(128),f_str(32000),f_time,f_dec)"} + idxDefMap = IdxDefMap{} + err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) + assert.Equal(t, "cannot parse order definitions: [index unique(f_int,f_float,f_bool(128),f_str(32000),f_time,f_dec): [invalid expression &{29 INT 128} in f_bool, component length modifier is valid only for string fields, but f_bool has type bool]; index unique(f_int,f_float,f_bool(128),f_str(32000),f_time,f_dec): [invalid expression &{40 INT 32000} in f_str, component length modifier for string fields cannot exceed 1024]]", err.Error()) + + rawIdxDefMap = map[string]string{"idx_bad_comp_string_len_float": "unique(f_int,f_float,f_bool,f_str(5.2),f_time,f_dec)"} + idxDefMap = IdxDefMap{} + err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) + assert.Equal(t, "cannot parse order definitions: [index unique(f_int,f_float,f_bool,f_str(5.2),f_time,f_dec): [invalid expression &{35 FLOAT 5.2} in f_str, expected an integer for string component length]]", err.Error()) + + rawIdxDefMap = map[string]string{"idx_bad_modifier_func": "unique(f_int,f_float,f_bool,f_str(badmofifierfunc()),f_time,f_dec)"} + idxDefMap = IdxDefMap{} + err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) + assert.Equal(t, "cannot parse order definitions: [index unique(f_int,f_float,f_bool,f_str(badmofifierfunc()),f_time,f_dec): [invalid expression &{badmofifierfunc 50 [] 0 51}, expected a modifier for field f_str: expected ignore_case,case_sensitive,asc,desc or an integer]]", err.Error()) + + rawIdxDefMap = map[string]string{"idx_bad_field_expr": "unique(123)"} + idxDefMap = IdxDefMap{} + err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) + assert.Equal(t, "cannot parse order definitions: [index unique(123): [invalid expression in index component definition, expected 'field([modifiers])' or 'field' where 'field' is one of the fields of the table created by this node]]", err.Error()) + + rawIdxDefMap = map[string]string{"idx_bad_syntax": "unique("} + idxDefMap = IdxDefMap{} + err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) + assert.Equal(t, "cannot parse order def 'unique(': 1:8: expected ')', found 'EOF'", err.Error()) + + rawIdxDefMap = map[string]string{"idx_bad_no_call": "unique"} + idxDefMap = IdxDefMap{} + err = idxDefMap.parseRawIndexDefMap(rawIdxDefMap, &fieldRefs) + assert.Equal(t, "cannot parse index def [unique]: expected top level unique()) or non_unique() definition, found unknown expression", err.Error()) +} diff --git a/pkg/sc/key.go b/pkg/sc/key.go index 2a31559..c01d7d5 100644 --- a/pkg/sc/key.go +++ b/pkg/sc/key.go @@ -1,171 +1,173 @@ -package sc - -import ( - "bytes" - "encoding/hex" - "errors" - "fmt" - "strings" - "time" - "unicode" - - "github.com/shopspring/decimal" - "golang.org/x/text/runes" - - "golang.org/x/text/transform" - "golang.org/x/text/unicode/norm" -) - -const BeginningOfTimeMicro = int64(-62135596800000000) // time.Date(1, time.January, 1, 0, 0, 0, 0, time.UTC).UnixMicro() - -func getNumericValueSign(v interface{}, expectedType TableFieldType) (string, interface{}, error) { - var sign string - var newVal interface{} - - switch expectedType { - case FieldTypeInt: - if n, ok := v.(int64); ok { - if n >= 0 { - sign = "0" // "0" > "-" - newVal = n - } else { - sign = "-" - newVal = -n - } - } else { - return "", nil, fmt.Errorf("cannot convert value %v to type %v", v, expectedType) - } - - case FieldTypeFloat: - if f, ok := v.(float64); ok { - if f >= 0 { - sign = "0" // "0" > "-" - newVal = f - } else { - sign = "-" - newVal = -f - } - } else { - return "", nil, fmt.Errorf("cannot convert value %v to type %v", v, expectedType) - } - - case FieldTypeDecimal2: - if d, ok := v.(decimal.Decimal); ok { - if d.Sign() >= 0 { - sign = "0" // "0" > "-" - newVal = d - } else { - sign = "-" - newVal = d.Neg() - } - } else { - return "", nil, errors.New(fmt.Sprintf("cannot convert value %v to type %v", v, expectedType)) - } - - default: - return "", nil, fmt.Errorf("unexpectedly, cannot convert value %v to type %v, type not supported", v, expectedType) - } - return sign, newVal, nil -} - -func BuildKey(fieldMap map[string]interface{}, idxDef *IdxDef) (string, error) { - var keyBuffer bytes.Buffer - t := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC) - flipReplacer := strings.NewReplacer("0", "9", "1", "8", "2", "7", "3", "6", "4", "5", "5", "4", "6", "3", "7", "2", "8", "1", "9", "0") - - for _, comp := range idxDef.Components { - if _, ok := fieldMap[comp.FieldName]; !ok { - return "", fmt.Errorf("cannot find value for field %v in %v while building key for index %v", comp.FieldName, fieldMap, idxDef) - } - - var stringValue string - - switch comp.FieldType { - case FieldTypeInt: - if sign, absVal, err := getNumericValueSign(fieldMap[comp.FieldName], FieldTypeInt); err == nil { - stringValue = fmt.Sprintf("%s%018d", sign, absVal) - // If this is a negative value, flip every digit - if sign == "-" { - stringValue = flipReplacer.Replace(stringValue) - } - } else { - return "", err - } - - case FieldTypeFloat: - // We should support numbers as big as 10^32 and with 32 digits afetr decimal point - if sign, absVal, err := getNumericValueSign(fieldMap[comp.FieldName], FieldTypeFloat); err == nil { - stringValue = strings.ReplaceAll(fmt.Sprintf("%s%66s", sign, fmt.Sprintf("%.32f", absVal)), " ", "0") - // If this is a negative value, flip every digit - if sign == "-" { - stringValue = flipReplacer.Replace(stringValue) - } - } else { - return "", err - } - - case FieldTypeDecimal2: - if sign, absVal, err := getNumericValueSign(fieldMap[comp.FieldName], FieldTypeDecimal2); err == nil { - decVal, _ := absVal.(decimal.Decimal) - floatVal, _ := decVal.Float64() - stringValue = strings.ReplaceAll(fmt.Sprintf("%s%66s", sign, fmt.Sprintf("%.32f", floatVal)), " ", "0") - // If this is a negative value, flip every digit - if sign == "-" { - stringValue = flipReplacer.Replace(stringValue) - } - } else { - return "", err - } - - case FieldTypeDateTime: - // We support time differences up to microsecond. Not nanosecond! Cassandra supports only milliseconds. Millis are our lingua franca. - if t, ok := fieldMap[comp.FieldName].(time.Time); ok { - stringValue = fmt.Sprintf("%020d", t.UnixMicro()-BeginningOfTimeMicro) - } else { - return "", fmt.Errorf("cannot convert value %v to type datetime", fieldMap[comp.FieldName]) - } - - case FieldTypeString: - if s, ok := fieldMap[comp.FieldName].(string); ok { - // Normalize the string - transformedString, _, _ := transform.String(t, s) - // Take only first 64 (or whatever we have in StringLen) characters - // use "%-64s" sprint format to pad with spaces on the right - formatString := fmt.Sprintf("%s-%ds", "%", comp.StringLen) - stringValue = fmt.Sprintf(formatString, transformedString)[:comp.StringLen] - if comp.CaseSensitivity == IdxIgnoreCase { - stringValue = strings.ToUpper(stringValue) - } - } else { - return "", fmt.Errorf("cannot convert value %v to type string", fieldMap[comp.FieldName]) - } - - case FieldTypeBool: - if b, ok := fieldMap[comp.FieldName].(bool); ok { - if b { - stringValue = "T" // "F" < "T" - } else { - stringValue = "F" - } - } else { - return "", fmt.Errorf("cannot convert value %v to type bool", fieldMap[comp.FieldName]) - } - - default: - return "", fmt.Errorf(fmt.Sprintf("cannot build key, unsupported field data type %s", comp.FieldType)) - } - - // Used by file creator top. Not used by actual indexes - Cassandra cannot do proper ORDER BY anyways - if comp.SortOrder == IdxSortDesc { - var stringBytes []byte = []byte(stringValue) - for i, b := range stringBytes { - stringBytes[i] = 0xFF - b - } - stringValue = hex.EncodeToString(stringBytes) - } - - keyBuffer.WriteString(stringValue) - } - - return keyBuffer.String(), nil -} +package sc + +import ( + "bytes" + "encoding/hex" + "fmt" + "strings" + "time" + "unicode" + + "github.com/shopspring/decimal" + "golang.org/x/text/runes" + + "golang.org/x/text/transform" + "golang.org/x/text/unicode/norm" +) + +const BeginningOfTimeMicro = int64(-62135596800000000) // time.Date(1, time.January, 1, 0, 0, 0, 0, time.UTC).UnixMicro() + +func getNumericValueSign(v any, expectedType TableFieldType) (string, any, error) { + var sign string + var newVal any + + switch expectedType { + case FieldTypeInt: + if n, ok := v.(int64); ok { + if n >= 0 { + sign = "0" // "0" > "-" + newVal = n + } else { + sign = "-" + newVal = -n + } + } else { + return "", nil, fmt.Errorf("cannot convert value %v to type %v", v, expectedType) + } + + case FieldTypeFloat: + if f, ok := v.(float64); ok { + if f >= 0 { + sign = "0" // "0" > "-" + newVal = f + } else { + sign = "-" + newVal = -f + } + } else { + return "", nil, fmt.Errorf("cannot convert value %v to type %v", v, expectedType) + } + + case FieldTypeDecimal2: + if d, ok := v.(decimal.Decimal); ok { + if d.Sign() >= 0 { + sign = "0" // "0" > "-" + newVal = d + } else { + sign = "-" + newVal = d.Neg() + } + } else { + return "", nil, fmt.Errorf("cannot convert value %v to type %v", v, expectedType) + } + + default: + return "", nil, fmt.Errorf("cannot convert value %v to type %v, type not supported", v, expectedType) + } + return sign, newVal, nil +} + +func BuildKey(fieldMap map[string]any, idxDef *IdxDef) (string, error) { + var keyBuffer bytes.Buffer + t := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC) + flipReplacer := strings.NewReplacer("0", "9", "1", "8", "2", "7", "3", "6", "4", "5", "5", "4", "6", "3", "7", "2", "8", "1", "9", "0") + + for _, comp := range idxDef.Components { + if _, ok := fieldMap[comp.FieldName]; !ok { + return "", fmt.Errorf("cannot find value for field %v in %v while building key for index %v", comp.FieldName, fieldMap, idxDef) + } + + var stringValue string + + switch comp.FieldType { + case FieldTypeInt: + if sign, absVal, err := getNumericValueSign(fieldMap[comp.FieldName], FieldTypeInt); err == nil { + stringValue = fmt.Sprintf("%s%018d", sign, absVal) + // If this is a negative value, flip every digit + if sign == "-" { + stringValue = flipReplacer.Replace(stringValue) + } + } else { + return "", err + } + + case FieldTypeFloat: + // We should support numbers as big as 10^32 and with 32 digits afetr decimal point + if sign, absVal, err := getNumericValueSign(fieldMap[comp.FieldName], FieldTypeFloat); err == nil { + stringValue = strings.ReplaceAll(fmt.Sprintf("%s%66s", sign, fmt.Sprintf("%.32f", absVal)), " ", "0") + // If this is a negative value, flip every digit + if sign == "-" { + stringValue = flipReplacer.Replace(stringValue) + } + } else { + return "", err + } + + case FieldTypeDecimal2: + if sign, absVal, err := getNumericValueSign(fieldMap[comp.FieldName], FieldTypeDecimal2); err == nil { + decVal, ok := absVal.(decimal.Decimal) + if !ok { + return "", fmt.Errorf("cannot convert value %v to type decimal2", fieldMap[comp.FieldName]) + } + floatVal, _ := decVal.Float64() + stringValue = strings.ReplaceAll(fmt.Sprintf("%s%66s", sign, fmt.Sprintf("%.32f", floatVal)), " ", "0") + // If this is a negative value, flip every digit + if sign == "-" { + stringValue = flipReplacer.Replace(stringValue) + } + } else { + return "", err + } + + case FieldTypeDateTime: + // We support time differences up to microsecond. Not nanosecond! Cassandra supports only milliseconds. Millis are our lingua franca. + if t, ok := fieldMap[comp.FieldName].(time.Time); ok { + stringValue = fmt.Sprintf("%020d", t.UnixMicro()-BeginningOfTimeMicro) + } else { + return "", fmt.Errorf("cannot convert value %v to type datetime", fieldMap[comp.FieldName]) + } + + case FieldTypeString: + if s, ok := fieldMap[comp.FieldName].(string); ok { + // Normalize the string + transformedString, _, _ := transform.String(t, s) + // Take only first 64 (or whatever we have in StringLen) characters + // use "%-64s" sprint format to pad with spaces on the right + formatString := fmt.Sprintf("%s-%ds", "%", comp.StringLen) + stringValue = fmt.Sprintf(formatString, transformedString)[:comp.StringLen] + if comp.CaseSensitivity == IdxIgnoreCase { + stringValue = strings.ToUpper(stringValue) + } + } else { + return "", fmt.Errorf("cannot convert value %v to type string", fieldMap[comp.FieldName]) + } + + case FieldTypeBool: + if b, ok := fieldMap[comp.FieldName].(bool); ok { + if b { + stringValue = "T" // "F" < "T" + } else { + stringValue = "F" + } + } else { + return "", fmt.Errorf("cannot convert value %v to type bool", fieldMap[comp.FieldName]) + } + + default: + return "", fmt.Errorf(fmt.Sprintf("cannot build key, unsupported field data type %s", comp.FieldType)) + } + + // Used by file creator top. Not used by actual indexes - Cassandra cannot do proper ORDER BY anyways + if comp.SortOrder == IdxSortDesc { + stringBytes := []byte(stringValue) + for i, b := range stringBytes { + stringBytes[i] = 0xFF - b + } + stringValue = hex.EncodeToString(stringBytes) + } + + keyBuffer.WriteString(stringValue) + } + + return keyBuffer.String(), nil +} diff --git a/pkg/sc/key_test.go b/pkg/sc/key_test.go index 86ec534..e1abc0f 100644 --- a/pkg/sc/key_test.go +++ b/pkg/sc/key_test.go @@ -18,9 +18,9 @@ func assertKeyErrorPrefix(t *testing.T, expectedErrorPrefix string, actualError func assertKeyCompare( t *testing.T, - row1 map[string]interface{}, + row1 map[string]any, moreLess string, - row2 map[string]interface{}, + row2 map[string]any, idxDef IdxDef) { if moreLess != "<" && moreLess != ">" && moreLess != "==" { @@ -38,8 +38,6 @@ func assertKeyCompare( t.Errorf("%s\n", err2) } - //t.Errorf("\n%s\n%s", key1, key2) - if moreLess == "<" && (key1 >= key2) || moreLess == ">" && (key1 <= key2) || moreLess == "==" && (key1 != key2) { t.Errorf("\nExpected:\n%s\n%s\n%s\n", key1, moreLess, key2) } @@ -48,7 +46,7 @@ func assertKeyCompare( func TestBad(t *testing.T) { idxDef := IdxDef{Uniqueness: "UNIQUE", Components: []IdxComponentDef{{FieldName: "fld", SortOrder: IdxSortAsc, FieldType: FieldTypeInt}}} - row1 := map[string]interface{}{"fld": false} + row1 := map[string]any{"fld": false} _, err := BuildKey(row1, &idxDef) assert.Equal(t, "cannot convert value false to type int", err.Error()) @@ -108,13 +106,13 @@ func TestCombined(t *testing.T) { }, } - row1 := map[string]interface{}{ + row1 := map[string]any{ "field_int": int64(1), "field_string": "abc", "field_float": -2.3, "field_bool": false, } - row2 := map[string]interface{}{ + row2 := map[string]any{ "field_int": int64(1), "field_string": "Abc", "field_float": 1.3, @@ -152,12 +150,12 @@ func TestTime(t *testing.T) { idxDef.Components[0].SortOrder = IdxSortAsc - row1 := map[string]interface{}{"fld": time.Date(1, time.January, 1, 2, 2, 2, 3000, time.UTC)} - row2 := map[string]interface{}{"fld": time.Date(1, time.January, 1, 2, 2, 2, 4000, time.UTC)} + row1 := map[string]any{"fld": time.Date(1, time.January, 1, 2, 2, 2, 3000, time.UTC)} + row2 := map[string]any{"fld": time.Date(1, time.January, 1, 2, 2, 2, 4000, time.UTC)} assertKeyCompare(t, row1, "<", row2, idxDef) - row1 = map[string]interface{}{"fld": time.Date(850000, time.January, 1, 2, 2, 2, 3000, time.UTC)} - row2 = map[string]interface{}{"fld": time.Date(850000, time.January, 1, 2, 2, 2, 4000, time.UTC)} + row1 = map[string]any{"fld": time.Date(850000, time.January, 1, 2, 2, 2, 3000, time.UTC)} + row2 = map[string]any{"fld": time.Date(850000, time.January, 1, 2, 2, 2, 4000, time.UTC)} assertKeyCompare(t, row1, "<", row2, idxDef) idxDef.Components[0].SortOrder = IdxSortDesc @@ -172,8 +170,8 @@ func TestBool(t *testing.T) { Components: []IdxComponentDef{{FieldName: "fld", FieldType: FieldTypeBool}}, } - row1 := map[string]interface{}{"fld": false} - row2 := map[string]interface{}{"fld": true} + row1 := map[string]any{"fld": false} + row2 := map[string]any{"fld": true} idxDef.Components[0].SortOrder = IdxSortAsc assertKeyCompare(t, row1, "<", row2, idxDef) @@ -191,22 +189,22 @@ func TestInt(t *testing.T) { idxDef.Components[0].SortOrder = IdxSortAsc - row1 := map[string]interface{}{"fld": int64(1000)} - row2 := map[string]interface{}{"fld": int64(2000)} + row1 := map[string]any{"fld": int64(1000)} + row2 := map[string]any{"fld": int64(2000)} assertKeyCompare(t, row1, "<", row2, idxDef) - row1 = map[string]interface{}{"fld": int64(-1000)} - row2 = map[string]interface{}{"fld": int64(-2000)} + row1 = map[string]any{"fld": int64(-1000)} + row2 = map[string]any{"fld": int64(-2000)} assertKeyCompare(t, row1, ">", row2, idxDef) - row1 = map[string]interface{}{"fld": int64(-1000)} - row2 = map[string]interface{}{"fld": int64(50)} + row1 = map[string]any{"fld": int64(-1000)} + row2 = map[string]any{"fld": int64(50)} assertKeyCompare(t, row1, "<", row2, idxDef) idxDef.Components[0].SortOrder = IdxSortDesc - row1 = map[string]interface{}{"fld": int64(-1000)} - row2 = map[string]interface{}{"fld": int64(50)} + row1 = map[string]any{"fld": int64(-1000)} + row2 = map[string]any{"fld": int64(50)} assertKeyCompare(t, row1, ">", row2, idxDef) } @@ -219,34 +217,34 @@ func TestFloat(t *testing.T) { idxDef.Components[0].SortOrder = IdxSortAsc - row1 := map[string]interface{}{"fld": 1.1} - row2 := map[string]interface{}{"fld": 1.2} + row1 := map[string]any{"fld": 1.1} + row2 := map[string]any{"fld": 1.2} assertKeyCompare(t, row1, "<", row2, idxDef) - row1 = map[string]interface{}{"fld": math.Pow10(32)} - row2 = map[string]interface{}{"fld": math.Pow10(32) / 2} + row1 = map[string]any{"fld": math.Pow10(32)} + row2 = map[string]any{"fld": math.Pow10(32) / 2} assertKeyCompare(t, row1, ">", row2, idxDef) - row1 = map[string]interface{}{"fld": -math.Pow10(32)} - row2 = map[string]interface{}{"fld": -math.Pow10(32) / 2} + row1 = map[string]any{"fld": -math.Pow10(32)} + row2 = map[string]any{"fld": -math.Pow10(32) / 2} assertKeyCompare(t, row1, "<", row2, idxDef) - row1 = map[string]interface{}{"fld": math.Pow10(-32)} - row2 = map[string]interface{}{"fld": math.Pow10(-32) * 2} + row1 = map[string]any{"fld": math.Pow10(-32)} + row2 = map[string]any{"fld": math.Pow10(-32) * 2} assertKeyCompare(t, row1, "<", row2, idxDef) - row1 = map[string]interface{}{"fld": -math.Pow10(-32)} - row2 = map[string]interface{}{"fld": -math.Pow10(-32) * 2} + row1 = map[string]any{"fld": -math.Pow10(-32)} + row2 = map[string]any{"fld": -math.Pow10(-32) * 2} assertKeyCompare(t, row1, ">", row2, idxDef) - row1 = map[string]interface{}{"fld": -1.2} - row2 = map[string]interface{}{"fld": 0.005} + row1 = map[string]any{"fld": -1.2} + row2 = map[string]any{"fld": 0.005} assertKeyCompare(t, row1, "<", row2, idxDef) idxDef.Components[0].SortOrder = IdxSortDesc - row1 = map[string]interface{}{"fld": 1.1} - row2 = map[string]interface{}{"fld": 1.2} + row1 = map[string]any{"fld": 1.1} + row2 = map[string]any{"fld": 1.2} assertKeyCompare(t, row1, ">", row2, idxDef) } @@ -261,35 +259,35 @@ func TestString(t *testing.T) { idxDef.Components[0].SortOrder = IdxSortAsc // Different length - row1 := map[string]interface{}{"fld": "aaa"} - row2 := map[string]interface{}{"fld": "bb"} + row1 := map[string]any{"fld": "aaa"} + row2 := map[string]any{"fld": "bb"} assertKeyCompare(t, row1, "<", row2, idxDef) // Plain - row1 = map[string]interface{}{"fld": "aaa"} - row2 = map[string]interface{}{"fld": "bbb"} + row1 = map[string]any{"fld": "aaa"} + row2 = map[string]any{"fld": "bbb"} assertKeyCompare(t, row1, "<", row2, idxDef) // Ignore case - row1 = map[string]interface{}{"fld": "aaa"} - row2 = map[string]interface{}{"fld": "Abb"} + row1 = map[string]any{"fld": "aaa"} + row2 = map[string]any{"fld": "Abb"} assertKeyCompare(t, row1, "<", row2, idxDef) // Beyond StringLen - row1 = map[string]interface{}{"fld": "1234567890123456A"} - row2 = map[string]interface{}{"fld": "1234567890123456B"} + row1 = map[string]any{"fld": "1234567890123456A"} + row2 = map[string]any{"fld": "1234567890123456B"} assertKeyCompare(t, row1, "==", row2, idxDef) // Within StringLen - row1 = map[string]interface{}{"fld": "123456789012345A"} - row2 = map[string]interface{}{"fld": "123456789012345B"} + row1 = map[string]any{"fld": "123456789012345A"} + row2 = map[string]any{"fld": "123456789012345B"} assertKeyCompare(t, row1, "<", row2, idxDef) idxDef.Components[0].SortOrder = IdxSortDesc // Reverse order - row1 = map[string]interface{}{"fld": "aaa"} - row2 = map[string]interface{}{"fld": "bbb"} + row1 = map[string]any{"fld": "aaa"} + row2 = map[string]any{"fld": "bbb"} assertKeyCompare(t, row1, ">", row2, idxDef) } @@ -301,37 +299,42 @@ func TestDecimal(t *testing.T) { idxDef.Components[0].SortOrder = IdxSortAsc - row1 := map[string]interface{}{"fld": decimal.NewFromFloat32(0.23456)} - row2 := map[string]interface{}{"fld": decimal.NewFromFloat32(985.4)} + row1 := map[string]any{"fld": decimal.NewFromFloat32(0.23456)} + row2 := map[string]any{"fld": decimal.NewFromFloat32(985.4)} assertKeyCompare(t, row1, "<", row2, idxDef) - row1 = map[string]interface{}{"fld": decimal.NewFromFloat32(0.23456)} - row2 = map[string]interface{}{"fld": decimal.NewFromFloat32(-985.4)} + row1 = map[string]any{"fld": decimal.NewFromFloat32(0.23456)} + row2 = map[string]any{"fld": decimal.NewFromFloat32(-985.4)} assertKeyCompare(t, row1, ">", row2, idxDef) - row1 = map[string]interface{}{"fld": decimal.NewFromFloat32(0.002)} - row2 = map[string]interface{}{"fld": decimal.NewFromFloat32(0.01)} + row1 = map[string]any{"fld": decimal.NewFromFloat32(0.002)} + row2 = map[string]any{"fld": decimal.NewFromFloat32(0.01)} assertKeyCompare(t, row1, "<", row2, idxDef) - row1 = map[string]interface{}{"fld": decimal.NewFromFloat32(-2000)} - row2 = map[string]interface{}{"fld": decimal.NewFromFloat32(-1000)} + row1 = map[string]any{"fld": decimal.NewFromFloat32(-2000)} + row2 = map[string]any{"fld": decimal.NewFromFloat32(-1000)} assertKeyCompare(t, row1, "<", row2, idxDef) idxDef.Components[0].SortOrder = IdxSortDesc - row1 = map[string]interface{}{"fld": decimal.NewFromFloat32(0.23456)} - row2 = map[string]interface{}{"fld": decimal.NewFromFloat32(985.4)} + row1 = map[string]any{"fld": decimal.NewFromFloat32(0.23456)} + row2 = map[string]any{"fld": decimal.NewFromFloat32(985.4)} assertKeyCompare(t, row1, ">", row2, idxDef) - row1 = map[string]interface{}{"fld": decimal.NewFromFloat32(0.23456)} - row2 = map[string]interface{}{"fld": decimal.NewFromFloat32(-985.4)} + row1 = map[string]any{"fld": decimal.NewFromFloat32(0.23456)} + row2 = map[string]any{"fld": decimal.NewFromFloat32(-985.4)} assertKeyCompare(t, row1, "<", row2, idxDef) - row1 = map[string]interface{}{"fld": decimal.NewFromFloat32(0.002)} - row2 = map[string]interface{}{"fld": decimal.NewFromFloat32(0.01)} + row1 = map[string]any{"fld": decimal.NewFromFloat32(0.002)} + row2 = map[string]any{"fld": decimal.NewFromFloat32(0.01)} assertKeyCompare(t, row1, ">", row2, idxDef) - row1 = map[string]interface{}{"fld": decimal.NewFromFloat32(-2000)} - row2 = map[string]interface{}{"fld": decimal.NewFromFloat32(-1000)} + row1 = map[string]any{"fld": decimal.NewFromFloat32(-2000)} + row2 = map[string]any{"fld": decimal.NewFromFloat32(-1000)} assertKeyCompare(t, row1, ">", row2, idxDef) } + +func TestGetNUmericValueSign(t *testing.T) { + _, _, err := getNumericValueSign(nil, FieldTypeUnknown) + assert.Contains(t, err.Error(), "cannot convert value to type unknown") +} diff --git a/pkg/sc/lookup_def.go b/pkg/sc/lookup_def.go index 9637a8f..b659f56 100644 --- a/pkg/sc/lookup_def.go +++ b/pkg/sc/lookup_def.go @@ -1,136 +1,136 @@ -package sc - -import ( - "fmt" - "go/ast" - "strings" - - "github.com/capillariesio/capillaries/pkg/eval" -) - -type LookupJoinType string - -const ( - LookupJoinInner LookupJoinType = "inner" - LookupJoinLeft LookupJoinType = "left" -) - -type LookupDef struct { - IndexName string `json:"index_name"` - RawJoinOn string `json:"join_on"` - IsGroup bool `json:"group"` - RawFilter string `json:"filter"` - LookupJoin LookupJoinType `json:"join_type"` - IdxReadBatchSize int `json:"idx_read_batch_size"` - RightLookupReadBatchSize int `json:"right_lookup_read_batch_size"` - - LeftTableFields FieldRefs // In the same order as lookup idx - important - TableCreator *TableCreatorDef // Populated when walking through al nodes - UsedInFilterFields FieldRefs - Filter ast.Expr -} - -const ( - defaultIdxBatchSize int = 3000 - maxIdxBatchSize int = 5000 - defaultRightLookupBatchSize int = 3000 - maxRightLookupReadBatchSize int = 5000 -) - -func (lkpDef *LookupDef) CheckPagedBatchSize() error { - // Default gocql iterator page size is 5000, do not exceed it. - // Actually, getting close to it (4000) causes problems on small servers - if lkpDef.IdxReadBatchSize <= 0 { - lkpDef.IdxReadBatchSize = defaultIdxBatchSize - } else if lkpDef.IdxReadBatchSize > maxIdxBatchSize { - return fmt.Errorf("cannot use idx_read_batch_size %d, expected <= %d, default %d, ", lkpDef.IdxReadBatchSize, maxIdxBatchSize, defaultIdxBatchSize) - } - if lkpDef.RightLookupReadBatchSize <= 0 { - lkpDef.RightLookupReadBatchSize = defaultRightLookupBatchSize - } else if lkpDef.RightLookupReadBatchSize > maxRightLookupReadBatchSize { - return fmt.Errorf("cannot use right_lookup_read_batch_size %d, expected <= %d, default %d, ", lkpDef.RightLookupReadBatchSize, maxRightLookupReadBatchSize, defaultRightLookupBatchSize) - } - return nil -} - -func (lkpDef *LookupDef) UsesFilter() bool { - return len(strings.TrimSpace(lkpDef.RawFilter)) > 0 -} - -func (lkpDef *LookupDef) ValidateJoinType() error { - if lkpDef.LookupJoin != LookupJoinLeft && lkpDef.LookupJoin != LookupJoinInner { - return fmt.Errorf("invalid join type, expected inner or left, %s is not supported", lkpDef.LookupJoin) - } - return nil -} - -func (lkpDef *LookupDef) ParseFilter() error { - if !lkpDef.UsesFilter() { - return nil - } - var err error - lkpDef.Filter, err = ParseRawGolangExpressionStringAndHarvestFieldRefs(lkpDef.RawFilter, &lkpDef.UsedInFilterFields) - if err != nil { - return fmt.Errorf("cannot parse lookup filter condition [%s]: %s", lkpDef.RawFilter, err.Error()) - } - return nil -} - -func (lkpDef *LookupDef) resolveLeftTableFields(srcName string, srcFieldRefs *FieldRefs) error { - fieldExpressions := strings.Split(lkpDef.RawJoinOn, ",") - lkpDef.LeftTableFields = make(FieldRefs, len(fieldExpressions)) - for fieldIdx := 0; fieldIdx < len(fieldExpressions); fieldIdx++ { - fieldNameParts := strings.Split(strings.TrimSpace(fieldExpressions[fieldIdx]), ".") - if len(fieldNameParts) != 2 { - return fmt.Errorf("expected a comma-separated list of ., got [%s]", lkpDef.RawJoinOn) - } - tName := strings.TrimSpace(fieldNameParts[0]) - fName := strings.TrimSpace(fieldNameParts[1]) - if tName != srcName { - return fmt.Errorf("source table name [%s] unknown, expected [%s]", tName, srcName) - } - srcFieldRef, ok := srcFieldRefs.FindByFieldName(fName) - if !ok { - return fmt.Errorf("source [%s] does not produce field [%s]", tName, fName) - } - if srcFieldRef.FieldType == FieldTypeUnknown { - return fmt.Errorf("source field [%s.%s] has unknown type", tName, fName) - } - lkpDef.LeftTableFields[fieldIdx] = *srcFieldRef - } - - // Verify lookup idx has this field and the type matches - idxFieldRefs := lkpDef.TableCreator.Indexes[lkpDef.IndexName].getComponentFieldRefs(lkpDef.TableCreator.Name) - if len(idxFieldRefs) != len(lkpDef.LeftTableFields) { - return fmt.Errorf("lookup joins on %d fields, while referenced index %s uses %d fields, these lengths need to be the same", len(lkpDef.LeftTableFields), lkpDef.IndexName, len(idxFieldRefs)) - } - - for fieldIdx := 0; fieldIdx < len(lkpDef.LeftTableFields); fieldIdx++ { - if lkpDef.LeftTableFields[fieldIdx].FieldType != idxFieldRefs[fieldIdx].FieldType { - return fmt.Errorf("left-side field %s has type %s, while index field %s has type %s", - lkpDef.LeftTableFields[fieldIdx].FieldName, - lkpDef.LeftTableFields[fieldIdx].FieldType, - idxFieldRefs[fieldIdx].FieldName, - idxFieldRefs[fieldIdx].FieldType) - } - } - - return nil -} - -func (lkpDef *LookupDef) CheckFilterCondition(varsFromLookup eval.VarValuesMap) (bool, error) { - if !lkpDef.UsesFilter() { - return true, nil - } - eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &varsFromLookup) - valVolatile, err := eCtx.Eval(lkpDef.Filter) - if err != nil { - return false, fmt.Errorf("cannot evaluate expression: [%s]", err.Error()) - } - valBool, ok := valVolatile.(bool) - if !ok { - return false, fmt.Errorf("cannot evaluate lookup filter condition expression, expected bool, got %v(%T) instead", valVolatile, valVolatile) - } - - return valBool, nil -} +package sc + +import ( + "fmt" + "go/ast" + "strings" + + "github.com/capillariesio/capillaries/pkg/eval" +) + +type LookupJoinType string + +const ( + LookupJoinInner LookupJoinType = "inner" + LookupJoinLeft LookupJoinType = "left" +) + +type LookupDef struct { + IndexName string `json:"index_name"` + RawJoinOn string `json:"join_on"` + IsGroup bool `json:"group"` + RawFilter string `json:"filter"` + LookupJoin LookupJoinType `json:"join_type"` + IdxReadBatchSize int `json:"idx_read_batch_size"` + RightLookupReadBatchSize int `json:"right_lookup_read_batch_size"` + + LeftTableFields FieldRefs // In the same order as lookup idx - important + TableCreator *TableCreatorDef // Populated when walking through al nodes + UsedInFilterFields FieldRefs + Filter ast.Expr +} + +const ( + defaultIdxBatchSize int = 3000 + maxIdxBatchSize int = 5000 + defaultRightLookupBatchSize int = 3000 + maxRightLookupReadBatchSize int = 5000 +) + +func (lkpDef *LookupDef) CheckPagedBatchSize() error { + // Default gocql iterator page size is 5000, do not exceed it. + // Actually, getting close to it (4000) causes problems on small servers + if lkpDef.IdxReadBatchSize <= 0 { + lkpDef.IdxReadBatchSize = defaultIdxBatchSize + } else if lkpDef.IdxReadBatchSize > maxIdxBatchSize { + return fmt.Errorf("cannot use idx_read_batch_size %d, expected <= %d, default %d, ", lkpDef.IdxReadBatchSize, maxIdxBatchSize, defaultIdxBatchSize) + } + if lkpDef.RightLookupReadBatchSize <= 0 { + lkpDef.RightLookupReadBatchSize = defaultRightLookupBatchSize + } else if lkpDef.RightLookupReadBatchSize > maxRightLookupReadBatchSize { + return fmt.Errorf("cannot use right_lookup_read_batch_size %d, expected <= %d, default %d, ", lkpDef.RightLookupReadBatchSize, maxRightLookupReadBatchSize, defaultRightLookupBatchSize) + } + return nil +} + +func (lkpDef *LookupDef) UsesFilter() bool { + return len(strings.TrimSpace(lkpDef.RawFilter)) > 0 +} + +func (lkpDef *LookupDef) ValidateJoinType() error { + if lkpDef.LookupJoin != LookupJoinLeft && lkpDef.LookupJoin != LookupJoinInner { + return fmt.Errorf("invalid join type, expected inner or left, %s is not supported", lkpDef.LookupJoin) + } + return nil +} + +func (lkpDef *LookupDef) ParseFilter() error { + if !lkpDef.UsesFilter() { + return nil + } + var err error + lkpDef.Filter, err = ParseRawGolangExpressionStringAndHarvestFieldRefs(lkpDef.RawFilter, &lkpDef.UsedInFilterFields) + if err != nil { + return fmt.Errorf("cannot parse lookup filter condition [%s]: %s", lkpDef.RawFilter, err.Error()) + } + return nil +} + +func (lkpDef *LookupDef) resolveLeftTableFields(srcName string, srcFieldRefs *FieldRefs) error { + fieldExpressions := strings.Split(lkpDef.RawJoinOn, ",") + lkpDef.LeftTableFields = make(FieldRefs, len(fieldExpressions)) + for fieldIdx := 0; fieldIdx < len(fieldExpressions); fieldIdx++ { + fieldNameParts := strings.Split(strings.TrimSpace(fieldExpressions[fieldIdx]), ".") + if len(fieldNameParts) != 2 { + return fmt.Errorf("expected a comma-separated list of ., got [%s]", lkpDef.RawJoinOn) + } + tName := strings.TrimSpace(fieldNameParts[0]) + fName := strings.TrimSpace(fieldNameParts[1]) + if tName != srcName { + return fmt.Errorf("source table name [%s] unknown, expected [%s]", tName, srcName) + } + srcFieldRef, ok := srcFieldRefs.FindByFieldName(fName) + if !ok { + return fmt.Errorf("source [%s] does not produce field [%s]", tName, fName) + } + if srcFieldRef.FieldType == FieldTypeUnknown { + return fmt.Errorf("source field [%s.%s] has unknown type", tName, fName) + } + lkpDef.LeftTableFields[fieldIdx] = *srcFieldRef + } + + // Verify lookup idx has this field and the type matches + idxFieldRefs := lkpDef.TableCreator.Indexes[lkpDef.IndexName].getComponentFieldRefs(lkpDef.TableCreator.Name) + if len(idxFieldRefs) != len(lkpDef.LeftTableFields) { + return fmt.Errorf("lookup joins on %d fields, while referenced index %s uses %d fields, these lengths need to be the same", len(lkpDef.LeftTableFields), lkpDef.IndexName, len(idxFieldRefs)) + } + + for fieldIdx := 0; fieldIdx < len(lkpDef.LeftTableFields); fieldIdx++ { + if lkpDef.LeftTableFields[fieldIdx].FieldType != idxFieldRefs[fieldIdx].FieldType { + return fmt.Errorf("left-side field %s has type %s, while index field %s has type %s", + lkpDef.LeftTableFields[fieldIdx].FieldName, + lkpDef.LeftTableFields[fieldIdx].FieldType, + idxFieldRefs[fieldIdx].FieldName, + idxFieldRefs[fieldIdx].FieldType) + } + } + + return nil +} + +func (lkpDef *LookupDef) CheckFilterCondition(varsFromLookup eval.VarValuesMap) (bool, error) { + if !lkpDef.UsesFilter() { + return true, nil + } + eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &varsFromLookup) + valVolatile, err := eCtx.Eval(lkpDef.Filter) + if err != nil { + return false, fmt.Errorf("cannot evaluate expression: [%s]", err.Error()) + } + valBool, ok := valVolatile.(bool) + if !ok { + return false, fmt.Errorf("cannot evaluate lookup filter condition expression, expected bool, got %v(%T) instead", valVolatile, valVolatile) + } + + return valBool, nil +} diff --git a/pkg/sc/lookup_def_test.go b/pkg/sc/lookup_def_test.go index d6ef5c6..92fa772 100644 --- a/pkg/sc/lookup_def_test.go +++ b/pkg/sc/lookup_def_test.go @@ -1,250 +1,250 @@ -package sc - -import ( - "regexp" - "testing" - - "github.com/stretchr/testify/assert" -) - -const scriptDefJson string = ` -{ - "nodes": { - "read_orders": { - "type": "file_table", - "r": { - "urls": [ - "{dir_in}/olist_orders_dataset.csv" - ], - "csv": { - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_order_id": { - "csv": { - "col_hdr": "order_id" - }, - "col_type": "string" - }, - "col_order_status": { - "csv": { - "col_hdr": "order_status" - }, - "col_type": "string" - }, - "col_order_purchase_timestamp": { - "csv": { - "col_hdr": "order_purchase_timestamp", - "col_format": "2006-01-02 15:04:05" - }, - "col_type": "datetime" - } - } - }, - "w": { - "name": "orders", - "fields": { - "order_id": { - "expression": "r.col_order_id", - "type": "string" - }, - "order_status": { - "expression": "r.col_order_status", - "type": "string" - }, - "order_purchase_timestamp": { - "expression": "r.col_order_purchase_timestamp", - "type": "datetime" - } - } - } - }, - "read_order_items": { - "type": "file_table", - "r": { - "urls": [ - "{dir_in}/olist_order_items_dataset.csv" - ], - "csv":{ - "hdr_line_idx": 0, - "first_data_line_idx": 1 - }, - "columns": { - "col_order_id": { - "csv": { - "col_idx": 0, - "col_hdr": null - }, - "col_type": "string" - }, - "col_order_item_id": { - "csv": { - "col_idx": 1, - "col_hdr": null, - "col_format": "%d" - }, - "col_type": "int" - }, - "col_product_id": { - "csv": { - "col_idx": 2, - "col_hdr": null - }, - "col_type": "string" - }, - "col_seller_id": { - "csv": { - "col_idx": 3, - "col_hdr": null - }, - "col_type": "string" - }, - "col_shipping_limit_date": { - "csv": { - "col_idx": 4, - "col_hdr": null, - "col_format": "2006-01-02 15:04:05" - }, - "col_type": "datetime" - }, - "col_price": { - "csv": { - "col_idx": 5, - "col_hdr": null, - "col_format": "%f" - }, - "col_type": "decimal2" - }, - "col_freight_value": { - "csv": { - "col_idx": 6, - "col_hdr": null, - "col_format": "%f" - }, - "col_type": "decimal2" - } - } - }, - "w": { - "name": "order_items", - "having": null, - "fields": { - "order_id": { - "expression": "r.col_order_id", - "type": "string" - }, - "order_item_id": { - "expression": "r.col_order_item_id", - "type": "int" - }, - "product_id": { - "expression": "r.col_product_id", - "type": "string" - }, - "seller_id": { - "expression": "r.col_seller_id", - "type": "string" - }, - "shipping_limit_date": { - "expression": "r.col_shipping_limit_date", - "type": "datetime" - }, - "value": { - "expression": "r.col_price+r.col_freight_value", - "type": "decimal2" - } - }, - "indexes": { - "idx_order_items_order_id": "non_unique(order_id(case_sensitive))" - } - } - }, - "order_item_date_inner": { - "type": "table_lookup_table", - "r": { - "table": "orders", - "expected_batches_total": 100 - }, - "l": { - "index_name": "idx_order_items_order_id", - "idx_read_batch_size": 3000, - "right_lookup_read_batch_size": 5000, - "filter": "len(l.product_id) > 0", - "join_on": "r.order_id", - "group": false, - "join_type": "inner" - }, - "w": { - "name": "order_item_date_inner", - "fields": { - "order_id": { - "expression": "r.order_id", - "type": "string" - }, - "order_purchase_timestamp": { - "expression": "r.order_purchase_timestamp", - "type": "datetime" - }, - "order_item_id": { - "expression": "l.order_item_id", - "type": "int" - }, - "product_id": { - "expression": "l.product_id", - "type": "string" - }, - "seller_id": { - "expression": "l.seller_id", - "type": "string" - }, - "shipping_limit_date": { - "expression": "l.shipping_limit_date", - "type": "datetime" - }, - "value": { - "expression": "l.value", - "type": "decimal2" - } - } - } - } - }, - "dependency_policies": { - "current_active_first_stopped_nogo":` + DefaultPolicyCheckerConf + - ` - } -}` - -func TestLookupDef(t *testing.T) { - scriptDef := ScriptDef{} - assert.Nil(t, scriptDef.Deserialize([]byte(scriptDefJson), nil, nil, "", nil)) - - re := regexp.MustCompile(`"idx_read_batch_size": [\d]+`) - assert.Contains(t, - scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, `"idx_read_batch_size": 10000`)), nil, nil, "", nil).Error(), - "cannot use idx_read_batch_size 10000, expected <= 5000") - - re = regexp.MustCompile(`"right_lookup_read_batch_size": [\d]+`) - assert.Contains(t, - scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, `"right_lookup_read_batch_size": 10000`)), nil, nil, "", nil).Error(), - "cannot use right_lookup_read_batch_size 10000, expected <= 5000") - - re = regexp.MustCompile(`"filter": "[^"]+",`) - assert.Contains(t, - scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, `"filter": "aaa",`)), nil, nil, "", nil).Error(), - "cannot parse lookup filter condition") - assert.Contains(t, - scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, `"filter": "123",`)), nil, nil, "", nil).Error(), - "cannot evaluate lookup filter expression [123]: [expected type bool, but got int64 (123)]") - assert.Nil(t, - scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, ``)), nil, nil, "", nil)) - - re = regexp.MustCompile(`"join_on": "[^"]+",`) - assert.Contains(t, - scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, `"join_on": "r.order_id,r.order_status",`)), nil, nil, "", nil).Error(), - "lookup joins on 2 fields, while referenced index idx_order_items_order_id uses 1 fields, these lengths need to be the same") - assert.Contains(t, - scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, `"join_on": "",`)), nil, nil, "", nil).Error(), - "failed to resolve lookup for node order_item_date_inner: [expected a comma-separated list of ., got []]") -} +package sc + +import ( + "regexp" + "testing" + + "github.com/stretchr/testify/assert" +) + +const scriptDefJson string = ` +{ + "nodes": { + "read_orders": { + "type": "file_table", + "r": { + "urls": [ + "{dir_in}/olist_orders_dataset.csv" + ], + "csv": { + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_order_id": { + "csv": { + "col_hdr": "order_id" + }, + "col_type": "string" + }, + "col_order_status": { + "csv": { + "col_hdr": "order_status" + }, + "col_type": "string" + }, + "col_order_purchase_timestamp": { + "csv": { + "col_hdr": "order_purchase_timestamp", + "col_format": "2006-01-02 15:04:05" + }, + "col_type": "datetime" + } + } + }, + "w": { + "name": "orders", + "fields": { + "order_id": { + "expression": "r.col_order_id", + "type": "string" + }, + "order_status": { + "expression": "r.col_order_status", + "type": "string" + }, + "order_purchase_timestamp": { + "expression": "r.col_order_purchase_timestamp", + "type": "datetime" + } + } + } + }, + "read_order_items": { + "type": "file_table", + "r": { + "urls": [ + "{dir_in}/olist_order_items_dataset.csv" + ], + "csv":{ + "hdr_line_idx": 0, + "first_data_line_idx": 1 + }, + "columns": { + "col_order_id": { + "csv": { + "col_idx": 0, + "col_hdr": null + }, + "col_type": "string" + }, + "col_order_item_id": { + "csv": { + "col_idx": 1, + "col_hdr": null, + "col_format": "%d" + }, + "col_type": "int" + }, + "col_product_id": { + "csv": { + "col_idx": 2, + "col_hdr": null + }, + "col_type": "string" + }, + "col_seller_id": { + "csv": { + "col_idx": 3, + "col_hdr": null + }, + "col_type": "string" + }, + "col_shipping_limit_date": { + "csv": { + "col_idx": 4, + "col_hdr": null, + "col_format": "2006-01-02 15:04:05" + }, + "col_type": "datetime" + }, + "col_price": { + "csv": { + "col_idx": 5, + "col_hdr": null, + "col_format": "%f" + }, + "col_type": "decimal2" + }, + "col_freight_value": { + "csv": { + "col_idx": 6, + "col_hdr": null, + "col_format": "%f" + }, + "col_type": "decimal2" + } + } + }, + "w": { + "name": "order_items", + "having": null, + "fields": { + "order_id": { + "expression": "r.col_order_id", + "type": "string" + }, + "order_item_id": { + "expression": "r.col_order_item_id", + "type": "int" + }, + "product_id": { + "expression": "r.col_product_id", + "type": "string" + }, + "seller_id": { + "expression": "r.col_seller_id", + "type": "string" + }, + "shipping_limit_date": { + "expression": "r.col_shipping_limit_date", + "type": "datetime" + }, + "value": { + "expression": "r.col_price+r.col_freight_value", + "type": "decimal2" + } + }, + "indexes": { + "idx_order_items_order_id": "non_unique(order_id(case_sensitive))" + } + } + }, + "order_item_date_inner": { + "type": "table_lookup_table", + "r": { + "table": "orders", + "expected_batches_total": 100 + }, + "l": { + "index_name": "idx_order_items_order_id", + "idx_read_batch_size": 3000, + "right_lookup_read_batch_size": 5000, + "filter": "len(l.product_id) > 0", + "join_on": "r.order_id", + "group": false, + "join_type": "inner" + }, + "w": { + "name": "order_item_date_inner", + "fields": { + "order_id": { + "expression": "r.order_id", + "type": "string" + }, + "order_purchase_timestamp": { + "expression": "r.order_purchase_timestamp", + "type": "datetime" + }, + "order_item_id": { + "expression": "l.order_item_id", + "type": "int" + }, + "product_id": { + "expression": "l.product_id", + "type": "string" + }, + "seller_id": { + "expression": "l.seller_id", + "type": "string" + }, + "shipping_limit_date": { + "expression": "l.shipping_limit_date", + "type": "datetime" + }, + "value": { + "expression": "l.value", + "type": "decimal2" + } + } + } + } + }, + "dependency_policies": { + "current_active_first_stopped_nogo":` + DefaultPolicyCheckerConf + + ` + } +}` + +func TestLookupDef(t *testing.T) { + scriptDef := ScriptDef{} + assert.Nil(t, scriptDef.Deserialize([]byte(scriptDefJson), nil, nil, "", nil)) + + re := regexp.MustCompile(`"idx_read_batch_size": [\d]+`) + assert.Contains(t, + scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, `"idx_read_batch_size": 10000`)), nil, nil, "", nil).Error(), + "cannot use idx_read_batch_size 10000, expected <= 5000") + + re = regexp.MustCompile(`"right_lookup_read_batch_size": [\d]+`) + assert.Contains(t, + scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, `"right_lookup_read_batch_size": 10000`)), nil, nil, "", nil).Error(), + "cannot use right_lookup_read_batch_size 10000, expected <= 5000") + + re = regexp.MustCompile(`"filter": "[^"]+",`) + assert.Contains(t, + scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, `"filter": "aaa",`)), nil, nil, "", nil).Error(), + "cannot parse lookup filter condition") + assert.Contains(t, + scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, `"filter": "123",`)), nil, nil, "", nil).Error(), + "cannot evaluate lookup filter expression [123]: [expected type bool, but got int64 (123)]") + assert.Nil(t, + scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, ``)), nil, nil, "", nil)) + + re = regexp.MustCompile(`"join_on": "[^"]+",`) + assert.Contains(t, + scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, `"join_on": "r.order_id,r.order_status",`)), nil, nil, "", nil).Error(), + "lookup joins on 2 fields, while referenced index idx_order_items_order_id uses 1 fields, these lengths need to be the same") + assert.Contains(t, + scriptDef.Deserialize([]byte(re.ReplaceAllString(scriptDefJson, `"join_on": "",`)), nil, nil, "", nil).Error(), + "failed to resolve lookup for node order_item_date_inner: [expected a comma-separated list of ., got []]") +} diff --git a/pkg/sc/script_def.go b/pkg/sc/script_def.go index 7490bc4..6314efd 100644 --- a/pkg/sc/script_def.go +++ b/pkg/sc/script_def.go @@ -1,309 +1,304 @@ -package sc - -import ( - "encoding/json" - "fmt" - "strings" -) - -const ( - ReservedParamBatchIdx string = "{batch_idx|string}" - ReservedParamRunId string = "{run_id|string}" -) - -type ScriptDef struct { - ScriptNodes map[string]*ScriptNodeDef `json:"nodes"` - RawDependencyPolicies map[string]json.RawMessage `json:"dependency_policies"` - TableCreatorNodeMap map[string](*ScriptNodeDef) - IndexNodeMap map[string](*ScriptNodeDef) -} - -func (scriptDef *ScriptDef) Deserialize(jsonBytesScript []byte, customProcessorDefFactory CustomProcessorDefFactory, customProcessorsSettings map[string]json.RawMessage, caPath string, privateKeys map[string]string) error { - - if err := json.Unmarshal(jsonBytesScript, &scriptDef); err != nil { - return fmt.Errorf("cannot unmarshal script json: [%s]", err.Error()) - } - - errors := make([]string, 0, 2) - - // Deserialize node by node - for nodeName, node := range scriptDef.ScriptNodes { - node.Name = nodeName - if err := node.Deserialize(customProcessorDefFactory, customProcessorsSettings, caPath, privateKeys); err != nil { - errors = append(errors, fmt.Sprintf("cannot deserialize node %s: [%s]", nodeName, err.Error())) - } - } - - if len(errors) > 0 { - return fmt.Errorf(strings.Join(errors, "; ")) - } - - // Table -> node map, to look for ord and lkp indexes, for those nodes that create tables - scriptDef.TableCreatorNodeMap = map[string]*ScriptNodeDef{} - for _, node := range scriptDef.ScriptNodes { - if node.HasTableCreator() { - if _, ok := scriptDef.TableCreatorNodeMap[node.TableCreator.Name]; ok { - return fmt.Errorf("duplicate table name: %s", node.TableCreator.Name) - } - scriptDef.TableCreatorNodeMap[node.TableCreator.Name] = node - } - } - // Index -> node map, to look for ord and lkp indexes, for those nodes that create tables - scriptDef.IndexNodeMap = map[string]*ScriptNodeDef{} - for _, node := range scriptDef.ScriptNodes { - if node.HasTableCreator() { - for idxName := range node.TableCreator.Indexes { - if _, ok := scriptDef.IndexNodeMap[idxName]; ok { - return fmt.Errorf("duplicate index name: %s", idxName) - } - if _, ok := scriptDef.TableCreatorNodeMap[idxName]; ok { - return fmt.Errorf("cannot use same name for table and index: %s", idxName) - } - scriptDef.IndexNodeMap[idxName] = node - } - } - } - - for _, node := range scriptDef.ScriptNodes { - if err := scriptDef.resolveReader(node); err != nil { - return fmt.Errorf("failed to resolve reader for node %s: [%s]", node.Name, err.Error()) - } - } - - for _, node := range scriptDef.ScriptNodes { - if err := scriptDef.resolveLookup(node); err != nil { - return fmt.Errorf("failed to resolve lookup for node %s: [%s]", node.Name, err.Error()) - } - } - - for _, node := range scriptDef.ScriptNodes { - if err := scriptDef.checkFieldUsageInCustomProcessorCreator(node); err != nil { - return fmt.Errorf("field usage error in custom processor creator, node %s: [%s]", node.Name, err.Error()) - } - } - - for _, node := range scriptDef.ScriptNodes { - if err := scriptDef.checkFieldUsageInCreator(node); err != nil { - return fmt.Errorf("field usage error in creator, node %s: [%s]", node.Name, err.Error()) - } - } - - for _, node := range scriptDef.ScriptNodes { - if err := node.evalCreatorAndLookupExpressionsAndCheckType(); err != nil { - return fmt.Errorf("failed evaluating creator/lookup expressions for node %s: [%s]", node.Name, err.Error()) - } - } - - depPolMap := map[string](*DependencyPolicyDef){} - defaultDepPolCount := 0 - var defaultDepPol *DependencyPolicyDef - for polName, rawPolDef := range scriptDef.RawDependencyPolicies { - pol := DependencyPolicyDef{} - if err := pol.Deserialize(rawPolDef); err != nil { - return fmt.Errorf("failed to deserialize dependency policy %s: %s", polName, err.Error()) - } - depPolMap[polName] = &pol - if pol.IsDefault { - defaultDepPol = &pol - defaultDepPolCount++ - } - } - if defaultDepPolCount != 1 { - return fmt.Errorf("failed to deserialize dependency policies, found %d default policies, required 1", defaultDepPolCount) - } - - for polName, polDef := range depPolMap { - if err := polDef.evalRuleExpressionsAndCheckType(); err != nil { - return fmt.Errorf("failed to test dependency policy %s rules: %s", polName, err.Error()) - } - } - - for _, node := range scriptDef.ScriptNodes { - if node.HasTableReader() { - if len(node.DependencyPolicyName) == 0 { - node.DepPolDef = defaultDepPol - } else { - var ok bool - node.DepPolDef, ok = depPolMap[node.DependencyPolicyName] - if !ok { - return fmt.Errorf("cannot find dependency policy %s for node %s", node.DependencyPolicyName, node.Name) - } - } - } - } - - return nil -} - -func (scriptDef *ScriptDef) resolveReader(node *ScriptNodeDef) error { - if node.HasTableReader() { - tableCreatorNode, ok := scriptDef.TableCreatorNodeMap[node.TableReader.TableName] - if !ok { - return fmt.Errorf("cannot find the node that creates table [%s]", node.TableReader.TableName) - } - node.TableReader.TableCreator = &tableCreatorNode.TableCreator - } - return nil -} - -func (scriptDef *ScriptDef) resolveLookup(node *ScriptNodeDef) error { - if !node.HasLookup() { - return nil - } - - srcFieldRefs, err := node.getSourceFieldRefs() - if err != nil { - return fmt.Errorf("unexpectedly cannot resolve source field refs: [%s]", err.Error()) - } - idxCreatorNode, ok := scriptDef.IndexNodeMap[node.Lookup.IndexName] - if !ok { - return fmt.Errorf("cannot find the node that creates index [%s]", node.Lookup.IndexName) - } - - node.Lookup.TableCreator = &idxCreatorNode.TableCreator - - if err = node.Lookup.resolveLeftTableFields(ReaderAlias, srcFieldRefs); err != nil { - return err - } - - if err = node.Lookup.ParseFilter(); err != nil { - return err - } - - if err = node.Lookup.ValidateJoinType(); err != nil { - return err - } - - if err = node.Lookup.CheckPagedBatchSize(); err != nil { - return err - } - - return nil - -} - -func (scriptDef *ScriptDef) checkFieldUsageInCreator(node *ScriptNodeDef) error { - srcFieldRefs, err := node.getSourceFieldRefs() - if err != nil { - return fmt.Errorf("unexpectedly cannot resolve source field refs: [%s]", err.Error()) - } - - var processorFieldRefs *FieldRefs - if node.HasCustomProcessor() { - processorFieldRefs = node.CustomProcessor.GetFieldRefs() - if err != nil { - return fmt.Errorf("cannot resolve processor field refs: [%s]", err.Error()) - } - } - - var lookupFieldRefs *FieldRefs - if node.HasLookup() { - lookupFieldRefs = node.Lookup.TableCreator.GetFieldRefsWithAlias(LookupAlias) - } - - errors := make([]string, 0) - - var targetFieldRefs *FieldRefs - if node.HasTableCreator() { - targetFieldRefs = node.TableCreator.GetFieldRefsWithAlias(CreatorAlias) - } else if node.HasFileCreator() { - targetFieldRefs = node.FileCreator.getFieldRefs() - } else { - return fmt.Errorf("dev error, unknown creator") - } - - // Lookup - if node.HasLookup() && node.Lookup.UsesFilter() { - // Having: allow only lookup table, prohibit src and tgt - if err := checkAllowed(&node.Lookup.UsedInFilterFields, JoinFieldRefs(srcFieldRefs, targetFieldRefs), lookupFieldRefs); err != nil { - errors = append(errors, fmt.Sprintf("invalid field in lookup filter [%s], only fields from the lookup table [%s](alias %s) are allowed: [%s]", node.Lookup.RawFilter, node.Lookup.TableCreator.Name, LookupAlias, err.Error())) - } - } - - // Table creator - if node.HasTableCreator() { - srcLkpCustomFieldRefs := JoinFieldRefs(srcFieldRefs, lookupFieldRefs, processorFieldRefs) - // Having: allow tgt fields, prohibit src, lkp - if err := checkAllowed(&node.TableCreator.UsedInHavingFields, srcLkpCustomFieldRefs, targetFieldRefs); err != nil { - errors = append(errors, fmt.Sprintf("invalid field in table creator 'having' condition: [%s]; only target (w.*) fields allowed, reader (r.*) and lookup (l.*) fields are prohibited", err.Error())) - } - // Tgt expressions: allow src iterator table (or src file), lkp, custom processor, prohibit target - // TODO: aggregate functions cannot include fields from group field list - if err := checkAllowed(&node.TableCreator.UsedInTargetExpressionsFields, targetFieldRefs, srcLkpCustomFieldRefs); err != nil { - errors = append(errors, fmt.Sprintf("invalid field(s) in target table field expression: [%s]", err.Error())) - } - } - - // File creator - if node.HasFileCreator() { - // Having: allow tgt fields, prohibit src - if err := checkAllowed(&node.FileCreator.UsedInHavingFields, srcFieldRefs, targetFieldRefs); err != nil { - errors = append(errors, fmt.Sprintf("invalid field in file creator 'having' condition: [%s]", err.Error())) - } - - // Tgt expressions: allow src, prohibit target fields - // TODO: aggregate functions cannot include fields from group field list - if err := checkAllowed(&node.FileCreator.UsedInTargetExpressionsFields, targetFieldRefs, srcFieldRefs); err != nil { - errors = append(errors, fmt.Sprintf("invalid field in target file field expression: [%s]", err.Error())) - } - } - - if len(errors) > 0 { - return fmt.Errorf("%s", strings.Join(errors, "; ")) - } else { - return nil - } -} - -func (scriptDef *ScriptDef) checkFieldUsageInCustomProcessorCreator(node *ScriptNodeDef) error { - if !node.HasCustomProcessor() { - return nil - } - - srcFieldRefs, err := node.getSourceFieldRefs() - if err != nil { - return fmt.Errorf("unexpectedly cannot resolve source field refs: [%s]", err.Error()) - } - - procTgtFieldRefs := node.CustomProcessor.GetFieldRefs() - - // In processor fields, we are allowed to use only reader and processor fields ("r" and "p") - if err := checkAllowed(node.CustomProcessor.GetUsedInTargetExpressionsFields(), nil, JoinFieldRefs(srcFieldRefs, procTgtFieldRefs)); err != nil { - return fmt.Errorf("invalid field(s) in target table field expression: [%s]", err.Error()) - } - - return nil -} - -func (scriptDef *ScriptDef) addToAffected(rootNode *ScriptNodeDef, affectedSet map[string]struct{}) { - if _, ok := affectedSet[rootNode.Name]; ok { - return - } - - affectedSet[rootNode.Name] = struct{}{} - - for _, node := range scriptDef.ScriptNodes { - if rootNode.HasTableCreator() && node.HasTableReader() && rootNode.TableCreator.Name == node.TableReader.TableName && node.StartPolicy == NodeStartAuto { - scriptDef.addToAffected(node, affectedSet) - } else if rootNode.HasTableCreator() && node.HasLookup() && rootNode.TableCreator.Name == node.Lookup.TableCreator.Name && node.StartPolicy == NodeStartAuto { - scriptDef.addToAffected(node, affectedSet) - } - } -} - -func (scriptDef *ScriptDef) GetAffectedNodes(startNodeNames []string) []string { - affectedSet := map[string]struct{}{} - for _, nodeName := range startNodeNames { - if node, ok := scriptDef.ScriptNodes[nodeName]; ok { - scriptDef.addToAffected(node, affectedSet) - } - } - - affectedList := make([]string, len(affectedSet)) - i := 0 - for k := range affectedSet { - affectedList[i] = k - i++ - } - return affectedList -} +package sc + +import ( + "encoding/json" + "fmt" + "strings" +) + +const ( + ReservedParamBatchIdx string = "{batch_idx|string}" + ReservedParamRunId string = "{run_id|string}" +) + +type ScriptDef struct { + ScriptNodes map[string]*ScriptNodeDef `json:"nodes"` + RawDependencyPolicies map[string]json.RawMessage `json:"dependency_policies"` + TableCreatorNodeMap map[string](*ScriptNodeDef) + IndexNodeMap map[string](*ScriptNodeDef) +} + +func (scriptDef *ScriptDef) Deserialize(jsonBytesScript []byte, customProcessorDefFactory CustomProcessorDefFactory, customProcessorsSettings map[string]json.RawMessage, caPath string, privateKeys map[string]string) error { + + if err := json.Unmarshal(jsonBytesScript, &scriptDef); err != nil { + return fmt.Errorf("cannot unmarshal script json: [%s]", err.Error()) + } + + errors := make([]string, 0, 2) + + // Deserialize node by node + for nodeName, node := range scriptDef.ScriptNodes { + node.Name = nodeName + if err := node.Deserialize(customProcessorDefFactory, customProcessorsSettings, caPath, privateKeys); err != nil { + errors = append(errors, fmt.Sprintf("cannot deserialize node %s: [%s]", nodeName, err.Error())) + } + } + + if len(errors) > 0 { + return fmt.Errorf(strings.Join(errors, "; ")) + } + + // Table -> node map, to look for ord and lkp indexes, for those nodes that create tables + scriptDef.TableCreatorNodeMap = map[string]*ScriptNodeDef{} + for _, node := range scriptDef.ScriptNodes { + if node.HasTableCreator() { + if _, ok := scriptDef.TableCreatorNodeMap[node.TableCreator.Name]; ok { + return fmt.Errorf("duplicate table name: %s", node.TableCreator.Name) + } + scriptDef.TableCreatorNodeMap[node.TableCreator.Name] = node + } + } + // Index -> node map, to look for ord and lkp indexes, for those nodes that create tables + scriptDef.IndexNodeMap = map[string]*ScriptNodeDef{} + for _, node := range scriptDef.ScriptNodes { + if node.HasTableCreator() { + for idxName := range node.TableCreator.Indexes { + if _, ok := scriptDef.IndexNodeMap[idxName]; ok { + return fmt.Errorf("duplicate index name: %s", idxName) + } + if _, ok := scriptDef.TableCreatorNodeMap[idxName]; ok { + return fmt.Errorf("cannot use same name for table and index: %s", idxName) + } + scriptDef.IndexNodeMap[idxName] = node + } + } + } + + for _, node := range scriptDef.ScriptNodes { + if err := scriptDef.resolveReader(node); err != nil { + return fmt.Errorf("failed to resolve reader for node %s: [%s]", node.Name, err.Error()) + } + } + + for _, node := range scriptDef.ScriptNodes { + if err := scriptDef.resolveLookup(node); err != nil { + return fmt.Errorf("failed to resolve lookup for node %s: [%s]", node.Name, err.Error()) + } + } + + for _, node := range scriptDef.ScriptNodes { + if err := scriptDef.checkFieldUsageInCustomProcessorCreator(node); err != nil { + return fmt.Errorf("field usage error in custom processor creator, node %s: [%s]", node.Name, err.Error()) + } + } + + for _, node := range scriptDef.ScriptNodes { + if err := scriptDef.checkFieldUsageInCreator(node); err != nil { + return fmt.Errorf("field usage error in creator, node %s: [%s]", node.Name, err.Error()) + } + } + + for _, node := range scriptDef.ScriptNodes { + if err := node.evalCreatorAndLookupExpressionsAndCheckType(); err != nil { + return fmt.Errorf("failed evaluating creator/lookup expressions for node %s: [%s]", node.Name, err.Error()) + } + } + + depPolMap := map[string](*DependencyPolicyDef){} + defaultDepPolCount := 0 + var defaultDepPol *DependencyPolicyDef + for polName, rawPolDef := range scriptDef.RawDependencyPolicies { + pol := DependencyPolicyDef{} + if err := pol.Deserialize(rawPolDef); err != nil { + return fmt.Errorf("failed to deserialize dependency policy %s: %s", polName, err.Error()) + } + depPolMap[polName] = &pol + if pol.IsDefault { + defaultDepPol = &pol + defaultDepPolCount++ + } + } + if defaultDepPolCount != 1 { + return fmt.Errorf("failed to deserialize dependency policies, found %d default policies, required 1", defaultDepPolCount) + } + + for polName, polDef := range depPolMap { + if err := polDef.evalRuleExpressionsAndCheckType(); err != nil { + return fmt.Errorf("failed to test dependency policy %s rules: %s", polName, err.Error()) + } + } + + for _, node := range scriptDef.ScriptNodes { + if node.HasTableReader() { + if len(node.DependencyPolicyName) == 0 { + node.DepPolDef = defaultDepPol + } else { + var ok bool + node.DepPolDef, ok = depPolMap[node.DependencyPolicyName] + if !ok { + return fmt.Errorf("cannot find dependency policy %s for node %s", node.DependencyPolicyName, node.Name) + } + } + } + } + + return nil +} + +func (scriptDef *ScriptDef) resolveReader(node *ScriptNodeDef) error { + if node.HasTableReader() { + tableCreatorNode, ok := scriptDef.TableCreatorNodeMap[node.TableReader.TableName] + if !ok { + return fmt.Errorf("cannot find the node that creates table [%s]", node.TableReader.TableName) + } + node.TableReader.TableCreator = &tableCreatorNode.TableCreator + } + return nil +} + +func (scriptDef *ScriptDef) resolveLookup(node *ScriptNodeDef) error { + if !node.HasLookup() { + return nil + } + + srcFieldRefs, err := node.getSourceFieldRefs() + if err != nil { + return fmt.Errorf("unexpectedly cannot resolve source field refs: [%s]", err.Error()) + } + idxCreatorNode, ok := scriptDef.IndexNodeMap[node.Lookup.IndexName] + if !ok { + return fmt.Errorf("cannot find the node that creates index [%s]", node.Lookup.IndexName) + } + + node.Lookup.TableCreator = &idxCreatorNode.TableCreator + + if err = node.Lookup.resolveLeftTableFields(ReaderAlias, srcFieldRefs); err != nil { + return err + } + + if err = node.Lookup.ParseFilter(); err != nil { + return err + } + + if err = node.Lookup.ValidateJoinType(); err != nil { + return err + } + + return node.Lookup.CheckPagedBatchSize() +} + +func (scriptDef *ScriptDef) checkFieldUsageInCreator(node *ScriptNodeDef) error { + srcFieldRefs, err := node.getSourceFieldRefs() + if err != nil { + return fmt.Errorf("unexpectedly cannot resolve source field refs: [%s]", err.Error()) + } + + var processorFieldRefs *FieldRefs + if node.HasCustomProcessor() { + processorFieldRefs = node.CustomProcessor.GetFieldRefs() + if err != nil { + return fmt.Errorf("cannot resolve processor field refs: [%s]", err.Error()) + } + } + + var lookupFieldRefs *FieldRefs + if node.HasLookup() { + lookupFieldRefs = node.Lookup.TableCreator.GetFieldRefsWithAlias(LookupAlias) + } + + errors := make([]string, 0) + + var targetFieldRefs *FieldRefs + if node.HasTableCreator() { + targetFieldRefs = node.TableCreator.GetFieldRefsWithAlias(CreatorAlias) + } else if node.HasFileCreator() { + targetFieldRefs = node.FileCreator.getFieldRefs() + } else { + return fmt.Errorf("dev error, unknown creator") + } + + // Lookup + if node.HasLookup() && node.Lookup.UsesFilter() { + // Having: allow only lookup table, prohibit src and tgt + if err := checkAllowed(&node.Lookup.UsedInFilterFields, JoinFieldRefs(srcFieldRefs, targetFieldRefs), lookupFieldRefs); err != nil { + errors = append(errors, fmt.Sprintf("invalid field in lookup filter [%s], only fields from the lookup table [%s](alias %s) are allowed: [%s]", node.Lookup.RawFilter, node.Lookup.TableCreator.Name, LookupAlias, err.Error())) + } + } + + // Table creator + if node.HasTableCreator() { + srcLkpCustomFieldRefs := JoinFieldRefs(srcFieldRefs, lookupFieldRefs, processorFieldRefs) + // Having: allow tgt fields, prohibit src, lkp + if err := checkAllowed(&node.TableCreator.UsedInHavingFields, srcLkpCustomFieldRefs, targetFieldRefs); err != nil { + errors = append(errors, fmt.Sprintf("invalid field in table creator 'having' condition: [%s]; only target (w.*) fields allowed, reader (r.*) and lookup (l.*) fields are prohibited", err.Error())) + } + // Tgt expressions: allow src iterator table (or src file), lkp, custom processor, prohibit target + // TODO: aggregate functions cannot include fields from group field list + if err := checkAllowed(&node.TableCreator.UsedInTargetExpressionsFields, targetFieldRefs, srcLkpCustomFieldRefs); err != nil { + errors = append(errors, fmt.Sprintf("invalid field(s) in target table field expression: [%s]", err.Error())) + } + } + + // File creator + if node.HasFileCreator() { + // Having: allow tgt fields, prohibit src + if err := checkAllowed(&node.FileCreator.UsedInHavingFields, srcFieldRefs, targetFieldRefs); err != nil { + errors = append(errors, fmt.Sprintf("invalid field in file creator 'having' condition: [%s]", err.Error())) + } + + // Tgt expressions: allow src, prohibit target fields + // TODO: aggregate functions cannot include fields from group field list + if err := checkAllowed(&node.FileCreator.UsedInTargetExpressionsFields, targetFieldRefs, srcFieldRefs); err != nil { + errors = append(errors, fmt.Sprintf("invalid field in target file field expression: [%s]", err.Error())) + } + } + + if len(errors) > 0 { + return fmt.Errorf("%s", strings.Join(errors, "; ")) + } + + return nil +} + +func (scriptDef *ScriptDef) checkFieldUsageInCustomProcessorCreator(node *ScriptNodeDef) error { + if !node.HasCustomProcessor() { + return nil + } + + srcFieldRefs, err := node.getSourceFieldRefs() + if err != nil { + return fmt.Errorf("unexpectedly cannot resolve source field refs: [%s]", err.Error()) + } + + procTgtFieldRefs := node.CustomProcessor.GetFieldRefs() + + // In processor fields, we are allowed to use only reader and processor fields ("r" and "p") + if err := checkAllowed(node.CustomProcessor.GetUsedInTargetExpressionsFields(), nil, JoinFieldRefs(srcFieldRefs, procTgtFieldRefs)); err != nil { + return fmt.Errorf("invalid field(s) in target table field expression: [%s]", err.Error()) + } + + return nil +} + +func (scriptDef *ScriptDef) addToAffected(rootNode *ScriptNodeDef, affectedSet map[string]struct{}) { + if _, ok := affectedSet[rootNode.Name]; ok { + return + } + + affectedSet[rootNode.Name] = struct{}{} + + for _, node := range scriptDef.ScriptNodes { + if rootNode.HasTableCreator() && node.HasTableReader() && rootNode.TableCreator.Name == node.TableReader.TableName && node.StartPolicy == NodeStartAuto { + scriptDef.addToAffected(node, affectedSet) + } else if rootNode.HasTableCreator() && node.HasLookup() && rootNode.TableCreator.Name == node.Lookup.TableCreator.Name && node.StartPolicy == NodeStartAuto { + scriptDef.addToAffected(node, affectedSet) + } + } +} + +func (scriptDef *ScriptDef) GetAffectedNodes(startNodeNames []string) []string { + affectedSet := map[string]struct{}{} + for _, nodeName := range startNodeNames { + if node, ok := scriptDef.ScriptNodes[nodeName]; ok { + scriptDef.addToAffected(node, affectedSet) + } + } + + affectedList := make([]string, len(affectedSet)) + i := 0 + for k := range affectedSet { + affectedList[i] = k + i++ + } + return affectedList +} diff --git a/pkg/sc/script_def_loader.donotcover.go b/pkg/sc/script_def_loader.donotcover.go index 3e0dc84..415cb6b 100644 --- a/pkg/sc/script_def_loader.donotcover.go +++ b/pkg/sc/script_def_loader.donotcover.go @@ -7,16 +7,16 @@ import ( "github.com/capillariesio/capillaries/pkg/xfer" ) -func NewScriptFromFiles(caPath string, privateKeys map[string]string, scriptUri string, scriptParamsUri string, customProcessorDefFactoryInstance CustomProcessorDefFactory, customProcessorsSettings map[string]json.RawMessage) (*ScriptDef, error, ScriptInitProblemType) { +func NewScriptFromFiles(caPath string, privateKeys map[string]string, scriptUri string, scriptParamsUri string, customProcessorDefFactoryInstance CustomProcessorDefFactory, customProcessorsSettings map[string]json.RawMessage) (*ScriptDef, ScriptInitProblemType, error) { jsonBytesScript, err := xfer.GetFileBytes(scriptUri, caPath, privateKeys) if err != nil { - return nil, fmt.Errorf("cannot read script: %s", err.Error()), ScriptInitConnectivityProblem + return nil, ScriptInitConnectivityProblem, fmt.Errorf("cannot read script: %s", err.Error()) } var jsonBytesParams []byte if len(scriptParamsUri) > 0 { jsonBytesParams, err = xfer.GetFileBytes(scriptParamsUri, caPath, privateKeys) if err != nil { - return nil, fmt.Errorf("cannot read script parameters: %s", err.Error()), ScriptInitConnectivityProblem + return nil, ScriptInitConnectivityProblem, fmt.Errorf("cannot read script parameters: %s", err.Error()) } } diff --git a/pkg/sc/script_def_loader.go b/pkg/sc/script_def_loader.go index 02c7cf6..78151b3 100644 --- a/pkg/sc/script_def_loader.go +++ b/pkg/sc/script_def_loader.go @@ -14,7 +14,7 @@ const ScriptInitUrlProblem ScriptInitProblemType = 1 const ScriptInitContentProblem ScriptInitProblemType = 2 const ScriptInitConnectivityProblem ScriptInitProblemType = 3 -func NewScriptFromFileBytes(caPath string, privateKeys map[string]string, scriptUri string, jsonBytesScript []byte, scriptParamsUri string, jsonBytesParams []byte, customProcessorDefFactoryInstance CustomProcessorDefFactory, customProcessorsSettings map[string]json.RawMessage) (*ScriptDef, error, ScriptInitProblemType) { +func NewScriptFromFileBytes(caPath string, privateKeys map[string]string, scriptUri string, jsonBytesScript []byte, scriptParamsUri string, jsonBytesParams []byte, customProcessorDefFactoryInstance CustomProcessorDefFactory, customProcessorsSettings map[string]json.RawMessage) (*ScriptDef, ScriptInitProblemType, error) { // Make sure parameters are in canonical format: {param_name|param_type} scriptString := string(jsonBytesScript) @@ -33,15 +33,15 @@ func NewScriptFromFileBytes(caPath string, privateKeys map[string]string, script re = regexp.MustCompile(`([^"]{[a-zA-Z0-9_]+\|(number|bool)})|({[a-zA-Z0-9_]+\|(number|bool)}[^"])`) invalidParamRefs := re.FindAllString(scriptString, -1) if len(invalidParamRefs) > 0 { - return nil, fmt.Errorf("cannot parse number/bool script parameter references in [%s], the following parameter references should not have extra characters between curly braces and double quotes: [%s]", scriptUri, strings.Join(invalidParamRefs, ",")), ScriptInitUrlProblem + return nil, ScriptInitUrlProblem, fmt.Errorf("cannot parse number/bool script parameter references in [%s], the following parameter references should not have extra characters between curly braces and double quotes: [%s]", scriptUri, strings.Join(invalidParamRefs, ",")) } // Apply template params here, script def should know nothing about them: they may tweak some 3d-party tfm config - paramsMap := map[string]interface{}{} + paramsMap := map[string]any{} if jsonBytesParams != nil { if err := json.Unmarshal(jsonBytesParams, ¶msMap); err != nil { - return nil, fmt.Errorf("cannot unmarshal script params json from [%s]: [%s]", scriptParamsUri, err.Error()), ScriptInitContentProblem + return nil, ScriptInitContentProblem, fmt.Errorf("cannot unmarshal script params json from [%s]: [%s]", scriptParamsUri, err.Error()) } } @@ -72,7 +72,7 @@ func NewScriptFromFileBytes(caPath string, privateKeys map[string]string, script replacerStrings[i] = fmt.Sprintf(`"{%s|bool}"`, templateParam) replacerStrings[i+1] = fmt.Sprintf("%t", typedParamVal) default: - return nil, fmt.Errorf("unsupported parameter type %T from [%s]: %s", templateParamVal, scriptParamsUri, templateParam), ScriptInitContentProblem + return nil, ScriptInitContentProblem, fmt.Errorf("unsupported parameter type %T from [%s]: %s", templateParamVal, scriptParamsUri, templateParam) } i += 2 } @@ -89,13 +89,13 @@ func NewScriptFromFileBytes(caPath string, privateKeys map[string]string, script } } if len(unresolvedParamMap) > 0 { - return nil, fmt.Errorf("unresolved parameter references in [%s]: %v; make sure that type in the script matches the type of the parameter value in the script parameters file", scriptUri, unresolvedParamMap), ScriptInitContentProblem + return nil, ScriptInitContentProblem, fmt.Errorf("unresolved parameter references in [%s]: %v; make sure that type in the script matches the type of the parameter value in the script parameters file", scriptUri, unresolvedParamMap) } newScript := &ScriptDef{} if err := newScript.Deserialize([]byte(scriptString), customProcessorDefFactoryInstance, customProcessorsSettings, caPath, privateKeys); err != nil { - return nil, fmt.Errorf("cannot deserialize script %s(%s): %s", scriptUri, scriptParamsUri, err.Error()), ScriptInitContentProblem + return nil, ScriptInitContentProblem, fmt.Errorf("cannot deserialize script %s(%s): %s", scriptUri, scriptParamsUri, err.Error()) } - return newScript, nil, ScriptInitNoProblem + return newScript, ScriptInitNoProblem, nil } diff --git a/pkg/sc/script_def_loader_test.go b/pkg/sc/script_def_loader_test.go index 16c585b..c782db6 100644 --- a/pkg/sc/script_def_loader_test.go +++ b/pkg/sc/script_def_loader_test.go @@ -190,12 +190,12 @@ func (procDef *SomeTestCustomProcessorDef) GetFieldRefs() *FieldRefs { TableName: CustomProcessorAlias, FieldName: fieldName, FieldType: fieldDef.Type} - i += 1 + i++ } return &fieldRefs } -func (procDef *SomeTestCustomProcessorDef) Deserialize(raw json.RawMessage, customProcSettings json.RawMessage, caPath string, privateKeys map[string]string) error { +func (procDef *SomeTestCustomProcessorDef) Deserialize(raw json.RawMessage, _ json.RawMessage, _ string, _ map[string]string) error { var err error if err = json.Unmarshal(raw, procDef); err != nil { return fmt.Errorf("cannot unmarshal some_test_custom_processor def: %s", err.Error()) @@ -212,6 +212,10 @@ func (procDef *SomeTestCustomProcessorDef) Deserialize(raw json.RawMessage, cust } } + if len(errors) > 0 { + return fmt.Errorf(strings.Join(errors, ";")) + } + procDef.UsedInTargetExpressionsFields = GetFieldRefsUsedInAllTargetExpressions(procDef.ProducedFields) return nil } @@ -234,7 +238,7 @@ func (f *SomeTestCustomProcessorDefFactory) Create(processorType string) (Custom func TestNewScriptFromFileBytes(t *testing.T) { // Test main script parsing function - scriptDef, err, initProblem := NewScriptFromFileBytes("", nil, + scriptDef, initProblem, err := NewScriptFromFileBytes("", nil, "someScriptUri", []byte(parameterizedScriptJson), "someScriptParamsUrl", []byte(paramsJson), &SomeTestCustomProcessorDefFactory{}, map[string]json.RawMessage{"some_test_custom_proc": []byte("{}")}) @@ -251,68 +255,68 @@ func TestNewScriptFromFileBytes(t *testing.T) { assert.Equal(t, true, scriptDef.ScriptNodes["join_table1_table2"].Lookup.IsGroup) // Tweak paramater name and make sure templating engine catches it - scriptDef, err, initProblem = NewScriptFromFileBytes("", nil, + _, _, err = NewScriptFromFileBytes("", nil, "someScriptUri", []byte(strings.ReplaceAll(parameterizedScriptJson, "source_table_for_test_custom_processor", "some_bad_param")), "someScriptParamsUrl", []byte(paramsJson), nil, nil) assert.Contains(t, err.Error(), "unresolved parameter references", err.Error()) // Bad-formed JSON - scriptDef, err, initProblem = NewScriptFromFileBytes("", nil, + _, _, err = NewScriptFromFileBytes("", nil, "someScriptUri", []byte(strings.TrimSuffix(parameterizedScriptJson, "}")), "someScriptParamsUrl", []byte(paramsJson), nil, nil) assert.Contains(t, err.Error(), "unexpected end of JSON input", err.Error()) // Invalid field in custom processor (Python) formula - scriptDef, err, initProblem = NewScriptFromFileBytes("", nil, + _, _, err = NewScriptFromFileBytes("", nil, "someScriptUri", []byte(strings.ReplaceAll(parameterizedScriptJson, "-r.field_int1*2", "r.bad_field")), "someScriptParamsUrl", []byte(paramsJson), &SomeTestCustomProcessorDefFactory{}, map[string]json.RawMessage{"some_test_custom_proc": []byte("{}")}) assert.Contains(t, err.Error(), "field usage error in custom processor creator") // Invalid dependency policy - scriptDef, err, initProblem = NewScriptFromFileBytes("", nil, + _, _, err = NewScriptFromFileBytes("", nil, "someScriptUri", []byte(strings.ReplaceAll(parameterizedScriptJson, "run_is_current(desc),node_start_ts(desc)", "some_bad_event_priority_order")), "someScriptParamsUrl", []byte(paramsJson), &SomeTestCustomProcessorDefFactory{}, map[string]json.RawMessage{"some_test_custom_proc": []byte("{}")}) assert.Contains(t, err.Error(), "failed to deserialize dependency policy") // Run (tweaked) dependency policy checker with some vanilla values and see if it works - scriptDef, err, initProblem = NewScriptFromFileBytes("", nil, + _, _, err = NewScriptFromFileBytes("", nil, "someScriptUri", []byte(strings.ReplaceAll(parameterizedScriptJson, "e.run_final_status == wfmodel.RunStart", "e.run_final_status == true")), "someScriptParamsUrl", []byte(paramsJson), &SomeTestCustomProcessorDefFactory{}, map[string]json.RawMessage{"some_test_custom_proc": []byte("{}")}) assert.Contains(t, err.Error(), "failed to test dependency policy") re := regexp.MustCompile(`"expression": "e\.run[^"]+"`) - scriptDef, err, initProblem = NewScriptFromFileBytes("", nil, + _, _, err = NewScriptFromFileBytes("", nil, "someScriptUri", []byte(re.ReplaceAllString(parameterizedScriptJson, `"expression": 1`)), "someScriptParamsUrl", []byte(paramsJson), &SomeTestCustomProcessorDefFactory{}, map[string]json.RawMessage{"some_test_custom_proc": []byte("{}")}) assert.Contains(t, err.Error(), "cannot unmarshal dependency policy") - scriptDef, err, initProblem = NewScriptFromFileBytes("", nil, + _, _, err = NewScriptFromFileBytes("", nil, "someScriptUri", []byte(re.ReplaceAllString(parameterizedScriptJson, `"expression": "a.aaa"`)), "someScriptParamsUrl", []byte(paramsJson), &SomeTestCustomProcessorDefFactory{}, map[string]json.RawMessage{"some_test_custom_proc": []byte("{}")}) assert.Contains(t, err.Error(), "cannot parse rule expression 'a.aaa': all fields must be prefixed") - scriptDef, err, initProblem = NewScriptFromFileBytes("", nil, + _, _, err = NewScriptFromFileBytes("", nil, "someScriptUri", []byte(re.ReplaceAllString(parameterizedScriptJson, `"expression": "e.aaa"`)), "someScriptParamsUrl", []byte(paramsJson), &SomeTestCustomProcessorDefFactory{}, map[string]json.RawMessage{"some_test_custom_proc": []byte("{}")}) assert.Contains(t, err.Error(), "cannot parse rule expression 'e.aaa': field e.aaa not found") // Tweak lookup isGroup = false and get error - scriptDef, err, initProblem = NewScriptFromFileBytes("", nil, + _, _, err = NewScriptFromFileBytes("", nil, "someScriptUri", []byte(parameterizedScriptJson), "someScriptParamsUrl", []byte(strings.ReplaceAll(paramsJson, "true", "false")), &SomeTestCustomProcessorDefFactory{}, map[string]json.RawMessage{"some_test_custom_proc": []byte("{}")}) assert.Contains(t, err.Error(), "cannot use agg functions") // Invalid rerun_policy - scriptDef, err, initProblem = NewScriptFromFileBytes("", nil, + _, _, err = NewScriptFromFileBytes("", nil, "someScriptUri", []byte(strings.ReplaceAll(parameterizedScriptJson, "\"rerun_policy\": \"fail\"", "\"rerun_policy\": \"bad_rerun_policy\"")), "someScriptParamsUrl", []byte(paramsJson), &SomeTestCustomProcessorDefFactory{}, map[string]json.RawMessage{"some_test_custom_proc": []byte("{}")}) diff --git a/pkg/sc/script_def_test.go b/pkg/sc/script_def_test.go index fb76e25..9dfb8be 100644 --- a/pkg/sc/script_def_test.go +++ b/pkg/sc/script_def_test.go @@ -1,543 +1,543 @@ -package sc - -import ( - "strings" - "testing" - - "github.com/capillariesio/capillaries/pkg/eval" - "github.com/shopspring/decimal" - "github.com/stretchr/testify/assert" -) - -const plainScriptJson string = ` -{ - "nodes": { - "read_table1": { - "type": "file_table", - "r": { - "urls": [ - "file1.csv" - ], - "csv":{ - "first_data_line_idx": 0 - }, - "columns": { - "col_field_int": { - "csv":{ - "col_idx": 0 - }, - "col_type": "int" - }, - "col_field_string": { - "csv":{ - "col_idx": 1 - }, - "col_type": "string" - } - } - }, - "w": { - "name": "table1", - "having": "w.field_int1 > 1", - "fields": { - "field_int1": { - "expression": "r.col_field_int", - "type": "int" - }, - "field_string1": { - "expression": "r.col_field_string", - "type": "string" - } - } - } - }, - "read_table2": { - "type": "file_table", - "r": { - "urls": [ - "file2.tsv" - ], - "csv":{ - "first_data_line_idx": 0 - }, - "columns": { - "col_field_int": { - "csv":{ - "col_idx": 0 - }, - "col_type": "int" - }, - "col_field_string": { - "csv":{ - "col_idx": 1 - }, - "col_type": "string" - } - } - }, - "w": { - "name": "table2", - "fields": { - "field_int2": { - "expression": "r.col_field_int", - "type": "int" - }, - "field_string2": { - "expression": "r.col_field_string", - "type": "string" - } - }, - "indexes": { - "idx_table2_string2": "unique(field_string2)" - } - } - }, - "join_table1_table2": { - "type": "table_lookup_table", - "start_policy": "auto", - "r": { - "table": "table1", - "expected_batches_total": 2 - }, - "l": { - "index_name": "idx_table2_string2", - "join_on": "r.field_string1", - "filter": "l.field_int2 > 100", - "group": true, - "join_type": "left" - }, - "w": { - "name": "joined_table1_table2", - "having": "w.total_value > 2", - "fields": { - "field_int1": { - "expression": "r.field_int1", - "type": "int" - }, - "field_string1": { - "expression": "r.field_string1", - "type": "string" - }, - "total_value": { - "expression": "sum(l.field_int2)", - "type": "int" - }, - "item_count": { - "expression": "count()", - "type": "int" - } - } - } - }, - "file_totals": { - "type": "table_file", - "r": { - "table": "joined_table1_table2" - }, - "w": { - "top": { - "order": "field_int1(asc),item_count(asc)", - "limit": 500000 - }, - "having": "w.total_value > 3", - "url_template": "file_totals.csv", - "columns": [ - { - "csv":{ - "header": "field_int1", - "format": "%d" - }, - "name": "field_int1", - "expression": "r.field_int1", - "type": "int" - }, - { - "csv":{ - "header": "field_string1", - "format": "%s" - }, - "name": "field_string1", - "expression": "r.field_string1", - "type": "string" - }, - { - "csv":{ - "header": "total_value", - "format": "%s" - }, - "name": "total_value", - "expression": "decimal2(r.total_value)", - "type": "decimal2" - }, - { - "csv":{ - "header": "item_count", - "format": "%d" - }, - "name": "item_count", - "expression": "r.item_count", - "type": "int" - } - ] - } - } - }, - "dependency_policies": { - "current_active_first_stopped_nogo":` + DefaultPolicyCheckerConf + - ` - } -}` - -func TestCreatorFieldRefs(t *testing.T) { - var err error - - newScript := &ScriptDef{} - assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) - - tableFieldRefs := newScript.ScriptNodes["read_table2"].TableCreator.GetFieldRefsWithAlias(CreatorAlias) - var tableFieldRef *FieldRef - tableFieldRef, _ = tableFieldRefs.FindByFieldName("field_int2") - assert.Equal(t, CreatorAlias, tableFieldRef.TableName) - assert.Equal(t, FieldTypeInt, tableFieldRef.FieldType) - - fileFieldRefs := newScript.ScriptNodes["file_totals"].FileCreator.getFieldRefs() - var fileFieldRef *FieldRef - fileFieldRef, _ = fileFieldRefs.FindByFieldName("total_value") - assert.Equal(t, CreatorAlias, fileFieldRef.TableName) - assert.Equal(t, FieldTypeDecimal2, fileFieldRef.FieldType) - - // Duplicate creator - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"name": "table2"`, `"name": "table1"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "duplicate table name: table1") - - // Bad readertable name - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"table": "table1"`, `"table": "bad_table_name"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "cannot find the node that creates table [bad_table_name]") -} - -func TestCreatorCalculateHaving(t *testing.T) { - var isHaving bool - - newScript := &ScriptDef{} - assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) - - // Table writer: calculate having - var tableRecord map[string]interface{} - tableCreator := newScript.ScriptNodes["join_table1_table2"].TableCreator - - tableRecord = map[string]interface{}{"total_value": 3} - isHaving, _ = tableCreator.CheckTableRecordHavingCondition(tableRecord) - assert.True(t, isHaving) - - tableRecord = map[string]interface{}{"total_value": 2} - isHaving, _ = tableCreator.CheckTableRecordHavingCondition(tableRecord) - assert.False(t, isHaving) - - // File writer: calculate having - var colVals []interface{} - fileCreator := newScript.ScriptNodes["file_totals"].FileCreator - - colVals = make([]interface{}, 0) - colVals = append(colVals, 0, "a", 4, 0) - isHaving, _ = fileCreator.CheckFileRecordHavingCondition(colVals) - assert.True(t, isHaving) - - colVals = make([]interface{}, 0) - colVals = append(colVals, 0, "a", 3, 0) - isHaving, _ = fileCreator.CheckFileRecordHavingCondition(colVals) - assert.False(t, isHaving) -} - -func TestCreatorCalculateOutput(t *testing.T) { - var err error - var vars eval.VarValuesMap - - newScript := &ScriptDef{} - assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) - - // Table creator: calculate fields - - var fields map[string]interface{} - vars = eval.VarValuesMap{"r": {"field_int1": int64(1), "field_string1": "a"}} - fields, _ = newScript.ScriptNodes["join_table1_table2"].TableCreator.CalculateTableRecordFromSrcVars(true, vars) - if len(fields) == 4 { - assert.Equal(t, int64(1), fields["field_int1"]) - assert.Equal(t, "a", fields["field_string1"]) - assert.Equal(t, int64(1), fields["total_value"]) - assert.Equal(t, int64(1), fields["item_count"]) - } - - // Table creator: bad field expression, tweak sum - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `sum(l.field_int2)`, `sum(l.field_int2`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "cannot parse field expression [sum(l.field_int2]") - - // File creator: calculate columns - - var cols []interface{} - vars = eval.VarValuesMap{"r": {"field_int1": int64(1), "field_string1": "a", "total_value": decimal.NewFromInt(1), "item_count": int64(1)}} - cols, _ = newScript.ScriptNodes["file_totals"].FileCreator.CalculateFileRecordFromSrcVars(vars) - assert.Equal(t, 4, len(cols)) - if len(cols) == 4 { - assert.Equal(t, int64(1), cols[0]) - assert.Equal(t, "a", cols[1]) - assert.Equal(t, decimal.NewFromInt(1), cols[2]) - assert.Equal(t, int64(1), cols[3]) - } - - // File creator: bad column expression, tweak decimal2() - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `decimal2(r.total_value)`, `decimal2(r.total_value`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "[cannot parse column expression [decimal2(r.total_value]") - -} - -func TestLookup(t *testing.T) { - var err error - var vars eval.VarValuesMap - var isMatch bool - - newScript := &ScriptDef{} - assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) - - // Invalid (writer) field in aggregate, tweak sum() arg - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"expression": "sum(l.field_int2)"`, `"expression": "sum(w.field_int1)"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "invalid field(s) in target table field expression: [prohibited field w.field_int1]") - - // Filter calculation - - vars = eval.VarValuesMap{"l": {"field_int2": 101}} - isMatch, _ = newScript.ScriptNodes["join_table1_table2"].Lookup.CheckFilterCondition(vars) - assert.True(t, isMatch) - - vars = eval.VarValuesMap{"l": {"field_int2": 100}} - isMatch, _ = newScript.ScriptNodes["join_table1_table2"].Lookup.CheckFilterCondition(vars) - assert.False(t, isMatch) - - // bad index_name, tweak it - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"index_name": "idx_table2_string2"`, `"index_name": "idx_table2_string2_bad"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "cannot find the node that creates index [idx_table2_string2_bad]") - - // bad join_on, tweak it - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"join_on": "r.field_string1"`, `"join_on": ""`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "expected a comma-separated list of ., got []") - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"join_on": "r.field_string1"`, `"join_on": "bla"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "expected a comma-separated list of ., got [bla]") - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"join_on": "r.field_string1"`, `"join_on": "bla.bla"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "source table name [bla] unknown, expected [r]") - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"join_on": "r.field_string1"`, `"join_on": "r.field_string1_bad"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "source [r] does not produce field [field_string1_bad]") - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"join_on": "r.field_string1"`, `"join_on": "r.field_int1"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "left-side field field_int1 has type int, while index field field_string2 has type string") - - // bad filter, tweak it - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"filter": "l.field_int2 > 100"`, `"filter": "r.field_int2 > 100"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "invalid field in lookup filter [r.field_int2 > 100], only fields from the lookup table [table2](alias l) are allowed: [unknown field r.field_int2]") - - // bad join_type, tweak it - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"join_type": "left"`, `"join_type": "left_bad"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "invalid join type, expected inner or left, left_bad is not supported") -} - -func TestBadCreatorHaving(t *testing.T) { - var err error - - newScript := &ScriptDef{} - assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) - - // Bad expression, tweak having expression - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 2"`, `"having": "w.total_value &> 2"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "cannot parse table creator 'having' condition [w.total_value &> 2]") - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 3"`, `"having": "w.bad_field &> 3"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "cannot parse file creator 'having' condition [w.bad_field &> 3]") - - // Unknown field in having - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 2"`, `"having": "w.bad_field > 2"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "invalid field in table creator 'having' condition: [unknown field w.bad_field]") - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 3"`, `"having": "w.bad_field > 3"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "invalid field in file creator 'having' condition: [unknown field w.bad_field]]") - - // Prohibited reader field in having - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 2"`, `"having": "r.field_int1 > 2"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "invalid field in table creator 'having' condition: [prohibited field r.field_int1]") - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 3"`, `"having": "r.field_int1 > 3"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "invalid field in file creator 'having' condition: [prohibited field r.field_int1]") - - // Prohibited lookup field in table creator having - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 2"`, `"having": "l.field_int2 > 2"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "invalid field in table creator 'having' condition: [prohibited field l.field_int2]") - - // Type mismatch in having - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 2"`, `"having": "w.total_value == true"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "cannot evaluate table creator 'having' expression [w.total_value == true]: [cannot perform binary comp op, incompatible arg types '0(int64)' == 'true(bool)' ]") - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 3"`, `"having": "w.total_value == true"`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "cannot evaluate file creator 'having' expression [w.total_value == true]: [cannot perform binary comp op, incompatible arg types '2.34(decimal.Decimal)' == 'true(bool)' ]") -} - -func TestTopLimit(t *testing.T) { - var err error - - newScript := &ScriptDef{} - assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) - - // Tweak limit beyond allowed maximum - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"limit": 500000`, `"limit": 500001`, 1)), - nil, nil, "", nil) - assert.Contains(t, err.Error(), "top.limit cannot exceed 500000") - - // Remove limit altogether - - err = newScript.Deserialize( - []byte(strings.Replace(plainScriptJson, `"limit": 500000`, `"some_bogus_setting": 500000`, 1)), - nil, nil, "", nil) - assert.Equal(t, 500000, newScript.ScriptNodes["file_totals"].FileCreator.Top.Limit) -} - -func TestBatchIntervalsCalculation(t *testing.T) { - newScript := &ScriptDef{} - assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) - - var intervals [][]int64 - - tableReaderNodeDef := newScript.ScriptNodes["join_table1_table2"] - intervals, _ = tableReaderNodeDef.GetTokenIntervalsByNumberOfBatches() - - assert.Equal(t, 2, len(intervals)) - if len(intervals) == 2 { - assert.Equal(t, int64(-9223372036854775808), intervals[0][0]) - assert.Equal(t, int64(-2), intervals[0][1]) - assert.Equal(t, int64(-1), intervals[1][0]) - assert.Equal(t, int64(9223372036854775807), intervals[1][1]) - } - - fileReaderNodeDef := newScript.ScriptNodes["read_table1"] - intervals, _ = fileReaderNodeDef.GetTokenIntervalsByNumberOfBatches() - - assert.Equal(t, 1, len(intervals)) - if len(intervals) == 1 { - assert.Equal(t, int64(0), intervals[0][0]) - assert.Equal(t, int64(0), intervals[0][1]) - } - - fileCreatorNodeDef := newScript.ScriptNodes["file_totals"] - intervals, _ = fileCreatorNodeDef.GetTokenIntervalsByNumberOfBatches() - - assert.Equal(t, 1, len(intervals)) - if len(intervals) == 1 { - assert.Equal(t, int64(-9223372036854775808), intervals[0][0]) - assert.Equal(t, int64(9223372036854775807), intervals[0][1]) - } -} - -func TestUniqueIndexesFieldRefs(t *testing.T) { - newScript := &ScriptDef{} - assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) - - fileReaderNodeDef := newScript.ScriptNodes["read_table2"] - fieldRefs := fileReaderNodeDef.GetUniqueIndexesFieldRefs() - assert.Equal(t, 1, len(*fieldRefs)) - if len(*fieldRefs) == 1 { - assert.Equal(t, "table2", (*fieldRefs)[0].TableName) - assert.Equal(t, "field_string2", (*fieldRefs)[0].FieldName) - assert.Equal(t, FieldTypeString, (*fieldRefs)[0].FieldType) - } -} - -func TestAffectedNodes(t *testing.T) { - var affectedNodes []string - - newScript := &ScriptDef{} - assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) - - affectedNodes = newScript.GetAffectedNodes([]string{"read_table1"}) - assert.Equal(t, 3, len(affectedNodes)) - assert.Contains(t, affectedNodes, "read_table1") - assert.Contains(t, affectedNodes, "join_table1_table2") - assert.Contains(t, affectedNodes, "file_totals") - - affectedNodes = newScript.GetAffectedNodes([]string{"read_table1", "read_table2"}) - assert.Equal(t, 4, len(affectedNodes)) - assert.Contains(t, affectedNodes, "read_table1") - assert.Contains(t, affectedNodes, "read_table2") - assert.Contains(t, affectedNodes, "join_table1_table2") - assert.Contains(t, affectedNodes, "file_totals") - - // Make join manual and see the list of affected nodes shrinking - - assert.Nil(t, newScript.Deserialize([]byte(strings.Replace(plainScriptJson, `"start_policy": "auto"`, `"start_policy": "manual"`, 1)), nil, nil, "", nil)) - - affectedNodes = newScript.GetAffectedNodes([]string{"read_table1"}) - assert.Equal(t, 1, len(affectedNodes)) - assert.Contains(t, affectedNodes, "read_table1") - - affectedNodes = newScript.GetAffectedNodes([]string{"read_table1", "read_table2"}) - assert.Equal(t, 2, len(affectedNodes)) - assert.Contains(t, affectedNodes, "read_table1") - assert.Contains(t, affectedNodes, "read_table2") -} +package sc + +import ( + "strings" + "testing" + + "github.com/capillariesio/capillaries/pkg/eval" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" +) + +const plainScriptJson string = ` +{ + "nodes": { + "read_table1": { + "type": "file_table", + "r": { + "urls": [ + "file1.csv" + ], + "csv":{ + "first_data_line_idx": 0 + }, + "columns": { + "col_field_int": { + "csv":{ + "col_idx": 0 + }, + "col_type": "int" + }, + "col_field_string": { + "csv":{ + "col_idx": 1 + }, + "col_type": "string" + } + } + }, + "w": { + "name": "table1", + "having": "w.field_int1 > 1", + "fields": { + "field_int1": { + "expression": "r.col_field_int", + "type": "int" + }, + "field_string1": { + "expression": "r.col_field_string", + "type": "string" + } + } + } + }, + "read_table2": { + "type": "file_table", + "r": { + "urls": [ + "file2.tsv" + ], + "csv":{ + "first_data_line_idx": 0 + }, + "columns": { + "col_field_int": { + "csv":{ + "col_idx": 0 + }, + "col_type": "int" + }, + "col_field_string": { + "csv":{ + "col_idx": 1 + }, + "col_type": "string" + } + } + }, + "w": { + "name": "table2", + "fields": { + "field_int2": { + "expression": "r.col_field_int", + "type": "int" + }, + "field_string2": { + "expression": "r.col_field_string", + "type": "string" + } + }, + "indexes": { + "idx_table2_string2": "unique(field_string2)" + } + } + }, + "join_table1_table2": { + "type": "table_lookup_table", + "start_policy": "auto", + "r": { + "table": "table1", + "expected_batches_total": 2 + }, + "l": { + "index_name": "idx_table2_string2", + "join_on": "r.field_string1", + "filter": "l.field_int2 > 100", + "group": true, + "join_type": "left" + }, + "w": { + "name": "joined_table1_table2", + "having": "w.total_value > 2", + "fields": { + "field_int1": { + "expression": "r.field_int1", + "type": "int" + }, + "field_string1": { + "expression": "r.field_string1", + "type": "string" + }, + "total_value": { + "expression": "sum(l.field_int2)", + "type": "int" + }, + "item_count": { + "expression": "count()", + "type": "int" + } + } + } + }, + "file_totals": { + "type": "table_file", + "r": { + "table": "joined_table1_table2" + }, + "w": { + "top": { + "order": "field_int1(asc),item_count(asc)", + "limit": 500000 + }, + "having": "w.total_value > 3", + "url_template": "file_totals.csv", + "columns": [ + { + "csv":{ + "header": "field_int1", + "format": "%d" + }, + "name": "field_int1", + "expression": "r.field_int1", + "type": "int" + }, + { + "csv":{ + "header": "field_string1", + "format": "%s" + }, + "name": "field_string1", + "expression": "r.field_string1", + "type": "string" + }, + { + "csv":{ + "header": "total_value", + "format": "%s" + }, + "name": "total_value", + "expression": "decimal2(r.total_value)", + "type": "decimal2" + }, + { + "csv":{ + "header": "item_count", + "format": "%d" + }, + "name": "item_count", + "expression": "r.item_count", + "type": "int" + } + ] + } + } + }, + "dependency_policies": { + "current_active_first_stopped_nogo":` + DefaultPolicyCheckerConf + + ` + } +}` + +func TestCreatorFieldRefs(t *testing.T) { + var err error + + newScript := &ScriptDef{} + assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) + + tableFieldRefs := newScript.ScriptNodes["read_table2"].TableCreator.GetFieldRefsWithAlias(CreatorAlias) + var tableFieldRef *FieldRef + tableFieldRef, _ = tableFieldRefs.FindByFieldName("field_int2") + assert.Equal(t, CreatorAlias, tableFieldRef.TableName) + assert.Equal(t, FieldTypeInt, tableFieldRef.FieldType) + + fileFieldRefs := newScript.ScriptNodes["file_totals"].FileCreator.getFieldRefs() + var fileFieldRef *FieldRef + fileFieldRef, _ = fileFieldRefs.FindByFieldName("total_value") + assert.Equal(t, CreatorAlias, fileFieldRef.TableName) + assert.Equal(t, FieldTypeDecimal2, fileFieldRef.FieldType) + + // Duplicate creator + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"name": "table2"`, `"name": "table1"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "duplicate table name: table1") + + // Bad readertable name + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"table": "table1"`, `"table": "bad_table_name"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "cannot find the node that creates table [bad_table_name]") +} + +func TestCreatorCalculateHaving(t *testing.T) { + var isHaving bool + + newScript := &ScriptDef{} + assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) + + // Table writer: calculate having + var tableRecord map[string]any + tableCreator := newScript.ScriptNodes["join_table1_table2"].TableCreator + + tableRecord = map[string]any{"total_value": 3} + isHaving, _ = tableCreator.CheckTableRecordHavingCondition(tableRecord) + assert.True(t, isHaving) + + tableRecord = map[string]any{"total_value": 2} + isHaving, _ = tableCreator.CheckTableRecordHavingCondition(tableRecord) + assert.False(t, isHaving) + + // File writer: calculate having + var colVals []any + fileCreator := newScript.ScriptNodes["file_totals"].FileCreator + + colVals = make([]any, 0) + colVals = append(colVals, 0, "a", 4, 0) + isHaving, _ = fileCreator.CheckFileRecordHavingCondition(colVals) + assert.True(t, isHaving) + + colVals = make([]any, 0) + colVals = append(colVals, 0, "a", 3, 0) + isHaving, _ = fileCreator.CheckFileRecordHavingCondition(colVals) + assert.False(t, isHaving) +} + +func TestCreatorCalculateOutput(t *testing.T) { + var err error + var vars eval.VarValuesMap + + newScript := &ScriptDef{} + assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) + + // Table creator: calculate fields + + var fields map[string]any + vars = eval.VarValuesMap{"r": {"field_int1": int64(1), "field_string1": "a"}} + fields, _ = newScript.ScriptNodes["join_table1_table2"].TableCreator.CalculateTableRecordFromSrcVars(true, vars) + if len(fields) == 4 { + assert.Equal(t, int64(1), fields["field_int1"]) + assert.Equal(t, "a", fields["field_string1"]) + assert.Equal(t, int64(1), fields["total_value"]) + assert.Equal(t, int64(1), fields["item_count"]) + } + + // Table creator: bad field expression, tweak sum + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `sum(l.field_int2)`, `sum(l.field_int2`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "cannot parse field expression [sum(l.field_int2]") + + // File creator: calculate columns + + var cols []any + vars = eval.VarValuesMap{"r": {"field_int1": int64(1), "field_string1": "a", "total_value": decimal.NewFromInt(1), "item_count": int64(1)}} + cols, _ = newScript.ScriptNodes["file_totals"].FileCreator.CalculateFileRecordFromSrcVars(vars) + assert.Equal(t, 4, len(cols)) + if len(cols) == 4 { + assert.Equal(t, int64(1), cols[0]) + assert.Equal(t, "a", cols[1]) + assert.Equal(t, decimal.NewFromInt(1), cols[2]) + assert.Equal(t, int64(1), cols[3]) + } + + // File creator: bad column expression, tweak decimal2() + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `decimal2(r.total_value)`, `decimal2(r.total_value`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "[cannot parse column expression [decimal2(r.total_value]") + +} + +func TestLookup(t *testing.T) { + var err error + var vars eval.VarValuesMap + var isMatch bool + + newScript := &ScriptDef{} + assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) + + // Invalid (writer) field in aggregate, tweak sum() arg + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"expression": "sum(l.field_int2)"`, `"expression": "sum(w.field_int1)"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "invalid field(s) in target table field expression: [prohibited field w.field_int1]") + + // Filter calculation + + vars = eval.VarValuesMap{"l": {"field_int2": 101}} + isMatch, _ = newScript.ScriptNodes["join_table1_table2"].Lookup.CheckFilterCondition(vars) + assert.True(t, isMatch) + + vars = eval.VarValuesMap{"l": {"field_int2": 100}} + isMatch, _ = newScript.ScriptNodes["join_table1_table2"].Lookup.CheckFilterCondition(vars) + assert.False(t, isMatch) + + // bad index_name, tweak it + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"index_name": "idx_table2_string2"`, `"index_name": "idx_table2_string2_bad"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "cannot find the node that creates index [idx_table2_string2_bad]") + + // bad join_on, tweak it + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"join_on": "r.field_string1"`, `"join_on": ""`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "expected a comma-separated list of ., got []") + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"join_on": "r.field_string1"`, `"join_on": "bla"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "expected a comma-separated list of ., got [bla]") + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"join_on": "r.field_string1"`, `"join_on": "bla.bla"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "source table name [bla] unknown, expected [r]") + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"join_on": "r.field_string1"`, `"join_on": "r.field_string1_bad"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "source [r] does not produce field [field_string1_bad]") + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"join_on": "r.field_string1"`, `"join_on": "r.field_int1"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "left-side field field_int1 has type int, while index field field_string2 has type string") + + // bad filter, tweak it + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"filter": "l.field_int2 > 100"`, `"filter": "r.field_int2 > 100"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "invalid field in lookup filter [r.field_int2 > 100], only fields from the lookup table [table2](alias l) are allowed: [unknown field r.field_int2]") + + // bad join_type, tweak it + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"join_type": "left"`, `"join_type": "left_bad"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "invalid join type, expected inner or left, left_bad is not supported") +} + +func TestBadCreatorHaving(t *testing.T) { + var err error + + newScript := &ScriptDef{} + assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) + + // Bad expression, tweak having expression + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 2"`, `"having": "w.total_value &> 2"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "cannot parse table creator 'having' condition [w.total_value &> 2]") + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 3"`, `"having": "w.bad_field &> 3"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "cannot parse file creator 'having' condition [w.bad_field &> 3]") + + // Unknown field in having + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 2"`, `"having": "w.bad_field > 2"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "invalid field in table creator 'having' condition: [unknown field w.bad_field]") + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 3"`, `"having": "w.bad_field > 3"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "invalid field in file creator 'having' condition: [unknown field w.bad_field]]") + + // Prohibited reader field in having + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 2"`, `"having": "r.field_int1 > 2"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "invalid field in table creator 'having' condition: [prohibited field r.field_int1]") + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 3"`, `"having": "r.field_int1 > 3"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "invalid field in file creator 'having' condition: [prohibited field r.field_int1]") + + // Prohibited lookup field in table creator having + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 2"`, `"having": "l.field_int2 > 2"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "invalid field in table creator 'having' condition: [prohibited field l.field_int2]") + + // Type mismatch in having + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 2"`, `"having": "w.total_value == true"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "cannot evaluate table creator 'having' expression [w.total_value == true]: [cannot perform binary comp op, incompatible arg types '0(int64)' == 'true(bool)' ]") + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"having": "w.total_value > 3"`, `"having": "w.total_value == true"`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "cannot evaluate file creator 'having' expression [w.total_value == true]: [cannot perform binary comp op, incompatible arg types '2.34(decimal.Decimal)' == 'true(bool)' ]") +} + +func TestTopLimit(t *testing.T) { + var err error + + newScript := &ScriptDef{} + assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) + + // Tweak limit beyond allowed maximum + + err = newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"limit": 500000`, `"limit": 500001`, 1)), + nil, nil, "", nil) + assert.Contains(t, err.Error(), "top.limit cannot exceed 500000") + + // Remove limit altogether + + assert.Nil(t, newScript.Deserialize( + []byte(strings.Replace(plainScriptJson, `"limit": 500000`, `"some_bogus_setting": 500000`, 1)), + nil, nil, "", nil)) + assert.Equal(t, 500000, newScript.ScriptNodes["file_totals"].FileCreator.Top.Limit) +} + +func TestBatchIntervalsCalculation(t *testing.T) { + newScript := &ScriptDef{} + assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) + + var intervals [][]int64 + + tableReaderNodeDef := newScript.ScriptNodes["join_table1_table2"] + intervals, _ = tableReaderNodeDef.GetTokenIntervalsByNumberOfBatches() + + assert.Equal(t, 2, len(intervals)) + if len(intervals) == 2 { + assert.Equal(t, int64(-9223372036854775808), intervals[0][0]) + assert.Equal(t, int64(-2), intervals[0][1]) + assert.Equal(t, int64(-1), intervals[1][0]) + assert.Equal(t, int64(9223372036854775807), intervals[1][1]) + } + + fileReaderNodeDef := newScript.ScriptNodes["read_table1"] + intervals, _ = fileReaderNodeDef.GetTokenIntervalsByNumberOfBatches() + + assert.Equal(t, 1, len(intervals)) + if len(intervals) == 1 { + assert.Equal(t, int64(0), intervals[0][0]) + assert.Equal(t, int64(0), intervals[0][1]) + } + + fileCreatorNodeDef := newScript.ScriptNodes["file_totals"] + intervals, _ = fileCreatorNodeDef.GetTokenIntervalsByNumberOfBatches() + + assert.Equal(t, 1, len(intervals)) + if len(intervals) == 1 { + assert.Equal(t, int64(-9223372036854775808), intervals[0][0]) + assert.Equal(t, int64(9223372036854775807), intervals[0][1]) + } +} + +func TestUniqueIndexesFieldRefs(t *testing.T) { + newScript := &ScriptDef{} + assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) + + fileReaderNodeDef := newScript.ScriptNodes["read_table2"] + fieldRefs := fileReaderNodeDef.GetUniqueIndexesFieldRefs() + assert.Equal(t, 1, len(*fieldRefs)) + if len(*fieldRefs) == 1 { + assert.Equal(t, "table2", (*fieldRefs)[0].TableName) + assert.Equal(t, "field_string2", (*fieldRefs)[0].FieldName) + assert.Equal(t, FieldTypeString, (*fieldRefs)[0].FieldType) + } +} + +func TestAffectedNodes(t *testing.T) { + var affectedNodes []string + + newScript := &ScriptDef{} + assert.Nil(t, newScript.Deserialize([]byte(plainScriptJson), nil, nil, "", nil)) + + affectedNodes = newScript.GetAffectedNodes([]string{"read_table1"}) + assert.Equal(t, 3, len(affectedNodes)) + assert.Contains(t, affectedNodes, "read_table1") + assert.Contains(t, affectedNodes, "join_table1_table2") + assert.Contains(t, affectedNodes, "file_totals") + + affectedNodes = newScript.GetAffectedNodes([]string{"read_table1", "read_table2"}) + assert.Equal(t, 4, len(affectedNodes)) + assert.Contains(t, affectedNodes, "read_table1") + assert.Contains(t, affectedNodes, "read_table2") + assert.Contains(t, affectedNodes, "join_table1_table2") + assert.Contains(t, affectedNodes, "file_totals") + + // Make join manual and see the list of affected nodes shrinking + + assert.Nil(t, newScript.Deserialize([]byte(strings.Replace(plainScriptJson, `"start_policy": "auto"`, `"start_policy": "manual"`, 1)), nil, nil, "", nil)) + + affectedNodes = newScript.GetAffectedNodes([]string{"read_table1"}) + assert.Equal(t, 1, len(affectedNodes)) + assert.Contains(t, affectedNodes, "read_table1") + + affectedNodes = newScript.GetAffectedNodes([]string{"read_table1", "read_table2"}) + assert.Equal(t, 2, len(affectedNodes)) + assert.Contains(t, affectedNodes, "read_table1") + assert.Contains(t, affectedNodes, "read_table2") +} diff --git a/pkg/sc/script_node_def.go b/pkg/sc/script_node_def.go index 272efd4..4359a83 100644 --- a/pkg/sc/script_node_def.go +++ b/pkg/sc/script_node_def.go @@ -1,397 +1,396 @@ -package sc - -import ( - "encoding/json" - "fmt" - "go/ast" - "math" - "regexp" - "strings" - - "github.com/capillariesio/capillaries/pkg/eval" -) - -const ( - HandlerExeTypeGeneric string = "capi_daemon" - HandlerExeTypeToolbelt string = "capi_toolbelt" - HandlerExeTypeWebapi string = "capi_webapi" -) - -const MaxAcceptedBatchesByTableReader int = 1000000 -const DefaultRowsetSize int = 1000 -const MaxRowsetSize int = 100000 - -type AggFinderVisitor struct { - Error error -} - -func (v *AggFinderVisitor) Visit(node ast.Node) ast.Visitor { - switch callExp := node.(type) { - case *ast.CallExpr: - switch callIdentExp := callExp.Fun.(type) { - case *ast.Ident: - if eval.StringToAggFunc(callIdentExp.Name) != eval.AggUnknown { - v.Error = fmt.Errorf("found aggregate function %s()", callIdentExp.Name) - return nil - } else { - return v - } - default: - return v - } - default: - return v - } -} - -type NodeType string - -const ( - NodeTypeNone NodeType = "none" - NodeTypeFileTable NodeType = "file_table" - NodeTypeTableTable NodeType = "table_table" - NodeTypeTableLookupTable NodeType = "table_lookup_table" - NodeTypeTableFile NodeType = "table_file" - NodeTypeTableCustomTfmTable NodeType = "table_custom_tfm_table" -) - -func ValidateNodeType(nodeType NodeType) error { - if nodeType == NodeTypeFileTable || - nodeType == NodeTypeTableTable || - nodeType == NodeTypeTableLookupTable || - nodeType == NodeTypeTableFile || - nodeType == NodeTypeTableCustomTfmTable { - return nil - } - return fmt.Errorf("invalid node type %s", nodeType) -} - -const ReaderAlias string = "r" -const CreatorAlias string = "w" -const LookupAlias string = "l" -const CustomProcessorAlias string = "p" - -type NodeRerunPolicy string - -const ( - NodeRerun NodeRerunPolicy = "rerun" // Default - NodeFail NodeRerunPolicy = "fail" -) - -func ValidateRerunPolicy(rerunPolicy NodeRerunPolicy) error { - if rerunPolicy == NodeRerun || - rerunPolicy == NodeFail { - return nil - } - return fmt.Errorf("invalid node rerun policy %s", rerunPolicy) -} - -type NodeStartPolicy string - -const ( - NodeStartManual NodeStartPolicy = "manual" - NodeStartAuto NodeStartPolicy = "auto" // Default -) - -func ValidateStartPolicy(startPolicy NodeStartPolicy) error { - if startPolicy == NodeStartManual || - startPolicy == NodeStartAuto { - return nil - } - return fmt.Errorf("invalid node start policy %s", startPolicy) -} - -type ScriptNodeDef struct { - Name string // Get it from the key - Type NodeType `json:"type"` - Desc string `json:"desc"` - StartPolicy NodeStartPolicy `json:"start_policy"` - RerunPolicy NodeRerunPolicy `json:"rerun_policy"` - CustomProcessorType string `json:"custom_proc_type"` - HandlerExeType string `json:"handler_exe_type"` - - RawReader json.RawMessage `json:"r"` // This depends on tfm type - TableReader TableReaderDef - FileReader FileReaderDef - - Lookup LookupDef `json:"l"` - - RawProcessorDef json.RawMessage `json:"p"` // This depends on tfm type - CustomProcessor CustomProcessorDef // Also should implement CustomProcessorRunner - - RawWriter json.RawMessage `json:"w"` // This depends on tfm type - DependencyPolicyName string `json:"dependency_policy"` - TableCreator TableCreatorDef - TableUpdater TableUpdaterDef - FileCreator FileCreatorDef - DepPolDef *DependencyPolicyDef -} - -func (node *ScriptNodeDef) HasTableReader() bool { - return node.Type == NodeTypeTableTable || - node.Type == NodeTypeTableLookupTable || - node.Type == NodeTypeTableFile || - node.Type == NodeTypeTableCustomTfmTable -} -func (node *ScriptNodeDef) HasFileReader() bool { - return node.Type == NodeTypeFileTable -} - -func (node *ScriptNodeDef) HasLookup() bool { - return node.Type == NodeTypeTableLookupTable -} - -func (node *ScriptNodeDef) HasCustomProcessor() bool { - return node.Type == NodeTypeTableCustomTfmTable -} - -func (node *ScriptNodeDef) HasTableCreator() bool { - return node.Type == NodeTypeFileTable || - node.Type == NodeTypeTableTable || - node.Type == NodeTypeTableLookupTable || - node.Type == NodeTypeTableCustomTfmTable -} -func (node *ScriptNodeDef) HasFileCreator() bool { - return node.Type == NodeTypeTableFile -} -func (node *ScriptNodeDef) GetTargetName() string { - if node.HasTableCreator() { - return node.TableCreator.Name - } else if node.HasFileCreator() { - return CreatorAlias - } else { - return "dev_error_uknown_target_name" - } -} - -func (node *ScriptNodeDef) Deserialize(customProcessorDefFactory CustomProcessorDefFactory, customProcessorsSettings map[string]json.RawMessage, caPath string, privateKeys map[string]string) error { - errors := make([]string, 0) - - if err := ValidateNodeType(node.Type); err != nil { - return err - } - - // Defaults - - if len(node.HandlerExeType) == 0 { - node.HandlerExeType = HandlerExeTypeGeneric - } - - if len(node.RerunPolicy) == 0 { - node.RerunPolicy = NodeRerun - } else if err := ValidateRerunPolicy(node.RerunPolicy); err != nil { - return err - } - - if len(node.StartPolicy) == 0 { - node.StartPolicy = NodeStartAuto - } else if err := ValidateStartPolicy(node.StartPolicy); err != nil { - return err - } - - // Reader - - if node.HasTableReader() { - if err := json.Unmarshal(node.RawReader, &node.TableReader); err != nil { - errors = append(errors, fmt.Sprintf("cannot unmarshal table reader: [%s]", err.Error())) - } - if len(node.TableReader.TableName) == 0 { - errors = append(errors, "table reader cannot reference empty table name") - } - if node.TableReader.ExpectedBatchesTotal == 0 { - node.TableReader.ExpectedBatchesTotal = 1 - } else if node.TableReader.ExpectedBatchesTotal < 0 || node.TableReader.ExpectedBatchesTotal > MaxAcceptedBatchesByTableReader { - errors = append(errors, fmt.Sprintf("table reader can accept between 1 and %d batches, %d specified", MaxAcceptedBatchesByTableReader, node.TableReader.ExpectedBatchesTotal)) - } - if node.TableReader.RowsetSize < 0 || MaxRowsetSize < node.TableReader.RowsetSize { - errors = append(errors, fmt.Sprintf("invalid rowset size %d, table reader can accept between 0 (defaults to %d) and %d", node.TableReader.RowsetSize, DefaultRowsetSize, MaxRowsetSize)) - } - if node.TableReader.RowsetSize == 0 { - node.TableReader.RowsetSize = DefaultRowsetSize - } - - } else if node.HasFileReader() { - if err := node.FileReader.Deserialize(node.RawReader); err != nil { - errors = append(errors, fmt.Sprintf("cannot deserialize file reader [%s]: [%s]", string(node.RawReader), err.Error())) - } - } - - // Creator - - if node.HasTableCreator() { - if err := node.TableCreator.Deserialize(node.RawWriter); err != nil { - errors = append(errors, fmt.Sprintf("cannot deserialize table creator [%s]: [%s]", strings.ReplaceAll(string(node.RawWriter), "\n", " "), err.Error())) - } - } else if node.HasFileCreator() { - if err := node.FileCreator.Deserialize(node.RawWriter); err != nil { - errors = append(errors, fmt.Sprintf("cannot deserialize file creator [%s]: [%s]", strings.ReplaceAll(string(node.RawWriter), "\n", " "), err.Error())) - } - } - - // Custom processor - - if node.HasCustomProcessor() { - if customProcessorDefFactory == nil { - return fmt.Errorf("undefined custom processor factory") - } - if customProcessorsSettings == nil { - return fmt.Errorf("missing custom processor settings section") - } - var ok bool - node.CustomProcessor, ok = customProcessorDefFactory.Create(node.CustomProcessorType) - if !ok { - errors = append(errors, fmt.Sprintf("cannot deserialize unknown custom processor %s", node.CustomProcessorType)) - } else { - if customProcSettings, ok := customProcessorsSettings[node.CustomProcessorType]; !ok { - errors = append(errors, fmt.Sprintf("cannot find custom processing settings for [%s] in the environment config file", node.CustomProcessorType)) - } else { - if err := node.CustomProcessor.Deserialize(node.RawProcessorDef, customProcSettings, caPath, privateKeys); err != nil { - re := regexp.MustCompile("[ \r\n]+") - errors = append(errors, fmt.Sprintf("cannot deserialize custom processor [%s]: [%s]", re.ReplaceAllString(string(node.RawProcessorDef), ""), err.Error())) - } - } - } - } - - if len(errors) > 0 { - return fmt.Errorf(strings.Join(errors, "; ")) - } else { - return nil - } -} - -func (node *ScriptNodeDef) evalCreatorAndLookupExpressionsAndCheckType() error { - errors := make([]string, 0, 2) - - if node.HasLookup() && node.Lookup.UsesFilter() { - if err := evalExpressionWithFieldRefsAndCheckType(node.Lookup.Filter, node.Lookup.UsedInFilterFields, FieldTypeBool); err != nil { - errors = append(errors, fmt.Sprintf("cannot evaluate lookup filter expression [%s]: [%s]", node.Lookup.RawFilter, err.Error())) - } - } - - if node.HasTableCreator() { - // Having - if err := evalExpressionWithFieldRefsAndCheckType(node.TableCreator.Having, node.TableCreator.UsedInHavingFields, FieldTypeBool); err != nil { - errors = append(errors, fmt.Sprintf("cannot evaluate table creator 'having' expression [%s]: [%s]", node.TableCreator.RawHaving, err.Error())) - } - - // Target table fields - for tgtFieldName, tgtFieldDef := range node.TableCreator.Fields { - - // TODO: find a way to check field usage: - // - lookup fields must be used only within enclosing agg calls (sum etc), otherwise last one wins - // - src table fields are allowed within enclosing agg calls, and there is even a biz case for it (multiply src field by the number of lookup rows) - - // If no grouping is used, no agg calls allowed - if node.HasLookup() && !node.Lookup.IsGroup || !node.HasLookup() { - v := AggFinderVisitor{} - ast.Walk(&v, tgtFieldDef.ParsedExpression) - if v.Error != nil { - errors = append(errors, fmt.Sprintf("cannot use agg functions in [%s], lookup group flag is not set or no lookups used: [%s]", tgtFieldDef.RawExpression, v.Error.Error())) - } - } - - // Just eval with test values, agg functions will go through preserving the type no problem - if err := evalExpressionWithFieldRefsAndCheckType(tgtFieldDef.ParsedExpression, node.TableCreator.UsedInTargetExpressionsFields, tgtFieldDef.Type); err != nil { - errors = append(errors, fmt.Sprintf("cannot evaluate table creator target field %s expression [%s]: [%s]", tgtFieldName, tgtFieldDef.RawExpression, err.Error())) - } - } - } - - if node.HasFileCreator() { - // Having - if err := evalExpressionWithFieldRefsAndCheckType(node.FileCreator.Having, node.FileCreator.UsedInHavingFields, FieldTypeBool); err != nil { - errors = append(errors, fmt.Sprintf("cannot evaluate file creator 'having' expression [%s]: [%s]", node.FileCreator.RawHaving, err.Error())) - } - - // Target table fields (yes, they are not just strings, we check the type) - for i := 0; i < len(node.FileCreator.Columns); i++ { - colDef := &node.FileCreator.Columns[i] - if err := evalExpressionWithFieldRefsAndCheckType(colDef.ParsedExpression, node.FileCreator.UsedInTargetExpressionsFields, colDef.Type); err != nil { - errors = append(errors, fmt.Sprintf("cannot evaluate table creator target field %s expression [%s]: [%s]", colDef.Name, colDef.RawExpression, err.Error())) - } - } - } - - // NOTE: do not even try to eval expressions from the custom processor here, - // they may contain custom stuff and are pretty much guaranteed to fail - - if len(errors) > 0 { - return fmt.Errorf(strings.Join(errors, "; ")) - } else { - return nil - } -} - -func (node *ScriptNodeDef) getSourceFieldRefs() (*FieldRefs, error) { - if node.HasFileReader() { - return node.FileReader.getFieldRefs(), nil - } else if node.HasTableReader() { - return node.TableReader.TableCreator.GetFieldRefsWithAlias(ReaderAlias), nil - } else { - return nil, fmt.Errorf("dev error, node of type %s has no file or table reader", node.Type) - } -} - -func (node *ScriptNodeDef) GetUniqueIndexesFieldRefs() *FieldRefs { - if !node.HasTableCreator() { - return &FieldRefs{} - } - fieldTypeMap := map[string]TableFieldType{} - for _, idxDef := range node.TableCreator.Indexes { - if idxDef.Uniqueness == IdxUnique { - for _, idxComponentDef := range idxDef.Components { - fieldTypeMap[idxComponentDef.FieldName] = idxComponentDef.FieldType - } - } - } - fieldRefs := make(FieldRefs, len(fieldTypeMap)) - fieldRefIdx := 0 - for fieldName, fieldType := range fieldTypeMap { - fieldRefs[fieldRefIdx] = FieldRef{ - FieldName: fieldName, - FieldType: fieldType, - TableName: node.TableCreator.Name} - fieldRefIdx++ - } - - return &fieldRefs -} - -func (node *ScriptNodeDef) GetTokenIntervalsByNumberOfBatches() ([][]int64, error) { - if node.HasTableReader() || node.HasFileCreator() && node.TableReader.ExpectedBatchesTotal > 1 { - if node.TableReader.ExpectedBatchesTotal == 1 { - return [][]int64{{int64(math.MinInt64), int64(math.MaxInt64)}}, nil - } - - tokenIntervalPerBatch := int64(math.MaxInt64/node.TableReader.ExpectedBatchesTotal) - int64(math.MinInt64/node.TableReader.ExpectedBatchesTotal) - - intervals := make([][]int64, node.TableReader.ExpectedBatchesTotal) - left := int64(math.MinInt64) - for i := 0; i < len(intervals); i++ { - var right int64 - if i == len(intervals)-1 { - right = math.MaxInt64 - } else { - right = left + tokenIntervalPerBatch - 1 - } - intervals[i] = []int64{left, right} - left = right + 1 - } - return intervals, nil - // } else if node.HasFileCreator() && node.TableReader.ExpectedBatchesTotal == 1 { - // // One output file - one batch, dummy intervals - // intervals := make([][]int64, 1) - // intervals[0] = []int64{int64(0), 0} - // return intervals, nil - } else if node.HasFileReader() { - // One input file - one batch - intervals := make([][]int64, len(node.FileReader.SrcFileUrls)) - for i := 0; i < len(node.FileReader.SrcFileUrls); i++ { - intervals[i] = []int64{int64(i), int64(i)} - } - return intervals, nil - } else { - return nil, fmt.Errorf("cannot find implementation for intervals for node %s", node.Name) - } -} +package sc + +import ( + "encoding/json" + "fmt" + "go/ast" + "math" + "regexp" + "strings" + + "github.com/capillariesio/capillaries/pkg/eval" +) + +const ( + HandlerExeTypeGeneric string = "capi_daemon" + HandlerExeTypeToolbelt string = "capi_toolbelt" + HandlerExeTypeWebapi string = "capi_webapi" +) + +const MaxAcceptedBatchesByTableReader int = 1000000 +const DefaultRowsetSize int = 1000 +const MaxRowsetSize int = 100000 + +type AggFinderVisitor struct { + Error error +} + +func (v *AggFinderVisitor) Visit(node ast.Node) ast.Visitor { + switch callExp := node.(type) { + case *ast.CallExpr: + switch callIdentExp := callExp.Fun.(type) { + case *ast.Ident: + if eval.StringToAggFunc(callIdentExp.Name) != eval.AggUnknown { + v.Error = fmt.Errorf("found aggregate function %s()", callIdentExp.Name) + return nil + } else { + return v + } + default: + return v + } + default: + return v + } +} + +type NodeType string + +const ( + NodeTypeNone NodeType = "none" + NodeTypeFileTable NodeType = "file_table" + NodeTypeTableTable NodeType = "table_table" + NodeTypeTableLookupTable NodeType = "table_lookup_table" + NodeTypeTableFile NodeType = "table_file" + NodeTypeTableCustomTfmTable NodeType = "table_custom_tfm_table" +) + +func ValidateNodeType(nodeType NodeType) error { + if nodeType == NodeTypeFileTable || + nodeType == NodeTypeTableTable || + nodeType == NodeTypeTableLookupTable || + nodeType == NodeTypeTableFile || + nodeType == NodeTypeTableCustomTfmTable { + return nil + } + return fmt.Errorf("invalid node type %s", nodeType) +} + +const ReaderAlias string = "r" +const CreatorAlias string = "w" +const LookupAlias string = "l" +const CustomProcessorAlias string = "p" + +type NodeRerunPolicy string + +const ( + NodeRerun NodeRerunPolicy = "rerun" // Default + NodeFail NodeRerunPolicy = "fail" +) + +func ValidateRerunPolicy(rerunPolicy NodeRerunPolicy) error { + if rerunPolicy == NodeRerun || + rerunPolicy == NodeFail { + return nil + } + return fmt.Errorf("invalid node rerun policy %s", rerunPolicy) +} + +type NodeStartPolicy string + +const ( + NodeStartManual NodeStartPolicy = "manual" + NodeStartAuto NodeStartPolicy = "auto" // Default +) + +func ValidateStartPolicy(startPolicy NodeStartPolicy) error { + if startPolicy == NodeStartManual || + startPolicy == NodeStartAuto { + return nil + } + return fmt.Errorf("invalid node start policy %s", startPolicy) +} + +type ScriptNodeDef struct { + Name string // Get it from the key + Type NodeType `json:"type"` + Desc string `json:"desc"` + StartPolicy NodeStartPolicy `json:"start_policy"` + RerunPolicy NodeRerunPolicy `json:"rerun_policy"` + CustomProcessorType string `json:"custom_proc_type"` + HandlerExeType string `json:"handler_exe_type"` + + RawReader json.RawMessage `json:"r"` // This depends on tfm type + TableReader TableReaderDef + FileReader FileReaderDef + + Lookup LookupDef `json:"l"` + + RawProcessorDef json.RawMessage `json:"p"` // This depends on tfm type + CustomProcessor CustomProcessorDef // Also should implement CustomProcessorRunner + + RawWriter json.RawMessage `json:"w"` // This depends on tfm type + DependencyPolicyName string `json:"dependency_policy"` + TableCreator TableCreatorDef + TableUpdater TableUpdaterDef + FileCreator FileCreatorDef + DepPolDef *DependencyPolicyDef +} + +func (node *ScriptNodeDef) HasTableReader() bool { + return node.Type == NodeTypeTableTable || + node.Type == NodeTypeTableLookupTable || + node.Type == NodeTypeTableFile || + node.Type == NodeTypeTableCustomTfmTable +} +func (node *ScriptNodeDef) HasFileReader() bool { + return node.Type == NodeTypeFileTable +} + +func (node *ScriptNodeDef) HasLookup() bool { + return node.Type == NodeTypeTableLookupTable +} + +func (node *ScriptNodeDef) HasCustomProcessor() bool { + return node.Type == NodeTypeTableCustomTfmTable +} + +func (node *ScriptNodeDef) HasTableCreator() bool { + return node.Type == NodeTypeFileTable || + node.Type == NodeTypeTableTable || + node.Type == NodeTypeTableLookupTable || + node.Type == NodeTypeTableCustomTfmTable +} +func (node *ScriptNodeDef) HasFileCreator() bool { + return node.Type == NodeTypeTableFile +} +func (node *ScriptNodeDef) GetTargetName() string { + if node.HasTableCreator() { + return node.TableCreator.Name + } else if node.HasFileCreator() { + return CreatorAlias + } + return "dev_error_uknown_target_name" +} + +func (node *ScriptNodeDef) Deserialize(customProcessorDefFactory CustomProcessorDefFactory, customProcessorsSettings map[string]json.RawMessage, caPath string, privateKeys map[string]string) error { + errors := make([]string, 0) + + if err := ValidateNodeType(node.Type); err != nil { + return err + } + + // Defaults + + if len(node.HandlerExeType) == 0 { + node.HandlerExeType = HandlerExeTypeGeneric + } + + if len(node.RerunPolicy) == 0 { + node.RerunPolicy = NodeRerun + } else if err := ValidateRerunPolicy(node.RerunPolicy); err != nil { + return err + } + + if len(node.StartPolicy) == 0 { + node.StartPolicy = NodeStartAuto + } else if err := ValidateStartPolicy(node.StartPolicy); err != nil { + return err + } + + // Reader + + if node.HasTableReader() { + if err := json.Unmarshal(node.RawReader, &node.TableReader); err != nil { + errors = append(errors, fmt.Sprintf("cannot unmarshal table reader: [%s]", err.Error())) + } + if len(node.TableReader.TableName) == 0 { + errors = append(errors, "table reader cannot reference empty table name") + } + if node.TableReader.ExpectedBatchesTotal == 0 { + node.TableReader.ExpectedBatchesTotal = 1 + } else if node.TableReader.ExpectedBatchesTotal < 0 || node.TableReader.ExpectedBatchesTotal > MaxAcceptedBatchesByTableReader { + errors = append(errors, fmt.Sprintf("table reader can accept between 1 and %d batches, %d specified", MaxAcceptedBatchesByTableReader, node.TableReader.ExpectedBatchesTotal)) + } + if node.TableReader.RowsetSize < 0 || MaxRowsetSize < node.TableReader.RowsetSize { + errors = append(errors, fmt.Sprintf("invalid rowset size %d, table reader can accept between 0 (defaults to %d) and %d", node.TableReader.RowsetSize, DefaultRowsetSize, MaxRowsetSize)) + } + if node.TableReader.RowsetSize == 0 { + node.TableReader.RowsetSize = DefaultRowsetSize + } + + } else if node.HasFileReader() { + if err := node.FileReader.Deserialize(node.RawReader); err != nil { + errors = append(errors, fmt.Sprintf("cannot deserialize file reader [%s]: [%s]", string(node.RawReader), err.Error())) + } + } + + // Creator + + if node.HasTableCreator() { + if err := node.TableCreator.Deserialize(node.RawWriter); err != nil { + errors = append(errors, fmt.Sprintf("cannot deserialize table creator [%s]: [%s]", strings.ReplaceAll(string(node.RawWriter), "\n", " "), err.Error())) + } + } else if node.HasFileCreator() { + if err := node.FileCreator.Deserialize(node.RawWriter); err != nil { + errors = append(errors, fmt.Sprintf("cannot deserialize file creator [%s]: [%s]", strings.ReplaceAll(string(node.RawWriter), "\n", " "), err.Error())) + } + } + + // Custom processor + + if node.HasCustomProcessor() { + if customProcessorDefFactory == nil { + return fmt.Errorf("undefined custom processor factory") + } + if customProcessorsSettings == nil { + return fmt.Errorf("missing custom processor settings section") + } + var ok bool + node.CustomProcessor, ok = customProcessorDefFactory.Create(node.CustomProcessorType) + if !ok { + errors = append(errors, fmt.Sprintf("cannot deserialize unknown custom processor %s", node.CustomProcessorType)) + } else { + if customProcSettings, ok := customProcessorsSettings[node.CustomProcessorType]; !ok { + errors = append(errors, fmt.Sprintf("cannot find custom processing settings for [%s] in the environment config file", node.CustomProcessorType)) + } else { + if err := node.CustomProcessor.Deserialize(node.RawProcessorDef, customProcSettings, caPath, privateKeys); err != nil { + re := regexp.MustCompile("[ \r\n]+") + errors = append(errors, fmt.Sprintf("cannot deserialize custom processor [%s]: [%s]", re.ReplaceAllString(string(node.RawProcessorDef), ""), err.Error())) + } + } + } + } + + if len(errors) > 0 { + return fmt.Errorf(strings.Join(errors, "; ")) + } + + return nil +} + +func (node *ScriptNodeDef) evalCreatorAndLookupExpressionsAndCheckType() error { + errors := make([]string, 0, 2) + + if node.HasLookup() && node.Lookup.UsesFilter() { + if err := evalExpressionWithFieldRefsAndCheckType(node.Lookup.Filter, node.Lookup.UsedInFilterFields, FieldTypeBool); err != nil { + errors = append(errors, fmt.Sprintf("cannot evaluate lookup filter expression [%s]: [%s]", node.Lookup.RawFilter, err.Error())) + } + } + + if node.HasTableCreator() { + // Having + if err := evalExpressionWithFieldRefsAndCheckType(node.TableCreator.Having, node.TableCreator.UsedInHavingFields, FieldTypeBool); err != nil { + errors = append(errors, fmt.Sprintf("cannot evaluate table creator 'having' expression [%s]: [%s]", node.TableCreator.RawHaving, err.Error())) + } + + // Target table fields + for tgtFieldName, tgtFieldDef := range node.TableCreator.Fields { + + // TODO: find a way to check field usage: + // - lookup fields must be used only within enclosing agg calls (sum etc), otherwise last one wins + // - src table fields are allowed within enclosing agg calls, and there is even a biz case for it (multiply src field by the number of lookup rows) + + // If no grouping is used, no agg calls allowed + if node.HasLookup() && !node.Lookup.IsGroup || !node.HasLookup() { + v := AggFinderVisitor{} + ast.Walk(&v, tgtFieldDef.ParsedExpression) + if v.Error != nil { + errors = append(errors, fmt.Sprintf("cannot use agg functions in [%s], lookup group flag is not set or no lookups used: [%s]", tgtFieldDef.RawExpression, v.Error.Error())) + } + } + + // Just eval with test values, agg functions will go through preserving the type no problem + if err := evalExpressionWithFieldRefsAndCheckType(tgtFieldDef.ParsedExpression, node.TableCreator.UsedInTargetExpressionsFields, tgtFieldDef.Type); err != nil { + errors = append(errors, fmt.Sprintf("cannot evaluate table creator target field %s expression [%s]: [%s]", tgtFieldName, tgtFieldDef.RawExpression, err.Error())) + } + } + } + + if node.HasFileCreator() { + // Having + if err := evalExpressionWithFieldRefsAndCheckType(node.FileCreator.Having, node.FileCreator.UsedInHavingFields, FieldTypeBool); err != nil { + errors = append(errors, fmt.Sprintf("cannot evaluate file creator 'having' expression [%s]: [%s]", node.FileCreator.RawHaving, err.Error())) + } + + // Target table fields (yes, they are not just strings, we check the type) + for i := 0; i < len(node.FileCreator.Columns); i++ { + colDef := &node.FileCreator.Columns[i] + if err := evalExpressionWithFieldRefsAndCheckType(colDef.ParsedExpression, node.FileCreator.UsedInTargetExpressionsFields, colDef.Type); err != nil { + errors = append(errors, fmt.Sprintf("cannot evaluate table creator target field %s expression [%s]: [%s]", colDef.Name, colDef.RawExpression, err.Error())) + } + } + } + + // NOTE: do not even try to eval expressions from the custom processor here, + // they may contain custom stuff and are pretty much guaranteed to fail + + if len(errors) > 0 { + return fmt.Errorf(strings.Join(errors, "; ")) + } + + return nil +} + +func (node *ScriptNodeDef) getSourceFieldRefs() (*FieldRefs, error) { + if node.HasFileReader() { + return node.FileReader.getFieldRefs(), nil + } else if node.HasTableReader() { + return node.TableReader.TableCreator.GetFieldRefsWithAlias(ReaderAlias), nil + } + + return nil, fmt.Errorf("dev error, node of type %s has no file or table reader", node.Type) +} + +func (node *ScriptNodeDef) GetUniqueIndexesFieldRefs() *FieldRefs { + if !node.HasTableCreator() { + return &FieldRefs{} + } + fieldTypeMap := map[string]TableFieldType{} + for _, idxDef := range node.TableCreator.Indexes { + if idxDef.Uniqueness == IdxUnique { + for _, idxComponentDef := range idxDef.Components { + fieldTypeMap[idxComponentDef.FieldName] = idxComponentDef.FieldType + } + } + } + fieldRefs := make(FieldRefs, len(fieldTypeMap)) + fieldRefIdx := 0 + for fieldName, fieldType := range fieldTypeMap { + fieldRefs[fieldRefIdx] = FieldRef{ + FieldName: fieldName, + FieldType: fieldType, + TableName: node.TableCreator.Name} + fieldRefIdx++ + } + + return &fieldRefs +} + +func (node *ScriptNodeDef) GetTokenIntervalsByNumberOfBatches() ([][]int64, error) { + if node.HasTableReader() || node.HasFileCreator() && node.TableReader.ExpectedBatchesTotal > 1 { + if node.TableReader.ExpectedBatchesTotal == 1 { + return [][]int64{{int64(math.MinInt64), int64(math.MaxInt64)}}, nil + } + + tokenIntervalPerBatch := int64(math.MaxInt64/node.TableReader.ExpectedBatchesTotal) - int64(math.MinInt64/node.TableReader.ExpectedBatchesTotal) + + intervals := make([][]int64, node.TableReader.ExpectedBatchesTotal) + left := int64(math.MinInt64) + for i := 0; i < len(intervals); i++ { + var right int64 + if i == len(intervals)-1 { + right = math.MaxInt64 + } else { + right = left + tokenIntervalPerBatch - 1 + } + intervals[i] = []int64{left, right} + left = right + 1 + } + return intervals, nil + // } else if node.HasFileCreator() && node.TableReader.ExpectedBatchesTotal == 1 { + // // One output file - one batch, dummy intervals + // intervals := make([][]int64, 1) + // intervals[0] = []int64{int64(0), 0} + // return intervals, nil + } else if node.HasFileReader() { + // One input file - one batch + intervals := make([][]int64, len(node.FileReader.SrcFileUrls)) + for i := 0; i < len(node.FileReader.SrcFileUrls); i++ { + intervals[i] = []int64{int64(i), int64(i)} + } + return intervals, nil + } + + return nil, fmt.Errorf("cannot find implementation for intervals for node %s", node.Name) +} diff --git a/pkg/sc/table_creator_def.go b/pkg/sc/table_creator_def.go index c021db2..cf7ebc0 100644 --- a/pkg/sc/table_creator_def.go +++ b/pkg/sc/table_creator_def.go @@ -1,276 +1,274 @@ -package sc - -import ( - "encoding/json" - "fmt" - "go/ast" - "math" - "regexp" - "strconv" - "strings" - "time" - - "github.com/capillariesio/capillaries/pkg/eval" - "gopkg.in/inf.v0" -) - -const ProhibitedTableNameRegex = "^idx|^wf|^system" -const AllowedTableNameRegex = "[A-Za-z0-9_]+" -const AllowedIdxNameRegex = "^idx[A-Za-z0-9_]+" - -type TableUpdaterDef struct { - Fields map[string]*WriteTableFieldDef `json:"fields"` -} - -type TableCreatorDef struct { - Name string `json:"name"` - RawHaving string `json:"having"` - Having ast.Expr - UsedInHavingFields FieldRefs - UsedInTargetExpressionsFields FieldRefs - Fields map[string]*WriteTableFieldDef `json:"fields"` - RawIndexes map[string]string `json:"indexes"` - Indexes IdxDefMap -} - -// func (fieldDef *WriteTableFieldDef) CheckValueType(val interface{}) error { -// switch assertedValue := val.(type) { -// case int64: -// if fieldDef.Type != FieldTypeInt { -// return fmt.Errorf("expected type %s, but got int64 (%d)", fieldDef.Type, assertedValue) -// } -// case float64: -// if fieldDef.Type != FieldTypeFloat { -// return fmt.Errorf("expected type %s, but got float64 (%f)", fieldDef.Type, assertedValue) -// } -// case string: -// if fieldDef.Type != FieldTypeString { -// return fmt.Errorf("expected type %s, but got string (%s)", fieldDef.Type, assertedValue) -// } -// case bool: -// if fieldDef.Type != FieldTypeBool { -// return fmt.Errorf("expected type %s, but got bool (%v)", fieldDef.Type, assertedValue) -// } -// case time.Time: -// if fieldDef.Type != FieldTypeDateTime { -// return fmt.Errorf("expected type %s, but got datetime (%s)", fieldDef.Type, assertedValue.String()) -// } -// case decimal.Decimal: -// if fieldDef.Type != FieldTypeDecimal2 { -// return fmt.Errorf("expected type %s, but got decimal (%s)", fieldDef.Type, assertedValue.String()) -// } -// default: -// return fmt.Errorf("expected type %s, but got unexpected type %T(%v)", fieldDef.Type, assertedValue, assertedValue) -// } -// return nil -// } - -func (tcDef *TableCreatorDef) GetFieldRefs() *FieldRefs { - return tcDef.GetFieldRefsWithAlias("") -} - -func (tcDef *TableCreatorDef) GetFieldRefsWithAlias(useTableAlias string) *FieldRefs { - fieldRefs := make(FieldRefs, len(tcDef.Fields)) - i := 0 - for fieldName, fieldDef := range tcDef.Fields { - tName := tcDef.Name - if len(useTableAlias) > 0 { - tName = useTableAlias - } - fieldRefs[i] = FieldRef{ - TableName: tName, - FieldName: fieldName, - FieldType: fieldDef.Type} - i += 1 - } - return &fieldRefs -} - -func (tcDef *TableCreatorDef) Deserialize(rawWriter json.RawMessage) error { - var err error - if err = json.Unmarshal(rawWriter, tcDef); err != nil { - return fmt.Errorf("cannot unmarshal table creator: %s", err.Error()) - } - - re := regexp.MustCompile(ProhibitedTableNameRegex) - invalidNamePieceFound := re.FindString(tcDef.Name) - if len(invalidNamePieceFound) > 0 { - return fmt.Errorf("invalid table name [%s]: prohibited regex is [%s]", tcDef.Name, ProhibitedTableNameRegex) - } - - re = regexp.MustCompile(AllowedTableNameRegex) - invalidNamePieceFound = re.FindString(tcDef.Name) - if len(invalidNamePieceFound) != len(tcDef.Name) { - return fmt.Errorf("invalid table name [%s]: allowed regex is [%s]", tcDef.Name, AllowedTableNameRegex) - } - - // Having - tcDef.Having, err = ParseRawGolangExpressionStringAndHarvestFieldRefs(tcDef.RawHaving, &tcDef.UsedInHavingFields) - if err != nil { - return fmt.Errorf("cannot parse table creator 'having' condition [%s]: [%s]", tcDef.RawHaving, err.Error()) - } - - // Fields - for _, fieldDef := range tcDef.Fields { - if fieldDef.ParsedExpression, err = ParseRawGolangExpressionStringAndHarvestFieldRefs(fieldDef.RawExpression, &fieldDef.UsedFields); err != nil { - return fmt.Errorf("cannot parse field expression [%s]: [%s]", fieldDef.RawExpression, err.Error()) - } - if !IsValidFieldType(fieldDef.Type) { - return fmt.Errorf("invalid field type [%s]", fieldDef.Type) - } - } - - tcDef.UsedInTargetExpressionsFields = GetFieldRefsUsedInAllTargetExpressions(tcDef.Fields) - - // Indexes - tcDef.Indexes = IdxDefMap{} - if err := tcDef.Indexes.parseRawIndexDefMap(tcDef.RawIndexes, tcDef.GetFieldRefs()); err != nil { - return err - } - - re = regexp.MustCompile(AllowedIdxNameRegex) - for idxName, _ := range tcDef.Indexes { - invalidNamePieceFound := re.FindString(idxName) - if len(invalidNamePieceFound) != len(idxName) { - return fmt.Errorf("invalid index name [%s]: allowed regex is [%s]", idxName, AllowedIdxNameRegex) - } - - } - - return nil -} - -func (creatorDef *TableCreatorDef) GetFieldDefaultReadyForDb(fieldName string) (interface{}, error) { - writerFieldDef, ok := creatorDef.Fields[fieldName] - if !ok { - return nil, fmt.Errorf("default for unknown field %s", fieldName) - } - defaultValueString := strings.TrimSpace(writerFieldDef.DefaultValue) - - var err error - switch writerFieldDef.Type { - case FieldTypeInt: - v := DefaultInt - if len(defaultValueString) > 0 { - v, err = strconv.ParseInt(defaultValueString, 10, 64) - if err != nil { - return nil, fmt.Errorf("cannot read int64 field %s from default value string '%s': %s", fieldName, defaultValueString, err.Error()) - } - } - return v, nil - case FieldTypeFloat: - v := DefaultFloat - if len(defaultValueString) > 0 { - v, err = strconv.ParseFloat(defaultValueString, 64) - if err != nil { - return nil, fmt.Errorf("cannot read float64 field %s from default value string '%s': %s", fieldName, defaultValueString, err.Error()) - } - } - return v, nil - case FieldTypeString: - v := DefaultString - if len(defaultValueString) > 0 { - v = defaultValueString - } - return v, nil - case FieldTypeDecimal2: - // Set it to Cassandra-accepted value, not decimal.Decimal: https://github.com/gocql/gocql/issues/1578 - v := inf.NewDec(0, 0) - if len(defaultValueString) > 0 { - f, err := strconv.ParseFloat(defaultValueString, 64) - if err != nil { - return nil, fmt.Errorf("cannot read decimal2 field %s from default value string '%s': %s", fieldName, defaultValueString, err.Error()) - } - scaled := int64(math.Round(f * 100)) - v = inf.NewDec(scaled, 2) - } - return v, nil - case FieldTypeBool: - v := DefaultBool - if len(defaultValueString) > 0 { - v, err = strconv.ParseBool(defaultValueString) - if err != nil { - return nil, fmt.Errorf("cannot read bool field %s, from default value string '%s', allowed values are true,false,T,F,0,1: %s", fieldName, defaultValueString, err.Error()) - } - } - return v, nil - case FieldTypeDateTime: - v := DefaultDateTime() - if len(defaultValueString) > 0 { - v, err = time.Parse(CassandraDatetimeFormat, defaultValueString) - if err != nil { - return nil, fmt.Errorf("cannot read time field %s from default value string '%s': %s", fieldName, defaultValueString, err.Error()) - } - } - return v, nil - default: - return nil, fmt.Errorf("GetFieldDefault unsupported field type %s, field %s", writerFieldDef.Type, fieldName) - } -} - -func CalculateFieldValue(fieldName string, fieldDef *WriteTableFieldDef, srcVars eval.VarValuesMap, canUseAggFunc bool) (interface{}, error) { - calcWithAggFunc, aggFuncType, aggFuncArgs := eval.DetectRootAggFunc(fieldDef.ParsedExpression) - if !canUseAggFunc { - calcWithAggFunc = eval.AggFuncDisabled - } - - eCtx, err := eval.NewPlainEvalCtxWithVarsAndInitializedAgg(calcWithAggFunc, &srcVars, aggFuncType, aggFuncArgs) - if err != nil { - return nil, err - } - - valVolatile, err := eCtx.Eval(fieldDef.ParsedExpression) - if err != nil { - return nil, fmt.Errorf("cannot evaluate expression for field %s: [%s]", fieldName, err.Error()) - } else { - if err := CheckValueType(valVolatile, fieldDef.Type); err != nil { - return nil, fmt.Errorf("invalid field %s type: [%s]", fieldName, err.Error()) - } else { - return valVolatile, nil - } - } -} - -func (creatorDef *TableCreatorDef) CalculateTableRecordFromSrcVars(canUseAggFunc bool, srcVars eval.VarValuesMap) (map[string]interface{}, error) { - errors := make([]string, 0, 2) - - tableRecord := map[string]interface{}{} - - for fieldName, fieldDef := range creatorDef.Fields { - var err error - tableRecord[fieldName], err = CalculateFieldValue(fieldName, fieldDef, srcVars, canUseAggFunc) - if err != nil { - errors = append(errors, err.Error()) - } - } - - if len(errors) > 0 { - return nil, fmt.Errorf(strings.Join(errors, "; ")) - } else { - return tableRecord, nil - } -} - -func (creatorDef *TableCreatorDef) CheckTableRecordHavingCondition(tableRecord map[string]interface{}) (bool, error) { - if creatorDef.Having == nil { - // No Having condition specified - return true, nil - } - vars := eval.VarValuesMap{} - vars[CreatorAlias] = map[string]interface{}{} - for fieldName, fieldValue := range tableRecord { - vars[CreatorAlias][fieldName] = fieldValue - } - - eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &vars) - valVolatile, err := eCtx.Eval(creatorDef.Having) - if err != nil { - return false, fmt.Errorf("cannot evaluate 'having' expression: [%s]", err.Error()) - } - valBool, ok := valVolatile.(bool) - if !ok { - return false, fmt.Errorf("cannot get bool when evaluating having expression, got %v(%T) instead", valVolatile, valVolatile) - } - - return valBool, nil -} +package sc + +import ( + "encoding/json" + "fmt" + "go/ast" + "math" + "regexp" + "strconv" + "strings" + "time" + + "github.com/capillariesio/capillaries/pkg/eval" + "gopkg.in/inf.v0" +) + +const ProhibitedTableNameRegex = "^idx|^wf|^system" +const AllowedTableNameRegex = "[A-Za-z0-9_]+" +const AllowedIdxNameRegex = "^idx[A-Za-z0-9_]+" + +type TableUpdaterDef struct { + Fields map[string]*WriteTableFieldDef `json:"fields"` +} + +type TableCreatorDef struct { + Name string `json:"name"` + RawHaving string `json:"having"` + Having ast.Expr + UsedInHavingFields FieldRefs + UsedInTargetExpressionsFields FieldRefs + Fields map[string]*WriteTableFieldDef `json:"fields"` + RawIndexes map[string]string `json:"indexes"` + Indexes IdxDefMap +} + +// func (fieldDef *WriteTableFieldDef) CheckValueType(val any) error { +// switch assertedValue := val.(type) { +// case int64: +// if fieldDef.Type != FieldTypeInt { +// return fmt.Errorf("expected type %s, but got int64 (%d)", fieldDef.Type, assertedValue) +// } +// case float64: +// if fieldDef.Type != FieldTypeFloat { +// return fmt.Errorf("expected type %s, but got float64 (%f)", fieldDef.Type, assertedValue) +// } +// case string: +// if fieldDef.Type != FieldTypeString { +// return fmt.Errorf("expected type %s, but got string (%s)", fieldDef.Type, assertedValue) +// } +// case bool: +// if fieldDef.Type != FieldTypeBool { +// return fmt.Errorf("expected type %s, but got bool (%v)", fieldDef.Type, assertedValue) +// } +// case time.Time: +// if fieldDef.Type != FieldTypeDateTime { +// return fmt.Errorf("expected type %s, but got datetime (%s)", fieldDef.Type, assertedValue.String()) +// } +// case decimal.Decimal: +// if fieldDef.Type != FieldTypeDecimal2 { +// return fmt.Errorf("expected type %s, but got decimal (%s)", fieldDef.Type, assertedValue.String()) +// } +// default: +// return fmt.Errorf("expected type %s, but got unexpected type %T(%v)", fieldDef.Type, assertedValue, assertedValue) +// } +// return nil +// } + +func (tcDef *TableCreatorDef) GetFieldRefs() *FieldRefs { + return tcDef.GetFieldRefsWithAlias("") +} + +func (tcDef *TableCreatorDef) GetFieldRefsWithAlias(useTableAlias string) *FieldRefs { + fieldRefs := make(FieldRefs, len(tcDef.Fields)) + i := 0 + for fieldName, fieldDef := range tcDef.Fields { + tName := tcDef.Name + if len(useTableAlias) > 0 { + tName = useTableAlias + } + fieldRefs[i] = FieldRef{ + TableName: tName, + FieldName: fieldName, + FieldType: fieldDef.Type} + i++ + } + return &fieldRefs +} + +func (tcDef *TableCreatorDef) Deserialize(rawWriter json.RawMessage) error { + var err error + if err = json.Unmarshal(rawWriter, tcDef); err != nil { + return fmt.Errorf("cannot unmarshal table creator: %s", err.Error()) + } + + re := regexp.MustCompile(ProhibitedTableNameRegex) + invalidNamePieceFound := re.FindString(tcDef.Name) + if len(invalidNamePieceFound) > 0 { + return fmt.Errorf("invalid table name [%s]: prohibited regex is [%s]", tcDef.Name, ProhibitedTableNameRegex) + } + + re = regexp.MustCompile(AllowedTableNameRegex) + invalidNamePieceFound = re.FindString(tcDef.Name) + if len(invalidNamePieceFound) != len(tcDef.Name) { + return fmt.Errorf("invalid table name [%s]: allowed regex is [%s]", tcDef.Name, AllowedTableNameRegex) + } + + // Having + tcDef.Having, err = ParseRawGolangExpressionStringAndHarvestFieldRefs(tcDef.RawHaving, &tcDef.UsedInHavingFields) + if err != nil { + return fmt.Errorf("cannot parse table creator 'having' condition [%s]: [%s]", tcDef.RawHaving, err.Error()) + } + + // Fields + for _, fieldDef := range tcDef.Fields { + if fieldDef.ParsedExpression, err = ParseRawGolangExpressionStringAndHarvestFieldRefs(fieldDef.RawExpression, &fieldDef.UsedFields); err != nil { + return fmt.Errorf("cannot parse field expression [%s]: [%s]", fieldDef.RawExpression, err.Error()) + } + if !IsValidFieldType(fieldDef.Type) { + return fmt.Errorf("invalid field type [%s]", fieldDef.Type) + } + } + + tcDef.UsedInTargetExpressionsFields = GetFieldRefsUsedInAllTargetExpressions(tcDef.Fields) + + // Indexes + tcDef.Indexes = IdxDefMap{} + if err := tcDef.Indexes.parseRawIndexDefMap(tcDef.RawIndexes, tcDef.GetFieldRefs()); err != nil { + return err + } + + re = regexp.MustCompile(AllowedIdxNameRegex) + for idxName := range tcDef.Indexes { + invalidNamePieceFound := re.FindString(idxName) + if len(invalidNamePieceFound) != len(idxName) { + return fmt.Errorf("invalid index name [%s]: allowed regex is [%s]", idxName, AllowedIdxNameRegex) + } + + } + + return nil +} + +func (tcDef *TableCreatorDef) GetFieldDefaultReadyForDb(fieldName string) (any, error) { + writerFieldDef, ok := tcDef.Fields[fieldName] + if !ok { + return nil, fmt.Errorf("default for unknown field %s", fieldName) + } + defaultValueString := strings.TrimSpace(writerFieldDef.DefaultValue) + + var err error + switch writerFieldDef.Type { + case FieldTypeInt: + v := DefaultInt + if len(defaultValueString) > 0 { + v, err = strconv.ParseInt(defaultValueString, 10, 64) + if err != nil { + return nil, fmt.Errorf("cannot read int64 field %s from default value string '%s': %s", fieldName, defaultValueString, err.Error()) + } + } + return v, nil + case FieldTypeFloat: + v := DefaultFloat + if len(defaultValueString) > 0 { + v, err = strconv.ParseFloat(defaultValueString, 64) + if err != nil { + return nil, fmt.Errorf("cannot read float64 field %s from default value string '%s': %s", fieldName, defaultValueString, err.Error()) + } + } + return v, nil + case FieldTypeString: + v := DefaultString + if len(defaultValueString) > 0 { + v = defaultValueString + } + return v, nil + case FieldTypeDecimal2: + // Set it to Cassandra-accepted value, not decimal.Decimal: https://github.com/gocql/gocql/issues/1578 + v := inf.NewDec(0, 0) + if len(defaultValueString) > 0 { + f, err := strconv.ParseFloat(defaultValueString, 64) + if err != nil { + return nil, fmt.Errorf("cannot read decimal2 field %s from default value string '%s': %s", fieldName, defaultValueString, err.Error()) + } + scaled := int64(math.Round(f * 100)) + v = inf.NewDec(scaled, 2) + } + return v, nil + case FieldTypeBool: + v := DefaultBool + if len(defaultValueString) > 0 { + v, err = strconv.ParseBool(defaultValueString) + if err != nil { + return nil, fmt.Errorf("cannot read bool field %s, from default value string '%s', allowed values are true,false,T,F,0,1: %s", fieldName, defaultValueString, err.Error()) + } + } + return v, nil + case FieldTypeDateTime: + v := DefaultDateTime() + if len(defaultValueString) > 0 { + v, err = time.Parse(CassandraDatetimeFormat, defaultValueString) + if err != nil { + return nil, fmt.Errorf("cannot read time field %s from default value string '%s': %s", fieldName, defaultValueString, err.Error()) + } + } + return v, nil + default: + return nil, fmt.Errorf("GetFieldDefault unsupported field type %s, field %s", writerFieldDef.Type, fieldName) + } +} + +func CalculateFieldValue(fieldName string, fieldDef *WriteTableFieldDef, srcVars eval.VarValuesMap, canUseAggFunc bool) (any, error) { + calcWithAggFunc, aggFuncType, aggFuncArgs := eval.DetectRootAggFunc(fieldDef.ParsedExpression) + if !canUseAggFunc { + calcWithAggFunc = eval.AggFuncDisabled + } + + eCtx, err := eval.NewPlainEvalCtxWithVarsAndInitializedAgg(calcWithAggFunc, &srcVars, aggFuncType, aggFuncArgs) + if err != nil { + return nil, err + } + + valVolatile, err := eCtx.Eval(fieldDef.ParsedExpression) + if err != nil { + return nil, fmt.Errorf("cannot evaluate expression for field %s: [%s]", fieldName, err.Error()) + } + if err := CheckValueType(valVolatile, fieldDef.Type); err != nil { + return nil, fmt.Errorf("invalid field %s type: [%s]", fieldName, err.Error()) + } + return valVolatile, nil +} + +func (tcDef *TableCreatorDef) CalculateTableRecordFromSrcVars(canUseAggFunc bool, srcVars eval.VarValuesMap) (map[string]any, error) { + errors := make([]string, 0, 2) + + tableRecord := map[string]any{} + + for fieldName, fieldDef := range tcDef.Fields { + var err error + tableRecord[fieldName], err = CalculateFieldValue(fieldName, fieldDef, srcVars, canUseAggFunc) + if err != nil { + errors = append(errors, err.Error()) + } + } + + if len(errors) > 0 { + return nil, fmt.Errorf(strings.Join(errors, "; ")) + } + + return tableRecord, nil +} + +func (tcDef *TableCreatorDef) CheckTableRecordHavingCondition(tableRecord map[string]any) (bool, error) { + if tcDef.Having == nil { + // No Having condition specified + return true, nil + } + vars := eval.VarValuesMap{} + vars[CreatorAlias] = map[string]any{} + for fieldName, fieldValue := range tableRecord { + vars[CreatorAlias][fieldName] = fieldValue + } + + eCtx := eval.NewPlainEvalCtxWithVars(eval.AggFuncDisabled, &vars) + valVolatile, err := eCtx.Eval(tcDef.Having) + if err != nil { + return false, fmt.Errorf("cannot evaluate 'having' expression: [%s]", err.Error()) + } + valBool, ok := valVolatile.(bool) + if !ok { + return false, fmt.Errorf("cannot get bool when evaluating having expression, got %v(%T) instead", valVolatile, valVolatile) + } + + return valBool, nil +} diff --git a/pkg/sc/table_creator_def_test.go b/pkg/sc/table_creator_def_test.go index 40b40c5..529fe3b 100644 --- a/pkg/sc/table_creator_def_test.go +++ b/pkg/sc/table_creator_def_test.go @@ -1,163 +1,198 @@ -package sc - -import ( - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "gopkg.in/inf.v0" -) - -const tableCreatorNodeJson string = ` -{ - "name": "test_table_creator", - "fields": { - "field_int": { - "expression": "r.field_int", - "default_value": "99", - "type": "int" - }, - "field_float": { - "expression": "r.field_float", - "default_value": "99.0", - "type": "float" - }, - "field_decimal2": { - "expression": "r.field_decimal2", - "default_value": "123.00", - "type": "decimal2" - }, - "field_datetime": { - "expression": "r.field_datetime", - "default_value": "1980-02-03T04:05:06.777+00:00", - "type": "datetime" - }, - "field_bool": { - "expression": "r.field_bool", - "default_value": "true", - "type": "bool" - }, - "field_string": { - "expression": "r.field_string", - "default_value": "some_string", - "type": "string" - } - }, - "indexes": { - "idx_1": "unique(field_string(case_sensitive))" - } -}` - -func TestCreatorDefaultFieldValues(t *testing.T) { - c := TableCreatorDef{} - assert.Nil(t, c.Deserialize([]byte(tableCreatorNodeJson))) - - var err error - var val interface{} - - val, err = c.GetFieldDefaultReadyForDb("field_int") - assert.Nil(t, err) - assert.Equal(t, int64(99), val.(int64)) - - val, err = c.GetFieldDefaultReadyForDb("field_float") - assert.Nil(t, err) - assert.Equal(t, float64(99.0), val.(float64)) - - val, err = c.GetFieldDefaultReadyForDb("field_decimal2") - assert.Nil(t, err) - assert.Equal(t, inf.NewDec(12300, 2), val.(*inf.Dec)) - - val, err = c.GetFieldDefaultReadyForDb("field_datetime") - assert.Nil(t, err) - dt, _ := time.Parse(CassandraDatetimeFormat, "1980-02-03T04:05:06.777+00:00") - assert.Equal(t, dt, val.(time.Time)) - - val, err = c.GetFieldDefaultReadyForDb("field_bool") - assert.Nil(t, err) - assert.Equal(t, true, val.(bool)) - - val, err = c.GetFieldDefaultReadyForDb("field_string") - assert.Nil(t, err) - assert.Equal(t, "some_string", val.(string)) - - confReplacer := strings.NewReplacer( - `"default_value": "99",`, ``, - `"default_value": "99.0",`, ``, - `"default_value": "123.00",`, ``, - `"default_value": "1980-02-03T04:05:06.777+00:00",`, ``, - `"default_value": "true",`, ``, - `"default_value": "some_string",`, ``) - - assert.Nil(t, c.Deserialize([]byte(confReplacer.Replace(tableCreatorNodeJson)))) - - val, err = c.GetFieldDefaultReadyForDb("field_int") - assert.Nil(t, err) - assert.Equal(t, int64(0), val.(int64)) - - val, err = c.GetFieldDefaultReadyForDb("field_float") - assert.Nil(t, err) - assert.Equal(t, float64(0.0), val.(float64)) - - val, err = c.GetFieldDefaultReadyForDb("field_decimal2") - assert.Nil(t, err) - assert.Equal(t, inf.NewDec(0, 0), val.(*inf.Dec)) - - val, err = c.GetFieldDefaultReadyForDb("field_datetime") - assert.Nil(t, err) - assert.Equal(t, DefaultDateTime(), val.(time.Time)) - - val, err = c.GetFieldDefaultReadyForDb("field_bool") - assert.Nil(t, err) - assert.False(t, val.(bool)) - - val, err = c.GetFieldDefaultReadyForDb("field_string") - assert.Nil(t, err) - assert.Equal(t, "", val.(string)) - - // Failures - err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "test_table_creator", "&"))) - assert.Contains(t, err.Error(), "invalid table name [&]: allowed regex is") - - err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "test_table_creator", "idx_a"))) - assert.Contains(t, err.Error(), "invalid table name [idx_a]: prohibited regex is") - - err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "string", "bad_type"))) - assert.Contains(t, err.Error(), "invalid field type [bad_type]") - - c = TableCreatorDef{} - err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "idx_1", "bad_idx_name"))) - assert.Contains(t, err.Error(), "invalid index name [bad_idx_name]: allowed regex is") - - // Check default fields - _, err = c.GetFieldDefaultReadyForDb("bad_field") - assert.Contains(t, err.Error(), "default for unknown field bad_field") - - c = TableCreatorDef{} - - err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "99", "aaa"))) - assert.Nil(t, err) - _, err = c.GetFieldDefaultReadyForDb("field_int") - assert.Contains(t, err.Error(), "cannot read int64 field field_int from default value string 'aaa'") - - err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "99.0", "aaa"))) - assert.Nil(t, err) - _, err = c.GetFieldDefaultReadyForDb("field_float") - assert.Contains(t, err.Error(), "cannot read float64 field field_float from default value string 'aaa'") - - err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "123.00", "aaa"))) - assert.Nil(t, err) - _, err = c.GetFieldDefaultReadyForDb("field_decimal2") - assert.Contains(t, err.Error(), "cannot read decimal2 field field_decimal2 from default value string 'aaa'") - - err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "1980-02-03T04:05:06.777+00:00", "aaa"))) - assert.Nil(t, err) - _, err = c.GetFieldDefaultReadyForDb("field_datetime") - assert.Contains(t, err.Error(), "cannot read time field field_datetime from default value string 'aaa'") - - err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "true", "aaa"))) - assert.Nil(t, err) - _, err = c.GetFieldDefaultReadyForDb("field_bool") - assert.Contains(t, err.Error(), "cannot read bool field field_bool, from default value string 'aaa'") - -} +package sc + +import ( + "regexp" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "gopkg.in/inf.v0" +) + +const tableCreatorNodeJson string = ` +{ + "name": "test_table_creator", + "having": "len(w.field_string) > 0", + "fields": { + "field_int": { + "expression": "r.field_int", + "default_value": "99", + "type": "int" + }, + "field_float": { + "expression": "r.field_float", + "default_value": "99.0", + "type": "float" + }, + "field_decimal2": { + "expression": "r.field_decimal2", + "default_value": "123.00", + "type": "decimal2" + }, + "field_datetime": { + "expression": "r.field_datetime", + "default_value": "1980-02-03T04:05:06.777+00:00", + "type": "datetime" + }, + "field_bool": { + "expression": "r.field_bool", + "default_value": "true", + "type": "bool" + }, + "field_string": { + "expression": "r.field_string", + "default_value": "some_string", + "type": "string" + } + }, + "indexes": { + "idx_1": "unique(field_string(case_sensitive))" + } +}` + +func TestCreatorDefaultFieldValues(t *testing.T) { + c := TableCreatorDef{} + assert.Nil(t, c.Deserialize([]byte(tableCreatorNodeJson))) + + var err error + var val any + + val, err = c.GetFieldDefaultReadyForDb("field_int") + assert.Nil(t, err) + assert.Equal(t, int64(99), val.(int64)) + + val, err = c.GetFieldDefaultReadyForDb("field_float") + assert.Nil(t, err) + assert.Equal(t, float64(99.0), val.(float64)) + + val, err = c.GetFieldDefaultReadyForDb("field_decimal2") + assert.Nil(t, err) + assert.Equal(t, inf.NewDec(12300, 2), val.(*inf.Dec)) + + val, err = c.GetFieldDefaultReadyForDb("field_datetime") + assert.Nil(t, err) + dt, _ := time.Parse(CassandraDatetimeFormat, "1980-02-03T04:05:06.777+00:00") + assert.Equal(t, dt, val.(time.Time)) + + val, err = c.GetFieldDefaultReadyForDb("field_bool") + assert.Nil(t, err) + assert.Equal(t, true, val.(bool)) + + val, err = c.GetFieldDefaultReadyForDb("field_string") + assert.Nil(t, err) + assert.Equal(t, "some_string", val.(string)) + + confReplacer := strings.NewReplacer( + `"default_value": "99",`, ``, + `"default_value": "99.0",`, ``, + `"default_value": "123.00",`, ``, + `"default_value": "1980-02-03T04:05:06.777+00:00",`, ``, + `"default_value": "true",`, ``, + `"default_value": "some_string",`, ``) + + assert.Nil(t, c.Deserialize([]byte(confReplacer.Replace(tableCreatorNodeJson)))) + + val, err = c.GetFieldDefaultReadyForDb("field_int") + assert.Nil(t, err) + assert.Equal(t, int64(0), val.(int64)) + + val, err = c.GetFieldDefaultReadyForDb("field_float") + assert.Nil(t, err) + assert.Equal(t, float64(0.0), val.(float64)) + + val, err = c.GetFieldDefaultReadyForDb("field_decimal2") + assert.Nil(t, err) + assert.Equal(t, inf.NewDec(0, 0), val.(*inf.Dec)) + + val, err = c.GetFieldDefaultReadyForDb("field_datetime") + assert.Nil(t, err) + assert.Equal(t, DefaultDateTime(), val.(time.Time)) + + val, err = c.GetFieldDefaultReadyForDb("field_bool") + assert.Nil(t, err) + assert.False(t, val.(bool)) + + val, err = c.GetFieldDefaultReadyForDb("field_string") + assert.Nil(t, err) + assert.Equal(t, "", val.(string)) + + // Failures + err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "test_table_creator", "&"))) + assert.Contains(t, err.Error(), "invalid table name [&]: allowed regex is") + + err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "test_table_creator", "idx_a"))) + assert.Contains(t, err.Error(), "invalid table name [idx_a]: prohibited regex is") + + err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "string", "bad_type"))) + assert.Contains(t, err.Error(), "invalid field type [bad_type]") + + c = TableCreatorDef{} + err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "idx_1", "bad_idx_name"))) + assert.Contains(t, err.Error(), "invalid index name [bad_idx_name]: allowed regex is") + + // Check default fields + _, err = c.GetFieldDefaultReadyForDb("bad_field") + assert.Contains(t, err.Error(), "default for unknown field bad_field") + + c = TableCreatorDef{} + + err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "99", "aaa"))) + assert.Nil(t, err) + _, err = c.GetFieldDefaultReadyForDb("field_int") + assert.Contains(t, err.Error(), "cannot read int64 field field_int from default value string 'aaa'") + + err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "99.0", "aaa"))) + assert.Nil(t, err) + _, err = c.GetFieldDefaultReadyForDb("field_float") + assert.Contains(t, err.Error(), "cannot read float64 field field_float from default value string 'aaa'") + + err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "123.00", "aaa"))) + assert.Nil(t, err) + _, err = c.GetFieldDefaultReadyForDb("field_decimal2") + assert.Contains(t, err.Error(), "cannot read decimal2 field field_decimal2 from default value string 'aaa'") + + err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "1980-02-03T04:05:06.777+00:00", "aaa"))) + assert.Nil(t, err) + _, err = c.GetFieldDefaultReadyForDb("field_datetime") + assert.Contains(t, err.Error(), "cannot read time field field_datetime from default value string 'aaa'") + + err = c.Deserialize([]byte(strings.ReplaceAll(tableCreatorNodeJson, "true", "aaa"))) + assert.Nil(t, err) + _, err = c.GetFieldDefaultReadyForDb("field_bool") + assert.Contains(t, err.Error(), "cannot read bool field field_bool, from default value string 'aaa'") +} + +func TestCheckTableRecordHavingCondition(t *testing.T) { + c := TableCreatorDef{} + assert.Nil(t, c.Deserialize([]byte(tableCreatorNodeJson))) + + isPass, err := c.CheckTableRecordHavingCondition(map[string]any{"field_string": "aaa"}) + assert.Nil(t, err) + assert.True(t, isPass) + + isPass, err = c.CheckTableRecordHavingCondition(map[string]any{"field_string": ""}) + assert.Nil(t, err) + assert.False(t, isPass) + + re := regexp.MustCompile(`"having": "[^"]+",`) + assert.Nil(t, c.Deserialize([]byte(re.ReplaceAllString(tableCreatorNodeJson, `"having": "w.bad_field",`)))) + _, err = c.CheckTableRecordHavingCondition(map[string]any{"field_string": "aaa"}) + assert.Contains(t, err.Error(), "cannot evaluate 'having' expression") + + re = regexp.MustCompile(`"having": "[^"]+",`) + assert.Nil(t, c.Deserialize([]byte(re.ReplaceAllString(tableCreatorNodeJson, `"having": "w.field_string",`)))) + _, err = c.CheckTableRecordHavingCondition(map[string]any{"field_string": "aaa"}) + assert.Contains(t, err.Error(), "cannot get bool when evaluating having expression, got aaa(string) instead") + + assert.Nil(t, c.Deserialize([]byte(re.ReplaceAllString(tableCreatorNodeJson, `"having": "w.field_string",`)))) + _, err = c.CheckTableRecordHavingCondition(map[string]any{"field_string": "aaa"}) + assert.Contains(t, err.Error(), "cannot get bool when evaluating having expression, got aaa(string) instead") + + // Remove having + c = TableCreatorDef{} + re = regexp.MustCompile(`"having": "[^"]+",`) + assert.Nil(t, c.Deserialize([]byte(re.ReplaceAllString(tableCreatorNodeJson, ``)))) + _, err = c.CheckTableRecordHavingCondition(map[string]any{"field_string": "aaa"}) + assert.Nil(t, err) +} diff --git a/pkg/sc/table_def.go b/pkg/sc/table_def.go index 33d3715..9a8ecc9 100644 --- a/pkg/sc/table_def.go +++ b/pkg/sc/table_def.go @@ -1,100 +1,100 @@ -package sc - -import ( - "fmt" - "time" - - "github.com/shopspring/decimal" - "gopkg.in/inf.v0" -) - -const ( - FieldNameUnknown = "unknown_field_name" -) - -type TableFieldType string - -const ( - FieldTypeString TableFieldType = "string" - FieldTypeInt TableFieldType = "int" // sign+18digit string - FieldTypeFloat TableFieldType = "float" // sign+64digit string, 32 digits after point - FieldTypeBool TableFieldType = "bool" // F or T - FieldTypeDecimal2 TableFieldType = "decimal2" // sign + 18digit+point+2 - FieldTypeDateTime TableFieldType = "datetime" // int unix epoch milliseconds - FieldTypeUnknown TableFieldType = "unknown" -) - -func IsValidFieldType(fieldType TableFieldType) bool { - return fieldType == FieldTypeString || - fieldType == FieldTypeInt || - fieldType == FieldTypeFloat || - fieldType == FieldTypeBool || - fieldType == FieldTypeDecimal2 || - fieldType == FieldTypeDateTime -} - -// Cassandra timestamps are milliseconds. No microsecond support. -// On writes: -// - allows (but not requires) ":" in the timezone -// - allows (but not requires) "T" in as date/time separator -const CassandraDatetimeFormat string = "2006-01-02T15:04:05.000-07:00" - -const DefaultInt int64 = int64(0) -const DefaultFloat float64 = float64(0.0) -const DefaultString string = "" -const DefaultBool bool = false - -func DefaultDecimal2() decimal.Decimal { return decimal.NewFromFloat(0.0) } -func DefaultCassandraDecimal2() *inf.Dec { return inf.NewDec(0, 0) } -func DefaultDateTime() time.Time { return time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC) } // Same as time.Time default - -func GetDefaultFieldTypeValue(fieldType TableFieldType) interface{} { - switch fieldType { - case FieldTypeInt: - return DefaultInt - case FieldTypeFloat: - return DefaultFloat - case FieldTypeString: - return DefaultString - case FieldTypeDecimal2: - return DefaultDecimal2() - case FieldTypeBool: - return DefaultBool - case FieldTypeDateTime: - return DefaultDateTime() - default: - return nil - } -} - -func CheckValueType(val interface{}, fieldType TableFieldType) error { - switch assertedValue := val.(type) { - case int64: - if fieldType != FieldTypeInt { - return fmt.Errorf("expected type %s, but got int64 (%d)", fieldType, assertedValue) - } - case float64: - if fieldType != FieldTypeFloat { - return fmt.Errorf("expected type %s, but got float64 (%f)", fieldType, assertedValue) - } - case string: - if fieldType != FieldTypeString { - return fmt.Errorf("expected type %s, but got string (%s)", fieldType, assertedValue) - } - case bool: - if fieldType != FieldTypeBool { - return fmt.Errorf("expected type %s, but got bool (%v)", fieldType, assertedValue) - } - case time.Time: - if fieldType != FieldTypeDateTime { - return fmt.Errorf("expected type %s, but got datetime (%s)", fieldType, assertedValue.String()) - } - case decimal.Decimal: - if fieldType != FieldTypeDecimal2 { - return fmt.Errorf("expected type %s, but got decimal (%s)", fieldType, assertedValue.String()) - } - default: - return fmt.Errorf("expected type %s, but got unexpected type %T(%v)", fieldType, assertedValue, assertedValue) - } - return nil -} +package sc + +import ( + "fmt" + "time" + + "github.com/shopspring/decimal" + "gopkg.in/inf.v0" +) + +const ( + FieldNameUnknown = "unknown_field_name" +) + +type TableFieldType string + +const ( + FieldTypeString TableFieldType = "string" + FieldTypeInt TableFieldType = "int" // sign+18digit string + FieldTypeFloat TableFieldType = "float" // sign+64digit string, 32 digits after point + FieldTypeBool TableFieldType = "bool" // F or T + FieldTypeDecimal2 TableFieldType = "decimal2" // sign + 18digit+point+2 + FieldTypeDateTime TableFieldType = "datetime" // int unix epoch milliseconds + FieldTypeUnknown TableFieldType = "unknown" +) + +func IsValidFieldType(fieldType TableFieldType) bool { + return fieldType == FieldTypeString || + fieldType == FieldTypeInt || + fieldType == FieldTypeFloat || + fieldType == FieldTypeBool || + fieldType == FieldTypeDecimal2 || + fieldType == FieldTypeDateTime +} + +// Cassandra timestamps are milliseconds. No microsecond support. +// On writes: +// - allows (but not requires) ":" in the timezone +// - allows (but not requires) "T" in as date/time separator +const CassandraDatetimeFormat string = "2006-01-02T15:04:05.000-07:00" + +const DefaultInt int64 = int64(0) +const DefaultFloat float64 = float64(0.0) +const DefaultString string = "" +const DefaultBool bool = false + +func DefaultDecimal2() decimal.Decimal { return decimal.NewFromFloat(0.0) } +func DefaultCassandraDecimal2() *inf.Dec { return inf.NewDec(0, 0) } +func DefaultDateTime() time.Time { return time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC) } // Same as time.Time default + +func GetDefaultFieldTypeValue(fieldType TableFieldType) any { + switch fieldType { + case FieldTypeInt: + return DefaultInt + case FieldTypeFloat: + return DefaultFloat + case FieldTypeString: + return DefaultString + case FieldTypeDecimal2: + return DefaultDecimal2() + case FieldTypeBool: + return DefaultBool + case FieldTypeDateTime: + return DefaultDateTime() + default: + return nil + } +} + +func CheckValueType(val any, fieldType TableFieldType) error { + switch assertedValue := val.(type) { + case int64: + if fieldType != FieldTypeInt { + return fmt.Errorf("expected type %s, but got int64 (%d)", fieldType, assertedValue) + } + case float64: + if fieldType != FieldTypeFloat { + return fmt.Errorf("expected type %s, but got float64 (%f)", fieldType, assertedValue) + } + case string: + if fieldType != FieldTypeString { + return fmt.Errorf("expected type %s, but got string (%s)", fieldType, assertedValue) + } + case bool: + if fieldType != FieldTypeBool { + return fmt.Errorf("expected type %s, but got bool (%v)", fieldType, assertedValue) + } + case time.Time: + if fieldType != FieldTypeDateTime { + return fmt.Errorf("expected type %s, but got datetime (%s)", fieldType, assertedValue.String()) + } + case decimal.Decimal: + if fieldType != FieldTypeDecimal2 { + return fmt.Errorf("expected type %s, but got decimal (%s)", fieldType, assertedValue.String()) + } + default: + return fmt.Errorf("expected type %s, but got unexpected type %T(%v)", fieldType, assertedValue, assertedValue) + } + return nil +} diff --git a/pkg/sc/table_reader_def.go b/pkg/sc/table_reader_def.go index 24e1f03..1016471 100644 --- a/pkg/sc/table_reader_def.go +++ b/pkg/sc/table_reader_def.go @@ -1,8 +1,8 @@ -package sc - -type TableReaderDef struct { - TableName string `json:"table"` - ExpectedBatchesTotal int `json:"expected_batches_total"` - RowsetSize int `json:"rowset_size"` // DefaultRowsetSize = 1000 - TableCreator *TableCreatorDef -} +package sc + +type TableReaderDef struct { + TableName string `json:"table"` + ExpectedBatchesTotal int `json:"expected_batches_total"` + RowsetSize int `json:"rowset_size"` // DefaultRowsetSize = 1000 + TableCreator *TableCreatorDef +} diff --git a/pkg/sc/write_table_field_def.go b/pkg/sc/write_table_field_def.go index 1dc6588..7b51eeb 100644 --- a/pkg/sc/write_table_field_def.go +++ b/pkg/sc/write_table_field_def.go @@ -1,36 +1,36 @@ -package sc - -import ( - "fmt" - "go/ast" -) - -type WriteTableFieldDef struct { - RawExpression string `json:"expression"` - Type TableFieldType `json:"type"` - DefaultValue string `json:"default_value"` // Optional. If omitted, default zero value is used - ParsedExpression ast.Expr - UsedFields FieldRefs -} - -func GetFieldRefsUsedInAllTargetExpressions(fieldDefMap map[string]*WriteTableFieldDef) FieldRefs { - fieldRefMap := map[string]FieldRef{} - for _, targetFieldDef := range fieldDefMap { - for i := 0; i < len(targetFieldDef.UsedFields); i++ { - hash := fmt.Sprintf("%s.%s", targetFieldDef.UsedFields[i].TableName, targetFieldDef.UsedFields[i].FieldName) - if _, ok := fieldRefMap[hash]; !ok { - fieldRefMap[hash] = targetFieldDef.UsedFields[i] - } - } - } - - // Map to FieldRefs - fieldRefs := make([]FieldRef, len(fieldRefMap)) - i := 0 - for _, fieldRef := range fieldRefMap { - fieldRefs[i] = fieldRef - i++ - } - - return fieldRefs -} +package sc + +import ( + "fmt" + "go/ast" +) + +type WriteTableFieldDef struct { + RawExpression string `json:"expression"` + Type TableFieldType `json:"type"` + DefaultValue string `json:"default_value"` // Optional. If omitted, default zero value is used + ParsedExpression ast.Expr + UsedFields FieldRefs +} + +func GetFieldRefsUsedInAllTargetExpressions(fieldDefMap map[string]*WriteTableFieldDef) FieldRefs { + fieldRefMap := map[string]FieldRef{} + for _, targetFieldDef := range fieldDefMap { + for i := 0; i < len(targetFieldDef.UsedFields); i++ { + hash := fmt.Sprintf("%s.%s", targetFieldDef.UsedFields[i].TableName, targetFieldDef.UsedFields[i].FieldName) + if _, ok := fieldRefMap[hash]; !ok { + fieldRefMap[hash] = targetFieldDef.UsedFields[i] + } + } + } + + // Map to FieldRefs + fieldRefs := make([]FieldRef, len(fieldRefMap)) + i := 0 + for _, fieldRef := range fieldRefMap { + fieldRefs[i] = fieldRef + i++ + } + + return fieldRefs +} diff --git a/pkg/storage/parquet.go b/pkg/storage/parquet.go index de9cd7d..4401bbd 100644 --- a/pkg/storage/parquet.go +++ b/pkg/storage/parquet.go @@ -8,7 +8,7 @@ import ( "github.com/capillariesio/capillaries/pkg/sc" gp "github.com/fraugster/parquet-go" - gp_parquet "github.com/fraugster/parquet-go/parquet" + pgparquet "github.com/fraugster/parquet-go/parquet" "github.com/shopspring/decimal" ) @@ -18,10 +18,10 @@ type ParquetWriter struct { } func NewParquetWriter(ioWriter io.Writer, codec sc.ParquetCodecType) (*ParquetWriter, error) { - codecMap := map[sc.ParquetCodecType]gp_parquet.CompressionCodec{ - sc.ParquetCodecGzip: gp_parquet.CompressionCodec_GZIP, - sc.ParquetCodecSnappy: gp_parquet.CompressionCodec_SNAPPY, - sc.ParquetCodecUncompressed: gp_parquet.CompressionCodec_UNCOMPRESSED, + codecMap := map[sc.ParquetCodecType]pgparquet.CompressionCodec{ + sc.ParquetCodecGzip: pgparquet.CompressionCodec_GZIP, + sc.ParquetCodecSnappy: pgparquet.CompressionCodec_SNAPPY, + sc.ParquetCodecUncompressed: pgparquet.CompressionCodec_UNCOMPRESSED, } gpCodec, ok := codecMap[codec] if !ok { @@ -42,42 +42,42 @@ func (w *ParquetWriter) AddColumn(name string, fieldType sc.TableFieldType) erro var err error switch fieldType { case sc.FieldTypeString: - params := &gp.ColumnParameters{LogicalType: gp_parquet.NewLogicalType()} - params.LogicalType.STRING = gp_parquet.NewStringType() - params.ConvertedType = gp_parquet.ConvertedTypePtr(gp_parquet.ConvertedType_UTF8) - s, err = gp.NewByteArrayStore(gp_parquet.Encoding_PLAIN, true, params) + params := &gp.ColumnParameters{LogicalType: pgparquet.NewLogicalType()} + params.LogicalType.STRING = pgparquet.NewStringType() + params.ConvertedType = pgparquet.ConvertedTypePtr(pgparquet.ConvertedType_UTF8) + s, err = gp.NewByteArrayStore(pgparquet.Encoding_PLAIN, true, params) case sc.FieldTypeDateTime: - params := &gp.ColumnParameters{LogicalType: gp_parquet.NewLogicalType()} - params.LogicalType.TIMESTAMP = gp_parquet.NewTimestampType() - params.LogicalType.TIMESTAMP.Unit = gp_parquet.NewTimeUnit() + params := &gp.ColumnParameters{LogicalType: pgparquet.NewLogicalType()} + params.LogicalType.TIMESTAMP = pgparquet.NewTimestampType() + params.LogicalType.TIMESTAMP.Unit = pgparquet.NewTimeUnit() // Go and Parquet support nanoseconds. Unfortunately, Cassandra supports only milliseconds. Millis are our lingua franca. - params.LogicalType.TIMESTAMP.Unit.MILLIS = gp_parquet.NewMilliSeconds() - params.ConvertedType = gp_parquet.ConvertedTypePtr(gp_parquet.ConvertedType_TIMESTAMP_MILLIS) - s, err = gp.NewInt64Store(gp_parquet.Encoding_PLAIN, true, params) + params.LogicalType.TIMESTAMP.Unit.MILLIS = pgparquet.NewMilliSeconds() + params.ConvertedType = pgparquet.ConvertedTypePtr(pgparquet.ConvertedType_TIMESTAMP_MILLIS) + s, err = gp.NewInt64Store(pgparquet.Encoding_PLAIN, true, params) case sc.FieldTypeInt: - s, err = gp.NewInt64Store(gp_parquet.Encoding_PLAIN, true, &gp.ColumnParameters{}) + s, err = gp.NewInt64Store(pgparquet.Encoding_PLAIN, true, &gp.ColumnParameters{}) case sc.FieldTypeDecimal2: - params := &gp.ColumnParameters{LogicalType: gp_parquet.NewLogicalType()} - params.LogicalType.DECIMAL = gp_parquet.NewDecimalType() + params := &gp.ColumnParameters{LogicalType: pgparquet.NewLogicalType()} + params.LogicalType.DECIMAL = pgparquet.NewDecimalType() params.LogicalType.DECIMAL.Scale = 2 params.LogicalType.DECIMAL.Precision = 2 // This is to make fraugster/go-parquet happy so it writes this metadata, // see buildElement() implementation in schema.go params.Scale = ¶ms.LogicalType.DECIMAL.Scale params.Precision = ¶ms.LogicalType.DECIMAL.Precision - params.ConvertedType = gp_parquet.ConvertedTypePtr(gp_parquet.ConvertedType_DECIMAL) - s, err = gp.NewInt64Store(gp_parquet.Encoding_PLAIN, true, params) + params.ConvertedType = pgparquet.ConvertedTypePtr(pgparquet.ConvertedType_DECIMAL) + s, err = gp.NewInt64Store(pgparquet.Encoding_PLAIN, true, params) case sc.FieldTypeFloat: - s, err = gp.NewDoubleStore(gp_parquet.Encoding_PLAIN, true, &gp.ColumnParameters{}) + s, err = gp.NewDoubleStore(pgparquet.Encoding_PLAIN, true, &gp.ColumnParameters{}) case sc.FieldTypeBool: - s, err = gp.NewBooleanStore(gp_parquet.Encoding_PLAIN, &gp.ColumnParameters{}) + s, err = gp.NewBooleanStore(pgparquet.Encoding_PLAIN, &gp.ColumnParameters{}) default: return fmt.Errorf("cannot add %s column %s: unsupported field type", fieldType, name) } if err != nil { return fmt.Errorf("cannot create store for %s column %s: %s", fieldType, name, err.Error()) } - if err := w.FileWriter.AddColumn(name, gp.NewDataColumn(s, gp_parquet.FieldRepetitionType_OPTIONAL)); err != nil { + if err := w.FileWriter.AddColumnByPath([]string{name}, gp.NewDataColumn(s, pgparquet.FieldRepetitionType_OPTIONAL)); err != nil { return fmt.Errorf("cannot add %s column %s: %s", fieldType, name, err.Error()) } w.StoreMap[name] = s @@ -96,86 +96,86 @@ func (w *ParquetWriter) Close() error { } return nil } -func ParquetWriterMilliTs(t time.Time) interface{} { +func ParquetWriterMilliTs(t time.Time) any { if t.Equal(sc.DefaultDateTime()) { return nil - } else { - // Go and Parquet support nanoseconds. Unfortunately, Cassandra supports only milliseconds. Millis are our lingua franca. - return t.UnixMilli() } + + // Go and Parquet support nanoseconds. Unfortunately, Cassandra supports only milliseconds. Millis are our lingua franca. + return t.UnixMilli() } -func ParquetWriterDecimal2(dec decimal.Decimal) interface{} { +func ParquetWriterDecimal2(dec decimal.Decimal) any { return dec.Mul(decimal.NewFromInt(100)).IntPart() } -func isType(se *gp_parquet.SchemaElement, t gp_parquet.Type) bool { +func isType(se *pgparquet.SchemaElement, t pgparquet.Type) bool { return se.Type != nil && *se.Type == t } -func isLogicalOrConvertedString(se *gp_parquet.SchemaElement) bool { +func isLogicalOrConvertedString(se *pgparquet.SchemaElement) bool { return se.LogicalType != nil && se.LogicalType.STRING != nil || - se.ConvertedType != nil && *se.ConvertedType == gp_parquet.ConvertedType_UTF8 + se.ConvertedType != nil && *se.ConvertedType == pgparquet.ConvertedType_UTF8 } -func isLogicalOrConvertedDecimal(se *gp_parquet.SchemaElement) bool { +func isLogicalOrConvertedDecimal(se *pgparquet.SchemaElement) bool { return se.LogicalType != nil && se.LogicalType.DECIMAL != nil || - se.ConvertedType != nil && *se.ConvertedType == gp_parquet.ConvertedType_DECIMAL + se.ConvertedType != nil && *se.ConvertedType == pgparquet.ConvertedType_DECIMAL } -func isLogicalOrConvertedDateTime(se *gp_parquet.SchemaElement) bool { +func isLogicalOrConvertedDateTime(se *pgparquet.SchemaElement) bool { return se.LogicalType != nil && se.LogicalType.TIMESTAMP != nil || - se.ConvertedType != nil && (*se.ConvertedType == gp_parquet.ConvertedType_TIMESTAMP_MILLIS || *se.ConvertedType == gp_parquet.ConvertedType_TIMESTAMP_MICROS) + se.ConvertedType != nil && (*se.ConvertedType == pgparquet.ConvertedType_TIMESTAMP_MILLIS || *se.ConvertedType == pgparquet.ConvertedType_TIMESTAMP_MICROS) } -func isParquetString(se *gp_parquet.SchemaElement) bool { - return isLogicalOrConvertedString(se) && isType(se, gp_parquet.Type_BYTE_ARRAY) +func isParquetString(se *pgparquet.SchemaElement) bool { + return isLogicalOrConvertedString(se) && isType(se, pgparquet.Type_BYTE_ARRAY) } -func isParquetIntDecimal2(se *gp_parquet.SchemaElement) bool { +func isParquetIntDecimal2(se *pgparquet.SchemaElement) bool { return isLogicalOrConvertedDecimal(se) && - (isType(se, gp_parquet.Type_INT32) || isType(se, gp_parquet.Type_INT64)) && + (isType(se, pgparquet.Type_INT32) || isType(se, pgparquet.Type_INT64)) && se.Scale != nil && *se.Scale > -20 && *se.Scale < 20 && se.Precision != nil && *se.Precision >= 0 && *se.Precision < 18 } -func isParquetFixedLengthByteArrayDecimal2(se *gp_parquet.SchemaElement) bool { +func isParquetFixedLengthByteArrayDecimal2(se *pgparquet.SchemaElement) bool { return isLogicalOrConvertedDecimal(se) && - isType(se, gp_parquet.Type_FIXED_LEN_BYTE_ARRAY) && + isType(se, pgparquet.Type_FIXED_LEN_BYTE_ARRAY) && se.Scale != nil && *se.Scale > -20 && *se.Scale < 20 && se.Precision != nil && *se.Precision >= 0 && *se.Precision <= 38 } -func isParquetDateTime(se *gp_parquet.SchemaElement) bool { +func isParquetDateTime(se *pgparquet.SchemaElement) bool { return isLogicalOrConvertedDateTime(se) && - (isType(se, gp_parquet.Type_INT32) || isType(se, gp_parquet.Type_INT64)) + (isType(se, pgparquet.Type_INT32) || isType(se, pgparquet.Type_INT64)) } -func isParquetInt96Date(se *gp_parquet.SchemaElement) bool { - return isType(se, gp_parquet.Type_INT96) +func isParquetInt96Date(se *pgparquet.SchemaElement) bool { + return isType(se, pgparquet.Type_INT96) } -func isParquetInt32Date(se *gp_parquet.SchemaElement) bool { - return se.Type != nil && *se.Type == gp_parquet.Type_INT32 && +func isParquetInt32Date(se *pgparquet.SchemaElement) bool { + return se.Type != nil && *se.Type == pgparquet.Type_INT32 && se.LogicalType != nil && se.LogicalType.DATE != nil } -func isParquetInt(se *gp_parquet.SchemaElement) bool { +func isParquetInt(se *pgparquet.SchemaElement) bool { return (se.LogicalType == nil || se.LogicalType != nil && se.LogicalType.INTEGER != nil) && - se.Type != nil && (*se.Type == gp_parquet.Type_INT32 || *se.Type == gp_parquet.Type_INT64) + se.Type != nil && (*se.Type == pgparquet.Type_INT32 || *se.Type == pgparquet.Type_INT64) } -func isParquetFloat(se *gp_parquet.SchemaElement) bool { +func isParquetFloat(se *pgparquet.SchemaElement) bool { return se.LogicalType == nil && - se.Type != nil && (*se.Type == gp_parquet.Type_FLOAT || *se.Type == gp_parquet.Type_DOUBLE) + se.Type != nil && (*se.Type == pgparquet.Type_FLOAT || *se.Type == pgparquet.Type_DOUBLE) } -func isParquetBool(se *gp_parquet.SchemaElement) bool { +func isParquetBool(se *pgparquet.SchemaElement) bool { return se.LogicalType == nil && - se.Type != nil && *se.Type == gp_parquet.Type_BOOLEAN + se.Type != nil && *se.Type == pgparquet.Type_BOOLEAN } -func ParquetGuessCapiType(se *gp_parquet.SchemaElement) (sc.TableFieldType, error) { +func ParquetGuessCapiType(se *pgparquet.SchemaElement) (sc.TableFieldType, error) { if isParquetString(se) { return sc.FieldTypeString, nil } else if isParquetIntDecimal2(se) || isParquetFixedLengthByteArrayDecimal2(se) { @@ -188,12 +188,11 @@ func ParquetGuessCapiType(se *gp_parquet.SchemaElement) (sc.TableFieldType, erro return sc.FieldTypeFloat, nil } else if isParquetBool(se) { return sc.FieldTypeBool, nil - } else { - return sc.FieldTypeUnknown, fmt.Errorf("parquet schema element not supported: %v", se) } + return sc.FieldTypeUnknown, fmt.Errorf("parquet schema element not supported: %v", se) } -func ParquetReadString(val interface{}, se *gp_parquet.SchemaElement) (string, error) { +func ParquetReadString(val any, se *pgparquet.SchemaElement) (string, error) { if !isParquetString(se) { return sc.DefaultString, fmt.Errorf("cannot read parquet string, schema %v", se) } @@ -204,7 +203,7 @@ func ParquetReadString(val interface{}, se *gp_parquet.SchemaElement) (string, e return string(typedVal), nil } -func ParquetReadDateTime(val interface{}, se *gp_parquet.SchemaElement) (time.Time, error) { +func ParquetReadDateTime(val any, se *pgparquet.SchemaElement) (time.Time, error) { if !isParquetDateTime(se) && !isParquetInt96Date(se) && !isParquetInt32Date(se) { return sc.DefaultDateTime(), fmt.Errorf("cannot read parquet datetime, schema %v", se) } @@ -218,9 +217,9 @@ func ParquetReadDateTime(val interface{}, se *gp_parquet.SchemaElement) (time.Ti return time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC).AddDate(0, 0, int(typedVal)).In(time.UTC), nil } else { switch *se.ConvertedType { - case gp_parquet.ConvertedType_TIMESTAMP_MILLIS: + case pgparquet.ConvertedType_TIMESTAMP_MILLIS: return time.UnixMilli(int64(typedVal)).In(time.UTC), nil - case gp_parquet.ConvertedType_TIMESTAMP_MICROS: + case pgparquet.ConvertedType_TIMESTAMP_MICROS: return time.UnixMicro(int64(typedVal)).In(time.UTC), nil default: return sc.DefaultDateTime(), fmt.Errorf("cannot read parquet datetime from int32, unsupported converted type, schema %v", se) @@ -228,9 +227,9 @@ func ParquetReadDateTime(val interface{}, se *gp_parquet.SchemaElement) (time.Ti } case int64: switch *se.ConvertedType { - case gp_parquet.ConvertedType_TIMESTAMP_MILLIS: + case pgparquet.ConvertedType_TIMESTAMP_MILLIS: return time.UnixMilli(typedVal).In(time.UTC), nil - case gp_parquet.ConvertedType_TIMESTAMP_MICROS: + case pgparquet.ConvertedType_TIMESTAMP_MICROS: return time.UnixMicro(typedVal).In(time.UTC), nil default: return sc.DefaultDateTime(), fmt.Errorf("cannot read parquet datetime from int64, unsupported converted type, schema %v", se) @@ -243,7 +242,7 @@ func ParquetReadDateTime(val interface{}, se *gp_parquet.SchemaElement) (time.Ti } } -func ParquetReadInt(val interface{}, se *gp_parquet.SchemaElement) (int64, error) { +func ParquetReadInt(val any, se *pgparquet.SchemaElement) (int64, error) { if !isParquetInt(se) { return sc.DefaultInt, fmt.Errorf("cannot read parquet int, schema %v", se) } @@ -269,7 +268,7 @@ func ParquetReadInt(val interface{}, se *gp_parquet.SchemaElement) (int64, error } } -func ParquetReadDecimal2(val interface{}, se *gp_parquet.SchemaElement) (decimal.Decimal, error) { +func ParquetReadDecimal2(val any, se *pgparquet.SchemaElement) (decimal.Decimal, error) { if !isParquetIntDecimal2(se) && !isParquetFixedLengthByteArrayDecimal2(se) { return sc.DefaultDecimal2(), fmt.Errorf("cannot read parquet decimal2, schema %v", se) } @@ -311,7 +310,7 @@ func ParquetReadDecimal2(val interface{}, se *gp_parquet.SchemaElement) (decimal } } -func ParquetReadFloat(val interface{}, se *gp_parquet.SchemaElement) (float64, error) { +func ParquetReadFloat(val any, se *pgparquet.SchemaElement) (float64, error) { if !isParquetFloat(se) { return sc.DefaultFloat, fmt.Errorf("cannot read parquet float, schema %v", se) } @@ -325,7 +324,7 @@ func ParquetReadFloat(val interface{}, se *gp_parquet.SchemaElement) (float64, e } } -func ParquetReadBool(val interface{}, se *gp_parquet.SchemaElement) (bool, error) { +func ParquetReadBool(val any, se *pgparquet.SchemaElement) (bool, error) { if !isParquetBool(se) { return sc.DefaultBool, fmt.Errorf("cannot read parquet float, schema %v", se) } diff --git a/pkg/wf/amqp_client.go b/pkg/wf/amqp_client.go index 9cb1270..dc18d6f 100644 --- a/pkg/wf/amqp_client.go +++ b/pkg/wf/amqp_client.go @@ -1,354 +1,357 @@ -package wf - -import ( - "fmt" - "os" - "time" - - "github.com/capillariesio/capillaries/pkg/env" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/wfmodel" - amqp "github.com/rabbitmq/amqp091-go" -) - -const DlxSuffix string = "_dlx" - -type DaemonCmdType int8 - -const ( - DaemonCmdNone DaemonCmdType = 0 // Should never see this - DaemonCmdAckSuccess DaemonCmdType = 2 // Best case - DaemonCmdRejectAndRetryLater DaemonCmdType = 3 // Node dependencies are not ready, wait with proessing this node - DaemonCmdReconnectDb DaemonCmdType = 4 // Db workflow error, try to reconnect - DaemonCmdQuit DaemonCmdType = 5 // Shutdown command was received - DaemonCmdAckWithError DaemonCmdType = 6 // There was a processing error: either some serious biz logic re-trying will not help, or it was a data table error (we consider it persistent), so ack it - DaemonCmdReconnectQueue DaemonCmdType = 7 // Queue error, try to reconnect -) - -func (daemonCmd DaemonCmdType) ToString() string { - switch daemonCmd { - case DaemonCmdNone: - return "none" - case DaemonCmdAckSuccess: - return "sucess" - case DaemonCmdRejectAndRetryLater: - return "reject_and_retry" - case DaemonCmdReconnectDb: - return "reconnect_db" - case DaemonCmdQuit: - return "quit" - case DaemonCmdAckWithError: - return "ack_with_error" - case DaemonCmdReconnectQueue: - return "reconnect_queue" - default: - return "unknown" - } -} - -/* -amqpDeliveryToString - helper to print the contents of amqp.Delivery object -*/ -func amqpDeliveryToString(d amqp.Delivery) string { - // Do not just do Sprintf("%v", m), it will print the whole Body and it can be very long - return fmt.Sprintf("Headers:%v, ContentType:%v, ContentEncoding:%v, DeliveryMode:%v, Priority:%v, CorrelationId:%v, ReplyTo:%v, Expiration:%v, MessageId:%v, Timestamp:%v, Type:%v, UserId:%v, AppId:%v, ConsumerTag:%v, MessageCount:%v, DeliveryTag:%v, Redelivered:%v, Exchange:%v, RoutingKey:%v, len(Body):%d", - d.Headers, - d.ContentType, - d.ContentEncoding, - d.DeliveryMode, - d.Priority, - d.CorrelationId, - d.ReplyTo, - d.Expiration, - d.MessageId, - d.Timestamp, - d.Type, - d.UserId, - d.AppId, - d.ConsumerTag, - d.MessageCount, - d.DeliveryTag, - d.Redelivered, - d.Exchange, - d.RoutingKey, - len(d.Body)) -} - -func processDelivery(envConfig *env.EnvConfig, logger *l.Logger, delivery *amqp.Delivery) DaemonCmdType { - logger.PushF("wf.processDelivery") - defer logger.PopF() - - // Deserialize incoming message - var msgIn wfmodel.Message - errDeserialize := msgIn.Deserialize(delivery.Body) - if errDeserialize != nil { - logger.Error("cannot deserialize incoming message: %s. %v", errDeserialize.Error(), delivery.Body) - return DaemonCmdAckWithError - } - - switch msgIn.MessageType { - case wfmodel.MessageTypeDataBatch: - dataBatchInfo, ok := msgIn.Payload.(wfmodel.MessagePayloadDataBatch) - if !ok { - logger.Error("unexpected type of data batch payload: %T", msgIn.Payload) - return DaemonCmdAckWithError - } - return ProcessDataBatchMsg(envConfig, logger, msgIn.Ts, &dataBatchInfo) - - // TODO: other commands like debug level or shutdown go here - default: - logger.Error("unexpected message type %d", msgIn.MessageType) - return DaemonCmdAckWithError - } -} - -func AmqpFullReconnectCycle(envConfig *env.EnvConfig, logger *l.Logger, osSignalChannel chan os.Signal) DaemonCmdType { - logger.PushF("wf.AmqpFullReconnectCycle") - defer logger.PopF() - - amqpConnection, err := amqp.Dial(envConfig.Amqp.URL) - if err != nil { - logger.Error("cannot dial RabbitMQ at %s, will reconnect: %s", envConfig.Amqp.URL, err.Error()) - return DaemonCmdReconnectQueue - } - - // Subscribe to errors, this is how we handle queue failures - chanErrors := amqpConnection.NotifyClose(make(chan *amqp.Error)) - var daemonCmd DaemonCmdType - - amqpChannel, err := amqpConnection.Channel() - if err != nil { - logger.Error("cannot create amqp channel, will reconnect: %s", err.Error()) - daemonCmd = DaemonCmdReconnectQueue - } else { - daemonCmd = amqpConnectAndSelect(envConfig, logger, osSignalChannel, amqpChannel, chanErrors) - time.Sleep(1000) - logger.Info("consuming %d amqp errors to avoid close deadlock...", len(chanErrors)) - for len(chanErrors) > 0 { - chanErr := <-chanErrors - logger.Info("consuming amqp error to avoid close deadlock: %v", chanErr) - } - logger.Info("consumed amqp errors, closing channel") - amqpChannel.Close() - logger.Info("consumed amqp errors, closed channel") - } - logger.Info("closing connection") - amqpConnection.Close() - logger.Info("closed connection") - return daemonCmd -} - -func amqpConnectAndSelect(envConfig *env.EnvConfig, logger *l.Logger, osSignalChannel chan os.Signal, amqpChannel *amqp.Channel, chanAmqpErrors chan *amqp.Error) DaemonCmdType { - logger.PushF("wf.amqpConnectAndSelect") - defer logger.PopF() - - errExchange := amqpChannel.ExchangeDeclare( - envConfig.Amqp.Exchange, // exchange name - "direct", // type, "direct" - false, // durable - false, // auto-deleted - false, // internal - false, // no-wait - nil) // arguments - if errExchange != nil { - logger.Error("cannot declare exchange %s, will reconnect: %s", envConfig.Amqp.Exchange, errExchange.Error()) - return DaemonCmdReconnectQueue - } - - errExchange = amqpChannel.ExchangeDeclare( - envConfig.Amqp.Exchange+DlxSuffix, // exchange name - "direct", // type - false, // durable - false, // auto-deleted - false, // internal - false, // no-wait - nil) // arguments - if errExchange != nil { - logger.Error("cannot declare exchange %s, will reconnect: %s", envConfig.Amqp.Exchange+DlxSuffix, errExchange.Error()) - return DaemonCmdReconnectQueue - } - - // TODO: declare exchange for non-data signals and handle them in a separate queue - - amqpQueue, err := amqpChannel.QueueDeclare( - envConfig.HandlerExecutableType, // queue name, matches routing key - false, // durable - false, // delete when unused - false, // exclusive - false, // no-wait - amqp.Table{"x-dead-letter-exchange": envConfig.Amqp.Exchange + DlxSuffix, "x-dead-letter-routing-key": envConfig.HandlerExecutableType + DlxSuffix}) // arguments - if err != nil { - logger.Error("cannot declare queue %s, will reconnect: %s\n", envConfig.HandlerExecutableType, err.Error()) - return DaemonCmdReconnectQueue - } - - amqpQueueDlx, err := amqpChannel.QueueDeclare( - envConfig.HandlerExecutableType+DlxSuffix, // queue name, matches routing key - false, // durable - false, // delete when unused - false, // exclusive - false, // no-wait - amqp.Table{"x-dead-letter-exchange": envConfig.Amqp.Exchange, "x-dead-letter-routing-key": envConfig.HandlerExecutableType, "x-message-ttl": envConfig.DeadLetterTtl}) - if err != nil { - logger.Error("cannot declare queue %s, will reconnect: %s\n", envConfig.HandlerExecutableType+DlxSuffix, err.Error()) - return DaemonCmdReconnectQueue - } - - errBind := amqpChannel.QueueBind( - amqpQueue.Name, // queue name - envConfig.HandlerExecutableType, // routing key / handler exe type - envConfig.Amqp.Exchange, // exchange - false, // nowait - nil) // args - if errBind != nil { - logger.Error("cannot bind queue %s with routing key %s, exchange %s , will reconnect: %s", amqpQueue.Name, envConfig.HandlerExecutableType, envConfig.Amqp.Exchange, errBind.Error()) - return DaemonCmdReconnectQueue - } - - errBind = amqpChannel.QueueBind( - amqpQueueDlx.Name, // queue name - envConfig.HandlerExecutableType+DlxSuffix, // routing key / handler exe type - envConfig.Amqp.Exchange+DlxSuffix, // exchange - false, // nowait - nil) // args - if errBind != nil { - logger.Error("cannot bind queue %s with routing key %s, exchange %s , will reconnect: %s", amqpQueueDlx.Name, envConfig.HandlerExecutableType+DlxSuffix, envConfig.Amqp.Exchange+DlxSuffix, errBind.Error()) - return DaemonCmdReconnectQueue - } - - errQos := amqpChannel.Qos(envConfig.Amqp.PrefetchCount, envConfig.Amqp.PrefetchSize, false) - if errQos != nil { - logger.Error("cannot set Qos, will reconnect: %s", errQos.Error()) - return DaemonCmdReconnectQueue - } - - ampqChannelConsumerTag := logger.ZapMachine.String + "/consumer" - chanDeliveries, err := amqpChannel.Consume( - amqpQueue.Name, // queue - ampqChannelConsumerTag, // unique consumer tag, default go ampq implementation is os.argv[0] (is it really unique?) - false, // auto-ack - false, // exclusive - false, // no-local - flag not supportd by rabbit - false, // no-wait - nil) // args - if err != nil { - logger.Error("cannot register consumer, queue %s, will reconnect: %s", amqpQueue.Name, err.Error()) - return DaemonCmdReconnectQueue - } - logger.Info("started consuming queue %s, routing key %s, exchange %s", amqpQueue.Name, envConfig.HandlerExecutableType, envConfig.Amqp.Exchange) - - var sem = make(chan int, envConfig.ThreadPoolSize) - - // daemonCommands len should be > ThreadPoolSize, otherwise on reconnect, we will get a deadlock: - // "still waiting for all workers to complete" will wait for one or more workers that will try adding - // "daemonCommands <- DaemonCmdReconnectDb" to the channel. Play safe by multiplying by 2. - var daemonCommands = make(chan DaemonCmdType, envConfig.ThreadPoolSize*2) - - for { - select { - case osSignal := <-osSignalChannel: - if osSignal == os.Interrupt || osSignal == os.Kill { - logger.Info("received os signal %v, sending quit...", osSignal) - daemonCommands <- DaemonCmdQuit - } - - case chanErr := <-chanAmqpErrors: - if chanErr != nil { - logger.Error("detected closed amqp channel, will reconnect: %s", chanErr.Error()) - } else { - logger.Error("detected closed amqp channel, will reconnect: nil error received") - } - daemonCommands <- DaemonCmdReconnectQueue - - case finalDaemonCmd := <-daemonCommands: - - // Here, we expect DaemonCmdReconnectDb, DaemonCmdReconnectQueue, DaemonCmdQuit. All of them require channel.Cancel() - - logger.Info("detected daemon cmd %d(%s), cancelling channel...", finalDaemonCmd, finalDaemonCmd.ToString()) - if err := amqpChannel.Cancel(ampqChannelConsumerTag, false); err != nil { - logger.Error("cannot cancel amqp channel: %s", err.Error()) - } else { - logger.Info("channel cancelled successfully") - } - - logger.Info("handling daemon cmd %d(%s), waiting for all workers to complete (%d items)...", finalDaemonCmd, finalDaemonCmd.ToString(), len(sem)) - for len(sem) > 0 { - logger.Info("still waiting for all workers to complete (%d items left)...", len(sem)) - time.Sleep(1000 * time.Millisecond) - } - - logger.Info("handling daemon cmd %d(%s), all workers completed, draining cmd channel (%d items)...", finalDaemonCmd, finalDaemonCmd.ToString(), len(daemonCommands)) - for len(daemonCommands) > 0 { - daemonCmd := <-daemonCommands - // Do not ignore quit command, make sure it makes it to the finals - if daemonCmd == DaemonCmdQuit { - finalDaemonCmd = DaemonCmdQuit - } - } - logger.Info("final daemon cmd %d(%s), all workers complete, cmd channel drained", finalDaemonCmd, finalDaemonCmd.ToString()) - return finalDaemonCmd - - case amqpDelivery := <-chanDeliveries: - - threadLogger, err := l.NewLoggerFromLogger(logger) - if err != nil { - logger.Error("cannot create logger for delivery handler thread: %s", err.Error()) - return DaemonCmdQuit - } - - logger.PushF("wf.amqpConnectAndSelect_worker") - defer logger.PopF() - - // Lock one slot in the semaphore - sem <- 1 - - go func(threadLogger *l.Logger, delivery amqp.Delivery, _channel amqp.Channel) { - var err error - - // I have spotted cases when m.Body is empty and Aknowledger is nil. Handle them. - if delivery.Acknowledger == nil { - threadLogger.Error("processor detected empty Acknowledger, assuming closed amqp channel, will reconnect: %s", amqpDeliveryToString(delivery)) - daemonCommands <- DaemonCmdReconnectQueue - } else { - daemonCmd := processDelivery(envConfig, threadLogger, &delivery) - - if daemonCmd == DaemonCmdAckSuccess || daemonCmd == DaemonCmdAckWithError { - err = delivery.Ack(false) - if err != nil { - threadLogger.Error("failed to ack message, will reconnect: %s", err.Error()) - daemonCommands <- DaemonCmdReconnectQueue - } - } else if daemonCmd == DaemonCmdRejectAndRetryLater { - err = delivery.Reject(false) - if err != nil { - threadLogger.Error("failed to reject message, will reconnect: %s", err.Error()) - daemonCommands <- DaemonCmdReconnectQueue - } - } else if daemonCmd == DaemonCmdReconnectQueue || daemonCmd == DaemonCmdReconnectDb { - // // Ideally, RabbitMQ should be smart enough to re-deliver a msg that was neither acked nor rejected. - // // But apparently, sometimes (when the machine goes to sleep, for example) the msg is never re-delivered. To improve our chances, force re-delivery by rejecting the msg. - // threadLogger.Error("daemonCmd %s detected, will reject(requeue) and reconnect", daemonCmd.ToString()) - // err = delivery.Reject(true) - // if err != nil { - // threadLogger.Error("failed to reject message, will reconnect: %s", err.Error()) - // daemonCommands <- DaemonCmdReconnectQueue - // } else { - // daemonCommands <- daemonCmd - // } - - // Verdict: we do not handle machine sleep scenario, amqp091-go goes into deadlock when shutting down. - daemonCommands <- daemonCmd - } else if daemonCmd == DaemonCmdQuit { - daemonCommands <- DaemonCmdQuit - } else { - threadLogger.Error("unexpected daemon cmd: %d", daemonCmd) - } - } - - // Unlock semaphore slot - <-sem - - }(threadLogger, amqpDelivery, *amqpChannel) - } - } -} +package wf + +import ( + "fmt" + "os" + "time" + + "github.com/capillariesio/capillaries/pkg/env" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/wfmodel" + amqp "github.com/rabbitmq/amqp091-go" +) + +const DlxSuffix string = "_dlx" + +type DaemonCmdType int8 + +const ( + DaemonCmdNone DaemonCmdType = 0 // Should never see this + DaemonCmdAckSuccess DaemonCmdType = 2 // Best case + DaemonCmdRejectAndRetryLater DaemonCmdType = 3 // Node dependencies are not ready, wait with proessing this node + DaemonCmdReconnectDb DaemonCmdType = 4 // Db workflow error, try to reconnect + DaemonCmdQuit DaemonCmdType = 5 // Shutdown command was received + DaemonCmdAckWithError DaemonCmdType = 6 // There was a processing error: either some serious biz logic re-trying will not help, or it was a data table error (we consider it persistent), so ack it + DaemonCmdReconnectQueue DaemonCmdType = 7 // Queue error, try to reconnect +) + +func (daemonCmd DaemonCmdType) ToString() string { + switch daemonCmd { + case DaemonCmdNone: + return "none" + case DaemonCmdAckSuccess: + return "success" + case DaemonCmdRejectAndRetryLater: + return "reject_and_retry" + case DaemonCmdReconnectDb: + return "reconnect_db" + case DaemonCmdQuit: + return "quit" + case DaemonCmdAckWithError: + return "ack_with_error" + case DaemonCmdReconnectQueue: + return "reconnect_queue" + default: + return "unknown" + } +} + +/* +amqpDeliveryToString - helper to print the contents of amqp.Delivery object +*/ +func amqpDeliveryToString(d amqp.Delivery) string { + // Do not just do Sprintf("%v", m), it will print the whole Body and it can be very long + return fmt.Sprintf("Headers:%v, ContentType:%v, ContentEncoding:%v, DeliveryMode:%v, Priority:%v, CorrelationId:%v, ReplyTo:%v, Expiration:%v, MessageId:%v, Timestamp:%v, Type:%v, UserId:%v, AppId:%v, ConsumerTag:%v, MessageCount:%v, DeliveryTag:%v, Redelivered:%v, Exchange:%v, RoutingKey:%v, len(Body):%d", + d.Headers, + d.ContentType, + d.ContentEncoding, + d.DeliveryMode, + d.Priority, + d.CorrelationId, + d.ReplyTo, + d.Expiration, + d.MessageId, + d.Timestamp, + d.Type, + d.UserId, + d.AppId, + d.ConsumerTag, + d.MessageCount, + d.DeliveryTag, + d.Redelivered, + d.Exchange, + d.RoutingKey, + len(d.Body)) +} + +func processDelivery(envConfig *env.EnvConfig, logger *l.CapiLogger, delivery *amqp.Delivery) DaemonCmdType { + logger.PushF("wf.processDelivery") + defer logger.PopF() + + // Deserialize incoming message + var msgIn wfmodel.Message + errDeserialize := msgIn.Deserialize(delivery.Body) + if errDeserialize != nil { + logger.Error("cannot deserialize incoming message: %s. %v", errDeserialize.Error(), delivery.Body) + return DaemonCmdAckWithError + } + + switch msgIn.MessageType { + case wfmodel.MessageTypeDataBatch: + dataBatchInfo, ok := msgIn.Payload.(wfmodel.MessagePayloadDataBatch) + if !ok { + logger.Error("unexpected type of data batch payload: %T", msgIn.Payload) + return DaemonCmdAckWithError + } + return ProcessDataBatchMsg(envConfig, logger, msgIn.Ts, &dataBatchInfo) + + // TODO: other commands like debug level or shutdown go here + default: + logger.Error("unexpected message type %d", msgIn.MessageType) + return DaemonCmdAckWithError + } +} + +func AmqpFullReconnectCycle(envConfig *env.EnvConfig, logger *l.CapiLogger, osSignalChannel chan os.Signal) DaemonCmdType { + logger.PushF("wf.AmqpFullReconnectCycle") + defer logger.PopF() + + amqpConnection, err := amqp.Dial(envConfig.Amqp.URL) + if err != nil { + logger.Error("cannot dial RabbitMQ at %s, will reconnect: %s", envConfig.Amqp.URL, err.Error()) + return DaemonCmdReconnectQueue + } + + // Subscribe to errors, this is how we handle queue failures + chanErrors := amqpConnection.NotifyClose(make(chan *amqp.Error)) + var daemonCmd DaemonCmdType + + amqpChannel, err := amqpConnection.Channel() + if err != nil { + logger.Error("cannot create amqp channel, will reconnect: %s", err.Error()) + daemonCmd = DaemonCmdReconnectQueue + } else { + daemonCmd = amqpConnectAndSelect(envConfig, logger, osSignalChannel, amqpChannel, chanErrors) + time.Sleep(1000) + logger.Info("consuming %d amqp errors to avoid close deadlock...", len(chanErrors)) + for len(chanErrors) > 0 { + chanErr := <-chanErrors + logger.Info("consuming amqp error to avoid close deadlock: %v", chanErr) + } + logger.Info("consumed amqp errors, closing channel") + amqpChannel.Close() + logger.Info("consumed amqp errors, closed channel") + } + logger.Info("closing connection") + amqpConnection.Close() + logger.Info("closed connection") + return daemonCmd +} + +func amqpConnectAndSelect(envConfig *env.EnvConfig, logger *l.CapiLogger, osSignalChannel chan os.Signal, amqpChannel *amqp.Channel, chanAmqpErrors chan *amqp.Error) DaemonCmdType { + logger.PushF("wf.amqpConnectAndSelect") + defer logger.PopF() + + errExchange := amqpChannel.ExchangeDeclare( + envConfig.Amqp.Exchange, // exchange name + "direct", // type, "direct" + false, // durable + false, // auto-deleted + false, // internal + false, // no-wait + nil) // arguments + if errExchange != nil { + logger.Error("cannot declare exchange %s, will reconnect: %s", envConfig.Amqp.Exchange, errExchange.Error()) + return DaemonCmdReconnectQueue + } + + errExchange = amqpChannel.ExchangeDeclare( + envConfig.Amqp.Exchange+DlxSuffix, // exchange name + "direct", // type + false, // durable + false, // auto-deleted + false, // internal + false, // no-wait + nil) // arguments + if errExchange != nil { + logger.Error("cannot declare exchange %s, will reconnect: %s", envConfig.Amqp.Exchange+DlxSuffix, errExchange.Error()) + return DaemonCmdReconnectQueue + } + + // TODO: declare exchange for non-data signals and handle them in a separate queue + + amqpQueue, err := amqpChannel.QueueDeclare( + envConfig.HandlerExecutableType, // queue name, matches routing key + false, // durable + false, // delete when unused + false, // exclusive + false, // no-wait + amqp.Table{"x-dead-letter-exchange": envConfig.Amqp.Exchange + DlxSuffix, "x-dead-letter-routing-key": envConfig.HandlerExecutableType + DlxSuffix}) // arguments + if err != nil { + logger.Error("cannot declare queue %s, will reconnect: %s\n", envConfig.HandlerExecutableType, err.Error()) + return DaemonCmdReconnectQueue + } + + amqpQueueDlx, err := amqpChannel.QueueDeclare( + envConfig.HandlerExecutableType+DlxSuffix, // queue name, matches routing key + false, // durable + false, // delete when unused + false, // exclusive + false, // no-wait + amqp.Table{"x-dead-letter-exchange": envConfig.Amqp.Exchange, "x-dead-letter-routing-key": envConfig.HandlerExecutableType, "x-message-ttl": envConfig.DeadLetterTtl}) + if err != nil { + logger.Error("cannot declare queue %s, will reconnect: %s\n", envConfig.HandlerExecutableType+DlxSuffix, err.Error()) + return DaemonCmdReconnectQueue + } + + errBind := amqpChannel.QueueBind( + amqpQueue.Name, // queue name + envConfig.HandlerExecutableType, // routing key / handler exe type + envConfig.Amqp.Exchange, // exchange + false, // nowait + nil) // args + if errBind != nil { + logger.Error("cannot bind queue %s with routing key %s, exchange %s , will reconnect: %s", amqpQueue.Name, envConfig.HandlerExecutableType, envConfig.Amqp.Exchange, errBind.Error()) + return DaemonCmdReconnectQueue + } + + errBind = amqpChannel.QueueBind( + amqpQueueDlx.Name, // queue name + envConfig.HandlerExecutableType+DlxSuffix, // routing key / handler exe type + envConfig.Amqp.Exchange+DlxSuffix, // exchange + false, // nowait + nil) // args + if errBind != nil { + logger.Error("cannot bind queue %s with routing key %s, exchange %s , will reconnect: %s", amqpQueueDlx.Name, envConfig.HandlerExecutableType+DlxSuffix, envConfig.Amqp.Exchange+DlxSuffix, errBind.Error()) + return DaemonCmdReconnectQueue + } + + errQos := amqpChannel.Qos(envConfig.Amqp.PrefetchCount, envConfig.Amqp.PrefetchSize, false) + if errQos != nil { + logger.Error("cannot set Qos, will reconnect: %s", errQos.Error()) + return DaemonCmdReconnectQueue + } + + ampqChannelConsumerTag := logger.ZapMachine.String + "/consumer" + chanDeliveries, err := amqpChannel.Consume( + amqpQueue.Name, // queue + ampqChannelConsumerTag, // unique consumer tag, default go ampq implementation is os.argv[0] (is it really unique?) + false, // auto-ack + false, // exclusive + false, // no-local - flag not supportd by rabbit + false, // no-wait + nil) // args + if err != nil { + logger.Error("cannot register consumer, queue %s, will reconnect: %s", amqpQueue.Name, err.Error()) + return DaemonCmdReconnectQueue + } + logger.Info("started consuming queue %s, routing key %s, exchange %s", amqpQueue.Name, envConfig.HandlerExecutableType, envConfig.Amqp.Exchange) + + var sem = make(chan int, envConfig.ThreadPoolSize) + + // daemonCommands len should be > ThreadPoolSize, otherwise on reconnect, we will get a deadlock: + // "still waiting for all workers to complete" will wait for one or more workers that will try adding + // "daemonCommands <- DaemonCmdReconnectDb" to the channel. Play safe by multiplying by 2. + var daemonCommands = make(chan DaemonCmdType, envConfig.ThreadPoolSize*2) + + for { + select { + case osSignal := <-osSignalChannel: + if osSignal == os.Interrupt || osSignal == os.Kill { + logger.Info("received os signal %v, sending quit...", osSignal) + daemonCommands <- DaemonCmdQuit + } + + case chanErr := <-chanAmqpErrors: + if chanErr != nil { + logger.Error("detected closed amqp channel, will reconnect: %s", chanErr.Error()) + } else { + logger.Error("detected closed amqp channel, will reconnect: nil error received") + } + daemonCommands <- DaemonCmdReconnectQueue + + case finalDaemonCmd := <-daemonCommands: + + // Here, we expect DaemonCmdReconnectDb, DaemonCmdReconnectQueue, DaemonCmdQuit. All of them require channel.Cancel() + + logger.Info("detected daemon cmd %d(%s), cancelling channel...", finalDaemonCmd, finalDaemonCmd.ToString()) + if err := amqpChannel.Cancel(ampqChannelConsumerTag, false); err != nil { + logger.Error("cannot cancel amqp channel: %s", err.Error()) + } else { + logger.Info("channel cancelled successfully") + } + + logger.Info("handling daemon cmd %d(%s), waiting for all workers to complete (%d items)...", finalDaemonCmd, finalDaemonCmd.ToString(), len(sem)) + for len(sem) > 0 { + logger.Info("still waiting for all workers to complete (%d items left)...", len(sem)) + time.Sleep(1000 * time.Millisecond) + } + + logger.Info("handling daemon cmd %d(%s), all workers completed, draining cmd channel (%d items)...", finalDaemonCmd, finalDaemonCmd.ToString(), len(daemonCommands)) + for len(daemonCommands) > 0 { + daemonCmd := <-daemonCommands + // Do not ignore quit command, make sure it makes it to the finals + if daemonCmd == DaemonCmdQuit { + finalDaemonCmd = DaemonCmdQuit + } + } + logger.Info("final daemon cmd %d(%s), all workers complete, cmd channel drained", finalDaemonCmd, finalDaemonCmd.ToString()) + return finalDaemonCmd + + case amqpDelivery := <-chanDeliveries: + + threadLogger, err := l.NewLoggerFromLogger(logger) + if err != nil { + logger.Error("cannot create logger for delivery handler thread: %s", err.Error()) + return DaemonCmdQuit + } + + // TODO: come up with safe logging + // it's tempting to move it into the async func below, but it will break the logger stack + // leaving it here is not good eiter: revive says "prefer not to defer inside loops" + // logger.PushF("wf.amqpConnectAndSelect_worker") + // defer logger.PopF() + + // Lock one slot in the semaphore + sem <- 1 + + go func(threadLogger *l.CapiLogger, delivery amqp.Delivery, _channel *amqp.Channel) { + var err error + + // I have spotted cases when m.Body is empty and Aknowledger is nil. Handle them. + if delivery.Acknowledger == nil { + threadLogger.Error("processor detected empty Acknowledger, assuming closed amqp channel, will reconnect: %s", amqpDeliveryToString(delivery)) + daemonCommands <- DaemonCmdReconnectQueue + } else { + daemonCmd := processDelivery(envConfig, threadLogger, &delivery) + + if daemonCmd == DaemonCmdAckSuccess || daemonCmd == DaemonCmdAckWithError { + err = delivery.Ack(false) + if err != nil { + threadLogger.Error("failed to ack message, will reconnect: %s", err.Error()) + daemonCommands <- DaemonCmdReconnectQueue + } + } else if daemonCmd == DaemonCmdRejectAndRetryLater { + err = delivery.Reject(false) + if err != nil { + threadLogger.Error("failed to reject message, will reconnect: %s", err.Error()) + daemonCommands <- DaemonCmdReconnectQueue + } + } else if daemonCmd == DaemonCmdReconnectQueue || daemonCmd == DaemonCmdReconnectDb { + // // Ideally, RabbitMQ should be smart enough to re-deliver a msg that was neither acked nor rejected. + // // But apparently, sometimes (when the machine goes to sleep, for example) the msg is never re-delivered. To improve our chances, force re-delivery by rejecting the msg. + // threadLogger.Error("daemonCmd %s detected, will reject(requeue) and reconnect", daemonCmd.ToString()) + // err = delivery.Reject(true) + // if err != nil { + // threadLogger.Error("failed to reject message, will reconnect: %s", err.Error()) + // daemonCommands <- DaemonCmdReconnectQueue + // } else { + // daemonCommands <- daemonCmd + // } + + // Verdict: we do not handle machine sleep scenario, amqp091-go goes into deadlock when shutting down. + daemonCommands <- daemonCmd + } else if daemonCmd == DaemonCmdQuit { + daemonCommands <- DaemonCmdQuit + } else { + threadLogger.Error("unexpected daemon cmd: %d", daemonCmd) + } + } + + // Unlock semaphore slot + <-sem + + }(threadLogger, amqpDelivery, amqpChannel) + } + } +} diff --git a/pkg/wf/message_handler.go b/pkg/wf/message_handler.go index eccd252..286d5fe 100644 --- a/pkg/wf/message_handler.go +++ b/pkg/wf/message_handler.go @@ -1,455 +1,459 @@ -package wf - -import ( - "fmt" - "time" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/dpc" - "github.com/capillariesio/capillaries/pkg/env" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/proc" - "github.com/capillariesio/capillaries/pkg/sc" - "github.com/capillariesio/capillaries/pkg/wfdb" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "go.uber.org/zap" -) - -func checkDependencyNodesReady(logger *l.Logger, pCtx *ctx.MessageProcessingContext) (sc.ReadyToRunNodeCmdType, int16, int16, error) { - logger.PushF("wf.checkDependencyNodesReady") - defer logger.PopF() - - depNodeNames := make([]string, 2) - depNodeCount := 0 - if pCtx.CurrentScriptNode.HasTableReader() { - tableToReadFrom := pCtx.CurrentScriptNode.TableReader.TableName - nodeToReadFrom, ok := pCtx.Script.TableCreatorNodeMap[tableToReadFrom] - if !ok { - return sc.NodeNone, 0, 0, fmt.Errorf("cannot find the node that creates reader table [%s]", tableToReadFrom) - } - depNodeNames[depNodeCount] = nodeToReadFrom.Name - depNodeCount++ - } - if pCtx.CurrentScriptNode.HasLookup() { - tableToReadFrom := pCtx.CurrentScriptNode.Lookup.TableCreator.Name - nodeToReadFrom, ok := pCtx.Script.TableCreatorNodeMap[tableToReadFrom] - if !ok { - return sc.NodeNone, 0, 0, fmt.Errorf("cannot find the node that creates lookup table [%s]", tableToReadFrom) - } - depNodeNames[depNodeCount] = nodeToReadFrom.Name - depNodeCount++ - } - - if depNodeCount == 0 { - return sc.NodeGo, 0, 0, nil - } - - depNodeNames = depNodeNames[:depNodeCount] - - nodeEventListMap, err := wfdb.BuildDependencyNodeEventLists(logger, pCtx, depNodeNames) - if err != nil { - return sc.NodeNone, 0, 0, err - } - - logger.DebugCtx(pCtx, "nodeEventListMap %v", nodeEventListMap) - - dependencyNodeCmds := make([]sc.ReadyToRunNodeCmdType, len(depNodeNames)) - dependencyRunIds := make([]int16, len(depNodeNames)) - for nodeIdx, depNodeName := range depNodeNames { - if len(nodeEventListMap[depNodeName]) == 0 { - return sc.NodeNogo, 0, 0, fmt.Errorf("target node %s, dep node %s not started yet, whoever started this run, failed to specify %s (or at least one of its dependencies) as start node", pCtx.CurrentScriptNode.Name, depNodeName, depNodeName) - } - var checkerLogMsg string - dependencyNodeCmds[nodeIdx], dependencyRunIds[nodeIdx], checkerLogMsg, err = dpc.CheckDependencyPolicyAgainstNodeEventList(pCtx.CurrentScriptNode.DepPolDef, nodeEventListMap[depNodeName]) - if len(checkerLogMsg) > 0 { - logger.Debug(checkerLogMsg) - } - if err != nil { - return sc.NodeNone, 0, 0, err - } - logger.DebugCtx(pCtx, "target node %s, dep node %s returned %s", pCtx.CurrentScriptNode.Name, depNodeName, dependencyNodeCmds[nodeIdx]) - } - - finalCmd := dependencyNodeCmds[0] - finalRunIdReader := dependencyRunIds[0] - finalRunIdLookup := int16(0) - if len(dependencyNodeCmds) == 2 { - finalRunIdLookup = dependencyRunIds[1] - if dependencyNodeCmds[0] == sc.NodeNogo || dependencyNodeCmds[1] == sc.NodeNogo { - finalCmd = sc.NodeNogo - } else if dependencyNodeCmds[0] == sc.NodeWait || dependencyNodeCmds[1] == sc.NodeWait { - finalCmd = sc.NodeWait - } else { - finalCmd = sc.NodeGo - } - } - - if finalCmd == sc.NodeNogo || finalCmd == sc.NodeGo { - logger.InfoCtx(pCtx, "checked all dependency nodes for %s, commands are %v, run ids are %v, finalCmd is %s", pCtx.CurrentScriptNode.Name, dependencyNodeCmds, dependencyRunIds, finalCmd) - } else { - logger.DebugCtx(pCtx, "checked all dependency nodes for %s, commands are %v, run ids are %v, finalCmd is wait", pCtx.CurrentScriptNode.Name, dependencyNodeCmds, dependencyRunIds) - } - - return finalCmd, finalRunIdReader, finalRunIdLookup, nil -} - -func SafeProcessBatch(envConfig *env.EnvConfig, logger *l.Logger, pCtx *ctx.MessageProcessingContext, readerNodeRunId int16, lookupNodeRunId int16) (wfmodel.NodeBatchStatusType, proc.BatchStats, error) { - logger.PushF("wf.SafeProcessBatch") - defer logger.PopF() - - var bs proc.BatchStats - var err error - - switch pCtx.CurrentScriptNode.Type { - case sc.NodeTypeFileTable: - if pCtx.BatchInfo.FirstToken != pCtx.BatchInfo.LastToken || pCtx.BatchInfo.FirstToken < 0 || pCtx.BatchInfo.FirstToken >= int64(len(pCtx.CurrentScriptNode.FileReader.SrcFileUrls)) { - err = fmt.Errorf( - "startToken %d must equal endToken %d must be smaller than the number of files specified by file reader %d", - pCtx.BatchInfo.FirstToken, - pCtx.BatchInfo.LastToken, - len(pCtx.CurrentScriptNode.FileReader.SrcFileUrls)) - } else { - bs, err = proc.RunReadFileForBatch(envConfig, logger, pCtx, int(pCtx.BatchInfo.FirstToken)) - } - - case sc.NodeTypeTableTable: - bs, err = proc.RunCreateTableForBatch(envConfig, logger, pCtx, readerNodeRunId, pCtx.BatchInfo.FirstToken, pCtx.BatchInfo.LastToken) - - case sc.NodeTypeTableLookupTable: - bs, err = proc.RunCreateTableRelForBatch(envConfig, logger, pCtx, readerNodeRunId, lookupNodeRunId, pCtx.BatchInfo.FirstToken, pCtx.BatchInfo.LastToken) - - case sc.NodeTypeTableFile: - bs, err = proc.RunCreateFile(envConfig, logger, pCtx, readerNodeRunId, pCtx.BatchInfo.FirstToken, pCtx.BatchInfo.LastToken) - - case sc.NodeTypeTableCustomTfmTable: - bs, err = proc.RunCreateTableForCustomProcessorForBatch(envConfig, logger, pCtx, readerNodeRunId, pCtx.BatchInfo.FirstToken, pCtx.BatchInfo.LastToken) - - default: - err = fmt.Errorf("unsupported node %s type %s", pCtx.CurrentScriptNode.Name, pCtx.CurrentScriptNode.Type) - } - - if err != nil { - logger.DebugCtx(pCtx, "batch processed, error: %s", err.Error()) - return wfmodel.NodeBatchFail, bs, fmt.Errorf("error running node %s of type %s in the script [%s]: [%s]", pCtx.CurrentScriptNode.Name, pCtx.CurrentScriptNode.Type, pCtx.BatchInfo.ScriptURI, err.Error()) - } else { - logger.DebugCtx(pCtx, "batch processed ok") - } - - return wfmodel.NodeBatchSuccess, bs, nil -} - -func UpdateNodeStatusFromBatches(logger *l.Logger, pCtx *ctx.MessageProcessingContext) (wfmodel.NodeBatchStatusType, bool, error) { - logger.PushF("wf.UpdateNodeStatusFromBatches") - defer logger.PopF() - - // Check all batches for this run/node, mark node complete if needed - totalNodeStatus, err := wfdb.HarvestBatchStatusesForNode(logger, pCtx) - if err != nil { - return wfmodel.NodeBatchNone, false, err - } - - if totalNodeStatus == wfmodel.NodeBatchFail || totalNodeStatus == wfmodel.NodeBatchSuccess || totalNodeStatus == wfmodel.NodeBatchRunStopReceived { - // Node processing completed, mark whole node as complete - var comment string - switch totalNodeStatus { - case wfmodel.NodeBatchSuccess: - comment = "completed - all batches ok" - case wfmodel.NodeBatchFail: - comment = "completed with some failed batches - check batch history" - case wfmodel.NodeBatchRunStopReceived: - comment = "run was stopped,check run and batch history" - } - - isApplied, err := wfdb.SetNodeStatus(logger, pCtx, totalNodeStatus, comment) - if err != nil { - return wfmodel.NodeBatchNone, false, err - } else { - return totalNodeStatus, isApplied, nil - } - } - - return totalNodeStatus, false, nil -} - -func UpdateRunStatusFromNodes(logger *l.Logger, pCtx *ctx.MessageProcessingContext) error { - logger.PushF("wf.UpdateRunStatusFromNodes") - defer logger.PopF() - - // Let's see if this run is complete - affectedNodes, err := wfdb.GetRunAffectedNodes(logger, pCtx.CqlSession, pCtx.BatchInfo.DataKeyspace, pCtx.BatchInfo.RunId) - if err != nil { - return err - } - combinedNodeStatus, nodeStatusString, err := wfdb.HarvestNodeStatusesForRun(logger, pCtx, affectedNodes) - if err != nil { - return err - } - - if combinedNodeStatus == wfmodel.NodeBatchSuccess || combinedNodeStatus == wfmodel.NodeBatchFail { - // Mark run as complete - if err := wfdb.SetRunStatus(logger, pCtx.CqlSession, pCtx.BatchInfo.DataKeyspace, pCtx.BatchInfo.RunId, wfmodel.RunComplete, nodeStatusString, cql.IgnoreIfExists); err != nil { - return err - } - } - - return nil -} - -func refreshNodeAndRunStatus(logger *l.Logger, pCtx *ctx.MessageProcessingContext) error { - logger.PushF("wf.refreshNodeAndRunStatus") - defer logger.PopF() - - _, _, err := UpdateNodeStatusFromBatches(logger, pCtx) - if err != nil { - logger.ErrorCtx(pCtx, "cannot refresh run/node status: %s", err.Error()) - return err - } else { - // Ideally, we should run the code below only if isApplied signaled something changed. But, there is a possibility - // that on the previous attempt, node status was updated and the daemon crashed right after that. - // We need to pick it up from there and refresh run status anyways. - err := UpdateRunStatusFromNodes(logger, pCtx) - if err != nil { - logger.ErrorCtx(pCtx, "cannot refresh run status: %s", err.Error()) - return err - } - } - return nil -} - -func ProcessDataBatchMsg(envConfig *env.EnvConfig, logger *l.Logger, msgTs int64, dataBatchInfo *wfmodel.MessagePayloadDataBatch) DaemonCmdType { - logger.PushF("wf.ProcessDataBatchMsg") - defer logger.PopF() - - pCtx := &ctx.MessageProcessingContext{ - MsgTs: msgTs, - BatchInfo: *dataBatchInfo, - ZapDataKeyspace: zap.String("ks", dataBatchInfo.DataKeyspace), - ZapRun: zap.Int16("run", dataBatchInfo.RunId), - ZapNode: zap.String("node", dataBatchInfo.TargetNodeName), - ZapBatchIdx: zap.Int16("bi", dataBatchInfo.BatchIdx), - ZapMsgAgeMillis: zap.Int64("age", time.Now().UnixMilli()-msgTs)} - - var err error - var initProblem sc.ScriptInitProblemType - pCtx.Script, err, initProblem = sc.NewScriptFromFiles(envConfig.CaPath, envConfig.PrivateKeys, dataBatchInfo.ScriptURI, dataBatchInfo.ScriptParamsURI, envConfig.CustomProcessorDefFactoryInstance, envConfig.CustomProcessorsSettings) - if initProblem != sc.ScriptInitNoProblem { - switch initProblem { - case sc.ScriptInitUrlProblem: - logger.Error("cannot init script because of URI problem, will not let other workers handle it, giving up with msg %s: %s", dataBatchInfo.ToString(), err.Error()) - return DaemonCmdAckWithError - case sc.ScriptInitContentProblem: - logger.Error("cannot init script because of content problem, will not let other workers handle it, giving up with msg %s: %s", dataBatchInfo.ToString(), err.Error()) - return DaemonCmdAckWithError - case sc.ScriptInitConnectivityProblem: - logger.Error("cannot init script because of connectivity problem, will let other workers handle it, giving up with msg %s: %s", dataBatchInfo.ToString(), err.Error()) - return DaemonCmdRejectAndRetryLater - default: - logger.Error("unexpected: cannot init script for unknown reason %d, will let other workers handle it, giving up with msg %s: %s", initProblem, dataBatchInfo.ToString(), err.Error()) - return DaemonCmdRejectAndRetryLater - } - } - - var ok bool - pCtx.CurrentScriptNode, ok = pCtx.Script.ScriptNodes[dataBatchInfo.TargetNodeName] - if !ok { - logger.Error("cannot find node %s in the script [%s], giving up with %s, returning DaemonCmdAckWithError, will not let other workers handle it", pCtx.BatchInfo.TargetNodeName, pCtx.BatchInfo.ScriptURI, dataBatchInfo.ToString()) - return DaemonCmdAckWithError - } - - if err := pCtx.DbConnect(envConfig); err != nil { - logger.Error("cannot connect to db: %s", err.Error()) - return DaemonCmdReconnectDb - } - defer pCtx.DbClose() - - logger.DebugCtx(pCtx, "started processing batch %s", dataBatchInfo.FullBatchId()) - - runStatus, err := wfdb.GetCurrentRunStatus(logger, pCtx) - if err != nil { - logger.ErrorCtx(pCtx, "cannot get current run status for batch %s: %s", dataBatchInfo.FullBatchId(), err.Error()) - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - - if runStatus == wfmodel.RunNone { - comment := fmt.Sprintf("run history status for batch %s is empty, looks like this run %d was never started", dataBatchInfo.FullBatchId(), pCtx.BatchInfo.RunId) - logger.ErrorCtx(pCtx, comment) - if err := wfdb.SetBatchStatus(logger, pCtx, wfmodel.NodeBatchRunStopReceived, comment); err != nil { - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - if err := refreshNodeAndRunStatus(logger, pCtx); err != nil { - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - return DaemonCmdAckWithError - } - - // If the user signaled stop to this proc, all results of the run are invalidated - if runStatus == wfmodel.RunStop { - comment := fmt.Sprintf("run stopped, batch %s marked %s", dataBatchInfo.FullBatchId(), wfmodel.NodeBatchRunStopReceived.ToString()) - if err := wfdb.SetBatchStatus(logger, pCtx, wfmodel.NodeBatchRunStopReceived, comment); err != nil { - logger.ErrorCtx(pCtx, fmt.Sprintf("%s, cannot set batch status: %s", comment, err.Error())) - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - - if err := refreshNodeAndRunStatus(logger, pCtx); err != nil { - logger.ErrorCtx(pCtx, fmt.Sprintf("%s, cannot refresh status: %s", comment, err.Error())) - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - - logger.DebugCtx(pCtx, fmt.Sprintf("%s, status successfully refreshed", comment)) - return DaemonCmdAckSuccess - } else if runStatus != wfmodel.RunStart { - logger.ErrorCtx(pCtx, "cannot process batch %s, run already has unexpected status %d", dataBatchInfo.FullBatchId(), runStatus) - return DaemonCmdAckWithError - } - - // Check if this run/node/batch has been handled already - lastBatchStatus, err := wfdb.HarvestLastStatusForBatch(logger, pCtx) - if err != nil { - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - - if lastBatchStatus == wfmodel.NodeBatchFail || lastBatchStatus == wfmodel.NodeBatchSuccess { - logger.InfoCtx(pCtx, "will not process batch %s, it has been already processed (processor crashed after processing it and before marking as success/fail?) with status %d", dataBatchInfo.FullBatchId(), lastBatchStatus) - if err := refreshNodeAndRunStatus(logger, pCtx); err != nil { - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - return DaemonCmdAckSuccess - } else if lastBatchStatus == wfmodel.NodeBatchStart { - // This run/node/batch has been picked up by another crashed processor (processor crashed before marking success/fail) - if pCtx.CurrentScriptNode.RerunPolicy == sc.NodeRerun { - if err := proc.DeleteDataAndUniqueIndexesByBatchIdx(logger, pCtx); err != nil { - comment := fmt.Sprintf("cannot clean up leftovers of the previous processing of batch %s: %s", pCtx.BatchInfo.FullBatchId(), err.Error()) - logger.ErrorCtx(pCtx, comment) - wfdb.SetBatchStatus(logger, pCtx, wfmodel.NodeFail, comment) - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - // Clean up successful, process this node - } else if pCtx.CurrentScriptNode.RerunPolicy == sc.NodeFail { - logger.ErrorCtx(pCtx, "will not rerun %s, rerun policy says we have to fail", pCtx.BatchInfo.FullBatchId()) - return DaemonCmdAckWithError - } else { - logger.ErrorCtx(pCtx, "unexpected rerun policy %s, looks like dev error", pCtx.CurrentScriptNode.RerunPolicy) - return DaemonCmdAckWithError - } - } else if lastBatchStatus != wfmodel.NodeBatchNone { - logger.ErrorCtx(pCtx, "unexpected batch %s status %d, expected None, looks like dev error.", pCtx.BatchInfo.FullBatchId(), lastBatchStatus) - return DaemonCmdAckWithError - } - - // Here, we are assuming this batch processing either never started or was started and then abandoned - - // Check if we have dependency nodes ready - nodeReady, readerNodeRunId, lookupNodeRunId, err := checkDependencyNodesReady(logger, pCtx) - if err != nil { - logger.ErrorCtx(pCtx, "cannot verify dependency nodes status for %s: %s", pCtx.BatchInfo.FullBatchId(), err.Error()) - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - - if nodeReady == sc.NodeNogo { - comment := fmt.Sprintf("some dependency nodes for %s are in bad state, or runs executing dependency nodes were stopped/invalidated, will not run this node; for details, check rules in dependency_policies and previous runs history", pCtx.BatchInfo.FullBatchId()) - logger.InfoCtx(pCtx, comment) - if err := wfdb.SetBatchStatus(logger, pCtx, wfmodel.NodeFail, comment); err != nil { - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - - if err := refreshNodeAndRunStatus(logger, pCtx); err != nil { - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - } - return DaemonCmdAckWithError - - } else if nodeReady == sc.NodeWait { - logger.InfoCtx(pCtx, "some dependency nodes for %s are not ready, will wait", pCtx.BatchInfo.FullBatchId()) - return DaemonCmdRejectAndRetryLater - } - - // Here, we are ready to actually process the node - - if _, err := wfdb.SetNodeStatus(logger, pCtx, wfmodel.NodeStart, "started"); err != nil { - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - - if err := wfdb.SetBatchStatus(logger, pCtx, wfmodel.NodeStart, ""); err != nil { - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - - batchStatus, batchStats, batchErr := SafeProcessBatch(envConfig, logger, pCtx, readerNodeRunId, lookupNodeRunId) - - // TODO: test only!!! - // if pCtx.BatchInfo.TargetNodeName == "order_item_date_inner" && pCtx.BatchInfo.BatchIdx == 3 { - // rnd := rand.New(rand.NewSource(time.Now().UnixMilli())) - // if rnd.Float32() < .5 { - // logger.InfoCtx(pCtx, "ProcessBatchWithStatus: test error") - // return DaemonCmdRejectAndRetryLater - // } - // } - - if batchErr != nil { - logger.ErrorCtx(pCtx, "ProcessBatchWithStatus: %s", batchErr.Error()) - if db.IsDbConnError(batchErr) { - return DaemonCmdReconnectDb - } - if err := wfdb.SetBatchStatus(logger, pCtx, wfmodel.NodeBatchFail, batchErr.Error()); err != nil { - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - } else { - logger.InfoCtx(pCtx, "ProcessBatchWithStatus: success") - if err := wfdb.SetBatchStatus(logger, pCtx, batchStatus, batchStats.ToString()); err != nil { - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - } - - if err := refreshNodeAndRunStatus(logger, pCtx); err != nil { - if db.IsDbConnError(err) { - return DaemonCmdReconnectDb - } - return DaemonCmdAckWithError - } - - return DaemonCmdAckSuccess -} +package wf + +import ( + "fmt" + "time" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/dpc" + "github.com/capillariesio/capillaries/pkg/env" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/proc" + "github.com/capillariesio/capillaries/pkg/sc" + "github.com/capillariesio/capillaries/pkg/wfdb" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "go.uber.org/zap" +) + +func checkDependencyNodesReady(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) (sc.ReadyToRunNodeCmdType, int16, int16, error) { + logger.PushF("wf.checkDependencyNodesReady") + defer logger.PopF() + + depNodeNames := make([]string, 2) + depNodeCount := 0 + if pCtx.CurrentScriptNode.HasTableReader() { + tableToReadFrom := pCtx.CurrentScriptNode.TableReader.TableName + nodeToReadFrom, ok := pCtx.Script.TableCreatorNodeMap[tableToReadFrom] + if !ok { + return sc.NodeNone, 0, 0, fmt.Errorf("cannot find the node that creates reader table [%s]", tableToReadFrom) + } + depNodeNames[depNodeCount] = nodeToReadFrom.Name + depNodeCount++ + } + if pCtx.CurrentScriptNode.HasLookup() { + tableToReadFrom := pCtx.CurrentScriptNode.Lookup.TableCreator.Name + nodeToReadFrom, ok := pCtx.Script.TableCreatorNodeMap[tableToReadFrom] + if !ok { + return sc.NodeNone, 0, 0, fmt.Errorf("cannot find the node that creates lookup table [%s]", tableToReadFrom) + } + depNodeNames[depNodeCount] = nodeToReadFrom.Name + depNodeCount++ + } + + if depNodeCount == 0 { + return sc.NodeGo, 0, 0, nil + } + + depNodeNames = depNodeNames[:depNodeCount] + + nodeEventListMap, err := wfdb.BuildDependencyNodeEventLists(logger, pCtx, depNodeNames) + if err != nil { + return sc.NodeNone, 0, 0, err + } + + logger.DebugCtx(pCtx, "nodeEventListMap %v", nodeEventListMap) + + dependencyNodeCmds := make([]sc.ReadyToRunNodeCmdType, len(depNodeNames)) + dependencyRunIds := make([]int16, len(depNodeNames)) + for nodeIdx, depNodeName := range depNodeNames { + if len(nodeEventListMap[depNodeName]) == 0 { + return sc.NodeNogo, 0, 0, fmt.Errorf("target node %s, dep node %s not started yet, whoever started this run, failed to specify %s (or at least one of its dependencies) as start node", pCtx.CurrentScriptNode.Name, depNodeName, depNodeName) + } + var checkerLogMsg string + dependencyNodeCmds[nodeIdx], dependencyRunIds[nodeIdx], checkerLogMsg, err = dpc.CheckDependencyPolicyAgainstNodeEventList(pCtx.CurrentScriptNode.DepPolDef, nodeEventListMap[depNodeName]) + if len(checkerLogMsg) > 0 { + logger.Debug(checkerLogMsg) + } + if err != nil { + return sc.NodeNone, 0, 0, err + } + logger.DebugCtx(pCtx, "target node %s, dep node %s returned %s", pCtx.CurrentScriptNode.Name, depNodeName, dependencyNodeCmds[nodeIdx]) + } + + finalCmd := dependencyNodeCmds[0] + finalRunIdReader := dependencyRunIds[0] + finalRunIdLookup := int16(0) + if len(dependencyNodeCmds) == 2 { + finalRunIdLookup = dependencyRunIds[1] + if dependencyNodeCmds[0] == sc.NodeNogo || dependencyNodeCmds[1] == sc.NodeNogo { + finalCmd = sc.NodeNogo + } else if dependencyNodeCmds[0] == sc.NodeWait || dependencyNodeCmds[1] == sc.NodeWait { + finalCmd = sc.NodeWait + } else { + finalCmd = sc.NodeGo + } + } + + if finalCmd == sc.NodeNogo || finalCmd == sc.NodeGo { + logger.InfoCtx(pCtx, "checked all dependency nodes for %s, commands are %v, run ids are %v, finalCmd is %s", pCtx.CurrentScriptNode.Name, dependencyNodeCmds, dependencyRunIds, finalCmd) + } else { + logger.DebugCtx(pCtx, "checked all dependency nodes for %s, commands are %v, run ids are %v, finalCmd is wait", pCtx.CurrentScriptNode.Name, dependencyNodeCmds, dependencyRunIds) + } + + return finalCmd, finalRunIdReader, finalRunIdLookup, nil +} + +func SafeProcessBatch(envConfig *env.EnvConfig, logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, readerNodeRunId int16, lookupNodeRunId int16) (wfmodel.NodeBatchStatusType, proc.BatchStats, error) { + logger.PushF("wf.SafeProcessBatch") + defer logger.PopF() + + var bs proc.BatchStats + var err error + + switch pCtx.CurrentScriptNode.Type { + case sc.NodeTypeFileTable: + if pCtx.BatchInfo.FirstToken != pCtx.BatchInfo.LastToken || pCtx.BatchInfo.FirstToken < 0 || pCtx.BatchInfo.FirstToken >= int64(len(pCtx.CurrentScriptNode.FileReader.SrcFileUrls)) { + err = fmt.Errorf( + "startToken %d must equal endToken %d must be smaller than the number of files specified by file reader %d", + pCtx.BatchInfo.FirstToken, + pCtx.BatchInfo.LastToken, + len(pCtx.CurrentScriptNode.FileReader.SrcFileUrls)) + } else { + bs, err = proc.RunReadFileForBatch(envConfig, logger, pCtx, int(pCtx.BatchInfo.FirstToken)) + } + + case sc.NodeTypeTableTable: + bs, err = proc.RunCreateTableForBatch(envConfig, logger, pCtx, readerNodeRunId, pCtx.BatchInfo.FirstToken, pCtx.BatchInfo.LastToken) + + case sc.NodeTypeTableLookupTable: + bs, err = proc.RunCreateTableRelForBatch(envConfig, logger, pCtx, readerNodeRunId, lookupNodeRunId, pCtx.BatchInfo.FirstToken, pCtx.BatchInfo.LastToken) + + case sc.NodeTypeTableFile: + bs, err = proc.RunCreateFile(envConfig, logger, pCtx, readerNodeRunId, pCtx.BatchInfo.FirstToken, pCtx.BatchInfo.LastToken) + + case sc.NodeTypeTableCustomTfmTable: + bs, err = proc.RunCreateTableForCustomProcessorForBatch(envConfig, logger, pCtx, readerNodeRunId, pCtx.BatchInfo.FirstToken, pCtx.BatchInfo.LastToken) + + default: + err = fmt.Errorf("unsupported node %s type %s", pCtx.CurrentScriptNode.Name, pCtx.CurrentScriptNode.Type) + } + + if err != nil { + logger.DebugCtx(pCtx, "batch processed, error: %s", err.Error()) + return wfmodel.NodeBatchFail, bs, fmt.Errorf("error running node %s of type %s in the script [%s]: [%s]", pCtx.CurrentScriptNode.Name, pCtx.CurrentScriptNode.Type, pCtx.BatchInfo.ScriptURI, err.Error()) + } + logger.DebugCtx(pCtx, "batch processed ok") + + return wfmodel.NodeBatchSuccess, bs, nil +} + +func UpdateNodeStatusFromBatches(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) (wfmodel.NodeBatchStatusType, bool, error) { + logger.PushF("wf.UpdateNodeStatusFromBatches") + defer logger.PopF() + + // Check all batches for this run/node, mark node complete if needed + totalNodeStatus, err := wfdb.HarvestBatchStatusesForNode(logger, pCtx) + if err != nil { + return wfmodel.NodeBatchNone, false, err + } + + if totalNodeStatus == wfmodel.NodeBatchFail || totalNodeStatus == wfmodel.NodeBatchSuccess || totalNodeStatus == wfmodel.NodeBatchRunStopReceived { + // Node processing completed, mark whole node as complete + var comment string + switch totalNodeStatus { + case wfmodel.NodeBatchSuccess: + comment = "completed - all batches ok" + case wfmodel.NodeBatchFail: + comment = "completed with some failed batches - check batch history" + case wfmodel.NodeBatchRunStopReceived: + comment = "run was stopped,check run and batch history" + } + + isApplied, err := wfdb.SetNodeStatus(logger, pCtx, totalNodeStatus, comment) + if err != nil { + return wfmodel.NodeBatchNone, false, err + } + + return totalNodeStatus, isApplied, nil + } + + return totalNodeStatus, false, nil +} + +func UpdateRunStatusFromNodes(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) error { + logger.PushF("wf.UpdateRunStatusFromNodes") + defer logger.PopF() + + // Let's see if this run is complete + affectedNodes, err := wfdb.GetRunAffectedNodes(logger, pCtx.CqlSession, pCtx.BatchInfo.DataKeyspace, pCtx.BatchInfo.RunId) + if err != nil { + return err + } + combinedNodeStatus, nodeStatusString, err := wfdb.HarvestNodeStatusesForRun(logger, pCtx, affectedNodes) + if err != nil { + return err + } + + if combinedNodeStatus == wfmodel.NodeBatchSuccess || combinedNodeStatus == wfmodel.NodeBatchFail { + // Mark run as complete + if err := wfdb.SetRunStatus(logger, pCtx.CqlSession, pCtx.BatchInfo.DataKeyspace, pCtx.BatchInfo.RunId, wfmodel.RunComplete, nodeStatusString, cql.IgnoreIfExists); err != nil { + return err + } + } + + return nil +} + +func refreshNodeAndRunStatus(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) error { + logger.PushF("wf.refreshNodeAndRunStatus") + defer logger.PopF() + + _, _, err := UpdateNodeStatusFromBatches(logger, pCtx) + if err != nil { + logger.ErrorCtx(pCtx, "cannot refresh run/node status: %s", err.Error()) + return err + } + + // Ideally, we should run the code below only if isApplied signaled something changed. But, there is a possibility + // that on the previous attempt, node status was updated and the daemon crashed right after that. + // We need to pick it up from there and refresh run status anyways. + err = UpdateRunStatusFromNodes(logger, pCtx) + if err != nil { + logger.ErrorCtx(pCtx, "cannot refresh run status: %s", err.Error()) + return err + } + + return nil +} + +func ProcessDataBatchMsg(envConfig *env.EnvConfig, logger *l.CapiLogger, msgTs int64, dataBatchInfo *wfmodel.MessagePayloadDataBatch) DaemonCmdType { + logger.PushF("wf.ProcessDataBatchMsg") + defer logger.PopF() + + pCtx := &ctx.MessageProcessingContext{ + MsgTs: msgTs, + BatchInfo: *dataBatchInfo, + ZapDataKeyspace: zap.String("ks", dataBatchInfo.DataKeyspace), + ZapRun: zap.Int16("run", dataBatchInfo.RunId), + ZapNode: zap.String("node", dataBatchInfo.TargetNodeName), + ZapBatchIdx: zap.Int16("bi", dataBatchInfo.BatchIdx), + ZapMsgAgeMillis: zap.Int64("age", time.Now().UnixMilli()-msgTs)} + + var err error + var initProblem sc.ScriptInitProblemType + pCtx.Script, initProblem, err = sc.NewScriptFromFiles(envConfig.CaPath, envConfig.PrivateKeys, dataBatchInfo.ScriptURI, dataBatchInfo.ScriptParamsURI, envConfig.CustomProcessorDefFactoryInstance, envConfig.CustomProcessorsSettings) + if initProblem != sc.ScriptInitNoProblem { + switch initProblem { + case sc.ScriptInitUrlProblem: + logger.Error("cannot init script because of URI problem, will not let other workers handle it, giving up with msg %s: %s", dataBatchInfo.ToString(), err.Error()) + return DaemonCmdAckWithError + case sc.ScriptInitContentProblem: + logger.Error("cannot init script because of content problem, will not let other workers handle it, giving up with msg %s: %s", dataBatchInfo.ToString(), err.Error()) + return DaemonCmdAckWithError + case sc.ScriptInitConnectivityProblem: + logger.Error("cannot init script because of connectivity problem, will let other workers handle it, giving up with msg %s: %s", dataBatchInfo.ToString(), err.Error()) + return DaemonCmdRejectAndRetryLater + default: + logger.Error("unexpected: cannot init script for unknown reason %d, will let other workers handle it, giving up with msg %s: %s", initProblem, dataBatchInfo.ToString(), err.Error()) + return DaemonCmdRejectAndRetryLater + } + } + + var ok bool + pCtx.CurrentScriptNode, ok = pCtx.Script.ScriptNodes[dataBatchInfo.TargetNodeName] + if !ok { + logger.Error("cannot find node %s in the script [%s], giving up with %s, returning DaemonCmdAckWithError, will not let other workers handle it", pCtx.BatchInfo.TargetNodeName, pCtx.BatchInfo.ScriptURI, dataBatchInfo.ToString()) + return DaemonCmdAckWithError + } + + if err := pCtx.DbConnect(envConfig); err != nil { + logger.Error("cannot connect to db: %s", err.Error()) + return DaemonCmdReconnectDb + } + defer pCtx.DbClose() + + logger.DebugCtx(pCtx, "started processing batch %s", dataBatchInfo.FullBatchId()) + + runStatus, err := wfdb.GetCurrentRunStatus(logger, pCtx) + if err != nil { + logger.ErrorCtx(pCtx, "cannot get current run status for batch %s: %s", dataBatchInfo.FullBatchId(), err.Error()) + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + + if runStatus == wfmodel.RunNone { + comment := fmt.Sprintf("run history status for batch %s is empty, looks like this run %d was never started", dataBatchInfo.FullBatchId(), pCtx.BatchInfo.RunId) + logger.ErrorCtx(pCtx, comment) + if err := wfdb.SetBatchStatus(logger, pCtx, wfmodel.NodeBatchRunStopReceived, comment); err != nil { + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + if err := refreshNodeAndRunStatus(logger, pCtx); err != nil { + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + return DaemonCmdAckWithError + } + + // If the user signaled stop to this proc, all results of the run are invalidated + if runStatus == wfmodel.RunStop { + comment := fmt.Sprintf("run stopped, batch %s marked %s", dataBatchInfo.FullBatchId(), wfmodel.NodeBatchRunStopReceived.ToString()) + if err := wfdb.SetBatchStatus(logger, pCtx, wfmodel.NodeBatchRunStopReceived, comment); err != nil { + logger.ErrorCtx(pCtx, fmt.Sprintf("%s, cannot set batch status: %s", comment, err.Error())) + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + + if err := refreshNodeAndRunStatus(logger, pCtx); err != nil { + logger.ErrorCtx(pCtx, fmt.Sprintf("%s, cannot refresh status: %s", comment, err.Error())) + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + + logger.DebugCtx(pCtx, fmt.Sprintf("%s, status successfully refreshed", comment)) + return DaemonCmdAckSuccess + } else if runStatus != wfmodel.RunStart { + logger.ErrorCtx(pCtx, "cannot process batch %s, run already has unexpected status %d", dataBatchInfo.FullBatchId(), runStatus) + return DaemonCmdAckWithError + } + + // Check if this run/node/batch has been handled already + lastBatchStatus, err := wfdb.HarvestLastStatusForBatch(logger, pCtx) + if err != nil { + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + + if lastBatchStatus == wfmodel.NodeBatchFail || lastBatchStatus == wfmodel.NodeBatchSuccess { + logger.InfoCtx(pCtx, "will not process batch %s, it has been already processed (processor crashed after processing it and before marking as success/fail?) with status %d", dataBatchInfo.FullBatchId(), lastBatchStatus) + if err := refreshNodeAndRunStatus(logger, pCtx); err != nil { + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + return DaemonCmdAckSuccess + } else if lastBatchStatus == wfmodel.NodeBatchStart { + // This run/node/batch has been picked up by another crashed processor (processor crashed before marking success/fail) + if pCtx.CurrentScriptNode.RerunPolicy == sc.NodeRerun { + if deleteErr := proc.DeleteDataAndUniqueIndexesByBatchIdx(logger, pCtx); err != nil { + comment := fmt.Sprintf("cannot clean up leftovers of the previous processing of batch %s: %s", pCtx.BatchInfo.FullBatchId(), deleteErr.Error()) + logger.ErrorCtx(pCtx, comment) + setBatchStatusErr := wfdb.SetBatchStatus(logger, pCtx, wfmodel.NodeFail, comment) + if setBatchStatusErr != nil { + comment += fmt.Sprintf("; cannot set batch status: %s", setBatchStatusErr.Error()) + logger.ErrorCtx(pCtx, comment) + } + if db.IsDbConnError(deleteErr) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + // Clean up successful, process this node + } else if pCtx.CurrentScriptNode.RerunPolicy == sc.NodeFail { + logger.ErrorCtx(pCtx, "will not rerun %s, rerun policy says we have to fail", pCtx.BatchInfo.FullBatchId()) + return DaemonCmdAckWithError + } else { + logger.ErrorCtx(pCtx, "unexpected rerun policy %s, looks like dev error", pCtx.CurrentScriptNode.RerunPolicy) + return DaemonCmdAckWithError + } + } else if lastBatchStatus != wfmodel.NodeBatchNone { + logger.ErrorCtx(pCtx, "unexpected batch %s status %d, expected None, looks like dev error.", pCtx.BatchInfo.FullBatchId(), lastBatchStatus) + return DaemonCmdAckWithError + } + + // Here, we are assuming this batch processing either never started or was started and then abandoned + + // Check if we have dependency nodes ready + nodeReady, readerNodeRunId, lookupNodeRunId, err := checkDependencyNodesReady(logger, pCtx) + if err != nil { + logger.ErrorCtx(pCtx, "cannot verify dependency nodes status for %s: %s", pCtx.BatchInfo.FullBatchId(), err.Error()) + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + + if nodeReady == sc.NodeNogo { + comment := fmt.Sprintf("some dependency nodes for %s are in bad state, or runs executing dependency nodes were stopped/invalidated, will not run this node; for details, check rules in dependency_policies and previous runs history", pCtx.BatchInfo.FullBatchId()) + logger.InfoCtx(pCtx, comment) + if err := wfdb.SetBatchStatus(logger, pCtx, wfmodel.NodeFail, comment); err != nil { + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + + if err := refreshNodeAndRunStatus(logger, pCtx); err != nil { + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + } + return DaemonCmdAckWithError + + } else if nodeReady == sc.NodeWait { + logger.InfoCtx(pCtx, "some dependency nodes for %s are not ready, will wait", pCtx.BatchInfo.FullBatchId()) + return DaemonCmdRejectAndRetryLater + } + + // Here, we are ready to actually process the node + + if _, err := wfdb.SetNodeStatus(logger, pCtx, wfmodel.NodeStart, "started"); err != nil { + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + + if err := wfdb.SetBatchStatus(logger, pCtx, wfmodel.NodeStart, ""); err != nil { + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + + batchStatus, batchStats, batchErr := SafeProcessBatch(envConfig, logger, pCtx, readerNodeRunId, lookupNodeRunId) + + // TODO: test only!!! + // if pCtx.BatchInfo.TargetNodeName == "order_item_date_inner" && pCtx.BatchInfo.BatchIdx == 3 { + // rnd := rand.New(rand.NewSource(time.Now().UnixMilli())) + // if rnd.Float32() < .5 { + // logger.InfoCtx(pCtx, "ProcessBatchWithStatus: test error") + // return DaemonCmdRejectAndRetryLater + // } + // } + + if batchErr != nil { + logger.ErrorCtx(pCtx, "ProcessBatchWithStatus: %s", batchErr.Error()) + if db.IsDbConnError(batchErr) { + return DaemonCmdReconnectDb + } + if err := wfdb.SetBatchStatus(logger, pCtx, wfmodel.NodeBatchFail, batchErr.Error()); err != nil { + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + } else { + logger.InfoCtx(pCtx, "ProcessBatchWithStatus: success") + if err := wfdb.SetBatchStatus(logger, pCtx, batchStatus, batchStats.ToString()); err != nil { + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + } + + if err := refreshNodeAndRunStatus(logger, pCtx); err != nil { + if db.IsDbConnError(err) { + return DaemonCmdReconnectDb + } + return DaemonCmdAckWithError + } + + return DaemonCmdAckSuccess +} diff --git a/pkg/wfdb/batch_history.go b/pkg/wfdb/batch_history.go index 67e0aa6..7989943 100644 --- a/pkg/wfdb/batch_history.go +++ b/pkg/wfdb/batch_history.go @@ -1,177 +1,177 @@ -package wfdb - -import ( - "fmt" - "sort" - "time" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "github.com/gocql/gocql" -) - -func HarvestLastStatusForBatch(logger *l.Logger, pCtx *ctx.MessageProcessingContext) (wfmodel.NodeBatchStatusType, error) { - logger.PushF("wfdb.HarvestLastStatusForBatch") - defer logger.PopF() - - fields := []string{"ts", "status"} - q := (&cql.QueryBuilder{}). - Keyspace(pCtx.BatchInfo.DataKeyspace). - Cond("run_id", "=", pCtx.BatchInfo.RunId). - Cond("script_node", "=", pCtx.BatchInfo.TargetNodeName). - Cond("batch_idx", "=", pCtx.BatchInfo.BatchIdx). - Select(wfmodel.TableNameBatchHistory, fields) - rows, err := pCtx.CqlSession.Query(q).Iter().SliceMap() - if err != nil { - return wfmodel.NodeBatchNone, db.WrapDbErrorWithQuery(fmt.Sprintf("HarvestLastStatusForBatch: cannot get batch history for batch %s", pCtx.BatchInfo.FullBatchId()), q, err) - } - - lastStatus := wfmodel.NodeBatchNone - lastTs := time.Unix(0, 0) - for _, r := range rows { - rec, err := wfmodel.NewBatchHistoryEventFromMap(r, fields) - if err != nil { - return wfmodel.NodeBatchNone, fmt.Errorf("HarvestLastStatusForBatch: : cannot deserialize batch history row: %s, %s", err.Error(), q) - } - - if rec.Ts.After(lastTs) { - lastTs = rec.Ts - lastStatus = wfmodel.NodeBatchStatusType(rec.Status) - } - } - - logger.DebugCtx(pCtx, "batch %s, status %s", pCtx.BatchInfo.FullBatchId(), lastStatus.ToString()) - return lastStatus, nil -} - -func GetRunNodeBatchHistory(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runId int16, nodeName string) ([]*wfmodel.BatchHistoryEvent, error) { - logger.PushF("wfdb.GetRunNodeBatchHistory") - defer logger.PopF() - - q := (&cql.QueryBuilder{}). - Keyspace(keyspace). - Cond("run_id", "=", runId). - Cond("script_node", "=", nodeName). - Select(wfmodel.TableNameBatchHistory, wfmodel.BatchHistoryEventAllFields()) - rows, err := cqlSession.Query(q).Iter().SliceMap() - if err != nil { - return []*wfmodel.BatchHistoryEvent{}, db.WrapDbErrorWithQuery("GetRunNodeBatchHistory: cannot get node batch history", q, err) - } - - result := make([]*wfmodel.BatchHistoryEvent, len(rows)) - for rowIdx, row := range rows { - rec, err := wfmodel.NewBatchHistoryEventFromMap(row, wfmodel.BatchHistoryEventAllFields()) - if err != nil { - return []*wfmodel.BatchHistoryEvent{}, fmt.Errorf("cannot deserialize batch node history row %s, %s", err.Error(), q) - } - result[rowIdx] = rec - } - - sort.Slice(result, func(i, j int) bool { return result[i].Ts.Before(result[j].Ts) }) - - return result, nil -} - -func HarvestBatchStatusesForNode(logger *l.Logger, pCtx *ctx.MessageProcessingContext) (wfmodel.NodeBatchStatusType, error) { - logger.PushF("wfdb.HarvestBatchStatusesForNode") - defer logger.PopF() - - fields := []string{"status", "batch_idx", "batches_total"} - q := (&cql.QueryBuilder{}). - Keyspace(pCtx.BatchInfo.DataKeyspace). - Cond("run_id", "=", pCtx.BatchInfo.RunId). - Cond("script_node", "=", pCtx.BatchInfo.TargetNodeName). - Select(wfmodel.TableNameBatchHistory, fields) - rows, err := pCtx.CqlSession.Query(q).Iter().SliceMap() - if err != nil { - return wfmodel.NodeBatchNone, db.WrapDbErrorWithQuery(fmt.Sprintf("harvestBatchStatusesForNode: cannot get node batch history for node %s", pCtx.BatchInfo.FullBatchId()), q, err) - } - - foundBatchesTotal := int16(-1) - batchesInProgress := map[int16]struct{}{} - - failFound := false - stopReceivedFound := false - for _, r := range rows { - rec, err := wfmodel.NewBatchHistoryEventFromMap(r, fields) - if err != nil { - return wfmodel.NodeBatchNone, fmt.Errorf("harvestBatchStatusesForNode: cannot deserialize batch history row %s, %s", err.Error(), q) - } - if foundBatchesTotal == -1 { - foundBatchesTotal = rec.BatchesTotal - for i := int16(0); i < rec.BatchesTotal; i++ { - batchesInProgress[i] = struct{}{} - } - } else if rec.BatchesTotal != foundBatchesTotal { - return wfmodel.NodeBatchNone, fmt.Errorf("conflicting batches total value, was %d, now %d: %s, %s", foundBatchesTotal, rec.BatchesTotal, q, pCtx.BatchInfo.ToString()) - } - - if rec.BatchIdx >= rec.BatchesTotal || rec.BatchesTotal < 0 || rec.BatchesTotal <= 0 { - return wfmodel.NodeBatchNone, fmt.Errorf("invalid batch idx/total(%d/%d) when processing [%v]: %s, %s", rec.BatchIdx, rec.BatchesTotal, r, q, pCtx.BatchInfo.ToString()) - } - - if rec.Status == wfmodel.NodeBatchSuccess || - rec.Status == wfmodel.NodeBatchFail || - rec.Status == wfmodel.NodeBatchRunStopReceived { - delete(batchesInProgress, rec.BatchIdx) - } - - if rec.Status == wfmodel.NodeBatchFail { - failFound = true - } else if rec.Status == wfmodel.NodeBatchRunStopReceived { - stopReceivedFound = true - } - } - - if len(batchesInProgress) == 0 { - nodeStatus := wfmodel.NodeBatchSuccess - if stopReceivedFound { - nodeStatus = wfmodel.NodeBatchRunStopReceived - } - if failFound { - nodeStatus = wfmodel.NodeBatchFail - } - logger.InfoCtx(pCtx, "node %d/%s complete, status %s", pCtx.BatchInfo.RunId, pCtx.CurrentScriptNode.Name, nodeStatus.ToString()) - return nodeStatus, nil - } - - // Some batches are still not complete, and no run stop/fail/success for the whole node was signaled - logger.DebugCtx(pCtx, "node %d/%s incomplete, still waiting for %d/%d batches", pCtx.BatchInfo.RunId, pCtx.CurrentScriptNode.Name, len(batchesInProgress), foundBatchesTotal) - return wfmodel.NodeBatchStart, nil -} - -func SetBatchStatus(logger *l.Logger, pCtx *ctx.MessageProcessingContext, status wfmodel.NodeBatchStatusType, comment string) error { - logger.PushF("wfdb.SetBatchStatus") - defer logger.PopF() - - qb := cql.QueryBuilder{} - qb. - Keyspace(pCtx.BatchInfo.DataKeyspace). - WriteForceUnquote("ts", "toTimeStamp(now())"). - Write("run_id", pCtx.BatchInfo.RunId). - Write("script_node", pCtx.CurrentScriptNode.Name). - Write("batch_idx", pCtx.BatchInfo.BatchIdx). - Write("batches_total", pCtx.BatchInfo.BatchesTotal). - Write("status", status). - Write("first_token", pCtx.BatchInfo.FirstToken). - Write("last_token", pCtx.BatchInfo.LastToken). - Write("instance", logger.ZapMachine.String). - Write("thread", logger.ZapThread.Integer) - if len(comment) > 0 { - qb.Write("comment", comment) - } - - q := qb.InsertUnpreparedQuery(wfmodel.TableNameBatchHistory, cql.IgnoreIfExists) // If not exists. First one wins. - err := pCtx.CqlSession.Query(q).Exec() - if err != nil { - err := db.WrapDbErrorWithQuery(fmt.Sprintf("cannot write batch %s status %d", pCtx.BatchInfo.FullBatchId(), status), q, err) - logger.ErrorCtx(pCtx, err.Error()) - return err - } - - logger.DebugCtx(pCtx, "batch %s, set status %s", pCtx.BatchInfo.FullBatchId(), status.ToString()) - return nil -} +package wfdb + +import ( + "fmt" + "sort" + "time" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "github.com/gocql/gocql" +) + +func HarvestLastStatusForBatch(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) (wfmodel.NodeBatchStatusType, error) { + logger.PushF("wfdb.HarvestLastStatusForBatch") + defer logger.PopF() + + fields := []string{"ts", "status"} + q := (&cql.QueryBuilder{}). + Keyspace(pCtx.BatchInfo.DataKeyspace). + Cond("run_id", "=", pCtx.BatchInfo.RunId). + Cond("script_node", "=", pCtx.BatchInfo.TargetNodeName). + Cond("batch_idx", "=", pCtx.BatchInfo.BatchIdx). + Select(wfmodel.TableNameBatchHistory, fields) + rows, err := pCtx.CqlSession.Query(q).Iter().SliceMap() + if err != nil { + return wfmodel.NodeBatchNone, db.WrapDbErrorWithQuery(fmt.Sprintf("HarvestLastStatusForBatch: cannot get batch history for batch %s", pCtx.BatchInfo.FullBatchId()), q, err) + } + + lastStatus := wfmodel.NodeBatchNone + lastTs := time.Unix(0, 0) + for _, r := range rows { + rec, err := wfmodel.NewBatchHistoryEventFromMap(r, fields) + if err != nil { + return wfmodel.NodeBatchNone, fmt.Errorf("HarvestLastStatusForBatch: : cannot deserialize batch history row: %s, %s", err.Error(), q) + } + + if rec.Ts.After(lastTs) { + lastTs = rec.Ts + lastStatus = wfmodel.NodeBatchStatusType(rec.Status) + } + } + + logger.DebugCtx(pCtx, "batch %s, status %s", pCtx.BatchInfo.FullBatchId(), lastStatus.ToString()) + return lastStatus, nil +} + +func GetRunNodeBatchHistory(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runId int16, nodeName string) ([]*wfmodel.BatchHistoryEvent, error) { + logger.PushF("wfdb.GetRunNodeBatchHistory") + defer logger.PopF() + + q := (&cql.QueryBuilder{}). + Keyspace(keyspace). + Cond("run_id", "=", runId). + Cond("script_node", "=", nodeName). + Select(wfmodel.TableNameBatchHistory, wfmodel.BatchHistoryEventAllFields()) + rows, err := cqlSession.Query(q).Iter().SliceMap() + if err != nil { + return []*wfmodel.BatchHistoryEvent{}, db.WrapDbErrorWithQuery("GetRunNodeBatchHistory: cannot get node batch history", q, err) + } + + result := make([]*wfmodel.BatchHistoryEvent, len(rows)) + for rowIdx, row := range rows { + rec, err := wfmodel.NewBatchHistoryEventFromMap(row, wfmodel.BatchHistoryEventAllFields()) + if err != nil { + return []*wfmodel.BatchHistoryEvent{}, fmt.Errorf("cannot deserialize batch node history row %s, %s", err.Error(), q) + } + result[rowIdx] = rec + } + + sort.Slice(result, func(i, j int) bool { return result[i].Ts.Before(result[j].Ts) }) + + return result, nil +} + +func HarvestBatchStatusesForNode(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) (wfmodel.NodeBatchStatusType, error) { + logger.PushF("wfdb.HarvestBatchStatusesForNode") + defer logger.PopF() + + fields := []string{"status", "batch_idx", "batches_total"} + q := (&cql.QueryBuilder{}). + Keyspace(pCtx.BatchInfo.DataKeyspace). + Cond("run_id", "=", pCtx.BatchInfo.RunId). + Cond("script_node", "=", pCtx.BatchInfo.TargetNodeName). + Select(wfmodel.TableNameBatchHistory, fields) + rows, err := pCtx.CqlSession.Query(q).Iter().SliceMap() + if err != nil { + return wfmodel.NodeBatchNone, db.WrapDbErrorWithQuery(fmt.Sprintf("harvestBatchStatusesForNode: cannot get node batch history for node %s", pCtx.BatchInfo.FullBatchId()), q, err) + } + + foundBatchesTotal := int16(-1) + batchesInProgress := map[int16]struct{}{} + + failFound := false + stopReceivedFound := false + for _, r := range rows { + rec, err := wfmodel.NewBatchHistoryEventFromMap(r, fields) + if err != nil { + return wfmodel.NodeBatchNone, fmt.Errorf("harvestBatchStatusesForNode: cannot deserialize batch history row %s, %s", err.Error(), q) + } + if foundBatchesTotal == -1 { + foundBatchesTotal = rec.BatchesTotal + for i := int16(0); i < rec.BatchesTotal; i++ { + batchesInProgress[i] = struct{}{} + } + } else if rec.BatchesTotal != foundBatchesTotal { + return wfmodel.NodeBatchNone, fmt.Errorf("conflicting batches total value, was %d, now %d: %s, %s", foundBatchesTotal, rec.BatchesTotal, q, pCtx.BatchInfo.ToString()) + } + + if rec.BatchIdx >= rec.BatchesTotal || rec.BatchesTotal < 0 || rec.BatchesTotal <= 0 { + return wfmodel.NodeBatchNone, fmt.Errorf("invalid batch idx/total(%d/%d) when processing [%v]: %s, %s", rec.BatchIdx, rec.BatchesTotal, r, q, pCtx.BatchInfo.ToString()) + } + + if rec.Status == wfmodel.NodeBatchSuccess || + rec.Status == wfmodel.NodeBatchFail || + rec.Status == wfmodel.NodeBatchRunStopReceived { + delete(batchesInProgress, rec.BatchIdx) + } + + if rec.Status == wfmodel.NodeBatchFail { + failFound = true + } else if rec.Status == wfmodel.NodeBatchRunStopReceived { + stopReceivedFound = true + } + } + + if len(batchesInProgress) == 0 { + nodeStatus := wfmodel.NodeBatchSuccess + if stopReceivedFound { + nodeStatus = wfmodel.NodeBatchRunStopReceived + } + if failFound { + nodeStatus = wfmodel.NodeBatchFail + } + logger.InfoCtx(pCtx, "node %d/%s complete, status %s", pCtx.BatchInfo.RunId, pCtx.CurrentScriptNode.Name, nodeStatus.ToString()) + return nodeStatus, nil + } + + // Some batches are still not complete, and no run stop/fail/success for the whole node was signaled + logger.DebugCtx(pCtx, "node %d/%s incomplete, still waiting for %d/%d batches", pCtx.BatchInfo.RunId, pCtx.CurrentScriptNode.Name, len(batchesInProgress), foundBatchesTotal) + return wfmodel.NodeBatchStart, nil +} + +func SetBatchStatus(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, status wfmodel.NodeBatchStatusType, comment string) error { + logger.PushF("wfdb.SetBatchStatus") + defer logger.PopF() + + qb := cql.QueryBuilder{} + qb. + Keyspace(pCtx.BatchInfo.DataKeyspace). + WriteForceUnquote("ts", "toTimeStamp(now())"). + Write("run_id", pCtx.BatchInfo.RunId). + Write("script_node", pCtx.CurrentScriptNode.Name). + Write("batch_idx", pCtx.BatchInfo.BatchIdx). + Write("batches_total", pCtx.BatchInfo.BatchesTotal). + Write("status", status). + Write("first_token", pCtx.BatchInfo.FirstToken). + Write("last_token", pCtx.BatchInfo.LastToken). + Write("instance", logger.ZapMachine.String). + Write("thread", logger.ZapThread.Integer) + if len(comment) > 0 { + qb.Write("comment", comment) + } + + q := qb.InsertUnpreparedQuery(wfmodel.TableNameBatchHistory, cql.IgnoreIfExists) // If not exists. First one wins. + err := pCtx.CqlSession.Query(q).Exec() + if err != nil { + err := db.WrapDbErrorWithQuery(fmt.Sprintf("cannot write batch %s status %d", pCtx.BatchInfo.FullBatchId(), status), q, err) + logger.ErrorCtx(pCtx, err.Error()) + return err + } + + logger.DebugCtx(pCtx, "batch %s, set status %s", pCtx.BatchInfo.FullBatchId(), status.ToString()) + return nil +} diff --git a/pkg/wfdb/dependency_node_event.go b/pkg/wfdb/dependency_node_event.go index c4b617b..d488c30 100644 --- a/pkg/wfdb/dependency_node_event.go +++ b/pkg/wfdb/dependency_node_event.go @@ -1,72 +1,72 @@ -package wfdb - -import ( - "fmt" - "time" - - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/wfmodel" -) - -func BuildDependencyNodeEventLists(logger *l.Logger, pCtx *ctx.MessageProcessingContext, depNodeNames []string) (map[string][]wfmodel.DependencyNodeEvent, error) { - logger.PushF("wfdb.buildDependencyNodeEventLists") - defer logger.PopF() - - affectingRunIds, nodeAffectingRunIdsMap, err := HarvestRunIdsByAffectedNodes(logger, pCtx, depNodeNames) - if err != nil { - return nil, err - } - - runLifespanMap, err := HarvestRunLifespans(logger, pCtx.CqlSession, pCtx.BatchInfo.DataKeyspace, affectingRunIds) - if err != nil { - return nil, err - } - - runNodeLifespanMap, err := HarvestNodeLifespans(logger, pCtx, affectingRunIds, depNodeNames) - if err != nil { - return nil, err - } - - nodeEventListMap := map[string][]wfmodel.DependencyNodeEvent{} - for _, nodeName := range depNodeNames { - nodeEventListMap[nodeName] = []wfmodel.DependencyNodeEvent{} - // Walk through only runs that affect this specific node. Do not use all affectingRunIds here. - for _, affectingRunId := range nodeAffectingRunIdsMap[nodeName] { - runLifespan, ok := runLifespanMap[affectingRunId] - if !ok { - return nil, fmt.Errorf("unexpectedly, cannot find run lifespan map for run %d: %s", affectingRunId, runLifespanMap.ToString()) - } - if runLifespan.StartTs == time.Unix(0, 0) || runLifespan.FinalStatus == wfmodel.RunNone { - return nil, fmt.Errorf("unexpectedly, run lifespan %d looks like the run never started: %s", affectingRunId, runLifespanMap.ToString()) - } - e := wfmodel.DependencyNodeEvent{ - RunId: affectingRunId, - RunIsCurrent: affectingRunId == pCtx.BatchInfo.RunId, - RunStartTs: runLifespan.StartTs, - RunFinalStatus: runLifespan.FinalStatus, - RunCompletedTs: runLifespan.CompletedTs, - RunStoppedTs: runLifespan.StoppedTs, - NodeIsStarted: false, - NodeStartTs: time.Time{}, - NodeStatus: wfmodel.NodeBatchNone, - NodeStatusTs: time.Time{}} - - nodeLifespanMap, ok := runNodeLifespanMap[affectingRunId] - if !ok { - return nil, fmt.Errorf("unexpectedly, cannot find node lifespan map for run %d: %s", affectingRunId, runNodeLifespanMap.ToString()) - } - - if nodeLifespan, ok := nodeLifespanMap[nodeName]; ok { - // This run already started this node, so it has some status. Update last few attributes. - e.NodeIsStarted = true - e.NodeStartTs = nodeLifespan.StartTs - e.NodeStatus = nodeLifespan.LastStatus - e.NodeStatusTs = nodeLifespan.LastStatusTs - } - - nodeEventListMap[nodeName] = append(nodeEventListMap[nodeName], e) - } - } - return nodeEventListMap, nil -} +package wfdb + +import ( + "fmt" + "time" + + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/wfmodel" +) + +func BuildDependencyNodeEventLists(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, depNodeNames []string) (map[string][]wfmodel.DependencyNodeEvent, error) { + logger.PushF("wfdb.buildDependencyNodeEventLists") + defer logger.PopF() + + affectingRunIds, nodeAffectingRunIdsMap, err := HarvestRunIdsByAffectedNodes(logger, pCtx) + if err != nil { + return nil, err + } + + runLifespanMap, err := HarvestRunLifespans(logger, pCtx.CqlSession, pCtx.BatchInfo.DataKeyspace, affectingRunIds) + if err != nil { + return nil, err + } + + runNodeLifespanMap, err := HarvestNodeLifespans(logger, pCtx, affectingRunIds, depNodeNames) + if err != nil { + return nil, err + } + + nodeEventListMap := map[string][]wfmodel.DependencyNodeEvent{} + for _, nodeName := range depNodeNames { + nodeEventListMap[nodeName] = []wfmodel.DependencyNodeEvent{} + // Walk through only runs that affect this specific node. Do not use all affectingRunIds here. + for _, affectingRunId := range nodeAffectingRunIdsMap[nodeName] { + runLifespan, ok := runLifespanMap[affectingRunId] + if !ok { + return nil, fmt.Errorf("unexpectedly, cannot find run lifespan map for run %d: %s", affectingRunId, runLifespanMap.ToString()) + } + if runLifespan.StartTs == time.Unix(0, 0) || runLifespan.FinalStatus == wfmodel.RunNone { + return nil, fmt.Errorf("unexpectedly, run lifespan %d looks like the run never started: %s", affectingRunId, runLifespanMap.ToString()) + } + e := wfmodel.DependencyNodeEvent{ + RunId: affectingRunId, + RunIsCurrent: affectingRunId == pCtx.BatchInfo.RunId, + RunStartTs: runLifespan.StartTs, + RunFinalStatus: runLifespan.FinalStatus, + RunCompletedTs: runLifespan.CompletedTs, + RunStoppedTs: runLifespan.StoppedTs, + NodeIsStarted: false, + NodeStartTs: time.Time{}, + NodeStatus: wfmodel.NodeBatchNone, + NodeStatusTs: time.Time{}} + + nodeLifespanMap, ok := runNodeLifespanMap[affectingRunId] + if !ok { + return nil, fmt.Errorf("unexpectedly, cannot find node lifespan map for run %d: %s", affectingRunId, runNodeLifespanMap.ToString()) + } + + if nodeLifespan, ok := nodeLifespanMap[nodeName]; ok { + // This run already started this node, so it has some status. Update last few attributes. + e.NodeIsStarted = true + e.NodeStartTs = nodeLifespan.StartTs + e.NodeStatus = nodeLifespan.LastStatus + e.NodeStatusTs = nodeLifespan.LastStatusTs + } + + nodeEventListMap[nodeName] = append(nodeEventListMap[nodeName], e) + } + } + return nodeEventListMap, nil +} diff --git a/pkg/wfdb/node_history.go b/pkg/wfdb/node_history.go index 03de895..b52e546 100644 --- a/pkg/wfdb/node_history.go +++ b/pkg/wfdb/node_history.go @@ -1,179 +1,179 @@ -package wfdb - -import ( - "fmt" - "sort" - "time" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "github.com/gocql/gocql" -) - -func HarvestNodeStatusesForRun(logger *l.Logger, pCtx *ctx.MessageProcessingContext, affectedNodes []string) (wfmodel.NodeBatchStatusType, string, error) { - logger.PushF("wfdb.HarvestNodeStatusesForRun") - defer logger.PopF() - - fields := []string{"script_node", "status"} - q := (&cql.QueryBuilder{}). - Keyspace(pCtx.BatchInfo.DataKeyspace). - Cond("run_id", "=", pCtx.BatchInfo.RunId). - CondInString("script_node", affectedNodes). // TODO: Is this really necessary? Shouldn't run id be enough? Of course, it's safer to be extra cautious, but...? - Select(wfmodel.TableNameNodeHistory, fields) - rows, err := pCtx.CqlSession.Query(q).Iter().SliceMap() - if err != nil { - return wfmodel.NodeBatchNone, "", db.WrapDbErrorWithQuery(fmt.Sprintf("cannot get node history for %s", pCtx.BatchInfo.FullBatchId()), q, err) - } - - nodeStatusMap := wfmodel.NodeStatusMap{} - for _, affectedNodeName := range affectedNodes { - nodeStatusMap[affectedNodeName] = wfmodel.NodeBatchNone - } - - nodeEvents := make([]*wfmodel.NodeHistoryEvent, len(rows)) - - for idx, r := range rows { - rec, err := wfmodel.NewNodeHistoryEventFromMap(r, fields) - if err != nil { - return wfmodel.NodeBatchNone, "", fmt.Errorf("cannot deserialize node history row %s, %s", err.Error(), q) - } - nodeEvents[idx] = rec - } - - sort.Slice(nodeEvents, func(i, j int) bool { return nodeEvents[i].Ts.Before(nodeEvents[j].Ts) }) - - for _, e := range nodeEvents { - lastStatus, ok := nodeStatusMap[e.ScriptNode] - if !ok { - nodeStatusMap[e.ScriptNode] = e.Status - } else { - // Stopreceived is higher priority than anything else - if lastStatus != wfmodel.NodeBatchRunStopReceived { - nodeStatusMap[e.ScriptNode] = e.Status - } - } - } - - highestStatus := wfmodel.NodeBatchNone - lowestStatus := wfmodel.NodeBatchRunStopReceived - for _, status := range nodeStatusMap { - if status > highestStatus { - highestStatus = status - } - if status < lowestStatus { - lowestStatus = status - } - } - - if lowestStatus > wfmodel.NodeBatchStart { - logger.InfoCtx(pCtx, "run %d complete, status map %s", pCtx.BatchInfo.RunId, nodeStatusMap.ToString()) - return highestStatus, nodeStatusMap.ToString(), nil - } else { - logger.DebugCtx(pCtx, "run %d incomplete, lowest status %s, status map %s", pCtx.BatchInfo.RunId, lowestStatus.ToString(), nodeStatusMap.ToString()) - return lowestStatus, nodeStatusMap.ToString(), nil - } -} - -func HarvestNodeLifespans(logger *l.Logger, pCtx *ctx.MessageProcessingContext, affectingRuns []int16, affectedNodes []string) (wfmodel.RunNodeLifespanMap, error) { - logger.PushF("wfdb.HarvestLastNodeStatuses") - defer logger.PopF() - - fields := []string{"ts", "run_id", "script_node", "status"} - q := (&cql.QueryBuilder{}). - Keyspace(pCtx.BatchInfo.DataKeyspace). - CondInInt16("run_id", affectingRuns). - CondInString("script_node", affectedNodes). - Select(wfmodel.TableNameNodeHistory, fields) - rows, err := pCtx.CqlSession.Query(q).Iter().SliceMap() - if err != nil { - return nil, db.WrapDbErrorWithQuery("cannot get node history", q, err) - } - - runNodeLifespanMap := wfmodel.RunNodeLifespanMap{} - for _, runId := range affectingRuns { - runNodeLifespanMap[runId] = wfmodel.NodeLifespanMap{} - for _, nodeName := range affectedNodes { - runNodeLifespanMap[runId][nodeName] = &wfmodel.NodeLifespan{ - StartTs: time.Time{}, - LastStatus: wfmodel.NodeBatchNone, - LastStatusTs: time.Time{}} - } - } - - for _, r := range rows { - rec, err := wfmodel.NewNodeHistoryEventFromMap(r, fields) - if err != nil { - return nil, fmt.Errorf("%s, %s", err.Error(), q) - } - - nodeLifespanMap, ok := runNodeLifespanMap[rec.RunId] - if !ok { - return nil, fmt.Errorf("unexpected run_id %d in the result %s", rec.RunId, q) - } - - if rec.Status == wfmodel.NodeStart { - nodeLifespanMap[rec.ScriptNode].StartTs = rec.Ts - } - - // Later status wins, Stop always wins - if rec.Ts.After(nodeLifespanMap[rec.ScriptNode].LastStatusTs) || rec.Status == wfmodel.NodeBatchRunStopReceived { - nodeLifespanMap[rec.ScriptNode].LastStatus = rec.Status - nodeLifespanMap[rec.ScriptNode].LastStatusTs = rec.Ts - } - } - return runNodeLifespanMap, nil -} - -func SetNodeStatus(logger *l.Logger, pCtx *ctx.MessageProcessingContext, status wfmodel.NodeBatchStatusType, comment string) (bool, error) { - logger.PushF("wfdb.SetNodeStatus") - defer logger.PopF() - - q := (&cql.QueryBuilder{}). - Keyspace(pCtx.BatchInfo.DataKeyspace). - WriteForceUnquote("ts", "toTimeStamp(now())"). - Write("run_id", pCtx.BatchInfo.RunId). - Write("script_node", pCtx.CurrentScriptNode.Name). - Write("status", status). - Write("comment", comment). - InsertUnpreparedQuery(wfmodel.TableNameNodeHistory, cql.IgnoreIfExists) // If not exists. First one wins. - - existingDataRow := map[string]interface{}{} - isApplied, err := pCtx.CqlSession.Query(q).MapScanCAS(existingDataRow) - - if err != nil { - err = db.WrapDbErrorWithQuery(fmt.Sprintf("cannot update node %d/%s status to %d", pCtx.BatchInfo.RunId, pCtx.BatchInfo.TargetNodeName, status), q, err) - logger.ErrorCtx(pCtx, err.Error()) - return false, err - } - logger.DebugCtx(pCtx, "%d/%s, %s, isApplied=%t", pCtx.BatchInfo.RunId, pCtx.CurrentScriptNode.Name, status.ToString(), isApplied) - return isApplied, nil -} - -func GetNodeHistoryForRun(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runId int16) ([]*wfmodel.NodeHistoryEvent, error) { - logger.PushF("wfdb.GetNodeHistoryForRun") - defer logger.PopF() - - q := (&cql.QueryBuilder{}). - Keyspace(keyspace). - Cond("run_id", "=", runId). - Select(wfmodel.TableNameNodeHistory, wfmodel.NodeHistoryEventAllFields()) - rows, err := cqlSession.Query(q).Iter().SliceMap() - if err != nil { - return []*wfmodel.NodeHistoryEvent{}, db.WrapDbErrorWithQuery(fmt.Sprintf("cannot get node history for run %d", runId), q, err) - } - - result := make([]*wfmodel.NodeHistoryEvent, len(rows)) - - for idx, r := range rows { - rec, err := wfmodel.NewNodeHistoryEventFromMap(r, wfmodel.NodeHistoryEventAllFields()) - if err != nil { - return []*wfmodel.NodeHistoryEvent{}, fmt.Errorf("cannot deserialize node history row %s, %s", err.Error(), q) - } - result[idx] = rec - } - - return result, nil -} +package wfdb + +import ( + "fmt" + "sort" + "time" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "github.com/gocql/gocql" +) + +func HarvestNodeStatusesForRun(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, affectedNodes []string) (wfmodel.NodeBatchStatusType, string, error) { + logger.PushF("wfdb.HarvestNodeStatusesForRun") + defer logger.PopF() + + fields := []string{"script_node", "status"} + q := (&cql.QueryBuilder{}). + Keyspace(pCtx.BatchInfo.DataKeyspace). + Cond("run_id", "=", pCtx.BatchInfo.RunId). + CondInString("script_node", affectedNodes). // TODO: Is this really necessary? Shouldn't run id be enough? Of course, it's safer to be extra cautious, but...? + Select(wfmodel.TableNameNodeHistory, fields) + rows, err := pCtx.CqlSession.Query(q).Iter().SliceMap() + if err != nil { + return wfmodel.NodeBatchNone, "", db.WrapDbErrorWithQuery(fmt.Sprintf("cannot get node history for %s", pCtx.BatchInfo.FullBatchId()), q, err) + } + + nodeStatusMap := wfmodel.NodeStatusMap{} + for _, affectedNodeName := range affectedNodes { + nodeStatusMap[affectedNodeName] = wfmodel.NodeBatchNone + } + + nodeEvents := make([]*wfmodel.NodeHistoryEvent, len(rows)) + + for idx, r := range rows { + rec, err := wfmodel.NewNodeHistoryEventFromMap(r, fields) + if err != nil { + return wfmodel.NodeBatchNone, "", fmt.Errorf("cannot deserialize node history row %s, %s", err.Error(), q) + } + nodeEvents[idx] = rec + } + + sort.Slice(nodeEvents, func(i, j int) bool { return nodeEvents[i].Ts.Before(nodeEvents[j].Ts) }) + + for _, e := range nodeEvents { + lastStatus, ok := nodeStatusMap[e.ScriptNode] + if !ok { + nodeStatusMap[e.ScriptNode] = e.Status + } else { + // Stopreceived is higher priority than anything else + if lastStatus != wfmodel.NodeBatchRunStopReceived { + nodeStatusMap[e.ScriptNode] = e.Status + } + } + } + + highestStatus := wfmodel.NodeBatchNone + lowestStatus := wfmodel.NodeBatchRunStopReceived + for _, status := range nodeStatusMap { + if status > highestStatus { + highestStatus = status + } + if status < lowestStatus { + lowestStatus = status + } + } + + if lowestStatus > wfmodel.NodeBatchStart { + logger.InfoCtx(pCtx, "run %d complete, status map %s", pCtx.BatchInfo.RunId, nodeStatusMap.ToString()) + return highestStatus, nodeStatusMap.ToString(), nil + } + + logger.DebugCtx(pCtx, "run %d incomplete, lowest status %s, status map %s", pCtx.BatchInfo.RunId, lowestStatus.ToString(), nodeStatusMap.ToString()) + return lowestStatus, nodeStatusMap.ToString(), nil +} + +func HarvestNodeLifespans(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, affectingRuns []int16, affectedNodes []string) (wfmodel.RunNodeLifespanMap, error) { + logger.PushF("wfdb.HarvestLastNodeStatuses") + defer logger.PopF() + + fields := []string{"ts", "run_id", "script_node", "status"} + q := (&cql.QueryBuilder{}). + Keyspace(pCtx.BatchInfo.DataKeyspace). + CondInInt16("run_id", affectingRuns). + CondInString("script_node", affectedNodes). + Select(wfmodel.TableNameNodeHistory, fields) + rows, err := pCtx.CqlSession.Query(q).Iter().SliceMap() + if err != nil { + return nil, db.WrapDbErrorWithQuery("cannot get node history", q, err) + } + + runNodeLifespanMap := wfmodel.RunNodeLifespanMap{} + for _, runId := range affectingRuns { + runNodeLifespanMap[runId] = wfmodel.NodeLifespanMap{} + for _, nodeName := range affectedNodes { + runNodeLifespanMap[runId][nodeName] = &wfmodel.NodeLifespan{ + StartTs: time.Time{}, + LastStatus: wfmodel.NodeBatchNone, + LastStatusTs: time.Time{}} + } + } + + for _, r := range rows { + rec, err := wfmodel.NewNodeHistoryEventFromMap(r, fields) + if err != nil { + return nil, fmt.Errorf("%s, %s", err.Error(), q) + } + + nodeLifespanMap, ok := runNodeLifespanMap[rec.RunId] + if !ok { + return nil, fmt.Errorf("unexpected run_id %d in the result %s", rec.RunId, q) + } + + if rec.Status == wfmodel.NodeStart { + nodeLifespanMap[rec.ScriptNode].StartTs = rec.Ts + } + + // Later status wins, Stop always wins + if rec.Ts.After(nodeLifespanMap[rec.ScriptNode].LastStatusTs) || rec.Status == wfmodel.NodeBatchRunStopReceived { + nodeLifespanMap[rec.ScriptNode].LastStatus = rec.Status + nodeLifespanMap[rec.ScriptNode].LastStatusTs = rec.Ts + } + } + return runNodeLifespanMap, nil +} + +func SetNodeStatus(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext, status wfmodel.NodeBatchStatusType, comment string) (bool, error) { + logger.PushF("wfdb.SetNodeStatus") + defer logger.PopF() + + q := (&cql.QueryBuilder{}). + Keyspace(pCtx.BatchInfo.DataKeyspace). + WriteForceUnquote("ts", "toTimeStamp(now())"). + Write("run_id", pCtx.BatchInfo.RunId). + Write("script_node", pCtx.CurrentScriptNode.Name). + Write("status", status). + Write("comment", comment). + InsertUnpreparedQuery(wfmodel.TableNameNodeHistory, cql.IgnoreIfExists) // If not exists. First one wins. + + existingDataRow := map[string]any{} + isApplied, err := pCtx.CqlSession.Query(q).MapScanCAS(existingDataRow) + + if err != nil { + err = db.WrapDbErrorWithQuery(fmt.Sprintf("cannot update node %d/%s status to %d", pCtx.BatchInfo.RunId, pCtx.BatchInfo.TargetNodeName, status), q, err) + logger.ErrorCtx(pCtx, err.Error()) + return false, err + } + logger.DebugCtx(pCtx, "%d/%s, %s, isApplied=%t", pCtx.BatchInfo.RunId, pCtx.CurrentScriptNode.Name, status.ToString(), isApplied) + return isApplied, nil +} + +func GetNodeHistoryForRun(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runId int16) ([]*wfmodel.NodeHistoryEvent, error) { + logger.PushF("wfdb.GetNodeHistoryForRun") + defer logger.PopF() + + q := (&cql.QueryBuilder{}). + Keyspace(keyspace). + Cond("run_id", "=", runId). + Select(wfmodel.TableNameNodeHistory, wfmodel.NodeHistoryEventAllFields()) + rows, err := cqlSession.Query(q).Iter().SliceMap() + if err != nil { + return []*wfmodel.NodeHistoryEvent{}, db.WrapDbErrorWithQuery(fmt.Sprintf("cannot get node history for run %d", runId), q, err) + } + + result := make([]*wfmodel.NodeHistoryEvent, len(rows)) + + for idx, r := range rows { + rec, err := wfmodel.NewNodeHistoryEventFromMap(r, wfmodel.NodeHistoryEventAllFields()) + if err != nil { + return []*wfmodel.NodeHistoryEvent{}, fmt.Errorf("cannot deserialize node history row %s, %s", err.Error(), q) + } + result[idx] = rec + } + + return result, nil +} diff --git a/pkg/wfdb/run_counter.go b/pkg/wfdb/run_counter.go index 4888790..f4e1faf 100644 --- a/pkg/wfdb/run_counter.go +++ b/pkg/wfdb/run_counter.go @@ -1,59 +1,59 @@ -package wfdb - -import ( - "fmt" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "github.com/gocql/gocql" -) - -func GetNextRunCounter(logger *l.Logger, cqlSession *gocql.Session, keyspace string) (int16, error) { - logger.PushF("wfdb.GetNextRunCounter") - defer logger.PopF() - - maxRetries := 100 - for retryCount := 0; retryCount < maxRetries; retryCount++ { - - // Initialize optimistic locking - q := (&cql.QueryBuilder{}). - Keyspace(keyspace). - Select(wfmodel.TableNameRunCounter, []string{"last_run"}) - rows, err := cqlSession.Query(q).Iter().SliceMap() - if err != nil { - return 0, db.WrapDbErrorWithQuery("cannot get run counter", q, err) - } - - if len(rows) != 1 { - return 0, fmt.Errorf("cannot get run counter, wrong number of rows: %s, %s", q, err.Error()) - } - - lastRunId, ok := rows[0]["last_run"].(int) - if !ok { - return 0, fmt.Errorf("cannot get run counter from [%v]: %s, %s", rows[0], q, err.Error()) - } - - // Try incrementing - newRunId := lastRunId + 1 - q = (&cql.QueryBuilder{}). - Keyspace(keyspace). - Write("last_run", newRunId). - Cond("ks", "=", keyspace). - If("last_run", "=", lastRunId). - Update(wfmodel.TableNameRunCounter) - existingDataRow := map[string]interface{}{} - isApplied, err := cqlSession.Query(q).MapScanCAS(existingDataRow) - - if err != nil { - return 0, db.WrapDbErrorWithQuery("cannot increment run counter", q, err) - } else if isApplied { - return int16(newRunId), nil - } - - // Retry - logger.Info("GetNextRunCounter: retry %d", retryCount) - } - return 0, fmt.Errorf("cannot increment run counter, too many attempts") -} +package wfdb + +import ( + "fmt" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "github.com/gocql/gocql" +) + +func GetNextRunCounter(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string) (int16, error) { + logger.PushF("wfdb.GetNextRunCounter") + defer logger.PopF() + + maxRetries := 100 + for retryCount := 0; retryCount < maxRetries; retryCount++ { + + // Initialize optimistic locking + q := (&cql.QueryBuilder{}). + Keyspace(keyspace). + Select(wfmodel.TableNameRunCounter, []string{"last_run"}) + rows, err := cqlSession.Query(q).Iter().SliceMap() + if err != nil { + return 0, db.WrapDbErrorWithQuery("cannot get run counter", q, err) + } + + if len(rows) != 1 { + return 0, fmt.Errorf("cannot get run counter, wrong number of rows: %s, %s", q, err.Error()) + } + + lastRunId, ok := rows[0]["last_run"].(int) + if !ok { + return 0, fmt.Errorf("cannot get run counter from [%v]: %s, %s", rows[0], q, err.Error()) + } + + // Try incrementing + newRunId := lastRunId + 1 + q = (&cql.QueryBuilder{}). + Keyspace(keyspace). + Write("last_run", newRunId). + Cond("ks", "=", keyspace). + If("last_run", "=", lastRunId). + Update(wfmodel.TableNameRunCounter) + existingDataRow := map[string]any{} + isApplied, err := cqlSession.Query(q).MapScanCAS(existingDataRow) + + if err != nil { + return 0, db.WrapDbErrorWithQuery("cannot increment run counter", q, err) + } else if isApplied { + return int16(newRunId), nil + } + + // Retry + logger.Info("GetNextRunCounter: retry %d", retryCount) + } + return 0, fmt.Errorf("cannot increment run counter, too many attempts") +} diff --git a/pkg/wfdb/run_history.go b/pkg/wfdb/run_history.go index e923ab8..b0a2b85 100644 --- a/pkg/wfdb/run_history.go +++ b/pkg/wfdb/run_history.go @@ -1,120 +1,120 @@ -package wfdb - -import ( - "fmt" - "sort" - "time" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "github.com/gocql/gocql" -) - -func GetCurrentRunStatus(logger *l.Logger, pCtx *ctx.MessageProcessingContext) (wfmodel.RunStatusType, error) { - logger.PushF("wfdb.GetCurrentRunStatus") - defer logger.PopF() - - fields := []string{"ts", "status"} - qb := cql.QueryBuilder{} - q := qb. - Keyspace(pCtx.BatchInfo.DataKeyspace). - Cond("run_id", "=", pCtx.BatchInfo.RunId). - Select(wfmodel.TableNameRunHistory, fields) - rows, err := pCtx.CqlSession.Query(q).Iter().SliceMap() - if err != nil { - return wfmodel.RunNone, db.WrapDbErrorWithQuery(fmt.Sprintf("cannot query run status for %s", pCtx.BatchInfo.FullBatchId()), q, err) - } - - lastStatus := wfmodel.RunNone - lastTs := time.Unix(0, 0) - for _, r := range rows { - rec, err := wfmodel.NewRunHistoryEventFromMap(r, fields) - if err != nil { - return wfmodel.RunNone, fmt.Errorf("%s, %s", err.Error(), q) - } - - if rec.Ts.After(lastTs) { - lastTs = rec.Ts - lastStatus = wfmodel.RunStatusType(rec.Status) - } - } - - logger.DebugCtx(pCtx, "batch %s, run status %s", pCtx.BatchInfo.FullBatchId(), lastStatus.ToString()) - return lastStatus, nil -} - -func HarvestRunLifespans(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runIds []int16) (wfmodel.RunLifespanMap, error) { - logger.PushF("wfdb.HarvestRunLifespans") - defer logger.PopF() - - qb := (&cql.QueryBuilder{}).Keyspace(keyspace) - if len(runIds) > 0 { - qb.CondInInt16("run_id", runIds) - } - q := qb.Select(wfmodel.TableNameRunHistory, wfmodel.RunHistoryEventAllFields()) - rows, err := cqlSession.Query(q).Iter().SliceMap() - if err != nil { - return nil, db.WrapDbErrorWithQuery("cannot get run statuses for a list of run ids", q, err) - } - - events := make([]*wfmodel.RunHistoryEvent, len(rows)) - - for idx, r := range rows { - rec, err := wfmodel.NewRunHistoryEventFromMap(r, wfmodel.RunHistoryEventAllFields()) - if err != nil { - return nil, fmt.Errorf("%s, %s", err.Error(), q) - } - events[idx] = rec - } - - sort.Slice(events, func(i, j int) bool { return events[i].Ts.Before(events[j].Ts) }) - - runLifespanMap := wfmodel.RunLifespanMap{} - emptyUnix := time.Time{}.Unix() - for _, e := range events { - if e.Status == wfmodel.RunStart { - runLifespanMap[e.RunId] = &wfmodel.RunLifespan{RunId: e.RunId, StartTs: e.Ts, StartComment: e.Comment, FinalStatus: wfmodel.RunStart, CompletedTs: time.Time{}, StoppedTs: time.Time{}} - } else { - _, ok := runLifespanMap[e.RunId] - if !ok { - return nil, fmt.Errorf("unexpected sequence of run status events: %v, %s", events, q) - } - if e.Status == wfmodel.RunComplete && runLifespanMap[e.RunId].CompletedTs.Unix() == emptyUnix { - runLifespanMap[e.RunId].CompletedTs = e.Ts - runLifespanMap[e.RunId].CompletedComment = e.Comment - if runLifespanMap[e.RunId].StoppedTs.Unix() == emptyUnix { - runLifespanMap[e.RunId].FinalStatus = wfmodel.RunComplete // If it was not stopped so far, consider it complete - } - } else if e.Status == wfmodel.RunStop && runLifespanMap[e.RunId].StoppedTs.Unix() == emptyUnix { - runLifespanMap[e.RunId].StoppedTs = e.Ts - runLifespanMap[e.RunId].StoppedComment = e.Comment - runLifespanMap[e.RunId].FinalStatus = wfmodel.RunStop // Stop always wins as final status, it may be sign for dependency checker to declare results invalid (depending on the rules) - } - } - } - - return runLifespanMap, nil -} - -func SetRunStatus(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runId int16, status wfmodel.RunStatusType, comment string, ifNotExistsFlag cql.IfNotExistsType) error { - logger.PushF("wfdb.SetRunStatus") - defer logger.PopF() - - q := (&cql.QueryBuilder{}). - Keyspace(keyspace). - WriteForceUnquote("ts", "toTimeStamp(now())"). - Write("run_id", runId). - Write("status", status). - Write("comment", comment). - InsertUnpreparedQuery(wfmodel.TableNameRunHistory, ifNotExistsFlag) - err := cqlSession.Query(q).Exec() - if err != nil { - return db.WrapDbErrorWithQuery("cannot write run status", q, err) - } - - logger.Debug("run %d, status %s", runId, status.ToString()) - return nil -} +package wfdb + +import ( + "fmt" + "sort" + "time" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "github.com/gocql/gocql" +) + +func GetCurrentRunStatus(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) (wfmodel.RunStatusType, error) { + logger.PushF("wfdb.GetCurrentRunStatus") + defer logger.PopF() + + fields := []string{"ts", "status"} + qb := cql.QueryBuilder{} + q := qb. + Keyspace(pCtx.BatchInfo.DataKeyspace). + Cond("run_id", "=", pCtx.BatchInfo.RunId). + Select(wfmodel.TableNameRunHistory, fields) + rows, err := pCtx.CqlSession.Query(q).Iter().SliceMap() + if err != nil { + return wfmodel.RunNone, db.WrapDbErrorWithQuery(fmt.Sprintf("cannot query run status for %s", pCtx.BatchInfo.FullBatchId()), q, err) + } + + lastStatus := wfmodel.RunNone + lastTs := time.Unix(0, 0) + for _, r := range rows { + rec, err := wfmodel.NewRunHistoryEventFromMap(r, fields) + if err != nil { + return wfmodel.RunNone, fmt.Errorf("%s, %s", err.Error(), q) + } + + if rec.Ts.After(lastTs) { + lastTs = rec.Ts + lastStatus = wfmodel.RunStatusType(rec.Status) + } + } + + logger.DebugCtx(pCtx, "batch %s, run status %s", pCtx.BatchInfo.FullBatchId(), lastStatus.ToString()) + return lastStatus, nil +} + +func HarvestRunLifespans(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runIds []int16) (wfmodel.RunLifespanMap, error) { + logger.PushF("wfdb.HarvestRunLifespans") + defer logger.PopF() + + qb := (&cql.QueryBuilder{}).Keyspace(keyspace) + if len(runIds) > 0 { + qb.CondInInt16("run_id", runIds) + } + q := qb.Select(wfmodel.TableNameRunHistory, wfmodel.RunHistoryEventAllFields()) + rows, err := cqlSession.Query(q).Iter().SliceMap() + if err != nil { + return nil, db.WrapDbErrorWithQuery("cannot get run statuses for a list of run ids", q, err) + } + + events := make([]*wfmodel.RunHistoryEvent, len(rows)) + + for idx, r := range rows { + rec, err := wfmodel.NewRunHistoryEventFromMap(r, wfmodel.RunHistoryEventAllFields()) + if err != nil { + return nil, fmt.Errorf("%s, %s", err.Error(), q) + } + events[idx] = rec + } + + sort.Slice(events, func(i, j int) bool { return events[i].Ts.Before(events[j].Ts) }) + + runLifespanMap := wfmodel.RunLifespanMap{} + emptyUnix := time.Time{}.Unix() + for _, e := range events { + if e.Status == wfmodel.RunStart { + runLifespanMap[e.RunId] = &wfmodel.RunLifespan{RunId: e.RunId, StartTs: e.Ts, StartComment: e.Comment, FinalStatus: wfmodel.RunStart, CompletedTs: time.Time{}, StoppedTs: time.Time{}} + } else { + _, ok := runLifespanMap[e.RunId] + if !ok { + return nil, fmt.Errorf("unexpected sequence of run status events: %v, %s", events, q) + } + if e.Status == wfmodel.RunComplete && runLifespanMap[e.RunId].CompletedTs.Unix() == emptyUnix { + runLifespanMap[e.RunId].CompletedTs = e.Ts + runLifespanMap[e.RunId].CompletedComment = e.Comment + if runLifespanMap[e.RunId].StoppedTs.Unix() == emptyUnix { + runLifespanMap[e.RunId].FinalStatus = wfmodel.RunComplete // If it was not stopped so far, consider it complete + } + } else if e.Status == wfmodel.RunStop && runLifespanMap[e.RunId].StoppedTs.Unix() == emptyUnix { + runLifespanMap[e.RunId].StoppedTs = e.Ts + runLifespanMap[e.RunId].StoppedComment = e.Comment + runLifespanMap[e.RunId].FinalStatus = wfmodel.RunStop // Stop always wins as final status, it may be sign for dependency checker to declare results invalid (depending on the rules) + } + } + } + + return runLifespanMap, nil +} + +func SetRunStatus(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runId int16, status wfmodel.RunStatusType, comment string, ifNotExistsFlag cql.IfNotExistsType) error { + logger.PushF("wfdb.SetRunStatus") + defer logger.PopF() + + q := (&cql.QueryBuilder{}). + Keyspace(keyspace). + WriteForceUnquote("ts", "toTimeStamp(now())"). + Write("run_id", runId). + Write("status", status). + Write("comment", comment). + InsertUnpreparedQuery(wfmodel.TableNameRunHistory, ifNotExistsFlag) + err := cqlSession.Query(q).Exec() + if err != nil { + return db.WrapDbErrorWithQuery("cannot write run status", q, err) + } + + logger.Debug("run %d, status %s", runId, status.ToString()) + return nil +} diff --git a/pkg/wfdb/run_properties.go b/pkg/wfdb/run_properties.go index 835f34a..00c036f 100644 --- a/pkg/wfdb/run_properties.go +++ b/pkg/wfdb/run_properties.go @@ -1,119 +1,119 @@ -package wfdb - -import ( - "fmt" - "sort" - "strings" - - "github.com/capillariesio/capillaries/pkg/cql" - "github.com/capillariesio/capillaries/pkg/ctx" - "github.com/capillariesio/capillaries/pkg/db" - "github.com/capillariesio/capillaries/pkg/l" - "github.com/capillariesio/capillaries/pkg/wfmodel" - "github.com/gocql/gocql" -) - -func GetRunAffectedNodes(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runId int16) ([]string, error) { - logger.PushF("wfdb.GetRunAffectedNodes") - defer logger.PopF() - - runPropsList, err := GetRunProperties(logger, cqlSession, keyspace, runId) - if err != nil { - return []string{}, err - } - if len(runPropsList) != 1 { - return []string{}, fmt.Errorf("run affected nodes for ks %s, run id %d returned wrong number of rows (%d), expected 1", keyspace, runId, len(runPropsList)) - } - return strings.Split(runPropsList[0].AffectedNodes, ","), nil -} - -// func GetAllRunsProperties(cqlSession *gocql.Session, keyspace string) ([]*wfmodel.RunAffectedNodes, error) { -// return getRunProperties(cqlSession, keyspace, 0) -// } - -func GetRunProperties(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runId int16) ([]*wfmodel.RunProperties, error) { - logger.PushF("wfdb.GetRunProperties") - defer logger.PopF() - - qb := cql.QueryBuilder{} - qb.Keyspace(keyspace) - if runId > 0 { - qb.Cond("run_id", "=", runId) - } - q := qb.Select(wfmodel.TableNameRunAffectedNodes, wfmodel.RunPropertiesAllFields()) - rows, err := cqlSession.Query(q).Iter().SliceMap() - if err != nil { - return []*wfmodel.RunProperties{}, db.WrapDbErrorWithQuery("cannot get all runs properties", q, err) - } - - runs := make([]*wfmodel.RunProperties, len(rows)) - for rowIdx, row := range rows { - rec, err := wfmodel.NewRunPropertiesFromMap(row, wfmodel.RunPropertiesAllFields()) - if err != nil { - return []*wfmodel.RunProperties{}, fmt.Errorf("%s, %s", err.Error(), q) - } - runs[rowIdx] = rec - } - - sort.Slice(runs, func(i, j int) bool { return runs[i].RunId < runs[j].RunId }) - - return runs, nil -} - -func HarvestRunIdsByAffectedNodes(logger *l.Logger, pCtx *ctx.MessageProcessingContext, nodeNames []string) ([]int16, map[string][]int16, error) { - logger.PushF("wfdb.HarvestRunIdsByAffectedNodes") - defer logger.PopF() - - fields := []string{"run_id", "affected_nodes"} - q := (&cql.QueryBuilder{}). - Keyspace(pCtx.BatchInfo.DataKeyspace). - Select(wfmodel.TableNameRunAffectedNodes, fields) - rows, err := pCtx.CqlSession.Query(q).Iter().SliceMap() - if err != nil { - return nil, nil, db.WrapDbErrorWithQuery("cannot get runs for affected nodes", q, err) - } - - runIds := make([]int16, len(rows)) - nodeAffectingRunIdsMap := map[string][]int16{} - for runIdx, r := range rows { - - rec, err := wfmodel.NewRunPropertiesFromMap(r, fields) - if err != nil { - return nil, nil, fmt.Errorf("%s, %s", err.Error(), q) - } - - runIds[runIdx] = rec.RunId - - affectedNodes := strings.Split(rec.AffectedNodes, ",") - for _, affectedNodeName := range affectedNodes { - _, ok := nodeAffectingRunIdsMap[affectedNodeName] - if !ok { - nodeAffectingRunIdsMap[affectedNodeName] = make([]int16, 1) - nodeAffectingRunIdsMap[affectedNodeName][0] = rec.RunId - } else { - nodeAffectingRunIdsMap[affectedNodeName] = append(nodeAffectingRunIdsMap[affectedNodeName], rec.RunId) - - } - } - } - - return runIds, nodeAffectingRunIdsMap, nil -} - -func WriteRunProperties(logger *l.Logger, cqlSession *gocql.Session, keyspace string, runId int16, startNodes []string, affectedNodes []string, scriptUri string, scriptParamsUri string, runDescription string) error { - q := (&cql.QueryBuilder{}). - Keyspace(keyspace). - Write("run_id", runId). - Write("start_nodes", strings.Join(startNodes, ",")). - Write("affected_nodes", strings.Join(affectedNodes, ",")). - Write("script_uri", scriptUri). - Write("script_params_uri", scriptParamsUri). - Write("run_description", runDescription). - InsertUnpreparedQuery(wfmodel.TableNameRunAffectedNodes, cql.IgnoreIfExists) // If not exists. First one wins. - err := cqlSession.Query(q).Exec() - if err != nil { - return db.WrapDbErrorWithQuery("cannot write affected nodes", q, err) - } - - return nil -} +package wfdb + +import ( + "fmt" + "sort" + "strings" + + "github.com/capillariesio/capillaries/pkg/cql" + "github.com/capillariesio/capillaries/pkg/ctx" + "github.com/capillariesio/capillaries/pkg/db" + "github.com/capillariesio/capillaries/pkg/l" + "github.com/capillariesio/capillaries/pkg/wfmodel" + "github.com/gocql/gocql" +) + +func GetRunAffectedNodes(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runId int16) ([]string, error) { + logger.PushF("wfdb.GetRunAffectedNodes") + defer logger.PopF() + + runPropsList, err := GetRunProperties(logger, cqlSession, keyspace, runId) + if err != nil { + return []string{}, err + } + if len(runPropsList) != 1 { + return []string{}, fmt.Errorf("run affected nodes for ks %s, run id %d returned wrong number of rows (%d), expected 1", keyspace, runId, len(runPropsList)) + } + return strings.Split(runPropsList[0].AffectedNodes, ","), nil +} + +// func GetAllRunsProperties(cqlSession *gocql.Session, keyspace string) ([]*wfmodel.RunAffectedNodes, error) { +// return getRunProperties(cqlSession, keyspace, 0) +// } + +func GetRunProperties(logger *l.CapiLogger, cqlSession *gocql.Session, keyspace string, runId int16) ([]*wfmodel.RunProperties, error) { + logger.PushF("wfdb.GetRunProperties") + defer logger.PopF() + + qb := cql.QueryBuilder{} + qb.Keyspace(keyspace) + if runId > 0 { + qb.Cond("run_id", "=", runId) + } + q := qb.Select(wfmodel.TableNameRunAffectedNodes, wfmodel.RunPropertiesAllFields()) + rows, err := cqlSession.Query(q).Iter().SliceMap() + if err != nil { + return []*wfmodel.RunProperties{}, db.WrapDbErrorWithQuery("cannot get all runs properties", q, err) + } + + runs := make([]*wfmodel.RunProperties, len(rows)) + for rowIdx, row := range rows { + rec, err := wfmodel.NewRunPropertiesFromMap(row, wfmodel.RunPropertiesAllFields()) + if err != nil { + return []*wfmodel.RunProperties{}, fmt.Errorf("%s, %s", err.Error(), q) + } + runs[rowIdx] = rec + } + + sort.Slice(runs, func(i, j int) bool { return runs[i].RunId < runs[j].RunId }) + + return runs, nil +} + +func HarvestRunIdsByAffectedNodes(logger *l.CapiLogger, pCtx *ctx.MessageProcessingContext) ([]int16, map[string][]int16, error) { + logger.PushF("wfdb.HarvestRunIdsByAffectedNodes") + defer logger.PopF() + + fields := []string{"run_id", "affected_nodes"} + q := (&cql.QueryBuilder{}). + Keyspace(pCtx.BatchInfo.DataKeyspace). + Select(wfmodel.TableNameRunAffectedNodes, fields) + rows, err := pCtx.CqlSession.Query(q).Iter().SliceMap() + if err != nil { + return nil, nil, db.WrapDbErrorWithQuery("cannot get runs for affected nodes", q, err) + } + + runIds := make([]int16, len(rows)) + nodeAffectingRunIdsMap := map[string][]int16{} + for runIdx, r := range rows { + + rec, err := wfmodel.NewRunPropertiesFromMap(r, fields) + if err != nil { + return nil, nil, fmt.Errorf("%s, %s", err.Error(), q) + } + + runIds[runIdx] = rec.RunId + + affectedNodes := strings.Split(rec.AffectedNodes, ",") + for _, affectedNodeName := range affectedNodes { + _, ok := nodeAffectingRunIdsMap[affectedNodeName] + if !ok { + nodeAffectingRunIdsMap[affectedNodeName] = make([]int16, 1) + nodeAffectingRunIdsMap[affectedNodeName][0] = rec.RunId + } else { + nodeAffectingRunIdsMap[affectedNodeName] = append(nodeAffectingRunIdsMap[affectedNodeName], rec.RunId) + + } + } + } + + return runIds, nodeAffectingRunIdsMap, nil +} + +func WriteRunProperties(cqlSession *gocql.Session, keyspace string, runId int16, startNodes []string, affectedNodes []string, scriptUri string, scriptParamsUri string, runDescription string) error { + q := (&cql.QueryBuilder{}). + Keyspace(keyspace). + Write("run_id", runId). + Write("start_nodes", strings.Join(startNodes, ",")). + Write("affected_nodes", strings.Join(affectedNodes, ",")). + Write("script_uri", scriptUri). + Write("script_params_uri", scriptParamsUri). + Write("run_description", runDescription). + InsertUnpreparedQuery(wfmodel.TableNameRunAffectedNodes, cql.IgnoreIfExists) // If not exists. First one wins. + err := cqlSession.Query(q).Exec() + if err != nil { + return db.WrapDbErrorWithQuery("cannot write affected nodes", q, err) + } + + return nil +} diff --git a/pkg/wfmodel/batch_history.go b/pkg/wfmodel/batch_history.go index 3c9314d..11ce241 100644 --- a/pkg/wfmodel/batch_history.go +++ b/pkg/wfmodel/batch_history.go @@ -1,88 +1,88 @@ -package wfmodel - -import ( - "fmt" - "time" -) - -type NodeBatchStatusType int8 - -// In priority order -const ( - NodeBatchNone NodeBatchStatusType = 0 - NodeBatchStart NodeBatchStatusType = 1 - NodeBatchSuccess NodeBatchStatusType = 2 - NodeBatchFail NodeBatchStatusType = 3 // Biz logicerror or data table (not WF) error - NodeBatchRunStopReceived NodeBatchStatusType = 104 -) - -const TableNameBatchHistory = "wf_batch_history" - -// Object model with tags that allow to create cql CREATE TABLE queries and to print object -type BatchHistoryEvent struct { - Ts time.Time `header:"ts" format:"%-33v" column:"ts" type:"timestamp" json:"ts"` - RunId int16 `header:"run_id" format:"%6d" column:"run_id" type:"int" key:"true" json:"run_id"` - ScriptNode string `header:"script_node" format:"%20v" column:"script_node" type:"text" key:"true" json:"script_node"` - BatchIdx int16 `header:"bnum" format:"%5v" column:"batch_idx" type:"int" key:"true" json:"batch_idx"` - BatchesTotal int16 `header:"tbtchs" format:"%6v" column:"batches_total" type:"int" json:"batches_total"` - Status NodeBatchStatusType `header:"sts" format:"%3v" column:"status" type:"tinyint" key:"true" json:"status"` - FirstToken int64 `header:"ftoken" format:"%21v" column:"first_token" type:"bigint" json:"first_token"` - LastToken int64 `header:"ltoken" format:"%21v" column:"last_token" type:"bigint" json:"last_token"` - Instance string `header:"instance" format:"%21v" column:"instance" type:"text" json:"instance"` - Thread int64 `header:"thread" format:"%4v" column:"thread" type:"bigint" json:"thread"` - Comment string `header:"comment" format:"%v" column:"comment" type:"text" json:"comment"` -} - -func BatchHistoryEventAllFields() []string { - return []string{"ts", "run_id", "script_node", "batch_idx", "batches_total", "status", "first_token", "last_token", "instance", "thread", "comment"} -} -func NewBatchHistoryEventFromMap(r map[string]interface{}, fields []string) (*BatchHistoryEvent, error) { - res := &BatchHistoryEvent{} - for _, fieldName := range fields { - var err error - switch fieldName { - case "ts": - res.Ts, err = ReadTimeFromRow(fieldName, r) - case "run_id": - res.RunId, err = ReadInt16FromRow(fieldName, r) - case "script_node": - res.ScriptNode, err = ReadStringFromRow(fieldName, r) - case "batch_idx": - res.BatchIdx, err = ReadInt16FromRow(fieldName, r) - case "batches_total": - res.BatchesTotal, err = ReadInt16FromRow(fieldName, r) - case "status": - res.Status, err = ReadNodeBatchStatusFromRow(fieldName, r) - case "first_token": - res.FirstToken, err = ReadInt64FromRow(fieldName, r) - case "last_token": - res.LastToken, err = ReadInt64FromRow(fieldName, r) - case "instance": - res.Instance, err = ReadStringFromRow(fieldName, r) - case "thread": - res.Thread, err = ReadInt64FromRow(fieldName, r) - case "comment": - res.Comment, err = ReadStringFromRow(fieldName, r) - default: - return nil, fmt.Errorf("unknown %s field %s", fieldName, TableNameNodeHistory) - } - if err != nil { - return nil, err - } - } - return res, nil -} - -// ToSpacedString - prints formatted field values, uses reflection, shoud not be used in prod -// func (n BatchHistoryEvent) ToSpacedString() string { -// t := reflect.TypeOf(n) -// formats := GetObjectModelFieldFormats(t) -// values := make([]string, t.NumField()) - -// v := reflect.ValueOf(&n).Elem() -// for i := 0; i < v.NumField(); i++ { -// fv := v.Field(i) -// values[i] = fmt.Sprintf(formats[i], fv) -// } -// return strings.Join(values, PrintTableDelimiter) -// } +package wfmodel + +import ( + "fmt" + "time" +) + +type NodeBatchStatusType int8 + +// In priority order +const ( + NodeBatchNone NodeBatchStatusType = 0 + NodeBatchStart NodeBatchStatusType = 1 + NodeBatchSuccess NodeBatchStatusType = 2 + NodeBatchFail NodeBatchStatusType = 3 // Biz logicerror or data table (not WF) error + NodeBatchRunStopReceived NodeBatchStatusType = 104 +) + +const TableNameBatchHistory = "wf_batch_history" + +// Object model with tags that allow to create cql CREATE TABLE queries and to print object +type BatchHistoryEvent struct { + Ts time.Time `header:"ts" format:"%-33v" column:"ts" type:"timestamp" json:"ts"` + RunId int16 `header:"run_id" format:"%6d" column:"run_id" type:"int" key:"true" json:"run_id"` + ScriptNode string `header:"script_node" format:"%20v" column:"script_node" type:"text" key:"true" json:"script_node"` + BatchIdx int16 `header:"bnum" format:"%5v" column:"batch_idx" type:"int" key:"true" json:"batch_idx"` + BatchesTotal int16 `header:"tbtchs" format:"%6v" column:"batches_total" type:"int" json:"batches_total"` + Status NodeBatchStatusType `header:"sts" format:"%3v" column:"status" type:"tinyint" key:"true" json:"status"` + FirstToken int64 `header:"ftoken" format:"%21v" column:"first_token" type:"bigint" json:"first_token"` + LastToken int64 `header:"ltoken" format:"%21v" column:"last_token" type:"bigint" json:"last_token"` + Instance string `header:"instance" format:"%21v" column:"instance" type:"text" json:"instance"` + Thread int64 `header:"thread" format:"%4v" column:"thread" type:"bigint" json:"thread"` + Comment string `header:"comment" format:"%v" column:"comment" type:"text" json:"comment"` +} + +func BatchHistoryEventAllFields() []string { + return []string{"ts", "run_id", "script_node", "batch_idx", "batches_total", "status", "first_token", "last_token", "instance", "thread", "comment"} +} +func NewBatchHistoryEventFromMap(r map[string]any, fields []string) (*BatchHistoryEvent, error) { + res := &BatchHistoryEvent{} + for _, fieldName := range fields { + var err error + switch fieldName { + case "ts": + res.Ts, err = ReadTimeFromRow(fieldName, r) + case "run_id": + res.RunId, err = ReadInt16FromRow(fieldName, r) + case "script_node": + res.ScriptNode, err = ReadStringFromRow(fieldName, r) + case "batch_idx": + res.BatchIdx, err = ReadInt16FromRow(fieldName, r) + case "batches_total": + res.BatchesTotal, err = ReadInt16FromRow(fieldName, r) + case "status": + res.Status, err = ReadNodeBatchStatusFromRow(fieldName, r) + case "first_token": + res.FirstToken, err = ReadInt64FromRow(fieldName, r) + case "last_token": + res.LastToken, err = ReadInt64FromRow(fieldName, r) + case "instance": + res.Instance, err = ReadStringFromRow(fieldName, r) + case "thread": + res.Thread, err = ReadInt64FromRow(fieldName, r) + case "comment": + res.Comment, err = ReadStringFromRow(fieldName, r) + default: + return nil, fmt.Errorf("unknown %s field %s", fieldName, TableNameNodeHistory) + } + if err != nil { + return nil, err + } + } + return res, nil +} + +// ToSpacedString - prints formatted field values, uses reflection, shoud not be used in prod +// func (n BatchHistoryEvent) ToSpacedString() string { +// t := reflect.TypeOf(n) +// formats := GetObjectModelFieldFormats(t) +// values := make([]string, t.NumField()) + +// v := reflect.ValueOf(&n).Elem() +// for i := 0; i < v.NumField(); i++ { +// fv := v.Field(i) +// values[i] = fmt.Sprintf(formats[i], fv) +// } +// return strings.Join(values, PrintTableDelimiter) +// } diff --git a/pkg/wfmodel/dependency_node_event.go b/pkg/wfmodel/dependency_node_event.go index 98289f2..b797a4f 100644 --- a/pkg/wfmodel/dependency_node_event.go +++ b/pkg/wfmodel/dependency_node_event.go @@ -1,93 +1,93 @@ -package wfmodel - -import ( - "fmt" - "strings" - "time" - - "github.com/capillariesio/capillaries/pkg/eval" -) - -const DependencyNodeEventTableName string = "e" - -type DependencyNodeEvent struct { - RunId int16 - RunIsCurrent bool - RunStartTs time.Time - RunFinalStatus RunStatusType - RunCompletedTs time.Time - RunStoppedTs time.Time - NodeIsStarted bool - NodeStartTs time.Time - NodeStatus NodeBatchStatusType - NodeStatusTs time.Time - SortKey string -} - -func (e *DependencyNodeEvent) ToVars() eval.VarValuesMap { - return eval.VarValuesMap{ - DependencyNodeEventTableName: map[string]interface{}{ - "run_id": int64(e.RunId), - "run_is_current": e.RunIsCurrent, - "run_start_ts": e.RunStartTs, - "run_final_status": int64(e.RunFinalStatus), - "run_completed_ts": e.RunCompletedTs, - "run_stopped_ts": e.RunStoppedTs, - "node_is_started": e.NodeIsStarted, - "node_start_ts": e.NodeStartTs, - "node_status": int64(e.NodeStatus), - "node_status_ts": e.NodeStatusTs}} -} - -func (e *DependencyNodeEvent) ToString() string { - sb := strings.Builder{} - sb.WriteString("{") - sb.WriteString(fmt.Sprintf("run_id:%d,", e.RunId)) - sb.WriteString(fmt.Sprintf("run_is_current:%t,", e.RunIsCurrent)) - sb.WriteString(fmt.Sprintf("run_start_ts:%s,", e.RunStartTs.Format(LogTsFormatQuoted))) - sb.WriteString(fmt.Sprintf("run_final_status:%s,", e.RunFinalStatus.ToString())) - sb.WriteString(fmt.Sprintf("run_completed_ts:%s,", e.RunCompletedTs.Format(LogTsFormatQuoted))) - sb.WriteString(fmt.Sprintf("run_stopped_ts:%s,", e.RunStoppedTs.Format(LogTsFormatQuoted))) - sb.WriteString(fmt.Sprintf("node_is_started:%t,", e.NodeIsStarted)) - sb.WriteString(fmt.Sprintf("node_start_ts:%s,", e.NodeStartTs.Format(LogTsFormatQuoted))) - sb.WriteString(fmt.Sprintf("node_status:%s,", e.NodeStatus.ToString())) - sb.WriteString(fmt.Sprintf("node_status_ts:%s", e.NodeStatusTs.Format(LogTsFormatQuoted))) - sb.WriteString("}") - return sb.String() -} - -type DependencyNodeEvents []DependencyNodeEvent - -func (events DependencyNodeEvents) ToString() string { - items := make([]string, len(events)) - for eventIdx := 0; eventIdx < len(events); eventIdx++ { - items[eventIdx] = events[eventIdx].ToString() - } - return fmt.Sprintf("[%s]", strings.Join(items, ", ")) -} - -func NewVarsFromDepCtx(runId int16, e DependencyNodeEvent) eval.VarValuesMap { - m := eval.VarValuesMap{} - m[WfmodelNamespace] = map[string]interface{}{ - "NodeBatchNone": int64(NodeBatchNone), - "NodeBatchStart": int64(NodeBatchStart), - "NodeBatchSuccess": int64(NodeBatchSuccess), - "NodeBatchFail": int64(NodeBatchFail), - "NodeBatchRunStopReceived": int64(NodeBatchRunStopReceived), - "RunNone": int64(RunNone), - "RunStart": int64(RunStart), - "RunComplete": int64(RunComplete), - "RunStop": int64(RunStop)} - m[DependencyNodeEventTableName] = map[string]interface{}{ - "run_id": int64(e.RunId), - "run_is_current": e.RunIsCurrent, - "run_start_ts": e.RunStartTs, - "run_final_status": int64(e.RunFinalStatus), - "run_completed_ts": e.RunCompletedTs, - "run_stopped_ts": e.RunStoppedTs, - "node_is_started": e.NodeIsStarted, - "node_start_ts": e.NodeStartTs, - "node_status": int64(e.NodeStatus), - "node_status_ts": e.NodeStatusTs} - return m -} +package wfmodel + +import ( + "fmt" + "strings" + "time" + + "github.com/capillariesio/capillaries/pkg/eval" +) + +const DependencyNodeEventTableName string = "e" + +type DependencyNodeEvent struct { + RunId int16 + RunIsCurrent bool + RunStartTs time.Time + RunFinalStatus RunStatusType + RunCompletedTs time.Time + RunStoppedTs time.Time + NodeIsStarted bool + NodeStartTs time.Time + NodeStatus NodeBatchStatusType + NodeStatusTs time.Time + SortKey string +} + +func (e *DependencyNodeEvent) ToVars() eval.VarValuesMap { + return eval.VarValuesMap{ + DependencyNodeEventTableName: map[string]any{ + "run_id": int64(e.RunId), + "run_is_current": e.RunIsCurrent, + "run_start_ts": e.RunStartTs, + "run_final_status": int64(e.RunFinalStatus), + "run_completed_ts": e.RunCompletedTs, + "run_stopped_ts": e.RunStoppedTs, + "node_is_started": e.NodeIsStarted, + "node_start_ts": e.NodeStartTs, + "node_status": int64(e.NodeStatus), + "node_status_ts": e.NodeStatusTs}} +} + +func (e *DependencyNodeEvent) ToString() string { + sb := strings.Builder{} + sb.WriteString("{") + sb.WriteString(fmt.Sprintf("run_id:%d,", e.RunId)) + sb.WriteString(fmt.Sprintf("run_is_current:%t,", e.RunIsCurrent)) + sb.WriteString(fmt.Sprintf("run_start_ts:%s,", e.RunStartTs.Format(LogTsFormatQuoted))) + sb.WriteString(fmt.Sprintf("run_final_status:%s,", e.RunFinalStatus.ToString())) + sb.WriteString(fmt.Sprintf("run_completed_ts:%s,", e.RunCompletedTs.Format(LogTsFormatQuoted))) + sb.WriteString(fmt.Sprintf("run_stopped_ts:%s,", e.RunStoppedTs.Format(LogTsFormatQuoted))) + sb.WriteString(fmt.Sprintf("node_is_started:%t,", e.NodeIsStarted)) + sb.WriteString(fmt.Sprintf("node_start_ts:%s,", e.NodeStartTs.Format(LogTsFormatQuoted))) + sb.WriteString(fmt.Sprintf("node_status:%s,", e.NodeStatus.ToString())) + sb.WriteString(fmt.Sprintf("node_status_ts:%s", e.NodeStatusTs.Format(LogTsFormatQuoted))) + sb.WriteString("}") + return sb.String() +} + +type DependencyNodeEvents []DependencyNodeEvent + +func (events DependencyNodeEvents) ToString() string { + items := make([]string, len(events)) + for eventIdx := 0; eventIdx < len(events); eventIdx++ { + items[eventIdx] = events[eventIdx].ToString() + } + return fmt.Sprintf("[%s]", strings.Join(items, ", ")) +} + +func NewVarsFromDepCtx(e DependencyNodeEvent) eval.VarValuesMap { + m := eval.VarValuesMap{} + m[WfmodelNamespace] = map[string]any{ + "NodeBatchNone": int64(NodeBatchNone), + "NodeBatchStart": int64(NodeBatchStart), + "NodeBatchSuccess": int64(NodeBatchSuccess), + "NodeBatchFail": int64(NodeBatchFail), + "NodeBatchRunStopReceived": int64(NodeBatchRunStopReceived), + "RunNone": int64(RunNone), + "RunStart": int64(RunStart), + "RunComplete": int64(RunComplete), + "RunStop": int64(RunStop)} + m[DependencyNodeEventTableName] = map[string]any{ + "run_id": int64(e.RunId), + "run_is_current": e.RunIsCurrent, + "run_start_ts": e.RunStartTs, + "run_final_status": int64(e.RunFinalStatus), + "run_completed_ts": e.RunCompletedTs, + "run_stopped_ts": e.RunStoppedTs, + "node_is_started": e.NodeIsStarted, + "node_start_ts": e.NodeStartTs, + "node_status": int64(e.NodeStatus), + "node_status_ts": e.NodeStatusTs} + return m +} diff --git a/pkg/wfmodel/message.go b/pkg/wfmodel/message.go index 04e4a65..9189f4c 100644 --- a/pkg/wfmodel/message.go +++ b/pkg/wfmodel/message.go @@ -1,104 +1,104 @@ -package wfmodel - -import ( - "encoding/json" - "fmt" - "strings" -) - -/* -MessagePayloadComment - generic paylod -Comment - unstructured, can be anything, like information about the sender of the signal -*/ -type MessagePayloadComment struct { - Comment string `json:"comment"` -} - -// Message types, payload depends on it -const ( - MessageTypeDataBatch = 1 - MessageTypeShutown = 101 // pass processor_id - MessageTypeSetLoggingLevel = 102 // pass processor_id and logging level - MessageTypeCancelProcessInstance = 103 // Pass process id and process instance -) - -/* -Message - carries data and signals to processors/nodes -1. No version support. Premature optimization is the root of all evil. -2. Used for data transfer and for control signals. -3. For faster de/serialization, consider custom parser not involving reflection -4. Timestamps are int (not uint) because Unix epoch is int -*/ -type Message struct { - Ts int64 `json:"ts"` - MessageType int `json:"message_type"` - Payload interface{} `json:"payload"` // This depends on MessageType -} - -func (msg Message) ToString() string { - var sb strings.Builder - - sb.WriteString(fmt.Sprintf("Ts:%d, MessageType:%d. ", msg.Ts, msg.MessageType)) - if msg.MessageType == MessageTypeDataBatch && msg.Payload != nil { - batchPayload, ok := msg.Payload.(MessagePayloadDataBatch) - if ok { - sb.WriteString(batchPayload.ToString()) - } - } - return sb.String() -} - -func (msg Message) Serialize() ([]byte, error) { - jsonBytes, err := json.Marshal(msg) - if err != nil { - // This is really unexpected, log the whole msg - return nil, fmt.Errorf("cannot serialize message: %s. %v", msg.ToString(), err) - } - return jsonBytes, nil -} - -func (msg *Message) Deserialize(jsonBytes []byte) error { - var payload json.RawMessage - msg.Payload = &payload - err := json.Unmarshal(jsonBytes, &msg) - if err != nil { - // This is realy unexpected, log the whole json as bytes - return fmt.Errorf("cannot deserialize message: %v. %v", jsonBytes, err) - } - - switch msg.MessageType { - case MessageTypeDataBatch: - var payloadDataChunk MessagePayloadDataBatch - err := json.Unmarshal(payload, &payloadDataChunk) - if err != nil { - return err - } - msg.Payload = payloadDataChunk - case MessageTypeCancelProcessInstance: - payloadComment := MessagePayloadComment{} - err := json.Unmarshal(payload, &payloadComment) - if err != nil { - return err - } - msg.Payload = payloadComment - default: - return fmt.Errorf("cannot deserialize message, unknown message type: %s", msg.ToString()) - } - - return nil -} - -// func (tgtMsg *Message) NewDataBatchFromCtx(context *ctx.MessageProcessingContext, targetNodeName string, firstToken int64, lastToken int64, batchIdx int16, batchesTotal int16) { -// tgtMsg.Ts = time.Now().UnixMilli() -// tgtMsg.MessageType = MessageTypeDataBatch -// tgtMsg.Payload = MessagePayloadDataBatch{ -// ScriptURI: context.BatchInfo.ScriptURI, -// ScriptParamsURI: context.BatchInfo.ScriptParamsURI, -// DataKeyspace: context.BatchInfo.DataKeyspace, -// RunId: context.BatchInfo.RunId, -// TargetNodeName: targetNodeName, -// FirstToken: firstToken, -// LastToken: lastToken, -// BatchIdx: batchIdx, -// BatchesTotal: batchesTotal} -// } +package wfmodel + +import ( + "encoding/json" + "fmt" + "strings" +) + +/* +MessagePayloadComment - generic paylod +Comment - unstructured, can be anything, like information about the sender of the signal +*/ +type MessagePayloadComment struct { + Comment string `json:"comment"` +} + +// Message types, payload depends on it +const ( + MessageTypeDataBatch = 1 + MessageTypeShutown = 101 // pass processor_id + MessageTypeSetLoggingLevel = 102 // pass processor_id and logging level + MessageTypeCancelProcessInstance = 103 // Pass process id and process instance +) + +/* +Message - carries data and signals to processors/nodes +1. No version support. Premature optimization is the root of all evil. +2. Used for data transfer and for control signals. +3. For faster de/serialization, consider custom parser not involving reflection +4. Timestamps are int (not uint) because Unix epoch is int +*/ +type Message struct { + Ts int64 `json:"ts"` + MessageType int `json:"message_type"` + Payload any `json:"payload"` // This depends on MessageType +} + +func (msg Message) ToString() string { + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("Ts:%d, MessageType:%d. ", msg.Ts, msg.MessageType)) + if msg.MessageType == MessageTypeDataBatch && msg.Payload != nil { + batchPayload, ok := msg.Payload.(MessagePayloadDataBatch) + if ok { + sb.WriteString(batchPayload.ToString()) + } + } + return sb.String() +} + +func (msg Message) Serialize() ([]byte, error) { + jsonBytes, err := json.Marshal(msg) + if err != nil { + // This is really unexpected, log the whole msg + return nil, fmt.Errorf("cannot serialize message: %s. %v", msg.ToString(), err) + } + return jsonBytes, nil +} + +func (msg *Message) Deserialize(jsonBytes []byte) error { + var payload json.RawMessage + msg.Payload = &payload + err := json.Unmarshal(jsonBytes, &msg) + if err != nil { + // This is really unexpected, log the whole json as bytes + return fmt.Errorf("cannot deserialize message: %v. %v", jsonBytes, err) + } + + switch msg.MessageType { + case MessageTypeDataBatch: + var payloadDataChunk MessagePayloadDataBatch + err := json.Unmarshal(payload, &payloadDataChunk) + if err != nil { + return err + } + msg.Payload = payloadDataChunk + case MessageTypeCancelProcessInstance: + payloadComment := MessagePayloadComment{} + err := json.Unmarshal(payload, &payloadComment) + if err != nil { + return err + } + msg.Payload = payloadComment + default: + return fmt.Errorf("cannot deserialize message, unknown message type: %s", msg.ToString()) + } + + return nil +} + +// func (tgtMsg *Message) NewDataBatchFromCtx(context *ctx.MessageProcessingContext, targetNodeName string, firstToken int64, lastToken int64, batchIdx int16, batchesTotal int16) { +// tgtMsg.Ts = time.Now().UnixMilli() +// tgtMsg.MessageType = MessageTypeDataBatch +// tgtMsg.Payload = MessagePayloadDataBatch{ +// ScriptURI: context.BatchInfo.ScriptURI, +// ScriptParamsURI: context.BatchInfo.ScriptParamsURI, +// DataKeyspace: context.BatchInfo.DataKeyspace, +// RunId: context.BatchInfo.RunId, +// TargetNodeName: targetNodeName, +// FirstToken: firstToken, +// LastToken: lastToken, +// BatchIdx: batchIdx, +// BatchesTotal: batchesTotal} +// } diff --git a/pkg/wfmodel/message_payload_data_batch.go b/pkg/wfmodel/message_payload_data_batch.go index c3155f0..0728bca 100644 --- a/pkg/wfmodel/message_payload_data_batch.go +++ b/pkg/wfmodel/message_payload_data_batch.go @@ -1,43 +1,40 @@ -package wfmodel - -import ( - "encoding/json" - "fmt" -) - -type MessagePayloadDataBatch struct { - ScriptURI string `json:"script_uri"` - ScriptParamsURI string `json:"script_params_uri"` - DataKeyspace string `json:"data_keyspace"` // Instance/process id - RunId int16 `json:"run_id"` - TargetNodeName string `json:"target_node"` - FirstToken int64 `json:"first_token"` - LastToken int64 `json:"last_token"` - BatchIdx int16 `json:"batch_idx"` - BatchesTotal int16 `json:"batches_total"` -} - -func (dc *MessagePayloadDataBatch) FullBatchId() string { - return fmt.Sprintf("%s/%d/%s/%d", dc.DataKeyspace, dc.RunId, dc.TargetNodeName, dc.BatchIdx) -} - -func (dc *MessagePayloadDataBatch) ToString() string { - return fmt.Sprintf("ScriptURI:%s,ScriptParamsURI:%s, DataKeyspace:%s, RunId:%d, TargetNodeName:%s, FirstToken:%d, LastToken:%d, BatchIdx:%d, BatchesTotal:%d. ", - dc.ScriptURI, dc.ScriptParamsURI, dc.DataKeyspace, dc.RunId, dc.TargetNodeName, dc.FirstToken, dc.LastToken, dc.BatchIdx, dc.BatchesTotal) -} - -func (dc *MessagePayloadDataBatch) Deserialize(jsonBytes []byte) error { - if err := json.Unmarshal(jsonBytes, dc); err != nil { - return err - } - return nil -} - -func (dc MessagePayloadDataBatch) Serialize() ([]byte, error) { - var jsonBytes []byte - jsonBytes, err := json.Marshal(dc) - if err != nil { - return nil, err - } - return jsonBytes, nil -} +package wfmodel + +import ( + "encoding/json" + "fmt" +) + +type MessagePayloadDataBatch struct { + ScriptURI string `json:"script_uri"` + ScriptParamsURI string `json:"script_params_uri"` + DataKeyspace string `json:"data_keyspace"` // Instance/process id + RunId int16 `json:"run_id"` + TargetNodeName string `json:"target_node"` + FirstToken int64 `json:"first_token"` + LastToken int64 `json:"last_token"` + BatchIdx int16 `json:"batch_idx"` + BatchesTotal int16 `json:"batches_total"` +} + +func (dc *MessagePayloadDataBatch) FullBatchId() string { + return fmt.Sprintf("%s/%d/%s/%d", dc.DataKeyspace, dc.RunId, dc.TargetNodeName, dc.BatchIdx) +} + +func (dc *MessagePayloadDataBatch) ToString() string { + return fmt.Sprintf("ScriptURI:%s,ScriptParamsURI:%s, DataKeyspace:%s, RunId:%d, TargetNodeName:%s, FirstToken:%d, LastToken:%d, BatchIdx:%d, BatchesTotal:%d. ", + dc.ScriptURI, dc.ScriptParamsURI, dc.DataKeyspace, dc.RunId, dc.TargetNodeName, dc.FirstToken, dc.LastToken, dc.BatchIdx, dc.BatchesTotal) +} + +func (dc *MessagePayloadDataBatch) Deserialize(jsonBytes []byte) error { + return json.Unmarshal(jsonBytes, dc) +} + +func (dc MessagePayloadDataBatch) Serialize() ([]byte, error) { + var jsonBytes []byte + jsonBytes, err := json.Marshal(dc) + if err != nil { + return nil, err + } + return jsonBytes, nil +} diff --git a/pkg/wfmodel/node_history.go b/pkg/wfmodel/node_history.go index 5da5761..efdca0a 100644 --- a/pkg/wfmodel/node_history.go +++ b/pkg/wfmodel/node_history.go @@ -1,163 +1,163 @@ -package wfmodel - -import ( - "fmt" - "strings" - "time" -) - -type NodeStatusType int8 - -const ( - NodehNone NodeBatchStatusType = 0 - NodeStart NodeBatchStatusType = 1 - NodeSuccess NodeBatchStatusType = 2 - NodeFail NodeBatchStatusType = 3 - NodeRunStopReceived NodeBatchStatusType = 104 -) - -const TableNameNodeHistory = "wf_node_history" - -func (status NodeBatchStatusType) ToString() string { - switch status { - case NodehNone: - return "none" - case NodeStart: - return "start" - case NodeSuccess: - return "success" - case NodeFail: - return "fail" - case NodeRunStopReceived: - return "stopreceived" - default: - return "unknown" - } -} - -type NodeStatusMap map[string]NodeBatchStatusType - -func (m NodeStatusMap) ToString() string { - sb := strings.Builder{} - sb.WriteString("{") - for nodeName, nodeStatus := range m { - if sb.Len() > 1 { - sb.WriteString(",") - } - sb.WriteString(fmt.Sprintf(`"%s":"%s"`, nodeName, nodeStatus.ToString())) - } - sb.WriteString("}") - return sb.String() -} - -type RunBatchStatusMap map[int16]NodeBatchStatusType - -func (m RunBatchStatusMap) ToString() string { - sb := strings.Builder{} - sb.WriteString("{") - for runId, nodeStatus := range m { - sb.WriteString(fmt.Sprintf("%d:%s ", runId, nodeStatus.ToString())) - } - sb.WriteString("}") - return sb.String() -} - -type NodeRunBatchStatusMap map[string]RunBatchStatusMap - -func (m NodeRunBatchStatusMap) ToString() string { - sb := strings.Builder{} - sb.WriteString("{") - for nodeName, runBatchStatusMap := range m { - sb.WriteString(fmt.Sprintf("%s:%s ", nodeName, runBatchStatusMap.ToString())) - } - sb.WriteString("}") - return sb.String() -} - -// Object model with tags that allow to create cql CREATE TABLE queries and to print object -type NodeHistoryEvent struct { - Ts time.Time `header:"ts" format:"%-33v" column:"ts" type:"timestamp" json:"ts"` - RunId int16 `header:"run_id" format:"%6d" column:"run_id" type:"int" key:"true" json:"run_id"` - ScriptNode string `header:"script_node" format:"%20v" column:"script_node" type:"text" key:"true" json:"script_node"` - Status NodeBatchStatusType `header:"sts" format:"%3v" column:"status" type:"tinyint" key:"true" json:"status"` - Comment string `header:"comment" format:"%v" column:"comment" type:"text" json:"comment"` -} - -func NodeHistoryEventAllFields() []string { - return []string{"ts", "run_id", "script_node", "status", "comment"} -} -func NewNodeHistoryEventFromMap(r map[string]interface{}, fields []string) (*NodeHistoryEvent, error) { - res := &NodeHistoryEvent{} - for _, fieldName := range fields { - var err error - switch fieldName { - case "ts": - res.Ts, err = ReadTimeFromRow(fieldName, r) - case "run_id": - res.RunId, err = ReadInt16FromRow(fieldName, r) - case "script_node": - res.ScriptNode, err = ReadStringFromRow(fieldName, r) - case "status": - res.Status, err = ReadNodeBatchStatusFromRow(fieldName, r) - case "comment": - res.Comment, err = ReadStringFromRow(fieldName, r) - default: - return nil, fmt.Errorf("unknown %s field %s", fieldName, TableNameNodeHistory) - } - if err != nil { - return nil, err - } - } - return res, nil -} - -// ToSpacedString - prints formatted field values, uses reflection, shoud not be used in prod -// func (n NodeHistoryEvent) ToSpacedString() string { -// t := reflect.TypeOf(n) -// formats := GetObjectModelFieldFormats(t) -// values := make([]string, t.NumField()) - -// v := reflect.ValueOf(&n).Elem() -// for i := 0; i < v.NumField(); i++ { -// fv := v.Field(i) -// values[i] = fmt.Sprintf(formats[i], fv) -// } -// return strings.Join(values, PrintTableDelimiter) -// } - -type NodeLifespan struct { - StartTs time.Time - LastStatus NodeBatchStatusType - LastStatusTs time.Time -} - -func (ls NodeLifespan) ToString() string { - return fmt.Sprintf("{start_ts:%s, last_status:%s, last_status_ts:%s}", - ls.StartTs.Format(LogTsFormatQuoted), - ls.LastStatus.ToString(), - ls.LastStatusTs.Format(LogTsFormatQuoted)) -} - -type NodeLifespanMap map[string]*NodeLifespan - -func (m NodeLifespanMap) ToString() string { - items := make([]string, len(m)) - nodeIdx := 0 - for nodeName, ls := range m { - items[nodeIdx] = fmt.Sprintf("%s:%s", nodeName, ls.ToString()) - nodeIdx++ - } - return fmt.Sprintf("{%s}", strings.Join(items, ", ")) -} - -type RunNodeLifespanMap map[int16]NodeLifespanMap - -func (m RunNodeLifespanMap) ToString() string { - items := make([]string, len(m)) - runIdx := 0 - for runId, ls := range m { - items[runIdx] = fmt.Sprintf("%d:%s", runId, ls.ToString()) - runIdx++ - } - return fmt.Sprintf("{%s}", strings.Join(items, ", ")) -} +package wfmodel + +import ( + "fmt" + "strings" + "time" +) + +type NodeStatusType int8 + +const ( + NodehNone NodeBatchStatusType = 0 + NodeStart NodeBatchStatusType = 1 + NodeSuccess NodeBatchStatusType = 2 + NodeFail NodeBatchStatusType = 3 + NodeRunStopReceived NodeBatchStatusType = 104 +) + +const TableNameNodeHistory = "wf_node_history" + +func (status NodeBatchStatusType) ToString() string { + switch status { + case NodehNone: + return "none" + case NodeStart: + return "start" + case NodeSuccess: + return "success" + case NodeFail: + return "fail" + case NodeRunStopReceived: + return "stopreceived" + default: + return "unknown" + } +} + +type NodeStatusMap map[string]NodeBatchStatusType + +func (m NodeStatusMap) ToString() string { + sb := strings.Builder{} + sb.WriteString("{") + for nodeName, nodeStatus := range m { + if sb.Len() > 1 { + sb.WriteString(",") + } + sb.WriteString(fmt.Sprintf(`"%s":"%s"`, nodeName, nodeStatus.ToString())) + } + sb.WriteString("}") + return sb.String() +} + +type RunBatchStatusMap map[int16]NodeBatchStatusType + +func (m RunBatchStatusMap) ToString() string { + sb := strings.Builder{} + sb.WriteString("{") + for runId, nodeStatus := range m { + sb.WriteString(fmt.Sprintf("%d:%s ", runId, nodeStatus.ToString())) + } + sb.WriteString("}") + return sb.String() +} + +type NodeRunBatchStatusMap map[string]RunBatchStatusMap + +func (m NodeRunBatchStatusMap) ToString() string { + sb := strings.Builder{} + sb.WriteString("{") + for nodeName, runBatchStatusMap := range m { + sb.WriteString(fmt.Sprintf("%s:%s ", nodeName, runBatchStatusMap.ToString())) + } + sb.WriteString("}") + return sb.String() +} + +// Object model with tags that allow to create cql CREATE TABLE queries and to print object +type NodeHistoryEvent struct { + Ts time.Time `header:"ts" format:"%-33v" column:"ts" type:"timestamp" json:"ts"` + RunId int16 `header:"run_id" format:"%6d" column:"run_id" type:"int" key:"true" json:"run_id"` + ScriptNode string `header:"script_node" format:"%20v" column:"script_node" type:"text" key:"true" json:"script_node"` + Status NodeBatchStatusType `header:"sts" format:"%3v" column:"status" type:"tinyint" key:"true" json:"status"` + Comment string `header:"comment" format:"%v" column:"comment" type:"text" json:"comment"` +} + +func NodeHistoryEventAllFields() []string { + return []string{"ts", "run_id", "script_node", "status", "comment"} +} +func NewNodeHistoryEventFromMap(r map[string]any, fields []string) (*NodeHistoryEvent, error) { + res := &NodeHistoryEvent{} + for _, fieldName := range fields { + var err error + switch fieldName { + case "ts": + res.Ts, err = ReadTimeFromRow(fieldName, r) + case "run_id": + res.RunId, err = ReadInt16FromRow(fieldName, r) + case "script_node": + res.ScriptNode, err = ReadStringFromRow(fieldName, r) + case "status": + res.Status, err = ReadNodeBatchStatusFromRow(fieldName, r) + case "comment": + res.Comment, err = ReadStringFromRow(fieldName, r) + default: + return nil, fmt.Errorf("unknown %s field %s", fieldName, TableNameNodeHistory) + } + if err != nil { + return nil, err + } + } + return res, nil +} + +// ToSpacedString - prints formatted field values, uses reflection, shoud not be used in prod +// func (n NodeHistoryEvent) ToSpacedString() string { +// t := reflect.TypeOf(n) +// formats := GetObjectModelFieldFormats(t) +// values := make([]string, t.NumField()) + +// v := reflect.ValueOf(&n).Elem() +// for i := 0; i < v.NumField(); i++ { +// fv := v.Field(i) +// values[i] = fmt.Sprintf(formats[i], fv) +// } +// return strings.Join(values, PrintTableDelimiter) +// } + +type NodeLifespan struct { + StartTs time.Time + LastStatus NodeBatchStatusType + LastStatusTs time.Time +} + +func (ls NodeLifespan) ToString() string { + return fmt.Sprintf("{start_ts:%s, last_status:%s, last_status_ts:%s}", + ls.StartTs.Format(LogTsFormatQuoted), + ls.LastStatus.ToString(), + ls.LastStatusTs.Format(LogTsFormatQuoted)) +} + +type NodeLifespanMap map[string]*NodeLifespan + +func (m NodeLifespanMap) ToString() string { + items := make([]string, len(m)) + nodeIdx := 0 + for nodeName, ls := range m { + items[nodeIdx] = fmt.Sprintf("%s:%s", nodeName, ls.ToString()) + nodeIdx++ + } + return fmt.Sprintf("{%s}", strings.Join(items, ", ")) +} + +type RunNodeLifespanMap map[int16]NodeLifespanMap + +func (m RunNodeLifespanMap) ToString() string { + items := make([]string, len(m)) + runIdx := 0 + for runId, ls := range m { + items[runIdx] = fmt.Sprintf("%d:%s", runId, ls.ToString()) + runIdx++ + } + return fmt.Sprintf("{%s}", strings.Join(items, ", ")) +} diff --git a/pkg/wfmodel/run_counter.go b/pkg/wfmodel/run_counter.go index b839860..9b59b2a 100644 --- a/pkg/wfmodel/run_counter.go +++ b/pkg/wfmodel/run_counter.go @@ -1,9 +1,9 @@ -package wfmodel - -const TableNameRunCounter = "wf_run_counter" - -// Object model with tags that allow to create cql CREATE TABLE queries and to print object -type RunCounter struct { - Keyspace int `header:"ks" format:"%20s" column:"ks" type:"text" key:"true"` - LastRun int `header:"lr" format:"%3d" column:"last_run" type:"int"` -} +package wfmodel + +const TableNameRunCounter = "wf_run_counter" + +// Object model with tags that allow to create cql CREATE TABLE queries and to print object +type RunCounter struct { + Keyspace int `header:"ks" format:"%20s" column:"ks" type:"text" key:"true"` + LastRun int `header:"lr" format:"%3d" column:"last_run" type:"int"` +} diff --git a/pkg/wfmodel/run_history.go b/pkg/wfmodel/run_history.go index 98e9386..98655e1 100644 --- a/pkg/wfmodel/run_history.go +++ b/pkg/wfmodel/run_history.go @@ -1,144 +1,144 @@ -package wfmodel - -import ( - "fmt" - "strings" - "time" -) - -type RunStatusType int8 - -const ( - RunNone RunStatusType = 0 - RunStart RunStatusType = 1 - RunComplete RunStatusType = 2 - RunStop RunStatusType = 3 -) - -const TableNameRunHistory = "wf_run_history" - -func (status RunStatusType) ToString() string { - switch status { - case RunNone: - return "none" - case RunStart: - return "start" - case RunComplete: - return "complete" - case RunStop: - return "stop" - default: - return "unknown" - } -} - -type RunStatusMap map[int16]RunStatusType -type RunStartTsMap map[int16]time.Time - -func (m RunStartTsMap) ToString() string { - sb := strings.Builder{} - for runId, ts := range m { - sb.WriteString(fmt.Sprintf("%d:%s,", runId, ts.Format(LogTsFormatQuoted))) - } - return sb.String() -} - -func (m RunStatusMap) ToString() string { - sb := strings.Builder{} - for runId, runStatus := range m { - sb.WriteString(fmt.Sprintf("%d:%s,", runId, runStatus.ToString())) - } - return sb.String() -} - -// Object model with tags that allow to create cql CREATE TABLE queries and to print object -type RunHistoryEvent struct { - Ts time.Time `header:"ts" format:"%-33v" column:"ts" type:"timestamp" json:"ts"` - RunId int16 `header:"run_id" format:"%6d" column:"run_id" type:"int" key:"true" json:"run_id"` - Status RunStatusType `header:"sts" format:"%3v" column:"status" type:"tinyint" key:"true" json:"status"` - Comment string `header:"comment" format:"%v" column:"comment" type:"text" json:"comment"` -} - -func RunHistoryEventAllFields() []string { - return []string{"ts", "run_id", "status", "comment"} -} - -func NewRunHistoryEventFromMap(r map[string]interface{}, fields []string) (*RunHistoryEvent, error) { - res := &RunHistoryEvent{} - for _, fieldName := range fields { - var err error - switch fieldName { - case "ts": - res.Ts, err = ReadTimeFromRow(fieldName, r) - case "run_id": - res.RunId, err = ReadInt16FromRow(fieldName, r) - case "status": - res.Status, err = ReadRunStatusFromRow(fieldName, r) - case "comment": - res.Comment, err = ReadStringFromRow(fieldName, r) - default: - return nil, fmt.Errorf("unknown %s field %s", fieldName, TableNameRunHistory) - } - if err != nil { - return nil, err - } - } - return res, nil -} - -// ToSpacedString - prints formatted field values, uses reflection, shoud not be used in prod -// func (n RunHistoryEvent) ToSpacedString() string { -// t := reflect.TypeOf(n) -// formats := GetObjectModelFieldFormats(t) -// values := make([]string, t.NumField()) - -// v := reflect.ValueOf(&n).Elem() -// for i := 0; i < v.NumField(); i++ { -// fv := v.Field(i) -// values[i] = fmt.Sprintf(formats[i], fv) -// } -// return strings.Join(values, PrintTableDelimiter) -// } - -type RunLifespan struct { - RunId int16 `json:"run_id"` - StartTs time.Time `json:"start_ts"` - StartComment string `json:"start_comment"` - FinalStatus RunStatusType `json:"final_status"` - CompletedTs time.Time `json:"completed_ts"` - CompletedComment string `json:"completed_comment"` - StoppedTs time.Time `json:"stopped_ts"` - StoppedComment string `json:"stopped_comment"` -} - -func (ls RunLifespan) ToString() string { - return fmt.Sprintf("{run_id: %d, start_ts:%s, final_status:%s, completed_ts:%s, stopped_ts:%s}", - ls.RunId, - ls.StartTs.Format(LogTsFormatQuoted), - ls.FinalStatus.ToString(), - ls.CompletedTs.Format(LogTsFormatQuoted), - ls.StoppedTs.Format(LogTsFormatQuoted)) -} - -type RunLifespanMap map[int16]*RunLifespan - -func (m RunLifespanMap) ToString() string { - items := make([]string, len(m)) - itemIdx := 0 - for runId, ls := range m { - items[itemIdx] = fmt.Sprintf("%d:%s", runId, ls.ToString()) - itemIdx++ - } - return fmt.Sprintf("{%s}", strings.Join(items, ", ")) -} - -// func InheritNodeBatchStatusToRunStatus(nodeBatchStatus NodeBatchStatusType) (RunStatusType, error) { -// switch nodeBatchStatus { -// case NodeBatchFail: -// return RunFail, nil -// case NodeBatchSuccess: -// return RunSuccess, nil -// default: -// return RunNone, fmt.Errorf("cannot inherit run from node batch status %d", nodeBatchStatus) -// } -// } +package wfmodel + +import ( + "fmt" + "strings" + "time" +) + +type RunStatusType int8 + +const ( + RunNone RunStatusType = 0 + RunStart RunStatusType = 1 + RunComplete RunStatusType = 2 + RunStop RunStatusType = 3 +) + +const TableNameRunHistory = "wf_run_history" + +func (status RunStatusType) ToString() string { + switch status { + case RunNone: + return "none" + case RunStart: + return "start" + case RunComplete: + return "complete" + case RunStop: + return "stop" + default: + return "unknown" + } +} + +type RunStatusMap map[int16]RunStatusType +type RunStartTsMap map[int16]time.Time + +func (m RunStartTsMap) ToString() string { + sb := strings.Builder{} + for runId, ts := range m { + sb.WriteString(fmt.Sprintf("%d:%s,", runId, ts.Format(LogTsFormatQuoted))) + } + return sb.String() +} + +func (m RunStatusMap) ToString() string { + sb := strings.Builder{} + for runId, runStatus := range m { + sb.WriteString(fmt.Sprintf("%d:%s,", runId, runStatus.ToString())) + } + return sb.String() +} + +// Object model with tags that allow to create cql CREATE TABLE queries and to print object +type RunHistoryEvent struct { + Ts time.Time `header:"ts" format:"%-33v" column:"ts" type:"timestamp" json:"ts"` + RunId int16 `header:"run_id" format:"%6d" column:"run_id" type:"int" key:"true" json:"run_id"` + Status RunStatusType `header:"sts" format:"%3v" column:"status" type:"tinyint" key:"true" json:"status"` + Comment string `header:"comment" format:"%v" column:"comment" type:"text" json:"comment"` +} + +func RunHistoryEventAllFields() []string { + return []string{"ts", "run_id", "status", "comment"} +} + +func NewRunHistoryEventFromMap(r map[string]any, fields []string) (*RunHistoryEvent, error) { + res := &RunHistoryEvent{} + for _, fieldName := range fields { + var err error + switch fieldName { + case "ts": + res.Ts, err = ReadTimeFromRow(fieldName, r) + case "run_id": + res.RunId, err = ReadInt16FromRow(fieldName, r) + case "status": + res.Status, err = ReadRunStatusFromRow(fieldName, r) + case "comment": + res.Comment, err = ReadStringFromRow(fieldName, r) + default: + return nil, fmt.Errorf("unknown %s field %s", fieldName, TableNameRunHistory) + } + if err != nil { + return nil, err + } + } + return res, nil +} + +// ToSpacedString - prints formatted field values, uses reflection, shoud not be used in prod +// func (n RunHistoryEvent) ToSpacedString() string { +// t := reflect.TypeOf(n) +// formats := GetObjectModelFieldFormats(t) +// values := make([]string, t.NumField()) + +// v := reflect.ValueOf(&n).Elem() +// for i := 0; i < v.NumField(); i++ { +// fv := v.Field(i) +// values[i] = fmt.Sprintf(formats[i], fv) +// } +// return strings.Join(values, PrintTableDelimiter) +// } + +type RunLifespan struct { + RunId int16 `json:"run_id"` + StartTs time.Time `json:"start_ts"` + StartComment string `json:"start_comment"` + FinalStatus RunStatusType `json:"final_status"` + CompletedTs time.Time `json:"completed_ts"` + CompletedComment string `json:"completed_comment"` + StoppedTs time.Time `json:"stopped_ts"` + StoppedComment string `json:"stopped_comment"` +} + +func (ls RunLifespan) ToString() string { + return fmt.Sprintf("{run_id: %d, start_ts:%s, final_status:%s, completed_ts:%s, stopped_ts:%s}", + ls.RunId, + ls.StartTs.Format(LogTsFormatQuoted), + ls.FinalStatus.ToString(), + ls.CompletedTs.Format(LogTsFormatQuoted), + ls.StoppedTs.Format(LogTsFormatQuoted)) +} + +type RunLifespanMap map[int16]*RunLifespan + +func (m RunLifespanMap) ToString() string { + items := make([]string, len(m)) + itemIdx := 0 + for runId, ls := range m { + items[itemIdx] = fmt.Sprintf("%d:%s", runId, ls.ToString()) + itemIdx++ + } + return fmt.Sprintf("{%s}", strings.Join(items, ", ")) +} + +// func InheritNodeBatchStatusToRunStatus(nodeBatchStatus NodeBatchStatusType) (RunStatusType, error) { +// switch nodeBatchStatus { +// case NodeBatchFail: +// return RunFail, nil +// case NodeBatchSuccess: +// return RunSuccess, nil +// default: +// return RunNone, fmt.Errorf("cannot inherit run from node batch status %d", nodeBatchStatus) +// } +// } diff --git a/pkg/wfmodel/run_properties.go b/pkg/wfmodel/run_properties.go index 0989a84..a34410b 100644 --- a/pkg/wfmodel/run_properties.go +++ b/pkg/wfmodel/run_properties.go @@ -1,62 +1,62 @@ -package wfmodel - -import ( - "fmt" -) - -const TableNameRunAffectedNodes = "wf_run_affected_nodes" - -// Object model with tags that allow to create cql CREATE TABLE queries and to print object -type RunProperties struct { - RunId int16 `header:"run_id" format:"%6d" column:"run_id" type:"int" key:"true" json:"run_id"` - StartNodes string `header:"start_nodes" format:"%20v" column:"start_nodes" type:"text" json:"start_nodes"` - AffectedNodes string `header:"affected_nodes" format:"%20v" column:"affected_nodes" type:"text" json:"affected_nodes"` - ScriptUri string `header:"script_uri" format:"%20v" column:"script_uri" type:"text" json:"script_uri"` - ScriptParamsUri string `header:"script_params_uri" format:"%20v" column:"script_params_uri" type:"text" json:"script_params_uri"` - RunDescription string `header:"run_desc" format:"%20v" column:"run_description" type:"text" json:"run_description"` -} - -func RunPropertiesAllFields() []string { - return []string{"run_id", "start_nodes", "affected_nodes", "script_uri", "script_params_uri", "run_description"} -} - -func NewRunPropertiesFromMap(r map[string]interface{}, fields []string) (*RunProperties, error) { - res := &RunProperties{} - for _, fieldName := range fields { - var err error - switch fieldName { - case "run_id": - res.RunId, err = ReadInt16FromRow(fieldName, r) - case "start_nodes": - res.StartNodes, err = ReadStringFromRow(fieldName, r) - case "affected_nodes": - res.AffectedNodes, err = ReadStringFromRow(fieldName, r) - case "script_uri": - res.ScriptUri, err = ReadStringFromRow(fieldName, r) - case "script_params_uri": - res.ScriptParamsUri, err = ReadStringFromRow(fieldName, r) - case "run_description": - res.RunDescription, err = ReadStringFromRow(fieldName, r) - default: - return nil, fmt.Errorf("unknown %s field %s", fieldName, TableNameRunAffectedNodes) - } - if err != nil { - return nil, err - } - } - return res, nil -} - -// ToSpacedString - prints formatted field values, uses reflection, shoud not be used in prod -// func (n RunProperties) ToSpacedString() string { -// t := reflect.TypeOf(n) -// formats := GetObjectModelFieldFormats(t) -// values := make([]string, t.NumField()) - -// v := reflect.ValueOf(&n).Elem() -// for i := 0; i < v.NumField(); i++ { -// fv := v.Field(i) -// values[i] = fmt.Sprintf(formats[i], fv) -// } -// return strings.Join(values, PrintTableDelimiter) -// } +package wfmodel + +import ( + "fmt" +) + +const TableNameRunAffectedNodes = "wf_run_affected_nodes" + +// Object model with tags that allow to create cql CREATE TABLE queries and to print object +type RunProperties struct { + RunId int16 `header:"run_id" format:"%6d" column:"run_id" type:"int" key:"true" json:"run_id"` + StartNodes string `header:"start_nodes" format:"%20v" column:"start_nodes" type:"text" json:"start_nodes"` + AffectedNodes string `header:"affected_nodes" format:"%20v" column:"affected_nodes" type:"text" json:"affected_nodes"` + ScriptUri string `header:"script_uri" format:"%20v" column:"script_uri" type:"text" json:"script_uri"` + ScriptParamsUri string `header:"script_params_uri" format:"%20v" column:"script_params_uri" type:"text" json:"script_params_uri"` + RunDescription string `header:"run_desc" format:"%20v" column:"run_description" type:"text" json:"run_description"` +} + +func RunPropertiesAllFields() []string { + return []string{"run_id", "start_nodes", "affected_nodes", "script_uri", "script_params_uri", "run_description"} +} + +func NewRunPropertiesFromMap(r map[string]any, fields []string) (*RunProperties, error) { + res := &RunProperties{} + for _, fieldName := range fields { + var err error + switch fieldName { + case "run_id": + res.RunId, err = ReadInt16FromRow(fieldName, r) + case "start_nodes": + res.StartNodes, err = ReadStringFromRow(fieldName, r) + case "affected_nodes": + res.AffectedNodes, err = ReadStringFromRow(fieldName, r) + case "script_uri": + res.ScriptUri, err = ReadStringFromRow(fieldName, r) + case "script_params_uri": + res.ScriptParamsUri, err = ReadStringFromRow(fieldName, r) + case "run_description": + res.RunDescription, err = ReadStringFromRow(fieldName, r) + default: + return nil, fmt.Errorf("unknown %s field %s", fieldName, TableNameRunAffectedNodes) + } + if err != nil { + return nil, err + } + } + return res, nil +} + +// ToSpacedString - prints formatted field values, uses reflection, shoud not be used in prod +// func (n RunProperties) ToSpacedString() string { +// t := reflect.TypeOf(n) +// formats := GetObjectModelFieldFormats(t) +// values := make([]string, t.NumField()) + +// v := reflect.ValueOf(&n).Elem() +// for i := 0; i < v.NumField(); i++ { +// fv := v.Field(i) +// values[i] = fmt.Sprintf(formats[i], fv) +// } +// return strings.Join(values, PrintTableDelimiter) +// } diff --git a/pkg/wfmodel/util.go b/pkg/wfmodel/util.go index cfab2b1..6b1d72e 100644 --- a/pkg/wfmodel/util.go +++ b/pkg/wfmodel/util.go @@ -1,134 +1,134 @@ -package wfmodel - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -const WfmodelNamespace string = "wfmodel" -const PrintTableDelimiter = "/" - -const LogTsFormatQuoted = `"2006-01-02T15:04:05.000-0700"` - -// GetSpacedHeader - prints formatted struct field names, uses reflection, shoud not be used in prod -func GetSpacedHeader(n interface{}) string { - t := reflect.TypeOf(n) - columns := make([]string, t.NumField()) - for i := 0; i < t.NumField(); i++ { - field := t.FieldByIndex([]int{i}) - h, ok := field.Tag.Lookup("header") - if ok { - f, ok := field.Tag.Lookup("format") - if ok { - columns[i] = fmt.Sprintf(f, h) - } else { - columns[i] = fmt.Sprintf("%v", h) - } - } else { - columns[i] = "N/A" - } - - } - return strings.Join(columns, PrintTableDelimiter) -} - -/* -GetObjectModelFieldFormats - helper to get formats for each field of an object model -*/ -func GetObjectModelFieldFormats(t reflect.Type) []string { - formats := make([]string, t.NumField()) - - for i := 0; i < t.NumField(); i++ { - field := t.FieldByIndex([]int{i}) - f, ok := field.Tag.Lookup("format") - if ok { - formats[i] = f - } else { - formats[i] = "%v" - } - } - return formats -} - -func GetCreateTableCql(t reflect.Type, keyspace string, tableName string) string { - - columnDefs := make([]string, t.NumField()) - keyDefs := make([]string, t.NumField()) - keyCount := 0 - - for i := 0; i < t.NumField(); i++ { - field := t.FieldByIndex([]int{i}) - cqlColumn, ok := field.Tag.Lookup("column") - if ok { - cqlType, ok := field.Tag.Lookup("type") - if ok { - columnDefs[i] = fmt.Sprintf("%s %s", cqlColumn, cqlType) - cqlKeyFlag, ok := field.Tag.Lookup("key") - if ok && cqlKeyFlag == "true" { - keyDefs[keyCount] = cqlColumn - keyCount++ - } - } else { - columnDefs[i] = fmt.Sprintf("no type for field %s", field.Name) - } - } else { - columnDefs[i] = fmt.Sprintf("no column name for field %s", field.Name) - } - } - - return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (%s, PRIMARY KEY(%s));", - keyspace, - tableName, - strings.Join(columnDefs, ", "), - strings.Join(keyDefs[:keyCount], ", ")) -} - -func ReadTimeFromRow(fieldName string, r map[string]interface{}) (time.Time, error) { - v, ok := r[fieldName].(time.Time) - if !ok { - return v, fmt.Errorf("cannot read time %s from %v", fieldName, r) - } - return v, nil -} - -func ReadInt16FromRow(fieldName string, r map[string]interface{}) (int16, error) { - v, ok := r[fieldName].(int) - if !ok { - return int16(0), fmt.Errorf("cannot read int16 %s from %v", fieldName, r) - } - return int16(v), nil -} - -func ReadInt64FromRow(fieldName string, r map[string]interface{}) (int64, error) { - v, ok := r[fieldName].(int64) - if !ok { - return int64(0), fmt.Errorf("cannot read int64 %s from %v", fieldName, r) - } - return v, nil -} - -func ReadRunStatusFromRow(fieldName string, r map[string]interface{}) (RunStatusType, error) { - v, ok := r[fieldName].(int8) - if !ok { - return RunNone, fmt.Errorf("cannot read run status %s from %v", fieldName, r) - } - return RunStatusType(v), nil -} - -func ReadStringFromRow(fieldName string, r map[string]interface{}) (string, error) { - v, ok := r[fieldName].(string) - if !ok { - return v, fmt.Errorf("cannot read string %s from %v", fieldName, r) - } - return v, nil -} - -func ReadNodeBatchStatusFromRow(fieldName string, r map[string]interface{}) (NodeBatchStatusType, error) { - v, ok := r[fieldName].(int8) - if !ok { - return NodeBatchNone, fmt.Errorf("cannot read node/batch status %s from %v", fieldName, r) - } - return NodeBatchStatusType(v), nil -} +package wfmodel + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +const WfmodelNamespace string = "wfmodel" +const PrintTableDelimiter = "/" + +const LogTsFormatQuoted = `"2006-01-02T15:04:05.000-0700"` + +// GetSpacedHeader - prints formatted struct field names, uses reflection, shoud not be used in prod +func GetSpacedHeader(n any) string { + t := reflect.TypeOf(n) + columns := make([]string, t.NumField()) + for i := 0; i < t.NumField(); i++ { + field := t.FieldByIndex([]int{i}) + h, ok := field.Tag.Lookup("header") + if ok { + f, ok := field.Tag.Lookup("format") + if ok { + columns[i] = fmt.Sprintf(f, h) + } else { + columns[i] = fmt.Sprintf("%v", h) + } + } else { + columns[i] = "N/A" + } + + } + return strings.Join(columns, PrintTableDelimiter) +} + +/* +GetObjectModelFieldFormats - helper to get formats for each field of an object model +*/ +func GetObjectModelFieldFormats(t reflect.Type) []string { + formats := make([]string, t.NumField()) + + for i := 0; i < t.NumField(); i++ { + field := t.FieldByIndex([]int{i}) + f, ok := field.Tag.Lookup("format") + if ok { + formats[i] = f + } else { + formats[i] = "%v" + } + } + return formats +} + +func GetCreateTableCql(t reflect.Type, keyspace string, tableName string) string { + + columnDefs := make([]string, t.NumField()) + keyDefs := make([]string, t.NumField()) + keyCount := 0 + + for i := 0; i < t.NumField(); i++ { + field := t.FieldByIndex([]int{i}) + cqlColumn, ok := field.Tag.Lookup("column") + if ok { + cqlType, ok := field.Tag.Lookup("type") + if ok { + columnDefs[i] = fmt.Sprintf("%s %s", cqlColumn, cqlType) + cqlKeyFlag, ok := field.Tag.Lookup("key") + if ok && cqlKeyFlag == "true" { + keyDefs[keyCount] = cqlColumn + keyCount++ + } + } else { + columnDefs[i] = fmt.Sprintf("no type for field %s", field.Name) + } + } else { + columnDefs[i] = fmt.Sprintf("no column name for field %s", field.Name) + } + } + + return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (%s, PRIMARY KEY(%s));", + keyspace, + tableName, + strings.Join(columnDefs, ", "), + strings.Join(keyDefs[:keyCount], ", ")) +} + +func ReadTimeFromRow(fieldName string, r map[string]any) (time.Time, error) { + v, ok := r[fieldName].(time.Time) + if !ok { + return v, fmt.Errorf("cannot read time %s from %v", fieldName, r) + } + return v, nil +} + +func ReadInt16FromRow(fieldName string, r map[string]any) (int16, error) { + v, ok := r[fieldName].(int) + if !ok { + return int16(0), fmt.Errorf("cannot read int16 %s from %v", fieldName, r) + } + return int16(v), nil +} + +func ReadInt64FromRow(fieldName string, r map[string]any) (int64, error) { + v, ok := r[fieldName].(int64) + if !ok { + return int64(0), fmt.Errorf("cannot read int64 %s from %v", fieldName, r) + } + return v, nil +} + +func ReadRunStatusFromRow(fieldName string, r map[string]any) (RunStatusType, error) { + v, ok := r[fieldName].(int8) + if !ok { + return RunNone, fmt.Errorf("cannot read run status %s from %v", fieldName, r) + } + return RunStatusType(v), nil +} + +func ReadStringFromRow(fieldName string, r map[string]any) (string, error) { + v, ok := r[fieldName].(string) + if !ok { + return v, fmt.Errorf("cannot read string %s from %v", fieldName, r) + } + return v, nil +} + +func ReadNodeBatchStatusFromRow(fieldName string, r map[string]any) (NodeBatchStatusType, error) { + v, ok := r[fieldName].(int8) + if !ok { + return NodeBatchNone, fmt.Errorf("cannot read node/batch status %s from %v", fieldName, r) + } + return NodeBatchStatusType(v), nil +} diff --git a/pkg/xfer/get_file_bytes.go b/pkg/xfer/get_file_bytes.go index 82849e3..b818cb4 100644 --- a/pkg/xfer/get_file_bytes.go +++ b/pkg/xfer/get_file_bytes.go @@ -2,7 +2,6 @@ package xfer import ( "fmt" - "io/ioutil" "net/url" "os" ) @@ -15,9 +14,15 @@ func GetFileBytes(uri string, certPath string, privateKeys map[string]string) ([ var bytes []byte if u.Scheme == UriSchemeFile || len(u.Scheme) == 0 { - bytes, err = ioutil.ReadFile(uri) + bytes, err = os.ReadFile(uri) + if err != nil { + return nil, err + } } else if u.Scheme == UriSchemeHttp || u.Scheme == UriSchemeHttps { bytes, err = readHttpFile(uri, u.Scheme, certPath) + if err != nil { + return nil, err + } } else if u.Scheme == UriSchemeSftp { // When dealing with sftp, we download the *whole* file, then we read all of it dstFile, err := os.CreateTemp("", "capi") @@ -34,17 +39,13 @@ func GetFileBytes(uri string, certPath string, privateKeys map[string]string) ([ defer os.Remove(dstFile.Name()) // Read - bytes, err = ioutil.ReadFile(dstFile.Name()) + bytes, err = os.ReadFile(dstFile.Name()) if err != nil { - err = fmt.Errorf("cannot read from file %s downloaded from %s: %s", dstFile.Name(), uri, err.Error()) + return nil, fmt.Errorf("cannot read from file %s downloaded from %s: %s", dstFile.Name(), uri, err.Error()) } } else { return nil, fmt.Errorf("uri scheme %s not supported: %s", u.Scheme, uri) } - if err != nil { - return nil, fmt.Errorf("cannot read input from %s: %s", uri, err.Error()) - } - return bytes, nil } diff --git a/pkg/xfer/http.go b/pkg/xfer/http.go index 2c5ac06..746ba09 100644 --- a/pkg/xfer/http.go +++ b/pkg/xfer/http.go @@ -5,8 +5,8 @@ import ( "crypto/x509" "fmt" "io" - "io/ioutil" "net/http" + "os" "path" "time" ) @@ -19,13 +19,13 @@ const UriSchemeSftp string = "sftp" func GetHttpReadCloser(uri string, scheme string, certDir string) (io.ReadCloser, error) { caCertPool := x509.NewCertPool() if scheme == UriSchemeHttps { - files, err := ioutil.ReadDir(certDir) + files, err := os.ReadDir(certDir) if err != nil { return nil, fmt.Errorf("cannot read ca dir with PEM certs %s: %s", certDir, err.Error()) } for _, f := range files { - caCert, err := ioutil.ReadFile(path.Join(certDir, f.Name())) + caCert, err := os.ReadFile(path.Join(certDir, f.Name())) if err != nil { return nil, fmt.Errorf("cannot read PEM cert %s: %s", f.Name(), err.Error()) } diff --git a/pkg/xfer/sftp.go b/pkg/xfer/sftp.go index 4a9fa95..76b5df9 100644 --- a/pkg/xfer/sftp.go +++ b/pkg/xfer/sftp.go @@ -37,7 +37,7 @@ func parseSftpUri(uri string, privateKeys map[string]string) (*ParsedSftpUri, er privateKeyPath, ok := privateKeys[userName] if !ok { - return nil, fmt.Errorf("username %s in sftp uri %s not found in enviroment configuration", userName, uri) + return nil, fmt.Errorf("username %s in sftp uri %s not found in environment configuration", userName, uri) } hostParts := strings.Split(u.Host, ":") @@ -53,14 +53,13 @@ func parseSftpUri(uri string, privateKeys map[string]string) (*ParsedSftpUri, er } func DownloadSftpFile(uri string, privateKeys map[string]string, dstFile *os.File) error { - //parsedUri, err := parseSftpUri(strings.ReplaceAll(uri, "sftp://", ""), privateKeys) parsedUri, err := parseSftpUri(uri, privateKeys) if err != nil { return err } // Assume empty key password "" - sshClientConfig, err := NewSshClientConfig(parsedUri.User, parsedUri.Host, parsedUri.Port, parsedUri.PrivateKeyPath, "") + sshClientConfig, err := NewSshClientConfig(parsedUri.User, parsedUri.PrivateKeyPath, "") if err != nil { return err } @@ -73,13 +72,13 @@ func DownloadSftpFile(uri string, privateKeys map[string]string, dstFile *os.Fil } defer sshClient.Close() - sftp, err := sftp.NewClient(sshClient) + sftpClient, err := sftp.NewClient(sshClient) if err != nil { return fmt.Errorf("cannot create sftp client to %s: %s", uri, err.Error()) } - defer sftp.Close() + defer sftpClient.Close() - srcFile, err := sftp.Open(parsedUri.RemotePath) + srcFile, err := sftpClient.Open(parsedUri.RemotePath) if err != nil { return fmt.Errorf("cannot open target file for sftp download %s: %s", uri, err.Error()) } @@ -100,7 +99,7 @@ func UploadSftpFile(srcPath string, uri string, privateKeys map[string]string) e } // Assume empty key password "" - sshClientConfig, err := NewSshClientConfig(parsedUri.User, parsedUri.Host, parsedUri.Port, parsedUri.PrivateKeyPath, "") + sshClientConfig, err := NewSshClientConfig(parsedUri.User, parsedUri.PrivateKeyPath, "") if err != nil { return err } @@ -113,13 +112,13 @@ func UploadSftpFile(srcPath string, uri string, privateKeys map[string]string) e } defer sshClient.Close() - sftp, err := sftp.NewClient(sshClient) + sftpClient, err := sftp.NewClient(sshClient) if err != nil { return fmt.Errorf("cannot create sftp client to %s: %s", uri, err.Error()) } - defer sftp.Close() + defer sftpClient.Close() - if err := sftp.MkdirAll(filepath.Dir(parsedUri.RemotePath)); err != nil { + if err := sftpClient.MkdirAll(filepath.Dir(parsedUri.RemotePath)); err != nil { return fmt.Errorf("cannot create target dir for %s: %s", uri, err.Error()) } @@ -129,7 +128,7 @@ func UploadSftpFile(srcPath string, uri string, privateKeys map[string]string) e } defer srcFile.Close() - dstFile, err := sftp.Create(parsedUri.RemotePath) + dstFile, err := sftpClient.Create(parsedUri.RemotePath) if err != nil { return fmt.Errorf("cannot create on upload %s: %s", uri, err.Error()) } diff --git a/pkg/xfer/ssh.go b/pkg/xfer/ssh.go index 1becbf3..201adb3 100644 --- a/pkg/xfer/ssh.go +++ b/pkg/xfer/ssh.go @@ -5,7 +5,6 @@ import ( "encoding/pem" "errors" "fmt" - "io/ioutil" "net" "os" "path/filepath" @@ -24,10 +23,10 @@ func signerFromPem(pemBytes []byte, password []byte) (ssh.Signer, error) { return nil, err } - // handle encrypted key - if x509.IsEncryptedPEMBlock(pemBlock) { + // handle key encrypted with password + if x509.IsEncryptedPEMBlock(pemBlock) { //nolint:all // decrypt PEM - pemBlock.Bytes, err = x509.DecryptPEMBlock(pemBlock, []byte(password)) + pemBlock.Bytes, err = x509.DecryptPEMBlock(pemBlock, []byte(password)) //nolint:all if err != nil { return nil, fmt.Errorf("cannot decrypt PEM block %s", err.Error()) } @@ -45,18 +44,18 @@ func signerFromPem(pemBytes []byte, password []byte) (ssh.Signer, error) { } return signer, nil - } else { - // generate signer instance from plain key - signer, err := ssh.ParsePrivateKey(pemBytes) - if err != nil { - return nil, fmt.Errorf("cannot parsie plain private key %s", err.Error()) - } + } - return signer, nil + // generate signer instance from plain key + signer, err := ssh.ParsePrivateKey(pemBytes) + if err != nil { + return nil, fmt.Errorf("cannot parsie plain private key %s", err.Error()) } + + return signer, nil } -func parsePemBlock(block *pem.Block) (interface{}, error) { +func parsePemBlock(block *pem.Block) (any, error) { switch block.Type { case "RSA PRIVATE KEY": key, err := x509.ParsePKCS1PrivateKey(block.Bytes) @@ -84,13 +83,13 @@ func parsePemBlock(block *pem.Block) (interface{}, error) { } } -func NewSshClientConfig(user string, host string, port int, privateKeyPath string, privateKeyPassword string) (*ssh.ClientConfig, error) { +func NewSshClientConfig(user string, privateKeyPath string, privateKeyPassword string) (*ssh.ClientConfig, error) { keyPath := privateKeyPath if strings.HasPrefix(keyPath, "~/") { homeDir, _ := os.UserHomeDir() keyPath = filepath.Join(homeDir, keyPath[2:]) } - pemBytes, err := ioutil.ReadFile(keyPath) + pemBytes, err := os.ReadFile(keyPath) if err != nil { return nil, fmt.Errorf("cannot read private key file %s: %s", keyPath, err.Error()) } diff --git a/test/code/lookup/generate_data.go b/test/code/lookup/generate_data.go index 7eaa62b..b90ded4 100644 --- a/test/code/lookup/generate_data.go +++ b/test/code/lookup/generate_data.go @@ -186,7 +186,7 @@ func shuffleAndSaveInOrders(inOrders []*Order, totalChunks int, basePath string, } if strings.Contains(formats, "parquet") { - if err := parquetWriter.FileWriter.AddData(map[string]interface{}{ + if err := parquetWriter.FileWriter.AddData(map[string]any{ "order_id": item.OrderId, "customer_id": item.CustomerId, "order_status": item.OrderStatus, @@ -300,7 +300,7 @@ func shuffleAndSaveInOrderItems(inOrderItems []*OrderItem, totalChunks int, base item.FreightValue)) } if strings.Contains(formats, "parquet") { - if err := parquetWriter.FileWriter.AddData(map[string]interface{}{ + if err := parquetWriter.FileWriter.AddData(map[string]any{ "order_id": item.OrderId, "order_item_id": item.OrderItemId, "product_id": item.ProductId, @@ -399,7 +399,7 @@ func sortAndSaveNoGroup(items []*NoGroupItem, fileBase string, formats string) { } for _, item := range items { - if err := w.FileWriter.AddData(map[string]interface{}{ + if err := w.FileWriter.AddData(map[string]any{ "order_id": item.OrderId, "order_item_id": item.OrderItemId, "product_id": item.ProductId, @@ -502,7 +502,7 @@ func sortAndSaveGroup(items []*GroupItem, fileBase string, formats string) { } for _, item := range items { - if err := w.FileWriter.AddData(map[string]interface{}{ + if err := w.FileWriter.AddData(map[string]any{ "total_value": storage.ParquetWriterDecimal2(item.TotalOrderValue), "order_purchase_timestamp": storage.ParquetWriterMilliTs(item.OrderPurchaseTs), "order_id": item.OrderId, diff --git a/test/code/parquet/capiparquet.go b/test/code/parquet/capiparquet.go index 358b66b..2fced85 100644 --- a/test/code/parquet/capiparquet.go +++ b/test/code/parquet/capiparquet.go @@ -29,7 +29,7 @@ func usage(flagset *flag.FlagSet) { fmt.Printf("Capillaries parquet tool\nUsage: capiparquet \nCommands:\n") fmt.Printf(" %s %s\n %s %s %s %s\n %s %s %s\n", CmdCat, "", - CmdDiff, "", "", "[optional paramaters]", + CmdDiff, "", "", "[optional parameters]", CmdSort, "", "") if flagset != nil { fmt.Printf("\n%s optional parameters:\n", flagset.Name()) @@ -327,7 +327,7 @@ func cat(path string) error { type IndexedRow struct { Key string - Row map[string]interface{} + Row map[string]any } func sortFile(path string, idxDef *sc.IdxDef) error { @@ -383,7 +383,7 @@ func sortFile(path string, idxDef *sc.IdxDef) error { return fmt.Errorf("cannot get row %d: %s", rowIdx, err.Error()) } - typedData := map[string]interface{}{} + typedData := map[string]any{} for colIdx, fieldName := range fields { se, _ := schemaElementMap[fieldName] diff --git a/test/code/portfolio/bigtest/generate_bigtest_data.go b/test/code/portfolio/bigtest/generate_bigtest_data.go index b2a0cdb..de8b4fb 100755 --- a/test/code/portfolio/bigtest/generate_bigtest_data.go +++ b/test/code/portfolio/bigtest/generate_bigtest_data.go @@ -113,7 +113,7 @@ func generateHoldings(fileQuickHoldingsPath string, fileInHoldingsPath string, b } } - if err := w.FileWriter.AddData(map[string]interface{}{ + if err := w.FileWriter.AddData(map[string]any{ "account_id": accId, "d": storage.ParquetWriterMilliTs(d), "ticker": line[2], @@ -227,7 +227,7 @@ func generateTxns(fileQuickTxnsPath string, fileInTxnsPath string, bigAccountsMa } } - if err := w.FileWriter.AddData(map[string]interface{}{ + if err := w.FileWriter.AddData(map[string]any{ "ts": storage.ParquetWriterMilliTs(d), "account_id": accId, "ticker": line[2], @@ -318,7 +318,7 @@ func generateOutTotals(fileQuickAccountYearPath string, fileOutAccountYearPath s return fmt.Errorf("cannot parse ret '%s' in account_year_perf_baseline: %s", line[3], err.Error()) } - if err := w.FileWriter.AddData(map[string]interface{}{ + if err := w.FileWriter.AddData(map[string]any{ "ARK fund": accId, "Period": line[1], "Sector": line[2], @@ -391,7 +391,7 @@ func generateOutBySector(fileQuickAccountPeriodSectorPath string, fileOutAccount return fmt.Errorf("cannot parse ret '%s' in account_period_sector_perf_baseline: %s", line[3], err.Error()) } - if err := w.FileWriter.AddData(map[string]interface{}{ + if err := w.FileWriter.AddData(map[string]any{ "ARK fund": accId, "Period": line[1], "Sector": line[2], @@ -441,7 +441,7 @@ func generateAccounts(fileInAccountsPath string, quickAccountsMap map[string]str if err != nil { return fmt.Errorf("cannot parse account earliest_period_start '%s': %s", eps, err.Error()) } - if err := w.FileWriter.AddData(map[string]interface{}{ + if err := w.FileWriter.AddData(map[string]any{ "account_id": accId, "earliest_period_start": storage.ParquetWriterMilliTs(d), }); err != nil {