From 992facdff33580e73376278b0740bb604cfe793a Mon Sep 17 00:00:00 2001 From: Shamali P Date: Thu, 6 Mar 2025 16:53:53 +0000 Subject: [PATCH] Unregister widget prediction callback on clear Noticed that predictor holds the registered callback and caused leak. Using named listener classes for better stacks and also unregisters them Bug: N/A Flag: EXEMPT BUGFIX Test: WidgetPredictionsRequesterTest Change-Id: I94211ddbc77077c98b804827bb1cecdefe57703b --- .../launcher3/WidgetPickerActivity.java | 8 +- .../model/WidgetPredictionsRequester.java | 98 ++++++++++------- .../model/WidgetsPredictionsRequesterTest.kt | 102 +++++++++++++++--- 3 files changed, 149 insertions(+), 59 deletions(-) diff --git a/quickstep/src/com/android/launcher3/WidgetPickerActivity.java b/quickstep/src/com/android/launcher3/WidgetPickerActivity.java index 7f3e615d31..4d3e3bea67 100644 --- a/quickstep/src/com/android/launcher3/WidgetPickerActivity.java +++ b/quickstep/src/com/android/launcher3/WidgetPickerActivity.java @@ -68,7 +68,8 @@ import java.util.function.Predicate; import java.util.regex.Pattern; /** An Activity that can host Launcher's widget picker. */ -public class WidgetPickerActivity extends BaseActivity { +public class WidgetPickerActivity extends BaseActivity implements + WidgetPredictionsRequester.WidgetPredictionsListener { private static final String TAG = "WidgetPickerActivity"; /** * Name of the extra that indicates that a widget being dragged. @@ -322,7 +323,7 @@ public class WidgetPickerActivity extends BaseActivity { if (mUiSurface != null) { mWidgetPredictionsRequester = new WidgetPredictionsRequester(app.getContext(), mUiSurface, mModel.getWidgetsByComponentKeyForPicker()); - mWidgetPredictionsRequester.request(mAddedWidgets, this::bindRecommendedWidgets); + mWidgetPredictionsRequester.request(mAddedWidgets, /*listener=*/ this); } }); } @@ -355,7 +356,8 @@ public class WidgetPickerActivity extends BaseActivity { }); } - private void bindRecommendedWidgets(List recommendedWidgets) { + @Override + public void onPredictionsAvailable(List recommendedWidgets) { // Bind recommendations once picker has finished open animation. MAIN_EXECUTOR.getHandler().postDelayed( () -> mWidgetPickerDataProvider.setWidgetRecommendations(recommendedWidgets), diff --git a/quickstep/src/com/android/launcher3/model/WidgetPredictionsRequester.java b/quickstep/src/com/android/launcher3/model/WidgetPredictionsRequester.java index d3ac975052..f9cec82834 100644 --- a/quickstep/src/com/android/launcher3/model/WidgetPredictionsRequester.java +++ b/quickstep/src/com/android/launcher3/model/WidgetPredictionsRequester.java @@ -48,7 +48,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.function.Consumer; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -56,7 +55,7 @@ import java.util.stream.Collectors; * Works with app predictor to fetch and process widget predictions displayed in a standalone * widget picker activity for a UI surface. */ -public class WidgetPredictionsRequester { +public class WidgetPredictionsRequester implements AppPredictor.Callback { private static final int NUM_OF_RECOMMENDED_WIDGETS_PREDICATION = 20; private static final String BUNDLE_KEY_ADDED_APP_WIDGETS = "added_app_widgets"; // container/screenid/[positionx,positiony]/[spanx,spany] @@ -71,6 +70,9 @@ public class WidgetPredictionsRequester { @NonNull private final String mUiSurface; private boolean mPredictionsAvailable; + @Nullable + private WidgetPredictionsListener mPredictionsListener = null; + @Nullable Predicate mFilter = null; @NonNull private final Map mAllWidgets; @@ -81,36 +83,49 @@ public class WidgetPredictionsRequester { mAllWidgets = Collections.unmodifiableMap(allWidgets); } + // AppPredictor.Callback -> onTargetsAvailable + @Override + @WorkerThread + public void onTargetsAvailable(List targets) { + List filteredPredictions = filterPredictions(targets, mAllWidgets, mFilter); + List mappedPredictions = mapWidgetItemsToItemInfo(filteredPredictions); + + if (!mPredictionsAvailable && mPredictionsListener != null) { + mPredictionsAvailable = true; + MAIN_EXECUTOR.execute( + () -> mPredictionsListener.onPredictionsAvailable(mappedPredictions)); + } + } + /** * Requests one time predictions from the app predictions manager and invokes provided callback - * once predictions are available. + * once predictions are available. Any previous requests may be cancelled. * * @param existingWidgets widgets that are currently added to the surface; - * @param callback consumer of prediction results to be called when predictions are - * available + * @param listener consumer of prediction results to be called when predictions are + * available; any previous listener will no longer receive updates. */ + @WorkerThread // e.g. MODEL_EXECUTOR public void request(List existingWidgets, - Consumer> callback) { + WidgetPredictionsListener listener) { + clear(); + mPredictionsListener = listener; + mFilter = notOnUiSurfaceFilter(existingWidgets); + + AppPredictionManager apm = mContext.getSystemService(AppPredictionManager.class); + if (apm == null) { + return; + } + Bundle bundle = buildBundleForPredictionSession(existingWidgets); - Predicate filter = notOnUiSurfaceFilter(existingWidgets); - - MODEL_EXECUTOR.execute(() -> { - clear(); - AppPredictionManager apm = mContext.getSystemService(AppPredictionManager.class); - if (apm == null) { - return; - } - - mAppPredictor = apm.createAppPredictionSession( - new AppPredictionContext.Builder(mContext) - .setUiSurface(mUiSurface) - .setExtras(bundle) - .setPredictedTargetCount(NUM_OF_RECOMMENDED_WIDGETS_PREDICATION) - .build()); - mAppPredictor.registerPredictionUpdates(MODEL_EXECUTOR, - targets -> bindPredictions(targets, filter, callback)); - mAppPredictor.requestPredictionUpdate(); - }); + mAppPredictor = apm.createAppPredictionSession( + new AppPredictionContext.Builder(mContext) + .setUiSurface(mUiSurface) + .setExtras(bundle) + .setPredictedTargetCount(NUM_OF_RECOMMENDED_WIDGETS_PREDICATION) + .build()); + mAppPredictor.registerPredictionUpdates(MODEL_EXECUTOR, /*callback=*/ this); + mAppPredictor.requestPredictionUpdate(); } /** @@ -158,27 +173,14 @@ public class WidgetPredictionsRequester { return widgetItem -> !existingComponentKeys.contains(widgetItem); } - /** Provides the predictions returned by the predictor to the registered callback. */ - @WorkerThread - private void bindPredictions(List targets, Predicate filter, - Consumer> callback) { - if (!mPredictionsAvailable) { - mPredictionsAvailable = true; - List filteredPredictions = filterPredictions(targets, mAllWidgets, filter); - List mappedPredictions = mapWidgetItemsToItemInfo(filteredPredictions); - - MAIN_EXECUTOR.execute(() -> callback.accept(mappedPredictions)); - MODEL_EXECUTOR.execute(this::clear); - } - } - /** * Applies the provided filter (e.g. widgets not on workspace) on the predictions returned by * the predictor. */ @VisibleForTesting static List filterPredictions(List predictions, - Map allWidgets, Predicate filter) { + @NonNull Map allWidgets, + @Nullable Predicate filter) { List servicePredictedItems = new ArrayList<>(); for (AppTarget prediction : predictions) { @@ -187,7 +189,7 @@ public class WidgetPredictionsRequester { WidgetItem widgetItem = allWidgets.get( new ComponentKey(new ComponentName(prediction.getPackageName(), className), prediction.getUser())); - if (widgetItem != null && filter.test(widgetItem)) { + if (widgetItem != null && (filter == null || filter.test(widgetItem))) { servicePredictedItems.add(widgetItem); } } @@ -218,9 +220,23 @@ public class WidgetPredictionsRequester { /** Cleans up any open prediction sessions. */ public void clear() { if (mAppPredictor != null) { + mAppPredictor.unregisterPredictionUpdates(this); mAppPredictor.destroy(); mAppPredictor = null; } + mPredictionsListener = null; mPredictionsAvailable = false; + mFilter = null; + } + + /** + * Listener class to listen to updates from the {@link WidgetPredictionsRequester} + */ + public interface WidgetPredictionsListener { + /** + * Callback method that is called when the predicted widgets are available. + * @param predictions list of predicted widgets {@link PendingAddWidgetInfo} + */ + void onPredictionsAvailable(List predictions); } } diff --git a/quickstep/tests/multivalentTests/src/com/android/launcher3/model/WidgetsPredictionsRequesterTest.kt b/quickstep/tests/multivalentTests/src/com/android/launcher3/model/WidgetsPredictionsRequesterTest.kt index 4ea74df776..d445189039 100644 --- a/quickstep/tests/multivalentTests/src/com/android/launcher3/model/WidgetsPredictionsRequesterTest.kt +++ b/quickstep/tests/multivalentTests/src/com/android/launcher3/model/WidgetsPredictionsRequesterTest.kt @@ -16,6 +16,8 @@ package com.android.launcher3.model +import android.app.prediction.AppPredictionManager +import android.app.prediction.AppPredictor import android.app.prediction.AppTarget import android.app.prediction.AppTargetEvent import android.app.prediction.AppTargetId @@ -36,9 +38,15 @@ import com.android.launcher3.model.WidgetPredictionsRequester.filterPredictions import com.android.launcher3.model.WidgetPredictionsRequester.notOnUiSurfaceFilter import com.android.launcher3.util.ActivityContextWrapper import com.android.launcher3.util.ComponentKey +import com.android.launcher3.util.Executors +import com.android.launcher3.util.Executors.MODEL_EXECUTOR +import com.android.launcher3.util.TestUtil import com.android.launcher3.util.WidgetUtils.createAppWidgetProviderInfo import com.android.launcher3.widget.LauncherAppWidgetProviderInfo +import com.android.launcher3.widget.PendingAddWidgetInfo import com.google.common.truth.Truth.assertThat +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit import java.util.function.Predicate import junit.framework.Assert.assertNotNull import org.junit.Before @@ -46,6 +54,9 @@ import org.junit.Test import org.junit.runner.RunWith import org.mockito.Mock import org.mockito.MockitoAnnotations +import org.mockito.kotlin.any +import org.mockito.kotlin.doAnswer +import org.mockito.kotlin.whenever @RunWith(AndroidJUnit4::class) class WidgetsPredictionsRequesterTest { @@ -67,11 +78,26 @@ class WidgetsPredictionsRequesterTest { @Mock private lateinit var iconCache: IconCache + @Mock private lateinit var apmMock: AppPredictionManager + + @Mock private lateinit var predictorMock: AppPredictor + @Before fun setUp() { MockitoAnnotations.initMocks(this) mUserHandle = myUserHandle() - context = ActivityContextWrapper(ApplicationProvider.getApplicationContext()) + + whenever(apmMock.createAppPredictionSession(any())).thenReturn(predictorMock) + + context = + object : ActivityContextWrapper(ApplicationProvider.getApplicationContext()) { + override fun getSystemService(name: String): Any? { + if (name == "app_prediction") { + return apmMock + } + return super.getSystemService(name) + } + } testInvariantProfile = LauncherAppState.getIDP(context) deviceProfile = testInvariantProfile.getDeviceProfile(context).copy(context) @@ -114,21 +140,67 @@ class WidgetsPredictionsRequesterTest { buildExpectedAppTargetEvent( /*pkg=*/ APP_1_PACKAGE_NAME, /*providerClassName=*/ APP_1_PROVIDER_A_CLASS_NAME, - /*user=*/ mUserHandle + /*user=*/ mUserHandle, ), buildExpectedAppTargetEvent( /*pkg=*/ APP_1_PACKAGE_NAME, /*providerClassName=*/ APP_1_PROVIDER_B_CLASS_NAME, - /*user=*/ mUserHandle + /*user=*/ mUserHandle, ), buildExpectedAppTargetEvent( /*pkg=*/ APP_2_PACKAGE_NAME, /*providerClassName=*/ APP_2_PROVIDER_1_CLASS_NAME, - /*user=*/ mUserHandle - ) + /*user=*/ mUserHandle, + ), ) } + @Test + fun request_invokesCallbackWithPredictedItems() { + TestUtil.runOnExecutorSync(MODEL_EXECUTOR) { + val underTest = WidgetPredictionsRequester(context, TEST_UI_SURFACE, allWidgets) + val existingWidgets = arrayListOf(widget1aInfo, widget1bInfo) + val predictions = + listOf( + // (existing) already on surface + AppTarget( + AppTargetId(APP_1_PACKAGE_NAME), + APP_1_PACKAGE_NAME, + APP_1_PROVIDER_B_CLASS_NAME, + mUserHandle, + ), + // eligible + AppTarget( + AppTargetId(APP_2_PACKAGE_NAME), + APP_2_PACKAGE_NAME, + APP_2_PROVIDER_1_CLASS_NAME, + mUserHandle, + ), + ) + doAnswer { + underTest.onTargetsAvailable(predictions) + null + } + .whenever(predictorMock) + .requestPredictionUpdate() + val testCountDownLatch = CountDownLatch(1) + val listener = + WidgetPredictionsRequester.WidgetPredictionsListener { itemInfos -> + if (itemInfos.size == 1 && itemInfos[0] is PendingAddWidgetInfo) { + // only one item was eligible. + testCountDownLatch.countDown() + } else { + println("Unexpected prediction items found: ${itemInfos.size}") + } + } + + underTest.request(existingWidgets, listener) + TestUtil.runOnExecutorSync(Executors.MAIN_EXECUTOR) {} + + assertThat(testCountDownLatch.await(TEST_TIMEOUT, TimeUnit.SECONDS)).isTrue() + } + } + @Test fun filterPredictions_notOnUiSurfaceFilter_returnsOnlyEligiblePredictions() { val widgetsAlreadyOnSurface = arrayListOf(widget1bInfo) @@ -141,15 +213,15 @@ class WidgetsPredictionsRequesterTest { AppTargetId(APP_1_PACKAGE_NAME), APP_1_PACKAGE_NAME, APP_1_PROVIDER_B_CLASS_NAME, - mUserHandle + mUserHandle, ), // eligible AppTarget( AppTargetId(APP_2_PACKAGE_NAME), APP_2_PACKAGE_NAME, APP_2_PROVIDER_1_CLASS_NAME, - mUserHandle - ) + mUserHandle, + ), ) // only 2 was eligible @@ -167,27 +239,27 @@ class WidgetsPredictionsRequesterTest { AppTargetId(APP_1_PACKAGE_NAME), APP_1_PACKAGE_NAME, "$APP_1_PACKAGE_NAME.SomeActivity", - mUserHandle + mUserHandle, ), AppTarget( AppTargetId(APP_2_PACKAGE_NAME), APP_2_PACKAGE_NAME, "$APP_2_PACKAGE_NAME.SomeActivity2", - mUserHandle + mUserHandle, ), ) assertThat(filterPredictions(predictions, allWidgets, filter)).isEmpty() } - private fun createWidgetItem( - providerInfo: AppWidgetProviderInfo, - ): WidgetItem { + private fun createWidgetItem(providerInfo: AppWidgetProviderInfo): WidgetItem { val widgetInfo = LauncherAppWidgetProviderInfo.fromProviderInfo(context, providerInfo) return WidgetItem(widgetInfo, testInvariantProfile, iconCache, context) } companion object { + const val TEST_TIMEOUT = 3L + const val TEST_UI_SURFACE = "widgets_test" const val BUNDLE_KEY_ADDED_APP_WIDGETS = "added_app_widgets" @@ -203,13 +275,13 @@ class WidgetsPredictionsRequesterTest { private fun buildExpectedAppTargetEvent( pkg: String, providerClassName: String, - userHandle: UserHandle + userHandle: UserHandle, ): AppTargetEvent { val appTarget = AppTarget.Builder( /*id=*/ AppTargetId("widget:$pkg"), /*packageName=*/ pkg, - /*user=*/ userHandle + /*user=*/ userHandle, ) .setClassName(providerClassName) .build()