Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConnectionContext and ShardingSphereMetaData in QueryContext, call setCurrentDatabaseName in ConnectionSession #31971

Merged
merged 3 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.shardingsphere.agent.plugin.metrics.core.fixture.collector.MetricsCollectorFixture;
import org.apache.shardingsphere.infra.binder.context.statement.UnknownSQLStatementContext;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sql.parser.statement.mysql.dml.MySQLDeleteStatement;
Expand Down Expand Up @@ -53,25 +54,29 @@ void reset() {

@Test
void assertInsertRoute() {
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLInsertStatement()), "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLInsertStatement()), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class),
mock(ShardingSphereMetaData.class));
assertRoute(queryContext, "INSERT=1");
}

@Test
void assertUpdateRoute() {
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLUpdateStatement()), "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLUpdateStatement()), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class),
mock(ShardingSphereMetaData.class));
assertRoute(queryContext, "UPDATE=1");
}

@Test
void assertDeleteRoute() {
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLDeleteStatement()), "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLDeleteStatement()), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class),
mock(ShardingSphereMetaData.class));
assertRoute(queryContext, "DELETE=1");
}

@Test
void assertSelectRoute() {
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLSelectStatement()), "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext = new QueryContext(new UnknownSQLStatementContext(new MySQLSelectStatement()), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class),
mock(ShardingSphereMetaData.class));
assertRoute(queryContext, "SELECT=1");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.datasource.pool.props.domain.DataSourcePoolProperties;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.node.StorageNode;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
Expand Down Expand Up @@ -132,7 +133,8 @@ private ShardingSphereDatabase mockSingleDatabase() {
private QueryContext createQueryContext() {
CreateTableStatement createTableStatement = new MySQLCreateTableStatement(false);
createTableStatement.setTable(new SimpleTableSegment(new TableNameSegment(1, 2, new IdentifierValue("t_order"))));
return new QueryContext(new CreateTableStatementContext(createTableStatement, DefaultDatabase.LOGIC_NAME), "CREATE TABLE", new LinkedList<>(), new HintValueContext());
return new QueryContext(new CreateTableStatementContext(createTableStatement, DefaultDatabase.LOGIC_NAME), "CREATE TABLE", new LinkedList<>(), new HintValueContext(),
mock(ConnectionContext.class), mock(ShardingSphereMetaData.class));
}

private Map<String, DataSource> createMultiDataSourceMap() throws SQLException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.instance.ComputeNodeInstanceContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
Expand Down Expand Up @@ -84,7 +85,8 @@ void setUp() {
@Test
void assertDecorateRouteContextToPrimaryDataSource() {
RouteContext actual = mockRouteContext();
QueryContext queryContext = new QueryContext(mock(SQLStatementContext.class), "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext =
new QueryContext(mock(SQLStatementContext.class), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class));
RuleMetaData ruleMetaData = new RuleMetaData(Collections.singleton(staticRule));
ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME,
mock(DatabaseType.class), mock(ResourceMetaData.class, RETURNS_DEEP_STUBS), ruleMetaData, Collections.emptyMap());
Expand All @@ -100,7 +102,7 @@ void assertDecorateRouteContextToReplicaDataSource() {
MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class);
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
when(selectStatement.getLock()).thenReturn(Optional.empty());
QueryContext queryContext = new QueryContext(sqlStatementContext, "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext = new QueryContext(sqlStatementContext, "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class));
RuleMetaData ruleMetaData = new RuleMetaData(Collections.singleton(staticRule));
ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME,
mock(DatabaseType.class), mock(ResourceMetaData.class, RETURNS_DEEP_STUBS), ruleMetaData, Collections.emptyMap());
Expand All @@ -116,7 +118,7 @@ void assertDecorateRouteContextToPrimaryDataSourceWithLock() {
MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class);
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
when(selectStatement.getLock()).thenReturn(Optional.of(mock(LockSegment.class)));
QueryContext queryContext = new QueryContext(sqlStatementContext, "", Collections.emptyList(), new HintValueContext());
QueryContext queryContext = new QueryContext(sqlStatementContext, "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class));
RuleMetaData ruleMetaData = new RuleMetaData(Collections.singleton(staticRule));
ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME,
mock(DatabaseType.class), mock(ResourceMetaData.class, RETURNS_DEEP_STUBS), ruleMetaData, Collections.emptyMap());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.shardingsphere.shadow.route.engine;

import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.DeleteStatementContext;
Expand Down Expand Up @@ -46,13 +48,17 @@ class ShadowRouteEngineFactoryTest {

@Test
void assertNewInstance() {
ShadowRouteEngine shadowInsertRouteEngine = ShadowRouteEngineFactory.newInstance(new QueryContext(createInsertSqlStatementContext(), "", Collections.emptyList(), new HintValueContext()));
ShadowRouteEngine shadowInsertRouteEngine = ShadowRouteEngineFactory.newInstance(
new QueryContext(createInsertSqlStatementContext(), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)));
assertThat(shadowInsertRouteEngine, instanceOf(ShadowInsertStatementRoutingEngine.class));
ShadowRouteEngine shadowUpdateRouteEngine = ShadowRouteEngineFactory.newInstance(new QueryContext(createUpdateSqlStatementContext(), "", Collections.emptyList(), new HintValueContext()));
ShadowRouteEngine shadowUpdateRouteEngine = ShadowRouteEngineFactory.newInstance(
new QueryContext(createUpdateSqlStatementContext(), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)));
assertThat(shadowUpdateRouteEngine, instanceOf(ShadowUpdateStatementRoutingEngine.class));
ShadowRouteEngine shadowDeleteRouteEngine = ShadowRouteEngineFactory.newInstance(new QueryContext(createDeleteSqlStatementContext(), "", Collections.emptyList(), new HintValueContext()));
ShadowRouteEngine shadowDeleteRouteEngine = ShadowRouteEngineFactory.newInstance(
new QueryContext(createDeleteSqlStatementContext(), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)));
assertThat(shadowDeleteRouteEngine, instanceOf(ShadowDeleteStatementRoutingEngine.class));
ShadowRouteEngine shadowSelectRouteEngine = ShadowRouteEngineFactory.newInstance(new QueryContext(createSelectSqlStatementContext(), "", Collections.emptyList(), new HintValueContext()));
ShadowRouteEngine shadowSelectRouteEngine = ShadowRouteEngineFactory.newInstance(
new QueryContext(createSelectSqlStatementContext(), "", Collections.emptyList(), new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)));
assertThat(shadowSelectRouteEngine, instanceOf(ShadowSelectStatementRoutingEngine.class));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sharding.api.config.strategy.audit.ShardingAuditStrategyConfiguration;
import org.apache.shardingsphere.sharding.exception.audit.DMLWithoutShardingKeyException;
Expand Down Expand Up @@ -79,15 +81,17 @@ void setUp() {
@Test
void assertCheckSuccess() {
RuleMetaData globalRuleMetaData = mock(RuleMetaData.class);
new ShardingSQLAuditor().audit(new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext), globalRuleMetaData, databases.get("foo_db"), rule);
new ShardingSQLAuditor().audit(new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext, mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)),
globalRuleMetaData, databases.get("foo_db"), rule);
verify(rule.getAuditors().get("auditor_1")).check(sqlStatementContext, Collections.emptyList(), globalRuleMetaData, databases.get("foo_db"));
}

@Test
void assertCheckSuccessByDisableAuditNames() {
when(auditStrategy.isAllowHintDisable()).thenReturn(true);
RuleMetaData globalRuleMetaData = mock(RuleMetaData.class);
new ShardingSQLAuditor().audit(new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext), globalRuleMetaData, databases.get("foo_db"), rule);
new ShardingSQLAuditor().audit(new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext, mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)),
globalRuleMetaData, databases.get("foo_db"), rule);
verify(rule.getAuditors().get("auditor_1"), times(0)).check(sqlStatementContext, Collections.emptyList(), globalRuleMetaData, databases.get("foo_db"));
}

Expand All @@ -97,7 +101,8 @@ void assertCheckFailed() {
RuleMetaData globalRuleMetaData = mock(RuleMetaData.class);
doThrow(new DMLWithoutShardingKeyException()).when(auditAlgorithm).check(sqlStatementContext, Collections.emptyList(), globalRuleMetaData, databases.get("foo_db"));
DMLWithoutShardingKeyException ex = assertThrows(DMLWithoutShardingKeyException.class, () -> new ShardingSQLAuditor().audit(
new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext), globalRuleMetaData, databases.get("foo_db"), rule));
new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext, mock(ConnectionContext.class), mock(ShardingSphereMetaData.class)), globalRuleMetaData,
databases.get("foo_db"), rule));
assertThat(ex.getMessage(), is("Not allow DML operation without sharding conditions."));
verify(rule.getAuditors().get("auditor_1")).check(sqlStatementContext, Collections.emptyList(), globalRuleMetaData, databases.get("foo_db"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable;
import org.apache.shardingsphere.infra.parser.sql.SQLStatementParserEngine;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sharding.api.config.ShardingRuleConfiguration;
Expand Down Expand Up @@ -145,7 +146,7 @@ private ShardingSphereDatabase createDatabase(final ShardingRule shardingRule, f

private QueryContext createQueryContext(final ShardingSphereDatabase database, final String sql, final List<Object> params) {
SQLStatementContext sqlStatementContext = new SQLBindEngine(createShardingSphereMetaData(database), DATABASE_NAME, new HintValueContext()).bind(parse(sql), params);
return new QueryContext(sqlStatementContext, sql, params, new HintValueContext());
return new QueryContext(sqlStatementContext, sql, params, new HintValueContext(), mock(ConnectionContext.class), mock(ShardingSphereMetaData.class));
}

private ShardingSphereMetaData createShardingSphereMetaData(final ShardingSphereDatabase database) {
Expand Down
Loading