Skip to content

Commit

Permalink
feat(database): add SQL query execution support #171
Browse files Browse the repository at this point in the history
Added `executeSqlQuery` method in `DatabaseSchemaAssistant` to execute SQL queries. Introduced `Execute` enum value in `DatabaseFunctionProvider` to handle SQL query requests. This extends the database tool's functionality to directly execute SQL statements.
  • Loading branch information
phodal committed Dec 30, 2024
1 parent e7dbb02 commit 660eb67
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 85 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
package com.phodal.shire.database

import com.intellij.database.console.JdbcConsoleProvider
import com.intellij.database.model.DasTable
import com.intellij.database.model.ObjectKind
import com.intellij.database.model.RawDataSource
import com.intellij.database.psi.DbDataSource
import com.intellij.database.psi.DbPsiFacade
import com.intellij.database.settings.DatabaseSettings
import com.intellij.database.util.DasUtil
import com.intellij.openapi.fileEditor.FileEditorManager
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiManager
import com.intellij.sql.psi.SqlPsiFacade
import com.intellij.testFramework.LightVirtualFile

object DatabaseSchemaAssistant {
fun getDataSources(project: Project): List<DbDataSource> {
Expand Down Expand Up @@ -43,6 +49,39 @@ object DatabaseSchemaAssistant {
return dasTables.filter { it.name == tableName }.toList()
}

fun executeSqlQuery(project: Project, sql: String): List<Map<String, Any?>> {
val file = LightVirtualFile("temp.sql", sql)
val psiFile = PsiManager.getInstance(project).findFile(file)
?: throw IllegalArgumentException("ShireError[Database]: No file found")

val fileEditor = FileEditorManager.getInstance(project).openFile(file).firstOrNull()
?: throw IllegalArgumentException("ShireError[Database]: No editor found")

val editor = FileEditorManager.getInstance(project).selectedTextEditor
?: throw IllegalArgumentException("ShireError[Database]: No editor found")

val dataSource = getAllRawDatasource(project).firstOrNull()
?: throw IllegalArgumentException("ShireError[Database]: No database found")

// val activeConnections = DatabaseConnectionManager.getInstance().activeConnections
// val first: DatabaseConnection = activeConnections.firstOrNull()

val execOptions = DatabaseSettings.getSettings().execOptions.last()
val console = JdbcConsoleProvider.getValidConsole(project, file)
// val elementAt = JdbcConsoleProvider.elementAt(psiFile, null, editor)
// JdbcConsoleProvider.findScriptModel(psiFile, elementAt, editor, execOption)
// val dbSession = JdbcConsoleProvider.findOrCreateSession(project, file)
// val info = JdbcConsoleProvider.Info()
val scriptModel =
if (console != null) console.scriptModel else SqlPsiFacade.getInstance(project).createScriptModel(psiFile)
// val m = ScriptModelUtil.adjustModelForSelection(model, document, selectionRange, execOption)
// JdbcConsoleProvider.Info(file, file, editor as EditorEx?, m, execOption, null as NotNullFunction<*, *>?)

// JdbcConsoleProvider.doRunQueryInConsole(console, info)
console!!.executeQueries(editor, scriptModel, execOptions)
return emptyList()
}

private fun isSQLiteTable(
rawDataSource: RawDataSource,
table: DasTable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import com.phodal.shirecore.provider.function.ToolchainFunctionProvider
enum class DatabaseFunction(val funName: String) {
Table("table"),
Column("column"),

Execute("execute")
;

companion object {
Expand All @@ -34,114 +34,127 @@ class DatabaseFunctionProvider : ToolchainFunctionProvider {
val databaseFunction = DatabaseFunction.fromString(funcName)
?: throw IllegalArgumentException("Shire[Database]: Invalid Database function name")

when (databaseFunction) {
DatabaseFunction.Table -> {
if (args.isEmpty()) {
val dataSource = DatabaseSchemaAssistant.getAllRawDatasource(project).firstOrNull()
?: return "ShireError[Database]: No database found"
return DatabaseSchemaAssistant.getTableByDataSource(dataSource)
return when (databaseFunction) {
DatabaseFunction.Table -> executeTableFunction(args, project)
DatabaseFunction.Column -> executeColumnFunction(args, project)
DatabaseFunction.Execute -> executeSqlFunction(args, project)
}
}

private fun executeTableFunction(args: List<Any>, project: Project): Any {
if (args.isEmpty()) {
val dataSource = DatabaseSchemaAssistant.getAllRawDatasource(project).firstOrNull()
?: return "ShireError[Database]: No database found"
return DatabaseSchemaAssistant.getTableByDataSource(dataSource)
}

val dbName = args.first()
// for example: [accounts, payment_limits, transactions]
var result = mutableListOf<DasTable>()
when (dbName) {
is String -> {
if (dbName.startsWith("[") && dbName.endsWith("]")) {
val tableNames = dbName.substring(1, dbName.length - 1).split(",")
result = tableNames.map {
getTable(project, it.trim())
}.flatten().toMutableList()
} else {
result = getTable(project, dbName).toMutableList()
}
}

val dbName = args.first()
// for example: [accounts, payment_limits, transactions]
var result = mutableListOf<DasTable>()
when (dbName) {
is String -> {
if (dbName.startsWith("[") && dbName.endsWith("]")) {
val tableNames = dbName.substring(1, dbName.length - 1).split(",")
result = tableNames.map {
getTable(project, it.trim())
}.flatten().toMutableList()
} else {
result = getTable(project, dbName).toMutableList()
}
}
is List<*> -> {
result = dbName.map {
getTable(project, it as String)
}.flatten().toMutableList()
}

is List<*> -> {
result = dbName.map {
getTable(project, it as String)
}.flatten().toMutableList()
}
else -> {

else -> {
}
}

}
}
return result
}

private fun executeSqlFunction(args: List<Any>, project: Project): Any {
if (args.isEmpty()) {
return "ShireError[DBTool]: SQL function requires a SQL query"
}

val sqlQuery = args.first()
return DatabaseSchemaAssistant.executeSqlQuery(project, sqlQuery as String)
}

return result
private fun executeColumnFunction(args: List<Any>, project: Project): Any {
if (args.isEmpty()) {
val allTables = DatabaseSchemaAssistant.getAllTables(project)
return allTables.map {
DatabaseSchemaAssistant.getTableColumn(it)
}
}

DatabaseFunction.Column -> {
if (args.isEmpty()) {
val allTables = DatabaseSchemaAssistant.getAllTables(project)
return allTables.map {
DatabaseSchemaAssistant.getTableColumn(it)
}
when (val first = args[0]) {
is RawDataSource -> {
return if (args.size == 1) {
DatabaseSchemaAssistant.getTableByDataSource(first)
} else {
DatabaseSchemaAssistant.getTable(first, args[1] as String)
}
}

when (val first = args[0]) {
is DasTable -> {
return DatabaseSchemaAssistant.getTableColumn(first)
}

is List<*> -> {
return when (first.first()) {
is RawDataSource -> {
return if (args.size == 1) {
DatabaseSchemaAssistant.getTableByDataSource(first)
} else {
DatabaseSchemaAssistant.getTable(first, args[1] as String)
return first.map {
DatabaseSchemaAssistant.getTableByDataSource(it as RawDataSource)
}
}

is DasTable -> {
return DatabaseSchemaAssistant.getTableColumn(first)
return first.map {
DatabaseSchemaAssistant.getTableColumn(it as DasTable)
}
}

is List<*> -> {
return when (first.first()) {
is RawDataSource -> {
return first.map {
DatabaseSchemaAssistant.getTableByDataSource(it as RawDataSource)
}
}

is DasTable -> {
return first.map {
DatabaseSchemaAssistant.getTableColumn(it as DasTable)
}
}

else -> {
"ShireError[DBTool]: Table function requires a data source or a list of table names"
}
}
else -> {
"ShireError[DBTool]: Table function requires a data source or a list of table names"
}
}
}

is String -> {
val allTables = DatabaseSchemaAssistant.getAllTables(project)
if (first.startsWith("[") && first.endsWith("]")) {
val tableNames = first.substring(1, first.length - 1).split(",")
return tableNames.mapNotNull {
val dasTable = allTables.firstOrNull { table ->
table.name == it.trim()
}

dasTable?.let {
DatabaseSchemaAssistant.getTableColumn(it)
}
}
} else {
val dasTable = allTables.firstOrNull { table ->
table.name == first
}

return dasTable?.let {
DatabaseSchemaAssistant.getTableColumn(it)
} ?: "ShireError[DBTool]: Table not found"
is String -> {
val allTables = DatabaseSchemaAssistant.getAllTables(project)
if (first.startsWith("[") && first.endsWith("]")) {
val tableNames = first.substring(1, first.length - 1).split(",")
return tableNames.mapNotNull {
val dasTable = allTables.firstOrNull { table ->
table.name == it.trim()
}
}

else -> {
logger<DatabaseFunctionProvider>().error("ShireError[DBTool] args types: ${first.javaClass}")
return "ShireError[DBTool]: Table function requires a data source or a list of table names"
dasTable?.let {
DatabaseSchemaAssistant.getTableColumn(it)
}
}
} else {
val dasTable = allTables.firstOrNull { table ->
table.name == first
}

return dasTable?.let {
DatabaseSchemaAssistant.getTableColumn(it)
} ?: "ShireError[DBTool]: Table not found"
}
}

else -> {
logger<DatabaseFunctionProvider>().error("ShireError[DBTool] args types: ${first.javaClass}")
return "ShireError[DBTool]: Table function requires a data source or a list of table names"
}
}
}

Expand Down

0 comments on commit 660eb67

Please sign in to comment.