Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for decl literals #2113

Merged
merged 2 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 101 additions & 60 deletions src/analysis.zig
Original file line number Diff line number Diff line change
Expand Up @@ -811,28 +811,37 @@ fn findReturnStatement(tree: Ast, body: Ast.Node.Index) ?Ast.Node.Index {
return findReturnStatementInternal(tree, body, &already_found);
}

pub fn resolveReturnType(analyser: *Analyser, fn_decl: Ast.full.FnProto, handle: *DocumentStore.Handle, fn_body: ?Ast.Node.Index) error{OutOfMemory}!?Type {
const tree = handle.tree;
if (isTypeFunction(tree, fn_decl) and fn_body != null) {
pub fn resolveReturnType(analyser: *Analyser, func_type_param: Type) error{OutOfMemory}!?Type {
const func_type = try analyser.resolveFuncProtoOfCallable(func_type_param) orelse return null;
const func_node_handle = func_type.data.other; // this assumes that function types can only be Ast nodes
const tree = func_node_handle.handle.tree;
const func_node = func_node_handle.node;

var buf: [1]Ast.Node.Index = undefined;
const fn_proto = tree.fullFnProto(&buf, func_node).?;
const has_body = tree.nodes.items(.tag)[func_node] == .fn_decl;

if (isTypeFunction(tree, fn_proto) and has_body) {
const body = tree.nodes.items(.data)[func_node].rhs;
// If this is a type function and it only contains a single return statement that returns
// a container declaration, we will return that declaration.
const ret = findReturnStatement(tree, fn_body.?) orelse return null;
const ret = findReturnStatement(tree, body) orelse return null;
const data = tree.nodes.items(.data)[ret];
if (data.lhs != 0) {
return try analyser.resolveTypeOfNodeInternal(.{ .node = data.lhs, .handle = handle });
return try analyser.resolveTypeOfNodeInternal(.{ .node = data.lhs, .handle = func_node_handle.handle });
}

return null;
}

if (fn_decl.ast.return_type == 0) return null;
const return_type = fn_decl.ast.return_type;
const ret: NodeWithHandle = .{ .node = return_type, .handle = handle };
if (fn_proto.ast.return_type == 0) return null;
const return_type = fn_proto.ast.return_type;
const ret: NodeWithHandle = .{ .node = return_type, .handle = func_node_handle.handle };
const child_type = (try analyser.resolveTypeOfNodeInternal(ret)) orelse
return null;
if (!child_type.is_type_val) return null;

if (ast.hasInferredError(tree, fn_decl)) {
if (ast.hasInferredError(tree, fn_proto)) {
const child_type_ptr = try analyser.arena.allocator().create(Type);
child_type_ptr.* = child_type;
return Type{
Expand Down Expand Up @@ -1542,11 +1551,7 @@ fn resolveTypeOfNodeUncached(analyser: *Analyser, node_handle: NodeWithHandle) e
}, argument_type);
}

const has_body = func_tree.nodes.items(.tag)[func_node] == .fn_decl;
const body = func_tree.nodes.items(.data)[func_node].rhs;
if (try analyser.resolveReturnType(fn_proto, func_handle, if (has_body) body else null)) |ret| {
return ret;
}
return try analyser.resolveReturnType(func_ty);
},
.container_field,
.container_field_init,
Expand Down Expand Up @@ -2646,18 +2651,27 @@ pub const Type = struct {
}
}

fn isContainerKind(self: Type, container_kind_tok: std.zig.Token.Tag) bool {
pub fn isContainerType(self: Type) bool {
return self.data == .container;
}

fn getContainerKind(self: Type) ?std.zig.Token.Tag {
const scope_handle = switch (self.data) {
.container => |s| s,
else => return false,
else => return null,
};
if (scope_handle.scope == .root) return .keyword_struct;

const node = scope_handle.toNode();

const tree = scope_handle.handle.tree;
const main_tokens = tree.nodes.items(.main_token);
const tags = tree.tokens.items(.tag);
return tags[main_tokens[node]] == container_kind_tok;
return tags[main_tokens[node]];
}

fn isContainerKind(self: Type, container_kind_tok: std.zig.Token.Tag) bool {
return self.getContainerKind() == container_kind_tok;
}

pub fn isStructType(self: Type) bool {
Expand Down Expand Up @@ -2714,6 +2728,17 @@ pub const Type = struct {
}
}

pub fn resolveDeclLiteralResultType(ty: Type) Type {
var result_type = ty;
while (true) {
result_type = switch (result_type.data) {
.optional => |child_ty| child_ty.*,
.error_union => |info| info.payload.*,
else => return result_type,
};
}
}

pub fn isTypeFunc(self: Type) bool {
var buf: [1]Ast.Node.Index = undefined;
return switch (self.data) {
Expand Down Expand Up @@ -3231,22 +3256,9 @@ pub fn getFieldAccessType(

// Can't call a function type, we need a function type instance.
if (current_type.?.is_type_val) return null;
// this assumes that function types can only be Ast nodes
const current_type_node_handle = ty.data.other;
const current_type_node = current_type_node_handle.node;
const current_type_handle = current_type_node_handle.handle;

const cur_tree = current_type_handle.tree;
var buf: [1]Ast.Node.Index = undefined;
const func = cur_tree.fullFnProto(&buf, current_type_node).?;
// Check if the function has a body and if so, pass it
// so the type can be resolved if it's a generic function returning
// an anonymous struct
const has_body = cur_tree.nodes.items(.tag)[current_type_node] == .fn_decl;
const body = cur_tree.nodes.items(.data)[current_type_node].rhs;

// TODO Actually bind params here when calling functions instead of just skipping args.
current_type = try analyser.resolveReturnType(func, current_type_handle, if (has_body) body else null) orelse return null;
current_type = try analyser.resolveReturnType(ty) orelse return null;

if (do_unwrap_error_payload) {
if (try analyser.resolveUnwrapErrorUnionType(current_type.?, .payload)) |unwrapped| current_type = unwrapped;
Expand Down Expand Up @@ -4470,7 +4482,7 @@ pub fn lookupSymbolFieldInit(
analyser: *Analyser,
handle: *DocumentStore.Handle,
field_name: []const u8,
nodes: []Ast.Node.Index,
nodes: []const Ast.Node.Index,
) error{OutOfMemory}!?DeclWithHandle {
if (nodes.len == 0) return null;

Expand All @@ -4480,29 +4492,53 @@ pub fn lookupSymbolFieldInit(
nodes[1..],
)) orelse return null;

const is_struct_init = switch (handle.tree.nodes.items(.tag)[nodes[0]]) {
.struct_init_one,
.struct_init_one_comma,
.struct_init_dot_two,
.struct_init_dot_two_comma,
.struct_init_dot,
.struct_init_dot_comma,
.struct_init,
.struct_init_comma,
=> true,
else => false,
};

if (try analyser.resolveUnwrapErrorUnionType(container_type, .payload)) |unwrapped|
container_type = unwrapped;

if (try analyser.resolveOptionalUnwrap(container_type)) |unwrapped|
container_type = unwrapped;

const container_scope_handle = switch (container_type.data) {
const container_scope = switch (container_type.data) {
.container => |s| s,
else => return null,
};
if (is_struct_init) {
return try analyser.lookupSymbolContainer(container_scope, field_name, .field);
}

return analyser.lookupSymbolContainer(
container_scope_handle,
field_name,
.field,
);
// Assume we are doing decl literals
switch (container_type.getContainerKind() orelse return null) {
.keyword_struct => {
const decl = try analyser.lookupSymbolContainer(container_scope, field_name, .other) orelse return null;
var resolved_type = try decl.resolveType(analyser) orelse return null;
resolved_type = try analyser.resolveReturnType(resolved_type) orelse resolved_type;
resolved_type = resolved_type.resolveDeclLiteralResultType();
if (resolved_type.eql(container_type) or resolved_type.eql(container_type.typeOf(analyser))) return decl;
return null;
},
.keyword_enum, .keyword_union => return try analyser.lookupSymbolContainer(container_scope, field_name, .field),
else => return null,
}
}

pub fn resolveExpressionType(
analyser: *Analyser,
handle: *DocumentStore.Handle,
node: Ast.Node.Index,
ancestors: []Ast.Node.Index,
ancestors: []const Ast.Node.Index,
) error{OutOfMemory}!?Type {
return (try analyser.resolveExpressionTypeFromAncestors(
handle,
Expand All @@ -4518,7 +4554,7 @@ pub fn resolveExpressionTypeFromAncestors(
analyser: *Analyser,
handle: *DocumentStore.Handle,
node: Ast.Node.Index,
ancestors: []Ast.Node.Index,
ancestors: []const Ast.Node.Index,
) error{OutOfMemory}!?Type {
if (ancestors.len == 0) return null;

Expand Down Expand Up @@ -4682,34 +4718,32 @@ pub fn resolveExpressionTypeFromAncestors(
=> {
var buffer: [1]Ast.Node.Index = undefined;
const call = tree.fullCall(&buffer, ancestors[0]).?;

if (call.ast.fn_expr == node) {
return try analyser.resolveExpressionType(
handle,
ancestors[0],
ancestors[1..],
);
}

const arg_index = std.mem.indexOfScalar(Ast.Node.Index, call.ast.params, node) orelse return null;

const ty = try analyser.resolveTypeOfNode(.{ .node = call.ast.fn_expr, .handle = handle }) orelse return null;
const fn_type = try analyser.resolveFuncProtoOfCallable(ty) orelse return null;
if (fn_type.is_type_val) return null;

const fn_node_handle = fn_type.data.other; // this assumes that function types can only be Ast nodes
const fn_node = fn_node_handle.node;
const fn_handle = fn_node_handle.handle;
const fn_tree = fn_handle.tree;

var fn_buf: [1]Ast.Node.Index = undefined;
const fn_proto = fn_tree.fullFnProto(&fn_buf, fn_node).?;

var param_iter = fn_proto.iterate(&fn_tree);
if (try analyser.isInstanceCall(handle, call, fn_type)) {
_ = ast.nextFnParam(&param_iter);
}
const param_decl: Declaration.Param = .{
.param_index = @truncate(arg_index + @intFromBool(try analyser.hasSelfParam(fn_type))),
.func = fn_node_handle.node,
};
const param = param_decl.get(fn_node_handle.handle.tree) orelse return null;

var param_index: usize = 0;
while (ast.nextFnParam(&param_iter)) |param| : (param_index += 1) {
if (param_index == arg_index) {
return try analyser.resolveTypeOfNode(.{
.node = param.type_expr,
.handle = fn_handle,
});
}
}
return try analyser.resolveTypeOfNode(.{
.node = param.type_expr,
.handle = fn_node_handle.handle,
});
},
.assign => {
if (node == datas[ancestors[0]].rhs) {
Expand Down Expand Up @@ -4785,6 +4819,13 @@ pub fn resolveExpressionTypeFromAncestors(
ancestors[index + 1 ..],
);
},
.@"try" => {
return try analyser.resolveExpressionType(
handle,
ancestors[0],
ancestors[1..],
);
},

else => {}, // TODO: Implement more expressions; better safe than sorry
}
Expand Down
Loading
Loading