Add support for head-tracker sensor type and reference reset

Pose listeners can now be notified whenever the received pose has a
different frame of reference than previous samples and this results in
a recenter.

In addition, the private sensor type used for head tracking is now
supported.


Test: Manual verification using uhid-sample
Bug: 188502620
Change-Id: Ibe86e654656b2797ecc6fc936769e1148d9e02fd
diff --git a/media/libheadtracking/SensorPoseProvider-example.cpp b/media/libheadtracking/SensorPoseProvider-example.cpp
index a246e8b..88e222e 100644
--- a/media/libheadtracking/SensorPoseProvider-example.cpp
+++ b/media/libheadtracking/SensorPoseProvider-example.cpp
@@ -40,7 +40,7 @@
 class Listener : public SensorPoseProvider::Listener {
   public:
     void onPose(int64_t timestamp, int32_t handle, const Pose3f& pose,
-                const std::optional<Twist3f>& twist) override {
+                const std::optional<Twist3f>& twist, bool isNewReference) override {
         int64_t now = elapsedRealtimeNano();
 
         std::cout << "onPose t=" << timestamp
@@ -53,7 +53,7 @@
         } else {
             std::cout << "<none>";
         }
-        std::cout << std::endl;
+        std::cout << " isNewReference=" << isNewReference << std::endl;
     }
 };
 
@@ -67,11 +67,15 @@
 
     std::unique_ptr<SensorPoseProvider> provider =
             SensorPoseProvider::create(kPackageName, &listener);
-    int32_t headHandle = provider->startSensor(headSensor->getHandle(), 500ms);
+    if (!provider->startSensor(headSensor->getHandle(), 500ms)) {
+        std::cout << "Failed to start head sensor" << std::endl;
+    }
     sleep(2);
-    provider->startSensor(screenSensor->getHandle(), 500ms);
+    if (!provider->startSensor(screenSensor->getHandle(), 500ms)) {
+        std::cout << "Failed to start screenSensor sensor" << std::endl;
+    }
     sleep(2);
-    provider->stopSensor(headHandle);
+    provider->stopSensor(headSensor->getHandle());
     sleep(2);
     return 0;
 }
diff --git a/media/libheadtracking/SensorPoseProvider.cpp b/media/libheadtracking/SensorPoseProvider.cpp
index c4c031d..ec5e1ec 100644
--- a/media/libheadtracking/SensorPoseProvider.cpp
+++ b/media/libheadtracking/SensorPoseProvider.cpp
@@ -24,6 +24,7 @@
 #include <map>
 #include <thread>
 
+#include <android-base/thread_annotations.h>
 #include <log/log_main.h>
 #include <sensor/Sensor.h>
 #include <sensor/SensorEventQueue.h>
@@ -86,7 +87,7 @@
         if (mSensor != SensorPoseProvider::INVALID_HANDLE) {
             int ret = mQueue->disableSensor(mSensor);
             if (ret) {
-                ALOGE("Failed to disable sensor: %s\n", strerror(ret));
+                ALOGE("Failed to disable sensor: %s", strerror(ret));
             }
         }
     }
@@ -123,9 +124,23 @@
     }
 
     bool startSensor(int32_t sensor, std::chrono::microseconds samplingPeriod) override {
+        // Figure out the sensor's data format.
+        DataFormat format = getSensorFormat(sensor);
+        if (format == DataFormat::kUnknown) {
+            ALOGE("Unknown format for sensor %" PRId32, sensor);
+            return false;
+        }
+
+        {
+            std::lock_guard lock(mMutex);
+            mEnabledSensorFormats.emplace(sensor, format);
+        }
+
         // Enable the sensor.
         if (mQueue->enableSensor(sensor, samplingPeriod.count(), 0, 0)) {
             ALOGE("Failed to enable sensor");
+            std::lock_guard lock(mMutex);
+            mEnabledSensorFormats.erase(sensor);
             return false;
         }
 
@@ -133,14 +148,32 @@
         return true;
     }
 
-    void stopSensor(int handle) override { mEnabledSensors.erase(handle); }
+    void stopSensor(int handle) override {
+        mEnabledSensors.erase(handle);
+        std::lock_guard lock(mMutex);
+        mEnabledSensorFormats.erase(handle);
+    }
 
   private:
+    enum DataFormat {
+        kUnknown,
+        kQuaternion,
+        kRotationVectorsAndFlags,
+    };
+
+    struct PoseEvent {
+        Pose3f pose;
+        std::optional<Twist3f> twist;
+        bool isNewReference;
+    };
+
     sp<Looper> mLooper;
     Listener* const mListener;
-
+    SensorManager* const mSensorManager;
     std::thread mThread;
+    std::mutex mMutex;
     std::map<int32_t, SensorEnableGuard> mEnabledSensors;
+    std::map<int32_t, DataFormat> mEnabledSensorFormats GUARDED_BY(mMutex);
     sp<SensorEventQueue> mQueue;
 
     // We must do some of the initialization operations on the worker thread, because the API relies
@@ -153,21 +186,19 @@
 
     SensorPoseProviderImpl(const char* packageName, Listener* listener)
         : mListener(listener),
-          mThread([this, p = std::string(packageName)] { threadFunc(p.c_str()); }) {}
+          mSensorManager(&SensorManager::getInstanceForPackage(String16(packageName))),
+          mThread([this] { threadFunc(); }) {}
 
     void initFinished(bool success) { mInitPromise.set_value(success); }
 
     bool waitInitFinished() { return mInitPromise.get_future().get(); }
 
-    void threadFunc(const char* packageName) {
+    void threadFunc() {
         // Obtain looper.
         mLooper = Looper::prepare(ALOOPER_PREPARE_ALLOW_NON_CALLBACKS);
 
-        // Obtain sensor manager.
-        SensorManager& sensorManager = SensorManager::getInstanceForPackage(String16(packageName));
-
         // Create event queue.
-        mQueue = sensorManager.createEventQueue();
+        mQueue = mSensorManager->createEventQueue();
 
         if (mQueue == nullptr) {
             ALOGE("Failed to create a sensor event queue");
@@ -217,24 +248,98 @@
     }
 
     void handleEvent(const ASensorEvent& event) {
-        auto value = parseEvent(event);
-        mListener->onPose(event.timestamp, event.sensor, std::get<0>(value), std::get<1>(value));
+        DataFormat format;
+        {
+            std::lock_guard lock(mMutex);
+            auto iter = mEnabledSensorFormats.find(event.sensor);
+            if (iter == mEnabledSensorFormats.end()) {
+                // This can happen if we have any pending events shortly after stopping.
+                return;
+            }
+            format = iter->second;
+        }
+        auto value = parseEvent(event, format);
+        mListener->onPose(event.timestamp, event.sensor, value.pose, value.twist,
+                          value.isNewReference);
     }
 
-    static std::tuple<Pose3f, std::optional<Twist3f>> parseEvent(const ASensorEvent& event) {
+    DataFormat getSensorFormat(int32_t handle) {
+        std::optional<const Sensor> sensor = getSensorByHandle(handle);
+        if (!sensor) {
+            ALOGE("Sensor not found: %d", handle);
+            return DataFormat::kUnknown;
+        }
+        if (sensor->getType() == ASENSOR_TYPE_ROTATION_VECTOR ||
+            sensor->getType() == ASENSOR_TYPE_GAME_ROTATION_VECTOR) {
+            return DataFormat::kQuaternion;
+        }
+
+        if (sensor->getStringType() == "com.google.hardware.sensor.hid_dynamic.headtracker") {
+            return DataFormat::kRotationVectorsAndFlags;
+        }
+
+        return DataFormat::kUnknown;
+    }
+
+    std::optional<const Sensor> getSensorByHandle(int32_t handle) {
+        const Sensor* const* list;
+        ssize_t size;
+
+        // Search static sensor list.
+        size = mSensorManager->getSensorList(&list);
+        if (size < 0) {
+            ALOGE("getSensorList failed with error code %zd", size);
+            return std::nullopt;
+        }
+        for (size_t i = 0; i < size; ++i) {
+            if (list[i]->getHandle() == handle) {
+                return *list[i];
+            }
+        }
+
+        // Search dynamic sensor list.
+        Vector<Sensor> dynList;
+        size = mSensorManager->getDynamicSensorList(dynList);
+        if (size < 0) {
+            ALOGE("getDynamicSensorList failed with error code %zd", size);
+            return std::nullopt;
+        }
+        for (size_t i = 0; i < size; ++i) {
+            if (dynList[i].getHandle() == handle) {
+                return dynList[i];
+            }
+        }
+
+        return std::nullopt;
+    }
+
+    static PoseEvent parseEvent(const ASensorEvent& event, DataFormat format) {
         // TODO(ytai): Add more types.
-        switch (event.type) {
-            case ASENSOR_TYPE_ROTATION_VECTOR:
-            case ASENSOR_TYPE_GAME_ROTATION_VECTOR: {
+        switch (format) {
+            case DataFormat::kQuaternion: {
                 Eigen::Quaternionf quat(event.data[3], event.data[0], event.data[1], event.data[2]);
                 // Adapt to different frame convention.
                 quat *= rotateX(-M_PI_2);
-                return std::make_tuple(Pose3f(quat), std::optional<Twist3f>());
+                return PoseEvent{Pose3f(quat), std::optional<Twist3f>(), false};
+            }
+
+            case DataFormat::kRotationVectorsAndFlags: {
+                // Custom sensor, assumed to contain:
+                // 3 floats representing orientation as a rotation vector (in rad).
+                // 3 floats representing angular velocity as a rotation vector (in rad/s).
+                // 1 uint32_t of flags, where:
+                // - LSb is '1' iff the given sample is the first one in a new frame of reference.
+                // - The rest of the bits are reserved for future use.
+                Eigen::Vector3f rotation = {event.data[0], event.data[1], event.data[2]};
+                Eigen::Vector3f twist = {event.data[3], event.data[4], event.data[5]};
+                Eigen::Quaternionf quat = rotationVectorToQuaternion(rotation);
+                uint32_t flags = *reinterpret_cast<const uint32_t*>(&event.data[6]);
+                return PoseEvent{Pose3f(quat), Twist3f(Eigen::Vector3f::Zero(), twist),
+                                 (flags & (1 << 0)) != 0};
             }
 
             default:
-                ALOGE("Unsupported sensor type: %" PRId32, event.type);
-                return std::make_tuple(Pose3f(), std::optional<Twist3f>());
+                LOG_ALWAYS_FATAL("Unexpected sensor type: %d", static_cast<int>(format));
         }
     }
 };
diff --git a/media/libheadtracking/include/media/SensorPoseProvider.h b/media/libheadtracking/include/media/SensorPoseProvider.h
index 1a5deb0..d2a6b77 100644
--- a/media/libheadtracking/include/media/SensorPoseProvider.h
+++ b/media/libheadtracking/include/media/SensorPoseProvider.h
@@ -61,7 +61,7 @@
         virtual ~Listener() = default;
 
         virtual void onPose(int64_t timestamp, int32_t handle, const Pose3f& pose,
-                            const std::optional<Twist3f>& twist) = 0;
+                            const std::optional<Twist3f>& twist, bool isNewReference) = 0;
     };
 
     /**
diff --git a/services/audiopolicy/service/SpatializerPoseController.cpp b/services/audiopolicy/service/SpatializerPoseController.cpp
index eb23298..ffedf63 100644
--- a/services/audiopolicy/service/SpatializerPoseController.cpp
+++ b/services/audiopolicy/service/SpatializerPoseController.cpp
@@ -224,13 +224,19 @@
 }
 
 void SpatializerPoseController::onPose(int64_t timestamp, int32_t sensor, const Pose3f& pose,
-                                       const std::optional<Twist3f>& twist) {
+                                       const std::optional<Twist3f>& twist, bool isNewReference) {
     std::lock_guard lock(mMutex);
     if (sensor == mHeadSensor) {
         mProcessor->setWorldToHeadPose(timestamp, pose, twist.value_or(Twist3f()));
+        if (isNewReference) {
+            mProcessor->recenter(true, false);
+        }
     }
     if (sensor == mScreenSensor) {
         mProcessor->setWorldToScreenPose(timestamp, pose);
+        if (isNewReference) {
+            mProcessor->recenter(false, true);
+        }
     }
 }
 
diff --git a/services/audiopolicy/service/SpatializerPoseController.h b/services/audiopolicy/service/SpatializerPoseController.h
index c579622..2b5c189 100644
--- a/services/audiopolicy/service/SpatializerPoseController.h
+++ b/services/audiopolicy/service/SpatializerPoseController.h
@@ -130,7 +130,7 @@
     bool mCalculated = false;
 
     void onPose(int64_t timestamp, int32_t sensor, const media::Pose3f& pose,
-                const std::optional<media::Twist3f>& twist) override;
+                const std::optional<media::Twist3f>& twist, bool isNewReference) override;
 
     /**
      * Calculates the new outputs and updates internal state. Must be called with the lock held.