Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dummy transaction from extension #4239

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions extension/delta/src/main/delta_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

#include "function/delta_scan.h"
#include "main/client_context.h"
#include "main/database.h"
#include "s3_download_options.h"

namespace kuzu {
namespace delta_extension {

void DeltaExtension::load(main::ClientContext* context) {
auto& db = *context->getDatabase();
extension::ExtensionUtils::addTableFunc<DeltaScanFunction>(db);
extension::ExtensionUtils::addTableFunc<DeltaScanFunction>(context->getTransaction(), db);
httpfs::S3DownloadOptions::registerExtensionOptions(&db);
httpfs::S3DownloadOptions::setEnvValue(context);
}
Expand Down
9 changes: 5 additions & 4 deletions extension/duckdb/src/catalog/duckdb_catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ DuckDBCatalog::DuckDBCatalog(std::string dbPath, std::string catalogName,
}
}

void DuckDBCatalog::init() {
void DuckDBCatalog::init(transaction::Transaction* transaction) {
auto query = common::stringFormat(
"select table_name from information_schema.tables where table_catalog = '{}' and "
"table_schema = '{}' order by table_name;",
Expand All @@ -51,7 +51,7 @@ void DuckDBCatalog::init() {
conversionFunc(resultChunk->data[0], tableNamesVector, resultChunk->size());
for (auto i = 0u; i < resultChunk->size(); i++) {
auto tableName = tableNamesVector.getValue<common::ku_string_t>(i).getAsString();
createForeignTable(tableName);
createForeignTable(transaction, tableName);
}
}

Expand All @@ -74,7 +74,8 @@ static std::string getQuery(const binder::BoundCreateTableInfo& info) {
extraInfo->schemaName, info.tableName);
}

void DuckDBCatalog::createForeignTable(const std::string& tableName) {
void DuckDBCatalog::createForeignTable(const transaction::Transaction* transaction,
const std::string& tableName) {
auto info = bindCreateTableInfo(tableName);
if (info == nullptr) {
return;
Expand All @@ -93,7 +94,7 @@ void DuckDBCatalog::createForeignTable(const std::string& tableName) {
for (auto& definition : extraInfo->propertyDefinitions) {
tableEntry->addProperty(definition);
}
tables->createEntry(&transaction::DUMMY_TRANSACTION, std::move(tableEntry));
tables->createEntry(transaction, std::move(tableEntry));
}

static bool getTableInfo(const DuckDBConnector& connector, const std::string& tableName,
Expand Down
20 changes: 8 additions & 12 deletions extension/duckdb/src/function/clear_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "binder/binder.h"
#include "catalog/duckdb_catalog.h"
#include "main/database_manager.h"
#include "processor/execution_context.h"
#include "storage/duckdb_storage.h"

using namespace kuzu::function;
Expand All @@ -12,20 +13,14 @@ using namespace kuzu::common;
namespace kuzu {
namespace duckdb_extension {

static offset_t clearCacheTableFunc(const TableFuncInput& input,
const TableFuncOutput& /*output*/) {
const auto sharedState = input.sharedState->ptrCast<SimpleTableFuncSharedState>();
const auto morsel = sharedState->getMorsel();
if (!morsel.hasMoreToOutput()) {
return 0;
}
static offset_t clearCacheTableFunc(const TableFuncInput& input, const TableFuncOutput&) {
const auto bindData = input.bindData->constPtrCast<ClearCacheBindData>();
bindData->databaseManager->invalidateCache();
return 1;
bindData->databaseManager->invalidateCache(input.context->clientContext->getTransaction());
return 0;
}

static std::unique_ptr<TableFuncBindData> clearCacheBindFunc(const ClientContext* context,
const TableFuncBindInput* /*input*/) {
const TableFuncBindInput*) {
return std::make_unique<ClearCacheBindData>(context->getDatabaseManager());
}

Expand All @@ -34,8 +29,9 @@ function_set ClearCacheFunction::getFunctionSet() {
auto function = std::make_unique<TableFunction>(name, std::vector<LogicalTypeID>{});
function->tableFunc = clearCacheTableFunc;
function->bindFunc = clearCacheBindFunc;
function->initSharedStateFunc = SimpleTableFunction::initSharedState;
function->initLocalStateFunc = SimpleTableFunction::initEmptyLocalState;
function->initSharedStateFunc = initSharedState;
function->initLocalStateFunc = initEmptyLocalState;
function->canParallelFunc = []() { return false; };
functionSet.push_back(std::move(function));
return functionSet;
}
Expand Down
13 changes: 6 additions & 7 deletions extension/duckdb/src/include/catalog/duckdb_catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct AttachOption;

namespace duckdb_extension {

struct BoundExtraCreateDuckDBTableInfo : public binder::BoundExtraCreateTableInfo {
struct BoundExtraCreateDuckDBTableInfo final : binder::BoundExtraCreateTableInfo {
std::string catalogName;
std::string schemaName;

Expand All @@ -29,13 +29,13 @@ struct BoundExtraCreateDuckDBTableInfo : public binder::BoundExtraCreateTableInf
}
};

class DuckDBCatalog : public extension::CatalogExtension {
class DuckDBCatalog final : public extension::CatalogExtension {
public:
DuckDBCatalog(std::string dbPath, std::string catalogName, std::string defaultSchemaName,
main::ClientContext* context, const DuckDBConnector& connector,
const binder::AttachOption& attachOption);

void init() override;
void init(transaction::Transaction* transaction) override;

static std::string bindSchemaName(const binder::AttachOption& options,
const std::string& defaultName);
Expand All @@ -45,11 +45,10 @@ class DuckDBCatalog : public extension::CatalogExtension {
std::vector<binder::PropertyDefinition>& propertyDefinitions);

private:
virtual std::unique_ptr<binder::BoundCreateTableInfo> bindCreateTableInfo(
const std::string& tableName);
std::unique_ptr<binder::BoundCreateTableInfo> bindCreateTableInfo(const std::string& tableName);

private:
void createForeignTable(const std::string& tableName);
void createForeignTable(const transaction::Transaction* transaction,
const std::string& tableName);

protected:
std::string dbPath;
Expand Down
6 changes: 3 additions & 3 deletions extension/duckdb/src/include/function/clear_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ struct ClearCacheBindData final : function::SimpleTableFuncBindData {
main::DatabaseManager* databaseManager;

explicit ClearCacheBindData(main::DatabaseManager* databaseManager)
: SimpleTableFuncBindData{binder::expression_vector{}, 1 /* maxOffset */},
databaseManager{databaseManager} {}
: SimpleTableFuncBindData{0}, databaseManager{databaseManager} {}

std::unique_ptr<TableFuncBindData> copy() const override {
return std::make_unique<ClearCacheBindData>(databaseManager);
}
};

struct ClearCacheFunction final : function::TableFunction {
struct ClearCacheFunction final : function::SimpleTableFunction {
static constexpr const char* name = "clear_attached_db_cache";

static function::function_set getFunctionSet();
};

Expand Down
2 changes: 1 addition & 1 deletion extension/duckdb/src/include/storage/duckdb_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class DuckDBStorageExtension final : public storage::StorageExtension {

static constexpr bool SKIP_UNSUPPORTED_TABLE_DEFAULT_VAL = false;

explicit DuckDBStorageExtension(main::Database* database);
explicit DuckDBStorageExtension(transaction::Transaction* transaction, main::Database* db);

bool canHandleDB(std::string dbType) const override;
};
Expand Down
3 changes: 2 additions & 1 deletion extension/duckdb/src/main/duckdb_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ namespace duckdb_extension {

void DuckDBExtension::load(main::ClientContext* context) {
auto db = context->getDatabase();
db->registerStorageExtension(EXTENSION_NAME, std::make_unique<DuckDBStorageExtension>(db));
db->registerStorageExtension(EXTENSION_NAME,
std::make_unique<DuckDBStorageExtension>(context->getTransaction(), db));
httpfs::S3DownloadOptions::registerExtensionOptions(db);
httpfs::S3DownloadOptions::setEnvValue(context);
}
Expand Down
8 changes: 4 additions & 4 deletions extension/duckdb/src/storage/duckdb_storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <filesystem>

#include "binder/bound_attach_info.h"
#include "catalog/duckdb_catalog.h"
#include "common/file_system/virtual_file_system.h"
#include "common/string_utils.h"
Expand Down Expand Up @@ -32,14 +31,15 @@ std::unique_ptr<main::AttachedDatabase> attachDuckDB(std::string dbName, std::st

auto duckdbCatalog = std::make_unique<DuckDBCatalog>(std::move(dbPath), std::move(catalogName),
schemaName, clientContext, *connector, attachOption);
duckdbCatalog->init();
duckdbCatalog->init(clientContext->getTransaction());
return std::make_unique<AttachedDuckDBDatabase>(dbName, DuckDBStorageExtension::DB_TYPE,
std::move(duckdbCatalog), std::move(connector));
}

DuckDBStorageExtension::DuckDBStorageExtension(main::Database* db)
DuckDBStorageExtension::DuckDBStorageExtension(transaction::Transaction* transaction,
main::Database* db)
: StorageExtension{attachDuckDB} {
extension::ExtensionUtils::addStandaloneTableFunc<ClearCacheFunction>(*db);
extension::ExtensionUtils::addStandaloneTableFunc<ClearCacheFunction>(transaction, *db);
}

bool DuckDBStorageExtension::canHandleDB(std::string dbType_) const {
Expand Down
14 changes: 7 additions & 7 deletions extension/fts/src/fts_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include "catalog/catalog.h"
#include "catalog/fts_index_catalog_entry.h"
#include "common/serializer/buffered_reader.h"
#include "function/create_fts_index.h"
#include "function/drop_fts_index.h"
#include "function/query_fts_index.h"
Expand All @@ -15,20 +14,21 @@ namespace fts_extension {

using namespace extension;

static void initFTSEntries(const transaction::Transaction* transaction, catalog::Catalog& catalog) {
static void initFTSEntries(const transaction::Transaction* transaction,
const catalog::Catalog& catalog) {
for (auto& indexEntry : catalog.getIndexEntries(transaction)) {
if (indexEntry->getIndexType() == FTSIndexCatalogEntry::TYPE_NAME) {
indexEntry->setAuxInfo(FTSIndexAuxInfo::deserialize(indexEntry->getAuxBufferReader()));
}
}
}

void FTSExtension::load(main::ClientContext* context) {
void FTSExtension::load(const main::ClientContext* context) {
auto& db = *context->getDatabase();
ExtensionUtils::addScalarFunc<StemFunction>(db);
ExtensionUtils::addGDSFunc<QueryFTSFunction>(db);
ExtensionUtils::addStandaloneTableFunc<CreateFTSFunction>(db);
ExtensionUtils::addStandaloneTableFunc<DropFTSFunction>(db);
ExtensionUtils::addScalarFunc<StemFunction>(context->getTransaction(), db);
ExtensionUtils::addGDSFunc<QueryFTSFunction>(context->getTransaction(), db);
ExtensionUtils::addStandaloneTableFunc<CreateFTSFunction>(context->getTransaction(), db);
ExtensionUtils::addStandaloneTableFunc<DropFTSFunction>(context->getTransaction(), db);
initFTSEntries(context->getTransaction(), *db.getCatalog());
}

Expand Down
2 changes: 1 addition & 1 deletion extension/fts/src/include/fts_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class FTSExtension final : public extension::Extension {
static constexpr char EXTENSION_NAME[] = "FTS";

public:
static void load(main::ClientContext* context);
static void load(const main::ClientContext* context);

static constexpr const char* EN_STOP_WORDS[] = {"a", "a's", "able", "about", "above",
"according", "accordingly", "across", "actually", "after", "afterwards", "again", "against",
Expand Down
7 changes: 3 additions & 4 deletions extension/iceberg/src/iceberg_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include "function/iceberg_functions.h"
#include "main/client_context.h"
#include "main/database.h"
#include "s3_download_options.h"

namespace kuzu {
Expand All @@ -12,9 +11,9 @@ using namespace kuzu::extension;

void IcebergExtension::load(main::ClientContext* context) {
auto& db = *context->getDatabase();
extension::ExtensionUtils::addTableFunc<IcebergScanFunction>(db);
extension::ExtensionUtils::addTableFunc<IcebergMetadataFunction>(db);
extension::ExtensionUtils::addTableFunc<IcebergSnapshotsFunction>(db);
ExtensionUtils::addTableFunc<IcebergScanFunction>(context->getTransaction(), db);
ExtensionUtils::addTableFunc<IcebergMetadataFunction>(context->getTransaction(), db);
ExtensionUtils::addTableFunc<IcebergSnapshotsFunction>(context->getTransaction(), db);
httpfs::S3DownloadOptions::registerExtensionOptions(&db);
httpfs::S3DownloadOptions::setEnvValue(context);
}
Expand Down
2 changes: 1 addition & 1 deletion extension/json/src/include/json_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class JsonExtension final : public extension::Extension {
static constexpr char JSON_TYPE_NAME[] = "json";
static constexpr common::idx_t JSON_SCAN_FILE_IDX = 0;

static void load(main::ClientContext* context);
static void load(const main::ClientContext* context);
};

} // namespace json_extension
Expand Down
53 changes: 27 additions & 26 deletions extension/json/src/json_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,40 @@ namespace json_extension {

using namespace kuzu::extension;

static void addJsonCreationFunction(main::Database& db) {
ExtensionUtils::addScalarFunc<ToJsonFunction>(db);
ExtensionUtils::addScalarFuncAlias<JsonQuoteFunction>(db);
ExtensionUtils::addScalarFuncAlias<ArrayToJsonFunction>(db);
ExtensionUtils::addScalarFuncAlias<RowToJsonFunction>(db);
ExtensionUtils::addScalarFuncAlias<CastToJsonFunction>(db);
ExtensionUtils::addScalarFunc<JsonArrayFunction>(db);
ExtensionUtils::addScalarFunc<JsonObjectFunction>(db);
ExtensionUtils::addScalarFunc<JsonMergePatchFunction>(db);
static void addJsonCreationFunction(transaction::Transaction* transaction, main::Database& db) {
ExtensionUtils::addScalarFunc<ToJsonFunction>(transaction, db);
ExtensionUtils::addScalarFuncAlias<JsonQuoteFunction>(transaction, db);
ExtensionUtils::addScalarFuncAlias<ArrayToJsonFunction>(transaction, db);
ExtensionUtils::addScalarFuncAlias<RowToJsonFunction>(transaction, db);
ExtensionUtils::addScalarFuncAlias<CastToJsonFunction>(transaction, db);
ExtensionUtils::addScalarFunc<JsonArrayFunction>(transaction, db);
ExtensionUtils::addScalarFunc<JsonObjectFunction>(transaction, db);
ExtensionUtils::addScalarFunc<JsonMergePatchFunction>(transaction, db);
}

static void addJsonExtractFunction(main::Database& db) {
ExtensionUtils::addScalarFunc<JsonExtractFunction>(db);
static void addJsonExtractFunction(transaction::Transaction* transaction, main::Database& db) {
ExtensionUtils::addScalarFunc<JsonExtractFunction>(transaction, db);
}

static void addJsonScalarFunction(main::Database& db) {
ExtensionUtils::addScalarFunc<JsonArrayLengthFunction>(db);
ExtensionUtils::addScalarFunc<JsonContainsFunction>(db);
ExtensionUtils::addScalarFunc<JsonKeysFunction>(db);
ExtensionUtils::addScalarFunc<JsonStructureFunction>(db);
ExtensionUtils::addScalarFunc<JsonValidFunction>(db);
ExtensionUtils::addScalarFunc<MinifyJsonFunction>(db);
static void addJsonScalarFunction(transaction::Transaction* transaction, main::Database& db) {
ExtensionUtils::addScalarFunc<JsonArrayLengthFunction>(transaction, db);
ExtensionUtils::addScalarFunc<JsonContainsFunction>(transaction, db);
ExtensionUtils::addScalarFunc<JsonKeysFunction>(transaction, db);
ExtensionUtils::addScalarFunc<JsonStructureFunction>(transaction, db);
ExtensionUtils::addScalarFunc<JsonValidFunction>(transaction, db);
ExtensionUtils::addScalarFunc<MinifyJsonFunction>(transaction, db);
}

void JsonExtension::load(main::ClientContext* context) {
void JsonExtension::load(const main::ClientContext* context) {
auto transaction = context->getTransaction();
auto& db = *context->getDatabase();
db.getCatalog()->createType(&transaction::DUMMY_TRANSACTION, JSON_TYPE_NAME,
JsonType::getJsonType());
addJsonCreationFunction(db);
addJsonExtractFunction(db);
addJsonScalarFunction(db);
ExtensionUtils::addScalarFunc<JsonExportFunction>(db);
ExtensionUtils::addTableFunc<JsonScan>(db);
KU_ASSERT(!db.getCatalog()->containsType(context->getTransaction(), JSON_TYPE_NAME));
db.getCatalog()->createType(context->getTransaction(), JSON_TYPE_NAME, JsonType::getJsonType());
addJsonCreationFunction(transaction, db);
addJsonExtractFunction(transaction, db);
addJsonScalarFunction(transaction, db);
ExtensionUtils::addScalarFunc<JsonExportFunction>(transaction, db);
ExtensionUtils::addTableFunc<JsonScan>(transaction, db);
}

} // namespace json_extension
Expand Down
2 changes: 1 addition & 1 deletion extension/postgres/src/include/main/postgres_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PostgresExtension final : public extension::Extension {
static constexpr char EXTENSION_NAME[] = "POSTGRES";

public:
static void load(main::ClientContext* context);
static void load(const main::ClientContext* context);
};

} // namespace postgres_extension
Expand Down
3 changes: 2 additions & 1 deletion extension/postgres/src/include/storage/postgres_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class PostgresStorageExtension final : public storage::StorageExtension {

static constexpr const char* DEFAULT_SCHEMA_NAME = "public";

explicit PostgresStorageExtension(main::Database* database);
explicit PostgresStorageExtension(transaction::Transaction* transaction,
main::Database* database);

bool canHandleDB(std::string dbType) const override;
};
Expand Down
5 changes: 3 additions & 2 deletions extension/postgres/src/main/postgres_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
namespace kuzu {
namespace postgres_extension {

void PostgresExtension::load(main::ClientContext* context) {
void PostgresExtension::load(const main::ClientContext* context) {
auto db = context->getDatabase();
db->registerStorageExtension(EXTENSION_NAME, std::make_unique<PostgresStorageExtension>(db));
db->registerStorageExtension(EXTENSION_NAME,
std::make_unique<PostgresStorageExtension>(context->getTransaction(), db));
}

} // namespace postgres_extension
Expand Down
7 changes: 4 additions & 3 deletions extension/postgres/src/storage/postgres_storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ std::unique_ptr<main::AttachedDatabase> attachPostgres(std::string dbName, std::
connector->connect(dbPath, catalogName, schemaName, clientContext);
auto catalog = std::make_unique<duckdb_extension::DuckDBCatalog>(dbPath, catalogName,
schemaName, clientContext, *connector, attachOption);
catalog->init();
catalog->init(clientContext->getTransaction());
return std::make_unique<duckdb_extension::AttachedDuckDBDatabase>(dbName,
PostgresStorageExtension::DB_TYPE, std::move(catalog), std::move(connector));
}

PostgresStorageExtension::PostgresStorageExtension(main::Database* database)
PostgresStorageExtension::PostgresStorageExtension(transaction::Transaction* transaction,
main::Database* database)
: StorageExtension{attachPostgres} {
extension::ExtensionUtils::addStandaloneTableFunc<duckdb_extension::ClearCacheFunction>(
*database);
transaction, *database);
}

bool PostgresStorageExtension::canHandleDB(std::string dbType_) const {
Expand Down
Loading