Skip to content

Commit

Permalink
Merge ecdd0f2 into 9e26730
Browse files Browse the repository at this point in the history
  • Loading branch information
igormunkin authored Sep 10, 2024
2 parents 9e26730 + ecdd0f2 commit bd224a9
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 21 deletions.
75 changes: 56 additions & 19 deletions ydb/library/yql/minikql/comp_nodes/mkql_block_map_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ namespace NMiniKQL {

namespace {

template <typename TItem>
const TVector<TItem> TupleToVector(const TRuntimeNode tupleNode) {
const auto tupleLiteral = AS_VALUE(TTupleLiteral, tupleNode);
TVector<TItem> vector;
vector.reserve(tupleLiteral->GetValuesCount());
for (ui32 i = 0; i < tupleLiteral->GetValuesCount(); i++) {
const auto item = AS_VALUE(TDataLiteral, tupleLiteral->GetValue(i));
vector.emplace_back(item->AsValue().Get<TItem>());
}
return vector;
}

size_t CalcMaxBlockLength(const TVector<TType*>& items) {
return CalcBlockLen(std::accumulate(items.cbegin(), items.cend(), 0ULL,
[](size_t max, const TType* type) {
Expand All @@ -27,13 +39,14 @@ template <bool RightRequired>
class TBlockJoinState : public TBlockState {
public:
TBlockJoinState(TMemoryUsageInfo* memInfo, TComputationContext& ctx,
const TVector<TType*>& inputItems,
const TVector<TType*>& inputItems, const TSet<ui32>& keyDrops,
const TVector<TType*> outputItems,
NUdf::TUnboxedValue**const fields)
: TBlockState(memInfo, outputItems.size())
, InputWidth_(inputItems.size() - 1)
, OutputWidth_(outputItems.size() - 1)
, Inputs_(inputItems.size())
, KeyDrops_(keyDrops)
, InputsDescr_(ToValueDescr(inputItems))
{
const auto& pgBuilder = ctx.Builder->GetPgBuilder();
Expand All @@ -46,6 +59,9 @@ class TBlockJoinState : public TBlockState {
}
// The last output column (i.e. block length) doesn't require a block builder.
for (size_t i = 0; i < OutputWidth_; i++) {
if (KeyDrops_.contains(i)) {
continue;
}
const TType* blockItemType = AS_TYPE(TBlockType, outputItems[i])->GetItemType();
Builders_.push_back(MakeArrayBuilder(TTypeInfoHelper(), blockItemType, ctx.ArrowMemoryPool, MaxLength_, &pgBuilder, &BuilderAllocatedSize_));
}
Expand Down Expand Up @@ -145,10 +161,16 @@ class TBlockJoinState : public TBlockState {

private:
void AddItem(const TBlockItem& item, size_t idx) {
if (KeyDrops_.contains(idx)) {
return;
}
Builders_[idx]->Add(item);
}

void AddValue(const NUdf::TUnboxedValuePod& value, size_t idx) {
if (KeyDrops_.contains(idx)) {
return;
}
Builders_[idx]->Add(value);
}

Expand All @@ -164,6 +186,7 @@ class TBlockJoinState : public TBlockState {
size_t InputWidth_;
size_t OutputWidth_;
TUnboxedValueVector Inputs_;
const TSet<ui32> KeyDrops_;
const std::vector<arrow::ValueDescr> InputsDescr_;
TVector<std::unique_ptr<IBlockReader>> Readers_;
TVector<std::unique_ptr<IBlockItemConverter>> Converters_;
Expand All @@ -178,12 +201,13 @@ using TState = TBlockJoinState<RightRequired>;
public:
TBlockWideMapJoinWrapper(TComputationMutables& mutables,
const TVector<TType*>&& resultJoinItems, const TVector<TType*>&& leftFlowItems,
TVector<ui32>&& leftKeyColumns,
const TVector<ui32>&& leftKeyColumns, const TVector<ui32>&& leftKeyDrops,
IComputationWideFlowNode* flow, IComputationNode* dict)
: TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
, ResultJoinItems_(std::move(resultJoinItems))
, LeftFlowItems_(std::move(leftFlowItems))
, LeftKeyColumns_(std::move(leftKeyColumns))
, LeftKeyDrops_(leftKeyDrops.cbegin(), leftKeyDrops.cend())
, Flow_(flow)
, Dict_(dict)
, WideFieldsIndex_(mutables.IncrementWideFieldsIndex(LeftFlowItems_.size()))
Expand Down Expand Up @@ -248,7 +272,8 @@ using TState = TBlockJoinState<RightRequired>;
}

void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
state = ctx.HolderFactory.Create<TState>(ctx, LeftFlowItems_, ResultJoinItems_, ctx.WideFields.data() + WideFieldsIndex_);
state = ctx.HolderFactory.Create<TState>(ctx, LeftFlowItems_, LeftKeyDrops_,
ResultJoinItems_, ctx.WideFields.data() + WideFieldsIndex_);
}

TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
Expand All @@ -267,6 +292,7 @@ using TState = TBlockJoinState<RightRequired>;
const TVector<TType*> ResultJoinItems_;
const TVector<TType*> LeftFlowItems_;
const TVector<ui32> LeftKeyColumns_;
const TSet<ui32> LeftKeyDrops_;
IComputationWideFlowNode* const Flow_;
IComputationNode* const Dict_;
ui32 WideFieldsIndex_;
Expand All @@ -280,12 +306,13 @@ using TState = TBlockJoinState<RightRequired>;
public:
TBlockWideMultiMapJoinWrapper(TComputationMutables& mutables,
const TVector<TType*>&& resultJoinItems, const TVector<TType*>&& leftFlowItems,
TVector<ui32>&& leftKeyColumns,
const TVector<ui32>&& leftKeyColumns, const TVector<ui32>&& leftKeyDrops,
IComputationWideFlowNode* flow, IComputationNode* dict)
: TBaseComputation(mutables, flow, EValueRepresentation::Boxed, EValueRepresentation::Boxed)
, ResultJoinItems_(std::move(resultJoinItems))
, LeftFlowItems_(std::move(leftFlowItems))
, LeftKeyColumns_(std::move(leftKeyColumns))
, LeftKeyDrops_(leftKeyDrops.cbegin(), leftKeyDrops.cend())
, Flow_(flow)
, Dict_(dict)
, WideFieldsIndex_(mutables.IncrementWideFieldsIndex(LeftFlowItems_.size()))
Expand Down Expand Up @@ -357,7 +384,8 @@ using TState = TBlockJoinState<RightRequired>;
}

void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
state = ctx.HolderFactory.Create<TState>(ctx, LeftFlowItems_, ResultJoinItems_, ctx.WideFields.data() + WideFieldsIndex_);
state = ctx.HolderFactory.Create<TState>(ctx, LeftFlowItems_, LeftKeyDrops_,
ResultJoinItems_, ctx.WideFields.data() + WideFieldsIndex_);
}

TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
Expand Down Expand Up @@ -413,6 +441,7 @@ using TState = TBlockJoinState<RightRequired>;
const TVector<TType*> ResultJoinItems_;
const TVector<TType*> LeftFlowItems_;
const TVector<ui32> LeftKeyColumns_;
const TSet<ui32> LeftKeyDrops_;
IComputationWideFlowNode* const Flow_;
IComputationNode* const Dict_;
ui32 WideFieldsIndex_;
Expand All @@ -421,7 +450,7 @@ using TState = TBlockJoinState<RightRequired>;
} // namespace

IComputationNode* WrapBlockMapJoinCore(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args");
MKQL_ENSURE(callable.GetInputsCount() == 5, "Expected 5 args");

const auto joinType = callable.GetType()->GetReturnType();
MKQL_ENSURE(joinType->IsFlow(), "Expected WideFlow as a resulting stream");
Expand Down Expand Up @@ -459,16 +488,18 @@ IComputationNode* WrapBlockMapJoinCore(TCallable& callable, const TComputationNo
Y_ENSURE(joinKind == EJoinKind::Inner || joinKind == EJoinKind::Left ||
joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly);

const auto tupleLiteral = AS_VALUE(TTupleLiteral, callable.GetInput(3));
TVector<ui32> leftKeyColumns;
leftKeyColumns.reserve(tupleLiteral->GetValuesCount());
for (ui32 i = 0; i < tupleLiteral->GetValuesCount(); i++) {
const auto item = AS_VALUE(TDataLiteral, tupleLiteral->GetValue(i));
leftKeyColumns.emplace_back(item->AsValue().Get<ui32>());
}
const auto leftKeyColumns = TupleToVector<ui32>(callable.GetInput(3));
// TODO: Handle multi keys.
Y_ENSURE(leftKeyColumns.size() == 1);

const auto leftKeyDrops = TupleToVector<ui32>(callable.GetInput(4));
const TSet<ui32> leftKeySet(leftKeyColumns.cbegin(), leftKeyColumns.cend());
for (const auto& drop : leftKeyDrops) {
MKQL_ENSURE(leftKeySet.contains(drop),
"Only key columns has to be specified in drop column set");

}

const auto flow = LocateNode(ctx.NodeLocator, callable, 0);
const auto dict = LocateNode(ctx.NodeLocator, callable, 1);

Expand All @@ -477,28 +508,34 @@ IComputationNode* WrapBlockMapJoinCore(TCallable& callable, const TComputationNo
case EJoinKind::Inner:
if (isMulti) {
return new TBlockWideMultiMapJoinWrapper<true>(ctx.Mutables,
std::move(joinItems), std::move(leftFlowItems), std::move(leftKeyColumns),
std::move(joinItems), std::move(leftFlowItems),
std::move(leftKeyColumns), std::move(leftKeyDrops),
static_cast<IComputationWideFlowNode*>(flow), dict);
}
return new TBlockWideMapJoinWrapper<false, true>(ctx.Mutables,
std::move(joinItems), std::move(leftFlowItems), std::move(leftKeyColumns),
std::move(joinItems), std::move(leftFlowItems),
std::move(leftKeyColumns), std::move(leftKeyDrops),
static_cast<IComputationWideFlowNode*>(flow), dict);
case EJoinKind::Left:
if (isMulti) {
return new TBlockWideMultiMapJoinWrapper<false>(ctx.Mutables,
std::move(joinItems), std::move(leftFlowItems), std::move(leftKeyColumns),
std::move(joinItems), std::move(leftFlowItems),
std::move(leftKeyColumns), std::move(leftKeyDrops),
static_cast<IComputationWideFlowNode*>(flow), dict);
}
return new TBlockWideMapJoinWrapper<false, false>(ctx.Mutables,
std::move(joinItems), std::move(leftFlowItems), std::move(leftKeyColumns),
std::move(joinItems), std::move(leftFlowItems),
std::move(leftKeyColumns), std::move(leftKeyDrops),
static_cast<IComputationWideFlowNode*>(flow), dict);
case EJoinKind::LeftSemi:
return new TBlockWideMapJoinWrapper<true, true>(ctx.Mutables,
std::move(joinItems), std::move(leftFlowItems), std::move(leftKeyColumns),
std::move(joinItems), std::move(leftFlowItems),
std::move(leftKeyColumns), std::move(leftKeyDrops),
static_cast<IComputationWideFlowNode*>(flow), dict);
case EJoinKind::LeftOnly:
return new TBlockWideMapJoinWrapper<true, false>(ctx.Mutables,
std::move(joinItems), std::move(leftFlowItems), std::move(leftKeyColumns),
std::move(joinItems), std::move(leftFlowItems),
std::move(leftKeyColumns), std::move(leftKeyDrops),
static_cast<IComputationWideFlowNode*>(flow), dict);
default:
MKQL_ENSURE(false, "BlockMapJoinCore doesn't support %s join type"
Expand Down
16 changes: 15 additions & 1 deletion ydb/library/yql/minikql/mkql_program_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5855,7 +5855,8 @@ TRuntimeNode TProgramBuilder::ScalarApply(const TArrayRef<const TRuntimeNode>& a
}

TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode flow, TRuntimeNode dict,
EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns
EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns,
const TArrayRef<const ui32>& leftKeyDrops
) {
if constexpr (RuntimeVersion < 51U) {
THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
Expand All @@ -5864,6 +5865,11 @@ TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode flow, TRuntimeNode d
joinKind == EJoinKind::LeftSemi || joinKind == EJoinKind::LeftOnly,
"Unsupported join kind");
MKQL_ENSURE(!leftKeyColumns.empty(), "At least one key column must be specified");
const TSet<ui32> leftKeySet(leftKeyColumns.cbegin(), leftKeyColumns.cend());
for (const auto& drop : leftKeyDrops) {
MKQL_ENSURE(leftKeySet.contains(drop),
"Only key columns has to be specified in drop column set");
}

TRuntimeNode::TList leftKeyColumnsNodes;
leftKeyColumnsNodes.reserve(leftKeyColumns.size());
Expand All @@ -5872,6 +5878,13 @@ TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode flow, TRuntimeNode d
return NewDataLiteral(idx);
});

TRuntimeNode::TList leftKeyDropsNodes;
leftKeyDropsNodes.reserve(leftKeyDrops.size());
std::transform(leftKeyDrops.cbegin(), leftKeyDrops.cend(),
std::back_inserter(leftKeyDropsNodes), [this](const ui32 idx) {
return NewDataLiteral(idx);
});

auto returnJoinItems = ValidateBlockFlowType(flow.GetStaticType(), false);
const auto payloadType = AS_TYPE(TDictType, dict.GetStaticType())->GetPayloadType();
const auto payloadItemType = payloadType->IsList()
Expand Down Expand Up @@ -5907,6 +5920,7 @@ TRuntimeNode TProgramBuilder::BlockMapJoinCore(TRuntimeNode flow, TRuntimeNode d
callableBuilder.Add(dict);
callableBuilder.Add(NewDataLiteral((ui32)joinKind));
callableBuilder.Add(NewTuple(leftKeyColumnsNodes));
callableBuilder.Add(NewTuple(leftKeyDropsNodes));

return TRuntimeNode(callableBuilder.Build(), false);
}
Expand Down
3 changes: 2 additions & 1 deletion ydb/library/yql/minikql/mkql_program_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ class TProgramBuilder : public TTypeBuilder {
TRuntimeNode BlockPgResolvedCall(const std::string_view& name, ui32 id,
const TArrayRef<const TRuntimeNode>& args, TType* returnType);
TRuntimeNode BlockMapJoinCore(TRuntimeNode flow, TRuntimeNode dict,
EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns);
EJoinKind joinKind, const TArrayRef<const ui32>& leftKeyColumns,
const TArrayRef<const ui32>& leftKeyDrops = {});

//-- logical functions
TRuntimeNode BlockNot(TRuntimeNode data);
Expand Down

0 comments on commit bd224a9

Please sign in to comment.