diff --git a/api.txt b/api.txt index d5c9a173af1..142dd386e4f 100644 --- a/api.txt +++ b/api.txt @@ -1414,6 +1414,19 @@ package androidx.media3.datasource.okhttp { } +package androidx.media3.datasource.ktor { + + public class KtorDataSource implements androidx.media3.datasource.DataSource androidx.media3.datasource.HttpDataSource { + } + + public static final class KtorDataSource.Factory implements androidx.media3.datasource.HttpDataSource.Factory { + ctor public KtorDataSource.Factory(io.ktor.client.HttpClient); + ctor public KtorDataSource.Factory(io.ktor.client.HttpClient, kotlinx.coroutines.CoroutineScope); + method public androidx.media3.datasource.ktor.KtorDataSource.Factory setUserAgent(@Nullable String); + } + +} + package androidx.media3.exoplayer { public final class ExoPlaybackException extends androidx.media3.common.PlaybackException { diff --git a/constants.gradle b/constants.gradle index 40513baf94d..c37b4d9a35c 100644 --- a/constants.gradle +++ b/constants.gradle @@ -65,6 +65,7 @@ project.ext { desugarJdkLibsVersion = '2.1.5' lottieVersion = '6.6.0' truthVersion = '1.4.0' + ktorVersion = '3.0.3' okhttpVersion = '4.12.0' testParameterInjectorVersion = '1.18' modulePrefix = ':' diff --git a/core_settings.gradle b/core_settings.gradle index 73f3c542f30..b4108e0216b 100644 --- a/core_settings.gradle +++ b/core_settings.gradle @@ -68,6 +68,8 @@ include modulePrefix + 'lib-datasource-rtmp' project(modulePrefix + 'lib-datasource-rtmp').projectDir = new File(rootDir, 'libraries/datasource_rtmp') include modulePrefix + 'lib-datasource-okhttp' project(modulePrefix + 'lib-datasource-okhttp').projectDir = new File(rootDir, 'libraries/datasource_okhttp') +include modulePrefix + 'lib-datasource-ktor' +project(modulePrefix + 'lib-datasource-ktor').projectDir = new File(rootDir, 'libraries/datasource_ktor') include modulePrefix + 'lib-decoder' project(modulePrefix + 'lib-decoder').projectDir = new File(rootDir, 'libraries/decoder') diff --git a/libraries/datasource_ktor/README.md b/libraries/datasource_ktor/README.md new file mode 100644 index 00000000000..965f8465bb3 --- /dev/null +++ b/libraries/datasource_ktor/README.md @@ -0,0 +1,64 @@ +# Ktor DataSource module + +This module provides an [HttpDataSource][] implementation that uses [Ktor][]. + +Ktor is a multiplatform HTTP client developed by JetBrains. It supports HTTP/2, +WebSocket, and Kotlin coroutines for asynchronous operations. + +[HttpDataSource]: ../datasource/src/main/java/androidx/media3/datasource/HttpDataSource.java +[Ktor]: https://ktor.io/ + +## Getting the module + +The easiest way to get the module is to add it as a gradle dependency: + +```groovy +implementation 'androidx.media3:media3-datasource-ktor:1.X.X' +``` + +where `1.X.X` is the version, which must match the version of the other media +modules being used. + +Alternatively, you can clone this GitHub project and depend on the module +locally. Instructions for doing this can be found in the [top level README][]. + +[top level README]: ../../README.md + +## Using the module + +Media components request data through `DataSource` instances. These instances +are obtained from instances of `DataSource.Factory`, which are instantiated and +injected from application code. + +If your application only needs to play http(s) content, using the Ktor +extension is as simple as updating any `DataSource.Factory` instantiations in +your application code to use `KtorDataSource.Factory`. If your application +also needs to play non-http(s) content such as local files, use: +``` +new DefaultDataSourceFactory( + ... + /* baseDataSourceFactory= */ new KtorDataSource.Factory(...)); +``` + +### Using with OkHttp engine + +```kotlin +val dataSourceFactory = KtorDataSource.Factory(OkHttp.create()) +``` + +### Using with a custom HttpClient + +```kotlin +val httpClient = HttpClient(OkHttp) { + engine { + // Configure OkHttp engine + } +} +val dataSourceFactory = KtorDataSource.Factory(httpClient) +``` + +## Links + +* [Javadoc][] + +[Javadoc]: https://developer.android.com/reference/androidx/media3/datasource/ktor/package-summary diff --git a/libraries/datasource_ktor/build.gradle b/libraries/datasource_ktor/build.gradle new file mode 100644 index 00000000000..582de699a39 --- /dev/null +++ b/libraries/datasource_ktor/build.gradle @@ -0,0 +1,41 @@ +apply from: "$gradle.ext.androidxMediaSettingsDir/common_library_config.gradle" + +apply plugin: 'kotlin-android' + +android { + namespace 'androidx.media3.datasource.ktor' + + defaultConfig.minSdkVersion project.ext.minSdkVersion + + publishing { + singleVariant('release') { + withSourcesJar() + } + } + + kotlinOptions { + jvmTarget = '1.8' + } +} + +dependencies { + api project(modulePrefix + 'lib-common') + api project(modulePrefix + 'lib-datasource') + implementation 'androidx.annotation:annotation:' + androidxAnnotationVersion + compileOnly 'com.google.errorprone:error_prone_annotations:' + errorProneVersion + compileOnly 'org.checkerframework:checker-qual:' + checkerframeworkVersion + compileOnly 'org.jetbrains.kotlin:kotlin-annotations-jvm:' + kotlinAnnotationsVersion + implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:' + kotlinxCoroutinesVersion + testImplementation project(modulePrefix + 'test-utils') + testImplementation 'com.squareup.okhttp3:mockwebserver:' + okhttpVersion + testImplementation 'org.robolectric:robolectric:' + robolectricVersion + testImplementation 'io.ktor:ktor-client-okhttp:' + ktorVersion + api 'io.ktor:ktor-client-core:' + ktorVersion +} + +ext { + releaseArtifactId = 'media3-datasource-ktor' + releaseName = 'Media3 Ktor DataSource module' + +} +apply from: '../../publish.gradle' diff --git a/libraries/datasource_ktor/proguard-rules.txt b/libraries/datasource_ktor/proguard-rules.txt new file mode 100644 index 00000000000..feb77bb649d --- /dev/null +++ b/libraries/datasource_ktor/proguard-rules.txt @@ -0,0 +1,11 @@ +# Proguard rules specific to the Ktor extension. + +# Options for Ktor and Okio +-dontwarn io.ktor.** +-dontwarn okio.** +-dontwarn javax.annotation.** +-dontwarn org.conscrypt.** + +# Keep Ktor client classes +-keep class io.ktor.** { *; } +-keep class kotlinx.coroutines.** { *; } diff --git a/libraries/datasource_ktor/src/main/AndroidManifest.xml b/libraries/datasource_ktor/src/main/AndroidManifest.xml new file mode 100644 index 00000000000..499171ca5de --- /dev/null +++ b/libraries/datasource_ktor/src/main/AndroidManifest.xml @@ -0,0 +1,4 @@ + + + + diff --git a/libraries/datasource_ktor/src/main/java/androidx/media3/datasource/ktor/KtorDataSource.kt b/libraries/datasource_ktor/src/main/java/androidx/media3/datasource/ktor/KtorDataSource.kt new file mode 100644 index 00000000000..03339f77841 --- /dev/null +++ b/libraries/datasource_ktor/src/main/java/androidx/media3/datasource/ktor/KtorDataSource.kt @@ -0,0 +1,519 @@ +package androidx.media3.datasource.ktor + +import android.net.Uri +import androidx.media3.common.C +import androidx.media3.common.MediaLibraryInfo +import androidx.media3.common.PlaybackException +import androidx.media3.common.util.UnstableApi +import androidx.media3.common.util.Util +import androidx.media3.datasource.BaseDataSource +import androidx.media3.datasource.DataSourceException +import androidx.media3.datasource.DataSpec +import androidx.media3.datasource.HttpDataSource +import androidx.media3.datasource.HttpUtil +import androidx.media3.datasource.TransferListener +import com.google.common.base.Predicate +import com.google.common.net.HttpHeaders +import io.ktor.client.HttpClient +import io.ktor.client.request.headers +import io.ktor.client.request.prepareRequest +import io.ktor.client.request.setBody +import io.ktor.client.request.url +import io.ktor.client.statement.HttpResponse +import io.ktor.client.statement.bodyAsChannel +import io.ktor.client.statement.request +import io.ktor.http.HttpMethod +import io.ktor.http.contentLength +import io.ktor.http.contentType +import io.ktor.utils.io.jvm.javaio.toInputStream +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.launch +import java.io.IOException +import java.io.InterruptedIOException +import java.util.TreeMap +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicReference +import kotlin.math.min + +/** + * An [HttpDataSource] that delegates to Ktor's [HttpClient]. + * + * Note: HTTP request headers will be set using all parameters passed via (in order of decreasing + * priority) the `dataSpec`, [setRequestProperty] and the default parameters used to construct + * the instance. + */ +class KtorDataSource private constructor( + private val httpClient: HttpClient, + private val coroutineScope: CoroutineScope, + private val userAgent: String?, + private val cacheControl: String?, + private val defaultRequestProperties: HttpDataSource.RequestProperties?, + private val contentTypePredicate: Predicate?, + private val requestProperties: HttpDataSource.RequestProperties +) : BaseDataSource(true), HttpDataSource { + + companion object { + private const val TAG = "KtorDataSource" + + init { + MediaLibraryInfo.registerModule("media3.datasource.ktor") + } + } + + /** + * [androidx.media3.datasource.DataSource.Factory] for [KtorDataSource] instances. + * + * @param httpClient A [HttpClient] for use by the sources created by the factory. + * @param scope A [CoroutineScope] for running suspend functions. If not provided, a default + * scope with [Dispatchers.IO] and a [SupervisorJob] will be created. + */ + class Factory( + private val httpClient: HttpClient, + private val scope: CoroutineScope = CoroutineScope(SupervisorJob() + Dispatchers.IO) + ) : HttpDataSource.Factory { + + private val defaultRequestProperties = HttpDataSource.RequestProperties() + + private var userAgent: String? = null + + private var cacheControl: String? = null + + private var transferListener: TransferListener? = null + + private var contentTypePredicate: Predicate? = null + + @UnstableApi + override fun setDefaultRequestProperties(defaultRequestProperties: Map): Factory { + this.defaultRequestProperties.clearAndSet(defaultRequestProperties) + return this + } + + /** + * Sets the user agent that will be used. + * + * The default is `null`, which causes the default user agent of the underlying [HttpClient] + * to be used. + * + * @param userAgent The user agent that will be used, or `null` to use the default user + * agent of the underlying [HttpClient]. + * @return This factory. + */ + fun setUserAgent(userAgent: String?): Factory { + this.userAgent = userAgent + return this + } + + /** + * Sets the Cache-Control header that will be used. + * + * The default is `null`. + * + * @param cacheControl The cache control header value that will be used, or `null` to clear + * a previously set value. + * @return This factory. + */ + @UnstableApi + fun setCacheControl(cacheControl: String?): Factory { + this.cacheControl = cacheControl + return this + } + + /** + * Sets a content type [Predicate]. If a content type is rejected by the predicate then a + * [HttpDataSource.InvalidContentTypeException] is thrown from [KtorDataSource.open]. + * + * The default is `null`. + * + * @param contentTypePredicate The content type [Predicate], or `null` to clear a predicate + * that was previously set. + * @return This factory. + */ + @UnstableApi + fun setContentTypePredicate(contentTypePredicate: Predicate?): Factory { + this.contentTypePredicate = contentTypePredicate + return this + } + + /** + * Sets the [TransferListener] that will be used. + * + * The default is `null`. + * + * See [androidx.media3.datasource.DataSource.addTransferListener]. + * + * @param transferListener The listener that will be used. + * @return This factory. + */ + @UnstableApi + fun setTransferListener(transferListener: TransferListener?): Factory { + this.transferListener = transferListener + return this + } + + @UnstableApi + override fun createDataSource(): KtorDataSource { + val client = httpClient + val dataSource = KtorDataSource( + client, + scope, + userAgent, + cacheControl, + defaultRequestProperties, + contentTypePredicate, + HttpDataSource.RequestProperties() + ) + transferListener?.let { dataSource.addTransferListener(it) } + return dataSource + } + } + + private var dataSpec: DataSpec? = null + + private var response: HttpResponse? = null + + private var responseInputStream: java.io.InputStream? = null + + private var currentJob: Job? = null + + private var connectionEstablished = false + private var bytesToRead: Long = 0 + private var bytesRead: Long = 0 + + @UnstableApi + override fun getUri(): Uri? { + return if (response != null) { + Uri.parse(response!!.request.url.toString()) + } else if (dataSpec != null) { + dataSpec!!.uri + } else { + null + } + } + + @UnstableApi + override fun getResponseCode(): Int { + return response?.status?.value ?: -1 + } + + @UnstableApi + override fun getResponseHeaders(): Map> { + val httpResponse = response ?: return emptyMap() + val headers = TreeMap>(String.CASE_INSENSITIVE_ORDER) + httpResponse.headers.names().forEach { name -> + headers[name] = httpResponse.headers.getAll(name) ?: emptyList() + } + return headers + } + + @UnstableApi + override fun setRequestProperty(name: String, value: String) { + requireNotNull(name) { "name cannot be null" } + requireNotNull(value) { "value cannot be null" } + requestProperties.set(name, value) + } + + @UnstableApi + override fun clearRequestProperty(name: String) { + requireNotNull(name) { "name cannot be null" } + requestProperties.remove(name) + } + + @UnstableApi + override fun clearAllRequestProperties() { + requestProperties.clear() + } + + @UnstableApi + @Throws(HttpDataSource.HttpDataSourceException::class) + override fun open(dataSpec: DataSpec): Long { + this.dataSpec = dataSpec + bytesRead = 0 + bytesToRead = 0 + transferInitializing(dataSpec) + + try { + val httpResponse = executeRequest(dataSpec) + this.response = httpResponse + this.responseInputStream = + executeSuspend { httpResponse.bodyAsChannel().toInputStream() } + } catch (e: IOException) { + if (e is HttpDataSource.HttpDataSourceException) throw e + throw HttpDataSource.HttpDataSourceException.createForIOException( + e, dataSpec, HttpDataSource.HttpDataSourceException.TYPE_OPEN + ) + } + + val httpResponse = this.response!! + val responseCode = httpResponse.status.value + + if (responseCode !in 200..299) { + if (responseCode == 416) { + val contentRange = httpResponse.headers[HttpHeaders.CONTENT_RANGE] + val documentSize = HttpUtil.getDocumentSize(contentRange) + if (dataSpec.position == documentSize) { + connectionEstablished = true + transferStarted(dataSpec) + return if (dataSpec.length != C.LENGTH_UNSET.toLong()) dataSpec.length else 0 + } + } + + val errorResponseBody: ByteArray = try { + responseInputStream?.readBytes() ?: Util.EMPTY_BYTE_ARRAY + } catch (e: IOException) { + Util.EMPTY_BYTE_ARRAY + } + + val headers = getResponseHeaders() + closeConnectionQuietly() + + val cause: IOException? = if (responseCode == 416) { + DataSourceException(PlaybackException.ERROR_CODE_IO_READ_POSITION_OUT_OF_RANGE) + } else { + null + } + + throw HttpDataSource.InvalidResponseCodeException( + responseCode, + httpResponse.status.description, + cause, + headers, + dataSpec, + errorResponseBody + ) + } + + val contentType = httpResponse.contentType()?.toString() ?: "" + if (contentTypePredicate != null && !contentTypePredicate.apply(contentType)) { + closeConnectionQuietly() + throw HttpDataSource.InvalidContentTypeException(contentType, dataSpec) + } + + val bytesToSkip = + if (responseCode == 200 && dataSpec.position != 0L) dataSpec.position else 0L + + if (dataSpec.length != C.LENGTH_UNSET.toLong()) { + bytesToRead = dataSpec.length + } else { + val contentLength = httpResponse.contentLength() ?: -1L + bytesToRead = + if (contentLength >= 0) contentLength - bytesToSkip else C.LENGTH_UNSET.toLong() + } + + connectionEstablished = true + transferStarted(dataSpec) + + try { + skipFully(bytesToSkip, dataSpec) + } catch (e: HttpDataSource.HttpDataSourceException) { + closeConnectionQuietly() + throw e + } + + return bytesToRead + } + + @UnstableApi + @Throws(HttpDataSource.HttpDataSourceException::class) + override fun read(buffer: ByteArray, offset: Int, length: Int): Int { + return try { + readInternal(buffer, offset, length) + } catch (e: IOException) { + throw HttpDataSource.HttpDataSourceException.createForIOException( + e, dataSpec!!, HttpDataSource.HttpDataSourceException.TYPE_READ + ) + } + } + + @UnstableApi + override fun close() { + if (connectionEstablished) { + connectionEstablished = false + transferEnded() + closeConnectionQuietly() + } + response = null + dataSpec = null + } + + @Throws(IOException::class) + private fun executeRequest(dataSpec: DataSpec): HttpResponse { + val urlString = dataSpec.uri.toString() + + val uri = Uri.parse(urlString) + val scheme = uri.scheme + if (scheme == null || !scheme.lowercase().startsWith("http")) { + throw HttpDataSource.HttpDataSourceException( + "Malformed URL", + dataSpec, + PlaybackException.ERROR_CODE_FAILED_RUNTIME_CHECK, + HttpDataSource.HttpDataSourceException.TYPE_OPEN + ) + } + + val mergedHeaders = HashMap() + defaultRequestProperties?.snapshot?.forEach { (key, value) -> + mergedHeaders[key] = value + } + requestProperties.snapshot.forEach { (key, value) -> + mergedHeaders[key] = value + } + dataSpec.httpRequestHeaders.forEach { (key, value) -> + mergedHeaders[key] = value + } + + return executeSuspend { + httpClient.prepareRequest { + url(urlString) + + headers { + mergedHeaders.forEach { (key, value) -> + append(key, value) + } + + val rangeHeader = + HttpUtil.buildRangeRequestHeader(dataSpec.position, dataSpec.length) + if (rangeHeader != null) { + append(HttpHeaders.RANGE, rangeHeader) + } + + if (userAgent != null) { + append(HttpHeaders.USER_AGENT, userAgent) + } + + if (cacheControl != null) { + append(HttpHeaders.CACHE_CONTROL, cacheControl) + } + + if (!dataSpec.isFlagSet(DataSpec.FLAG_ALLOW_GZIP)) { + append(HttpHeaders.ACCEPT_ENCODING, "identity") + } + } + + method = when (dataSpec.httpMethod) { + DataSpec.HTTP_METHOD_GET -> HttpMethod.Get + DataSpec.HTTP_METHOD_POST -> HttpMethod.Post + DataSpec.HTTP_METHOD_HEAD -> HttpMethod.Head + else -> HttpMethod.Get + } + + if (dataSpec.httpBody != null) { + setBody(dataSpec.httpBody!!) + } else if (dataSpec.httpMethod == DataSpec.HTTP_METHOD_POST) { + setBody(ByteArray(0)) + } + }.execute() + } + } + + @Throws(IOException::class) + private fun executeSuspend(block: suspend () -> T): T { + val exceptionRef = AtomicReference(null) + val resultRef = AtomicReference(null) + val latch = CountDownLatch(1) + + currentJob = coroutineScope.launch { + try { + resultRef.set(block()) + } catch (e: CancellationException) { + exceptionRef.set(InterruptedIOException()) + } catch (e: Exception) { + exceptionRef.set(e) + } finally { + latch.countDown() + } + } + + try { + latch.await() + } catch (e: InterruptedException) { + currentJob?.cancel() + throw InterruptedIOException() + } + + exceptionRef.get()?.let { throwable -> + when (throwable) { + is IOException -> throw throwable + is InterruptedIOException -> throw throwable + else -> throw IOException(throwable) + } + } + + @Suppress("UNCHECKED_CAST") + return resultRef.get() as T + } + + @Throws(HttpDataSource.HttpDataSourceException::class) + private fun skipFully(bytesToSkip: Long, dataSpec: DataSpec) { + if (bytesToSkip == 0L) return + + val skipBuffer = ByteArray(4096) + var remaining = bytesToSkip + + try { + val inputStream = responseInputStream ?: throw IOException("Stream closed") + while (remaining > 0) { + val readLength = min(remaining.toInt(), skipBuffer.size) + val read = inputStream.read(skipBuffer, 0, readLength) + + if (Thread.currentThread().isInterrupted) { + throw InterruptedIOException() + } + + if (read < 0) { + throw HttpDataSource.HttpDataSourceException( + dataSpec, + PlaybackException.ERROR_CODE_IO_READ_POSITION_OUT_OF_RANGE, + HttpDataSource.HttpDataSourceException.TYPE_OPEN + ) + } + + remaining -= read + bytesTransferred(read) + } + } catch (e: IOException) { + if (e is HttpDataSource.HttpDataSourceException) throw e + throw HttpDataSource.HttpDataSourceException( + dataSpec, + PlaybackException.ERROR_CODE_IO_UNSPECIFIED, + HttpDataSource.HttpDataSourceException.TYPE_OPEN + ) + } + } + + @Throws(IOException::class) + private fun readInternal(buffer: ByteArray, offset: Int, readLength: Int): Int { + if (readLength == 0) return 0 + + if (bytesToRead != C.LENGTH_UNSET.toLong()) { + val bytesRemaining = bytesToRead - bytesRead + if (bytesRemaining == 0L) return C.RESULT_END_OF_INPUT + + val actualReadLength = min(readLength.toLong(), bytesRemaining).toInt() + return readFromStream(buffer, offset, actualReadLength) + } + + return readFromStream(buffer, offset, readLength) + } + + @Throws(IOException::class) + private fun readFromStream(buffer: ByteArray, offset: Int, readLength: Int): Int { + val inputStream = responseInputStream ?: return C.RESULT_END_OF_INPUT + val read = inputStream.read(buffer, offset, readLength) + + if (read < 0) return C.RESULT_END_OF_INPUT + + bytesRead += read + bytesTransferred(read) + return read + } + + private fun closeConnectionQuietly() { + responseInputStream?.close() + responseInputStream = null + currentJob = null + } +} diff --git a/libraries/datasource_ktor/src/main/java/androidx/media3/datasource/ktor/package-info.java b/libraries/datasource_ktor/src/main/java/androidx/media3/datasource/ktor/package-info.java new file mode 100644 index 00000000000..2c8ed6752e3 --- /dev/null +++ b/libraries/datasource_ktor/src/main/java/androidx/media3/datasource/ktor/package-info.java @@ -0,0 +1,4 @@ +@NonNullApi +package androidx.media3.datasource.ktor; + +import androidx.media3.common.util.NonNullApi; diff --git a/libraries/datasource_ktor/src/test/AndroidManifest.xml b/libraries/datasource_ktor/src/test/AndroidManifest.xml new file mode 100644 index 00000000000..7e4fa50124c --- /dev/null +++ b/libraries/datasource_ktor/src/test/AndroidManifest.xml @@ -0,0 +1,4 @@ + + + + diff --git a/libraries/datasource_ktor/src/test/java/androidx/media3/datasource/ktor/KtorDataSourceContractTest.kt b/libraries/datasource_ktor/src/test/java/androidx/media3/datasource/ktor/KtorDataSourceContractTest.kt new file mode 100644 index 00000000000..d79e5cc6616 --- /dev/null +++ b/libraries/datasource_ktor/src/test/java/androidx/media3/datasource/ktor/KtorDataSourceContractTest.kt @@ -0,0 +1,37 @@ +package androidx.media3.datasource.ktor + +import androidx.media3.datasource.DataSource +import androidx.media3.test.utils.DataSourceContractTest +import androidx.media3.test.utils.HttpDataSourceTestEnv +import androidx.test.ext.junit.runners.AndroidJUnit4 +import com.google.common.collect.ImmutableList +import io.ktor.client.HttpClient +import io.ktor.client.engine.okhttp.OkHttp +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import org.junit.Rule +import org.junit.runner.RunWith + +@RunWith(AndroidJUnit4::class) +class KtorDataSourceContractTest : DataSourceContractTest() { + + @JvmField + @Rule + var httpDataSourceTestEnv = HttpDataSourceTestEnv() + val httpClient = HttpClient() + + private val coroutineScope = CoroutineScope(SupervisorJob() + Dispatchers.IO) + + override fun createDataSource(): DataSource { + return KtorDataSource.Factory(httpClient, coroutineScope).createDataSource() + } + + override fun getTestResources(): ImmutableList { + return httpDataSourceTestEnv.servedResources + } + + override fun getNotFoundResources(): MutableList { + return httpDataSourceTestEnv.notFoundResources + } +} diff --git a/libraries/datasource_ktor/src/test/java/androidx/media3/datasource/ktor/KtorDataSourceTest.kt b/libraries/datasource_ktor/src/test/java/androidx/media3/datasource/ktor/KtorDataSourceTest.kt new file mode 100644 index 00000000000..043693383a0 --- /dev/null +++ b/libraries/datasource_ktor/src/test/java/androidx/media3/datasource/ktor/KtorDataSourceTest.kt @@ -0,0 +1,153 @@ +package androidx.media3.datasource.ktor + +import androidx.media3.datasource.DataSpec +import androidx.media3.datasource.HttpDataSource +import androidx.test.ext.junit.runners.AndroidJUnit4 +import com.google.common.truth.Truth.assertThat +import io.ktor.client.HttpClient +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import java.nio.charset.StandardCharsets +import java.util.HashMap +import java.util.concurrent.TimeUnit +import okhttp3.mockwebserver.MockResponse +import okhttp3.mockwebserver.MockWebServer +import org.junit.Assert.assertThrows +import org.junit.Test +import org.junit.runner.RunWith + +@RunWith(AndroidJUnit4::class) +class KtorDataSourceTest { + + private val coroutineScope = CoroutineScope(SupervisorJob() + Dispatchers.IO) + val httpClient = HttpClient() + + @Test + @Throws(Exception::class) + fun open_setsCorrectHeaders() { + val mockWebServer = MockWebServer() + mockWebServer.enqueue(MockResponse()) + + val propertyFromFactory = "fromFactory" + val defaultRequestProperties = HashMap() + defaultRequestProperties["0"] = propertyFromFactory + defaultRequestProperties["1"] = propertyFromFactory + defaultRequestProperties["2"] = propertyFromFactory + defaultRequestProperties["4"] = propertyFromFactory + + val dataSource = KtorDataSource.Factory(httpClient, coroutineScope) + .setDefaultRequestProperties(defaultRequestProperties) + .createDataSource() + + val propertyFromSetter = "fromSetter" + dataSource.setRequestProperty("1", propertyFromSetter) + dataSource.setRequestProperty("2", propertyFromSetter) + dataSource.setRequestProperty("3", propertyFromSetter) + dataSource.setRequestProperty("5", propertyFromSetter) + + val propertyFromDataSpec = "fromDataSpec" + val dataSpecRequestProperties = HashMap() + dataSpecRequestProperties["2"] = propertyFromDataSpec + dataSpecRequestProperties["3"] = propertyFromDataSpec + dataSpecRequestProperties["4"] = propertyFromDataSpec + dataSpecRequestProperties["6"] = propertyFromDataSpec + + val dataSpec = DataSpec.Builder() + .setUri(mockWebServer.url("/test-path").toString()) + .setHttpRequestHeaders(dataSpecRequestProperties) + .build() + + dataSource.open(dataSpec) + + val request = mockWebServer.takeRequest(10, TimeUnit.SECONDS) + assertThat(request).isNotNull() + val headers = request!!.headers + assertThat(headers["0"]).isEqualTo(propertyFromFactory) + assertThat(headers["1"]).isEqualTo(propertyFromSetter) + assertThat(headers["2"]).isEqualTo(propertyFromDataSpec) + assertThat(headers["3"]).isEqualTo(propertyFromDataSpec) + assertThat(headers["4"]).isEqualTo(propertyFromDataSpec) + assertThat(headers["5"]).isEqualTo(propertyFromSetter) + assertThat(headers["6"]).isEqualTo(propertyFromDataSpec) + } + + @Test + fun open_invalidResponseCode() { + val mockWebServer = MockWebServer() + mockWebServer.enqueue(MockResponse().setResponseCode(404).setBody("failure msg")) + + val dataSource = KtorDataSource.Factory(httpClient, coroutineScope).createDataSource() + + val dataSpec = DataSpec.Builder() + .setUri(mockWebServer.url("/test-path").toString()) + .build() + + val exception = assertThrows( + HttpDataSource.InvalidResponseCodeException::class.java + ) { dataSource.open(dataSpec) } + + assertThat(exception.responseCode).isEqualTo(404) + assertThat(exception.responseBody).isEqualTo("failure msg".toByteArray(StandardCharsets.UTF_8)) + } + + @Test + @Throws(Exception::class) + fun factory_setRequestPropertyAfterCreation_setsCorrectHeaders() { + val mockWebServer = MockWebServer() + mockWebServer.enqueue(MockResponse()) + val dataSpec = DataSpec.Builder() + .setUri(mockWebServer.url("/test-path").toString()) + .build() + + val factory = KtorDataSource.Factory(httpClient, coroutineScope) + val dataSource = factory.createDataSource() + + val defaultRequestProperties = HashMap() + defaultRequestProperties["0"] = "afterCreation" + factory.setDefaultRequestProperties(defaultRequestProperties) + dataSource.open(dataSpec) + + val request = mockWebServer.takeRequest(10, TimeUnit.SECONDS) + assertThat(request).isNotNull() + val headers = request!!.headers + assertThat(headers["0"]).isEqualTo("afterCreation") + } + + @Test + fun open_malformedUrl_throwsException() { + val dataSource = KtorDataSource.Factory(httpClient, coroutineScope).createDataSource() + + val dataSpec = DataSpec.Builder() + .setUri("not-a-valid-url") + .build() + + val exception = assertThrows( + HttpDataSource.HttpDataSourceException::class.java + ) { dataSource.open(dataSpec) } + + assertThat(exception.message).contains("Malformed URL") + } + + @Test + @Throws(Exception::class) + fun open_httpPost_sendsPostRequest() { + val mockWebServer = MockWebServer() + mockWebServer.enqueue(MockResponse()) + + val dataSource = KtorDataSource.Factory(httpClient, coroutineScope).createDataSource() + + val dataSpec = DataSpec.Builder() + .setUri(mockWebServer.url("/test-path").toString()) + .setHttpMethod(DataSpec.HTTP_METHOD_POST) + .setHttpBody("test body".toByteArray(StandardCharsets.UTF_8)) + .build() + + dataSource.open(dataSpec) + + val request = mockWebServer.takeRequest(10, TimeUnit.SECONDS) + assertThat(request).isNotNull() + assertThat(request!!.method).isEqualTo("POST") + assertThat(request.body.readUtf8()).isEqualTo("test body") + } +}