type_resolver: fix fn detection for comptime arg type (fix #23454) (#23456)

This commit is contained in:
Felipe Pena 2025-01-14 18:05:13 -03:00 committed by GitHub
parent 9ba294bc73
commit 6ab25623e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 239 additions and 142 deletions

View File

@ -861,6 +861,7 @@ pub mut:
pos token.Pos
should_be_ptr bool // fn expects a ptr for this arg
// tmp_name string // for autofree
ct_expr bool // true, when the expression is a comptime/generic expression
}
// function return statement

View File

@ -412,51 +412,9 @@ fn (mut c Checker) assign_stmt(mut node ast.AssignStmt) {
}
}
}
if right is ast.ComptimeSelector {
if is_decl {
left.obj.ct_type_var = .field_var
left.obj.typ = c.comptime.comptime_for_field_type
}
} else if mut right is ast.InfixExpr {
right_ct_var := c.comptime.get_ct_type_var(right.left)
if right_ct_var != .no_comptime {
left.obj.ct_type_var = right_ct_var
}
} else if mut right is ast.IndexExpr
&& c.comptime.is_comptime(right) {
right_ct_var := c.comptime.get_ct_type_var(right.left)
if right_ct_var != .no_comptime {
left.obj.ct_type_var = right_ct_var
}
} else if mut right is ast.Ident && right.obj is ast.Var
&& right.or_expr.kind == .absent {
right_obj_var := right.obj as ast.Var
if right_obj_var.ct_type_var != .no_comptime {
ctyp := c.type_resolver.get_type(right)
if ctyp != ast.void_type {
left.obj.ct_type_var = right_obj_var.ct_type_var
left.obj.typ = ctyp
}
}
} else if right is ast.DumpExpr
&& right.expr is ast.ComptimeSelector {
left.obj.ct_type_var = .field_var
left.obj.typ = c.comptime.comptime_for_field_type
} else if mut right is ast.CallExpr {
if right.left_type != 0
&& c.table.type_kind(right.left_type) == .array
&& right.name == 'map' && right.args.len > 0
&& right.args[0].expr is ast.AsCast
&& right.args[0].expr.typ.has_flag(.generic) {
left.obj.ct_type_var = .generic_var
} else if left.obj.ct_type_var in [.generic_var, .no_comptime]
&& c.table.cur_fn != unsafe { nil }
&& c.table.cur_fn.generic_names.len != 0
&& !right.comptime_ret_val
&& c.type_resolver.is_generic_expr(right) {
// mark variable as generic var because its type changes according to fn return generic resolution type
left.obj.ct_type_var = .generic_var
}
// flag the variable as comptime/generic related on its declaration
if is_decl {
c.change_flags_if_comptime_expr(mut left, right)
}
}
ast.GlobalField {
@ -962,3 +920,46 @@ or use an explicit `unsafe{ a[..] }`, if you do not want a copy of the slice.',
c.error('assign statement left type number mismatch', node.pos)
}
}
// change_flags_if_comptime_expr changes the flags of the left variable if the right expression is comptime/generic expr
fn (mut c Checker) change_flags_if_comptime_expr(mut left ast.Ident, right ast.Expr) {
if mut left.obj is ast.Var {
if right is ast.ComptimeSelector {
left.obj.ct_type_var = .field_var
left.obj.typ = c.comptime.comptime_for_field_type
} else if right is ast.InfixExpr {
right_ct_var := c.comptime.get_ct_type_var(right.left)
if right_ct_var != .no_comptime {
left.obj.ct_type_var = right_ct_var
}
} else if right is ast.IndexExpr && c.comptime.is_comptime(right) {
right_ct_var := c.comptime.get_ct_type_var(right.left)
if right_ct_var != .no_comptime {
left.obj.ct_type_var = right_ct_var
}
} else if right is ast.Ident && right.obj is ast.Var && right.or_expr.kind == .absent {
right_obj_var := right.obj as ast.Var
if right_obj_var.ct_type_var != .no_comptime {
ctyp := c.type_resolver.get_type(right)
if ctyp != ast.void_type {
left.obj.ct_type_var = right_obj_var.ct_type_var
left.obj.typ = ctyp
}
}
} else if right is ast.DumpExpr && right.expr is ast.ComptimeSelector {
left.obj.ct_type_var = .field_var
left.obj.typ = c.comptime.comptime_for_field_type
} else if right is ast.CallExpr {
if right.left_type != 0 && c.table.type_kind(right.left_type) == .array
&& right.name == 'map' && right.args.len > 0 && right.args[0].expr is ast.AsCast
&& right.args[0].expr.typ.has_flag(.generic) {
left.obj.ct_type_var = .generic_var
} else if left.obj.ct_type_var in [.generic_var, .no_comptime]
&& c.table.cur_fn != unsafe { nil } && c.table.cur_fn.generic_names.len != 0
&& !right.comptime_ret_val && c.type_resolver.is_generic_expr(right) {
// mark variable as generic var because its type changes according to fn return generic resolution type
left.obj.ct_type_var = .generic_var
}
}
}
}

View File

@ -1421,6 +1421,7 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast.
}
// println / eprintln / panic can print anything
if node.args.len > 0 && fn_name in print_everything_fns {
node.args[0].ct_expr = c.comptime.is_comptime(node.args[0].expr)
c.builtin_args(mut node, fn_name, func)
if c.pref.skip_unused && !c.is_builtin_mod && c.mod != 'math.bits'
&& node.args[0].expr !is ast.StringLiteral {
@ -1442,6 +1443,7 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast.
if node.args.len == 1 && fn_name == 'error' {
mut arg := node.args[0]
node.args[0].typ = c.expr(mut arg.expr)
node.args[0].ct_expr = c.comptime.is_comptime(node.args[0].expr)
if node.args[0].typ == ast.error_type {
c.warn('`error(${arg})` can be shortened to just `${arg}`', node.pos)
}
@ -1456,6 +1458,9 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast.
if func.params.len == 0 {
continue
}
if !c.inside_recheck {
call_arg.ct_expr = c.comptime.is_comptime(call_arg.expr)
}
if !func.is_variadic && has_decompose {
c.error('cannot have parameter after array decompose', node.pos)
}
@ -2383,6 +2388,9 @@ fn (mut c Checker) method_call(mut node ast.CallExpr, mut continue_check &bool)
} else {
method.params[i + 1].typ
}
if !c.inside_recheck {
arg.ct_expr = c.comptime.is_comptime(arg.expr)
}
// If initialize a generic struct with short syntax,
// need to get the parameter information from the original generic method
if is_method_from_embed && arg.expr is ast.StructInit {

View File

@ -481,8 +481,22 @@ fn (mut g Gen) gen_array_map(node ast.CallExpr) {
}
return_type := if g.type_resolver.is_generic_expr(node.args[0].expr) {
ast.new_type(g.table.find_or_register_array(g.type_resolver.unwrap_generic_expr(node.args[0].expr,
node.return_type)))
mut ctyp := ast.void_type
if node.args[0].expr is ast.CallExpr && node.args[0].expr.return_type_generic != 0
&& node.args[0].expr.return_type_generic.has_flag(.generic) {
ctyp = g.resolve_return_type(node.args[0].expr)
if g.table.type_kind(node.args[0].expr.return_type_generic) in [.array, .array_fixed] {
ctyp = ast.new_type(g.table.find_or_register_array(ctyp))
}
}
if ctyp == ast.void_type {
ctyp = g.type_resolver.unwrap_generic_expr(node.args[0].expr, node.return_type)
}
if g.table.type_kind(g.unwrap_generic(ctyp)) !in [.array, .array_fixed] {
ast.new_type(g.table.find_or_register_array(ctyp))
} else {
ctyp
}
} else {
node.return_type
}
@ -498,7 +512,6 @@ fn (mut g Gen) gen_array_map(node ast.CallExpr) {
(ret_sym.info as ast.ArrayFixed).elem_type
}
mut ret_elem_styp := g.styp(ret_elem_type)
inp_elem_type := if left_is_array {
(inp_sym.info as ast.Array).elem_type
} else {

View File

@ -294,95 +294,99 @@ fn (mut g Gen) assign_stmt(node_ ast.AssignStmt) {
}
}
if mut left.obj is ast.Var {
if val is ast.Ident && val.ct_expr {
ctyp := g.unwrap_generic(g.type_resolver.get_type(val))
if ctyp != ast.void_type {
var_type = ctyp
val_type = var_type
if is_decl {
if val is ast.Ident && val.ct_expr {
ctyp := g.unwrap_generic(g.type_resolver.get_type(val))
if ctyp != ast.void_type {
var_type = ctyp
val_type = var_type
gen_or = val.or_expr.kind != .absent
if gen_or {
var_type = val_type.clear_flag(.option)
}
left.obj.typ = var_type
g.assign_ct_type = var_type
}
} else if val is ast.ComptimeSelector {
if val.typ_key != '' {
if is_decl {
var_type = g.type_resolver.get_ct_type_or_default(val.typ_key,
var_type)
val_type = var_type
left.obj.typ = var_type
} else {
val_type = g.type_resolver.get_ct_type_or_default(val.typ_key,
var_type)
}
g.assign_ct_type = var_type
}
} else if val is ast.ComptimeCall {
key_str := '${val.method_name}.return_type'
var_type = g.type_resolver.get_ct_type_or_default(key_str, var_type)
left.obj.typ = var_type
g.assign_ct_type = var_type
} else if val is ast.Ident && val.info is ast.IdentVar {
val_info := (val as ast.Ident).info
gen_or = val.or_expr.kind != .absent
if gen_or {
if val_info.is_option && gen_or {
var_type = val_type.clear_flag(.option)
}
left.obj.typ = var_type
g.assign_ct_type = var_type
}
} else if val is ast.ComptimeSelector {
if val.typ_key != '' {
if is_decl {
var_type = g.type_resolver.get_ct_type_or_default(val.typ_key,
var_type)
val_type = var_type
left.obj.typ = var_type
} else {
val_type = g.type_resolver.get_ct_type_or_default(val.typ_key,
var_type)
}
g.assign_ct_type = var_type
}
} else if val is ast.ComptimeCall {
key_str := '${val.method_name}.return_type'
var_type = g.type_resolver.get_ct_type_or_default(key_str, var_type)
left.obj.typ = var_type
g.assign_ct_type = var_type
} else if is_decl && val is ast.Ident && val.info is ast.IdentVar {
val_info := (val as ast.Ident).info
gen_or = val.or_expr.kind != .absent
if val_info.is_option && gen_or {
var_type = val_type.clear_flag(.option)
left.obj.typ = var_type
}
} else if val is ast.DumpExpr {
if val.expr is ast.ComptimeSelector {
if val.expr.typ_key != '' {
var_type = g.type_resolver.get_ct_type_or_default(val.expr.typ_key,
var_type)
val_type = var_type
left.obj.typ = var_type
}
g.assign_ct_type = var_type
}
} else if val is ast.IndexExpr && (val.left is ast.Ident && val.left.ct_expr) {
ctyp := g.unwrap_generic(g.type_resolver.get_type(val))
if ctyp != ast.void_type {
var_type = ctyp
val_type = var_type
left.obj.typ = var_type
g.assign_ct_type = var_type
}
} else if left.obj.ct_type_var == .generic_var && val is ast.CallExpr {
if val.return_type_generic != 0 && val.return_type_generic.has_flag(.generic) {
fn_ret_type := g.resolve_return_type(val)
if fn_ret_type != ast.void_type {
} else if val is ast.DumpExpr {
if val.expr is ast.ComptimeSelector {
if val.expr.typ_key != '' {
var_type = g.type_resolver.get_ct_type_or_default(val.expr.typ_key,
var_type)
val_type = var_type
left.obj.typ = var_type
}
g.assign_ct_type = var_type
}
} else if val is ast.IndexExpr && (val.left is ast.Ident && val.left.ct_expr) {
ctyp := g.unwrap_generic(g.type_resolver.get_type(val))
if ctyp != ast.void_type {
var_type = ctyp
val_type = var_type
left.obj.typ = var_type
g.assign_ct_type = var_type
}
} else if left.obj.ct_type_var == .generic_var && val is ast.CallExpr {
if val.return_type_generic != 0
&& val.return_type_generic.has_flag(.generic) {
fn_ret_type := g.resolve_return_type(val)
if fn_ret_type != ast.void_type {
var_type = fn_ret_type
val_type = var_type
left.obj.typ = var_type
}
} else if val.is_static_method && val.left_type.has_flag(.generic) {
fn_ret_type := g.resolve_return_type(val)
var_type = fn_ret_type
val_type = var_type
left.obj.typ = var_type
g.assign_ct_type = var_type
} else if val.left_type != 0 && g.table.type_kind(val.left_type) == .array
&& val.name == 'map' && val.args.len > 0
&& val.args[0].expr is ast.AsCast
&& val.args[0].expr.typ.has_flag(.generic) {
var_type = g.table.find_or_register_array(g.unwrap_generic((val.args[0].expr as ast.AsCast).typ))
val_type = var_type
left.obj.typ = var_type
g.assign_ct_type = var_type
}
} else if val.is_static_method && val.left_type.has_flag(.generic) {
fn_ret_type := g.resolve_return_type(val)
var_type = fn_ret_type
val_type = var_type
left.obj.typ = var_type
g.assign_ct_type = var_type
} else if val.left_type != 0 && g.table.type_kind(val.left_type) == .array
&& val.name == 'map' && val.args.len > 0 && val.args[0].expr is ast.AsCast
&& val.args[0].expr.typ.has_flag(.generic) {
var_type = g.table.find_or_register_array(g.unwrap_generic((val.args[0].expr as ast.AsCast).typ))
val_type = var_type
left.obj.typ = var_type
g.assign_ct_type = var_type
}
} else if val is ast.InfixExpr && val.op in [.plus, .minus, .mul, .div, .mod]
&& val.left_ct_expr {
ctyp := g.unwrap_generic(g.type_resolver.get_type(val.left))
if ctyp != ast.void_type {
ct_type_var := g.comptime.get_ct_type_var(val.left)
if ct_type_var in [.key_var, .value_var] {
g.type_resolver.update_ct_type(left.name, g.unwrap_generic(ctyp))
} else if val is ast.InfixExpr && val.op in [.plus, .minus, .mul, .div, .mod]
&& val.left_ct_expr {
ctyp := g.unwrap_generic(g.type_resolver.get_type(val.left))
if ctyp != ast.void_type {
ct_type_var := g.comptime.get_ct_type_var(val.left)
if ct_type_var in [.key_var, .value_var] {
g.type_resolver.update_ct_type(left.name, g.unwrap_generic(ctyp))
}
var_type = ctyp
val_type = var_type
left.obj.typ = var_type
g.assign_ct_type = var_type
}
var_type = ctyp
val_type = var_type
left.obj.typ = var_type
g.assign_ct_type = var_type
}
}
is_auto_heap = left.obj.is_auto_heap

View File

@ -2028,8 +2028,7 @@ fn (mut g Gen) fn_call(node ast.CallExpr) {
// Handle `print(x)`
mut print_auto_str := false
if is_print && (node.args[0].typ != ast.string_type
|| g.comptime.comptime_for_method != unsafe { nil }
|| g.comptime.is_comptime(node.args[0].expr)) {
|| g.comptime.comptime_for_method != unsafe { nil } || node.args[0].ct_expr) {
g.inside_interface_deref = true
defer {
g.inside_interface_deref = false
@ -2577,7 +2576,7 @@ fn (mut g Gen) keep_alive_call_postgen(node ast.CallExpr, tmp_cnt_save int) {
@[inline]
fn (mut g Gen) ref_or_deref_arg(arg ast.CallArg, expected_type ast.Type, lang ast.Language, is_smartcast bool) {
arg_typ := if arg.expr is ast.ComptimeSelector {
arg_typ := if arg.ct_expr {
g.unwrap_generic(g.type_resolver.get_type(arg.expr))
} else {
g.unwrap_generic(arg.typ)

View File

@ -0,0 +1,38 @@
pub type DesiredCapabilities = FireFox | Edge
struct FireFox {
browser_name string = 'firefox'
accept_insecure_certs bool = true
moz_debugger_address bool = true
}
struct Edge {
browser_name string = 'MicrosoftEdge'
}
fn struct_values[T](s T) map[string]string {
mut res := map[string]string{}
$if T is $struct {
$for field in T.fields {
res[field.name] = s.$(field.name).str()
}
}
return res
}
fn useit(dc DesiredCapabilities) string {
$for v in dc.variants {
if dc is v {
$if v is $struct {
result := struct_values(dc)
return result.str()
}
}
}
return ''
}
fn test_main() {
assert useit(Edge{}) == "{'browser_name': 'MicrosoftEdge'}"
assert useit(FireFox{}) == "{'browser_name': 'firefox', 'accept_insecure_certs': 'true', 'moz_debugger_address': 'true'}"
}

View File

@ -0,0 +1,18 @@
import math
fn t[T](a1 []int, a2 []int) T {
mut a := 0 * a1[0] - a2[0]
a = math.max(a, 10)
mut t := T{}
for i in a .. 20 {
t += i
for j in a .. 20 {
t += j
}
}
return t
}
fn test_main() {
assert t[int]([1, 2, 3], [4, 5, 6]) == 1595
}

View File

@ -47,7 +47,7 @@ pub fn (t &ResolverInfo) is_comptime(node ast.Expr) bool {
}
}
ast.SelectorExpr {
return node.expr is ast.Ident && node.expr.ct_expr
return node.expr is ast.Ident && node.expr.ct_expr && node.field_name != 'len'
}
ast.InfixExpr {
return node.left_ct_expr || node.right_ct_expr
@ -55,6 +55,9 @@ pub fn (t &ResolverInfo) is_comptime(node ast.Expr) bool {
ast.ParExpr {
return t.is_comptime(node.expr)
}
ast.ComptimeSelector {
return true
}
else {
false
}

View File

@ -67,6 +67,9 @@ pub fn (t &TypeResolver) is_generic_expr(node ast.Expr) bool {
if node.is_static_method && node.left_type.has_flag(.generic) {
return true
}
if node.return_type_generic != 0 && node.return_type_generic.has_flag(.generic) {
return true
}
// fn[T]() or generic_var.fn[T]()
node.concrete_types.any(it.has_flag(.generic))
}
@ -119,6 +122,7 @@ pub fn (mut t TypeResolver) resolve_args(cur_fn &ast.FnDecl, func &ast.Fn, mut n
mut comptime_args := map[int]ast.Type{}
has_dynamic_vars := (cur_fn != unsafe { nil } && cur_fn.generic_names.len > 0)
|| t.info.comptime_for_field_var != ''
|| func.generic_names.len != node_.raw_concrete_types.len
if !has_dynamic_vars {
return comptime_args
}

View File

@ -101,15 +101,10 @@ pub fn (mut t TypeResolver) get_type_or_default(node ast.Expr, default_typ ast.T
}
ast.SelectorExpr {
if node.expr is ast.Ident && node.expr.ct_expr {
struct_typ := t.resolver.unwrap_generic(t.get_type(node.expr))
struct_sym := t.table.final_sym(struct_typ)
// Struct[T] can have field with generic type
if struct_sym.info is ast.Struct && struct_sym.info.generic_types.len > 0 {
if field := t.table.find_field(struct_sym, node.field_name) {
return field.typ
}
}
ctyp := t.get_type(node)
return if ctyp != ast.void_type { ctyp } else { default_typ }
}
return default_typ
}
ast.ParExpr {
return t.get_type_or_default(node.expr, default_typ)
@ -189,8 +184,21 @@ pub fn (mut t TypeResolver) get_type(node ast.Expr) ast.Type {
} else if node is ast.ComptimeSelector {
// val.$(field.name)
return t.get_comptime_selector_type(node, ast.void_type)
} else if node is ast.SelectorExpr && t.info.is_comptime_selector_type(node) {
return t.get_type_from_comptime_var(node.expr as ast.Ident)
} else if node is ast.SelectorExpr {
if t.info.is_comptime_selector_type(node) {
return t.get_type_from_comptime_var(node.expr as ast.Ident)
}
if node.expr is ast.Ident && node.expr.ct_expr {
struct_typ := t.resolver.unwrap_generic(t.get_type(node.expr))
struct_sym := t.table.final_sym(struct_typ)
// Struct[T] can have field with generic type
if struct_sym.info is ast.Struct && struct_sym.info.generic_types.len > 0 {
if field := t.table.find_field(struct_sym, node.field_name) {
return field.typ
}
}
}
return node.typ
} else if node is ast.ComptimeCall {
method_name := t.info.comptime_for_method.name
left_sym := t.table.sym(t.resolver.unwrap_generic(node.left_type))