aaudio: fix possible race condition in close()

Move increment and decrement of reference count under
the same lock used for finding the stream.

Bug: 79693915
Test: run CTS tests
Change-Id: I206eb09724a81f2d79a03fa756adbcbb8abf5efa
diff --git a/services/oboeservice/AAudioService.cpp b/services/oboeservice/AAudioService.cpp
index 5675b0b..6a72e5b 100644
--- a/services/oboeservice/AAudioService.cpp
+++ b/services/oboeservice/AAudioService.cpp
@@ -144,15 +144,14 @@
 // If a close request is pending then close the stream
 bool AAudioService::releaseStream(const sp<AAudioServiceStreamBase> &serviceStream) {
     bool closed = false;
-    if ((serviceStream->decrementServiceReferenceCount() == 0) && serviceStream->isCloseNeeded()) {
-        // removeStreamByHandle() uses a lock so that if there are two simultaneous closes
-        // then only one will get the pointer and do the close.
-        sp<AAudioServiceStreamBase> foundStream = mStreamTracker.removeStreamByHandle(serviceStream->getHandle());
-        if (foundStream.get() != nullptr) {
-            foundStream->close();
-            pid_t pid = foundStream->getOwnerProcessId();
-            AAudioClientTracker::getInstance().unregisterClientStream(pid, foundStream);
-        }
+    // decrementAndRemoveStreamByHandle() uses a lock so that if there are two simultaneous closes
+    // then only one will get the pointer and do the close.
+    sp<AAudioServiceStreamBase> foundStream = mStreamTracker.decrementAndRemoveStreamByHandle(
+            serviceStream->getHandle());
+    if (foundStream.get() != nullptr) {
+        foundStream->close();
+        pid_t pid = foundStream->getOwnerProcessId();
+        AAudioClientTracker::getInstance().unregisterClientStream(pid, foundStream);
         closed = true;
     }
     return closed;
@@ -175,14 +174,15 @@
     pid_t pid = serviceStream->getOwnerProcessId();
     AAudioClientTracker::getInstance().unregisterClientStream(pid, serviceStream);
 
-    serviceStream->setCloseNeeded(true);
+    serviceStream->markCloseNeeded();
     (void) releaseStream(serviceStream);
     return AAUDIO_OK;
 }
 
 sp<AAudioServiceStreamBase> AAudioService::convertHandleToServiceStream(
         aaudio_handle_t streamHandle) {
-    sp<AAudioServiceStreamBase> serviceStream = mStreamTracker.getStreamByHandle(streamHandle);
+    sp<AAudioServiceStreamBase> serviceStream = mStreamTracker.getStreamByHandleAndIncrement(
+            streamHandle);
     if (serviceStream.get() != nullptr) {
         // Only allow owner or the aaudio service to access the stream.
         const uid_t callingUserId = IPCThreadState::self()->getCallingUid();
@@ -194,9 +194,9 @@
         if (!allowed) {
             ALOGE("AAudioService: calling uid %d cannot access stream 0x%08X owned by %d",
                   callingUserId, streamHandle, ownerUserId);
+            // We incremented the reference count so we must check if it needs to be closed.
+            checkForPendingClose(serviceStream, AAUDIO_OK);
             serviceStream.clear();
-        } else {
-            serviceStream->incrementServiceReferenceCount();
         }
     }
     return serviceStream;
@@ -328,12 +328,11 @@
 aaudio_result_t AAudioService::disconnectStreamByPortHandle(audio_port_handle_t portHandle) {
     ALOGD("%s(%d) called", __func__, portHandle);
     sp<AAudioServiceStreamBase> serviceStream =
-            mStreamTracker.findStreamByPortHandle(portHandle);
+            mStreamTracker.findStreamByPortHandleAndIncrement(portHandle);
     if (serviceStream.get() == nullptr) {
         ALOGE("%s(), could not find stream with portHandle = %d", __func__, portHandle);
         return AAUDIO_ERROR_INVALID_HANDLE;
     }
-    serviceStream->incrementServiceReferenceCount();
     aaudio_result_t result = serviceStream->stop();
     serviceStream->disconnect();
     return checkForPendingClose(serviceStream, result);
diff --git a/services/oboeservice/AAudioServiceStreamBase.cpp b/services/oboeservice/AAudioServiceStreamBase.cpp
index 48d8002..9af8af3 100644
--- a/services/oboeservice/AAudioServiceStreamBase.cpp
+++ b/services/oboeservice/AAudioServiceStreamBase.cpp
@@ -414,12 +414,13 @@
     sendServiceEvent(AAUDIO_SERVICE_EVENT_VOLUME, volume);
 }
 
-int32_t AAudioServiceStreamBase::incrementServiceReferenceCount() {
-    std::lock_guard<std::mutex> lock(mCallingCountLock);
+int32_t AAudioServiceStreamBase::incrementServiceReferenceCount_l() {
     return ++mCallingCount;
 }
 
-int32_t AAudioServiceStreamBase::decrementServiceReferenceCount() {
-    std::lock_guard<std::mutex> lock(mCallingCountLock);
-    return --mCallingCount;
+int32_t AAudioServiceStreamBase::decrementServiceReferenceCount_l() {
+    int32_t count = --mCallingCount;
+    // Each call to increment should be balanced with one call to decrement.
+    assert(count >= 0);
+    return count;
 }
diff --git a/services/oboeservice/AAudioServiceStreamBase.h b/services/oboeservice/AAudioServiceStreamBase.h
index 0ad015e..a1815d0 100644
--- a/services/oboeservice/AAudioServiceStreamBase.h
+++ b/services/oboeservice/AAudioServiceStreamBase.h
@@ -205,22 +205,33 @@
 
     /**
      * Atomically increment the number of active references to the stream by AAudioService.
+     *
+     * This is called under a global lock in AAudioStreamTracker.
+     *
      * @return value after the increment
      */
-    int32_t incrementServiceReferenceCount();
+    int32_t incrementServiceReferenceCount_l();
 
     /**
      * Atomically decrement the number of active references to the stream by AAudioService.
+     * This should only be called after incrementServiceReferenceCount_l().
+     *
+     * This is called under a global lock in AAudioStreamTracker.
+     *
      * @return value after the decrement
      */
-    int32_t decrementServiceReferenceCount();
+    int32_t decrementServiceReferenceCount_l();
 
     bool isCloseNeeded() const {
         return mCloseNeeded.load();
     }
 
-    void setCloseNeeded(bool needed) {
-        mCloseNeeded.store(needed);
+    /**
+     * Mark this stream as needing to be closed.
+     * Once marked for closing, it cannot be unmarked.
+     */
+    void markCloseNeeded() {
+        mCloseNeeded.store(true);
     }
 
     virtual const char *getTypeText() const { return "Base"; }
@@ -290,8 +301,9 @@
     aaudio_handle_t         mHandle = -1;
     bool                    mFlowing = false;
 
-    std::mutex              mCallingCountLock;
-    std::atomic<int32_t>    mCallingCount{0};
+    // This is modified under a global lock in AAudioStreamTracker.
+    int32_t                 mCallingCount = 0;
+
     std::atomic<bool>       mCloseNeeded{false};
 };
 
diff --git a/services/oboeservice/AAudioStreamTracker.cpp b/services/oboeservice/AAudioStreamTracker.cpp
index 9d5d8fc..3328159 100644
--- a/services/oboeservice/AAudioStreamTracker.cpp
+++ b/services/oboeservice/AAudioStreamTracker.cpp
@@ -30,34 +30,40 @@
 using namespace android;
 using namespace aaudio;
 
-sp<AAudioServiceStreamBase> AAudioStreamTracker::removeStreamByHandle(
+sp<AAudioServiceStreamBase> AAudioStreamTracker::decrementAndRemoveStreamByHandle(
+        aaudio_handle_t streamHandle) {
+    std::lock_guard<std::mutex> lock(mHandleLock);
+    sp<AAudioServiceStreamBase> serviceStream;
+    auto it = mStreamsByHandle.find(streamHandle);
+    if (it != mStreamsByHandle.end()) {
+        sp<AAudioServiceStreamBase> tempStream = it->second;
+        // Does the caller need to close the stream?
+        // The reference count should never be negative.
+        // But it is safer to check for <= 0 than == 0.
+        if ((tempStream->decrementServiceReferenceCount_l() <= 0) && tempStream->isCloseNeeded()) {
+            serviceStream = tempStream; // Only return stream if ready to be closed.
+            mStreamsByHandle.erase(it);
+        }
+    }
+    return serviceStream;
+}
+
+sp<AAudioServiceStreamBase> AAudioStreamTracker::getStreamByHandleAndIncrement(
         aaudio_handle_t streamHandle) {
     std::lock_guard<std::mutex> lock(mHandleLock);
     sp<AAudioServiceStreamBase> serviceStream;
     auto it = mStreamsByHandle.find(streamHandle);
     if (it != mStreamsByHandle.end()) {
         serviceStream = it->second;
-        mStreamsByHandle.erase(it);
+        serviceStream->incrementServiceReferenceCount_l();
     }
     return serviceStream;
 }
 
-sp<AAudioServiceStreamBase> AAudioStreamTracker::getStreamByHandle(
-        aaudio_handle_t streamHandle) {
-    std::lock_guard<std::mutex> lock(mHandleLock);
-    sp<AAudioServiceStreamBase> serviceStream;
-    auto it = mStreamsByHandle.find(streamHandle);
-    if (it != mStreamsByHandle.end()) {
-        serviceStream = it->second;
-    }
-    return serviceStream;
-}
-
-
 // The port handle is only available when the stream is started.
 // So we have to iterate over all the streams.
 // Luckily this rarely happens.
-sp<AAudioServiceStreamBase> AAudioStreamTracker::findStreamByPortHandle(
+sp<AAudioServiceStreamBase> AAudioStreamTracker::findStreamByPortHandleAndIncrement(
         audio_port_handle_t portHandle) {
     std::lock_guard<std::mutex> lock(mHandleLock);
     sp<AAudioServiceStreamBase> serviceStream;
@@ -66,6 +72,7 @@
         auto candidate = it->second;
         if (candidate->getPortHandle() == portHandle) {
             serviceStream = candidate;
+            serviceStream->incrementServiceReferenceCount_l();
             break;
         }
         it++;
@@ -86,7 +93,7 @@
 
 aaudio_handle_t AAudioStreamTracker::addStreamForHandle(sp<AAudioServiceStreamBase> serviceStream) {
     std::lock_guard<std::mutex> lock(mHandleLock);
-    aaudio_handle_t handle = mPreviousHandle.load();
+    aaudio_handle_t handle = mPreviousHandle;
     // Assign a unique handle.
     while (true) {
         handle = bumpHandle(handle);
@@ -98,7 +105,7 @@
             break;
         }
     }
-    mPreviousHandle.store(handle);
+    mPreviousHandle = handle;
     return handle;
 }
 
diff --git a/services/oboeservice/AAudioStreamTracker.h b/services/oboeservice/AAudioStreamTracker.h
index 54e46ca..57ec426 100644
--- a/services/oboeservice/AAudioStreamTracker.h
+++ b/services/oboeservice/AAudioStreamTracker.h
@@ -32,25 +32,35 @@
 
 public:
     /**
-     * Remove the stream associated with the handle.
+     * Find the stream associated with the handle.
+     * Decrement its reference counter. If zero and the stream needs
+     * to be closed then remove the stream and return a pointer to the stream.
+     * Otherwise return null if it does not need to be closed.
+     *
      * @param streamHandle
-     * @return strong pointer to the stream if found or to nullptr
+     * @return strong pointer to the stream if it needs to be closed, or nullptr
      */
-    android::sp<AAudioServiceStreamBase> removeStreamByHandle(aaudio_handle_t streamHandle);
+    android::sp<AAudioServiceStreamBase> decrementAndRemoveStreamByHandle(
+            aaudio_handle_t streamHandle);
 
     /**
      * Look up a stream based on the handle.
+     * Increment its service reference count if found.
+     *
      * @param streamHandle
-     * @return strong pointer to the stream if found or to nullptr
+     * @return strong pointer to the stream if found, or nullptr
      */
-    android::sp<aaudio::AAudioServiceStreamBase> getStreamByHandle(aaudio_handle_t streamHandle);
+    android::sp<aaudio::AAudioServiceStreamBase> getStreamByHandleAndIncrement(
+            aaudio_handle_t streamHandle);
 
     /**
      * Look up a stream based on the AudioPolicy portHandle.
+     * Increment its service reference count if found.
+     *
      * @param portHandle
-     * @return strong pointer to the stream if found or to nullptr
+     * @return strong pointer to the stream if found, or nullptr
      */
-    android::sp<aaudio::AAudioServiceStreamBase> findStreamByPortHandle(
+    android::sp<aaudio::AAudioServiceStreamBase> findStreamByPortHandleAndIncrement(
             audio_port_handle_t portHandle);
 
     /**
@@ -71,7 +81,9 @@
 
     // Track stream using a unique handle that wraps. Only use positive half.
     mutable std::mutex                mHandleLock;
-    std::atomic<aaudio_handle_t>      mPreviousHandle{0};
+    // protected by mHandleLock
+    aaudio_handle_t                   mPreviousHandle = 0;
+    // protected by mHandleLock
     std::map<aaudio_handle_t, android::sp<aaudio::AAudioServiceStreamBase>> mStreamsByHandle;
 };