Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SELECT * bug fixes #7

Merged
merged 3 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/binder/bind_node_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void BindNodeVisitor::Visit(common::ManagedPointer<parser::CopyStatement> node)

// If the table is given, we're either writing or reading all columns
std::vector<common::ManagedPointer<parser::AbstractExpression>> new_select_list;
context_->GenerateAllColumnExpressions(sherpa_->GetParseResult(), common::ManagedPointer(&new_select_list));
context_->GenerateAllColumnExpressions(sherpa_->GetParseResult(), common::ManagedPointer(&new_select_list), node->GetSelectStatement(), "");
auto col = node->GetSelectStatement()->GetSelectColumns();
col.insert(std::end(col), std::begin(new_select_list), std::end(new_select_list));
} else {
Expand Down Expand Up @@ -493,7 +493,8 @@ void BindNodeVisitor::Visit(common::ManagedPointer<parser::SelectStatement> node
BINDER_LOG_TRACE("Gathering select columns...");
for (auto &select_element : node->GetSelectColumns()) {
if (select_element->GetExpressionType() == parser::ExpressionType::STAR) {
context_->GenerateAllColumnExpressions(sherpa_->GetParseResult(), common::ManagedPointer(&new_select_list));
auto star_expression = select_element.CastManagedPointerTo<terrier::parser::StarExpression>();
context_->GenerateAllColumnExpressions(sherpa_->GetParseResult(), common::ManagedPointer(&new_select_list), node, star_expression->GetTableName());
continue;
}
if (select_element->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) {
Expand Down
33 changes: 21 additions & 12 deletions src/binder/binder_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,18 +290,27 @@ bool BinderContext::CheckNestedTableColumn(const std::string &alias, const std::

void BinderContext::GenerateAllColumnExpressions(
common::ManagedPointer<parser::ParseResult> parse_result,
common::ManagedPointer<std::vector<common::ManagedPointer<parser::AbstractExpression>>> exprs) {
common::ManagedPointer<std::vector<common::ManagedPointer<parser::AbstractExpression>>> exprs,
common::ManagedPointer<parser::SelectStatement> stmt,
std::string table_name) {
// Set containing tables whose columns are to be included in the SELECT * query results
std::unordered_set<std::string> constituent_table_aliases;
// for (auto &stmt : parse_result->GetStatements()) {
// if (stmt->GetType() == parser::StatementType::SELECT) {
// auto select_stmt = stmt.CastManagedPointerTo<parser::SelectStatement>();
// select_stmt->GetSelectTable()->GetConstituentTableAliases(&constituent_table_aliases);
// }
// }
stmt->GetSelectTable()->GetConstituentTableAliases(&constituent_table_aliases);
if (!table_name.empty()) {
if (constituent_table_aliases.count(table_name) == 0) {
// SELECT table_name.* FROM ..., where the from clause does not contain table_name
throw BINDER_EXCEPTION(fmt::format("missing FROM-clause entry for table \"{}\"", table_name),
common::ErrorCode::ERRCODE_UNDEFINED_TABLE);
}
else {
constituent_table_aliases.clear();
constituent_table_aliases.insert(table_name);
}
}

for (auto &entry : regular_table_alias_map_) {
// auto &table_alias = entry.first;
// if (constituent_table_aliases.count(table_alias) > 0) {
auto &table_alias = entry.first;
if (constituent_table_aliases.count(table_alias) > 0) {
auto &schema = std::get<2>(entry.second);
auto col_cnt = schema.GetColumns().size();
for (uint32_t i = 0; i < col_cnt; i++) {
Expand All @@ -320,12 +329,12 @@ void BinderContext::GenerateAllColumnExpressions(
auto new_tv_expr = common::ManagedPointer(parse_result->GetExpressions().back());
exprs->push_back(new_tv_expr);
}
// }
}
}

for (auto &entry : nested_table_alias_map_) {
auto &table_alias = entry.first;
// if (constituent_table_aliases.count(table_alias) != 0) {
if (constituent_table_aliases.count(table_alias) > 0) {
auto &cols = entry.second;
for (auto &col_entry : cols) {
auto tv_expr = new parser::ColumnValueExpression(std::string(table_alias), std::string(col_entry.first));
Expand All @@ -340,7 +349,7 @@ void BinderContext::GenerateAllColumnExpressions(
// All derived columns do not have bound oids, thus keep them as INVALID_OIDs
exprs->push_back(new_tv_expr);
}
// }
}
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/include/binder/binder_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,14 @@ class BinderContext {
* Generate list of column value expression that covers all columns in the alias maps of the current context
* @param parse_result Result generated by the parser. A collection of statements and expressions in the query.
* @param exprs Pointer to the list of column value expression.
* @param table_name Name of the table associated with the star expression
* The generated column value expressions will be placed in this list.
*/
void GenerateAllColumnExpressions(
common::ManagedPointer<parser::ParseResult> parse_result,
common::ManagedPointer<std::vector<common::ManagedPointer<parser::AbstractExpression>>> exprs);
common::ManagedPointer<std::vector<common::ManagedPointer<parser::AbstractExpression>>> exprs,
common::ManagedPointer<parser::SelectStatement> stmt,
std::string table_name);

/**
* Return the binder context's metadata for the provided @p table_name.
Expand Down
21 changes: 20 additions & 1 deletion src/include/parser/expression/star_expression.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <memory>
#include <string>
#include <vector>
#include "parser/expression/abstract_expression.h"

Expand All @@ -13,7 +14,22 @@ class StarExpression : public AbstractExpression {
/**
* Instantiates a new star expression, e.g. as in COUNT(*).
*/
StarExpression() : AbstractExpression(ExpressionType::STAR, type::TypeId::INTEGER, {}) {}
StarExpression() : AbstractExpression(ExpressionType::STAR, type::TypeId::INTEGER, {}) {
table_name_ = "";
}

/**
* Instantiates a new star expression with a table name, e.g. as in xxx.*
*/
StarExpression(std::string table_name) : AbstractExpression(ExpressionType::STAR, type::TypeId::INTEGER, {}) {
table_name_ = std::move(table_name);
}

/**
* Returns the table name associated with the star expression
* @return table name
*/
std::string GetTableName() {return table_name_;}

/**
* Copies this StarExpression
Expand All @@ -36,6 +52,9 @@ class StarExpression : public AbstractExpression {
}

void Accept(common::ManagedPointer<binder::SqlNodeVisitor> v) override { v->Visit(common::ManagedPointer(this)); }

private:
std::string table_name_;
};

DEFINE_JSON_HEADER_DECLARATIONS(StarExpression);
Expand Down
10 changes: 8 additions & 2 deletions src/parser/postgresparser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,9 +529,15 @@ std::unique_ptr<AbstractExpression> PostgresParser::ColumnRefTransform(ParseResu
col_name = reinterpret_cast<value *>(node)->val_.str_;
table_name = "";
} else {
auto next_node = reinterpret_cast<Node *>(fields->head->next->data.ptr_value);
col_name = reinterpret_cast<value *>(next_node)->val_.str_;
table_name = reinterpret_cast<value *>(node)->val_.str_;
auto next_node = reinterpret_cast<Node *>(fields->head->next->data.ptr_value);
if (next_node->type == T_A_Star) {
result = std::make_unique<StarExpression>(table_name);
break;
}
else {
col_name = reinterpret_cast<value *>(next_node)->val_.str_;
}
}

if (alias != nullptr)
Expand Down
2 changes: 1 addition & 1 deletion src/parser/table_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ std::unique_ptr<TableRef> TableRef::Copy() {

void TableRef::GetConstituentTableAliases(std::unordered_set<std::string> *aliases) {
if (!alias_.empty()) {
aliases->insert(alias_);
aliases->insert(GetAlias());
}
if (join_ != nullptr) {
join_->GetLeftTable()->GetConstituentTableAliases(aliases);
Expand Down