diff --git a/.github/workflows/go-test.yml b/.github/workflows/go-test.yml index 7463c9ba091..f9bc9804bfa 100644 --- a/.github/workflows/go-test.yml +++ b/.github/workflows/go-test.yml @@ -25,3 +25,5 @@ jobs: run: test -z $(gofmt -l .) - name: Go test run: go test -short -v ./... + env: + RILL_RUNTIME_DRUID_TEST_DSN: ${{ secrets.RILL_RUNTIME_DRUID_TEST_DSN }} diff --git a/runtime/drivers/clickhouse/olap.go b/runtime/drivers/clickhouse/olap.go index 0ac69d45418..3c86192a748 100644 --- a/runtime/drivers/clickhouse/olap.go +++ b/runtime/drivers/clickhouse/olap.go @@ -685,6 +685,8 @@ func databaseTypeToPB(dbt string, nullable bool) (*runtimev1.Type, error) { t.Code = runtimev1.Type_CODE_STRING case "OTHER": t.Code = runtimev1.Type_CODE_JSON + case "NOTHING": + t.Code = runtimev1.Type_CODE_STRING case "POINT": return databaseTypeToPB("Array(Float64)", nullable) case "RING": diff --git a/runtime/drivers/druid/druid.go b/runtime/drivers/druid/druid.go index b57d1622492..850946366cb 100644 --- a/runtime/drivers/druid/druid.go +++ b/runtime/drivers/druid/druid.go @@ -304,6 +304,9 @@ func dsnFromConfig(conf *configProperties) (string, error) { func correctURL(dsn string) (string, error) { u, err := url.Parse(dsn) if err != nil { + if strings.Contains(err.Error(), dsn) { // avoid returning the actual DSN with the password which will be logged + return "", fmt.Errorf("%s", strings.ReplaceAll(err.Error(), dsn, "")) + } return "", err } diff --git a/runtime/drivers/druid/druidsqldriver/druid_api_sql_driver.go b/runtime/drivers/druid/druidsqldriver/druid_api_sql_driver.go index c4edf0cca92..cd993ec2471 100644 --- a/runtime/drivers/druid/druidsqldriver/druid_api_sql_driver.go +++ b/runtime/drivers/druid/druidsqldriver/druid_api_sql_driver.go @@ -12,6 +12,7 @@ import ( "reflect" "regexp" "strconv" + "strings" "time" "github.com/google/uuid" @@ -96,12 +97,18 @@ func (c *sqlConnection) QueryContext(ctx context.Context, query string, args []d req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.dsn, bodyReader) if err != nil { + if strings.Contains(err.Error(), c.dsn) { // avoid returning the actual DSN with the password which will be logged + return nil, retrier.Fail, fmt.Errorf("%s", strings.ReplaceAll(err.Error(), c.dsn, "")) + } return nil, retrier.Fail, err } req.Header.Add("Content-Type", "application/json") resp, err := c.client.Do(req) if err != nil { + if strings.Contains(err.Error(), c.dsn) { // avoid returning the actual DSN with the password which will be logged + return nil, retrier.Fail, fmt.Errorf("%s", strings.ReplaceAll(err.Error(), c.dsn, "")) + } return nil, retrier.Fail, err } diff --git a/runtime/queries/metricsview_aggregation_test.go b/runtime/queries/metricsview_aggregation_test.go index b0e91c68af2..86e71262089 100644 --- a/runtime/queries/metricsview_aggregation_test.go +++ b/runtime/queries/metricsview_aggregation_test.go @@ -2666,7 +2666,6 @@ func TestMetricsViewsAggregation_Druid_comparison_no_time_dim(t *testing.T) { Name: "measure_1", }, }, - TimeRange: &runtimev1.TimeRange{ Start: timestamppb.New(time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)), End: timestamppb.New(time.Date(2022, 1, 3, 0, 0, 0, 0, time.UTC)), @@ -2675,7 +2674,8 @@ func TestMetricsViewsAggregation_Druid_comparison_no_time_dim(t *testing.T) { Start: timestamppb.New(time.Date(2022, 1, 3, 0, 0, 0, 0, time.UTC)), End: timestamppb.New(time.Date(2022, 1, 5, 0, 0, 0, 0, time.UTC)), }, - Limit: &limit, + Limit: &limit, + SecurityClaims: testClaims(), } err = q.Resolve(context.Background(), rt, instanceID, 0) require.NoError(t, err) diff --git a/runtime/queries/metricsview_comparison_toplist_test.go b/runtime/queries/metricsview_comparison_toplist_test.go index aed31044bb4..b9ae9ef6b73 100644 --- a/runtime/queries/metricsview_comparison_toplist_test.go +++ b/runtime/queries/metricsview_comparison_toplist_test.go @@ -222,7 +222,8 @@ func TestMetricsViewsComparison_Druid_dim_order(t *testing.T) { Desc: true, }, }, - Limit: 250, + Limit: 250, + SecurityClaims: testClaims(), } err = q.Resolve(context.Background(), rt, instanceID, 0) diff --git a/runtime/resolver.go b/runtime/resolver.go index e701eece8c9..01bf190fd02 100644 --- a/runtime/resolver.go +++ b/runtime/resolver.go @@ -264,6 +264,9 @@ func (r *driverResolverResult) MarshalJSON() ([]byte, error) { if r.rows.Err() != nil { return nil, r.rows.Err() } + if out == nil { // fixes 'null' output when there are no rows + out = []map[string]any{} + } return json.Marshal(out) } diff --git a/runtime/resolvers/resolvers_test.go b/runtime/resolvers/resolvers_test.go new file mode 100644 index 00000000000..dda237419d4 --- /dev/null +++ b/runtime/resolvers/resolvers_test.go @@ -0,0 +1,190 @@ +package resolvers + +import ( + "bytes" + "context" + "encoding/csv" + "encoding/json" + "flag" + "maps" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "testing" + + "github.com/rilldata/rill/runtime" + "github.com/rilldata/rill/runtime/testruntime" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +type Resolvers struct { + Project map[string]yaml.Node + Connectors map[string]*testruntime.InstanceOptionsForResolvers + Tests map[string]*Test +} + +type Test struct { + Options struct { + InstanceID string + Resolver string + ResolverProperties map[string]any "yaml:\"resolver_properties\"" + Args map[string]any + Claims struct { + UserAttributes map[string]any "yaml:\"user_attributes\"" + } + } + Result []map[string]any + CSVResult string "yaml:\"csv_result\"" + ErrorContains string "yaml:\"error_contains\"" +} + +var update = flag.Bool("update", false, "Update test results") + +func TestMain(m *testing.M) { + flag.Parse() + os.Exit(m.Run()) +} + +func TestResolvers(t *testing.T) { + files, err := filepath.Glob("./testdata/*_resolvers_test.yaml") + require.NoError(t, err) + for _, f := range files { + t.Log("Running with", f) + yamlFile, err := os.ReadFile(f) + require.NoError(t, err) + var r Resolvers + err = yaml.Unmarshal(yamlFile, &r) + require.NoError(t, err) + + files := make(map[string]string) + for name, node := range r.Project { + bytes, err := yaml.Marshal(&node) + require.NoError(t, err) + files[name] = string(bytes) + } + + for connector, opts := range r.Connectors { + t.Log("Running with", connector) + if opts == nil { + opts = &testruntime.InstanceOptionsForResolvers{} + } + if opts.Files == nil { + opts.Files = map[string]string{"rill.yaml": ""} + } + + switch connector { + case "druid": + opts.OLAPDriver = "druid" + case "clickhouse": + opts.OLAPDriver = "clickhouse" + } + + maps.Copy(opts.Files, files) + rt, instanceID := testruntime.NewInstanceForResolvers(t, *opts) + for testName, test := range r.Tests { + t.Run(testName, func(t *testing.T) { + t.Log("======================") + t.Log("Running ", testName, "with", f, "and", connector) + testruntime.RequireReconcileState(t, rt, instanceID, -1, 0, 0) + + ropts := test.Options + ro := &runtime.ResolveOptions{} + ro.InstanceID = instanceID + ro.Resolver = ropts.Resolver + ro.ResolverProperties = ropts.ResolverProperties + ro.Args = ropts.Args + ro.Claims = &runtime.SecurityClaims{ + UserAttributes: ropts.Claims.UserAttributes, + } + res, err := rt.Resolve(context.Background(), ro) + if test.ErrorContains != "" { + if *update { + require.Error(t, err) + test.ErrorContains = err.Error() + } else { + require.ErrorContains(t, err, test.ErrorContains) + } + return + } else { + require.NoError(t, err) + } + var rows []map[string]interface{} + b, err := res.MarshalJSON() + require.NoError(t, err) + require.NoError(t, json.Unmarshal(b, &rows), string(b)) + if *update { + test.Result = rows + for _, m := range test.Result { + for k, v := range m { + node := yaml.Node{} + node.Kind = yaml.ScalarNode + switch val := v.(type) { + case float32: + node.Value = strconv.FormatFloat(float64(val), 'f', 2, 32) + m[k] = &node + case float64: + node.Value = strconv.FormatFloat(val, 'f', 2, 64) + m[k] = &node + } + } + } + } else { + expected := test.Result + if test.CSVResult != "" { + expected = readCSV(t, test.CSVResult) + } + require.Equal(t, expected, rows) + } + t.Log("======================") + }) + } + if *update { + buf := bytes.Buffer{} + yamlEncoder := yaml.NewEncoder(&buf) + yamlEncoder.SetIndent(2) + err := yamlEncoder.Encode(r) + require.NoError(t, err) + require.NoError(t, os.WriteFile(f, buf.Bytes(), 0644)) + } + } + } +} + +func readCSV(t *testing.T, in string) []map[string]any { + var digitCheck = regexp.MustCompile(`^[0-9]+$`) + var numericCheck = regexp.MustCompile(`^[0-9\.]+$`) + + r := csv.NewReader(strings.NewReader(in)) + records, err := r.ReadAll() + require.NoError(t, err) + + rows := make([]map[string]any, 0, len(records)) + headers := records[0] + for i := 1; i < len(records); i++ { + m := make(map[string]any, len(headers)) + for j, h := range headers { + str := records[i][j] + + if str == "" { + m[h] = nil + continue + } + if digitCheck.MatchString(str) { + num, err := strconv.Atoi(str) + require.NoError(t, err) + m[h] = num + } else if numericCheck.MatchString(str) { + num, err := strconv.ParseFloat(str, 64) + require.NoError(t, err) + m[h] = num + } else { + m[h] = records[i][j] + } + } + rows = append(rows, m) + } + return rows +} diff --git a/runtime/resolvers/testdata/clickhouse_resolvers_test.yaml b/runtime/resolvers/testdata/clickhouse_resolvers_test.yaml new file mode 100644 index 00000000000..460bb0b3d6e --- /dev/null +++ b/runtime/resolvers/testdata/clickhouse_resolvers_test.yaml @@ -0,0 +1,88 @@ +project: + sources: {} + "models/ad_bids_mini.yaml": + type: model + sql: SELECT * FROM url('https://raw.githubusercontent.com/rilldata/rill/main/runtime/testruntime/testdata/ad_bids/data/AdBids_mini.csv', CSV) + output: + columns: (id UInt32,timestamp DateTime64,publisher varchar,domain varchar,bid_price Float32,volume UInt8,impressions UInt8,"ad words" varchar,clicks Float32,device varchar) + materialize: true + incremental_strategy: append + "dashboards/ad_bids_mini_metrics_with_policy.yaml": + model: ad_bids_mini + display_name: Ad bids + description: + timeseries: timestamp + smallest_time_grain: "" + dimensions: + - label: Publisher + name: publisher + expression: publisher + description: "" + - label: Domain + property: domain + description: "" + measures: + - label: "Number of bids" + name: bid's number + expression: count(*) + - label: "Total volume" + name: total volume + expression: sum(volume) + - label: "Total impressions" + name: total impressions + expression: sum(impressions) + - label: "Total clicks" + name: total click"s + expression: sum(clicks) + security: + access: true + row_filter: "domain = '{{ .user.domain }}'" + exclude: + - if: "'{{ .user.domain }}' != 'msn.com'" + names: + - total volume +connectors: + clickhouse: null +tests: + empty: + resolver: mv_sql_policy_api + options: + resolver: "metrics_sql" + resolver_properties: + sql: "select \n publisher,\n domain, \n \"total impressions\"\nFROM \n ad_bids_mini_metrics_with_policy \n" + args: {} + claims: + user_attributes: + domain: google.com + email: user@google.com + result: [] + msn: + resolver: mv_sql_policy_api + options: + resolver: metrics_sql + resolver_properties: + sql: "select \n publisher,\n domain, \n \"total impressions\"\nFROM \n ad_bids_mini_metrics_with_policy \n" + args: {} + claims: + user_attributes: + domain: msn.com + email: user@msn.com + result: + - domain: msn.com + publisher: "" + total impressions: 3.00 + simple: + resolver: mv_sql_policy_api + options: + resolver: metrics_sql + resolver_properties: + sql: "select \n publisher,\n domain, \n \"total impressions\"\nFROM \n ad_bids_mini_metrics_with_policy \n" + args: {} + claims: + user_attributes: + domain: yahoo.com + email: user@yahoo.com + result: + - domain: yahoo.com + publisher: Yahoo + total impressions: 3.00 diff --git a/runtime/resolvers/testdata/druid_resolvers_test.yaml b/runtime/resolvers/testdata/druid_resolvers_test.yaml new file mode 100644 index 00000000000..6774ffd0ce6 --- /dev/null +++ b/runtime/resolvers/testdata/druid_resolvers_test.yaml @@ -0,0 +1,49 @@ +project: + sources: {} + models: {} + "dashboards/ad_bids_mini_metrics_with_policy.yaml": + model: AdBids + display_name: Ad bids + description: + timeseries: __time + smallest_time_grain: "" + dimensions: + - label: Publisher + name: publisher + expression: publisher + description: "" + - label: Domain + property: domain + description: "" + measures: + - label: "Number of bids" + name: bid's number + expression: count(*) + - label: "Max bid price" + name: max bid price + expression: max(bid_price) + - label: "min bid price" + name: min bid price + expression: min(bid_price) + security: + access: true + # row_filter: "domain = '{{ .user.domain }}'" is not supported in Druid + # exclude: is not supported in Druid +connectors: + druid: null +tests: + simple: + options: + resolver: metrics_sql + resolver_properties: + sql: "select \n publisher,\n domain, \n \"min bid price\",\n \"max bid price\"\nFROM \n ad_bids_mini_metrics_with_policy \nWHERE\n publisher is not null AND domain = 'news.yahoo.com'\nORDER BY \n publisher,\n domain\nLIMIT 1\n" + args: {} + claims: + user_attributes: + domain: yahoo.com + email: user@yahoo.com + result: + - domain: news.yahoo.com + max bid price: 6.00 + min bid price: 1.00 + publisher: Yahoo diff --git a/runtime/resolvers/testdata/duckdb_resolvers_test.yaml b/runtime/resolvers/testdata/duckdb_resolvers_test.yaml new file mode 100644 index 00000000000..26e4a9665f5 --- /dev/null +++ b/runtime/resolvers/testdata/duckdb_resolvers_test.yaml @@ -0,0 +1,122 @@ +project: + "sources/ad_bids_mini_source.yaml": + connector: https + path: https://raw.githubusercontent.com/rilldata/rill/main/runtime/testruntime/testdata/ad_bids/data/AdBids_mini.csv + "models/ad_bids_mini.yaml": + sql: | + select + id, + timestamp, + publisher, + domain, + volume, + impressions, + clicks + from ad_bids_mini_source + "dashboards/ad_bids_mini_metrics_with_policy.yaml": + model: ad_bids_mini + display_name: Ad bids + description: + timeseries: timestamp + smallest_time_grain: "" + dimensions: + - label: Publisher + name: publisher + expression: upper(publisher) + description: "" + - label: Domain + property: domain + description: "" + measures: + - label: "Number of bids" + name: bid's number + expression: count(*) + - label: "Total volume" + name: total volume + expression: sum(volume) + - label: "Total impressions" + name: total impressions + expression: sum(impressions) + - label: "Total clicks" + name: total click"s + expression: sum(clicks) + security: + access: true + row_filter: "domain = '{{ .user.domain }}'" + exclude: + - if: "'{{ .user.domain }}' != 'msn.com'" + names: + - total volume +connectors: + duckdb: null +tests: + csv: + options: + resolver: metrics_sql + resolver_properties: + sql: "select \n publisher,\n domain, \n \"total impressions\" \nFROM \n ad_bids_mini_metrics_with_policy \n" + args: {} + claims: + user_attributes: + domain: msn.com + email: user@msn.com + csv_result: | + domain,publisher,total impressions + msn.com,,3.00 + empty: + options: + resolver: metrics_sql + resolver_properties: + sql: "select \n publisher,\n domain, \n \"total impressions\" \nFROM \n ad_bids_mini_metrics_with_policy \n" + args: {} + claims: + user_attributes: + domain: google.com + email: user@google.com + result: [] + msn: + options: + resolver: metrics_sql + resolver_properties: + sql: "select \n publisher,\n domain, \n \"total impressions\" \nFROM \n ad_bids_mini_metrics_with_policy \n" + args: {} + claims: + user_attributes: + domain: msn.com + email: user@msn.com + result: + - domain: msn.com + publisher: null + total impressions: 3.00 + sql: + options: + resolver: sql + resolver_properties: + sql: "select \n publisher,\n domain \n \nFROM \n ad_bids_mini where publisher = 'Yahoo' limit 1\n" + args: {} + claims: + user_attributes: + domain: msn.com + email: user@msn.com + result: + - publisher: Yahoo + domain: yahoo.com + simple: + options: + resolver: metrics_sql + resolver_properties: + sql: "select \n publisher,\n domain, \n \"total impressions\" \nFROM \n ad_bids_mini_metrics_with_policy \n" + args: {} + claims: + user_attributes: {} + result: [] + error: + options: + resolver: metrics_sql + resolver_properties: + sql: "select \n publisher,\n dom1, \n \"total impressions\" \nFROM \n ad_bids_mini_metrics_with_policy \n" + args: {} + claims: + user_attributes: {} + error_contains: "selected column `dom1` not found" + diff --git a/runtime/testruntime/reconcile.go b/runtime/testruntime/reconcile.go index 24d4a1dc01a..9d6badd69ac 100644 --- a/runtime/testruntime/reconcile.go +++ b/runtime/testruntime/reconcile.go @@ -118,7 +118,9 @@ func RequireReconcileState(t testing.TB, rt *runtime.Runtime, id string, lenReso require.Equal(t, lenParseErrs, len(parseErrs), "parse errors: %s", strings.Join(parseErrs, "\n")) require.Equal(t, lenReconcileErrs, len(reconcileErrs), "reconcile errors: %s", strings.Join(reconcileErrs, "\n")) - require.Equal(t, lenResources, len(rs), "resources: %s", strings.Join(names, "\n")) + if lenResources != -1 { + require.Equal(t, lenResources, len(rs), "resources: %s", strings.Join(names, "\n")) + } } func RequireResource(t testing.TB, rt *runtime.Runtime, id string, a *runtimev1.Resource) { @@ -211,7 +213,7 @@ func RequireParseErrors(t testing.TB, rt *runtime.Runtime, id string, expectedPa for _, pe := range pp.GetProjectParser().State.ParseErrors { parseErrs[pe.FilePath] = pe.Message } - require.Len(t, parseErrs, len(expectedParseErrors)) + require.Len(t, parseErrs, len(expectedParseErrors), "Should have %d parse errors", len(expectedParseErrors)) for f, pe := range parseErrs { // Checking parseError using Contains instead of Equal diff --git a/runtime/testruntime/testdata/users.xml b/runtime/testruntime/testdata/users.xml new file mode 100644 index 00000000000..215b82f7599 --- /dev/null +++ b/runtime/testruntime/testdata/users.xml @@ -0,0 +1,28 @@ + + + + + best_effort + + + + + + + + + + + + + 3600 + + 0 + 0 + 0 + 0 + 0 + + + + \ No newline at end of file diff --git a/runtime/testruntime/testruntime.go b/runtime/testruntime/testruntime.go index 7adac70ec95..111c6cb5a7c 100644 --- a/runtime/testruntime/testruntime.go +++ b/runtime/testruntime/testruntime.go @@ -18,6 +18,8 @@ import ( "github.com/rilldata/rill/runtime/pkg/activity" "github.com/rilldata/rill/runtime/pkg/email" "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/clickhouse" "go.uber.org/zap" // Load database drivers for testing. @@ -88,6 +90,12 @@ type InstanceOptions struct { StageChanges bool } +type InstanceOptionsForResolvers struct { + InstanceOptions + OLAPDriver string + OLAPDSN string +} + // NewInstanceWithOptions creates a runtime and an instance for use in tests. // The instance's repo is a temp directory that will be cleared when the tests finish. func NewInstanceWithOptions(t TestingT, opts InstanceOptions) (*runtime.Runtime, string) { @@ -157,6 +165,105 @@ func NewInstanceWithOptions(t TestingT, opts InstanceOptions) (*runtime.Runtime, return rt, inst.ID } +func NewInstanceForResolvers(t *testing.T, opts InstanceOptionsForResolvers) (*runtime.Runtime, string) { + rt := New(t) + + if opts.OLAPDriver == "" { + opts.OLAPDriver = "duckdb" + opts.OLAPDSN = ":memory:" + } + + tmpDir := t.TempDir() + + switch opts.OLAPDriver { + case "clickhouse": + ctx := context.Background() + clickHouseContainer, err := clickhouse.Run( + ctx, + "clickhouse/clickhouse-server:24.6.2.17", + clickhouse.WithUsername("clickhouse"), + clickhouse.WithPassword("clickhouse"), + clickhouse.WithConfigFile("../testruntime/testdata/clickhouse-config.xml"), + withUsersConfig("../testruntime/testdata/users.xml"), + ) + require.NoError(t, err) + t.Cleanup(func() { + err := clickHouseContainer.Terminate(ctx) + require.NoError(t, err) + }) + + host, err := clickHouseContainer.Host(ctx) + require.NoError(t, err) + port, err := clickHouseContainer.MappedPort(ctx, "9000/tcp") + require.NoError(t, err) + + opts.OLAPDSN = fmt.Sprintf("clickhouse://clickhouse:clickhouse@%v:%v", host, port.Port()) + case "druid": + _, currentFile, _, _ := goruntime.Caller(0) + envPath := filepath.Join(currentFile, "..", "..", "..", ".env") + _, err := os.Stat(envPath) + if err == nil { // avoid .env in CI environment + require.NoError(t, godotenv.Load(envPath)) + } + + opts.OLAPDSN = os.Getenv("RILL_RUNTIME_DRUID_TEST_DSN") + require.NotEqual(t, "", opts.OLAPDSN) + } + + vars := make(map[string]string) + maps.Copy(vars, opts.Variables) + vars["rill.stage_changes"] = strconv.FormatBool(opts.StageChanges) + + inst := &drivers.Instance{ + Environment: "test", + OLAPConnector: opts.OLAPDriver, + RepoConnector: "repo", + CatalogConnector: "catalog", + Connectors: []*runtimev1.Connector{ + { + Type: "file", + Name: "repo", + Config: map[string]string{"dsn": tmpDir}, + }, + { + Type: opts.OLAPDriver, + Name: opts.OLAPDriver, + Config: map[string]string{"dsn": opts.OLAPDSN}, + }, + { + Type: "sqlite", + Name: "catalog", + // Setting a test-specific name ensures a unique connection when "cache=shared" is enabled. + // "cache=shared" is needed to prevent threading problems. + Config: map[string]string{"dsn": fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())}, + }, + }, + Variables: vars, + WatchRepo: opts.WatchRepo, + } + + for path, data := range opts.Files { + abs := filepath.Join(tmpDir, path) + require.NoError(t, os.MkdirAll(filepath.Dir(abs), os.ModePerm)) + require.NoError(t, os.WriteFile(abs, []byte(data), 0o644)) + } + + err := rt.CreateInstance(context.Background(), inst) + require.NoError(t, err) + require.NotEmpty(t, inst.ID) + + ctrl, err := rt.Controller(context.Background(), inst.ID) + require.NoError(t, err) + + _, err = ctrl.Get(context.Background(), runtime.GlobalProjectParserName, false) + require.NoError(t, err) + + err = ctrl.WaitUntilIdle(context.Background(), opts.WatchRepo) + require.NoError(t, err) + + return rt, inst.ID +} + // NewInstance is a convenience wrapper around NewInstanceWithOptions, using defaults sensible for most tests. func NewInstance(t TestingT) (*runtime.Runtime, string) { return NewInstanceWithOptions(t, InstanceOptions{ @@ -185,9 +292,22 @@ func NewInstanceForProject(t TestingT, name string) (*runtime.Runtime, string) { _, currentFile, _, _ := goruntime.Caller(0) projectPath := filepath.Join(currentFile, "..", "testdata", name) + olapDriver := os.Getenv("RILL_RUNTIME_TEST_OLAP_DRIVER") // todo: refactor a couple of tests that use envs + if olapDriver == "" { + olapDriver = "duckdb" + } + olapDSN := os.Getenv("RILL_RUNTIME_TEST_OLAP_DSN") + if olapDSN == "" { + olapDSN = ":memory:" + } + embedCatalog := true + if olapDriver == "clickhouse" { + embedCatalog = false + } + inst := &drivers.Instance{ Environment: "test", - OLAPConnector: "duckdb", + OLAPConnector: olapDriver, RepoConnector: "repo", CatalogConnector: "catalog", Connectors: []*runtimev1.Connector{ @@ -197,9 +317,9 @@ func NewInstanceForProject(t TestingT, name string) (*runtime.Runtime, string) { Config: map[string]string{"dsn": projectPath}, }, { - Type: "duckdb", - Name: "duckdb", - Config: map[string]string{"dsn": ":memory:"}, + Type: olapDriver, + Name: olapDriver, + Config: map[string]string{"dsn": olapDSN}, }, { Type: "sqlite", @@ -209,7 +329,7 @@ func NewInstanceForProject(t TestingT, name string) (*runtime.Runtime, string) { Config: map[string]string{"dsn": fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())}, }, }, - EmbedCatalog: true, + EmbedCatalog: embedCatalog, } err := rt.CreateInstance(context.Background(), inst) @@ -288,3 +408,16 @@ func NewInstanceForDruidProject(t *testing.T) (*runtime.Runtime, string, error) return rt, inst.ID, nil } + +func withUsersConfig(configFile string) testcontainers.CustomizeRequestOption { + return func(req *testcontainers.GenericContainerRequest) error { + cf := testcontainers.ContainerFile{ + HostFilePath: configFile, + ContainerFilePath: "/etc/clickhouse-server/users.xml", + FileMode: 0o755, + } + req.Files = append(req.Files, cf) + + return nil + } +}