Implement DrmSessionManager w mediaresourcemanager

Bug: 134787536
Test: DrmSessionManagerTest
Test: DrmSessionManager_test
Test: ResourceManagerService_test
Change-Id: Iab9f4f681c83f46b043cefc8633bb3e513a8e75a
diff --git a/drm/libmediadrm/Android.bp b/drm/libmediadrm/Android.bp
index 39b048a..8507729 100644
--- a/drm/libmediadrm/Android.bp
+++ b/drm/libmediadrm/Android.bp
@@ -40,9 +40,11 @@
         "libcutils",
         "libdl",
         "liblog",
+        "libmedia",
         "libmediadrmmetrics_lite",
         "libmediametrics",
         "libmediautils",
+        "libresourcemanagerservice",
         "libstagefright_foundation",
         "libutils",
         "android.hardware.drm@1.0",
diff --git a/drm/libmediadrm/DrmHal.cpp b/drm/libmediadrm/DrmHal.cpp
index 7cfe900..bd4b521 100644
--- a/drm/libmediadrm/DrmHal.cpp
+++ b/drm/libmediadrm/DrmHal.cpp
@@ -293,38 +293,45 @@
     }
 }
 
-
 Mutex DrmHal::mLock;
 
-struct DrmSessionClient : public DrmSessionClientInterface {
-    explicit DrmSessionClient(DrmHal* drm) : mDrm(drm) {}
-
-    virtual bool reclaimSession(const Vector<uint8_t>& sessionId) {
-        sp<DrmHal> drm = mDrm.promote();
-        if (drm == NULL) {
-            return true;
-        }
-        status_t err = drm->closeSession(sessionId);
-        if (err != OK) {
-            return false;
-        }
-        drm->sendEvent(EventType::SESSION_RECLAIMED,
-                toHidlVec(sessionId), hidl_vec<uint8_t>());
+bool DrmHal::DrmSessionClient::reclaimResource() {
+    sp<DrmHal> drm = mDrm.promote();
+    if (drm == NULL) {
         return true;
     }
+    status_t err = drm->closeSession(mSessionId);
+    if (err != OK) {
+        return false;
+    }
+    drm->sendEvent(EventType::SESSION_RECLAIMED,
+            toHidlVec(mSessionId), hidl_vec<uint8_t>());
+    return true;
+}
 
-protected:
-    virtual ~DrmSessionClient() {}
+String8 DrmHal::DrmSessionClient::getName() {
+    String8 name;
+    sp<DrmHal> drm = mDrm.promote();
+    if (drm == NULL) {
+        name.append("<deleted>");
+    } else if (drm->getPropertyStringInternal(String8("vendor"), name) != OK
+        || name.isEmpty()) {
+      name.append("<Get vendor failed or is empty>");
+    }
+    name.append("[");
+    for (size_t i = 0; i < mSessionId.size(); ++i) {
+        name.appendFormat("%02x", mSessionId[i]);
+    }
+    name.append("]");
+    return name;
+}
 
-private:
-    wp<DrmHal> mDrm;
-
-    DISALLOW_EVIL_CONSTRUCTORS(DrmSessionClient);
-};
+DrmHal::DrmSessionClient::~DrmSessionClient() {
+    DrmSessionManager::Instance()->removeSession(mSessionId);
+}
 
 DrmHal::DrmHal()
-   : mDrmSessionClient(new DrmSessionClient(this)),
-     mFactories(makeDrmFactories()),
+   : mFactories(makeDrmFactories()),
      mInitCheck((mFactories.size() == 0) ? ERROR_UNSUPPORTED : NO_INIT) {
 }
 
@@ -333,14 +340,13 @@
     auto openSessions = mOpenSessions;
     for (size_t i = 0; i < openSessions.size(); i++) {
         mLock.unlock();
-        closeSession(openSessions[i]);
+        closeSession(openSessions[i]->mSessionId);
         mLock.lock();
     }
     mOpenSessions.clear();
 }
 
 DrmHal::~DrmHal() {
-    DrmSessionManager::Instance()->removeDrm(mDrmSessionClient);
 }
 
 void DrmHal::cleanup() {
@@ -746,9 +752,9 @@
     } while (retry);
 
     if (err == OK) {
-        DrmSessionManager::Instance()->addSession(getCallingPid(),
-                mDrmSessionClient, sessionId);
-        mOpenSessions.push(sessionId);
+        sp<DrmSessionClient> client(new DrmSessionClient(this, sessionId));
+        DrmSessionManager::Instance()->addSession(getCallingPid(), client, sessionId);
+        mOpenSessions.push(client);
         mMetrics.SetSessionStart(sessionId);
     }
 
@@ -765,7 +771,7 @@
         if (status == Status::OK) {
             DrmSessionManager::Instance()->removeSession(sessionId);
             for (size_t i = 0; i < mOpenSessions.size(); i++) {
-                if (mOpenSessions[i] == sessionId) {
+                if (isEqualSessionId(mOpenSessions[i]->mSessionId, sessionId)) {
                     mOpenSessions.removeAt(i);
                     break;
                 }
diff --git a/drm/libmediadrm/DrmSessionManager.cpp b/drm/libmediadrm/DrmSessionManager.cpp
index 375644c..0b927ef 100644
--- a/drm/libmediadrm/DrmSessionManager.cpp
+++ b/drm/libmediadrm/DrmSessionManager.cpp
@@ -21,12 +21,17 @@
 #include <binder/IPCThreadState.h>
 #include <binder/IProcessInfoService.h>
 #include <binder/IServiceManager.h>
-#include <media/stagefright/ProcessInfo.h>
-#include <mediadrm/DrmSessionClientInterface.h>
+#include <cutils/properties.h>
+#include <media/IResourceManagerClient.h>
+#include <media/MediaResource.h>
 #include <mediadrm/DrmSessionManager.h>
 #include <unistd.h>
 #include <utils/String8.h>
 
+#include <vector>
+
+#include "ResourceManagerService.h"
+
 namespace android {
 
 static String8 GetSessionIdString(const Vector<uint8_t> &sessionId) {
@@ -37,6 +42,35 @@
     return sessionIdStr;
 }
 
+static std::vector<uint8_t> toStdVec(const Vector<uint8_t> &vector) {
+    const uint8_t *v = vector.array();
+    std::vector<uint8_t> vec(v, v + vector.size());
+    return vec;
+}
+
+static uint64_t toClientId(const sp<IResourceManagerClient>& drm) {
+    return reinterpret_cast<int64_t>(drm.get());
+}
+
+static Vector<MediaResource> toResourceVec(const Vector<uint8_t> &sessionId) {
+    Vector<MediaResource> resources;
+    // use UINT64_MAX to decrement through addition overflow
+    resources.push_back(MediaResource(MediaResource::kDrmSession, toStdVec(sessionId), UINT64_MAX));
+    return resources;
+}
+
+static sp<IResourceManagerService> getResourceManagerService() {
+    if (property_get_bool("persist.device_config.media_native.mediadrmserver", 1)) {
+        return new ResourceManagerService();
+    }
+    sp<IServiceManager> sm = defaultServiceManager();
+    if (sm == NULL) {
+        return NULL;
+    }
+    sp<IBinder> binder = sm->getService(String16("media.resource_manager"));
+    return interface_cast<IResourceManagerService>(binder);
+}
+
 bool isEqualSessionId(const Vector<uint8_t> &sessionId1, const Vector<uint8_t> &sessionId2) {
     if (sessionId1.size() != sessionId2.size()) {
         return false;
@@ -51,189 +85,114 @@
 
 sp<DrmSessionManager> DrmSessionManager::Instance() {
     static sp<DrmSessionManager> drmSessionManager = new DrmSessionManager();
+    drmSessionManager->init();
     return drmSessionManager;
 }
 
 DrmSessionManager::DrmSessionManager()
-    : mProcessInfo(new ProcessInfo()),
-      mTime(0) {}
+    : DrmSessionManager(getResourceManagerService()) {
+}
 
-DrmSessionManager::DrmSessionManager(sp<ProcessInfoInterface> processInfo)
-    : mProcessInfo(processInfo),
-      mTime(0) {}
+DrmSessionManager::DrmSessionManager(const sp<IResourceManagerService> &service)
+    : mService(service),
+      mInitialized(false) {
+    if (mService == NULL) {
+        ALOGE("Failed to init ResourceManagerService");
+    }
+}
 
-DrmSessionManager::~DrmSessionManager() {}
+DrmSessionManager::~DrmSessionManager() {
+    if (mService != NULL) {
+        IInterface::asBinder(mService)->unlinkToDeath(this);
+    }
+}
 
-void DrmSessionManager::addSession(
-        int pid, const sp<DrmSessionClientInterface>& drm, const Vector<uint8_t> &sessionId) {
-    ALOGV("addSession(pid %d, drm %p, sessionId %s)", pid, drm.get(),
+void DrmSessionManager::init() {
+    Mutex::Autolock lock(mLock);
+    if (mInitialized) {
+        return;
+    }
+    mInitialized = true;
+    if (mService != NULL) {
+        IInterface::asBinder(mService)->linkToDeath(this);
+    }
+}
+
+void DrmSessionManager::addSession(int pid,
+        const sp<IResourceManagerClient>& drm, const Vector<uint8_t> &sessionId) {
+    uid_t uid = IPCThreadState::self()->getCallingUid();
+    ALOGV("addSession(pid %d, uid %d, drm %p, sessionId %s)", pid, uid, drm.get(),
             GetSessionIdString(sessionId).string());
 
     Mutex::Autolock lock(mLock);
-    SessionInfo info;
-    info.drm = drm;
-    info.sessionId = sessionId;
-    info.timeStamp = getTime_l();
-    ssize_t index = mSessionMap.indexOfKey(pid);
-    if (index < 0) {
-        // new pid
-        SessionInfos infosForPid;
-        infosForPid.push_back(info);
-        mSessionMap.add(pid, infosForPid);
-    } else {
-        mSessionMap.editValueAt(index).push_back(info);
+    if (mService == NULL) {
+        return;
     }
+
+    int64_t clientId = toClientId(drm);
+    mSessionMap[toStdVec(sessionId)] = (SessionInfo){pid, uid, clientId};
+    mService->addResource(pid, uid, clientId, drm, toResourceVec(sessionId));
 }
 
 void DrmSessionManager::useSession(const Vector<uint8_t> &sessionId) {
     ALOGV("useSession(%s)", GetSessionIdString(sessionId).string());
 
     Mutex::Autolock lock(mLock);
-    for (size_t i = 0; i < mSessionMap.size(); ++i) {
-        SessionInfos& infos = mSessionMap.editValueAt(i);
-        for (size_t j = 0; j < infos.size(); ++j) {
-            SessionInfo& info = infos.editItemAt(j);
-            if (isEqualSessionId(sessionId, info.sessionId)) {
-                info.timeStamp = getTime_l();
-                return;
-            }
-        }
+    auto it = mSessionMap.find(toStdVec(sessionId));
+    if (mService == NULL || it == mSessionMap.end()) {
+        return;
     }
+
+    auto info = it->second;
+    mService->addResource(info.pid, info.uid, info.clientId, NULL, toResourceVec(sessionId));
 }
 
 void DrmSessionManager::removeSession(const Vector<uint8_t> &sessionId) {
     ALOGV("removeSession(%s)", GetSessionIdString(sessionId).string());
 
     Mutex::Autolock lock(mLock);
-    for (size_t i = 0; i < mSessionMap.size(); ++i) {
-        SessionInfos& infos = mSessionMap.editValueAt(i);
-        for (size_t j = 0; j < infos.size(); ++j) {
-            if (isEqualSessionId(sessionId, infos[j].sessionId)) {
-                infos.removeAt(j);
-                return;
-            }
-        }
+    auto it = mSessionMap.find(toStdVec(sessionId));
+    if (mService == NULL || it == mSessionMap.end()) {
+        return;
     }
-}
 
-void DrmSessionManager::removeDrm(const sp<DrmSessionClientInterface>& drm) {
-    ALOGV("removeDrm(%p)", drm.get());
-
-    Mutex::Autolock lock(mLock);
-    bool found = false;
-    for (size_t i = 0; i < mSessionMap.size(); ++i) {
-        SessionInfos& infos = mSessionMap.editValueAt(i);
-        for (size_t j = 0; j < infos.size();) {
-            if (infos[j].drm == drm) {
-                ALOGV("removed session (%s)", GetSessionIdString(infos[j].sessionId).string());
-                j = infos.removeAt(j);
-                found = true;
-            } else {
-                ++j;
-            }
-        }
-        if (found) {
-            break;
-        }
-    }
+    auto info = it->second;
+    mService->removeResource(info.pid, info.clientId, toResourceVec(sessionId));
+    mSessionMap.erase(it);
 }
 
 bool DrmSessionManager::reclaimSession(int callingPid) {
     ALOGV("reclaimSession(%d)", callingPid);
 
-    sp<DrmSessionClientInterface> drm;
-    Vector<uint8_t> sessionId;
-    int lowestPriorityPid;
-    int lowestPriority;
-    {
-        Mutex::Autolock lock(mLock);
-        int callingPriority;
-        if (!mProcessInfo->getPriority(callingPid, &callingPriority)) {
-            return false;
-        }
-        if (!getLowestPriority_l(&lowestPriorityPid, &lowestPriority)) {
-            return false;
-        }
-        if (lowestPriority <= callingPriority) {
-            return false;
-        }
+    // unlock early because reclaimResource might callback into removeSession
+    mLock.lock();
+    sp<IResourceManagerService> service(mService);
+    mLock.unlock();
 
-        if (!getLeastUsedSession_l(lowestPriorityPid, &drm, &sessionId)) {
-            return false;
-        }
-    }
-
-    if (drm == NULL) {
+    if (service == NULL) {
         return false;
     }
 
-    ALOGV("reclaim session(%s) opened by pid %d",
-            GetSessionIdString(sessionId).string(), lowestPriorityPid);
-
-    return drm->reclaimSession(sessionId);
+    // cannot update mSessionMap because we do not know which sessionId is reclaimed;
+    // we rely on IResourceManagerClient to removeSession in reclaimResource
+    Vector<uint8_t> dummy;
+    return service->reclaimResource(callingPid, toResourceVec(dummy));
 }
 
-int64_t DrmSessionManager::getTime_l() {
-    return mTime++;
+size_t DrmSessionManager::getSessionCount() const {
+    Mutex::Autolock lock(mLock);
+    return mSessionMap.size();
 }
 
-bool DrmSessionManager::getLowestPriority_l(int* lowestPriorityPid, int* lowestPriority) {
-    int pid = -1;
-    int priority = -1;
-    for (size_t i = 0; i < mSessionMap.size(); ++i) {
-        if (mSessionMap.valueAt(i).size() == 0) {
-            // no opened session by this process.
-            continue;
-        }
-        int tempPid = mSessionMap.keyAt(i);
-        int tempPriority;
-        if (!mProcessInfo->getPriority(tempPid, &tempPriority)) {
-            // shouldn't happen.
-            return false;
-        }
-        if (pid == -1) {
-            pid = tempPid;
-            priority = tempPriority;
-        } else {
-            if (tempPriority > priority) {
-                pid = tempPid;
-                priority = tempPriority;
-            }
-        }
-    }
-    if (pid != -1) {
-        *lowestPriorityPid = pid;
-        *lowestPriority = priority;
-    }
-    return (pid != -1);
+bool DrmSessionManager::containsSession(const Vector<uint8_t>& sessionId) const {
+    Mutex::Autolock lock(mLock);
+    return mSessionMap.count(toStdVec(sessionId));
 }
 
-bool DrmSessionManager::getLeastUsedSession_l(
-        int pid, sp<DrmSessionClientInterface>* drm, Vector<uint8_t>* sessionId) {
-    ssize_t index = mSessionMap.indexOfKey(pid);
-    if (index < 0) {
-        return false;
-    }
-
-    int leastUsedIndex = -1;
-    int64_t minTs = LLONG_MAX;
-    const SessionInfos& infos = mSessionMap.valueAt(index);
-    for (size_t j = 0; j < infos.size(); ++j) {
-        if (leastUsedIndex == -1) {
-            leastUsedIndex = j;
-            minTs = infos[j].timeStamp;
-        } else {
-            if (infos[j].timeStamp < minTs) {
-                leastUsedIndex = j;
-                minTs = infos[j].timeStamp;
-            }
-        }
-    }
-    if (leastUsedIndex != -1) {
-        *drm = infos[leastUsedIndex].drm;
-        *sessionId = infos[leastUsedIndex].sessionId;
-    }
-    return (leastUsedIndex != -1);
+void DrmSessionManager::binderDied(const wp<IBinder>& /*who*/) {
+    ALOGW("ResourceManagerService died.");
+    Mutex::Autolock lock(mLock);
+    mService.clear();
 }
 
 }  // namespace android
diff --git a/drm/libmediadrm/include/mediadrm/DrmHal.h b/drm/libmediadrm/include/mediadrm/DrmHal.h
index bdf1b30..542d300 100644
--- a/drm/libmediadrm/include/mediadrm/DrmHal.h
+++ b/drm/libmediadrm/include/mediadrm/DrmHal.h
@@ -26,8 +26,10 @@
 #include <android/hardware/drm/1.2/IDrmPlugin.h>
 #include <android/hardware/drm/1.2/IDrmPluginListener.h>
 
+#include <media/IResourceManagerService.h>
 #include <media/MediaAnalyticsItem.h>
 #include <mediadrm/DrmMetrics.h>
+#include <mediadrm/DrmSessionManager.h>
 #include <mediadrm/IDrm.h>
 #include <mediadrm/IDrmClient.h>
 #include <utils/threads.h>
@@ -59,6 +61,26 @@
 struct DrmHal : public BnDrm,
                 public IBinder::DeathRecipient,
                 public IDrmPluginListener_V1_2 {
+
+    struct DrmSessionClient : public BnResourceManagerClient {
+        explicit DrmSessionClient(DrmHal* drm, const Vector<uint8_t>& sessionId)
+          : mSessionId(sessionId),
+            mDrm(drm) {}
+
+        virtual bool reclaimResource();
+        virtual String8 getName();
+
+        const Vector<uint8_t> mSessionId;
+
+    protected:
+        virtual ~DrmSessionClient();
+
+    private:
+        wp<DrmHal> mDrm;
+
+        DISALLOW_EVIL_CONSTRUCTORS(DrmSessionClient);
+    };
+
     DrmHal();
     virtual ~DrmHal();
 
@@ -193,8 +215,6 @@
 private:
     static Mutex mLock;
 
-    sp<DrmSessionClientInterface> mDrmSessionClient;
-
     sp<IDrmClient> mListener;
     mutable Mutex mEventLock;
     mutable Mutex mNotifyLock;
@@ -208,7 +228,7 @@
     // Mutable to allow modification within GetPropertyByteArray.
     mutable MediaDrmMetrics mMetrics;
 
-    Vector<Vector<uint8_t>> mOpenSessions;
+    Vector<sp<DrmSessionClient>> mOpenSessions;
     void closeOpenSessions();
     void cleanup();
 
diff --git a/drm/libmediadrm/include/mediadrm/DrmSessionManager.h b/drm/libmediadrm/include/mediadrm/DrmSessionManager.h
index ba27199..b1ad580 100644
--- a/drm/libmediadrm/include/mediadrm/DrmSessionManager.h
+++ b/drm/libmediadrm/include/mediadrm/DrmSessionManager.h
@@ -18,56 +18,61 @@
 
 #define DRM_SESSION_MANAGER_H_
 
+#include <binder/IBinder.h>
+#include <media/IResourceManagerService.h>
 #include <media/stagefright/foundation/ABase.h>
 #include <utils/RefBase.h>
 #include <utils/KeyedVector.h>
 #include <utils/threads.h>
 #include <utils/Vector.h>
 
+#include <map>
+#include <utility>
+#include <vector>
+
 namespace android {
 
 class DrmSessionManagerTest;
-struct DrmSessionClientInterface;
-struct ProcessInfoInterface;
+class IResourceManagerClient;
 
 bool isEqualSessionId(const Vector<uint8_t> &sessionId1, const Vector<uint8_t> &sessionId2);
 
 struct SessionInfo {
-    sp<DrmSessionClientInterface> drm;
-    Vector<uint8_t> sessionId;
-    int64_t timeStamp;
+    pid_t pid;
+    uid_t uid;
+    int64_t clientId;
 };
 
-typedef Vector<SessionInfo > SessionInfos;
-typedef KeyedVector<int, SessionInfos > PidSessionInfosMap;
+typedef std::map<std::vector<uint8_t>, SessionInfo> SessionInfoMap;
 
-struct DrmSessionManager : public RefBase {
+struct DrmSessionManager : public IBinder::DeathRecipient {
     static sp<DrmSessionManager> Instance();
 
     DrmSessionManager();
-    explicit DrmSessionManager(sp<ProcessInfoInterface> processInfo);
+    explicit DrmSessionManager(const sp<IResourceManagerService> &service);
 
-    void addSession(int pid, const sp<DrmSessionClientInterface>& drm, const Vector<uint8_t>& sessionId);
+    void addSession(int pid, const sp<IResourceManagerClient>& drm, const Vector<uint8_t>& sessionId);
     void useSession(const Vector<uint8_t>& sessionId);
     void removeSession(const Vector<uint8_t>& sessionId);
-    void removeDrm(const sp<DrmSessionClientInterface>& drm);
     bool reclaimSession(int callingPid);
 
+    // sanity check APIs
+    size_t getSessionCount() const;
+    bool containsSession(const Vector<uint8_t>& sessionId) const;
+
+    // implements DeathRecipient
+    virtual void binderDied(const wp<IBinder>& /*who*/);
+
 protected:
     virtual ~DrmSessionManager();
 
 private:
-    friend class DrmSessionManagerTest;
+    void init();
 
-    int64_t getTime_l();
-    bool getLowestPriority_l(int* lowestPriorityPid, int* lowestPriority);
-    bool getLeastUsedSession_l(
-            int pid, sp<DrmSessionClientInterface>* drm, Vector<uint8_t>* sessionId);
-
-    sp<ProcessInfoInterface> mProcessInfo;
+    sp<IResourceManagerService> mService;
     mutable Mutex mLock;
-    PidSessionInfosMap mSessionMap;
-    int64_t mTime;
+    SessionInfoMap mSessionMap;
+    bool mInitialized;
 
     DISALLOW_EVIL_CONSTRUCTORS(DrmSessionManager);
 };
diff --git a/media/libmedia/MediaResource.cpp b/media/libmedia/MediaResource.cpp
index e636a50..8626009 100644
--- a/media/libmedia/MediaResource.cpp
+++ b/media/libmedia/MediaResource.cpp
@@ -19,6 +19,8 @@
 #include <utils/Log.h>
 #include <media/MediaResource.h>
 
+#include <vector>
+
 namespace android {
 
 MediaResource::MediaResource()
@@ -36,26 +38,48 @@
           mSubType(subType),
           mValue(value) {}
 
+MediaResource::MediaResource(Type type, const std::vector<uint8_t> &id, uint64_t value)
+        : mType(type),
+          mSubType(kUnspecifiedSubType),
+          mValue(value),
+          mId(id) {}
+
 void MediaResource::readFromParcel(const Parcel &parcel) {
     mType = static_cast<Type>(parcel.readInt32());
     mSubType = static_cast<SubType>(parcel.readInt32());
     mValue = parcel.readUint64();
+    parcel.readByteVector(&mId);
 }
 
 void MediaResource::writeToParcel(Parcel *parcel) const {
     parcel->writeInt32(static_cast<int32_t>(mType));
     parcel->writeInt32(static_cast<int32_t>(mSubType));
     parcel->writeUint64(mValue);
+    parcel->writeByteVector(mId);
+}
+
+static String8 bytesToHexString(const std::vector<uint8_t> &bytes) {
+    String8 str;
+    for (auto &b : bytes) {
+        str.appendFormat("%02x", b);
+    }
+    return str;
 }
 
 String8 MediaResource::toString() const {
     String8 str;
-    str.appendFormat("%s/%s:%llu", asString(mType), asString(mSubType), (unsigned long long)mValue);
+    str.appendFormat("%s/%s:[%s]:%llu",
+        asString(mType), asString(mSubType),
+        bytesToHexString(mId).c_str(),
+        (unsigned long long)mValue);
     return str;
 }
 
 bool MediaResource::operator==(const MediaResource &other) const {
-    return (other.mType == mType) && (other.mSubType == mSubType) && (other.mValue == mValue);
+    return (other.mType == mType)
+      && (other.mSubType == mSubType)
+      && (other.mValue == mValue)
+      && (other.mId == mId);
 }
 
 bool MediaResource::operator!=(const MediaResource &other) const {
diff --git a/media/libmedia/include/media/MediaResource.h b/media/libmedia/include/media/MediaResource.h
index 10a07bb..e9684f0 100644
--- a/media/libmedia/include/media/MediaResource.h
+++ b/media/libmedia/include/media/MediaResource.h
@@ -20,6 +20,7 @@
 
 #include <binder/Parcel.h>
 #include <utils/String8.h>
+#include <vector>
 
 namespace android {
 
@@ -32,6 +33,7 @@
         kGraphicMemory,
         kCpuBoost,
         kBattery,
+        kDrmSession,
     };
 
     enum SubType {
@@ -43,6 +45,7 @@
     MediaResource();
     MediaResource(Type type, uint64_t value);
     MediaResource(Type type, SubType subType, uint64_t value);
+    MediaResource(Type type, const std::vector<uint8_t> &id, uint64_t value);
 
     void readFromParcel(const Parcel &parcel);
     void writeToParcel(Parcel *parcel) const;
@@ -55,6 +58,8 @@
     Type mType;
     SubType mSubType;
     uint64_t mValue;
+    // for kDrmSession-type mId is the unique session id obtained via MediaDrm#openSession
+    std::vector<uint8_t> mId;
 };
 
 inline static const char *asString(MediaResource::Type i, const char *def = "??") {
@@ -65,6 +70,7 @@
         case MediaResource::kGraphicMemory:  return "graphic-memory";
         case MediaResource::kCpuBoost:       return "cpu-boost";
         case MediaResource::kBattery:        return "battery";
+        case MediaResource::kDrmSession:     return "drm-session";
         default:                             return def;
     }
 }
diff --git a/media/libmediaplayerservice/tests/Android.bp b/media/libmediaplayerservice/tests/Android.bp
index f8c89e5..8357925 100644
--- a/media/libmediaplayerservice/tests/Android.bp
+++ b/media/libmediaplayerservice/tests/Android.bp
@@ -6,14 +6,22 @@
 
     shared_libs: [
         "liblog",
+        "libbinder",
+        "libmedia",
         "libmediaplayerservice",
         "libmediadrm",
+        "libresourcemanagerservice",
         "libutils",
         "android.hardware.drm@1.0",
         "android.hardware.drm@1.1",
         "android.hardware.drm@1.2",
     ],
 
+    include_dirs: [
+        "frameworks/av/include",
+        "frameworks/av/services/mediaresourcemanager",
+    ],
+
     cflags: [
         "-Werror",
         "-Wall",
diff --git a/media/libmediaplayerservice/tests/DrmSessionManager_test.cpp b/media/libmediaplayerservice/tests/DrmSessionManager_test.cpp
index d81ee05..58e4bee 100644
--- a/media/libmediaplayerservice/tests/DrmSessionManager_test.cpp
+++ b/media/libmediaplayerservice/tests/DrmSessionManager_test.cpp
@@ -20,14 +20,29 @@
 
 #include <gtest/gtest.h>
 
+#include <media/IResourceManagerService.h>
+#include <media/IResourceManagerClient.h>
 #include <media/stagefright/foundation/ADebug.h>
 #include <media/stagefright/ProcessInfoInterface.h>
 #include <mediadrm/DrmHal.h>
 #include <mediadrm/DrmSessionClientInterface.h>
 #include <mediadrm/DrmSessionManager.h>
 
+#include <algorithm>
+#include <vector>
+
+#include "ResourceManagerService.h"
+
 namespace android {
 
+static Vector<uint8_t> toAndroidVector(const std::vector<uint8_t> &vec) {
+    Vector<uint8_t> aVec;
+    for (auto b : vec) {
+        aVec.push_back(b);
+    }
+    return aVec;
+}
+
 struct FakeProcessInfo : public ProcessInfoInterface {
     FakeProcessInfo() {}
     virtual ~FakeProcessInfo() {}
@@ -47,173 +62,128 @@
     DISALLOW_EVIL_CONSTRUCTORS(FakeProcessInfo);
 };
 
-struct FakeDrm : public DrmSessionClientInterface {
-    FakeDrm() {}
+struct FakeDrm : public BnResourceManagerClient {
+    FakeDrm(const std::vector<uint8_t>& sessionId, const sp<DrmSessionManager>& manager)
+        : mSessionId(toAndroidVector(sessionId)),
+          mReclaimed(false),
+          mDrmSessionManager(manager) {}
+
     virtual ~FakeDrm() {}
 
-    virtual bool reclaimSession(const Vector<uint8_t>& sessionId) {
-        mReclaimedSessions.push_back(sessionId);
+    virtual bool reclaimResource() {
+        mReclaimed = true;
+        mDrmSessionManager->removeSession(mSessionId);
         return true;
     }
 
-    const Vector<Vector<uint8_t> >& reclaimedSessions() const {
-        return mReclaimedSessions;
+    virtual String8 getName() {
+        String8 name("FakeDrm[");
+        for (size_t i = 0; i < mSessionId.size(); ++i) {
+            name.appendFormat("%02x", mSessionId[i]);
+        }
+        name.append("]");
+        return name;
     }
 
+    bool isReclaimed() const {
+        return mReclaimed;
+    }
+
+    const Vector<uint8_t> mSessionId;
+
 private:
-    Vector<Vector<uint8_t> > mReclaimedSessions;
+    bool mReclaimed;
+    const sp<DrmSessionManager> mDrmSessionManager;
 
     DISALLOW_EVIL_CONSTRUCTORS(FakeDrm);
 };
 
+struct FakeSystemCallback :
+        public ResourceManagerService::SystemCallbackInterface {
+    FakeSystemCallback() {}
+
+    virtual void noteStartVideo(int /*uid*/) override {}
+
+    virtual void noteStopVideo(int /*uid*/) override {}
+
+    virtual void noteResetVideo() override {}
+
+    virtual bool requestCpusetBoost(
+            bool /*enable*/, const sp<IInterface> &/*client*/) override {
+        return true;
+    }
+
+protected:
+    virtual ~FakeSystemCallback() {}
+
+private:
+
+    DISALLOW_EVIL_CONSTRUCTORS(FakeSystemCallback);
+};
+
 static const int kTestPid1 = 30;
 static const int kTestPid2 = 20;
-static const uint8_t kTestSessionId1[] = {1, 2, 3};
-static const uint8_t kTestSessionId2[] = {4, 5, 6, 7, 8};
-static const uint8_t kTestSessionId3[] = {9, 0};
+static const std::vector<uint8_t> kTestSessionId1{1, 2, 3};
+static const std::vector<uint8_t> kTestSessionId2{4, 5, 6, 7, 8};
+static const std::vector<uint8_t> kTestSessionId3{9, 0};
 
 class DrmSessionManagerTest : public ::testing::Test {
 public:
     DrmSessionManagerTest()
-        : mDrmSessionManager(new DrmSessionManager(new FakeProcessInfo())),
-          mTestDrm1(new FakeDrm()),
-          mTestDrm2(new FakeDrm()) {
-        GetSessionId(kTestSessionId1, ARRAY_SIZE(kTestSessionId1), &mSessionId1);
-        GetSessionId(kTestSessionId2, ARRAY_SIZE(kTestSessionId2), &mSessionId2);
-        GetSessionId(kTestSessionId3, ARRAY_SIZE(kTestSessionId3), &mSessionId3);
+        : mService(new ResourceManagerService(new FakeProcessInfo(), new FakeSystemCallback())),
+          mDrmSessionManager(new DrmSessionManager(mService)),
+          mTestDrm1(new FakeDrm(kTestSessionId1, mDrmSessionManager)),
+          mTestDrm2(new FakeDrm(kTestSessionId2, mDrmSessionManager)),
+          mTestDrm3(new FakeDrm(kTestSessionId3, mDrmSessionManager)) {
+        DrmSessionManager *ptr = new DrmSessionManager(mService);
+        EXPECT_NE(ptr, nullptr);
+        /* mDrmSessionManager = ptr; */
     }
 
 protected:
-    static void GetSessionId(const uint8_t* ids, size_t num, Vector<uint8_t>* sessionId) {
-        for (size_t i = 0; i < num; ++i) {
-            sessionId->push_back(ids[i]);
-        }
-    }
-
-    static void ExpectEqSessionInfo(const SessionInfo& info, sp<DrmSessionClientInterface> drm,
-            const Vector<uint8_t>& sessionId, int64_t timeStamp) {
-        EXPECT_EQ(drm, info.drm);
-        EXPECT_TRUE(isEqualSessionId(sessionId, info.sessionId));
-        EXPECT_EQ(timeStamp, info.timeStamp);
-    }
-
     void addSession() {
-        mDrmSessionManager->addSession(kTestPid1, mTestDrm1, mSessionId1);
-        mDrmSessionManager->addSession(kTestPid2, mTestDrm2, mSessionId2);
-        mDrmSessionManager->addSession(kTestPid2, mTestDrm2, mSessionId3);
-        const PidSessionInfosMap& map = sessionMap();
-        EXPECT_EQ(2u, map.size());
-        ssize_t index1 = map.indexOfKey(kTestPid1);
-        ASSERT_GE(index1, 0);
-        const SessionInfos& infos1 = map[index1];
-        EXPECT_EQ(1u, infos1.size());
-        ExpectEqSessionInfo(infos1[0], mTestDrm1, mSessionId1, 0);
-
-        ssize_t index2 = map.indexOfKey(kTestPid2);
-        ASSERT_GE(index2, 0);
-        const SessionInfos& infos2 = map[index2];
-        EXPECT_EQ(2u, infos2.size());
-        ExpectEqSessionInfo(infos2[0], mTestDrm2, mSessionId2, 1);
-        ExpectEqSessionInfo(infos2[1], mTestDrm2, mSessionId3, 2);
+        mDrmSessionManager->addSession(kTestPid1, mTestDrm1, mTestDrm1->mSessionId);
+        mDrmSessionManager->addSession(kTestPid2, mTestDrm2, mTestDrm2->mSessionId);
+        mDrmSessionManager->addSession(kTestPid2, mTestDrm3, mTestDrm3->mSessionId);
     }
 
-    const PidSessionInfosMap& sessionMap() {
-        return mDrmSessionManager->mSessionMap;
-    }
-
-    void testGetLowestPriority() {
-        int pid;
-        int priority;
-        EXPECT_FALSE(mDrmSessionManager->getLowestPriority_l(&pid, &priority));
-
-        addSession();
-        EXPECT_TRUE(mDrmSessionManager->getLowestPriority_l(&pid, &priority));
-
-        EXPECT_EQ(kTestPid1, pid);
-        FakeProcessInfo processInfo;
-        int priority1;
-        processInfo.getPriority(kTestPid1, &priority1);
-        EXPECT_EQ(priority1, priority);
-    }
-
-    void testGetLeastUsedSession() {
-        sp<DrmSessionClientInterface> drm;
-        Vector<uint8_t> sessionId;
-        EXPECT_FALSE(mDrmSessionManager->getLeastUsedSession_l(kTestPid1, &drm, &sessionId));
-
-        addSession();
-
-        EXPECT_TRUE(mDrmSessionManager->getLeastUsedSession_l(kTestPid1, &drm, &sessionId));
-        EXPECT_EQ(mTestDrm1, drm);
-        EXPECT_TRUE(isEqualSessionId(mSessionId1, sessionId));
-
-        EXPECT_TRUE(mDrmSessionManager->getLeastUsedSession_l(kTestPid2, &drm, &sessionId));
-        EXPECT_EQ(mTestDrm2, drm);
-        EXPECT_TRUE(isEqualSessionId(mSessionId2, sessionId));
-
-        // mSessionId2 is no longer the least used session.
-        mDrmSessionManager->useSession(mSessionId2);
-        EXPECT_TRUE(mDrmSessionManager->getLeastUsedSession_l(kTestPid2, &drm, &sessionId));
-        EXPECT_EQ(mTestDrm2, drm);
-        EXPECT_TRUE(isEqualSessionId(mSessionId3, sessionId));
-    }
-
+    sp<IResourceManagerService> mService;
     sp<DrmSessionManager> mDrmSessionManager;
     sp<FakeDrm> mTestDrm1;
     sp<FakeDrm> mTestDrm2;
-    Vector<uint8_t> mSessionId1;
-    Vector<uint8_t> mSessionId2;
-    Vector<uint8_t> mSessionId3;
+    sp<FakeDrm> mTestDrm3;
 };
 
 TEST_F(DrmSessionManagerTest, addSession) {
     addSession();
+
+    EXPECT_EQ(3u, mDrmSessionManager->getSessionCount());
+    EXPECT_TRUE(mDrmSessionManager->containsSession(mTestDrm1->mSessionId));
+    EXPECT_TRUE(mDrmSessionManager->containsSession(mTestDrm2->mSessionId));
+    EXPECT_TRUE(mDrmSessionManager->containsSession(mTestDrm3->mSessionId));
 }
 
 TEST_F(DrmSessionManagerTest, useSession) {
     addSession();
 
-    mDrmSessionManager->useSession(mSessionId1);
-    mDrmSessionManager->useSession(mSessionId3);
+    mDrmSessionManager->useSession(mTestDrm1->mSessionId);
+    mDrmSessionManager->useSession(mTestDrm3->mSessionId);
 
-    const PidSessionInfosMap& map = sessionMap();
-    const SessionInfos& infos1 = map.valueFor(kTestPid1);
-    const SessionInfos& infos2 = map.valueFor(kTestPid2);
-    ExpectEqSessionInfo(infos1[0], mTestDrm1, mSessionId1, 3);
-    ExpectEqSessionInfo(infos2[1], mTestDrm2, mSessionId3, 4);
+    EXPECT_EQ(3u, mDrmSessionManager->getSessionCount());
+    EXPECT_TRUE(mDrmSessionManager->containsSession(mTestDrm1->mSessionId));
+    EXPECT_TRUE(mDrmSessionManager->containsSession(mTestDrm2->mSessionId));
+    EXPECT_TRUE(mDrmSessionManager->containsSession(mTestDrm3->mSessionId));
 }
 
 TEST_F(DrmSessionManagerTest, removeSession) {
     addSession();
 
-    mDrmSessionManager->removeSession(mSessionId2);
+    mDrmSessionManager->removeSession(mTestDrm2->mSessionId);
 
-    const PidSessionInfosMap& map = sessionMap();
-    EXPECT_EQ(2u, map.size());
-    const SessionInfos& infos1 = map.valueFor(kTestPid1);
-    const SessionInfos& infos2 = map.valueFor(kTestPid2);
-    EXPECT_EQ(1u, infos1.size());
-    EXPECT_EQ(1u, infos2.size());
-    // mSessionId2 has been removed.
-    ExpectEqSessionInfo(infos2[0], mTestDrm2, mSessionId3, 2);
-}
-
-TEST_F(DrmSessionManagerTest, removeDrm) {
-    addSession();
-
-    sp<FakeDrm> drm = new FakeDrm;
-    const uint8_t ids[] = {123};
-    Vector<uint8_t> sessionId;
-    GetSessionId(ids, ARRAY_SIZE(ids), &sessionId);
-    mDrmSessionManager->addSession(kTestPid2, drm, sessionId);
-
-    mDrmSessionManager->removeDrm(mTestDrm2);
-
-    const PidSessionInfosMap& map = sessionMap();
-    const SessionInfos& infos2 = map.valueFor(kTestPid2);
-    EXPECT_EQ(1u, infos2.size());
-    // mTestDrm2 has been removed.
-    ExpectEqSessionInfo(infos2[0], drm, sessionId, 3);
+    EXPECT_EQ(2u, mDrmSessionManager->getSessionCount());
+    EXPECT_TRUE(mDrmSessionManager->containsSession(mTestDrm1->mSessionId));
+    EXPECT_FALSE(mDrmSessionManager->containsSession(mTestDrm2->mSessionId));
+    EXPECT_TRUE(mDrmSessionManager->containsSession(mTestDrm3->mSessionId));
 }
 
 TEST_F(DrmSessionManagerTest, reclaimSession) {
@@ -224,30 +194,63 @@
     EXPECT_FALSE(mDrmSessionManager->reclaimSession(50));
 
     EXPECT_TRUE(mDrmSessionManager->reclaimSession(10));
-    EXPECT_EQ(1u, mTestDrm1->reclaimedSessions().size());
-    EXPECT_TRUE(isEqualSessionId(mSessionId1, mTestDrm1->reclaimedSessions()[0]));
-
-    mDrmSessionManager->removeSession(mSessionId1);
+    EXPECT_TRUE(mTestDrm1->isReclaimed());
 
     // add a session from a higher priority process.
-    sp<FakeDrm> drm = new FakeDrm;
-    const uint8_t ids[] = {1, 3, 5};
-    Vector<uint8_t> sessionId;
-    GetSessionId(ids, ARRAY_SIZE(ids), &sessionId);
-    mDrmSessionManager->addSession(15, drm, sessionId);
+    const std::vector<uint8_t> sid{1, 3, 5};
+    sp<FakeDrm> drm = new FakeDrm(sid, mDrmSessionManager);
+    mDrmSessionManager->addSession(15, drm, drm->mSessionId);
 
+    // make sure mTestDrm2 is reclaimed next instead of mTestDrm3
+    mDrmSessionManager->useSession(mTestDrm3->mSessionId);
     EXPECT_TRUE(mDrmSessionManager->reclaimSession(18));
-    EXPECT_EQ(1u, mTestDrm2->reclaimedSessions().size());
-    // mSessionId2 is reclaimed.
-    EXPECT_TRUE(isEqualSessionId(mSessionId2, mTestDrm2->reclaimedSessions()[0]));
+    EXPECT_TRUE(mTestDrm2->isReclaimed());
+
+    EXPECT_EQ(2u, mDrmSessionManager->getSessionCount());
+    EXPECT_FALSE(mDrmSessionManager->containsSession(mTestDrm1->mSessionId));
+    EXPECT_FALSE(mDrmSessionManager->containsSession(mTestDrm2->mSessionId));
+    EXPECT_TRUE(mDrmSessionManager->containsSession(mTestDrm3->mSessionId));
+    EXPECT_TRUE(mDrmSessionManager->containsSession(drm->mSessionId));
 }
 
-TEST_F(DrmSessionManagerTest, getLowestPriority) {
-    testGetLowestPriority();
-}
+TEST_F(DrmSessionManagerTest, reclaimAfterUse) {
+    // nothing to reclaim yet
+    EXPECT_FALSE(mDrmSessionManager->reclaimSession(kTestPid1));
+    EXPECT_FALSE(mDrmSessionManager->reclaimSession(kTestPid2));
 
-TEST_F(DrmSessionManagerTest, getLeastUsedSession_l) {
-    testGetLeastUsedSession();
+    // add sessions from same pid
+    mDrmSessionManager->addSession(kTestPid2, mTestDrm1, mTestDrm1->mSessionId);
+    mDrmSessionManager->addSession(kTestPid2, mTestDrm2, mTestDrm2->mSessionId);
+    mDrmSessionManager->addSession(kTestPid2, mTestDrm3, mTestDrm3->mSessionId);
+
+    // use some but not all sessions
+    mDrmSessionManager->useSession(mTestDrm1->mSessionId);
+    mDrmSessionManager->useSession(mTestDrm1->mSessionId);
+    mDrmSessionManager->useSession(mTestDrm2->mSessionId);
+
+    // calling pid priority is too low
+    int lowPriorityPid = kTestPid2 + 1;
+    EXPECT_FALSE(mDrmSessionManager->reclaimSession(lowPriorityPid));
+
+    // unused session is reclaimed first
+    int highPriorityPid = kTestPid2 - 1;
+    EXPECT_TRUE(mDrmSessionManager->reclaimSession(highPriorityPid));
+    EXPECT_FALSE(mTestDrm1->isReclaimed());
+    EXPECT_FALSE(mTestDrm2->isReclaimed());
+    EXPECT_TRUE(mTestDrm3->isReclaimed());
+    mDrmSessionManager->removeSession(mTestDrm3->mSessionId);
+
+    // less-used session is reclaimed next
+    EXPECT_TRUE(mDrmSessionManager->reclaimSession(highPriorityPid));
+    EXPECT_FALSE(mTestDrm1->isReclaimed());
+    EXPECT_TRUE(mTestDrm2->isReclaimed());
+    EXPECT_TRUE(mTestDrm3->isReclaimed());
+
+    // most-used session still open
+    EXPECT_EQ(1u, mDrmSessionManager->getSessionCount());
+    EXPECT_TRUE(mDrmSessionManager->containsSession(mTestDrm1->mSessionId));
+    EXPECT_FALSE(mDrmSessionManager->containsSession(mTestDrm2->mSessionId));
+    EXPECT_FALSE(mDrmSessionManager->containsSession(mTestDrm3->mSessionId));
 }
 
 } // namespace android
diff --git a/media/utils/ProcessInfo.cpp b/media/utils/ProcessInfo.cpp
index 27f1a79..113e4a7 100644
--- a/media/utils/ProcessInfo.cpp
+++ b/media/utils/ProcessInfo.cpp
@@ -23,6 +23,7 @@
 #include <binder/IPCThreadState.h>
 #include <binder/IProcessInfoService.h>
 #include <binder/IServiceManager.h>
+#include <private/android_filesystem_config.h>
 
 namespace android {
 
@@ -55,8 +56,9 @@
 
 bool ProcessInfo::isValidPid(int pid) {
     int callingPid = IPCThreadState::self()->getCallingPid();
+    int callingUid = IPCThreadState::self()->getCallingUid();
     // Trust it if this is called from the same process otherwise pid has to match the calling pid.
-    return (callingPid == getpid()) || (callingPid == pid);
+    return (callingPid == getpid()) || (callingPid == pid) || (callingUid == AID_MEDIA);
 }
 
 ProcessInfo::~ProcessInfo() {}
diff --git a/services/mediadrm/Android.mk b/services/mediadrm/Android.mk
index d4bb48a..707a2aa 100644
--- a/services/mediadrm/Android.mk
+++ b/services/mediadrm/Android.mk
@@ -26,6 +26,7 @@
 LOCAL_SHARED_LIBRARIES:= \
     libbinder \
     liblog \
+    libmedia \
     libmediadrm \
     libutils \
     libhidlbase \
diff --git a/services/mediaresourcemanager/Android.bp b/services/mediaresourcemanager/Android.bp
index f3339a0..d468406 100644
--- a/services/mediaresourcemanager/Android.bp
+++ b/services/mediaresourcemanager/Android.bp
@@ -23,4 +23,6 @@
         "-Wall",
     ],
 
+    export_include_dirs: ["."],
+
 }
diff --git a/services/mediaresourcemanager/ResourceManagerService.cpp b/services/mediaresourcemanager/ResourceManagerService.cpp
index bdcd5e4..45eea0f 100644
--- a/services/mediaresourcemanager/ResourceManagerService.cpp
+++ b/services/mediaresourcemanager/ResourceManagerService.cpp
@@ -290,6 +290,18 @@
     }
 }
 
+void ResourceManagerService::mergeResources(
+        MediaResource& r1, const MediaResource& r2) {
+    if (r1.mType == MediaResource::kDrmSession) {
+        // This means we are using a session. Each session's mValue is initialized to UINT64_MAX.
+        // The oftener a session is used the lower it's mValue. During reclaim the session with
+        // the highest mValue/lowest usage would be closed.
+        r1.mValue -= (r1.mValue == 0 ? 0 : 1);
+    } else {
+        r1.mValue += r2.mValue;
+    }
+}
+
 void ResourceManagerService::addResource(
         int pid,
         int uid,
@@ -309,15 +321,16 @@
     ResourceInfo& info = getResourceInfoForEdit(uid, clientId, client, infos);
 
     for (size_t i = 0; i < resources.size(); ++i) {
-        const auto resType = std::make_pair(resources[i].mType, resources[i].mSubType);
+        const auto &res = resources[i];
+        const auto resType = std::tuple(res.mType, res.mSubType, res.mId);
         if (info.resources.find(resType) == info.resources.end()) {
-            onFirstAdded(resources[i], info);
-            info.resources[resType] = resources[i];
+            onFirstAdded(res, info);
+            info.resources[resType] = res;
         } else {
-            info.resources[resType].mValue += resources[i].mValue;
+            mergeResources(info.resources[resType], res);
         }
     }
-    if (info.deathNotifier == nullptr) {
+    if (info.deathNotifier == nullptr && client != nullptr) {
         info.deathNotifier = new DeathNotifier(this, pid, clientId);
         IInterface::asBinder(client)->linkToDeath(info.deathNotifier);
     }
@@ -351,14 +364,17 @@
     ResourceInfo &info = infos.editValueAt(index);
 
     for (size_t i = 0; i < resources.size(); ++i) {
-        const auto resType = std::make_pair(resources[i].mType, resources[i].mSubType);
+        const auto &res = resources[i];
+        const auto resType = std::tuple(res.mType, res.mSubType, res.mId);
         // ignore if we don't have it
         if (info.resources.find(resType) != info.resources.end()) {
             MediaResource &resource = info.resources[resType];
-            if (resource.mValue > resources[i].mValue) {
-                resource.mValue -= resources[i].mValue;
+            if (resource.mValue > res.mValue) {
+                resource.mValue -= res.mValue;
             } else {
-                onLastRemoved(resources[i], info);
+                // drm sessions always take this branch because res.mValue is set
+                // to UINT64_MAX
+                onLastRemoved(res, info);
                 info.resources.erase(resType);
             }
         }
@@ -430,6 +446,7 @@
         const MediaResource *secureCodec = NULL;
         const MediaResource *nonSecureCodec = NULL;
         const MediaResource *graphicMemory = NULL;
+        const MediaResource *drmSession = NULL;
         for (size_t i = 0; i < resources.size(); ++i) {
             MediaResource::Type type = resources[i].mType;
             if (resources[i].mType == MediaResource::kSecureCodec) {
@@ -438,6 +455,8 @@
                 nonSecureCodec = &resources[i];
             } else if (type == MediaResource::kGraphicMemory) {
                 graphicMemory = &resources[i];
+            } else if (type == MediaResource::kDrmSession) {
+                drmSession = &resources[i];
             }
         }
 
@@ -461,6 +480,12 @@
                 }
             }
         }
+        if (drmSession != NULL) {
+            getClientForResource_l(callingPid, drmSession, &clients);
+            if (clients.size() == 0) {
+                return false;
+            }
+        }
 
         if (clients.size() == 0) {
             // if no secure/non-secure codec conflict, run second pass to handle other resources.
diff --git a/services/mediaresourcemanager/ResourceManagerService.h b/services/mediaresourcemanager/ResourceManagerService.h
index f086dc3..44d0c28 100644
--- a/services/mediaresourcemanager/ResourceManagerService.h
+++ b/services/mediaresourcemanager/ResourceManagerService.h
@@ -33,7 +33,7 @@
 class ServiceLog;
 struct ProcessInfoInterface;
 
-typedef std::map<std::pair<MediaResource::Type, MediaResource::SubType>, MediaResource> ResourceList;
+typedef std::map<std::tuple<MediaResource::Type, MediaResource::SubType, std::vector<uint8_t>>, MediaResource> ResourceList;
 struct ResourceInfo {
     int64_t clientId;
     uid_t uid;
@@ -126,6 +126,9 @@
     void onFirstAdded(const MediaResource& res, const ResourceInfo& clientInfo);
     void onLastRemoved(const MediaResource& res, const ResourceInfo& clientInfo);
 
+    // Merge r2 into r1
+    void mergeResources(MediaResource& r1, const MediaResource& r2);
+
     mutable Mutex mLock;
     sp<ProcessInfoInterface> mProcessInfo;
     sp<SystemCallbackInterface> mSystemCB;
diff --git a/services/mediaresourcemanager/test/ResourceManagerService_test.cpp b/services/mediaresourcemanager/test/ResourceManagerService_test.cpp
index ae97ec8..9e14151 100644
--- a/services/mediaresourcemanager/test/ResourceManagerService_test.cpp
+++ b/services/mediaresourcemanager/test/ResourceManagerService_test.cpp
@@ -173,8 +173,9 @@
         // convert resource1 to ResourceList
         ResourceList r1;
         for (size_t i = 0; i < resources1.size(); ++i) {
-            const auto resType = std::make_pair(resources1[i].mType, resources1[i].mSubType);
-            r1[resType] = resources1[i];
+            const auto &res = resources1[i];
+            const auto resType = std::tuple(res.mType, res.mSubType, res.mId);
+            r1[resType] = res;
         }
         return r1 == resources2;
     }