Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 55 additions & 7 deletions cc_bindings_from_rs/generate_bindings/generate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use quote::quote;
use rustc_hir::attrs::AttributeKind;
use rustc_hir::{self as hir, def::DefKind};
use rustc_middle::mir::Mutability;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::ty::{self, TraitRef, Ty, TyCtxt};
use rustc_span::def_id::DefId;
use rustc_span::symbol::Symbol;
use std::collections::BTreeSet;
Expand Down Expand Up @@ -697,6 +697,54 @@ fn get_function_cc_name(db: &BindingsGenerator, def_id: DefId) -> Result<Ident>
.context("Error formatting function name")
}

fn format_trait_ref_for_cc<'tcx>(
db: &BindingsGenerator<'tcx>,
trait_ref: &TraitRef<'tcx>,
) -> Result<CcSnippet<'tcx>> {
let trait_name = db
.symbol_canonical_name(trait_ref.def_id)
.and_then(|fully_qualified_name| fully_qualified_name.format_for_cc(db).ok())
.expect("Generated trait method for a trait with an invalid cc name");
let mut trait_args = trait_ref.args[1..].iter().filter_map(|arg| arg.as_type()).peekable();
let mut prereqs = CcPrerequisites::default();
let tokens = if trait_args.peek().is_none() {
quote! { #trait_name }
} else {
let arg_tokens = trait_args
.map(|ty_arg| {
Ok(db.format_ty_for_cc(ty_arg, TypeLocation::Other)?.into_tokens(&mut prereqs))
})
.collect::<Result<Vec<_>>>()?;
quote! { #trait_name<#(#arg_tokens),*> }
};
Ok(CcSnippet { prereqs, tokens })
}

fn format_trait_ref_for_rs<'tcx>(
db: &BindingsGenerator<'tcx>,
trait_ref: &TraitRef<'tcx>,
) -> Result<TokenStream> {
let trait_name = db
.symbol_canonical_name(trait_ref.def_id)
.map(|fully_qualified_name| fully_qualified_name.format_for_rs())
.expect("Generated trait method for a trait with an invalid rs name");
let mut trait_args = trait_ref.args[1..].iter().filter_map(|arg| arg.as_type()).peekable();
if trait_args.peek().is_none() {
Ok(quote! { #trait_name })
} else {
let arg_tokens = trait_args
.map(|ty_arg| {
let static_ty_arg = crate::generate_function_thunk::replace_all_regions_with_static(
db.tcx(),
ty_arg,
);
db.format_ty_for_rs(static_ty_arg)
})
.collect::<Result<Vec<_>>>()?;
Ok(quote! { #trait_name<#(#arg_tokens),*> })
}
}

/// Implementation of `BindingsGenerator::generate_function`.
pub fn generate_function<'tcx>(
db: &BindingsGenerator<'tcx>,
Expand Down Expand Up @@ -946,15 +994,14 @@ pub fn generate_function<'tcx>(
let decl_name = trait_ref
.as_ref()
.map(|trait_ref| {
let trait_name = db
.symbol_canonical_name(trait_ref.def_id)
.and_then(|fully_qualified_name| fully_qualified_name.format_for_cc(db).ok())
.expect("Generated trait method for a trait with an invalid rust name");
let struct_name = struct_name
.as_ref()
.and_then(|fully_qualified_name| fully_qualified_name.format_for_cc(db).ok())
.expect("Generated trait method for an ADT with an invalid rust name");
quote! { rs_std :: impl <#struct_name, #trait_name> :: #bracketed_decl_name }
let trait_name_with_args = format_trait_ref_for_cc(db, trait_ref)
.expect("Implementation of trait containing invalid type requested. Caller should have verified type arguments were valid.")
.into_tokens(&mut prereqs);
quote! { rs_std :: impl <#struct_name, #trait_name_with_args> :: #bracketed_decl_name }
})
.or_else(|| {
struct_name.as_ref().map(|fully_qualified_name| {
Expand Down Expand Up @@ -993,7 +1040,8 @@ pub fn generate_function<'tcx>(
.map(|fully_qualified_name| fully_qualified_name.format_for_rs())
.expect("Generated trait method for an ADT with an invalid rust name");
let fn_name = make_rs_ident(unqualified_rust_fn_name.as_str());
quote! { <#struct_name as #trait_name>::#fn_name }
let trait_name_with_args = format_trait_ref_for_rs(db, trait_ref).expect("Implementation of trait containing invalid type requested. Caller should have verified type arguments were valid.");
quote! { <#struct_name as #trait_name_with_args>::#fn_name }
})
// Inherent method
.or_else(|| struct_name.as_ref().map(|struct_name| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use std::collections::{BTreeSet, HashMap, HashSet};
use std::iter::once;
use std::rc::Rc;

fn has_type_or_const_vars() -> TypeFlags {
pub(crate) fn has_type_or_const_vars() -> TypeFlags {
TypeFlags::HAS_TY_PARAM
| TypeFlags::HAS_CT_PARAM
| TypeFlags::HAS_TY_INFER
Expand Down
175 changes: 115 additions & 60 deletions cc_bindings_from_rs/generate_bindings/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::generate_function::{generate_function, must_use_attr_of};
use crate::generate_function_thunk::{generate_trait_thunks, TraitThunks};
use crate::generate_struct_and_union::{
adt_needs_bindings, cpp_enum_cpp_underlying_type, from_trait_impls_by_argument, generate_adt,
generate_adt_core, generate_associated_item, scalar_value_to_string,
generate_adt_core, generate_associated_item, has_type_or_const_vars, scalar_value_to_string,
};
use arc_anyhow::{Context, Error, Result};
use code_gen_utils::{format_cc_includes, CcConstQualifier, CcInclude, NamespaceQualifier};
Expand All @@ -58,7 +58,7 @@ use rustc_abi::{AddressSpace, BackendRepr, Integer, Primitive, Scalar};
use rustc_hir::def::{DefKind, Res};
use rustc_middle::metadata::{ModChild, Reexport};
use rustc_middle::mir::ConstValue;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::ty::{self, GenericParamDefKind, Ty, TyCtxt};
use rustc_span::def_id::{CrateNum, DefId, LOCAL_CRATE};
use rustc_span::symbol::{sym, Symbol};
use std::cmp::Ordering;
Expand Down Expand Up @@ -862,12 +862,13 @@ fn supported_traits(db: &BindingsGenerator<'_>) -> Rc<[DefId]> {
&& crate_name.as_str() != "alloc";

let generics = tcx.generics_of(*trait_id);
// TODO: b/259749095 - Support generics in Traits.
// Traits will have a single parameter for the self type which is allowed.
let no_generic_args = (generics.has_self
&& generics.own_params.iter().filter(|param| param.kind.is_ty_or_const()).count()
== 1)
|| !generics.requires_monomorphization(tcx);
// Traits do not support const generics.
let no_generic_args = generics
.own_params
.iter()
.filter(|param| matches!(param.kind, GenericParamDefKind::Const { .. }))
.count()
== 0;

let is_exposed_trait = db.symbol_canonical_name(*trait_id).is_some();
// We might want to explicitly omit certain marker traits here that are already handled by the bindings for structs/enums (Copy, Clone, Default, etc.).
Expand Down Expand Up @@ -896,12 +897,33 @@ fn generate_trait<'tcx>(
let rs_type = canonical_name.format_for_rs().to_string();
let attributes = vec![quote! {CRUBIT_INTERNAL_RUST_TYPE(#rs_type)}];

let tcx = db.tcx();
let generics = tcx.generics_of(trait_id);
let own_params: Vec<_> = generics
.own_params
.iter()
.filter(|param| matches!(param.kind, GenericParamDefKind::Type { .. }))
.collect();
let trait_params = if generics.has_self { &own_params[1..] } else { &own_params[..] };

let (template_prefix, trait_name_with_args) = if trait_params.is_empty() {
(quote! {}, quote! { #trait_name })
} else {
let template_params = trait_params.iter().enumerate().map(|(i, _)| {
let param_name = format_ident!("T{}", i);
quote! { typename #param_name }
});
let template_args = trait_params.iter().enumerate().map(|(i, _)| format_ident!("T{}", i));
(quote! { template <#(#template_params),*> }, quote! { #trait_name<#(#template_args),*> })
};

let main_api = CcSnippet::with_include(
quote! {
__NEWLINE__ #doc_comment
#template_prefix
struct #(#attributes)* #trait_name {
template <typename T>
using impl = rs_std::impl<T, #trait_name>;
using impl = rs_std::impl<T, #trait_name_with_args>;
};
__NEWLINE__
},
Expand Down Expand Up @@ -1659,62 +1681,95 @@ fn generate_trait_impls<'a, 'tcx>(
.map(move |impl_def_id| (adt_cc_name.clone(), trait_def_id, impl_def_id))
})
})
.map(move |(adt_cc_name, trait_def_id, impl_def_id)| {
let canonical_name = db.symbol_canonical_name(trait_def_id).expect(
"symbol_canonical_name was unexpectedly called on a trait without a canonical name",
);
let trait_name = canonical_name.format_for_cc(db).map_err(|err| (impl_def_id, err))?;
let mut prereqs = CcPrerequisites::default();
if trait_def_id.krate == db.source_crate_num() {
prereqs.defs.insert(trait_def_id);
} else {
let other_crate_name = tcx.crate_name(trait_def_id.krate);
let crate_name_to_include_paths = db.crate_name_to_include_paths();
let includes = crate_name_to_include_paths
.get(other_crate_name.as_str())
.ok_or_else(|| {
let trait_name = tcx.def_path_str(trait_def_id);
(
impl_def_id,
anyhow!(
"Trait `{trait_name}` comes from the `{other_crate_name}` crate, \
but no `--crate-header` was specified for this crate"
),
)
})?;
prereqs.includes.extend(includes.iter().cloned());
}
.map(
move |(adt_cc_name, trait_def_id, impl_def_id)| -> Result<ApiSnippets, (DefId, Error)> {
let trait_header = tcx.impl_trait_header(impl_def_id);
#[rustversion::before(2025-10-17)]
let trait_header = trait_header.expect("Trait impl should have a trait header");
let trait_ref = trait_header.trait_ref.instantiate_identity();

let canonical_name = db.symbol_canonical_name(trait_def_id).expect(
"symbol_canonical_name was unexpectedly called on a trait without a canonical name",
);
let trait_name =
canonical_name.format_for_cc(db).map_err(|err| (impl_def_id, err))?;

let mut member_function_names = HashSet::new();
let assoc_items: ApiSnippets = tcx
.associated_items(impl_def_id)
.in_definition_order()
.flat_map(|assoc_item| {
generate_associated_item(db, assoc_item, &mut member_function_names)
})
.collect();
let mut prereqs = CcPrerequisites::default();
let trait_args: Vec<_> = trait_ref
.args
.iter()
// Skip self type.
.skip(1)
.filter_map(|arg| arg.as_type())
.map(|arg| {
if arg.flags().contains(has_type_or_const_vars()) {
return Err((impl_def_id, anyhow!("Implementation of traits must specify all types to receive bindings.")));
}
db.format_ty_for_cc(arg, TypeLocation::Other)
.map(|snippet| snippet.into_tokens(&mut prereqs))
.map_err(|err| (impl_def_id, err))
})
.collect::<Result<Vec<_>, _>>()?;

let main_api = assoc_items.main_api.into_tokens(&mut prereqs);
prereqs.includes.insert(db.support_header("rs_std/traits.h"));
let type_args = if trait_args.is_empty() {
quote! {}
} else {
quote! { <#(#trait_args),*> }
};

Ok(ApiSnippets {
main_api: CcSnippet {
tokens: quote! {
__NEWLINE__
template<>
struct rs_std::impl<#adt_cc_name, #trait_name> {
static constexpr bool kIsImplemented = true;
let trait_name_with_args = quote! { #trait_name #type_args };

#main_api
};
__NEWLINE__
if trait_def_id.krate == db.source_crate_num() {
prereqs.defs.insert(trait_def_id);
} else {
let other_crate_name = tcx.crate_name(trait_def_id.krate);
let crate_name_to_include_paths = db.crate_name_to_include_paths();
let includes = crate_name_to_include_paths
.get(other_crate_name.as_str())
.ok_or_else(|| {
let trait_name = tcx.def_path_str(trait_def_id);
(
impl_def_id,
anyhow!(
"Trait `{trait_name}` comes from the `{other_crate_name}` crate, \
but no `--crate-header` was specified for this crate"
),
)
})?;
prereqs.includes.extend(includes.iter().cloned());
}

let mut member_function_names = HashSet::new();
let assoc_items: ApiSnippets = tcx
.associated_items(impl_def_id)
.in_definition_order()
.flat_map(|assoc_item| {
generate_associated_item(db, assoc_item, &mut member_function_names)
})
.collect();

let main_api = assoc_items.main_api.into_tokens(&mut prereqs);
prereqs.includes.insert(db.support_header("rs_std/traits.h"));

Ok(ApiSnippets {
main_api: CcSnippet {
tokens: quote! {
__NEWLINE__
template<>
struct rs_std::impl<#adt_cc_name, #trait_name_with_args> {
static constexpr bool kIsImplemented = true;

#main_api
};
__NEWLINE__
},
prereqs,
},
prereqs,
},
cc_details: assoc_items.cc_details,
rs_details: assoc_items.rs_details,
})
})
cc_details: assoc_items.cc_details,
rs_details: assoc_items.rs_details,
})
},
)
.map(|results_snippets| {
results_snippets.unwrap_or_else(|(def_id, err)| {
generate_unsupported_def(db, def_id, err).into_main_api()
Expand Down
34 changes: 34 additions & 0 deletions cc_bindings_from_rs/test/traits/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,37 @@ crubit_cc_test(
"//testing/base/public:gunit_main",
],
)

rust_library(
name = "generic_traits",
srcs = ["generic_traits.rs"],
aspect_hints = [
"//features:experimental",
],
proc_macro_deps = [
"//support:crubit_annotate",
],
)

cc_bindings_from_rust(
name = "generic_traits_cc_api",
testonly = 1,
crate = ":generic_traits",
)

golden_test(
name = "generic_traits_golden_test",
basename = "generic_traits",
golden_h = "generic_traits_cc_api.h",
golden_rs = "generic_traits_cc_api_impl.rs",
rust_library = "generic_traits",
)

crubit_cc_test(
name = "generic_traits_test",
srcs = ["generic_traits_test.cc"],
deps = [
":generic_traits_cc_api",
"//testing/base/public:gunit_main",
],
)
Loading