OSDN Git Service

rusty-gd: implement ACL connection close
authorZach Johnson <zachoverflow@google.com>
Wed, 30 Dec 2020 02:07:52 +0000 (18:07 -0800)
committerZach Johnson <zachoverflow@google.com>
Wed, 20 Jan 2021 00:55:52 +0000 (16:55 -0800)
Bug: 171749953
Tag: #gd-refactor
Test: gd/cert/run --rhost
Change-Id: Iaf9770676077aa5e9bfb3f2ff49545d9b71c282f

gd/rust/acl/src/fragment.rs
gd/rust/acl/src/lib.rs

index f946c94..b66b855 100644 (file)
@@ -9,6 +9,7 @@ use bytes::{Buf, Bytes, BytesMut};
 use futures::stream::{self, StreamExt};
 use log::{error, info, warn};
 use tokio::sync::mpsc::{Receiver, Sender};
+use tokio::sync::oneshot;
 
 const L2CAP_BASIC_FRAME_HEADER_LEN: usize = 4;
 
@@ -90,6 +91,7 @@ pub fn fragmenting_stream(
     mtu: usize,
     handle: u16,
     bt: Bluetooth,
+    close_rx: oneshot::Receiver<()>,
 ) -> std::pin::Pin<
     std::boxed::Box<dyn futures::Stream<Item = bt_packets::hci::AclPacket> + std::marker::Send>,
 > {
@@ -113,5 +115,6 @@ pub fn fragmenting_stream(
                 .collect::<Vec<AclPacket>>(),
         )
     })
+    .take_until(close_rx)
     .boxed()
 }
index 5415828..6c1d0fe 100644 (file)
@@ -31,35 +31,65 @@ module! {
 pub struct Connection {
     rx: Receiver<Bytes>,
     tx: Sender<Bytes>,
+    handle: u16,
+    requests: Sender<Request>,
 }
 
 struct ConnectionInternal {
     reassembler: Reassembler,
     bt: Bluetooth,
+    close_tx: oneshot::Sender<()>,
+}
+
+impl Connection {
+    /// Close this connection. Consumes self.
+    pub async fn close(self) {
+        let (tx, rx) = oneshot::channel();
+        self.requests
+            .send(Request::Unregister(UnregisterRequest { handle: self.handle, fut: tx }))
+            .await
+            .unwrap();
+        rx.await.unwrap()
+    }
 }
 
 /// Manages rx and tx for open ACL connections
 #[derive(Clone, Stoppable)]
 pub struct AclDispatch {
-    requests: Sender<RegistrationRequest>,
+    requests: Sender<Request>,
 }
 
 impl AclDispatch {
     /// Register the provided connection with the ACL dispatch
     pub async fn register(&mut self, handle: u16, bt: Bluetooth) -> Connection {
         let (tx, rx) = oneshot::channel();
-        self.requests.send(RegistrationRequest { handle, bt, fut: tx }).await.unwrap();
+        self.requests
+            .send(Request::Register(RegisterRequest { handle, bt, fut: tx }))
+            .await
+            .unwrap();
         rx.await.unwrap()
     }
 }
 
 #[derive(Debug)]
-struct RegistrationRequest {
+enum Request {
+    Register(RegisterRequest),
+    Unregister(UnregisterRequest),
+}
+
+#[derive(Debug)]
+struct RegisterRequest {
     handle: u16,
     bt: Bluetooth,
     fut: oneshot::Sender<Connection>,
 }
 
+#[derive(Debug)]
+struct UnregisterRequest {
+    handle: u16,
+    fut: oneshot::Sender<()>,
+}
+
 const QCOM_DEBUG_HANDLE: u16 = 0xedc;
 
 #[provides]
@@ -69,7 +99,8 @@ async fn provide_acl_dispatch(
     mut events: EventRegistry,
     rt: Arc<Runtime>,
 ) -> AclDispatch {
-    let (req_tx, mut req_rx) = channel::<RegistrationRequest>(10);
+    let (req_tx, mut req_rx) = channel::<Request>(10);
+    let req_tx_clone = req_tx.clone();
 
     rt.spawn(async move {
         let mut connections: HashMap<u16, ConnectionInternal> = HashMap::new();
@@ -84,28 +115,43 @@ async fn provide_acl_dispatch(
         loop {
             select! {
                 Some(req) = req_rx.recv() => {
-                    let (out_tx, out_rx) = channel(10);
-                    let (in_tx, in_rx) = channel(10);
-
-                    assert!(connections.insert(
-                        req.handle,
-                        ConnectionInternal {
-                            reassembler: Reassembler::new(out_tx),
-                            bt: req.bt,
-                        }).is_none());
-
-                    match req.bt {
-                        Classic => {
-                            classic_outbound.push(fragmenting_stream(
-                                in_rx, controller.acl_buffer_length.into(), req.handle, req.bt));
-                        },
-                        Le => {
-                            le_outbound.push(fragmenting_stream(
-                                in_rx, controller.le_buffer_length.into(), req.handle, req.bt));
+                    match req {
+                        Request::Register(req) => {
+                            let (out_tx, out_rx) = channel(10);
+                            let (in_tx, in_rx) = channel(10);
+                            let (close_tx, close_rx) = oneshot::channel();
+
+                            assert!(connections.insert(
+                                req.handle,
+                                ConnectionInternal {
+                                    reassembler: Reassembler::new(out_tx),
+                                    bt: req.bt,
+                                    close_tx,
+                                }).is_none());
+
+                            match req.bt {
+                                Classic => {
+                                    classic_outbound.push(fragmenting_stream(
+                                        in_rx, controller.acl_buffer_length.into(), req.handle, req.bt, close_rx));
+                                },
+                                Le => {
+                                    le_outbound.push(fragmenting_stream(
+                                        in_rx, controller.le_buffer_length.into(), req.handle, req.bt, close_rx));
+                                },
+                            }
+
+                            req.fut.send(Connection {
+                                rx: out_rx,
+                                tx: in_tx,
+                                handle: req.handle,
+                                requests: req_tx_clone.clone()}).unwrap();
                         },
+                        Request::Unregister(req) => {
+                            if let Some(connection) = connections.remove(&req.handle) {
+                                connection.close_tx.send(()).unwrap();
+                            }
+                        }
                     }
-
-                    req.fut.send(Connection { rx: out_rx, tx: in_tx }).unwrap();
                 },
                 Some(packet) = consume(&acl.rx) => {
                     match connections.get_mut(&packet.get_handle()) {