aaudio: fix possible race condition in close()
Move increment and decrement of reference count under
the same lock used for finding the stream.
Bug: 79693915
Test: run CTS tests
Change-Id: I206eb09724a81f2d79a03fa756adbcbb8abf5efa
diff --git a/services/oboeservice/AAudioService.cpp b/services/oboeservice/AAudioService.cpp
index 5675b0b..6a72e5b 100644
--- a/services/oboeservice/AAudioService.cpp
+++ b/services/oboeservice/AAudioService.cpp
@@ -144,15 +144,14 @@
// If a close request is pending then close the stream
bool AAudioService::releaseStream(const sp<AAudioServiceStreamBase> &serviceStream) {
bool closed = false;
- if ((serviceStream->decrementServiceReferenceCount() == 0) && serviceStream->isCloseNeeded()) {
- // removeStreamByHandle() uses a lock so that if there are two simultaneous closes
- // then only one will get the pointer and do the close.
- sp<AAudioServiceStreamBase> foundStream = mStreamTracker.removeStreamByHandle(serviceStream->getHandle());
- if (foundStream.get() != nullptr) {
- foundStream->close();
- pid_t pid = foundStream->getOwnerProcessId();
- AAudioClientTracker::getInstance().unregisterClientStream(pid, foundStream);
- }
+ // decrementAndRemoveStreamByHandle() uses a lock so that if there are two simultaneous closes
+ // then only one will get the pointer and do the close.
+ sp<AAudioServiceStreamBase> foundStream = mStreamTracker.decrementAndRemoveStreamByHandle(
+ serviceStream->getHandle());
+ if (foundStream.get() != nullptr) {
+ foundStream->close();
+ pid_t pid = foundStream->getOwnerProcessId();
+ AAudioClientTracker::getInstance().unregisterClientStream(pid, foundStream);
closed = true;
}
return closed;
@@ -175,14 +174,15 @@
pid_t pid = serviceStream->getOwnerProcessId();
AAudioClientTracker::getInstance().unregisterClientStream(pid, serviceStream);
- serviceStream->setCloseNeeded(true);
+ serviceStream->markCloseNeeded();
(void) releaseStream(serviceStream);
return AAUDIO_OK;
}
sp<AAudioServiceStreamBase> AAudioService::convertHandleToServiceStream(
aaudio_handle_t streamHandle) {
- sp<AAudioServiceStreamBase> serviceStream = mStreamTracker.getStreamByHandle(streamHandle);
+ sp<AAudioServiceStreamBase> serviceStream = mStreamTracker.getStreamByHandleAndIncrement(
+ streamHandle);
if (serviceStream.get() != nullptr) {
// Only allow owner or the aaudio service to access the stream.
const uid_t callingUserId = IPCThreadState::self()->getCallingUid();
@@ -194,9 +194,9 @@
if (!allowed) {
ALOGE("AAudioService: calling uid %d cannot access stream 0x%08X owned by %d",
callingUserId, streamHandle, ownerUserId);
+ // We incremented the reference count so we must check if it needs to be closed.
+ checkForPendingClose(serviceStream, AAUDIO_OK);
serviceStream.clear();
- } else {
- serviceStream->incrementServiceReferenceCount();
}
}
return serviceStream;
@@ -328,12 +328,11 @@
aaudio_result_t AAudioService::disconnectStreamByPortHandle(audio_port_handle_t portHandle) {
ALOGD("%s(%d) called", __func__, portHandle);
sp<AAudioServiceStreamBase> serviceStream =
- mStreamTracker.findStreamByPortHandle(portHandle);
+ mStreamTracker.findStreamByPortHandleAndIncrement(portHandle);
if (serviceStream.get() == nullptr) {
ALOGE("%s(), could not find stream with portHandle = %d", __func__, portHandle);
return AAUDIO_ERROR_INVALID_HANDLE;
}
- serviceStream->incrementServiceReferenceCount();
aaudio_result_t result = serviceStream->stop();
serviceStream->disconnect();
return checkForPendingClose(serviceStream, result);
diff --git a/services/oboeservice/AAudioServiceStreamBase.cpp b/services/oboeservice/AAudioServiceStreamBase.cpp
index 48d8002..9af8af3 100644
--- a/services/oboeservice/AAudioServiceStreamBase.cpp
+++ b/services/oboeservice/AAudioServiceStreamBase.cpp
@@ -414,12 +414,13 @@
sendServiceEvent(AAUDIO_SERVICE_EVENT_VOLUME, volume);
}
-int32_t AAudioServiceStreamBase::incrementServiceReferenceCount() {
- std::lock_guard<std::mutex> lock(mCallingCountLock);
+int32_t AAudioServiceStreamBase::incrementServiceReferenceCount_l() {
return ++mCallingCount;
}
-int32_t AAudioServiceStreamBase::decrementServiceReferenceCount() {
- std::lock_guard<std::mutex> lock(mCallingCountLock);
- return --mCallingCount;
+int32_t AAudioServiceStreamBase::decrementServiceReferenceCount_l() {
+ int32_t count = --mCallingCount;
+ // Each call to increment should be balanced with one call to decrement.
+ assert(count >= 0);
+ return count;
}
diff --git a/services/oboeservice/AAudioServiceStreamBase.h b/services/oboeservice/AAudioServiceStreamBase.h
index 0ad015e..a1815d0 100644
--- a/services/oboeservice/AAudioServiceStreamBase.h
+++ b/services/oboeservice/AAudioServiceStreamBase.h
@@ -205,22 +205,33 @@
/**
* Atomically increment the number of active references to the stream by AAudioService.
+ *
+ * This is called under a global lock in AAudioStreamTracker.
+ *
* @return value after the increment
*/
- int32_t incrementServiceReferenceCount();
+ int32_t incrementServiceReferenceCount_l();
/**
* Atomically decrement the number of active references to the stream by AAudioService.
+ * This should only be called after incrementServiceReferenceCount_l().
+ *
+ * This is called under a global lock in AAudioStreamTracker.
+ *
* @return value after the decrement
*/
- int32_t decrementServiceReferenceCount();
+ int32_t decrementServiceReferenceCount_l();
bool isCloseNeeded() const {
return mCloseNeeded.load();
}
- void setCloseNeeded(bool needed) {
- mCloseNeeded.store(needed);
+ /**
+ * Mark this stream as needing to be closed.
+ * Once marked for closing, it cannot be unmarked.
+ */
+ void markCloseNeeded() {
+ mCloseNeeded.store(true);
}
virtual const char *getTypeText() const { return "Base"; }
@@ -290,8 +301,9 @@
aaudio_handle_t mHandle = -1;
bool mFlowing = false;
- std::mutex mCallingCountLock;
- std::atomic<int32_t> mCallingCount{0};
+ // This is modified under a global lock in AAudioStreamTracker.
+ int32_t mCallingCount = 0;
+
std::atomic<bool> mCloseNeeded{false};
};
diff --git a/services/oboeservice/AAudioStreamTracker.cpp b/services/oboeservice/AAudioStreamTracker.cpp
index 9d5d8fc..3328159 100644
--- a/services/oboeservice/AAudioStreamTracker.cpp
+++ b/services/oboeservice/AAudioStreamTracker.cpp
@@ -30,34 +30,40 @@
using namespace android;
using namespace aaudio;
-sp<AAudioServiceStreamBase> AAudioStreamTracker::removeStreamByHandle(
+sp<AAudioServiceStreamBase> AAudioStreamTracker::decrementAndRemoveStreamByHandle(
+ aaudio_handle_t streamHandle) {
+ std::lock_guard<std::mutex> lock(mHandleLock);
+ sp<AAudioServiceStreamBase> serviceStream;
+ auto it = mStreamsByHandle.find(streamHandle);
+ if (it != mStreamsByHandle.end()) {
+ sp<AAudioServiceStreamBase> tempStream = it->second;
+ // Does the caller need to close the stream?
+ // The reference count should never be negative.
+ // But it is safer to check for <= 0 than == 0.
+ if ((tempStream->decrementServiceReferenceCount_l() <= 0) && tempStream->isCloseNeeded()) {
+ serviceStream = tempStream; // Only return stream if ready to be closed.
+ mStreamsByHandle.erase(it);
+ }
+ }
+ return serviceStream;
+}
+
+sp<AAudioServiceStreamBase> AAudioStreamTracker::getStreamByHandleAndIncrement(
aaudio_handle_t streamHandle) {
std::lock_guard<std::mutex> lock(mHandleLock);
sp<AAudioServiceStreamBase> serviceStream;
auto it = mStreamsByHandle.find(streamHandle);
if (it != mStreamsByHandle.end()) {
serviceStream = it->second;
- mStreamsByHandle.erase(it);
+ serviceStream->incrementServiceReferenceCount_l();
}
return serviceStream;
}
-sp<AAudioServiceStreamBase> AAudioStreamTracker::getStreamByHandle(
- aaudio_handle_t streamHandle) {
- std::lock_guard<std::mutex> lock(mHandleLock);
- sp<AAudioServiceStreamBase> serviceStream;
- auto it = mStreamsByHandle.find(streamHandle);
- if (it != mStreamsByHandle.end()) {
- serviceStream = it->second;
- }
- return serviceStream;
-}
-
-
// The port handle is only available when the stream is started.
// So we have to iterate over all the streams.
// Luckily this rarely happens.
-sp<AAudioServiceStreamBase> AAudioStreamTracker::findStreamByPortHandle(
+sp<AAudioServiceStreamBase> AAudioStreamTracker::findStreamByPortHandleAndIncrement(
audio_port_handle_t portHandle) {
std::lock_guard<std::mutex> lock(mHandleLock);
sp<AAudioServiceStreamBase> serviceStream;
@@ -66,6 +72,7 @@
auto candidate = it->second;
if (candidate->getPortHandle() == portHandle) {
serviceStream = candidate;
+ serviceStream->incrementServiceReferenceCount_l();
break;
}
it++;
@@ -86,7 +93,7 @@
aaudio_handle_t AAudioStreamTracker::addStreamForHandle(sp<AAudioServiceStreamBase> serviceStream) {
std::lock_guard<std::mutex> lock(mHandleLock);
- aaudio_handle_t handle = mPreviousHandle.load();
+ aaudio_handle_t handle = mPreviousHandle;
// Assign a unique handle.
while (true) {
handle = bumpHandle(handle);
@@ -98,7 +105,7 @@
break;
}
}
- mPreviousHandle.store(handle);
+ mPreviousHandle = handle;
return handle;
}
diff --git a/services/oboeservice/AAudioStreamTracker.h b/services/oboeservice/AAudioStreamTracker.h
index 54e46ca..57ec426 100644
--- a/services/oboeservice/AAudioStreamTracker.h
+++ b/services/oboeservice/AAudioStreamTracker.h
@@ -32,25 +32,35 @@
public:
/**
- * Remove the stream associated with the handle.
+ * Find the stream associated with the handle.
+ * Decrement its reference counter. If zero and the stream needs
+ * to be closed then remove the stream and return a pointer to the stream.
+ * Otherwise return null if it does not need to be closed.
+ *
* @param streamHandle
- * @return strong pointer to the stream if found or to nullptr
+ * @return strong pointer to the stream if it needs to be closed, or nullptr
*/
- android::sp<AAudioServiceStreamBase> removeStreamByHandle(aaudio_handle_t streamHandle);
+ android::sp<AAudioServiceStreamBase> decrementAndRemoveStreamByHandle(
+ aaudio_handle_t streamHandle);
/**
* Look up a stream based on the handle.
+ * Increment its service reference count if found.
+ *
* @param streamHandle
- * @return strong pointer to the stream if found or to nullptr
+ * @return strong pointer to the stream if found, or nullptr
*/
- android::sp<aaudio::AAudioServiceStreamBase> getStreamByHandle(aaudio_handle_t streamHandle);
+ android::sp<aaudio::AAudioServiceStreamBase> getStreamByHandleAndIncrement(
+ aaudio_handle_t streamHandle);
/**
* Look up a stream based on the AudioPolicy portHandle.
+ * Increment its service reference count if found.
+ *
* @param portHandle
- * @return strong pointer to the stream if found or to nullptr
+ * @return strong pointer to the stream if found, or nullptr
*/
- android::sp<aaudio::AAudioServiceStreamBase> findStreamByPortHandle(
+ android::sp<aaudio::AAudioServiceStreamBase> findStreamByPortHandleAndIncrement(
audio_port_handle_t portHandle);
/**
@@ -71,7 +81,9 @@
// Track stream using a unique handle that wraps. Only use positive half.
mutable std::mutex mHandleLock;
- std::atomic<aaudio_handle_t> mPreviousHandle{0};
+ // protected by mHandleLock
+ aaudio_handle_t mPreviousHandle = 0;
+ // protected by mHandleLock
std::map<aaudio_handle_t, android::sp<aaudio::AAudioServiceStreamBase>> mStreamsByHandle;
};