diff --git a/vlib/v/checker/check_types.v b/vlib/v/checker/check_types.v index 8e6b963d90..8d9b526f2e 100644 --- a/vlib/v/checker/check_types.v +++ b/vlib/v/checker/check_types.v @@ -1069,6 +1069,15 @@ fn (mut c Checker) infer_fn_generic_types(func ast.Fn, mut node ast.CallExpr) { if param_sym.info.func.return_type.nr_muls() > 0 && typ.nr_muls() > 0 { typ = typ.set_nr_muls(0) } + // resolve lambda with generic return type + if arg.expr is ast.LambdaExpr && typ.has_flag(.generic) { + typ = c.comptime.resolve_generic_expr(arg.expr.expr, typ) + if typ.has_flag(.generic) { + lambda_ret_gt_name := c.table.type_to_str(typ) + idx := func.generic_names.index(lambda_ret_gt_name) + typ = node.concrete_types[idx] + } + } } } } else if arg_sym.kind in [.struct_, .interface_, .sum_type] { diff --git a/vlib/v/checker/fn.v b/vlib/v/checker/fn.v index 9b2f3990e2..504ddf4fab 100644 --- a/vlib/v/checker/fn.v +++ b/vlib/v/checker/fn.v @@ -447,7 +447,7 @@ fn (mut c Checker) fn_decl(mut node ast.FnDecl) { } node.params = [node.params[0], ctx_param] node.params << params[1..] - println('new params ${node.name}') + // println('new params ${node.name}') // println(node.params) } // sym := c.table.sym(typ_veb_context) diff --git a/vlib/v/checker/return.v b/vlib/v/checker/return.v index 57d6740d10..807e4153ef 100644 --- a/vlib/v/checker/return.v +++ b/vlib/v/checker/return.v @@ -261,6 +261,11 @@ fn (mut c Checker) return_stmt(mut node ast.Return) { } else { got_type_sym.name } + // ignore generic casting expr on lambda in this phase + if c.inside_lambda && exp_type.has_flag(.generic) + && node.exprs[expr_idxs[i]] is ast.CastExpr { + continue + } c.error('cannot use `${got_type_name}` as ${c.error_type_name(exp_type)} in return argument', pos) } diff --git a/vlib/v/comptime/comptimeinfo.v b/vlib/v/comptime/comptimeinfo.v index bc1edf3e6b..efc7e9d249 100644 --- a/vlib/v/comptime/comptimeinfo.v +++ b/vlib/v/comptime/comptimeinfo.v @@ -279,6 +279,36 @@ fn (mut ct ComptimeInfo) comptime_get_kind_var(var ast.Ident) ?ast.ComptimeForKi } } +pub fn (mut ct ComptimeInfo) resolve_generic_expr(expr ast.Expr, default_typ ast.Type) ast.Type { + match expr { + ast.ParExpr { + return ct.resolve_generic_expr(expr.expr, default_typ) + } + ast.CastExpr { + return expr.typ + } + ast.InfixExpr { + if ct.is_comptime_var(expr.left) { + return ct.resolver.unwrap_generic(ct.get_comptime_var_type(expr.left)) + } + if ct.is_comptime_var(expr.right) { + return ct.resolver.unwrap_generic(ct.get_comptime_var_type(expr.right)) + } + return default_typ + } + ast.Ident { + return if ct.is_comptime_var(expr) { + ct.resolver.unwrap_generic(ct.get_comptime_var_type(expr)) + } else { + default_typ + } + } + else { + return default_typ + } + } +} + pub struct DummyResolver { mut: file &ast.File = unsafe { nil } diff --git a/vlib/v/gen/c/fn.v b/vlib/v/gen/c/fn.v index 7fd7edc667..cd7e55fbd6 100644 --- a/vlib/v/gen/c/fn.v +++ b/vlib/v/gen/c/fn.v @@ -1469,39 +1469,12 @@ fn (mut g Gen) resolve_receiver_name(node ast.CallExpr, unwrapped_rec_type ast.T return receiver_type_name } -fn (mut g Gen) resolve_generic_expr(expr ast.Expr, default_typ ast.Type) ast.Type { - match expr { - ast.ParExpr { - return g.resolve_generic_expr(expr.expr, default_typ) - } - ast.InfixExpr { - if g.comptime.is_comptime_var(expr.left) { - return g.unwrap_generic(g.comptime.get_comptime_var_type(expr.left)) - } - if g.comptime.is_comptime_var(expr.right) { - return g.unwrap_generic(g.comptime.get_comptime_var_type(expr.right)) - } - return default_typ - } - ast.Ident { - return if g.comptime.is_comptime_var(expr) { - g.unwrap_generic(g.comptime.get_comptime_var_type(expr)) - } else { - default_typ - } - } - else { - return default_typ - } - } -} - fn (mut g Gen) resolve_receiver_type(node ast.CallExpr) (ast.Type, &ast.TypeSymbol) { left_type := g.unwrap_generic(node.left_type) mut unwrapped_rec_type := node.receiver_type if g.cur_fn != unsafe { nil } && g.cur_fn.generic_names.len > 0 { // in generic fn unwrapped_rec_type = g.unwrap_generic(node.receiver_type) - unwrapped_rec_type = g.resolve_generic_expr(node.left, unwrapped_rec_type) + unwrapped_rec_type = g.comptime.resolve_generic_expr(node.left, unwrapped_rec_type) } else { // in non-generic fn sym := g.table.sym(node.receiver_type) match sym.info { diff --git a/vlib/v/tests/generic_lambda_expr_test.v b/vlib/v/tests/generic_lambda_expr_test.v new file mode 100644 index 0000000000..2a3f0a7eb7 --- /dev/null +++ b/vlib/v/tests/generic_lambda_expr_test.v @@ -0,0 +1,15 @@ +pub fn mymap[T, R](input []T, f fn (T) R) []R { + mut results := []R{cap: input.len} + for x in input { + results << f(x) + } + return results +} + +fn test_main() { + assert dump(mymap([1, 2, 3, 4, 5], fn (i int) int { + return i * i + })) == [1, 4, 9, 16, 25] + assert dump(mymap([1, 2, 3, 4, 5], |x| x * x)) == [1, 4, 9, 16, 25] + assert dump(mymap([1, 2, 3, 4, 5], |x| u16(x * x))) == [u16(1), 4, 9, 16, 25] +}