From 6ab25623e3006e10a4196d5157833c4cd9dab509 Mon Sep 17 00:00:00 2001 From: Felipe Pena Date: Tue, 14 Jan 2025 18:05:13 -0300 Subject: [PATCH] type_resolver: fix fn detection for comptime arg type (fix #23454) (#23456) --- vlib/v/ast/ast.v | 1 + vlib/v/checker/assign.v | 91 +++++----- vlib/v/checker/fn.v | 8 + vlib/v/gen/c/array.v | 19 +- vlib/v/gen/c/assign.v | 164 +++++++++--------- vlib/v/gen/c/fn.v | 5 +- .../comptime_generic_comptime_variant_test.v | 38 ++++ vlib/v/tests/generics/generic_var_loop_test.v | 18 ++ vlib/v/type_resolver/comptime_resolver.v | 5 +- vlib/v/type_resolver/generic_resolver.v | 4 + vlib/v/type_resolver/type_resolver.v | 28 +-- 11 files changed, 239 insertions(+), 142 deletions(-) create mode 100644 vlib/v/tests/comptime_generic_comptime_variant_test.v create mode 100644 vlib/v/tests/generics/generic_var_loop_test.v diff --git a/vlib/v/ast/ast.v b/vlib/v/ast/ast.v index bc348cda9f..355234d4f0 100644 --- a/vlib/v/ast/ast.v +++ b/vlib/v/ast/ast.v @@ -861,6 +861,7 @@ pub mut: pos token.Pos should_be_ptr bool // fn expects a ptr for this arg // tmp_name string // for autofree + ct_expr bool // true, when the expression is a comptime/generic expression } // function return statement diff --git a/vlib/v/checker/assign.v b/vlib/v/checker/assign.v index 2c451d5330..8d2c592e02 100644 --- a/vlib/v/checker/assign.v +++ b/vlib/v/checker/assign.v @@ -412,51 +412,9 @@ fn (mut c Checker) assign_stmt(mut node ast.AssignStmt) { } } } - if right is ast.ComptimeSelector { - if is_decl { - left.obj.ct_type_var = .field_var - left.obj.typ = c.comptime.comptime_for_field_type - } - } else if mut right is ast.InfixExpr { - right_ct_var := c.comptime.get_ct_type_var(right.left) - if right_ct_var != .no_comptime { - left.obj.ct_type_var = right_ct_var - } - } else if mut right is ast.IndexExpr - && c.comptime.is_comptime(right) { - right_ct_var := c.comptime.get_ct_type_var(right.left) - if right_ct_var != .no_comptime { - left.obj.ct_type_var = right_ct_var - } - } else if mut right is ast.Ident && right.obj is ast.Var - && right.or_expr.kind == .absent { - right_obj_var := right.obj as ast.Var - if right_obj_var.ct_type_var != .no_comptime { - ctyp := c.type_resolver.get_type(right) - if ctyp != ast.void_type { - left.obj.ct_type_var = right_obj_var.ct_type_var - left.obj.typ = ctyp - } - } - } else if right is ast.DumpExpr - && right.expr is ast.ComptimeSelector { - left.obj.ct_type_var = .field_var - left.obj.typ = c.comptime.comptime_for_field_type - } else if mut right is ast.CallExpr { - if right.left_type != 0 - && c.table.type_kind(right.left_type) == .array - && right.name == 'map' && right.args.len > 0 - && right.args[0].expr is ast.AsCast - && right.args[0].expr.typ.has_flag(.generic) { - left.obj.ct_type_var = .generic_var - } else if left.obj.ct_type_var in [.generic_var, .no_comptime] - && c.table.cur_fn != unsafe { nil } - && c.table.cur_fn.generic_names.len != 0 - && !right.comptime_ret_val - && c.type_resolver.is_generic_expr(right) { - // mark variable as generic var because its type changes according to fn return generic resolution type - left.obj.ct_type_var = .generic_var - } + // flag the variable as comptime/generic related on its declaration + if is_decl { + c.change_flags_if_comptime_expr(mut left, right) } } ast.GlobalField { @@ -962,3 +920,46 @@ or use an explicit `unsafe{ a[..] }`, if you do not want a copy of the slice.', c.error('assign statement left type number mismatch', node.pos) } } + +// change_flags_if_comptime_expr changes the flags of the left variable if the right expression is comptime/generic expr +fn (mut c Checker) change_flags_if_comptime_expr(mut left ast.Ident, right ast.Expr) { + if mut left.obj is ast.Var { + if right is ast.ComptimeSelector { + left.obj.ct_type_var = .field_var + left.obj.typ = c.comptime.comptime_for_field_type + } else if right is ast.InfixExpr { + right_ct_var := c.comptime.get_ct_type_var(right.left) + if right_ct_var != .no_comptime { + left.obj.ct_type_var = right_ct_var + } + } else if right is ast.IndexExpr && c.comptime.is_comptime(right) { + right_ct_var := c.comptime.get_ct_type_var(right.left) + if right_ct_var != .no_comptime { + left.obj.ct_type_var = right_ct_var + } + } else if right is ast.Ident && right.obj is ast.Var && right.or_expr.kind == .absent { + right_obj_var := right.obj as ast.Var + if right_obj_var.ct_type_var != .no_comptime { + ctyp := c.type_resolver.get_type(right) + if ctyp != ast.void_type { + left.obj.ct_type_var = right_obj_var.ct_type_var + left.obj.typ = ctyp + } + } + } else if right is ast.DumpExpr && right.expr is ast.ComptimeSelector { + left.obj.ct_type_var = .field_var + left.obj.typ = c.comptime.comptime_for_field_type + } else if right is ast.CallExpr { + if right.left_type != 0 && c.table.type_kind(right.left_type) == .array + && right.name == 'map' && right.args.len > 0 && right.args[0].expr is ast.AsCast + && right.args[0].expr.typ.has_flag(.generic) { + left.obj.ct_type_var = .generic_var + } else if left.obj.ct_type_var in [.generic_var, .no_comptime] + && c.table.cur_fn != unsafe { nil } && c.table.cur_fn.generic_names.len != 0 + && !right.comptime_ret_val && c.type_resolver.is_generic_expr(right) { + // mark variable as generic var because its type changes according to fn return generic resolution type + left.obj.ct_type_var = .generic_var + } + } + } +} diff --git a/vlib/v/checker/fn.v b/vlib/v/checker/fn.v index 47f4dd1374..12cd03cd5e 100644 --- a/vlib/v/checker/fn.v +++ b/vlib/v/checker/fn.v @@ -1421,6 +1421,7 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast. } // println / eprintln / panic can print anything if node.args.len > 0 && fn_name in print_everything_fns { + node.args[0].ct_expr = c.comptime.is_comptime(node.args[0].expr) c.builtin_args(mut node, fn_name, func) if c.pref.skip_unused && !c.is_builtin_mod && c.mod != 'math.bits' && node.args[0].expr !is ast.StringLiteral { @@ -1442,6 +1443,7 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast. if node.args.len == 1 && fn_name == 'error' { mut arg := node.args[0] node.args[0].typ = c.expr(mut arg.expr) + node.args[0].ct_expr = c.comptime.is_comptime(node.args[0].expr) if node.args[0].typ == ast.error_type { c.warn('`error(${arg})` can be shortened to just `${arg}`', node.pos) } @@ -1456,6 +1458,9 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast. if func.params.len == 0 { continue } + if !c.inside_recheck { + call_arg.ct_expr = c.comptime.is_comptime(call_arg.expr) + } if !func.is_variadic && has_decompose { c.error('cannot have parameter after array decompose', node.pos) } @@ -2383,6 +2388,9 @@ fn (mut c Checker) method_call(mut node ast.CallExpr, mut continue_check &bool) } else { method.params[i + 1].typ } + if !c.inside_recheck { + arg.ct_expr = c.comptime.is_comptime(arg.expr) + } // If initialize a generic struct with short syntax, // need to get the parameter information from the original generic method if is_method_from_embed && arg.expr is ast.StructInit { diff --git a/vlib/v/gen/c/array.v b/vlib/v/gen/c/array.v index 0e69448599..22e81341b8 100644 --- a/vlib/v/gen/c/array.v +++ b/vlib/v/gen/c/array.v @@ -481,8 +481,22 @@ fn (mut g Gen) gen_array_map(node ast.CallExpr) { } return_type := if g.type_resolver.is_generic_expr(node.args[0].expr) { - ast.new_type(g.table.find_or_register_array(g.type_resolver.unwrap_generic_expr(node.args[0].expr, - node.return_type))) + mut ctyp := ast.void_type + if node.args[0].expr is ast.CallExpr && node.args[0].expr.return_type_generic != 0 + && node.args[0].expr.return_type_generic.has_flag(.generic) { + ctyp = g.resolve_return_type(node.args[0].expr) + if g.table.type_kind(node.args[0].expr.return_type_generic) in [.array, .array_fixed] { + ctyp = ast.new_type(g.table.find_or_register_array(ctyp)) + } + } + if ctyp == ast.void_type { + ctyp = g.type_resolver.unwrap_generic_expr(node.args[0].expr, node.return_type) + } + if g.table.type_kind(g.unwrap_generic(ctyp)) !in [.array, .array_fixed] { + ast.new_type(g.table.find_or_register_array(ctyp)) + } else { + ctyp + } } else { node.return_type } @@ -498,7 +512,6 @@ fn (mut g Gen) gen_array_map(node ast.CallExpr) { (ret_sym.info as ast.ArrayFixed).elem_type } mut ret_elem_styp := g.styp(ret_elem_type) - inp_elem_type := if left_is_array { (inp_sym.info as ast.Array).elem_type } else { diff --git a/vlib/v/gen/c/assign.v b/vlib/v/gen/c/assign.v index e28dbcd6c7..a9435a8e83 100644 --- a/vlib/v/gen/c/assign.v +++ b/vlib/v/gen/c/assign.v @@ -294,95 +294,99 @@ fn (mut g Gen) assign_stmt(node_ ast.AssignStmt) { } } if mut left.obj is ast.Var { - if val is ast.Ident && val.ct_expr { - ctyp := g.unwrap_generic(g.type_resolver.get_type(val)) - if ctyp != ast.void_type { - var_type = ctyp - val_type = var_type + if is_decl { + if val is ast.Ident && val.ct_expr { + ctyp := g.unwrap_generic(g.type_resolver.get_type(val)) + if ctyp != ast.void_type { + var_type = ctyp + val_type = var_type + gen_or = val.or_expr.kind != .absent + if gen_or { + var_type = val_type.clear_flag(.option) + } + left.obj.typ = var_type + g.assign_ct_type = var_type + } + } else if val is ast.ComptimeSelector { + if val.typ_key != '' { + if is_decl { + var_type = g.type_resolver.get_ct_type_or_default(val.typ_key, + var_type) + val_type = var_type + left.obj.typ = var_type + } else { + val_type = g.type_resolver.get_ct_type_or_default(val.typ_key, + var_type) + } + g.assign_ct_type = var_type + } + } else if val is ast.ComptimeCall { + key_str := '${val.method_name}.return_type' + var_type = g.type_resolver.get_ct_type_or_default(key_str, var_type) + left.obj.typ = var_type + g.assign_ct_type = var_type + } else if val is ast.Ident && val.info is ast.IdentVar { + val_info := (val as ast.Ident).info gen_or = val.or_expr.kind != .absent - if gen_or { + if val_info.is_option && gen_or { var_type = val_type.clear_flag(.option) - } - left.obj.typ = var_type - g.assign_ct_type = var_type - } - } else if val is ast.ComptimeSelector { - if val.typ_key != '' { - if is_decl { - var_type = g.type_resolver.get_ct_type_or_default(val.typ_key, - var_type) - val_type = var_type - left.obj.typ = var_type - } else { - val_type = g.type_resolver.get_ct_type_or_default(val.typ_key, - var_type) - } - g.assign_ct_type = var_type - } - } else if val is ast.ComptimeCall { - key_str := '${val.method_name}.return_type' - var_type = g.type_resolver.get_ct_type_or_default(key_str, var_type) - left.obj.typ = var_type - g.assign_ct_type = var_type - } else if is_decl && val is ast.Ident && val.info is ast.IdentVar { - val_info := (val as ast.Ident).info - gen_or = val.or_expr.kind != .absent - if val_info.is_option && gen_or { - var_type = val_type.clear_flag(.option) - left.obj.typ = var_type - } - } else if val is ast.DumpExpr { - if val.expr is ast.ComptimeSelector { - if val.expr.typ_key != '' { - var_type = g.type_resolver.get_ct_type_or_default(val.expr.typ_key, - var_type) - val_type = var_type left.obj.typ = var_type } - g.assign_ct_type = var_type - } - } else if val is ast.IndexExpr && (val.left is ast.Ident && val.left.ct_expr) { - ctyp := g.unwrap_generic(g.type_resolver.get_type(val)) - if ctyp != ast.void_type { - var_type = ctyp - val_type = var_type - left.obj.typ = var_type - g.assign_ct_type = var_type - } - } else if left.obj.ct_type_var == .generic_var && val is ast.CallExpr { - if val.return_type_generic != 0 && val.return_type_generic.has_flag(.generic) { - fn_ret_type := g.resolve_return_type(val) - if fn_ret_type != ast.void_type { + } else if val is ast.DumpExpr { + if val.expr is ast.ComptimeSelector { + if val.expr.typ_key != '' { + var_type = g.type_resolver.get_ct_type_or_default(val.expr.typ_key, + var_type) + val_type = var_type + left.obj.typ = var_type + } + g.assign_ct_type = var_type + } + } else if val is ast.IndexExpr && (val.left is ast.Ident && val.left.ct_expr) { + ctyp := g.unwrap_generic(g.type_resolver.get_type(val)) + if ctyp != ast.void_type { + var_type = ctyp + val_type = var_type + left.obj.typ = var_type + g.assign_ct_type = var_type + } + } else if left.obj.ct_type_var == .generic_var && val is ast.CallExpr { + if val.return_type_generic != 0 + && val.return_type_generic.has_flag(.generic) { + fn_ret_type := g.resolve_return_type(val) + if fn_ret_type != ast.void_type { + var_type = fn_ret_type + val_type = var_type + left.obj.typ = var_type + } + } else if val.is_static_method && val.left_type.has_flag(.generic) { + fn_ret_type := g.resolve_return_type(val) var_type = fn_ret_type val_type = var_type left.obj.typ = var_type + g.assign_ct_type = var_type + } else if val.left_type != 0 && g.table.type_kind(val.left_type) == .array + && val.name == 'map' && val.args.len > 0 + && val.args[0].expr is ast.AsCast + && val.args[0].expr.typ.has_flag(.generic) { + var_type = g.table.find_or_register_array(g.unwrap_generic((val.args[0].expr as ast.AsCast).typ)) + val_type = var_type + left.obj.typ = var_type + g.assign_ct_type = var_type } - } else if val.is_static_method && val.left_type.has_flag(.generic) { - fn_ret_type := g.resolve_return_type(val) - var_type = fn_ret_type - val_type = var_type - left.obj.typ = var_type - g.assign_ct_type = var_type - } else if val.left_type != 0 && g.table.type_kind(val.left_type) == .array - && val.name == 'map' && val.args.len > 0 && val.args[0].expr is ast.AsCast - && val.args[0].expr.typ.has_flag(.generic) { - var_type = g.table.find_or_register_array(g.unwrap_generic((val.args[0].expr as ast.AsCast).typ)) - val_type = var_type - left.obj.typ = var_type - g.assign_ct_type = var_type - } - } else if val is ast.InfixExpr && val.op in [.plus, .minus, .mul, .div, .mod] - && val.left_ct_expr { - ctyp := g.unwrap_generic(g.type_resolver.get_type(val.left)) - if ctyp != ast.void_type { - ct_type_var := g.comptime.get_ct_type_var(val.left) - if ct_type_var in [.key_var, .value_var] { - g.type_resolver.update_ct_type(left.name, g.unwrap_generic(ctyp)) + } else if val is ast.InfixExpr && val.op in [.plus, .minus, .mul, .div, .mod] + && val.left_ct_expr { + ctyp := g.unwrap_generic(g.type_resolver.get_type(val.left)) + if ctyp != ast.void_type { + ct_type_var := g.comptime.get_ct_type_var(val.left) + if ct_type_var in [.key_var, .value_var] { + g.type_resolver.update_ct_type(left.name, g.unwrap_generic(ctyp)) + } + var_type = ctyp + val_type = var_type + left.obj.typ = var_type + g.assign_ct_type = var_type } - var_type = ctyp - val_type = var_type - left.obj.typ = var_type - g.assign_ct_type = var_type } } is_auto_heap = left.obj.is_auto_heap diff --git a/vlib/v/gen/c/fn.v b/vlib/v/gen/c/fn.v index d84ba1040b..586bdaa8a0 100644 --- a/vlib/v/gen/c/fn.v +++ b/vlib/v/gen/c/fn.v @@ -2028,8 +2028,7 @@ fn (mut g Gen) fn_call(node ast.CallExpr) { // Handle `print(x)` mut print_auto_str := false if is_print && (node.args[0].typ != ast.string_type - || g.comptime.comptime_for_method != unsafe { nil } - || g.comptime.is_comptime(node.args[0].expr)) { + || g.comptime.comptime_for_method != unsafe { nil } || node.args[0].ct_expr) { g.inside_interface_deref = true defer { g.inside_interface_deref = false @@ -2577,7 +2576,7 @@ fn (mut g Gen) keep_alive_call_postgen(node ast.CallExpr, tmp_cnt_save int) { @[inline] fn (mut g Gen) ref_or_deref_arg(arg ast.CallArg, expected_type ast.Type, lang ast.Language, is_smartcast bool) { - arg_typ := if arg.expr is ast.ComptimeSelector { + arg_typ := if arg.ct_expr { g.unwrap_generic(g.type_resolver.get_type(arg.expr)) } else { g.unwrap_generic(arg.typ) diff --git a/vlib/v/tests/comptime_generic_comptime_variant_test.v b/vlib/v/tests/comptime_generic_comptime_variant_test.v new file mode 100644 index 0000000000..b853863d1e --- /dev/null +++ b/vlib/v/tests/comptime_generic_comptime_variant_test.v @@ -0,0 +1,38 @@ +pub type DesiredCapabilities = FireFox | Edge + +struct FireFox { + browser_name string = 'firefox' + accept_insecure_certs bool = true + moz_debugger_address bool = true +} + +struct Edge { + browser_name string = 'MicrosoftEdge' +} + +fn struct_values[T](s T) map[string]string { + mut res := map[string]string{} + $if T is $struct { + $for field in T.fields { + res[field.name] = s.$(field.name).str() + } + } + return res +} + +fn useit(dc DesiredCapabilities) string { + $for v in dc.variants { + if dc is v { + $if v is $struct { + result := struct_values(dc) + return result.str() + } + } + } + return '' +} + +fn test_main() { + assert useit(Edge{}) == "{'browser_name': 'MicrosoftEdge'}" + assert useit(FireFox{}) == "{'browser_name': 'firefox', 'accept_insecure_certs': 'true', 'moz_debugger_address': 'true'}" +} diff --git a/vlib/v/tests/generics/generic_var_loop_test.v b/vlib/v/tests/generics/generic_var_loop_test.v new file mode 100644 index 0000000000..062c3ade52 --- /dev/null +++ b/vlib/v/tests/generics/generic_var_loop_test.v @@ -0,0 +1,18 @@ +import math + +fn t[T](a1 []int, a2 []int) T { + mut a := 0 * a1[0] - a2[0] + a = math.max(a, 10) + mut t := T{} + for i in a .. 20 { + t += i + for j in a .. 20 { + t += j + } + } + return t +} + +fn test_main() { + assert t[int]([1, 2, 3], [4, 5, 6]) == 1595 +} diff --git a/vlib/v/type_resolver/comptime_resolver.v b/vlib/v/type_resolver/comptime_resolver.v index d09e9a60af..8d4012a33f 100644 --- a/vlib/v/type_resolver/comptime_resolver.v +++ b/vlib/v/type_resolver/comptime_resolver.v @@ -47,7 +47,7 @@ pub fn (t &ResolverInfo) is_comptime(node ast.Expr) bool { } } ast.SelectorExpr { - return node.expr is ast.Ident && node.expr.ct_expr + return node.expr is ast.Ident && node.expr.ct_expr && node.field_name != 'len' } ast.InfixExpr { return node.left_ct_expr || node.right_ct_expr @@ -55,6 +55,9 @@ pub fn (t &ResolverInfo) is_comptime(node ast.Expr) bool { ast.ParExpr { return t.is_comptime(node.expr) } + ast.ComptimeSelector { + return true + } else { false } diff --git a/vlib/v/type_resolver/generic_resolver.v b/vlib/v/type_resolver/generic_resolver.v index 457c8a5234..aab11a4e09 100644 --- a/vlib/v/type_resolver/generic_resolver.v +++ b/vlib/v/type_resolver/generic_resolver.v @@ -67,6 +67,9 @@ pub fn (t &TypeResolver) is_generic_expr(node ast.Expr) bool { if node.is_static_method && node.left_type.has_flag(.generic) { return true } + if node.return_type_generic != 0 && node.return_type_generic.has_flag(.generic) { + return true + } // fn[T]() or generic_var.fn[T]() node.concrete_types.any(it.has_flag(.generic)) } @@ -119,6 +122,7 @@ pub fn (mut t TypeResolver) resolve_args(cur_fn &ast.FnDecl, func &ast.Fn, mut n mut comptime_args := map[int]ast.Type{} has_dynamic_vars := (cur_fn != unsafe { nil } && cur_fn.generic_names.len > 0) || t.info.comptime_for_field_var != '' + || func.generic_names.len != node_.raw_concrete_types.len if !has_dynamic_vars { return comptime_args } diff --git a/vlib/v/type_resolver/type_resolver.v b/vlib/v/type_resolver/type_resolver.v index 2e4cd38fbf..edcd34b408 100644 --- a/vlib/v/type_resolver/type_resolver.v +++ b/vlib/v/type_resolver/type_resolver.v @@ -101,15 +101,10 @@ pub fn (mut t TypeResolver) get_type_or_default(node ast.Expr, default_typ ast.T } ast.SelectorExpr { if node.expr is ast.Ident && node.expr.ct_expr { - struct_typ := t.resolver.unwrap_generic(t.get_type(node.expr)) - struct_sym := t.table.final_sym(struct_typ) - // Struct[T] can have field with generic type - if struct_sym.info is ast.Struct && struct_sym.info.generic_types.len > 0 { - if field := t.table.find_field(struct_sym, node.field_name) { - return field.typ - } - } + ctyp := t.get_type(node) + return if ctyp != ast.void_type { ctyp } else { default_typ } } + return default_typ } ast.ParExpr { return t.get_type_or_default(node.expr, default_typ) @@ -189,8 +184,21 @@ pub fn (mut t TypeResolver) get_type(node ast.Expr) ast.Type { } else if node is ast.ComptimeSelector { // val.$(field.name) return t.get_comptime_selector_type(node, ast.void_type) - } else if node is ast.SelectorExpr && t.info.is_comptime_selector_type(node) { - return t.get_type_from_comptime_var(node.expr as ast.Ident) + } else if node is ast.SelectorExpr { + if t.info.is_comptime_selector_type(node) { + return t.get_type_from_comptime_var(node.expr as ast.Ident) + } + if node.expr is ast.Ident && node.expr.ct_expr { + struct_typ := t.resolver.unwrap_generic(t.get_type(node.expr)) + struct_sym := t.table.final_sym(struct_typ) + // Struct[T] can have field with generic type + if struct_sym.info is ast.Struct && struct_sym.info.generic_types.len > 0 { + if field := t.table.find_field(struct_sym, node.field_name) { + return field.typ + } + } + } + return node.typ } else if node is ast.ComptimeCall { method_name := t.info.comptime_for_method.name left_sym := t.table.sym(t.resolver.unwrap_generic(node.left_type))