diff --git a/snail-kotlin/src/main/java/com/compass/snail/Replay.kt b/snail-kotlin/src/main/java/com/compass/snail/Replay.kt index a421c3f..5633e41 100644 --- a/snail-kotlin/src/main/java/com/compass/snail/Replay.kt +++ b/snail-kotlin/src/main/java/com/compass/snail/Replay.kt @@ -4,9 +4,12 @@ package com.compass.snail import com.compass.snail.disposer.Disposable import kotlinx.coroutines.CoroutineDispatcher +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock open class Replay(private val threshold: Int) : Observable() { private var values: MutableList = mutableListOf() + private val lock = ReentrantLock() override fun subscribe(dispatcher: CoroutineDispatcher?, next: ((T) -> Unit)?, error: ((Throwable) -> Unit)?, done: (() -> Unit)?): Disposable { replay(dispatcher, createHandler(next, error, done)) @@ -14,12 +17,16 @@ open class Replay(private val threshold: Int) : Observable() { } override fun next(value: T) { - values.add(value) - values = values.takeLast(threshold).toMutableList() + lock.withLock { + values.add(value) + values = values.takeLast(threshold).toMutableList() + } super.next(value) } private fun replay(dispatcher: CoroutineDispatcher?, handler: (Event) -> Unit) { - values.forEach { notify(Subscriber(dispatcher, handler, this), Event(next = Next(it))) } + lock.withLock { + values.forEach { notify(Subscriber(dispatcher, handler, this), Event(next = Next(it))) } + } } } diff --git a/snail-kotlin/src/test/java/com/compass/snail/ReplayTests.kt b/snail-kotlin/src/test/java/com/compass/snail/ReplayTests.kt index 4486ace..02f8fdb 100644 --- a/snail-kotlin/src/test/java/com/compass/snail/ReplayTests.kt +++ b/snail-kotlin/src/test/java/com/compass/snail/ReplayTests.kt @@ -5,6 +5,9 @@ package com.compass.snail import org.junit.Assert.assertEquals import org.junit.Before import org.junit.Test +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import kotlin.concurrent.thread class ReplayTests { private var subject: Replay? = null @@ -38,4 +41,38 @@ class ReplayTests { assertEquals("2", b[0]) assertEquals(2, b.size) } + + @Test + fun testMultiThreadedBehavior() { + val subject = Replay(1) + var a = 0 + var b = 0 + + subject.subscribe(next = { + a = it + }) + subject.subscribe(next = { + b = it + }) + + val latch = CountDownLatch(2) + thread { + for (i in 1..100) { + subject.next(i) + } + latch.countDown() + } + thread { + for (i in 1..100) { + subject.next(i) + } + latch.countDown() + } + latch.await(1000, TimeUnit.SECONDS) + + subject.removeSubscribers() + + assertEquals(100, a) + assertEquals(100, b) + } }