From 6a60db8768aaf6cc07fbd268f20cc588cd95a7cf Mon Sep 17 00:00:00 2001 From: Felipe Pena Date: Mon, 10 Apr 2023 04:42:49 -0300 Subject: [PATCH] cgen, checker: fix generic/comptime parameter concrete type resolution in some cases (#17762) --- vlib/v/ast/ast.v | 1 + vlib/v/checker/assign.v | 8 +- vlib/v/checker/check_types.v | 48 +++++ vlib/v/checker/checker.v | 5 +- vlib/v/checker/comptime.v | 32 ++-- vlib/v/checker/fn.v | 164 ++++++++++++---- vlib/v/checker/for.v | 41 +++- vlib/v/checker/postfix.v | 2 +- vlib/v/gen/c/assign.v | 24 ++- vlib/v/gen/c/cgen.v | 16 +- vlib/v/gen/c/comptime.v | 53 +++-- vlib/v/gen/c/fn.v | 234 ++++++++++++++++------- vlib/v/gen/c/for.v | 42 ++-- vlib/v/parser/fn.v | 8 +- vlib/v/tests/generic_recursive_fn_test.v | 46 +++++ vlib/v/tests/generic_resolve_test.v | 50 +++++ vlib/v/tests/resolve_generic_2_test.v | 20 ++ 17 files changed, 619 insertions(+), 175 deletions(-) create mode 100644 vlib/v/tests/generic_recursive_fn_test.v create mode 100644 vlib/v/tests/generic_resolve_test.v create mode 100644 vlib/v/tests/resolve_generic_2_test.v diff --git a/vlib/v/ast/ast.v b/vlib/v/ast/ast.v index 1e463665af..e1814f045d 100644 --- a/vlib/v/ast/ast.v +++ b/vlib/v/ast/ast.v @@ -755,6 +755,7 @@ pub enum ComptimeVarKind { key_var // map key from `for k,v in t.$(field.name)` value_var // map value from `for k,v in t.$(field.name)` field_var // comptime field var `a := t.$(field.name)` + generic_param // generic fn parameter } [minify] diff --git a/vlib/v/checker/assign.v b/vlib/v/checker/assign.v index 9b52d9ab1e..3553cba664 100644 --- a/vlib/v/checker/assign.v +++ b/vlib/v/checker/assign.v @@ -329,7 +329,13 @@ fn (mut c Checker) assign_stmt(mut node ast.AssignStmt) { left.obj.typ = c.comptime_fields_default_type } else if right is ast.Ident && (right as ast.Ident).obj is ast.Var && (right as ast.Ident).or_expr.kind == .absent { - left.obj.ct_type_var = ((right as ast.Ident).obj as ast.Var).ct_type_var + if ((right as ast.Ident).obj as ast.Var).ct_type_var != .no_comptime { + ctyp := c.get_comptime_var_type(right) + if ctyp != ast.void_type { + left.obj.ct_type_var = ((right as ast.Ident).obj as ast.Var).ct_type_var + left.obj.typ = ctyp + } + } } else if right is ast.DumpExpr && (right as ast.DumpExpr).expr is ast.ComptimeSelector { left.obj.ct_type_var = .field_var diff --git a/vlib/v/checker/check_types.v b/vlib/v/checker/check_types.v index a71aefff5f..33efba6dab 100644 --- a/vlib/v/checker/check_types.v +++ b/vlib/v/checker/check_types.v @@ -801,6 +801,36 @@ fn (mut c Checker) infer_struct_generic_types(typ ast.Type, node ast.StructInit) return concrete_types } +fn (g Checker) get_generic_array_element_type(array ast.Array) ast.Type { + mut cparam_elem_info := array as ast.Array + mut cparam_elem_sym := g.table.sym(cparam_elem_info.elem_type) + mut typ := ast.void_type + for { + if cparam_elem_sym.kind == .array { + cparam_elem_info = cparam_elem_sym.info as ast.Array + cparam_elem_sym = g.table.sym(cparam_elem_info.elem_type) + } else { + return cparam_elem_info.elem_type.set_nr_muls(0) + } + } + return typ +} + +fn (g Checker) get_generic_array_fixed_element_type(array ast.ArrayFixed) ast.Type { + mut cparam_elem_info := array as ast.ArrayFixed + mut cparam_elem_sym := g.table.sym(cparam_elem_info.elem_type) + mut typ := ast.void_type + for { + if cparam_elem_sym.kind == .array_fixed { + cparam_elem_info = cparam_elem_sym.info as ast.ArrayFixed + cparam_elem_sym = g.table.sym(cparam_elem_info.elem_type) + } else { + return cparam_elem_info.elem_type.set_nr_muls(0) + } + } + return typ +} + fn (mut c Checker) infer_fn_generic_types(func ast.Fn, mut node ast.CallExpr) { mut inferred_types := []ast.Type{} for gi, gt_name in func.generic_names { @@ -979,6 +1009,24 @@ fn (mut c Checker) infer_fn_generic_types(func ast.Fn, mut node ast.CallExpr) { idx := generic_names.index(gt_name) typ = concrete_types[idx] } + } else if arg_sym.kind == .any && c.table.cur_fn.generic_names.len > 0 + && c.table.cur_fn.params.len > 0 && func.generic_names.len > 0 + && arg.expr is ast.Ident { + var_name := (arg.expr as ast.Ident).name + for cur_param in c.table.cur_fn.params { + if !cur_param.typ.has_flag(.generic) || cur_param.name != var_name { + continue + } + typ = cur_param.typ + mut cparam_type_sym := c.table.sym(c.unwrap_generic(typ)) + if cparam_type_sym.kind == .array { + typ = c.get_generic_array_element_type(cparam_type_sym.info as ast.Array) + } else if cparam_type_sym.kind == .array_fixed { + typ = c.get_generic_array_fixed_element_type(cparam_type_sym.info as ast.ArrayFixed) + } + typ = c.unwrap_generic(typ) + break + } } } } diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 1f83d5fa55..10bf890dfa 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -96,8 +96,6 @@ mut: timers &util.Timers = util.get_timers() comptime_for_field_var string comptime_fields_default_type ast.Type - comptime_fields_key_type ast.Type // key type on `$for k, v in val.$(field.name)` - comptime_fields_val_type ast.Type // value type on `$for k, v in val.$(field.name)` comptime_fields_type map[string]ast.Type comptime_for_field_value ast.StructField // value of the field variable comptime_enum_field_value string // current enum value name @@ -3668,8 +3666,9 @@ fn (c &Checker) has_return(stmts []ast.Stmt) ?bool { return none } +[inline] pub fn (mut c Checker) is_comptime_var(node ast.Expr) bool { - return c.inside_comptime_for_field && node is ast.Ident + return node is ast.Ident && (node as ast.Ident).info is ast.IdentVar && (node as ast.Ident).kind == .variable && ((node as ast.Ident).obj as ast.Var).ct_type_var != .no_comptime } diff --git a/vlib/v/checker/comptime.v b/vlib/v/checker/comptime.v index db4e685c0b..9a8654e8eb 100644 --- a/vlib/v/checker/comptime.v +++ b/vlib/v/checker/comptime.v @@ -10,23 +10,31 @@ import v.util import v.pkgconfig import v.checker.constants -[inline] -fn (mut c Checker) get_comptime_var_type_from_kind(kind ast.ComptimeVarKind) ast.Type { - return match kind { - .key_var { c.comptime_fields_key_type } - .value_var { c.comptime_fields_val_type } - .field_var { c.comptime_fields_default_type } - else { ast.void_type } - } -} - [inline] fn (mut c Checker) get_comptime_var_type(node ast.Expr) ast.Type { if node is ast.Ident && (node as ast.Ident).obj is ast.Var { - return c.get_comptime_var_type_from_kind((node.obj as ast.Var).ct_type_var) + return match (node.obj as ast.Var).ct_type_var { + .generic_param { + // generic parameter from current function + node.obj.typ + } + .key_var, .value_var { + // key and value variables from normal for stmt + c.comptime_fields_type[node.name] or { ast.void_type } + } + .field_var { + // field var from $for loop + c.comptime_fields_default_type + } + else { + ast.void_type + } + } } else if node is ast.ComptimeSelector { + // val.$(field.name) return c.get_comptime_selector_type(node, ast.void_type) } else if node is ast.SelectorExpr && c.is_comptime_selector_type(node as ast.SelectorExpr) { + // field_var.typ from $for field return c.comptime_fields_default_type } return ast.void_type @@ -850,7 +858,7 @@ fn (mut c Checker) get_comptime_selector_type(node ast.ComptimeSelector, default return default_type } -// check_comptime_is_field_selector checks if the SelectorExpr is related to $for variable +// check_comptime_is_field_selector checks if the SelectorExpr is related to $for variable accessing .typ field [inline] fn (mut c Checker) is_comptime_selector_type(node ast.SelectorExpr) bool { if c.inside_comptime_for_field && node.expr is ast.Ident { diff --git a/vlib/v/checker/fn.v b/vlib/v/checker/fn.v index 90f49b71d7..a4bfbdc37b 100644 --- a/vlib/v/checker/fn.v +++ b/vlib/v/checker/fn.v @@ -910,7 +910,6 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast. c.error('unknown function: ${fn_name}', node.pos) return ast.void_type } - node.is_file_translated = func.is_file_translated node.is_noreturn = func.is_noreturn node.is_ctor_new = func.is_ctor_new @@ -1251,6 +1250,7 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast. // no type arguments given in call, attempt implicit instantiation c.infer_fn_generic_types(func, mut node) concrete_types = node.concrete_types.map(c.unwrap_generic(it)) + c.resolve_fn_generic_args(func, mut node) } if func.generic_names.len > 0 { for i, mut call_arg in node.args { @@ -1330,22 +1330,131 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast. return func.return_type } -fn (mut c Checker) get_comptime_args(node_ ast.CallExpr) map[int]ast.Type { +fn (mut c Checker) get_comptime_args(func ast.Fn, node_ ast.CallExpr, concrete_types []ast.Type) map[int]ast.Type { mut comptime_args := map[int]ast.Type{} - for i, call_arg in node_.args { - if call_arg.expr is ast.Ident { - if call_arg.expr.obj is ast.Var { - if call_arg.expr.obj.ct_type_var != .no_comptime { - comptime_args[i] = c.get_comptime_var_type_from_kind(call_arg.expr.obj.ct_type_var) - } + has_dynamic_vars := (c.table.cur_fn != unsafe { nil } && c.table.cur_fn.generic_names.len > 0) + || c.inside_comptime_for_field + if has_dynamic_vars { + offset := if func.is_method { 1 } else { 0 } + for i, call_arg in node_.args { + param := if func.is_variadic && i >= func.params.len - (offset + 1) { + func.params.last() + } else { + func.params[offset + i] + } + if !param.typ.has_flag(.generic) { + continue + } + if call_arg.expr is ast.Ident { + if call_arg.expr.obj is ast.Var { + if call_arg.expr.obj.ct_type_var !in [.generic_param, .no_comptime] { + mut ctyp := c.get_comptime_var_type(call_arg.expr) + if ctyp != ast.void_type { + arg_sym := c.table.sym(ctyp) + param_typ := param.typ + if arg_sym.kind == .array && param_typ.has_flag(.generic) + && c.table.final_sym(param_typ).kind == .array { + ctyp = (arg_sym.info as ast.Array).elem_type + } + comptime_args[i] = ctyp + } + } else if call_arg.expr.obj.ct_type_var == .generic_param { + mut ctyp := c.get_comptime_var_type(call_arg.expr) + if ctyp != ast.void_type { + param_typ := param.typ + arg_sym := c.table.final_sym(call_arg.typ) + param_typ_sym := c.table.sym(param_typ) + + if param_typ.has_flag(.variadic) { + ctyp = ast.mktyp(ctyp) + comptime_args[i] = ctyp + } else if arg_sym.kind == .array && param_typ.has_flag(.generic) + && param_typ_sym.kind == .array { + ctyp = c.get_generic_array_element_type(arg_sym.info as ast.Array) + comptime_args[i] = ctyp + } else if arg_sym.kind in [.struct_, .interface_, .sum_type] { + mut generic_types := []ast.Type{} + match arg_sym.info { + ast.Struct, ast.Interface, ast.SumType { + if param_typ_sym.generic_types.len > 0 { + generic_types = param_typ_sym.generic_types.clone() + } else { + generic_types = arg_sym.info.generic_types.clone() + } + } + else {} + } + generic_names := generic_types.map(c.table.sym(it).name) + for _, gt_name in c.table.cur_fn.generic_names { + if gt_name in generic_names + && generic_types.len == concrete_types.len { + idx := generic_names.index(gt_name) + comptime_args[i] = concrete_types[idx] + break + } + } + } else if arg_sym.kind == .any { + mut cparam_type_sym := c.table.sym(c.unwrap_generic(ctyp)) + if param_typ_sym.kind == .array && cparam_type_sym.kind == .array { + ctyp = (cparam_type_sym.info as ast.Array).elem_type + comptime_args[i] = ctyp + } else { + if node_.args[i].expr.is_auto_deref_var() { + ctyp = ctyp.deref() + } + if ctyp.nr_muls() > 0 && param_typ.nr_muls() > 0 { + ctyp = ctyp.set_nr_muls(0) + } + comptime_args[i] = ctyp + } + } else { + comptime_args[i] = ctyp + } + } + } + } + } else if call_arg.expr is ast.ComptimeSelector && c.is_comptime_var(call_arg.expr) { + comptime_args[i] = c.get_comptime_var_type(call_arg.expr) } - } else if call_arg.expr is ast.ComptimeSelector && c.is_comptime_var(call_arg.expr) { - comptime_args[i] = c.get_comptime_var_type(call_arg.expr) } } return comptime_args } +fn (mut c Checker) resolve_fn_generic_args(func ast.Fn, mut node ast.CallExpr) []ast.Type { + mut concrete_types := node.concrete_types.map(c.unwrap_generic(it)) + + // dynamic values from comptime and generic parameters + // overwrite concrete_types[ receiver_concrete_type + arg number ] + if concrete_types.len > 0 { + mut rec_len := 0 + // discover receiver concrete_type len + if func.is_method && node.left_type.has_flag(.generic) { + rec_sym := c.table.final_sym(c.unwrap_generic(node.left_type)) + match rec_sym.info { + ast.Struct, ast.Interface, ast.SumType { + rec_len += rec_sym.info.generic_types.len + } + else {} + } + } + + mut comptime_args := c.get_comptime_args(func, node, concrete_types) + if comptime_args.len > 0 { + for k, v in comptime_args { + if (rec_len + k) < concrete_types.len { + concrete_types[rec_len + k] = c.unwrap_generic(v) + } + } + if c.table.register_fn_concrete_types(func.fkey(), concrete_types) { + c.need_recheck_generic_fns = true + } + } + } + + return concrete_types +} + fn (mut c Checker) method_call(mut node ast.CallExpr) ast.Type { left_type := c.expr(node.left) if left_type == ast.void_type { @@ -1531,36 +1640,10 @@ fn (mut c Checker) method_call(mut node ast.CallExpr) ast.Type { } else {} } - mut concrete_types := []ast.Type{} - for concrete_type in node.concrete_types { - if concrete_type.has_flag(.generic) { - concrete_types << c.unwrap_generic(concrete_type) - } else { - concrete_types << concrete_type - } - } - if c.inside_comptime_for_field && concrete_types.len > 0 { - mut comptime_args := c.get_comptime_args(node) - mut comptime_types := concrete_types.clone() - for k, v in comptime_args { - arg_sym := c.table.sym(v) - if method.generic_names.len > 0 && arg_sym.kind == .array - && method.params[k + 1].typ.has_flag(.generic) - && c.table.final_sym(method.params[k + 1].typ).kind == .array { - comptime_types[k] = (arg_sym.info as ast.Array).elem_type - } else { - comptime_types[k] = v - } - } - if comptime_args.len > 0 - && c.table.register_fn_concrete_types(method.fkey(), comptime_types) { - c.need_recheck_generic_fns = true - } - } - if concrete_types.len > 0 { - if c.table.register_fn_concrete_types(method.fkey(), concrete_types) { - c.need_recheck_generic_fns = true - } + mut concrete_types := node.concrete_types.map(c.unwrap_generic(it)) + if concrete_types.len > 0 + && c.table.register_fn_concrete_types(method.fkey(), concrete_types) { + c.need_recheck_generic_fns = true } node.is_noreturn = method.is_noreturn node.is_ctor_new = method.is_ctor_new @@ -1854,6 +1937,7 @@ fn (mut c Checker) method_call(mut node ast.CallExpr) ast.Type { } if concrete_types.len > 0 && !concrete_types[0].has_flag(.generic) { c.table.register_fn_concrete_types(method.fkey(), concrete_types) + c.resolve_fn_generic_args(method, mut node) } // resolve return generics struct to concrete type diff --git a/vlib/v/checker/for.v b/vlib/v/checker/for.v index f5fbce8c58..87abf3ffc7 100644 --- a/vlib/v/checker/for.v +++ b/vlib/v/checker/for.v @@ -84,9 +84,13 @@ fn (mut c Checker) for_in_stmt(mut node ast.ForInStmt) { mut is_comptime := false if (node.cond is ast.Ident && c.is_comptime_var(node.cond)) || node.cond is ast.ComptimeSelector { - is_comptime = true - typ = c.unwrap_generic(c.comptime_fields_default_type) + ctyp := c.get_comptime_var_type(node.cond) + if ctyp != ast.void_type { + is_comptime = true + typ = ctyp + } } + mut sym := c.table.final_sym(typ) if sym.kind != .string { match mut node.cond { @@ -139,6 +143,15 @@ fn (mut c Checker) for_in_stmt(mut node ast.ForInStmt) { node.kind = sym.kind node.val_type = val_type node.scope.update_var_type(node.val_var, val_type) + + if is_comptime { + c.comptime_fields_type[node.val_var] = val_type + node.scope.update_ct_var_kind(node.val_var, .value_var) + + defer { + c.comptime_fields_type.delete(node.val_var) + } + } } else if sym.kind == .any { node.cond_type = typ node.kind = sym.kind @@ -155,8 +168,12 @@ fn (mut c Checker) for_in_stmt(mut node ast.ForInStmt) { node.scope.update_var_type(node.key_var, key_type) if is_comptime { - c.comptime_fields_key_type = key_type + c.comptime_fields_type[node.key_var] = key_type node.scope.update_ct_var_kind(node.key_var, .key_var) + + defer { + c.comptime_fields_type.delete(node.key_var) + } } } @@ -164,8 +181,12 @@ fn (mut c Checker) for_in_stmt(mut node ast.ForInStmt) { node.scope.update_var_type(node.val_var, value_type) if is_comptime { - c.comptime_fields_val_type = value_type + c.comptime_fields_type[node.val_var] = value_type node.scope.update_ct_var_kind(node.val_var, .value_var) + + defer { + c.comptime_fields_type.delete(node.val_var) + } } } else { if sym.kind == .map && !(node.key_var.len > 0 && node.val_var.len > 0) { @@ -182,8 +203,12 @@ fn (mut c Checker) for_in_stmt(mut node ast.ForInStmt) { node.scope.update_var_type(node.key_var, key_type) if is_comptime { - c.comptime_fields_key_type = key_type + c.comptime_fields_type[node.key_var] = key_type node.scope.update_ct_var_kind(node.key_var, .key_var) + + defer { + c.comptime_fields_type.delete(node.key_var) + } } } mut value_type := c.table.value_type(typ) @@ -235,8 +260,12 @@ fn (mut c Checker) for_in_stmt(mut node ast.ForInStmt) { node.val_type = value_type node.scope.update_var_type(node.val_var, value_type) if is_comptime { - c.comptime_fields_val_type = value_type + c.comptime_fields_type[node.val_var] = value_type node.scope.update_ct_var_kind(node.val_var, .value_var) + + defer { + c.comptime_fields_type.delete(node.val_var) + } } } } diff --git a/vlib/v/checker/postfix.v b/vlib/v/checker/postfix.v index 6fb548c0f7..6eab0ec231 100644 --- a/vlib/v/checker/postfix.v +++ b/vlib/v/checker/postfix.v @@ -25,7 +25,7 @@ fn (mut c Checker) postfix_expr(mut node ast.PostfixExpr) ast.Type { if !(typ_sym.is_number() || ((c.inside_unsafe || c.pref.translated) && is_non_void_pointer)) { if c.inside_comptime_for_field { if c.is_comptime_var(node.expr) || node.expr is ast.ComptimeSelector { - return c.comptime_fields_default_type + return c.unwrap_generic(c.get_comptime_var_type(node.expr)) } } diff --git a/vlib/v/gen/c/assign.v b/vlib/v/gen/c/assign.v index 842e81ece9..96fded29c4 100644 --- a/vlib/v/gen/c/assign.v +++ b/vlib/v/gen/c/assign.v @@ -225,13 +225,16 @@ fn (mut g Gen) assign_stmt(node_ ast.AssignStmt) { } if mut left.obj is ast.Var { if val is ast.Ident && g.is_comptime_var(val) { - var_type = g.unwrap_generic(g.get_comptime_var_type(val)) - val_type = var_type - gen_or = val.or_expr.kind != .absent - if gen_or { - var_type = val_type.clear_flag(.option) + ctyp := g.unwrap_generic(g.get_comptime_var_type(val)) + if ctyp != ast.void_type { + var_type = ctyp + val_type = var_type + gen_or = val.or_expr.kind != .absent + if gen_or { + var_type = val_type.clear_flag(.option) + } + left.obj.typ = var_type } - left.obj.typ = var_type } else if val is ast.ComptimeSelector { key_str := g.get_comptime_selector_key_type(val) if key_str != '' { @@ -257,6 +260,15 @@ fn (mut g Gen) assign_stmt(node_ ast.AssignStmt) { val_type = var_type left.obj.typ = var_type } + } else if val is ast.IndexExpr { + if val.left is ast.Ident && g.is_generic_param_var((val as ast.IndexExpr).left) { + ctyp := g.unwrap_generic(g.get_gn_var_type((val as ast.IndexExpr).left as ast.Ident)) + if ctyp != ast.void_type { + var_type = ctyp + val_type = var_type + left.obj.typ = var_type + } + } } is_auto_heap = left.obj.is_auto_heap } diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index cb104ea99e..acd05280c2 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -143,7 +143,6 @@ mut: inside_const bool inside_const_opt_or_res bool inside_lambda bool - inside_for_in_any_cond bool inside_cinit bool last_tmp_call_var []string loop_depth int @@ -203,8 +202,6 @@ mut: comptime_for_field_var string // $for field in T.fields {}; the variable name comptime_for_field_value ast.StructField // value of the field variable comptime_for_field_type ast.Type // type of the field variable inferred from `$if field.typ is T {}` - comptime_for_field_key_type ast.Type // type of key on comptime for on map field - comptime_for_field_val_type ast.Type // type of value on comptime for on map field comptime_enum_field_value string // value of enum name comptime_var_type_map map[string]ast.Type comptime_values_stack []CurrentComptimeValues // stores the values from the above on each $for loop, to make nesting them easier @@ -4049,9 +4046,16 @@ fn (mut g Gen) select_expr(node ast.SelectExpr) { } } +[inline] +pub fn (mut g Gen) is_generic_param_var(node ast.Expr) bool { + return node is ast.Ident + && (node as ast.Ident).info is ast.IdentVar && (node as ast.Ident).obj is ast.Var && ((node as ast.Ident).obj as ast.Var).ct_type_var == .generic_param +} + +[inline] pub fn (mut g Gen) is_comptime_var(node ast.Expr) bool { - return g.inside_comptime_for_field && node is ast.Ident - && (node as ast.Ident).info is ast.IdentVar && ((node as ast.Ident).obj as ast.Var).ct_type_var != .no_comptime + return node is ast.Ident + && (node as ast.Ident).info is ast.IdentVar && (node as ast.Ident).obj is ast.Var && ((node as ast.Ident).obj as ast.Var).ct_type_var != .no_comptime } fn (mut g Gen) ident(node ast.Ident) { @@ -4084,7 +4088,7 @@ fn (mut g Gen) ident(node ast.Ident) { mut is_auto_heap := false if node.info is ast.IdentVar { if node.obj is ast.Var { - if !g.is_assign_lhs && node.obj.ct_type_var != .no_comptime { + if !g.is_assign_lhs && node.obj.ct_type_var !in [.generic_param, .no_comptime] { comptime_type := g.get_comptime_var_type(node) if comptime_type.has_flag(.option) { if (g.inside_opt_or_res || g.left_is_opt) && node.or_expr.kind == .absent { diff --git a/vlib/v/gen/c/comptime.v b/vlib/v/gen/c/comptime.v index 0c7c96489b..8d8425aa3c 100644 --- a/vlib/v/gen/c/comptime.v +++ b/vlib/v/gen/c/comptime.v @@ -647,15 +647,13 @@ fn (mut g Gen) comptime_if_cond(cond ast.Expr, pkg_exist bool) (bool, bool) { // struct CurrentComptimeValues { - inside_comptime_for_field bool - comptime_for_method string - comptime_for_method_var string - comptime_for_field_var string - comptime_for_field_value ast.StructField - comptime_for_field_type ast.Type - comptime_for_field_key_type ast.Type - comptime_for_field_val_type ast.Type - comptime_var_type_map map[string]ast.Type + inside_comptime_for_field bool + comptime_for_method string + comptime_for_method_var string + comptime_for_field_var string + comptime_for_field_value ast.StructField + comptime_for_field_type ast.Type + comptime_var_type_map map[string]ast.Type } fn (mut g Gen) push_existing_comptime_values() { @@ -666,8 +664,6 @@ fn (mut g Gen) push_existing_comptime_values() { comptime_for_field_var: g.comptime_for_field_var comptime_for_field_value: g.comptime_for_field_value comptime_for_field_type: g.comptime_for_field_type - comptime_for_field_key_type: g.comptime_for_field_key_type - comptime_for_field_val_type: g.comptime_for_field_val_type comptime_var_type_map: g.comptime_var_type_map.clone() } } @@ -680,29 +676,46 @@ fn (mut g Gen) pop_existing_comptime_values() { g.comptime_for_field_var = old.comptime_for_field_var g.comptime_for_field_value = old.comptime_for_field_value g.comptime_for_field_type = old.comptime_for_field_type - g.comptime_for_field_key_type = old.comptime_for_field_key_type - g.comptime_for_field_val_type = old.comptime_for_field_val_type g.comptime_var_type_map = old.comptime_var_type_map.clone() } +// check_comptime_is_field_selector checks if the SelectorExpr is related to $for variable accessing .typ field [inline] -fn (mut g Gen) get_comptime_var_type_from_kind(kind ast.ComptimeVarKind) ast.Type { - return match kind { - .key_var { g.comptime_for_field_key_type } - .value_var { g.comptime_for_field_val_type } - .field_var { g.comptime_for_field_type } - else { ast.void_type } +fn (mut g Gen) is_comptime_selector_type(node ast.SelectorExpr) bool { + if g.inside_comptime_for_field && node.expr is ast.Ident { + return (node.expr as ast.Ident).name == g.comptime_for_field_var && node.field_name == 'typ' } + return false } fn (mut g Gen) get_comptime_var_type(node ast.Expr) ast.Type { if node is ast.Ident && (node as ast.Ident).obj is ast.Var { - return g.get_comptime_var_type_from_kind((node.obj as ast.Var).ct_type_var) + return match (node.obj as ast.Var).ct_type_var { + .generic_param { + // generic parameter from current function + node.obj.typ + } + .key_var, .value_var { + // key and value variables from normal for stmt + g.comptime_var_type_map[node.name] or { ast.void_type } + } + .field_var { + // field var from $for loop + g.comptime_for_field_type + } + else { + ast.void_type + } + } } else if node is ast.ComptimeSelector { + // val.$(field.name) key_str := g.get_comptime_selector_key_type(node) if key_str != '' { return g.comptime_var_type_map[key_str] or { ast.void_type } } + } else if node is ast.SelectorExpr && g.is_comptime_selector_type(node as ast.SelectorExpr) { + // field_var.typ from $for field + return g.comptime_for_field_type } return ast.void_type } diff --git a/vlib/v/gen/c/fn.v b/vlib/v/gen/c/fn.v index 9175ddc256..d24c4e6fd2 100644 --- a/vlib/v/gen/c/fn.v +++ b/vlib/v/gen/c/fn.v @@ -972,18 +972,133 @@ fn (mut g Gen) gen_to_str_method_call(node ast.CallExpr) bool { return false } -fn (mut g Gen) change_comptime_args(mut node_ ast.CallExpr) map[int]ast.Type { - mut comptime_args := map[int]ast.Type{} - for i, mut call_arg in node_.args { - if mut call_arg.expr is ast.Ident { - if mut call_arg.expr.obj is ast.Var { - node_.args[i].typ = call_arg.expr.obj.typ - if call_arg.expr.obj.ct_type_var != .no_comptime { - comptime_args[i] = g.get_comptime_var_type_from_kind(call_arg.expr.obj.ct_type_var) - } +fn (mut g Gen) get_gn_var_type(var ast.Ident) ast.Type { + if g.cur_fn != unsafe { nil } && g.cur_fn.generic_names.len > 0 { + for k, cur_param in g.cur_fn.params { + if (k == 0 && g.cur_fn.is_method) || !cur_param.typ.has_flag(.generic) + || var.name != cur_param.name { + continue + } + mut typ := cur_param.typ + mut cparam_type_sym := g.table.sym(g.unwrap_generic(typ)) + + if cparam_type_sym.kind == .array { + typ = g.unwrap_generic((cparam_type_sym.info as ast.Array).elem_type) + } else if cparam_type_sym.kind == .array_fixed { + typ = g.unwrap_generic((cparam_type_sym.info as ast.ArrayFixed).elem_type) + } + return typ + } + } + return ast.void_type +} + +fn (g Gen) get_generic_array_element_type(array ast.Array) ast.Type { + mut cparam_elem_info := array as ast.Array + mut cparam_elem_sym := g.table.sym(cparam_elem_info.elem_type) + mut typ := ast.void_type + for { + if cparam_elem_sym.kind == .array { + cparam_elem_info = cparam_elem_sym.info as ast.Array + cparam_elem_sym = g.table.sym(cparam_elem_info.elem_type) + } else { + typ = cparam_elem_info.elem_type + if cparam_elem_info.elem_type.nr_muls() > 0 && typ.nr_muls() > 0 { + typ = typ.set_nr_muls(0) + } + break + } + } + return typ +} + +fn (mut g Gen) change_comptime_args(func ast.Fn, mut node_ ast.CallExpr, concrete_types []ast.Type) map[int]ast.Type { + mut comptime_args := map[int]ast.Type{} + has_dynamic_vars := (g.cur_fn != unsafe { nil } && g.cur_fn.generic_names.len > 0) + || g.inside_comptime_for_field + if has_dynamic_vars { + offset := if func.is_method { 1 } else { 0 } + for i, mut call_arg in node_.args { + param := if func.is_variadic && i >= func.params.len - (offset + 1) { + func.params.last() + } else { + func.params[offset + i] + } + if !param.typ.has_flag(.generic) { + continue + } + if mut call_arg.expr is ast.Ident { + if mut call_arg.expr.obj is ast.Var { + node_.args[i].typ = call_arg.expr.obj.typ + if call_arg.expr.obj.ct_type_var !in [.generic_param, .no_comptime] { + mut ctyp := g.get_comptime_var_type(call_arg.expr) + if ctyp != ast.void_type { + arg_sym := g.table.sym(ctyp) + param_typ := param.typ + if arg_sym.kind == .array && param_typ.has_flag(.generic) + && g.table.final_sym(param_typ).kind == .array { + ctyp = (arg_sym.info as ast.Array).elem_type + } + comptime_args[i] = ctyp + } + } else if call_arg.expr.obj.ct_type_var == .generic_param { + mut ctyp := g.get_comptime_var_type(call_arg.expr) + if ctyp != ast.void_type { + param_typ := param.typ + arg_sym := g.table.final_sym(call_arg.typ) + param_typ_sym := g.table.sym(param_typ) + + if param_typ.has_flag(.variadic) { + ctyp = ast.mktyp(ctyp) + comptime_args[i] = ctyp + } else if arg_sym.kind == .array && param_typ.has_flag(.generic) + && param_typ_sym.kind == .array { + ctyp = g.get_generic_array_element_type(arg_sym.info as ast.Array) + comptime_args[i] = ctyp + } else if arg_sym.kind in [.struct_, .interface_, .sum_type] { + mut generic_types := []ast.Type{} + match arg_sym.info { + ast.Struct, ast.Interface, ast.SumType { + if param_typ_sym.generic_types.len > 0 { + generic_types = param_typ_sym.generic_types.clone() + } else { + generic_types = arg_sym.info.generic_types.clone() + } + } + else {} + } + generic_names := generic_types.map(g.table.sym(it).name) + for _, gt_name in g.cur_fn.generic_names { + if gt_name in generic_names + && generic_types.len == concrete_types.len { + idx := generic_names.index(gt_name) + comptime_args[i] = concrete_types[idx] + break + } + } + } else if arg_sym.kind == .any { + mut cparam_type_sym := g.table.sym(g.unwrap_generic(ctyp)) + if param_typ_sym.kind == .array && cparam_type_sym.kind == .array { + ctyp = (cparam_type_sym.info as ast.Array).elem_type + comptime_args[i] = ctyp + } else { + if node_.args[i].expr.is_auto_deref_var() { + ctyp = ctyp.deref() + } + if ctyp.nr_muls() > 0 && param_typ.nr_muls() > 0 { + ctyp = ctyp.set_nr_muls(0) + } + comptime_args[i] = ctyp + } + } else { + comptime_args[i] = ctyp + } + } + } + } + } else if mut call_arg.expr is ast.ComptimeSelector { + comptime_args[i] = g.comptime_for_field_type } - } else if mut call_arg.expr is ast.ComptimeSelector { - comptime_args[i] = g.comptime_for_field_type } } return comptime_args @@ -999,8 +1114,6 @@ fn (mut g Gen) method_call(node ast.CallExpr) { } left_type := g.unwrap_generic(node.left_type) mut unwrapped_rec_type := node.receiver_type - mut for_in_any_var_type := ast.void_type - mut comptime_args := map[int]ast.Type{} if g.cur_fn != unsafe { nil } && g.cur_fn.generic_names.len > 0 { // in generic fn unwrapped_rec_type = g.unwrap_generic(node.receiver_type) } else { // in non-generic fn @@ -1019,7 +1132,6 @@ fn (mut g Gen) method_call(node ast.CallExpr) { else {} } } - if node.from_embed_types.len == 0 && node.left is ast.Ident { if node.left.obj is ast.Var { if node.left.obj.smartcasts.len > 0 { @@ -1031,21 +1143,6 @@ fn (mut g Gen) method_call(node ast.CallExpr) { } } } - - if g.inside_comptime_for_field { - mut node_ := unsafe { node } - comptime_args = g.change_comptime_args(mut node_) - } - if g.inside_for_in_any_cond { - for call_arg in node.args { - if call_arg.expr is ast.Ident { - if call_arg.expr.obj is ast.Var { - for_in_any_var_type = call_arg.expr.obj.typ - } - } - } - } - mut typ_sym := g.table.sym(unwrapped_rec_type) // non-option alias type that undefined this method (not include `str`) need to use parent type if !left_type.has_flag(.option) && typ_sym.kind == .alias && node.name != 'str' @@ -1207,28 +1304,33 @@ fn (mut g Gen) method_call(node ast.CallExpr) { } } - if g.comptime_for_field_type != 0 && g.inside_comptime_for_field && comptime_args.len > 0 { - mut concrete_types := node.concrete_types.map(g.unwrap_generic(it)) - arg_sym := g.table.sym(g.comptime_for_field_type) - if m := g.table.find_method(g.table.sym(node.left_type), node.name) { - for k, v in comptime_args { - if m.generic_names.len > 0 && arg_sym.kind == .array - && m.params[k + 1].typ.has_flag(.generic) - && g.table.final_sym(m.params[k + 1].typ).kind == .array { - concrete_types[k] = (arg_sym.info as ast.Array).elem_type - } else { - concrete_types[k] = v + if node.concrete_types.len > 0 { + mut rec_len := 0 + if node.left_type.has_flag(.generic) { + rec_sym := g.table.final_sym(g.unwrap_generic(node.left_type)) + match rec_sym.info { + ast.Struct, ast.Interface, ast.SumType { + rec_len += rec_sym.info.generic_types.len } + else {} } } - name = g.generic_fn_name(concrete_types, name) - } else if g.inside_for_in_any_cond && for_in_any_var_type != ast.void_type { - name = g.generic_fn_name([for_in_any_var_type], name) - } else { - concrete_types := node.concrete_types.map(g.unwrap_generic(it)) - name = g.generic_fn_name(concrete_types, name) + mut concrete_types := node.concrete_types.map(g.unwrap_generic(it)) + if m := g.table.find_method(g.table.sym(node.left_type), node.name) { + mut node_ := unsafe { node } + comptime_args := g.change_comptime_args(m, mut node_, concrete_types) + for k, v in comptime_args { + if (rec_len + k) < concrete_types.len { + if !node.concrete_types[k].has_flag(.generic) { + concrete_types[rec_len + k] = g.unwrap_generic(v) + } + } + } + name = g.generic_fn_name(concrete_types, name) + } else { + name = g.generic_fn_name(concrete_types, name) + } } - // TODO2 // g.generate_tmp_autofree_arg_vars(node, name) if !node.receiver_type.is_ptr() && left_type.is_ptr() && node.name == 'str' { @@ -1348,7 +1450,6 @@ fn (mut g Gen) fn_call(node ast.CallExpr) { // will be `0` for `foo()` mut is_interface_call := false mut is_selector_call := false - mut comptime_args := map[int]ast.Type{} if node.left_type != 0 { left_sym := g.table.sym(node.left_type) if left_sym.kind == .interface_ { @@ -1373,10 +1474,6 @@ fn (mut g Gen) fn_call(node ast.CallExpr) { } is_selector_call = true } - if g.inside_comptime_for_field { - mut node_ := unsafe { node } - comptime_args = g.change_comptime_args(mut node_) - } mut name := node.name is_print := name in ['print', 'println', 'eprint', 'eprintln', 'panic'] print_method := name @@ -1460,24 +1557,18 @@ fn (mut g Gen) fn_call(node ast.CallExpr) { } if !is_selector_call { if func := g.table.find_fn(node.name) { - if func.generic_names.len > 0 { - if g.comptime_for_field_type != 0 && g.inside_comptime_for_field - && comptime_args.len > 0 { - mut concrete_types := node.concrete_types.map(g.unwrap_generic(it)) - arg_sym := g.table.sym(g.comptime_for_field_type) - for k, v in comptime_args { - if arg_sym.kind == .array && func.params[k].typ.has_flag(.generic) - && g.table.sym(func.params[k].typ).kind == .array { - concrete_types[k] = (arg_sym.info as ast.Array).elem_type - } else { - concrete_types[k] = v + mut concrete_types := node.concrete_types.map(g.unwrap_generic(it)) + mut node_ := unsafe { node } + comptime_args := g.change_comptime_args(func, mut node_, concrete_types) + if concrete_types.len > 0 { + for k, v in comptime_args { + if k < concrete_types.len { + if !node.concrete_types[k].has_flag(.generic) { + concrete_types[k] = g.unwrap_generic(v) } } - name = g.generic_fn_name(concrete_types, name) - } else { - concrete_types := node.concrete_types.map(g.unwrap_generic(it)) - name = g.generic_fn_name(concrete_types, name) } + name = g.generic_fn_name(concrete_types, name) } } } @@ -1489,8 +1580,15 @@ fn (mut g Gen) fn_call(node ast.CallExpr) { // g.generate_tmp_autofree_arg_vars(node, name) // Handle `print(x)` mut print_auto_str := false - if is_print && (node.args[0].typ != ast.string_type || g.comptime_for_method.len > 0) { + if is_print && (node.args[0].typ != ast.string_type + || g.comptime_for_method.len > 0 || g.is_comptime_var(node.args[0].expr)) { mut typ := node.args[0].typ + if g.is_comptime_var(node.args[0].expr) { + ctyp := g.get_comptime_var_type(node.args[0].expr) + if ctyp != ast.void_type { + typ = ctyp + } + } if typ == 0 { g.checker_bug('print arg.typ is 0', node.pos) } diff --git a/vlib/v/gen/c/for.v b/vlib/v/gen/c/for.v index ffa49e17d8..872fea1094 100644 --- a/vlib/v/gen/c/for.v +++ b/vlib/v/gen/c/for.v @@ -133,16 +133,28 @@ fn (mut g Gen) for_in_stmt(node_ ast.ForInStmt) { mut is_comptime := false if (node.cond is ast.Ident && g.is_comptime_var(node.cond)) || node.cond is ast.ComptimeSelector { - is_comptime = true - mut unwrapped_typ := g.unwrap_generic(g.comptime_for_field_type) + mut unwrapped_typ := g.unwrap_generic(node.cond_type) + ctyp := g.get_comptime_var_type(node.cond) + if ctyp != ast.void_type { + unwrapped_typ = g.unwrap_generic(ctyp) + is_comptime = true + } + mut unwrapped_sym := g.table.sym(unwrapped_typ) + node.cond_type = unwrapped_typ node.val_type = g.table.value_type(unwrapped_typ) node.scope.update_var_type(node.val_var, node.val_type) node.kind = unwrapped_sym.kind - g.comptime_for_field_val_type = node.val_type - node.scope.update_ct_var_kind(node.val_var, .value_var) + if is_comptime { + g.comptime_var_type_map[node.val_var] = node.val_type + node.scope.update_ct_var_kind(node.val_var, .value_var) + + defer { + g.comptime_var_type_map.delete(node.val_var) + } + } if node.key_var.len > 0 { key_type := match unwrapped_sym.kind { @@ -152,13 +164,18 @@ fn (mut g Gen) for_in_stmt(node_ ast.ForInStmt) { node.key_type = key_type node.scope.update_var_type(node.key_var, key_type) - g.comptime_for_field_key_type = node.key_type - node.scope.update_ct_var_kind(node.key_var, .key_var) + if is_comptime { + g.comptime_var_type_map[node.key_var] = node.key_type + node.scope.update_ct_var_kind(node.key_var, .key_var) + + defer { + g.comptime_var_type_map.delete(node.key_var) + } + } } } if node.kind == .any && !is_comptime { - g.inside_for_in_any_cond = true mut unwrapped_typ := g.unwrap_generic(node.cond_type) mut unwrapped_sym := g.table.sym(unwrapped_typ) node.kind = unwrapped_sym.kind @@ -192,9 +209,14 @@ fn (mut g Gen) for_in_stmt(node_ ast.ForInStmt) { // g.writeln('// FOR IN array') mut styp := g.typ(node.val_type) mut val_sym := g.table.sym(node.val_type) + op_field := g.dot_or_ptr(node.cond_type) - if g.is_comptime_var(node.cond) { - unwrapped_typ := g.unwrap_generic(g.comptime_for_field_type) + if is_comptime && g.is_comptime_var(node.cond) { + mut unwrapped_typ := g.unwrap_generic(node.cond_type) + ctyp := g.unwrap_generic(g.get_comptime_var_type(node.cond)) + if ctyp != ast.void_type { + unwrapped_typ = ctyp + } val_sym = g.table.sym(unwrapped_typ) node.val_type = g.table.value_type(unwrapped_typ) styp = g.typ(node.val_type) @@ -218,7 +240,6 @@ fn (mut g Gen) for_in_stmt(node_ ast.ForInStmt) { g.writeln(';') } i := if node.key_var in ['', '_'] { g.new_tmp_var() } else { node.key_var } - op_field := g.dot_or_ptr(node.cond_type) g.empty_line = true opt_expr := '(*(${g.typ(node.cond_type.clear_flag(.option))}*)${cond_var}${op_field}data)' cond_expr := if node.cond_type.has_flag(.option) { @@ -447,6 +468,5 @@ fn (mut g Gen) for_in_stmt(node_ ast.ForInStmt) { if node.label.len > 0 { g.writeln('\t${node.label}__break: {}') } - g.inside_for_in_any_cond = false g.loop_depth-- } diff --git a/vlib/v/parser/fn.v b/vlib/v/parser/fn.v index c04cc513e9..44fb613c99 100644 --- a/vlib/v/parser/fn.v +++ b/vlib/v/parser/fn.v @@ -372,7 +372,7 @@ fn (mut p Parser) fn_decl() ast.FnDecl { } params << args2 if !are_args_type_only { - for param in params { + for k, param in params { if p.scope.known_var(param.name) { p.error_with_pos('redefinition of parameter `${param.name}`', param.pos) return ast.FnDecl{ @@ -389,6 +389,12 @@ fn (mut p Parser) fn_decl() ast.FnDecl { pos: param.pos is_used: true is_arg: true + ct_type_var: if (!is_method || k > 0) && param.typ.has_flag(.generic) + && !param.typ.has_flag(.variadic) { + .generic_param + } else { + .no_comptime + } }) } } diff --git a/vlib/v/tests/generic_recursive_fn_test.v b/vlib/v/tests/generic_recursive_fn_test.v new file mode 100644 index 0000000000..d1e989b4ac --- /dev/null +++ b/vlib/v/tests/generic_recursive_fn_test.v @@ -0,0 +1,46 @@ +import strings + +fn myprintln[T](data T, mut str strings.Builder) T { + $if T is $array { + str.write_string('array: [') + for i, elem in data { + myprintln(elem, mut str) + if i < data.len - 1 { + str.write_string(', ') + } + } + str.write_string(']') + } $else $if T is $map { + str.write_string('map: {') + for key, val in data { + str.write_string('(key) ') + myprintln(key, mut str) + str.write_string(' -> (value) ') + myprintln(val, mut str) + } + str.write_string('}') + } $else { + str.write_string(data.str()) + } + return data +} + +struct Test { +mut: + s strings.Builder +} + +fn test_recursive_array() { + mut t := Test{} + myprintln([[1], [2], [3]], mut t.s) + assert t.s.str() == 'array: [array: [1], array: [2], array: [3]]' +} + +fn test_recursive_map() { + mut t2 := Test{} + myprintln({ + 'a': [1, 2, 3] + 'b': [1000] + }, mut t2.s) + assert t2.s.str() == 'map: {(key) a -> (value) array: [1, 2, 3](key) b -> (value) array: [1000]}' +} diff --git a/vlib/v/tests/generic_resolve_test.v b/vlib/v/tests/generic_resolve_test.v new file mode 100644 index 0000000000..d4a03d926a --- /dev/null +++ b/vlib/v/tests/generic_resolve_test.v @@ -0,0 +1,50 @@ +struct Encoder {} + +struct StructType[T] { +mut: + val T +} + +fn (e &Encoder) encode_struct[U](val U) ! { + $for field in U.fields { + value := val.$(field.name) + $if field.typ is $struct { + e.encode_struct(value)! + } $else $if field.typ is $map { + e.encode_map(value)! + } + } +} + +fn (e &Encoder) encode_map[U](val U) ! { + for k, v in val { + e.encode_value_with_level(v)! + } +} + +fn (e &Encoder) encode_value_with_level[U](val U) ! { + $if U is $struct { + e.encode_struct(val)! + } $else $if U is $map { + e.encode_map(val)! + } +} + +fn test_simple_cases() { + e := Encoder{} + e.encode_struct(StructType[map[string]string]{ + val: { + '1': '1' + } + })! + e.encode_struct(StructType[map[string]map[string]int]{})! + e.encode_struct(StructType[map[string]map[string]int]{ + val: { + 'a': { + '1': 1 + } + } + })! + + assert true +} diff --git a/vlib/v/tests/resolve_generic_2_test.v b/vlib/v/tests/resolve_generic_2_test.v new file mode 100644 index 0000000000..d3b9c7b9ca --- /dev/null +++ b/vlib/v/tests/resolve_generic_2_test.v @@ -0,0 +1,20 @@ +fn test_resolve_generic_params() { + assert encode(true) == [] + assert encode([true]) == ['[]bool'] + assert encode(1) == [] + assert encode([1]) == ['[]int'] + assert encode('1') == [] + assert encode(['1']) == ['[]string'] +} + +fn encode[U](val U) []string { + mut c := []string{} + $if U is $array { + c << g_array(val) + } + return c +} + +fn g_array[T](t []T) string { + return typeof(t).name +}