From d942d41eca62d67c7d05bc9ba49e3020a9165799 Mon Sep 17 00:00:00 2001 From: Zach Johnson Date: Tue, 22 Dec 2020 13:59:28 -0800 Subject: [PATCH] rusty-gd: allow parting out injected values this way we can inject only what we actually care about Bug: 171749953 Tag: #gd-refactor Test: gd/cert/run --rhost SimpleHalTest Change-Id: I3798084eabb7a8ffcc9a48b982715c792d64ff8d --- gd/rust/gddi-macros/src/lib.rs | 91 ++++++++++++++++++++++++++++++++---------- gd/rust/gddi/src/lib.rs | 14 ++----- gd/rust/hci/src/controller.rs | 6 +-- gd/rust/hci/src/facade.rs | 35 ++++++++++------ gd/rust/hci/src/lib.rs | 80 +++++++++++++++++++++++++------------ gd/rust/shim/src/hci.rs | 35 +++++++++------- gd/rust/shim/src/stack.rs | 13 +++--- 7 files changed, 182 insertions(+), 92 deletions(-) diff --git a/gd/rust/gddi-macros/src/lib.rs b/gd/rust/gddi-macros/src/lib.rs index cc8a58599..7febb110c 100644 --- a/gd/rust/gddi-macros/src/lib.rs +++ b/gd/rust/gddi-macros/src/lib.rs @@ -5,7 +5,10 @@ use proc_macro::TokenStream; use quote::{format_ident, quote}; use syn::parse::{Parse, ParseStream, Result}; use syn::punctuated::Punctuated; -use syn::{braced, parse, parse_macro_input, FnArg, Ident, ItemFn, Token, Type, DeriveInput, Path}; +use syn::{ + braced, parse, parse_macro_input, DeriveInput, Fields, FnArg, Ident, ItemFn, ItemStruct, Path, + Token, Type, +}; /// Defines a provider function, with generated helper that implicitly fetches argument instances from the registry #[proc_macro_attribute] @@ -57,6 +60,7 @@ enum ModuleEntry { struct ProviderDef { ty: Type, ident: Ident, + parts: bool, } impl Parse for ModuleDef { @@ -75,30 +79,34 @@ impl Parse for ModuleDef { panic!("providers specified more than once"); } providers = value; - }, + } ModuleEntry::Submodules(value) => { if !submodules.is_empty() { panic!("submodules specified more than once"); } submodules = value; - }, + } } } - Ok(ModuleDef { - name, - providers, - submodules, - }) + Ok(ModuleDef { name, providers, submodules }) } } impl Parse for ProviderDef { fn parse(input: ParseStream) -> Result { + let parts = input.peek3(Token![=>]); + if parts { + match input.parse::()?.to_string().as_str() { + "parts" => {} + keyword => panic!("expected 'parts', got '{}'", keyword), + } + } + // A provider definition follows this format: -> let ty = input.parse()?; input.parse::]>()?; let ident = input.parse()?; - Ok(ProviderDef { ty, ident }) + Ok(ProviderDef { ty, ident, parts }) } } @@ -108,16 +116,12 @@ impl Parse for ModuleEntry { "providers" => { let entries; braced!(entries in input); - Ok(ModuleEntry::Providers( - entries.parse_terminated(ProviderDef::parse)?, - )) + Ok(ModuleEntry::Providers(entries.parse_terminated(ProviderDef::parse)?)) } "submodules" => { let entries; braced!(entries in input); - Ok(ModuleEntry::Submodules( - entries.parse_terminated(Path::parse)?, - )) + Ok(ModuleEntry::Submodules(entries.parse_terminated(Path::parse)?)) } keyword => { panic!("unexpected keyword: {}", keyword); @@ -131,20 +135,30 @@ impl Parse for ModuleEntry { pub fn module(item: TokenStream) -> TokenStream { let module = parse_macro_input!(item as ModuleDef); let init_ident = module.name.clone(); - let types = module.providers.iter().map(|p| p.ty.clone()); - let provider_idents = module - .providers - .iter() - .map(|p| format_ident!("__gddi_{}_injected", p.ident.clone())); + let providers = module.providers.iter(); + let types = providers.clone().map(|p| p.ty.clone()); + let provider_idents = + providers.clone().map(|p| format_ident!("__gddi_{}_injected", p.ident.clone())); + let parting_functions = providers.filter_map(|p| match &p.ty { + Type::Path(ty) if p.parts => Some(format_ident!( + "__gddi_part_out_{}", + ty.path.get_ident().unwrap().to_string().to_lowercase() + )), + _ => None, + }); let submodule_idents = module.submodules.iter(); let emitted_code = quote! { #[doc(hidden)] #[allow(missing_docs)] pub fn #init_ident(builder: gddi::RegistryBuilder) -> gddi::RegistryBuilder { // Register all providers on this module - builder#(.register_provider::<#types>(Box::new(#provider_idents)))* + let ret = builder#(.register_provider::<#types>(Box::new(#provider_idents)))* // Register all submodules on this module - #(.register_module(#submodule_idents))* + #(.register_module(#submodule_idents))*; + + #(let ret = #parting_functions(ret);)* + + ret } }; emitted_code.into() @@ -160,3 +174,36 @@ pub fn derive_nop_stop(item: TokenStream) -> TokenStream { }; emitted_code.into() } + +/// Generates the code necessary to split up a type into its components +#[proc_macro_attribute] +pub fn part_out(_attr: TokenStream, item: TokenStream) -> TokenStream { + let struct_: ItemStruct = parse(item).expect("can only be applied to struct definitions"); + let struct_ident = struct_.ident.clone(); + let fields = match struct_.fields.clone() { + Fields::Named(f) => f, + _ => panic!("can only be applied to structs with named fields"), + } + .named; + + let field_names = fields.iter().map(|f| f.ident.clone().expect("field without a name")); + let field_types = fields.iter().map(|f| f.ty.clone()); + + let fn_ident = format_ident!("__gddi_part_out_{}", struct_ident.to_string().to_lowercase()); + + let emitted_code = quote! { + #struct_ + + fn #fn_ident(builder: gddi::RegistryBuilder) -> gddi::RegistryBuilder { + builder#(.register_provider::<#field_types>(Box::new( + |registry: std::sync::Arc| -> std::pin::Pin { + Box::pin(async move { + Box::new(async move { + registry.get::<#struct_ident>().await.#field_names + }) as Box + }) + })))* + } + }; + emitted_code.into() +} diff --git a/gd/rust/gddi/src/lib.rs b/gd/rust/gddi/src/lib.rs index f34c8ed84..1e005431f 100644 --- a/gd/rust/gddi/src/lib.rs +++ b/gd/rust/gddi/src/lib.rs @@ -7,7 +7,7 @@ use std::pin::Pin; use std::sync::Arc; use tokio::sync::Mutex; -pub use gddi_macros::{module, provides, Stoppable}; +pub use gddi_macros::{module, part_out, provides, Stoppable}; type InstanceBox = Box; /// A box around a future for a provider that is safe to send between threads @@ -46,9 +46,7 @@ impl Default for RegistryBuilder { impl RegistryBuilder { /// Creates a new RegistryBuilder pub fn new() -> Self { - RegistryBuilder { - providers: HashMap::new(), - } + RegistryBuilder { providers: HashMap::new() } } /// Registers a module with this registry @@ -61,8 +59,7 @@ impl RegistryBuilder { /// Registers a provider function with this registry pub fn register_provider(mut self, f: ProviderFnBox) -> Self { - self.providers - .insert(TypeId::of::(), Provider { f: Arc::new(f) }); + self.providers.insert(TypeId::of::(), Provider { f: Arc::new(f) }); self } @@ -84,10 +81,7 @@ impl Registry { { let instances = self.instances.lock().await; if let Some(value) = instances.get(&typeid) { - return value - .downcast_ref::() - .expect("was not correct type") - .clone(); + return value.downcast_ref::().expect("was not correct type").clone(); } } diff --git a/gd/rust/hci/src/controller.rs b/gd/rust/hci/src/controller.rs index d3d807a78..5ab9f1dca 100644 --- a/gd/rust/hci/src/controller.rs +++ b/gd/rust/hci/src/controller.rs @@ -1,6 +1,6 @@ //! Loads info from the controller at startup -use crate::{Address, Hci}; +use crate::{Address, CommandSender}; use bt_packets::hci::{ Enable, ErrorCode, LeMaximumDataLength, LeReadBufferSizeV1Builder, LeReadBufferSizeV2Builder, LeReadConnectListSizeBuilder, LeReadLocalSupportedFeaturesBuilder, @@ -34,7 +34,7 @@ macro_rules! assert_success { } #[provides] -async fn provide_controller(mut hci: Hci) -> Arc { +async fn provide_controller(mut hci: CommandSender) -> Arc { assert_success!(hci.send(LeSetEventMaskBuilder { le_event_mask: 0x0000000000021e7f })); assert_success!(hci.send(SetEventMaskBuilder { event_mask: 0x3dbfffffffffffff })); assert_success!( @@ -167,7 +167,7 @@ async fn provide_controller(mut hci: Hci) -> Arc { }) } -async fn read_features(hci: &mut Hci) -> SupportedFeatures { +async fn read_features(hci: &mut CommandSender) -> SupportedFeatures { let mut features = Vec::new(); let mut page_number: u8 = 0; let mut max_page_number: u8 = 1; diff --git a/gd/rust/hci/src/facade.rs b/gd/rust/hci/src/facade.rs index a7b2efac1..b319d510d 100644 --- a/gd/rust/hci/src/facade.rs +++ b/gd/rust/hci/src/facade.rs @@ -1,6 +1,6 @@ //! HCI layer facade -use crate::Hci; +use crate::{EventRegistry, HciForAcl, RawCommandSender}; use bt_common::GrpcFacade; use bt_facade_proto::common::Data; use bt_facade_proto::empty::Empty; @@ -26,11 +26,18 @@ module! { } #[provides] -async fn provide_facade(hci: Hci, rt: Arc) -> HciFacadeService { +async fn provide_facade( + commands: RawCommandSender, + events: EventRegistry, + acl: HciForAcl, + rt: Arc, +) -> HciFacadeService { let (from_hci_evt_tx, to_grpc_evt_rx) = channel::(10); let (from_hci_le_evt_tx, to_grpc_le_evt_rx) = channel::(10); HciFacadeService { - hci, + commands, + events, + acl, rt, from_hci_evt_tx, to_grpc_evt_rx: Arc::new(Mutex::new(to_grpc_evt_rx)), @@ -42,7 +49,9 @@ async fn provide_facade(hci: Hci, rt: Arc) -> HciFacadeService { /// HCI layer facade service #[derive(Clone, Stoppable)] pub struct HciFacadeService { - hci: Hci, + commands: RawCommandSender, + events: EventRegistry, + acl: HciForAcl, rt: Arc, from_hci_evt_tx: Sender, to_grpc_evt_rx: Arc>>, @@ -59,16 +68,18 @@ impl GrpcFacade for HciFacadeService { impl HciFacade for HciFacadeService { fn send_command(&mut self, _ctx: RpcContext<'_>, mut data: Data, sink: UnarySink) { self.rt - .block_on(self.hci.send_raw(CommandPacket::parse(&data.take_payload()).unwrap())) + .block_on(self.commands.send(CommandPacket::parse(&data.take_payload()).unwrap())) .unwrap(); sink.success(Empty::default()); } fn request_event(&mut self, _ctx: RpcContext<'_>, req: EventRequest, sink: UnarySink) { - self.rt.block_on(self.hci.register_event_handler( - EventCode::from_u32(req.get_code()).unwrap(), - self.from_hci_evt_tx.clone(), - )); + self.rt.block_on( + self.events.register( + EventCode::from_u32(req.get_code()).unwrap(), + self.from_hci_evt_tx.clone(), + ), + ); sink.success(Empty::default()); } @@ -78,7 +89,7 @@ impl HciFacade for HciFacadeService { req: EventRequest, sink: UnarySink, ) { - self.rt.block_on(self.hci.register_le_event_handler( + self.rt.block_on(self.events.register_le( SubeventCode::from_u32(req.get_code()).unwrap(), self.from_hci_le_evt_tx.clone(), )); @@ -86,7 +97,7 @@ impl HciFacade for HciFacadeService { } fn send_acl(&mut self, _ctx: RpcContext<'_>, mut packet: Data, sink: UnarySink) { - let acl_tx = self.hci.acl_tx.clone(); + let acl_tx = self.acl.tx.clone(); self.rt.block_on(async move { acl_tx.send(AclPacket::parse(&packet.take_payload()).unwrap()).await.unwrap(); }); @@ -133,7 +144,7 @@ impl HciFacade for HciFacadeService { _req: Empty, mut resp: ServerStreamingSink, ) { - let acl_rx = self.hci.acl_rx.clone(); + let acl_rx = self.acl.rx.clone(); self.rt.spawn(async move { while let Some(data) = acl_rx.lock().await.recv().await { diff --git a/gd/rust/hci/src/lib.rs b/gd/rust/hci/src/lib.rs index cadf5e914..3108e0e26 100644 --- a/gd/rust/hci/src/lib.rs +++ b/gd/rust/hci/src/lib.rs @@ -21,7 +21,7 @@ use bt_packets::hci::{ LeMetaEventPacket, ResetBuilder, SubeventCode, }; use error::Result; -use gddi::{module, provides, Stoppable}; +use gddi::{module, part_out, provides, Stoppable}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -37,10 +37,19 @@ module! { controller::controller_module, }, providers { - Hci => provide_hci, + parts Hci => provide_hci, }, } +#[part_out] +#[derive(Clone, Stoppable)] +struct Hci { + raw_commands: RawCommandSender, + commands: CommandSender, + events: EventRegistry, + acl: HciForAcl, +} + #[provides] async fn provide_hci(hal: Hal, rt: Arc) -> Hci { let (cmd_tx, cmd_rx) = channel::(10); @@ -55,15 +64,20 @@ async fn provide_hci(hal: Hal, rt: Arc) -> Hci { cmd_rx, )); - let mut hci = - Hci { cmd_tx, evt_handlers, le_evt_handlers, acl_tx: hal.acl_tx, acl_rx: hal.acl_rx }; + let raw_commands = RawCommandSender { cmd_tx }; + let mut commands = CommandSender { raw: raw_commands.clone() }; assert!( - hci.send(ResetBuilder {}).await.get_status() == ErrorCode::Success, + commands.send(ResetBuilder {}).await.get_status() == ErrorCode::Success, "reset did not complete successfully" ); - hci + Hci { + raw_commands, + commands, + events: EventRegistry { evt_handlers, le_evt_handlers }, + acl: HciForAcl { tx: hal.acl_tx, rx: hal.acl_rx }, + } } #[derive(Debug)] @@ -72,39 +86,59 @@ struct QueuedCommand { fut: oneshot::Sender, } -/// HCI interface +/// Sends raw commands. Only useful for facades & shims, or wrapped as a CommandSender. #[derive(Clone, Stoppable)] -pub struct Hci { +pub struct RawCommandSender { cmd_tx: Sender, - evt_handlers: Arc>>>, - le_evt_handlers: Arc>>>, - /// Transmit end of a channel used to send ACL data - pub acl_tx: Sender, - /// Receive end of a channel used to receive ACL data - pub acl_rx: Arc>>, } -impl Hci { +impl RawCommandSender { /// Send a command, but does not automagically associate the expected returning event type. /// /// Only really useful for facades & shims. - pub async fn send_raw(&mut self, cmd: CommandPacket) -> Result { + pub async fn send(&mut self, cmd: CommandPacket) -> Result { let (tx, rx) = oneshot::channel::(); self.cmd_tx.send(QueuedCommand { cmd, fut: tx }).await?; let event = rx.await?; Ok(event) } +} + +/// Sends commands to the controller +#[derive(Clone, Stoppable)] +pub struct CommandSender { + raw: RawCommandSender, +} +impl CommandSender { /// Send a command to the controller, getting an expected response back pub async fn send + CommandExpectations>( &mut self, cmd: T, ) -> T::ResponseType { - T::_to_response_type(self.send_raw(cmd.into()).await.unwrap()) + T::_to_response_type(self.raw.send(cmd.into()).await.unwrap()) } +} + +/// Exposes the ACL send/receive interface +#[derive(Clone, Stoppable)] +pub struct HciForAcl { + /// Transmit end + pub tx: Sender, + /// Receive end + pub rx: Arc>>, +} + +/// Provides ability to register and unregister for HCI events +#[derive(Clone, Stoppable)] +pub struct EventRegistry { + evt_handlers: Arc>>>, + le_evt_handlers: Arc>>>, +} +impl EventRegistry { /// Indicate interest in specific HCI events - pub async fn register_event_handler(&mut self, code: EventCode, sender: Sender) { + pub async fn register(&mut self, code: EventCode, sender: Sender) { match code { EventCode::CommandStatus | EventCode::CommandComplete @@ -123,16 +157,12 @@ impl Hci { } /// Remove interest in specific HCI events - pub async fn unregister_event_handler(&mut self, code: EventCode) { + pub async fn unregister(&mut self, code: EventCode) { self.evt_handlers.lock().await.remove(&code); } /// Indicate interest in specific LE events - pub async fn register_le_event_handler( - &mut self, - code: SubeventCode, - sender: Sender, - ) { + pub async fn register_le(&mut self, code: SubeventCode, sender: Sender) { assert!( self.le_evt_handlers.lock().await.insert(code, sender).is_none(), "A handler for {:?} is already registered", @@ -141,7 +171,7 @@ impl Hci { } /// Remove interest in specific LE events - pub async fn unregister_le_event_handler(&mut self, code: SubeventCode) { + pub async fn unregister_le(&mut self, code: SubeventCode) { self.le_evt_handlers.lock().await.remove(&code); } } diff --git a/gd/rust/shim/src/hci.rs b/gd/rust/shim/src/hci.rs index d3a17cef5..2441b2d15 100644 --- a/gd/rust/shim/src/hci.rs +++ b/gd/rust/shim/src/hci.rs @@ -1,5 +1,6 @@ //! Hci shim +use bt_hci::{EventRegistry, HciForAcl, RawCommandSender}; use bt_packets::hci::{ AclPacket, CommandPacket, EventCode, EventPacket, LeMetaEventPacket, SubeventCode, }; @@ -40,8 +41,10 @@ unsafe impl Send for ffi::u8SliceCallback {} unsafe impl Send for ffi::u8SliceOnceCallback {} pub struct Hci { + commands: RawCommandSender, + events: EventRegistry, + acl: HciForAcl, rt: Arc, - internal: bt_hci::Hci, acl_callback_set: bool, evt_callback_set: bool, le_evt_callback_set: bool, @@ -52,12 +55,19 @@ pub struct Hci { } impl Hci { - pub fn new(rt: Arc, internal: bt_hci::Hci) -> Self { + pub fn new( + rt: Arc, + commands: RawCommandSender, + events: EventRegistry, + acl: HciForAcl, + ) -> Self { let (evt_tx, evt_rx) = channel::(10); let (le_evt_tx, le_evt_rx) = channel::(10); Self { rt, - internal, + commands, + events, + acl, acl_callback_set: false, evt_callback_set: false, le_evt_callback_set: false, @@ -75,35 +85,32 @@ pub fn hci_send_command( callback: cxx::UniquePtr, ) { let packet = CommandPacket::parse(data).unwrap(); - let mut clone_internal = hci.internal.clone(); + let mut commands = hci.commands.clone(); hci.rt.spawn(async move { - let resp = clone_internal.send_raw(packet).await.unwrap(); + let resp = commands.send(packet).await.unwrap(); callback.Run(&resp.to_bytes()); }); } pub fn hci_send_acl(hci: &mut Hci, data: &[u8]) { - hci.rt.block_on(hci.internal.acl_tx.send(AclPacket::parse(data).unwrap())).unwrap(); + hci.rt.block_on(hci.acl.tx.send(AclPacket::parse(data).unwrap())).unwrap(); } pub fn hci_register_event(hci: &mut Hci, event: u8) { - hci.rt.block_on( - hci.internal.register_event_handler(EventCode::from_u8(event).unwrap(), hci.evt_tx.clone()), - ); + hci.rt.block_on(hci.events.register(EventCode::from_u8(event).unwrap(), hci.evt_tx.clone())); } pub fn hci_register_le_event(hci: &mut Hci, subevent: u8) { - hci.rt.block_on(hci.internal.register_le_event_handler( - SubeventCode::from_u8(subevent).unwrap(), - hci.le_evt_tx.clone(), - )); + hci.rt.block_on( + hci.events.register_le(SubeventCode::from_u8(subevent).unwrap(), hci.le_evt_tx.clone()), + ); } pub fn hci_set_acl_callback(hci: &mut Hci, callback: cxx::UniquePtr) { assert!(!hci.acl_callback_set); hci.acl_callback_set = true; - let stream = hci.internal.acl_rx.clone(); + let stream = hci.acl.rx.clone(); hci.rt.spawn(async move { while let Some(item) = stream.lock().await.recv().await { callback.Run(&item.to_bytes()); diff --git a/gd/rust/shim/src/stack.rs b/gd/rust/shim/src/stack.rs index 03f8e86ee..7d0628e38 100644 --- a/gd/rust/shim/src/stack.rs +++ b/gd/rust/shim/src/stack.rs @@ -41,12 +41,8 @@ pub fn stack_create() -> Box { }) } -pub fn stack_start(stack: &mut Stack) { +pub fn stack_start(_stack: &mut Stack) { assert!(init_flags::gd_rust_is_enabled()); - - if init_flags::gd_hci_is_enabled() { - stack.get_blocking::(); - } } pub fn stack_stop(stack: &mut Stack) { @@ -59,7 +55,12 @@ pub fn get_hci(stack: &mut Stack) -> Box { assert!(init_flags::gd_rust_is_enabled()); assert!(init_flags::gd_hci_is_enabled()); - Box::new(Hci::new(stack.get_runtime(), stack.get_blocking::())) + Box::new(Hci::new( + stack.get_runtime(), + stack.get_blocking::(), + stack.get_blocking::(), + stack.get_blocking::(), + )) } pub fn get_controller(stack: &mut Stack) -> Box { -- 2.11.0