Skip to content

Commit

Permalink
Added Role-based access control integration tests for Spanner Change …
Browse files Browse the repository at this point in the history
…Streams (#25246)
  • Loading branch information
nuggetwheat committed Feb 7, 2023
1 parent bf5114b commit a9e80d2
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,14 @@ public interface ChangeStreamTestPipelineOptions extends IOTestPipelineOptions,
void setInstanceId(String value);

@Description("Database ID prefix to write to in Spanner")
@Default.String("changestream")
@Default.String("cstest_primary")
String getDatabaseId();

void setDatabaseId(String value);

@Description("Metadata database ID prefix to write to in Spanner")
@Default.String("cstest_metadata")
String getMetadataDatabaseId();

void setMetadataDatabaseId(String value);
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerOptions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
Expand All @@ -48,18 +49,21 @@ public class IntegrationTestEnv extends ExternalResource {
private static final String METADATA_TABLE_NAME_PREFIX = "TestMetadata";
private static final String SINGERS_TABLE_NAME_PREFIX = "Singers";
private static final String CHANGE_STREAM_NAME_PREFIX = "SingersStream";
private static final String DATABASE_ROLE = "test_role";
private List<String> changeStreams;
private List<String> tables;

private String projectId;
private String instanceId;
private String databaseId;
private String metadataDatabaseId;
private String metadataTableName;
private Spanner spanner;
private final String host = "https://spanner.googleapis.com";
private DatabaseAdminClient databaseAdminClient;
private DatabaseClient databaseClient;
private boolean isPostgres;
public boolean useSeparateMetadataDb;

@Override
protected void before() throws Throwable {
Expand All @@ -70,14 +74,13 @@ protected void before() throws Throwable {
Optional.ofNullable(options.getProjectId())
.orElseGet(() -> options.as(GcpOptions.class).getProject());
instanceId = options.getInstanceId();
databaseId = generateDatabaseName(options.getDatabaseId());
generateDatabaseIds(options);
spanner =
SpannerOptions.newBuilder().setProjectId(projectId).setHost(host).build().getService();
databaseAdminClient = spanner.getDatabaseAdminClient();
metadataTableName = generateTableName(METADATA_TABLE_NAME_PREFIX);

recreateDatabase(databaseAdminClient, instanceId, databaseId, isPostgres);

databaseClient = spanner.getDatabaseClient(DatabaseId.of(projectId, instanceId, databaseId));

changeStreams = new ArrayList<>();
Expand Down Expand Up @@ -144,10 +147,17 @@ protected void after() {
} catch (Exception e) {
LOG.error("Failed to drop database " + databaseId + ". Skipping...", e);
}

if (useSeparateMetadataDb) {
databaseAdminClient.dropDatabase(instanceId, metadataDatabaseId);
}
spanner.close();
}

void createMetadataDatabase() throws ExecutionException, InterruptedException, TimeoutException {
recreateDatabase(databaseAdminClient, instanceId, metadataDatabaseId, isPostgres);
useSeparateMetadataDb = true;
}

String createSingersTable() throws InterruptedException, ExecutionException, TimeoutException {
final String tableName = generateTableName(SINGERS_TABLE_NAME_PREFIX);
LOG.info("Creating table " + tableName);
Expand All @@ -168,7 +178,6 @@ String createSingersTable() throws InterruptedException, ExecutionException, Tim
+ ")"),
null)
.get(TIMEOUT_MINUTES, TimeUnit.MINUTES);
tables.add(tableName);
} else {
databaseAdminClient
.updateDatabaseDdl(
Expand All @@ -185,8 +194,8 @@ String createSingersTable() throws InterruptedException, ExecutionException, Tim
+ " ) PRIMARY KEY (SingerId)"),
null)
.get(TIMEOUT_MINUTES, TimeUnit.MINUTES);
tables.add(tableName);
}
tables.add(tableName);
return tableName;
}

Expand Down Expand Up @@ -214,10 +223,32 @@ String createChangeStreamFor(String tableName)
.get(TIMEOUT_MINUTES, TimeUnit.MINUTES);
}
changeStreams.add(changeStreamName);

return changeStreamName;
}

void createRoleAndGrantPrivileges(String table, String changeStream)
throws InterruptedException, ExecutionException, TimeoutException {
if (this.isPostgres) {
LOG.error("Database roles not supported with Postgres dialect.");
return;
}
databaseAdminClient
.updateDatabaseDdl(
instanceId,
databaseId,
Arrays.asList(
"CREATE ROLE " + DATABASE_ROLE,
"GRANT INSERT, UPDATE, DELETE ON TABLE " + table + " TO ROLE " + DATABASE_ROLE,
"GRANT SELECT ON CHANGE STREAM " + changeStream + " TO ROLE " + DATABASE_ROLE,
"GRANT EXECUTE ON TABLE FUNCTION READ_"
+ changeStream
+ " TO ROLE "
+ DATABASE_ROLE),
null)
.get(TIMEOUT_MINUTES, TimeUnit.MINUTES);
return;
}

String getProjectId() {
return projectId;
}
Expand All @@ -230,6 +261,14 @@ String getDatabaseId() {
return databaseId;
}

String getMetadataDatabaseId() {
return metadataDatabaseId;
}

String getDatabaseRole() {
return DATABASE_ROLE;
}

String getMetadataTableName() {
return metadataTableName;
}
Expand Down Expand Up @@ -282,10 +321,13 @@ private String generateChangeStreamName() {
MAX_CHANGE_STREAM_NAME_LENGTH - 1 - CHANGE_STREAM_NAME_PREFIX.length());
}

private String generateDatabaseName(String prefix) {
return prefix
+ "_"
+ RandomStringUtils.randomAlphanumeric(MAX_DATABASE_NAME_LENGTH - 1 - prefix.length())
private void generateDatabaseIds(ChangeStreamTestPipelineOptions options) {
int prefixLength =
Math.max(options.getDatabaseId().length(), options.getMetadataDatabaseId().length());
String suffix =
RandomStringUtils.randomAlphanumeric(MAX_DATABASE_NAME_LENGTH - 1 - prefixLength)
.toLowerCase(Locale.ROOT);
databaseId = options.getDatabaseId() + "_" + suffix;
metadataDatabaseId = options.getMetadataDatabaseId() + "_" + suffix;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeTrue;

import com.google.cloud.Timestamp;
import com.google.cloud.spanner.DatabaseClient;
Expand All @@ -38,10 +39,12 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.beam.runners.direct.DirectRunner;
import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig;
import org.apache.beam.sdk.io.gcp.spanner.SpannerIO;
import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.DataChangeRecord;
import org.apache.beam.sdk.io.gcp.spanner.changestreams.model.Mod;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.DoFn;
Expand All @@ -56,6 +59,7 @@
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

Expand All @@ -64,8 +68,12 @@
public class SpannerChangeStreamIT {

@ClassRule public static final IntegrationTestEnv ENV = new IntegrationTestEnv();

@Rule public final transient TestPipeline pipeline = TestPipeline.create();

/** Rule for exception testing. */
@Rule public ExpectedException exception = ExpectedException.none();

private static String instanceId;
private static String projectId;
private static String databaseId;
Expand All @@ -83,6 +91,8 @@ public static void beforeClass() throws Exception {
changeStreamTableName = ENV.createSingersTable();
changeStreamName = ENV.createChangeStreamFor(changeStreamTableName);
databaseClient = ENV.getDatabaseClient();
ENV.createMetadataDatabase();
ENV.createRoleAndGrantPrivileges(changeStreamTableName, changeStreamName);
}

@Before
Expand All @@ -93,6 +103,23 @@ public void before() {

@Test
public void testReadSpannerChangeStream() {
testReadSpannerChangeStreamImpl(pipeline, null);
}

@Test
public void testReadSpannerChangeStreamWithAuthorizedRole() {
testReadSpannerChangeStreamImpl(pipeline, ENV.getDatabaseRole());
}

@Test
public void testReadSpannerChangeStreamWithUnauthorizedRole() {
assumeTrue(pipeline.getOptions().getRunner() == DirectRunner.class);
exception.expect(SpannerException.class);
exception.expectMessage("Role not found: bad_role.");
testReadSpannerChangeStreamImpl(pipeline.enableAbandonedNodeEnforcement(false), "bad_role");
}

public void testReadSpannerChangeStreamImpl(TestPipeline testPipeline, String role) {
// Defines how many rows are going to be inserted / updated / deleted in the test
final int numRows = 5;
// Inserts numRows rows and uses the first commit timestamp as the startAt for reading the
Expand All @@ -106,19 +133,22 @@ public void testReadSpannerChangeStream() {
final Pair<Timestamp, Timestamp> deleteTimestamps = deleteRows(numRows);
final Timestamp endAt = deleteTimestamps.getRight();

final SpannerConfig spannerConfig =
SpannerConfig spannerConfig =
SpannerConfig.create()
.withProjectId(projectId)
.withInstanceId(instanceId)
.withDatabaseId(databaseId);
if (role != null) {
spannerConfig = spannerConfig.withDatabaseRole(StaticValueProvider.of(role));
}

final PCollection<String> tokens =
pipeline
testPipeline
.apply(
SpannerIO.readChangeStream()
.withSpannerConfig(spannerConfig)
.withChangeStreamName(changeStreamName)
.withMetadataDatabase(databaseId)
.withMetadataDatabase(ENV.getMetadataDatabaseId())
.withMetadataTable(metadataTableName)
.withInclusiveStartAt(startAt)
.withInclusiveEndAt(endAt))
Expand All @@ -143,7 +173,7 @@ public void testReadSpannerChangeStream() {
"DELETE,3,Updated First Name 3,Updated Last Name 3,null,null",
"DELETE,4,Updated First Name 4,Updated Last Name 4,null,null",
"DELETE,5,Updated First Name 5,Updated Last Name 5,null,null");
pipeline.run().waitUntilFinish();
testPipeline.run().waitUntilFinish();

assertMetadataTableHasBeenDropped();
}
Expand Down Expand Up @@ -176,7 +206,7 @@ public void testReadSpannerChangeStreamFilteredByTransactionTag() {
SpannerIO.readChangeStream()
.withSpannerConfig(spannerConfig)
.withChangeStreamName(changeStreamName)
.withMetadataDatabase(databaseId)
.withMetadataDatabase(ENV.getMetadataDatabaseId())
.withMetadataTable(metadataTableName)
.withInclusiveStartAt(startAt)
.withInclusiveEndAt(endAt))
Expand Down

0 comments on commit a9e80d2

Please sign in to comment.