Skip to content

Commit

Permalink
Add type inference, ConcatArray
Browse files Browse the repository at this point in the history
  • Loading branch information
shkoo committed Aug 23, 2024
1 parent f627433 commit faa60c4
Show file tree
Hide file tree
Showing 16 changed files with 241 additions and 68 deletions.
2 changes: 1 addition & 1 deletion zirgen/Conversions/Typing/BuiltinComponents.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ void Builtins::addBuiltins() {
genTrivial("String", Zhlt::getStringType(ctx));

makeSpecialBuiltin<Zhlt::BuiltinArrayOp>("Array");
// makeSpecialBuiltin<Zhlt::BuiltinConcatArrayOp>("ConcatArray");
makeSpecialBuiltin<Zhlt::BuiltinArrayOp>("ConcatArray");
makeSpecialBuiltin<Zhlt::BuiltinLogOp>("Log");
}

Expand Down
52 changes: 37 additions & 15 deletions zirgen/Conversions/Typing/ComponentManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ using namespace zirgen::Zhl;

class ComponentManagerImpl : public Zhlt::ComponentManager {
public:
Zhlt::ComponentTypeAttr getGlobalReference(mlir::Location loc, mlir::StringAttr name) override;
Zhlt::ComponentTypeAttr specialize(mlir::Location loc,
Zhlt::ComponentTypeAttr orig,
llvm::ArrayRef<mlir::Attribute> typeArgs) override;
mlir::LogicalResult requireComponent(mlir::Location loc, Zhlt::ComponentTypeAttr name) override;
mlir::LogicalResult requireComponentInferringType(mlir::Location loc,
Zhlt::ComponentTypeAttr& name,
mlir::ValueRange constructArgs) override;
mlir::LogicalResult requireAbstractComponent(mlir::Location loc,
Zhlt::ComponentTypeAttr name) override;
mlir::Type getLayoutType(Zhlt::ComponentTypeAttr component) override;
Expand All @@ -63,6 +65,7 @@ class ComponentManagerImpl : public Zhlt::ComponentManager {
size_t distance = 0) override;
std::optional<llvm::SmallVector<mlir::Type>>
getConstructParams(Zhlt::ComponentTypeAttr component) override;
Zhlt::ComponentTypeAttr getNameForType(mlir::Type type) override;

private:
struct TypeInfo;
Expand Down Expand Up @@ -113,6 +116,10 @@ class ComponentManagerImpl : public Zhlt::ComponentManager {
// unmangled component name to use to reconstruct them.
llvm::DenseMap<mlir::Type, Zhlt::ComponentTypeAttr> reconstructTypes;

// Value types of components that have been required, and the
// unmangled component name.
llvm::DenseMap<mlir::Type, Zhlt::ComponentTypeAttr> valueTypes;

friend std::optional<mlir::ModuleOp> typeCheck(mlir::MLIRContext&, mlir::ModuleOp);
};

Expand All @@ -134,12 +141,12 @@ mlir::LogicalResult ComponentManagerImpl::requireComponent(Location loc,

// Otherwise, try to instantiate a generic zhl.component
if (!intf) {
if (auto zhlOp = zhlModule.lookupSymbol<ComponentOp>(name.getName()))
if (auto zhlOp = getUnloweredComponent(name.getName()))
intf = genComponent(loc, name.getName(), name.getTypeArgs());
}

if (!intf) {
emitError(loc) << "Unable to instantiate component";
emitError(loc) << "Unable to instantiate component " << name;
return failure();
}

Expand All @@ -150,9 +157,32 @@ mlir::LogicalResult ComponentManagerImpl::requireComponent(Location loc,
if (auto layout = intf.getLayoutType(this, name)) {
reconstructTypes[layout] = name;
}
if (auto valType = intf.getValueType(this, name)) {
valueTypes[valType] = name;
}
return success();
}

mlir::LogicalResult ComponentManagerImpl::requireComponentInferringType(
Location loc, Zhlt::ComponentTypeAttr& name, ValueRange constructArgs) {
if (requiredComponents.contains(name))
return success();

// Attempt to find an interface for specialization
Zhlt::ComponentOpInterface intf =
zhltModule.lookupSymbol<Zhlt::ComponentOpInterface>(name.getName());
if (!intf)
intf = zhlModule.lookupSymbol<Zhlt::ComponentOpInterface>(name.getName());

if (intf) {
intf.inferType(this, name, constructArgs);
} else if (auto zhlOp = zhlModule.lookupSymbol<ComponentOp>(name.getName())) {
// TODO: attempt to infer types for user-defined components
}

return requireComponent(loc, name);
}

mlir::LogicalResult ComponentManagerImpl::requireAbstractComponent(Location loc,
Zhlt::ComponentTypeAttr name) {
if (requiredComponents.contains(name))
Expand All @@ -167,18 +197,6 @@ mlir::LogicalResult ComponentManagerImpl::requireAbstractComponent(Location loc,
return failure();
}

// Fails and emits an error if the given name isn't resolvable. Doesn't require it to be any
// specific type.
Zhlt::ComponentTypeAttr ComponentManagerImpl::getGlobalReference(mlir::Location loc,
mlir::StringAttr name) {
if (zhlModule.lookupSymbol(name) || zhltModule.lookupSymbol(name)) {
return Zhlt::ComponentTypeAttr::get(name);
}

emitError(loc) << "Unable to resolve " << name;
return {};
}

Zhlt::ComponentTypeAttr ComponentManagerImpl::specialize(mlir::Location loc,
Zhlt::ComponentTypeAttr orig,
llvm::ArrayRef<mlir::Attribute> typeArgs) {
Expand Down Expand Up @@ -231,6 +249,10 @@ ComponentManagerImpl::getConstructParams(Zhlt::ComponentTypeAttr name) {
return getComponentInterface(name).getConstructParams(this, name);
}

Zhlt::ComponentTypeAttr ComponentManagerImpl::getNameForType(mlir::Type type) {
return valueTypes.lookup(type);
}

struct ComponentManagerImpl::DebugListener : public OpBuilder::Listener {
void notifyOperationInserted(Operation* op, IRRewriter::InsertPoint previous) override {
LLVM_DEBUG({
Expand Down
28 changes: 23 additions & 5 deletions zirgen/Conversions/Typing/ZhlComponent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ class LoweringImpl {

// Evaluiates the given value as an instantiatable type.
ComponentTypeAttr asTypeName(Value v);
ComponentTypeAttr asTypeNameInferringType(Value v, ValueRange args);

// Evaluiates the given value as a potentially generic type.
ComponentTypeAttr asAbstractTypeName(Value v);
Expand Down Expand Up @@ -689,7 +690,7 @@ void LoweringImpl::gen(StringOp string, ComponentBuilder& cb) {
}

void LoweringImpl::gen(GlobalOp global, ComponentBuilder& cb) {
auto name = ComponentTypeAttr::get(global.getNameAttr());
auto name = ComponentTypeAttr::get(global.getNameAttr(), /*typeArgs=*/{});
if (failed(componentManager->requireAbstractComponent(global.getLoc(), name)))
throw(MalformedIRException());
typeNameMapping[global.getOut()] = name;
Expand Down Expand Up @@ -769,8 +770,8 @@ void LoweringImpl::gen(SpecializeOp specialize, ComponentBuilder& cb) {
}

void LoweringImpl::gen(ConstructOp construct, ComponentBuilder& cb) {
auto component = asTypeName(construct.getType());
auto args = llvm::map_to_vector(construct.getArgs(), [&](auto arg) { return asValue(arg); });
auto component = asTypeNameInferringType(construct.getType(), args);
Value layout = addOrExpandLayoutMember(
construct.getLoc(), cb, construct.getOut(), componentManager->getLayoutType(component));

Expand Down Expand Up @@ -1042,9 +1043,16 @@ void LoweringImpl::gen(SwitchOp sw, ComponentBuilder& cb) {

Type armResultType = Zhlt::getLeastCommonSuper(armTypes);
assert(armResultType);
Type commonArmLayoutType = Zhlt::getLeastCommonSuper(armLayouts, /*isLayout=*/1);
Value superLayout = muxContext.addLayoutMember(sw.getLoc(), "@super", commonArmLayoutType);
LLVM_DEBUG({ llvm::dbgs() << "Switch arm common layout: " << commonArmLayoutType << "\n"; });

Value superLayout;
Type commonArmLayoutType;
if (armLayouts.size() == size) {
// All arms have layouts; save the common super.
commonArmLayoutType = Zhlt::getLeastCommonSuper(armLayouts, /*isLayout=*/1);
superLayout = muxContext.addLayoutMember(sw.getLoc(), "@super", commonArmLayoutType);
LLVM_DEBUG({ llvm::dbgs() << "Switch arm common layout: " << commonArmLayoutType << "\n"; });
}

SmallVector<Value> selectorValues;
for (size_t i = 0; i != size; ++i) {
auto indexOp = builder.create<Zll::ConstOp>(sw.getLoc(), i);
Expand Down Expand Up @@ -1344,6 +1352,16 @@ ComponentTypeAttr LoweringImpl::asTypeName(Value v) {
throw MalformedIRException();
}

ComponentTypeAttr LoweringImpl::asTypeNameInferringType(Value v, ValueRange args) {
auto typeName = asAbstractTypeName(v);
if (succeeded(componentManager->requireComponentInferringType(v.getLoc(), typeName, args))) {
return typeName;
}

emitError(v.getLoc()) << "component type is not instantiatable";
throw MalformedIRException();
}

ComponentTypeAttr LoweringImpl::asAbstractTypeName(Value v) {
Attribute attr = asConstant(v);
if (!attr) {
Expand Down
5 changes: 1 addition & 4 deletions zirgen/Dialect/ZHLT/IR/Attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,9 @@ def ComponentTypeAttr : ZhltAttr<"ComponentType", "component_type"> {
AttrBuilder<(ins "mlir::StringRef":$name, "llvm::ArrayRef<mlir::Attribute>":$typeArgs), [{
return $_get(context, mlir::StringAttr::get(context, name), typeArgs);
}]>,
AttrBuilderWithInferredContext<(ins "mlir::StringAttr":$name), [{
return $_get(name.getContext(), name, llvm::ArrayRef<mlir::Attribute>{});
}]>,
AttrBuilderWithInferredContext<(ins "mlir::StringAttr":$name, "llvm::ArrayRef<mlir::Attribute>":$typeArgs), [{
return $_get(name.getContext(), name, typeArgs);
}]>
}]>,
];
let assemblyFormat = [{ $name `<` $typeArgs `>` }];
}
Expand Down
115 changes: 104 additions & 11 deletions zirgen/Dialect/ZHLT/IR/BuiltinOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
#include "zirgen/Dialect/Zll/IR/IR.h"
#include "zirgen/Dialect/Zll/IR/Interpreter.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"

#define GET_OP_CLASSES
#include "zirgen/Dialect/ZHLT/IR/BuiltinOps.cpp.inc"

#define DEBUG_TYPE "zhlt"

using namespace mlir;

namespace zirgen::Zhlt {
Expand All @@ -31,11 +34,11 @@ Value BuiltinLogOp::buildConstruct(ComponentManager* manager,
auto valType = getValType(getContext());
for (auto arg : constructArgs.drop_front(1)) {
if (!isCoercibleTo(arg.getType(), valType)) {
mlir::emitError(loc) << "Cannot coerce to " << valType << "\n";
}
args.push_back(coerceTo(arg, valType, builder));
mlir::emitError(loc) << "Cannot coerce " << arg << " to " << valType << "\n";
} else
args.push_back(coerceTo(arg, valType, builder));
}
builder.create<Zll::ExternOp>(loc, /*outTypes=*/TypeRange{}, args, "log", fmt);
builder.create<Zll::ExternOp>(loc, /*outTypes=*/TypeRange{}, args, "Log", fmt);
return builder.create<ZStruct::PackOp>(loc, manager->getValueType(name), ValueRange{});
}

Expand All @@ -49,6 +52,9 @@ BuiltinArrayOp::requireComponent(ComponentManager* manager, Location loc, Compon
if (!elemType) {
return mlir::emitError(loc) << "Unknown type of array element";
}
if (failed(manager->requireComponent(loc, elemType))) {
return mlir::emitError(loc) << "Can't instantiate element type";
}
auto elemCount = llvm::dyn_cast<PolynomialAttr>(name.getTypeArgs()[1]);
if (!elemCount) {
return mlir::emitError(loc) << "Unable to determine array length";
Expand All @@ -63,8 +69,17 @@ BuiltinArrayOp::requireComponent(ComponentManager* manager, Location loc, Compon
mlir::Type BuiltinArrayOp::getValueType(ComponentManager* manager, ComponentTypeAttr name) {
auto elemType = llvm::cast<ComponentTypeAttr>(name.getTypeArgs()[0]);
auto elemCount = llvm::cast<PolynomialAttr>(name.getTypeArgs()[1]);
return ZStruct::ArrayType::get(getContext(), manager->getValueType(elemType), elemCount[0]);
}

return ZStruct::LayoutArrayType::get(getContext(), manager->getValueType(elemType), elemCount[0]);
mlir::Type BuiltinArrayOp::getLayoutType(ComponentManager* manager, ComponentTypeAttr name) {
auto elemType = llvm::cast<ComponentTypeAttr>(name.getTypeArgs()[0]);
auto elemCount = llvm::cast<PolynomialAttr>(name.getTypeArgs()[1]);
auto elemLayout = manager->getLayoutType(elemType);
if (elemLayout)
return ZStruct::LayoutArrayType::get(getContext(), elemLayout, elemCount[0]);
else
return {};
}

mlir::Value BuiltinArrayOp::buildConstruct(ComponentManager* manager,
Expand All @@ -73,16 +88,94 @@ mlir::Value BuiltinArrayOp::buildConstruct(ComponentManager* manager,
ComponentTypeAttr name,
ValueRange constructArgs,
Value layout) {
if (constructArgs.size() != 1) {
if (constructArgs.size() != 1 && name.getName() == "Array") {
mlir::emitError(loc) << "Array constructor must have exactly one argument";
return {};
}
Type valueType = manager->getValueType(name);
if (!isCoercibleTo(constructArgs[0].getType(), valueType)) {
mlir::emitError(loc) << "Unable to convert " << constructArgs[0] << " to " << valueType;
return {};

ZStruct::ArrayType valueType = getCoercibleArrayType(manager->getValueType(name));
assert(valueType &&
"Shouldn't have been able to instantiate this component without a valid value type");

if (constructArgs.size() == 1) {
if (!isCoercibleTo(constructArgs[0].getType(), valueType)) {
mlir::emitError(loc) << "Unable to convert " << constructArgs[0] << " to " << valueType;
return {};
}
return coerceTo(constructArgs[0], valueType, builder);
}

assert((name.getName() == "ConcatArray") && "`Array` component cannot be used to concatinate");

SmallVector<Value> concatValues;
for (auto arg : constructArgs) {
auto argType = getCoercibleArrayType(arg.getType());
if (!argType) {
mlir::emitError(loc) << "Unable to use " << argType << " as an array";
return {};
}
auto arrayArg = coerceToArray(arg, builder);
if (!isCoercibleTo(argType.getElement(), valueType.getElement())) {
mlir::emitError(loc) << "Unable to convert " << argType << " elements to " << valueType
<< "\n";
return {};
}
for (size_t i = 0; i != argType.getSize(); ++i) {
Value index = builder.create<Zll::ConstOp>(loc, i);
auto elem = builder.create<ZStruct::SubscriptOp>(loc, arrayArg, index);

concatValues.push_back(coerceTo(elem, valueType.getElement(), builder));
}
}

return builder.create<ZStruct::ArrayOp>(loc, valueType, concatValues);
}

void BuiltinArrayOp::inferType(ComponentManager* manager,
ComponentTypeAttr& name,
ValueRange constructArgs) {
LLVM_DEBUG({ llvm::dbgs() << "Attempting to infer type of " << name << "\n"; });
if (!name.getTypeArgs().empty()) {
// Already specialized
return;
}

assert(name.getName() == "Array" || name.getName() == "ConcatArray");
if (name.getName() == "Array" && constructArgs.size() != 1) {
LLVM_DEBUG({ llvm::dbgs() << "Array does not have a single constructor arg\n"; });

// To concatinate more than one array together they must be invoked as `ConcatArray`
return;
}
if (constructArgs.empty()) {
LLVM_DEBUG({ llvm::dbgs() << "Array can not be empty."; });

return;
}
size_t numElem = 0;
SmallVector<Type> elemTypes;
for (auto constructArg : constructArgs) {
auto arrType = getCoercibleArrayType(constructArg.getType());
if (!arrType) {
LLVM_DEBUG(
{ llvm::dbgs() << "Non-array type supplied: " << constructArg.getType() << "\n"; });

// Non-array argument
return;
}

numElem += arrType.getSize();
elemTypes.push_back(arrType.getElement());
}

ComponentTypeAttr elemType = manager->getNameForType(getLeastCommonSuper(elemTypes));
if (!elemType) {
LLVM_DEBUG({ llvm::dbgs() << "No common element type\n"; });
return;
}
return coerceTo(constructArgs[0], valueType, builder);
name = ComponentTypeAttr::get(
getContext(), name.getName(), {elemType, PolynomialAttr::get(getContext(), numElem)});
LLVM_DEBUG({ llvm::dbgs() << "Inferred " << name << "\n"; });
}

} // namespace zirgen::Zhlt
7 changes: 1 addition & 6 deletions zirgen/Dialect/ZHLT/IR/BuiltinOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,5 @@ class BuiltinComponentOp<string mnemonic, list<string> methods = []>

def BuiltinLogOp : BuiltinComponentOp<"log", ["buildConstruct"]>;

// def BuiltinConcatArrayOp : BuiltinComponentOp<"concat_array",
// ["requireComponent", "buildConstruct", "getValueType", "getLayoutType", "reconstructFromLayout",
// "isConcrete", "specialize"]>;

def BuiltinArrayOp : BuiltinComponentOp<"array",
["requireComponent", "buildConstruct", "getValueType"//, "getLayoutType", "reconstructFromLayout"
]>;
["requireComponent", "buildConstruct", "getValueType", "getLayoutType", "inferType"]>;
7 changes: 7 additions & 0 deletions zirgen/Dialect/ZHLT/IR/ComponentOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ LogicalResult verifyRegion(Operation* origOp,

} // namespace

LogicalResult
ComponentOp::requireComponent(ComponentManager* manager, Location loc, ComponentTypeAttr name) {
assert(name.getMangledName() == getName());

return mlir::success();
}

mlir::Type ComponentOp::getValueType(ComponentManager* manager, ComponentTypeAttr name) {
assert(name.getMangledName() == getName());
return getOutType();
Expand Down
2 changes: 1 addition & 1 deletion zirgen/Dialect/ZHLT/IR/ComponentOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def ComponentOp : ZFuncOp</*aspect=*/"",
Symbol,
DeclareOpInterfaceMethods<ComponentOpInterface,
["getLayoutType", "getValueType", "buildConstruct", "reconstructFromLayout",
"getConstructParams"]>],
"getConstructParams", "requireComponent"]>],
/*mnemonic=*/"component"> {
let symPrefix = "";
let summary = "Component declaration";
Expand Down
Loading

0 comments on commit faa60c4

Please sign in to comment.