cgen, checker: allow using smartcasted sumtype variant values in the ORM queries (fix #23239) (#23241)

This commit is contained in:
Swastik Baranwal 2024-12-23 19:49:27 +05:30 committed by GitHub
parent e0a63dba62
commit f089ba9ff2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 1 deletions

View File

@ -152,6 +152,7 @@ const skip_with_fsanitize_memory = [
'vlib/orm/orm_create_and_drop_test.v',
'vlib/orm/orm_insert_test.v',
'vlib/orm/orm_insert_reserved_name_test.v',
'vlib/orm/orm_sum_type_insert_test.v',
'vlib/orm/orm_fn_calls_test.v',
'vlib/orm/orm_last_id_test.v',
'vlib/orm/orm_string_interpolation_in_where_test.v',
@ -199,6 +200,7 @@ const skip_with_fsanitize_address = [
'vlib/orm/orm_create_and_drop_test.v',
'vlib/orm/orm_insert_test.v',
'vlib/orm/orm_insert_reserved_name_test.v',
'vlib/orm/orm_sum_type_insert_test.v',
'vlib/orm/orm_references_test.v',
'vlib/v/tests/websocket_logger_interface_should_compile_test.v',
'vlib/v/tests/orm_enum_test.v',
@ -212,6 +214,7 @@ const skip_with_fsanitize_undefined = [
'vlib/orm/orm_create_and_drop_test.v',
'vlib/orm/orm_insert_test.v',
'vlib/orm/orm_insert_reserved_name_test.v',
'vlib/orm/orm_sum_type_insert_test.v',
'vlib/orm/orm_references_test.v',
'vlib/v/tests/orm_enum_test.v',
'vlib/v/tests/orm_sub_array_struct_test.v',
@ -254,6 +257,7 @@ const skip_on_ubuntu_musl = [
'vlib/orm/orm_create_and_drop_test.v',
'vlib/orm/orm_insert_test.v',
'vlib/orm/orm_insert_reserved_name_test.v',
'vlib/orm/orm_sum_type_insert_test.v',
'vlib/orm/orm_fn_calls_test.v',
'vlib/orm/orm_null_test.v',
'vlib/orm/orm_last_id_test.v',

View File

@ -0,0 +1,26 @@
import db.sqlite
struct SomeStruct {
foo int
bar string
}
struct OtherStruct {
baz f64
}
type SomeSum = SomeStruct | OtherStruct
fn test_sum_type_insert() {
db := sqlite.connect(':memory:')!
sql db {
create table SomeStruct
}!
some := SomeSum(SomeStruct{})
if some is SomeStruct {
sql db {
insert some into SomeStruct
}!
}
}

View File

@ -253,7 +253,8 @@ fn (mut c Checker) sql_stmt_line(mut node ast.SqlStmtLine) ast.Type {
inserting_object_type = inserting_object.typ.deref()
}
if inserting_object_type != node.table_expr.typ {
if inserting_object_type != node.table_expr.typ
&& !c.table.sumtype_has_variant(inserting_object_type, node.table_expr.typ, false) {
table_name := table_sym.name
inserting_type_name := c.table.sym(inserting_object_type).name

View File

@ -334,6 +334,7 @@ fn (mut g Gen) write_orm_insert_with_last_ids(node ast.SqlStmtLine, connection_v
is_serial := primary_field.attrs.contains_arg('sql', 'serial')
&& primary_field.typ == ast.int_type
mut inserting_object_type := ast.void_type
mut member_access_type := '.'
if node.scope != unsafe { nil } {
inserting_object := node.scope.find(node.object_var) or {
@ -342,8 +343,10 @@ fn (mut g Gen) write_orm_insert_with_last_ids(node ast.SqlStmtLine, connection_v
if inserting_object.typ.is_ptr() {
member_access_type = '->'
}
inserting_object_type = inserting_object.typ
}
inserting_object_sym := g.table.sym(inserting_object_type)
for i, mut sub in subs {
if subs_unwrapped_c_typ[i].len > 0 {
var := '${node.object_var}${member_access_type}${sub.object_var}'
@ -418,6 +421,10 @@ fn (mut g Gen) write_orm_insert_with_last_ids(node ast.SqlStmtLine, connection_v
var := '${node.object_var}${member_access_type}${c_name(field.name)}'
if field.typ.has_flag(.option) {
g.writeln('${var}.state == 2? _const_orm__null_primitive : orm__${typ}_to_primitive(*(${ctyp}*)(${var}.data)),')
} else if inserting_object_sym.kind == .sum_type {
table_sym := g.table.sym(node.table_expr.typ)
sum_type_var := '(*${node.object_var}._${table_sym.cname})${member_access_type}${c_name(field.name)}'
g.writeln('orm__${typ}_to_primitive(${sum_type_var}),')
} else {
g.writeln('orm__${typ}_to_primitive(${var}),')
}