From 6978c42b304bd8c5429c88bcd7d6ed20ac3fd98a Mon Sep 17 00:00:00 2001 From: Scott Winkler Date: Tue, 9 May 2023 11:13:53 -0700 Subject: [PATCH] feat: masking policy in v2 sdk (#1777) * masking policy in v2 sdk * add more tests and docs * add back in region check * linting * linting * linting * linting * fix unit test * update docs * address comments * address comments * remove tag ddl type * Update docs/index.md Co-authored-by: Nathan Gaberel * Update templates/index.md.tmpl Co-authored-by: Nathan Gaberel * fix PR comments * chore: misc linting errors (#1779) * fix misc linting errors * linting fixes * int tests * cleanup env vars * fix int test for masking policy tag * fix linting errors * remove useless test * skip email notification test * fix int test * fix int test * fix int test * adjust timeout * update docs * disable-if-return * add in temp warehouse * if-return revert * update-docs --------- Co-authored-by: Nathan Gaberel --- .golangci.yml | 5 +- CONTRIBUTING.md | 14 +- Makefile | 2 +- docs/index.md | 48 +- .../email_notification_integration.md | 48 ++ docs/resources/masking_policy.md | 45 +- .../tag_masking_policy_association.md | 4 +- examples/provider/provider.tf | 12 +- .../resource.tf | 3 +- .../snowflake_masking_policy/resource.tf | 25 +- go.mod | 3 + go.sum | 7 + pkg/datasources/masking_policies.go | 55 ++- .../masking_policies_acceptance_test.go | 7 +- pkg/datasources/role.go | 1 - pkg/helpers/helpers.go | 15 +- pkg/helpers/random.go | 27 ++ pkg/provider/provider.go | 84 +++- pkg/provider/provider_test.go | 53 ++- pkg/resources/account.go | 2 +- pkg/resources/account_grant.go | 4 +- pkg/resources/alert.go | 5 +- pkg/resources/database_grant.go | 4 +- pkg/resources/database_role.go | 1 - .../database_role_acceptance_test.go | 2 +- ...otification_integration_acceptance_test.go | 1 + pkg/resources/external_table.go | 1 - pkg/resources/external_table_grant.go | 4 +- pkg/resources/failover_group.go | 27 +- pkg/resources/file_format.go | 1 - pkg/resources/file_format_grant.go | 4 +- pkg/resources/function_grant.go | 4 +- pkg/resources/grant_helpers.go | 3 +- pkg/resources/helpers.go | 25 ++ pkg/resources/helpers_test.go | 18 - pkg/resources/integration_grant.go | 4 +- pkg/resources/masking_policy.go | 355 ++++++++------- .../masking_policy_acceptance_test.go | 101 ++++- pkg/resources/masking_policy_grant.go | 4 +- .../masking_policy_grant_acceptance_test.go | 11 +- pkg/resources/masking_policy_test.go | 84 ---- pkg/resources/materialized_view_grant.go | 4 +- pkg/resources/network_policy_attachment.go | 7 +- pkg/resources/object_parameter.go | 1 - pkg/resources/password_policy.go | 82 ++-- pkg/resources/pipe.go | 15 +- pkg/resources/pipe_grant.go | 4 +- pkg/resources/procedure.go | 4 +- pkg/resources/procedure_grant.go | 4 +- pkg/resources/resource_monitor.go | 1 - pkg/resources/resource_monitor_grant.go | 4 +- pkg/resources/role.go | 4 +- pkg/resources/role_grants.go | 4 +- pkg/resources/row_access_policy_grant.go | 4 +- pkg/resources/saml_integration.go | 6 +- pkg/resources/schema_grant.go | 4 +- pkg/resources/sequence.go | 1 - pkg/resources/sequence_grant.go | 4 +- pkg/resources/stage.go | 1 - pkg/resources/stage_grant.go | 4 +- pkg/resources/stream.go | 8 +- pkg/resources/stream_grant.go | 4 +- pkg/resources/table.go | 4 +- ...king_policy_application_acceptance_test.go | 7 +- pkg/resources/table_constraint.go | 6 +- pkg/resources/table_grant.go | 10 +- pkg/resources/tag.go | 11 +- pkg/resources/tag_association.go | 5 +- pkg/resources/tag_grant.go | 4 +- .../tag_masking_policy_association.go | 57 ++- ...king_policy_association_acceptance_test.go | 7 +- .../tag_masking_policy_association_test.go | 94 ---- pkg/resources/task.go | 75 ++-- pkg/resources/task_grant.go | 4 +- pkg/resources/user_grant.go | 4 +- pkg/resources/view.go | 12 +- pkg/resources/view_grant.go | 4 +- pkg/resources/warehouse_grant.go | 4 +- pkg/sdk/client.go | 40 +- pkg/sdk/client_integration_test.go | 67 +++ pkg/sdk/client_test.go | 111 ----- pkg/sdk/common_types.go | 11 + pkg/sdk/config.go | 127 ++++++ pkg/sdk/config_test.go | 121 +++++ pkg/sdk/context_functions.go | 57 ++- pkg/sdk/context_functions_integration_test.go | 40 ++ pkg/sdk/data_types.go | 36 +- pkg/sdk/data_types_test.go | 80 ++++ pkg/sdk/databases.go | 12 + pkg/sdk/errors.go | 1 + pkg/sdk/helper_test.go | 195 +++++++- pkg/sdk/identifier_helpers.go | 109 +++-- pkg/sdk/masking_policy.go | 359 +++++++++++++++ pkg/sdk/masking_policy_integration_test.go | 421 ++++++++++++++++++ pkg/sdk/masking_policy_test.go | 261 +++++++++++ pkg/sdk/password_policy.go | 26 +- pkg/sdk/password_policy_integration_test.go | 42 +- pkg/sdk/password_policy_test.go | 28 +- pkg/sdk/schemas.go | 11 + pkg/sdk/sessions.go | 38 ++ pkg/sdk/sessions_integration_test.go | 48 ++ pkg/sdk/sql_builder.go | 193 ++++++-- pkg/sdk/sql_builder_test.go | 47 +- pkg/sdk/system_functions.go | 29 ++ pkg/sdk/system_functions_integration_test.go | 52 +++ pkg/sdk/tags.go | 14 + pkg/sdk/validations.go | 6 + pkg/sdk/validations_test.go | 19 + pkg/sdk/warehouses.go | 82 ++++ pkg/snowflake/all_grant.go | 2 +- pkg/snowflake/future_grant.go | 2 +- pkg/snowflake/masking_policy.go | 123 +---- pkg/snowflake/masking_policy_test.go | 80 ---- pkg/snowflake/parser.go | 6 +- pkg/snowflake/role_ownership_grant_test.go | 12 +- pkg/snowflake/user_ownership_grant_test.go | 12 +- pkg/validation/validation.go | 2 +- templates/index.md.tmpl | 28 +- 118 files changed, 3336 insertions(+), 1215 deletions(-) create mode 100644 docs/resources/email_notification_integration.md create mode 100644 pkg/helpers/random.go delete mode 100644 pkg/resources/masking_policy_test.go delete mode 100644 pkg/resources/tag_masking_policy_association_test.go create mode 100644 pkg/sdk/client_integration_test.go delete mode 100644 pkg/sdk/client_test.go create mode 100644 pkg/sdk/config.go create mode 100644 pkg/sdk/config_test.go create mode 100644 pkg/sdk/data_types_test.go create mode 100644 pkg/sdk/databases.go create mode 100644 pkg/sdk/masking_policy.go create mode 100644 pkg/sdk/masking_policy_integration_test.go create mode 100644 pkg/sdk/masking_policy_test.go create mode 100644 pkg/sdk/schemas.go create mode 100644 pkg/sdk/sessions.go create mode 100644 pkg/sdk/sessions_integration_test.go create mode 100644 pkg/sdk/system_functions.go create mode 100644 pkg/sdk/system_functions_integration_test.go create mode 100644 pkg/sdk/tags.go create mode 100644 pkg/sdk/validations.go create mode 100644 pkg/sdk/validations_test.go create mode 100644 pkg/sdk/warehouses.go delete mode 100644 pkg/snowflake/masking_policy_test.go diff --git a/.golangci.yml b/.golangci.yml index 817e2ae80f..6f0b07cead 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -11,7 +11,10 @@ linters-settings: packages: - github.com/pkg/error - io/ioutil - + revive: + rules: + - name: if-return + disabled: true linters: disable-all: true enable: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 88345c1df9..bc4fdc2daa 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -30,6 +30,18 @@ export SNOWFLAKE_REGION=us-west-2 export TF_ACC=true ``` +You can also read the config from a `~/.snowflake/config` file, although you will still need to set `TF_ACC` to true. + + +~/.snowflake/config +```sh +[default] +account='TESTACCOUNT' +user='TEST_USER' +password='hunter2' +role='ACCOUNTADMIN' +``` + **Note: PRs for new resources will not be accepted without passing acceptance tests.** For the Terraform resources, there are 3 levels of testing - internal, unit and acceptance tests. @@ -86,4 +98,4 @@ Releases will be performed as needed, typically once every 1-2 weeks. If your ch Releases are done by [goreleaser](https://goreleaser.com/) and run by our make files. There two goreleaser configs, `.goreleaser.yml` for regular releases and `.goreleaser.prerelease.yml` for doing prereleases (for testing). -Releases are [published to the terraform registry](https://registry.terraform.io/providers/chanzuckerberg/snowflake/latest), which requires that releases by signed. \ No newline at end of file +Releases are [published to the terraform registry](https://registry.terraform.io/providers/chanzuckerberg/snowflake/latest), which requires that releases by signed. diff --git a/Makefile b/Makefile index 983784e341..bd83f677e1 100644 --- a/Makefile +++ b/Makefile @@ -62,7 +62,7 @@ test: ## run the tests (except sdk tests) .PHONY: test test-acceptance: ## runs all tests, including the acceptance tests which create and destroys real resources - SKIP_MANAGED_ACCOUNT_TEST=1 TF_ACC=1 $(go_test) -v -coverprofile=coverage.txt -covermode=atomic $(TESTARGS) ./... + SKIP_MANAGED_ACCOUNT_TEST=1 SKIP_EMAIL_INTEGRATION_TESTS=1 TF_ACC=1 $(go_test) -timeout 900s -v -coverprofile=coverage.txt -covermode=atomic $(TESTARGS) ./... .PHONY: test-acceptance deps: diff --git a/docs/index.md b/docs/index.md index cbfc4c3d4f..b62913541f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -13,11 +13,8 @@ Coverage is focused on part of Snowflake related to access control. ```terraform provider "snowflake" { - // required - username = "..." - account = "..." # the Snowflake account identifier - - // optional, exactly one must be set + account = "..." # required if not using profile. Can also be set via SNOWFLAKE_ACCOUNT env var + username = "..." # required if not using profile or token. Can also be set via SNOWFLAKE_USER env var password = "..." oauth_access_token = "..." private_key_path = "..." @@ -35,6 +32,11 @@ provider "snowflake" { host = "..." warehouse = "..." } + + +provider snowflake { + profile = "securityadmin" +} ``` ## Configuration Schema @@ -44,13 +46,9 @@ provider "snowflake" { ## Schema -### Required - -- `account` (String) The name of the Snowflake account. Can also come from the `SNOWFLAKE_ACCOUNT` environment variable. -- `username` (String) Username for username+password authentication. Can come from the `SNOWFLAKE_USER` environment variable. - ### Optional +- `account` (String) The name of the Snowflake account. Can also come from the `SNOWFLAKE_ACCOUNT` environment variable. Required unless using profile. - `browser_auth` (Boolean) Required when `oauth_refresh_token` is used. Can be sourced from `SNOWFLAKE_USE_BROWSER_AUTH` environment variable. - `host` (String) Supports passing in a custom host value to the snowflake go driver for use with privatelink. - `insecure_mode` (Boolean) If true, bypass the Online Certificate Status Protocol (OCSP) certificate revocation check. IMPORTANT: Change the default value for testing or emergency situations only. @@ -65,9 +63,11 @@ provider "snowflake" { - `private_key` (String, Sensitive) Private Key for username+private-key auth. Cannot be used with `browser_auth` or `password`. Can be sourced from `SNOWFLAKE_PRIVATE_KEY` environment variable. - `private_key_passphrase` (String, Sensitive) Supports the encryption ciphers aes-128-cbc, aes-128-gcm, aes-192-cbc, aes-192-gcm, aes-256-cbc, aes-256-gcm, and des-ede3-cbc - `private_key_path` (String, Sensitive) Path to a private key for using keypair authentication. Cannot be used with `browser_auth`, `oauth_access_token` or `password`. Can be sourced from `SNOWFLAKE_PRIVATE_KEY_PATH` environment variable. +- `profile` (String) Sets the profile to read from ~/.snowflake/config file. - `protocol` (String) Support custom protocols to snowflake go driver. Can be sourced from `SNOWFLAKE_PROTOCOL` environment variable. - `region` (String) [Snowflake region](https://docs.snowflake.com/en/user-guide/intro-regions.html) to use. Required if using the [legacy format for the `account` identifier](https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#format-2-legacy-account-locator-in-a-region) in the form of `.`. Can be sourced from the `SNOWFLAKE_REGION` environment variable. - `role` (String) Snowflake role to use for operations. If left unset, default role for user will be used. Can be sourced from the `SNOWFLAKE_ROLE` environment variable. +- `username` (String) Username for username+password authentication. Can come from the `SNOWFLAKE_USER` environment variable. Required unless using profile. - `warehouse` (String) Sets the default warehouse. Optional. Can be sourced from SNOWFLAKE_WAREHOUSE environment variable. ## Authentication @@ -79,8 +79,9 @@ The Snowflake provider support multiple ways to authenticate: * OAuth Refresh Token * Browser Auth * Private Key +* Config File -In all cases account and region are required. +In all cases account and username are required. ### Keypair Authentication Environment Variables @@ -156,3 +157,28 @@ If you choose to use Username and Password Authentication, export these credenti export SNOWFLAKE_USER='...' export SNOWFLAKE_PASSWORD='...' ``` + +### Config File + +If you choose to use a config file, the optional `profile` attribute specifies the profile to use from the config file. If no profile is specified, the default profile is used. The Snowflake config file lives at `~/.snowflake/config` and uses [TOML](https://toml.io/) format. You can override this location by setting the `SNOWFLAKE_CONFIG_PATH` environment variable. If no username and account are specified, the provider will fall back to reading the config file. + +```shell +[default] +account='TESTACCOUNT' +user='TEST_USER' +password='hunter2' +role='ACCOUNTADMIN' + +[securityadmin] +account='TESTACCOUNT' +user='TEST_USER' +password='hunter2' +role='SECURITYADMIN' +``` + +## Order Precedence + +The Snowflake provider will use the following order of precedence when determining which credentials to use: +1) Provider Configuration +2) Environment Variables +3) Config File diff --git a/docs/resources/email_notification_integration.md b/docs/resources/email_notification_integration.md new file mode 100644 index 0000000000..2dd4e78d6d --- /dev/null +++ b/docs/resources/email_notification_integration.md @@ -0,0 +1,48 @@ +--- +# generated by https://github.com/hashicorp/terraform-plugin-docs +page_title: "snowflake_email_notification_integration Resource - terraform-provider-snowflake" +subcategory: "" +description: |- + +--- + +# snowflake_email_notification_integration (Resource) + + + +## Example Usage + +```terraform +resource "snowflake_email_notification_integration" "email_int" { + name = "notification" + comment = "A notification integration." + + enabled = true + allowed_recipients = ["john.doe@gmail.com"] +} +``` + + +## Schema + +### Required + +- `allowed_recipients` (Set of String) List of email addresses that should receive notifications. +- `enabled` (Boolean) +- `name` (String) + +### Optional + +- `comment` (String) A comment for the email integration. + +### Read-Only + +- `id` (String) The ID of this resource. + +## Import + +Import is supported using the following syntax: + +```shell +terraform import snowflake_email_notification_integration.example name +``` diff --git a/docs/resources/masking_policy.md b/docs/resources/masking_policy.md index 901bf3cbe7..6c06a839e4 100644 --- a/docs/resources/masking_policy.md +++ b/docs/resources/masking_policy.md @@ -13,13 +13,28 @@ description: |- ## Example Usage ```terraform -resource "snowflake_masking_policy" "example_masking_policy" { - name = "EXAMPLE_MASKING_POLICY" +resource "snowflake_masking_policy" "test" { + name = "EXAMPLE_MASKING_POLICY" database = "EXAMPLE_DB" schema = "EXAMPLE_SCHEMA" - value_data_type = "string" - masking_expression = "case when current_role() in ('ANALYST') then val else sha2(val, 512) end" - return_data_type = "string" + signature { + column { + name = "val" + type = "VARCHAR" + } + } + masking_expression = <<-EOF + case + when current_role() in ('ROLE_A') then + val + when is_role_in_session( 'ROLE_B' ) then + 'ABC123' + else + '******' + end + EOF + + return_data_type = "VARCHAR" } ``` @@ -33,17 +48,35 @@ resource "snowflake_masking_policy" "example_masking_policy" { - `name` (String) Specifies the identifier for the masking policy; must be unique for the database and schema in which the masking policy is created. - `return_data_type` (String) Specifies the data type to return. - `schema` (String) The schema in which to create the masking policy. -- `value_data_type` (String) Specifies the data type to mask. +- `signature` (Block List, Min: 1, Max: 1) The signature for the masking policy; specifies the input columns and data types to evaluate at query runtime. (see [below for nested schema](#nestedblock--signature)) ### Optional - `comment` (String) Specifies a comment for the masking policy. +- `exempt_other_policies` (Boolean) Specifies whether the row access policy or conditional masking policy can reference a column that is already protected by a masking policy. +- `if_not_exists` (Boolean) Prevent overwriting a previous masking policy with the same name. +- `or_replace` (Boolean) Whether to override a previous masking policy with the same name. ### Read-Only - `id` (String) The ID of this resource. - `qualified_name` (String) Specifies the qualified identifier for the masking policy. + +### Nested Schema for `signature` + +Required: + +- `column` (Block List, Min: 1) (see [below for nested schema](#nestedblock--signature--column)) + + +### Nested Schema for `signature.column` + +Required: + +- `name` (String) Specifies the column name to mask. +- `type` (String) Specifies the column type to mask. + ## Import Import is supported using the following syntax: diff --git a/docs/resources/tag_masking_policy_association.md b/docs/resources/tag_masking_policy_association.md index 3fa0429ab1..bfcdfedabe 100644 --- a/docs/resources/tag_masking_policy_association.md +++ b/docs/resources/tag_masking_policy_association.md @@ -3,12 +3,12 @@ page_title: "snowflake_tag_masking_policy_association Resource - terraform-provider-snowflake" subcategory: "" description: |- - + Attach a masking policy to a tag. Requires a current warehouse to be set. Either with SNOWFLAKE_WAREHOUSE env variable or in current session. If no warehouse is provided, a temporary warehouse will be created. --- # snowflake_tag_masking_policy_association (Resource) - +Attach a masking policy to a tag. Requires a current warehouse to be set. Either with SNOWFLAKE_WAREHOUSE env variable or in current session. If no warehouse is provided, a temporary warehouse will be created. ## Example Usage diff --git a/examples/provider/provider.tf b/examples/provider/provider.tf index d6ddaae6b2..830ac2aa8d 100644 --- a/examples/provider/provider.tf +++ b/examples/provider/provider.tf @@ -1,9 +1,6 @@ provider "snowflake" { - // required - username = "..." - account = "..." # the Snowflake account identifier - - // optional, exactly one must be set + account = "..." # required if not using profile. Can also be set via SNOWFLAKE_ACCOUNT env var + username = "..." # required if not using profile or token. Can also be set via SNOWFLAKE_USER env var password = "..." oauth_access_token = "..." private_key_path = "..." @@ -21,3 +18,8 @@ provider "snowflake" { host = "..." warehouse = "..." } + + +provider snowflake { + profile = "securityadmin" +} diff --git a/examples/resources/snowflake_email_notification_integration/resource.tf b/examples/resources/snowflake_email_notification_integration/resource.tf index 1f12bb31ee..ab737f76ce 100644 --- a/examples/resources/snowflake_email_notification_integration/resource.tf +++ b/examples/resources/snowflake_email_notification_integration/resource.tf @@ -3,6 +3,5 @@ resource "snowflake_email_notification_integration" "email_int" { comment = "A notification integration." enabled = true - allowed_recipients = ['john.doe@gmail.com'] - + allowed_recipients = ["john.doe@gmail.com"] } diff --git a/examples/resources/snowflake_masking_policy/resource.tf b/examples/resources/snowflake_masking_policy/resource.tf index e8630ef1d9..fc4d352ece 100644 --- a/examples/resources/snowflake_masking_policy/resource.tf +++ b/examples/resources/snowflake_masking_policy/resource.tf @@ -1,8 +1,23 @@ -resource "snowflake_masking_policy" "example_masking_policy" { - name = "EXAMPLE_MASKING_POLICY" +resource "snowflake_masking_policy" "test" { + name = "EXAMPLE_MASKING_POLICY" database = "EXAMPLE_DB" schema = "EXAMPLE_SCHEMA" - value_data_type = "string" - masking_expression = "case when current_role() in ('ANALYST') then val else sha2(val, 512) end" - return_data_type = "string" + signature { + column { + name = "val" + type = "VARCHAR" + } + } + masking_expression = <<-EOF + case + when current_role() in ('ROLE_A') then + val + when is_role_in_session( 'ROLE_B' ) then + 'ABC123' + else + '******' + end + EOF + + return_data_type = "VARCHAR" } diff --git a/go.mod b/go.mod index 3023523cf1..4c20c14133 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.19 require ( github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/Pallinder/go-randomdata v1.2.0 + github.com/buger/jsonparser v1.1.1 github.com/hashicorp/terraform-plugin-docs v0.14.1 github.com/hashicorp/terraform-plugin-sdk/v2 v2.26.1 github.com/jmoiron/sqlx v1.3.5 @@ -48,6 +49,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/s3 v1.30.6 // indirect github.com/aws/smithy-go v1.13.5 // indirect github.com/bgentry/speakeasy v0.1.0 // indirect + github.com/brianvoe/gofakeit/v6 v6.21.0 // indirect github.com/danieljoos/wincred v1.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dvsekhvalnov/jose2go v1.5.0 // indirect @@ -99,6 +101,7 @@ require ( github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/mtibben/percent v0.2.1 // indirect github.com/oklog/run v1.1.0 // indirect + github.com/pelletier/go-toml/v2 v2.0.7 // indirect github.com/pierrec/lz4/v4 v4.1.17 // indirect github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 7883904b93..a706d60067 100644 --- a/go.sum +++ b/go.sum @@ -88,6 +88,10 @@ github.com/aws/smithy-go v1.13.5 h1:hgz0X/DX0dGqTYpGALqXJoRKRj5oQ7150i5FdTePzO8= github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= +github.com/brianvoe/gofakeit/v6 v6.21.0 h1:tNkm9yxEbpuPK8Bx39tT4sSc5i9SUGiciLdNix+VDQY= +github.com/brianvoe/gofakeit/v6 v6.21.0/go.mod h1:Ow6qC71xtwm79anlwKRlWZW6zVq9D2XHE4QSSMP/rU8= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/danieljoos/wincred v1.1.2 h1:QLdCxFs1/Yl4zduvBdcHB8goaYk9RARS2SgLLRuAyr0= github.com/danieljoos/wincred v1.1.2/go.mod h1:GijpziifJoIBfYh+S7BbkdUTU4LfM+QnGqR5Vl2tAx0= @@ -279,6 +283,8 @@ github.com/mtibben/percent v0.2.1/go.mod h1:KG9uO+SZkUp+VkRHsCdYQV3XSZrrSpR3O9ib github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/oklog/run v1.1.0 h1:GEenZ1cK0+q0+wsJew9qUg/DyD8k3JzYsZAi5gYi2mA= github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU= +github.com/pelletier/go-toml/v2 v2.0.7 h1:muncTPStnKRos5dpVKULv2FVd4bMOhNePj9CjgDb8Us= +github.com/pelletier/go-toml/v2 v2.0.7/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek= github.com/pierrec/lz4/v4 v4.1.17 h1:kV4Ip+/hUBC+8T6+2EgburRtkE9ef4nbY3f4dFhGjMc= github.com/pierrec/lz4/v4 v4.1.17/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= @@ -321,6 +327,7 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/vmihailenco/msgpack v3.3.3+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk= diff --git a/pkg/datasources/masking_policies.go b/pkg/datasources/masking_policies.go index 0dc922d341..637c8f2afc 100644 --- a/pkg/datasources/masking_policies.go +++ b/pkg/datasources/masking_policies.go @@ -1,12 +1,11 @@ package datasources import ( + "context" "database/sql" - "errors" - "fmt" - "log" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) @@ -65,33 +64,29 @@ func ReadMaskingPolicies(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) databaseName := d.Get("database").(string) schemaName := d.Get("schema").(string) - - currentMaskingPolicies, err := snowflake.ListMaskingPolicies(databaseName, schemaName, db) - if errors.Is(err, sql.ErrNoRows) { - // If not found, mark resource to be removed from state file during apply or refresh - log.Printf("[DEBUG] masking policies in schema (%s) not found", d.Id()) - d.SetId("") - return nil - } else if err != nil { - log.Printf("[DEBUG] unable to parse masking policies in schema (%s)", d.Id()) - d.SetId("") - return nil + client := sdk.NewClientFromDB(db) + ctx := context.Background() + maskingPolicies, err := client.MaskingPolicies.Show(ctx, &sdk.MaskingPolicyShowOptions{ + In: &sdk.In{ + Schema: sdk.NewSchemaIdentifier(databaseName, schemaName), + }, + }) + if err != nil { + return err } - - maskingPolicies := []map[string]interface{}{} - - for _, maskingPolicy := range currentMaskingPolicies { + maskingPoliciesList := []map[string]interface{}{} + for _, maskingPolicy := range maskingPolicies { maskingPolicyMap := map[string]interface{}{} - - maskingPolicyMap["name"] = maskingPolicy.Name.String - maskingPolicyMap["database"] = maskingPolicy.DatabaseName.String - maskingPolicyMap["schema"] = maskingPolicy.SchemaName.String - maskingPolicyMap["comment"] = maskingPolicy.Comment.String - maskingPolicyMap["kind"] = maskingPolicy.Kind.String - - maskingPolicies = append(maskingPolicies, maskingPolicyMap) + maskingPolicyMap["name"] = maskingPolicy.Name + maskingPolicyMap["database"] = maskingPolicy.DatabaseName + maskingPolicyMap["schema"] = maskingPolicy.SchemaName + maskingPolicyMap["comment"] = maskingPolicy.Comment + maskingPolicyMap["kind"] = maskingPolicy.Kind + maskingPoliciesList = append(maskingPoliciesList, maskingPolicyMap) } - - d.SetId(fmt.Sprintf(`%v|%v`, databaseName, schemaName)) - return d.Set("masking_policies", maskingPolicies) + if err := d.Set("masking_policies", maskingPoliciesList); err != nil { + return err + } + d.SetId(helpers.EncodeSnowflakeID(databaseName, schemaName)) + return nil } diff --git a/pkg/datasources/masking_policies_acceptance_test.go b/pkg/datasources/masking_policies_acceptance_test.go index 4b1c12e34a..ab94d5ce11 100644 --- a/pkg/datasources/masking_policies_acceptance_test.go +++ b/pkg/datasources/masking_policies_acceptance_test.go @@ -47,7 +47,12 @@ func maskingPolicies(databaseName string, schemaName string, maskingPolicyName s name = "%v" database = snowflake_database.test.name schema = snowflake_schema.test.name - value_data_type = "VARCHAR" + signature { + column { + name = "val" + type = "VARCHAR" + } + } masking_expression = "case when current_role() in ('ANALYST') then val else sha2(val, 512) end" return_data_type = "VARCHAR(16777216)" comment = "Terraform acceptance test" diff --git a/pkg/datasources/role.go b/pkg/datasources/role.go index 14a9e73581..6cf81c4271 100644 --- a/pkg/datasources/role.go +++ b/pkg/datasources/role.go @@ -55,6 +55,5 @@ func ReadRole(d *schema.ResourceData, meta interface{}) error { if err := d.Set("comment", role.Comment.String); err != nil { return err } - return nil } diff --git a/pkg/helpers/helpers.go b/pkg/helpers/helpers.go index e64cf0ef35..08a231c3f7 100644 --- a/pkg/helpers/helpers.go +++ b/pkg/helpers/helpers.go @@ -52,8 +52,19 @@ func StringToBool(s string) bool { return strings.ToLower(s) == "true" } -// SnowflakeID generates a unique ID for a resource. -func SnowflakeID(attributes ...interface{}) string { +// EncodeSnowflakeID generates a unique ID for a resource. +func EncodeSnowflakeID(attributes ...interface{}) string { + // is attribute already an object identifier? + if len(attributes) == 1 { + if id, ok := attributes[0].(sdk.ObjectIdentifier); ok { + // remove quotes and replace dots with pipes + parts := strings.Split(id.FullyQualifiedName(), ".") + for i, part := range parts { + parts[i] = strings.Trim(part, `"`) + } + return strings.Join(parts, IDDelimiter) + } + } var parts []string for i, attr := range attributes { if attr == nil { diff --git a/pkg/helpers/random.go b/pkg/helpers/random.go new file mode 100644 index 0000000000..9a3dfde964 --- /dev/null +++ b/pkg/helpers/random.go @@ -0,0 +1,27 @@ +package helpers + +import ( + "github.com/brianvoe/gofakeit/v6" +) + +func RandomBool() bool { + return gofakeit.Bool() +} + +func RandomString() string { + return gofakeit.Password(true, true, true, true, false, 28) +} + +func RandomStringRange(min, max int) string { + if min > max { + return "" + } + return gofakeit.Password(true, true, true, true, false, RandomIntRange(min, max)) +} + +func RandomIntRange(min, max int) int { + if min > max { + return 0 + } + return gofakeit.IntRange(min, max) +} diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index 305d068ef6..b41e35d11c 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -1,6 +1,7 @@ package provider import ( + "context" "crypto/rsa" "database/sql" "encoding/json" @@ -8,6 +9,7 @@ import ( "errors" "fmt" "io" + "log" "net/http" "net/url" "os" @@ -23,6 +25,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/datasources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/db" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/resources" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" ) // Provider is a provider. @@ -31,14 +34,14 @@ func Provider() *schema.Provider { Schema: map[string]*schema.Schema{ "account": { Type: schema.TypeString, - Description: "The name of the Snowflake account. Can also come from the `SNOWFLAKE_ACCOUNT` environment variable.", - Required: true, + Description: "The name of the Snowflake account. Can also come from the `SNOWFLAKE_ACCOUNT` environment variable. Required unless using profile.", + Optional: true, DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_ACCOUNT", nil), }, "username": { Type: schema.TypeString, - Description: "Username for username+password authentication. Can come from the `SNOWFLAKE_USER` environment variable.", - Required: true, + Description: "Username for username+password authentication. Can come from the `SNOWFLAKE_USER` environment variable. Required unless using profile.", + Optional: true, DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_USER", nil), }, "password": { @@ -176,6 +179,12 @@ func Provider() *schema.Provider { Optional: true, DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_WAREHOUSE", nil), }, + "profile": { + Type: schema.TypeString, + Description: "Sets the profile to read from ~/.snowflake/config file.", + Optional: true, + DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_PROFILE", "default"), + }, }, ResourcesMap: getResources(), DataSourcesMap: getDataSources(), @@ -334,6 +343,7 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) { port := s.Get("port").(int) warehouse := s.Get("warehouse").(string) insecureMode := s.Get("insecure_mode").(bool) + profile := s.Get("profile").(string) if oauthRefreshToken != "" { accessToken, err := GetOauthAccessToken(oauthEndpoint, oauthClientID, oauthClientSecret, GetOauthData(oauthRefreshToken, oauthRedirectURL)) @@ -359,6 +369,7 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) { port, warehouse, insecureMode, + profile, ) if err != nil { return nil, fmt.Errorf("could not build dsn for snowflake connection err = %w", err) @@ -366,7 +377,21 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) { db, err := db.Open(dsn) if err != nil { - return nil, fmt.Errorf("Could not open snowflake database err = %w", err) + return nil, fmt.Errorf("could not open snowflake database err = %w", err) + } + log.Printf("[INFO] account: %s\n", account) + log.Printf("[INFO] user: %s\n", user) + log.Printf("[INFO] role: %s\n", role) + log.Printf("[INFO] warehouse: %s\n", warehouse) + log.Printf("[INFO] dsn: %s\n", dsn) + client := sdk.NewClientFromDB(db) + sessionID, err := client.ContextFunctions.CurrentSession(context.Background()) + if err != nil { + return nil, fmt.Errorf("could not retrieve session id err = %w", err) + } + log.Printf("[INFO] Snowflake DB connection opened, session ID : %s\n", sessionID) + if err != nil { + return nil, fmt.Errorf("could not open snowflake database err = %w", err) } return db, nil @@ -388,6 +413,7 @@ func DSN( port int, warehouse string, insecureMode bool, + profile string, ) (string, error) { // us-west-2 is Snowflake's default region, but if you actually specify that it won't trigger the default code // https://github.com/snowflakedb/gosnowflake/blob/52137ce8c32eaf93b0bd22fc5c7297beff339812/dsn.go#L61 @@ -395,12 +421,11 @@ func DSN( region = "" } - config := gosnowflake.Config{ + config := &gosnowflake.Config{ Account: account, User: user, Region: region, Role: role, - Application: "terraform-provider-snowflake", Port: port, Protocol: protocol, InsecureMode: insecureMode, @@ -420,18 +445,18 @@ func DSN( if privateKeyPath != "" { //nolint:gocritic // todo: please fix this to pass gocritic privateKeyBytes, err := ReadPrivateKeyFile(privateKeyPath) if err != nil { - return "", fmt.Errorf("Private Key file could not be read err = %w", err) + return "", fmt.Errorf("private Key file could not be read err = %w", err) } rsaPrivateKey, err := ParsePrivateKey(privateKeyBytes, []byte(privateKeyPassphrase)) if err != nil { - return "", fmt.Errorf("Private Key could not be parsed err = %w", err) + return "", fmt.Errorf("private Key could not be parsed err = %w", err) } config.PrivateKey = rsaPrivateKey config.Authenticator = gosnowflake.AuthTypeJwt } else if privateKey != "" { rsaPrivateKey, err := ParsePrivateKey([]byte(privateKey), []byte(privateKeyPassphrase)) if err != nil { - return "", fmt.Errorf("Private Key could not be parsed err = %w", err) + return "", fmt.Errorf("private Key could not be parsed err = %w", err) } config.PrivateKey = rsaPrivateKey config.Authenticator = gosnowflake.AuthTypeJwt @@ -442,26 +467,32 @@ func DSN( config.Token = oauthAccessToken } else if password != "" { config.Password = password - } else { - return "", errors.New("no authentication method provided") + } else if account == "" && user == "" { + // If account and user are empty then we need to fall back on using profile config + log.Printf("[DEBUG] No account or user provided, falling back to profile %s\n", profile) + profileConfig, err := sdk.ProfileConfig(profile) + if err != nil { + return "", errors.New("no authentication method provided") + } + config = sdk.MergeConfig(config, profileConfig) } - - return gosnowflake.DSN(&config) + config.Application = "terraform-provider-snowflake" + return gosnowflake.DSN(config) } func ReadPrivateKeyFile(privateKeyPath string) ([]byte, error) { expandedPrivateKeyPath, err := homedir.Expand(privateKeyPath) if err != nil { - return nil, fmt.Errorf("Invalid Path to private key err = %w", err) + return nil, fmt.Errorf("invalid Path to private key err = %w", err) } privateKeyBytes, err := os.ReadFile(expandedPrivateKeyPath) if err != nil { - return nil, fmt.Errorf("Could not read private key err = %w", err) + return nil, fmt.Errorf("could not read private key err = %w", err) } if len(privateKeyBytes) == 0 { - return nil, errors.New("Private key is empty") + return nil, errors.New("private key is empty") } return privateKeyBytes, nil @@ -479,14 +510,14 @@ func ParsePrivateKey(privateKeyBytes []byte, passhrase []byte) (*rsa.PrivateKey, } privateKey, err := pkcs8.ParsePKCS8PrivateKeyRSA(privateKeyBlock.Bytes, passhrase) if err != nil { - return nil, fmt.Errorf("Could not parse encrypted private key with passphrase, only ciphers aes-128-cbc, aes-128-gcm, aes-192-cbc, aes-192-gcm, aes-256-cbc, aes-256-gcm, and des-ede3-cbc are supported err = %w", err) + return nil, fmt.Errorf("could not parse encrypted private key with passphrase, only ciphers aes-128-cbc, aes-128-gcm, aes-192-cbc, aes-192-gcm, aes-256-cbc, aes-256-gcm, and des-ede3-cbc are supported err = %w", err) } return privateKey, nil } privateKey, err := ssh.ParseRawPrivateKey(privateKeyBytes) if err != nil { - return nil, fmt.Errorf("Could not parse private key err = %w", err) + return nil, fmt.Errorf("could not parse private key err = %w", err) } rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey) @@ -513,7 +544,7 @@ func GetOauthData(refreshToken, redirectURL string) url.Values { func GetOauthRequest(dataContent io.Reader, endPoint, clientID, clientSecret string) (*http.Request, error) { request, err := http.NewRequest("POST", endPoint, dataContent) if err != nil { - return nil, fmt.Errorf("Request to the endpoint could not be completed %w", err) + return nil, fmt.Errorf("request to the endpoint could not be completed %w", err) } request.SetBasicAuth(clientID, clientSecret) request.Header.Set("Content-Type", "application/x-www-form-urlencoded;charset=UTF-8") @@ -536,19 +567,19 @@ func GetOauthAccessToken( response, err := client.Do(request) if err != nil { - return "", fmt.Errorf("Response status returned an err = %w", err) + return "", fmt.Errorf("response status returned an err = %w", err) } if response.StatusCode != 200 { - return "", fmt.Errorf("Response status code: %s: %s err = %w", strconv.Itoa(response.StatusCode), http.StatusText(response.StatusCode), err) + return "", fmt.Errorf("response status code: %s: %s err = %w", strconv.Itoa(response.StatusCode), http.StatusText(response.StatusCode), err) } defer response.Body.Close() body, err := io.ReadAll(response.Body) if err != nil { - return "", fmt.Errorf("Response body was not able to be parsed err = %w", err) + return "", fmt.Errorf("response body was not able to be parsed err = %w", err) } err = json.Unmarshal(body, &result) if err != nil { - return "", fmt.Errorf("Error parsing JSON from Snowflake err = %w", err) + return "", fmt.Errorf("error parsing JSON from Snowflake err = %w", err) } return result.AccessToken, nil } @@ -567,6 +598,10 @@ func GetDatabaseHandleFromEnv() (db *sql.DB, err error) { host := os.Getenv("SNOWFLAKE_HOST") warehouse := os.Getenv("SNOWFLAKE_WAREHOUSE") protocol := os.Getenv("SNOWFLAKE_PROTOCOL") + profile := os.Getenv("SNOWFLAKE_PROFILE") + if profile == "" { + profile = "default" + } port, err := strconv.Atoi(os.Getenv("SNOWFLAKE_PORT")) if err != nil { port = 443 @@ -587,6 +622,7 @@ func GetDatabaseHandleFromEnv() (db *sql.DB, err error) { port, warehouse, false, + profile, ) if err != nil { return nil, err diff --git a/pkg/provider/provider_test.go b/pkg/provider/provider_test.go index 2e7d9bbe75..1daefa30d7 100644 --- a/pkg/provider/provider_test.go +++ b/pkg/provider/provider_test.go @@ -6,6 +6,8 @@ import ( "io" "net/http" "net/url" + "os" + "path/filepath" "reflect" "strconv" "strings" @@ -22,17 +24,31 @@ func TestProvider(t *testing.T) { } func TestDSN(t *testing.T) { + dat := []byte(` + [default] + account='TEST_ACCOUNT' + user='TEST_USER' + password='abcd1234' + role='ACCOUNTADMIN' + `) + path := filepath.Join(t.TempDir(), "config") + err := os.WriteFile(path, dat, 0o600) + require.NoError(t, err) + os.Setenv("SNOWFLAKE_CONFIG_PATH", path) + type args struct { - account string - user string - password string - browserAuth bool - region string - role string - host string - protocol string - port int - warehouse string + account string + user string + password string + browserAuth bool + region string + role string + host string + protocol string + port int + warehouse string + insecureMode bool + profile string } tests := []struct { name string @@ -42,29 +58,34 @@ func TestDSN(t *testing.T) { }{ { "simple", - args{"acct", "user", "pass", false, "region", "role", "", "https", 443, ""}, + args{"acct", "user", "pass", false, "region", "role", "", "https", 443, "", false, "default"}, "user:pass@acct.region.snowflakecomputing.com:443?application=terraform-provider-snowflake&ocspFailOpen=true®ion=region&role=role&validateDefaultParameters=true", false, }, { "us-west-2 special case", - args{"acct2", "user2", "pass2", false, "us-west-2", "role2", "", "https", 443, ""}, + args{"acct2", "user2", "pass2", false, "us-west-2", "role2", "", "https", 443, "", false, "default"}, "user2:pass2@acct2.snowflakecomputing.com:443?application=terraform-provider-snowflake&ocspFailOpen=true&role=role2&validateDefaultParameters=true", false, }, { "customhostwregion", - args{"acct3", "user3", "pass3", false, "", "role3", "zha123.us-east-1.privatelink.snowflakecomputing.com", "https", 443, ""}, + args{"acct3", "user3", "pass3", false, "", "role3", "zha123.us-east-1.privatelink.snowflakecomputing.com", "https", 443, "", false, "default"}, "user3:pass3@zha123.us-east-1.privatelink.snowflakecomputing.com:443?account=acct3&application=terraform-provider-snowflake&ocspFailOpen=true&role=role3&validateDefaultParameters=true", false, }, { "customhostignoreregion", - args{"acct4", "user4", "pass4", false, "fakeregion", "role4", "zha1234.us-east-1.privatelink.snowflakecomputing.com", "https", 8443, ""}, + args{"acct4", "user4", "pass4", false, "fakeregion", "role4", "zha1234.us-east-1.privatelink.snowflakecomputing.com", "https", 8443, "", false, "default"}, "user4:pass4@zha1234.us-east-1.privatelink.snowflakecomputing.com:8443?account=acct4&application=terraform-provider-snowflake&ocspFailOpen=true&role=role4&validateDefaultParameters=true", false, }, + { + "profile", + args{"", "", "", false, "", "", "", "", 0, "", false, "default"}, + "TEST_USER:abcd1234@TEST_ACCOUNT.snowflakecomputing.com:443?application=terraform-provider-snowflake&ocspFailOpen=true&role=ACCOUNTADMIN&validateDefaultParameters=true", false, + }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - got, err := provider.DSN(tt.args.account, tt.args.user, tt.args.password, tt.args.browserAuth, "", "", "", "", tt.args.region, tt.args.role, tt.args.host, tt.args.protocol, tt.args.port, "", false) + got, err := provider.DSN(tt.args.account, tt.args.user, tt.args.password, tt.args.browserAuth, "", "", "", "", tt.args.region, tt.args.role, tt.args.host, tt.args.protocol, tt.args.port, tt.args.warehouse, tt.args.insecureMode, tt.args.profile) if (err != nil) != tt.wantErr { t.Errorf("DSN() error = %v, wantErr %v", err, tt.wantErr) return @@ -114,7 +135,7 @@ func TestOAuthDSN(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - got, err := provider.DSN(tt.args.account, tt.args.user, "", false, "", "", "", tt.args.oauthAccessToken, tt.args.region, tt.args.role, tt.args.host, tt.args.protocol, tt.args.port, "", false) + got, err := provider.DSN(tt.args.account, tt.args.user, "", false, "", "", "", tt.args.oauthAccessToken, tt.args.region, tt.args.role, tt.args.host, tt.args.protocol, tt.args.port, "", false, "default") if (err != nil) != tt.wantErr { t.Errorf("DSN() error = %v, dsn = %v, wantErr %v", err, got, tt.wantErr) diff --git a/pkg/resources/account.go b/pkg/resources/account.go index 28e0658e5d..55f5fee83c 100644 --- a/pkg/resources/account.go +++ b/pkg/resources/account.go @@ -343,6 +343,6 @@ func UpdateAccount(d *schema.ResourceData, meta interface{}) error { } // DeleteAccount implements schema.DeleteFunc. -func DeleteAccount(d *schema.ResourceData, meta interface{}) error { +func DeleteAccount(_ *schema.ResourceData, _ interface{}) error { return fmt.Errorf("cannot delete Snowflake accounts because there is no self service API allowing Terraform to do so. To delete an account, contact Snowflake Support and provide a unique identifier for your account, which can be one of the following:\n Account name\n Account locator\nOnce you contact Snowflake Support, it may take up to six weeks for the account to be fully deleted. This delay allows you to recover the account within 30 days of the request. Snowflake usually deducts the account from the number of accounts allowed for your organization within a few days of the initial request") } diff --git a/pkg/resources/account_grant.go b/pkg/resources/account_grant.go index bfb8589bdd..269113df90 100644 --- a/pkg/resources/account_grant.go +++ b/pkg/resources/account_grant.go @@ -120,7 +120,7 @@ func CreateAccountGrant(d *schema.ResourceData, meta interface{}) error { privilege := d.Get("privilege").(string) roles := expandStringList(d.Get("roles").(*schema.Set).List()) withGrantOption := d.Get("with_grant_option").(bool) - grantID := helpers.SnowflakeID(privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(privilege, withGrantOption, roles) d.SetId(grantID) return ReadAccountGrant(d, meta) @@ -138,7 +138,7 @@ func ReadAccountGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(privilege, withGrantOption, roles) // if the ID is not in the new format, rewrite it if grantID != d.Id() { d.SetId(grantID) diff --git a/pkg/resources/alert.go b/pkg/resources/alert.go index 0760ccc72b..83d5535872 100644 --- a/pkg/resources/alert.go +++ b/pkg/resources/alert.go @@ -259,7 +259,6 @@ func ReadAlert(d *schema.ResourceData, meta interface{}) error { if err := d.Set("action", alert.Action); err != nil { return err } - return nil } @@ -367,8 +366,8 @@ func UpdateAlert(d *schema.ResourceData, meta interface{}) error { } if d.HasChange("alert_schedule") { - _, new := d.GetChange("alert_schedule") - alertSchedule := new.([]interface{})[0].(map[string]interface{}) + _, n := d.GetChange("alert_schedule") + alertSchedule := n.([]interface{})[0].(map[string]interface{}) log.Printf("[DEBUG] alertSchedule: %v", alertSchedule) log.Printf("[DEBUG] alertSchedule[cron]: %v", alertSchedule["cron"]) c := alertSchedule["cron"].([]interface{}) diff --git a/pkg/resources/database_grant.go b/pkg/resources/database_grant.go index 61fba5b59e..e7ff5bcc32 100644 --- a/pkg/resources/database_grant.go +++ b/pkg/resources/database_grant.go @@ -114,7 +114,7 @@ func CreateDatabaseGrant(d *schema.ResourceData, meta interface{}) error { roles := expandStringList(d.Get("roles").(*schema.Set).List()) shares := expandStringList(d.Get("shares").(*schema.Set).List()) withGrantOption := d.Get("with_grant_option").(bool) - grantID := helpers.SnowflakeID(databaseName, privilege, withGrantOption, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, privilege, withGrantOption, roles, shares) d.SetId(grantID) return ReadDatabaseGrant(d, meta) @@ -141,7 +141,7 @@ func ReadDatabaseGrant(d *schema.ResourceData, meta interface{}) error { return fmt.Errorf("error reading database grant: %w", err) } - grantID := helpers.SnowflakeID(databaseName, privilege, withGrantOption, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, privilege, withGrantOption, roles, shares) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/database_role.go b/pkg/resources/database_role.go index 9f2f3d0383..7dba3281c6 100644 --- a/pkg/resources/database_role.go +++ b/pkg/resources/database_role.go @@ -136,7 +136,6 @@ func ReadDatabaseRole(d *schema.ResourceData, meta interface{}) error { if err := d.Set("comment", databaseRole.Comment); err != nil { return err } - return nil } diff --git a/pkg/resources/database_role_acceptance_test.go b/pkg/resources/database_role_acceptance_test.go index ecee74abbd..88420ac5fb 100644 --- a/pkg/resources/database_role_acceptance_test.go +++ b/pkg/resources/database_role_acceptance_test.go @@ -42,7 +42,7 @@ func TestAcc_DatabaseRole(t *testing.T) { }) } -func databaseRoleConfig(dbName string, dbRoleName string, comment string) string { //nolint +func databaseRoleConfig(dbName string, dbRoleName string, comment string) string { s := ` resource "snowflake_database" "test_db" { name = "%s" diff --git a/pkg/resources/email_notification_integration_acceptance_test.go b/pkg/resources/email_notification_integration_acceptance_test.go index 3546ef8ef9..a96282d43b 100644 --- a/pkg/resources/email_notification_integration_acceptance_test.go +++ b/pkg/resources/email_notification_integration_acceptance_test.go @@ -15,6 +15,7 @@ func TestAcc_EmailNotificationIntegration(t *testing.T) { if _, ok := os.LookupEnv("SKIP_EMAIL_INTEGRATION_TESTS"); ok { t.Skip("Skipping TestAcc_EmailNotificationIntegration") } + resource.Test(t, resource.TestCase{ Providers: providers(), CheckDestroy: nil, diff --git a/pkg/resources/external_table.go b/pkg/resources/external_table.go index 56eb10a578..b236a0c4cf 100644 --- a/pkg/resources/external_table.go +++ b/pkg/resources/external_table.go @@ -288,7 +288,6 @@ func ReadExternalTable(d *schema.ResourceData, meta interface{}) error { if err := d.Set("owner", externalTable.Owner.String); err != nil { return err } - return nil } diff --git a/pkg/resources/external_table_grant.go b/pkg/resources/external_table_grant.go index aef993fb93..34c45806ce 100644 --- a/pkg/resources/external_table_grant.go +++ b/pkg/resources/external_table_grant.go @@ -179,7 +179,7 @@ func CreateExternalTableGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, externalTableName, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, externalTableName, privilege, withGrantOption, onFuture, onAll, roles, shares) d.SetId(grantID) return ReadExternalTableGrant(d, meta) @@ -212,7 +212,7 @@ func ReadExternalTableGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, externalTableName, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, externalTableName, privilege, withGrantOption, onFuture, onAll, roles, shares) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/failover_group.go b/pkg/resources/failover_group.go index ab8541ee02..aa9b4bb254 100644 --- a/pkg/resources/failover_group.go +++ b/pkg/resources/failover_group.go @@ -433,7 +433,6 @@ func ReadFailoverGroup(d *schema.ResourceData, meta interface{}) error { if err := d.Set("allowed_shares", nil); err != nil { return err } - return nil } @@ -444,8 +443,8 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { builder := snowflake.NewFailoverGroupBuilder(name) if d.HasChange("object_types") { - _, new := d.GetChange("object_types") - newObjectTypes := new.(*schema.Set).List() + _, n := d.GetChange("object_types") + newObjectTypes := n.(*schema.Set).List() var objectTypes []string for _, v := range newObjectTypes { @@ -458,13 +457,13 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { } if d.HasChange("allowed_databases") { - old, new := d.GetChange("allowed_databases") - oad := old.(*schema.Set).List() + o, n := d.GetChange("allowed_databases") + oad := o.(*schema.Set).List() oldAllowedDatabases := make([]string, len(oad)) for i, v := range oad { oldAllowedDatabases[i] = v.(string) } - nad := new.(*schema.Set).List() + nad := n.(*schema.Set).List() newAllowedDatabases := make([]string, len(nad)) for i, v := range nad { newAllowedDatabases[i] = v.(string) @@ -499,13 +498,13 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { } if d.HasChange("allowed_shares") { - old, new := d.GetChange("allowed_shares") - oad := old.(*schema.Set).List() + o, n := d.GetChange("allowed_shares") + oad := o.(*schema.Set).List() oldAllowedShares := make([]string, len(oad)) for i, v := range oad { oldAllowedShares[i] = v.(string) } - nad := new.(*schema.Set).List() + nad := n.(*schema.Set).List() newAllowedShares := make([]string, len(nad)) for i, v := range nad { newAllowedShares[i] = v.(string) @@ -552,13 +551,13 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { } if d.HasChange("allowed_accounts") { - old, new := d.GetChange("allowed_accounts") - oad := old.(*schema.Set).List() + o, n := d.GetChange("allowed_accounts") + oad := o.(*schema.Set).List() oldAllowedAccounts := make([]string, len(oad)) for i, v := range oad { oldAllowedAccounts[i] = v.(string) } - nad := new.(*schema.Set).List() + nad := n.(*schema.Set).List() newAllowedAccounts := make([]string, len(nad)) for i, v := range nad { newAllowedAccounts[i] = v.(string) @@ -595,8 +594,8 @@ func UpdateFailoverGroup(d *schema.ResourceData, meta interface{}) error { } if d.HasChange("replication_schedule") { - _, new := d.GetChange("replication_schedule") - replicationSchedule := new.([]interface{})[0].(map[string]interface{}) + _, n := d.GetChange("replication_schedule") + replicationSchedule := n.([]interface{})[0].(map[string]interface{}) log.Printf("[DEBUG] replicationSchedule: %v", replicationSchedule) log.Printf("[DEBUG] replicationSchedule[cron]: %v", replicationSchedule["cron"]) c := replicationSchedule["cron"].([]interface{}) diff --git a/pkg/resources/file_format.go b/pkg/resources/file_format.go index 5a326af3ac..cf296071b5 100644 --- a/pkg/resources/file_format.go +++ b/pkg/resources/file_format.go @@ -691,7 +691,6 @@ func ReadFileFormat(d *schema.ResourceData, meta interface{}) error { if err := d.Set("comment", f.Comment.String); err != nil { return err } - return nil } diff --git a/pkg/resources/file_format_grant.go b/pkg/resources/file_format_grant.go index d7bbef9e47..ce844fbcc9 100644 --- a/pkg/resources/file_format_grant.go +++ b/pkg/resources/file_format_grant.go @@ -164,7 +164,7 @@ func CreateFileFormatGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, fileFormatName, privilege, withGrantOption, onFuture, onAll, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, fileFormatName, privilege, withGrantOption, onFuture, onAll, roles) d.SetId(grantID) return ReadFileFormatGrant(d, meta) @@ -196,7 +196,7 @@ func ReadFileFormatGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, fileFormatName, privilege, withGrantOption, onFuture, onAll, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, fileFormatName, privilege, withGrantOption, onFuture, onAll, roles) if d.Id() != grantID { d.SetId(grantID) } diff --git a/pkg/resources/function_grant.go b/pkg/resources/function_grant.go index dc2860ce05..be88d78523 100644 --- a/pkg/resources/function_grant.go +++ b/pkg/resources/function_grant.go @@ -196,7 +196,7 @@ func CreateFunctionGrant(d *schema.ResourceData, meta interface{}) error { if err := createGenericGrant(d, meta, builder); err != nil { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, functionName, argumentDataTypes, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, functionName, argumentDataTypes, privilege, withGrantOption, onFuture, onAll, roles, shares) d.SetId(grantID) return ReadFunctionGrant(d, meta) } @@ -228,7 +228,7 @@ func ReadFunctionGrant(d *schema.ResourceData, meta interface{}) error { if err != nil { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, functionName, argumentDataTypes, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, functionName, argumentDataTypes, privilege, withGrantOption, onFuture, onAll, roles, shares) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/grant_helpers.go b/pkg/resources/grant_helpers.go index 3317797130..6c54db684e 100644 --- a/pkg/resources/grant_helpers.go +++ b/pkg/resources/grant_helpers.go @@ -113,7 +113,7 @@ func readGenericGrant( builder snowflake.GrantBuilder, futureObjects bool, allObjects bool, - validPrivileges PrivilegeSet, + _ PrivilegeSet, ) error { db := meta.(*sql.DB) var grants []*grant @@ -236,7 +236,6 @@ func readGenericGrant( if err := d.Set("with_grant_option", grantOption); err != nil { return err } - return nil } diff --git a/pkg/resources/helpers.go b/pkg/resources/helpers.go index a786fa5524..66697eb678 100644 --- a/pkg/resources/helpers.go +++ b/pkg/resources/helpers.go @@ -1,5 +1,30 @@ package resources +import ( + "fmt" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" +) + func isOk(_ interface{}, ok bool) bool { return ok } + +func dataTypeValidateFunc(val interface{}, _ string) (warns []string, errs []error) { + if ok := sdk.IsValidDataType(val.(string)); !ok { + errs = append(errs, fmt.Errorf("%v is not a valid data type", val)) + } + return +} + +func dataTypeDiffSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { + oldDT := sdk.DataTypeFromString(old) + newDT := sdk.DataTypeFromString(new) + return oldDT == newDT +} + +func ignoreTrimSpaceSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { + return strings.TrimSpace(old) == strings.TrimSpace(new) +} diff --git a/pkg/resources/helpers_test.go b/pkg/resources/helpers_test.go index 0e26d5b6a8..775588ad39 100644 --- a/pkg/resources/helpers_test.go +++ b/pkg/resources/helpers_test.go @@ -117,15 +117,6 @@ func managedAccount(t *testing.T, id string, params map[string]interface{}) *sch return d } -func maskingPolicy(t *testing.T, id string, params map[string]interface{}) *schema.ResourceData { - t.Helper() - r := require.New(t) - d := schema.TestResourceDataRaw(t, resources.MaskingPolicy().Schema, params) - r.NotNil(d) - d.SetId(id) - return d -} - func networkPolicy(t *testing.T, id string, params map[string]interface{}) *schema.ResourceData { t.Helper() r := require.New(t) @@ -268,15 +259,6 @@ func oauthIntegration(t *testing.T, id string, params map[string]interface{}) *s return d } -func externalOauthIntegration(t *testing.T, id string, params map[string]interface{}) *schema.ResourceData { - t.Helper() - r := require.New(t) - d := schema.TestResourceDataRaw(t, resources.ExternalOauthIntegration().Schema, params) - r.NotNil(d) - d.SetId(id) - return d -} - func externalFunction(t *testing.T, id string, params map[string]interface{}) *schema.ResourceData { t.Helper() r := require.New(t) diff --git a/pkg/resources/integration_grant.go b/pkg/resources/integration_grant.go index b9a7229b26..d3c82a6c54 100644 --- a/pkg/resources/integration_grant.go +++ b/pkg/resources/integration_grant.go @@ -100,7 +100,7 @@ func CreateIntegrationGrant(d *schema.ResourceData, meta interface{}) error { if err := createGenericGrant(d, meta, builder); err != nil { return err } - grantID := helpers.SnowflakeID(integrationName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(integrationName, privilege, withGrantOption, roles) d.SetId(grantID) return ReadIntegrationGrant(d, meta) @@ -120,7 +120,7 @@ func ReadIntegrationGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(integrationName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(integrationName, privilege, withGrantOption, roles) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/masking_policy.go b/pkg/resources/masking_policy.go index ab6539ea68..cdfa3267c1 100644 --- a/pkg/resources/masking_policy.go +++ b/pkg/resources/masking_policy.go @@ -1,22 +1,36 @@ package resources import ( - "bytes" + "context" "database/sql" - "encoding/csv" "fmt" - "strings" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" - "golang.org/x/exp/slices" -) - -const ( - maskingPolicyIDDelimiter = '|' ) var maskingPolicySchema = map[string]*schema.Schema{ + "or_replace": { + Type: schema.TypeBool, + Optional: true, + Default: false, + Description: "Whether to override a previous masking policy with the same name.", + DiffSuppressOnRefresh: true, + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + return old != new + }, + }, + "if_not_exists": { + Type: schema.TypeBool, + Optional: true, + Default: false, + Description: "Prevent overwriting a previous masking policy with the same name.", + DiffSuppressOnRefresh: true, + DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { + return old != new + }, + }, "name": { Type: schema.TypeString, Required: true, @@ -35,32 +49,58 @@ var maskingPolicySchema = map[string]*schema.Schema{ Description: "The schema in which to create the masking policy.", ForceNew: true, }, - "value_data_type": { - Type: schema.TypeString, + "signature": { + Type: schema.TypeList, Required: true, - Description: "Specifies the data type to mask.", - ForceNew: true, - DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { - // these are all equivalent as per https://docs.snowflake.com/en/sql-reference/data-types-text.html - varcharType := []string{"VARCHAR(16777216)", "VARCHAR", "text", "string", "NVARCHAR", "NVARCHAR2", "CHAR VARYING", "NCHAR VARYING"} - return slices.Contains(varcharType, new) && slices.Contains(varcharType, old) + Description: "The signature for the masking policy; specifies the input columns and data types to evaluate at query runtime.", + MinItems: 1, + MaxItems: 1, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "column": { + Type: schema.TypeList, + Required: true, + MinItems: 1, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "name": { + Type: schema.TypeString, + Required: true, + Description: "Specifies the column name to mask.", + }, + "type": { + Type: schema.TypeString, + Required: true, + Description: "Specifies the column type to mask.", + ForceNew: true, + ValidateFunc: dataTypeValidateFunc, + DiffSuppressFunc: dataTypeDiffSuppressFunc, + }, + }, + }, + }, + }, }, }, "masking_expression": { - Type: schema.TypeString, - Required: true, - Description: "Specifies the SQL expression that transforms the data.", + Type: schema.TypeString, + Required: true, + Description: "Specifies the SQL expression that transforms the data.", + DiffSuppressFunc: ignoreTrimSpaceSuppressFunc, }, "return_data_type": { - Type: schema.TypeString, - Required: true, - Description: "Specifies the data type to return.", - ForceNew: true, - DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool { - // these are all equivalent as per https://docs.snowflake.com/en/sql-reference/data-types-text.html - varcharType := []string{"VARCHAR(16777216)", "VARCHAR", "text", "string", "NVARCHAR", "NVARCHAR2", "CHAR VARYING", "NCHAR VARYING"} - return slices.Contains(varcharType, new) && slices.Contains(varcharType, old) - }, + Type: schema.TypeString, + Required: true, + Description: "Specifies the data type to return.", + ForceNew: true, + ValidateFunc: dataTypeValidateFunc, + DiffSuppressFunc: dataTypeDiffSuppressFunc, + }, + "exempt_other_policies": { + Type: schema.TypeBool, + Optional: true, + Description: "Specifies whether the row access policy or conditional masking policy can reference a column that is already protected by a masking policy.", + Default: false, }, "comment": { Type: schema.TypeString, @@ -74,50 +114,7 @@ var maskingPolicySchema = map[string]*schema.Schema{ }, } -type maskingPolicyID struct { - DatabaseName string - SchemaName string - MaskingPolicyName string -} - -// String() takes in a maskingPolicyID object and returns a pipe-delimited string: // DatabaseName|SchemaName|MaskingPolicyName. -func (mpi *maskingPolicyID) String() (string, error) { - var buf bytes.Buffer - csvWriter := csv.NewWriter(&buf) - csvWriter.Comma = maskingPolicyIDDelimiter - dataIdentifiers := [][]string{{mpi.DatabaseName, mpi.SchemaName, mpi.MaskingPolicyName}} - if err := csvWriter.WriteAll(dataIdentifiers); err != nil { - return "", err - } - strMaskingPolicyID := strings.TrimSpace(buf.String()) - return strMaskingPolicyID, nil -} - -// / maskingPolicyIDFromString() takes in a pipe-delimited string: DatabaseName|SchemaName|MaskingPolicyName -// and returns a maskingPolicyID object. -func maskingPolicyIDFromString(stringID string) (*maskingPolicyID, error) { - reader := csv.NewReader(strings.NewReader(stringID)) - reader.Comma = maskingPolicyIDDelimiter - lines, err := reader.ReadAll() - if err != nil { - return nil, fmt.Errorf("not CSV compatible") - } - - if len(lines) != 1 { - return nil, fmt.Errorf("1 line per masking policy") - } - if len(lines[0]) != 3 { - return nil, fmt.Errorf("3 fields allowed") - } - - maskingPolicyResult := &maskingPolicyID{ - DatabaseName: lines[0][0], - SchemaName: lines[0][1], - MaskingPolicyName: lines[0][2], - } - return maskingPolicyResult, nil -} // MaskingPolicy returns a pointer to the resource representing a masking policy. func MaskingPolicy() *schema.Resource { @@ -137,39 +134,48 @@ func MaskingPolicy() *schema.Resource { // CreateMaskingPolicy implements schema.CreateFunc. func CreateMaskingPolicy(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) + client := sdk.NewClientFromDB(db) + name := d.Get("name").(string) - database := d.Get("database").(string) - schema := d.Get("schema").(string) - valueDataType := d.Get("value_data_type").(string) - maskingExpression := d.Get("masking_expression").(string) + databaseName := d.Get("database").(string) + schemaName := d.Get("schema").(string) + + expression := d.Get("masking_expression").(string) returnDataType := d.Get("return_data_type").(string) - builder := snowflake.MaskingPolicy(name, database, schema) + ctx := context.Background() + objectIdentifier := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name) + + signatureList := d.Get("signature").([]interface{}) + signature := []sdk.TableColumnSignature{} + for _, s := range signatureList { + m := s.(map[string]interface{}) + columns := m["column"].([]interface{}) + for _, c := range columns { + cm := c.(map[string]interface{}) + dt := sdk.DataTypeFromString(cm["type"].(string)) + signature = append(signature, sdk.TableColumnSignature{ + Name: cm["name"].(string), + Type: dt, + }) + } + } - builder.WithValueDataType(valueDataType) - builder.WithMaskingExpression(maskingExpression) - builder.WithReturnDataType(returnDataType) + returns := sdk.DataTypeFromString(returnDataType) - // Set optionals - if v, ok := d.GetOk("comment"); ok { - builder.WithComment(v.(string)) + opts := &sdk.MaskingPolicyCreateOptions{} + if comment, ok := d.Get("comment").(string); ok { + opts.Comment = sdk.String(comment) } - - stmt := builder.Create() - if err := snowflake.Exec(db, stmt); err != nil { - return fmt.Errorf("error creating masking policy %v err = %w", name, err) + if exemptOtherPolicies := d.Get("exempt_other_policies").(bool); exemptOtherPolicies { + opts.ExemptOtherPolicies = sdk.Bool(exemptOtherPolicies) } - maskingPolicyID := &maskingPolicyID{ - DatabaseName: database, - SchemaName: schema, - MaskingPolicyName: name, - } - dataIDInput, err := maskingPolicyID.String() + err := client.MaskingPolicies.Create(ctx, objectIdentifier, signature, returns, expression, opts) if err != nil { return err } - d.SetId(dataIDInput) + d.SetId(helpers.EncodeSnowflakeID(objectIdentifier)) return ReadMaskingPolicy(d, meta) } @@ -177,76 +183,75 @@ func CreateMaskingPolicy(d *schema.ResourceData, meta interface{}) error { // ReadMaskingPolicy implements schema.ReadFunc. func ReadMaskingPolicy(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - maskingPolicyID, err := maskingPolicyIDFromString(d.Id()) - if err != nil { - return err - } - - dbName := maskingPolicyID.DatabaseName - schema := maskingPolicyID.SchemaName - policyName := maskingPolicyID.MaskingPolicyName + client := sdk.NewClientFromDB(db) + objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - builder := snowflake.MaskingPolicy(policyName, dbName, schema) - - showSQL := builder.Show() - - row := snowflake.QueryRow(db, showSQL) - - s, err := snowflake.ScanMaskingPolicies(row) + ctx := context.Background() + opts := &sdk.MaskingPolicyShowOptions{ + Like: &sdk.Like{ + Pattern: sdk.String(objectIdentifier.Name()), + }, + In: &sdk.In{ + Schema: sdk.NewSchemaIdentifier(objectIdentifier.DatabaseName(), objectIdentifier.SchemaName()), + }, + } + maskingPolicies, err := client.MaskingPolicies.Show(ctx, opts) if err != nil { return err } - - if err := d.Set("name", s.Name.String); err != nil { + if len(maskingPolicies) == 0 { + return fmt.Errorf("masking policy %v not found", d.Id()) + } + maskingPolicy := maskingPolicies[0] + if err := d.Set("name", maskingPolicy.Name); err != nil { return err } - if err := d.Set("database", s.DatabaseName.String); err != nil { + if err := d.Set("database", maskingPolicy.DatabaseName); err != nil { return err } - if err := d.Set("schema", s.SchemaName.String); err != nil { + if err := d.Set("schema", maskingPolicy.SchemaName); err != nil { return err } - if err := d.Set("comment", s.Comment.String); err != nil { + if err := d.Set("exempt_other_policies", maskingPolicy.ExemptOtherPolicies); err != nil { return err } - if err := d.Set("qualified_name", builder.QualifiedName()); err != nil { + if err := d.Set("comment", maskingPolicy.Comment); err != nil { return err } - descSQL := builder.Describe() - rows, err := snowflake.Query(db, descSQL) + maskingPolicyDetails, err := client.MaskingPolicies.Describe(ctx, objectIdentifier) if err != nil { return err } - var ( - name string - signature string - returnType string - body string - ) - for rows.Next() { - if err := rows.Scan(&name, &signature, &returnType, &body); err != nil { - return err - } - - if err := d.Set("masking_expression", body); err != nil { - return err - } + if err := d.Set("masking_expression", maskingPolicyDetails.Body); err != nil { + return err + } - if err := d.Set("return_data_type", returnType); err != nil { - return err - } + if err := d.Set("return_data_type", maskingPolicyDetails.ReturnType); err != nil { + return err + } - // format in database is `(VAL )` - valueDataType := strings.TrimSuffix(strings.Split(signature, " ")[1], ")") - if err := d.Set("value_data_type", valueDataType); err != nil { - return err - } + signature := []map[string]interface{}{} + for _, s := range maskingPolicyDetails.Signature { + signature = append(signature, map[string]interface{}{ + "column": []map[string]interface{}{ + { + "name": s.Name, + "type": s.Type, + }, + }, + }) + } + if err := d.Set("signature", signature); err != nil { + return err + } + if err := d.Set("qualified_name", objectIdentifier.FullyQualifiedName()); err != nil { + return err } return err @@ -255,39 +260,51 @@ func ReadMaskingPolicy(d *schema.ResourceData, meta interface{}) error { // UpdateMaskingPolicy implements schema.UpdateFunc. func UpdateMaskingPolicy(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) + client := sdk.NewClientFromDB(db) + objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) + ctx := context.Background() - maskingPolicyID, err := maskingPolicyIDFromString(d.Id()) - if err != nil { - return err + if d.HasChange("masking_expression") { + alterOptions := &sdk.MaskingPolicyAlterOptions{} + _, n := d.GetChange("masking_expression") + alterOptions.Set = &sdk.MaskingPolicySet{ + Body: sdk.String(n.(string)), + } + err := client.MaskingPolicies.Alter(ctx, objectIdentifier, alterOptions) + if err != nil { + return err + } } - dbName := maskingPolicyID.DatabaseName - schema := maskingPolicyID.SchemaName - policyName := maskingPolicyID.MaskingPolicyName - - builder := snowflake.MaskingPolicy(policyName, dbName, schema) - if d.HasChange("comment") { - comment := d.Get("comment") - if c := comment.(string); c == "" { - q := builder.RemoveComment() - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error unsetting comment for masking policy on %v err = %w", d.Id(), err) + alterOptions := &sdk.MaskingPolicyAlterOptions{} + if v, ok := d.GetOk("comment"); ok { + alterOptions.Set = &sdk.MaskingPolicySet{ + Comment: sdk.String(v.(string)), } } else { - q := builder.ChangeComment(c) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating comment for masking policy on %v err = %w", d.Id(), err) + alterOptions.Unset = &sdk.MaskingPolicyUnset{ + Comment: sdk.Bool(true), } } + err := client.MaskingPolicies.Alter(ctx, objectIdentifier, alterOptions) + if err != nil { + return err + } } - if d.HasChange("masking_expression") { - maskingExpression := d.Get("masking_expression") - q := builder.ChangeMaskingExpression(maskingExpression.(string)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating masking policy expression on %v err = %w", d.Id(), err) + if d.HasChange("name") { + _, n := d.GetChange("name") + newName := n.(string) + newID := sdk.NewSchemaObjectIdentifier(objectIdentifier.DatabaseName(), objectIdentifier.SchemaName(), newName) + alterOptions := &sdk.MaskingPolicyAlterOptions{ + NewName: newID, + } + err := client.MaskingPolicies.Alter(ctx, objectIdentifier, alterOptions) + if err != nil { + return err } + d.SetId(helpers.EncodeSnowflakeID(newID)) } return ReadMaskingPolicy(d, meta) @@ -296,21 +313,15 @@ func UpdateMaskingPolicy(d *schema.ResourceData, meta interface{}) error { // DeleteMaskingPolicy implements schema.DeleteFunc. func DeleteMaskingPolicy(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - maskingPolicyID, err := maskingPolicyIDFromString(d.Id()) + client := sdk.NewClientFromDB(db) + ctx := context.Background() + objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) + + err := client.MaskingPolicies.Drop(ctx, objectIdentifier) if err != nil { return err } - dbName := maskingPolicyID.DatabaseName - schema := maskingPolicyID.SchemaName - policyName := maskingPolicyID.MaskingPolicyName - - q := snowflake.MaskingPolicy(policyName, dbName, schema).Drop() - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error deleting masking policy %v err = %w", d.Id(), err) - } - d.SetId("") - return nil } diff --git a/pkg/resources/masking_policy_acceptance_test.go b/pkg/resources/masking_policy_acceptance_test.go index 9a6a22c783..a81f21432c 100644 --- a/pkg/resources/masking_policy_acceptance_test.go +++ b/pkg/resources/masking_policy_acceptance_test.go @@ -2,7 +2,6 @@ package resources_test import ( "fmt" - "os" "strings" "testing" @@ -11,32 +10,63 @@ import ( ) func TestAcc_MaskingPolicy(t *testing.T) { - if _, ok := os.LookupEnv("SKIP_MASKING_POLICY_TESTS"); ok { - t.Skip("Skipping TestAccMaskingPolicy") - } - accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - + accName2 := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) + comment := "Terraform acceptance test" + comment2 := "Terraform acceptance test 2" resource.ParallelTest(t, resource.TestCase{ Providers: providers(), CheckDestroy: nil, Steps: []resource.TestStep{ { - Config: maskingPolicyConfig(accName), + Config: maskingPolicyConfig(accName, accName, comment), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_masking_policy.test", "name", accName), resource.TestCheckResourceAttr("snowflake_masking_policy.test", "database", accName), resource.TestCheckResourceAttr("snowflake_masking_policy.test", "schema", accName), - resource.TestCheckResourceAttr("snowflake_masking_policy.test", "comment", "Terraform acceptance test"), - resource.TestCheckResourceAttr("snowflake_masking_policy.test", "value_data_type", "VARCHAR"), + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "comment", comment), resource.TestCheckResourceAttr("snowflake_masking_policy.test", "masking_expression", "case when current_role() in ('ANALYST') then val else sha2(val, 512) end"), + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "return_data_type", "VARCHAR"), + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "signature.#", "1"), + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "signature.0.column.#", "1"), + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "signature.0.column.0.name", "val"), + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "signature.0.column.0.type", "VARCHAR"), + ), + }, + // change comment + { + Config: maskingPolicyConfig(accName, accName, comment2), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "name", accName), + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "comment", comment2), ), }, + // rename + { + Config: maskingPolicyConfig(accName, accName2, comment2), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "name", accName2), + ), + }, + // change body and unset comment + { + Config: maskingPolicyConfigMultiline(accName, accName2), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "masking_expression", "case \n\twhen current_role() in ('ROLE_A') then \n\t\tval \n\twhen is_role_in_session( 'ROLE_B' ) then \n\t\t'ABC123'\n\telse\n\t\t'******'\nend"), + resource.TestCheckResourceAttr("snowflake_masking_policy.test", "comment", ""), + ), + }, + // IMPORT + { + ResourceName: "snowflake_masking_policy.test", + ImportState: true, + ImportStateVerify: true, + }, }, }) } -func maskingPolicyConfig(n string) string { +func maskingPolicyConfig(n string, name string, comment string) string { return fmt.Sprintf(` resource "snowflake_database" "test" { name = "%v" @@ -50,13 +80,56 @@ resource "snowflake_schema" "test" { } resource "snowflake_masking_policy" "test" { - name = "%v" + name = "%s" database = snowflake_database.test.name schema = snowflake_schema.test.name - value_data_type = "VARCHAR" + signature { + column { + name = "val" + type = "VARCHAR" + } + } masking_expression = "case when current_role() in ('ANALYST') then val else sha2(val, 512) end" return_data_type = "VARCHAR" - comment = "Terraform acceptance test" + comment = "%s" } -`, n, n, n) +`, n, n, name, comment) +} + +func maskingPolicyConfigMultiline(n string, name string) string { + return fmt.Sprintf(` + resource "snowflake_database" "test" { + name = "%v" + comment = "Terraform acceptance test" + } + + resource "snowflake_schema" "test" { + name = "%v" + database = snowflake_database.test.name + comment = "Terraform acceptance test" + } + + resource "snowflake_masking_policy" "test" { + name = "%s" + database = snowflake_database.test.name + schema = snowflake_schema.test.name + signature { + column { + name = "val" + type = "VARCHAR" + } + } + masking_expression = <<-EOF + case + when current_role() in ('ROLE_A') then + val + when is_role_in_session( 'ROLE_B' ) then + 'ABC123' + else + '******' + end + EOF + return_data_type = "VARCHAR" + } + `, n, n, name) } diff --git a/pkg/resources/masking_policy_grant.go b/pkg/resources/masking_policy_grant.go index 9897adb547..2ea2551fa1 100644 --- a/pkg/resources/masking_policy_grant.go +++ b/pkg/resources/masking_policy_grant.go @@ -123,7 +123,7 @@ func CreateMaskingPolicyGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, maskingPolicyName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, maskingPolicyName, privilege, withGrantOption, roles) d.SetId(grantID) return ReadMaskingPolicyGrant(d, meta) @@ -145,7 +145,7 @@ func ReadMaskingPolicyGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, maskingPolicyName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, maskingPolicyName, privilege, withGrantOption, roles) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/masking_policy_grant_acceptance_test.go b/pkg/resources/masking_policy_grant_acceptance_test.go index b9017731d6..3c1cb5ab18 100644 --- a/pkg/resources/masking_policy_grant_acceptance_test.go +++ b/pkg/resources/masking_policy_grant_acceptance_test.go @@ -2,7 +2,6 @@ package resources_test import ( "fmt" - "os" "strings" "testing" @@ -11,9 +10,6 @@ import ( ) func TestAcc_MaskingPolicyGrant(t *testing.T) { - if _, ok := os.LookupEnv("SKIP_MASKING_POLICY_TESTS"); ok { - t.Skip("Skipping TestAccMaskingPolicy") - } accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.Test(t, resource.TestCase{ @@ -64,7 +60,12 @@ func maskingPolicyGrantConfig(name string) string { name = "%v" database = snowflake_database.test.name schema = snowflake_schema.test.name - value_data_type = "VARCHAR" + signature { + column { + name = "val" + type = "VARCHAR" + } + } masking_expression = "case when current_role() in ('ANALYST') then val else sha2(val, 512) end" return_data_type = "VARCHAR" comment = "Terraform acceptance test" diff --git a/pkg/resources/masking_policy_test.go b/pkg/resources/masking_policy_test.go deleted file mode 100644 index 89fab7cedb..0000000000 --- a/pkg/resources/masking_policy_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package resources_test - -import ( - "database/sql" - "testing" - "time" - - sqlmock "github.com/DATA-DOG/go-sqlmock" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/resources" - . "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/testhelpers" - "github.com/stretchr/testify/require" -) - -func TestMaskingPolicy(t *testing.T) { - r := require.New(t) - err := resources.MaskingPolicy().InternalValidate(provider.Provider().Schema, true) - r.NoError(err) -} - -func TestMaskingPolicyCreate(t *testing.T) { - r := require.New(t) - - in := map[string]interface{}{ - "name": "policy_name", - "database": "database_name", - "schema": "schema_name", - "comment": "great comment", - "value_data_type": "string", - "masking_expression": "case when current_role() in ('ANALYST') then val else sha2(val, 512) end", - "return_data_type": "string", - } - - d := maskingPolicy(t, "database_name|schema_name|policy_name", in) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - mock.ExpectExec( - `^CREATE MASKING POLICY "database_name"."schema_name"."policy_name" AS \(VAL string\) RETURNS string -> case when current_role\(\) in \('ANALYST'\) then val else sha2\(val, 512\) end COMMENT = \'great comment\'$`, - ).WillReturnResult(sqlmock.NewResult(1, 1)) - expectReadMaskingPolicy(mock) - err := resources.CreateMaskingPolicy(d, db) - r.NoError(err) - r.Equal("policy_name", d.Get("name").(string)) - }) -} - -func expectReadMaskingPolicy(mock sqlmock.Sqlmock) { - showRows := sqlmock.NewRows([]string{ - "created_on", "name", "database_name", "schema_name", "kind", "owner", "comment", - }).AddRow( - time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), "policy_name", "database_name", "schema_name", "MASKING_POLICY", "test", "this is a comment", - ) - mock.ExpectQuery(`^SHOW MASKING POLICIES LIKE 'policy_name' IN SCHEMA "database_name"."schema_name"$`).WillReturnRows(showRows) - - descRows := sqlmock.NewRows([]string{ - "name", "signature", "return_type", "body", - }).AddRow( - "policy_name", "(VAL VARCHAR)", "VARCHAR(16777216)", "case when current_role() in ('ANALYST') then val else sha2(val, 512) end", - ) - mock.ExpectQuery(`^DESCRIBE MASKING POLICY "database_name"."schema_name"."policy_name"$`).WillReturnRows(descRows) -} - -func TestMaskingPolicyDelete(t *testing.T) { - r := require.New(t) - - in := map[string]interface{}{ - "name": "policy_name", - "database": "database_name", - "schema": "schema_name", - "comment": "great comment", - "value_data_type": "string", - "masking_expression": "case when current_role() in ('ANALYST') then val else sha2(val, 512) end", - "return_data_type": "string", - } - - d := maskingPolicy(t, "database_name|schema_name|policy_name", in) - r.NotNil(d) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - mock.ExpectExec(`^DROP MASKING POLICY "database_name"."schema_name"."policy_name"$`).WillReturnResult(sqlmock.NewResult(1, 1)) - err := resources.DeleteMaskingPolicy(d, db) - r.NoError(err) - }) -} diff --git a/pkg/resources/materialized_view_grant.go b/pkg/resources/materialized_view_grant.go index 3dc7257a76..21740abf6b 100644 --- a/pkg/resources/materialized_view_grant.go +++ b/pkg/resources/materialized_view_grant.go @@ -183,7 +183,7 @@ func CreateMaterializedViewGrant(d *schema.ResourceData, meta interface{}) error return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, materializedViewName, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, materializedViewName, privilege, withGrantOption, onFuture, onAll, roles, shares) d.SetId(grantID) return ReadMaterializedViewGrant(d, meta) @@ -225,7 +225,7 @@ func ReadMaterializedViewGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, materializedViewName, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, materializedViewName, privilege, withGrantOption, onFuture, onAll, roles, shares) // if the ID is not in the new format, rewrite it if d.Id() != grantID { d.SetId(grantID) diff --git a/pkg/resources/network_policy_attachment.go b/pkg/resources/network_policy_attachment.go index 2ece38c3ea..ecce6eaf06 100644 --- a/pkg/resources/network_policy_attachment.go +++ b/pkg/resources/network_policy_attachment.go @@ -119,7 +119,6 @@ func ReadNetworkPolicyAttachment(d *schema.ResourceData, meta interface{}) error if err := d.Set("set_for_account", isSetOnAccount); err != nil { return err } - return nil } @@ -139,9 +138,9 @@ func UpdateNetworkPolicyAttachment(d *schema.ResourceData, meta interface{}) err } if d.HasChange("users") { - old, new := d.GetChange("users") - oldUsersSet := old.(*schema.Set) - newUsersSet := new.(*schema.Set) + o, n := d.GetChange("users") + oldUsersSet := o.(*schema.Set) + newUsersSet := n.(*schema.Set) removedUsers := expandStringList(oldUsersSet.Difference(newUsersSet).List()) addedUsers := expandStringList(newUsersSet.Difference(oldUsersSet).List()) diff --git a/pkg/resources/object_parameter.go b/pkg/resources/object_parameter.go index 4cfbd24d84..ab33bee112 100644 --- a/pkg/resources/object_parameter.go +++ b/pkg/resources/object_parameter.go @@ -177,7 +177,6 @@ func ReadObjectParameter(d *schema.ResourceData, meta interface{}) error { if err := d.Set("value", p.Value.String); err != nil { return err } - return nil } diff --git a/pkg/resources/password_policy.go b/pkg/resources/password_policy.go index cb52713b3f..7c7b5ebc5f 100644 --- a/pkg/resources/password_policy.go +++ b/pkg/resources/password_policy.go @@ -26,7 +26,6 @@ var passwordPolicySchema = map[string]*schema.Schema{ "name": { Type: schema.TypeString, Required: true, - ForceNew: true, Description: "Identifier for the password policy; must be unique for your account.", }, "or_replace": { @@ -142,11 +141,7 @@ func CreatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { name := d.Get("name").(string) database := d.Get("database").(string) schema := d.Get("schema").(string) - objectIdentifier := sdk.SchemaObjectIdentifier{ - DatabaseName: database, - SchemaName: schema, - Name: name, - } + objectIdentifier := sdk.NewSchemaObjectIdentifier(database, schema, name) createOptions := &sdk.PasswordPolicyCreateOptions{} if v, ok := d.GetOk("or_replace"); ok { @@ -201,8 +196,7 @@ func CreatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { if err != nil { return err } - id := helpers.SnowflakeID(database, schema, name) - d.SetId(id) + d.SetId(helpers.EncodeSnowflakeID(objectIdentifier)) return ReadPasswordPolicy(d, meta) } @@ -212,10 +206,10 @@ func ReadPasswordPolicy(d *schema.ResourceData, meta interface{}) error { client := sdk.NewClientFromDB(db) ctx := context.Background() objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - schemaIdentifier := sdk.NewSchemaIdentifier(objectIdentifier.DatabaseName, objectIdentifier.SchemaName) + schemaIdentifier := sdk.NewSchemaIdentifier(objectIdentifier.DatabaseName(), objectIdentifier.SchemaName()) passwordPolicyList, err := client.PasswordPolicies.Show(ctx, &sdk.PasswordPolicyShowOptions{ Like: &sdk.Like{ - Pattern: sdk.String(objectIdentifier.Name), + Pattern: sdk.String(objectIdentifier.Name()), }, In: &sdk.In{ Schema: schemaIdentifier, @@ -273,7 +267,6 @@ func ReadPasswordPolicy(d *schema.ResourceData, meta interface{}) error { if err := d.Set("lockout_time_mins", passwordPolicyDetails.PasswordLockoutTimeMins.Value); err != nil { return err } - return nil } @@ -285,27 +278,14 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { objectIdentifier := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - if d.HasChange("name") { - _, n := d.GetChange("name") - databaseName := d.Get("database").(string) - schemaName := d.Get("schema").(string) - alterOptions := &sdk.PasswordPolicyAlterOptions{ - NewName: sdk.NewSchemaObjectIdentifier(databaseName, schemaName, n.(string)), - } - err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) - if err != nil { - return err - } - } - if d.HasChange("min_length") { alterOptions := &sdk.PasswordPolicyAlterOptions{} if v, ok := d.GetOk("min_length"); ok { - alterOptions.Set = &sdk.PasswordPolicyAlterSet{ + alterOptions.Set = &sdk.PasswordPolicySet{ PasswordMinLength: sdk.Int(v.(int)), } } else { - alterOptions.Unset = &sdk.PasswordPolicyAlterUnset{ + alterOptions.Unset = &sdk.PasswordPolicyUnset{ PasswordMinLength: sdk.Bool(true), } } @@ -317,11 +297,11 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { if d.HasChange("max_length") { alterOptions := &sdk.PasswordPolicyAlterOptions{} if v, ok := d.GetOk("max_length"); ok { - alterOptions.Set = &sdk.PasswordPolicyAlterSet{ + alterOptions.Set = &sdk.PasswordPolicySet{ PasswordMaxLength: sdk.Int(v.(int)), } } else { - alterOptions.Unset = &sdk.PasswordPolicyAlterUnset{ + alterOptions.Unset = &sdk.PasswordPolicyUnset{ PasswordMaxLength: sdk.Bool(true), } } @@ -333,11 +313,11 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { if d.HasChange("min_upper_case_chars") { alterOptions := &sdk.PasswordPolicyAlterOptions{} if v, ok := d.GetOk("min_upper_case_chars"); ok { - alterOptions.Set = &sdk.PasswordPolicyAlterSet{ + alterOptions.Set = &sdk.PasswordPolicySet{ PasswordMinUpperCaseChars: sdk.Int(v.(int)), } } else { - alterOptions.Unset = &sdk.PasswordPolicyAlterUnset{ + alterOptions.Unset = &sdk.PasswordPolicyUnset{ PasswordMinUpperCaseChars: sdk.Bool(true), } } @@ -349,11 +329,11 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { if d.HasChange("min_lower_case_chars") { alterOptions := &sdk.PasswordPolicyAlterOptions{} if v, ok := d.GetOk("min_lower_case_chars"); ok { - alterOptions.Set = &sdk.PasswordPolicyAlterSet{ + alterOptions.Set = &sdk.PasswordPolicySet{ PasswordMinLowerCaseChars: sdk.Int(v.(int)), } } else { - alterOptions.Unset = &sdk.PasswordPolicyAlterUnset{ + alterOptions.Unset = &sdk.PasswordPolicyUnset{ PasswordMinLowerCaseChars: sdk.Bool(true), } } @@ -366,11 +346,11 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { if d.HasChange("min_numeric_chars") { alterOptions := &sdk.PasswordPolicyAlterOptions{} if v, ok := d.GetOk("min_numeric_chars"); ok { - alterOptions.Set = &sdk.PasswordPolicyAlterSet{ + alterOptions.Set = &sdk.PasswordPolicySet{ PasswordMinNumericChars: sdk.Int(v.(int)), } } else { - alterOptions.Unset = &sdk.PasswordPolicyAlterUnset{ + alterOptions.Unset = &sdk.PasswordPolicyUnset{ PasswordMinNumericChars: sdk.Bool(true), } } @@ -383,11 +363,11 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { if d.HasChange("min_special_chars") { alterOptions := &sdk.PasswordPolicyAlterOptions{} if v, ok := d.GetOk("min_special_chars"); ok { - alterOptions.Set = &sdk.PasswordPolicyAlterSet{ + alterOptions.Set = &sdk.PasswordPolicySet{ PasswordMinSpecialChars: sdk.Int(v.(int)), } } else { - alterOptions.Unset = &sdk.PasswordPolicyAlterUnset{ + alterOptions.Unset = &sdk.PasswordPolicyUnset{ PasswordMinSpecialChars: sdk.Bool(true), } } @@ -400,11 +380,11 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { if d.HasChange("max_age_days") { alterOptions := &sdk.PasswordPolicyAlterOptions{} if v, ok := d.GetOk("max_age_days"); ok { - alterOptions.Set = &sdk.PasswordPolicyAlterSet{ + alterOptions.Set = &sdk.PasswordPolicySet{ PasswordMaxAgeDays: sdk.Int(v.(int)), } } else { - alterOptions.Unset = &sdk.PasswordPolicyAlterUnset{ + alterOptions.Unset = &sdk.PasswordPolicyUnset{ PasswordMaxAgeDays: sdk.Bool(true), } } @@ -417,11 +397,11 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { if d.HasChange("max_retries") { alterOptions := &sdk.PasswordPolicyAlterOptions{} if v, ok := d.GetOk("max_retries"); ok { - alterOptions.Set = &sdk.PasswordPolicyAlterSet{ + alterOptions.Set = &sdk.PasswordPolicySet{ PasswordMaxRetries: sdk.Int(v.(int)), } } else { - alterOptions.Unset = &sdk.PasswordPolicyAlterUnset{ + alterOptions.Unset = &sdk.PasswordPolicyUnset{ PasswordMaxRetries: sdk.Bool(true), } } @@ -433,11 +413,11 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { if d.HasChange("lockout_time_mins") { alterOptions := &sdk.PasswordPolicyAlterOptions{} if v, ok := d.GetOk("lockout_time_mins"); ok { - alterOptions.Set = &sdk.PasswordPolicyAlterSet{ + alterOptions.Set = &sdk.PasswordPolicySet{ PasswordLockoutTimeMins: sdk.Int(v.(int)), } } else { - alterOptions.Unset = &sdk.PasswordPolicyAlterUnset{ + alterOptions.Unset = &sdk.PasswordPolicyUnset{ PasswordLockoutTimeMins: sdk.Bool(true), } } @@ -450,11 +430,11 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { if d.HasChange("comment") { alterOptions := &sdk.PasswordPolicyAlterOptions{} if v, ok := d.GetOk("comment"); ok { - alterOptions.Set = &sdk.PasswordPolicyAlterSet{ + alterOptions.Set = &sdk.PasswordPolicySet{ Comment: sdk.String(v.(string)), } } else { - alterOptions.Unset = &sdk.PasswordPolicyAlterUnset{ + alterOptions.Unset = &sdk.PasswordPolicyUnset{ Comment: sdk.Bool(true), } } @@ -464,6 +444,20 @@ func UpdatePasswordPolicy(d *schema.ResourceData, meta interface{}) error { } } + if d.HasChange("name") { + _, n := d.GetChange("name") + newName := n.(string) + newID := sdk.NewSchemaObjectIdentifier(objectIdentifier.DatabaseName(), objectIdentifier.SchemaName(), newName) + alterOptions := &sdk.PasswordPolicyAlterOptions{ + NewName: newID, + } + err := client.PasswordPolicies.Alter(ctx, objectIdentifier, alterOptions) + if err != nil { + return err + } + d.SetId(helpers.EncodeSnowflakeID(newID)) + } + return nil } diff --git a/pkg/resources/pipe.go b/pkg/resources/pipe.go index 3a7c7cc29f..e056e0347f 100644 --- a/pkg/resources/pipe.go +++ b/pkg/resources/pipe.go @@ -97,13 +97,13 @@ func Pipe() *schema.Resource { } } -func pipeCopyStatementDiffSuppress(k, old, new string, d *schema.ResourceData) bool { +func pipeCopyStatementDiffSuppress(_, o, n string, _ *schema.ResourceData) bool { // standardize line endings - old = strings.ReplaceAll(old, "\r\n", "\n") - new = strings.ReplaceAll(new, "\r\n", "\n") + o = strings.ReplaceAll(o, "\r\n", "\n") + n = strings.ReplaceAll(n, "\r\n", "\n") // trim off any trailing line endings - return strings.TrimRight(old, ";\r\n") == strings.TrimRight(new, ";\r\n") + return strings.TrimRight(o, ";\r\n") == strings.TrimRight(n, ";\r\n") } type pipeID struct { @@ -272,11 +272,8 @@ func ReadPipe(d *schema.ResourceData, meta interface{}) error { pipe.ErrorIntegration.Valid = false pipe.ErrorIntegration.String = "" } - if err := d.Set("error_integration", pipe.ErrorIntegration.String); err != nil { - return err - } - - return nil + err = d.Set("error_integration", pipe.ErrorIntegration.String) + return err } // UpdatePipe implements schema.UpdateFunc. diff --git a/pkg/resources/pipe_grant.go b/pkg/resources/pipe_grant.go index 54fb6b81c2..bdd94fb00b 100644 --- a/pkg/resources/pipe_grant.go +++ b/pkg/resources/pipe_grant.go @@ -152,7 +152,7 @@ func CreatePipeGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, pipeName, privilege, withGrantOption, onFuture, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, pipeName, privilege, withGrantOption, onFuture, roles) d.SetId(grantID) return ReadPipeGrant(d, meta) @@ -188,7 +188,7 @@ func ReadPipeGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, pipeName, privilege, withGrantOption, onFuture, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, pipeName, privilege, withGrantOption, onFuture, roles) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/procedure.go b/pkg/resources/procedure.go index c6304032ab..d0e74ded10 100644 --- a/pkg/resources/procedure.go +++ b/pkg/resources/procedure.go @@ -164,8 +164,8 @@ var procedureSchema = map[string]*schema.Schema{ }, } -func DiffTypes(k, old, new string, d *schema.ResourceData) bool { - return strings.EqualFold(strings.ToUpper(old), strings.ToUpper(new)) +func DiffTypes(_, o, n string, _ *schema.ResourceData) bool { + return strings.EqualFold(strings.ToUpper(o), strings.ToUpper(n)) } // Procedure returns a pointer to the resource representing a stored procedure. diff --git a/pkg/resources/procedure_grant.go b/pkg/resources/procedure_grant.go index 7436936301..e476f65de9 100644 --- a/pkg/resources/procedure_grant.go +++ b/pkg/resources/procedure_grant.go @@ -197,7 +197,7 @@ func CreateProcedureGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, procedureName, argumentDataTypes, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, procedureName, argumentDataTypes, privilege, withGrantOption, onFuture, onAll, roles, shares) d.SetId(grantID) return ReadProcedureGrant(d, meta) } @@ -230,7 +230,7 @@ func ReadProcedureGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, procedureName, argumentDataTypes, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, procedureName, argumentDataTypes, privilege, withGrantOption, onFuture, onAll, roles, shares) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/resource_monitor.go b/pkg/resources/resource_monitor.go index 6080f34c58..8c519f0faa 100644 --- a/pkg/resources/resource_monitor.go +++ b/pkg/resources/resource_monitor.go @@ -204,7 +204,6 @@ func CreateResourceMonitor(d *schema.ResourceData, meta interface{}) error { if err := ReadResourceMonitor(d, meta); err != nil { return err } - return nil } diff --git a/pkg/resources/resource_monitor_grant.go b/pkg/resources/resource_monitor_grant.go index 7b90afe584..8157d44ad3 100644 --- a/pkg/resources/resource_monitor_grant.go +++ b/pkg/resources/resource_monitor_grant.go @@ -100,7 +100,7 @@ func CreateResourceMonitorGrant(d *schema.ResourceData, meta interface{}) error return err } - grantID := helpers.SnowflakeID(monitorName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(monitorName, privilege, withGrantOption, roles) d.SetId(grantID) return ReadResourceMonitorGrant(d, meta) @@ -119,7 +119,7 @@ func ReadResourceMonitorGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(monitorName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(monitorName, privilege, withGrantOption, roles) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/role.go b/pkg/resources/role.go index 5bfa8933a9..5c3ef41554 100644 --- a/pkg/resources/role.go +++ b/pkg/resources/role.go @@ -118,8 +118,8 @@ func UpdateRole(d *schema.ResourceData, meta interface{}) error { } if d.HasChange("tag") { - old, new := d.GetChange("tag") - removed, added, changed := getTags(old).diffs(getTags(new)) + o, n := d.GetChange("tag") + removed, added, changed := getTags(o).diffs(getTags(n)) for _, tA := range removed { err := builder.UnsetTag(tA.toSnowflakeTagValue()) if err != nil { diff --git a/pkg/resources/role_grants.go b/pkg/resources/role_grants.go index d4fe16e04e..542ce4f589 100644 --- a/pkg/resources/role_grants.go +++ b/pkg/resources/role_grants.go @@ -85,7 +85,7 @@ func CreateRoleGrants(d *schema.ResourceData, meta interface{}) error { return fmt.Errorf("no users or roles specified for role grants") } - grantID := helpers.SnowflakeID(roleName, roles, users) + grantID := helpers.EncodeSnowflakeID(roleName, roles, users) d.SetId(grantID) for _, role := range roles { @@ -170,7 +170,7 @@ func ReadRoleGrants(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(roleName, roles, users) + grantID := helpers.EncodeSnowflakeID(roleName, roles, users) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/row_access_policy_grant.go b/pkg/resources/row_access_policy_grant.go index 893068ade4..172f88cd56 100644 --- a/pkg/resources/row_access_policy_grant.go +++ b/pkg/resources/row_access_policy_grant.go @@ -128,7 +128,7 @@ func CreateRowAccessPolicyGrant(d *schema.ResourceData, meta interface{}) error return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, rowAccessPolicyName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, rowAccessPolicyName, privilege, withGrantOption, roles) d.SetId(grantID) return ReadRowAccessPolicyGrant(d, meta) @@ -150,7 +150,7 @@ func ReadRowAccessPolicyGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, rowAccessPolicyName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, rowAccessPolicyName, privilege, withGrantOption, roles) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/saml_integration.go b/pkg/resources/saml_integration.go index d9f9b54326..b6bb0697e2 100644 --- a/pkg/resources/saml_integration.go +++ b/pkg/resources/saml_integration.go @@ -292,7 +292,7 @@ func ReadSAMLIntegration(d *schema.ResourceData, meta interface{}) error { return fmt.Errorf("unable to set saml2_sp_initiated_login_page_label for security integration") } case "SAML2_ENABLE_SP_INITIATED": - b := false + var b bool switch v2 := v.(type) { case bool: b = v2 @@ -312,7 +312,7 @@ func ReadSAMLIntegration(d *schema.ResourceData, meta interface{}) error { return fmt.Errorf("unable to set saml2_snowflake_x509_cert for security integration err = %w", err) } case "SAML2_SIGN_REQUEST": - b := false + var b bool switch v2 := v.(type) { case bool: b = v2 @@ -336,7 +336,7 @@ func ReadSAMLIntegration(d *schema.ResourceData, meta interface{}) error { return fmt.Errorf("unable to set saml2_post_logout_redirect_url for security integration err = %w", err) } case "SAML2_FORCE_AUTHN": - b := false + var b bool switch v2 := v.(type) { case bool: b = v2 diff --git a/pkg/resources/schema_grant.go b/pkg/resources/schema_grant.go index d23893a7d1..3c4a224b75 100644 --- a/pkg/resources/schema_grant.go +++ b/pkg/resources/schema_grant.go @@ -184,7 +184,7 @@ func CreateSchemaGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, privilege, withGrantOption, onFuture, onAll, roles, shares) d.SetId(grantID) return ReadSchemaGrant(d, meta) @@ -280,7 +280,7 @@ func ReadSchemaGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, privilege, withGrantOption, onFuture, onAll, roles, shares) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/sequence.go b/pkg/resources/sequence.go index 2ba77ad02e..5735996218 100644 --- a/pkg/resources/sequence.go +++ b/pkg/resources/sequence.go @@ -196,7 +196,6 @@ func ReadSequence(d *schema.ResourceData, meta interface{}) error { if err := d.Set("fully_qualified_name", seq.Address()); err != nil { return err } - return nil } diff --git a/pkg/resources/sequence_grant.go b/pkg/resources/sequence_grant.go index b3e24eb306..64c519379b 100644 --- a/pkg/resources/sequence_grant.go +++ b/pkg/resources/sequence_grant.go @@ -164,7 +164,7 @@ func CreateSequenceGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, sequenceName, privilege, withGrantOption, onFuture, onAll, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, sequenceName, privilege, withGrantOption, onFuture, onAll, roles) d.SetId(grantID) return ReadSequenceGrant(d, meta) @@ -196,7 +196,7 @@ func ReadSequenceGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, sequenceName, privilege, withGrantOption, onFuture, onAll, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, sequenceName, privilege, withGrantOption, onFuture, onAll, roles) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/stage.go b/pkg/resources/stage.go index 87a0ec89d3..9e26b24235 100644 --- a/pkg/resources/stage.go +++ b/pkg/resources/stage.go @@ -302,7 +302,6 @@ func ReadStage(d *schema.ResourceData, meta interface{}) error { if err := d.Set("snowflake_iam_user", stageDesc.SnowflakeIamUser); err != nil { return err } - return nil } diff --git a/pkg/resources/stage_grant.go b/pkg/resources/stage_grant.go index b8e14b9a4b..1076e01c36 100644 --- a/pkg/resources/stage_grant.go +++ b/pkg/resources/stage_grant.go @@ -173,7 +173,7 @@ func CreateStageGrant(d *schema.ResourceData, meta interface{}) error { } roles := expandStringList(d.Get("roles").(*schema.Set).List()) - grantID := helpers.SnowflakeID(databaseName, schemaName, stageName, privilege, withGrantOption, onFuture, onAll, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, stageName, privilege, withGrantOption, onFuture, onAll, roles) d.SetId(grantID) return ReadStageGrant(d, meta) @@ -205,7 +205,7 @@ func ReadStageGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, stageName, privilege, withGrantOption, onFuture, onAll, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, stageName, privilege, withGrantOption, onFuture, onAll, roles) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/stream.go b/pkg/resources/stream.go index aa8764b5c5..2c570c1c31 100644 --- a/pkg/resources/stream.go +++ b/pkg/resources/stream.go @@ -317,15 +317,16 @@ func ReadStream(d *schema.ResourceData, meta interface{}) error { return err } - if stream.SourceType.String == "Stage" { + switch stream.SourceType.String { + case "Stage": if err := d.Set("on_stage", stream.TableName.String); err != nil { return err } - } else if stream.SourceType.String == "View" { + case "View": if err := d.Set("on_view", stream.TableName.String); err != nil { return err } - } else { + default: if err := d.Set("on_table", stream.TableName.String); err != nil { return err } @@ -350,7 +351,6 @@ func ReadStream(d *schema.ResourceData, meta interface{}) error { if err := d.Set("owner", stream.Owner.String); err != nil { return err } - return nil } diff --git a/pkg/resources/stream_grant.go b/pkg/resources/stream_grant.go index 549c666e48..ec20008793 100644 --- a/pkg/resources/stream_grant.go +++ b/pkg/resources/stream_grant.go @@ -164,7 +164,7 @@ func CreateStreamGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, streamName, privilege, withGrantOption, onFuture, onAll, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, streamName, privilege, withGrantOption, onFuture, onAll, roles) d.SetId(grantID) return ReadStreamGrant(d, meta) @@ -196,7 +196,7 @@ func ReadStreamGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, streamName, privilege, withGrantOption, onFuture, onAll, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, streamName, privilege, withGrantOption, onFuture, onAll, roles) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/table.go b/pkg/resources/table.go index 7da719e020..54afb5af85 100644 --- a/pkg/resources/table.go +++ b/pkg/resources/table.go @@ -646,8 +646,8 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { } } if d.HasChange("column") { - t, new := d.GetChange("column") - removed, added, changed := getColumns(t).diffs(getColumns(new)) + t, n := d.GetChange("column") + removed, added, changed := getColumns(t).diffs(getColumns(n)) for _, cA := range removed { q := builder.DropColumn(cA.name) if err := snowflake.Exec(db, q); err != nil { diff --git a/pkg/resources/table_column_masking_policy_application_acceptance_test.go b/pkg/resources/table_column_masking_policy_application_acceptance_test.go index f1fa38bbb7..d4d0f8a816 100644 --- a/pkg/resources/table_column_masking_policy_application_acceptance_test.go +++ b/pkg/resources/table_column_masking_policy_application_acceptance_test.go @@ -47,7 +47,12 @@ resource "snowflake_masking_policy" "test" { name = "mypolicy" database = snowflake_database.test.name schema = snowflake_schema.test.name - value_data_type = "VARCHAR" + signature { + column { + name = "val" + type = "VARCHAR" + } + } masking_expression = "case when current_role() in ('ANALYST') then val else sha2(val, 512) end" return_data_type = "VARCHAR" comment = "Terraform acceptance test" diff --git a/pkg/resources/table_constraint.go b/pkg/resources/table_constraint.go index f2237e73d8..0876c8ea63 100644 --- a/pkg/resources/table_constraint.go +++ b/pkg/resources/table_constraint.go @@ -286,7 +286,7 @@ func CreateTableConstraint(d *schema.ResourceData, meta interface{}) error { } // ReadTableConstraint implements schema.ReadFunc. -func ReadTableConstraint(d *schema.ResourceData, meta interface{}) error { +func ReadTableConstraint(_ *schema.ResourceData, _ interface{}) error { // commenting this out since it requires an active warehouse to be set which may not be intuitive. // also it takes a while for the database to reflect changes. Would likely need to add a validation // step like in tag association. People don't like waiting 40 minutes for Terraform to run. @@ -324,8 +324,8 @@ func UpdateTableConstraint(d *schema.ResourceData, meta interface{}) error { }*/ if d.HasChange("name") { - _, new := d.GetChange("name") - _, err := db.Exec(builder.Rename(new.(string))) + _, n := d.GetChange("name") + _, err := db.Exec(builder.Rename(n.(string))) if err != nil { return fmt.Errorf("error renaming table constraint %v err = %w", tc.name, err) } diff --git a/pkg/resources/table_grant.go b/pkg/resources/table_grant.go index 034c894f56..fac5cd5a86 100644 --- a/pkg/resources/table_grant.go +++ b/pkg/resources/table_grant.go @@ -184,7 +184,7 @@ func CreateTableGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, tableName, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, tableName, privilege, withGrantOption, onFuture, onAll, roles, shares) d.SetId(grantID) return ReadTableGrant(d, meta) } @@ -215,7 +215,7 @@ func ReadTableGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, tableName, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, tableName, privilege, withGrantOption, onFuture, onAll, roles, shares) if grantID != d.Id() { d.SetId(grantID) } @@ -251,9 +251,9 @@ func UpdateTableGrant(d *schema.ResourceData, meta interface{}) error { // difference calculates roles/shares to add/revoke difference := func(key string) (toAdd []string, toRevoke []string) { - old, new := d.GetChange(key) - oldSet := old.(*schema.Set) - newSet := new.(*schema.Set) + o, n := d.GetChange(key) + oldSet := o.(*schema.Set) + newSet := n.(*schema.Set) toAdd = expandStringList(newSet.Difference(oldSet).List()) toRevoke = expandStringList(oldSet.Difference(newSet).List()) return diff --git a/pkg/resources/tag.go b/pkg/resources/tag.go index c4d71c04dd..01359edf34 100644 --- a/pkg/resources/tag.go +++ b/pkg/resources/tag.go @@ -97,8 +97,8 @@ type TagBuilder interface { func handleTagChanges(db *sql.DB, d *schema.ResourceData, builder TagBuilder) error { if d.HasChange("tag") { - old, new := d.GetChange("tag") - removed, added, changed := getTags(old).diffs(getTags(new)) + o, n := d.GetChange("tag") + removed, added, changed := getTags(o).diffs(getTags(n)) for _, tA := range removed { q := builder.UnsetTag(tA.toSnowflakeTagValue()) if err := snowflake.Exec(db, q); err != nil { @@ -260,11 +260,8 @@ func ReadTag(d *schema.ResourceData, meta interface{}) error { av := strings.ReplaceAll(t.AllowedValues.String, "\"", "") av = strings.TrimPrefix(av, "[") av = strings.TrimSuffix(av, "]") - if err := d.Set("allowed_values", helpers.StringListToList(av)); err != nil { - return err - } - - return nil + err = d.Set("allowed_values", helpers.StringListToList(av)) + return err } // UpdateTag implements schema.UpdateFunc. diff --git a/pkg/resources/tag_association.go b/pkg/resources/tag_association.go index 3464911366..a71ae67b2e 100644 --- a/pkg/resources/tag_association.go +++ b/pkg/resources/tag_association.go @@ -178,7 +178,6 @@ func ReadTagAssociation(d *schema.ResourceData, meta interface{}) error { if err := d.Set("tag_value", ta.TagValue.String); err != nil { return err } - return nil } @@ -193,8 +192,8 @@ func UpdateTagAssociation(d *schema.ResourceData, meta interface{}) error { builder := snowflake.NewTagAssociationBuilder(tagID).WithObjectIdentifier(fullyQualifierObjectIdentifier).WithObjectType(objectType) if d.HasChange("skip_validation") { - old, new := d.GetChange("skip_validation") - log.Printf("[DEBUG] skip_validation changed from %v to %v", old, new) + o, n := d.GetChange("skip_validation") + log.Printf("[DEBUG] skip_validation changed from %v to %v", o, n) } if d.HasChange("tag_value") { diff --git a/pkg/resources/tag_grant.go b/pkg/resources/tag_grant.go index 55678a8710..b627563d26 100644 --- a/pkg/resources/tag_grant.go +++ b/pkg/resources/tag_grant.go @@ -121,7 +121,7 @@ func CreateTagGrant(d *schema.ResourceData, meta interface{}) error { if err := createGenericGrant(d, meta, builder); err != nil { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, tagName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, tagName, privilege, withGrantOption, roles) d.SetId(grantID) return ReadTagGrant(d, meta) @@ -143,7 +143,7 @@ func ReadTagGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, tagName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, tagName, privilege, withGrantOption, roles) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/tag_masking_policy_association.go b/pkg/resources/tag_masking_policy_association.go index 099844ff36..4753ad9372 100644 --- a/pkg/resources/tag_masking_policy_association.go +++ b/pkg/resources/tag_masking_policy_association.go @@ -2,6 +2,7 @@ package resources import ( "bytes" + "context" "database/sql" "encoding/csv" "errors" @@ -11,6 +12,8 @@ import ( "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" snowflakeValidation "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/validation" ) @@ -97,6 +100,7 @@ func TagMaskingPolicyAssociation() *schema.Resource { Importer: &schema.ResourceImporter{ StateContext: schema.ImportStatePassthroughContext, }, + Description: "Attach a masking policy to a tag. Requires a current warehouse to be set. Either with SNOWFLAKE_WAREHOUSE env variable or in current session. If no warehouse is provided, a temporary warehouse will be created.", } } @@ -113,13 +117,10 @@ func CreateTagMaskingPolicyAssociation(d *schema.ResourceData, meta interface{}) tagName := tagIDStruct.TagName mpID := d.Get("masking_policy_id").(string) - mpIDStruct, mpIDErr := maskingPolicyIDFromString(mpID) - if mpIDErr != nil { - return mpIDErr - } - mpDB := mpIDStruct.DatabaseName - mpSchema := mpIDStruct.SchemaName - mpName := mpIDStruct.MaskingPolicyName + mpIDStruct := helpers.DecodeSnowflakeID(mpID).(sdk.SchemaObjectIdentifier) + mpDB := mpIDStruct.DatabaseName() + mpSchema := mpIDStruct.SchemaName() + mpName := mpIDStruct.Name() mP := snowflake.MaskingPolicy(mpName, mpDB, mpSchema) builder := snowflake.NewTagBuilder(tagName).WithDB(tagDB).WithSchema(tagSchema).WithMaskingPolicy(mP) @@ -165,6 +166,36 @@ func ReadTagMaskingPolicyAssociation(d *schema.ResourceData, meta interface{}) e mP := snowflake.MaskingPolicy(mpName, mpDBName, mpSchameName) builder := snowflake.NewTagBuilder(tagName).WithDB(tagDBName).WithSchema(tagSchemaName).WithMaskingPolicy(mP) + // create temp warehouse to query the tag, and make sure to clean it up + client := sdk.NewClientFromDB(db) + ctx := context.Background() + originalWarehouse, err := client.ContextFunctions.CurrentWarehouse(ctx) + if err != nil { + return err + } + if originalWarehouse == "" { + log.Printf("[DEBUG] no current warehouse set, creating a temporary warehouse") + randomWarehouseName := fmt.Sprintf("terraform-provider-snowflake-%v", helpers.RandomString()) + tempWarehouseID := sdk.NewAccountObjectIdentifier(randomWarehouseName) + err = client.Warehouses.Create(ctx, tempWarehouseID, nil) + if err != nil { + return err + } + defer func() { + err := client.Warehouses.Drop(ctx, tempWarehouseID, nil) + if err != nil { + log.Printf("[WARN] error cleaning up temp warehouse %v", err) + } + err = client.Sessions.UseWarehouse(ctx, sdk.NewAccountObjectIdentifier(originalWarehouse)) + if err != nil { + log.Printf("[WARN] error resetting warehouse %v", err) + } + }() + err = client.Sessions.UseWarehouse(ctx, tempWarehouseID) + if err != nil { + return err + } + } row := snowflake.QueryRow(db, builder.ShowAttachedPolicy()) t, err := snowflake.ScanTagPolicy(row) if errors.Is(err, sql.ErrNoRows) { @@ -189,16 +220,7 @@ func ReadTagMaskingPolicyAssociation(d *schema.ResourceData, meta interface{}) e return err } - mpID := maskingPolicyID{ - DatabaseName: t.PolicyDB.String, - SchemaName: t.PolicySchema.String, - MaskingPolicyName: t.PolicyName.String, - } - - mpIDString, err := mpID.String() - if err != nil { - return err - } + mpIDString := helpers.EncodeSnowflakeID(t.PolicyDB.String, t.PolicySchema.String, t.PolicyName.String) if err := d.Set("tag_id", tagIDString); err != nil { return err @@ -207,7 +229,6 @@ func ReadTagMaskingPolicyAssociation(d *schema.ResourceData, meta interface{}) e if err := d.Set("masking_policy_id", mpIDString); err != nil { return err } - return nil } diff --git a/pkg/resources/tag_masking_policy_association_acceptance_test.go b/pkg/resources/tag_masking_policy_association_acceptance_test.go index 5841e23c70..d1c0d768aa 100644 --- a/pkg/resources/tag_masking_policy_association_acceptance_test.go +++ b/pkg/resources/tag_masking_policy_association_acceptance_test.go @@ -52,7 +52,12 @@ resource "snowflake_masking_policy" "test" { name = "%[1]v" database = snowflake_database.test.name schema = snowflake_schema.test.name - value_data_type = "VARCHAR" + signature { + column { + name = "val" + type = "VARCHAR" + } + } masking_expression = "case when current_role() in ('ANALYST') then val else sha2(val, 512) end" return_data_type = "VARCHAR(16777216)" comment = "Terraform acceptance test" diff --git a/pkg/resources/tag_masking_policy_association_test.go b/pkg/resources/tag_masking_policy_association_test.go deleted file mode 100644 index 4da54e9ce8..0000000000 --- a/pkg/resources/tag_masking_policy_association_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package resources_test - -import ( - "database/sql" - "regexp" - "testing" - - sqlmock "github.com/DATA-DOG/go-sqlmock" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/resources" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" - . "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/testhelpers" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" - "github.com/stretchr/testify/require" -) - -func TestTagMaskingPolicyAssociation(t *testing.T) { - r := require.New(t) - err := resources.Tag().InternalValidate(provider.Provider().Schema, true) - r.NoError(err) -} - -func TestTagMaskingPolicyAssociationCreate(t *testing.T) { - r := require.New(t) - - in := map[string]interface{}{ - "tag_id": "tag_db|tag_schema|tag_name", - "masking_policy_id": "mp_db|mp_schema|mp_name", - } - d := schema.TestResourceDataRaw(t, resources.TagMaskingPolicyAssociation().Schema, in) - r.NotNil(d) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - mock.ExpectExec(`^ALTER TAG "tag_db"."tag_schema"."tag_name" SET MASKING POLICY "mp_db"."mp_schema"."mp_name"$`).WillReturnResult(sqlmock.NewResult(1, 1)) - - expectReadTestTagMaskingPolicyAssociation(mock) - err := resources.CreateTagMaskingPolicyAssociation(d, db) - r.NoError(err) - }) -} - -func TestTagMaskingPolicyAssociationDelete(t *testing.T) { - r := require.New(t) - - in := map[string]interface{}{ - "tag_id": "tag_db|tag_schema|tag_name", - "masking_policy_id": "mp_db|mp_schema|mp_name", - } - - d := schema.TestResourceDataRaw(t, resources.TagMaskingPolicyAssociation().Schema, in) - d.SetId("tag_db|tag_schema|tag_name|mp_db|mp_schema|mp_name") - r.NotNil(d) - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - mock.ExpectExec(`^ALTER TAG "tag_db"."tag_schema"."tag_name" UNSET MASKING POLICY "mp_db"."mp_schema"."mp_name"$`).WillReturnResult(sqlmock.NewResult(1, 1)) - - err := resources.DeleteTagMaskingPolicyAssociation(d, db) - - r.NoError(err) - }) -} - -func TestTagMaskingPolicyAssociationRead(t *testing.T) { - r := require.New(t) - - in := map[string]interface{}{ - "tag_id": "tag_db|tag_schema|tag_name", - "masking_policy_id": "mp_db|mp_schema|mp_name", - } - - d := schema.TestResourceDataRaw(t, resources.TagMaskingPolicyAssociation().Schema, in) - d.SetId("tag_db|tag_schema|tag_name|mp_db|mp_schema|mp_name") - - WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - // Test when resource is not found, checking if state will be empty - r.NotEmpty(d.State()) - mP := snowflake.MaskingPolicy("mp_name", "mp_db", "mp_schema") - q := snowflake.NewTagBuilder("tag_name").WithDB("tag_db").WithSchema("tag_schema").WithMaskingPolicy(mP).ShowAttachedPolicy() - mock.ExpectQuery(regexp.QuoteMeta(q)).WillReturnError(sql.ErrNoRows) - err := resources.ReadTagMaskingPolicyAssociation(d, db) - - r.Empty(d.State()) - r.Nil(err) - }) -} - -func expectReadTestTagMaskingPolicyAssociation(mock sqlmock.Sqlmock) { - rows := sqlmock.NewRows([]string{ - "POLICY_DB", "POLICY_SCHEMA", "POLICY_NAME", "POLICY_KIND", "REF_DATABASE_NAME", "REF_SCHEMA_NAME", "REF_ENTITY_NAME", "REF_ENTITY_DOMAIN", - }, - ).AddRow("mp_db", "mp_schema", "mp_name", "MASKING", "tag_db", "tag_schema", "tag_name", "TAG") - mock.ExpectQuery(regexp.QuoteMeta(`SELECT * from table ("tag_db".information_schema.policy_references(ref_entity_name => '"tag_db"."tag_schema"."tag_name"', ref_entity_domain => 'TAG')) where policy_db='mp_db' and policy_schema='mp_schema' and policy_name='mp_name'`)).WillReturnRows(rows) -} diff --git a/pkg/resources/task.go b/pkg/resources/task.go index d35f34aebd..148fcb87b8 100644 --- a/pkg/resources/task.go +++ b/pkg/resources/task.go @@ -405,12 +405,7 @@ func CreateTask(d *schema.ResourceData, meta interface{}) error { // resume the task after modifications are complete as long as it is not a standalone task if !(rootTask.Name == name) { - defer func() { - q = rootTask.Resume() - if err := snowflake.Exec(db, q); err != nil { - log.Printf("[WARN] failed to resume task %s", rootTask.Name) - } - }() + defer resumeTask(db, rootTask) } } } @@ -448,6 +443,13 @@ func CreateTask(d *schema.ResourceData, meta interface{}) error { return ReadTask(d, meta) } +func resumeTask(db *sql.DB, rootTask *snowflake.Task) { + q := rootTask.Resume() + if err := snowflake.Exec(db, q); err != nil { + log.Printf("[WARN] failed to resume task %s", rootTask.Name) + } +} + // UpdateTask implements schema.UpdateFunc. func UpdateTask(d *schema.ResourceData, meta interface{}) error { taskID, err := taskIDFromString(d.Id()) @@ -475,12 +477,7 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { if !(rootTask.Name == name) { // resume the task after modifications are complete, as long as it is not a standalone task - defer func() { - q = rootTask.Resume() - if err := snowflake.Exec(db, q); err != nil { - log.Printf("[WARN] failed to resume task %s", rootTask.Name) - } - }() + defer resumeTask(db, rootTask) } } } @@ -539,15 +536,15 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { return fmt.Errorf("error suspending task %v", d.Id()) } - old, new := d.GetChange("after") + o, n := d.GetChange("after") var oldAfter []string - if len(old.([]interface{})) > 0 { - oldAfter = expandStringList(old.([]interface{})) + if len(o.([]interface{})) > 0 { + oldAfter = expandStringList(o.([]interface{})) } var newAfter []string - if len(new.([]interface{})) > 0 { - newAfter = expandStringList(new.([]interface{})) + if len(n.([]interface{})) > 0 { + newAfter = expandStringList(n.([]interface{})) } // Remove old dependencies that are not in new dependencies @@ -587,12 +584,7 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { if !(rootTask.Name == name) { // resume the task after modifications are complete, as long as it is not a standalone task - defer func() { - q = rootTask.Resume() - if err := snowflake.Exec(db, q); err != nil { - log.Printf("[WARN] failed to resume task %s", rootTask.Name) - } - }() + defer resumeTask(db, rootTask) } } } @@ -606,11 +598,11 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { if d.HasChange("schedule") { var q string - old, new := d.GetChange("schedule") - if old != "" && new == "" { + o, n := d.GetChange("schedule") + if o != "" && n == "" { q = builder.RemoveSchedule() } else { - q = builder.ChangeSchedule(new.(string)) + q = builder.ChangeSchedule(n.(string)) } if err := snowflake.Exec(db, q); err != nil { @@ -620,11 +612,11 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { if d.HasChange("user_task_timeout_ms") { var q string - old, new := d.GetChange("user_task_timeout_ms") - if old.(int) > 0 && new.(int) == 0 { + o, n := d.GetChange("user_task_timeout_ms") + if o.(int) > 0 && n.(int) == 0 { q = builder.RemoveTimeout() } else { - q = builder.ChangeTimeout(new.(int)) + q = builder.ChangeTimeout(n.(int)) } if err := snowflake.Exec(db, q); err != nil { return fmt.Errorf("error updating user task timeout on task %v", d.Id()) @@ -633,11 +625,11 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { if d.HasChange("comment") { var q string - old, new := d.GetChange("comment") - if old != "" && new == "" { + o, n := d.GetChange("comment") + if o != "" && n == "" { q = builder.RemoveComment() } else { - q = builder.ChangeComment(new.(string)) + q = builder.ChangeComment(n.(string)) } if err := snowflake.Exec(db, q); err != nil { @@ -647,8 +639,8 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { if d.HasChange("allow_overlapping_execution") { var q string - _, new := d.GetChange("allow_overlapping_execution") - flag := new.(bool) + _, n := d.GetChange("allow_overlapping_execution") + flag := n.(bool) if flag { q = builder.SetAllowOverlappingExecutionParameter() } else { @@ -692,16 +684,16 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error { } if d.HasChange("when") { - new := d.Get("when") - q := builder.ChangeCondition(new.(string)) + n := d.Get("when") + q := builder.ChangeCondition(n.(string)) if err := snowflake.Exec(db, q); err != nil { return fmt.Errorf("error updating when condition on task %v", d.Id()) } } if d.HasChange("sql_statement") { - new := d.Get("sql_statement") - q := builder.ChangeSQLStatement(new.(string)) + n := d.Get("sql_statement") + q := builder.ChangeSQLStatement(n.(string)) if err := snowflake.Exec(db, q); err != nil { return fmt.Errorf("error updating sql statement on task %v", d.Id()) } @@ -747,12 +739,7 @@ func DeleteTask(d *schema.ResourceData, meta interface{}) error { if !(rootTask.Name == name) { // resume the task after modifications are complete, as long as it is not a standalone task - defer func() { - q = rootTask.Resume() - if err := snowflake.Exec(db, q); err != nil { - log.Printf("[WARN] failed to resume task %s", rootTask.Name) - } - }() + defer resumeTask(db, rootTask) } } } diff --git a/pkg/resources/task_grant.go b/pkg/resources/task_grant.go index d0121a0b89..c1a4d63781 100644 --- a/pkg/resources/task_grant.go +++ b/pkg/resources/task_grant.go @@ -164,7 +164,7 @@ func CreateTaskGrant(d *schema.ResourceData, meta interface{}) error { if err := createGenericGrant(d, meta, builder); err != nil { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, taskName, privilege, withGrantOption, onFuture, onAll, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, taskName, privilege, withGrantOption, onFuture, onAll, roles) d.SetId(grantID) return ReadTaskGrant(d, meta) @@ -196,7 +196,7 @@ func ReadTaskGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, taskName, privilege, withGrantOption, onFuture, onAll, roles) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, taskName, privilege, withGrantOption, onFuture, onAll, roles) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/user_grant.go b/pkg/resources/user_grant.go index c99fc9447f..bb9c64f191 100644 --- a/pkg/resources/user_grant.go +++ b/pkg/resources/user_grant.go @@ -98,7 +98,7 @@ func CreateUserGrant(d *schema.ResourceData, meta interface{}) error { if err := createGenericGrant(d, meta, builder); err != nil { return err } - grantID := helpers.SnowflakeID(userName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(userName, privilege, withGrantOption, roles) d.SetId(grantID) return ReadUserGrant(d, meta) @@ -118,7 +118,7 @@ func ReadUserGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(userName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(userName, privilege, withGrantOption, roles) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/view.go b/pkg/resources/view.go index c43d6e3368..acceaf2953 100644 --- a/pkg/resources/view.go +++ b/pkg/resources/view.go @@ -82,7 +82,7 @@ func normalizeQuery(str string) string { // If we can find a sql parser that can handle the snowflake dialect then we should switch to parsing // queries and either comparing ASTs or emitting a canonical serialization for comparison. I couldn't // find such a library. -func DiffSuppressStatement(_, old, new string, d *schema.ResourceData) bool { +func DiffSuppressStatement(_, old, new string, _ *schema.ResourceData) bool { return strings.EqualFold(normalizeQuery(old), normalizeQuery(new)) } @@ -255,10 +255,8 @@ func ReadView(d *schema.ResourceData, meta interface{}) error { if err = d.Set("statement", substringOfQuery); err != nil { return err } - if err = d.Set("database", v.DatabaseName.String); err != nil { - return err - } - return nil + err = d.Set("database", v.DatabaseName.String) + return err } // UpdateView implements schema.UpdateFunc. @@ -343,8 +341,8 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error { return tagChangeErr } if d.HasChange("tag") { - old, new := d.GetChange("tag") - removed, added, changed := getTags(old).diffs(getTags(new)) + o, n := d.GetChange("tag") + removed, added, changed := getTags(o).diffs(getTags(n)) for _, tA := range removed { q := builder.UnsetTag(tA.toSnowflakeTagValue()) if err := snowflake.Exec(db, q); err != nil { diff --git a/pkg/resources/view_grant.go b/pkg/resources/view_grant.go index 4511cbf654..758ac9b8ba 100644 --- a/pkg/resources/view_grant.go +++ b/pkg/resources/view_grant.go @@ -181,7 +181,7 @@ func CreateViewGrant(d *schema.ResourceData, meta interface{}) error { if err != nil { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, viewName, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, viewName, privilege, withGrantOption, onFuture, onAll, roles, shares) d.SetId(grantID) return ReadViewGrant(d, meta) } @@ -212,7 +212,7 @@ func ReadViewGrant(d *schema.ResourceData, meta interface{}) error { if err != nil { return err } - grantID := helpers.SnowflakeID(databaseName, schemaName, viewName, privilege, withGrantOption, onFuture, onAll, roles, shares) + grantID := helpers.EncodeSnowflakeID(databaseName, schemaName, viewName, privilege, withGrantOption, onFuture, onAll, roles, shares) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/resources/warehouse_grant.go b/pkg/resources/warehouse_grant.go index 26ff2de241..14b31b7eaa 100644 --- a/pkg/resources/warehouse_grant.go +++ b/pkg/resources/warehouse_grant.go @@ -104,7 +104,7 @@ func CreateWarehouseGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(warehouseName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(warehouseName, privilege, withGrantOption, roles) d.SetId(grantID) return ReadWarehouseGrant(d, meta) @@ -124,7 +124,7 @@ func ReadWarehouseGrant(d *schema.ResourceData, meta interface{}) error { return err } - grantID := helpers.SnowflakeID(warehouseName, privilege, withGrantOption, roles) + grantID := helpers.EncodeSnowflakeID(warehouseName, privilege, withGrantOption, roles) if grantID != d.Id() { d.SetId(grantID) } diff --git a/pkg/sdk/client.go b/pkg/sdk/client.go index ee5aa9c33f..20fa4a9c4e 100644 --- a/pkg/sdk/client.go +++ b/pkg/sdk/client.go @@ -5,7 +5,6 @@ import ( "database/sql" "fmt" "log" - "os" "github.com/jmoiron/sqlx" "github.com/luna-duclos/instrumentedsql" @@ -17,34 +16,26 @@ import ( // ObjectType is the type of object. type ObjectType string +const ( + ObjectTypeMaskingPolicy ObjectType = "MASKING POLICY" + ObjectTypePasswordPolicy ObjectType = "PASSWORD POLICY" +) + func (o ObjectType) String() string { return string(o) } -func DefaultConfig() *gosnowflake.Config { - cfg := &gosnowflake.Config{ - Account: os.Getenv("SNOWFLAKE_ACCOUNT"), - User: os.Getenv("SNOWFLAKE_USER"), - Password: os.Getenv("SNOWFLAKE_PASSWORD"), - Region: os.Getenv("SNOWFLAKE_REGION"), - Role: os.Getenv("SNOWFLAKE_ROLE"), - Host: os.Getenv("SNOWFLAKE_HOST"), - Warehouse: os.Getenv("SNOWFLAKE_WAREHOUSE"), - } - // us-west-2 is Snowflake's default region, but if you actually specify that it won't trigger the default code - // https://github.com/snowflakedb/gosnowflake/blob/52137ce8c32eaf93b0bd22fc5c7297beff339812/dsn.go#L61 - if cfg.Region == "us-west-2" { - cfg.Region = "" - } - return cfg -} - type Client struct { + config *gosnowflake.Config db *sqlx.DB dryRun bool ContextFunctions ContextFunctions + MaskingPolicies MaskingPolicies PasswordPolicies PasswordPolicies + Sessions Sessions + SystemFunctions SystemFunctions + Warehouses Warehouses } func NewDefaultClient() (*Client, error) { @@ -52,7 +43,9 @@ func NewDefaultClient() (*Client, error) { } func NewClient(cfg *gosnowflake.Config) (*Client, error) { + var err error if cfg == nil { + log.Printf("[DEBUG] Searching for default config in credentials chain...\n") cfg = DefaultConfig() } @@ -81,7 +74,8 @@ func NewClient(cfg *gosnowflake.Config) (*Client, error) { client := &Client{ // snowflake does not adhere to the normal sql driver interface, so we have to use unsafe - db: db.Unsafe(), + db: db.Unsafe(), + config: cfg, } client.initialize() @@ -110,8 +104,12 @@ func NewClientFromDB(db *sql.DB) *Client { func (c *Client) initialize() { b := &sqlBuilder{} - c.PasswordPolicies = &passwordPolicies{client: c, builder: b} c.ContextFunctions = &contextFunctions{client: c, builder: b} + c.MaskingPolicies = &maskingPolicies{client: c, builder: b} + c.PasswordPolicies = &passwordPolicies{client: c, builder: b} + c.Sessions = &sessions{client: c, builder: b} + c.SystemFunctions = &systemFunctions{client: c, builder: b} + c.Warehouses = &warehouses{client: c, builder: b} } func (c *Client) SetDryRun(dryRun bool) { diff --git a/pkg/sdk/client_integration_test.go b/pkg/sdk/client_integration_test.go new file mode 100644 index 0000000000..965b607009 --- /dev/null +++ b/pkg/sdk/client_integration_test.go @@ -0,0 +1,67 @@ +package sdk + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClient_newClient(t *testing.T) { + t.Run("with default config", func(t *testing.T) { + config := DefaultConfig() + _, err := NewClient(config) + require.NoError(t, err) + }) + + t.Run("uses env vars if values are missing", func(t *testing.T) { + cleanupEnvVars := setupEnvVars(t, "TEST_ACCOUNT", "TEST_USER", "abcd1234", "ACCOUNTADMIN", "") + t.Cleanup(cleanupEnvVars) + config := EnvConfig() + _, err := NewClient(config) + require.Error(t, err) + }) +} + +func TestClient_ping(t *testing.T) { + client := testClient(t) + err := client.Ping() + require.NoError(t, err) +} + +func TestClient_close(t *testing.T) { + client := testClient(t) + err := client.Close() + require.NoError(t, err) +} + +func TestClient_exec(t *testing.T) { + client := testClient(t) + ctx := context.Background() + _, err := client.exec(ctx, "SELECT 1") + require.NoError(t, err) +} + +func TestClient_query(t *testing.T) { + client := testClient(t) + ctx := context.Background() + rows := []struct { + One int `db:"ONE"` + }{} + err := client.query(ctx, &rows, "SELECT 1 AS ONE") + require.NoError(t, err) + require.NotNil(t, rows) + require.Equal(t, 1, len(rows)) + require.Equal(t, 1, rows[0].One) +} + +func TestClient_queryOne(t *testing.T) { + client := testClient(t) + ctx := context.Background() + row := struct { + One int `db:"ONE"` + }{} + err := client.queryOne(ctx, &row, "SELECT 1 AS ONE") + require.NoError(t, err) + require.Equal(t, 1, row.One) +} diff --git a/pkg/sdk/client_test.go b/pkg/sdk/client_test.go deleted file mode 100644 index c9e6180b50..0000000000 --- a/pkg/sdk/client_test.go +++ /dev/null @@ -1,111 +0,0 @@ -package sdk - -import ( - "context" - "os" - "testing" - - "github.com/snowflakedb/gosnowflake" - "github.com/stretchr/testify/require" -) - -func TestClient_newClient(t *testing.T) { - config := &gosnowflake.Config{} - t.Run("uses env vars if values are missing", func(t *testing.T) { - cleanupEnvVars := setupEnvVars(t, "TEST_ACCOUNT", "TEST_USER", "abcd1234", "ACCOUNTADMIN") - t.Cleanup(cleanupEnvVars) - _, err := NewClient(config) - require.ErrorIs(t, err, ErrAccountIsEmpty) - }) - - t.Run("with default config", func(t *testing.T) { - config := DefaultConfig() - _, err := NewClient(config) - require.NoError(t, err) - }) -} - -func TestClient_ping(t *testing.T) { - client := testClient(t) - err := client.Ping() - require.NoError(t, err) -} - -func TestClient_close(t *testing.T) { - client := testClient(t) - err := client.Close() - require.NoError(t, err) -} - -func TestClient_exec(t *testing.T) { - client := testClient(t) - ctx := context.Background() - _, err := client.exec(ctx, "SELECT 1") - require.NoError(t, err) -} - -func TestClient_query(t *testing.T) { - client := testClient(t) - ctx := context.Background() - rows := []struct { - One int `db:"ONE"` - }{} - err := client.query(ctx, &rows, "SELECT 1 AS ONE") - require.NoError(t, err) - require.NotNil(t, rows) - require.Equal(t, 1, len(rows)) - require.Equal(t, 1, rows[0].One) -} - -func TestClient_queryOne(t *testing.T) { - client := testClient(t) - ctx := context.Background() - row := struct { - One int `db:"ONE"` - }{} - err := client.queryOne(ctx, &row, "SELECT 1 AS ONE") - require.NoError(t, err) - require.Equal(t, 1, row.One) -} - -func TestClient_defaultConfig(t *testing.T) { - t.Run("with no environment variables", func(t *testing.T) { - cleanupEnvVars := setupEnvVars(t, "", "", "", "") - t.Cleanup(cleanupEnvVars) - config := DefaultConfig() - require.Equal(t, "", config.Account) - require.Equal(t, "", config.User) - require.Equal(t, "", config.Password) - require.Equal(t, "", config.Role) - }) - - t.Run("with environment variables", func(t *testing.T) { - cleanupEnvVars := setupEnvVars(t, "TEST_ACCOUNT", "TEST_USER", "abcd1234", "ACCOUNTADMIN") - t.Cleanup(cleanupEnvVars) - config := DefaultConfig() - require.Equal(t, "TEST_ACCOUNT", config.Account) - require.Equal(t, "TEST_USER", config.User) - require.Equal(t, "abcd1234", config.Password) - require.Equal(t, "ACCOUNTADMIN", config.Role) - }) -} - -func setupEnvVars(t *testing.T, account, user, password, role string) func() { - t.Helper() - orginalAccount := os.Getenv("SNOWFLAKE_ACCOUNT") - orginalUser := os.Getenv("SNOWFLAKE_USER") - originalPassword := os.Getenv("SNOWFLAKE_PASSWORD") - originalRole := os.Getenv("SNOWFLAKE_ROLE") - - os.Setenv("SNOWFLAKE_ACCOUNT", account) - os.Setenv("SNOWFLAKE_USER", user) - os.Setenv("SNOWFLAKE_PASSWORD", password) - os.Setenv("SNOWFLAKE_ROLE", role) - - return func() { - os.Setenv("SNOWFLAKE_ACCOUNT", orginalAccount) - os.Setenv("SNOWFLAKE_USER", orginalUser) - os.Setenv("SNOWFLAKE_PASSWORD", originalPassword) - os.Setenv("SNOWFLAKE_ROLE", originalRole) - } -} diff --git a/pkg/sdk/common_types.go b/pkg/sdk/common_types.go index bb45fdf004..fa21125f57 100644 --- a/pkg/sdk/common_types.go +++ b/pkg/sdk/common_types.go @@ -10,6 +10,17 @@ type Like struct { Pattern *string `ddl:"keyword,single_quotes"` } +type TagAssociation struct { + Name ObjectIdentifier `ddl:"identifier"` + eq bool `ddl:"static" db:"="` //lint:ignore U1000 This is used in the ddl tag + Value string `ddl:"keyword,single_quotes"` +} + +type TableColumnSignature struct { + Name string `ddl:"keyword,double_quotes"` + Type DataType `ddl:"keyword"` +} + type StringProperty struct { Value string DefaultValue string diff --git a/pkg/sdk/config.go b/pkg/sdk/config.go new file mode 100644 index 0000000000..cd0c964a1c --- /dev/null +++ b/pkg/sdk/config.go @@ -0,0 +1,127 @@ +package sdk + +import ( + "log" + "os" + "path/filepath" + + "github.com/pelletier/go-toml/v2" + "github.com/snowflakedb/gosnowflake" +) + +func DefaultConfig() *gosnowflake.Config { + config, err := ProfileConfig("default") + if err != nil || config == nil { + log.Printf("[DEBUG] No Snowflake config file found, falling back to environment variables: %v\n", err) + return EnvConfig() + } + return config +} + +func ProfileConfig(profile string) (*gosnowflake.Config, error) { + configs, err := loadConfigFile() + if err != nil { + return nil, err + } + + if profile == "" { + profile = "default" + } + var config *gosnowflake.Config + if cfg, ok := configs[profile]; ok { + log.Printf("[DEBUG] loading config for profile: \"%s\"", profile) + config = cfg + } + + // us-west-2 is Snowflake's default region, but if you actually specify that it won't trigger the default code + // https://github.com/snowflakedb/gosnowflake/blob/52137ce8c32eaf93b0bd22fc5c7297beff339812/dsn.go#L61 + if config.Region == "us-west-2" { + config.Region = "" + } + + return config, nil +} + +func MergeConfig(baseConfig *gosnowflake.Config, mergeConfig *gosnowflake.Config) *gosnowflake.Config { + if baseConfig == nil { + return mergeConfig + } + if mergeConfig.Account != "" { + baseConfig.Account = mergeConfig.Account + } + if mergeConfig.User != "" { + baseConfig.User = mergeConfig.User + } + if mergeConfig.Password != "" { + baseConfig.Password = mergeConfig.Password + } + if mergeConfig.Role != "" { + baseConfig.Role = mergeConfig.Role + } + if mergeConfig.Region != "" { + baseConfig.Region = mergeConfig.Region + } + if mergeConfig.Host != "" { + baseConfig.Host = mergeConfig.Host + } + return baseConfig +} + +func configFile() (string, error) { + // has the user overwridden the default config path? + if configPath, ok := os.LookupEnv("SNOWFLAKE_CONFIG_PATH"); ok { + return configPath, nil + } + dir, err := os.UserHomeDir() + if err != nil { + return "", err + } + // default config path is ~/.snowflake/config. + return filepath.Join(dir, ".snowflake", "config"), nil +} + +func EnvConfig() *gosnowflake.Config { + config := &gosnowflake.Config{} + + if account, ok := os.LookupEnv("SNOWFLAKE_ACCOUNT"); ok { + config.Account = account + } + if user, ok := os.LookupEnv("SNOWFLAKE_USER"); ok { + config.User = user + } + if password, ok := os.LookupEnv("SNOWFLAKE_PASSWORD"); ok { + config.Password = password + } + if role, ok := os.LookupEnv("SNOWFLAKE_ROLE"); ok { + config.Role = role + } + if region, ok := os.LookupEnv("SNOWFLAKE_REGION"); ok { + config.Region = region + } + if host, ok := os.LookupEnv("SNOWFLAKE_HOST"); ok { + config.Host = host + } + if warehouse, ok := os.LookupEnv("SNOWFLAKE_WAREHOUSE"); ok { + config.Warehouse = warehouse + } + + return config +} + +func loadConfigFile() (map[string]*gosnowflake.Config, error) { + path, err := configFile() + if err != nil { + return nil, err + } + dat, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var s map[string]*gosnowflake.Config + err = toml.Unmarshal(dat, &s) + if err != nil { + log.Printf("[DEBUG] error unmarshalling config file: %v\n", err) + return nil, nil + } + return s, nil +} diff --git a/pkg/sdk/config_test.go b/pkg/sdk/config_test.go new file mode 100644 index 0000000000..3ab19bdf1b --- /dev/null +++ b/pkg/sdk/config_test.go @@ -0,0 +1,121 @@ +package sdk + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadConfigFile(t *testing.T) { + c := ` + [default] + account='TEST_ACCOUNT' + user='TEST_USER' + password='abcd1234' + role='ACCOUNTADMIN' + + [securityadmin] + account='TEST_ACCOUNT' + user='TEST_USER' + password='abcd1234' + role='SECURITYADMIN' + ` + configPath := testFile(t, "config", []byte(c)) + cleanupEnvVars := setupEnvVars(t, "", "", "", "", configPath) + t.Cleanup(cleanupEnvVars) + m, err := loadConfigFile() + require.NoError(t, err) + assert.Equal(t, "TEST_ACCOUNT", m["default"].Account) + assert.Equal(t, "TEST_USER", m["default"].User) + assert.Equal(t, "abcd1234", m["default"].Password) + assert.Equal(t, "ACCOUNTADMIN", m["default"].Role) + assert.Equal(t, "TEST_ACCOUNT", m["securityadmin"].Account) + assert.Equal(t, "TEST_USER", m["securityadmin"].User) + assert.Equal(t, "abcd1234", m["securityadmin"].Password) + assert.Equal(t, "SECURITYADMIN", m["securityadmin"].Role) +} + +func TestEnvConfig(t *testing.T) { + cleanupEnvVars := setupEnvVars(t, "TEST_ACCOUNT", "TEST_USER", "abcd1234", "ACCOUNTADMIN", "") + t.Cleanup(cleanupEnvVars) + config := EnvConfig() + assert.Equal(t, "TEST_ACCOUNT", config.Account) + assert.Equal(t, "TEST_USER", config.User) + assert.Equal(t, "abcd1234", config.Password) + assert.Equal(t, "ACCOUNTADMIN", config.Role) +} + +func TestProfileConfig(t *testing.T) { + c := ` + [securityadmin] + account='TEST_ACCOUNT' + user='TEST_USER' + password='abcd1234' + role='SECURITYADMIN' + ` + configPath := testFile(t, "config", []byte(c)) + cleanupEnvVars := setupEnvVars(t, "", "", "", "", configPath) + t.Cleanup(cleanupEnvVars) + config, err := ProfileConfig("securityadmin") + require.NoError(t, err) + assert.Equal(t, "TEST_ACCOUNT", config.Account) + assert.Equal(t, "TEST_USER", config.User) + assert.Equal(t, "abcd1234", config.Password) + assert.Equal(t, "SECURITYADMIN", config.Role) +} + +func TestDefaultConfig(t *testing.T) { + t.Run("with no environment variables", func(t *testing.T) { + cleanupEnvVars := setupEnvVars(t, "", "", "", "", "") + t.Cleanup(cleanupEnvVars) + config := DefaultConfig() + assert.Equal(t, "", config.Account) + assert.Equal(t, "", config.User) + assert.Equal(t, "", config.Password) + assert.Equal(t, "", config.Role) + }) + + t.Run("with environment variables", func(t *testing.T) { + cleanupEnvVars := setupEnvVars(t, "TEST_ACCOUNT", "TEST_USER", "abcd1234", "ACCOUNTADMIN", "") + t.Cleanup(cleanupEnvVars) + config := DefaultConfig() + assert.Equal(t, "TEST_ACCOUNT", config.Account) + assert.Equal(t, "TEST_USER", config.User) + assert.Equal(t, "abcd1234", config.Password) + assert.Equal(t, "ACCOUNTADMIN", config.Role) + }) +} + +func testFile(t *testing.T, filename string, dat []byte) string { + t.Helper() + path := filepath.Join(t.TempDir(), filename) + err := os.WriteFile(path, dat, 0o600) + require.NoError(t, err) + return path +} + +func setupEnvVars(t *testing.T, account, user, password, role, configPath string) func() { + t.Helper() + orginalAccount := os.Getenv("SNOWFLAKE_ACCOUNT") + orginalUser := os.Getenv("SNOWFLAKE_USER") + originalPassword := os.Getenv("SNOWFLAKE_PASSWORD") + originalRole := os.Getenv("SNOWFLAKE_ROLE") + originalPath := os.Getenv("SNOWFLAKE_CONFIG_PATH") + + os.Setenv("SNOWFLAKE_ACCOUNT", account) + os.Setenv("SNOWFLAKE_USER", user) + os.Setenv("SNOWFLAKE_PASSWORD", password) + os.Setenv("SNOWFLAKE_ROLE", role) + os.Setenv("SNOWFLAKE_CONFIG_PATH", configPath) + + return func() { + os.Setenv("SNOWFLAKE_ACCOUNT", orginalAccount) + os.Setenv("SNOWFLAKE_USER", orginalUser) + os.Setenv("SNOWFLAKE_PASSWORD", originalPassword) + os.Setenv("SNOWFLAKE_ROLE", originalRole) + os.Setenv("SNOWFLAKE_CONFIG_PATH", originalPath) + } +} diff --git a/pkg/sdk/context_functions.go b/pkg/sdk/context_functions.go index 81d8b2af9c..2c821d9ad4 100644 --- a/pkg/sdk/context_functions.go +++ b/pkg/sdk/context_functions.go @@ -1,9 +1,18 @@ package sdk -import "context" +import ( + "context" + "database/sql" +) type ContextFunctions interface { + // Session functions. CurrentSession(ctx context.Context) (string, error) + + // Session Object functions. + CurrentDatabase(ctx context.Context) (string, error) + CurrentSchema(ctx context.Context) (string, error) + CurrentWarehouse(ctx context.Context) (string, error) } var _ ContextFunctions = (*contextFunctions)(nil) @@ -15,11 +24,53 @@ type contextFunctions struct { func (c *contextFunctions) CurrentSession(ctx context.Context) (string, error) { s := &struct { - CurrentSession string `db:"CURRENT_SESSION()"` + CurrentSession string `db:"CURRENT_SESSION"` }{} - err := c.client.queryOne(ctx, s, "SELECT CURRENT_SESSION()") + err := c.client.queryOne(ctx, s, "SELECT CURRENT_SESSION() as CURRENT_SESSION") if err != nil { return "", err } return s.CurrentSession, nil } + +func (c *contextFunctions) CurrentDatabase(ctx context.Context) (string, error) { + s := &struct { + CurrentDatabase sql.NullString `db:"CURRENT_DATABASE"` + }{} + err := c.client.queryOne(ctx, s, "SELECT CURRENT_DATABASE() as CURRENT_DATABASE") + if err != nil { + return "", err + } + if !s.CurrentDatabase.Valid { + return "", nil + } + return s.CurrentDatabase.String, nil +} + +func (c *contextFunctions) CurrentSchema(ctx context.Context) (string, error) { + s := &struct { + CurrentSchema sql.NullString `db:"CURRENT_SCHEMA"` + }{} + err := c.client.queryOne(ctx, s, "SELECT CURRENT_SCHEMA() as CURRENT_SCHEMA") + if err != nil { + return "", err + } + if !s.CurrentSchema.Valid { + return "", nil + } + return s.CurrentSchema.String, nil +} + +func (c *contextFunctions) CurrentWarehouse(ctx context.Context) (string, error) { + s := &struct { + CurrentWarehouse sql.NullString `db:"CURRENT_WAREHOUSE"` + }{} + err := c.client.queryOne(ctx, s, "SELECT CURRENT_WAREHOUSE() as CURRENT_WAREHOUSE") + if err != nil { + return "", err + } + if !s.CurrentWarehouse.Valid { + return "", nil + } + return s.CurrentWarehouse.String, nil +} diff --git a/pkg/sdk/context_functions_integration_test.go b/pkg/sdk/context_functions_integration_test.go index d17af080f8..d0af3dfede 100644 --- a/pkg/sdk/context_functions_integration_test.go +++ b/pkg/sdk/context_functions_integration_test.go @@ -16,3 +16,43 @@ func TestInt_CurrentSession(t *testing.T) { require.NoError(t, err) assert.NotEmpty(t, session) } + +func TestInt_CurrentDatabase(t *testing.T) { + client := testClient(t) + ctx := context.Background() + databaseTest, databaseCleanup := createDatabase(t, client) + t.Cleanup(databaseCleanup) + err := client.Sessions.UseDatabase(ctx, databaseTest.ID()) + require.NoError(t, err) + db, err := client.ContextFunctions.CurrentDatabase(ctx) + require.NoError(t, err) + assert.NotEmpty(t, db) +} + +func TestInt_CurrentSchema(t *testing.T) { + client := testClient(t) + ctx := context.Background() + + databaseTest, databaseCleanup := createDatabase(t, client) + t.Cleanup(databaseCleanup) + schemaTest, schemaCleanup := createSchema(t, client, databaseTest) + t.Cleanup(schemaCleanup) + err := client.Sessions.UseSchema(ctx, schemaTest.ID()) + require.NoError(t, err) + schema, err := client.ContextFunctions.CurrentSchema(ctx) + require.NoError(t, err) + assert.NotEmpty(t, schema) +} + +func TestInt_CurrentWarehouse(t *testing.T) { + client := testClient(t) + ctx := context.Background() + + warehouseTest, warehouseCleanup := createWarehouse(t, client) + t.Cleanup(warehouseCleanup) + err := client.Sessions.UseWarehouse(ctx, warehouseTest.ID()) + require.NoError(t, err) + warehouse, err := client.ContextFunctions.CurrentWarehouse(ctx) + require.NoError(t, err) + assert.NotEmpty(t, warehouse) +} diff --git a/pkg/sdk/data_types.go b/pkg/sdk/data_types.go index b2455e6c31..f2e66837b6 100644 --- a/pkg/sdk/data_types.go +++ b/pkg/sdk/data_types.go @@ -24,10 +24,35 @@ const ( DataTypeArray DataType = "ARRAY" DataTypeGeography DataType = "GEOGRAPHY" DataTypeGeometry DataType = "GEOMETRY" + + // DataTypeUnknown is used for testing purposes only. + DataTypeUnknown DataType = "UNKNOWN" ) -func NewDataType(s string) DataType { +func DataTypeFromString(s string) DataType { dType := strings.ToUpper(s) + + switch dType { + case "DATE": + return DataTypeDate + case "TIME": + return DataTypeTime + case "TIMESTAMP_LTZ": + return DataTypeTimestampLTZ + case "TIMESTAMP_TZ": + return DataTypeTimestampTZ + case "VARIANT": + return DataTypeVariant + case "OBJECT": + return DataTypeObject + case "ARRAY": + return DataTypeArray + case "GEOGRAPHY": + return DataTypeGeography + case "GEOMETRY": + return DataTypeGeometry + } + numberSynonyms := []string{"NUMBER", "DECIMAL", "NUMERIC", "INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT"} if slices.ContainsFunc(numberSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { return DataTypeNumber @@ -49,6 +74,11 @@ func NewDataType(s string) DataType { if slices.Contains(booleanSynonyms, dType) { return DataTypeBoolean } - // todo: date, time, timestamp, variant, object, array, geography, geometry - return DataType(dType) + + timestampNTZSynonyms := []string{"DATETIME", "TIMESTAMP", "TIMESTAMP_NTZ"} + if slices.ContainsFunc(timestampNTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { + return DataTypeTimestampNTZ + } + + return DataTypeUnknown } diff --git a/pkg/sdk/data_types_test.go b/pkg/sdk/data_types_test.go new file mode 100644 index 0000000000..6a062a2a1b --- /dev/null +++ b/pkg/sdk/data_types_test.go @@ -0,0 +1,80 @@ +package sdk + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDataTypeFromString(t *testing.T) { + type test struct { + input string + want DataType + } + + tests := []test{ + // case insensitive. + {input: "STRING", want: DataTypeVARCHAR}, + {input: "string", want: DataTypeVARCHAR}, + {input: "String", want: DataTypeVARCHAR}, + + // number types. + {input: "number", want: DataTypeNumber}, + {input: "decimal", want: DataTypeNumber}, + {input: "numeric", want: DataTypeNumber}, + {input: "int", want: DataTypeNumber}, + {input: "integer", want: DataTypeNumber}, + {input: "bigint", want: DataTypeNumber}, + {input: "smallint", want: DataTypeNumber}, + {input: "tinyint", want: DataTypeNumber}, + {input: "byteint", want: DataTypeNumber}, + + // float types. + {input: "float", want: DataTypeFloat}, + {input: "float4", want: DataTypeFloat}, + {input: "float8", want: DataTypeFloat}, + {input: "double", want: DataTypeFloat}, + {input: "double precision", want: DataTypeFloat}, + {input: "real", want: DataTypeFloat}, + + // varchar types. + {input: "varchar", want: DataTypeVARCHAR}, + {input: "char", want: DataTypeVARCHAR}, + {input: "character", want: DataTypeVARCHAR}, + {input: "string", want: DataTypeVARCHAR}, + {input: "text", want: DataTypeVARCHAR}, + + // binary types. + {input: "binary", want: DataTypeBinary}, + {input: "varbinary", want: DataTypeBinary}, + {input: "boolean", want: DataTypeBoolean}, + + // boolean types. + {input: "boolean", want: DataTypeBoolean}, + {input: "bool", want: DataTypeBoolean}, + + // timestamp ntz types. + {input: "datetime", want: DataTypeTimestampNTZ}, + {input: "timestamp", want: DataTypeTimestampNTZ}, + {input: "timestamp_ntz", want: DataTypeTimestampNTZ}, + + // all othertypes + {input: "date", want: DataTypeDate}, + {input: "time", want: DataTypeTime}, + {input: "timestamp_ltz", want: DataTypeTimestampLTZ}, + {input: "timestamp_tz", want: DataTypeTimestampTZ}, + {input: "variant", want: DataTypeVariant}, + {input: "object", want: DataTypeObject}, + {input: "array", want: DataTypeArray}, + {input: "geography", want: DataTypeGeography}, + {input: "geometry", want: DataTypeGeometry}, + {input: "invalid", want: DataTypeUnknown}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + got := DataTypeFromString(tc.input) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/pkg/sdk/databases.go b/pkg/sdk/databases.go new file mode 100644 index 0000000000..26d38ec218 --- /dev/null +++ b/pkg/sdk/databases.go @@ -0,0 +1,12 @@ +package sdk + +// placeholder for the real implementation. +type DatabaseCreateOptions struct{} + +type Database struct { + Name string +} + +func (v *Database) ID() AccountObjectIdentifier { + return NewAccountObjectIdentifier(v.Name) +} diff --git a/pkg/sdk/errors.go b/pkg/sdk/errors.go index 02f423fa5d..640a5378cc 100644 --- a/pkg/sdk/errors.go +++ b/pkg/sdk/errors.go @@ -7,6 +7,7 @@ import ( ) var ( + // go-snowflake errors. ErrObjectNotExistOrAuthorized = errors.New("object does not exist or not authorized") ErrAccountIsEmpty = errors.New("account is empty") ) diff --git a/pkg/sdk/helper_test.go b/pkg/sdk/helper_test.go index 8084214af5..4ba0b847d1 100644 --- a/pkg/sdk/helper_test.go +++ b/pkg/sdk/helper_test.go @@ -5,23 +5,65 @@ import ( "fmt" "testing" + "github.com/brianvoe/gofakeit/v6" "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/require" ) func randomSchemaObjectIdentifier(t *testing.T) SchemaObjectIdentifier { t.Helper() - return NewSchemaObjectIdentifier(randomString(t), randomString(t), randomString(t)) + return NewSchemaObjectIdentifier(randomStringRange(t, 8, 12), randomStringRange(t, 8, 12), randomStringRange(t, 8, 12)) } func randomSchemaIdentifier(t *testing.T) SchemaIdentifier { t.Helper() - return NewSchemaIdentifier(randomString(t), randomString(t)) + return NewSchemaIdentifier(randomStringRange(t, 8, 12), randomStringRange(t, 8, 12)) } func randomAccountObjectIdentifier(t *testing.T) AccountObjectIdentifier { t.Helper() - return NewAccountObjectIdentifier(randomString(t)) + return NewAccountObjectIdentifier(randomStringRange(t, 8, 12)) +} + +func useDatabase(t *testing.T, client *Client, databaseID AccountObjectIdentifier) func() { + t.Helper() + ctx := context.Background() + orgDB, err := client.ContextFunctions.CurrentDatabase(ctx) + require.NoError(t, err) + err = client.Sessions.UseDatabase(ctx, databaseID) + require.NoError(t, err) + return func() { + err := client.Sessions.UseDatabase(ctx, NewAccountObjectIdentifier(orgDB)) + require.NoError(t, err) + } +} + +func useSchema(t *testing.T, client *Client, schemaID SchemaIdentifier) func() { + t.Helper() + ctx := context.Background() + orgDB, err := client.ContextFunctions.CurrentDatabase(ctx) + require.NoError(t, err) + orgSchema, err := client.ContextFunctions.CurrentSchema(ctx) + require.NoError(t, err) + err = client.Sessions.UseSchema(ctx, schemaID) + require.NoError(t, err) + return func() { + err := client.Sessions.UseSchema(ctx, NewSchemaIdentifier(orgDB, orgSchema)) + require.NoError(t, err) + } +} + +func useWarehouse(t *testing.T, client *Client, warehouseID AccountObjectIdentifier) func() { + t.Helper() + ctx := context.Background() + orgWarehouse, err := client.ContextFunctions.CurrentWarehouse(ctx) + require.NoError(t, err) + err = client.Sessions.UseWarehouse(ctx, warehouseID) + require.NoError(t, err) + return func() { + err := client.Sessions.UseWarehouse(ctx, NewAccountObjectIdentifier(orgWarehouse)) + require.NoError(t, err) + } } func testBuilder(t *testing.T) *sqlBuilder { @@ -31,6 +73,7 @@ func testBuilder(t *testing.T) *sqlBuilder { func testClient(t *testing.T) *Client { t.Helper() + client, err := NewDefaultClient() if err != nil { t.Fatal(err) @@ -39,31 +82,62 @@ func testClient(t *testing.T) *Client { return client } -// mock structs until we have more of the SDK implemented. -type DatabaseCreateOptions struct{} +func randomUUID(t *testing.T) string { + t.Helper() + v, err := uuid.GenerateUUID() + require.NoError(t, err) + return v +} -type Database struct { - Name string +func randomComment(t *testing.T) string { + t.Helper() + return gofakeit.Sentence(10) } -func (v *Database) Identifier() AccountObjectIdentifier { - return NewAccountObjectIdentifier(v.Name) +func randomBool(t *testing.T) bool { + t.Helper() + return gofakeit.Bool() } -type Schema struct { - DatabaseName string - Name string +func randomString(t *testing.T) string { + t.Helper() + return gofakeit.Password(true, true, true, true, false, 28) } -func (v *Schema) Identifier() SchemaIdentifier { - return NewSchemaIdentifier(v.DatabaseName, v.Name) +func randomStringRange(t *testing.T, min, max int) string { + t.Helper() + if min > max { + t.Errorf("min %d is greater than max %d", min, max) + } + return gofakeit.Password(true, true, true, true, false, randomIntRange(t, min, max)) } -func randomString(t *testing.T) string { +func randomIntRange(t *testing.T, min, max int) int { t.Helper() - v, err := uuid.GenerateUUID() + if min > max { + t.Errorf("min %d is greater than max %d", min, max) + } + return gofakeit.IntRange(min, max) +} + +func createWarehouse(t *testing.T, client *Client) (*Warehouse, func()) { + t.Helper() + return createWarehouseWithOptions(t, client, &WarehouseCreateOptions{}) +} + +func createWarehouseWithOptions(t *testing.T, client *Client, _ *WarehouseCreateOptions) (*Warehouse, func()) { + t.Helper() + name := randomStringRange(t, 8, 28) + id := NewAccountObjectIdentifier(name) + ctx := context.Background() + err := client.Warehouses.Create(ctx, id, nil) require.NoError(t, err) - return v + return &Warehouse{ + Name: name, + }, func() { + err := client.Warehouses.Drop(ctx, id, nil) + require.NoError(t, err) + } } func createDatabase(t *testing.T, client *Client) (*Database, func()) { @@ -73,7 +147,7 @@ func createDatabase(t *testing.T, client *Client) (*Database, func()) { func createDatabaseWithOptions(t *testing.T, client *Client, _ *DatabaseCreateOptions) (*Database, func()) { t.Helper() - name := randomString(t) + name := randomStringRange(t, 8, 28) ctx := context.Background() _, err := client.exec(ctx, fmt.Sprintf("CREATE DATABASE \"%s\"", name)) require.NoError(t, err) @@ -87,7 +161,7 @@ func createDatabaseWithOptions(t *testing.T, client *Client, _ *DatabaseCreateOp func createSchema(t *testing.T, client *Client, database *Database) (*Schema, func()) { t.Helper() - name := randomString(t) + name := randomStringRange(t, 8, 28) ctx := context.Background() _, err := client.exec(ctx, fmt.Sprintf("CREATE SCHEMA \"%s\".\"%s\"", database.Name, name)) require.NoError(t, err) @@ -100,6 +174,27 @@ func createSchema(t *testing.T, client *Client, database *Database) (*Schema, fu } } +func createTag(t *testing.T, client *Client, database *Database, schema *Schema) (*Tag, func()) { + t.Helper() + return createTagWithOptions(t, client, database, schema, &TagCreateOptions{}) +} + +func createTagWithOptions(t *testing.T, client *Client, database *Database, schema *Schema, _ *TagCreateOptions) (*Tag, func()) { + t.Helper() + name := randomStringRange(t, 8, 28) + ctx := context.Background() + _, err := client.exec(ctx, fmt.Sprintf("CREATE TAG \"%s\".\"%s\".\"%s\"", database.Name, schema.Name, name)) + require.NoError(t, err) + return &Tag{ + Name: name, + DatabaseName: database.Name, + SchemaName: schema.Name, + }, func() { + _, err := client.exec(ctx, fmt.Sprintf("DROP TAG \"%s\".\"%s\".\"%s\"", database.Name, schema.Name, name)) + require.NoError(t, err) + } +} + func createPasswordPolicyWithOptions(t *testing.T, client *Client, database *Database, schema *Schema, options *PasswordPolicyCreateOptions) (*PasswordPolicy, func()) { t.Helper() var databaseCleanup func() @@ -110,7 +205,7 @@ func createPasswordPolicyWithOptions(t *testing.T, client *Client, database *Dat if schema == nil { schema, schemaCleanup = createSchema(t, client, database) } - name := randomString(t) + name := randomUUID(t) id := NewSchemaObjectIdentifier(schema.DatabaseName, schema.Name, name) ctx := context.Background() err := client.PasswordPolicies.Create(ctx, id, options) @@ -121,7 +216,7 @@ func createPasswordPolicyWithOptions(t *testing.T, client *Client, database *Dat Pattern: String(name), }, In: &In{ - Schema: schema.Identifier(), + Schema: schema.ID(), }, } passwordPolicyList, err := client.PasswordPolicies.Show(ctx, showOptions) @@ -143,3 +238,61 @@ func createPasswordPolicy(t *testing.T, client *Client, database *Database, sche t.Helper() return createPasswordPolicyWithOptions(t, client, database, schema, nil) } + +func createMaskingPolicyWithOptions(t *testing.T, client *Client, database *Database, schema *Schema, signature []TableColumnSignature, returns DataType, expression string, options *MaskingPolicyCreateOptions) (*MaskingPolicy, func()) { + t.Helper() + var databaseCleanup func() + if database == nil { + database, databaseCleanup = createDatabase(t, client) + } + var schemaCleanup func() + if schema == nil { + schema, schemaCleanup = createSchema(t, client, database) + } + name := randomString(t) + id := NewSchemaObjectIdentifier(schema.DatabaseName, schema.Name, name) + ctx := context.Background() + err := client.MaskingPolicies.Create(ctx, id, signature, returns, expression, options) + require.NoError(t, err) + + showOptions := &MaskingPolicyShowOptions{ + Like: &Like{ + Pattern: String(name), + }, + In: &In{ + Schema: schema.ID(), + }, + } + maskingPolicyList, err := client.MaskingPolicies.Show(ctx, showOptions) + require.NoError(t, err) + require.Equal(t, 1, len(maskingPolicyList)) + return maskingPolicyList[0], func() { + err := client.MaskingPolicies.Drop(ctx, id) + require.NoError(t, err) + if schemaCleanup != nil { + schemaCleanup() + } + if databaseCleanup != nil { + databaseCleanup() + } + } +} + +func createMaskingPolicy(t *testing.T, client *Client, database *Database, schema *Schema) (*MaskingPolicy, func()) { + t.Helper() + signature := []TableColumnSignature{ + { + Name: randomString(t), + Type: DataTypeVARCHAR, + }, + } + n := randomIntRange(t, 0, 5) + for i := 0; i < n; i++ { + signature = append(signature, TableColumnSignature{ + Name: randomString(t), + Type: DataTypeVARCHAR, + }) + } + expression := "REPLACE('X', 1, 2)" + return createMaskingPolicyWithOptions(t, client, database, schema, signature, DataTypeVARCHAR, expression, &MaskingPolicyCreateOptions{}) +} diff --git a/pkg/sdk/identifier_helpers.go b/pkg/sdk/identifier_helpers.go index 745a8a8294..b4bc7dcd24 100644 --- a/pkg/sdk/identifier_helpers.go +++ b/pkg/sdk/identifier_helpers.go @@ -6,105 +6,146 @@ import ( ) type ObjectIdentifier interface { + Name() string FullyQualifiedName() string } type AccountObjectIdentifier struct { - Name string + name string } func NewAccountObjectIdentifier(name string) AccountObjectIdentifier { - return AccountObjectIdentifier{Name: name} + return AccountObjectIdentifier{name: name} +} + +func (i AccountObjectIdentifier) Name() string { + return i.name } func (i AccountObjectIdentifier) FullyQualifiedName() string { - if i.Name == "" { + if i.name == "" { return "" } - return fmt.Sprintf(`"%v"`, i.Name) + return fmt.Sprintf(`"%v"`, i.name) } type SchemaIdentifier struct { - DatabaseName string - SchemaName string + databaseName string + schemaName string } func NewSchemaIdentifier(databaseName, schemaName string) SchemaIdentifier { return SchemaIdentifier{ - DatabaseName: strings.Trim(databaseName, `"`), - SchemaName: strings.Trim(schemaName, `"`), + databaseName: strings.Trim(databaseName, `"`), + schemaName: strings.Trim(schemaName, `"`), } } func NewSchemaIdentifierFromFullyQualifiedName(fullyQualifiedName string) SchemaIdentifier { parts := strings.Split(fullyQualifiedName, ".") return SchemaIdentifier{ - DatabaseName: strings.Trim(parts[0], `"`), - SchemaName: strings.Trim(parts[1], `"`), + databaseName: strings.Trim(parts[0], `"`), + schemaName: strings.Trim(parts[1], `"`), } } +func (i SchemaIdentifier) DatabaseName() string { + return i.databaseName +} + +func (i SchemaIdentifier) Name() string { + return i.schemaName +} + func (i SchemaIdentifier) FullyQualifiedName() string { - if i.SchemaName == "" && i.DatabaseName == "" { + if i.schemaName == "" && i.databaseName == "" { return "" } - return fmt.Sprintf(`"%v"."%v"`, i.DatabaseName, i.SchemaName) + return fmt.Sprintf(`"%v"."%v"`, i.databaseName, i.schemaName) } type SchemaObjectIdentifier struct { - DatabaseName string - SchemaName string - Name string + databaseName string + schemaName string + name string } func NewSchemaObjectIdentifier(databaseName, schemaName, name string) SchemaObjectIdentifier { return SchemaObjectIdentifier{ - DatabaseName: strings.Trim(databaseName, `"`), - SchemaName: strings.Trim(schemaName, `"`), - Name: strings.Trim(name, `"`), + databaseName: strings.Trim(databaseName, `"`), + schemaName: strings.Trim(schemaName, `"`), + name: strings.Trim(name, `"`), } } func NewSchemaObjectIdentifierFromFullyQualifiedName(fullyQualifiedName string) SchemaObjectIdentifier { parts := strings.Split(fullyQualifiedName, ".") return SchemaObjectIdentifier{ - DatabaseName: strings.Trim(parts[0], `"`), - SchemaName: strings.Trim(parts[1], `"`), - Name: strings.Trim(parts[2], `"`), + databaseName: strings.Trim(parts[0], `"`), + schemaName: strings.Trim(parts[1], `"`), + name: strings.Trim(parts[2], `"`), } } +func (i SchemaObjectIdentifier) DatabaseName() string { + return i.databaseName +} + +func (i SchemaObjectIdentifier) SchemaName() string { + return i.schemaName +} + +func (i SchemaObjectIdentifier) Name() string { + return i.name +} + func (i SchemaObjectIdentifier) FullyQualifiedName() string { - if i.SchemaName == "" && i.DatabaseName == "" && i.Name == "" { + if i.schemaName == "" && i.databaseName == "" && i.name == "" { return "" } - return fmt.Sprintf(`"%v"."%v"."%v"`, i.DatabaseName, i.SchemaName, i.Name) + return fmt.Sprintf(`"%v"."%v"."%v"`, i.databaseName, i.schemaName, i.name) } type TableColumnIdentifier struct { - DatabaseName string - SchemaName string - TableName string - ColumnName string + databaseName string + schemaName string + tableName string + columnName string } func NewTableColumnIdentifier(databaseName, schemaName, tableName, columnName string) TableColumnIdentifier { - return TableColumnIdentifier{DatabaseName: databaseName, SchemaName: schemaName, TableName: tableName, ColumnName: columnName} + return TableColumnIdentifier{databaseName: databaseName, schemaName: schemaName, tableName: tableName, columnName: columnName} } func NewTableColumnIdentifierFromFullyQualifiedName(fullyQualifiedName string) TableColumnIdentifier { parts := strings.Split(fullyQualifiedName, ".") return TableColumnIdentifier{ - DatabaseName: strings.Trim(parts[0], `"`), - SchemaName: strings.Trim(parts[1], `"`), - TableName: strings.Trim(parts[2], `"`), - ColumnName: strings.Trim(parts[3], `"`), + databaseName: strings.Trim(parts[0], `"`), + schemaName: strings.Trim(parts[1], `"`), + tableName: strings.Trim(parts[2], `"`), + columnName: strings.Trim(parts[3], `"`), } } +func (i TableColumnIdentifier) DatabaseName() string { + return i.databaseName +} + +func (i TableColumnIdentifier) SchemaName() string { + return i.schemaName +} + +func (i TableColumnIdentifier) TableName() string { + return i.tableName +} + +func (i TableColumnIdentifier) Name() string { + return i.columnName +} + func (i TableColumnIdentifier) FullyQualifiedName() string { - if i.SchemaName == "" && i.DatabaseName == "" && i.TableName == "" && i.ColumnName == "" { + if i.schemaName == "" && i.databaseName == "" && i.tableName == "" && i.columnName == "" { return "" } - return fmt.Sprintf(`"%v"."%v"."%v"."%v"`, i.DatabaseName, i.SchemaName, i.TableName, i.ColumnName) + return fmt.Sprintf(`"%v"."%v"."%v"."%v"`, i.databaseName, i.schemaName, i.tableName, i.columnName) } diff --git a/pkg/sdk/masking_policy.go b/pkg/sdk/masking_policy.go new file mode 100644 index 0000000000..2d6a905972 --- /dev/null +++ b/pkg/sdk/masking_policy.go @@ -0,0 +1,359 @@ +package sdk + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/buger/jsonparser" +) + +// Compile-time proof of interface implementation. +var _ MaskingPolicies = (*maskingPolicies)(nil) + +// MaskingPolicies describes all the masking policy related methods that the +// Snowflake API supports. +type MaskingPolicies interface { + // Create creates a new masking policy. + Create(ctx context.Context, id SchemaObjectIdentifier, signature []TableColumnSignature, returns DataType, expression string, opts *MaskingPolicyCreateOptions) error + // Alter modifies an existing masking policy. + Alter(ctx context.Context, id SchemaObjectIdentifier, opts *MaskingPolicyAlterOptions) error + // Drop removes a masking policy. + Drop(ctx context.Context, id SchemaObjectIdentifier) error + // Show returns a list of masking policies. + Show(ctx context.Context, opts *MaskingPolicyShowOptions) ([]*MaskingPolicy, error) + // Describe returns the details of a masking policy. + Describe(ctx context.Context, id SchemaObjectIdentifier) (*MaskingPolicyDetails, error) +} + +// maskingPolicies implements MaskingPolicies. +type maskingPolicies struct { + client *Client + builder *sqlBuilder +} + +type MaskingPolicyCreateOptions struct { + create bool `ddl:"static" db:"CREATE"` //lint:ignore U1000 This is used in the ddl tag + OrReplace *bool `ddl:"keyword" db:"OR REPLACE"` + maskingPolicy bool `ddl:"static" db:"MASKING POLICY"` //lint:ignore U1000 This is used in the ddl tag + IfNotExists *bool `ddl:"keyword" db:"IF NOT EXISTS"` + name SchemaObjectIdentifier `ddl:"identifier"` + + // required + signature []TableColumnSignature `ddl:"list" db:"AS"` + returns DataType `ddl:"command" db:"RETURNS"` + arrow bool `ddl:"static" db:"->"` //lint:ignore U1000 This is used in the ddl tag + body string `ddl:"keyword" db:"EXPRESSION"` + + // optional + Comment *string `ddl:"parameter,single_quotes" db:"COMMENT"` + ExemptOtherPolicies *bool `ddl:"parameter" db:"EXEMPT_OTHER_POLICIES"` +} + +func (opts *MaskingPolicyCreateOptions) validate() error { + if opts.name.FullyQualifiedName() == "" { + return fmt.Errorf("name is required") + } + + return nil +} + +func (v *maskingPolicies) Create(ctx context.Context, id SchemaObjectIdentifier, signature []TableColumnSignature, returns DataType, body string, opts *MaskingPolicyCreateOptions) error { + if opts == nil { + opts = &MaskingPolicyCreateOptions{} + } + opts.name = id + opts.signature = signature + opts.returns = returns + opts.body = body + if err := opts.validate(); err != nil { + return err + } + clauses, err := v.builder.parseStruct(opts) + if err != nil { + return err + } + stmt := v.builder.sql(clauses...) + _, err = v.client.exec(ctx, stmt) + return err +} + +type MaskingPolicyAlterOptions struct { + alter bool `ddl:"static" db:"ALTER"` //lint:ignore U1000 This is used in the ddl tag + maskingPolicy bool `ddl:"static" db:"MASKING POLICY"` //lint:ignore U1000 This is used in the ddl tag + IfExists *bool `ddl:"keyword" db:"IF EXISTS"` + name SchemaObjectIdentifier `ddl:"identifier"` + NewName SchemaObjectIdentifier `ddl:"identifier" db:"RENAME TO"` + Set *MaskingPolicySet `ddl:"keyword" db:"SET"` + Unset *MaskingPolicyUnset `ddl:"keyword" db:"UNSET"` +} + +func (opts *MaskingPolicyAlterOptions) validate() error { + if opts.name.FullyQualifiedName() == "" { + return errors.New("name must not be empty") + } + + if opts.Set == nil && opts.Unset == nil { + if opts.NewName.FullyQualifiedName() == "" { + return errors.New("new name must not be empty") + } + } + + if opts.Set != nil && opts.Unset != nil { + return errors.New("cannot set and unset parameters in the same ALTER statement") + } + + if opts.Set != nil { + count := 0 + if opts.Set.Body != nil { + count++ + } + if opts.Set.Tag != nil { + count++ + } + if opts.Set.Comment != nil { + count++ + } + if count != 1 { + return errors.New("only one parameter must be set") + } + } + + if opts.Unset != nil { + count := 0 + if opts.Unset.Tag != nil { + count++ + } + if opts.Unset.Comment != nil { + count++ + } + if count != 1 { + return errors.New("only one parameter can be unset at a time") + } + } + + return nil +} + +type MaskingPolicySet struct { + Body *string `ddl:"command" db:"BODY ->"` + Tag []TagAssociation `ddl:"list,no_parentheses" db:"TAG"` + Comment *string `ddl:"parameter,single_quotes" db:"COMMENT"` +} + +type MaskingPolicyUnset struct { + Tag []ObjectIdentifier `ddl:"list,no_parentheses" db:"TAG"` + Comment *bool `ddl:"keyword" db:"COMMENT"` +} + +func (v *maskingPolicies) Alter(ctx context.Context, id SchemaObjectIdentifier, opts *MaskingPolicyAlterOptions) error { + if opts == nil { + opts = &MaskingPolicyAlterOptions{} + } + opts.name = id + if err := opts.validate(); err != nil { + return err + } + clauses, err := v.builder.parseStruct(opts) + if err != nil { + return err + } + stmt := v.builder.sql(clauses...) + _, err = v.client.exec(ctx, stmt) + return err +} + +type MaskingPolicyDropOptions struct { + drop bool `ddl:"static" db:"DROP"` //lint:ignore U1000 This is used in the ddl tag + maskingPolicy bool `ddl:"static" db:"MASKING POLICY"` //lint:ignore U1000 This is used in the ddl tag + name SchemaObjectIdentifier `ddl:"identifier"` +} + +func (opts *MaskingPolicyDropOptions) validate() error { + if opts.name.FullyQualifiedName() == "" { + return errors.New("name must not be empty") + } + return nil +} + +func (v *maskingPolicies) Drop(ctx context.Context, id SchemaObjectIdentifier) error { + // masking policy drop does not support [IF EXISTS] so there are no drop options. + opts := &MaskingPolicyDropOptions{ + name: id, + } + if err := opts.validate(); err != nil { + return fmt.Errorf("validate drop options: %w", err) + } + clauses, err := v.builder.parseStruct(opts) + if err != nil { + return err + } + stmt := v.builder.sql(clauses...) + _, err = v.client.exec(ctx, stmt) + if err != nil { + return decodeDriverError(err) + } + return err +} + +// MaskingPolicyShowOptions represents the options for listing masking policies. +type MaskingPolicyShowOptions struct { + show bool `ddl:"static" db:"SHOW"` //lint:ignore U1000 This is used in the ddl tag + maskingPolicies bool `ddl:"static" db:"MASKING POLICIES"` //lint:ignore U1000 This is used in the ddl tag + Like *Like `ddl:"keyword" db:"LIKE"` + In *In `ddl:"keyword" db:"IN"` + Limit *int `ddl:"command,no_quotes" db:"LIMIT"` +} + +func (input *MaskingPolicyShowOptions) validate() error { + return nil +} + +// MaskingPolicys is a user friendly result for a CREATE MASKING POLICY query. +type MaskingPolicy struct { + CreatedOn time.Time + Name string + DatabaseName string + SchemaName string + Kind string + Owner string + Comment string + ExemptOtherPolicies bool +} + +func (v *MaskingPolicy) ID() SchemaObjectIdentifier { + return NewSchemaObjectIdentifier(v.DatabaseName, v.SchemaName, v.Name) +} + +// maskingPolicyDBRow is used to decode the result of a CREATE MASKING POLICY query. +type maskingPolicyDBRow struct { + CreatedOn time.Time `db:"created_on"` + Name string `db:"name"` + DatabaseName string `db:"database_name"` + SchemaName string `db:"schema_name"` + Kind string `db:"kind"` + Owner string `db:"owner"` + Comment string `db:"comment"` + OwnerRoleType string `db:"owner_role_type"` + Options string `db:"options"` +} + +func (row maskingPolicyDBRow) toMaskingPolicy() *MaskingPolicy { + exemptOtherPolicies, err := jsonparser.GetBoolean([]byte(row.Options), "EXEMPT_OTHER_POLICIES") + if err != nil { + exemptOtherPolicies = false + } + return &MaskingPolicy{ + CreatedOn: row.CreatedOn, + Name: row.Name, + DatabaseName: row.DatabaseName, + SchemaName: row.SchemaName, + Kind: row.Kind, + Owner: row.Owner, + Comment: row.Comment, + ExemptOtherPolicies: exemptOtherPolicies, + } +} + +// List all the masking policies by pattern. +func (v *maskingPolicies) Show(ctx context.Context, opts *MaskingPolicyShowOptions) ([]*MaskingPolicy, error) { + if opts == nil { + opts = &MaskingPolicyShowOptions{} + } + if err := opts.validate(); err != nil { + return nil, err + } + clauses, err := v.builder.parseStruct(opts) + if err != nil { + return nil, err + } + stmt := v.builder.sql(clauses...) + dest := []maskingPolicyDBRow{} + + err = v.client.query(ctx, &dest, stmt) + if err != nil { + return nil, decodeDriverError(err) + } + resultList := make([]*MaskingPolicy, len(dest)) + for i, row := range dest { + resultList[i] = row.toMaskingPolicy() + } + + return resultList, nil +} + +type maskingPolicyDescribeOptions struct { + describe bool `ddl:"static" db:"DESCRIBE"` //lint:ignore U1000 This is used in the ddl tag + maskingPolicy bool `ddl:"static" db:"MASKING POLICY"` //lint:ignore U1000 This is used in the ddl tag + name SchemaObjectIdentifier `ddl:"identifier"` +} + +func (v *maskingPolicyDescribeOptions) validate() error { + if v.name.FullyQualifiedName() == "" { + return fmt.Errorf("name is required") + } + return nil +} + +type MaskingPolicyDetails struct { + Name string + Signature []TableColumnSignature + ReturnType DataType + Body string +} + +type maskingPolicyDetailsRow struct { + Name string `db:"name"` + Signature string `db:"signature"` + ReturnType string `db:"return_type"` + Body string `db:"body"` +} + +func (row maskingPolicyDetailsRow) toMaskingPolicyDetails() *MaskingPolicyDetails { + dataType := DataTypeFromString(row.ReturnType) + v := &MaskingPolicyDetails{ + Name: row.Name, + Signature: []TableColumnSignature{}, + ReturnType: dataType, + Body: row.Body, + } + s := strings.Trim(row.Signature, "()") + parts := strings.Split(s, ",") + for _, part := range parts { + p := strings.Split(strings.TrimSpace(part), " ") + if len(p) != 2 { + continue + } + dType := DataTypeFromString(p[1]) + v.Signature = append(v.Signature, TableColumnSignature{ + Name: p[0], + Type: dType, + }) + } + + return v +} + +func (v *maskingPolicies) Describe(ctx context.Context, id SchemaObjectIdentifier) (*MaskingPolicyDetails, error) { + opts := &maskingPolicyDescribeOptions{ + name: id, + } + if err := opts.validate(); err != nil { + return nil, err + } + + clauses, err := v.builder.parseStruct(opts) + if err != nil { + return nil, err + } + stmt := v.builder.sql(clauses...) + dest := maskingPolicyDetailsRow{} + err = v.client.queryOne(ctx, &dest, stmt) + if err != nil { + return nil, decodeDriverError(err) + } + + return dest.toMaskingPolicyDetails(), nil +} diff --git a/pkg/sdk/masking_policy_integration_test.go b/pkg/sdk/masking_policy_integration_test.go new file mode 100644 index 0000000000..9f2fc1b0ae --- /dev/null +++ b/pkg/sdk/masking_policy_integration_test.go @@ -0,0 +1,421 @@ +package sdk + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInt_MaskingPoliciesShow(t *testing.T) { + client := testClient(t) + ctx := context.Background() + databaseTest, databaseCleanup := createDatabase(t, client) + t.Cleanup(databaseCleanup) + + schemaTest, schemaCleanup := createSchema(t, client, databaseTest) + t.Cleanup(schemaCleanup) + + maskingPolicyTest, maskingPolicyCleanup := createMaskingPolicy(t, client, databaseTest, schemaTest) + t.Cleanup(maskingPolicyCleanup) + + maskingPolicy2Test, maskingPolicy2Cleanup := createMaskingPolicy(t, client, databaseTest, schemaTest) + t.Cleanup(maskingPolicy2Cleanup) + + t.Run("without show options", func(t *testing.T) { + useDatabaseCleanup := useDatabase(t, client, databaseTest.ID()) + t.Cleanup(useDatabaseCleanup) + useSchemaCleanup := useSchema(t, client, schemaTest.ID()) + t.Cleanup(useSchemaCleanup) + + maskingPolicies, err := client.MaskingPolicies.Show(ctx, nil) + require.NoError(t, err) + assert.Equal(t, 2, len(maskingPolicies)) + }) + + t.Run("with show options", func(t *testing.T) { + showOptions := &MaskingPolicyShowOptions{ + In: &In{ + Schema: schemaTest.ID(), + }, + } + maskingPolicies, err := client.MaskingPolicies.Show(ctx, showOptions) + require.NoError(t, err) + assert.Contains(t, maskingPolicies, maskingPolicyTest) + assert.Contains(t, maskingPolicies, maskingPolicy2Test) + assert.Equal(t, 2, len(maskingPolicies)) + }) + + t.Run("with show options and like", func(t *testing.T) { + showOptions := &MaskingPolicyShowOptions{ + Like: &Like{ + Pattern: String(maskingPolicyTest.Name), + }, + In: &In{ + Database: databaseTest.ID(), + }, + } + maskingPolicies, err := client.MaskingPolicies.Show(ctx, showOptions) + require.NoError(t, err) + assert.Contains(t, maskingPolicies, maskingPolicyTest) + assert.Equal(t, 1, len(maskingPolicies)) + }) + + t.Run("when searching a non-existent masking policy", func(t *testing.T) { + showOptions := &MaskingPolicyShowOptions{ + Like: &Like{ + Pattern: String("non-existent"), + }, + } + maskingPolicies, err := client.MaskingPolicies.Show(ctx, showOptions) + require.NoError(t, err) + assert.Equal(t, 0, len(maskingPolicies)) + }) + + /* + // there appears to be a bug in the Snowflake API. LIMIT is not actually limiting the number of results + t.Run("when limiting the number of results", func(t *testing.T) { + showOptions := &MaskingPolicyShowOptions{ + In: &In{ + Schema: schemaTest.ID(), + }, + Limit: Int(1), + } + maskingPolicies, err := client.MaskingPolicies.Show(ctx, showOptions) + require.NoError(t, err) + assert.Equal(t, 1, len(maskingPolicies)) + }) + */ +} + +func TestInt_MaskingPolicyCreate(t *testing.T) { + client := testClient(t) + ctx := context.Background() + databaseTest, databaseCleanup := createDatabase(t, client) + t.Cleanup(databaseCleanup) + + schemaTest, schemaCleanup := createSchema(t, client, databaseTest) + t.Cleanup(schemaCleanup) + + t.Run("test complete case", func(t *testing.T) { + name := randomString(t) + id := NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, name) + signature := []TableColumnSignature{ + { + Name: "col1", + Type: DataTypeVARCHAR, + }, + { + Name: "col2", + Type: DataTypeVARCHAR, + }, + } + expression := "REPLACE('X', 1, 2)" + comment := randomComment(t) + exemptOtherPolicies := randomBool(t) + err := client.MaskingPolicies.Create(ctx, id, signature, DataTypeVARCHAR, expression, &MaskingPolicyCreateOptions{ + OrReplace: Bool(true), + IfNotExists: Bool(false), + Comment: String(comment), + ExemptOtherPolicies: Bool(exemptOtherPolicies), + }) + require.NoError(t, err) + maskingPolicyDetails, err := client.MaskingPolicies.Describe(ctx, id) + require.NoError(t, err) + assert.Equal(t, name, maskingPolicyDetails.Name) + assert.Equal(t, signature, maskingPolicyDetails.Signature) + assert.Equal(t, DataTypeVARCHAR, maskingPolicyDetails.ReturnType) + assert.Equal(t, expression, maskingPolicyDetails.Body) + + maskingPolicy, err := client.MaskingPolicies.Show(ctx, &MaskingPolicyShowOptions{ + Like: &Like{ + Pattern: String(name), + }, + In: &In{ + Schema: schemaTest.ID(), + }, + }) + require.NoError(t, err) + assert.Equal(t, 1, len(maskingPolicy)) + assert.Equal(t, name, maskingPolicy[0].Name) + assert.Equal(t, comment, maskingPolicy[0].Comment) + assert.Equal(t, exemptOtherPolicies, maskingPolicy[0].ExemptOtherPolicies) + }) + + t.Run("test if_not_exists", func(t *testing.T) { + name := randomString(t) + id := NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, name) + signature := []TableColumnSignature{ + { + Name: "col1", + Type: DataTypeVARCHAR, + }, + { + Name: "col2", + Type: DataTypeVARCHAR, + }, + } + expression := "REPLACE('X', 1, 2)" + comment := randomComment(t) + err := client.MaskingPolicies.Create(ctx, id, signature, DataTypeVARCHAR, expression, &MaskingPolicyCreateOptions{ + OrReplace: Bool(false), + IfNotExists: Bool(true), + Comment: String(comment), + ExemptOtherPolicies: Bool(true), + }) + require.NoError(t, err) + maskingPolicyDetails, err := client.MaskingPolicies.Describe(ctx, id) + require.NoError(t, err) + assert.Equal(t, name, maskingPolicyDetails.Name) + assert.Equal(t, signature, maskingPolicyDetails.Signature) + assert.Equal(t, DataTypeVARCHAR, maskingPolicyDetails.ReturnType) + assert.Equal(t, expression, maskingPolicyDetails.Body) + + maskingPolicy, err := client.MaskingPolicies.Show(ctx, &MaskingPolicyShowOptions{ + Like: &Like{ + Pattern: String(name), + }, + In: &In{ + Schema: schemaTest.ID(), + }, + }) + require.NoError(t, err) + assert.Equal(t, 1, len(maskingPolicy)) + assert.Equal(t, name, maskingPolicy[0].Name) + assert.Equal(t, comment, maskingPolicy[0].Comment) + assert.Equal(t, true, maskingPolicy[0].ExemptOtherPolicies) + }) + + t.Run("test no options", func(t *testing.T) { + name := randomString(t) + id := NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, name) + signature := []TableColumnSignature{ + { + Name: "col1", + Type: DataTypeVARCHAR, + }, + } + expression := "REPLACE('X', 1, 2)" + err := client.MaskingPolicies.Create(ctx, id, signature, DataTypeVARCHAR, expression, nil) + require.NoError(t, err) + maskingPolicyDetails, err := client.MaskingPolicies.Describe(ctx, id) + require.NoError(t, err) + assert.Equal(t, name, maskingPolicyDetails.Name) + assert.Equal(t, signature, maskingPolicyDetails.Signature) + assert.Equal(t, DataTypeVARCHAR, maskingPolicyDetails.ReturnType) + assert.Equal(t, expression, maskingPolicyDetails.Body) + + maskingPolicy, err := client.MaskingPolicies.Show(ctx, &MaskingPolicyShowOptions{ + Like: &Like{ + Pattern: String(name), + }, + In: &In{ + Schema: schemaTest.ID(), + }, + }) + require.NoError(t, err) + assert.Equal(t, 1, len(maskingPolicy)) + assert.Equal(t, name, maskingPolicy[0].Name) + assert.Equal(t, "", maskingPolicy[0].Comment) + assert.Equal(t, false, maskingPolicy[0].ExemptOtherPolicies) + }) + + t.Run("test multiline expression", func(t *testing.T) { + name := randomString(t) + id := NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, name) + signature := []TableColumnSignature{ + { + Name: "val", + Type: DataTypeVARCHAR, + }, + } + expression := ` + case + when current_role() in ('ROLE_A') then + val + when is_role_in_session( 'ROLE_B' ) then + 'ABC123' + else + '******' + end + ` + err := client.MaskingPolicies.Create(ctx, id, signature, DataTypeVARCHAR, expression, nil) + require.NoError(t, err) + maskingPolicyDetails, err := client.MaskingPolicies.Describe(ctx, id) + require.NoError(t, err) + assert.Equal(t, name, maskingPolicyDetails.Name) + assert.Equal(t, signature, maskingPolicyDetails.Signature) + assert.Equal(t, DataTypeVARCHAR, maskingPolicyDetails.ReturnType) + assert.Equal(t, strings.TrimSpace(expression), maskingPolicyDetails.Body) + }) +} + +func TestInt_MaskingPolicyDescribe(t *testing.T) { + client := testClient(t) + ctx := context.Background() + + databaseTest, databaseCleanup := createDatabase(t, client) + t.Cleanup(databaseCleanup) + + schemaTest, schemaCleanup := createSchema(t, client, databaseTest) + t.Cleanup(schemaCleanup) + + maskingPolicy, maskingPolicyCleanup := createMaskingPolicy(t, client, databaseTest, schemaTest) + t.Cleanup(maskingPolicyCleanup) + + t.Run("when masking policy exists", func(t *testing.T) { + maskingPolicyDetails, err := client.MaskingPolicies.Describe(ctx, maskingPolicy.ID()) + require.NoError(t, err) + assert.Equal(t, maskingPolicy.Name, maskingPolicyDetails.Name) + }) + + t.Run("when masking policy does not exist", func(t *testing.T) { + id := NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, "does_not_exist") + _, err := client.MaskingPolicies.Describe(ctx, id) + assert.ErrorIs(t, err, ErrObjectNotExistOrAuthorized) + }) +} + +func TestInt_MaskingPolicyAlter(t *testing.T) { + client := testClient(t) + ctx := context.Background() + + databaseTest, databaseCleanup := createDatabase(t, client) + t.Cleanup(databaseCleanup) + + schemaTest, schemaCleanup := createSchema(t, client, databaseTest) + t.Cleanup(schemaCleanup) + + t.Run("when setting and unsetting a value", func(t *testing.T) { + maskingPolicy, maskingPolicyCleanup := createMaskingPolicy(t, client, databaseTest, schemaTest) + t.Cleanup(maskingPolicyCleanup) + comment := randomComment(t) + alterOptions := &MaskingPolicyAlterOptions{ + Set: &MaskingPolicySet{ + Comment: String(comment), + }, + } + err := client.MaskingPolicies.Alter(ctx, maskingPolicy.ID(), alterOptions) + require.NoError(t, err) + maskingPolicies, err := client.MaskingPolicies.Show(ctx, &MaskingPolicyShowOptions{ + Like: &Like{ + Pattern: String(maskingPolicy.Name), + }, + In: &In{ + Schema: schemaTest.ID(), + }, + }) + require.NoError(t, err) + assert.Equal(t, 1, len(maskingPolicies)) + assert.Equal(t, comment, maskingPolicies[0].Comment) + + err = client.MaskingPolicies.Alter(ctx, maskingPolicy.ID(), alterOptions) + require.NoError(t, err) + alterOptions = &MaskingPolicyAlterOptions{ + Unset: &MaskingPolicyUnset{ + Comment: Bool(true), + }, + } + err = client.MaskingPolicies.Alter(ctx, maskingPolicy.ID(), alterOptions) + require.NoError(t, err) + maskingPolicies, err = client.MaskingPolicies.Show(ctx, &MaskingPolicyShowOptions{ + Like: &Like{ + Pattern: String(maskingPolicy.Name), + }, + In: &In{ + Schema: schemaTest.ID(), + }, + }) + require.NoError(t, err) + assert.Equal(t, 1, len(maskingPolicies)) + assert.Equal(t, "", maskingPolicies[0].Comment) + }) + + t.Run("when renaming", func(t *testing.T) { + maskingPolicy, maskingPolicyCleanup := createMaskingPolicy(t, client, databaseTest, schemaTest) + oldID := maskingPolicy.ID() + t.Cleanup(maskingPolicyCleanup) + newName := randomString(t) + newID := NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, newName) + alterOptions := &MaskingPolicyAlterOptions{ + NewName: newID, + } + err := client.MaskingPolicies.Alter(ctx, oldID, alterOptions) + require.NoError(t, err) + maskingPolicyDetails, err := client.MaskingPolicies.Describe(ctx, newID) + require.NoError(t, err) + assert.Equal(t, newName, maskingPolicyDetails.Name) + // rename back to original name so it can be cleaned up + alterOptions = &MaskingPolicyAlterOptions{ + NewName: oldID, + } + err = client.MaskingPolicies.Alter(ctx, newID, alterOptions) + require.NoError(t, err) + }) + + t.Run("setting and unsetting tags", func(t *testing.T) { + maskingPolicy, maskingPolicyCleanup := createMaskingPolicy(t, client, databaseTest, schemaTest) + id := maskingPolicy.ID() + t.Cleanup(maskingPolicyCleanup) + + tag, tagCleanup := createTag(t, client, databaseTest, schemaTest) + t.Cleanup(tagCleanup) + + tag2, tag2Cleanup := createTag(t, client, databaseTest, schemaTest) + t.Cleanup(tag2Cleanup) + + tagAssociations := []TagAssociation{{Name: tag.ID(), Value: "value1"}, {Name: tag2.ID(), Value: "value2"}} + alterOptions := &MaskingPolicyAlterOptions{ + Set: &MaskingPolicySet{ + Tag: tagAssociations, + }, + } + err := client.MaskingPolicies.Alter(ctx, id, alterOptions) + require.NoError(t, err) + tagValue, err := client.SystemFunctions.GetTag(ctx, tag.ID(), id, ObjectTypeMaskingPolicy) + require.NoError(t, err) + assert.Equal(t, tagAssociations[0].Value, tagValue) + tag2Value, err := client.SystemFunctions.GetTag(ctx, tag2.ID(), id, ObjectTypeMaskingPolicy) + require.NoError(t, err) + assert.Equal(t, tagAssociations[1].Value, tag2Value) + + // unset tag + alterOptions = &MaskingPolicyAlterOptions{ + Unset: &MaskingPolicyUnset{ + Tag: []ObjectIdentifier{tag.ID()}, + }, + } + err = client.MaskingPolicies.Alter(ctx, id, alterOptions) + require.NoError(t, err) + _, err = client.SystemFunctions.GetTag(ctx, tag.ID(), id, ObjectTypeMaskingPolicy) + assert.Error(t, err) + }) +} + +func TestInt_MaskingPolicyDrop(t *testing.T) { + client := testClient(t) + ctx := context.Background() + + databaseTest, databaseCleanup := createDatabase(t, client) + t.Cleanup(databaseCleanup) + + schemaTest, schemaCleanup := createSchema(t, client, databaseTest) + t.Cleanup(schemaCleanup) + + t.Run("when masking policy exists", func(t *testing.T) { + maskingPolicy, _ := createMaskingPolicy(t, client, databaseTest, schemaTest) + id := maskingPolicy.ID() + err := client.MaskingPolicies.Drop(ctx, id) + require.NoError(t, err) + _, err = client.PasswordPolicies.Describe(ctx, id) + assert.ErrorIs(t, err, ErrObjectNotExistOrAuthorized) + }) + + t.Run("when masking policy does not exist", func(t *testing.T) { + id := NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, "does_not_exist") + err := client.MaskingPolicies.Drop(ctx, id) + assert.ErrorIs(t, err, ErrObjectNotExistOrAuthorized) + }) +} diff --git a/pkg/sdk/masking_policy_test.go b/pkg/sdk/masking_policy_test.go new file mode 100644 index 0000000000..5c15381c9e --- /dev/null +++ b/pkg/sdk/masking_policy_test.go @@ -0,0 +1,261 @@ +package sdk + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stretchr/testify/require" +) + +func TestMaskingPolicyCreate(t *testing.T) { + builder := testBuilder(t) + id := randomSchemaObjectIdentifier(t) + + t.Run("empty options", func(t *testing.T) { + opts := &MaskingPolicyCreateOptions{} + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := "CREATE MASKING POLICY RETURNS ->" + assert.Equal(t, expected, actual) + }) + + t.Run("with complete options", func(t *testing.T) { + signature := []TableColumnSignature{ + { + Name: "col1", + Type: DataTypeVARCHAR, + }, + { + Name: "col2", + Type: DataTypeVARCHAR, + }, + } + expression := "REPLACE('X', 1, 2)" + comment := randomString(t) + + opts := &MaskingPolicyCreateOptions{ + OrReplace: Bool(true), + name: id, + IfNotExists: Bool(true), + signature: signature, + body: expression, + returns: DataTypeVARCHAR, + Comment: String(comment), + ExemptOtherPolicies: Bool(true), + } + + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := fmt.Sprintf(`CREATE OR REPLACE MASKING POLICY IF NOT EXISTS %s AS ("col1" VARCHAR,"col2" VARCHAR) RETURNS %s -> %s COMMENT = '%s' EXEMPT_OTHER_POLICIES = %t`, id.FullyQualifiedName(), DataTypeVARCHAR, expression, comment, true) + assert.Equal(t, expected, actual) + }) +} + +func TestMaskingPolicyAlter(t *testing.T) { + builder := testBuilder(t) + id := randomSchemaObjectIdentifier(t) + + t.Run("empty options", func(t *testing.T) { + opts := &MaskingPolicyAlterOptions{} + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := "ALTER MASKING POLICY" + assert.Equal(t, expected, actual) + }) + + t.Run("only name", func(t *testing.T) { + opts := &MaskingPolicyAlterOptions{ + name: id, + } + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := fmt.Sprintf("ALTER MASKING POLICY %s", id.FullyQualifiedName()) + assert.Equal(t, expected, actual) + }) + + t.Run("with set", func(t *testing.T) { + newComment := randomString(t) + opts := &MaskingPolicyAlterOptions{ + name: id, + Set: &MaskingPolicySet{ + Comment: String(newComment), + }, + } + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := fmt.Sprintf("ALTER MASKING POLICY %s SET COMMENT = '%s'", id.FullyQualifiedName(), newComment) + assert.Equal(t, expected, actual) + }) + + t.Run("with unset", func(t *testing.T) { + opts := &MaskingPolicyAlterOptions{ + name: id, + Unset: &MaskingPolicyUnset{ + Comment: Bool(true), + }, + } + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := fmt.Sprintf("ALTER MASKING POLICY %s UNSET COMMENT", id.FullyQualifiedName()) + assert.Equal(t, expected, actual) + }) + + t.Run("rename", func(t *testing.T) { + newID := NewSchemaObjectIdentifier(id.databaseName, id.schemaName, randomUUID(t)) + opts := &MaskingPolicyAlterOptions{ + name: id, + NewName: newID, + } + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := fmt.Sprintf("ALTER MASKING POLICY %s RENAME TO %s", id.FullyQualifiedName(), newID.FullyQualifiedName()) + assert.Equal(t, expected, actual) + }) +} + +func TestMaskingPolicyDrop(t *testing.T) { + builder := testBuilder(t) + id := randomSchemaObjectIdentifier(t) + + t.Run("empty options", func(t *testing.T) { + opts := &MaskingPolicyDropOptions{} + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := "DROP MASKING POLICY" + assert.Equal(t, expected, actual) + }) + + t.Run("only name", func(t *testing.T) { + opts := &MaskingPolicyDropOptions{ + name: id, + } + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := fmt.Sprintf("DROP MASKING POLICY %s", id.FullyQualifiedName()) + assert.Equal(t, expected, actual) + }) +} + +func TestMaskingPolicyShow(t *testing.T) { + builder := testBuilder(t) + id := randomSchemaObjectIdentifier(t) + + t.Run("empty options", func(t *testing.T) { + opts := &MaskingPolicyShowOptions{} + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := "SHOW MASKING POLICIES" + assert.Equal(t, expected, actual) + }) + + t.Run("with like", func(t *testing.T) { + opts := &MaskingPolicyShowOptions{ + Like: &Like{ + Pattern: String(id.Name()), + }, + } + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := fmt.Sprintf("SHOW MASKING POLICIES LIKE '%s'", id.Name()) + assert.Equal(t, expected, actual) + }) + + t.Run("with like and in account", func(t *testing.T) { + opts := &MaskingPolicyShowOptions{ + Like: &Like{ + Pattern: String(id.Name()), + }, + In: &In{ + Account: Bool(true), + }, + } + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := fmt.Sprintf("SHOW MASKING POLICIES LIKE '%s' IN ACCOUNT", id.Name()) + assert.Equal(t, expected, actual) + }) + + t.Run("with like and in database", func(t *testing.T) { + databaseIdentifier := NewAccountObjectIdentifier(id.DatabaseName()) + opts := &MaskingPolicyShowOptions{ + Like: &Like{ + Pattern: String(id.Name()), + }, + In: &In{ + Database: databaseIdentifier, + }, + } + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := fmt.Sprintf("SHOW MASKING POLICIES LIKE '%s' IN DATABASE %s", id.Name(), databaseIdentifier.FullyQualifiedName()) + assert.Equal(t, expected, actual) + }) + + t.Run("with like and in schema", func(t *testing.T) { + schemaIdentifier := NewSchemaIdentifier(id.DatabaseName(), id.SchemaName()) + opts := &MaskingPolicyShowOptions{ + Like: &Like{ + Pattern: String(id.Name()), + }, + In: &In{ + Schema: schemaIdentifier, + }, + } + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := fmt.Sprintf("SHOW MASKING POLICIES LIKE '%s' IN SCHEMA %s", id.Name(), schemaIdentifier.FullyQualifiedName()) + assert.Equal(t, expected, actual) + }) + + t.Run("with limit", func(t *testing.T) { + opts := &MaskingPolicyShowOptions{ + Limit: Int(10), + } + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := "SHOW MASKING POLICIES LIMIT 10" + assert.Equal(t, expected, actual) + }) +} + +func TestMaskingPolicyDescribe(t *testing.T) { + builder := testBuilder(t) + id := randomSchemaObjectIdentifier(t) + + t.Run("empty options", func(t *testing.T) { + opts := &maskingPolicyDescribeOptions{} + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := "DESCRIBE MASKING POLICY" + assert.Equal(t, expected, actual) + }) + + t.Run("only name", func(t *testing.T) { + opts := &maskingPolicyDescribeOptions{ + name: id, + } + clauses, err := builder.parseStruct(opts) + require.NoError(t, err) + actual := builder.sql(clauses...) + expected := fmt.Sprintf("DESCRIBE MASKING POLICY %s", id.FullyQualifiedName()) + assert.Equal(t, expected, actual) + }) +} diff --git a/pkg/sdk/password_policy.go b/pkg/sdk/password_policy.go index fc00861109..7f59af0753 100644 --- a/pkg/sdk/password_policy.go +++ b/pkg/sdk/password_policy.go @@ -35,8 +35,8 @@ type PasswordPolicyCreateOptions struct { create bool `ddl:"static" db:"CREATE"` //lint:ignore U1000 This is used in the ddl tag OrReplace *bool `ddl:"keyword" db:"OR REPLACE"` passwordPolicy bool `ddl:"static" db:"PASSWORD POLICY"` //lint:ignore U1000 This is used in the ddl tag - name SchemaObjectIdentifier `ddl:"identifier"` IfNotExists *bool `ddl:"keyword" db:"IF NOT EXISTS"` + name SchemaObjectIdentifier `ddl:"identifier"` PasswordMinLength *int `ddl:"parameter" db:"PASSWORD_MIN_LENGTH"` PasswordMaxLength *int `ddl:"parameter" db:"PASSWORD_MAX_LENGTH"` @@ -77,13 +77,13 @@ func (v *passwordPolicies) Create(ctx context.Context, id SchemaObjectIdentifier } type PasswordPolicyAlterOptions struct { - alter bool `ddl:"static" db:"ALTER"` //lint:ignore U1000 This is used in the ddl tag - passwordPolicy bool `ddl:"static" db:"PASSWORD POLICY"` //lint:ignore U1000 This is used in the ddl tag - IfExists *bool `ddl:"keyword" db:"IF EXISTS"` - name SchemaObjectIdentifier `ddl:"identifier"` - NewName SchemaObjectIdentifier `ddl:"identifier" db:"RENAME TO"` - Set *PasswordPolicyAlterSet `ddl:"keyword" db:"SET"` - Unset *PasswordPolicyAlterUnset `ddl:"keyword" db:"UNSET"` + alter bool `ddl:"static" db:"ALTER"` //lint:ignore U1000 This is used in the ddl tag + passwordPolicy bool `ddl:"static" db:"PASSWORD POLICY"` //lint:ignore U1000 This is used in the ddl tag + IfExists *bool `ddl:"keyword" db:"IF EXISTS"` + name SchemaObjectIdentifier `ddl:"identifier"` + NewName SchemaObjectIdentifier `ddl:"identifier" db:"RENAME TO"` + Set *PasswordPolicySet `ddl:"keyword" db:"SET"` + Unset *PasswordPolicyUnset `ddl:"keyword" db:"UNSET"` } func (opts *PasswordPolicyAlterOptions) validate() error { @@ -181,7 +181,7 @@ func (opts *PasswordPolicyAlterOptions) validate() error { return nil } -type PasswordPolicyAlterSet struct { +type PasswordPolicySet struct { PasswordMinLength *int `ddl:"parameter" db:"PASSWORD_MIN_LENGTH"` PasswordMaxLength *int `ddl:"parameter" db:"PASSWORD_MAX_LENGTH"` PasswordMinUpperCaseChars *int `ddl:"parameter" db:"PASSWORD_MIN_UPPER_CASE_CHARS"` @@ -194,7 +194,7 @@ type PasswordPolicyAlterSet struct { Comment *string `ddl:"parameter,single_quotes" db:"COMMENT"` } -type PasswordPolicyAlterUnset struct { +type PasswordPolicyUnset struct { PasswordMinLength *bool `ddl:"keyword" db:"PASSWORD_MIN_LENGTH"` PasswordMaxLength *bool `ddl:"keyword" db:"PASSWORD_MAX_LENGTH"` PasswordMinUpperCaseChars *bool `ddl:"keyword" db:"PASSWORD_MIN_UPPER_CASE_CHARS"` @@ -282,7 +282,7 @@ type PasswordPolicy struct { Comment string } -func (v *PasswordPolicy) Identifier() SchemaObjectIdentifier { +func (v *PasswordPolicy) ID() SchemaObjectIdentifier { return NewSchemaObjectIdentifier(v.DatabaseName, v.SchemaName, v.Name) } @@ -299,7 +299,7 @@ type passwordPolicyDBRow struct { Options string `db:"options"` } -func passwordPolicyFromRow(row passwordPolicyDBRow) *PasswordPolicy { +func (row passwordPolicyDBRow) toPasswordPolicy() *PasswordPolicy { return &PasswordPolicy{ CreatedOn: row.CreatedOn, Name: row.Name, @@ -332,7 +332,7 @@ func (v *passwordPolicies) Show(ctx context.Context, opts *PasswordPolicyShowOpt } resultList := make([]*PasswordPolicy, len(dest)) for i, row := range dest { - resultList[i] = passwordPolicyFromRow(row) + resultList[i] = row.toPasswordPolicy() } return resultList, nil diff --git a/pkg/sdk/password_policy_integration_test.go b/pkg/sdk/password_policy_integration_test.go index 448e355bfe..cf6b721e8e 100644 --- a/pkg/sdk/password_policy_integration_test.go +++ b/pkg/sdk/password_policy_integration_test.go @@ -32,7 +32,7 @@ func TestInt_PasswordPoliciesShow(t *testing.T) { t.Run("with show options", func(t *testing.T) { showOptions := &PasswordPolicyShowOptions{ In: &In{ - Schema: schemaTest.Identifier(), + Schema: schemaTest.ID(), }, } passwordPolicies, err := client.PasswordPolicies.Show(ctx, showOptions) @@ -48,7 +48,7 @@ func TestInt_PasswordPoliciesShow(t *testing.T) { Pattern: String(passwordPolicyTest.Name), }, In: &In{ - Database: databaseTest.Identifier(), + Database: databaseTest.ID(), }, } passwordPolicies, err := client.PasswordPolicies.Show(ctx, showOptions) @@ -90,9 +90,8 @@ func TestInt_PasswordPolicyCreate(t *testing.T) { schemaTest, schemaCleanup := createSchema(t, client, databaseTest) t.Cleanup(schemaCleanup) - - t.Run("test complete case", func(t *testing.T) { - name := randomString(t) + t.Run("test complete", func(t *testing.T) { + name := randomUUID(t) id := NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, name) err := client.PasswordPolicies.Create(ctx, id, &PasswordPolicyCreateOptions{ OrReplace: Bool(true), @@ -123,11 +122,12 @@ func TestInt_PasswordPolicyCreate(t *testing.T) { assert.Equal(t, 30, passwordPolicyDetails.PasswordLockoutTimeMins.Value) }) - t.Run("test no on_on replace", func(t *testing.T) { - name := randomString(t) + t.Run("test if_not_exists", func(t *testing.T) { + name := randomUUID(t) id := NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, name) err := client.PasswordPolicies.Create(ctx, id, &PasswordPolicyCreateOptions{ OrReplace: Bool(false), + IfNotExists: Bool(true), PasswordMinLength: Int(10), PasswordMaxLength: Int(20), PasswordMinUpperCaseChars: Int(5), @@ -144,7 +144,7 @@ func TestInt_PasswordPolicyCreate(t *testing.T) { }) t.Run("test no options", func(t *testing.T) { - name := randomString(t) + name := randomUUID(t) id := NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, name) err := client.PasswordPolicies.Create(ctx, id, nil) require.NoError(t, err) @@ -178,7 +178,7 @@ func TestInt_PasswordPolicyDescribe(t *testing.T) { t.Cleanup(passwordPolicyCleanup) t.Run("when password policy exists", func(t *testing.T) { - passwordPolicyDetails, err := client.PasswordPolicies.Describe(ctx, passwordPolicy.Identifier()) + passwordPolicyDetails, err := client.PasswordPolicies.Describe(ctx, passwordPolicy.ID()) require.NoError(t, err) assert.Equal(t, passwordPolicy.Name, passwordPolicyDetails.Name.Value) assert.Equal(t, passwordPolicy.Comment, passwordPolicyDetails.Comment.Value) @@ -205,14 +205,14 @@ func TestInt_PasswordPolicyAlter(t *testing.T) { passwordPolicy, passwordPolicyCleanup := createPasswordPolicy(t, client, databaseTest, schemaTest) t.Cleanup(passwordPolicyCleanup) alterOptions := &PasswordPolicyAlterOptions{ - Set: &PasswordPolicyAlterSet{ + Set: &PasswordPolicySet{ PasswordMinLength: Int(10), PasswordMaxLength: Int(20), }, } - err := client.PasswordPolicies.Alter(ctx, passwordPolicy.Identifier(), alterOptions) + err := client.PasswordPolicies.Alter(ctx, passwordPolicy.ID(), alterOptions) require.NoError(t, err) - passwordPolicyDetails, err := client.PasswordPolicies.Describe(ctx, passwordPolicy.Identifier()) + passwordPolicyDetails, err := client.PasswordPolicies.Describe(ctx, passwordPolicy.ID()) require.NoError(t, err) assert.Equal(t, passwordPolicy.Name, passwordPolicyDetails.Name.Value) assert.Equal(t, 10, passwordPolicyDetails.PasswordMinLength.Value) @@ -221,9 +221,9 @@ func TestInt_PasswordPolicyAlter(t *testing.T) { t.Run("when renaming", func(t *testing.T) { passwordPolicy, passwordPolicyCleanup := createPasswordPolicy(t, client, databaseTest, schemaTest) - oldID := passwordPolicy.Identifier() + oldID := passwordPolicy.ID() t.Cleanup(passwordPolicyCleanup) - newName := randomString(t) + newName := randomUUID(t) newID := NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, newName) alterOptions := &PasswordPolicyAlterOptions{ NewName: newID, @@ -247,17 +247,17 @@ func TestInt_PasswordPolicyAlter(t *testing.T) { PasswordMaxRetries: Int(10), } passwordPolicy, passwordPolicyCleanup := createPasswordPolicyWithOptions(t, client, databaseTest, schemaTest, createOptions) - id := passwordPolicy.Identifier() + id := passwordPolicy.ID() t.Cleanup(passwordPolicyCleanup) alterOptions := &PasswordPolicyAlterOptions{ - Unset: &PasswordPolicyAlterUnset{ + Unset: &PasswordPolicyUnset{ PasswordMaxRetries: Bool(true), }, } err := client.PasswordPolicies.Alter(ctx, id, alterOptions) require.NoError(t, err) alterOptions = &PasswordPolicyAlterOptions{ - Unset: &PasswordPolicyAlterUnset{ + Unset: &PasswordPolicyUnset{ Comment: Bool(true), }, } @@ -276,10 +276,10 @@ func TestInt_PasswordPolicyAlter(t *testing.T) { PasswordMaxRetries: Int(10), } passwordPolicy, passwordPolicyCleanup := createPasswordPolicyWithOptions(t, client, databaseTest, schemaTest, createOptions) - id := passwordPolicy.Identifier() + id := passwordPolicy.ID() t.Cleanup(passwordPolicyCleanup) alterOptions := &PasswordPolicyAlterOptions{ - Unset: &PasswordPolicyAlterUnset{ + Unset: &PasswordPolicyUnset{ Comment: Bool(true), PasswordMaxRetries: Bool(true), }, @@ -301,7 +301,7 @@ func TestInt_PasswordPolicyDrop(t *testing.T) { t.Run("when password policy exists", func(t *testing.T) { passwordPolicy, _ := createPasswordPolicy(t, client, databaseTest, schemaTest) - id := passwordPolicy.Identifier() + id := passwordPolicy.ID() err := client.PasswordPolicies.Drop(ctx, id, nil) require.NoError(t, err) _, err = client.PasswordPolicies.Describe(ctx, id) @@ -316,7 +316,7 @@ func TestInt_PasswordPolicyDrop(t *testing.T) { t.Run("when password policy exists and if exists is true", func(t *testing.T) { passwordPolicy, _ := createPasswordPolicy(t, client, databaseTest, schemaTest) - id := passwordPolicy.Identifier() + id := passwordPolicy.ID() dropOptions := &PasswordPolicyDropOptions{IfExists: Bool(true)} err := client.PasswordPolicies.Drop(ctx, id, dropOptions) require.NoError(t, err) diff --git a/pkg/sdk/password_policy_test.go b/pkg/sdk/password_policy_test.go index 61074efe53..82de8f83bd 100644 --- a/pkg/sdk/password_policy_test.go +++ b/pkg/sdk/password_policy_test.go @@ -52,7 +52,7 @@ func TestPasswordPolicyCreate(t *testing.T) { clauses, err := builder.parseStruct(opts) require.NoError(t, err) actual := builder.sql(clauses...) - expected := fmt.Sprintf(`CREATE OR REPLACE PASSWORD POLICY %s IF NOT EXISTS PASSWORD_MIN_LENGTH = 10 PASSWORD_MAX_LENGTH = 20 PASSWORD_MIN_UPPER_CASE_CHARS = 1 PASSWORD_MIN_LOWER_CASE_CHARS = 1 PASSWORD_MIN_NUMERIC_CHARS = 1 PASSWORD_MIN_SPECIAL_CHARS = 1 PASSWORD_MAX_AGE_DAYS = 30 PASSWORD_MAX_RETRIES = 5 PASSWORD_LOCKOUT_TIME_MINS = 30 COMMENT = 'test comment'`, id.FullyQualifiedName()) + expected := fmt.Sprintf(`CREATE OR REPLACE PASSWORD POLICY IF NOT EXISTS %s PASSWORD_MIN_LENGTH = 10 PASSWORD_MAX_LENGTH = 20 PASSWORD_MIN_UPPER_CASE_CHARS = 1 PASSWORD_MIN_LOWER_CASE_CHARS = 1 PASSWORD_MIN_NUMERIC_CHARS = 1 PASSWORD_MIN_SPECIAL_CHARS = 1 PASSWORD_MAX_AGE_DAYS = 30 PASSWORD_MAX_RETRIES = 5 PASSWORD_LOCKOUT_TIME_MINS = 30 COMMENT = 'test comment'`, id.FullyQualifiedName()) assert.Equal(t, expected, actual) }) } @@ -84,7 +84,7 @@ func TestPasswordPolicyAlter(t *testing.T) { t.Run("with set", func(t *testing.T) { opts := &PasswordPolicyAlterOptions{ name: id, - Set: &PasswordPolicyAlterSet{ + Set: &PasswordPolicySet{ PasswordMinLength: Int(10), PasswordMaxLength: Int(20), PasswordMinUpperCaseChars: Int(1), @@ -100,7 +100,7 @@ func TestPasswordPolicyAlter(t *testing.T) { t.Run("with unset", func(t *testing.T) { opts := &PasswordPolicyAlterOptions{ name: id, - Unset: &PasswordPolicyAlterUnset{ + Unset: &PasswordPolicyUnset{ PasswordMinLength: Bool(true), }, } @@ -112,7 +112,7 @@ func TestPasswordPolicyAlter(t *testing.T) { }) t.Run("rename", func(t *testing.T) { - newID := NewSchemaObjectIdentifier(id.DatabaseName, id.SchemaName, randomString(t)) + newID := NewSchemaObjectIdentifier(id.databaseName, id.schemaName, randomUUID(t)) opts := &PasswordPolicyAlterOptions{ name: id, NewName: newID, @@ -178,20 +178,20 @@ func TestPasswordPolicyShow(t *testing.T) { t.Run("with like", func(t *testing.T) { opts := &PasswordPolicyShowOptions{ Like: &Like{ - Pattern: String(id.Name), + Pattern: String(id.Name()), }, } clauses, err := builder.parseStruct(opts) require.NoError(t, err) actual := builder.sql(clauses...) - expected := fmt.Sprintf("SHOW PASSWORD POLICIES LIKE '%s'", id.Name) + expected := fmt.Sprintf("SHOW PASSWORD POLICIES LIKE '%s'", id.Name()) assert.Equal(t, expected, actual) }) t.Run("with like and in account", func(t *testing.T) { opts := &PasswordPolicyShowOptions{ Like: &Like{ - Pattern: String(id.Name), + Pattern: String(id.Name()), }, In: &In{ Account: Bool(true), @@ -200,15 +200,15 @@ func TestPasswordPolicyShow(t *testing.T) { clauses, err := builder.parseStruct(opts) require.NoError(t, err) actual := builder.sql(clauses...) - expected := fmt.Sprintf("SHOW PASSWORD POLICIES LIKE '%s' IN ACCOUNT", id.Name) + expected := fmt.Sprintf("SHOW PASSWORD POLICIES LIKE '%s' IN ACCOUNT", id.Name()) assert.Equal(t, expected, actual) }) t.Run("with like and in database", func(t *testing.T) { - databaseIdentifier := NewAccountObjectIdentifier(id.DatabaseName) + databaseIdentifier := NewAccountObjectIdentifier(id.DatabaseName()) opts := &PasswordPolicyShowOptions{ Like: &Like{ - Pattern: String(id.Name), + Pattern: String(id.Name()), }, In: &In{ Database: databaseIdentifier, @@ -217,15 +217,15 @@ func TestPasswordPolicyShow(t *testing.T) { clauses, err := builder.parseStruct(opts) require.NoError(t, err) actual := builder.sql(clauses...) - expected := fmt.Sprintf("SHOW PASSWORD POLICIES LIKE '%s' IN DATABASE %s", id.Name, databaseIdentifier.FullyQualifiedName()) + expected := fmt.Sprintf("SHOW PASSWORD POLICIES LIKE '%s' IN DATABASE %s", id.Name(), databaseIdentifier.FullyQualifiedName()) assert.Equal(t, expected, actual) }) t.Run("with like and in schema", func(t *testing.T) { - schemaIdentifier := NewSchemaIdentifier(id.DatabaseName, id.SchemaName) + schemaIdentifier := NewSchemaIdentifier(id.DatabaseName(), id.SchemaName()) opts := &PasswordPolicyShowOptions{ Like: &Like{ - Pattern: String(id.Name), + Pattern: String(id.Name()), }, In: &In{ Schema: schemaIdentifier, @@ -234,7 +234,7 @@ func TestPasswordPolicyShow(t *testing.T) { clauses, err := builder.parseStruct(opts) require.NoError(t, err) actual := builder.sql(clauses...) - expected := fmt.Sprintf("SHOW PASSWORD POLICIES LIKE '%s' IN SCHEMA %s", id.Name, schemaIdentifier.FullyQualifiedName()) + expected := fmt.Sprintf("SHOW PASSWORD POLICIES LIKE '%s' IN SCHEMA %s", id.Name(), schemaIdentifier.FullyQualifiedName()) assert.Equal(t, expected, actual) }) diff --git a/pkg/sdk/schemas.go b/pkg/sdk/schemas.go new file mode 100644 index 0000000000..a7e97b6edd --- /dev/null +++ b/pkg/sdk/schemas.go @@ -0,0 +1,11 @@ +package sdk + +// placeholder for the real implementation. +type Schema struct { + DatabaseName string + Name string +} + +func (v *Schema) ID() SchemaIdentifier { + return NewSchemaIdentifier(v.DatabaseName, v.Name) +} diff --git a/pkg/sdk/sessions.go b/pkg/sdk/sessions.go new file mode 100644 index 0000000000..ece4983b48 --- /dev/null +++ b/pkg/sdk/sessions.go @@ -0,0 +1,38 @@ +package sdk + +import ( + "context" + "fmt" +) + +type Sessions interface { + // Context functions. + UseWarehouse(ctx context.Context, warehouse AccountObjectIdentifier) error + UseDatabase(ctx context.Context, database AccountObjectIdentifier) error + UseSchema(ctx context.Context, schema SchemaIdentifier) error +} + +var _ Sessions = (*sessions)(nil) + +type sessions struct { + client *Client + builder *sqlBuilder +} + +func (c *sessions) UseWarehouse(ctx context.Context, warehouse AccountObjectIdentifier) error { + sql := fmt.Sprintf(`USE WAREHOUSE %s`, warehouse.FullyQualifiedName()) + _, err := c.client.exec(ctx, sql) + return err +} + +func (c *sessions) UseDatabase(ctx context.Context, database AccountObjectIdentifier) error { + sql := fmt.Sprintf(`USE DATABASE %s`, database.FullyQualifiedName()) + _, err := c.client.exec(ctx, sql) + return err +} + +func (c *sessions) UseSchema(ctx context.Context, schema SchemaIdentifier) error { + sql := fmt.Sprintf(`USE SCHEMA %s`, schema.FullyQualifiedName()) + _, err := c.client.exec(ctx, sql) + return err +} diff --git a/pkg/sdk/sessions_integration_test.go b/pkg/sdk/sessions_integration_test.go new file mode 100644 index 0000000000..cf628eec41 --- /dev/null +++ b/pkg/sdk/sessions_integration_test.go @@ -0,0 +1,48 @@ +package sdk + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInt_UseWarehouse(t *testing.T) { + client := testClient(t) + ctx := context.Background() + warehouseTest, warehouseCleanup := createWarehouse(t, client) + t.Cleanup(warehouseCleanup) + err := client.Sessions.UseWarehouse(ctx, warehouseTest.ID()) + require.NoError(t, err) + warehouse, err := client.ContextFunctions.CurrentWarehouse(ctx) + require.NoError(t, err) + assert.Equal(t, warehouseTest.Name, warehouse) +} + +func TestInt_UseDatabase(t *testing.T) { + client := testClient(t) + ctx := context.Background() + databaseTest, databaseCleanup := createDatabase(t, client) + t.Cleanup(databaseCleanup) + err := client.Sessions.UseDatabase(ctx, databaseTest.ID()) + require.NoError(t, err) + db, err := client.ContextFunctions.CurrentDatabase(ctx) + require.NoError(t, err) + assert.Equal(t, databaseTest.Name, db) +} + +func TestInt_UseSchema(t *testing.T) { + client := testClient(t) + ctx := context.Background() + databaseTest, databaseCleanup := createDatabase(t, client) + t.Cleanup(databaseCleanup) + schemaTest, schemaCleanup := createSchema(t, client, databaseTest) + t.Cleanup(schemaCleanup) + + err := client.Sessions.UseSchema(ctx, schemaTest.ID()) + require.NoError(t, err) + s, err := client.ContextFunctions.CurrentSchema(ctx) + require.NoError(t, err) + assert.Equal(t, schemaTest.Name, s) +} diff --git a/pkg/sdk/sql_builder.go b/pkg/sdk/sql_builder.go index be63b90179..b5df6b3851 100644 --- a/pkg/sdk/sql_builder.go +++ b/pkg/sdk/sql_builder.go @@ -7,6 +7,11 @@ import ( "unsafe" ) +// couple of helper functions. +func parentheses(s string) string { + return fmt.Sprintf("(%s)", s) +} + type quoteType string const ( @@ -53,20 +58,38 @@ func getQuoteTypeFromTag(tag reflect.StructTag, tagName string) quoteType { parts := strings.Split(t, ",") for _, part := range parts { if strings.Contains(part, "quotes") { - return quoteType(part) + return quoteType(strings.TrimSpace(part)) } } return NoQuotes } +func getUseParenthesesFromTag(tag reflect.StructTag, tagName string, defaultParentheses bool) bool { + t := strings.ToLower(tag.Get(tagName)) + if t == "" { + return defaultParentheses + } + parts := strings.Split(t, ",") + for _, part := range parts { + switch strings.TrimSpace(part) { + case "parentheses": + return true + case "no_parentheses": + return false + } + } + return defaultParentheses +} + type sqlBuilder struct{} // sql builds a SQL statement from sqlClauses. func (b *sqlBuilder) sql(clauses ...sqlClause) string { - sList := make([]string, len(clauses)) - for i, c := range clauses { - if c != nil { - sList[i] = c.String() + // remove nil and empty strings + sList := make([]string, 0) + for _, c := range clauses { + if c != nil && c.String() != "" { + sList = append(sList, c.String()) } } @@ -106,6 +129,53 @@ func (b *sqlBuilder) parseStruct(s interface{}) ([]sqlClause, error) { value = value.Elem() } + if value.Kind() == reflect.Slice { + // check if there is any keyword + ddlTag := field.Tag.Get("ddl") + if ddlTag != "" { + ddlTagParts := strings.Split(ddlTag, ",") + ddlType := ddlTagParts[0] + switch ddlType { + case "keyword": + clauses = append(clauses, sqlKeywordClause{ + value: field.Tag.Get("db"), + qt: getQuoteTypeFromTag(field.Tag, "ddl"), + }) + case "list": + listClauses := make([]sqlClause, 0) + // loop through the slice call parseStruct on each element (since the elements could be structs) + for i := 0; i < value.Len(); i++ { + v := value.Index(i).Interface() + // test if v is an ObjectIdentifier. If it is it needs to be handled separately + objectIdentifer, ok := v.(ObjectIdentifier) + if ok { + listClauses = append(listClauses, sqlIdentifierClause{ + value: objectIdentifer, + }) + continue + } + structClauses, err := b.parseStruct(value.Index(i).Interface()) + if err != nil { + return nil, err + } + // each element of the slice needs to be pre-rendered before the commas are added + renderedStructClauses := b.sql(structClauses...) + sClause := sqlStaticClause(renderedStructClauses) + listClauses = append(listClauses, sClause) + } + if len(listClauses) < 1 { + continue + } + clauses = append(clauses, sqlListClause{ + clauses: listClauses, + sep: ",", + useParentheses: getUseParenthesesFromTag(field.Tag, "ddl", true), + keyword: field.Tag.Get("db"), + }) + } + } + } + if value.Kind() == reflect.Struct { // check if there is any keyword on the struct // if there is, then we need to add it to the clause @@ -117,7 +187,7 @@ func (b *sqlBuilder) parseStruct(s interface{}) ([]sqlClause, error) { ddlType := ddlTagParts[0] switch ddlType { case "keyword": - clauses = append(clauses, sqlClauseKeyword{ + clauses = append(clauses, sqlKeywordClause{ value: field.Tag.Get("db"), qt: getQuoteTypeFromTag(field.Tag, "ddl"), }) @@ -125,7 +195,7 @@ func (b *sqlBuilder) parseStruct(s interface{}) ([]sqlClause, error) { if value.Interface().(ObjectIdentifier).FullyQualifiedName() == "" { continue } - clauses = append(clauses, sqlClauseIdentifier{ + clauses = append(clauses, sqlIdentifierClause{ key: field.Tag.Get("db"), value: value.Interface().(ObjectIdentifier), }) @@ -171,7 +241,7 @@ func (b *sqlBuilder) parseField(field reflect.StructField, value reflect.Value) // static must be applied no matter what if ddlTag == "static" { - clauses = append(clauses, sqlClauseStatic(dbTag)) + clauses = append(clauses, sqlStaticClause(dbTag)) return clauses, nil } @@ -184,7 +254,7 @@ func (b *sqlBuilder) parseField(field reflect.StructField, value reflect.Value) if value.Kind() == reflect.Bool { useKeyword := value.Interface().(bool) if useKeyword { - clause = sqlClauseKeyword{ + clause = sqlKeywordClause{ value: dbTag, qt: getQuoteTypeFromTag(field.Tag, "ddl"), } @@ -192,20 +262,24 @@ func (b *sqlBuilder) parseField(field reflect.StructField, value reflect.Value) return nil, nil } } else { - clause = sqlClauseKeyword{ - value: value.Interface().(string), + clause = sqlKeywordClause{ + value: value.Interface(), qt: getQuoteTypeFromTag(field.Tag, "ddl"), } } case "command": - clause = sqlClauseCommand{ + clause = sqlCommandClause{ key: dbTag, value: value.Interface(), qt: getQuoteTypeFromTag(field.Tag, "ddl"), } - + case "identifier": + clause = sqlIdentifierClause{ + key: dbTag, + value: value.Interface().(ObjectIdentifier), + } case "parameter": - clause = sqlClauseParameter{ + clause = sqlParameterClause{ key: dbTag, value: value.Interface(), qt: getQuoteTypeFromTag(field.Tag, "ddl"), @@ -232,58 +306,123 @@ func (b *sqlBuilder) parseUnexportedField(field reflect.StructField, value refle dbTag := field.Tag.Get("db") var clause sqlClause switch ddlType { + case "list": + // if it is a list just get the type and go back to parseStruct + f := b.getUnexportedField(value) + if f == nil { + return nil, nil + } + + listClauses := make([]sqlClause, 0) + // loop through the slice call parseStruct on each element (since the elements could be structs) + for i := 0; i < value.Len(); i++ { + u := b.getUnexportedField(value.Index(i)) + structClauses, err := b.parseStruct(u) + if err != nil { + return nil, err + } + // each element of the slice needs to be pre-rendered before the commas are added + renderedStructClauses := b.sql(structClauses...) + sClause := sqlStaticClause(renderedStructClauses) + listClauses = append(listClauses, sClause) + } + clauses = append(clauses, sqlListClause{ + clauses: listClauses, + sep: ",", + keyword: field.Tag.Get("db"), + useParentheses: getUseParenthesesFromTag(field.Tag, "ddl", true), + }) + return clauses, nil case "identifier": id := b.getUnexportedField(value).(ObjectIdentifier) if id.FullyQualifiedName() != "" { - clause = sqlClauseIdentifier{ + clause = sqlIdentifierClause{ key: dbTag, value: id, } } + case "keyword": + clause = sqlKeywordClause{ + value: b.getUnexportedField(value), + qt: getQuoteTypeFromTag(field.Tag, "ddl"), + } + case "command": + clause = sqlCommandClause{ + key: dbTag, + value: b.getUnexportedField(value), + qt: getQuoteTypeFromTag(field.Tag, "ddl"), + } case "static": - clause = sqlClauseStatic(dbTag) + clause = sqlStaticClause(dbTag) } return append(clauses, clause), nil } +type sqlListClause struct { + keyword string + clauses []sqlClause + sep string + useParentheses bool +} + +func (v sqlListClause) String() string { + var s string + // unclear if we should return parentheses at all. + if len(v.clauses) == 0 { + return s + } + clauseStrings := make([]string, len(v.clauses)) + for i, clause := range v.clauses { + clauseStrings[i] = clause.String() + } + s = strings.Join(clauseStrings, v.sep) + if v.useParentheses { + s = parentheses(s) + } + if v.keyword != "" { + s = fmt.Sprintf("%s %s", v.keyword, s) + } + return s +} + type sqlClause interface { String() string } -type sqlClauseStatic string +type sqlStaticClause string -func (v sqlClauseStatic) String() string { +func (v sqlStaticClause) String() string { return string(v) } -type sqlClauseKeyword struct { - value string +type sqlKeywordClause struct { + value interface{} qt quoteType } -func (v sqlClauseKeyword) String() string { +func (v sqlKeywordClause) String() string { return v.qt.Quote(v.value) } -type sqlClauseIdentifier struct { +type sqlIdentifierClause struct { key string value ObjectIdentifier } -func (v sqlClauseIdentifier) String() string { +func (v sqlIdentifierClause) String() string { if v.key != "" { return fmt.Sprintf("%s %s", v.key, v.value.FullyQualifiedName()) } return v.value.FullyQualifiedName() } -type sqlClauseParameter struct { +type sqlParameterClause struct { key string value interface{} // string list, string, string literal, bool, int qt quoteType } -func (v sqlClauseParameter) String() string { +func (v sqlParameterClause) String() string { vType := reflect.TypeOf(v.value) var result string if v.key != "" { @@ -298,12 +437,12 @@ func (v sqlClauseParameter) String() string { return result } -type sqlClauseCommand struct { +type sqlCommandClause struct { key string value interface{} qt quoteType } -func (v sqlClauseCommand) String() string { +func (v sqlCommandClause) String() string { return fmt.Sprintf("%s %s", v.key, v.qt.Quote(v.value)) } diff --git a/pkg/sdk/sql_builder_test.go b/pkg/sdk/sql_builder_test.go index 4c21b5d766..c472508e5a 100644 --- a/pkg/sdk/sql_builder_test.go +++ b/pkg/sdk/sql_builder_test.go @@ -455,6 +455,49 @@ func TestBuilder_parseStruct(t *testing.T) { assert.Equal(t, "EXAMPLE_PARAMETER = example", clauses[2].String()) assert.Equal(t, "EXAMPLE_COMMAND example", clauses[3].String()) }) + + t.Run("struct with a slice field using ddl: list", func(t *testing.T) { + type testListElement struct { + K *string `ddl:"parameter,single_quotes" db:"KEY"` + K2 *string `ddl:"parameter,single_quotes" db:"KEY2"` + } + s := &struct { + List []testListElement `ddl:"list" db:"TAG"` + }{ + List: []testListElement{{K: String("abc"), K2: String("def")}, {K: String("123"), K2: String("456")}}, + } + clauses, err := builder.parseStruct(s) + assert.NoError(t, err) + assert.Len(t, clauses, 1) + assert.Equal(t, "TAG (KEY = 'abc' KEY2 = 'def',KEY = '123' KEY2 = '456')", clauses[0].String()) + }) + + t.Run("struct with a slice field using ddl: list (no elements)", func(t *testing.T) { + type testListElement struct { + K *string `ddl:"parameter,single_quotes" db:"KEY"` + } + s := &struct { + List []testListElement `ddl:"list"` + }{} + clauses, err := builder.parseStruct(s) + assert.NoError(t, err) + assert.Len(t, clauses, 0) + }) + + t.Run("struct with a slice field using ddl: list (no parentheses)", func(t *testing.T) { + type testListElement struct { + K *string `ddl:"parameter,single_quotes" db:"KEY"` + } + s := &struct { + List []testListElement `ddl:"list,no_parentheses"` + }{ + List: []testListElement{{K: String("abc")}, {K: String("123")}}, + } + clauses, err := builder.parseStruct(s) + assert.NoError(t, err) + assert.Len(t, clauses, 1) + assert.Equal(t, "KEY = 'abc',KEY = '123'", clauses[0].String()) + }) } func TestBuilder_sql(t *testing.T) { @@ -467,8 +510,8 @@ func TestBuilder_sql(t *testing.T) { t.Run("test sql with clauses", func(t *testing.T) { clauses := []sqlClause{ - sqlClauseStatic("EXAMPLE_STATIC"), - sqlClauseParameter{ + sqlStaticClause("EXAMPLE_STATIC"), + sqlParameterClause{ key: "EXAMPLE_KEYWORD", value: "example", }, diff --git a/pkg/sdk/system_functions.go b/pkg/sdk/system_functions.go new file mode 100644 index 0000000000..7998bdf558 --- /dev/null +++ b/pkg/sdk/system_functions.go @@ -0,0 +1,29 @@ +package sdk + +import ( + "context" + "fmt" +) + +type SystemFunctions interface { + GetTag(ctx context.Context, tagID ObjectIdentifier, objectID ObjectIdentifier, typ ObjectType) (string, error) +} + +var _ SystemFunctions = (*systemFunctions)(nil) + +type systemFunctions struct { + client *Client + builder *sqlBuilder +} + +func (c *systemFunctions) GetTag(ctx context.Context, tagID ObjectIdentifier, objectID ObjectIdentifier, objectType ObjectType) (string, error) { + s := &struct { + Tag string `db:"TAG"` + }{} + sql := fmt.Sprintf(`SELECT SYSTEM$GET_TAG('%s', '%s', '%v') AS "TAG"`, tagID.FullyQualifiedName(), objectID.FullyQualifiedName(), objectType) + err := c.client.queryOne(ctx, s, sql) + if err != nil { + return "", err + } + return s.Tag, nil +} diff --git a/pkg/sdk/system_functions_integration_test.go b/pkg/sdk/system_functions_integration_test.go new file mode 100644 index 0000000000..6338267dfc --- /dev/null +++ b/pkg/sdk/system_functions_integration_test.go @@ -0,0 +1,52 @@ +package sdk + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInt_GetTag(t *testing.T) { + client := testClient(t) + ctx := context.Background() + databaseTest, databaseCleanup := createDatabase(t, client) + t.Cleanup(databaseCleanup) + + schemaTest, schemaCleanup := createSchema(t, client, databaseTest) + t.Cleanup(schemaCleanup) + + tagTest, tagCleanup := createTag(t, client, databaseTest, schemaTest) + t.Cleanup(tagCleanup) + + t.Run("masking policy tag", func(t *testing.T) { + maskingPolicyTest, maskingPolicyCleanup := createMaskingPolicy(t, client, databaseTest, schemaTest) + t.Cleanup(maskingPolicyCleanup) + + tagValue := randomString(t) + err := client.MaskingPolicies.Alter(ctx, maskingPolicyTest.ID(), &MaskingPolicyAlterOptions{ + Set: &MaskingPolicySet{ + Tag: []TagAssociation{ + { + Name: tagTest.ID(), + Value: tagValue, + }, + }, + }, + }) + require.NoError(t, err) + s, err := client.SystemFunctions.GetTag(ctx, tagTest.ID(), maskingPolicyTest.ID(), ObjectTypeMaskingPolicy) + require.NoError(t, err) + assert.Equal(t, tagValue, s) + }) + + t.Run("masking policy with no set tag", func(t *testing.T) { + maskingPolicyTest, maskingPolicyCleanup := createMaskingPolicy(t, client, databaseTest, schemaTest) + t.Cleanup(maskingPolicyCleanup) + + s, err := client.SystemFunctions.GetTag(ctx, tagTest.ID(), maskingPolicyTest.ID(), ObjectTypeMaskingPolicy) + require.Error(t, err) + assert.Equal(t, "", s) + }) +} diff --git a/pkg/sdk/tags.go b/pkg/sdk/tags.go new file mode 100644 index 0000000000..4648e7da61 --- /dev/null +++ b/pkg/sdk/tags.go @@ -0,0 +1,14 @@ +package sdk + +// placeholder for the real implementation. +type TagCreateOptions struct{} + +type Tag struct { + DatabaseName string + SchemaName string + Name string +} + +func (v *Tag) ID() SchemaObjectIdentifier { + return NewSchemaObjectIdentifier(v.DatabaseName, v.SchemaName, v.Name) +} diff --git a/pkg/sdk/validations.go b/pkg/sdk/validations.go new file mode 100644 index 0000000000..6bbc4fb018 --- /dev/null +++ b/pkg/sdk/validations.go @@ -0,0 +1,6 @@ +package sdk + +func IsValidDataType(v string) bool { + dt := DataTypeFromString(v) + return dt != DataTypeUnknown +} diff --git a/pkg/sdk/validations_test.go b/pkg/sdk/validations_test.go new file mode 100644 index 0000000000..80f71800ed --- /dev/null +++ b/pkg/sdk/validations_test.go @@ -0,0 +1,19 @@ +package sdk + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsValidDataType(t *testing.T) { + t.Run("with valid data type", func(t *testing.T) { + ok := IsValidDataType("VARCHAR") + assert.Equal(t, ok, true) + }) + + t.Run("with invalid data type", func(t *testing.T) { + ok := IsValidDataType("foo") + assert.Equal(t, ok, false) + }) +} diff --git a/pkg/sdk/warehouses.go b/pkg/sdk/warehouses.go new file mode 100644 index 0000000000..a61627148b --- /dev/null +++ b/pkg/sdk/warehouses.go @@ -0,0 +1,82 @@ +package sdk + +import ( + "context" + "fmt" +) + +type Warehouses interface { + // Create creates a warehouse. + Create(ctx context.Context, id AccountObjectIdentifier, opts *WarehouseCreateOptions) error + // Alter modifies an existing warehouse + Alter(ctx context.Context, id AccountObjectIdentifier, opts *WarehouseAlterOptions) error + // Drop removes a warehouse. + Drop(ctx context.Context, id AccountObjectIdentifier, opts *WarehouseDropOptions) error + // Show returns a list of warehouses. + Show(ctx context.Context, opts *WarehouseShowOptions) ([]*Warehouse, error) + // Describe returns the details of a warehouse. + Describe(ctx context.Context, id AccountObjectIdentifier) (*WarehouseDetails, error) +} + +var _ Warehouses = (*warehouses)(nil) + +type warehouses struct { + client *Client + builder *sqlBuilder +} + +type Warehouse struct { + Name string +} + +// placeholder for the real implementation. +type WarehouseCreateOptions struct{} + +func (c *warehouses) Create(ctx context.Context, id AccountObjectIdentifier, _ *WarehouseCreateOptions) error { + sql := fmt.Sprintf(`CREATE WAREHOUSE %s`, id.FullyQualifiedName()) + _, err := c.client.exec(ctx, sql) + return err +} + +// placeholder for the real implementation. +type WarehouseAlterOptions struct{} + +func (c *warehouses) Alter(ctx context.Context, id AccountObjectIdentifier, _ *WarehouseAlterOptions) error { + sql := fmt.Sprintf(`ALTER WAREHOUSE %s`, id.FullyQualifiedName()) + _, err := c.client.exec(ctx, sql) + return err +} + +// placeholder for the real implementation. +type WarehouseDropOptions struct{} + +func (c *warehouses) Drop(ctx context.Context, id AccountObjectIdentifier, _ *WarehouseDropOptions) error { + sql := fmt.Sprintf(`DROP WAREHOUSE %s`, id.FullyQualifiedName()) + _, err := c.client.exec(ctx, sql) + return err +} + +// placeholder for the real implementation. +type WarehouseShowOptions struct{} + +func (c *warehouses) Show(ctx context.Context, _ *WarehouseShowOptions) ([]*Warehouse, error) { + sql := `SHOW WAREHOUSES` + var warehouses []*Warehouse + err := c.client.query(ctx, &warehouses, sql) + return warehouses, err +} + +type WarehouseDetails struct { + Name string +} + +func (c *warehouses) Describe(ctx context.Context, id AccountObjectIdentifier) (*WarehouseDetails, error) { + sql := fmt.Sprintf(`DESCRIBE WAREHOUSE %s`, id.FullyQualifiedName()) + var details WarehouseDetails + err := c.client.queryOne(ctx, &details, sql) + return &details, err +} + +func (v *Warehouse) ID() AccountObjectIdentifier { + return NewAccountObjectIdentifier(v.Name) +} diff --git a/pkg/snowflake/all_grant.go b/pkg/snowflake/all_grant.go index 546069f83a..c2cd0f8917 100644 --- a/pkg/snowflake/all_grant.go +++ b/pkg/snowflake/all_grant.go @@ -209,7 +209,7 @@ func (fgb *AllGrantBuilder) Role(n string) GrantExecutable { } // Share is not implemented because all objects cannot be granted to shares. -func (fgb *AllGrantBuilder) Share(n string) GrantExecutable { +func (fgb *AllGrantBuilder) Share(_ string) GrantExecutable { return nil } diff --git a/pkg/snowflake/future_grant.go b/pkg/snowflake/future_grant.go index 9fd3d7d002..c2adeddb2f 100644 --- a/pkg/snowflake/future_grant.go +++ b/pkg/snowflake/future_grant.go @@ -228,7 +228,7 @@ func (fgb *FutureGrantBuilder) Role(n string) GrantExecutable { } // Share is not implemented because future objects cannot be granted to shares. -func (fgb *FutureGrantBuilder) Share(n string) GrantExecutable { +func (fgb *FutureGrantBuilder) Share(_ string) GrantExecutable { return nil } diff --git a/pkg/snowflake/masking_policy.go b/pkg/snowflake/masking_policy.go index 6137058c8c..0037a6ac7a 100644 --- a/pkg/snowflake/masking_policy.go +++ b/pkg/snowflake/masking_policy.go @@ -1,24 +1,15 @@ package snowflake import ( - "database/sql" - "errors" "fmt" - "log" "strings" - - "github.com/jmoiron/sqlx" ) // MaskingPolicyBuilder abstracts the creation of SQL queries for a Snowflake Masking Policy. type MaskingPolicyBuilder struct { - name string - db string - schema string - comment string - valueDataType string - maskingExpression string - returnDataType string + name string + db string + schema string } // QualifiedName prepends the db and schema if set and escapes everything nicely. @@ -42,30 +33,6 @@ func (mpb *MaskingPolicyBuilder) QualifiedName() string { return n.String() } -// WithComment adds a comment to the MaskingPolicyBuilder. -func (mpb *MaskingPolicyBuilder) WithComment(c string) *MaskingPolicyBuilder { - mpb.comment = EscapeString(c) - return mpb -} - -// WithValueDataType adds valueDataType to the MaskingPolicyBuilder. -func (mpb *MaskingPolicyBuilder) WithValueDataType(dataType string) *MaskingPolicyBuilder { - mpb.valueDataType = dataType - return mpb -} - -// WithMaskingExpression adds maskingExpression to the MaskingPolicyBuilder. -func (mpb *MaskingPolicyBuilder) WithMaskingExpression(maskingExpression string) *MaskingPolicyBuilder { - mpb.maskingExpression = maskingExpression - return mpb -} - -// WithReturnDataType adds returnDataType to the MaskingPolicyBuilder. -func (mpb *MaskingPolicyBuilder) WithReturnDataType(dataType string) *MaskingPolicyBuilder { - mpb.returnDataType = dataType - return mpb -} - // MaskingPolicy returns a pointer to a Builder that abstracts the DDL operations for a masking policy. // // Supported DDL operations are: @@ -83,87 +50,3 @@ func MaskingPolicy(name, db, schema string) *MaskingPolicyBuilder { schema: schema, } } - -// Create returns the SQL query that will create a masking policy. -func (mpb *MaskingPolicyBuilder) Create() string { - q := strings.Builder{} - q.WriteString(fmt.Sprintf(`CREATE MASKING POLICY %v AS (VAL %v) RETURNS %v -> `, mpb.QualifiedName(), mpb.valueDataType, mpb.returnDataType)) - - q.WriteString(mpb.maskingExpression) - - if mpb.comment != "" { - q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, EscapeString(mpb.comment))) - } - - return q.String() -} - -// Describe returns the SQL query that will describe a masking policy. -func (mpb *MaskingPolicyBuilder) Describe() string { - return fmt.Sprintf(`DESCRIBE MASKING POLICY %v`, mpb.QualifiedName()) -} - -// ChangeComment returns the SQL query that will update the comment on the masking policy. -func (mpb *MaskingPolicyBuilder) ChangeComment(c string) string { - return fmt.Sprintf(`ALTER MASKING POLICY %v SET COMMENT = '%v'`, mpb.QualifiedName(), EscapeString(c)) -} - -// RemoveComment returns the SQL query that will remove the comment on the masking policy. -func (mpb *MaskingPolicyBuilder) RemoveComment() string { - return fmt.Sprintf(`ALTER MASKING POLICY %v UNSET COMMENT`, mpb.QualifiedName()) -} - -// ChangeMaskingExpression returns the SQL query that will update the masking expression on the masking policy. -func (mpb *MaskingPolicyBuilder) ChangeMaskingExpression(maskingExpression string) string { - q := strings.Builder{} - q.WriteString(fmt.Sprintf(`ALTER MASKING POLICY %v SET BODY -> `, mpb.QualifiedName())) - - q.WriteString(maskingExpression) - - return q.String() -} - -// Drop returns the SQL query that will drop a masking policy. -func (mpb *MaskingPolicyBuilder) Drop() string { - return fmt.Sprintf(`DROP MASKING POLICY %v`, mpb.QualifiedName()) -} - -// Show returns the SQL query that will show a masking policy. -func (mpb *MaskingPolicyBuilder) Show() string { - return fmt.Sprintf(`SHOW MASKING POLICIES LIKE '%v' IN SCHEMA "%v"."%v"`, mpb.name, mpb.db, mpb.schema) -} - -type MaskingPolicyStruct struct { - CreatedOn sql.NullString `db:"created_on"` - Name sql.NullString `db:"name"` - DatabaseName sql.NullString `db:"database_name"` - SchemaName sql.NullString `db:"schema_name"` - Kind sql.NullString `db:"kind"` - Owner sql.NullString `db:"owner"` - Comment sql.NullString `db:"comment"` -} - -func ScanMaskingPolicies(row *sqlx.Row) (*MaskingPolicyStruct, error) { - m := &MaskingPolicyStruct{} - err := row.StructScan(m) - return m, err -} - -func ListMaskingPolicies(databaseName string, schemaName string, db *sql.DB) ([]MaskingPolicyStruct, error) { - stmt := fmt.Sprintf(`SHOW MASKING POLICIES IN SCHEMA "%s"."%v"`, databaseName, schemaName) - rows, err := Query(db, stmt) - if err != nil { - return nil, err - } - defer rows.Close() - - dbs := []MaskingPolicyStruct{} - if err := sqlx.StructScan(rows, &dbs); err != nil { - if errors.Is(err, sql.ErrNoRows) { - log.Println("[DEBUG] no masking policies found") - return nil, nil - } - return nil, fmt.Errorf("unable to scan row for %s err = %w", stmt, err) - } - return dbs, nil -} diff --git a/pkg/snowflake/masking_policy_test.go b/pkg/snowflake/masking_policy_test.go deleted file mode 100644 index 91026f51cf..0000000000 --- a/pkg/snowflake/masking_policy_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package snowflake_test - -import ( - "testing" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" - "github.com/stretchr/testify/require" -) - -func TestMaskingPolicyCreate(t *testing.T) { - r := require.New(t) - m := snowflake.MaskingPolicy("test_masking_policy", "test_db", "test_schema") - r.NotNil(m) - - m.WithValueDataType("string") - m.WithMaskingExpression(`case - when current_role() in ('ANALYST') then val - else '*********' -end`) - m.WithReturnDataType("string") - m.WithComment("This is a test comment") - - q := m.Create() - r.Equal(`CREATE MASKING POLICY "test_db"."test_schema"."test_masking_policy" AS (VAL string) RETURNS string -> case - when current_role() in ('ANALYST') then val - else '*********' -end COMMENT = 'This is a test comment'`, q) -} - -func TestMaskingPolicyDescribe(t *testing.T) { - r := require.New(t) - m := snowflake.MaskingPolicy("test_masking_policy", "test_db", "test_schema") - r.NotNil(m) - - q := m.Describe() - r.Equal(`DESCRIBE MASKING POLICY "test_db"."test_schema"."test_masking_policy"`, q) -} - -func TestMaskingPolicyDrop(t *testing.T) { - r := require.New(t) - m := snowflake.MaskingPolicy("test_masking_policy", "test_db", "test_schema") - r.NotNil(m) - - q := m.Drop() - r.Equal(`DROP MASKING POLICY "test_db"."test_schema"."test_masking_policy"`, q) -} - -func TestMaskingPolicyChangeComment(t *testing.T) { - r := require.New(t) - m := snowflake.MaskingPolicy("test_masking_policy", "test_db", "test_schema") - r.NotNil(m) - - q := m.ChangeComment("test comment!") - r.Equal(`ALTER MASKING POLICY "test_db"."test_schema"."test_masking_policy" SET COMMENT = 'test comment!'`, q) -} - -func TestMaskingPolicyRemoveComment(t *testing.T) { - r := require.New(t) - m := snowflake.MaskingPolicy("test_masking_policy", "test_db", "test_schema") - r.NotNil(m) - - q := m.RemoveComment() - r.Equal(`ALTER MASKING POLICY "test_db"."test_schema"."test_masking_policy" UNSET COMMENT`, q) -} - -func TestMaskingChangeMaskingExpression(t *testing.T) { - r := require.New(t) - m := snowflake.MaskingPolicy("test_masking_policy", "test_db", "test_schema") - r.NotNil(m) - - q := m.ChangeMaskingExpression(`case - when current_role() in ('ANALYST') then val - else sha2(val, 512) -end`) - - r.Equal(`ALTER MASKING POLICY "test_db"."test_schema"."test_masking_policy" SET BODY -> case - when current_role() in ('ANALYST') then val - else sha2(val, 512) -end`, q) -} diff --git a/pkg/snowflake/parser.go b/pkg/snowflake/parser.go index a326603b74..bdd429886e 100644 --- a/pkg/snowflake/parser.go +++ b/pkg/snowflake/parser.go @@ -39,7 +39,7 @@ func (e *ViewSelectStatementExtractor) Extract() (string, error) { e.consumeSpace() e.consumeToken("if not exists") e.consumeSpace() - e.consumeIdentifier() + e.consumeID() // TODO column list e.consumeSpace() e.consumeToken("copy grants") @@ -70,7 +70,7 @@ func (e *ViewSelectStatementExtractor) ExtractMaterializedView() (string, error) e.consumeSpace() e.consumeToken("if not exists") e.consumeSpace() - e.consumeIdentifier() + e.consumeID() // TODO copy grants // TODO column list e.consumeComment() @@ -118,7 +118,7 @@ func (e *ViewSelectStatementExtractor) consumeSpace() { e.pos += found } -func (e *ViewSelectStatementExtractor) consumeIdentifier() { +func (e *ViewSelectStatementExtractor) consumeID() { e.consumeNonSpace() } diff --git a/pkg/snowflake/role_ownership_grant_test.go b/pkg/snowflake/role_ownership_grant_test.go index 85fb44be2f..f932bfd4bd 100644 --- a/pkg/snowflake/role_ownership_grant_test.go +++ b/pkg/snowflake/role_ownership_grant_test.go @@ -9,18 +9,18 @@ import ( func TestRoleOwnershipGrantQuery(t *testing.T) { r := require.New(t) - copy := snowflake.NewRoleOwnershipGrantBuilder("role1", "COPY") - revoke := snowflake.NewRoleOwnershipGrantBuilder("role1", "REVOKE") + copyBuilder := snowflake.NewRoleOwnershipGrantBuilder("role1", "COPY") + revokeBuilder := snowflake.NewRoleOwnershipGrantBuilder("role1", "REVOKE") - g1 := copy.Role("role2").Grant() + g1 := copyBuilder.Role("role2").Grant() r.Equal(`GRANT OWNERSHIP ON ROLE "role1" TO ROLE "role2" COPY CURRENT GRANTS`, g1) - r1 := copy.Role("ACCOUNTADMIN").Revoke() + r1 := copyBuilder.Role("ACCOUNTADMIN").Revoke() r.Equal(`GRANT OWNERSHIP ON ROLE "role1" TO ROLE "ACCOUNTADMIN" COPY CURRENT GRANTS`, r1) - g2 := revoke.Role("role2").Grant() + g2 := revokeBuilder.Role("role2").Grant() r.Equal(`GRANT OWNERSHIP ON ROLE "role1" TO ROLE "role2" REVOKE CURRENT GRANTS`, g2) - r2 := revoke.Role("ACCOUNTADMIN").Revoke() + r2 := revokeBuilder.Role("ACCOUNTADMIN").Revoke() r.Equal(`GRANT OWNERSHIP ON ROLE "role1" TO ROLE "ACCOUNTADMIN" REVOKE CURRENT GRANTS`, r2) } diff --git a/pkg/snowflake/user_ownership_grant_test.go b/pkg/snowflake/user_ownership_grant_test.go index bcf58974d3..3750d45c6d 100644 --- a/pkg/snowflake/user_ownership_grant_test.go +++ b/pkg/snowflake/user_ownership_grant_test.go @@ -9,18 +9,18 @@ import ( func TestUserOwnershipGrantQuery(t *testing.T) { r := require.New(t) - copy := snowflake.NewUserOwnershipGrantBuilder("user1", "COPY") - revoke := snowflake.NewUserOwnershipGrantBuilder("user1", "REVOKE") + copyBuilder := snowflake.NewUserOwnershipGrantBuilder("user1", "COPY") + revokeBuilder := snowflake.NewUserOwnershipGrantBuilder("user1", "REVOKE") - g1 := copy.Role("role1").Grant() + g1 := copyBuilder.Role("role1").Grant() r.Equal(`GRANT OWNERSHIP ON USER "user1" TO ROLE "role1" COPY CURRENT GRANTS`, g1) - r1 := copy.Role("ACCOUNTADMIN").Revoke() + r1 := copyBuilder.Role("ACCOUNTADMIN").Revoke() r.Equal(`GRANT OWNERSHIP ON USER "user1" TO ROLE "ACCOUNTADMIN" COPY CURRENT GRANTS`, r1) - g2 := revoke.Role("role1").Grant() + g2 := revokeBuilder.Role("role1").Grant() r.Equal(`GRANT OWNERSHIP ON USER "user1" TO ROLE "role1" REVOKE CURRENT GRANTS`, g2) - r2 := revoke.Role("ACCOUNTADMIN").Revoke() + r2 := revokeBuilder.Role("ACCOUNTADMIN").Revoke() r.Equal(`GRANT OWNERSHIP ON USER "user1" TO ROLE "ACCOUNTADMIN" REVOKE CURRENT GRANTS`, r2) } diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go index 670ac304d5..25756ad3b6 100644 --- a/pkg/validation/validation.go +++ b/pkg/validation/validation.go @@ -137,7 +137,7 @@ func ValidateAdminName(i interface{}, k string) (s []string, errors []error) { return } -func ValidateFullyQualifiedObjectID(i interface{}, k string) (s []string, errors []error) { +func ValidateFullyQualifiedObjectID(i interface{}, _ string) (s []string, errors []error) { v, _ := i.(string) if strings.Contains(v, ".") { //nolint:gocritic // todo: please fix this tagArray := strings.Split(v, ".") diff --git a/templates/index.md.tmpl b/templates/index.md.tmpl index 1d7eae13cb..c0ef6dd459 100644 --- a/templates/index.md.tmpl +++ b/templates/index.md.tmpl @@ -28,8 +28,9 @@ The Snowflake provider support multiple ways to authenticate: * OAuth Refresh Token * Browser Auth * Private Key +* Config File -In all cases account and region are required. +In all cases account and username are required. ### Keypair Authentication Environment Variables @@ -105,3 +106,28 @@ If you choose to use Username and Password Authentication, export these credenti export SNOWFLAKE_USER='...' export SNOWFLAKE_PASSWORD='...' ``` + +### Config File + +If you choose to use a config file, the optional `profile` attribute specifies the profile to use from the config file. If no profile is specified, the default profile is used. The Snowflake config file lives at `~/.snowflake/config` and uses [TOML](https://toml.io/) format. You can override this location by setting the `SNOWFLAKE_CONFIG_PATH` environment variable. If no username and account are specified, the provider will fall back to reading the config file. + +```shell +[default] +account='TESTACCOUNT' +user='TEST_USER' +password='hunter2' +role='ACCOUNTADMIN' + +[securityadmin] +account='TESTACCOUNT' +user='TEST_USER' +password='hunter2' +role='SECURITYADMIN' +``` + +## Order Precedence + +The Snowflake provider will use the following order of precedence when determining which credentials to use: +1) Provider Configuration +2) Environment Variables +3) Config File