Skip to content

Commit

Permalink
Merge pull request #4239 from danlark1/master
Browse files Browse the repository at this point in the history
produce hints for typo functions and types
  • Loading branch information
alexey-milovidov authored Feb 2, 2019
2 parents c411362 + 3d00aaa commit 14f208b
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 5 deletions.
6 changes: 5 additions & 1 deletion dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
return combinator->transformAggregateFunction(nested_function, argument_types, parameters);
}

throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
auto hints = this->getHints(name);
if (!hints.empty())
throw Exception("Unknown aggregate function " + name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
else
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
}


Expand Down
13 changes: 13 additions & 0 deletions dbms/src/Common/IFactoryWithAliases.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <Common/Exception.h>
#include <Common/NamePrompter.h>
#include <Core/Types.h>
#include <Poco/String.h>

Expand Down Expand Up @@ -105,6 +106,12 @@ class IFactoryWithAliases
return aliases.count(name) || case_insensitive_aliases.count(name);
}

std::vector<String> getHints(const String & name) const
{
static const auto registered_names = getAllRegisteredNames();
return prompter.getHints(name, registered_names);
}

virtual ~IFactoryWithAliases() {}

private:
Expand All @@ -120,6 +127,12 @@ class IFactoryWithAliases

/// Case insensitive aliases
AliasMap case_insensitive_aliases;

/**
* prompter for names, if a person makes a typo for some function or type, it
* helps to find best possible match (in particular, edit distance is one or two symbols)
*/
NamePrompter</*MistakeFactor=*/2, /*MaxNumHints=*/2> prompter;
};

}
83 changes: 83 additions & 0 deletions dbms/src/Common/NamePrompter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#pragma once

#include <Core/Types.h>

#include <algorithm>
#include <cctype>
#include <queue>
#include <utility>

namespace DB
{
template <size_t MistakeFactor, size_t MaxNumHints>
class NamePrompter
{
public:
using DistanceIndex = std::pair<size_t, size_t>;
using DistanceIndexQueue = std::priority_queue<DistanceIndex>;

static std::vector<String> getHints(const String & name, const std::vector<String> & prompting_strings)
{
DistanceIndexQueue queue;
for (size_t i = 0; i < prompting_strings.size(); ++i)
appendToQueue(i, name, queue, prompting_strings);
return release(queue, prompting_strings);
}

private:
static size_t levenshteinDistance(const String & lhs, const String & rhs)
{
size_t n = lhs.size();
size_t m = rhs.size();
std::vector<std::vector<size_t>> dp(n + 1, std::vector<size_t>(m + 1));

for (size_t i = 1; i <= n; ++i)
dp[i][0] = i;

for (size_t i = 1; i <= m; ++i)
dp[0][i] = i;

for (size_t j = 1; j <= m; ++j)
{
for (size_t i = 1; i <= n; ++i)
{
if (std::tolower(lhs[i - 1]) == std::tolower(rhs[j - 1]))
dp[i][j] = dp[i - 1][j - 1];
else
dp[i][j] = std::min(dp[i - 1][j] + 1, std::min(dp[i][j - 1] + 1, dp[i - 1][j - 1] + 1));
}
}

return dp[n][m];
}

static void appendToQueue(size_t ind, const String & name, DistanceIndexQueue & queue, const std::vector<String> & prompting_strings)
{
if (prompting_strings[ind].size() <= name.size() + MistakeFactor && prompting_strings[ind].size() + MistakeFactor >= name.size())
{
size_t distance = levenshteinDistance(prompting_strings[ind], name);
if (distance <= MistakeFactor)
{
queue.emplace(distance, ind);
if (queue.size() > MaxNumHints)
queue.pop();
}
}
}

static std::vector<String> release(DistanceIndexQueue & queue, const std::vector<String> & prompting_strings)
{
std::vector<String> ans;
ans.reserve(queue.size());
while (!queue.empty())
{
auto top = queue.top();
queue.pop();
ans.push_back(prompting_strings[top.second]);
}
std::reverse(ans.begin(), ans.end());
return ans;
}
};

}
8 changes: 6 additions & 2 deletions dbms/src/DataTypes/DataTypeFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <Common/typeid_cast.h>
#include <Poco/String.h>
#include <Common/StringUtils/StringUtils.h>

#include <IO/WriteHelpers.h>

namespace DB
{
Expand Down Expand Up @@ -87,7 +87,11 @@ DataTypePtr DataTypeFactory::get(const String & family_name_param, const ASTPtr
return it->second(parameters);
}

throw Exception("Unknown data type family: " + family_name, ErrorCodes::UNKNOWN_TYPE);
auto hints = this->getHints(family_name);
if (!hints.empty())
throw Exception("Unknown data type family: " + family_name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_TYPE);
else
throw Exception("Unknown data type family: " + family_name, ErrorCodes::UNKNOWN_TYPE);
}


Expand Down
10 changes: 9 additions & 1 deletion dbms/src/Functions/FunctionFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <Poco/String.h>

#include <IO/WriteHelpers.h>

namespace DB
{

Expand Down Expand Up @@ -43,7 +45,13 @@ FunctionBuilderPtr FunctionFactory::get(
{
auto res = tryGet(name, context);
if (!res)
throw Exception("Unknown function " + name, ErrorCodes::UNKNOWN_FUNCTION);
{
auto hints = this->getHints(name);
if (!hints.empty())
throw Exception("Unknown function " + name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_FUNCTION);
else
throw Exception("Unknown function " + name, ErrorCodes::UNKNOWN_FUNCTION);
}
return res;
}

Expand Down
13 changes: 12 additions & 1 deletion dbms/src/Interpreters/ActionsVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,18 @@ void ActionsVisitor::visit(const ASTPtr & ast)
? context.getQueryContext()
: context;

const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(node->name, function_context);
FunctionBuilderPtr function_builder;
try
{
function_builder = FunctionFactory::instance().get(node->name, function_context);
}
catch (DB::Exception & e)
{
auto hints = AggregateFunctionFactory::instance().getHints(node->name);
if (!hints.empty())
e.addMessage("Or unknown aggregate function " + node->name + ". Maybe you meant: " + toString(hints));
e.rethrow();
}

Names argument_names;
DataTypes argument_types;
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env bash

set -e

CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
. $CURDIR/../shell_config.sh

$CLICKHOUSE_CLIENT -q "select c23ount(*) from system.functions;" 2>&1 | grep "Maybe you meant: \['count'" &>/dev/null;
$CLICKHOUSE_CLIENT -q "select cunt(*) from system.functions;" 2>&1 | grep "Maybe you meant: \['count'" &>/dev/null;
$CLICKHOUSE_CLIENT -q "select positin(*) from system.functions;" 2>&1 | grep "Maybe you meant: \['position'" &>/dev/null;
$CLICKHOUSE_CLIENT -q "select POSITIO(*) from system.functions;" 2>&1 | grep "Maybe you meant: \['position'" &>/dev/null;
$CLICKHOUSE_CLIENT -q "select fount(*) from system.functions;" 2>&1 | grep "Maybe you meant: \['count'" | grep "Maybe you meant: \['round'" | grep "Or unknown aggregate function" &>/dev/null;
$CLICKHOUSE_CLIENT -q "select positin(*) from system.functions;" 2>&1 | grep -v "Or unknown aggregate function" &>/dev/null;
$CLICKHOUSE_CLIENT -q "select pov(*) from system.functions;" 2>&1 | grep "Maybe you meant: \['pow','cos'\]" &>/dev/null;

0 comments on commit 14f208b

Please sign in to comment.